1 | import { csPermute } from './csPermute'
|
2 | import { csPost } from './csPost'
|
3 | import { csEtree } from './csEtree'
|
4 | import { createCsAmd } from './csAmd'
|
5 | import { createCsCounts } from './csCounts'
|
6 | import { factory } from '../../../utils/factory'
|
7 |
|
8 | const name = 'csSqr'
|
9 | const dependencies = [
|
10 | 'add',
|
11 | 'multiply',
|
12 | 'transpose'
|
13 | ]
|
14 |
|
15 | export const createCsSqr = factory(name, dependencies, ({ add, multiply, transpose }) => {
|
16 | const csAmd = createCsAmd({ add, multiply, transpose })
|
17 | const csCounts = createCsCounts({ transpose })
|
18 |
|
19 | |
20 |
|
21 |
|
22 |
|
23 |
|
24 |
|
25 |
|
26 |
|
27 |
|
28 |
|
29 |
|
30 |
|
31 | return function csSqr (order, a, qr) {
|
32 |
|
33 | const aptr = a._ptr
|
34 | const asize = a._size
|
35 |
|
36 | const n = asize[1]
|
37 |
|
38 | let k
|
39 |
|
40 | const s = {}
|
41 |
|
42 | s.q = csAmd(order, a)
|
43 |
|
44 | if (order && !s.q) { return null }
|
45 |
|
46 | if (qr) {
|
47 |
|
48 | const c = order ? csPermute(a, null, s.q, 0) : a
|
49 |
|
50 | s.parent = csEtree(c, 1)
|
51 |
|
52 | const post = csPost(s.parent, n)
|
53 |
|
54 | s.cp = csCounts(c, s.parent, post, 1)
|
55 |
|
56 | if (c && s.parent && s.cp && _vcount(c, s)) {
|
57 |
|
58 | for (s.unz = 0, k = 0; k < n; k++) { s.unz += s.cp[k] }
|
59 | }
|
60 | } else {
|
61 |
|
62 | s.unz = 4 * (aptr[n]) + n
|
63 | s.lnz = s.unz
|
64 | }
|
65 |
|
66 | return s
|
67 | }
|
68 |
|
69 | |
70 |
|
71 |
|
72 | function _vcount (a, s) {
|
73 |
|
74 | const aptr = a._ptr
|
75 | const aindex = a._index
|
76 | const asize = a._size
|
77 |
|
78 | const m = asize[0]
|
79 | const n = asize[1]
|
80 |
|
81 | s.pinv = []
|
82 | s.leftmost = []
|
83 |
|
84 | const parent = s.parent
|
85 | const pinv = s.pinv
|
86 | const leftmost = s.leftmost
|
87 |
|
88 | const w = []
|
89 | const next = 0
|
90 | const head = m
|
91 | const tail = m + n
|
92 | const nque = m + 2 * n
|
93 |
|
94 | let i, k, p, p0, p1
|
95 |
|
96 | for (k = 0; k < n; k++) {
|
97 |
|
98 | w[head + k] = -1
|
99 | w[tail + k] = -1
|
100 | w[nque + k] = 0
|
101 | }
|
102 |
|
103 | for (i = 0; i < m; i++) { leftmost[i] = -1 }
|
104 |
|
105 | for (k = n - 1; k >= 0; k--) {
|
106 |
|
107 | for (p0 = aptr[k], p1 = aptr[k + 1], p = p0; p < p1; p++) {
|
108 |
|
109 | leftmost[aindex[p]] = k
|
110 | }
|
111 | }
|
112 |
|
113 | for (i = m - 1; i >= 0; i--) {
|
114 |
|
115 | pinv[i] = -1
|
116 | k = leftmost[i]
|
117 |
|
118 | if (k === -1) { continue }
|
119 |
|
120 | if (w[nque + k]++ === 0) { w[tail + k] = i }
|
121 |
|
122 | w[next + i] = w[head + k]
|
123 | w[head + k] = i
|
124 | }
|
125 | s.lnz = 0
|
126 | s.m2 = m
|
127 |
|
128 | for (k = 0; k < n; k++) {
|
129 |
|
130 | i = w[head + k]
|
131 |
|
132 | s.lnz++
|
133 |
|
134 | if (i < 0) { i = s.m2++ }
|
135 |
|
136 | pinv[i] = k
|
137 |
|
138 | if (--nque[k] <= 0) { continue }
|
139 |
|
140 | s.lnz += w[nque + k]
|
141 |
|
142 | const pa = parent[k]
|
143 | if (pa !== -1) {
|
144 | if (w[nque + pa] === 0) { w[tail + pa] = w[tail + k] }
|
145 | w[next + w[tail + k]] = w[head + pa]
|
146 | w[head + pa] = w[next + i]
|
147 | w[nque + pa] += w[nque + k]
|
148 | }
|
149 | }
|
150 | for (i = 0; i < m; i++) {
|
151 | if (pinv[i] < 0) { pinv[i] = k++ }
|
152 | }
|
153 | return true
|
154 | }
|
155 | })
|