1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 | import * as tf from '../index';
|
18 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
19 | import { expectArraysEqual } from '../test_util';
|
20 |
|
21 |
|
22 |
|
23 | describeWithFlags('confusionMatrix', ALL_ENVS, () => {
|
24 |
|
25 |
|
26 |
|
27 |
|
28 |
|
29 |
|
30 |
|
31 |
|
32 |
|
33 |
|
34 |
|
35 |
|
36 |
|
37 | it('3x3 all cases present in both labels and predictions', async () => {
|
38 | const labels = tf.tensor1d([0, 1, 2, 1, 0], 'int32');
|
39 | const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'int32');
|
40 | const numClasses = 3;
|
41 | const out = tf.math.confusionMatrix(labels, predictions, numClasses);
|
42 | expectArraysEqual(await out.data(), [2, 0, 0, 0, 1, 1, 0, 0, 1]);
|
43 | expect(out.dtype).toBe('int32');
|
44 | expect(out.shape).toEqual([3, 3]);
|
45 | });
|
46 | it('float32 arguments are accepted', async () => {
|
47 | const labels = tf.tensor1d([0, 1, 2, 1, 0], 'float32');
|
48 | const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'float32');
|
49 | const numClasses = 3;
|
50 | const out = tf.math.confusionMatrix(labels, predictions, numClasses);
|
51 | expectArraysEqual(await out.data(), [2, 0, 0, 0, 1, 1, 0, 0, 1]);
|
52 | expect(out.dtype).toBe('int32');
|
53 | expect(out.shape).toEqual([3, 3]);
|
54 | });
|
55 |
|
56 |
|
57 |
|
58 |
|
59 |
|
60 |
|
61 |
|
62 |
|
63 |
|
64 |
|
65 |
|
66 |
|
67 |
|
68 | it('4x4 all cases present in labels, but not predictions', async () => {
|
69 | const labels = tf.tensor1d([3, 3, 2, 2, 1, 1, 0, 0], 'int32');
|
70 | const predictions = tf.tensor1d([2, 2, 2, 2, 0, 0, 0, 0], 'int32');
|
71 | const numClasses = 4;
|
72 | const out = tf.math.confusionMatrix(labels, predictions, numClasses);
|
73 | expectArraysEqual(await out.data(), [2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0]);
|
74 | expect(out.dtype).toBe('int32');
|
75 | expect(out.shape).toEqual([4, 4]);
|
76 | });
|
77 | it('4x4 all cases present in predictions, but not labels', async () => {
|
78 | const labels = tf.tensor1d([2, 2, 2, 2, 0, 0, 0, 0], 'int32');
|
79 | const predictions = tf.tensor1d([3, 3, 2, 2, 1, 1, 0, 0], 'int32');
|
80 | const numClasses = 4;
|
81 | const out = tf.math.confusionMatrix(labels, predictions, numClasses);
|
82 | expectArraysEqual(await out.data(), [2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0]);
|
83 | expect(out.dtype).toBe('int32');
|
84 | expect(out.shape).toEqual([4, 4]);
|
85 | });
|
86 | it('Plain arrays as inputs', async () => {
|
87 | const labels = [3, 3, 2, 2, 1, 1, 0, 0];
|
88 | const predictions = [2, 2, 2, 2, 0, 0, 0, 0];
|
89 | const numClasses = 4;
|
90 | const out = tf.math.confusionMatrix(labels, predictions, numClasses);
|
91 | expectArraysEqual(await out.data(), [2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0]);
|
92 | expect(out.dtype).toBe('int32');
|
93 | expect(out.shape).toEqual([4, 4]);
|
94 | });
|
95 | it('Int32Arrays as inputs', async () => {
|
96 | const labels = new Int32Array([3, 3, 2, 2, 1, 1, 0, 0]);
|
97 | const predictions = new Int32Array([2, 2, 2, 2, 0, 0, 0, 0]);
|
98 | const numClasses = 4;
|
99 | const out = tf.math.confusionMatrix(labels, predictions, numClasses);
|
100 | expectArraysEqual(await out.data(), [2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0]);
|
101 | expect(out.dtype).toBe('int32');
|
102 | expect(out.shape).toEqual([4, 4]);
|
103 | });
|
104 |
|
105 |
|
106 |
|
107 |
|
108 |
|
109 |
|
110 |
|
111 |
|
112 |
|
113 |
|
114 |
|
115 |
|
116 |
|
117 | it('5x5 predictions and labels both missing some cases', async () => {
|
118 | const labels = tf.tensor1d([0, 4], 'int32');
|
119 | const predictions = tf.tensor1d([4, 0], 'int32');
|
120 | const numClasses = 5;
|
121 | const out = tf.math.confusionMatrix(labels, predictions, numClasses);
|
122 | expectArraysEqual(await out.data(), [
|
123 | 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
124 | 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0
|
125 | ]);
|
126 | expect(out.dtype).toBe('int32');
|
127 | expect(out.shape).toEqual([5, 5]);
|
128 | });
|
129 | it('Invalid numClasses leads to Error', () => {
|
130 | expect(() => tf.math.confusionMatrix(tf.tensor1d([0, 1]), tf.tensor1d([1, 0]), 2.5))
|
131 | .toThrowError(/numClasses .* positive integer.* got 2\.5/);
|
132 | });
|
133 | it('Incorrect tensor rank leads to Error', () => {
|
134 | expect(() => tf.math.confusionMatrix(
|
135 |
|
136 | tf.scalar(0), tf.scalar(0), 1))
|
137 | .toThrowError(/rank .* 1.*got 0/);
|
138 | expect(() =>
|
139 |
|
140 | tf.math.confusionMatrix(tf.zeros([3, 3]), tf.zeros([9]), 2))
|
141 | .toThrowError(/rank .* 1.*got 2/);
|
142 | expect(() =>
|
143 |
|
144 | tf.math.confusionMatrix(tf.zeros([9]), tf.zeros([3, 3]), 2))
|
145 | .toThrowError(/rank .* 1.*got 2/);
|
146 | });
|
147 | it('Mismatch in lengths leads to Error', () => {
|
148 | expect(
|
149 |
|
150 | () => tf.math.confusionMatrix(tf.zeros([3]), tf.zeros([9]), 2))
|
151 | .toThrowError(/Mismatch .* 3 vs.* 9/);
|
152 | });
|
153 | });
|
154 |
|
\ | No newline at end of file |