UNPKG

15.2 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17// tslint:disable-next-line: no-imports-from-dist
18import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
19import { getParamValue } from './utils';
20export const executeOp = (node, tensorMap, context, ops = tfOps) => {
21 switch (node.op) {
22 case 'Cast': {
23 return [ops.cast(getParamValue('x', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
24 }
25 case 'ExpandDims': {
26 const axis = getParamValue('axis', node, tensorMap, context);
27 return [ops.expandDims(getParamValue('x', node, tensorMap, context), axis)];
28 }
29 case 'Squeeze': {
30 const axis = getParamValue('axis', node, tensorMap, context);
31 return [ops.squeeze(getParamValue('x', node, tensorMap, context), axis)];
32 }
33 case 'Reshape': {
34 return [ops.reshape(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
35 }
36 case 'EnsureShape': {
37 return [ops.ensureShape(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
38 }
39 case 'MirrorPad': {
40 return [ops.mirrorPad(getParamValue('x', node, tensorMap, context), getParamValue('padding', node, tensorMap, context), getParamValue('mode', node, tensorMap, context))];
41 }
42 case 'PadV2':
43 case 'Pad': {
44 return [ops.pad(getParamValue('x', node, tensorMap, context), getParamValue('padding', node, tensorMap, context), getParamValue('constantValue', node, tensorMap, context))];
45 }
46 case 'SpaceToBatchND': {
47 const blockShape = getParamValue('blockShape', node, tensorMap, context);
48 const paddings = getParamValue('paddings', node, tensorMap, context);
49 return [ops.spaceToBatchND(getParamValue('x', node, tensorMap, context), blockShape, paddings)];
50 }
51 case 'BatchToSpaceND': {
52 const blockShape = getParamValue('blockShape', node, tensorMap, context);
53 const crops = getParamValue('crops', node, tensorMap, context);
54 return [ops.batchToSpaceND(getParamValue('x', node, tensorMap, context), blockShape, crops)];
55 }
56 case 'DepthToSpace': {
57 const blockSize = getParamValue('blockSize', node, tensorMap, context);
58 const dataFormat = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
59 return [ops.depthToSpace(getParamValue('x', node, tensorMap, context), blockSize, dataFormat)];
60 }
61 case 'BroadcastTo': {
62 return [ops.broadcastTo(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
63 }
64 case 'BroadcastArgs': {
65 return [ops.broadcastArgs(getParamValue('s0', node, tensorMap, context), getParamValue('s1', node, tensorMap, context))];
66 }
67 default:
68 throw TypeError(`Node type ${node.op} is not implemented`);
69 }
70};
71export const CATEGORY = 'transformation';
72//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"transformation_executor.js","sourceRoot":"","sources":["../../../../../../../tfjs-converter/src/operations/executors/transformation_executor.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAGH,iDAAiD;AACjD,OAAO,KAAK,KAAK,MAAM,kDAAkD,CAAC;AAM1E,OAAO,EAAC,aAAa,EAAC,MAAM,SAAS,CAAC;AAEtC,MAAM,CAAC,MAAM,SAAS,GAClB,CAAC,IAAU,EAAE,SAA0B,EAAE,OAAyB,EACjE,GAAG,GAAG,KAAK,EAAY,EAAE;IACxB,QAAQ,IAAI,CAAC,EAAE,EAAE;QACf,KAAK,MAAM,CAAC,CAAC;YACX,OAAO,CAAC,GAAG,CAAC,IAAI,CACZ,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CACzB,CAAC,CAAC,CAAC;SAC9B;QACD,KAAK,YAAY,CAAC,CAAC;YACjB,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAC9D,OAAO,CAAC,GAAG,CAAC,UAAU,CAClB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAAE,IAAI,CAAC,CAAC,CAAC;SACpE;QACD,KAAK,SAAS,CAAC,CAAC;YACd,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAChE,OAAO,CAAC,GAAG,CAAC,OAAO,CACf,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAAE,IAAI,CAAC,CAAC,CAAC;SACpE;QAED,KAAK,SAAS,CAAC,CAAC;YACd,OAAO,CAAC,GAAG,CAAC,OAAO,CACf,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC,CAAC,CAAC;SACpE;QACD,KAAK,aAAa,CAAC,CAAC;YAClB,OAAO,CAAC,GAAG,CAAC,WAAW,CACnB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC,CAAC,CAAC;SACpE;QACD,KAAK,WAAW,CAAC,CAAC;YAChB,OAAO,CAAC,GAAG,CAAC,SAAS,CACjB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CACtB,EAC3B,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAC/B,CAAC,CAAC,CAAC;SACvB;QACD,KAAK,OAAO,CAAC;QACb,KAAK,KAAK,CAAC,CAAC;YACV,OAAO,CAAC,GAAG,CAAC,GAAG,CACX,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CACtB,EAC3B,aAAa,CAAC,eAAe,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAC7C,CAAC,CAAC,CAAC;SAClB;QACD,KAAK,gBAAgB,CAAC,CAAC;YACrB,MAAM,UAAU,GACZ,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACtE,MAAM,QAAQ,GACV,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAe,CAAC;YACtE,OAAO,CAAC,GAAG,CAAC,cAAc,CACtB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,UAAU,EAAE,QAAQ,CAAC,CAAC,CAAC;SAC5B;QACD,KAAK,gBAAgB,CAAC,CAAC;YACrB,MAAM,UAAU,GACZ,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACtE,MAAM,KAAK,GACP,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAe,CAAC;YACnE,OAAO,CAAC,GAAG,CAAC,cAAc,CACtB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,UAAU,EAAE,KAAK,CAAC,CAAC,CAAC;SACzB;QACD,KAAK,cAAc,CAAC,CAAC;YACnB,MAAM,SAAS,GACX,aAAa,CAAC,WAAW,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACnE,MAAM,UAAU,GACX,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAC7C,CAAC,WAAW,EACd,CAAC;YACX,OAAO,CAAC,GAAG,CAAC,YAAY,CACpB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EACxD,SAAS,EAAE,UAAU,CAAC,CAAC,CAAC;SAC7B;QACD,KAAK,aAAa,CAAC,CAAC;YAClB,OAAO,CAAC,GAAG,CAAC,WAAW,CACnB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC,CAAC,CAAC;SACpE;QACD,KAAK,eAAe,CAAC,CAAC;YACpB,OAAO,CAAC,GAAG,CAAC,aAAa,CACrB,aAAa,CAAC,IAAI,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACvD,aAAa,CAAC,IAAI,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SAC/D;QACD;YACE,MAAM,SAAS,CAAC,aAAa,IAAI,CAAC,EAAE,qBAAqB,CAAC,CAAC;KAC9D;AACH,CAAC,CAAC;AAEN,MAAM,CAAC,MAAM,QAAQ,GAAG,gBAAgB,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor, Tensor4D} from '@tensorflow/tfjs-core';\n// tslint:disable-next-line: no-imports-from-dist\nimport * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';\n\nimport {NamedTensorsMap} from '../../data/types';\nimport {ExecutionContext} from '../../executor/execution_context';\nimport {InternalOpExecutor, Node} from '../types';\n\nimport {getParamValue} from './utils';\n\nexport const executeOp: InternalOpExecutor =\n    (node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext,\n     ops = tfOps): Tensor[] => {\n      switch (node.op) {\n        case 'Cast': {\n          return [ops.cast(\n              getParamValue('x', node, tensorMap, context) as Tensor,\n              getParamValue('dtype', node, tensorMap, context) as 'int32' |\n                  'float32' | 'bool')];\n        }\n        case 'ExpandDims': {\n          const axis =\n              getParamValue('axis', node, tensorMap, context) as number;\n          return [ops.expandDims(\n              getParamValue('x', node, tensorMap, context) as Tensor, axis)];\n        }\n        case 'Squeeze': {\n          const axis =\n              getParamValue('axis', node, tensorMap, context) as number[];\n          return [ops.squeeze(\n              getParamValue('x', node, tensorMap, context) as Tensor, axis)];\n        }\n\n        case 'Reshape': {\n          return [ops.reshape(\n              getParamValue('x', node, tensorMap, context) as Tensor,\n              getParamValue('shape', node, tensorMap, context) as number[])];\n        }\n        case 'EnsureShape': {\n          return [ops.ensureShape(\n              getParamValue('x', node, tensorMap, context) as Tensor,\n              getParamValue('shape', node, tensorMap, context) as number[])];\n        }\n        case 'MirrorPad': {\n          return [ops.mirrorPad(\n              getParamValue('x', node, tensorMap, context) as Tensor,\n              getParamValue('padding', node, tensorMap, context) as\n                  Array<[number, number]>,\n              getParamValue('mode', node, tensorMap, context) as 'reflect' |\n                  'symmetric')];\n        }\n        case 'PadV2':\n        case 'Pad': {\n          return [ops.pad(\n              getParamValue('x', node, tensorMap, context) as Tensor,\n              getParamValue('padding', node, tensorMap, context) as\n                  Array<[number, number]>,\n              getParamValue('constantValue', node, tensorMap, context) as\n                  number)];\n        }\n        case 'SpaceToBatchND': {\n          const blockShape =\n              getParamValue('blockShape', node, tensorMap, context) as number[];\n          const paddings =\n              getParamValue('paddings', node, tensorMap, context) as number[][];\n          return [ops.spaceToBatchND(\n              getParamValue('x', node, tensorMap, context) as Tensor,\n              blockShape, paddings)];\n        }\n        case 'BatchToSpaceND': {\n          const blockShape =\n              getParamValue('blockShape', node, tensorMap, context) as number[];\n          const crops =\n              getParamValue('crops', node, tensorMap, context) as number[][];\n          return [ops.batchToSpaceND(\n              getParamValue('x', node, tensorMap, context) as Tensor,\n              blockShape, crops)];\n        }\n        case 'DepthToSpace': {\n          const blockSize =\n              getParamValue('blockSize', node, tensorMap, context) as number;\n          const dataFormat =\n              (getParamValue('dataFormat', node, tensorMap, context) as\n               string).toUpperCase() as 'NHWC' |\n              'NCHW';\n          return [ops.depthToSpace(\n              getParamValue('x', node, tensorMap, context) as Tensor4D,\n              blockSize, dataFormat)];\n        }\n        case 'BroadcastTo': {\n          return [ops.broadcastTo(\n              getParamValue('x', node, tensorMap, context) as Tensor,\n              getParamValue('shape', node, tensorMap, context) as number[])];\n        }\n        case 'BroadcastArgs': {\n          return [ops.broadcastArgs(\n              getParamValue('s0', node, tensorMap, context) as Tensor,\n              getParamValue('s1', node, tensorMap, context) as Tensor)];\n        }\n        default:\n          throw TypeError(`Node type ${node.op} is not implemented`);\n      }\n    };\n\nexport const CATEGORY = 'transformation';\n"]}
\No newline at end of file