UNPKG

3.77 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20describeWithFlags('step kernel', ALL_ENVS, () => {
21 it('with 1d tensor', async () => {
22 const a = tf.tensor1d([1, -2, -.01, 3, -0.1]);
23 const result = tf.step(a);
24 expectArraysClose(await result.data(), [1, 0, 0, 1, 0]);
25 });
26 it('with 1d tensor and alpha', async () => {
27 const a = tf.tensor1d([1, -2, -.01, 3, NaN]);
28 const result = tf.step(a, 0.1);
29 expectArraysClose(await result.data(), [1, 0.1, 0.1, 1, NaN]);
30 });
31 it('with 2d tensor', async () => {
32 const a = tf.tensor2d([1, -5, -3, 4], [2, 2]);
33 const result = tf.step(a);
34 expect(result.shape).toEqual([2, 2]);
35 expectArraysClose(await result.data(), [1, 0, 0, 1]);
36 });
37 it('propagates NaNs', async () => {
38 const a = tf.tensor1d([1, -2, -.01, 3, NaN]);
39 const result = tf.step(a);
40 expectArraysClose(await result.data(), [1, 0, 0, 1, NaN]);
41 });
42 it('gradients: Scalar', async () => {
43 const a = tf.scalar(-4);
44 const dy = tf.scalar(8);
45 const gradients = tf.grad(a => tf.step(a))(a, dy);
46 expect(gradients.shape).toEqual(a.shape);
47 expect(gradients.dtype).toEqual('float32');
48 expectArraysClose(await gradients.data(), [0]);
49 });
50 it('gradient with clones', async () => {
51 const a = tf.scalar(-4);
52 const dy = tf.scalar(8);
53 const gradients = tf.grad(a => tf.step(a.clone()).clone())(a, dy);
54 expect(gradients.shape).toEqual(a.shape);
55 expect(gradients.dtype).toEqual('float32');
56 expectArraysClose(await gradients.data(), [0]);
57 });
58 it('gradients: Tensor1D', async () => {
59 const a = tf.tensor1d([1, 2, -3, 5]);
60 const dy = tf.tensor1d([1, 2, 3, 4]);
61 const gradients = tf.grad(a => tf.step(a))(a, dy);
62 expect(gradients.shape).toEqual(a.shape);
63 expect(gradients.dtype).toEqual('float32');
64 expectArraysClose(await gradients.data(), [0, 0, 0, 0]);
65 });
66 it('gradients: Tensor2D', async () => {
67 const a = tf.tensor2d([3, -1, -2, 3], [2, 2]);
68 const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]);
69 const gradients = tf.grad(a => tf.step(a))(a, dy);
70 expect(gradients.shape).toEqual(a.shape);
71 expect(gradients.dtype).toEqual('float32');
72 expectArraysClose(await gradients.data(), [0, 0, 0, 0]);
73 });
74 it('throws when passed a non-tensor', () => {
75 expect(() => tf.step({}))
76 .toThrowError(/Argument 'x' passed to 'step' must be a Tensor/);
77 });
78 it('accepts a tensor-like object', async () => {
79 const result = tf.step([1, -2, -.01, 3, -0.1]);
80 expectArraysClose(await result.data(), [1, 0, 0, 1, 0]);
81 });
82 it('throws for string tensor', () => {
83 expect(() => tf.step('q'))
84 .toThrowError(/Argument 'x' passed to 'step' must be numeric/);
85 });
86});
87//# sourceMappingURL=step_test.js.map
\No newline at end of file