UNPKG

5.79 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose, expectArraysEqual } from '../test_util';
20describeWithFlags('mean', ALL_ENVS, () => {
21 it('basic', async () => {
22 const a = tf.tensor2d([
23 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
24 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
25 ], [16, 2]);
26 const r = tf.mean(a);
27 expect(r.dtype).toBe('float32');
28 expectArraysClose(await r.data(), 15.5);
29 });
30 it('propagates NaNs', async () => {
31 const a = tf.tensor2d([1, 2, 3, NaN, 0, 1], [3, 2]);
32 const r = tf.mean(a);
33 expect(r.dtype).toBe('float32');
34 expectArraysEqual(await r.data(), NaN);
35 });
36 it('mean(int32) => float32', async () => {
37 const a = tf.tensor1d([1, 5, 7, 3], 'int32');
38 const r = tf.mean(a);
39 expect(r.dtype).toBe('float32');
40 expectArraysClose(await r.data(), 4);
41 });
42 it('mean(bool) => float32', async () => {
43 const a = tf.tensor1d([true, false, false, true, true], 'bool');
44 const r = tf.mean(a);
45 expect(r.dtype).toBe('float32');
46 expectArraysClose(await r.data(), 3 / 5);
47 });
48 it('2D array with keep dim', async () => {
49 const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
50 const res = tf.mean(a, null, true /* keepDims */);
51 expect(res.shape).toEqual([1, 1]);
52 expect(res.dtype).toBe('float32');
53 expectArraysClose(await res.data(), [7 / 6]);
54 });
55 it('axis=0 in 2D array', async () => {
56 const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
57 const res = tf.mean(a, [0]);
58 expect(res.shape).toEqual([2]);
59 expect(res.dtype).toBe('float32');
60 expectArraysClose(await res.data(), [4 / 3, 1]);
61 });
62 it('axis=0 in 2D array, keepDims', async () => {
63 const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
64 const res = tf.mean(a, [0], true /* keepDims */);
65 expect(res.shape).toEqual([1, 2]);
66 expect(res.dtype).toBe('float32');
67 expectArraysClose(await res.data(), [4 / 3, 1]);
68 });
69 it('axis=1 in 2D array', async () => {
70 const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
71 const res = tf.mean(a, [1]);
72 expect(res.dtype).toBe('float32');
73 expect(res.shape).toEqual([3]);
74 expectArraysClose(await res.data(), [1.5, 1.5, 0.5]);
75 });
76 it('axis = -1 in 2D array', async () => {
77 const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
78 const res = tf.mean(a, [-1]);
79 expect(res.dtype).toBe('float32');
80 expect(res.shape).toEqual([3]);
81 expectArraysClose(await res.data(), [1.5, 1.5, 0.5]);
82 });
83 it('2D, axis=1 provided as number', async () => {
84 const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [2, 3]);
85 const res = tf.mean(a, 1);
86 expect(res.shape).toEqual([2]);
87 expect(res.dtype).toBe('float32');
88 expectArraysClose(await res.data(), [2, 1 / 3]);
89 });
90 it('axis=0,1 in 2D array', async () => {
91 const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
92 const res = tf.mean(a, [0, 1]);
93 expect(res.shape).toEqual([]);
94 expect(res.dtype).toBe('float32');
95 expectArraysClose(await res.data(), [7 / 6]);
96 });
97 it('gradients', async () => {
98 const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
99 const dy = tf.scalar(1.5);
100 const da = tf.grad(a => a.mean())(a, dy);
101 const dyVal = await dy.array();
102 expect(da.shape).toEqual(a.shape);
103 expectArraysClose(await da.data(), [
104 dyVal / a.size, dyVal / a.size, dyVal / a.size, dyVal / a.size,
105 dyVal / a.size, dyVal / a.size
106 ]);
107 });
108 it('gradient with clones', async () => {
109 const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
110 const dy = tf.scalar(1.5);
111 const da = tf.grad(a => a.clone().mean().clone())(a, dy);
112 const dyVal = await dy.array();
113 expect(da.shape).toEqual(a.shape);
114 expectArraysClose(await da.data(), [
115 dyVal / a.size, dyVal / a.size, dyVal / a.size, dyVal / a.size,
116 dyVal / a.size, dyVal / a.size
117 ]);
118 });
119 it('gradients throws for defined axis', () => {
120 const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
121 const dy = tf.scalar(1.5);
122 expect(() => tf.grad(a => a.mean(1))(a, dy)).toThrowError();
123 });
124 it('throws when passed a non-tensor', () => {
125 expect(() => tf.mean({}))
126 .toThrowError(/Argument 'x' passed to 'mean' must be a Tensor/);
127 });
128 it('accepts a tensor-like object', async () => {
129 const r = tf.mean([[1, 2, 3], [0, 0, 1]]);
130 expect(r.dtype).toBe('float32');
131 expectArraysClose(await r.data(), 7 / 6);
132 });
133 it('throws error for string tensor', () => {
134 expect(() => tf.mean(['a']))
135 .toThrowError(/Argument 'x' passed to 'mean' must be numeric tensor/);
136 });
137});
138//# sourceMappingURL=mean_test.js.map
\No newline at end of file