UNPKG

45.8 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20describeWithFlags('dilation2d', ALL_ENVS, () => {
21 it('valid padding.', async () => {
22 const inputShape = [1, 2, 2, 1];
23 const filterShape = [2, 2, 1];
24 const x = tf.tensor4d([.1, .2, .3, .4], inputShape);
25 const filter = tf.tensor3d([.4, .3, .1, .0], filterShape);
26 const result = tf.dilation2d(x, filter, 1 /* strides */, 'valid');
27 expect(result.shape).toEqual([1, 1, 1, 1]);
28 expectArraysClose(await result.data(), [.5]);
29 });
30 it('same padding.', async () => {
31 const inputShape = [1, 2, 2, 1];
32 const filterShape = [2, 2, 1];
33 const x = tf.tensor4d([.1, .2, .3, .4], inputShape);
34 const filter = tf.tensor3d([.4, .3, .1, .0], filterShape);
35 const result = tf.dilation2d(x, filter, 1 /* strides */, 'same');
36 expect(result.shape).toEqual([1, 2, 2, 1]);
37 expectArraysClose(await result.data(), [.5, .6, .7, .8]);
38 });
39 it('same padding depth 3.', async () => {
40 const inputShape = [1, 2, 2, 3];
41 const filterShape = [2, 2, 3];
42 const x = tf.tensor4d([.1, .2, .0, .2, .3, .1, .3, .4, .2, .4, .5, .3], inputShape);
43 const filter = tf.tensor3d([.4, .5, .3, .3, .4, .2, .1, .2, .0, .0, .1, -.1], filterShape);
44 const result = tf.dilation2d(x, filter, 1 /* strides */, 'same');
45 expect(result.shape).toEqual([1, 2, 2, 3]);
46 expectArraysClose(await result.data(), [.5, .7, .3, .6, .8, .4, .7, .9, .5, .8, 1., .6]);
47 });
48 it('same padding batch 2.', async () => {
49 const inputShape = [2, 2, 2, 1];
50 const filterShape = [2, 2, 1];
51 const x = tf.tensor4d([.1, .2, .3, .4, .2, .3, .4, .5], inputShape);
52 const filter = tf.tensor3d([.4, .3, .1, .0], filterShape);
53 const result = tf.dilation2d(x, filter, 1 /* strides */, 'same');
54 expect(result.shape).toEqual([2, 2, 2, 1]);
55 expectArraysClose(await result.data(), [.5, .6, .7, .8, .6, .7, .8, .9]);
56 });
57 it('same padding filter 2.', async () => {
58 const inputShape = [1, 3, 3, 1];
59 const filterShape = [2, 2, 1];
60 const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape);
61 const filter = tf.tensor3d([.4, .3, .1, .2], filterShape);
62 const result = tf.dilation2d(x, filter, 1 /* strides */, 'same');
63 expect(result.shape).toEqual([1, 3, 3, 1]);
64 expectArraysClose(await result.data(), [.7, .8, .7, 1, 1.1, 1, 1.1, 1.2, 1.3]);
65 });
66 it('valid padding non-square-window.', async () => {
67 const inputShape = [1, 2, 2, 1];
68 const filterShape = [1, 2, 1];
69 const x = tf.tensor4d([.1, .2, .3, .4], inputShape);
70 const filter = tf.tensor3d([.4, .3], filterShape);
71 const result = tf.dilation2d(x, filter, 1 /* strides */, 'valid');
72 expect(result.shape).toEqual([1, 2, 1, 1]);
73 expectArraysClose(await result.data(), [.5, .7]);
74 });
75 it('same padding dilations 2.', async () => {
76 const inputShape = [1, 3, 3, 1];
77 const filterShape = [2, 2, 1];
78 const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape);
79 const filter = tf.tensor3d([.4, .3, .1, .2], filterShape);
80 const result = tf.dilation2d(x, filter, 1 /* strides */, 'same', 2);
81 // Because dilations = 2, the effective filter is [3, 3, 1]:
82 // filter_eff = [[[.4], [.0], [.3]],
83 // [[.0], [.0], [.0]],
84 // [[.1], [.0], [.2]]]
85 expect(result.shape).toEqual([1, 3, 3, 1]);
86 expectArraysClose(await result.data(), [.7, .8, .6, 1., 1.1, .9, .8, .9, .9]);
87 });
88 it('valid padding uneven stride.', async () => {
89 const inputShape = [1, 3, 4, 1];
90 const filterShape = [2, 2, 1];
91 const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9, 1., 1.1, 1.2], inputShape);
92 const filter = tf.tensor3d([.4, .3, .1, .2], filterShape);
93 const result = tf.dilation2d(x, filter, [1, 2] /* strides */, 'valid');
94 expect(result.shape).toEqual([1, 2, 2, 1]);
95 expectArraysClose(await result.data(), [.8, 1., 1.2, 1.4]);
96 });
97 it('throws when input rank is not 3 or 4.', async () => {
98 const filterShape = [1, 1, 1];
99 // tslint:disable-next-line:no-any
100 const x = tf.tensor1d([.5]);
101 const filter = tf.tensor3d([.4], filterShape);
102 expect(() => tf.dilation2d(x, filter, 1, 'valid')).toThrowError();
103 });
104 it('thorws when filter is not rank 3.', async () => {
105 const inputShape = [1, 2, 2, 1];
106 const filterShape = [2, 2];
107 const x = tf.tensor4d([.1, .2, .3, .4], inputShape);
108 // tslint:disable-next-line:no-any
109 const filter = tf.tensor2d([.4, .3, .1, .0], filterShape);
110 expect(() => tf.dilation2d(x, filter, 1, 'valid')).toThrowError();
111 });
112 it('throws when data format is not NHWC.', async () => {
113 const inputShape = [1, 2, 2, 1];
114 const filterShape = [2, 2, 1];
115 const x = tf.tensor4d([.1, .2, .3, .4], inputShape);
116 const filter = tf.tensor3d([.4, .3, .1, .0], filterShape);
117 // tslint:disable-next-line:no-any
118 const dataFormat = 'NCHW';
119 expect(() => tf.dilation2d(x, filter, 1 /* strides */, 'valid', 1, dataFormat))
120 .toThrowError();
121 });
122 it('dilation gradient valid padding.', async () => {
123 const inputShape = [1, 3, 3, 1];
124 const filterShape = [1, 1, 1];
125 const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape);
126 const filter = tf.tensor3d([.5], filterShape);
127 const dy = tf.tensor4d([.2, .3, .4, .2, .1, 1., .2, .3, .4], inputShape);
128 const grads = tf.grads((x, filter) => x.dilation2d(filter, 1, 'valid'));
129 const [dx, dfilter] = grads([x, filter], dy);
130 expect(dx.shape).toEqual(x.shape);
131 expectArraysClose(await dx.data(), [.2, .3, .4, .2, .1, 1., .2, .3, .4]);
132 expect(dfilter.shape).toEqual(filterShape);
133 expectArraysClose(await dfilter.data(), [3.1]);
134 });
135 it('dilation gradient same padding.', async () => {
136 const inputShape = [1, 3, 3, 1];
137 const filterShape = [1, 1, 1];
138 const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape);
139 const filter = tf.tensor3d([.5], filterShape);
140 const dy = tf.tensor4d([.2, .3, .4, .2, .1, 1., .2, .3, .4], inputShape);
141 const grads = tf.grads((x, filter) => x.dilation2d(filter, 1, 'same'));
142 const [dx, dfilter] = grads([x, filter], dy);
143 expect(dx.shape).toEqual(x.shape);
144 expectArraysClose(await dx.data(), [.2, .3, .4, .2, .1, 1., .2, .3, .4]);
145 expect(dfilter.shape).toEqual(filterShape);
146 expectArraysClose(await dfilter.data(), [3.1]);
147 });
148 it('dilation gradient same padding depth 2.', async () => {
149 const inputShape = [1, 2, 2, 3];
150 const filterShape = [1, 1, 3];
151 const x = tf.tensor4d([.1, .2, .0, .2, .3, .1, .3, .4, .2, .4, .5, .3], inputShape);
152 const filter = tf.tensor3d([.4, .5, .6], filterShape);
153 const dy = tf.tensor4d([.2, .3, .4, .2, .1, 1., .2, .3, .4, .8, -.1, .1], inputShape);
154 const grads = tf.grads((x, filter) => x.dilation2d(filter, 1, 'same'));
155 const [dx, dfilter] = grads([x, filter], dy);
156 expect(dx.shape).toEqual(x.shape);
157 expectArraysClose(await dx.data(), [.2, .3, .4, .2, .1, 1., .2, .3, .4, .8, -.1, .1]);
158 expect(dfilter.shape).toEqual(filterShape);
159 expectArraysClose(await dfilter.data(), [1.4, .6, 1.9]);
160 });
161 it('dilation gradient valid padding filter 2.', async () => {
162 const inputShape = [1, 3, 3, 1];
163 const filterShape = [2, 2, 1];
164 const dyShape = [1, 2, 2, 1];
165 const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape);
166 const filter = tf.tensor3d([.4, .3, .1, .2], filterShape);
167 const dy = tf.tensor4d([.2, .3, .4, .2], dyShape);
168 const grads = tf.grads((x, filter) => x.dilation2d(filter, 1, 'valid'));
169 const [dx, dfilter] = grads([x, filter], dy);
170 expect(dx.shape).toEqual(x.shape);
171 expectArraysClose(await dx.data(), [0, 0, 0, 0, .2, .3, 0, .4, .2]);
172 expect(dfilter.shape).toEqual(filterShape);
173 expectArraysClose(await dfilter.data(), [0, 0, 0, 1.1]);
174 });
175 it('dilation gradient same padding filter 2 depth 3.', async () => {
176 const inputShape = [1, 3, 3, 3];
177 const filterShape = [2, 2, 3];
178 const x = tf.tensor4d([
179 .1, .2, .3, .4, .5, .6, .7, .8, .9, .3, .2, .3, .4, .5,
180 .1, .9, .6, .3, .4, .5, .6, .2, .3, .5, .1, .2, .3
181 ], inputShape);
182 const filter = tf.tensor3d([.4, .3, .1, .2, .2, .1, .7, .3, .8, .4, .9, .1], filterShape);
183 const dy = tf.tensor4d([
184 .2, .3, .4, .2, .1, .5, 0, .8, .7, .1, .2, .1, .2, .3,
185 .4, .5, .6, .6, .6, .7, .8, .3, .2, .1, .2, .4, .2
186 ], inputShape);
187 const grads = tf.grads((x, filter) => x.dilation2d(filter, 1, 'same'));
188 const [dx, dfilter] = grads([x, filter], dy);
189 expect(dx.shape).toEqual(x.shape);
190 expectArraysClose(await dx.data(), [
191 0, 0, 0, 0, 0, 0, 0, .8, .5, .2, 0, .4, 0, .3,
192 0, .9, .7, .7, .7, .7, .9, .3, .4, .5, .2, .7, .8
193 ]);
194 expect(dfilter.shape).toEqual(filterShape);
195 expectArraysClose(await dfilter.data(), [1.6, 2.7, 1.1, .2, 0, .5, .3, 0, 2.2, .2, .9, 0]);
196 });
197});
198//# sourceMappingURL=data:application/json;base64,
\No newline at end of file