UNPKG

48.6 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2017 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17// Empirically determined minimal shared dimension in matmul before we forward
18// to a.mul(b).sum() in order to take advantage of GPU parallelism. See
19// https://github.com/tensorflow/tfjs-core/pull/1379 for benchmarks.
20// Copied from webgl backend.
21// TODO(yassogba, annyuan) copy tests over to webgl backend that want to
22// explicitly test this threshold.
23export const MATMUL_SHARED_DIM_THRESHOLD = 1000;
24import * as tf from '../index';
25import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
26import { expectArraysClose, expectArraysEqual } from '../test_util';
27describeWithFlags('matmul', ALL_ENVS, () => {
28 it('A x B', async () => {
29 const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
30 const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
31 const c = tf.matMul(a, b);
32 expect(c.shape).toEqual([2, 2]);
33 expectArraysClose(await c.data(), [0, 8, -3, 20]);
34 });
35 it('[8,4]x[4,8]', async () => {
36 const a = tf.tensor2d([
37 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
38 17, 18, 19, 20, 21, 22, 23, 24, 1, 2, 3, 4, 5, 6, 7, 8
39 ], [8, 4]);
40 const b = tf.tensor2d([
41 0, 1, -3, 2, 1, -1, 0, 5, 6, 7, 8, 0, -2, -2, 1, 9,
42 11, 10, 0, 1, -3, 2, 1, -1, 1, 2, 3, 4, 5, 6, 7, 8
43 ], [4, 8]);
44 const c = tf.matMul(a, b);
45 const cData = await c.data();
46 expect(c.shape).toEqual([8, 8]);
47 expectArraysClose(cData, [
48 49, 53, 25, 21, 8, 25, 33, 52, 121, 133, 57, 49, 12,
49 45, 69, 136, 193, 213, 89, 77, 16, 65, 105, 220, 265, 293,
50 121, 105, 20, 85, 141, 304, 337, 373, 153, 133, 24, 105, 177,
51 388, 409, 453, 185, 161, 28, 125, 213, 472, 49, 53, 25, 21,
52 8, 25, 33, 52, 121, 133, 57, 49, 12, 45, 69, 136
53 ]);
54 });
55 it('broadcast with unequal batch dims', async () => {
56 const a = tf.tensor3d([
57 2, 1, 3, 2, 1, 1, 1, 5, 6, 7, 8, 1,
58 2, 2, 1, 9, 11, 10, 1, 1, 3, 2, 1, 1
59 ], [4, 3, 2]);
60 const b = tf.tensor3d([1, 0.5], [1, 2, 1]);
61 const c = tf.matMul(a, b);
62 expect(c.shape).toEqual([4, 3, 1]);
63 expectArraysClose(await c.data(), [2.5, 4, 1.5, 3.5, 9.5, 8.5, 3, 5.5, 16, 1.5, 4, 1.5]);
64 });
65 it('broadcast with unequal ranks', async () => {
66 const a = tf.tensor5d([
67 2, 1, 3, 2, 1, 1, 1, 5, 6, 7, 8, 1,
68 2, 2, 1, 9, 11, 10, 1, 1, 3, 2, 1, 1
69 ], [1, 2, 2, 3, 2]);
70 const b = tf.tensor2d([1, 0.5], [2, 1]);
71 const c = tf.matMul(a, b);
72 expect(c.shape).toEqual([1, 2, 2, 3, 1]);
73 expectArraysClose(await c.data(), [2.5, 4, 1.5, 3.5, 9.5, 8.5, 3, 5.5, 16, 1.5, 4, 1.5]);
74 });
75 it('matmul followed by mul', async () => {
76 const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
77 const b = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
78 const c = tf.matMul(a, b);
79 const f = tf.tensor2d([0, 1, 0.5, 0, 0.25, 2], [2, 3]);
80 const d = tf.mul(c, f);
81 const dData = await d.data();
82 expect(d.shape).toEqual([2, 3]);
83 expectArraysClose(dData, [0, 12, 7.5, 0, 6.5, 66]);
84 });
85 it('upcasts when dtypes dont match', async () => {
86 const a = [1, 2, 3, 4, 5, 6];
87 const b = [0, 1, -3, 2, 2, 1];
88 let c = tf.matMul(tf.tensor(a, [2, 3], 'float32'), tf.tensor(b, [3, 2], 'int32'));
89 expect(c.shape).toEqual([2, 2]);
90 expect(c.dtype).toBe('float32');
91 expectArraysClose(await c.data(), [0, 8, -3, 20]);
92 c = tf.matMul(tf.tensor(a, [2, 3], 'int32'), tf.tensor(b, [3, 2], 'bool'));
93 expect(c.shape).toEqual([2, 2]);
94 expect(c.dtype).toBe('int32');
95 expectArraysClose(await c.data(), [5, 6, 11, 15]);
96 });
97 it('A x B^t', async () => {
98 const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
99 const b = tf.tensor2d([1, 0, 2, 4, 3, 0], [2, 3]);
100 const transposeA = false;
101 const transposeB = true;
102 const c = tf.matMul(a, b, transposeA, transposeB);
103 const expected = [7, 10, 16, 31];
104 expectArraysClose(await c.data(), expected);
105 });
106 it('A^t x B', async () => {
107 const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
108 const b = tf.tensor2d([1, 0, 2, 4, 3, 0], [2, 3]);
109 const transposeA = true;
110 const transposeB = false;
111 const c = tf.matMul(a, b, transposeA, transposeB);
112 const expected = [17, 12, 2, 22, 15, 4, 27, 18, 6];
113 expectArraysClose(await c.data(), expected);
114 });
115 it('A^t x B^t', async () => {
116 const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
117 const b = tf.tensor2d([1, 0, 2, 4, 3, 0], [2, 3]);
118 const transposeA = true;
119 const transposeB = true;
120 const c = tf.matMul(a, b, transposeA, transposeB);
121 const expected = [11, 13, 14, 20];
122 expectArraysClose(await c.data(), expected);
123 });
124 it('A x B^t shapes do not match', () => {
125 const a = tf.zeros([2, 3]);
126 const b = tf.zeros([3, 2]);
127 const f = () => {
128 const transposeA = false;
129 const transposeB = true;
130 tf.matMul(a, b, transposeA, transposeB);
131 };
132 expect(f).toThrowError();
133 });
134 it('A^t x B shapes do not match', () => {
135 const a = tf.zeros([2, 3]);
136 const b = tf.zeros([3, 2]);
137 const f = () => {
138 const transposeA = true;
139 const transposeB = false;
140 tf.matMul(a, b, transposeA, transposeB);
141 };
142 expect(f).toThrowError();
143 });
144 it('A^t x B^t shapes do not match', () => {
145 const a = tf.zeros([3, 2]);
146 const b = tf.zeros([3, 2]);
147 const f = () => {
148 const transposeA = true;
149 const transposeB = true;
150 tf.matMul(a, b, transposeA, transposeB);
151 };
152 expect(f).toThrowError();
153 });
154 it('matmul throws when inner dimensions dont match', () => {
155 const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
156 const b = tf.tensor2d([0, 1, -3, 2, 2, 1, 2, 2], [4, 2]);
157 expect(() => tf.matMul(a, b)).toThrowError();
158 });
159 it('matmul throws when passed non matrices', () => {
160 // tslint:disable-next-line:no-any
161 const a = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]);
162 const b = tf.tensor2d([0, 1, -3, 2, 2, 1, 2, 2], [4, 2]);
163 expect(() => tf.matMul(a, b)).toThrowError();
164 expect(() => tf.matMul(b, a)).toThrowError();
165 });
166 it('matmul throws when passed a vector', () => {
167 // tslint:disable-next-line:no-any
168 const v = tf.tensor1d([2, 3]);
169 const matrix = tf.tensor2d([1, 2, 3, 4], [2, 2]);
170 expect(() => tf.matMul(matrix, v)).toThrowError();
171 });
172 it('Vector times matrix', async () => {
173 const v = tf.tensor1d([2, 3]);
174 const matrix = tf.tensor2d([1, 2, 3, 4], [2, 2]);
175 const result = tf.dot(v, matrix);
176 const expected = [11, 16];
177 expectArraysClose(await result.data(), expected);
178 });
179 it('Vector times matrix with implicit reshape', async () => {
180 const v = tf.tensor1d([2, 3]);
181 const matrix = tf.tensor2d([1, 2, 3, 4], [2, 2]);
182 const result = tf.dot(v, matrix);
183 const expected = [11, 16];
184 expectArraysClose(await result.data(), expected);
185 });
186 it('Matrix times vector', async () => {
187 const matrix = tf.tensor2d([1, 2, 3, 4], [2, 2]);
188 const v = tf.tensor1d([2, 3]);
189 const result = tf.dot(matrix, v);
190 const expected = [8, 18];
191 expectArraysClose(await result.data(), expected);
192 });
193 it('batched matmul with the matrices being vectors', async () => {
194 const batch = 3;
195 const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
196 const values = new Float32Array(batch * sharedDim);
197 values[10] = 2;
198 const a = tf.tensor(values, [batch, 1, sharedDim]);
199 const b = tf.tensor(values, [batch, sharedDim, 1]);
200 const result = tf.matMul(a, b);
201 expect(result.shape).toEqual([batch, 1, 1]);
202 expectArraysClose(await result.data(), [4, 0, 0]);
203 });
204 it('batched matmul called twice so memory of output is reused', async () => {
205 const batch = 3;
206 const n = 2;
207 const vals = new Float32Array(batch * n * n);
208 vals[0] = 2;
209 vals[4] = 3;
210 vals[8] = 4;
211 const a = tf.tensor(vals, [batch, n, n]);
212 const b = tf.tensor(vals, [batch, n, n]);
213 const result = tf.matMul(a, b);
214 expect(result.shape).toEqual([batch, n, n]);
215 expectArraysClose(await result.data(), [4, 0, 0, 0, 9, 0, 0, 0, 16, 0, 0, 0]);
216 // Dispose the first output, so memory of the second output (which has the
217 // same shape), could be reused.
218 result.dispose();
219 const vals2 = new Float32Array(batch * n * n);
220 vals2[3] = 2;
221 vals2[7] = 3;
222 vals2[11] = 4;
223 const a2 = tf.tensor(vals2, [batch, n, n]);
224 const b2 = tf.tensor(vals2, [batch, n, n]);
225 const result2 = tf.matMul(a2, b2);
226 expect(result2.shape).toEqual([batch, n, n]);
227 expectArraysClose(await result2.data(), [0, 0, 0, 4, 0, 0, 0, 9, 0, 0, 0, 16]);
228 });
229 it('batched matmul with the matrices being vectors transposedA', async () => {
230 const batch = 3;
231 const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
232 const values = new Float32Array(batch * sharedDim);
233 values[10] = 2;
234 const a = tf.tensor(values, [batch, sharedDim, 1]);
235 const b = tf.tensor(values, [batch, sharedDim, 1]);
236 const transposeA = true;
237 const transposeB = false;
238 const result = tf.matMul(a, b, transposeA, transposeB);
239 expect(result.shape).toEqual([batch, 1, 1]);
240 expectArraysClose(await result.data(), [4, 0, 0]);
241 });
242 it('batched matmul with the matrices being vectors transposedB', async () => {
243 const batch = 3;
244 const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
245 const values = new Float32Array(batch * sharedDim);
246 values[10] = 2;
247 const a = tf.tensor(values, [batch, 1, sharedDim]);
248 const b = tf.tensor(values, [batch, 1, sharedDim]);
249 const transposeA = false;
250 const transposeB = true;
251 const result = tf.matMul(a, b, transposeA, transposeB);
252 expect(result.shape).toEqual([batch, 1, 1]);
253 expectArraysClose(await result.data(), [4, 0, 0]);
254 });
255 it('batched matmul with matrix x vector', async () => {
256 const batch = 3;
257 const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
258 const values = new Float32Array(batch * sharedDim);
259 values[10] = 2;
260 const a = tf.ones([batch, 2, sharedDim]);
261 const b = tf.tensor(values, [batch, sharedDim, 1]);
262 const result = tf.matMul(a, b);
263 expect(result.shape).toEqual([batch, 2, 1]);
264 expectArraysClose(await result.data(), [2, 2, 0, 0, 0, 0]);
265 });
266 it('batched matmul with matrix x vector transposedA', async () => {
267 const batch = 3;
268 const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
269 const values = new Float32Array(batch * sharedDim);
270 values[10] = 2;
271 const a = tf.ones([batch, sharedDim, 2]);
272 const b = tf.tensor(values, [batch, sharedDim, 1]);
273 const transposeA = true;
274 const transposeB = false;
275 const result = tf.matMul(a, b, transposeA, transposeB);
276 expect(result.shape).toEqual([batch, 2, 1]);
277 expectArraysClose(await result.data(), [2, 2, 0, 0, 0, 0]);
278 });
279 it('batched matmul with matrix x vector transposedB', async () => {
280 const batch = 3;
281 const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
282 const values = new Float32Array(batch * sharedDim);
283 values[10] = 2;
284 const a = tf.ones([batch, 2, sharedDim]);
285 const b = tf.tensor(values, [batch, 1, sharedDim]);
286 const transposeA = false;
287 const transposeB = true;
288 const result = tf.matMul(a, b, transposeA, transposeB);
289 expect(result.shape).toEqual([batch, 2, 1]);
290 expectArraysClose(await result.data(), [2, 2, 0, 0, 0, 0]);
291 });
292 it('batched matmul with vector x matrix', async () => {
293 const batch = 3;
294 const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
295 const values = new Float32Array(batch * sharedDim);
296 values[10] = 2;
297 const a = tf.tensor(values, [batch, 1, sharedDim]);
298 const b = tf.ones([batch, sharedDim, 2]);
299 const result = tf.matMul(a, b);
300 expect(result.shape).toEqual([batch, 1, 2]);
301 expectArraysClose(await result.data(), [2, 2, 0, 0, 0, 0]);
302 });
303 it('batched matmul with vector x matrix transposedA', async () => {
304 const batch = 3;
305 const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
306 const values = new Float32Array(batch * sharedDim);
307 values[10] = 2;
308 const a = tf.tensor(values, [batch, sharedDim, 1]);
309 const b = tf.ones([batch, sharedDim, 2]);
310 const transposeA = true;
311 const transposeB = false;
312 const result = tf.matMul(a, b, transposeA, transposeB);
313 expect(result.shape).toEqual([batch, 1, 2]);
314 expectArraysClose(await result.data(), [2, 2, 0, 0, 0, 0]);
315 });
316 it('batched matmul with vector x matrix transposedB', async () => {
317 const batch = 3;
318 const sharedDim = MATMUL_SHARED_DIM_THRESHOLD + 1;
319 const values = new Float32Array(batch * sharedDim);
320 values[10] = 2;
321 const a = tf.tensor(values, [batch, 1, sharedDim]);
322 const b = tf.ones([batch, 2, sharedDim]);
323 const transposeA = false;
324 const transposeB = true;
325 const result = tf.matMul(a, b, transposeA, transposeB);
326 expect(result.shape).toEqual([batch, 1, 2]);
327 expectArraysClose(await result.data(), [2, 2, 0, 0, 0, 0]);
328 });
329 it('Matrix * vector propagates NaNs', async () => {
330 const matrix = tf.tensor2d([1, 2, 3, 4], [2, 2]);
331 const v = tf.tensor1d([2, NaN]);
332 const result = tf.dot(matrix, v);
333 const expected = [NaN, NaN];
334 expectArraysClose(await result.data(), expected);
335 });
336 it('matrix times vector throws when not passed a matrix', () => {
337 const v = tf.tensor1d([2, 3]);
338 // tslint:disable-next-line:no-any
339 const matrix = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]);
340 expect(() => tf.dot(matrix, v)).toThrowError();
341 });
342 it('Dot product', async () => {
343 const v1 = tf.tensor1d([2, 3]);
344 const v2 = tf.tensor1d([2, 1]);
345 const result = tf.dot(v1, v2);
346 expectArraysClose(await result.data(), [7]);
347 });
348 it('Dot product propagates NaNs', async () => {
349 const v1 = tf.tensor1d([2, NaN]);
350 const v2 = tf.tensor1d([2, 1]);
351 const result = tf.dot(v1, v2);
352 expectArraysEqual(await result.data(), [NaN]);
353 });
354 it('Dot product throws when vectors are different size', () => {
355 const v1 = tf.tensor1d([2, 3, 3]);
356 const v2 = tf.tensor1d([2, 1]);
357 expect(() => tf.dot(v1, v2)).toThrowError();
358 expect(() => tf.dot(v2, v1)).toThrowError();
359 });
360 it('Outer product', async () => {
361 const v1 = tf.tensor1d([2, 3]);
362 const v2 = tf.tensor1d([2, 1]);
363 const result = tf.outerProduct(v1, v2);
364 const expected = [4, 2, 6, 3];
365 expect(result.shape).toEqual([2, 2]);
366 expectArraysClose(await result.data(), expected);
367 });
368 it('outer product accepts a tensor-like object', async () => {
369 const v1 = [2, 3];
370 const v2 = [2, 1];
371 const result = tf.outerProduct(v1, v2);
372 const expected = [4, 2, 6, 3];
373 expect(result.shape).toEqual([2, 2]);
374 expectArraysClose(await result.data(), expected);
375 });
376 it('gradients: A * B', async () => {
377 const aT = tf.tensor2d([1, 2, 3, 10, 20, 30], [2, 3]);
378 const bT = tf.tensor2d([2, 3, 4, 1, 2, 3], [3, 2]);
379 const dyT = tf.tensor2d([1, 10, 20, 30], [2, 2]);
380 const transposeA = false;
381 const transposeB = false;
382 const grads = tf.grads((a, b) => tf.matMul(a, b, transposeA, transposeB));
383 const [da, db] = grads([aT, bT], dyT);
384 // da = dy * bT
385 expect(da.shape).toEqual(aT.shape);
386 const a = await aT.buffer();
387 const dy = await dyT.buffer();
388 const b = await bT.buffer();
389 expectArraysClose(await da.data(), [
390 dy.get(0, 0) * b.get(0, 0) + dy.get(0, 1) * b.get(0, 1),
391 dy.get(0, 0) * b.get(1, 0) + dy.get(0, 1) * b.get(1, 1),
392 dy.get(0, 0) * b.get(2, 0) + dy.get(0, 1) * b.get(2, 1),
393 dy.get(1, 0) * b.get(0, 0) + dy.get(1, 1) * b.get(0, 1),
394 dy.get(1, 0) * b.get(1, 0) + dy.get(1, 1) * b.get(1, 1),
395 dy.get(1, 0) * b.get(2, 0) + dy.get(1, 1) * b.get(2, 1)
396 ], 1e-1);
397 // db = aT * dy
398 expect(db.shape).toEqual(b.shape);
399 expectArraysClose(await db.data(), [
400 a.get(0, 0) * dy.get(0, 0) + a.get(1, 0) * dy.get(1, 0),
401 a.get(0, 0) * dy.get(0, 1) + a.get(1, 0) * dy.get(1, 1),
402 a.get(0, 1) * dy.get(0, 0) + a.get(1, 1) * dy.get(1, 0),
403 a.get(0, 1) * dy.get(0, 1) + a.get(1, 1) * dy.get(1, 1),
404 a.get(0, 2) * dy.get(0, 0) + a.get(1, 2) * dy.get(1, 0),
405 a.get(0, 2) * dy.get(0, 1) + a.get(1, 2) * dy.get(1, 1)
406 ]);
407 });
408 it('gradient with clones', () => {
409 const a = tf.tensor2d([1, 2, 3, 10, 20, 30], [2, 3]);
410 const b = tf.tensor2d([2, 3, 4, 1, 2, 3], [3, 2]);
411 const grads = tf.grads((a, b) => tf.matMul(a.clone(), b.clone()).clone());
412 const [da, db] = grads([a, b]);
413 expect(da.shape).toEqual(a.shape);
414 expect(db.shape).toEqual(b.shape);
415 });
416 it('gradients: a * bT', async () => {
417 const aT = tf.tensor2d([1, 2, 3, 10, 20, 30], [3, 2]);
418 const bT = tf.tensor2d([2, 3, 4, 1, 2, 3], [3, 2]);
419 const dyT = tf.tensor2d([1, 10, 20, 30, 40, 50, 60, 70, 80], [3, 3]);
420 const transposeA = false;
421 const transposeB = true;
422 const grads = tf.grads((a, b) => tf.matMul(a, b, transposeA, transposeB));
423 const [da, db] = grads([aT, bT], dyT);
424 // da = dy * b
425 expect(da.shape).toEqual(aT.shape);
426 const a = await aT.buffer();
427 const dy = await dyT.buffer();
428 const b = await bT.buffer();
429 expectArraysClose(await da.data(), [
430 dy.get(0, 0) * b.get(0, 0) + dy.get(0, 1) * b.get(1, 0) +
431 dy.get(0, 2) * b.get(2, 0),
432 dy.get(0, 0) * b.get(0, 1) + dy.get(0, 1) * b.get(1, 1) +
433 dy.get(0, 2) * b.get(2, 1),
434 dy.get(1, 0) * b.get(0, 0) + dy.get(1, 1) * b.get(1, 0) +
435 dy.get(1, 2) * b.get(2, 0),
436 dy.get(1, 0) * b.get(0, 1) + dy.get(1, 1) * b.get(1, 1) +
437 dy.get(1, 2) * b.get(2, 1),
438 dy.get(2, 0) * b.get(0, 0) + dy.get(2, 1) * b.get(1, 0) +
439 dy.get(2, 2) * b.get(2, 0),
440 dy.get(2, 0) * b.get(0, 1) + dy.get(2, 1) * b.get(1, 1) +
441 dy.get(2, 2) * b.get(2, 1)
442 ]);
443 // db = dyT * a
444 expect(db.shape).toEqual(b.shape);
445 expectArraysClose(await db.data(), [
446 dy.get(0, 0) * a.get(0, 0) + dy.get(1, 0) * a.get(1, 0) +
447 dy.get(2, 0) * a.get(2, 0),
448 dy.get(0, 0) * a.get(0, 1) + dy.get(1, 0) * a.get(1, 1) +
449 dy.get(2, 0) * a.get(2, 1),
450 dy.get(0, 1) * a.get(0, 0) + dy.get(1, 1) * a.get(1, 0) +
451 dy.get(2, 1) * a.get(2, 0),
452 dy.get(0, 1) * a.get(0, 1) + dy.get(1, 1) * a.get(1, 1) +
453 dy.get(2, 1) * a.get(2, 1),
454 dy.get(0, 2) * a.get(0, 0) + dy.get(1, 2) * a.get(1, 0) +
455 dy.get(2, 2) * a.get(2, 0),
456 dy.get(0, 2) * a.get(0, 1) + dy.get(1, 2) * a.get(1, 1) +
457 dy.get(2, 2) * a.get(2, 1)
458 ]);
459 });
460 it('gradients: aT * b', async () => {
461 const aT = tf.tensor2d([1, 2, 3, 10, 20, 30], [3, 2]);
462 const bT = tf.tensor2d([2, 3, 4, 1, 2, 3], [3, 2]);
463 const dyT = tf.tensor2d([1, 10, 20, 30], [2, 2]);
464 const transposeA = true;
465 const transposeB = false;
466 const grads = tf.grads((a, b) => tf.matMul(a, b, transposeA, transposeB));
467 const [da, db] = grads([aT, bT], dyT);
468 // da = b * dyT
469 expect(da.shape).toEqual(aT.shape);
470 const a = await aT.buffer();
471 const dy = await dyT.buffer();
472 const b = await bT.buffer();
473 expectArraysClose(await da.data(), [
474 dy.get(0, 0) * b.get(0, 0) + dy.get(0, 1) * b.get(0, 1),
475 dy.get(1, 0) * b.get(0, 0) + dy.get(1, 1) * b.get(0, 1),
476 dy.get(0, 0) * b.get(1, 0) + dy.get(0, 1) * b.get(1, 1),
477 dy.get(1, 0) * b.get(1, 0) + dy.get(1, 1) * b.get(1, 1),
478 dy.get(0, 0) * b.get(2, 0) + dy.get(0, 1) * b.get(2, 1),
479 dy.get(1, 0) * b.get(2, 0) + dy.get(1, 1) * b.get(2, 1)
480 ]);
481 // db = a * dy
482 expect(db.shape).toEqual(b.shape);
483 expectArraysClose(await db.data(), [
484 dy.get(0, 0) * a.get(0, 0) + dy.get(1, 0) * a.get(0, 1),
485 dy.get(0, 1) * a.get(0, 0) + dy.get(1, 1) * a.get(0, 1),
486 dy.get(0, 0) * a.get(1, 0) + dy.get(1, 0) * a.get(1, 1),
487 dy.get(0, 1) * a.get(1, 0) + dy.get(1, 1) * a.get(1, 1),
488 dy.get(0, 0) * a.get(2, 0) + dy.get(1, 0) * a.get(2, 1),
489 dy.get(0, 1) * a.get(2, 0) + dy.get(1, 1) * a.get(2, 1)
490 ]);
491 });
492 it('gradients: aT * bT', async () => {
493 const aT = tf.tensor2d([1, 2, 3, 10, 20, 30], [3, 2]);
494 const bT = tf.tensor2d([2, 3, 4, 1, 2, 3], [2, 3]);
495 const dyT = tf.tensor2d([1, 10, 20, 30], [2, 2]);
496 const transposeA = true;
497 const transposeB = true;
498 const grads = tf.grads((a, b) => tf.matMul(a, b, transposeA, transposeB));
499 const [da, db] = grads([aT, bT], dyT);
500 // da = bT * dyT
501 expect(da.shape).toEqual(aT.shape);
502 const a = await aT.buffer();
503 const dy = await dyT.buffer();
504 const b = await bT.buffer();
505 expectArraysClose(await da.data(), [
506 dy.get(0, 0) * b.get(0, 0) + dy.get(0, 1) * b.get(1, 0),
507 dy.get(1, 0) * b.get(0, 0) + dy.get(1, 1) * b.get(1, 0),
508 dy.get(0, 0) * b.get(0, 1) + dy.get(0, 1) * b.get(1, 1),
509 dy.get(1, 0) * b.get(0, 1) + dy.get(1, 1) * b.get(1, 1),
510 dy.get(0, 0) * b.get(0, 2) + dy.get(0, 1) * b.get(1, 2),
511 dy.get(1, 0) * b.get(0, 2) + dy.get(1, 1) * b.get(1, 2)
512 ]);
513 // db = dyT * aT
514 expect(db.shape).toEqual(b.shape);
515 expectArraysClose(await db.data(), [
516 dy.get(0, 0) * a.get(0, 0) + dy.get(1, 0) * a.get(0, 1),
517 dy.get(0, 0) * a.get(1, 0) + dy.get(1, 0) * a.get(1, 1),
518 dy.get(0, 0) * a.get(2, 0) + dy.get(1, 0) * a.get(2, 1),
519 dy.get(0, 1) * a.get(0, 0) + dy.get(1, 1) * a.get(0, 1),
520 dy.get(0, 1) * a.get(1, 0) + dy.get(1, 1) * a.get(1, 1),
521 dy.get(0, 1) * a.get(2, 0) + dy.get(1, 1) * a.get(2, 1)
522 ]);
523 });
524 it('throws when passed a as a non-tensor', () => {
525 expect(() => tf.matMul({}, tf.tensor2d([2], [1, 1])))
526 .toThrowError(/Argument 'a' passed to 'matMul' must be a Tensor/);
527 });
528 it('throws when passed b as a non-tensor', () => {
529 expect(() => tf.matMul(tf.tensor2d([2], [1, 1]), {}))
530 .toThrowError(/Argument 'b' passed to 'matMul' must be a Tensor/);
531 });
532 it('accepts a tensor-like object', async () => {
533 const a = [[1, 2, 3], [4, 5, 6]]; // 2x3
534 const b = [[0, 1], [-3, 2], [2, 1]]; // 3x2
535 const c = tf.matMul(a, b);
536 expect(c.shape).toEqual([2, 2]);
537 expectArraysClose(await c.data(), [0, 8, -3, 20]);
538 });
539 it('accepts a tensor-like object chained', async () => {
540 const a = tf.tensor2d([[1, 2, 3], [4, 5, 6]], [2, 3]); // 2x3
541 const b = [[0, 1], [-3, 2], [2, 1]]; // 3x2
542 const c = a.matMul(b);
543 expect(c.shape).toEqual([2, 2]);
544 expectArraysClose(await c.data(), [0, 8, -3, 20]);
545 });
546 it('a * b where a has zero in its shape', async () => {
547 const a = tf.tensor2d([], [0, 3]);
548 const b = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
549 const c = tf.matMul(a, b);
550 expect(c.shape).toEqual([0, 2]);
551 expect(c.rank).toBe(2);
552 expect(c.size).toBe(0);
553 expectArraysClose(await c.data(), []);
554 });
555 it('(a * b) * c where a has zero in its shape, so a*b does also', async () => {
556 const a = tf.tensor2d([], [0, 3]);
557 const b = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
558 const ab = tf.matMul(a, b);
559 expect(ab.shape).toEqual([0, 2]);
560 expectArraysClose(await ab.data(), []);
561 const c = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
562 const res = tf.matMul(ab, c);
563 expect(res.shape).toEqual([0, 3]);
564 expectArraysClose(await res.data(), []);
565 });
566 it('throws error for string tensor', () => {
567 expect(() => tf.matMul([['a']], [['b']]))
568 .toThrowError(/Argument 'a' passed to 'matMul' must be numeric tensor/);
569 });
570});
571describeWithFlags('matmulBatch', ALL_ENVS, () => {
572 it('A x B', async () => {
573 const a = tf.tensor3d([
574 -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3, 7, -2, 5,
575 -6, 3, 8, 7, -8, 1, 4, -4, 6, 4, -4, -9, -5, 2, -2
576 ], [5, 2, 3]);
577 const b = tf.tensor3d([
578 -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
579 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
580 ], [5, 3, 2]);
581 const c = tf.matMul(a, b);
582 expect(c.shape).toEqual([5, 2, 2]);
583 expectArraysClose(await c.data(), [
584 87, 20, -6, -32, -24, -50, -36, -5, 24, 98,
585 70, 33, -64, 47, -42, -28, -71, 24, 37, 5
586 ]);
587 });
588 it('A x B in 4D', async () => {
589 const a = tf.tensor4d([
590 -2, 3, 5, -5, 3, 9, -3, -5, 1, 1, -9, 9, -6, 6, -8,
591 -7, -1, 3, 9, -7, -7, 2, 10, -6, -8, -6, 9, -6, 4, -1,
592 9, -6, 10, 8, -9, 5, -8, -7, 0, 2, -5, -1, -9, -4, 3,
593 -2, 6, -4, 7, 1, -5, -4, 9, -8, -6, -8, 4, -1, 4, 3,
594 -7, 8, -7, 5, -3, -2, -4, 9, 2, -1, 1, -10, -3, 5, -4,
595 6, -8, -8, 9, -3, -5, 10, 3, -3, -3, 9, 3, -3, 2, -8,
596 10, 1, 9, -2, -2, -3, -4, 6, -10, -1, 8, -8, 7, 3, -2,
597 3, 6, -2, -2, -4, 1, -5, -4, 0, 5, 1, 9, -8, -2, -1
598 ], [4, 5, 2, 3]);
599 const b = tf.tensor4d([
600 -4, -3, -2, -6, 6, -1, -4, -1, 7, -4, 8, -9, -9, 0, -1,
601 -4, -6, -7, -3, -4, -7, 6, -8, 1, -2, 1, -1, -3, 8, -5,
602 9, -2, 5, 9, -2, 2, -5, -5, -8, -1, -2, -3, -2, -10, 6,
603 -3, 0, 1, 6, 7, 1, 2, -4, -5, 2, -5, -7, 9, 3, -6,
604 6, 4, -4, 6, 10, -3, -2, 8, 10, -8, 10, -1, -9, -7, -8,
605 -3, 1, 1, -2, -9, -7, -6, -1, 0, 7, -9, -7, -5, 0, -4,
606 -4, -7, 2, 4, 6, 6, -4, -6, -8, 3, -8, -9, 6, 9, -4,
607 1, -1, 0, 8, 9, 0, -5, 3, -1, 5, 0, -10, 7, -2, 6
608 ], [4, 5, 3, 2]);
609 const transposeA = false;
610 const transposeB = false;
611 const c = tf.matMul(a, b, transposeA, transposeB);
612 expectArraysClose(await c.data(), [
613 32, -17, 68, -12, -15, 14, 5, -46, 96, 32, 46, -17, 78, -85,
614 -28, 46, 94, -35, 0, -13, 31, -52, 17, -87, 96, 47, 32, -2,
615 -6, 105, 40, -2, 63, 76, 17, 30, 56, -66, -21, 23, -144, 41,
616 22, 8, 118, -106, -88, -6, -17, 2, 2, -26, 8, -63, -38, -108,
617 -84, -30, -35, 49, 16, -12, -14, -12, 48, 132, 4, 102, 32, 66,
618 -4, 33, -13, 1, -40, -25, -3, 61, -18, -20
619 ]);
620 });
621 it('A x B^t', async () => {
622 const a = tf.tensor3d([
623 -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3, 7, -2, 5,
624 -6, 3, 8, 7, -8, 1, 4, -4, 6, 4, -4, -9, -5, 2, -2
625 ], [5, 2, 3]);
626 const b = tf.tensor3d([
627 -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
628 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
629 ], [5, 2, 3]);
630 const transposeA = false;
631 const transposeB = true;
632 const c = tf.matMul(a, b, transposeA, transposeB);
633 expect(c.shape).toEqual([5, 2, 2]);
634 expectArraysClose(await c.data(), [
635 66, 35, -48, 14, -45, -33, -12, 7, -76, 64,
636 3, 66, -119, -9, -64, -60, -76, 48, 33, -16
637 ]);
638 });
639 it('A^t x B', async () => {
640 const a = tf.tensor3d([
641 -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3, 7, -2, 5,
642 -6, 3, 8, 7, -8, 1, 4, -4, 6, 4, -4, -9, -5, 2, -2
643 ], [5, 2, 3]);
644 const b = tf.tensor3d([
645 -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
646 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
647 ], [5, 2, 3]);
648 const transposeA = true;
649 const transposeB = false;
650 const c = tf.matMul(a, b, transposeA, transposeB);
651 expectArraysClose(await c.data(), [
652 40, -36, 5, 40, 34, 5, 48, 80, 6, -6, 21, -48, -23, -20, -50,
653 -12, -21, -12, -58, 15, -96, 23, 6, 39, 20, 109, 42, -67, 45, -40,
654 76, -52, 40, -15, 1, -60, -58, -3, 36, 40, -6, -24, 51, -33, -28
655 ]);
656 });
657 it('A^t x B in 4D', async () => {
658 const a = tf.tensor4d([
659 -2, 3, 5, -5, 3, 9, -3, -5, 1, 1, -9, 9, -6, 6, -8,
660 -7, -1, 3, 9, -7, -7, 2, 10, -6, -8, -6, 9, -6, 4, -1,
661 9, -6, 10, 8, -9, 5, -8, -7, 0, 2, -5, -1, -9, -4, 3,
662 -2, 6, -4, 7, 1, -5, -4, 9, -8, -6, -8, 4, -1, 4, 3,
663 -7, 8, -7, 5, -3, -2, -4, 9, 2, -1, 1, -10, -3, 5, -4,
664 6, -8, -8, 9, -3, -5, 10, 3, -3, -3, 9, 3, -3, 2, -8,
665 10, 1, 9, -2, -2, -3, -4, 6, -10, -1, 8, -8, 7, 3, -2,
666 3, 6, -2, -2, -4, 1, -5, -4, 0, 5, 1, 9, -8, -2, -1
667 ], [4, 5, 2, 3]);
668 const b = tf.tensor4d([
669 -4, -3, -2, -6, 6, -1, -4, -1, 7, -4, 8, -9, -9, 0, -1,
670 -4, -6, -7, -3, -4, -7, 6, -8, 1, -2, 1, -1, -3, 8, -5,
671 9, -2, 5, 9, -2, 2, -5, -5, -8, -1, -2, -3, -2, -10, 6,
672 -3, 0, 1, 6, 7, 1, 2, -4, -5, 2, -5, -7, 9, 3, -6,
673 6, 4, -4, 6, 10, -3, -2, 8, 10, -8, 10, -1, -9, -7, -8,
674 -3, 1, 1, -2, -9, -7, -6, -1, 0, 7, -9, -7, -5, 0, -4,
675 -4, -7, 2, 4, 6, 6, -4, -6, -8, 3, -8, -9, 6, 9, -4,
676 1, -1, 0, 8, 9, 0, -5, 3, -1, 5, 0, -10, 7, -2, 6
677 ], [4, 5, 2, 3]);
678 const transposeA = true;
679 const transposeB = false;
680 const c = tf.matMul(a, b, transposeA, transposeB);
681 expectArraysClose(await c.data(), [
682 38, -24, 9, -30, 9, -9, -74, 39, -19, 8, 11, -30, 56, -67,
683 46, -40, 71, -74, 82, 42, 55, -50, 6, 1, 60, -18, -13, -15,
684 -52, -61, 81, -52, 59, -15, 76, 43, 34, -56, 38, 0, 26, -14,
685 -15, 1, -4, 153, -34, 61, -135, 30, -48, 135, -30, 60, 38, 36,
686 58, 40, 45, 71, 1, 2, 3, 24, 90, -56, -10, 40, -18, 6,
687 -30, 14, 34, 65, 27, 24, -29, -44, -46, -3, 35, -21, 27, 48,
688 20, 52, 32, 35, -11, -46, -12, 22, 13, 30, 2, -23, -54, -48,
689 34, 16, -42, -39, -26, 82, 89, 76, -84, 30, 9, 27, 30, -21,
690 -43, -48, 60, 20, 24, -78, -91, -63, -12, 24, 21, 28, 48, 35,
691 -6, 27, 33, 53, -81, -71, 61, -27, 11, -48, -82, 8, -12, -19,
692 -10, -48, -81, 0, 13, 32, 41, 0, -100, -120, 16, 124, 152, 45,
693 60, -28, 24, 21, -12, -14, -16, 8, 9, -33, 5, -12, -48, 4,
694 8, 9, 0, -31, 16, -98, -9, 4, -22, 38, 2, -96
695 ]);
696 });
697 it('A^t x B^t', async () => {
698 const a = tf.tensor3d([
699 -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3, 7, -2, 5,
700 -6, 3, 8, 7, -8, 1, 4, -4, 6, 4, -4, -9, -5, 2, -2
701 ], [5, 3, 2]);
702 const b = tf.tensor3d([
703 -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
704 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
705 ], [5, 2, 3]);
706 const transposeA = true;
707 const transposeB = true;
708 const c = tf.matMul(a, b, transposeA, transposeB);
709 expectArraysClose(await c.data(), [
710 66, 42, 16, -56, -12, 6, -30, 19, -1, 102,
711 -94, 14, -56, 32, 100, -56, -47, -11, 5, -31
712 ]);
713 });
714 it('batch dimensions do not match', () => {
715 const a = tf.tensor3d([
716 -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3,
717 7, -2, 5, -6, 3, 8, 7, -8, 1, 4, -4, 6
718 ], [4, 3, 2]);
719 const b = tf.tensor3d([
720 -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
721 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
722 ], [5, 2, 3]);
723 const f = () => {
724 tf.matMul(a, b, false, false);
725 };
726 expect(f).toThrowError();
727 });
728 it('gradients: A x B', async () => {
729 const a = tf.tensor3d([
730 -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3, 7, -2, 5,
731 -6, 3, 8, 7, -8, 1, 4, -4, 6, 4, -4, -9, -5, 2, -2
732 ], [5, 2, 3]);
733 const b = tf.tensor3d([
734 -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
735 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
736 ], [5, 3, 2]);
737 const dy = tf.tensor3d([8, 2, -3, -2, -8, 4, 5, 7, 4, -4, -4, 5, 8, 10, 1, 0, 6, 6, -4, 7], [5, 2, 2]);
738 const grads = tf.grads((a, b) => tf.matMul(a, b, false, false));
739 const [da, db] = grads([a, b], dy);
740 // da = dy * bT
741 expect(da.shape).toEqual(a.shape);
742 expectArraysClose(await da.data(), [
743 -72, -8, -56, 32, 3, 21, -12, -40, 40, 36, 44, 51, -52, -44, -4,
744 61, 49, 13, -2, -10, -108, -9, 0, -1, -24, 60, -6, 49, 26, -40
745 ]);
746 // db = aT * dy
747 expect(db.shape).toEqual(b.shape);
748 expectArraysClose(await db.data(), [
749 -64, -26, -34, -6, -24, 4, -77, -47, 51, -35, 63, -3, 52, -58, -20,
750 23, -12, 20, 60, 70, -68, -80, 14, 10, 44, -11, -32, -10, -46, -68
751 ]);
752 });
753 it('4d gradients: A x B', async () => {
754 const a = tf.tensor4d([
755 -2, 3, 5, -5, 3, 9, -3, -5, 1, 1, -9, 9, -6, 6, -8,
756 -7, -1, 3, 9, -7, -7, 2, 10, -6, -8, -6, 9, -6, 4, -1,
757 9, -6, 10, 8, -9, 5, -8, -7, 0, 2, -5, -1, -9, -4, 3,
758 -2, 6, -4, 7, 1, -5, -4, 9, -8, -6, -8, 4, -1, 4, 3,
759 -7, 8, -7, 5, -3, -2, -4, 9, 2, -1, 1, -10, -3, 5, -4,
760 6, -8, -8, 9, -3, -5, 10, 3, -3, -3, 9, 3, -3, 2, -8,
761 10, 1, 9, -2, -2, -3, -4, 6, -10, -1, 8, -8, 7, 3, -2,
762 3, 6, -2, -2, -4, 1, -5, -4, 0, 5, 1, 9, -8, -2, -1
763 ], [4, 5, 2, 3]);
764 const b = tf.tensor4d([
765 -4, -3, -2, -6, 6, -1, -4, -1, 7, -4, 8, -9, -9, 0, -1,
766 -4, -6, -7, -3, -4, -7, 6, -8, 1, -2, 1, -1, -3, 8, -5,
767 9, -2, 5, 9, -2, 2, -5, -5, -8, -1, -2, -3, -2, -10, 6,
768 -3, 0, 1, 6, 7, 1, 2, -4, -5, 2, -5, -7, 9, 3, -6,
769 6, 4, -4, 6, 10, -3, -2, 8, 10, -8, 10, -1, -9, -7, -8,
770 -3, 1, 1, -2, -9, -7, -6, -1, 0, 7, -9, -7, -5, 0, -4,
771 -4, -7, 2, 4, 6, 6, -4, -6, -8, 3, -8, -9, 6, 9, -4,
772 1, -1, 0, 8, 9, 0, -5, 3, -1, 5, 0, -10, 7, -2, 6
773 ], [4, 5, 3, 2]);
774 const dy = tf.tensor4d([
775 8, -7, 0, -9, -5, -5, 0, 3, 7, -4, 6, -8, -8, 0, -1, -8,
776 -9, -7, -4, -9, 2, 3, 5, 8, -5, -7, 3, -10, -5, -9, -5, 1,
777 7, 1, -9, -10, 8, 5, 0, 8, -6, 4, 0, -5, 8, -7, -2, 1,
778 -8, 9, 9, -7, 1, 7, -2, 5, -2, 9, 1, -5, 7, 5, -7, -6,
779 6, 7, -8, 7, 4, -5, 4, -5, 3, -4, -5, 4, -6, 3, -8, 10
780 ], [4, 5, 2, 2]);
781 const grads = tf.grads((a, b) => tf.matMul(a, b, false, false));
782 const [da, db] = grads([a, b], dy);
783 // da = dy * bT
784 expect(da.shape).toEqual(a.shape);
785 expectArraysClose(await da.data(), [
786 -11, 26, 55, 27, 54, 9, 25, -15, 5, -3, -12, -27, -63, 9,
787 -14, -54, 26, 20, 24, 56, 64, 35, -41, 0, 11, 30, -37, -1,
788 31, 13, 12, 37, 2, 29, 97, 6, 60, 47, 31, 35, -14, 24,
789 100, -3, -9, 0, -33, 1, 49, 9, -33, -124, -29, 86, -9, -11,
790 -6, -40, 72, -48, -20, 48, -72, -20, -30, 15, -72, 136, 87, 12,
791 -28, -21, 9, 37, 1, -32, -51, 2, -65, -49, -1, -41, -16, 2,
792 -95, -31, -36, 52, 18, 20, -63, 34, 72, 70, -38, -78, -66, -27,
793 -111, -10, 85, 1, -21, -21, -4, -21, -21, -4, -12, 20, 13, -4,
794 -20, -19, -30, 81, 30, -40, 150, 76
795 ]);
796 // db = aT * dy
797 expect(db.shape).toEqual(b.shape);
798 expectArraysClose(await db.data(), [
799 -16, 59, 24, -48, 40, -116, 15, 18, 25, -2, -5, 22, -84, 80,
800 36, -16, -38, 8, -74, -16, 46, -80, 62, 48, 96, 110, 38, 6,
801 -77, -54, 58, 91, -57, -90, 45, 70, 46, 36, 20, 99, -3, 10,
802 55, 79, -10, 42, 5, -31, 85, 47, -74, -89, 37, 75, -48, -38,
803 -64, -8, 32, 44, 42, -53, -48, 47, 42, -18, -30, 27, 70, -62,
804 36, -24, 78, -69, -112, 101, -40, 20, -11, 113, -9, -6, 1, -50,
805 3, -12, -16, 71, -14, 67, 84, 62, 21, 17, 84, 63, -16, -35,
806 -28, 98, 4, -126, 40, -50, 36, -45, -16, 20, 19, -12, 8, 0,
807 3, -4, 34, -65, 10, -17, -46, 17
808 ]);
809 });
810 it('gradients: A x B^t', async () => {
811 const a = tf.tensor3d([
812 -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3, 7, -2, 5,
813 -6, 3, 8, 7, -8, 1, 4, -4, 6, 4, -4, -9, -5, 2, -2
814 ], [5, 3, 2]);
815 const b = tf.tensor3d([
816 -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
817 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
818 ], [5, 3, 2]);
819 const dy = tf.tensor3d([
820 -0, 7, 5, 0, -9, 5, -7, 6, -5, -3, -2, -2, -4, 10, -3,
821 5, -1, 3, -2, -9, 4, -5, 7, 9, -10, -8, -8, -5, -0, -1,
822 3, 3, 4, 9, -7, 6, -2, -9, 5, 1, -5, -3, -1, 9, 4
823 ], [5, 3, 3]);
824 const grads = tf.grads((a, b) => tf.matMul(a, b, false, true));
825 const [da, db] = grads([a, b], dy);
826 expect(da.shape).toEqual(a.shape);
827 expectArraysClose(await da.data(), [
828 -42, 0, -26, 0, 85, 28, -19, -29, 51, -16, 6, 37, 94, -27, 50,
829 71, 24, -202, 46, -25, -31, -22, -87, 10, -7, -80, -36, -15, 55, 35
830 ]);
831 expect(db.shape).toEqual(b.shape);
832 expectArraysClose(await db.data(), [
833 14, 56, 7, -155, -45, 55, 7, 72, -67, -79, 7, 50, -69, -46, -52,
834 -88, 49, -126, -68, 106, 31, -30, -27, 60, -19, 5, 27, 43, 55, -13
835 ]);
836 });
837 it('4d gradients: A x B^t', async () => {
838 const a = tf.tensor4d([
839 -2, 3, 5, -5, 3, 9, -3, -5, 1, 1, -9, 9, -6, 6, -8,
840 -7, -1, 3, 9, -7, -7, 2, 10, -6, -8, -6, 9, -6, 4, -1,
841 9, -6, 10, 8, -9, 5, -8, -7, 0, 2, -5, -1, -9, -4, 3,
842 -2, 6, -4, 7, 1, -5, -4, 9, -8, -6, -8, 4, -1, 4, 3,
843 -7, 8, -7, 5, -3, -2, -4, 9, 2, -1, 1, -10, -3, 5, -4,
844 6, -8, -8, 9, -3, -5, 10, 3, -3, -3, 9, 3, -3, 2, -8,
845 10, 1, 9, -2, -2, -3, -4, 6, -10, -1, 8, -8, 7, 3, -2,
846 3, 6, -2, -2, -4, 1, -5, -4, 0, 5, 1, 9, -8, -2, -1
847 ], [4, 5, 3, 2]);
848 const b = tf.tensor4d([
849 -4, -3, -2, -6, 6, -1, -4, -1, 7, -4, 8, -9, -9, 0, -1,
850 -4, -6, -7, -3, -4, -7, 6, -8, 1, -2, 1, -1, -3, 8, -5,
851 9, -2, 5, 9, -2, 2, -5, -5, -8, -1, -2, -3, -2, -10, 6,
852 -3, 0, 1, 6, 7, 1, 2, -4, -5, 2, -5, -7, 9, 3, -6,
853 6, 4, -4, 6, 10, -3, -2, 8, 10, -8, 10, -1, -9, -7, -8,
854 -3, 1, 1, -2, -9, -7, -6, -1, 0, 7, -9, -7, -5, 0, -4,
855 -4, -7, 2, 4, 6, 6, -4, -6, -8, 3, -8, -9, 6, 9, -4,
856 1, -1, 0, 8, 9, 0, -5, 3, -1, 5, 0, -10, 7, -2, 6
857 ], [4, 5, 3, 2]);
858 const dy = tf.tensor4d([
859 5, -1, -5, -4, -1, 9, 1, -2, 10, 7, -1, 6, -8, 8, -3,
860 9, -4, 2, -4, -8, 8, 4, 8, -10, -8, -8, 6, 6, -5, 9,
861 -1, -7, -5, -3, -3, 2, -6, 5, 8, -9, 5, -8, -3, 8, 6,
862 2, 8, 5, 9, 7, 6, 2, -3, 10, 7, 7, -3, 4, -3, -6,
863 -8, -8, 9, 0, -8, -3, -2, -2, 8, 2, 3, -6, 3, 6, -3,
864 7, 7, -9, -3, 8, 7, 7, -1, -6, 5, 2, -1, -1, 1, 5,
865 0, -4, 3, -4, -10, 1, -2, -8, -9, -6, 4, 4, -7, -1, -1,
866 -9, 7, 1, -1, 8, 0, -2, -7, 5, 7, 8, 9, -3, -8, -6,
867 -7, -8, -1, 8, -4, 7, 5, -9, 9, 3, 0, -10, 7, -9, 4,
868 -7, 5, -2, -2, 3, 3, -6, 2, 0, 8, -5, -10, 3, -7, 0,
869 -6, 2, 3, -1, 3, 3, -10, 1, 3, -7, -1, 8, -2, -1, -1,
870 -3, -9, 7, 4, -6, 3, 0, -7, -4, -5, -8, -6, 10, -6, 4
871 ], [4, 5, 3, 3]);
872 const grads = tf.grads((a, b) => tf.matMul(a, b, false, true));
873 const [da, db] = grads([a, b], dy);
874 expect(da.shape).toEqual(a.shape);
875 expectArraysClose(await da.data(), [
876 -48, -4, 72, 9, 60, -1, 13, -57, 64, 3, -48, -11, -4, -24,
877 16, 38, 44, -10, -55, -45, 92, -43, 14, -4, 71, -61, -51, 16,
878 46, -57, 48, 78, 104, 57, -17, -11, -85, -33, 16, 1, 86, 21,
879 -48, 21, -8, 34, 14, -35, 36, 48, 85, 108, -38, -40, 3, -8,
880 -7, -1, 6, -16, 46, -33, 26, -79, -70, -29, 92, -84, -6, -47,
881 98, -129, -55, -17, 79, 40, -118, -64, 68, 75, 71, 111, 5, -48,
882 98, -36, 21, 13, 112, -34, 26, 57, 32, 44, 28, 50, 88, 27,
883 44, -39, -16, 15, -21, -6, -67, -89, -46, -64, -19, -12, -3, 11,
884 41, 63, 78, -73, 67, -92, 102, -18
885 ]);
886 expect(db.shape).toEqual(b.shape);
887 expectArraysClose(await db.data(), [
888 -27, 44, -9, -16, 85, 30, -110, 38, 47, -23, -39, -15, 0, -76,
889 -8, -128, 26, 136, 31, -26, -26, 39, 136, -85, -45, 93, 37, -68,
890 -112, -6, 90, 70, 169, -7, 15, 68, -16, -33, -16, -47, -21, 0,
891 6, -4, 84, 24, 15, 20, -41, -1, 79, -86, 87, -23, -26, -64,
892 18, 9, 52, 64, 34, -16, 122, -66, -1, 47, 1, 43, -11, -33,
893 -17, 27, -45, -73, -60, -66, -92, -42, 32, -85, -44, -44, -28, -13,
894 8, -20, 9, -9, -49, 79, -76, 15, 73, -7, 7, -8, -110, 93,
895 106, -39, 64, -84, -29, -19, 13, 14, 63, 2, -15, 23, 17, 49,
896 -3, -31, -65, 30, -95, 63, -82, 40
897 ]);
898 });
899 it('gradients: A^t x B', async () => {
900 const a = tf.tensor3d([
901 -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3, 7, -2, 5,
902 -6, 3, 8, 7, -8, 1, 4, -4, 6, 4, -4, -9, -5, 2, -2
903 ], [5, 3, 2]);
904 const b = tf.tensor3d([
905 -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
906 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
907 ], [5, 3, 2]);
908 const dy = tf.tensor3d([8, 2, -3, -2, -8, 4, 5, 7, 4, -4, -4, 5, 8, 10, 1, 0, 6, 6, -4, 7], [5, 2, 2]);
909 const grads = tf.grads((a, b) => tf.matMul(a, b, true, false));
910 const [da, db] = grads([a, b], dy);
911 expect(da.shape).toEqual(a.shape);
912 expectArraysClose(await da.data(), [
913 -72, 32, -8, 3, -56, 21, -12, 36, -40, 44, 40, 51, -52, 61, -44,
914 49, -4, 13, -2, -9, -10, 0, -108, -1, -24, 49, 60, 26, -6, -40
915 ]);
916 expect(db.shape).toEqual(b.shape);
917 expectArraysClose(await db.data(), [
918 -25, 0, -72, -28, 8, 12, -67, -33, 3, -87, 23, 17, 36, -38, 44,
919 -50, -20, 28, 48, 70, 12, 10, -26, -40, 40, -4, -34, -89, 20, -2
920 ]);
921 });
922 it('gradients: A^t x B^t', async () => {
923 const a = tf.tensor3d([
924 -5, -5, -6, 8, -2, -8, 4, -7, -6, -9, -1, 3, 7, -2, 5,
925 -6, 3, 8, 7, -8, 1, 4, -4, 6, 4, -4, -9, -5, 2, -2
926 ], [5, 3, 2]);
927 const b = tf.tensor3d([
928 -8, -4, -1, 0, -7, 0, 3, 3, 6, 2, -1, 8, -4, 9, -6,
929 5, 8, 9, -9, 7, 0, -1, -1, -10, -7, 3, 4, 6, 3, -4
930 ], [5, 2, 3]);
931 const dy = tf.tensor3d([8, 2, -3, -2, -8, 4, 5, 7, 4, -4, -4, 5, 8, 10, 1, 0, 6, 6, -4, 7], [5, 2, 2]);
932 const grads = tf.grads((a, b) => tf.matMul(a, b, true, true));
933 const [da, db] = grads([a, b], dy);
934 expect(da.shape).toEqual(a.shape);
935 expectArraysClose(await da.data(), [
936 -64, 24, -46, 26, -8, 3, -16, 29, -28, 8, -16, 86, -36, 41, 4,
937 4, -60, 69, -82, -9, 46, 7, -100, 0, -6, 70, 36, 9, 0, -44
938 ]);
939 expect(db.shape).toEqual(b.shape);
940 expectArraysClose(await db.data(), [
941 -25, -72, 8, 0, -28, 12, -67, 3, 23, -33, -87, 17, 36, 44, -20,
942 -38, -50, 28, 48, 12, -26, 70, 10, -40, 40, -34, 20, -4, -89, -2
943 ]);
944 });
945});
946describeWithFlags('dot', ALL_ENVS, () => {
947 let a;
948 let b;
949 let c;
950 let d;
951 let e;
952 let f;
953 beforeEach(() => {
954 a = tf.tensor1d([1, 2]);
955 b = tf.tensor2d([[1, 2], [3, 4]]);
956 c = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
957 d = tf.tensor3d([1, 2], [1, 1, 2]);
958 e = tf.scalar(1);
959 f = tf.tensor3d([1, 2, 1, 2], [2, 1, 2]);
960 });
961 it('vector-vector', async () => {
962 const aa = tf.dot(a, a);
963 expectArraysClose(await aa.data(), [5]);
964 expect(aa.shape).toEqual([]);
965 });
966 it('vector-matrix', async () => {
967 const ab = tf.dot(a, b);
968 const ac = tf.dot(a, c);
969 expect(ab.shape).toEqual([2]);
970 expect(ac.shape).toEqual([3]);
971 expectArraysClose(await ab.data(), [7, 10]);
972 expectArraysClose(await ac.data(), [9, 12, 15]);
973 });
974 it('matrix-vector', async () => {
975 const ba = b.dot(a);
976 expect(ba.shape).toEqual([2]);
977 expectArraysClose(await ba.data(), [5, 11]);
978 });
979 it('matrix-matrix', async () => {
980 const bb = tf.dot(b, b);
981 const bc = tf.dot(b, c);
982 expect(bb.shape).toEqual([2, 2]);
983 expect(bc.shape).toEqual([2, 3]);
984 expectArraysClose(await bb.data(), [7, 10, 15, 22]);
985 expectArraysClose(await bc.data(), [9, 12, 15, 19, 26, 33]);
986 });
987 it('matmul A x B asymmetric', async () => {
988 const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
989 const b = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
990 const c = tf.matMul(a, b);
991 const cData = await c.data();
992 expect(c.shape).toEqual([2, 3]);
993 expectArraysClose(cData, [9, 12, 15, 19, 26, 33]);
994 });
995 it('throws error on incompatible dimensions', () => {
996 expect(() => tf.dot(c, f)).toThrowError();
997 });
998 it('throws error when inputs are not rank 1 or 2', () => {
999 expect(() => tf.dot(a, d)).toThrowError();
1000 expect(() => tf.dot(a, e)).toThrowError();
1001 });
1002 it('accepts a tensor-like object', async () => {
1003 const a = [1, 2, 3];
1004 const res = tf.dot(a, a);
1005 expectArraysClose(await res.data(), [14]);
1006 expect(res.shape).toEqual([]);
1007 });
1008 it('throws error for string tensors', () => {
1009 expect(() => tf.dot('a', 'b'))
1010 .toThrowError(/Argument 't1' passed to 'dot' must be numeric tensor/);
1011 });
1012 it('ensure no memory leak', async () => {
1013 const numTensorsBefore = tf.memory().numTensors;
1014 const numDataIdBefore = tf.engine().backend.numDataIds();
1015 const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
1016 const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
1017 const c = tf.matMul(a, b);
1018 expect(c.shape).toEqual([2, 2]);
1019 expectArraysClose(await c.data(), [0, 8, -3, 20]);
1020 a.dispose();
1021 b.dispose();
1022 c.dispose();
1023 const numTensorsAfter = tf.memory().numTensors;
1024 const numDataIdAfter = tf.engine().backend.numDataIds();
1025 expect(numTensorsAfter).toBe(numTensorsBefore);
1026 expect(numDataIdAfter).toBe(numDataIdBefore);
1027 });
1028});
1029//# sourceMappingURL=mat_mul_test.js.map
\No newline at end of file