UNPKG

228 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/**
11 * TensorFlow.js Layers: Recurrent Neural Network Layers.
12 */
13import * as tfc from '@tensorflow/tfjs-core';
14import { serialization, tidy, util } from '@tensorflow/tfjs-core';
15import { getActivation, serializeActivation } from '../activations';
16import * as K from '../backend/tfjs_backend';
17import { nameScope } from '../common';
18import { getConstraint, serializeConstraint } from '../constraints';
19import { InputSpec, SymbolicTensor } from '../engine/topology';
20import { Layer } from '../engine/topology';
21import { AttributeError, NotImplementedError, ValueError } from '../errors';
22import { getInitializer, Initializer, Ones, serializeInitializer } from '../initializers';
23import { getRegularizer, serializeRegularizer } from '../regularizers';
24import { assertPositiveInteger } from '../utils/generic_utils';
25import * as math_utils from '../utils/math_utils';
26import { getExactlyOneShape, getExactlyOneTensor, isArrayOfShapes } from '../utils/types_utils';
27import { batchGetValue, batchSetValue } from '../variables';
28import { deserialize } from './serialization';
29/**
30 * Standardize `apply()` args to a single list of tensor inputs.
31 *
32 * When running a model loaded from file, the input tensors `initialState` and
33 * `constants` are passed to `RNN.apply()` as part of `inputs` instead of the
34 * dedicated kwargs fields. `inputs` consists of
35 * `[inputs, initialState0, initialState1, ..., constant0, constant1]` in this
36 * case.
37 * This method makes sure that arguments are
38 * separated and that `initialState` and `constants` are `Array`s of tensors
39 * (or None).
40 *
41 * @param inputs Tensor or `Array` of tensors.
42 * @param initialState Tensor or `Array` of tensors or `null`/`undefined`.
43 * @param constants Tensor or `Array` of tensors or `null`/`undefined`.
44 * @returns An object consisting of
45 * inputs: A tensor.
46 * initialState: `Array` of tensors or `null`.
47 * constants: `Array` of tensors or `null`.
48 * @throws ValueError, if `inputs` is an `Array` but either `initialState` or
49 * `constants` is provided.
50 */
51export function standardizeArgs(inputs, initialState, constants, numConstants) {
52 if (Array.isArray(inputs)) {
53 if (initialState != null || constants != null) {
54 throw new ValueError('When inputs is an array, neither initialState or constants ' +
55 'should be provided');
56 }
57 if (numConstants != null) {
58 constants = inputs.slice(inputs.length - numConstants, inputs.length);
59 inputs = inputs.slice(0, inputs.length - numConstants);
60 }
61 if (inputs.length > 1) {
62 initialState = inputs.slice(1, inputs.length);
63 }
64 inputs = inputs[0];
65 }
66 function toListOrNull(x) {
67 if (x == null || Array.isArray(x)) {
68 return x;
69 }
70 else {
71 return [x];
72 }
73 }
74 initialState = toListOrNull(initialState);
75 constants = toListOrNull(constants);
76 return { inputs, initialState, constants };
77}
78/**
79 * Iterates over the time dimension of a tensor.
80 *
81 * @param stepFunction RNN step function.
82 * Parameters:
83 * inputs: tensor with shape `[samples, ...]` (no time dimension),
84 * representing input for the batch of samples at a certain time step.
85 * states: an Array of tensors.
86 * Returns:
87 * outputs: tensor with shape `[samples, outputDim]` (no time dimension).
88 * newStates: list of tensors, same length and shapes as `states`. The first
89 * state in the list must be the output tensor at the previous timestep.
90 * @param inputs Tensor of temporal data of shape `[samples, time, ...]` (at
91 * least 3D).
92 * @param initialStates Tensor with shape `[samples, outputDim]` (no time
93 * dimension), containing the initial values of the states used in the step
94 * function.
95 * @param goBackwards If `true`, do the iteration over the time dimension in
96 * reverse order and return the reversed sequence.
97 * @param mask Binary tensor with shape `[sample, time, 1]`, with a zero for
98 * every element that is masked.
99 * @param constants An Array of constant values passed at each step.
100 * @param unroll Whether to unroll the RNN or to use a symbolic loop. *Not*
101 * applicable to this imperative deeplearn.js backend. Its value is ignored.
102 * @param needPerStepOutputs Whether the per-step outputs are to be
103 * concatenated into a single tensor and returned (as the second return
104 * value). Default: `false`. This arg is included so that the relatively
105 * expensive concatenation of the stepwise outputs can be omitted unless
106 * the stepwise outputs need to be kept (e.g., for an LSTM layer of which
107 * `returnSequence` is `true`.)
108 * @returns An Array: `[lastOutput, outputs, newStates]`.
109 * lastOutput: the lastest output of the RNN, of shape `[samples, ...]`.
110 * outputs: tensor with shape `[samples, time, ...]` where each entry
111 * `output[s, t]` is the output of the step function at time `t` for sample
112 * `s`. This return value is provided if and only if the
113 * `needPerStepOutputs` is set as `true`. If it is set as `false`, this
114 * return value will be `undefined`.
115 * newStates: Array of tensors, latest states returned by the step function,
116 * of shape `(samples, ...)`.
117 * @throws ValueError If input dimension is less than 3.
118 *
119 * TODO(nielsene): This needs to be tidy-ed.
120 */
121export function rnn(stepFunction, inputs, initialStates, goBackwards = false, mask, constants, unroll = false, needPerStepOutputs = false) {
122 return tfc.tidy(() => {
123 const ndim = inputs.shape.length;
124 if (ndim < 3) {
125 throw new ValueError(`Input should be at least 3D, but is ${ndim}D.`);
126 }
127 // Transpose to time-major, i.e., from [batch, time, ...] to [time, batch,
128 // ...].
129 const axes = [1, 0].concat(math_utils.range(2, ndim));
130 inputs = tfc.transpose(inputs, axes);
131 if (constants != null) {
132 throw new NotImplementedError('The rnn() functoin of the deeplearn.js backend does not support ' +
133 'constants yet.');
134 }
135 // Porting Note: the unroll option is ignored by the imperative backend.
136 if (unroll) {
137 console.warn('Backend rnn(): the unroll = true option is not applicable to the ' +
138 'imperative deeplearn.js backend.');
139 }
140 if (mask != null) {
141 mask = tfc.cast(tfc.cast(mask, 'bool'), 'float32');
142 if (mask.rank === ndim - 1) {
143 mask = tfc.expandDims(mask, -1);
144 }
145 mask = tfc.transpose(mask, axes);
146 }
147 if (goBackwards) {
148 inputs = tfc.reverse(inputs, 0);
149 if (mask != null) {
150 mask = tfc.reverse(mask, 0);
151 }
152 }
153 // Porting Note: PyKeras with TensorFlow backend uses a symbolic loop
154 // (tf.while_loop). But for the imperative deeplearn.js backend, we just
155 // use the usual TypeScript control flow to iterate over the time steps in
156 // the inputs.
157 // Porting Note: PyKeras patches a "_use_learning_phase" attribute to
158 // outputs.
159 // This is not idiomatic in TypeScript. The info regarding whether we are
160 // in a learning (i.e., training) phase for RNN is passed in a different
161 // way.
162 const perStepOutputs = [];
163 let lastOutput;
164 let states = initialStates;
165 const timeSteps = inputs.shape[0];
166 const perStepInputs = tfc.unstack(inputs);
167 let perStepMasks;
168 if (mask != null) {
169 perStepMasks = tfc.unstack(mask);
170 }
171 for (let t = 0; t < timeSteps; ++t) {
172 const currentInput = perStepInputs[t];
173 const stepOutputs = tfc.tidy(() => stepFunction(currentInput, states));
174 if (mask == null) {
175 lastOutput = stepOutputs[0];
176 states = stepOutputs[1];
177 }
178 else {
179 const maskedOutputs = tfc.tidy(() => {
180 const stepMask = perStepMasks[t];
181 const negStepMask = tfc.sub(tfc.onesLike(stepMask), stepMask);
182 // TODO(cais): Would tfc.where() be better for performance?
183 const output = tfc.add(tfc.mul(stepOutputs[0], stepMask), tfc.mul(states[0], negStepMask));
184 const newStates = states.map((state, i) => {
185 return tfc.add(tfc.mul(stepOutputs[1][i], stepMask), tfc.mul(state, negStepMask));
186 });
187 return { output, newStates };
188 });
189 lastOutput = maskedOutputs.output;
190 states = maskedOutputs.newStates;
191 }
192 if (needPerStepOutputs) {
193 perStepOutputs.push(lastOutput);
194 }
195 }
196 let outputs;
197 if (needPerStepOutputs) {
198 const axis = 1;
199 outputs = tfc.stack(perStepOutputs, axis);
200 }
201 return [lastOutput, outputs, states];
202 });
203}
204export class RNN extends Layer {
205 constructor(args) {
206 super(args);
207 let cell;
208 if (args.cell == null) {
209 throw new ValueError('cell property is missing for the constructor of RNN.');
210 }
211 else if (Array.isArray(args.cell)) {
212 cell = new StackedRNNCells({ cells: args.cell });
213 }
214 else {
215 cell = args.cell;
216 }
217 if (cell.stateSize == null) {
218 throw new ValueError('The RNN cell should have an attribute `stateSize` (tuple of ' +
219 'integers, one integer per RNN state).');
220 }
221 this.cell = cell;
222 this.returnSequences =
223 args.returnSequences == null ? false : args.returnSequences;
224 this.returnState = args.returnState == null ? false : args.returnState;
225 this.goBackwards = args.goBackwards == null ? false : args.goBackwards;
226 this._stateful = args.stateful == null ? false : args.stateful;
227 this.unroll = args.unroll == null ? false : args.unroll;
228 this.supportsMasking = true;
229 this.inputSpec = [new InputSpec({ ndim: 3 })];
230 this.stateSpec = null;
231 this.states_ = null;
232 // TODO(cais): Add constantsSpec and numConstants.
233 this.numConstants = null;
234 // TODO(cais): Look into the use of initial_state in the kwargs of the
235 // constructor.
236 this.keptStates = [];
237 }
238 // Porting Note: This is the equivalent of `RNN.states` property getter in
239 // PyKeras.
240 getStates() {
241 if (this.states_ == null) {
242 const numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
243 return math_utils.range(0, numStates).map(x => null);
244 }
245 else {
246 return this.states_;
247 }
248 }
249 // Porting Note: This is the equivalent of the `RNN.states` property setter in
250 // PyKeras.
251 setStates(states) {
252 this.states_ = states;
253 }
254 computeOutputShape(inputShape) {
255 if (isArrayOfShapes(inputShape)) {
256 inputShape = inputShape[0];
257 }
258 inputShape = inputShape;
259 // TODO(cais): Remove the casting once stacked RNN cells become supported.
260 let stateSize = this.cell.stateSize;
261 if (!Array.isArray(stateSize)) {
262 stateSize = [stateSize];
263 }
264 const outputDim = stateSize[0];
265 let outputShape;
266 if (this.returnSequences) {
267 outputShape = [inputShape[0], inputShape[1], outputDim];
268 }
269 else {
270 outputShape = [inputShape[0], outputDim];
271 }
272 if (this.returnState) {
273 const stateShape = [];
274 for (const dim of stateSize) {
275 stateShape.push([inputShape[0], dim]);
276 }
277 return [outputShape].concat(stateShape);
278 }
279 else {
280 return outputShape;
281 }
282 }
283 computeMask(inputs, mask) {
284 return tfc.tidy(() => {
285 if (Array.isArray(mask)) {
286 mask = mask[0];
287 }
288 const outputMask = this.returnSequences ? mask : null;
289 if (this.returnState) {
290 const stateMask = this.states.map(s => null);
291 return [outputMask].concat(stateMask);
292 }
293 else {
294 return outputMask;
295 }
296 });
297 }
298 /**
299 * Get the current state tensors of the RNN.
300 *
301 * If the state hasn't been set, return an array of `null`s of the correct
302 * length.
303 */
304 get states() {
305 if (this.states_ == null) {
306 const numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
307 const output = [];
308 for (let i = 0; i < numStates; ++i) {
309 output.push(null);
310 }
311 return output;
312 }
313 else {
314 return this.states_;
315 }
316 }
317 set states(s) {
318 this.states_ = s;
319 }
320 build(inputShape) {
321 // Note inputShape will be an Array of Shapes of initial states and
322 // constants if these are passed in apply().
323 const constantShape = null;
324 if (this.numConstants != null) {
325 throw new NotImplementedError('Constants support is not implemented in RNN yet.');
326 }
327 if (isArrayOfShapes(inputShape)) {
328 inputShape = inputShape[0];
329 }
330 inputShape = inputShape;
331 const batchSize = this.stateful ? inputShape[0] : null;
332 const inputDim = inputShape.slice(2);
333 this.inputSpec[0] = new InputSpec({ shape: [batchSize, null, ...inputDim] });
334 // Allow cell (if RNNCell Layer) to build before we set or validate
335 // stateSpec.
336 const stepInputShape = [inputShape[0]].concat(inputShape.slice(2));
337 if (constantShape != null) {
338 throw new NotImplementedError('Constants support is not implemented in RNN yet.');
339 }
340 else {
341 this.cell.build(stepInputShape);
342 }
343 // Set or validate stateSpec.
344 let stateSize;
345 if (Array.isArray(this.cell.stateSize)) {
346 stateSize = this.cell.stateSize;
347 }
348 else {
349 stateSize = [this.cell.stateSize];
350 }
351 if (this.stateSpec != null) {
352 if (!util.arraysEqual(this.stateSpec.map(spec => spec.shape[spec.shape.length - 1]), stateSize)) {
353 throw new ValueError(`An initialState was passed that is not compatible with ` +
354 `cell.stateSize. Received stateSpec=${this.stateSpec}; ` +
355 `However cell.stateSize is ${this.cell.stateSize}`);
356 }
357 }
358 else {
359 this.stateSpec =
360 stateSize.map(dim => new InputSpec({ shape: [null, dim] }));
361 }
362 if (this.stateful) {
363 this.resetStates();
364 }
365 }
366 /**
367 * Reset the state tensors of the RNN.
368 *
369 * If the `states` argument is `undefined` or `null`, will set the
370 * state tensor(s) of the RNN to all-zero tensors of the appropriate
371 * shape(s).
372 *
373 * If `states` is provided, will set the state tensors of the RNN to its
374 * value.
375 *
376 * @param states Optional externally-provided initial states.
377 * @param training Whether this call is done during training. For stateful
378 * RNNs, this affects whether the old states are kept or discarded. In
379 * particular, if `training` is `true`, the old states will be kept so
380 * that subsequent backpropgataion through time (BPTT) may work properly.
381 * Else, the old states will be discarded.
382 */
383 resetStates(states, training = false) {
384 tidy(() => {
385 if (!this.stateful) {
386 throw new AttributeError('Cannot call resetStates() on an RNN Layer that is not stateful.');
387 }
388 const batchSize = this.inputSpec[0].shape[0];
389 if (batchSize == null) {
390 throw new ValueError('If an RNN is stateful, it needs to know its batch size. Specify ' +
391 'the batch size of your input tensors: \n' +
392 '- If using a Sequential model, specify the batch size by ' +
393 'passing a `batchInputShape` option to your first layer.\n' +
394 '- If using the functional API, specify the batch size by ' +
395 'passing a `batchShape` option to your Input layer.');
396 }
397 // Initialize state if null.
398 if (this.states_ == null) {
399 if (Array.isArray(this.cell.stateSize)) {
400 this.states_ =
401 this.cell.stateSize.map(dim => tfc.zeros([batchSize, dim]));
402 }
403 else {
404 this.states_ = [tfc.zeros([batchSize, this.cell.stateSize])];
405 }
406 }
407 else if (states == null) {
408 // Dispose old state tensors.
409 tfc.dispose(this.states_);
410 // For stateful RNNs, fully dispose kept old states.
411 if (this.keptStates != null) {
412 tfc.dispose(this.keptStates);
413 this.keptStates = [];
414 }
415 if (Array.isArray(this.cell.stateSize)) {
416 this.states_ =
417 this.cell.stateSize.map(dim => tfc.zeros([batchSize, dim]));
418 }
419 else {
420 this.states_[0] = tfc.zeros([batchSize, this.cell.stateSize]);
421 }
422 }
423 else {
424 if (!Array.isArray(states)) {
425 states = [states];
426 }
427 if (states.length !== this.states_.length) {
428 throw new ValueError(`Layer ${this.name} expects ${this.states_.length} state(s), ` +
429 `but it received ${states.length} state value(s). Input ` +
430 `received: ${states}`);
431 }
432 if (training === true) {
433 // Store old state tensors for complete disposal later, i.e., during
434 // the next no-arg call to this method. We do not dispose the old
435 // states immediately because that BPTT (among other things) require
436 // them.
437 this.keptStates.push(this.states_.slice());
438 }
439 else {
440 tfc.dispose(this.states_);
441 }
442 for (let index = 0; index < this.states_.length; ++index) {
443 const value = states[index];
444 const dim = Array.isArray(this.cell.stateSize) ?
445 this.cell.stateSize[index] :
446 this.cell.stateSize;
447 const expectedShape = [batchSize, dim];
448 if (!util.arraysEqual(value.shape, expectedShape)) {
449 throw new ValueError(`State ${index} is incompatible with layer ${this.name}: ` +
450 `expected shape=${expectedShape}, received shape=${value.shape}`);
451 }
452 this.states_[index] = value;
453 }
454 }
455 this.states_ = this.states_.map(state => tfc.keep(state.clone()));
456 });
457 }
458 apply(inputs, kwargs) {
459 // TODO(cais): Figure out whether initialState is in kwargs or inputs.
460 let initialState = kwargs == null ? null : kwargs['initialState'];
461 let constants = kwargs == null ? null : kwargs['constants'];
462 if (kwargs == null) {
463 kwargs = {};
464 }
465 const standardized = standardizeArgs(inputs, initialState, constants, this.numConstants);
466 inputs = standardized.inputs;
467 initialState = standardized.initialState;
468 constants = standardized.constants;
469 // If any of `initial_state` or `constants` are specified and are
470 // `tf.SymbolicTensor`s, then add them to the inputs and temporarily modify
471 // the input_spec to include them.
472 let additionalInputs = [];
473 let additionalSpecs = [];
474 if (initialState != null) {
475 kwargs['initialState'] = initialState;
476 additionalInputs = additionalInputs.concat(initialState);
477 this.stateSpec = [];
478 for (const state of initialState) {
479 this.stateSpec.push(new InputSpec({ shape: state.shape }));
480 }
481 // TODO(cais): Use the following instead.
482 // this.stateSpec = initialState.map(state => new InputSpec({shape:
483 // state.shape}));
484 additionalSpecs = additionalSpecs.concat(this.stateSpec);
485 }
486 if (constants != null) {
487 kwargs['constants'] = constants;
488 additionalInputs = additionalInputs.concat(constants);
489 // TODO(cais): Add this.constantsSpec.
490 this.numConstants = constants.length;
491 }
492 const isTensor = additionalInputs[0] instanceof SymbolicTensor;
493 if (isTensor) {
494 // Compute full input spec, including state and constants.
495 const fullInput = [inputs].concat(additionalInputs);
496 const fullInputSpec = this.inputSpec.concat(additionalSpecs);
497 // Perform the call with temporarily replaced inputSpec.
498 const originalInputSpec = this.inputSpec;
499 this.inputSpec = fullInputSpec;
500 const output = super.apply(fullInput, kwargs);
501 this.inputSpec = originalInputSpec;
502 return output;
503 }
504 else {
505 return super.apply(inputs, kwargs);
506 }
507 }
508 // tslint:disable-next-line:no-any
509 call(inputs, kwargs) {
510 // Input shape: `[samples, time (padded with zeros), input_dim]`.
511 // Note that the .build() method of subclasses **must** define
512 // this.inputSpec and this.stateSpec owith complete input shapes.
513 return tidy(() => {
514 const mask = kwargs == null ? null : kwargs['mask'];
515 const training = kwargs == null ? null : kwargs['training'];
516 let initialState = kwargs == null ? null : kwargs['initialState'];
517 inputs = getExactlyOneTensor(inputs);
518 if (initialState == null) {
519 if (this.stateful) {
520 initialState = this.states_;
521 }
522 else {
523 initialState = this.getInitialState(inputs);
524 }
525 }
526 const numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
527 if (initialState.length !== numStates) {
528 throw new ValueError(`RNN Layer has ${numStates} state(s) but was passed ` +
529 `${initialState.length} initial state(s).`);
530 }
531 if (this.unroll) {
532 console.warn('Ignoring unroll = true for RNN layer, due to imperative backend.');
533 }
534 const cellCallKwargs = { training };
535 // TODO(cais): Add support for constants.
536 const step = (inputs, states) => {
537 // `inputs` and `states` are concatenated to form a single `Array` of
538 // `tf.Tensor`s as the input to `cell.call()`.
539 const outputs = this.cell.call([inputs].concat(states), cellCallKwargs);
540 // Marshall the return value into output and new states.
541 return [outputs[0], outputs.slice(1)];
542 };
543 // TODO(cais): Add support for constants.
544 const rnnOutputs = rnn(step, inputs, initialState, this.goBackwards, mask, null, this.unroll, this.returnSequences);
545 const lastOutput = rnnOutputs[0];
546 const outputs = rnnOutputs[1];
547 const states = rnnOutputs[2];
548 if (this.stateful) {
549 this.resetStates(states, training);
550 }
551 const output = this.returnSequences ? outputs : lastOutput;
552 // TODO(cais): Porperty set learning phase flag.
553 if (this.returnState) {
554 return [output].concat(states);
555 }
556 else {
557 return output;
558 }
559 });
560 }
561 getInitialState(inputs) {
562 return tidy(() => {
563 // Build an all-zero tensor of shape [samples, outputDim].
564 // [Samples, timeSteps, inputDim].
565 let initialState = tfc.zeros(inputs.shape);
566 // [Samples].
567 initialState = tfc.sum(initialState, [1, 2]);
568 initialState = K.expandDims(initialState); // [Samples, 1].
569 if (Array.isArray(this.cell.stateSize)) {
570 return this.cell.stateSize.map(dim => dim > 1 ? K.tile(initialState, [1, dim]) : initialState);
571 }
572 else {
573 return this.cell.stateSize > 1 ?
574 [K.tile(initialState, [1, this.cell.stateSize])] :
575 [initialState];
576 }
577 });
578 }
579 get trainableWeights() {
580 if (!this.trainable) {
581 return [];
582 }
583 // Porting Note: In TypeScript, `this` is always an instance of `Layer`.
584 return this.cell.trainableWeights;
585 }
586 get nonTrainableWeights() {
587 // Porting Note: In TypeScript, `this` is always an instance of `Layer`.
588 if (!this.trainable) {
589 return this.cell.weights;
590 }
591 return this.cell.nonTrainableWeights;
592 }
593 setFastWeightInitDuringBuild(value) {
594 super.setFastWeightInitDuringBuild(value);
595 if (this.cell != null) {
596 this.cell.setFastWeightInitDuringBuild(value);
597 }
598 }
599 getConfig() {
600 const baseConfig = super.getConfig();
601 const config = {
602 returnSequences: this.returnSequences,
603 returnState: this.returnState,
604 goBackwards: this.goBackwards,
605 stateful: this.stateful,
606 unroll: this.unroll,
607 };
608 if (this.numConstants != null) {
609 config['numConstants'] = this.numConstants;
610 }
611 const cellConfig = this.cell.getConfig();
612 if (this.getClassName() === RNN.className) {
613 config['cell'] = {
614 'className': this.cell.getClassName(),
615 'config': cellConfig,
616 };
617 }
618 // this order is necessary, to prevent cell name from replacing layer name
619 return Object.assign(Object.assign(Object.assign({}, cellConfig), baseConfig), config);
620 }
621 /** @nocollapse */
622 static fromConfig(cls, config, customObjects = {}) {
623 const cellConfig = config['cell'];
624 const cell = deserialize(cellConfig, customObjects);
625 return new cls(Object.assign(config, { cell }));
626 }
627}
628/** @nocollapse */
629RNN.className = 'RNN';
630serialization.registerClass(RNN);
631// Porting Note: This is a common parent class for RNN cells. There is no
632// equivalent of this in PyKeras. Having a common parent class forgoes the
633// need for `has_attr(cell, ...)` checks or its TypeScript equivalent.
634/**
635 * An RNNCell layer.
636 *
637 * @doc {heading: 'Layers', subheading: 'Classes'}
638 */
639export class RNNCell extends Layer {
640}
641export class SimpleRNNCell extends RNNCell {
642 constructor(args) {
643 super(args);
644 this.DEFAULT_ACTIVATION = 'tanh';
645 this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
646 this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
647 this.DEFAULT_BIAS_INITIALIZER = 'zeros';
648 this.units = args.units;
649 assertPositiveInteger(this.units, `units`);
650 this.activation = getActivation(args.activation == null ? this.DEFAULT_ACTIVATION : args.activation);
651 this.useBias = args.useBias == null ? true : args.useBias;
652 this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
653 this.recurrentInitializer = getInitializer(args.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER);
654 this.biasInitializer =
655 getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
656 this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
657 this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
658 this.biasRegularizer = getRegularizer(args.biasRegularizer);
659 this.kernelConstraint = getConstraint(args.kernelConstraint);
660 this.recurrentConstraint = getConstraint(args.recurrentConstraint);
661 this.biasConstraint = getConstraint(args.biasConstraint);
662 this.dropout = math_utils.min([1, math_utils.max([0, args.dropout == null ? 0 : args.dropout])]);
663 this.recurrentDropout = math_utils.min([
664 1,
665 math_utils.max([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
666 ]);
667 this.dropoutFunc = args.dropoutFunc;
668 this.stateSize = this.units;
669 this.dropoutMask = null;
670 this.recurrentDropoutMask = null;
671 }
672 build(inputShape) {
673 inputShape = getExactlyOneShape(inputShape);
674 // TODO(cais): Use regularizer.
675 this.kernel = this.addWeight('kernel', [inputShape[inputShape.length - 1], this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
676 this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
677 if (this.useBias) {
678 this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
679 }
680 else {
681 this.bias = null;
682 }
683 this.built = true;
684 }
685 // Porting Note: PyKeras' equivalent of this method takes two tensor inputs:
686 // `inputs` and `states`. Here, the two tensors are combined into an
687 // `Tensor[]` Array as the first input argument.
688 // Similarly, PyKeras' equivalent of this method returns two values:
689 // `output` and `[output]`. Here the two are combined into one length-2
690 // `Tensor[]`, consisting of `output` repeated.
691 call(inputs, kwargs) {
692 return tidy(() => {
693 inputs = inputs;
694 if (inputs.length !== 2) {
695 throw new ValueError(`SimpleRNNCell expects 2 input Tensors, got ${inputs.length}.`);
696 }
697 let prevOutput = inputs[1];
698 inputs = inputs[0];
699 const training = kwargs['training'] == null ? false : kwargs['training'];
700 if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
701 this.dropoutMask = generateDropoutMask({
702 ones: () => tfc.onesLike(inputs),
703 rate: this.dropout,
704 training,
705 dropoutFunc: this.dropoutFunc,
706 });
707 }
708 if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
709 this.recurrentDropoutMask == null) {
710 this.recurrentDropoutMask = generateDropoutMask({
711 ones: () => tfc.onesLike(prevOutput),
712 rate: this.recurrentDropout,
713 training,
714 dropoutFunc: this.dropoutFunc,
715 });
716 }
717 let h;
718 const dpMask = this.dropoutMask;
719 const recDpMask = this.recurrentDropoutMask;
720 if (dpMask != null) {
721 h = K.dot(tfc.mul(inputs, dpMask), this.kernel.read());
722 }
723 else {
724 h = K.dot(inputs, this.kernel.read());
725 }
726 if (this.bias != null) {
727 h = K.biasAdd(h, this.bias.read());
728 }
729 if (recDpMask != null) {
730 prevOutput = tfc.mul(prevOutput, recDpMask);
731 }
732 let output = tfc.add(h, K.dot(prevOutput, this.recurrentKernel.read()));
733 if (this.activation != null) {
734 output = this.activation.apply(output);
735 }
736 // TODO(cais): Properly set learning phase on output tensor?
737 return [output, output];
738 });
739 }
740 getConfig() {
741 const baseConfig = super.getConfig();
742 const config = {
743 units: this.units,
744 activation: serializeActivation(this.activation),
745 useBias: this.useBias,
746 kernelInitializer: serializeInitializer(this.kernelInitializer),
747 recurrentInitializer: serializeInitializer(this.recurrentInitializer),
748 biasInitializer: serializeInitializer(this.biasInitializer),
749 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
750 recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
751 biasRegularizer: serializeRegularizer(this.biasRegularizer),
752 activityRegularizer: serializeRegularizer(this.activityRegularizer),
753 kernelConstraint: serializeConstraint(this.kernelConstraint),
754 recurrentConstraint: serializeConstraint(this.recurrentConstraint),
755 biasConstraint: serializeConstraint(this.biasConstraint),
756 dropout: this.dropout,
757 recurrentDropout: this.recurrentDropout,
758 };
759 return Object.assign(Object.assign({}, baseConfig), config);
760 }
761}
762/** @nocollapse */
763SimpleRNNCell.className = 'SimpleRNNCell';
764serialization.registerClass(SimpleRNNCell);
765export class SimpleRNN extends RNN {
766 constructor(args) {
767 args.cell = new SimpleRNNCell(args);
768 super(args);
769 // TODO(cais): Add activityRegularizer.
770 }
771 call(inputs, kwargs) {
772 return tidy(() => {
773 if (this.cell.dropoutMask != null) {
774 tfc.dispose(this.cell.dropoutMask);
775 this.cell.dropoutMask = null;
776 }
777 if (this.cell.recurrentDropoutMask != null) {
778 tfc.dispose(this.cell.recurrentDropoutMask);
779 this.cell.recurrentDropoutMask = null;
780 }
781 const mask = kwargs == null ? null : kwargs['mask'];
782 const training = kwargs == null ? null : kwargs['training'];
783 const initialState = kwargs == null ? null : kwargs['initialState'];
784 return super.call(inputs, { mask, training, initialState });
785 });
786 }
787 /** @nocollapse */
788 static fromConfig(cls, config) {
789 return new cls(config);
790 }
791}
792/** @nocollapse */
793SimpleRNN.className = 'SimpleRNN';
794serialization.registerClass(SimpleRNN);
795export class GRUCell extends RNNCell {
796 constructor(args) {
797 super(args);
798 this.DEFAULT_ACTIVATION = 'tanh';
799 this.DEFAULT_RECURRENT_ACTIVATION = 'hardSigmoid';
800 this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
801 this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
802 this.DEFAULT_BIAS_INITIALIZER = 'zeros';
803 if (args.resetAfter) {
804 throw new ValueError(`GRUCell does not support reset_after parameter set to true.`);
805 }
806 this.units = args.units;
807 assertPositiveInteger(this.units, 'units');
808 this.activation = getActivation(args.activation === undefined ? this.DEFAULT_ACTIVATION :
809 args.activation);
810 this.recurrentActivation = getActivation(args.recurrentActivation === undefined ?
811 this.DEFAULT_RECURRENT_ACTIVATION :
812 args.recurrentActivation);
813 this.useBias = args.useBias == null ? true : args.useBias;
814 this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
815 this.recurrentInitializer = getInitializer(args.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER);
816 this.biasInitializer =
817 getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
818 this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
819 this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
820 this.biasRegularizer = getRegularizer(args.biasRegularizer);
821 this.kernelConstraint = getConstraint(args.kernelConstraint);
822 this.recurrentConstraint = getConstraint(args.recurrentConstraint);
823 this.biasConstraint = getConstraint(args.biasConstraint);
824 this.dropout = math_utils.min([1, math_utils.max([0, args.dropout == null ? 0 : args.dropout])]);
825 this.recurrentDropout = math_utils.min([
826 1,
827 math_utils.max([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
828 ]);
829 this.dropoutFunc = args.dropoutFunc;
830 this.implementation = args.implementation;
831 this.stateSize = this.units;
832 this.dropoutMask = null;
833 this.recurrentDropoutMask = null;
834 }
835 build(inputShape) {
836 inputShape = getExactlyOneShape(inputShape);
837 const inputDim = inputShape[inputShape.length - 1];
838 this.kernel = this.addWeight('kernel', [inputDim, this.units * 3], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
839 this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units * 3], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
840 if (this.useBias) {
841 this.bias = this.addWeight('bias', [this.units * 3], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
842 }
843 else {
844 this.bias = null;
845 }
846 // Porting Notes: Unlike the PyKeras implementation, we perform slicing
847 // of the weights and bias in the call() method, at execution time.
848 this.built = true;
849 }
850 call(inputs, kwargs) {
851 return tidy(() => {
852 inputs = inputs;
853 if (inputs.length !== 2) {
854 throw new ValueError(`GRUCell expects 2 input Tensors (inputs, h, c), got ` +
855 `${inputs.length}.`);
856 }
857 const training = kwargs['training'] == null ? false : kwargs['training'];
858 let hTMinus1 = inputs[1]; // Previous memory state.
859 inputs = inputs[0];
860 // Note: For superior performance, TensorFlow.js always uses
861 // implementation 2, regardless of the actual value of
862 // config.implementation.
863 if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
864 this.dropoutMask = generateDropoutMask({
865 ones: () => tfc.onesLike(inputs),
866 rate: this.dropout,
867 training,
868 count: 3,
869 dropoutFunc: this.dropoutFunc,
870 });
871 }
872 if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
873 this.recurrentDropoutMask == null) {
874 this.recurrentDropoutMask = generateDropoutMask({
875 ones: () => tfc.onesLike(hTMinus1),
876 rate: this.recurrentDropout,
877 training,
878 count: 3,
879 dropoutFunc: this.dropoutFunc,
880 });
881 }
882 const dpMask = this.dropoutMask;
883 const recDpMask = this.recurrentDropoutMask;
884 let z;
885 let r;
886 let hh;
887 if (0 < this.dropout && this.dropout < 1) {
888 inputs = tfc.mul(inputs, dpMask[0]);
889 }
890 let matrixX = K.dot(inputs, this.kernel.read());
891 if (this.useBias) {
892 matrixX = K.biasAdd(matrixX, this.bias.read());
893 }
894 if (0 < this.recurrentDropout && this.recurrentDropout < 1) {
895 hTMinus1 = tfc.mul(hTMinus1, recDpMask[0]);
896 }
897 const recurrentKernelValue = this.recurrentKernel.read();
898 const [rk1, rk2] = tfc.split(recurrentKernelValue, [2 * this.units, this.units], recurrentKernelValue.rank - 1);
899 const matrixInner = K.dot(hTMinus1, rk1);
900 const [xZ, xR, xH] = tfc.split(matrixX, 3, matrixX.rank - 1);
901 const [recurrentZ, recurrentR] = tfc.split(matrixInner, 2, matrixInner.rank - 1);
902 z = this.recurrentActivation.apply(tfc.add(xZ, recurrentZ));
903 r = this.recurrentActivation.apply(tfc.add(xR, recurrentR));
904 const recurrentH = K.dot(tfc.mul(r, hTMinus1), rk2);
905 hh = this.activation.apply(tfc.add(xH, recurrentH));
906 const h = tfc.add(tfc.mul(z, hTMinus1), tfc.mul(tfc.add(1, tfc.neg(z)), hh));
907 // TODO(cais): Add use_learning_phase flag properly.
908 return [h, h];
909 });
910 }
911 getConfig() {
912 const baseConfig = super.getConfig();
913 const config = {
914 units: this.units,
915 activation: serializeActivation(this.activation),
916 recurrentActivation: serializeActivation(this.recurrentActivation),
917 useBias: this.useBias,
918 kernelInitializer: serializeInitializer(this.kernelInitializer),
919 recurrentInitializer: serializeInitializer(this.recurrentInitializer),
920 biasInitializer: serializeInitializer(this.biasInitializer),
921 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
922 recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
923 biasRegularizer: serializeRegularizer(this.biasRegularizer),
924 activityRegularizer: serializeRegularizer(this.activityRegularizer),
925 kernelConstraint: serializeConstraint(this.kernelConstraint),
926 recurrentConstraint: serializeConstraint(this.recurrentConstraint),
927 biasConstraint: serializeConstraint(this.biasConstraint),
928 dropout: this.dropout,
929 recurrentDropout: this.recurrentDropout,
930 implementation: this.implementation,
931 resetAfter: false
932 };
933 return Object.assign(Object.assign({}, baseConfig), config);
934 }
935}
936/** @nocollapse */
937GRUCell.className = 'GRUCell';
938serialization.registerClass(GRUCell);
939export class GRU extends RNN {
940 constructor(args) {
941 if (args.implementation === 0) {
942 console.warn('`implementation=0` has been deprecated, and now defaults to ' +
943 '`implementation=1`. Please update your layer call.');
944 }
945 args.cell = new GRUCell(args);
946 super(args);
947 // TODO(cais): Add activityRegularizer.
948 }
949 call(inputs, kwargs) {
950 return tidy(() => {
951 if (this.cell.dropoutMask != null) {
952 tfc.dispose(this.cell.dropoutMask);
953 this.cell.dropoutMask = null;
954 }
955 if (this.cell.recurrentDropoutMask != null) {
956 tfc.dispose(this.cell.recurrentDropoutMask);
957 this.cell.recurrentDropoutMask = null;
958 }
959 const mask = kwargs == null ? null : kwargs['mask'];
960 const training = kwargs == null ? null : kwargs['training'];
961 const initialState = kwargs == null ? null : kwargs['initialState'];
962 return super.call(inputs, { mask, training, initialState });
963 });
964 }
965 /** @nocollapse */
966 static fromConfig(cls, config) {
967 if (config['implmentation'] === 0) {
968 config['implementation'] = 1;
969 }
970 return new cls(config);
971 }
972}
973/** @nocollapse */
974GRU.className = 'GRU';
975serialization.registerClass(GRU);
976export class LSTMCell extends RNNCell {
977 constructor(args) {
978 super(args);
979 this.DEFAULT_ACTIVATION = 'tanh';
980 this.DEFAULT_RECURRENT_ACTIVATION = 'hardSigmoid';
981 this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
982 this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
983 this.DEFAULT_BIAS_INITIALIZER = 'zeros';
984 this.units = args.units;
985 assertPositiveInteger(this.units, 'units');
986 this.activation = getActivation(args.activation === undefined ? this.DEFAULT_ACTIVATION :
987 args.activation);
988 this.recurrentActivation = getActivation(args.recurrentActivation === undefined ?
989 this.DEFAULT_RECURRENT_ACTIVATION :
990 args.recurrentActivation);
991 this.useBias = args.useBias == null ? true : args.useBias;
992 this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
993 this.recurrentInitializer = getInitializer(args.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER);
994 this.biasInitializer =
995 getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
996 this.unitForgetBias = args.unitForgetBias;
997 this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
998 this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
999 this.biasRegularizer = getRegularizer(args.biasRegularizer);
1000 this.kernelConstraint = getConstraint(args.kernelConstraint);
1001 this.recurrentConstraint = getConstraint(args.recurrentConstraint);
1002 this.biasConstraint = getConstraint(args.biasConstraint);
1003 this.dropout = math_utils.min([1, math_utils.max([0, args.dropout == null ? 0 : args.dropout])]);
1004 this.recurrentDropout = math_utils.min([
1005 1,
1006 math_utils.max([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
1007 ]);
1008 this.dropoutFunc = args.dropoutFunc;
1009 this.implementation = args.implementation;
1010 this.stateSize = [this.units, this.units];
1011 this.dropoutMask = null;
1012 this.recurrentDropoutMask = null;
1013 }
1014 build(inputShape) {
1015 var _a;
1016 inputShape = getExactlyOneShape(inputShape);
1017 const inputDim = inputShape[inputShape.length - 1];
1018 this.kernel = this.addWeight('kernel', [inputDim, this.units * 4], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
1019 this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units * 4], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
1020 let biasInitializer;
1021 if (this.useBias) {
1022 if (this.unitForgetBias) {
1023 const capturedBiasInit = this.biasInitializer;
1024 const capturedUnits = this.units;
1025 biasInitializer = new (_a = class CustomInit extends Initializer {
1026 apply(shape, dtype) {
1027 // TODO(cais): More informative variable names?
1028 const bI = capturedBiasInit.apply([capturedUnits]);
1029 const bF = (new Ones()).apply([capturedUnits]);
1030 const bCAndH = capturedBiasInit.apply([capturedUnits * 2]);
1031 return K.concatAlongFirstAxis(K.concatAlongFirstAxis(bI, bF), bCAndH);
1032 }
1033 },
1034 /** @nocollapse */
1035 _a.className = 'CustomInit',
1036 _a)();
1037 }
1038 else {
1039 biasInitializer = this.biasInitializer;
1040 }
1041 this.bias = this.addWeight('bias', [this.units * 4], null, biasInitializer, this.biasRegularizer, true, this.biasConstraint);
1042 }
1043 else {
1044 this.bias = null;
1045 }
1046 // Porting Notes: Unlike the PyKeras implementation, we perform slicing
1047 // of the weights and bias in the call() method, at execution time.
1048 this.built = true;
1049 }
1050 call(inputs, kwargs) {
1051 return tidy(() => {
1052 const training = kwargs['training'] == null ? false : kwargs['training'];
1053 inputs = inputs;
1054 if (inputs.length !== 3) {
1055 throw new ValueError(`LSTMCell expects 3 input Tensors (inputs, h, c), got ` +
1056 `${inputs.length}.`);
1057 }
1058 let hTMinus1 = inputs[1]; // Previous memory state.
1059 const cTMinus1 = inputs[2]; // Previous carry state.
1060 inputs = inputs[0];
1061 if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
1062 this.dropoutMask = generateDropoutMask({
1063 ones: () => tfc.onesLike(inputs),
1064 rate: this.dropout,
1065 training,
1066 count: 4,
1067 dropoutFunc: this.dropoutFunc
1068 });
1069 }
1070 if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
1071 this.recurrentDropoutMask == null) {
1072 this.recurrentDropoutMask = generateDropoutMask({
1073 ones: () => tfc.onesLike(hTMinus1),
1074 rate: this.recurrentDropout,
1075 training,
1076 count: 4,
1077 dropoutFunc: this.dropoutFunc
1078 });
1079 }
1080 const dpMask = this.dropoutMask;
1081 const recDpMask = this.recurrentDropoutMask;
1082 // Note: For superior performance, TensorFlow.js always uses
1083 // implementation 2 regardless of the actual value of
1084 // config.implementation.
1085 let i;
1086 let f;
1087 let c;
1088 let o;
1089 if (0 < this.dropout && this.dropout < 1) {
1090 inputs = tfc.mul(inputs, dpMask[0]);
1091 }
1092 let z = K.dot(inputs, this.kernel.read());
1093 if (0 < this.recurrentDropout && this.recurrentDropout < 1) {
1094 hTMinus1 = tfc.mul(hTMinus1, recDpMask[0]);
1095 }
1096 z = tfc.add(z, K.dot(hTMinus1, this.recurrentKernel.read()));
1097 if (this.useBias) {
1098 z = K.biasAdd(z, this.bias.read());
1099 }
1100 const [z0, z1, z2, z3] = tfc.split(z, 4, z.rank - 1);
1101 i = this.recurrentActivation.apply(z0);
1102 f = this.recurrentActivation.apply(z1);
1103 c = tfc.add(tfc.mul(f, cTMinus1), tfc.mul(i, this.activation.apply(z2)));
1104 o = this.recurrentActivation.apply(z3);
1105 const h = tfc.mul(o, this.activation.apply(c));
1106 // TODO(cais): Add use_learning_phase flag properly.
1107 return [h, h, c];
1108 });
1109 }
1110 getConfig() {
1111 const baseConfig = super.getConfig();
1112 const config = {
1113 units: this.units,
1114 activation: serializeActivation(this.activation),
1115 recurrentActivation: serializeActivation(this.recurrentActivation),
1116 useBias: this.useBias,
1117 kernelInitializer: serializeInitializer(this.kernelInitializer),
1118 recurrentInitializer: serializeInitializer(this.recurrentInitializer),
1119 biasInitializer: serializeInitializer(this.biasInitializer),
1120 unitForgetBias: this.unitForgetBias,
1121 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
1122 recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
1123 biasRegularizer: serializeRegularizer(this.biasRegularizer),
1124 activityRegularizer: serializeRegularizer(this.activityRegularizer),
1125 kernelConstraint: serializeConstraint(this.kernelConstraint),
1126 recurrentConstraint: serializeConstraint(this.recurrentConstraint),
1127 biasConstraint: serializeConstraint(this.biasConstraint),
1128 dropout: this.dropout,
1129 recurrentDropout: this.recurrentDropout,
1130 implementation: this.implementation,
1131 };
1132 return Object.assign(Object.assign({}, baseConfig), config);
1133 }
1134}
1135/** @nocollapse */
1136LSTMCell.className = 'LSTMCell';
1137serialization.registerClass(LSTMCell);
1138export class LSTM extends RNN {
1139 constructor(args) {
1140 if (args.implementation === 0) {
1141 console.warn('`implementation=0` has been deprecated, and now defaults to ' +
1142 '`implementation=1`. Please update your layer call.');
1143 }
1144 args.cell = new LSTMCell(args);
1145 super(args);
1146 // TODO(cais): Add activityRegularizer.
1147 }
1148 call(inputs, kwargs) {
1149 return tidy(() => {
1150 if (this.cell.dropoutMask != null) {
1151 tfc.dispose(this.cell.dropoutMask);
1152 this.cell.dropoutMask = null;
1153 }
1154 if (this.cell.recurrentDropoutMask != null) {
1155 tfc.dispose(this.cell.recurrentDropoutMask);
1156 this.cell.recurrentDropoutMask = null;
1157 }
1158 const mask = kwargs == null ? null : kwargs['mask'];
1159 const training = kwargs == null ? null : kwargs['training'];
1160 const initialState = kwargs == null ? null : kwargs['initialState'];
1161 return super.call(inputs, { mask, training, initialState });
1162 });
1163 }
1164 /** @nocollapse */
1165 static fromConfig(cls, config) {
1166 if (config['implmentation'] === 0) {
1167 config['implementation'] = 1;
1168 }
1169 return new cls(config);
1170 }
1171}
1172/** @nocollapse */
1173LSTM.className = 'LSTM';
1174serialization.registerClass(LSTM);
1175export class StackedRNNCells extends RNNCell {
1176 constructor(args) {
1177 super(args);
1178 this.cells = args.cells;
1179 }
1180 get stateSize() {
1181 // States are a flat list in reverse order of the cell stack.
1182 // This allows perserving the requirement `stack.statesize[0] ===
1183 // outputDim`. E.g., states of a 2-layer LSTM would be `[h2, c2, h1, c1]`,
1184 // assuming one LSTM has states `[h, c]`.
1185 const stateSize = [];
1186 for (const cell of this.cells.slice().reverse()) {
1187 if (Array.isArray(cell.stateSize)) {
1188 stateSize.push(...cell.stateSize);
1189 }
1190 else {
1191 stateSize.push(cell.stateSize);
1192 }
1193 }
1194 return stateSize;
1195 }
1196 call(inputs, kwargs) {
1197 return tidy(() => {
1198 inputs = inputs;
1199 let states = inputs.slice(1);
1200 // Recover per-cell states.
1201 const nestedStates = [];
1202 for (const cell of this.cells.slice().reverse()) {
1203 if (Array.isArray(cell.stateSize)) {
1204 nestedStates.push(states.splice(0, cell.stateSize.length));
1205 }
1206 else {
1207 nestedStates.push(states.splice(0, 1));
1208 }
1209 }
1210 nestedStates.reverse();
1211 // Call the cells in order and store the returned states.
1212 const newNestedStates = [];
1213 let callInputs;
1214 for (let i = 0; i < this.cells.length; ++i) {
1215 const cell = this.cells[i];
1216 states = nestedStates[i];
1217 // TODO(cais): Take care of constants.
1218 if (i === 0) {
1219 callInputs = [inputs[0]].concat(states);
1220 }
1221 else {
1222 callInputs = [callInputs[0]].concat(states);
1223 }
1224 callInputs = cell.call(callInputs, kwargs);
1225 newNestedStates.push(callInputs.slice(1));
1226 }
1227 // Format the new states as a flat list in reverse cell order.
1228 states = [];
1229 for (const cellStates of newNestedStates.slice().reverse()) {
1230 states.push(...cellStates);
1231 }
1232 return [callInputs[0]].concat(states);
1233 });
1234 }
1235 build(inputShape) {
1236 if (isArrayOfShapes(inputShape)) {
1237 // TODO(cais): Take care of input constants.
1238 // const constantShape = inputShape.slice(1);
1239 inputShape = inputShape[0];
1240 }
1241 inputShape = inputShape;
1242 let outputDim;
1243 this.cells.forEach((cell, i) => {
1244 nameScope(`RNNCell_${i}`, () => {
1245 // TODO(cais): Take care of input constants.
1246 cell.build(inputShape);
1247 if (Array.isArray(cell.stateSize)) {
1248 outputDim = cell.stateSize[0];
1249 }
1250 else {
1251 outputDim = cell.stateSize;
1252 }
1253 inputShape = [inputShape[0], outputDim];
1254 });
1255 });
1256 this.built = true;
1257 }
1258 getConfig() {
1259 const baseConfig = super.getConfig();
1260 const getCellConfig = (cell) => {
1261 return {
1262 'className': cell.getClassName(),
1263 'config': cell.getConfig(),
1264 };
1265 };
1266 const cellConfigs = this.cells.map(getCellConfig);
1267 const config = { 'cells': cellConfigs };
1268 return Object.assign(Object.assign({}, baseConfig), config);
1269 }
1270 /** @nocollapse */
1271 static fromConfig(cls, config, customObjects = {}) {
1272 const cells = [];
1273 for (const cellConfig of config['cells']) {
1274 cells.push(deserialize(cellConfig, customObjects));
1275 }
1276 return new cls({ cells });
1277 }
1278 get trainableWeights() {
1279 if (!this.trainable) {
1280 return [];
1281 }
1282 const weights = [];
1283 for (const cell of this.cells) {
1284 weights.push(...cell.trainableWeights);
1285 }
1286 return weights;
1287 }
1288 get nonTrainableWeights() {
1289 const weights = [];
1290 for (const cell of this.cells) {
1291 weights.push(...cell.nonTrainableWeights);
1292 }
1293 if (!this.trainable) {
1294 const trainableWeights = [];
1295 for (const cell of this.cells) {
1296 trainableWeights.push(...cell.trainableWeights);
1297 }
1298 return trainableWeights.concat(weights);
1299 }
1300 return weights;
1301 }
1302 /**
1303 * Retrieve the weights of a the model.
1304 *
1305 * @returns A flat `Array` of `tf.Tensor`s.
1306 */
1307 getWeights() {
1308 const weights = [];
1309 for (const cell of this.cells) {
1310 weights.push(...cell.weights);
1311 }
1312 return batchGetValue(weights);
1313 }
1314 /**
1315 * Set the weights of the model.
1316 *
1317 * @param weights An `Array` of `tf.Tensor`s with shapes and types matching
1318 * the output of `getWeights()`.
1319 */
1320 setWeights(weights) {
1321 const tuples = [];
1322 for (const cell of this.cells) {
1323 const numParams = cell.weights.length;
1324 const inputWeights = weights.splice(numParams);
1325 for (let i = 0; i < cell.weights.length; ++i) {
1326 tuples.push([cell.weights[i], inputWeights[i]]);
1327 }
1328 }
1329 batchSetValue(tuples);
1330 }
1331}
1332/** @nocollapse */
1333StackedRNNCells.className = 'StackedRNNCells';
1334serialization.registerClass(StackedRNNCells);
1335export function generateDropoutMask(args) {
1336 const { ones, rate, training = false, count = 1, dropoutFunc } = args;
1337 const droppedInputs = () => dropoutFunc != null ? dropoutFunc(ones(), rate) : K.dropout(ones(), rate);
1338 const createMask = () => K.inTrainPhase(droppedInputs, ones, training);
1339 // just in case count is provided with null or undefined
1340 if (!count || count <= 1) {
1341 return tfc.keep(createMask().clone());
1342 }
1343 const masks = Array(count).fill(undefined).map(createMask);
1344 return masks.map(m => tfc.keep(m.clone()));
1345}
1346//# sourceMappingURL=data:application/json;base64,
\No newline at end of file