UNPKG

9 kBTypeScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/// <amd-module name="@tensorflow/tfjs-layers/dist/engine/training_tensors" />
11import { Tensor, Tensor1D } from '@tensorflow/tfjs-core';
12import { BaseCallback, CustomCallbackArgs, History, ModelLoggingVerbosity, YieldEveryOptions } from '../base_callbacks';
13import { ClassWeight, ClassWeightMap } from './training_utils';
14/**
15 * Interface configuration model training based on data as `tf.Tensor`s.
16 */
17export interface ModelFitArgs {
18 /**
19 * Number of samples per gradient update. If unspecified, it
20 * will default to 32.
21 */
22 batchSize?: number;
23 /**
24 * Integer number of times to iterate over the training data arrays.
25 */
26 epochs?: number;
27 /**
28 * Verbosity level.
29 *
30 * Expected to be 0, 1, or 2. Default: 1.
31 *
32 * 0 - No printed message during fit() call.
33 * 1 - In Node.js (tfjs-node), prints the progress bar, together with
34 * real-time updates of loss and metric values and training speed.
35 * In the browser: no action. This is the default.
36 * 2 - Not implemented yet.
37 */
38 verbose?: ModelLoggingVerbosity;
39 /**
40 * List of callbacks to be called during training.
41 * Can have one or more of the following callbacks:
42 * - `onTrainBegin(logs)`: called when training starts.
43 * - `onTrainEnd(logs)`: called when training ends.
44 * - `onEpochBegin(epoch, logs)`: called at the start of every epoch.
45 * - `onEpochEnd(epoch, logs)`: called at the end of every epoch.
46 * - `onBatchBegin(batch, logs)`: called at the start of every batch.
47 * - `onBatchEnd(batch, logs)`: called at the end of every batch.
48 * - `onYield(epoch, batch, logs)`: called every `yieldEvery` milliseconds
49 * with the current epoch, batch and logs. The logs are the same
50 * as in `onBatchEnd()`. Note that `onYield` can skip batches or
51 * epochs. See also docs for `yieldEvery` below.
52 */
53 callbacks?: BaseCallback[] | CustomCallbackArgs | CustomCallbackArgs[];
54 /**
55 * Float between 0 and 1: fraction of the training data
56 * to be used as validation data. The model will set apart this fraction of
57 * the training data, will not train on it, and will evaluate the loss and
58 * any model metrics on this data at the end of each epoch.
59 * The validation data is selected from the last samples in the `x` and `y`
60 * data provided, before shuffling.
61 */
62 validationSplit?: number;
63 /**
64 * Data on which to evaluate the loss and any model
65 * metrics at the end of each epoch. The model will not be trained on this
66 * data. This could be a tuple [xVal, yVal] or a tuple [xVal, yVal,
67 * valSampleWeights]. The model will not be trained on this data.
68 * `validationData` will override `validationSplit`.
69 */
70 validationData?: [Tensor | Tensor[], Tensor | Tensor[]] | [Tensor | Tensor[], Tensor | Tensor[], Tensor | Tensor[]];
71 /**
72 * Whether to shuffle the training data before each epoch. Has
73 * no effect when `stepsPerEpoch` is not `null`.
74 */
75 shuffle?: boolean;
76 /**
77 * Optional object mapping class indices (integers) to
78 * a weight (float) to apply to the model's loss for the samples from this
79 * class during training. This can be useful to tell the model to "pay more
80 * attention" to samples from an under-represented class.
81 *
82 * If the model has multiple outputs, a class weight can be specified for
83 * each of the outputs by setting this field an array of weight object
84 * or a object that maps model output names (e.g., `model.outputNames[0]`)
85 * to weight objects.
86 */
87 classWeight?: ClassWeight | ClassWeight[] | ClassWeightMap;
88 /**
89 * Optional array of the same length as x, containing
90 * weights to apply to the model's loss for each sample. In the case of
91 * temporal data, you can pass a 2D array with shape (samples,
92 * sequenceLength), to apply a different weight to every timestep of every
93 * sample. In this case you should make sure to specify
94 * sampleWeightMode="temporal" in compile().
95 */
96 sampleWeight?: Tensor;
97 /**
98 * Epoch at which to start training (useful for resuming a previous training
99 * run). When this is used, `epochs` is the index of the "final epoch".
100 * The model is not trained for a number of iterations given by `epochs`,
101 * but merely until the epoch of index `epochs` is reached.
102 */
103 initialEpoch?: number;
104 /**
105 * Total number of steps (batches of samples) before
106 * declaring one epoch finished and starting the next epoch. When training
107 * with Input Tensors such as TensorFlow data tensors, the default `null` is
108 * equal to the number of unique samples in your dataset divided by the
109 * batch size, or 1 if that cannot be determined.
110 */
111 stepsPerEpoch?: number;
112 /**
113 * Only relevant if `stepsPerEpoch` is specified. Total number of steps
114 * (batches of samples) to validate before stopping.
115 */
116 validationSteps?: number;
117 /**
118 * Configures the frequency of yielding the main thread to other tasks.
119 *
120 * In the browser environment, yielding the main thread can improve the
121 * responsiveness of the page during training. In the Node.js environment,
122 * it can ensure tasks queued in the event loop can be handled in a timely
123 * manner.
124 *
125 * The value can be one of the following:
126 * - `'auto'`: The yielding happens at a certain frame rate (currently set
127 * at 125ms). This is the default.
128 * - `'batch'`: yield every batch.
129 * - `'epoch'`: yield every epoch.
130 * - any `number`: yield every `number` milliseconds.
131 * - `'never'`: never yield. (yielding can still happen through `await
132 * nextFrame()` calls in custom callbacks.)
133 */
134 yieldEvery?: YieldEveryOptions;
135}
136export declare function checkBatchSize(batchSize: number): void;
137/**
138 * Slice a Tensor or an Array of Tensors, by start and stop indices.
139 *
140 * Porting Note: The `_slice_arrays` function in PyKeras is covered by this
141 * function and `sliceArraysByIndices()` together.
142 *
143 * @param arrays: the input.
144 * @param start: the starting index (inclusive).
145 * @param stop: the stopping index (exclusive).
146 * @returns The result of the slicing. If `arrays` is an `Array` of
147 * `tf.Tensor`s, the slicing will be applied to all elements of the `Array`
148 * in the same way.
149 */
150export declare function sliceArrays(arrays: Tensor | Tensor[], start: number, stop: number): Tensor | Tensor[];
151/**
152 * Slice a Tensor or an Array of Tensors, by random-order indices.
153 *
154 * Porting Note: The `_slice_arrays` function in PyKeras is covered by this
155 * function and `sliceArrays()` together.
156 *
157 * @param arrays The input `tf.Tensor` or `Array` of `tf.Tensor`s to slice.
158 * If an `Array` of `tf.Tensor`s, all `tf.Tensor`s will be sliced in the
159 * same fashion.
160 * @param indices The indices to use for slicing along the first (batch)
161 * dimension.
162 * @returns Result(s) of the slicing.
163 */
164export declare function sliceArraysByIndices(arrays: Tensor | Tensor[], indices: Tensor1D): Tensor | Tensor[];
165/**
166 * Returns a list of batch indices (tuples of indices).
167 * @param size: Integer, total size of the data to slice into batches.
168 * @param batchSize: Integer, batch size.
169 * @returns An Array of [batchStart, batchEnd] tuples. batchStart is
170 * inclusive; batchEnd is exclusive. I.e., each batch consists of indices x
171 * that satisfy batchStart <= x < batchEnd.
172 */
173export declare function makeBatches(size: number, batchSize: number): Array<[number, number]>;
174export declare function fitTensors(model: any, x: Tensor | Tensor[] | {
175 [inputName: string]: Tensor;
176}, y: Tensor | Tensor[] | {
177 [inputName: string]: Tensor;
178}, args?: ModelFitArgs): Promise<History>;
179/**
180 * Ensure tensors all have a rank of at least 2.
181 *
182 * If a tensor has a rank of 1, it is dimension-expanded to rank 2.
183 * If any tensor has a rank of 0 (i.e., is a scalar), an error will be thrown.
184 */
185export declare function ensureTensorsRank2OrHigher(tensors: Tensor | Tensor[]): Tensor[];
186/**
187 * Compare a set of tensors with a reference (old) set, discard the ones
188 * in the new set that are not present in the reference set.
189 *
190 * This method is used for memory clenaup during calls such as
191 * LayersModel.fit().
192 *
193 * @param tensors New set which may contain Tensors not present in
194 * `refTensors`.
195 * @param refTensors Reference Tensor set.
196 */
197export declare function disposeNewTensors(tensors: Tensor | Tensor[] | {
198 [inputName: string]: Tensor;
199}, refTensors: Tensor | Tensor[] | {
200 [inputName: string]: Tensor;
201}): void;