UNPKG

20 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/* Original source: keras/callbacks.py */
11import { BaseCallback } from './base_callbacks';
12import { LayersModel } from './engine/training';
13import { NotImplementedError } from './errors';
14import { resolveScalarsInLogs } from './logs';
15export class Callback extends BaseCallback {
16 constructor() {
17 super(...arguments);
18 /** Instance of `keras.models.Model`. Reference of the model being trained. */
19 this.model = null;
20 }
21 setModel(model) {
22 if (!(model instanceof LayersModel)) {
23 throw new Error('model must be a LayersModel, not some other Container');
24 }
25 this.model = model;
26 }
27}
28function less(currVal, prevVal) {
29 return currVal < prevVal;
30}
31function greater(currVal, prevVal) {
32 return currVal > prevVal;
33}
34/**
35 * A Callback that stops training when a monitored quantity has stopped
36 * improving.
37 */
38export class EarlyStopping extends Callback {
39 constructor(args) {
40 super();
41 if (args == null) {
42 args = {};
43 }
44 if (args.restoreBestWeights) {
45 throw new NotImplementedError('restoreBestWeights = True is not implemented in EarlyStopping yet.');
46 }
47 this.monitor = args.monitor || 'val_loss';
48 this.minDelta = Math.abs(args.minDelta || 0);
49 this.patience = args.patience || 0;
50 this.verbose = args.verbose || 0;
51 this.mode = args.mode || 'auto';
52 this.baseline = args.baseline;
53 if (['auto', 'min', 'max'].indexOf(this.mode) === -1) {
54 console.warn(`EarlyStopping mode '${this.mode}' is invalid. ` +
55 `Falling back to mode 'auto'.`);
56 this.mode = 'auto';
57 }
58 if (this.mode === 'min') {
59 this.monitorFunc = less;
60 }
61 else if (this.mode === 'max') {
62 this.monitorFunc = greater;
63 }
64 else {
65 // For mode === 'auto'.
66 if (this.monitor.indexOf('acc') !== -1) {
67 this.monitorFunc = greater;
68 }
69 else {
70 this.monitorFunc = less;
71 }
72 }
73 if (this.monitorFunc === less) {
74 this.minDelta *= -1;
75 }
76 }
77 async onTrainBegin(logs) {
78 this.wait = 0;
79 this.stoppedEpoch = 0;
80 if (this.baseline != null) {
81 this.best = this.baseline;
82 }
83 else {
84 this.best = this.monitorFunc === less ? Infinity : -Infinity;
85 }
86 }
87 async onEpochEnd(epoch, logs) {
88 await resolveScalarsInLogs(logs);
89 const current = this.getMonitorValue(logs);
90 if (current == null) {
91 return;
92 }
93 if (this.monitorFunc(current - this.minDelta, this.best)) {
94 this.best = current;
95 this.wait = 0;
96 // TODO(cais): Logic for restoreBestWeights.
97 }
98 else {
99 this.wait++;
100 if (this.wait >= this.patience) {
101 this.stoppedEpoch = epoch;
102 this.model.stopTraining = true;
103 }
104 // TODO(cais): Logic for restoreBestWeights.
105 }
106 }
107 async onTrainEnd(logs) {
108 if (this.stoppedEpoch > 0 && this.verbose) {
109 console.log(`Epoch ${this.stoppedEpoch}: early stopping.`);
110 }
111 }
112 getMonitorValue(logs) {
113 if (logs == null) {
114 logs = {};
115 }
116 const monitorValue = logs[this.monitor];
117 if (monitorValue == null) {
118 console.warn(`Metric for EarlyStopping ${this.monitor} is not available. ` +
119 `Available metrics are: ${Object.keys(logs)}`);
120 }
121 return monitorValue;
122 }
123}
124/**
125 * Factory function for a Callback that stops training when a monitored
126 * quantity has stopped improving.
127 *
128 * Early stopping is a type of regularization, and protects model against
129 * overfitting.
130 *
131 * The following example based on fake data illustrates how this callback
132 * can be used during `tf.LayersModel.fit()`:
133 *
134 * ```js
135 * const model = tf.sequential();
136 * model.add(tf.layers.dense({
137 * units: 3,
138 * activation: 'softmax',
139 * kernelInitializer: 'ones',
140 * inputShape: [2]
141 * }));
142 * const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
143 * const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
144 * const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
145 * const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
146 * model.compile(
147 * {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});
148 *
149 * // Without the EarlyStopping callback, the val_acc value would be:
150 * // 0.5, 0.5, 0.5, 0.5, ...
151 * // With val_acc being monitored, training should stop after the 2nd epoch.
152 * const history = await model.fit(xs, ys, {
153 * epochs: 10,
154 * validationData: [xsVal, ysVal],
155 * callbacks: tf.callbacks.earlyStopping({monitor: 'val_acc'})
156 * });
157 *
158 * // Expect to see a length-2 array.
159 * console.log(history.history.val_acc);
160 * ```
161 *
162 * @doc {
163 * heading: 'Callbacks',
164 * namespace: 'callbacks'
165 * }
166 */
167export function earlyStopping(args) {
168 return new EarlyStopping(args);
169}
170export const callbacks = { earlyStopping };
171//# sourceMappingURL=data:application/json;base64,
\No newline at end of file