UNPKG

159 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2017 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as conv_util from './conv_util';
18describe('conv_util computeConv2DInfo', () => {
19 it('1x1 conv over 1x1 array with same pad', () => {
20 const inShape = [1, 1, 1, 1];
21 const stride = 1;
22 const dilation = 1;
23 const convInfo = conv_util.computeConv2DInfo(inShape, [1, 1, 1, 1], stride, dilation, 'same');
24 expect(convInfo.batchSize).toEqual(1);
25 expect(convInfo.outHeight).toEqual(1);
26 expect(convInfo.outWidth).toEqual(1);
27 expect(convInfo.outChannels).toEqual(1);
28 expect(convInfo.effectiveFilterWidth).toEqual(1);
29 expect(convInfo.effectiveFilterHeight).toEqual(1);
30 });
31 it('2x2 conv over 3x3 array with same pad', () => {
32 const inShape = [1, 3, 3, 1];
33 const stride = 1;
34 const dilation = 1;
35 const convInfo = conv_util.computeConv2DInfo(inShape, [2, 2, 1, 1], stride, dilation, 'same');
36 expect(convInfo.batchSize).toEqual(1);
37 expect(convInfo.outHeight).toEqual(3);
38 expect(convInfo.outWidth).toEqual(3);
39 expect(convInfo.outChannels).toEqual(1);
40 expect(convInfo.effectiveFilterWidth).toEqual(2);
41 expect(convInfo.effectiveFilterHeight).toEqual(2);
42 // Should produce non-even padding with extra pixel at the right/bottom.
43 expect(convInfo.padInfo.left).toBe(0);
44 expect(convInfo.padInfo.right).toBe(1);
45 expect(convInfo.padInfo.top).toBe(0);
46 expect(convInfo.padInfo.bottom).toBe(1);
47 });
48 it('2x2 conv over 3x3 array with same pad', () => {
49 const inShape = [1, 3, 3, 1];
50 const stride = 1;
51 const dilation = 1;
52 const convInfo = conv_util.computeConv2DInfo(inShape, [2, 2, 1, 1], stride, dilation, 'same');
53 expect(convInfo.batchSize).toEqual(1);
54 expect(convInfo.outHeight).toEqual(3);
55 expect(convInfo.outWidth).toEqual(3);
56 expect(convInfo.outChannels).toEqual(1);
57 expect(convInfo.effectiveFilterWidth).toEqual(2);
58 expect(convInfo.effectiveFilterHeight).toEqual(2);
59 });
60 it('2x2 conv over 3x3 array with valid pad', () => {
61 const inShape = [1, 3, 3, 1];
62 const stride = 1;
63 const dilation = 1;
64 const convInfo = conv_util.computeConv2DInfo(inShape, [2, 2, 1, 1], stride, dilation, 'valid');
65 expect(convInfo.batchSize).toEqual(1);
66 expect(convInfo.outHeight).toEqual(2);
67 expect(convInfo.outWidth).toEqual(2);
68 expect(convInfo.outChannels).toEqual(1);
69 expect(convInfo.effectiveFilterWidth).toEqual(2);
70 expect(convInfo.effectiveFilterHeight).toEqual(2);
71 });
72 it('3x3 conv over 5x5 array with same pad with stride 2', () => {
73 const inShape = [1, 5, 5, 1];
74 const stride = 2;
75 const dilation = 1;
76 const convInfo = conv_util.computeConv2DInfo(inShape, [3, 3, 1, 1], stride, dilation, 'same');
77 expect(convInfo.batchSize).toEqual(1);
78 expect(convInfo.outHeight).toEqual(3);
79 expect(convInfo.outWidth).toEqual(3);
80 expect(convInfo.outChannels).toEqual(1);
81 expect(convInfo.effectiveFilterWidth).toEqual(3);
82 expect(convInfo.effectiveFilterHeight).toEqual(3);
83 expect(convInfo.padInfo.left).toBe(1);
84 expect(convInfo.padInfo.right).toBe(1);
85 expect(convInfo.padInfo.top).toBe(1);
86 expect(convInfo.padInfo.bottom).toBe(1);
87 });
88 it('2x2 conv over 3x3 array with valid pad with stride 2', () => {
89 const inShape = [1, 3, 3, 1];
90 const stride = 2;
91 const dilation = 1;
92 const convInfo = conv_util.computeConv2DInfo(inShape, [2, 2, 1, 1], stride, dilation, 'valid');
93 expect(convInfo.batchSize).toEqual(1);
94 expect(convInfo.outHeight).toEqual(1);
95 expect(convInfo.outWidth).toEqual(1);
96 expect(convInfo.outChannels).toEqual(1);
97 expect(convInfo.effectiveFilterWidth).toEqual(2);
98 expect(convInfo.effectiveFilterHeight).toEqual(2);
99 });
100 it('2x1 conv over 3x3 array with valid pad with stride 1', () => {
101 const inShape = [1, 3, 3, 1];
102 const stride = 1;
103 const dilation = 1;
104 const convInfo = conv_util.computeConv2DInfo(inShape, [2, 1, 1, 1], stride, dilation, 'valid');
105 expect(convInfo.batchSize).toEqual(1);
106 expect(convInfo.outHeight).toEqual(2);
107 expect(convInfo.outWidth).toEqual(3);
108 expect(convInfo.outChannels).toEqual(1);
109 expect(convInfo.effectiveFilterWidth).toEqual(1);
110 expect(convInfo.effectiveFilterHeight).toEqual(2);
111 });
112 it('2x1 conv over 3x3 array with valid pad with strides h=2, w=1', () => {
113 const inShape = [1, 3, 3, 1];
114 const strides = [2, 1];
115 const dilation = 1;
116 const convInfo = conv_util.computeConv2DInfo(inShape, [2, 1, 1, 1], strides, dilation, 'valid');
117 expect(convInfo.batchSize).toEqual(1);
118 expect(convInfo.outHeight).toEqual(1);
119 expect(convInfo.outWidth).toEqual(3);
120 expect(convInfo.outChannels).toEqual(1);
121 expect(convInfo.effectiveFilterWidth).toEqual(1);
122 expect(convInfo.effectiveFilterHeight).toEqual(2);
123 });
124 it('1x2 conv over 3x3 array with valid pad with stride 1', () => {
125 const inShape = [1, 3, 3, 1];
126 const stride = 1;
127 const dilation = 1;
128 const convInfo = conv_util.computeConv2DInfo(inShape, [1, 2, 1, 1], stride, dilation, 'valid');
129 expect(convInfo.batchSize).toEqual(1);
130 expect(convInfo.outHeight).toEqual(3);
131 expect(convInfo.outWidth).toEqual(2);
132 expect(convInfo.outChannels).toEqual(1);
133 expect(convInfo.effectiveFilterWidth).toEqual(2);
134 expect(convInfo.effectiveFilterHeight).toEqual(1);
135 });
136 it('1x2 conv over 3x3 array with valid pad with stride 1, batch=5', () => {
137 const inShape = [5, 3, 3, 1];
138 const stride = 1;
139 const dilation = 1;
140 const convInfo = conv_util.computeConv2DInfo(inShape, [1, 2, 1, 1], stride, dilation, 'valid');
141 expect(convInfo.batchSize).toEqual(5);
142 expect(convInfo.outHeight).toEqual(3);
143 expect(convInfo.outWidth).toEqual(2);
144 expect(convInfo.outChannels).toEqual(1);
145 expect(convInfo.effectiveFilterWidth).toEqual(2);
146 expect(convInfo.effectiveFilterHeight).toEqual(1);
147 });
148 it('2x2 conv over 3x3 array with same pad with dilations 2', () => {
149 const inShape = [1, 3, 3, 1];
150 const stride = 1;
151 const dilations = 2;
152 const convInfo = conv_util.computeConv2DInfo(inShape, [2, 2, 1, 1], stride, dilations, 'same');
153 expect(convInfo.batchSize).toEqual(1);
154 expect(convInfo.outHeight).toEqual(3);
155 expect(convInfo.outWidth).toEqual(3);
156 expect(convInfo.outChannels).toEqual(1);
157 // pad evenly on all sides
158 expect(convInfo.padInfo.left).toBe(1);
159 expect(convInfo.padInfo.right).toBe(1);
160 expect(convInfo.padInfo.top).toBe(1);
161 expect(convInfo.padInfo.bottom).toBe(1);
162 expect(convInfo.effectiveFilterWidth).toEqual(3);
163 expect(convInfo.effectiveFilterHeight).toEqual(3);
164 });
165 it('2x1 conv over 3x3 array with same pad with dilations 2', () => {
166 const inShape = [1, 3, 3, 1];
167 const stride = 1;
168 const dilations = 2;
169 const convInfo = conv_util.computeConv2DInfo(inShape, [2, 1, 1, 1], stride, dilations, 'same');
170 expect(convInfo.batchSize).toEqual(1);
171 expect(convInfo.outHeight).toEqual(3);
172 expect(convInfo.outWidth).toEqual(3);
173 expect(convInfo.outChannels).toEqual(1);
174 // pad top and bottom
175 expect(convInfo.padInfo.left).toBe(0);
176 expect(convInfo.padInfo.right).toBe(0);
177 expect(convInfo.padInfo.top).toBe(1);
178 expect(convInfo.padInfo.bottom).toBe(1);
179 expect(convInfo.effectiveFilterWidth).toEqual(1);
180 expect(convInfo.effectiveFilterHeight).toEqual(3);
181 });
182 it('3x4 conv over 8x8 array with same pad with dilations h=4 w=3', () => {
183 const inShape = [1, 8, 8, 1];
184 const stride = 1;
185 const dilations = [4, 3];
186 const convInfo = conv_util.computeConv2DInfo(inShape, [3, 4, 1, 1], stride, dilations, 'same');
187 expect(convInfo.batchSize).toEqual(1);
188 expect(convInfo.outHeight).toEqual(8);
189 expect(convInfo.outWidth).toEqual(8);
190 expect(convInfo.outChannels).toEqual(1);
191 expect(convInfo.effectiveFilterWidth).toEqual(10);
192 expect(convInfo.effectiveFilterHeight).toEqual(9);
193 expect(convInfo.padInfo.left).toBe(4);
194 expect(convInfo.padInfo.right).toBe(5);
195 expect(convInfo.padInfo.top).toBe(4);
196 expect(convInfo.padInfo.bottom).toBe(4);
197 });
198 it('2x1 conv over 3x3 array with valid pad with dilations 2', () => {
199 const inShape = [1, 3, 3, 1];
200 const stride = 1;
201 const dilations = 2;
202 const convInfo = conv_util.computeConv2DInfo(inShape, [2, 1, 1, 1], stride, dilations, 'valid');
203 expect(convInfo.batchSize).toEqual(1);
204 expect(convInfo.outHeight).toEqual(1);
205 expect(convInfo.outWidth).toEqual(3);
206 expect(convInfo.outChannels).toEqual(1);
207 expect(convInfo.effectiveFilterWidth).toEqual(1);
208 expect(convInfo.effectiveFilterHeight).toEqual(3);
209 });
210 it('2x2 conv over 3x3 array with valid pad with dilations 2', () => {
211 const inShape = [1, 3, 3, 1];
212 const stride = 1;
213 const dilations = 2;
214 const convInfo = conv_util.computeConv2DInfo(inShape, [2, 2, 1, 1], stride, dilations, 'valid');
215 expect(convInfo.batchSize).toEqual(1);
216 expect(convInfo.outHeight).toEqual(1);
217 expect(convInfo.outWidth).toEqual(1);
218 expect(convInfo.outChannels).toEqual(1);
219 expect(convInfo.effectiveFilterWidth).toEqual(3);
220 expect(convInfo.effectiveFilterHeight).toEqual(3);
221 });
222 it('2x2 conv over 4x4 array with valid pad with dilations 2', () => {
223 const inShape = [1, 4, 4, 1];
224 const stride = 1;
225 const dilations = 2;
226 const convInfo = conv_util.computeConv2DInfo(inShape, [2, 2, 1, 1], stride, dilations, 'valid');
227 expect(convInfo.batchSize).toEqual(1);
228 expect(convInfo.outHeight).toEqual(2);
229 expect(convInfo.outWidth).toEqual(2);
230 expect(convInfo.outChannels).toEqual(1);
231 expect(convInfo.effectiveFilterWidth).toEqual(3);
232 expect(convInfo.effectiveFilterHeight).toEqual(3);
233 });
234});
235describe('conv_util computeConv3DInfo', () => {
236 it('1x1x1 conv over 1x1x1 array with same pad', () => {
237 const inShape = [1, 1, 1, 1, 1];
238 const stride = 1;
239 const dilation = 1;
240 const convInfo = conv_util.computeConv3DInfo(inShape, [1, 1, 1, 1, 1], stride, dilation, 'same');
241 expect(convInfo.batchSize).toEqual(1);
242 expect(convInfo.outDepth).toEqual(1);
243 expect(convInfo.outHeight).toEqual(1);
244 expect(convInfo.outWidth).toEqual(1);
245 expect(convInfo.outChannels).toEqual(1);
246 });
247 it('2x2x2 conv over 3x3x3 array with same pad', () => {
248 const inShape = [1, 3, 3, 3, 1];
249 const stride = 1;
250 const dilation = 1;
251 const convInfo = conv_util.computeConv3DInfo(inShape, [2, 2, 2, 1, 1], stride, dilation, 'same');
252 expect(convInfo.batchSize).toEqual(1);
253 expect(convInfo.outDepth).toEqual(3);
254 expect(convInfo.outHeight).toEqual(3);
255 expect(convInfo.outWidth).toEqual(3);
256 expect(convInfo.outChannels).toEqual(1);
257 // Should produce non-even padding with extra pixel at the back/right/bottom
258 expect(convInfo.padInfo.front).toBe(0);
259 expect(convInfo.padInfo.back).toBe(1);
260 expect(convInfo.padInfo.left).toBe(0);
261 expect(convInfo.padInfo.right).toBe(1);
262 expect(convInfo.padInfo.top).toBe(0);
263 expect(convInfo.padInfo.bottom).toBe(1);
264 });
265 it('2x2x2 conv over 3x3x3 array with same pad', () => {
266 const inShape = [1, 3, 3, 3, 1];
267 const stride = 1;
268 const dilation = 1;
269 const convInfo = conv_util.computeConv3DInfo(inShape, [2, 2, 2, 1, 1], stride, dilation, 'same');
270 expect(convInfo.batchSize).toEqual(1);
271 expect(convInfo.outDepth).toEqual(3);
272 expect(convInfo.outHeight).toEqual(3);
273 expect(convInfo.outWidth).toEqual(3);
274 expect(convInfo.outChannels).toEqual(1);
275 });
276 it('2x2x2 conv over 3x3x3 array with valid pad', () => {
277 const inShape = [1, 3, 3, 3, 1];
278 const stride = 1;
279 const dilation = 1;
280 const convInfo = conv_util.computeConv3DInfo(inShape, [2, 2, 2, 1, 1], stride, dilation, 'valid');
281 expect(convInfo.batchSize).toEqual(1);
282 expect(convInfo.outDepth).toEqual(2);
283 expect(convInfo.outHeight).toEqual(2);
284 expect(convInfo.outWidth).toEqual(2);
285 expect(convInfo.outChannels).toEqual(1);
286 });
287 it('3x3x3 conv over 5x5x5 array with same pad with stride 2', () => {
288 const inShape = [1, 5, 5, 5, 1];
289 const stride = 2;
290 const dilation = 1;
291 const convInfo = conv_util.computeConv3DInfo(inShape, [3, 3, 3, 1, 1], stride, dilation, 'same');
292 expect(convInfo.batchSize).toEqual(1);
293 expect(convInfo.outDepth).toEqual(3);
294 expect(convInfo.outHeight).toEqual(3);
295 expect(convInfo.outWidth).toEqual(3);
296 expect(convInfo.outChannels).toEqual(1);
297 expect(convInfo.padInfo.front).toBe(1);
298 expect(convInfo.padInfo.back).toBe(1);
299 expect(convInfo.padInfo.left).toBe(1);
300 expect(convInfo.padInfo.right).toBe(1);
301 expect(convInfo.padInfo.top).toBe(1);
302 expect(convInfo.padInfo.bottom).toBe(1);
303 });
304 it('2x2x2 conv over 3x3x3 array with valid pad with stride 2', () => {
305 const inShape = [1, 3, 3, 3, 1];
306 const stride = 2;
307 const dilation = 1;
308 const convInfo = conv_util.computeConv3DInfo(inShape, [2, 2, 2, 1, 1], stride, dilation, 'valid');
309 expect(convInfo.batchSize).toEqual(1);
310 expect(convInfo.outDepth).toEqual(1);
311 expect(convInfo.outHeight).toEqual(1);
312 expect(convInfo.outWidth).toEqual(1);
313 expect(convInfo.outChannels).toEqual(1);
314 });
315 it('2x1x1 conv over 3x3x3 array with valid pad with stride 1', () => {
316 const inShape = [1, 3, 3, 3, 1];
317 const stride = 1;
318 const dilation = 1;
319 const convInfo = conv_util.computeConv3DInfo(inShape, [2, 1, 1, 1, 1], stride, dilation, 'valid');
320 expect(convInfo.batchSize).toEqual(1);
321 expect(convInfo.outDepth).toEqual(2);
322 expect(convInfo.outHeight).toEqual(3);
323 expect(convInfo.outWidth).toEqual(3);
324 expect(convInfo.outChannels).toEqual(1);
325 });
326 it('2x1x1 conv over 3x3x3 array with valid pad with strides d=2, h=1, w=1', () => {
327 const inShape = [1, 3, 3, 3, 1];
328 const strides = [2, 1, 1];
329 const dilation = 1;
330 const convInfo = conv_util.computeConv3DInfo(inShape, [2, 1, 1, 1, 1], strides, dilation, 'valid');
331 expect(convInfo.batchSize).toEqual(1);
332 expect(convInfo.outDepth).toEqual(1);
333 expect(convInfo.outHeight).toEqual(3);
334 expect(convInfo.outWidth).toEqual(3);
335 expect(convInfo.outChannels).toEqual(1);
336 });
337 it('1x2x2 conv over 3x3x3 array with valid pad with stride 1', () => {
338 const inShape = [1, 3, 3, 3, 1];
339 const stride = 1;
340 const dilation = 1;
341 const convInfo = conv_util.computeConv3DInfo(inShape, [1, 2, 2, 1, 1], stride, dilation, 'valid');
342 expect(convInfo.batchSize).toEqual(1);
343 expect(convInfo.outDepth).toEqual(3);
344 expect(convInfo.outHeight).toEqual(2);
345 expect(convInfo.outWidth).toEqual(2);
346 expect(convInfo.outChannels).toEqual(1);
347 });
348 it('1x2x2 conv over 3x3x3 array with valid pad with stride 1, batch=5', () => {
349 const inShape = [5, 3, 3, 3, 1];
350 const stride = 1;
351 const dilation = 1;
352 const convInfo = conv_util.computeConv3DInfo(inShape, [1, 2, 2, 1, 1], stride, dilation, 'valid');
353 expect(convInfo.batchSize).toEqual(5);
354 expect(convInfo.outDepth).toEqual(3);
355 expect(convInfo.outHeight).toEqual(2);
356 expect(convInfo.outWidth).toEqual(2);
357 expect(convInfo.outChannels).toEqual(1);
358 });
359 it('2x2x2 conv over 3x3x3 array with same pad with dilations 2', () => {
360 const inShape = [1, 3, 3, 3, 1];
361 const stride = 1;
362 const dilations = 2;
363 const convInfo = conv_util.computeConv3DInfo(inShape, [2, 2, 2, 1, 1], stride, dilations, 'same');
364 expect(convInfo.batchSize).toEqual(1);
365 expect(convInfo.outDepth).toEqual(3);
366 expect(convInfo.outHeight).toEqual(3);
367 expect(convInfo.outWidth).toEqual(3);
368 expect(convInfo.outChannels).toEqual(1);
369 // pad evenly on all sides
370 expect(convInfo.padInfo.front).toBe(1);
371 expect(convInfo.padInfo.back).toBe(1);
372 expect(convInfo.padInfo.left).toBe(1);
373 expect(convInfo.padInfo.right).toBe(1);
374 expect(convInfo.padInfo.top).toBe(1);
375 expect(convInfo.padInfo.bottom).toBe(1);
376 });
377 it('2x1x1 conv over 3x3x3 array with same pad with dilations 2', () => {
378 const inShape = [1, 3, 3, 3, 1];
379 const stride = 1;
380 const dilations = 2;
381 const convInfo = conv_util.computeConv3DInfo(inShape, [2, 1, 1, 1, 1], stride, dilations, 'same');
382 expect(convInfo.batchSize).toEqual(1);
383 expect(convInfo.outDepth).toEqual(3);
384 expect(convInfo.outHeight).toEqual(3);
385 expect(convInfo.outWidth).toEqual(3);
386 expect(convInfo.outChannels).toEqual(1);
387 // pad top and bottom
388 expect(convInfo.padInfo.front).toBe(1);
389 expect(convInfo.padInfo.back).toBe(1);
390 expect(convInfo.padInfo.left).toBe(0);
391 expect(convInfo.padInfo.right).toBe(0);
392 expect(convInfo.padInfo.top).toBe(0);
393 expect(convInfo.padInfo.bottom).toBe(0);
394 });
395 it('3x4x4 conv over 8x8 array with same pad with dilations d=4 h=3 w=3', () => {
396 const inShape = [1, 8, 8, 8, 1];
397 const stride = 1;
398 const dilations = [4, 3, 3];
399 const convInfo = conv_util.computeConv3DInfo(inShape, [3, 4, 4, 1, 1], stride, dilations, 'same');
400 expect(convInfo.batchSize).toEqual(1);
401 expect(convInfo.outDepth).toEqual(8);
402 expect(convInfo.outHeight).toEqual(8);
403 expect(convInfo.outWidth).toEqual(8);
404 expect(convInfo.outChannels).toEqual(1);
405 expect(convInfo.padInfo.front).toBe(4);
406 expect(convInfo.padInfo.back).toBe(4);
407 expect(convInfo.padInfo.left).toBe(4);
408 expect(convInfo.padInfo.right).toBe(5);
409 expect(convInfo.padInfo.top).toBe(4);
410 expect(convInfo.padInfo.bottom).toBe(5);
411 });
412 it('2x1x1 conv over 3x3x3 array with valid pad with dilations 2', () => {
413 const inShape = [1, 3, 3, 3, 1];
414 const stride = 1;
415 const dilations = 2;
416 const convInfo = conv_util.computeConv3DInfo(inShape, [2, 1, 1, 1, 1], stride, dilations, 'valid');
417 expect(convInfo.batchSize).toEqual(1);
418 expect(convInfo.outDepth).toEqual(1);
419 expect(convInfo.outHeight).toEqual(3);
420 expect(convInfo.outWidth).toEqual(3);
421 expect(convInfo.outChannels).toEqual(1);
422 });
423 it('2x2x2 conv over 3x3x3 array with valid pad with dilations 2', () => {
424 const inShape = [1, 3, 3, 3, 1];
425 const stride = 1;
426 const dilations = 2;
427 const convInfo = conv_util.computeConv3DInfo(inShape, [2, 2, 2, 1, 1], stride, dilations, 'valid');
428 expect(convInfo.batchSize).toEqual(1);
429 expect(convInfo.outDepth).toEqual(1);
430 expect(convInfo.outHeight).toEqual(1);
431 expect(convInfo.outWidth).toEqual(1);
432 expect(convInfo.outChannels).toEqual(1);
433 });
434 it('2x2x2 conv over 4x4x4 array with valid pad with dilations 2', () => {
435 const inShape = [1, 4, 4, 4, 1];
436 const stride = 1;
437 const dilations = 2;
438 const convInfo = conv_util.computeConv3DInfo(inShape, [2, 2, 2, 1, 1], stride, dilations, 'valid');
439 expect(convInfo.batchSize).toEqual(1);
440 expect(convInfo.outDepth).toEqual(2);
441 expect(convInfo.outHeight).toEqual(2);
442 expect(convInfo.outWidth).toEqual(2);
443 expect(convInfo.outChannels).toEqual(1);
444 });
445});
446describe('conv_util computeConv2DInfo with depthwise=true', () => {
447 it('1x1 filter over 1x1 array with same pad', () => {
448 const inChannels = 1;
449 const inShape = [1, 1, 1, inChannels];
450 const fSize = 1;
451 const chMul = 1;
452 const stride = 1;
453 const dilation = 1;
454 const pad = 'same';
455 const convInfo = conv_util.computeConv2DInfo(inShape, [fSize, fSize, inChannels, chMul], stride, dilation, pad, null, true);
456 expect(convInfo.batchSize).toEqual(1);
457 expect(convInfo.outHeight).toEqual(1);
458 expect(convInfo.outWidth).toEqual(1);
459 expect(convInfo.outChannels).toEqual(1);
460 expect(convInfo.effectiveFilterWidth).toEqual(1);
461 expect(convInfo.effectiveFilterHeight).toEqual(1);
462 });
463 it('2x2 filter over 3x3 array with same pad, chMul=3, depth=2', () => {
464 const inChannels = 2;
465 const batchSize = 1;
466 const inSize = 3;
467 const inShape = [batchSize, inSize, inSize, inChannels];
468 const fSize = 2;
469 const chMul = 3;
470 const stride = 1;
471 const dilation = 1;
472 const pad = 'same';
473 const convInfo = conv_util.computeConv2DInfo(inShape, [fSize, fSize, inChannels, chMul], stride, dilation, pad, null, true);
474 expect(convInfo.batchSize).toEqual(1);
475 expect(convInfo.outHeight).toEqual(3);
476 expect(convInfo.outWidth).toEqual(3);
477 expect(convInfo.outChannels).toEqual(6);
478 expect(convInfo.effectiveFilterWidth).toEqual(2);
479 expect(convInfo.effectiveFilterHeight).toEqual(2);
480 });
481 it('2x2 filter over 3x3 array with valid pad, chMul=3, depth=2', () => {
482 const inChannels = 2;
483 const batchSize = 1;
484 const inSize = 3;
485 const inShape = [batchSize, inSize, inSize, inChannels];
486 const fSize = 2;
487 const chMul = 3;
488 const stride = 1;
489 const dilation = 1;
490 const pad = 'valid';
491 const convInfo = conv_util.computeConv2DInfo(inShape, [fSize, fSize, inChannels, chMul], stride, dilation, pad, null, true);
492 expect(convInfo.batchSize).toEqual(1);
493 expect(convInfo.outHeight).toEqual(2);
494 expect(convInfo.outWidth).toEqual(2);
495 expect(convInfo.outChannels).toEqual(6);
496 expect(convInfo.effectiveFilterWidth).toEqual(2);
497 expect(convInfo.effectiveFilterHeight).toEqual(2);
498 });
499});
500describe('conv_util computeConv3DInfo with depthwise=true', () => {
501 it('1x1x1 filter over 1x1x1 array with same pad', () => {
502 const inChannels = 1;
503 const inShape = [1, 1, 1, 1, inChannels];
504 const fSize = 1;
505 const chMul = 1;
506 const stride = 1;
507 const dilation = 1;
508 const pad = 'same';
509 const convInfo = conv_util.computeConv3DInfo(inShape, [fSize, fSize, fSize, inChannels, chMul], stride, dilation, pad, true);
510 expect(convInfo.batchSize).toEqual(1);
511 expect(convInfo.outDepth).toEqual(1);
512 expect(convInfo.outHeight).toEqual(1);
513 expect(convInfo.outWidth).toEqual(1);
514 expect(convInfo.outChannels).toEqual(1);
515 });
516 it('2x2x2 filter over 3x3x3 array with same pad, chMul=3, depth=2', () => {
517 const inChannels = 2;
518 const batchSize = 1;
519 const inSize = 3;
520 const inShape = [batchSize, inSize, inSize, inSize, inChannels];
521 const fSize = 2;
522 const chMul = 3;
523 const stride = 1;
524 const dilation = 1;
525 const pad = 'same';
526 const convInfo = conv_util.computeConv3DInfo(inShape, [fSize, fSize, fSize, inChannels, chMul], stride, dilation, pad, true);
527 expect(convInfo.batchSize).toEqual(1);
528 expect(convInfo.outDepth).toEqual(3);
529 expect(convInfo.outHeight).toEqual(3);
530 expect(convInfo.outWidth).toEqual(3);
531 expect(convInfo.outChannels).toEqual(6);
532 });
533 it('2x2x2 filter over 3x3x3 array with valid pad, chMul=3, depth=2', () => {
534 const inChannels = 2;
535 const batchSize = 1;
536 const inSize = 3;
537 const inShape = [batchSize, inSize, inSize, inSize, inChannels];
538 const fSize = 2;
539 const chMul = 3;
540 const stride = 1;
541 const dilation = 1;
542 const pad = 'valid';
543 const convInfo = conv_util.computeConv3DInfo(inShape, [fSize, fSize, fSize, inChannels, chMul], stride, dilation, pad, true);
544 expect(convInfo.batchSize).toEqual(1);
545 expect(convInfo.outDepth).toEqual(2);
546 expect(convInfo.outHeight).toEqual(2);
547 expect(convInfo.outWidth).toEqual(2);
548 expect(convInfo.outChannels).toEqual(6);
549 });
550});
551describe('conv_util computeConv2DInfo channelsFirst', () => {
552 it('2x2 conv over 3x3 array with same pad', () => {
553 const inDepth = 2;
554 const outDepth = 4;
555 const inShape = [1, inDepth, 3, 3];
556 const stride = 1;
557 const dilation = 1;
558 const convInfo = conv_util.computeConv2DInfo(inShape, [2, 2, inDepth, outDepth], stride, dilation, 'same', null, false, 'channelsFirst');
559 expect(convInfo.batchSize).toEqual(1);
560 expect(convInfo.outHeight).toEqual(3);
561 expect(convInfo.outWidth).toEqual(3);
562 expect(convInfo.outChannels).toEqual(4);
563 expect(convInfo.outShape).toEqual([1, 4, 3, 3]);
564 expect(convInfo.effectiveFilterWidth).toEqual(2);
565 expect(convInfo.effectiveFilterHeight).toEqual(2);
566 // Should produce non-even padding with extra pixel at the right/bottom.
567 expect(convInfo.padInfo.left).toBe(0);
568 expect(convInfo.padInfo.right).toBe(1);
569 expect(convInfo.padInfo.top).toBe(0);
570 expect(convInfo.padInfo.bottom).toBe(1);
571 });
572 it('2x2 conv over 3x3 array with valid pad', () => {
573 const inDepth = 6;
574 const outDepth = 16;
575 const inShape = [1, inDepth, 3, 3];
576 const stride = 1;
577 const dilation = 1;
578 const convInfo = conv_util.computeConv2DInfo(inShape, [2, 2, inDepth, outDepth], stride, dilation, 'valid', null, false, 'channelsFirst');
579 expect(convInfo.batchSize).toEqual(1);
580 expect(convInfo.outHeight).toEqual(2);
581 expect(convInfo.outWidth).toEqual(2);
582 expect(convInfo.outChannels).toEqual(16);
583 expect(convInfo.outShape).toEqual([1, 16, 2, 2]);
584 expect(convInfo.effectiveFilterWidth).toEqual(2);
585 expect(convInfo.effectiveFilterHeight).toEqual(2);
586 // Should produce no padding.
587 expect(convInfo.padInfo.left).toBe(0);
588 expect(convInfo.padInfo.right).toBe(0);
589 expect(convInfo.padInfo.top).toBe(0);
590 expect(convInfo.padInfo.bottom).toBe(0);
591 });
592});
593describe('conv_util computeConv3DInfo channelsFirst', () => {
594 it('2x2x2 conv over 3x3x3 array with same pad', () => {
595 const inDepth = 2;
596 const outDepth = 4;
597 const inShape = [1, inDepth, 3, 3, 3];
598 const stride = 1;
599 const dilation = 1;
600 const convInfo = conv_util.computeConv3DInfo(inShape, [2, 2, 2, inDepth, outDepth], stride, dilation, 'same', false, 'channelsFirst');
601 expect(convInfo.batchSize).toEqual(1);
602 expect(convInfo.outDepth).toEqual(3);
603 expect(convInfo.outHeight).toEqual(3);
604 expect(convInfo.outWidth).toEqual(3);
605 expect(convInfo.outChannels).toEqual(4);
606 expect(convInfo.outShape).toEqual([1, 4, 3, 3, 3]);
607 // Should produce non-even padding with extra pixel at the back/right/bottom
608 expect(convInfo.padInfo.front).toBe(0);
609 expect(convInfo.padInfo.back).toBe(1);
610 expect(convInfo.padInfo.left).toBe(0);
611 expect(convInfo.padInfo.right).toBe(1);
612 expect(convInfo.padInfo.top).toBe(0);
613 expect(convInfo.padInfo.bottom).toBe(1);
614 });
615 it('2x2x2 conv over 3x3x3 array with valid pad', () => {
616 const inDepth = 6;
617 const outDepth = 16;
618 const inShape = [1, inDepth, 3, 3, 3];
619 const stride = 1;
620 const dilation = 1;
621 const convInfo = conv_util.computeConv3DInfo(inShape, [2, 2, 2, inDepth, outDepth], stride, dilation, 'valid', false, 'channelsFirst');
622 expect(convInfo.batchSize).toEqual(1);
623 expect(convInfo.outDepth).toEqual(2);
624 expect(convInfo.outHeight).toEqual(2);
625 expect(convInfo.outWidth).toEqual(2);
626 expect(convInfo.outChannels).toEqual(16);
627 expect(convInfo.outShape).toEqual([1, 16, 2, 2, 2]);
628 // Should produce no padding.
629 expect(convInfo.padInfo.front).toBe(0);
630 expect(convInfo.padInfo.back).toBe(0);
631 expect(convInfo.padInfo.left).toBe(0);
632 expect(convInfo.padInfo.right).toBe(0);
633 expect(convInfo.padInfo.top).toBe(0);
634 expect(convInfo.padInfo.bottom).toBe(0);
635 });
636});
637describe('conv_util computeConv2DInfo roundingMode', () => {
638 const inChannels = 6;
639 const batchSize = 1;
640 const inSize = 5;
641 const inShape = [batchSize, inSize, inSize, inChannels];
642 const fSize = 2;
643 const chMul = 12;
644 const stride = 2;
645 const dilation = 1;
646 const pad = 1;
647 it('Default truncate the output dimension of Conv Layer', () => {
648 const convInfo = conv_util.computeConv2DInfo(inShape, [fSize, fSize, inChannels, chMul], stride, dilation, pad);
649 expect(convInfo.outShape).toEqual([batchSize, 3, 3, chMul]);
650 });
651 it('Floor the output dimension of Conv Layer', () => {
652 const convInfo = conv_util.computeConv2DInfo(inShape, [fSize, fSize, inChannels, chMul], stride, dilation, pad, 'floor');
653 expect(convInfo.outShape).toEqual([batchSize, 3, 3, chMul]);
654 });
655 it('Round the output dimension of Conv Layer', () => {
656 const convInfo = conv_util.computeConv2DInfo(inShape, [fSize, fSize, inChannels, chMul], stride, dilation, pad, 'round');
657 expect(convInfo.outShape).toEqual([batchSize, 4, 4, chMul]);
658 });
659 it('Ceil the output dimension of Conv Layer', () => {
660 const convInfo = conv_util.computeConv2DInfo(inShape, [fSize, fSize, inChannels, chMul], stride, dilation, pad, 'ceil');
661 expect(convInfo.outShape).toEqual([batchSize, 4, 4, chMul]);
662 });
663});
664describe('conv_util computePoolInfo roundingMode', () => {
665 const inChannels = 6;
666 const batchSize = 1;
667 const inSize = 5;
668 const inShape = [batchSize, inSize, inSize, inChannels];
669 const fSize = 2;
670 const stride = 2;
671 const dilation = 1;
672 const pad = 1;
673 it('Default truncate the output dimension of Pool Layer', () => {
674 const poolInfo = conv_util.computePool2DInfo(inShape, [fSize, fSize], stride, pad, dilation, 'floor');
675 expect(poolInfo.outShape).toEqual([batchSize, 3, 3, inChannels]);
676 });
677 it('Floor the output dimension of Pool Layer', () => {
678 const poolInfo = conv_util.computePool2DInfo(inShape, [fSize, fSize], stride, pad, dilation, 'floor');
679 expect(poolInfo.outShape).toEqual([batchSize, 3, 3, inChannels]);
680 });
681 it('Round the output dimension of Pool Layer', () => {
682 const poolInfo = conv_util.computePool2DInfo(inShape, [fSize, fSize], stride, pad, dilation, 'round');
683 expect(poolInfo.outShape).toEqual([batchSize, 4, 4, inChannels]);
684 });
685 it('Ceil the output dimension of Pool Layer', () => {
686 const poolInfo = conv_util.computePool2DInfo(inShape, [fSize, fSize], stride, pad, dilation, 'ceil');
687 expect(poolInfo.outShape).toEqual([batchSize, 4, 4, inChannels]);
688 });
689});
690describe('conv_util computePool3dInfo', () => {
691 it('1x1x1 pool over 1x1x1 array with valid pad', () => {
692 const inShape = [1, 1, 1, 1, 1];
693 const filterSize = 1;
694 const stride = 1;
695 const dilation = 1;
696 const convInfo = conv_util.computePool3DInfo(inShape, filterSize, stride, dilation, 'valid');
697 expect(convInfo.batchSize).toEqual(1);
698 expect(convInfo.outDepth).toEqual(1);
699 expect(convInfo.outHeight).toEqual(1);
700 expect(convInfo.outWidth).toEqual(1);
701 expect(convInfo.outChannels).toEqual(1);
702 expect(convInfo.effectiveFilterDepth).toEqual(1);
703 expect(convInfo.effectiveFilterWidth).toEqual(1);
704 expect(convInfo.effectiveFilterHeight).toEqual(1);
705 });
706 it('1x1x1 pool over 3x3x3 array with valid pad', () => {
707 const inShape = [1, 3, 3, 3, 1];
708 const filterSize = 1;
709 const stride = 1;
710 const dilation = 1;
711 const convInfo = conv_util.computePool3DInfo(inShape, filterSize, stride, dilation, 'valid');
712 expect(convInfo.batchSize).toEqual(1);
713 expect(convInfo.outDepth).toEqual(3);
714 expect(convInfo.outHeight).toEqual(3);
715 expect(convInfo.outWidth).toEqual(3);
716 expect(convInfo.outChannels).toEqual(1);
717 expect(convInfo.effectiveFilterDepth).toEqual(1);
718 expect(convInfo.effectiveFilterWidth).toEqual(1);
719 expect(convInfo.effectiveFilterHeight).toEqual(1);
720 });
721 it('2x2x2 pool over 3x3x3 array with same pad', () => {
722 const inShape = [1, 3, 3, 3, 1];
723 const filterSize = 2;
724 const stride = 1;
725 const dilation = 1;
726 const convInfo = conv_util.computePool3DInfo(inShape, filterSize, stride, dilation, 'same');
727 expect(convInfo.batchSize).toEqual(1);
728 expect(convInfo.outDepth).toEqual(3);
729 expect(convInfo.outHeight).toEqual(3);
730 expect(convInfo.outWidth).toEqual(3);
731 expect(convInfo.outChannels).toEqual(1);
732 expect(convInfo.effectiveFilterDepth).toEqual(2);
733 expect(convInfo.effectiveFilterWidth).toEqual(2);
734 expect(convInfo.effectiveFilterHeight).toEqual(2);
735 expect(convInfo.padInfo.top).toEqual(0);
736 expect(convInfo.padInfo.bottom).toEqual(1);
737 expect(convInfo.padInfo.left).toEqual(0);
738 expect(convInfo.padInfo.right).toEqual(1);
739 expect(convInfo.padInfo.front).toEqual(0);
740 expect(convInfo.padInfo.back).toEqual(1);
741 expect(convInfo.padInfo.type).toEqual('SAME');
742 });
743 it('2x2x2 pool over 3x3x3 array with valid pad', () => {
744 const inShape = [1, 3, 3, 3, 1];
745 const filterSize = 2;
746 const stride = 1;
747 const dilation = 1;
748 const convInfo = conv_util.computePool3DInfo(inShape, filterSize, stride, dilation, 'valid');
749 expect(convInfo.batchSize).toEqual(1);
750 expect(convInfo.outDepth).toEqual(2);
751 expect(convInfo.outHeight).toEqual(2);
752 expect(convInfo.outWidth).toEqual(2);
753 expect(convInfo.outChannels).toEqual(1);
754 expect(convInfo.effectiveFilterDepth).toEqual(2);
755 expect(convInfo.effectiveFilterWidth).toEqual(2);
756 expect(convInfo.effectiveFilterHeight).toEqual(2);
757 });
758 it('2x2x2 pool over 4x4x4 array with valid pad, stride 2', () => {
759 const inShape = [1, 4, 4, 4, 1];
760 const filterSize = 2;
761 const stride = 2;
762 const dilation = 1;
763 const convInfo = conv_util.computePool3DInfo(inShape, filterSize, stride, dilation, 'valid');
764 expect(convInfo.batchSize).toEqual(1);
765 expect(convInfo.outDepth).toEqual(2);
766 expect(convInfo.outHeight).toEqual(2);
767 expect(convInfo.outWidth).toEqual(2);
768 expect(convInfo.outChannels).toEqual(1);
769 expect(convInfo.effectiveFilterDepth).toEqual(2);
770 expect(convInfo.effectiveFilterWidth).toEqual(2);
771 expect(convInfo.effectiveFilterHeight).toEqual(2);
772 });
773 it('2x2x2 pool over 3x3x3 array with valid pad, dilation 2', () => {
774 const inShape = [1, 3, 3, 3, 1];
775 const filterSize = 2;
776 const stride = 1;
777 const dilation = 2;
778 const convInfo = conv_util.computePool3DInfo(inShape, filterSize, stride, dilation, 'valid');
779 expect(convInfo.batchSize).toEqual(1);
780 expect(convInfo.outDepth).toEqual(1);
781 expect(convInfo.outHeight).toEqual(1);
782 expect(convInfo.outWidth).toEqual(1);
783 expect(convInfo.outChannels).toEqual(1);
784 expect(convInfo.effectiveFilterDepth).toEqual(3);
785 expect(convInfo.effectiveFilterWidth).toEqual(3);
786 expect(convInfo.effectiveFilterHeight).toEqual(3);
787 });
788 it('2x2x2 pool over 3x3x3 array with pad 1, roundingMode floor', () => {
789 const inShape = [1, 3, 3, 3, 1];
790 const filterSize = 2;
791 const stride = 1;
792 const dilation = 1;
793 const convInfo = conv_util.computePool3DInfo(inShape, filterSize, stride, dilation, 1, 'floor');
794 expect(convInfo.batchSize).toEqual(1);
795 expect(convInfo.outDepth).toEqual(4);
796 expect(convInfo.outHeight).toEqual(4);
797 expect(convInfo.outWidth).toEqual(4);
798 expect(convInfo.outChannels).toEqual(1);
799 expect(convInfo.effectiveFilterDepth).toEqual(2);
800 expect(convInfo.effectiveFilterWidth).toEqual(2);
801 expect(convInfo.effectiveFilterHeight).toEqual(2);
802 expect(convInfo.padInfo.top).toEqual(1);
803 expect(convInfo.padInfo.bottom).toEqual(1);
804 expect(convInfo.padInfo.left).toEqual(1);
805 expect(convInfo.padInfo.right).toEqual(1);
806 expect(convInfo.padInfo.front).toEqual(1);
807 expect(convInfo.padInfo.back).toEqual(1);
808 expect(convInfo.padInfo.type).toEqual('NUMBER');
809 });
810 it('throws unknown dataFormat', () => {
811 const inShape = [1, 3, 3, 3, 1];
812 const filterSize = 2;
813 const stride = 1;
814 const dilation = 1;
815 const fakeDataFormat = 'fakeFormat';
816 expect(() => conv_util.computePool3DInfo(inShape, filterSize, stride, dilation, 1, 'floor', fakeDataFormat))
817 .toThrowError();
818 });
819});
820describe('conv_util convertConv2DDataFormat', () => {
821 it('convert NHWC to channelsLast', () => {
822 const dataFormat = 'NHWC';
823 const $dataFormat = conv_util.convertConv2DDataFormat(dataFormat);
824 expect($dataFormat).toEqual('channelsLast');
825 });
826 it('convert NCHW to channelsFirst', () => {
827 const dataFormat = 'NCHW';
828 const $dataFormat = conv_util.convertConv2DDataFormat(dataFormat);
829 expect($dataFormat).toEqual('channelsFirst');
830 });
831 it('throws unknown dataFormat', () => {
832 const dataFormat = 'FakeFormat';
833 expect(() => conv_util.convertConv2DDataFormat(dataFormat))
834 .toThrowError();
835 });
836});
837//# sourceMappingURL=data:application/json;base64,
\No newline at end of file