UNPKG

29.7 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2017 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20describeWithFlags('maxPoolWithArgmax', ALL_ENVS, () => {
21 it('x=[1,1,1] f=[1,1] s=1 d=1 [0] => [0]', async () => {
22 const x = tf.tensor4d([0], [1, 1, 1, 1]);
23 const padding = 0;
24 const { result, indexes } = tf.maxPoolWithArgmax(x, [1, 1], [1, 1], padding);
25 expectArraysClose(await result.data(), [0]);
26 expectArraysClose(await indexes.data(), [0]);
27 });
28 it('x=[2,2,2,1] f=[2,2,2] s=1 p=valid', async () => {
29 const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1]);
30 const { result, indexes } = tf.maxPoolWithArgmax(x, 2, 1, 'valid');
31 expect(result.shape).toEqual([2, 1, 1, 1]);
32 expectArraysClose(await result.data(), [4, 8]);
33 expect(indexes.shape).toEqual([2, 1, 1, 1]);
34 expectArraysClose(await indexes.data(), [3, 3]);
35 });
36 it('x=[2,2,2,1] f=[2,2,2] s=1 p=valid includeBatchInIndex=true', async () => {
37 const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1]);
38 const { result, indexes } = tf.maxPoolWithArgmax(x, 2, 1, 'valid', true);
39 expect(result.shape).toEqual([2, 1, 1, 1]);
40 expectArraysClose(await result.data(), [4, 8]);
41 expect(indexes.shape).toEqual([2, 1, 1, 1]);
42 expectArraysClose(await indexes.data(), [3, 7]);
43 });
44 it('x=[1,3,3,1] f=[2,2] s=1, p=0', async () => {
45 // Feed forward.
46 const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 9, 8], [1, 3, 3, 1]);
47 const { result, indexes } = tf.maxPoolWithArgmax(x, 2, 1, 0);
48 expect(result.shape).toEqual([1, 2, 2, 1]);
49 expectArraysClose(await result.data(), [5, 6, 9, 9]);
50 expect(indexes.shape).toEqual([1, 2, 2, 1]);
51 expectArraysClose(await indexes.data(), [4, 5, 7, 7]);
52 });
53 it('x=[1,3,3,1] f=[2,2] s=1 p=same', async () => {
54 const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 9, 8], [1, 3, 3, 1]);
55 const { result, indexes } = tf.maxPoolWithArgmax(x, 2, 1, 'same');
56 expect(result.shape).toEqual([1, 3, 3, 1]);
57 tf.test_util.expectArraysClose(await result.data(), new Float32Array([5, 6, 6, 9, 9, 8, 9, 9, 8]));
58 expect(indexes.shape).toEqual([1, 3, 3, 1]);
59 expectArraysClose(await indexes.data(), [4, 5, 5, 7, 7, 8, 7, 7, 8]);
60 });
61 it('x=[2,3,3,1] f=[2,2] s=1', async () => {
62 const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 9, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9], [2, 3, 3, 1]);
63 const { result, indexes } = tf.maxPoolWithArgmax(x, 2, 1, 0);
64 expect(result.shape).toEqual([2, 2, 2, 1]);
65 expectArraysClose(await result.data(), [5, 6, 9, 9, 5, 6, 8, 9]);
66 expect(indexes.shape).toEqual([2, 2, 2, 1]);
67 expectArraysClose(await indexes.data(), [4, 5, 7, 7, 4, 5, 7, 8]);
68 });
69 it('x=[2,3,3,1] f=[2,2] s=1 includeBatchInIndex=true', async () => {
70 const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 9, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9], [2, 3, 3, 1]);
71 const { result, indexes } = tf.maxPoolWithArgmax(x, 2, 1, 0, true);
72 expect(result.shape).toEqual([2, 2, 2, 1]);
73 expectArraysClose(await result.data(), [5, 6, 9, 9, 5, 6, 8, 9]);
74 expect(indexes.shape).toEqual([2, 2, 2, 1]);
75 expectArraysClose(await indexes.data(), [4, 5, 7, 7, 13, 14, 16, 17]);
76 });
77 it('[x=[1,3,3,1] f=[2,2] s=1 ignores NaNs', async () => {
78 const x = tf.tensor4d([NaN, 1, 2, 3, 4, 5, 6, 7, 9], [1, 3, 3, 1]);
79 const { result, indexes } = tf.maxPoolWithArgmax(x, 2, 1, 0);
80 expect(result.shape).toEqual([1, 2, 2, 1]);
81 expectArraysClose(await result.data(), [4, 5, 7, 9]);
82 expect(indexes.shape).toEqual([1, 2, 2, 1]);
83 expectArraysClose(await indexes.data(), [4, 5, 7, 8]);
84 });
85 it('x=[1, 3,3,2] f=[2,2] s=1', async () => {
86 // Feed forward.
87 const x = tf.tensor4d([1, 99, 2, 88, 3, 77, 4, 66, 5, 55, 6, 44, 7, 33, 9, 22, 8, 11], [1, 3, 3, 2]);
88 const { result, indexes } = tf.maxPoolWithArgmax(x, 2, 1, 0);
89 expect(result.shape).toEqual([1, 2, 2, 2]);
90 expectArraysClose(await result.data(), [5, 99, 6, 88, 9, 66, 9, 55]);
91 expect(indexes.shape).toEqual([1, 2, 2, 2]);
92 expectArraysClose(await indexes.data(), [8, 1, 10, 3, 14, 7, 14, 9]);
93 });
94 it('x=[1,4,4,1] f=[2,2] s=2', async () => {
95 // Feed forward.
96 const x = tf.tensor4d([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [1, 4, 4, 1]);
97 const { result, indexes } = tf.maxPoolWithArgmax(x, 2, 2, 0);
98 expect(result.shape).toEqual([1, 2, 2, 1]);
99 expectArraysClose(await result.data(), [5, 7, 13, 15]);
100 expect(indexes.shape).toEqual([1, 2, 2, 1]);
101 expectArraysClose(await indexes.data(), [5, 7, 13, 15]);
102 });
103 it('x=[1,2,2,1] f=[2,2] s=1 p=same', async () => {
104 // Feed forward.
105 const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]);
106 const { result, indexes } = tf.maxPoolWithArgmax(x, 2, 1, 'same');
107 expect(result.shape).toEqual([1, 2, 2, 1]);
108 expectArraysClose(await result.data(), [4, 4, 4, 4]);
109 expect(indexes.shape).toEqual([1, 2, 2, 1]);
110 expectArraysClose(await indexes.data(), [3, 3, 3, 3]);
111 });
112 it('throws when x is not rank 4', () => {
113 // tslint:disable-next-line:no-any
114 const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 3, 3]);
115 expect(() => tf.maxPoolWithArgmax(x, 2, 1, 0)).toThrowError();
116 });
117 it('throws when passed a non-tensor', () => {
118 expect(() => tf.maxPoolWithArgmax({}, 2, 1, 'valid'))
119 .toThrowError(/Argument 'x' passed to 'maxPoolWithArgmax' must be a Tensor/);
120 });
121 it('accepts a tensor-like object', async () => {
122 const x = [[[[0]]]]; // 1x1x1
123 const { result, indexes } = tf.maxPoolWithArgmax(x, 1, 1, 0);
124 expectArraysClose(await result.data(), [0]);
125 expect(indexes.shape).toEqual([1, 1, 1, 1]);
126 expectArraysClose(await indexes.data(), [0]);
127 });
128});
129//# sourceMappingURL=data:application/json;base64,
\No newline at end of file