UNPKG

75.2 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/**
11 * deeplearn.js backend.
12 */
13import * as tfc from '@tensorflow/tfjs-core';
14import { onesLike as coreOnesLike, scalar, tensor1d, tidy, where, zerosLike as coreZerosLike } from '@tensorflow/tfjs-core';
15import { checkDataFormat } from '../common';
16import { NotImplementedError, ValueError } from '../errors';
17import * as math_utils from '../utils/math_utils';
18import { imageDataFormat } from './common';
19// tslint:enable
20/* Setting and getting backend from deeplearn.js. */
21// Default deeplearn.js backend is WebGL (GPU).
22let backend = 'webgl';
23export function setBackend(requestedBackend) {
24 tfc.setBackend(requestedBackend);
25 backend = requestedBackend;
26}
27export function getBackend() {
28 return backend;
29}
30/**
31 * Indicates whether the backend is operating symbolically.
32 *
33 * This function will be used to determine how to interpret user code. If
34 * it returns true, calls to the backend construct a symbolic graph; if
35 * it returns false, calls to the backend execute immediately.
36 */
37export function isBackendSymbolic() {
38 return false;
39}
40/**
41 * Get the number of elements in a Tensor.
42 * @param x The Tensor.
43 * @return Number of elements in `x`.
44 */
45export function countParams(x) {
46 const shape = x.shape;
47 if (shape.length > 0) {
48 return shape.reduce((a, b) => a * b);
49 }
50 else {
51 // Scalar.
52 return 1;
53 }
54}
55/**
56 * Casts a tensor to a different dtype and returns it.
57 * @param x Input tensor.
58 * @param dtype String: 'float32'|'int32'|'bool'.
59 * @returns Tensor of the specified `dtype`.
60 */
61export function cast(x, dtype) {
62 return tfc.cast(x, dtype);
63}
64/**
65 * Adds a 1-sized dimension at index "axis".
66 * @param x Input tensor.
67 * @param axis Position where to add the new axis.
68 * @returns Result of the dimension expansion.
69 */
70export function expandDims(x, axis = -1) {
71 const outShape = x.shape.slice();
72 if (axis < 0) {
73 axis = outShape.length + axis + 1;
74 }
75 outShape.splice(axis, 0, 1);
76 return tfc.reshape(x, outShape);
77}
78/**
79 * Repeats a 2D tensor.
80 *
81 * If `x` has shape `[samples, dim]` and `n` is 2, for example, the output
82 * will have shape `[samples, 2, dim]`.
83 *
84 * @param x Input tensor.
85 * @param n Integer, number of times to repeat.
86 * @returns The result of the repeat operation.
87 * @throws ValueError: If input tensor is not 2D.
88 */
89export function repeat(x, n) {
90 return tidy(() => {
91 if (x.shape.length !== 2) {
92 throw new ValueError(`repeat() expects a rank-2 tensor, but received a ` +
93 `rank-${x.shape.length} tensor.`);
94 }
95 const y = expandDims(x, 1);
96 return tile(y, [1, n, 1]);
97 });
98}
99/**
100 * Flatten a Tensor into 1D.
101 * @param x Input tensor.
102 * @return The result of the flattening `x`.
103 */
104export function flatten(x) {
105 const newShape = [math_utils.arrayProd(x.shape)];
106 return tfc.reshape(x, newShape);
107}
108/**
109 * Turn a nD tensor into a 2D tensor with same 0th dimension.
110 * In other words, it flattens each data samples of a batch.
111 *
112 * @param x The tensor to flatten. The rank of this tensor is required to be 2
113 * or higher.
114 * @return The result of the flattening.
115 */
116export function batchFlatten(x) {
117 if (x.rank <= 1) {
118 throw new ValueError(`batchFlatten requires a minimum rank of 2. Got rank: ${x.rank}.`);
119 }
120 const newShape = [x.shape[0], math_utils.arrayProd(x.shape, 1)];
121 return tfc.reshape(x, newShape);
122}
123/**
124 * Do slicing along the first axis.
125 * @param array input `tf.Tensor`.
126 * @param start starting index, inclusive.
127 * @param size size of the slice along the first axis.
128 * @returns result of the slicing.
129 * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
130 */
131export function sliceAlongFirstAxis(array, start, size) {
132 return tidy(() => {
133 switch (array.rank) {
134 case 1:
135 return tfc.slice1d(array, start, size);
136 case 2:
137 return tfc.slice2d(array, [start, 0], [size, array.shape[1]]);
138 case 3:
139 return tfc.slice3d(array, [start, 0, 0], [size, array.shape[1], array.shape[2]]);
140 case 4:
141 return tfc.slice4d(array, [start, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3]]);
142 case 5:
143 return tfc.slice(array, [start, 0, 0, 0, 0], [
144 size, array.shape[1], array.shape[2], array.shape[3], array.shape[4]
145 ]);
146 case 6:
147 return tfc.slice(array, [start, 0, 0, 0, 0, 0], [
148 size, array.shape[1], array.shape[2], array.shape[3], array.shape[4],
149 array.shape[5]
150 ]);
151 default:
152 throw new ValueError(`sliceAlongFirstAxis() received an unsupported tensor rank: ` +
153 `${array.rank}`);
154 }
155 });
156}
157/**
158 * Do slicing along the last axis.
159 * @param array input `tf.Tensor`.
160 * @param start starting index, inclusive.
161 * @param size size of the slice along the last axis.
162 * @returns result of the slicing.
163 * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
164 */
165export function sliceAlongLastAxis(array, start, size) {
166 return tidy(() => {
167 switch (array.rank) {
168 case 1:
169 return tfc.slice1d(array, start, size);
170 case 2:
171 return tfc.slice2d(array, [0, start], [array.shape[0], size]);
172 case 3:
173 return tfc.slice3d(array, [0, 0, start], [array.shape[0], array.shape[1], size]);
174 case 4:
175 return tfc.slice4d(array, [0, 0, 0, start], [array.shape[0], array.shape[1], array.shape[2], size]);
176 default:
177 throw new ValueError(`sliceAlongLastAxis() received an unsupported tensor rank: ` +
178 `${array.rank}`);
179 }
180 });
181}
182/**
183 * Do slicing along the sepcified axis.
184 * @param array input `tf.Tensor`.
185 * @param start starting index, inclusive.
186 * @param size of the slice along the chosen axis.
187 * @param choose an axis.
188 * @returns result of the slicing.
189 * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
190 */
191export function sliceAlongAxis(array, start, size, axis) {
192 return tidy(() => {
193 switch (array.rank) {
194 case 1:
195 return tfc.slice1d(array, start, size);
196 case 2:
197 switch (axis) {
198 case 1:
199 return sliceAlongFirstAxis(array, start, size);
200 case 2:
201 return sliceAlongLastAxis(array, start, size);
202 default:
203 throw new ValueError(`The axis is not within the rank of the tensor ` +
204 `${axis}`);
205 }
206 case 3:
207 switch (axis) {
208 case 1:
209 return sliceAlongFirstAxis(array, start, size);
210 case 2:
211 return tfc.slice3d(array, [0, start, 0], [array.shape[0], size, array.shape[2]]);
212 case 3:
213 return sliceAlongLastAxis(array, start, size);
214 default:
215 throw new ValueError(`The axis is not within the rank of the tensor ` +
216 `${axis}`);
217 }
218 case 4:
219 switch (axis) {
220 case 1:
221 return sliceAlongFirstAxis(array, start, size);
222 case 2:
223 return tfc.slice4d(array, [0, start, 0, 0], [array.shape[0], size, array.shape[2], array.shape[3]]);
224 case 3:
225 return tfc.slice4d(array, [0, 0, start, 0], [array.shape[0], array.shape[1], size, array.shape[3]]);
226 case 4:
227 return sliceAlongLastAxis(array, start, size);
228 default:
229 throw new ValueError(`The axis is not within the rank of the tensor ` +
230 `${axis}`);
231 }
232 default:
233 throw new ValueError(`sliceAlongLastAxis() received an unsupported tensor rank: ` +
234 `${array.rank}`);
235 }
236 });
237}
238/**
239 * Concatenates a list of tensors alongside the specified axis.
240 * @param tensors `Array` of tensors to concatenate.
241 * @param axis Concatenation axis.
242 * @returns The result of the concatenation.
243 */
244export function concatenate(tensors, axis = -1) {
245 let rank;
246 if (axis < 0) {
247 rank = tensors[0].rank;
248 if (rank !== 0) {
249 axis = rank;
250 }
251 else {
252 axis = 0;
253 }
254 }
255 if (axis === tensors[0].rank) {
256 // Porting Note: This is necessary because tfc.concat() requires axis to be
257 // in the interval [-rank, rank).
258 axis = -1;
259 }
260 // Porting Note: Sparse concat is not supported yet.
261 return tfc.concat(tensors, axis);
262}
263/**
264 * Concatenate two arrays along the first dimension.
265 * @param a The 1st `tf.Tensor` to concatenate.
266 * @param b The 2nd `tf.Tensor` to concatenate.
267 * @returns Result of the concatenation.
268 * @throws ValueError: If `a` is of an unsupported subtype of `tf.Tensor`.
269 */
270export function concatAlongFirstAxis(a, b) {
271 switch (a.rank) {
272 case 1:
273 return tfc.concat1d([a, b]);
274 case 2:
275 return tfc.concat2d([a, b], 0);
276 case 3:
277 return tfc.concat3d([a, b], 0);
278 case 4:
279 return tfc.concat4d([a, b], 0);
280 default:
281 throw new ValueError(`concatAlongFirstAxis() received an unsupported ` +
282 `tensor rank: ${a.rank}`);
283 }
284}
285/**
286 * Creates a tensor by tiling `x` by `n`.
287 * @param x A tensor.
288 * @param n An Array of integers or a single integer. If an Array, the length
289 * must be the same as the number of dimensions in `x`. If a single integer,
290 * it will be treated as an Array of length 1.
291 */
292export function tile(x, n) {
293 if (!Array.isArray(n)) {
294 n = [n];
295 }
296 if (x.rank !== n.length) {
297 throw new ValueError(`The length of input n (${n.length}) does not match ` +
298 `the number of dimensions in input x (${x.rank})`);
299 }
300 return tfc.tile(x, n);
301}
302/* Creation of random tensors. */
303/**
304 * Get a tensor with normal distribution of values.
305 *
306 * @param shape Shape of the tensor.
307 * @param mean mean value of the normal distribution.
308 * @param stddev standard deviation of the normal distribution.
309 * @param dtype
310 * @param seed
311 * @return The normal tensor.
312 */
313export function randomNormal(shape, mean = 0.0, stddev = 1.0, dtype, seed) {
314 return tfc.randomNormal(shape, mean, stddev, dtype, seed);
315}
316/* Linear Algebra */
317/**
318 * Multiply two tensors and returns the result as a tensor.
319 *
320 * For 2D tensors, this is equivalent to matrix multiplication (matMul).
321 * For tensors of higher ranks, it follows the Theano behavior,
322 * (e.g. `(2, 3) * (4, 3, 5) -> (2, 4, 5)`). From the Theano documentation:
323 *
324 * For N dimensions it is a sum product over the last axis of x and the
325 * second-to-last of y:
326 *
327 * @param a A tensor of at least rank 2.
328 * @param b A tensor of at least rank 2.
329 * @param activation (optional) A string identifying the activation
330 * function.
331 * @return Result of the dot operation.
332 */
333export function dot(a, b, activation, bias) {
334 if ((a.rank < 2) || (b.rank < 2)) {
335 throw new NotImplementedError(`dot requires both inputs to be rank >= 2` +
336 ` but got x shape = ${a.shape} and y shape = ${b.shape}`);
337 }
338 if (b.rank >= 3) {
339 const xLastDim = a.shape.slice(-1)[0];
340 const ySecondLastDim = b.shape.slice(-2)[0];
341 if (xLastDim !== ySecondLastDim) {
342 throw new NotImplementedError(`If rank y >= 3, then the second last dim` +
343 ` of y must equal the last dim of x but got x shape = ${a.shape} and ` +
344 ` y shape = ${b.shape}`);
345 }
346 }
347 // Handle basic 2D x 2D case.
348 if ((a.rank === 2) && (b.rank === 2)) {
349 const transposeA = false;
350 const transposeB = false;
351 // tfc.fused.matMul only fuses certain activation functions. Unsupported
352 // activation functions are treated as 'linear' activations, which is
353 // equivalent to a no-op.
354 return tfc.fused.matMul({
355 a,
356 b: b,
357 transposeA,
358 transposeB,
359 bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
360 activation
361 });
362 }
363 else {
364 // Reshape x into the analogous 2D Tensor.
365 const aFirstDims = a.shape.slice(); // Holds all but the last dim of x.
366 const aLastDim = aFirstDims.pop();
367 a = tfc.reshape(a, [-1, aLastDim]);
368 // Reshape y into the analogous 2D Tensor, and keep track of the
369 // required dimensions to reproduce the output shape.
370 const bShape = b.shape.slice();
371 const bLastDim = bShape.pop();
372 const ySecondLastDim = bShape.pop();
373 const yOtherDims = [...bShape, bLastDim];
374 // permutation should be like [r-2, 0, 1, 2, ... r-4, r-3, r-1]
375 // where r is the rank of y.
376 const perm = Array.from({ length: b.rank }, (_, i) => {
377 if (i === 0) {
378 return b.rank - 2;
379 }
380 else if (i <= b.rank - 2) {
381 return i - 1;
382 }
383 return i;
384 });
385 b = tfc.reshape(tfc.transpose(b, perm), [ySecondLastDim, -1]);
386 // Multiply x and y as 2D Tensors, and then reshape back to original.
387 const outputShape = [...aFirstDims, ...yOtherDims];
388 const transposeA = false;
389 const transposeB = false;
390 return tfc.reshape(tfc.fused.matMul({
391 a,
392 b,
393 transposeA,
394 transposeB,
395 bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
396 activation
397 }), outputShape);
398 }
399}
400/**
401 * Compute the sign Tensor of an input Tensor.
402 *
403 * Elements of the input `tf.Tensor` that are === 0 are mapped to 0.
404 * Elements of the input `tf.Tensor` that are > 0 are mapped to 1.
405 * Elements of the input `tf.Tensor` that are < 0 are mapped to -1.
406 *
407 * @param x Input `tf.Tensor`.
408 * @return The sign `tf.Tensor`.
409 */
410export function sign(x) {
411 // TODO(cais): Move to the core.
412 return tidy(() => {
413 const zerosLikeX = coreZerosLike(x);
414 const onesLikeX = coreOnesLike(x);
415 return where(tfc.equal(x, zerosLikeX), zerosLikeX, where(tfc.greater(x, coreZerosLike(x)), onesLikeX, tfc.mul(-1, onesLikeX)));
416 });
417}
418/**
419 * Computes the one-hot representation of an integer tensor.
420 * @param indices nD integer tensor of shape
421 * `(batch_size, dim1, dim2, ... dim(n-1))`
422 * @param numClasses Integer, number of classes to consider.
423 * @returns (n + 1)D one hot representation of the input
424 * with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`
425 */
426export function oneHot(indices, numClasses) {
427 return tidy(() => {
428 if (indices.rank !== 1) {
429 throw new Error('Only 1D one-hot tensors are supported in the ' +
430 'deeplearn backend, at present.');
431 }
432 indices = tfc.cast(indices, 'int32');
433 return tfc.cast(tfc.oneHot(indices, numClasses), 'float32');
434 });
435}
436/* Elementary math functions. */
437/**
438 * Retrieves the elements of indices `indices` in the tensor `reference`.
439 * @param reference A tensor.
440 * @param indices An integer tensor of indices or an `Array` of integers.
441 * @param axis Axis along which to perform the gather operation.
442 * @returns The result of the gathering as a tensor.
443 */
444export function gather(reference, indices, axis) {
445 return tidy(() => {
446 if (Array.isArray(indices)) {
447 indices = tensor1d(indices, 'int32');
448 }
449 else {
450 indices = tfc.cast(indices, 'int32');
451 }
452 return tfc.gather(reference, indices, axis);
453 });
454}
455/**
456 * Element-wise square.
457 * @param x Input tensor.
458 * @return element-wise x^2
459 */
460export function square(x) {
461 return tfc.mul(x, x);
462}
463/**
464 * Element-wise exponentiation.
465 *
466 * Porting Note: In PyKeras, `a` (the exponent) is a Python integer, which
467 * takes advatnage of the backend's (e.g., TensorFlow's) automatic
468 * conversion to tensor. Here we allow `a` to be either a number or a tensor.
469 *
470 * @param x The base tensor.
471 * @param a The exponent, tensor or number. If a number, it is rounded to the
472 * nearest integer and converted to a tensor.
473 * @returns A tensor of the same shape as `x`.
474 */
475export function pow(x, a) {
476 return tidy(() => {
477 if (typeof (a) === 'number') {
478 a = scalar(Math.round(a), 'int32');
479 }
480 if (a.dtype !== 'int32') {
481 throw new NotImplementedError(`Non-int32 dtype (${a.dtype}) is not supported by pow() yet`);
482 }
483 return tfc.pow(x, a);
484 });
485}
486/**
487 * Reshapes bias tensor according to rank of x.
488 */
489function reshapeBias(xRank, bias, dataFormat) {
490 const biasShape = bias.shape;
491 if (bias.rank !== 1 && bias.rank !== xRank) {
492 throw new ValueError(`Unexpected bias dimensions: ${bias.rank}` +
493 `; expected it to be 1 or ${xRank}`);
494 }
495 if (xRank === 5) {
496 if (dataFormat === 'channelsFirst') {
497 if (biasShape.length === 1) {
498 return tfc.reshape(bias, [1, biasShape[0], 1, 1, 1]);
499 }
500 else {
501 return tfc.reshape(bias, [1, biasShape[3], biasShape[0], biasShape[1], biasShape[2]]);
502 }
503 }
504 else if (dataFormat === 'channelsLast') {
505 if (biasShape.length === 1) {
506 return tfc.reshape(bias, [1, 1, 1, 1, biasShape[0]]);
507 }
508 else {
509 return tfc.reshape(bias, [1].concat(biasShape));
510 }
511 }
512 }
513 else if (xRank === 4) {
514 if (dataFormat === 'channelsFirst') {
515 if (biasShape.length === 1) {
516 return tfc.reshape(bias, [1, biasShape[0], 1, 1]);
517 }
518 else {
519 return tfc.reshape(bias, [1, biasShape[2], biasShape[0], biasShape[1]]);
520 }
521 }
522 else if (dataFormat === 'channelsLast') {
523 if (biasShape.length === 1) {
524 return tfc.reshape(bias, [1, 1, 1, biasShape[0]]);
525 }
526 else {
527 return tfc.reshape(bias, [1].concat(biasShape));
528 }
529 }
530 }
531 else if (xRank === 3) {
532 if (dataFormat === 'channelsFirst') {
533 if (biasShape.length === 1) {
534 return tfc.reshape(bias, [1, biasShape[0], 1]);
535 }
536 else {
537 return tfc.reshape(bias, [1, biasShape[1], biasShape[0]]);
538 }
539 }
540 else if (dataFormat === 'channelsLast') {
541 if (biasShape.length === 1) {
542 return tfc.reshape(bias, [1, 1, biasShape[0]]);
543 }
544 else {
545 return tfc.reshape(bias, [1].concat(biasShape));
546 }
547 }
548 }
549 else if (xRank < 3) {
550 return bias;
551 }
552 throw new ValueError(`Unsupported input rank by biasAdd: ${bias.rank}`);
553}
554/* Neural-network operations. */
555/**
556 * Add a bias to a tensor.
557 *
558 * @param x The tensor to add the bias to.
559 * @param bias The bias to add to `x`. Must be 1D or the same rank as `x`.
560 * @return Result of the bias adding.
561 * @throws ValueError: If the rank of `bias` is incorrect.
562 */
563export function biasAdd(x, bias, dataFormat) {
564 return tidy(() => {
565 if (dataFormat == null) {
566 dataFormat = imageDataFormat();
567 }
568 checkDataFormat(dataFormat);
569 return tfc.add(x, reshapeBias(x.rank, bias, dataFormat));
570 });
571}
572/**
573 * Exponential linear unit (ELU).
574 * @param x A tensor or variable to compute the activation function for.
575 * @param alpha: A scalar, a scaling factor for the negative section.
576 * @return Output of the ELU operation.
577 */
578export function elu(x, alpha = 1) {
579 // TODO(cais): Add support for alpha values other than 1.
580 if (alpha !== 1) {
581 throw new NotImplementedError(`Support for alpha values other than 1 (${alpha}) is not implemented ` +
582 `yet.`);
583 }
584 return tfc.elu(x);
585}
586/**
587 * Softsign of a tensor.
588 *
589 * Defined as x / (abs(x) + 1), element-wise.
590 *
591 * @param x: Input.
592 * @returns Output.
593 */
594export function softsign(x) {
595 return tidy(() => tfc.div(x, tfc.add(tfc.abs(x), 1)));
596}
597/**
598 * Sets entries in `x` to zero at random, while scaling the entire tensor.
599 *
600 * @param x input tensor.
601 * @param level fraction of the entries in the tensor that will be set to 0.
602 * @param noiseShape shape of randomly generated keep/drop flags, must be
603 * broadcastable to the shape of `x`. Optional.
604 * @param seed random seed to ensure determinism. Optional.
605 * @returns Result of the dropout operation.
606 */
607export function dropout(x, level, noiseShape, seed) {
608 return tidy(() => tfc.dropout(x, level, noiseShape, seed));
609}
610/**
611 * Element-wise, segment-wise linear approximation of sigmoid.
612 *
613 * Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
614 * In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
615 *
616 * @param x Input tensor.
617 * @returns Output tensor.
618 */
619export function hardSigmoid(x) {
620 return tidy(() => {
621 const y = tfc.add(.5, tfc.mul(.2, x));
622 return tfc.clipByValue(y, 0, 1);
623 });
624}
625/**
626 * Invoke `x` in the training phase, and `alt` otherwise.
627 *
628 * Porting Note: We do not create placeholder tensors for the `training`
629 * boolean flag here, because there is no such thing in the TF.js imperative
630 * backend.
631 *
632 * @param x The function to invoke iff `training` is `true`.
633 * @param alt The function to invoke iff `training` is `false`.
634 * @param training Boolean flag for whether training phase is active.
635 * @returns The return value of `x()` if `training` is `true`, or the return
636 * value of `alt()` if `training` is `false`.
637 */
638export function inTrainPhase(x, alt, training = false) {
639 return training ? x() : alt();
640}
641//# sourceMappingURL=data:application/json;base64,
\No newline at end of file