UNPKG

15.2 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17// tslint:disable-next-line: no-imports-from-dist
18import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
19import { getParamValue } from './utils';
20export const executeOp = (node, tensorMap, context, ops = tfOps) => {
21 switch (node.op) {
22 case 'Cast': {
23 return [ops.cast(getParamValue('x', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
24 }
25 case 'ExpandDims': {
26 const axis = getParamValue('axis', node, tensorMap, context);
27 return [ops.expandDims(getParamValue('x', node, tensorMap, context), axis)];
28 }
29 case 'Squeeze': {
30 const axis = getParamValue('axis', node, tensorMap, context);
31 return [ops.squeeze(getParamValue('x', node, tensorMap, context), axis)];
32 }
33 case 'Reshape': {
34 return [ops.reshape(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
35 }
36 case 'EnsureShape': {
37 return [ops.ensureShape(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
38 }
39 case 'MirrorPad': {
40 return [ops.mirrorPad(getParamValue('x', node, tensorMap, context), getParamValue('padding', node, tensorMap, context), getParamValue('mode', node, tensorMap, context))];
41 }
42 case 'PadV2':
43 case 'Pad': {
44 return [ops.pad(getParamValue('x', node, tensorMap, context), getParamValue('padding', node, tensorMap, context), getParamValue('constantValue', node, tensorMap, context))];
45 }
46 case 'SpaceToBatchND': {
47 const blockShape = getParamValue('blockShape', node, tensorMap, context);
48 const paddings = getParamValue('paddings', node, tensorMap, context);
49 return [ops.spaceToBatchND(getParamValue('x', node, tensorMap, context), blockShape, paddings)];
50 }
51 case 'BatchToSpaceND': {
52 const blockShape = getParamValue('blockShape', node, tensorMap, context);
53 const crops = getParamValue('crops', node, tensorMap, context);
54 return [ops.batchToSpaceND(getParamValue('x', node, tensorMap, context), blockShape, crops)];
55 }
56 case 'DepthToSpace': {
57 const blockSize = getParamValue('blockSize', node, tensorMap, context);
58 const dataFormat = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
59 return [ops.depthToSpace(getParamValue('x', node, tensorMap, context), blockSize, dataFormat)];
60 }
61 case 'BroadcastTo': {
62 return [ops.broadcastTo(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
63 }
64 case 'BroadcastArgs': {
65 return [ops.broadcastArgs(getParamValue('s0', node, tensorMap, context), getParamValue('s1', node, tensorMap, context))];
66 }
67 default:
68 throw TypeError(`Node type ${node.op} is not implemented`);
69 }
70};
71export const CATEGORY = 'transformation';
72//# sourceMappingURL=data:application/json;base64,
\No newline at end of file