1 | /**
|
2 | * @license
|
3 | * Copyright 2018 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | // tslint:disable-next-line: no-imports-from-dist
|
18 | import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
|
19 | import { getPadding, getParamValue } from './utils';
|
20 | function fusedConvAndDepthWiseParams(node, tensorMap, context) {
|
21 | const [extraOp, activationFunc] = getParamValue('fusedOps', node, tensorMap, context);
|
22 | const isBiasAdd = extraOp === 'biasadd';
|
23 | const noBiasAdd = !isBiasAdd;
|
24 | const isPrelu = activationFunc === 'prelu';
|
25 | const isBatchNorm = extraOp === 'fusedbatchnorm';
|
26 | const numArgs = getParamValue('numArgs', node, tensorMap, context);
|
27 | if (isBiasAdd) {
|
28 | if (isPrelu && numArgs !== 2) {
|
29 | throw new Error('FusedConv2d and DepthwiseConv2d with BiasAdd and Prelu ' +
|
30 | 'must have two extra arguments: bias and alpha.');
|
31 | }
|
32 | if (!isPrelu && isBiasAdd && numArgs !== 1) {
|
33 | throw new Error('FusedConv2d and DepthwiseConv2d with BiasAdd must have ' +
|
34 | 'one extra argument: bias.');
|
35 | }
|
36 | }
|
37 | if (isBatchNorm) {
|
38 | throw new Error('FusedConv2d and DepthwiseConv2d with FusedBatchNorm is not supported');
|
39 | }
|
40 | const stride = getParamValue('strides', node, tensorMap, context);
|
41 | const pad = getPadding(node, tensorMap, context);
|
42 | const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
|
43 | .toUpperCase();
|
44 | const dilations = getParamValue('dilations', node, tensorMap, context);
|
45 | let [biasArg, preluArg] = getParamValue('args', node, tensorMap, context);
|
46 | if (noBiasAdd) {
|
47 | preluArg = biasArg;
|
48 | biasArg = undefined;
|
49 | }
|
50 | const leakyreluAlpha = getParamValue('leakyreluAlpha', node, tensorMap, context);
|
51 | return {
|
52 | stride,
|
53 | pad,
|
54 | dataFormat,
|
55 | dilations,
|
56 | biasArg,
|
57 | preluArg,
|
58 | activationFunc,
|
59 | leakyreluAlpha
|
60 | };
|
61 | }
|
62 | export const executeOp = (node, tensorMap, context, ops = tfOps) => {
|
63 | switch (node.op) {
|
64 | case 'Conv1D': {
|
65 | const stride = getParamValue('stride', node, tensorMap, context);
|
66 | const pad = getParamValue('pad', node, tensorMap, context);
|
67 | const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
|
68 | .toUpperCase();
|
69 | const dilation = getParamValue('dilation', node, tensorMap, context);
|
70 | return [ops.conv1d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), stride, pad, dataFormat, dilation)];
|
71 | }
|
72 | case 'Conv2D': {
|
73 | const stride = getParamValue('strides', node, tensorMap, context);
|
74 | const pad = getPadding(node, tensorMap, context);
|
75 | const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
|
76 | .toUpperCase();
|
77 | const dilations = getParamValue('dilations', node, tensorMap, context);
|
78 | return [ops.conv2d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [stride[1], stride[2]], pad, dataFormat, [dilations[1], dilations[2]])];
|
79 | }
|
80 | case '_FusedConv2D': {
|
81 | const { stride, pad, dataFormat, dilations, biasArg, preluArg, activationFunc, leakyreluAlpha } = fusedConvAndDepthWiseParams(node, tensorMap, context);
|
82 | return [ops.fused.conv2d({
|
83 | x: getParamValue('x', node, tensorMap, context),
|
84 | filter: getParamValue('filter', node, tensorMap, context),
|
85 | strides: [stride[1], stride[2]],
|
86 | pad: pad,
|
87 | dataFormat: dataFormat,
|
88 | dilations: [dilations[1], dilations[2]],
|
89 | bias: biasArg,
|
90 | activation: activationFunc,
|
91 | preluActivationWeights: preluArg,
|
92 | leakyreluAlpha
|
93 | })];
|
94 | }
|
95 | case 'FusedDepthwiseConv2dNative': {
|
96 | const { stride, pad, dataFormat, dilations, biasArg, preluArg, activationFunc, leakyreluAlpha, } = fusedConvAndDepthWiseParams(node, tensorMap, context);
|
97 | return [ops.fused.depthwiseConv2d({
|
98 | x: getParamValue('x', node, tensorMap, context),
|
99 | filter: getParamValue('filter', node, tensorMap, context),
|
100 | strides: [stride[1], stride[2]],
|
101 | pad: pad,
|
102 | dataFormat: dataFormat,
|
103 | dilations: [dilations[1], dilations[2]],
|
104 | bias: biasArg,
|
105 | activation: activationFunc,
|
106 | preluActivationWeights: preluArg,
|
107 | leakyreluAlpha
|
108 | })];
|
109 | }
|
110 | case 'Conv2DBackpropInput':
|
111 | case 'Conv2dTranspose': {
|
112 | const shape = getParamValue('outputShape', node, tensorMap, context);
|
113 | const stride = getParamValue('strides', node, tensorMap, context);
|
114 | const pad = getPadding(node, tensorMap, context);
|
115 | return [ops.conv2dTranspose(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), shape, [stride[1], stride[2]], pad)];
|
116 | }
|
117 | case 'DepthwiseConv2dNative':
|
118 | case 'DepthwiseConv2d': {
|
119 | const stride = getParamValue('strides', node, tensorMap, context);
|
120 | const pad = getPadding(node, tensorMap, context);
|
121 | const dilations = getParamValue('dilations', node, tensorMap, context);
|
122 | const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
|
123 | .toUpperCase();
|
124 | return [ops.depthwiseConv2d(getParamValue('input', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [stride[1], stride[2]], pad, dataFormat, [dilations[1], dilations[2]])];
|
125 | }
|
126 | case 'Conv3D': {
|
127 | const stride = getParamValue('strides', node, tensorMap, context);
|
128 | const pad = getParamValue('pad', node, tensorMap, context);
|
129 | const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
|
130 | .toUpperCase();
|
131 | const dilations = getParamValue('dilations', node, tensorMap, context);
|
132 | return [ops.conv3d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [stride[1], stride[2], stride[3]], pad, dataFormat, [dilations[1], dilations[2], dilations[3]])];
|
133 | }
|
134 | case 'AvgPool': {
|
135 | const stride = getParamValue('strides', node, tensorMap, context);
|
136 | const pad = getParamValue('pad', node, tensorMap, context);
|
137 | const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
|
138 | return [ops.avgPool(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], pad)];
|
139 | }
|
140 | case 'MaxPool': {
|
141 | const stride = getParamValue('strides', node, tensorMap, context);
|
142 | const pad = getParamValue('pad', node, tensorMap, context);
|
143 | const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
|
144 | return [ops.maxPool(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], pad)];
|
145 | }
|
146 | case 'MaxPoolWithArgmax': {
|
147 | const stride = getParamValue('strides', node, tensorMap, context);
|
148 | const pad = getParamValue('pad', node, tensorMap, context);
|
149 | const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
|
150 | const includeBatchInIndex = getParamValue('includeBatchInIndex', node, tensorMap, context);
|
151 | const { result, indexes } = ops.maxPoolWithArgmax(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], pad, includeBatchInIndex);
|
152 | return [result, indexes];
|
153 | }
|
154 | case 'AvgPool3D': {
|
155 | const stride = getParamValue('strides', node, tensorMap, context);
|
156 | const pad = getParamValue('pad', node, tensorMap, context);
|
157 | const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
|
158 | return [ops.avgPool3d(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2], kernelSize[3]], [stride[1], stride[2], stride[3]], pad)];
|
159 | }
|
160 | case 'MaxPool3D': {
|
161 | const stride = getParamValue('strides', node, tensorMap, context);
|
162 | const pad = getParamValue('pad', node, tensorMap, context);
|
163 | const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
|
164 | return [ops.maxPool3d(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2], kernelSize[3]], [stride[1], stride[2], stride[3]], pad)];
|
165 | }
|
166 | case 'Dilation2D': {
|
167 | const strides = getParamValue('strides', node, tensorMap, context);
|
168 | const pad = getParamValue('pad', node, tensorMap, context);
|
169 | const dilations = getParamValue('dilations', node, tensorMap, context);
|
170 | // strides: [1, stride_height, stride_width, 1].
|
171 | const strideHeight = strides[1];
|
172 | const strideWidth = strides[2];
|
173 | // dilations: [1, dilation_height, dilation_width, 1].
|
174 | const dilationHeight = dilations[1];
|
175 | const dilationWidth = dilations[2];
|
176 | return [ops.dilation2d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [strideHeight, strideWidth], pad, [dilationHeight, dilationWidth], 'NHWC' /* dataFormat */)];
|
177 | }
|
178 | default:
|
179 | throw TypeError(`Node type ${node.op} is not implemented`);
|
180 | }
|
181 | };
|
182 | export const CATEGORY = 'convolution';
|
183 | //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"convolution_executor.js","sourceRoot":"","sources":["../../../../../../../tfjs-converter/src/operations/executors/convolution_executor.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAGH,iDAAiD;AACjD,OAAO,KAAK,KAAK,MAAM,kDAAkD,CAAC;AAM1E,OAAO,EAAC,UAAU,EAAE,aAAa,EAAC,MAAM,SAAS,CAAC;AAElD,SAAS,2BAA2B,CAChC,IAAU,EAAE,SAA0B,EAAE,OAAyB;IACnE,MAAM,CAAC,OAAO,EAAE,cAAc,CAAC,GAC1B,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAc,CAAC;IAEtE,MAAM,SAAS,GAAG,OAAO,KAAK,SAAS,CAAC;IACxC,MAAM,SAAS,GAAG,CAAC,SAAS,CAAC;IAC7B,MAAM,OAAO,GAAG,cAAc,KAAK,OAAO,CAAC;IAC3C,MAAM,WAAW,GAAG,OAAO,KAAK,gBAAgB,CAAC;IAEjD,MAAM,OAAO,GACR,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC;IACnE,IAAI,SAAS,EAAE;QACb,IAAI,OAAO,IAAI,OAAO,KAAK,CAAC,EAAE;YAC5B,MAAM,IAAI,KAAK,CACX,yDAAyD;gBACzD,gDAAgD,CAAC,CAAC;SACvD;QACD,IAAI,CAAC,OAAO,IAAI,SAAS,IAAI,OAAO,KAAK,CAAC,EAAE;YAC1C,MAAM,IAAI,KAAK,CACX,yDAAyD;gBACzD,2BAA2B,CAAC,CAAC;SAClC;KACF;IACD,IAAI,WAAW,EAAE;QACf,MAAM,IAAI,KAAK,CACX,sEAAsE,CAAC,CAAC;KAC7E;IACD,MAAM,MAAM,GAAG,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;IAC9E,MAAM,GAAG,GAAG,UAAU,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;IACjD,MAAM,UAAU,GACX,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY;SAC5D,WAAW,EAAE,CAAC;IACvB,MAAM,SAAS,GACX,aAAa,CAAC,WAAW,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;IACrE,IAAI,CAAC,OAAO,EAAE,QAAQ,CAAC,GACnB,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;IAChE,IAAI,SAAS,EAAE;QACb,QAAQ,GAAG,OAAO,CAAC;QACnB,OAAO,GAAG,SAAS,CAAC;KACrB;IACD,MAAM,cAAc,GAChB,aAAa,CAAC,gBAAgB,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;IAExE,OAAO;QACL,MAAM;QACN,GAAG;QACH,UAAU;QACV,SAAS;QACT,OAAO;QACP,QAAQ;QACR,cAAc;QACd,cAAc;KACf,CAAC;AACJ,CAAC;AAED,MAAM,CAAC,MAAM,SAAS,GAClB,CAAC,IAAU,EAAE,SAA0B,EACtC,OAAyB,EAAE,GAAG,GAAG,KAAK,EAAY,EAAE;IACnD,QAAQ,IAAI,CAAC,EAAE,EAAE;QACf,KAAK,QAAQ,CAAC,CAAC;YACb,MAAM,MAAM,GACR,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAChE,MAAM,GAAG,GAAG,aAAa,CAAC,KAAK,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;YAC3D,MAAM,UAAU,GACX,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY;iBAC5D,WAAW,EAAE,CAAC;YACvB,MAAM,QAAQ,GACV,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAClE,OAAO,CAAC,GAAG,CAAC,MAAM,CACd,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EACxD,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EAC7D,MAAM,EAAE,GAAuB,EAAE,UAA2B,EAC5D,QAAQ,CAAC,CAAC,CAAC;SAChB;QACD,KAAK,QAAQ,CAAC,CAAC;YACb,MAAM,MAAM,GACR,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACnE,MAAM,GAAG,GAAG,UAAU,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;YACjD,MAAM,UAAU,GACX,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY;iBAC5D,WAAW,EAAE,CAAC;YACvB,MAAM,SAAS,GACX,aAAa,CAAC,WAAW,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACrE,OAAO,CAAC,GAAG,CAAC,MAAM,CACd,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAC/B,EACZ,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EAC7D,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,EAAE,GAAuB,EAC/C,UAA6B,EAAE,CAAC,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;SACnE;QACD,KAAK,cAAc,CAAC,CAAC;YACnB,MAAM,EACJ,MAAM,EACN,GAAG,EACH,UAAU,EACV,SAAS,EACT,OAAO,EACP,QAAQ,EACR,cAAc,EACd,cAAc,EACf,GAAG,2BAA2B,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;YAE1D,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,CAAC;oBACvB,CAAC,EAAE,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAClC;oBACZ,MAAM,EAAE,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAC5C;oBACZ,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC;oBAC/B,GAAG,EAAE,GAAuB;oBAC5B,UAAU,EAAE,UAA6B;oBACzC,SAAS,EAAE,CAAC,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC;oBACvC,IAAI,EAAE,OAAO;oBACb,UAAU,EAAE,cAAwC;oBACpD,sBAAsB,EAAE,QAAQ;oBAChC,cAAc;iBACf,CAAC,CAAC,CAAC;SACL;QAED,KAAK,4BAA4B,CAAC,CAAC;YACjC,MAAM,EACJ,MAAM,EACN,GAAG,EACH,UAAU,EACV,SAAS,EACT,OAAO,EACP,QAAQ,EACR,cAAc,EACd,cAAc,GACf,GAAG,2BAA2B,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;YAE1D,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,eAAe,CAAC;oBAChC,CAAC,EAAE,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAClC;oBACZ,MAAM,EAAE,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAC5C;oBACZ,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC;oBAC/B,GAAG,EAAE,GAAuB;oBAC5B,UAAU,EAAE,UAA6B;oBACzC,SAAS,EAAE,CAAC,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC;oBACvC,IAAI,EAAE,OAAO;oBACb,UAAU,EAAE,cAAwC;oBACpD,sBAAsB,EAAE,QAAQ;oBAChC,cAAc;iBACf,CAAC,CAAC,CAAC;SACL;QACD,KAAK,qBAAqB,CAAC;QAC3B,KAAK,iBAAiB,CAAC,CAAC;YACtB,MAAM,KAAK,GAAG,aAAa,CACT,aAAa,EAAE,IAAI,EAAE,SAAS,EAC9B,OAAO,CACW,CAAC;YACrC,MAAM,MAAM,GACR,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACnE,MAAM,GAAG,GAAG,UAAU,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;YACjD,OAAO,CAAC,GAAG,CAAC,eAAe,CACvB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAC/B,EACZ,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EAC7D,KAAK,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,EAAE,GAAuB,CAAC,CAAC,CAAC;SAC9D;QACD,KAAK,uBAAuB,CAAC;QAC7B,KAAK,iBAAiB,CAAC,CAAC;YACtB,MAAM,MAAM,GACR,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACnE,MAAM,GAAG,GAAG,UAAU,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;YACjD,MAAM,SAAS,GACX,aAAa,CAAC,WAAW,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACrE,MAAM,UAAU,GACX,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY;iBAC5D,WAAW,EAAE,CAAC;YAEvB,OAAO,CAAC,GAAG,CAAC,eAAe,CACvB,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CACnC,EACZ,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EAC7D,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,EAAE,GAAuB,EAC/C,UAA6B,EAAE,CAAC,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;SACnE;QACD,KAAK,QAAQ,CAAC,CAAC;YACb,MAAM,MAAM,GACR,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACnE,MAAM,GAAG,GAAG,aAAa,CAAC,KAAK,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;YAC3D,MAAM,UAAU,GACX,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY;iBAC5D,WAAW,EAAE,CAAC;YACvB,MAAM,SAAS,GACX,aAAa,CAAC,WAAW,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACrE,OAAO,CAAC,GAAG,CAAC,MAAM,CACd,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CACxB,EACnB,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAC7B,EACnB,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,EAAE,GAAuB,EAC1D,UAA+B,EAC/B,CAAC,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;SAClD;QACD,KAAK,SAAS,CAAC,CAAC;YACd,MAAM,MAAM,GACR,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACnE,MAAM,GAAG,GAAG,aAAa,CAAC,KAAK,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;YAC3D,MAAM,UAAU,GACZ,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAEtE,OAAO,CAAC,GAAG,CAAC,OAAO,CACf,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAC/B,EACZ,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,EACtD,GAAuB,CAAC,CAAC,CAAC;SAC/B;QACD,KAAK,SAAS,CAAC,CAAC;YACd,MAAM,MAAM,GACR,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACnE,MAAM,GAAG,GAAG,aAAa,CAAC,KAAK,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;YAC3D,MAAM,UAAU,GACZ,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAEtE,OAAO,CAAC,GAAG,CAAC,OAAO,CACf,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAC/B,EACZ,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,EACtD,GAAuB,CAAC,CAAC,CAAC;SAC/B;QACD,KAAK,mBAAmB,CAAC,CAAC;YACxB,MAAM,MAAM,GACR,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACnE,MAAM,GAAG,GAAG,aAAa,CAAC,KAAK,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;YAC3D,MAAM,UAAU,GACZ,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACtE,MAAM,mBAAmB,GACrB,aAAa,CAAC,qBAAqB,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CACtD,CAAC;YACZ,MAAM,EAAC,MAAM,EAAE,OAAO,EAAC,GAAG,GAAG,CAAC,iBAAiB,CAC3C,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EACxD,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,EACtD,GAAuB,EAAE,mBAAmB,CAAC,CAAC;YAClD,OAAO,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;SAC1B;QACD,KAAK,WAAW,CAAC,CAAC;YAChB,MAAM,MAAM,GACR,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACnE,MAAM,GAAG,GAAG,aAAa,CAAC,KAAK,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;YAC3D,MAAM,UAAU,GACZ,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAEtE,OAAO,CAAC,GAAG,CAAC,SAAS,CACjB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EACxD,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC,CAAC,EAC7C,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,EAAE,GAAuB,CAAC,CAAC,CAAC;SAClE;QAED,KAAK,WAAW,CAAC,CAAC;YAChB,MAAM,MAAM,GACR,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACnE,MAAM,GAAG,GAAG,aAAa,CAAC,KAAK,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;YAC3D,MAAM,UAAU,GACZ,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAEtE,OAAO,CAAC,GAAG,CAAC,SAAS,CACjB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EACxD,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC,CAAC,EAC7C,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,EAAE,GAAuB,CAAC,CAAC,CAAC;SAClE;QAED,KAAK,YAAY,CAAC,CAAC;YACjB,MAAM,OAAO,GACT,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACnE,MAAM,GAAG,GAAG,aAAa,CAAC,KAAK,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;YAC3D,MAAM,SAAS,GACX,aAAa,CAAC,WAAW,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAErE,gDAAgD;YAChD,MAAM,YAAY,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC;YAChC,MAAM,WAAW,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC;YAE/B,sDAAsD;YACtD,MAAM,cAAc,GAAG,SAAS,CAAC,CAAC,CAAC,CAAC;YACpC,MAAM,aAAa,GAAG,SAAS,CAAC,CAAC,CAAC,CAAC;YAEnC,OAAO,CAAC,GAAG,CAAC,UAAU,CAClB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAC/B,EACZ,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EAC7D,CAAC,YAAY,EAAE,WAAW,CAAC,EAAE,GAAuB,EACpD,CAAC,cAAc,EAAE,aAAa,CAAC,EAAE,MAAM,CAAC,gBAAgB,CAAC,CAAC,CAAC;SAChE;QAED;YACE,MAAM,SAAS,CAAC,aAAa,IAAI,CAAC,EAAE,qBAAqB,CAAC,CAAC;KAC9D;AACH,CAAC,CAAC;AAEN,MAAM,CAAC,MAAM,QAAQ,GAAG,aAAa,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Rank, Tensor, Tensor3D, Tensor4D, Tensor5D} from '@tensorflow/tfjs-core';\n// tslint:disable-next-line: no-imports-from-dist\nimport * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';\n\nimport {NamedTensorsMap} from '../../data/types';\nimport {ExecutionContext} from '../../executor/execution_context';\nimport {InternalOpExecutor, Node} from '../types';\n\nimport {getPadding, getParamValue} from './utils';\n\nfunction fusedConvAndDepthWiseParams(\n    node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext) {\n  const [extraOp, activationFunc] =\n      (getParamValue('fusedOps', node, tensorMap, context) as string[]);\n\n  const isBiasAdd = extraOp === 'biasadd';\n  const noBiasAdd = !isBiasAdd;\n  const isPrelu = activationFunc === 'prelu';\n  const isBatchNorm = extraOp === 'fusedbatchnorm';\n\n  const numArgs =\n      (getParamValue('numArgs', node, tensorMap, context) as number);\n  if (isBiasAdd) {\n    if (isPrelu && numArgs !== 2) {\n      throw new Error(\n          'FusedConv2d and DepthwiseConv2d with BiasAdd and Prelu ' +\n          'must have two extra arguments: bias and alpha.');\n    }\n    if (!isPrelu && isBiasAdd && numArgs !== 1) {\n      throw new Error(\n          'FusedConv2d and DepthwiseConv2d with BiasAdd must have ' +\n          'one extra argument: bias.');\n    }\n  }\n  if (isBatchNorm) {\n    throw new Error(\n        'FusedConv2d and DepthwiseConv2d with FusedBatchNorm is not supported');\n  }\n  const stride = getParamValue('strides', node, tensorMap, context) as number[];\n  const pad = getPadding(node, tensorMap, context);\n  const dataFormat =\n      (getParamValue('dataFormat', node, tensorMap, context) as string)\n          .toUpperCase();\n  const dilations =\n      getParamValue('dilations', node, tensorMap, context) as number[];\n  let [biasArg, preluArg] =\n      getParamValue('args', node, tensorMap, context) as Tensor[];\n  if (noBiasAdd) {\n    preluArg = biasArg;\n    biasArg = undefined;\n  }\n  const leakyreluAlpha =\n      getParamValue('leakyreluAlpha', node, tensorMap, context) as number;\n\n  return {\n    stride,\n    pad,\n    dataFormat,\n    dilations,\n    biasArg,\n    preluArg,\n    activationFunc,\n    leakyreluAlpha\n  };\n}\n\nexport const executeOp: InternalOpExecutor =\n    (node: Node, tensorMap: NamedTensorsMap,\n     context: ExecutionContext, ops = tfOps): Tensor[] => {\n      switch (node.op) {\n        case 'Conv1D': {\n          const stride =\n              getParamValue('stride', node, tensorMap, context) as number;\n          const pad = getParamValue('pad', node, tensorMap, context);\n          const dataFormat =\n              (getParamValue('dataFormat', node, tensorMap, context) as string)\n                  .toUpperCase();\n          const dilation =\n              getParamValue('dilation', node, tensorMap, context) as number;\n          return [ops.conv1d(\n              getParamValue('x', node, tensorMap, context) as Tensor3D,\n              getParamValue('filter', node, tensorMap, context) as Tensor3D,\n              stride, pad as 'valid' | 'same', dataFormat as 'NWC' | 'NCW',\n              dilation)];\n        }\n        case 'Conv2D': {\n          const stride =\n              getParamValue('strides', node, tensorMap, context) as number[];\n          const pad = getPadding(node, tensorMap, context);\n          const dataFormat =\n              (getParamValue('dataFormat', node, tensorMap, context) as string)\n                  .toUpperCase();\n          const dilations =\n              getParamValue('dilations', node, tensorMap, context) as number[];\n          return [ops.conv2d(\n              getParamValue('x', node, tensorMap, context) as Tensor3D |\n                  Tensor4D,\n              getParamValue('filter', node, tensorMap, context) as Tensor4D,\n              [stride[1], stride[2]], pad as 'valid' | 'same',\n              dataFormat as 'NHWC' | 'NCHW', [dilations[1], dilations[2]])];\n        }\n        case '_FusedConv2D': {\n          const {\n            stride,\n            pad,\n            dataFormat,\n            dilations,\n            biasArg,\n            preluArg,\n            activationFunc,\n            leakyreluAlpha\n          } = fusedConvAndDepthWiseParams(node, tensorMap, context);\n\n          return [ops.fused.conv2d({\n            x: getParamValue('x', node, tensorMap, context) as Tensor3D |\n                Tensor4D,\n            filter: getParamValue('filter', node, tensorMap, context) as\n                Tensor4D,\n            strides: [stride[1], stride[2]],\n            pad: pad as 'valid' | 'same',\n            dataFormat: dataFormat as 'NHWC' | 'NCHW',\n            dilations: [dilations[1], dilations[2]],\n            bias: biasArg,\n            activation: activationFunc as tfOps.fused.Activation,\n            preluActivationWeights: preluArg,\n            leakyreluAlpha\n          })];\n        }\n\n        case 'FusedDepthwiseConv2dNative': {\n          const {\n            stride,\n            pad,\n            dataFormat,\n            dilations,\n            biasArg,\n            preluArg,\n            activationFunc,\n            leakyreluAlpha,\n          } = fusedConvAndDepthWiseParams(node, tensorMap, context);\n\n          return [ops.fused.depthwiseConv2d({\n            x: getParamValue('x', node, tensorMap, context) as Tensor3D |\n                Tensor4D,\n            filter: getParamValue('filter', node, tensorMap, context) as\n                Tensor4D,\n            strides: [stride[1], stride[2]],\n            pad: pad as 'valid' | 'same',\n            dataFormat: dataFormat as 'NHWC' | 'NCHW',\n            dilations: [dilations[1], dilations[2]],\n            bias: biasArg,\n            activation: activationFunc as tfOps.fused.Activation,\n            preluActivationWeights: preluArg,\n            leakyreluAlpha\n          })];\n        }\n        case 'Conv2DBackpropInput':\n        case 'Conv2dTranspose': {\n          const shape = getParamValue(\n                            'outputShape', node, tensorMap,\n                            context) as [number, number, number] |\n              [number, number, number, number];\n          const stride =\n              getParamValue('strides', node, tensorMap, context) as number[];\n          const pad = getPadding(node, tensorMap, context);\n          return [ops.conv2dTranspose(\n              getParamValue('x', node, tensorMap, context) as Tensor3D |\n                  Tensor4D,\n              getParamValue('filter', node, tensorMap, context) as Tensor4D,\n              shape, [stride[1], stride[2]], pad as 'valid' | 'same')];\n        }\n        case 'DepthwiseConv2dNative':\n        case 'DepthwiseConv2d': {\n          const stride =\n              getParamValue('strides', node, tensorMap, context) as number[];\n          const pad = getPadding(node, tensorMap, context);\n          const dilations =\n              getParamValue('dilations', node, tensorMap, context) as number[];\n          const dataFormat =\n              (getParamValue('dataFormat', node, tensorMap, context) as string)\n                  .toUpperCase();\n\n          return [ops.depthwiseConv2d(\n              getParamValue('input', node, tensorMap, context) as Tensor3D |\n                  Tensor4D,\n              getParamValue('filter', node, tensorMap, context) as Tensor4D,\n              [stride[1], stride[2]], pad as 'valid' | 'same',\n              dataFormat as 'NHWC' | 'NCHW', [dilations[1], dilations[2]])];\n        }\n        case 'Conv3D': {\n          const stride =\n              getParamValue('strides', node, tensorMap, context) as number[];\n          const pad = getParamValue('pad', node, tensorMap, context);\n          const dataFormat =\n              (getParamValue('dataFormat', node, tensorMap, context) as string)\n                  .toUpperCase();\n          const dilations =\n              getParamValue('dilations', node, tensorMap, context) as number[];\n          return [ops.conv3d(\n              getParamValue('x', node, tensorMap, context) as Tensor4D |\n                  Tensor<Rank.R5>,\n              getParamValue('filter', node, tensorMap, context) as\n                  Tensor<Rank.R5>,\n              [stride[1], stride[2], stride[3]], pad as 'valid' | 'same',\n              dataFormat as 'NDHWC' | 'NCDHW',\n              [dilations[1], dilations[2], dilations[3]])];\n        }\n        case 'AvgPool': {\n          const stride =\n              getParamValue('strides', node, tensorMap, context) as number[];\n          const pad = getParamValue('pad', node, tensorMap, context);\n          const kernelSize =\n              getParamValue('kernelSize', node, tensorMap, context) as number[];\n\n          return [ops.avgPool(\n              getParamValue('x', node, tensorMap, context) as Tensor3D |\n                  Tensor4D,\n              [kernelSize[1], kernelSize[2]], [stride[1], stride[2]],\n              pad as 'valid' | 'same')];\n        }\n        case 'MaxPool': {\n          const stride =\n              getParamValue('strides', node, tensorMap, context) as number[];\n          const pad = getParamValue('pad', node, tensorMap, context);\n          const kernelSize =\n              getParamValue('kernelSize', node, tensorMap, context) as number[];\n\n          return [ops.maxPool(\n              getParamValue('x', node, tensorMap, context) as Tensor3D |\n                  Tensor4D,\n              [kernelSize[1], kernelSize[2]], [stride[1], stride[2]],\n              pad as 'valid' | 'same')];\n        }\n        case 'MaxPoolWithArgmax': {\n          const stride =\n              getParamValue('strides', node, tensorMap, context) as number[];\n          const pad = getParamValue('pad', node, tensorMap, context);\n          const kernelSize =\n              getParamValue('kernelSize', node, tensorMap, context) as number[];\n          const includeBatchInIndex =\n              getParamValue('includeBatchInIndex', node, tensorMap, context) as\n              boolean;\n          const {result, indexes} = ops.maxPoolWithArgmax(\n              getParamValue('x', node, tensorMap, context) as Tensor4D,\n              [kernelSize[1], kernelSize[2]], [stride[1], stride[2]],\n              pad as 'valid' | 'same', includeBatchInIndex);\n          return [result, indexes];\n        }\n        case 'AvgPool3D': {\n          const stride =\n              getParamValue('strides', node, tensorMap, context) as number[];\n          const pad = getParamValue('pad', node, tensorMap, context);\n          const kernelSize =\n              getParamValue('kernelSize', node, tensorMap, context) as number[];\n\n          return [ops.avgPool3d(\n              getParamValue('x', node, tensorMap, context) as Tensor5D,\n              [kernelSize[1], kernelSize[2], kernelSize[3]],\n              [stride[1], stride[2], stride[3]], pad as 'valid' | 'same')];\n        }\n\n        case 'MaxPool3D': {\n          const stride =\n              getParamValue('strides', node, tensorMap, context) as number[];\n          const pad = getParamValue('pad', node, tensorMap, context);\n          const kernelSize =\n              getParamValue('kernelSize', node, tensorMap, context) as number[];\n\n          return [ops.maxPool3d(\n              getParamValue('x', node, tensorMap, context) as Tensor5D,\n              [kernelSize[1], kernelSize[2], kernelSize[3]],\n              [stride[1], stride[2], stride[3]], pad as 'valid' | 'same')];\n        }\n\n        case 'Dilation2D': {\n          const strides =\n              getParamValue('strides', node, tensorMap, context) as number[];\n          const pad = getParamValue('pad', node, tensorMap, context);\n          const dilations =\n              getParamValue('dilations', node, tensorMap, context) as number[];\n\n          // strides: [1, stride_height, stride_width, 1].\n          const strideHeight = strides[1];\n          const strideWidth = strides[2];\n\n          // dilations: [1, dilation_height, dilation_width, 1].\n          const dilationHeight = dilations[1];\n          const dilationWidth = dilations[2];\n\n          return [ops.dilation2d(\n              getParamValue('x', node, tensorMap, context) as Tensor3D |\n                  Tensor4D,\n              getParamValue('filter', node, tensorMap, context) as Tensor3D,\n              [strideHeight, strideWidth], pad as 'valid' | 'same',\n              [dilationHeight, dilationWidth], 'NHWC' /* dataFormat */)];\n        }\n\n        default:\n          throw TypeError(`Node type ${node.op} is not implemented`);\n      }\n    };\n\nexport const CATEGORY = 'convolution';\n"]} |
\ | No newline at end of file |