UNPKG

4.24 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { ENGINE } from '../engine';
18import { SpaceToBatchND } from '../kernel_names';
19import { convertToTensor } from '../tensor_util_env';
20import * as util from '../util';
21import { op } from './operation';
22/**
23 * This operation divides "spatial" dimensions `[1, ..., M]` of the input into
24 * a grid of blocks of shape `blockShape`, and interleaves these blocks with
25 * the "batch" dimension (0) such that in the output, the spatial
26 * dimensions `[1, ..., M]` correspond to the position within the grid,
27 * and the batch dimension combines both the position within a spatial block
28 * and the original batch position. Prior to division into blocks,
29 * the spatial dimensions of the input are optionally zero padded
30 * according to `paddings`. See below for a precise description.
31 *
32 * ```js
33 * const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]);
34 * const blockShape = [2, 2];
35 * const paddings = [[0, 0], [0, 0]];
36 *
37 * x.spaceToBatchND(blockShape, paddings).print();
38 * ```
39 *
40 * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
41 * remainingShape`, where spatialShape has `M` dimensions.
42 * @param blockShape A 1-D array. Must have shape `[M]`, all values must
43 * be >= 1.
44 * @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >=
45 * 0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad
46 * from input dimension `i + 1`, which corresponds to spatial dimension `i`. It
47 * is required that
48 * `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0`
49 *
50 * This operation is equivalent to the following steps:
51 *
52 * 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input
53 * according to `paddings` to produce `padded` of shape paddedShape.
54 *
55 * 2. Reshape `padded` to `reshapedPadded` of shape:
56 * `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ...,
57 * paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape`
58 *
59 * 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded`
60 * of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ...,
61 * paddedShape[M] / blockShape[M-1]] + remainingShape`
62 *
63 * 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the
64 * batch dimension, producing an output tensor of shape:
65 * `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ...,
66 * paddedShape[M] / blockShape[M-1]] + remainingShape`
67 *
68 * @doc {heading: 'Tensors', subheading: 'Transformations'}
69 */
70function spaceToBatchND_(x, blockShape, paddings) {
71 const $x = convertToTensor(x, 'x', 'spaceToBatchND');
72 util.assert($x.rank >= 1 + blockShape.length, () => `input rank ${$x.rank} should be > than [blockShape] ${blockShape.length}`);
73 util.assert(paddings.length === blockShape.length, () => `paddings.shape[0] ${paddings.length} must be equal to [blockShape] ${blockShape.length}`);
74 util.assert($x.shape.reduce((a, b, i) => {
75 if (i > 0 && i <= blockShape.length) {
76 return a &&
77 ((b + paddings[i - 1][0] + paddings[i - 1][1]) %
78 blockShape[i - 1] ===
79 0);
80 }
81 return a;
82 }, true), () => `input spatial dimensions ${$x.shape.slice(1)} with paddings ${paddings.toString()} must be divisible by blockShapes ${blockShape.toString()}`);
83 const inputs = { x: $x };
84 const attrs = { blockShape, paddings };
85 return ENGINE.runKernel(SpaceToBatchND, inputs, attrs);
86}
87export const spaceToBatchND = op({ spaceToBatchND_ });
88//# sourceMappingURL=space_to_batch_nd.js.map
\No newline at end of file