1 | /**
|
2 | * @license
|
3 | * Copyright 2020 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | import { ENGINE } from '../engine';
|
18 | import { SpaceToBatchND } from '../kernel_names';
|
19 | import { convertToTensor } from '../tensor_util_env';
|
20 | import * as util from '../util';
|
21 | import { op } from './operation';
|
22 | /**
|
23 | * This operation divides "spatial" dimensions `[1, ..., M]` of the input into
|
24 | * a grid of blocks of shape `blockShape`, and interleaves these blocks with
|
25 | * the "batch" dimension (0) such that in the output, the spatial
|
26 | * dimensions `[1, ..., M]` correspond to the position within the grid,
|
27 | * and the batch dimension combines both the position within a spatial block
|
28 | * and the original batch position. Prior to division into blocks,
|
29 | * the spatial dimensions of the input are optionally zero padded
|
30 | * according to `paddings`. See below for a precise description.
|
31 | *
|
32 | * ```js
|
33 | * const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]);
|
34 | * const blockShape = [2, 2];
|
35 | * const paddings = [[0, 0], [0, 0]];
|
36 | *
|
37 | * x.spaceToBatchND(blockShape, paddings).print();
|
38 | * ```
|
39 | *
|
40 | * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
|
41 | * remainingShape`, where spatialShape has `M` dimensions.
|
42 | * @param blockShape A 1-D array. Must have shape `[M]`, all values must
|
43 | * be >= 1.
|
44 | * @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >=
|
45 | * 0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad
|
46 | * from input dimension `i + 1`, which corresponds to spatial dimension `i`. It
|
47 | * is required that
|
48 | * `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0`
|
49 | *
|
50 | * This operation is equivalent to the following steps:
|
51 | *
|
52 | * 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input
|
53 | * according to `paddings` to produce `padded` of shape paddedShape.
|
54 | *
|
55 | * 2. Reshape `padded` to `reshapedPadded` of shape:
|
56 | * `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ...,
|
57 | * paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape`
|
58 | *
|
59 | * 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded`
|
60 | * of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ...,
|
61 | * paddedShape[M] / blockShape[M-1]] + remainingShape`
|
62 | *
|
63 | * 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the
|
64 | * batch dimension, producing an output tensor of shape:
|
65 | * `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ...,
|
66 | * paddedShape[M] / blockShape[M-1]] + remainingShape`
|
67 | *
|
68 | * @doc {heading: 'Tensors', subheading: 'Transformations'}
|
69 | */
|
70 | function spaceToBatchND_(x, blockShape, paddings) {
|
71 | const $x = convertToTensor(x, 'x', 'spaceToBatchND');
|
72 | util.assert($x.rank >= 1 + blockShape.length, () => `input rank ${$x.rank} should be > than [blockShape] ${blockShape.length}`);
|
73 | util.assert(paddings.length === blockShape.length, () => `paddings.shape[0] ${paddings.length} must be equal to [blockShape] ${blockShape.length}`);
|
74 | util.assert($x.shape.reduce((a, b, i) => {
|
75 | if (i > 0 && i <= blockShape.length) {
|
76 | return a &&
|
77 | ((b + paddings[i - 1][0] + paddings[i - 1][1]) %
|
78 | blockShape[i - 1] ===
|
79 | 0);
|
80 | }
|
81 | return a;
|
82 | }, true), () => `input spatial dimensions ${$x.shape.slice(1)} with paddings ${paddings.toString()} must be divisible by blockShapes ${blockShape.toString()}`);
|
83 | const inputs = { x: $x };
|
84 | const attrs = { blockShape, paddings };
|
85 | return ENGINE.runKernel(SpaceToBatchND, inputs, attrs);
|
86 | }
|
87 | export const spaceToBatchND = op({ spaceToBatchND_ });
|
88 | //# sourceMappingURL=space_to_batch_nd.js.map |
\ | No newline at end of file |