UNPKG

9.1 kBTypeScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/// <amd-module name="@tensorflow/tfjs-layers/dist/base_callbacks" />
11import { Tensor } from '@tensorflow/tfjs-core';
12import { Container } from './engine/container';
13import { Logs, UnresolvedLogs } from './logs';
14/** Verbosity logging level when fitting a model. */
15export declare enum ModelLoggingVerbosity {
16 SILENT = 0,
17 VERBOSE = 1
18}
19/** How often to yield to the main thread when training (in ms). */
20export declare const DEFAULT_YIELD_EVERY_MS = 125;
21export declare type Params = {
22 [key: string]: number | string | boolean | number[] | string[] | boolean[];
23};
24export declare type YieldEveryOptions = 'auto' | 'batch' | 'epoch' | 'never' | number;
25/**
26 * Abstract base class used to build new callbacks.
27 *
28 * The `logs` dictionary that callback methods take as argument will contain
29 * keys for quantities relevant to the current batch or epoch.
30 *
31 * Currently, the `.fit()` method of the `Sequential` model class
32 * will include the following quantities in the `logs` that
33 * it passes to its callbacks:
34 *
35 * onEpochEnd: Logs include `acc` and `loss`, and optionally include `valLoss`
36 * (if validation is enabled in `fit`), and `valAcc` (if validation and
37 * accuracy monitoring are enabled).
38 * onBatchBegin: Logs include `size`, the number of samples in the current
39 * batch.
40 * onBatchEnd: Logs include `loss`, and optionally `acc` (if accuracy monitoring
41 * is enabled).
42 */
43export declare abstract class BaseCallback {
44 validationData: Tensor | Tensor[];
45 /**
46 * Training parameters (eg. verbosity, batch size, number of epochs...).
47 */
48 params: Params;
49 setParams(params: Params): void;
50 onEpochBegin(epoch: number, logs?: UnresolvedLogs): Promise<void>;
51 onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
52 onBatchBegin(batch: number, logs?: UnresolvedLogs): Promise<void>;
53 onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>;
54 onTrainBegin(logs?: UnresolvedLogs): Promise<void>;
55 onTrainEnd(logs?: UnresolvedLogs): Promise<void>;
56 setModel(model: Container): void;
57}
58/**
59 * Container abstracting a list of callbacks.
60 */
61export declare class CallbackList {
62 callbacks: BaseCallback[];
63 queueLength: number;
64 /**
65 * Constructor of CallbackList.
66 * @param callbacks Array of `Callback` instances.
67 * @param queueLength Queue length for keeping running statistics over
68 * callback execution time.
69 */
70 constructor(callbacks?: BaseCallback[], queueLength?: number);
71 append(callback: BaseCallback): void;
72 setParams(params: Params): void;
73 setModel(model: Container): void;
74 /**
75 * Called at the start of an epoch.
76 * @param epoch Index of epoch.
77 * @param logs Dictionary of logs.
78 */
79 onEpochBegin(epoch: number, logs?: UnresolvedLogs): Promise<void>;
80 /**
81 * Called at the end of an epoch.
82 * @param epoch Index of epoch.
83 * @param logs Dictionary of logs.
84 */
85 onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
86 /**
87 * Called right before processing a batch.
88 * @param batch Index of batch within the current epoch.
89 * @param logs Dictionary of logs.
90 */
91 onBatchBegin(batch: number, logs?: UnresolvedLogs): Promise<void>;
92 /**
93 * Called at the end of a batch.
94 * @param batch Index of batch within the current epoch.
95 * @param logs Dictionary of logs.
96 */
97 onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>;
98 /**
99 * Called at the beginning of training.
100 * @param logs Dictionary of logs.
101 */
102 onTrainBegin(logs?: UnresolvedLogs): Promise<void>;
103 /**
104 * Called at the end of training.
105 * @param logs Dictionary of logs.
106 */
107 onTrainEnd(logs?: UnresolvedLogs): Promise<void>;
108}
109/**
110 * Callback that accumulates epoch averages of metrics.
111 *
112 * This callback is automatically applied to every LayersModel.
113 */
114export declare class BaseLogger extends BaseCallback {
115 private seen;
116 private totals;
117 constructor();
118 onEpochBegin(epoch: number): Promise<void>;
119 onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>;
120 onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
121}
122/**
123 * Callback that records events into a `History` object. This callback is
124 * automatically applied to every TF.js Layers model. The `History` object
125 * gets returned by the `fit` method of models.
126 */
127export declare class History extends BaseCallback {
128 epoch: number[];
129 history: {
130 [key: string]: Array<number | Tensor>;
131 };
132 onTrainBegin(logs?: UnresolvedLogs): Promise<void>;
133 onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
134 /**
135 * Await the values of all losses and metrics.
136 */
137 syncData(): Promise<void>;
138}
139export interface CustomCallbackArgs {
140 onTrainBegin?: (logs?: Logs) => void | Promise<void>;
141 onTrainEnd?: (logs?: Logs) => void | Promise<void>;
142 onEpochBegin?: (epoch: number, logs?: Logs) => void | Promise<void>;
143 onEpochEnd?: (epoch: number, logs?: Logs) => void | Promise<void>;
144 onBatchBegin?: (batch: number, logs?: Logs) => void | Promise<void>;
145 onBatchEnd?: (batch: number, logs?: Logs) => void | Promise<void>;
146 onYield?: (epoch: number, batch: number, logs: Logs) => void | Promise<void>;
147 nowFunc?: Function;
148 nextFrameFunc?: Function;
149}
150/**
151 * Custom callback for training.
152 */
153export declare class CustomCallback extends BaseCallback {
154 protected readonly trainBegin: (logs?: Logs) => void | Promise<void>;
155 protected readonly trainEnd: (logs?: Logs) => void | Promise<void>;
156 protected readonly epochBegin: (epoch: number, logs?: Logs) => void | Promise<void>;
157 protected readonly epochEnd: (epoch: number, logs?: Logs) => void | Promise<void>;
158 protected readonly batchBegin: (batch: number, logs?: Logs) => void | Promise<void>;
159 protected readonly batchEnd: (batch: number, logs?: Logs) => void | Promise<void>;
160 protected readonly yield: (epoch: number, batch: number, logs: Logs) => void | Promise<void>;
161 private yieldEvery;
162 private currentEpoch;
163 nowFunc: Function;
164 nextFrameFunc: Function;
165 constructor(args: CustomCallbackArgs, yieldEvery?: YieldEveryOptions);
166 maybeWait(epoch: number, batch: number, logs: UnresolvedLogs): Promise<void>;
167 onEpochBegin(epoch: number, logs?: UnresolvedLogs): Promise<void>;
168 onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
169 onBatchBegin(batch: number, logs?: UnresolvedLogs): Promise<void>;
170 onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>;
171 onTrainBegin(logs?: UnresolvedLogs): Promise<void>;
172 onTrainEnd(logs?: UnresolvedLogs): Promise<void>;
173}
174/**
175 * Standardize callbacks or configurations of them to an Array of callbacks.
176 */
177export declare function standardizeCallbacks(callbacks: BaseCallback | BaseCallback[] | CustomCallbackArgs | CustomCallbackArgs[], yieldEvery: YieldEveryOptions): BaseCallback[];
178export declare type BaseCallbackConstructor = {
179 new (): BaseCallback;
180};
181/**
182 * A global registry for callback constructors to be used during
183 * LayersModel.fit().
184 */
185export declare class CallbackConstructorRegistry {
186 private static constructors;
187 /**
188 * Blocks public access to constructor.
189 */
190 private constructor();
191 /**
192 * Register a tf.LayersModel.fit() callback constructor.
193 *
194 * The registered callback constructor will be used to instantiate
195 * callbacks for every tf.LayersModel.fit() call afterwards.
196 *
197 * @param verbosityLevel Level of verbosity at which the `callbackConstructor`
198 * is to be reigstered.
199 * @param callbackConstructor A no-arg constructor for `tf.Callback`.
200 * @throws Error, if the same callbackConstructor has been registered before,
201 * either at the same or a different `verbosityLevel`.
202 */
203 static registerCallbackConstructor(verbosityLevel: number, callbackConstructor: BaseCallbackConstructor): void;
204 private static checkForDuplicate;
205 /**
206 * Clear all registered callback constructors.
207 */
208 protected static clear(): void;
209 /**
210 * Create callbacks using the registered callback constructors.
211 *
212 * Given `verbosityLevel`, all constructors registered at that level or above
213 * will be called and the instantiated callbacks will be used.
214 *
215 * @param verbosityLevel: Level of verbosity.
216 */
217 static createCallbacks(verbosityLevel: number): BaseCallback[];
218}
219export declare function configureCallbacks(callbacks: BaseCallback[], verbose: ModelLoggingVerbosity, epochs: number, initialEpoch: number, numTrainSamples: number, stepsPerEpoch: number, batchSize: number, doValidation: boolean, callbackMetrics: string[]): {
220 callbackList: CallbackList;
221 history: History;
222};