UNPKG

3.31 kBJavaScriptView Raw
1'use strict'
2
3const clone = require('../../utils/object').clone
4const format = require('../../utils/string').format
5
6function factory (type, config, load, typed) {
7 const matrix = load(require('../../type/matrix/function/matrix'))
8 const add = load(require('../arithmetic/add'))
9
10 /**
11 * Calculate the trace of a matrix: the sum of the elements on the main
12 * diagonal of a square matrix.
13 *
14 * Syntax:
15 *
16 * math.trace(x)
17 *
18 * Examples:
19 *
20 * math.trace([[1, 2], [3, 4]]) // returns 5
21 *
22 * const A = [
23 * [1, 2, 3],
24 * [-1, 2, 3],
25 * [2, 0, 3]
26 * ]
27 * math.trace(A) // returns 6
28 *
29 * See also:
30 *
31 * diag
32 *
33 * @param {Array | Matrix} x A matrix
34 *
35 * @return {number} The trace of `x`
36 */
37 const trace = typed('trace', {
38
39 'Array': function _arrayTrace (x) {
40 // use dense matrix implementation
41 return _denseTrace(matrix(x))
42 },
43
44 'SparseMatrix': _sparseTrace,
45
46 'DenseMatrix': _denseTrace,
47
48 'any': clone
49 })
50
51 function _denseTrace (m) {
52 // matrix size & data
53 const size = m._size
54 const data = m._data
55
56 // process dimensions
57 switch (size.length) {
58 case 1:
59 // vector
60 if (size[0] === 1) {
61 // return data[0]
62 return clone(data[0])
63 }
64 throw new RangeError('Matrix must be square (size: ' + format(size) + ')')
65 case 2:
66 // two dimensional
67 const rows = size[0]
68 const cols = size[1]
69 if (rows === cols) {
70 // calulate sum
71 let sum = 0
72 // loop diagonal
73 for (let i = 0; i < rows; i++) { sum = add(sum, data[i][i]) }
74 // return trace
75 return sum
76 }
77 throw new RangeError('Matrix must be square (size: ' + format(size) + ')')
78 default:
79 // multi dimensional
80 throw new RangeError('Matrix must be two dimensional (size: ' + format(size) + ')')
81 }
82 }
83
84 function _sparseTrace (m) {
85 // matrix arrays
86 const values = m._values
87 const index = m._index
88 const ptr = m._ptr
89 const size = m._size
90 // check dimensions
91 const rows = size[0]
92 const columns = size[1]
93 // matrix must be square
94 if (rows === columns) {
95 // calulate sum
96 let sum = 0
97 // check we have data (avoid looping columns)
98 if (values.length > 0) {
99 // loop columns
100 for (let j = 0; j < columns; j++) {
101 // k0 <= k < k1 where k0 = _ptr[j] && k1 = _ptr[j+1]
102 const k0 = ptr[j]
103 const k1 = ptr[j + 1]
104 // loop k within [k0, k1[
105 for (let k = k0; k < k1; k++) {
106 // row index
107 const i = index[k]
108 // check row
109 if (i === j) {
110 // accumulate value
111 sum = add(sum, values[k])
112 // exit loop
113 break
114 }
115 if (i > j) {
116 // exit loop, no value on the diagonal for column j
117 break
118 }
119 }
120 }
121 }
122 // return trace
123 return sum
124 }
125 throw new RangeError('Matrix must be square (size: ' + format(size) + ')')
126 }
127
128 trace.toTex = { 1: `\\mathrm{tr}\\left(\${args[0]}\\right)` }
129
130 return trace
131}
132
133exports.name = 'trace'
134exports.factory = factory