1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 | import * as tf from '../index';
|
18 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
19 | import { expectArraysClose } from '../test_util';
|
20 | describeWithFlags('movingAverage', ALL_ENVS, () => {
|
21 |
|
22 |
|
23 |
|
24 |
|
25 |
|
26 |
|
27 |
|
28 |
|
29 |
|
30 |
|
31 |
|
32 |
|
33 |
|
34 |
|
35 |
|
36 |
|
37 |
|
38 |
|
39 |
|
40 |
|
41 |
|
42 |
|
43 |
|
44 |
|
45 | it('zeroDebias=true, decay and step are numbers', async () => {
|
46 | const v0 = tf.tensor2d([[0, 0], [0, 0]], [2, 2]);
|
47 | const x = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
|
48 | const decay = 0.6;
|
49 | const v1 = tf.movingAverage(v0, x, decay, 1);
|
50 | expectArraysClose(await v1.array(), [[1, 2], [3, 4]]);
|
51 | const y = tf.tensor2d([[11, 12], [13, 14]], [2, 2]);
|
52 | const v2 = tf.movingAverage(v1, y, decay, 2);
|
53 | expectArraysClose(await v2.array(), [[7.25, 8.25], [9.25, 10.25]]);
|
54 | });
|
55 | it('zeroDebias=true, decay and step are scalars', async () => {
|
56 | const v0 = tf.tensor2d([[0, 0], [0, 0]], [2, 2]);
|
57 | const x = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
|
58 | const decay = tf.scalar(0.6);
|
59 | const v1 = tf.movingAverage(v0, x, decay, tf.scalar(1));
|
60 | expectArraysClose(await v1.array(), [[1, 2], [3, 4]]);
|
61 | const y = tf.tensor2d([[11, 12], [13, 14]], [2, 2]);
|
62 | const v2 = tf.movingAverage(v1, y, decay, tf.scalar(2));
|
63 | expectArraysClose(await v2.array(), [[7.25, 8.25], [9.25, 10.25]]);
|
64 | });
|
65 |
|
66 |
|
67 |
|
68 |
|
69 |
|
70 |
|
71 |
|
72 |
|
73 |
|
74 |
|
75 |
|
76 |
|
77 |
|
78 |
|
79 |
|
80 |
|
81 |
|
82 |
|
83 |
|
84 |
|
85 |
|
86 |
|
87 |
|
88 |
|
89 | it('zeroDebias=false, decay and step are numbers', async () => {
|
90 | const v0 = tf.tensor2d([[0, 0], [0, 0]], [2, 2]);
|
91 | const x = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
|
92 | const decay = 0.6;
|
93 | const v1 = tf.movingAverage(v0, x, decay, null, false);
|
94 | expectArraysClose(await v1.array(), [[0.4, 0.8], [1.2, 1.6]]);
|
95 | const y = tf.tensor2d([[11, 12], [13, 14]], [2, 2]);
|
96 | const v2 = tf.movingAverage(v1, y, decay, null, false);
|
97 | expectArraysClose(await v2.array(), [[4.64, 5.28], [5.92, 6.56]]);
|
98 | });
|
99 | it('zeroDebias=false, decay is scalar', async () => {
|
100 | const v0 = tf.tensor2d([[0, 0], [0, 0]], [2, 2]);
|
101 | const x = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
|
102 | const decay = tf.scalar(0.6);
|
103 | const v1 = tf.movingAverage(v0, x, decay, null, false);
|
104 | expectArraysClose(await v1.array(), [[0.4, 0.8], [1.2, 1.6]]);
|
105 | const y = tf.tensor2d([[11, 12], [13, 14]], [2, 2]);
|
106 | const v2 = tf.movingAverage(v1, y, decay, null, false);
|
107 | expectArraysClose(await v2.array(), [[4.64, 5.28], [5.92, 6.56]]);
|
108 | });
|
109 | it('zeroDebias=true, no step throws error', () => {
|
110 | const v0 = tf.tensor2d([[0, 0], [0, 0]], [2, 2]);
|
111 | const x = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
|
112 | const decay = tf.scalar(0.6);
|
113 | expect(() => tf.movingAverage(v0, x, decay, null)).toThrowError();
|
114 | });
|
115 | it('shape mismatch in v and x throws error', () => {
|
116 | const v0 = tf.tensor2d([[0, 0], [0, 0]], [2, 2]);
|
117 | const x = tf.tensor2d([[1, 2]], [1, 2]);
|
118 | const decay = tf.scalar(0.6);
|
119 | expect(() => tf.movingAverage(v0, x, decay, null)).toThrowError();
|
120 | });
|
121 | it('throws when passed v as a non-tensor', () => {
|
122 | const x = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
|
123 | expect(() => tf.movingAverage({}, x, 1))
|
124 | .toThrowError(/Argument 'v' passed to 'movingAverage' must be a Tensor/);
|
125 | });
|
126 | it('throws when passed v as a non-tensor', () => {
|
127 | const v = tf.tensor2d([[0, 0], [0, 0]], [2, 2]);
|
128 | expect(() => tf.movingAverage(v, {}, 1))
|
129 | .toThrowError(/Argument 'x' passed to 'movingAverage' must be a Tensor/);
|
130 | });
|
131 | it('accepts a tensor-like object', async () => {
|
132 | const v0 = [[0, 0], [0, 0]];
|
133 | const x = [[1, 2], [3, 4]];
|
134 | const decay = 0.6;
|
135 | const v1 = tf.movingAverage(v0, x, decay, 1);
|
136 | expectArraysClose(await v1.array(), [[1, 2], [3, 4]]);
|
137 | });
|
138 | });
|
139 |
|
\ | No newline at end of file |