1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 | import * as tfc from '@tensorflow/tfjs-core';
|
12 | import { io, ModelPredictConfig as ModelPredictArgs, NamedTensorMap, Optimizer, Scalar, Tensor } from '@tensorflow/tfjs-core';
|
13 | import { BaseCallback, History, ModelLoggingVerbosity } from '../base_callbacks';
|
14 | import { Shape } from '../keras_format/common';
|
15 | import { TrainingConfig } from '../keras_format/training_config';
|
16 | import { LossOrMetricFn, NamedTensor } from '../types';
|
17 | import { Container, ContainerArgs } from './container';
|
18 | import { Dataset } from './dataset_stub';
|
19 | import { DisposeResult } from './topology';
|
20 | import { ModelEvaluateDatasetArgs, ModelFitDatasetArgs } from './training_dataset';
|
21 | import { ModelFitArgs } from './training_tensors';
|
22 | import { ClassWeight, ClassWeightMap } from './training_utils';
|
23 |
|
24 |
|
25 |
|
26 | export declare function isDataTensor(x: Tensor | Tensor[] | {
|
27 | [inputName: string]: Tensor;
|
28 | } | {
|
29 | [inputName: string]: Tensor[];
|
30 | }): boolean;
|
31 |
|
32 |
|
33 |
|
34 | export declare function isDataArray(x: Tensor | Tensor[] | {
|
35 | [inputName: string]: Tensor;
|
36 | }): boolean;
|
37 |
|
38 |
|
39 |
|
40 | export declare function isDataDict(x: Tensor | Tensor[] | {
|
41 | [inputName: string]: Tensor;
|
42 | }): boolean;
|
43 |
|
44 |
|
45 |
|
46 |
|
47 |
|
48 |
|
49 |
|
50 |
|
51 |
|
52 |
|
53 |
|
54 | export declare function standardizeInputData(data: Tensor | Tensor[] | {
|
55 | [inputName: string]: Tensor;
|
56 | }, names: string[], shapes?: Shape[], checkBatchAxis?: boolean, exceptionPrefix?: string): Tensor[];
|
57 |
|
58 |
|
59 |
|
60 |
|
61 |
|
62 |
|
63 |
|
64 | export declare function checkArrayLengths(inputs: Tensor[], targets: Tensor[], weights?: Tensor[]): void;
|
65 |
|
66 |
|
67 |
|
68 |
|
69 |
|
70 |
|
71 |
|
72 |
|
73 |
|
74 |
|
75 |
|
76 |
|
77 |
|
78 | export declare function collectMetrics(metrics: string | LossOrMetricFn | Array<string | LossOrMetricFn> | {
|
79 | [outputName: string]: string | LossOrMetricFn;
|
80 | }, outputNames: string[]): Array<Array<string | LossOrMetricFn>>;
|
81 | export interface ModelEvaluateArgs {
|
82 | |
83 |
|
84 |
|
85 | batchSize?: number;
|
86 | |
87 |
|
88 |
|
89 | verbose?: ModelLoggingVerbosity;
|
90 | |
91 |
|
92 |
|
93 |
|
94 | sampleWeight?: Tensor;
|
95 | |
96 |
|
97 |
|
98 |
|
99 |
|
100 | steps?: number;
|
101 | }
|
102 |
|
103 |
|
104 |
|
105 | export interface ModelCompileArgs {
|
106 | |
107 |
|
108 |
|
109 | optimizer: string | Optimizer;
|
110 | |
111 |
|
112 |
|
113 |
|
114 |
|
115 |
|
116 |
|
117 | loss: string | string[] | {
|
118 | [outputName: string]: string;
|
119 | } | LossOrMetricFn | LossOrMetricFn[] | {
|
120 | [outputName: string]: LossOrMetricFn;
|
121 | };
|
122 | |
123 |
|
124 |
|
125 |
|
126 |
|
127 |
|
128 | metrics?: string | LossOrMetricFn | Array<string | LossOrMetricFn> | {
|
129 | [outputName: string]: string | LossOrMetricFn;
|
130 | };
|
131 | }
|
132 |
|
133 |
|
134 |
|
135 |
|
136 |
|
137 |
|
138 |
|
139 |
|
140 |
|
141 |
|
142 |
|
143 |
|
144 | export declare class LayersModel extends Container implements tfc.InferenceModel {
|
145 |
|
146 | static className: string;
|
147 | protected optimizer_: Optimizer;
|
148 | protected isOptimizerOwned: boolean;
|
149 | loss: string | string[] | {
|
150 | [outputName: string]: string;
|
151 | } | LossOrMetricFn | LossOrMetricFn[] | {
|
152 | [outputName: string]: LossOrMetricFn;
|
153 | };
|
154 | lossFunctions: LossOrMetricFn[];
|
155 | private feedOutputShapes;
|
156 | private feedLossFns;
|
157 | private collectedTrainableWeights;
|
158 | private testFunction;
|
159 | history: History;
|
160 | protected stopTraining_: boolean;
|
161 | protected isTraining: boolean;
|
162 | metrics: string | LossOrMetricFn | Array<string | LossOrMetricFn> | {
|
163 | [outputName: string]: string | LossOrMetricFn;
|
164 | };
|
165 | metricsNames: string[];
|
166 | metricsTensors: Array<[LossOrMetricFn, number]>;
|
167 | private userDefinedMetadata;
|
168 | constructor(args: ContainerArgs);
|
169 | /**
|
170 | * Print a text summary of the model's layers.
|
171 | *
|
172 | * The summary includes
|
173 | * - Name and type of all layers that comprise the model.
|
174 | * - Output shape(s) of the layers
|
175 | * - Number of weight parameters of each layer
|
176 | * - If the model has non-sequential-like topology, the inputs each layer
|
177 | * receives
|
178 | * - The total number of trainable and non-trainable parameters of the model.
|
179 | *
|
180 | * ```js
|
181 | * const input1 = tf.input({shape: [10]});
|
182 | * const input2 = tf.input({shape: [20]});
|
183 | * const dense1 = tf.layers.dense({units: 4}).apply(input1);
|
184 | * const dense2 = tf.layers.dense({units: 8}).apply(input2);
|
185 | * const concat = tf.layers.concatenate().apply([dense1, dense2]);
|
186 | * const output =
|
187 | * tf.layers.dense({units: 3, activation: 'softmax'}).apply(concat);
|
188 | *
|
189 | * const model = tf.model({inputs: [input1, input2], outputs: output});
|
190 | * model.summary();
|
191 | * ```
|
192 | *
|
193 | * @param lineLength Custom line length, in number of characters.
|
194 | * @param positions Custom widths of each of the columns, as either
|
195 | * fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number
|
196 | * of characters (e.g., `[30, 50, 65]`). Each number corresponds to
|
197 | * right-most (i.e., ending) position of a column.
|
198 | * @param printFn Custom print function. Can be used to replace the default
|
199 | * `console.log`. For example, you can use `x => {}` to mute the printed
|
200 | * messages in the console.
|
201 | *
|
202 | * @doc {heading: 'Models', subheading: 'Classes'}
|
203 | */
|
204 | summary(lineLength?: number, positions?: number[], printFn?: (message?: any, ...optionalParams: any[]) => void): void;
|
205 | /**
|
206 | * Configures and prepares the model for training and evaluation. Compiling
|
207 | * outfits the model with an optimizer, loss, and/or metrics. Calling `fit`
|
208 | * or `evaluate` on an un-compiled model will throw an error.
|
209 | *
|
210 | * @param args a `ModelCompileArgs` specifying the loss, optimizer, and
|
211 | * metrics to be used for fitting and evaluating this model.
|
212 | *
|
213 | * @doc {heading: 'Models', subheading: 'Classes'}
|
214 | */
|
215 | compile(args: ModelCompileArgs): void;
|
216 | /**
|
217 | * Check trainable weights count consistency.
|
218 | *
|
219 | * This will raise a warning if `this.trainableWeights` and
|
220 | * `this.collectedTrainableWeights` are inconsistent (i.e., have different
|
221 | * numbers of parameters).
|
222 | * Inconsistency will typically arise when one modifies `model.trainable`
|
223 | * without calling `model.compile()` again.
|
224 | */
|
225 | protected checkTrainableWeightsConsistency(): void;
|
226 | /**
|
227 | * Returns the loss value & metrics values for the model in test mode.
|
228 | *
|
229 | * Loss and metrics are specified during `compile()`, which needs to happen
|
230 | * before calls to `evaluate()`.
|
231 | *
|
232 | * Computation is done in batches.
|
233 | *
|
234 | * ```js
|
235 | * const model = tf.sequential({
|
236 | * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
|
237 | * });
|
238 | * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
|
239 | * const result = model.evaluate(
|
240 | * tf.ones([8, 10]), tf.ones([8, 1]), {batchSize: 4});
|
241 | * result.print();
|
242 | * ```
|
243 | *
|
244 | * @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the
|
245 | * model has multiple inputs.
|
246 | * @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the
|
247 | * model has multiple outputs.
|
248 | * @param args A `ModelEvaluateArgs`, containing optional fields.
|
249 | *
|
250 | * @return `Scalar` test loss (if the model has a single output and no
|
251 | * metrics) or `Array` of `Scalar`s (if the model has multiple outputs
|
252 | * and/or metrics). The attribute `model.metricsNames`
|
253 | * will give you the display labels for the scalar outputs.
|
254 | *
|
255 | * @doc {heading: 'Models', subheading: 'Classes'}
|
256 | */
|
257 | evaluate(x: Tensor | Tensor[], y: Tensor | Tensor[], args?: ModelEvaluateArgs): Scalar | Scalar[];
|
258 | /**
|
259 | * Evaluate model using a dataset object.
|
260 | *
|
261 | * Note: Unlike `evaluate()`, this method is asynchronous (`async`).
|
262 | *
|
263 | * @param dataset A dataset object. Its `iterator()` method is expected
|
264 | * to generate a dataset iterator object, the `next()` method of which
|
265 | * is expected to produce data batches for evaluation. The return value
|
266 | * of the `next()` call ought to contain a boolean `done` field and a
|
267 | * `value` field. The `value` field is expected to be an array of two
|
268 | * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
|
269 | * case is for models with exactly one input and one output (e.g.
|
270 | * a sequential model). The latter case is for models with multiple
|
271 | * inputs and/or multiple outputs. Of the two items in the array, the
|
272 | * first is the input feature(s) and the second is the output target(s).
|
273 | * @param args A configuration object for the dataset-based evaluation.
|
274 | * @returns Loss and metric values as an Array of `Scalar` objects.
|
275 | *
|
276 | * @doc {heading: 'Models', subheading: 'Classes'}
|
277 | */
|
278 | evaluateDataset(dataset: Dataset<{}>, args?: ModelEvaluateDatasetArgs): Promise<Scalar | Scalar[]>;
|
279 | /**
|
280 | * Get number of samples provided for training, evaluation or prediction.
|
281 | *
|
282 | * @param ins Input `tf.Tensor`.
|
283 | * @param batchSize Integer batch size, optional.
|
284 | * @param steps Total number of steps (batches of samples) before
|
285 | * declaring loop finished. Optional.
|
286 | * @param stepsName The public API's parameter name for `steps`.
|
287 | * @returns Number of samples provided.
|
288 | */
|
289 | private checkNumSamples;
|
290 | /**
|
291 | * Execute internal tensors of the model with input data feed.
|
292 | * @param inputs Input data feed. Must match the inputs of the model.
|
293 | * @param outputs Names of the output tensors to be fetched. Must match
|
294 | * names of the SymbolicTensors that belong to the graph.
|
295 | * @returns Fetched values for `outputs`.
|
296 | */
|
297 | execute(inputs: Tensor | Tensor[] | NamedTensorMap, outputs: string | string[]): Tensor | Tensor[];
|
298 | /**
|
299 | * Retrieve the model's internal symbolic tensors from symbolic-tensor names.
|
300 | */
|
301 | private retrieveSymbolicTensors;
|
302 | /**
|
303 | * Helper method to loop over some data in batches.
|
304 | *
|
305 | * Porting Note: Not using the functional approach in the Python equivalent
|
306 | * due to the imperative backend.
|
307 | * Porting Note: Does not support step mode currently.
|
308 | *
|
309 | * @param ins: input data
|
310 | * @param batchSize: integer batch size.
|
311 | * @param verbose: verbosity model
|
312 | * @returns: Predictions as `tf.Tensor` (if a single output) or an `Array` of
|
313 | * `tf.Tensor` (if multipe outputs).
|
314 | */
|
315 | private predictLoop;
|
316 | /**
|
317 | * Generates output predictions for the input samples.
|
318 | *
|
319 | * Computation is done in batches.
|
320 | *
|
321 | * Note: the "step" mode of predict() is currently not supported.
|
322 | * This is because the TensorFlow.js core backend is imperative only.
|
323 | *
|
324 | * ```js
|
325 | * const model = tf.sequential({
|
326 | * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
|
327 | * });
|
328 | * model.predict(tf.ones([8, 10]), {batchSize: 4}).print();
|
329 | * ```
|
330 | *
|
331 | * @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if
|
332 | * the model has multiple inputs.
|
333 | * @param args A `ModelPredictArgs` object containing optional fields.
|
334 | *
|
335 | * @return Prediction results as a `tf.Tensor`(s).
|
336 | *
|
337 | * @exception ValueError In case of mismatch between the provided input data
|
338 | * and the model's expectations, or in case a stateful model receives a
|
339 | * number of samples that is not a multiple of the batch size.
|
340 | *
|
341 | * @doc {heading: 'Models', subheading: 'Classes'}
|
342 | */
|
343 | predict(x: Tensor | Tensor[], args?: ModelPredictArgs): Tensor | Tensor[];
|
344 | /**
|
345 | * Returns predictions for a single batch of samples.
|
346 | *
|
347 | * ```js
|
348 | * const model = tf.sequential({
|
349 | * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
|
350 | * });
|
351 | * model.predictOnBatch(tf.ones([8, 10])).print();
|
352 | * ```
|
353 | * @param x: Input samples, as a Tensor (for models with exactly one
|
354 | * input) or an array of Tensors (for models with more than one input).
|
355 | * @return Tensor(s) of predictions
|
356 | *
|
357 | * @doc {heading: 'Models', subheading: 'Classes'}
|
358 | */
|
359 | predictOnBatch(x: Tensor | Tensor[]): Tensor | Tensor[];
|
360 | protected standardizeUserDataXY(x: Tensor | Tensor[] | {
|
361 | [inputName: string]: Tensor;
|
362 | }, y: Tensor | Tensor[] | {
|
363 | [inputName: string]: Tensor;
|
364 | }, checkBatchAxis?: boolean, batchSize?: number): [Tensor[], Tensor[]];
|
365 | protected standardizeUserData(x: Tensor | Tensor[] | {
|
366 | [inputName: string]: Tensor;
|
367 | }, y: Tensor | Tensor[] | {
|
368 | [inputName: string]: Tensor;
|
369 | }, sampleWeight?: Tensor | Tensor[] | {
|
370 | [outputName: string]: Tensor;
|
371 | }, classWeight?: ClassWeight | ClassWeight[] | ClassWeightMap, checkBatchAxis?: boolean, batchSize?: number): Promise<[Tensor[], Tensor[], Tensor[]]>;
|
372 | /**
|
373 | * Loop over some test data in batches.
|
374 | * @param f A Function returning a list of tensors.
|
375 | * @param ins Array of tensors to be fed to `f`.
|
376 | * @param batchSize Integer batch size or `null` / `undefined`.
|
377 | * @param verbose verbosity mode.
|
378 | * @param steps Total number of steps (batches of samples) before
|
379 | * declaring test finished. Ignored with the default value of `null` /
|
380 | * `undefined`.
|
381 | * @returns Array of Scalars.
|
382 | */
|
383 | private testLoop;
|
384 | protected getDedupedMetricsNames(): string[];
|
385 | /**
|
386 | * Creates a function that performs the following actions:
|
387 | *
|
388 | * 1. computes the losses
|
389 | * 2. sums them to get the total loss
|
390 | * 3. call the optimizer computes the gradients of the LayersModel's
|
391 | * trainable weights w.r.t. the total loss and update the variables
|
392 | * 4. calculates the metrics
|
393 | * 5. returns the values of the losses and metrics.
|
394 | */
|
395 | protected makeTrainFunction(): (data: Tensor[]) => Scalar[];
|
396 | /**
|
397 | * Create a function which, when invoked with an array of `tf.Tensor`s as a
|
398 | * batch of inputs, returns the prespecified loss and metrics of the model
|
399 | * under the batch of input data.
|
400 | */
|
401 | private makeTestFunction;
|
402 | /**
|
403 | * Trains the model for a fixed number of epochs (iterations on a
|
404 | * dataset).
|
405 | *
|
406 | * ```js
|
407 | * const model = tf.sequential({
|
408 | * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
|
409 | * });
|
410 | * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
|
411 | * for (let i = 1; i < 5 ; ++i) {
|
412 | * const h = await model.fit(tf.ones([8, 10]), tf.ones([8, 1]), {
|
413 | * batchSize: 4,
|
414 | * epochs: 3
|
415 | * });
|
416 | * console.log("Loss after Epoch " + i + " : " + h.history.loss[0]);
|
417 | * }
|
418 | * ```
|
419 | *
|
420 | * @param x `tf.Tensor` of training data, or an array of `tf.Tensor`s if the
|
421 | * model has multiple inputs. If all inputs in the model are named, you
|
422 | * can also pass a dictionary mapping input names to `tf.Tensor`s.
|
423 | * @param y `tf.Tensor` of target (label) data, or an array of `tf.Tensor`s if
|
424 | * the model has multiple outputs. If all outputs in the model are named,
|
425 | * you can also pass a dictionary mapping output names to `tf.Tensor`s.
|
426 | * @param args A `ModelFitArgs`, containing optional fields.
|
427 | *
|
428 | * @return A `History` instance. Its `history` attribute contains all
|
429 | * information collected during training.
|
430 | *
|
431 | * @exception ValueError In case of mismatch between the provided input
|
432 | * data and what the model expects.
|
433 | *
|
434 | * @doc {heading: 'Models', subheading: 'Classes'}
|
435 | */
|
436 | fit(x: Tensor | Tensor[] | {
|
437 | [inputName: string]: Tensor;
|
438 | }, y: Tensor | Tensor[] | {
|
439 | [inputName: string]: Tensor;
|
440 | }, args?: ModelFitArgs): Promise<History>;
|
441 | /**
|
442 | * Abstract fit function for `f(ins)`.
|
443 | * @param f A Function returning a list of tensors. For training, this
|
444 | * function is expected to perform the updates to the variables.
|
445 | * @param ins List of tensors to be fed to `f`.
|
446 | * @param outLabels List of strings, display names of the outputs of `f`.
|
447 | * @param batchSize Integer batch size or `== null` if unknown. Default : 32.
|
448 | * @param epochs Number of times to iterate over the data. Default : 1.
|
449 | * @param verbose Verbosity mode: 0, 1, or 2. Default: 1.
|
450 | * @param callbacks List of callbacks to be called during training.
|
451 | * @param valF Function to call for validation.
|
452 | * @param valIns List of tensors to be fed to `valF`.
|
453 | * @param shuffle Whether to shuffle the data at the beginning of every
|
454 | * epoch. Default : true.
|
455 | * @param callbackMetrics List of strings, the display names of the metrics
|
456 | * passed to the callbacks. They should be the concatenation of the
|
457 | * display names of the outputs of `f` and the list of display names
|
458 | * of the outputs of `valF`.
|
459 | * @param initialEpoch Epoch at which to start training (useful for
|
460 | * resuming a previous training run). Default : 0.
|
461 | * @param stepsPerEpoch Total number of steps (batches on samples) before
|
462 | * declaring one epoch finished and starting the next epoch. Ignored with
|
463 | * the default value of `undefined` or `null`.
|
464 | * @param validationSteps Number of steps to run validation for (only if
|
465 | * doing validation from data tensors). Not applicable for tfjs-layers.
|
466 | * @returns A `History` object.
|
467 | */
|
468 | fitLoop(f: (data: Tensor[]) => Scalar[], ins: Tensor[], outLabels?: string[], batchSize?: number, epochs?: number, verbose?: number, callbacks?: BaseCallback[], valF?: (data: Tensor[]) => Scalar[], valIns?: Tensor[], shuffle?: boolean | string, callbackMetrics?: string[], initialEpoch?: number, stepsPerEpoch?: number, validationSteps?: number): Promise<History>;
|
469 | /**
|
470 | * Trains the model using a dataset object.
|
471 | *
|
472 | * @param dataset A dataset object. Its `iterator()` method is expected
|
473 | * to generate a dataset iterator object, the `next()` method of which
|
474 | * is expected to produce data batches for training. The return value
|
475 | * of the `next()` call ought to contain a boolean `done` field and a
|
476 | * `value` field. The `value` field is expected to be an array of two
|
477 | * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
|
478 | * case is for models with exactly one input and one output (e.g.
|
479 | * a sequential model). The latter case is for models with multiple
|
480 | * inputs and/or multiple outputs.
|
481 | * Of the two items in the array, the first is the input feature(s) and
|
482 | * the second is the output target(s).
|
483 | * @param args A `ModelFitDatasetArgs`, containing optional fields.
|
484 | *
|
485 | * @return A `History` instance. Its `history` attribute contains all
|
486 | * information collected during training.
|
487 | *
|
488 | * @doc {heading: 'Models', subheading: 'Classes'}
|
489 | */
|
490 | fitDataset<T>(dataset: Dataset<T>, args: ModelFitDatasetArgs<T>): Promise<History>;
|
491 | /**
|
492 | * Runs a single gradient update on a single batch of data.
|
493 | *
|
494 | * This method differs from `fit()` and `fitDataset()` in the following
|
495 | * regards:
|
496 | * - It operates on exactly one batch of data.
|
497 | * - It returns only the loss and metric values, instead of
|
498 | * returning the batch-by-batch loss and metric values.
|
499 | * - It doesn't support fine-grained options such as verbosity and
|
500 | * callbacks.
|
501 | *
|
502 | * @param x Input data. It could be one of the following:
|
503 | * - A `tf.Tensor`, or an Array of `tf.Tensor`s (in case the model has
|
504 | * multiple inputs).
|
505 | * - An Object mapping input names to corresponding `tf.Tensor` (if the
|
506 | * model has named inputs).
|
507 | * @param y Target data. It could be either a `tf.Tensor` or multiple
|
508 | * `tf.Tensor`s. It should be consistent with `x`.
|
509 | * @returns Training loss or losses (in case the model has
|
510 | * multiple outputs), along with metrics (if any), as numbers.
|
511 | *
|
512 | * @doc {heading: 'Models', subheading: 'Classes'}
|
513 | */
|
514 | trainOnBatch(x: Tensor | Tensor[] | {
|
515 | [inputName: string]: Tensor;
|
516 | }, y: Tensor | Tensor[] | {
|
517 | [inputName: string]: Tensor;
|
518 | }): Promise<number | number[]>;
|
519 | /**
|
520 | * Extract weight values of the model.
|
521 | *
|
522 | * @param config: An instance of `io.SaveConfig`, which specifies
|
523 | * model-saving options such as whether only trainable weights are to be
|
524 | * saved.
|
525 | * @returns A `NamedTensorMap` mapping original weight names (i.e.,
|
526 | * non-uniqueified weight names) to their values.
|
527 | */
|
528 | protected getNamedWeights(config?: io.SaveConfig): NamedTensor[];
|
529 | /**
|
530 | * Setter used for force stopping of LayersModel.fit() (i.e., training).
|
531 | *
|
532 | * Example:
|
533 | *
|
534 | * ```js
|
535 | * const input = tf.input({shape: [10]});
|
536 | * const output = tf.layers.dense({units: 1}).apply(input);
|
537 | * const model = tf.model({inputs: [input], outputs: [output]});
|
538 | * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
|
539 | * const xs = tf.ones([8, 10]);
|
540 | * const ys = tf.zeros([8, 1]);
|
541 | *
|
542 | * const history = await model.fit(xs, ys, {
|
543 | * epochs: 10,
|
544 | * callbacks: {
|
545 | * onEpochEnd: async (epoch, logs) => {
|
546 | * if (epoch === 2) {
|
547 | * model.stopTraining = true;
|
548 | * }
|
549 | * }
|
550 | * }
|
551 | * });
|
552 | *
|
553 | *
|
554 | * values,
|
555 | *
|
556 | * console.log(history.history.loss);
|
557 | * ```
|
558 | */
|
559 | set stopTraining(stop: boolean);
|
560 | get stopTraining(): boolean;
|
561 | get optimizer(): Optimizer;
|
562 | set optimizer(optimizer: Optimizer);
|
563 | dispose(): DisposeResult;
|
564 | private getLossIdentifiers;
|
565 | private getMetricIdentifiers;
|
566 | protected getTrainingConfig(): TrainingConfig;
|
567 | loadTrainingConfig(trainingConfig: TrainingConfig): void;
|
568 | /**
|
569 | * Save the configuration and/or weights of the LayersModel.
|
570 | *
|
571 | * An `IOHandler` is an object that has a `save` method of the proper
|
572 | * signature defined. The `save` method manages the storing or
|
573 | * transmission of serialized data ("artifacts") that represent the
|
574 | * model's topology and weights onto or via a specific medium, such as
|
575 | * file downloads, local storage, IndexedDB in the web browser and HTTP
|
576 | * requests to a server. TensorFlow.js provides `IOHandler`
|
577 | * implementations for a number of frequently used saving mediums, such as
|
578 | * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
|
579 | * for more details.
|
580 | *
|
581 | * This method also allows you to refer to certain types of `IOHandler`s
|
582 | * as URL-like string shortcuts, such as 'localstorage://' and
|
583 | * 'indexeddb://'.
|
584 | *
|
585 | * Example 1: Save `model`'s topology and weights to browser [local
|
586 | * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
|
587 | * then load it back.
|
588 | *
|
589 | * ```js
|
590 | * const model = tf.sequential(
|
591 | * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
|
592 | * console.log('Prediction from original model:');
|
593 | * model.predict(tf.ones([1, 3])).print();
|
594 | *
|
595 | * const saveResults = await model.save('localstorage://my-model-1');
|
596 | *
|
597 | * const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
|
598 | * console.log('Prediction from loaded model:');
|
599 | * loadedModel.predict(tf.ones([1, 3])).print();
|
600 | * ```
|
601 | *
|
602 | * Example 2. Saving `model`'s topology and weights to browser
|
603 | * [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
|
604 | * then load it back.
|
605 | *
|
606 | * ```js
|
607 | * const model = tf.sequential(
|
608 | * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
|
609 | * console.log('Prediction from original model:');
|
610 | * model.predict(tf.ones([1, 3])).print();
|
611 | *
|
612 | * const saveResults = await model.save('indexeddb://my-model-1');
|
613 | *
|
614 | * const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
|
615 | * console.log('Prediction from loaded model:');
|
616 | * loadedModel.predict(tf.ones([1, 3])).print();
|
617 | * ```
|
618 | *
|
619 | * Example 3. Saving `model`'s topology and weights as two files
|
620 | * (`my-model-1.json` and `my-model-1.weights.bin`) downloaded from
|
621 | * browser.
|
622 | *
|
623 | * ```js
|
624 | * const model = tf.sequential(
|
625 | * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
|
626 | * const saveResults = await model.save('downloads://my-model-1');
|
627 | * ```
|
628 | *
|
629 | * Example 4. Send `model`'s topology and weights to an HTTP server.
|
630 | * See the documentation of `tf.io.http` for more details
|
631 | * including specifying request parameters and implementation of the
|
632 | * server.
|
633 | *
|
634 | * ```js
|
635 | * const model = tf.sequential(
|
636 | * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
|
637 | * const saveResults = await model.save('http://my-server/model/upload');
|
638 | * ```
|
639 | *
|
640 | * @param handlerOrURL An instance of `IOHandler` or a URL-like,
|
641 | * scheme-based string shortcut for `IOHandler`.
|
642 | * @param config Options for saving the model.
|
643 | * @returns A `Promise` of `SaveResult`, which summarizes the result of
|
644 | * the saving, such as byte sizes of the saved artifacts for the model's
|
645 | * topology and weight values.
|
646 | *
|
647 | * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
|
648 | */
|
649 | save(handlerOrURL: io.IOHandler | string, config?: io.SaveConfig): Promise<io.SaveResult>;
|
650 | /**
|
651 | * Set user-defined metadata.
|
652 | *
|
653 | * The set metadata will be serialized together with the topology
|
654 | * and weights of the model during `save()` calls.
|
655 | *
|
656 | * @param setUserDefinedMetadata
|
657 | */
|
658 | setUserDefinedMetadata(userDefinedMetadata: {}): void;
|
659 | /**
|
660 | * Get user-defined metadata.
|
661 | *
|
662 | * The metadata is supplied via one of the two routes:
|
663 | * 1. By calling `setUserDefinedMetadata()`.
|
664 | * 2. Loaded during model loading (if the model is constructed
|
665 | * via `tf.loadLayersModel()`.)
|
666 | *
|
667 | * If no user-defined metadata is available from either of the
|
668 | * two routes, this function will return `undefined`.
|
669 | */
|
670 | getUserDefinedMetadata(): {};
|
671 | }
|
672 | /**
|
673 | * A `tf.Functional` is an alias to `tf.LayersModel`.
|
674 | *
|
675 | * See also:
|
676 | * `tf.LayersModel`, `tf.Sequential`, `tf.loadLayersModel`.
|
677 | */
|
678 | /** @doc {heading: 'Models', subheading: 'Classes'} */
|
679 | export declare class Functional extends LayersModel {
|
680 | static className: string;
|
681 | }
|
682 |
|
\ | No newline at end of file |