1 | /**
|
2 | * @license
|
3 | * Copyright 2018 Google LLC
|
4 | *
|
5 | * Use of this source code is governed by an MIT-style
|
6 | * license that can be found in the LICENSE file or at
|
7 | * https://opensource.org/licenses/MIT.
|
8 | * =============================================================================
|
9 | */
|
10 | /**
|
11 | * deeplearn.js backend.
|
12 | */
|
13 | import * as tfc from '@tensorflow/tfjs-core';
|
14 | import { onesLike as coreOnesLike, scalar, tensor1d, tidy, where, zerosLike as coreZerosLike } from '@tensorflow/tfjs-core';
|
15 | import { checkDataFormat } from '../common';
|
16 | import { NotImplementedError, ValueError } from '../errors';
|
17 | import * as math_utils from '../utils/math_utils';
|
18 | import { imageDataFormat } from './common';
|
19 | // tslint:enable
|
20 | /* Setting and getting backend from deeplearn.js. */
|
21 | // Default deeplearn.js backend is WebGL (GPU).
|
22 | let backend = 'webgl';
|
23 | export function setBackend(requestedBackend) {
|
24 | tfc.setBackend(requestedBackend);
|
25 | backend = requestedBackend;
|
26 | }
|
27 | export function getBackend() {
|
28 | return backend;
|
29 | }
|
30 | /**
|
31 | * Indicates whether the backend is operating symbolically.
|
32 | *
|
33 | * This function will be used to determine how to interpret user code. If
|
34 | * it returns true, calls to the backend construct a symbolic graph; if
|
35 | * it returns false, calls to the backend execute immediately.
|
36 | */
|
37 | export function isBackendSymbolic() {
|
38 | return false;
|
39 | }
|
40 | /**
|
41 | * Get the number of elements in a Tensor.
|
42 | * @param x The Tensor.
|
43 | * @return Number of elements in `x`.
|
44 | */
|
45 | export function countParams(x) {
|
46 | const shape = x.shape;
|
47 | if (shape.length > 0) {
|
48 | return shape.reduce((a, b) => a * b);
|
49 | }
|
50 | else {
|
51 | // Scalar.
|
52 | return 1;
|
53 | }
|
54 | }
|
55 | /**
|
56 | * Casts a tensor to a different dtype and returns it.
|
57 | * @param x Input tensor.
|
58 | * @param dtype String: 'float32'|'int32'|'bool'.
|
59 | * @returns Tensor of the specified `dtype`.
|
60 | */
|
61 | export function cast(x, dtype) {
|
62 | return tfc.cast(x, dtype);
|
63 | }
|
64 | /**
|
65 | * Adds a 1-sized dimension at index "axis".
|
66 | * @param x Input tensor.
|
67 | * @param axis Position where to add the new axis.
|
68 | * @returns Result of the dimension expansion.
|
69 | */
|
70 | export function expandDims(x, axis = -1) {
|
71 | const outShape = x.shape.slice();
|
72 | if (axis < 0) {
|
73 | axis = outShape.length + axis + 1;
|
74 | }
|
75 | outShape.splice(axis, 0, 1);
|
76 | return tfc.reshape(x, outShape);
|
77 | }
|
78 | /**
|
79 | * Repeats a 2D tensor.
|
80 | *
|
81 | * If `x` has shape `[samples, dim]` and `n` is 2, for example, the output
|
82 | * will have shape `[samples, 2, dim]`.
|
83 | *
|
84 | * @param x Input tensor.
|
85 | * @param n Integer, number of times to repeat.
|
86 | * @returns The result of the repeat operation.
|
87 | * @throws ValueError: If input tensor is not 2D.
|
88 | */
|
89 | export function repeat(x, n) {
|
90 | return tidy(() => {
|
91 | if (x.shape.length !== 2) {
|
92 | throw new ValueError(`repeat() expects a rank-2 tensor, but received a ` +
|
93 | `rank-${x.shape.length} tensor.`);
|
94 | }
|
95 | const y = expandDims(x, 1);
|
96 | return tile(y, [1, n, 1]);
|
97 | });
|
98 | }
|
99 | /**
|
100 | * Flatten a Tensor into 1D.
|
101 | * @param x Input tensor.
|
102 | * @return The result of the flattening `x`.
|
103 | */
|
104 | export function flatten(x) {
|
105 | const newShape = [math_utils.arrayProd(x.shape)];
|
106 | return tfc.reshape(x, newShape);
|
107 | }
|
108 | /**
|
109 | * Turn a nD tensor into a 2D tensor with same 0th dimension.
|
110 | * In other words, it flattens each data samples of a batch.
|
111 | *
|
112 | * @param x The tensor to flatten. The rank of this tensor is required to be 2
|
113 | * or higher.
|
114 | * @return The result of the flattening.
|
115 | */
|
116 | export function batchFlatten(x) {
|
117 | if (x.rank <= 1) {
|
118 | throw new ValueError(`batchFlatten requires a minimum rank of 2. Got rank: ${x.rank}.`);
|
119 | }
|
120 | const newShape = [x.shape[0], math_utils.arrayProd(x.shape, 1)];
|
121 | return tfc.reshape(x, newShape);
|
122 | }
|
123 | /**
|
124 | * Do slicing along the first axis.
|
125 | * @param array input `tf.Tensor`.
|
126 | * @param start starting index, inclusive.
|
127 | * @param size size of the slice along the first axis.
|
128 | * @returns result of the slicing.
|
129 | * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
|
130 | */
|
131 | export function sliceAlongFirstAxis(array, start, size) {
|
132 | return tidy(() => {
|
133 | switch (array.rank) {
|
134 | case 1:
|
135 | return tfc.slice1d(array, start, size);
|
136 | case 2:
|
137 | return tfc.slice2d(array, [start, 0], [size, array.shape[1]]);
|
138 | case 3:
|
139 | return tfc.slice3d(array, [start, 0, 0], [size, array.shape[1], array.shape[2]]);
|
140 | case 4:
|
141 | return tfc.slice4d(array, [start, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3]]);
|
142 | case 5:
|
143 | return tfc.slice(array, [start, 0, 0, 0, 0], [
|
144 | size, array.shape[1], array.shape[2], array.shape[3], array.shape[4]
|
145 | ]);
|
146 | case 6:
|
147 | return tfc.slice(array, [start, 0, 0, 0, 0, 0], [
|
148 | size, array.shape[1], array.shape[2], array.shape[3], array.shape[4],
|
149 | array.shape[5]
|
150 | ]);
|
151 | default:
|
152 | throw new ValueError(`sliceAlongFirstAxis() received an unsupported tensor rank: ` +
|
153 | `${array.rank}`);
|
154 | }
|
155 | });
|
156 | }
|
157 | /**
|
158 | * Do slicing along the last axis.
|
159 | * @param array input `tf.Tensor`.
|
160 | * @param start starting index, inclusive.
|
161 | * @param size size of the slice along the last axis.
|
162 | * @returns result of the slicing.
|
163 | * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
|
164 | */
|
165 | export function sliceAlongLastAxis(array, start, size) {
|
166 | return tidy(() => {
|
167 | switch (array.rank) {
|
168 | case 1:
|
169 | return tfc.slice1d(array, start, size);
|
170 | case 2:
|
171 | return tfc.slice2d(array, [0, start], [array.shape[0], size]);
|
172 | case 3:
|
173 | return tfc.slice3d(array, [0, 0, start], [array.shape[0], array.shape[1], size]);
|
174 | case 4:
|
175 | return tfc.slice4d(array, [0, 0, 0, start], [array.shape[0], array.shape[1], array.shape[2], size]);
|
176 | default:
|
177 | throw new ValueError(`sliceAlongLastAxis() received an unsupported tensor rank: ` +
|
178 | `${array.rank}`);
|
179 | }
|
180 | });
|
181 | }
|
182 | /**
|
183 | * Do slicing along the sepcified axis.
|
184 | * @param array input `tf.Tensor`.
|
185 | * @param start starting index, inclusive.
|
186 | * @param size of the slice along the chosen axis.
|
187 | * @param choose an axis.
|
188 | * @returns result of the slicing.
|
189 | * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
|
190 | */
|
191 | export function sliceAlongAxis(array, start, size, axis) {
|
192 | return tidy(() => {
|
193 | switch (array.rank) {
|
194 | case 1:
|
195 | return tfc.slice1d(array, start, size);
|
196 | case 2:
|
197 | switch (axis) {
|
198 | case 1:
|
199 | return sliceAlongFirstAxis(array, start, size);
|
200 | case 2:
|
201 | return sliceAlongLastAxis(array, start, size);
|
202 | default:
|
203 | throw new ValueError(`The axis is not within the rank of the tensor ` +
|
204 | `${axis}`);
|
205 | }
|
206 | case 3:
|
207 | switch (axis) {
|
208 | case 1:
|
209 | return sliceAlongFirstAxis(array, start, size);
|
210 | case 2:
|
211 | return tfc.slice3d(array, [0, start, 0], [array.shape[0], size, array.shape[2]]);
|
212 | case 3:
|
213 | return sliceAlongLastAxis(array, start, size);
|
214 | default:
|
215 | throw new ValueError(`The axis is not within the rank of the tensor ` +
|
216 | `${axis}`);
|
217 | }
|
218 | case 4:
|
219 | switch (axis) {
|
220 | case 1:
|
221 | return sliceAlongFirstAxis(array, start, size);
|
222 | case 2:
|
223 | return tfc.slice4d(array, [0, start, 0, 0], [array.shape[0], size, array.shape[2], array.shape[3]]);
|
224 | case 3:
|
225 | return tfc.slice4d(array, [0, 0, start, 0], [array.shape[0], array.shape[1], size, array.shape[3]]);
|
226 | case 4:
|
227 | return sliceAlongLastAxis(array, start, size);
|
228 | default:
|
229 | throw new ValueError(`The axis is not within the rank of the tensor ` +
|
230 | `${axis}`);
|
231 | }
|
232 | default:
|
233 | throw new ValueError(`sliceAlongLastAxis() received an unsupported tensor rank: ` +
|
234 | `${array.rank}`);
|
235 | }
|
236 | });
|
237 | }
|
238 | /**
|
239 | * Concatenates a list of tensors alongside the specified axis.
|
240 | * @param tensors `Array` of tensors to concatenate.
|
241 | * @param axis Concatenation axis.
|
242 | * @returns The result of the concatenation.
|
243 | */
|
244 | export function concatenate(tensors, axis = -1) {
|
245 | let rank;
|
246 | if (axis < 0) {
|
247 | rank = tensors[0].rank;
|
248 | if (rank !== 0) {
|
249 | axis = rank;
|
250 | }
|
251 | else {
|
252 | axis = 0;
|
253 | }
|
254 | }
|
255 | if (axis === tensors[0].rank) {
|
256 | // Porting Note: This is necessary because tfc.concat() requires axis to be
|
257 | // in the interval [-rank, rank).
|
258 | axis = -1;
|
259 | }
|
260 | // Porting Note: Sparse concat is not supported yet.
|
261 | return tfc.concat(tensors, axis);
|
262 | }
|
263 | /**
|
264 | * Concatenate two arrays along the first dimension.
|
265 | * @param a The 1st `tf.Tensor` to concatenate.
|
266 | * @param b The 2nd `tf.Tensor` to concatenate.
|
267 | * @returns Result of the concatenation.
|
268 | * @throws ValueError: If `a` is of an unsupported subtype of `tf.Tensor`.
|
269 | */
|
270 | export function concatAlongFirstAxis(a, b) {
|
271 | switch (a.rank) {
|
272 | case 1:
|
273 | return tfc.concat1d([a, b]);
|
274 | case 2:
|
275 | return tfc.concat2d([a, b], 0);
|
276 | case 3:
|
277 | return tfc.concat3d([a, b], 0);
|
278 | case 4:
|
279 | return tfc.concat4d([a, b], 0);
|
280 | default:
|
281 | throw new ValueError(`concatAlongFirstAxis() received an unsupported ` +
|
282 | `tensor rank: ${a.rank}`);
|
283 | }
|
284 | }
|
285 | /**
|
286 | * Creates a tensor by tiling `x` by `n`.
|
287 | * @param x A tensor.
|
288 | * @param n An Array of integers or a single integer. If an Array, the length
|
289 | * must be the same as the number of dimensions in `x`. If a single integer,
|
290 | * it will be treated as an Array of length 1.
|
291 | */
|
292 | export function tile(x, n) {
|
293 | if (!Array.isArray(n)) {
|
294 | n = [n];
|
295 | }
|
296 | if (x.rank !== n.length) {
|
297 | throw new ValueError(`The length of input n (${n.length}) does not match ` +
|
298 | `the number of dimensions in input x (${x.rank})`);
|
299 | }
|
300 | return tfc.tile(x, n);
|
301 | }
|
302 | /* Creation of random tensors. */
|
303 | /**
|
304 | * Get a tensor with normal distribution of values.
|
305 | *
|
306 | * @param shape Shape of the tensor.
|
307 | * @param mean mean value of the normal distribution.
|
308 | * @param stddev standard deviation of the normal distribution.
|
309 | * @param dtype
|
310 | * @param seed
|
311 | * @return The normal tensor.
|
312 | */
|
313 | export function randomNormal(shape, mean = 0.0, stddev = 1.0, dtype, seed) {
|
314 | return tfc.randomNormal(shape, mean, stddev, dtype, seed);
|
315 | }
|
316 | /* Linear Algebra */
|
317 | /**
|
318 | * Multiply two tensors and returns the result as a tensor.
|
319 | *
|
320 | * For 2D tensors, this is equivalent to matrix multiplication (matMul).
|
321 | * For tensors of higher ranks, it follows the Theano behavior,
|
322 | * (e.g. `(2, 3) * (4, 3, 5) -> (2, 4, 5)`). From the Theano documentation:
|
323 | *
|
324 | * For N dimensions it is a sum product over the last axis of x and the
|
325 | * second-to-last of y:
|
326 | *
|
327 | * @param a A tensor of at least rank 2.
|
328 | * @param b A tensor of at least rank 2.
|
329 | * @param activation (optional) A string identifying the activation
|
330 | * function.
|
331 | * @return Result of the dot operation.
|
332 | */
|
333 | export function dot(a, b, activation, bias) {
|
334 | if ((a.rank < 2) || (b.rank < 2)) {
|
335 | throw new NotImplementedError(`dot requires both inputs to be rank >= 2` +
|
336 | ` but got x shape = ${a.shape} and y shape = ${b.shape}`);
|
337 | }
|
338 | if (b.rank >= 3) {
|
339 | const xLastDim = a.shape.slice(-1)[0];
|
340 | const ySecondLastDim = b.shape.slice(-2)[0];
|
341 | if (xLastDim !== ySecondLastDim) {
|
342 | throw new NotImplementedError(`If rank y >= 3, then the second last dim` +
|
343 | ` of y must equal the last dim of x but got x shape = ${a.shape} and ` +
|
344 | ` y shape = ${b.shape}`);
|
345 | }
|
346 | }
|
347 | // Handle basic 2D x 2D case.
|
348 | if ((a.rank === 2) && (b.rank === 2)) {
|
349 | const transposeA = false;
|
350 | const transposeB = false;
|
351 | // tfc.fused.matMul only fuses certain activation functions. Unsupported
|
352 | // activation functions are treated as 'linear' activations, which is
|
353 | // equivalent to a no-op.
|
354 | return tfc.fused.matMul({
|
355 | a,
|
356 | b: b,
|
357 | transposeA,
|
358 | transposeB,
|
359 | bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
|
360 | activation
|
361 | });
|
362 | }
|
363 | else {
|
364 | // Reshape x into the analogous 2D Tensor.
|
365 | const aFirstDims = a.shape.slice(); // Holds all but the last dim of x.
|
366 | const aLastDim = aFirstDims.pop();
|
367 | a = tfc.reshape(a, [-1, aLastDim]);
|
368 | // Reshape y into the analogous 2D Tensor, and keep track of the
|
369 | // required dimensions to reproduce the output shape.
|
370 | const bShape = b.shape.slice();
|
371 | const bLastDim = bShape.pop();
|
372 | const ySecondLastDim = bShape.pop();
|
373 | const yOtherDims = [...bShape, bLastDim];
|
374 | // permutation should be like [r-2, 0, 1, 2, ... r-4, r-3, r-1]
|
375 | // where r is the rank of y.
|
376 | const perm = Array.from({ length: b.rank }, (_, i) => {
|
377 | if (i === 0) {
|
378 | return b.rank - 2;
|
379 | }
|
380 | else if (i <= b.rank - 2) {
|
381 | return i - 1;
|
382 | }
|
383 | return i;
|
384 | });
|
385 | b = tfc.reshape(tfc.transpose(b, perm), [ySecondLastDim, -1]);
|
386 | // Multiply x and y as 2D Tensors, and then reshape back to original.
|
387 | const outputShape = [...aFirstDims, ...yOtherDims];
|
388 | const transposeA = false;
|
389 | const transposeB = false;
|
390 | return tfc.reshape(tfc.fused.matMul({
|
391 | a,
|
392 | b,
|
393 | transposeA,
|
394 | transposeB,
|
395 | bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
|
396 | activation
|
397 | }), outputShape);
|
398 | }
|
399 | }
|
400 | /**
|
401 | * Compute the sign Tensor of an input Tensor.
|
402 | *
|
403 | * Elements of the input `tf.Tensor` that are === 0 are mapped to 0.
|
404 | * Elements of the input `tf.Tensor` that are > 0 are mapped to 1.
|
405 | * Elements of the input `tf.Tensor` that are < 0 are mapped to -1.
|
406 | *
|
407 | * @param x Input `tf.Tensor`.
|
408 | * @return The sign `tf.Tensor`.
|
409 | */
|
410 | export function sign(x) {
|
411 | // TODO(cais): Move to the core.
|
412 | return tidy(() => {
|
413 | const zerosLikeX = coreZerosLike(x);
|
414 | const onesLikeX = coreOnesLike(x);
|
415 | return where(tfc.equal(x, zerosLikeX), zerosLikeX, where(tfc.greater(x, coreZerosLike(x)), onesLikeX, tfc.mul(-1, onesLikeX)));
|
416 | });
|
417 | }
|
418 | /**
|
419 | * Computes the one-hot representation of an integer tensor.
|
420 | * @param indices nD integer tensor of shape
|
421 | * `(batch_size, dim1, dim2, ... dim(n-1))`
|
422 | * @param numClasses Integer, number of classes to consider.
|
423 | * @returns (n + 1)D one hot representation of the input
|
424 | * with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`
|
425 | */
|
426 | export function oneHot(indices, numClasses) {
|
427 | return tidy(() => {
|
428 | if (indices.rank !== 1) {
|
429 | throw new Error('Only 1D one-hot tensors are supported in the ' +
|
430 | 'deeplearn backend, at present.');
|
431 | }
|
432 | indices = tfc.cast(indices, 'int32');
|
433 | return tfc.cast(tfc.oneHot(indices, numClasses), 'float32');
|
434 | });
|
435 | }
|
436 | /* Elementary math functions. */
|
437 | /**
|
438 | * Retrieves the elements of indices `indices` in the tensor `reference`.
|
439 | * @param reference A tensor.
|
440 | * @param indices An integer tensor of indices or an `Array` of integers.
|
441 | * @param axis Axis along which to perform the gather operation.
|
442 | * @returns The result of the gathering as a tensor.
|
443 | */
|
444 | export function gather(reference, indices, axis) {
|
445 | return tidy(() => {
|
446 | if (Array.isArray(indices)) {
|
447 | indices = tensor1d(indices, 'int32');
|
448 | }
|
449 | else {
|
450 | indices = tfc.cast(indices, 'int32');
|
451 | }
|
452 | return tfc.gather(reference, indices, axis);
|
453 | });
|
454 | }
|
455 | /**
|
456 | * Element-wise square.
|
457 | * @param x Input tensor.
|
458 | * @return element-wise x^2
|
459 | */
|
460 | export function square(x) {
|
461 | return tfc.mul(x, x);
|
462 | }
|
463 | /**
|
464 | * Element-wise exponentiation.
|
465 | *
|
466 | * Porting Note: In PyKeras, `a` (the exponent) is a Python integer, which
|
467 | * takes advatnage of the backend's (e.g., TensorFlow's) automatic
|
468 | * conversion to tensor. Here we allow `a` to be either a number or a tensor.
|
469 | *
|
470 | * @param x The base tensor.
|
471 | * @param a The exponent, tensor or number. If a number, it is rounded to the
|
472 | * nearest integer and converted to a tensor.
|
473 | * @returns A tensor of the same shape as `x`.
|
474 | */
|
475 | export function pow(x, a) {
|
476 | return tidy(() => {
|
477 | if (typeof (a) === 'number') {
|
478 | a = scalar(Math.round(a), 'int32');
|
479 | }
|
480 | if (a.dtype !== 'int32') {
|
481 | throw new NotImplementedError(`Non-int32 dtype (${a.dtype}) is not supported by pow() yet`);
|
482 | }
|
483 | return tfc.pow(x, a);
|
484 | });
|
485 | }
|
486 | /**
|
487 | * Reshapes bias tensor according to rank of x.
|
488 | */
|
489 | function reshapeBias(xRank, bias, dataFormat) {
|
490 | const biasShape = bias.shape;
|
491 | if (bias.rank !== 1 && bias.rank !== xRank) {
|
492 | throw new ValueError(`Unexpected bias dimensions: ${bias.rank}` +
|
493 | `; expected it to be 1 or ${xRank}`);
|
494 | }
|
495 | if (xRank === 5) {
|
496 | if (dataFormat === 'channelsFirst') {
|
497 | if (biasShape.length === 1) {
|
498 | return tfc.reshape(bias, [1, biasShape[0], 1, 1, 1]);
|
499 | }
|
500 | else {
|
501 | return tfc.reshape(bias, [1, biasShape[3], biasShape[0], biasShape[1], biasShape[2]]);
|
502 | }
|
503 | }
|
504 | else if (dataFormat === 'channelsLast') {
|
505 | if (biasShape.length === 1) {
|
506 | return tfc.reshape(bias, [1, 1, 1, 1, biasShape[0]]);
|
507 | }
|
508 | else {
|
509 | return tfc.reshape(bias, [1].concat(biasShape));
|
510 | }
|
511 | }
|
512 | }
|
513 | else if (xRank === 4) {
|
514 | if (dataFormat === 'channelsFirst') {
|
515 | if (biasShape.length === 1) {
|
516 | return tfc.reshape(bias, [1, biasShape[0], 1, 1]);
|
517 | }
|
518 | else {
|
519 | return tfc.reshape(bias, [1, biasShape[2], biasShape[0], biasShape[1]]);
|
520 | }
|
521 | }
|
522 | else if (dataFormat === 'channelsLast') {
|
523 | if (biasShape.length === 1) {
|
524 | return tfc.reshape(bias, [1, 1, 1, biasShape[0]]);
|
525 | }
|
526 | else {
|
527 | return tfc.reshape(bias, [1].concat(biasShape));
|
528 | }
|
529 | }
|
530 | }
|
531 | else if (xRank === 3) {
|
532 | if (dataFormat === 'channelsFirst') {
|
533 | if (biasShape.length === 1) {
|
534 | return tfc.reshape(bias, [1, biasShape[0], 1]);
|
535 | }
|
536 | else {
|
537 | return tfc.reshape(bias, [1, biasShape[1], biasShape[0]]);
|
538 | }
|
539 | }
|
540 | else if (dataFormat === 'channelsLast') {
|
541 | if (biasShape.length === 1) {
|
542 | return tfc.reshape(bias, [1, 1, biasShape[0]]);
|
543 | }
|
544 | else {
|
545 | return tfc.reshape(bias, [1].concat(biasShape));
|
546 | }
|
547 | }
|
548 | }
|
549 | else if (xRank < 3) {
|
550 | return bias;
|
551 | }
|
552 | throw new ValueError(`Unsupported input rank by biasAdd: ${bias.rank}`);
|
553 | }
|
554 | /* Neural-network operations. */
|
555 | /**
|
556 | * Add a bias to a tensor.
|
557 | *
|
558 | * @param x The tensor to add the bias to.
|
559 | * @param bias The bias to add to `x`. Must be 1D or the same rank as `x`.
|
560 | * @return Result of the bias adding.
|
561 | * @throws ValueError: If the rank of `bias` is incorrect.
|
562 | */
|
563 | export function biasAdd(x, bias, dataFormat) {
|
564 | return tidy(() => {
|
565 | if (dataFormat == null) {
|
566 | dataFormat = imageDataFormat();
|
567 | }
|
568 | checkDataFormat(dataFormat);
|
569 | return tfc.add(x, reshapeBias(x.rank, bias, dataFormat));
|
570 | });
|
571 | }
|
572 | /**
|
573 | * Exponential linear unit (ELU).
|
574 | * @param x A tensor or variable to compute the activation function for.
|
575 | * @param alpha: A scalar, a scaling factor for the negative section.
|
576 | * @return Output of the ELU operation.
|
577 | */
|
578 | export function elu(x, alpha = 1) {
|
579 | // TODO(cais): Add support for alpha values other than 1.
|
580 | if (alpha !== 1) {
|
581 | throw new NotImplementedError(`Support for alpha values other than 1 (${alpha}) is not implemented ` +
|
582 | `yet.`);
|
583 | }
|
584 | return tfc.elu(x);
|
585 | }
|
586 | /**
|
587 | * Softsign of a tensor.
|
588 | *
|
589 | * Defined as x / (abs(x) + 1), element-wise.
|
590 | *
|
591 | * @param x: Input.
|
592 | * @returns Output.
|
593 | */
|
594 | export function softsign(x) {
|
595 | return tidy(() => tfc.div(x, tfc.add(tfc.abs(x), 1)));
|
596 | }
|
597 | /**
|
598 | * Sets entries in `x` to zero at random, while scaling the entire tensor.
|
599 | *
|
600 | * @param x input tensor.
|
601 | * @param level fraction of the entries in the tensor that will be set to 0.
|
602 | * @param noiseShape shape of randomly generated keep/drop flags, must be
|
603 | * broadcastable to the shape of `x`. Optional.
|
604 | * @param seed random seed to ensure determinism. Optional.
|
605 | * @returns Result of the dropout operation.
|
606 | */
|
607 | export function dropout(x, level, noiseShape, seed) {
|
608 | return tidy(() => tfc.dropout(x, level, noiseShape, seed));
|
609 | }
|
610 | /**
|
611 | * Element-wise, segment-wise linear approximation of sigmoid.
|
612 | *
|
613 | * Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
|
614 | * In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
|
615 | *
|
616 | * @param x Input tensor.
|
617 | * @returns Output tensor.
|
618 | */
|
619 | export function hardSigmoid(x) {
|
620 | return tidy(() => {
|
621 | const y = tfc.add(.5, tfc.mul(.2, x));
|
622 | return tfc.clipByValue(y, 0, 1);
|
623 | });
|
624 | }
|
625 | /**
|
626 | * Invoke `x` in the training phase, and `alt` otherwise.
|
627 | *
|
628 | * Porting Note: We do not create placeholder tensors for the `training`
|
629 | * boolean flag here, because there is no such thing in the TF.js imperative
|
630 | * backend.
|
631 | *
|
632 | * @param x The function to invoke iff `training` is `true`.
|
633 | * @param alt The function to invoke iff `training` is `false`.
|
634 | * @param training Boolean flag for whether training phase is active.
|
635 | * @returns The return value of `x()` if `training` is `true`, or the return
|
636 | * value of `alt()` if `training` is `false`.
|
637 | */
|
638 | export function inTrainPhase(x, alt, training = false) {
|
639 | return training ? x() : alt();
|
640 | }
|
641 | //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"tfjs_backend.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/backend/tfjs_backend.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH;;GAEG;AAEH,OAAO,KAAK,GAAG,MAAM,uBAAuB,CAAC;AAC7C,OAAO,EAAC,QAAQ,IAAI,YAAY,EAAE,MAAM,EAAoB,QAAQ,EAA0C,IAAI,EAAE,KAAK,EAAE,SAAS,IAAI,aAAa,EAAC,MAAM,uBAAuB,CAAC;AACpL,OAAO,EAAC,eAAe,EAAC,MAAM,WAAW,CAAC;AAC1C,OAAO,EAAC,mBAAmB,EAAE,UAAU,EAAC,MAAM,WAAW,CAAC;AAG1D,OAAO,KAAK,UAAU,MAAM,qBAAqB,CAAC;AAElD,OAAO,EAAC,eAAe,EAAC,MAAM,UAAU,CAAC;AAEzC,gBAAgB;AAEhB,oDAAoD;AAEpD,+CAA+C;AAC/C,IAAI,OAAO,GAAkB,OAAO,CAAC;AAErC,MAAM,UAAU,UAAU,CAAC,gBAA+B;IACxD,GAAG,CAAC,UAAU,CAAC,gBAAgB,CAAC,CAAC;IACjC,OAAO,GAAG,gBAAgB,CAAC;AAC7B,CAAC;AAED,MAAM,UAAU,UAAU;IACxB,OAAO,OAAO,CAAC;AACjB,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,iBAAiB;IAC/B,OAAO,KAAK,CAAC;AACf,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,WAAW,CAAC,CAAW;IACrC,MAAM,KAAK,GAAG,CAAC,CAAC,KAAK,CAAC;IACtB,IAAI,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE;QACpB,OAAO,KAAK,CAAC,MAAM,CAAC,CAAC,CAAS,EAAE,CAAS,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;KACtD;SAAM;QACL,UAAU;QACV,OAAO,CAAC,CAAC;KACV;AACH,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,IAAI,CAAC,CAAS,EAAE,KAAmB;IACjD,OAAO,GAAG,CAAC,IAAI,CAAC,CAAC,EAAE,KAAK,CAAC,CAAC;AAC5B,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,UAAU,CAAC,CAAS,EAAE,IAAI,GAAG,CAAC,CAAC;IAC7C,MAAM,QAAQ,GAAG,CAAC,CAAC,KAAK,CAAC,KAAK,EAAE,CAAC;IACjC,IAAI,IAAI,GAAG,CAAC,EAAE;QACZ,IAAI,GAAG,QAAQ,CAAC,MAAM,GAAG,IAAI,GAAG,CAAC,CAAC;KACnC;IACD,QAAQ,CAAC,MAAM,CAAC,IAAI,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;IAC5B,OAAO,GAAG,CAAC,OAAO,CAAC,CAAC,EAAE,QAAQ,CAAC,CAAC;AAClC,CAAC;AAED;;;;;;;;;;GAUG;AACH,MAAM,UAAU,MAAM,CAAC,CAAS,EAAE,CAAS;IACzC,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,IAAI,CAAC,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC,EAAE;YACxB,MAAM,IAAI,UAAU,CAChB,mDAAmD;gBACnD,QAAQ,CAAC,CAAC,KAAK,CAAC,MAAM,UAAU,CAAC,CAAC;SACvC;QACD,MAAM,CAAC,GAAG,UAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;QAC3B,OAAO,IAAI,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAC5B,CAAC,CAAC,CAAC;AACL,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,OAAO,CAAC,CAAS;IAC/B,MAAM,QAAQ,GAAG,CAAC,UAAU,CAAC,SAAS,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC;IACjD,OAAO,GAAG,CAAC,OAAO,CAAC,CAAC,EAAE,QAAQ,CAAC,CAAC;AAClC,CAAC;AAED;;;;;;;GAOG;AACH,MAAM,UAAU,YAAY,CAAC,CAAS;IACpC,IAAI,CAAC,CAAC,IAAI,IAAI,CAAC,EAAE;QACf,MAAM,IAAI,UAAU,CAChB,wDAAwD,CAAC,CAAC,IAAI,GAAG,CAAC,CAAC;KACxE;IACD,MAAM,QAAQ,GAAG,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,SAAS,CAAC,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,CAAC;IAChE,OAAO,GAAG,CAAC,OAAO,CAAC,CAAC,EAAE,QAAQ,CAAC,CAAC;AAClC,CAAC;AAED;;;;;;;GAOG;AACH,MAAM,UAAU,mBAAmB,CAC/B,KAAa,EAAE,KAAa,EAAE,IAAY;IAC5C,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,QAAQ,KAAK,CAAC,IAAI,EAAE;YAClB,KAAK,CAAC;gBACJ,OAAO,GAAG,CAAC,OAAO,CAAC,KAAiB,EAAE,KAAK,EAAE,IAAI,CAAC,CAAC;YACrD,KAAK,CAAC;gBACJ,OAAO,GAAG,CAAC,OAAO,CACd,KAAiB,EAAE,CAAC,KAAK,EAAE,CAAC,CAAC,EAAE,CAAC,IAAI,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC7D,KAAK,CAAC;gBACJ,OAAO,GAAG,CAAC,OAAO,CACd,KAAiB,EAAE,CAAC,KAAK,EAAE,CAAC,EAAE,CAAC,CAAC,EAChC,CAAC,IAAI,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC9C,KAAK,CAAC;gBACJ,OAAO,GAAG,CAAC,OAAO,CACd,KAAiB,EAAE,CAAC,KAAK,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACnC,CAAC,IAAI,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC9D,KAAK,CAAC;gBACJ,OAAO,GAAG,CAAC,KAAK,CAAC,KAAiB,EAAE,CAAC,KAAK,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE;oBACvD,IAAI,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC;iBACrE,CAAC,CAAC;YACL,KAAK,CAAC;gBACJ,OAAO,GAAG,CAAC,KAAK,CAAC,KAAK,EAAE,CAAC,KAAK,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE;oBAC9C,IAAI,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC;oBACpE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC;iBACf,CAAC,CAAC;YACL;gBACE,MAAM,IAAI,UAAU,CAChB,6DAA6D;oBAC7D,GAAG,KAAK,CAAC,IAAI,EAAE,CAAC,CAAC;SACxB;IACH,CAAC,CAAC,CAAC;AACL,CAAC;AAED;;;;;;;GAOG;AACH,MAAM,UAAU,kBAAkB,CAC9B,KAAa,EAAE,KAAa,EAAE,IAAY;IAC5C,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,QAAQ,KAAK,CAAC,IAAI,EAAE;YAClB,KAAK,CAAC;gBACJ,OAAO,GAAG,CAAC,OAAO,CAAC,KAAiB,EAAE,KAAK,EAAE,IAAI,CAAC,CAAC;YACrD,KAAK,CAAC;gBACJ,OAAO,GAAG,CAAC,OAAO,CACd,KAAiB,EAAE,CAAC,CAAC,EAAE,KAAK,CAAC,EAAE,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC,CAAC;YAC7D,KAAK,CAAC;gBACJ,OAAO,GAAG,CAAC,OAAO,CACd,KAAiB,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,KAAK,CAAC,EAChC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC,CAAC;YAC9C,KAAK,CAAC;gBACJ,OAAO,GAAG,CAAC,OAAO,CACd,KAAiB,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,KAAK,CAAC,EACnC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC,CAAC;YAC9D;gBACE,MAAM,IAAI,UAAU,CAChB,4DAA4D;oBAC5D,GAAG,KAAK,CAAC,IAAI,EAAE,CAAC,CAAC;SACxB;IACH,CAAC,CAAC,CAAC;AACL,CAAC;AAED;;;;;;;;GAQG;AACH,MAAM,UAAU,cAAc,CAC1B,KAAa,EAAE,KAAa,EAAE,IAAY,EAAE,IAAY;IAC1D,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,QAAQ,KAAK,CAAC,IAAI,EAAE;YAClB,KAAK,CAAC;gBACJ,OAAO,GAAG,CAAC,OAAO,CAAC,KAAiB,EAAE,KAAK,EAAE,IAAI,CAAC,CAAC;YACrD,KAAK,CAAC;gBACJ,QAAQ,IAAI,EAAE;oBACZ,KAAK,CAAC;wBACJ,OAAO,mBAAmB,CAAC,KAAK,EAAE,KAAK,EAAE,IAAI,CAAC,CAAC;oBACjD,KAAK,CAAC;wBACJ,OAAO,kBAAkB,CAAC,KAAK,EAAE,KAAK,EAAE,IAAI,CAAC,CAAC;oBAChD;wBACE,MAAM,IAAI,UAAU,CAChB,gDAAgD;4BAChD,GAAG,IAAI,EAAE,CAAC,CAAC;iBAClB;YACH,KAAK,CAAC;gBACJ,QAAQ,IAAI,EAAE;oBACZ,KAAK,CAAC;wBACJ,OAAO,mBAAmB,CAAC,KAAK,EAAE,KAAK,EAAE,IAAI,CAAC,CAAC;oBACjD,KAAK,CAAC;wBACJ,OAAO,GAAG,CAAC,OAAO,CACd,KAAiB,EAAE,CAAC,CAAC,EAAE,KAAK,EAAE,CAAC,CAAC,EAChC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,IAAI,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;oBAC9C,KAAK,CAAC;wBACJ,OAAO,kBAAkB,CAAC,KAAK,EAAE,KAAK,EAAE,IAAI,CAAC,CAAC;oBAChD;wBACE,MAAM,IAAI,UAAU,CAChB,gDAAgD;4BAChD,GAAG,IAAI,EAAE,CAAC,CAAC;iBAClB;YACH,KAAK,CAAC;gBACJ,QAAQ,IAAI,EAAE;oBACZ,KAAK,CAAC;wBACJ,OAAO,mBAAmB,CAAC,KAAK,EAAE,KAAK,EAAE,IAAI,CAAC,CAAC;oBACjD,KAAK,CAAC;wBACJ,OAAO,GAAG,CAAC,OAAO,CACd,KAAiB,EAAE,CAAC,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE,CAAC,CAAC,EACnC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,IAAI,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;oBAC9D,KAAK,CAAC;wBACJ,OAAO,GAAG,CAAC,OAAO,CACd,KAAiB,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,KAAK,EAAE,CAAC,CAAC,EACnC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,IAAI,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;oBAC9D,KAAK,CAAC;wBACJ,OAAO,kBAAkB,CAAC,KAAK,EAAE,KAAK,EAAE,IAAI,CAAC,CAAC;oBAChD;wBACE,MAAM,IAAI,UAAU,CAChB,gDAAgD;4BAChD,GAAG,IAAI,EAAE,CAAC,CAAC;iBAClB;YACH;gBACE,MAAM,IAAI,UAAU,CAChB,4DAA4D;oBAC5D,GAAG,KAAK,CAAC,IAAI,EAAE,CAAC,CAAC;SACxB;IACH,CAAC,CAAC,CAAC;AACL,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,WAAW,CAAC,OAAiB,EAAE,IAAI,GAAG,CAAC,CAAC;IACtD,IAAI,IAAY,CAAC;IACjB,IAAI,IAAI,GAAG,CAAC,EAAE;QACZ,IAAI,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC;QACvB,IAAI,IAAI,KAAK,CAAC,EAAE;YACd,IAAI,GAAG,IAAI,CAAC;SACb;aAAM;YACL,IAAI,GAAG,CAAC,CAAC;SACV;KACF;IACD,IAAI,IAAI,KAAK,OAAO,CAAC,CAAC,CAAC,CAAC,IAAI,EAAE;QAC5B,2EAA2E;QAC3E,mCAAmC;QACnC,IAAI,GAAG,CAAC,CAAC,CAAC;KACX;IACD,oDAAoD;IACpD,OAAO,GAAG,CAAC,MAAM,CAAC,OAAO,EAAE,IAAI,CAAC,CAAC;AACnC,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,oBAAoB,CAAC,CAAS,EAAE,CAAS;IACvD,QAAQ,CAAC,CAAC,IAAI,EAAE;QACd,KAAK,CAAC;YACJ,OAAO,GAAG,CAAC,QAAQ,CAAC,CAAC,CAAa,EAAE,CAAa,CAAC,CAAC,CAAC;QACtD,KAAK,CAAC;YACJ,OAAO,GAAG,CAAC,QAAQ,CAAC,CAAC,CAAa,EAAE,CAAa,CAAC,EAAE,CAAC,CAAC,CAAC;QACzD,KAAK,CAAC;YACJ,OAAO,GAAG,CAAC,QAAQ,CAAC,CAAC,CAAa,EAAE,CAAa,CAAC,EAAE,CAAC,CAAC,CAAC;QACzD,KAAK,CAAC;YACJ,OAAO,GAAG,CAAC,QAAQ,CAAC,CAAC,CAAa,EAAE,CAAa,CAAC,EAAE,CAAC,CAAC,CAAC;QACzD;YACE,MAAM,IAAI,UAAU,CAChB,iDAAiD;gBACjD,gBAAgB,CAAC,CAAC,IAAI,EAAE,CAAC,CAAC;KACjC;AACH,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,IAAI,CAAC,CAAS,EAAE,CAAkB;IAChD,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE;QACrB,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;KACT;IACD,IAAI,CAAC,CAAC,IAAI,KAAK,CAAC,CAAC,MAAM,EAAE;QACvB,MAAM,IAAI,UAAU,CAChB,0BAA0B,CAAC,CAAC,MAAM,mBAAmB;YACrD,wCAAwC,CAAC,CAAC,IAAI,GAAG,CAAC,CAAC;KACxD;IACD,OAAO,GAAG,CAAC,IAAI,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;AACxB,CAAC;AAED,iCAAiC;AAEjC;;;;;;;;;GASG;AACH,MAAM,UAAU,YAAY,CACxB,KAAY,EAAE,IAAI,GAAG,GAAG,EAAE,MAAM,GAAG,GAAG,EAAE,KAAyB,EACjE,IAAa;IACf,OAAO,GAAG,CAAC,YAAY,CAAC,KAAK,EAAE,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,CAAC,CAAC;AAC5D,CAAC;AAED,oBAAoB;AAEpB;;;;;;;;;;;;;;;GAeG;AACH,MAAM,UAAU,GAAG,CACf,CAAS,EAAE,CAAS,EAAE,UAAiC,EACvD,IAAa;IACf,IAAI,CAAC,CAAC,CAAC,IAAI,GAAG,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,IAAI,GAAG,CAAC,CAAC,EAAE;QAChC,MAAM,IAAI,mBAAmB,CACzB,0CAA0C;YAC1C,sBAAsB,CAAC,CAAC,KAAK,kBAAkB,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC;KAC/D;IACD,IAAI,CAAC,CAAC,IAAI,IAAI,CAAC,EAAE;QACf,MAAM,QAAQ,GAAG,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QACtC,MAAM,cAAc,GAAG,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC5C,IAAI,QAAQ,KAAK,cAAc,EAAE;YAC/B,MAAM,IAAI,mBAAmB,CACzB,0CAA0C;gBAC1C,wDACI,CAAC,CAAC,KAAK,OAAO;gBAClB,cAAc,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC;SAC9B;KACF;IACD,6BAA6B;IAC7B,IAAI,CAAC,CAAC,CAAC,IAAI,KAAK,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,IAAI,KAAK,CAAC,CAAC,EAAE;QACpC,MAAM,UAAU,GAAG,KAAK,CAAC;QACzB,MAAM,UAAU,GAAG,KAAK,CAAC;QACzB,wEAAwE;QACxE,qEAAqE;QACrE,yBAAyB;QACzB,OAAO,GAAG,CAAC,KAAK,CAAC,MAAM,CAAC;YACtB,CAAC;YACD,CAAC,EAAE,CAAa;YAChB,UAAU;YACV,UAAU;YACV,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,WAAW,CAAC,CAAC,CAAC,IAAI,EAAE,IAAI,EAAE,eAAe,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI;YAChE,UAAU;SACX,CAAC,CAAC;KACJ;SAAM;QACL,0CAA0C;QAC1C,MAAM,UAAU,GAAG,CAAC,CAAC,KAAK,CAAC,KAAK,EAAE,CAAC,CAAE,mCAAmC;QACxE,MAAM,QAAQ,GAAG,UAAU,CAAC,GAAG,EAAE,CAAC;QAClC,CAAC,GAAG,GAAG,CAAC,OAAO,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,CAAC,CAAC;QAEnC,gEAAgE;QAChE,qDAAqD;QACrD,MAAM,MAAM,GAAG,CAAC,CAAC,KAAK,CAAC,KAAK,EAAE,CAAC;QAC/B,MAAM,QAAQ,GAAG,MAAM,CAAC,GAAG,EAAE,CAAC;QAC9B,MAAM,cAAc,GAAG,MAAM,CAAC,GAAG,EAAE,CAAC;QACpC,MAAM,UAAU,GAAG,CAAC,GAAG,MAAM,EAAE,QAAQ,CAAC,CAAC;QACzC,+DAA+D;QAC/D,4BAA4B;QAC5B,MAAM,IAAI,GAAG,KAAK,CAAC,IAAI,CAAC,EAAC,MAAM,EAAE,CAAC,CAAC,IAAI,EAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE;YACjD,IAAI,CAAC,KAAK,CAAC,EAAE;gBACX,OAAO,CAAC,CAAC,IAAI,GAAG,CAAC,CAAC;aACnB;iBAAM,IAAI,CAAC,IAAI,CAAC,CAAC,IAAI,GAAG,CAAC,EAAE;gBAC1B,OAAO,CAAC,GAAG,CAAC,CAAC;aACd;YACD,OAAO,CAAC,CAAC;QACX,CAAC,CAAC,CAAC;QACH,CAAC,GAAG,GAAG,CAAC,OAAO,CAAC,GAAG,CAAC,SAAS,CAAC,CAAC,EAAE,IAAI,CAAC,EAAE,CAAC,cAAc,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QAE9D,qEAAqE;QACrE,MAAM,WAAW,GAAG,CAAC,GAAG,UAAU,EAAE,GAAG,UAAU,CAAC,CAAC;QACnD,MAAM,UAAU,GAAG,KAAK,CAAC;QACzB,MAAM,UAAU,GAAG,KAAK,CAAC;QACzB,OAAO,GAAG,CAAC,OAAO,CACd,GAAG,CAAC,KAAK,CAAC,MAAM,CAAC;YACf,CAAC;YACD,CAAC;YACD,UAAU;YACV,UAAU;YACV,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,WAAW,CAAC,CAAC,CAAC,IAAI,EAAE,IAAI,EAAE,eAAe,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI;YAChE,UAAU;SACX,CAAC,EACF,WAAW,CAAC,CAAC;KAClB;AACH,CAAC;AAED;;;;;;;;;GASG;AACH,MAAM,UAAU,IAAI,CAAC,CAAS;IAC5B,gCAAgC;IAChC,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,MAAM,UAAU,GAAG,aAAa,CAAC,CAAC,CAAC,CAAC;QACpC,MAAM,SAAS,GAAG,YAAY,CAAC,CAAC,CAAC,CAAC;QAClC,OAAO,KAAK,CACR,GAAG,CAAC,KAAK,CAAC,CAAC,EAAE,UAAU,CAAC,EAAE,UAAU,EACpC,KAAK,CACD,GAAG,CAAC,OAAO,CAAC,CAAC,EAAE,aAAa,CAAC,CAAC,CAAC,CAAC,EAAE,SAAS,EAC3C,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC;IACnC,CAAC,CAAC,CAAC;AACL,CAAC;AAED;;;;;;;GAOG;AACH,MAAM,UAAU,MAAM,CAAC,OAAe,EAAE,UAAkB;IACxD,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,IAAI,OAAO,CAAC,IAAI,KAAK,CAAC,EAAE;YACtB,MAAM,IAAI,KAAK,CACX,+CAA+C;gBAC/C,gCAAgC,CAAC,CAAC;SACvC;QACD,OAAO,GAAG,GAAG,CAAC,IAAI,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;QACrC,OAAO,GAAG,CAAC,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,OAAmB,EAAE,UAAU,CAAC,EAAE,SAAS,CAAC,CAAC;IAC1E,CAAC,CAAC,CAAC;AACL,CAAC;AAED,gCAAgC;AAEhC;;;;;;GAMG;AACH,MAAM,UAAU,MAAM,CAClB,SAAiB,EAAE,OAA0B,EAAE,IAAa;IAC9D,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,IAAI,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,EAAE;YAC1B,OAAO,GAAG,QAAQ,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;SACtC;aAAM;YACL,OAAO,GAAG,GAAG,CAAC,IAAI,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;SACtC;QACD,OAAO,GAAG,CAAC,MAAM,CAAC,SAAS,EAAE,OAAO,EAAE,IAAI,CAAC,CAAC;IAC9C,CAAC,CAAC,CAAC;AACL,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,MAAM,CAAC,CAAS;IAC9B,OAAO,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;AACvB,CAAC;AAED;;;;;;;;;;;GAWG;AACH,MAAM,UAAU,GAAG,CAAC,CAAS,EAAE,CAAgB;IAC7C,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,IAAI,OAAO,CAAC,CAAC,CAAC,KAAK,QAAQ,EAAE;YAC3B,CAAC,GAAG,MAAM,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;SACpC;QACD,IAAI,CAAC,CAAC,KAAK,KAAK,OAAO,EAAE;YACvB,MAAM,IAAI,mBAAmB,CACzB,oBAAoB,CAAC,CAAC,KAAK,iCAAiC,CAAC,CAAC;SACnE;QACD,OAAO,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;IACvB,CAAC,CAAC,CAAC;AACL,CAAC;AAED;;GAEG;AACH,SAAS,WAAW,CAAC,KAAa,EAAE,IAAY,EAAE,UAAkB;IAClE,MAAM,SAAS,GAAG,IAAI,CAAC,KAAK,CAAC;IAE7B,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,IAAI,IAAI,CAAC,IAAI,KAAK,KAAK,EAAE;QAC1C,MAAM,IAAI,UAAU,CAChB,+BAA+B,IAAI,CAAC,IAAI,EAAE;YAC1C,4BAA4B,KAAK,EAAE,CAAC,CAAC;KAC1C;IAED,IAAI,KAAK,KAAK,CAAC,EAAE;QACf,IAAI,UAAU,KAAK,eAAe,EAAE;YAClC,IAAI,SAAS,CAAC,MAAM,KAAK,CAAC,EAAE;gBAC1B,OAAO,GAAG,CAAC,OAAO,CAAC,IAAI,EAAE,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;aACtD;iBAAM;gBACL,OAAO,GAAG,CAAC,OAAO,CACd,IAAI,EAAE,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;aACxE;SACF;aAAM,IAAI,UAAU,KAAK,cAAc,EAAE;YACxC,IAAI,SAAS,CAAC,MAAM,KAAK,CAAC,EAAE;gBAC1B,OAAO,GAAG,CAAC,OAAO,CAAC,IAAI,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;aACtD;iBAAM;gBACL,OAAO,GAAG,CAAC,OAAO,CAAC,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,SAAS,CAAC,CAAC,CAAC;aACjD;SACF;KACF;SAAM,IAAI,KAAK,KAAK,CAAC,EAAE;QACtB,IAAI,UAAU,KAAK,eAAe,EAAE;YAClC,IAAI,SAAS,CAAC,MAAM,KAAK,CAAC,EAAE;gBAC1B,OAAO,GAAG,CAAC,OAAO,CAAC,IAAI,EAAE,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;aACnD;iBAAM;gBACL,OAAO,GAAG,CAAC,OAAO,CAAC,IAAI,EAAE,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;aACzE;SACF;aAAM,IAAI,UAAU,KAAK,cAAc,EAAE;YACxC,IAAI,SAAS,CAAC,MAAM,KAAK,CAAC,EAAE;gBAC1B,OAAO,GAAG,CAAC,OAAO,CAAC,IAAI,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;aACnD;iBAAM;gBACL,OAAO,GAAG,CAAC,OAAO,CAAC,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,SAAS,CAAC,CAAC,CAAC;aACjD;SACF;KACF;SAAM,IAAI,KAAK,KAAK,CAAC,EAAE;QACtB,IAAI,UAAU,KAAK,eAAe,EAAE;YAClC,IAAI,SAAS,CAAC,MAAM,KAAK,CAAC,EAAE;gBAC1B,OAAO,GAAG,CAAC,OAAO,CAAC,IAAI,EAAE,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;aAChD;iBAAM;gBACL,OAAO,GAAG,CAAC,OAAO,CAAC,IAAI,EAAE,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;aAC3D;SACF;aAAM,IAAI,UAAU,KAAK,cAAc,EAAE;YACxC,IAAI,SAAS,CAAC,MAAM,KAAK,CAAC,EAAE;gBAC1B,OAAO,GAAG,CAAC,OAAO,CAAC,IAAI,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;aAChD;iBAAM;gBACL,OAAO,GAAG,CAAC,OAAO,CAAC,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,SAAS,CAAC,CAAC,CAAC;aACjD;SACF;KACF;SAAM,IAAI,KAAK,GAAG,CAAC,EAAE;QACpB,OAAO,IAAI,CAAC;KACb;IACD,MAAM,IAAI,UAAU,CAAC,sCAAsC,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC;AAC1E,CAAC;AAED,gCAAgC;AAEhC;;;;;;;GAOG;AACH,MAAM,UAAU,OAAO,CACnB,CAAS,EAAE,IAAY,EAAE,UAAuB;IAClD,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,IAAI,UAAU,IAAI,IAAI,EAAE;YACtB,UAAU,GAAG,eAAe,EAAE,CAAC;SAChC;QACD,eAAe,CAAC,UAAU,CAAC,CAAC;QAE5B,OAAO,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,WAAW,CAAC,CAAC,CAAC,IAAI,EAAE,IAAI,EAAE,UAAU,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAC;AACL,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,GAAG,CAAC,CAAS,EAAE,KAAK,GAAG,CAAC;IACtC,yDAAyD;IACzD,IAAI,KAAK,KAAK,CAAC,EAAE;QACf,MAAM,IAAI,mBAAmB,CACzB,0CAA0C,KAAK,uBAAuB;YACtE,MAAM,CAAC,CAAC;KACb;IACD,OAAO,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;AACpB,CAAC;AAED;;;;;;;GAOG;AACH,MAAM,UAAU,QAAQ,CAAC,CAAS;IAChC,OAAO,IAAI,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;AACxD,CAAC;AAED;;;;;;;;;GASG;AACH,MAAM,UAAU,OAAO,CACnB,CAAS,EAAE,KAAa,EAAE,UAAqB,EAAE,IAAa;IAChE,OAAO,IAAI,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,OAAO,CAAC,CAAC,EAAE,KAAK,EAAE,UAAU,EAAE,IAAI,CAAC,CAAC,CAAC;AAC7D,CAAC;AAED;;;;;;;;GAQG;AACH,MAAM,UAAU,WAAW,CAAC,CAAS;IACnC,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,MAAM,CAAC,GAAG,GAAG,CAAC,GAAG,CAAC,EAAE,EAAE,GAAG,CAAC,GAAG,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;QACtC,OAAO,GAAG,CAAC,WAAW,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;IAClC,CAAC,CAAC,CAAC;AACL,CAAC;AAED;;;;;;;;;;;;GAYG;AACH,MAAM,UAAU,YAAY,CAAI,CAAU,EAAE,GAAY,EAAE,QAAQ,GAAG,KAAK;IACxE,OAAO,QAAQ,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC;AAChC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/**\n * deeplearn.js backend.\n */\n\nimport * as tfc from '@tensorflow/tfjs-core';\nimport {onesLike as coreOnesLike, scalar, Tensor, Tensor1D, tensor1d, Tensor2D, Tensor3D, Tensor4D, Tensor5D, tidy, where, zerosLike as coreZerosLike} from '@tensorflow/tfjs-core';\nimport {checkDataFormat} from '../common';\nimport {NotImplementedError, ValueError} from '../errors';\nimport {DataFormat, Shape} from '../keras_format/common';\nimport {HasShape} from '../types';\nimport * as math_utils from '../utils/math_utils';\n\nimport {imageDataFormat} from './common';\n\n// tslint:enable\n\n/* Setting and getting backend from deeplearn.js. */\n\n// Default deeplearn.js backend is WebGL (GPU).\nlet backend: 'cpu'|'webgl' = 'webgl';\n\nexport function setBackend(requestedBackend: 'cpu'|'webgl') {\n  tfc.setBackend(requestedBackend);\n  backend = requestedBackend;\n}\n\nexport function getBackend(): 'cpu'|'webgl' {\n  return backend;\n}\n\n/**\n * Indicates whether the backend is operating symbolically.\n *\n * This function will be used to determine how to interpret user code. If\n * it returns true, calls to the backend construct a symbolic graph; if\n * it returns false, calls to the backend execute immediately.\n */\nexport function isBackendSymbolic(): boolean {\n  return false;\n}\n\n/**\n * Get the number of elements in a Tensor.\n * @param x The Tensor.\n * @return Number of elements in `x`.\n */\nexport function countParams(x: HasShape): number {\n  const shape = x.shape;\n  if (shape.length > 0) {\n    return shape.reduce((a: number, b: number) => a * b);\n  } else {\n    // Scalar.\n    return 1;\n  }\n}\n\n/**\n * Casts a tensor to a different dtype and returns it.\n * @param x Input tensor.\n * @param dtype String: 'float32'|'int32'|'bool'.\n * @returns Tensor of the specified `dtype`.\n */\nexport function cast(x: Tensor, dtype: tfc.DataType): Tensor {\n  return tfc.cast(x, dtype);\n}\n\n/**\n * Adds a 1-sized dimension at index \"axis\".\n * @param x Input tensor.\n * @param axis Position where to add the new axis.\n * @returns Result of the dimension expansion.\n */\nexport function expandDims(x: Tensor, axis = -1): Tensor {\n  const outShape = x.shape.slice();\n  if (axis < 0) {\n    axis = outShape.length + axis + 1;\n  }\n  outShape.splice(axis, 0, 1);\n  return tfc.reshape(x, outShape);\n}\n\n/**\n * Repeats a 2D tensor.\n *\n * If `x` has shape `[samples, dim]` and `n` is 2, for example, the output\n * will have shape `[samples, 2, dim]`.\n *\n * @param x Input tensor.\n * @param n Integer, number of times to repeat.\n * @returns The result of the repeat operation.\n * @throws ValueError: If input tensor is not 2D.\n */\nexport function repeat(x: Tensor, n: number): Tensor {\n  return tidy(() => {\n    if (x.shape.length !== 2) {\n      throw new ValueError(\n          `repeat() expects a rank-2 tensor, but received a ` +\n          `rank-${x.shape.length} tensor.`);\n    }\n    const y = expandDims(x, 1);\n    return tile(y, [1, n, 1]);\n  });\n}\n\n/**\n * Flatten a Tensor into 1D.\n * @param x Input tensor.\n * @return The result of the flattening `x`.\n */\nexport function flatten(x: Tensor): Tensor {\n  const newShape = [math_utils.arrayProd(x.shape)];\n  return tfc.reshape(x, newShape);\n}\n\n/**\n * Turn a nD tensor into a 2D tensor with same 0th dimension.\n * In other words, it flattens each data samples of a batch.\n *\n * @param x The tensor to flatten. The rank of this tensor is required to be 2\n *   or higher.\n * @return The result of the flattening.\n */\nexport function batchFlatten(x: Tensor): Tensor {\n  if (x.rank <= 1) {\n    throw new ValueError(\n        `batchFlatten requires a minimum rank of 2. Got rank: ${x.rank}.`);\n  }\n  const newShape = [x.shape[0], math_utils.arrayProd(x.shape, 1)];\n  return tfc.reshape(x, newShape);\n}\n\n/**\n * Do slicing along the first axis.\n * @param array input `tf.Tensor`.\n * @param start starting index, inclusive.\n * @param size size of the slice along the first axis.\n * @returns result of the slicing.\n * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.\n */\nexport function sliceAlongFirstAxis(\n    array: Tensor, start: number, size: number): Tensor {\n  return tidy(() => {\n    switch (array.rank) {\n      case 1:\n        return tfc.slice1d(array as Tensor1D, start, size);\n      case 2:\n        return tfc.slice2d(\n            array as Tensor2D, [start, 0], [size, array.shape[1]]);\n      case 3:\n        return tfc.slice3d(\n            array as Tensor3D, [start, 0, 0],\n            [size, array.shape[1], array.shape[2]]);\n      case 4:\n        return tfc.slice4d(\n            array as Tensor4D, [start, 0, 0, 0],\n            [size, array.shape[1], array.shape[2], array.shape[3]]);\n      case 5:\n        return tfc.slice(array as Tensor5D, [start, 0, 0, 0, 0], [\n          size, array.shape[1], array.shape[2], array.shape[3], array.shape[4]\n        ]);\n      case 6:\n        return tfc.slice(array, [start, 0, 0, 0, 0, 0], [\n          size, array.shape[1], array.shape[2], array.shape[3], array.shape[4],\n          array.shape[5]\n        ]);\n      default:\n        throw new ValueError(\n            `sliceAlongFirstAxis() received an unsupported tensor rank: ` +\n            `${array.rank}`);\n    }\n  });\n}\n\n/**\n * Do slicing along the last axis.\n * @param array input `tf.Tensor`.\n * @param start starting index, inclusive.\n * @param size size of the slice along the last axis.\n * @returns result of the slicing.\n * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.\n */\nexport function sliceAlongLastAxis(\n    array: Tensor, start: number, size: number): Tensor {\n  return tidy(() => {\n    switch (array.rank) {\n      case 1:\n        return tfc.slice1d(array as Tensor1D, start, size);\n      case 2:\n        return tfc.slice2d(\n            array as Tensor2D, [0, start], [array.shape[0], size]);\n      case 3:\n        return tfc.slice3d(\n            array as Tensor3D, [0, 0, start],\n            [array.shape[0], array.shape[1], size]);\n      case 4:\n        return tfc.slice4d(\n            array as Tensor4D, [0, 0, 0, start],\n            [array.shape[0], array.shape[1], array.shape[2], size]);\n      default:\n        throw new ValueError(\n            `sliceAlongLastAxis() received an unsupported tensor rank: ` +\n            `${array.rank}`);\n    }\n  });\n}\n\n/**\n * Do slicing along the sepcified axis.\n * @param array input `tf.Tensor`.\n * @param start starting index, inclusive.\n * @param size of the slice along the chosen axis.\n * @param choose an axis.\n * @returns result of the slicing.\n * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.\n */\nexport function sliceAlongAxis(\n    array: Tensor, start: number, size: number, axis: number): Tensor {\n  return tidy(() => {\n    switch (array.rank) {\n      case 1:\n        return tfc.slice1d(array as Tensor1D, start, size);\n      case 2:\n        switch (axis) {\n          case 1:\n            return sliceAlongFirstAxis(array, start, size);\n          case 2:\n            return sliceAlongLastAxis(array, start, size);\n          default:\n            throw new ValueError(\n                `The axis is not within the rank of the tensor ` +\n                `${axis}`);\n        }\n      case 3:\n        switch (axis) {\n          case 1:\n            return sliceAlongFirstAxis(array, start, size);\n          case 2:\n            return tfc.slice3d(\n                array as Tensor3D, [0, start, 0],\n                [array.shape[0], size, array.shape[2]]);\n          case 3:\n            return sliceAlongLastAxis(array, start, size);\n          default:\n            throw new ValueError(\n                `The axis is not within the rank of the tensor ` +\n                `${axis}`);\n        }\n      case 4:\n        switch (axis) {\n          case 1:\n            return sliceAlongFirstAxis(array, start, size);\n          case 2:\n            return tfc.slice4d(\n                array as Tensor4D, [0, start, 0, 0],\n                [array.shape[0], size, array.shape[2], array.shape[3]]);\n          case 3:\n            return tfc.slice4d(\n                array as Tensor4D, [0, 0, start, 0],\n                [array.shape[0], array.shape[1], size, array.shape[3]]);\n          case 4:\n            return sliceAlongLastAxis(array, start, size);\n          default:\n            throw new ValueError(\n                `The axis is not within the rank of the tensor ` +\n                `${axis}`);\n        }\n      default:\n        throw new ValueError(\n            `sliceAlongLastAxis() received an unsupported tensor rank: ` +\n            `${array.rank}`);\n    }\n  });\n}\n\n/**\n * Concatenates a list of tensors alongside the specified axis.\n * @param tensors `Array` of tensors to concatenate.\n * @param axis Concatenation axis.\n * @returns The result of the concatenation.\n */\nexport function concatenate(tensors: Tensor[], axis = -1): Tensor {\n  let rank: number;\n  if (axis < 0) {\n    rank = tensors[0].rank;\n    if (rank !== 0) {\n      axis = rank;\n    } else {\n      axis = 0;\n    }\n  }\n  if (axis === tensors[0].rank) {\n    // Porting Note: This is necessary because tfc.concat() requires axis to be\n    //   in the interval [-rank, rank).\n    axis = -1;\n  }\n  // Porting Note: Sparse concat is not supported yet.\n  return tfc.concat(tensors, axis);\n}\n\n/**\n * Concatenate two arrays along the first dimension.\n * @param a The 1st `tf.Tensor` to concatenate.\n * @param b The 2nd `tf.Tensor` to concatenate.\n * @returns Result of the concatenation.\n * @throws ValueError: If `a` is of an unsupported subtype of `tf.Tensor`.\n */\nexport function concatAlongFirstAxis(a: Tensor, b: Tensor): Tensor {\n  switch (a.rank) {\n    case 1:\n      return tfc.concat1d([a as Tensor1D, b as Tensor1D]);\n    case 2:\n      return tfc.concat2d([a as Tensor2D, b as Tensor2D], 0);\n    case 3:\n      return tfc.concat3d([a as Tensor3D, b as Tensor3D], 0);\n    case 4:\n      return tfc.concat4d([a as Tensor4D, b as Tensor4D], 0);\n    default:\n      throw new ValueError(\n          `concatAlongFirstAxis() received an unsupported ` +\n          `tensor rank: ${a.rank}`);\n  }\n}\n\n/**\n * Creates a tensor by tiling `x` by `n`.\n * @param x A tensor.\n * @param n An Array of integers or a single integer. If an Array, the length\n *   must be the same as the number of dimensions in `x`. If a single integer,\n *   it will be treated as an Array of length 1.\n */\nexport function tile(x: Tensor, n: number|number[]): Tensor {\n  if (!Array.isArray(n)) {\n    n = [n];\n  }\n  if (x.rank !== n.length) {\n    throw new ValueError(\n        `The length of input n (${n.length}) does not match ` +\n        `the number of dimensions in input x (${x.rank})`);\n  }\n  return tfc.tile(x, n);\n}\n\n/* Creation of random tensors. */\n\n/**\n * Get a tensor with normal distribution of values.\n *\n * @param shape Shape of the tensor.\n * @param mean mean value of the normal distribution.\n * @param stddev standard deviation of the normal distribution.\n * @param dtype\n * @param seed\n * @return The normal tensor.\n */\nexport function randomNormal(\n    shape: Shape, mean = 0.0, stddev = 1.0, dtype?: 'float32'|'int32',\n    seed?: number): Tensor {\n  return tfc.randomNormal(shape, mean, stddev, dtype, seed);\n}\n\n/* Linear Algebra */\n\n/**\n * Multiply two tensors and returns the result as a tensor.\n *\n * For 2D tensors, this is equivalent to matrix multiplication (matMul).\n * For tensors of higher ranks, it follows the Theano behavior,\n * (e.g. `(2, 3) * (4, 3, 5) -> (2, 4, 5)`).  From the Theano documentation:\n *\n * For N dimensions it is a sum product over the last axis of x and the\n * second-to-last of y:\n *\n * @param a A tensor of at least rank 2.\n * @param b A tensor of at least rank 2.\n * @param activation (optional) A string identifying the activation\n *   function.\n * @return Result of the dot operation.\n */\nexport function dot(\n    a: Tensor, b: Tensor, activation?: tfc.fused.Activation,\n    bias?: Tensor): Tensor {\n  if ((a.rank < 2) || (b.rank < 2)) {\n    throw new NotImplementedError(\n        `dot requires both inputs to be rank >= 2` +\n        ` but got x shape = ${a.shape} and y shape = ${b.shape}`);\n  }\n  if (b.rank >= 3) {\n    const xLastDim = a.shape.slice(-1)[0];\n    const ySecondLastDim = b.shape.slice(-2)[0];\n    if (xLastDim !== ySecondLastDim) {\n      throw new NotImplementedError(\n          `If rank y >= 3, then the second last dim` +\n          ` of y must equal the last dim of x but got x shape = ${\n              a.shape} and ` +\n          ` y shape = ${b.shape}`);\n    }\n  }\n  // Handle basic 2D x 2D case.\n  if ((a.rank === 2) && (b.rank === 2)) {\n    const transposeA = false;\n    const transposeB = false;\n    // tfc.fused.matMul only fuses certain activation functions. Unsupported\n    // activation functions are treated as 'linear' activations, which is\n    // equivalent to a no-op.\n    return tfc.fused.matMul({\n      a,\n      b: b as Tensor2D,\n      transposeA,\n      transposeB,\n      bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,\n      activation\n    });\n  } else {\n    // Reshape x into the analogous 2D Tensor.\n    const aFirstDims = a.shape.slice();  // Holds all but the last dim of x.\n    const aLastDim = aFirstDims.pop();\n    a = tfc.reshape(a, [-1, aLastDim]);\n\n    // Reshape y into the analogous 2D Tensor, and keep track of the\n    // required dimensions to reproduce the output shape.\n    const bShape = b.shape.slice();\n    const bLastDim = bShape.pop();\n    const ySecondLastDim = bShape.pop();\n    const yOtherDims = [...bShape, bLastDim];\n    // permutation should be like [r-2, 0, 1, 2, ... r-4, r-3, r-1]\n    // where r is the rank of y.\n    const perm = Array.from({length: b.rank}, (_, i) => {\n      if (i === 0) {\n        return b.rank - 2;\n      } else if (i <= b.rank - 2) {\n        return i - 1;\n      }\n      return i;\n    });\n    b = tfc.reshape(tfc.transpose(b, perm), [ySecondLastDim, -1]);\n\n    // Multiply x and y as 2D Tensors, and then reshape back to original.\n    const outputShape = [...aFirstDims, ...yOtherDims];\n    const transposeA = false;\n    const transposeB = false;\n    return tfc.reshape(\n        tfc.fused.matMul({\n          a,\n          b,\n          transposeA,\n          transposeB,\n          bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,\n          activation\n        }),\n        outputShape);\n  }\n}\n\n/**\n * Compute the sign Tensor of an input Tensor.\n *\n * Elements of the input `tf.Tensor` that are === 0 are mapped to 0.\n * Elements of the input `tf.Tensor` that are > 0 are mapped to 1.\n * Elements of the input `tf.Tensor` that are < 0 are mapped to -1.\n *\n * @param x Input `tf.Tensor`.\n * @return The sign `tf.Tensor`.\n */\nexport function sign(x: Tensor): Tensor {\n  // TODO(cais): Move to the core.\n  return tidy(() => {\n    const zerosLikeX = coreZerosLike(x);\n    const onesLikeX = coreOnesLike(x);\n    return where(\n        tfc.equal(x, zerosLikeX), zerosLikeX,\n        where(\n            tfc.greater(x, coreZerosLike(x)), onesLikeX,\n            tfc.mul(-1, onesLikeX)));\n  });\n}\n\n/**\n * Computes the one-hot representation of an integer tensor.\n * @param indices nD integer tensor of shape\n *   `(batch_size, dim1, dim2, ... dim(n-1))`\n * @param numClasses Integer, number of classes to consider.\n * @returns (n + 1)D one hot representation of the input\n *   with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`\n */\nexport function oneHot(indices: Tensor, numClasses: number): Tensor {\n  return tidy(() => {\n    if (indices.rank !== 1) {\n      throw new Error(\n          'Only 1D one-hot tensors are supported in the ' +\n          'deeplearn backend, at present.');\n    }\n    indices = tfc.cast(indices, 'int32');\n    return tfc.cast(tfc.oneHot(indices as Tensor1D, numClasses), 'float32');\n  });\n}\n\n/* Elementary math functions. */\n\n/**\n * Retrieves the elements of indices `indices` in the tensor `reference`.\n * @param reference A tensor.\n * @param indices An integer tensor of indices or an `Array` of integers.\n * @param axis Axis along which to perform the gather operation.\n * @returns The result of the gathering as a tensor.\n */\nexport function gather(\n    reference: Tensor, indices: number[]|Tensor1D, axis?: number): Tensor {\n  return tidy(() => {\n    if (Array.isArray(indices)) {\n      indices = tensor1d(indices, 'int32');\n    } else {\n      indices = tfc.cast(indices, 'int32');\n    }\n    return tfc.gather(reference, indices, axis);\n  });\n}\n\n/**\n * Element-wise square.\n * @param x Input tensor.\n * @return element-wise x^2\n */\nexport function square(x: Tensor): Tensor {\n  return tfc.mul(x, x);\n}\n\n/**\n * Element-wise exponentiation.\n *\n * Porting Note: In PyKeras, `a` (the exponent) is a Python integer, which\n *   takes advatnage of the backend's (e.g., TensorFlow's) automatic\n * conversion to tensor. Here we allow `a` to be either a number or a tensor.\n *\n * @param x The base tensor.\n * @param a The exponent, tensor or number. If a number, it is rounded to the\n *   nearest integer and converted to a tensor.\n * @returns A tensor of the same shape as `x`.\n */\nexport function pow(x: Tensor, a: Tensor|number): Tensor {\n  return tidy(() => {\n    if (typeof (a) === 'number') {\n      a = scalar(Math.round(a), 'int32');\n    }\n    if (a.dtype !== 'int32') {\n      throw new NotImplementedError(\n          `Non-int32 dtype (${a.dtype}) is not supported by pow() yet`);\n    }\n    return tfc.pow(x, a);\n  });\n}\n\n/**\n * Reshapes bias tensor according to rank of x.\n */\nfunction reshapeBias(xRank: number, bias: Tensor, dataFormat: string) {\n  const biasShape = bias.shape;\n\n  if (bias.rank !== 1 && bias.rank !== xRank) {\n    throw new ValueError(\n        `Unexpected bias dimensions: ${bias.rank}` +\n        `; expected it to be 1 or ${xRank}`);\n  }\n\n  if (xRank === 5) {\n    if (dataFormat === 'channelsFirst') {\n      if (biasShape.length === 1) {\n        return tfc.reshape(bias, [1, biasShape[0], 1, 1, 1]);\n      } else {\n        return tfc.reshape(\n            bias, [1, biasShape[3], biasShape[0], biasShape[1], biasShape[2]]);\n      }\n    } else if (dataFormat === 'channelsLast') {\n      if (biasShape.length === 1) {\n        return tfc.reshape(bias, [1, 1, 1, 1, biasShape[0]]);\n      } else {\n        return tfc.reshape(bias, [1].concat(biasShape));\n      }\n    }\n  } else if (xRank === 4) {\n    if (dataFormat === 'channelsFirst') {\n      if (biasShape.length === 1) {\n        return tfc.reshape(bias, [1, biasShape[0], 1, 1]);\n      } else {\n        return tfc.reshape(bias, [1, biasShape[2], biasShape[0], biasShape[1]]);\n      }\n    } else if (dataFormat === 'channelsLast') {\n      if (biasShape.length === 1) {\n        return tfc.reshape(bias, [1, 1, 1, biasShape[0]]);\n      } else {\n        return tfc.reshape(bias, [1].concat(biasShape));\n      }\n    }\n  } else if (xRank === 3) {\n    if (dataFormat === 'channelsFirst') {\n      if (biasShape.length === 1) {\n        return tfc.reshape(bias, [1, biasShape[0], 1]);\n      } else {\n        return tfc.reshape(bias, [1, biasShape[1], biasShape[0]]);\n      }\n    } else if (dataFormat === 'channelsLast') {\n      if (biasShape.length === 1) {\n        return tfc.reshape(bias, [1, 1, biasShape[0]]);\n      } else {\n        return tfc.reshape(bias, [1].concat(biasShape));\n      }\n    }\n  } else if (xRank < 3) {\n    return bias;\n  }\n  throw new ValueError(`Unsupported input rank by biasAdd: ${bias.rank}`);\n}\n\n/* Neural-network operations. */\n\n/**\n * Add a bias to a tensor.\n *\n * @param x The tensor to add the bias to.\n * @param bias The bias to add to `x`. Must be 1D or the same rank as `x`.\n * @return Result of the bias adding.\n * @throws ValueError: If the rank of `bias` is incorrect.\n */\nexport function biasAdd(\n    x: Tensor, bias: Tensor, dataFormat?: DataFormat): Tensor {\n  return tidy(() => {\n    if (dataFormat == null) {\n      dataFormat = imageDataFormat();\n    }\n    checkDataFormat(dataFormat);\n\n    return tfc.add(x, reshapeBias(x.rank, bias, dataFormat));\n  });\n}\n\n/**\n * Exponential linear unit (ELU).\n * @param x A tensor or variable to compute the activation function for.\n * @param alpha: A scalar, a scaling factor for the negative section.\n * @return Output of the ELU operation.\n */\nexport function elu(x: Tensor, alpha = 1): Tensor {\n  // TODO(cais): Add support for alpha values other than 1.\n  if (alpha !== 1) {\n    throw new NotImplementedError(\n        `Support for alpha values other than 1 (${alpha}) is not implemented ` +\n        `yet.`);\n  }\n  return tfc.elu(x);\n}\n\n/**\n * Softsign of a tensor.\n *\n * Defined as x / (abs(x) + 1), element-wise.\n *\n * @param x: Input.\n * @returns Output.\n */\nexport function softsign(x: Tensor): Tensor {\n  return tidy(() => tfc.div(x, tfc.add(tfc.abs(x), 1)));\n}\n\n/**\n * Sets entries in `x` to zero at random, while scaling the entire tensor.\n *\n * @param x input tensor.\n * @param level fraction of the entries in the tensor that will be set to 0.\n * @param noiseShape shape of randomly generated keep/drop flags, must be\n *   broadcastable to the shape of `x`. Optional.\n * @param seed random seed to ensure determinism. Optional.\n * @returns Result of the dropout operation.\n */\nexport function dropout(\n    x: Tensor, level: number, noiseShape?: number[], seed?: number): Tensor {\n  return tidy(() => tfc.dropout(x, level, noiseShape, seed));\n}\n\n/**\n * Element-wise, segment-wise linear approximation of sigmoid.\n *\n * Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.\n * In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.\n *\n * @param x Input tensor.\n * @returns Output tensor.\n */\nexport function hardSigmoid(x: Tensor): Tensor {\n  return tidy(() => {\n    const y = tfc.add(.5, tfc.mul(.2, x));\n    return tfc.clipByValue(y, 0, 1);\n  });\n}\n\n/**\n * Invoke `x` in the training phase, and `alt` otherwise.\n *\n * Porting Note: We do not create placeholder tensors for the `training`\n * boolean flag here, because there is no such thing in the TF.js imperative\n * backend.\n *\n * @param x The function to invoke iff `training` is `true`.\n * @param alt The function to invoke iff `training` is `false`.\n * @param training Boolean flag for whether training phase is active.\n * @returns The return value of `x()` if `training` is `true`, or the return\n *   value of `alt()` if `training` is `false`.\n */\nexport function inTrainPhase<T>(x: () => T, alt: () => T, training = false): T {\n  return training ? x() : alt();\n}\n"]} |
\ | No newline at end of file |