UNPKG

4.58 kBTypeScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/// <amd-module name="@tensorflow/tfjs-layers/dist/callbacks" />
11import { BaseCallback } from './base_callbacks';
12import { Container } from './engine/container';
13import { LayersModel } from './engine/training';
14import { Logs } from './logs';
15export declare abstract class Callback extends BaseCallback {
16 /** Instance of `keras.models.Model`. Reference of the model being trained. */
17 model: LayersModel;
18 setModel(model: Container): void;
19}
20export interface EarlyStoppingCallbackArgs {
21 /**
22 * Quantity to be monitored.
23 *
24 * Defaults to 'val_loss'.
25 */
26 monitor?: string;
27 /**
28 * Minimum change in the monitored quantity to qualify as improvement,
29 * i.e., an absolute change of less than `minDelta` will count as no
30 * improvement.
31 *
32 * Defaults to 0.
33 */
34 minDelta?: number;
35 /**
36 * Number of epochs with no improvement after which training will be stopped.
37 *
38 * Defaults to 0.
39 */
40 patience?: number;
41 /** Verbosity mode. */
42 verbose?: number;
43 /**
44 * Mode: one of 'min', 'max', and 'auto'.
45 * - In 'min' mode, training will be stopped when the quantity monitored has
46 * stopped decreasing.
47 * - In 'max' mode, training will be stopped when the quantity monitored has
48 * stopped increasing.
49 * - In 'auto' mode, the direction is inferred automatically from the name of
50 * the monitored quantity.
51 *
52 * Defaults to 'auto'.
53 */
54 mode?: 'auto' | 'min' | 'max';
55 /**
56 * Baseline value of the monitored quantity.
57 *
58 * If specified, training will be stopped if the model doesn't show
59 * improvement over the baseline.
60 */
61 baseline?: number;
62 /**
63 * Whether to restore model weights from the epoch with the best value
64 * of the monitored quantity. If `False`, the model weights obtained at the
65 * at the last step of training are used.
66 *
67 * **`True` is not supported yet.**
68 */
69 restoreBestWeights?: boolean;
70}
71/**
72 * A Callback that stops training when a monitored quantity has stopped
73 * improving.
74 */
75export declare class EarlyStopping extends Callback {
76 protected readonly monitor: string;
77 protected readonly minDelta: number;
78 protected readonly patience: number;
79 protected readonly baseline: number;
80 protected readonly verbose: number;
81 protected readonly mode: 'auto' | 'min' | 'max';
82 protected monitorFunc: (currVal: number, prevVal: number) => boolean;
83 private wait;
84 private stoppedEpoch;
85 private best;
86 constructor(args?: EarlyStoppingCallbackArgs);
87 onTrainBegin(logs?: Logs): Promise<void>;
88 onEpochEnd(epoch: number, logs?: Logs): Promise<void>;
89 onTrainEnd(logs?: Logs): Promise<void>;
90 private getMonitorValue;
91}
92/**
93 * Factory function for a Callback that stops training when a monitored
94 * quantity has stopped improving.
95 *
96 * Early stopping is a type of regularization, and protects model against
97 * overfitting.
98 *
99 * The following example based on fake data illustrates how this callback
100 * can be used during `tf.LayersModel.fit()`:
101 *
102 * ```js
103 * const model = tf.sequential();
104 * model.add(tf.layers.dense({
105 * units: 3,
106 * activation: 'softmax',
107 * kernelInitializer: 'ones',
108 * inputShape: [2]
109 * }));
110 * const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
111 * const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
112 * const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
113 * const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
114 * model.compile(
115 * {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});
116 *
117 * // Without the EarlyStopping callback, the val_acc value would be:
118 * // 0.5, 0.5, 0.5, 0.5, ...
119 * // With val_acc being monitored, training should stop after the 2nd epoch.
120 * const history = await model.fit(xs, ys, {
121 * epochs: 10,
122 * validationData: [xsVal, ysVal],
123 * callbacks: tf.callbacks.earlyStopping({monitor: 'val_acc'})
124 * });
125 *
126 * // Expect to see a length-2 array.
127 * console.log(history.history.val_acc);
128 * ```
129 *
130 * @doc {
131 * heading: 'Callbacks',
132 * namespace: 'callbacks'
133 * }
134 */
135export declare function earlyStopping(args?: EarlyStoppingCallbackArgs): EarlyStopping;
136export declare const callbacks: {
137 earlyStopping: typeof earlyStopping;
138};