1 | import { convertToTensor, convertToTensorArray } from '../tensor_util_env';
|
2 | import { op } from './operation';
|
3 | /**
|
4 | * Computes the next states and outputs of a stack of LSTMCells.
|
5 | *
|
6 | * Each cell output is used as input to the next cell.
|
7 | *
|
8 | * Returns `[cellState, cellOutput]`.
|
9 | *
|
10 | * Derived from tf.contrib.rn.MultiRNNCell.
|
11 | *
|
12 | * @param lstmCells Array of LSTMCell functions.
|
13 | * @param data The input to the cell.
|
14 | * @param c Array of previous cell states.
|
15 | * @param h Array of previous cell outputs.
|
16 | *
|
17 | * @doc {heading: 'Operations', subheading: 'RNN'}
|
18 | */
|
19 | function multiRNNCell_(lstmCells, data, c, h) {
|
20 | const $data = convertToTensor(data, 'data', 'multiRNNCell');
|
21 | const $c = convertToTensorArray(c, 'c', 'multiRNNCell');
|
22 | const $h = convertToTensorArray(h, 'h', 'multiRNNCell');
|
23 | let input = $data;
|
24 | const newStates = [];
|
25 | for (let i = 0; i < lstmCells.length; i++) {
|
26 | const output = lstmCells[i](input, $c[i], $h[i]);
|
27 | newStates.push(output[0]);
|
28 | newStates.push(output[1]);
|
29 | input = output[1];
|
30 | }
|
31 | const newC = [];
|
32 | const newH = [];
|
33 | for (let i = 0; i < newStates.length; i += 2) {
|
34 | newC.push(newStates[i]);
|
35 | newH.push(newStates[i + 1]);
|
36 | }
|
37 | return [newC, newH];
|
38 | }
|
39 | export const multiRNNCell = op({ multiRNNCell_ });
|
40 | //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoibXVsdGlfcm5uX2NlbGwuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvcmUvc3JjL29wcy9tdWx0aV9ybm5fY2VsbC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFpQkEsT0FBTyxFQUFDLGVBQWUsRUFBRSxvQkFBb0IsRUFBQyxNQUFNLG9CQUFvQixDQUFDO0FBRXpFLE9BQU8sRUFBQyxFQUFFLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFTL0I7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBQ0gsU0FBUyxhQUFhLENBQ2xCLFNBQXlCLEVBQUUsSUFBeUIsRUFDcEQsQ0FBNkIsRUFDN0IsQ0FBNkI7SUFDL0IsTUFBTSxLQUFLLEdBQUcsZUFBZSxDQUFDLElBQUksRUFBRSxNQUFNLEVBQUUsY0FBYyxDQUFDLENBQUM7SUFDNUQsTUFBTSxFQUFFLEdBQUcsb0JBQW9CLENBQUMsQ0FBQyxFQUFFLEdBQUcsRUFBRSxjQUFjLENBQUMsQ0FBQztJQUN4RCxNQUFNLEVBQUUsR0FBRyxvQkFBb0IsQ0FBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLGNBQWMsQ0FBQyxDQUFDO0lBRXhELElBQUksS0FBSyxHQUFHLEtBQUssQ0FBQztJQUNsQixNQUFNLFNBQVMsR0FBRyxFQUFFLENBQUM7SUFDckIsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFNBQVMsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxFQUFFLEVBQUU7UUFDekMsTUFBTSxNQUFNLEdBQUcsU0FBUyxDQUFDLENBQUMsQ0FBQyxDQUFDLEtBQUssRUFBRSxFQUFFLENBQUMsQ0FBQyxDQUFDLEVBQUUsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDakQsU0FBUyxDQUFDLElBQUksQ0FBQyxNQUFNLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUMxQixTQUFTLENBQUMsSUFBSSxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQzFCLEtBQUssR0FBRyxNQUFNLENBQUMsQ0FBQyxDQUFDLENBQUM7S0FDbkI7SUFDRCxNQUFNLElBQUksR0FBZSxFQUFFLENBQUM7SUFDNUIsTUFBTSxJQUFJLEdBQWUsRUFBRSxDQUFDO0lBQzVCLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxTQUFTLENBQUMsTUFBTSxFQUFFLENBQUMsSUFBSSxDQUFDLEVBQUU7UUFDNUMsSUFBSSxDQUFDLElBQUksQ0FBQyxTQUFTLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUN4QixJQUFJLENBQUMsSUFBSSxDQUFDLFNBQVMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsQ0FBQztLQUM3QjtJQUNELE9BQU8sQ0FBQyxJQUFJLEVBQUUsSUFBSSxDQUFDLENBQUM7QUFDdEIsQ0FBQztBQUNELE1BQU0sQ0FBQyxNQUFNLFlBQVksR0FBRyxFQUFFLENBQUMsRUFBQyxhQUFhLEVBQUMsQ0FBQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuaW1wb3J0IHtUZW5zb3IyRH0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7Y29udmVydFRvVGVuc29yLCBjb252ZXJ0VG9UZW5zb3JBcnJheX0gZnJvbSAnLi4vdGVuc29yX3V0aWxfZW52JztcbmltcG9ydCB7VGVuc29yTGlrZX0gZnJvbSAnLi4vdHlwZXMnO1xuaW1wb3J0IHtvcH0gZnJvbSAnLi9vcGVyYXRpb24nO1xuXG4vKipcbiAqIEBkb2NhbGlhcyAoZGF0YTogVGVuc29yMkQsIGM6IFRlbnNvcjJELCBoOiBUZW5zb3IyRCk6IFtUZW5zb3IyRCwgVGVuc29yMkRdXG4gKi9cbmV4cG9ydCB0eXBlIExTVE1DZWxsRnVuYyA9IHtcbiAgKGRhdGE6IFRlbnNvcjJELCBjOiBUZW5zb3IyRCwgaDogVGVuc29yMkQpOiBbVGVuc29yMkQsIFRlbnNvcjJEXTtcbn07XG5cbi8qKlxuICogQ29tcHV0ZXMgdGhlIG5leHQgc3RhdGVzIGFuZCBvdXRwdXRzIG9mIGEgc3RhY2sgb2YgTFNUTUNlbGxzLlxuICpcbiAqIEVhY2ggY2VsbCBvdXRwdXQgaXMgdXNlZCBhcyBpbnB1dCB0byB0aGUgbmV4dCBjZWxsLlxuICpcbiAqIFJldHVybnMgYFtjZWxsU3RhdGUsIGNlbGxPdXRwdXRdYC5cbiAqXG4gKiBEZXJpdmVkIGZyb20gdGYuY29udHJpYi5ybi5NdWx0aVJOTkNlbGwuXG4gKlxuICogQHBhcmFtIGxzdG1DZWxscyBBcnJheSBvZiBMU1RNQ2VsbCBmdW5jdGlvbnMuXG4gKiBAcGFyYW0gZGF0YSBUaGUgaW5wdXQgdG8gdGhlIGNlbGwuXG4gKiBAcGFyYW0gYyBBcnJheSBvZiBwcmV2aW91cyBjZWxsIHN0YXRlcy5cbiAqIEBwYXJhbSBoIEFycmF5IG9mIHByZXZpb3VzIGNlbGwgb3V0cHV0cy5cbiAqXG4gKiBAZG9jIHtoZWFkaW5nOiAnT3BlcmF0aW9ucycsIHN1YmhlYWRpbmc6ICdSTk4nfVxuICovXG5mdW5jdGlvbiBtdWx0aVJOTkNlbGxfKFxuICAgIGxzdG1DZWxsczogTFNUTUNlbGxGdW5jW10sIGRhdGE6IFRlbnNvcjJEfFRlbnNvckxpa2UsXG4gICAgYzogQXJyYXk8VGVuc29yMkR8VGVuc29yTGlrZT4sXG4gICAgaDogQXJyYXk8VGVuc29yMkR8VGVuc29yTGlrZT4pOiBbVGVuc29yMkRbXSwgVGVuc29yMkRbXV0ge1xuICBjb25zdCAkZGF0YSA9IGNvbnZlcnRUb1RlbnNvcihkYXRhLCAnZGF0YScsICdtdWx0aVJOTkNlbGwnKTtcbiAgY29uc3QgJGMgPSBjb252ZXJ0VG9UZW5zb3JBcnJheShjLCAnYycsICdtdWx0aVJOTkNlbGwnKTtcbiAgY29uc3QgJGggPSBjb252ZXJ0VG9UZW5zb3JBcnJheShoLCAnaCcsICdtdWx0aVJOTkNlbGwnKTtcblxuICBsZXQgaW5wdXQgPSAkZGF0YTtcbiAgY29uc3QgbmV3U3RhdGVzID0gW107XG4gIGZvciAobGV0IGkgPSAwOyBpIDwgbHN0bUNlbGxzLmxlbmd0aDsgaSsrKSB7XG4gICAgY29uc3Qgb3V0cHV0ID0gbHN0bUNlbGxzW2ldKGlucHV0LCAkY1tpXSwgJGhbaV0pO1xuICAgIG5ld1N0YXRlcy5wdXNoKG91dHB1dFswXSk7XG4gICAgbmV3U3RhdGVzLnB1c2gob3V0cHV0WzFdKTtcbiAgICBpbnB1dCA9IG91dHB1dFsxXTtcbiAgfVxuICBjb25zdCBuZXdDOiBUZW5zb3IyRFtdID0gW107XG4gIGNvbnN0IG5ld0g6IFRlbnNvcjJEW10gPSBbXTtcbiAgZm9yIChsZXQgaSA9IDA7IGkgPCBuZXdTdGF0ZXMubGVuZ3RoOyBpICs9IDIpIHtcbiAgICBuZXdDLnB1c2gobmV3U3RhdGVzW2ldKTtcbiAgICBuZXdILnB1c2gobmV3U3RhdGVzW2kgKyAxXSk7XG4gIH1cbiAgcmV0dXJuIFtuZXdDLCBuZXdIXTtcbn1cbmV4cG9ydCBjb25zdCBtdWx0aVJOTkNlbGwgPSBvcCh7bXVsdGlSTk5DZWxsX30pO1xuIl19 |
\ | No newline at end of file |