UNPKG

21 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10import { argMax, clone, dispose, mul, reshape, tensor1d, tidy } from '@tensorflow/tfjs-core';
11function standardizeSampleOrClassWeights(xWeight, outputNames, weightType) {
12 const numOutputs = outputNames.length;
13 if (xWeight == null || (Array.isArray(xWeight) && xWeight.length === 0)) {
14 return outputNames.map(name => null);
15 }
16 if (numOutputs === 1) {
17 if (Array.isArray(xWeight) && xWeight.length === 1) {
18 return xWeight;
19 }
20 else if (typeof xWeight === 'object' && outputNames[0] in xWeight) {
21 return [xWeight[outputNames[0]]];
22 }
23 else {
24 return [xWeight];
25 }
26 }
27 if (Array.isArray(xWeight)) {
28 if (xWeight.length !== numOutputs) {
29 throw new Error(`Provided ${weightType} is an array of ${xWeight.length} ` +
30 `element(s), but the model has ${numOutputs} outputs. ` +
31 `Make sure a set of weights is provided for each model output.`);
32 }
33 return xWeight;
34 }
35 else if (typeof xWeight === 'object' && Object.keys(xWeight).length > 0 &&
36 typeof xWeight[Object.keys(xWeight)[0]] ===
37 'object') {
38 const output = [];
39 outputNames.forEach(outputName => {
40 if (outputName in xWeight) {
41 output.push(xWeight[outputName]);
42 }
43 else {
44 output.push(null);
45 }
46 });
47 return output;
48 }
49 else {
50 throw new Error(`The model has multiple (${numOutputs}) outputs, ` +
51 `so ${weightType} must be either an array with ` +
52 `${numOutputs} elements or an object with ${outputNames} keys. ` +
53 `Provided ${weightType} not understood: ${JSON.stringify(xWeight)}`);
54 }
55}
56/**
57 * Standardize class weighting objects.
58 *
59 * This function takes a single class-weighting object, an array of them,
60 * or a map from output name to class-weighting object. It compares it to the
61 * output name(s) of the model, base on which it outputs an array of
62 * class-weighting objects of which the length matches the number of outputs.
63 *
64 * @param classWeight Input class-weighting object(s).
65 * @param outputNames All output name(s) of the model.
66 * @return An array of class-weighting objects. The length of the array matches
67 * the model's number of outputs.
68 */
69export function standardizeClassWeights(classWeight, outputNames) {
70 return standardizeSampleOrClassWeights(classWeight, outputNames, 'classWeight');
71}
72export function standardizeSampleWeights(classWeight, outputNames) {
73 return standardizeSampleOrClassWeights(classWeight, outputNames, 'sampleWeight');
74}
75/**
76 * Standardize by-sample and/or by-class weights for training.
77 *
78 * Note that this function operates on one model output at a time. For a model
79 * with multiple outputs, you must call this function multiple times.
80 *
81 * @param y The target tensor that the by-sample and/or by-class weight is for.
82 * The values of y are assumed to encode the classes, either directly
83 * as an integer index, or as one-hot encoding.
84 * @param sampleWeight By-sample weights.
85 * @param classWeight By-class weights: an object mapping class indices
86 * (integers) to a weight (float) to apply to the model's loss for the
87 * samples from this class during training. This can be useful to tell the
88 * model to "pay more attention" to samples from an under-represented class.
89 * @param sampleWeightMode The mode for the sample weights.
90 * @return A Promise of weight tensor, of which the size of the first dimension
91 * matches that of `y`.
92 */
93export async function standardizeWeights(y, sampleWeight, classWeight, sampleWeightMode) {
94 if (sampleWeight != null || sampleWeightMode != null) {
95 // TODO(cais): Once 'temporal' mode is implemented, document it in the doc
96 // string.
97 throw new Error('Support sampleWeight is not implemented yet');
98 }
99 if (classWeight != null) {
100 // Apply class weights per sample.
101 const yClasses = tidy(() => {
102 if (y.shape.length === 1) {
103 // Assume class indices.
104 return clone(y);
105 }
106 else if (y.shape.length === 2) {
107 if (y.shape[1] > 1) {
108 // Assume one-hot encoding of classes.
109 const axis = 1;
110 return argMax(y, axis);
111 }
112 else if (y.shape[1] === 1) {
113 // Class index.
114 return reshape(y, [y.shape[0]]);
115 }
116 else {
117 throw new Error(`Encountered unexpected last-dimension size (${y.shape[1]}) ` +
118 `during handling of class weights. The size is expected to be ` +
119 `>= 1.`);
120 }
121 }
122 else {
123 throw new Error(`Unexpected rank of target (y) tensor (${y.rank}) during ` +
124 `handling of class weights. The rank is expected to be 1 or 2.`);
125 }
126 });
127 const yClassIndices = Array.from(await yClasses.data());
128 dispose(yClasses);
129 const classSampleWeight = [];
130 yClassIndices.forEach(classIndex => {
131 if (classWeight[classIndex] == null) {
132 throw new Error(`classWeight must contain all classes in the training data. ` +
133 `The class ${classIndex} exists in the data but not in ` +
134 `classWeight`);
135 }
136 else {
137 classSampleWeight.push(classWeight[classIndex]);
138 }
139 });
140 return tensor1d(classSampleWeight, 'float32');
141 }
142 else {
143 return null;
144 }
145}
146/**
147 * Apply per-sample weights on the loss values from a number of samples.
148 *
149 * @param losses Loss tensor of shape `[batchSize]`.
150 * @param sampleWeights Per-sample weight tensor of shape `[batchSize]`.
151 * @returns Tensor of the same shape as`losses`.
152 */
153export function computeWeightedLoss(losses, sampleWeights) {
154 return mul(losses, sampleWeights);
155}
156//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"training_utils.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/engine/training_utils.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH,OAAO,EAAC,MAAM,EAAE,KAAK,EAAE,OAAO,EAAE,GAAG,EAAE,OAAO,EAAoB,QAAQ,EAAE,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAuB7G,SAAS,+BAA+B,CACpC,OAAiD,EAAE,WAAqB,EACxE,UAAwC;IAC1C,MAAM,UAAU,GAAG,WAAW,CAAC,MAAM,CAAC;IACtC,IAAI,OAAO,IAAI,IAAI,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,IAAI,OAAO,CAAC,MAAM,KAAK,CAAC,CAAC,EAAE;QACvE,OAAO,WAAW,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC;KACtC;IACD,IAAI,UAAU,KAAK,CAAC,EAAE;QACpB,IAAI,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,IAAI,OAAO,CAAC,MAAM,KAAK,CAAC,EAAE;YAClD,OAAO,OAAO,CAAC;SAChB;aAAM,IAAI,OAAO,OAAO,KAAK,QAAQ,IAAI,WAAW,CAAC,CAAC,CAAC,IAAI,OAAO,EAAE;YACnE,OAAO,CAAE,OAA0B,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;SACtD;aAAM;YACL,OAAO,CAAC,OAAsB,CAAC,CAAC;SACjC;KACF;IACD,IAAI,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,EAAE;QAC1B,IAAI,OAAO,CAAC,MAAM,KAAK,UAAU,EAAE;YACjC,MAAM,IAAI,KAAK,CACX,YAAY,UAAU,mBAAmB,OAAO,CAAC,MAAM,GAAG;gBAC1D,iCAAiC,UAAU,YAAY;gBACvD,+DAA+D,CAAC,CAAC;SACtE;QACD,OAAO,OAAO,CAAC;KAChB;SAAM,IACH,OAAO,OAAO,KAAK,QAAQ,IAAI,MAAM,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,MAAM,GAAG,CAAC;QAC9D,OAAQ,OAA0B,CAAC,MAAM,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC;YACvD,QAAQ,EAAE;QAChB,MAAM,MAAM,GAAkB,EAAE,CAAC;QACjC,WAAW,CAAC,OAAO,CAAC,UAAU,CAAC,EAAE;YAC/B,IAAI,UAAU,IAAI,OAAO,EAAE;gBACzB,MAAM,CAAC,IAAI,CAAE,OAA0B,CAAC,UAAU,CAAC,CAAC,CAAC;aACtD;iBAAM;gBACL,MAAM,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;aACnB;QACH,CAAC,CAAC,CAAC;QACH,OAAO,MAAM,CAAC;KACf;SAAM;QACL,MAAM,IAAI,KAAK,CACX,2BAA2B,UAAU,aAAa;YAClD,MAAM,UAAU,gCAAgC;YAChD,GAAG,UAAU,+BAA+B,WAAW,SAAS;YAChE,YAAY,UAAU,oBAAoB,IAAI,CAAC,SAAS,CAAC,OAAO,CAAC,EAAE,CAAC,CAAC;KAC1E;AACH,CAAC;AAED;;;;;;;;;;;;GAYG;AACH,MAAM,UAAU,uBAAuB,CACnC,WAAqD,EACrD,WAAqB;IACvB,OAAO,+BAA+B,CAClC,WAAW,EAAE,WAAW,EAAE,aAAa,CAAC,CAAC;AAC/C,CAAC;AAED,MAAM,UAAU,wBAAwB,CACpC,WAAqD,EACrD,WAAqB;IACvB,OAAO,+BAA+B,CAClC,WAAW,EAAE,WAAW,EAAE,cAAc,CAAC,CAAC;AAChD,CAAC;AAED;;;;;;;;;;;;;;;;;GAiBG;AACH,MAAM,CAAC,KAAK,UAAU,kBAAkB,CACpC,CAAS,EAAE,YAAqB,EAAE,WAAyB,EAC3D,gBAA6B;IAC/B,IAAI,YAAY,IAAI,IAAI,IAAI,gBAAgB,IAAI,IAAI,EAAE;QACpD,0EAA0E;QAC1E,UAAU;QACV,MAAM,IAAI,KAAK,CAAC,6CAA6C,CAAC,CAAC;KAChE;IAED,IAAI,WAAW,IAAI,IAAI,EAAE;QACvB,kCAAkC;QAClC,MAAM,QAAQ,GAAa,IAAI,CAAC,GAAG,EAAE;YACnC,IAAI,CAAC,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC,EAAE;gBACxB,wBAAwB;gBACxB,OAAO,KAAK,CAAC,CAAC,CAAa,CAAC;aAC7B;iBAAM,IAAI,CAAC,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC,EAAE;gBAC/B,IAAI,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,EAAE;oBAClB,sCAAsC;oBACtC,MAAM,IAAI,GAAG,CAAC,CAAC;oBACf,OAAO,MAAM,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC;iBACxB;qBAAM,IAAI,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,EAAE;oBAC3B,eAAe;oBACf,OAAO,OAAO,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;iBACjC;qBAAM;oBACL,MAAM,IAAI,KAAK,CACX,+CAA+C,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI;wBAC7D,+DAA+D;wBAC/D,OAAO,CAAC,CAAC;iBACd;aACF;iBAAM;gBACL,MAAM,IAAI,KAAK,CACX,yCAAyC,CAAC,CAAC,IAAI,WAAW;oBAC1D,+DAA+D,CAAC,CAAC;aACtE;QACH,CAAC,CAAC,CAAC;QAEH,MAAM,aAAa,GAAG,KAAK,CAAC,IAAI,CAAC,MAAM,QAAQ,CAAC,IAAI,EAAE,CAAC,CAAC;QACxD,OAAO,CAAC,QAAQ,CAAC,CAAC;QAClB,MAAM,iBAAiB,GAAa,EAAE,CAAC;QACvC,aAAa,CAAC,OAAO,CAAC,UAAU,CAAC,EAAE;YACjC,IAAI,WAAW,CAAC,UAAU,CAAC,IAAI,IAAI,EAAE;gBACnC,MAAM,IAAI,KAAK,CACX,6DAA6D;oBAC7D,aAAa,UAAU,iCAAiC;oBACxD,aAAa,CAAC,CAAC;aACpB;iBAAM;gBACL,iBAAiB,CAAC,IAAI,CAAC,WAAW,CAAC,UAAU,CAAC,CAAC,CAAC;aACjD;QACH,CAAC,CAAC,CAAC;QAEH,OAAO,QAAQ,CAAC,iBAAiB,EAAE,SAAS,CAAC,CAAC;KAC/C;SAAM;QACL,OAAO,IAAI,CAAC;KACb;AACH,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,mBAAmB,CAAC,MAAc,EAAE,aAAqB;IACvE,OAAO,GAAG,CAAC,MAAM,EAAE,aAAa,CAAC,CAAC;AACpC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\nimport {argMax, clone, dispose, mul, reshape, Tensor, Tensor1D, tensor1d, tidy} from '@tensorflow/tfjs-core';\n\n/**\n * For multi-class classification problems, this object is designed to store a\n * mapping from class index to the \"weight\" of the class, where higher weighted\n * classes have larger impact on loss, accuracy, and other metrics.\n *\n * This is useful for cases in which you want the model to \"pay more attention\"\n * to examples from an under-represented class, e.g., in unbalanced datasets.\n */\nexport type ClassWeight = {\n  [classIndex: number]: number\n};\n\n/**\n * Class weighting for a model with multiple outputs.\n *\n * This object maps each output name to a class-weighting object.\n */\nexport type ClassWeightMap = {\n  [outputName: string]: ClassWeight\n};\n\nfunction standardizeSampleOrClassWeights(\n    xWeight: ClassWeight|ClassWeight[]|ClassWeightMap, outputNames: string[],\n    weightType: 'sampleWeight'|'classWeight'): ClassWeight[] {\n  const numOutputs = outputNames.length;\n  if (xWeight == null || (Array.isArray(xWeight) && xWeight.length === 0)) {\n    return outputNames.map(name => null);\n  }\n  if (numOutputs === 1) {\n    if (Array.isArray(xWeight) && xWeight.length === 1) {\n      return xWeight;\n    } else if (typeof xWeight === 'object' && outputNames[0] in xWeight) {\n      return [(xWeight as ClassWeightMap)[outputNames[0]]];\n    } else {\n      return [xWeight as ClassWeight];\n    }\n  }\n  if (Array.isArray(xWeight)) {\n    if (xWeight.length !== numOutputs) {\n      throw new Error(\n          `Provided ${weightType} is an array of ${xWeight.length} ` +\n          `element(s), but the model has ${numOutputs} outputs. ` +\n          `Make sure a set of weights is provided for each model output.`);\n    }\n    return xWeight;\n  } else if (\n      typeof xWeight === 'object' && Object.keys(xWeight).length > 0 &&\n      typeof (xWeight as ClassWeightMap)[Object.keys(xWeight)[0]] ===\n          'object') {\n    const output: ClassWeight[] = [];\n    outputNames.forEach(outputName => {\n      if (outputName in xWeight) {\n        output.push((xWeight as ClassWeightMap)[outputName]);\n      } else {\n        output.push(null);\n      }\n    });\n    return output;\n  } else {\n    throw new Error(\n        `The model has multiple (${numOutputs}) outputs, ` +\n        `so ${weightType} must be either an array with ` +\n        `${numOutputs} elements or an object with ${outputNames} keys. ` +\n        `Provided ${weightType} not understood: ${JSON.stringify(xWeight)}`);\n  }\n}\n\n/**\n * Standardize class weighting objects.\n *\n * This function takes a single class-weighting object, an array of them,\n * or a map from output name to class-weighting object. It compares it to the\n * output name(s) of the model, base on which it outputs an array of\n * class-weighting objects of which the length matches the number of outputs.\n *\n * @param classWeight Input class-weighting object(s).\n * @param outputNames All output name(s) of the model.\n * @return An array of class-weighting objects. The length of the array matches\n *   the model's number of outputs.\n */\nexport function standardizeClassWeights(\n    classWeight: ClassWeight|ClassWeight[]|ClassWeightMap,\n    outputNames: string[]): ClassWeight[] {\n  return standardizeSampleOrClassWeights(\n      classWeight, outputNames, 'classWeight');\n}\n\nexport function standardizeSampleWeights(\n    classWeight: ClassWeight|ClassWeight[]|ClassWeightMap,\n    outputNames: string[]): ClassWeight[] {\n  return standardizeSampleOrClassWeights(\n      classWeight, outputNames, 'sampleWeight');\n}\n\n/**\n * Standardize by-sample and/or by-class weights for training.\n *\n * Note that this function operates on one model output at a time. For a model\n * with multiple outputs, you must call this function multiple times.\n *\n * @param y The target tensor that the by-sample and/or by-class weight is for.\n *     The values of y are assumed to encode the classes, either directly\n *     as an integer index, or as one-hot encoding.\n * @param sampleWeight By-sample weights.\n * @param classWeight By-class weights: an object mapping class indices\n *     (integers) to a weight (float) to apply to the model's loss for the\n *     samples from this class during training. This can be useful to tell the\n *     model to \"pay more attention\" to samples from an under-represented class.\n * @param sampleWeightMode The mode for the sample weights.\n * @return A Promise of weight tensor, of which the size of the first dimension\n *     matches that of `y`.\n */\nexport async function standardizeWeights(\n    y: Tensor, sampleWeight?: Tensor, classWeight?: ClassWeight,\n    sampleWeightMode?: 'temporal'): Promise<Tensor> {\n  if (sampleWeight != null || sampleWeightMode != null) {\n    // TODO(cais): Once 'temporal' mode is implemented, document it in the doc\n    // string.\n    throw new Error('Support sampleWeight is not implemented yet');\n  }\n\n  if (classWeight != null) {\n    // Apply class weights per sample.\n    const yClasses: Tensor1D = tidy(() => {\n      if (y.shape.length === 1) {\n        // Assume class indices.\n        return clone(y) as Tensor1D;\n      } else if (y.shape.length === 2) {\n        if (y.shape[1] > 1) {\n          // Assume one-hot encoding of classes.\n          const axis = 1;\n          return argMax(y, axis);\n        } else if (y.shape[1] === 1) {\n          // Class index.\n          return reshape(y, [y.shape[0]]);\n        } else {\n          throw new Error(\n              `Encountered unexpected last-dimension size (${y.shape[1]}) ` +\n              `during handling of class weights. The size is expected to be ` +\n              `>= 1.`);\n        }\n      } else {\n        throw new Error(\n            `Unexpected rank of target (y) tensor (${y.rank}) during ` +\n            `handling of class weights. The rank is expected to be 1 or 2.`);\n      }\n    });\n\n    const yClassIndices = Array.from(await yClasses.data());\n    dispose(yClasses);\n    const classSampleWeight: number[] = [];\n    yClassIndices.forEach(classIndex => {\n      if (classWeight[classIndex] == null) {\n        throw new Error(\n            `classWeight must contain all classes in the training data. ` +\n            `The class ${classIndex} exists in the data but not in ` +\n            `classWeight`);\n      } else {\n        classSampleWeight.push(classWeight[classIndex]);\n      }\n    });\n\n    return tensor1d(classSampleWeight, 'float32');\n  } else {\n    return null;\n  }\n}\n\n/**\n * Apply per-sample weights on the loss values from a number of samples.\n *\n * @param losses Loss tensor of shape `[batchSize]`.\n * @param sampleWeights Per-sample weight tensor of shape `[batchSize]`.\n * @returns Tensor of the same shape as`losses`.\n */\nexport function computeWeightedLoss(losses: Tensor, sampleWeights: Tensor) {\n  return mul(losses, sampleWeights);\n}\n"]}
\No newline at end of file