1 | /**
|
2 | * @license
|
3 | * Copyright 2018 Google LLC
|
4 | *
|
5 | * Use of this source code is governed by an MIT-style
|
6 | * license that can be found in the LICENSE file or at
|
7 | * https://opensource.org/licenses/MIT.
|
8 | * =============================================================================
|
9 | */
|
10 | /// <amd-module name="@tensorflow/tfjs-layers/dist/engine/training_dataset" />
|
11 | /**
|
12 | * Interfaces and methods for training models using TensorFlow.js datasets.
|
13 | */
|
14 | import * as tfc from '@tensorflow/tfjs-core';
|
15 | import { BaseCallback, CustomCallbackArgs, History, ModelLoggingVerbosity, YieldEveryOptions } from '../base_callbacks';
|
16 | import { TensorOrArrayOrMap } from '../types';
|
17 | import { Dataset, LazyIterator } from './dataset_stub';
|
18 | import { ClassWeight, ClassWeightMap } from './training_utils';
|
19 | /**
|
20 | * Interface for configuring model training based on a dataset object.
|
21 | */
|
22 | export interface ModelFitDatasetArgs<T> {
|
23 | /**
|
24 | * (Optional) Total number of steps (batches of samples) before
|
25 | * declaring one epoch finished and starting the next epoch. It should
|
26 | * typically be equal to the number of samples of your dataset divided by
|
27 | * the batch size, so that `fitDataset`() call can utilize the entire dataset.
|
28 | * If it is not provided, use `done` return value in `iterator.next()` as
|
29 | * signal to finish an epoch.
|
30 | */
|
31 | batchesPerEpoch?: number;
|
32 | /**
|
33 | * Integer number of times to iterate over the training dataset.
|
34 | */
|
35 | epochs: number;
|
36 | /**
|
37 | * Verbosity level.
|
38 | *
|
39 | * Expected to be 0, 1, or 2. Default: 1.
|
40 | *
|
41 | * 0 - No printed message during fit() call.
|
42 | * 1 - In Node.js (tfjs-node), prints the progress bar, together with
|
43 | * real-time updates of loss and metric values and training speed.
|
44 | * In the browser: no action. This is the default.
|
45 | * 2 - Not implemented yet.
|
46 | */
|
47 | verbose?: ModelLoggingVerbosity;
|
48 | /**
|
49 | * List of callbacks to be called during training.
|
50 | * Can have one or more of the following callbacks:
|
51 | * - `onTrainBegin(logs)`: called when training starts.
|
52 | * - `onTrainEnd(logs)`: called when training ends.
|
53 | * - `onEpochBegin(epoch, logs)`: called at the start of every epoch.
|
54 | * - `onEpochEnd(epoch, logs)`: called at the end of every epoch.
|
55 | * - `onBatchBegin(batch, logs)`: called at the start of every batch.
|
56 | * - `onBatchEnd(batch, logs)`: called at the end of every batch.
|
57 | * - `onYield(epoch, batch, logs)`: called every `yieldEvery` milliseconds
|
58 | * with the current epoch, batch and logs. The logs are the same
|
59 | * as in `onBatchEnd()`. Note that `onYield` can skip batches or
|
60 | * epochs. See also docs for `yieldEvery` below.
|
61 | */
|
62 | callbacks?: BaseCallback[] | CustomCallbackArgs | CustomCallbackArgs[];
|
63 | /**
|
64 | * Data on which to evaluate the loss and any model
|
65 | * metrics at the end of each epoch. The model will not be trained on this
|
66 | * data. This could be any of the following:
|
67 | *
|
68 | * - An array `[xVal, yVal]`, where the two values may be `tf.Tensor`,
|
69 | * an array of Tensors, or a map of string to Tensor.
|
70 | * - Similarly, an array ` [xVal, yVal, valSampleWeights]`
|
71 | * (not implemented yet).
|
72 | * - a `Dataset` object with elements of the form `{xs: xVal, ys: yVal}`,
|
73 | * where `xs` and `ys` are the feature and label tensors, respectively.
|
74 | *
|
75 | * If `validationData` is an Array of Tensor objects, each `tf.Tensor` will be
|
76 | * sliced into batches during validation, using the parameter
|
77 | * `validationBatchSize` (which defaults to 32). The entirety of the
|
78 | * `tf.Tensor` objects will be used in the validation.
|
79 | *
|
80 | * If `validationData` is a dataset object, and the `validationBatches`
|
81 | * parameter is specified, the validation will use `validationBatches` batches
|
82 | * drawn from the dataset object. If `validationBatches` parameter is not
|
83 | * specified, the validation will stop when the dataset is exhausted.
|
84 | *
|
85 | * The model will not be trained on this data.
|
86 | */
|
87 | validationData?: [TensorOrArrayOrMap, TensorOrArrayOrMap] | [TensorOrArrayOrMap, TensorOrArrayOrMap, TensorOrArrayOrMap] | Dataset<T>;
|
88 | /**
|
89 | * Optional batch size for validation.
|
90 | *
|
91 | * Used only if `validationData` is an array of `tf.Tensor` objects, i.e., not
|
92 | * a dataset object.
|
93 | *
|
94 | * If not specified, its value defaults to 32.
|
95 | */
|
96 | validationBatchSize?: number;
|
97 | /**
|
98 | * (Optional) Only relevant if `validationData` is specified and is a dataset
|
99 | * object.
|
100 | *
|
101 | * Total number of batches of samples to draw from `validationData` for
|
102 | * validation purpose before stopping at the end of every epoch. If not
|
103 | * specified, `evaluateDataset` will use `iterator.next().done` as signal to
|
104 | * stop validation.
|
105 | */
|
106 | validationBatches?: number;
|
107 | /**
|
108 | * Configures the frequency of yielding the main thread to other tasks.
|
109 | *
|
110 | * In the browser environment, yielding the main thread can improve the
|
111 | * responsiveness of the page during training. In the Node.js environment,
|
112 | * it can ensure tasks queued in the event loop can be handled in a timely
|
113 | * manner.
|
114 | *
|
115 | * The value can be one of the following:
|
116 | * - `'auto'`: The yielding happens at a certain frame rate (currently set
|
117 | * at 125ms). This is the default.
|
118 | * - `'batch'`: yield every batch.
|
119 | * - `'epoch'`: yield every epoch.
|
120 | * - a `number`: Will yield every `number` milliseconds.
|
121 | * - `'never'`: never yield. (But yielding can still happen through `await
|
122 | * nextFrame()` calls in custom callbacks.)
|
123 | */
|
124 | yieldEvery?: YieldEveryOptions;
|
125 | /**
|
126 | * Epoch at which to start training (useful for resuming a previous training
|
127 | * run). When this is used, `epochs` is the index of the "final epoch".
|
128 | * The model is not trained for a number of iterations given by `epochs`,
|
129 | * but merely until the epoch of index `epochs` is reached.
|
130 | */
|
131 | initialEpoch?: number;
|
132 | /**
|
133 | * Optional object mapping class indices (integers) to
|
134 | * a weight (float) to apply to the model's loss for the samples from this
|
135 | * class during training. This can be useful to tell the model to "pay more
|
136 | * attention" to samples from an under-represented class.
|
137 | *
|
138 | * If the model has multiple outputs, a class weight can be specified for
|
139 | * each of the outputs by setting this field an array of weight object
|
140 | * or a object that maps model output names (e.g., `model.outputNames[0]`)
|
141 | * to weight objects.
|
142 | */
|
143 | classWeight?: ClassWeight | ClassWeight[] | ClassWeightMap;
|
144 | }
|
145 | export interface FitDatasetElement {
|
146 | xs: TensorOrArrayOrMap;
|
147 | ys: TensorOrArrayOrMap;
|
148 | }
|
149 | /**
|
150 | * Interface for configuring model evaluation based on a dataset object.
|
151 | */
|
152 | export interface ModelEvaluateDatasetArgs {
|
153 | /**
|
154 | * Number of batches to draw from the dataset object before ending the
|
155 | * evaluation.
|
156 | */
|
157 | batches?: number;
|
158 | /**
|
159 | * Verbosity mode.
|
160 | */
|
161 | verbose?: ModelLoggingVerbosity;
|
162 | }
|
163 | export declare function fitDataset<T>(model: any, dataset: Dataset<T>, args: ModelFitDatasetArgs<T>): Promise<History>;
|
164 | export declare function evaluateDataset<T>(model: any, dataset: Dataset<T> | LazyIterator<T>, args: ModelEvaluateDatasetArgs): Promise<tfc.Scalar | tfc.Scalar[]>;
|