UNPKG

12.7 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2019 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { getTensor } from '../executors/utils';
18import { getBoolArrayParam, getBoolParam, getDtypeArrayParam, getDtypeParam, getNumberParam, getNumericArrayParam, getStringArrayParam, getStringParam, getTensorShapeArrayParam, getTensorShapeParam } from '../operation_mapper';
19/**
20 * Helper class for lookup inputs and params for nodes in the model graph.
21 */
22export class NodeValueImpl {
23 constructor(node, tensorMap, context) {
24 this.node = node;
25 this.tensorMap = tensorMap;
26 this.context = context;
27 this.inputs = [];
28 this.attrs = {};
29 this.inputs = node.inputNames.map(name => this.getInput(name));
30 if (node.rawAttrs != null) {
31 this.attrs = Object.keys(node.rawAttrs)
32 .reduce((attrs, key) => {
33 attrs[key] = this.getAttr(key);
34 return attrs;
35 }, {});
36 }
37 }
38 /**
39 * Return the value of the attribute or input param.
40 * @param name String: name of attribute or input param.
41 */
42 getInput(name) {
43 return getTensor(name, this.tensorMap, this.context);
44 }
45 /**
46 * Return the value of the attribute or input param.
47 * @param name String: name of attribute or input param.
48 */
49 getAttr(name, defaultValue) {
50 const value = this.node.rawAttrs[name];
51 if (value.tensor != null) {
52 return getTensor(name, this.tensorMap, this.context);
53 }
54 if (value.i != null || value.f != null) {
55 return getNumberParam(this.node.rawAttrs, name, defaultValue);
56 }
57 if (value.s != null) {
58 return getStringParam(this.node.rawAttrs, name, defaultValue);
59 }
60 if (value.b != null) {
61 return getBoolParam(this.node.rawAttrs, name, defaultValue);
62 }
63 if (value.shape != null) {
64 return getTensorShapeParam(this.node.rawAttrs, name, defaultValue);
65 }
66 if (value.type != null) {
67 return getDtypeParam(this.node.rawAttrs, name, defaultValue);
68 }
69 if (value.list != null) {
70 if (value.list.i != null || value.list.f != null) {
71 return getNumericArrayParam(this.node.rawAttrs, name, defaultValue);
72 }
73 if (value.list.s != null) {
74 return getStringArrayParam(this.node.rawAttrs, name, defaultValue);
75 }
76 if (value.list.shape != null) {
77 return getTensorShapeArrayParam(this.node.rawAttrs, name, defaultValue);
78 }
79 if (value.list.b != null) {
80 return getBoolArrayParam(this.node.rawAttrs, name, defaultValue);
81 }
82 if (value.list.type != null) {
83 return getDtypeArrayParam(this.node.rawAttrs, name, defaultValue);
84 }
85 }
86 return defaultValue;
87 }
88}
89//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"node_value_impl.js","sourceRoot":"","sources":["../../../../../../../tfjs-converter/src/operations/custom_op/node_value_impl.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAMH,OAAO,EAAC,SAAS,EAAC,MAAM,oBAAoB,CAAC;AAC7C,OAAO,EAAC,iBAAiB,EAAE,YAAY,EAAE,kBAAkB,EAAE,aAAa,EAAE,cAAc,EAAE,oBAAoB,EAAE,mBAAmB,EAAE,cAAc,EAAE,wBAAwB,EAAE,mBAAmB,EAAC,MAAM,qBAAqB,CAAC;AAGjO;;GAEG;AACH,MAAM,OAAO,aAAa;IAGxB,YACY,IAAU,EAAU,SAA0B,EAC9C,OAAyB;QADzB,SAAI,GAAJ,IAAI,CAAM;QAAU,cAAS,GAAT,SAAS,CAAiB;QAC9C,YAAO,GAAP,OAAO,CAAkB;QAJrB,WAAM,GAAa,EAAE,CAAC;QACtB,UAAK,GAA+B,EAAE,CAAC;QAIrD,IAAI,CAAC,MAAM,GAAG,IAAI,CAAC,UAAU,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,IAAI,CAAC,QAAQ,CAAC,IAAI,CAAC,CAAC,CAAC;QAC/D,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;YACzB,IAAI,CAAC,KAAK,GAAG,MAAM,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,CAAC;iBACrB,MAAM,CAAC,CAAC,KAAiC,EAAE,GAAG,EAAE,EAAE;gBACjD,KAAK,CAAC,GAAG,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC;gBAC/B,OAAO,KAAK,CAAC;YACf,CAAC,EAAE,EAAE,CAAC,CAAC;SACzB;IACH,CAAC;IAED;;;OAGG;IACK,QAAQ,CAAC,IAAY;QAC3B,OAAO,SAAS,CAAC,IAAI,EAAE,IAAI,CAAC,SAAS,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC;IACvD,CAAC;IAED;;;OAGG;IACK,OAAO,CAAC,IAAY,EAAE,YAAwB;QACpD,MAAM,KAAK,GAAG,IAAI,CAAC,IAAI,CAAC,QAAQ,CAAC,IAAI,CAAC,CAAC;QACvC,IAAI,KAAK,CAAC,MAAM,IAAI,IAAI,EAAE;YACxB,OAAO,SAAS,CAAC,IAAI,EAAE,IAAI,CAAC,SAAS,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC;SACtD;QACD,IAAI,KAAK,CAAC,CAAC,IAAI,IAAI,IAAI,KAAK,CAAC,CAAC,IAAI,IAAI,EAAE;YACtC,OAAO,cAAc,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAsB,CAAC,CAAC;SACzE;QACD,IAAI,KAAK,CAAC,CAAC,IAAI,IAAI,EAAE;YACnB,OAAO,cAAc,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAsB,CAAC,CAAC;SACzE;QACD,IAAI,KAAK,CAAC,CAAC,IAAI,IAAI,EAAE;YACnB,OAAO,YAAY,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAuB,CAAC,CAAC;SACxE;QACD,IAAI,KAAK,CAAC,KAAK,IAAI,IAAI,EAAE;YACvB,OAAO,mBAAmB,CACtB,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAwB,CAAC,CAAC;SACzD;QACD,IAAI,KAAK,CAAC,IAAI,IAAI,IAAI,EAAE;YACtB,OAAO,aAAa,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAwB,CAAC,CAAC;SAC1E;QACD,IAAI,KAAK,CAAC,IAAI,IAAI,IAAI,EAAE;YACtB,IAAI,KAAK,CAAC,IAAI,CAAC,CAAC,IAAI,IAAI,IAAI,KAAK,CAAC,IAAI,CAAC,CAAC,IAAI,IAAI,EAAE;gBAChD,OAAO,oBAAoB,CACvB,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAwB,CAAC,CAAC;aACzD;YACD,IAAI,KAAK,CAAC,IAAI,CAAC,CAAC,IAAI,IAAI,EAAE;gBACxB,OAAO,mBAAmB,CACtB,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAwB,CAAC,CAAC;aACzD;YACD,IAAI,KAAK,CAAC,IAAI,CAAC,KAAK,IAAI,IAAI,EAAE;gBAC5B,OAAO,wBAAwB,CAC3B,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAA0B,CAAC,CAAC;aAC3D;YACD,IAAI,KAAK,CAAC,IAAI,CAAC,CAAC,IAAI,IAAI,EAAE;gBACxB,OAAO,iBAAiB,CACpB,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAyB,CAAC,CAAC;aAC1D;YACD,IAAI,KAAK,CAAC,IAAI,CAAC,IAAI,IAAI,IAAI,EAAE;gBAC3B,OAAO,kBAAkB,CACrB,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAA0B,CAAC,CAAC;aAC3D;SACF;QAED,OAAO,YAAY,CAAC;IACtB,CAAC;CACF","sourcesContent":["/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {DataType, Tensor} from '@tensorflow/tfjs-core';\n\nimport {NamedTensorsMap} from '../../data/types';\nimport {ExecutionContext} from '../../executor/execution_context';\nimport {getTensor} from '../executors/utils';\nimport {getBoolArrayParam, getBoolParam, getDtypeArrayParam, getDtypeParam, getNumberParam, getNumericArrayParam, getStringArrayParam, getStringParam, getTensorShapeArrayParam, getTensorShapeParam} from '../operation_mapper';\nimport {GraphNode, Node, ValueType} from '../types';\n\n/**\n * Helper class for lookup inputs and params for nodes in the model graph.\n */\nexport class NodeValueImpl implements GraphNode {\n  public readonly inputs: Tensor[] = [];\n  public readonly attrs: {[key: string]: ValueType} = {};\n  constructor(\n      private node: Node, private tensorMap: NamedTensorsMap,\n      private context: ExecutionContext) {\n    this.inputs = node.inputNames.map(name => this.getInput(name));\n    if (node.rawAttrs != null) {\n      this.attrs = Object.keys(node.rawAttrs)\n                       .reduce((attrs: {[key: string]: ValueType}, key) => {\n                         attrs[key] = this.getAttr(key);\n                         return attrs;\n                       }, {});\n    }\n  }\n\n  /**\n   * Return the value of the attribute or input param.\n   * @param name String: name of attribute or input param.\n   */\n  private getInput(name: string): Tensor {\n    return getTensor(name, this.tensorMap, this.context);\n  }\n\n  /**\n   * Return the value of the attribute or input param.\n   * @param name String: name of attribute or input param.\n   */\n  private getAttr(name: string, defaultValue?: ValueType): ValueType {\n    const value = this.node.rawAttrs[name];\n    if (value.tensor != null) {\n      return getTensor(name, this.tensorMap, this.context);\n    }\n    if (value.i != null || value.f != null) {\n      return getNumberParam(this.node.rawAttrs, name, defaultValue as number);\n    }\n    if (value.s != null) {\n      return getStringParam(this.node.rawAttrs, name, defaultValue as string);\n    }\n    if (value.b != null) {\n      return getBoolParam(this.node.rawAttrs, name, defaultValue as boolean);\n    }\n    if (value.shape != null) {\n      return getTensorShapeParam(\n          this.node.rawAttrs, name, defaultValue as number[]);\n    }\n    if (value.type != null) {\n      return getDtypeParam(this.node.rawAttrs, name, defaultValue as DataType);\n    }\n    if (value.list != null) {\n      if (value.list.i != null || value.list.f != null) {\n        return getNumericArrayParam(\n            this.node.rawAttrs, name, defaultValue as number[]);\n      }\n      if (value.list.s != null) {\n        return getStringArrayParam(\n            this.node.rawAttrs, name, defaultValue as string[]);\n      }\n      if (value.list.shape != null) {\n        return getTensorShapeArrayParam(\n            this.node.rawAttrs, name, defaultValue as number[][]);\n      }\n      if (value.list.b != null) {\n        return getBoolArrayParam(\n            this.node.rawAttrs, name, defaultValue as boolean[]);\n      }\n      if (value.list.type != null) {\n        return getDtypeArrayParam(\n            this.node.rawAttrs, name, defaultValue as DataType[]);\n      }\n    }\n\n    return defaultValue;\n  }\n}\n"]}
\No newline at end of file