UNPKG

10.9 kBTypeScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17/// <amd-module name="@tensorflow/tfjs-converter/dist/executor/graph_model" />
18import { InferenceModel, io, ModelPredictConfig, NamedTensorMap, Tensor } from '@tensorflow/tfjs-core';
19import { NamedTensorsMap, TensorInfo } from '../data/types';
20export declare const TFHUB_SEARCH_PARAM = "?tfjs-format=file";
21export declare const DEFAULT_MODEL_NAME = "model.json";
22/**
23 * A `tf.GraphModel` is a directed, acyclic graph built from a
24 * SavedModel GraphDef and allows inference execution.
25 *
26 * A `tf.GraphModel` can only be created by loading from a model converted from
27 * a [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) using
28 * the command line converter tool and loaded via `tf.loadGraphModel`.
29 *
30 * @doc {heading: 'Models', subheading: 'Classes'}
31 */
32export declare class GraphModel implements InferenceModel {
33 private modelUrl;
34 private loadOptions;
35 private executor;
36 private version;
37 private handler;
38 private artifacts;
39 private initializer;
40 private resourceManager;
41 private signature;
42 readonly modelVersion: string;
43 readonly inputNodes: string[];
44 readonly outputNodes: string[];
45 readonly inputs: TensorInfo[];
46 readonly outputs: TensorInfo[];
47 readonly weights: NamedTensorsMap;
48 readonly metadata: {};
49 readonly modelSignature: {};
50 /**
51 * @param modelUrl url for the model, or an `io.IOHandler`.
52 * @param weightManifestUrl url for the weight file generated by
53 * scripts/convert.py script.
54 * @param requestOption options for Request, which allows to send credentials
55 * and custom headers.
56 * @param onProgress Optional, progress callback function, fired periodically
57 * before the load is completed.
58 */
59 constructor(modelUrl: string | io.IOHandler, loadOptions?: io.LoadOptions);
60 private findIOHandler;
61 /**
62 * Loads the model and weight files, construct the in memory weight map and
63 * compile the inference graph.
64 */
65 load(): Promise<boolean>;
66 /**
67 * Synchronously construct the in memory weight map and
68 * compile the inference graph. Also initialize hashtable if any.
69 *
70 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
71 */
72 loadSync(artifacts: io.ModelArtifacts): boolean;
73 /**
74 * Save the configuration and/or weights of the GraphModel.
75 *
76 * An `IOHandler` is an object that has a `save` method of the proper
77 * signature defined. The `save` method manages the storing or
78 * transmission of serialized data ("artifacts") that represent the
79 * model's topology and weights onto or via a specific medium, such as
80 * file downloads, local storage, IndexedDB in the web browser and HTTP
81 * requests to a server. TensorFlow.js provides `IOHandler`
82 * implementations for a number of frequently used saving mediums, such as
83 * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
84 * for more details.
85 *
86 * This method also allows you to refer to certain types of `IOHandler`s
87 * as URL-like string shortcuts, such as 'localstorage://' and
88 * 'indexeddb://'.
89 *
90 * Example 1: Save `model`'s topology and weights to browser [local
91 * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
92 * then load it back.
93 *
94 * ```js
95 * const modelUrl =
96 * 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
97 * const model = await tf.loadGraphModel(modelUrl);
98 * const zeros = tf.zeros([1, 224, 224, 3]);
99 * model.predict(zeros).print();
100 *
101 * const saveResults = await model.save('localstorage://my-model-1');
102 *
103 * const loadedModel = await tf.loadGraphModel('localstorage://my-model-1');
104 * console.log('Prediction from loaded model:');
105 * model.predict(zeros).print();
106 * ```
107 *
108 * @param handlerOrURL An instance of `IOHandler` or a URL-like,
109 * scheme-based string shortcut for `IOHandler`.
110 * @param config Options for saving the model.
111 * @returns A `Promise` of `SaveResult`, which summarizes the result of
112 * the saving, such as byte sizes of the saved artifacts for the model's
113 * topology and weight values.
114 *
115 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
116 */
117 save(handlerOrURL: io.IOHandler | string, config?: io.SaveConfig): Promise<io.SaveResult>;
118 /**
119 * Execute the inference for the input tensors.
120 *
121 * @param input The input tensors, when there is single input for the model,
122 * inputs param should be a `tf.Tensor`. For models with mutliple inputs,
123 * inputs params should be in either `tf.Tensor`[] if the input order is
124 * fixed, or otherwise NamedTensorMap format.
125 *
126 * For model with multiple inputs, we recommend you use NamedTensorMap as the
127 * input type, if you use `tf.Tensor`[], the order of the array needs to
128 * follow the
129 * order of inputNodes array. @see {@link GraphModel.inputNodes}
130 *
131 * You can also feed any intermediate nodes using the NamedTensorMap as the
132 * input type. For example, given the graph
133 * InputNode => Intermediate => OutputNode,
134 * you can execute the subgraph Intermediate => OutputNode by calling
135 * model.execute('IntermediateNode' : tf.tensor(...));
136 *
137 * This is useful for models that uses tf.dynamic_rnn, where the intermediate
138 * state needs to be fed manually.
139 *
140 * For batch inference execution, the tensors for each input need to be
141 * concatenated together. For example with mobilenet, the required input shape
142 * is [1, 244, 244, 3], which represents the [batch, height, width, channel].
143 * If we are provide a batched data of 100 images, the input tensor should be
144 * in the shape of [100, 244, 244, 3].
145 *
146 * @param config Prediction configuration for specifying the batch size and
147 * output node names. Currently the batch size option is ignored for graph
148 * model.
149 *
150 * @returns Inference result tensors. The output would be single `tf.Tensor`
151 * if model has single output node, otherwise Tensor[] or NamedTensorMap[]
152 * will be returned for model with multiple outputs.
153 *
154 * @doc {heading: 'Models', subheading: 'Classes'}
155 */
156 predict(inputs: Tensor | Tensor[] | NamedTensorMap, config?: ModelPredictConfig): Tensor | Tensor[] | NamedTensorMap;
157 private normalizeInputs;
158 private normalizeOutputs;
159 /**
160 * Executes inference for the model for given input tensors.
161 * @param inputs tensor, tensor array or tensor map of the inputs for the
162 * model, keyed by the input node names.
163 * @param outputs output node name from the Tensorflow model, if no
164 * outputs are specified, the default outputs of the model would be used.
165 * You can inspect intermediate nodes of the model by adding them to the
166 * outputs array.
167 *
168 * @returns A single tensor if provided with a single output or no outputs
169 * are provided and there is only one default output, otherwise return a
170 * tensor array. The order of the tensor array is the same as the outputs
171 * if provided, otherwise the order of outputNodes attribute of the model.
172 *
173 * @doc {heading: 'Models', subheading: 'Classes'}
174 */
175 execute(inputs: Tensor | Tensor[] | NamedTensorMap, outputs?: string | string[]): Tensor | Tensor[];
176 /**
177 * Executes inference for the model for given input tensors in async
178 * fashion, use this method when your model contains control flow ops.
179 * @param inputs tensor, tensor array or tensor map of the inputs for the
180 * model, keyed by the input node names.
181 * @param outputs output node name from the Tensorflow model, if no outputs
182 * are specified, the default outputs of the model would be used. You can
183 * inspect intermediate nodes of the model by adding them to the outputs
184 * array.
185 *
186 * @returns A Promise of single tensor if provided with a single output or
187 * no outputs are provided and there is only one default output, otherwise
188 * return a tensor map.
189 *
190 * @doc {heading: 'Models', subheading: 'Classes'}
191 */
192 executeAsync(inputs: Tensor | Tensor[] | NamedTensorMap, outputs?: string | string[]): Promise<Tensor | Tensor[]>;
193 /**
194 * Get intermediate tensors for model debugging mode (flag
195 * KEEP_INTERMEDIATE_TENSORS is true).
196 *
197 * @doc {heading: 'Models', subheading: 'Classes'}
198 */
199 getIntermediateTensors(): NamedTensorsMap;
200 /**
201 * Dispose intermediate tensors for model debugging mode (flag
202 * KEEP_INTERMEDIATE_TENSORS is true).
203 *
204 * @doc {heading: 'Models', subheading: 'Classes'}
205 */
206 disposeIntermediateTensors(): void;
207 private convertTensorMapToTensorsMap;
208 /**
209 * Releases the memory used by the weight tensors and resourceManager.
210 *
211 * @doc {heading: 'Models', subheading: 'Classes'}
212 */
213 dispose(): void;
214}
215/**
216 * Load a graph model given a URL to the model definition.
217 *
218 * Example of loading MobileNetV2 from a URL and making a prediction with a
219 * zeros input:
220 *
221 * ```js
222 * const modelUrl =
223 * 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
224 * const model = await tf.loadGraphModel(modelUrl);
225 * const zeros = tf.zeros([1, 224, 224, 3]);
226 * model.predict(zeros).print();
227 * ```
228 *
229 * Example of loading MobileNetV2 from a TF Hub URL and making a prediction with
230 * a zeros input:
231 *
232 * ```js
233 * const modelUrl =
234 * 'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
235 * const model = await tf.loadGraphModel(modelUrl, {fromTFHub: true});
236 * const zeros = tf.zeros([1, 224, 224, 3]);
237 * model.predict(zeros).print();
238 * ```
239 * @param modelUrl The url or an `io.IOHandler` that loads the model.
240 * @param options Options for the HTTP request, which allows to send credentials
241 * and custom headers.
242 *
243 * @doc {heading: 'Models', subheading: 'Loading'}
244 */
245export declare function loadGraphModel(modelUrl: string | io.IOHandler, options?: io.LoadOptions): Promise<GraphModel>;