1 | ### dot =====================================================================
|
2 |
|
3 | Tags
|
4 | ----
|
5 | scripting, JS, internal, treenode, general concept
|
6 |
|
7 | Parameters
|
8 | ----------
|
9 | a,b,...
|
10 |
|
11 | General description
|
12 | -------------------
|
13 |
|
14 | The inner (or dot) operator gives products of vectors,
|
15 | matrices, and tensors.
|
16 |
|
17 | Note that for Algebrite, the elements of a vector/matrix
|
18 | can only be scalars. This allows for example to flesh out
|
19 | matrix multiplication using the usual multiplication.
|
20 | So for example block-representations are not allowed.
|
21 |
|
22 | There is an aweful lot of confusion between sw packages on
|
23 | what dot and inner do.
|
24 |
|
25 | First off, the "dot" operator is different from the
|
26 | mathematical notion of dot product, which can be
|
27 | slightly confusing.
|
28 |
|
29 | The mathematical notion of dot product is here:
|
30 | http://mathworld.wolfram.com/DotProduct.html
|
31 |
|
32 | However, "dot" does that and a bunch of other things,
|
33 | i.e. in Algebrite
|
34 | dot/inner does what the dot of Mathematica does, i.e.:
|
35 |
|
36 | scalar product of vectors:
|
37 |
|
38 | inner((a, b, c), (x, y, z))
|
39 | > a x + b y + c z
|
40 |
|
41 | products of matrices and vectors:
|
42 |
|
43 | inner(((a, b), (c,d)), (x, y))
|
44 | > (a x + b y,c x + d y)
|
45 |
|
46 | inner((x, y), ((a, b), (c,d)))
|
47 | > (a x + c y,b x + d y)
|
48 |
|
49 | inner((x, y), ((a, b), (c,d)), (r, s))
|
50 | > a r x + b s x + c r y + d s y
|
51 |
|
52 | matrix product:
|
53 |
|
54 | inner(((a,b),(c,d)),((r,s),(t,u)))
|
55 | > ((a r + b t,a s + b u),(c r + d t,c s + d u))
|
56 |
|
57 | the "dot/inner" operator is associative and
|
58 | distributive but not commutative.
|
59 |
|
60 | In Mathematica, Inner is a generalisation of Dot where
|
61 | the user can specify the multiplication and the addition
|
62 | operators.
|
63 | But here in Algebrite they do the same thing.
|
64 |
|
65 | https://reference.wolfram.com/language/ref/Dot.html
|
66 | https://reference.wolfram.com/language/ref/Inner.html
|
67 |
|
68 | http://uk.mathworks.com/help/matlab/ref/dot.html
|
69 | http://uk.mathworks.com/help/matlab/ref/mtimes.html
|
70 |
|
71 | ###
|
72 |
|
73 |
|
74 |
|
75 | Eval_inner = ->
|
76 |
|
77 |
|
78 | # if there are more than two arguments then
|
79 | # reduce it to a more standard version
|
80 | # of two arguments, which means we need to
|
81 | # transform the arguments into a tree of
|
82 | # inner products e.g.
|
83 | # inner(a,b,c) becomes inner(a,inner(b,c))
|
84 | # this is so we can get to a standard binary-tree
|
85 | # version that is simpler to manipulate.
|
86 | theArguments = []
|
87 | theArguments.push car(cdr(p1))
|
88 | secondArgument = car(cdr(cdr(p1)))
|
89 | if secondArgument == symbol(NIL)
|
90 | stop("pattern needs at least a template and a transformed version")
|
91 |
|
92 | moretheArguments = cdr(cdr(p1))
|
93 | while moretheArguments != symbol(NIL)
|
94 | theArguments.push car(moretheArguments)
|
95 | moretheArguments = cdr(moretheArguments)
|
96 |
|
97 | # make it so e.g. inner(a,b,c) becomes inner(a,inner(b,c))
|
98 | if theArguments.length > 2
|
99 | push_symbol(INNER)
|
100 | push theArguments[theArguments.length-2]
|
101 | push theArguments[theArguments.length-1]
|
102 | list(3)
|
103 | for i in [2...theArguments.length]
|
104 | push_symbol(INNER)
|
105 | swap()
|
106 | push theArguments[theArguments.length-i-1]
|
107 | swap()
|
108 | list(3)
|
109 | p1 = pop()
|
110 | Eval_inner()
|
111 | return
|
112 |
|
113 |
|
114 |
|
115 |
|
116 | # TODO we have to take a look at the whole
|
117 | # sequence of operands and make simplifications
|
118 | # on that...
|
119 | operands = []
|
120 | get_innerprod_factors(p1, operands)
|
121 |
|
122 | #console.log "printing operands --------"
|
123 | #for i in [0...operands.length]
|
124 | # console.log "operand " + i + " : " + operands[i]
|
125 |
|
126 | refinedOperands = []
|
127 | # removing all identity matrices
|
128 | for i in [0...operands.length]
|
129 | if operands[i] == symbol(SYMBOL_IDENTITY_MATRIX)
|
130 | continue
|
131 | else refinedOperands.push operands[i]
|
132 | operands = refinedOperands
|
133 |
|
134 | refinedOperands = []
|
135 | if operands.length > 1
|
136 | # removing all consecutive pairs of inverses
|
137 | # so we can answer that inv(a)·a results in the
|
138 | # identity matrix. We want to catch symbolic inverses
|
139 | # not numeric inverses, those will just take care
|
140 | # of themselves when multiplied
|
141 | shift = 0
|
142 | for i in [0...operands.length]
|
143 | #console.log "comparing if " + operands[i+shift] + " and " + operands[i+shift+1] + " are inverses of each other"
|
144 | if (i+shift+1) <= (operands.length - 1)
|
145 | #console.log "isnumerictensor " + operands[i+shift] + " : " + isnumerictensor(operands[i+shift])
|
146 | #console.log "isnumerictensor " + operands[i+shift+1] + " : " + isnumerictensor(operands[i+shift+1])
|
147 | if !(isnumerictensor(operands[i+shift]) or isnumerictensor(operands[i+shift+1]))
|
148 | push operands[i+shift]
|
149 | Eval()
|
150 | inv()
|
151 | push operands[i+shift+1]
|
152 | Eval()
|
153 | subtract()
|
154 | difference = pop()
|
155 | #console.log "result: " + difference
|
156 | if (iszero(difference))
|
157 | shift+=1
|
158 | else
|
159 | refinedOperands.push operands[i+shift]
|
160 | else
|
161 | refinedOperands.push operands[i+shift]
|
162 |
|
163 | else
|
164 | break
|
165 |
|
166 | #console.log "i: " + i + " shift: " + shift + " operands.length: " + operands.length
|
167 |
|
168 | if i+shift == operands.length - 2
|
169 | #console.log "adding last operand 2 "
|
170 | refinedOperands.push operands[operands.length-1]
|
171 | if i+shift >= operands.length - 1
|
172 | break
|
173 | operands = refinedOperands
|
174 |
|
175 | #console.log "refined operands --------"
|
176 | #for i in [0...refinedOperands.length]
|
177 | # console.log "refined operand " + i + " : " + refinedOperands[i]
|
178 |
|
179 |
|
180 | #console.log "stack[tos-1]: " + stack[tos-1]
|
181 |
|
182 | # now rebuild the arguments, just using the
|
183 | # refined operands
|
184 | push symbol(INNER)
|
185 | #console.log "rebuilding the argument ----"
|
186 |
|
187 | if operands.length > 0
|
188 | for i in [0...operands.length]
|
189 | #console.log "pushing " + operands[i]
|
190 | push operands[i]
|
191 | else
|
192 | pop()
|
193 | push symbol(SYMBOL_IDENTITY_MATRIX)
|
194 | return
|
195 | #console.log "list(operands.length): " + (operands.length+1)
|
196 | list(operands.length + 1)
|
197 | p1 = pop()
|
198 |
|
199 |
|
200 | p1 = cdr(p1)
|
201 | push(car(p1))
|
202 | Eval()
|
203 | p1 = cdr(p1)
|
204 | while (iscons(p1))
|
205 | push(car(p1))
|
206 | Eval()
|
207 | inner()
|
208 | p1 = cdr(p1)
|
209 |
|
210 | # inner definition
|
211 | inner = ->
|
212 | save()
|
213 | p2 = pop()
|
214 | p1 = pop()
|
215 |
|
216 | # more in general, when a and b are scalars,
|
217 | # inner(a*M1, b*M2) is equal to
|
218 | # a*b*inner(M1,M2), but of course we can only
|
219 | # "bring out" in a and b the scalars, because
|
220 | # it's the only commutative part.
|
221 | # that's going to be trickier to do in general
|
222 | # but let's start with just the signs.
|
223 | if isnegativeterm(p2) and isnegativeterm(p1)
|
224 | push p2
|
225 | negate()
|
226 | p2 = pop()
|
227 | push p1
|
228 | negate()
|
229 | p1 = pop()
|
230 |
|
231 | # since inner is associative,
|
232 | # put it in a canonical form i.e.
|
233 | # inner(inner(a,b),c) ->
|
234 | # inner(a,inner(b,c))
|
235 | # so that we can recognise when they
|
236 | # are equal.
|
237 | if isinnerordot(p1)
|
238 | arg1 = car(cdr(p1)) #a
|
239 | arg2 = car(cdr(cdr(p1))) #b
|
240 | arg3 = p2
|
241 | p1 = arg1
|
242 | push arg2
|
243 | push arg3
|
244 | inner()
|
245 | p2 = pop()
|
246 |
|
247 | # Check if one of the operands is the identity matrix
|
248 | # we could maybe use Eval_testeq here but
|
249 | # this seems to suffice?
|
250 | if p1 == symbol(SYMBOL_IDENTITY_MATRIX)
|
251 | push p2
|
252 | restore()
|
253 | return
|
254 | else if p2 == symbol(SYMBOL_IDENTITY_MATRIX)
|
255 | push p1
|
256 | restore()
|
257 | return
|
258 |
|
259 |
|
260 | if (istensor(p1) && istensor(p2))
|
261 | inner_f()
|
262 | else
|
263 |
|
264 | # simple check if the two consecutive elements are one the
|
265 | # (symbolic) inv of the other. If they are, the answer is
|
266 | # the identity matrix
|
267 | if !(isnumerictensor(p1) or isnumerictensor(p2))
|
268 | push p1
|
269 | push p2
|
270 | inv()
|
271 | subtract()
|
272 | subtractionResult = pop()
|
273 | if (iszero(subtractionResult))
|
274 | push_symbol(SYMBOL_IDENTITY_MATRIX)
|
275 | restore()
|
276 | return
|
277 |
|
278 |
|
279 | # if either operand is a sum then distribute
|
280 | # (if we are in expanding mode)
|
281 | if (expanding && isadd(p1))
|
282 | p1 = cdr(p1)
|
283 | push(zero)
|
284 | while (iscons(p1))
|
285 | push(car(p1))
|
286 | push(p2)
|
287 | inner()
|
288 | add()
|
289 | p1 = cdr(p1)
|
290 | restore()
|
291 | return
|
292 |
|
293 | if (expanding && isadd(p2))
|
294 | p2 = cdr(p2)
|
295 | push(zero)
|
296 | while (iscons(p2))
|
297 | push(p1)
|
298 | push(car(p2))
|
299 | inner()
|
300 | add()
|
301 | p2 = cdr(p2)
|
302 | restore()
|
303 | return
|
304 |
|
305 | push(p1)
|
306 | push(p2)
|
307 |
|
308 |
|
309 | # there are 8 remaining cases here, since each of the
|
310 | # two arguments can only be a scalar/tensor/unknown
|
311 | # and the tensor - tensor case was caught
|
312 | # upper in the code
|
313 | if (istensor(p1) and isnum(p2))
|
314 | # one case covered by this branch:
|
315 | # tensor - scalar
|
316 | tensor_times_scalar()
|
317 | else if (isnum(p1) and istensor(p2))
|
318 | # one case covered by this branch:
|
319 | # scalar - tensor
|
320 | scalar_times_tensor()
|
321 | else
|
322 | if (isnum(p1) or isnum(p2))
|
323 | # three cases covered by this branch:
|
324 | # unknown - scalar
|
325 | # scalar - unknown
|
326 | # scalar - scalar
|
327 | # in these cases a normal multiplication
|
328 | # will be OK
|
329 | multiply()
|
330 | else
|
331 | # three cases covered by this branch:
|
332 | # unknown - unknown
|
333 | # unknown - tensor
|
334 | # tensor - unknown
|
335 | # in this case we can't use normal
|
336 | # multiplication.
|
337 | pop()
|
338 | pop()
|
339 | push_symbol(INNER)
|
340 | push(p1)
|
341 | push(p2)
|
342 | list(3)
|
343 | restore()
|
344 | return
|
345 |
|
346 |
|
347 | restore()
|
348 |
|
349 | # inner product of tensors p1 and p2
|
350 | inner_f = ->
|
351 |
|
352 | i = 0
|
353 | n = p1.tensor.dim[p1.tensor.ndim - 1]
|
354 |
|
355 | if (n != p2.tensor.dim[0])
|
356 | debugger
|
357 | stop("inner: tensor dimension check")
|
358 |
|
359 | ndim = p1.tensor.ndim + p2.tensor.ndim - 2
|
360 |
|
361 | if (ndim > MAXDIM)
|
362 | stop("inner: rank of result exceeds maximum")
|
363 |
|
364 | a = p1.tensor.elem
|
365 | b = p2.tensor.elem
|
366 |
|
367 | #---------------------------------------------------------------------
|
368 | #
|
369 | # ak is the number of rows in tensor A
|
370 | #
|
371 | # bk is the number of columns in tensor B
|
372 | #
|
373 | # Example:
|
374 | #
|
375 | # A[3][3][4] B[4][4][3]
|
376 | #
|
377 | # 3 3 ak = 3 * 3 = 9
|
378 | #
|
379 | # 4 3 bk = 4 * 3 = 12
|
380 | #
|
381 | #---------------------------------------------------------------------
|
382 |
|
383 | ak = 1
|
384 | for i in [0...(p1.tensor.ndim - 1)]
|
385 | ak *= p1.tensor.dim[i]
|
386 |
|
387 | bk = 1
|
388 | for i in [1...p2.tensor.ndim]
|
389 | bk *= p2.tensor.dim[i]
|
390 |
|
391 | p3 = alloc_tensor(ak * bk)
|
392 |
|
393 | c = p3.tensor.elem
|
394 |
|
395 | # new method copied from ginac http://www.ginac.de/
|
396 | for i in [0...ak]
|
397 | for j in [0...n]
|
398 | if (iszero(a[i * n + j]))
|
399 | continue
|
400 | for k in [0...bk]
|
401 | push(a[i * n + j])
|
402 | push(b[j * bk + k])
|
403 | multiply()
|
404 | push(c[i * bk + k])
|
405 | add()
|
406 | c[i * bk + k] = pop()
|
407 |
|
408 | #---------------------------------------------------------------------
|
409 | #
|
410 | # Note on understanding "k * bk + j"
|
411 | #
|
412 | # k * bk because each element of a column is bk locations apart
|
413 | #
|
414 | # + j because the beginnings of all columns are in the first bk
|
415 | # locations
|
416 | #
|
417 | # Example: n = 2, bk = 6
|
418 | #
|
419 | # b111 <- 1st element of 1st column
|
420 | # b112 <- 1st element of 2nd column
|
421 | # b113 <- 1st element of 3rd column
|
422 | # b121 <- 1st element of 4th column
|
423 | # b122 <- 1st element of 5th column
|
424 | # b123 <- 1st element of 6th column
|
425 | #
|
426 | # b211 <- 2nd element of 1st column
|
427 | # b212 <- 2nd element of 2nd column
|
428 | # b213 <- 2nd element of 3rd column
|
429 | # b221 <- 2nd element of 4th column
|
430 | # b222 <- 2nd element of 5th column
|
431 | # b223 <- 2nd element of 6th column
|
432 | #
|
433 | #---------------------------------------------------------------------
|
434 |
|
435 | if (ndim == 0)
|
436 | push(p3.tensor.elem[0])
|
437 | else
|
438 | p3.tensor.ndim = ndim
|
439 | j = 0
|
440 | for i in [0...(p1.tensor.ndim - 1)]
|
441 | p3.tensor.dim[i] = p1.tensor.dim[i]
|
442 | j = p1.tensor.ndim - 1
|
443 | for i in [0...(p2.tensor.ndim - 1)]
|
444 | p3.tensor.dim[j + i] = p2.tensor.dim[i + 1]
|
445 | push(p3)
|
446 |
|
447 | # Algebrite.run('c·(b+a)ᵀ·inv((a+b)ᵀ)·d').toString();
|
448 | # Algebrite.run('c*(b+a)ᵀ·inv((a+b)ᵀ)·d').toString();
|
449 | # Algebrite.run('(c·(b+a)ᵀ)·(inv((a+b)ᵀ)·d)').toString();
|
450 | get_innerprod_factors = (tree, factors_accumulator) ->
|
451 | # console.log "extracting inner prod. factors from " + tree
|
452 |
|
453 | if !iscons(tree)
|
454 | add_factor_to_accumulator(tree, factors_accumulator)
|
455 | return
|
456 |
|
457 | if cdr(tree) == symbol(NIL)
|
458 | tree = get_innerprod_factors(car(tree), factors_accumulator)
|
459 | return
|
460 |
|
461 | if isinnerordot(tree)
|
462 | # console.log "there is inner at top, recursing on the operands"
|
463 | get_innerprod_factors(car(cdr(tree)),factors_accumulator)
|
464 | get_innerprod_factors(cdr(cdr(tree)),factors_accumulator)
|
465 | return
|
466 |
|
467 | add_factor_to_accumulator(tree, factors_accumulator)
|
468 |
|
469 | add_factor_to_accumulator = (tree, factors_accumulator) ->
|
470 | if tree != symbol(NIL)
|
471 | # console.log ">> adding to factors_accumulator: " + tree
|
472 | factors_accumulator.push(tree)
|
473 |
|