UNPKG

3.73 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose, expectArraysEqual } from '../test_util';
20describeWithFlags('all', ALL_ENVS, () => {
21 it('Tensor1D', async () => {
22 let a = tf.tensor1d([0, 0, 0], 'bool');
23 expectArraysClose(await tf.all(a).data(), 0);
24 a = tf.tensor1d([1, 0, 1], 'bool');
25 expectArraysClose(await tf.all(a).data(), 0);
26 a = tf.tensor1d([1, 1, 1], 'bool');
27 expectArraysClose(await tf.all(a).data(), 1);
28 });
29 it('ignores NaNs', async () => {
30 const a = tf.tensor1d([1, NaN, 1], 'bool');
31 expectArraysEqual(await tf.all(a).data(), 1);
32 });
33 it('2D', async () => {
34 const a = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');
35 expectArraysClose(await tf.all(a).data(), 0);
36 });
37 it('2D axis=[0,1]', async () => {
38 const a = tf.tensor2d([1, 1, 0, 0, 1, 0], [2, 3], 'bool');
39 expectArraysClose(await tf.all(a, [0, 1]).data(), 0);
40 });
41 it('2D, axis=0', async () => {
42 const a = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');
43 let r = tf.all(a, 0);
44 expect(r.shape).toEqual([2]);
45 expectArraysClose(await r.data(), [0, 0]);
46 r = tf.all(a, 1);
47 expect(r.shape).toEqual([2]);
48 expectArraysClose(await r.data(), [1, 0]);
49 });
50 it('2D, axis=0, keepDims', async () => {
51 const a = tf.tensor2d([1, 1, 0, 0, 1, 0], [2, 3], 'bool');
52 const r = a.all(0, true /* keepDims */);
53 expect(r.shape).toEqual([1, 3]);
54 expectArraysClose(await r.data(), [0, 1, 0]);
55 });
56 it('2D, axis=1 provided as a number', async () => {
57 const a = tf.tensor2d([1, 1, 0, 0, 1, 0], [2, 3], 'bool');
58 const r = tf.all(a, 1);
59 expectArraysClose(await r.data(), [0, 0]);
60 });
61 it('2D, axis = -1 provided as a number', async () => {
62 const a = tf.tensor2d([1, 1, 0, 0, 1, 0], [2, 3], 'bool');
63 const r = tf.all(a, -1);
64 expectArraysClose(await r.data(), [0, 0]);
65 });
66 it('2D, axis=[1]', async () => {
67 const a = tf.tensor2d([1, 1, 0, 0, 1, 0], [2, 3], 'bool');
68 const r = tf.all(a, [1]);
69 expectArraysClose(await r.data(), [0, 0]);
70 });
71 it('throws when dtype is not boolean', () => {
72 const a = tf.tensor2d([1, 1, 0, 0], [2, 2]);
73 expect(() => tf.all(a))
74 .toThrowError(/Argument 'x' passed to 'all' must be bool tensor, but got float/);
75 });
76 it('throws when passed a non-tensor', () => {
77 expect(() => tf.all({}))
78 .toThrowError(/Argument 'x' passed to 'all' must be a Tensor/);
79 });
80 it('accepts a tensor-like object', async () => {
81 const a = [0, 0, 0];
82 expectArraysClose(await tf.all(a).data(), 0);
83 });
84 it('throws error for string tensor', () => {
85 expect(() => tf.all(['a']))
86 .toThrowError(/Argument 'x' passed to 'all' must be bool tensor, but got string/);
87 });
88});
89//# sourceMappingURL=all_test.js.map
\No newline at end of file