1 | import { factory } from '../../../utils/factory'
|
2 | import { DimensionError } from '../../../error/DimensionError'
|
3 |
|
4 | const name = 'algorithm07'
|
5 | const dependencies = ['typed', 'DenseMatrix']
|
6 |
|
7 | export const createAlgorithm07 = factory(name, dependencies, ({ typed, DenseMatrix }) => {
|
8 | |
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 |
|
18 |
|
19 |
|
20 |
|
21 |
|
22 | return function algorithm07 (a, b, callback) {
|
23 |
|
24 | const asize = a._size
|
25 | const adt = a._datatype
|
26 |
|
27 | const bsize = b._size
|
28 | const bdt = b._datatype
|
29 |
|
30 |
|
31 | if (asize.length !== bsize.length) { throw new DimensionError(asize.length, bsize.length) }
|
32 |
|
33 |
|
34 | if (asize[0] !== bsize[0] || asize[1] !== bsize[1]) { throw new RangeError('Dimension mismatch. Matrix A (' + asize + ') must match Matrix B (' + bsize + ')') }
|
35 |
|
36 |
|
37 | const rows = asize[0]
|
38 | const columns = asize[1]
|
39 |
|
40 |
|
41 | let dt
|
42 |
|
43 | let zero = 0
|
44 |
|
45 | let cf = callback
|
46 |
|
47 |
|
48 | if (typeof adt === 'string' && adt === bdt) {
|
49 |
|
50 | dt = adt
|
51 |
|
52 | zero = typed.convert(0, dt)
|
53 |
|
54 | cf = typed.find(callback, [dt, dt])
|
55 | }
|
56 |
|
57 |
|
58 | let i, j
|
59 |
|
60 |
|
61 | const cdata = []
|
62 |
|
63 | for (i = 0; i < rows; i++) { cdata[i] = [] }
|
64 |
|
65 |
|
66 | const c = new DenseMatrix({
|
67 | data: cdata,
|
68 | size: [rows, columns],
|
69 | datatype: dt
|
70 | })
|
71 |
|
72 |
|
73 | const xa = []
|
74 | const xb = []
|
75 |
|
76 | const wa = []
|
77 | const wb = []
|
78 |
|
79 |
|
80 | for (j = 0; j < columns; j++) {
|
81 |
|
82 | const mark = j + 1
|
83 |
|
84 | _scatter(a, j, wa, xa, mark)
|
85 |
|
86 | _scatter(b, j, wb, xb, mark)
|
87 |
|
88 | for (i = 0; i < rows; i++) {
|
89 |
|
90 | const va = wa[i] === mark ? xa[i] : zero
|
91 | const vb = wb[i] === mark ? xb[i] : zero
|
92 |
|
93 | cdata[i][j] = cf(va, vb)
|
94 | }
|
95 | }
|
96 |
|
97 |
|
98 | return c
|
99 | }
|
100 |
|
101 | function _scatter (m, j, w, x, mark) {
|
102 |
|
103 | const values = m._values
|
104 | const index = m._index
|
105 | const ptr = m._ptr
|
106 |
|
107 | for (let k = ptr[j], k1 = ptr[j + 1]; k < k1; k++) {
|
108 |
|
109 | const i = index[k]
|
110 |
|
111 | w[i] = mark
|
112 | x[i] = values[k]
|
113 | }
|
114 | }
|
115 | })
|