UNPKG

3.77 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { ENGINE } from '../engine';
18import { Conv2DBackpropFilter } from '../kernel_names';
19import * as util from '../util';
20import { op } from './operation';
21import { reshape } from './reshape';
22/**
23 * Computes the derivative of the filter of a 2D convolution.
24 *
25 * @param x The input tensor, of rank 4 or rank 3 of shape
26 * [batch, height, width, inChannels]. If rank 3, batch of 1 is assumed.
27 * @param dy The dy image, of rank 4 or rank 3, of shape
28 * [batch, height, width, outDepth]. If rank 3, batch of 1 is assumed.
29 * @param filterShape The shape of the filter, length 4,
30 * [filterHeight, filterWidth, inDepth, outDepth].
31 * @param strides The strides of the convolution: [strideHeight,
32 * strideWidth].
33 * @param pad A string from: 'same', 'valid'. The type of padding algorithm
34 * used in the forward prop of the op.
35 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
36 * "NHWC". Specify the data format of the input and output data. With the
37 * default format "NHWC", the data is stored in the order of: [batch,
38 * height, width, channels].
39 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
40 * provided, it will default to truncate.
41 */
42function conv2DBackpropFilter_(x, dy, filterShape, strides, pad, dataFormat = 'NHWC', dimRoundingMode) {
43 let x4D = x;
44 if (x.rank === 3) {
45 x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
46 }
47 let dy4D = dy;
48 if (dy4D.rank === 3) {
49 dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
50 }
51 util.assert(x4D.rank === 4, () => `Error in conv2dDerFilter: input must be rank 4, but got shape ` +
52 `${x4D.shape}.`);
53 util.assert(dy4D.rank === 4, () => `Error in conv2dDerFilter: dy must be rank 4, but got shape ` +
54 `${dy4D.shape}.`);
55 util.assert(filterShape.length === 4, () => `Error in conv2dDerFilter: filterShape must be length 4, but got ` +
56 `${filterShape}.`);
57 const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
58 const outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
59 util.assert(inDepth === filterShape[2], () => `Error in conv2dDerFilter: depth of input ${inDepth}) must ` +
60 `match input depth in filter (${filterShape[2]}.`);
61 util.assert(outDepth === filterShape[3], () => `Error in conv2dDerFilter: depth of dy (${outDepth}) must ` +
62 `match output depth for filter (${filterShape[3]}).`);
63 if (dimRoundingMode != null) {
64 util.assert(util.isInt(pad), () => `Error in conv2dDerFilter: pad must be an integer when using, ` +
65 `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
66 }
67 const inputs = { x: x4D, dy: dy4D };
68 const attrs = { strides, pad, dataFormat, dimRoundingMode, filterShape };
69 // tslint:disable-next-line: no-unnecessary-type-assertion
70 return ENGINE.runKernel(Conv2DBackpropFilter, inputs, attrs);
71}
72export const conv2DBackpropFilter = op({ conv2DBackpropFilter_ });
73//# sourceMappingURL=conv2d_backprop_filter.js.map
\No newline at end of file