UNPKG

3.54 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { ENGINE } from '../engine';
18import { Conv3DBackpropInputV2 } from '../kernel_names';
19import * as util from '../util';
20import { op } from './operation';
21import { reshape } from './reshape';
22/**
23 * Computes the derivative of the input of a 3D convolution.
24 *
25 * @param xShape The shape of the input: [batch, depth, height, width,
26 * in_channels]. If length of 4, batch of 1 is assumed.
27 * @param dy The derivative of the output, of rank 5 or rank 4 of shape
28 * `[batch, outDepth, outHeight, outWidth, in_channels]`.
29 * If rank 4, batch of 1 is assumed.
30 * @param filter The filter, rank 5, of shape
31 * `[filterDepth, filterHeight, filterWidth, inDepth, outDepth]`.
32 * @param strides The strides of the convolution: `[strideDepth, strideHeight,
33 * strideWidth]`.
34 * @param pad The type of padding algorithm used:
35 * - `same` and stride 1: output will be of same size as input,
36 * regardless of filter size.
37 * - `valid`: output will be smaller than input if filter is larger
38 * than 1x1.
39 */
40function conv3DBackpropInput_(xShape, dy, filter, strides, pad) {
41 util.assert(xShape.length === dy.rank, () => `Length of inShape ` +
42 `(${xShape.length}) and rank of dy (${dy.rank}) must match`);
43 let xShape5D = xShape;
44 let dy5D = dy;
45 let reshapedTo5D = false;
46 if (dy.rank === 4) {
47 reshapedTo5D = true;
48 dy5D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
49 xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]];
50 }
51 const inDepth = xShape5D[4];
52 const outDepth = dy5D.shape[4];
53 util.assert(xShape5D.length === 5, () => `Error in conv3dDerInput: inShape must be length 5, but got length ` +
54 `${xShape5D.length}.`);
55 util.assert(dy5D.rank === 5, () => `Error in conv3dDerInput: dy must be rank 5, but got ` +
56 `rank ${dy5D.rank}`);
57 util.assert(filter.rank === 5, () => `Error in conv3dDerInput: filter must be rank 5, but got ` +
58 `rank ${filter.rank}`);
59 util.assert(inDepth === filter.shape[3], () => `Error in conv3dDerInput: depth of input (${inDepth}) must ` +
60 `match input depth for filter ${filter.shape[3]}.`);
61 util.assert(outDepth === filter.shape[4], () => `Error in conv3dDerInput: depth of output (${outDepth}) must ` +
62 `match output depth for filter ${filter.shape[4]}.`);
63 const inputs = { dy: dy5D, filter };
64 const attrs = { pad, strides, inputShape: xShape5D };
65 // tslint:disable-next-line: no-unnecessary-type-assertion
66 const res = ENGINE.runKernel(Conv3DBackpropInputV2, inputs, attrs);
67 if (reshapedTo5D) {
68 return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
69 }
70 return res;
71}
72export const conv3DBackpropInput = op({ conv3DBackpropInput_ });
73//# sourceMappingURL=conv3d_backprop_input.js.map
\No newline at end of file