1 | /**
|
2 | * @license
|
3 | * Copyright 2018 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 |
|
18 |
|
19 |
|
20 |
|
21 |
|
22 |
|
23 |
|
24 |
|
25 |
|
26 |
|
27 |
|
28 | namespace tfnodejs {
|
29 |
|
30 | class TFJSBackend {
|
31 | public:
|
32 | // Creates, initializes, and returns a TFJSBackend instance. If initialization
|
33 | // fails, a nullptr is returned.
|
34 | static TFJSBackend *Create(napi_env env);
|
35 |
|
36 | // Creates a new Tensor with given shape and data and returns an ID that
|
37 | // refernces the new Tensor.
|
38 | // - shape_value (number[])
|
39 | // - dtype_value (number)
|
40 | // - array_value (TypedArray|Array)
|
41 | napi_value CreateTensor(napi_env env, napi_value shape_value,
|
42 | napi_value dtype_value, napi_value array_value);
|
43 |
|
44 | // Deletes a created Tensor.
|
45 | // - tensor_id_value (number)
|
46 | void DeleteTensor(napi_env env, napi_value tensor_id_value);
|
47 |
|
48 | // Returns a typed-array as a `napi_value` with the data associated with the
|
49 | // TF/TFE pointers.
|
50 | // - tensor_id_value (number)
|
51 | napi_value GetTensorData(napi_env env, napi_value tensor_id_value);
|
52 |
|
53 | // Executes a TFE Op and returns an array of objects containing tensor
|
54 | // attributes (id, dtype, shape).
|
55 | // - op_name_value (string)
|
56 | // - op_attr_inputs (array of TFE Op attributes)
|
57 | // - input_tensor_ids (array of input tensor IDs)
|
58 | // - num_output_values (number)
|
59 | napi_value ExecuteOp(napi_env env, napi_value op_name_value,
|
60 | napi_value op_attr_inputs, napi_value input_tensor_ids,
|
61 | napi_value num_output_values);
|
62 |
|
63 | // Load a SavedModel from a path:
|
64 | // - export_dir (string)
|
65 | // - tags_value (string)
|
66 | napi_value LoadSavedModel(napi_env env, napi_value export_dir,
|
67 | napi_value tags_value);
|
68 |
|
69 | // Delete the SavedModel corresponding TF_Session and TF_Graph
|
70 | // - saved_model_id (number)
|
71 | void DeleteSavedModel(napi_env env, napi_value saved_model_id);
|
72 |
|
73 | // Execute a session from SavedModel with the provided inputs:
|
74 | // - saved_model_id (number)
|
75 | // - input_tensor_ids (array of input tensor IDs)
|
76 | // - input_op_names (array of input op names)
|
77 | // - output_op_names (array of output op names)
|
78 | napi_value RunSavedModel(napi_env env, napi_value saved_model_id,
|
79 | napi_value input_tensor_ids,
|
80 | napi_value input_op_names,
|
81 | napi_value output_op_names);
|
82 |
|
83 | // Get number of loaded SavedModel in the backend:
|
84 | napi_value GetNumOfSavedModels(napi_env env);
|
85 |
|
86 | private:
|
87 | TFJSBackend(napi_env env);
|
88 | ~TFJSBackend();
|
89 |
|
90 | int32_t InsertHandle(TFE_TensorHandle *tfe_handle);
|
91 | int32_t InsertSavedModel(TF_Session *tf_session, TF_Graph *tf_graph);
|
92 | napi_value GenerateOutputTensorInfo(napi_env env, TFE_TensorHandle *handle);
|
93 |
|
94 | TFE_Context *tfe_context_;
|
95 | std::unordered_map<int32_t, TFE_TensorHandle *> tfe_handle_map_;
|
96 | std::unordered_map<int32_t, std::pair<TF_Session *, TF_Graph *>>
|
97 | tf_savedmodel_map_;
|
98 | int32_t next_tensor_id_;
|
99 | int32_t next_savedmodel_id_;
|
100 | std::string device_name;
|
101 |
|
102 | public:
|
103 | bool is_gpu_device;
|
104 | };
|
105 |
|
106 | } // namespace tfnodejs
|
107 |
|
108 |
|