UNPKG

24.1 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/**
11 * TensorFlow.js Layers: Depthwise Convolutional Layers
12 */
13import * as tfc from '@tensorflow/tfjs-core';
14import { serialization, tidy } from '@tensorflow/tfjs-core';
15import { imageDataFormat } from '../backend/common';
16import * as K from '../backend/tfjs_backend';
17import { checkDataFormat } from '../common';
18import { getConstraint, serializeConstraint } from '../constraints';
19import { ValueError } from '../errors';
20import { getInitializer, serializeInitializer } from '../initializers';
21import { getRegularizer, serializeRegularizer } from '../regularizers';
22import { convOutputLength } from '../utils/conv_utils';
23import { getExactlyOneShape, getExactlyOneTensor } from '../utils/types_utils';
24import { BaseConv, preprocessConv2DInput } from './convolutional';
25/**
26 * 2D convolution with separable filters.
27 * @param x Input tensor.
28 * @param depthwiseKernel Convolution kernel for depthwise convolution.
29 * @param strides Strides (Array of two integers).
30 * @param padding Padding model.
31 * @param dataFormat Data format.
32 * @param dilationRate Array of two integers, dilation rates for the separable
33 * convolution.
34 * @returns Output tensor.
35 * @throws ValueError If depthwiseKernel is not a 4D array.
36 */
37export function depthwiseConv2d(x, depthwiseKernel, strides = [1, 1], padding = 'valid', dataFormat, dilationRate) {
38 return tidy(() => {
39 if (dataFormat == null) {
40 dataFormat = imageDataFormat();
41 }
42 checkDataFormat(dataFormat);
43 let y = preprocessConv2DInput(x, dataFormat);
44 if (x.rank !== 4) {
45 throw new ValueError(`Input for depthwiseConv2d is required to be 4-D, but is instead ` +
46 `${x.rank}-D`);
47 }
48 if (depthwiseKernel.rank !== 4) {
49 throw new ValueError(`depthwiseKernel is required to be 4-D, but is instead ` +
50 `${depthwiseKernel.rank}-D`);
51 }
52 y = tfc.depthwiseConv2d(y, depthwiseKernel, strides, padding === 'same' ? 'same' : 'valid', 'NHWC', dilationRate);
53 if (dataFormat === 'channelsFirst') {
54 y = tfc.transpose(y, [0, 3, 1, 2]);
55 }
56 return y;
57 });
58}
59export class DepthwiseConv2D extends BaseConv {
60 constructor(args) {
61 super(2, args);
62 this.depthwiseKernel = null;
63 this.depthMultiplier =
64 args.depthMultiplier == null ? 1 : args.depthMultiplier;
65 this.depthwiseInitializer = getInitializer(args.depthwiseInitializer || this.DEFAULT_KERNEL_INITIALIZER);
66 this.depthwiseConstraint = getConstraint(args.depthwiseConstraint);
67 this.depthwiseRegularizer = getRegularizer(args.depthwiseRegularizer);
68 }
69 build(inputShape) {
70 inputShape = getExactlyOneShape(inputShape);
71 if (inputShape.length < 4) {
72 throw new ValueError(`Inputs to DepthwiseConv2D should have rank 4. ` +
73 `Received input shape: ${JSON.stringify(inputShape)}.`);
74 }
75 const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : 3;
76 if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) {
77 throw new ValueError('The channel dimension of the inputs to DepthwiseConv2D should ' +
78 `be defined, but is not (${inputShape[channelAxis]}).`);
79 }
80 const inputDim = inputShape[channelAxis];
81 const depthwiseKernelShape = [
82 this.kernelSize[0], this.kernelSize[1], inputDim, this.depthMultiplier
83 ];
84 this.depthwiseKernel = this.addWeight('depthwise_kernel', depthwiseKernelShape, null, this.depthwiseInitializer, this.depthwiseRegularizer, true, this.depthwiseConstraint);
85 if (this.useBias) {
86 this.bias = this.addWeight('bias', [inputDim * this.depthMultiplier], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
87 }
88 else {
89 this.bias = null;
90 }
91 this.built = true;
92 }
93 call(inputs, kwargs) {
94 return tidy(() => {
95 inputs = getExactlyOneTensor(inputs);
96 let outputs = depthwiseConv2d(inputs, this.depthwiseKernel.read(), this.strides, this.padding, this.dataFormat, null);
97 // TODO(cais): Add support for dilation.
98 if (this.useBias) {
99 outputs = K.biasAdd(outputs, this.bias.read(), this.dataFormat);
100 }
101 if (this.activation != null) {
102 outputs = this.activation.apply(outputs);
103 }
104 return outputs;
105 });
106 }
107 computeOutputShape(inputShape) {
108 inputShape = getExactlyOneShape(inputShape);
109 const rows = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
110 const cols = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
111 const outFilters = this.dataFormat === 'channelsFirst' ?
112 inputShape[1] * this.depthMultiplier :
113 inputShape[3] * this.depthMultiplier;
114 const outRows = convOutputLength(rows, this.kernelSize[0], this.padding, this.strides[0]);
115 const outCols = convOutputLength(cols, this.kernelSize[1], this.padding, this.strides[1]);
116 if (this.dataFormat === 'channelsFirst') {
117 return [inputShape[0], outFilters, outRows, outCols];
118 }
119 else {
120 // In this case, assume 'channelsLast'.
121 return [inputShape[0], outRows, outCols, outFilters];
122 }
123 }
124 getConfig() {
125 const config = super.getConfig();
126 config['depthMultiplier'] = this.depthMultiplier;
127 config['depthwiseInitializer'] =
128 serializeInitializer(this.depthwiseInitializer);
129 config['depthwiseRegularizer'] =
130 serializeRegularizer(this.depthwiseRegularizer);
131 config['depthwiseConstraint'] =
132 serializeConstraint(this.depthwiseRegularizer);
133 return config;
134 }
135}
136/** @nocollapse */
137DepthwiseConv2D.className = 'DepthwiseConv2D';
138serialization.registerClass(DepthwiseConv2D);
139//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"convolutional_depthwise.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/layers/convolutional_depthwise.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH;;GAEG;AAEH,OAAO,KAAK,GAAG,MAAM,uBAAuB,CAAC;AAC7C,OAAO,EAAC,aAAa,EAAoB,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAE5E,OAAO,EAAC,eAAe,EAAC,MAAM,mBAAmB,CAAC;AAClD,OAAO,KAAK,CAAC,MAAM,yBAAyB,CAAC;AAC7C,OAAO,EAAC,eAAe,EAAC,MAAM,WAAW,CAAC;AAC1C,OAAO,EAAmC,aAAa,EAAE,mBAAmB,EAAC,MAAM,gBAAgB,CAAC;AACpG,OAAO,EAAC,UAAU,EAAC,MAAM,WAAW,CAAC;AACrC,OAAO,EAAC,cAAc,EAAsC,oBAAoB,EAAC,MAAM,iBAAiB,CAAC;AAEzG,OAAO,EAAC,cAAc,EAAsC,oBAAoB,EAAC,MAAM,iBAAiB,CAAC;AAEzG,OAAO,EAAC,gBAAgB,EAAC,MAAM,qBAAqB,CAAC;AACrD,OAAO,EAAC,kBAAkB,EAAE,mBAAmB,EAAC,MAAM,sBAAsB,CAAC;AAG7E,OAAO,EAAC,QAAQ,EAAoC,qBAAqB,EAAC,MAAM,iBAAiB,CAAC;AAElG;;;;;;;;;;;GAWG;AACH,MAAM,UAAU,eAAe,CAC3B,CAAS,EAAE,eAAuB,EAAE,UAA4B,CAAC,CAAC,EAAE,CAAC,CAAC,EACtE,OAAO,GAAG,OAAO,EAAE,UAAuB,EAC1C,YAA+B;IACjC,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,IAAI,UAAU,IAAI,IAAI,EAAE;YACtB,UAAU,GAAG,eAAe,EAAE,CAAC;SAChC;QACD,eAAe,CAAC,UAAU,CAAC,CAAC;QAC5B,IAAI,CAAC,GAAG,qBAAqB,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC;QAC7C,IAAI,CAAC,CAAC,IAAI,KAAK,CAAC,EAAE;YAChB,MAAM,IAAI,UAAU,CAChB,kEAAkE;gBAClE,GAAG,CAAC,CAAC,IAAI,IAAI,CAAC,CAAC;SACpB;QACD,IAAI,eAAe,CAAC,IAAI,KAAK,CAAC,EAAE;YAC9B,MAAM,IAAI,UAAU,CAChB,wDAAwD;gBACxD,GAAG,eAAe,CAAC,IAAI,IAAI,CAAC,CAAC;SAClC;QACD,CAAC,GAAG,GAAG,CAAC,eAAe,CACnB,CAAa,EAAE,eAA2B,EAAE,OAAO,EACnD,OAAO,KAAK,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,OAAO,EAAE,MAAM,EAAE,YAAY,CAAC,CAAC;QACjE,IAAI,UAAU,KAAK,eAAe,EAAE;YAClC,CAAC,GAAG,GAAG,CAAC,SAAS,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;SACpC;QACD,OAAO,CAAC,CAAC;IACX,CAAC,CAAC,CAAC;AACL,CAAC;AAoCD,MAAM,OAAO,eAAgB,SAAQ,QAAQ;IAU3C,YAAY,IAA8B;QACxC,KAAK,CAAC,CAAC,EAAE,IAAqB,CAAC,CAAC;QAH1B,oBAAe,GAAkB,IAAI,CAAC;QAI5C,IAAI,CAAC,eAAe;YAChB,IAAI,CAAC,eAAe,IAAI,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,eAAe,CAAC;QAC5D,IAAI,CAAC,oBAAoB,GAAG,cAAc,CACtC,IAAI,CAAC,oBAAoB,IAAI,IAAI,CAAC,0BAA0B,CAAC,CAAC;QAClE,IAAI,CAAC,mBAAmB,GAAG,aAAa,CAAC,IAAI,CAAC,mBAAmB,CAAC,CAAC;QACnE,IAAI,CAAC,oBAAoB,GAAG,cAAc,CAAC,IAAI,CAAC,oBAAoB,CAAC,CAAC;IACxE,CAAC;IAEQ,KAAK,CAAC,UAAyB;QACtC,UAAU,GAAG,kBAAkB,CAAC,UAAU,CAAC,CAAC;QAC5C,IAAI,UAAU,CAAC,MAAM,GAAG,CAAC,EAAE;YACzB,MAAM,IAAI,UAAU,CAChB,gDAAgD;gBAChD,yBAAyB,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,GAAG,CAAC,CAAC;SAC7D;QACD,MAAM,WAAW,GAAG,IAAI,CAAC,UAAU,KAAK,eAAe,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAChE,IAAI,UAAU,CAAC,WAAW,CAAC,IAAI,IAAI,IAAI,UAAU,CAAC,WAAW,CAAC,GAAG,CAAC,EAAE;YAClE,MAAM,IAAI,UAAU,CAChB,gEAAgE;gBAChE,2BAA2B,UAAU,CAAC,WAAW,CAAC,IAAI,CAAC,CAAC;SAC7D;QACD,MAAM,QAAQ,GAAG,UAAU,CAAC,WAAW,CAAC,CAAC;QACzC,MAAM,oBAAoB,GAAU;YAClC,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,QAAQ,EAAE,IAAI,CAAC,eAAe;SACvE,CAAC;QAEF,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC,SAAS,CACjC,kBAAkB,EAAE,oBAAoB,EAAE,IAAI,EAC9C,IAAI,CAAC,oBAAoB,EAAE,IAAI,CAAC,oBAAoB,EAAE,IAAI,EAC1D,IAAI,CAAC,mBAAmB,CAAC,CAAC;QAC9B,IAAI,IAAI,CAAC,OAAO,EAAE;YAChB,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC,SAAS,CACtB,MAAM,EAAE,CAAC,QAAQ,GAAG,IAAI,CAAC,eAAe,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,eAAe,EACrE,IAAI,CAAC,eAAe,EAAE,IAAI,EAAE,IAAI,CAAC,cAAc,CAAC,CAAC;SACtD;aAAM;YACL,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC;SAClB;QACD,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC;IACpB,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,MAAM,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;YACrC,IAAI,OAAO,GAAG,eAAe,CACzB,MAAM,EAAE,IAAI,CAAC,eAAe,CAAC,IAAI,EAAE,EAAE,IAAI,CAAC,OAA2B,EACrE,IAAI,CAAC,OAAO,EAAE,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,CAAC;YACzC,wCAAwC;YACxC,IAAI,IAAI,CAAC,OAAO,EAAE;gBAChB,OAAO,GAAG,CAAC,CAAC,OAAO,CAAC,OAAO,EAAE,IAAI,CAAC,IAAI,CAAC,IAAI,EAAE,EAAE,IAAI,CAAC,UAAU,CAAC,CAAC;aACjE;YACD,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;gBAC3B,OAAO,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;aAC1C;YACD,OAAO,OAAO,CAAC;QACjB,CAAC,CAAC,CAAC;IACL,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,UAAU,GAAG,kBAAkB,CAAC,UAAU,CAAC,CAAC;QAC5C,MAAM,IAAI,GACN,IAAI,CAAC,UAAU,KAAK,eAAe,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;QACxE,MAAM,IAAI,GACN,IAAI,CAAC,UAAU,KAAK,eAAe,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;QACxE,MAAM,UAAU,GAAG,IAAI,CAAC,UAAU,KAAK,eAAe,CAAC,CAAC;YACpD,UAAU,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,eAAe,CAAC,CAAC;YACtC,UAAU,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,eAAe,CAAC;QACzC,MAAM,OAAO,GAAG,gBAAgB,CAC5B,IAAI,EAAE,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,OAAO,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC;QAC7D,MAAM,OAAO,GAAG,gBAAgB,CAC5B,IAAI,EAAE,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,OAAO,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC;QAC7D,IAAI,IAAI,CAAC,UAAU,KAAK,eAAe,EAAE;YACvC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,UAAU,EAAE,OAAO,EAAE,OAAO,CAAC,CAAC;SACtD;aAAM;YACL,uCAAuC;YACvC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,CAAC,CAAC;SACtD;IACH,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACjC,MAAM,CAAC,iBAAiB,CAAC,GAAG,IAAI,CAAC,eAAe,CAAC;QACjD,MAAM,CAAC,sBAAsB,CAAC;YAC1B,oBAAoB,CAAC,IAAI,CAAC,oBAAoB,CAAC,CAAC;QACpD,MAAM,CAAC,sBAAsB,CAAC;YAC1B,oBAAoB,CAAC,IAAI,CAAC,oBAAoB,CAAC,CAAC;QACpD,MAAM,CAAC,qBAAqB,CAAC;YACzB,mBAAmB,CAAC,IAAI,CAAC,oBAAoB,CAAC,CAAC;QACnD,OAAO,MAAM,CAAC;IAChB,CAAC;;AAnGD,kBAAkB;AACX,yBAAS,GAAG,iBAAiB,CAAC;AAoGvC,aAAa,CAAC,aAAa,CAAC,eAAe,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/**\n * TensorFlow.js Layers: Depthwise Convolutional Layers\n */\n\nimport * as tfc from '@tensorflow/tfjs-core';\nimport {serialization, Tensor, Tensor4D, tidy} from '@tensorflow/tfjs-core';\n\nimport {imageDataFormat} from '../backend/common';\nimport * as K from '../backend/tfjs_backend';\nimport {checkDataFormat} from '../common';\nimport {Constraint, ConstraintIdentifier, getConstraint, serializeConstraint} from '../constraints';\nimport {ValueError} from '../errors';\nimport {getInitializer, Initializer, InitializerIdentifier, serializeInitializer} from '../initializers';\nimport {DataFormat, Shape} from '../keras_format/common';\nimport {getRegularizer, Regularizer, RegularizerIdentifier, serializeRegularizer} from '../regularizers';\nimport {Kwargs} from '../types';\nimport {convOutputLength} from '../utils/conv_utils';\nimport {getExactlyOneShape, getExactlyOneTensor} from '../utils/types_utils';\nimport {LayerVariable} from '../variables';\n\nimport {BaseConv, BaseConvLayerArgs, ConvLayerArgs, preprocessConv2DInput} from './convolutional';\n\n/**\n * 2D convolution with separable filters.\n * @param x Input tensor.\n * @param depthwiseKernel Convolution kernel for depthwise convolution.\n * @param strides Strides (Array of two integers).\n * @param padding Padding model.\n * @param dataFormat Data format.\n * @param dilationRate Array of two integers, dilation rates for the separable\n *   convolution.\n * @returns Output tensor.\n * @throws ValueError If depthwiseKernel is not a 4D array.\n */\nexport function depthwiseConv2d(\n    x: Tensor, depthwiseKernel: Tensor, strides: [number, number] = [1, 1],\n    padding = 'valid', dataFormat?: DataFormat,\n    dilationRate?: [number, number]): Tensor {\n  return tidy(() => {\n    if (dataFormat == null) {\n      dataFormat = imageDataFormat();\n    }\n    checkDataFormat(dataFormat);\n    let y = preprocessConv2DInput(x, dataFormat);\n    if (x.rank !== 4) {\n      throw new ValueError(\n          `Input for depthwiseConv2d is required to be 4-D, but is instead ` +\n          `${x.rank}-D`);\n    }\n    if (depthwiseKernel.rank !== 4) {\n      throw new ValueError(\n          `depthwiseKernel is required to be 4-D, but is instead ` +\n          `${depthwiseKernel.rank}-D`);\n    }\n    y = tfc.depthwiseConv2d(\n        y as Tensor4D, depthwiseKernel as Tensor4D, strides,\n        padding === 'same' ? 'same' : 'valid', 'NHWC', dilationRate);\n    if (dataFormat === 'channelsFirst') {\n      y = tfc.transpose(y, [0, 3, 1, 2]);\n    }\n    return y;\n  });\n}\n\nexport declare interface DepthwiseConv2DLayerArgs extends BaseConvLayerArgs {\n  /**\n   * An integer or Array of 2 integers, specifying the width and height of the\n   * 2D convolution window. Can be a single integer to specify the same value\n   * for all spatial dimensions.\n   */\n  kernelSize: number|[number, number];\n\n  /**\n   * The number of depthwise convolution output channels for each input\n   * channel.\n   * The total number of depthwise convolution output channels will be equal to\n   * `filtersIn * depthMultiplier`.\n   * Default: 1.\n   */\n  depthMultiplier?: number;\n\n  /**\n   * Initializer for the depthwise kernel matrix.\n   * Default: GlorotNormal.\n   */\n  depthwiseInitializer?: InitializerIdentifier|Initializer;\n\n  /**\n   * Constraint for the depthwise kernel matrix.\n   */\n  depthwiseConstraint?: ConstraintIdentifier|Constraint;\n\n  /**\n   * Regularizer function for the depthwise kernel matrix.\n   */\n  depthwiseRegularizer?: RegularizerIdentifier|Regularizer;\n}\n\nexport class DepthwiseConv2D extends BaseConv {\n  /** @nocollapse */\n  static className = 'DepthwiseConv2D';\n  private readonly depthMultiplier: number;\n  private readonly depthwiseInitializer: Initializer;\n  private readonly depthwiseConstraint: Constraint;\n  private readonly depthwiseRegularizer: Regularizer;\n\n  private depthwiseKernel: LayerVariable = null;\n\n  constructor(args: DepthwiseConv2DLayerArgs) {\n    super(2, args as ConvLayerArgs);\n    this.depthMultiplier =\n        args.depthMultiplier == null ? 1 : args.depthMultiplier;\n    this.depthwiseInitializer = getInitializer(\n        args.depthwiseInitializer || this.DEFAULT_KERNEL_INITIALIZER);\n    this.depthwiseConstraint = getConstraint(args.depthwiseConstraint);\n    this.depthwiseRegularizer = getRegularizer(args.depthwiseRegularizer);\n  }\n\n  override build(inputShape: Shape|Shape[]): void {\n    inputShape = getExactlyOneShape(inputShape);\n    if (inputShape.length < 4) {\n      throw new ValueError(\n          `Inputs to DepthwiseConv2D should have rank 4. ` +\n          `Received input shape: ${JSON.stringify(inputShape)}.`);\n    }\n    const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : 3;\n    if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) {\n      throw new ValueError(\n          'The channel dimension of the inputs to DepthwiseConv2D should ' +\n          `be defined, but is not (${inputShape[channelAxis]}).`);\n    }\n    const inputDim = inputShape[channelAxis];\n    const depthwiseKernelShape: Shape = [\n      this.kernelSize[0], this.kernelSize[1], inputDim, this.depthMultiplier\n    ];\n\n    this.depthwiseKernel = this.addWeight(\n        'depthwise_kernel', depthwiseKernelShape, null,\n        this.depthwiseInitializer, this.depthwiseRegularizer, true,\n        this.depthwiseConstraint);\n    if (this.useBias) {\n      this.bias = this.addWeight(\n          'bias', [inputDim * this.depthMultiplier], null, this.biasInitializer,\n          this.biasRegularizer, true, this.biasConstraint);\n    } else {\n      this.bias = null;\n    }\n    this.built = true;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    return tidy(() => {\n      inputs = getExactlyOneTensor(inputs);\n      let outputs = depthwiseConv2d(\n          inputs, this.depthwiseKernel.read(), this.strides as [number, number],\n          this.padding, this.dataFormat, null);\n      // TODO(cais): Add support for dilation.\n      if (this.useBias) {\n        outputs = K.biasAdd(outputs, this.bias.read(), this.dataFormat);\n      }\n      if (this.activation != null) {\n        outputs = this.activation.apply(outputs);\n      }\n      return outputs;\n    });\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    inputShape = getExactlyOneShape(inputShape);\n    const rows =\n        this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];\n    const cols =\n        this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];\n    const outFilters = this.dataFormat === 'channelsFirst' ?\n        inputShape[1] * this.depthMultiplier :\n        inputShape[3] * this.depthMultiplier;\n    const outRows = convOutputLength(\n        rows, this.kernelSize[0], this.padding, this.strides[0]);\n    const outCols = convOutputLength(\n        cols, this.kernelSize[1], this.padding, this.strides[1]);\n    if (this.dataFormat === 'channelsFirst') {\n      return [inputShape[0], outFilters, outRows, outCols];\n    } else {\n      // In this case, assume 'channelsLast'.\n      return [inputShape[0], outRows, outCols, outFilters];\n    }\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config = super.getConfig();\n    config['depthMultiplier'] = this.depthMultiplier;\n    config['depthwiseInitializer'] =\n        serializeInitializer(this.depthwiseInitializer);\n    config['depthwiseRegularizer'] =\n        serializeRegularizer(this.depthwiseRegularizer);\n    config['depthwiseConstraint'] =\n        serializeConstraint(this.depthwiseRegularizer);\n    return config;\n  }\n}\nserialization.registerClass(DepthwiseConv2D);\n"]}
\No newline at end of file