UNPKG

67.9 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/**
11 * Layers that augment the functionality of a base layer.
12 */
13import * as tfc from '@tensorflow/tfjs-core';
14import { serialization, tidy } from '@tensorflow/tfjs-core';
15import * as K from '../backend/tfjs_backend';
16import { nameScope } from '../common';
17import { InputSpec, Layer, SymbolicTensor } from '../engine/topology';
18import { NotImplementedError, ValueError } from '../errors';
19import { VALID_BIDIRECTIONAL_MERGE_MODES } from '../keras_format/common';
20import * as generic_utils from '../utils/generic_utils';
21import { getExactlyOneShape, getExactlyOneTensor } from '../utils/types_utils';
22import { rnn, standardizeArgs } from './recurrent';
23import { deserialize } from './serialization';
24/**
25 * Abstract wrapper base class.
26 *
27 * Wrappers take another layer and augment it in various ways.
28 * Do not use this class as a layer, it is only an abstract base class.
29 * Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
30 */
31export class Wrapper extends Layer {
32 constructor(args) {
33 // Porting Note: In PyKeras, `self.layer` is set prior to the calling
34 // `super()`. But we can't do that here due to TypeScript's restriction.
35 // See: https://github.com/Microsoft/TypeScript/issues/8277
36 // As a result, we have to add checks in `get trainable()` and
37 // `set trainable()` below in order to prevent using `this.layer` when
38 // its value is `undefined`. The super constructor does use the getter
39 // and the setter of `this.layer`.
40 super(args);
41 this.layer = args.layer;
42 }
43 build(inputShape) {
44 this.built = true;
45 }
46 // TODO(cais): Implement activityRegularizer getter.
47 get trainable() {
48 // Porting Note: the check of `this.layer` here is necessary due to the
49 // way the `constructor` of this class is written (see Porting Note
50 // above).
51 if (this.layer != null) {
52 return this.layer.trainable;
53 }
54 else {
55 return false;
56 }
57 }
58 set trainable(value) {
59 // Porting Note: the check of `this.layer` here is necessary due to the
60 // way the `constructor` of this class is written (see Porting Note
61 // above).
62 if (this.layer != null) {
63 this.layer.trainable = value;
64 }
65 }
66 get trainableWeights() {
67 return this.layer.trainableWeights;
68 }
69 // TODO(cais): Implement setter for trainableWeights.
70 get nonTrainableWeights() {
71 return this.layer.nonTrainableWeights;
72 }
73 // TODO(cais): Implement setter for nonTrainableWeights.
74 get updates() {
75 // tslint:disable-next-line:no-any
76 return this.layer._updates;
77 }
78 // TODO(cais): Implement getUpdatesFor().
79 get losses() {
80 return this.layer.losses;
81 }
82 // TODO(cais): Implement getLossesFor().
83 getWeights() {
84 return this.layer.getWeights();
85 }
86 setWeights(weights) {
87 this.layer.setWeights(weights);
88 }
89 getConfig() {
90 const config = {
91 'layer': {
92 'className': this.layer.getClassName(),
93 'config': this.layer.getConfig(),
94 }
95 };
96 const baseConfig = super.getConfig();
97 Object.assign(config, baseConfig);
98 return config;
99 }
100 setFastWeightInitDuringBuild(value) {
101 super.setFastWeightInitDuringBuild(value);
102 if (this.layer != null) {
103 this.layer.setFastWeightInitDuringBuild(value);
104 }
105 }
106 /** @nocollapse */
107 static fromConfig(cls, config, customObjects = {}) {
108 const layerConfig = config['layer'];
109 const layer = deserialize(layerConfig, customObjects);
110 delete config['layer'];
111 const newConfig = { layer };
112 Object.assign(newConfig, config);
113 return new cls(newConfig);
114 }
115}
116export class TimeDistributed extends Wrapper {
117 constructor(args) {
118 super(args);
119 this.supportsMasking = true;
120 }
121 build(inputShape) {
122 inputShape = getExactlyOneShape(inputShape);
123 if (inputShape.length < 3) {
124 throw new ValueError(`TimeDistributed layer expects an input shape >= 3D, but received ` +
125 `input shape ${JSON.stringify(inputShape)}`);
126 }
127 this.inputSpec = [{ shape: inputShape }];
128 const childInputShape = [inputShape[0]].concat(inputShape.slice(2));
129 if (!this.layer.built) {
130 this.layer.build(childInputShape);
131 this.layer.built = true;
132 }
133 super.build(inputShape);
134 }
135 computeOutputShape(inputShape) {
136 inputShape = getExactlyOneShape(inputShape);
137 const childInputShape = [inputShape[0]].concat(inputShape.slice(2));
138 const childOutputShape = this.layer.computeOutputShape(childInputShape);
139 const timesteps = inputShape[1];
140 return [childOutputShape[0], timesteps].concat(childOutputShape.slice(1));
141 }
142 call(inputs, kwargs) {
143 return tidy(() => {
144 // TODO(cais): Add 'training' and 'useLearningPhase' to kwargs.
145 inputs = getExactlyOneTensor(inputs);
146 // Porting Note: In tfjs-layers, `inputs` are always concrete tensor
147 // values. Hence the inputs can't have an undetermined first (batch)
148 // dimension, which is why we always use the K.rnn approach here.
149 const step = (inputs, states) => {
150 // TODO(cais): Add useLearningPhase.
151 // NOTE(cais): `layer.call` may return a length-1 array of Tensor in
152 // some cases (e.g., `layer` is a `Sequential` instance), which is
153 // why `getExactlyOneTensor` is used below.
154 const output = getExactlyOneTensor(this.layer.call(inputs, kwargs));
155 return [output, []];
156 };
157 const rnnOutputs = rnn(step, inputs, [], false /* goBackwards */, null /* mask */, null /* constants */, false /* unroll */, true /* needPerStepOutputs */);
158 const y = rnnOutputs[1];
159 // TODO(cais): Add activity regularization.
160 // TODO(cais): Add useLearningPhase.
161 return y;
162 });
163 }
164}
165/** @nocollapse */
166TimeDistributed.className = 'TimeDistributed';
167serialization.registerClass(TimeDistributed);
168export function checkBidirectionalMergeMode(value) {
169 generic_utils.checkStringTypeUnionValue(VALID_BIDIRECTIONAL_MERGE_MODES, 'BidirectionalMergeMode', value);
170}
171const DEFAULT_BIDIRECTIONAL_MERGE_MODE = 'concat';
172export class Bidirectional extends Wrapper {
173 constructor(args) {
174 super(args);
175 // Note: When creating `this.forwardLayer`, the original Layer object
176 // (`config.layer`) ought to be cloned. This is why we call
177 // `getConfig()` followed by `deserialize()`. Without this cloning,
178 // the layer names saved during serialization will incorrectly contain
179 // the 'forward_' prefix. In Python Keras, this is done using
180 // `copy.copy` (shallow copy), which does not have a simple equivalent
181 // in JavaScript. JavaScript's `Object.assign()` does not copy
182 // methods.
183 const layerConfig = args.layer.getConfig();
184 const forwDict = {};
185 forwDict['className'] = args.layer.getClassName();
186 forwDict['config'] = layerConfig;
187 this.forwardLayer = deserialize(forwDict);
188 layerConfig['goBackwards'] =
189 layerConfig['goBackwards'] === true ? false : true;
190 const backDict = {};
191 backDict['className'] = args.layer.getClassName();
192 backDict['config'] = layerConfig;
193 this.backwardLayer = deserialize(backDict);
194 this.forwardLayer.name = 'forward_' + this.forwardLayer.name;
195 this.backwardLayer.name = 'backward_' + this.backwardLayer.name;
196 this.mergeMode = args.mergeMode === undefined ?
197 DEFAULT_BIDIRECTIONAL_MERGE_MODE :
198 args.mergeMode;
199 checkBidirectionalMergeMode(this.mergeMode);
200 if (args.weights) {
201 throw new NotImplementedError('weights support is not implemented for Bidirectional layer yet.');
202 }
203 this._stateful = args.layer.stateful;
204 this.returnSequences = args.layer.returnSequences;
205 this.returnState = args.layer.returnState;
206 this.supportsMasking = true;
207 this._trainable = true;
208 this.inputSpec = args.layer.inputSpec;
209 this.numConstants = null;
210 }
211 get trainable() {
212 return this._trainable;
213 }
214 set trainable(value) {
215 // Porting Note: the check of `this.layer` here is necessary due to the
216 // way the `constructor` of this class is written (see Porting Note
217 // above).
218 this._trainable = value;
219 if (this.forwardLayer != null) {
220 this.forwardLayer.trainable = value;
221 }
222 if (this.backwardLayer != null) {
223 this.backwardLayer.trainable = value;
224 }
225 }
226 getWeights() {
227 return this.forwardLayer.getWeights().concat(this.backwardLayer.getWeights());
228 }
229 setWeights(weights) {
230 const numWeights = weights.length;
231 const numeightsOver2 = Math.floor(numWeights / 2);
232 this.forwardLayer.setWeights(weights.slice(0, numeightsOver2));
233 this.backwardLayer.setWeights(weights.slice(numeightsOver2));
234 }
235 computeOutputShape(inputShape) {
236 let layerShapes = this.forwardLayer.computeOutputShape(inputShape);
237 if (!(Array.isArray(layerShapes) && Array.isArray(layerShapes[0]))) {
238 layerShapes = [layerShapes];
239 }
240 layerShapes = layerShapes;
241 let outputShape;
242 let outputShapes;
243 let stateShape;
244 if (this.returnState) {
245 stateShape = layerShapes.slice(1);
246 outputShape = layerShapes[0];
247 }
248 else {
249 outputShape = layerShapes[0];
250 }
251 outputShape = outputShape;
252 if (this.mergeMode === 'concat') {
253 outputShape[outputShape.length - 1] *= 2;
254 outputShapes = [outputShape];
255 }
256 else if (this.mergeMode == null) {
257 outputShapes = [outputShape, outputShape.slice()];
258 }
259 else {
260 outputShapes = [outputShape];
261 }
262 if (this.returnState) {
263 if (this.mergeMode == null) {
264 return outputShapes.concat(stateShape).concat(stateShape.slice());
265 }
266 return [outputShape].concat(stateShape).concat(stateShape.slice());
267 }
268 return generic_utils.singletonOrArray(outputShapes);
269 }
270 apply(inputs, kwargs) {
271 let initialState = kwargs == null ? null : kwargs['initialState'];
272 let constants = kwargs == null ? null : kwargs['constants'];
273 if (kwargs == null) {
274 kwargs = {};
275 }
276 const standardized = standardizeArgs(inputs, initialState, constants, this.numConstants);
277 inputs = standardized.inputs;
278 initialState = standardized.initialState;
279 constants = standardized.constants;
280 if (Array.isArray(inputs)) {
281 initialState = inputs.slice(1);
282 inputs = inputs[0];
283 }
284 if ((initialState == null || initialState.length === 0) &&
285 constants == null) {
286 return super.apply(inputs, kwargs);
287 }
288 const additionalInputs = [];
289 const additionalSpecs = [];
290 if (initialState != null) {
291 const numStates = initialState.length;
292 if (numStates % 2 > 0) {
293 throw new ValueError('When passing `initialState` to a Bidrectional RNN, ' +
294 'the state should be an Array containing the states of ' +
295 'the underlying RNNs.');
296 }
297 kwargs['initialState'] = initialState;
298 additionalInputs.push(...initialState);
299 const stateSpecs = initialState
300 .map(state => new InputSpec({ shape: state.shape }));
301 this.forwardLayer.stateSpec = stateSpecs.slice(0, numStates / 2);
302 this.backwardLayer.stateSpec = stateSpecs.slice(numStates / 2);
303 additionalSpecs.push(...stateSpecs);
304 }
305 if (constants != null) {
306 throw new NotImplementedError('Support for constants in Bidirectional layers is not ' +
307 'implemented yet.');
308 }
309 const isSymbolicTensor = additionalInputs[0] instanceof SymbolicTensor;
310 for (const tensor of additionalInputs) {
311 if (tensor instanceof SymbolicTensor !== isSymbolicTensor) {
312 throw new ValueError('The initial state of a Bidirectional layer cannot be ' +
313 'specified as a mix of symbolic and non-symbolic tensors');
314 }
315 }
316 if (isSymbolicTensor) {
317 // Compute the full input and specs, including the states.
318 const fullInput = [inputs].concat(additionalInputs);
319 const fullInputSpec = this.inputSpec.concat(additionalSpecs);
320 // Perform the call temporarily and replace inputSpec.
321 // Note: with initial states symbolic calls and non-symbolic calls to
322 // this method differ in how the initial states are passed. For
323 // symbolic calls, the initial states are passed in the first arg, as
324 // an Array of SymbolicTensors; for non-symbolic calls, they are
325 // passed in the second arg as a part of the kwargs. Hence the need to
326 // temporarily modify inputSpec here.
327 // TODO(cais): Make refactoring so that this hacky code below is no
328 // longer needed.
329 const originalInputSpec = this.inputSpec;
330 this.inputSpec = fullInputSpec;
331 const output = super.apply(fullInput, kwargs);
332 this.inputSpec = originalInputSpec;
333 return output;
334 }
335 else {
336 return super.apply(inputs, kwargs);
337 }
338 }
339 call(inputs, kwargs) {
340 return tidy(() => {
341 const initialState = kwargs['initialState'];
342 let y;
343 let yRev;
344 if (initialState == null) {
345 y = this.forwardLayer.call(inputs, kwargs);
346 yRev = this.backwardLayer.call(inputs, kwargs);
347 }
348 else {
349 const forwardState = initialState.slice(0, initialState.length / 2);
350 const backwardState = initialState.slice(initialState.length / 2);
351 y = this.forwardLayer.call(inputs, Object.assign(kwargs, { initialState: forwardState }));
352 yRev = this.backwardLayer.call(inputs, Object.assign(kwargs, { initialState: backwardState }));
353 }
354 let states;
355 if (this.returnState) {
356 if (Array.isArray(y)) {
357 states = y.slice(1).concat(yRev.slice(1));
358 }
359 else {
360 }
361 y = y[0];
362 yRev = yRev[0];
363 }
364 if (this.returnSequences) {
365 yRev = tfc.reverse(yRev, 1);
366 }
367 let output;
368 if (this.mergeMode === 'concat') {
369 output = K.concatenate([y, yRev]);
370 }
371 else if (this.mergeMode === 'sum') {
372 output = tfc.add(y, yRev);
373 }
374 else if (this.mergeMode === 'ave') {
375 output = tfc.mul(.5, tfc.add(y, yRev));
376 }
377 else if (this.mergeMode === 'mul') {
378 output = tfc.mul(y, yRev);
379 }
380 else if (this.mergeMode == null) {
381 output = [y, yRev];
382 }
383 // TODO(cais): Properly set learning phase.
384 if (this.returnState) {
385 if (this.mergeMode == null) {
386 return output.concat(states);
387 }
388 return [output].concat(states);
389 }
390 return output;
391 });
392 }
393 resetStates(states) {
394 this.forwardLayer.resetStates();
395 this.backwardLayer.resetStates();
396 }
397 build(inputShape) {
398 nameScope(this.forwardLayer.name, () => {
399 this.forwardLayer.build(inputShape);
400 });
401 nameScope(this.backwardLayer.name, () => {
402 this.backwardLayer.build(inputShape);
403 });
404 this.built = true;
405 }
406 computeMask(inputs, mask) {
407 if (Array.isArray(mask)) {
408 mask = mask[0];
409 }
410 let outputMask;
411 if (this.returnSequences) {
412 if (this.mergeMode == null) {
413 outputMask = [mask, mask];
414 }
415 else {
416 outputMask = mask;
417 }
418 }
419 else {
420 if (this.mergeMode == null) {
421 outputMask = [null, null];
422 }
423 else {
424 outputMask = null;
425 }
426 }
427 if (this.returnState) {
428 const states = this.forwardLayer.states;
429 const stateMask = states.map(state => null);
430 if (Array.isArray(outputMask)) {
431 return outputMask.concat(stateMask).concat(stateMask);
432 }
433 else {
434 return [outputMask].concat(stateMask).concat(stateMask);
435 }
436 }
437 else {
438 return outputMask;
439 }
440 }
441 get trainableWeights() {
442 return this.forwardLayer.trainableWeights.concat(this.backwardLayer.trainableWeights);
443 }
444 get nonTrainableWeights() {
445 return this.forwardLayer.nonTrainableWeights.concat(this.backwardLayer.nonTrainableWeights);
446 }
447 // TODO(cais): Implement constraints().
448 setFastWeightInitDuringBuild(value) {
449 super.setFastWeightInitDuringBuild(value);
450 if (this.forwardLayer != null) {
451 this.forwardLayer.setFastWeightInitDuringBuild(value);
452 }
453 if (this.backwardLayer != null) {
454 this.backwardLayer.setFastWeightInitDuringBuild(value);
455 }
456 }
457 getConfig() {
458 const config = {
459 'mergeMode': this.mergeMode,
460 };
461 // TODO(cais): Add logic for `numConstants` once the property is added.
462 const baseConfig = super.getConfig();
463 Object.assign(config, baseConfig);
464 return config;
465 }
466 /** @nocollapse */
467 static fromConfig(cls, config) {
468 const rnnLayer = deserialize(config['layer']);
469 delete config['layer'];
470 // TODO(cais): Add logic for `numConstants` once the property is added.
471 if (config['numConstants'] != null) {
472 throw new NotImplementedError(`Deserialization of a Bidirectional layer with numConstants ` +
473 `present is not supported yet.`);
474 }
475 // tslint:disable-next-line:no-any
476 const newConfig = config;
477 newConfig['layer'] = rnnLayer;
478 return new cls(newConfig);
479 }
480}
481/** @nocollapse */
482Bidirectional.className = 'Bidirectional';
483serialization.registerClass(Bidirectional);
484//# sourceMappingURL=data:application/json;base64,
\No newline at end of file