1 | /**
|
2 | * @license
|
3 | * Copyright 2018 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | // tslint:disable-next-line: no-imports-from-dist
|
18 | import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
|
19 | import { cloneTensor, getParamValue, getTensor } from './utils';
|
20 | export const executeOp = (node, tensorMap, context, ops = tfOps) => {
|
21 | switch (node.op) {
|
22 | case 'Const': {
|
23 | return tensorMap[node.name];
|
24 | }
|
25 | case 'PlaceholderWithDefault':
|
26 | const def = getParamValue('default', node, tensorMap, context);
|
27 | return [getTensor(node.name, tensorMap, context) || def];
|
28 | case 'Placeholder':
|
29 | return [getTensor(node.name, tensorMap, context)];
|
30 | case 'Identity':
|
31 | case 'StopGradient':
|
32 | case 'FakeQuantWithMinMaxVars': { // This op is currently ignored.
|
33 | const data = getParamValue('x', node, tensorMap, context);
|
34 | return [cloneTensor(data)];
|
35 | }
|
36 | case 'IdentityN':
|
37 | return getParamValue('x', node, tensorMap, context)
|
38 | .map((t) => cloneTensor(t));
|
39 | case 'Snapshot':
|
40 | const snapshot = getParamValue('x', node, tensorMap, context);
|
41 | return [cloneTensor(snapshot)];
|
42 | case 'Shape':
|
43 | return [ops.tensor1d(getParamValue('x', node, tensorMap, context).shape, 'int32')];
|
44 | case 'ShapeN':
|
45 | return getParamValue('x', node, tensorMap, context)
|
46 | .map((t) => ops.tensor1d(t.shape));
|
47 | case 'Size':
|
48 | return [ops.scalar(getParamValue('x', node, tensorMap, context).size, 'int32')];
|
49 | case 'Rank':
|
50 | return [ops.scalar(getParamValue('x', node, tensorMap, context).rank, 'int32')];
|
51 | case 'NoOp':
|
52 | return [ops.scalar(1)];
|
53 | case 'Print':
|
54 | const input = getParamValue('x', node, tensorMap, context);
|
55 | const data = getParamValue('data', node, tensorMap, context);
|
56 | const message = getParamValue('message', node, tensorMap, context);
|
57 | const summarize = getParamValue('summarize', node, tensorMap, context);
|
58 | console.warn('The graph has a tf.print() operation,' +
|
59 | 'usually used for debugging, which slows down performance.');
|
60 | console.log(message);
|
61 | for (let i = 0; i < data.length; i++) {
|
62 | console.log(Array.prototype.slice.call(data[i].dataSync())
|
63 | .slice(0, summarize));
|
64 | }
|
65 | return [input];
|
66 | default:
|
67 | throw TypeError(`Node type ${node.op} is not implemented`);
|
68 | }
|
69 | };
|
70 | export const CATEGORY = 'graph';
|
71 | //# sourceMappingURL=data:application/json;base64, |
\ | No newline at end of file |