1 | /**
|
2 | * @license
|
3 | * Copyright 2018 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | // tslint:disable-next-line: no-imports-from-dist
|
18 | import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
|
19 | import { getParamValue } from './utils';
|
20 | export const executeOp = (node, tensorMap, context, ops = tfOps) => {
|
21 | switch (node.op) {
|
22 | case 'BatchMatMul':
|
23 | case 'BatchMatMulV2':
|
24 | case 'MatMul':
|
25 | return [ops.matMul(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context), getParamValue('transposeA', node, tensorMap, context), getParamValue('transposeB', node, tensorMap, context))];
|
26 | case 'Einsum':
|
27 | return [ops.einsum(getParamValue('equation', node, tensorMap, context), ...getParamValue('tensors', node, tensorMap, context))];
|
28 | case 'Transpose':
|
29 | return [ops.transpose(getParamValue('x', node, tensorMap, context), getParamValue('perm', node, tensorMap, context))];
|
30 | case '_FusedMatMul':
|
31 | const [extraOp, activationFunc] = getParamValue('fusedOps', node, tensorMap, context);
|
32 | const isBiasAdd = extraOp === 'biasadd';
|
33 | const isPrelu = activationFunc === 'prelu';
|
34 | const numArgs = getParamValue('numArgs', node, tensorMap, context);
|
35 | const leakyreluAlpha = getParamValue('leakyreluAlpha', node, tensorMap, context);
|
36 | if (isBiasAdd) {
|
37 | if (isPrelu && numArgs !== 2) {
|
38 | throw new Error('Fused MatMul with BiasAdd and Prelu must have two ' +
|
39 | 'extra arguments: bias and alpha.');
|
40 | }
|
41 | if (!isPrelu && numArgs !== 1) {
|
42 | throw new Error('Fused MatMul with BiasAdd must have one extra argument: bias.');
|
43 | }
|
44 | }
|
45 | const [biasArg, preluArg] = getParamValue('args', node, tensorMap, context);
|
46 | return [ops.fused.matMul({
|
47 | a: getParamValue('a', node, tensorMap, context),
|
48 | b: getParamValue('b', node, tensorMap, context),
|
49 | transposeA: getParamValue('transposeA', node, tensorMap, context),
|
50 | transposeB: getParamValue('transposeB', node, tensorMap, context),
|
51 | bias: biasArg,
|
52 | activation: activationFunc,
|
53 | preluActivationWeights: preluArg,
|
54 | leakyreluAlpha
|
55 | })];
|
56 | case 'MatrixBandPart':
|
57 | return [ops.linalg.bandPart(getParamValue('a', node, tensorMap, context), getParamValue('numLower', node, tensorMap, context), getParamValue('numUpper', node, tensorMap, context))];
|
58 | default:
|
59 | throw TypeError(`Node type ${node.op} is not implemented`);
|
60 | }
|
61 | };
|
62 | export const CATEGORY = 'matrices';
|
63 | //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"matrices_executor.js","sourceRoot":"","sources":["../../../../../../../tfjs-converter/src/operations/executors/matrices_executor.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAGH,iDAAiD;AACjD,OAAO,KAAK,KAAK,MAAM,kDAAkD,CAAC;AAM1E,OAAO,EAAC,aAAa,EAAC,MAAM,SAAS,CAAC;AAEtC,MAAM,CAAC,MAAM,SAAS,GAClB,CAAC,IAAU,EAAE,SAA0B,EAAE,OAAyB,EACjE,GAAG,GAAG,KAAK,EAAY,EAAE;IACxB,QAAQ,IAAI,CAAC,EAAE,EAAE;QACf,KAAK,aAAa,CAAC;QACnB,KAAK,eAAe,CAAC;QACrB,KAAK,QAAQ;YACX,OAAO,CAAC,GAAG,CAAC,MAAM,CACd,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EACxD,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EACxD,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,EAChE,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CACzC,CAAC,CAAC,CAAC;QAEpB,KAAK,QAAQ;YACX,OAAO,CAAC,GAAG,CAAC,MAAM,CACd,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAC7D,GAAG,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CACxC,CAAC,CAAC,CAAC;QAErB,KAAK,WAAW;YACd,OAAO,CAAC,GAAG,CAAC,SAAS,CACjB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC,CAAC,CAAC;QAEpE,KAAK,cAAc;YACjB,MAAM,CAAC,OAAO,EAAE,cAAc,CAAC,GAC1B,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAc,CAAC;YAEtE,MAAM,SAAS,GAAG,OAAO,KAAK,SAAS,CAAC;YACxC,MAAM,OAAO,GAAG,cAAc,KAAK,OAAO,CAAC;YAE3C,MAAM,OAAO,GACR,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC;YACnE,MAAM,cAAc,GAChB,aAAa,CAAC,gBAAgB,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAClD,CAAC;YAEX,IAAI,SAAS,EAAE;gBACb,IAAI,OAAO,IAAI,OAAO,KAAK,CAAC,EAAE;oBAC5B,MAAM,IAAI,KAAK,CACX,oDAAoD;wBACpD,kCAAkC,CAAC,CAAC;iBACzC;gBACD,IAAI,CAAC,OAAO,IAAI,OAAO,KAAK,CAAC,EAAE;oBAC7B,MAAM,IAAI,KAAK,CACX,+DAA+D,CAAC,CAAC;iBACtE;aACF;YACD,MAAM,CAAC,OAAO,EAAE,QAAQ,CAAC,GACrB,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAChE,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,CAAC;oBACvB,CAAC,EAAE,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa;oBAC3D,CAAC,EAAE,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa;oBAC3D,UAAU,EAAE,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CACrD;oBACX,UAAU,EAAE,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CACrD;oBACX,IAAI,EAAE,OAAO;oBACb,UAAU,EAAE,cAAwC;oBACpD,sBAAsB,EAAE,QAAQ;oBAChC,cAAc;iBACf,CAAC,CAAC,CAAC;QAEN,KAAK,gBAAgB;YACnB,OAAO,CAAC,GAAG,CAAC,MAAM,CAAC,QAAQ,CACvB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EACxD,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAC7D,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;QAEtE;YACE,MAAM,SAAS,CAAC,aAAa,IAAI,CAAC,EAAE,qBAAqB,CAAC,CAAC;KAC9D;AACH,CAAC,CAAC;AAEN,MAAM,CAAC,MAAM,QAAQ,GAAG,UAAU,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Scalar, Tensor, Tensor2D} from '@tensorflow/tfjs-core';\n// tslint:disable-next-line: no-imports-from-dist\nimport * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';\n\nimport {NamedTensorsMap} from '../../data/types';\nimport {ExecutionContext} from '../../executor/execution_context';\nimport {InternalOpExecutor, Node} from '../types';\n\nimport {getParamValue} from './utils';\n\nexport const executeOp: InternalOpExecutor =\n    (node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext,\n     ops = tfOps): Tensor[] => {\n      switch (node.op) {\n        case 'BatchMatMul':\n        case 'BatchMatMulV2':\n        case 'MatMul':\n          return [ops.matMul(\n              getParamValue('a', node, tensorMap, context) as Tensor2D,\n              getParamValue('b', node, tensorMap, context) as Tensor2D,\n              getParamValue('transposeA', node, tensorMap, context) as boolean,\n              getParamValue('transposeB', node, tensorMap, context) as\n                  boolean)];\n\n        case 'Einsum':\n          return [ops.einsum(\n              getParamValue('equation', node, tensorMap, context) as string,\n              ...getParamValue('tensors', node, tensorMap, context) as\n                  Tensor[])];\n\n        case 'Transpose':\n          return [ops.transpose(\n              getParamValue('x', node, tensorMap, context) as Tensor,\n              getParamValue('perm', node, tensorMap, context) as number[])];\n\n        case '_FusedMatMul':\n          const [extraOp, activationFunc] =\n              (getParamValue('fusedOps', node, tensorMap, context) as string[]);\n\n          const isBiasAdd = extraOp === 'biasadd';\n          const isPrelu = activationFunc === 'prelu';\n\n          const numArgs =\n              (getParamValue('numArgs', node, tensorMap, context) as number);\n          const leakyreluAlpha =\n              getParamValue('leakyreluAlpha', node, tensorMap, context) as\n              number;\n\n          if (isBiasAdd) {\n            if (isPrelu && numArgs !== 2) {\n              throw new Error(\n                  'Fused MatMul with BiasAdd and Prelu must have two ' +\n                  'extra arguments: bias and alpha.');\n            }\n            if (!isPrelu && numArgs !== 1) {\n              throw new Error(\n                  'Fused MatMul with BiasAdd must have one extra argument: bias.');\n            }\n          }\n          const [biasArg, preluArg] =\n              getParamValue('args', node, tensorMap, context) as Tensor[];\n          return [ops.fused.matMul({\n            a: getParamValue('a', node, tensorMap, context) as Tensor2D,\n            b: getParamValue('b', node, tensorMap, context) as Tensor2D,\n            transposeA: getParamValue('transposeA', node, tensorMap, context) as\n                boolean,\n            transposeB: getParamValue('transposeB', node, tensorMap, context) as\n                boolean,\n            bias: biasArg,\n            activation: activationFunc as tfOps.fused.Activation,\n            preluActivationWeights: preluArg,\n            leakyreluAlpha\n          })];\n\n        case 'MatrixBandPart':\n          return [ops.linalg.bandPart(\n              getParamValue('a', node, tensorMap, context) as Tensor2D,\n              getParamValue('numLower', node, tensorMap, context) as Scalar,\n              getParamValue('numUpper', node, tensorMap, context) as Scalar)];\n\n        default:\n          throw TypeError(`Node type ${node.op} is not implemented`);\n      }\n    };\n\nexport const CATEGORY = 'matrices';\n"]} |
\ | No newline at end of file |