1 | /**
|
2 | * @license
|
3 | * Copyright 2019 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | import { getTensor } from '../executors/utils';
|
18 | import { getBoolArrayParam, getBoolParam, getDtypeArrayParam, getDtypeParam, getNumberParam, getNumericArrayParam, getStringArrayParam, getStringParam, getTensorShapeArrayParam, getTensorShapeParam } from '../operation_mapper';
|
19 | /**
|
20 | * Helper class for lookup inputs and params for nodes in the model graph.
|
21 | */
|
22 | export class NodeValueImpl {
|
23 | constructor(node, tensorMap, context) {
|
24 | this.node = node;
|
25 | this.tensorMap = tensorMap;
|
26 | this.context = context;
|
27 | this.inputs = [];
|
28 | this.attrs = {};
|
29 | this.inputs = node.inputNames.map(name => this.getInput(name));
|
30 | if (node.rawAttrs != null) {
|
31 | this.attrs = Object.keys(node.rawAttrs)
|
32 | .reduce((attrs, key) => {
|
33 | attrs[key] = this.getAttr(key);
|
34 | return attrs;
|
35 | }, {});
|
36 | }
|
37 | }
|
38 | /**
|
39 | * Return the value of the attribute or input param.
|
40 | * @param name String: name of attribute or input param.
|
41 | */
|
42 | getInput(name) {
|
43 | return getTensor(name, this.tensorMap, this.context);
|
44 | }
|
45 | /**
|
46 | * Return the value of the attribute or input param.
|
47 | * @param name String: name of attribute or input param.
|
48 | */
|
49 | getAttr(name, defaultValue) {
|
50 | const value = this.node.rawAttrs[name];
|
51 | if (value.tensor != null) {
|
52 | return getTensor(name, this.tensorMap, this.context);
|
53 | }
|
54 | if (value.i != null || value.f != null) {
|
55 | return getNumberParam(this.node.rawAttrs, name, defaultValue);
|
56 | }
|
57 | if (value.s != null) {
|
58 | return getStringParam(this.node.rawAttrs, name, defaultValue);
|
59 | }
|
60 | if (value.b != null) {
|
61 | return getBoolParam(this.node.rawAttrs, name, defaultValue);
|
62 | }
|
63 | if (value.shape != null) {
|
64 | return getTensorShapeParam(this.node.rawAttrs, name, defaultValue);
|
65 | }
|
66 | if (value.type != null) {
|
67 | return getDtypeParam(this.node.rawAttrs, name, defaultValue);
|
68 | }
|
69 | if (value.list != null) {
|
70 | if (value.list.i != null || value.list.f != null) {
|
71 | return getNumericArrayParam(this.node.rawAttrs, name, defaultValue);
|
72 | }
|
73 | if (value.list.s != null) {
|
74 | return getStringArrayParam(this.node.rawAttrs, name, defaultValue);
|
75 | }
|
76 | if (value.list.shape != null) {
|
77 | return getTensorShapeArrayParam(this.node.rawAttrs, name, defaultValue);
|
78 | }
|
79 | if (value.list.b != null) {
|
80 | return getBoolArrayParam(this.node.rawAttrs, name, defaultValue);
|
81 | }
|
82 | if (value.list.type != null) {
|
83 | return getDtypeArrayParam(this.node.rawAttrs, name, defaultValue);
|
84 | }
|
85 | }
|
86 | return defaultValue;
|
87 | }
|
88 | }
|
89 | //# sourceMappingURL=data:application/json;base64, |
\ | No newline at end of file |