UNPKG

25 kBMarkdownView Raw
1# Getting started
2
3**TensorFlow.js converter** is an open source library to load a pretrained
4TensorFlow
5[SavedModel](https://www.tensorflow.org/guide/saved_model)
6or [TensorFlow Hub module](https://www.tensorflow.org/hub/)
7into the browser and run inference through
8[TensorFlow.js](https://js.tensorflow.org).
9
10__Note__: _Session bundle format have been deprecated.
11
12A 2-step process to import your model:
13
141. A python pip package to convert a TensorFlow SavedModel or TensorFlow Hub
15module to a web friendly format. If you already have a converted model, or are
16using an already hosted model (e.g. MobileNet), skip this step.
172. [JavaScript API](./src/executor/graph_model.ts), for loading and running
18inference.
19
20## Step 1: Converting a [TensorFlow SavedModel](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md), [TensorFlow Hub module](https://www.tensorflow.org/hub/), [Keras HDF5](https://keras.io/getting_started/faq/#what-are-my-options-for-saving-models), [tf.keras SavedModel](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model), or [Flax/JAX model](http://github.com/google/flax) to a web-friendly format
21
22__0. Please make sure that you run in a Docker container or a virtual environment.__
23
24 The script pulls its own subset of TensorFlow, which might conflict with the
25 existing TensorFlow/Keras installation.
26
27__Note__: *Check that [`tf-nightly-cpu-2.0-preview`](https://pypi.org/project/tf-nightly-cpu-2.0-preview/#files) is available for your platform.*
28
29Most of the times, this means that you have to use Python 3.6.8 in your local
30environment. To force Python 3.6.8 in your local project, you can install
31[`pyenv`](https://github.com/pyenv/pyenv) and proceed as follows in the target
32directory:
33
34```bash
35pyenv install 3.6.8
36pyenv local 3.6.8
37```
38
39Now, you can
40[create and activate](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)
41a `venv` virtual environment in your current folder:
42
43```bash
44virtualenv --no-site-packages venv
45. venv/bin/activate
46```
47
48__1. Install the TensorFlow.js pip package:__
49
50Install the library with interactive CLI:
51```bash
52 pip install tensorflowjs[wizard]
53```
54
55__2. Run the converter script provided by the pip package:__
56
57There are three way to trigger the model conversion, explain below:
58
59- The conversion wizard: `tensorflowjs_wizard` ([go to section](https://github.com/tensorflow/tfjs/blob/master/tfjs-converter/README.md#conversion-wizard-tensorflowjs_wizard))
60- Regular conversion script: `tensorflowjs_converter` ([go to section](https://github.com/tensorflow/tfjs/tree/master/tfjs-converter#regular-conversion-script-tensorflowjs_converter))
61- Calling a converter function in Python (Flax/JAX) ([go to section](https://github.com/tensorflow/tfjs/tree/master/tfjs-converter#calling-a-converter-function-in-python-flaxjax))
62
63## Conversion wizard: `tensorflowjs_wizard`
64
65To start the conversion wizard:
66```bash
67tensorflowjs_wizard
68```
69
70This tool will walk you through the conversion process and provide you with
71details explanations for each choice you need to make. Behind the scene it calls
72the converter script (`tensorflowjs_converter`) in pip package. This is the
73recommended way to convert a single model.
74
75There is also a dry run mode for the wizard, which will not perform the actual
76conversion but only generate the command for `tensorflowjs_converter` command.
77This generated command can be used in your own shell script.
78
79Here is an screen capture of the wizard in action. ![wizard](./tensorflowjs_wizard.gif)
80```bash
81tensorflowjs_wizard --dryrun
82```
83
84To convert a batch of models or integrate the conversion process into your own
85script, you should use the tensorflowjs_converter script.
86
87## Regular conversion script: `tensorflowjs_converter`
88
89The converter expects a __TensorFlow SavedModel__, __TensorFlow Hub module__,
90__TensorFlow.js JSON__ format, __Keras HDF5 model__, or __tf.keras SavedModel__
91for input.
92
93* __TensorFlow SavedModel__ example:
94
95```bash
96tensorflowjs_converter \
97 --input_format=tf_saved_model \
98 --output_format=tfjs_graph_model \
99 --signature_name=serving_default \
100 --saved_model_tags=serve \
101 /mobilenet/saved_model \
102 /mobilenet/web_model
103```
104* __TensorFlow Frozen Model__ example:
105
106__Note:__ Frozen model is a deprecated format and support is added for backward compatibility purpose.
107
108```bash
109$ tensorflowjs_converter \
110 --input_format=tf_frozen_model \
111 --output_node_names='MobilenetV1/Predictions/Reshape_1' \
112 /mobilenet/frozen_model.pb \
113 /mobilenet/web_model
114```
115
116* __Tensorflow Hub module__ example:
117
118```bash
119tensorflowjs_converter \
120 --input_format=tf_hub \
121 'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
122 /mobilenet/web_model
123```
124
125* __Keras HDF5 model__ example:
126
127```bash
128tensorflowjs_converter \
129 --input_format=keras \
130 /tmp/my_keras_model.h5 \
131 /tmp/my_tfjs_model
132```
133
134* __tf.keras SavedModel__ example:
135
136```bash
137tensorflowjs_converter \
138 --input_format=keras_saved_model \
139 /tmp/my_tf_keras_saved_model/1542211770 \
140 /tmp/my_tfjs_model
141```
142
143Note that the input path used above is a subfolder that has a Unix epoch
144time (1542211770) and is generated automatically by tensorflow when it
145saved a tf.keras model in the SavedModel format.
146
147### Conversion Flags
148
149|Positional Arguments | Description |
150|---|---|
151|`input_path` | Full path of the saved model directory or TensorFlow Hub module handle or path.|
152|`output_path` | Path for all output artifacts.|
153
154
155| Options | Description
156|---|---|
157|`--input_format` | The format of input model, use `tf_saved_model` for SavedModel, `tf_hub` for TensorFlow Hub module, `tfjs_layers_model` for TensorFlow.js JSON format, and `keras` for Keras HDF5. |
158|`--output_format`| The desired output format. Must be `tfjs_layers_model`, `tfjs_graph_model` or `keras`. Not all pairs of input-output formats are supported. Please file a [github issue](https://github.com/tensorflow/tfjs/issues) if your desired input-output pair is not supported.|
159|<nobr>`--saved_model_tags`</nobr> | Only applicable to SavedModel conversion. Tags of the MetaGraphDef to load, in comma separated format. If there are no tags defined in the saved model, set it to empty string `saved_model_tags=''`. Defaults to `serve`.|
160|`--signature_name` | Only applicable to TensorFlow SavedModel and Hub module conversion, signature to load. Defaults to `serving_default` for SavedModel and `default` for Hub module. See https://www.tensorflow.org/hub/common_signatures/.|
161|`--strip_debug_ops` | Strips out TensorFlow debug operations `Print`, `Assert`, `CheckNumerics`. Defaults to `True`.|
162|`--quantization_bytes` | (Deprecated) How many bytes to optionally quantize/compress the weights to. Valid values are 1 and 2. which will quantize int32 and float32 to 1 or 2 bytes respectively. The default (unquantized) size is 4 bytes.|
163|`--quantize_float16` | Comma separated list of node names to apply float16 quantization. You can also use wildcard symbol (\*) to apply quantization to multiple nodes (e.g., conv/\*/weights). When the flag is provided without any nodes the default behavior will match all nodes. |
164|`--quantize_uint8` | Comma separated list of node names to apply 1-byte affine quantization. You can also use wildcard symbol (\*) to apply quantization to multiple nodes (e.g., conv/\*/weights). When the flag is provided without any nodes the default behavior will match all nodes. |
165|`--quantize_uint16` | Comma separated list of node names to apply 2-byte affine quantization. You can also use wildcard symbol (\*) to apply quantization to multiple nodes (e.g., conv/\*/weights). When the flag is provided without any nodes the default behavior will match all nodes. |
166|`--weight_shard_size_bytes` | Shard size (in bytes) of the weight files. Only supported when `output_format` is `tfjs_layers_model` or `tfjs_graph_model`. Default size is 4 MB (4194304 bytes).|
167|<nobr>`--output_node_names`</nobr>| Only applicable to Frozen Model. The names of the output nodes, separated by commas.|
168|<nobr>`--control_flow_v2`</nobr>| Only applicable to TF 2.x Saved Model. This flag improve performance on models with control flow ops, default to False.|
169|<nobr>`--metadata`</nobr>| Comma separated list of metadata json file paths, indexed by name. Prefer absolute path. Example: 'metadata1:/metadata1.json,metadata2:/metadata2.json'.|
170|<nobr>`--use_structured_outputs_names`</nobr>| Changes output of graph model to match the structured_outputs format instead of list format. Defaults to `False`.|
171
172__Note: If you want to convert TensorFlow session bundle, you can install older versions of the tensorflowjs pip package, i.e. `pip install tensorflowjs==0.8.6`.__
173
174### Format Conversion Support Tables
175
176Note: Unless stated otherwise, we can infer the value of `--output_format` from the
177value of `--input_format`. So the `--output_format` flag can be omitted in
178most cases.
179
180#### Python-to-JavaScript
181
182| `--input_format` | `--output_format` | Description |
183|---|---|---|
184| `keras` | `tfjs_layers_model` | Convert a keras or tf.keras HDF5 model file to TensorFlow.js Layers model format. Use [`tf.loadLayersModel()`](https://js.tensorflow.org/api/latest/#loadLayersModel) to load the model in JavaScript. The loaded model supports the full inference and training (e.g., transfer learning) features of the original keras or tf.keras model. |
185| `keras` | `tfjs_graph_model` | Convert a keras or tf.keras HDF5 model file to TensorFlow.js Graph model format. Use [`tf.loadGraphModel()`](https://js.tensorflow.org/api/latest/#loadGraphModel) to load the converted model in JavaScript. The loaded model supports only inference, but the speed of inference is generally faster than that of a tfjs_layers_model (see above row) thanks to the graph optimization performed by TensorFlow. Another limitation of this conversion route is that it does not support some layer types (e.g., recurrent layers such as LSTM) yet. |
186| `keras_saved_model` | `tfjs_layers_model` | Convert a tf.keras SavedModel model file (from [`tf.contrib.saved_model.save_keras_model`](https://www.tensorflow.org/api_docs/python/tf/contrib/saved_model/save_keras_model)) to TensorFlow.js Layers model format. Use [`tf.loadLayersModel()`](https://js.tensorflow.org/api/latest/#loadLayersModel) to load the model in JavaScript. |
187| `tf_hub` | `tfjs_graph_model` | Convert a [TF-Hub](https://www.tensorflow.org/hub) model file to TensorFlow.js graph model format. Use [`tf.loadGraphModel()`](https://js.tensorflow.org/api/latest/#loadGraphModel) to load the converted model in JavaScript. |
188| `tf_saved_model` | `tfjs_graph_model` | Convert a [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model#build_and_load_a_savedmodel) to TensorFlow.js graph model format. Use [`tf.loadGraphModel()`](https://js.tensorflow.org/api/latest/#loadGraphModel) to load the converted model in JavaScript. |
189| `tf_frozen_model` | `tfjs_graph_model` | Convert a [Frozen Model](https://medium.com/@sebastingarcaacosta/how-to-export-a-tensorflow-2-x-keras-model-to-a-frozen-and-optimized-graph-39740846d9eb) to TensorFlow.js graph model format. Use [`tf.loadGraphModel()`](https://js.tensorflow.org/api/latest/#loadGraphModel) to load the converted model in JavaScript. |
190
191#### JavaScript-to-Python
192
193| `--input_format` | `--output_format` | Description |
194|---|---|---|
195| `tfjs_layers_model` | `keras` | Convert a TensorFlow.js Layers model (JSON + binary weight file(s)) to a Keras HDF5 model file. Use [`keras.model.load_model()`](https://keras.io/getting-started/faq/#savingloading-whole-models-architecture-weights-optimizer-state) or [`tf.keras.models.load_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model) to load the converted model in Python. |
196| `tfjs_layers_model` | `keras_saved_model` | Convert a TensorFlow.js Layers model (JSON + binary weight file(s)) to the tf.keras SavedModel format. This format is useful for subsequent uses such as [TensorFlow Serving](https://www.tensorflow.org/tfx/serving/serving_basic) and [conversion to TFLite](https://www.tensorflow.org/lite/convert). |
197
198#### JavaScript-to-JavaScript
199
200##### Converting tfjs_layers_model to tfjs_layers_model with weight sharding and quantization
201
202The tfjs_layers_model-to-tfjs_layer_model conversion option serves the following
203purposes:
204
2051. It allows you to shard the binary weight file into multiple small shards
206 to facilitate browser caching. This step is necessary for models with
207 large-sized weights saved from TensorFlow.js (either browser or Node.js),
208 because TensorFlow.js puts all weights in a single weight file
209 ('group1-shard1of1.bin'). To shard the weight file, do
210
211 ```sh
212 tensorflowjs_converter \
213 --input_format tfjs_layers_model \
214 --output_format tfjs_layers_model \
215 original_model/model.json \
216 sharded_model/
217 ```
218
219 The command above creates shards of size 4 MB (4194304 bytes) by default.
220 Alternative shard sizes can be specified using the
221 `--weight_shard_size_bytes` flag.
222
2232. It allows you to reduce the on-the-wire size of the weights through
224 16- or 8-bit quantization. For example:
225
226 ```sh
227 tensorflowjs_converter \
228 --input_format tfjs_layers_model \
229 --output_format tfjs_layers_model \
230 --quantize_uint16 \
231 original_model/model.json
232 quantized_model/
233 ```
234
235##### Converting tfjs_layers_model to tfjs_graph_model
236
237Converting a `tfjs_layers_model` to a `tfjs_graph_model` usually leads to
238faster inference speed in the browser and Node.js, thanks to the graph
239optimization that goes into generating the tfjs_graph_models. For more details,
240see the following document on TensorFlow's Grappler:
241["TensorFlow Graph Optimizations" by R. Larsen an T. Shpeisman](https://ai.google/research/pubs/pub48051).
242
243There are two caveats:
244
2451. The model that results from this conversion does not support further
246 training.
2472. Certain layer types (e.g., recurrent layers such as LSTM) are not supported
248 yet.
249
250See example command below:
251
252```sh
253 tensorflowjs_converter \
254 --input_format tfjs_layers_model \
255 --output_format tfjs_graph_model \
256 my_layers_model/model.json
257 my_graph_model/
258```
259
260tfjs_layers_model-to-tfjs_graph_model also support weight quantization.
261
262### Web-friendly format
263
264The conversion script above produces 2 types of files:
265
266* `model.json` (the dataflow graph and weight manifest file)
267* `group1-shard\*of\*` (collection of binary weight files)
268
269For example, here is the MobileNet model converted and served in
270following location:
271
272```html
273 https://storage.cloud.google.com/tfjs-models/savedmodel/mobilenet_v1_1.0_224/model.json
274 https://storage.cloud.google.com/tfjs-models/savedmodel/mobilenet_v1_1.0_224/group1-shard1of5
275 ...
276 https://storage.cloud.google.com/tfjs-models/savedmodel/mobilenet_v1_1.0_224/group1-shard5of5
277```
278
279## Calling a Converter Function in Python (Flax/JAX)
280
281You can also convert your model to web format in Python by calling one of the
282conversion functions. This is currently the only way to convert a Flax or JAX
283model, since no standard serialization format exists to store a Module (only the
284checkpoints).
285
286Here we provide an example of how to convert a Flax function using the
287conversion function `tfjs.jax_conversion.convert_jax()`.
288
289```py
290import numpy as np
291from flax import linen as nn
292from jax import random
293import jax.numpy as jnp
294from tensorflowjs.converters import jax_conversion
295
296module = nn.Dense(features=4)
297inputs = jnp.ones((3, 4))
298params = module.init(random.PRNKey(0), inputs)['params']
299
300jax_conversion.convert_jax(
301 apply_fn=module.apply,
302 params=params,
303 input_signatures=[((3, 4), np.float32)],
304 model_dir=tfjs_model_dir)
305```
306
307Note that when using dynamic shapes, an additional argument `polymorphic_shapes`
308should be provided specifying values for the dynamic ("polymorphic")
309dimensions). So in order to convert the same model as before, but now with a
310dynamic first dimension, one should call `convert_jax` as follows:
311
312```py
313jax_conversion.convert_jax(
314 apply_fn=module.apply,
315 params=params,
316 input_signatures=[((None, 4), np.float32)],
317 polymorphic_shapes=["(b, 4)"],
318 model_dir=tfjs_model_dir)
319```
320
321See
322[here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion)
323for more details on the exact syntax for this argument.
324
325When converting JAX models, you can also pass any [options that
326`convert_tf_saved_model`
327uses](https://github.com/tensorflow/tfjs/blob/master/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py#L951-L974).
328For example, to quantize a model's weights, pass the `quantization_dtype_map`
329option listing the weights that should be quantized.
330
331```py
332jax_conversion.convert_jax(
333 apply_fn=module.apply,
334 params=params,
335 input_signatures=[((3, 4), np.float32)],
336 model_dir=tfjs_model_dir,
337 quantization_dtype_map={'float16': '*'})
338```
339
340## Step 2: Loading and running in the browser
341
342If the original model was a `SavedModel`, use
343[`tf.loadGraphModel()`](https://js.tensorflow.org/api/latest/#loadGraphModel).
344If it was Keras, use
345[`tf.loadLayersModel()`](https://js.tensorflow.org/api/latest/#loadLayersModel):
346
347```typescript
348import * as tf from '@tensorflow/tfjs';
349
350const MODEL_URL = 'https://.../mobilenet/model.json';
351
352// For Keras use tf.loadLayersModel().
353const model = await tf.loadGraphModel(MODEL_URL);
354const cat = document.getElementById('cat');
355model.predict(tf.browser.fromPixels(cat));
356```
357
358See our API docs for the
359[`predict()`](https://js.tensorflow.org/api/latest/#tf.GraphModel.predict)
360method. To see what other methods exist on a `Model`, see
361[`tf.LayersModel`](https://js.tensorflow.org/api/latest/#class:LayersModel)
362and [`tf.GraphModel`](https://js.tensorflow.org/api/latest/#class:GraphModel).
363Also check out our working [MobileNet demo](./demo/mobilenet/README.md).
364
365If your server requests credentials for accessing the model files, you can
366provide the optional RequestOption param.
367
368```typescript
369const model = await loadGraphModel(MODEL_URL,
370 {credentials: 'include'});
371```
372
373Please see
374[fetch() documentation](https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch)
375for details.
376
377### Native File System
378
379TensorFlow.js can be used from Node.js. See the
380[tfjs-node project](https://github.com/tensorflow/tfjs-node) for more details.
381Unlike web browsers, Node.js can access the local file system directly.
382Therefore, you can load the same frozen model from local file system into
383a Node.js program running TensorFlow.js. This is done by calling
384`loadGraphModel` with the path to the model files:
385
386```js
387// Load the tfjs-node binding
388import * as tf from '@tensorflow/tfjs-node';
389
390const MODEL_PATH = 'file:///tmp/mobilenet/model.json';
391const model = await tf.loadGraphModel(MODEL_PATH);
392```
393
394You can also load the remote model files the same way as in browser, but you
395might need to polyfill
396the fetch() method.
397
398## Supported operations
399
400Currently TensorFlow.js only supports a limited set of TensorFlow Ops. See the
401[full list](./docs/supported_ops.md).
402If your model uses unsupported ops, the `tensorflowjs_converter` script will
403fail and produce a list of the unsupported ops in your model. Please file issues
404to let us know what ops you need support with.
405
406## Manual forward pass and direct weights loading
407
408If you want to manually write the forward pass with the ops API, you can load
409the weights directly as a map from weight names to tensors:
410
411```js
412import * as tf from '@tensorflow/tfjs';
413
414const modelUrl = "https://example.org/model/model.json";
415
416const response = await fetch(modelUrl);
417this.weightManifest = (await response.json())['weightsManifest'];
418const weightMap = await tf.io.loadWeights(
419 this.weightManifest, "https://example.org/model");
420```
421
422`weightMap` maps a weight name to a tensor. You can use it to manually implement
423the forward pass of the model:
424
425```js
426const input = tf.tensor(...);
427tf.matMul(weightMap['fc1/weights'], input).add(weightMap['fc1/bias']);
428```
429
430## FAQ
431
432__1. What TensorFlow models does the converter currently support?__
433
434Image-based models (MobileNet, SqueezeNet, add more if you tested) are the most
435supported. Models with control flow ops (e.g. RNNs) are also supported.
436The tensorflowjs_converter script will validate the model you have and show a
437list of unsupported ops in your model. See [this list](./docs/supported_ops.md)
438for which ops are currently supported.
439
440__2. Will model with large weights work?__
441
442While the browser supports loading 100-500MB models, the page load time,
443the inference time and the user experience would not be great. We recommend
444using models that are designed for edge devices (e.g. phones). These models are
445usually smaller than 30MB.
446
447__3. Will the model and weight files be cached in the browser?__
448
449Yes, we are splitting the weights into files of 4MB chunks, which enable the
450browser to cache them automatically. If the model architecture is less than 4MB
451(most models are), it will also be cached.
452
453__4. Can I quantize the weights over the wire?__
454
455Yes, you can use the --quantize_{float16, uint8, uint16} flags to compress
456weights with 1 byte integer quantization (`uint8`) or 2 byte integer
457(`uint16`)/float (`float16`) quantization.
458Quantizing to float16 may provide better accuracy over
4592 byte affine integer scaling (`uint16`). 1-byte affine quantization,
460i.e., `uint8` provides a 4x size reduction at the cost of accuracy.
461For example, we can quantize our MobileNet model using float16 quantization:
462
463```
464tensorflowjs_converter
465 --quantize_float16 \
466 --input_format=tf_hub \
467 'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
468 /mobilenet/web_model
469```
470
471You can also quantize specific weights as well as weight groupings using
472a wildcard replacement. For example,
473```
474tensorflowjs_converter
475 --quantize_float16="conv/*/weights"
476```
477which will quantize all weights that match the pattern conv/*/weights.
478This will exclude biases and any weights that don't begin with conv/.
479This can be a powerful tool to reduce model size while trying to maximize
480performance.
481
482__5. Why is the predict() method for inference so much slower on the first call than the subsequent calls?__
483
484The time of first call also includes the compilation time of WebGL shader
485programs for the model. After the first call the shader programs are cached,
486which makes the subsequent calls much faster. You can warm up the cache by
487calling the predict method with an all zero inputs, right after the completion
488of the model loading.
489
490__6. I have a model converted with a previous version of TensorFlow.js converter (0.15.x), that is in .pb format. How do I convert it to the new JSON format?__
491
492You can use the built-in migration tool to convert the models generated by
493previous versions. Here are the steps:
494
495```bash
496git clone git@github.com:tensorflow/tfjs-converter.git
497cd tfjs-converter
498yarn
499yarn ts-node tools/pb2json_converter.ts pb_model_directory/ json_model_directory/
500```
501
502`pb_model_directory` is the directory where the model generated by previous
503version is located.
504`json_model_directory` is the destination directory for the converted model.
505
506
507## Development
508
509To build **TensorFlow.js converter** from source, we need to prepare the dev environment and clone the project.
510
511Bazel builds Python from source, so we install the dependencies required to build it. Since we will be using pip and C extensions, we also install the ssl, foreign functions, and zlib development packages. On debian, this is done with:
512
513```bash
514sudo apt-get build-dep python3
515sudo apt install libssl-dev libffi-dev zlib1g-dev
516```
517
518See the [python developer guide](https://devguide.python.org/setup/#install-dependencies) for instructions on installing these for other platforms.
519
520Then, we clone the project and install dependencies with:
521
522```bash
523git clone https://github.com/tensorflow/tfjs.git
524cd tfjs
525yarn # Installs dependencies.
526```
527
528We recommend using [Visual Studio Code](https://code.visualstudio.com/) for
529development. Make sure to install
530[TSLint VSCode extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode.vscode-typescript-tslint-plugin)
531and the npm [clang-format](https://github.com/angular/clang-format) `1.2.2`
532or later with the
533[Clang-Format VSCode extension](https://marketplace.visualstudio.com/items?itemName=xaver.clang-format)
534for auto-formatting.
535
536Before submitting a pull request, make sure the code passes all the tests and is
537clean of lint errors:
538
539```bash
540cd tfjs-converter
541yarn test
542yarn lint
543```
544
545To run a subset of tests and/or on a specific browser:
546
547```bash
548yarn test --browsers=Chrome --grep='execute'
549> ...
550> Chrome 64.0.3282 (Linux 0.0.0): Executed 39 of 39 SUCCESS (0.129 secs / 0 secs)
551```
552
553To run the tests once and exit the karma process (helpful on Windows):
554
555```bash
556yarn test --single-run
557```
558
559To run all the python tests
560
561```bash
562yarn run-python-tests
563```