UNPKG

27.5 kBTypeScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/// <amd-module name="@tensorflow/tfjs-layers/dist/engine/training" />
11import * as tfc from '@tensorflow/tfjs-core';
12import { io, ModelPredictConfig as ModelPredictArgs, NamedTensorMap, Optimizer, Scalar, Tensor } from '@tensorflow/tfjs-core';
13import { History, ModelLoggingVerbosity } from '../base_callbacks';
14import { Shape } from '../keras_format/common';
15import { TrainingConfig } from '../keras_format/training_config';
16import { LossOrMetricFn, NamedTensor } from '../types';
17import { Container, ContainerArgs } from './container';
18import { Dataset } from './dataset_stub';
19import { DisposeResult } from './topology';
20import { ModelEvaluateDatasetArgs, ModelFitDatasetArgs } from './training_dataset';
21import { ModelFitArgs } from './training_tensors';
22import { ClassWeight, ClassWeightMap } from './training_utils';
23/**
24 * Helper function for polymorphic input data: 1. singleton Tensor.
25 */
26export declare function isDataTensor(x: Tensor | Tensor[] | {
27 [inputName: string]: Tensor;
28} | {
29 [inputName: string]: Tensor[];
30}): boolean;
31/**
32 * Helper function for polymorphic input data: 2. Array of Tensor.
33 */
34export declare function isDataArray(x: Tensor | Tensor[] | {
35 [inputName: string]: Tensor;
36}): boolean;
37/**
38 * Helper function for polymorphic input data: 3. "dict" of Tensor.
39 */
40export declare function isDataDict(x: Tensor | Tensor[] | {
41 [inputName: string]: Tensor;
42}): boolean;
43/**
44 * Normalizes inputs and targets provided by users.
45 * @param data User-provided input data (polymorphic).
46 * @param names An Array of expected Tensor names.
47 * @param shapes Optional Array of expected Tensor shapes.
48 * @param checkBatchAxis Whether to check that the batch axis of the arrays
49 * match the expected value found in `shapes`.
50 * @param exceptionPrefix String prefix used for exception formatting.
51 * @returns List of standardized input Tensors (one Tensor per model input).
52 * @throws ValueError: in case of improperly formatted user data.
53 */
54export declare function standardizeInputData(data: Tensor | Tensor[] | {
55 [inputName: string]: Tensor;
56}, names: string[], shapes?: Shape[], checkBatchAxis?: boolean, exceptionPrefix?: string): Tensor[];
57/**
58 * User input validation for Tensors.
59 * @param inputs `Array` of `tf.Tensor`s for inputs.
60 * @param targets `Array` of `tf.Tensor`s for targets.
61 * @param weights Optional `Array` of `tf.Tensor`s for sample weights.
62 * @throws ValueError: in case of incorrectly formatted data.
63 */
64export declare function checkArrayLengths(inputs: Tensor[], targets: Tensor[], weights?: Tensor[]): void;
65/**
66 * Maps metric functions to model outputs.
67 * @param metrics An shortcut strings name, metric function, `Array` or dict
68 * (`Object`) of metric functions.
69 * @param outputNames An `Array` of the names of model outputs.
70 * @returns An `Array` (one entry per model output) of `Array` of metric
71 * functions. For instance, if the model has 2 outputs, and for the first
72 * output we want to compute `binaryAccuracy` and `binaryCrossentropy`,
73 * and just `binaryAccuracy` for the second output, the `Array` would look
74 * like:
75 * `[[binaryAccuracy, binaryCrossentropy], [binaryAccuracy]]`
76 * @throws TypeError: incompatible metrics format.
77 */
78export declare function collectMetrics(metrics: string | LossOrMetricFn | Array<string | LossOrMetricFn> | {
79 [outputName: string]: string | LossOrMetricFn;
80}, outputNames: string[]): Array<Array<string | LossOrMetricFn>>;
81export interface ModelEvaluateArgs {
82 /**
83 * Batch size (Integer). If unspecified, it will default to 32.
84 */
85 batchSize?: number;
86 /**
87 * Verbosity mode.
88 */
89 verbose?: ModelLoggingVerbosity;
90 /**
91 * Tensor of weights to weight the contribution of different samples to the
92 * loss and metrics.
93 */
94 sampleWeight?: Tensor;
95 /**
96 * integer: total number of steps (batches of samples)
97 * before declaring the evaluation round finished. Ignored with the default
98 * value of `undefined`.
99 */
100 steps?: number;
101}
102/**
103 * Configuration for calls to `LayersModel.compile()`.
104 */
105export interface ModelCompileArgs {
106 /**
107 * An instance of `tf.train.Optimizer` or a string name for an Optimizer.
108 */
109 optimizer: string | Optimizer;
110 /**
111 * Object function(s) or name(s) of object function(s).
112 * If the model has multiple outputs, you can use a different loss
113 * on each output by passing a dictionary or an Array of losses.
114 * The loss value that will be minimized by the model will then be the sum
115 * of all individual losses.
116 */
117 loss: string | string[] | {
118 [outputName: string]: string;
119 } | LossOrMetricFn | LossOrMetricFn[] | {
120 [outputName: string]: LossOrMetricFn;
121 };
122 /**
123 * List of metrics to be evaluated by the model during training and testing.
124 * Typically you will use `metrics=['accuracy']`.
125 * To specify different metrics for different outputs of a multi-output
126 * model, you could also pass a dictionary.
127 */
128 metrics?: string | LossOrMetricFn | Array<string | LossOrMetricFn> | {
129 [outputName: string]: string | LossOrMetricFn;
130 };
131}
132/**
133 * A `tf.LayersModel` is a directed, acyclic graph of `tf.Layer`s plus methods
134 * for training, evaluation, prediction and saving.
135 *
136 * `tf.LayersModel` is the basic unit of training, inference and evaluation in
137 * TensorFlow.js. To create a `tf.LayersModel`, use `tf.LayersModel`.
138 *
139 * See also:
140 * `tf.Sequential`, `tf.loadLayersModel`.
141 *
142 * @doc {heading: 'Models', subheading: 'Classes'}
143 */
144export declare class LayersModel extends Container implements tfc.InferenceModel {
145 /** @nocollapse */
146 static className: string;
147 protected optimizer_: Optimizer;
148 protected isOptimizerOwned: boolean;
149 loss: string | string[] | {
150 [outputName: string]: string;
151 } | LossOrMetricFn | LossOrMetricFn[] | {
152 [outputName: string]: LossOrMetricFn;
153 };
154 lossFunctions: LossOrMetricFn[];
155 private feedOutputShapes;
156 private feedLossFns;
157 private collectedTrainableWeights;
158 private testFunction;
159 history: History;
160 protected stopTraining_: boolean;
161 protected isTraining: boolean;
162 metrics: string | LossOrMetricFn | Array<string | LossOrMetricFn> | {
163 [outputName: string]: string | LossOrMetricFn;
164 };
165 metricsNames: string[];
166 metricsTensors: Array<[LossOrMetricFn, number]>;
167 private userDefinedMetadata;
168 constructor(args: ContainerArgs);
169 /**
170 * Print a text summary of the model's layers.
171 *
172 * The summary includes
173 * - Name and type of all layers that comprise the model.
174 * - Output shape(s) of the layers
175 * - Number of weight parameters of each layer
176 * - If the model has non-sequential-like topology, the inputs each layer
177 * receives
178 * - The total number of trainable and non-trainable parameters of the model.
179 *
180 * ```js
181 * const input1 = tf.input({shape: [10]});
182 * const input2 = tf.input({shape: [20]});
183 * const dense1 = tf.layers.dense({units: 4}).apply(input1);
184 * const dense2 = tf.layers.dense({units: 8}).apply(input2);
185 * const concat = tf.layers.concatenate().apply([dense1, dense2]);
186 * const output =
187 * tf.layers.dense({units: 3, activation: 'softmax'}).apply(concat);
188 *
189 * const model = tf.model({inputs: [input1, input2], outputs: output});
190 * model.summary();
191 * ```
192 *
193 * @param lineLength Custom line length, in number of characters.
194 * @param positions Custom widths of each of the columns, as either
195 * fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number
196 * of characters (e.g., `[30, 50, 65]`). Each number corresponds to
197 * right-most (i.e., ending) position of a column.
198 * @param printFn Custom print function. Can be used to replace the default
199 * `console.log`. For example, you can use `x => {}` to mute the printed
200 * messages in the console.
201 *
202 * @doc {heading: 'Models', subheading: 'Classes'}
203 */
204 summary(lineLength?: number, positions?: number[], printFn?: (message?: any, ...optionalParams: any[]) => void): void;
205 /**
206 * Configures and prepares the model for training and evaluation. Compiling
207 * outfits the model with an optimizer, loss, and/or metrics. Calling `fit`
208 * or `evaluate` on an un-compiled model will throw an error.
209 *
210 * @param args a `ModelCompileArgs` specifying the loss, optimizer, and
211 * metrics to be used for fitting and evaluating this model.
212 *
213 * @doc {heading: 'Models', subheading: 'Classes'}
214 */
215 compile(args: ModelCompileArgs): void;
216 /**
217 * Check trainable weights count consistency.
218 *
219 * This will raise a warning if `this.trainableWeights` and
220 * `this.collectedTrainableWeights` are inconsistent (i.e., have different
221 * numbers of parameters).
222 * Inconsistency will typically arise when one modifies `model.trainable`
223 * without calling `model.compile()` again.
224 */
225 protected checkTrainableWeightsConsistency(): void;
226 /**
227 * Returns the loss value & metrics values for the model in test mode.
228 *
229 * Loss and metrics are specified during `compile()`, which needs to happen
230 * before calls to `evaluate()`.
231 *
232 * Computation is done in batches.
233 *
234 * ```js
235 * const model = tf.sequential({
236 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
237 * });
238 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
239 * const result = model.evaluate(
240 * tf.ones([8, 10]), tf.ones([8, 1]), {batchSize: 4});
241 * result.print();
242 * ```
243 *
244 * @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the
245 * model has multiple inputs.
246 * @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the
247 * model has multiple outputs.
248 * @param args A `ModelEvaluateArgs`, containing optional fields.
249 *
250 * @return `Scalar` test loss (if the model has a single output and no
251 * metrics) or `Array` of `Scalar`s (if the model has multiple outputs
252 * and/or metrics). The attribute `model.metricsNames`
253 * will give you the display labels for the scalar outputs.
254 *
255 * @doc {heading: 'Models', subheading: 'Classes'}
256 */
257 evaluate(x: Tensor | Tensor[], y: Tensor | Tensor[], args?: ModelEvaluateArgs): Scalar | Scalar[];
258 /**
259 * Evaluate model using a dataset object.
260 *
261 * Note: Unlike `evaluate()`, this method is asynchronous (`async`);
262 *
263 * @param dataset A dataset object. Its `iterator()` method is expected
264 * to generate a dataset iterator object, the `next()` method of which
265 * is expected to produce data batches for evaluation. The return value
266 * of the `next()` call ought to contain a boolean `done` field and a
267 * `value` field. The `value` field is expected to be an array of two
268 * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
269 * case is for models with exactly one input and one output (e.g..
270 * a sequential model). The latter case is for models with multiple
271 * inputs and/or multiple outputs. Of the two items in the array, the
272 * first is the input feature(s) and the second is the output target(s).
273 * @param args A configuration object for the dataset-based evaluation.
274 * @returns Loss and metric values as an Array of `Scalar` objects.
275 *
276 * @doc {heading: 'Models', subheading: 'Classes'}
277 */
278 evaluateDataset(dataset: Dataset<{}>, args?: ModelEvaluateDatasetArgs): Promise<Scalar | Scalar[]>;
279 /**
280 * Get number of samples provided for training, evaluation or prediction.
281 *
282 * @param ins Input `tf.Tensor`.
283 * @param batchSize Integer batch size, optional.
284 * @param steps Total number of steps (batches of samples) before
285 * declaring loop finished. Optional.
286 * @param stepsName The public API's parameter name for `steps`.
287 * @returns Number of samples provided.
288 */
289 private checkNumSamples;
290 /**
291 * Execute internal tensors of the model with input data feed.
292 * @param inputs Input data feed. Must match the inputs of the model.
293 * @param outputs Names of the output tensors to be fetched. Must match
294 * names of the SymbolicTensors that belong to the graph.
295 * @returns Fetched values for `outputs`.
296 */
297 execute(inputs: Tensor | Tensor[] | NamedTensorMap, outputs: string | string[]): Tensor | Tensor[];
298 /**
299 * Retrieve the model's internal symbolic tensors from symbolic-tensor names.
300 */
301 private retrieveSymbolicTensors;
302 /**
303 * Helper method to loop over some data in batches.
304 *
305 * Porting Note: Not using the functional approach in the Python equivalent
306 * due to the imperative backend.
307 * Porting Note: Does not support step mode currently.
308 *
309 * @param ins: input data
310 * @param batchSize: integer batch size.
311 * @param verbose: verbosity model
312 * @returns: Predictions as `tf.Tensor` (if a single output) or an `Array` of
313 * `tf.Tensor` (if multipe outputs).
314 */
315 private predictLoop;
316 /**
317 * Generates output predictions for the input samples.
318 *
319 * Computation is done in batches.
320 *
321 * Note: the "step" mode of predict() is currently not supported.
322 * This is because the TensorFlow.js core backend is imperative only.
323 *
324 * ```js
325 * const model = tf.sequential({
326 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
327 * });
328 * model.predict(tf.ones([8, 10]), {batchSize: 4}).print();
329 * ```
330 *
331 * @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if
332 * the model has multiple inputs.
333 * @param args A `ModelPredictArgs` object containing optional fields.
334 *
335 * @return Prediction results as a `tf.Tensor`(s).
336 *
337 * @exception ValueError In case of mismatch between the provided input data
338 * and the model's expectations, or in case a stateful model receives a
339 * number of samples that is not a multiple of the batch size.
340 *
341 * @doc {heading: 'Models', subheading: 'Classes'}
342 */
343 predict(x: Tensor | Tensor[], args?: ModelPredictArgs): Tensor | Tensor[];
344 /**
345 * Returns predictions for a single batch of samples.
346 *
347 * ```js
348 * const model = tf.sequential({
349 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
350 * });
351 * model.predictOnBatch(tf.ones([8, 10])).print();
352 * ```
353 * @param x: Input samples, as a Tensor (for models with exactly one
354 * input) or an array of Tensors (for models with more than one input).
355 * @return Tensor(s) of predictions
356 *
357 * @doc {heading: 'Models', subheading: 'Classes'}
358 */
359 predictOnBatch(x: Tensor | Tensor[]): Tensor | Tensor[];
360 protected standardizeUserDataXY(x: Tensor | Tensor[] | {
361 [inputName: string]: Tensor;
362 }, y: Tensor | Tensor[] | {
363 [inputName: string]: Tensor;
364 }, checkBatchAxis?: boolean, batchSize?: number): [Tensor[], Tensor[]];
365 protected standardizeUserData(x: Tensor | Tensor[] | {
366 [inputName: string]: Tensor;
367 }, y: Tensor | Tensor[] | {
368 [inputName: string]: Tensor;
369 }, sampleWeight?: Tensor | Tensor[] | {
370 [outputName: string]: Tensor;
371 }, classWeight?: ClassWeight | ClassWeight[] | ClassWeightMap, checkBatchAxis?: boolean, batchSize?: number): Promise<[Tensor[], Tensor[], Tensor[]]>;
372 /**
373 * Loop over some test data in batches.
374 * @param f A Function returning a list of tensors.
375 * @param ins Array of tensors to be fed to `f`.
376 * @param batchSize Integer batch size or `null` / `undefined`.
377 * @param verbose verbosity mode.
378 * @param steps Total number of steps (batches of samples) before
379 * declaring test finished. Ignored with the default value of `null` /
380 * `undefined`.
381 * @returns Array of Scalars.
382 */
383 private testLoop;
384 protected getDedupedMetricsNames(): string[];
385 /**
386 * Creates a function that performs the following actions:
387 *
388 * 1. computes the losses
389 * 2. sums them to get the total loss
390 * 3. call the optimizer computes the gradients of the LayersModel's
391 * trainable weights w.r.t. the total loss and update the variables
392 * 4. calculates the metrics
393 * 5. returns the values of the losses and metrics.
394 */
395 protected makeTrainFunction(): (data: Tensor[]) => Scalar[];
396 /**
397 * Create a function which, when invoked with an array of `tf.Tensor`s as a
398 * batch of inputs, returns the prespecified loss and metrics of the model
399 * under the batch of input data.
400 */
401 private makeTestFunction;
402 /**
403 * Trains the model for a fixed number of epochs (iterations on a
404 * dataset).
405 *
406 * ```js
407 * const model = tf.sequential({
408 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
409 * });
410 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
411 * for (let i = 1; i < 5 ; ++i) {
412 * const h = await model.fit(tf.ones([8, 10]), tf.ones([8, 1]), {
413 * batchSize: 4,
414 * epochs: 3
415 * });
416 * console.log("Loss after Epoch " + i + " : " + h.history.loss[0]);
417 * }
418 * ```
419 *
420 * @param x `tf.Tensor` of training data, or an array of `tf.Tensor`s if the
421 * model has multiple inputs. If all inputs in the model are named, you
422 * can also pass a dictionary mapping input names to `tf.Tensor`s.
423 * @param y `tf.Tensor` of target (label) data, or an array of `tf.Tensor`s if
424 * the model has multiple outputs. If all outputs in the model are named,
425 * you can also pass a dictionary mapping output names to `tf.Tensor`s.
426 * @param args A `ModelFitArgs`, containing optional fields.
427 *
428 * @return A `History` instance. Its `history` attribute contains all
429 * information collected during training.
430 *
431 * @exception ValueError In case of mismatch between the provided input
432 * data and what the model expects.
433 *
434 * @doc {heading: 'Models', subheading: 'Classes'}
435 */
436 fit(x: Tensor | Tensor[] | {
437 [inputName: string]: Tensor;
438 }, y: Tensor | Tensor[] | {
439 [inputName: string]: Tensor;
440 }, args?: ModelFitArgs): Promise<History>;
441 /**
442 * Trains the model using a dataset object.
443 *
444 * @param dataset A dataset object. Its `iterator()` method is expected
445 * to generate a dataset iterator object, the `next()` method of which
446 * is expected to produce data batches for training. The return value
447 * of the `next()` call ought to contain a boolean `done` field and a
448 * `value` field. The `value` field is expected to be an array of two
449 * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
450 * case is for models with exactly one input and one output (e.g..
451 * a sequential model). The latter case is for models with multiple
452 * inputs and/or multiple outputs.
453 * Of the two items in the array, the first is the input feature(s) and
454 * the second is the output target(s).
455 * @param args A `ModelFitDatasetArgs`, containing optional fields.
456 *
457 * @return A `History` instance. Its `history` attribute contains all
458 * information collected during training.
459 *
460 * @doc {heading: 'Models', subheading: 'Classes'}
461 */
462 fitDataset<T>(dataset: Dataset<T>, args: ModelFitDatasetArgs<T>): Promise<History>;
463 /**
464 * Runs a single gradient update on a single batch of data.
465 *
466 * This method differs from `fit()` and `fitDataset()` in the following
467 * regards:
468 * - It operates on exactly one batch of data.
469 * - It returns only the loss and matric values, instead of
470 * returning the batch-by-batch loss and metric values.
471 * - It doesn't support fine-grained options such as verbosity and
472 * callbacks.
473 *
474 * @param x Input data. It could be one of the following:
475 * - A `tf.Tensor`, or an Array of `tf.Tensor`s (in case the model has
476 * multiple inputs).
477 * - An Object mapping input names to corresponding `tf.Tensor` (if the
478 * model has named inputs).
479 * @param y Target darta. It could be either a `tf.Tensor` a multiple
480 * `tf.Tensor`s. It should be consistent with `x`.
481 * @returns Training loss or losses (in case the model has
482 * multiple outputs), along with metrics (if any), as numbers.
483 *
484 * @doc {heading: 'Models', subheading: 'Classes'}
485 */
486 trainOnBatch(x: Tensor | Tensor[] | {
487 [inputName: string]: Tensor;
488 }, y: Tensor | Tensor[] | {
489 [inputName: string]: Tensor;
490 }): Promise<number | number[]>;
491 /**
492 * Extract weight values of the model.
493 *
494 * @param config: An instance of `io.SaveConfig`, which specifies
495 * model-saving options such as whether only trainable weights are to be
496 * saved.
497 * @returns A `NamedTensorMap` mapping original weight names (i.e.,
498 * non-uniqueified weight names) to their values.
499 */
500 protected getNamedWeights(config?: io.SaveConfig): NamedTensor[];
501 /**
502 * Setter used for force stopping of LayersModel.fit() (i.e., training).
503 *
504 * Example:
505 *
506 * ```js
507 * const input = tf.input({shape: [10]});
508 * const output = tf.layers.dense({units: 1}).apply(input);
509 * const model = tf.model({inputs: [input], outputs: [output]});
510 * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
511 * const xs = tf.ones([8, 10]);
512 * const ys = tf.zeros([8, 1]);
513 *
514 * const history = await model.fit(xs, ys, {
515 * epochs: 10,
516 * callbacks: {
517 * onEpochEnd: async (epoch, logs) => {
518 * if (epoch === 2) {
519 * model.stopTraining = true;
520 * }
521 * }
522 * }
523 * });
524 *
525 * // There should be only 3 values in the loss array, instead of 10
526 * values,
527 * // due to the stopping after 3 epochs.
528 * console.log(history.history.loss);
529 * ```
530 */
531 stopTraining: boolean;
532 optimizer: Optimizer;
533 dispose(): DisposeResult;
534 private getLossIdentifiers;
535 private getMetricIdentifiers;
536 protected getTrainingConfig(): TrainingConfig;
537 loadTrainingConfig(trainingConfig: TrainingConfig): void;
538 /**
539 * Save the configuration and/or weights of the LayersModel.
540 *
541 * An `IOHandler` is an object that has a `save` method of the proper
542 * signature defined. The `save` method manages the storing or
543 * transmission of serialized data ("artifacts") that represent the
544 * model's topology and weights onto or via a specific medium, such as
545 * file downloads, local storage, IndexedDB in the web browser and HTTP
546 * requests to a server. TensorFlow.js provides `IOHandler`
547 * implementations for a number of frequently used saving mediums, such as
548 * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
549 * for more details.
550 *
551 * This method also allows you to refer to certain types of `IOHandler`s
552 * as URL-like string shortcuts, such as 'localstorage://' and
553 * 'indexeddb://'.
554 *
555 * Example 1: Save `model`'s topology and weights to browser [local
556 * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
557 * then load it back.
558 *
559 * ```js
560 * const model = tf.sequential(
561 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
562 * console.log('Prediction from original model:');
563 * model.predict(tf.ones([1, 3])).print();
564 *
565 * const saveResults = await model.save('localstorage://my-model-1');
566 *
567 * const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
568 * console.log('Prediction from loaded model:');
569 * loadedModel.predict(tf.ones([1, 3])).print();
570 * ```
571 *
572 * Example 2. Saving `model`'s topology and weights to browser
573 * [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
574 * then load it back.
575 *
576 * ```js
577 * const model = tf.sequential(
578 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
579 * console.log('Prediction from original model:');
580 * model.predict(tf.ones([1, 3])).print();
581 *
582 * const saveResults = await model.save('indexeddb://my-model-1');
583 *
584 * const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
585 * console.log('Prediction from loaded model:');
586 * loadedModel.predict(tf.ones([1, 3])).print();
587 * ```
588 *
589 * Example 3. Saving `model`'s topology and weights as two files
590 * (`my-model-1.json` and `my-model-1.weights.bin`) downloaded from
591 * browser.
592 *
593 * ```js
594 * const model = tf.sequential(
595 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
596 * const saveResults = await model.save('downloads://my-model-1');
597 * ```
598 *
599 * Example 4. Send `model`'s topology and weights to an HTTP server.
600 * See the documentation of `tf.io.http` for more details
601 * including specifying request parameters and implementation of the
602 * server.
603 *
604 * ```js
605 * const model = tf.sequential(
606 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
607 * const saveResults = await model.save('http://my-server/model/upload');
608 * ```
609 *
610 * @param handlerOrURL An instance of `IOHandler` or a URL-like,
611 * scheme-based string shortcut for `IOHandler`.
612 * @param config Options for saving the model.
613 * @returns A `Promise` of `SaveResult`, which summarizes the result of
614 * the saving, such as byte sizes of the saved artifacts for the model's
615 * topology and weight values.
616 *
617 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
618 */
619 save(handlerOrURL: io.IOHandler | string, config?: io.SaveConfig): Promise<io.SaveResult>;
620 /**
621 * Set user-defined metadata.
622 *
623 * The set metadata will be serialized together with the topology
624 * and weights of the model during `save()` calls.
625 *
626 * @param setUserDefinedMetadata
627 */
628 setUserDefinedMetadata(userDefinedMetadata: {}): void;
629 /**
630 * Get user-defined metadata.
631 *
632 * The metadata is supplied via one of the two routes:
633 * 1. By calling `setUserDefinedMetadata()`.
634 * 2. Loaded during model loading (if the model is constructed
635 * via `tf.loadLayersModel()`.)
636 *
637 * If no user-defined metadata is available from either of the
638 * two routes, this function will return `undefined`.
639 */
640 getUserDefinedMetadata(): {};
641}
642/**
643 * A `tf.Functional` is an alias to `tf.LayersModel`.
644 *
645 * See also:
646 * `tf.LayersModel`, `tf.Sequential`, `tf.loadLayersModel`.
647 */
648/** @doc {heading: 'Models', subheading: 'Classes'} */
649export declare class Functional extends LayersModel {
650 static className: string;
651}
652
\No newline at end of file