UNPKG

20 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/* Original source: keras/callbacks.py */
11import { BaseCallback } from './base_callbacks';
12import { LayersModel } from './engine/training';
13import { NotImplementedError } from './errors';
14import { resolveScalarsInLogs } from './logs';
15export class Callback extends BaseCallback {
16 constructor() {
17 super(...arguments);
18 /** Instance of `keras.models.Model`. Reference of the model being trained. */
19 this.model = null;
20 }
21 setModel(model) {
22 if (!(model instanceof LayersModel)) {
23 throw new Error('model must be a LayersModel, not some other Container');
24 }
25 this.model = model;
26 }
27}
28function less(currVal, prevVal) {
29 return currVal < prevVal;
30}
31function greater(currVal, prevVal) {
32 return currVal > prevVal;
33}
34/**
35 * A Callback that stops training when a monitored quantity has stopped
36 * improving.
37 */
38export class EarlyStopping extends Callback {
39 constructor(args) {
40 super();
41 if (args == null) {
42 args = {};
43 }
44 if (args.restoreBestWeights) {
45 throw new NotImplementedError('restoreBestWeights = True is not implemented in EarlyStopping yet.');
46 }
47 this.monitor = args.monitor || 'val_loss';
48 this.minDelta = Math.abs(args.minDelta || 0);
49 this.patience = args.patience || 0;
50 this.verbose = args.verbose || 0;
51 this.mode = args.mode || 'auto';
52 this.baseline = args.baseline;
53 if (['auto', 'min', 'max'].indexOf(this.mode) === -1) {
54 console.warn(`EarlyStopping mode '${this.mode}' is invalid. ` +
55 `Falling back to mode 'auto'.`);
56 this.mode = 'auto';
57 }
58 if (this.mode === 'min') {
59 this.monitorFunc = less;
60 }
61 else if (this.mode === 'max') {
62 this.monitorFunc = greater;
63 }
64 else {
65 // For mode === 'auto'.
66 if (this.monitor.indexOf('acc') !== -1) {
67 this.monitorFunc = greater;
68 }
69 else {
70 this.monitorFunc = less;
71 }
72 }
73 if (this.monitorFunc === less) {
74 this.minDelta *= -1;
75 }
76 }
77 async onTrainBegin(logs) {
78 this.wait = 0;
79 this.stoppedEpoch = 0;
80 if (this.baseline != null) {
81 this.best = this.baseline;
82 }
83 else {
84 this.best = this.monitorFunc === less ? Infinity : -Infinity;
85 }
86 }
87 async onEpochEnd(epoch, logs) {
88 await resolveScalarsInLogs(logs);
89 const current = this.getMonitorValue(logs);
90 if (current == null) {
91 return;
92 }
93 if (this.monitorFunc(current - this.minDelta, this.best)) {
94 this.best = current;
95 this.wait = 0;
96 // TODO(cais): Logic for restoreBestWeights.
97 }
98 else {
99 this.wait++;
100 if (this.wait >= this.patience) {
101 this.stoppedEpoch = epoch;
102 this.model.stopTraining = true;
103 }
104 // TODO(cais): Logic for restoreBestWeights.
105 }
106 }
107 async onTrainEnd(logs) {
108 if (this.stoppedEpoch > 0 && this.verbose) {
109 console.log(`Epoch ${this.stoppedEpoch}: early stopping.`);
110 }
111 }
112 getMonitorValue(logs) {
113 if (logs == null) {
114 logs = {};
115 }
116 const monitorValue = logs[this.monitor];
117 if (monitorValue == null) {
118 console.warn(`Metric for EarlyStopping ${this.monitor} is not available. ` +
119 `Available metrics are: ${Object.keys(logs)}`);
120 }
121 return monitorValue;
122 }
123}
124/**
125 * Factory function for a Callback that stops training when a monitored
126 * quantity has stopped improving.
127 *
128 * Early stopping is a type of regularization, and protects model against
129 * overfitting.
130 *
131 * The following example based on fake data illustrates how this callback
132 * can be used during `tf.LayersModel.fit()`:
133 *
134 * ```js
135 * const model = tf.sequential();
136 * model.add(tf.layers.dense({
137 * units: 3,
138 * activation: 'softmax',
139 * kernelInitializer: 'ones',
140 * inputShape: [2]
141 * }));
142 * const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
143 * const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
144 * const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
145 * const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
146 * model.compile(
147 * {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});
148 *
149 * // Without the EarlyStopping callback, the val_acc value would be:
150 * // 0.5, 0.5, 0.5, 0.5, ...
151 * // With val_acc being monitored, training should stop after the 2nd epoch.
152 * const history = await model.fit(xs, ys, {
153 * epochs: 10,
154 * validationData: [xsVal, ysVal],
155 * callbacks: tf.callbacks.earlyStopping({monitor: 'val_acc'})
156 * });
157 *
158 * // Expect to see a length-2 array.
159 * console.log(history.history.val_acc);
160 * ```
161 *
162 * @doc {
163 * heading: 'Callbacks',
164 * namespace: 'callbacks'
165 * }
166 */
167export function earlyStopping(args) {
168 return new EarlyStopping(args);
169}
170export const callbacks = { earlyStopping };
171//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"callbacks.js","sourceRoot":"","sources":["../../../../../tfjs-layers/src/callbacks.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH,yCAAyC;AAEzC,OAAO,EAAC,YAAY,EAAC,MAAM,kBAAkB,CAAC;AAE9C,OAAO,EAAC,WAAW,EAAC,MAAM,mBAAmB,CAAC;AAC9C,OAAO,EAAC,mBAAmB,EAAC,MAAM,UAAU,CAAC;AAC7C,OAAO,EAAO,oBAAoB,EAAC,MAAM,QAAQ,CAAC;AAElD,MAAM,OAAgB,QAAS,SAAQ,YAAY;IAAnD;;QACE,8EAA8E;QAC9E,UAAK,GAAgB,IAAI,CAAC;IAQ5B,CAAC;IANU,QAAQ,CAAC,KAAgB;QAChC,IAAI,CAAC,CAAC,KAAK,YAAY,WAAW,CAAC,EAAE;YACnC,MAAM,IAAI,KAAK,CAAC,uDAAuD,CAAC,CAAC;SAC1E;QACD,IAAI,CAAC,KAAK,GAAG,KAAK,CAAC;IACrB,CAAC;CACF;AA4DD,SAAS,IAAI,CAAC,OAAe,EAAE,OAAe;IAC5C,OAAO,OAAO,GAAG,OAAO,CAAC;AAC3B,CAAC;AAED,SAAS,OAAO,CAAC,OAAe,EAAE,OAAe;IAC/C,OAAO,OAAO,GAAG,OAAO,CAAC;AAC3B,CAAC;AAED;;;GAGG;AACH,MAAM,OAAO,aAAc,SAAQ,QAAQ;IAczC,YAAY,IAAgC;QAC1C,KAAK,EAAE,CAAC;QACR,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,IAAI,IAAI,CAAC,kBAAkB,EAAE;YAC3B,MAAM,IAAI,mBAAmB,CACzB,oEAAoE,CAAC,CAAC;SAC3E;QAED,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,OAAO,IAAI,UAAU,CAAC;QAC1C,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,IAAI,CAAC,CAAC,CAAC;QAC7C,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,IAAI,CAAC,CAAC;QACnC,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,OAAO,IAAI,CAAC,CAAC;QACjC,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC,IAAI,IAAI,MAAM,CAAC;QAChC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAE9B,IAAI,CAAC,MAAM,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,EAAE;YACpD,OAAO,CAAC,IAAI,CACR,uBAAuB,IAAI,CAAC,IAAI,gBAAgB;gBAChD,8BAA8B,CAAC,CAAC;YACpC,IAAI,CAAC,IAAI,GAAG,MAAM,CAAC;SACpB;QAED,IAAI,IAAI,CAAC,IAAI,KAAK,KAAK,EAAE;YACvB,IAAI,CAAC,WAAW,GAAG,IAAI,CAAC;SACzB;aAAM,IAAI,IAAI,CAAC,IAAI,KAAK,KAAK,EAAE;YAC9B,IAAI,CAAC,WAAW,GAAG,OAAO,CAAC;SAC5B;aAAM;YACL,uBAAuB;YACvB,IAAI,IAAI,CAAC,OAAO,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE;gBACtC,IAAI,CAAC,WAAW,GAAG,OAAO,CAAC;aAC5B;iBAAM;gBACL,IAAI,CAAC,WAAW,GAAG,IAAI,CAAC;aACzB;SACF;QAED,IAAI,IAAI,CAAC,WAAW,KAAK,IAAI,EAAE;YAC7B,IAAI,CAAC,QAAQ,IAAI,CAAC,CAAC,CAAC;SACrB;IACH,CAAC;IAEQ,KAAK,CAAC,YAAY,CAAC,IAAW;QACrC,IAAI,CAAC,IAAI,GAAG,CAAC,CAAC;QACd,IAAI,CAAC,YAAY,GAAG,CAAC,CAAC;QACtB,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;YACzB,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC,QAAQ,CAAC;SAC3B;aAAM;YACL,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC,WAAW,KAAK,IAAI,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,QAAQ,CAAC;SAC9D;IACH,CAAC;IAEQ,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAW;QAClD,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;QACjC,MAAM,OAAO,GAAG,IAAI,CAAC,eAAe,CAAC,IAAI,CAAC,CAAC;QAC3C,IAAI,OAAO,IAAI,IAAI,EAAE;YACnB,OAAO;SACR;QAED,IAAI,IAAI,CAAC,WAAW,CAAC,OAAO,GAAG,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,IAAI,CAAC,EAAE;YACxD,IAAI,CAAC,IAAI,GAAG,OAAO,CAAC;YACpB,IAAI,CAAC,IAAI,GAAG,CAAC,CAAC;YACd,4CAA4C;SAC7C;aAAM;YACL,IAAI,CAAC,IAAI,EAAE,CAAC;YACZ,IAAI,IAAI,CAAC,IAAI,IAAI,IAAI,CAAC,QAAQ,EAAE;gBAC9B,IAAI,CAAC,YAAY,GAAG,KAAK,CAAC;gBAC1B,IAAI,CAAC,KAAK,CAAC,YAAY,GAAG,IAAI,CAAC;aAChC;YACD,4CAA4C;SAC7C;IACH,CAAC;IAEQ,KAAK,CAAC,UAAU,CAAC,IAAW;QACnC,IAAI,IAAI,CAAC,YAAY,GAAG,CAAC,IAAI,IAAI,CAAC,OAAO,EAAE;YACzC,OAAO,CAAC,GAAG,CAAC,SAAS,IAAI,CAAC,YAAY,mBAAmB,CAAC,CAAC;SAC5D;IACH,CAAC;IAEO,eAAe,CAAC,IAAU;QAChC,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,MAAM,YAAY,GAAG,IAAI,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;QACxC,IAAI,YAAY,IAAI,IAAI,EAAE;YACxB,OAAO,CAAC,IAAI,CACR,4BAA4B,IAAI,CAAC,OAAO,qBAAqB;gBAC7D,0BAA0B,MAAM,CAAC,IAAI,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;SACpD;QACD,OAAO,YAAY,CAAC;IACtB,CAAC;CACF;AAED;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA0CG;AACH,MAAM,UAAU,aAAa,CAAC,IAAgC;IAC5D,OAAO,IAAI,aAAa,CAAC,IAAI,CAAC,CAAC;AACjC,CAAC;AAED,MAAM,CAAC,MAAM,SAAS,GAAG,EAAC,aAAa,EAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/* Original source: keras/callbacks.py */\n\nimport {BaseCallback} from './base_callbacks';\nimport {Container} from './engine/container';\nimport {LayersModel} from './engine/training';\nimport {NotImplementedError} from './errors';\nimport {Logs, resolveScalarsInLogs} from './logs';\n\nexport abstract class Callback extends BaseCallback {\n  /** Instance of `keras.models.Model`. Reference of the model being trained. */\n  model: LayersModel = null;\n\n  override setModel(model: Container): void {\n    if (!(model instanceof LayersModel)) {\n      throw new Error('model must be a LayersModel, not some other Container');\n    }\n    this.model = model;\n  }\n}\n\nexport interface EarlyStoppingCallbackArgs {\n  /**\n   * Quantity to be monitored.\n   *\n   * Defaults to 'val_loss'.\n   */\n  monitor?: string;\n\n  /**\n   * Minimum change in the monitored quantity to qualify as improvement,\n   * i.e., an absolute change of less than `minDelta` will count as no\n   * improvement.\n   *\n   * Defaults to 0.\n   */\n  minDelta?: number;\n\n  /**\n   * Number of epochs with no improvement after which training will be stopped.\n   *\n   * Defaults to 0.\n   */\n  patience?: number;\n\n  /** Verbosity mode. */\n  verbose?: number;\n\n  /**\n   * Mode: one of 'min', 'max', and 'auto'.\n   * - In 'min' mode, training will be stopped when the quantity monitored has\n   *   stopped decreasing.\n   * - In 'max' mode, training will be stopped when the quantity monitored has\n   *   stopped increasing.\n   * - In 'auto' mode, the direction is inferred automatically from the name of\n   *   the monitored quantity.\n   *\n   * Defaults to 'auto'.\n   */\n  mode?: 'auto'|'min'|'max';\n\n  /**\n   * Baseline value of the monitored quantity.\n   *\n   * If specified, training will be stopped if the model doesn't show\n   * improvement over the baseline.\n   */\n  baseline?: number;\n\n  /**\n   * Whether to restore model weights from the epoch with the best value\n   * of the monitored quantity. If `False`, the model weights obtained at the\n   * last step of training are used.\n   *\n   * **`True` is not supported yet.**\n   */\n  restoreBestWeights?: boolean;\n}\n\nfunction less(currVal: number, prevVal: number) {\n  return currVal < prevVal;\n}\n\nfunction greater(currVal: number, prevVal: number) {\n  return currVal > prevVal;\n}\n\n/**\n * A Callback that stops training when a monitored quantity has stopped\n * improving.\n */\nexport class EarlyStopping extends Callback {\n  protected readonly monitor: string;\n  protected readonly minDelta: number;\n  protected readonly patience: number;\n  protected readonly baseline: number;\n  protected readonly verbose: number;\n  protected readonly mode: 'auto'|'min'|'max';\n\n  protected monitorFunc: (currVal: number, prevVal: number) => boolean;\n\n  private wait: number;\n  private stoppedEpoch: number;\n  private best: number;\n\n  constructor(args?: EarlyStoppingCallbackArgs) {\n    super();\n    if (args == null) {\n      args = {};\n    }\n    if (args.restoreBestWeights) {\n      throw new NotImplementedError(\n          'restoreBestWeights = True is not implemented in EarlyStopping yet.');\n    }\n\n    this.monitor = args.monitor || 'val_loss';\n    this.minDelta = Math.abs(args.minDelta || 0);\n    this.patience = args.patience || 0;\n    this.verbose = args.verbose || 0;\n    this.mode = args.mode || 'auto';\n    this.baseline = args.baseline;\n\n    if (['auto', 'min', 'max'].indexOf(this.mode) === -1) {\n      console.warn(\n          `EarlyStopping mode '${this.mode}' is invalid. ` +\n          `Falling back to mode 'auto'.`);\n      this.mode = 'auto';\n    }\n\n    if (this.mode === 'min') {\n      this.monitorFunc = less;\n    } else if (this.mode === 'max') {\n      this.monitorFunc = greater;\n    } else {\n      // For mode === 'auto'.\n      if (this.monitor.indexOf('acc') !== -1) {\n        this.monitorFunc = greater;\n      } else {\n        this.monitorFunc = less;\n      }\n    }\n\n    if (this.monitorFunc === less) {\n      this.minDelta *= -1;\n    }\n  }\n\n  override async onTrainBegin(logs?: Logs) {\n    this.wait = 0;\n    this.stoppedEpoch = 0;\n    if (this.baseline != null) {\n      this.best = this.baseline;\n    } else {\n      this.best = this.monitorFunc === less ? Infinity : -Infinity;\n    }\n  }\n\n  override async onEpochEnd(epoch: number, logs?: Logs) {\n    await resolveScalarsInLogs(logs);\n    const current = this.getMonitorValue(logs);\n    if (current == null) {\n      return;\n    }\n\n    if (this.monitorFunc(current - this.minDelta, this.best)) {\n      this.best = current;\n      this.wait = 0;\n      // TODO(cais): Logic for restoreBestWeights.\n    } else {\n      this.wait++;\n      if (this.wait >= this.patience) {\n        this.stoppedEpoch = epoch;\n        this.model.stopTraining = true;\n      }\n      // TODO(cais): Logic for restoreBestWeights.\n    }\n  }\n\n  override async onTrainEnd(logs?: Logs) {\n    if (this.stoppedEpoch > 0 && this.verbose) {\n      console.log(`Epoch ${this.stoppedEpoch}: early stopping.`);\n    }\n  }\n\n  private getMonitorValue(logs: Logs) {\n    if (logs == null) {\n      logs = {};\n    }\n    const monitorValue = logs[this.monitor];\n    if (monitorValue == null) {\n      console.warn(\n          `Metric for EarlyStopping ${this.monitor} is not available. ` +\n          `Available metrics are: ${Object.keys(logs)}`);\n    }\n    return monitorValue;\n  }\n}\n\n/**\n * Factory function for a Callback that stops training when a monitored\n * quantity has stopped improving.\n *\n * Early stopping is a type of regularization, and protects model against\n * overfitting.\n *\n * The following example based on fake data illustrates how this callback\n * can be used during `tf.LayersModel.fit()`:\n *\n * ```js\n * const model = tf.sequential();\n * model.add(tf.layers.dense({\n *   units: 3,\n *   activation: 'softmax',\n *   kernelInitializer: 'ones',\n *   inputShape: [2]\n * }));\n * const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n * const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);\n * const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);\n * const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);\n * model.compile(\n *     {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});\n *\n * // Without the EarlyStopping callback, the val_acc value would be:\n * //   0.5, 0.5, 0.5, 0.5, ...\n * // With val_acc being monitored, training should stop after the 2nd epoch.\n * const history = await model.fit(xs, ys, {\n *   epochs: 10,\n *   validationData: [xsVal, ysVal],\n *   callbacks: tf.callbacks.earlyStopping({monitor: 'val_acc'})\n * });\n *\n * // Expect to see a length-2 array.\n * console.log(history.history.val_acc);\n * ```\n *\n * @doc {\n *   heading: 'Callbacks',\n *   namespace: 'callbacks'\n * }\n */\nexport function earlyStopping(args?: EarlyStoppingCallbackArgs) {\n  return new EarlyStopping(args);\n}\n\nexport const callbacks = {earlyStopping};\n"]}
\No newline at end of file