1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 | import * as tf from '../index';
|
18 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
19 | import { expectArraysClose } from '../test_util';
|
20 | describeWithFlags('lstm', ALL_ENVS, () => {
|
21 | it('MultiRNNCell with 2 BasicLSTMCells', async () => {
|
22 | const lstmKernel1 = tf.tensor2d([
|
23 | 0.26242125034332275, -0.8787832260131836, 0.781475305557251,
|
24 | 1.337337851524353, 0.6180247068405151, -0.2760246992111206,
|
25 | -0.11299663782119751, -0.46332040429115295, -0.1765323281288147,
|
26 | 0.6807947158813477, -0.8326982855796814, 0.6732975244522095
|
27 | ], [3, 4]);
|
28 | const lstmBias1 = tf.tensor1d([1.090713620185852, -0.8282332420349121, 0, 1.0889357328414917]);
|
29 | const lstmKernel2 = tf.tensor2d([
|
30 | -1.893059492111206, -1.0185645818710327, -0.6270437240600586,
|
31 | -2.1829540729522705, -0.4583775997161865, -0.5454602241516113,
|
32 | -0.3114445209503174, 0.8450229167938232
|
33 | ], [2, 4]);
|
34 | const lstmBias2 = tf.tensor1d([0.9906240105628967, 0.6248329877853394, 0, 1.0224634408950806]);
|
35 | const forgetBias = tf.scalar(1.0);
|
36 | const lstm1 = (data, c, h) => tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h);
|
37 | const lstm2 = (data, c, h) => tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h);
|
38 | const c = [
|
39 | tf.zeros([1, lstmBias1.shape[0] / 4]),
|
40 | tf.zeros([1, lstmBias2.shape[0] / 4])
|
41 | ];
|
42 | const h = [
|
43 | tf.zeros([1, lstmBias1.shape[0] / 4]),
|
44 | tf.zeros([1, lstmBias2.shape[0] / 4])
|
45 | ];
|
46 | const onehot = tf.buffer([1, 2], 'float32');
|
47 | onehot.set(1.0, 0, 0);
|
48 | const output = tf.multiRNNCell([lstm1, lstm2], onehot.toTensor(), c, h);
|
49 | expectArraysClose(await output[0][0].data(), [-0.7440074682235718]);
|
50 | expectArraysClose(await output[0][1].data(), [0.7460772395133972]);
|
51 | expectArraysClose(await output[1][0].data(), [-0.5802832245826721]);
|
52 | expectArraysClose(await output[1][1].data(), [0.5745711922645569]);
|
53 | });
|
54 | });
|
55 | describeWithFlags('multiRNN throws when passed non-tensor', ALL_ENVS, () => {
|
56 | it('input: data', () => {
|
57 | const lstmKernel1 = tf.zeros([3, 4]);
|
58 | const lstmBias1 = tf.zeros([4]);
|
59 | const lstmKernel2 = tf.zeros([2, 4]);
|
60 | const lstmBias2 = tf.zeros([4]);
|
61 | const forgetBias = tf.scalar(1.0);
|
62 | const lstm1 = (data, c, h) => tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h);
|
63 | const lstm2 = (data, c, h) => tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h);
|
64 | const c = [
|
65 | tf.zeros([1, lstmBias1.shape[0] / 4]),
|
66 | tf.zeros([1, lstmBias2.shape[0] / 4])
|
67 | ];
|
68 | const h = [
|
69 | tf.zeros([1, lstmBias1.shape[0] / 4]),
|
70 | tf.zeros([1, lstmBias2.shape[0] / 4])
|
71 | ];
|
72 | expect(() => tf.multiRNNCell([lstm1, lstm2], {}, c, h))
|
73 | .toThrowError(/Argument 'data' passed to 'multiRNNCell' must be a Tensor/);
|
74 | });
|
75 | it('input: c', () => {
|
76 | const lstmKernel1 = tf.zeros([3, 4]);
|
77 | const lstmBias1 = tf.zeros([4]);
|
78 | const lstmKernel2 = tf.zeros([2, 4]);
|
79 | const lstmBias2 = tf.zeros([4]);
|
80 | const forgetBias = tf.scalar(1.0);
|
81 | const lstm1 = (data, c, h) => tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h);
|
82 | const lstm2 = (data, c, h) => tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h);
|
83 | const h = [
|
84 | tf.zeros([1, lstmBias1.shape[0] / 4]),
|
85 | tf.zeros([1, lstmBias2.shape[0] / 4])
|
86 | ];
|
87 | const data = tf.zeros([1, 2]);
|
88 | expect(() => tf.multiRNNCell([lstm1, lstm2], data, [{}], h))
|
89 | .toThrowError(/Argument 'c\[0\]' passed to 'multiRNNCell' must be a Tensor/);
|
90 | });
|
91 | it('input: h', () => {
|
92 | const lstmKernel1 = tf.zeros([3, 4]);
|
93 | const lstmBias1 = tf.zeros([4]);
|
94 | const lstmKernel2 = tf.zeros([2, 4]);
|
95 | const lstmBias2 = tf.zeros([4]);
|
96 | const forgetBias = tf.scalar(1.0);
|
97 | const lstm1 = (data, c, h) => tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h);
|
98 | const lstm2 = (data, c, h) => tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h);
|
99 | const c = [
|
100 | tf.zeros([1, lstmBias1.shape[0] / 4]),
|
101 | tf.zeros([1, lstmBias2.shape[0] / 4])
|
102 | ];
|
103 | const data = tf.zeros([1, 2]);
|
104 | expect(() => tf.multiRNNCell([lstm1, lstm2], data, c, [{}]))
|
105 | .toThrowError(/Argument 'h\[0\]' passed to 'multiRNNCell' must be a Tensor/);
|
106 | });
|
107 | });
|
108 |
|
\ | No newline at end of file |