1 | /**
|
2 | * @license
|
3 | * Copyright 2018 Google LLC
|
4 | *
|
5 | * Use of this source code is governed by an MIT-style
|
6 | * license that can be found in the LICENSE file or at
|
7 | * https://opensource.org/licenses/MIT.
|
8 | * =============================================================================
|
9 | */
|
10 | /// <amd-module name="@tensorflow/tfjs-layers/dist/engine/training_utils" />
|
11 | import { Tensor } from '@tensorflow/tfjs-core';
|
12 | /**
|
13 | * For multi-class classification problems, this object is designed to store a
|
14 | * mapping from class index to the "weight" of the class, where higher weighted
|
15 | * classes have larger impact on loss, accuracy, and other metrics.
|
16 | *
|
17 | * This is useful for cases in which you want the model to "pay more attention"
|
18 | * to examples from an under-represented class, e.g., in unbalanced datasets.
|
19 | */
|
20 | export declare type ClassWeight = {
|
21 | [classIndex: number]: number;
|
22 | };
|
23 | /**
|
24 | * Class weighting for a model with multiple outputs.
|
25 | *
|
26 | * This object maps each output name to a class-weighting object.
|
27 | */
|
28 | export declare type ClassWeightMap = {
|
29 | [outputName: string]: ClassWeight;
|
30 | };
|
31 | /**
|
32 | * Standardize class weighting objects.
|
33 | *
|
34 | * This function takes a single class-weighting object, an array of them,
|
35 | * or a map from output name to class-weighting object. It compares it to the
|
36 | * output name(s) of the model, base on which it outputs an array of
|
37 | * class-weighting objects of which the length matches the number of outputs.
|
38 | *
|
39 | * @param classWeight Input class-weighting object(s).
|
40 | * @param outputNames All output name(s) of the model.
|
41 | * @return An array of class-weighting objects. The length of the array matches
|
42 | * the model's number of outputs.
|
43 | */
|
44 | export declare function standardizeClassWeights(classWeight: ClassWeight | ClassWeight[] | ClassWeightMap, outputNames: string[]): ClassWeight[];
|
45 | export declare function standardizeSampleWeights(classWeight: ClassWeight | ClassWeight[] | ClassWeightMap, outputNames: string[]): ClassWeight[];
|
46 | /**
|
47 | * Standardize by-sample and/or by-class weights for training.
|
48 | *
|
49 | * Note that this function operates on one model output at a time. For a model
|
50 | * with multiple outputs, you must call this function multiple times.
|
51 | *
|
52 | * @param y The target tensor that the by-sample and/or by-class weight is for.
|
53 | * The values of y are assumed to encode the classes, either directly
|
54 | * as an integer index, or as one-hot encoding.
|
55 | * @param sampleWeight By-sample weights.
|
56 | * @param classWeight By-class weights: an object mapping class indices
|
57 | * (integers) to a weight (float) to apply to the model's loss for the
|
58 | * samples from this class during training. This can be useful to tell the
|
59 | * model to "pay more attention" to samples from an under-represented class.
|
60 | * @param sampleWeightMode The mode for the sample weights.
|
61 | * @return A Promise of weight tensor, of which the size of the first dimension
|
62 | * matches that of `y`.
|
63 | */
|
64 | export declare function standardizeWeights(y: Tensor, sampleWeight?: Tensor, classWeight?: ClassWeight, sampleWeightMode?: 'temporal'): Promise<Tensor>;
|
65 | /**
|
66 | * Apply per-sample weights on the loss values from a number of samples.
|
67 | *
|
68 | * @param losses Loss tensor of shape `[batchSize]`.
|
69 | * @param sampleWeights Per-sample weight tensor of shape `[batchSize]`.
|
70 | * @returns Tensor of the same shape as`losses`.
|
71 | */
|
72 | export declare function computeWeightedLoss(losses: Tensor, sampleWeights: Tensor): Tensor<import("). /tfjs-core"Rank>;
|