UNPKG

2.6 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20describeWithFlags('broadcastTo', ALL_ENVS, () => {
21 it('[] -> [3,2]', async () => {
22 const a = tf.scalar(4.2);
23 const A = tf.tensor2d([[4.2, 4.2], [4.2, 4.2], [4.2, 4.2]]);
24 expectArraysClose(await A.array(), await tf.broadcastTo(a, A.shape).array());
25 // test gradients
26 const w = tf.tensor2d([[4.7, 4.5], [-6.1, -6.6], [-8.1, -3.4]]), f = (a) => tf.broadcastTo(a, A.shape).mul(w).mean().asScalar(), h = (a) => a.mul(w).mean().asScalar();
27 const df = tf.grad(f), dh = tf.grad(h);
28 expectArraysClose(await df(a).array(), await dh(a).array());
29 });
30 it('[2] -> [3,2]', async () => {
31 const a = tf.tensor1d([1, 2]);
32 const A = tf.tensor2d([[1, 2], [1, 2], [1, 2]]);
33 expectArraysClose(await A.array(), await tf.broadcastTo(a, A.shape).array());
34 // test gradients
35 const w = tf.tensor2d([[4.7, 4.5], [-6.1, -6.6], [-8.1, -3.4]]), f = (a) => tf.broadcastTo(a, A.shape).mul(w).mean().asScalar(), h = (a) => a.mul(w).mean().asScalar();
36 const df = tf.grad(f), dh = tf.grad(h);
37 expectArraysClose(await df(a).array(), await dh(a).array());
38 });
39 it('[3,1] -> [3,2]', async () => {
40 const a = tf.tensor2d([[1], [2], [3]]);
41 const A = tf.tensor2d([[1, 1], [2, 2], [3, 3]]);
42 expectArraysClose(await A.array(), await tf.broadcastTo(a, A.shape).array());
43 // test gradients
44 const w = tf.tensor2d([[4.7, 4.5], [-6.1, -6.6], [-8.1, -3.4]]), f = (a) => tf.broadcastTo(a, A.shape).mul(w).mean().asScalar(), h = (a) => a.mul(w).mean().asScalar();
45 const df = tf.grad(f), dh = tf.grad(h);
46 expectArraysClose(await df(a).array(), await dh(a).array());
47 });
48});
49//# sourceMappingURL=broadcast_to_test.js.map
\No newline at end of file