UNPKG

12.7 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2019 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { getTensor } from '../executors/utils';
18import { getBoolArrayParam, getBoolParam, getDtypeArrayParam, getDtypeParam, getNumberParam, getNumericArrayParam, getStringArrayParam, getStringParam, getTensorShapeArrayParam, getTensorShapeParam } from '../operation_mapper';
19/**
20 * Helper class for lookup inputs and params for nodes in the model graph.
21 */
22export class NodeValueImpl {
23 constructor(node, tensorMap, context) {
24 this.node = node;
25 this.tensorMap = tensorMap;
26 this.context = context;
27 this.inputs = [];
28 this.attrs = {};
29 this.inputs = node.inputNames.map(name => this.getInput(name));
30 if (node.rawAttrs != null) {
31 this.attrs = Object.keys(node.rawAttrs)
32 .reduce((attrs, key) => {
33 attrs[key] = this.getAttr(key);
34 return attrs;
35 }, {});
36 }
37 }
38 /**
39 * Return the value of the attribute or input param.
40 * @param name String: name of attribute or input param.
41 */
42 getInput(name) {
43 return getTensor(name, this.tensorMap, this.context);
44 }
45 /**
46 * Return the value of the attribute or input param.
47 * @param name String: name of attribute or input param.
48 */
49 getAttr(name, defaultValue) {
50 const value = this.node.rawAttrs[name];
51 if (value.tensor != null) {
52 return getTensor(name, this.tensorMap, this.context);
53 }
54 if (value.i != null || value.f != null) {
55 return getNumberParam(this.node.rawAttrs, name, defaultValue);
56 }
57 if (value.s != null) {
58 return getStringParam(this.node.rawAttrs, name, defaultValue);
59 }
60 if (value.b != null) {
61 return getBoolParam(this.node.rawAttrs, name, defaultValue);
62 }
63 if (value.shape != null) {
64 return getTensorShapeParam(this.node.rawAttrs, name, defaultValue);
65 }
66 if (value.type != null) {
67 return getDtypeParam(this.node.rawAttrs, name, defaultValue);
68 }
69 if (value.list != null) {
70 if (value.list.i != null || value.list.f != null) {
71 return getNumericArrayParam(this.node.rawAttrs, name, defaultValue);
72 }
73 if (value.list.s != null) {
74 return getStringArrayParam(this.node.rawAttrs, name, defaultValue);
75 }
76 if (value.list.shape != null) {
77 return getTensorShapeArrayParam(this.node.rawAttrs, name, defaultValue);
78 }
79 if (value.list.b != null) {
80 return getBoolArrayParam(this.node.rawAttrs, name, defaultValue);
81 }
82 if (value.list.type != null) {
83 return getDtypeArrayParam(this.node.rawAttrs, name, defaultValue);
84 }
85 }
86 return defaultValue;
87 }
88}
89//# sourceMappingURL=data:application/json;base64,
\No newline at end of file