1 | /**
|
2 | * @license
|
3 | * Copyright 2020 Google Inc. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | import * as tf from '../index';
|
18 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
19 | import { expectArraysClose, expectArraysEqual } from '../test_util';
|
20 | import * as reduce_util from './reduce_util';
|
21 | describeWithFlags('argmin', ALL_ENVS, () => {
|
22 | it('Tensor1D', async () => {
|
23 | const a = tf.tensor1d([1, 0, 3, 2]);
|
24 | const result = tf.argMin(a);
|
25 | expectArraysEqual(await result.data(), 1);
|
26 | });
|
27 | it('one value', async () => {
|
28 | const a = tf.tensor1d([10]);
|
29 | const result = tf.argMin(a);
|
30 | expectArraysEqual(await result.data(), 0);
|
31 | });
|
32 | it('N > than parallelization threshold', async () => {
|
33 | const n = reduce_util.PARALLELIZE_THRESHOLD * 2;
|
34 | const values = new Float32Array(n);
|
35 | for (let i = 0; i < n; i++) {
|
36 | values[i] = n - i;
|
37 | }
|
38 | const a = tf.tensor1d(values);
|
39 | const result = tf.argMin(a);
|
40 | expect(result.dtype).toBe('int32');
|
41 | expectArraysEqual(await result.data(), n - 1);
|
42 | });
|
43 | it('4D, N > than parallelization threshold', async () => {
|
44 | const n = reduce_util.PARALLELIZE_THRESHOLD * 2;
|
45 | const values = new Float32Array(n);
|
46 | for (let i = 0; i < n; i++) {
|
47 | values[i] = n - i;
|
48 | }
|
49 | const a = tf.tensor4d(values, [1, 1, 1, n]);
|
50 | const result = tf.argMin(a, -1);
|
51 | expect(result.dtype).toBe('int32');
|
52 | expectArraysEqual(await result.data(), n - 1);
|
53 | });
|
54 | it('min index corresponds to start of a non-initial window', async () => {
|
55 | const n = reduce_util.PARALLELIZE_THRESHOLD * 2;
|
56 | const windowSize = reduce_util.computeOptimalWindowSize(n);
|
57 | const values = new Float32Array(n);
|
58 | const index = windowSize * 2;
|
59 | values[index] = -1;
|
60 | const a = tf.tensor1d(values);
|
61 | const result = tf.argMin(a);
|
62 | expect(result.dtype).toBe('int32');
|
63 | expectArraysEqual(await result.data(), index);
|
64 | });
|
65 | it('ignores NaNs', async () => {
|
66 | const a = tf.tensor1d([5, 0, NaN, -1, 3]);
|
67 | const res = tf.argMin(a);
|
68 | expectArraysEqual(await res.data(), 3);
|
69 | });
|
70 | it('3D, ignores NaNs', async () => {
|
71 | const a = tf.tensor3d([5, 0, NaN, -1, 3], [1, 1, 5]);
|
72 | const res = tf.argMin(a, -1);
|
73 | expectArraysEqual(await res.data(), 3);
|
74 | });
|
75 | it('2D, no axis specified', async () => {
|
76 | const a = tf.tensor2d([3, -1, 0, 100, -7, 2], [2, 3]);
|
77 | expectArraysEqual(await tf.argMin(a).data(), [0, 1, 0]);
|
78 | });
|
79 | it('2D, axis=0', async () => {
|
80 | const a = tf.tensor2d([3, -1, 0, 100, -7, 2], [2, 3]);
|
81 | const r = tf.argMin(a, 0);
|
82 | expect(r.shape).toEqual([3]);
|
83 | expect(r.dtype).toBe('int32');
|
84 | expectArraysEqual(await r.data(), [0, 1, 0]);
|
85 | });
|
86 | it('2D, axis=1', async () => {
|
87 | const a = tf.tensor2d([3, 2, 5, 100, -7, -8], [2, 3]);
|
88 | const r = tf.argMin(a, 1);
|
89 | expectArraysEqual(await r.data(), [1, 2]);
|
90 | });
|
91 | it('2D, axis = -1', async () => {
|
92 | const a = tf.tensor2d([3, 2, 5, 100, -7, -8], [2, 3]);
|
93 | const r = tf.argMin(a, -1);
|
94 | expectArraysEqual(await r.data(), [1, 2]);
|
95 | });
|
96 | it('throws when passed a non-tensor', () => {
|
97 | expect(() => tf.argMin({}))
|
98 | .toThrowError(/Argument 'x' passed to 'argMin' must be a Tensor/);
|
99 | });
|
100 | it('accepts a tensor-like object', async () => {
|
101 | const result = tf.argMin([1, 0, 3, 2]);
|
102 | expectArraysEqual(await result.data(), 1);
|
103 | });
|
104 | it('accepts tensor with bool values', async () => {
|
105 | const t = tf.tensor1d([0, 1], 'bool');
|
106 | const result = tf.argMin(t);
|
107 | expect(result.dtype).toBe('int32');
|
108 | expectArraysEqual(await result.data(), 0);
|
109 | });
|
110 | it('has gradient', async () => {
|
111 | const a = tf.tensor2d([3, 2, 5, 100, -7, 2], [2, 3]);
|
112 | const dy = tf.ones([3], 'float32');
|
113 | const da = tf.grad((x) => tf.argMin(x))(a, dy);
|
114 | expect(da.dtype).toBe('float32');
|
115 | expect(da.shape).toEqual([2, 3]);
|
116 | expectArraysClose(await da.data(), [0, 0, 0, 0, 0, 0]);
|
117 | });
|
118 | it('gradient with clones', async () => {
|
119 | const a = tf.tensor2d([3, 2, 5, 100, -7, 2], [2, 3]);
|
120 | const dy = tf.ones([3], 'float32');
|
121 | const da = tf.grad((x) => tf.argMin(x.clone()).clone())(a, dy);
|
122 | expect(da.dtype).toBe('float32');
|
123 | expect(da.shape).toEqual([2, 3]);
|
124 | expectArraysClose(await da.data(), [0, 0, 0, 0, 0, 0]);
|
125 | });
|
126 | it('throws error for string tensor', () => {
|
127 | expect(() => tf.argMin(['a']))
|
128 | .toThrowError(/Argument 'x' passed to 'argMin' must be numeric tensor/);
|
129 | });
|
130 | });
|
131 | //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"arg_min_test.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/arg_min_test.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,KAAK,EAAE,MAAM,UAAU,CAAC;AAC/B,OAAO,EAAC,QAAQ,EAAE,iBAAiB,EAAC,MAAM,iBAAiB,CAAC;AAC5D,OAAO,EAAC,iBAAiB,EAAE,iBAAiB,EAAC,MAAM,cAAc,CAAC;AAElE,OAAO,KAAK,WAAW,MAAM,eAAe,CAAC;AAE7C,iBAAiB,CAAC,QAAQ,EAAE,QAAQ,EAAE,GAAG,EAAE;IACzC,EAAE,CAAC,UAAU,EAAE,KAAK,IAAI,EAAE;QACxB,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpC,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;QAC5B,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,CAAC;IAC5C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,WAAW,EAAE,KAAK,IAAI,EAAE;QACzB,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;QAC5B,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;QAC5B,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,CAAC;IAC5C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,oCAAoC,EAAE,KAAK,IAAI,EAAE;QAClD,MAAM,CAAC,GAAG,WAAW,CAAC,qBAAqB,GAAG,CAAC,CAAC;QAChD,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACnC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE;YAC1B,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;SACnB;QACD,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC;QAC9B,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;QAC5B,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;QACnC,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC;IAChD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,wCAAwC,EAAE,KAAK,IAAI,EAAE;QACtD,MAAM,CAAC,GAAG,WAAW,CAAC,qBAAqB,GAAG,CAAC,CAAC;QAChD,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACnC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE;YAC1B,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;SACnB;QACD,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,MAAM,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC5C,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAChC,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;QACnC,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC;IAChD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,wDAAwD,EAAE,KAAK,IAAI,EAAE;QACtE,MAAM,CAAC,GAAG,WAAW,CAAC,qBAAqB,GAAG,CAAC,CAAC;QAChD,MAAM,UAAU,GAAG,WAAW,CAAC,wBAAwB,CAAC,CAAC,CAAC,CAAC;QAC3D,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACnC,MAAM,KAAK,GAAG,UAAU,GAAG,CAAC,CAAC;QAC7B,MAAM,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC;QACnB,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC;QAC9B,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;QAC5B,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;QACnC,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,KAAK,CAAC,CAAC;IAChD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,cAAc,EAAE,KAAK,IAAI,EAAE;QAC5B,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC1C,MAAM,GAAG,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;QACzB,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,CAAC;IACzC,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,kBAAkB,EAAE,KAAK,IAAI,EAAE;QAChC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACrD,MAAM,GAAG,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC7B,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,CAAC;IACzC,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,uBAAuB,EAAE,KAAK,IAAI,EAAE;QACrC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACtD,iBAAiB,CAAC,MAAM,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAC1D,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,YAAY,EAAE,KAAK,IAAI,EAAE;QAC1B,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACtD,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;QAE1B,MAAM,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC7B,MAAM,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;QAC9B,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAC/C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,YAAY,EAAE,KAAK,IAAI,EAAE;QAC1B,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,GAAG,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACtD,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;QAC1B,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAC5C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,eAAe,EAAE,KAAK,IAAI,EAAE;QAC7B,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,GAAG,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACtD,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3B,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAC5C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,iCAAiC,EAAE,GAAG,EAAE;QACzC,MAAM,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,EAAe,CAAC,CAAC;aACnC,YAAY,CAAC,kDAAkD,CAAC,CAAC;IACxE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,8BAA8B,EAAE,KAAK,IAAI,EAAE;QAC5C,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACvC,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,CAAC;IAC5C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,iCAAiC,EAAE,KAAK,IAAI,EAAE;QAC/C,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;QACtC,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;QAC5B,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;QACnC,iBAAiB,CAAC,MAAM,MAAM,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,CAAC;IAC5C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,cAAc,EAAE,KAAK,IAAI,EAAE;QAC5B,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,GAAG,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACrD,MAAM,EAAE,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC;QACnC,MAAM,EAAE,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,CAAc,EAAE,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAE5D,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;QACjC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACjC,iBAAiB,CAAC,MAAM,EAAE,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACzD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,sBAAsB,EAAE,KAAK,IAAI,EAAE;QACpC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,GAAG,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACrD,MAAM,EAAE,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC;QACnC,MAAM,EAAE,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,CAAc,EAAE,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAE5E,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;QACjC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACjC,iBAAiB,CAAC,MAAM,EAAE,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACzD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,gCAAgC,EAAE,GAAG,EAAE;QACxC,MAAM,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;aACzB,YAAY,CAAC,wDAAwD,CAAC,CAAC;IAC9E,CAAC,CAAC,CAAC;AACL,CAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as tf from '../index';\nimport {ALL_ENVS, describeWithFlags} from '../jasmine_util';\nimport {expectArraysClose, expectArraysEqual} from '../test_util';\n\nimport * as reduce_util from './reduce_util';\n\ndescribeWithFlags('argmin', ALL_ENVS, () => {\n  it('Tensor1D', async () => {\n    const a = tf.tensor1d([1, 0, 3, 2]);\n    const result = tf.argMin(a);\n    expectArraysEqual(await result.data(), 1);\n  });\n\n  it('one value', async () => {\n    const a = tf.tensor1d([10]);\n    const result = tf.argMin(a);\n    expectArraysEqual(await result.data(), 0);\n  });\n\n  it('N > than parallelization threshold', async () => {\n    const n = reduce_util.PARALLELIZE_THRESHOLD * 2;\n    const values = new Float32Array(n);\n    for (let i = 0; i < n; i++) {\n      values[i] = n - i;\n    }\n    const a = tf.tensor1d(values);\n    const result = tf.argMin(a);\n    expect(result.dtype).toBe('int32');\n    expectArraysEqual(await result.data(), n - 1);\n  });\n\n  it('4D, N > than parallelization threshold', async () => {\n    const n = reduce_util.PARALLELIZE_THRESHOLD * 2;\n    const values = new Float32Array(n);\n    for (let i = 0; i < n; i++) {\n      values[i] = n - i;\n    }\n    const a = tf.tensor4d(values, [1, 1, 1, n]);\n    const result = tf.argMin(a, -1);\n    expect(result.dtype).toBe('int32');\n    expectArraysEqual(await result.data(), n - 1);\n  });\n\n  it('min index corresponds to start of a non-initial window', async () => {\n    const n = reduce_util.PARALLELIZE_THRESHOLD * 2;\n    const windowSize = reduce_util.computeOptimalWindowSize(n);\n    const values = new Float32Array(n);\n    const index = windowSize * 2;\n    values[index] = -1;\n    const a = tf.tensor1d(values);\n    const result = tf.argMin(a);\n    expect(result.dtype).toBe('int32');\n    expectArraysEqual(await result.data(), index);\n  });\n\n  it('ignores NaNs', async () => {\n    const a = tf.tensor1d([5, 0, NaN, -1, 3]);\n    const res = tf.argMin(a);\n    expectArraysEqual(await res.data(), 3);\n  });\n\n  it('3D, ignores NaNs', async () => {\n    const a = tf.tensor3d([5, 0, NaN, -1, 3], [1, 1, 5]);\n    const res = tf.argMin(a, -1);\n    expectArraysEqual(await res.data(), 3);\n  });\n\n  it('2D, no axis specified', async () => {\n    const a = tf.tensor2d([3, -1, 0, 100, -7, 2], [2, 3]);\n    expectArraysEqual(await tf.argMin(a).data(), [0, 1, 0]);\n  });\n\n  it('2D, axis=0', async () => {\n    const a = tf.tensor2d([3, -1, 0, 100, -7, 2], [2, 3]);\n    const r = tf.argMin(a, 0);\n\n    expect(r.shape).toEqual([3]);\n    expect(r.dtype).toBe('int32');\n    expectArraysEqual(await r.data(), [0, 1, 0]);\n  });\n\n  it('2D, axis=1', async () => {\n    const a = tf.tensor2d([3, 2, 5, 100, -7, -8], [2, 3]);\n    const r = tf.argMin(a, 1);\n    expectArraysEqual(await r.data(), [1, 2]);\n  });\n\n  it('2D, axis = -1', async () => {\n    const a = tf.tensor2d([3, 2, 5, 100, -7, -8], [2, 3]);\n    const r = tf.argMin(a, -1);\n    expectArraysEqual(await r.data(), [1, 2]);\n  });\n\n  it('throws when passed a non-tensor', () => {\n    expect(() => tf.argMin({} as tf.Tensor))\n        .toThrowError(/Argument 'x' passed to 'argMin' must be a Tensor/);\n  });\n\n  it('accepts a tensor-like object', async () => {\n    const result = tf.argMin([1, 0, 3, 2]);\n    expectArraysEqual(await result.data(), 1);\n  });\n\n  it('accepts tensor with bool values', async () => {\n    const t = tf.tensor1d([0, 1], 'bool');\n    const result = tf.argMin(t);\n    expect(result.dtype).toBe('int32');\n    expectArraysEqual(await result.data(), 0);\n  });\n\n  it('has gradient', async () => {\n    const a = tf.tensor2d([3, 2, 5, 100, -7, 2], [2, 3]);\n    const dy = tf.ones([3], 'float32');\n    const da = tf.grad((x: tf.Tensor2D) => tf.argMin(x))(a, dy);\n\n    expect(da.dtype).toBe('float32');\n    expect(da.shape).toEqual([2, 3]);\n    expectArraysClose(await da.data(), [0, 0, 0, 0, 0, 0]);\n  });\n\n  it('gradient with clones', async () => {\n    const a = tf.tensor2d([3, 2, 5, 100, -7, 2], [2, 3]);\n    const dy = tf.ones([3], 'float32');\n    const da = tf.grad((x: tf.Tensor2D) => tf.argMin(x.clone()).clone())(a, dy);\n\n    expect(da.dtype).toBe('float32');\n    expect(da.shape).toEqual([2, 3]);\n    expectArraysClose(await da.data(), [0, 0, 0, 0, 0, 0]);\n  });\n\n  it('throws error for string tensor', () => {\n    expect(() => tf.argMin(['a']))\n        .toThrowError(/Argument 'x' passed to 'argMin' must be numeric tensor/);\n  });\n});\n"]} |
\ | No newline at end of file |