1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 | import * as tf from '../index';
|
18 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
19 | import { expectArraysClose, expectArraysEqual } from '../test_util';
|
20 | describeWithFlags('mean', ALL_ENVS, () => {
|
21 | it('basic', async () => {
|
22 | const a = tf.tensor2d([
|
23 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
24 | 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
|
25 | ], [16, 2]);
|
26 | const r = tf.mean(a);
|
27 | expect(r.dtype).toBe('float32');
|
28 | expectArraysClose(await r.data(), 15.5);
|
29 | });
|
30 | it('propagates NaNs', async () => {
|
31 | const a = tf.tensor2d([1, 2, 3, NaN, 0, 1], [3, 2]);
|
32 | const r = tf.mean(a);
|
33 | expect(r.dtype).toBe('float32');
|
34 | expectArraysEqual(await r.data(), NaN);
|
35 | });
|
36 | it('mean(int32) => float32', async () => {
|
37 | const a = tf.tensor1d([1, 5, 7, 3], 'int32');
|
38 | const r = tf.mean(a);
|
39 | expect(r.dtype).toBe('float32');
|
40 | expectArraysClose(await r.data(), 4);
|
41 | });
|
42 | it('mean(bool) => float32', async () => {
|
43 | const a = tf.tensor1d([true, false, false, true, true], 'bool');
|
44 | const r = tf.mean(a);
|
45 | expect(r.dtype).toBe('float32');
|
46 | expectArraysClose(await r.data(), 3 / 5);
|
47 | });
|
48 | it('2D array with keep dim', async () => {
|
49 | const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
|
50 | const res = tf.mean(a, null, true );
|
51 | expect(res.shape).toEqual([1, 1]);
|
52 | expect(res.dtype).toBe('float32');
|
53 | expectArraysClose(await res.data(), [7 / 6]);
|
54 | });
|
55 | it('axis=0 in 2D array', async () => {
|
56 | const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
|
57 | const res = tf.mean(a, [0]);
|
58 | expect(res.shape).toEqual([2]);
|
59 | expect(res.dtype).toBe('float32');
|
60 | expectArraysClose(await res.data(), [4 / 3, 1]);
|
61 | });
|
62 | it('axis=0 in 2D array, keepDims', async () => {
|
63 | const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
|
64 | const res = tf.mean(a, [0], true );
|
65 | expect(res.shape).toEqual([1, 2]);
|
66 | expect(res.dtype).toBe('float32');
|
67 | expectArraysClose(await res.data(), [4 / 3, 1]);
|
68 | });
|
69 | it('axis=1 in 2D array', async () => {
|
70 | const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
|
71 | const res = tf.mean(a, [1]);
|
72 | expect(res.dtype).toBe('float32');
|
73 | expect(res.shape).toEqual([3]);
|
74 | expectArraysClose(await res.data(), [1.5, 1.5, 0.5]);
|
75 | });
|
76 | it('axis = -1 in 2D array', async () => {
|
77 | const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
|
78 | const res = tf.mean(a, [-1]);
|
79 | expect(res.dtype).toBe('float32');
|
80 | expect(res.shape).toEqual([3]);
|
81 | expectArraysClose(await res.data(), [1.5, 1.5, 0.5]);
|
82 | });
|
83 | it('2D, axis=1 provided as number', async () => {
|
84 | const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [2, 3]);
|
85 | const res = tf.mean(a, 1);
|
86 | expect(res.shape).toEqual([2]);
|
87 | expect(res.dtype).toBe('float32');
|
88 | expectArraysClose(await res.data(), [2, 1 / 3]);
|
89 | });
|
90 | it('axis=0,1 in 2D array', async () => {
|
91 | const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
|
92 | const res = tf.mean(a, [0, 1]);
|
93 | expect(res.shape).toEqual([]);
|
94 | expect(res.dtype).toBe('float32');
|
95 | expectArraysClose(await res.data(), [7 / 6]);
|
96 | });
|
97 | it('gradients', async () => {
|
98 | const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
|
99 | const dy = tf.scalar(1.5);
|
100 | const da = tf.grad(a => a.mean())(a, dy);
|
101 | const dyVal = await dy.array();
|
102 | expect(da.shape).toEqual(a.shape);
|
103 | expectArraysClose(await da.data(), [
|
104 | dyVal / a.size, dyVal / a.size, dyVal / a.size, dyVal / a.size,
|
105 | dyVal / a.size, dyVal / a.size
|
106 | ]);
|
107 | });
|
108 | it('gradient with clones', async () => {
|
109 | const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
|
110 | const dy = tf.scalar(1.5);
|
111 | const da = tf.grad(a => a.clone().mean().clone())(a, dy);
|
112 | const dyVal = await dy.array();
|
113 | expect(da.shape).toEqual(a.shape);
|
114 | expectArraysClose(await da.data(), [
|
115 | dyVal / a.size, dyVal / a.size, dyVal / a.size, dyVal / a.size,
|
116 | dyVal / a.size, dyVal / a.size
|
117 | ]);
|
118 | });
|
119 | it('gradients throws for defined axis', () => {
|
120 | const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
|
121 | const dy = tf.scalar(1.5);
|
122 | expect(() => tf.grad(a => a.mean(1))(a, dy)).toThrowError();
|
123 | });
|
124 | it('throws when passed a non-tensor', () => {
|
125 | expect(() => tf.mean({}))
|
126 | .toThrowError(/Argument 'x' passed to 'mean' must be a Tensor/);
|
127 | });
|
128 | it('accepts a tensor-like object', async () => {
|
129 | const r = tf.mean([[1, 2, 3], [0, 0, 1]]);
|
130 | expect(r.dtype).toBe('float32');
|
131 | expectArraysClose(await r.data(), 7 / 6);
|
132 | });
|
133 | it('throws error for string tensor', () => {
|
134 | expect(() => tf.mean(['a']))
|
135 | .toThrowError(/Argument 'x' passed to 'mean' must be numeric tensor/);
|
136 | });
|
137 | });
|
138 |
|
\ | No newline at end of file |