UNPKG

5.36 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20describeWithFlags('lstm', ALL_ENVS, () => {
21 it('MultiRNNCell with 2 BasicLSTMCells', async () => {
22 const lstmKernel1 = tf.tensor2d([
23 0.26242125034332275, -0.8787832260131836, 0.781475305557251,
24 1.337337851524353, 0.6180247068405151, -0.2760246992111206,
25 -0.11299663782119751, -0.46332040429115295, -0.1765323281288147,
26 0.6807947158813477, -0.8326982855796814, 0.6732975244522095
27 ], [3, 4]);
28 const lstmBias1 = tf.tensor1d([1.090713620185852, -0.8282332420349121, 0, 1.0889357328414917]);
29 const lstmKernel2 = tf.tensor2d([
30 -1.893059492111206, -1.0185645818710327, -0.6270437240600586,
31 -2.1829540729522705, -0.4583775997161865, -0.5454602241516113,
32 -0.3114445209503174, 0.8450229167938232
33 ], [2, 4]);
34 const lstmBias2 = tf.tensor1d([0.9906240105628967, 0.6248329877853394, 0, 1.0224634408950806]);
35 const forgetBias = tf.scalar(1.0);
36 const lstm1 = (data, c, h) => tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h);
37 const lstm2 = (data, c, h) => tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h);
38 const c = [
39 tf.zeros([1, lstmBias1.shape[0] / 4]),
40 tf.zeros([1, lstmBias2.shape[0] / 4])
41 ];
42 const h = [
43 tf.zeros([1, lstmBias1.shape[0] / 4]),
44 tf.zeros([1, lstmBias2.shape[0] / 4])
45 ];
46 const onehot = tf.buffer([1, 2], 'float32');
47 onehot.set(1.0, 0, 0);
48 const output = tf.multiRNNCell([lstm1, lstm2], onehot.toTensor(), c, h);
49 expectArraysClose(await output[0][0].data(), [-0.7440074682235718]);
50 expectArraysClose(await output[0][1].data(), [0.7460772395133972]);
51 expectArraysClose(await output[1][0].data(), [-0.5802832245826721]);
52 expectArraysClose(await output[1][1].data(), [0.5745711922645569]);
53 });
54});
55describeWithFlags('multiRNN throws when passed non-tensor', ALL_ENVS, () => {
56 it('input: data', () => {
57 const lstmKernel1 = tf.zeros([3, 4]);
58 const lstmBias1 = tf.zeros([4]);
59 const lstmKernel2 = tf.zeros([2, 4]);
60 const lstmBias2 = tf.zeros([4]);
61 const forgetBias = tf.scalar(1.0);
62 const lstm1 = (data, c, h) => tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h);
63 const lstm2 = (data, c, h) => tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h);
64 const c = [
65 tf.zeros([1, lstmBias1.shape[0] / 4]),
66 tf.zeros([1, lstmBias2.shape[0] / 4])
67 ];
68 const h = [
69 tf.zeros([1, lstmBias1.shape[0] / 4]),
70 tf.zeros([1, lstmBias2.shape[0] / 4])
71 ];
72 expect(() => tf.multiRNNCell([lstm1, lstm2], {}, c, h))
73 .toThrowError(/Argument 'data' passed to 'multiRNNCell' must be a Tensor/);
74 });
75 it('input: c', () => {
76 const lstmKernel1 = tf.zeros([3, 4]);
77 const lstmBias1 = tf.zeros([4]);
78 const lstmKernel2 = tf.zeros([2, 4]);
79 const lstmBias2 = tf.zeros([4]);
80 const forgetBias = tf.scalar(1.0);
81 const lstm1 = (data, c, h) => tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h);
82 const lstm2 = (data, c, h) => tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h);
83 const h = [
84 tf.zeros([1, lstmBias1.shape[0] / 4]),
85 tf.zeros([1, lstmBias2.shape[0] / 4])
86 ];
87 const data = tf.zeros([1, 2]);
88 expect(() => tf.multiRNNCell([lstm1, lstm2], data, [{}], h))
89 .toThrowError(/Argument 'c\[0\]' passed to 'multiRNNCell' must be a Tensor/);
90 });
91 it('input: h', () => {
92 const lstmKernel1 = tf.zeros([3, 4]);
93 const lstmBias1 = tf.zeros([4]);
94 const lstmKernel2 = tf.zeros([2, 4]);
95 const lstmBias2 = tf.zeros([4]);
96 const forgetBias = tf.scalar(1.0);
97 const lstm1 = (data, c, h) => tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h);
98 const lstm2 = (data, c, h) => tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h);
99 const c = [
100 tf.zeros([1, lstmBias1.shape[0] / 4]),
101 tf.zeros([1, lstmBias2.shape[0] / 4])
102 ];
103 const data = tf.zeros([1, 2]);
104 expect(() => tf.multiRNNCell([lstm1, lstm2], data, c, [{}]))
105 .toThrowError(/Argument 'h\[0\]' passed to 'multiRNNCell' must be a Tensor/);
106 });
107});
108//# sourceMappingURL=multi_rnn_cell_test.js.map
\No newline at end of file