1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 | import * as tf from '../index';
|
18 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
19 | import { expectArraysClose } from '../test_util';
|
20 | describeWithFlags('neg', ALL_ENVS, () => {
|
21 | it('basic', async () => {
|
22 | const a = tf.tensor1d([1, -3, 2, 7, -4]);
|
23 | const result = tf.neg(a);
|
24 | expectArraysClose(await result.data(), [-1, 3, -2, -7, 4]);
|
25 | });
|
26 | it('propagate NaNs', async () => {
|
27 | const a = tf.tensor1d([1, -3, 2, 7, NaN]);
|
28 | const result = tf.neg(a);
|
29 | const expected = [-1, 3, -2, -7, NaN];
|
30 | expectArraysClose(await result.data(), expected);
|
31 | });
|
32 | it('gradients: Scalar', async () => {
|
33 | const a = tf.scalar(4);
|
34 | const dy = tf.scalar(8);
|
35 | const da = tf.grad(a => tf.neg(a))(a, dy);
|
36 | expect(da.shape).toEqual(a.shape);
|
37 | expect(da.dtype).toEqual('float32');
|
38 | expectArraysClose(await da.data(), [8 * -1]);
|
39 | });
|
40 | it('gradients: Scalar', async () => {
|
41 | const a = tf.scalar(4);
|
42 | const dy = tf.scalar(8);
|
43 | const da = tf.grad(a => tf.neg(a.clone()).clone())(a, dy);
|
44 | expect(da.shape).toEqual(a.shape);
|
45 | expect(da.dtype).toEqual('float32');
|
46 | expectArraysClose(await da.data(), [8 * -1]);
|
47 | });
|
48 | it('gradients: Tensor1D', async () => {
|
49 | const a = tf.tensor1d([1, 2, -3, 5]);
|
50 | const dy = tf.tensor1d([1, 2, 3, 4]);
|
51 | const da = tf.grad(a => tf.neg(a))(a, dy);
|
52 | expect(da.shape).toEqual(a.shape);
|
53 | expect(da.dtype).toEqual('float32');
|
54 | expectArraysClose(await da.data(), [1 * -1, 2 * -1, 3 * -1, 4 * -1]);
|
55 | });
|
56 | it('gradients: Tensor2D', async () => {
|
57 | const a = tf.tensor2d([3, -1, -2, 3], [2, 2]);
|
58 | const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]);
|
59 | const da = tf.grad(a => tf.neg(a))(a, dy);
|
60 | expect(da.shape).toEqual(a.shape);
|
61 | expect(da.dtype).toEqual('float32');
|
62 | expectArraysClose(await da.data(), [1 * -1, 2 * -1, 3 * -1, 4 * -1]);
|
63 | });
|
64 | it('throws when passed a non-tensor', () => {
|
65 | expect(() => tf.neg({}))
|
66 | .toThrowError(/Argument 'x' passed to 'neg' must be a Tensor/);
|
67 | });
|
68 | it('accepts a tensor-like object', async () => {
|
69 | const result = tf.neg([1, -3, 2, 7, -4]);
|
70 | expectArraysClose(await result.data(), [-1, 3, -2, -7, 4]);
|
71 | });
|
72 | it('throws for string tensor', () => {
|
73 | expect(() => tf.neg('q'))
|
74 | .toThrowError(/Argument 'x' passed to 'neg' must be numeric/);
|
75 | });
|
76 | });
|
77 |
|
\ | No newline at end of file |