1 | /**
|
2 | * @license
|
3 | * Copyright 2018 Google LLC
|
4 | *
|
5 | * Use of this source code is governed by an MIT-style
|
6 | * license that can be found in the LICENSE file or at
|
7 | * https://opensource.org/licenses/MIT.
|
8 | * =============================================================================
|
9 | */
|
10 | /// <amd-module name="@tensorflow/tfjs-layers/dist/models" />
|
11 | import { io, Optimizer, Scalar, serialization, Tensor } from '@tensorflow/tfjs-core';
|
12 | import { History } from './base_callbacks';
|
13 | import { Dataset } from './engine/dataset_stub';
|
14 | import { Layer } from './engine/topology';
|
15 | import { LayersModel, ModelCompileArgs, ModelEvaluateArgs } from './engine/training';
|
16 | import { ModelEvaluateDatasetArgs, ModelFitDatasetArgs } from './engine/training_dataset';
|
17 | import { ModelFitArgs } from './engine/training_tensors';
|
18 | import { Shape } from './keras_format/common';
|
19 | import { PyJsonDict } from './keras_format/types';
|
20 | import { Kwargs } from './types';
|
21 | /**
|
22 | * Parses a JSON model configuration file and returns a model instance.
|
23 | *
|
24 | * ```js
|
25 | * // This example shows how to serialize a model using `toJSON()` and
|
26 | * // deserialize it as another model using `tf.models.modelFromJSON()`.
|
27 | * // Note: this example serializes and deserializes only the topology
|
28 | * // of the model; the weights of the loaded model will be different
|
29 | * // from those of the the original model, due to random weight
|
30 | * // initialization.
|
31 | * // To load the topology and weights of a model, use `tf.loadLayersModel()`.
|
32 | * const model1 = tf.sequential();
|
33 | * model1.add(tf.layers.repeatVector({inputShape: [2], n: 4}));
|
34 | * // Serialize `model1` as a JSON object.
|
35 | * const model1JSON = model1.toJSON(null, false);
|
36 | * model1.summary();
|
37 | *
|
38 | * const model2 = await tf.models.modelFromJSON(model1JSON);
|
39 | * model2.summary();
|
40 | * ```
|
41 | *
|
42 | * @param modelAndWeightsConfig JSON object or string encoding a model and
|
43 | * weights configuration. It can also be only the topology JSON of the
|
44 | * model, in which case the weights will not be loaded.
|
45 | * @param custom_objects Optional dictionary mapping names
|
46 | * (strings) to custom classes or functions to be
|
47 | * considered during deserialization.
|
48 | * @returns A TensorFlow.js Layers `tf.LayersModel` instance (uncompiled).
|
49 | */
|
50 | export declare function modelFromJSON(modelAndWeightsConfig: ModelAndWeightsConfig | PyJsonDict, customObjects?: serialization.ConfigDict): Promise<LayersModel>;
|
51 | /**
|
52 | * Options for loading a saved mode in TensorFlow.js format.
|
53 | */
|
54 | export interface ModelAndWeightsConfig {
|
55 | /**
|
56 | * A JSON object or JSON string containing the model config.
|
57 | *
|
58 | * This can be either of the following two formats:
|
59 | * - A model archiecture-only config, i.e., a format consistent with the
|
60 | * return value of`keras.Model.to_json()`.
|
61 | * - A full model config, containing not only model architecture, but also
|
62 | * training options and state, i.e., a format consistent with the return
|
63 | * value of `keras.models.save_model()`.
|
64 | */
|
65 | modelTopology: PyJsonDict;
|
66 | /**
|
67 | * A weights manifest in TensorFlow.js format.
|
68 | */
|
69 | weightsManifest?: io.WeightsManifestConfig;
|
70 | /**
|
71 | * Path to prepend to the paths in `weightManifest` before fetching.
|
72 | *
|
73 | * The path may optionally end in a slash ('/').
|
74 | */
|
75 | pathPrefix?: string;
|
76 | }
|
77 | export interface ModelPredictArgs {
|
78 | /**
|
79 | * Optional. Batch size (Integer). If unspecified, it will default to 32.
|
80 | */
|
81 | batchSize?: number;
|
82 | /**
|
83 | * Optional. Verbosity mode. Defaults to false.
|
84 | */
|
85 | verbose?: boolean;
|
86 | }
|
87 | /**
|
88 | * Load a model composed of Layer objects, including its topology and optionally
|
89 | * weights. See the Tutorial named "How to import a Keras Model" for usage
|
90 | * examples.
|
91 | *
|
92 | * This method is applicable to:
|
93 | *
|
94 | * 1. Models created with the `tf.layers.*`, `tf.sequential`, and
|
95 | * `tf.model` APIs of TensorFlow.js and later saved with the
|
96 | * `tf.LayersModel.save` method.
|
97 | * 2. Models converted from Keras or TensorFlow tf.keras using the
|
98 | * [tensorflowjs_converter](https://github.com/tensorflow/tfjs/tree/master/tfjs-converter).
|
99 | *
|
100 | * This mode is *not* applicable to TensorFlow `SavedModel`s or their converted
|
101 | * forms. For those models, use `tf.loadGraphModel`.
|
102 | *
|
103 | * Example 1. Load a model from an HTTP server.
|
104 | *
|
105 | * ```js
|
106 | * const model = await tf.loadLayersModel(
|
107 | * 'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json');
|
108 | * model.summary();
|
109 | * ```
|
110 | *
|
111 | * Example 2: Save `model`'s topology and weights to browser [local
|
112 | * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
|
113 | * then load it back.
|
114 | *
|
115 | * ```js
|
116 | * const model = tf.sequential(
|
117 | * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
|
118 | * console.log('Prediction from original model:');
|
119 | * model.predict(tf.ones([1, 3])).print();
|
120 | *
|
121 | * const saveResults = await model.save('localstorage://my-model-1');
|
122 | *
|
123 | * const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
|
124 | * console.log('Prediction from loaded model:');
|
125 | * loadedModel.predict(tf.ones([1, 3])).print();
|
126 | * ```
|
127 | *
|
128 | * Example 3. Saving `model`'s topology and weights to browser
|
129 | * [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
|
130 | * then load it back.
|
131 | *
|
132 | * ```js
|
133 | * const model = tf.sequential(
|
134 | * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
|
135 | * console.log('Prediction from original model:');
|
136 | * model.predict(tf.ones([1, 3])).print();
|
137 | *
|
138 | * const saveResults = await model.save('indexeddb://my-model-1');
|
139 | *
|
140 | * const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
|
141 | * console.log('Prediction from loaded model:');
|
142 | * loadedModel.predict(tf.ones([1, 3])).print();
|
143 | * ```
|
144 | *
|
145 | * Example 4. Load a model from user-selected files from HTML
|
146 | * [file input
|
147 | * elements](https://developer.mozilla.org/en-US/docs/Web/HTML/Element/input/file).
|
148 | *
|
149 | * ```js
|
150 | * // Note: this code snippet will not work without the HTML elements in the
|
151 | * // page
|
152 | * const jsonUpload = document.getElementById('json-upload');
|
153 | * const weightsUpload = document.getElementById('weights-upload');
|
154 | *
|
155 | * const model = await tf.loadLayersModel(
|
156 | * tf.io.browserFiles([jsonUpload.files[0], weightsUpload.files[0]]));
|
157 | * ```
|
158 | *
|
159 | * @param pathOrIOHandler Can be either of the two formats
|
160 | * 1. A string path to the `ModelAndWeightsConfig` JSON describing
|
161 | * the model in the canonical TensorFlow.js format. For file://
|
162 | * (tfjs-node-only), http:// and https:// schemas, the path can be
|
163 | * either absolute or relative. The content of the JSON file is assumed to
|
164 | * be a JSON object with the following fields and values:
|
165 | * - 'modelTopology': A JSON object that can be either of:
|
166 | * 1. a model architecture JSON consistent with the format of the return
|
167 | * value of `keras.Model.to_json()`
|
168 | * 2. a full model JSON in the format of `keras.models.save_model()`.
|
169 | * - 'weightsManifest': A TensorFlow.js weights manifest.
|
170 | * See the Python converter function `save_model()` for more details.
|
171 | * It is also assumed that model weights can be accessed from relative
|
172 | * paths described by the `paths` fields in weights manifest.
|
173 | * 2. A `tf.io.IOHandler` object that loads model artifacts with its `load`
|
174 | * method.
|
175 | * @param options Optional configuration arguments for the model loading,
|
176 | * including:
|
177 | * - `strict`: Require that the provided weights exactly match those required
|
178 | * by the layers. Default true. Passing false means that both extra
|
179 | * weights and missing weights will be silently ignored.
|
180 | * - `onProgress`: A progress callback of the form:
|
181 | * `(fraction: number) => void`. This callback can be used to monitor the
|
182 | * model-loading process.
|
183 | * @returns A `Promise` of `tf.LayersModel`, with the topology and weights
|
184 | * loaded.
|
185 | *
|
186 | * @doc {heading: 'Models', subheading: 'Loading'}
|
187 | */
|
188 | export declare function loadLayersModel(pathOrIOHandler: string | io.IOHandler, options?: io.LoadOptions): Promise<LayersModel>;
|
189 | /**
|
190 | * Load a model and optionally its weights, using an IOHandler object.
|
191 | *
|
192 | * @param handler The instance of `IOHandler` to be used during the model
|
193 | * loading.
|
194 | * @param customObjects Any optional custom objects to be used during model
|
195 | * loading.
|
196 | * @param strict Whether the weight loading will be done in strict mode.
|
197 | * Default: `true`.
|
198 | */
|
199 | export declare function loadLayersModelFromIOHandler(handler: io.IOHandler, customObjects?: serialization.ConfigDict, options?: io.LoadOptions): Promise<LayersModel>;
|
200 | /**
|
201 | * Configuration for a Sequential model.
|
202 | */
|
203 | export interface SequentialArgs {
|
204 | /** Stack of layers for the model. */
|
205 | layers?: Layer[];
|
206 | /** The name of this model. */
|
207 | name?: string;
|
208 | }
|
209 | /**
|
210 | * A model with a stack of layers, feeding linearly from one to the next.
|
211 | *
|
212 | * `tf.sequential` is a factory function that creates an instance of
|
213 | * `tf.Sequential`.
|
214 | *
|
215 | * ```js
|
216 | * // Define a model for linear regression.
|
217 | * const model = tf.sequential();
|
218 | * model.add(tf.layers.dense({units: 1, inputShape: [1]}));
|
219 | *
|
220 | * // Prepare the model for training: Specify the loss and the optimizer.
|
221 | * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
|
222 | *
|
223 | * // Generate some synthetic data for training.
|
224 | * const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
|
225 | * const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
|
226 | *
|
227 | * // Train the model using the data then do inference on a data point the
|
228 | * // model hasn't seen:
|
229 | * await model.fit(xs, ys);
|
230 | * model.predict(tf.tensor2d([5], [1, 1])).print();
|
231 | * ```
|
232 | *
|
233 | * @doc {heading: 'Models', subheading: 'Classes'}
|
234 | */
|
235 | export declare class Sequential extends LayersModel {
|
236 | /** @nocollapse */
|
237 | static className: string;
|
238 | private model;
|
239 | constructor(args?: SequentialArgs);
|
240 | private checkShape;
|
241 | /**
|
242 | * Adds a layer instance on top of the layer stack.
|
243 | *
|
244 | * ```js
|
245 | * const model = tf.sequential();
|
246 | * model.add(tf.layers.dense({units: 8, inputShape: [1]}));
|
247 | * model.add(tf.layers.dense({units: 4, activation: 'relu6'}));
|
248 | * model.add(tf.layers.dense({units: 1, activation: 'relu6'}));
|
249 | * // Note that the untrained model is random at this point.
|
250 | * model.predict(tf.randomNormal([10, 1])).print();
|
251 | * ```
|
252 | * @param layer Layer instance.
|
253 | *
|
254 | * @exception ValueError In case the `layer` argument does not know its
|
255 | * input shape.
|
256 | * @exception ValueError In case the `layer` argument has multiple output
|
257 | * tensors, or is already connected somewhere else (forbidden in
|
258 | * `Sequential` models).
|
259 | *
|
260 | * @doc {heading: 'Models', subheading: 'Classes'}
|
261 | */
|
262 | add(layer: Layer): void;
|
263 | /**
|
264 | * Removes the last layer in the model.
|
265 | *
|
266 | * @exception TypeError if there are no layers in the model.
|
267 | */
|
268 | pop(): void;
|
269 | call(inputs: Tensor | Tensor[], kwargs: Kwargs): Tensor | Tensor[];
|
270 | build(inputShape?: Shape | Shape[]): void;
|
271 | countParams(): number;
|
272 | /**
|
273 | * Print a text summary of the Sequential model's layers.
|
274 | *
|
275 | * The summary includes
|
276 | * - Name and type of all layers that comprise the model.
|
277 | * - Output shape(s) of the layers
|
278 | * - Number of weight parameters of each layer
|
279 | * - The total number of trainable and non-trainable parameters of the
|
280 | * model.
|
281 | *
|
282 | * ```js
|
283 | * const model = tf.sequential();
|
284 | * model.add(
|
285 | * tf.layers.dense({units: 100, inputShape: [10], activation: 'relu'}));
|
286 | * model.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));
|
287 | *
|
288 | * model.summary();
|
289 | * ```
|
290 | *
|
291 | * @param lineLength Custom line length, in number of characters.
|
292 | * @param positions Custom widths of each of the columns, as either
|
293 | * fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number
|
294 | * of characters (e.g., `[30, 50, 65]`). Each number corresponds to
|
295 | * right-most (i.e., ending) position of a column.
|
296 | * @param printFn Custom print function. Can be used to replace the default
|
297 | * `console.log`. For example, you can use `x => {}` to mute the printed
|
298 | * messages in the console.
|
299 | *
|
300 | * @doc {heading: 'Models', subheading: 'Classes'}
|
301 | */
|
302 | summary(lineLength?: number, positions?: number[], printFn?: (message?: any, ...optionalParams: any[]) => void): void;
|
303 | /**
|
304 | * Sets the weights of the model.
|
305 | *
|
306 | * @param weights Should be a list of Tensors with shapes and types matching
|
307 | * the output of `model.getWeights()`.
|
308 | */
|
309 | setWeights(weights: Tensor[]): void;
|
310 | /**
|
311 | * Returns the loss value & metrics values for the model in test mode.
|
312 | *
|
313 | * Loss and metrics are specified during `compile()`, which needs to happen
|
314 | * before calls to `evaluate()`.
|
315 | *
|
316 | * Computation is done in batches.
|
317 | *
|
318 | * ```js
|
319 | * const model = tf.sequential({
|
320 | * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
|
321 | * });
|
322 | * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
|
323 | * const result = model.evaluate(tf.ones([8, 10]), tf.ones([8, 1]), {
|
324 | * batchSize: 4,
|
325 | * });
|
326 | * result.print();
|
327 | * ```
|
328 | *
|
329 | * @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the
|
330 | * model has multiple inputs.
|
331 | * @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the
|
332 | * model has multiple outputs.
|
333 | * @param args A `ModelEvaluateConfig`, containing optional fields.
|
334 | *
|
335 | * @return `Scalar` test loss (if the model has a single output and no
|
336 | * metrics) or `Array` of `Scalar`s (if the model has multiple outputs
|
337 | * and/or metrics). The attribute `model.metricsNames`
|
338 | * will give you the display labels for the scalar outputs.
|
339 | *
|
340 | * @doc {heading: 'Models', subheading: 'Classes'}
|
341 | */
|
342 | evaluate(x: Tensor | Tensor[], y: Tensor | Tensor[], args?: ModelEvaluateArgs): Scalar | Scalar[];
|
343 | /**
|
344 | * Evaluate model using a dataset object.
|
345 | *
|
346 | * Note: Unlike `evaluate()`, this method is asynchronous (`async`).
|
347 | *
|
348 | * @param dataset A dataset object. Its `iterator()` method is expected
|
349 | * to generate a dataset iterator object, the `next()` method of which
|
350 | * is expected to produce data batches for evaluation. The return value
|
351 | * of the `next()` call ought to contain a boolean `done` field and a
|
352 | * `value` field. The `value` field is expected to be an array of two
|
353 | * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
|
354 | * case is for models with exactly one input and one output (e.g.
|
355 | * a sequential model). The latter case is for models with multiple
|
356 | * inputs and/or multiple outputs. Of the two items in the array, the
|
357 | * first is the input feature(s) and the second is the output target(s).
|
358 | * @param args A configuration object for the dataset-based evaluation.
|
359 | * @returns Loss and metric values as an Array of `Scalar` objects.
|
360 | *
|
361 | * @doc {heading: 'Models', subheading: 'Classes'}
|
362 | */
|
363 | evaluateDataset(dataset: Dataset<{}>, args: ModelEvaluateDatasetArgs): Promise<Scalar | Scalar[]>;
|
364 | /**
|
365 | * Generates output predictions for the input samples.
|
366 | *
|
367 | * Computation is done in batches.
|
368 | *
|
369 | * Note: the "step" mode of predict() is currently not supported.
|
370 | * This is because the TensorFlow.js core backend is imperative only.
|
371 | *
|
372 | * ```js
|
373 | * const model = tf.sequential({
|
374 | * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
|
375 | * });
|
376 | * model.predict(tf.ones([2, 10])).print();
|
377 | * ```
|
378 | *
|
379 | * @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if
|
380 | * the model has multiple inputs.
|
381 | * @param conifg A `ModelPredictConfig` object containing optional fields.
|
382 | *
|
383 | * @return `tf.Tensor`(s) of predictions.
|
384 | *
|
385 | * @exception ValueError In case of mismatch between the provided input data
|
386 | * and the model's expectations, or in case a stateful model receives a
|
387 | * number of samples that is not a multiple of the batch size.
|
388 | *
|
389 | * @doc {heading: 'Models', subheading: 'Classes'}
|
390 | */
|
391 | predict(x: Tensor | Tensor[], args?: ModelPredictArgs): Tensor | Tensor[];
|
392 | /**
|
393 | * Returns predictions for a single batch of samples.
|
394 | *
|
395 | * @param x: Input samples, as a Tensor, or list of Tensors (if the model
|
396 | * has multiple inputs).
|
397 | * @return Tensor(s) of predictions
|
398 | */
|
399 | predictOnBatch(x: Tensor): Tensor | Tensor[];
|
400 | /**
|
401 | * See `LayersModel.compile`.
|
402 | *
|
403 | * @param args
|
404 | */
|
405 | compile(args: ModelCompileArgs): void;
|
406 | get optimizer(): Optimizer;
|
407 | set optimizer(optimizer: Optimizer);
|
408 | /**
|
409 | * Trains the model for a fixed number of epochs (iterations on a dataset).
|
410 | *
|
411 | * ```js
|
412 | * const model = tf.sequential({
|
413 | * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
|
414 | * });
|
415 | * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
|
416 | * const history = await model.fit(tf.ones([8, 10]), tf.ones([8, 1]), {
|
417 | * batchSize: 4,
|
418 | * epochs: 3
|
419 | * });
|
420 | * console.log(history.history.loss[0]);
|
421 | * ```
|
422 | *
|
423 | * @param x `tf.Tensor` of training data, or an array of `tf.Tensor`s if the
|
424 | * model has multiple inputs. If all inputs in the model are named, you can
|
425 | * also pass a dictionary mapping input names to `tf.Tensor`s.
|
426 | * @param y `tf.Tensor` of target (label) data, or an array of `tf.Tensor`s if
|
427 | * the model has multiple outputs. If all outputs in the model are named, you
|
428 | * can also pass a dictionary mapping output names to `tf.Tensor`s.
|
429 | * @param args A `ModelFitConfig`, containing optional fields.
|
430 | *
|
431 | * @return A `History` instance. Its `history` attribute contains all
|
432 | * information collected during training.
|
433 | *
|
434 | * @exception ValueError In case of mismatch between the provided input data
|
435 | * and what the model expects.
|
436 | *
|
437 | * @doc {heading: 'Models', subheading: 'Classes'}
|
438 | */
|
439 | fit(x: Tensor | Tensor[] | {
|
440 | [inputName: string]: Tensor;
|
441 | }, y: Tensor | Tensor[] | {
|
442 | [inputName: string]: Tensor;
|
443 | }, args?: ModelFitArgs): Promise<History>;
|
444 | /**
|
445 | * Trains the model using a dataset object.
|
446 | *
|
447 | * ```js
|
448 | * const xArray = [
|
449 | * [1, 1, 1, 1, 1, 1, 1, 1, 1],
|
450 | * [1, 1, 1, 1, 1, 1, 1, 1, 1],
|
451 | * [1, 1, 1, 1, 1, 1, 1, 1, 1],
|
452 | * [1, 1, 1, 1, 1, 1, 1, 1, 1],
|
453 | * ];
|
454 | * const yArray = [1, 1, 1, 1];
|
455 | * // Create a dataset from the JavaScript array.
|
456 | * const xDataset = tf.data.array(xArray);
|
457 | * const yDataset = tf.data.array(yArray);
|
458 | * // Zip combines the `x` and `y` Datasets into a single Dataset, the
|
459 | * // iterator of which will return an object containing of two tensors,
|
460 | * // corresponding to `x` and `y`. The call to `batch(4)` will bundle
|
461 | * // four such samples into a single object, with the same keys now pointing
|
462 | * // to tensors that hold 4 examples, organized along the batch dimension.
|
463 | * // The call to `shuffle(4)` causes each iteration through the dataset to
|
464 | * // happen in a different order. The size of the shuffle window is 4.
|
465 | * const xyDataset = tf.data.zip({xs: xDataset, ys: yDataset})
|
466 | * .batch(4)
|
467 | * .shuffle(4);
|
468 | * const model = tf.sequential({
|
469 | * layers: [tf.layers.dense({units: 1, inputShape: [9]})]
|
470 | * });
|
471 | * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
|
472 | * const history = await model.fitDataset(xyDataset, {
|
473 | * epochs: 4,
|
474 | * callbacks: {onEpochEnd: (epoch, logs) => console.log(logs.loss)}
|
475 | * });
|
476 | * ```
|
477 | *
|
478 | * @param dataset A dataset object. Its `iterator()` method is expected to
|
479 | * generate a dataset iterator object, the `next()` method of which is
|
480 | * expected to produce data batches for evaluation. The return value of the
|
481 | * `next()` call ought to contain a boolean `done` field and a `value`
|
482 | * field.
|
483 | *
|
484 | * The `value` field is expected to be an object of with fields
|
485 | * `xs` and `ys`, which point to the feature tensor and the target tensor,
|
486 | * respectively. This case is for models with exactly one input and one
|
487 | * output (e.g. a sequential model). For example:
|
488 | * ```js
|
489 | * {value: {xs: xsTensor, ys: ysTensor}, done: false}
|
490 | * ```
|
491 | *
|
492 | * If the model has multiple inputs, the `xs` field of `value` should
|
493 | * be an object mapping input names to their respective feature tensors.
|
494 | * For example:
|
495 | * ```js
|
496 | * {
|
497 | * value: {
|
498 | * xs: {
|
499 | * input_1: xsTensor1,
|
500 | * input_2: xsTensor2
|
501 | * },
|
502 | * ys: ysTensor
|
503 | * },
|
504 | * done: false
|
505 | * }
|
506 | * ```
|
507 | * If the model has multiple outputs, the `ys` field of `value` should
|
508 | * be an object mapping output names to their respective target tensors.
|
509 | * For example:
|
510 | * ```js
|
511 | * {
|
512 | * value: {
|
513 | * xs: xsTensor,
|
514 | * ys: {
|
515 | * output_1: ysTensor1,
|
516 | * output_2: ysTensor2
|
517 | * },
|
518 | * },
|
519 | * done: false
|
520 | * }
|
521 | * ```
|
522 | * @param args A `ModelFitDatasetArgs`, containing optional fields.
|
523 | *
|
524 | * @return A `History` instance. Its `history` attribute contains all
|
525 | * information collected during training.
|
526 | *
|
527 | * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
|
528 | */
|
529 | fitDataset<T>(dataset: Dataset<T>, args: ModelFitDatasetArgs<T>): Promise<History>;
|
530 | /**
|
531 | * Runs a single gradient update on a single batch of data.
|
532 | *
|
533 | * This method differs from `fit()` and `fitDataset()` in the following
|
534 | * regards:
|
535 | * - It operates on exactly one batch of data.
|
536 | * - It returns only the loss and metric values, instead of
|
537 | * returning the batch-by-batch loss and metric values.
|
538 | * - It doesn't support fine-grained options such as verbosity and
|
539 | * callbacks.
|
540 | *
|
541 | * @param x Input data. It could be one of the following:
|
542 | * - A `tf.Tensor`, or an Array of `tf.Tensor`s (in case the model has
|
543 | * multiple inputs).
|
544 | * - An Object mapping input names to corresponding `tf.Tensor` (if the
|
545 | * model has named inputs).
|
546 | * @param y Target data. It could be either a `tf.Tensor` or multiple
|
547 | * `tf.Tensor`s. It should be consistent with `x`.
|
548 | * @returns Training loss or losses (in case the model has
|
549 | * multiple outputs), along with metrics (if any), as numbers.
|
550 | *
|
551 | * @doc {heading: 'Models', subheading: 'Classes'}
|
552 | */
|
553 | trainOnBatch(x: Tensor | Tensor[] | {
|
554 | [inputName: string]: Tensor;
|
555 | }, y: Tensor | Tensor[] | {
|
556 | [inputName: string]: Tensor;
|
557 | }): Promise<number | number[]>;
|
558 | /** @nocollapse */
|
559 | static fromConfig<T extends serialization.Serializable>(cls: serialization.SerializableConstructor<T>, config: serialization.ConfigDict, customObjects?: serialization.ConfigDict, fastWeightInit?: boolean): T;
|
560 | /**
|
561 | * Setter used for force stopping of LayersModel.fit() (i.e., training).
|
562 | *
|
563 | * Example:
|
564 | *
|
565 | * ```js
|
566 | * const model = tf.sequential();
|
567 | * model.add(tf.layers.dense({units: 1, inputShape: [10]}));
|
568 | * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
|
569 | * const xs = tf.ones([8, 10]);
|
570 | * const ys = tf.zeros([8, 1]);
|
571 | *
|
572 | * const history = await model.fit(xs, ys, {
|
573 | * epochs: 10,
|
574 | * callbacks: {
|
575 | * onEpochEnd: async (epoch, logs) => {
|
576 | * if (epoch === 2) {
|
577 | * model.stopTraining = true;
|
578 | * }
|
579 | * }
|
580 | * }
|
581 | * });
|
582 | *
|
583 | * // There should be only 3 values in the loss array, instead of 10 values,
|
584 | * // due to the stopping after 3 epochs.
|
585 | * console.log(history.history.loss);
|
586 | * ```
|
587 | */
|
588 | set stopTraining(stop: boolean);
|
589 | get stopTraining(): boolean;
|
590 | getConfig(): any;
|
591 | }
|