UNPKG

19.1 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2017 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20describeWithFlags('stridedSlice', ALL_ENVS, () => {
21 it('with ellipsisMask=1', async () => {
22 const t = tf.tensor2d([
23 [1, 2, 3, 4, 5],
24 [2, 3, 4, 5, 6],
25 [3, 4, 5, 6, 7],
26 [4, 5, 6, 7, 8],
27 [5, 6, 7, 8, 9],
28 [6, 7, 8, 9, 10],
29 [7, 8, 9, 10, 11],
30 [8, 8, 9, 10, 11],
31 [9, 8, 9, 10, 11],
32 [10, 8, 9, 10, 11],
33 ]);
34 const begin = [0, 4];
35 const end = [0, 5];
36 const strides = [1, 1];
37 const beginMask = 0;
38 const endMask = 0;
39 const ellipsisMask = 1;
40 const output = t.stridedSlice(begin, end, strides, beginMask, endMask, ellipsisMask);
41 expect(output.shape).toEqual([10, 1]);
42 expectArraysClose(await output.data(), [5, 6, 7, 8, 9, 10, 11, 11, 11, 11]);
43 });
44 it('with ellipsisMask=1, begin / end masks and start / end normalization', async () => {
45 const t = tf.randomNormal([1, 6, 2006, 4]);
46 const output = tf.stridedSlice(t, [0, 0, 0], [0, 2004, 0], [1, 1, 1], 6, 4, 1);
47 expect(output.shape).toEqual([1, 6, 2004, 4]);
48 });
49 it('with ellipsisMask=1 and start / end normalization', async () => {
50 const t = tf.tensor3d([
51 [[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]
52 ]);
53 const begin = [1, 0];
54 const end = [2, 1];
55 const strides = [1, 1];
56 const beginMask = 0;
57 const endMask = 0;
58 const ellipsisMask = 1;
59 const output = tf.stridedSlice(t, begin, end, strides, beginMask, endMask, ellipsisMask);
60 expect(output.shape).toEqual([3, 2, 1]);
61 expectArraysClose(await output.data(), [1, 2, 3, 4, 5, 6]);
62 });
63 it('with ellipsisMask=2', async () => {
64 const t = tf.tensor3d([
65 [[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]
66 ]);
67 const begin = [1, 0, 0];
68 const end = [2, 1, 3];
69 const strides = [1, 1, 1];
70 const beginMask = 0;
71 const endMask = 0;
72 const ellipsisMask = 2;
73 const output = tf.stridedSlice(t, begin, end, strides, beginMask, endMask, ellipsisMask);
74 expect(output.shape).toEqual([1, 2, 3]);
75 expectArraysClose(await output.data(), [3, 3, 3, 4, 4, 4]);
76 });
77 it('with ellipsisMask=2 and start / end normalization', async () => {
78 const t = tf.tensor4d([
79 [[[1, 1], [1, 1], [1, 1]], [[2, 2], [2, 2], [2, 2]]],
80 [[[3, 3], [3, 3], [3, 3]], [[4, 4], [4, 4], [4, 4]]],
81 [[[5, 5], [5, 5], [5, 5]], [[6, 6], [6, 6], [6, 6]]]
82 ]);
83 const begin = [1, 0, 0];
84 const end = [2, 1, 1];
85 const strides = [1, 1, 1];
86 const beginMask = 0;
87 const endMask = 0;
88 const ellipsisMask = 2;
89 const output = tf.stridedSlice(t, begin, end, strides, beginMask, endMask, ellipsisMask);
90 expect(output.shape).toEqual([1, 2, 3, 1]);
91 expectArraysClose(await output.data(), [3, 3, 3, 4, 4, 4]);
92 });
93 it('stridedSlice should fail if ellipsis mask is set and newAxisMask or ' +
94 'shrinkAxisMask are also set', async () => {
95 const tensor = tf.tensor1d([0, 1, 2, 3]);
96 expect(() => tf.stridedSlice(tensor, [0], [3], [2], 0, 0, 1, 1))
97 .toThrow();
98 expect(() => tf.stridedSlice(tensor, [0], [3], [2], 0, 0, 1, 0, 1))
99 .toThrow();
100 });
101 it('stridedSlice with first axis being new', async () => {
102 // Python slice code: t[tf.newaxis,0:3]
103 const t = tf.tensor1d([0, 1, 2, 3]);
104 const begin = [0, 0];
105 const end = [1, 3];
106 const strides = [1, 2];
107 const beginMask = 0;
108 const endMask = 0;
109 const ellipsisMask = 0;
110 const newAxisMask = 1;
111 const output = tf.stridedSlice(t, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask);
112 expect(output.shape).toEqual([1, 2]);
113 expectArraysClose(await output.data(), [0, 2]);
114 });
115 it('strided slice with several new axes', async () => {
116 // Python slice code: t[1:2,tf.newaxis,0:3,tf.newaxis,2:5]
117 const t = tf.zeros([2, 3, 4, 5]);
118 const begin = [1, 0, 0, 0, 2];
119 const end = [2, 1, 3, 1, 5];
120 const strides = null;
121 const beginMask = 0;
122 const endMask = 0;
123 const ellipsisMask = 0;
124 const newAxisMask = 0b1010;
125 const output = tf.stridedSlice(t, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask);
126 expect(output.shape).toEqual([1, 1, 3, 1, 2, 5]);
127 expectArraysClose(await output.data(), new Array(30).fill(0));
128 });
129 it('strided slice with new axes and shrink axes', () => {
130 // Python slice code: t[1:2,tf.newaxis,1,tf.newaxis,2,2:5]
131 const t = tf.zeros([2, 3, 4, 5]);
132 const begin = [1, 0, 1, 0, 2, 2];
133 const end = [2, 1, 2, 1, 3, 5];
134 const strides = null;
135 const beginMask = 0;
136 const endMask = 0;
137 const ellipsisMask = 0;
138 const newAxisMask = 0b1010;
139 const shrinkAxisMask = 0b10100;
140 const output = tf.stridedSlice(t, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
141 expect(output.shape).toEqual([1, 1, 1, 3]);
142 });
143 it('stridedSlice should support 1d tensor', async () => {
144 const tensor = tf.tensor1d([0, 1, 2, 3]);
145 const output = tf.stridedSlice(tensor, [0], [3], [2]);
146 expect(output.shape).toEqual([2]);
147 expectArraysClose(await output.data(), [0, 2]);
148 });
149 it('stridedSlice should support 1d tensor', async () => {
150 const tensor = tf.tensor1d([0, 1, 2, 3]);
151 const output = tf.stridedSlice(tensor, [0], [3], [2]);
152 expect(output.shape).toEqual([2]);
153 expectArraysClose(await output.data(), [0, 2]);
154 });
155 it('stridedSlice with 1d tensor should be used by tensor directly', async () => {
156 const t = tf.tensor1d([0, 1, 2, 3]);
157 const output = t.stridedSlice([0], [3], [2]);
158 expect(output.shape).toEqual([2]);
159 expectArraysClose(await output.data(), [0, 2]);
160 });
161 it('stridedSlice should support 1d tensor empty result', async () => {
162 const tensor = tf.tensor1d([0, 1, 2, 3]);
163 const output = tf.stridedSlice(tensor, [10], [3], [2]);
164 expect(output.shape).toEqual([0]);
165 expectArraysClose(await output.data(), []);
166 });
167 it('stridedSlice should support 1d tensor negative begin', async () => {
168 const tensor = tf.tensor1d([0, 1, 2, 3]);
169 const output = tf.stridedSlice(tensor, [-3], [3], [1]);
170 expect(output.shape).toEqual([2]);
171 expectArraysClose(await output.data(), [1, 2]);
172 });
173 it('stridedSlice should support 1d tensor out of range begin', async () => {
174 const tensor = tf.tensor1d([0, 1, 2, 3]);
175 const output = tf.stridedSlice(tensor, [-5], [3], [1]);
176 expect(output.shape).toEqual([3]);
177 expectArraysClose(await output.data(), [0, 1, 2]);
178 });
179 it('stridedSlice should support 1d tensor negative end', async () => {
180 const tensor = tf.tensor1d([0, 1, 2, 3]);
181 const output = tf.stridedSlice(tensor, [1], [-2], [1]);
182 expect(output.shape).toEqual([1]);
183 expectArraysClose(await output.data(), [1]);
184 });
185 it('stridedSlice should support 1d tensor out of range end', async () => {
186 const tensor = tf.tensor1d([0, 1, 2, 3]);
187 const output = tf.stridedSlice(tensor, [-3], [5], [1]);
188 expect(output.shape).toEqual([3]);
189 expectArraysClose(await output.data(), [1, 2, 3]);
190 });
191 it('stridedSlice should support 1d tensor begin mask', async () => {
192 const tensor = tf.tensor1d([0, 1, 2, 3]);
193 const output = tf.stridedSlice(tensor, [1], [3], [1], 1);
194 expect(output.shape).toEqual([3]);
195 expectArraysClose(await output.data(), [0, 1, 2]);
196 });
197 it('stridedSlice should support 1d tensor nagtive begin and stride', async () => {
198 const tensor = tf.tensor1d([0, 1, 2, 3]);
199 const output = tf.stridedSlice(tensor, [-2], [-3], [-1]);
200 expect(output.shape).toEqual([1]);
201 expectArraysClose(await output.data(), [2]);
202 });
203 it('stridedSlice should support 1d tensor' +
204 ' out of range begin and negative stride', async () => {
205 const tensor = tf.tensor1d([0, 1, 2, 3]);
206 const output = tf.stridedSlice(tensor, [5], [-2], [-1]);
207 expect(output.shape).toEqual([1]);
208 expectArraysClose(await output.data(), [3]);
209 });
210 it('stridedSlice should support 1d tensor nagtive end and stride', async () => {
211 const tensor = tf.tensor1d([0, 1, 2, 3]);
212 const output = tf.stridedSlice(tensor, [2], [-4], [-1]);
213 expect(output.shape).toEqual([2]);
214 expectArraysClose(await output.data(), [2, 1]);
215 });
216 it('stridedSlice should support 1d tensor' +
217 ' out of range end and negative stride', async () => {
218 const tensor = tf.tensor1d([0, 1, 2, 3]);
219 const output = tf.stridedSlice(tensor, [-3], [-5], [-1]);
220 expect(output.shape).toEqual([2]);
221 expectArraysClose(await output.data(), [1, 0]);
222 });
223 it('stridedSlice should support 1d tensor end mask', async () => {
224 const tensor = tf.tensor1d([0, 1, 2, 3]);
225 const output = tf.stridedSlice(tensor, [1], [3], [1], 0, 1);
226 expect(output.shape).toEqual([3]);
227 expectArraysClose(await output.data(), [1, 2, 3]);
228 });
229 it('stridedSlice should support 1d tensor shrink axis mask', async () => {
230 const tensor = tf.tensor1d([0, 1, 2, 3]);
231 const output = tf.stridedSlice(tensor, [1], [3], [1], 0, 0, 0, 0, 1);
232 expect(output.shape).toEqual([]);
233 expectArraysClose(await output.data(), [1]);
234 });
235 it('stridedSlice should support 1d tensor negative stride', async () => {
236 const tensor = tf.tensor1d([0, 1, 2, 3]);
237 const output = tf.stridedSlice(tensor, [-1], [-4], [-1]);
238 expect(output.shape).toEqual([3]);
239 expectArraysClose(await output.data(), [3, 2, 1]);
240 });
241 it('stridedSlice should support 1d tensor even length stride', async () => {
242 const tensor = tf.tensor1d([0, 1, 2, 3]);
243 const output = tf.stridedSlice(tensor, [0], [2], [2]);
244 expect(output.shape).toEqual([1]);
245 expectArraysClose(await output.data(), [0]);
246 });
247 it('stridedSlice should support 1d tensor odd length stride', async () => {
248 const tensor = tf.tensor1d([0, 1, 2, 3]);
249 const output = tf.stridedSlice(tensor, [0], [3], [2]);
250 expect(output.shape).toEqual([2]);
251 expectArraysClose(await output.data(), [0, 2]);
252 });
253 it('stridedSlice should support 2d tensor identity', async () => {
254 const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
255 const output = tf.stridedSlice(tensor, [0, 0], [2, 3], [1, 1]);
256 expect(output.shape).toEqual([2, 3]);
257 expectArraysClose(await output.data(), [1, 2, 3, 4, 5, 6]);
258 });
259 it('stridedSlice should support 2d tensor', async () => {
260 const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
261 const output = tf.stridedSlice(tensor, [1, 0], [2, 2], [1, 1]);
262 expect(output.shape).toEqual([1, 2]);
263 expectArraysClose(await output.data(), [4, 5]);
264 });
265 it('stridedSlice should support 2d tensor strides', async () => {
266 const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
267 const output = tf.stridedSlice(tensor, [0, 0], [2, 3], [2, 2]);
268 expect(output.shape).toEqual([1, 2]);
269 expectArraysClose(await output.data(), [1, 3]);
270 });
271 it('stridedSlice with 2d tensor should be used by tensor directly', async () => {
272 const t = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
273 const output = t.stridedSlice([1, 0], [2, 2], [1, 1]);
274 expect(output.shape).toEqual([1, 2]);
275 expectArraysClose(await output.data(), [4, 5]);
276 });
277 it('stridedSlice should support 2d tensor negative strides', async () => {
278 const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
279 const output = tf.stridedSlice(tensor, [1, -1], [2, -4], [2, -1]);
280 expect(output.shape).toEqual([1, 3]);
281 expectArraysClose(await output.data(), [6, 5, 4]);
282 });
283 it('stridedSlice should support 2d tensor begin mask', async () => {
284 const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
285 const output = tf.stridedSlice(tensor, [1, 0], [2, 2], [1, 1], 1);
286 expect(output.shape).toEqual([2, 2]);
287 expectArraysClose(await output.data(), [1, 2, 4, 5]);
288 });
289 it('stridedSlice should support 2d tensor shrink mask', async () => {
290 const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
291 const output = tf.stridedSlice(tensor, [1, 0], [2, 2], [1, 1], 0, 0, 0, 0, 1);
292 expect(output.shape).toEqual([2]);
293 expectArraysClose(await output.data(), [4, 5]);
294 });
295 it('stridedSlice should support 2d tensor end mask', async () => {
296 const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
297 const output = tf.stridedSlice(tensor, [1, 0], [2, 2], [1, 1], 0, 2);
298 expect(output.shape).toEqual([1, 3]);
299 expectArraysClose(await output.data(), [4, 5, 6]);
300 });
301 it('stridedSlice should support 2d tensor' +
302 ' negative strides and begin mask', async () => {
303 const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
304 const output = tf.stridedSlice(tensor, [1, -2], [2, -4], [1, -1], 2);
305 expect(output.shape).toEqual([1, 3]);
306 expectArraysClose(await output.data(), [6, 5, 4]);
307 });
308 it('stridedSlice should support 2d tensor' +
309 ' negative strides and end mask', async () => {
310 const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
311 const output = tf.stridedSlice(tensor, [1, -2], [2, -3], [1, -1], 0, 2);
312 expect(output.shape).toEqual([1, 2]);
313 expectArraysClose(await output.data(), [5, 4]);
314 });
315 it('stridedSlice should support 3d tensor identity', async () => {
316 const tensor = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]);
317 const output = tf.stridedSlice(tensor, [0, 0, 0], [2, 3, 2], [1, 1, 1]);
318 expect(output.shape).toEqual([2, 3, 2]);
319 expectArraysClose(await output.data(), [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
320 });
321 it('stridedSlice should support 3d tensor negative stride', async () => {
322 const tensor = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]);
323 const output = tf.stridedSlice(tensor, [-1, -1, -1], [-3, -4, -3], [-1, -1, -1]);
324 expect(output.shape).toEqual([2, 3, 2]);
325 expectArraysClose(await output.data(), [12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]);
326 });
327 it('stridedSlice should support 3d tensor strided 2', async () => {
328 const tensor = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]);
329 const output = tf.stridedSlice(tensor, [0, 0, 0], [2, 3, 2], [2, 2, 2]);
330 expect(output.shape).toEqual([1, 2, 1]);
331 expectArraysClose(await output.data(), [1, 5]);
332 });
333 it('stridedSlice should support 3d tensor shrink mask', async () => {
334 const tensor = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]);
335 const output = tf.stridedSlice(tensor, [0, 0, 0], [2, 3, 2], [1, 1, 1], 0, 0, 0, 0, 1);
336 expect(output.shape).toEqual([3, 2]);
337 expectArraysClose(await output.data(), [1, 2, 3, 4, 5, 6]);
338 });
339 it('stridedSlice should support 3d with smaller length of begin array', async () => {
340 const tensor = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 1, 2]);
341 const output = tf.stridedSlice(tensor, [1, 0], [2, 3, 1, 2], [1, 1, 1, 1], 0, 0, 0, 0, 0);
342 expect(output.shape).toEqual([1, 3, 1, 2]);
343 expectArraysClose(await output.data(), [7, 8, 9, 10, 11, 12]);
344 });
345 it('stridedSlice should support 3d with smaller length of end array', async () => {
346 const tensor = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 1, 2]);
347 const output = tf.stridedSlice(tensor, [1, 0, 0, 0], [2, 3], [1, 1, 1, 1], 0, 0, 0, 0, 0);
348 expect(output.shape).toEqual([1, 3, 1, 2]);
349 expectArraysClose(await output.data(), [7, 8, 9, 10, 11, 12]);
350 });
351 it('stridedSlice should support 3d with smaller length of stride array', async () => {
352 const tensor = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 1, 2]);
353 const output = tf.stridedSlice(tensor, [1, 0, 0, 0], [2, 3, 1, 2], [1, 1], 0, 0, 0, 0, 0);
354 expect(output.shape).toEqual([1, 3, 1, 2]);
355 expectArraysClose(await output.data(), [7, 8, 9, 10, 11, 12]);
356 });
357 it('stridedSlice should throw when passed a non-tensor', () => {
358 expect(() => tf.stridedSlice({}, [0], [0], [1]))
359 .toThrowError(/Argument 'x' passed to 'stridedSlice' must be a Tensor/);
360 });
361 it('stridedSlice should handle negative end with ellipsisMask', () => {
362 const a = tf.ones([1, 240, 1, 10]);
363 const output = tf.stridedSlice(a, [0, 0, 0], [0, -1, 0], [1, 1, 1], 3, 1, 4);
364 expect(output.shape).toEqual([1, 239, 1, 10]);
365 });
366 it('accepts a tensor-like object', async () => {
367 const tensor = [0, 1, 2, 3];
368 const output = tf.stridedSlice(tensor, [0], [3], [2]);
369 expect(output.shape).toEqual([2]);
370 expectArraysClose(await output.data(), [0, 2]);
371 });
372 it('ensure no memory leak', async () => {
373 const numTensorsBefore = tf.memory().numTensors;
374 const numDataIdBefore = tf.engine().backend.numDataIds();
375 const tensor = tf.tensor1d([0, 1, 2, 3]);
376 const output = tf.stridedSlice(tensor, [0], [3], [2]);
377 expect(output.shape).toEqual([2]);
378 expectArraysClose(await output.data(), [0, 2]);
379 tensor.dispose();
380 output.dispose();
381 const numTensorsAfter = tf.memory().numTensors;
382 const numDataIdAfter = tf.engine().backend.numDataIds();
383 expect(numTensorsAfter).toBe(numTensorsBefore);
384 expect(numDataIdAfter).toBe(numDataIdBefore);
385 });
386});
387//# sourceMappingURL=strided_slice_test.js.map
\No newline at end of file