UNPKG

22 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google Inc. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose, expectArraysEqual } from '../test_util';
20import * as reduce_util from './reduce_util';
21describeWithFlags('argmin', ALL_ENVS, () => {
22 it('Tensor1D', async () => {
23 const a = tf.tensor1d([1, 0, 3, 2]);
24 const result = tf.argMin(a);
25 expectArraysEqual(await result.data(), 1);
26 });
27 it('one value', async () => {
28 const a = tf.tensor1d([10]);
29 const result = tf.argMin(a);
30 expectArraysEqual(await result.data(), 0);
31 });
32 it('N > than parallelization threshold', async () => {
33 const n = reduce_util.PARALLELIZE_THRESHOLD * 2;
34 const values = new Float32Array(n);
35 for (let i = 0; i < n; i++) {
36 values[i] = n - i;
37 }
38 const a = tf.tensor1d(values);
39 const result = tf.argMin(a);
40 expect(result.dtype).toBe('int32');
41 expectArraysEqual(await result.data(), n - 1);
42 });
43 it('4D, N > than parallelization threshold', async () => {
44 const n = reduce_util.PARALLELIZE_THRESHOLD * 2;
45 const values = new Float32Array(n);
46 for (let i = 0; i < n; i++) {
47 values[i] = n - i;
48 }
49 const a = tf.tensor4d(values, [1, 1, 1, n]);
50 const result = tf.argMin(a, -1);
51 expect(result.dtype).toBe('int32');
52 expectArraysEqual(await result.data(), n - 1);
53 });
54 it('min index corresponds to start of a non-initial window', async () => {
55 const n = reduce_util.PARALLELIZE_THRESHOLD * 2;
56 const windowSize = reduce_util.computeOptimalWindowSize(n);
57 const values = new Float32Array(n);
58 const index = windowSize * 2;
59 values[index] = -1;
60 const a = tf.tensor1d(values);
61 const result = tf.argMin(a);
62 expect(result.dtype).toBe('int32');
63 expectArraysEqual(await result.data(), index);
64 });
65 it('ignores NaNs', async () => {
66 const a = tf.tensor1d([5, 0, NaN, -1, 3]);
67 const res = tf.argMin(a);
68 expectArraysEqual(await res.data(), 3);
69 });
70 it('3D, ignores NaNs', async () => {
71 const a = tf.tensor3d([5, 0, NaN, -1, 3], [1, 1, 5]);
72 const res = tf.argMin(a, -1);
73 expectArraysEqual(await res.data(), 3);
74 });
75 it('2D, no axis specified', async () => {
76 const a = tf.tensor2d([3, -1, 0, 100, -7, 2], [2, 3]);
77 expectArraysEqual(await tf.argMin(a).data(), [0, 1, 0]);
78 });
79 it('2D, axis=0', async () => {
80 const a = tf.tensor2d([3, -1, 0, 100, -7, 2], [2, 3]);
81 const r = tf.argMin(a, 0);
82 expect(r.shape).toEqual([3]);
83 expect(r.dtype).toBe('int32');
84 expectArraysEqual(await r.data(), [0, 1, 0]);
85 });
86 it('2D, axis=1', async () => {
87 const a = tf.tensor2d([3, 2, 5, 100, -7, -8], [2, 3]);
88 const r = tf.argMin(a, 1);
89 expectArraysEqual(await r.data(), [1, 2]);
90 });
91 it('2D, axis = -1', async () => {
92 const a = tf.tensor2d([3, 2, 5, 100, -7, -8], [2, 3]);
93 const r = tf.argMin(a, -1);
94 expectArraysEqual(await r.data(), [1, 2]);
95 });
96 it('throws when passed a non-tensor', () => {
97 expect(() => tf.argMin({}))
98 .toThrowError(/Argument 'x' passed to 'argMin' must be a Tensor/);
99 });
100 it('accepts a tensor-like object', async () => {
101 const result = tf.argMin([1, 0, 3, 2]);
102 expectArraysEqual(await result.data(), 1);
103 });
104 it('accepts tensor with bool values', async () => {
105 const t = tf.tensor1d([0, 1], 'bool');
106 const result = tf.argMin(t);
107 expect(result.dtype).toBe('int32');
108 expectArraysEqual(await result.data(), 0);
109 });
110 it('has gradient', async () => {
111 const a = tf.tensor2d([3, 2, 5, 100, -7, 2], [2, 3]);
112 const dy = tf.ones([3], 'float32');
113 const da = tf.grad((x) => tf.argMin(x))(a, dy);
114 expect(da.dtype).toBe('float32');
115 expect(da.shape).toEqual([2, 3]);
116 expectArraysClose(await da.data(), [0, 0, 0, 0, 0, 0]);
117 });
118 it('gradient with clones', async () => {
119 const a = tf.tensor2d([3, 2, 5, 100, -7, 2], [2, 3]);
120 const dy = tf.ones([3], 'float32');
121 const da = tf.grad((x) => tf.argMin(x.clone()).clone())(a, dy);
122 expect(da.dtype).toBe('float32');
123 expect(da.shape).toEqual([2, 3]);
124 expectArraysClose(await da.data(), [0, 0, 0, 0, 0, 0]);
125 });
126 it('throws error for string tensor', () => {
127 expect(() => tf.argMin(['a']))
128 .toThrowError(/Argument 'x' passed to 'argMin' must be numeric tensor/);
129 });
130});
131//# sourceMappingURL=data:application/json;base64,
\No newline at end of file