UNPKG

23.2 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/**
11 * TensorFlow.js Layers: Embedding Layer.
12 *
13 * Original source: keras/constraints.py
14 */
15import { notEqual, reshape, serialization, tidy, zerosLike } from '@tensorflow/tfjs-core';
16import * as K from '../backend/tfjs_backend';
17import { getConstraint, serializeConstraint } from '../constraints';
18import { Layer } from '../engine/topology';
19import { ValueError } from '../errors';
20import { getInitializer, serializeInitializer } from '../initializers';
21import { getRegularizer, serializeRegularizer } from '../regularizers';
22import * as generic_utils from '../utils/generic_utils';
23import { getExactlyOneShape, getExactlyOneTensor } from '../utils/types_utils';
24export class Embedding extends Layer {
25 constructor(args) {
26 super(args);
27 this.embeddings = null;
28 this.DEFAULT_EMBEDDINGS_INITIALIZER = 'randomUniform';
29 if (args.batchInputShape == null && args.inputShape == null) {
30 // Porting Note: This logic is copied from Layer's constructor, since we
31 // can't do exactly what the Python constructor does for Embedding().
32 // Specifically, the super constructor can not be called after the
33 // mutation of the `config` argument.
34 let batchSize = null;
35 if (args.batchSize != null) {
36 batchSize = args.batchSize;
37 }
38 if (args.inputLength == null) {
39 // Fix super-constructor to what it would have done if
40 // 'config.inputShape' were (None, )
41 this.batchInputShape = [batchSize, null];
42 }
43 else {
44 // Fix super-constructor to what it would have done if
45 // 'config.inputShape' were (config.inputLength, )
46 this.batchInputShape =
47 [batchSize].concat(generic_utils.toList(args.inputLength));
48 }
49 }
50 this.inputDim = args.inputDim;
51 generic_utils.assertPositiveInteger(this.inputDim, 'inputDim');
52 this.outputDim = args.outputDim;
53 generic_utils.assertPositiveInteger(this.outputDim, 'outputDim');
54 this.embeddingsInitializer = getInitializer(args.embeddingsInitializer || this.DEFAULT_EMBEDDINGS_INITIALIZER);
55 this.embeddingsRegularizer = getRegularizer(args.embeddingsRegularizer);
56 this.activityRegularizer = getRegularizer(args.activityRegularizer);
57 this.embeddingsConstraint = getConstraint(args.embeddingsConstraint);
58 this.maskZero = args.maskZero;
59 this.supportsMasking = args.maskZero;
60 this.inputLength = args.inputLength;
61 }
62 build(inputShape) {
63 this.embeddings = this.addWeight('embeddings', [this.inputDim, this.outputDim], this.dtype, this.embeddingsInitializer, this.embeddingsRegularizer, true, this.embeddingsConstraint);
64 this.built = true;
65 }
66 // Override warnOnIncompatibleInputShape because an embedding layer allows
67 // the input to have varying ranks.
68 warnOnIncompatibleInputShape(inputShape) { }
69 computeMask(inputs, mask) {
70 return tidy(() => {
71 if (!this.maskZero) {
72 return null;
73 }
74 else {
75 inputs = getExactlyOneTensor(inputs);
76 return notEqual(inputs, zerosLike(inputs));
77 }
78 });
79 }
80 computeOutputShape(inputShape) {
81 inputShape = getExactlyOneShape(inputShape);
82 if (this.inputLength == null) {
83 return [...inputShape, this.outputDim];
84 }
85 // inputLength can be an array if input is 3D or higher.
86 const inLens = generic_utils.toList(this.inputLength);
87 if (inLens.length !== inputShape.length - 1) {
88 throw new ValueError(`"inputLength" is ${this.inputLength}, but received ` +
89 `input shape has shape ${inputShape}`);
90 }
91 else {
92 let i = 0;
93 for (let k = 0; k < inLens.length; ++k) {
94 const s1 = inLens[k];
95 const s2 = inputShape[k + 1];
96 if ((s1 != null) && (s2 != null) && (s1 !== s2)) {
97 throw new ValueError(`"inputLength" is ${this.inputLength}, but received ` +
98 `input shape has shape ${inputShape}`);
99 }
100 else if (s1 == null) {
101 inLens[i] = s2;
102 }
103 i++;
104 }
105 }
106 return [inputShape[0], ...inLens, this.outputDim];
107 }
108 call(inputs, kwargs) {
109 return tidy(() => {
110 this.invokeCallHook(inputs, kwargs);
111 // Embedding layer accepts only a single input.
112 let input = getExactlyOneTensor(inputs);
113 if (input.dtype !== 'int32') {
114 input = K.cast(input, 'int32');
115 }
116 const output = K.gather(this.embeddings.read(), reshape(input, [input.size]));
117 return reshape(output, getExactlyOneShape(this.computeOutputShape(input.shape)));
118 });
119 }
120 getConfig() {
121 const config = {
122 inputDim: this.inputDim,
123 outputDim: this.outputDim,
124 embeddingsInitializer: serializeInitializer(this.embeddingsInitializer),
125 embeddingsRegularizer: serializeRegularizer(this.embeddingsRegularizer),
126 activityRegularizer: serializeRegularizer(this.activityRegularizer),
127 embeddingsConstraint: serializeConstraint(this.embeddingsConstraint),
128 maskZero: this.maskZero,
129 inputLength: this.inputLength
130 };
131 const baseConfig = super.getConfig();
132 Object.assign(config, baseConfig);
133 return config;
134 }
135}
136/** @nocollapse */
137Embedding.className = 'Embedding';
138serialization.registerClass(Embedding);
139//# sourceMappingURL=data:application/json;base64,
\No newline at end of file