1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 |
|
18 | #include "tfjs_backend.h"
|
19 |
|
20 | #include <algorithm>
|
21 | #include <cstring>
|
22 | #include <memory>
|
23 | #include <set>
|
24 | #include <string>
|
25 | #include "napi_auto_ref.h"
|
26 | #include "tf_auto_tensor.h"
|
27 | #include "tfe_auto_op.h"
|
28 | #include "utils.h"
|
29 |
|
30 | namespace tfnodejs {
|
31 |
|
32 |
|
33 | static std::set<std::string> ATTR_NAME_SET;
|
34 |
|
35 |
|
36 | static void DeallocTensor(void *data, size_t len, void *arg) {
|
37 | NapiAutoRef *auto_ref = static_cast<NapiAutoRef *>(arg);
|
38 | if (!auto_ref) {
|
39 | #if DEBUG
|
40 | fprintf(stderr, "Invalid NapiAutoRef reference passed to V8 cleanup\n");
|
41 | #endif
|
42 | return;
|
43 | }
|
44 | if (auto_ref->Cleanup() != napi_ok) {
|
45 | #if DEBUG
|
46 | fprintf(stderr, "Exception cleaning up napi_ref instance\n");
|
47 | #endif
|
48 | }
|
49 | delete auto_ref;
|
50 | }
|
51 |
|
52 |
|
53 | TFE_TensorHandle *CreateTFE_TensorHandleFromTypedArray(napi_env env,
|
54 | int64_t *shape,
|
55 | uint32_t shape_length,
|
56 | TF_DataType dtype,
|
57 | napi_value array_value) {
|
58 | napi_status nstatus;
|
59 | napi_typedarray_type array_type;
|
60 | size_t array_length;
|
61 | void *array_data;
|
62 | nstatus =
|
63 | napi_get_typedarray_info(env, array_value, &array_type, &array_length,
|
64 | &array_data, nullptr, nullptr);
|
65 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
66 |
|
67 |
|
68 |
|
69 | size_t width = 0;
|
70 | switch (array_type) {
|
71 | case napi_float32_array:
|
72 | if (dtype != TF_FLOAT) {
|
73 | NAPI_THROW_ERROR(env, "Tensor type does not match Float32Array");
|
74 | return nullptr;
|
75 | }
|
76 | width = sizeof(float);
|
77 | break;
|
78 | case napi_int32_array:
|
79 | if (dtype != TF_INT32 && dtype != TF_INT64) {
|
80 |
|
81 |
|
82 |
|
83 | NAPI_THROW_ERROR(env, "Tensor type does not match Int32Array");
|
84 | return nullptr;
|
85 | }
|
86 | width = sizeof(int32_t);
|
87 | break;
|
88 | case napi_uint8_array:
|
89 | if (dtype != TF_BOOL && dtype != TF_UINT8) {
|
90 | NAPI_THROW_ERROR(env, "Tensor type does not match Uint8Array");
|
91 | return nullptr;
|
92 | }
|
93 | width = sizeof(uint8_t);
|
94 | break;
|
95 | default:
|
96 | REPORT_UNKNOWN_TYPED_ARRAY_TYPE(env, array_type);
|
97 | return nullptr;
|
98 | }
|
99 |
|
100 |
|
101 | if (dtype == TF_INT64) {
|
102 |
|
103 |
|
104 |
|
105 | if (width * 2 != TF_DataTypeSize(dtype)) {
|
106 | NAPI_THROW_ERROR(
|
107 | env,
|
108 | "Byte size of elements differs between JavaScript VM "
|
109 | "(%zu * 2 = %zu) and TensorFlow (%zu) for int64-type tensor",
|
110 | width, width * 2, TF_DataTypeSize(dtype));
|
111 | return nullptr;
|
112 | }
|
113 | } else {
|
114 | if (width != TF_DataTypeSize(dtype)) {
|
115 | NAPI_THROW_ERROR(env,
|
116 | "Byte size of elements differs between JavaScript VM "
|
117 | "(%zu) and TensorFlow (%zu)",
|
118 | width, TF_DataTypeSize(dtype));
|
119 | return nullptr;
|
120 | }
|
121 | }
|
122 |
|
123 |
|
124 | size_t num_elements = 1;
|
125 | for (size_t i = 0; i < shape_length; i++) {
|
126 | num_elements *= shape[i];
|
127 | }
|
128 |
|
129 |
|
130 | if (dtype == TF_INT64) {
|
131 |
|
132 |
|
133 |
|
134 |
|
135 | if (array_length != num_elements * 2) {
|
136 | NAPI_THROW_ERROR(
|
137 | env,
|
138 | "Shape does not match two times typed-array in bindData() "
|
139 | "(num_elements * 2 = %zu, array_length=%zu) for int64 data type",
|
140 | num_elements * 2, array_length);
|
141 | return nullptr;
|
142 | }
|
143 | } else {
|
144 | if (num_elements != array_length) {
|
145 | NAPI_THROW_ERROR(env,
|
146 | "Shape does not match typed-array in bindData() "
|
147 | "(num_elements=%zu, array_length=%zu)",
|
148 | num_elements, array_length);
|
149 | return nullptr;
|
150 | }
|
151 | }
|
152 |
|
153 |
|
154 |
|
155 |
|
156 | NapiAutoRef *auto_ref = new NapiAutoRef();
|
157 | nstatus = auto_ref->Init(env, array_value);
|
158 | if (nstatus != napi_ok) {
|
159 | delete auto_ref;
|
160 | }
|
161 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
162 |
|
163 |
|
164 |
|
165 |
|
166 | const size_t byte_size =
|
167 | dtype == TF_INT64 ? num_elements * width * 2 : num_elements * width;
|
168 |
|
169 | TF_AutoTensor tensor(TF_NewTensor(dtype, shape, shape_length, array_data,
|
170 | byte_size, DeallocTensor, auto_ref));
|
171 |
|
172 | TF_AutoStatus tf_status;
|
173 | TFE_TensorHandle *tfe_tensor_handle =
|
174 | TFE_NewTensorHandle(tensor.tensor, tf_status.status);
|
175 | if (TF_GetCode(tf_status.status) != TF_OK) {
|
176 | delete auto_ref;
|
177 | TFE_DeleteTensorHandle(tfe_tensor_handle);
|
178 | }
|
179 | ENSURE_TF_OK_RETVAL(env, tf_status, nullptr);
|
180 |
|
181 | return tfe_tensor_handle;
|
182 | }
|
183 |
|
184 |
|
185 | TFE_TensorHandle *CreateTFE_TensorHandleFromStringArray(
|
186 | napi_env env, int64_t *shape, uint32_t shape_length, TF_DataType dtype,
|
187 | napi_value array_value) {
|
188 | napi_status nstatus;
|
189 |
|
190 | uint32_t array_length;
|
191 | nstatus = napi_get_array_length(env, array_value, &array_length);
|
192 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
193 |
|
194 | size_t offsets_size = array_length * sizeof(uint64_t);
|
195 | size_t data_size = offsets_size;
|
196 |
|
197 | for (uint32_t i = 0; i < array_length; ++i) {
|
198 | napi_value cur_value;
|
199 | nstatus = napi_get_element(env, array_value, i, &cur_value);
|
200 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
201 | ENSURE_VALUE_IS_TYPED_ARRAY_RETVAL(env, cur_value, nullptr);
|
202 |
|
203 | size_t cur_array_length;
|
204 | napi_typedarray_type array_type;
|
205 | nstatus =
|
206 | napi_get_typedarray_info(env, cur_value, &array_type, &cur_array_length,
|
207 | nullptr, nullptr, nullptr);
|
208 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
209 |
|
210 |
|
211 | if (array_type != napi_uint8_array) {
|
212 | NAPI_THROW_ERROR(env, "Unsupported array type - expecting Uint8Array");
|
213 | return nullptr;
|
214 | }
|
215 |
|
216 | data_size += TF_StringEncodedSize(cur_array_length);
|
217 | }
|
218 |
|
219 | TF_AutoStatus tf_status;
|
220 | TF_AutoTensor tensor(
|
221 | TF_AllocateTensor(TF_STRING, shape, shape_length, data_size));
|
222 |
|
223 | void *tensor_data = TF_TensorData(tensor.tensor);
|
224 | uint64_t *offsets = (uint64_t *)tensor_data;
|
225 |
|
226 | char *str_data_start = (char *)tensor_data + offsets_size;
|
227 | char *cur_str_data = str_data_start;
|
228 |
|
229 | for (uint32_t i = 0; i < array_length; ++i) {
|
230 | napi_value cur_value;
|
231 | nstatus = napi_get_element(env, array_value, i, &cur_value);
|
232 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
233 |
|
234 | size_t cur_array_length;
|
235 | void *buffer = nullptr;
|
236 | nstatus = napi_get_typedarray_info(
|
237 | env, cur_value, nullptr, &cur_array_length, &buffer, nullptr, nullptr);
|
238 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
239 |
|
240 | size_t encoded_size =
|
241 | TF_StringEncode(reinterpret_cast<char *>(buffer), cur_array_length,
|
242 | cur_str_data, data_size, tf_status.status);
|
243 | ENSURE_TF_OK_RETVAL(env, tf_status, nullptr);
|
244 |
|
245 | offsets[i] = cur_str_data - str_data_start;
|
246 | cur_str_data += encoded_size;
|
247 | }
|
248 |
|
249 | TFE_TensorHandle *tfe_tensor_handle =
|
250 | TFE_NewTensorHandle(tensor.tensor, tf_status.status);
|
251 | ENSURE_TF_OK_RETVAL(env, tf_status, nullptr);
|
252 | return tfe_tensor_handle;
|
253 | }
|
254 |
|
255 | TFE_TensorHandle *CreateTFE_TensorHandleFromJSValues(napi_env env,
|
256 | int64_t *shape,
|
257 | uint32_t shape_length,
|
258 | TF_DataType dtype,
|
259 | napi_value array_value) {
|
260 | bool is_typed_array;
|
261 | napi_status nstatus = napi_is_typedarray(env, array_value, &is_typed_array);
|
262 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
263 | if (is_typed_array) {
|
264 | return CreateTFE_TensorHandleFromTypedArray(env, shape, shape_length, dtype,
|
265 | array_value);
|
266 | } else {
|
267 | return CreateTFE_TensorHandleFromStringArray(env, shape, shape_length,
|
268 | dtype, array_value);
|
269 | }
|
270 | }
|
271 |
|
272 | TFE_TensorHandle *CopyTFE_TensorHandleToDevice(napi_env env,
|
273 | const char *device_name,
|
274 | TFE_TensorHandle *handle,
|
275 | TFE_Context *tfe_context) {
|
276 | TF_AutoStatus tf_status;
|
277 |
|
278 | TFE_TensorHandle *new_handle = TFE_TensorHandleCopyToDevice(
|
279 | handle, tfe_context, device_name, tf_status.status);
|
280 | ENSURE_TF_OK_RETVAL(env, tf_status, nullptr);
|
281 |
|
282 | return new_handle;
|
283 | }
|
284 |
|
285 | void CopyTFE_TensorHandleDataToTypedArray(napi_env env,
|
286 | TFE_Context *tfe_context,
|
287 | TFE_TensorHandle *tfe_tensor_handle,
|
288 | TF_DataType tensor_data_type,
|
289 | napi_typedarray_type array_type,
|
290 | napi_value *result) {
|
291 | TF_AutoStatus tf_status;
|
292 |
|
293 | TF_AutoTensor tensor(
|
294 | TFE_TensorHandleResolve(tfe_tensor_handle, tf_status.status));
|
295 | ENSURE_TF_OK(env, tf_status);
|
296 |
|
297 |
|
298 | size_t num_elements = GetTensorNumElements(tensor.tensor);
|
299 |
|
300 | if (tensor_data_type == TF_COMPLEX64) {
|
301 |
|
302 | num_elements *= 2;
|
303 | }
|
304 |
|
305 | size_t byte_length = TF_TensorByteSize(tensor.tensor);
|
306 |
|
307 | napi_value array_buffer_value;
|
308 | void *array_buffer_data;
|
309 | napi_status nstatus;
|
310 | nstatus = napi_create_arraybuffer(env, byte_length, &array_buffer_data,
|
311 | &array_buffer_value);
|
312 | ENSURE_NAPI_OK(env, nstatus);
|
313 |
|
314 |
|
315 |
|
316 | memcpy(array_buffer_data, TF_TensorData(tensor.tensor), byte_length);
|
317 |
|
318 | nstatus = napi_create_typedarray(env, array_type, num_elements,
|
319 | array_buffer_value, 0, result);
|
320 | ENSURE_NAPI_OK(env, nstatus);
|
321 | }
|
322 |
|
323 | void CopyTFE_TensorHandleDataToStringArray(napi_env env,
|
324 | TFE_Context *tfe_context,
|
325 | TFE_TensorHandle *tfe_tensor_handle,
|
326 | napi_value *result) {
|
327 | TF_AutoStatus tf_status;
|
328 |
|
329 | TF_AutoTensor tensor(
|
330 | TFE_TensorHandleResolve(tfe_tensor_handle, tf_status.status));
|
331 | ENSURE_TF_OK(env, tf_status);
|
332 |
|
333 | if (TF_TensorType(tensor.tensor) != TF_STRING) {
|
334 | NAPI_THROW_ERROR(env, "Tensor is not of type TF_STRING");
|
335 | return;
|
336 | }
|
337 |
|
338 | void *tensor_data = TF_TensorData(tensor.tensor);
|
339 | ENSURE_VALUE_IS_NOT_NULL(env, tensor_data);
|
340 |
|
341 | size_t byte_length = TF_TensorByteSize(tensor.tensor);
|
342 | const char *limit = static_cast<const char *>(tensor_data) + byte_length;
|
343 |
|
344 | size_t num_elements = GetTensorNumElements(tensor.tensor);
|
345 |
|
346 |
|
347 | const uint64_t *offsets = static_cast<const uint64_t *>(tensor_data);
|
348 | const size_t offsets_size = sizeof(uint64_t) * num_elements;
|
349 |
|
350 |
|
351 | const char *data = static_cast<const char *>(tensor_data) + offsets_size;
|
352 |
|
353 | TF_AutoStatus status;
|
354 |
|
355 |
|
356 | napi_status nstatus;
|
357 | nstatus = napi_create_array_with_length(env, num_elements, result);
|
358 |
|
359 | const size_t expected_tensor_size =
|
360 | (limit - static_cast<const char *>(tensor_data));
|
361 | if (expected_tensor_size != byte_length) {
|
362 | NAPI_THROW_ERROR(env,
|
363 | "Invalid/corrupt TF_STRING tensor. Expected size: %zu, "
|
364 | "byte_length: %zu",
|
365 | expected_tensor_size, byte_length);
|
366 | return;
|
367 | }
|
368 |
|
369 | for (uint64_t i = 0; i < num_elements; i++) {
|
370 | const char *start = data + offsets[i];
|
371 | const char *str_ptr = nullptr;
|
372 | size_t str_len = 0;
|
373 |
|
374 | TF_StringDecode(start, limit - start, &str_ptr, &str_len, status.status);
|
375 | ENSURE_TF_OK(env, tf_status);
|
376 |
|
377 | napi_value array_buffer_value;
|
378 | void *array_buffer_data;
|
379 | nstatus = napi_create_arraybuffer(env, str_len, &array_buffer_data,
|
380 | &array_buffer_value);
|
381 | ENSURE_NAPI_OK(env, nstatus);
|
382 |
|
383 |
|
384 |
|
385 |
|
386 | memcpy(array_buffer_data, str_ptr, str_len);
|
387 |
|
388 | napi_value typed_array_value;
|
389 | nstatus = napi_create_typedarray(env, napi_uint8_array, str_len,
|
390 | array_buffer_value, 0, &typed_array_value);
|
391 | ENSURE_NAPI_OK(env, nstatus);
|
392 |
|
393 | nstatus = napi_set_element(env, *result, i, typed_array_value);
|
394 | ENSURE_NAPI_OK(env, nstatus);
|
395 | }
|
396 | }
|
397 |
|
398 | void CopyTFE_TensorHandleDataToResourceArray(
|
399 | napi_env env, TFE_Context *tfe_context, TFE_TensorHandle *tfe_tensor_handle,
|
400 | napi_value *result) {
|
401 | TF_AutoStatus tf_status;
|
402 |
|
403 | TF_AutoTensor tensor(
|
404 | TFE_TensorHandleResolve(tfe_tensor_handle, tf_status.status));
|
405 | ENSURE_TF_OK(env, tf_status);
|
406 |
|
407 | if (TF_TensorType(tensor.tensor) != TF_RESOURCE) {
|
408 | NAPI_THROW_ERROR(env, "Tensor is not of type TF_RESOURCE");
|
409 | return;
|
410 | }
|
411 |
|
412 | void *tensor_data = TF_TensorData(tensor.tensor);
|
413 | ENSURE_VALUE_IS_NOT_NULL(env, tensor_data);
|
414 |
|
415 | size_t num_elements = GetTensorNumElements(tensor.tensor);
|
416 | if (num_elements != 1) {
|
417 | NAPI_THROW_ERROR(env,
|
418 | "For DT_RESOURCE tensors, Node.js binding currently "
|
419 | "supports only exactly 1 element, but encountered "
|
420 | "DT_RESOURCE tensor with %zu elements.",
|
421 | num_elements);
|
422 | }
|
423 |
|
424 | TF_AutoStatus status;
|
425 |
|
426 |
|
427 | napi_status nstatus;
|
428 | size_t byte_length = TF_TensorByteSize(tensor.tensor);
|
429 | nstatus = napi_create_array_with_length(env, byte_length, result);
|
430 | ENSURE_NAPI_OK(env, nstatus);
|
431 |
|
432 | napi_value array_buffer_value;
|
433 | void *array_buffer_data = nullptr;
|
434 | nstatus = napi_create_arraybuffer(env, byte_length, &array_buffer_data,
|
435 | &array_buffer_value);
|
436 | ENSURE_NAPI_OK(env, nstatus);
|
437 |
|
438 |
|
439 |
|
440 | memcpy(array_buffer_data, tensor_data, byte_length);
|
441 |
|
442 |
|
443 | nstatus = napi_create_typedarray(env, napi_uint8_array, byte_length,
|
444 | array_buffer_value, 0, result);
|
445 | ENSURE_NAPI_OK(env, nstatus);
|
446 | }
|
447 |
|
448 |
|
449 | void CopyTFE_TensorHandleDataToJSData(napi_env env, TFE_Context *tfe_context,
|
450 | TFE_TensorHandle *tfe_tensor_handle,
|
451 | napi_value *result) {
|
452 | if (tfe_context == nullptr) {
|
453 | NAPI_THROW_ERROR(env, "Invalid TFE_Context");
|
454 | return;
|
455 | }
|
456 | if (tfe_tensor_handle == nullptr) {
|
457 | NAPI_THROW_ERROR(env, "Invalid TFE_TensorHandle");
|
458 | return;
|
459 | }
|
460 |
|
461 |
|
462 | napi_typedarray_type typed_array_type;
|
463 | bool is_string = false;
|
464 | bool is_resource = false;
|
465 | TF_DataType tensor_data_type = TFE_TensorHandleDataType(tfe_tensor_handle);
|
466 | switch (tensor_data_type) {
|
467 | case TF_COMPLEX64:
|
468 | case TF_FLOAT:
|
469 | typed_array_type = napi_float32_array;
|
470 | break;
|
471 | case TF_INT32:
|
472 | typed_array_type = napi_int32_array;
|
473 | break;
|
474 | case TF_BOOL:
|
475 | typed_array_type = napi_uint8_array;
|
476 | break;
|
477 | case TF_STRING:
|
478 | is_string = true;
|
479 | break;
|
480 | case TF_RESOURCE:
|
481 |
|
482 | typed_array_type = napi_uint8_array;
|
483 | is_resource = true;
|
484 | break;
|
485 | default:
|
486 | REPORT_UNKNOWN_TF_DATA_TYPE(env,
|
487 | TFE_TensorHandleDataType(tfe_tensor_handle));
|
488 | return;
|
489 | }
|
490 |
|
491 | if (is_string) {
|
492 | CopyTFE_TensorHandleDataToStringArray(env, tfe_context, tfe_tensor_handle,
|
493 | result);
|
494 | } else if (is_resource) {
|
495 | CopyTFE_TensorHandleDataToResourceArray(env, tfe_context, tfe_tensor_handle,
|
496 | result);
|
497 | } else {
|
498 | CopyTFE_TensorHandleDataToTypedArray(env, tfe_context, tfe_tensor_handle,
|
499 | tensor_data_type, typed_array_type,
|
500 | result);
|
501 | }
|
502 | }
|
503 |
|
504 | void GetTFE_TensorHandleShape(napi_env env, TFE_TensorHandle *handle,
|
505 | napi_value *result) {
|
506 | napi_status nstatus;
|
507 |
|
508 | TF_AutoStatus tf_status;
|
509 | uint32_t num_dims = TFE_TensorHandleNumDims(handle, tf_status.status);
|
510 | ENSURE_TF_OK(env, tf_status);
|
511 |
|
512 | if (num_dims == 0) {
|
513 | nstatus = napi_create_array_with_length(env, 0, result);
|
514 | ENSURE_NAPI_OK(env, nstatus);
|
515 | } else {
|
516 | nstatus = napi_create_array_with_length(env, num_dims, result);
|
517 | ENSURE_NAPI_OK(env, nstatus);
|
518 |
|
519 | for (uint32_t i = 0; i < num_dims; i++) {
|
520 | napi_value cur_dim;
|
521 | nstatus = napi_create_int64(
|
522 | env, TFE_TensorHandleDim(handle, i, tf_status.status), &cur_dim);
|
523 | ENSURE_TF_OK(env, tf_status);
|
524 | ENSURE_NAPI_OK(env, nstatus);
|
525 |
|
526 | nstatus = napi_set_element(env, *result, i, cur_dim);
|
527 | ENSURE_NAPI_OK(env, nstatus);
|
528 | }
|
529 | }
|
530 | }
|
531 |
|
532 | inline bool IsArray(napi_env env, napi_status &nstatus, napi_value *val) {
|
533 | bool is_array;
|
534 | nstatus = napi_is_array(env, *val, &is_array);
|
535 | ENSURE_NAPI_OK_RETVAL(env, nstatus, false);
|
536 | return is_array;
|
537 | }
|
538 |
|
539 | void GetTFE_TensorHandleType(napi_env env, TFE_TensorHandle *handle,
|
540 | napi_value *result) {
|
541 | napi_status nstatus;
|
542 |
|
543 | TF_DataType dtype = TFE_TensorHandleDataType(handle);
|
544 | nstatus = napi_create_int32(env, dtype, result);
|
545 | ENSURE_NAPI_OK(env, nstatus);
|
546 | }
|
547 |
|
548 | void AssignOpAttr(napi_env env, TFE_Op *tfe_op, napi_value attr_value) {
|
549 | napi_status nstatus;
|
550 |
|
551 | napi_value attr_name_value;
|
552 | nstatus = napi_get_named_property(env, attr_value, "name", &attr_name_value);
|
553 | ENSURE_NAPI_OK(env, nstatus);
|
554 |
|
555 | std::string attr_name_string;
|
556 | nstatus = GetStringParam(env, attr_name_value, attr_name_string);
|
557 | ENSURE_NAPI_OK(env, nstatus);
|
558 |
|
559 |
|
560 |
|
561 |
|
562 | const char *attr_name =
|
563 | ATTR_NAME_SET.insert(attr_name_string.c_str()).first->c_str();
|
564 |
|
565 | napi_value attr_type_value;
|
566 | nstatus = napi_get_named_property(env, attr_value, "type", &attr_type_value);
|
567 | ENSURE_NAPI_OK(env, nstatus);
|
568 |
|
569 | TF_AttrType tf_attr_type;
|
570 | nstatus = napi_get_value_int32(env, attr_type_value,
|
571 | reinterpret_cast<int32_t *>(&tf_attr_type));
|
572 | ENSURE_NAPI_OK(env, nstatus);
|
573 |
|
574 | napi_value js_value;
|
575 | nstatus = napi_get_named_property(env, attr_value, "value", &js_value);
|
576 | ENSURE_NAPI_OK(env, nstatus);
|
577 |
|
578 | switch (tf_attr_type) {
|
579 | case TF_ATTR_STRING: {
|
580 |
|
581 |
|
582 | std::string str_value;
|
583 | nstatus = GetStringParam(env, js_value, str_value);
|
584 | ENSURE_NAPI_OK(env, nstatus);
|
585 |
|
586 | TFE_OpSetAttrString(tfe_op, attr_name, str_value.c_str(),
|
587 | str_value.size());
|
588 | break;
|
589 | }
|
590 |
|
591 | case TF_ATTR_INT: {
|
592 | if (IsArray(env, nstatus, &js_value)) {
|
593 | uint32_t length;
|
594 | nstatus = napi_get_array_length(env, js_value, &length);
|
595 | ENSURE_NAPI_OK(env, nstatus);
|
596 | std::unique_ptr<int64_t[]> data(new int64_t[length]);
|
597 | for (uint32_t i = 0; i < length; ++i) {
|
598 | napi_value element;
|
599 | nstatus = napi_get_element(env, js_value, i, &element);
|
600 | ENSURE_NAPI_OK(env, nstatus);
|
601 | int32_t value;
|
602 | nstatus = napi_get_value_int32(env, element, &value);
|
603 | ENSURE_NAPI_OK(env, nstatus);
|
604 | data[i] = value;
|
605 | }
|
606 | TFE_OpSetAttrIntList(tfe_op, attr_name, data.get(),
|
607 | static_cast<int>(length));
|
608 | } else {
|
609 | int64_t value;
|
610 | nstatus = napi_get_value_int64(env, js_value, &value);
|
611 | ENSURE_NAPI_OK(env, nstatus);
|
612 |
|
613 | TFE_OpSetAttrInt(tfe_op, attr_name, value);
|
614 | }
|
615 | break;
|
616 | }
|
617 |
|
618 | case TF_ATTR_FLOAT: {
|
619 | if (IsArray(env, nstatus, &js_value)) {
|
620 | uint32_t length;
|
621 | nstatus = napi_get_array_length(env, js_value, &length);
|
622 | ENSURE_NAPI_OK(env, nstatus);
|
623 | std::unique_ptr<float[]> data(new float[length]);
|
624 | for (uint32_t i = 0; i < length; ++i) {
|
625 | napi_value element;
|
626 | nstatus = napi_get_element(env, js_value, i, &element);
|
627 | ENSURE_NAPI_OK(env, nstatus);
|
628 | double value;
|
629 | nstatus = napi_get_value_double(env, element, &value);
|
630 | ENSURE_NAPI_OK(env, nstatus);
|
631 | data[i] = static_cast<float>(value);
|
632 | }
|
633 | TFE_OpSetAttrFloatList(tfe_op, attr_name, data.get(),
|
634 | static_cast<int>(length));
|
635 | } else {
|
636 | double value;
|
637 | nstatus = napi_get_value_double(env, js_value, &value);
|
638 | ENSURE_NAPI_OK(env, nstatus);
|
639 | TFE_OpSetAttrFloat(tfe_op, attr_name, static_cast<float>(value));
|
640 | }
|
641 | break;
|
642 | }
|
643 |
|
644 | case TF_ATTR_BOOL: {
|
645 | if (IsArray(env, nstatus, &js_value)) {
|
646 | uint32_t length;
|
647 | nstatus = napi_get_array_length(env, js_value, &length);
|
648 | ENSURE_NAPI_OK(env, nstatus);
|
649 | std::unique_ptr<unsigned char[]> data(new unsigned char[length]);
|
650 | for (uint32_t i = 0; i < length; ++i) {
|
651 | napi_value element;
|
652 | nstatus = napi_get_element(env, js_value, i, &element);
|
653 | ENSURE_NAPI_OK(env, nstatus);
|
654 | bool value;
|
655 | nstatus = napi_get_value_bool(env, element, &value);
|
656 | ENSURE_NAPI_OK(env, nstatus);
|
657 | data[i] = value ? 1 : 0;
|
658 | }
|
659 | TFE_OpSetAttrBoolList(tfe_op, attr_name, data.get(),
|
660 | static_cast<int>(length));
|
661 | } else {
|
662 | bool value;
|
663 | nstatus = napi_get_value_bool(env, js_value, &value);
|
664 | ENSURE_NAPI_OK(env, nstatus);
|
665 | TFE_OpSetAttrBool(tfe_op, attr_name, value ? 1 : 0);
|
666 | }
|
667 | break;
|
668 | }
|
669 |
|
670 | case TF_ATTR_TYPE: {
|
671 | TF_DataType tf_data_type;
|
672 | nstatus = napi_get_value_int32(
|
673 | env, js_value, reinterpret_cast<int32_t *>(&tf_data_type));
|
674 | ENSURE_NAPI_OK(env, nstatus);
|
675 |
|
676 | TFE_OpSetAttrType(tfe_op, attr_name, tf_data_type);
|
677 | break;
|
678 | }
|
679 |
|
680 | case TF_ATTR_SHAPE: {
|
681 | std::vector<int64_t> shape_vector;
|
682 | ExtractArrayShape(env, js_value, &shape_vector);
|
683 |
|
684 | TF_AutoStatus tf_status;
|
685 | TFE_OpSetAttrShape(tfe_op, attr_name, shape_vector.data(),
|
686 | shape_vector.size(), tf_status.status);
|
687 | ENSURE_TF_OK(env, tf_status);
|
688 | break;
|
689 | }
|
690 |
|
691 | default:
|
692 | REPORT_UNKNOWN_TF_ATTR_TYPE(env, tf_attr_type);
|
693 | break;
|
694 | }
|
695 | }
|
696 |
|
697 | TFJSBackend::TFJSBackend(napi_env env)
|
698 | : next_tensor_id_(0), next_savedmodel_id_(0) {
|
699 | TF_AutoStatus tf_status;
|
700 | TFE_ContextOptions *tfe_options = TFE_NewContextOptions();
|
701 | tfe_context_ = TFE_NewContext(tfe_options, tf_status.status);
|
702 | if (TF_GetCode(tf_status.status) != TF_OK) {
|
703 | NAPI_THROW_ERROR(env, "Exception creating TFE_Context");
|
704 | }
|
705 |
|
706 | TFE_DeleteContextOptions(tfe_options);
|
707 |
|
708 | TF_DeviceList *device_list =
|
709 | TFE_ContextListDevices(tfe_context_, tf_status.status);
|
710 | if (TF_GetCode(tf_status.status) != TF_OK) {
|
711 | NAPI_THROW_ERROR(env, "Exception creating TFE_Context");
|
712 | }
|
713 |
|
714 |
|
715 |
|
716 | std::string cpu_device_name;
|
717 | const int num_devices = TF_DeviceListCount(device_list);
|
718 | for (int i = 0; i < num_devices; i++) {
|
719 | const char *device_type =
|
720 | TF_DeviceListType(device_list, i, tf_status.status);
|
721 | ENSURE_TF_OK(env, tf_status);
|
722 |
|
723 |
|
724 | if (strcmp(device_type, "CPU") == 0) {
|
725 | cpu_device_name =
|
726 | std::string(TF_DeviceListName(device_list, i, tf_status.status));
|
727 | ENSURE_TF_OK(env, tf_status);
|
728 | } else if (strcmp(device_type, "GPU") == 0) {
|
729 | device_name =
|
730 | std::string(TF_DeviceListName(device_list, i, tf_status.status));
|
731 | ENSURE_TF_OK(env, tf_status);
|
732 | }
|
733 | }
|
734 |
|
735 |
|
736 | if (device_name.empty()) {
|
737 | device_name = cpu_device_name;
|
738 | is_gpu_device = false;
|
739 | } else {
|
740 | is_gpu_device = true;
|
741 | }
|
742 | TF_DeleteDeviceList(device_list);
|
743 | }
|
744 |
|
745 | TFJSBackend::~TFJSBackend() {
|
746 | for (auto &kv : tfe_handle_map_) {
|
747 | TFE_DeleteTensorHandle(kv.second);
|
748 | }
|
749 | for (auto &kv : tf_savedmodel_map_) {
|
750 | TF_AutoStatus tf_status;
|
751 | TF_DeleteSession(kv.second.first, tf_status.status);
|
752 | TF_DeleteGraph(kv.second.second);
|
753 | }
|
754 | if (tfe_context_ != nullptr) {
|
755 | TFE_DeleteContext(tfe_context_);
|
756 | }
|
757 | }
|
758 |
|
759 | TFJSBackend *TFJSBackend::Create(napi_env env) { return new TFJSBackend(env); }
|
760 |
|
761 | int32_t TFJSBackend::InsertHandle(TFE_TensorHandle *tfe_handle) {
|
762 | return tfe_handle_map_.insert(std::make_pair(next_tensor_id_++, tfe_handle))
|
763 | .first->first;
|
764 | }
|
765 |
|
766 | int32_t TFJSBackend::InsertSavedModel(TF_Session *tf_session,
|
767 | TF_Graph *tf_graph) {
|
768 |
|
769 |
|
770 | return tf_savedmodel_map_
|
771 | .insert(std::make_pair(next_savedmodel_id_++,
|
772 | std::make_pair(tf_session, tf_graph)))
|
773 | .first->first;
|
774 | }
|
775 |
|
776 | napi_value TFJSBackend::CreateTensor(napi_env env, napi_value shape_value,
|
777 | napi_value dtype_value,
|
778 | napi_value array_value) {
|
779 | napi_status nstatus;
|
780 |
|
781 | std::vector<int64_t> shape_vector;
|
782 | ExtractArrayShape(env, shape_value, &shape_vector);
|
783 |
|
784 | if (IsExceptionPending(env)) {
|
785 | return nullptr;
|
786 | }
|
787 |
|
788 | int32_t dtype_int32;
|
789 | nstatus = napi_get_value_int32(env, dtype_value, &dtype_int32);
|
790 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
791 |
|
792 | TFE_TensorHandle *tfe_handle = CreateTFE_TensorHandleFromJSValues(
|
793 | env, shape_vector.data(), shape_vector.size(),
|
794 | static_cast<TF_DataType>(dtype_int32), array_value);
|
795 |
|
796 |
|
797 | if (IsExceptionPending(env)) {
|
798 | return nullptr;
|
799 | }
|
800 |
|
801 |
|
802 |
|
803 | if (dtype_int32 != TF_INT32 && dtype_int32 != TF_STRING) {
|
804 |
|
805 |
|
806 | TFE_TensorHandle *new_handle = CopyTFE_TensorHandleToDevice(
|
807 | env, device_name.c_str(), tfe_handle, tfe_context_);
|
808 |
|
809 | TFE_DeleteTensorHandle(tfe_handle);
|
810 | tfe_handle = new_handle;
|
811 | }
|
812 |
|
813 | napi_value output_tensor_id;
|
814 | nstatus = napi_create_int32(env, InsertHandle(tfe_handle), &output_tensor_id);
|
815 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
816 | return output_tensor_id;
|
817 | }
|
818 |
|
819 | void TFJSBackend::DeleteTensor(napi_env env, napi_value tensor_id_value) {
|
820 | int32_t tensor_id;
|
821 | ENSURE_NAPI_OK(env, napi_get_value_int32(env, tensor_id_value, &tensor_id));
|
822 |
|
823 | auto tensor_entry = tfe_handle_map_.find(tensor_id);
|
824 | if (tensor_entry == tfe_handle_map_.end()) {
|
825 | NAPI_THROW_ERROR(env,
|
826 | "Delete called on a Tensor not referenced (tensor_id: %d)",
|
827 | tensor_id);
|
828 | return;
|
829 | }
|
830 |
|
831 | TFE_DeleteTensorHandle(tensor_entry->second);
|
832 | tfe_handle_map_.erase(tensor_entry);
|
833 | }
|
834 |
|
835 | napi_value TFJSBackend::GetTensorData(napi_env env,
|
836 | napi_value tensor_id_value) {
|
837 | int32_t tensor_id;
|
838 | ENSURE_NAPI_OK_RETVAL(
|
839 | env, napi_get_value_int32(env, tensor_id_value, &tensor_id), nullptr);
|
840 |
|
841 | auto tensor_entry = tfe_handle_map_.find(tensor_id);
|
842 | if (tensor_entry == tfe_handle_map_.end()) {
|
843 | NAPI_THROW_ERROR(
|
844 | env, "Get data called on a Tensor not referenced (tensor_id: %d)",
|
845 | tensor_id);
|
846 | return nullptr;
|
847 | }
|
848 |
|
849 | napi_value js_value;
|
850 | CopyTFE_TensorHandleDataToJSData(env, tfe_context_, tensor_entry->second,
|
851 | &js_value);
|
852 | return js_value;
|
853 | }
|
854 |
|
855 | napi_value TFJSBackend::ExecuteOp(napi_env env, napi_value op_name_value,
|
856 | napi_value op_attr_inputs,
|
857 | napi_value input_tensor_ids,
|
858 | napi_value num_output_values) {
|
859 | napi_status nstatus;
|
860 |
|
861 | std::string op_name;
|
862 | nstatus = GetStringParam(env, op_name_value, op_name);
|
863 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
864 |
|
865 | TF_AutoStatus tf_status;
|
866 | TFE_AutoOp tfe_op(TFE_NewOp(tfe_context_, op_name.c_str(), tf_status.status));
|
867 | ENSURE_TF_OK_RETVAL(env, tf_status, nullptr);
|
868 |
|
869 | uint32_t num_input_ids;
|
870 | nstatus = napi_get_array_length(env, input_tensor_ids, &num_input_ids);
|
871 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
872 |
|
873 | for (uint32_t i = 0; i < num_input_ids; i++) {
|
874 | napi_value cur_input_id;
|
875 | nstatus = napi_get_element(env, input_tensor_ids, i, &cur_input_id);
|
876 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
877 |
|
878 | int32_t cur_input_tensor_id;
|
879 | nstatus = napi_get_value_int32(env, cur_input_id, &cur_input_tensor_id);
|
880 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
881 |
|
882 | auto input_tensor_entry = tfe_handle_map_.find(cur_input_tensor_id);
|
883 | if (input_tensor_entry == tfe_handle_map_.end()) {
|
884 | NAPI_THROW_ERROR(env, "Input Tensor ID not referenced (tensor_id: %d)",
|
885 | cur_input_tensor_id);
|
886 | return nullptr;
|
887 | }
|
888 |
|
889 | TFE_OpAddInput(tfe_op.op, input_tensor_entry->second, tf_status.status);
|
890 | ENSURE_TF_OK_RETVAL(env, tf_status, nullptr);
|
891 | }
|
892 |
|
893 | uint32_t op_attrs_length;
|
894 | nstatus = napi_get_array_length(env, op_attr_inputs, &op_attrs_length);
|
895 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
896 |
|
897 | for (uint32_t i = 0; i < op_attrs_length; i++) {
|
898 | napi_value cur_op_attr;
|
899 | nstatus = napi_get_element(env, op_attr_inputs, i, &cur_op_attr);
|
900 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
901 |
|
902 | AssignOpAttr(env, tfe_op.op, cur_op_attr);
|
903 |
|
904 |
|
905 | if (IsExceptionPending(env)) {
|
906 | return nullptr;
|
907 | }
|
908 | }
|
909 |
|
910 | int32_t num_outputs;
|
911 | nstatus = napi_get_value_int32(env, num_output_values, &num_outputs);
|
912 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
913 |
|
914 |
|
915 |
|
916 | std::vector<TFE_TensorHandle *> result_handles(num_outputs, nullptr);
|
917 |
|
918 | int size = result_handles.size();
|
919 | TFE_Execute(tfe_op.op, result_handles.data(), &size, tf_status.status);
|
920 | ENSURE_TF_OK_RETVAL(env, tf_status, nullptr);
|
921 |
|
922 | napi_value output_tensor_infos;
|
923 | nstatus = napi_create_array_with_length(env, size, &output_tensor_infos);
|
924 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
925 |
|
926 | for (int32_t i = 0; i < num_outputs; i++) {
|
927 | TFE_TensorHandle *handle = result_handles[i];
|
928 | napi_value tensor_info_value = GenerateOutputTensorInfo(env, handle);
|
929 |
|
930 | nstatus = napi_set_element(env, output_tensor_infos, i, tensor_info_value);
|
931 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
932 | }
|
933 |
|
934 | return output_tensor_infos;
|
935 | }
|
936 |
|
937 |
|
938 |
|
939 |
|
940 |
|
941 | napi_value TFJSBackend::GenerateOutputTensorInfo(napi_env env,
|
942 | TFE_TensorHandle *handle) {
|
943 | napi_status nstatus;
|
944 |
|
945 |
|
946 | napi_value tensor_info_value;
|
947 | nstatus = napi_create_object(env, &tensor_info_value);
|
948 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
949 |
|
950 |
|
951 | napi_value output_tensor_id_value;
|
952 | nstatus =
|
953 | napi_create_int32(env, InsertHandle(handle), &output_tensor_id_value);
|
954 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
955 |
|
956 | nstatus = napi_set_named_property(env, tensor_info_value, "id",
|
957 | output_tensor_id_value);
|
958 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
959 |
|
960 |
|
961 | napi_value shape_value;
|
962 | GetTFE_TensorHandleShape(env, handle, &shape_value);
|
963 |
|
964 | nstatus =
|
965 | napi_set_named_property(env, tensor_info_value, "shape", shape_value);
|
966 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
967 |
|
968 |
|
969 | napi_value type_value;
|
970 | GetTFE_TensorHandleType(env, handle, &type_value);
|
971 |
|
972 | nstatus =
|
973 | napi_set_named_property(env, tensor_info_value, "dtype", type_value);
|
974 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
975 |
|
976 | return tensor_info_value;
|
977 | }
|
978 |
|
979 | napi_value TFJSBackend::LoadSavedModel(napi_env env,
|
980 | napi_value export_dir_value,
|
981 | napi_value tags_value) {
|
982 | TF_SessionOptions *session_options = TF_NewSessionOptions();
|
983 |
|
984 | TF_Buffer *run_options = TF_NewBufferFromString("", 0);
|
985 |
|
986 | std::string export_dir_string;
|
987 | napi_status nstatus;
|
988 | nstatus = GetStringParam(env, export_dir_value, export_dir_string);
|
989 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
990 | const char *export_dir = export_dir_string.c_str();
|
991 |
|
992 | std::string tags;
|
993 | nstatus = GetStringParam(env, tags_value, tags);
|
994 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
995 |
|
996 | std::vector<const char *> tags_ptrs = splitStringByComma(tags);
|
997 |
|
998 | TF_Graph *graph = TF_NewGraph();
|
999 |
|
1000 | TF_Buffer *metagraph = TF_NewBuffer();
|
1001 |
|
1002 | TF_AutoStatus tf_status;
|
1003 |
|
1004 | TF_Session *session = TF_LoadSessionFromSavedModel(
|
1005 | session_options, run_options, export_dir, tags_ptrs.data(),
|
1006 | tags_ptrs.size(), graph, metagraph, tf_status.status);
|
1007 |
|
1008 |
|
1009 | TF_DeleteSessionOptions(session_options);
|
1010 | TF_DeleteBuffer(run_options);
|
1011 | TF_DeleteBuffer(metagraph);
|
1012 |
|
1013 | if (TF_GetCode(tf_status.status) != TF_OK) {
|
1014 | NAPI_THROW_ERROR(env, "Failed to load SavedModel: %s",
|
1015 | TF_Message(tf_status.status));
|
1016 | return nullptr;
|
1017 | }
|
1018 |
|
1019 | napi_value output_session_id;
|
1020 | nstatus = napi_create_int32(env, InsertSavedModel(session, graph),
|
1021 | &output_session_id);
|
1022 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
1023 | return output_session_id;
|
1024 | }
|
1025 |
|
1026 | void TFJSBackend::DeleteSavedModel(napi_env env,
|
1027 | napi_value savedmodel_id_value) {
|
1028 | int32_t savedmodel_id;
|
1029 | ENSURE_NAPI_OK(
|
1030 | env, napi_get_value_int32(env, savedmodel_id_value, &savedmodel_id));
|
1031 |
|
1032 | auto savedmodel_entry = tf_savedmodel_map_.find(savedmodel_id);
|
1033 | if (savedmodel_entry == tf_savedmodel_map_.end()) {
|
1034 | NAPI_THROW_ERROR(
|
1035 | env, "Delete called on a SavedModel not found (savedmodel_id: %d)",
|
1036 | savedmodel_id);
|
1037 | return;
|
1038 | }
|
1039 |
|
1040 | TF_AutoStatus tf_status;
|
1041 | TF_DeleteSession(savedmodel_entry->second.first, tf_status.status);
|
1042 | if (TF_GetCode(tf_status.status) != TF_OK) {
|
1043 | NAPI_THROW_ERROR(env, "Failed to delete SavedModel: %s",
|
1044 | TF_Message(tf_status.status));
|
1045 | return;
|
1046 | }
|
1047 |
|
1048 |
|
1049 | TF_DeleteGraph(savedmodel_entry->second.second);
|
1050 | tf_savedmodel_map_.erase(savedmodel_entry);
|
1051 | }
|
1052 |
|
1053 | napi_value TFJSBackend::RunSavedModel(napi_env env,
|
1054 | napi_value savedmodel_id_value,
|
1055 | napi_value input_tensor_ids,
|
1056 | napi_value input_op_names_value,
|
1057 | napi_value output_op_names_value) {
|
1058 | napi_status nstatus;
|
1059 | TF_AutoStatus tf_status;
|
1060 |
|
1061 | int32_t savedmodel_id;
|
1062 | nstatus = napi_get_value_int32(env, savedmodel_id_value, &savedmodel_id);
|
1063 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
1064 |
|
1065 |
|
1066 | auto savedmodel_entry = tf_savedmodel_map_.find(savedmodel_id);
|
1067 | if (savedmodel_entry == tf_savedmodel_map_.end()) {
|
1068 | NAPI_THROW_ERROR(env, "SavedModel ID not found (savedmodel_id: %d)",
|
1069 | savedmodel_id);
|
1070 | return nullptr;
|
1071 | }
|
1072 |
|
1073 | std::string input_op_names;
|
1074 | nstatus = GetStringParam(env, input_op_names_value, input_op_names);
|
1075 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
1076 | std::string output_op_names;
|
1077 | nstatus = GetStringParam(env, output_op_names_value, output_op_names);
|
1078 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
1079 |
|
1080 |
|
1081 | std::vector<const char *> input_op_name_array =
|
1082 | splitStringByComma(input_op_names);
|
1083 | std::vector<const char *> output_op_name_array =
|
1084 | splitStringByComma(output_op_names);
|
1085 |
|
1086 | std::vector<TF_Output> inputs;
|
1087 | std::vector<TF_Output> outputs;
|
1088 |
|
1089 | uint32_t num_input_ids;
|
1090 | nstatus = napi_get_array_length(env, input_tensor_ids, &num_input_ids);
|
1091 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
1092 |
|
1093 | if (input_op_name_array.size() != num_input_ids) {
|
1094 | NAPI_THROW_ERROR(env,
|
1095 | "Length of input op names (%d) does not match the length "
|
1096 | "of input tensors (%d).",
|
1097 | input_op_name_array.size(), num_input_ids);
|
1098 | return nullptr;
|
1099 | }
|
1100 |
|
1101 | std::vector<TF_Tensor *> input_values;
|
1102 |
|
1103 | for (uint32_t i = 0; i < num_input_ids; i++) {
|
1104 | napi_value cur_input_id;
|
1105 | nstatus = napi_get_element(env, input_tensor_ids, i, &cur_input_id);
|
1106 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
1107 |
|
1108 | int32_t cur_input_tensor_id;
|
1109 | nstatus = napi_get_value_int32(env, cur_input_id, &cur_input_tensor_id);
|
1110 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
1111 |
|
1112 |
|
1113 | auto tensor_entry = tfe_handle_map_.find(cur_input_tensor_id);
|
1114 | if (tensor_entry == tfe_handle_map_.end()) {
|
1115 | NAPI_THROW_ERROR(env, "Input Tensor ID not found (tensor_id: %d)",
|
1116 | cur_input_tensor_id);
|
1117 | return nullptr;
|
1118 | }
|
1119 | TF_Tensor *inputTensor =
|
1120 | TFE_TensorHandleResolve(tensor_entry->second, tf_status.status);
|
1121 |
|
1122 | if (TF_GetCode(tf_status.status) != TF_OK) {
|
1123 | NAPI_THROW_ERROR(
|
1124 | env, "Failed to get input tensor (tensor_id: %d) for session.",
|
1125 | cur_input_tensor_id);
|
1126 | return nullptr;
|
1127 | }
|
1128 |
|
1129 |
|
1130 | input_values.push_back(inputTensor);
|
1131 |
|
1132 |
|
1133 |
|
1134 | std::string name(input_op_name_array[i]);
|
1135 | int index = name.find(":");
|
1136 | std::string input_op_name = name.substr(0, index);
|
1137 | const char *input_op_index = name.substr(index + 1).c_str();
|
1138 | int input_tensor_index;
|
1139 | if (strlen(input_op_index) == 0) {
|
1140 | input_tensor_index = 0;
|
1141 | } else {
|
1142 | input_tensor_index = atoi(input_op_index);
|
1143 | }
|
1144 |
|
1145 |
|
1146 |
|
1147 |
|
1148 | TF_Operation *input_op = TF_GraphOperationByName(
|
1149 | savedmodel_entry->second.second, input_op_name.c_str());
|
1150 | if (input_op == nullptr) {
|
1151 | NAPI_THROW_ERROR(env, "Input op name can not be found in the graph.");
|
1152 | return nullptr;
|
1153 | }
|
1154 | TF_Output in = {input_op, input_tensor_index};
|
1155 | inputs.push_back(in);
|
1156 | }
|
1157 |
|
1158 |
|
1159 | for (uint32_t i = 0; i < output_op_name_array.size(); i++) {
|
1160 |
|
1161 |
|
1162 | std::string name(output_op_name_array[i]);
|
1163 | int index = name.find(":");
|
1164 | std::string output_op_name = name.substr(0, index);
|
1165 | const char *output_op_index = name.substr(index + 1).c_str();
|
1166 | int output_tensor_index;
|
1167 | if (strlen(output_op_index) == 0) {
|
1168 | output_tensor_index = 0;
|
1169 | } else {
|
1170 | output_tensor_index = atoi(output_op_index);
|
1171 | }
|
1172 |
|
1173 | TF_Operation *output_op = TF_GraphOperationByName(
|
1174 | savedmodel_entry->second.second, output_op_name.c_str());
|
1175 | if (output_op == nullptr) {
|
1176 | NAPI_THROW_ERROR(env, "Output op name can not be found in the graph.");
|
1177 | return nullptr;
|
1178 | }
|
1179 | TF_Output out = {output_op, output_tensor_index};
|
1180 | outputs.push_back(out);
|
1181 | }
|
1182 |
|
1183 | std::vector<TF_Tensor *> output_values(outputs.size(), nullptr);
|
1184 |
|
1185 | TF_SessionRun(savedmodel_entry->second.first, nullptr, inputs.data(),
|
1186 | input_values.data(), num_input_ids, outputs.data(),
|
1187 | output_values.data(), output_op_name_array.size(), nullptr, 0,
|
1188 | nullptr, tf_status.status);
|
1189 |
|
1190 | if (TF_GetCode(tf_status.status) != TF_OK) {
|
1191 | NAPI_THROW_ERROR(env, "Session fail to run with error: %s",
|
1192 | TF_Message(tf_status.status));
|
1193 | return nullptr;
|
1194 | }
|
1195 |
|
1196 | napi_value output_tensor_infos;
|
1197 | nstatus = napi_create_array_with_length(env, 1, &output_tensor_infos);
|
1198 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
1199 |
|
1200 |
|
1201 | for (uint32_t i = 0; i < output_op_name_array.size(); i++) {
|
1202 | TFE_TensorHandle *tfe_handle =
|
1203 | TFE_NewTensorHandle(output_values[i], tf_status.status);
|
1204 |
|
1205 | TF_DeleteTensor(output_values[i]);
|
1206 |
|
1207 | napi_value tensor_info_value = GenerateOutputTensorInfo(env, tfe_handle);
|
1208 |
|
1209 | nstatus = napi_set_element(env, output_tensor_infos, i, tensor_info_value);
|
1210 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
1211 | }
|
1212 |
|
1213 | for (uint32_t i = 0; i < num_input_ids; i++) {
|
1214 |
|
1215 | TF_DeleteTensor(input_values[i]);
|
1216 | }
|
1217 |
|
1218 | return output_tensor_infos;
|
1219 | }
|
1220 |
|
1221 | napi_value TFJSBackend::GetNumOfSavedModels(napi_env env) {
|
1222 | napi_status nstatus;
|
1223 | napi_value num_saved_models;
|
1224 | nstatus =
|
1225 | napi_create_int32(env, tf_savedmodel_map_.size(), &num_saved_models);
|
1226 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
1227 | return num_saved_models;
|
1228 | }
|
1229 |
|
1230 | }
|