1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 | import { BaseCallback } from './base_callbacks';
|
12 | import { Container } from './engine/container';
|
13 | import { LayersModel } from './engine/training';
|
14 | import { Logs } from './logs';
|
15 | export declare abstract class Callback extends BaseCallback {
|
16 |
|
17 | model: LayersModel;
|
18 | setModel(model: Container): void;
|
19 | }
|
20 | export interface EarlyStoppingCallbackArgs {
|
21 | |
22 |
|
23 |
|
24 |
|
25 |
|
26 | monitor?: string;
|
27 | |
28 |
|
29 |
|
30 |
|
31 |
|
32 |
|
33 |
|
34 | minDelta?: number;
|
35 | |
36 |
|
37 |
|
38 |
|
39 |
|
40 | patience?: number;
|
41 |
|
42 | verbose?: number;
|
43 | |
44 |
|
45 |
|
46 |
|
47 |
|
48 |
|
49 |
|
50 |
|
51 |
|
52 |
|
53 |
|
54 | mode?: 'auto' | 'min' | 'max';
|
55 | |
56 |
|
57 |
|
58 |
|
59 |
|
60 |
|
61 | baseline?: number;
|
62 | |
63 |
|
64 |
|
65 |
|
66 |
|
67 |
|
68 |
|
69 | restoreBestWeights?: boolean;
|
70 | }
|
71 |
|
72 |
|
73 |
|
74 |
|
75 | export declare class EarlyStopping extends Callback {
|
76 | protected readonly monitor: string;
|
77 | protected readonly minDelta: number;
|
78 | protected readonly patience: number;
|
79 | protected readonly baseline: number;
|
80 | protected readonly verbose: number;
|
81 | protected readonly mode: 'auto' | 'min' | 'max';
|
82 | protected monitorFunc: (currVal: number, prevVal: number) => boolean;
|
83 | private wait;
|
84 | private stoppedEpoch;
|
85 | private best;
|
86 | constructor(args?: EarlyStoppingCallbackArgs);
|
87 | onTrainBegin(logs?: Logs): Promise<void>;
|
88 | onEpochEnd(epoch: number, logs?: Logs): Promise<void>;
|
89 | onTrainEnd(logs?: Logs): Promise<void>;
|
90 | private getMonitorValue;
|
91 | }
|
92 | /**
|
93 | * Factory function for a Callback that stops training when a monitored
|
94 | * quantity has stopped improving.
|
95 | *
|
96 | * Early stopping is a type of regularization, and protects model against
|
97 | * overfitting.
|
98 | *
|
99 | * The following example based on fake data illustrates how this callback
|
100 | * can be used during `tf.LayersModel.fit()`:
|
101 | *
|
102 | * ```js
|
103 | * const model = tf.sequential();
|
104 | * model.add(tf.layers.dense({
|
105 | * units: 3,
|
106 | * activation: 'softmax',
|
107 | * kernelInitializer: 'ones',
|
108 | * inputShape: [2]
|
109 | * }));
|
110 | * const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
|
111 | * const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
|
112 | * const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
|
113 | * const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
|
114 | * model.compile(
|
115 | * {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});
|
116 | *
|
117 | * // Without the EarlyStopping callback, the val_acc value would be:
|
118 | * // 0.5, 0.5, 0.5, 0.5, ...
|
119 | * // With val_acc being monitored, training should stop after the 2nd epoch.
|
120 | * const history = await model.fit(xs, ys, {
|
121 | * epochs: 10,
|
122 | * validationData: [xsVal, ysVal],
|
123 | * callbacks: tf.callbacks.earlyStopping({monitor: 'val_acc'})
|
124 | * });
|
125 | *
|
126 | * // Expect to see a length-2 array.
|
127 | * console.log(history.history.val_acc);
|
128 | * ```
|
129 | *
|
130 | * @doc {
|
131 | * heading: 'Callbacks',
|
132 | * namespace: 'callbacks'
|
133 | * }
|
134 | */
|
135 | export declare function earlyStopping(args?: EarlyStoppingCallbackArgs): EarlyStopping;
|
136 | export declare const callbacks: {
|
137 | earlyStopping: typeof earlyStopping;
|
138 | };
|