UNPKG

6.14 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2017 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20describeWithFlags('movingAverage', ALL_ENVS, () => {
21 // Use the following tensorflow to generate reference values for
22 // `zeroDebias` = `true`;
23 //
24 // ```python
25 // import tensorflow as tf
26 // from tensorflow.python.training.moving_averages import
27 // assign_moving_average
28 //
29 // with tf.Session() as sess:
30 // v = tf.get_variable("v1", shape=[2, 2], dtype=tf.float32,
31 // initializer=tf.zeros_initializer)
32 // x = tf.Variable([[1.0, 2.0], [3.0, 4.0]])
33 // inc_x = x.assign_add([[10.0, 10.0], [10.0, 10.0]])
34 // update = assign_moving_average(v, x, 0.6)
35 //
36 // sess.run(tf.global_variables_initializer())
37 //
38 // sess.run(update)
39 // print(sess.run(v))
40 //
41 // sess.run(inc_x)
42 // sess.run(update)
43 // print(sess.run(v))
44 // ```
45 it('zeroDebias=true, decay and step are numbers', async () => {
46 const v0 = tf.tensor2d([[0, 0], [0, 0]], [2, 2]);
47 const x = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
48 const decay = 0.6;
49 const v1 = tf.movingAverage(v0, x, decay, 1);
50 expectArraysClose(await v1.array(), [[1, 2], [3, 4]]);
51 const y = tf.tensor2d([[11, 12], [13, 14]], [2, 2]);
52 const v2 = tf.movingAverage(v1, y, decay, 2);
53 expectArraysClose(await v2.array(), [[7.25, 8.25], [9.25, 10.25]]);
54 });
55 it('zeroDebias=true, decay and step are scalars', async () => {
56 const v0 = tf.tensor2d([[0, 0], [0, 0]], [2, 2]);
57 const x = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
58 const decay = tf.scalar(0.6);
59 const v1 = tf.movingAverage(v0, x, decay, tf.scalar(1));
60 expectArraysClose(await v1.array(), [[1, 2], [3, 4]]);
61 const y = tf.tensor2d([[11, 12], [13, 14]], [2, 2]);
62 const v2 = tf.movingAverage(v1, y, decay, tf.scalar(2));
63 expectArraysClose(await v2.array(), [[7.25, 8.25], [9.25, 10.25]]);
64 });
65 // Use the following tensorflow to generate reference values for
66 // `zeroDebias` = `false`;
67 //
68 // ```python
69 // import tensorflow as tf
70 // from tensorflow.python.training.moving_averages import
71 // assign_moving_average
72 //
73 // with tf.Session() as sess:
74 // v = tf.get_variable("v1", shape=[2, 2], dtype=tf.float32,
75 // initializer=tf.zeros_initializer)
76 // x = tf.Variable([[1.0, 2.0], [3.0, 4.0]])
77 // inc_x = x.assign_add([[10.0, 10.0], [10.0, 10.0]])
78 // update = assign_moving_average(v, x, 0.6, zero_debias=False)
79 //
80 // sess.run(tf.global_variables_initializer())
81 //
82 // sess.run(update)
83 // print(sess.run(v))
84 //
85 // sess.run(inc_x)
86 // sess.run(update)
87 // print(sess.run(v))
88 // ```
89 it('zeroDebias=false, decay and step are numbers', async () => {
90 const v0 = tf.tensor2d([[0, 0], [0, 0]], [2, 2]);
91 const x = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
92 const decay = 0.6;
93 const v1 = tf.movingAverage(v0, x, decay, null, false);
94 expectArraysClose(await v1.array(), [[0.4, 0.8], [1.2, 1.6]]);
95 const y = tf.tensor2d([[11, 12], [13, 14]], [2, 2]);
96 const v2 = tf.movingAverage(v1, y, decay, null, false);
97 expectArraysClose(await v2.array(), [[4.64, 5.28], [5.92, 6.56]]);
98 });
99 it('zeroDebias=false, decay is scalar', async () => {
100 const v0 = tf.tensor2d([[0, 0], [0, 0]], [2, 2]);
101 const x = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
102 const decay = tf.scalar(0.6);
103 const v1 = tf.movingAverage(v0, x, decay, null, false);
104 expectArraysClose(await v1.array(), [[0.4, 0.8], [1.2, 1.6]]);
105 const y = tf.tensor2d([[11, 12], [13, 14]], [2, 2]);
106 const v2 = tf.movingAverage(v1, y, decay, null, false);
107 expectArraysClose(await v2.array(), [[4.64, 5.28], [5.92, 6.56]]);
108 });
109 it('zeroDebias=true, no step throws error', () => {
110 const v0 = tf.tensor2d([[0, 0], [0, 0]], [2, 2]);
111 const x = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
112 const decay = tf.scalar(0.6);
113 expect(() => tf.movingAverage(v0, x, decay, null)).toThrowError();
114 });
115 it('shape mismatch in v and x throws error', () => {
116 const v0 = tf.tensor2d([[0, 0], [0, 0]], [2, 2]);
117 const x = tf.tensor2d([[1, 2]], [1, 2]);
118 const decay = tf.scalar(0.6);
119 expect(() => tf.movingAverage(v0, x, decay, null)).toThrowError();
120 });
121 it('throws when passed v as a non-tensor', () => {
122 const x = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
123 expect(() => tf.movingAverage({}, x, 1))
124 .toThrowError(/Argument 'v' passed to 'movingAverage' must be a Tensor/);
125 });
126 it('throws when passed v as a non-tensor', () => {
127 const v = tf.tensor2d([[0, 0], [0, 0]], [2, 2]);
128 expect(() => tf.movingAverage(v, {}, 1))
129 .toThrowError(/Argument 'x' passed to 'movingAverage' must be a Tensor/);
130 });
131 it('accepts a tensor-like object', async () => {
132 const v0 = [[0, 0], [0, 0]]; // 2x2
133 const x = [[1, 2], [3, 4]]; // 2x2
134 const decay = 0.6;
135 const v1 = tf.movingAverage(v0, x, decay, 1);
136 expectArraysClose(await v1.array(), [[1, 2], [3, 4]]);
137 });
138});
139//# sourceMappingURL=moving_average_test.js.map
\No newline at end of file