UNPKG

10.1 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2017 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { util } from '..';
18import * as tf from '../index';
19import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
20import { expectValuesInRange } from '../test_util';
21import { MPRandGauss, RandGamma, UniformRandom } from './rand_util';
22import { expectArrayInMeanStdRange, jarqueBeraNormalityTest } from './rand_util';
23describeWithFlags('rand', ALL_ENVS, () => {
24 it('should return a random 1D float32 array', async () => {
25 const shape = [10];
26 // Enusre defaults to float32 w/o type:
27 let result = tf.rand(shape, () => util.randUniform(0, 2));
28 expect(result.dtype).toBe('float32');
29 expectValuesInRange(await result.data(), 0, 2);
30 result = tf.rand(shape, () => util.randUniform(0, 1.5));
31 expect(result.dtype).toBe('float32');
32 expectValuesInRange(await result.data(), 0, 1.5);
33 });
34 it('should return a random 1D int32 array', async () => {
35 const shape = [10];
36 const result = tf.rand(shape, () => util.randUniform(0, 2), 'int32');
37 expect(result.dtype).toBe('int32');
38 expectValuesInRange(await result.data(), 0, 2);
39 });
40 it('should return a random 1D bool array', async () => {
41 const shape = [10];
42 const result = tf.rand(shape, () => util.randUniform(0, 1), 'bool');
43 expect(result.dtype).toBe('bool');
44 expectValuesInRange(await result.data(), 0, 1);
45 });
46 it('should return a random 2D float32 array', async () => {
47 const shape = [3, 4];
48 // Enusre defaults to float32 w/o type:
49 let result = tf.rand(shape, () => util.randUniform(0, 2.5));
50 expect(result.dtype).toBe('float32');
51 expectValuesInRange(await result.data(), 0, 2.5);
52 result = tf.rand(shape, () => util.randUniform(0, 1.5), 'float32');
53 expect(result.dtype).toBe('float32');
54 expectValuesInRange(await result.data(), 0, 1.5);
55 });
56 it('should return a random 2D int32 array', async () => {
57 const shape = [3, 4];
58 const result = tf.rand(shape, () => util.randUniform(0, 2), 'int32');
59 expect(result.dtype).toBe('int32');
60 expectValuesInRange(await result.data(), 0, 2);
61 });
62 it('should return a random 2D bool array', async () => {
63 const shape = [3, 4];
64 const result = tf.rand(shape, () => util.randUniform(0, 1), 'bool');
65 expect(result.dtype).toBe('bool');
66 expectValuesInRange(await result.data(), 0, 1);
67 });
68 it('should return a random 3D float32 array', async () => {
69 const shape = [3, 4, 5];
70 // Enusre defaults to float32 w/o type:
71 let result = tf.rand(shape, () => util.randUniform(0, 2.5));
72 expect(result.dtype).toBe('float32');
73 expectValuesInRange(await result.data(), 0, 2.5);
74 result = tf.rand(shape, () => util.randUniform(0, 1.5), 'float32');
75 expect(result.dtype).toBe('float32');
76 expectValuesInRange(await result.data(), 0, 1.5);
77 });
78 it('should return a random 3D int32 array', async () => {
79 const shape = [3, 4, 5];
80 const result = tf.rand(shape, () => util.randUniform(0, 2), 'int32');
81 expect(result.dtype).toBe('int32');
82 expectValuesInRange(await result.data(), 0, 2);
83 });
84 it('should return a random 3D bool array', async () => {
85 const shape = [3, 4, 5];
86 const result = tf.rand(shape, () => util.randUniform(0, 1), 'bool');
87 expect(result.dtype).toBe('bool');
88 expectValuesInRange(await result.data(), 0, 1);
89 });
90 it('should return a random 4D float32 array', async () => {
91 const shape = [3, 4, 5, 6];
92 // Enusre defaults to float32 w/o type:
93 let result = tf.rand(shape, () => util.randUniform(0, 2.5));
94 expect(result.dtype).toBe('float32');
95 expectValuesInRange(await result.data(), 0, 2.5);
96 result = tf.rand(shape, () => util.randUniform(0, 1.5));
97 expect(result.dtype).toBe('float32');
98 expectValuesInRange(await result.data(), 0, 1.5);
99 });
100 it('should return a random 4D int32 array', async () => {
101 const shape = [3, 4, 5, 6];
102 const result = tf.rand(shape, () => util.randUniform(0, 2), 'int32');
103 expect(result.dtype).toBe('int32');
104 expectValuesInRange(await result.data(), 0, 2);
105 });
106 it('should return a random 4D bool array', async () => {
107 const shape = [3, 4, 5, 6];
108 const result = tf.rand(shape, () => util.randUniform(0, 1), 'bool');
109 expect(result.dtype).toBe('bool');
110 expectValuesInRange(await result.data(), 0, 1);
111 });
112});
113function isFloat(n) {
114 return Number(n) === n && n % 1 !== 0;
115}
116describe('MPRandGauss', () => {
117 const EPSILON = 0.05;
118 const SEED = 2002;
119 it('should default to float32 numbers', () => {
120 const rand = new MPRandGauss(0, 1.5);
121 expect(isFloat(rand.nextValue())).toBe(true);
122 });
123 it('should handle a mean/stdv of float32 numbers', () => {
124 const rand = new MPRandGauss(0, 1.5, 'float32', false /* truncated */, SEED);
125 const values = [];
126 const size = 10000;
127 for (let i = 0; i < size; i++) {
128 values.push(rand.nextValue());
129 }
130 expectArrayInMeanStdRange(values, 0, 1.5, EPSILON);
131 jarqueBeraNormalityTest(values);
132 });
133 it('should handle int32 numbers', () => {
134 const rand = new MPRandGauss(0, 1, 'int32');
135 expect(isFloat(rand.nextValue())).toBe(false);
136 });
137 it('should handle a mean/stdv of int32 numbers', () => {
138 const rand = new MPRandGauss(0, 2, 'int32', false /* truncated */, SEED);
139 const values = [];
140 const size = 10000;
141 for (let i = 0; i < size; i++) {
142 values.push(rand.nextValue());
143 }
144 expectArrayInMeanStdRange(values, 0, 2, EPSILON);
145 jarqueBeraNormalityTest(values);
146 });
147 it('Should not have a more than 2x std-d from mean for truncated values', () => {
148 const stdv = 1.5;
149 const rand = new MPRandGauss(0, stdv, 'float32', true /* truncated */);
150 for (let i = 0; i < 1000; i++) {
151 expect(Math.abs(rand.nextValue())).toBeLessThan(stdv * 2);
152 }
153 });
154});
155describe('RandGamma', () => {
156 const SEED = 2002;
157 it('should default to float32 numbers', () => {
158 const rand = new RandGamma(2, 2, 'float32');
159 expect(isFloat(rand.nextValue())).toBe(true);
160 });
161 it('should handle an alpha/beta of float32 numbers', () => {
162 const rand = new RandGamma(2, 2, 'float32', SEED);
163 const values = [];
164 const size = 10000;
165 for (let i = 0; i < size; i++) {
166 values.push(rand.nextValue());
167 }
168 expectValuesInRange(values, 0, 30);
169 });
170 it('should handle int32 numbers', () => {
171 const rand = new RandGamma(2, 2, 'int32');
172 expect(isFloat(rand.nextValue())).toBe(false);
173 });
174 it('should handle an alpha/beta of int32 numbers', () => {
175 const rand = new RandGamma(2, 2, 'int32', SEED);
176 const values = [];
177 const size = 10000;
178 for (let i = 0; i < size; i++) {
179 values.push(rand.nextValue());
180 }
181 expectValuesInRange(values, 0, 30);
182 });
183});
184describe('UniformRandom', () => {
185 it('float32, no seed', () => {
186 const min = 0.2;
187 const max = 0.24;
188 const dtype = 'float32';
189 const xs = [];
190 for (let i = 0; i < 10; ++i) {
191 const rand = new UniformRandom(min, max, dtype);
192 const x = rand.nextValue();
193 xs.push(x);
194 }
195 expect(Math.min(...xs)).toBeGreaterThanOrEqual(min);
196 expect(Math.max(...xs)).toBeLessThan(max);
197 });
198 it('int32, no seed', () => {
199 const min = 13;
200 const max = 37;
201 const dtype = 'int32';
202 const xs = [];
203 for (let i = 0; i < 10; ++i) {
204 const rand = new UniformRandom(min, max, dtype);
205 const x = rand.nextValue();
206 expect(Number.isInteger(x)).toEqual(true);
207 xs.push(x);
208 }
209 expect(Math.min(...xs)).toBeGreaterThanOrEqual(min);
210 expect(Math.max(...xs)).toBeLessThanOrEqual(max);
211 });
212 it('seed is number', () => {
213 const min = -1.2;
214 const max = -0.4;
215 const dtype = 'float32';
216 const seed = 1337;
217 const xs = [];
218 for (let i = 0; i < 10; ++i) {
219 const rand = new UniformRandom(min, max, dtype, seed);
220 const x = rand.nextValue();
221 expect(x).toBeGreaterThanOrEqual(min);
222 expect(x).toBeLessThan(max);
223 xs.push(x);
224 }
225 // Assert deterministic results.
226 expect(Math.min(...xs)).toEqual(Math.max(...xs));
227 });
228 it('seed === null', () => {
229 const min = 0;
230 const max = 1;
231 const dtype = 'float32';
232 const seed = null;
233 const rand = new UniformRandom(min, max, dtype, seed);
234 const x = rand.nextValue();
235 expect(x).toBeGreaterThanOrEqual(0);
236 expect(x).toBeLessThan(1);
237 });
238 it('seed === undefined', () => {
239 const min = 0;
240 const max = 1;
241 const dtype = 'float32';
242 const seed = undefined;
243 const rand = new UniformRandom(min, max, dtype, seed);
244 const x = rand.nextValue();
245 expect(x).toBeGreaterThanOrEqual(0);
246 expect(x).toBeLessThan(1);
247 });
248});
249//# sourceMappingURL=rand_test.js.map
\No newline at end of file