1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 | import * as tf from '../index';
|
18 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
19 | import { expectArraysClose } from '../test_util';
|
20 | describeWithFlags('reciprocal', ALL_ENVS, () => {
|
21 | it('1D array', async () => {
|
22 | const a = tf.tensor1d([2, 3, 0, NaN]);
|
23 | const r = tf.reciprocal(a);
|
24 | expectArraysClose(await r.data(), [1 / 2, 1 / 3, Infinity, NaN]);
|
25 | });
|
26 | it('2D array', async () => {
|
27 | const a = tf.tensor2d([1, Infinity, 0, NaN], [2, 2]);
|
28 | const r = tf.reciprocal(a);
|
29 | expect(r.shape).toEqual([2, 2]);
|
30 | expectArraysClose(await r.data(), [1 / 1, 0, Infinity, NaN]);
|
31 | });
|
32 | it('reciprocal propagates NaNs', async () => {
|
33 | const a = tf.tensor1d([1.5, NaN]);
|
34 | const r = tf.reciprocal(a);
|
35 | expectArraysClose(await r.data(), [1 / 1.5, NaN]);
|
36 | });
|
37 | it('gradients: Scalar', async () => {
|
38 | const a = tf.scalar(5);
|
39 | const dy = tf.scalar(8);
|
40 | const gradients = tf.grad(a => tf.reciprocal(a))(a, dy);
|
41 | expect(gradients.shape).toEqual(a.shape);
|
42 | expect(gradients.dtype).toEqual('float32');
|
43 | expectArraysClose(await gradients.data(), [-1 * 8 * (1 / (5 * 5))]);
|
44 | });
|
45 | it('gradient with clones', async () => {
|
46 | const a = tf.scalar(5);
|
47 | const dy = tf.scalar(8);
|
48 | const gradients = tf.grad(a => tf.reciprocal(a.clone()).clone())(a, dy);
|
49 | expect(gradients.shape).toEqual(a.shape);
|
50 | expect(gradients.dtype).toEqual('float32');
|
51 | expectArraysClose(await gradients.data(), [-1 * 8 * (1 / (5 * 5))]);
|
52 | });
|
53 | it('gradients: Tensor1D', async () => {
|
54 | const a = tf.tensor1d([-1, 2, 3, -5]);
|
55 | const dy = tf.tensor1d([1, 2, 3, 4]);
|
56 | const gradients = tf.grad(a => tf.reciprocal(a))(a, dy);
|
57 | expect(gradients.shape).toEqual(a.shape);
|
58 | expect(gradients.dtype).toEqual('float32');
|
59 | expectArraysClose(await gradients.data(), [
|
60 | -1 * 1 * (1 / (-1 * -1)), -1 * 2 * (1 / (2 * 2)), -1 * 3 * (1 / (3 * 3)),
|
61 | -1 * 4 * (1 / (-5 * -5))
|
62 | ]);
|
63 | });
|
64 | it('gradients: Tensor2D', async () => {
|
65 | const a = tf.tensor2d([-1, 2, 3, -5], [2, 2]);
|
66 | const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]);
|
67 | const gradients = tf.grad(a => tf.reciprocal(a))(a, dy);
|
68 | expect(gradients.shape).toEqual(a.shape);
|
69 | expect(gradients.dtype).toEqual('float32');
|
70 | expectArraysClose(await gradients.data(), [
|
71 | -1 * 1 * (1 / (-1 * -1)), -1 * 2 * (1 / (2 * 2)), -1 * 3 * (1 / (3 * 3)),
|
72 | -1 * 4 * (1 / (-5 * -5))
|
73 | ]);
|
74 | });
|
75 | it('throws when passed a non-tensor', () => {
|
76 | expect(() => tf.reciprocal({}))
|
77 | .toThrowError(/Argument 'x' passed to 'reciprocal' must be a Tensor/);
|
78 | });
|
79 | it('accepts a tensor-like object', async () => {
|
80 | const r = tf.reciprocal([2, 3, 0, NaN]);
|
81 | expectArraysClose(await r.data(), [1 / 2, 1 / 3, Infinity, NaN]);
|
82 | });
|
83 | it('throws for string tensor', () => {
|
84 | expect(() => tf.reciprocal('q'))
|
85 | .toThrowError(/Argument 'x' passed to 'reciprocal' must be numeric/);
|
86 | });
|
87 | });
|
88 |
|
\ | No newline at end of file |