1 | /**
|
2 | * @license
|
3 | * Copyright 2018 Google LLC
|
4 | *
|
5 | * Use of this source code is governed by an MIT-style
|
6 | * license that can be found in the LICENSE file or at
|
7 | * https://opensource.org/licenses/MIT.
|
8 | * =============================================================================
|
9 | */
|
10 | /**
|
11 | * Padding Layers.
|
12 | */
|
13 | // Porting Note: In Python Keras, the padding layers are in convolutional.py,
|
14 | // but we decided to put them in a separate file (padding.ts) for clarity.
|
15 | import * as tfc from '@tensorflow/tfjs-core';
|
16 | import { serialization, tidy } from '@tensorflow/tfjs-core';
|
17 | import { imageDataFormat } from '../backend/common';
|
18 | import { InputSpec, Layer } from '../engine/topology';
|
19 | import { ValueError } from '../errors';
|
20 | import { getExactlyOneShape, getExactlyOneTensor } from '../utils/types_utils';
|
21 | /**
|
22 | * Pads the middle dimension of a 3D tensor.
|
23 | *
|
24 | * @param x Input `tf.Tensor` to be padded.
|
25 | * @param padding `Array` of 2 integers, how many zeros to add at the start and
|
26 | * end of the middle dimension (i.e., dimension 1).
|
27 | * @return A padded 3D `tf.Tensor`.
|
28 | */
|
29 | export function temporalPadding(x, padding) {
|
30 | return tidy(() => {
|
31 | if (x.rank !== 3) {
|
32 | throw new ValueError(`temporalPadding expects input tensor to be 3-D, but received a ` +
|
33 | `${x.rank}-D tensor.`);
|
34 | }
|
35 | if (padding == null) {
|
36 | padding = [1, 1];
|
37 | }
|
38 | if (padding.length !== 2) {
|
39 | throw new ValueError(`temporalPadding expects input padding pattern to be a length-2 ` +
|
40 | `array, but received a length-${padding.length} array.`);
|
41 | }
|
42 | const pattern = [[0, 0], padding, [0, 0]];
|
43 | return tfc.pad(x, pattern);
|
44 | });
|
45 | }
|
46 | /**
|
47 | * Pads the 2nd and 3rd dimensions of a 4D tensor.
|
48 | *
|
49 | * @param x Input `tf.Tensor` to be padded.
|
50 | * @param padding `Array` of two `Array`s, each of which is an `Array` of two
|
51 | * integers. The amount of padding at the beginning and end of the 2nd and 3rd
|
52 | * dimensions, respectively.
|
53 | * @param dataFormat 'channelsLast' (default) or 'channelsFirst'.
|
54 | * @return Padded 4D `tf.Tensor`.
|
55 | */
|
56 | export function spatial2dPadding(x, padding, dataFormat) {
|
57 | return tidy(() => {
|
58 | if (x.rank !== 4) {
|
59 | throw new ValueError(`temporalPadding expects input tensor to be 4-D, but received a ` +
|
60 | `${x.rank}-D tensor.`);
|
61 | }
|
62 | if (padding == null) {
|
63 | padding = [[1, 1], [1, 1]];
|
64 | }
|
65 | if (padding.length !== 2 || padding[0].length !== 2 ||
|
66 | padding[1].length !== 2) {
|
67 | throw new ValueError('spatial2dPadding expects `padding` to be an Array of two Arrays, ' +
|
68 | 'each of which is an Array of two integers.');
|
69 | }
|
70 | if (dataFormat == null) {
|
71 | dataFormat = imageDataFormat();
|
72 | }
|
73 | if (dataFormat !== 'channelsLast' && dataFormat !== 'channelsFirst') {
|
74 | throw new ValueError(`Unknown data format: ${dataFormat}. ` +
|
75 | `Supported data formats are 'channelsLast' and 'channelsFirst.`);
|
76 | }
|
77 | let pattern;
|
78 | if (dataFormat === 'channelsFirst') {
|
79 | pattern = [[0, 0], [0, 0], padding[0], padding[1]];
|
80 | }
|
81 | else {
|
82 | pattern = [[0, 0], padding[0], padding[1], [0, 0]];
|
83 | }
|
84 | return tfc.pad(x, pattern);
|
85 | });
|
86 | }
|
87 | export class ZeroPadding2D extends Layer {
|
88 | constructor(args) {
|
89 | if (args == null) {
|
90 | args = {};
|
91 | }
|
92 | super(args);
|
93 | this.dataFormat =
|
94 | args.dataFormat == null ? imageDataFormat() : args.dataFormat;
|
95 | // TODO(cais): Maybe refactor the following logic surrounding `padding`
|
96 | // into a helper method.
|
97 | if (args.padding == null) {
|
98 | this.padding = [[1, 1], [1, 1]];
|
99 | }
|
100 | else if (typeof args.padding === 'number') {
|
101 | this.padding =
|
102 | [[args.padding, args.padding], [args.padding, args.padding]];
|
103 | }
|
104 | else {
|
105 | args.padding = args.padding;
|
106 | if (args.padding.length !== 2) {
|
107 | throw new ValueError(`ZeroPadding2D expects padding to be a length-2 array, but ` +
|
108 | `received a length-${args.padding.length} array.`);
|
109 | }
|
110 | let heightPadding;
|
111 | let widthPadding;
|
112 | if (typeof args.padding[0] === 'number') {
|
113 | heightPadding = [args.padding[0], args.padding[0]];
|
114 | widthPadding = [args.padding[1], args.padding[1]];
|
115 | }
|
116 | else {
|
117 | args.padding = args.padding;
|
118 | if (args.padding[0].length !== 2) {
|
119 | throw new ValueError(`ZeroPadding2D expects height padding to be a length-2 array, ` +
|
120 | `but received a length-${args.padding[0].length} array.`);
|
121 | }
|
122 | heightPadding = args.padding[0];
|
123 | if (args.padding[1].length !== 2) {
|
124 | throw new ValueError(`ZeroPadding2D expects width padding to be a length-2 array, ` +
|
125 | `but received a length-${args.padding[1].length} array.`);
|
126 | }
|
127 | widthPadding = args.padding[1];
|
128 | }
|
129 | this.padding = [heightPadding, widthPadding];
|
130 | }
|
131 | this.inputSpec = [new InputSpec({ ndim: 4 })];
|
132 | }
|
133 | computeOutputShape(inputShape) {
|
134 | inputShape = getExactlyOneShape(inputShape);
|
135 | let rows;
|
136 | let cols;
|
137 | if (this.dataFormat === 'channelsFirst') {
|
138 | if (inputShape[2] != null && inputShape[2] >= 0) {
|
139 | rows = inputShape[2] + this.padding[0][0] + this.padding[0][1];
|
140 | }
|
141 | else {
|
142 | rows = null;
|
143 | }
|
144 | if (inputShape[3] != null && inputShape[3] >= 0) {
|
145 | cols = inputShape[3] + this.padding[1][0] + this.padding[1][1];
|
146 | }
|
147 | else {
|
148 | cols = null;
|
149 | }
|
150 | return [inputShape[0], inputShape[1], rows, cols];
|
151 | }
|
152 | else {
|
153 | if (inputShape[1] != null && inputShape[1] >= 0) {
|
154 | rows = inputShape[1] + this.padding[0][0] + this.padding[0][1];
|
155 | }
|
156 | else {
|
157 | rows = null;
|
158 | }
|
159 | if (inputShape[2] != null && inputShape[2] >= 0) {
|
160 | cols = inputShape[2] + this.padding[1][0] + this.padding[1][1];
|
161 | }
|
162 | else {
|
163 | cols = null;
|
164 | }
|
165 | return [inputShape[0], rows, cols, inputShape[3]];
|
166 | }
|
167 | }
|
168 | call(inputs, kwargs) {
|
169 | return tidy(() => spatial2dPadding(getExactlyOneTensor(inputs), this.padding, this.dataFormat));
|
170 | }
|
171 | getConfig() {
|
172 | const config = {
|
173 | padding: this.padding,
|
174 | dataFormat: this.dataFormat,
|
175 | };
|
176 | const baseConfig = super.getConfig();
|
177 | Object.assign(config, baseConfig);
|
178 | return config;
|
179 | }
|
180 | }
|
181 | /** @nocollapse */
|
182 | ZeroPadding2D.className = 'ZeroPadding2D';
|
183 | serialization.registerClass(ZeroPadding2D);
|
184 | //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"padding.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/layers/padding.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH;;GAEG;AAEH,6EAA6E;AAC7E,4EAA4E;AAE5E,OAAO,KAAK,GAAG,MAAM,uBAAuB,CAAC;AAC7C,OAAO,EAAC,aAAa,EAAU,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAElE,OAAO,EAAC,eAAe,EAAC,MAAM,mBAAmB,CAAC;AAClD,OAAO,EAAC,SAAS,EAAE,KAAK,EAAY,MAAM,oBAAoB,CAAC;AAC/D,OAAO,EAAC,UAAU,EAAC,MAAM,WAAW,CAAC;AAGrC,OAAO,EAAC,kBAAkB,EAAE,mBAAmB,EAAC,MAAM,sBAAsB,CAAC;AAE7E;;;;;;;GAOG;AACH,MAAM,UAAU,eAAe,CAAC,CAAS,EAAE,OAA0B;IACnE,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,IAAI,CAAC,CAAC,IAAI,KAAK,CAAC,EAAE;YAChB,MAAM,IAAI,UAAU,CAChB,iEAAiE;gBACjE,GAAG,CAAC,CAAC,IAAI,YAAY,CAAC,CAAC;SAC5B;QAED,IAAI,OAAO,IAAI,IAAI,EAAE;YACnB,OAAO,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;SAClB;QACD,IAAI,OAAO,CAAC,MAAM,KAAK,CAAC,EAAE;YACxB,MAAM,IAAI,UAAU,CAChB,iEAAiE;gBACjE,gCAAgC,OAAO,CAAC,MAAM,SAAS,CAAC,CAAC;SAC9D;QAED,MAAM,OAAO,GAA4B,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACnE,OAAO,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;IAC7B,CAAC,CAAC,CAAC;AACL,CAAC;AAED;;;;;;;;;GASG;AACH,MAAM,UAAU,gBAAgB,CAC5B,CAAS,EAAE,OAA8C,EACzD,UAAuB;IACzB,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,IAAI,CAAC,CAAC,IAAI,KAAK,CAAC,EAAE;YAChB,MAAM,IAAI,UAAU,CAChB,iEAAiE;gBACjE,GAAG,CAAC,CAAC,IAAI,YAAY,CAAC,CAAC;SAC5B;QAED,IAAI,OAAO,IAAI,IAAI,EAAE;YACnB,OAAO,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;SAC5B;QACD,IAAI,OAAO,CAAC,MAAM,KAAK,CAAC,IAAI,OAAO,CAAC,CAAC,CAAC,CAAC,MAAM,KAAK,CAAC;YAC/C,OAAO,CAAC,CAAC,CAAC,CAAC,MAAM,KAAK,CAAC,EAAE;YAC3B,MAAM,IAAI,UAAU,CAChB,mEAAmE;gBACnE,4CAA4C,CAAC,CAAC;SACnD;QAED,IAAI,UAAU,IAAI,IAAI,EAAE;YACtB,UAAU,GAAG,eAAe,EAAE,CAAC;SAChC;QACD,IAAI,UAAU,KAAK,cAAc,IAAI,UAAU,KAAK,eAAe,EAAE;YACnE,MAAM,IAAI,UAAU,CAChB,wBAAwB,UAAU,IAAI;gBACtC,+DAA+D,CAAC,CAAC;SACtE;QAED,IAAI,OAAgC,CAAC;QACrC,IAAI,UAAU,KAAK,eAAe,EAAE;YAClC,OAAO,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC;SACpD;aAAM;YACL,OAAO,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;SACpD;QAED,OAAO,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;IAC7B,CAAC,CAAC,CAAC;AACL,CAAC;AA2BD,MAAM,OAAO,aAAc,SAAQ,KAAK;IAMtC,YAAY,IAA6B;QACvC,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,UAAU;YACX,IAAI,CAAC,UAAU,IAAI,IAAI,CAAC,CAAC,CAAC,eAAe,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,UAAU,CAAC;QAClE,uEAAuE;QACvE,0BAA0B;QAC1B,IAAI,IAAI,CAAC,OAAO,IAAI,IAAI,EAAE;YACxB,IAAI,CAAC,OAAO,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;SACjC;aAAM,IAAI,OAAO,IAAI,CAAC,OAAO,KAAK,QAAQ,EAAE;YAC3C,IAAI,CAAC,OAAO;gBACR,CAAC,CAAC,IAAI,CAAC,OAAO,EAAE,IAAI,CAAC,OAAO,CAAC,EAAE,CAAC,IAAI,CAAC,OAAO,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC;SAClE;aAAM;YACL,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,OAAO,CAAC;YAC5B,IAAI,IAAI,CAAC,OAAO,CAAC,MAAM,KAAK,CAAC,EAAE;gBAC7B,MAAM,IAAI,UAAU,CAChB,4DAA4D;oBAC5D,qBAAqB,IAAI,CAAC,OAAO,CAAC,MAAM,SAAS,CAAC,CAAC;aACxD;YAED,IAAI,aAA+B,CAAC;YACpC,IAAI,YAA8B,CAAC;YACnC,IAAI,OAAO,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,QAAQ,EAAE;gBACvC,aAAa,GAAG,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC;gBACnD,YAAY,GAAG,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,CAAW,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC,CAAW,CAAC,CAAC;aACvE;iBAAM;gBACL,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,OAA+C,CAAC;gBAEpE,IAAI,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,MAAM,KAAK,CAAC,EAAE;oBAChC,MAAM,IAAI,UAAU,CAChB,+DAA+D;wBAC/D,yBAAyB,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,MAAM,SAAS,CAAC,CAAC;iBAC/D;gBACD,aAAa,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC,CAAqB,CAAC;gBAEpD,IAAI,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,MAAM,KAAK,CAAC,EAAE;oBAChC,MAAM,IAAI,UAAU,CAChB,8DAA8D;wBAC9D,yBAAyB,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,MAAM,SAAS,CAAC,CAAC;iBAC/D;gBACD,YAAY,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC,CAAqB,CAAC;aACpD;YACD,IAAI,CAAC,OAAO,GAAG,CAAC,aAAa,EAAE,YAAY,CAAC,CAAC;SAC9C;QACD,IAAI,CAAC,SAAS,GAAG,CAAC,IAAI,SAAS,CAAC,EAAC,IAAI,EAAE,CAAC,EAAC,CAAC,CAAC,CAAC;IAC9C,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,UAAU,GAAG,kBAAkB,CAAC,UAAU,CAAC,CAAC;QAE5C,IAAI,IAAY,CAAC;QACjB,IAAI,IAAY,CAAC;QACjB,IAAI,IAAI,CAAC,UAAU,KAAK,eAAe,EAAE;YACvC,IAAI,UAAU,CAAC,CAAC,CAAC,IAAI,IAAI,IAAI,UAAU,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE;gBAC/C,IAAI,GAAG,UAAU,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;aAChE;iBAAM;gBACL,IAAI,GAAG,IAAI,CAAC;aACb;YACD,IAAI,UAAU,CAAC,CAAC,CAAC,IAAI,IAAI,IAAI,UAAU,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE;gBAC/C,IAAI,GAAG,UAAU,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;aAChE;iBAAM;gBACL,IAAI,GAAG,IAAI,CAAC;aACb;YACD,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,CAAC;SACnD;aAAM;YACL,IAAI,UAAU,CAAC,CAAC,CAAC,IAAI,IAAI,IAAI,UAAU,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE;gBAC/C,IAAI,GAAG,UAAU,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;aAChE;iBAAM;gBACL,IAAI,GAAG,IAAI,CAAC;aACb;YACD,IAAI,UAAU,CAAC,CAAC,CAAC,IAAI,IAAI,IAAI,UAAU,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE;gBAC/C,IAAI,GAAG,UAAU,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;aAChE;iBAAM;gBACL,IAAI,GAAG,IAAI,CAAC;aACb;YACD,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,IAAI,EAAE,IAAI,EAAE,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC;SACnD;IACH,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,OAAO,IAAI,CACP,GAAG,EAAE,CAAC,gBAAgB,CAClB,mBAAmB,CAAC,MAAM,CAAC,EAAE,IAAI,CAAC,OAAO,EAAE,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC;IACvE,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B;YACvC,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,UAAU,EAAE,IAAI,CAAC,UAAU;SAC5B,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AArGD,kBAAkB;AACX,uBAAS,GAAG,eAAe,CAAC;AAsGrC,aAAa,CAAC,aAAa,CAAC,aAAa,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/**\n * Padding Layers.\n */\n\n// Porting Note: In Python Keras, the padding layers are in convolutional.py,\n//   but we decided to put them in a separate file (padding.ts) for clarity.\n\nimport * as tfc from '@tensorflow/tfjs-core';\nimport {serialization, Tensor, tidy} from '@tensorflow/tfjs-core';\n\nimport {imageDataFormat} from '../backend/common';\nimport {InputSpec, Layer, LayerArgs} from '../engine/topology';\nimport {ValueError} from '../errors';\nimport {DataFormat, Shape} from '../keras_format/common';\nimport {Kwargs} from '../types';\nimport {getExactlyOneShape, getExactlyOneTensor} from '../utils/types_utils';\n\n/**\n * Pads the middle dimension of a 3D tensor.\n *\n * @param x Input `tf.Tensor` to be padded.\n * @param padding `Array` of 2 integers, how many zeros to add at the start and\n *   end of the middle dimension (i.e., dimension 1).\n * @return A padded 3D `tf.Tensor`.\n */\nexport function temporalPadding(x: Tensor, padding?: [number, number]): Tensor {\n  return tidy(() => {\n    if (x.rank !== 3) {\n      throw new ValueError(\n          `temporalPadding expects input tensor to be 3-D, but received a ` +\n          `${x.rank}-D tensor.`);\n    }\n\n    if (padding == null) {\n      padding = [1, 1];\n    }\n    if (padding.length !== 2) {\n      throw new ValueError(\n          `temporalPadding expects input padding pattern to be a length-2 ` +\n          `array, but received a length-${padding.length} array.`);\n    }\n\n    const pattern: Array<[number, number]> = [[0, 0], padding, [0, 0]];\n    return tfc.pad(x, pattern);\n  });\n}\n\n/**\n * Pads the 2nd and 3rd dimensions of a 4D tensor.\n *\n * @param x Input `tf.Tensor` to be padded.\n * @param padding `Array` of two `Array`s, each of which is an `Array` of two\n *   integers. The amount of padding at the beginning and end of the 2nd and 3rd\n *   dimensions, respectively.\n * @param dataFormat 'channelsLast' (default) or 'channelsFirst'.\n * @return Padded 4D `tf.Tensor`.\n */\nexport function spatial2dPadding(\n    x: Tensor, padding?: [[number, number], [number, number]],\n    dataFormat?: DataFormat): Tensor {\n  return tidy(() => {\n    if (x.rank !== 4) {\n      throw new ValueError(\n          `temporalPadding expects input tensor to be 4-D, but received a ` +\n          `${x.rank}-D tensor.`);\n    }\n\n    if (padding == null) {\n      padding = [[1, 1], [1, 1]];\n    }\n    if (padding.length !== 2 || padding[0].length !== 2 ||\n        padding[1].length !== 2) {\n      throw new ValueError(\n          'spatial2dPadding expects `padding` to be an Array of two Arrays, ' +\n          'each of which is an Array of two integers.');\n    }\n\n    if (dataFormat == null) {\n      dataFormat = imageDataFormat();\n    }\n    if (dataFormat !== 'channelsLast' && dataFormat !== 'channelsFirst') {\n      throw new ValueError(\n          `Unknown data format: ${dataFormat}. ` +\n          `Supported data formats are 'channelsLast' and 'channelsFirst.`);\n    }\n\n    let pattern: Array<[number, number]>;\n    if (dataFormat === 'channelsFirst') {\n      pattern = [[0, 0], [0, 0], padding[0], padding[1]];\n    } else {\n      pattern = [[0, 0], padding[0], padding[1], [0, 0]];\n    }\n\n    return tfc.pad(x, pattern);\n  });\n}\n\nexport declare interface ZeroPadding2DLayerArgs extends LayerArgs {\n  /**\n   * Integer, or `Array` of 2 integers, or `Array` of 2 `Array`s, each of\n   * which is an `Array` of 2 integers.\n   * - If integer, the same symmetric padding is applied to width and height.\n   * - If `Array` of 2 integers, interpreted as two different symmetric values\n   *   for height and width:\n   *   `[symmetricHeightPad, symmetricWidthPad]`.\n   * - If `Array` of 2 `Array`s, interpreted as:\n   *   `[[topPad, bottomPad], [leftPad, rightPad]]`.\n   */\n  padding?: number|[number, number]|[[number, number], [number, number]];\n\n  /**\n   * One of `'channelsLast'` (default) and `'channelsFirst'`.\n   *\n   * The ordering of the dimensions in the inputs.\n   * `channelsLast` corresponds to inputs with shape\n   * `[batch, height, width, channels]` while `channelsFirst`\n   * corresponds to inputs with shape\n   * `[batch, channels, height, width]`.\n   */\n  dataFormat?: DataFormat;\n}\n\nexport class ZeroPadding2D extends Layer {\n  /** @nocollapse */\n  static className = 'ZeroPadding2D';\n  readonly dataFormat: DataFormat;\n  readonly padding: [[number, number], [number, number]];\n\n  constructor(args?: ZeroPadding2DLayerArgs) {\n    if (args == null) {\n      args = {};\n    }\n    super(args);\n\n    this.dataFormat =\n        args.dataFormat == null ? imageDataFormat() : args.dataFormat;\n    // TODO(cais): Maybe refactor the following logic surrounding `padding`\n    //   into a helper method.\n    if (args.padding == null) {\n      this.padding = [[1, 1], [1, 1]];\n    } else if (typeof args.padding === 'number') {\n      this.padding =\n          [[args.padding, args.padding], [args.padding, args.padding]];\n    } else {\n      args.padding = args.padding;\n      if (args.padding.length !== 2) {\n        throw new ValueError(\n            `ZeroPadding2D expects padding to be a length-2 array, but ` +\n            `received a length-${args.padding.length} array.`);\n      }\n\n      let heightPadding: [number, number];\n      let widthPadding: [number, number];\n      if (typeof args.padding[0] === 'number') {\n        heightPadding = [args.padding[0], args.padding[0]];\n        widthPadding = [args.padding[1] as number, args.padding[1] as number];\n      } else {\n        args.padding = args.padding as [[number, number], [number, number]];\n\n        if (args.padding[0].length !== 2) {\n          throw new ValueError(\n              `ZeroPadding2D expects height padding to be a length-2 array, ` +\n              `but received a length-${args.padding[0].length} array.`);\n        }\n        heightPadding = args.padding[0] as [number, number];\n\n        if (args.padding[1].length !== 2) {\n          throw new ValueError(\n              `ZeroPadding2D expects width padding to be a length-2 array, ` +\n              `but received a length-${args.padding[1].length} array.`);\n        }\n        widthPadding = args.padding[1] as [number, number];\n      }\n      this.padding = [heightPadding, widthPadding];\n    }\n    this.inputSpec = [new InputSpec({ndim: 4})];\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    inputShape = getExactlyOneShape(inputShape);\n\n    let rows: number;\n    let cols: number;\n    if (this.dataFormat === 'channelsFirst') {\n      if (inputShape[2] != null && inputShape[2] >= 0) {\n        rows = inputShape[2] + this.padding[0][0] + this.padding[0][1];\n      } else {\n        rows = null;\n      }\n      if (inputShape[3] != null && inputShape[3] >= 0) {\n        cols = inputShape[3] + this.padding[1][0] + this.padding[1][1];\n      } else {\n        cols = null;\n      }\n      return [inputShape[0], inputShape[1], rows, cols];\n    } else {\n      if (inputShape[1] != null && inputShape[1] >= 0) {\n        rows = inputShape[1] + this.padding[0][0] + this.padding[0][1];\n      } else {\n        rows = null;\n      }\n      if (inputShape[2] != null && inputShape[2] >= 0) {\n        cols = inputShape[2] + this.padding[1][0] + this.padding[1][1];\n      } else {\n        cols = null;\n      }\n      return [inputShape[0], rows, cols, inputShape[3]];\n    }\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    return tidy(\n        () => spatial2dPadding(\n            getExactlyOneTensor(inputs), this.padding, this.dataFormat));\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {\n      padding: this.padding,\n      dataFormat: this.dataFormat,\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(ZeroPadding2D);\n"]} |
\ | No newline at end of file |