UNPKG

7.44 kBTypeScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/// <amd-module name="@tensorflow/tfjs-layers/dist/engine/training_dataset" />
11/**
12 * Interfaces and methods for training models using TensorFlow.js datasets.
13 */
14import * as tfc from '@tensorflow/tfjs-core';
15import { BaseCallback, CustomCallbackArgs, History, ModelLoggingVerbosity, YieldEveryOptions } from '../base_callbacks';
16import { TensorOrArrayOrMap } from '../types';
17import { Dataset, LazyIterator } from './dataset_stub';
18import { ClassWeight, ClassWeightMap } from './training_utils';
19/**
20 * Interface for configuring model training based on a dataset object.
21 */
22export interface ModelFitDatasetArgs<T> {
23 /**
24 * (Optional) Total number of steps (batches of samples) before
25 * declaring one epoch finished and starting the next epoch. It should
26 * typically be equal to the number of samples of your dataset divided by
27 * the batch size, so that `fitDataset`() call can utilize the entire dataset.
28 * If it is not provided, use `done` return value in `iterator.next()` as
29 * signal to finish an epoch.
30 */
31 batchesPerEpoch?: number;
32 /**
33 * Integer number of times to iterate over the training dataset.
34 */
35 epochs: number;
36 /**
37 * Verbosity level.
38 *
39 * Expected to be 0, 1, or 2. Default: 1.
40 *
41 * 0 - No printed message during fit() call.
42 * 1 - In Node.js (tfjs-node), prints the progress bar, together with
43 * real-time updates of loss and metric values and training speed.
44 * In the browser: no action. This is the default.
45 * 2 - Not implemented yet.
46 */
47 verbose?: ModelLoggingVerbosity;
48 /**
49 * List of callbacks to be called during training.
50 * Can have one or more of the following callbacks:
51 * - `onTrainBegin(logs)`: called when training starts.
52 * - `onTrainEnd(logs)`: called when training ends.
53 * - `onEpochBegin(epoch, logs)`: called at the start of every epoch.
54 * - `onEpochEnd(epoch, logs)`: called at the end of every epoch.
55 * - `onBatchBegin(batch, logs)`: called at the start of every batch.
56 * - `onBatchEnd(batch, logs)`: called at the end of every batch.
57 * - `onYield(epoch, batch, logs)`: called every `yieldEvery` milliseconds
58 * with the current epoch, batch and logs. The logs are the same
59 * as in `onBatchEnd()`. Note that `onYield` can skip batches or
60 * epochs. See also docs for `yieldEvery` below.
61 */
62 callbacks?: BaseCallback[] | CustomCallbackArgs | CustomCallbackArgs[];
63 /**
64 * Data on which to evaluate the loss and any model
65 * metrics at the end of each epoch. The model will not be trained on this
66 * data. This could be any of the following:
67 *
68 * - An array `[xVal, yVal]`, where the two values may be `tf.Tensor`,
69 * an array of Tensors, or a map of string to Tensor.
70 * - Similarly, an array ` [xVal, yVal, valSampleWeights]`
71 * (not implemented yet).
72 * - a `Dataset` object with elements of the form `{xs: xVal, ys: yVal}`,
73 * where `xs` and `ys` are the feature and label tensors, respectively.
74 *
75 * If `validationData` is an Array of Tensor objects, each `tf.Tensor` will be
76 * sliced into batches during validation, using the parameter
77 * `validationBatchSize` (which defaults to 32). The entirety of the
78 * `tf.Tensor` objects will be used in the validation.
79 *
80 * If `validationData` is a dataset object, and the `validationBatches`
81 * parameter is specified, the validation will use `validationBatches` batches
82 * drawn from the dataset object. If `validationBatches` parameter is not
83 * specified, the validation will stop when the dataset is exhausted.
84 *
85 * The model will not be trained on this data.
86 */
87 validationData?: [
88 TensorOrArrayOrMap,
89 TensorOrArrayOrMap
90 ] | [TensorOrArrayOrMap, TensorOrArrayOrMap, TensorOrArrayOrMap] | Dataset<T>;
91 /**
92 * Optional batch size for validation.
93 *
94 * Used only if `validationData` is an array of `tf.Tensor` objects, i.e., not
95 * a dataset object.
96 *
97 * If not specified, its value defaults to 32.
98 */
99 validationBatchSize?: number;
100 /**
101 * (Optional) Only relevant if `validationData` is specified and is a dataset
102 * object.
103 *
104 * Total number of batches of samples to draw from `validationData` for
105 * validation purpose before stopping at the end of every epoch. If not
106 * specified, `evaluateDataset` will use `iterator.next().done` as signal to
107 * stop validation.
108 */
109 validationBatches?: number;
110 /**
111 * Configures the frequency of yielding the main thread to other tasks.
112 *
113 * In the browser environment, yielding the main thread can improve the
114 * responsiveness of the page during training. In the Node.js environment,
115 * it can ensure tasks queued in the event loop can be handled in a timely
116 * manner.
117 *
118 * The value can be one of the following:
119 * - `'auto'`: The yielding happens at a certain frame rate (currently set
120 * at 125ms). This is the default.
121 * - `'batch'`: yield every batch.
122 * - `'epoch'`: yield every epoch.
123 * - a `number`: Will yield every `number` milliseconds.
124 * - `'never'`: never yield. (But yielding can still happen through `await
125 * nextFrame()` calls in custom callbacks.)
126 */
127 yieldEvery?: YieldEveryOptions;
128 /**
129 * Epoch at which to start training (useful for resuming a previous training
130 * run). When this is used, `epochs` is the index of the "final epoch".
131 * The model is not trained for a number of iterations given by `epochs`,
132 * but merely until the epoch of index `epochs` is reached.
133 */
134 initialEpoch?: number;
135 /**
136 * Optional object mapping class indices (integers) to
137 * a weight (float) to apply to the model's loss for the samples from this
138 * class during training. This can be useful to tell the model to "pay more
139 * attention" to samples from an under-represented class.
140 *
141 * If the model has multiple outputs, a class weight can be specified for
142 * each of the outputs by setting this field an array of weight object
143 * or an object that maps model output names (e.g., `model.outputNames[0]`)
144 * to weight objects.
145 */
146 classWeight?: ClassWeight | ClassWeight[] | ClassWeightMap;
147}
148export interface FitDatasetElement {
149 xs: TensorOrArrayOrMap;
150 ys: TensorOrArrayOrMap;
151}
152/**
153 * Interface for configuring model evaluation based on a dataset object.
154 */
155export interface ModelEvaluateDatasetArgs {
156 /**
157 * Number of batches to draw from the dataset object before ending the
158 * evaluation.
159 */
160 batches?: number;
161 /**
162 * Verbosity mode.
163 */
164 verbose?: ModelLoggingVerbosity;
165}
166export declare function fitDataset<T>(model: any, dataset: Dataset<T>, args: ModelFitDatasetArgs<T>): Promise<History>;
167export declare function evaluateDataset<T>(model: any, dataset: Dataset<T> | LazyIterator<T>, args: ModelEvaluateDatasetArgs): Promise<tfc.Scalar | tfc.Scalar[]>;