UNPKG

19.3 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17// tslint:disable-next-line: no-imports-from-dist
18import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
19import { getParamValue } from './utils';
20export const executeOp = (node, tensorMap, context, ops = tfOps) => {
21 switch (node.op) {
22 case 'Fill': {
23 const shape = getParamValue('shape', node, tensorMap, context);
24 const dtype = getParamValue('dtype', node, tensorMap, context);
25 const value = getParamValue('value', node, tensorMap, context);
26 return [ops.fill(shape, value, dtype)];
27 }
28 case 'LinSpace': {
29 const start = getParamValue('start', node, tensorMap, context);
30 const stop = getParamValue('stop', node, tensorMap, context);
31 const num = getParamValue('num', node, tensorMap, context);
32 return [ops.linspace(start, stop, num)];
33 }
34 case 'Multinomial': {
35 const logits = getParamValue('logits', node, tensorMap, context);
36 const numSamples = getParamValue('numSamples', node, tensorMap, context);
37 const seed = getParamValue('seed', node, tensorMap, context);
38 return [ops.multinomial(logits, numSamples, seed)];
39 }
40 case 'OneHot': {
41 const indices = getParamValue('indices', node, tensorMap, context);
42 const depth = getParamValue('depth', node, tensorMap, context);
43 const onValue = getParamValue('onValue', node, tensorMap, context);
44 const offValue = getParamValue('offValue', node, tensorMap, context);
45 const dtype = getParamValue('dtype', node, tensorMap, context);
46 return [ops.oneHot(indices, depth, onValue, offValue, dtype)];
47 }
48 case 'Ones': {
49 return [ops.ones(getParamValue('shape', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
50 }
51 case 'OnesLike': {
52 return [ops.onesLike(getParamValue('x', node, tensorMap, context))];
53 }
54 case 'RandomStandardNormal': {
55 return [ops.randomStandardNormal(getParamValue('shape', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context), getParamValue('seed', node, tensorMap, context))];
56 }
57 case 'RandomUniform': {
58 return [ops.randomUniform(
59 // tslint:disable-next-line:no-any
60 getParamValue('shape', node, tensorMap, context), getParamValue('minval', node, tensorMap, context), getParamValue('maxval', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
61 }
62 case 'RandomUniformInt': {
63 return [ops.randomUniformInt(getParamValue('shape', node, tensorMap, context), getParamValue('minval', node, tensorMap, context), getParamValue('maxval', node, tensorMap, context), getParamValue('seed', node, tensorMap, context))];
64 }
65 case 'Range': {
66 const start = getParamValue('start', node, tensorMap, context);
67 const stop = getParamValue('stop', node, tensorMap, context);
68 const step = getParamValue('step', node, tensorMap, context);
69 return [ops.range(start, stop, step, getParamValue('dtype', node, tensorMap, context))];
70 }
71 case 'TruncatedNormal': {
72 const shape = getParamValue('shape', node, tensorMap, context);
73 const mean = getParamValue('mean', node, tensorMap, context);
74 const stdDev = getParamValue('stdDev', node, tensorMap, context);
75 const seed = getParamValue('seed', node, tensorMap, context);
76 return [ops.truncatedNormal(shape, mean, stdDev, getParamValue('dtype', node, tensorMap, context), seed)];
77 }
78 case 'Zeros': {
79 return [ops.zeros(getParamValue('shape', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
80 }
81 case 'ZerosLike': {
82 return [ops.zerosLike(getParamValue('x', node, tensorMap, context))];
83 }
84 default:
85 throw TypeError(`Node type ${node.op} is not implemented`);
86 }
87};
88export const CATEGORY = 'creation';
89//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"creation_executor.js","sourceRoot":"","sources":["../../../../../../../tfjs-converter/src/operations/executors/creation_executor.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAGH,iDAAiD;AACjD,OAAO,KAAK,KAAK,MAAM,kDAAkD,CAAC;AAM1E,OAAO,EAAC,aAAa,EAAC,MAAM,SAAS,CAAC;AAEtC,MAAM,CAAC,MAAM,SAAS,GAClB,CAAC,IAAU,EAAE,SAA0B,EAAE,OAAyB,EACjE,GAAG,GAAG,KAAK,EAAY,EAAE;IACxB,QAAQ,IAAI,CAAC,EAAE,EAAE;QACf,KAAK,MAAM,CAAC,CAAC;YACX,MAAM,KAAK,GACP,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACjE,MAAM,KAAK,GACP,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACjE,MAAM,KAAK,GACP,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAC/D,OAAO,CAAC,GAAG,CAAC,IAAI,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC;SACxC;QACD,KAAK,UAAU,CAAC,CAAC;YACf,MAAM,KAAK,GACP,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAC/D,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAC9D,MAAM,GAAG,GAAG,aAAa,CAAC,KAAK,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACrE,OAAO,CAAC,GAAG,CAAC,QAAQ,CAAC,KAAK,EAAE,IAAI,EAAE,GAAG,CAAC,CAAC,CAAC;SACzC;QACD,KAAK,aAAa,CAAC,CAAC;YAClB,MAAM,MAAM,GACR,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAClE,MAAM,UAAU,GACZ,aAAa,CAAC,YAAY,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACpE,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAC9D,OAAO,CAAC,GAAG,CAAC,WAAW,CAAC,MAAM,EAAE,UAAU,EAAE,IAAI,CAAC,CAAC,CAAC;SACpD;QACD,KAAK,QAAQ,CAAC,CAAC;YACb,MAAM,OAAO,GACT,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACnE,MAAM,KAAK,GACP,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAC/D,MAAM,OAAO,GACT,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACjE,MAAM,QAAQ,GACV,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAClE,MAAM,KAAK,GACP,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACjE,OAAO,CAAC,GAAG,CAAC,MAAM,CAAC,OAAO,EAAE,KAAK,EAAE,OAAO,EAAE,QAAQ,EAAE,KAAK,CAAC,CAAC,CAAC;SAC/D;QACD,KAAK,MAAM,CAAC,CAAC;YACX,OAAO,CAAC,GAAG,CAAC,IAAI,CACZ,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EAC5D,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC,CAAC,CAAC;SACpE;QACD,KAAK,UAAU,CAAC,CAAC;YACf,OAAO,CAAC,GAAG,CAAC,QAAQ,CAChB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SAC9D;QACD,KAAK,sBAAsB,CAAC,CAAC;YAC3B,OAAO,CAAC,GAAG,CAAC,oBAAoB,CAC5B,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EAC5D,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CACpC,EACX,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SACjE;QACD,KAAK,eAAe,CAAC,CAAC;YACpB,OAAO,CAAC,GAAG,CAAC,aAAa;gBACrB,kCAAkC;gBAClC,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAQ,EACvD,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAC3D,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAC3D,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC,CAAC,CAAC;SACpE;QACD,KAAK,kBAAkB,CAAC,CAAC;YACvB,OAAO,CAAC,GAAG,CAAC,gBAAgB,CACxB,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EAC5D,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAC3D,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAC3D,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SACjE;QACD,KAAK,OAAO,CAAC,CAAC;YACZ,MAAM,KAAK,GACP,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAC/D,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAC9D,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAC9D,OAAO,CAAC,GAAG,CAAC,KAAK,CACb,KAAK,EAAE,IAAI,EAAE,IAAI,EACjB,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CACpC,CAAC,CAAC,CAAC;SACnB;QACD,KAAK,iBAAiB,CAAC,CAAC;YACtB,MAAM,KAAK,GACP,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACjE,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAC9D,MAAM,MAAM,GACR,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAChE,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAC9D,OAAO,CAAC,GAAG,CAAC,eAAe,CACvB,KAAK,EAAE,IAAI,EAAE,MAAM,EACnB,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CACpC,EACX,IAAI,CAAC,CAAC,CAAC;SACZ;QACD,KAAK,OAAO,CAAC,CAAC;YACZ,OAAO,CAAC,GAAG,CAAC,KAAK,CACb,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EAC5D,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC,CAAC,CAAC;SACpE;QACD,KAAK,WAAW,CAAC,CAAC;YAChB,OAAO,CAAC,GAAG,CAAC,SAAS,CACjB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SAC9D;QACD;YACE,MAAM,SAAS,CAAC,aAAa,IAAI,CAAC,EAAE,qBAAqB,CAAC,CAAC;KAC9D;AACH,CAAC,CAAC;AAEN,MAAM,CAAC,MAAM,QAAQ,GAAG,UAAU,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {DataType, Tensor, Tensor1D} from '@tensorflow/tfjs-core';\n// tslint:disable-next-line: no-imports-from-dist\nimport * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';\n\nimport {NamedTensorsMap} from '../../data/types';\nimport {ExecutionContext} from '../../executor/execution_context';\nimport {InternalOpExecutor, Node} from '../types';\n\nimport {getParamValue} from './utils';\n\nexport const executeOp: InternalOpExecutor =\n    (node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext,\n     ops = tfOps): Tensor[] => {\n      switch (node.op) {\n        case 'Fill': {\n          const shape =\n              getParamValue('shape', node, tensorMap, context) as number[];\n          const dtype =\n              getParamValue('dtype', node, tensorMap, context) as DataType;\n          const value =\n              getParamValue('value', node, tensorMap, context) as number;\n          return [ops.fill(shape, value, dtype)];\n        }\n        case 'LinSpace': {\n          const start =\n              getParamValue('start', node, tensorMap, context) as number;\n          const stop =\n              getParamValue('stop', node, tensorMap, context) as number;\n          const num = getParamValue('num', node, tensorMap, context) as number;\n          return [ops.linspace(start, stop, num)];\n        }\n        case 'Multinomial': {\n          const logits =\n              getParamValue('logits', node, tensorMap, context) as Tensor1D;\n          const numSamples =\n              getParamValue('numSamples', node, tensorMap, context) as number;\n          const seed =\n              getParamValue('seed', node, tensorMap, context) as number;\n          return [ops.multinomial(logits, numSamples, seed)];\n        }\n        case 'OneHot': {\n          const indices =\n              getParamValue('indices', node, tensorMap, context) as Tensor1D;\n          const depth =\n              getParamValue('depth', node, tensorMap, context) as number;\n          const onValue =\n              getParamValue('onValue', node, tensorMap, context) as number;\n          const offValue =\n              getParamValue('offValue', node, tensorMap, context) as number;\n          const dtype =\n              getParamValue('dtype', node, tensorMap, context) as DataType;\n          return [ops.oneHot(indices, depth, onValue, offValue, dtype)];\n        }\n        case 'Ones': {\n          return [ops.ones(\n              getParamValue('shape', node, tensorMap, context) as number[],\n              getParamValue('dtype', node, tensorMap, context) as DataType)];\n        }\n        case 'OnesLike': {\n          return [ops.onesLike(\n              getParamValue('x', node, tensorMap, context) as Tensor)];\n        }\n        case 'RandomStandardNormal': {\n          return [ops.randomStandardNormal(\n              getParamValue('shape', node, tensorMap, context) as number[],\n              getParamValue('dtype', node, tensorMap, context) as 'float32' |\n                  'int32',\n              getParamValue('seed', node, tensorMap, context) as number)];\n        }\n        case 'RandomUniform': {\n          return [ops.randomUniform(\n              // tslint:disable-next-line:no-any\n              getParamValue('shape', node, tensorMap, context) as any,\n              getParamValue('minval', node, tensorMap, context) as number,\n              getParamValue('maxval', node, tensorMap, context) as number,\n              getParamValue('dtype', node, tensorMap, context) as DataType)];\n        }\n        case 'RandomUniformInt': {\n          return [ops.randomUniformInt(\n              getParamValue('shape', node, tensorMap, context) as number[],\n              getParamValue('minval', node, tensorMap, context) as number,\n              getParamValue('maxval', node, tensorMap, context) as number,\n              getParamValue('seed', node, tensorMap, context) as number)];\n        }\n        case 'Range': {\n          const start =\n              getParamValue('start', node, tensorMap, context) as number;\n          const stop =\n              getParamValue('stop', node, tensorMap, context) as number;\n          const step =\n              getParamValue('step', node, tensorMap, context) as number;\n          return [ops.range(\n              start, stop, step,\n              getParamValue('dtype', node, tensorMap, context) as 'float32' |\n                  'int32')];\n        }\n        case 'TruncatedNormal': {\n          const shape =\n              getParamValue('shape', node, tensorMap, context) as number[];\n          const mean =\n              getParamValue('mean', node, tensorMap, context) as number;\n          const stdDev =\n              getParamValue('stdDev', node, tensorMap, context) as number;\n          const seed =\n              getParamValue('seed', node, tensorMap, context) as number;\n          return [ops.truncatedNormal(\n              shape, mean, stdDev,\n              getParamValue('dtype', node, tensorMap, context) as 'float32' |\n                  'int32',\n              seed)];\n        }\n        case 'Zeros': {\n          return [ops.zeros(\n              getParamValue('shape', node, tensorMap, context) as number[],\n              getParamValue('dtype', node, tensorMap, context) as DataType)];\n        }\n        case 'ZerosLike': {\n          return [ops.zerosLike(\n              getParamValue('x', node, tensorMap, context) as Tensor)];\n        }\n        default:\n          throw TypeError(`Node type ${node.op} is not implemented`);\n      }\n    };\n\nexport const CATEGORY = 'creation';\n"]}
\No newline at end of file