UNPKG

3.88 kBtext/x-cView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17
18#ifndef TF_NODEJS_TFJS_BACKEND_H_
19#define TF_NODEJS_TFJS_BACKEND_H_
20
21#include <node_api.h>
22#include <memory>
23#include <string>
24#include <unordered_map>
25#include "tensorflow/c/c_api.h"
26#include "tensorflow/c/eager/c_api.h"
27
28namespace tfnodejs {
29
30class TFJSBackend {
31 public:
32 // Creates, initializes, and returns a TFJSBackend instance. If initialization
33 // fails, a nullptr is returned.
34 static TFJSBackend *Create(napi_env env);
35
36 // Creates a new Tensor with given shape and data and returns an ID that
37 // refernces the new Tensor.
38 // - shape_value (number[])
39 // - dtype_value (number)
40 // - array_value (TypedArray|Array)
41 napi_value CreateTensor(napi_env env, napi_value shape_value,
42 napi_value dtype_value, napi_value array_value);
43
44 // Deletes a created Tensor.
45 // - tensor_id_value (number)
46 void DeleteTensor(napi_env env, napi_value tensor_id_value);
47
48 // Returns a typed-array as a `napi_value` with the data associated with the
49 // TF/TFE pointers.
50 // - tensor_id_value (number)
51 napi_value GetTensorData(napi_env env, napi_value tensor_id_value);
52
53 // Executes a TFE Op and returns an array of objects containing tensor
54 // attributes (id, dtype, shape).
55 // - op_name_value (string)
56 // - op_attr_inputs (array of TFE Op attributes)
57 // - input_tensor_ids (array of input tensor IDs)
58 // - num_output_values (number)
59 napi_value ExecuteOp(napi_env env, napi_value op_name_value,
60 napi_value op_attr_inputs, napi_value input_tensor_ids,
61 napi_value num_output_values);
62
63 // Load a SavedModel from a path:
64 // - export_dir (string)
65 // - tags_value (string)
66 napi_value LoadSavedModel(napi_env env, napi_value export_dir,
67 napi_value tags_value);
68
69 // Delete the SavedModel corresponding TF_Session and TF_Graph
70 // - saved_model_id (number)
71 void DeleteSavedModel(napi_env env, napi_value saved_model_id);
72
73 // Execute a session from SavedModel with the provided inputs:
74 // - saved_model_id (number)
75 // - input_tensor_ids (array of input tensor IDs)
76 // - input_op_names (array of input op names)
77 // - output_op_names (array of output op names)
78 napi_value RunSavedModel(napi_env env, napi_value saved_model_id,
79 napi_value input_tensor_ids,
80 napi_value input_op_names,
81 napi_value output_op_names);
82
83 // Get number of loaded SavedModel in the backend:
84 napi_value GetNumOfSavedModels(napi_env env);
85
86 private:
87 TFJSBackend(napi_env env);
88 ~TFJSBackend();
89
90 int32_t InsertHandle(TFE_TensorHandle *tfe_handle);
91 int32_t InsertSavedModel(TF_Session *tf_session, TF_Graph *tf_graph);
92 napi_value GenerateOutputTensorInfo(napi_env env, TFE_TensorHandle *handle);
93
94 TFE_Context *tfe_context_;
95 std::unordered_map<int32_t, TFE_TensorHandle *> tfe_handle_map_;
96 std::unordered_map<int32_t, std::pair<TF_Session *, TF_Graph *>>
97 tf_savedmodel_map_;
98 int32_t next_tensor_id_;
99 int32_t next_savedmodel_id_;
100 std::string device_name;
101
102 public:
103 bool is_gpu_device;
104};
105
106} // namespace tfnodejs
107
108#endif // TF_NODEJS_TFJS_BACKEND_H_