1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 | import { convertToTensor } from '../tensor_util_env';
|
18 | import * as util from '../util';
|
19 | import { matMul } from './mat_mul';
|
20 | import { op } from './operation';
|
21 | import { reshape } from './reshape';
|
22 |
|
23 |
|
24 |
|
25 |
|
26 |
|
27 |
|
28 |
|
29 |
|
30 |
|
31 |
|
32 |
|
33 |
|
34 |
|
35 |
|
36 |
|
37 |
|
38 |
|
39 | function dot_(t1, t2) {
|
40 | const $t1 = convertToTensor(t1, 't1', 'dot');
|
41 | const $t2 = convertToTensor(t2, 't2', 'dot');
|
42 | util.assert(($t1.rank === 1 || $t1.rank === 2) && ($t2.rank === 1 || $t2.rank === 2), () => `Error in dot: inputs must all be rank 1 or 2, but got ranks ` +
|
43 | `${$t1.rank} and ${$t2.rank}.`);
|
44 | const t1Inner = ($t1.rank === 1 ? $t1.size : $t1.shape[1]);
|
45 | const t2Inner = ($t2.rank === 1 ? $t2.size : $t2.shape[0]);
|
46 | util.assert(t1Inner === t2Inner, () => `Error in dot: inner dimensions of inputs must match, but got ` +
|
47 | `${t1Inner} and ${t2Inner}.`);
|
48 | if ($t1.rank === 1 && $t2.rank === 1) {
|
49 | const t12D = reshape($t1, [1, -1]);
|
50 | const t22D = reshape($t2, [-1, 1]);
|
51 | const t1t2 = matMul(t12D, t22D);
|
52 | return reshape(t1t2, []);
|
53 | }
|
54 | else if ($t1.rank === 1 && $t2.rank === 2) {
|
55 | const t12D = reshape($t1, [1, -1]);
|
56 | const t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
|
57 | const t1t2 = matMul(t12D, t22D);
|
58 | return reshape(t1t2, [t1t2.size]);
|
59 | }
|
60 | else if ($t1.rank === 2 && $t2.rank === 1) {
|
61 | const t22D = reshape($t2, [-1, 1]);
|
62 | const t1t2 = matMul($t1, t22D);
|
63 | return reshape(t1t2, [t1t2.size]);
|
64 | }
|
65 | else {
|
66 | const t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
|
67 | const t1t2 = matMul($t1, t22D);
|
68 | return t1t2;
|
69 | }
|
70 | }
|
71 | export const dot = op({ dot_ });
|
72 |
|
\ | No newline at end of file |