UNPKG

3.45 kBTypeScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/// <amd-module name="@tensorflow/tfjs-layers/dist/engine/training_utils" />
11import { Tensor } from '@tensorflow/tfjs-core';
12/**
13 * For multi-class classification problems, this object is designed to store a
14 * mapping from class index to the "weight" of the class, where higher weighted
15 * classes have larger impact on loss, accuracy, and other metrics.
16 *
17 * This is useful for cases in which you want the model to "pay more attention"
18 * to examples from an under-represented class, e.g., in unbalanced datasets.
19 */
20export declare type ClassWeight = {
21 [classIndex: number]: number;
22};
23/**
24 * Class weighting for a model with multiple outputs.
25 *
26 * This object maps each output name to a class-weighting object.
27 */
28export declare type ClassWeightMap = {
29 [outputName: string]: ClassWeight;
30};
31/**
32 * Standardize class weighting objects.
33 *
34 * This function takes a single class-weighting object, an array of them,
35 * or a map from output name to class-weighting object. It compares it to the
36 * output name(s) of the model, base on which it outputs an array of
37 * class-weighting objects of which the length matches the number of outputs.
38 *
39 * @param classWeight Input class-weighting object(s).
40 * @param outputNames All output name(s) of the model.
41 * @return An array of class-weighting objects. The length of the array matches
42 * the model's number of outputs.
43 */
44export declare function standardizeClassWeights(classWeight: ClassWeight | ClassWeight[] | ClassWeightMap, outputNames: string[]): ClassWeight[];
45export declare function standardizeSampleWeights(classWeight: ClassWeight | ClassWeight[] | ClassWeightMap, outputNames: string[]): ClassWeight[];
46/**
47 * Standardize by-sample and/or by-class weights for training.
48 *
49 * Note that this function operates on one model output at a time. For a model
50 * with multiple outputs, you must call this function multiple times.
51 *
52 * @param y The target tensor that the by-sample and/or by-class weight is for.
53 * The values of y are assumed to encode the classes, either directly
54 * as an integer index, or as one-hot encoding.
55 * @param sampleWeight By-sample weights.
56 * @param classWeight By-class weights: an object mapping class indices
57 * (integers) to a weight (float) to apply to the model's loss for the
58 * samples from this class during training. This can be useful to tell the
59 * model to "pay more attention" to samples from an under-represented class.
60 * @param sampleWeightMode The mode for the sample weights.
61 * @return A Promise of weight tensor, of which the size of the first dimension
62 * matches that of `y`.
63 */
64export declare function standardizeWeights(y: Tensor, sampleWeight?: Tensor, classWeight?: ClassWeight, sampleWeightMode?: 'temporal'): Promise<Tensor>;
65/**
66 * Apply per-sample weights on the loss values from a number of samples.
67 *
68 * @param losses Loss tensor of shape `[batchSize]`.
69 * @param sampleWeights Per-sample weight tensor of shape `[batchSize]`.
70 * @returns Tensor of the same shape as`losses`.
71 */
72export declare function computeWeightedLoss(losses: Tensor, sampleWeights: Tensor): Tensor<import("@tensorflow/tfjs-core").Rank>;