UNPKG

20.7 kBJavaScriptView Raw
1"use strict";
2/**
3 * @license
4 * Copyright 2018 Google LLC. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 * =============================================================================
17 */
18var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
19 return new (P || (P = Promise))(function (resolve, reject) {
20 function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
21 function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
22 function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); }
23 step((generator = generator.apply(thisArg, _arguments || [])).next());
24 });
25};
26var __generator = (this && this.__generator) || function (thisArg, body) {
27 var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g;
28 return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g;
29 function verb(n) { return function (v) { return step([n, v]); }; }
30 function step(op) {
31 if (f) throw new TypeError("Generator is already executing.");
32 while (_) try {
33 if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t;
34 if (y = 0, t) op = [op[0] & 2, t.value];
35 switch (op[0]) {
36 case 0: case 1: t = op; break;
37 case 4: _.label++; return { value: op[1], done: false };
38 case 5: _.label++; y = op[1]; op = [0]; continue;
39 case 7: op = _.ops.pop(); _.trys.pop(); continue;
40 default:
41 if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; }
42 if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; }
43 if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; }
44 if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; }
45 if (t[2]) _.ops.pop();
46 _.trys.pop(); continue;
47 }
48 op = body.call(thisArg, _);
49 } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; }
50 if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true };
51 }
52};
53var _this = this;
54Object.defineProperty(exports, "__esModule", { value: true });
55var tfjs_1 = require("@tensorflow/tfjs");
56var fs = require("fs");
57var path = require("path");
58var util_1 = require("util");
59var tfn = require("./index");
60// tslint:disable-next-line:no-require-imports
61var rimraf = require('rimraf');
62// tslint:disable-next-line:no-require-imports
63var tmp = require('tmp');
64var rimrafPromise = util_1.promisify(rimraf);
65describe('tensorboard', function () {
66 var tmpLogDir;
67 beforeEach(function () {
68 tmpLogDir = tmp.dirSync().name;
69 });
70 afterEach(function () { return __awaiter(_this, void 0, void 0, function () {
71 return __generator(this, function (_a) {
72 switch (_a.label) {
73 case 0:
74 if (!(tmpLogDir != null)) return [3 /*break*/, 2];
75 return [4 /*yield*/, rimrafPromise(tmpLogDir)];
76 case 1:
77 _a.sent();
78 _a.label = 2;
79 case 2: return [2 /*return*/];
80 }
81 });
82 }); });
83 it('Create summaryFileWriter and write scalar', function () {
84 var writer = tfn.node.summaryFileWriter(tmpLogDir);
85 writer.scalar('foo', 42, 0);
86 writer.flush();
87 // Currently, we only verify that the file exists and the size
88 // increases in a sensible way as we write more scalars to it.
89 // The difficulty is in reading the protobuf contents of the event
90 // file in JavaScript/TypeScript.
91 var fileNames = fs.readdirSync(tmpLogDir);
92 expect(fileNames.length).toEqual(1);
93 var eventFilePath = path.join(tmpLogDir, fileNames[0]);
94 var fileSize0 = fs.statSync(eventFilePath).size;
95 writer.scalar('foo', 43, 1);
96 writer.flush();
97 var fileSize1 = fs.statSync(eventFilePath).size;
98 var incrementPerScalar = fileSize1 - fileSize0;
99 expect(incrementPerScalar).toBeGreaterThan(0);
100 writer.scalar('foo', 44, 2);
101 writer.scalar('foo', 45, 3);
102 writer.flush();
103 var fileSize2 = fs.statSync(eventFilePath).size;
104 expect(fileSize2 - fileSize1).toEqual(2 * incrementPerScalar);
105 });
106 it('Writing tf.Scalar works', function () {
107 var writer = tfn.node.summaryFileWriter(tmpLogDir);
108 writer.scalar('foo', tfjs_1.scalar(42), 0);
109 writer.flush();
110 // Currently, we only verify that the file exists and the size
111 // increases in a sensible way as we write more scalars to it.
112 // The difficulty is in reading the protobuf contents of the event
113 // file in JavaScript/TypeScript.
114 var fileNames = fs.readdirSync(tmpLogDir);
115 expect(fileNames.length).toEqual(1);
116 });
117 it('No crosstalk between two summary writers', function () {
118 var logDir1 = path.join(tmpLogDir, '1');
119 var writer1 = tfn.node.summaryFileWriter(logDir1);
120 writer1.scalar('foo', 42, 0);
121 writer1.flush();
122 var logDir2 = path.join(tmpLogDir, '2');
123 var writer2 = tfn.node.summaryFileWriter(logDir2);
124 writer2.scalar('foo', 1.337, 0);
125 writer2.flush();
126 // Currently, we only verify that the file exists and the size
127 // increases in a sensible way as we write more scalars to it.
128 // The difficulty is in reading the protobuf contents of the event
129 // file in JavaScript/TypeScript.
130 var fileNames = fs.readdirSync(logDir1);
131 expect(fileNames.length).toEqual(1);
132 var eventFilePath1 = path.join(logDir1, fileNames[0]);
133 var fileSize1Num0 = fs.statSync(eventFilePath1).size;
134 fileNames = fs.readdirSync(logDir2);
135 expect(fileNames.length).toEqual(1);
136 var eventFilePath2 = path.join(logDir2, fileNames[0]);
137 var fileSize2Num0 = fs.statSync(eventFilePath2).size;
138 expect(fileSize2Num0).toBeGreaterThan(0);
139 writer1.scalar('foo', 43, 1);
140 writer1.flush();
141 var fileSize1Num1 = fs.statSync(eventFilePath1).size;
142 var incrementPerScalar = fileSize1Num1 - fileSize1Num0;
143 expect(incrementPerScalar).toBeGreaterThan(0);
144 writer1.scalar('foo', 44, 2);
145 writer1.scalar('foo', 45, 3);
146 writer1.flush();
147 var fileSize1Num2 = fs.statSync(eventFilePath1).size;
148 expect(fileSize1Num2 - fileSize1Num1).toEqual(2 * incrementPerScalar);
149 var fileSize2Num1 = fs.statSync(eventFilePath2).size;
150 expect(fileSize2Num1).toEqual(fileSize2Num0);
151 writer2.scalar('foo', 1.336, 1);
152 writer2.scalar('foo', 1.335, 2);
153 writer2.flush();
154 var fileSize1Num3 = fs.statSync(eventFilePath1).size;
155 expect(fileSize1Num3).toEqual(fileSize1Num2);
156 var fileSize2Num2 = fs.statSync(eventFilePath2).size;
157 expect(fileSize2Num2 - fileSize2Num1).toEqual(2 * incrementPerScalar);
158 });
159 it('Writing into existing directory works', function () {
160 fs.mkdirSync(tmpLogDir, { recursive: true });
161 var writer = tfn.node.summaryFileWriter(path.join(tmpLogDir, '22'));
162 writer.scalar('foo', 42, 0);
163 writer.flush();
164 var fileNames = fs.readdirSync(tmpLogDir);
165 expect(fileNames.length).toEqual(1);
166 });
167 it('empty logdir leads to error', function () {
168 expect(function () { return tfn.node.summaryFileWriter(''); }).toThrowError(/empty string/);
169 });
170});
171describe('tensorBoard callback', function () {
172 var tmpLogDir;
173 beforeEach(function () {
174 tmpLogDir = tmp.dirSync().name;
175 });
176 afterEach(function () { return __awaiter(_this, void 0, void 0, function () {
177 return __generator(this, function (_a) {
178 switch (_a.label) {
179 case 0:
180 if (!(tmpLogDir != null)) return [3 /*break*/, 2];
181 return [4 /*yield*/, rimrafPromise(tmpLogDir)];
182 case 1:
183 _a.sent();
184 _a.label = 2;
185 case 2: return [2 /*return*/];
186 }
187 });
188 }); });
189 function createModelForTest() {
190 var model = tfn.sequential();
191 model.add(tfn.layers.dense({ units: 5, activation: 'relu', inputShape: [10] }));
192 model.add(tfn.layers.dense({ units: 1 }));
193 model.compile({ loss: 'meanSquaredError', optimizer: 'sgd', metrics: ['MAE'] });
194 return model;
195 }
196 it('fit(): default epoch updateFreq, with validation', function () { return __awaiter(_this, void 0, void 0, function () {
197 var model, xs, ys, valXs, valYs, subDirs, trainLogDir, trainFiles, trainFileSize0, valLogDir, valFiles, valFileSize0, history, trainFileSize1, valFileSize1;
198 return __generator(this, function (_a) {
199 switch (_a.label) {
200 case 0:
201 model = createModelForTest();
202 xs = tfn.randomUniform([100, 10]);
203 ys = tfn.randomUniform([100, 1]);
204 valXs = tfn.randomUniform([10, 10]);
205 valYs = tfn.randomUniform([10, 1]);
206 // Warm-up training.
207 return [4 /*yield*/, model.fit(xs, ys, {
208 epochs: 1,
209 verbose: 0,
210 validationData: [valXs, valYs],
211 callbacks: tfn.node.tensorBoard(tmpLogDir)
212 })];
213 case 1:
214 // Warm-up training.
215 _a.sent();
216 subDirs = fs.readdirSync(tmpLogDir);
217 expect(subDirs).toContain('train');
218 expect(subDirs).toContain('val');
219 trainLogDir = path.join(tmpLogDir, 'train');
220 trainFiles = fs.readdirSync(trainLogDir);
221 trainFileSize0 = fs.statSync(path.join(trainLogDir, trainFiles[0])).size;
222 expect(trainFileSize0).toBeGreaterThan(0);
223 valLogDir = path.join(tmpLogDir, 'val');
224 valFiles = fs.readdirSync(valLogDir);
225 valFileSize0 = fs.statSync(path.join(valLogDir, valFiles[0])).size;
226 expect(valFileSize0).toBeGreaterThan(0);
227 // With updateFreq === epoch, the train and val subset should have generated
228 // the same amount of logs.
229 expect(valFileSize0).toEqual(trainFileSize0);
230 return [4 /*yield*/, model.fit(xs, ys, {
231 epochs: 3,
232 verbose: 0,
233 validationData: [valXs, valYs],
234 callbacks: tfn.node.tensorBoard(tmpLogDir)
235 })];
236 case 2:
237 history = _a.sent();
238 expect(history.history.loss.length).toEqual(3);
239 expect(history.history.val_loss.length).toEqual(3);
240 expect(history.history.MAE.length).toEqual(3);
241 expect(history.history.val_MAE.length).toEqual(3);
242 trainFileSize1 = fs.statSync(path.join(trainLogDir, trainFiles[0])).size;
243 valFileSize1 = fs.statSync(path.join(valLogDir, valFiles[0])).size;
244 // We currently only assert that new content has been written to the log
245 // file.
246 expect(trainFileSize1).toBeGreaterThan(trainFileSize0);
247 expect(valFileSize1).toBeGreaterThan(valFileSize0);
248 // With updateFreq === epoch, the train and val subset should have generated
249 // the same amount of logs.
250 expect(valFileSize1).toEqual(trainFileSize1);
251 return [2 /*return*/];
252 }
253 });
254 }); });
255 it('fit(): batch updateFreq, with validation', function () { return __awaiter(_this, void 0, void 0, function () {
256 var model, xs, ys, valXs, valYs, subDirs, trainLogDir, trainFiles, trainFileSize0, valLogDir, valFiles, valFileSize0, history, trainFileSize1, valFileSize1;
257 return __generator(this, function (_a) {
258 switch (_a.label) {
259 case 0:
260 model = createModelForTest();
261 xs = tfn.randomUniform([100, 10]);
262 ys = tfn.randomUniform([100, 1]);
263 valXs = tfn.randomUniform([10, 10]);
264 valYs = tfn.randomUniform([10, 1]);
265 // Warm-up training.
266 return [4 /*yield*/, model.fit(xs, ys, {
267 epochs: 1,
268 verbose: 0,
269 validationData: [valXs, valYs],
270 // Use batch updateFreq here.
271 callbacks: tfn.node.tensorBoard(tmpLogDir, { updateFreq: 'batch' })
272 })];
273 case 1:
274 // Warm-up training.
275 _a.sent();
276 subDirs = fs.readdirSync(tmpLogDir);
277 expect(subDirs).toContain('train');
278 expect(subDirs).toContain('val');
279 trainLogDir = path.join(tmpLogDir, 'train');
280 trainFiles = fs.readdirSync(trainLogDir);
281 trainFileSize0 = fs.statSync(path.join(trainLogDir, trainFiles[0])).size;
282 expect(trainFileSize0).toBeGreaterThan(0);
283 valLogDir = path.join(tmpLogDir, 'val');
284 valFiles = fs.readdirSync(valLogDir);
285 valFileSize0 = fs.statSync(path.join(valLogDir, valFiles[0])).size;
286 expect(valFileSize0).toBeGreaterThan(0);
287 // The train subset should have generated more logs than the val subset,
288 // because the train subset gets logged every batch, while the val subset
289 // gets logged every epoch.
290 expect(trainFileSize0).toBeGreaterThan(valFileSize0);
291 return [4 /*yield*/, model.fit(xs, ys, {
292 epochs: 3,
293 verbose: 0,
294 validationData: [valXs, valYs],
295 callbacks: tfn.node.tensorBoard(tmpLogDir)
296 })];
297 case 2:
298 history = _a.sent();
299 expect(history.history.loss.length).toEqual(3);
300 expect(history.history.val_loss.length).toEqual(3);
301 expect(history.history.MAE.length).toEqual(3);
302 expect(history.history.val_MAE.length).toEqual(3);
303 trainFileSize1 = fs.statSync(path.join(trainLogDir, trainFiles[0])).size;
304 valFileSize1 = fs.statSync(path.join(valLogDir, valFiles[0])).size;
305 // We currently only assert that new content has been written to the log
306 // file.
307 expect(trainFileSize1).toBeGreaterThan(trainFileSize0);
308 expect(valFileSize1).toBeGreaterThan(valFileSize0);
309 // The train subset should have generated more logs than the val subset,
310 // because the train subset gets logged every batch, while the val subset
311 // gets logged every epoch.
312 expect(trainFileSize1).toBeGreaterThan(valFileSize1);
313 return [2 /*return*/];
314 }
315 });
316 }); });
317 it('fit(): with initialEpoch', function () { return __awaiter(_this, void 0, void 0, function () {
318 var model, xs, ys, valXs, valYs, callback, trainWriterScalarSpy, valWriterScalarSpy, trainCallArgs, valCallArgs;
319 return __generator(this, function (_a) {
320 switch (_a.label) {
321 case 0:
322 model = createModelForTest();
323 xs = tfn.randomUniform([100, 10]);
324 ys = tfn.randomUniform([100, 1]);
325 valXs = tfn.randomUniform([10, 10]);
326 valYs = tfn.randomUniform([10, 1]);
327 // Warm-up training. Also ensures that `callback.trainWriter` and
328 // `callback.valWriter` are created.
329 return [4 /*yield*/, model.fit(xs, ys, {
330 epochs: 2,
331 validationData: [valXs, valYs],
332 verbose: 0,
333 })];
334 case 1:
335 // Warm-up training. Also ensures that `callback.trainWriter` and
336 // `callback.valWriter` are created.
337 _a.sent();
338 callback = tfn.node.tensorBoard(tmpLogDir, { updateFreq: 'epoch' });
339 // tslint:disable-next-line:no-any
340 callback.ensureTrainWriterCreated();
341 // tslint:disable-next-line:no-any
342 callback.ensureValWriterCreated();
343 trainWriterScalarSpy = spyOn(callback.trainWriter, 'scalar');
344 valWriterScalarSpy = spyOn(callback.valWriter, 'scalar');
345 // Train for 2 more epochs, using initialEpoch and callback.
346 return [4 /*yield*/, model.fit(xs, ys, {
347 epochs: 4,
348 initialEpoch: 2,
349 validationData: [valXs, valYs],
350 callbacks: [callback],
351 })];
352 case 2:
353 // Train for 2 more epochs, using initialEpoch and callback.
354 _a.sent();
355 expect(trainWriterScalarSpy).toHaveBeenCalledTimes(4);
356 trainCallArgs = trainWriterScalarSpy.calls.allArgs();
357 // Assert that the epoch numbers used to log the epoch-end loss and metric
358 // reflect initialEpoch.
359 expect(trainCallArgs[0][0]).toEqual('epoch_loss');
360 expect(trainCallArgs[0][2]).toEqual(3);
361 expect(trainCallArgs[1][0]).toEqual('epoch_MAE');
362 expect(trainCallArgs[1][2]).toEqual(3);
363 expect(trainCallArgs[2][0]).toEqual('epoch_loss');
364 expect(trainCallArgs[2][2]).toEqual(4);
365 expect(trainCallArgs[3][0]).toEqual('epoch_MAE');
366 expect(trainCallArgs[3][2]).toEqual(4);
367 expect(valWriterScalarSpy).toHaveBeenCalledTimes(4);
368 valCallArgs = valWriterScalarSpy.calls.allArgs();
369 expect(valCallArgs[0][0]).toEqual('epoch_loss');
370 expect(valCallArgs[0][2]).toEqual(3);
371 expect(valCallArgs[1][0]).toEqual('epoch_MAE');
372 expect(valCallArgs[1][2]).toEqual(3);
373 expect(valCallArgs[2][0]).toEqual('epoch_loss');
374 expect(valCallArgs[2][2]).toEqual(4);
375 expect(valCallArgs[3][0]).toEqual('epoch_MAE');
376 expect(valCallArgs[3][2]).toEqual(4);
377 return [2 /*return*/];
378 }
379 });
380 }); });
381 it('Invalid updateFreq value causes error', function () { return __awaiter(_this, void 0, void 0, function () {
382 return __generator(this, function (_a) {
383 expect(function () { return tfn.node.tensorBoard(tmpLogDir, {
384 // tslint:disable-next-line:no-any
385 updateFreq: 'foo'
386 }); }).toThrowError(/Expected updateFreq/);
387 return [2 /*return*/];
388 });
389 }); });
390});