UNPKG

4.89 kBTypeScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google Inc. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { CustomCallback } from '@tensorflow/tfjs';
18export declare const progressBarHelper: {
19 ProgressBar: any;
20 log: Function;
21};
22/**
23 * Terminal-based progress bar callback for tf.Model.fit().
24 */
25export declare class ProgbarLogger extends CustomCallback {
26 private numTrainBatchesPerEpoch;
27 private progressBar;
28 private currentEpochBegin;
29 private epochDurationMillis;
30 private usPerStep;
31 private batchesInLatestEpoch;
32 private terminalWidth;
33 private readonly RENDER_THROTTLE_MS;
34 /**
35 * Construtor of LoggingCallback.
36 */
37 constructor();
38 private formatLogsAsMetricsContent;
39 private isFieldRelevant;
40}
41/**
42 * Get a succint string representation of a number.
43 *
44 * Uses decimal notation if the number isn't too small.
45 * Otherwise, use engineering notation.
46 *
47 * @param x Input number.
48 * @return Succinct string representing `x`.
49 */
50export declare function getSuccinctNumberDisplay(x: number): string;
51/**
52 * Determine the number of decimal places to display.
53 *
54 * @param x Number to display.
55 * @return Number of decimal places to display for `x`.
56 */
57export declare function getDisplayDecimalPlaces(x: number): number;
58export interface TensorBoardCallbackArgs {
59 /**
60 * The frequency at which loss and metric values are written to logs.
61 *
62 * Currently supported options are:
63 *
64 * - 'batch': Write logs at the end of every batch of training, in addition
65 * to the end of every epoch of training.
66 * - 'epoch': Write logs at the end of every epoch of training.
67 *
68 * Note that writing logs too often slows down the training.
69 *
70 * Default: 'epoch'.
71 */
72 updateFreq?: 'batch' | 'epoch';
73}
74/**
75 * Callback for logging to TensorBoard during training.
76 *
77 * Users are expected to access this class through the `tensorBoardCallback()`
78 * factory method instead.
79 */
80export declare class TensorBoardCallback extends CustomCallback {
81 readonly logdir: string;
82 private trainWriter;
83 private valWriter;
84 private batchesSeen;
85 private epochsSeen;
86 private readonly args;
87 constructor(logdir?: string, args?: TensorBoardCallbackArgs);
88 private logMetrics;
89 private ensureTrainWriterCreated;
90 private ensureValWriterCreated;
91}
92/**
93 * Callback for logging to TensorBoard during training.
94 *
95 * Writes the loss and metric values (if any) to the specified log directory
96 * (`logdir`) which can be ingested and visualized by TensorBoard.
97 * This callback is usually passed as a callback to `tf.Model.fit()` or
98 * `tf.Model.fitDataset()` calls during model training. The frequency at which
99 * the values are logged can be controlled with the `updateFreq` field of the
100 * configuration object (2nd argument).
101 *
102 * Usage example:
103 * ```js
104 * // Constructor a toy multilayer-perceptron regressor for demo purpose.
105 * const model = tf.sequential();
106 * model.add(
107 * tf.layers.dense({units: 100, activation: 'relu', inputShape: [200]}));
108 * model.add(tf.layers.dense({units: 1}));
109 * model.compile({
110 * loss: 'meanSquaredError',
111 * optimizer: 'sgd',
112 * metrics: ['MAE']
113 * });
114 *
115 * // Generate some random fake data for demo purpose.
116 * const xs = tf.randomUniform([10000, 200]);
117 * const ys = tf.randomUniform([10000, 1]);
118 * const valXs = tf.randomUniform([1000, 200]);
119 * const valYs = tf.randomUniform([1000, 1]);
120 *
121 * // Start model training process.
122 * await model.fit(xs, ys, {
123 * epochs: 100,
124 * validationData: [valXs, valYs],
125 * // Add the tensorBoard callback here.
126 * callbacks: tf.node.tensorBoard('/tmp/fit_logs_1')
127 * });
128 * ```
129 *
130 * Then you can use the following commands to point tensorboard
131 * to the logdir:
132 *
133 * ```sh
134 * pip install tensorboard # Unless you've already installed it.
135 * tensorboard --logdir /tmp/fit_logs_1
136 * ```
137 *
138 * @param logdir Directory to which the logs will be written.
139 * @param args Optional configuration arguments.
140 * @returns An instance of `TensorBoardCallback`, which is a subclass of
141 * `tf.CustomCallback`.
142 */
143/**
144 * @doc {heading: 'TensorBoard', namespace: 'node'}
145 */
146export declare function tensorBoard(logdir?: string, args?: TensorBoardCallbackArgs): TensorBoardCallback;