UNPKG

9.76 kBtext/x-cView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17
18#include <node_api.h>
19#include "tfjs_backend.h"
20#include "utils.h"
21
22namespace tfnodejs {
23
24TFJSBackend *gBackend = nullptr;
25
26static void AssignIntProperty(napi_env env, napi_value exports,
27 const char *name, int32_t value) {
28 napi_value js_value;
29 napi_status nstatus = napi_create_int32(env, value, &js_value);
30 ENSURE_NAPI_OK(env, nstatus);
31
32 napi_property_descriptor property = {name, nullptr, nullptr,
33 nullptr, nullptr, js_value,
34 napi_default, nullptr};
35 nstatus = napi_define_properties(env, exports, 1, &property);
36 ENSURE_NAPI_OK(env, nstatus);
37}
38
39static napi_value CreateTensor(napi_env env, napi_callback_info info) {
40 napi_status nstatus;
41
42 // Create tensor takes 3 params: shape, dtype, typed-array/array:
43 size_t argc = 3;
44 napi_value args[3];
45 napi_value js_this;
46 nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
47 ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
48
49 if (argc < 3) {
50 NAPI_THROW_ERROR(env,
51 "Invalid number of args passed to createTensor(). "
52 "Expecting 3 args but got %d.",
53 argc);
54 return nullptr;
55 }
56
57 ENSURE_VALUE_IS_ARRAY_RETVAL(env, args[0], nullptr);
58 ENSURE_VALUE_IS_NUMBER_RETVAL(env, args[1], nullptr);
59
60 // The third array can either be a typed array or an array:
61 bool is_typed_array;
62 nstatus = napi_is_typedarray(env, args[2], &is_typed_array);
63 ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
64 if (!is_typed_array) {
65 ENSURE_VALUE_IS_ARRAY_RETVAL(env, args[2], nullptr);
66 }
67
68 return gBackend->CreateTensor(env, args[0], args[1], args[2]);
69}
70
71static napi_value DeleteTensor(napi_env env, napi_callback_info info) {
72 napi_status nstatus;
73
74 // Delete tensor takes 1 param: tensor ID;
75 size_t argc = 1;
76 napi_value args[1];
77 napi_value js_this;
78 nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
79 ENSURE_NAPI_OK_RETVAL(env, nstatus, js_this);
80
81 if (argc < 1) {
82 NAPI_THROW_ERROR(env,
83 "Invalid number of args passed to deleteTensor(). "
84 "Expecting 1 arg but got %d.",
85 argc);
86 return js_this;
87 }
88
89 ENSURE_VALUE_IS_NUMBER_RETVAL(env, args[0], js_this);
90
91 gBackend->DeleteTensor(env, args[0]);
92 return js_this;
93}
94
95static napi_value TensorDataSync(napi_env env, napi_callback_info info) {
96 napi_status nstatus;
97
98 // Tensor data-sync takes 1 param: tensor ID;
99 size_t argc = 1;
100 napi_value args[1];
101 napi_value js_this;
102 nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
103 ENSURE_NAPI_OK_RETVAL(env, nstatus, js_this);
104
105 if (argc < 1) {
106 NAPI_THROW_ERROR(env,
107 "Invalid number of args passed to tensorDataSync(). "
108 "Expecting 1 arg but got %d.",
109 argc);
110 return nullptr;
111 }
112
113 ENSURE_VALUE_IS_NUMBER_RETVAL(env, args[0], js_this);
114
115 return gBackend->GetTensorData(env, args[0]);
116}
117
118static napi_value ExecuteOp(napi_env env, napi_callback_info info) {
119 napi_status nstatus;
120
121 // Create tensor takes 4 params: op-name, op-attrs, input-tensor-ids,
122 // num-outputs:
123 size_t argc = 4;
124 napi_value args[4];
125 napi_value js_this;
126 nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
127 ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
128
129 if (argc < 4) {
130 NAPI_THROW_ERROR(env,
131 "Invalid number of args passed to executeOp(). Expecting "
132 "4 args but got %d.",
133 argc);
134 return nullptr;
135 }
136
137 ENSURE_VALUE_IS_STRING_RETVAL(env, args[0], nullptr);
138 ENSURE_VALUE_IS_ARRAY_RETVAL(env, args[1], nullptr);
139 ENSURE_VALUE_IS_ARRAY_RETVAL(env, args[2], nullptr);
140 ENSURE_VALUE_IS_NUMBER_RETVAL(env, args[3], nullptr);
141
142 return gBackend->ExecuteOp(env, args[0], args[1], args[2], args[3]);
143}
144
145static napi_value IsUsingGPUDevice(napi_env env, napi_callback_info info) {
146 napi_value result;
147
148 napi_status nstatus;
149 nstatus = napi_get_boolean(env, gBackend->is_gpu_device, &result);
150 ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
151
152 return result;
153}
154
155static napi_value LoadSavedModel(napi_env env, napi_callback_info info) {
156 napi_status nstatus;
157
158 // Load saved model takes 2 params: export_dir, tags:
159 size_t argc = 2;
160 napi_value args[2];
161 napi_value js_this;
162 nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
163 ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
164
165 if (argc < 2) {
166 NAPI_THROW_ERROR(env,
167 "Invalid number of args passed to LoadSavedModel(). "
168 "Expecting 2 args but got %d.",
169 argc);
170 return nullptr;
171 }
172
173 ENSURE_VALUE_IS_STRING_RETVAL(env, args[0], nullptr);
174 ENSURE_VALUE_IS_STRING_RETVAL(env, args[1], nullptr);
175
176 return gBackend->LoadSavedModel(env, args[0], args[1]);
177}
178
179static napi_value DeleteSavedModel(napi_env env, napi_callback_info info) {
180 napi_status nstatus;
181
182 // Delete SavedModel takes 1 param: savedModel ID;
183 size_t argc = 1;
184 napi_value args[1];
185 napi_value js_this;
186 nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
187 ENSURE_NAPI_OK_RETVAL(env, nstatus, js_this);
188
189 if (argc < 1) {
190 NAPI_THROW_ERROR(env,
191 "Invalid number of args passed to deleteSavedModel(). "
192 "Expecting 1 arg but got %d.",
193 argc);
194 return js_this;
195 }
196
197 ENSURE_VALUE_IS_NUMBER_RETVAL(env, args[0], js_this);
198
199 gBackend->DeleteSavedModel(env, args[0]);
200 return js_this;
201}
202
203static napi_value RunSavedModel(napi_env env, napi_callback_info info) {
204 napi_status nstatus;
205
206 // Run SavedModel takes 4 params: session_id, input_tensor_ids,
207 // input_op_names, output_op_names.
208 size_t argc = 4;
209 napi_value args[4];
210 napi_value js_this;
211 nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
212 ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
213
214 if (argc < 4) {
215 NAPI_THROW_ERROR(env, "Invalid number of args passed to RunSavedModel()");
216 return nullptr;
217 }
218
219 ENSURE_VALUE_IS_NUMBER_RETVAL(env, args[0], nullptr);
220 ENSURE_VALUE_IS_ARRAY_RETVAL(env, args[1], nullptr);
221 ENSURE_VALUE_IS_STRING_RETVAL(env, args[2], nullptr);
222 ENSURE_VALUE_IS_STRING_RETVAL(env, args[3], nullptr);
223
224 return gBackend->RunSavedModel(env, args[0], args[1], args[2], args[3]);
225}
226
227static napi_value GetNumOfSavedModels(napi_env env, napi_callback_info info) {
228 // Delete SavedModel takes 0 param;
229 return gBackend->GetNumOfSavedModels(env);
230}
231
232static napi_value InitTFNodeJSBinding(napi_env env, napi_value exports) {
233 napi_status nstatus;
234
235 gBackend = TFJSBackend::Create(env);
236 ENSURE_VALUE_IS_NOT_NULL_RETVAL(env, gBackend, nullptr);
237
238 // TF version
239 napi_value tf_version;
240 nstatus = napi_create_string_latin1(env, TF_Version(), -1, &tf_version);
241 ENSURE_NAPI_OK_RETVAL(env, nstatus, exports);
242
243 // Set all export values list here.
244 napi_property_descriptor exports_properties[] = {
245 {"createTensor", nullptr, CreateTensor, nullptr, nullptr, nullptr,
246 napi_default, nullptr},
247 {"deleteTensor", nullptr, DeleteTensor, nullptr, nullptr, nullptr,
248 napi_default, nullptr},
249 {"tensorDataSync", nullptr, TensorDataSync, nullptr, nullptr, nullptr,
250 napi_default, nullptr},
251 {"executeOp", nullptr, ExecuteOp, nullptr, nullptr, nullptr, napi_default,
252 nullptr},
253 {"loadSavedModel", nullptr, LoadSavedModel, nullptr, nullptr, nullptr,
254 napi_default, nullptr},
255 {"deleteSavedModel", nullptr, DeleteSavedModel, nullptr, nullptr, nullptr,
256 napi_default, nullptr},
257 {"runSavedModel", nullptr, RunSavedModel, nullptr, nullptr, nullptr,
258 napi_default, nullptr},
259 {"TF_Version", nullptr, nullptr, nullptr, nullptr, tf_version,
260 napi_default, nullptr},
261 {"isUsingGpuDevice", nullptr, IsUsingGPUDevice, nullptr, nullptr, nullptr,
262 napi_default, nullptr},
263 {"getNumOfSavedModels", nullptr, GetNumOfSavedModels, nullptr, nullptr,
264 nullptr, napi_default, nullptr},
265 };
266 nstatus = napi_define_properties(env, exports, ARRAY_SIZE(exports_properties),
267 exports_properties);
268 ENSURE_NAPI_OK_RETVAL(env, nstatus, exports);
269
270 // Export TF property types to JS
271#define EXPORT_INT_PROPERTY(v) AssignIntProperty(env, exports, #v, v)
272 // Types
273 EXPORT_INT_PROPERTY(TF_FLOAT);
274 EXPORT_INT_PROPERTY(TF_INT32);
275 EXPORT_INT_PROPERTY(TF_INT64);
276 EXPORT_INT_PROPERTY(TF_BOOL);
277 EXPORT_INT_PROPERTY(TF_COMPLEX64);
278 EXPORT_INT_PROPERTY(TF_STRING);
279 EXPORT_INT_PROPERTY(TF_RESOURCE);
280 EXPORT_INT_PROPERTY(TF_UINT8);
281
282 // Op AttrType
283 EXPORT_INT_PROPERTY(TF_ATTR_STRING);
284 EXPORT_INT_PROPERTY(TF_ATTR_INT);
285 EXPORT_INT_PROPERTY(TF_ATTR_FLOAT);
286 EXPORT_INT_PROPERTY(TF_ATTR_BOOL);
287 EXPORT_INT_PROPERTY(TF_ATTR_TYPE);
288 EXPORT_INT_PROPERTY(TF_ATTR_SHAPE);
289#undef EXPORT_INT_PROPERTY
290
291 return exports;
292}
293
294NAPI_MODULE(tfe_binding, InitTFNodeJSBinding)
295
296} // namespace tfnodejs