UNPKG

24.1 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/**
11 * TensorFlow.js Layers: Depthwise Convolutional Layers
12 */
13import * as tfc from '@tensorflow/tfjs-core';
14import { serialization, tidy } from '@tensorflow/tfjs-core';
15import { imageDataFormat } from '../backend/common';
16import * as K from '../backend/tfjs_backend';
17import { checkDataFormat } from '../common';
18import { getConstraint, serializeConstraint } from '../constraints';
19import { ValueError } from '../errors';
20import { getInitializer, serializeInitializer } from '../initializers';
21import { getRegularizer, serializeRegularizer } from '../regularizers';
22import { convOutputLength } from '../utils/conv_utils';
23import { getExactlyOneShape, getExactlyOneTensor } from '../utils/types_utils';
24import { BaseConv, preprocessConv2DInput } from './convolutional';
25/**
26 * 2D convolution with separable filters.
27 * @param x Input tensor.
28 * @param depthwiseKernel Convolution kernel for depthwise convolution.
29 * @param strides Strides (Array of two integers).
30 * @param padding Padding model.
31 * @param dataFormat Data format.
32 * @param dilationRate Array of two integers, dilation rates for the separable
33 * convolution.
34 * @returns Output tensor.
35 * @throws ValueError If depthwiseKernel is not a 4D array.
36 */
37export function depthwiseConv2d(x, depthwiseKernel, strides = [1, 1], padding = 'valid', dataFormat, dilationRate) {
38 return tidy(() => {
39 if (dataFormat == null) {
40 dataFormat = imageDataFormat();
41 }
42 checkDataFormat(dataFormat);
43 let y = preprocessConv2DInput(x, dataFormat);
44 if (x.rank !== 4) {
45 throw new ValueError(`Input for depthwiseConv2d is required to be 4-D, but is instead ` +
46 `${x.rank}-D`);
47 }
48 if (depthwiseKernel.rank !== 4) {
49 throw new ValueError(`depthwiseKernel is required to be 4-D, but is instead ` +
50 `${depthwiseKernel.rank}-D`);
51 }
52 y = tfc.depthwiseConv2d(y, depthwiseKernel, strides, padding === 'same' ? 'same' : 'valid', 'NHWC', dilationRate);
53 if (dataFormat === 'channelsFirst') {
54 y = tfc.transpose(y, [0, 3, 1, 2]);
55 }
56 return y;
57 });
58}
59export class DepthwiseConv2D extends BaseConv {
60 constructor(args) {
61 super(2, args);
62 this.depthwiseKernel = null;
63 this.depthMultiplier =
64 args.depthMultiplier == null ? 1 : args.depthMultiplier;
65 this.depthwiseInitializer = getInitializer(args.depthwiseInitializer || this.DEFAULT_KERNEL_INITIALIZER);
66 this.depthwiseConstraint = getConstraint(args.depthwiseConstraint);
67 this.depthwiseRegularizer = getRegularizer(args.depthwiseRegularizer);
68 }
69 build(inputShape) {
70 inputShape = getExactlyOneShape(inputShape);
71 if (inputShape.length < 4) {
72 throw new ValueError(`Inputs to DepthwiseConv2D should have rank 4. ` +
73 `Received input shape: ${JSON.stringify(inputShape)}.`);
74 }
75 const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : 3;
76 if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) {
77 throw new ValueError('The channel dimension of the inputs to DepthwiseConv2D should ' +
78 `be defined, but is not (${inputShape[channelAxis]}).`);
79 }
80 const inputDim = inputShape[channelAxis];
81 const depthwiseKernelShape = [
82 this.kernelSize[0], this.kernelSize[1], inputDim, this.depthMultiplier
83 ];
84 this.depthwiseKernel = this.addWeight('depthwise_kernel', depthwiseKernelShape, null, this.depthwiseInitializer, this.depthwiseRegularizer, true, this.depthwiseConstraint);
85 if (this.useBias) {
86 this.bias = this.addWeight('bias', [inputDim * this.depthMultiplier], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
87 }
88 else {
89 this.bias = null;
90 }
91 this.built = true;
92 }
93 call(inputs, kwargs) {
94 return tidy(() => {
95 inputs = getExactlyOneTensor(inputs);
96 let outputs = depthwiseConv2d(inputs, this.depthwiseKernel.read(), this.strides, this.padding, this.dataFormat, null);
97 // TODO(cais): Add support for dilation.
98 if (this.useBias) {
99 outputs = K.biasAdd(outputs, this.bias.read(), this.dataFormat);
100 }
101 if (this.activation != null) {
102 outputs = this.activation.apply(outputs);
103 }
104 return outputs;
105 });
106 }
107 computeOutputShape(inputShape) {
108 inputShape = getExactlyOneShape(inputShape);
109 const rows = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
110 const cols = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
111 const outFilters = this.dataFormat === 'channelsFirst' ?
112 inputShape[1] * this.depthMultiplier :
113 inputShape[3] * this.depthMultiplier;
114 const outRows = convOutputLength(rows, this.kernelSize[0], this.padding, this.strides[0]);
115 const outCols = convOutputLength(cols, this.kernelSize[1], this.padding, this.strides[1]);
116 if (this.dataFormat === 'channelsFirst') {
117 return [inputShape[0], outFilters, outRows, outCols];
118 }
119 else {
120 // In this case, assume 'channelsLast'.
121 return [inputShape[0], outRows, outCols, outFilters];
122 }
123 }
124 getConfig() {
125 const config = super.getConfig();
126 config['depthMultiplier'] = this.depthMultiplier;
127 config['depthwiseInitializer'] =
128 serializeInitializer(this.depthwiseInitializer);
129 config['depthwiseRegularizer'] =
130 serializeRegularizer(this.depthwiseRegularizer);
131 config['depthwiseConstraint'] =
132 serializeConstraint(this.depthwiseRegularizer);
133 return config;
134 }
135}
136/** @nocollapse */
137DepthwiseConv2D.className = 'DepthwiseConv2D';
138serialization.registerClass(DepthwiseConv2D);
139//# sourceMappingURL=data:application/json;base64,
\No newline at end of file