UNPKG

3.47 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20describeWithFlags('sqrt', ALL_ENVS, () => {
21 it('sqrt', async () => {
22 const a = tf.tensor1d([2, 4]);
23 const r = tf.sqrt(a);
24 expectArraysClose(await r.data(), [Math.sqrt(2), Math.sqrt(4)]);
25 });
26 it('sqrt propagates NaNs', async () => {
27 const a = tf.tensor1d([1, NaN]);
28 const r = tf.sqrt(a);
29 expectArraysClose(await r.data(), [Math.sqrt(1), NaN]);
30 });
31 it('gradients: Scalar', async () => {
32 const a = tf.scalar(4);
33 const dy = tf.scalar(8);
34 const da = tf.grad(a => tf.sqrt(a))(a, dy);
35 expect(da.shape).toEqual(a.shape);
36 expect(da.dtype).toEqual('float32');
37 expectArraysClose(await da.data(), [8 / (2 * Math.sqrt(4))]);
38 });
39 it('gradient with clones', async () => {
40 const a = tf.scalar(4);
41 const dy = tf.scalar(8);
42 const da = tf.grad(a => tf.sqrt(a.clone()).clone())(a, dy);
43 expect(da.shape).toEqual(a.shape);
44 expect(da.dtype).toEqual('float32');
45 expectArraysClose(await da.data(), [8 / (2 * Math.sqrt(4))]);
46 });
47 it('gradients: Tensor1D', async () => {
48 const a = tf.tensor1d([1, 2, 3, 5]);
49 const dy = tf.tensor1d([1, 2, 3, 4]);
50 const gradients = tf.grad(a => tf.sqrt(a))(a, dy);
51 expect(gradients.shape).toEqual(a.shape);
52 expect(gradients.dtype).toEqual('float32');
53 expectArraysClose(await gradients.data(), [
54 1 / (2 * Math.sqrt(1)), 2 / (2 * Math.sqrt(2)),
55 3 / (2 * Math.sqrt(3)), 4 / (2 * Math.sqrt(5))
56 ], 1e-1);
57 });
58 it('gradients: Tensor2D', async () => {
59 const a = tf.tensor2d([3, 1, 2, 3], [2, 2]);
60 const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]);
61 const gradients = tf.grad(a => tf.sqrt(a))(a, dy);
62 expect(gradients.shape).toEqual(a.shape);
63 expect(gradients.dtype).toEqual('float32');
64 expectArraysClose(await gradients.data(), [
65 1 / (2 * Math.sqrt(3)), 2 / (2 * Math.sqrt(1)),
66 3 / (2 * Math.sqrt(2)), 4 / (2 * Math.sqrt(3))
67 ], 1e-1);
68 });
69 it('throws when passed a non-tensor', () => {
70 expect(() => tf.sqrt({}))
71 .toThrowError(/Argument 'x' passed to 'sqrt' must be a Tensor/);
72 });
73 it('accepts a tensor-like object', async () => {
74 const r = tf.sqrt([2, 4]);
75 expectArraysClose(await r.data(), [Math.sqrt(2), Math.sqrt(4)]);
76 });
77 it('throws for string tensor', () => {
78 expect(() => tf.sqrt('q'))
79 .toThrowError(/Argument 'x' passed to 'sqrt' must be numeric/);
80 });
81});
82//# sourceMappingURL=sqrt_test.js.map
\No newline at end of file