1 | /**
|
2 | * @license
|
3 | * Copyright 2018 Google LLC
|
4 | *
|
5 | * Use of this source code is governed by an MIT-style
|
6 | * license that can be found in the LICENSE file or at
|
7 | * https://opensource.org/licenses/MIT.
|
8 | * =============================================================================
|
9 | */
|
10 | /// <amd-module name="@tensorflow/tfjs-layers/dist/engine/training_tensors" />
|
11 | import { Tensor, Tensor1D } from '@tensorflow/tfjs-core';
|
12 | import { BaseCallback, CustomCallbackArgs, History, ModelLoggingVerbosity, YieldEveryOptions } from '../base_callbacks';
|
13 | import { ClassWeight, ClassWeightMap } from './training_utils';
|
14 | /**
|
15 | * Interface configuration model training based on data as `tf.Tensor`s.
|
16 | */
|
17 | export interface ModelFitArgs {
|
18 | /**
|
19 | * Number of samples per gradient update. If unspecified, it
|
20 | * will default to 32.
|
21 | */
|
22 | batchSize?: number;
|
23 | /**
|
24 | * Integer number of times to iterate over the training data arrays.
|
25 | */
|
26 | epochs?: number;
|
27 | /**
|
28 | * Verbosity level.
|
29 | *
|
30 | * Expected to be 0, 1, or 2. Default: 1.
|
31 | *
|
32 | * 0 - No printed message during fit() call.
|
33 | * 1 - In Node.js (tfjs-node), prints the progress bar, together with
|
34 | * real-time updates of loss and metric values and training speed.
|
35 | * In the browser: no action. This is the default.
|
36 | * 2 - Not implemented yet.
|
37 | */
|
38 | verbose?: ModelLoggingVerbosity;
|
39 | /**
|
40 | * List of callbacks to be called during training.
|
41 | * Can have one or more of the following callbacks:
|
42 | * - `onTrainBegin(logs)`: called when training starts.
|
43 | * - `onTrainEnd(logs)`: called when training ends.
|
44 | * - `onEpochBegin(epoch, logs)`: called at the start of every epoch.
|
45 | * - `onEpochEnd(epoch, logs)`: called at the end of every epoch.
|
46 | * - `onBatchBegin(batch, logs)`: called at the start of every batch.
|
47 | * - `onBatchEnd(batch, logs)`: called at the end of every batch.
|
48 | * - `onYield(epoch, batch, logs)`: called every `yieldEvery` milliseconds
|
49 | * with the current epoch, batch and logs. The logs are the same
|
50 | * as in `onBatchEnd()`. Note that `onYield` can skip batches or
|
51 | * epochs. See also docs for `yieldEvery` below.
|
52 | */
|
53 | callbacks?: BaseCallback[] | CustomCallbackArgs | CustomCallbackArgs[];
|
54 | /**
|
55 | * Float between 0 and 1: fraction of the training data
|
56 | * to be used as validation data. The model will set apart this fraction of
|
57 | * the training data, will not train on it, and will evaluate the loss and
|
58 | * any model metrics on this data at the end of each epoch.
|
59 | * The validation data is selected from the last samples in the `x` and `y`
|
60 | * data provided, before shuffling.
|
61 | */
|
62 | validationSplit?: number;
|
63 | /**
|
64 | * Data on which to evaluate the loss and any model
|
65 | * metrics at the end of each epoch. The model will not be trained on this
|
66 | * data. This could be a tuple [xVal, yVal] or a tuple [xVal, yVal,
|
67 | * valSampleWeights]. The model will not be trained on this data.
|
68 | * `validationData` will override `validationSplit`.
|
69 | */
|
70 | validationData?: [Tensor | Tensor[], Tensor | Tensor[]] | [Tensor | Tensor[], Tensor | Tensor[], Tensor | Tensor[]];
|
71 | /**
|
72 | * Whether to shuffle the training data before each epoch. Has
|
73 | * no effect when `stepsPerEpoch` is not `null`.
|
74 | */
|
75 | shuffle?: boolean;
|
76 | /**
|
77 | * Optional object mapping class indices (integers) to
|
78 | * a weight (float) to apply to the model's loss for the samples from this
|
79 | * class during training. This can be useful to tell the model to "pay more
|
80 | * attention" to samples from an under-represented class.
|
81 | *
|
82 | * If the model has multiple outputs, a class weight can be specified for
|
83 | * each of the outputs by setting this field an array of weight object
|
84 | * or a object that maps model output names (e.g., `model.outputNames[0]`)
|
85 | * to weight objects.
|
86 | */
|
87 | classWeight?: ClassWeight | ClassWeight[] | ClassWeightMap;
|
88 | /**
|
89 | * Optional array of the same length as x, containing
|
90 | * weights to apply to the model's loss for each sample. In the case of
|
91 | * temporal data, you can pass a 2D array with shape (samples,
|
92 | * sequenceLength), to apply a different weight to every timestep of every
|
93 | * sample. In this case you should make sure to specify
|
94 | * sampleWeightMode="temporal" in compile().
|
95 | */
|
96 | sampleWeight?: Tensor;
|
97 | /**
|
98 | * Epoch at which to start training (useful for resuming a previous training
|
99 | * run). When this is used, `epochs` is the index of the "final epoch".
|
100 | * The model is not trained for a number of iterations given by `epochs`,
|
101 | * but merely until the epoch of index `epochs` is reached.
|
102 | */
|
103 | initialEpoch?: number;
|
104 | /**
|
105 | * Total number of steps (batches of samples) before
|
106 | * declaring one epoch finished and starting the next epoch. When training
|
107 | * with Input Tensors such as TensorFlow data tensors, the default `null` is
|
108 | * equal to the number of unique samples in your dataset divided by the
|
109 | * batch size, or 1 if that cannot be determined.
|
110 | */
|
111 | stepsPerEpoch?: number;
|
112 | /**
|
113 | * Only relevant if `stepsPerEpoch` is specified. Total number of steps
|
114 | * (batches of samples) to validate before stopping.
|
115 | */
|
116 | validationSteps?: number;
|
117 | /**
|
118 | * Configures the frequency of yielding the main thread to other tasks.
|
119 | *
|
120 | * In the browser environment, yielding the main thread can improve the
|
121 | * responsiveness of the page during training. In the Node.js environment,
|
122 | * it can ensure tasks queued in the event loop can be handled in a timely
|
123 | * manner.
|
124 | *
|
125 | * The value can be one of the following:
|
126 | * - `'auto'`: The yielding happens at a certain frame rate (currently set
|
127 | * at 125ms). This is the default.
|
128 | * - `'batch'`: yield every batch.
|
129 | * - `'epoch'`: yield every epoch.
|
130 | * - any `number`: yield every `number` milliseconds.
|
131 | * - `'never'`: never yield. (yielding can still happen through `await
|
132 | * nextFrame()` calls in custom callbacks.)
|
133 | */
|
134 | yieldEvery?: YieldEveryOptions;
|
135 | }
|
136 | export declare function checkBatchSize(batchSize: number): void;
|
137 | /**
|
138 | * Slice a Tensor or an Array of Tensors, by start and stop indices.
|
139 | *
|
140 | * Porting Note: The `_slice_arrays` function in PyKeras is covered by this
|
141 | * function and `sliceArraysByIndices()` together.
|
142 | *
|
143 | * @param arrays: the input.
|
144 | * @param start: the starting index (inclusive).
|
145 | * @param stop: the stopping index (exclusive).
|
146 | * @returns The result of the slicing. If `arrays` is an `Array` of
|
147 | * `tf.Tensor`s, the slicing will be applied to all elements of the `Array`
|
148 | * in the same way.
|
149 | */
|
150 | export declare function sliceArrays(arrays: Tensor | Tensor[], start: number, stop: number): Tensor | Tensor[];
|
151 | /**
|
152 | * Slice a Tensor or an Array of Tensors, by random-order indices.
|
153 | *
|
154 | * Porting Note: The `_slice_arrays` function in PyKeras is covered by this
|
155 | * function and `sliceArrays()` together.
|
156 | *
|
157 | * @param arrays The input `tf.Tensor` or `Array` of `tf.Tensor`s to slice.
|
158 | * If an `Array` of `tf.Tensor`s, all `tf.Tensor`s will be sliced in the
|
159 | * same fashion.
|
160 | * @param indices The indices to use for slicing along the first (batch)
|
161 | * dimension.
|
162 | * @returns Result(s) of the slicing.
|
163 | */
|
164 | export declare function sliceArraysByIndices(arrays: Tensor | Tensor[], indices: Tensor1D): Tensor | Tensor[];
|
165 | /**
|
166 | * Returns a list of batch indices (tuples of indices).
|
167 | * @param size: Integer, total size of the data to slice into batches.
|
168 | * @param batchSize: Integer, batch size.
|
169 | * @returns An Array of [batchStart, batchEnd] tuples. batchStart is
|
170 | * inclusive; batchEnd is exclusive. I.e., each batch consists of indices x
|
171 | * that satisfy batchStart <= x < batchEnd.
|
172 | */
|
173 | export declare function makeBatches(size: number, batchSize: number): Array<[number, number]>;
|
174 | export declare function fitTensors(model: any, x: Tensor | Tensor[] | {
|
175 | [inputName: string]: Tensor;
|
176 | }, y: Tensor | Tensor[] | {
|
177 | [inputName: string]: Tensor;
|
178 | }, args?: ModelFitArgs): Promise<History>;
|
179 | /**
|
180 | * Ensure tensors all have a rank of at least 2.
|
181 | *
|
182 | * If a tensor has a rank of 1, it is dimension-expanded to rank 2.
|
183 | * If any tensor has a rank of 0 (i.e., is a scalar), an error will be thrown.
|
184 | */
|
185 | export declare function ensureTensorsRank2OrHigher(tensors: Tensor | Tensor[]): Tensor[];
|
186 | /**
|
187 | * Compare a set of tensors with a reference (old) set, discard the ones
|
188 | * in the new set that are not present in the reference set.
|
189 | *
|
190 | * This method is used for memory clenaup during calls such as
|
191 | * LayersModel.fit().
|
192 | *
|
193 | * @param tensors New set which may contain Tensors not present in
|
194 | * `refTensors`.
|
195 | * @param refTensors Reference Tensor set.
|
196 | */
|
197 | export declare function disposeNewTensors(tensors: Tensor | Tensor[] | {
|
198 | [inputName: string]: Tensor;
|
199 | }, refTensors: Tensor | Tensor[] | {
|
200 | [inputName: string]: Tensor;
|
201 | }): void;
|