UNPKG

14.4 kBTypeScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17/// <amd-module name="@tensorflow/tfjs-converter/dist/executor/graph_model" />
18import { InferenceModel, io, ModelPredictConfig, NamedTensorMap, Tensor } from '@tensorflow/tfjs-core';
19import { NamedTensorsMap, TensorInfo } from '../data/types';
20export declare const TFHUB_SEARCH_PARAM = "?tfjs-format=file";
21export declare const DEFAULT_MODEL_NAME = "model.json";
22type Url = string | io.IOHandler | io.IOHandlerSync;
23type UrlIOHandler<T extends Url> = T extends string ? io.IOHandler : T;
24/**
25 * A `tf.GraphModel` is a directed, acyclic graph built from a
26 * SavedModel GraphDef and allows inference execution.
27 *
28 * A `tf.GraphModel` can only be created by loading from a model converted from
29 * a [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) using
30 * the command line converter tool and loaded via `tf.loadGraphModel`.
31 *
32 * @doc {heading: 'Models', subheading: 'Classes'}
33 */
34export declare class GraphModel<ModelURL extends Url = string | io.IOHandler> implements InferenceModel {
35 private modelUrl;
36 private loadOptions;
37 private executor;
38 private version;
39 private handler;
40 private artifacts;
41 private initializer;
42 private resourceIdToCapturedInput;
43 private resourceManager;
44 private signature;
45 private initializerSignature;
46 private structuredOutputKeys;
47 private readonly io;
48 get modelVersion(): string;
49 get inputNodes(): string[];
50 get outputNodes(): string[];
51 get inputs(): TensorInfo[];
52 get outputs(): TensorInfo[];
53 get weights(): NamedTensorsMap;
54 get metadata(): {};
55 get modelSignature(): {};
56 get modelStructuredOutputKeys(): {};
57 /**
58 * @param modelUrl url for the model, or an `io.IOHandler`.
59 * @param weightManifestUrl url for the weight file generated by
60 * scripts/convert.py script.
61 * @param requestOption options for Request, which allows to send credentials
62 * and custom headers.
63 * @param onProgress Optional, progress callback function, fired periodically
64 * before the load is completed.
65 */
66 constructor(modelUrl: ModelURL, loadOptions?: io.LoadOptions, tfio?: typeof io);
67 private findIOHandler;
68 /**
69 * Loads the model and weight files, construct the in memory weight map and
70 * compile the inference graph.
71 */
72 load(): UrlIOHandler<ModelURL> extends io.IOHandlerSync ? boolean : Promise<boolean>;
73 /**
74 * Synchronously construct the in memory weight map and
75 * compile the inference graph.
76 *
77 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
78 */
79 loadSync(artifacts: io.ModelArtifacts): boolean;
80 private loadStreaming;
81 private loadWithWeightMap;
82 /**
83 * Save the configuration and/or weights of the GraphModel.
84 *
85 * An `IOHandler` is an object that has a `save` method of the proper
86 * signature defined. The `save` method manages the storing or
87 * transmission of serialized data ("artifacts") that represent the
88 * model's topology and weights onto or via a specific medium, such as
89 * file downloads, local storage, IndexedDB in the web browser and HTTP
90 * requests to a server. TensorFlow.js provides `IOHandler`
91 * implementations for a number of frequently used saving mediums, such as
92 * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
93 * for more details.
94 *
95 * This method also allows you to refer to certain types of `IOHandler`s
96 * as URL-like string shortcuts, such as 'localstorage://' and
97 * 'indexeddb://'.
98 *
99 * Example 1: Save `model`'s topology and weights to browser [local
100 * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
101 * then load it back.
102 *
103 * ```js
104 * const modelUrl =
105 * 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
106 * const model = await tf.loadGraphModel(modelUrl);
107 * const zeros = tf.zeros([1, 224, 224, 3]);
108 * model.predict(zeros).print();
109 *
110 * const saveResults = await model.save('localstorage://my-model-1');
111 *
112 * const loadedModel = await tf.loadGraphModel('localstorage://my-model-1');
113 * console.log('Prediction from loaded model:');
114 * model.predict(zeros).print();
115 * ```
116 *
117 * @param handlerOrURL An instance of `IOHandler` or a URL-like,
118 * scheme-based string shortcut for `IOHandler`.
119 * @param config Options for saving the model.
120 * @returns A `Promise` of `SaveResult`, which summarizes the result of
121 * the saving, such as byte sizes of the saved artifacts for the model's
122 * topology and weight values.
123 *
124 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
125 */
126 save(handlerOrURL: io.IOHandler | string, config?: io.SaveConfig): Promise<io.SaveResult>;
127 private addStructuredOutputNames;
128 /**
129 * Execute the inference for the input tensors.
130 *
131 * @param input The input tensors, when there is single input for the model,
132 * inputs param should be a `tf.Tensor`. For models with multiple inputs,
133 * inputs params should be in either `tf.Tensor`[] if the input order is
134 * fixed, or otherwise NamedTensorMap format.
135 *
136 * For model with multiple inputs, we recommend you use NamedTensorMap as the
137 * input type, if you use `tf.Tensor`[], the order of the array needs to
138 * follow the
139 * order of inputNodes array. @see {@link GraphModel.inputNodes}
140 *
141 * You can also feed any intermediate nodes using the NamedTensorMap as the
142 * input type. For example, given the graph
143 * InputNode => Intermediate => OutputNode,
144 * you can execute the subgraph Intermediate => OutputNode by calling
145 * model.execute('IntermediateNode' : tf.tensor(...));
146 *
147 * This is useful for models that uses tf.dynamic_rnn, where the intermediate
148 * state needs to be fed manually.
149 *
150 * For batch inference execution, the tensors for each input need to be
151 * concatenated together. For example with mobilenet, the required input shape
152 * is [1, 244, 244, 3], which represents the [batch, height, width, channel].
153 * If we are provide a batched data of 100 images, the input tensor should be
154 * in the shape of [100, 244, 244, 3].
155 *
156 * @param config Prediction configuration for specifying the batch size.
157 * Currently the batch size option is ignored for graph model.
158 *
159 * @returns Inference result tensors. If the model is converted and it
160 * originally had structured_outputs in tensorflow, then a NamedTensorMap
161 * will be returned matching the structured_outputs. If no structured_outputs
162 * are present, the output will be single `tf.Tensor` if the model has single
163 * output node, otherwise Tensor[].
164 *
165 * @doc {heading: 'Models', subheading: 'Classes'}
166 */
167 predict(inputs: Tensor | Tensor[] | NamedTensorMap, config?: ModelPredictConfig): Tensor | Tensor[] | NamedTensorMap;
168 /**
169 * Execute the inference for the input tensors in async fashion, use this
170 * method when your model contains control flow ops.
171 *
172 * @param input The input tensors, when there is single input for the model,
173 * inputs param should be a `tf.Tensor`. For models with mutliple inputs,
174 * inputs params should be in either `tf.Tensor`[] if the input order is
175 * fixed, or otherwise NamedTensorMap format.
176 *
177 * For model with multiple inputs, we recommend you use NamedTensorMap as the
178 * input type, if you use `tf.Tensor`[], the order of the array needs to
179 * follow the
180 * order of inputNodes array. @see {@link GraphModel.inputNodes}
181 *
182 * You can also feed any intermediate nodes using the NamedTensorMap as the
183 * input type. For example, given the graph
184 * InputNode => Intermediate => OutputNode,
185 * you can execute the subgraph Intermediate => OutputNode by calling
186 * model.execute('IntermediateNode' : tf.tensor(...));
187 *
188 * This is useful for models that uses tf.dynamic_rnn, where the intermediate
189 * state needs to be fed manually.
190 *
191 * For batch inference execution, the tensors for each input need to be
192 * concatenated together. For example with mobilenet, the required input shape
193 * is [1, 244, 244, 3], which represents the [batch, height, width, channel].
194 * If we are provide a batched data of 100 images, the input tensor should be
195 * in the shape of [100, 244, 244, 3].
196 *
197 * @param config Prediction configuration for specifying the batch size.
198 * Currently the batch size option is ignored for graph model.
199 *
200 * @returns A Promise of inference result tensors. If the model is converted
201 * and it originally had structured_outputs in tensorflow, then a
202 * NamedTensorMap will be returned matching the structured_outputs. If no
203 * structured_outputs are present, the output will be single `tf.Tensor` if
204 * the model has single output node, otherwise Tensor[].
205 *
206 * @doc {heading: 'Models', subheading: 'Classes'}
207 */
208 predictAsync(inputs: Tensor | Tensor[] | NamedTensorMap, config?: ModelPredictConfig): Promise<Tensor | Tensor[] | NamedTensorMap>;
209 private normalizeInputs;
210 private normalizeOutputs;
211 private executeInitializerGraph;
212 private executeInitializerGraphAsync;
213 private setResourceIdToCapturedInput;
214 /**
215 * Executes inference for the model for given input tensors.
216 * @param inputs tensor, tensor array or tensor map of the inputs for the
217 * model, keyed by the input node names.
218 * @param outputs output node name from the TensorFlow model, if no
219 * outputs are specified, the default outputs of the model would be used.
220 * You can inspect intermediate nodes of the model by adding them to the
221 * outputs array.
222 *
223 * @returns A single tensor if provided with a single output or no outputs
224 * are provided and there is only one default output, otherwise return a
225 * tensor array. The order of the tensor array is the same as the outputs
226 * if provided, otherwise the order of outputNodes attribute of the model.
227 *
228 * @doc {heading: 'Models', subheading: 'Classes'}
229 */
230 execute(inputs: Tensor | Tensor[] | NamedTensorMap, outputs?: string | string[]): Tensor | Tensor[];
231 /**
232 * Executes inference for the model for given input tensors in async
233 * fashion, use this method when your model contains control flow ops.
234 * @param inputs tensor, tensor array or tensor map of the inputs for the
235 * model, keyed by the input node names.
236 * @param outputs output node name from the TensorFlow model, if no outputs
237 * are specified, the default outputs of the model would be used. You can
238 * inspect intermediate nodes of the model by adding them to the outputs
239 * array.
240 *
241 * @returns A Promise of single tensor if provided with a single output or
242 * no outputs are provided and there is only one default output, otherwise
243 * return a tensor map.
244 *
245 * @doc {heading: 'Models', subheading: 'Classes'}
246 */
247 executeAsync(inputs: Tensor | Tensor[] | NamedTensorMap, outputs?: string | string[]): Promise<Tensor | Tensor[]>;
248 /**
249 * Get intermediate tensors for model debugging mode (flag
250 * KEEP_INTERMEDIATE_TENSORS is true).
251 *
252 * @doc {heading: 'Models', subheading: 'Classes'}
253 */
254 getIntermediateTensors(): NamedTensorsMap;
255 /**
256 * Dispose intermediate tensors for model debugging mode (flag
257 * KEEP_INTERMEDIATE_TENSORS is true).
258 *
259 * @doc {heading: 'Models', subheading: 'Classes'}
260 */
261 disposeIntermediateTensors(): void;
262 private convertTensorMapToTensorsMap;
263 /**
264 * Releases the memory used by the weight tensors and resourceManager.
265 *
266 * @doc {heading: 'Models', subheading: 'Classes'}
267 */
268 dispose(): void;
269}
270/**
271 * Load a graph model given a URL to the model definition.
272 *
273 * Example of loading MobileNetV2 from a URL and making a prediction with a
274 * zeros input:
275 *
276 * ```js
277 * const modelUrl =
278 * 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
279 * const model = await tf.loadGraphModel(modelUrl);
280 * const zeros = tf.zeros([1, 224, 224, 3]);
281 * model.predict(zeros).print();
282 * ```
283 *
284 * Example of loading MobileNetV2 from a TF Hub URL and making a prediction
285 * with a zeros input:
286 *
287 * ```js
288 * const modelUrl =
289 * 'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
290 * const model = await tf.loadGraphModel(modelUrl, {fromTFHub: true});
291 * const zeros = tf.zeros([1, 224, 224, 3]);
292 * model.predict(zeros).print();
293 * ```
294 * @param modelUrl The url or an `io.IOHandler` that loads the model.
295 * @param options Options for the HTTP request, which allows to send
296 * credentials
297 * and custom headers.
298 *
299 * @doc {heading: 'Models', subheading: 'Loading'}
300 */
301export declare function loadGraphModel(modelUrl: string | io.IOHandler, options?: io.LoadOptions, tfio?: typeof io): Promise<GraphModel>;
302/**
303 * Load a graph model given a synchronous IO handler with a 'load' method.
304 *
305 * @param modelSource The `io.IOHandlerSync` that loads the model, or the
306 * `io.ModelArtifacts` that encode the model, or a tuple of
307 * `[io.ModelJSON, ArrayBuffer]` of which the first element encodes the
308 * model and the second contains the weights.
309 *
310 * @doc {heading: 'Models', subheading: 'Loading'}
311 */
312export declare function loadGraphModelSync(modelSource: io.IOHandlerSync | io.ModelArtifacts | [io.ModelJSON, /* Weights */ ArrayBuffer]): GraphModel<io.IOHandlerSync>;
313export {};