UNPKG

17.2 kBJavaScriptView Raw
1"use strict";
2/**
3 * @license
4 * Copyright 2018 Google LLC. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 * =============================================================================
17 */
18var __extends = (this && this.__extends) || (function () {
19 var extendStatics = function (d, b) {
20 extendStatics = Object.setPrototypeOf ||
21 ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) ||
22 function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; };
23 return extendStatics(d, b);
24 };
25 return function (d, b) {
26 extendStatics(d, b);
27 function __() { this.constructor = d; }
28 d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
29 };
30})();
31var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
32 return new (P || (P = Promise))(function (resolve, reject) {
33 function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
34 function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
35 function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); }
36 step((generator = generator.apply(thisArg, _arguments || [])).next());
37 });
38};
39var __generator = (this && this.__generator) || function (thisArg, body) {
40 var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g;
41 return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g;
42 function verb(n) { return function (v) { return step([n, v]); }; }
43 function step(op) {
44 if (f) throw new TypeError("Generator is already executing.");
45 while (_) try {
46 if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t;
47 if (y = 0, t) op = [op[0] & 2, t.value];
48 switch (op[0]) {
49 case 0: case 1: t = op; break;
50 case 4: _.label++; return { value: op[1], done: false };
51 case 5: _.label++; y = op[1]; op = [0]; continue;
52 case 7: op = _.ops.pop(); _.trys.pop(); continue;
53 default:
54 if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; }
55 if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; }
56 if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; }
57 if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; }
58 if (t[2]) _.ops.pop();
59 _.trys.pop(); continue;
60 }
61 op = body.call(thisArg, _);
62 } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; }
63 if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true };
64 }
65};
66Object.defineProperty(exports, "__esModule", { value: true });
67var tfjs_1 = require("@tensorflow/tfjs");
68var path = require("path");
69var ProgressBar = require("progress");
70var tensorboard_1 = require("./tensorboard");
71// A helper class created for testing with the jasmine `spyOn` method, which
72// operates only on member methods of objects.
73// tslint:disable-next-line:no-any
74exports.progressBarHelper = {
75 ProgressBar: ProgressBar,
76 log: console.log
77};
78/**
79 * Terminal-based progress bar callback for tf.Model.fit().
80 */
81var ProgbarLogger = /** @class */ (function (_super) {
82 __extends(ProgbarLogger, _super);
83 /**
84 * Construtor of LoggingCallback.
85 */
86 function ProgbarLogger() {
87 var _this = _super.call(this, {
88 onTrainBegin: function (logs) { return __awaiter(_this, void 0, void 0, function () {
89 var samples, batchSize, steps;
90 return __generator(this, function (_a) {
91 samples = this.params.samples;
92 batchSize = this.params.batchSize;
93 steps = this.params.steps;
94 if (samples != null || steps != null) {
95 this.numTrainBatchesPerEpoch =
96 samples != null ? Math.ceil(samples / batchSize) : steps;
97 }
98 else {
99 // Undetermined number of batches per epoch, e.g., due to
100 // `fitDataset()` without `batchesPerEpoch`.
101 this.numTrainBatchesPerEpoch = 0;
102 }
103 return [2 /*return*/];
104 });
105 }); },
106 onEpochBegin: function (epoch, logs) { return __awaiter(_this, void 0, void 0, function () {
107 return __generator(this, function (_a) {
108 exports.progressBarHelper.log("Epoch " + (epoch + 1) + " / " + this.params.epochs);
109 this.currentEpochBegin = tfjs_1.util.now();
110 this.epochDurationMillis = null;
111 this.usPerStep = null;
112 this.batchesInLatestEpoch = 0;
113 this.terminalWidth = process.stderr.columns;
114 return [2 /*return*/];
115 });
116 }); },
117 onBatchEnd: function (batch, logs) { return __awaiter(_this, void 0, void 0, function () {
118 var maxMetricsStringLength, tickTokens;
119 return __generator(this, function (_a) {
120 switch (_a.label) {
121 case 0:
122 this.batchesInLatestEpoch++;
123 if (batch === 0) {
124 this.progressBar = new exports.progressBarHelper.ProgressBar('eta=:eta :bar :placeholderForLossesAndMetrics', {
125 width: Math.floor(0.5 * this.terminalWidth),
126 total: this.numTrainBatchesPerEpoch + 1,
127 head: ">",
128 renderThrottle: this.RENDER_THROTTLE_MS
129 });
130 }
131 maxMetricsStringLength = Math.floor(this.terminalWidth * 0.5 - 12);
132 tickTokens = {
133 placeholderForLossesAndMetrics: this.formatLogsAsMetricsContent(logs, maxMetricsStringLength)
134 };
135 if (this.numTrainBatchesPerEpoch === 0) {
136 // Undetermined number of batches per epoch.
137 this.progressBar.tick(0, tickTokens);
138 }
139 else {
140 this.progressBar.tick(tickTokens);
141 }
142 return [4 /*yield*/, tfjs_1.nextFrame()];
143 case 1:
144 _a.sent();
145 if (batch === this.numTrainBatchesPerEpoch - 1) {
146 this.epochDurationMillis = tfjs_1.util.now() - this.currentEpochBegin;
147 this.usPerStep = this.params.samples != null ?
148 this.epochDurationMillis / this.params.samples * 1e3 :
149 this.epochDurationMillis / this.batchesInLatestEpoch * 1e3;
150 }
151 return [2 /*return*/];
152 }
153 });
154 }); },
155 onEpochEnd: function (epoch, logs) { return __awaiter(_this, void 0, void 0, function () {
156 var lossesAndMetricsString;
157 return __generator(this, function (_a) {
158 switch (_a.label) {
159 case 0:
160 if (this.epochDurationMillis == null) {
161 // In cases where the number of batches per epoch is not determined,
162 // the calculation of the per-step duration is done at the end of the
163 // epoch. N.B., this includes the time spent on validation.
164 this.epochDurationMillis = tfjs_1.util.now() - this.currentEpochBegin;
165 this.usPerStep =
166 this.epochDurationMillis / this.batchesInLatestEpoch * 1e3;
167 }
168 this.progressBar.tick({ placeholderForLossesAndMetrics: '' });
169 lossesAndMetricsString = this.formatLogsAsMetricsContent(logs);
170 exports.progressBarHelper.log(this.epochDurationMillis.toFixed(0) + "ms " +
171 (this.usPerStep.toFixed(0) + "us/step - ") +
172 ("" + lossesAndMetricsString));
173 return [4 /*yield*/, tfjs_1.nextFrame()];
174 case 1:
175 _a.sent();
176 return [2 /*return*/];
177 }
178 });
179 }); },
180 }) || this;
181 _this.RENDER_THROTTLE_MS = 50;
182 return _this;
183 }
184 ProgbarLogger.prototype.formatLogsAsMetricsContent = function (logs, maxMetricsLength) {
185 var metricsContent = '';
186 var keys = Object.keys(logs).sort();
187 for (var _i = 0, keys_1 = keys; _i < keys_1.length; _i++) {
188 var key = keys_1[_i];
189 if (this.isFieldRelevant(key)) {
190 var value = logs[key];
191 metricsContent += key + "=" + getSuccinctNumberDisplay(value) + " ";
192 }
193 }
194 if (maxMetricsLength != null && metricsContent.length > maxMetricsLength) {
195 // Cut off metrics strings that are too long to avoid new lines being
196 // constantly created.
197 metricsContent = metricsContent.slice(0, maxMetricsLength - 3) + '...';
198 }
199 return metricsContent;
200 };
201 ProgbarLogger.prototype.isFieldRelevant = function (key) {
202 return key !== 'batch' && key !== 'size';
203 };
204 return ProgbarLogger;
205}(tfjs_1.CustomCallback));
206exports.ProgbarLogger = ProgbarLogger;
207var BASE_NUM_DIGITS = 2;
208var MAX_NUM_DECIMAL_PLACES = 4;
209/**
210 * Get a succint string representation of a number.
211 *
212 * Uses decimal notation if the number isn't too small.
213 * Otherwise, use engineering notation.
214 *
215 * @param x Input number.
216 * @return Succinct string representing `x`.
217 */
218function getSuccinctNumberDisplay(x) {
219 var decimalPlaces = getDisplayDecimalPlaces(x);
220 return decimalPlaces > MAX_NUM_DECIMAL_PLACES ?
221 x.toExponential(BASE_NUM_DIGITS) :
222 x.toFixed(decimalPlaces);
223}
224exports.getSuccinctNumberDisplay = getSuccinctNumberDisplay;
225/**
226 * Determine the number of decimal places to display.
227 *
228 * @param x Number to display.
229 * @return Number of decimal places to display for `x`.
230 */
231function getDisplayDecimalPlaces(x) {
232 if (!Number.isFinite(x) || x === 0 || x > 1 || x < -1) {
233 return BASE_NUM_DIGITS;
234 }
235 else {
236 return BASE_NUM_DIGITS - Math.floor(Math.log10(Math.abs(x)));
237 }
238}
239exports.getDisplayDecimalPlaces = getDisplayDecimalPlaces;
240/**
241 * Callback for logging to TensorBoard during training.
242 *
243 * Users are expected to access this class through the `tensorBoardCallback()`
244 * factory method instead.
245 */
246var TensorBoardCallback = /** @class */ (function (_super) {
247 __extends(TensorBoardCallback, _super);
248 function TensorBoardCallback(logdir, args) {
249 if (logdir === void 0) { logdir = './logs'; }
250 var _this = _super.call(this, {
251 onBatchEnd: function (batch, logs) { return __awaiter(_this, void 0, void 0, function () {
252 return __generator(this, function (_a) {
253 this.batchesSeen++;
254 if (this.args.updateFreq !== 'epoch') {
255 this.logMetrics(logs, 'batch_', this.batchesSeen);
256 }
257 return [2 /*return*/];
258 });
259 }); },
260 onEpochEnd: function (epoch, logs) { return __awaiter(_this, void 0, void 0, function () {
261 return __generator(this, function (_a) {
262 this.logMetrics(logs, 'epoch_', epoch + 1);
263 return [2 /*return*/];
264 });
265 }); },
266 onTrainEnd: function (logs) { return __awaiter(_this, void 0, void 0, function () {
267 return __generator(this, function (_a) {
268 if (this.trainWriter != null) {
269 this.trainWriter.flush();
270 }
271 if (this.valWriter != null) {
272 this.valWriter.flush();
273 }
274 return [2 /*return*/];
275 });
276 }); }
277 }) || this;
278 _this.logdir = logdir;
279 _this.args = args == null ? {} : args;
280 if (_this.args.updateFreq == null) {
281 _this.args.updateFreq = 'epoch';
282 }
283 tfjs_1.util.assert(['batch', 'epoch'].indexOf(_this.args.updateFreq) !== -1, function () { return "Expected updateFreq to be 'batch' or 'epoch', but got " +
284 ("" + _this.args.updateFreq); });
285 _this.batchesSeen = 0;
286 return _this;
287 }
288 TensorBoardCallback.prototype.logMetrics = function (logs, prefix, step) {
289 for (var key in logs) {
290 if (key === 'batch' || key === 'size' || key === 'num_steps') {
291 continue;
292 }
293 var VAL_PREFIX = 'val_';
294 if (key.startsWith(VAL_PREFIX)) {
295 this.ensureValWriterCreated();
296 var scalarName = prefix + key.slice(VAL_PREFIX.length);
297 this.valWriter.scalar(scalarName, logs[key], step);
298 }
299 else {
300 this.ensureTrainWriterCreated();
301 this.trainWriter.scalar("" + prefix + key, logs[key], step);
302 }
303 }
304 };
305 TensorBoardCallback.prototype.ensureTrainWriterCreated = function () {
306 this.trainWriter = tensorboard_1.summaryFileWriter(path.join(this.logdir, 'train'));
307 };
308 TensorBoardCallback.prototype.ensureValWriterCreated = function () {
309 this.valWriter = tensorboard_1.summaryFileWriter(path.join(this.logdir, 'val'));
310 };
311 return TensorBoardCallback;
312}(tfjs_1.CustomCallback));
313exports.TensorBoardCallback = TensorBoardCallback;
314/**
315 * Callback for logging to TensorBoard during training.
316 *
317 * Writes the loss and metric values (if any) to the specified log directory
318 * (`logdir`) which can be ingested and visualized by TensorBoard.
319 * This callback is usually passed as a callback to `tf.Model.fit()` or
320 * `tf.Model.fitDataset()` calls during model training. The frequency at which
321 * the values are logged can be controlled with the `updateFreq` field of the
322 * configuration object (2nd argument).
323 *
324 * Usage example:
325 * ```js
326 * // Constructor a toy multilayer-perceptron regressor for demo purpose.
327 * const model = tf.sequential();
328 * model.add(
329 * tf.layers.dense({units: 100, activation: 'relu', inputShape: [200]}));
330 * model.add(tf.layers.dense({units: 1}));
331 * model.compile({
332 * loss: 'meanSquaredError',
333 * optimizer: 'sgd',
334 * metrics: ['MAE']
335 * });
336 *
337 * // Generate some random fake data for demo purpose.
338 * const xs = tf.randomUniform([10000, 200]);
339 * const ys = tf.randomUniform([10000, 1]);
340 * const valXs = tf.randomUniform([1000, 200]);
341 * const valYs = tf.randomUniform([1000, 1]);
342 *
343 * // Start model training process.
344 * await model.fit(xs, ys, {
345 * epochs: 100,
346 * validationData: [valXs, valYs],
347 * // Add the tensorBoard callback here.
348 * callbacks: tf.node.tensorBoard('/tmp/fit_logs_1')
349 * });
350 * ```
351 *
352 * Then you can use the following commands to point tensorboard
353 * to the logdir:
354 *
355 * ```sh
356 * pip install tensorboard # Unless you've already installed it.
357 * tensorboard --logdir /tmp/fit_logs_1
358 * ```
359 *
360 * @param logdir Directory to which the logs will be written.
361 * @param args Optional configuration arguments.
362 * @returns An instance of `TensorBoardCallback`, which is a subclass of
363 * `tf.CustomCallback`.
364 *
365 * @doc {heading: 'TensorBoard', namespace: 'node'}
366 */
367function tensorBoard(logdir, args) {
368 if (logdir === void 0) { logdir = './logs'; }
369 return new TensorBoardCallback(logdir, args);
370}
371exports.tensorBoard = tensorBoard;