1 | # Getting started
|
2 |
|
3 | **TensorFlow.js converter** is an open source library to load a pretrained
|
4 | TensorFlow
|
5 | [SavedModel](https://www.tensorflow.org/guide/saved_model)
|
6 | or [TensorFlow Hub module](https://www.tensorflow.org/hub/)
|
7 | into the browser and run inference through
|
8 | [TensorFlow.js](https://js.tensorflow.org).
|
9 |
|
10 | __Note__: _Session bundle format have been deprecated.
|
11 |
|
12 | A 2-step process to import your model:
|
13 |
|
14 | 1. A python pip package to convert a TensorFlow SavedModel or TensorFlow Hub
|
15 | module to a web friendly format. If you already have a converted model, or are
|
16 | using an already hosted model (e.g. MobileNet), skip this step.
|
17 | 2. [JavaScript API](./src/executor/graph_model.ts), for loading and running
|
18 | inference.
|
19 |
|
20 | ## Step 1: Converting a [TensorFlow SavedModel](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md), [TensorFlow Hub module](https://www.tensorflow.org/hub/), [Keras HDF5](https://keras.io/getting_started/faq/#what-are-my-options-for-saving-models), [tf.keras SavedModel](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model), or [Flax/JAX model](http://github.com/google/flax) to a web-friendly format
|
21 |
|
22 | __0. Please make sure that you run in a Docker container or a virtual environment.__
|
23 |
|
24 | The script pulls its own subset of TensorFlow, which might conflict with the
|
25 | existing TensorFlow/Keras installation.
|
26 |
|
27 | __Note__: *Check that [`tf-nightly-cpu-2.0-preview`](https://pypi.org/project/tf-nightly-cpu-2.0-preview/#files) is available for your platform.*
|
28 |
|
29 | Most of the times, this means that you have to use Python 3.6.8 in your local
|
30 | environment. To force Python 3.6.8 in your local project, you can install
|
31 | [`pyenv`](https://github.com/pyenv/pyenv) and proceed as follows in the target
|
32 | directory:
|
33 |
|
34 | ```bash
|
35 | pyenv install 3.6.8
|
36 | pyenv local 3.6.8
|
37 | ```
|
38 |
|
39 | Now, you can
|
40 | [create and activate](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)
|
41 | a `venv` virtual environment in your current folder:
|
42 |
|
43 | ```bash
|
44 | virtualenv --no-site-packages venv
|
45 | . venv/bin/activate
|
46 | ```
|
47 |
|
48 | __1. Install the TensorFlow.js pip package:__
|
49 |
|
50 | Install the library with interactive CLI:
|
51 | ```bash
|
52 | pip install tensorflowjs[wizard]
|
53 | ```
|
54 |
|
55 | __2. Run the converter script provided by the pip package:__
|
56 |
|
57 | There are three way to trigger the model conversion, explain below:
|
58 |
|
59 | - The conversion wizard: `tensorflowjs_wizard` ([go to section](https://github.com/tensorflow/tfjs/blob/master/tfjs-converter/README.md#conversion-wizard-tensorflowjs_wizard))
|
60 | - Regular conversion script: `tensorflowjs_converter` ([go to section](https://github.com/tensorflow/tfjs/tree/master/tfjs-converter#regular-conversion-script-tensorflowjs_converter))
|
61 | - Calling a converter function in Python (Flax/JAX) ([go to section](https://github.com/tensorflow/tfjs/tree/master/tfjs-converter#calling-a-converter-function-in-python-flaxjax))
|
62 |
|
63 | ## Conversion wizard: `tensorflowjs_wizard`
|
64 |
|
65 | To start the conversion wizard:
|
66 | ```bash
|
67 | tensorflowjs_wizard
|
68 | ```
|
69 |
|
70 | This tool will walk you through the conversion process and provide you with
|
71 | details explanations for each choice you need to make. Behind the scene it calls
|
72 | the converter script (`tensorflowjs_converter`) in pip package. This is the
|
73 | recommended way to convert a single model.
|
74 |
|
75 | There is also a dry run mode for the wizard, which will not perform the actual
|
76 | conversion but only generate the command for `tensorflowjs_converter` command.
|
77 | This generated command can be used in your own shell script.
|
78 |
|
79 | Here is an screen capture of the wizard in action. ![wizard](./tensorflowjs_wizard.gif)
|
80 | ```bash
|
81 | tensorflowjs_wizard --dryrun
|
82 | ```
|
83 |
|
84 | To convert a batch of models or integrate the conversion process into your own
|
85 | script, you should use the tensorflowjs_converter script.
|
86 |
|
87 | ## Regular conversion script: `tensorflowjs_converter`
|
88 |
|
89 | The converter expects a __TensorFlow SavedModel__, __TensorFlow Hub module__,
|
90 | __TensorFlow.js JSON__ format, __Keras HDF5 model__, or __tf.keras SavedModel__
|
91 | for input.
|
92 |
|
93 | * __TensorFlow SavedModel__ example:
|
94 |
|
95 | ```bash
|
96 | tensorflowjs_converter \
|
97 | --input_format=tf_saved_model \
|
98 | --output_format=tfjs_graph_model \
|
99 | --signature_name=serving_default \
|
100 | --saved_model_tags=serve \
|
101 | /mobilenet/saved_model \
|
102 | /mobilenet/web_model
|
103 | ```
|
104 | * __TensorFlow Frozen Model__ example:
|
105 |
|
106 | __Note:__ Frozen model is a deprecated format and support is added for backward compatibility purpose.
|
107 |
|
108 | ```bash
|
109 | $ tensorflowjs_converter \
|
110 | --input_format=tf_frozen_model \
|
111 | --output_node_names='MobilenetV1/Predictions/Reshape_1' \
|
112 | /mobilenet/frozen_model.pb \
|
113 | /mobilenet/web_model
|
114 | ```
|
115 |
|
116 | * __Tensorflow Hub module__ example:
|
117 |
|
118 | ```bash
|
119 | tensorflowjs_converter \
|
120 | --input_format=tf_hub \
|
121 | 'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
|
122 | /mobilenet/web_model
|
123 | ```
|
124 |
|
125 | * __Keras HDF5 model__ example:
|
126 |
|
127 | ```bash
|
128 | tensorflowjs_converter \
|
129 | --input_format=keras \
|
130 | /tmp/my_keras_model.h5 \
|
131 | /tmp/my_tfjs_model
|
132 | ```
|
133 |
|
134 | * __tf.keras SavedModel__ example:
|
135 |
|
136 | ```bash
|
137 | tensorflowjs_converter \
|
138 | --input_format=keras_saved_model \
|
139 | /tmp/my_tf_keras_saved_model/1542211770 \
|
140 | /tmp/my_tfjs_model
|
141 | ```
|
142 |
|
143 | Note that the input path used above is a subfolder that has a Unix epoch
|
144 | time (1542211770) and is generated automatically by tensorflow when it
|
145 | saved a tf.keras model in the SavedModel format.
|
146 |
|
147 | ### Conversion Flags
|
148 |
|
149 | |Positional Arguments | Description |
|
150 | |---|---|
|
151 | |`input_path` | Full path of the saved model directory or TensorFlow Hub module handle or path.|
|
152 | |`output_path` | Path for all output artifacts.|
|
153 |
|
154 |
|
155 | | Options | Description
|
156 | |---|---|
|
157 | |`--input_format` | The format of input model, use `tf_saved_model` for SavedModel, `tf_hub` for TensorFlow Hub module, `tfjs_layers_model` for TensorFlow.js JSON format, and `keras` for Keras HDF5. |
|
158 | |`--output_format`| The desired output format. Must be `tfjs_layers_model`, `tfjs_graph_model` or `keras`. Not all pairs of input-output formats are supported. Please file a [github issue](https://github.com/tensorflow/tfjs/issues) if your desired input-output pair is not supported.|
|
159 | |<nobr>`--saved_model_tags`</nobr> | Only applicable to SavedModel conversion. Tags of the MetaGraphDef to load, in comma separated format. If there are no tags defined in the saved model, set it to empty string `saved_model_tags=''`. Defaults to `serve`.|
|
160 | |`--signature_name` | Only applicable to TensorFlow SavedModel and Hub module conversion, signature to load. Defaults to `serving_default` for SavedModel and `default` for Hub module. See https://www.tensorflow.org/hub/common_signatures/.|
|
161 | |`--strip_debug_ops` | Strips out TensorFlow debug operations `Print`, `Assert`, `CheckNumerics`. Defaults to `True`.|
|
162 | |`--quantization_bytes` | (Deprecated) How many bytes to optionally quantize/compress the weights to. Valid values are 1 and 2. which will quantize int32 and float32 to 1 or 2 bytes respectively. The default (unquantized) size is 4 bytes.|
|
163 | |`--quantize_float16` | Comma separated list of node names to apply float16 quantization. You can also use wildcard symbol (\*) to apply quantization to multiple nodes (e.g., conv/\*/weights). When the flag is provided without any nodes the default behavior will match all nodes. |
|
164 | |`--quantize_uint8` | Comma separated list of node names to apply 1-byte affine quantization. You can also use wildcard symbol (\*) to apply quantization to multiple nodes (e.g., conv/\*/weights). When the flag is provided without any nodes the default behavior will match all nodes. |
|
165 | |`--quantize_uint16` | Comma separated list of node names to apply 2-byte affine quantization. You can also use wildcard symbol (\*) to apply quantization to multiple nodes (e.g., conv/\*/weights). When the flag is provided without any nodes the default behavior will match all nodes. |
|
166 | |`--weight_shard_size_bytes` | Shard size (in bytes) of the weight files. Only supported when `output_format` is `tfjs_layers_model` or `tfjs_graph_model`. Default size is 4 MB (4194304 bytes).|
|
167 | |<nobr>`--output_node_names`</nobr>| Only applicable to Frozen Model. The names of the output nodes, separated by commas.|
|
168 | |<nobr>`--control_flow_v2`</nobr>| Only applicable to TF 2.x Saved Model. This flag improve performance on models with control flow ops, default to False.|
|
169 | |<nobr>`--metadata`</nobr>| Comma separated list of metadata json file paths, indexed by name. Prefer absolute path. Example: 'metadata1:/metadata1.json,metadata2:/metadata2.json'.|
|
170 | |<nobr>`--use_structured_outputs_names`</nobr>| Changes output of graph model to match the structured_outputs format instead of list format. Defaults to `False`.|
|
171 |
|
172 | __Note: If you want to convert TensorFlow session bundle, you can install older versions of the tensorflowjs pip package, i.e. `pip install tensorflowjs==0.8.6`.__
|
173 |
|
174 | ### Format Conversion Support Tables
|
175 |
|
176 | Note: Unless stated otherwise, we can infer the value of `--output_format` from the
|
177 | value of `--input_format`. So the `--output_format` flag can be omitted in
|
178 | most cases.
|
179 |
|
180 | #### Python-to-JavaScript
|
181 |
|
182 | | `--input_format` | `--output_format` | Description |
|
183 | |---|---|---|
|
184 | | `keras` | `tfjs_layers_model` | Convert a keras or tf.keras HDF5 model file to TensorFlow.js Layers model format. Use [`tf.loadLayersModel()`](https://js.tensorflow.org/api/latest/#loadLayersModel) to load the model in JavaScript. The loaded model supports the full inference and training (e.g., transfer learning) features of the original keras or tf.keras model. |
|
185 | | `keras` | `tfjs_graph_model` | Convert a keras or tf.keras HDF5 model file to TensorFlow.js Graph model format. Use [`tf.loadGraphModel()`](https://js.tensorflow.org/api/latest/#loadGraphModel) to load the converted model in JavaScript. The loaded model supports only inference, but the speed of inference is generally faster than that of a tfjs_layers_model (see above row) thanks to the graph optimization performed by TensorFlow. Another limitation of this conversion route is that it does not support some layer types (e.g., recurrent layers such as LSTM) yet. |
|
186 | | `keras_saved_model` | `tfjs_layers_model` | Convert a tf.keras SavedModel model file (from [`tf.contrib.saved_model.save_keras_model`](https://www.tensorflow.org/api_docs/python/tf/contrib/saved_model/save_keras_model)) to TensorFlow.js Layers model format. Use [`tf.loadLayersModel()`](https://js.tensorflow.org/api/latest/#loadLayersModel) to load the model in JavaScript. |
|
187 | | `tf_hub` | `tfjs_graph_model` | Convert a [TF-Hub](https://www.tensorflow.org/hub) model file to TensorFlow.js graph model format. Use [`tf.loadGraphModel()`](https://js.tensorflow.org/api/latest/#loadGraphModel) to load the converted model in JavaScript. |
|
188 | | `tf_saved_model` | `tfjs_graph_model` | Convert a [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model#build_and_load_a_savedmodel) to TensorFlow.js graph model format. Use [`tf.loadGraphModel()`](https://js.tensorflow.org/api/latest/#loadGraphModel) to load the converted model in JavaScript. |
|
189 | | `tf_frozen_model` | `tfjs_graph_model` | Convert a [Frozen Model](https://medium.com/@sebastingarcaacosta/how-to-export-a-tensorflow-2-x-keras-model-to-a-frozen-and-optimized-graph-39740846d9eb) to TensorFlow.js graph model format. Use [`tf.loadGraphModel()`](https://js.tensorflow.org/api/latest/#loadGraphModel) to load the converted model in JavaScript. |
|
190 |
|
191 | #### JavaScript-to-Python
|
192 |
|
193 | | `--input_format` | `--output_format` | Description |
|
194 | |---|---|---|
|
195 | | `tfjs_layers_model` | `keras` | Convert a TensorFlow.js Layers model (JSON + binary weight file(s)) to a Keras HDF5 model file. Use [`keras.model.load_model()`](https://keras.io/getting-started/faq/#savingloading-whole-models-architecture-weights-optimizer-state) or [`tf.keras.models.load_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model) to load the converted model in Python. |
|
196 | | `tfjs_layers_model` | `keras_saved_model` | Convert a TensorFlow.js Layers model (JSON + binary weight file(s)) to the tf.keras SavedModel format. This format is useful for subsequent uses such as [TensorFlow Serving](https://www.tensorflow.org/tfx/serving/serving_basic) and [conversion to TFLite](https://www.tensorflow.org/lite/convert). |
|
197 |
|
198 | #### JavaScript-to-JavaScript
|
199 |
|
200 | ##### Converting tfjs_layers_model to tfjs_layers_model with weight sharding and quantization
|
201 |
|
202 | The tfjs_layers_model-to-tfjs_layer_model conversion option serves the following
|
203 | purposes:
|
204 |
|
205 | 1. It allows you to shard the binary weight file into multiple small shards
|
206 | to facilitate browser caching. This step is necessary for models with
|
207 | large-sized weights saved from TensorFlow.js (either browser or Node.js),
|
208 | because TensorFlow.js puts all weights in a single weight file
|
209 | ('group1-shard1of1.bin'). To shard the weight file, do
|
210 |
|
211 | ```sh
|
212 | tensorflowjs_converter \
|
213 | --input_format tfjs_layers_model \
|
214 | --output_format tfjs_layers_model \
|
215 | original_model/model.json \
|
216 | sharded_model/
|
217 | ```
|
218 |
|
219 | The command above creates shards of size 4 MB (4194304 bytes) by default.
|
220 | Alternative shard sizes can be specified using the
|
221 | `--weight_shard_size_bytes` flag.
|
222 |
|
223 | 2. It allows you to reduce the on-the-wire size of the weights through
|
224 | 16- or 8-bit quantization. For example:
|
225 |
|
226 | ```sh
|
227 | tensorflowjs_converter \
|
228 | --input_format tfjs_layers_model \
|
229 | --output_format tfjs_layers_model \
|
230 | --quantize_uint16 \
|
231 | original_model/model.json
|
232 | quantized_model/
|
233 | ```
|
234 |
|
235 | ##### Converting tfjs_layers_model to tfjs_graph_model
|
236 |
|
237 | Converting a `tfjs_layers_model` to a `tfjs_graph_model` usually leads to
|
238 | faster inference speed in the browser and Node.js, thanks to the graph
|
239 | optimization that goes into generating the tfjs_graph_models. For more details,
|
240 | see the following document on TensorFlow's Grappler:
|
241 | ["TensorFlow Graph Optimizations" by R. Larsen an T. Shpeisman](https://ai.google/research/pubs/pub48051).
|
242 |
|
243 | There are two caveats:
|
244 |
|
245 | 1. The model that results from this conversion does not support further
|
246 | training.
|
247 | 2. Certain layer types (e.g., recurrent layers such as LSTM) are not supported
|
248 | yet.
|
249 |
|
250 | See example command below:
|
251 |
|
252 | ```sh
|
253 | tensorflowjs_converter \
|
254 | --input_format tfjs_layers_model \
|
255 | --output_format tfjs_graph_model \
|
256 | my_layers_model/model.json
|
257 | my_graph_model/
|
258 | ```
|
259 |
|
260 | tfjs_layers_model-to-tfjs_graph_model also support weight quantization.
|
261 |
|
262 | ### Web-friendly format
|
263 |
|
264 | The conversion script above produces 2 types of files:
|
265 |
|
266 | * `model.json` (the dataflow graph and weight manifest file)
|
267 | * `group1-shard\*of\*` (collection of binary weight files)
|
268 |
|
269 | For example, here is the MobileNet model converted and served in
|
270 | following location:
|
271 |
|
272 | ```html
|
273 | https://storage.cloud.google.com/tfjs-models/savedmodel/mobilenet_v1_1.0_224/model.json
|
274 | https://storage.cloud.google.com/tfjs-models/savedmodel/mobilenet_v1_1.0_224/group1-shard1of5
|
275 | ...
|
276 | https://storage.cloud.google.com/tfjs-models/savedmodel/mobilenet_v1_1.0_224/group1-shard5of5
|
277 | ```
|
278 |
|
279 | ## Calling a Converter Function in Python (Flax/JAX)
|
280 |
|
281 | You can also convert your model to web format in Python by calling one of the
|
282 | conversion functions. This is currently the only way to convert a Flax or JAX
|
283 | model, since no standard serialization format exists to store a Module (only the
|
284 | checkpoints).
|
285 |
|
286 | Here we provide an example of how to convert a Flax function using the
|
287 | conversion function `tfjs.jax_conversion.convert_jax()`.
|
288 |
|
289 | ```py
|
290 | import numpy as np
|
291 | from flax import linen as nn
|
292 | from jax import random
|
293 | import jax.numpy as jnp
|
294 | from tensorflowjs.converters import jax_conversion
|
295 |
|
296 | module = nn.Dense(features=4)
|
297 | inputs = jnp.ones((3, 4))
|
298 | params = module.init(random.PRNKey(0), inputs)['params']
|
299 |
|
300 | jax_conversion.convert_jax(
|
301 | apply_fn=module.apply,
|
302 | params=params,
|
303 | input_signatures=[((3, 4), np.float32)],
|
304 | model_dir=tfjs_model_dir)
|
305 | ```
|
306 |
|
307 | Note that when using dynamic shapes, an additional argument `polymorphic_shapes`
|
308 | should be provided specifying values for the dynamic ("polymorphic")
|
309 | dimensions). So in order to convert the same model as before, but now with a
|
310 | dynamic first dimension, one should call `convert_jax` as follows:
|
311 |
|
312 | ```py
|
313 | jax_conversion.convert_jax(
|
314 | apply_fn=module.apply,
|
315 | params=params,
|
316 | input_signatures=[((None, 4), np.float32)],
|
317 | polymorphic_shapes=["(b, 4)"],
|
318 | model_dir=tfjs_model_dir)
|
319 | ```
|
320 |
|
321 | See
|
322 | [here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion)
|
323 | for more details on the exact syntax for this argument.
|
324 |
|
325 | When converting JAX models, you can also pass any [options that
|
326 | `convert_tf_saved_model`
|
327 | uses](https://github.com/tensorflow/tfjs/blob/master/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py#L951-L974).
|
328 | For example, to quantize a model's weights, pass the `quantization_dtype_map`
|
329 | option listing the weights that should be quantized.
|
330 |
|
331 | ```py
|
332 | jax_conversion.convert_jax(
|
333 | apply_fn=module.apply,
|
334 | params=params,
|
335 | input_signatures=[((3, 4), np.float32)],
|
336 | model_dir=tfjs_model_dir,
|
337 | quantization_dtype_map={'float16': '*'})
|
338 | ```
|
339 |
|
340 | ## Step 2: Loading and running in the browser
|
341 |
|
342 | If the original model was a `SavedModel`, use
|
343 | [`tf.loadGraphModel()`](https://js.tensorflow.org/api/latest/#loadGraphModel).
|
344 | If it was Keras, use
|
345 | [`tf.loadLayersModel()`](https://js.tensorflow.org/api/latest/#loadLayersModel):
|
346 |
|
347 | ```typescript
|
348 | import * as tf from '@tensorflow/tfjs';
|
349 |
|
350 | const MODEL_URL = 'https://.../mobilenet/model.json';
|
351 |
|
352 | // For Keras use tf.loadLayersModel().
|
353 | const model = await tf.loadGraphModel(MODEL_URL);
|
354 | const cat = document.getElementById('cat');
|
355 | model.predict(tf.browser.fromPixels(cat));
|
356 | ```
|
357 |
|
358 | See our API docs for the
|
359 | [`predict()`](https://js.tensorflow.org/api/latest/#tf.GraphModel.predict)
|
360 | method. To see what other methods exist on a `Model`, see
|
361 | [`tf.LayersModel`](https://js.tensorflow.org/api/latest/#class:LayersModel)
|
362 | and [`tf.GraphModel`](https://js.tensorflow.org/api/latest/#class:GraphModel).
|
363 | Also check out our working [MobileNet demo](./demo/mobilenet/README.md).
|
364 |
|
365 | If your server requests credentials for accessing the model files, you can
|
366 | provide the optional RequestOption param.
|
367 |
|
368 | ```typescript
|
369 | const model = await loadGraphModel(MODEL_URL,
|
370 | {credentials: 'include'});
|
371 | ```
|
372 |
|
373 | Please see
|
374 | [fetch() documentation](https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch)
|
375 | for details.
|
376 |
|
377 | ### Native File System
|
378 |
|
379 | TensorFlow.js can be used from Node.js. See the
|
380 | [tfjs-node project](https://github.com/tensorflow/tfjs-node) for more details.
|
381 | Unlike web browsers, Node.js can access the local file system directly.
|
382 | Therefore, you can load the same frozen model from local file system into
|
383 | a Node.js program running TensorFlow.js. This is done by calling
|
384 | `loadGraphModel` with the path to the model files:
|
385 |
|
386 | ```js
|
387 | // Load the tfjs-node binding
|
388 | import * as tf from '@tensorflow/tfjs-node';
|
389 |
|
390 | const MODEL_PATH = 'file:///tmp/mobilenet/model.json';
|
391 | const model = await tf.loadGraphModel(MODEL_PATH);
|
392 | ```
|
393 |
|
394 | You can also load the remote model files the same way as in browser, but you
|
395 | might need to polyfill
|
396 | the fetch() method.
|
397 |
|
398 | ## Supported operations
|
399 |
|
400 | Currently TensorFlow.js only supports a limited set of TensorFlow Ops. See the
|
401 | [full list](./docs/supported_ops.md).
|
402 | If your model uses unsupported ops, the `tensorflowjs_converter` script will
|
403 | fail and produce a list of the unsupported ops in your model. Please file issues
|
404 | to let us know what ops you need support with.
|
405 |
|
406 | ## Manual forward pass and direct weights loading
|
407 |
|
408 | If you want to manually write the forward pass with the ops API, you can load
|
409 | the weights directly as a map from weight names to tensors:
|
410 |
|
411 | ```js
|
412 | import * as tf from '@tensorflow/tfjs';
|
413 |
|
414 | const modelUrl = "https://example.org/model/model.json";
|
415 |
|
416 | const response = await fetch(modelUrl);
|
417 | this.weightManifest = (await response.json())['weightsManifest'];
|
418 | const weightMap = await tf.io.loadWeights(
|
419 | this.weightManifest, "https://example.org/model");
|
420 | ```
|
421 |
|
422 | `weightMap` maps a weight name to a tensor. You can use it to manually implement
|
423 | the forward pass of the model:
|
424 |
|
425 | ```js
|
426 | const input = tf.tensor(...);
|
427 | tf.matMul(weightMap['fc1/weights'], input).add(weightMap['fc1/bias']);
|
428 | ```
|
429 |
|
430 | ## FAQ
|
431 |
|
432 | __1. What TensorFlow models does the converter currently support?__
|
433 |
|
434 | Image-based models (MobileNet, SqueezeNet, add more if you tested) are the most
|
435 | supported. Models with control flow ops (e.g. RNNs) are also supported.
|
436 | The tensorflowjs_converter script will validate the model you have and show a
|
437 | list of unsupported ops in your model. See [this list](./docs/supported_ops.md)
|
438 | for which ops are currently supported.
|
439 |
|
440 | __2. Will model with large weights work?__
|
441 |
|
442 | While the browser supports loading 100-500MB models, the page load time,
|
443 | the inference time and the user experience would not be great. We recommend
|
444 | using models that are designed for edge devices (e.g. phones). These models are
|
445 | usually smaller than 30MB.
|
446 |
|
447 | __3. Will the model and weight files be cached in the browser?__
|
448 |
|
449 | Yes, we are splitting the weights into files of 4MB chunks, which enable the
|
450 | browser to cache them automatically. If the model architecture is less than 4MB
|
451 | (most models are), it will also be cached.
|
452 |
|
453 | __4. Can I quantize the weights over the wire?__
|
454 |
|
455 | Yes, you can use the --quantize_{float16, uint8, uint16} flags to compress
|
456 | weights with 1 byte integer quantization (`uint8`) or 2 byte integer
|
457 | (`uint16`)/float (`float16`) quantization.
|
458 | Quantizing to float16 may provide better accuracy over
|
459 | 2 byte affine integer scaling (`uint16`). 1-byte affine quantization,
|
460 | i.e., `uint8` provides a 4x size reduction at the cost of accuracy.
|
461 | For example, we can quantize our MobileNet model using float16 quantization:
|
462 |
|
463 | ```
|
464 | tensorflowjs_converter
|
465 | --quantize_float16 \
|
466 | --input_format=tf_hub \
|
467 | 'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
|
468 | /mobilenet/web_model
|
469 | ```
|
470 |
|
471 | You can also quantize specific weights as well as weight groupings using
|
472 | a wildcard replacement. For example,
|
473 | ```
|
474 | tensorflowjs_converter
|
475 | --quantize_float16="conv/*/weights"
|
476 | ```
|
477 | which will quantize all weights that match the pattern conv/*/weights.
|
478 | This will exclude biases and any weights that don't begin with conv/.
|
479 | This can be a powerful tool to reduce model size while trying to maximize
|
480 | performance.
|
481 |
|
482 | __5. Why is the predict() method for inference so much slower on the first call than the subsequent calls?__
|
483 |
|
484 | The time of first call also includes the compilation time of WebGL shader
|
485 | programs for the model. After the first call the shader programs are cached,
|
486 | which makes the subsequent calls much faster. You can warm up the cache by
|
487 | calling the predict method with an all zero inputs, right after the completion
|
488 | of the model loading.
|
489 |
|
490 | __6. I have a model converted with a previous version of TensorFlow.js converter (0.15.x), that is in .pb format. How do I convert it to the new JSON format?__
|
491 |
|
492 | You can use the built-in migration tool to convert the models generated by
|
493 | previous versions. Here are the steps:
|
494 |
|
495 | ```bash
|
496 | git clone git@github.com:tensorflow/tfjs-converter.git
|
497 | cd tfjs-converter
|
498 | yarn
|
499 | yarn ts-node tools/pb2json_converter.ts pb_model_directory/ json_model_directory/
|
500 | ```
|
501 |
|
502 | `pb_model_directory` is the directory where the model generated by previous
|
503 | version is located.
|
504 | `json_model_directory` is the destination directory for the converted model.
|
505 |
|
506 |
|
507 | ## Development
|
508 |
|
509 | To build **TensorFlow.js converter** from source, we need to prepare the dev environment and clone the project.
|
510 |
|
511 | Bazel builds Python from source, so we install the dependencies required to build it. Since we will be using pip and C extensions, we also install the ssl, foreign functions, and zlib development packages. On debian, this is done with:
|
512 |
|
513 | ```bash
|
514 | sudo apt-get build-dep python3
|
515 | sudo apt install libssl-dev libffi-dev zlib1g-dev
|
516 | ```
|
517 |
|
518 | See the [python developer guide](https://devguide.python.org/setup/#install-dependencies) for instructions on installing these for other platforms.
|
519 |
|
520 | Then, we clone the project and install dependencies with:
|
521 |
|
522 | ```bash
|
523 | git clone https://github.com/tensorflow/tfjs.git
|
524 | cd tfjs
|
525 | yarn # Installs dependencies.
|
526 | ```
|
527 |
|
528 | We recommend using [Visual Studio Code](https://code.visualstudio.com/) for
|
529 | development. Make sure to install
|
530 | [TSLint VSCode extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode.vscode-typescript-tslint-plugin)
|
531 | and the npm [clang-format](https://github.com/angular/clang-format) `1.2.2`
|
532 | or later with the
|
533 | [Clang-Format VSCode extension](https://marketplace.visualstudio.com/items?itemName=xaver.clang-format)
|
534 | for auto-formatting.
|
535 |
|
536 | Before submitting a pull request, make sure the code passes all the tests and is
|
537 | clean of lint errors:
|
538 |
|
539 | ```bash
|
540 | cd tfjs-converter
|
541 | yarn test
|
542 | yarn lint
|
543 | ```
|
544 |
|
545 | To run a subset of tests and/or on a specific browser:
|
546 |
|
547 | ```bash
|
548 | yarn test --browsers=Chrome --grep='execute'
|
549 | > ...
|
550 | > Chrome 64.0.3282 (Linux 0.0.0): Executed 39 of 39 SUCCESS (0.129 secs / 0 secs)
|
551 | ```
|
552 |
|
553 | To run the tests once and exit the karma process (helpful on Windows):
|
554 |
|
555 | ```bash
|
556 | yarn test --single-run
|
557 | ```
|
558 |
|
559 | To run all the python tests
|
560 |
|
561 | ```bash
|
562 | yarn run-python-tests
|
563 | ```
|