UNPKG

12.2 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17// tslint:disable-next-line: no-imports-from-dist
18import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
19import { cloneTensor, getParamValue, getTensor } from './utils';
20export const executeOp = (node, tensorMap, context, ops = tfOps) => {
21 switch (node.op) {
22 case 'Const': {
23 return tensorMap[node.name];
24 }
25 case 'PlaceholderWithDefault':
26 const def = getParamValue('default', node, tensorMap, context);
27 return [getTensor(node.name, tensorMap, context) || def];
28 case 'Placeholder':
29 return [getTensor(node.name, tensorMap, context)];
30 case 'Identity':
31 case 'StopGradient':
32 case 'FakeQuantWithMinMaxVars': { // This op is currently ignored.
33 const data = getParamValue('x', node, tensorMap, context);
34 return [cloneTensor(data)];
35 }
36 case 'IdentityN':
37 return getParamValue('x', node, tensorMap, context)
38 .map((t) => cloneTensor(t));
39 case 'Snapshot':
40 const snapshot = getParamValue('x', node, tensorMap, context);
41 return [cloneTensor(snapshot)];
42 case 'Shape':
43 return [ops.tensor1d(getParamValue('x', node, tensorMap, context).shape, 'int32')];
44 case 'ShapeN':
45 return getParamValue('x', node, tensorMap, context)
46 .map((t) => ops.tensor1d(t.shape));
47 case 'Size':
48 return [ops.scalar(getParamValue('x', node, tensorMap, context).size, 'int32')];
49 case 'Rank':
50 return [ops.scalar(getParamValue('x', node, tensorMap, context).rank, 'int32')];
51 case 'NoOp':
52 return [ops.scalar(1)];
53 case 'Print':
54 const input = getParamValue('x', node, tensorMap, context);
55 const data = getParamValue('data', node, tensorMap, context);
56 const message = getParamValue('message', node, tensorMap, context);
57 const summarize = getParamValue('summarize', node, tensorMap, context);
58 console.warn('The graph has a tf.print() operation,' +
59 'usually used for debugging, which slows down performance.');
60 console.log(message);
61 for (let i = 0; i < data.length; i++) {
62 console.log(Array.prototype.slice.call(data[i].dataSync())
63 .slice(0, summarize));
64 }
65 return [input];
66 default:
67 throw TypeError(`Node type ${node.op} is not implemented`);
68 }
69};
70export const CATEGORY = 'graph';
71//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"graph_executor.js","sourceRoot":"","sources":["../../../../../../../tfjs-converter/src/operations/executors/graph_executor.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAGH,iDAAiD;AACjD,OAAO,KAAK,KAAK,MAAM,kDAAkD,CAAC;AAM1E,OAAO,EAAC,WAAW,EAAE,aAAa,EAAE,SAAS,EAAC,MAAM,SAAS,CAAC;AAE9D,MAAM,CAAC,MAAM,SAAS,GAClB,CAAC,IAAU,EAAE,SAA0B,EACtC,OAAyB,EAAE,GAAG,GAAG,KAAK,EAAY,EAAE;IACnD,QAAQ,IAAI,CAAC,EAAE,EAAE;QACf,KAAK,OAAO,CAAC,CAAC;YACZ,OAAO,SAAS,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;SAC7B;QACD,KAAK,wBAAwB;YAC3B,MAAM,GAAG,GACL,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACjE,OAAO,CAAC,SAAS,CAAC,IAAI,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC;QAC3D,KAAK,aAAa;YAChB,OAAO,CAAC,SAAS,CAAC,IAAI,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;QACpD,KAAK,UAAU,CAAC;QAChB,KAAK,cAAc,CAAC;QACpB,KAAK,yBAAyB,CAAC,CAAC,EAAG,gCAAgC;YACjE,MAAM,IAAI,GAAG,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACpE,OAAO,CAAC,WAAW,CAAC,IAAI,CAAC,CAAC,CAAC;SAC5B;QACD,KAAK,WAAW;YACd,OAAQ,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAc;iBAC5D,GAAG,CAAC,CAAC,CAAS,EAAE,EAAE,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC1C,KAAK,UAAU;YACb,MAAM,QAAQ,GACT,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC;YAC7D,OAAO,CAAC,WAAW,CAAC,QAAQ,CAAC,CAAC,CAAC;QACjC,KAAK,OAAO;YACV,OAAO,CAAC,GAAG,CAAC,QAAQ,CACf,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC,KAAK,EAC9D,OAAO,CAAC,CAAC,CAAC;QAChB,KAAK,QAAQ;YACX,OAAQ,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAc;iBAC5D,GAAG,CAAC,CAAC,CAAS,EAAE,EAAE,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC;QACjD,KAAK,MAAM;YACT,OAAO,CAAC,GAAG,CAAC,MAAM,CACb,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC,IAAI,EAC7D,OAAO,CAAC,CAAC,CAAC;QAChB,KAAK,MAAM;YACT,OAAO,CAAC,GAAG,CAAC,MAAM,CACb,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC,IAAI,EAC7D,OAAO,CAAC,CAAC,CAAC;QAChB,KAAK,MAAM;YACT,OAAO,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;QACzB,KAAK,OAAO;YACV,MAAM,KAAK,GAAG,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACrE,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAChE,MAAM,OAAO,GACT,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACjE,MAAM,SAAS,GACX,aAAa,CAAC,WAAW,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACnE,OAAO,CAAC,IAAI,CACR,uCAAuC;gBACvC,2DAA2D,CAAC,CAAC;YACjE,OAAO,CAAC,GAAG,CAAC,OAAO,CAAC,CAAC;YACrB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;gBACpC,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,SAAS,CAAC,KAAK,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,QAAQ,EAAE,CAAC;qBACzC,KAAK,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;aACvC;YACD,OAAO,CAAC,KAAK,CAAC,CAAC;QAEjB;YACE,MAAM,SAAS,CAAC,aAAa,IAAI,CAAC,EAAE,qBAAqB,CAAC,CAAC;KAC9D;AACH,CAAC,CAAC;AAEN,MAAM,CAAC,MAAM,QAAQ,GAAG,OAAO,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '@tensorflow/tfjs-core';\n// tslint:disable-next-line: no-imports-from-dist\nimport * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';\n\nimport {NamedTensorsMap} from '../../data/types';\nimport {ExecutionContext} from '../../executor/execution_context';\nimport {InternalOpExecutor, Node} from '../types';\n\nimport {cloneTensor, getParamValue, getTensor} from './utils';\n\nexport const executeOp: InternalOpExecutor =\n    (node: Node, tensorMap: NamedTensorsMap,\n     context: ExecutionContext, ops = tfOps): Tensor[] => {\n      switch (node.op) {\n        case 'Const': {\n          return tensorMap[node.name];\n        }\n        case 'PlaceholderWithDefault':\n          const def =\n              getParamValue('default', node, tensorMap, context) as Tensor;\n          return [getTensor(node.name, tensorMap, context) || def];\n        case 'Placeholder':\n          return [getTensor(node.name, tensorMap, context)];\n        case 'Identity':\n        case 'StopGradient':\n        case 'FakeQuantWithMinMaxVars': {  // This op is currently ignored.\n          const data = getParamValue('x', node, tensorMap, context) as Tensor;\n          return [cloneTensor(data)];\n        }\n        case 'IdentityN':\n          return (getParamValue('x', node, tensorMap, context) as Tensor[])\n              .map((t: Tensor) => cloneTensor(t));\n        case 'Snapshot':\n          const snapshot =\n              (getParamValue('x', node, tensorMap, context) as Tensor);\n          return [cloneTensor(snapshot)];\n        case 'Shape':\n          return [ops.tensor1d(\n              (getParamValue('x', node, tensorMap, context) as Tensor).shape,\n              'int32')];\n        case 'ShapeN':\n          return (getParamValue('x', node, tensorMap, context) as Tensor[])\n              .map((t: Tensor) => ops.tensor1d(t.shape));\n        case 'Size':\n          return [ops.scalar(\n              (getParamValue('x', node, tensorMap, context) as Tensor).size,\n              'int32')];\n        case 'Rank':\n          return [ops.scalar(\n              (getParamValue('x', node, tensorMap, context) as Tensor).rank,\n              'int32')];\n        case 'NoOp':\n          return [ops.scalar(1)];\n        case 'Print':\n          const input = getParamValue('x', node, tensorMap, context) as Tensor;\n          const data =\n              getParamValue('data', node, tensorMap, context) as Tensor[];\n          const message =\n              getParamValue('message', node, tensorMap, context) as string;\n          const summarize =\n              getParamValue('summarize', node, tensorMap, context) as number;\n          console.warn(\n              'The graph has a tf.print() operation,' +\n              'usually used for debugging, which slows down performance.');\n          console.log(message);\n          for (let i = 0; i < data.length; i++) {\n            console.log(Array.prototype.slice.call(data[i].dataSync())\n                            .slice(0, summarize));\n          }\n          return [input];\n\n        default:\n          throw TypeError(`Node type ${node.op} is not implemented`);\n      }\n    };\n\nexport const CATEGORY = 'graph';\n"]}
\No newline at end of file