UNPKG

12 kBtext/x-cView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17
18#ifndef TF_NODEJS_UTILS_H_
19#define TF_NODEJS_UTILS_H_
20
21#include <node_api.h>
22#include <stdarg.h>
23#include <stdio.h>
24#include <cstdlib>
25#include <cstring>
26#include <vector>
27#include "tensorflow/c/c_api.h"
28#include "tf_auto_status.h"
29
30#define MAX_TENSOR_SHAPE 4
31
32#define ARRAY_SIZE(array) (sizeof(array) / sizeof(array[0]))
33
34#ifndef DEBUG
35#define DEBUG 0
36#endif
37
38#define DEBUG_LOG(message, file, line_number) \
39 do { \
40 if (DEBUG) \
41 fprintf(stderr, "** -%s:%zu\n-- %s\n", file, line_number, message); \
42 } while (0)
43
44namespace tfnodejs {
45
46#define NAPI_THROW_ERROR(env, message, ...) \
47 NapiThrowError(env, __FILE__, __LINE__, message, ##__VA_ARGS__);
48
49inline void NapiThrowError(napi_env env, const char *file,
50 const size_t line_number, const char *message, ...) {
51 char buffer[500];
52 va_list args;
53 va_start(args, message);
54 std::vsnprintf(buffer, 500, message, args);
55 va_end(args);
56 DEBUG_LOG(buffer, file, line_number);
57 napi_throw_error(env, nullptr, buffer);
58}
59
60#define ENSURE_NAPI_OK(env, status) \
61 if (!EnsureNapiOK(env, status, __FILE__, __LINE__)) return;
62#define ENSURE_NAPI_OK_RETVAL(env, status, retval) \
63 if (!EnsureNapiOK(env, status, __FILE__, __LINE__)) return retval;
64
65inline bool EnsureNapiOK(napi_env env, napi_status status, const char *file,
66 const size_t line_number) {
67 if (status != napi_ok) {
68 const napi_extended_error_info *error_info = 0;
69 napi_get_last_error_info(env, &error_info);
70 NapiThrowError(
71 env, file, line_number, "Invalid napi_status: %s\n",
72 error_info->error_message ? error_info->error_message : "unknown");
73 }
74 return status == napi_ok;
75}
76
77#define ENSURE_TF_OK(env, status) \
78 if (!EnsureTFOK(env, status, __FILE__, __LINE__)) return;
79#define ENSURE_TF_OK_RETVAL(env, status, retval) \
80 if (!EnsureTFOK(env, status, __FILE__, __LINE__)) return retval;
81
82inline bool EnsureTFOK(napi_env env, TF_AutoStatus &status, const char *file,
83 const size_t line_number) {
84 TF_Code tf_code = TF_GetCode(status.status);
85 if (tf_code != TF_OK) {
86 NapiThrowError(env, file, line_number, "Invalid TF_Status: %u\nMessage: %s",
87 TF_GetCode(status.status), TF_Message(status.status));
88 }
89 return tf_code == TF_OK;
90}
91
92#define ENSURE_CONSTRUCTOR_CALL(env, info) \
93 if (!EnsureConstructorCall(env, info, __FILE__, __LINE__)) return;
94#define ENSURE_CONSTRUCTOR_CALL_RETVAL(env, info, retval) \
95 if (!EnsureConstructorCall(env, info, __FILE__, __LINE__)) return retval;
96
97inline bool EnsureConstructorCall(napi_env env, napi_callback_info info,
98 const char *file, const size_t line_number) {
99 napi_value js_target;
100 napi_status nstatus = napi_get_new_target(env, info, &js_target);
101 ENSURE_NAPI_OK_RETVAL(env, nstatus, false);
102 bool is_target = js_target != nullptr;
103 if (!is_target) {
104 NapiThrowError(env, file, line_number,
105 "Function not used as a constructor!");
106 }
107 return is_target;
108}
109
110#define ENSURE_VALUE_IS_OBJECT(env, value) \
111 if (!EnsureValueIsObject(env, value, __FILE__, __LINE__)) return;
112#define ENSURE_VALUE_IS_OBJECT_RETVAL(env, value, retval) \
113 if (!EnsureValueIsObject(env, value, __FILE__, __LINE__)) return retval;
114
115inline bool EnsureValueIsObject(napi_env env, napi_value value,
116 const char *file, const size_t line_number) {
117 napi_valuetype type;
118 ENSURE_NAPI_OK_RETVAL(env, napi_typeof(env, value, &type), false);
119 bool is_object = type == napi_object;
120 if (!is_object) {
121 NapiThrowError(env, file, line_number, "Argument is not an object!");
122 }
123 return is_object;
124}
125
126#define ENSURE_VALUE_IS_STRING(env, value) \
127 if (!EnsureValueIsString(env, value, __FILE__, __LINE__)) return;
128#define ENSURE_VALUE_IS_STRING_RETVAL(env, value, retval) \
129 if (!EnsureValueIsString(env, value, __FILE__, __LINE__)) return retval;
130
131inline bool EnsureValueIsString(napi_env env, napi_value value,
132 const char *file, const size_t line_number) {
133 napi_valuetype type;
134 ENSURE_NAPI_OK_RETVAL(env, napi_typeof(env, value, &type), false);
135 bool is_string = type == napi_string;
136 if (!is_string) {
137 NapiThrowError(env, file, line_number, "Argument is not a string!");
138 }
139 return is_string;
140}
141
142#define ENSURE_VALUE_IS_NUMBER(env, value) \
143 if (!EnsureValueIsNumber(env, value, __FILE__, __LINE__)) return;
144#define ENSURE_VALUE_IS_NUMBER_RETVAL(env, value, retval) \
145 if (!EnsureValueIsNumber(env, value, __FILE__, __LINE__)) return retval;
146
147inline bool EnsureValueIsNumber(napi_env env, napi_value value,
148 const char *file, const size_t line_number) {
149 napi_valuetype type;
150 ENSURE_NAPI_OK_RETVAL(env, napi_typeof(env, value, &type), false);
151 bool is_number = type == napi_number;
152 if (!is_number) {
153 NapiThrowError(env, file, line_number, "Argument is not a number!");
154 }
155 return is_number;
156}
157
158#define ENSURE_VALUE_IS_ARRAY(env, value) \
159 if (!EnsureValueIsArray(env, value, __FILE__, __LINE__)) return;
160#define ENSURE_VALUE_IS_ARRAY_RETVAL(env, value, retval) \
161 if (!EnsureValueIsArray(env, value, __FILE__, __LINE__)) return retval;
162
163inline bool EnsureValueIsArray(napi_env env, napi_value value, const char *file,
164 const size_t line_number) {
165 bool is_array;
166 ENSURE_NAPI_OK_RETVAL(env, napi_is_array(env, value, &is_array), false);
167 if (!is_array) {
168 NapiThrowError(env, file, line_number, "Argument is not an array!");
169 }
170 return is_array;
171}
172
173#define ENSURE_VALUE_IS_TYPED_ARRAY(env, value) \
174 if (!EnsureValueIsTypedArray(env, value, __FILE__, __LINE__)) return;
175#define ENSURE_VALUE_IS_TYPED_ARRAY_RETVAL(env, value, retval) \
176 if (!EnsureValueIsTypedArray(env, value, __FILE__, __LINE__)) return retval;
177
178inline bool EnsureValueIsTypedArray(napi_env env, napi_value value,
179 const char *file,
180 const size_t line_number) {
181 bool is_array;
182 ENSURE_NAPI_OK_RETVAL(env, napi_is_typedarray(env, value, &is_array), false);
183 if (!is_array) {
184 NapiThrowError(env, file, line_number, "Argument is not a typed-array!");
185 }
186 return is_array;
187}
188
189#define ENSURE_VALUE_IS_LESS_THAN(env, value, max) \
190 if (!EnsureValueIsLessThan(env, value, max, __FILE__, __LINE__)) return;
191#define ENSURE_VALUE_IS_LESS_THAN_RETVAL(env, value, max, retval) \
192 if (!EnsureValueIsLessThan(env, value, max, __FILE__, __LINE__)) \
193 return retval;
194
195inline bool EnsureValueIsLessThan(napi_env env, uint32_t value, uint32_t max,
196 const char *file, const size_t line_number) {
197 if (value > max) {
198 NapiThrowError(env, file, line_number,
199 "Argument is greater than max: %u > %u", value, max);
200 return false;
201 } else {
202 return true;
203 }
204}
205
206#define REPORT_UNKNOWN_TF_DATA_TYPE(env, type) \
207 ReportUnknownTFDataType(env, type, __FILE__, __LINE__)
208
209inline void ReportUnknownTFDataType(napi_env env, TF_DataType type,
210 const char *file,
211 const size_t line_number) {
212 NapiThrowError(env, file, line_number, "Unhandled TF_DataType: %u\n", type);
213}
214
215#define REPORT_UNKNOWN_TF_ATTR_TYPE(env, type) \
216 ReportUnknownTFAttrType(env, type, __FILE__, __LINE__)
217
218inline void ReportUnknownTFAttrType(napi_env env, TF_AttrType type,
219 const char *file,
220 const size_t line_number) {
221 NapiThrowError(env, file, line_number, "Unhandled TF_AttrType: %u\n", type);
222}
223
224#define REPORT_UNKNOWN_TYPED_ARRAY_TYPE(env, type) \
225 ReportUnknownTypedArrayType(env, type, __FILE__, __LINE__)
226
227inline void ReportUnknownTypedArrayType(napi_env env, napi_typedarray_type type,
228 const char *file,
229 const size_t line_number) {
230 NapiThrowError(env, file, line_number, "Unhandled napi typed_array_type: %u",
231 type);
232}
233
234// Returns a vector with the shape values of an array.
235inline void ExtractArrayShape(napi_env env, napi_value array_value,
236 std::vector<int64_t> *result) {
237 napi_status nstatus;
238
239 uint32_t array_length;
240 nstatus = napi_get_array_length(env, array_value, &array_length);
241 ENSURE_NAPI_OK(env, nstatus);
242
243 for (uint32_t i = 0; i < array_length; i++) {
244 napi_value dimension_value;
245 nstatus = napi_get_element(env, array_value, i, &dimension_value);
246 ENSURE_NAPI_OK(env, nstatus);
247
248 int64_t dimension;
249 nstatus = napi_get_value_int64(env, dimension_value, &dimension);
250 ENSURE_NAPI_OK(env, nstatus);
251
252 result->push_back(dimension);
253 }
254}
255
256inline bool IsExceptionPending(napi_env env) {
257 bool has_exception = false;
258 ENSURE_NAPI_OK_RETVAL(env, napi_is_exception_pending(env, &has_exception),
259 has_exception);
260 return has_exception;
261}
262
263#define ENSURE_VALUE_IS_NOT_NULL(env, value) \
264 if (!EnsureValueIsNotNull(env, value, __FILE__, __LINE__)) return;
265#define ENSURE_VALUE_IS_NOT_NULL_RETVAL(env, value, retval) \
266 if (!EnsureValueIsNotNull(env, value, __FILE__, __LINE__)) return retval;
267
268inline bool EnsureValueIsNotNull(napi_env env, void *value, const char *file,
269 const size_t line_number) {
270 bool is_null = value == nullptr;
271 if (is_null) {
272 NapiThrowError(env, file, line_number, "Argument is null!");
273 }
274 return !is_null;
275}
276
277inline napi_status GetStringParam(napi_env env, napi_value string_value,
278 std::string &string) {
279 ENSURE_VALUE_IS_STRING_RETVAL(env, string_value, napi_invalid_arg);
280
281 napi_status nstatus;
282
283 size_t str_length;
284 nstatus =
285 napi_get_value_string_utf8(env, string_value, nullptr, 0, &str_length);
286 ENSURE_NAPI_OK_RETVAL(env, nstatus, nstatus);
287
288 char *buffer = (char *)(malloc(sizeof(char) * (str_length + 1)));
289 ENSURE_VALUE_IS_NOT_NULL_RETVAL(env, buffer, napi_generic_failure);
290
291 nstatus = napi_get_value_string_utf8(env, string_value, buffer,
292 str_length + 1, &str_length);
293 ENSURE_NAPI_OK_RETVAL(env, nstatus, nstatus);
294
295 string.assign(buffer, str_length);
296 free(buffer);
297 return napi_ok;
298}
299
300// Returns the number of elements in a Tensor.
301inline size_t GetTensorNumElements(TF_Tensor *tensor) {
302 size_t ret = 1;
303 for (int i = 0; i < TF_NumDims(tensor); ++i) {
304 ret *= TF_Dim(tensor, i);
305 }
306 return ret;
307}
308
309// Split a string into an array of characters array with `,` as delimiter.
310inline std::vector<const char *> splitStringByComma(const std::string &str) {
311 std::vector<const char *> tokens;
312 size_t prev = 0, pos = 0;
313 do {
314 pos = str.find(',', prev);
315 if (pos == std::string::npos) pos = str.length();
316 std::string token = str.substr(prev, pos - prev);
317 if (!token.empty()) {
318 char *cstr = new char[str.length() + 1];
319 std::strcpy(cstr, token.c_str());
320 tokens.push_back(cstr);
321 }
322 prev = pos + 1;
323 } while (pos < str.length() && prev < str.length());
324 return tokens;
325}
326
327} // namespace tfnodejs
328
329#endif // TF_NODEJS_UTILS_H_