1 | /**
|
2 | * @license
|
3 | * Copyright 2020 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | import * as tf from '../index';
|
18 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
19 | import { expectArraysClose, expectNumbersClose } from '../test_util';
|
20 | import { backend } from '../index';
|
21 | describeWithFlags('square', ALL_ENVS, () => {
|
22 | it('1D array', async () => {
|
23 | const a = tf.tensor1d([2, 4, Math.sqrt(2)]);
|
24 | const r = tf.square(a);
|
25 | expectArraysClose(await r.data(), [4, 16, 2]);
|
26 | });
|
27 | it('2D array', async () => {
|
28 | const a = tf.tensor2d([1, 2, Math.sqrt(2), Math.sqrt(3)], [2, 2]);
|
29 | const r = tf.square(a);
|
30 | expect(r.shape).toEqual([2, 2]);
|
31 | expectArraysClose(await r.data(), [1, 4, 2, 3]);
|
32 | });
|
33 | it('5D array', async () => {
|
34 | const a = tf.tensor5d([1, 2, Math.sqrt(2), Math.sqrt(3)], [1, 1, 2, 2, 1]);
|
35 | const r = tf.square(a);
|
36 | expect(r.shape).toEqual([1, 1, 2, 2, 1]);
|
37 | expectArraysClose(await r.data(), [1, 4, 2, 3]);
|
38 | });
|
39 | it('6D array', async () => {
|
40 | const a = tf.tensor6d([1, 2, Math.sqrt(2), Math.sqrt(3), 3, 4, Math.sqrt(7), Math.sqrt(13)], [1, 1, 2, 2, 2, 1]);
|
41 | const r = tf.square(a);
|
42 | expect(r.shape).toEqual(a.shape);
|
43 | expectArraysClose(await r.data(), [1, 4, 2, 3, 9, 16, 7, 13]);
|
44 | });
|
45 | it('square propagates NaNs', async () => {
|
46 | const a = tf.tensor1d([1.5, NaN]);
|
47 | const r = tf.square(a);
|
48 | expectArraysClose(await r.data(), [2.25, NaN]);
|
49 | });
|
50 | it('int32', async () => {
|
51 | if (backend() && backend().floatPrecision() === 32) {
|
52 | // TODO: Use skip() instead when it is implemented
|
53 | const a = tf.tensor1d([2, 4, 40000], 'int32');
|
54 | const r = tf.square(a);
|
55 | expect(r.dtype).toEqual('int32');
|
56 | const data = await r.data();
|
57 | expectNumbersClose(data[0], 4);
|
58 | expectNumbersClose(data[1], 16);
|
59 | // Epsilon must be larger here for webgl1
|
60 | // TODO: Use expectArraysClose when it supports epsilons scaled by the
|
61 | // numbers being compared.
|
62 | expectNumbersClose(data[2], 1600000000, 1000 /* epsilon */);
|
63 | }
|
64 | });
|
65 | it('gradients: Scalar', async () => {
|
66 | const a = tf.scalar(5);
|
67 | const dy = tf.scalar(8);
|
68 | const gradients = tf.grad(a => tf.square(a))(a, dy);
|
69 | expect(gradients.shape).toEqual(a.shape);
|
70 | expect(gradients.dtype).toEqual('float32');
|
71 | expectArraysClose(await gradients.data(), [2 * 5 * 8]);
|
72 | });
|
73 | it('gradients: Scalar', async () => {
|
74 | const a = tf.scalar(5);
|
75 | const dy = tf.scalar(8);
|
76 | const gradients = tf.grad(a => tf.square(a.clone()).clone())(a, dy);
|
77 | expect(gradients.shape).toEqual(a.shape);
|
78 | expect(gradients.dtype).toEqual('float32');
|
79 | expectArraysClose(await gradients.data(), [2 * 5 * 8]);
|
80 | });
|
81 | it('gradients: Tensor1D', async () => {
|
82 | const a = tf.tensor1d([-1, 2, 3, -5]);
|
83 | const dy = tf.tensor1d([1, 2, 3, 4]);
|
84 | const gradients = tf.grad(a => tf.square(a))(a, dy);
|
85 | expect(gradients.shape).toEqual(a.shape);
|
86 | expect(gradients.dtype).toEqual('float32');
|
87 | expectArraysClose(await gradients.data(), [-2, 4 * 2, 6 * 3, -10 * 4]);
|
88 | });
|
89 | it('gradients: Tensor2D', async () => {
|
90 | const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]);
|
91 | const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]);
|
92 | const gradients = tf.grad(a => tf.square(a))(a, dy);
|
93 | expect(gradients.shape).toEqual(a.shape);
|
94 | expect(gradients.dtype).toEqual('float32');
|
95 | expectArraysClose(await gradients.data(), [-6 * 1, 2 * 2, 4 * 3, 6 * 4]);
|
96 | });
|
97 | it('gradients: Tensor5D', async () => {
|
98 | const a = tf.tensor5d([-3, 1, 2, 3], [1, 1, 1, 2, 2]);
|
99 | const dy = tf.tensor5d([1, 2, 3, 4], [1, 1, 1, 2, 2]);
|
100 | const gradients = tf.grad(a => tf.square(a))(a, dy);
|
101 | expect(gradients.shape).toEqual(a.shape);
|
102 | expect(gradients.dtype).toEqual('float32');
|
103 | expectArraysClose(await gradients.data(), [-6 * 1, 2 * 2, 4 * 3, 6 * 4]);
|
104 | });
|
105 | it('gradients: Tensor6D', async () => {
|
106 | const a = tf.tensor6d([-3, 1, 2, 3, -4, 5, 12, 3], [1, 1, 1, 2, 2, 2]);
|
107 | const dy = tf.tensor6d([1, 2, 3, 4, 5, 6, 7, 8], [1, 1, 1, 2, 2, 2]);
|
108 | const gradients = tf.grad(a => tf.square(a))(a, dy);
|
109 | expect(gradients.shape).toEqual(a.shape);
|
110 | expect(gradients.dtype).toEqual('float32');
|
111 | expectArraysClose(await gradients.data(), [-6 * 1, 2 * 2, 4 * 3, 6 * 4, -8 * 5, 10 * 6, 24 * 7, 6 * 8]);
|
112 | });
|
113 | it('throws when passed a non-tensor', () => {
|
114 | expect(() => tf.square({}))
|
115 | .toThrowError(/Argument 'x' passed to 'square' must be a Tensor/);
|
116 | });
|
117 | it('accepts a tensor-like object', async () => {
|
118 | const r = tf.square([2, 4, Math.sqrt(2)]);
|
119 | expectArraysClose(await r.data(), [4, 16, 2]);
|
120 | });
|
121 | it('throws for string tensor', () => {
|
122 | expect(() => tf.square('q'))
|
123 | .toThrowError(/Argument 'x' passed to 'square' must be numeric/);
|
124 | });
|
125 | });
|
126 | //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"square_test.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/square_test.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,KAAK,EAAE,MAAM,UAAU,CAAC;AAC/B,OAAO,EAAC,QAAQ,EAAE,iBAAiB,EAAC,MAAM,iBAAiB,CAAC;AAC5D,OAAO,EAAC,iBAAiB,EAAE,kBAAkB,EAAC,MAAM,cAAc,CAAC;AACnE,OAAO,EAAE,OAAO,EAAE,MAAM,UAAU,CAAC;AAEnC,iBAAiB,CAAC,QAAQ,EAAE,QAAQ,EAAE,GAAG,EAAE;IACzC,EAAE,CAAC,UAAU,EAAE,KAAK,IAAI,EAAE;QACxB,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC5C,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;QACvB,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;IAChD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,UAAU,EAAE,KAAK,IAAI,EAAE;QACxB,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAClE,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;QACvB,MAAM,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAChC,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAClD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,UAAU,EAAE,KAAK,IAAI,EAAE;QACxB,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3E,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;QACvB,MAAM,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACzC,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAClD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,UAAU,EAAE,KAAK,IAAI,EAAE;QACxB,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CACjB,CAAC,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC,EACrE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACxB,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;QACvB,MAAM,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACjC,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;IAChE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,wBAAwB,EAAE,KAAK,IAAI,EAAE;QACtC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAClC,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;QACvB,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,IAAI,EAAE,GAAG,CAAC,CAAC,CAAC;IACjD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,OAAO,EAAE,KAAK,IAAI,EAAE;QACrB,IAAI,OAAO,EAAE,IAAI,OAAO,EAAE,CAAC,cAAc,EAAE,KAAK,EAAE,EAAE;YAClD,kDAAkD;YAClD,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,KAAK,CAAC,EAAE,OAAO,CAAC,CAAC;YAC9C,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;YACvB,MAAM,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,OAAO,CAAC,CAAC;YACjC,MAAM,IAAI,GAAG,MAAM,CAAC,CAAC,IAAI,EAAE,CAAC;YAC5B,kBAAkB,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;YAC/B,kBAAkB,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;YAChC,yCAAyC;YACzC,sEAAsE;YACtE,0BAA0B;YAC1B,kBAAkB,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,UAAa,EAAE,IAAK,CAAC,aAAa,CAAC,CAAC;SACjE;IACH,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,mBAAmB,EAAE,KAAK,IAAI,EAAE;QACjC,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;QACvB,MAAM,EAAE,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;QAExB,MAAM,SAAS,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAEpD,MAAM,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACzC,MAAM,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,SAAS,CAAC,CAAC;QAC3C,iBAAiB,CAAC,MAAM,SAAS,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;IACzD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,mBAAmB,EAAE,KAAK,IAAI,EAAE;QACjC,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;QACvB,MAAM,EAAE,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;QAExB,MAAM,SAAS,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAEpE,MAAM,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACzC,MAAM,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,SAAS,CAAC,CAAC;QAC3C,iBAAiB,CAAC,MAAM,SAAS,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;IACzD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,qBAAqB,EAAE,KAAK,IAAI,EAAE;QACnC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QACtC,MAAM,EAAE,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAErC,MAAM,SAAS,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAEpD,MAAM,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACzC,MAAM,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,SAAS,CAAC,CAAC;QAC3C,iBAAiB,CAAC,MAAM,SAAS,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC;IACzE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,qBAAqB,EAAE,KAAK,IAAI,EAAE;QACnC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC7C,MAAM,EAAE,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAE7C,MAAM,SAAS,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAEpD,MAAM,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACzC,MAAM,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,SAAS,CAAC,CAAC;QAC3C,iBAAiB,CAAC,MAAM,SAAS,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;IAC3E,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,qBAAqB,EAAE,KAAK,IAAI,EAAE;QACnC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACtD,MAAM,EAAE,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAEtD,MAAM,SAAS,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAEpD,MAAM,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACzC,MAAM,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,SAAS,CAAC,CAAC;QAC3C,iBAAiB,CAAC,MAAM,SAAS,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;IAC3E,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,qBAAqB,EAAE,KAAK,IAAI,EAAE;QACnC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACvE,MAAM,EAAE,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAErE,MAAM,SAAS,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAEpD,MAAM,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACzC,MAAM,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,SAAS,CAAC,CAAC;QAC3C,iBAAiB,CACb,MAAM,SAAS,CAAC,IAAI,EAAE,EACtB,CAAC,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC,GAAG,CAAC,EAAE,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;IACpE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,iCAAiC,EAAE,GAAG,EAAE;QACzC,MAAM,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,EAAe,CAAC,CAAC;aACnC,YAAY,CAAC,kDAAkD,CAAC,CAAC;IACxE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,8BAA8B,EAAE,KAAK,IAAI,EAAE;QAC5C,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC1C,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;IAChD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,0BAA0B,EAAE,GAAG,EAAE;QAClC,MAAM,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC;aACvB,YAAY,CAAC,iDAAiD,CAAC,CAAC;IACvE,CAAC,CAAC,CAAC;AACL,CAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as tf from '../index';\nimport {ALL_ENVS, describeWithFlags} from '../jasmine_util';\nimport {expectArraysClose, expectNumbersClose} from '../test_util';\nimport { backend } from '../index';\n\ndescribeWithFlags('square', ALL_ENVS, () => {\n  it('1D array', async () => {\n    const a = tf.tensor1d([2, 4, Math.sqrt(2)]);\n    const r = tf.square(a);\n    expectArraysClose(await r.data(), [4, 16, 2]);\n  });\n\n  it('2D array', async () => {\n    const a = tf.tensor2d([1, 2, Math.sqrt(2), Math.sqrt(3)], [2, 2]);\n    const r = tf.square(a);\n    expect(r.shape).toEqual([2, 2]);\n    expectArraysClose(await r.data(), [1, 4, 2, 3]);\n  });\n\n  it('5D array', async () => {\n    const a = tf.tensor5d([1, 2, Math.sqrt(2), Math.sqrt(3)], [1, 1, 2, 2, 1]);\n    const r = tf.square(a);\n    expect(r.shape).toEqual([1, 1, 2, 2, 1]);\n    expectArraysClose(await r.data(), [1, 4, 2, 3]);\n  });\n\n  it('6D array', async () => {\n    const a = tf.tensor6d(\n        [1, 2, Math.sqrt(2), Math.sqrt(3), 3, 4, Math.sqrt(7), Math.sqrt(13)],\n        [1, 1, 2, 2, 2, 1]);\n    const r = tf.square(a);\n    expect(r.shape).toEqual(a.shape);\n    expectArraysClose(await r.data(), [1, 4, 2, 3, 9, 16, 7, 13]);\n  });\n\n  it('square propagates NaNs', async () => {\n    const a = tf.tensor1d([1.5, NaN]);\n    const r = tf.square(a);\n    expectArraysClose(await r.data(), [2.25, NaN]);\n  });\n\n  it('int32', async () => {\n    if (backend() && backend().floatPrecision() === 32) {\n      // TODO: Use skip() instead when it is implemented\n      const a = tf.tensor1d([2, 4, 40000], 'int32');\n      const r = tf.square(a);\n      expect(r.dtype).toEqual('int32');\n      const data = await r.data();\n      expectNumbersClose(data[0], 4);\n      expectNumbersClose(data[1], 16);\n      // Epsilon must be larger here for webgl1\n      // TODO: Use expectArraysClose when it supports epsilons scaled by the\n      // numbers being compared.\n      expectNumbersClose(data[2], 1_600_000_000, 1_000 /* epsilon */);\n    }\n  });\n\n  it('gradients: Scalar', async () => {\n    const a = tf.scalar(5);\n    const dy = tf.scalar(8);\n\n    const gradients = tf.grad(a => tf.square(a))(a, dy);\n\n    expect(gradients.shape).toEqual(a.shape);\n    expect(gradients.dtype).toEqual('float32');\n    expectArraysClose(await gradients.data(), [2 * 5 * 8]);\n  });\n\n  it('gradients: Scalar', async () => {\n    const a = tf.scalar(5);\n    const dy = tf.scalar(8);\n\n    const gradients = tf.grad(a => tf.square(a.clone()).clone())(a, dy);\n\n    expect(gradients.shape).toEqual(a.shape);\n    expect(gradients.dtype).toEqual('float32');\n    expectArraysClose(await gradients.data(), [2 * 5 * 8]);\n  });\n\n  it('gradients: Tensor1D', async () => {\n    const a = tf.tensor1d([-1, 2, 3, -5]);\n    const dy = tf.tensor1d([1, 2, 3, 4]);\n\n    const gradients = tf.grad(a => tf.square(a))(a, dy);\n\n    expect(gradients.shape).toEqual(a.shape);\n    expect(gradients.dtype).toEqual('float32');\n    expectArraysClose(await gradients.data(), [-2, 4 * 2, 6 * 3, -10 * 4]);\n  });\n\n  it('gradients: Tensor2D', async () => {\n    const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]);\n    const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n\n    const gradients = tf.grad(a => tf.square(a))(a, dy);\n\n    expect(gradients.shape).toEqual(a.shape);\n    expect(gradients.dtype).toEqual('float32');\n    expectArraysClose(await gradients.data(), [-6 * 1, 2 * 2, 4 * 3, 6 * 4]);\n  });\n\n  it('gradients: Tensor5D', async () => {\n    const a = tf.tensor5d([-3, 1, 2, 3], [1, 1, 1, 2, 2]);\n    const dy = tf.tensor5d([1, 2, 3, 4], [1, 1, 1, 2, 2]);\n\n    const gradients = tf.grad(a => tf.square(a))(a, dy);\n\n    expect(gradients.shape).toEqual(a.shape);\n    expect(gradients.dtype).toEqual('float32');\n    expectArraysClose(await gradients.data(), [-6 * 1, 2 * 2, 4 * 3, 6 * 4]);\n  });\n\n  it('gradients: Tensor6D', async () => {\n    const a = tf.tensor6d([-3, 1, 2, 3, -4, 5, 12, 3], [1, 1, 1, 2, 2, 2]);\n    const dy = tf.tensor6d([1, 2, 3, 4, 5, 6, 7, 8], [1, 1, 1, 2, 2, 2]);\n\n    const gradients = tf.grad(a => tf.square(a))(a, dy);\n\n    expect(gradients.shape).toEqual(a.shape);\n    expect(gradients.dtype).toEqual('float32');\n    expectArraysClose(\n        await gradients.data(),\n        [-6 * 1, 2 * 2, 4 * 3, 6 * 4, -8 * 5, 10 * 6, 24 * 7, 6 * 8]);\n  });\n\n  it('throws when passed a non-tensor', () => {\n    expect(() => tf.square({} as tf.Tensor))\n        .toThrowError(/Argument 'x' passed to 'square' must be a Tensor/);\n  });\n\n  it('accepts a tensor-like object', async () => {\n    const r = tf.square([2, 4, Math.sqrt(2)]);\n    expectArraysClose(await r.data(), [4, 16, 2]);\n  });\n\n  it('throws for string tensor', () => {\n    expect(() => tf.square('q'))\n        .toThrowError(/Argument 'x' passed to 'square' must be numeric/);\n  });\n});\n"]} |
\ | No newline at end of file |