UNPKG

4.86 kBTypeScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { CustomCallback } from '@tensorflow/tfjs';
18export declare const progressBarHelper: {
19 ProgressBar: any;
20 log: Function;
21};
22/**
23 * Terminal-based progress bar callback for tf.Model.fit().
24 */
25export declare class ProgbarLogger extends CustomCallback {
26 private numTrainBatchesPerEpoch;
27 private progressBar;
28 private currentEpochBegin;
29 private epochDurationMillis;
30 private usPerStep;
31 private batchesInLatestEpoch;
32 private terminalWidth;
33 private readonly RENDER_THROTTLE_MS;
34 /**
35 * Construtor of LoggingCallback.
36 */
37 constructor();
38 private formatLogsAsMetricsContent;
39 private isFieldRelevant;
40}
41/**
42 * Get a succint string representation of a number.
43 *
44 * Uses decimal notation if the number isn't too small.
45 * Otherwise, use engineering notation.
46 *
47 * @param x Input number.
48 * @return Succinct string representing `x`.
49 */
50export declare function getSuccinctNumberDisplay(x: number): string;
51/**
52 * Determine the number of decimal places to display.
53 *
54 * @param x Number to display.
55 * @return Number of decimal places to display for `x`.
56 */
57export declare function getDisplayDecimalPlaces(x: number): number;
58export interface TensorBoardCallbackArgs {
59 /**
60 * The frequency at which loss and metric values are written to logs.
61 *
62 * Currently supported options are:
63 *
64 * - 'batch': Write logs at the end of every batch of training, in addition
65 * to the end of every epoch of training.
66 * - 'epoch': Write logs at the end of every epoch of training.
67 *
68 * Note that writing logs too often slows down the training.
69 *
70 * Default: 'epoch'.
71 */
72 updateFreq?: 'batch' | 'epoch';
73}
74/**
75 * Callback for logging to TensorBoard during training.
76 *
77 * Users are expected to access this class through the `tensorBoardCallback()`
78 * factory method instead.
79 */
80export declare class TensorBoardCallback extends CustomCallback {
81 readonly logdir: string;
82 private trainWriter;
83 private valWriter;
84 private batchesSeen;
85 private readonly args;
86 constructor(logdir?: string, args?: TensorBoardCallbackArgs);
87 private logMetrics;
88 private ensureTrainWriterCreated;
89 private ensureValWriterCreated;
90}
91/**
92 * Callback for logging to TensorBoard during training.
93 *
94 * Writes the loss and metric values (if any) to the specified log directory
95 * (`logdir`) which can be ingested and visualized by TensorBoard.
96 * This callback is usually passed as a callback to `tf.Model.fit()` or
97 * `tf.Model.fitDataset()` calls during model training. The frequency at which
98 * the values are logged can be controlled with the `updateFreq` field of the
99 * configuration object (2nd argument).
100 *
101 * Usage example:
102 * ```js
103 * // Constructor a toy multilayer-perceptron regressor for demo purpose.
104 * const model = tf.sequential();
105 * model.add(
106 * tf.layers.dense({units: 100, activation: 'relu', inputShape: [200]}));
107 * model.add(tf.layers.dense({units: 1}));
108 * model.compile({
109 * loss: 'meanSquaredError',
110 * optimizer: 'sgd',
111 * metrics: ['MAE']
112 * });
113 *
114 * // Generate some random fake data for demo purpose.
115 * const xs = tf.randomUniform([10000, 200]);
116 * const ys = tf.randomUniform([10000, 1]);
117 * const valXs = tf.randomUniform([1000, 200]);
118 * const valYs = tf.randomUniform([1000, 1]);
119 *
120 * // Start model training process.
121 * await model.fit(xs, ys, {
122 * epochs: 100,
123 * validationData: [valXs, valYs],
124 * // Add the tensorBoard callback here.
125 * callbacks: tf.node.tensorBoard('/tmp/fit_logs_1')
126 * });
127 * ```
128 *
129 * Then you can use the following commands to point tensorboard
130 * to the logdir:
131 *
132 * ```sh
133 * pip install tensorboard # Unless you've already installed it.
134 * tensorboard --logdir /tmp/fit_logs_1
135 * ```
136 *
137 * @param logdir Directory to which the logs will be written.
138 * @param args Optional configuration arguments.
139 * @returns An instance of `TensorBoardCallback`, which is a subclass of
140 * `tf.CustomCallback`.
141 *
142 * @doc {heading: 'TensorBoard', namespace: 'node'}
143 */
144export declare function tensorBoard(logdir?: string, args?: TensorBoardCallbackArgs): TensorBoardCallback;