UNPKG

4.94 kBJavaScriptView Raw
1import { factory } from '../../../utils/factory'
2import { DimensionError } from '../../../error/DimensionError'
3
4const name = 'algorithm04'
5const dependencies = ['typed', 'equalScalar']
6
7export const createAlgorithm04 = /* #__PURE__ */ factory(name, dependencies, ({ typed, equalScalar }) => {
8 /**
9 * Iterates over SparseMatrix A and SparseMatrix B nonzero items and invokes the callback function f(Aij, Bij).
10 * Callback function invoked MAX(NNZA, NNZB) times
11 *
12 *
13 * ┌ f(Aij, Bij) ; A(i,j) !== 0 && B(i,j) !== 0
14 * C(i,j) = ┤ A(i,j) ; A(i,j) !== 0
15 * └ B(i,j) ; B(i,j) !== 0
16 *
17 *
18 * @param {Matrix} a The SparseMatrix instance (A)
19 * @param {Matrix} b The SparseMatrix instance (B)
20 * @param {Function} callback The f(Aij,Bij) operation to invoke
21 *
22 * @return {Matrix} SparseMatrix (C)
23 *
24 * see https://github.com/josdejong/mathjs/pull/346#issuecomment-97620294
25 */
26 return function algorithm04 (a, b, callback) {
27 // sparse matrix arrays
28 const avalues = a._values
29 const aindex = a._index
30 const aptr = a._ptr
31 const asize = a._size
32 const adt = a._datatype
33 // sparse matrix arrays
34 const bvalues = b._values
35 const bindex = b._index
36 const bptr = b._ptr
37 const bsize = b._size
38 const bdt = b._datatype
39
40 // validate dimensions
41 if (asize.length !== bsize.length) { throw new DimensionError(asize.length, bsize.length) }
42
43 // check rows & columns
44 if (asize[0] !== bsize[0] || asize[1] !== bsize[1]) { throw new RangeError('Dimension mismatch. Matrix A (' + asize + ') must match Matrix B (' + bsize + ')') }
45
46 // rows & columns
47 const rows = asize[0]
48 const columns = asize[1]
49
50 // datatype
51 let dt
52 // equal signature to use
53 let eq = equalScalar
54 // zero value
55 let zero = 0
56 // callback signature to use
57 let cf = callback
58
59 // process data types
60 if (typeof adt === 'string' && adt === bdt) {
61 // datatype
62 dt = adt
63 // find signature that matches (dt, dt)
64 eq = typed.find(equalScalar, [dt, dt])
65 // convert 0 to the same datatype
66 zero = typed.convert(0, dt)
67 // callback
68 cf = typed.find(callback, [dt, dt])
69 }
70
71 // result arrays
72 const cvalues = avalues && bvalues ? [] : undefined
73 const cindex = []
74 const cptr = []
75 // matrix
76 const c = a.createSparseMatrix({
77 values: cvalues,
78 index: cindex,
79 ptr: cptr,
80 size: [rows, columns],
81 datatype: dt
82 })
83
84 // workspace
85 const xa = avalues && bvalues ? [] : undefined
86 const xb = avalues && bvalues ? [] : undefined
87 // marks indicating we have a value in x for a given column
88 const wa = []
89 const wb = []
90
91 // vars
92 let i, j, k, k0, k1
93
94 // loop columns
95 for (j = 0; j < columns; j++) {
96 // update cptr
97 cptr[j] = cindex.length
98 // columns mark
99 const mark = j + 1
100 // loop A(:,j)
101 for (k0 = aptr[j], k1 = aptr[j + 1], k = k0; k < k1; k++) {
102 // row
103 i = aindex[k]
104 // update c
105 cindex.push(i)
106 // update workspace
107 wa[i] = mark
108 // check we need to process values
109 if (xa) { xa[i] = avalues[k] }
110 }
111 // loop B(:,j)
112 for (k0 = bptr[j], k1 = bptr[j + 1], k = k0; k < k1; k++) {
113 // row
114 i = bindex[k]
115 // check row exists in A
116 if (wa[i] === mark) {
117 // update record in xa @ i
118 if (xa) {
119 // invoke callback
120 const v = cf(xa[i], bvalues[k])
121 // check for zero
122 if (!eq(v, zero)) {
123 // update workspace
124 xa[i] = v
125 } else {
126 // remove mark (index will be removed later)
127 wa[i] = null
128 }
129 }
130 } else {
131 // update c
132 cindex.push(i)
133 // update workspace
134 wb[i] = mark
135 // check we need to process values
136 if (xb) { xb[i] = bvalues[k] }
137 }
138 }
139 // check we need to process values (non pattern matrix)
140 if (xa && xb) {
141 // initialize first index in j
142 k = cptr[j]
143 // loop index in j
144 while (k < cindex.length) {
145 // row
146 i = cindex[k]
147 // check workspace has value @ i
148 if (wa[i] === mark) {
149 // push value (Aij != 0 || (Aij != 0 && Bij != 0))
150 cvalues[k] = xa[i]
151 // increment pointer
152 k++
153 } else if (wb[i] === mark) {
154 // push value (bij != 0)
155 cvalues[k] = xb[i]
156 // increment pointer
157 k++
158 } else {
159 // remove index @ k
160 cindex.splice(k, 1)
161 }
162 }
163 }
164 }
165 // update cptr
166 cptr[columns] = cindex.length
167
168 // return sparse matrix
169 return c
170 }
171})