1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 |
|
18 |
|
19 |
|
20 |
|
21 |
|
22 |
|
23 | export const MATMUL_SHARED_DIM_THRESHOLD = 1000;
|
24 | import * as tf from '../index';
|
25 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
26 | import { expectArraysClose, expectArraysEqual } from '../test_util';
|
27 | describeWithFlags('matmul', ALL_ENVS, () => {
|
28 | it('A x B', async () => {
|
29 | const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
30 | const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
|
31 | const c = tf.matMul(a, b);
|
32 | expect(c.shape).toEqual([2, 2]);
|
33 | expectArraysClose(await c.data(), [0, 8, -3, 20]);
|
34 | });
|
35 | it('[8,4]x[4,8]', async () => {
|
36 | const a = tf.tensor2d([
|
37 | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
38 | 17, 18, 19, 20, 21, 22, 23, 24, 1, 2, 3, 4, 5, 6, 7, 8
|
39 | ], [8, 4]);
|
40 | const b = tf.tensor2d([
|
41 | 0, 1, -3, 2, 1, -1, 0, 5, 6, 7, 8, 0, -2, -2, 1, 9,
|
42 | 11, 10, 0, 1, -3, 2, 1, -1, 1, 2, 3, 4, 5, 6, 7, 8
|
43 | ], [4, 8]);
|
44 | const c = tf.matMul(a, b);
|
45 | const cData = await c.data();
|
46 | expect(c.shape).toEqual([8, 8]);
|
47 | expectArraysClose(cData, [
|
48 | 49, 53, 25, 21, 8, 25, 33, 52, 121, 133, 57, 49, 12,
|
49 | 45, 69, 136, 193, 213, 89, 77, 16, 65, 105, 220, 265, 293,
|
50 | 121, 105, 20, 85, 141, 304, 337, 373, 153, 133, 24, 105, 177,
|
51 | 388, 409, 453, 185, 161, 28, 125, 213, 472, 49, 53, 25, 21,
|
52 | 8, 25, 33, 52, 121, 133, 57, 49, 12, 45, 69, 136
|
53 | ]);
|
54 | });
|
55 | it('broadcast with unequal batch dims', async () => {
|
56 | const a = tf.tensor3d([
|
57 | 2, 1, 3, 2, 1, 1, 1, 5, 6, 7, 8, 1,
|
58 | 2, 2, 1, 9, 11, 10, 1, 1, 3, 2, 1, 1
|
59 | ], [4, 3, 2]);
|
60 | const b = tf.tensor3d([1, 0.5], [1, 2, 1]);
|
61 | const c = tf.matMul(a, b);
|
62 | expect(c.shape).toEqual([4, 3, 1]);
|
63 | expectArraysClose(await c.data(), [2.5, 4, 1.5, 3.5, 9.5, 8.5, 3, 5.5, 16, 1.5, 4, 1.5]);
|
64 | });
|
65 | it('broadcast with unequal ranks', async () => {
|
66 | const a = tf.tensor5d([
|
67 | 2, 1, 3, 2, 1, 1, 1, 5, 6, 7, 8, 1,
|
68 | 2, 2, 1, 9, 11, 10, 1, 1, 3, 2, 1, 1
|
69 | ], [1, 2, 2, 3, 2]);
|
70 | const b = tf.tensor2d([1, 0.5], [2, 1]);
|
71 | const c = tf.matMul(a, b);
|
72 | expect(c.shape).toEqual([1, 2, 2, 3, 1]);
|
73 | expectArraysClose(await c.data(), [2.5, 4, 1.5, 3.5, 9.5, 8.5, 3, 5.5, 16, 1.5, 4, 1.5]);
|
74 | });
|
75 | it('matmul followed by mul', async () => {
|
76 | const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
|
77 | const b = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
78 | const c = tf.matMul(a, b);
|
79 | const f = tf.tensor2d([0, 1, 0.5, 0, 0.25, 2], [2, 3]);
|
80 | const d = tf.mul(c, f);
|
81 | const dData = await d.data();
|
82 | expect(d.shape).toEqual([2, 3]);
|
83 | expectArraysClose(dData, [0, 12, 7.5, 0, 6.5, 66]);
|
84 | });
|
85 | it('upcasts when dtypes dont match', async () => {
|
86 | const a = [1, 2, 3, 4, 5, 6];
|
87 | const b = [0, 1, -3, 2, 2, 1];
|
88 | let c = tf.matMul(tf.tensor(a, [2, 3], 'float32'), tf.tensor(b, [3, 2], 'int32'));
|
89 | expect(c.shape).toEqual([2, 2]);
|
90 | expect(c.dtype).toBe('float32');
|
91 | expectArraysClose(await c.data(), [0, 8, -3, 20]);
|
92 | c = tf.matMul(tf.tensor(a, [2, 3], 'int32'), tf.tensor(b, [3, 2], 'bool'));
|
93 | expect(c.shape).toEqual([2, 2]);
|
94 | expect(c.dtype).toBe('int32');
|
95 | expectArraysClose(await c.data(), [5, 6, 11, 15]);
|
96 | });
|
97 | it('A x B^t', async () => {
|
98 | const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
99 | const b = tf.tensor2d([1, 0, 2, 4, 3, 0], [2, 3]);
|
100 | const transposeA = false;
|
101 | const transposeB = true;
|
102 | const c = tf.matMul(a, b, transposeA, transposeB);
|
103 | const expected = [7, 10, 16, 31];
|
104 | expectArraysClose(await c.data(), expected);
|
105 | });
|
106 | it('A^t x B', async () => {
|
107 | const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
108 | const b = tf.tensor2d([1, 0, 2, 4, 3, 0], [2, 3]);
|
109 | const transposeA = true;
|
110 | const transposeB = false;
|
111 | const c = tf.matMul(a, b, transposeA, transposeB);
|
112 | const expected = [17, 12, 2, 22, 15, 4, 27, 18, 6];
|
113 | expectArraysClose(await c.data(), expected);
|
114 | });
|
115 | it('A^t x B^t', async () => {
|
116 | const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
|
117 | const b = tf.tensor2d([1, 0, 2, 4, 3, 0], [2, 3]);
|
118 | const transposeA = true;
|
119 | const transposeB = true;
|
120 | const c = tf.matMul(a, b, transposeA, transposeB);
|
121 | const expected = [11, 13, 14, 20];
|
122 | expectArraysClose(await c.data(), expected);
|
123 | });
|
124 | it('A x B^t shapes do not match', () => {
|
125 | const a = tf.zeros([2, 3]);
|
126 | const b = tf.zeros([3, 2]);
|
127 | const f = () => {
|
128 | const transposeA = false;
|
129 | const transposeB = true;
|
130 | tf.matMul(a, b, transposeA, transposeB);
|
131 | };
|
132 | expect(f).toThrowError();
|
133 | });
|
134 | it('A^t x B shapes do not match', () => {
|
135 | const a = tf.zeros([2, 3]);
|
136 | const b = tf.zeros([3, 2]);
|
137 | const f = () => {
|
138 | const transposeA = true;
|
139 | const transposeB = false;
|
140 | tf.matMul(a, b, transposeA, transposeB);
|
141 | };
|
142 | expect(f).toThrowError();
|
143 | });
|
144 | it('A^t x B^t shapes do not match', () => {
|
145 | const a = tf.zeros([3, 2]);
|
146 | const b = tf.zeros([3, 2]);
|
147 | const f = () => {
|
148 | const transposeA = true;
|
149 | const transposeB = true;
|
150 | tf.matMul(a, b, transposeA, transposeB);
|
151 | };
|
152 | expect(f).toThrowError();
|
153 | });
|
154 | it('matmul throws when inner dimensions dont match', () => {
|
155 | const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
156 | const b = tf.tensor2d([0, 1, -3, 2, 2, 1, 2, 2], [4, 2]);
|
157 | expect(() => tf.matMul(a, b)).toThrowError();
|
158 | });
|
159 | it('matmul throws when passed non matrices', () => {
|
160 |
|
161 | const a = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]);
|
162 | const b = tf.tensor2d([0, 1, -3, 2, 2, 1, 2, 2], [4, 2]);
|
163 | expect(() => tf.matMul(a, b)).toThrowError();
|
164 | expect(() => tf.matMul(b, a)).toThrowError();
|
165 | });
|
166 | it('matmul throws when passed a vector', () => {
|
167 |
|
168 | const v = tf.tensor1d([2, 3]);
|
169 | const matrix = tf.tensor2d([1, 2, 3, 4], [2, 2]);
|
170 | expect(() => tf.matMul(matrix, v)).toThrowError();
|
171 | });
|
172 | it('Vector times matrix', async () => {
|
173 | const v = tf.tensor1d([2, 3]);
|
174 | const matrix = tf.tensor2d([1, 2, 3, 4], [2, 2]);
|
175 | const result = tf.dot(v, matrix);
|
176 | const expected = [11, 16];
|
177 | expectArraysClose(await result.data(), expected);
|
178 | });
|
179 | it('Vector times matrix with implicit reshape', async () => {
|
180 | const v = tf.tensor1d([2, 3]);
|
181 | const matrix = tf.tensor2d([1, 2, 3, 4], [2, 2]);
|
182 | const result = tf.dot(v, matrix);
|
183 | const expected = [11, 16];
|
184 | expectArraysClose(await result.data(), expected);
|
185 | });
|
186 | it('Matrix times vector', async () => {
|
187 | const matrix = tf.tensor2d([1, 2, 3, 4], [2, 2]);
|
188 | const v = tf.tensor1d([2, 3]);
|
189 | const result = tf.dot(matrix, v);
|
190 | const expected = [8, 18];
|
191 | expectArraysClose(await result.data(), expected);
|
192 | });
|
193 | it('batched matmul with the matrices being vectors', async () => {
|
194 | const batch = 3;
|
195 | const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
|
196 | const values = new Float32Array(batch * sharedDim);
|
197 | values[10] = 2;
|
198 | const a = tf.tensor(values, [batch, 1, sharedDim]);
|
199 | const b = tf.tensor(values, [batch, sharedDim, 1]);
|
200 | const result = tf.matMul(a, b);
|
201 | expect(result.shape).toEqual([batch, 1, 1]);
|
202 | expectArraysClose(await result.data(), [4, 0, 0]);
|
203 | });
|
204 | it('batched matmul called twice so memory of output is reused', async () => {
|
205 | const batch = 3;
|
206 | const n = 2;
|
207 | const vals = new Float32Array(batch * n * n);
|
208 | vals[0] = 2;
|
209 | vals[4] = 3;
|
210 | vals[8] = 4;
|
211 | const a = tf.tensor(vals, [batch, n, n]);
|
212 | const b = tf.tensor(vals, [batch, n, n]);
|
213 | const result = tf.matMul(a, b);
|
214 | expect(result.shape).toEqual([batch, n, n]);
|
215 | expectArraysClose(await result.data(), [4, 0, 0, 0, 9, 0, 0, 0, 16, 0, 0, 0]);
|
216 |
|
217 |
|
218 | result.dispose();
|
219 | const vals2 = new Float32Array(batch * n * n);
|
220 | vals2[3] = 2;
|
221 | vals2[7] = 3;
|
222 | vals2[11] = 4;
|
223 | const a2 = tf.tensor(vals2, [batch, n, n]);
|
224 | const b2 = tf.tensor(vals2, [batch, n, n]);
|
225 | const result2 = tf.matMul(a2, b2);
|
226 | expect(result2.shape).toEqual([batch, n, n]);
|
227 | expectArraysClose(await result2.data(), [0, 0, 0, 4, 0, 0, 0, 9, 0, 0, 0, 16]);
|
228 | });
|
229 | it('batched matmul with the matrices being vectors transposedA', async () => {
|
230 | const batch = 3;
|
231 | const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
|
232 | const values = new Float32Array(batch * sharedDim);
|
233 | values[10] = 2;
|
234 | const a = tf.tensor(values, [batch, sharedDim, 1]);
|
235 | const b = tf.tensor(values, [batch, sharedDim, 1]);
|
236 | const transposeA = true;
|
237 | const transposeB = false;
|
238 | const result = tf.matMul(a, b, transposeA, transposeB);
|
239 | expect(result.shape).toEqual([batch, 1, 1]);
|
240 | expectArraysClose(await result.data(), [4, 0, 0]);
|
241 | });
|
242 | it('batched matmul with the matrices being vectors transposedB', async () => {
|
243 | const batch = 3;
|
244 | const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
|
245 | const values = new Float32Array(batch * sharedDim);
|
246 | values[10] = 2;
|
247 | const a = tf.tensor(values, [batch, 1, sharedDim]);
|
248 | const b = tf.tensor(values, [batch, 1, sharedDim]);
|
249 | const transposeA = false;
|
250 | const transposeB = true;
|
251 | const result = tf.matMul(a, b, transposeA, transposeB);
|
252 | expect(result.shape).toEqual([batch, 1, 1]);
|
253 | expectArraysClose(await result.data(), [4, 0, 0]);
|
254 | });
|
255 | it('batched matmul with matrix x vector', async () => {
|
256 | const batch = 3;
|
257 | const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
|
258 | const values = new Float32Array(batch * sharedDim);
|
259 | values[10] = 2;
|
260 | const a = tf.ones([batch, 2, sharedDim]);
|
261 | const b = tf.tensor(values, [batch, sharedDim, 1]);
|
262 | const result = tf.matMul(a, b);
|
263 | expect(result.shape).toEqual([batch, 2, 1]);
|
264 | expectArraysClose(await result.data(), [2, 2, 0, 0, 0, 0]);
|
265 | });
|
266 | it('batched matmul with matrix x vector transposedA', async () => {
|
267 | const batch = 3;
|
268 | const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
|
269 | const values = new Float32Array(batch * sharedDim);
|
270 | values[10] = 2;
|
271 | const a = tf.ones([batch, sharedDim, 2]);
|
272 | const b = tf.tensor(values, [batch, sharedDim, 1]);
|
273 | const transposeA = true;
|
274 | const transposeB = false;
|
275 | const result = tf.matMul(a, b, transposeA, transposeB);
|
276 | expect(result.shape).toEqual([batch, 2, 1]);
|
277 | expectArraysClose(await result.data(), [2, 2, 0, 0, 0, 0]);
|
278 | });
|
279 | it('batched matmul with matrix x vector transposedB', async () => {
|
280 | const batch = 3;
|
281 | const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
|
282 | const values = new Float32Array(batch * sharedDim);
|
283 | values[10] = 2;
|
284 | const a = tf.ones([batch, 2, sharedDim]);
|
285 | const b = tf.tensor(values, [batch, 1, sharedDim]);
|
286 | const transposeA = false;
|
287 | const transposeB = true;
|
288 | const result = tf.matMul(a, b, transposeA, transposeB);
|
289 | expect(result.shape).toEqual([batch, 2, 1]);
|
290 | expectArraysClose(await result.data(), [2, 2, 0, 0, 0, 0]);
|
291 | });
|
292 | it('batched matmul with vector x matrix', async () => {
|
293 | const batch = 3;
|
294 | const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
|
295 | const values = new Float32Array(batch * sharedDim);
|
296 | values[10] = 2;
|
297 | const a = tf.tensor(values, [batch, 1, sharedDim]);
|
298 | const b = tf.ones([batch, sharedDim, 2]);
|
299 | const result = tf.matMul(a, b);
|
300 | expect(result.shape).toEqual([batch, 1, 2]);
|
301 | expectArraysClose(await result.data(), [2, 2, 0, 0, 0, 0]);
|
302 | });
|
303 | it('batched matmul with vector x matrix transposedA', async () => {
|
304 | const batch = 3;
|
305 | const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
|
306 | const values = new Float32Array(batch * sharedDim);
|
307 | values[10] = 2;
|
308 | const a = tf.tensor(values, [batch, sharedDim, 1]);
|
309 | const b = tf.ones([batch, sharedDim, 2]);
|
310 | const transposeA = true;
|
311 | const transposeB = false;
|
312 | const result = tf.matMul(a, b, transposeA, transposeB);
|
313 | expect(result.shape).toEqual([batch, 1, 2]);
|
314 | expectArraysClose(await result.data(), [2, 2, 0, 0, 0, 0]);
|
315 | });
|
316 | it('batched matmul with vector x matrix transposedB', async () => {
|
317 | const batch = 3;
|
318 | const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
|
319 | const values = new Float32Array(batch * sharedDim);
|
320 | values[10] = 2;
|
321 | const a = tf.tensor(values, [batch, 1, sharedDim]);
|
322 | const b = tf.ones([batch, 2, sharedDim]);
|
323 | const transposeA = false;
|
324 | const transposeB = true;
|
325 | const result = tf.matMul(a, b, transposeA, transposeB);
|
326 | expect(result.shape).toEqual([batch, 1, 2]);
|
327 | expectArraysClose(await result.data(), [2, 2, 0, 0, 0, 0]);
|
328 | });
|
329 | it('Matrix * vector propagates NaNs', async () => {
|
330 | const matrix = tf.tensor2d([1, 2, 3, 4], [2, 2]);
|
331 | const v = tf.tensor1d([2, NaN]);
|
332 | const result = tf.dot(matrix, v);
|
333 | const expected = [NaN, NaN];
|
334 | expectArraysClose(await result.data(), expected);
|
335 | });
|
336 | it('matrix times vector throws when not passed a matrix', () => {
|
337 | const v = tf.tensor1d([2, 3]);
|
338 |
|
339 | const matrix = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]);
|
340 | expect(() => tf.dot(matrix, v)).toThrowError();
|
341 | });
|
342 | it('Dot product', async () => {
|
343 | const v1 = tf.tensor1d([2, 3]);
|
344 | const v2 = tf.tensor1d([2, 1]);
|
345 | const result = tf.dot(v1, v2);
|
346 | expectArraysClose(await result.data(), [7]);
|
347 | });
|
348 | it('Dot product propagates NaNs', async () => {
|
349 | const v1 = tf.tensor1d([2, NaN]);
|
350 | const v2 = tf.tensor1d([2, 1]);
|
351 | const result = tf.dot(v1, v2);
|
352 | expectArraysEqual(await result.data(), [NaN]);
|
353 | });
|
354 | it('Dot product throws when vectors are different size', () => {
|
355 | const v1 = tf.tensor1d([2, 3, 3]);
|
356 | const v2 = tf.tensor1d([2, 1]);
|
357 | expect(() => tf.dot(v1, v2)).toThrowError();
|
358 | expect(() => tf.dot(v2, v1)).toThrowError();
|
359 | });
|
360 | it('Outer product', async () => {
|
361 | const v1 = tf.tensor1d([2, 3]);
|
362 | const v2 = tf.tensor1d([2, 1]);
|
363 | const result = tf.outerProduct(v1, v2);
|
364 | const expected = [4, 2, 6, 3];
|
365 | expect(result.shape).toEqual([2, 2]);
|
366 | expectArraysClose(await result.data(), expected);
|
367 | });
|
368 | it('outer product accepts a tensor-like object', async () => {
|
369 | const v1 = [2, 3];
|
370 | const v2 = [2, 1];
|
371 | const result = tf.outerProduct(v1, v2);
|
372 | const expected = [4, 2, 6, 3];
|
373 | expect(result.shape).toEqual([2, 2]);
|
374 | expectArraysClose(await result.data(), expected);
|
375 | });
|
376 | it('gradients: A * B', async () => {
|
377 | const aT = tf.tensor2d([1, 2, 3, 10, 20, 30], [2, 3]);
|
378 | const bT = tf.tensor2d([2, 3, 4, 1, 2, 3], [3, 2]);
|
379 | const dyT = tf.tensor2d([1, 10, 20, 30], [2, 2]);
|
380 | const transposeA = false;
|
381 | const transposeB = false;
|
382 | const grads = tf.grads((a, b) => tf.matMul(a, b, transposeA, transposeB));
|
383 | const [da, db] = grads([aT, bT], dyT);
|
384 |
|
385 | expect(da.shape).toEqual(aT.shape);
|
386 | const a = await aT.buffer();
|
387 | const dy = await dyT.buffer();
|
388 | const b = await bT.buffer();
|
389 | expectArraysClose(await da.data(), [
|
390 | dy.get(0, 0) * b.get(0, 0) + dy.get(0, 1) * b.get(0, 1),
|
391 | dy.get(0, 0) * b.get(1, 0) + dy.get(0, 1) * b.get(1, 1),
|
392 | dy.get(0, 0) * b.get(2, 0) + dy.get(0, 1) * b.get(2, 1),
|
393 | dy.get(1, 0) * b.get(0, 0) + dy.get(1, 1) * b.get(0, 1),
|
394 | dy.get(1, 0) * b.get(1, 0) + dy.get(1, 1) * b.get(1, 1),
|
395 | dy.get(1, 0) * b.get(2, 0) + dy.get(1, 1) * b.get(2, 1)
|
396 | ], 1e-1);
|
397 |
|
398 | expect(db.shape).toEqual(b.shape);
|
399 | expectArraysClose(await db.data(), [
|
400 | a.get(0, 0) * dy.get(0, 0) + a.get(1, 0) * dy.get(1, 0),
|
401 | a.get(0, 0) * dy.get(0, 1) + a.get(1, 0) * dy.get(1, 1),
|
402 | a.get(0, 1) * dy.get(0, 0) + a.get(1, 1) * dy.get(1, 0),
|
403 | a.get(0, 1) * dy.get(0, 1) + a.get(1, 1) * dy.get(1, 1),
|
404 | a.get(0, 2) * dy.get(0, 0) + a.get(1, 2) * dy.get(1, 0),
|
405 | a.get(0, 2) * dy.get(0, 1) + a.get(1, 2) * dy.get(1, 1)
|
406 | ]);
|
407 | });
|
408 | it('gradient with clones', () => {
|
409 | const a = tf.tensor2d([1, 2, 3, 10, 20, 30], [2, 3]);
|
410 | const b = tf.tensor2d([2, 3, 4, 1, 2, 3], [3, 2]);
|
411 | const grads = tf.grads((a, b) => tf.matMul(a.clone(), b.clone()).clone());
|
412 | const [da, db] = grads([a, b]);
|
413 | expect(da.shape).toEqual(a.shape);
|
414 | expect(db.shape).toEqual(b.shape);
|
415 | });
|
416 | it('gradients: a * bT', async () => {
|
417 | const aT = tf.tensor2d([1, 2, 3, 10, 20, 30], [3, 2]);
|
418 | const bT = tf.tensor2d([2, 3, 4, 1, 2, 3], [3, 2]);
|
419 | const dyT = tf.tensor2d([1, 10, 20, 30, 40, 50, 60, 70, 80], [3, 3]);
|
420 | const transposeA = false;
|
421 | const transposeB = true;
|
422 | const grads = tf.grads((a, b) => tf.matMul(a, b, transposeA, transposeB));
|
423 | const [da, db] = grads([aT, bT], dyT);
|
424 |
|
425 | expect(da.shape).toEqual(aT.shape);
|
426 | const a = await aT.buffer();
|
427 | const dy = await dyT.buffer();
|
428 | const b = await bT.buffer();
|
429 | expectArraysClose(await da.data(), [
|
430 | dy.get(0, 0) * b.get(0, 0) + dy.get(0, 1) * b.get(1, 0) +
|
431 | dy.get(0, 2) * b.get(2, 0),
|
432 | dy.get(0, 0) * b.get(0, 1) + dy.get(0, 1) * b.get(1, 1) +
|
433 | dy.get(0, 2) * b.get(2, 1),
|
434 | dy.get(1, 0) * b.get(0, 0) + dy.get(1, 1) * b.get(1, 0) +
|
435 | dy.get(1, 2) * b.get(2, 0),
|
436 | dy.get(1, 0) * b.get(0, 1) + dy.get(1, 1) * b.get(1, 1) +
|
437 | dy.get(1, 2) * b.get(2, 1),
|
438 | dy.get(2, 0) * b.get(0, 0) + dy.get(2, 1) * b.get(1, 0) +
|
439 | dy.get(2, 2) * b.get(2, 0),
|
440 | dy.get(2, 0) * b.get(0, 1) + dy.get(2, 1) * b.get(1, 1) +
|
441 | dy.get(2, 2) * b.get(2, 1)
|
442 | ]);
|
443 |
|
444 | expect(db.shape).toEqual(b.shape);
|
445 | expectArraysClose(await db.data(), [
|
446 | dy.get(0, 0) * a.get(0, 0) + dy.get(1, 0) * a.get(1, 0) +
|
447 | dy.get(2, 0) * a.get(2, 0),
|
448 | dy.get(0, 0) * a.get(0, 1) + dy.get(1, 0) * a.get(1, 1) +
|
449 | dy.get(2, 0) * a.get(2, 1),
|
450 | dy.get(0, 1) * a.get(0, 0) + dy.get(1, 1) * a.get(1, 0) +
|
451 | dy.get(2, 1) * a.get(2, 0),
|
452 | dy.get(0, 1) * a.get(0, 1) + dy.get(1, 1) * a.get(1, 1) +
|
453 | dy.get(2, 1) * a.get(2, 1),
|
454 | dy.get(0, 2) * a.get(0, 0) + dy.get(1, 2) * a.get(1, 0) +
|
455 | dy.get(2, 2) * a.get(2, 0),
|
456 | dy.get(0, 2) * a.get(0, 1) + dy.get(1, 2) * a.get(1, 1) +
|
457 | dy.get(2, 2) * a.get(2, 1)
|
458 | ]);
|
459 | });
|
460 | it('gradients: aT * b', async () => {
|
461 | const aT = tf.tensor2d([1, 2, 3, 10, 20, 30], [3, 2]);
|
462 | const bT = tf.tensor2d([2, 3, 4, 1, 2, 3], [3, 2]);
|
463 | const dyT = tf.tensor2d([1, 10, 20, 30], [2, 2]);
|
464 | const transposeA = true;
|
465 | const transposeB = false;
|
466 | const grads = tf.grads((a, b) => tf.matMul(a, b, transposeA, transposeB));
|
467 | const [da, db] = grads([aT, bT], dyT);
|
468 |
|
469 | expect(da.shape).toEqual(aT.shape);
|
470 | const a = await aT.buffer();
|
471 | const dy = await dyT.buffer();
|
472 | const b = await bT.buffer();
|
473 | expectArraysClose(await da.data(), [
|
474 | dy.get(0, 0) * b.get(0, 0) + dy.get(0, 1) * b.get(0, 1),
|
475 | dy.get(1, 0) * b.get(0, 0) + dy.get(1, 1) * b.get(0, 1),
|
476 | dy.get(0, 0) * b.get(1, 0) + dy.get(0, 1) * b.get(1, 1),
|
477 | dy.get(1, 0) * b.get(1, 0) + dy.get(1, 1) * b.get(1, 1),
|
478 | dy.get(0, 0) * b.get(2, 0) + dy.get(0, 1) * b.get(2, 1),
|
479 | dy.get(1, 0) * b.get(2, 0) + dy.get(1, 1) * b.get(2, 1)
|
480 | ]);
|
481 |
|
482 | expect(db.shape).toEqual(b.shape);
|
483 | expectArraysClose(await db.data(), [
|
484 | dy.get(0, 0) * a.get(0, 0) + dy.get(1, 0) * a.get(0, 1),
|
485 | dy.get(0, 1) * a.get(0, 0) + dy.get(1, 1) * a.get(0, 1),
|
486 | dy.get(0, 0) * a.get(1, 0) + dy.get(1, 0) * a.get(1, 1),
|
487 | dy.get(0, 1) * a.get(1, 0) + dy.get(1, 1) * a.get(1, 1),
|
488 | dy.get(0, 0) * a.get(2, 0) + dy.get(1, 0) * a.get(2, 1),
|
489 | dy.get(0, 1) * a.get(2, 0) + dy.get(1, 1) * a.get(2, 1)
|
490 | ]);
|
491 | });
|
492 | it('gradients: aT * bT', async () => {
|
493 | const aT = tf.tensor2d([1, 2, 3, 10, 20, 30], [3, 2]);
|
494 | const bT = tf.tensor2d([2, 3, 4, 1, 2, 3], [2, 3]);
|
495 | const dyT = tf.tensor2d([1, 10, 20, 30], [2, 2]);
|
496 | const transposeA = true;
|
497 | const transposeB = true;
|
498 | const grads = tf.grads((a, b) => tf.matMul(a, b, transposeA, transposeB));
|
499 | const [da, db] = grads([aT, bT], dyT);
|
500 |
|
501 | expect(da.shape).toEqual(aT.shape);
|
502 | const a = await aT.buffer();
|
503 | const dy = await dyT.buffer();
|
504 | const b = await bT.buffer();
|
505 | expectArraysClose(await da.data(), [
|
506 | dy.get(0, 0) * b.get(0, 0) + dy.get(0, 1) * b.get(1, 0),
|
507 | dy.get(1, 0) * b.get(0, 0) + dy.get(1, 1) * b.get(1, 0),
|
508 | dy.get(0, 0) * b.get(0, 1) + dy.get(0, 1) * b.get(1, 1),
|
509 | dy.get(1, 0) * b.get(0, 1) + dy.get(1, 1) * b.get(1, 1),
|
510 | dy.get(0, 0) * b.get(0, 2) + dy.get(0, 1) * b.get(1, 2),
|
511 | dy.get(1, 0) * b.get(0, 2) + dy.get(1, 1) * b.get(1, 2)
|
512 | ]);
|
513 |
|
514 | expect(db.shape).toEqual(b.shape);
|
515 | expectArraysClose(await db.data(), [
|
516 | dy.get(0, 0) * a.get(0, 0) + dy.get(1, 0) * a.get(0, 1),
|
517 | dy.get(0, 0) * a.get(1, 0) + dy.get(1, 0) * a.get(1, 1),
|
518 | dy.get(0, 0) * a.get(2, 0) + dy.get(1, 0) * a.get(2, 1),
|
519 | dy.get(0, 1) * a.get(0, 0) + dy.get(1, 1) * a.get(0, 1),
|
520 | dy.get(0, 1) * a.get(1, 0) + dy.get(1, 1) * a.get(1, 1),
|
521 | dy.get(0, 1) * a.get(2, 0) + dy.get(1, 1) * a.get(2, 1)
|
522 | ]);
|
523 | });
|
524 | it('throws when passed a as a non-tensor', () => {
|
525 | expect(() => tf.matMul({}, tf.tensor2d([2], [1, 1])))
|
526 | .toThrowError(/Argument 'a' passed to 'matMul' must be a Tensor/);
|
527 | });
|
528 | it('throws when passed b as a non-tensor', () => {
|
529 | expect(() => tf.matMul(tf.tensor2d([2], [1, 1]), {}))
|
530 | .toThrowError(/Argument 'b' passed to 'matMul' must be a Tensor/);
|
531 | });
|
532 | it('accepts a tensor-like object', async () => {
|
533 | const a = [[1, 2, 3], [4, 5, 6]];
|
534 | const b = [[0, 1], [-3, 2], [2, 1]];
|
535 | const c = tf.matMul(a, b);
|
536 | expect(c.shape).toEqual([2, 2]);
|
537 | expectArraysClose(await c.data(), [0, 8, -3, 20]);
|
538 | });
|
539 | it('accepts a tensor-like object chained', async () => {
|
540 | const a = tf.tensor2d([[1, 2, 3], [4, 5, 6]], [2, 3]);
|
541 | const b = [[0, 1], [-3, 2], [2, 1]];
|
542 | const c = a.matMul(b);
|
543 | expect(c.shape).toEqual([2, 2]);
|
544 | expectArraysClose(await c.data(), [0, 8, -3, 20]);
|
545 | });
|
546 | it('a * b where a has zero in its shape', async () => {
|
547 | const a = tf.tensor2d([], [0, 3]);
|
548 | const b = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
|
549 | const c = tf.matMul(a, b);
|
550 | expect(c.shape).toEqual([0, 2]);
|
551 | expect(c.rank).toBe(2);
|
552 | expect(c.size).toBe(0);
|
553 | expectArraysClose(await c.data(), []);
|
554 | });
|
555 | it('(a * b) * c where a has zero in its shape, so a*b does also', async () => {
|
556 | const a = tf.tensor2d([], [0, 3]);
|
557 | const b = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
|
558 | const ab = tf.matMul(a, b);
|
559 | expect(ab.shape).toEqual([0, 2]);
|
560 | expectArraysClose(await ab.data(), []);
|
561 | const c = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
562 | const res = tf.matMul(ab, c);
|
563 | expect(res.shape).toEqual([0, 3]);
|
564 | expectArraysClose(await res.data(), []);
|
565 | });
|
566 | it('throws error for string tensor', () => {
|
567 | expect(() => tf.matMul([['a']], [['b']]))
|
568 | .toThrowError(/Argument 'a' passed to 'matMul' must be numeric tensor/);
|
569 | });
|
570 | });
|
571 | describeWithFlags('matmulBatch', ALL_ENVS, () => {
|
572 | it('A x B', async () => {
|
573 | const a = tf.tensor3d([
|
574 | -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3, 7, -2, 5,
|
575 | -6, 3, 8, 7, -8, 1, 4, -4, 6, 4, -4, -9, -5, 2, -2
|
576 | ], [5, 2, 3]);
|
577 | const b = tf.tensor3d([
|
578 | -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
|
579 | 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
|
580 | ], [5, 3, 2]);
|
581 | const c = tf.matMul(a, b);
|
582 | expect(c.shape).toEqual([5, 2, 2]);
|
583 | expectArraysClose(await c.data(), [
|
584 | 87, 20, -6, -32, -24, -50, -36, -5, 24, 98,
|
585 | 70, 33, -64, 47, -42, -28, -71, 24, 37, 5
|
586 | ]);
|
587 | });
|
588 | it('A x B in 4D', async () => {
|
589 | const a = tf.tensor4d([
|
590 | -2, 3, 5, -5, 3, 9, -3, -5, 1, 1, -9, 9, -6, 6, -8,
|
591 | -7, -1, 3, 9, -7, -7, 2, 10, -6, -8, -6, 9, -6, 4, -1,
|
592 | 9, -6, 10, 8, -9, 5, -8, -7, 0, 2, -5, -1, -9, -4, 3,
|
593 | -2, 6, -4, 7, 1, -5, -4, 9, -8, -6, -8, 4, -1, 4, 3,
|
594 | -7, 8, -7, 5, -3, -2, -4, 9, 2, -1, 1, -10, -3, 5, -4,
|
595 | 6, -8, -8, 9, -3, -5, 10, 3, -3, -3, 9, 3, -3, 2, -8,
|
596 | 10, 1, 9, -2, -2, -3, -4, 6, -10, -1, 8, -8, 7, 3, -2,
|
597 | 3, 6, -2, -2, -4, 1, -5, -4, 0, 5, 1, 9, -8, -2, -1
|
598 | ], [4, 5, 2, 3]);
|
599 | const b = tf.tensor4d([
|
600 | -4, -3, -2, -6, 6, -1, -4, -1, 7, -4, 8, -9, -9, 0, -1,
|
601 | -4, -6, -7, -3, -4, -7, 6, -8, 1, -2, 1, -1, -3, 8, -5,
|
602 | 9, -2, 5, 9, -2, 2, -5, -5, -8, -1, -2, -3, -2, -10, 6,
|
603 | -3, 0, 1, 6, 7, 1, 2, -4, -5, 2, -5, -7, 9, 3, -6,
|
604 | 6, 4, -4, 6, 10, -3, -2, 8, 10, -8, 10, -1, -9, -7, -8,
|
605 | -3, 1, 1, -2, -9, -7, -6, -1, 0, 7, -9, -7, -5, 0, -4,
|
606 | -4, -7, 2, 4, 6, 6, -4, -6, -8, 3, -8, -9, 6, 9, -4,
|
607 | 1, -1, 0, 8, 9, 0, -5, 3, -1, 5, 0, -10, 7, -2, 6
|
608 | ], [4, 5, 3, 2]);
|
609 | const transposeA = false;
|
610 | const transposeB = false;
|
611 | const c = tf.matMul(a, b, transposeA, transposeB);
|
612 | expectArraysClose(await c.data(), [
|
613 | 32, -17, 68, -12, -15, 14, 5, -46, 96, 32, 46, -17, 78, -85,
|
614 | -28, 46, 94, -35, 0, -13, 31, -52, 17, -87, 96, 47, 32, -2,
|
615 | -6, 105, 40, -2, 63, 76, 17, 30, 56, -66, -21, 23, -144, 41,
|
616 | 22, 8, 118, -106, -88, -6, -17, 2, 2, -26, 8, -63, -38, -108,
|
617 | -84, -30, -35, 49, 16, -12, -14, -12, 48, 132, 4, 102, 32, 66,
|
618 | -4, 33, -13, 1, -40, -25, -3, 61, -18, -20
|
619 | ]);
|
620 | });
|
621 | it('A x B^t', async () => {
|
622 | const a = tf.tensor3d([
|
623 | -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3, 7, -2, 5,
|
624 | -6, 3, 8, 7, -8, 1, 4, -4, 6, 4, -4, -9, -5, 2, -2
|
625 | ], [5, 2, 3]);
|
626 | const b = tf.tensor3d([
|
627 | -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
|
628 | 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
|
629 | ], [5, 2, 3]);
|
630 | const transposeA = false;
|
631 | const transposeB = true;
|
632 | const c = tf.matMul(a, b, transposeA, transposeB);
|
633 | expect(c.shape).toEqual([5, 2, 2]);
|
634 | expectArraysClose(await c.data(), [
|
635 | 66, 35, -48, 14, -45, -33, -12, 7, -76, 64,
|
636 | 3, 66, -119, -9, -64, -60, -76, 48, 33, -16
|
637 | ]);
|
638 | });
|
639 | it('A^t x B', async () => {
|
640 | const a = tf.tensor3d([
|
641 | -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3, 7, -2, 5,
|
642 | -6, 3, 8, 7, -8, 1, 4, -4, 6, 4, -4, -9, -5, 2, -2
|
643 | ], [5, 2, 3]);
|
644 | const b = tf.tensor3d([
|
645 | -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
|
646 | 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
|
647 | ], [5, 2, 3]);
|
648 | const transposeA = true;
|
649 | const transposeB = false;
|
650 | const c = tf.matMul(a, b, transposeA, transposeB);
|
651 | expectArraysClose(await c.data(), [
|
652 | 40, -36, 5, 40, 34, 5, 48, 80, 6, -6, 21, -48, -23, -20, -50,
|
653 | -12, -21, -12, -58, 15, -96, 23, 6, 39, 20, 109, 42, -67, 45, -40,
|
654 | 76, -52, 40, -15, 1, -60, -58, -3, 36, 40, -6, -24, 51, -33, -28
|
655 | ]);
|
656 | });
|
657 | it('A^t x B in 4D', async () => {
|
658 | const a = tf.tensor4d([
|
659 | -2, 3, 5, -5, 3, 9, -3, -5, 1, 1, -9, 9, -6, 6, -8,
|
660 | -7, -1, 3, 9, -7, -7, 2, 10, -6, -8, -6, 9, -6, 4, -1,
|
661 | 9, -6, 10, 8, -9, 5, -8, -7, 0, 2, -5, -1, -9, -4, 3,
|
662 | -2, 6, -4, 7, 1, -5, -4, 9, -8, -6, -8, 4, -1, 4, 3,
|
663 | -7, 8, -7, 5, -3, -2, -4, 9, 2, -1, 1, -10, -3, 5, -4,
|
664 | 6, -8, -8, 9, -3, -5, 10, 3, -3, -3, 9, 3, -3, 2, -8,
|
665 | 10, 1, 9, -2, -2, -3, -4, 6, -10, -1, 8, -8, 7, 3, -2,
|
666 | 3, 6, -2, -2, -4, 1, -5, -4, 0, 5, 1, 9, -8, -2, -1
|
667 | ], [4, 5, 2, 3]);
|
668 | const b = tf.tensor4d([
|
669 | -4, -3, -2, -6, 6, -1, -4, -1, 7, -4, 8, -9, -9, 0, -1,
|
670 | -4, -6, -7, -3, -4, -7, 6, -8, 1, -2, 1, -1, -3, 8, -5,
|
671 | 9, -2, 5, 9, -2, 2, -5, -5, -8, -1, -2, -3, -2, -10, 6,
|
672 | -3, 0, 1, 6, 7, 1, 2, -4, -5, 2, -5, -7, 9, 3, -6,
|
673 | 6, 4, -4, 6, 10, -3, -2, 8, 10, -8, 10, -1, -9, -7, -8,
|
674 | -3, 1, 1, -2, -9, -7, -6, -1, 0, 7, -9, -7, -5, 0, -4,
|
675 | -4, -7, 2, 4, 6, 6, -4, -6, -8, 3, -8, -9, 6, 9, -4,
|
676 | 1, -1, 0, 8, 9, 0, -5, 3, -1, 5, 0, -10, 7, -2, 6
|
677 | ], [4, 5, 2, 3]);
|
678 | const transposeA = true;
|
679 | const transposeB = false;
|
680 | const c = tf.matMul(a, b, transposeA, transposeB);
|
681 | expectArraysClose(await c.data(), [
|
682 | 38, -24, 9, -30, 9, -9, -74, 39, -19, 8, 11, -30, 56, -67,
|
683 | 46, -40, 71, -74, 82, 42, 55, -50, 6, 1, 60, -18, -13, -15,
|
684 | -52, -61, 81, -52, 59, -15, 76, 43, 34, -56, 38, 0, 26, -14,
|
685 | -15, 1, -4, 153, -34, 61, -135, 30, -48, 135, -30, 60, 38, 36,
|
686 | 58, 40, 45, 71, 1, 2, 3, 24, 90, -56, -10, 40, -18, 6,
|
687 | -30, 14, 34, 65, 27, 24, -29, -44, -46, -3, 35, -21, 27, 48,
|
688 | 20, 52, 32, 35, -11, -46, -12, 22, 13, 30, 2, -23, -54, -48,
|
689 | 34, 16, -42, -39, -26, 82, 89, 76, -84, 30, 9, 27, 30, -21,
|
690 | -43, -48, 60, 20, 24, -78, -91, -63, -12, 24, 21, 28, 48, 35,
|
691 | -6, 27, 33, 53, -81, -71, 61, -27, 11, -48, -82, 8, -12, -19,
|
692 | -10, -48, -81, 0, 13, 32, 41, 0, -100, -120, 16, 124, 152, 45,
|
693 | 60, -28, 24, 21, -12, -14, -16, 8, 9, -33, 5, -12, -48, 4,
|
694 | 8, 9, 0, -31, 16, -98, -9, 4, -22, 38, 2, -96
|
695 | ]);
|
696 | });
|
697 | it('A^t x B^t', async () => {
|
698 | const a = tf.tensor3d([
|
699 | -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3, 7, -2, 5,
|
700 | -6, 3, 8, 7, -8, 1, 4, -4, 6, 4, -4, -9, -5, 2, -2
|
701 | ], [5, 3, 2]);
|
702 | const b = tf.tensor3d([
|
703 | -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
|
704 | 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
|
705 | ], [5, 2, 3]);
|
706 | const transposeA = true;
|
707 | const transposeB = true;
|
708 | const c = tf.matMul(a, b, transposeA, transposeB);
|
709 | expectArraysClose(await c.data(), [
|
710 | 66, 42, 16, -56, -12, 6, -30, 19, -1, 102,
|
711 | -94, 14, -56, 32, 100, -56, -47, -11, 5, -31
|
712 | ]);
|
713 | });
|
714 | it('batch dimensions do not match', () => {
|
715 | const a = tf.tensor3d([
|
716 | -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3,
|
717 | 7, -2, 5, -6, 3, 8, 7, -8, 1, 4, -4, 6
|
718 | ], [4, 3, 2]);
|
719 | const b = tf.tensor3d([
|
720 | -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
|
721 | 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
|
722 | ], [5, 2, 3]);
|
723 | const f = () => {
|
724 | tf.matMul(a, b, false, false);
|
725 | };
|
726 | expect(f).toThrowError();
|
727 | });
|
728 | it('gradients: A x B', async () => {
|
729 | const a = tf.tensor3d([
|
730 | -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3, 7, -2, 5,
|
731 | -6, 3, 8, 7, -8, 1, 4, -4, 6, 4, -4, -9, -5, 2, -2
|
732 | ], [5, 2, 3]);
|
733 | const b = tf.tensor3d([
|
734 | -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
|
735 | 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
|
736 | ], [5, 3, 2]);
|
737 | const dy = tf.tensor3d([8, 2, -3, -2, -8, 4, 5, 7, 4, -4, -4, 5, 8, 10, 1, 0, 6, 6, -4, 7], [5, 2, 2]);
|
738 | const grads = tf.grads((a, b) => tf.matMul(a, b, false, false));
|
739 | const [da, db] = grads([a, b], dy);
|
740 |
|
741 | expect(da.shape).toEqual(a.shape);
|
742 | expectArraysClose(await da.data(), [
|
743 | -72, -8, -56, 32, 3, 21, -12, -40, 40, 36, 44, 51, -52, -44, -4,
|
744 | 61, 49, 13, -2, -10, -108, -9, 0, -1, -24, 60, -6, 49, 26, -40
|
745 | ]);
|
746 |
|
747 | expect(db.shape).toEqual(b.shape);
|
748 | expectArraysClose(await db.data(), [
|
749 | -64, -26, -34, -6, -24, 4, -77, -47, 51, -35, 63, -3, 52, -58, -20,
|
750 | 23, -12, 20, 60, 70, -68, -80, 14, 10, 44, -11, -32, -10, -46, -68
|
751 | ]);
|
752 | });
|
753 | it('4d gradients: A x B', async () => {
|
754 | const a = tf.tensor4d([
|
755 | -2, 3, 5, -5, 3, 9, -3, -5, 1, 1, -9, 9, -6, 6, -8,
|
756 | -7, -1, 3, 9, -7, -7, 2, 10, -6, -8, -6, 9, -6, 4, -1,
|
757 | 9, -6, 10, 8, -9, 5, -8, -7, 0, 2, -5, -1, -9, -4, 3,
|
758 | -2, 6, -4, 7, 1, -5, -4, 9, -8, -6, -8, 4, -1, 4, 3,
|
759 | -7, 8, -7, 5, -3, -2, -4, 9, 2, -1, 1, -10, -3, 5, -4,
|
760 | 6, -8, -8, 9, -3, -5, 10, 3, -3, -3, 9, 3, -3, 2, -8,
|
761 | 10, 1, 9, -2, -2, -3, -4, 6, -10, -1, 8, -8, 7, 3, -2,
|
762 | 3, 6, -2, -2, -4, 1, -5, -4, 0, 5, 1, 9, -8, -2, -1
|
763 | ], [4, 5, 2, 3]);
|
764 | const b = tf.tensor4d([
|
765 | -4, -3, -2, -6, 6, -1, -4, -1, 7, -4, 8, -9, -9, 0, -1,
|
766 | -4, -6, -7, -3, -4, -7, 6, -8, 1, -2, 1, -1, -3, 8, -5,
|
767 | 9, -2, 5, 9, -2, 2, -5, -5, -8, -1, -2, -3, -2, -10, 6,
|
768 | -3, 0, 1, 6, 7, 1, 2, -4, -5, 2, -5, -7, 9, 3, -6,
|
769 | 6, 4, -4, 6, 10, -3, -2, 8, 10, -8, 10, -1, -9, -7, -8,
|
770 | -3, 1, 1, -2, -9, -7, -6, -1, 0, 7, -9, -7, -5, 0, -4,
|
771 | -4, -7, 2, 4, 6, 6, -4, -6, -8, 3, -8, -9, 6, 9, -4,
|
772 | 1, -1, 0, 8, 9, 0, -5, 3, -1, 5, 0, -10, 7, -2, 6
|
773 | ], [4, 5, 3, 2]);
|
774 | const dy = tf.tensor4d([
|
775 | 8, -7, 0, -9, -5, -5, 0, 3, 7, -4, 6, -8, -8, 0, -1, -8,
|
776 | -9, -7, -4, -9, 2, 3, 5, 8, -5, -7, 3, -10, -5, -9, -5, 1,
|
777 | 7, 1, -9, -10, 8, 5, 0, 8, -6, 4, 0, -5, 8, -7, -2, 1,
|
778 | -8, 9, 9, -7, 1, 7, -2, 5, -2, 9, 1, -5, 7, 5, -7, -6,
|
779 | 6, 7, -8, 7, 4, -5, 4, -5, 3, -4, -5, 4, -6, 3, -8, 10
|
780 | ], [4, 5, 2, 2]);
|
781 | const grads = tf.grads((a, b) => tf.matMul(a, b, false, false));
|
782 | const [da, db] = grads([a, b], dy);
|
783 |
|
784 | expect(da.shape).toEqual(a.shape);
|
785 | expectArraysClose(await da.data(), [
|
786 | -11, 26, 55, 27, 54, 9, 25, -15, 5, -3, -12, -27, -63, 9,
|
787 | -14, -54, 26, 20, 24, 56, 64, 35, -41, 0, 11, 30, -37, -1,
|
788 | 31, 13, 12, 37, 2, 29, 97, 6, 60, 47, 31, 35, -14, 24,
|
789 | 100, -3, -9, 0, -33, 1, 49, 9, -33, -124, -29, 86, -9, -11,
|
790 | -6, -40, 72, -48, -20, 48, -72, -20, -30, 15, -72, 136, 87, 12,
|
791 | -28, -21, 9, 37, 1, -32, -51, 2, -65, -49, -1, -41, -16, 2,
|
792 | -95, -31, -36, 52, 18, 20, -63, 34, 72, 70, -38, -78, -66, -27,
|
793 | -111, -10, 85, 1, -21, -21, -4, -21, -21, -4, -12, 20, 13, -4,
|
794 | -20, -19, -30, 81, 30, -40, 150, 76
|
795 | ]);
|
796 |
|
797 | expect(db.shape).toEqual(b.shape);
|
798 | expectArraysClose(await db.data(), [
|
799 | -16, 59, 24, -48, 40, -116, 15, 18, 25, -2, -5, 22, -84, 80,
|
800 | 36, -16, -38, 8, -74, -16, 46, -80, 62, 48, 96, 110, 38, 6,
|
801 | -77, -54, 58, 91, -57, -90, 45, 70, 46, 36, 20, 99, -3, 10,
|
802 | 55, 79, -10, 42, 5, -31, 85, 47, -74, -89, 37, 75, -48, -38,
|
803 | -64, -8, 32, 44, 42, -53, -48, 47, 42, -18, -30, 27, 70, -62,
|
804 | 36, -24, 78, -69, -112, 101, -40, 20, -11, 113, -9, -6, 1, -50,
|
805 | 3, -12, -16, 71, -14, 67, 84, 62, 21, 17, 84, 63, -16, -35,
|
806 | -28, 98, 4, -126, 40, -50, 36, -45, -16, 20, 19, -12, 8, 0,
|
807 | 3, -4, 34, -65, 10, -17, -46, 17
|
808 | ]);
|
809 | });
|
810 | it('gradients: A x B^t', async () => {
|
811 | const a = tf.tensor3d([
|
812 | -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3, 7, -2, 5,
|
813 | -6, 3, 8, 7, -8, 1, 4, -4, 6, 4, -4, -9, -5, 2, -2
|
814 | ], [5, 3, 2]);
|
815 | const b = tf.tensor3d([
|
816 | -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
|
817 | 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
|
818 | ], [5, 3, 2]);
|
819 | const dy = tf.tensor3d([
|
820 | -0, 7, 5, 0, -9, 5, -7, 6, -5, -3, -2, -2, -4, 10, -3,
|
821 | 5, -1, 3, -2, -9, 4, -5, 7, 9, -10, -8, -8, -5, -0, -1,
|
822 | 3, 3, 4, 9, -7, 6, -2, -9, 5, 1, -5, -3, -1, 9, 4
|
823 | ], [5, 3, 3]);
|
824 | const grads = tf.grads((a, b) => tf.matMul(a, b, false, true));
|
825 | const [da, db] = grads([a, b], dy);
|
826 | expect(da.shape).toEqual(a.shape);
|
827 | expectArraysClose(await da.data(), [
|
828 | -42, 0, -26, 0, 85, 28, -19, -29, 51, -16, 6, 37, 94, -27, 50,
|
829 | 71, 24, -202, 46, -25, -31, -22, -87, 10, -7, -80, -36, -15, 55, 35
|
830 | ]);
|
831 | expect(db.shape).toEqual(b.shape);
|
832 | expectArraysClose(await db.data(), [
|
833 | 14, 56, 7, -155, -45, 55, 7, 72, -67, -79, 7, 50, -69, -46, -52,
|
834 | -88, 49, -126, -68, 106, 31, -30, -27, 60, -19, 5, 27, 43, 55, -13
|
835 | ]);
|
836 | });
|
837 | it('4d gradients: A x B^t', async () => {
|
838 | const a = tf.tensor4d([
|
839 | -2, 3, 5, -5, 3, 9, -3, -5, 1, 1, -9, 9, -6, 6, -8,
|
840 | -7, -1, 3, 9, -7, -7, 2, 10, -6, -8, -6, 9, -6, 4, -1,
|
841 | 9, -6, 10, 8, -9, 5, -8, -7, 0, 2, -5, -1, -9, -4, 3,
|
842 | -2, 6, -4, 7, 1, -5, -4, 9, -8, -6, -8, 4, -1, 4, 3,
|
843 | -7, 8, -7, 5, -3, -2, -4, 9, 2, -1, 1, -10, -3, 5, -4,
|
844 | 6, -8, -8, 9, -3, -5, 10, 3, -3, -3, 9, 3, -3, 2, -8,
|
845 | 10, 1, 9, -2, -2, -3, -4, 6, -10, -1, 8, -8, 7, 3, -2,
|
846 | 3, 6, -2, -2, -4, 1, -5, -4, 0, 5, 1, 9, -8, -2, -1
|
847 | ], [4, 5, 3, 2]);
|
848 | const b = tf.tensor4d([
|
849 | -4, -3, -2, -6, 6, -1, -4, -1, 7, -4, 8, -9, -9, 0, -1,
|
850 | -4, -6, -7, -3, -4, -7, 6, -8, 1, -2, 1, -1, -3, 8, -5,
|
851 | 9, -2, 5, 9, -2, 2, -5, -5, -8, -1, -2, -3, -2, -10, 6,
|
852 | -3, 0, 1, 6, 7, 1, 2, -4, -5, 2, -5, -7, 9, 3, -6,
|
853 | 6, 4, -4, 6, 10, -3, -2, 8, 10, -8, 10, -1, -9, -7, -8,
|
854 | -3, 1, 1, -2, -9, -7, -6, -1, 0, 7, -9, -7, -5, 0, -4,
|
855 | -4, -7, 2, 4, 6, 6, -4, -6, -8, 3, -8, -9, 6, 9, -4,
|
856 | 1, -1, 0, 8, 9, 0, -5, 3, -1, 5, 0, -10, 7, -2, 6
|
857 | ], [4, 5, 3, 2]);
|
858 | const dy = tf.tensor4d([
|
859 | 5, -1, -5, -4, -1, 9, 1, -2, 10, 7, -1, 6, -8, 8, -3,
|
860 | 9, -4, 2, -4, -8, 8, 4, 8, -10, -8, -8, 6, 6, -5, 9,
|
861 | -1, -7, -5, -3, -3, 2, -6, 5, 8, -9, 5, -8, -3, 8, 6,
|
862 | 2, 8, 5, 9, 7, 6, 2, -3, 10, 7, 7, -3, 4, -3, -6,
|
863 | -8, -8, 9, 0, -8, -3, -2, -2, 8, 2, 3, -6, 3, 6, -3,
|
864 | 7, 7, -9, -3, 8, 7, 7, -1, -6, 5, 2, -1, -1, 1, 5,
|
865 | 0, -4, 3, -4, -10, 1, -2, -8, -9, -6, 4, 4, -7, -1, -1,
|
866 | -9, 7, 1, -1, 8, 0, -2, -7, 5, 7, 8, 9, -3, -8, -6,
|
867 | -7, -8, -1, 8, -4, 7, 5, -9, 9, 3, 0, -10, 7, -9, 4,
|
868 | -7, 5, -2, -2, 3, 3, -6, 2, 0, 8, -5, -10, 3, -7, 0,
|
869 | -6, 2, 3, -1, 3, 3, -10, 1, 3, -7, -1, 8, -2, -1, -1,
|
870 | -3, -9, 7, 4, -6, 3, 0, -7, -4, -5, -8, -6, 10, -6, 4
|
871 | ], [4, 5, 3, 3]);
|
872 | const grads = tf.grads((a, b) => tf.matMul(a, b, false, true));
|
873 | const [da, db] = grads([a, b], dy);
|
874 | expect(da.shape).toEqual(a.shape);
|
875 | expectArraysClose(await da.data(), [
|
876 | -48, -4, 72, 9, 60, -1, 13, -57, 64, 3, -48, -11, -4, -24,
|
877 | 16, 38, 44, -10, -55, -45, 92, -43, 14, -4, 71, -61, -51, 16,
|
878 | 46, -57, 48, 78, 104, 57, -17, -11, -85, -33, 16, 1, 86, 21,
|
879 | -48, 21, -8, 34, 14, -35, 36, 48, 85, 108, -38, -40, 3, -8,
|
880 | -7, -1, 6, -16, 46, -33, 26, -79, -70, -29, 92, -84, -6, -47,
|
881 | 98, -129, -55, -17, 79, 40, -118, -64, 68, 75, 71, 111, 5, -48,
|
882 | 98, -36, 21, 13, 112, -34, 26, 57, 32, 44, 28, 50, 88, 27,
|
883 | 44, -39, -16, 15, -21, -6, -67, -89, -46, -64, -19, -12, -3, 11,
|
884 | 41, 63, 78, -73, 67, -92, 102, -18
|
885 | ]);
|
886 | expect(db.shape).toEqual(b.shape);
|
887 | expectArraysClose(await db.data(), [
|
888 | -27, 44, -9, -16, 85, 30, -110, 38, 47, -23, -39, -15, 0, -76,
|
889 | -8, -128, 26, 136, 31, -26, -26, 39, 136, -85, -45, 93, 37, -68,
|
890 | -112, -6, 90, 70, 169, -7, 15, 68, -16, -33, -16, -47, -21, 0,
|
891 | 6, -4, 84, 24, 15, 20, -41, -1, 79, -86, 87, -23, -26, -64,
|
892 | 18, 9, 52, 64, 34, -16, 122, -66, -1, 47, 1, 43, -11, -33,
|
893 | -17, 27, -45, -73, -60, -66, -92, -42, 32, -85, -44, -44, -28, -13,
|
894 | 8, -20, 9, -9, -49, 79, -76, 15, 73, -7, 7, -8, -110, 93,
|
895 | 106, -39, 64, -84, -29, -19, 13, 14, 63, 2, -15, 23, 17, 49,
|
896 | -3, -31, -65, 30, -95, 63, -82, 40
|
897 | ]);
|
898 | });
|
899 | it('gradients: A^t x B', async () => {
|
900 | const a = tf.tensor3d([
|
901 | -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3, 7, -2, 5,
|
902 | -6, 3, 8, 7, -8, 1, 4, -4, 6, 4, -4, -9, -5, 2, -2
|
903 | ], [5, 3, 2]);
|
904 | const b = tf.tensor3d([
|
905 | -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
|
906 | 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
|
907 | ], [5, 3, 2]);
|
908 | const dy = tf.tensor3d([8, 2, -3, -2, -8, 4, 5, 7, 4, -4, -4, 5, 8, 10, 1, 0, 6, 6, -4, 7], [5, 2, 2]);
|
909 | const grads = tf.grads((a, b) => tf.matMul(a, b, true, false));
|
910 | const [da, db] = grads([a, b], dy);
|
911 | expect(da.shape).toEqual(a.shape);
|
912 | expectArraysClose(await da.data(), [
|
913 | -72, 32, -8, 3, -56, 21, -12, 36, -40, 44, 40, 51, -52, 61, -44,
|
914 | 49, -4, 13, -2, -9, -10, 0, -108, -1, -24, 49, 60, 26, -6, -40
|
915 | ]);
|
916 | expect(db.shape).toEqual(b.shape);
|
917 | expectArraysClose(await db.data(), [
|
918 | -25, 0, -72, -28, 8, 12, -67, -33, 3, -87, 23, 17, 36, -38, 44,
|
919 | -50, -20, 28, 48, 70, 12, 10, -26, -40, 40, -4, -34, -89, 20, -2
|
920 | ]);
|
921 | });
|
922 | it('gradients: A^t x B^t', async () => {
|
923 | const a = tf.tensor3d([
|
924 | -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3, 7, -2, 5,
|
925 | -6, 3, 8, 7, -8, 1, 4, -4, 6, 4, -4, -9, -5, 2, -2
|
926 | ], [5, 3, 2]);
|
927 | const b = tf.tensor3d([
|
928 | -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
|
929 | 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
|
930 | ], [5, 2, 3]);
|
931 | const dy = tf.tensor3d([8, 2, -3, -2, -8, 4, 5, 7, 4, -4, -4, 5, 8, 10, 1, 0, 6, 6, -4, 7], [5, 2, 2]);
|
932 | const grads = tf.grads((a, b) => tf.matMul(a, b, true, true));
|
933 | const [da, db] = grads([a, b], dy);
|
934 | expect(da.shape).toEqual(a.shape);
|
935 | expectArraysClose(await da.data(), [
|
936 | -64, 24, -46, 26, -8, 3, -16, 29, -28, 8, -16, 86, -36, 41, 4,
|
937 | 4, -60, 69, -82, -9, 46, 7, -100, 0, -6, 70, 36, 9, 0, -44
|
938 | ]);
|
939 | expect(db.shape).toEqual(b.shape);
|
940 | expectArraysClose(await db.data(), [
|
941 | -25, -72, 8, 0, -28, 12, -67, 3, 23, -33, -87, 17, 36, 44, -20,
|
942 | -38, -50, 28, 48, 12, -26, 70, 10, -40, 40, -34, 20, -4, -89, -2
|
943 | ]);
|
944 | });
|
945 | });
|
946 | describeWithFlags('dot', ALL_ENVS, () => {
|
947 | let a;
|
948 | let b;
|
949 | let c;
|
950 | let d;
|
951 | let e;
|
952 | let f;
|
953 | beforeEach(() => {
|
954 | a = tf.tensor1d([1, 2]);
|
955 | b = tf.tensor2d([[1, 2], [3, 4]]);
|
956 | c = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
|
957 | d = tf.tensor3d([1, 2], [1, 1, 2]);
|
958 | e = tf.scalar(1);
|
959 | f = tf.tensor3d([1, 2, 1, 2], [2, 1, 2]);
|
960 | });
|
961 | it('vector-vector', async () => {
|
962 | const aa = tf.dot(a, a);
|
963 | expectArraysClose(await aa.data(), [5]);
|
964 | expect(aa.shape).toEqual([]);
|
965 | });
|
966 | it('vector-matrix', async () => {
|
967 | const ab = tf.dot(a, b);
|
968 | const ac = tf.dot(a, c);
|
969 | expect(ab.shape).toEqual([2]);
|
970 | expect(ac.shape).toEqual([3]);
|
971 | expectArraysClose(await ab.data(), [7, 10]);
|
972 | expectArraysClose(await ac.data(), [9, 12, 15]);
|
973 | });
|
974 | it('matrix-vector', async () => {
|
975 | const ba = b.dot(a);
|
976 | expect(ba.shape).toEqual([2]);
|
977 | expectArraysClose(await ba.data(), [5, 11]);
|
978 | });
|
979 | it('matrix-matrix', async () => {
|
980 | const bb = tf.dot(b, b);
|
981 | const bc = tf.dot(b, c);
|
982 | expect(bb.shape).toEqual([2, 2]);
|
983 | expect(bc.shape).toEqual([2, 3]);
|
984 | expectArraysClose(await bb.data(), [7, 10, 15, 22]);
|
985 | expectArraysClose(await bc.data(), [9, 12, 15, 19, 26, 33]);
|
986 | });
|
987 | it('matmul A x B asymmetric', async () => {
|
988 | const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
|
989 | const b = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
990 | const c = tf.matMul(a, b);
|
991 | const cData = await c.data();
|
992 | expect(c.shape).toEqual([2, 3]);
|
993 | expectArraysClose(cData, [9, 12, 15, 19, 26, 33]);
|
994 | });
|
995 | it('throws error on incompatible dimensions', () => {
|
996 | expect(() => tf.dot(c, f)).toThrowError();
|
997 | });
|
998 | it('throws error when inputs are not rank 1 or 2', () => {
|
999 | expect(() => tf.dot(a, d)).toThrowError();
|
1000 | expect(() => tf.dot(a, e)).toThrowError();
|
1001 | });
|
1002 | it('accepts a tensor-like object', async () => {
|
1003 | const a = [1, 2, 3];
|
1004 | const res = tf.dot(a, a);
|
1005 | expectArraysClose(await res.data(), [14]);
|
1006 | expect(res.shape).toEqual([]);
|
1007 | });
|
1008 | it('throws error for string tensors', () => {
|
1009 | expect(() => tf.dot('a', 'b'))
|
1010 | .toThrowError(/Argument 't1' passed to 'dot' must be numeric tensor/);
|
1011 | });
|
1012 | it('ensure no memory leak', async () => {
|
1013 | const numTensorsBefore = tf.memory().numTensors;
|
1014 | const numDataIdBefore = tf.engine().backend.numDataIds();
|
1015 | const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
1016 | const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
|
1017 | const c = tf.matMul(a, b);
|
1018 | expect(c.shape).toEqual([2, 2]);
|
1019 | expectArraysClose(await c.data(), [0, 8, -3, 20]);
|
1020 | a.dispose();
|
1021 | b.dispose();
|
1022 | c.dispose();
|
1023 | const numTensorsAfter = tf.memory().numTensors;
|
1024 | const numDataIdAfter = tf.engine().backend.numDataIds();
|
1025 | expect(numTensorsAfter).toBe(numTensorsBefore);
|
1026 | expect(numDataIdAfter).toBe(numDataIdBefore);
|
1027 | });
|
1028 | });
|
1029 |
|
\ | No newline at end of file |