UNPKG

10.4 kBTypeScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { InferenceModel, io, ModelPredictConfig, NamedTensorMap, Tensor } from '@tensorflow/tfjs-core';
18import { NamedTensorsMap, TensorInfo } from '../data/types';
19export declare const TFHUB_SEARCH_PARAM = "?tfjs-format=file";
20export declare const DEFAULT_MODEL_NAME = "model.json";
21/**
22 * A `tf.GraphModel` is a directed, acyclic graph built from a
23 * SavedModel GraphDef and allows inference execution.
24 *
25 * A `tf.GraphModel` can only be created by loading from a model converted from
26 * a [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) using
27 * the command line converter tool and loaded via `tf.loadGraphModel`.
28 *
29 * @doc {heading: 'Models', subheading: 'Classes'}
30 */
31export declare class GraphModel implements InferenceModel {
32 private modelUrl;
33 private loadOptions;
34 private executor;
35 private version;
36 private handler;
37 private artifacts;
38 private initializer;
39 private resourceManager;
40 private signature;
41 readonly modelVersion: string;
42 readonly inputNodes: string[];
43 readonly outputNodes: string[];
44 readonly inputs: TensorInfo[];
45 readonly outputs: TensorInfo[];
46 readonly weights: NamedTensorsMap;
47 readonly metadata: {};
48 readonly modelSignature: {};
49 /**
50 * @param modelUrl url for the model, or an `io.IOHandler`.
51 * @param weightManifestUrl url for the weight file generated by
52 * scripts/convert.py script.
53 * @param requestOption options for Request, which allows to send credentials
54 * and custom headers.
55 * @param onProgress Optional, progress callback function, fired periodically
56 * before the load is completed.
57 */
58 constructor(modelUrl: string | io.IOHandler, loadOptions?: io.LoadOptions);
59 private findIOHandler;
60 /**
61 * Loads the model and weight files, construct the in memory weight map and
62 * compile the inference graph.
63 */
64 load(): Promise<boolean>;
65 /**
66 * Synchronously construct the in memory weight map and
67 * compile the inference graph. Also initialize hashtable if any.
68 *
69 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
70 */
71 loadSync(artifacts: io.ModelArtifacts): boolean;
72 /**
73 * Save the configuration and/or weights of the GraphModel.
74 *
75 * An `IOHandler` is an object that has a `save` method of the proper
76 * signature defined. The `save` method manages the storing or
77 * transmission of serialized data ("artifacts") that represent the
78 * model's topology and weights onto or via a specific medium, such as
79 * file downloads, local storage, IndexedDB in the web browser and HTTP
80 * requests to a server. TensorFlow.js provides `IOHandler`
81 * implementations for a number of frequently used saving mediums, such as
82 * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
83 * for more details.
84 *
85 * This method also allows you to refer to certain types of `IOHandler`s
86 * as URL-like string shortcuts, such as 'localstorage://' and
87 * 'indexeddb://'.
88 *
89 * Example 1: Save `model`'s topology and weights to browser [local
90 * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
91 * then load it back.
92 *
93 * ```js
94 * const modelUrl =
95 * 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
96 * const model = await tf.loadGraphModel(modelUrl);
97 * const zeros = tf.zeros([1, 224, 224, 3]);
98 * model.predict(zeros).print();
99 *
100 * const saveResults = await model.save('localstorage://my-model-1');
101 *
102 * const loadedModel = await tf.loadGraphModel('localstorage://my-model-1');
103 * console.log('Prediction from loaded model:');
104 * model.predict(zeros).print();
105 * ```
106 *
107 * @param handlerOrURL An instance of `IOHandler` or a URL-like,
108 * scheme-based string shortcut for `IOHandler`.
109 * @param config Options for saving the model.
110 * @returns A `Promise` of `SaveResult`, which summarizes the result of
111 * the saving, such as byte sizes of the saved artifacts for the model's
112 * topology and weight values.
113 *
114 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
115 */
116 save(handlerOrURL: io.IOHandler | string, config?: io.SaveConfig): Promise<io.SaveResult>;
117 /**
118 * Execute the inference for the input tensors.
119 *
120 * @param input The input tensors, when there is single input for the model,
121 * inputs param should be a `tf.Tensor`. For models with mutliple inputs,
122 * inputs params should be in either `tf.Tensor`[] if the input order is
123 * fixed, or otherwise NamedTensorMap format.
124 *
125 * For model with multiple inputs, we recommend you use NamedTensorMap as the
126 * input type, if you use `tf.Tensor`[], the order of the array needs to
127 * follow the
128 * order of inputNodes array. @see {@link GraphModel.inputNodes}
129 *
130 * You can also feed any intermediate nodes using the NamedTensorMap as the
131 * input type. For example, given the graph
132 * InputNode => Intermediate => OutputNode,
133 * you can execute the subgraph Intermediate => OutputNode by calling
134 * model.execute('IntermediateNode' : tf.tensor(...));
135 *
136 * This is useful for models that uses tf.dynamic_rnn, where the intermediate
137 * state needs to be fed manually.
138 *
139 * For batch inference execution, the tensors for each input need to be
140 * concatenated together. For example with mobilenet, the required input shape
141 * is [1, 244, 244, 3], which represents the [batch, height, width, channel].
142 * If we are provide a batched data of 100 images, the input tensor should be
143 * in the shape of [100, 244, 244, 3].
144 *
145 * @param config Prediction configuration for specifying the batch size and
146 * output node names. Currently the batch size option is ignored for graph
147 * model.
148 *
149 * @returns Inference result tensors. The output would be single `tf.Tensor`
150 * if model has single output node, otherwise Tensor[] or NamedTensorMap[]
151 * will be returned for model with multiple outputs.
152 *
153 * @doc {heading: 'Models', subheading: 'Classes'}
154 */
155 predict(inputs: Tensor | Tensor[] | NamedTensorMap, config?: ModelPredictConfig): Tensor | Tensor[] | NamedTensorMap;
156 private normalizeInputs;
157 private normalizeOutputs;
158 /**
159 * Executes inference for the model for given input tensors.
160 * @param inputs tensor, tensor array or tensor map of the inputs for the
161 * model, keyed by the input node names.
162 * @param outputs output node name from the Tensorflow model, if no
163 * outputs are specified, the default outputs of the model would be used.
164 * You can inspect intermediate nodes of the model by adding them to the
165 * outputs array.
166 *
167 * @returns A single tensor if provided with a single output or no outputs
168 * are provided and there is only one default output, otherwise return a
169 * tensor array. The order of the tensor array is the same as the outputs
170 * if provided, otherwise the order of outputNodes attribute of the model.
171 *
172 * @doc {heading: 'Models', subheading: 'Classes'}
173 */
174 execute(inputs: Tensor | Tensor[] | NamedTensorMap, outputs?: string | string[]): Tensor | Tensor[];
175 /**
176 * Executes inference for the model for given input tensors in async
177 * fashion, use this method when your model contains control flow ops.
178 * @param inputs tensor, tensor array or tensor map of the inputs for the
179 * model, keyed by the input node names.
180 * @param outputs output node name from the Tensorflow model, if no outputs
181 * are specified, the default outputs of the model would be used. You can
182 * inspect intermediate nodes of the model by adding them to the outputs
183 * array.
184 *
185 * @returns A Promise of single tensor if provided with a single output or
186 * no outputs are provided and there is only one default output, otherwise
187 * return a tensor map.
188 *
189 * @doc {heading: 'Models', subheading: 'Classes'}
190 */
191 executeAsync(inputs: Tensor | Tensor[] | NamedTensorMap, outputs?: string | string[]): Promise<Tensor | Tensor[]>;
192 private convertTensorMapToTensorsMap;
193 /**
194 * Releases the memory used by the weight tensors and resourceManager.
195 *
196 * @doc {heading: 'Models', subheading: 'Classes'}
197 */
198 dispose(): void;
199}
200/**
201 * Load a graph model given a URL to the model definition.
202 *
203 * Example of loading MobileNetV2 from a URL and making a prediction with a
204 * zeros input:
205 *
206 * ```js
207 * const modelUrl =
208 * 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
209 * const model = await tf.loadGraphModel(modelUrl);
210 * const zeros = tf.zeros([1, 224, 224, 3]);
211 * model.predict(zeros).print();
212 * ```
213 *
214 * Example of loading MobileNetV2 from a TF Hub URL and making a prediction with
215 * a zeros input:
216 *
217 * ```js
218 * const modelUrl =
219 * 'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
220 * const model = await tf.loadGraphModel(modelUrl, {fromTFHub: true});
221 * const zeros = tf.zeros([1, 224, 224, 3]);
222 * model.predict(zeros).print();
223 * ```
224 * @param modelUrl The url or an `io.IOHandler` that loads the model.
225 * @param options Options for the HTTP request, which allows to send credentials
226 * and custom headers.
227 *
228 * @doc {heading: 'Models', subheading: 'Loading'}
229 */
230export declare function loadGraphModel(modelUrl: string | io.IOHandler, options?: io.LoadOptions): Promise<GraphModel>;