UNPKG

24.2 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose, expectNumbersClose } from '../test_util';
20import { backend } from '../index';
21describeWithFlags('square', ALL_ENVS, () => {
22 it('1D array', async () => {
23 const a = tf.tensor1d([2, 4, Math.sqrt(2)]);
24 const r = tf.square(a);
25 expectArraysClose(await r.data(), [4, 16, 2]);
26 });
27 it('2D array', async () => {
28 const a = tf.tensor2d([1, 2, Math.sqrt(2), Math.sqrt(3)], [2, 2]);
29 const r = tf.square(a);
30 expect(r.shape).toEqual([2, 2]);
31 expectArraysClose(await r.data(), [1, 4, 2, 3]);
32 });
33 it('5D array', async () => {
34 const a = tf.tensor5d([1, 2, Math.sqrt(2), Math.sqrt(3)], [1, 1, 2, 2, 1]);
35 const r = tf.square(a);
36 expect(r.shape).toEqual([1, 1, 2, 2, 1]);
37 expectArraysClose(await r.data(), [1, 4, 2, 3]);
38 });
39 it('6D array', async () => {
40 const a = tf.tensor6d([1, 2, Math.sqrt(2), Math.sqrt(3), 3, 4, Math.sqrt(7), Math.sqrt(13)], [1, 1, 2, 2, 2, 1]);
41 const r = tf.square(a);
42 expect(r.shape).toEqual(a.shape);
43 expectArraysClose(await r.data(), [1, 4, 2, 3, 9, 16, 7, 13]);
44 });
45 it('square propagates NaNs', async () => {
46 const a = tf.tensor1d([1.5, NaN]);
47 const r = tf.square(a);
48 expectArraysClose(await r.data(), [2.25, NaN]);
49 });
50 it('int32', async () => {
51 if (backend() && backend().floatPrecision() === 32) {
52 // TODO: Use skip() instead when it is implemented
53 const a = tf.tensor1d([2, 4, 40000], 'int32');
54 const r = tf.square(a);
55 expect(r.dtype).toEqual('int32');
56 const data = await r.data();
57 expectNumbersClose(data[0], 4);
58 expectNumbersClose(data[1], 16);
59 // Epsilon must be larger here for webgl1
60 // TODO: Use expectArraysClose when it supports epsilons scaled by the
61 // numbers being compared.
62 expectNumbersClose(data[2], 1600000000, 1000 /* epsilon */);
63 }
64 });
65 it('gradients: Scalar', async () => {
66 const a = tf.scalar(5);
67 const dy = tf.scalar(8);
68 const gradients = tf.grad(a => tf.square(a))(a, dy);
69 expect(gradients.shape).toEqual(a.shape);
70 expect(gradients.dtype).toEqual('float32');
71 expectArraysClose(await gradients.data(), [2 * 5 * 8]);
72 });
73 it('gradients: Scalar', async () => {
74 const a = tf.scalar(5);
75 const dy = tf.scalar(8);
76 const gradients = tf.grad(a => tf.square(a.clone()).clone())(a, dy);
77 expect(gradients.shape).toEqual(a.shape);
78 expect(gradients.dtype).toEqual('float32');
79 expectArraysClose(await gradients.data(), [2 * 5 * 8]);
80 });
81 it('gradients: Tensor1D', async () => {
82 const a = tf.tensor1d([-1, 2, 3, -5]);
83 const dy = tf.tensor1d([1, 2, 3, 4]);
84 const gradients = tf.grad(a => tf.square(a))(a, dy);
85 expect(gradients.shape).toEqual(a.shape);
86 expect(gradients.dtype).toEqual('float32');
87 expectArraysClose(await gradients.data(), [-2, 4 * 2, 6 * 3, -10 * 4]);
88 });
89 it('gradients: Tensor2D', async () => {
90 const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]);
91 const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]);
92 const gradients = tf.grad(a => tf.square(a))(a, dy);
93 expect(gradients.shape).toEqual(a.shape);
94 expect(gradients.dtype).toEqual('float32');
95 expectArraysClose(await gradients.data(), [-6 * 1, 2 * 2, 4 * 3, 6 * 4]);
96 });
97 it('gradients: Tensor5D', async () => {
98 const a = tf.tensor5d([-3, 1, 2, 3], [1, 1, 1, 2, 2]);
99 const dy = tf.tensor5d([1, 2, 3, 4], [1, 1, 1, 2, 2]);
100 const gradients = tf.grad(a => tf.square(a))(a, dy);
101 expect(gradients.shape).toEqual(a.shape);
102 expect(gradients.dtype).toEqual('float32');
103 expectArraysClose(await gradients.data(), [-6 * 1, 2 * 2, 4 * 3, 6 * 4]);
104 });
105 it('gradients: Tensor6D', async () => {
106 const a = tf.tensor6d([-3, 1, 2, 3, -4, 5, 12, 3], [1, 1, 1, 2, 2, 2]);
107 const dy = tf.tensor6d([1, 2, 3, 4, 5, 6, 7, 8], [1, 1, 1, 2, 2, 2]);
108 const gradients = tf.grad(a => tf.square(a))(a, dy);
109 expect(gradients.shape).toEqual(a.shape);
110 expect(gradients.dtype).toEqual('float32');
111 expectArraysClose(await gradients.data(), [-6 * 1, 2 * 2, 4 * 3, 6 * 4, -8 * 5, 10 * 6, 24 * 7, 6 * 8]);
112 });
113 it('throws when passed a non-tensor', () => {
114 expect(() => tf.square({}))
115 .toThrowError(/Argument 'x' passed to 'square' must be a Tensor/);
116 });
117 it('accepts a tensor-like object', async () => {
118 const r = tf.square([2, 4, Math.sqrt(2)]);
119 expectArraysClose(await r.data(), [4, 16, 2]);
120 });
121 it('throws for string tensor', () => {
122 expect(() => tf.square('q'))
123 .toThrowError(/Argument 'x' passed to 'square' must be numeric/);
124 });
125});
126//# sourceMappingURL=data:application/json;base64,
\No newline at end of file