UNPKG

23.8 kBJavaScriptView Raw
1"use strict";
2/**
3 * @license
4 * Copyright 2019 Google LLC. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 * =============================================================================
17 */
18var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
19 return new (P || (P = Promise))(function (resolve, reject) {
20 function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
21 function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
22 function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); }
23 step((generator = generator.apply(thisArg, _arguments || [])).next());
24 });
25};
26var __generator = (this && this.__generator) || function (thisArg, body) {
27 var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g;
28 return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g;
29 function verb(n) { return function (v) { return step([n, v]); }; }
30 function step(op) {
31 if (f) throw new TypeError("Generator is already executing.");
32 while (_) try {
33 if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t;
34 if (y = 0, t) op = [op[0] & 2, t.value];
35 switch (op[0]) {
36 case 0: case 1: t = op; break;
37 case 4: _.label++; return { value: op[1], done: false };
38 case 5: _.label++; y = op[1]; op = [0]; continue;
39 case 7: op = _.ops.pop(); _.trys.pop(); continue;
40 default:
41 if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; }
42 if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; }
43 if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; }
44 if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; }
45 if (t[2]) _.ops.pop();
46 _.trys.pop(); continue;
47 }
48 op = body.call(thisArg, _);
49 } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; }
50 if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true };
51 }
52};
53Object.defineProperty(exports, "__esModule", { value: true });
54var tfjs_1 = require("@tensorflow/tfjs");
55var fs = require("fs");
56var util_1 = require("util");
57var nodejs_kernel_backend_1 = require("./nodejs_kernel_backend");
58var readFile = util_1.promisify(fs.readFile);
59// tslint:disable-next-line:no-require-imports
60var messages = require('./proto/api_pb');
61var SAVED_MODEL_FILE_NAME = '/saved_model.pb';
62var SAVED_MODEL_INIT_OP_KEY = '__saved_model_init_op';
63// This map is used to keep track of loaded SavedModel metagraph mapping
64// information. The map key is TFSavedModel id in JavaScript, value is
65// an object of path to the SavedModel, metagraph tags, and loaded Session ID in
66// the c++ bindings. When user loads a SavedModel signature, it will go through
67// entries in this map to find if the corresponding SavedModel session has
68// already been loaded in C++ addon and will reuse it if existing.
69var loadedSavedModelPathMap = new Map();
70// The ID of loaded TFSavedModel. This ID is used to keep track of loaded
71// TFSavedModel, so the loaded session in c++ bindings for the corresponding
72// TFSavedModel can be properly reused/disposed.
73var nextTFSavedModelId = 0;
74/**
75 * Get a key in an object by its value. This is used to get protobuf enum value
76 * from index.
77 *
78 * @param object
79 * @param value
80 */
81// tslint:disable-next-line:no-any
82function getEnumKeyFromValue(object, value) {
83 return Object.keys(object).find(function (key) { return object[key] === value; });
84}
85exports.getEnumKeyFromValue = getEnumKeyFromValue;
86/**
87 * Read SavedModel proto message from path.
88 *
89 * @param path Path to SavedModel folder.
90 */
91function readSavedModelProto(path) {
92 return __awaiter(this, void 0, void 0, function () {
93 var modelFile, array;
94 return __generator(this, function (_a) {
95 switch (_a.label) {
96 case 0:
97 // Load the SavedModel pb file and deserialize it into message.
98 try {
99 fs.accessSync(path + SAVED_MODEL_FILE_NAME, fs.constants.R_OK);
100 }
101 catch (error) {
102 throw new Error('There is no saved_model.pb file in the directory: ' + path);
103 }
104 return [4 /*yield*/, readFile(path + SAVED_MODEL_FILE_NAME)];
105 case 1:
106 modelFile = _a.sent();
107 array = new Uint8Array(modelFile);
108 return [2 /*return*/, messages.SavedModel.deserializeBinary(array)];
109 }
110 });
111 });
112}
113exports.readSavedModelProto = readSavedModelProto;
114/**
115 * Inspect the MetaGraphs of the SavedModel from the provided path. This
116 * function will return an array of `MetaGraphInfo` objects.
117 *
118 * @param path Path to SavedModel folder.
119 *
120 * @doc {heading: 'Models', subheading: 'SavedModel', namespace: 'node'}
121 */
122function getMetaGraphsFromSavedModel(path) {
123 return __awaiter(this, void 0, void 0, function () {
124 var result, modelMessage, metaGraphList, i, metaGraph, tags, signatureDef, signatureDefMap, signatureDefKeys, key, signatureDefEntry, inputsMapMessage, inputsMapKeys, inputs, inputsMapKey, inputTensor, inputTensorInfo, outputsMapMessage, outputsMapKeys, outputs, outputsMapKey, outputTensor, outputTensorInfo;
125 return __generator(this, function (_a) {
126 switch (_a.label) {
127 case 0:
128 result = [];
129 return [4 /*yield*/, readSavedModelProto(path)];
130 case 1:
131 modelMessage = _a.sent();
132 metaGraphList = modelMessage.getMetaGraphsList();
133 for (i = 0; i < metaGraphList.length; i++) {
134 metaGraph = {};
135 tags = metaGraphList[i].getMetaInfoDef().getTagsList();
136 metaGraph.tags = tags;
137 signatureDef = {};
138 signatureDefMap = metaGraphList[i].getSignatureDefMap();
139 signatureDefKeys = signatureDefMap.keys();
140 // Go through all signatureDefs
141 while (true) {
142 key = signatureDefKeys.next();
143 if (key.done) {
144 break;
145 }
146 // Skip TensorFlow internal Signature '__saved_model_init_op'.
147 if (key.value === SAVED_MODEL_INIT_OP_KEY) {
148 continue;
149 }
150 signatureDefEntry = signatureDefMap.get(key.value);
151 inputsMapMessage = signatureDefEntry.getInputsMap();
152 inputsMapKeys = inputsMapMessage.keys();
153 inputs = {};
154 while (true) {
155 inputsMapKey = inputsMapKeys.next();
156 if (inputsMapKey.done) {
157 break;
158 }
159 inputTensor = inputsMapMessage.get(inputsMapKey.value);
160 inputTensorInfo = {};
161 inputTensorInfo.dtype = mapTFDtypeToJSDtype(getEnumKeyFromValue(messages.DataType, inputTensor.getDtype()));
162 inputTensorInfo.name = inputTensor.getName();
163 inputTensorInfo.shape = inputTensor.getTensorShape().getDimList();
164 inputs[inputsMapKey.value] = inputTensorInfo;
165 }
166 outputsMapMessage = signatureDefEntry.getOutputsMap();
167 outputsMapKeys = outputsMapMessage.keys();
168 outputs = {};
169 while (true) {
170 outputsMapKey = outputsMapKeys.next();
171 if (outputsMapKey.done) {
172 break;
173 }
174 outputTensor = outputsMapMessage.get(outputsMapKey.value);
175 outputTensorInfo = {};
176 outputTensorInfo.dtype = mapTFDtypeToJSDtype(getEnumKeyFromValue(messages.DataType, outputTensor.getDtype()));
177 outputTensorInfo.name = outputTensor.getName();
178 outputTensorInfo.shape = outputTensor.getTensorShape().getDimList();
179 outputs[outputsMapKey.value] = outputTensorInfo;
180 }
181 signatureDef[key.value] = { inputs: inputs, outputs: outputs };
182 }
183 metaGraph.signatureDefs = signatureDef;
184 result.push(metaGraph);
185 }
186 return [2 /*return*/, result];
187 }
188 });
189 });
190}
191exports.getMetaGraphsFromSavedModel = getMetaGraphsFromSavedModel;
192/**
193 * Get input and output node names from SavedModel metagraphs info. The
194 * input.output node names will be used when executing a SavedModel signature.
195 *
196 * @param savedModelInfo The MetaGraphInfo array loaded through
197 * getMetaGraphsFromSavedModel().
198 * @param tags The tags of the MetaGraph to get input/output node names from.
199 * @param signature The signature to get input/output node names from.
200 */
201function getInputAndOutputNodeNameFromMetaGraphInfo(savedModelInfo, tags, signature) {
202 for (var i = 0; i < savedModelInfo.length; i++) {
203 var metaGraphInfo = savedModelInfo[i];
204 if (stringArraysHaveSameElements(tags, metaGraphInfo.tags)) {
205 if (metaGraphInfo.signatureDefs[signature] == null) {
206 throw new Error('The SavedModel does not have signature: ' + signature);
207 }
208 var inputNodeNames = {};
209 var outputNodeNames = {};
210 for (var _i = 0, _a = Object.keys(metaGraphInfo.signatureDefs); _i < _a.length; _i++) {
211 var signatureDef = _a[_i];
212 if (signatureDef === signature) {
213 for (var _b = 0, _c = Object.keys(metaGraphInfo.signatureDefs[signature].inputs); _b < _c.length; _b++) {
214 var tensorName = _c[_b];
215 inputNodeNames[tensorName] =
216 metaGraphInfo.signatureDefs[signature].inputs[tensorName].name;
217 }
218 for (var _d = 0, _e = Object.keys(metaGraphInfo.signatureDefs[signature].outputs); _d < _e.length; _d++) {
219 var tensorName = _e[_d];
220 outputNodeNames[tensorName] =
221 metaGraphInfo.signatureDefs[signature].outputs[tensorName].name;
222 }
223 }
224 }
225 return [inputNodeNames, outputNodeNames];
226 }
227 }
228 throw new Error("The SavedModel does not have tags: " + tags);
229}
230exports.getInputAndOutputNodeNameFromMetaGraphInfo = getInputAndOutputNodeNameFromMetaGraphInfo;
231/**
232 * A `tf.TFSavedModel` is a signature loaded from a SavedModel
233 * metagraph, and allows inference execution.
234 *
235 * @doc {heading: 'Models', subheading: 'SavedModel', namespace: 'node'}
236 */
237var TFSavedModel = /** @class */ (function () {
238 function TFSavedModel(sessionId, jsid, inputNodeNames, outputNodeNames, backend) {
239 this.sessionId = sessionId;
240 this.jsid = jsid;
241 this.inputNodeNames = inputNodeNames;
242 this.outputNodeNames = outputNodeNames;
243 this.backend = backend;
244 this.disposed = false;
245 }
246 Object.defineProperty(TFSavedModel.prototype, "inputs", {
247 /**
248 * Return the array of input tensor info.
249 *
250 * @doc {heading: 'Models', subheading: 'SavedModel'}
251 */
252 get: function () {
253 throw new Error('SavedModel inputs information is not available yet.');
254 },
255 enumerable: true,
256 configurable: true
257 });
258 Object.defineProperty(TFSavedModel.prototype, "outputs", {
259 /**
260 * Return the array of output tensor info.
261 *
262 * @doc {heading: 'Models', subheading: 'SavedModel'}
263 */
264 get: function () {
265 throw new Error('SavedModel outputs information is not available yet.');
266 },
267 enumerable: true,
268 configurable: true
269 });
270 /**
271 * Delete the SavedModel from nodeBackend and delete corresponding session in
272 * the C++ backend if the session is only used by this TFSavedModel.
273 *
274 * @doc {heading: 'Models', subheading: 'SavedModel'}
275 */
276 TFSavedModel.prototype.dispose = function () {
277 if (!this.disposed) {
278 this.disposed = true;
279 loadedSavedModelPathMap.delete(this.jsid);
280 for (var _i = 0, _a = Array.from(loadedSavedModelPathMap.keys()); _i < _a.length; _i++) {
281 var id = _a[_i];
282 var value = loadedSavedModelPathMap.get(id);
283 if (value.sessionId === this.sessionId) {
284 return;
285 }
286 }
287 this.backend.deleteSavedModel(this.sessionId);
288 }
289 else {
290 throw new Error('This SavedModel has already been deleted.');
291 }
292 };
293 /**
294 * Execute the inference for the input tensors.
295 *
296 * @param input The input tensors, when there is single input for the model,
297 * inputs param should be a Tensor. For models with multiple inputs, inputs
298 * params should be in either Tensor[] if the input order is fixed, or
299 * otherwise NamedTensorMap format. The keys in the NamedTensorMap are the
300 * name of input tensors in SavedModel signatureDef. It can be found through
301 * `tf.node.getMetaGraphsFromSavedModel()`.
302 *
303 * For batch inference execution, the tensors for each input need to be
304 * concatenated together. For example with mobilenet, the required input shape
305 * is [1, 244, 244, 3], which represents the [batch, height, width, channel].
306 * If we are provide a batched data of 100 images, the input tensor should be
307 * in the shape of [100, 244, 244, 3].
308 *
309 * @param config Prediction configuration for specifying the batch size.
310 *
311 * @returns Inference result tensors. The output would be single Tensor if
312 * model has single output node, otherwise Tensor[] or NamedTensorMap[] will
313 * be returned for model with multiple outputs.
314 *
315 * @doc {heading: 'Models', subheading: 'SavedModel'}
316 */
317 TFSavedModel.prototype.predict = function (inputs, config) {
318 var _this = this;
319 if (this.disposed) {
320 throw new Error('The TFSavedModel has already been deleted!');
321 }
322 else {
323 var inputTensors = [];
324 if (inputs instanceof tfjs_1.Tensor) {
325 inputTensors.push(inputs);
326 var result = this.backend.runSavedModel(this.sessionId, inputTensors, Object.values(this.inputNodeNames), Object.values(this.outputNodeNames));
327 return result.length > 1 ? result : result[0];
328 }
329 else if (Array.isArray(inputs)) {
330 inputTensors = inputs;
331 return this.backend.runSavedModel(this.sessionId, inputTensors, Object.values(this.inputNodeNames), Object.values(this.outputNodeNames));
332 }
333 else {
334 var inputTensorNames = Object.keys(this.inputNodeNames);
335 var providedInputNames = Object.keys(inputs);
336 if (!stringArraysHaveSameElements(inputTensorNames, providedInputNames)) {
337 throw new Error("The model signatureDef input names are " + inputTensorNames.join() + ", however the provided input names are " + providedInputNames.join() + ".");
338 }
339 var inputNodeNamesArray = [];
340 for (var i = 0; i < inputTensorNames.length; i++) {
341 inputTensors.push(inputs[inputTensorNames[i]]);
342 inputNodeNamesArray.push(this.inputNodeNames[inputTensorNames[i]]);
343 }
344 var outputTensorNames = Object.keys(this.outputNodeNames);
345 var outputNodeNamesArray = [];
346 for (var i = 0; i < outputTensorNames.length; i++) {
347 outputNodeNamesArray.push(this.outputNodeNames[outputTensorNames[i]]);
348 }
349 var outputTensors_1 = this.backend.runSavedModel(this.sessionId, inputTensors, inputNodeNamesArray, outputNodeNamesArray);
350 tfjs_1.util.assert(outputTensors_1.length === outputNodeNamesArray.length, function () { return 'Output tensors do not match output node names, ' +
351 ("receive " + outputTensors_1.length + ") output tensors but ") +
352 ("there are " + _this.outputNodeNames.length + " output nodes."); });
353 var outputMap = {};
354 for (var i = 0; i < outputTensorNames.length; i++) {
355 outputMap[outputTensorNames[i]] = outputTensors_1[i];
356 }
357 return outputMap;
358 }
359 }
360 };
361 /**
362 * Execute the inference for the input tensors and return activation
363 * values for specified output node names without batching.
364 *
365 * @param input The input tensors, when there is single input for the model,
366 * inputs param should be a Tensor. For models with multiple inputs, inputs
367 * params should be in either Tensor[] if the input order is fixed, or
368 * otherwise NamedTensorMap format.
369 *
370 * @param outputs string|string[]. List of output node names to retrieve
371 * activation from.
372 *
373 * @returns Activation values for the output nodes result tensors. The return
374 * type matches specified parameter outputs type. The output would be single
375 * Tensor if single output is specified, otherwise Tensor[] for multiple
376 * outputs.
377 *
378 * @doc {heading: 'Models', subheading: 'SavedModel'}
379 */
380 TFSavedModel.prototype.execute = function (inputs, outputs) {
381 throw new Error('execute() of TFSavedModel is not supported yet.');
382 };
383 return TFSavedModel;
384}());
385exports.TFSavedModel = TFSavedModel;
386/**
387 * Load a TensorFlow SavedModel from disk. TensorFlow SavedModel is different
388 * from TensorFlow.js model format. A SavedModel is a directory containing
389 * serialized signatures and the states needed to run them. The directory has a
390 * saved_model.pb (or saved_model.pbtxt) file storing the actual TensorFlow
391 * program, or model, and a set of named signatures, each identifying a
392 * function. The directory also has a variables directory contains a standard
393 * training checkpoint. The directory may also has a assets directory contains
394 * files used by the TensorFlow graph, for example text files used to initialize
395 * vocabulary tables. These are supported datatypes: float32, int32, complex64,
396 * string.For more information, see this guide:
397 * https://www.tensorflow.org/guide/saved_model.
398 *
399 * @param path The path to the SavedModel.
400 * @param tags The tags of the MetaGraph to load. The available tags of a
401 * SavedModel can be retrieved through tf.node.getMetaGraphsFromSavedModel()
402 * API. Defaults to ['serve'].
403 * @param signature The name of the SignatureDef to load. The available
404 * SignatureDefs of a SavedModel can be retrieved through
405 * tf.node.getMetaGraphsFromSavedModel() API. Defaults to 'serving_default'.
406 *
407 * @doc {heading: 'Models', subheading: 'SavedModel', namespace: 'node'}
408 */
409function loadSavedModel(path, tags, signature) {
410 if (tags === void 0) { tags = ['serve']; }
411 if (signature === void 0) { signature = 'serving_default'; }
412 return __awaiter(this, void 0, void 0, function () {
413 var backend, savedModelInfo, _a, inputNodeNames, outputNodeNames, sessionId, _i, _b, id_1, modelInfo, tagsString, id, savedModel;
414 return __generator(this, function (_c) {
415 switch (_c.label) {
416 case 0:
417 nodejs_kernel_backend_1.ensureTensorflowBackend();
418 backend = nodejs_kernel_backend_1.nodeBackend();
419 return [4 /*yield*/, getMetaGraphsFromSavedModel(path)];
420 case 1:
421 savedModelInfo = _c.sent();
422 _a = getInputAndOutputNodeNameFromMetaGraphInfo(savedModelInfo, tags, signature), inputNodeNames = _a[0], outputNodeNames = _a[1];
423 for (_i = 0, _b = Array.from(loadedSavedModelPathMap.keys()); _i < _b.length; _i++) {
424 id_1 = _b[_i];
425 modelInfo = loadedSavedModelPathMap.get(id_1);
426 if (modelInfo.path === path &&
427 stringArraysHaveSameElements(modelInfo.tags, tags)) {
428 sessionId = modelInfo.sessionId;
429 }
430 }
431 if (sessionId == null) {
432 tagsString = tags.join(',');
433 sessionId = backend.loadSavedModelMetaGraph(path, tagsString);
434 }
435 id = nextTFSavedModelId++;
436 savedModel = new TFSavedModel(sessionId, id, inputNodeNames, outputNodeNames, backend);
437 loadedSavedModelPathMap.set(id, { path: path, tags: tags, sessionId: sessionId });
438 return [2 /*return*/, savedModel];
439 }
440 });
441 });
442}
443exports.loadSavedModel = loadSavedModel;
444/**
445 * Compare if two unsorted arrays of string have the same elements.
446 * @param arrayA
447 * @param arrayB
448 */
449function stringArraysHaveSameElements(arrayA, arrayB) {
450 if (arrayA.length === arrayB.length &&
451 arrayA.sort().join() === arrayB.sort().join()) {
452 return true;
453 }
454 return false;
455}
456function mapTFDtypeToJSDtype(tfDtype) {
457 switch (tfDtype) {
458 case 'DT_FLOAT':
459 return 'float32';
460 case 'DT_INT32':
461 return 'int32';
462 case 'DT_BOOL':
463 return 'bool';
464 case 'DT_COMPLEX64':
465 return 'complex64';
466 case 'DT_STRING':
467 return 'string';
468 default:
469 throw new Error('Unsupported tensor DataType: ' + tfDtype +
470 ', try to modify the model in python to convert the datatype');
471 }
472}
473function getNumOfSavedModels() {
474 nodejs_kernel_backend_1.ensureTensorflowBackend();
475 var backend = nodejs_kernel_backend_1.nodeBackend();
476 return backend.getNumOfSavedModels();
477}
478exports.getNumOfSavedModels = getNumOfSavedModels;