UNPKG

19.6 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2017 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20describeWithFlags('multinomial', ALL_ENVS, () => {
21 const NUM_SAMPLES = 1000;
22 // Allowed Variance in probability (in %).
23 const EPSILON = 0.05;
24 const SEED = 3.14;
25 it('Flip a fair coin and check bounds', async () => {
26 const probs = tf.tensor1d([1, 1]);
27 const result = tf.multinomial(probs, NUM_SAMPLES, SEED);
28 expect(result.dtype).toBe('int32');
29 expect(result.shape).toEqual([NUM_SAMPLES]);
30 const outcomeProbs = computeProbs(await result.data(), 2);
31 expectArraysClose(outcomeProbs, [0.5, 0.5], EPSILON);
32 });
33 it('Flip a two-sided coin with 100% of heads', async () => {
34 const logits = tf.tensor1d([1, -100]);
35 const result = tf.multinomial(logits, NUM_SAMPLES, SEED);
36 expect(result.dtype).toBe('int32');
37 expect(result.shape).toEqual([NUM_SAMPLES]);
38 const outcomeProbs = computeProbs(await result.data(), 2);
39 expectArraysClose(outcomeProbs, [1, 0], EPSILON);
40 });
41 it('Flip a two-sided coin with 100% of tails', async () => {
42 const logits = tf.tensor1d([-100, 1]);
43 const result = tf.multinomial(logits, NUM_SAMPLES, SEED);
44 expect(result.dtype).toBe('int32');
45 expect(result.shape).toEqual([NUM_SAMPLES]);
46 const outcomeProbs = computeProbs(await result.data(), 2);
47 expectArraysClose(outcomeProbs, [0, 1], EPSILON);
48 });
49 it('Flip a single-sided coin throws error', () => {
50 const probs = tf.tensor1d([1]);
51 expect(() => tf.multinomial(probs, NUM_SAMPLES, SEED)).toThrowError();
52 });
53 it('Flip a ten-sided coin and check bounds', async () => {
54 const numOutcomes = 10;
55 const logits = tf.fill([numOutcomes], 1).as1D();
56 const result = tf.multinomial(logits, NUM_SAMPLES, SEED);
57 expect(result.dtype).toBe('int32');
58 expect(result.shape).toEqual([NUM_SAMPLES]);
59 const outcomeProbs = computeProbs(await result.data(), numOutcomes);
60 expect(outcomeProbs.length).toBeLessThanOrEqual(numOutcomes);
61 });
62 it('Flip 3 three-sided coins, each coin is 100% biases', async () => {
63 const numOutcomes = 3;
64 const logits = tf.tensor2d([[-100, -100, 1], [-100, 1, -100], [1, -100, -100]], [3, numOutcomes]);
65 const result = tf.multinomial(logits, NUM_SAMPLES, SEED);
66 expect(result.dtype).toBe('int32');
67 expect(result.shape).toEqual([3, NUM_SAMPLES]);
68 // First coin always gets last event.
69 let outcomeProbs = computeProbs((await result.data()).slice(0, NUM_SAMPLES), numOutcomes);
70 expectArraysClose(outcomeProbs, [0, 0, 1], EPSILON);
71 // Second coin always gets middle event.
72 outcomeProbs = computeProbs((await result.data()).slice(NUM_SAMPLES, 2 * NUM_SAMPLES), numOutcomes);
73 expectArraysClose(outcomeProbs, [0, 1, 0], EPSILON);
74 // Third coin always gets first event
75 outcomeProbs =
76 computeProbs((await result.data()).slice(2 * NUM_SAMPLES), numOutcomes);
77 expectArraysClose(outcomeProbs, [1, 0, 0], EPSILON);
78 });
79 it('passing Tensor3D throws error', () => {
80 const probs = tf.zeros([3, 2, 2]);
81 const normalized = true;
82 expect(() => tf.multinomial(probs, 3, SEED, normalized))
83 .toThrowError();
84 });
85 it('throws when passed a non-tensor', () => {
86 // tslint:disable-next-line:no-any
87 expect(() => tf.multinomial({}, NUM_SAMPLES, SEED))
88 .toThrowError(/Argument 'logits' passed to 'multinomial' must be a Tensor/);
89 });
90 it('accepts a tensor-like object for logits (biased coin)', async () => {
91 const res = tf.multinomial([-100, 1], NUM_SAMPLES, SEED);
92 expect(res.dtype).toBe('int32');
93 expect(res.shape).toEqual([NUM_SAMPLES]);
94 const outcomeProbs = computeProbs(await res.data(), 2);
95 expectArraysClose(outcomeProbs, [0, 1], EPSILON);
96 });
97 function computeProbs(events, numOutcomes) {
98 const counts = [];
99 for (let i = 0; i < numOutcomes; ++i) {
100 counts[i] = 0;
101 }
102 const numSamples = events.length;
103 for (let i = 0; i < events.length; ++i) {
104 counts[events[i]]++;
105 }
106 // Normalize counts to be probabilities between [0, 1].
107 for (let i = 0; i < counts.length; i++) {
108 counts[i] /= numSamples;
109 }
110 return counts;
111 }
112});
113//# sourceMappingURL=data:application/json;base64,
\No newline at end of file