UNPKG

225 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/* Original Source: engine/training.py */
11import * as tfc from '@tensorflow/tfjs-core';
12import { io, Optimizer, scalar, serialization, Tensor, tensor1d, util } from '@tensorflow/tfjs-core';
13import * as K from '../backend/tfjs_backend';
14import { nameScope } from '../common';
15import { NotImplementedError, RuntimeError, ValueError } from '../errors';
16import { deserialize } from '../layers/serialization';
17import * as losses from '../losses';
18import * as Metrics from '../metrics';
19import * as optimizers from '../optimizers';
20import { checkUserDefinedMetadata } from '../user_defined_metadata';
21import { count, pyListRepeat, singletonOrArray, toCamelCase, toSnakeCase, unique } from '../utils/generic_utils';
22import { printSummary } from '../utils/layer_utils';
23import { range } from '../utils/math_utils';
24import { convertPythonicToTs } from '../utils/serialization_utils';
25import { version } from '../version';
26import { Container } from './container';
27import { execute, FeedDict } from './executor';
28import { evaluateDataset, fitDataset } from './training_dataset';
29import { checkBatchSize, disposeNewTensors, ensureTensorsRank2OrHigher, fitTensors, makeBatches, sliceArrays, sliceArraysByIndices } from './training_tensors';
30import { computeWeightedLoss, standardizeClassWeights, standardizeWeights } from './training_utils';
31/**
32 * Helper function for polymorphic input data: 1. singleton Tensor.
33 */
34export function isDataTensor(x) {
35 return x instanceof Tensor;
36}
37/**
38 * Helper function for polymorphic input data: 2. Array of Tensor.
39 */
40export function isDataArray(x) {
41 return Array.isArray(x);
42}
43/**
44 * Helper function for polymorphic input data: 3. "dict" of Tensor.
45 */
46export function isDataDict(x) {
47 return !isDataTensor(x) && !isDataArray(x);
48}
49/**
50 * Normalizes inputs and targets provided by users.
51 * @param data User-provided input data (polymorphic).
52 * @param names An Array of expected Tensor names.
53 * @param shapes Optional Array of expected Tensor shapes.
54 * @param checkBatchAxis Whether to check that the batch axis of the arrays
55 * match the expected value found in `shapes`.
56 * @param exceptionPrefix String prefix used for exception formatting.
57 * @returns List of standardized input Tensors (one Tensor per model input).
58 * @throws ValueError: in case of improperly formatted user data.
59 */
60export function standardizeInputData(data, names, shapes, checkBatchAxis = true, exceptionPrefix = '') {
61 if (names == null || names.length === 0) {
62 // Check for the case where the model expected no data, but some data got
63 // sent.
64 if (data != null) {
65 let gotUnexpectedData = false;
66 if (isDataArray(data) && data.length > 0) {
67 gotUnexpectedData = true;
68 }
69 else if (isDataDict(data)) {
70 for (const key in data) {
71 if (data.hasOwnProperty(key)) {
72 gotUnexpectedData = true;
73 break;
74 }
75 }
76 }
77 else {
78 // `data` is a singleton Tensor in this case.
79 gotUnexpectedData = true;
80 }
81 if (gotUnexpectedData) {
82 throw new ValueError(`Error when checking model ${exceptionPrefix} expected no data, ` +
83 `but got ${data}`);
84 }
85 }
86 return [];
87 }
88 if (data == null) {
89 return names.map(name => null);
90 }
91 let arrays;
92 if (isDataDict(data)) {
93 data = data;
94 arrays = [];
95 for (const name of names) {
96 if (data[name] == null) {
97 throw new ValueError(`No data provided for "${name}". Need data for each key in: ` +
98 `${names}`);
99 }
100 arrays.push(data[name]);
101 }
102 }
103 else if (isDataArray(data)) {
104 data = data;
105 if (data.length !== names.length) {
106 throw new ValueError(`Error when checking model ${exceptionPrefix}: the Array of ` +
107 `Tensors that you are passing to your model is not the size the ` +
108 `model expected. Expected to see ${names.length} Tensor(s), but ` +
109 `instead got the following list of Tensor(s): ${data}`);
110 }
111 arrays = data;
112 }
113 else {
114 data = data;
115 if (names.length > 1) {
116 throw new ValueError(`The model ${exceptionPrefix} expects ${names.length} Tensor(s), ` +
117 `but only received one Tensor. Found: Tensor with shape ${data.shape}`);
118 }
119 arrays = [data];
120 }
121 arrays = ensureTensorsRank2OrHigher(arrays);
122 // Check shape compatibility.
123 if (shapes != null) {
124 for (let i = 0; i < names.length; ++i) {
125 if (shapes[i] == null) {
126 continue;
127 }
128 const array = arrays[i];
129 if (array.shape.length !== shapes[i].length) {
130 throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
131 `to have ${shapes[i].length} dimension(s). but got array with ` +
132 `shape ${array.shape}`);
133 }
134 for (let j = 0; j < shapes[i].length; ++j) {
135 if (j === 0 && !checkBatchAxis) {
136 // Skip the first (batch) axis.
137 continue;
138 }
139 const dim = array.shape[j];
140 const refDim = shapes[i][j];
141 if (refDim != null && refDim >= 0 && dim !== refDim) {
142 throw new ValueError(`${exceptionPrefix} expected a batch of elements where each ` +
143 `example has shape [${shapes[i].slice(1, shapes[i].length)}] ` +
144 `(i.e.,tensor shape [*,${shapes[i].slice(1, shapes[i].length)}])` +
145 ` but the ${exceptionPrefix} received an input with ${array.shape[0]}` +
146 ` examples, each with shape [${array.shape.slice(1, array.shape.length)}]` +
147 ` (tensor shape [${array.shape}])`);
148 }
149 }
150 }
151 }
152 return arrays;
153}
154/**
155 * User input validation for Tensors.
156 * @param inputs `Array` of `tf.Tensor`s for inputs.
157 * @param targets `Array` of `tf.Tensor`s for targets.
158 * @param weights Optional `Array` of `tf.Tensor`s for sample weights.
159 * @throws ValueError: in case of incorrectly formatted data.
160 */
161export function checkArrayLengths(inputs, targets, weights) {
162 const setX = unique(inputs.map(input => input.shape[0]));
163 setX.sort();
164 const setY = unique(targets.map(target => target.shape[0]));
165 setY.sort();
166 // TODO(cais): Check `weights` as well.
167 if (setX.length > 1) {
168 throw new ValueError(`All input Tensors (x) should have the same number of samples. ` +
169 `Got array shapes: ` +
170 `${JSON.stringify(inputs.map(input => input.shape))}`);
171 }
172 if (setY.length > 1) {
173 throw new ValueError(`All target Tensors (y) should have the same number of samples. ` +
174 `Got array shapes: ` +
175 `${JSON.stringify(targets.map(target => target.shape))}`);
176 }
177 if (setX.length > 0 && setY.length > 0 && !util.arraysEqual(setX, setY)) {
178 throw new ValueError(`Input Tensors should have the same number of samples as target ` +
179 `Tensors. Found ${setX[0]} input sample(s) and ${setY[0]} target ` +
180 `sample(s).`);
181 }
182}
183/**
184 * Validation on the compatibility of targes and loss functions.
185 *
186 * This helps prevent users from using loss functions incorrectly.
187 *
188 * @param targets `Array` of `tf.Tensor`s of targets.
189 * @param lossFns `Array` of loss functions.
190 * @param outputShapes `Array` of shapes of model outputs.
191 */
192function checkLossAndTargetCompatibility(targets, lossFns, outputShapes) {
193 // TODO(cais): Dedicated test coverage?
194 const keyLosses = [
195 losses.meanSquaredError, losses.binaryCrossentropy,
196 losses.categoricalCrossentropy
197 ];
198 for (let i = 0; i < targets.length; ++i) {
199 const y = targets[i];
200 const loss = lossFns[i];
201 const shape = outputShapes[i];
202 if (loss == null) {
203 continue;
204 }
205 if (loss === losses.categoricalCrossentropy) {
206 if (y.shape[y.shape.length - 1] === 1) {
207 throw new ValueError(`You are passing a target array of shape ${y.shape} while using ` +
208 `a loss 'categorical_crossentropy'. 'categorical_crossentropy'` +
209 `expects targets to be binary matrices (1s and 0s) of shape ` +
210 `[samples, classes].`);
211 // TODO(cais): Example code in error message.
212 }
213 }
214 if (keyLosses.indexOf(loss) !== -1) {
215 const slicedYShape = y.shape.slice(1);
216 const slicedShape = shape.slice(1);
217 for (let j = 0; j < slicedYShape.length; ++j) {
218 const targetDim = slicedYShape[j];
219 const outDim = slicedShape[j];
220 if (outDim != null && targetDim !== outDim) {
221 throw new ValueError(`A target Tensor with shape ${y.shape} was passed for an ` +
222 `output of shape ${shape}, while using a loss function that ` +
223 `expects targets to have the same shape as the output.`);
224 }
225 }
226 }
227 }
228}
229/**
230 * Check inputs provided by the user.
231 *
232 * Porting Note: This corresponds to _standardize_input_data() in Python
233 * Keras. Because of the strong typing in TF.js, we do not need to convert
234 * the data. Specifically:
235 * 1) in PyKeras, `data` can be `DataFrame` instances from pandas, for
236 * example. We don't need to worry about that here because there is no
237 * widely popular javascript/typesdcript equivalent of pandas (so far).
238 * If one becomes available in the future, we can add support.
239 * 2) in PyKeras, inputs can be Python dict. But here we are stipulating
240 * that the data is either a single `tf.Tensor` or an Array of `tf.Tensor`s. We
241 * may add support for `Object` data inputs in the future when the need
242 * arises.
243 *
244 * Instead, we perform basic checks for number of parameters and shapes.
245 *
246 * @param data: The input data.
247 * @param names: Name for the inputs, from the model.
248 * @param shapes: Expected shapes for the input data, from the model.
249 * @param checkBatchAxis: Whether the size along the batch axis (i.e., the
250 * first dimension) will be checked for matching.
251 * @param exceptionPrefix: Execption prefix message, used in generating error
252 * messages.
253 * @throws ValueError: on incorrect number of inputs or mismatches in shapes.
254 */
255function checkInputData(data, names, shapes, checkBatchAxis = true, exceptionPrefix = '') {
256 let arrays;
257 if (Array.isArray(data)) {
258 if (data.length !== names.length) {
259 throw new ValueError(`Error when checking model ${exceptionPrefix}: the Array of ` +
260 `Tensors that you are passing to your model is not the size the ` +
261 `the model expected. Expected to see ${names.length} Tensor(s),` +
262 ` but instead got ${data.length} Tensors(s).`);
263 }
264 arrays = data;
265 }
266 else {
267 if (names.length > 1) {
268 throw new ValueError(`The model expects ${names.length} ${exceptionPrefix} Tensors, ` +
269 `but only received one Tensor. Found: array with shape ` +
270 `${JSON.stringify(data.shape)}.`);
271 }
272 arrays = [data];
273 }
274 if (shapes != null) {
275 for (let i = 0; i < names.length; ++i) {
276 if (shapes[i] == null) {
277 continue;
278 }
279 const array = arrays[i];
280 if (array.shape.length !== shapes[i].length) {
281 throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
282 `to have ${shapes[i].length} dimension(s), but got array with ` +
283 `shape ${JSON.stringify(array.shape)}`);
284 }
285 for (let j = 0; j < shapes[i].length; ++j) {
286 if (j === 0 && !checkBatchAxis) {
287 continue;
288 }
289 const dim = array.shape[j];
290 const refDim = shapes[i][j];
291 if (refDim != null) {
292 if (refDim !== dim) {
293 throw new ValueError(`Error when checking ${exceptionPrefix}: expected ` +
294 `${names[i]} to have shape ${JSON.stringify(shapes[i])} but ` +
295 `got array with shape ${JSON.stringify(array.shape)}.`);
296 }
297 }
298 }
299 }
300 }
301}
302/**
303 * Maps metric functions to model outputs.
304 * @param metrics An shortcut strings name, metric function, `Array` or dict
305 * (`Object`) of metric functions.
306 * @param outputNames An `Array` of the names of model outputs.
307 * @returns An `Array` (one entry per model output) of `Array` of metric
308 * functions. For instance, if the model has 2 outputs, and for the first
309 * output we want to compute `binaryAccuracy` and `binaryCrossentropy`,
310 * and just `binaryAccuracy` for the second output, the `Array` would look
311 * like:
312 * `[[binaryAccuracy, binaryCrossentropy], [binaryAccuracy]]`
313 * @throws TypeError: incompatible metrics format.
314 */
315export function collectMetrics(metrics, outputNames) {
316 if (metrics == null || Array.isArray(metrics) && metrics.length === 0) {
317 return outputNames.map(name => []);
318 }
319 let wrappedMetrics;
320 if (typeof metrics === 'string' || typeof metrics === 'function') {
321 wrappedMetrics = [metrics];
322 }
323 else if (Array.isArray(metrics) || typeof metrics === 'object') {
324 wrappedMetrics = metrics;
325 }
326 else {
327 throw new TypeError('Type of metrics argument not understood. Expected an string,' +
328 `function, Array, or Object, found: ${metrics}`);
329 }
330 if (Array.isArray(wrappedMetrics)) {
331 // We then apply all metrics to all outputs.
332 return outputNames.map(name => wrappedMetrics);
333 }
334 else {
335 // In this case, metrics is a dict.
336 const nestedMetrics = [];
337 for (const name of outputNames) {
338 let outputMetrics = wrappedMetrics.hasOwnProperty(name) ? wrappedMetrics[name] : [];
339 if (!Array.isArray(outputMetrics)) {
340 outputMetrics = [outputMetrics];
341 }
342 nestedMetrics.push(outputMetrics);
343 }
344 return nestedMetrics;
345 }
346}
347const LAYERS_MODEL_FORMAT_NAME = 'layers-model';
348/**
349 * A `tf.LayersModel` is a directed, acyclic graph of `tf.Layer`s plus methods
350 * for training, evaluation, prediction and saving.
351 *
352 * `tf.LayersModel` is the basic unit of training, inference and evaluation in
353 * TensorFlow.js. To create a `tf.LayersModel`, use `tf.LayersModel`.
354 *
355 * See also:
356 * `tf.Sequential`, `tf.loadLayersModel`.
357 *
358 * @doc {heading: 'Models', subheading: 'Classes'}
359 */
360export class LayersModel extends Container {
361 constructor(args) {
362 super(args);
363 this.isTraining = false;
364 }
365 /**
366 * Print a text summary of the model's layers.
367 *
368 * The summary includes
369 * - Name and type of all layers that comprise the model.
370 * - Output shape(s) of the layers
371 * - Number of weight parameters of each layer
372 * - If the model has non-sequential-like topology, the inputs each layer
373 * receives
374 * - The total number of trainable and non-trainable parameters of the model.
375 *
376 * ```js
377 * const input1 = tf.input({shape: [10]});
378 * const input2 = tf.input({shape: [20]});
379 * const dense1 = tf.layers.dense({units: 4}).apply(input1);
380 * const dense2 = tf.layers.dense({units: 8}).apply(input2);
381 * const concat = tf.layers.concatenate().apply([dense1, dense2]);
382 * const output =
383 * tf.layers.dense({units: 3, activation: 'softmax'}).apply(concat);
384 *
385 * const model = tf.model({inputs: [input1, input2], outputs: output});
386 * model.summary();
387 * ```
388 *
389 * @param lineLength Custom line length, in number of characters.
390 * @param positions Custom widths of each of the columns, as either
391 * fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number
392 * of characters (e.g., `[30, 50, 65]`). Each number corresponds to
393 * right-most (i.e., ending) position of a column.
394 * @param printFn Custom print function. Can be used to replace the default
395 * `console.log`. For example, you can use `x => {}` to mute the printed
396 * messages in the console.
397 *
398 * @doc {heading: 'Models', subheading: 'Classes'}
399 */
400 summary(lineLength, positions, printFn = console.log) {
401 if (!this.built) {
402 throw new ValueError(`This model has never been called, thus its weights have not been ` +
403 `created yet. So no summary can be displayed. Build the model ` +
404 `first (e.g., by calling it on some test data).`);
405 }
406 printSummary(this, lineLength, positions, printFn);
407 }
408 /**
409 * Configures and prepares the model for training and evaluation. Compiling
410 * outfits the model with an optimizer, loss, and/or metrics. Calling `fit`
411 * or `evaluate` on an un-compiled model will throw an error.
412 *
413 * @param args a `ModelCompileArgs` specifying the loss, optimizer, and
414 * metrics to be used for fitting and evaluating this model.
415 *
416 * @doc {heading: 'Models', subheading: 'Classes'}
417 */
418 compile(args) {
419 if (args.loss == null) {
420 args.loss = [];
421 }
422 this.loss = args.loss;
423 if (typeof args.optimizer === 'string') {
424 this.optimizer_ = optimizers.getOptimizer(args.optimizer);
425 this.isOptimizerOwned = true;
426 }
427 else {
428 if (!(args.optimizer instanceof Optimizer)) {
429 throw new ValueError(`User-defined optimizer must be an instance of tf.Optimizer.`);
430 }
431 this.optimizer_ = args.optimizer;
432 this.isOptimizerOwned = false;
433 }
434 // TODO(cais): Add lossWeights.
435 // TODO(cais): Add sampleWeightMode.
436 // Prepare loss functions.
437 let lossFunctions = [];
438 if (!Array.isArray(args.loss) && typeof args.loss !== 'string' &&
439 typeof args.loss !== 'function') {
440 args.loss = args.loss;
441 for (const name in args.loss) {
442 if (this.outputNames.indexOf(name) === -1) {
443 throw new ValueError(`Unknown entry in loss dictionary: "${name}". ` +
444 `Only expected the following keys: ${this.outputNames}`);
445 }
446 }
447 for (const name of this.outputNames) {
448 if (args.loss[name] == null) {
449 console.warn(`Output "${name}" is missing from loss dictionary. We assume ` +
450 `this was done on purpose, and we will not be expecting data ` +
451 `to be passed to ${name} during training`);
452 }
453 lossFunctions.push(losses.get(args.loss[name]));
454 }
455 }
456 else if (Array.isArray(args.loss)) {
457 if (args.loss.length !== this.outputs.length) {
458 throw new ValueError(`When passing an Array as loss, it should have one entry per ` +
459 `model output. The model has ${this.outputs.length} output(s), ` +
460 `but you passed loss=${args.loss}.`);
461 }
462 const theLosses = args.loss;
463 lossFunctions = theLosses.map(l => losses.get(l));
464 }
465 else {
466 const lossFunction = losses.get(args.loss);
467 this.outputs.forEach(_ => {
468 lossFunctions.push(lossFunction);
469 });
470 }
471 this.lossFunctions = lossFunctions;
472 this.feedOutputNames = [];
473 this.feedOutputShapes = [];
474 this.feedLossFns = [];
475 for (let i = 0; i < this.outputs.length; ++i) {
476 // TODO(cais): Logic for skipping target(s).
477 const shape = this.internalOutputShapes[i];
478 const name = this.outputNames[i];
479 this.feedOutputNames.push(name);
480 this.feedOutputShapes.push(shape);
481 this.feedLossFns.push(this.lossFunctions[i]);
482 }
483 // TODO(cais): Add logic for output masks.
484 // TODO(cais): Add logic for sample weights.
485 const skipTargetIndices = [];
486 // Prepare metrics.
487 this.metrics = args.metrics;
488 // TODO(cais): Add weightedMetrics.
489 this.metricsNames = ['loss'];
490 this.metricsTensors = [];
491 // Compute total loss.
492 // Porting Note: In PyKeras, metrics_tensors are symbolic tensor objects.
493 // Here, metricsTensors are TypeScript functions. This difference is due
494 // to the difference in symbolic/imperative property of the backends.
495 nameScope('loss', () => {
496 for (let i = 0; i < this.outputs.length; ++i) {
497 if (skipTargetIndices.indexOf(i) !== -1) {
498 continue;
499 }
500 // TODO(cais): Add weightedLoss, sampleWeight and mask.
501 // The following line should be weightedLoss
502 const weightedLoss = this.lossFunctions[i];
503 if (this.outputs.length > 1) {
504 this.metricsTensors.push([weightedLoss, i]);
505 this.metricsNames.push(this.outputNames[i] + '_loss');
506 }
507 }
508 // Porting Note: Due to the imperative nature of the backend, we calculate
509 // the regularizer penalties in the totalLossFunction, instead of here.
510 });
511 const nestedMetrics = collectMetrics(args.metrics, this.outputNames);
512 // TODO(cais): Add nestedWeightedMetrics.
513 /**
514 * Helper function used in loop below.
515 */
516 const appendMetric = (outputIndex, metricName, metricTensor) => {
517 if (this.outputNames.length > 1) {
518 metricName = this.outputNames[outputIndex] + '_' + metricName;
519 }
520 this.metricsNames.push(metricName);
521 this.metricsTensors.push([metricTensor, outputIndex]);
522 };
523 nameScope('metric', () => {
524 for (let i = 0; i < this.outputs.length; ++i) {
525 if (skipTargetIndices.indexOf(i) !== -1) {
526 continue;
527 }
528 const outputMetrics = nestedMetrics[i];
529 // TODO(cais): Add weights and outputWeightedMetrics.
530 // TODO(cais): Add optional arg `weights` to the following function.
531 const handleMetrics = (metrics) => {
532 const metricNamePrefix = '';
533 let metricName;
534 let accFn;
535 let weightedMetricFn;
536 // TODO(cais): Use 'weights_' for weighted metrics.
537 for (const metric of metrics) {
538 if (typeof metric === 'string' &&
539 ['accuracy', 'acc', 'crossentropy', 'ce'].indexOf(metric) !==
540 -1) {
541 const outputShape = this.internalOutputShapes[i];
542 if (outputShape[outputShape.length - 1] === 1 ||
543 this.lossFunctions[i] === losses.binaryCrossentropy) {
544 // case: binary accuracy/crossentropy.
545 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
546 accFn = Metrics.binaryAccuracy;
547 }
548 else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
549 accFn = Metrics.binaryCrossentropy;
550 }
551 }
552 else if (this.lossFunctions[i] ===
553 losses.sparseCategoricalCrossentropy) {
554 // case: categorical accuracy / crossentropy with sparse
555 // targets.
556 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
557 accFn = Metrics.sparseCategoricalAccuracy;
558 }
559 else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
560 accFn = Metrics.sparseCategoricalCrossentropy;
561 }
562 }
563 else {
564 // case: categorical accuracy / crossentropy.
565 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
566 accFn = Metrics.categoricalAccuracy;
567 }
568 else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
569 accFn = Metrics.categoricalCrossentropy;
570 }
571 }
572 let suffix;
573 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
574 suffix = 'acc';
575 }
576 else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
577 suffix = 'ce';
578 }
579 // TODO(cais): Add weighting actually.
580 weightedMetricFn = accFn;
581 metricName = metricNamePrefix + suffix;
582 }
583 else {
584 const metricFn = Metrics.get(metric);
585 // TODO(cais): Add weighting actually.
586 weightedMetricFn = metricFn;
587 metricName =
588 metricNamePrefix + Metrics.getLossOrMetricName(metric);
589 }
590 // TODO(cais): Add weighting and masking to metricResult.
591 let metricResult;
592 nameScope(metricName, () => {
593 metricResult = weightedMetricFn;
594 });
595 appendMetric(i, metricName, metricResult);
596 }
597 };
598 handleMetrics(outputMetrics);
599 // TODO(cais): Call handleMetrics with weights.
600 }
601 });
602 // Porting Notes: Given the imperative backend of tfjs-core,
603 // there is no need for constructing the symbolic graph and placeholders.
604 this.collectedTrainableWeights = this.trainableWeights;
605 }
606 /**
607 * Check trainable weights count consistency.
608 *
609 * This will raise a warning if `this.trainableWeights` and
610 * `this.collectedTrainableWeights` are inconsistent (i.e., have different
611 * numbers of parameters).
612 * Inconsistency will typically arise when one modifies `model.trainable`
613 * without calling `model.compile()` again.
614 */
615 checkTrainableWeightsConsistency() {
616 if (this.collectedTrainableWeights == null) {
617 return;
618 }
619 if (this.trainableWeights.length !==
620 this.collectedTrainableWeights.length) {
621 console.warn('Discrepancy between trainableweights and collected trainable ' +
622 'weights. Did you set `model.trainable` without calling ' +
623 '`model.compile()` afterwards?');
624 }
625 }
626 /**
627 * Returns the loss value & metrics values for the model in test mode.
628 *
629 * Loss and metrics are specified during `compile()`, which needs to happen
630 * before calls to `evaluate()`.
631 *
632 * Computation is done in batches.
633 *
634 * ```js
635 * const model = tf.sequential({
636 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
637 * });
638 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
639 * const result = model.evaluate(
640 * tf.ones([8, 10]), tf.ones([8, 1]), {batchSize: 4});
641 * result.print();
642 * ```
643 *
644 * @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the
645 * model has multiple inputs.
646 * @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the
647 * model has multiple outputs.
648 * @param args A `ModelEvaluateArgs`, containing optional fields.
649 *
650 * @return `Scalar` test loss (if the model has a single output and no
651 * metrics) or `Array` of `Scalar`s (if the model has multiple outputs
652 * and/or metrics). The attribute `model.metricsNames`
653 * will give you the display labels for the scalar outputs.
654 *
655 * @doc {heading: 'Models', subheading: 'Classes'}
656 */
657 evaluate(x, y, args = {}) {
658 const batchSize = args.batchSize == null ? 32 : args.batchSize;
659 checkBatchSize(batchSize);
660 // TODO(cais): Standardize `config.sampleWeights` as well.
661 // Validate user data.
662 const checkBatchAxis = true;
663 const standardizedOuts = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
664 try {
665 // TODO(cais): If uses `useLearningPhase`, set the corresponding element
666 // of the input to 0.
667 const ins = standardizedOuts[0].concat(standardizedOuts[1]);
668 this.makeTestFunction();
669 const f = this.testFunction;
670 const testOuts = this.testLoop(f, ins, batchSize, args.verbose, args.steps);
671 return singletonOrArray(testOuts);
672 }
673 finally {
674 disposeNewTensors(standardizedOuts[0], x);
675 disposeNewTensors(standardizedOuts[1], y);
676 }
677 }
678 // TODO(cais): Add code snippet below once real dataset objects are
679 // available.
680 /**
681 * Evaluate model using a dataset object.
682 *
683 * Note: Unlike `evaluate()`, this method is asynchronous (`async`);
684 *
685 * @param dataset A dataset object. Its `iterator()` method is expected
686 * to generate a dataset iterator object, the `next()` method of which
687 * is expected to produce data batches for evaluation. The return value
688 * of the `next()` call ought to contain a boolean `done` field and a
689 * `value` field. The `value` field is expected to be an array of two
690 * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
691 * case is for models with exactly one input and one output (e.g..
692 * a sequential model). The latter case is for models with multiple
693 * inputs and/or multiple outputs. Of the two items in the array, the
694 * first is the input feature(s) and the second is the output target(s).
695 * @param args A configuration object for the dataset-based evaluation.
696 * @returns Loss and metric values as an Array of `Scalar` objects.
697 *
698 * @doc {heading: 'Models', subheading: 'Classes'}
699 */
700 async evaluateDataset(dataset, args) {
701 this.makeTestFunction();
702 return evaluateDataset(this, dataset, args);
703 }
704 /**
705 * Get number of samples provided for training, evaluation or prediction.
706 *
707 * @param ins Input `tf.Tensor`.
708 * @param batchSize Integer batch size, optional.
709 * @param steps Total number of steps (batches of samples) before
710 * declaring loop finished. Optional.
711 * @param stepsName The public API's parameter name for `steps`.
712 * @returns Number of samples provided.
713 */
714 checkNumSamples(ins, batchSize, steps, stepsName = 'steps') {
715 let numSamples;
716 if (steps != null) {
717 numSamples = null;
718 if (batchSize != null) {
719 throw new ValueError(`If ${stepsName} is set, batchSize must be null or undefined.` +
720 `Got batchSize = ${batchSize}`);
721 }
722 }
723 else if (ins != null) {
724 if (Array.isArray(ins)) {
725 numSamples = ins[0].shape[0];
726 }
727 else {
728 numSamples = ins.shape[0];
729 }
730 }
731 else {
732 throw new ValueError(`Either the input data should have a defined shape, or ` +
733 `${stepsName} shoud be specified.`);
734 }
735 return numSamples;
736 }
737 /**
738 * Execute internal tensors of the model with input data feed.
739 * @param inputs Input data feed. Must match the inputs of the model.
740 * @param outputs Names of the output tensors to be fetched. Must match
741 * names of the SymbolicTensors that belong to the graph.
742 * @returns Fetched values for `outputs`.
743 */
744 execute(inputs, outputs) {
745 if (Array.isArray(outputs) && outputs.length === 0) {
746 throw new ValueError('`outputs` is an empty Array, which is not allowed.');
747 }
748 const outputsIsArray = Array.isArray(outputs);
749 const outputNames = (outputsIsArray ? outputs : [outputs]);
750 const outputSymbolicTensors = this.retrieveSymbolicTensors(outputNames);
751 // Format the input into a FeedDict.
752 const feedDict = new FeedDict();
753 if (inputs instanceof Tensor) {
754 inputs = [inputs];
755 }
756 if (Array.isArray(inputs)) {
757 if (inputs.length !== this.inputs.length) {
758 throw new ValueError(`The number of inputs provided (${inputs.length}) ` +
759 `does not match the number of inputs of this model ` +
760 `(${this.inputs.length}).`);
761 }
762 for (let i = 0; i < this.inputs.length; ++i) {
763 feedDict.add(this.inputs[i], inputs[i]);
764 }
765 }
766 else {
767 for (const input of this.inputs) {
768 const tensorValue = inputs[input.name];
769 if (tensorValue == null) {
770 throw new ValueError(`No value is provided for the model's input ${input.name}`);
771 }
772 feedDict.add(input, tensorValue);
773 }
774 }
775 // Run execution.
776 const executeOutputs = execute(outputSymbolicTensors, feedDict);
777 return outputsIsArray ? executeOutputs : executeOutputs[0];
778 }
779 /**
780 * Retrieve the model's internal symbolic tensors from symbolic-tensor names.
781 */
782 retrieveSymbolicTensors(symbolicTensorNames) {
783 const outputSymbolicTensors = pyListRepeat(null, symbolicTensorNames.length);
784 let outputsRemaining = symbolicTensorNames.length;
785 for (const layer of this.layers) {
786 const layerOutputs = Array.isArray(layer.output) ? layer.output : [layer.output];
787 const layerOutputNames = layerOutputs.map(output => output.name);
788 for (let i = 0; i < symbolicTensorNames.length; ++i) {
789 const index = layerOutputNames.indexOf(symbolicTensorNames[i]);
790 if (index !== -1) {
791 outputSymbolicTensors[i] = layerOutputs[index];
792 outputsRemaining--;
793 }
794 if (outputsRemaining === 0) {
795 break;
796 }
797 }
798 if (outputsRemaining === 0) {
799 break;
800 }
801 }
802 if (outputsRemaining > 0) {
803 const remainingNames = [];
804 outputSymbolicTensors.forEach((tensor, i) => {
805 if (tensor == null) {
806 remainingNames.push(symbolicTensorNames[i]);
807 }
808 });
809 throw new ValueError(`Cannot find SymbolicTensors for output name(s): ` +
810 `${JSON.stringify(remainingNames)}`);
811 }
812 return outputSymbolicTensors;
813 }
814 /**
815 * Helper method to loop over some data in batches.
816 *
817 * Porting Note: Not using the functional approach in the Python equivalent
818 * due to the imperative backend.
819 * Porting Note: Does not support step mode currently.
820 *
821 * @param ins: input data
822 * @param batchSize: integer batch size.
823 * @param verbose: verbosity model
824 * @returns: Predictions as `tf.Tensor` (if a single output) or an `Array` of
825 * `tf.Tensor` (if multipe outputs).
826 */
827 predictLoop(ins, batchSize = 32, verbose = false) {
828 return tfc.tidy(() => {
829 const numSamples = this.checkNumSamples(ins);
830 if (verbose) {
831 throw new NotImplementedError('Verbose predictLoop() is not implemented yet.');
832 }
833 // Sample-based predictions.
834 // Porting Note: Tensor currently does not support sliced assignments as
835 // in numpy, e.g., x[1:3] = y. Therefore we use concatenation while
836 // iterating over the batches.
837 const batches = makeBatches(numSamples, batchSize);
838 const outsBatches = this.outputs.map(output => []);
839 // TODO(cais): Can the scope() be pushed down inside the for loop?
840 for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
841 const batchOuts = tfc.tidy(() => {
842 const batchStart = batches[batchIndex][0];
843 const batchEnd = batches[batchIndex][1];
844 // TODO(cais): Take care of the case of the last element is a flag for
845 // training/test.
846 const insBatch = sliceArrays(ins, batchStart, batchEnd);
847 // Construct the feeds for execute();
848 const feeds = [];
849 if (Array.isArray(insBatch)) {
850 for (let i = 0; i < insBatch.length; ++i) {
851 feeds.push({ key: this.inputs[i], value: insBatch[i] });
852 }
853 }
854 else {
855 feeds.push({ key: this.inputs[0], value: insBatch });
856 }
857 const feedDict = new FeedDict(feeds);
858 return execute(this.outputs, feedDict);
859 });
860 batchOuts.forEach((batchOut, i) => outsBatches[i].push(batchOut));
861 }
862 return singletonOrArray(outsBatches.map(batches => tfc.concat(batches, 0)));
863 });
864 }
865 /**
866 * Generates output predictions for the input samples.
867 *
868 * Computation is done in batches.
869 *
870 * Note: the "step" mode of predict() is currently not supported.
871 * This is because the TensorFlow.js core backend is imperative only.
872 *
873 * ```js
874 * const model = tf.sequential({
875 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
876 * });
877 * model.predict(tf.ones([8, 10]), {batchSize: 4}).print();
878 * ```
879 *
880 * @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if
881 * the model has multiple inputs.
882 * @param args A `ModelPredictArgs` object containing optional fields.
883 *
884 * @return Prediction results as a `tf.Tensor`(s).
885 *
886 * @exception ValueError In case of mismatch between the provided input data
887 * and the model's expectations, or in case a stateful model receives a
888 * number of samples that is not a multiple of the batch size.
889 *
890 * @doc {heading: 'Models', subheading: 'Classes'}
891 */
892 predict(x, args = {}) {
893 const xsRank2OrHigher = ensureTensorsRank2OrHigher(x);
894 checkInputData(xsRank2OrHigher, this.inputNames, this.feedInputShapes, false);
895 try {
896 // TODO(cais): Take care of stateful models.
897 // if (this.stateful) ...
898 // TODO(cais): Take care of the learning_phase boolean flag.
899 // if (this.useLearningPhase) ...
900 const batchSize = args.batchSize == null ? 32 : args.batchSize;
901 checkBatchSize(batchSize);
902 return this.predictLoop(xsRank2OrHigher, batchSize);
903 }
904 finally {
905 disposeNewTensors(xsRank2OrHigher, x);
906 }
907 }
908 /**
909 * Returns predictions for a single batch of samples.
910 *
911 * ```js
912 * const model = tf.sequential({
913 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
914 * });
915 * model.predictOnBatch(tf.ones([8, 10])).print();
916 * ```
917 * @param x: Input samples, as a Tensor (for models with exactly one
918 * input) or an array of Tensors (for models with more than one input).
919 * @return Tensor(s) of predictions
920 *
921 * @doc {heading: 'Models', subheading: 'Classes'}
922 */
923 predictOnBatch(x) {
924 checkInputData(x, this.inputNames, this.feedInputShapes, true);
925 // TODO(cais): Take care of the learning_phase boolean flag.
926 // if (this.useLearningPhase) ...
927 const batchSize = (Array.isArray(x) ? x[0] : x).shape[0];
928 return this.predictLoop(x, batchSize);
929 }
930 standardizeUserDataXY(x, y, checkBatchAxis = true, batchSize) {
931 // TODO(cais): Add sampleWeight, classWeight
932 if (this.optimizer_ == null) {
933 throw new RuntimeError('You must compile a model before training/testing. Use ' +
934 'LayersModel.compile(modelCompileArgs).');
935 }
936 const outputShapes = [];
937 for (let i = 0; i < this.feedOutputShapes.length; ++i) {
938 const outputShape = this.feedOutputShapes[i];
939 const lossFn = this.feedLossFns[i];
940 if (lossFn === losses.sparseCategoricalCrossentropy) {
941 outputShapes.push(outputShape.slice(0, outputShape.length - 1).concat([1]));
942 }
943 else {
944 // Porting Note: Because of strong typing `lossFn` must be a function.
945 outputShapes.push(outputShape);
946 }
947 }
948 x = standardizeInputData(x, this.feedInputNames, this.feedInputShapes, false, 'input');
949 y = standardizeInputData(y, this.feedOutputNames, outputShapes, false, 'target');
950 // TODO(cais): Standardize sampleWeights & classWeights.
951 checkArrayLengths(x, y, null);
952 // TODO(cais): Check sampleWeights as well.
953 checkLossAndTargetCompatibility(y, this.feedLossFns, this.feedOutputShapes);
954 if (this.stateful && batchSize != null && batchSize > 0) {
955 if (x[0].shape[0] % batchSize !== 0) {
956 throw new ValueError(`In a stateful network, you should only pass inputs with a ` +
957 `number of samples that is divisible by the batch size ` +
958 `${batchSize}. Found: ${x[0].shape[0]} sample(s).`);
959 }
960 }
961 return [x, y];
962 }
963 async standardizeUserData(x, y, sampleWeight, classWeight, checkBatchAxis = true, batchSize) {
964 const [standardXs, standardYs] = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
965 // TODO(cais): Handle sampleWeights.
966 if (sampleWeight != null) {
967 throw new Error('sample weight is not supported yet.');
968 }
969 let standardSampleWeights = null;
970 if (classWeight != null) {
971 const classWeights = standardizeClassWeights(classWeight, this.outputNames);
972 standardSampleWeights = [];
973 for (let i = 0; i < classWeights.length; ++i) {
974 standardSampleWeights.push(await standardizeWeights(standardYs[i], null, classWeights[i]));
975 }
976 }
977 // TODO(cais): Deal with the case of model.stateful == true.
978 return [standardXs, standardYs, standardSampleWeights];
979 }
980 /**
981 * Loop over some test data in batches.
982 * @param f A Function returning a list of tensors.
983 * @param ins Array of tensors to be fed to `f`.
984 * @param batchSize Integer batch size or `null` / `undefined`.
985 * @param verbose verbosity mode.
986 * @param steps Total number of steps (batches of samples) before
987 * declaring test finished. Ignored with the default value of `null` /
988 * `undefined`.
989 * @returns Array of Scalars.
990 */
991 testLoop(f, ins, batchSize, verbose = 0, steps) {
992 return tfc.tidy(() => {
993 const numSamples = this.checkNumSamples(ins, batchSize, steps, 'steps');
994 const outs = [];
995 if (verbose > 0) {
996 throw new NotImplementedError('Verbose mode is not implemented yet.');
997 }
998 // TODO(cais): Use `indicesForConversionToDense' to prevent slow down.
999 if (steps != null) {
1000 throw new NotImplementedError('steps mode in testLoop() is not implemented yet');
1001 }
1002 else {
1003 const batches = makeBatches(numSamples, batchSize);
1004 const indexArray = tensor1d(range(0, numSamples));
1005 for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
1006 const batchStart = batches[batchIndex][0];
1007 const batchEnd = batches[batchIndex][1];
1008 const batchIds = K.sliceAlongFirstAxis(indexArray, batchStart, batchEnd - batchStart);
1009 // TODO(cais): In ins, train flag can be a number, instead of an
1010 // Tensor? Do we need to handle this in tfjs-layers?
1011 const insBatch = sliceArraysByIndices(ins, batchIds);
1012 const batchOuts = f(insBatch);
1013 if (batchIndex === 0) {
1014 for (let i = 0; i < batchOuts.length; ++i) {
1015 outs.push(scalar(0));
1016 }
1017 }
1018 for (let i = 0; i < batchOuts.length; ++i) {
1019 const batchOut = batchOuts[i];
1020 outs[i] =
1021 tfc.add(outs[i], tfc.mul(batchEnd - batchStart, batchOut));
1022 }
1023 }
1024 for (let i = 0; i < outs.length; ++i) {
1025 outs[i] = tfc.div(outs[i], numSamples);
1026 }
1027 }
1028 return outs;
1029 });
1030 }
1031 getDedupedMetricsNames() {
1032 const outLabels = this.metricsNames;
1033 // Rename duplicated metrics names (can happen with an output layer
1034 // shared among multiple dataflows).
1035 const dedupedOutLabels = [];
1036 for (let i = 0; i < outLabels.length; ++i) {
1037 const label = outLabels[i];
1038 let newLabel = label;
1039 if (count(outLabels, label) > 1) {
1040 const dupIndex = count(outLabels.slice(0, i), label);
1041 newLabel += `_${dupIndex}`;
1042 }
1043 dedupedOutLabels.push(newLabel);
1044 }
1045 return dedupedOutLabels;
1046 }
1047 /**
1048 * Creates a function that performs the following actions:
1049 *
1050 * 1. computes the losses
1051 * 2. sums them to get the total loss
1052 * 3. call the optimizer computes the gradients of the LayersModel's
1053 * trainable weights w.r.t. the total loss and update the variables
1054 * 4. calculates the metrics
1055 * 5. returns the values of the losses and metrics.
1056 */
1057 makeTrainFunction() {
1058 return (data) => {
1059 const lossValues = [];
1060 const inputs = data.slice(0, this.inputs.length);
1061 const targets = data.slice(this.inputs.length, this.inputs.length + this.outputs.length);
1062 const sampleWeights = data.slice(this.inputs.length + this.outputs.length, this.inputs.length + this.outputs.length * 2);
1063 const metricsValues = [];
1064 // Create a function that computes the total loss based on the
1065 // inputs. This function is used for obtaining gradients through
1066 // backprop.
1067 const totalLossFunction = () => {
1068 const feeds = [];
1069 for (let i = 0; i < this.inputs.length; ++i) {
1070 feeds.push({ key: this.inputs[i], value: inputs[i] });
1071 }
1072 const feedDict = new FeedDict(feeds);
1073 const outputs = execute(this.outputs, feedDict, { 'training': true });
1074 // TODO(cais): Take care of the case of multiple outputs from a
1075 // single layer?
1076 let totalLoss;
1077 for (let i = 0; i < this.lossFunctions.length; ++i) {
1078 const lossFunction = this.lossFunctions[i];
1079 let loss = lossFunction(targets[i], outputs[i]);
1080 if (sampleWeights[i] != null) {
1081 loss = computeWeightedLoss(loss, sampleWeights[i]);
1082 }
1083 // TODO(cais): push Scalar instead.
1084 const meanLoss = tfc.mean(loss);
1085 // TODO(cais): Use a scope() instead, to avoid ownership.
1086 lossValues.push(meanLoss);
1087 if (i === 0) {
1088 totalLoss = loss;
1089 }
1090 else {
1091 totalLoss = tfc.add(totalLoss, loss);
1092 }
1093 }
1094 // Compute the metrics.
1095 // TODO(cais): These should probably be calculated outside
1096 // totalLossFunction to benefit speed?
1097 for (let i = 0; i < this.metricsTensors.length; ++i) {
1098 let weightedMetric;
1099 if (this.outputs.length > 1 && i < this.outputs.length) {
1100 weightedMetric = lossValues[i];
1101 }
1102 else {
1103 const metric = this.metricsTensors[i][0];
1104 const outputIndex = this.metricsTensors[i][1];
1105 weightedMetric =
1106 tfc.mean(metric(targets[outputIndex], outputs[outputIndex]));
1107 }
1108 tfc.keep(weightedMetric);
1109 // TODO(cais): Use a scope() instead, to avoid ownership.
1110 metricsValues.push(weightedMetric);
1111 }
1112 totalLoss = tfc.mean(totalLoss);
1113 // Add regularizer penalties.
1114 this.calculateLosses().forEach(regularizerLoss => {
1115 totalLoss = tfc.add(totalLoss, regularizerLoss);
1116 });
1117 return totalLoss;
1118 };
1119 const variables = this.collectedTrainableWeights.map(param => param.read());
1120 const returnCost = true;
1121 const totalLossValue = this.optimizer_.minimize(totalLossFunction, returnCost, variables);
1122 return [totalLossValue].concat(metricsValues);
1123 };
1124 }
1125 /**
1126 * Create a function which, when invoked with an array of `tf.Tensor`s as a
1127 * batch of inputs, returns the prespecified loss and metrics of the model
1128 * under the batch of input data.
1129 */
1130 makeTestFunction() {
1131 this.testFunction = (data) => {
1132 return tfc.tidy(() => {
1133 const valOutputs = [];
1134 let totalLoss;
1135 const inputs = data.slice(0, this.inputs.length);
1136 const targets = data.slice(this.inputs.length, this.inputs.length + this.outputs.length);
1137 const feeds = [];
1138 for (let i = 0; i < this.inputs.length; ++i) {
1139 feeds.push({ key: this.inputs[i], value: inputs[i] });
1140 }
1141 const feedDict = new FeedDict(feeds);
1142 const outputs = execute(this.outputs, feedDict);
1143 // Compute total loss.
1144 for (let i = 0; i < this.lossFunctions.length; ++i) {
1145 const lossFunction = this.lossFunctions[i];
1146 // TODO(cais): Add sample weighting and replace the simple
1147 // averaging.
1148 const loss = tfc.mean(lossFunction(targets[i], outputs[i]));
1149 if (i === 0) {
1150 totalLoss = loss;
1151 }
1152 else {
1153 totalLoss = tfc.add(totalLoss, loss);
1154 }
1155 valOutputs.push(totalLoss);
1156 }
1157 // Compute the metrics.
1158 for (let i = 0; i < this.metricsTensors.length; ++i) {
1159 const metric = this.metricsTensors[i][0];
1160 const outputIndex = this.metricsTensors[i][1];
1161 // TODO(cais): Replace K.mean() with a proper weighting function.
1162 const meanMetric = tfc.mean(metric(targets[outputIndex], outputs[outputIndex]));
1163 valOutputs.push(meanMetric);
1164 }
1165 return valOutputs;
1166 });
1167 };
1168 }
1169 /**
1170 * Trains the model for a fixed number of epochs (iterations on a
1171 * dataset).
1172 *
1173 * ```js
1174 * const model = tf.sequential({
1175 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
1176 * });
1177 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
1178 * for (let i = 1; i < 5 ; ++i) {
1179 * const h = await model.fit(tf.ones([8, 10]), tf.ones([8, 1]), {
1180 * batchSize: 4,
1181 * epochs: 3
1182 * });
1183 * console.log("Loss after Epoch " + i + " : " + h.history.loss[0]);
1184 * }
1185 * ```
1186 *
1187 * @param x `tf.Tensor` of training data, or an array of `tf.Tensor`s if the
1188 * model has multiple inputs. If all inputs in the model are named, you
1189 * can also pass a dictionary mapping input names to `tf.Tensor`s.
1190 * @param y `tf.Tensor` of target (label) data, or an array of `tf.Tensor`s if
1191 * the model has multiple outputs. If all outputs in the model are named,
1192 * you can also pass a dictionary mapping output names to `tf.Tensor`s.
1193 * @param args A `ModelFitArgs`, containing optional fields.
1194 *
1195 * @return A `History` instance. Its `history` attribute contains all
1196 * information collected during training.
1197 *
1198 * @exception ValueError In case of mismatch between the provided input
1199 * data and what the model expects.
1200 *
1201 * @doc {heading: 'Models', subheading: 'Classes'}
1202 */
1203 async fit(x, y, args = {}) {
1204 return fitTensors(this, x, y, args);
1205 }
1206 // TODO(cais): Add code snippet below when it's possible to instantiate
1207 // actual dataset objects.
1208 /**
1209 * Trains the model using a dataset object.
1210 *
1211 * @param dataset A dataset object. Its `iterator()` method is expected
1212 * to generate a dataset iterator object, the `next()` method of which
1213 * is expected to produce data batches for training. The return value
1214 * of the `next()` call ought to contain a boolean `done` field and a
1215 * `value` field. The `value` field is expected to be an array of two
1216 * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
1217 * case is for models with exactly one input and one output (e.g..
1218 * a sequential model). The latter case is for models with multiple
1219 * inputs and/or multiple outputs.
1220 * Of the two items in the array, the first is the input feature(s) and
1221 * the second is the output target(s).
1222 * @param args A `ModelFitDatasetArgs`, containing optional fields.
1223 *
1224 * @return A `History` instance. Its `history` attribute contains all
1225 * information collected during training.
1226 *
1227 * @doc {heading: 'Models', subheading: 'Classes'}
1228 */
1229 async fitDataset(dataset, args) {
1230 return fitDataset(this, dataset, args);
1231 }
1232 /**
1233 * Runs a single gradient update on a single batch of data.
1234 *
1235 * This method differs from `fit()` and `fitDataset()` in the following
1236 * regards:
1237 * - It operates on exactly one batch of data.
1238 * - It returns only the loss and matric values, instead of
1239 * returning the batch-by-batch loss and metric values.
1240 * - It doesn't support fine-grained options such as verbosity and
1241 * callbacks.
1242 *
1243 * @param x Input data. It could be one of the following:
1244 * - A `tf.Tensor`, or an Array of `tf.Tensor`s (in case the model has
1245 * multiple inputs).
1246 * - An Object mapping input names to corresponding `tf.Tensor` (if the
1247 * model has named inputs).
1248 * @param y Target darta. It could be either a `tf.Tensor` a multiple
1249 * `tf.Tensor`s. It should be consistent with `x`.
1250 * @returns Training loss or losses (in case the model has
1251 * multiple outputs), along with metrics (if any), as numbers.
1252 *
1253 * @doc {heading: 'Models', subheading: 'Classes'}
1254 */
1255 async trainOnBatch(x, y) {
1256 // TODO(cais): Support sampleWeight and classWeight.
1257 // TODO(cais): Support Dataset objects.
1258 const standardizeOut = await this.standardizeUserData(x, y);
1259 const inputs = standardizeOut[0];
1260 const targets = standardizeOut[1];
1261 const trainFunction = this.makeTrainFunction();
1262 const losses = trainFunction(inputs.concat(targets));
1263 const lossValues = [];
1264 for (const loss of losses) {
1265 const v = await loss.data();
1266 lossValues.push(v[0]);
1267 }
1268 tfc.dispose(losses);
1269 disposeNewTensors(standardizeOut[0], x);
1270 disposeNewTensors(standardizeOut[1], y);
1271 return singletonOrArray(lossValues);
1272 }
1273 /**
1274 * Extract weight values of the model.
1275 *
1276 * @param config: An instance of `io.SaveConfig`, which specifies
1277 * model-saving options such as whether only trainable weights are to be
1278 * saved.
1279 * @returns A `NamedTensorMap` mapping original weight names (i.e.,
1280 * non-uniqueified weight names) to their values.
1281 */
1282 getNamedWeights(config) {
1283 const namedWeights = [];
1284 const trainableOnly = config != null && config.trainableOnly;
1285 const weights = trainableOnly ? this.trainableWeights : this.weights;
1286 const weightValues = this.getWeights(trainableOnly);
1287 for (let i = 0; i < weights.length; ++i) {
1288 if (trainableOnly && !weights[i].trainable) {
1289 // Optionally skip non-trainable weights.
1290 continue;
1291 }
1292 namedWeights.push({ name: weights[i].originalName, tensor: weightValues[i] });
1293 }
1294 return namedWeights;
1295 }
1296 /**
1297 * Setter used for force stopping of LayersModel.fit() (i.e., training).
1298 *
1299 * Example:
1300 *
1301 * ```js
1302 * const input = tf.input({shape: [10]});
1303 * const output = tf.layers.dense({units: 1}).apply(input);
1304 * const model = tf.model({inputs: [input], outputs: [output]});
1305 * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
1306 * const xs = tf.ones([8, 10]);
1307 * const ys = tf.zeros([8, 1]);
1308 *
1309 * const history = await model.fit(xs, ys, {
1310 * epochs: 10,
1311 * callbacks: {
1312 * onEpochEnd: async (epoch, logs) => {
1313 * if (epoch === 2) {
1314 * model.stopTraining = true;
1315 * }
1316 * }
1317 * }
1318 * });
1319 *
1320 * // There should be only 3 values in the loss array, instead of 10
1321 * values,
1322 * // due to the stopping after 3 epochs.
1323 * console.log(history.history.loss);
1324 * ```
1325 */
1326 set stopTraining(stop) {
1327 this.stopTraining_ = stop;
1328 }
1329 get stopTraining() {
1330 return this.stopTraining_;
1331 }
1332 get optimizer() {
1333 return this.optimizer_;
1334 }
1335 set optimizer(optimizer) {
1336 if (this.optimizer_ !== optimizer) {
1337 this.optimizer_ = optimizer;
1338 this.isOptimizerOwned = false;
1339 }
1340 }
1341 dispose() {
1342 const result = super.dispose();
1343 if (result.refCountAfterDispose === 0 && this.optimizer != null &&
1344 this.isOptimizerOwned) {
1345 const numTensorsBeforeOptmizerDisposal = tfc.memory().numTensors;
1346 this.optimizer_.dispose();
1347 result.numDisposedVariables +=
1348 numTensorsBeforeOptmizerDisposal - tfc.memory().numTensors;
1349 }
1350 return result;
1351 }
1352 getLossIdentifiers() {
1353 let lossNames;
1354 if (typeof this.loss === 'string') {
1355 lossNames = toSnakeCase(this.loss);
1356 }
1357 else if (Array.isArray(this.loss)) {
1358 for (const loss of this.loss) {
1359 if (typeof loss !== 'string') {
1360 throw new Error('Serialization of non-string loss is not supported.');
1361 }
1362 }
1363 lossNames = this.loss.map(name => toSnakeCase(name));
1364 }
1365 else {
1366 const outputNames = Object.keys(this.loss);
1367 lossNames = {};
1368 const losses = this.loss;
1369 for (const outputName of outputNames) {
1370 if (typeof losses[outputName] === 'string') {
1371 lossNames[outputName] =
1372 toSnakeCase(losses[outputName]);
1373 }
1374 else {
1375 throw new Error('Serialization of non-string loss is not supported.');
1376 }
1377 }
1378 }
1379 return lossNames;
1380 }
1381 getMetricIdentifiers() {
1382 if (typeof this.metrics === 'string' ||
1383 typeof this.metrics === 'function') {
1384 return [toSnakeCase(Metrics.getLossOrMetricName(this.metrics))];
1385 }
1386 else if (Array.isArray(this.metrics)) {
1387 return this.metrics.map(metric => toSnakeCase(Metrics.getLossOrMetricName(metric)));
1388 }
1389 else {
1390 const metricsIdentifiers = {};
1391 for (const key in this.metrics) {
1392 metricsIdentifiers[key] =
1393 toSnakeCase(Metrics.getLossOrMetricName(this.metrics[key]));
1394 }
1395 return metricsIdentifiers;
1396 }
1397 }
1398 getTrainingConfig() {
1399 return {
1400 loss: this.getLossIdentifiers(),
1401 metrics: this.getMetricIdentifiers(),
1402 optimizer_config: {
1403 class_name: this.optimizer.getClassName(),
1404 config: this.optimizer.getConfig()
1405 }
1406 };
1407 // TODO(cais): Add weight_metrics when they are supported.
1408 // TODO(cais): Add sample_weight_mode when it's supported.
1409 // TODO(cais): Add loss_weights when it's supported.
1410 }
1411 loadTrainingConfig(trainingConfig) {
1412 if (trainingConfig.weighted_metrics != null) {
1413 throw new Error('Loading weight_metrics is not supported yet.');
1414 }
1415 if (trainingConfig.loss_weights != null) {
1416 throw new Error('Loading loss_weights is not supported yet.');
1417 }
1418 if (trainingConfig.sample_weight_mode != null) {
1419 throw new Error('Loading sample_weight_mode is not supported yet.');
1420 }
1421 const tsConfig = convertPythonicToTs(trainingConfig.optimizer_config);
1422 const optimizer = deserialize(tsConfig);
1423 let loss;
1424 if (typeof trainingConfig.loss === 'string') {
1425 loss = toCamelCase(trainingConfig.loss);
1426 }
1427 else if (Array.isArray(trainingConfig.loss)) {
1428 loss = trainingConfig.loss.map(lossEntry => toCamelCase(lossEntry));
1429 }
1430 else if (trainingConfig.loss != null) {
1431 loss = {};
1432 for (const key in trainingConfig.loss) {
1433 loss[key] = toCamelCase(trainingConfig.loss[key]);
1434 }
1435 }
1436 let metrics;
1437 if (Array.isArray(trainingConfig.metrics)) {
1438 metrics = trainingConfig.metrics.map(metric => toCamelCase(metric));
1439 }
1440 else if (trainingConfig.metrics != null) {
1441 metrics = {};
1442 for (const key in trainingConfig.metrics) {
1443 metrics[key] = toCamelCase(trainingConfig.metrics[key]);
1444 }
1445 }
1446 this.compile({ loss, metrics, optimizer });
1447 }
1448 /**
1449 * Save the configuration and/or weights of the LayersModel.
1450 *
1451 * An `IOHandler` is an object that has a `save` method of the proper
1452 * signature defined. The `save` method manages the storing or
1453 * transmission of serialized data ("artifacts") that represent the
1454 * model's topology and weights onto or via a specific medium, such as
1455 * file downloads, local storage, IndexedDB in the web browser and HTTP
1456 * requests to a server. TensorFlow.js provides `IOHandler`
1457 * implementations for a number of frequently used saving mediums, such as
1458 * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
1459 * for more details.
1460 *
1461 * This method also allows you to refer to certain types of `IOHandler`s
1462 * as URL-like string shortcuts, such as 'localstorage://' and
1463 * 'indexeddb://'.
1464 *
1465 * Example 1: Save `model`'s topology and weights to browser [local
1466 * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
1467 * then load it back.
1468 *
1469 * ```js
1470 * const model = tf.sequential(
1471 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
1472 * console.log('Prediction from original model:');
1473 * model.predict(tf.ones([1, 3])).print();
1474 *
1475 * const saveResults = await model.save('localstorage://my-model-1');
1476 *
1477 * const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
1478 * console.log('Prediction from loaded model:');
1479 * loadedModel.predict(tf.ones([1, 3])).print();
1480 * ```
1481 *
1482 * Example 2. Saving `model`'s topology and weights to browser
1483 * [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
1484 * then load it back.
1485 *
1486 * ```js
1487 * const model = tf.sequential(
1488 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
1489 * console.log('Prediction from original model:');
1490 * model.predict(tf.ones([1, 3])).print();
1491 *
1492 * const saveResults = await model.save('indexeddb://my-model-1');
1493 *
1494 * const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
1495 * console.log('Prediction from loaded model:');
1496 * loadedModel.predict(tf.ones([1, 3])).print();
1497 * ```
1498 *
1499 * Example 3. Saving `model`'s topology and weights as two files
1500 * (`my-model-1.json` and `my-model-1.weights.bin`) downloaded from
1501 * browser.
1502 *
1503 * ```js
1504 * const model = tf.sequential(
1505 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
1506 * const saveResults = await model.save('downloads://my-model-1');
1507 * ```
1508 *
1509 * Example 4. Send `model`'s topology and weights to an HTTP server.
1510 * See the documentation of `tf.io.http` for more details
1511 * including specifying request parameters and implementation of the
1512 * server.
1513 *
1514 * ```js
1515 * const model = tf.sequential(
1516 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
1517 * const saveResults = await model.save('http://my-server/model/upload');
1518 * ```
1519 *
1520 * @param handlerOrURL An instance of `IOHandler` or a URL-like,
1521 * scheme-based string shortcut for `IOHandler`.
1522 * @param config Options for saving the model.
1523 * @returns A `Promise` of `SaveResult`, which summarizes the result of
1524 * the saving, such as byte sizes of the saved artifacts for the model's
1525 * topology and weight values.
1526 *
1527 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
1528 */
1529 async save(handlerOrURL, config) {
1530 if (typeof handlerOrURL === 'string') {
1531 const handlers = io.getSaveHandlers(handlerOrURL);
1532 if (handlers.length === 0) {
1533 throw new ValueError(`Cannot find any save handlers for URL '${handlerOrURL}'`);
1534 }
1535 else if (handlers.length > 1) {
1536 throw new ValueError(`Found more than one (${handlers.length}) save handlers for ` +
1537 `URL '${handlerOrURL}'`);
1538 }
1539 handlerOrURL = handlers[0];
1540 }
1541 if (handlerOrURL.save == null) {
1542 throw new ValueError('LayersModel.save() cannot proceed because the IOHandler ' +
1543 'provided does not have the `save` attribute defined.');
1544 }
1545 const weightDataAndSpecs = await io.encodeWeights(this.getNamedWeights(config));
1546 const returnString = false;
1547 const unusedArg = null;
1548 const modelConfig = this.toJSON(unusedArg, returnString);
1549 const modelArtifacts = {
1550 modelTopology: modelConfig,
1551 format: LAYERS_MODEL_FORMAT_NAME,
1552 generatedBy: `TensorFlow.js tfjs-layers v${version}`,
1553 convertedBy: null,
1554 };
1555 const includeOptimizer = config == null ? false : config.includeOptimizer;
1556 if (includeOptimizer && this.optimizer != null) {
1557 modelArtifacts.trainingConfig = this.getTrainingConfig();
1558 const weightType = 'optimizer';
1559 const { data: optimizerWeightData, specs: optimizerWeightSpecs } = await io.encodeWeights(await this.optimizer.getWeights(), weightType);
1560 weightDataAndSpecs.specs.push(...optimizerWeightSpecs);
1561 weightDataAndSpecs.data = io.concatenateArrayBuffers([weightDataAndSpecs.data, optimizerWeightData]);
1562 }
1563 if (this.userDefinedMetadata != null) {
1564 // Check serialized size of user-defined metadata.
1565 const checkSize = true;
1566 checkUserDefinedMetadata(this.userDefinedMetadata, this.name, checkSize);
1567 modelArtifacts.userDefinedMetadata = this.userDefinedMetadata;
1568 }
1569 modelArtifacts.weightData = weightDataAndSpecs.data;
1570 modelArtifacts.weightSpecs = weightDataAndSpecs.specs;
1571 return handlerOrURL.save(modelArtifacts);
1572 }
1573 /**
1574 * Set user-defined metadata.
1575 *
1576 * The set metadata will be serialized together with the topology
1577 * and weights of the model during `save()` calls.
1578 *
1579 * @param setUserDefinedMetadata
1580 */
1581 setUserDefinedMetadata(userDefinedMetadata) {
1582 checkUserDefinedMetadata(userDefinedMetadata, this.name);
1583 this.userDefinedMetadata = userDefinedMetadata;
1584 }
1585 /**
1586 * Get user-defined metadata.
1587 *
1588 * The metadata is supplied via one of the two routes:
1589 * 1. By calling `setUserDefinedMetadata()`.
1590 * 2. Loaded during model loading (if the model is constructed
1591 * via `tf.loadLayersModel()`.)
1592 *
1593 * If no user-defined metadata is available from either of the
1594 * two routes, this function will return `undefined`.
1595 */
1596 getUserDefinedMetadata() {
1597 return this.userDefinedMetadata;
1598 }
1599}
1600// The class name is 'Model' rather than 'LayersModel' for backwards
1601// compatibility since this class name shows up in the serialization format.
1602/** @nocollapse */
1603LayersModel.className = 'Model';
1604serialization.registerClass(LayersModel);
1605/**
1606 * A `tf.Functional` is an alias to `tf.LayersModel`.
1607 *
1608 * See also:
1609 * `tf.LayersModel`, `tf.Sequential`, `tf.loadLayersModel`.
1610 */
1611/** @doc {heading: 'Models', subheading: 'Classes'} */
1612export class Functional extends LayersModel {
1613}
1614Functional.className = 'Functional';
1615serialization.registerClass(Functional);
1616//# sourceMappingURL=data:application/json;base64,
\No newline at end of file