UNPKG

45 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags, SYNC_BACKEND_ENVS } from '../jasmine_util';
19import { encodeStrings, expectArraysClose } from '../test_util';
20describeWithFlags('slice ', ALL_ENVS, () => {
21 describeWithFlags('ergonomics', ALL_ENVS, () => {
22 it('slices 2x2x2 array into 2x1x1 no size', async () => {
23 const a = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]);
24 const result = a.slice([0, 1, 1]);
25 expect(result.shape).toEqual([2, 1, 1]);
26 expectArraysClose(await result.data(), [4, 8]);
27 });
28 it('slices 2x2x2 array into 1x2x2 with scalar begin no size', async () => {
29 const a = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]);
30 const result = a.slice(1);
31 expect(result.shape).toEqual([1, 2, 2]);
32 expectArraysClose(await result.data(), [5, 6, 7, 8]);
33 });
34 it('slices 2x2x2 array using 2d size and 2d size', async () => {
35 const a = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]);
36 const result = a.slice([0, 1]);
37 expect(result.shape).toEqual([2, 1, 2]);
38 expectArraysClose(await result.data(), [3, 4, 7, 8]);
39 });
40 it('slices 2x2x2 array using negative size', async () => {
41 const a = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]);
42 const result = a.slice([0, 1], [-1, 1]);
43 expect(result.shape).toEqual([2, 1, 2]);
44 expectArraysClose(await result.data(), [3, 4, 7, 8]);
45 });
46 it('slices 2x2x2 array using 1d size', async () => {
47 const a = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]);
48 const result = a.slice(0, 1);
49 expect(result.shape).toEqual([1, 2, 2]);
50 expectArraysClose(await result.data(), [1, 2, 3, 4]);
51 });
52 it('throws when passed a non-tensor', () => {
53 expect(() => tf.slice({}, 0, 0))
54 .toThrowError(/Argument 'x' passed to 'slice' must be a Tensor/);
55 });
56 it('accepts a tensor-like object', async () => {
57 const a = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]; // 2x2x2
58 const result = tf.slice(a, [0, 1, 1]);
59 expect(result.shape).toEqual([2, 1, 1]);
60 expectArraysClose(await result.data(), [4, 8]);
61 });
62 it('should match source tensor dtype', () => {
63 const a = tf.tensor1d([1, 2, 3, 4, 5], 'int32');
64 const b = a.asType('float32');
65 expect(tf.slice(b, 0).dtype).toEqual('float32');
66 });
67 it('throws when begin is negative', async () => {
68 const a = [[1, 2], [3, 4]]; // 2x2
69 expect(() => tf.slice(a, [-1, 1], [
70 1, 1
71 ])).toThrowError(/slice\(\) does not support negative begin indexing./);
72 });
73 });
74 describeWithFlags('shallow slicing', ALL_ENVS, () => {
75 it('shallow slice an input that was cast', async () => {
76 const a = tf.tensor([[1, 2], [3, 4]], [2, 2], 'int32');
77 const b = a.toFloat();
78 const c = b.slice(1, 1);
79 expect(c.dtype).toBe('float32');
80 expect(c.shape).toEqual([1, 2]);
81 expectArraysClose(await c.data(), [3, 4]);
82 });
83 it('delayed async read of sliced tensor has no mem leak', async () => {
84 const a = tf.zeros([10]);
85 const b = tf.slice(a, 0, 1);
86 const nBefore = tf.memory().numTensors;
87 expect(nBefore).toBe(2);
88 await b.data();
89 const nAfter = tf.memory().numTensors;
90 expect(nAfter).toBe(2);
91 tf.dispose([a, b]);
92 expect(tf.memory().numTensors).toBe(0);
93 });
94 });
95 describeWithFlags('shallow slicing', SYNC_BACKEND_ENVS, () => {
96 it('delayed sync read of sliced tensor has no mem leak', () => {
97 const a = tf.zeros([10]);
98 const b = tf.slice(a, 0, 1);
99 const nBefore = tf.memory().numTensors;
100 expect(nBefore).toBe(2);
101 b.dataSync();
102 const nAfter = tf.memory().numTensors;
103 expect(nAfter).toBe(2);
104 tf.dispose([a, b]);
105 expect(tf.memory().numTensors).toBe(0);
106 });
107 });
108 describeWithFlags('slice5d', ALL_ENVS, () => {
109 it('slices 1x1x1x1x1 into shape 1x1x1x1x1 (effectively a copy)', async () => {
110 const a = tf.tensor5d([[[[[5]]]]], [1, 1, 1, 1, 1]);
111 const result = tf.slice(a, [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]);
112 expect(result.shape).toEqual([1, 1, 1, 1, 1]);
113 expectArraysClose(await result.data(), [5]);
114 });
115 it('slices 2x2x2x2x2 array into 1x2x2x2x2 starting at [1,0,0,0,0]', async () => {
116 const a = tf.tensor5d([
117 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
118 12, 13, 14, 15, 16, 11, 22, 33, 44, 55, 66,
119 77, 88, 111, 222, 333, 444, 555, 666, 777, 888
120 ], [2, 2, 2, 2, 2]);
121 const result = tf.slice(a, [1, 0, 0, 0, 0], [1, 2, 2, 2, 2]);
122 expect(result.shape).toEqual([1, 2, 2, 2, 2]);
123 expectArraysClose(await result.data(), [
124 11, 22, 33, 44, 55, 66, 77, 88, 111, 222, 333, 444, 555, 666, 777,
125 888
126 ]);
127 });
128 it('slices 2x2x2x2x2 array into 2x1x1x1x1 starting at [0,1,1,1,1]', async () => {
129 const a = tf.tensor5d([
130 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
131 12, 13, 14, 15, 16, 11, 22, 33, 44, 55, 66,
132 77, 88, 111, 222, 333, 444, 555, 666, 777, 888
133 ], [2, 2, 2, 2, 2]);
134 const result = tf.slice(a, [0, 1, 1, 1, 1], [2, 1, 1, 1, 1]);
135 expect(result.shape).toEqual([2, 1, 1, 1, 1]);
136 expectArraysClose(await result.data(), [16, 888]);
137 });
138 it('accepts a tensor-like object', async () => {
139 const a = [[[[[5]]]]]; // 1x1x1x1x1
140 const result = tf.slice(a, [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]);
141 expect(result.shape).toEqual([1, 1, 1, 1, 1]);
142 expectArraysClose(await result.data(), [5]);
143 });
144 });
145 describeWithFlags('slice6d', ALL_ENVS, () => {
146 it('slices 1x1x1x1x1x1 into shape 1x1x1x1x1x1 (effectively a copy)', async () => {
147 const a = tf.tensor6d([[[[[[5]]]]]], [1, 1, 1, 1, 1, 1]);
148 const result = tf.slice(a, [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]);
149 expect(result.shape).toEqual([1, 1, 1, 1, 1, 1]);
150 expectArraysClose(await result.data(), [5]);
151 });
152 it('slices 2x2x2x2x2x2 array into 1x2x2x2x2x2 starting at [1,0,0,0,0,0]', async () => {
153 const a = tf.tensor6d([
154 31, 32, 33, 34, 35, 36, 37, 38, 39, 310, 311,
155 312, 313, 314, 315, 316, 311, 322, 333, 344, 355, 366,
156 377, 388, 3111, 3222, 3333, 3444, 3555, 3666, 3777, 3888,
157 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
158 12, 13, 14, 15, 16, 11, 22, 33, 44, 55, 66,
159 77, 88, 111, 222, 333, 444, 555, 666, 777, 888
160 ], [2, 2, 2, 2, 2, 2]);
161 const result = tf.slice(a, [1, 0, 0, 0, 0, 0], [1, 2, 2, 2, 2, 2]);
162 expect(result.shape).toEqual([1, 2, 2, 2, 2, 2]);
163 expectArraysClose(await result.data(), [
164 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
165 12, 13, 14, 15, 16, 11, 22, 33, 44, 55, 66,
166 77, 88, 111, 222, 333, 444, 555, 666, 777, 888
167 ]);
168 });
169 it('slices 2x2x2x2x2x2 array into 2x1x1x1x1x1 starting at [0,1,1,1,1,1]', async () => {
170 const a = tf.tensor6d([
171 31, 32, 33, 34, 35, 36, 37, 38, 39, 310, 311,
172 312, 313, 314, 315, 316, 311, 322, 333, 344, 355, 366,
173 377, 388, 3111, 3222, 3333, 3444, 3555, 3666, 3777, 3888,
174 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
175 12, 13, 14, 15, 16, 11, 22, 33, 44, 55, 66,
176 77, 88, 111, 222, 333, 444, 555, 666, 777, 888
177 ], [2, 2, 2, 2, 2, 2]);
178 const result = tf.slice(a, [0, 1, 1, 1, 1, 1], [2, 1, 1, 1, 1, 1]);
179 expect(result.shape).toEqual([2, 1, 1, 1, 1, 1]);
180 expectArraysClose(await result.data(), [3888, 888]);
181 });
182 it('accepts a tensor-like object', async () => {
183 const a = [[[[[[5]]]]]]; // 1x1x1x1x1x1
184 const result = tf.slice(a, [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]);
185 expect(result.shape).toEqual([1, 1, 1, 1, 1, 1]);
186 expectArraysClose(await result.data(), [5]);
187 });
188 });
189 describeWithFlags('accepts string', ALL_ENVS, () => {
190 it('slices 2x2x2 array into 2x1x1 no size.', async () => {
191 const a = tf.tensor3d(['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight'], [2, 2, 2], 'string');
192 const result = a.slice([0, 1, 1]);
193 expect(result.shape).toEqual([2, 1, 1]);
194 expectArraysClose(await result.data(), ['four', 'eight']);
195 });
196 it('slices 2x2x2 array into 1x2x2 with scalar begin no size.', async () => {
197 const a = tf.tensor3d(['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight'], [2, 2, 2]);
198 const result = a.slice(1);
199 expect(result.shape).toEqual([1, 2, 2]);
200 expectArraysClose(await result.data(), ['five', 'six', 'seven', 'eight']);
201 });
202 it('slice encoded string.', async () => {
203 const bytes = encodeStrings([
204 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight'
205 ]);
206 const a = tf.tensor3d(bytes, [2, 2, 2], 'string');
207 const result = a.slice([0, 1, 1]);
208 expect(result.shape).toEqual([2, 1, 1]);
209 expectArraysClose(await result.data(), ['four', 'eight']);
210 });
211 });
212});
213//# sourceMappingURL=data:application/json;base64,
\No newline at end of file