1 | /**
|
2 | * @license
|
3 | * Copyright 2018 Google LLC
|
4 | *
|
5 | * Use of this source code is governed by an MIT-style
|
6 | * license that can be found in the LICENSE file or at
|
7 | * https://opensource.org/licenses/MIT.
|
8 | * =============================================================================
|
9 | */
|
10 | /**
|
11 | * TensorFlow.js Layers: Embedding Layer.
|
12 | *
|
13 | * Original source: keras/constraints.py
|
14 | */
|
15 | import { notEqual, reshape, serialization, tidy, zerosLike } from '@tensorflow/tfjs-core';
|
16 | import * as K from '../backend/tfjs_backend';
|
17 | import { getConstraint, serializeConstraint } from '../constraints';
|
18 | import { Layer } from '../engine/topology';
|
19 | import { ValueError } from '../errors';
|
20 | import { getInitializer, serializeInitializer } from '../initializers';
|
21 | import { getRegularizer, serializeRegularizer } from '../regularizers';
|
22 | import * as generic_utils from '../utils/generic_utils';
|
23 | import { getExactlyOneShape, getExactlyOneTensor } from '../utils/types_utils';
|
24 | export class Embedding extends Layer {
|
25 | constructor(args) {
|
26 | super(args);
|
27 | this.embeddings = null;
|
28 | this.DEFAULT_EMBEDDINGS_INITIALIZER = 'randomUniform';
|
29 | if (args.batchInputShape == null && args.inputShape == null) {
|
30 | // Porting Note: This logic is copied from Layer's constructor, since we
|
31 | // can't do exactly what the Python constructor does for Embedding().
|
32 | // Specifically, the super constructor can not be called after the
|
33 | // mutation of the `config` argument.
|
34 | let batchSize = null;
|
35 | if (args.batchSize != null) {
|
36 | batchSize = args.batchSize;
|
37 | }
|
38 | if (args.inputLength == null) {
|
39 | // Fix super-constructor to what it would have done if
|
40 | // 'config.inputShape' were (None, )
|
41 | this.batchInputShape = [batchSize, null];
|
42 | }
|
43 | else {
|
44 | // Fix super-constructor to what it would have done if
|
45 | // 'config.inputShape' were (config.inputLength, )
|
46 | this.batchInputShape =
|
47 | [batchSize].concat(generic_utils.toList(args.inputLength));
|
48 | }
|
49 | }
|
50 | this.inputDim = args.inputDim;
|
51 | generic_utils.assertPositiveInteger(this.inputDim, 'inputDim');
|
52 | this.outputDim = args.outputDim;
|
53 | generic_utils.assertPositiveInteger(this.outputDim, 'outputDim');
|
54 | this.embeddingsInitializer = getInitializer(args.embeddingsInitializer || this.DEFAULT_EMBEDDINGS_INITIALIZER);
|
55 | this.embeddingsRegularizer = getRegularizer(args.embeddingsRegularizer);
|
56 | this.activityRegularizer = getRegularizer(args.activityRegularizer);
|
57 | this.embeddingsConstraint = getConstraint(args.embeddingsConstraint);
|
58 | this.maskZero = args.maskZero;
|
59 | this.supportsMasking = args.maskZero;
|
60 | this.inputLength = args.inputLength;
|
61 | }
|
62 | build(inputShape) {
|
63 | this.embeddings = this.addWeight('embeddings', [this.inputDim, this.outputDim], this.dtype, this.embeddingsInitializer, this.embeddingsRegularizer, true, this.embeddingsConstraint);
|
64 | this.built = true;
|
65 | }
|
66 | // Override warnOnIncompatibleInputShape because an embedding layer allows
|
67 | // the input to have varying ranks.
|
68 | warnOnIncompatibleInputShape(inputShape) { }
|
69 | computeMask(inputs, mask) {
|
70 | return tidy(() => {
|
71 | if (!this.maskZero) {
|
72 | return null;
|
73 | }
|
74 | else {
|
75 | inputs = getExactlyOneTensor(inputs);
|
76 | return notEqual(inputs, zerosLike(inputs));
|
77 | }
|
78 | });
|
79 | }
|
80 | computeOutputShape(inputShape) {
|
81 | inputShape = getExactlyOneShape(inputShape);
|
82 | if (this.inputLength == null) {
|
83 | return [...inputShape, this.outputDim];
|
84 | }
|
85 | // inputLength can be an array if input is 3D or higher.
|
86 | const inLens = generic_utils.toList(this.inputLength);
|
87 | if (inLens.length !== inputShape.length - 1) {
|
88 | throw new ValueError(`"inputLength" is ${this.inputLength}, but received ` +
|
89 | `input shape has shape ${inputShape}`);
|
90 | }
|
91 | else {
|
92 | let i = 0;
|
93 | for (let k = 0; k < inLens.length; ++k) {
|
94 | const s1 = inLens[k];
|
95 | const s2 = inputShape[k + 1];
|
96 | if ((s1 != null) && (s2 != null) && (s1 !== s2)) {
|
97 | throw new ValueError(`"inputLength" is ${this.inputLength}, but received ` +
|
98 | `input shape has shape ${inputShape}`);
|
99 | }
|
100 | else if (s1 == null) {
|
101 | inLens[i] = s2;
|
102 | }
|
103 | i++;
|
104 | }
|
105 | }
|
106 | return [inputShape[0], ...inLens, this.outputDim];
|
107 | }
|
108 | call(inputs, kwargs) {
|
109 | return tidy(() => {
|
110 | this.invokeCallHook(inputs, kwargs);
|
111 | // Embedding layer accepts only a single input.
|
112 | let input = getExactlyOneTensor(inputs);
|
113 | if (input.dtype !== 'int32') {
|
114 | input = K.cast(input, 'int32');
|
115 | }
|
116 | const output = K.gather(this.embeddings.read(), reshape(input, [input.size]));
|
117 | return reshape(output, getExactlyOneShape(this.computeOutputShape(input.shape)));
|
118 | });
|
119 | }
|
120 | getConfig() {
|
121 | const config = {
|
122 | inputDim: this.inputDim,
|
123 | outputDim: this.outputDim,
|
124 | embeddingsInitializer: serializeInitializer(this.embeddingsInitializer),
|
125 | embeddingsRegularizer: serializeRegularizer(this.embeddingsRegularizer),
|
126 | activityRegularizer: serializeRegularizer(this.activityRegularizer),
|
127 | embeddingsConstraint: serializeConstraint(this.embeddingsConstraint),
|
128 | maskZero: this.maskZero,
|
129 | inputLength: this.inputLength
|
130 | };
|
131 | const baseConfig = super.getConfig();
|
132 | Object.assign(config, baseConfig);
|
133 | return config;
|
134 | }
|
135 | }
|
136 | /** @nocollapse */
|
137 | Embedding.className = 'Embedding';
|
138 | serialization.registerClass(Embedding);
|
139 | //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"embeddings.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/layers/embeddings.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH;;;;GAIG;AACH,OAAO,EAAC,QAAQ,EAAE,OAAO,EAAE,aAAa,EAAU,IAAI,EAAE,SAAS,EAAC,MAAM,uBAAuB,CAAC;AAEhG,OAAO,KAAK,CAAC,MAAM,yBAAyB,CAAC;AAC7C,OAAO,EAAmC,aAAa,EAAE,mBAAmB,EAAC,MAAM,gBAAgB,CAAC;AACpG,OAAO,EAAC,KAAK,EAAY,MAAM,oBAAoB,CAAC;AACpD,OAAO,EAAC,UAAU,EAAC,MAAM,WAAW,CAAC;AACrC,OAAO,EAAC,cAAc,EAAsC,oBAAoB,EAAC,MAAM,iBAAiB,CAAC;AAEzG,OAAO,EAAC,cAAc,EAAsC,oBAAoB,EAAC,MAAM,iBAAiB,CAAC;AAEzG,OAAO,KAAK,aAAa,MAAM,wBAAwB,CAAC;AACxD,OAAO,EAAC,kBAAkB,EAAE,mBAAmB,EAAC,MAAM,sBAAsB,CAAC;AAiD7E,MAAM,OAAO,SAAU,SAAQ,KAAK;IAgBlC,YAAY,IAAwB;QAClC,KAAK,CAAC,IAAI,CAAC,CAAC;QARN,eAAU,GAAkB,IAAI,CAAC;QAEhC,mCAA8B,GACnC,eAAe,CAAC;QAMlB,IAAI,IAAI,CAAC,eAAe,IAAI,IAAI,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3D,wEAAwE;YACxE,qEAAqE;YACrE,kEAAkE;YAClE,qCAAqC;YACrC,IAAI,SAAS,GAAW,IAAI,CAAC;YAC7B,IAAI,IAAI,CAAC,SAAS,IAAI,IAAI,EAAE;gBAC1B,SAAS,GAAG,IAAI,CAAC,SAAS,CAAC;aAC5B;YACD,IAAI,IAAI,CAAC,WAAW,IAAI,IAAI,EAAE;gBAC5B,sDAAsD;gBACtD,oCAAoC;gBACpC,IAAI,CAAC,eAAe,GAAG,CAAC,SAAS,EAAE,IAAI,CAAC,CAAC;aAC1C;iBAAM;gBACL,sDAAsD;gBACtD,kDAAkD;gBAClD,IAAI,CAAC,eAAe;oBAChB,CAAC,SAAS,CAAC,CAAC,MAAM,CAAC,aAAa,CAAC,MAAM,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC;aAChE;SACF;QACD,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAC9B,aAAa,CAAC,qBAAqB,CAAC,IAAI,CAAC,QAAQ,EAAE,UAAU,CAAC,CAAC;QAC/D,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,SAAS,CAAC;QAChC,aAAa,CAAC,qBAAqB,CAAC,IAAI,CAAC,SAAS,EAAE,WAAW,CAAC,CAAC;QACjE,IAAI,CAAC,qBAAqB,GAAG,cAAc,CACvC,IAAI,CAAC,qBAAqB,IAAI,IAAI,CAAC,8BAA8B,CAAC,CAAC;QACvE,IAAI,CAAC,qBAAqB,GAAG,cAAc,CAAC,IAAI,CAAC,qBAAqB,CAAC,CAAC;QACxE,IAAI,CAAC,mBAAmB,GAAG,cAAc,CAAC,IAAI,CAAC,mBAAmB,CAAC,CAAC;QACpE,IAAI,CAAC,oBAAoB,GAAG,aAAa,CAAC,IAAI,CAAC,oBAAoB,CAAC,CAAC;QACrE,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAC9B,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC,QAAQ,CAAC;QACrC,IAAI,CAAC,WAAW,GAAG,IAAI,CAAC,WAAW,CAAC;IACtC,CAAC;IAEe,KAAK,CAAC,UAAyB;QAC7C,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,SAAS,CAC5B,YAAY,EAAE,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,SAAS,CAAC,EAAE,IAAI,CAAC,KAAK,EACzD,IAAI,CAAC,qBAAqB,EAAE,IAAI,CAAC,qBAAqB,EAAE,IAAI,EAC5D,IAAI,CAAC,oBAAoB,CAAC,CAAC;QAC/B,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC;IACpB,CAAC;IAED,0EAA0E;IAC1E,mCAAmC;IAChB,4BAA4B,CAAC,UAAiB,IAAG,CAAC;IAE5D,WAAW,CAAC,MAAuB,EAAE,IAAsB;QAElE,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE;gBAClB,OAAO,IAAI,CAAC;aACb;iBAAM;gBACL,MAAM,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;gBACrC,OAAO,QAAQ,CAAC,MAAM,EAAE,SAAS,CAAC,MAAM,CAAC,CAAC,CAAC;aAC5C;QACH,CAAC,CAAC,CAAC;IACL,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,UAAU,GAAG,kBAAkB,CAAC,UAAU,CAAC,CAAC;QAC5C,IAAI,IAAI,CAAC,WAAW,IAAI,IAAI,EAAE;YAC5B,OAAO,CAAC,GAAG,UAAU,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;SACxC;QACD,wDAAwD;QACxD,MAAM,MAAM,GAAa,aAAa,CAAC,MAAM,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC;QAChE,IAAI,MAAM,CAAC,MAAM,KAAK,UAAU,CAAC,MAAM,GAAG,CAAC,EAAE;YAC3C,MAAM,IAAI,UAAU,CAChB,oBAAoB,IAAI,CAAC,WAAW,iBAAiB;gBACrD,yBAAyB,UAAU,EAAE,CAAC,CAAC;SAC5C;aAAM;YACL,IAAI,CAAC,GAAG,CAAC,CAAC;YACV,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;gBACtC,MAAM,EAAE,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;gBACrB,MAAM,EAAE,GAAG,UAAU,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;gBAC7B,IAAI,CAAC,EAAE,IAAI,IAAI,CAAC,IAAI,CAAC,EAAE,IAAI,IAAI,CAAC,IAAI,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE;oBAC/C,MAAM,IAAI,UAAU,CAChB,oBAAoB,IAAI,CAAC,WAAW,iBAAiB;wBACrD,yBAAyB,UAAU,EAAE,CAAC,CAAC;iBAC5C;qBAAM,IAAI,EAAE,IAAI,IAAI,EAAE;oBACrB,MAAM,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC;iBAChB;gBACD,CAAC,EAAE,CAAC;aACL;SACF;QACD,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,GAAG,MAAM,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;IACpD,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,IAAI,CAAC,cAAc,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YACpC,+CAA+C;YAC/C,IAAI,KAAK,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;YACxC,IAAI,KAAK,CAAC,KAAK,KAAK,OAAO,EAAE;gBAC3B,KAAK,GAAG,CAAC,CAAC,IAAI,CAAC,KAAK,EAAE,OAAO,CAAC,CAAC;aAChC;YACD,MAAM,MAAM,GACR,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,UAAU,CAAC,IAAI,EAAE,EAAE,OAAO,CAAC,KAAK,EAAE,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;YACnE,OAAO,OAAO,CACV,MAAM,EAAE,kBAAkB,CAAC,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QACxE,CAAC,CAAC,CAAC;IACL,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAAG;YACb,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,SAAS,EAAE,IAAI,CAAC,SAAS;YACzB,qBAAqB,EAAE,oBAAoB,CAAC,IAAI,CAAC,qBAAqB,CAAC;YACvE,qBAAqB,EAAE,oBAAoB,CAAC,IAAI,CAAC,qBAAqB,CAAC;YACvE,mBAAmB,EAAE,oBAAoB,CAAC,IAAI,CAAC,mBAAmB,CAAC;YACnE,oBAAoB,EAAE,mBAAmB,CAAC,IAAI,CAAC,oBAAoB,CAAC;YACpE,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,WAAW,EAAE,IAAI,CAAC,WAAW;SAC9B,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AArID,kBAAkB;AACX,mBAAS,GAAG,WAAW,CAAC;AAsIjC,aAAa,CAAC,aAAa,CAAC,SAAS,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/**\n * TensorFlow.js Layers: Embedding Layer.\n *\n * Original source: keras/constraints.py\n */\nimport {notEqual, reshape, serialization, Tensor, tidy, zerosLike} from '@tensorflow/tfjs-core';\n\nimport * as K from '../backend/tfjs_backend';\nimport {Constraint, ConstraintIdentifier, getConstraint, serializeConstraint} from '../constraints';\nimport {Layer, LayerArgs} from '../engine/topology';\nimport {ValueError} from '../errors';\nimport {getInitializer, Initializer, InitializerIdentifier, serializeInitializer} from '../initializers';\nimport {Shape} from '../keras_format/common';\nimport {getRegularizer, Regularizer, RegularizerIdentifier, serializeRegularizer} from '../regularizers';\nimport {Kwargs} from '../types';\nimport * as generic_utils from '../utils/generic_utils';\nimport {getExactlyOneShape, getExactlyOneTensor} from '../utils/types_utils';\nimport {LayerVariable} from '../variables';\n\nexport declare interface EmbeddingLayerArgs extends LayerArgs {\n  /**\n   * Integer > 0. Size of the vocabulary, i.e. maximum integer index + 1.\n   */\n  inputDim: number;\n  /**\n   * Integer >= 0. Dimension of the dense embedding.\n   */\n  outputDim: number;\n  /**\n   * Initializer for the `embeddings` matrix.\n   */\n  embeddingsInitializer?: InitializerIdentifier|Initializer;\n  /**\n   * Regularizer function applied to the `embeddings` matrix.\n   */\n  embeddingsRegularizer?: RegularizerIdentifier|Regularizer;\n  /**\n   * Regularizer function applied to the activation.\n   */\n  activityRegularizer?: RegularizerIdentifier|Regularizer;\n  /**\n   * Constraint function applied to the `embeddings` matrix.\n   */\n  embeddingsConstraint?: ConstraintIdentifier|Constraint;\n  /**\n   * Whether the input value 0 is a special \"padding\" value that should be\n   * masked out. This is useful when using recurrent layers which may take\n   * variable length input.\n   *\n   * If this is `True` then all subsequent layers in the model need to support\n   * masking or an exception will be raised. If maskZero is set to `True`, as a\n   * consequence, index 0 cannot be used in the vocabulary (inputDim should\n   * equal size of vocabulary + 1).\n   */\n  maskZero?: boolean;\n  /**\n   * Length of input sequences, when it is constant.\n   *\n   * This argument is required if you are going to connect `flatten` then\n   * `dense` layers upstream (without it, the shape of the dense outputs cannot\n   * be computed).\n   */\n  inputLength?: number|number[];\n}\n\nexport class Embedding extends Layer {\n  /** @nocollapse */\n  static className = 'Embedding';\n  private inputDim: number;\n  private outputDim: number;\n  private embeddingsInitializer: Initializer;\n  private maskZero: boolean;\n  private inputLength: number|number[];\n\n  private embeddings: LayerVariable = null;\n\n  readonly DEFAULT_EMBEDDINGS_INITIALIZER: InitializerIdentifier =\n      'randomUniform';\n  private readonly embeddingsRegularizer?: Regularizer;\n  private readonly embeddingsConstraint?: Constraint;\n\n  constructor(args: EmbeddingLayerArgs) {\n    super(args);\n    if (args.batchInputShape == null && args.inputShape == null) {\n      // Porting Note: This logic is copied from Layer's constructor, since we\n      // can't do exactly what the Python constructor does for Embedding().\n      // Specifically, the super constructor can not be called after the\n      // mutation of the `config` argument.\n      let batchSize: number = null;\n      if (args.batchSize != null) {\n        batchSize = args.batchSize;\n      }\n      if (args.inputLength == null) {\n        // Fix super-constructor to what it would have done if\n        // 'config.inputShape' were (None, )\n        this.batchInputShape = [batchSize, null];\n      } else {\n        // Fix super-constructor to what it would have done if\n        // 'config.inputShape' were (config.inputLength, )\n        this.batchInputShape =\n            [batchSize].concat(generic_utils.toList(args.inputLength));\n      }\n    }\n    this.inputDim = args.inputDim;\n    generic_utils.assertPositiveInteger(this.inputDim, 'inputDim');\n    this.outputDim = args.outputDim;\n    generic_utils.assertPositiveInteger(this.outputDim, 'outputDim');\n    this.embeddingsInitializer = getInitializer(\n        args.embeddingsInitializer || this.DEFAULT_EMBEDDINGS_INITIALIZER);\n    this.embeddingsRegularizer = getRegularizer(args.embeddingsRegularizer);\n    this.activityRegularizer = getRegularizer(args.activityRegularizer);\n    this.embeddingsConstraint = getConstraint(args.embeddingsConstraint);\n    this.maskZero = args.maskZero;\n    this.supportsMasking = args.maskZero;\n    this.inputLength = args.inputLength;\n  }\n\n  public override build(inputShape: Shape|Shape[]): void {\n    this.embeddings = this.addWeight(\n        'embeddings', [this.inputDim, this.outputDim], this.dtype,\n        this.embeddingsInitializer, this.embeddingsRegularizer, true,\n        this.embeddingsConstraint);\n    this.built = true;\n  }\n\n  // Override warnOnIncompatibleInputShape because an embedding layer allows\n  // the input to have varying ranks.\n  protected override warnOnIncompatibleInputShape(inputShape: Shape) {}\n\n  override computeMask(inputs: Tensor|Tensor[], mask?: Tensor|Tensor[]):\n      Tensor {\n    return tidy(() => {\n      if (!this.maskZero) {\n        return null;\n      } else {\n        inputs = getExactlyOneTensor(inputs);\n        return notEqual(inputs, zerosLike(inputs));\n      }\n    });\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    inputShape = getExactlyOneShape(inputShape);\n    if (this.inputLength == null) {\n      return [...inputShape, this.outputDim];\n    }\n    // inputLength can be an array if input is 3D or higher.\n    const inLens: number[] = generic_utils.toList(this.inputLength);\n    if (inLens.length !== inputShape.length - 1) {\n      throw new ValueError(\n          `\"inputLength\" is ${this.inputLength}, but received ` +\n          `input shape has shape ${inputShape}`);\n    } else {\n      let i = 0;\n      for (let k = 0; k < inLens.length; ++k) {\n        const s1 = inLens[k];\n        const s2 = inputShape[k + 1];\n        if ((s1 != null) && (s2 != null) && (s1 !== s2)) {\n          throw new ValueError(\n              `\"inputLength\" is ${this.inputLength}, but received ` +\n              `input shape has shape ${inputShape}`);\n        } else if (s1 == null) {\n          inLens[i] = s2;\n        }\n        i++;\n      }\n    }\n    return [inputShape[0], ...inLens, this.outputDim];\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    return tidy(() => {\n      this.invokeCallHook(inputs, kwargs);\n      // Embedding layer accepts only a single input.\n      let input = getExactlyOneTensor(inputs);\n      if (input.dtype !== 'int32') {\n        input = K.cast(input, 'int32');\n      }\n      const output =\n          K.gather(this.embeddings.read(), reshape(input, [input.size]));\n      return reshape(\n          output, getExactlyOneShape(this.computeOutputShape(input.shape)));\n    });\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config = {\n      inputDim: this.inputDim,\n      outputDim: this.outputDim,\n      embeddingsInitializer: serializeInitializer(this.embeddingsInitializer),\n      embeddingsRegularizer: serializeRegularizer(this.embeddingsRegularizer),\n      activityRegularizer: serializeRegularizer(this.activityRegularizer),\n      embeddingsConstraint: serializeConstraint(this.embeddingsConstraint),\n      maskZero: this.maskZero,\n      inputLength: this.inputLength\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(Embedding);\n"]} |
\ | No newline at end of file |