UNPKG

5.3 kBTypeScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { CustomCallback, LayersModel } from '@tensorflow/tfjs';
18export declare const progressBarHelper: {
19 ProgressBar: any;
20 log: Function;
21};
22/**
23 * Terminal-based progress bar callback for tf.Model.fit().
24 */
25export declare class ProgbarLogger extends CustomCallback {
26 private numTrainBatchesPerEpoch;
27 private progressBar;
28 private currentEpochBegin;
29 private epochDurationMillis;
30 private usPerStep;
31 private batchesInLatestEpoch;
32 private terminalWidth;
33 private readonly RENDER_THROTTLE_MS;
34 /**
35 * Construtor of LoggingCallback.
36 */
37 constructor();
38 private formatLogsAsMetricsContent;
39 private isFieldRelevant;
40}
41/**
42 * Get a succint string representation of a number.
43 *
44 * Uses decimal notation if the number isn't too small.
45 * Otherwise, use engineering notation.
46 *
47 * @param x Input number.
48 * @return Succinct string representing `x`.
49 */
50export declare function getSuccinctNumberDisplay(x: number): string;
51/**
52 * Determine the number of decimal places to display.
53 *
54 * @param x Number to display.
55 * @return Number of decimal places to display for `x`.
56 */
57export declare function getDisplayDecimalPlaces(x: number): number;
58export interface TensorBoardCallbackArgs {
59 /**
60 * The frequency at which loss and metric values are written to logs.
61 *
62 * Currently supported options are:
63 *
64 * - 'batch': Write logs at the end of every batch of training, in addition
65 * to the end of every epoch of training.
66 * - 'epoch': Write logs at the end of every epoch of training.
67 *
68 * Note that writing logs too often slows down the training.
69 *
70 * Default: 'epoch'.
71 */
72 updateFreq?: 'batch' | 'epoch';
73 /**
74 * The frequency (in epochs) at which to compute activation and weight
75 * histograms for the layers of the model.
76 *
77 * If set to 0, histograms won't be computed.
78 *
79 * Validation data (or split) must be specified for histogram visualizations.
80 *
81 * Default: 0.
82 */
83 histogramFreq?: number;
84}
85/**
86 * Callback for logging to TensorBoard during training.
87 *
88 * Users are expected to access this class through the `tensorBoardCallback()`
89 * factory method instead.
90 */
91export declare class TensorBoardCallback extends CustomCallback {
92 readonly logdir: string;
93 private model;
94 private trainWriter;
95 private valWriter;
96 private batchesSeen;
97 private readonly args;
98 constructor(logdir?: string, args?: TensorBoardCallbackArgs);
99 setModel(model: LayersModel): void;
100 private logMetrics;
101 private logWeights;
102 private ensureTrainWriterCreated;
103 private ensureValWriterCreated;
104}
105/**
106 * Callback for logging to TensorBoard during training.
107 *
108 * Writes the loss and metric values (if any) to the specified log directory
109 * (`logdir`) which can be ingested and visualized by TensorBoard.
110 * This callback is usually passed as a callback to `tf.Model.fit()` or
111 * `tf.Model.fitDataset()` calls during model training. The frequency at which
112 * the values are logged can be controlled with the `updateFreq` field of the
113 * configuration object (2nd argument).
114 *
115 * Usage example:
116 * ```js
117 * // Constructor a toy multilayer-perceptron regressor for demo purpose.
118 * const model = tf.sequential();
119 * model.add(
120 * tf.layers.dense({units: 100, activation: 'relu', inputShape: [200]}));
121 * model.add(tf.layers.dense({units: 1}));
122 * model.compile({
123 * loss: 'meanSquaredError',
124 * optimizer: 'sgd',
125 * metrics: ['MAE']
126 * });
127 *
128 * // Generate some random fake data for demo purpose.
129 * const xs = tf.randomUniform([10000, 200]);
130 * const ys = tf.randomUniform([10000, 1]);
131 * const valXs = tf.randomUniform([1000, 200]);
132 * const valYs = tf.randomUniform([1000, 1]);
133 *
134 * // Start model training process.
135 * await model.fit(xs, ys, {
136 * epochs: 100,
137 * validationData: [valXs, valYs],
138 * // Add the tensorBoard callback here.
139 * callbacks: tf.node.tensorBoard('/tmp/fit_logs_1')
140 * });
141 * ```
142 *
143 * Then you can use the following commands to point tensorboard
144 * to the logdir:
145 *
146 * ```sh
147 * pip install tensorboard # Unless you've already installed it.
148 * tensorboard --logdir /tmp/fit_logs_1
149 * ```
150 *
151 * @param logdir Directory to which the logs will be written.
152 * @param args Optional configuration arguments.
153 * @returns An instance of `TensorBoardCallback`, which is a subclass of
154 * `tf.CustomCallback`.
155 *
156 * @doc {heading: 'TensorBoard', namespace: 'node'}
157 */
158export declare function tensorBoard(logdir?: string, args?: TensorBoardCallbackArgs): TensorBoardCallback;