1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 | import * as tf from '../index';
|
18 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
19 | import { expectArraysClose, expectArraysEqual } from '../test_util';
|
20 | describeWithFlags('all', ALL_ENVS, () => {
|
21 | it('Tensor1D', async () => {
|
22 | let a = tf.tensor1d([0, 0, 0], 'bool');
|
23 | expectArraysClose(await tf.all(a).data(), 0);
|
24 | a = tf.tensor1d([1, 0, 1], 'bool');
|
25 | expectArraysClose(await tf.all(a).data(), 0);
|
26 | a = tf.tensor1d([1, 1, 1], 'bool');
|
27 | expectArraysClose(await tf.all(a).data(), 1);
|
28 | });
|
29 | it('ignores NaNs', async () => {
|
30 | const a = tf.tensor1d([1, NaN, 1], 'bool');
|
31 | expectArraysEqual(await tf.all(a).data(), 1);
|
32 | });
|
33 | it('2D', async () => {
|
34 | const a = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');
|
35 | expectArraysClose(await tf.all(a).data(), 0);
|
36 | });
|
37 | it('2D axis=[0,1]', async () => {
|
38 | const a = tf.tensor2d([1, 1, 0, 0, 1, 0], [2, 3], 'bool');
|
39 | expectArraysClose(await tf.all(a, [0, 1]).data(), 0);
|
40 | });
|
41 | it('2D, axis=0', async () => {
|
42 | const a = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');
|
43 | let r = tf.all(a, 0);
|
44 | expect(r.shape).toEqual([2]);
|
45 | expectArraysClose(await r.data(), [0, 0]);
|
46 | r = tf.all(a, 1);
|
47 | expect(r.shape).toEqual([2]);
|
48 | expectArraysClose(await r.data(), [1, 0]);
|
49 | });
|
50 | it('2D, axis=0, keepDims', async () => {
|
51 | const a = tf.tensor2d([1, 1, 0, 0, 1, 0], [2, 3], 'bool');
|
52 | const r = a.all(0, true );
|
53 | expect(r.shape).toEqual([1, 3]);
|
54 | expectArraysClose(await r.data(), [0, 1, 0]);
|
55 | });
|
56 | it('2D, axis=1 provided as a number', async () => {
|
57 | const a = tf.tensor2d([1, 1, 0, 0, 1, 0], [2, 3], 'bool');
|
58 | const r = tf.all(a, 1);
|
59 | expectArraysClose(await r.data(), [0, 0]);
|
60 | });
|
61 | it('2D, axis = -1 provided as a number', async () => {
|
62 | const a = tf.tensor2d([1, 1, 0, 0, 1, 0], [2, 3], 'bool');
|
63 | const r = tf.all(a, -1);
|
64 | expectArraysClose(await r.data(), [0, 0]);
|
65 | });
|
66 | it('2D, axis=[1]', async () => {
|
67 | const a = tf.tensor2d([1, 1, 0, 0, 1, 0], [2, 3], 'bool');
|
68 | const r = tf.all(a, [1]);
|
69 | expectArraysClose(await r.data(), [0, 0]);
|
70 | });
|
71 | it('throws when dtype is not boolean', () => {
|
72 | const a = tf.tensor2d([1, 1, 0, 0], [2, 2]);
|
73 | expect(() => tf.all(a))
|
74 | .toThrowError(/Argument 'x' passed to 'all' must be bool tensor, but got float/);
|
75 | });
|
76 | it('throws when passed a non-tensor', () => {
|
77 | expect(() => tf.all({}))
|
78 | .toThrowError(/Argument 'x' passed to 'all' must be a Tensor/);
|
79 | });
|
80 | it('accepts a tensor-like object', async () => {
|
81 | const a = [0, 0, 0];
|
82 | expectArraysClose(await tf.all(a).data(), 0);
|
83 | });
|
84 | it('throws error for string tensor', () => {
|
85 | expect(() => tf.all(['a']))
|
86 | .toThrowError(/Argument 'x' passed to 'all' must be bool tensor, but got string/);
|
87 | });
|
88 | });
|
89 |
|
\ | No newline at end of file |