UNPKG

3.88 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20describeWithFlags('reciprocal', ALL_ENVS, () => {
21 it('1D array', async () => {
22 const a = tf.tensor1d([2, 3, 0, NaN]);
23 const r = tf.reciprocal(a);
24 expectArraysClose(await r.data(), [1 / 2, 1 / 3, Infinity, NaN]);
25 });
26 it('2D array', async () => {
27 const a = tf.tensor2d([1, Infinity, 0, NaN], [2, 2]);
28 const r = tf.reciprocal(a);
29 expect(r.shape).toEqual([2, 2]);
30 expectArraysClose(await r.data(), [1 / 1, 0, Infinity, NaN]);
31 });
32 it('reciprocal propagates NaNs', async () => {
33 const a = tf.tensor1d([1.5, NaN]);
34 const r = tf.reciprocal(a);
35 expectArraysClose(await r.data(), [1 / 1.5, NaN]);
36 });
37 it('gradients: Scalar', async () => {
38 const a = tf.scalar(5);
39 const dy = tf.scalar(8);
40 const gradients = tf.grad(a => tf.reciprocal(a))(a, dy);
41 expect(gradients.shape).toEqual(a.shape);
42 expect(gradients.dtype).toEqual('float32');
43 expectArraysClose(await gradients.data(), [-1 * 8 * (1 / (5 * 5))]);
44 });
45 it('gradient with clones', async () => {
46 const a = tf.scalar(5);
47 const dy = tf.scalar(8);
48 const gradients = tf.grad(a => tf.reciprocal(a.clone()).clone())(a, dy);
49 expect(gradients.shape).toEqual(a.shape);
50 expect(gradients.dtype).toEqual('float32');
51 expectArraysClose(await gradients.data(), [-1 * 8 * (1 / (5 * 5))]);
52 });
53 it('gradients: Tensor1D', async () => {
54 const a = tf.tensor1d([-1, 2, 3, -5]);
55 const dy = tf.tensor1d([1, 2, 3, 4]);
56 const gradients = tf.grad(a => tf.reciprocal(a))(a, dy);
57 expect(gradients.shape).toEqual(a.shape);
58 expect(gradients.dtype).toEqual('float32');
59 expectArraysClose(await gradients.data(), [
60 -1 * 1 * (1 / (-1 * -1)), -1 * 2 * (1 / (2 * 2)), -1 * 3 * (1 / (3 * 3)),
61 -1 * 4 * (1 / (-5 * -5))
62 ]);
63 });
64 it('gradients: Tensor2D', async () => {
65 const a = tf.tensor2d([-1, 2, 3, -5], [2, 2]);
66 const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]);
67 const gradients = tf.grad(a => tf.reciprocal(a))(a, dy);
68 expect(gradients.shape).toEqual(a.shape);
69 expect(gradients.dtype).toEqual('float32');
70 expectArraysClose(await gradients.data(), [
71 -1 * 1 * (1 / (-1 * -1)), -1 * 2 * (1 / (2 * 2)), -1 * 3 * (1 / (3 * 3)),
72 -1 * 4 * (1 / (-5 * -5))
73 ]);
74 });
75 it('throws when passed a non-tensor', () => {
76 expect(() => tf.reciprocal({}))
77 .toThrowError(/Argument 'x' passed to 'reciprocal' must be a Tensor/);
78 });
79 it('accepts a tensor-like object', async () => {
80 const r = tf.reciprocal([2, 3, 0, NaN]);
81 expectArraysClose(await r.data(), [1 / 2, 1 / 3, Infinity, NaN]);
82 });
83 it('throws for string tensor', () => {
84 expect(() => tf.reciprocal('q'))
85 .toThrowError(/Argument 'x' passed to 'reciprocal' must be numeric/);
86 });
87});
88//# sourceMappingURL=reciprocal_test.js.map
\No newline at end of file