UNPKG

16.8 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20import { sizeFromShape } from '../util';
21// Generates small floating point inputs to avoid overflows
22function generateCaseInputs(totalSizeTensor, totalSizeFilter) {
23 const inp = new Array(totalSizeTensor);
24 const filt = new Array(totalSizeFilter);
25 for (let i = 0; i < totalSizeTensor; i++) {
26 inp[i] = (i + 1) / totalSizeTensor;
27 }
28 for (let i = 0; i < totalSizeFilter; i++) {
29 filt[i] = (i + 1) / totalSizeFilter;
30 }
31 return { input: inp, filter: filt };
32}
33function generateGradientCaseInputs(totalSizeTensor, totalSizeFilter) {
34 const inp = new Array(totalSizeTensor);
35 const filt = new Array(totalSizeFilter);
36 for (let i = 0; i < totalSizeTensor; i++) {
37 inp[i] = i + 1;
38 }
39 for (let i = 0; i < totalSizeFilter; i++) {
40 filt[i] = i + 1;
41 }
42 return { input: inp, filter: filt };
43}
44function runConv3DTestCase(batch, inDepth, inHeight, inWidth, inChannels, outChannels, fDepth, fHeight, fWidth, pad, stride) {
45 const inputShape = [batch, inDepth, inHeight, inWidth, inChannels];
46 const filterShape = [fDepth, fHeight, fWidth, inChannels, outChannels];
47 const totalSizeTensor = sizeFromShape(inputShape);
48 const totalSizeFilter = sizeFromShape(filterShape);
49 const inputs = generateCaseInputs(totalSizeTensor, totalSizeFilter);
50 const x = tf.tensor5d(inputs.input, inputShape);
51 const w = tf.tensor5d(inputs.filter, filterShape);
52 const result = tf.conv3d(x, w, stride, pad);
53 return result;
54}
55function runGradientConv3DTestCase(batch, inDepth, inHeight, inWidth, inChannels, outChannels, fDepth, fHeight, fWidth, pad, stride) {
56 const inputShape = [batch, inDepth, inHeight, inWidth, inChannels];
57 const filterShape = [fDepth, fHeight, fWidth, inChannels, outChannels];
58 const totalSizeTensor = sizeFromShape(inputShape);
59 const totalSizeFilter = sizeFromShape(filterShape);
60 const inputs = generateGradientCaseInputs(totalSizeTensor, totalSizeFilter);
61 const x = tf.tensor5d(inputs.input, inputShape);
62 const w = tf.tensor5d(inputs.filter, filterShape);
63 const grads = tf.grads((x, filter) => tf.conv3d(x.clone(), filter.clone(), stride, pad).clone());
64 const [dx, dfilter] = grads([x, w]);
65 expect(dx.shape).toEqual(x.shape);
66 expect(dfilter.shape).toEqual(w.shape);
67 return [dx, dfilter];
68}
69describeWithFlags('conv3d', ALL_ENVS, () => {
70 it('x=[1, 2, 3, 1, 3] f=[1, 1, 1, 3, 3] s=1 d=1 p=valid', async () => {
71 const batch = 1;
72 const inDepth = 2;
73 const inHeight = 3;
74 const inWidth = 1;
75 const inChannels = 3;
76 const outChannels = 3;
77 const fSize = 1;
78 const pad = 'valid';
79 const stride = 1;
80 const result = runConv3DTestCase(batch, inDepth, inHeight, inWidth, inChannels, outChannels, fSize, fSize, fSize, pad, stride);
81 const expectedOutput = [
82 0.18518519, 0.22222222, 0.25925926, 0.40740741, 0.5, 0.59259259,
83 0.62962963, 0.77777778, 0.92592593, 0.85185185, 1.05555556, 1.25925926,
84 1.07407407, 1.33333333, 1.59259259, 1.2962963, 1.61111111, 1.92592593
85 ];
86 expectArraysClose(await result.data(), expectedOutput);
87 });
88 it('x=[1, 2, 1, 3, 3] f=[1, 1, 1, 3, 3] s=1 d=1 p=valid', async () => {
89 const batch = 1;
90 const inDepth = 2;
91 const inHeight = 1;
92 const inWidth = 3;
93 const inChannels = 3;
94 const outChannels = 3;
95 const fSize = 1;
96 const pad = 'valid';
97 const stride = 1;
98 const result = runConv3DTestCase(batch, inDepth, inHeight, inWidth, inChannels, outChannels, fSize, fSize, fSize, pad, stride);
99 const expectedOutput = [
100 0.18518519, 0.22222222, 0.25925926, 0.40740741, 0.5, 0.59259259,
101 0.62962963, 0.77777778, 0.92592593, 0.85185185, 1.05555556, 1.25925926,
102 1.07407407, 1.33333333, 1.59259259, 1.2962963, 1.61111111, 1.92592593
103 ];
104 expectArraysClose(await result.data(), expectedOutput);
105 });
106 it('x=[1, 1, 2, 3, 3] f=[1, 1, 1, 3, 3] s=1 d=1 p=valid', async () => {
107 const batch = 1;
108 const inDepth = 1;
109 const inHeight = 2;
110 const inWidth = 3;
111 const inChannels = 3;
112 const outChannels = 3;
113 const fSize = 1;
114 const pad = 'valid';
115 const stride = 1;
116 const result = runConv3DTestCase(batch, inDepth, inHeight, inWidth, inChannels, outChannels, fSize, fSize, fSize, pad, stride);
117 const expectedOutput = [
118 0.18518519, 0.22222222, 0.25925926, 0.40740741, 0.5, 0.59259259,
119 0.62962963, 0.77777778, 0.92592593, 0.85185185, 1.05555556, 1.25925926,
120 1.07407407, 1.33333333, 1.59259259, 1.2962963, 1.61111111, 1.92592593
121 ];
122 expectArraysClose(await result.data(), expectedOutput);
123 });
124 it('x=[1, 4, 2, 3, 3] f=[2, 2, 2, 3, 3] s=1 d=1 p=valid', async () => {
125 const batch = 1;
126 const inDepth = 4;
127 const inHeight = 2;
128 const inWidth = 3;
129 const inChannels = 3;
130 const outChannels = 3;
131 const fSize = 2;
132 const pad = 'valid';
133 const stride = 1;
134 const result = runConv3DTestCase(batch, inDepth, inHeight, inWidth, inChannels, outChannels, fSize, fSize, fSize, pad, stride);
135 const expectedOutput = [
136 3.77199074, 3.85069444, 3.92939815, 4.2650463, 4.35763889, 4.45023148,
137 6.73032407, 6.89236111, 7.05439815, 7.22337963, 7.39930556, 7.57523148,
138 9.68865741, 9.93402778, 10.17939815, 10.18171296, 10.44097222, 10.70023148
139 ];
140 expectArraysClose(await result.data(), expectedOutput);
141 });
142 it('x=[1, 5, 8, 7, 1] f=[1, 2, 3, 1, 1] s=[2, 3, 1] d=1 p=same', async () => {
143 const batch = 1;
144 const inDepth = 5;
145 const inHeight = 8;
146 const inWidth = 7;
147 const inChannels = 1;
148 const outChannels = 1;
149 const fDepth = 1;
150 const fHeight = 2;
151 const fWidth = 3;
152 const pad = 'same';
153 const stride = [2, 3, 1];
154 const result = runConv3DTestCase(batch, inDepth, inHeight, inWidth, inChannels, outChannels, fDepth, fHeight, fWidth, pad, stride);
155 const expectedOutput = [
156 0.06071429, 0.08988095, 0.10238095, 0.11488095, 0.12738095, 0.13988095,
157 0.08452381, 0.26071429, 0.35238095, 0.36488095, 0.37738095, 0.38988095,
158 0.40238095, 0.23452381, 0.46071429, 0.61488095, 0.62738095, 0.63988095,
159 0.65238095, 0.66488095, 0.38452381, 1.12738095, 1.48988095, 1.50238095,
160 1.51488095, 1.52738095, 1.53988095, 0.88452381, 1.32738095, 1.75238095,
161 1.76488095, 1.77738095, 1.78988095, 1.80238095, 1.03452381, 1.52738095,
162 2.01488095, 2.02738095, 2.03988095, 2.05238095, 2.06488095, 1.18452381,
163 2.19404762, 2.88988095, 2.90238095, 2.91488095, 2.92738095, 2.93988095,
164 1.68452381, 2.39404762, 3.15238095, 3.16488095, 3.17738095, 3.18988095,
165 3.20238095, 1.83452381, 2.59404762, 3.41488095, 3.42738095, 3.43988095,
166 3.45238095, 3.46488095, 1.98452381
167 ];
168 expectArraysClose(await result.data(), expectedOutput);
169 });
170 it('x=[1, 4, 2, 3, 3] f=[2, 2, 2, 3, 3] s=2 d=1 p=valid', async () => {
171 const batch = 1;
172 const inDepth = 4;
173 const inHeight = 2;
174 const inWidth = 3;
175 const inChannels = 3;
176 const outChannels = 3;
177 const fSize = 2;
178 const pad = 'valid';
179 const stride = 2;
180 const result = runConv3DTestCase(batch, inDepth, inHeight, inWidth, inChannels, outChannels, fSize, fSize, fSize, pad, stride);
181 const expectedOutput = [
182 3.77199074, 3.85069444, 3.92939815, 9.68865741, 9.93402778, 10.17939815
183 ];
184 expectArraysClose(await result.data(), expectedOutput);
185 });
186 it('x=[1, 6, 7, 8, 2] f=[3, 2, 1, 2, 3] s=3 d=1 p=valid', async () => {
187 const batch = 1;
188 const inDepth = 6;
189 const inHeight = 7;
190 const inWidth = 8;
191 const inChannels = 2;
192 const outChannels = 3;
193 const fDepth = 3;
194 const fHeight = 2;
195 const fWidth = 1;
196 const pad = 'valid';
197 const stride = 3;
198 const result = runConv3DTestCase(batch, inDepth, inHeight, inWidth, inChannels, outChannels, fDepth, fHeight, fWidth, pad, stride);
199 const expectedOutput = [
200 1.51140873, 1.57167659, 1.63194444, 1.56349206, 1.62673611, 1.68998016,
201 1.6155754, 1.68179563, 1.74801587, 1.9280754, 2.01215278, 2.09623016,
202 1.98015873, 2.0672123, 2.15426587, 2.03224206, 2.12227183, 2.21230159,
203 4.4280754, 4.65500992, 4.88194444, 4.48015873, 4.71006944, 4.93998016,
204 4.53224206, 4.76512897, 4.99801587, 4.84474206, 5.09548611, 5.34623016,
205 4.8968254, 5.15054563, 5.40426587, 4.94890873, 5.20560516, 5.46230159
206 ];
207 expectArraysClose(await result.data(), expectedOutput);
208 });
209 it('x=[1, 4, 2, 3, 3] f=[2, 2, 2, 3, 3] s=2 d=1 p=same', async () => {
210 const batch = 1;
211 const inDepth = 4;
212 const inHeight = 2;
213 const inWidth = 3;
214 const inChannels = 3;
215 const outChannels = 3;
216 const fSize = 2;
217 const pad = 'same';
218 const stride = 2;
219 const result = runConv3DTestCase(batch, inDepth, inHeight, inWidth, inChannels, outChannels, fSize, fSize, fSize, pad, stride);
220 const expectedOutput = [
221 3.77199074, 3.85069444, 3.92939815, 2.0162037, 2.06597222, 2.11574074,
222 9.68865741, 9.93402778, 10.17939815, 4.59953704, 4.73263889, 4.86574074
223 ];
224 expectArraysClose(await result.data(), expectedOutput);
225 });
226 it('x=[1, 3, 3, 3, 1] f=[1, 1, 1, 1, 1] s=2 d=1 p=same', async () => {
227 const batch = 1;
228 const inDepth = 3;
229 const inHeight = 3;
230 const inWidth = 3;
231 const inChannels = 1;
232 const outChannels = 1;
233 const fSize = 1;
234 const pad = 'same';
235 const stride = 2;
236 const result = runConv3DTestCase(batch, inDepth, inHeight, inWidth, inChannels, outChannels, fSize, fSize, fSize, pad, stride);
237 const expectedOutput = [
238 0.03703704, 0.11111111, 0.25925926, 0.33333333, 0.7037037, 0.77777778,
239 0.92592593, 1.
240 ];
241 expectArraysClose(await result.data(), expectedOutput);
242 });
243 it('x=[1, 3, 3, 3, 1] f=[1, 1, 1, 1, 1] s=2 d=1 p=valid', async () => {
244 const batch = 1;
245 const inDepth = 3;
246 const inHeight = 3;
247 const inWidth = 3;
248 const inChannels = 1;
249 const outChannels = 1;
250 const fSize = 1;
251 const pad = 'valid';
252 const stride = 2;
253 const result = runConv3DTestCase(batch, inDepth, inHeight, inWidth, inChannels, outChannels, fSize, fSize, fSize, pad, stride);
254 const expectedOutput = [
255 0.03703704, 0.11111111, 0.25925926, 0.33333333, 0.7037037, 0.77777778,
256 0.92592593, 1.
257 ];
258 expectArraysClose(await result.data(), expectedOutput);
259 });
260 it('x=[1, 7, 7, 7, 1] f=[2, 2, 2, 1, 1] s=3 d=1 p=same', async () => {
261 const batch = 1;
262 const inDepth = 7;
263 const inHeight = 7;
264 const inWidth = 7;
265 const inChannels = 1;
266 const outChannels = 1;
267 const fSize = 2;
268 const pad = 'same';
269 const stride = 3;
270 const result = runConv3DTestCase(batch, inDepth, inHeight, inWidth, inChannels, outChannels, fSize, fSize, fSize, pad, stride);
271 const expectedOutput = [
272 0.54081633, 0.58017493, 0.28061224, 0.81632653, 0.85568513, 0.40306122,
273 0.41873178, 0.4340379, 0.19642857, 2.46938776, 2.50874636, 1.1377551,
274 2.74489796, 2.78425656, 1.26020408, 1.16873178, 1.1840379, 0.51785714,
275 1.09511662, 1.10604956, 0.44642857, 1.17164723, 1.18258017, 0.47704082,
276 0.3691691, 0.37244898, 0.125
277 ];
278 expectArraysClose(await result.data(), expectedOutput);
279 });
280 it('x=[1, 7, 7, 7, 1] f=[2, 2, 2, 1, 1] s=3 d=1 p=valid', async () => {
281 const batch = 1;
282 const inDepth = 7;
283 const inHeight = 7;
284 const inWidth = 7;
285 const inChannels = 1;
286 const outChannels = 1;
287 const fSize = 2;
288 const pad = 'valid';
289 const stride = 3;
290 const result = runConv3DTestCase(batch, inDepth, inHeight, inWidth, inChannels, outChannels, fSize, fSize, fSize, pad, stride);
291 const expectedOutput = [
292 0.540816, 0.580175, 0.816327, 0.855685, 2.469388, 2.508746, 2.744898,
293 2.784257
294 ];
295 expectArraysClose(await result.data(), expectedOutput);
296 });
297 it('x=[1, 2, 1, 2, 1] f=[2, 1, 2, 1, 2] s=1 d=1 p=valid', async () => {
298 const batch = 1;
299 const inDepth = 2;
300 const inHeight = 1;
301 const inWidth = 2;
302 const inChannels = 1;
303 const outChannels = 2;
304 const fDepth = 2;
305 const fHeight = 1;
306 const fWidth = 2;
307 const pad = 'valid';
308 const stride = 1;
309 const result = runConv3DTestCase(batch, inDepth, inHeight, inWidth, inChannels, outChannels, fDepth, fHeight, fWidth, pad, stride);
310 const expectedOutput = [1.5625, 1.875];
311 expectArraysClose(await result.data(), expectedOutput);
312 });
313 it('gradient with clones, x=[1,3,6,1,1] filter=[2,2,1,1,1] s=1 d=1 p=valid', async () => {
314 const batch = 1;
315 const inDepth = 3;
316 const inHeight = 6;
317 const inWidth = 1;
318 const inChannels = 1;
319 const outChannels = 1;
320 const fDepth = 2;
321 const fHeight = 2;
322 const fWidth = 1;
323 const pad = 'valid';
324 const stride = 1;
325 const [dx, dfilter] = runGradientConv3DTestCase(batch, inDepth, inHeight, inWidth, inChannels, outChannels, fDepth, fHeight, fWidth, pad, stride);
326 const expectedFilterOutput = [60.0, 70.0, 120.0, 130.0];
327 const expectedOutput = [
328 1.0, 3.0, 3.0, 3.0, 3.0, 2.0, 4.0, 10.0, 10.0, 10.0, 10.0, 6.0, 3.0,
329 7.0, 7.0, 7.0, 7.0, 4.0
330 ];
331 expectArraysClose(await dx.data(), expectedOutput);
332 expectArraysClose(await dfilter.data(), expectedFilterOutput);
333 });
334 it('throws when passed x as a non-tensor', () => {
335 const inputDepth = 1;
336 const outputDepth = 1;
337 const fSize = 1;
338 const pad = 'valid';
339 const stride = 1;
340 const w = tf.tensor5d([2], [fSize, fSize, fSize, inputDepth, outputDepth]);
341 expect(() => tf.conv3d({}, w, stride, pad))
342 .toThrowError(/Argument 'x' passed to 'conv3d' must be a Tensor/);
343 });
344 it('throws when passed filter as a non-tensor', () => {
345 const inputDepth = 1;
346 const inputShape = [2, 2, 1, inputDepth];
347 const pad = 'valid';
348 const stride = 1;
349 const x = tf.tensor4d([1, 2, 3, 4], inputShape);
350 expect(() => tf.conv3d(x, {}, stride, pad))
351 .toThrowError(/Argument 'filter' passed to 'conv3d' must be a Tensor/);
352 });
353 it('accepts a tensor-like object', async () => {
354 const pad = 'valid';
355 const stride = 1;
356 const x = [[[[1], [2]], [[3], [4]]]]; // 2x2x1x1
357 const w = [[[[[2]]]]]; // 1x1x1x1x1
358 const result = tf.conv3d(x, w, stride, pad);
359 expectArraysClose(await result.data(), [2, 4, 6, 8]);
360 });
361 it('throws when data format not NDHWC', () => {
362 const inputDepth = 1;
363 const outputDepth = 1;
364 const inputShape = [2, 2, 1, inputDepth];
365 const pad = 'valid';
366 const fSize = 1;
367 const stride = 1;
368 const dataFormat = 'NCDHW';
369 const x = tf.tensor4d([1, 2, 3, 4], inputShape);
370 const w = tf.tensor5d([2], [fSize, fSize, fSize, inputDepth, outputDepth]);
371 expect(() => tf.conv3d(x, w, stride, pad, dataFormat)).toThrowError();
372 });
373});
374//# sourceMappingURL=conv3d_test.js.map
\No newline at end of file