1 | "use strict";
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 |
|
18 | var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
|
19 | return new (P || (P = Promise))(function (resolve, reject) {
|
20 | function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
|
21 | function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
|
22 | function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); }
|
23 | step((generator = generator.apply(thisArg, _arguments || [])).next());
|
24 | });
|
25 | };
|
26 | var __generator = (this && this.__generator) || function (thisArg, body) {
|
27 | var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g;
|
28 | return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g;
|
29 | function verb(n) { return function (v) { return step([n, v]); }; }
|
30 | function step(op) {
|
31 | if (f) throw new TypeError("Generator is already executing.");
|
32 | while (_) try {
|
33 | if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t;
|
34 | if (y = 0, t) op = [op[0] & 2, t.value];
|
35 | switch (op[0]) {
|
36 | case 0: case 1: t = op; break;
|
37 | case 4: _.label++; return { value: op[1], done: false };
|
38 | case 5: _.label++; y = op[1]; op = [0]; continue;
|
39 | case 7: op = _.ops.pop(); _.trys.pop(); continue;
|
40 | default:
|
41 | if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; }
|
42 | if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; }
|
43 | if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; }
|
44 | if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; }
|
45 | if (t[2]) _.ops.pop();
|
46 | _.trys.pop(); continue;
|
47 | }
|
48 | op = body.call(thisArg, _);
|
49 | } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; }
|
50 | if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true };
|
51 | }
|
52 | };
|
53 | Object.defineProperty(exports, "__esModule", { value: true });
|
54 | var tfjs_1 = require("@tensorflow/tfjs");
|
55 | var fs = require("fs");
|
56 | var util_1 = require("util");
|
57 | var nodejs_kernel_backend_1 = require("./nodejs_kernel_backend");
|
58 | var readFile = util_1.promisify(fs.readFile);
|
59 |
|
60 | var messages = require('./proto/api_pb');
|
61 | var SAVED_MODEL_FILE_NAME = '/saved_model.pb';
|
62 | var SAVED_MODEL_INIT_OP_KEY = '__saved_model_init_op';
|
63 |
|
64 |
|
65 |
|
66 |
|
67 |
|
68 |
|
69 | var loadedSavedModelPathMap = new Map();
|
70 |
|
71 |
|
72 |
|
73 | var nextTFSavedModelId = 0;
|
74 |
|
75 |
|
76 |
|
77 |
|
78 |
|
79 |
|
80 |
|
81 |
|
82 | function getEnumKeyFromValue(object, value) {
|
83 | return Object.keys(object).find(function (key) { return object[key] === value; });
|
84 | }
|
85 | exports.getEnumKeyFromValue = getEnumKeyFromValue;
|
86 |
|
87 |
|
88 |
|
89 |
|
90 |
|
91 | function readSavedModelProto(path) {
|
92 | return __awaiter(this, void 0, void 0, function () {
|
93 | var modelFile, array;
|
94 | return __generator(this, function (_a) {
|
95 | switch (_a.label) {
|
96 | case 0:
|
97 |
|
98 | try {
|
99 | fs.accessSync(path + SAVED_MODEL_FILE_NAME, fs.constants.R_OK);
|
100 | }
|
101 | catch (error) {
|
102 | throw new Error('There is no saved_model.pb file in the directory: ' + path);
|
103 | }
|
104 | return [4 , readFile(path + SAVED_MODEL_FILE_NAME)];
|
105 | case 1:
|
106 | modelFile = _a.sent();
|
107 | array = new Uint8Array(modelFile);
|
108 | return [2 , messages.SavedModel.deserializeBinary(array)];
|
109 | }
|
110 | });
|
111 | });
|
112 | }
|
113 | exports.readSavedModelProto = readSavedModelProto;
|
114 |
|
115 |
|
116 |
|
117 |
|
118 |
|
119 |
|
120 |
|
121 |
|
122 |
|
123 | function getMetaGraphsFromSavedModel(path) {
|
124 | return __awaiter(this, void 0, void 0, function () {
|
125 | var result, modelMessage, metaGraphList, i, metaGraph, tags, signatureDef, signatureDefMap, signatureDefKeys, key, signatureDefEntry, inputsMapMessage, inputsMapKeys, inputs, inputsMapKey, inputTensor, inputTensorInfo, outputsMapMessage, outputsMapKeys, outputs, outputsMapKey, outputTensor, outputTensorInfo;
|
126 | return __generator(this, function (_a) {
|
127 | switch (_a.label) {
|
128 | case 0:
|
129 | result = [];
|
130 | return [4 , readSavedModelProto(path)];
|
131 | case 1:
|
132 | modelMessage = _a.sent();
|
133 | metaGraphList = modelMessage.getMetaGraphsList();
|
134 | for (i = 0; i < metaGraphList.length; i++) {
|
135 | metaGraph = {};
|
136 | tags = metaGraphList[i].getMetaInfoDef().getTagsList();
|
137 | metaGraph.tags = tags;
|
138 | signatureDef = {};
|
139 | signatureDefMap = metaGraphList[i].getSignatureDefMap();
|
140 | signatureDefKeys = signatureDefMap.keys();
|
141 |
|
142 | while (true) {
|
143 | key = signatureDefKeys.next();
|
144 | if (key.done) {
|
145 | break;
|
146 | }
|
147 |
|
148 | if (key.value === SAVED_MODEL_INIT_OP_KEY) {
|
149 | continue;
|
150 | }
|
151 | signatureDefEntry = signatureDefMap.get(key.value);
|
152 | inputsMapMessage = signatureDefEntry.getInputsMap();
|
153 | inputsMapKeys = inputsMapMessage.keys();
|
154 | inputs = {};
|
155 | while (true) {
|
156 | inputsMapKey = inputsMapKeys.next();
|
157 | if (inputsMapKey.done) {
|
158 | break;
|
159 | }
|
160 | inputTensor = inputsMapMessage.get(inputsMapKey.value);
|
161 | inputTensorInfo = {};
|
162 | inputTensorInfo.dtype = mapTFDtypeToJSDtype(getEnumKeyFromValue(messages.DataType, inputTensor.getDtype()));
|
163 | inputTensorInfo.name = inputTensor.getName();
|
164 | inputTensorInfo.shape = inputTensor.getTensorShape().getDimList();
|
165 | inputs[inputsMapKey.value] = inputTensorInfo;
|
166 | }
|
167 | outputsMapMessage = signatureDefEntry.getOutputsMap();
|
168 | outputsMapKeys = outputsMapMessage.keys();
|
169 | outputs = {};
|
170 | while (true) {
|
171 | outputsMapKey = outputsMapKeys.next();
|
172 | if (outputsMapKey.done) {
|
173 | break;
|
174 | }
|
175 | outputTensor = outputsMapMessage.get(outputsMapKey.value);
|
176 | outputTensorInfo = {};
|
177 | outputTensorInfo.dtype = mapTFDtypeToJSDtype(getEnumKeyFromValue(messages.DataType, outputTensor.getDtype()));
|
178 | outputTensorInfo.name = outputTensor.getName();
|
179 | outputTensorInfo.shape = outputTensor.getTensorShape().getDimList();
|
180 | outputs[outputsMapKey.value] = outputTensorInfo;
|
181 | }
|
182 | signatureDef[key.value] = { inputs: inputs, outputs: outputs };
|
183 | }
|
184 | metaGraph.signatureDefs = signatureDef;
|
185 | result.push(metaGraph);
|
186 | }
|
187 | return [2 , result];
|
188 | }
|
189 | });
|
190 | });
|
191 | }
|
192 | exports.getMetaGraphsFromSavedModel = getMetaGraphsFromSavedModel;
|
193 |
|
194 |
|
195 |
|
196 |
|
197 |
|
198 |
|
199 |
|
200 |
|
201 |
|
202 | function getInputAndOutputNodeNameFromMetaGraphInfo(savedModelInfo, tags, signature) {
|
203 | for (var i = 0; i < savedModelInfo.length; i++) {
|
204 | var metaGraphInfo = savedModelInfo[i];
|
205 | if (stringArraysHaveSameElements(tags, metaGraphInfo.tags)) {
|
206 | if (metaGraphInfo.signatureDefs[signature] == null) {
|
207 | throw new Error('The SavedModel does not have signature: ' + signature);
|
208 | }
|
209 | var inputNodeNames = {};
|
210 | var outputNodeNames = {};
|
211 | for (var _i = 0, _a = Object.keys(metaGraphInfo.signatureDefs); _i < _a.length; _i++) {
|
212 | var signatureDef = _a[_i];
|
213 | if (signatureDef === signature) {
|
214 | for (var _b = 0, _c = Object.keys(metaGraphInfo.signatureDefs[signature].inputs); _b < _c.length; _b++) {
|
215 | var tensorName = _c[_b];
|
216 | inputNodeNames[tensorName] =
|
217 | metaGraphInfo.signatureDefs[signature].inputs[tensorName].name;
|
218 | }
|
219 | for (var _d = 0, _e = Object.keys(metaGraphInfo.signatureDefs[signature].outputs); _d < _e.length; _d++) {
|
220 | var tensorName = _e[_d];
|
221 | outputNodeNames[tensorName] =
|
222 | metaGraphInfo.signatureDefs[signature].outputs[tensorName].name;
|
223 | }
|
224 | }
|
225 | }
|
226 | return [inputNodeNames, outputNodeNames];
|
227 | }
|
228 | }
|
229 | throw new Error("The SavedModel does not have tags: " + tags);
|
230 | }
|
231 | exports.getInputAndOutputNodeNameFromMetaGraphInfo = getInputAndOutputNodeNameFromMetaGraphInfo;
|
232 |
|
233 |
|
234 |
|
235 |
|
236 |
|
237 |
|
238 |
|
239 | var TFSavedModel = (function () {
|
240 | function TFSavedModel(sessionId, jsid, inputNodeNames, outputNodeNames, backend) {
|
241 | this.sessionId = sessionId;
|
242 | this.jsid = jsid;
|
243 | this.inputNodeNames = inputNodeNames;
|
244 | this.outputNodeNames = outputNodeNames;
|
245 | this.backend = backend;
|
246 | this.disposed = false;
|
247 | }
|
248 | Object.defineProperty(TFSavedModel.prototype, "inputs", {
|
249 | |
250 |
|
251 |
|
252 |
|
253 | get: function () {
|
254 | throw new Error('SavedModel inputs information is not available yet.');
|
255 | },
|
256 | enumerable: true,
|
257 | configurable: true
|
258 | });
|
259 | Object.defineProperty(TFSavedModel.prototype, "outputs", {
|
260 | |
261 |
|
262 |
|
263 |
|
264 | get: function () {
|
265 | throw new Error('SavedModel outputs information is not available yet.');
|
266 | },
|
267 | enumerable: true,
|
268 | configurable: true
|
269 | });
|
270 | |
271 |
|
272 |
|
273 |
|
274 |
|
275 | TFSavedModel.prototype.dispose = function () {
|
276 | if (!this.disposed) {
|
277 | this.disposed = true;
|
278 | loadedSavedModelPathMap.delete(this.jsid);
|
279 | for (var _i = 0, _a = Array.from(loadedSavedModelPathMap.keys()); _i < _a.length; _i++) {
|
280 | var id = _a[_i];
|
281 | var value = loadedSavedModelPathMap.get(id);
|
282 | if (value.sessionId === this.sessionId) {
|
283 | return;
|
284 | }
|
285 | }
|
286 | this.backend.deleteSavedModel(this.sessionId);
|
287 | }
|
288 | else {
|
289 | throw new Error('This SavedModel has already been deleted.');
|
290 | }
|
291 | };
|
292 | |
293 |
|
294 |
|
295 |
|
296 |
|
297 |
|
298 |
|
299 |
|
300 |
|
301 |
|
302 |
|
303 |
|
304 |
|
305 |
|
306 |
|
307 |
|
308 |
|
309 |
|
310 |
|
311 |
|
312 |
|
313 |
|
314 |
|
315 | TFSavedModel.prototype.predict = function (inputs, config) {
|
316 | var _this = this;
|
317 | if (this.disposed) {
|
318 | throw new Error('The TFSavedModel has already been deleted!');
|
319 | }
|
320 | else {
|
321 | var inputTensors = [];
|
322 | if (inputs instanceof tfjs_1.Tensor) {
|
323 | inputTensors.push(inputs);
|
324 | var result = this.backend.runSavedModel(this.sessionId, inputTensors, Object.values(this.inputNodeNames), Object.values(this.outputNodeNames));
|
325 | return result.length > 1 ? result : result[0];
|
326 | }
|
327 | else if (Array.isArray(inputs)) {
|
328 | inputTensors = inputs;
|
329 | return this.backend.runSavedModel(this.sessionId, inputTensors, Object.values(this.inputNodeNames), Object.values(this.outputNodeNames));
|
330 | }
|
331 | else {
|
332 | var inputTensorNames = Object.keys(this.inputNodeNames);
|
333 | var providedInputNames = Object.keys(inputs);
|
334 | if (!stringArraysHaveSameElements(inputTensorNames, providedInputNames)) {
|
335 | throw new Error("The model signatureDef input names are " + inputTensorNames.join() + ", however the provided input names are " + providedInputNames.join() + ".");
|
336 | }
|
337 | var inputNodeNamesArray = [];
|
338 | for (var i = 0; i < inputTensorNames.length; i++) {
|
339 | inputTensors.push(inputs[inputTensorNames[i]]);
|
340 | inputNodeNamesArray.push(this.inputNodeNames[inputTensorNames[i]]);
|
341 | }
|
342 | var outputTensorNames = Object.keys(this.outputNodeNames);
|
343 | var outputNodeNamesArray = [];
|
344 | for (var i = 0; i < outputTensorNames.length; i++) {
|
345 | outputNodeNamesArray.push(this.outputNodeNames[outputTensorNames[i]]);
|
346 | }
|
347 | var outputTensors_1 = this.backend.runSavedModel(this.sessionId, inputTensors, inputNodeNamesArray, outputNodeNamesArray);
|
348 | tfjs_1.util.assert(outputTensors_1.length === outputNodeNamesArray.length, function () { return 'Output tensors do not match output node names, ' +
|
349 | ("receive " + outputTensors_1.length + ") output tensors but ") +
|
350 | ("there are " + _this.outputNodeNames.length + " output nodes."); });
|
351 | var outputMap = {};
|
352 | for (var i = 0; i < outputTensorNames.length; i++) {
|
353 | outputMap[outputTensorNames[i]] = outputTensors_1[i];
|
354 | }
|
355 | return outputMap;
|
356 | }
|
357 | }
|
358 | };
|
359 | |
360 |
|
361 |
|
362 |
|
363 |
|
364 |
|
365 |
|
366 |
|
367 |
|
368 |
|
369 |
|
370 |
|
371 |
|
372 |
|
373 |
|
374 |
|
375 |
|
376 |
|
377 | TFSavedModel.prototype.execute = function (inputs, outputs) {
|
378 | throw new Error('execute() of TFSavedModel is not supported yet.');
|
379 | };
|
380 | return TFSavedModel;
|
381 | }());
|
382 | exports.TFSavedModel = TFSavedModel;
|
383 |
|
384 |
|
385 |
|
386 |
|
387 |
|
388 |
|
389 |
|
390 |
|
391 |
|
392 |
|
393 |
|
394 |
|
395 |
|
396 |
|
397 |
|
398 |
|
399 |
|
400 |
|
401 |
|
402 |
|
403 |
|
404 |
|
405 | function loadSavedModel(path, tags, signature) {
|
406 | if (tags === void 0) { tags = ['serve']; }
|
407 | if (signature === void 0) { signature = 'serving_default'; }
|
408 | return __awaiter(this, void 0, void 0, function () {
|
409 | var backend, savedModelInfo, _a, inputNodeNames, outputNodeNames, sessionId, _i, _b, id_1, modelInfo, tagsString, id, savedModel;
|
410 | return __generator(this, function (_c) {
|
411 | switch (_c.label) {
|
412 | case 0:
|
413 | nodejs_kernel_backend_1.ensureTensorflowBackend();
|
414 | backend = nodejs_kernel_backend_1.nodeBackend();
|
415 | return [4 , getMetaGraphsFromSavedModel(path)];
|
416 | case 1:
|
417 | savedModelInfo = _c.sent();
|
418 | _a = getInputAndOutputNodeNameFromMetaGraphInfo(savedModelInfo, tags, signature), inputNodeNames = _a[0], outputNodeNames = _a[1];
|
419 | for (_i = 0, _b = Array.from(loadedSavedModelPathMap.keys()); _i < _b.length; _i++) {
|
420 | id_1 = _b[_i];
|
421 | modelInfo = loadedSavedModelPathMap.get(id_1);
|
422 | if (modelInfo.path === path &&
|
423 | stringArraysHaveSameElements(modelInfo.tags, tags)) {
|
424 | sessionId = modelInfo.sessionId;
|
425 | }
|
426 | }
|
427 | if (sessionId == null) {
|
428 | tagsString = tags.join(',');
|
429 | sessionId = backend.loadSavedModelMetaGraph(path, tagsString);
|
430 | }
|
431 | id = nextTFSavedModelId++;
|
432 | savedModel = new TFSavedModel(sessionId, id, inputNodeNames, outputNodeNames, backend);
|
433 | loadedSavedModelPathMap.set(id, { path: path, tags: tags, sessionId: sessionId });
|
434 | return [2 , savedModel];
|
435 | }
|
436 | });
|
437 | });
|
438 | }
|
439 | exports.loadSavedModel = loadSavedModel;
|
440 |
|
441 |
|
442 |
|
443 |
|
444 |
|
445 | function stringArraysHaveSameElements(arrayA, arrayB) {
|
446 | if (arrayA.length === arrayB.length &&
|
447 | arrayA.sort().join() === arrayB.sort().join()) {
|
448 | return true;
|
449 | }
|
450 | return false;
|
451 | }
|
452 | function mapTFDtypeToJSDtype(tfDtype) {
|
453 | switch (tfDtype) {
|
454 | case 'DT_FLOAT':
|
455 | return 'float32';
|
456 | case 'DT_INT32':
|
457 | return 'int32';
|
458 | case 'DT_BOOL':
|
459 | return 'bool';
|
460 | case 'DT_COMPLEX64':
|
461 | return 'complex64';
|
462 | case 'DT_STRING':
|
463 | return 'string';
|
464 | default:
|
465 | throw new Error('Unsupported tensor DataType: ' + tfDtype +
|
466 | ', try to modify the model in python to convert the datatype');
|
467 | }
|
468 | }
|
469 | function getNumOfSavedModels() {
|
470 | nodejs_kernel_backend_1.ensureTensorflowBackend();
|
471 | var backend = nodejs_kernel_backend_1.nodeBackend();
|
472 | return backend.getNumOfSavedModels();
|
473 | }
|
474 | exports.getNumOfSavedModels = getNumOfSavedModels;
|