1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 | import * as tf from '../index';
|
18 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
19 | import { expectArraysClose } from '../test_util';
|
20 | describeWithFlags('logSoftmax', ALL_ENVS, () => {
|
21 | it('regular test', async () => {
|
22 | const y = tf.logSoftmax(tf.tensor1d([2, 1, 3]));
|
23 | expectArraysClose(await y.data(), [-1.407606, -2.4076061, -0.407606]);
|
24 | });
|
25 | it('Huge difference', async () => {
|
26 | const y = tf.logSoftmax(tf.tensor1d([-1000, +1000]));
|
27 | expectArraysClose(await y.data(), [-2000, 0]);
|
28 | });
|
29 | it('Propagates NaNs', async () => {
|
30 | const a = tf.tensor1d([2, 1, NaN]);
|
31 | const y = tf.logSoftmax(a);
|
32 | expectArraysClose(await y.data(), [NaN, NaN, NaN]);
|
33 | });
|
34 | it('2D, axis=1', async () => {
|
35 | const y = tf.logSoftmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3]), 1);
|
36 | const expected = [-1.407606, -2.4076061, -0.407606, -2.4076061, -0.4076061, -1.4076061];
|
37 | expect(y.rank).toBe(2);
|
38 | expectArraysClose(await y.data(), expected);
|
39 | });
|
40 | it('2D, implicit axis=1', async () => {
|
41 | const y = tf.logSoftmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3]));
|
42 | const expected = [-1.407606, -2.4076061, -0.407606, -2.4076061, -0.4076061, -1.4076061];
|
43 | expect(y.rank).toBe(2);
|
44 | expectArraysClose(await y.data(), expected);
|
45 | });
|
46 | it('1D gradient', async () => {
|
47 | const x = tf.tensor1d([1, 2, 10]);
|
48 | const dy = tf.tensor1d([1, 2, 3]);
|
49 | const dx = tf.grad((x) => x.logSoftmax())(x, dy);
|
50 | expect(dx.shape).toEqual(x.shape);
|
51 | expectArraysClose(await dx.data(), [0.9992599, 1.9979881, -2.9972477]);
|
52 | });
|
53 | it('2D, axis=0 throws error', () => {
|
54 | const f = () => {
|
55 | tf.logSoftmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3]), 0);
|
56 | };
|
57 | expect(f).toThrowError();
|
58 | });
|
59 | it('throws when passed a non-tensor', () => {
|
60 | expect(() => tf.logSoftmax({}))
|
61 | .toThrowError(/Argument 'logits' passed to 'logSoftmax' must be a Tensor/);
|
62 | });
|
63 | it('accepts a tensor-like object', async () => {
|
64 | const y = tf.logSoftmax([2, 1, 3]);
|
65 | expectArraysClose(await y.data(), [-1.407606, -2.4076061, -0.407606]);
|
66 | });
|
67 | });
|
68 |
|
\ | No newline at end of file |