1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 | import * as tf from '../index';
|
18 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
19 | import { expectArraysClose } from '../test_util';
|
20 | describeWithFlags('stridedSlice', ALL_ENVS, () => {
|
21 | it('with ellipsisMask=1', async () => {
|
22 | const t = tf.tensor2d([
|
23 | [1, 2, 3, 4, 5],
|
24 | [2, 3, 4, 5, 6],
|
25 | [3, 4, 5, 6, 7],
|
26 | [4, 5, 6, 7, 8],
|
27 | [5, 6, 7, 8, 9],
|
28 | [6, 7, 8, 9, 10],
|
29 | [7, 8, 9, 10, 11],
|
30 | [8, 8, 9, 10, 11],
|
31 | [9, 8, 9, 10, 11],
|
32 | [10, 8, 9, 10, 11],
|
33 | ]);
|
34 | const begin = [0, 4];
|
35 | const end = [0, 5];
|
36 | const strides = [1, 1];
|
37 | const beginMask = 0;
|
38 | const endMask = 0;
|
39 | const ellipsisMask = 1;
|
40 | const output = t.stridedSlice(begin, end, strides, beginMask, endMask, ellipsisMask);
|
41 | expect(output.shape).toEqual([10, 1]);
|
42 | expectArraysClose(await output.data(), [5, 6, 7, 8, 9, 10, 11, 11, 11, 11]);
|
43 | });
|
44 | it('with ellipsisMask=1, begin / end masks and start / end normalization', async () => {
|
45 | const t = tf.randomNormal([1, 6, 2006, 4]);
|
46 | const output = tf.stridedSlice(t, [0, 0, 0], [0, 2004, 0], [1, 1, 1], 6, 4, 1);
|
47 | expect(output.shape).toEqual([1, 6, 2004, 4]);
|
48 | });
|
49 | it('with ellipsisMask=1 and start / end normalization', async () => {
|
50 | const t = tf.tensor3d([
|
51 | [[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]
|
52 | ]);
|
53 | const begin = [1, 0];
|
54 | const end = [2, 1];
|
55 | const strides = [1, 1];
|
56 | const beginMask = 0;
|
57 | const endMask = 0;
|
58 | const ellipsisMask = 1;
|
59 | const output = tf.stridedSlice(t, begin, end, strides, beginMask, endMask, ellipsisMask);
|
60 | expect(output.shape).toEqual([3, 2, 1]);
|
61 | expectArraysClose(await output.data(), [1, 2, 3, 4, 5, 6]);
|
62 | });
|
63 | it('with ellipsisMask=2', async () => {
|
64 | const t = tf.tensor3d([
|
65 | [[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]
|
66 | ]);
|
67 | const begin = [1, 0, 0];
|
68 | const end = [2, 1, 3];
|
69 | const strides = [1, 1, 1];
|
70 | const beginMask = 0;
|
71 | const endMask = 0;
|
72 | const ellipsisMask = 2;
|
73 | const output = tf.stridedSlice(t, begin, end, strides, beginMask, endMask, ellipsisMask);
|
74 | expect(output.shape).toEqual([1, 2, 3]);
|
75 | expectArraysClose(await output.data(), [3, 3, 3, 4, 4, 4]);
|
76 | });
|
77 | it('with ellipsisMask=2 and start / end normalization', async () => {
|
78 | const t = tf.tensor4d([
|
79 | [[[1, 1], [1, 1], [1, 1]], [[2, 2], [2, 2], [2, 2]]],
|
80 | [[[3, 3], [3, 3], [3, 3]], [[4, 4], [4, 4], [4, 4]]],
|
81 | [[[5, 5], [5, 5], [5, 5]], [[6, 6], [6, 6], [6, 6]]]
|
82 | ]);
|
83 | const begin = [1, 0, 0];
|
84 | const end = [2, 1, 1];
|
85 | const strides = [1, 1, 1];
|
86 | const beginMask = 0;
|
87 | const endMask = 0;
|
88 | const ellipsisMask = 2;
|
89 | const output = tf.stridedSlice(t, begin, end, strides, beginMask, endMask, ellipsisMask);
|
90 | expect(output.shape).toEqual([1, 2, 3, 1]);
|
91 | expectArraysClose(await output.data(), [3, 3, 3, 4, 4, 4]);
|
92 | });
|
93 | it('stridedSlice should fail if ellipsis mask is set and newAxisMask or ' +
|
94 | 'shrinkAxisMask are also set', async () => {
|
95 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
96 | expect(() => tf.stridedSlice(tensor, [0], [3], [2], 0, 0, 1, 1))
|
97 | .toThrow();
|
98 | expect(() => tf.stridedSlice(tensor, [0], [3], [2], 0, 0, 1, 0, 1))
|
99 | .toThrow();
|
100 | });
|
101 | it('stridedSlice with first axis being new', async () => {
|
102 |
|
103 | const t = tf.tensor1d([0, 1, 2, 3]);
|
104 | const begin = [0, 0];
|
105 | const end = [1, 3];
|
106 | const strides = [1, 2];
|
107 | const beginMask = 0;
|
108 | const endMask = 0;
|
109 | const ellipsisMask = 0;
|
110 | const newAxisMask = 1;
|
111 | const output = tf.stridedSlice(t, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask);
|
112 | expect(output.shape).toEqual([1, 2]);
|
113 | expectArraysClose(await output.data(), [0, 2]);
|
114 | });
|
115 | it('strided slice with several new axes', async () => {
|
116 |
|
117 | const t = tf.zeros([2, 3, 4, 5]);
|
118 | const begin = [1, 0, 0, 0, 2];
|
119 | const end = [2, 1, 3, 1, 5];
|
120 | const strides = null;
|
121 | const beginMask = 0;
|
122 | const endMask = 0;
|
123 | const ellipsisMask = 0;
|
124 | const newAxisMask = 0b1010;
|
125 | const output = tf.stridedSlice(t, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask);
|
126 | expect(output.shape).toEqual([1, 1, 3, 1, 2, 5]);
|
127 | expectArraysClose(await output.data(), new Array(30).fill(0));
|
128 | });
|
129 | it('strided slice with new axes and shrink axes', () => {
|
130 |
|
131 | const t = tf.zeros([2, 3, 4, 5]);
|
132 | const begin = [1, 0, 1, 0, 2, 2];
|
133 | const end = [2, 1, 2, 1, 3, 5];
|
134 | const strides = null;
|
135 | const beginMask = 0;
|
136 | const endMask = 0;
|
137 | const ellipsisMask = 0;
|
138 | const newAxisMask = 0b1010;
|
139 | const shrinkAxisMask = 0b10100;
|
140 | const output = tf.stridedSlice(t, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
|
141 | expect(output.shape).toEqual([1, 1, 1, 3]);
|
142 | });
|
143 | it('stridedSlice should support 1d tensor', async () => {
|
144 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
145 | const output = tf.stridedSlice(tensor, [0], [3], [2]);
|
146 | expect(output.shape).toEqual([2]);
|
147 | expectArraysClose(await output.data(), [0, 2]);
|
148 | });
|
149 | it('stridedSlice should support 1d tensor', async () => {
|
150 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
151 | const output = tf.stridedSlice(tensor, [0], [3], [2]);
|
152 | expect(output.shape).toEqual([2]);
|
153 | expectArraysClose(await output.data(), [0, 2]);
|
154 | });
|
155 | it('stridedSlice with 1d tensor should be used by tensor directly', async () => {
|
156 | const t = tf.tensor1d([0, 1, 2, 3]);
|
157 | const output = t.stridedSlice([0], [3], [2]);
|
158 | expect(output.shape).toEqual([2]);
|
159 | expectArraysClose(await output.data(), [0, 2]);
|
160 | });
|
161 | it('stridedSlice should support 1d tensor empty result', async () => {
|
162 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
163 | const output = tf.stridedSlice(tensor, [10], [3], [2]);
|
164 | expect(output.shape).toEqual([0]);
|
165 | expectArraysClose(await output.data(), []);
|
166 | });
|
167 | it('stridedSlice should support 1d tensor negative begin', async () => {
|
168 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
169 | const output = tf.stridedSlice(tensor, [-3], [3], [1]);
|
170 | expect(output.shape).toEqual([2]);
|
171 | expectArraysClose(await output.data(), [1, 2]);
|
172 | });
|
173 | it('stridedSlice should support 1d tensor out of range begin', async () => {
|
174 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
175 | const output = tf.stridedSlice(tensor, [-5], [3], [1]);
|
176 | expect(output.shape).toEqual([3]);
|
177 | expectArraysClose(await output.data(), [0, 1, 2]);
|
178 | });
|
179 | it('stridedSlice should support 1d tensor negative end', async () => {
|
180 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
181 | const output = tf.stridedSlice(tensor, [1], [-2], [1]);
|
182 | expect(output.shape).toEqual([1]);
|
183 | expectArraysClose(await output.data(), [1]);
|
184 | });
|
185 | it('stridedSlice should support 1d tensor out of range end', async () => {
|
186 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
187 | const output = tf.stridedSlice(tensor, [-3], [5], [1]);
|
188 | expect(output.shape).toEqual([3]);
|
189 | expectArraysClose(await output.data(), [1, 2, 3]);
|
190 | });
|
191 | it('stridedSlice should support 1d tensor begin mask', async () => {
|
192 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
193 | const output = tf.stridedSlice(tensor, [1], [3], [1], 1);
|
194 | expect(output.shape).toEqual([3]);
|
195 | expectArraysClose(await output.data(), [0, 1, 2]);
|
196 | });
|
197 | it('stridedSlice should support 1d tensor nagtive begin and stride', async () => {
|
198 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
199 | const output = tf.stridedSlice(tensor, [-2], [-3], [-1]);
|
200 | expect(output.shape).toEqual([1]);
|
201 | expectArraysClose(await output.data(), [2]);
|
202 | });
|
203 | it('stridedSlice should support 1d tensor' +
|
204 | ' out of range begin and negative stride', async () => {
|
205 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
206 | const output = tf.stridedSlice(tensor, [5], [-2], [-1]);
|
207 | expect(output.shape).toEqual([1]);
|
208 | expectArraysClose(await output.data(), [3]);
|
209 | });
|
210 | it('stridedSlice should support 1d tensor nagtive end and stride', async () => {
|
211 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
212 | const output = tf.stridedSlice(tensor, [2], [-4], [-1]);
|
213 | expect(output.shape).toEqual([2]);
|
214 | expectArraysClose(await output.data(), [2, 1]);
|
215 | });
|
216 | it('stridedSlice should support 1d tensor' +
|
217 | ' out of range end and negative stride', async () => {
|
218 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
219 | const output = tf.stridedSlice(tensor, [-3], [-5], [-1]);
|
220 | expect(output.shape).toEqual([2]);
|
221 | expectArraysClose(await output.data(), [1, 0]);
|
222 | });
|
223 | it('stridedSlice should support 1d tensor end mask', async () => {
|
224 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
225 | const output = tf.stridedSlice(tensor, [1], [3], [1], 0, 1);
|
226 | expect(output.shape).toEqual([3]);
|
227 | expectArraysClose(await output.data(), [1, 2, 3]);
|
228 | });
|
229 | it('stridedSlice should support 1d tensor shrink axis mask', async () => {
|
230 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
231 | const output = tf.stridedSlice(tensor, [1], [3], [1], 0, 0, 0, 0, 1);
|
232 | expect(output.shape).toEqual([]);
|
233 | expectArraysClose(await output.data(), [1]);
|
234 | });
|
235 | it('stridedSlice should support 1d tensor negative stride', async () => {
|
236 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
237 | const output = tf.stridedSlice(tensor, [-1], [-4], [-1]);
|
238 | expect(output.shape).toEqual([3]);
|
239 | expectArraysClose(await output.data(), [3, 2, 1]);
|
240 | });
|
241 | it('stridedSlice should support 1d tensor even length stride', async () => {
|
242 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
243 | const output = tf.stridedSlice(tensor, [0], [2], [2]);
|
244 | expect(output.shape).toEqual([1]);
|
245 | expectArraysClose(await output.data(), [0]);
|
246 | });
|
247 | it('stridedSlice should support 1d tensor odd length stride', async () => {
|
248 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
249 | const output = tf.stridedSlice(tensor, [0], [3], [2]);
|
250 | expect(output.shape).toEqual([2]);
|
251 | expectArraysClose(await output.data(), [0, 2]);
|
252 | });
|
253 | it('stridedSlice should support 2d tensor identity', async () => {
|
254 | const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
255 | const output = tf.stridedSlice(tensor, [0, 0], [2, 3], [1, 1]);
|
256 | expect(output.shape).toEqual([2, 3]);
|
257 | expectArraysClose(await output.data(), [1, 2, 3, 4, 5, 6]);
|
258 | });
|
259 | it('stridedSlice should support 2d tensor', async () => {
|
260 | const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
261 | const output = tf.stridedSlice(tensor, [1, 0], [2, 2], [1, 1]);
|
262 | expect(output.shape).toEqual([1, 2]);
|
263 | expectArraysClose(await output.data(), [4, 5]);
|
264 | });
|
265 | it('stridedSlice should support 2d tensor strides', async () => {
|
266 | const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
267 | const output = tf.stridedSlice(tensor, [0, 0], [2, 3], [2, 2]);
|
268 | expect(output.shape).toEqual([1, 2]);
|
269 | expectArraysClose(await output.data(), [1, 3]);
|
270 | });
|
271 | it('stridedSlice with 2d tensor should be used by tensor directly', async () => {
|
272 | const t = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
273 | const output = t.stridedSlice([1, 0], [2, 2], [1, 1]);
|
274 | expect(output.shape).toEqual([1, 2]);
|
275 | expectArraysClose(await output.data(), [4, 5]);
|
276 | });
|
277 | it('stridedSlice should support 2d tensor negative strides', async () => {
|
278 | const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
279 | const output = tf.stridedSlice(tensor, [1, -1], [2, -4], [2, -1]);
|
280 | expect(output.shape).toEqual([1, 3]);
|
281 | expectArraysClose(await output.data(), [6, 5, 4]);
|
282 | });
|
283 | it('stridedSlice should support 2d tensor begin mask', async () => {
|
284 | const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
285 | const output = tf.stridedSlice(tensor, [1, 0], [2, 2], [1, 1], 1);
|
286 | expect(output.shape).toEqual([2, 2]);
|
287 | expectArraysClose(await output.data(), [1, 2, 4, 5]);
|
288 | });
|
289 | it('stridedSlice should support 2d tensor shrink mask', async () => {
|
290 | const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
291 | const output = tf.stridedSlice(tensor, [1, 0], [2, 2], [1, 1], 0, 0, 0, 0, 1);
|
292 | expect(output.shape).toEqual([2]);
|
293 | expectArraysClose(await output.data(), [4, 5]);
|
294 | });
|
295 | it('stridedSlice should support 2d tensor end mask', async () => {
|
296 | const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
297 | const output = tf.stridedSlice(tensor, [1, 0], [2, 2], [1, 1], 0, 2);
|
298 | expect(output.shape).toEqual([1, 3]);
|
299 | expectArraysClose(await output.data(), [4, 5, 6]);
|
300 | });
|
301 | it('stridedSlice should support 2d tensor' +
|
302 | ' negative strides and begin mask', async () => {
|
303 | const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
304 | const output = tf.stridedSlice(tensor, [1, -2], [2, -4], [1, -1], 2);
|
305 | expect(output.shape).toEqual([1, 3]);
|
306 | expectArraysClose(await output.data(), [6, 5, 4]);
|
307 | });
|
308 | it('stridedSlice should support 2d tensor' +
|
309 | ' negative strides and end mask', async () => {
|
310 | const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
|
311 | const output = tf.stridedSlice(tensor, [1, -2], [2, -3], [1, -1], 0, 2);
|
312 | expect(output.shape).toEqual([1, 2]);
|
313 | expectArraysClose(await output.data(), [5, 4]);
|
314 | });
|
315 | it('stridedSlice should support 3d tensor identity', async () => {
|
316 | const tensor = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]);
|
317 | const output = tf.stridedSlice(tensor, [0, 0, 0], [2, 3, 2], [1, 1, 1]);
|
318 | expect(output.shape).toEqual([2, 3, 2]);
|
319 | expectArraysClose(await output.data(), [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
|
320 | });
|
321 | it('stridedSlice should support 3d tensor negative stride', async () => {
|
322 | const tensor = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]);
|
323 | const output = tf.stridedSlice(tensor, [-1, -1, -1], [-3, -4, -3], [-1, -1, -1]);
|
324 | expect(output.shape).toEqual([2, 3, 2]);
|
325 | expectArraysClose(await output.data(), [12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]);
|
326 | });
|
327 | it('stridedSlice should support 3d tensor strided 2', async () => {
|
328 | const tensor = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]);
|
329 | const output = tf.stridedSlice(tensor, [0, 0, 0], [2, 3, 2], [2, 2, 2]);
|
330 | expect(output.shape).toEqual([1, 2, 1]);
|
331 | expectArraysClose(await output.data(), [1, 5]);
|
332 | });
|
333 | it('stridedSlice should support 3d tensor shrink mask', async () => {
|
334 | const tensor = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]);
|
335 | const output = tf.stridedSlice(tensor, [0, 0, 0], [2, 3, 2], [1, 1, 1], 0, 0, 0, 0, 1);
|
336 | expect(output.shape).toEqual([3, 2]);
|
337 | expectArraysClose(await output.data(), [1, 2, 3, 4, 5, 6]);
|
338 | });
|
339 | it('stridedSlice should support 3d with smaller length of begin array', async () => {
|
340 | const tensor = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 1, 2]);
|
341 | const output = tf.stridedSlice(tensor, [1, 0], [2, 3, 1, 2], [1, 1, 1, 1], 0, 0, 0, 0, 0);
|
342 | expect(output.shape).toEqual([1, 3, 1, 2]);
|
343 | expectArraysClose(await output.data(), [7, 8, 9, 10, 11, 12]);
|
344 | });
|
345 | it('stridedSlice should support 3d with smaller length of end array', async () => {
|
346 | const tensor = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 1, 2]);
|
347 | const output = tf.stridedSlice(tensor, [1, 0, 0, 0], [2, 3], [1, 1, 1, 1], 0, 0, 0, 0, 0);
|
348 | expect(output.shape).toEqual([1, 3, 1, 2]);
|
349 | expectArraysClose(await output.data(), [7, 8, 9, 10, 11, 12]);
|
350 | });
|
351 | it('stridedSlice should support 3d with smaller length of stride array', async () => {
|
352 | const tensor = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 1, 2]);
|
353 | const output = tf.stridedSlice(tensor, [1, 0, 0, 0], [2, 3, 1, 2], [1, 1], 0, 0, 0, 0, 0);
|
354 | expect(output.shape).toEqual([1, 3, 1, 2]);
|
355 | expectArraysClose(await output.data(), [7, 8, 9, 10, 11, 12]);
|
356 | });
|
357 | it('stridedSlice should throw when passed a non-tensor', () => {
|
358 | expect(() => tf.stridedSlice({}, [0], [0], [1]))
|
359 | .toThrowError(/Argument 'x' passed to 'stridedSlice' must be a Tensor/);
|
360 | });
|
361 | it('stridedSlice should handle negative end with ellipsisMask', () => {
|
362 | const a = tf.ones([1, 240, 1, 10]);
|
363 | const output = tf.stridedSlice(a, [0, 0, 0], [0, -1, 0], [1, 1, 1], 3, 1, 4);
|
364 | expect(output.shape).toEqual([1, 239, 1, 10]);
|
365 | });
|
366 | it('accepts a tensor-like object', async () => {
|
367 | const tensor = [0, 1, 2, 3];
|
368 | const output = tf.stridedSlice(tensor, [0], [3], [2]);
|
369 | expect(output.shape).toEqual([2]);
|
370 | expectArraysClose(await output.data(), [0, 2]);
|
371 | });
|
372 | it('ensure no memory leak', async () => {
|
373 | const numTensorsBefore = tf.memory().numTensors;
|
374 | const numDataIdBefore = tf.engine().backend.numDataIds();
|
375 | const tensor = tf.tensor1d([0, 1, 2, 3]);
|
376 | const output = tf.stridedSlice(tensor, [0], [3], [2]);
|
377 | expect(output.shape).toEqual([2]);
|
378 | expectArraysClose(await output.data(), [0, 2]);
|
379 | tensor.dispose();
|
380 | output.dispose();
|
381 | const numTensorsAfter = tf.memory().numTensors;
|
382 | const numDataIdAfter = tf.engine().backend.numDataIds();
|
383 | expect(numTensorsAfter).toBe(numTensorsBefore);
|
384 | expect(numDataIdAfter).toBe(numDataIdBefore);
|
385 | });
|
386 | });
|
387 |
|
\ | No newline at end of file |