UNPKG

6.42 kBJavaScriptView Raw
1import { convertToTensor, convertToTensorArray } from '../tensor_util_env';
2import { op } from './operation';
3/**
4 * Computes the next states and outputs of a stack of LSTMCells.
5 *
6 * Each cell output is used as input to the next cell.
7 *
8 * Returns `[cellState, cellOutput]`.
9 *
10 * Derived from tf.contrib.rn.MultiRNNCell.
11 *
12 * @param lstmCells Array of LSTMCell functions.
13 * @param data The input to the cell.
14 * @param c Array of previous cell states.
15 * @param h Array of previous cell outputs.
16 *
17 * @doc {heading: 'Operations', subheading: 'RNN'}
18 */
19function multiRNNCell_(lstmCells, data, c, h) {
20 const $data = convertToTensor(data, 'data', 'multiRNNCell');
21 const $c = convertToTensorArray(c, 'c', 'multiRNNCell');
22 const $h = convertToTensorArray(h, 'h', 'multiRNNCell');
23 let input = $data;
24 const newStates = [];
25 for (let i = 0; i < lstmCells.length; i++) {
26 const output = lstmCells[i](input, $c[i], $h[i]);
27 newStates.push(output[0]);
28 newStates.push(output[1]);
29 input = output[1];
30 }
31 const newC = [];
32 const newH = [];
33 for (let i = 0; i < newStates.length; i += 2) {
34 newC.push(newStates[i]);
35 newH.push(newStates[i + 1]);
36 }
37 return [newC, newH];
38}
39export const multiRNNCell = op({ multiRNNCell_ });
40//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoibXVsdGlfcm5uX2NlbGwuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvcmUvc3JjL29wcy9tdWx0aV9ybm5fY2VsbC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFpQkEsT0FBTyxFQUFDLGVBQWUsRUFBRSxvQkFBb0IsRUFBQyxNQUFNLG9CQUFvQixDQUFDO0FBRXpFLE9BQU8sRUFBQyxFQUFFLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFTL0I7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBQ0gsU0FBUyxhQUFhLENBQ2xCLFNBQXlCLEVBQUUsSUFBeUIsRUFDcEQsQ0FBNkIsRUFDN0IsQ0FBNkI7SUFDL0IsTUFBTSxLQUFLLEdBQUcsZUFBZSxDQUFDLElBQUksRUFBRSxNQUFNLEVBQUUsY0FBYyxDQUFDLENBQUM7SUFDNUQsTUFBTSxFQUFFLEdBQUcsb0JBQW9CLENBQUMsQ0FBQyxFQUFFLEdBQUcsRUFBRSxjQUFjLENBQUMsQ0FBQztJQUN4RCxNQUFNLEVBQUUsR0FBRyxvQkFBb0IsQ0FBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLGNBQWMsQ0FBQyxDQUFDO0lBRXhELElBQUksS0FBSyxHQUFHLEtBQUssQ0FBQztJQUNsQixNQUFNLFNBQVMsR0FBRyxFQUFFLENBQUM7SUFDckIsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFNBQVMsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxFQUFFLEVBQUU7UUFDekMsTUFBTSxNQUFNLEdBQUcsU0FBUyxDQUFDLENBQUMsQ0FBQyxDQUFDLEtBQUssRUFBRSxFQUFFLENBQUMsQ0FBQyxDQUFDLEVBQUUsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDakQsU0FBUyxDQUFDLElBQUksQ0FBQyxNQUFNLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUMxQixTQUFTLENBQUMsSUFBSSxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQzFCLEtBQUssR0FBRyxNQUFNLENBQUMsQ0FBQyxDQUFDLENBQUM7S0FDbkI7SUFDRCxNQUFNLElBQUksR0FBZSxFQUFFLENBQUM7SUFDNUIsTUFBTSxJQUFJLEdBQWUsRUFBRSxDQUFDO0lBQzVCLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxTQUFTLENBQUMsTUFBTSxFQUFFLENBQUMsSUFBSSxDQUFDLEVBQUU7UUFDNUMsSUFBSSxDQUFDLElBQUksQ0FBQyxTQUFTLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUN4QixJQUFJLENBQUMsSUFBSSxDQUFDLFNBQVMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsQ0FBQztLQUM3QjtJQUNELE9BQU8sQ0FBQyxJQUFJLEVBQUUsSUFBSSxDQUFDLENBQUM7QUFDdEIsQ0FBQztBQUNELE1BQU0sQ0FBQyxNQUFNLFlBQVksR0FBRyxFQUFFLENBQUMsRUFBQyxhQUFhLEVBQUMsQ0FBQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuaW1wb3J0IHtUZW5zb3IyRH0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7Y29udmVydFRvVGVuc29yLCBjb252ZXJ0VG9UZW5zb3JBcnJheX0gZnJvbSAnLi4vdGVuc29yX3V0aWxfZW52JztcbmltcG9ydCB7VGVuc29yTGlrZX0gZnJvbSAnLi4vdHlwZXMnO1xuaW1wb3J0IHtvcH0gZnJvbSAnLi9vcGVyYXRpb24nO1xuXG4vKipcbiAqIEBkb2NhbGlhcyAoZGF0YTogVGVuc29yMkQsIGM6IFRlbnNvcjJELCBoOiBUZW5zb3IyRCk6IFtUZW5zb3IyRCwgVGVuc29yMkRdXG4gKi9cbmV4cG9ydCB0eXBlIExTVE1DZWxsRnVuYyA9IHtcbiAgKGRhdGE6IFRlbnNvcjJELCBjOiBUZW5zb3IyRCwgaDogVGVuc29yMkQpOiBbVGVuc29yMkQsIFRlbnNvcjJEXTtcbn07XG5cbi8qKlxuICogQ29tcHV0ZXMgdGhlIG5leHQgc3RhdGVzIGFuZCBvdXRwdXRzIG9mIGEgc3RhY2sgb2YgTFNUTUNlbGxzLlxuICpcbiAqIEVhY2ggY2VsbCBvdXRwdXQgaXMgdXNlZCBhcyBpbnB1dCB0byB0aGUgbmV4dCBjZWxsLlxuICpcbiAqIFJldHVybnMgYFtjZWxsU3RhdGUsIGNlbGxPdXRwdXRdYC5cbiAqXG4gKiBEZXJpdmVkIGZyb20gdGYuY29udHJpYi5ybi5NdWx0aVJOTkNlbGwuXG4gKlxuICogQHBhcmFtIGxzdG1DZWxscyBBcnJheSBvZiBMU1RNQ2VsbCBmdW5jdGlvbnMuXG4gKiBAcGFyYW0gZGF0YSBUaGUgaW5wdXQgdG8gdGhlIGNlbGwuXG4gKiBAcGFyYW0gYyBBcnJheSBvZiBwcmV2aW91cyBjZWxsIHN0YXRlcy5cbiAqIEBwYXJhbSBoIEFycmF5IG9mIHByZXZpb3VzIGNlbGwgb3V0cHV0cy5cbiAqXG4gKiBAZG9jIHtoZWFkaW5nOiAnT3BlcmF0aW9ucycsIHN1YmhlYWRpbmc6ICdSTk4nfVxuICovXG5mdW5jdGlvbiBtdWx0aVJOTkNlbGxfKFxuICAgIGxzdG1DZWxsczogTFNUTUNlbGxGdW5jW10sIGRhdGE6IFRlbnNvcjJEfFRlbnNvckxpa2UsXG4gICAgYzogQXJyYXk8VGVuc29yMkR8VGVuc29yTGlrZT4sXG4gICAgaDogQXJyYXk8VGVuc29yMkR8VGVuc29yTGlrZT4pOiBbVGVuc29yMkRbXSwgVGVuc29yMkRbXV0ge1xuICBjb25zdCAkZGF0YSA9IGNvbnZlcnRUb1RlbnNvcihkYXRhLCAnZGF0YScsICdtdWx0aVJOTkNlbGwnKTtcbiAgY29uc3QgJGMgPSBjb252ZXJ0VG9UZW5zb3JBcnJheShjLCAnYycsICdtdWx0aVJOTkNlbGwnKTtcbiAgY29uc3QgJGggPSBjb252ZXJ0VG9UZW5zb3JBcnJheShoLCAnaCcsICdtdWx0aVJOTkNlbGwnKTtcblxuICBsZXQgaW5wdXQgPSAkZGF0YTtcbiAgY29uc3QgbmV3U3RhdGVzID0gW107XG4gIGZvciAobGV0IGkgPSAwOyBpIDwgbHN0bUNlbGxzLmxlbmd0aDsgaSsrKSB7XG4gICAgY29uc3Qgb3V0cHV0ID0gbHN0bUNlbGxzW2ldKGlucHV0LCAkY1tpXSwgJGhbaV0pO1xuICAgIG5ld1N0YXRlcy5wdXNoKG91dHB1dFswXSk7XG4gICAgbmV3U3RhdGVzLnB1c2gob3V0cHV0WzFdKTtcbiAgICBpbnB1dCA9IG91dHB1dFsxXTtcbiAgfVxuICBjb25zdCBuZXdDOiBUZW5zb3IyRFtdID0gW107XG4gIGNvbnN0IG5ld0g6IFRlbnNvcjJEW10gPSBbXTtcbiAgZm9yIChsZXQgaSA9IDA7IGkgPCBuZXdTdGF0ZXMubGVuZ3RoOyBpICs9IDIpIHtcbiAgICBuZXdDLnB1c2gobmV3U3RhdGVzW2ldKTtcbiAgICBuZXdILnB1c2gobmV3U3RhdGVzW2kgKyAxXSk7XG4gIH1cbiAgcmV0dXJuIFtuZXdDLCBuZXdIXTtcbn1cbmV4cG9ydCBjb25zdCBtdWx0aVJOTkNlbGwgPSBvcCh7bXVsdGlSTk5DZWxsX30pO1xuIl19
\No newline at end of file