1 | 'use strict'
|
2 |
|
3 | const clone = require('../../utils/object').clone
|
4 | const format = require('../../utils/string').format
|
5 |
|
6 | function factory (type, config, load, typed) {
|
7 | const matrix = load(require('../../type/matrix/function/matrix'))
|
8 | const add = load(require('../arithmetic/add'))
|
9 |
|
10 | |
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 |
|
18 |
|
19 |
|
20 |
|
21 |
|
22 |
|
23 |
|
24 |
|
25 |
|
26 |
|
27 |
|
28 |
|
29 |
|
30 |
|
31 |
|
32 |
|
33 |
|
34 |
|
35 |
|
36 |
|
37 | const trace = typed('trace', {
|
38 |
|
39 | 'Array': function _arrayTrace (x) {
|
40 |
|
41 | return _denseTrace(matrix(x))
|
42 | },
|
43 |
|
44 | 'SparseMatrix': _sparseTrace,
|
45 |
|
46 | 'DenseMatrix': _denseTrace,
|
47 |
|
48 | 'any': clone
|
49 | })
|
50 |
|
51 | function _denseTrace (m) {
|
52 |
|
53 | const size = m._size
|
54 | const data = m._data
|
55 |
|
56 |
|
57 | switch (size.length) {
|
58 | case 1:
|
59 |
|
60 | if (size[0] === 1) {
|
61 |
|
62 | return clone(data[0])
|
63 | }
|
64 | throw new RangeError('Matrix must be square (size: ' + format(size) + ')')
|
65 | case 2:
|
66 |
|
67 | const rows = size[0]
|
68 | const cols = size[1]
|
69 | if (rows === cols) {
|
70 |
|
71 | let sum = 0
|
72 |
|
73 | for (let i = 0; i < rows; i++) { sum = add(sum, data[i][i]) }
|
74 |
|
75 | return sum
|
76 | }
|
77 | throw new RangeError('Matrix must be square (size: ' + format(size) + ')')
|
78 | default:
|
79 |
|
80 | throw new RangeError('Matrix must be two dimensional (size: ' + format(size) + ')')
|
81 | }
|
82 | }
|
83 |
|
84 | function _sparseTrace (m) {
|
85 |
|
86 | const values = m._values
|
87 | const index = m._index
|
88 | const ptr = m._ptr
|
89 | const size = m._size
|
90 |
|
91 | const rows = size[0]
|
92 | const columns = size[1]
|
93 |
|
94 | if (rows === columns) {
|
95 |
|
96 | let sum = 0
|
97 |
|
98 | if (values.length > 0) {
|
99 |
|
100 | for (let j = 0; j < columns; j++) {
|
101 |
|
102 | const k0 = ptr[j]
|
103 | const k1 = ptr[j + 1]
|
104 |
|
105 | for (let k = k0; k < k1; k++) {
|
106 |
|
107 | const i = index[k]
|
108 |
|
109 | if (i === j) {
|
110 |
|
111 | sum = add(sum, values[k])
|
112 |
|
113 | break
|
114 | }
|
115 | if (i > j) {
|
116 |
|
117 | break
|
118 | }
|
119 | }
|
120 | }
|
121 | }
|
122 |
|
123 | return sum
|
124 | }
|
125 | throw new RangeError('Matrix must be square (size: ' + format(size) + ')')
|
126 | }
|
127 |
|
128 | trace.toTex = { 1: `\\mathrm{tr}\\left(\${args[0]}\\right)` }
|
129 |
|
130 | return trace
|
131 | }
|
132 |
|
133 | exports.name = 'trace'
|
134 | exports.factory = factory
|