1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 | import { Tensor } from '@tensorflow/tfjs-core';
|
12 | import { Container } from './engine/container';
|
13 | import { Logs, UnresolvedLogs } from './logs';
|
14 |
|
15 | export declare enum ModelLoggingVerbosity {
|
16 | SILENT = 0,
|
17 | VERBOSE = 1
|
18 | }
|
19 |
|
20 | export declare const DEFAULT_YIELD_EVERY_MS = 125;
|
21 | export type Params = {
|
22 | [key: string]: number | string | boolean | number[] | string[] | boolean[];
|
23 | };
|
24 | export type YieldEveryOptions = 'auto' | 'batch' | 'epoch' | 'never' | number;
|
25 |
|
26 |
|
27 |
|
28 |
|
29 |
|
30 |
|
31 |
|
32 |
|
33 |
|
34 |
|
35 |
|
36 |
|
37 |
|
38 |
|
39 |
|
40 |
|
41 |
|
42 |
|
43 | export declare abstract class BaseCallback {
|
44 | validationData: Tensor | Tensor[];
|
45 | |
46 |
|
47 |
|
48 | params: Params;
|
49 | setParams(params: Params): void;
|
50 | onEpochBegin(epoch: number, logs?: UnresolvedLogs): Promise<void>;
|
51 | onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
|
52 | onBatchBegin(batch: number, logs?: UnresolvedLogs): Promise<void>;
|
53 | onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>;
|
54 | onTrainBegin(logs?: UnresolvedLogs): Promise<void>;
|
55 | onTrainEnd(logs?: UnresolvedLogs): Promise<void>;
|
56 | setModel(model: Container): void;
|
57 | }
|
58 |
|
59 |
|
60 |
|
61 | export declare class CallbackList {
|
62 | callbacks: BaseCallback[];
|
63 | queueLength: number;
|
64 | |
65 |
|
66 |
|
67 |
|
68 |
|
69 |
|
70 | constructor(callbacks?: BaseCallback[], queueLength?: number);
|
71 | append(callback: BaseCallback): void;
|
72 | setParams(params: Params): void;
|
73 | setModel(model: Container): void;
|
74 | /**
|
75 | * Called at the start of an epoch.
|
76 | * @param epoch Index of epoch.
|
77 | * @param logs Dictionary of logs.
|
78 | */
|
79 | onEpochBegin(epoch: number, logs?: UnresolvedLogs): Promise<void>;
|
80 | /**
|
81 | * Called at the end of an epoch.
|
82 | * @param epoch Index of epoch.
|
83 | * @param logs Dictionary of logs.
|
84 | */
|
85 | onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
|
86 | /**
|
87 | * Called right before processing a batch.
|
88 | * @param batch Index of batch within the current epoch.
|
89 | * @param logs Dictionary of logs.
|
90 | */
|
91 | onBatchBegin(batch: number, logs?: UnresolvedLogs): Promise<void>;
|
92 | /**
|
93 | * Called at the end of a batch.
|
94 | * @param batch Index of batch within the current epoch.
|
95 | * @param logs Dictionary of logs.
|
96 | */
|
97 | onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>;
|
98 | /**
|
99 | * Called at the beginning of training.
|
100 | * @param logs Dictionary of logs.
|
101 | */
|
102 | onTrainBegin(logs?: UnresolvedLogs): Promise<void>;
|
103 | /**
|
104 | * Called at the end of training.
|
105 | * @param logs Dictionary of logs.
|
106 | */
|
107 | onTrainEnd(logs?: UnresolvedLogs): Promise<void>;
|
108 | }
|
109 | /**
|
110 | * Callback that accumulates epoch averages of metrics.
|
111 | *
|
112 | * This callback is automatically applied to every LayersModel.
|
113 | */
|
114 | export declare class BaseLogger extends BaseCallback {
|
115 | private seen;
|
116 | private totals;
|
117 | constructor();
|
118 | onEpochBegin(epoch: number): Promise<void>;
|
119 | onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>;
|
120 | onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
|
121 | }
|
122 | /**
|
123 | * Callback that records events into a `History` object. This callback is
|
124 | * automatically applied to every TF.js Layers model. The `History` object
|
125 | * gets returned by the `fit` method of models.
|
126 | */
|
127 | export declare class History extends BaseCallback {
|
128 | epoch: number[];
|
129 | history: {
|
130 | [key: string]: Array<number | Tensor>;
|
131 | };
|
132 | onTrainBegin(logs?: UnresolvedLogs): Promise<void>;
|
133 | onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
|
134 | |
135 |
|
136 |
|
137 | syncData(): Promise<void>;
|
138 | }
|
139 | export interface CustomCallbackArgs {
|
140 | onTrainBegin?: (logs?: Logs) => void | Promise<void>;
|
141 | onTrainEnd?: (logs?: Logs) => void | Promise<void>;
|
142 | onEpochBegin?: (epoch: number, logs?: Logs) => void | Promise<void>;
|
143 | onEpochEnd?: (epoch: number, logs?: Logs) => void | Promise<void>;
|
144 | onBatchBegin?: (batch: number, logs?: Logs) => void | Promise<void>;
|
145 | onBatchEnd?: (batch: number, logs?: Logs) => void | Promise<void>;
|
146 | onYield?: (epoch: number, batch: number, logs: Logs) => void | Promise<void>;
|
147 | nowFunc?: Function;
|
148 | nextFrameFunc?: Function;
|
149 | }
|
150 |
|
151 |
|
152 |
|
153 | export declare class CustomCallback extends BaseCallback {
|
154 | protected readonly trainBegin: (logs?: Logs) => void | Promise<void>;
|
155 | protected readonly trainEnd: (logs?: Logs) => void | Promise<void>;
|
156 | protected readonly epochBegin: (epoch: number, logs?: Logs) => void | Promise<void>;
|
157 | protected readonly epochEnd: (epoch: number, logs?: Logs) => void | Promise<void>;
|
158 | protected readonly batchBegin: (batch: number, logs?: Logs) => void | Promise<void>;
|
159 | protected readonly batchEnd: (batch: number, logs?: Logs) => void | Promise<void>;
|
160 | protected readonly yield: (epoch: number, batch: number, logs: Logs) => void | Promise<void>;
|
161 | private yieldEvery;
|
162 | private currentEpoch;
|
163 | nowFunc: Function;
|
164 | nextFrameFunc: Function;
|
165 | constructor(args: CustomCallbackArgs, yieldEvery?: YieldEveryOptions);
|
166 | maybeWait(epoch: number, batch: number, logs: UnresolvedLogs): Promise<void>;
|
167 | onEpochBegin(epoch: number, logs?: UnresolvedLogs): Promise<void>;
|
168 | onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
|
169 | onBatchBegin(batch: number, logs?: UnresolvedLogs): Promise<void>;
|
170 | onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>;
|
171 | onTrainBegin(logs?: UnresolvedLogs): Promise<void>;
|
172 | onTrainEnd(logs?: UnresolvedLogs): Promise<void>;
|
173 | }
|
174 | /**
|
175 | * Standardize callbacks or configurations of them to an Array of callbacks.
|
176 | */
|
177 | export declare function standardizeCallbacks(callbacks: BaseCallback | BaseCallback[] | CustomCallbackArgs | CustomCallbackArgs[], yieldEvery: YieldEveryOptions): BaseCallback[];
|
178 | export declare type BaseCallbackConstructor = {
|
179 | new (): BaseCallback;
|
180 | };
|
181 |
|
182 |
|
183 |
|
184 |
|
185 | export declare class CallbackConstructorRegistry {
|
186 | private static constructors;
|
187 | |
188 |
|
189 |
|
190 | private constructor();
|
191 | /**
|
192 | * Register a tf.LayersModel.fit() callback constructor.
|
193 | *
|
194 | * The registered callback constructor will be used to instantiate
|
195 | * callbacks for every tf.LayersModel.fit() call afterwards.
|
196 | *
|
197 | * @param verbosityLevel Level of verbosity at which the `callbackConstructor`
|
198 | * is to be reigstered.
|
199 | * @param callbackConstructor A no-arg constructor for `tf.Callback`.
|
200 | * @throws Error, if the same callbackConstructor has been registered before,
|
201 | * either at the same or a different `verbosityLevel`.
|
202 | */
|
203 | static registerCallbackConstructor(verbosityLevel: number, callbackConstructor: BaseCallbackConstructor): void;
|
204 | private static checkForDuplicate;
|
205 | /**
|
206 | * Clear all registered callback constructors.
|
207 | */
|
208 | protected static clear(): void;
|
209 | /**
|
210 | * Create callbacks using the registered callback constructors.
|
211 | *
|
212 | * Given `verbosityLevel`, all constructors registered at that level or above
|
213 | * will be called and the instantiated callbacks will be used.
|
214 | *
|
215 | * @param verbosityLevel: Level of verbosity.
|
216 | */
|
217 | static createCallbacks(verbosityLevel: number): BaseCallback[];
|
218 | }
|
219 | export declare function configureCallbacks(callbacks: BaseCallback[], verbose: ModelLoggingVerbosity, epochs: number, initialEpoch: number, numTrainSamples: number, stepsPerEpoch: number, batchSize: number, doValidation: boolean, callbackMetrics: string[]): {
|
220 | callbackList: CallbackList;
|
221 | history: History;
|
222 | };
|