1 | /**
|
2 | * @license
|
3 | * Copyright 2019 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | import { getTensor } from '../executors/utils';
|
18 | import { getBoolArrayParam, getBoolParam, getDtypeArrayParam, getDtypeParam, getNumberParam, getNumericArrayParam, getStringArrayParam, getStringParam, getTensorShapeArrayParam, getTensorShapeParam } from '../operation_mapper';
|
19 | /**
|
20 | * Helper class for lookup inputs and params for nodes in the model graph.
|
21 | */
|
22 | export class NodeValueImpl {
|
23 | constructor(node, tensorMap, context) {
|
24 | this.node = node;
|
25 | this.tensorMap = tensorMap;
|
26 | this.context = context;
|
27 | this.inputs = [];
|
28 | this.attrs = {};
|
29 | this.inputs = node.inputNames.map(name => this.getInput(name));
|
30 | if (node.rawAttrs != null) {
|
31 | this.attrs = Object.keys(node.rawAttrs)
|
32 | .reduce((attrs, key) => {
|
33 | attrs[key] = this.getAttr(key);
|
34 | return attrs;
|
35 | }, {});
|
36 | }
|
37 | }
|
38 | /**
|
39 | * Return the value of the attribute or input param.
|
40 | * @param name String: name of attribute or input param.
|
41 | */
|
42 | getInput(name) {
|
43 | return getTensor(name, this.tensorMap, this.context);
|
44 | }
|
45 | /**
|
46 | * Return the value of the attribute or input param.
|
47 | * @param name String: name of attribute or input param.
|
48 | */
|
49 | getAttr(name, defaultValue) {
|
50 | const value = this.node.rawAttrs[name];
|
51 | if (value.tensor != null) {
|
52 | return getTensor(name, this.tensorMap, this.context);
|
53 | }
|
54 | if (value.i != null || value.f != null) {
|
55 | return getNumberParam(this.node.rawAttrs, name, defaultValue);
|
56 | }
|
57 | if (value.s != null) {
|
58 | return getStringParam(this.node.rawAttrs, name, defaultValue);
|
59 | }
|
60 | if (value.b != null) {
|
61 | return getBoolParam(this.node.rawAttrs, name, defaultValue);
|
62 | }
|
63 | if (value.shape != null) {
|
64 | return getTensorShapeParam(this.node.rawAttrs, name, defaultValue);
|
65 | }
|
66 | if (value.type != null) {
|
67 | return getDtypeParam(this.node.rawAttrs, name, defaultValue);
|
68 | }
|
69 | if (value.list != null) {
|
70 | if (value.list.i != null || value.list.f != null) {
|
71 | return getNumericArrayParam(this.node.rawAttrs, name, defaultValue);
|
72 | }
|
73 | if (value.list.s != null) {
|
74 | return getStringArrayParam(this.node.rawAttrs, name, defaultValue);
|
75 | }
|
76 | if (value.list.shape != null) {
|
77 | return getTensorShapeArrayParam(this.node.rawAttrs, name, defaultValue);
|
78 | }
|
79 | if (value.list.b != null) {
|
80 | return getBoolArrayParam(this.node.rawAttrs, name, defaultValue);
|
81 | }
|
82 | if (value.list.type != null) {
|
83 | return getDtypeArrayParam(this.node.rawAttrs, name, defaultValue);
|
84 | }
|
85 | }
|
86 | return defaultValue;
|
87 | }
|
88 | }
|
89 | //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"node_value_impl.js","sourceRoot":"","sources":["../../../../../../../tfjs-converter/src/operations/custom_op/node_value_impl.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAMH,OAAO,EAAC,SAAS,EAAC,MAAM,oBAAoB,CAAC;AAC7C,OAAO,EAAC,iBAAiB,EAAE,YAAY,EAAE,kBAAkB,EAAE,aAAa,EAAE,cAAc,EAAE,oBAAoB,EAAE,mBAAmB,EAAE,cAAc,EAAE,wBAAwB,EAAE,mBAAmB,EAAC,MAAM,qBAAqB,CAAC;AAGjO;;GAEG;AACH,MAAM,OAAO,aAAa;IAGxB,YACY,IAAU,EAAU,SAA0B,EAC9C,OAAyB;QADzB,SAAI,GAAJ,IAAI,CAAM;QAAU,cAAS,GAAT,SAAS,CAAiB;QAC9C,YAAO,GAAP,OAAO,CAAkB;QAJrB,WAAM,GAAa,EAAE,CAAC;QACtB,UAAK,GAA+B,EAAE,CAAC;QAIrD,IAAI,CAAC,MAAM,GAAG,IAAI,CAAC,UAAU,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,IAAI,CAAC,QAAQ,CAAC,IAAI,CAAC,CAAC,CAAC;QAC/D,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;YACzB,IAAI,CAAC,KAAK,GAAG,MAAM,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,CAAC;iBACrB,MAAM,CAAC,CAAC,KAAiC,EAAE,GAAG,EAAE,EAAE;gBACjD,KAAK,CAAC,GAAG,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC;gBAC/B,OAAO,KAAK,CAAC;YACf,CAAC,EAAE,EAAE,CAAC,CAAC;SACzB;IACH,CAAC;IAED;;;OAGG;IACK,QAAQ,CAAC,IAAY;QAC3B,OAAO,SAAS,CAAC,IAAI,EAAE,IAAI,CAAC,SAAS,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC;IACvD,CAAC;IAED;;;OAGG;IACK,OAAO,CAAC,IAAY,EAAE,YAAwB;QACpD,MAAM,KAAK,GAAG,IAAI,CAAC,IAAI,CAAC,QAAQ,CAAC,IAAI,CAAC,CAAC;QACvC,IAAI,KAAK,CAAC,MAAM,IAAI,IAAI,EAAE;YACxB,OAAO,SAAS,CAAC,IAAI,EAAE,IAAI,CAAC,SAAS,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC;SACtD;QACD,IAAI,KAAK,CAAC,CAAC,IAAI,IAAI,IAAI,KAAK,CAAC,CAAC,IAAI,IAAI,EAAE;YACtC,OAAO,cAAc,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAsB,CAAC,CAAC;SACzE;QACD,IAAI,KAAK,CAAC,CAAC,IAAI,IAAI,EAAE;YACnB,OAAO,cAAc,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAsB,CAAC,CAAC;SACzE;QACD,IAAI,KAAK,CAAC,CAAC,IAAI,IAAI,EAAE;YACnB,OAAO,YAAY,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAuB,CAAC,CAAC;SACxE;QACD,IAAI,KAAK,CAAC,KAAK,IAAI,IAAI,EAAE;YACvB,OAAO,mBAAmB,CACtB,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAwB,CAAC,CAAC;SACzD;QACD,IAAI,KAAK,CAAC,IAAI,IAAI,IAAI,EAAE;YACtB,OAAO,aAAa,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAwB,CAAC,CAAC;SAC1E;QACD,IAAI,KAAK,CAAC,IAAI,IAAI,IAAI,EAAE;YACtB,IAAI,KAAK,CAAC,IAAI,CAAC,CAAC,IAAI,IAAI,IAAI,KAAK,CAAC,IAAI,CAAC,CAAC,IAAI,IAAI,EAAE;gBAChD,OAAO,oBAAoB,CACvB,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAwB,CAAC,CAAC;aACzD;YACD,IAAI,KAAK,CAAC,IAAI,CAAC,CAAC,IAAI,IAAI,EAAE;gBACxB,OAAO,mBAAmB,CACtB,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAwB,CAAC,CAAC;aACzD;YACD,IAAI,KAAK,CAAC,IAAI,CAAC,KAAK,IAAI,IAAI,EAAE;gBAC5B,OAAO,wBAAwB,CAC3B,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAA0B,CAAC,CAAC;aAC3D;YACD,IAAI,KAAK,CAAC,IAAI,CAAC,CAAC,IAAI,IAAI,EAAE;gBACxB,OAAO,iBAAiB,CACpB,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAyB,CAAC,CAAC;aAC1D;YACD,IAAI,KAAK,CAAC,IAAI,CAAC,IAAI,IAAI,IAAI,EAAE;gBAC3B,OAAO,kBAAkB,CACrB,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAA0B,CAAC,CAAC;aAC3D;SACF;QAED,OAAO,YAAY,CAAC;IACtB,CAAC;CACF","sourcesContent":["/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {DataType, Tensor} from '@tensorflow/tfjs-core';\n\nimport {NamedTensorsMap} from '../../data/types';\nimport {ExecutionContext} from '../../executor/execution_context';\nimport {getTensor} from '../executors/utils';\nimport {getBoolArrayParam, getBoolParam, getDtypeArrayParam, getDtypeParam, getNumberParam, getNumericArrayParam, getStringArrayParam, getStringParam, getTensorShapeArrayParam, getTensorShapeParam} from '../operation_mapper';\nimport {GraphNode, Node, ValueType} from '../types';\n\n/**\n * Helper class for lookup inputs and params for nodes in the model graph.\n */\nexport class NodeValueImpl implements GraphNode {\n  public readonly inputs: Tensor[] = [];\n  public readonly attrs: {[key: string]: ValueType} = {};\n  constructor(\n      private node: Node, private tensorMap: NamedTensorsMap,\n      private context: ExecutionContext) {\n    this.inputs = node.inputNames.map(name => this.getInput(name));\n    if (node.rawAttrs != null) {\n      this.attrs = Object.keys(node.rawAttrs)\n                       .reduce((attrs: {[key: string]: ValueType}, key) => {\n                         attrs[key] = this.getAttr(key);\n                         return attrs;\n                       }, {});\n    }\n  }\n\n  /**\n   * Return the value of the attribute or input param.\n   * @param name String: name of attribute or input param.\n   */\n  private getInput(name: string): Tensor {\n    return getTensor(name, this.tensorMap, this.context);\n  }\n\n  /**\n   * Return the value of the attribute or input param.\n   * @param name String: name of attribute or input param.\n   */\n  private getAttr(name: string, defaultValue?: ValueType): ValueType {\n    const value = this.node.rawAttrs[name];\n    if (value.tensor != null) {\n      return getTensor(name, this.tensorMap, this.context);\n    }\n    if (value.i != null || value.f != null) {\n      return getNumberParam(this.node.rawAttrs, name, defaultValue as number);\n    }\n    if (value.s != null) {\n      return getStringParam(this.node.rawAttrs, name, defaultValue as string);\n    }\n    if (value.b != null) {\n      return getBoolParam(this.node.rawAttrs, name, defaultValue as boolean);\n    }\n    if (value.shape != null) {\n      return getTensorShapeParam(\n          this.node.rawAttrs, name, defaultValue as number[]);\n    }\n    if (value.type != null) {\n      return getDtypeParam(this.node.rawAttrs, name, defaultValue as DataType);\n    }\n    if (value.list != null) {\n      if (value.list.i != null || value.list.f != null) {\n        return getNumericArrayParam(\n            this.node.rawAttrs, name, defaultValue as number[]);\n      }\n      if (value.list.s != null) {\n        return getStringArrayParam(\n            this.node.rawAttrs, name, defaultValue as string[]);\n      }\n      if (value.list.shape != null) {\n        return getTensorShapeArrayParam(\n            this.node.rawAttrs, name, defaultValue as number[][]);\n      }\n      if (value.list.b != null) {\n        return getBoolArrayParam(\n            this.node.rawAttrs, name, defaultValue as boolean[]);\n      }\n      if (value.list.type != null) {\n        return getDtypeArrayParam(\n            this.node.rawAttrs, name, defaultValue as DataType[]);\n      }\n    }\n\n    return defaultValue;\n  }\n}\n"]} |
\ | No newline at end of file |