UNPKG

3.02 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { ENGINE } from '../engine';
18import { AvgPoolGrad } from '../kernel_names';
19import { convertToTensor } from '../tensor_util_env';
20import * as util from '../util';
21import { op } from './operation';
22import { reshape } from './reshape';
23/**
24 * Computes the backprop of an 2D avg pool.
25 *
26 * @param dy The dy error, of rank 4 or rank 3 of shape
27 * [batchSize, height, width, channels]. If rank 3, batch of 1 is
28 * assumed.
29 * @param input The input image, of rank 4 or rank 3 of shape
30 * [batchSize, height, width, channels]. If rank 3, batch of 1 is
31 * assumed.
32 * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
33 * `filterSize` is a single number, then `filterHeight == filterWidth`.
34 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
35 * `strides` is a single number, then `strideHeight == strideWidth`.
36 * @param pad A string from: 'same', 'valid'. The type of padding algorithm
37 * used in the forward prop of the op.
38 */
39function avgPoolGrad_(dy, input, filterSize, strides, pad) {
40 const $dy = convertToTensor(dy, 'dy', 'avgPoolGrad');
41 const $input = convertToTensor(input, 'input', 'avgPoolGrad');
42 util.assert($input.rank === $dy.rank, () => `Rank of input (${$input.rank}) does not match rank of dy (${$dy.rank})`);
43 let input4D = $input;
44 let dy4D = $dy;
45 let reshapedTo4D = false;
46 if ($input.rank === 3) {
47 reshapedTo4D = true;
48 input4D =
49 reshape($input, [1, $input.shape[0], $input.shape[1], $input.shape[2]]);
50 dy4D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2]]);
51 }
52 util.assert(dy4D.rank === 4, () => `Error in avgPoolGrad: dy must be rank 4 but got rank ` +
53 `${dy4D.rank}.`);
54 util.assert(input4D.rank === 4, () => `Error in avgPoolGrad: input must be rank 4 but got rank ` +
55 `${input4D.rank}.`);
56 const inputs = { dy: dy4D, input: input4D };
57 const attrs = { filterSize, strides, pad };
58 // tslint:disable-next-line: no-unnecessary-type-assertion
59 const res = ENGINE.runKernel(AvgPoolGrad, inputs, attrs);
60 if (reshapedTo4D) {
61 return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
62 }
63 return res;
64}
65export const avgPoolGrad = op({ avgPoolGrad_ });
66//# sourceMappingURL=avg_pool_grad.js.map
\No newline at end of file