1 | /**
|
2 | * @license
|
3 | * Copyright 2018 Google LLC
|
4 | *
|
5 | * Use of this source code is governed by an MIT-style
|
6 | * license that can be found in the LICENSE file or at
|
7 | * https://opensource.org/licenses/MIT.
|
8 | * =============================================================================
|
9 | */
|
10 | /* Original source: keras/callbacks.py */
|
11 | import { BaseCallback } from './base_callbacks';
|
12 | import { LayersModel } from './engine/training';
|
13 | import { NotImplementedError } from './errors';
|
14 | import { resolveScalarsInLogs } from './logs';
|
15 | export class Callback extends BaseCallback {
|
16 | constructor() {
|
17 | super(...arguments);
|
18 | /** Instance of `keras.models.Model`. Reference of the model being trained. */
|
19 | this.model = null;
|
20 | }
|
21 | setModel(model) {
|
22 | if (!(model instanceof LayersModel)) {
|
23 | throw new Error('model must be a LayersModel, not some other Container');
|
24 | }
|
25 | this.model = model;
|
26 | }
|
27 | }
|
28 | function less(currVal, prevVal) {
|
29 | return currVal < prevVal;
|
30 | }
|
31 | function greater(currVal, prevVal) {
|
32 | return currVal > prevVal;
|
33 | }
|
34 | /**
|
35 | * A Callback that stops training when a monitored quantity has stopped
|
36 | * improving.
|
37 | */
|
38 | export class EarlyStopping extends Callback {
|
39 | constructor(args) {
|
40 | super();
|
41 | if (args == null) {
|
42 | args = {};
|
43 | }
|
44 | if (args.restoreBestWeights) {
|
45 | throw new NotImplementedError('restoreBestWeights = True is not implemented in EarlyStopping yet.');
|
46 | }
|
47 | this.monitor = args.monitor || 'val_loss';
|
48 | this.minDelta = Math.abs(args.minDelta || 0);
|
49 | this.patience = args.patience || 0;
|
50 | this.verbose = args.verbose || 0;
|
51 | this.mode = args.mode || 'auto';
|
52 | this.baseline = args.baseline;
|
53 | if (['auto', 'min', 'max'].indexOf(this.mode) === -1) {
|
54 | console.warn(`EarlyStopping mode '${this.mode}' is invalid. ` +
|
55 | `Falling back to mode 'auto'.`);
|
56 | this.mode = 'auto';
|
57 | }
|
58 | if (this.mode === 'min') {
|
59 | this.monitorFunc = less;
|
60 | }
|
61 | else if (this.mode === 'max') {
|
62 | this.monitorFunc = greater;
|
63 | }
|
64 | else {
|
65 | // For mode === 'auto'.
|
66 | if (this.monitor.indexOf('acc') !== -1) {
|
67 | this.monitorFunc = greater;
|
68 | }
|
69 | else {
|
70 | this.monitorFunc = less;
|
71 | }
|
72 | }
|
73 | if (this.monitorFunc === less) {
|
74 | this.minDelta *= -1;
|
75 | }
|
76 | }
|
77 | async onTrainBegin(logs) {
|
78 | this.wait = 0;
|
79 | this.stoppedEpoch = 0;
|
80 | if (this.baseline != null) {
|
81 | this.best = this.baseline;
|
82 | }
|
83 | else {
|
84 | this.best = this.monitorFunc === less ? Infinity : -Infinity;
|
85 | }
|
86 | }
|
87 | async onEpochEnd(epoch, logs) {
|
88 | await resolveScalarsInLogs(logs);
|
89 | const current = this.getMonitorValue(logs);
|
90 | if (current == null) {
|
91 | return;
|
92 | }
|
93 | if (this.monitorFunc(current - this.minDelta, this.best)) {
|
94 | this.best = current;
|
95 | this.wait = 0;
|
96 | // TODO(cais): Logic for restoreBestWeights.
|
97 | }
|
98 | else {
|
99 | this.wait++;
|
100 | if (this.wait >= this.patience) {
|
101 | this.stoppedEpoch = epoch;
|
102 | this.model.stopTraining = true;
|
103 | }
|
104 | // TODO(cais): Logic for restoreBestWeights.
|
105 | }
|
106 | }
|
107 | async onTrainEnd(logs) {
|
108 | if (this.stoppedEpoch > 0 && this.verbose) {
|
109 | console.log(`Epoch ${this.stoppedEpoch}: early stopping.`);
|
110 | }
|
111 | }
|
112 | getMonitorValue(logs) {
|
113 | if (logs == null) {
|
114 | logs = {};
|
115 | }
|
116 | const monitorValue = logs[this.monitor];
|
117 | if (monitorValue == null) {
|
118 | console.warn(`Metric for EarlyStopping ${this.monitor} is not available. ` +
|
119 | `Available metrics are: ${Object.keys(logs)}`);
|
120 | }
|
121 | return monitorValue;
|
122 | }
|
123 | }
|
124 | /**
|
125 | * Factory function for a Callback that stops training when a monitored
|
126 | * quantity has stopped improving.
|
127 | *
|
128 | * Early stopping is a type of regularization, and protects model against
|
129 | * overfitting.
|
130 | *
|
131 | * The following example based on fake data illustrates how this callback
|
132 | * can be used during `tf.LayersModel.fit()`:
|
133 | *
|
134 | * ```js
|
135 | * const model = tf.sequential();
|
136 | * model.add(tf.layers.dense({
|
137 | * units: 3,
|
138 | * activation: 'softmax',
|
139 | * kernelInitializer: 'ones',
|
140 | * inputShape: [2]
|
141 | * }));
|
142 | * const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
|
143 | * const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
|
144 | * const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
|
145 | * const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
|
146 | * model.compile(
|
147 | * {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});
|
148 | *
|
149 | * // Without the EarlyStopping callback, the val_acc value would be:
|
150 | * // 0.5, 0.5, 0.5, 0.5, ...
|
151 | * // With val_acc being monitored, training should stop after the 2nd epoch.
|
152 | * const history = await model.fit(xs, ys, {
|
153 | * epochs: 10,
|
154 | * validationData: [xsVal, ysVal],
|
155 | * callbacks: tf.callbacks.earlyStopping({monitor: 'val_acc'})
|
156 | * });
|
157 | *
|
158 | * // Expect to see a length-2 array.
|
159 | * console.log(history.history.val_acc);
|
160 | * ```
|
161 | *
|
162 | * @doc {
|
163 | * heading: 'Callbacks',
|
164 | * namespace: 'callbacks'
|
165 | * }
|
166 | */
|
167 | export function earlyStopping(args) {
|
168 | return new EarlyStopping(args);
|
169 | }
|
170 | export const callbacks = { earlyStopping };
|
171 | //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"callbacks.js","sourceRoot":"","sources":["../../../../../tfjs-layers/src/callbacks.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH,yCAAyC;AAEzC,OAAO,EAAC,YAAY,EAAC,MAAM,kBAAkB,CAAC;AAE9C,OAAO,EAAC,WAAW,EAAC,MAAM,mBAAmB,CAAC;AAC9C,OAAO,EAAC,mBAAmB,EAAC,MAAM,UAAU,CAAC;AAC7C,OAAO,EAAO,oBAAoB,EAAC,MAAM,QAAQ,CAAC;AAElD,MAAM,OAAgB,QAAS,SAAQ,YAAY;IAAnD;;QACE,8EAA8E;QAC9E,UAAK,GAAgB,IAAI,CAAC;IAQ5B,CAAC;IANC,QAAQ,CAAC,KAAgB;QACvB,IAAI,CAAC,CAAC,KAAK,YAAY,WAAW,CAAC,EAAE;YACnC,MAAM,IAAI,KAAK,CAAC,uDAAuD,CAAC,CAAC;SAC1E;QACD,IAAI,CAAC,KAAK,GAAG,KAAK,CAAC;IACrB,CAAC;CACF;AA4DD,SAAS,IAAI,CAAC,OAAe,EAAE,OAAe;IAC5C,OAAO,OAAO,GAAG,OAAO,CAAC;AAC3B,CAAC;AAED,SAAS,OAAO,CAAC,OAAe,EAAE,OAAe;IAC/C,OAAO,OAAO,GAAG,OAAO,CAAC;AAC3B,CAAC;AAED;;;GAGG;AACH,MAAM,OAAO,aAAc,SAAQ,QAAQ;IAczC,YAAY,IAAgC;QAC1C,KAAK,EAAE,CAAC;QACR,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,IAAI,IAAI,CAAC,kBAAkB,EAAE;YAC3B,MAAM,IAAI,mBAAmB,CACzB,oEAAoE,CAAC,CAAC;SAC3E;QAED,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,OAAO,IAAI,UAAU,CAAC;QAC1C,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,IAAI,CAAC,CAAC,CAAC;QAC7C,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,IAAI,CAAC,CAAC;QACnC,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,OAAO,IAAI,CAAC,CAAC;QACjC,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC,IAAI,IAAI,MAAM,CAAC;QAChC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAE9B,IAAI,CAAC,MAAM,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,EAAE;YACpD,OAAO,CAAC,IAAI,CACR,uBAAuB,IAAI,CAAC,IAAI,gBAAgB;gBAChD,8BAA8B,CAAC,CAAC;YACpC,IAAI,CAAC,IAAI,GAAG,MAAM,CAAC;SACpB;QAED,IAAI,IAAI,CAAC,IAAI,KAAK,KAAK,EAAE;YACvB,IAAI,CAAC,WAAW,GAAG,IAAI,CAAC;SACzB;aAAM,IAAI,IAAI,CAAC,IAAI,KAAK,KAAK,EAAE;YAC9B,IAAI,CAAC,WAAW,GAAG,OAAO,CAAC;SAC5B;aAAM;YACL,uBAAuB;YACvB,IAAI,IAAI,CAAC,OAAO,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE;gBACtC,IAAI,CAAC,WAAW,GAAG,OAAO,CAAC;aAC5B;iBAAM;gBACL,IAAI,CAAC,WAAW,GAAG,IAAI,CAAC;aACzB;SACF;QAED,IAAI,IAAI,CAAC,WAAW,KAAK,IAAI,EAAE;YAC7B,IAAI,CAAC,QAAQ,IAAI,CAAC,CAAC,CAAC;SACrB;IACH,CAAC;IAED,KAAK,CAAC,YAAY,CAAC,IAAW;QAC5B,IAAI,CAAC,IAAI,GAAG,CAAC,CAAC;QACd,IAAI,CAAC,YAAY,GAAG,CAAC,CAAC;QACtB,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;YACzB,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC,QAAQ,CAAC;SAC3B;aAAM;YACL,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC,WAAW,KAAK,IAAI,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,QAAQ,CAAC;SAC9D;IACH,CAAC;IAED,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAW;QACzC,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;QACjC,MAAM,OAAO,GAAG,IAAI,CAAC,eAAe,CAAC,IAAI,CAAC,CAAC;QAC3C,IAAI,OAAO,IAAI,IAAI,EAAE;YACnB,OAAO;SACR;QAED,IAAI,IAAI,CAAC,WAAW,CAAC,OAAO,GAAG,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,IAAI,CAAC,EAAE;YACxD,IAAI,CAAC,IAAI,GAAG,OAAO,CAAC;YACpB,IAAI,CAAC,IAAI,GAAG,CAAC,CAAC;YACd,4CAA4C;SAC7C;aAAM;YACL,IAAI,CAAC,IAAI,EAAE,CAAC;YACZ,IAAI,IAAI,CAAC,IAAI,IAAI,IAAI,CAAC,QAAQ,EAAE;gBAC9B,IAAI,CAAC,YAAY,GAAG,KAAK,CAAC;gBAC1B,IAAI,CAAC,KAAK,CAAC,YAAY,GAAG,IAAI,CAAC;aAChC;YACD,4CAA4C;SAC7C;IACH,CAAC;IAED,KAAK,CAAC,UAAU,CAAC,IAAW;QAC1B,IAAI,IAAI,CAAC,YAAY,GAAG,CAAC,IAAI,IAAI,CAAC,OAAO,EAAE;YACzC,OAAO,CAAC,GAAG,CAAC,SAAS,IAAI,CAAC,YAAY,mBAAmB,CAAC,CAAC;SAC5D;IACH,CAAC;IAEO,eAAe,CAAC,IAAU;QAChC,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,MAAM,YAAY,GAAG,IAAI,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;QACxC,IAAI,YAAY,IAAI,IAAI,EAAE;YACxB,OAAO,CAAC,IAAI,CACR,4BAA4B,IAAI,CAAC,OAAO,qBAAqB;gBAC7D,0BAA0B,MAAM,CAAC,IAAI,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;SACpD;QACD,OAAO,YAAY,CAAC;IACtB,CAAC;CACF;AAED;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA0CG;AACH,MAAM,UAAU,aAAa,CAAC,IAAgC;IAC5D,OAAO,IAAI,aAAa,CAAC,IAAI,CAAC,CAAC;AACjC,CAAC;AAED,MAAM,CAAC,MAAM,SAAS,GAAG,EAAC,aAAa,EAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/* Original source: keras/callbacks.py */\n\nimport {BaseCallback} from './base_callbacks';\nimport {Container} from './engine/container';\nimport {LayersModel} from './engine/training';\nimport {NotImplementedError} from './errors';\nimport {Logs, resolveScalarsInLogs} from './logs';\n\nexport abstract class Callback extends BaseCallback {\n  /** Instance of `keras.models.Model`. Reference of the model being trained. */\n  model: LayersModel = null;\n\n  setModel(model: Container): void {\n    if (!(model instanceof LayersModel)) {\n      throw new Error('model must be a LayersModel, not some other Container');\n    }\n    this.model = model;\n  }\n}\n\nexport interface EarlyStoppingCallbackArgs {\n  /**\n   * Quantity to be monitored.\n   *\n   * Defaults to 'val_loss'.\n   */\n  monitor?: string;\n\n  /**\n   * Minimum change in the monitored quantity to qualify as improvement,\n   * i.e., an absolute change of less than `minDelta` will count as no\n   * improvement.\n   *\n   * Defaults to 0.\n   */\n  minDelta?: number;\n\n  /**\n   * Number of epochs with no improvement after which training will be stopped.\n   *\n   * Defaults to 0.\n   */\n  patience?: number;\n\n  /** Verbosity mode. */\n  verbose?: number;\n\n  /**\n   * Mode: one of 'min', 'max', and 'auto'.\n   * - In 'min' mode, training will be stopped when the quantity monitored has\n   *   stopped decreasing.\n   * - In 'max' mode, training will be stopped when the quantity monitored has\n   *   stopped increasing.\n   * - In 'auto' mode, the direction is inferred automatically from the name of\n   *   the monitored quantity.\n   *\n   * Defaults to 'auto'.\n   */\n  mode?: 'auto'|'min'|'max';\n\n  /**\n   * Baseline value of the monitored quantity.\n   *\n   * If specified, training will be stopped if the model doesn't show\n   * improvement over the baseline.\n   */\n  baseline?: number;\n\n  /**\n   * Whether to restore model weights from the epoch with the best value\n   * of the monitored quantity. If `False`, the model weights obtained at the\n   * at the last step of training are used.\n   *\n   * **`True` is not supported yet.**\n   */\n  restoreBestWeights?: boolean;\n}\n\nfunction less(currVal: number, prevVal: number) {\n  return currVal < prevVal;\n}\n\nfunction greater(currVal: number, prevVal: number) {\n  return currVal > prevVal;\n}\n\n/**\n * A Callback that stops training when a monitored quantity has stopped\n * improving.\n */\nexport class EarlyStopping extends Callback {\n  protected readonly monitor: string;\n  protected readonly minDelta: number;\n  protected readonly patience: number;\n  protected readonly baseline: number;\n  protected readonly verbose: number;\n  protected readonly mode: 'auto'|'min'|'max';\n\n  protected monitorFunc: (currVal: number, prevVal: number) => boolean;\n\n  private wait: number;\n  private stoppedEpoch: number;\n  private best: number;\n\n  constructor(args?: EarlyStoppingCallbackArgs) {\n    super();\n    if (args == null) {\n      args = {};\n    }\n    if (args.restoreBestWeights) {\n      throw new NotImplementedError(\n          'restoreBestWeights = True is not implemented in EarlyStopping yet.');\n    }\n\n    this.monitor = args.monitor || 'val_loss';\n    this.minDelta = Math.abs(args.minDelta || 0);\n    this.patience = args.patience || 0;\n    this.verbose = args.verbose || 0;\n    this.mode = args.mode || 'auto';\n    this.baseline = args.baseline;\n\n    if (['auto', 'min', 'max'].indexOf(this.mode) === -1) {\n      console.warn(\n          `EarlyStopping mode '${this.mode}' is invalid. ` +\n          `Falling back to mode 'auto'.`);\n      this.mode = 'auto';\n    }\n\n    if (this.mode === 'min') {\n      this.monitorFunc = less;\n    } else if (this.mode === 'max') {\n      this.monitorFunc = greater;\n    } else {\n      // For mode === 'auto'.\n      if (this.monitor.indexOf('acc') !== -1) {\n        this.monitorFunc = greater;\n      } else {\n        this.monitorFunc = less;\n      }\n    }\n\n    if (this.monitorFunc === less) {\n      this.minDelta *= -1;\n    }\n  }\n\n  async onTrainBegin(logs?: Logs) {\n    this.wait = 0;\n    this.stoppedEpoch = 0;\n    if (this.baseline != null) {\n      this.best = this.baseline;\n    } else {\n      this.best = this.monitorFunc === less ? Infinity : -Infinity;\n    }\n  }\n\n  async onEpochEnd(epoch: number, logs?: Logs) {\n    await resolveScalarsInLogs(logs);\n    const current = this.getMonitorValue(logs);\n    if (current == null) {\n      return;\n    }\n\n    if (this.monitorFunc(current - this.minDelta, this.best)) {\n      this.best = current;\n      this.wait = 0;\n      // TODO(cais): Logic for restoreBestWeights.\n    } else {\n      this.wait++;\n      if (this.wait >= this.patience) {\n        this.stoppedEpoch = epoch;\n        this.model.stopTraining = true;\n      }\n      // TODO(cais): Logic for restoreBestWeights.\n    }\n  }\n\n  async onTrainEnd(logs?: Logs) {\n    if (this.stoppedEpoch > 0 && this.verbose) {\n      console.log(`Epoch ${this.stoppedEpoch}: early stopping.`);\n    }\n  }\n\n  private getMonitorValue(logs: Logs) {\n    if (logs == null) {\n      logs = {};\n    }\n    const monitorValue = logs[this.monitor];\n    if (monitorValue == null) {\n      console.warn(\n          `Metric for EarlyStopping ${this.monitor} is not available. ` +\n          `Available metrics are: ${Object.keys(logs)}`);\n    }\n    return monitorValue;\n  }\n}\n\n/**\n * Factory function for a Callback that stops training when a monitored\n * quantity has stopped improving.\n *\n * Early stopping is a type of regularization, and protects model against\n * overfitting.\n *\n * The following example based on fake data illustrates how this callback\n * can be used during `tf.LayersModel.fit()`:\n *\n * ```js\n * const model = tf.sequential();\n * model.add(tf.layers.dense({\n *   units: 3,\n *   activation: 'softmax',\n *   kernelInitializer: 'ones',\n *   inputShape: [2]\n * }));\n * const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n * const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);\n * const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);\n * const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);\n * model.compile(\n *     {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});\n *\n * // Without the EarlyStopping callback, the val_acc value would be:\n * //   0.5, 0.5, 0.5, 0.5, ...\n * // With val_acc being monitored, training should stop after the 2nd epoch.\n * const history = await model.fit(xs, ys, {\n *   epochs: 10,\n *   validationData: [xsVal, ysVal],\n *   callbacks: tf.callbacks.earlyStopping({monitor: 'val_acc'})\n * });\n *\n * // Expect to see a length-2 array.\n * console.log(history.history.val_acc);\n * ```\n *\n * @doc {\n *   heading: 'Callbacks',\n *   namespace: 'callbacks'\n * }\n */\nexport function earlyStopping(args?: EarlyStoppingCallbackArgs) {\n  return new EarlyStopping(args);\n}\n\nexport const callbacks = {earlyStopping};\n"]} |
\ | No newline at end of file |