1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 |
|
18 | #include <node_api.h>
|
19 | #include "tfjs_backend.h"
|
20 | #include "utils.h"
|
21 |
|
22 | namespace tfnodejs {
|
23 |
|
24 | TFJSBackend *gBackend = nullptr;
|
25 |
|
26 | static void AssignIntProperty(napi_env env, napi_value exports,
|
27 | const char *name, int32_t value) {
|
28 | napi_value js_value;
|
29 | napi_status nstatus = napi_create_int32(env, value, &js_value);
|
30 | ENSURE_NAPI_OK(env, nstatus);
|
31 |
|
32 | napi_property_descriptor property = {name, nullptr, nullptr,
|
33 | nullptr, nullptr, js_value,
|
34 | napi_default, nullptr};
|
35 | nstatus = napi_define_properties(env, exports, 1, &property);
|
36 | ENSURE_NAPI_OK(env, nstatus);
|
37 | }
|
38 |
|
39 | static napi_value CreateTensor(napi_env env, napi_callback_info info) {
|
40 | napi_status nstatus;
|
41 |
|
42 |
|
43 | size_t argc = 3;
|
44 | napi_value args[3];
|
45 | napi_value js_this;
|
46 | nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
|
47 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
48 |
|
49 | if (argc < 3) {
|
50 | NAPI_THROW_ERROR(env,
|
51 | "Invalid number of args passed to createTensor(). "
|
52 | "Expecting 3 args but got %d.",
|
53 | argc);
|
54 | return nullptr;
|
55 | }
|
56 |
|
57 | ENSURE_VALUE_IS_ARRAY_RETVAL(env, args[0], nullptr);
|
58 | ENSURE_VALUE_IS_NUMBER_RETVAL(env, args[1], nullptr);
|
59 |
|
60 |
|
61 | bool is_typed_array;
|
62 | nstatus = napi_is_typedarray(env, args[2], &is_typed_array);
|
63 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
64 | if (!is_typed_array) {
|
65 | ENSURE_VALUE_IS_ARRAY_RETVAL(env, args[2], nullptr);
|
66 | }
|
67 |
|
68 | return gBackend->CreateTensor(env, args[0], args[1], args[2]);
|
69 | }
|
70 |
|
71 | static napi_value DeleteTensor(napi_env env, napi_callback_info info) {
|
72 | napi_status nstatus;
|
73 |
|
74 |
|
75 | size_t argc = 1;
|
76 | napi_value args[1];
|
77 | napi_value js_this;
|
78 | nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
|
79 | ENSURE_NAPI_OK_RETVAL(env, nstatus, js_this);
|
80 |
|
81 | if (argc < 1) {
|
82 | NAPI_THROW_ERROR(env,
|
83 | "Invalid number of args passed to deleteTensor(). "
|
84 | "Expecting 1 arg but got %d.",
|
85 | argc);
|
86 | return js_this;
|
87 | }
|
88 |
|
89 | ENSURE_VALUE_IS_NUMBER_RETVAL(env, args[0], js_this);
|
90 |
|
91 | gBackend->DeleteTensor(env, args[0]);
|
92 | return js_this;
|
93 | }
|
94 |
|
95 | static napi_value TensorDataSync(napi_env env, napi_callback_info info) {
|
96 | napi_status nstatus;
|
97 |
|
98 |
|
99 | size_t argc = 1;
|
100 | napi_value args[1];
|
101 | napi_value js_this;
|
102 | nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
|
103 | ENSURE_NAPI_OK_RETVAL(env, nstatus, js_this);
|
104 |
|
105 | if (argc < 1) {
|
106 | NAPI_THROW_ERROR(env,
|
107 | "Invalid number of args passed to tensorDataSync(). "
|
108 | "Expecting 1 arg but got %d.",
|
109 | argc);
|
110 | return nullptr;
|
111 | }
|
112 |
|
113 | ENSURE_VALUE_IS_NUMBER_RETVAL(env, args[0], js_this);
|
114 |
|
115 | return gBackend->GetTensorData(env, args[0]);
|
116 | }
|
117 |
|
118 | static napi_value ExecuteOp(napi_env env, napi_callback_info info) {
|
119 | napi_status nstatus;
|
120 |
|
121 |
|
122 |
|
123 | size_t argc = 4;
|
124 | napi_value args[4];
|
125 | napi_value js_this;
|
126 | nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
|
127 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
128 |
|
129 | if (argc < 4) {
|
130 | NAPI_THROW_ERROR(env,
|
131 | "Invalid number of args passed to executeOp(). Expecting "
|
132 | "4 args but got %d.",
|
133 | argc);
|
134 | return nullptr;
|
135 | }
|
136 |
|
137 | ENSURE_VALUE_IS_STRING_RETVAL(env, args[0], nullptr);
|
138 | ENSURE_VALUE_IS_ARRAY_RETVAL(env, args[1], nullptr);
|
139 | ENSURE_VALUE_IS_ARRAY_RETVAL(env, args[2], nullptr);
|
140 | ENSURE_VALUE_IS_NUMBER_RETVAL(env, args[3], nullptr);
|
141 |
|
142 | return gBackend->ExecuteOp(env, args[0], args[1], args[2], args[3]);
|
143 | }
|
144 |
|
145 | static napi_value IsUsingGPUDevice(napi_env env, napi_callback_info info) {
|
146 | napi_value result;
|
147 |
|
148 | napi_status nstatus;
|
149 | nstatus = napi_get_boolean(env, gBackend->is_gpu_device, &result);
|
150 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
151 |
|
152 | return result;
|
153 | }
|
154 |
|
155 | static napi_value LoadSavedModel(napi_env env, napi_callback_info info) {
|
156 | napi_status nstatus;
|
157 |
|
158 |
|
159 | size_t argc = 2;
|
160 | napi_value args[2];
|
161 | napi_value js_this;
|
162 | nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
|
163 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
164 |
|
165 | if (argc < 2) {
|
166 | NAPI_THROW_ERROR(env,
|
167 | "Invalid number of args passed to LoadSavedModel(). "
|
168 | "Expecting 2 args but got %d.",
|
169 | argc);
|
170 | return nullptr;
|
171 | }
|
172 |
|
173 | ENSURE_VALUE_IS_STRING_RETVAL(env, args[0], nullptr);
|
174 | ENSURE_VALUE_IS_STRING_RETVAL(env, args[1], nullptr);
|
175 |
|
176 | return gBackend->LoadSavedModel(env, args[0], args[1]);
|
177 | }
|
178 |
|
179 | static napi_value DeleteSavedModel(napi_env env, napi_callback_info info) {
|
180 | napi_status nstatus;
|
181 |
|
182 |
|
183 | size_t argc = 1;
|
184 | napi_value args[1];
|
185 | napi_value js_this;
|
186 | nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
|
187 | ENSURE_NAPI_OK_RETVAL(env, nstatus, js_this);
|
188 |
|
189 | if (argc < 1) {
|
190 | NAPI_THROW_ERROR(env,
|
191 | "Invalid number of args passed to deleteSavedModel(). "
|
192 | "Expecting 1 arg but got %d.",
|
193 | argc);
|
194 | return js_this;
|
195 | }
|
196 |
|
197 | ENSURE_VALUE_IS_NUMBER_RETVAL(env, args[0], js_this);
|
198 |
|
199 | gBackend->DeleteSavedModel(env, args[0]);
|
200 | return js_this;
|
201 | }
|
202 |
|
203 | static napi_value RunSavedModel(napi_env env, napi_callback_info info) {
|
204 | napi_status nstatus;
|
205 |
|
206 |
|
207 |
|
208 | size_t argc = 4;
|
209 | napi_value args[4];
|
210 | napi_value js_this;
|
211 | nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
|
212 | ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
|
213 |
|
214 | if (argc < 4) {
|
215 | NAPI_THROW_ERROR(env, "Invalid number of args passed to RunSavedModel()");
|
216 | return nullptr;
|
217 | }
|
218 |
|
219 | ENSURE_VALUE_IS_NUMBER_RETVAL(env, args[0], nullptr);
|
220 | ENSURE_VALUE_IS_ARRAY_RETVAL(env, args[1], nullptr);
|
221 | ENSURE_VALUE_IS_STRING_RETVAL(env, args[2], nullptr);
|
222 | ENSURE_VALUE_IS_STRING_RETVAL(env, args[3], nullptr);
|
223 |
|
224 | return gBackend->RunSavedModel(env, args[0], args[1], args[2], args[3]);
|
225 | }
|
226 |
|
227 | static napi_value GetNumOfSavedModels(napi_env env, napi_callback_info info) {
|
228 |
|
229 | return gBackend->GetNumOfSavedModels(env);
|
230 | }
|
231 |
|
232 | static napi_value InitTFNodeJSBinding(napi_env env, napi_value exports) {
|
233 | napi_status nstatus;
|
234 |
|
235 | gBackend = TFJSBackend::Create(env);
|
236 | ENSURE_VALUE_IS_NOT_NULL_RETVAL(env, gBackend, nullptr);
|
237 |
|
238 |
|
239 | napi_value tf_version;
|
240 | nstatus = napi_create_string_latin1(env, TF_Version(), -1, &tf_version);
|
241 | ENSURE_NAPI_OK_RETVAL(env, nstatus, exports);
|
242 |
|
243 |
|
244 | napi_property_descriptor exports_properties[] = {
|
245 | {"createTensor", nullptr, CreateTensor, nullptr, nullptr, nullptr,
|
246 | napi_default, nullptr},
|
247 | {"deleteTensor", nullptr, DeleteTensor, nullptr, nullptr, nullptr,
|
248 | napi_default, nullptr},
|
249 | {"tensorDataSync", nullptr, TensorDataSync, nullptr, nullptr, nullptr,
|
250 | napi_default, nullptr},
|
251 | {"executeOp", nullptr, ExecuteOp, nullptr, nullptr, nullptr, napi_default,
|
252 | nullptr},
|
253 | {"loadSavedModel", nullptr, LoadSavedModel, nullptr, nullptr, nullptr,
|
254 | napi_default, nullptr},
|
255 | {"deleteSavedModel", nullptr, DeleteSavedModel, nullptr, nullptr, nullptr,
|
256 | napi_default, nullptr},
|
257 | {"runSavedModel", nullptr, RunSavedModel, nullptr, nullptr, nullptr,
|
258 | napi_default, nullptr},
|
259 | {"TF_Version", nullptr, nullptr, nullptr, nullptr, tf_version,
|
260 | napi_default, nullptr},
|
261 | {"isUsingGpuDevice", nullptr, IsUsingGPUDevice, nullptr, nullptr, nullptr,
|
262 | napi_default, nullptr},
|
263 | {"getNumOfSavedModels", nullptr, GetNumOfSavedModels, nullptr, nullptr,
|
264 | nullptr, napi_default, nullptr},
|
265 | };
|
266 | nstatus = napi_define_properties(env, exports, ARRAY_SIZE(exports_properties),
|
267 | exports_properties);
|
268 | ENSURE_NAPI_OK_RETVAL(env, nstatus, exports);
|
269 |
|
270 |
|
271 | #define EXPORT_INT_PROPERTY(v) AssignIntProperty(env, exports, #v, v)
|
272 |
|
273 | EXPORT_INT_PROPERTY(TF_FLOAT);
|
274 | EXPORT_INT_PROPERTY(TF_INT32);
|
275 | EXPORT_INT_PROPERTY(TF_INT64);
|
276 | EXPORT_INT_PROPERTY(TF_BOOL);
|
277 | EXPORT_INT_PROPERTY(TF_COMPLEX64);
|
278 | EXPORT_INT_PROPERTY(TF_STRING);
|
279 | EXPORT_INT_PROPERTY(TF_RESOURCE);
|
280 | EXPORT_INT_PROPERTY(TF_UINT8);
|
281 |
|
282 |
|
283 | EXPORT_INT_PROPERTY(TF_ATTR_STRING);
|
284 | EXPORT_INT_PROPERTY(TF_ATTR_INT);
|
285 | EXPORT_INT_PROPERTY(TF_ATTR_FLOAT);
|
286 | EXPORT_INT_PROPERTY(TF_ATTR_BOOL);
|
287 | EXPORT_INT_PROPERTY(TF_ATTR_TYPE);
|
288 | EXPORT_INT_PROPERTY(TF_ATTR_SHAPE);
|
289 | #undef EXPORT_INT_PROPERTY
|
290 |
|
291 | return exports;
|
292 | }
|
293 |
|
294 | NAPI_MODULE(tfe_binding, InitTFNodeJSBinding)
|
295 |
|
296 | }
|