/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
///
import { BaseCallback } from './base_callbacks';
import { Container } from './engine/container';
import { LayersModel } from './engine/training';
import { Logs } from './logs';
export declare abstract class Callback extends BaseCallback {
/** Instance of `keras.models.Model`. Reference of the model being trained. */
model: LayersModel;
setModel(model: Container): void;
}
export interface EarlyStoppingCallbackArgs {
/**
* Quantity to be monitored.
*
* Defaults to 'val_loss'.
*/
monitor?: string;
/**
* Minimum change in the monitored quantity to qualify as improvement,
* i.e., an absolute change of less than `minDelta` will count as no
* improvement.
*
* Defaults to 0.
*/
minDelta?: number;
/**
* Number of epochs with no improvement after which training will be stopped.
*
* Defaults to 0.
*/
patience?: number;
/** Verbosity mode. */
verbose?: number;
/**
* Mode: one of 'min', 'max', and 'auto'.
* - In 'min' mode, training will be stopped when the quantity monitored has
* stopped decreasing.
* - In 'max' mode, training will be stopped when the quantity monitored has
* stopped increasing.
* - In 'auto' mode, the direction is inferred automatically from the name of
* the monitored quantity.
*
* Defaults to 'auto'.
*/
mode?: 'auto' | 'min' | 'max';
/**
* Baseline value of the monitored quantity.
*
* If specified, training will be stopped if the model doesn't show
* improvement over the baseline.
*/
baseline?: number;
/**
* Whether to restore model weights from the epoch with the best value
* of the monitored quantity. If `False`, the model weights obtained at the
* last step of training are used.
*
* **`True` is not supported yet.**
*/
restoreBestWeights?: boolean;
}
/**
* A Callback that stops training when a monitored quantity has stopped
* improving.
*/
export declare class EarlyStopping extends Callback {
protected readonly monitor: string;
protected readonly minDelta: number;
protected readonly patience: number;
protected readonly baseline: number;
protected readonly verbose: number;
protected readonly mode: 'auto' | 'min' | 'max';
protected monitorFunc: (currVal: number, prevVal: number) => boolean;
private wait;
private stoppedEpoch;
private best;
constructor(args?: EarlyStoppingCallbackArgs);
onTrainBegin(logs?: Logs): Promise;
onEpochEnd(epoch: number, logs?: Logs): Promise;
onTrainEnd(logs?: Logs): Promise;
private getMonitorValue;
}
/**
* Factory function for a Callback that stops training when a monitored
* quantity has stopped improving.
*
* Early stopping is a type of regularization, and protects model against
* overfitting.
*
* The following example based on fake data illustrates how this callback
* can be used during `tf.LayersModel.fit()`:
*
* ```js
* const model = tf.sequential();
* model.add(tf.layers.dense({
* units: 3,
* activation: 'softmax',
* kernelInitializer: 'ones',
* inputShape: [2]
* }));
* const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
* const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
* const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
* const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
* model.compile(
* {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});
*
* // Without the EarlyStopping callback, the val_acc value would be:
* // 0.5, 0.5, 0.5, 0.5, ...
* // With val_acc being monitored, training should stop after the 2nd epoch.
* const history = await model.fit(xs, ys, {
* epochs: 10,
* validationData: [xsVal, ysVal],
* callbacks: tf.callbacks.earlyStopping({monitor: 'val_acc'})
* });
*
* // Expect to see a length-2 array.
* console.log(history.history.val_acc);
* ```
*
* @doc {
* heading: 'Callbacks',
* namespace: 'callbacks'
* }
*/
export declare function earlyStopping(args?: EarlyStoppingCallbackArgs): EarlyStopping;
export declare const callbacks: {
earlyStopping: typeof earlyStopping;
};