UNPKG

6.51 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysEqual } from '../test_util';
20/**
21 * Unit tests for confusionMatrix().
22 */
23describeWithFlags('confusionMatrix', ALL_ENVS, () => {
24 // Reference (Python) TensorFlow code:
25 //
26 // ```py
27 // import tensorflow as tf
28 //
29 // tf.enable_eager_execution()
30 //
31 // labels = tf.constant([0, 1, 2, 1, 0])
32 // predictions = tf.constant([0, 2, 2, 1, 0])
33 // out = tf.confusion_matrix(labels, predictions, 3)
34 //
35 // print(out)
36 // ```
37 it('3x3 all cases present in both labels and predictions', async () => {
38 const labels = tf.tensor1d([0, 1, 2, 1, 0], 'int32');
39 const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'int32');
40 const numClasses = 3;
41 const out = tf.math.confusionMatrix(labels, predictions, numClasses);
42 expectArraysEqual(await out.data(), [2, 0, 0, 0, 1, 1, 0, 0, 1]);
43 expect(out.dtype).toBe('int32');
44 expect(out.shape).toEqual([3, 3]);
45 });
46 it('float32 arguments are accepted', async () => {
47 const labels = tf.tensor1d([0, 1, 2, 1, 0], 'float32');
48 const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'float32');
49 const numClasses = 3;
50 const out = tf.math.confusionMatrix(labels, predictions, numClasses);
51 expectArraysEqual(await out.data(), [2, 0, 0, 0, 1, 1, 0, 0, 1]);
52 expect(out.dtype).toBe('int32');
53 expect(out.shape).toEqual([3, 3]);
54 });
55 // Reference (Python) TensorFlow code:
56 //
57 // ```py
58 // import tensorflow as tf
59 //
60 // tf.enable_eager_execution()
61 //
62 // labels = tf.constant([3, 3, 2, 2, 1, 1, 0, 0])
63 // predictions = tf.constant([2, 2, 2, 2, 0, 0, 0, 0])
64 // out = tf.confusion_matrix(labels, predictions, 4)
65 //
66 // print(out)
67 // ```
68 it('4x4 all cases present in labels, but not predictions', async () => {
69 const labels = tf.tensor1d([3, 3, 2, 2, 1, 1, 0, 0], 'int32');
70 const predictions = tf.tensor1d([2, 2, 2, 2, 0, 0, 0, 0], 'int32');
71 const numClasses = 4;
72 const out = tf.math.confusionMatrix(labels, predictions, numClasses);
73 expectArraysEqual(await out.data(), [2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0]);
74 expect(out.dtype).toBe('int32');
75 expect(out.shape).toEqual([4, 4]);
76 });
77 it('4x4 all cases present in predictions, but not labels', async () => {
78 const labels = tf.tensor1d([2, 2, 2, 2, 0, 0, 0, 0], 'int32');
79 const predictions = tf.tensor1d([3, 3, 2, 2, 1, 1, 0, 0], 'int32');
80 const numClasses = 4;
81 const out = tf.math.confusionMatrix(labels, predictions, numClasses);
82 expectArraysEqual(await out.data(), [2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0]);
83 expect(out.dtype).toBe('int32');
84 expect(out.shape).toEqual([4, 4]);
85 });
86 it('Plain arrays as inputs', async () => {
87 const labels = [3, 3, 2, 2, 1, 1, 0, 0];
88 const predictions = [2, 2, 2, 2, 0, 0, 0, 0];
89 const numClasses = 4;
90 const out = tf.math.confusionMatrix(labels, predictions, numClasses);
91 expectArraysEqual(await out.data(), [2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0]);
92 expect(out.dtype).toBe('int32');
93 expect(out.shape).toEqual([4, 4]);
94 });
95 it('Int32Arrays as inputs', async () => {
96 const labels = new Int32Array([3, 3, 2, 2, 1, 1, 0, 0]);
97 const predictions = new Int32Array([2, 2, 2, 2, 0, 0, 0, 0]);
98 const numClasses = 4;
99 const out = tf.math.confusionMatrix(labels, predictions, numClasses);
100 expectArraysEqual(await out.data(), [2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0]);
101 expect(out.dtype).toBe('int32');
102 expect(out.shape).toEqual([4, 4]);
103 });
104 // Reference (Python) TensorFlow code:
105 //
106 // ```py
107 // import tensorflow as tf
108 //
109 // tf.enable_eager_execution()
110 //
111 // labels = tf.constant([0, 4])
112 // predictions = tf.constant([4, 0])
113 // out = tf.confusion_matrix(labels, predictions, 5)
114 //
115 // print(out)
116 // ```
117 it('5x5 predictions and labels both missing some cases', async () => {
118 const labels = tf.tensor1d([0, 4], 'int32');
119 const predictions = tf.tensor1d([4, 0], 'int32');
120 const numClasses = 5;
121 const out = tf.math.confusionMatrix(labels, predictions, numClasses);
122 expectArraysEqual(await out.data(), [
123 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
124 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0
125 ]);
126 expect(out.dtype).toBe('int32');
127 expect(out.shape).toEqual([5, 5]);
128 });
129 it('Invalid numClasses leads to Error', () => {
130 expect(() => tf.math.confusionMatrix(tf.tensor1d([0, 1]), tf.tensor1d([1, 0]), 2.5))
131 .toThrowError(/numClasses .* positive integer.* got 2\.5/);
132 });
133 it('Incorrect tensor rank leads to Error', () => {
134 expect(() => tf.math.confusionMatrix(
135 // tslint:disable-next-line:no-any
136 tf.scalar(0), tf.scalar(0), 1))
137 .toThrowError(/rank .* 1.*got 0/);
138 expect(() =>
139 // tslint:disable-next-line:no-any
140 tf.math.confusionMatrix(tf.zeros([3, 3]), tf.zeros([9]), 2))
141 .toThrowError(/rank .* 1.*got 2/);
142 expect(() =>
143 // tslint:disable-next-line:no-any
144 tf.math.confusionMatrix(tf.zeros([9]), tf.zeros([3, 3]), 2))
145 .toThrowError(/rank .* 1.*got 2/);
146 });
147 it('Mismatch in lengths leads to Error', () => {
148 expect(
149 // tslint:disable-next-line:no-any
150 () => tf.math.confusionMatrix(tf.zeros([3]), tf.zeros([9]), 2))
151 .toThrowError(/Mismatch .* 3 vs.* 9/);
152 });
153});
154//# sourceMappingURL=confusion_matrix_test.js.map
\No newline at end of file