UNPKG

12.2 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17// tslint:disable-next-line: no-imports-from-dist
18import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
19import { cloneTensor, getParamValue, getTensor } from './utils';
20export const executeOp = (node, tensorMap, context, ops = tfOps) => {
21 switch (node.op) {
22 case 'Const': {
23 return tensorMap[node.name];
24 }
25 case 'PlaceholderWithDefault':
26 const def = getParamValue('default', node, tensorMap, context);
27 return [getTensor(node.name, tensorMap, context) || def];
28 case 'Placeholder':
29 return [getTensor(node.name, tensorMap, context)];
30 case 'Identity':
31 case 'StopGradient':
32 case 'FakeQuantWithMinMaxVars': { // This op is currently ignored.
33 const data = getParamValue('x', node, tensorMap, context);
34 return [cloneTensor(data)];
35 }
36 case 'IdentityN':
37 return getParamValue('x', node, tensorMap, context)
38 .map((t) => cloneTensor(t));
39 case 'Snapshot':
40 const snapshot = getParamValue('x', node, tensorMap, context);
41 return [cloneTensor(snapshot)];
42 case 'Shape':
43 return [ops.tensor1d(getParamValue('x', node, tensorMap, context).shape, 'int32')];
44 case 'ShapeN':
45 return getParamValue('x', node, tensorMap, context)
46 .map((t) => ops.tensor1d(t.shape));
47 case 'Size':
48 return [ops.scalar(getParamValue('x', node, tensorMap, context).size, 'int32')];
49 case 'Rank':
50 return [ops.scalar(getParamValue('x', node, tensorMap, context).rank, 'int32')];
51 case 'NoOp':
52 return [ops.scalar(1)];
53 case 'Print':
54 const input = getParamValue('x', node, tensorMap, context);
55 const data = getParamValue('data', node, tensorMap, context);
56 const message = getParamValue('message', node, tensorMap, context);
57 const summarize = getParamValue('summarize', node, tensorMap, context);
58 console.warn('The graph has a tf.print() operation,' +
59 'usually used for debugging, which slows down performance.');
60 console.log(message);
61 for (let i = 0; i < data.length; i++) {
62 console.log(Array.prototype.slice.call(data[i].dataSync())
63 .slice(0, summarize));
64 }
65 return [input];
66 default:
67 throw TypeError(`Node type ${node.op} is not implemented`);
68 }
69};
70export const CATEGORY = 'graph';
71//# sourceMappingURL=data:application/json;base64,
\No newline at end of file