UNPKG

11.6 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { convertToTensor } from '../tensor_util_env';
18import * as util from '../util';
19import { cast } from './cast';
20import { matMul } from './mat_mul';
21import { oneHot } from './one_hot';
22import { op } from './operation';
23import { transpose } from './transpose';
24/**
25 * Computes the confusion matrix from true labels and predicted labels.
26 *
27 * ```js
28 * const labels = tf.tensor1d([0, 1, 2, 1, 0], 'int32');
29 * const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'int32');
30 * const numClasses = 3;
31 * const out = tf.math.confusionMatrix(labels, predictions, numClasses);
32 * out.print();
33 * // Expected output matrix:
34 * // [[2, 0, 0],
35 * // [0, 1, 1],
36 * // [0, 0, 1]]
37 * ```
38 *
39 * @param labels The target labels, assumed to be 0-based integers
40 * for the classes. The shape is `[numExamples]`, where
41 * `numExamples` is the number of examples included.
42 * @param predictions The predicted classes, assumed to be
43 * 0-based integers for the classes. Must have the same shape as `labels`.
44 * @param numClasses Number of all classes, as an integer.
45 * Its value must be larger than the largest element in `labels` and
46 * `predictions`.
47 * @returns The confusion matrix as a int32-type 2D tensor. The value at
48 * row `r` and column `c` is the number of times examples of actual class
49 * `r` were predicted as class `c`.
50 *
51 * @doc {heading: 'Operations', subheading: 'Evaluation'}
52 */
53export function confusionMatrix_(labels, predictions, numClasses) {
54 const $labels = convertToTensor(labels, 'labels', 'confusionMatrix');
55 const $predictions = convertToTensor(predictions, 'predictions', 'confusionMatrix');
56 util.assert(numClasses == null || numClasses > 0 && Number.isInteger(numClasses), () => `If provided, numClasses must be a positive integer, ` +
57 `but got ${numClasses}`);
58 util.assert($labels.rank === 1, () => `Expected the rank of labels to be 1, but got ${$labels.rank}`);
59 util.assert($predictions.rank === 1, () => `Expected the rank of predictions to be 1, ` +
60 `but got ${$predictions.rank}`);
61 util.assert($labels.shape[0] === $predictions.shape[0], () => `Mismatch in the number of examples: ` +
62 `${$labels.shape[0]} vs. ${$predictions.shape[0]}. ` +
63 `Labels and predictions should have the same number of elements.`);
64 util.assert(numClasses > 0 && Number.isInteger(numClasses), () => `numClasses is required to be a positive integer, but got ` +
65 `${numClasses}`);
66 // TODO(cais): In the future, if oneHot supports tensors inputs for
67 // `numClasses`, `confusionMatrix` can make `numClasses` optional.
68 const oneHotLabels = oneHot(cast($labels, 'int32'), numClasses);
69 const oneHotPredictions = oneHot(cast($predictions, 'int32'), numClasses);
70 const oneHotLabelsT = transpose(oneHotLabels);
71 const product = matMul(oneHotLabelsT, oneHotPredictions);
72 return cast(product, 'int32');
73}
74export const confusionMatrix = op({ confusionMatrix_ });
75//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiY29uZnVzaW9uX21hdHJpeC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtY29yZS9zcmMvb3BzL2NvbmZ1c2lvbl9tYXRyaXgudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBR0gsT0FBTyxFQUFDLGVBQWUsRUFBQyxNQUFNLG9CQUFvQixDQUFDO0FBRW5ELE9BQU8sS0FBSyxJQUFJLE1BQU0sU0FBUyxDQUFDO0FBRWhDLE9BQU8sRUFBQyxJQUFJLEVBQUMsTUFBTSxRQUFRLENBQUM7QUFDNUIsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNqQyxPQUFPLEVBQUMsTUFBTSxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBQ2pDLE9BQU8sRUFBQyxFQUFFLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDL0IsT0FBTyxFQUFDLFNBQVMsRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUV0Qzs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7OztHQTRCRztBQUNILE1BQU0sVUFBVSxnQkFBZ0IsQ0FDNUIsTUFBMkIsRUFBRSxXQUFnQyxFQUM3RCxVQUFrQjtJQUNwQixNQUFNLE9BQU8sR0FBRyxlQUFlLENBQUMsTUFBTSxFQUFFLFFBQVEsRUFBRSxpQkFBaUIsQ0FBQyxDQUFDO0lBQ3JFLE1BQU0sWUFBWSxHQUNkLGVBQWUsQ0FBQyxXQUFXLEVBQUUsYUFBYSxFQUFFLGlCQUFpQixDQUFDLENBQUM7SUFFbkUsSUFBSSxDQUFDLE1BQU0sQ0FDUCxVQUFVLElBQUksSUFBSSxJQUFJLFVBQVUsR0FBRyxDQUFDLElBQUksTUFBTSxDQUFDLFNBQVMsQ0FBQyxVQUFVLENBQUMsRUFDcEUsR0FBRyxFQUFFLENBQUMsc0RBQXNEO1FBQ3hELFdBQVcsVUFBVSxFQUFFLENBQUMsQ0FBQztJQUNqQyxJQUFJLENBQUMsTUFBTSxDQUNQLE9BQU8sQ0FBQyxJQUFJLEtBQUssQ0FBQyxFQUNsQixHQUFHLEVBQUUsQ0FBQyxnREFBZ0QsT0FBTyxDQUFDLElBQUksRUFBRSxDQUFDLENBQUM7SUFDMUUsSUFBSSxDQUFDLE1BQU0sQ0FDUCxZQUFZLENBQUMsSUFBSSxLQUFLLENBQUMsRUFDdkIsR0FBRyxFQUFFLENBQUMsNENBQTRDO1FBQzlDLFdBQVcsWUFBWSxDQUFDLElBQUksRUFBRSxDQUFDLENBQUM7SUFDeEMsSUFBSSxDQUFDLE1BQU0sQ0FDUCxPQUFPLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxLQUFLLFlBQVksQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLEVBQzFDLEdBQUcsRUFBRSxDQUFDLHNDQUFzQztRQUN4QyxHQUFHLE9BQU8sQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLFFBQVEsWUFBWSxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsSUFBSTtRQUNwRCxpRUFBaUUsQ0FBQyxDQUFDO0lBQzNFLElBQUksQ0FBQyxNQUFNLENBQ1AsVUFBVSxHQUFHLENBQUMsSUFBSSxNQUFNLENBQUMsU0FBUyxDQUFDLFVBQVUsQ0FBQyxFQUM5QyxHQUFHLEVBQUUsQ0FBQywyREFBMkQ7UUFDN0QsR0FBRyxVQUFVLEVBQUUsQ0FBQyxDQUFDO0lBQ3pCLG1FQUFtRTtJQUNuRSxvRUFBb0U7SUFFcEUsTUFBTSxZQUFZLEdBQUcsTUFBTSxDQUFDLElBQUksQ0FBQyxPQUFPLEVBQUUsT0FBTyxDQUFDLEVBQUUsVUFBVSxDQUFhLENBQUM7SUFDNUUsTUFBTSxpQkFBaUIsR0FDbkIsTUFBTSxDQUFDLElBQUksQ0FBQyxZQUFZLEVBQUUsT0FBTyxDQUFDLEVBQUUsVUFBVSxDQUFhLENBQUM7SUFDaEUsTUFBTSxhQUFhLEdBQWEsU0FBUyxDQUFDLFlBQVksQ0FBQyxDQUFDO0lBQ3hELE1BQU0sT0FBTyxHQUFhLE1BQU0sQ0FBQyxhQUFhLEVBQUUsaUJBQWlCLENBQUMsQ0FBQztJQUNuRSxPQUFPLElBQUksQ0FBQyxPQUFPLEVBQUUsT0FBTyxDQUFDLENBQUM7QUFDaEMsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLGVBQWUsR0FBRyxFQUFFLENBQUMsRUFBQyxnQkFBZ0IsRUFBQyxDQUFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxOCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7VGVuc29yMUQsIFRlbnNvcjJEfSBmcm9tICcuLi90ZW5zb3InO1xuaW1wb3J0IHtjb252ZXJ0VG9UZW5zb3J9IGZyb20gJy4uL3RlbnNvcl91dGlsX2Vudic7XG5pbXBvcnQge1RlbnNvckxpa2V9IGZyb20gJy4uL3R5cGVzJztcbmltcG9ydCAqIGFzIHV0aWwgZnJvbSAnLi4vdXRpbCc7XG5cbmltcG9ydCB7Y2FzdH0gZnJvbSAnLi9jYXN0JztcbmltcG9ydCB7bWF0TXVsfSBmcm9tICcuL21hdF9tdWwnO1xuaW1wb3J0IHtvbmVIb3R9IGZyb20gJy4vb25lX2hvdCc7XG5pbXBvcnQge29wfSBmcm9tICcuL29wZXJhdGlvbic7XG5pbXBvcnQge3RyYW5zcG9zZX0gZnJvbSAnLi90cmFuc3Bvc2UnO1xuXG4vKipcbiAqIENvbXB1dGVzIHRoZSBjb25mdXNpb24gbWF0cml4IGZyb20gdHJ1ZSBsYWJlbHMgYW5kIHByZWRpY3RlZCBsYWJlbHMuXG4gKlxuICogYGBganNcbiAqIGNvbnN0IGxhYmVscyA9IHRmLnRlbnNvcjFkKFswLCAxLCAyLCAxLCAwXSwgJ2ludDMyJyk7XG4gKiBjb25zdCBwcmVkaWN0aW9ucyA9IHRmLnRlbnNvcjFkKFswLCAyLCAyLCAxLCAwXSwgJ2ludDMyJyk7XG4gKiBjb25zdCBudW1DbGFzc2VzID0gMztcbiAqIGNvbnN0IG91dCA9IHRmLm1hdGguY29uZnVzaW9uTWF0cml4KGxhYmVscywgcHJlZGljdGlvbnMsIG51bUNsYXNzZXMpO1xuICogb3V0LnByaW50KCk7XG4gKiAvLyBFeHBlY3RlZCBvdXRwdXQgbWF0cml4OlxuICogLy8gW1syLCAwLCAwXSxcbiAqIC8vICBbMCwgMSwgMV0sXG4gKiAvLyAgWzAsIDAsIDFdXVxuICogYGBgXG4gKlxuICogQHBhcmFtIGxhYmVscyBUaGUgdGFyZ2V0IGxhYmVscywgYXNzdW1lZCB0byBiZSAwLWJhc2VkIGludGVnZXJzXG4gKiAgIGZvciB0aGUgY2xhc3Nlcy4gVGhlIHNoYXBlIGlzIGBbbnVtRXhhbXBsZXNdYCwgd2hlcmVcbiAqICAgYG51bUV4YW1wbGVzYCBpcyB0aGUgbnVtYmVyIG9mIGV4YW1wbGVzIGluY2x1ZGVkLlxuICogQHBhcmFtIHByZWRpY3Rpb25zIFRoZSBwcmVkaWN0ZWQgY2xhc3NlcywgYXNzdW1lZCB0byBiZVxuICogICAwLWJhc2VkIGludGVnZXJzIGZvciB0aGUgY2xhc3Nlcy4gTXVzdCBoYXZlIHRoZSBzYW1lIHNoYXBlIGFzIGBsYWJlbHNgLlxuICogQHBhcmFtIG51bUNsYXNzZXMgTnVtYmVyIG9mIGFsbCBjbGFzc2VzLCBhcyBhbiBpbnRlZ2VyLlxuICogICBJdHMgdmFsdWUgbXVzdCBiZSBsYXJnZXIgdGhhbiB0aGUgbGFyZ2VzdCBlbGVtZW50IGluIGBsYWJlbHNgIGFuZFxuICogICBgcHJlZGljdGlvbnNgLlxuICogQHJldHVybnMgVGhlIGNvbmZ1c2lvbiBtYXRyaXggYXMgYSBpbnQzMi10eXBlIDJEIHRlbnNvci4gVGhlIHZhbHVlIGF0XG4gKiAgIHJvdyBgcmAgYW5kIGNvbHVtbiBgY2AgaXMgdGhlIG51bWJlciBvZiB0aW1lcyBleGFtcGxlcyBvZiBhY3R1YWwgY2xhc3NcbiAqICAgYHJgIHdlcmUgcHJlZGljdGVkIGFzIGNsYXNzIGBjYC5cbiAqXG4gKiBAZG9jIHtoZWFkaW5nOiAnT3BlcmF0aW9ucycsIHN1YmhlYWRpbmc6ICdFdmFsdWF0aW9uJ31cbiAqL1xuZXhwb3J0IGZ1bmN0aW9uIGNvbmZ1c2lvbk1hdHJpeF8oXG4gICAgbGFiZWxzOiBUZW5zb3IxRHxUZW5zb3JMaWtlLCBwcmVkaWN0aW9uczogVGVuc29yMUR8VGVuc29yTGlrZSxcbiAgICBudW1DbGFzc2VzOiBudW1iZXIpOiBUZW5zb3IyRCB7XG4gIGNvbnN0ICRsYWJlbHMgPSBjb252ZXJ0VG9UZW5zb3IobGFiZWxzLCAnbGFiZWxzJywgJ2NvbmZ1c2lvbk1hdHJpeCcpO1xuICBjb25zdCAkcHJlZGljdGlvbnMgPVxuICAgICAgY29udmVydFRvVGVuc29yKHByZWRpY3Rpb25zLCAncHJlZGljdGlvbnMnLCAnY29uZnVzaW9uTWF0cml4Jyk7XG5cbiAgdXRpbC5hc3NlcnQoXG4gICAgICBudW1DbGFzc2VzID09IG51bGwgfHwgbnVtQ2xhc3NlcyA+IDAgJiYgTnVtYmVyLmlzSW50ZWdlcihudW1DbGFzc2VzKSxcbiAgICAgICgpID0+IGBJZiBwcm92aWRlZCwgbnVtQ2xhc3NlcyBtdXN0IGJlIGEgcG9zaXRpdmUgaW50ZWdlciwgYCArXG4gICAgICAgICAgYGJ1dCBnb3QgJHtudW1DbGFzc2VzfWApO1xuICB1dGlsLmFzc2VydChcbiAgICAgICRsYWJlbHMucmFuayA9PT0gMSxcbiAgICAgICgpID0+IGBFeHBlY3RlZCB0aGUgcmFuayBvZiBsYWJlbHMgdG8gYmUgMSwgYnV0IGdvdCAkeyRsYWJlbHMucmFua31gKTtcbiAgdXRpbC5hc3NlcnQoXG4gICAgICAkcHJlZGljdGlvbnMucmFuayA9PT0gMSxcbiAgICAgICgpID0+IGBFeHBlY3RlZCB0aGUgcmFuayBvZiBwcmVkaWN0aW9ucyB0byBiZSAxLCBgICtcbiAgICAgICAgICBgYnV0IGdvdCAkeyRwcmVkaWN0aW9ucy5yYW5rfWApO1xuICB1dGlsLmFzc2VydChcbiAgICAgICRsYWJlbHMuc2hhcGVbMF0gPT09ICRwcmVkaWN0aW9ucy5zaGFwZVswXSxcbiAgICAgICgpID0+IGBNaXNtYXRjaCBpbiB0aGUgbnVtYmVyIG9mIGV4YW1wbGVzOiBgICtcbiAgICAgICAgICBgJHskbGFiZWxzLnNoYXBlWzBdfSB2cy4gJHskcHJlZGljdGlvbnMuc2hhcGVbMF19LiBgICtcbiAgICAgICAgICBgTGFiZWxzIGFuZCBwcmVkaWN0aW9ucyBzaG91bGQgaGF2ZSB0aGUgc2FtZSBudW1iZXIgb2YgZWxlbWVudHMuYCk7XG4gIHV0aWwuYXNzZXJ0KFxuICAgICAgbnVtQ2xhc3NlcyA+IDAgJiYgTnVtYmVyLmlzSW50ZWdlcihudW1DbGFzc2VzKSxcbiAgICAgICgpID0+IGBudW1DbGFzc2VzIGlzIHJlcXVpcmVkIHRvIGJlIGEgcG9zaXRpdmUgaW50ZWdlciwgYnV0IGdvdCBgICtcbiAgICAgICAgICBgJHtudW1DbGFzc2VzfWApO1xuICAvLyBUT0RPKGNhaXMpOiBJbiB0aGUgZnV0dXJlLCBpZiBvbmVIb3Qgc3VwcG9ydHMgdGVuc29ycyBpbnB1dHMgZm9yXG4gIC8vICAgYG51bUNsYXNzZXNgLCBgY29uZnVzaW9uTWF0cml4YCBjYW4gbWFrZSBgbnVtQ2xhc3Nlc2Agb3B0aW9uYWwuXG5cbiAgY29uc3Qgb25lSG90TGFiZWxzID0gb25lSG90KGNhc3QoJGxhYmVscywgJ2ludDMyJyksIG51bUNsYXNzZXMpIGFzIFRlbnNvcjJEO1xuICBjb25zdCBvbmVIb3RQcmVkaWN0aW9ucyA9XG4gICAgICBvbmVIb3QoY2FzdCgkcHJlZGljdGlvbnMsICdpbnQzMicpLCBudW1DbGFzc2VzKSBhcyBUZW5zb3IyRDtcbiAgY29uc3Qgb25lSG90TGFiZWxzVDogVGVuc29yMkQgPSB0cmFuc3Bvc2Uob25lSG90TGFiZWxzKTtcbiAgY29uc3QgcHJvZHVjdDogVGVuc29yMkQgPSBtYXRNdWwob25lSG90TGFiZWxzVCwgb25lSG90UHJlZGljdGlvbnMpO1xuICByZXR1cm4gY2FzdChwcm9kdWN0LCAnaW50MzInKTtcbn1cblxuZXhwb3J0IGNvbnN0IGNvbmZ1c2lvbk1hdHJpeCA9IG9wKHtjb25mdXNpb25NYXRyaXhffSk7XG4iXX0=
\No newline at end of file