1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 | import * as tf from '../index';
|
18 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
19 | import { expectArraysClose } from '../test_util';
|
20 | describeWithFlags('broadcastTo', ALL_ENVS, () => {
|
21 | it('[] -> [3,2]', async () => {
|
22 | const a = tf.scalar(4.2);
|
23 | const A = tf.tensor2d([[4.2, 4.2], [4.2, 4.2], [4.2, 4.2]]);
|
24 | expectArraysClose(await A.array(), await tf.broadcastTo(a, A.shape).array());
|
25 |
|
26 | const w = tf.tensor2d([[4.7, 4.5], [-6.1, -6.6], [-8.1, -3.4]]), f = (a) => tf.broadcastTo(a, A.shape).mul(w).mean().asScalar(), h = (a) => a.mul(w).mean().asScalar();
|
27 | const df = tf.grad(f), dh = tf.grad(h);
|
28 | expectArraysClose(await df(a).array(), await dh(a).array());
|
29 | });
|
30 | it('[2] -> [3,2]', async () => {
|
31 | const a = tf.tensor1d([1, 2]);
|
32 | const A = tf.tensor2d([[1, 2], [1, 2], [1, 2]]);
|
33 | expectArraysClose(await A.array(), await tf.broadcastTo(a, A.shape).array());
|
34 |
|
35 | const w = tf.tensor2d([[4.7, 4.5], [-6.1, -6.6], [-8.1, -3.4]]), f = (a) => tf.broadcastTo(a, A.shape).mul(w).mean().asScalar(), h = (a) => a.mul(w).mean().asScalar();
|
36 | const df = tf.grad(f), dh = tf.grad(h);
|
37 | expectArraysClose(await df(a).array(), await dh(a).array());
|
38 | });
|
39 | it('[3,1] -> [3,2]', async () => {
|
40 | const a = tf.tensor2d([[1], [2], [3]]);
|
41 | const A = tf.tensor2d([[1, 1], [2, 2], [3, 3]]);
|
42 | expectArraysClose(await A.array(), await tf.broadcastTo(a, A.shape).array());
|
43 |
|
44 | const w = tf.tensor2d([[4.7, 4.5], [-6.1, -6.6], [-8.1, -3.4]]), f = (a) => tf.broadcastTo(a, A.shape).mul(w).mean().asScalar(), h = (a) => a.mul(w).mean().asScalar();
|
45 | const df = tf.grad(f), dh = tf.grad(h);
|
46 | expectArraysClose(await df(a).array(), await dh(a).array());
|
47 | });
|
48 | });
|
49 |
|
\ | No newline at end of file |