UNPKG

82.2 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { tidy, util } from '@tensorflow/tfjs-core';
18import { getNodeNameAndIndex, getParamValue, getTensor, getTensorsForCurrentContenxt, parseNodeName } from '../operations/executors/utils';
19import { executeOp } from '../operations/operation_executor';
20import { ExecutionContext } from './execution_context';
21import { getExecutionSubgraph, getNodesInTopologicalOrder, isControlFlow } from './model_analysis';
22export class GraphExecutor {
23 /**
24 *
25 * @param graph Graph the model or function graph to be executed.
26 * @param parent When building function exector you need to set the parent
27 * executor. Since the weights and function executor maps are set at parant
28 * level, that function executor can access the function maps and weight maps
29 * through the parent.
30 */
31 constructor(graph, parent) {
32 this.graph = graph;
33 this.parent = parent;
34 this.compiledMap = new Map();
35 this._weightMap = {};
36 this.SEPERATOR = ',';
37 this._functions = {};
38 this._functionExecutorMap = {};
39 this._outputs = graph.outputs;
40 this._inputs = graph.inputs;
41 this._initNodes = graph.initNodes;
42 this._signature = graph.signature;
43 this._functions = graph.functions;
44 // create sub-graph executors
45 if (graph.functions != null) {
46 Object.keys(graph.functions).forEach(name => {
47 this._functionExecutorMap[name] =
48 new GraphExecutor(graph.functions[name], this);
49 });
50 }
51 }
52 get weightIds() {
53 return this.parent ? this.parent.weightIds : this._weightIds;
54 }
55 get functionExecutorMap() {
56 return this.parent ? this.parent.functionExecutorMap :
57 this._functionExecutorMap;
58 }
59 get weightMap() {
60 return this.parent ? this.parent.weightMap : this._weightMap;
61 }
62 set weightMap(weightMap) {
63 const weightIds = Object.keys(weightMap).map(key => weightMap[key].map(tensor => tensor.id));
64 this._weightIds = [].concat(...weightIds);
65 this._weightMap = weightMap;
66 }
67 /**
68 * Set `ResourceManager` shared by executors of a model.
69 * @param resourceManager: `ResourceManager` of the `GraphModel`.
70 */
71 set resourceManager(resourceManager) {
72 this._resourceManager = resourceManager;
73 }
74 get inputs() {
75 return this._inputs.map(node => {
76 return {
77 name: node.name,
78 shape: node.attrParams['shape'] ?
79 node.attrParams['shape'].value :
80 undefined,
81 dtype: node.attrParams['dtype'] ?
82 node.attrParams['dtype'].value :
83 undefined
84 };
85 });
86 }
87 get outputs() {
88 return this._outputs.map(node => {
89 return {
90 name: node.name,
91 shape: node.attrParams['shape'] ?
92 node.attrParams['shape'].value :
93 undefined,
94 dtype: node.attrParams['dtype'] ?
95 node.attrParams['dtype'].value :
96 undefined
97 };
98 });
99 }
100 get inputNodes() {
101 return this._inputs.map(node => node.signatureKey || node.name);
102 }
103 get outputNodes() {
104 return this._outputs.map((node) => {
105 const name = node.signatureKey || node.name;
106 return node.defaultOutput ? (`${name}:${node.defaultOutput}`) : name;
107 });
108 }
109 get functions() {
110 return Object.keys(this._functions).reduce((map, key) => {
111 map[key] = this._functions[key].signature;
112 return map;
113 }, {});
114 }
115 getCompilationKey(inputs, outputs) {
116 const sortedInputs = inputs.map(node => node.name).sort();
117 const sortedOutputs = outputs.map(node => node.name).sort();
118 return sortedInputs.join(this.SEPERATOR) + '--' +
119 sortedOutputs.join(this.SEPERATOR);
120 }
121 /**
122 * Compiles the inference graph and returns the minimal set of nodes that are
123 * required for execution, in the correct execution order.
124 */
125 compile(inputs, outputs) {
126 const executionInfo = getExecutionSubgraph(inputs, outputs, this.weightMap, this._initNodes);
127 const { missingInputs, dynamicNode, syncInputs } = executionInfo;
128 if (dynamicNode != null) {
129 throw new Error(`This execution contains the node '${dynamicNode.name}', which has ` +
130 `the dynamic op '${dynamicNode.op}'. Please use ` +
131 `model.executeAsync() instead. Alternatively, to avoid the ` +
132 `dynamic ops, specify the inputs [${syncInputs}]`);
133 }
134 if (missingInputs.length > 0) {
135 const outNames = outputs.map(n => n.name);
136 const inNames = Object.keys(inputs);
137 throw new Error(`Cannot compute the outputs [${outNames}] from the provided inputs ` +
138 `[${inNames}]. Missing the following inputs: [${missingInputs}]`);
139 }
140 return getNodesInTopologicalOrder(this.graph, this.weightMap, executionInfo);
141 }
142 /**
143 * Executes the inference for given input tensors.
144 * @param inputs Tensor map for the model inputs, keyed by the input node
145 * names.
146 * @param outputs Optional. output node name from the Tensorflow model, if
147 * no outputs are specified, the default outputs of the model would be used.
148 * You can inspect intermediate nodes of the model by adding them to the
149 * outputs array.
150 */
151 execute(inputs, outputs) {
152 inputs = this.mapInputs(inputs);
153 const names = Object.keys(inputs).sort();
154 this.checkInputs(inputs);
155 this.checkInputShapeAndType(inputs);
156 outputs = this.mapOutputs(outputs);
157 this.checkOutputs(outputs);
158 const inputNodes = names.map(name => this.graph.nodes[parseNodeName(name)[0]]);
159 const outputNodeNames = outputs.map(name => parseNodeName(name)[0]);
160 let outputNodes = outputNodeNames.map(name => this.graph.nodes[name]);
161 // If no outputs are specified, then use the default outputs of the model.
162 if (outputNodes.length === 0) {
163 outputNodes = this._outputs;
164 }
165 const compilationKey = this.getCompilationKey(inputNodes, outputNodes);
166 // Do nothing if the compiled graph cache contains the input.
167 let orderedNodes = this.compiledMap.get(compilationKey);
168 if (orderedNodes == null) {
169 orderedNodes = this.compile(inputs, outputNodes);
170 this.compiledMap.set(compilationKey, orderedNodes);
171 }
172 const tensorArrayMap = {};
173 const tensorListMap = {};
174 return tidy(() => {
175 const context = new ExecutionContext(this.weightMap, tensorArrayMap, tensorListMap, this.functionExecutorMap);
176 const tensorsMap = Object.assign({}, this.weightMap);
177 Object.keys(inputs).forEach(name => {
178 const [nodeName, index] = parseNodeName(name);
179 const tensors = [];
180 tensors[index] = inputs[name];
181 tensorsMap[nodeName] = tensors;
182 });
183 const tensorsToKeep = this.getFrozenTensorIds(tensorsMap);
184 const intermediateTensorConsumerCount = {};
185 for (let i = 0; i < orderedNodes.length; i++) {
186 const node = orderedNodes[i];
187 if (!tensorsMap[node.name]) {
188 const tensors = executeOp(node, tensorsMap, context, this._resourceManager);
189 if (util.isPromise(tensors)) {
190 throw new Error(`The execution of the op '${node.op}' returned a promise. ` +
191 `Please use model.executeAsync() instead.`);
192 }
193 tensorsMap[node.name] = tensors;
194 this.checkTensorForDisposal(node.name, node, tensorsMap, context, tensorsToKeep, outputNodeNames, intermediateTensorConsumerCount);
195 }
196 }
197 // dispose the context for the root executor
198 if (this.parent == null) {
199 context.dispose(tensorsToKeep);
200 }
201 return outputs.map(name => getTensor(name, tensorsMap, context));
202 });
203 }
204 getFrozenTensorIds(tensorMap) {
205 const ids = [].concat.apply([], Object.keys(tensorMap)
206 .map(key => tensorMap[key])
207 .map(tensors => tensors.map(tensor => tensor.id)));
208 return new Set(ids);
209 }
210 checkTensorForDisposal(nodeName, node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount) {
211 // Skip output nodes and any control flow nodes, since its dependency is
212 // tricky to track correctly.
213 if (node.category === 'control' || outputNames.indexOf(nodeName) !== -1) {
214 return;
215 }
216 tensorMap[nodeName].forEach(tensor => {
217 if (tensor != null) {
218 intermediateTensorConsumerCount[tensor.id] =
219 (intermediateTensorConsumerCount[tensor.id] || 0) +
220 node.children.length;
221 }
222 });
223 node.inputs.forEach(input => {
224 // Skip any control flow nodes, since its dependency is tricky to track
225 // correctly.
226 if (input.category !== 'control') {
227 const tensors = getTensorsForCurrentContenxt(input.name, tensorMap, context);
228 if (tensors != null) {
229 tensors.forEach(tensor => {
230 if (tensor && !tensor.kept && !tensorsToKeep.has(tensor.id)) {
231 const count = intermediateTensorConsumerCount[tensor.id];
232 if (count === 1) {
233 tensor.dispose();
234 delete intermediateTensorConsumerCount[tensor.id];
235 }
236 else if (count != null) {
237 // only intermediate nodes has count set, inputs and weights are
238 // not.
239 intermediateTensorConsumerCount[tensor.id]--;
240 }
241 }
242 });
243 }
244 }
245 });
246 }
247 /**
248 * Executes the inference for given input tensors in Async fashion.
249 * @param inputs Tensor map for the model inputs, keyed by the input node
250 * names.
251 * @param outputs output node name from the Tensorflow model, if no outputs
252 * are specified, the default outputs of the model would be used. You can
253 * inspect intermediate nodes of the model by adding them to the outputs
254 * array.
255 */
256 async executeAsync(inputs, outputs) {
257 return this._executeAsync(inputs, outputs);
258 }
259 /**
260 * Executes the inference for given input tensors in Async fashion.
261 * @param inputs Tensor map for the model inputs, keyed by the input node
262 * names.
263 * @param outputs Optional. output node name from the Tensorflow model,
264 * if no outputs are specified, the default outputs of the model would be
265 * used. You can inspect intermediate nodes of the model by adding them to the
266 * outputs array.
267 * @param isFunctionExecution Optional. Flag for executing a function.
268 * @param tensorArrayMap Optional, global TensorArray map by id. Used for
269 * function execution.
270 * @param tensorArrayMap Optinal global TensorList map by id. Used for
271 * function execution.
272 */
273 async _executeAsync(inputs, outputs, isFunctionExecution = false, tensorArrayMap = {}, tensorListMap = {}) {
274 if (!isFunctionExecution) {
275 inputs = this.mapInputs(inputs);
276 this.checkInputs(inputs);
277 this.checkInputShapeAndType(inputs);
278 outputs = this.mapOutputs(outputs);
279 this.checkOutputs(outputs);
280 }
281 const context = new ExecutionContext(this.weightMap, tensorArrayMap, tensorListMap, this.functionExecutorMap);
282 // Graph with control flow op requires runtime evaluation of the execution
283 // order, while without control flow the execution order is pre-determined
284 // in the compile method.
285 const tensorMap = await this.executeWithControlFlow(inputs, context, outputs, isFunctionExecution);
286 const results = outputs.map(name => getTensor(name, tensorMap, context));
287 // dispose all the intermediate tensors
288 const outputIds = results.map(t => t.id);
289 const inputIds = Object.keys(inputs).map(name => inputs[name].id);
290 const keepIds = new Set([...outputIds, ...inputIds, ...this.weightIds]);
291 Object.keys(tensorMap).forEach(key => {
292 const tensorArray = tensorMap[key];
293 tensorArray.forEach(tensor => {
294 if (tensor && !tensor.kept && !tensor.isDisposed &&
295 !keepIds.has(tensor.id)) {
296 tensor.dispose();
297 }
298 });
299 });
300 // dispose the context for the root executor
301 if (this.parent == null) {
302 context.dispose(keepIds);
303 }
304 return results;
305 }
306 async executeFunctionAsync(inputs, tensorArrayMap, tensorListMap) {
307 const mappedInputs = inputs.reduce((map, tensor, index) => {
308 map[this.inputs[index].name] = tensor;
309 return map;
310 }, {});
311 return this._executeAsync(mappedInputs, this.outputNodes, true, tensorArrayMap, tensorListMap);
312 }
313 /**
314 * When there are control flow nodes in the graph, the graph execution use
315 * ExecutionContext to keep track of the frames and loop iterators.
316 * @param inputs placeholder tensors for the graph.
317 * @param context the execution context object for current execution.
318 * @param outputNames Optional. output node name from the Tensorflow model,
319 * if no outputs are specified, the default outputs of the model would be
320 * used. You can inspect intermediate nodes of the model by adding them to the
321 * outputs array.
322 * @param isFunctionExecution Flag for executing a function.
323 */
324 async executeWithControlFlow(inputs, context, outputNames, isFunctionExecution) {
325 const names = Object.keys(inputs);
326 const inputNodes = names.map(name => this.graph.nodes[parseNodeName(name)[0]]);
327 const outputNodeNames = outputNames.map(name => parseNodeName(name)[0]);
328 let outputNodes = outputNodeNames.map(name => this.graph.nodes[name]);
329 // If no outputs are specified, then use the default outputs of the model.
330 if (outputNodes.length === 0) {
331 outputNodes = this._outputs;
332 }
333 const { usedNodes, missingInputs, dynamicNode, syncInputs } = getExecutionSubgraph(inputs, outputNodes, this.weightMap, this._initNodes);
334 // First nodes to execute include inputNodes, weights, and initNodes.
335 const stack = [
336 ...inputNodes, ...this.graph.weights, ...(this._initNodes || [])
337 ].map(node => {
338 return { node, contexts: context.currentContext };
339 });
340 const tensorsMap = Object.assign({}, this.weightMap);
341 Object.keys(inputs).forEach(name => {
342 const [nodeName, index] = parseNodeName(name);
343 const tensors = [];
344 tensors[index] = inputs[name];
345 tensorsMap[nodeName] = tensors;
346 });
347 const intermediateTensorConsumerCount = {};
348 const tensorsToKeep = this.getFrozenTensorIds(tensorsMap);
349 const added = {};
350 while (stack.length > 0) {
351 const promises = this.processStack(inputNodes, stack, context, tensorsMap, added, tensorsToKeep, outputNodeNames, intermediateTensorConsumerCount, usedNodes);
352 await Promise.all(promises);
353 }
354 if (dynamicNode == null && !isFunctionExecution) {
355 console.warn(`This model execution did not contain any nodes with control flow ` +
356 `or dynamic output shapes. You can use model.execute() instead.`);
357 }
358 const missingOutputs = outputNodes
359 .filter(node => !isControlFlow(node) &&
360 !getTensor(node.name, tensorsMap, context))
361 .map(node => node.name);
362 if (missingOutputs.length > 0) {
363 let alternativeMsg = '';
364 if (dynamicNode != null) {
365 alternativeMsg =
366 `Alternatively, to avoid the dynamic ops, use model.execute() ` +
367 `and specify the inputs [${syncInputs}]`;
368 }
369 throw new Error(`Cannot compute the outputs [${missingOutputs}] from the provided ` +
370 `inputs [${names}]. Consider providing the following inputs: ` +
371 `[${missingInputs}]. ${alternativeMsg}`);
372 }
373 return tensorsMap;
374 }
375 processStack(inputNodes, stack, context, tensorMap, added, tensorsToKeep, outputNames, intermediateTensorConsumerCount, usedNodes) {
376 const promises = [];
377 while (stack.length > 0) {
378 const item = stack.pop();
379 context.currentContext = item.contexts;
380 let nodeName = '';
381 // The tensor of the Enter op with isConstant set should be set
382 // in the parent scope, so it will be available as constant for the
383 // whole loop.
384 if (item.node.op === 'Enter' &&
385 getParamValue('isConstant', item.node, tensorMap, context)) {
386 [nodeName] = getNodeNameAndIndex(item.node.name, context);
387 }
388 // only process nodes that are not in the tensorMap yet, this include
389 // inputNodes and internal initNodes.
390 if (tensorMap[item.node.name] == null) {
391 const tensors = executeOp(item.node, tensorMap, context, this._resourceManager);
392 if (!nodeName) {
393 [nodeName] = getNodeNameAndIndex(item.node.name, context);
394 }
395 const currentContext = context.currentContext;
396 if (util.isPromise(tensors)) {
397 promises.push(tensors.then(t => {
398 tensorMap[nodeName] = t;
399 context.currentContext = currentContext;
400 this.checkTensorForDisposal(nodeName, item.node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount);
401 this.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
402 return t;
403 }));
404 }
405 else {
406 tensorMap[nodeName] = tensors;
407 this.checkTensorForDisposal(nodeName, item.node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount);
408 this.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
409 }
410 }
411 else {
412 this.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
413 }
414 }
415 return promises;
416 }
417 processChildNodes(node, stack, context, tensorMap, added, usedNodes) {
418 node.children.forEach((childNode) => {
419 const [nodeName,] = getNodeNameAndIndex(childNode.name, context);
420 if (added[nodeName] || !usedNodes.has(childNode.name)) {
421 return;
422 }
423 // Merge op can be pushed if any of its inputs has value.
424 if (childNode.op === 'Merge') {
425 if (childNode.inputNames.some(name => {
426 return !!getTensor(name, tensorMap, context);
427 })) {
428 added[nodeName] = true;
429 stack.push({ contexts: context.currentContext, node: childNode });
430 }
431 }
432 else // Otherwise all inputs must to have value.
433 if (childNode.inputNames.every(name => {
434 return !!getTensor(name, tensorMap, context);
435 })) {
436 added[nodeName] = true;
437 stack.push({ contexts: context.currentContext, node: childNode });
438 }
439 });
440 }
441 /**
442 * Releases the memory used by the weight tensors.
443 */
444 dispose() {
445 Object.keys(this.weightMap)
446 .forEach(key => this.weightMap[key].forEach(tensor => tensor.dispose()));
447 }
448 checkInputShapeAndType(inputs) {
449 Object.keys(inputs).forEach(name => {
450 const input = inputs[name];
451 const [nodeName,] = parseNodeName(name);
452 const node = this.graph.nodes[nodeName];
453 if (node.attrParams['shape'] && node.attrParams['shape'].value) {
454 const shape = node.attrParams['shape'].value;
455 const match = shape.length === input.shape.length &&
456 input.shape.every((dim, index) => shape[index] === -1 || shape[index] === dim);
457 util.assert(match, () => `The shape of dict['${node.name}'] provided in ` +
458 `model.execute(dict) must be [${shape}], but was ` +
459 `[${input.shape}]`);
460 }
461 if (node.attrParams['dtype'] && node.attrParams['dtype'].value) {
462 util.assert(input.dtype === node.attrParams['dtype'].value, () => `The dtype of dict['${node.name}'] provided in ` +
463 `model.execute(dict) must be ` +
464 `${node.attrParams['dtype'].value}, but was ${input.dtype}`);
465 }
466 });
467 }
468 mapInputs(inputs) {
469 const result = {};
470 for (const inputName in inputs) {
471 if (this._signature != null && this._signature.inputs != null &&
472 this._signature.inputs[inputName] != null) {
473 const tensor = this._signature.inputs[inputName];
474 result[tensor.name] = inputs[inputName];
475 }
476 else {
477 result[inputName] = inputs[inputName];
478 }
479 }
480 return result;
481 }
482 checkInputs(inputs) {
483 const notInGraph = Object.keys(inputs).filter(name => {
484 const [nodeName] = parseNodeName(name);
485 return this.graph.nodes[nodeName] == null;
486 });
487 if (notInGraph.length > 0) {
488 throw new Error(`The dict provided in model.execute(dict) has ` +
489 `keys: [${notInGraph}] that are not part of graph`);
490 }
491 }
492 mapOutputs(outputs) {
493 return outputs.map(name => {
494 if (this._signature != null && this._signature.outputs != null &&
495 this._signature.outputs[name] != null) {
496 const tensor = this._signature.outputs[name];
497 return tensor.name;
498 }
499 return name;
500 }, {});
501 }
502 checkOutputs(outputs) {
503 outputs.forEach(name => {
504 const [normalizedName] = parseNodeName(name);
505 if (!this.graph.nodes[normalizedName]) {
506 throw new Error(`The output '${name}' is not found in the graph`);
507 }
508 });
509 }
510}
511//# sourceMappingURL=data:application/json;base64,
\No newline at end of file