UNPKG

4.16 MBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2022 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17(function (global, factory) {
18 typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports) :
19 typeof define === 'function' && define.amd ? define(['exports'], factory) :
20 (global = global || self, factory(global.tf = global.tf || {}));
21}(this, (function (exports) { 'use strict';
22
23 /**
24 * @license
25 * Copyright 2020 Google LLC. All Rights Reserved.
26 * Licensed under the Apache License, Version 2.0 (the "License");
27 * you may not use this file except in compliance with the License.
28 * You may obtain a copy of the License at
29 *
30 * http://www.apache.org/licenses/LICENSE-2.0
31 *
32 * Unless required by applicable law or agreed to in writing, software
33 * distributed under the License is distributed on an "AS IS" BASIS,
34 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35 * See the License for the specific language governing permissions and
36 * limitations under the License.
37 * =============================================================================
38 */
39 const EPSILON_FLOAT32 = 1e-7;
40 const EPSILON_FLOAT16 = 1e-4;
41 /** Convenient class for storing tensor-related data. */
42 class DataStorage {
43 constructor(backend, dataMover) {
44 this.backend = backend;
45 this.dataMover = dataMover;
46 this.data = new WeakMap();
47 this.dataIdsCount = 0;
48 }
49 get(dataId) {
50 if (!this.data.has(dataId)) {
51 this.dataMover.moveData(this.backend, dataId);
52 }
53 return this.data.get(dataId);
54 }
55 set(dataId, value) {
56 this.dataIdsCount++;
57 this.data.set(dataId, value);
58 }
59 has(dataId) {
60 return this.data.has(dataId);
61 }
62 delete(dataId) {
63 this.dataIdsCount--;
64 return this.data.delete(dataId);
65 }
66 numDataIds() {
67 return this.dataIdsCount;
68 }
69 }
70 /**
71 * The interface that defines the kernels that should be implemented when
72 * adding a new backend. New backends don't need to implement every one of the
73 * methods, this can be done gradually (throw an error for unimplemented
74 * methods).
75 */
76 class KernelBackend {
77 refCount(dataId) {
78 return notYetImplemented('refCount');
79 }
80 incRef(dataId) {
81 return notYetImplemented('incRef');
82 }
83 timerAvailable() {
84 return true;
85 }
86 time(f) {
87 return notYetImplemented('time');
88 }
89 read(dataId) {
90 return notYetImplemented('read');
91 }
92 readSync(dataId) {
93 return notYetImplemented('readSync');
94 }
95 readToGPU(dataId, options) {
96 return notYetImplemented('readToGPU');
97 }
98 numDataIds() {
99 return notYetImplemented('numDataIds');
100 }
101 disposeData(dataId, force) {
102 return notYetImplemented('disposeData');
103 }
104 write(values, shape, dtype) {
105 return notYetImplemented('write');
106 }
107 move(dataId, values, shape, dtype, refCount) {
108 return notYetImplemented('move');
109 }
110 memory() {
111 return notYetImplemented('memory');
112 }
113 /** Returns the highest precision for floats in bits (e.g. 16 or 32) */
114 floatPrecision() {
115 return notYetImplemented('floatPrecision');
116 }
117 /** Returns the smallest representable number. */
118 epsilon() {
119 return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;
120 }
121 dispose() {
122 return notYetImplemented('dispose');
123 }
124 }
125 function notYetImplemented(kernelName) {
126 throw new Error(`'${kernelName}' not yet implemented or not found in the registry. ` +
127 `This kernel may not be supported by the tfjs backend you have chosen`);
128 }
129
130 /**
131 * @license
132 * Copyright 2020 Google LLC. All Rights Reserved.
133 * Licensed under the Apache License, Version 2.0 (the "License");
134 * you may not use this file except in compliance with the License.
135 * You may obtain a copy of the License at
136 *
137 * http://www.apache.org/licenses/LICENSE-2.0
138 *
139 * Unless required by applicable law or agreed to in writing, software
140 * distributed under the License is distributed on an "AS IS" BASIS,
141 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
142 * See the License for the specific language governing permissions and
143 * limitations under the License.
144 * =============================================================================
145 */
146 /**
147 * Shuffles the array in-place using Fisher-Yates algorithm.
148 *
149 * ```js
150 * const a = [1, 2, 3, 4, 5];
151 * tf.util.shuffle(a);
152 * console.log(a);
153 * ```
154 *
155 * @param array The array to shuffle in-place.
156 *
157 * @doc {heading: 'Util', namespace: 'util'}
158 */
159 // tslint:disable-next-line:no-any
160 function shuffle(array) {
161 let counter = array.length;
162 let index = 0;
163 // While there are elements in the array
164 while (counter > 0) {
165 // Pick a random index
166 index = (Math.random() * counter) | 0;
167 // Decrease counter by 1
168 counter--;
169 // And swap the last element with it
170 swap(array, counter, index);
171 }
172 }
173 /**
174 * Shuffles two arrays in-place the same way using Fisher-Yates algorithm.
175 *
176 * ```js
177 * const a = [1,2,3,4,5];
178 * const b = [11,22,33,44,55];
179 * tf.util.shuffleCombo(a, b);
180 * console.log(a, b);
181 * ```
182 *
183 * @param array The first array to shuffle in-place.
184 * @param array2 The second array to shuffle in-place with the same permutation
185 * as the first array.
186 *
187 * @doc {heading: 'Util', namespace: 'util'}
188 */
189 function shuffleCombo(
190 // tslint:disable-next-line:no-any
191 array,
192 // tslint:disable-next-line:no-any
193 array2) {
194 if (array.length !== array2.length) {
195 throw new Error(`Array sizes must match to be shuffled together ` +
196 `First array length was ${array.length}` +
197 `Second array length was ${array2.length}`);
198 }
199 let counter = array.length;
200 let index = 0;
201 // While there are elements in the array
202 while (counter > 0) {
203 // Pick a random index
204 index = (Math.random() * counter) | 0;
205 // Decrease counter by 1
206 counter--;
207 // And swap the last element of each array with it
208 swap(array, counter, index);
209 swap(array2, counter, index);
210 }
211 }
212 /** Clamps a value to a specified range. */
213 function clamp(min, x, max) {
214 return Math.max(min, Math.min(x, max));
215 }
216 function nearestLargerEven(val) {
217 return val % 2 === 0 ? val : val + 1;
218 }
219 function swap(object, left, right) {
220 const temp = object[left];
221 object[left] = object[right];
222 object[right] = temp;
223 }
224 function sum(arr) {
225 let sum = 0;
226 for (let i = 0; i < arr.length; i++) {
227 sum += arr[i];
228 }
229 return sum;
230 }
231 /**
232 * Returns a sample from a uniform [a, b) distribution.
233 *
234 * @param a The minimum support (inclusive).
235 * @param b The maximum support (exclusive).
236 * @return A pseudorandom number on the half-open interval [a,b).
237 */
238 function randUniform(a, b) {
239 const r = Math.random();
240 return (b * r) + (1 - r) * a;
241 }
242 /** Returns the squared Euclidean distance between two vectors. */
243 function distSquared(a, b) {
244 let result = 0;
245 for (let i = 0; i < a.length; i++) {
246 const diff = Number(a[i]) - Number(b[i]);
247 result += diff * diff;
248 }
249 return result;
250 }
251 /**
252 * Asserts that the expression is true. Otherwise throws an error with the
253 * provided message.
254 *
255 * ```js
256 * const x = 2;
257 * tf.util.assert(x === 2, 'x is not 2');
258 * ```
259 *
260 * @param expr The expression to assert (as a boolean).
261 * @param msg A function that returns the message to report when throwing an
262 * error. We use a function for performance reasons.
263 *
264 * @doc {heading: 'Util', namespace: 'util'}
265 */
266 function assert(expr, msg) {
267 if (!expr) {
268 throw new Error(typeof msg === 'string' ? msg : msg());
269 }
270 }
271 function assertShapesMatch(shapeA, shapeB, errorMessagePrefix = '') {
272 assert(arraysEqual(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);
273 }
274 function assertNonNull(a) {
275 assert(a != null, () => `The input to the tensor constructor must be a non-null value.`);
276 }
277 // NOTE: We explicitly type out what T extends instead of any so that
278 // util.flatten on a nested array of number doesn't try to infer T as a
279 // number[][], causing us to explicitly type util.flatten<number>().
280 /**
281 * Flattens an arbitrarily nested array.
282 *
283 * ```js
284 * const a = [[1, 2], [3, 4], [5, [6, [7]]]];
285 * const flat = tf.util.flatten(a);
286 * console.log(flat);
287 * ```
288 *
289 * @param arr The nested array to flatten.
290 * @param result The destination array which holds the elements.
291 * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults
292 * to false.
293 *
294 * @doc {heading: 'Util', namespace: 'util'}
295 */
296 function flatten(arr, result = [], skipTypedArray = false) {
297 if (result == null) {
298 result = [];
299 }
300 if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) {
301 for (let i = 0; i < arr.length; ++i) {
302 flatten(arr[i], result, skipTypedArray);
303 }
304 }
305 else {
306 result.push(arr);
307 }
308 return result;
309 }
310 /**
311 * Returns the size (number of elements) of the tensor given its shape.
312 *
313 * ```js
314 * const shape = [3, 4, 2];
315 * const size = tf.util.sizeFromShape(shape);
316 * console.log(size);
317 * ```
318 *
319 * @doc {heading: 'Util', namespace: 'util'}
320 */
321 function sizeFromShape(shape) {
322 if (shape.length === 0) {
323 // Scalar.
324 return 1;
325 }
326 let size = shape[0];
327 for (let i = 1; i < shape.length; i++) {
328 size *= shape[i];
329 }
330 return size;
331 }
332 function isScalarShape(shape) {
333 return shape.length === 0;
334 }
335 function arraysEqual(n1, n2) {
336 if (n1 === n2) {
337 return true;
338 }
339 if (n1 == null || n2 == null) {
340 return false;
341 }
342 if (n1.length !== n2.length) {
343 return false;
344 }
345 for (let i = 0; i < n1.length; i++) {
346 if (n1[i] !== n2[i]) {
347 return false;
348 }
349 }
350 return true;
351 }
352 function isInt(a) {
353 return a % 1 === 0;
354 }
355 function tanh(x) {
356 // tslint:disable-next-line:no-any
357 if (Math.tanh != null) {
358 // tslint:disable-next-line:no-any
359 return Math.tanh(x);
360 }
361 if (x === Infinity) {
362 return 1;
363 }
364 else if (x === -Infinity) {
365 return -1;
366 }
367 else {
368 const e2x = Math.exp(2 * x);
369 return (e2x - 1) / (e2x + 1);
370 }
371 }
372 function sizeToSquarishShape(size) {
373 const width = Math.ceil(Math.sqrt(size));
374 return [width, Math.ceil(size / width)];
375 }
376 /**
377 * Creates a new array with randomized indicies to a given quantity.
378 *
379 * ```js
380 * const randomTen = tf.util.createShuffledIndices(10);
381 * console.log(randomTen);
382 * ```
383 *
384 * @param number Quantity of how many shuffled indicies to create.
385 *
386 * @doc {heading: 'Util', namespace: 'util'}
387 */
388 function createShuffledIndices(n) {
389 const shuffledIndices = new Uint32Array(n);
390 for (let i = 0; i < n; ++i) {
391 shuffledIndices[i] = i;
392 }
393 shuffle(shuffledIndices);
394 return shuffledIndices;
395 }
396 function rightPad(a, size) {
397 if (size <= a.length) {
398 return a;
399 }
400 return a + ' '.repeat(size - a.length);
401 }
402 function repeatedTry(checkFn, delayFn = (counter) => 0, maxCounter) {
403 return new Promise((resolve, reject) => {
404 let tryCount = 0;
405 const tryFn = () => {
406 if (checkFn()) {
407 resolve();
408 return;
409 }
410 tryCount++;
411 const nextBackoff = delayFn(tryCount);
412 if (maxCounter != null && tryCount >= maxCounter) {
413 reject();
414 return;
415 }
416 setTimeout(tryFn, nextBackoff);
417 };
418 tryFn();
419 });
420 }
421 /**
422 * Given the full size of the array and a shape that may contain -1 as the
423 * implicit dimension, returns the inferred shape where -1 is replaced.
424 * E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3].
425 *
426 * @param shape The shape, which may contain -1 in some dimension.
427 * @param size The full size (number of elements) of the array.
428 * @return The inferred shape where -1 is replaced with the inferred size.
429 */
430 function inferFromImplicitShape(shape, size) {
431 let shapeProd = 1;
432 let implicitIdx = -1;
433 for (let i = 0; i < shape.length; ++i) {
434 if (shape[i] >= 0) {
435 shapeProd *= shape[i];
436 }
437 else if (shape[i] === -1) {
438 if (implicitIdx !== -1) {
439 throw Error(`Shapes can only have 1 implicit size. ` +
440 `Found -1 at dim ${implicitIdx} and dim ${i}`);
441 }
442 implicitIdx = i;
443 }
444 else if (shape[i] < 0) {
445 throw Error(`Shapes can not be < 0. Found ${shape[i]} at dim ${i}`);
446 }
447 }
448 if (implicitIdx === -1) {
449 if (size > 0 && size !== shapeProd) {
450 throw Error(`Size(${size}) must match the product of shape ${shape}`);
451 }
452 return shape;
453 }
454 if (shapeProd === 0) {
455 throw Error(`Cannot infer the missing size in [${shape}] when ` +
456 `there are 0 elements`);
457 }
458 if (size % shapeProd !== 0) {
459 throw Error(`The implicit shape can't be a fractional number. ` +
460 `Got ${size} / ${shapeProd}`);
461 }
462 const newShape = shape.slice();
463 newShape[implicitIdx] = size / shapeProd;
464 return newShape;
465 }
466 function parseAxisParam(axis, shape) {
467 const rank = shape.length;
468 // Normalize input
469 axis = axis == null ? shape.map((s, i) => i) : [].concat(axis);
470 // Check for valid range
471 assert(axis.every(ax => ax >= -rank && ax < rank), () => `All values in axis param must be in range [-${rank}, ${rank}) but ` +
472 `got axis ${axis}`);
473 // Check for only integers
474 assert(axis.every(ax => isInt(ax)), () => `All values in axis param must be integers but ` +
475 `got axis ${axis}`);
476 // Handle negative axis.
477 return axis.map(a => a < 0 ? rank + a : a);
478 }
479 /** Reduces the shape by removing all dimensions of shape 1. */
480 function squeezeShape(shape, axis) {
481 const newShape = [];
482 const keptDims = [];
483 const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
484 const axes = (axis == null || isEmptyArray) ?
485 null :
486 parseAxisParam(axis, shape).sort();
487 let j = 0;
488 for (let i = 0; i < shape.length; ++i) {
489 if (axes != null) {
490 if (axes[j] === i && shape[i] !== 1) {
491 throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`);
492 }
493 if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
494 newShape.push(shape[i]);
495 keptDims.push(i);
496 }
497 if (axes[j] <= i) {
498 j++;
499 }
500 }
501 if (shape[i] !== 1) {
502 newShape.push(shape[i]);
503 keptDims.push(i);
504 }
505 }
506 return { newShape, keptDims };
507 }
508 function getTypedArrayFromDType(dtype, size) {
509 let values = null;
510 if (dtype == null || dtype === 'float32') {
511 values = new Float32Array(size);
512 }
513 else if (dtype === 'int32') {
514 values = new Int32Array(size);
515 }
516 else if (dtype === 'bool') {
517 values = new Uint8Array(size);
518 }
519 else {
520 throw new Error(`Unknown data type ${dtype}`);
521 }
522 return values;
523 }
524 function getArrayFromDType(dtype, size) {
525 let values = null;
526 if (dtype == null || dtype === 'float32') {
527 values = new Float32Array(size);
528 }
529 else if (dtype === 'int32') {
530 values = new Int32Array(size);
531 }
532 else if (dtype === 'bool') {
533 values = new Uint8Array(size);
534 }
535 else if (dtype === 'string') {
536 values = new Array(size);
537 }
538 else {
539 throw new Error(`Unknown data type ${dtype}`);
540 }
541 return values;
542 }
543 function checkConversionForErrors(vals, dtype) {
544 for (let i = 0; i < vals.length; i++) {
545 const num = vals[i];
546 if (isNaN(num) || !isFinite(num)) {
547 throw Error(`A tensor of type ${dtype} being uploaded contains ${num}.`);
548 }
549 }
550 }
551 /** Returns true if the dtype is valid. */
552 function isValidDtype(dtype) {
553 return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' ||
554 dtype === 'int32' || dtype === 'string';
555 }
556 /**
557 * Returns true if the new type can't encode the old type without loss of
558 * precision.
559 */
560 function hasEncodingLoss(oldType, newType) {
561 if (newType === 'complex64') {
562 return false;
563 }
564 if (newType === 'float32' && oldType !== 'complex64') {
565 return false;
566 }
567 if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {
568 return false;
569 }
570 if (newType === 'bool' && oldType === 'bool') {
571 return false;
572 }
573 return true;
574 }
575 function isTypedArray(a) {
576 return a instanceof Float32Array || a instanceof Int32Array ||
577 a instanceof Uint8Array || a instanceof Uint8ClampedArray;
578 }
579 function bytesPerElement(dtype) {
580 if (dtype === 'float32' || dtype === 'int32') {
581 return 4;
582 }
583 else if (dtype === 'complex64') {
584 return 8;
585 }
586 else if (dtype === 'bool') {
587 return 1;
588 }
589 else {
590 throw new Error(`Unknown dtype ${dtype}`);
591 }
592 }
593 /**
594 * Returns the approximate number of bytes allocated in the string array - 2
595 * bytes per character. Computing the exact bytes for a native string in JS is
596 * not possible since it depends on the encoding of the html page that serves
597 * the website.
598 */
599 function bytesFromStringArray(arr) {
600 if (arr == null) {
601 return 0;
602 }
603 let bytes = 0;
604 arr.forEach(x => bytes += x.length);
605 return bytes;
606 }
607 /** Returns true if the value is a string. */
608 function isString(value) {
609 return typeof value === 'string' || value instanceof String;
610 }
611 function isBoolean(value) {
612 return typeof value === 'boolean';
613 }
614 function isNumber(value) {
615 return typeof value === 'number';
616 }
617 function inferDtype(values) {
618 if (Array.isArray(values)) {
619 return inferDtype(values[0]);
620 }
621 if (values instanceof Float32Array) {
622 return 'float32';
623 }
624 else if (values instanceof Int32Array
625 || values instanceof Uint8Array
626 || values instanceof Uint8ClampedArray) {
627 return 'int32';
628 }
629 else if (isNumber(values)) {
630 return 'float32';
631 }
632 else if (isString(values)) {
633 return 'string';
634 }
635 else if (isBoolean(values)) {
636 return 'bool';
637 }
638 return 'float32';
639 }
640 function isFunction(f) {
641 return !!(f && f.constructor && f.call && f.apply);
642 }
643 function nearestDivisor(size, start) {
644 for (let i = start; i < size; ++i) {
645 if (size % i === 0) {
646 return i;
647 }
648 }
649 return size;
650 }
651 function computeStrides(shape) {
652 const rank = shape.length;
653 if (rank < 2) {
654 return [];
655 }
656 // Last dimension has implicit stride of 1, thus having D-1 (instead of D)
657 // strides.
658 const strides = new Array(rank - 1);
659 strides[rank - 2] = shape[rank - 1];
660 for (let i = rank - 3; i >= 0; --i) {
661 strides[i] = strides[i + 1] * shape[i + 1];
662 }
663 return strides;
664 }
665 function createNestedArray(offset, shape, a, isComplex = false) {
666 const ret = new Array();
667 if (shape.length === 1) {
668 const d = shape[0] * (isComplex ? 2 : 1);
669 for (let i = 0; i < d; i++) {
670 ret[i] = a[offset + i];
671 }
672 }
673 else {
674 const d = shape[0];
675 const rest = shape.slice(1);
676 const len = rest.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1);
677 for (let i = 0; i < d; i++) {
678 ret[i] = createNestedArray(offset + i * len, rest, a, isComplex);
679 }
680 }
681 return ret;
682 }
683 // Provide a nested array of TypedArray in given shape.
684 function toNestedArray(shape, a, isComplex = false) {
685 if (shape.length === 0) {
686 // Scalar type should return a single number.
687 return a[0];
688 }
689 const size = shape.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1);
690 if (size === 0) {
691 // A tensor with shape zero should be turned into empty list.
692 return [];
693 }
694 if (size !== a.length) {
695 throw new Error(`[${shape}] does not match the input size ${a.length}${isComplex ? ' for a complex tensor' : ''}.`);
696 }
697 return createNestedArray(0, shape, a, isComplex);
698 }
699 function makeOnesTypedArray(size, dtype) {
700 const array = makeZerosTypedArray(size, dtype);
701 for (let i = 0; i < array.length; i++) {
702 array[i] = 1;
703 }
704 return array;
705 }
706 function makeZerosTypedArray(size, dtype) {
707 if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
708 return new Float32Array(size);
709 }
710 else if (dtype === 'int32') {
711 return new Int32Array(size);
712 }
713 else if (dtype === 'bool') {
714 return new Uint8Array(size);
715 }
716 else {
717 throw new Error(`Unknown data type ${dtype}`);
718 }
719 }
720 /**
721 * Make nested `TypedArray` filled with zeros.
722 * @param shape The shape information for the nested array.
723 * @param dtype dtype of the array element.
724 */
725 function makeZerosNestedTypedArray(shape, dtype) {
726 const size = shape.reduce((prev, curr) => prev * curr, 1);
727 if (dtype == null || dtype === 'float32') {
728 return toNestedArray(shape, new Float32Array(size));
729 }
730 else if (dtype === 'int32') {
731 return toNestedArray(shape, new Int32Array(size));
732 }
733 else if (dtype === 'bool') {
734 return toNestedArray(shape, new Uint8Array(size));
735 }
736 else {
737 throw new Error(`Unknown data type ${dtype}`);
738 }
739 }
740 function assertNonNegativeIntegerDimensions(shape) {
741 shape.forEach(dimSize => {
742 assert(Number.isInteger(dimSize) && dimSize >= 0, () => `Tensor must have a shape comprised of positive integers but got ` +
743 `shape [${shape}].`);
744 });
745 }
746 /**
747 * Computes flat index for a given location (multidimentionsal index) in a
748 * Tensor/multidimensional array.
749 *
750 * @param locs Location in the tensor.
751 * @param rank Rank of the tensor.
752 * @param strides Tensor strides.
753 */
754 function locToIndex(locs, rank, strides) {
755 if (rank === 0) {
756 return 0;
757 }
758 else if (rank === 1) {
759 return locs[0];
760 }
761 let index = locs[locs.length - 1];
762 for (let i = 0; i < locs.length - 1; ++i) {
763 index += strides[i] * locs[i];
764 }
765 return index;
766 }
767 /**
768 * Computes the location (multidimensional index) in a tensor/multidimentional
769 * array for a given flat index.
770 *
771 * @param index Index in flat array.
772 * @param rank Rank of tensor.
773 * @param strides Strides of tensor.
774 */
775 function indexToLoc(index, rank, strides) {
776 if (rank === 0) {
777 return [];
778 }
779 else if (rank === 1) {
780 return [index];
781 }
782 const locs = new Array(rank);
783 for (let i = 0; i < locs.length - 1; ++i) {
784 locs[i] = Math.floor(index / strides[i]);
785 index -= locs[i] * strides[i];
786 }
787 locs[locs.length - 1] = index;
788 return locs;
789 }
790 /**
791 * This method asserts whether an object is a Promise instance.
792 * @param object
793 */
794 // tslint:disable-next-line: no-any
795 function isPromise(object) {
796 // We chose to not use 'obj instanceOf Promise' for two reasons:
797 // 1. It only reliably works for es6 Promise, not other Promise
798 // implementations.
799 // 2. It doesn't work with framework that uses zone.js. zone.js monkey patch
800 // the async calls, so it is possible the obj (patched) is comparing to a
801 // pre-patched Promise.
802 return object && object.then && typeof object.then === 'function';
803 }
804
805 /**
806 * @license
807 * Copyright 2017 Google LLC. All Rights Reserved.
808 * Licensed under the Apache License, Version 2.0 (the "License");
809 * you may not use this file except in compliance with the License.
810 * You may obtain a copy of the License at
811 *
812 * http://www.apache.org/licenses/LICENSE-2.0
813 *
814 * Unless required by applicable law or agreed to in writing, software
815 * distributed under the License is distributed on an "AS IS" BASIS,
816 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
817 * See the License for the specific language governing permissions and
818 * limitations under the License.
819 * =============================================================================
820 */
821 // Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true.
822 const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';
823 /**
824 * The environment contains evaluated flags as well as the registered platform.
825 * This is always used as a global singleton and can be retrieved with
826 * `tf.env()`.
827 *
828 * @doc {heading: 'Environment'}
829 */
830 class Environment {
831 // tslint:disable-next-line: no-any
832 constructor(global) {
833 this.global = global;
834 this.flags = {};
835 this.flagRegistry = {};
836 this.urlFlags = {};
837 // Jasmine spies on this in 'environment_test.ts'
838 this.getQueryParams = getQueryParams;
839 this.populateURLFlags();
840 }
841 setPlatform(platformName, platform) {
842 if (this.platform != null) {
843 if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
844 console.warn(`Platform ${this.platformName} has already been set. ` +
845 `Overwriting the platform with ${platformName}.`);
846 }
847 }
848 this.platformName = platformName;
849 this.platform = platform;
850 }
851 registerFlag(flagName, evaluationFn, setHook) {
852 this.flagRegistry[flagName] = { evaluationFn, setHook };
853 // Override the flag value from the URL. This has to happen here because
854 // the environment is initialized before flags get registered.
855 if (this.urlFlags[flagName] != null) {
856 const flagValue = this.urlFlags[flagName];
857 if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
858 console.warn(`Setting feature override from URL ${flagName}: ${flagValue}.`);
859 }
860 this.set(flagName, flagValue);
861 }
862 }
863 async getAsync(flagName) {
864 if (flagName in this.flags) {
865 return this.flags[flagName];
866 }
867 this.flags[flagName] = await this.evaluateFlag(flagName);
868 return this.flags[flagName];
869 }
870 get(flagName) {
871 if (flagName in this.flags) {
872 return this.flags[flagName];
873 }
874 const flagValue = this.evaluateFlag(flagName);
875 if (isPromise(flagValue)) {
876 throw new Error(`Flag ${flagName} cannot be synchronously evaluated. ` +
877 `Please use getAsync() instead.`);
878 }
879 this.flags[flagName] = flagValue;
880 return this.flags[flagName];
881 }
882 getNumber(flagName) {
883 return this.get(flagName);
884 }
885 getBool(flagName) {
886 return this.get(flagName);
887 }
888 getFlags() {
889 return this.flags;
890 }
891 // For backwards compatibility.
892 get features() {
893 return this.flags;
894 }
895 set(flagName, value) {
896 if (this.flagRegistry[flagName] == null) {
897 throw new Error(`Cannot set flag ${flagName} as it has not been registered.`);
898 }
899 this.flags[flagName] = value;
900 if (this.flagRegistry[flagName].setHook != null) {
901 this.flagRegistry[flagName].setHook(value);
902 }
903 }
904 evaluateFlag(flagName) {
905 if (this.flagRegistry[flagName] == null) {
906 throw new Error(`Cannot evaluate flag '${flagName}': no evaluation function found.`);
907 }
908 return this.flagRegistry[flagName].evaluationFn();
909 }
910 setFlags(flags) {
911 this.flags = Object.assign({}, flags);
912 }
913 reset() {
914 this.flags = {};
915 this.urlFlags = {};
916 this.populateURLFlags();
917 }
918 populateURLFlags() {
919 if (typeof this.global === 'undefined' ||
920 typeof this.global.location === 'undefined' ||
921 typeof this.global.location.search === 'undefined') {
922 return;
923 }
924 const urlParams = this.getQueryParams(this.global.location.search);
925 if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
926 const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');
927 keyValues.forEach(keyValue => {
928 const [key, value] = keyValue.split(':');
929 this.urlFlags[key] = parseValue(key, value);
930 });
931 }
932 }
933 }
934 function getQueryParams(queryString) {
935 const params = {};
936 queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (s, ...t) => {
937 decodeParam(params, t[0], t[1]);
938 return t.join('=');
939 });
940 return params;
941 }
942 function decodeParam(params, name, value) {
943 params[decodeURIComponent(name)] = decodeURIComponent(value || '');
944 }
945 function parseValue(flagName, value) {
946 value = value.toLowerCase();
947 if (value === 'true' || value === 'false') {
948 return value === 'true';
949 }
950 else if (`${+value}` === value) {
951 return +value;
952 }
953 throw new Error(`Could not parse value flag value ${value} for flag ${flagName}.`);
954 }
955 /**
956 * Returns the current environment (a global singleton).
957 *
958 * The environment object contains the evaluated feature values as well as the
959 * active platform.
960 *
961 * @doc {heading: 'Environment'}
962 */
963 function env() {
964 return exports.ENV;
965 }
966 exports.ENV = null;
967 function setEnvironmentGlobal(environment) {
968 exports.ENV = environment;
969 }
970
971 /**
972 * @license
973 * Copyright 2020 Google LLC. All Rights Reserved.
974 * Licensed under the Apache License, Version 2.0 (the "License");
975 * you may not use this file except in compliance with the License.
976 * You may obtain a copy of the License at
977 *
978 * http://www.apache.org/licenses/LICENSE-2.0
979 *
980 * Unless required by applicable law or agreed to in writing, software
981 * distributed under the License is distributed on an "AS IS" BASIS,
982 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
983 * See the License for the specific language governing permissions and
984 * limitations under the License.
985 * =============================================================================
986 */
987 // Note that the identifier globalNameSpace is scoped to this module, but will
988 // always resolve to the same global object regardless of how the module is
989 // resolved.
990 // tslint:disable-next-line:no-any
991 let globalNameSpace;
992 // tslint:disable-next-line:no-any
993 function getGlobalNamespace() {
994 if (globalNameSpace == null) {
995 // tslint:disable-next-line:no-any
996 let ns;
997 if (typeof (window) !== 'undefined') {
998 ns = window;
999 }
1000 else if (typeof (global) !== 'undefined') {
1001 ns = global;
1002 }
1003 else if (typeof (process) !== 'undefined') {
1004 ns = process;
1005 }
1006 else if (typeof (self) !== 'undefined') {
1007 ns = self;
1008 }
1009 else {
1010 throw new Error('Could not find a global object');
1011 }
1012 globalNameSpace = ns;
1013 }
1014 return globalNameSpace;
1015 }
1016 // tslint:disable-next-line:no-any
1017 function getGlobalMap() {
1018 const ns = getGlobalNamespace();
1019 if (ns._tfGlobals == null) {
1020 ns._tfGlobals = new Map();
1021 }
1022 return ns._tfGlobals;
1023 }
1024 /**
1025 * Returns a globally accessible 'singleton' object.
1026 *
1027 * @param key the name of the object
1028 * @param init a function to initialize to initialize this object
1029 * the first time it is fetched.
1030 */
1031 function getGlobal(key, init) {
1032 const globalMap = getGlobalMap();
1033 if (globalMap.has(key)) {
1034 return globalMap.get(key);
1035 }
1036 else {
1037 const singleton = init();
1038 globalMap.set(key, singleton);
1039 return globalMap.get(key);
1040 }
1041 }
1042
1043 const Abs = 'Abs';
1044 const Acos = 'Acos';
1045 const Acosh = 'Acosh';
1046 const Add = 'Add';
1047 const AddN = 'AddN';
1048 const All = 'All';
1049 const Any = 'Any';
1050 const ArgMax = 'ArgMax';
1051 const ArgMin = 'ArgMin';
1052 const Asin = 'Asin';
1053 const Asinh = 'Asinh';
1054 const Atan = 'Atan';
1055 const Atanh = 'Atanh';
1056 const Atan2 = 'Atan2';
1057 const AvgPool = 'AvgPool';
1058 const AvgPoolGrad = 'AvgPoolGrad';
1059 const AvgPool3D = 'AvgPool3D';
1060 const AvgPool3DGrad = 'AvgPool3DGrad';
1061 const BatchMatMul = 'BatchMatMul';
1062 const BatchToSpaceND = 'BatchToSpaceND';
1063 const Bincount = 'Bincount';
1064 const BroadcastTo = 'BroadcastTo';
1065 const BroadcastArgs = 'BroadcastArgs';
1066 const Cast = 'Cast';
1067 const Ceil = 'Ceil';
1068 const ClipByValue = 'ClipByValue';
1069 const Complex = 'Complex';
1070 const ComplexAbs = 'ComplexAbs';
1071 const Concat = 'Concat';
1072 const Conv2D = 'Conv2D';
1073 const Conv2DBackpropFilter = 'Conv2DBackpropFilter';
1074 const Conv2DBackpropInput = 'Conv2DBackpropInput';
1075 const Conv3D = 'Conv3D';
1076 const Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2';
1077 const Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2';
1078 const Cos = 'Cos';
1079 const Cosh = 'Cosh';
1080 const Cumprod = 'Cumprod';
1081 const Cumsum = 'Cumsum';
1082 const CropAndResize = 'CropAndResize';
1083 const DenseBincount = 'DenseBincount';
1084 const DepthToSpace = 'DepthToSpace';
1085 const DepthwiseConv2dNative = 'DepthwiseConv2dNative';
1086 const DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter';
1087 const DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput';
1088 const Diag = 'Diag';
1089 const Dilation2D = 'Dilation2D';
1090 const Dilation2DBackpropInput = 'Dilation2DBackpropInput';
1091 const Dilation2DBackpropFilter = 'Dilation2DBackpropFilter';
1092 const RealDiv = 'RealDiv';
1093 const Einsum = 'Einsum';
1094 const Elu = 'Elu';
1095 const EluGrad = 'EluGrad';
1096 const Erf = 'Erf';
1097 const Equal = 'Equal';
1098 const Exp = 'Exp';
1099 const ExpandDims = 'ExpandDims';
1100 const Expm1 = 'Expm1';
1101 const FFT = 'FFT';
1102 const Fill = 'Fill';
1103 const FlipLeftRight = 'FlipLeftRight';
1104 const Floor = 'Floor';
1105 const FloorDiv = 'FloorDiv';
1106 const FusedBatchNorm = 'FusedBatchNorm';
1107 const GatherV2 = 'GatherV2';
1108 const GatherNd = 'GatherNd';
1109 const Greater = 'Greater';
1110 const GreaterEqual = 'GreaterEqual';
1111 const Identity = 'Identity';
1112 const IFFT = 'IFFT';
1113 const Imag = 'Imag';
1114 const IsFinite = 'IsFinite';
1115 const IsInf = 'IsInf';
1116 const IsNan = 'IsNan';
1117 const LeakyRelu = 'LeakyRelu';
1118 const Less = 'Less';
1119 const LessEqual = 'LessEqual';
1120 const LinSpace = 'LinSpace';
1121 const Log = 'Log';
1122 const Log1p = 'Log1p';
1123 const LogicalAnd = 'LogicalAnd';
1124 const LogicalNot = 'LogicalNot';
1125 const LogicalOr = 'LogicalOr';
1126 const LogSoftmax = 'LogSoftmax';
1127 const LowerBound = 'LowerBound';
1128 const LRN = 'LRN';
1129 const LRNGrad = 'LRNGrad';
1130 const Max = 'Max';
1131 const Maximum = 'Maximum';
1132 const MaxPool = 'MaxPool';
1133 const MaxPoolGrad = 'MaxPoolGrad';
1134 const MaxPool3D = 'MaxPool3D';
1135 const MaxPool3DGrad = 'MaxPool3DGrad';
1136 const MaxPoolWithArgmax = 'MaxPoolWithArgmax';
1137 const Mean = 'Mean';
1138 const Min = 'Min';
1139 const Minimum = 'Minimum';
1140 const MirrorPad = 'MirrorPad';
1141 const Mod = 'Mod';
1142 const Multinomial = 'Multinomial';
1143 const Multiply = 'Multiply';
1144 const Neg = 'Neg';
1145 const NotEqual = 'NotEqual';
1146 const NonMaxSuppressionV3 = 'NonMaxSuppressionV3';
1147 const NonMaxSuppressionV4 = 'NonMaxSuppressionV4';
1148 const NonMaxSuppressionV5 = 'NonMaxSuppressionV5';
1149 const OnesLike = 'OnesLike';
1150 const OneHot = 'OneHot';
1151 const Pack = 'Pack';
1152 const PadV2 = 'PadV2';
1153 const Pool = 'Pool';
1154 const Pow = 'Pow';
1155 const Prelu = 'Prelu';
1156 const Prod = 'Prod';
1157 const Range = 'Range';
1158 const Real = 'Real';
1159 const Reciprocal = 'Reciprocal';
1160 const Relu = 'Relu';
1161 const Reshape = 'Reshape';
1162 const ResizeNearestNeighbor = 'ResizeNearestNeighbor';
1163 const ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad';
1164 const ResizeBilinear = 'ResizeBilinear';
1165 const ResizeBilinearGrad = 'ResizeBilinearGrad';
1166 const Relu6 = 'Relu6';
1167 const Reverse = 'Reverse';
1168 const Round = 'Round';
1169 const Rsqrt = 'Rsqrt';
1170 const ScatterNd = 'ScatterNd';
1171 const SearchSorted = 'SearchSorted';
1172 const Select = 'Select';
1173 const Selu = 'Selu';
1174 const Slice = 'Slice';
1175 const Sin = 'Sin';
1176 const Sinh = 'Sinh';
1177 const Sign = 'Sign';
1178 const Sigmoid = 'Sigmoid';
1179 const Softplus = 'Softplus';
1180 const Sqrt = 'Sqrt';
1181 const Sum = 'Sum';
1182 const SpaceToBatchND = 'SpaceToBatchND';
1183 const SplitV = 'SplitV';
1184 const Softmax = 'Softmax';
1185 const SparseFillEmptyRows = 'SparseFillEmptyRows';
1186 const SparseReshape = 'SparseReshape';
1187 const SparseSegmentMean = 'SparseSegmentMean';
1188 const SparseSegmentSum = 'SparseSegmentSum';
1189 const SparseToDense = 'SparseToDense';
1190 const SquaredDifference = 'SquaredDifference';
1191 const Square = 'Square';
1192 const StridedSlice = 'StridedSlice';
1193 const StringNGrams = 'StringNGrams';
1194 const StringSplit = 'StringSplit';
1195 const StringToHashBucketFast = 'StringToHashBucketFast';
1196 const Sub = 'Sub';
1197 const Tan = 'Tan';
1198 const Tanh = 'Tanh';
1199 const Tile = 'Tile';
1200 const TopK = 'TopK';
1201 const Transform = 'Transform';
1202 const Transpose = 'Transpose';
1203 const Unique = 'Unique';
1204 const Unpack = 'Unpack';
1205 const UnsortedSegmentSum = 'UnsortedSegmentSum';
1206 const UpperBound = 'UpperBound';
1207 const ZerosLike = 'ZerosLike';
1208 /**
1209 * TensorFlow.js-only kernels
1210 */
1211 const Step = 'Step';
1212 const FromPixels = 'FromPixels';
1213 const RotateWithOffset = 'RotateWithOffset';
1214 const _FusedMatMul = '_FusedMatMul';
1215 const FusedConv2D = 'FusedConv2D';
1216 const FusedDepthwiseConv2D = 'FusedDepthwiseConv2D';
1217
1218 /**
1219 * @license
1220 * Copyright 2018 Google LLC. All Rights Reserved.
1221 * Licensed under the Apache License, Version 2.0 (the "License");
1222 * you may not use this file except in compliance with the License.
1223 * You may obtain a copy of the License at
1224 *
1225 * http://www.apache.org/licenses/LICENSE-2.0
1226 *
1227 * Unless required by applicable law or agreed to in writing, software
1228 * distributed under the License is distributed on an "AS IS" BASIS,
1229 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1230 * See the License for the specific language governing permissions and
1231 * limitations under the License.
1232 * =============================================================================
1233 */
1234 function warn(...msg) {
1235 if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
1236 console.warn(...msg);
1237 }
1238 }
1239 function log(...msg) {
1240 if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
1241 console.log(...msg);
1242 }
1243 }
1244
1245 /**
1246 * @license
1247 * Copyright 2019 Google LLC. All Rights Reserved.
1248 * Licensed under the Apache License, Version 2.0 (the "License");
1249 * you may not use this file except in compliance with the License.
1250 * You may obtain a copy of the License at
1251 *
1252 * http://www.apache.org/licenses/LICENSE-2.0
1253 *
1254 * Unless required by applicable law or agreed to in writing, software
1255 * distributed under the License is distributed on an "AS IS" BASIS,
1256 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1257 * See the License for the specific language governing permissions and
1258 * limitations under the License.
1259 * =============================================================================
1260 */
1261 const kernelRegistry = getGlobal('kernelRegistry', () => new Map());
1262 const gradRegistry = getGlobal('gradRegistry', () => new Map());
1263 /**
1264 * Returns the kernel function (code) associated with the provided names.
1265 *
1266 * @param kernelName The official name of the kernel.
1267 * @param backendName The official name of the backend.
1268 */
1269 function getKernel(kernelName, backendName) {
1270 const key = makeKey(kernelName, backendName);
1271 return kernelRegistry.get(key);
1272 }
1273 /**
1274 * Returns the registered gradient info associated with the provided kernel.
1275 * @param kernelName The official TF kernel name.
1276 */
1277 function getGradient(kernelName) {
1278 return gradRegistry.get(kernelName);
1279 }
1280 function getKernelsForBackend(backendName) {
1281 const it = kernelRegistry.entries();
1282 const result = [];
1283 while (true) {
1284 const { done, value } = it.next();
1285 if (done) {
1286 break;
1287 }
1288 const [key, config] = value;
1289 const [backend,] = key.split('_');
1290 if (backend === backendName) {
1291 result.push(config);
1292 }
1293 }
1294 return result;
1295 }
1296 /**
1297 * Registers the function (forward pass) for the kernel in a global registry.
1298 *
1299 * @param config A config object with the following properties:
1300 * - `kernelName` The official name of the kernel.
1301 * - `backendName` The official name of the backend.
1302 * - `kernelFunc` The function to run during the forward pass of the kernel.
1303 * - `setupFunc` Optional. Gets called once, after the backend initializes.
1304 * - `disposeFunc` Optional. Gets called once, right before the backend is
1305 * disposed.
1306 */
1307 function registerKernel(config) {
1308 const { kernelName, backendName } = config;
1309 const key = makeKey(kernelName, backendName);
1310 if (kernelRegistry.has(key)) {
1311 warn(`The kernel '${kernelName}' for backend ` +
1312 `'${backendName}' is already registered`);
1313 }
1314 kernelRegistry.set(key, config);
1315 }
1316 /**
1317 * Registers a gradient function for a given kernel in the global registry,
1318 * to be used during the back-propagation of that kernel.
1319 *
1320 * @param config An object with the following properties:
1321 * - `kernelName` The name of the kernel that the gradient function is for.
1322 * - `gradFunc` The function to run during back-propagation.
1323 */
1324 function registerGradient(config) {
1325 const { kernelName } = config;
1326 if (gradRegistry.has(kernelName)) {
1327 // TODO (yassogba) after 3.0 assess whether we need to keep this gated
1328 // to debug mode.
1329 if (env().getBool('DEBUG')) {
1330 warn(`Overriding the gradient for '${kernelName}'`);
1331 }
1332 }
1333 gradRegistry.set(kernelName, config);
1334 }
1335 /**
1336 * Removes the kernel function from the registry.
1337 *
1338 * @param kernelName The official name of the kernel.
1339 * @param backendName The official name of the backend.
1340 *
1341 */
1342 function unregisterKernel(kernelName, backendName) {
1343 const key = makeKey(kernelName, backendName);
1344 if (!kernelRegistry.has(key)) {
1345 throw new Error(`The kernel '${kernelName}' for backend ` +
1346 `'${backendName}' is not registered`);
1347 }
1348 kernelRegistry.delete(key);
1349 }
1350 /** Removes the registered gradient from the global registry. */
1351 function unregisterGradient(kernelName) {
1352 if (!gradRegistry.has(kernelName)) {
1353 throw new Error(`The gradient '${kernelName}' for backend is not registered`);
1354 }
1355 gradRegistry.delete(kernelName);
1356 }
1357 /**
1358 * Finds kernels that have already been registered to a backend and re-registers
1359 * them for a new backend. Useful for registering custom backends.
1360 * @param registeredBackendName Already registered backend.
1361 * @param newBackendName New backend.
1362 */
1363 function copyRegisteredKernels(registeredBackendName, newBackendName) {
1364 const kernels = getKernelsForBackend(registeredBackendName);
1365 kernels.forEach(kernelConfig => {
1366 const newKernelConfig = Object.assign({}, kernelConfig, { backendName: newBackendName });
1367 registerKernel(newKernelConfig);
1368 });
1369 }
1370 function makeKey(kernelName, backendName) {
1371 return `${backendName}_${kernelName}`;
1372 }
1373
1374 var long_1 = Long;
1375
1376 /**
1377 * wasm optimizations, to do native i64 multiplication and divide
1378 */
1379 var wasm = null;
1380
1381 try {
1382 wasm = new WebAssembly.Instance(new WebAssembly.Module(new Uint8Array([
1383 0, 97, 115, 109, 1, 0, 0, 0, 1, 13, 2, 96, 0, 1, 127, 96, 4, 127, 127, 127, 127, 1, 127, 3, 7, 6, 0, 1, 1, 1, 1, 1, 6, 6, 1, 127, 1, 65, 0, 11, 7, 50, 6, 3, 109, 117, 108, 0, 1, 5, 100, 105, 118, 95, 115, 0, 2, 5, 100, 105, 118, 95, 117, 0, 3, 5, 114, 101, 109, 95, 115, 0, 4, 5, 114, 101, 109, 95, 117, 0, 5, 8, 103, 101, 116, 95, 104, 105, 103, 104, 0, 0, 10, 191, 1, 6, 4, 0, 35, 0, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 126, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 127, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 128, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 129, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 130, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11
1384 ])), {}).exports;
1385 } catch (e) {
1386 // no wasm support :(
1387 }
1388
1389 /**
1390 * Constructs a 64 bit two's-complement integer, given its low and high 32 bit values as *signed* integers.
1391 * See the from* functions below for more convenient ways of constructing Longs.
1392 * @exports Long
1393 * @class A Long class for representing a 64 bit two's-complement integer value.
1394 * @param {number} low The low (signed) 32 bits of the long
1395 * @param {number} high The high (signed) 32 bits of the long
1396 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
1397 * @constructor
1398 */
1399 function Long(low, high, unsigned) {
1400
1401 /**
1402 * The low 32 bits as a signed value.
1403 * @type {number}
1404 */
1405 this.low = low | 0;
1406
1407 /**
1408 * The high 32 bits as a signed value.
1409 * @type {number}
1410 */
1411 this.high = high | 0;
1412
1413 /**
1414 * Whether unsigned or not.
1415 * @type {boolean}
1416 */
1417 this.unsigned = !!unsigned;
1418 }
1419
1420 // The internal representation of a long is the two given signed, 32-bit values.
1421 // We use 32-bit pieces because these are the size of integers on which
1422 // Javascript performs bit-operations. For operations like addition and
1423 // multiplication, we split each number into 16 bit pieces, which can easily be
1424 // multiplied within Javascript's floating-point representation without overflow
1425 // or change in sign.
1426 //
1427 // In the algorithms below, we frequently reduce the negative case to the
1428 // positive case by negating the input(s) and then post-processing the result.
1429 // Note that we must ALWAYS check specially whether those values are MIN_VALUE
1430 // (-2^63) because -MIN_VALUE == MIN_VALUE (since 2^63 cannot be represented as
1431 // a positive number, it overflows back into a negative). Not handling this
1432 // case would often result in infinite recursion.
1433 //
1434 // Common constant values ZERO, ONE, NEG_ONE, etc. are defined below the from*
1435 // methods on which they depend.
1436
1437 /**
1438 * An indicator used to reliably determine if an object is a Long or not.
1439 * @type {boolean}
1440 * @const
1441 * @private
1442 */
1443 Long.prototype.__isLong__;
1444
1445 Object.defineProperty(Long.prototype, "__isLong__", { value: true });
1446
1447 /**
1448 * @function
1449 * @param {*} obj Object
1450 * @returns {boolean}
1451 * @inner
1452 */
1453 function isLong(obj) {
1454 return (obj && obj["__isLong__"]) === true;
1455 }
1456
1457 /**
1458 * Tests if the specified object is a Long.
1459 * @function
1460 * @param {*} obj Object
1461 * @returns {boolean}
1462 */
1463 Long.isLong = isLong;
1464
1465 /**
1466 * A cache of the Long representations of small integer values.
1467 * @type {!Object}
1468 * @inner
1469 */
1470 var INT_CACHE = {};
1471
1472 /**
1473 * A cache of the Long representations of small unsigned integer values.
1474 * @type {!Object}
1475 * @inner
1476 */
1477 var UINT_CACHE = {};
1478
1479 /**
1480 * @param {number} value
1481 * @param {boolean=} unsigned
1482 * @returns {!Long}
1483 * @inner
1484 */
1485 function fromInt(value, unsigned) {
1486 var obj, cachedObj, cache;
1487 if (unsigned) {
1488 value >>>= 0;
1489 if (cache = (0 <= value && value < 256)) {
1490 cachedObj = UINT_CACHE[value];
1491 if (cachedObj)
1492 return cachedObj;
1493 }
1494 obj = fromBits(value, (value | 0) < 0 ? -1 : 0, true);
1495 if (cache)
1496 UINT_CACHE[value] = obj;
1497 return obj;
1498 } else {
1499 value |= 0;
1500 if (cache = (-128 <= value && value < 128)) {
1501 cachedObj = INT_CACHE[value];
1502 if (cachedObj)
1503 return cachedObj;
1504 }
1505 obj = fromBits(value, value < 0 ? -1 : 0, false);
1506 if (cache)
1507 INT_CACHE[value] = obj;
1508 return obj;
1509 }
1510 }
1511
1512 /**
1513 * Returns a Long representing the given 32 bit integer value.
1514 * @function
1515 * @param {number} value The 32 bit integer in question
1516 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
1517 * @returns {!Long} The corresponding Long value
1518 */
1519 Long.fromInt = fromInt;
1520
1521 /**
1522 * @param {number} value
1523 * @param {boolean=} unsigned
1524 * @returns {!Long}
1525 * @inner
1526 */
1527 function fromNumber(value, unsigned) {
1528 if (isNaN(value))
1529 return unsigned ? UZERO : ZERO;
1530 if (unsigned) {
1531 if (value < 0)
1532 return UZERO;
1533 if (value >= TWO_PWR_64_DBL)
1534 return MAX_UNSIGNED_VALUE;
1535 } else {
1536 if (value <= -TWO_PWR_63_DBL)
1537 return MIN_VALUE;
1538 if (value + 1 >= TWO_PWR_63_DBL)
1539 return MAX_VALUE;
1540 }
1541 if (value < 0)
1542 return fromNumber(-value, unsigned).neg();
1543 return fromBits((value % TWO_PWR_32_DBL) | 0, (value / TWO_PWR_32_DBL) | 0, unsigned);
1544 }
1545
1546 /**
1547 * Returns a Long representing the given value, provided that it is a finite number. Otherwise, zero is returned.
1548 * @function
1549 * @param {number} value The number in question
1550 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
1551 * @returns {!Long} The corresponding Long value
1552 */
1553 Long.fromNumber = fromNumber;
1554
1555 /**
1556 * @param {number} lowBits
1557 * @param {number} highBits
1558 * @param {boolean=} unsigned
1559 * @returns {!Long}
1560 * @inner
1561 */
1562 function fromBits(lowBits, highBits, unsigned) {
1563 return new Long(lowBits, highBits, unsigned);
1564 }
1565
1566 /**
1567 * Returns a Long representing the 64 bit integer that comes by concatenating the given low and high bits. Each is
1568 * assumed to use 32 bits.
1569 * @function
1570 * @param {number} lowBits The low 32 bits
1571 * @param {number} highBits The high 32 bits
1572 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
1573 * @returns {!Long} The corresponding Long value
1574 */
1575 Long.fromBits = fromBits;
1576
1577 /**
1578 * @function
1579 * @param {number} base
1580 * @param {number} exponent
1581 * @returns {number}
1582 * @inner
1583 */
1584 var pow_dbl = Math.pow; // Used 4 times (4*8 to 15+4)
1585
1586 /**
1587 * @param {string} str
1588 * @param {(boolean|number)=} unsigned
1589 * @param {number=} radix
1590 * @returns {!Long}
1591 * @inner
1592 */
1593 function fromString(str, unsigned, radix) {
1594 if (str.length === 0)
1595 throw Error('empty string');
1596 if (str === "NaN" || str === "Infinity" || str === "+Infinity" || str === "-Infinity")
1597 return ZERO;
1598 if (typeof unsigned === 'number') {
1599 // For goog.math.long compatibility
1600 radix = unsigned,
1601 unsigned = false;
1602 } else {
1603 unsigned = !! unsigned;
1604 }
1605 radix = radix || 10;
1606 if (radix < 2 || 36 < radix)
1607 throw RangeError('radix');
1608
1609 var p;
1610 if ((p = str.indexOf('-')) > 0)
1611 throw Error('interior hyphen');
1612 else if (p === 0) {
1613 return fromString(str.substring(1), unsigned, radix).neg();
1614 }
1615
1616 // Do several (8) digits each time through the loop, so as to
1617 // minimize the calls to the very expensive emulated div.
1618 var radixToPower = fromNumber(pow_dbl(radix, 8));
1619
1620 var result = ZERO;
1621 for (var i = 0; i < str.length; i += 8) {
1622 var size = Math.min(8, str.length - i),
1623 value = parseInt(str.substring(i, i + size), radix);
1624 if (size < 8) {
1625 var power = fromNumber(pow_dbl(radix, size));
1626 result = result.mul(power).add(fromNumber(value));
1627 } else {
1628 result = result.mul(radixToPower);
1629 result = result.add(fromNumber(value));
1630 }
1631 }
1632 result.unsigned = unsigned;
1633 return result;
1634 }
1635
1636 /**
1637 * Returns a Long representation of the given string, written using the specified radix.
1638 * @function
1639 * @param {string} str The textual representation of the Long
1640 * @param {(boolean|number)=} unsigned Whether unsigned or not, defaults to signed
1641 * @param {number=} radix The radix in which the text is written (2-36), defaults to 10
1642 * @returns {!Long} The corresponding Long value
1643 */
1644 Long.fromString = fromString;
1645
1646 /**
1647 * @function
1648 * @param {!Long|number|string|!{low: number, high: number, unsigned: boolean}} val
1649 * @param {boolean=} unsigned
1650 * @returns {!Long}
1651 * @inner
1652 */
1653 function fromValue(val, unsigned) {
1654 if (typeof val === 'number')
1655 return fromNumber(val, unsigned);
1656 if (typeof val === 'string')
1657 return fromString(val, unsigned);
1658 // Throws for non-objects, converts non-instanceof Long:
1659 return fromBits(val.low, val.high, typeof unsigned === 'boolean' ? unsigned : val.unsigned);
1660 }
1661
1662 /**
1663 * Converts the specified value to a Long using the appropriate from* function for its type.
1664 * @function
1665 * @param {!Long|number|string|!{low: number, high: number, unsigned: boolean}} val Value
1666 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
1667 * @returns {!Long}
1668 */
1669 Long.fromValue = fromValue;
1670
1671 // NOTE: the compiler should inline these constant values below and then remove these variables, so there should be
1672 // no runtime penalty for these.
1673
1674 /**
1675 * @type {number}
1676 * @const
1677 * @inner
1678 */
1679 var TWO_PWR_16_DBL = 1 << 16;
1680
1681 /**
1682 * @type {number}
1683 * @const
1684 * @inner
1685 */
1686 var TWO_PWR_24_DBL = 1 << 24;
1687
1688 /**
1689 * @type {number}
1690 * @const
1691 * @inner
1692 */
1693 var TWO_PWR_32_DBL = TWO_PWR_16_DBL * TWO_PWR_16_DBL;
1694
1695 /**
1696 * @type {number}
1697 * @const
1698 * @inner
1699 */
1700 var TWO_PWR_64_DBL = TWO_PWR_32_DBL * TWO_PWR_32_DBL;
1701
1702 /**
1703 * @type {number}
1704 * @const
1705 * @inner
1706 */
1707 var TWO_PWR_63_DBL = TWO_PWR_64_DBL / 2;
1708
1709 /**
1710 * @type {!Long}
1711 * @const
1712 * @inner
1713 */
1714 var TWO_PWR_24 = fromInt(TWO_PWR_24_DBL);
1715
1716 /**
1717 * @type {!Long}
1718 * @inner
1719 */
1720 var ZERO = fromInt(0);
1721
1722 /**
1723 * Signed zero.
1724 * @type {!Long}
1725 */
1726 Long.ZERO = ZERO;
1727
1728 /**
1729 * @type {!Long}
1730 * @inner
1731 */
1732 var UZERO = fromInt(0, true);
1733
1734 /**
1735 * Unsigned zero.
1736 * @type {!Long}
1737 */
1738 Long.UZERO = UZERO;
1739
1740 /**
1741 * @type {!Long}
1742 * @inner
1743 */
1744 var ONE = fromInt(1);
1745
1746 /**
1747 * Signed one.
1748 * @type {!Long}
1749 */
1750 Long.ONE = ONE;
1751
1752 /**
1753 * @type {!Long}
1754 * @inner
1755 */
1756 var UONE = fromInt(1, true);
1757
1758 /**
1759 * Unsigned one.
1760 * @type {!Long}
1761 */
1762 Long.UONE = UONE;
1763
1764 /**
1765 * @type {!Long}
1766 * @inner
1767 */
1768 var NEG_ONE = fromInt(-1);
1769
1770 /**
1771 * Signed negative one.
1772 * @type {!Long}
1773 */
1774 Long.NEG_ONE = NEG_ONE;
1775
1776 /**
1777 * @type {!Long}
1778 * @inner
1779 */
1780 var MAX_VALUE = fromBits(0xFFFFFFFF|0, 0x7FFFFFFF|0, false);
1781
1782 /**
1783 * Maximum signed value.
1784 * @type {!Long}
1785 */
1786 Long.MAX_VALUE = MAX_VALUE;
1787
1788 /**
1789 * @type {!Long}
1790 * @inner
1791 */
1792 var MAX_UNSIGNED_VALUE = fromBits(0xFFFFFFFF|0, 0xFFFFFFFF|0, true);
1793
1794 /**
1795 * Maximum unsigned value.
1796 * @type {!Long}
1797 */
1798 Long.MAX_UNSIGNED_VALUE = MAX_UNSIGNED_VALUE;
1799
1800 /**
1801 * @type {!Long}
1802 * @inner
1803 */
1804 var MIN_VALUE = fromBits(0, 0x80000000|0, false);
1805
1806 /**
1807 * Minimum signed value.
1808 * @type {!Long}
1809 */
1810 Long.MIN_VALUE = MIN_VALUE;
1811
1812 /**
1813 * @alias Long.prototype
1814 * @inner
1815 */
1816 var LongPrototype = Long.prototype;
1817
1818 /**
1819 * Converts the Long to a 32 bit integer, assuming it is a 32 bit integer.
1820 * @returns {number}
1821 */
1822 LongPrototype.toInt = function toInt() {
1823 return this.unsigned ? this.low >>> 0 : this.low;
1824 };
1825
1826 /**
1827 * Converts the Long to a the nearest floating-point representation of this value (double, 53 bit mantissa).
1828 * @returns {number}
1829 */
1830 LongPrototype.toNumber = function toNumber() {
1831 if (this.unsigned)
1832 return ((this.high >>> 0) * TWO_PWR_32_DBL) + (this.low >>> 0);
1833 return this.high * TWO_PWR_32_DBL + (this.low >>> 0);
1834 };
1835
1836 /**
1837 * Converts the Long to a string written in the specified radix.
1838 * @param {number=} radix Radix (2-36), defaults to 10
1839 * @returns {string}
1840 * @override
1841 * @throws {RangeError} If `radix` is out of range
1842 */
1843 LongPrototype.toString = function toString(radix) {
1844 radix = radix || 10;
1845 if (radix < 2 || 36 < radix)
1846 throw RangeError('radix');
1847 if (this.isZero())
1848 return '0';
1849 if (this.isNegative()) { // Unsigned Longs are never negative
1850 if (this.eq(MIN_VALUE)) {
1851 // We need to change the Long value before it can be negated, so we remove
1852 // the bottom-most digit in this base and then recurse to do the rest.
1853 var radixLong = fromNumber(radix),
1854 div = this.div(radixLong),
1855 rem1 = div.mul(radixLong).sub(this);
1856 return div.toString(radix) + rem1.toInt().toString(radix);
1857 } else
1858 return '-' + this.neg().toString(radix);
1859 }
1860
1861 // Do several (6) digits each time through the loop, so as to
1862 // minimize the calls to the very expensive emulated div.
1863 var radixToPower = fromNumber(pow_dbl(radix, 6), this.unsigned),
1864 rem = this;
1865 var result = '';
1866 while (true) {
1867 var remDiv = rem.div(radixToPower),
1868 intval = rem.sub(remDiv.mul(radixToPower)).toInt() >>> 0,
1869 digits = intval.toString(radix);
1870 rem = remDiv;
1871 if (rem.isZero())
1872 return digits + result;
1873 else {
1874 while (digits.length < 6)
1875 digits = '0' + digits;
1876 result = '' + digits + result;
1877 }
1878 }
1879 };
1880
1881 /**
1882 * Gets the high 32 bits as a signed integer.
1883 * @returns {number} Signed high bits
1884 */
1885 LongPrototype.getHighBits = function getHighBits() {
1886 return this.high;
1887 };
1888
1889 /**
1890 * Gets the high 32 bits as an unsigned integer.
1891 * @returns {number} Unsigned high bits
1892 */
1893 LongPrototype.getHighBitsUnsigned = function getHighBitsUnsigned() {
1894 return this.high >>> 0;
1895 };
1896
1897 /**
1898 * Gets the low 32 bits as a signed integer.
1899 * @returns {number} Signed low bits
1900 */
1901 LongPrototype.getLowBits = function getLowBits() {
1902 return this.low;
1903 };
1904
1905 /**
1906 * Gets the low 32 bits as an unsigned integer.
1907 * @returns {number} Unsigned low bits
1908 */
1909 LongPrototype.getLowBitsUnsigned = function getLowBitsUnsigned() {
1910 return this.low >>> 0;
1911 };
1912
1913 /**
1914 * Gets the number of bits needed to represent the absolute value of this Long.
1915 * @returns {number}
1916 */
1917 LongPrototype.getNumBitsAbs = function getNumBitsAbs() {
1918 if (this.isNegative()) // Unsigned Longs are never negative
1919 return this.eq(MIN_VALUE) ? 64 : this.neg().getNumBitsAbs();
1920 var val = this.high != 0 ? this.high : this.low;
1921 for (var bit = 31; bit > 0; bit--)
1922 if ((val & (1 << bit)) != 0)
1923 break;
1924 return this.high != 0 ? bit + 33 : bit + 1;
1925 };
1926
1927 /**
1928 * Tests if this Long's value equals zero.
1929 * @returns {boolean}
1930 */
1931 LongPrototype.isZero = function isZero() {
1932 return this.high === 0 && this.low === 0;
1933 };
1934
1935 /**
1936 * Tests if this Long's value equals zero. This is an alias of {@link Long#isZero}.
1937 * @returns {boolean}
1938 */
1939 LongPrototype.eqz = LongPrototype.isZero;
1940
1941 /**
1942 * Tests if this Long's value is negative.
1943 * @returns {boolean}
1944 */
1945 LongPrototype.isNegative = function isNegative() {
1946 return !this.unsigned && this.high < 0;
1947 };
1948
1949 /**
1950 * Tests if this Long's value is positive.
1951 * @returns {boolean}
1952 */
1953 LongPrototype.isPositive = function isPositive() {
1954 return this.unsigned || this.high >= 0;
1955 };
1956
1957 /**
1958 * Tests if this Long's value is odd.
1959 * @returns {boolean}
1960 */
1961 LongPrototype.isOdd = function isOdd() {
1962 return (this.low & 1) === 1;
1963 };
1964
1965 /**
1966 * Tests if this Long's value is even.
1967 * @returns {boolean}
1968 */
1969 LongPrototype.isEven = function isEven() {
1970 return (this.low & 1) === 0;
1971 };
1972
1973 /**
1974 * Tests if this Long's value equals the specified's.
1975 * @param {!Long|number|string} other Other value
1976 * @returns {boolean}
1977 */
1978 LongPrototype.equals = function equals(other) {
1979 if (!isLong(other))
1980 other = fromValue(other);
1981 if (this.unsigned !== other.unsigned && (this.high >>> 31) === 1 && (other.high >>> 31) === 1)
1982 return false;
1983 return this.high === other.high && this.low === other.low;
1984 };
1985
1986 /**
1987 * Tests if this Long's value equals the specified's. This is an alias of {@link Long#equals}.
1988 * @function
1989 * @param {!Long|number|string} other Other value
1990 * @returns {boolean}
1991 */
1992 LongPrototype.eq = LongPrototype.equals;
1993
1994 /**
1995 * Tests if this Long's value differs from the specified's.
1996 * @param {!Long|number|string} other Other value
1997 * @returns {boolean}
1998 */
1999 LongPrototype.notEquals = function notEquals(other) {
2000 return !this.eq(/* validates */ other);
2001 };
2002
2003 /**
2004 * Tests if this Long's value differs from the specified's. This is an alias of {@link Long#notEquals}.
2005 * @function
2006 * @param {!Long|number|string} other Other value
2007 * @returns {boolean}
2008 */
2009 LongPrototype.neq = LongPrototype.notEquals;
2010
2011 /**
2012 * Tests if this Long's value differs from the specified's. This is an alias of {@link Long#notEquals}.
2013 * @function
2014 * @param {!Long|number|string} other Other value
2015 * @returns {boolean}
2016 */
2017 LongPrototype.ne = LongPrototype.notEquals;
2018
2019 /**
2020 * Tests if this Long's value is less than the specified's.
2021 * @param {!Long|number|string} other Other value
2022 * @returns {boolean}
2023 */
2024 LongPrototype.lessThan = function lessThan(other) {
2025 return this.comp(/* validates */ other) < 0;
2026 };
2027
2028 /**
2029 * Tests if this Long's value is less than the specified's. This is an alias of {@link Long#lessThan}.
2030 * @function
2031 * @param {!Long|number|string} other Other value
2032 * @returns {boolean}
2033 */
2034 LongPrototype.lt = LongPrototype.lessThan;
2035
2036 /**
2037 * Tests if this Long's value is less than or equal the specified's.
2038 * @param {!Long|number|string} other Other value
2039 * @returns {boolean}
2040 */
2041 LongPrototype.lessThanOrEqual = function lessThanOrEqual(other) {
2042 return this.comp(/* validates */ other) <= 0;
2043 };
2044
2045 /**
2046 * Tests if this Long's value is less than or equal the specified's. This is an alias of {@link Long#lessThanOrEqual}.
2047 * @function
2048 * @param {!Long|number|string} other Other value
2049 * @returns {boolean}
2050 */
2051 LongPrototype.lte = LongPrototype.lessThanOrEqual;
2052
2053 /**
2054 * Tests if this Long's value is less than or equal the specified's. This is an alias of {@link Long#lessThanOrEqual}.
2055 * @function
2056 * @param {!Long|number|string} other Other value
2057 * @returns {boolean}
2058 */
2059 LongPrototype.le = LongPrototype.lessThanOrEqual;
2060
2061 /**
2062 * Tests if this Long's value is greater than the specified's.
2063 * @param {!Long|number|string} other Other value
2064 * @returns {boolean}
2065 */
2066 LongPrototype.greaterThan = function greaterThan(other) {
2067 return this.comp(/* validates */ other) > 0;
2068 };
2069
2070 /**
2071 * Tests if this Long's value is greater than the specified's. This is an alias of {@link Long#greaterThan}.
2072 * @function
2073 * @param {!Long|number|string} other Other value
2074 * @returns {boolean}
2075 */
2076 LongPrototype.gt = LongPrototype.greaterThan;
2077
2078 /**
2079 * Tests if this Long's value is greater than or equal the specified's.
2080 * @param {!Long|number|string} other Other value
2081 * @returns {boolean}
2082 */
2083 LongPrototype.greaterThanOrEqual = function greaterThanOrEqual(other) {
2084 return this.comp(/* validates */ other) >= 0;
2085 };
2086
2087 /**
2088 * Tests if this Long's value is greater than or equal the specified's. This is an alias of {@link Long#greaterThanOrEqual}.
2089 * @function
2090 * @param {!Long|number|string} other Other value
2091 * @returns {boolean}
2092 */
2093 LongPrototype.gte = LongPrototype.greaterThanOrEqual;
2094
2095 /**
2096 * Tests if this Long's value is greater than or equal the specified's. This is an alias of {@link Long#greaterThanOrEqual}.
2097 * @function
2098 * @param {!Long|number|string} other Other value
2099 * @returns {boolean}
2100 */
2101 LongPrototype.ge = LongPrototype.greaterThanOrEqual;
2102
2103 /**
2104 * Compares this Long's value with the specified's.
2105 * @param {!Long|number|string} other Other value
2106 * @returns {number} 0 if they are the same, 1 if the this is greater and -1
2107 * if the given one is greater
2108 */
2109 LongPrototype.compare = function compare(other) {
2110 if (!isLong(other))
2111 other = fromValue(other);
2112 if (this.eq(other))
2113 return 0;
2114 var thisNeg = this.isNegative(),
2115 otherNeg = other.isNegative();
2116 if (thisNeg && !otherNeg)
2117 return -1;
2118 if (!thisNeg && otherNeg)
2119 return 1;
2120 // At this point the sign bits are the same
2121 if (!this.unsigned)
2122 return this.sub(other).isNegative() ? -1 : 1;
2123 // Both are positive if at least one is unsigned
2124 return (other.high >>> 0) > (this.high >>> 0) || (other.high === this.high && (other.low >>> 0) > (this.low >>> 0)) ? -1 : 1;
2125 };
2126
2127 /**
2128 * Compares this Long's value with the specified's. This is an alias of {@link Long#compare}.
2129 * @function
2130 * @param {!Long|number|string} other Other value
2131 * @returns {number} 0 if they are the same, 1 if the this is greater and -1
2132 * if the given one is greater
2133 */
2134 LongPrototype.comp = LongPrototype.compare;
2135
2136 /**
2137 * Negates this Long's value.
2138 * @returns {!Long} Negated Long
2139 */
2140 LongPrototype.negate = function negate() {
2141 if (!this.unsigned && this.eq(MIN_VALUE))
2142 return MIN_VALUE;
2143 return this.not().add(ONE);
2144 };
2145
2146 /**
2147 * Negates this Long's value. This is an alias of {@link Long#negate}.
2148 * @function
2149 * @returns {!Long} Negated Long
2150 */
2151 LongPrototype.neg = LongPrototype.negate;
2152
2153 /**
2154 * Returns the sum of this and the specified Long.
2155 * @param {!Long|number|string} addend Addend
2156 * @returns {!Long} Sum
2157 */
2158 LongPrototype.add = function add(addend) {
2159 if (!isLong(addend))
2160 addend = fromValue(addend);
2161
2162 // Divide each number into 4 chunks of 16 bits, and then sum the chunks.
2163
2164 var a48 = this.high >>> 16;
2165 var a32 = this.high & 0xFFFF;
2166 var a16 = this.low >>> 16;
2167 var a00 = this.low & 0xFFFF;
2168
2169 var b48 = addend.high >>> 16;
2170 var b32 = addend.high & 0xFFFF;
2171 var b16 = addend.low >>> 16;
2172 var b00 = addend.low & 0xFFFF;
2173
2174 var c48 = 0, c32 = 0, c16 = 0, c00 = 0;
2175 c00 += a00 + b00;
2176 c16 += c00 >>> 16;
2177 c00 &= 0xFFFF;
2178 c16 += a16 + b16;
2179 c32 += c16 >>> 16;
2180 c16 &= 0xFFFF;
2181 c32 += a32 + b32;
2182 c48 += c32 >>> 16;
2183 c32 &= 0xFFFF;
2184 c48 += a48 + b48;
2185 c48 &= 0xFFFF;
2186 return fromBits((c16 << 16) | c00, (c48 << 16) | c32, this.unsigned);
2187 };
2188
2189 /**
2190 * Returns the difference of this and the specified Long.
2191 * @param {!Long|number|string} subtrahend Subtrahend
2192 * @returns {!Long} Difference
2193 */
2194 LongPrototype.subtract = function subtract(subtrahend) {
2195 if (!isLong(subtrahend))
2196 subtrahend = fromValue(subtrahend);
2197 return this.add(subtrahend.neg());
2198 };
2199
2200 /**
2201 * Returns the difference of this and the specified Long. This is an alias of {@link Long#subtract}.
2202 * @function
2203 * @param {!Long|number|string} subtrahend Subtrahend
2204 * @returns {!Long} Difference
2205 */
2206 LongPrototype.sub = LongPrototype.subtract;
2207
2208 /**
2209 * Returns the product of this and the specified Long.
2210 * @param {!Long|number|string} multiplier Multiplier
2211 * @returns {!Long} Product
2212 */
2213 LongPrototype.multiply = function multiply(multiplier) {
2214 if (this.isZero())
2215 return ZERO;
2216 if (!isLong(multiplier))
2217 multiplier = fromValue(multiplier);
2218
2219 // use wasm support if present
2220 if (wasm) {
2221 var low = wasm.mul(this.low,
2222 this.high,
2223 multiplier.low,
2224 multiplier.high);
2225 return fromBits(low, wasm.get_high(), this.unsigned);
2226 }
2227
2228 if (multiplier.isZero())
2229 return ZERO;
2230 if (this.eq(MIN_VALUE))
2231 return multiplier.isOdd() ? MIN_VALUE : ZERO;
2232 if (multiplier.eq(MIN_VALUE))
2233 return this.isOdd() ? MIN_VALUE : ZERO;
2234
2235 if (this.isNegative()) {
2236 if (multiplier.isNegative())
2237 return this.neg().mul(multiplier.neg());
2238 else
2239 return this.neg().mul(multiplier).neg();
2240 } else if (multiplier.isNegative())
2241 return this.mul(multiplier.neg()).neg();
2242
2243 // If both longs are small, use float multiplication
2244 if (this.lt(TWO_PWR_24) && multiplier.lt(TWO_PWR_24))
2245 return fromNumber(this.toNumber() * multiplier.toNumber(), this.unsigned);
2246
2247 // Divide each long into 4 chunks of 16 bits, and then add up 4x4 products.
2248 // We can skip products that would overflow.
2249
2250 var a48 = this.high >>> 16;
2251 var a32 = this.high & 0xFFFF;
2252 var a16 = this.low >>> 16;
2253 var a00 = this.low & 0xFFFF;
2254
2255 var b48 = multiplier.high >>> 16;
2256 var b32 = multiplier.high & 0xFFFF;
2257 var b16 = multiplier.low >>> 16;
2258 var b00 = multiplier.low & 0xFFFF;
2259
2260 var c48 = 0, c32 = 0, c16 = 0, c00 = 0;
2261 c00 += a00 * b00;
2262 c16 += c00 >>> 16;
2263 c00 &= 0xFFFF;
2264 c16 += a16 * b00;
2265 c32 += c16 >>> 16;
2266 c16 &= 0xFFFF;
2267 c16 += a00 * b16;
2268 c32 += c16 >>> 16;
2269 c16 &= 0xFFFF;
2270 c32 += a32 * b00;
2271 c48 += c32 >>> 16;
2272 c32 &= 0xFFFF;
2273 c32 += a16 * b16;
2274 c48 += c32 >>> 16;
2275 c32 &= 0xFFFF;
2276 c32 += a00 * b32;
2277 c48 += c32 >>> 16;
2278 c32 &= 0xFFFF;
2279 c48 += a48 * b00 + a32 * b16 + a16 * b32 + a00 * b48;
2280 c48 &= 0xFFFF;
2281 return fromBits((c16 << 16) | c00, (c48 << 16) | c32, this.unsigned);
2282 };
2283
2284 /**
2285 * Returns the product of this and the specified Long. This is an alias of {@link Long#multiply}.
2286 * @function
2287 * @param {!Long|number|string} multiplier Multiplier
2288 * @returns {!Long} Product
2289 */
2290 LongPrototype.mul = LongPrototype.multiply;
2291
2292 /**
2293 * Returns this Long divided by the specified. The result is signed if this Long is signed or
2294 * unsigned if this Long is unsigned.
2295 * @param {!Long|number|string} divisor Divisor
2296 * @returns {!Long} Quotient
2297 */
2298 LongPrototype.divide = function divide(divisor) {
2299 if (!isLong(divisor))
2300 divisor = fromValue(divisor);
2301 if (divisor.isZero())
2302 throw Error('division by zero');
2303
2304 // use wasm support if present
2305 if (wasm) {
2306 // guard against signed division overflow: the largest
2307 // negative number / -1 would be 1 larger than the largest
2308 // positive number, due to two's complement.
2309 if (!this.unsigned &&
2310 this.high === -0x80000000 &&
2311 divisor.low === -1 && divisor.high === -1) {
2312 // be consistent with non-wasm code path
2313 return this;
2314 }
2315 var low = (this.unsigned ? wasm.div_u : wasm.div_s)(
2316 this.low,
2317 this.high,
2318 divisor.low,
2319 divisor.high
2320 );
2321 return fromBits(low, wasm.get_high(), this.unsigned);
2322 }
2323
2324 if (this.isZero())
2325 return this.unsigned ? UZERO : ZERO;
2326 var approx, rem, res;
2327 if (!this.unsigned) {
2328 // This section is only relevant for signed longs and is derived from the
2329 // closure library as a whole.
2330 if (this.eq(MIN_VALUE)) {
2331 if (divisor.eq(ONE) || divisor.eq(NEG_ONE))
2332 return MIN_VALUE; // recall that -MIN_VALUE == MIN_VALUE
2333 else if (divisor.eq(MIN_VALUE))
2334 return ONE;
2335 else {
2336 // At this point, we have |other| >= 2, so |this/other| < |MIN_VALUE|.
2337 var halfThis = this.shr(1);
2338 approx = halfThis.div(divisor).shl(1);
2339 if (approx.eq(ZERO)) {
2340 return divisor.isNegative() ? ONE : NEG_ONE;
2341 } else {
2342 rem = this.sub(divisor.mul(approx));
2343 res = approx.add(rem.div(divisor));
2344 return res;
2345 }
2346 }
2347 } else if (divisor.eq(MIN_VALUE))
2348 return this.unsigned ? UZERO : ZERO;
2349 if (this.isNegative()) {
2350 if (divisor.isNegative())
2351 return this.neg().div(divisor.neg());
2352 return this.neg().div(divisor).neg();
2353 } else if (divisor.isNegative())
2354 return this.div(divisor.neg()).neg();
2355 res = ZERO;
2356 } else {
2357 // The algorithm below has not been made for unsigned longs. It's therefore
2358 // required to take special care of the MSB prior to running it.
2359 if (!divisor.unsigned)
2360 divisor = divisor.toUnsigned();
2361 if (divisor.gt(this))
2362 return UZERO;
2363 if (divisor.gt(this.shru(1))) // 15 >>> 1 = 7 ; with divisor = 8 ; true
2364 return UONE;
2365 res = UZERO;
2366 }
2367
2368 // Repeat the following until the remainder is less than other: find a
2369 // floating-point that approximates remainder / other *from below*, add this
2370 // into the result, and subtract it from the remainder. It is critical that
2371 // the approximate value is less than or equal to the real value so that the
2372 // remainder never becomes negative.
2373 rem = this;
2374 while (rem.gte(divisor)) {
2375 // Approximate the result of division. This may be a little greater or
2376 // smaller than the actual value.
2377 approx = Math.max(1, Math.floor(rem.toNumber() / divisor.toNumber()));
2378
2379 // We will tweak the approximate result by changing it in the 48-th digit or
2380 // the smallest non-fractional digit, whichever is larger.
2381 var log2 = Math.ceil(Math.log(approx) / Math.LN2),
2382 delta = (log2 <= 48) ? 1 : pow_dbl(2, log2 - 48),
2383
2384 // Decrease the approximation until it is smaller than the remainder. Note
2385 // that if it is too large, the product overflows and is negative.
2386 approxRes = fromNumber(approx),
2387 approxRem = approxRes.mul(divisor);
2388 while (approxRem.isNegative() || approxRem.gt(rem)) {
2389 approx -= delta;
2390 approxRes = fromNumber(approx, this.unsigned);
2391 approxRem = approxRes.mul(divisor);
2392 }
2393
2394 // We know the answer can't be zero... and actually, zero would cause
2395 // infinite recursion since we would make no progress.
2396 if (approxRes.isZero())
2397 approxRes = ONE;
2398
2399 res = res.add(approxRes);
2400 rem = rem.sub(approxRem);
2401 }
2402 return res;
2403 };
2404
2405 /**
2406 * Returns this Long divided by the specified. This is an alias of {@link Long#divide}.
2407 * @function
2408 * @param {!Long|number|string} divisor Divisor
2409 * @returns {!Long} Quotient
2410 */
2411 LongPrototype.div = LongPrototype.divide;
2412
2413 /**
2414 * Returns this Long modulo the specified.
2415 * @param {!Long|number|string} divisor Divisor
2416 * @returns {!Long} Remainder
2417 */
2418 LongPrototype.modulo = function modulo(divisor) {
2419 if (!isLong(divisor))
2420 divisor = fromValue(divisor);
2421
2422 // use wasm support if present
2423 if (wasm) {
2424 var low = (this.unsigned ? wasm.rem_u : wasm.rem_s)(
2425 this.low,
2426 this.high,
2427 divisor.low,
2428 divisor.high
2429 );
2430 return fromBits(low, wasm.get_high(), this.unsigned);
2431 }
2432
2433 return this.sub(this.div(divisor).mul(divisor));
2434 };
2435
2436 /**
2437 * Returns this Long modulo the specified. This is an alias of {@link Long#modulo}.
2438 * @function
2439 * @param {!Long|number|string} divisor Divisor
2440 * @returns {!Long} Remainder
2441 */
2442 LongPrototype.mod = LongPrototype.modulo;
2443
2444 /**
2445 * Returns this Long modulo the specified. This is an alias of {@link Long#modulo}.
2446 * @function
2447 * @param {!Long|number|string} divisor Divisor
2448 * @returns {!Long} Remainder
2449 */
2450 LongPrototype.rem = LongPrototype.modulo;
2451
2452 /**
2453 * Returns the bitwise NOT of this Long.
2454 * @returns {!Long}
2455 */
2456 LongPrototype.not = function not() {
2457 return fromBits(~this.low, ~this.high, this.unsigned);
2458 };
2459
2460 /**
2461 * Returns the bitwise AND of this Long and the specified.
2462 * @param {!Long|number|string} other Other Long
2463 * @returns {!Long}
2464 */
2465 LongPrototype.and = function and(other) {
2466 if (!isLong(other))
2467 other = fromValue(other);
2468 return fromBits(this.low & other.low, this.high & other.high, this.unsigned);
2469 };
2470
2471 /**
2472 * Returns the bitwise OR of this Long and the specified.
2473 * @param {!Long|number|string} other Other Long
2474 * @returns {!Long}
2475 */
2476 LongPrototype.or = function or(other) {
2477 if (!isLong(other))
2478 other = fromValue(other);
2479 return fromBits(this.low | other.low, this.high | other.high, this.unsigned);
2480 };
2481
2482 /**
2483 * Returns the bitwise XOR of this Long and the given one.
2484 * @param {!Long|number|string} other Other Long
2485 * @returns {!Long}
2486 */
2487 LongPrototype.xor = function xor(other) {
2488 if (!isLong(other))
2489 other = fromValue(other);
2490 return fromBits(this.low ^ other.low, this.high ^ other.high, this.unsigned);
2491 };
2492
2493 /**
2494 * Returns this Long with bits shifted to the left by the given amount.
2495 * @param {number|!Long} numBits Number of bits
2496 * @returns {!Long} Shifted Long
2497 */
2498 LongPrototype.shiftLeft = function shiftLeft(numBits) {
2499 if (isLong(numBits))
2500 numBits = numBits.toInt();
2501 if ((numBits &= 63) === 0)
2502 return this;
2503 else if (numBits < 32)
2504 return fromBits(this.low << numBits, (this.high << numBits) | (this.low >>> (32 - numBits)), this.unsigned);
2505 else
2506 return fromBits(0, this.low << (numBits - 32), this.unsigned);
2507 };
2508
2509 /**
2510 * Returns this Long with bits shifted to the left by the given amount. This is an alias of {@link Long#shiftLeft}.
2511 * @function
2512 * @param {number|!Long} numBits Number of bits
2513 * @returns {!Long} Shifted Long
2514 */
2515 LongPrototype.shl = LongPrototype.shiftLeft;
2516
2517 /**
2518 * Returns this Long with bits arithmetically shifted to the right by the given amount.
2519 * @param {number|!Long} numBits Number of bits
2520 * @returns {!Long} Shifted Long
2521 */
2522 LongPrototype.shiftRight = function shiftRight(numBits) {
2523 if (isLong(numBits))
2524 numBits = numBits.toInt();
2525 if ((numBits &= 63) === 0)
2526 return this;
2527 else if (numBits < 32)
2528 return fromBits((this.low >>> numBits) | (this.high << (32 - numBits)), this.high >> numBits, this.unsigned);
2529 else
2530 return fromBits(this.high >> (numBits - 32), this.high >= 0 ? 0 : -1, this.unsigned);
2531 };
2532
2533 /**
2534 * Returns this Long with bits arithmetically shifted to the right by the given amount. This is an alias of {@link Long#shiftRight}.
2535 * @function
2536 * @param {number|!Long} numBits Number of bits
2537 * @returns {!Long} Shifted Long
2538 */
2539 LongPrototype.shr = LongPrototype.shiftRight;
2540
2541 /**
2542 * Returns this Long with bits logically shifted to the right by the given amount.
2543 * @param {number|!Long} numBits Number of bits
2544 * @returns {!Long} Shifted Long
2545 */
2546 LongPrototype.shiftRightUnsigned = function shiftRightUnsigned(numBits) {
2547 if (isLong(numBits))
2548 numBits = numBits.toInt();
2549 numBits &= 63;
2550 if (numBits === 0)
2551 return this;
2552 else {
2553 var high = this.high;
2554 if (numBits < 32) {
2555 var low = this.low;
2556 return fromBits((low >>> numBits) | (high << (32 - numBits)), high >>> numBits, this.unsigned);
2557 } else if (numBits === 32)
2558 return fromBits(high, 0, this.unsigned);
2559 else
2560 return fromBits(high >>> (numBits - 32), 0, this.unsigned);
2561 }
2562 };
2563
2564 /**
2565 * Returns this Long with bits logically shifted to the right by the given amount. This is an alias of {@link Long#shiftRightUnsigned}.
2566 * @function
2567 * @param {number|!Long} numBits Number of bits
2568 * @returns {!Long} Shifted Long
2569 */
2570 LongPrototype.shru = LongPrototype.shiftRightUnsigned;
2571
2572 /**
2573 * Returns this Long with bits logically shifted to the right by the given amount. This is an alias of {@link Long#shiftRightUnsigned}.
2574 * @function
2575 * @param {number|!Long} numBits Number of bits
2576 * @returns {!Long} Shifted Long
2577 */
2578 LongPrototype.shr_u = LongPrototype.shiftRightUnsigned;
2579
2580 /**
2581 * Converts this Long to signed.
2582 * @returns {!Long} Signed long
2583 */
2584 LongPrototype.toSigned = function toSigned() {
2585 if (!this.unsigned)
2586 return this;
2587 return fromBits(this.low, this.high, false);
2588 };
2589
2590 /**
2591 * Converts this Long to unsigned.
2592 * @returns {!Long} Unsigned long
2593 */
2594 LongPrototype.toUnsigned = function toUnsigned() {
2595 if (this.unsigned)
2596 return this;
2597 return fromBits(this.low, this.high, true);
2598 };
2599
2600 /**
2601 * Converts this Long to its byte representation.
2602 * @param {boolean=} le Whether little or big endian, defaults to big endian
2603 * @returns {!Array.<number>} Byte representation
2604 */
2605 LongPrototype.toBytes = function toBytes(le) {
2606 return le ? this.toBytesLE() : this.toBytesBE();
2607 };
2608
2609 /**
2610 * Converts this Long to its little endian byte representation.
2611 * @returns {!Array.<number>} Little endian byte representation
2612 */
2613 LongPrototype.toBytesLE = function toBytesLE() {
2614 var hi = this.high,
2615 lo = this.low;
2616 return [
2617 lo & 0xff,
2618 lo >>> 8 & 0xff,
2619 lo >>> 16 & 0xff,
2620 lo >>> 24 ,
2621 hi & 0xff,
2622 hi >>> 8 & 0xff,
2623 hi >>> 16 & 0xff,
2624 hi >>> 24
2625 ];
2626 };
2627
2628 /**
2629 * Converts this Long to its big endian byte representation.
2630 * @returns {!Array.<number>} Big endian byte representation
2631 */
2632 LongPrototype.toBytesBE = function toBytesBE() {
2633 var hi = this.high,
2634 lo = this.low;
2635 return [
2636 hi >>> 24 ,
2637 hi >>> 16 & 0xff,
2638 hi >>> 8 & 0xff,
2639 hi & 0xff,
2640 lo >>> 24 ,
2641 lo >>> 16 & 0xff,
2642 lo >>> 8 & 0xff,
2643 lo & 0xff
2644 ];
2645 };
2646
2647 /**
2648 * Creates a Long from its byte representation.
2649 * @param {!Array.<number>} bytes Byte representation
2650 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
2651 * @param {boolean=} le Whether little or big endian, defaults to big endian
2652 * @returns {Long} The corresponding Long value
2653 */
2654 Long.fromBytes = function fromBytes(bytes, unsigned, le) {
2655 return le ? Long.fromBytesLE(bytes, unsigned) : Long.fromBytesBE(bytes, unsigned);
2656 };
2657
2658 /**
2659 * Creates a Long from its little endian byte representation.
2660 * @param {!Array.<number>} bytes Little endian byte representation
2661 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
2662 * @returns {Long} The corresponding Long value
2663 */
2664 Long.fromBytesLE = function fromBytesLE(bytes, unsigned) {
2665 return new Long(
2666 bytes[0] |
2667 bytes[1] << 8 |
2668 bytes[2] << 16 |
2669 bytes[3] << 24,
2670 bytes[4] |
2671 bytes[5] << 8 |
2672 bytes[6] << 16 |
2673 bytes[7] << 24,
2674 unsigned
2675 );
2676 };
2677
2678 /**
2679 * Creates a Long from its big endian byte representation.
2680 * @param {!Array.<number>} bytes Big endian byte representation
2681 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
2682 * @returns {Long} The corresponding Long value
2683 */
2684 Long.fromBytesBE = function fromBytesBE(bytes, unsigned) {
2685 return new Long(
2686 bytes[4] << 24 |
2687 bytes[5] << 16 |
2688 bytes[6] << 8 |
2689 bytes[7],
2690 bytes[0] << 24 |
2691 bytes[1] << 16 |
2692 bytes[2] << 8 |
2693 bytes[3],
2694 unsigned
2695 );
2696 };
2697
2698 var LongExports = /*#__PURE__*/Object.freeze({
2699 __proto__: null,
2700 'default': long_1,
2701 __moduleExports: long_1
2702 });
2703
2704 /**
2705 * @license
2706 * Copyright 2021 Google LLC. All Rights Reserved.
2707 * Licensed under the Apache License, Version 2.0 (the "License");
2708 * you may not use this file except in compliance with the License.
2709 * You may obtain a copy of the License at
2710 *
2711 * http://www.apache.org/licenses/LICENSE-2.0
2712 *
2713 * Unless required by applicable law or agreed to in writing, software
2714 * distributed under the License is distributed on an "AS IS" BASIS,
2715 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2716 * See the License for the specific language governing permissions and
2717 * limitations under the License.
2718 * =============================================================================
2719 */
2720 // tslint:disable-next-line
2721 const Long$1 =
2722 // tslint:disable-next-line
2723 long_1 || LongExports;
2724 function hexToLong(hex) {
2725 return Long$1.fromString(hex, true, 16);
2726 }
2727 // Some primes between 2^63 and 2^64 for various uses.
2728 // Hex 0xc3a5c85c97cb3127
2729 const k0 = hexToLong('c3a5c85c97cb3127');
2730 // Hex 0xb492b66fbe98f273
2731 const k1 = hexToLong('b492b66fbe98f273');
2732 // Hex 0x9ae16a3b2f90404f
2733 const k2 = hexToLong('9ae16a3b2f90404f');
2734 function shiftMix(val) {
2735 return val.xor(val.shru(47));
2736 }
2737 function fetch$1(s, offset, numBytes) {
2738 const bytes = s.slice(offset, offset + numBytes);
2739 return Long$1.fromBytes(Array.from(bytes), true, true);
2740 }
2741 function fetch64(s, offset) {
2742 return fetch$1(s, offset, 8);
2743 }
2744 function fetch32(s, offset) {
2745 return fetch$1(s, offset, 4);
2746 }
2747 function rotate64(val, shift) {
2748 // Avoid shifting by 64: doing so yields an undefined result.
2749 return shift === 0 ? val : val.shru(shift).or(val.shl(64 - shift));
2750 }
2751 function hashLen16(u, v, mul = hexToLong('9ddfea08eb382d69')) {
2752 // Murmur-inspired hashing.
2753 let a = u.xor(v).mul(mul);
2754 a = a.xor(a.shru(47));
2755 let b = v.xor(a).mul(mul);
2756 b = b.xor(b.shru(47));
2757 b = b.mul(mul);
2758 return b;
2759 }
2760 // Return a 16-byte hash for 48 bytes. Quick and dirty.
2761 // Callers do best to use "random-looking" values for a and b.
2762 function weakHashLen32WithSeeds(w, x, y, z, a, b) {
2763 a = a.add(w);
2764 b = rotate64(b.add(a).add(z), 21);
2765 const c = a;
2766 a = a.add(x);
2767 a = a.add(y);
2768 b = b.add(rotate64(a, 44));
2769 return [a.add(z), b.add(c)];
2770 }
2771 function weakHashLen32WithSeedsStr(s, offset, a, b) {
2772 return weakHashLen32WithSeeds(fetch64(s, offset), fetch64(s, offset + 8), fetch64(s, offset + 16), fetch64(s, offset + 24), a, b);
2773 }
2774 function hashLen0to16(s, len = s.length) {
2775 if (len >= 8) {
2776 const mul = k2.add(len * 2);
2777 const a = fetch64(s, 0).add(k2);
2778 const b = fetch64(s, len - 8);
2779 const c = rotate64(b, 37).mul(mul).add(a);
2780 const d = rotate64(a, 25).add(b).mul(mul);
2781 return hashLen16(c, d, mul);
2782 }
2783 if (len >= 4) {
2784 const mul = k2.add(len * 2);
2785 const a = fetch32(s, 0);
2786 return hashLen16(a.shl(3).add(len), fetch32(s, len - 4), mul);
2787 }
2788 if (len > 0) {
2789 const a = s[0];
2790 const b = s[len >> 1];
2791 const c = s[len - 1];
2792 const y = a + (b << 8);
2793 const z = len + (c << 2);
2794 return shiftMix(k2.mul(y).xor(k0.mul(z))).mul(k2);
2795 }
2796 return k2;
2797 }
2798 function hashLen17to32(s, len = s.length) {
2799 const mul = k2.add(len * 2);
2800 const a = fetch64(s, 0).mul(k1);
2801 const b = fetch64(s, 8);
2802 const c = fetch64(s, len - 8).mul(mul);
2803 const d = fetch64(s, len - 16).mul(k2);
2804 return hashLen16(rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d), a.add(rotate64(b.add(k2), 18)).add(c), mul);
2805 }
2806 function hashLen33to64(s, len = s.length) {
2807 const mul = k2.add(len * 2);
2808 const a = fetch64(s, 0).mul(k2);
2809 const b = fetch64(s, 8);
2810 const c = fetch64(s, len - 8).mul(mul);
2811 const d = fetch64(s, len - 16).mul(k2);
2812 const y = rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d);
2813 const z = hashLen16(y, a.add(rotate64(b.add(k2), 18)).add(c), mul);
2814 const e = fetch64(s, 16).mul(mul);
2815 const f = fetch64(s, 24);
2816 const g = y.add(fetch64(s, len - 32)).mul(mul);
2817 const h = z.add(fetch64(s, len - 24)).mul(mul);
2818 return hashLen16(rotate64(e.add(f), 43).add(rotate64(g, 30)).add(h), e.add(rotate64(f.add(a), 18)).add(g), mul);
2819 }
2820 function fingerPrint64(s, len = s.length) {
2821 const seed = Long$1.fromNumber(81, true);
2822 if (len <= 32) {
2823 if (len <= 16) {
2824 return hashLen0to16(s, len);
2825 }
2826 else {
2827 return hashLen17to32(s, len);
2828 }
2829 }
2830 else if (len <= 64) {
2831 return hashLen33to64(s, len);
2832 }
2833 // For strings over 64 bytes we loop. Internal state consists of
2834 // 56 bytes: v, w, x, y, and z.
2835 let x = seed;
2836 let y = seed.mul(k1).add(113);
2837 let z = shiftMix(y.mul(k2).add(113)).mul(k2);
2838 let v = [Long$1.UZERO, Long$1.UZERO];
2839 let w = [Long$1.UZERO, Long$1.UZERO];
2840 x = x.mul(k2).add(fetch64(s, 0));
2841 let offset = 0;
2842 // Set end so that after the loop we have 1 to 64 bytes left to process.
2843 const end = ((len - 1) >> 6) * 64;
2844 const last64 = end + ((len - 1) & 63) - 63;
2845 do {
2846 x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(k1);
2847 y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(k1);
2848 x = x.xor(w[1]);
2849 y = y.add(v[0]).add(fetch64(s, offset + 40));
2850 z = rotate64(z.add(w[0]), 33).mul(k1);
2851 v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(k1), x.add(w[0]));
2852 w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16)));
2853 [z, x] = [x, z];
2854 offset += 64;
2855 } while (offset !== end);
2856 const mul = k1.add(z.and(0xff).shl(1));
2857 // Point to the last 64 bytes of input.
2858 offset = last64;
2859 w[0] = w[0].add((len - 1) & 63);
2860 v[0] = v[0].add(w[0]);
2861 w[0] = w[0].add(v[0]);
2862 x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(mul);
2863 y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(mul);
2864 x = x.xor(w[1].mul(9));
2865 y = y.add(v[0].mul(9).add(fetch64(s, offset + 40)));
2866 z = rotate64(z.add(w[0]), 33).mul(mul);
2867 v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(mul), x.add(w[0]));
2868 w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16)));
2869 [z, x] = [x, z];
2870 return hashLen16(hashLen16(v[0], w[0], mul).add(shiftMix(y).mul(k0)).add(z), hashLen16(v[1], w[1], mul).add(x), mul);
2871 }
2872
2873 /**
2874 * @license
2875 * Copyright 2017 Google LLC. All Rights Reserved.
2876 * Licensed under the Apache License, Version 2.0 (the "License");
2877 * you may not use this file except in compliance with the License.
2878 * You may obtain a copy of the License at
2879 *
2880 * http://www.apache.org/licenses/LICENSE-2.0
2881 *
2882 * Unless required by applicable law or agreed to in writing, software
2883 * distributed under the License is distributed on an "AS IS" BASIS,
2884 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2885 * See the License for the specific language governing permissions and
2886 * limitations under the License.
2887 * =============================================================================
2888 */
2889 /**
2890 * Create typed array for scalar value. Used for storing in `DataStorage`.
2891 */
2892 function createScalarValue(value, dtype) {
2893 if (dtype === 'string') {
2894 return encodeString(value);
2895 }
2896 return toTypedArray([value], dtype);
2897 }
2898 function noConversionNeeded(a, dtype) {
2899 return (a instanceof Float32Array && dtype === 'float32') ||
2900 (a instanceof Int32Array && dtype === 'int32') ||
2901 (a instanceof Uint8Array && dtype === 'bool');
2902 }
2903 function toTypedArray(a, dtype) {
2904 if (dtype === 'string') {
2905 throw new Error('Cannot convert a string[] to a TypedArray');
2906 }
2907 if (Array.isArray(a)) {
2908 a = flatten(a);
2909 }
2910 if (env().getBool('DEBUG')) {
2911 checkConversionForErrors(a, dtype);
2912 }
2913 if (noConversionNeeded(a, dtype)) {
2914 return a;
2915 }
2916 if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
2917 return new Float32Array(a);
2918 }
2919 else if (dtype === 'int32') {
2920 return new Int32Array(a);
2921 }
2922 else if (dtype === 'bool') {
2923 const bool = new Uint8Array(a.length);
2924 for (let i = 0; i < bool.length; ++i) {
2925 if (Math.round(a[i]) !== 0) {
2926 bool[i] = 1;
2927 }
2928 }
2929 return bool;
2930 }
2931 else {
2932 throw new Error(`Unknown data type ${dtype}`);
2933 }
2934 }
2935 /**
2936 * Returns the current high-resolution time in milliseconds relative to an
2937 * arbitrary time in the past. It works across different platforms (node.js,
2938 * browsers).
2939 *
2940 * ```js
2941 * console.log(tf.util.now());
2942 * ```
2943 *
2944 * @doc {heading: 'Util', namespace: 'util'}
2945 */
2946 function now() {
2947 return env().platform.now();
2948 }
2949 /**
2950 * Returns a platform-specific implementation of
2951 * [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API).
2952 *
2953 * If `fetch` is defined on the global object (`window`, `process`, etc.),
2954 * `tf.util.fetch` returns that function.
2955 *
2956 * If not, `tf.util.fetch` returns a platform-specific solution.
2957 *
2958 * ```js
2959 * const resource = await tf.util.fetch('https://unpkg.com/@tensorflow/tfjs');
2960 * // handle response
2961 * ```
2962 *
2963 * @doc {heading: 'Util'}
2964 */
2965 function fetch$2(path, requestInits) {
2966 return env().platform.fetch(path, requestInits);
2967 }
2968 /**
2969 * Encodes the provided string into bytes using the provided encoding scheme.
2970 *
2971 * @param s The string to encode.
2972 * @param encoding The encoding scheme. Defaults to utf-8.
2973 *
2974 * @doc {heading: 'Util'}
2975 */
2976 function encodeString(s, encoding = 'utf-8') {
2977 encoding = encoding || 'utf-8';
2978 return env().platform.encode(s, encoding);
2979 }
2980 /**
2981 * Decodes the provided bytes into a string using the provided encoding scheme.
2982 * @param bytes The bytes to decode.
2983 *
2984 * @param encoding The encoding scheme. Defaults to utf-8.
2985 *
2986 * @doc {heading: 'Util'}
2987 */
2988 function decodeString(bytes, encoding = 'utf-8') {
2989 encoding = encoding || 'utf-8';
2990 return env().platform.decode(bytes, encoding);
2991 }
2992
2993 var util = /*#__PURE__*/Object.freeze({
2994 __proto__: null,
2995 createScalarValue: createScalarValue,
2996 toTypedArray: toTypedArray,
2997 now: now,
2998 fetch: fetch$2,
2999 encodeString: encodeString,
3000 decodeString: decodeString,
3001 shuffle: shuffle,
3002 shuffleCombo: shuffleCombo,
3003 clamp: clamp,
3004 nearestLargerEven: nearestLargerEven,
3005 swap: swap,
3006 sum: sum,
3007 randUniform: randUniform,
3008 distSquared: distSquared,
3009 assert: assert,
3010 assertShapesMatch: assertShapesMatch,
3011 assertNonNull: assertNonNull,
3012 flatten: flatten,
3013 sizeFromShape: sizeFromShape,
3014 isScalarShape: isScalarShape,
3015 arraysEqual: arraysEqual,
3016 isInt: isInt,
3017 tanh: tanh,
3018 sizeToSquarishShape: sizeToSquarishShape,
3019 createShuffledIndices: createShuffledIndices,
3020 rightPad: rightPad,
3021 repeatedTry: repeatedTry,
3022 inferFromImplicitShape: inferFromImplicitShape,
3023 parseAxisParam: parseAxisParam,
3024 squeezeShape: squeezeShape,
3025 getTypedArrayFromDType: getTypedArrayFromDType,
3026 getArrayFromDType: getArrayFromDType,
3027 checkConversionForErrors: checkConversionForErrors,
3028 isValidDtype: isValidDtype,
3029 hasEncodingLoss: hasEncodingLoss,
3030 isTypedArray: isTypedArray,
3031 bytesPerElement: bytesPerElement,
3032 bytesFromStringArray: bytesFromStringArray,
3033 isString: isString,
3034 isBoolean: isBoolean,
3035 isNumber: isNumber,
3036 inferDtype: inferDtype,
3037 isFunction: isFunction,
3038 nearestDivisor: nearestDivisor,
3039 computeStrides: computeStrides,
3040 toNestedArray: toNestedArray,
3041 makeOnesTypedArray: makeOnesTypedArray,
3042 makeZerosTypedArray: makeZerosTypedArray,
3043 makeZerosNestedTypedArray: makeZerosNestedTypedArray,
3044 assertNonNegativeIntegerDimensions: assertNonNegativeIntegerDimensions,
3045 locToIndex: locToIndex,
3046 indexToLoc: indexToLoc,
3047 isPromise: isPromise,
3048 hexToLong: hexToLong,
3049 fingerPrint64: fingerPrint64
3050 });
3051
3052 /**
3053 * @license
3054 * Copyright 2018 Google LLC. All Rights Reserved.
3055 * Licensed under the Apache License, Version 2.0 (the "License");
3056 * you may not use this file except in compliance with the License.
3057 * You may obtain a copy of the License at
3058 *
3059 * http://www.apache.org/licenses/LICENSE-2.0
3060 *
3061 * Unless required by applicable law or agreed to in writing, software
3062 * distributed under the License is distributed on an "AS IS" BASIS,
3063 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3064 * See the License for the specific language governing permissions and
3065 * limitations under the License.
3066 * =============================================================================
3067 */
3068 class Profiler {
3069 constructor(backendTimer, logger) {
3070 this.backendTimer = backendTimer;
3071 this.logger = logger;
3072 if (logger == null) {
3073 this.logger = new Logger();
3074 }
3075 }
3076 profileKernel(kernelName, inputs, f) {
3077 let outputs;
3078 const holdResultWrapperFn = () => {
3079 outputs = f();
3080 };
3081 let timer;
3082 const start = now();
3083 if (this.backendTimer.timerAvailable()) {
3084 timer = this.backendTimer.time(holdResultWrapperFn);
3085 }
3086 else {
3087 holdResultWrapperFn();
3088 for (const output of outputs) {
3089 output.dataSync();
3090 }
3091 timer = Promise.resolve({ kernelMs: now() - start });
3092 }
3093 if (env().getBool('CHECK_COMPUTATION_FOR_ERRORS')) {
3094 for (let i = 0; i < outputs.length; i++) {
3095 const output = outputs[i];
3096 // Dangling promise here because we don't want to propagate up
3097 // asynchronicity.
3098 output.data().then(tensorVals => {
3099 checkComputationForErrors(tensorVals, output.dtype, kernelName);
3100 });
3101 }
3102 }
3103 const kernelProfile = {
3104 kernelName,
3105 outputs,
3106 inputs,
3107 timeMs: timer.then(timing => timing.kernelMs),
3108 extraInfo: timer.then(timing => timing.getExtraProfileInfo != null ?
3109 timing.getExtraProfileInfo() :
3110 '')
3111 };
3112 return kernelProfile;
3113 }
3114 logKernelProfile(kernelProfile) {
3115 const { kernelName, outputs, timeMs, inputs, extraInfo } = kernelProfile;
3116 outputs.forEach(result => {
3117 Promise.all([result.data(), timeMs, extraInfo]).then(valueContainer => {
3118 this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]);
3119 });
3120 });
3121 }
3122 }
3123 function checkComputationForErrors(vals, dtype, kernelName) {
3124 if (dtype !== 'float32') {
3125 // Only floating point computations will generate NaN values
3126 return false;
3127 }
3128 for (let i = 0; i < vals.length; i++) {
3129 const num = vals[i];
3130 if (isNaN(num) || !isFinite(num)) {
3131 // Throwing custom exception so behavior is testable.
3132 console.warn(`Found ${num} in the result of '${kernelName}'`);
3133 return true;
3134 }
3135 }
3136 return false;
3137 }
3138 class Logger {
3139 logKernelProfile(name, result, vals, timeMs, inputs, extraInfo) {
3140 const time = typeof timeMs === 'number' ? rightPad(`${timeMs}ms`, 9) :
3141 timeMs['error'];
3142 const paddedName = rightPad(name, 25);
3143 const rank = result.rank;
3144 const size = result.size;
3145 const shape = rightPad(result.shape.toString(), 14);
3146 let inputShapesDescription = '';
3147 for (const name in inputs) {
3148 const input = inputs[name];
3149 if (input != null) {
3150 // The input might be a non-tensor (e.g HTMLImageElement), in which case
3151 // we claim the output shape as input shape.
3152 const inputShape = input.shape || result.shape;
3153 const inputRank = inputShape.length;
3154 inputShapesDescription +=
3155 `${name}: ${inputRank}D ${inputRank > 0 ? inputShape : ''} `;
3156 }
3157 }
3158 console.log(`%c${paddedName}\t%c${time}\t%c${rank}D ${shape}\t%c${size}\t%c${inputShapesDescription}\t%c${extraInfo}`, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue');
3159 }
3160 }
3161
3162 /**
3163 * @license
3164 * Copyright 2017 Google LLC. All Rights Reserved.
3165 * Licensed under the Apache License, Version 2.0 (the "License");
3166 * you may not use this file except in compliance with the License.
3167 * You may obtain a copy of the License at
3168 *
3169 * http://www.apache.org/licenses/LICENSE-2.0
3170 *
3171 * Unless required by applicable law or agreed to in writing, software
3172 * distributed under the License is distributed on an "AS IS" BASIS,
3173 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3174 * See the License for the specific language governing permissions and
3175 * limitations under the License.
3176 * =============================================================================
3177 */
3178 /**
3179 * Computes a list of TapeNodes that connect x to y, filtering everything else
3180 * out and preserving the order of the original tape elements.
3181 *
3182 * @param tape The tape elements to filter.
3183 * @param xs The input Tensors.
3184 * @param y The output Tensor.
3185 */
3186 function getFilteredNodesXToY(tape, xs, y) {
3187 // Forward pass to compute all the nodes and Tensors that are transitively a
3188 // function of x.
3189 const tensorsFromX = {};
3190 const nodesFromX = {};
3191 for (let i = 0; i < xs.length; i++) {
3192 tensorsFromX[xs[i].id] = true;
3193 }
3194 for (let i = 0; i < tape.length; i++) {
3195 const node = tape[i];
3196 const nodeInputs = node.inputs;
3197 for (const inputName in nodeInputs) {
3198 const input = nodeInputs[inputName];
3199 let anyInputFromX = false;
3200 for (let j = 0; j < xs.length; j++) {
3201 if (tensorsFromX[input.id]) {
3202 node.outputs.forEach(output => tensorsFromX[output.id] = true);
3203 anyInputFromX = true;
3204 nodesFromX[node.id] = true;
3205 break;
3206 }
3207 }
3208 if (anyInputFromX) {
3209 break;
3210 }
3211 }
3212 }
3213 // Backward pass to find all of the nodes and Tensors that lead to y.
3214 const tensorsLeadToY = {};
3215 tensorsLeadToY[y.id] = true;
3216 const nodesToY = {};
3217 for (let i = tape.length - 1; i >= 0; i--) {
3218 const node = tape[i];
3219 const nodeInputs = node.inputs;
3220 // If any of the outputs lead to y, mark all of the inputs as leading to y.
3221 for (let j = 0; j < node.outputs.length; j++) {
3222 if (tensorsLeadToY[node.outputs[j].id]) {
3223 for (const inputName in nodeInputs) {
3224 tensorsLeadToY[nodeInputs[inputName].id] = true;
3225 nodesToY[node.id] = true;
3226 }
3227 break;
3228 }
3229 }
3230 }
3231 // Return the paths that come from x and lead to y.
3232 const filteredTape = [];
3233 for (let i = 0; i < tape.length; i++) {
3234 const node = tape[i];
3235 if (nodesFromX[node.id] && nodesToY[node.id]) {
3236 // Prune the inputs from the node that aren't a function of x.
3237 const prunedInputs = {};
3238 for (const inputName in node.inputs) {
3239 const nodeInput = node.inputs[inputName];
3240 if (tensorsFromX[nodeInput.id]) {
3241 prunedInputs[inputName] = nodeInput;
3242 }
3243 }
3244 // Copy the node and overwrite inputsAndArgs to the pruned version.
3245 const prunedNode = Object.assign({}, node);
3246 prunedNode.inputs = prunedInputs;
3247 prunedNode.outputs = node.outputs;
3248 filteredTape.push(prunedNode);
3249 }
3250 }
3251 return filteredTape;
3252 }
3253 /**
3254 * Backpropagate gradients through the filtered TapeNodes.
3255 *
3256 * @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map
3257 * is mutated by this method.
3258 * @param filteredTape The filtered TapeNodes to backprop through.
3259 */
3260 function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy, add) {
3261 // Walk the tape backward and keep a map of Tensor to its gradient.
3262 for (let i = filteredTape.length - 1; i >= 0; i--) {
3263 const node = filteredTape[i];
3264 const dys = [];
3265 node.outputs.forEach(o => {
3266 const gradTensor = tensorAccumulatedGradientMap[o.id];
3267 if (gradTensor != null) {
3268 dys.push(gradTensor);
3269 }
3270 else {
3271 // This particular output is not in the back-propagation subgraph, so it
3272 // does not affect the final output, thus we put null for its dy.
3273 dys.push(null);
3274 }
3275 });
3276 if (node.gradient == null) {
3277 throw new Error(`Cannot compute gradient: gradient function not found ` +
3278 `for ${node.kernelName}.`);
3279 }
3280 // Backprop dy through this node and accumulate gradients over the inputs.
3281 const inputGradients = node.gradient(dys);
3282 for (const inputName in node.inputs) {
3283 if (!(inputName in inputGradients)) {
3284 throw new Error(`Cannot backprop through input ${inputName}. ` +
3285 `Available gradients found: ${Object.keys(inputGradients)}.`);
3286 }
3287 // Call the gradient function.
3288 const dx = tidy(() => inputGradients[inputName]());
3289 if (dx.dtype !== 'float32') {
3290 throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` +
3291 `${inputName} must have 'float32' dtype, but has '${dx.dtype}'`);
3292 }
3293 const x = node.inputs[inputName];
3294 if (!arraysEqual(dx.shape, x.shape)) {
3295 throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` +
3296 `'${inputName}' has shape '${dx.shape}', which does not match ` +
3297 `the shape of the input '${x.shape}'`);
3298 }
3299 if (tensorAccumulatedGradientMap[x.id] == null) {
3300 tensorAccumulatedGradientMap[x.id] = dx;
3301 }
3302 else {
3303 const curGradient = tensorAccumulatedGradientMap[x.id];
3304 tensorAccumulatedGradientMap[x.id] = add(curGradient, dx);
3305 curGradient.dispose();
3306 }
3307 }
3308 }
3309 }
3310
3311 /**
3312 * @license
3313 * Copyright 2018 Google LLC. All Rights Reserved.
3314 * Licensed under the Apache License, Version 2.0 (the "License");
3315 * you may not use this file except in compliance with the License.
3316 * You may obtain a copy of the License at
3317 *
3318 * http://www.apache.org/licenses/LICENSE-2.0
3319 *
3320 * Unless required by applicable law or agreed to in writing, software
3321 * distributed under the License is distributed on an "AS IS" BASIS,
3322 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3323 * See the License for the specific language governing permissions and
3324 * limitations under the License.
3325 * =============================================================================
3326 */
3327 // Maximum number of values before we decide to show ellipsis.
3328 const FORMAT_LIMIT_NUM_VALS = 20;
3329 // Number of first and last values to show when displaying a, b,...,y, z.
3330 const FORMAT_NUM_FIRST_LAST_VALS = 3;
3331 // Number of significant digits to show.
3332 const FORMAT_NUM_SIG_DIGITS = 7;
3333 function tensorToString(vals, shape, dtype, verbose) {
3334 const strides = computeStrides(shape);
3335 const padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides);
3336 const rank = shape.length;
3337 const valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol);
3338 const lines = ['Tensor'];
3339 if (verbose) {
3340 lines.push(` dtype: ${dtype}`);
3341 lines.push(` rank: ${rank}`);
3342 lines.push(` shape: [${shape}]`);
3343 lines.push(` values:`);
3344 }
3345 lines.push(valsLines.map(l => ' ' + l).join('\n'));
3346 return lines.join('\n');
3347 }
3348 function computeMaxSizePerColumn(vals, shape, dtype, strides) {
3349 const n = sizeFromShape(shape);
3350 const numCols = strides[strides.length - 1];
3351 const padPerCol = new Array(numCols).fill(0);
3352 const rank = shape.length;
3353 const valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals;
3354 if (rank > 1) {
3355 for (let row = 0; row < n / numCols; row++) {
3356 const offset = row * numCols;
3357 for (let j = 0; j < numCols; j++) {
3358 padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length);
3359 }
3360 }
3361 }
3362 return padPerCol;
3363 }
3364 function valToString(val, pad, dtype) {
3365 let valStr;
3366 if (Array.isArray(val)) {
3367 valStr = `${parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS))} + ` +
3368 `${parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS))}j`;
3369 }
3370 else if (isString(val)) {
3371 valStr = `'${val}'`;
3372 }
3373 else if (dtype === 'bool') {
3374 valStr = boolNumToString(val);
3375 }
3376 else {
3377 valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString();
3378 }
3379 return rightPad(valStr, pad);
3380 }
3381 function boolNumToString(v) {
3382 return v === 0 ? 'false' : 'true';
3383 }
3384 function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast = true) {
3385 const storagePerElement = dtype === 'complex64' ? 2 : 1;
3386 const size = shape[0];
3387 const rank = shape.length;
3388 if (rank === 0) {
3389 if (dtype === 'complex64') {
3390 const complexTuple = createComplexTuples(vals);
3391 return [valToString(complexTuple[0], 0, dtype)];
3392 }
3393 if (dtype === 'bool') {
3394 return [boolNumToString(vals[0])];
3395 }
3396 return [vals[0].toString()];
3397 }
3398 if (rank === 1) {
3399 if (size > FORMAT_LIMIT_NUM_VALS) {
3400 const firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement;
3401 let firstVals = Array.from(vals.slice(0, firstValsSize));
3402 let lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement));
3403 if (dtype === 'complex64') {
3404 firstVals = createComplexTuples(firstVals);
3405 lastVals = createComplexTuples(lastVals);
3406 }
3407 return [
3408 '[' +
3409 firstVals.map((x, i) => valToString(x, padPerCol[i], dtype))
3410 .join(', ') +
3411 ', ..., ' +
3412 lastVals
3413 .map((x, i) => valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype))
3414 .join(', ') +
3415 ']'
3416 ];
3417 }
3418 const displayVals = dtype === 'complex64' ? createComplexTuples(vals) :
3419 Array.from(vals);
3420 return [
3421 '[' +
3422 displayVals.map((x, i) => valToString(x, padPerCol[i], dtype))
3423 .join(', ') +
3424 ']'
3425 ];
3426 }
3427 // The array is rank 2 or more.
3428 const subshape = shape.slice(1);
3429 const substrides = strides.slice(1);
3430 const stride = strides[0] * storagePerElement;
3431 const lines = [];
3432 if (size > FORMAT_LIMIT_NUM_VALS) {
3433 for (let i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) {
3434 const start = i * stride;
3435 const end = start + stride;
3436 lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false /* isLast */));
3437 }
3438 lines.push('...');
3439 for (let i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) {
3440 const start = i * stride;
3441 const end = start + stride;
3442 lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */));
3443 }
3444 }
3445 else {
3446 for (let i = 0; i < size; i++) {
3447 const start = i * stride;
3448 const end = start + stride;
3449 lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */));
3450 }
3451 }
3452 const sep = rank === 2 ? ',' : '';
3453 lines[0] = '[' + lines[0] + sep;
3454 for (let i = 1; i < lines.length - 1; i++) {
3455 lines[i] = ' ' + lines[i] + sep;
3456 }
3457 let newLineSep = ',\n';
3458 for (let i = 2; i < rank; i++) {
3459 newLineSep += '\n';
3460 }
3461 lines[lines.length - 1] =
3462 ' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep);
3463 return lines;
3464 }
3465 function createComplexTuples(vals) {
3466 const complexTuples = [];
3467 for (let i = 0; i < vals.length; i += 2) {
3468 complexTuples.push([vals[i], vals[i + 1]]);
3469 }
3470 return complexTuples;
3471 }
3472
3473 /**
3474 * @license
3475 * Copyright 2017 Google LLC. All Rights Reserved.
3476 * Licensed under the Apache License, Version 2.0 (the "License");
3477 * you may not use this file except in compliance with the License.
3478 * You may obtain a copy of the License at
3479 *
3480 * http://www.apache.org/licenses/LICENSE-2.0
3481 *
3482 * Unless required by applicable law or agreed to in writing, software
3483 * distributed under the License is distributed on an "AS IS" BASIS,
3484 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3485 * See the License for the specific language governing permissions and
3486 * limitations under the License.
3487 * =============================================================================
3488 */
3489 /**
3490 * A mutable object, similar to `tf.Tensor`, that allows users to set values
3491 * at locations before converting to an immutable `tf.Tensor`.
3492 *
3493 * See `tf.buffer` for creating a tensor buffer.
3494 *
3495 * @doc {heading: 'Tensors', subheading: 'Classes'}
3496 */
3497 class TensorBuffer {
3498 constructor(shape, dtype, values) {
3499 this.dtype = dtype;
3500 this.shape = shape.slice();
3501 this.size = sizeFromShape(shape);
3502 if (values != null) {
3503 const n = values.length;
3504 assert(n === this.size, () => `Length of values '${n}' does not match the size ` +
3505 `inferred by the shape '${this.size}'.`);
3506 }
3507 if (dtype === 'complex64') {
3508 throw new Error(`complex64 dtype TensorBuffers are not supported. Please create ` +
3509 `a TensorBuffer for the real and imaginary parts separately and ` +
3510 `call tf.complex(real, imag).`);
3511 }
3512 this.values = values || getArrayFromDType(dtype, this.size);
3513 this.strides = computeStrides(shape);
3514 }
3515 /**
3516 * Sets a value in the buffer at a given location.
3517 *
3518 * @param value The value to set.
3519 * @param locs The location indices.
3520 *
3521 * @doc {heading: 'Tensors', subheading: 'Creation'}
3522 */
3523 set(value, ...locs) {
3524 if (locs.length === 0) {
3525 locs = [0];
3526 }
3527 assert(locs.length === this.rank, () => `The number of provided coordinates (${locs.length}) must ` +
3528 `match the rank (${this.rank})`);
3529 const index = this.locToIndex(locs);
3530 this.values[index] = value;
3531 }
3532 /**
3533 * Returns the value in the buffer at the provided location.
3534 *
3535 * @param locs The location indices.
3536 *
3537 * @doc {heading: 'Tensors', subheading: 'Creation'}
3538 */
3539 get(...locs) {
3540 if (locs.length === 0) {
3541 locs = [0];
3542 }
3543 let i = 0;
3544 for (const loc of locs) {
3545 if (loc < 0 || loc >= this.shape[i]) {
3546 const msg = `Requested out of range element at ${locs}. ` +
3547 ` Buffer shape=${this.shape}`;
3548 throw new Error(msg);
3549 }
3550 i++;
3551 }
3552 let index = locs[locs.length - 1];
3553 for (let i = 0; i < locs.length - 1; ++i) {
3554 index += this.strides[i] * locs[i];
3555 }
3556 return this.values[index];
3557 }
3558 locToIndex(locs) {
3559 if (this.rank === 0) {
3560 return 0;
3561 }
3562 else if (this.rank === 1) {
3563 return locs[0];
3564 }
3565 let index = locs[locs.length - 1];
3566 for (let i = 0; i < locs.length - 1; ++i) {
3567 index += this.strides[i] * locs[i];
3568 }
3569 return index;
3570 }
3571 indexToLoc(index) {
3572 if (this.rank === 0) {
3573 return [];
3574 }
3575 else if (this.rank === 1) {
3576 return [index];
3577 }
3578 const locs = new Array(this.shape.length);
3579 for (let i = 0; i < locs.length - 1; ++i) {
3580 locs[i] = Math.floor(index / this.strides[i]);
3581 index -= locs[i] * this.strides[i];
3582 }
3583 locs[locs.length - 1] = index;
3584 return locs;
3585 }
3586 get rank() {
3587 return this.shape.length;
3588 }
3589 /**
3590 * Creates an immutable `tf.Tensor` object from the buffer.
3591 *
3592 * @doc {heading: 'Tensors', subheading: 'Creation'}
3593 */
3594 toTensor() {
3595 return trackerFn().makeTensor(this.values, this.shape, this.dtype);
3596 }
3597 }
3598 // For tracking tensor creation and disposal.
3599 let trackerFn = null;
3600 // Used by chaining methods to call into ops.
3601 let opHandler = null;
3602 // Used to warn about deprecated methods.
3603 let deprecationWarningFn = null;
3604 // This here so that we can use this method on dev branches and keep the
3605 // functionality at master.
3606 // tslint:disable-next-line:no-unused-expression
3607 [deprecationWarningFn];
3608 /**
3609 * An external consumer can register itself as the tensor tracker. This way
3610 * the Tensor class can notify the tracker for every tensor created and
3611 * disposed.
3612 */
3613 function setTensorTracker(fn) {
3614 trackerFn = fn;
3615 }
3616 /**
3617 * An external consumer can register itself as the op handler. This way the
3618 * Tensor class can have chaining methods that call into ops via the op
3619 * handler.
3620 */
3621 function setOpHandler(handler) {
3622 opHandler = handler;
3623 }
3624 /**
3625 * Sets the deprecation warning function to be used by this file. This way the
3626 * Tensor class can be a leaf but still use the environment.
3627 */
3628 function setDeprecationWarningFn(fn) {
3629 deprecationWarningFn = fn;
3630 }
3631 /**
3632 * A `tf.Tensor` object represents an immutable, multidimensional array of
3633 * numbers that has a shape and a data type.
3634 *
3635 * For performance reasons, functions that create tensors do not necessarily
3636 * perform a copy of the data passed to them (e.g. if the data is passed as a
3637 * `Float32Array`), and changes to the data will change the tensor. This is not
3638 * a feature and is not supported. To avoid this behavior, use the tensor before
3639 * changing the input data or create a copy with `copy = tf.add(yourTensor, 0)`.
3640 *
3641 * See `tf.tensor` for details on how to create a `tf.Tensor`.
3642 *
3643 * @doc {heading: 'Tensors', subheading: 'Classes'}
3644 */
3645 class Tensor {
3646 constructor(shape, dtype, dataId, id) {
3647 /** Whether this tensor has been globally kept. */
3648 this.kept = false;
3649 this.isDisposedInternal = false;
3650 this.shape = shape.slice();
3651 this.dtype = dtype || 'float32';
3652 this.size = sizeFromShape(shape);
3653 this.strides = computeStrides(shape);
3654 this.dataId = dataId;
3655 this.id = id;
3656 this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher');
3657 }
3658 get rank() {
3659 return this.shape.length;
3660 }
3661 /**
3662 * Returns a promise of `tf.TensorBuffer` that holds the underlying data.
3663 *
3664 * @doc {heading: 'Tensors', subheading: 'Classes'}
3665 */
3666 async buffer() {
3667 const vals = await this.data();
3668 return opHandler.buffer(this.shape, this.dtype, vals);
3669 }
3670 /**
3671 * Returns a `tf.TensorBuffer` that holds the underlying data.
3672 * @doc {heading: 'Tensors', subheading: 'Classes'}
3673 */
3674 bufferSync() {
3675 return opHandler.buffer(this.shape, this.dtype, this.dataSync());
3676 }
3677 /**
3678 * Returns the tensor data as a nested array. The transfer of data is done
3679 * asynchronously.
3680 *
3681 * @doc {heading: 'Tensors', subheading: 'Classes'}
3682 */
3683 async array() {
3684 const vals = await this.data();
3685 return toNestedArray(this.shape, vals, this.dtype === 'complex64');
3686 }
3687 /**
3688 * Returns the tensor data as a nested array. The transfer of data is done
3689 * synchronously.
3690 *
3691 * @doc {heading: 'Tensors', subheading: 'Classes'}
3692 */
3693 arraySync() {
3694 return toNestedArray(this.shape, this.dataSync(), this.dtype === 'complex64');
3695 }
3696 /**
3697 * Asynchronously downloads the values from the `tf.Tensor`. Returns a
3698 * promise of `TypedArray` that resolves when the computation has finished.
3699 *
3700 * @doc {heading: 'Tensors', subheading: 'Classes'}
3701 */
3702 async data() {
3703 this.throwIfDisposed();
3704 const data = trackerFn().read(this.dataId);
3705 if (this.dtype === 'string') {
3706 const bytes = await data;
3707 try {
3708 return bytes.map(b => decodeString(b));
3709 }
3710 catch (_a) {
3711 throw new Error('Failed to decode the string bytes into utf-8. ' +
3712 'To get the original bytes, call tensor.bytes().');
3713 }
3714 }
3715 return data;
3716 }
3717 /**
3718 * Copy the tensor's data to a new GPU resource. Comparing to the `dataSync()`
3719 * and `data()`, this method prevents data from being downloaded to CPU.
3720 *
3721 * For WebGL backend, the data will be stored on a densely packed texture.
3722 * This means that the texture will use the RGBA channels to store value.
3723 *
3724 * For WebGPU backend, the data will be stored on a buffer. There is no
3725 * parameter, so can not use an user defined size to create the buffer.
3726 *
3727 * @param options:
3728 * For WebGL,
3729 * - customTexShape: Optional. If set, will use the user defined
3730 * texture shape to create the texture.
3731 *
3732 * @returns For WebGL backend, a GPUData contains the new texture and
3733 * its information.
3734 * {
3735 * tensorRef: The tensor that is associated with this texture,
3736 * texture: WebGLTexture,
3737 * texShape: [number, number] // [height, width]
3738 * }
3739 *
3740 * For WebGPU backend, a GPUData contains the new buffer and
3741 * its information.
3742 * {
3743 * tensorRef: The tensor that is associated with this buffer,
3744 * buffer: GPUBuffer,
3745 * bufSize: number
3746 * }
3747 *
3748 * Remember to dispose the GPUData after it is used by
3749 * `res.tensorRef.dispose()`.
3750 *
3751 * @doc {heading: 'Tensors', subheading: 'Classes'}
3752 */
3753 dataToGPU(options) {
3754 this.throwIfDisposed();
3755 return trackerFn().readToGPU(this.dataId, options);
3756 }
3757 /**
3758 * Synchronously downloads the values from the `tf.Tensor`. This blocks the
3759 * UI thread until the values are ready, which can cause performance issues.
3760 *
3761 * @doc {heading: 'Tensors', subheading: 'Classes'}
3762 */
3763 dataSync() {
3764 this.throwIfDisposed();
3765 const data = trackerFn().readSync(this.dataId);
3766 if (this.dtype === 'string') {
3767 try {
3768 return data.map(b => decodeString(b));
3769 }
3770 catch (_a) {
3771 throw new Error('Failed to decode the string bytes into utf-8. ' +
3772 'To get the original bytes, call tensor.bytes().');
3773 }
3774 }
3775 return data;
3776 }
3777 /** Returns the underlying bytes of the tensor's data. */
3778 async bytes() {
3779 this.throwIfDisposed();
3780 const data = await trackerFn().read(this.dataId);
3781 if (this.dtype === 'string') {
3782 return data;
3783 }
3784 else {
3785 return new Uint8Array(data.buffer);
3786 }
3787 }
3788 /**
3789 * Disposes `tf.Tensor` from memory.
3790 *
3791 * @doc {heading: 'Tensors', subheading: 'Classes'}
3792 */
3793 dispose() {
3794 if (this.isDisposed) {
3795 return;
3796 }
3797 trackerFn().disposeTensor(this);
3798 this.isDisposedInternal = true;
3799 }
3800 get isDisposed() {
3801 return this.isDisposedInternal;
3802 }
3803 throwIfDisposed() {
3804 if (this.isDisposed) {
3805 throw new Error(`Tensor is disposed.`);
3806 }
3807 }
3808 /**
3809 * Prints the `tf.Tensor`. See `tf.print` for details.
3810 *
3811 * @param verbose Whether to print verbose information about the tensor,
3812 * including dtype and size.
3813 *
3814 * @doc {heading: 'Tensors', subheading: 'Classes'}
3815 */
3816 print(verbose = false) {
3817 return opHandler.print(this, verbose);
3818 }
3819 /**
3820 * Returns a copy of the tensor. See `tf.clone` for details.
3821 * @doc {heading: 'Tensors', subheading: 'Classes'}
3822 */
3823 clone() {
3824 this.throwIfDisposed();
3825 return opHandler.clone(this);
3826 }
3827 /**
3828 * Returns a human-readable description of the tensor. Useful for logging.
3829 *
3830 * @doc {heading: 'Tensors', subheading: 'Classes'}
3831 */
3832 toString(verbose = false) {
3833 const vals = this.dataSync();
3834 return tensorToString(vals, this.shape, this.dtype, verbose);
3835 }
3836 cast(dtype) {
3837 this.throwIfDisposed();
3838 return opHandler.cast(this, dtype);
3839 }
3840 variable(trainable = true, name, dtype) {
3841 this.throwIfDisposed();
3842 return trackerFn().makeVariable(this, trainable, name, dtype);
3843 }
3844 }
3845 Object.defineProperty(Tensor, Symbol.hasInstance, {
3846 value: (instance) => {
3847 // Implementation note: we should use properties of the object that will be
3848 // defined before the constructor body has finished executing (methods).
3849 // This is because when this code is transpiled by babel, babel will call
3850 // classCallCheck before the constructor body is run.
3851 // See https://github.com/tensorflow/tfjs/issues/3384 for backstory.
3852 return !!instance && instance.data != null && instance.dataSync != null &&
3853 instance.throwIfDisposed != null;
3854 }
3855 });
3856 function getGlobalTensorClass() {
3857 // Use getGlobal so that we can augment the Tensor class across package
3858 // boundaries becase the node resolution alg may result in different modules
3859 // being returned for this file depending on the path they are loaded from.
3860 return getGlobal('Tensor', () => {
3861 return Tensor;
3862 });
3863 }
3864 // Global side effect. Cache global reference to Tensor class
3865 getGlobalTensorClass();
3866 /**
3867 * A mutable `tf.Tensor`, useful for persisting state, e.g. for training.
3868 *
3869 * @doc {heading: 'Tensors', subheading: 'Classes'}
3870 */
3871 class Variable extends Tensor {
3872 constructor(initialValue, trainable, name, tensorId) {
3873 super(initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId);
3874 this.trainable = trainable;
3875 this.name = name;
3876 }
3877 /**
3878 * Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have
3879 * the same shape and dtype as the old `tf.Tensor`.
3880 *
3881 * @param newValue New tensor to be assigned to this variable.
3882 *
3883 * @doc {heading: 'Tensors', subheading: 'Classes'}
3884 */
3885 assign(newValue) {
3886 if (newValue.dtype !== this.dtype) {
3887 throw new Error(`dtype of the new value (${newValue.dtype}) and ` +
3888 `previous value (${this.dtype}) must match`);
3889 }
3890 if (!arraysEqual(newValue.shape, this.shape)) {
3891 throw new Error(`shape of the new value (${newValue.shape}) and ` +
3892 `previous value (${this.shape}) must match`);
3893 }
3894 trackerFn().disposeTensor(this);
3895 this.dataId = newValue.dataId;
3896 trackerFn().incRef(this, null /* backend */);
3897 }
3898 dispose() {
3899 trackerFn().disposeVariable(this);
3900 this.isDisposedInternal = true;
3901 }
3902 }
3903 Object.defineProperty(Variable, Symbol.hasInstance, {
3904 value: (instance) => {
3905 return instance instanceof Tensor && instance.assign != null &&
3906 instance.assign instanceof Function;
3907 }
3908 });
3909
3910 /**
3911 * @license
3912 * Copyright 2017 Google LLC. All Rights Reserved.
3913 * Licensed under the Apache License, Version 2.0 (the "License");
3914 * you may not use this file except in compliance with the License.
3915 * You may obtain a copy of the License at
3916 *
3917 * http://www.apache.org/licenses/LICENSE-2.0
3918 *
3919 * Unless required by applicable law or agreed to in writing, software
3920 * distributed under the License is distributed on an "AS IS" BASIS,
3921 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3922 * See the License for the specific language governing permissions and
3923 * limitations under the License.
3924 * =============================================================================
3925 */
3926 (function (Rank) {
3927 Rank["R0"] = "R0";
3928 Rank["R1"] = "R1";
3929 Rank["R2"] = "R2";
3930 Rank["R3"] = "R3";
3931 Rank["R4"] = "R4";
3932 Rank["R5"] = "R5";
3933 Rank["R6"] = "R6";
3934 })(exports.Rank || (exports.Rank = {}));
3935 // Looks for upcasting types. Used, for example, in operations with mixed dtype
3936 // inputs.
3937 var UpcastInt32AndMap;
3938 (function (UpcastInt32AndMap) {
3939 UpcastInt32AndMap["float32"] = "float32";
3940 UpcastInt32AndMap["int32"] = "int32";
3941 UpcastInt32AndMap["bool"] = "int32";
3942 UpcastInt32AndMap["complex64"] = "complex64";
3943 })(UpcastInt32AndMap || (UpcastInt32AndMap = {}));
3944 var UpcastBoolAndMap;
3945 (function (UpcastBoolAndMap) {
3946 UpcastBoolAndMap["float32"] = "float32";
3947 UpcastBoolAndMap["int32"] = "int32";
3948 UpcastBoolAndMap["bool"] = "bool";
3949 UpcastBoolAndMap["complex64"] = "complex64";
3950 })(UpcastBoolAndMap || (UpcastBoolAndMap = {}));
3951 var UpcastFloat32AndMap;
3952 (function (UpcastFloat32AndMap) {
3953 UpcastFloat32AndMap["float32"] = "float32";
3954 UpcastFloat32AndMap["int32"] = "float32";
3955 UpcastFloat32AndMap["bool"] = "float32";
3956 UpcastFloat32AndMap["complex64"] = "complex64";
3957 })(UpcastFloat32AndMap || (UpcastFloat32AndMap = {}));
3958 var UpcastComplex64AndMap;
3959 (function (UpcastComplex64AndMap) {
3960 UpcastComplex64AndMap["float32"] = "complex64";
3961 UpcastComplex64AndMap["int32"] = "complex64";
3962 UpcastComplex64AndMap["bool"] = "complex64";
3963 UpcastComplex64AndMap["complex64"] = "complex64";
3964 })(UpcastComplex64AndMap || (UpcastComplex64AndMap = {}));
3965 const upcastTypeMap = {
3966 'float32': UpcastFloat32AndMap,
3967 'int32': UpcastInt32AndMap,
3968 'bool': UpcastBoolAndMap,
3969 'complex64': UpcastComplex64AndMap
3970 };
3971 function upcastType(typeA, typeB) {
3972 if (typeA === 'string' || typeB === 'string') {
3973 if (typeA === 'string' && typeB === 'string') {
3974 return 'string';
3975 }
3976 throw new Error(`Can not upcast ${typeA} with ${typeB}`);
3977 }
3978 return upcastTypeMap[typeA][typeB];
3979 }
3980 /** Returns the output type after summation. */
3981 function sumOutType(type) {
3982 return upcastType(type, 'int32');
3983 }
3984
3985 /**
3986 * @license
3987 * Copyright 2018 Google LLC. All Rights Reserved.
3988 * Licensed under the Apache License, Version 2.0 (the "License");
3989 * you may not use this file except in compliance with the License.
3990 * You may obtain a copy of the License at
3991 *
3992 * http://www.apache.org/licenses/LICENSE-2.0
3993 *
3994 * Unless required by applicable law or agreed to in writing, software
3995 * distributed under the License is distributed on an "AS IS" BASIS,
3996 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3997 * See the License for the specific language governing permissions and
3998 * limitations under the License.
3999 * =============================================================================
4000 */
4001 function makeTypesMatch(a, b) {
4002 if (a.dtype === b.dtype) {
4003 return [a, b];
4004 }
4005 const dtype = upcastType(a.dtype, b.dtype);
4006 return [a.cast(dtype), b.cast(dtype)];
4007 }
4008 function assertTypesMatch(a, b) {
4009 assert(a.dtype === b.dtype, () => `The dtypes of the first(${a.dtype}) and` +
4010 ` second(${b.dtype}) input must match`);
4011 }
4012 function isTensorInList(tensor, tensorList) {
4013 return tensorList.some(x => x.id === tensor.id);
4014 }
4015 /**
4016 * Extracts any `Tensor`s found within the provided object.
4017 *
4018 * @param container an object that may be a `Tensor` or may directly contain
4019 * `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it
4020 * is safe to pass any object here, except that `Promise`s are not
4021 * supported.
4022 * @returns An array of `Tensors` found within the passed object. If the
4023 * argument is simply a `Tensor', a list containing that `Tensor` is
4024 * returned. If the object is not a `Tensor` or does not
4025 * contain `Tensors`, an empty list is returned.
4026 */
4027 function getTensorsInContainer(result) {
4028 const list = [];
4029 const seen = new Set();
4030 walkTensorContainer(result, list, seen);
4031 return list;
4032 }
4033 function walkTensorContainer(container, list, seen) {
4034 if (container == null) {
4035 return;
4036 }
4037 if (container instanceof Tensor) {
4038 list.push(container);
4039 return;
4040 }
4041 if (!isIterable(container)) {
4042 return;
4043 }
4044 // Iteration over keys works also for arrays.
4045 const iterable = container;
4046 for (const k in iterable) {
4047 const val = iterable[k];
4048 if (!seen.has(val)) {
4049 seen.add(val);
4050 walkTensorContainer(val, list, seen);
4051 }
4052 }
4053 }
4054 // tslint:disable-next-line:no-any
4055 function isIterable(obj) {
4056 return Array.isArray(obj) || typeof obj === 'object';
4057 }
4058
4059 var tensor_util = /*#__PURE__*/Object.freeze({
4060 __proto__: null,
4061 makeTypesMatch: makeTypesMatch,
4062 assertTypesMatch: assertTypesMatch,
4063 isTensorInList: isTensorInList,
4064 getTensorsInContainer: getTensorsInContainer
4065 });
4066
4067 /**
4068 * @license
4069 * Copyright 2018 Google LLC. All Rights Reserved.
4070 * Licensed under the Apache License, Version 2.0 (the "License");
4071 * you may not use this file except in compliance with the License.
4072 * You may obtain a copy of the License at
4073 *
4074 * http://www.apache.org/licenses/LICENSE-2.0
4075 *
4076 * Unless required by applicable law or agreed to in writing, software
4077 * distributed under the License is distributed on an "AS IS" BASIS,
4078 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4079 * See the License for the specific language governing permissions and
4080 * limitations under the License.
4081 * =============================================================================
4082 */
4083 function isRegisteredKernelInvocation(kernelInvocation) {
4084 return kernelInvocation.kernelName != null;
4085 }
4086 class EngineState {
4087 constructor() {
4088 // Public since optimizers will use it.
4089 this.registeredVariables = {};
4090 this.nextTapeNodeId = 0;
4091 this.numBytes = 0;
4092 this.numTensors = 0;
4093 this.numStringTensors = 0;
4094 this.numDataBuffers = 0;
4095 // Number of nested tf.grad() statements when computing higher-order
4096 // gradients. E.g. `1` for first-order gradients and `2` for second-order
4097 // gradients. Used to track if the tape should be removed after a backprop.
4098 this.gradientDepth = 0;
4099 // Number of nested kernel calls. When kernel depth is greater than 1, we turn
4100 // off the tape.
4101 this.kernelDepth = 0;
4102 this.scopeStack = [];
4103 /**
4104 * Keeps track of the number of data moves during a kernel execution. We
4105 * maintain a stack since kernels can call other kernels, recursively.
4106 */
4107 this.numDataMovesStack = [];
4108 this.nextScopeId = 0;
4109 this.tensorInfo = new WeakMap();
4110 this.profiling = false;
4111 this.activeProfile = {
4112 newBytes: 0,
4113 newTensors: 0,
4114 peakBytes: 0,
4115 kernels: [],
4116 result: null,
4117 get kernelNames() {
4118 return Array.from(new Set(this.kernels.map(k => k.name)));
4119 }
4120 };
4121 }
4122 dispose() {
4123 for (const variableName in this.registeredVariables) {
4124 this.registeredVariables[variableName].dispose();
4125 }
4126 }
4127 }
4128 class Engine {
4129 constructor(ENV) {
4130 this.ENV = ENV;
4131 this.registry = {};
4132 this.registryFactory = {};
4133 this.pendingBackendInitId = 0;
4134 this.state = new EngineState();
4135 }
4136 async ready() {
4137 if (this.pendingBackendInit != null) {
4138 return this.pendingBackendInit.then(() => { });
4139 }
4140 if (this.backendInstance != null) {
4141 return;
4142 }
4143 const sortedBackends = this.getSortedBackends();
4144 for (let i = 0; i < sortedBackends.length; i++) {
4145 const backendName = sortedBackends[i];
4146 const success = await this.initializeBackend(backendName).success;
4147 if (success) {
4148 await this.setBackend(backendName);
4149 return;
4150 }
4151 }
4152 throw new Error(`Could not initialize any backends, all backend initializations ` +
4153 `failed.`);
4154 }
4155 get backend() {
4156 if (this.pendingBackendInit != null) {
4157 throw new Error(`Backend '${this.backendName}' has not yet been initialized. Make ` +
4158 `sure to await tf.ready() or await tf.setBackend() before calling ` +
4159 `other methods`);
4160 }
4161 if (this.backendInstance == null) {
4162 const { name, asyncInit } = this.initializeBackendsAndReturnBest();
4163 if (asyncInit) {
4164 throw new Error(`The highest priority backend '${name}' has not yet been ` +
4165 `initialized. Make sure to await tf.ready() or ` +
4166 `await tf.setBackend() before calling other methods`);
4167 }
4168 this.setBackend(name);
4169 }
4170 return this.backendInstance;
4171 }
4172 backendNames() {
4173 return Object.keys(this.registryFactory);
4174 }
4175 findBackend(backendName) {
4176 if (!(backendName in this.registry)) {
4177 // If the backend hasn't been initialized but we have a registry entry for
4178 // it, initialize it and return it.
4179 if (backendName in this.registryFactory) {
4180 const { asyncInit } = this.initializeBackend(backendName);
4181 if (asyncInit) {
4182 // Backend is not ready yet.
4183 return null;
4184 }
4185 }
4186 else {
4187 return null;
4188 }
4189 }
4190 return this.registry[backendName];
4191 }
4192 findBackendFactory(backendName) {
4193 if (!(backendName in this.registryFactory)) {
4194 return null;
4195 }
4196 return this.registryFactory[backendName].factory;
4197 }
4198 registerBackend(backendName, factory, priority = 1) {
4199 if (backendName in this.registryFactory) {
4200 warn(`${backendName} backend was already registered. ` +
4201 `Reusing existing backend factory.`);
4202 return false;
4203 }
4204 this.registryFactory[backendName] = { factory, priority };
4205 return true;
4206 }
4207 async setBackend(backendName) {
4208 if (this.registryFactory[backendName] == null) {
4209 throw new Error(`Backend name '${backendName}' not found in registry`);
4210 }
4211 this.backendName = backendName;
4212 if (this.registry[backendName] == null) {
4213 this.backendInstance = null;
4214 const { success, asyncInit } = this.initializeBackend(backendName);
4215 const result = asyncInit ? await success : success;
4216 if (!result) {
4217 return false;
4218 }
4219 }
4220 this.backendInstance = this.registry[backendName];
4221 this.setupRegisteredKernels();
4222 // Reset the profiler.
4223 this.profiler = new Profiler(this.backendInstance);
4224 return true;
4225 }
4226 setupRegisteredKernels() {
4227 const kernels = getKernelsForBackend(this.backendName);
4228 kernels.forEach(kernel => {
4229 if (kernel.setupFunc != null) {
4230 kernel.setupFunc(this.backendInstance);
4231 }
4232 });
4233 }
4234 disposeRegisteredKernels(backendName) {
4235 const kernels = getKernelsForBackend(backendName);
4236 kernels.forEach(kernel => {
4237 if (kernel.disposeFunc != null) {
4238 kernel.disposeFunc(this.registry[backendName]);
4239 }
4240 });
4241 }
4242 /**
4243 * Initializes a backend by looking up the backend name in the factory
4244 * registry and calling the factory method. Returns a boolean representing
4245 * whether the initialization of the backend suceeded. Throws an error if
4246 * there is no backend in the factory registry.
4247 */
4248 initializeBackend(backendName) {
4249 const registryFactoryEntry = this.registryFactory[backendName];
4250 if (registryFactoryEntry == null) {
4251 throw new Error(`Cannot initialize backend ${backendName}, no registration found.`);
4252 }
4253 try {
4254 const backend = registryFactoryEntry.factory();
4255 /* Test if the factory returns a promise.
4256 Done in a more liberal way than
4257 previous 'Promise.resolve(backend)===backend'
4258 as we needed to account for custom Promise
4259 implementations (e.g. Angular) */
4260 if (backend && !(backend instanceof KernelBackend) &&
4261 typeof backend.then === 'function') {
4262 const promiseId = ++this.pendingBackendInitId;
4263 const success = backend
4264 .then(backendInstance => {
4265 // Outdated promise. Another backend was set in the meantime.
4266 if (promiseId < this.pendingBackendInitId) {
4267 return false;
4268 }
4269 this.registry[backendName] = backendInstance;
4270 this.pendingBackendInit = null;
4271 return true;
4272 })
4273 .catch(err => {
4274 // Outdated promise. Another backend was set in the meantime.
4275 if (promiseId < this.pendingBackendInitId) {
4276 return false;
4277 }
4278 this.pendingBackendInit = null;
4279 warn(`Initialization of backend ${backendName} failed`);
4280 warn(err.stack || err.message);
4281 return false;
4282 });
4283 this.pendingBackendInit = success;
4284 return { success, asyncInit: true };
4285 }
4286 else {
4287 this.registry[backendName] = backend;
4288 return { success: true, asyncInit: false };
4289 }
4290 }
4291 catch (err) {
4292 warn(`Initialization of backend ${backendName} failed`);
4293 warn(err.stack || err.message);
4294 return { success: false, asyncInit: false };
4295 }
4296 }
4297 removeBackend(backendName) {
4298 if (!(backendName in this.registryFactory)) {
4299 throw new Error(`${backendName} backend not found in registry`);
4300 }
4301 if (this.backendName === backendName && this.pendingBackendInit != null) {
4302 // There is a pending promise of the backend we want to remove. Make it
4303 // obsolete.
4304 this.pendingBackendInitId++;
4305 }
4306 if (backendName in this.registry) {
4307 this.disposeRegisteredKernels(backendName);
4308 this.registry[backendName].dispose();
4309 delete this.registry[backendName];
4310 }
4311 delete this.registryFactory[backendName];
4312 // Unset the backend if it is active.
4313 if (this.backendName === backendName) {
4314 this.pendingBackendInit = null;
4315 this.backendName = null;
4316 this.backendInstance = null;
4317 }
4318 }
4319 getSortedBackends() {
4320 if (Object.keys(this.registryFactory).length === 0) {
4321 throw new Error('No backend found in registry.');
4322 }
4323 return Object.keys(this.registryFactory).sort((a, b) => {
4324 // Highest priority comes first.
4325 return this.registryFactory[b].priority -
4326 this.registryFactory[a].priority;
4327 });
4328 }
4329 initializeBackendsAndReturnBest() {
4330 const sortedBackends = this.getSortedBackends();
4331 for (let i = 0; i < sortedBackends.length; i++) {
4332 const backendName = sortedBackends[i];
4333 const { success, asyncInit } = this.initializeBackend(backendName);
4334 if (asyncInit || success) {
4335 return { name: backendName, asyncInit };
4336 }
4337 }
4338 throw new Error(`Could not initialize any backends, all backend initializations ` +
4339 `failed.`);
4340 }
4341 moveData(backend, dataId) {
4342 const info = this.state.tensorInfo.get(dataId);
4343 const srcBackend = info.backend;
4344 const values = this.readSync(dataId);
4345 const refCount = srcBackend.refCount(dataId);
4346 // Delete the tensor from the old backend and move it to the new
4347 // backend.
4348 srcBackend.disposeData(dataId, true);
4349 info.backend = backend;
4350 backend.move(dataId, values, info.shape, info.dtype, refCount);
4351 if (this.shouldCheckForMemLeaks()) {
4352 // Track the number of moves during a kernel execution to correctly
4353 // detect memory leaks.
4354 this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
4355 }
4356 }
4357 tidy(nameOrFn, fn) {
4358 let name = null;
4359 if (fn == null) {
4360 // Called with only 1 argument.
4361 if (typeof nameOrFn !== 'function') {
4362 throw new Error('Please provide a function to tidy()');
4363 }
4364 fn = nameOrFn;
4365 }
4366 else {
4367 // Called with 2 arguments.
4368 if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {
4369 throw new Error('When calling with two arguments, the first argument ' +
4370 'to tidy() must be a string');
4371 }
4372 if (typeof fn !== 'function') {
4373 throw new Error('When calling with two arguments, the 2nd argument ' +
4374 'to tidy() must be a function');
4375 }
4376 name = nameOrFn;
4377 // TODO(nsthorat,smilkov): Do operation logging and performance
4378 // profiling.
4379 }
4380 let result;
4381 return this.scopedRun(() => this.startScope(name), () => this.endScope(result), () => {
4382 result = fn();
4383 if (result instanceof Promise) {
4384 console.error('Cannot return a Promise inside of tidy.');
4385 }
4386 return result;
4387 });
4388 }
4389 scopedRun(start, end, f) {
4390 start();
4391 try {
4392 const res = f();
4393 end();
4394 return res;
4395 }
4396 catch (ex) {
4397 end();
4398 throw ex;
4399 }
4400 }
4401 nextTensorId() {
4402 return Engine.nextTensorId++;
4403 }
4404 nextVariableId() {
4405 return Engine.nextVariableId++;
4406 }
4407 /**
4408 * This method is called instead of the public-facing tensor.clone() when
4409 * saving a tensor for backwards pass. It makes sure to add the clone
4410 * operation to the tape regardless of being called inside a kernel
4411 * execution.
4412 */
4413 clone(x) {
4414 const y = ENGINE.runKernel(Identity, { x });
4415 const inputs = { x };
4416 const grad = (dy) => ({
4417 x: () => {
4418 const dtype = 'float32';
4419 const gradInputs = { x: dy };
4420 const attrs = { dtype };
4421 return ENGINE.runKernel(Cast, gradInputs,
4422 // tslint:disable-next-line: no-unnecessary-type-assertion
4423 attrs);
4424 }
4425 });
4426 const saved = [];
4427 this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {});
4428 return y;
4429 }
4430 /**
4431 * Execute a kernel with the given name and return the output tensor.
4432 *
4433 * @param kernelName The name of the kernel to execute.
4434 * @param inputs A map of input names to tensors.
4435 * @param attrs A map of attribute names to their values. An attribute is a
4436 * primitive (non-tensor) input to the kernel.
4437 * @param inputsToSave A list of tensors, inputs to save for the backprop
4438 * computation.
4439 * @param outputsToSave A list of booleans, specifying which output to save
4440 * for the backprop computation. These are booleans since the output
4441 * tensors are not visible to the user.
4442 */
4443 runKernel(kernelName, inputs, attrs) {
4444 if (this.backendName == null) {
4445 // backend has not been initialized yet (backend initialization is lazy
4446 // can be deferred until an op/ kernel is run).
4447 // The below getter has side effects that will try to initialize the
4448 // backend and set properties like this.backendName
4449 // tslint:disable-next-line: no-unused-expression
4450 this.backend;
4451 }
4452 const hasKernel = getKernel(kernelName, this.backendName) != null;
4453 if (!hasKernel) {
4454 throw new Error(`Kernel '${kernelName}' not registered for backend '${this.backendName}'`);
4455 }
4456 return this.runKernelFunc({ kernelName, inputs, attrs });
4457 }
4458 shouldCheckForMemLeaks() {
4459 return this.ENV.getBool('IS_TEST');
4460 }
4461 checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos) {
4462 const numDataIdsAfter = this.backend.numDataIds();
4463 // Count the number of data ids associated with the result of the kernel.
4464 let numOutputDataIds = 0;
4465 outInfos.forEach(info => {
4466 // Complex numbers allocate 3 data ids, one for 'real', one for
4467 // 'imaginary', and one for the container that holds the former two.
4468 numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1);
4469 });
4470 // Account for the number of moves during kernel execution. A "data move"
4471 // can happen in the middle of a kernel execution, placing a new (key,value)
4472 // pair in the data storage. Since data moves have net zero effect (we
4473 // always remove the data from the old backend), we have to cancel them out
4474 // when detecting memory leaks.
4475 const numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];
4476 const dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;
4477 if (dataIdsLeaked > 0) {
4478 throw new Error(`Backend '${this.backendName}' has an internal memory leak ` +
4479 `(${dataIdsLeaked} data ids) after running '${kernelName}'`);
4480 }
4481 }
4482 /**
4483 * Internal helper method to execute a kernel Func
4484 *
4485 * Use `runKernel` to execute kernels from outside of engine.
4486 */
4487 runKernelFunc(kernelParams) {
4488 let outputs;
4489 let saved = [];
4490 const isTapeOn = this.isTapeOn();
4491 const startingBytecount = this.state.numBytes;
4492 const startingNumTensors = this.state.numTensors;
4493 if (this.shouldCheckForMemLeaks()) {
4494 this.state.numDataMovesStack.push(0);
4495 }
4496 let kernelFunc;
4497 if (this.backendName == null) {
4498 // backend has not been initialized yet (backend initialization is lazy
4499 // can be deferred until an op/ kernel is run).
4500 // The below getter has side effects that will try to initialize the
4501 // backend and set properties like this.backendName
4502 // tslint:disable-next-line: no-unused-expression
4503 this.backend;
4504 }
4505 let out;
4506 const kernelOrScopeName = isRegisteredKernelInvocation(kernelParams) ?
4507 kernelParams.kernelName :
4508 this.state.activeScope != null ? this.state.activeScope.name : '';
4509 // Create the kernelFunc from either a registered kernel OR passed in
4510 // forward/backward functions (used by custom grad). In this context a
4511 // kernelFunc wraps a kernel implementation with some bookkeeping.
4512 if (isRegisteredKernelInvocation(kernelParams)) {
4513 const { kernelName, inputs, attrs } = kernelParams;
4514 if (this.backendName == null) {
4515 // backend has not been initialized yet (backend initialization is lazy
4516 // can be deferred until an op/ kernel is run).
4517 // The below getter has side effects that will try to initialize the
4518 // backend and set properties like this.backendName
4519 // tslint:disable-next-line: no-unused-expression
4520 this.backend;
4521 }
4522 const kernel = getKernel(kernelName, this.backendName);
4523 assert(kernel != null, () => `Cannot find registered kernel '${kernelName}' for backend '${this.backendName}'`);
4524 kernelFunc = () => {
4525 const numDataIdsBefore = this.backend.numDataIds();
4526 out = kernel.kernelFunc({ inputs, attrs, backend: this.backend });
4527 const outInfos = Array.isArray(out) ? out : [out];
4528 if (this.shouldCheckForMemLeaks()) {
4529 this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
4530 }
4531 const outTensors = outInfos.map((outInfo) => {
4532 // todo (yassogba) remove this option (Tensor) when node backend
4533 // methods have been modularized and they all return tensorInfo.
4534 // TensorInfos do not have a rank attribute.
4535 if (outInfo.rank != null) {
4536 return outInfo;
4537 }
4538 return this.makeTensorFromTensorInfo(outInfo);
4539 });
4540 // Save any required inputs and outputs.
4541 // Do not save unless we are recording to the tape. Otherwise it would
4542 // cause a mem leak since there would be no backprop for these tensors
4543 // (which would otherwise dispose them).
4544 if (isTapeOn) {
4545 const tensorsToSave = this.getTensorsForGradient(kernelName, inputs, outTensors);
4546 saved = this.saveTensorsForBackwardMode(tensorsToSave);
4547 }
4548 return outTensors;
4549 };
4550 }
4551 else {
4552 const { forwardFunc } = kernelParams;
4553 // Running a customGrad op.
4554 const saveFunc = (tensors) => {
4555 // Do not save unless we are recording to the tape. Otherwise it would
4556 // cause a mem leak since we would never run backprop, which disposes
4557 // the kept tensors.
4558 if (!isTapeOn) {
4559 return;
4560 }
4561 saved = tensors.map(tensor => this.keep(this.clone(tensor)));
4562 };
4563 kernelFunc = () => {
4564 const numDataIdsBefore = this.backend.numDataIds();
4565 out = this.tidy(() => forwardFunc(this.backend, saveFunc));
4566 const outs = (Array.isArray(out) ? out : [out]);
4567 if (this.shouldCheckForMemLeaks()) {
4568 // Scope name is used to print a more helpful error message if needed.
4569 this.checkKernelForMemLeak(kernelOrScopeName, numDataIdsBefore, outs);
4570 }
4571 return outs;
4572 };
4573 }
4574 //
4575 // Run the kernelFunc. Optionally profiling it.
4576 //
4577 const { inputs, attrs } = kernelParams;
4578 const backwardsFunc = isRegisteredKernelInvocation(kernelParams) ?
4579 null :
4580 kernelParams.backwardsFunc;
4581 let kernelProfile;
4582 this.scopedRun(
4583 // Stop recording to a tape when running a kernel.
4584 () => this.state.kernelDepth++, () => this.state.kernelDepth--, () => {
4585 if (!this.ENV.getBool('DEBUG') && !this.state.profiling) {
4586 outputs = kernelFunc();
4587 }
4588 else {
4589 kernelProfile = this.profiler.profileKernel(kernelOrScopeName, inputs, () => kernelFunc());
4590 if (this.ENV.getBool('DEBUG')) {
4591 this.profiler.logKernelProfile(kernelProfile);
4592 }
4593 outputs = kernelProfile.outputs;
4594 }
4595 });
4596 if (isTapeOn) {
4597 this.addTapeNode(kernelOrScopeName, inputs, outputs, backwardsFunc, saved, attrs);
4598 }
4599 if (this.state.profiling) {
4600 this.state.activeProfile.kernels.push({
4601 name: kernelOrScopeName,
4602 bytesAdded: this.state.numBytes - startingBytecount,
4603 totalBytesSnapshot: this.state.numBytes,
4604 tensorsAdded: this.state.numTensors - startingNumTensors,
4605 totalTensorsSnapshot: this.state.numTensors,
4606 inputShapes: Object.keys(inputs).map(key => inputs[key] != null ? inputs[key].shape : null),
4607 outputShapes: outputs.map(item => item.shape),
4608 kernelTimeMs: kernelProfile.timeMs,
4609 extraInfo: kernelProfile.extraInfo
4610 });
4611 }
4612 return (Array.isArray(out) ? outputs : outputs[0]);
4613 }
4614 /**
4615 * Saves tensors used in forward mode for use in backward mode.
4616 *
4617 * @param tensors the list of tensors to save.
4618 */
4619 saveTensorsForBackwardMode(tensors) {
4620 const saved = tensors.map(tensor => this.keep(this.clone(tensor)));
4621 return saved;
4622 }
4623 /**
4624 * Returns a list of tensors to save for a given gradient calculation.
4625 *
4626 * @param kernelName name of kernel to look up gradient for.
4627 * @param inputs a map of input tensors.
4628 * @param outputs an array of output tensors from forward mode of kernel.
4629 */
4630 getTensorsForGradient(kernelName, inputs, outputs) {
4631 const gradConfig = getGradient(kernelName);
4632 if (gradConfig != null) {
4633 const inputsToSave = gradConfig.inputsToSave || [];
4634 const outputsToSave = gradConfig.outputsToSave || [];
4635 // If saveAllInputs is true, all inputs will be saved. Otherwise, inputs
4636 // specified in inputsToSave will be saved.
4637 let inputTensorsToSave;
4638 if (gradConfig.saveAllInputs) {
4639 assert(Array.isArray(inputs), () => 'saveAllInputs is true, expected inputs to be an array.');
4640 inputTensorsToSave = Object.keys(inputs).map((key) => inputs[key]);
4641 }
4642 else {
4643 inputTensorsToSave = inputsToSave.map((inputName) => inputs[inputName]);
4644 }
4645 const outputTensorsToSave = outputs.filter((_, i) => outputsToSave[i]);
4646 return inputTensorsToSave.concat(outputTensorsToSave);
4647 }
4648 // We return an empty list rather than throw an error because the kernel we
4649 // are looking up may not actually be relevant to backproping through the
4650 // overall function
4651 //
4652 // See 'does not error if irrelevant (pruned) ops are missing grads' test
4653 // in gradients_test.ts for an example.
4654 return [];
4655 }
4656 /**
4657 * Internal method used by public APIs for tensor creation. Makes a new
4658 * tensor with the provided shape, dtype and values. It always
4659 * creates a new data id and writes the values to the underlying backend.
4660 */
4661 makeTensor(values, shape, dtype, backend) {
4662 if (values == null) {
4663 throw new Error('Values passed to engine.makeTensor() are null');
4664 }
4665 dtype = dtype || 'float32';
4666 backend = backend || this.backend;
4667 let backendVals = values;
4668 if (dtype === 'string' && isString(values[0])) {
4669 backendVals = values.map(d => encodeString(d));
4670 }
4671 const dataId = backend.write(backendVals, shape, dtype);
4672 const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
4673 this.trackTensor(t, backend);
4674 // Count bytes for string tensors.
4675 if (dtype === 'string') {
4676 const info = this.state.tensorInfo.get(dataId);
4677 const newBytes = bytesFromStringArray(backendVals);
4678 this.state.numBytes += newBytes - info.bytes;
4679 info.bytes = newBytes;
4680 }
4681 return t;
4682 }
4683 /**
4684 * Internal method used by backends. Makes a new tensor
4685 * that is a wrapper around an existing data id. It doesn't create
4686 * a new data id, only increments the ref count used in memory tracking.
4687 * @deprecated
4688 */
4689 makeTensorFromDataId(dataId, shape, dtype, backend) {
4690 dtype = dtype || 'float32';
4691 const tensorInfo = { dataId, shape, dtype };
4692 return this.makeTensorFromTensorInfo(tensorInfo, backend);
4693 }
4694 /**
4695 * Internal method used by backends. Makes a new tensor that is a wrapper
4696 * around an existing data id in TensorInfo. It doesn't create a new data id,
4697 * only increments the ref count used in memory tracking.
4698 */
4699 makeTensorFromTensorInfo(tensorInfo, backend) {
4700 const { dataId, shape, dtype } = tensorInfo;
4701 const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
4702 this.trackTensor(t, backend);
4703 return t;
4704 }
4705 makeVariable(initialValue, trainable = true, name, dtype) {
4706 name = name || this.nextVariableId().toString();
4707 if (dtype != null && dtype !== initialValue.dtype) {
4708 initialValue = initialValue.cast(dtype);
4709 }
4710 const v = new Variable(initialValue, trainable, name, this.nextTensorId());
4711 if (this.state.registeredVariables[v.name] != null) {
4712 throw new Error(`Variable with name ${v.name} was already registered`);
4713 }
4714 this.state.registeredVariables[v.name] = v;
4715 this.incRef(v, this.backend);
4716 return v;
4717 }
4718 trackTensor(a, backend) {
4719 this.state.numTensors++;
4720 if (a.dtype === 'string') {
4721 this.state.numStringTensors++;
4722 }
4723 // Bytes for complex numbers are counted by their components. Bytes for
4724 // string tensors are counted when writing values.
4725 let bytes = 0;
4726 if (a.dtype !== 'complex64' && a.dtype !== 'string') {
4727 bytes = a.size * bytesPerElement(a.dtype);
4728 }
4729 this.state.numBytes += bytes;
4730 if (!this.state.tensorInfo.has(a.dataId)) {
4731 this.state.numDataBuffers++;
4732 this.state.tensorInfo.set(a.dataId, {
4733 backend: backend || this.backend,
4734 dtype: a.dtype,
4735 shape: a.shape,
4736 bytes
4737 });
4738 }
4739 if (!(a instanceof Variable)) {
4740 this.track(a);
4741 }
4742 }
4743 // Track the tensor by dataId and increase the refCount for the dataId in the
4744 // backend.
4745 // TODO(pyu10055): This is currently used by makeVariable method, to increase
4746 // refCount on the backend for the dataId. It can potentially be replaced with
4747 // Identity op indead of calling backend directly.
4748 incRef(a, backend) {
4749 this.trackTensor(a, backend);
4750 this.backend.incRef(a.dataId);
4751 }
4752 removeDataId(dataId, backend) {
4753 if (this.state.tensorInfo.has(dataId) &&
4754 this.state.tensorInfo.get(dataId).backend === backend) {
4755 this.state.tensorInfo.delete(dataId);
4756 this.state.numDataBuffers--;
4757 }
4758 }
4759 disposeTensor(a) {
4760 if (!this.state.tensorInfo.has(a.dataId)) {
4761 return;
4762 }
4763 const info = this.state.tensorInfo.get(a.dataId);
4764 this.state.numTensors--;
4765 if (a.dtype === 'string') {
4766 this.state.numStringTensors--;
4767 this.state.numBytes -= info.bytes;
4768 }
4769 // Don't count bytes for complex numbers as they are counted by their
4770 // components.
4771 if (a.dtype !== 'complex64' && a.dtype !== 'string') {
4772 const bytes = a.size * bytesPerElement(a.dtype);
4773 this.state.numBytes -= bytes;
4774 }
4775 // Remove the reference to dataId if backend dispose the data successfully
4776 if (info.backend.disposeData(a.dataId)) {
4777 this.removeDataId(a.dataId, info.backend);
4778 }
4779 // TODO(nsthorat): Construct an error and save the stack trace for
4780 // debugging when in debug mode. Creating a stack trace is too expensive
4781 // to do unconditionally.
4782 }
4783 disposeVariables() {
4784 for (const varName in this.state.registeredVariables) {
4785 const v = this.state.registeredVariables[varName];
4786 this.disposeVariable(v);
4787 }
4788 }
4789 disposeVariable(v) {
4790 this.disposeTensor(v);
4791 if (this.state.registeredVariables[v.name] != null) {
4792 delete this.state.registeredVariables[v.name];
4793 }
4794 }
4795 memory() {
4796 const info = this.backend.memory();
4797 info.numTensors = this.state.numTensors;
4798 info.numDataBuffers = this.state.numDataBuffers;
4799 info.numBytes = this.state.numBytes;
4800 if (this.state.numStringTensors > 0) {
4801 info.unreliable = true;
4802 if (info.reasons == null) {
4803 info.reasons = [];
4804 }
4805 info.reasons.push('Memory usage by string tensors is approximate ' +
4806 '(2 bytes per character)');
4807 }
4808 return info;
4809 }
4810 async profile(query) {
4811 this.state.profiling = true;
4812 const startBytes = this.state.numBytes;
4813 const startNumTensors = this.state.numTensors;
4814 this.state.activeProfile.kernels = [];
4815 this.state.activeProfile.result = await query();
4816 this.state.profiling = false;
4817 this.state.activeProfile.peakBytes = Math.max(...this.state.activeProfile.kernels.map(d => d.totalBytesSnapshot));
4818 this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
4819 this.state.activeProfile.newTensors =
4820 this.state.numTensors - startNumTensors;
4821 for (const kernel of this.state.activeProfile.kernels) {
4822 kernel.kernelTimeMs = await kernel.kernelTimeMs;
4823 kernel.extraInfo = await kernel.extraInfo;
4824 }
4825 return this.state.activeProfile;
4826 }
4827 isTapeOn() {
4828 return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
4829 }
4830 addTapeNode(kernelName, inputs, outputs, gradientsFunc, saved, attrs) {
4831 const tapeNode = { id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved };
4832 const gradConfig = getGradient(kernelName);
4833 if (gradConfig != null) {
4834 gradientsFunc = gradConfig.gradFunc;
4835 }
4836 if (gradientsFunc != null) {
4837 tapeNode.gradient = (dys) => {
4838 // TODO(smilkov): To optimize back-prop, pass dys that are not used in
4839 // the backprop graph to the user as null instead of zeros
4840 dys = dys.map((dy, i) => {
4841 if (dy == null) {
4842 const output = outputs[i];
4843 const vals = makeZerosTypedArray(output.size, output.dtype);
4844 return this.makeTensor(vals, output.shape, output.dtype);
4845 }
4846 return dy;
4847 });
4848 // Grad functions of ops with single outputs expect a dy, while ops
4849 // with multiple outputs expect dys (array of dy).
4850 return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);
4851 };
4852 }
4853 this.state.activeTape.push(tapeNode);
4854 }
4855 keep(result) {
4856 result.kept = true;
4857 return result;
4858 }
4859 startTape() {
4860 if (this.state.gradientDepth === 0) {
4861 this.state.activeTape = [];
4862 }
4863 this.state.gradientDepth++;
4864 }
4865 endTape() {
4866 this.state.gradientDepth--;
4867 }
4868 /**
4869 * Start a scope. Use this with endScope() to achieve the same functionality
4870 * as scope() without the need for a function closure.
4871 */
4872 startScope(name) {
4873 const scopeInfo = {
4874 track: [],
4875 name: 'unnamed scope',
4876 id: this.state.nextScopeId++
4877 };
4878 if (name) {
4879 scopeInfo.name = name;
4880 }
4881 this.state.scopeStack.push(scopeInfo);
4882 this.state.activeScope = scopeInfo;
4883 }
4884 /**
4885 * End a scope. Use this with startScope() to achieve the same functionality
4886 * as scope() without the need for a function closure.
4887 */
4888 endScope(result) {
4889 const tensorsToTrackInParent = getTensorsInContainer(result);
4890 const tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(t => t.id));
4891 // Dispose the arrays tracked in this scope.
4892 for (let i = 0; i < this.state.activeScope.track.length; i++) {
4893 const tensor = this.state.activeScope.track[i];
4894 if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) {
4895 tensor.dispose();
4896 }
4897 }
4898 const oldScope = this.state.scopeStack.pop();
4899 this.state.activeScope = this.state.scopeStack.length === 0 ?
4900 null :
4901 this.state.scopeStack[this.state.scopeStack.length - 1];
4902 // Track the current result in the parent scope.
4903 tensorsToTrackInParent.forEach(tensor => {
4904 // Only track the tensor if was allocated in the inner scope and is not
4905 // globally kept.
4906 if (!tensor.kept && tensor.scopeId === oldScope.id) {
4907 this.track(tensor);
4908 }
4909 });
4910 }
4911 /**
4912 * Returns gradients of `f` with respect to each of the `xs`. The gradients
4913 * returned are of the same length as `xs`, but some might be null if `f`
4914 * was not a function of that `x`. It also takes optional dy to multiply the
4915 * gradient, which defaults to `1`.
4916 */
4917 gradients(f, xs, dy, allowNoGradients = false) {
4918 assert(xs.length > 0, () => 'gradients() received an empty list of xs.');
4919 if (dy != null && dy.dtype !== 'float32') {
4920 throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`);
4921 }
4922 const y = this.scopedRun(() => this.startTape(), () => this.endTape(), () => this.tidy('forward', f));
4923 assert(y instanceof Tensor, () => 'The result y returned by f() must be a tensor.');
4924 // Filter out the nodes that don't connect x => y.
4925 const filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);
4926 if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
4927 throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' +
4928 'that the f you passed encloses all operations that lead from x ' +
4929 'to y.');
4930 }
4931 return this.tidy('backward', () => {
4932 const accumulatedGradientMap = {};
4933 accumulatedGradientMap[y.id] = (dy == null) ? ones(y.shape) : dy;
4934 // Backprop gradients through the filtered nodes.
4935 backpropagateGradients(accumulatedGradientMap, filteredTape,
4936 // Pass the tidy function to avoid circular dep with `tape.ts`.
4937 f => this.tidy(f),
4938 // Pass an add function to avoide a circular dep with `tape.ts`.
4939 add);
4940 const grads = xs.map(x => accumulatedGradientMap[x.id]);
4941 if (this.state.gradientDepth === 0) {
4942 // This means that we are not computing higher-order gradients
4943 // and can clean up the tape.
4944 this.state.activeTape.forEach(node => {
4945 for (const tensor of node.saved) {
4946 tensor.dispose();
4947 }
4948 });
4949 this.state.activeTape = null;
4950 }
4951 return { value: y, grads };
4952 });
4953 }
4954 customGrad(f) {
4955 assert(isFunction(f), () => 'The f passed in customGrad(f) must be a function.');
4956 return (...inputs) => {
4957 assert(inputs.every(t => t instanceof Tensor), () => 'The args passed in customGrad(f)(x1, x2,...) must all be ' +
4958 'tensors');
4959 let res;
4960 const inputMap = {};
4961 inputs.forEach((input, i) => {
4962 inputMap[i] = input;
4963 });
4964 const forwardFunc = (_, save) => {
4965 res = f(...[...inputs, save]);
4966 assert(res.value instanceof Tensor, () => 'The function f passed in customGrad(f) must return an ' +
4967 'object where `obj.value` is a tensor');
4968 assert(isFunction(res.gradFunc), () => 'The function f passed in customGrad(f) must return an ' +
4969 'object where `obj.gradFunc` is a function.');
4970 return res.value;
4971 };
4972 const backwardsFunc = (dy, saved) => {
4973 const gradRes = res.gradFunc(dy, saved);
4974 const grads = Array.isArray(gradRes) ? gradRes : [gradRes];
4975 assert(grads.length === inputs.length, () => 'The function f passed in customGrad(f) must return an ' +
4976 'object where `obj.gradFunc` is a function that returns ' +
4977 'the same number of tensors as inputs passed to f(...).');
4978 assert(grads.every(t => t instanceof Tensor), () => 'The function f passed in customGrad(f) must return an ' +
4979 'object where `obj.gradFunc` is a function that returns ' +
4980 'a list of only tensors.');
4981 const gradMap = {};
4982 grads.forEach((grad, i) => {
4983 gradMap[i] = () => grad;
4984 });
4985 return gradMap;
4986 };
4987 return this.runKernelFunc({
4988 forwardFunc,
4989 backwardsFunc,
4990 inputs: inputMap,
4991 });
4992 };
4993 }
4994 readSync(dataId) {
4995 // Route the read to the correct backend.
4996 const info = this.state.tensorInfo.get(dataId);
4997 return info.backend.readSync(dataId);
4998 }
4999 read(dataId) {
5000 // Route the read to the correct backend.
5001 const info = this.state.tensorInfo.get(dataId);
5002 return info.backend.read(dataId);
5003 }
5004 readToGPU(dataId, options) {
5005 // Route the read to the correct backend.
5006 const info = this.state.tensorInfo.get(dataId);
5007 return info.backend.readToGPU(dataId, options);
5008 }
5009 async time(query) {
5010 const start = now();
5011 const timingInfo = await this.backend.time(query);
5012 timingInfo.wallMs = now() - start;
5013 return timingInfo;
5014 }
5015 /**
5016 * Tracks a Tensor in the current scope to be automatically cleaned up
5017 * when the current scope ends, and returns the value.
5018 *
5019 * @param result The Tensor to track in the current scope.
5020 */
5021 track(result) {
5022 if (this.state.activeScope != null) {
5023 result.scopeId = this.state.activeScope.id;
5024 this.state.activeScope.track.push(result);
5025 }
5026 return result;
5027 }
5028 get registeredVariables() {
5029 return this.state.registeredVariables;
5030 }
5031 /**
5032 * Resets the engine state. Removes all backends but does not remove
5033 * registered backend factories.
5034 */
5035 reset() {
5036 // Make any pending promise obsolete.
5037 this.pendingBackendInitId++;
5038 this.state.dispose();
5039 this.ENV.reset();
5040 this.state = new EngineState();
5041 for (const backendName in this.registry) {
5042 this.disposeRegisteredKernels(backendName);
5043 this.registry[backendName].dispose();
5044 delete this.registry[backendName];
5045 }
5046 this.backendName = null;
5047 this.backendInstance = null;
5048 this.pendingBackendInit = null;
5049 }
5050 }
5051 Engine.nextTensorId = 0;
5052 Engine.nextVariableId = 0;
5053 function ones(shape) {
5054 const values = makeOnesTypedArray(sizeFromShape(shape), 'float32');
5055 return ENGINE.makeTensor(values, shape, 'float32');
5056 }
5057 function getOrMakeEngine() {
5058 const ns = getGlobalNamespace();
5059 if (ns._tfengine == null) {
5060 const environment = new Environment(ns);
5061 ns._tfengine = new Engine(environment);
5062 }
5063 setEnvironmentGlobal(ns._tfengine.ENV);
5064 // Tell the current tensor interface that the global engine is responsible
5065 // for tracking.
5066 setTensorTracker(() => ns._tfengine);
5067 return ns._tfengine;
5068 }
5069 const ENGINE = getOrMakeEngine();
5070 /**
5071 * A implementation of the add op for use within engine and tape.
5072 *
5073 * This allows us to avoid a circular dependency between add.ts and engine.
5074 * It is exported to be available in tape tests.
5075 */
5076 function add(a, b) {
5077 // We duplicate Add here to avoid a circular dependency with add.ts.
5078 const inputs = { a, b };
5079 return ENGINE.runKernel(Add, inputs);
5080 }
5081
5082 /**
5083 * @license
5084 * Copyright 2017 Google LLC. All Rights Reserved.
5085 * Licensed under the Apache License, Version 2.0 (the "License");
5086 * you may not use this file except in compliance with the License.
5087 * You may obtain a copy of the License at
5088 *
5089 * http://www.apache.org/licenses/LICENSE-2.0
5090 *
5091 * Unless required by applicable law or agreed to in writing, software
5092 * distributed under the License is distributed on an "AS IS" BASIS,
5093 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5094 * See the License for the specific language governing permissions and
5095 * limitations under the License.
5096 * =============================================================================
5097 */
5098 // tslint:disable-next-line:no-any
5099 function _isNavigatorDefined() {
5100 return typeof navigator !== 'undefined' && navigator != null;
5101 }
5102 let isMobileMockValue;
5103 function mockIsMobile(value) {
5104 isMobileMockValue = value;
5105 }
5106 function isMobile(nav) {
5107 if (isMobileMockValue !== undefined) {
5108 return isMobileMockValue;
5109 }
5110 if (nav || _isNavigatorDefined()) {
5111 if (!nav) {
5112 nav = navigator;
5113 }
5114 if (nav.product === 'ReactNative') {
5115 return true;
5116 }
5117 const a = nav.userAgent || nav.vendor ||
5118 // tslint:disable-next-line:no-any
5119 (typeof window !== 'undefined' ? window.opera : '');
5120 // Use `navigator.userAgentData.mobile` as fallback.
5121 if (!a) {
5122 // tslint:disable-next-line:no-any
5123 const navAny = nav;
5124 return navAny.userAgentData && navAny.userAgentData.mobile;
5125 }
5126 // tslint:disable-next-line:max-line-length
5127 return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i
5128 .test(a) ||
5129 // tslint:disable-next-line:max-line-length
5130 /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i
5131 .test(a.substr(0, 4));
5132 }
5133 return false;
5134 }
5135 function isBrowser() {
5136 return (typeof window !== 'undefined' && window.document != null) ||
5137 //@ts-ignore
5138 (typeof WorkerGlobalScope !== 'undefined');
5139 }
5140
5141 var device_util = /*#__PURE__*/Object.freeze({
5142 __proto__: null,
5143 mockIsMobile: mockIsMobile,
5144 isMobile: isMobile,
5145 isBrowser: isBrowser
5146 });
5147
5148 /**
5149 * @license
5150 * Copyright 2019 Google LLC. All Rights Reserved.
5151 * Licensed under the Apache License, Version 2.0 (the "License");
5152 * you may not use this file except in compliance with the License.
5153 * You may obtain a copy of the License at
5154 *
5155 * http://www.apache.org/licenses/LICENSE-2.0
5156 *
5157 * Unless required by applicable law or agreed to in writing, software
5158 * distributed under the License is distributed on an "AS IS" BASIS,
5159 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5160 * See the License for the specific language governing permissions and
5161 * limitations under the License.
5162 * =============================================================================
5163 */
5164 const ENV = env();
5165 /**
5166 * This file contains environment-related flag registrations.
5167 */
5168 /** Whether to enable debug mode. */
5169 ENV.registerFlag('DEBUG', () => false, debugValue => {
5170 if (debugValue) {
5171 console.warn('Debugging mode is ON. The output of every math call will ' +
5172 'be downloaded to CPU and checked for NaNs. ' +
5173 'This significantly impacts performance.');
5174 }
5175 });
5176 /** Whether we are in a browser (as versus, say, node.js) environment. */
5177 ENV.registerFlag('IS_BROWSER', () => isBrowser());
5178 /** Whether we are in a browser (as versus, say, node.js) environment. */
5179 ENV.registerFlag('IS_NODE', () => (typeof process !== 'undefined') &&
5180 (typeof process.versions !== 'undefined') &&
5181 (typeof process.versions.node !== 'undefined'));
5182 /** Whether this browser is Chrome. */
5183 ENV.registerFlag('IS_CHROME', () => typeof navigator !== 'undefined' && navigator != null &&
5184 navigator.userAgent != null && /Chrome/.test(navigator.userAgent) &&
5185 /Google Inc/.test(navigator.vendor));
5186 /**
5187 * True when the environment is "production" where we disable safety checks
5188 * to gain performance.
5189 */
5190 ENV.registerFlag('PROD', () => false);
5191 /**
5192 * Whether to do sanity checks when inferring a shape from user-provided
5193 * values, used when creating a new tensor.
5194 */
5195 ENV.registerFlag('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', () => ENV.getBool('DEBUG'));
5196 /** Whether deprecation warnings are enabled. */
5197 ENV.registerFlag('DEPRECATION_WARNINGS_ENABLED', () => true);
5198 /** True if running unit tests. */
5199 ENV.registerFlag('IS_TEST', () => false);
5200 /** Whether to check computation result for errors. */
5201 ENV.registerFlag('CHECK_COMPUTATION_FOR_ERRORS', () => true);
5202 /** Whether the backend needs to wrap input to imageBitmap. */
5203 ENV.registerFlag('WRAP_TO_IMAGEBITMAP', () => false);
5204 /** Experimental flag, whether enter compile only phase. */
5205 ENV.registerFlag('ENGINE_COMPILE_ONLY', () => false);
5206
5207 /**
5208 * @license
5209 * Copyright 2018 Google LLC. All Rights Reserved.
5210 * Licensed under the Apache License, Version 2.0 (the "License");
5211 * you may not use this file except in compliance with the License.
5212 * You may obtain a copy of the License at
5213 *
5214 * http://www.apache.org/licenses/LICENSE-2.0
5215 *
5216 * Unless required by applicable law or agreed to in writing, software
5217 * distributed under the License is distributed on an "AS IS" BASIS,
5218 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5219 * See the License for the specific language governing permissions and
5220 * limitations under the License.
5221 * =============================================================================
5222 */
5223 function inferShape(val, dtype) {
5224 let firstElem = val;
5225 if (isTypedArray(val)) {
5226 return dtype === 'string' ? [] : [val.length];
5227 }
5228 if (!Array.isArray(val)) {
5229 return []; // Scalar.
5230 }
5231 const shape = [];
5232 while (Array.isArray(firstElem) ||
5233 isTypedArray(firstElem) && dtype !== 'string') {
5234 shape.push(firstElem.length);
5235 firstElem = firstElem[0];
5236 }
5237 if (Array.isArray(val) &&
5238 env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) {
5239 deepAssertShapeConsistency(val, shape, []);
5240 }
5241 return shape;
5242 }
5243 function deepAssertShapeConsistency(val, shape, indices) {
5244 indices = indices || [];
5245 if (!(Array.isArray(val)) && !isTypedArray(val)) {
5246 assert(shape.length === 0, () => `Element arr[${indices.join('][')}] is a primitive, ` +
5247 `but should be an array/TypedArray of ${shape[0]} elements`);
5248 return;
5249 }
5250 assert(shape.length > 0, () => `Element arr[${indices.join('][')}] should be a primitive, ` +
5251 `but is an array of ${val.length} elements`);
5252 assert(val.length === shape[0], () => `Element arr[${indices.join('][')}] should have ${shape[0]} ` +
5253 `elements, but has ${val.length} elements`);
5254 const subShape = shape.slice(1);
5255 for (let i = 0; i < val.length; ++i) {
5256 deepAssertShapeConsistency(val[i], subShape, indices.concat(i));
5257 }
5258 }
5259 function assertDtype(expectedDtype, actualDType, argName, functionName) {
5260 if (expectedDtype === 'string_or_numeric') {
5261 return;
5262 }
5263 if (expectedDtype == null) {
5264 throw new Error(`Expected dtype cannot be null.`);
5265 }
5266 if (expectedDtype !== 'numeric' && expectedDtype !== actualDType ||
5267 expectedDtype === 'numeric' && actualDType === 'string') {
5268 throw new Error(`Argument '${argName}' passed to '${functionName}' must ` +
5269 `be ${expectedDtype} tensor, but got ${actualDType} tensor`);
5270 }
5271 }
5272 function convertToTensor(x, argName, functionName, parseAsDtype = 'numeric') {
5273 if (x instanceof Tensor) {
5274 assertDtype(parseAsDtype, x.dtype, argName, functionName);
5275 return x;
5276 }
5277 let inferredDtype = inferDtype(x);
5278 // If the user expects a bool/int/float, use that info to update the
5279 // inferredDtype when it is not a string.
5280 if (inferredDtype !== 'string' &&
5281 ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) {
5282 inferredDtype = parseAsDtype;
5283 }
5284 assertDtype(parseAsDtype, inferredDtype, argName, functionName);
5285 if ((x == null) ||
5286 (!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' &&
5287 typeof x !== 'boolean' && typeof x !== 'string')) {
5288 const type = x == null ? 'null' : x.constructor.name;
5289 throw new Error(`Argument '${argName}' passed to '${functionName}' must be a ` +
5290 `Tensor or TensorLike, but got '${type}'`);
5291 }
5292 const inferredShape = inferShape(x, inferredDtype);
5293 if (!isTypedArray(x) && !Array.isArray(x)) {
5294 x = [x];
5295 }
5296 const skipTypedArray = true;
5297 const values = inferredDtype !== 'string' ?
5298 toTypedArray(x, inferredDtype) :
5299 flatten(x, [], skipTypedArray);
5300 return ENGINE.makeTensor(values, inferredShape, inferredDtype);
5301 }
5302 function convertToTensorArray(arg, argName, functionName, parseAsDtype = 'numeric') {
5303 if (!Array.isArray(arg)) {
5304 throw new Error(`Argument ${argName} passed to ${functionName} must be a ` +
5305 '`Tensor[]` or `TensorLike[]`');
5306 }
5307 const tensors = arg;
5308 return tensors.map((t, i) => convertToTensor(t, `${argName}[${i}]`, functionName, parseAsDtype));
5309 }
5310
5311 /**
5312 * @license
5313 * Copyright 2018 Google LLC. All Rights Reserved.
5314 * Licensed under the Apache License, Version 2.0 (the "License");
5315 * you may not use this file except in compliance with the License.
5316 * You may obtain a copy of the License at
5317 *
5318 * http://www.apache.org/licenses/LICENSE-2.0
5319 *
5320 * Unless required by applicable law or agreed to in writing, software
5321 * distributed under the License is distributed on an "AS IS" BASIS,
5322 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5323 * See the License for the specific language governing permissions and
5324 * limitations under the License.
5325 * =============================================================================
5326 */
5327 const OP_SCOPE_SUFFIX = '__op';
5328 /**
5329 * Used for wrapping functions that perform math operations on
5330 * Tensors. The function will be wrapped in a named scope that cleans all
5331 * memory usage after the function is done.
5332 */
5333 function op(f) {
5334 const keys = Object.keys(f);
5335 if (keys.length !== 1) {
5336 throw new Error(`Please provide an object with a single key ` +
5337 `(operation name) mapping to a function. Got an object with ` +
5338 `${keys.length} keys.`);
5339 }
5340 let opName = keys[0];
5341 const fn = f[opName];
5342 // Strip the underscore from the end of the function name.
5343 if (opName.endsWith('_')) {
5344 opName = opName.substring(0, opName.length - 1);
5345 }
5346 // add an __op suffix to distinguish ops from kernels in tf.profile
5347 opName = opName + OP_SCOPE_SUFFIX;
5348 // tslint:disable-next-line:no-any
5349 const f2 = (...args) => {
5350 ENGINE.startScope(opName);
5351 try {
5352 const result = fn(...args);
5353 if (isPromise(result)) {
5354 console.error('Cannot return a Promise inside of tidy.');
5355 }
5356 ENGINE.endScope(result);
5357 return result;
5358 }
5359 catch (ex) {
5360 ENGINE.endScope(null);
5361 throw ex;
5362 }
5363 };
5364 Object.defineProperty(f2, 'name', { value: opName, configurable: true });
5365 // tslint:disable-next-line:no-any
5366 return f2;
5367 }
5368
5369 /**
5370 * @license
5371 * Copyright 2020 Google LLC. All Rights Reserved.
5372 * Licensed under the Apache License, Version 2.0 (the "License");
5373 * you may not use this file except in compliance with the License.
5374 * You may obtain a copy of the License at
5375 *
5376 * http://www.apache.org/licenses/LICENSE-2.0
5377 *
5378 * Unless required by applicable law or agreed to in writing, software
5379 * distributed under the License is distributed on an "AS IS" BASIS,
5380 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5381 * See the License for the specific language governing permissions and
5382 * limitations under the License.
5383 * =============================================================================
5384 */
5385 /**
5386 * Converts two real numbers to a complex number.
5387 *
5388 * Given a tensor `real` representing the real part of a complex number, and a
5389 * tensor `imag` representing the imaginary part of a complex number, this
5390 * operation returns complex numbers elementwise of the form [r0, i0, r1, i1],
5391 * where r represents the real part and i represents the imag part.
5392 *
5393 * The input tensors real and imag must have the same shape.
5394 *
5395 * ```js
5396 * const real = tf.tensor1d([2.25, 3.25]);
5397 * const imag = tf.tensor1d([4.75, 5.75]);
5398 * const complex = tf.complex(real, imag);
5399 *
5400 * complex.print();
5401 * ```
5402 *
5403 * @doc {heading: 'Tensors', subheading: 'Creation'}
5404 */
5405 function complex_(real, imag) {
5406 const $real = convertToTensor(real, 'real', 'complex');
5407 const $imag = convertToTensor(imag, 'imag', 'complex');
5408 assertShapesMatch($real.shape, $imag.shape, `real and imag shapes, ${$real.shape} and ${$imag.shape}, ` +
5409 `must match in call to tf.complex().`);
5410 const inputs = { real: $real, imag: $imag };
5411 return ENGINE.runKernel(Complex, inputs);
5412 }
5413 const complex = op({ complex_ });
5414
5415 /**
5416 * @license
5417 * Copyright 2018 Google LLC. All Rights Reserved.
5418 * Licensed under the Apache License, Version 2.0 (the "License");
5419 * you may not use this file except in compliance with the License.
5420 * You may obtain a copy of the License at
5421 *
5422 * http://www.apache.org/licenses/LICENSE-2.0
5423 *
5424 * Unless required by applicable law or agreed to in writing, software
5425 * distributed under the License is distributed on an "AS IS" BASIS,
5426 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5427 * See the License for the specific language governing permissions and
5428 * limitations under the License.
5429 * =============================================================================
5430 */
5431 /** This is shared code across all tensor creation methods. */
5432 function makeTensor(values, shape, inferredShape, dtype) {
5433 if (dtype == null) {
5434 dtype = inferDtype(values);
5435 }
5436 if (dtype === 'complex64') {
5437 throw new Error(`Cannot construct a complex64 tensor directly. ` +
5438 `Please use tf.complex(real, imag).`);
5439 }
5440 if (!isTypedArray(values) && !Array.isArray(values) &&
5441 typeof values !== 'number' && typeof values !== 'boolean' &&
5442 typeof values !== 'string') {
5443 throw new Error('values passed to tensor(values) must be a number/boolean/string or ' +
5444 'an array of numbers/booleans/strings, or a TypedArray');
5445 }
5446 if (shape != null) {
5447 assertNonNegativeIntegerDimensions(shape);
5448 const providedSize = sizeFromShape(shape);
5449 const inferredSize = sizeFromShape(inferredShape);
5450 assert(providedSize === inferredSize, () => `Based on the provided shape, [${shape}], the tensor should have ` +
5451 `${providedSize} values but has ${inferredSize}`);
5452 for (let i = 0; i < inferredShape.length; ++i) {
5453 const inferred = inferredShape[i];
5454 const flatDimsDontMatch = i === inferredShape.length - 1 ?
5455 inferred !== sizeFromShape(shape.slice(i)) :
5456 true;
5457 assert(inferredShape[i] === shape[i] || !flatDimsDontMatch, () => `Error creating a new Tensor. Inferred shape ` +
5458 `(${inferredShape}) does not match the provided ` +
5459 `shape (${shape}). `);
5460 }
5461 }
5462 if (!isTypedArray(values) && !Array.isArray(values)) {
5463 values = [values];
5464 }
5465 shape = shape || inferredShape;
5466 values = dtype !== 'string' ?
5467 toTypedArray(values, dtype) :
5468 flatten(values, [], true);
5469 return ENGINE.makeTensor(values, shape, dtype);
5470 }
5471
5472 /**
5473 * @license
5474 * Copyright 2018 Google LLC. All Rights Reserved.
5475 * Licensed under the Apache License, Version 2.0 (the "License");
5476 * you may not use this file except in compliance with the License.
5477 * You may obtain a copy of the License at
5478 *
5479 * http://www.apache.org/licenses/LICENSE-2.0
5480 *
5481 * Unless required by applicable law or agreed to in writing, software
5482 * distributed under the License is distributed on an "AS IS" BASIS,
5483 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5484 * See the License for the specific language governing permissions and
5485 * limitations under the License.
5486 * =============================================================================
5487 */
5488 /**
5489 * Creates a `tf.Tensor` with the provided values, shape and dtype.
5490 *
5491 * ```js
5492 * // Pass an array of values to create a vector.
5493 * tf.tensor([1, 2, 3, 4]).print();
5494 * ```
5495 *
5496 * ```js
5497 * // Pass a nested array of values to make a matrix or a higher
5498 * // dimensional tensor.
5499 * tf.tensor([[1, 2], [3, 4]]).print();
5500 * ```
5501 *
5502 * ```js
5503 * // Pass a flat array and specify a shape yourself.
5504 * tf.tensor([1, 2, 3, 4], [2, 2]).print();
5505 * ```
5506 *
5507 * @param values The values of the tensor. Can be nested array of numbers,
5508 * or a flat array, or a `TypedArray`. If the values are strings,
5509 * they will be encoded as utf-8 and kept as `Uint8Array[]`.
5510 * @param shape The shape of the tensor. Optional. If not provided,
5511 * it is inferred from `values`.
5512 * @param dtype The data type.
5513 *
5514 * @doc {heading: 'Tensors', subheading: 'Creation'}
5515 */
5516 function tensor(values, shape, dtype) {
5517 const inferredShape = inferShape(values, dtype);
5518 return makeTensor(values, shape, inferredShape, dtype);
5519 }
5520
5521 /**
5522 * @license
5523 * Copyright 2018 Google LLC. All Rights Reserved.
5524 * Licensed under the Apache License, Version 2.0 (the "License");
5525 * you may not use this file except in compliance with the License.
5526 * You may obtain a copy of the License at
5527 *
5528 * http://www.apache.org/licenses/LICENSE-2.0
5529 *
5530 * Unless required by applicable law or agreed to in writing, software
5531 * distributed under the License is distributed on an "AS IS" BASIS,
5532 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5533 * See the License for the specific language governing permissions and
5534 * limitations under the License.
5535 * =============================================================================
5536 */
5537 /* Type definitions for exporting and importing of models. */
5538 /**
5539 * A map from Tensor dtype to number of bytes per element of the Tensor.
5540 */
5541 const DTYPE_VALUE_SIZE_MAP = {
5542 'float32': 4,
5543 'float16': 2,
5544 'int32': 4,
5545 'uint16': 2,
5546 'uint8': 1,
5547 'bool': 1,
5548 'complex64': 8
5549 };
5550
5551 /**
5552 * @license
5553 * Copyright 2018 Google LLC. All Rights Reserved.
5554 * Licensed under the Apache License, Version 2.0 (the "License");
5555 * you may not use this file except in compliance with the License.
5556 * You may obtain a copy of the License at
5557 *
5558 * http://www.apache.org/licenses/LICENSE-2.0
5559 *
5560 * Unless required by applicable law or agreed to in writing, software
5561 * distributed under the License is distributed on an "AS IS" BASIS,
5562 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5563 * See the License for the specific language governing permissions and
5564 * limitations under the License.
5565 * =============================================================================
5566 */
5567 /** Number of bytes reserved for the length of the string. (32bit integer). */
5568 const NUM_BYTES_STRING_LENGTH = 4;
5569 /**
5570 * Encode a map from names to weight values as an ArrayBuffer, along with an
5571 * `Array` of `WeightsManifestEntry` as specification of the encoded weights.
5572 *
5573 * This function does not perform sharding.
5574 *
5575 * This function is the reverse of `decodeWeights`.
5576 *
5577 * @param tensors A map ("dict") from names to tensors.
5578 * @param group Group to which the weights belong (optional).
5579 * @returns A `Promise` of
5580 * - A flat `ArrayBuffer` with all the binary values of the `Tensor`s
5581 * concatenated.
5582 * - An `Array` of `WeightManifestEntry`s, carrying information including
5583 * tensor names, `dtype`s and shapes.
5584 * @throws Error: on unsupported tensor `dtype`.
5585 */
5586 async function encodeWeights(tensors, group) {
5587 // TODO(adarob, cais): Support quantization.
5588 const specs = [];
5589 const dataPromises = [];
5590 const names = Array.isArray(tensors) ?
5591 tensors.map(tensor => tensor.name) :
5592 Object.keys(tensors);
5593 for (let i = 0; i < names.length; ++i) {
5594 const name = names[i];
5595 const t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name];
5596 if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' &&
5597 t.dtype !== 'string' && t.dtype !== 'complex64') {
5598 throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`);
5599 }
5600 const spec = { name, shape: t.shape, dtype: t.dtype };
5601 if (t.dtype === 'string') {
5602 const utf8bytes = new Promise(async (resolve) => {
5603 const vals = await t.bytes();
5604 const totalNumBytes = vals.reduce((p, c) => p + c.length, 0) +
5605 NUM_BYTES_STRING_LENGTH * vals.length;
5606 const bytes = new Uint8Array(totalNumBytes);
5607 let offset = 0;
5608 for (let i = 0; i < vals.length; i++) {
5609 const val = vals[i];
5610 const bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer);
5611 bytes.set(bytesOfLength, offset);
5612 offset += NUM_BYTES_STRING_LENGTH;
5613 bytes.set(val, offset);
5614 offset += val.length;
5615 }
5616 resolve(bytes);
5617 });
5618 dataPromises.push(utf8bytes);
5619 }
5620 else {
5621 dataPromises.push(t.data());
5622 }
5623 if (group != null) {
5624 spec.group = group;
5625 }
5626 specs.push(spec);
5627 }
5628 const tensorValues = await Promise.all(dataPromises);
5629 return { data: concatenateTypedArrays(tensorValues), specs };
5630 }
5631 /**
5632 * Decode flat ArrayBuffer as weights.
5633 *
5634 * This function does not handle sharding.
5635 *
5636 * This function is the reverse of `encodeWeights`.
5637 *
5638 * @param buffer A flat ArrayBuffer carrying the binary values of the tensors
5639 * concatenated in the order specified in `specs`.
5640 * @param specs Specifications of the names, dtypes and shapes of the tensors
5641 * whose value are encoded by `buffer`.
5642 * @return A map from tensor name to tensor value, with the names corresponding
5643 * to names in `specs`.
5644 * @throws Error, if any of the tensors has unsupported dtype.
5645 */
5646 function decodeWeights(buffer, specs) {
5647 // TODO(adarob, cais): Support quantization.
5648 const out = {};
5649 let float16Decode;
5650 let offset = 0;
5651 for (const spec of specs) {
5652 const name = spec.name;
5653 const dtype = spec.dtype;
5654 const shape = spec.shape;
5655 const size = sizeFromShape(shape);
5656 let values;
5657 if ('quantization' in spec) {
5658 const quantization = spec.quantization;
5659 if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
5660 if (!('min' in quantization && 'scale' in quantization)) {
5661 throw new Error(`Weight ${spec.name} with quantization ${quantization.dtype} ` +
5662 `doesn't have corresponding metadata min and scale.`);
5663 }
5664 }
5665 else if (quantization.dtype === 'float16') {
5666 if (dtype !== 'float32') {
5667 throw new Error(`Weight ${spec.name} is quantized with ${quantization.dtype} ` +
5668 `which only supports weights of type float32 not ${dtype}.`);
5669 }
5670 }
5671 else {
5672 throw new Error(`Weight ${spec.name} has unknown ` +
5673 `quantization dtype ${quantization.dtype}. ` +
5674 `Supported quantization dtypes are: ` +
5675 `'uint8', 'uint16', and 'float16'.`);
5676 }
5677 const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
5678 const byteBuffer = buffer.slice(offset, offset + size * quantizationSizeFactor);
5679 const quantizedArray = (quantization.dtype === 'uint8') ?
5680 new Uint8Array(byteBuffer) :
5681 new Uint16Array(byteBuffer);
5682 if (dtype === 'float32') {
5683 if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
5684 values = new Float32Array(quantizedArray.length);
5685 for (let i = 0; i < quantizedArray.length; i++) {
5686 const v = quantizedArray[i];
5687 values[i] = v * quantization.scale + quantization.min;
5688 }
5689 }
5690 else if (quantization.dtype === 'float16') {
5691 if (float16Decode === undefined) {
5692 float16Decode = getFloat16Decoder();
5693 }
5694 values = float16Decode(quantizedArray);
5695 }
5696 else {
5697 throw new Error(`Unsupported quantization type ${quantization.dtype} ` +
5698 `for weight type float32.`);
5699 }
5700 }
5701 else if (dtype === 'int32') {
5702 if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {
5703 throw new Error(`Unsupported quantization type ${quantization.dtype} ` +
5704 `for weight type int32.`);
5705 }
5706 values = new Int32Array(quantizedArray.length);
5707 for (let i = 0; i < quantizedArray.length; i++) {
5708 const v = quantizedArray[i];
5709 values[i] = Math.round(v * quantization.scale + quantization.min);
5710 }
5711 }
5712 else {
5713 throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
5714 }
5715 offset += size * quantizationSizeFactor;
5716 }
5717 else if (dtype === 'string') {
5718 const size = sizeFromShape(spec.shape);
5719 values = [];
5720 for (let i = 0; i < size; i++) {
5721 const byteLength = new Uint32Array(buffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0];
5722 offset += NUM_BYTES_STRING_LENGTH;
5723 const bytes = new Uint8Array(buffer.slice(offset, offset + byteLength));
5724 values.push(bytes);
5725 offset += byteLength;
5726 }
5727 }
5728 else {
5729 const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype];
5730 const byteBuffer = buffer.slice(offset, offset + size * dtypeFactor);
5731 if (dtype === 'float32') {
5732 values = new Float32Array(byteBuffer);
5733 }
5734 else if (dtype === 'int32') {
5735 values = new Int32Array(byteBuffer);
5736 }
5737 else if (dtype === 'bool') {
5738 values = new Uint8Array(byteBuffer);
5739 }
5740 else if (dtype === 'complex64') {
5741 values = new Float32Array(byteBuffer);
5742 const real = new Float32Array(values.length / 2);
5743 const image = new Float32Array(values.length / 2);
5744 for (let i = 0; i < real.length; i++) {
5745 real[i] = values[i * 2];
5746 image[i] = values[i * 2 + 1];
5747 }
5748 const realTensor = tensor(real, shape, 'float32');
5749 const imageTensor = tensor(image, shape, 'float32');
5750 out[name] = complex(realTensor, imageTensor);
5751 realTensor.dispose();
5752 imageTensor.dispose();
5753 }
5754 else {
5755 throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
5756 }
5757 offset += size * dtypeFactor;
5758 }
5759 if (dtype !== 'complex64') {
5760 out[name] = tensor(values, shape, dtype);
5761 }
5762 }
5763 return out;
5764 }
5765 /**
5766 * Concatenate TypedArrays into an ArrayBuffer.
5767 */
5768 function concatenateTypedArrays(xs) {
5769 // TODO(adarob, cais): Support quantization.
5770 if (xs === null) {
5771 throw new Error(`Invalid input value: ${JSON.stringify(xs)}`);
5772 }
5773 let totalByteLength = 0;
5774 // `normalizedXs` is here for this reason: a `TypedArray`'s `buffer'
5775 // can have a different byte length from that of the `TypedArray` itself,
5776 // for example, when the `TypedArray` is created from an offset in an
5777 // `ArrayBuffer`. `normliazedXs` holds `TypedArray`s whose `buffer`s match
5778 // the `TypedArray` in byte length. If an element of `xs` does not show
5779 // this property, a new `TypedArray` that satisfy this property will be
5780 // constructed and pushed into `normalizedXs`.
5781 const normalizedXs = [];
5782 xs.forEach((x) => {
5783 totalByteLength += x.byteLength;
5784 // tslint:disable:no-any
5785 normalizedXs.push(x.byteLength === x.buffer.byteLength ? x :
5786 new x.constructor(x));
5787 if (!(x instanceof Float32Array || x instanceof Int32Array ||
5788 x instanceof Uint8Array)) {
5789 throw new Error(`Unsupported TypedArray subtype: ${x.constructor.name}`);
5790 }
5791 // tslint:enable:no-any
5792 });
5793 const y = new Uint8Array(totalByteLength);
5794 let offset = 0;
5795 normalizedXs.forEach((x) => {
5796 y.set(new Uint8Array(x.buffer), offset);
5797 offset += x.byteLength;
5798 });
5799 return y.buffer;
5800 }
5801 // Use Buffer on Node.js instead of Blob/atob/btoa
5802 const useNodeBuffer = typeof Buffer !== 'undefined' &&
5803 (typeof Blob === 'undefined' || typeof atob === 'undefined' ||
5804 typeof btoa === 'undefined');
5805 /**
5806 * Calculate the byte length of a JavaScript string.
5807 *
5808 * Note that a JavaScript string can contain wide characters, therefore the
5809 * length of the string is not necessarily equal to the byte length.
5810 *
5811 * @param str Input string.
5812 * @returns Byte length.
5813 */
5814 function stringByteLength(str) {
5815 if (useNodeBuffer) {
5816 return Buffer.byteLength(str);
5817 }
5818 return new Blob([str]).size;
5819 }
5820 /**
5821 * Encode an ArrayBuffer as a base64 encoded string.
5822 *
5823 * @param buffer `ArrayBuffer` to be converted.
5824 * @returns A string that base64-encodes `buffer`.
5825 */
5826 function arrayBufferToBase64String(buffer) {
5827 if (useNodeBuffer) {
5828 return Buffer.from(buffer).toString('base64');
5829 }
5830 const buf = new Uint8Array(buffer);
5831 let s = '';
5832 for (let i = 0, l = buf.length; i < l; i++) {
5833 s += String.fromCharCode(buf[i]);
5834 }
5835 return btoa(s);
5836 }
5837 /**
5838 * Decode a base64 string as an ArrayBuffer.
5839 *
5840 * @param str Base64 string.
5841 * @returns Decoded `ArrayBuffer`.
5842 */
5843 function base64StringToArrayBuffer(str) {
5844 if (useNodeBuffer) {
5845 const buf = Buffer.from(str, 'base64');
5846 return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength);
5847 }
5848 const s = atob(str);
5849 const buffer = new Uint8Array(s.length);
5850 for (let i = 0; i < s.length; ++i) {
5851 buffer.set([s.charCodeAt(i)], i);
5852 }
5853 return buffer.buffer;
5854 }
5855 /**
5856 * Concatenate a number of ArrayBuffers into one.
5857 *
5858 * @param buffers A number of array buffers to concatenate.
5859 * @returns Result of concatenating `buffers` in order.
5860 */
5861 function concatenateArrayBuffers(buffers) {
5862 if (buffers.length === 1) {
5863 return buffers[0];
5864 }
5865 let totalByteLength = 0;
5866 buffers.forEach((buffer) => {
5867 totalByteLength += buffer.byteLength;
5868 });
5869 const temp = new Uint8Array(totalByteLength);
5870 let offset = 0;
5871 buffers.forEach((buffer) => {
5872 temp.set(new Uint8Array(buffer), offset);
5873 offset += buffer.byteLength;
5874 });
5875 return temp.buffer;
5876 }
5877 /**
5878 * Get the basename of a path.
5879 *
5880 * Behaves in a way analogous to Linux's basename command.
5881 *
5882 * @param path
5883 */
5884 function basename(path) {
5885 const SEPARATOR = '/';
5886 path = path.trim();
5887 while (path.endsWith(SEPARATOR)) {
5888 path = path.slice(0, path.length - 1);
5889 }
5890 const items = path.split(SEPARATOR);
5891 return items[items.length - 1];
5892 }
5893 /**
5894 * Create `ModelJSON` from `ModelArtifacts`.
5895 *
5896 * @param artifacts Model artifacts, describing the model and its weights.
5897 * @param manifest Weight manifest, describing where the weights of the
5898 * `ModelArtifacts` are stored, and some metadata about them.
5899 * @returns Object representing the `model.json` file describing the model
5900 * artifacts and weights
5901 */
5902 function getModelJSONForModelArtifacts(artifacts, manifest) {
5903 const result = {
5904 modelTopology: artifacts.modelTopology,
5905 format: artifacts.format,
5906 generatedBy: artifacts.generatedBy,
5907 convertedBy: artifacts.convertedBy,
5908 weightsManifest: manifest
5909 };
5910 if (artifacts.signature != null) {
5911 result.signature = artifacts.signature;
5912 }
5913 if (artifacts.userDefinedMetadata != null) {
5914 result.userDefinedMetadata = artifacts.userDefinedMetadata;
5915 }
5916 if (artifacts.modelInitializer != null) {
5917 result.modelInitializer = artifacts.modelInitializer;
5918 }
5919 if (artifacts.trainingConfig != null) {
5920 result.trainingConfig = artifacts.trainingConfig;
5921 }
5922 return result;
5923 }
5924 /**
5925 * Create `ModelArtifacts` from a JSON file.
5926 *
5927 * @param modelJSON Object containing the parsed JSON of `model.json`
5928 * @param loadWeights Function that takes the JSON file's weights manifest,
5929 * reads weights from the listed path(s), and returns a Promise of the
5930 * weight manifest entries along with the weights data.
5931 * @returns A Promise of the `ModelArtifacts`, as described by the JSON file.
5932 */
5933 async function getModelArtifactsForJSON(modelJSON, loadWeights) {
5934 const modelArtifacts = {
5935 modelTopology: modelJSON.modelTopology,
5936 format: modelJSON.format,
5937 generatedBy: modelJSON.generatedBy,
5938 convertedBy: modelJSON.convertedBy
5939 };
5940 if (modelJSON.trainingConfig != null) {
5941 modelArtifacts.trainingConfig = modelJSON.trainingConfig;
5942 }
5943 if (modelJSON.weightsManifest != null) {
5944 const [weightSpecs, weightData] = await loadWeights(modelJSON.weightsManifest);
5945 modelArtifacts.weightSpecs = weightSpecs;
5946 modelArtifacts.weightData = weightData;
5947 }
5948 if (modelJSON.signature != null) {
5949 modelArtifacts.signature = modelJSON.signature;
5950 }
5951 if (modelJSON.userDefinedMetadata != null) {
5952 modelArtifacts.userDefinedMetadata = modelJSON.userDefinedMetadata;
5953 }
5954 if (modelJSON.modelInitializer != null) {
5955 modelArtifacts.modelInitializer = modelJSON.modelInitializer;
5956 }
5957 return modelArtifacts;
5958 }
5959 /**
5960 * Populate ModelArtifactsInfo fields for a model with JSON topology.
5961 * @param modelArtifacts
5962 * @returns A ModelArtifactsInfo object.
5963 */
5964 function getModelArtifactsInfoForJSON(modelArtifacts) {
5965 if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
5966 throw new Error('Expected JSON model topology, received ArrayBuffer.');
5967 }
5968 return {
5969 dateSaved: new Date(),
5970 modelTopologyType: 'JSON',
5971 modelTopologyBytes: modelArtifacts.modelTopology == null ?
5972 0 :
5973 stringByteLength(JSON.stringify(modelArtifacts.modelTopology)),
5974 weightSpecsBytes: modelArtifacts.weightSpecs == null ?
5975 0 :
5976 stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)),
5977 weightDataBytes: modelArtifacts.weightData == null ?
5978 0 :
5979 modelArtifacts.weightData.byteLength,
5980 };
5981 }
5982 /**
5983 * Computes mantisa table for casting Float16 to Float32
5984 * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
5985 *
5986 * @returns Uint32Array, 2048 mantissa lookup values.
5987 */
5988 function computeFloat16MantisaTable() {
5989 const convertMantissa = (i) => {
5990 let m = i << 13;
5991 let e = 0;
5992 while ((m & 0x00800000) === 0) {
5993 e -= 0x00800000;
5994 m <<= 1;
5995 }
5996 m &= ~0x00800000;
5997 e += 0x38800000;
5998 return m | e;
5999 };
6000 const mantisaTable = new Uint32Array(2048);
6001 mantisaTable[0] = 0;
6002 for (let i = 1; i < 1024; i++) {
6003 mantisaTable[i] = convertMantissa(i);
6004 }
6005 for (let i = 1024; i < 2048; i++) {
6006 mantisaTable[i] = 0x38000000 + ((i - 1024) << 13);
6007 }
6008 return mantisaTable;
6009 }
6010 /**
6011 * Computes exponent table for casting Float16 to Float32
6012 * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
6013 *
6014 * @returns Uint32Array, 64 exponent lookup values.
6015 */
6016 function computeFloat16ExponentTable() {
6017 const exponentTable = new Uint32Array(64);
6018 exponentTable[0] = 0;
6019 exponentTable[31] = 0x47800000;
6020 exponentTable[32] = 0x80000000;
6021 exponentTable[63] = 0xc7800000;
6022 for (let i = 1; i < 31; i++) {
6023 exponentTable[i] = i << 23;
6024 }
6025 for (let i = 33; i < 63; i++) {
6026 exponentTable[i] = 0x80000000 + ((i - 32) << 23);
6027 }
6028 return exponentTable;
6029 }
6030 /**
6031 * Computes offset table for casting Float16 to Float32
6032 * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
6033 *
6034 * @returns Uint32Array, 6d offset values.
6035 */
6036 function computeFloat16OffsetTable() {
6037 const offsetTable = new Uint32Array(64);
6038 for (let i = 0; i < 64; i++) {
6039 offsetTable[i] = 1024;
6040 }
6041 offsetTable[0] = offsetTable[32] = 0;
6042 return offsetTable;
6043 }
6044 /**
6045 * Retrieve a Float16 decoder which will decode a ByteArray of Float16 values
6046 * to a Float32Array.
6047 *
6048 * @returns Function (buffer: Uint16Array) => Float32Array which decodes
6049 * the Uint16Array of Float16 bytes to a Float32Array.
6050 */
6051 function getFloat16Decoder() {
6052 // Algorithm is based off of
6053 // http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
6054 // Cache lookup tables
6055 const mantisaTable = computeFloat16MantisaTable();
6056 const exponentTable = computeFloat16ExponentTable();
6057 const offsetTable = computeFloat16OffsetTable();
6058 return (quantizedArray) => {
6059 const buffer = new ArrayBuffer(4 * quantizedArray.length);
6060 const bufferUint32View = new Uint32Array(buffer);
6061 for (let index = 0; index < quantizedArray.length; index++) {
6062 const float16Bits = quantizedArray[index];
6063 const float32Bits = mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 0x3ff)] +
6064 exponentTable[float16Bits >> 10];
6065 bufferUint32View[index] = float32Bits;
6066 }
6067 return new Float32Array(buffer);
6068 };
6069 }
6070
6071 /**
6072 * @license
6073 * Copyright 2018 Google LLC. All Rights Reserved.
6074 * Licensed under the Apache License, Version 2.0 (the "License");
6075 * you may not use this file except in compliance with the License.
6076 * You may obtain a copy of the License at
6077 *
6078 * http://www.apache.org/licenses/LICENSE-2.0
6079 *
6080 * Unless required by applicable law or agreed to in writing, software
6081 * distributed under the License is distributed on an "AS IS" BASIS,
6082 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6083 * See the License for the specific language governing permissions and
6084 * limitations under the License.
6085 * =============================================================================
6086 */
6087 class IORouterRegistry {
6088 constructor() {
6089 this.saveRouters = [];
6090 this.loadRouters = [];
6091 }
6092 static getInstance() {
6093 if (IORouterRegistry.instance == null) {
6094 IORouterRegistry.instance = new IORouterRegistry();
6095 }
6096 return IORouterRegistry.instance;
6097 }
6098 /**
6099 * Register a save-handler router.
6100 *
6101 * @param saveRouter A function that maps a URL-like string onto an instance
6102 * of `IOHandler` with the `save` method defined or `null`.
6103 */
6104 static registerSaveRouter(saveRouter) {
6105 IORouterRegistry.getInstance().saveRouters.push(saveRouter);
6106 }
6107 /**
6108 * Register a load-handler router.
6109 *
6110 * @param loadRouter A function that maps a URL-like string onto an instance
6111 * of `IOHandler` with the `load` method defined or `null`.
6112 */
6113 static registerLoadRouter(loadRouter) {
6114 IORouterRegistry.getInstance().loadRouters.push(loadRouter);
6115 }
6116 /**
6117 * Look up IOHandler for saving, given a URL-like string.
6118 *
6119 * @param url
6120 * @returns If only one match is found, an instance of IOHandler with the
6121 * `save` method defined. If no match is found, `null`.
6122 * @throws Error, if more than one match is found.
6123 */
6124 static getSaveHandlers(url) {
6125 return IORouterRegistry.getHandlers(url, 'save');
6126 }
6127 /**
6128 * Look up IOHandler for loading, given a URL-like string.
6129 *
6130 * @param url
6131 * @param loadOptions Optional, custom load options.
6132 * @returns All valid handlers for `url`, given the currently registered
6133 * handler routers.
6134 */
6135 static getLoadHandlers(url, loadOptions) {
6136 return IORouterRegistry.getHandlers(url, 'load', loadOptions);
6137 }
6138 static getHandlers(url, handlerType, loadOptions) {
6139 const validHandlers = [];
6140 const routers = handlerType === 'load' ?
6141 IORouterRegistry.getInstance().loadRouters :
6142 IORouterRegistry.getInstance().saveRouters;
6143 routers.forEach(router => {
6144 const handler = router(url, loadOptions);
6145 if (handler !== null) {
6146 validHandlers.push(handler);
6147 }
6148 });
6149 return validHandlers;
6150 }
6151 }
6152 const registerSaveRouter = (loudRouter) => IORouterRegistry.registerSaveRouter(loudRouter);
6153 const registerLoadRouter = (loudRouter) => IORouterRegistry.registerLoadRouter(loudRouter);
6154 const getSaveHandlers = (url) => IORouterRegistry.getSaveHandlers(url);
6155 const getLoadHandlers = (url, loadOptions) => IORouterRegistry.getLoadHandlers(url, loadOptions);
6156
6157 /**
6158 * @license
6159 * Copyright 2018 Google LLC. All Rights Reserved.
6160 * Licensed under the Apache License, Version 2.0 (the "License");
6161 * you may not use this file except in compliance with the License.
6162 * You may obtain a copy of the License at
6163 *
6164 * http://www.apache.org/licenses/LICENSE-2.0
6165 *
6166 * Unless required by applicable law or agreed to in writing, software
6167 * distributed under the License is distributed on an "AS IS" BASIS,
6168 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6169 * See the License for the specific language governing permissions and
6170 * limitations under the License.
6171 * =============================================================================
6172 */
6173 const DATABASE_NAME = 'tensorflowjs';
6174 const DATABASE_VERSION = 1;
6175 // Model data and ModelArtifactsInfo (metadata) are stored in two separate
6176 // stores for efficient access of the list of stored models and their metadata.
6177 // 1. The object store for model data: topology, weights and weight manifests.
6178 const MODEL_STORE_NAME = 'models_store';
6179 // 2. The object store for ModelArtifactsInfo, including meta-information such
6180 // as the type of topology (JSON vs binary), byte size of the topology, byte
6181 // size of the weights, etc.
6182 const INFO_STORE_NAME = 'model_info_store';
6183 /**
6184 * Delete the entire database for tensorflow.js, including the models store.
6185 */
6186 async function deleteDatabase() {
6187 const idbFactory = getIndexedDBFactory();
6188 return new Promise((resolve, reject) => {
6189 const deleteRequest = idbFactory.deleteDatabase(DATABASE_NAME);
6190 deleteRequest.onsuccess = () => resolve();
6191 deleteRequest.onerror = error => reject(error);
6192 });
6193 }
6194 function getIndexedDBFactory() {
6195 if (!env().getBool('IS_BROWSER')) {
6196 // TODO(cais): Add more info about what IOHandler subtypes are available.
6197 // Maybe point to a doc page on the web and/or automatically determine
6198 // the available IOHandlers and print them in the error message.
6199 throw new Error('Failed to obtain IndexedDB factory because the current environment' +
6200 'is not a web browser.');
6201 }
6202 // tslint:disable-next-line:no-any
6203 const theWindow = typeof window === 'undefined' ? self : window;
6204 const factory = theWindow.indexedDB || theWindow.mozIndexedDB ||
6205 theWindow.webkitIndexedDB || theWindow.msIndexedDB ||
6206 theWindow.shimIndexedDB;
6207 if (factory == null) {
6208 throw new Error('The current browser does not appear to support IndexedDB.');
6209 }
6210 return factory;
6211 }
6212 function setUpDatabase(openRequest) {
6213 const db = openRequest.result;
6214 db.createObjectStore(MODEL_STORE_NAME, { keyPath: 'modelPath' });
6215 db.createObjectStore(INFO_STORE_NAME, { keyPath: 'modelPath' });
6216 }
6217 /**
6218 * IOHandler subclass: Browser IndexedDB.
6219 *
6220 * See the doc string of `browserIndexedDB` for more details.
6221 */
6222 class BrowserIndexedDB {
6223 constructor(modelPath) {
6224 this.indexedDB = getIndexedDBFactory();
6225 if (modelPath == null || !modelPath) {
6226 throw new Error('For IndexedDB, modelPath must not be null, undefined or empty.');
6227 }
6228 this.modelPath = modelPath;
6229 }
6230 async save(modelArtifacts) {
6231 // TODO(cais): Support saving GraphDef models.
6232 if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
6233 throw new Error('BrowserLocalStorage.save() does not support saving model topology ' +
6234 'in binary formats yet.');
6235 }
6236 return this.databaseAction(this.modelPath, modelArtifacts);
6237 }
6238 async load() {
6239 return this.databaseAction(this.modelPath);
6240 }
6241 /**
6242 * Perform database action to put model artifacts into or read model artifacts
6243 * from IndexedDB object store.
6244 *
6245 * Whether the action is put or get depends on whether `modelArtifacts` is
6246 * specified. If it is specified, the action will be put; otherwise the action
6247 * will be get.
6248 *
6249 * @param modelPath A unique string path for the model.
6250 * @param modelArtifacts If specified, it will be the model artifacts to be
6251 * stored in IndexedDB.
6252 * @returns A `Promise` of `SaveResult`, if the action is put, or a `Promise`
6253 * of `ModelArtifacts`, if the action is get.
6254 */
6255 databaseAction(modelPath, modelArtifacts) {
6256 return new Promise((resolve, reject) => {
6257 const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
6258 openRequest.onupgradeneeded = () => setUpDatabase(openRequest);
6259 openRequest.onsuccess = () => {
6260 const db = openRequest.result;
6261 if (modelArtifacts == null) {
6262 // Read model out from object store.
6263 const modelTx = db.transaction(MODEL_STORE_NAME, 'readonly');
6264 const modelStore = modelTx.objectStore(MODEL_STORE_NAME);
6265 const getRequest = modelStore.get(this.modelPath);
6266 getRequest.onsuccess = () => {
6267 if (getRequest.result == null) {
6268 db.close();
6269 return reject(new Error(`Cannot find model with path '${this.modelPath}' ` +
6270 `in IndexedDB.`));
6271 }
6272 else {
6273 resolve(getRequest.result.modelArtifacts);
6274 }
6275 };
6276 getRequest.onerror = error => {
6277 db.close();
6278 return reject(getRequest.error);
6279 };
6280 modelTx.oncomplete = () => db.close();
6281 }
6282 else {
6283 // Put model into object store.
6284 const modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts);
6285 // First, put ModelArtifactsInfo into info store.
6286 const infoTx = db.transaction(INFO_STORE_NAME, 'readwrite');
6287 let infoStore = infoTx.objectStore(INFO_STORE_NAME);
6288 const putInfoRequest = infoStore.put({ modelPath: this.modelPath, modelArtifactsInfo });
6289 let modelTx;
6290 putInfoRequest.onsuccess = () => {
6291 // Second, put model data into model store.
6292 modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite');
6293 const modelStore = modelTx.objectStore(MODEL_STORE_NAME);
6294 const putModelRequest = modelStore.put({
6295 modelPath: this.modelPath,
6296 modelArtifacts,
6297 modelArtifactsInfo
6298 });
6299 putModelRequest.onsuccess = () => resolve({ modelArtifactsInfo });
6300 putModelRequest.onerror = error => {
6301 // If the put-model request fails, roll back the info entry as
6302 // well.
6303 infoStore = infoTx.objectStore(INFO_STORE_NAME);
6304 const deleteInfoRequest = infoStore.delete(this.modelPath);
6305 deleteInfoRequest.onsuccess = () => {
6306 db.close();
6307 return reject(putModelRequest.error);
6308 };
6309 deleteInfoRequest.onerror = error => {
6310 db.close();
6311 return reject(putModelRequest.error);
6312 };
6313 };
6314 };
6315 putInfoRequest.onerror = error => {
6316 db.close();
6317 return reject(putInfoRequest.error);
6318 };
6319 infoTx.oncomplete = () => {
6320 if (modelTx == null) {
6321 db.close();
6322 }
6323 else {
6324 modelTx.oncomplete = () => db.close();
6325 }
6326 };
6327 }
6328 };
6329 openRequest.onerror = error => reject(openRequest.error);
6330 });
6331 }
6332 }
6333 BrowserIndexedDB.URL_SCHEME = 'indexeddb://';
6334 const indexedDBRouter = (url) => {
6335 if (!env().getBool('IS_BROWSER')) {
6336 return null;
6337 }
6338 else {
6339 if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) {
6340 return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length));
6341 }
6342 else {
6343 return null;
6344 }
6345 }
6346 };
6347 IORouterRegistry.registerSaveRouter(indexedDBRouter);
6348 IORouterRegistry.registerLoadRouter(indexedDBRouter);
6349 /**
6350 * Creates a browser IndexedDB IOHandler for saving and loading models.
6351 *
6352 * ```js
6353 * const model = tf.sequential();
6354 * model.add(
6355 * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'}));
6356 *
6357 * const saveResult = await model.save('indexeddb://MyModel'));
6358 * console.log(saveResult);
6359 * ```
6360 *
6361 * @param modelPath A unique identifier for the model to be saved. Must be a
6362 * non-empty string.
6363 * @returns An instance of `BrowserIndexedDB` (sublcass of `IOHandler`),
6364 * which can be used with, e.g., `tf.Model.save`.
6365 */
6366 function browserIndexedDB(modelPath) {
6367 return new BrowserIndexedDB(modelPath);
6368 }
6369 function maybeStripScheme(key) {
6370 return key.startsWith(BrowserIndexedDB.URL_SCHEME) ?
6371 key.slice(BrowserIndexedDB.URL_SCHEME.length) :
6372 key;
6373 }
6374 class BrowserIndexedDBManager {
6375 constructor() {
6376 this.indexedDB = getIndexedDBFactory();
6377 }
6378 async listModels() {
6379 return new Promise((resolve, reject) => {
6380 const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
6381 openRequest.onupgradeneeded = () => setUpDatabase(openRequest);
6382 openRequest.onsuccess = () => {
6383 const db = openRequest.result;
6384 const tx = db.transaction(INFO_STORE_NAME, 'readonly');
6385 const store = tx.objectStore(INFO_STORE_NAME);
6386 // tslint:disable:max-line-length
6387 // Need to cast `store` as `any` here because TypeScript's DOM
6388 // library does not have the `getAll()` method even though the
6389 // method is supported in the latest version of most mainstream
6390 // browsers:
6391 // https://developer.mozilla.org/en-US/docs/Web/API/IDBObjectStore/getAll
6392 // tslint:enable:max-line-length
6393 // tslint:disable-next-line:no-any
6394 const getAllInfoRequest = store.getAll();
6395 getAllInfoRequest.onsuccess = () => {
6396 const out = {};
6397 for (const item of getAllInfoRequest.result) {
6398 out[item.modelPath] = item.modelArtifactsInfo;
6399 }
6400 resolve(out);
6401 };
6402 getAllInfoRequest.onerror = error => {
6403 db.close();
6404 return reject(getAllInfoRequest.error);
6405 };
6406 tx.oncomplete = () => db.close();
6407 };
6408 openRequest.onerror = error => reject(openRequest.error);
6409 });
6410 }
6411 async removeModel(path) {
6412 path = maybeStripScheme(path);
6413 return new Promise((resolve, reject) => {
6414 const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
6415 openRequest.onupgradeneeded = () => setUpDatabase(openRequest);
6416 openRequest.onsuccess = () => {
6417 const db = openRequest.result;
6418 const infoTx = db.transaction(INFO_STORE_NAME, 'readwrite');
6419 const infoStore = infoTx.objectStore(INFO_STORE_NAME);
6420 const getInfoRequest = infoStore.get(path);
6421 let modelTx;
6422 getInfoRequest.onsuccess = () => {
6423 if (getInfoRequest.result == null) {
6424 db.close();
6425 return reject(new Error(`Cannot find model with path '${path}' ` +
6426 `in IndexedDB.`));
6427 }
6428 else {
6429 // First, delete the entry in the info store.
6430 const deleteInfoRequest = infoStore.delete(path);
6431 const deleteModelData = () => {
6432 // Second, delete the entry in the model store.
6433 modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite');
6434 const modelStore = modelTx.objectStore(MODEL_STORE_NAME);
6435 const deleteModelRequest = modelStore.delete(path);
6436 deleteModelRequest.onsuccess = () => resolve(getInfoRequest.result.modelArtifactsInfo);
6437 deleteModelRequest.onerror = error => reject(getInfoRequest.error);
6438 };
6439 // Proceed with deleting model data regardless of whether deletion
6440 // of info data succeeds or not.
6441 deleteInfoRequest.onsuccess = deleteModelData;
6442 deleteInfoRequest.onerror = error => {
6443 deleteModelData();
6444 db.close();
6445 return reject(getInfoRequest.error);
6446 };
6447 }
6448 };
6449 getInfoRequest.onerror = error => {
6450 db.close();
6451 return reject(getInfoRequest.error);
6452 };
6453 infoTx.oncomplete = () => {
6454 if (modelTx == null) {
6455 db.close();
6456 }
6457 else {
6458 modelTx.oncomplete = () => db.close();
6459 }
6460 };
6461 };
6462 openRequest.onerror = error => reject(openRequest.error);
6463 });
6464 }
6465 }
6466
6467 /**
6468 * @license
6469 * Copyright 2018 Google LLC. All Rights Reserved.
6470 * Licensed under the Apache License, Version 2.0 (the "License");
6471 * you may not use this file except in compliance with the License.
6472 * You may obtain a copy of the License at
6473 *
6474 * http://www.apache.org/licenses/LICENSE-2.0
6475 *
6476 * Unless required by applicable law or agreed to in writing, software
6477 * distributed under the License is distributed on an "AS IS" BASIS,
6478 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6479 * See the License for the specific language governing permissions and
6480 * limitations under the License.
6481 * =============================================================================
6482 */
6483 const PATH_SEPARATOR = '/';
6484 const PATH_PREFIX = 'tensorflowjs_models';
6485 const INFO_SUFFIX = 'info';
6486 const MODEL_TOPOLOGY_SUFFIX = 'model_topology';
6487 const WEIGHT_SPECS_SUFFIX = 'weight_specs';
6488 const WEIGHT_DATA_SUFFIX = 'weight_data';
6489 const MODEL_METADATA_SUFFIX = 'model_metadata';
6490 /**
6491 * Purge all tensorflow.js-saved model artifacts from local storage.
6492 *
6493 * @returns Paths of the models purged.
6494 */
6495 function purgeLocalStorageArtifacts() {
6496 if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' ||
6497 typeof window.localStorage === 'undefined') {
6498 throw new Error('purgeLocalStorageModels() cannot proceed because local storage is ' +
6499 'unavailable in the current environment.');
6500 }
6501 const LS = window.localStorage;
6502 const purgedModelPaths = [];
6503 for (let i = 0; i < LS.length; ++i) {
6504 const key = LS.key(i);
6505 const prefix = PATH_PREFIX + PATH_SEPARATOR;
6506 if (key.startsWith(prefix) && key.length > prefix.length) {
6507 LS.removeItem(key);
6508 const modelName = getModelPathFromKey(key);
6509 if (purgedModelPaths.indexOf(modelName) === -1) {
6510 purgedModelPaths.push(modelName);
6511 }
6512 }
6513 }
6514 return purgedModelPaths;
6515 }
6516 function getModelKeys(path) {
6517 return {
6518 info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR),
6519 topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR),
6520 weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR),
6521 weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR),
6522 modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR)
6523 };
6524 }
6525 function removeItems(keys) {
6526 for (const key of Object.values(keys)) {
6527 window.localStorage.removeItem(key);
6528 }
6529 }
6530 /**
6531 * Get model path from a local-storage key.
6532 *
6533 * E.g., 'tensorflowjs_models/my/model/1/info' --> 'my/model/1'
6534 *
6535 * @param key
6536 */
6537 function getModelPathFromKey(key) {
6538 const items = key.split(PATH_SEPARATOR);
6539 if (items.length < 3) {
6540 throw new Error(`Invalid key format: ${key}`);
6541 }
6542 return items.slice(1, items.length - 1).join(PATH_SEPARATOR);
6543 }
6544 function maybeStripScheme$1(key) {
6545 return key.startsWith(BrowserLocalStorage.URL_SCHEME) ?
6546 key.slice(BrowserLocalStorage.URL_SCHEME.length) :
6547 key;
6548 }
6549 /**
6550 * IOHandler subclass: Browser Local Storage.
6551 *
6552 * See the doc string to `browserLocalStorage` for more details.
6553 */
6554 class BrowserLocalStorage {
6555 constructor(modelPath) {
6556 if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' ||
6557 typeof window.localStorage === 'undefined') {
6558 // TODO(cais): Add more info about what IOHandler subtypes are
6559 // available.
6560 // Maybe point to a doc page on the web and/or automatically determine
6561 // the available IOHandlers and print them in the error message.
6562 throw new Error('The current environment does not support local storage.');
6563 }
6564 this.LS = window.localStorage;
6565 if (modelPath == null || !modelPath) {
6566 throw new Error('For local storage, modelPath must not be null, undefined or empty.');
6567 }
6568 this.modelPath = modelPath;
6569 this.keys = getModelKeys(this.modelPath);
6570 }
6571 /**
6572 * Save model artifacts to browser local storage.
6573 *
6574 * See the documentation to `browserLocalStorage` for details on the saved
6575 * artifacts.
6576 *
6577 * @param modelArtifacts The model artifacts to be stored.
6578 * @returns An instance of SaveResult.
6579 */
6580 async save(modelArtifacts) {
6581 if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
6582 throw new Error('BrowserLocalStorage.save() does not support saving model topology ' +
6583 'in binary formats yet.');
6584 }
6585 else {
6586 const topology = JSON.stringify(modelArtifacts.modelTopology);
6587 const weightSpecs = JSON.stringify(modelArtifacts.weightSpecs);
6588 const modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts);
6589 try {
6590 this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo));
6591 this.LS.setItem(this.keys.topology, topology);
6592 this.LS.setItem(this.keys.weightSpecs, weightSpecs);
6593 this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(modelArtifacts.weightData));
6594 // Note that JSON.stringify doesn't write out keys that have undefined
6595 // values, so for some keys, we set undefined instead of a null-ish
6596 // value.
6597 const metadata = {
6598 format: modelArtifacts.format,
6599 generatedBy: modelArtifacts.generatedBy,
6600 convertedBy: modelArtifacts.convertedBy,
6601 signature: modelArtifacts.signature != null ?
6602 modelArtifacts.signature :
6603 undefined,
6604 userDefinedMetadata: modelArtifacts.userDefinedMetadata != null ?
6605 modelArtifacts.userDefinedMetadata :
6606 undefined,
6607 modelInitializer: modelArtifacts.modelInitializer != null ?
6608 modelArtifacts.modelInitializer :
6609 undefined,
6610 trainingConfig: modelArtifacts.trainingConfig != null ?
6611 modelArtifacts.trainingConfig :
6612 undefined
6613 };
6614 this.LS.setItem(this.keys.modelMetadata, JSON.stringify(metadata));
6615 return { modelArtifactsInfo };
6616 }
6617 catch (err) {
6618 // If saving failed, clean up all items saved so far.
6619 removeItems(this.keys);
6620 throw new Error(`Failed to save model '${this.modelPath}' to local storage: ` +
6621 `size quota being exceeded is a possible cause of this failure: ` +
6622 `modelTopologyBytes=${modelArtifactsInfo.modelTopologyBytes}, ` +
6623 `weightSpecsBytes=${modelArtifactsInfo.weightSpecsBytes}, ` +
6624 `weightDataBytes=${modelArtifactsInfo.weightDataBytes}.`);
6625 }
6626 }
6627 }
6628 /**
6629 * Load a model from local storage.
6630 *
6631 * See the documentation to `browserLocalStorage` for details on the saved
6632 * artifacts.
6633 *
6634 * @returns The loaded model (if loading succeeds).
6635 */
6636 async load() {
6637 const info = JSON.parse(this.LS.getItem(this.keys.info));
6638 if (info == null) {
6639 throw new Error(`In local storage, there is no model with name '${this.modelPath}'`);
6640 }
6641 if (info.modelTopologyType !== 'JSON') {
6642 throw new Error('BrowserLocalStorage does not support loading non-JSON model ' +
6643 'topology yet.');
6644 }
6645 const out = {};
6646 // Load topology.
6647 const topology = JSON.parse(this.LS.getItem(this.keys.topology));
6648 if (topology == null) {
6649 throw new Error(`In local storage, the topology of model '${this.modelPath}' ` +
6650 `is missing.`);
6651 }
6652 out.modelTopology = topology;
6653 // Load weight specs.
6654 const weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs));
6655 if (weightSpecs == null) {
6656 throw new Error(`In local storage, the weight specs of model '${this.modelPath}' ` +
6657 `are missing.`);
6658 }
6659 out.weightSpecs = weightSpecs;
6660 // Load meta-data fields.
6661 const metadataString = this.LS.getItem(this.keys.modelMetadata);
6662 if (metadataString != null) {
6663 const metadata = JSON.parse(metadataString);
6664 out.format = metadata.format;
6665 out.generatedBy = metadata.generatedBy;
6666 out.convertedBy = metadata.convertedBy;
6667 if (metadata.signature != null) {
6668 out.signature = metadata.signature;
6669 }
6670 if (metadata.userDefinedMetadata != null) {
6671 out.userDefinedMetadata = metadata.userDefinedMetadata;
6672 }
6673 if (metadata.modelInitializer != null) {
6674 out.modelInitializer = metadata.modelInitializer;
6675 }
6676 if (metadata.trainingConfig != null) {
6677 out.trainingConfig = metadata.trainingConfig;
6678 }
6679 }
6680 // Load weight data.
6681 const weightDataBase64 = this.LS.getItem(this.keys.weightData);
6682 if (weightDataBase64 == null) {
6683 throw new Error(`In local storage, the binary weight values of model ` +
6684 `'${this.modelPath}' are missing.`);
6685 }
6686 out.weightData = base64StringToArrayBuffer(weightDataBase64);
6687 return out;
6688 }
6689 }
6690 BrowserLocalStorage.URL_SCHEME = 'localstorage://';
6691 const localStorageRouter = (url) => {
6692 if (!env().getBool('IS_BROWSER')) {
6693 return null;
6694 }
6695 else {
6696 if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) {
6697 return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length));
6698 }
6699 else {
6700 return null;
6701 }
6702 }
6703 };
6704 IORouterRegistry.registerSaveRouter(localStorageRouter);
6705 IORouterRegistry.registerLoadRouter(localStorageRouter);
6706 /**
6707 * Factory function for local storage IOHandler.
6708 *
6709 * This `IOHandler` supports both `save` and `load`.
6710 *
6711 * For each model's saved artifacts, four items are saved to local storage.
6712 * - `${PATH_SEPARATOR}/${modelPath}/info`: Contains meta-info about the
6713 * model, such as date saved, type of the topology, size in bytes, etc.
6714 * - `${PATH_SEPARATOR}/${modelPath}/topology`: Model topology. For Keras-
6715 * style models, this is a stringized JSON.
6716 * - `${PATH_SEPARATOR}/${modelPath}/weight_specs`: Weight specs of the
6717 * model, can be used to decode the saved binary weight values (see
6718 * item below).
6719 * - `${PATH_SEPARATOR}/${modelPath}/weight_data`: Concatenated binary
6720 * weight values, stored as a base64-encoded string.
6721 *
6722 * Saving may throw an `Error` if the total size of the artifacts exceed the
6723 * browser-specific quota.
6724 *
6725 * @param modelPath A unique identifier for the model to be saved. Must be a
6726 * non-empty string.
6727 * @returns An instance of `IOHandler`, which can be used with, e.g.,
6728 * `tf.Model.save`.
6729 */
6730 function browserLocalStorage(modelPath) {
6731 return new BrowserLocalStorage(modelPath);
6732 }
6733 class BrowserLocalStorageManager {
6734 constructor() {
6735 assert(env().getBool('IS_BROWSER'), () => 'Current environment is not a web browser');
6736 assert(typeof window === 'undefined' ||
6737 typeof window.localStorage !== 'undefined', () => 'Current browser does not appear to support localStorage');
6738 this.LS = window.localStorage;
6739 }
6740 async listModels() {
6741 const out = {};
6742 const prefix = PATH_PREFIX + PATH_SEPARATOR;
6743 const suffix = PATH_SEPARATOR + INFO_SUFFIX;
6744 for (let i = 0; i < this.LS.length; ++i) {
6745 const key = this.LS.key(i);
6746 if (key.startsWith(prefix) && key.endsWith(suffix)) {
6747 const modelPath = getModelPathFromKey(key);
6748 out[modelPath] = JSON.parse(this.LS.getItem(key));
6749 }
6750 }
6751 return out;
6752 }
6753 async removeModel(path) {
6754 path = maybeStripScheme$1(path);
6755 const keys = getModelKeys(path);
6756 if (this.LS.getItem(keys.info) == null) {
6757 throw new Error(`Cannot find model at path '${path}'`);
6758 }
6759 const info = JSON.parse(this.LS.getItem(keys.info));
6760 removeItems(keys);
6761 return info;
6762 }
6763 }
6764
6765 /**
6766 * @license
6767 * Copyright 2018 Google LLC. All Rights Reserved.
6768 * Licensed under the Apache License, Version 2.0 (the "License");
6769 * you may not use this file except in compliance with the License.
6770 * You may obtain a copy of the License at
6771 *
6772 * http://www.apache.org/licenses/LICENSE-2.0
6773 *
6774 * Unless required by applicable law or agreed to in writing, software
6775 * distributed under the License is distributed on an "AS IS" BASIS,
6776 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6777 * See the License for the specific language governing permissions and
6778 * limitations under the License.
6779 * =============================================================================
6780 */
6781 const URL_SCHEME_SUFFIX = '://';
6782 class ModelStoreManagerRegistry {
6783 constructor() {
6784 this.managers = {};
6785 }
6786 static getInstance() {
6787 if (ModelStoreManagerRegistry.instance == null) {
6788 ModelStoreManagerRegistry.instance = new ModelStoreManagerRegistry();
6789 }
6790 return ModelStoreManagerRegistry.instance;
6791 }
6792 /**
6793 * Register a save-handler router.
6794 *
6795 * @param saveRouter A function that maps a URL-like string onto an instance
6796 * of `IOHandler` with the `save` method defined or `null`.
6797 */
6798 static registerManager(scheme, manager) {
6799 assert(scheme != null, () => 'scheme must not be undefined or null.');
6800 if (scheme.endsWith(URL_SCHEME_SUFFIX)) {
6801 scheme = scheme.slice(0, scheme.indexOf(URL_SCHEME_SUFFIX));
6802 }
6803 assert(scheme.length > 0, () => 'scheme must not be an empty string.');
6804 const registry = ModelStoreManagerRegistry.getInstance();
6805 assert(registry.managers[scheme] == null, () => `A model store manager is already registered for scheme '${scheme}'.`);
6806 registry.managers[scheme] = manager;
6807 }
6808 static getManager(scheme) {
6809 const manager = this.getInstance().managers[scheme];
6810 if (manager == null) {
6811 throw new Error(`Cannot find model manager for scheme '${scheme}'`);
6812 }
6813 return manager;
6814 }
6815 static getSchemes() {
6816 return Object.keys(this.getInstance().managers);
6817 }
6818 }
6819 /**
6820 * Helper method for parsing a URL string into a scheme and a path.
6821 *
6822 * @param url E.g., 'localstorage://my-model'
6823 * @returns A dictionary with two fields: scheme and path.
6824 * Scheme: e.g., 'localstorage' in the example above.
6825 * Path: e.g., 'my-model' in the example above.
6826 */
6827 function parseURL(url) {
6828 if (url.indexOf(URL_SCHEME_SUFFIX) === -1) {
6829 throw new Error(`The url string provided does not contain a scheme. ` +
6830 `Supported schemes are: ` +
6831 `${ModelStoreManagerRegistry.getSchemes().join(',')}`);
6832 }
6833 return {
6834 scheme: url.split(URL_SCHEME_SUFFIX)[0],
6835 path: url.split(URL_SCHEME_SUFFIX)[1],
6836 };
6837 }
6838 async function cloneModelInternal(sourceURL, destURL, deleteSource = false) {
6839 assert(sourceURL !== destURL, () => `Old path and new path are the same: '${sourceURL}'`);
6840 const loadHandlers = IORouterRegistry.getLoadHandlers(sourceURL);
6841 assert(loadHandlers.length > 0, () => `Copying failed because no load handler is found for source URL ${sourceURL}.`);
6842 assert(loadHandlers.length < 2, () => `Copying failed because more than one (${loadHandlers.length}) ` +
6843 `load handlers for source URL ${sourceURL}.`);
6844 const loadHandler = loadHandlers[0];
6845 const saveHandlers = IORouterRegistry.getSaveHandlers(destURL);
6846 assert(saveHandlers.length > 0, () => `Copying failed because no save handler is found for destination ` +
6847 `URL ${destURL}.`);
6848 assert(saveHandlers.length < 2, () => `Copying failed because more than one (${loadHandlers.length}) ` +
6849 `save handlers for destination URL ${destURL}.`);
6850 const saveHandler = saveHandlers[0];
6851 const sourceScheme = parseURL(sourceURL).scheme;
6852 const sourcePath = parseURL(sourceURL).path;
6853 const sameMedium = sourceScheme === parseURL(sourceURL).scheme;
6854 const modelArtifacts = await loadHandler.load();
6855 // If moving within the same storage medium, remove the old model as soon as
6856 // the loading is done. Without doing this, it is possible that the combined
6857 // size of the two models will cause the cloning to fail.
6858 if (deleteSource && sameMedium) {
6859 await ModelStoreManagerRegistry.getManager(sourceScheme)
6860 .removeModel(sourcePath);
6861 }
6862 const saveResult = await saveHandler.save(modelArtifacts);
6863 // If moving between mediums, the deletion is done after the save succeeds.
6864 // This guards against the case in which saving to the destination medium
6865 // fails.
6866 if (deleteSource && !sameMedium) {
6867 await ModelStoreManagerRegistry.getManager(sourceScheme)
6868 .removeModel(sourcePath);
6869 }
6870 return saveResult.modelArtifactsInfo;
6871 }
6872 /**
6873 * List all models stored in registered storage mediums.
6874 *
6875 * For a web browser environment, the registered mediums are Local Storage and
6876 * IndexedDB.
6877 *
6878 * ```js
6879 * // First create and save a model.
6880 * const model = tf.sequential();
6881 * model.add(tf.layers.dense(
6882 * {units: 1, inputShape: [10], activation: 'sigmoid'}));
6883 * await model.save('localstorage://demo/management/model1');
6884 *
6885 * // Then list existing models.
6886 * console.log(JSON.stringify(await tf.io.listModels()));
6887 *
6888 * // Delete the model.
6889 * await tf.io.removeModel('localstorage://demo/management/model1');
6890 *
6891 * // List models again.
6892 * console.log(JSON.stringify(await tf.io.listModels()));
6893 * ```
6894 *
6895 * @returns A `Promise` of a dictionary mapping URLs of existing models to
6896 * their model artifacts info. URLs include medium-specific schemes, e.g.,
6897 * 'indexeddb://my/model/1'. Model artifacts info include type of the
6898 * model's topology, byte sizes of the topology, weights, etc.
6899 *
6900 * @doc {
6901 * heading: 'Models',
6902 * subheading: 'Management',
6903 * namespace: 'io',
6904 * ignoreCI: true
6905 * }
6906 */
6907 async function listModels() {
6908 const schemes = ModelStoreManagerRegistry.getSchemes();
6909 const out = {};
6910 for (const scheme of schemes) {
6911 const schemeOut = await ModelStoreManagerRegistry.getManager(scheme).listModels();
6912 for (const path in schemeOut) {
6913 const url = scheme + URL_SCHEME_SUFFIX + path;
6914 out[url] = schemeOut[path];
6915 }
6916 }
6917 return out;
6918 }
6919 /**
6920 * Remove a model specified by URL from a reigstered storage medium.
6921 *
6922 * ```js
6923 * // First create and save a model.
6924 * const model = tf.sequential();
6925 * model.add(tf.layers.dense(
6926 * {units: 1, inputShape: [10], activation: 'sigmoid'}));
6927 * await model.save('localstorage://demo/management/model1');
6928 *
6929 * // Then list existing models.
6930 * console.log(JSON.stringify(await tf.io.listModels()));
6931 *
6932 * // Delete the model.
6933 * await tf.io.removeModel('localstorage://demo/management/model1');
6934 *
6935 * // List models again.
6936 * console.log(JSON.stringify(await tf.io.listModels()));
6937 * ```
6938 *
6939 * @param url A URL to a stored model, with a scheme prefix, e.g.,
6940 * 'localstorage://my-model-1', 'indexeddb://my/model/2'.
6941 * @returns ModelArtifactsInfo of the deleted model (if and only if deletion
6942 * is successful).
6943 * @throws Error if deletion fails, e.g., if no model exists at `path`.
6944 *
6945 * @doc {
6946 * heading: 'Models',
6947 * subheading: 'Management',
6948 * namespace: 'io',
6949 * ignoreCI: true
6950 * }
6951 */
6952 async function removeModel(url) {
6953 const schemeAndPath = parseURL(url);
6954 const manager = ModelStoreManagerRegistry.getManager(schemeAndPath.scheme);
6955 return manager.removeModel(schemeAndPath.path);
6956 }
6957 /**
6958 * Copy a model from one URL to another.
6959 *
6960 * This function supports:
6961 *
6962 * 1. Copying within a storage medium, e.g.,
6963 * `tf.io.copyModel('localstorage://model-1', 'localstorage://model-2')`
6964 * 2. Copying between two storage mediums, e.g.,
6965 * `tf.io.copyModel('localstorage://model-1', 'indexeddb://model-1')`
6966 *
6967 * ```js
6968 * // First create and save a model.
6969 * const model = tf.sequential();
6970 * model.add(tf.layers.dense(
6971 * {units: 1, inputShape: [10], activation: 'sigmoid'}));
6972 * await model.save('localstorage://demo/management/model1');
6973 *
6974 * // Then list existing models.
6975 * console.log(JSON.stringify(await tf.io.listModels()));
6976 *
6977 * // Copy the model, from Local Storage to IndexedDB.
6978 * await tf.io.copyModel(
6979 * 'localstorage://demo/management/model1',
6980 * 'indexeddb://demo/management/model1');
6981 *
6982 * // List models again.
6983 * console.log(JSON.stringify(await tf.io.listModels()));
6984 *
6985 * // Remove both models.
6986 * await tf.io.removeModel('localstorage://demo/management/model1');
6987 * await tf.io.removeModel('indexeddb://demo/management/model1');
6988 * ```
6989 *
6990 * @param sourceURL Source URL of copying.
6991 * @param destURL Destination URL of copying.
6992 * @returns ModelArtifactsInfo of the copied model (if and only if copying
6993 * is successful).
6994 * @throws Error if copying fails, e.g., if no model exists at `sourceURL`, or
6995 * if `oldPath` and `newPath` are identical.
6996 *
6997 * @doc {
6998 * heading: 'Models',
6999 * subheading: 'Management',
7000 * namespace: 'io',
7001 * ignoreCI: true
7002 * }
7003 */
7004 async function copyModel(sourceURL, destURL) {
7005 const deleteSource = false;
7006 return cloneModelInternal(sourceURL, destURL, deleteSource);
7007 }
7008 /**
7009 * Move a model from one URL to another.
7010 *
7011 * This function supports:
7012 *
7013 * 1. Moving within a storage medium, e.g.,
7014 * `tf.io.moveModel('localstorage://model-1', 'localstorage://model-2')`
7015 * 2. Moving between two storage mediums, e.g.,
7016 * `tf.io.moveModel('localstorage://model-1', 'indexeddb://model-1')`
7017 *
7018 * ```js
7019 * // First create and save a model.
7020 * const model = tf.sequential();
7021 * model.add(tf.layers.dense(
7022 * {units: 1, inputShape: [10], activation: 'sigmoid'}));
7023 * await model.save('localstorage://demo/management/model1');
7024 *
7025 * // Then list existing models.
7026 * console.log(JSON.stringify(await tf.io.listModels()));
7027 *
7028 * // Move the model, from Local Storage to IndexedDB.
7029 * await tf.io.moveModel(
7030 * 'localstorage://demo/management/model1',
7031 * 'indexeddb://demo/management/model1');
7032 *
7033 * // List models again.
7034 * console.log(JSON.stringify(await tf.io.listModels()));
7035 *
7036 * // Remove the moved model.
7037 * await tf.io.removeModel('indexeddb://demo/management/model1');
7038 * ```
7039 *
7040 * @param sourceURL Source URL of moving.
7041 * @param destURL Destination URL of moving.
7042 * @returns ModelArtifactsInfo of the copied model (if and only if copying
7043 * is successful).
7044 * @throws Error if moving fails, e.g., if no model exists at `sourceURL`, or
7045 * if `oldPath` and `newPath` are identical.
7046 *
7047 * @doc {
7048 * heading: 'Models',
7049 * subheading: 'Management',
7050 * namespace: 'io',
7051 * ignoreCI: true
7052 * }
7053 */
7054 async function moveModel(sourceURL, destURL) {
7055 const deleteSource = true;
7056 return cloneModelInternal(sourceURL, destURL, deleteSource);
7057 }
7058
7059 /**
7060 * @license
7061 * Copyright 2019 Google LLC. All Rights Reserved.
7062 * Licensed under the Apache License, Version 2.0 (the "License");
7063 * you may not use this file except in compliance with the License.
7064 * You may obtain a copy of the License at
7065 *
7066 * http://www.apache.org/licenses/LICENSE-2.0
7067 *
7068 * Unless required by applicable law or agreed to in writing, software
7069 * distributed under the License is distributed on an "AS IS" BASIS,
7070 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7071 * See the License for the specific language governing permissions and
7072 * limitations under the License.
7073 * =============================================================================
7074 */
7075 class PlatformBrowser {
7076 fetch(path, init) {
7077 return fetch(path, init);
7078 }
7079 now() {
7080 return performance.now();
7081 }
7082 encode(text, encoding) {
7083 if (encoding !== 'utf-8' && encoding !== 'utf8') {
7084 throw new Error(`Browser's encoder only supports utf-8, but got ${encoding}`);
7085 }
7086 if (this.textEncoder == null) {
7087 this.textEncoder = new TextEncoder();
7088 }
7089 return this.textEncoder.encode(text);
7090 }
7091 decode(bytes, encoding) {
7092 return new TextDecoder(encoding).decode(bytes);
7093 }
7094 }
7095 if (env().get('IS_BROWSER')) {
7096 env().setPlatform('browser', new PlatformBrowser());
7097 // Register LocalStorage IOHandler
7098 try {
7099 ModelStoreManagerRegistry.registerManager(BrowserLocalStorage.URL_SCHEME, new BrowserLocalStorageManager());
7100 }
7101 catch (err) {
7102 }
7103 // Register IndexedDB IOHandler
7104 try {
7105 ModelStoreManagerRegistry.registerManager(BrowserIndexedDB.URL_SCHEME, new BrowserIndexedDBManager());
7106 }
7107 catch (err) {
7108 }
7109 }
7110
7111 /**
7112 * @license
7113 * Copyright 2019 Google LLC. All Rights Reserved.
7114 * Licensed under the Apache License, Version 2.0 (the "License");
7115 * you may not use this file except in compliance with the License.
7116 * You may obtain a copy of the License at
7117 *
7118 * http://www.apache.org/licenses/LICENSE-2.0
7119 *
7120 * Unless required by applicable law or agreed to in writing, software
7121 * distributed under the License is distributed on an "AS IS" BASIS,
7122 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7123 * See the License for the specific language governing permissions and
7124 * limitations under the License.
7125 * =============================================================================
7126 */
7127 // We are wrapping this within an object so it can be stubbed by Jasmine.
7128 const getNodeFetch = {
7129 // tslint:disable-next-line:no-require-imports
7130 importFetch: () => require('node-fetch')
7131 };
7132 let systemFetch;
7133 // These getters and setters are for testing so we don't export a mutable
7134 // variable.
7135 function resetSystemFetch() {
7136 systemFetch = null;
7137 }
7138 function setSystemFetch(fetchFn) {
7139 systemFetch = fetchFn;
7140 }
7141 function getSystemFetch() {
7142 return systemFetch;
7143 }
7144 class PlatformNode {
7145 constructor() {
7146 // tslint:disable-next-line:no-require-imports
7147 this.util = require('util');
7148 // According to the spec, the built-in encoder can do only UTF-8 encoding.
7149 // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder
7150 this.textEncoder = new this.util.TextEncoder();
7151 }
7152 fetch(path, requestInits) {
7153 if (env().global.fetch != null) {
7154 return env().global.fetch(path, requestInits);
7155 }
7156 if (systemFetch == null) {
7157 systemFetch = getNodeFetch.importFetch();
7158 }
7159 return systemFetch(path, requestInits);
7160 }
7161 now() {
7162 const time = process.hrtime();
7163 return time[0] * 1000 + time[1] / 1000000;
7164 }
7165 encode(text, encoding) {
7166 if (encoding !== 'utf-8' && encoding !== 'utf8') {
7167 throw new Error(`Node built-in encoder only supports utf-8, but got ${encoding}`);
7168 }
7169 return this.textEncoder.encode(text);
7170 }
7171 decode(bytes, encoding) {
7172 if (bytes.length === 0) {
7173 return '';
7174 }
7175 return new this.util.TextDecoder(encoding).decode(bytes);
7176 }
7177 }
7178 if (env().get('IS_NODE') && !env().get('IS_BROWSER')) {
7179 env().setPlatform('node', new PlatformNode());
7180 }
7181
7182 /**
7183 * @license
7184 * Copyright 2020 Google Inc. All Rights Reserved.
7185 * Licensed under the Apache License, Version 2.0 (the "License");
7186 * you may not use this file except in compliance with the License.
7187 * You may obtain a copy of the License at
7188 *
7189 * http://www.apache.org/licenses/LICENSE-2.0
7190 *
7191 * Unless required by applicable law or agreed to in writing, software
7192 * distributed under the License is distributed on an "AS IS" BASIS,
7193 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7194 * See the License for the specific language governing permissions and
7195 * limitations under the License.
7196 * =============================================================================
7197 */
7198 /**
7199 * Creates an empty `tf.TensorBuffer` with the specified `shape` and `dtype`.
7200 *
7201 * The values are stored in CPU as `TypedArray`. Fill the buffer using
7202 * `buffer.set()`, or by modifying directly `buffer.values`.
7203 *
7204 * When done, call `buffer.toTensor()` to get an immutable `tf.Tensor` with
7205 * those values.
7206 *
7207 * ```js
7208 * // Create a buffer and set values at particular indices.
7209 * const buffer = tf.buffer([2, 2]);
7210 * buffer.set(3, 0, 0);
7211 * buffer.set(5, 1, 0);
7212 *
7213 * // Convert the buffer back to a tensor.
7214 * buffer.toTensor().print();
7215 * ```
7216 *
7217 * @param shape An array of integers defining the output tensor shape.
7218 * @param dtype The dtype of the buffer. Defaults to 'float32'.
7219 * @param values The values of the buffer as `TypedArray`. Defaults to
7220 * zeros.
7221 *
7222 * @doc {heading: 'Tensors', subheading: 'Creation'}
7223 */
7224 function buffer(shape, dtype = 'float32', values) {
7225 dtype = dtype || 'float32';
7226 assertNonNegativeIntegerDimensions(shape);
7227 return new TensorBuffer(shape, dtype, values);
7228 }
7229
7230 /**
7231 * @license
7232 * Copyright 2020 Google Inc. All Rights Reserved.
7233 * Licensed under the Apache License, Version 2.0 (the "License");
7234 * you may not use this file except in compliance with the License.
7235 * You may obtain a copy of the License at
7236 *
7237 * http://www.apache.org/licenses/LICENSE-2.0
7238 *
7239 * Unless required by applicable law or agreed to in writing, software
7240 * distributed under the License is distributed on an "AS IS" BASIS,
7241 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7242 * See the License for the specific language governing permissions and
7243 * limitations under the License.
7244 * =============================================================================
7245 */
7246 /**
7247 * Casts a `tf.Tensor` to a new dtype.
7248 *
7249 * ```js
7250 * const x = tf.tensor1d([1.5, 2.5, 3]);
7251 * tf.cast(x, 'int32').print();
7252 * ```
7253 * @param x The input tensor to be casted.
7254 * @param dtype The dtype to cast the input tensor to.
7255 *
7256 * @doc {heading: 'Tensors', subheading: 'Transformations'}
7257 */
7258 function cast_(x, dtype) {
7259 const $x = convertToTensor(x, 'x', 'cast');
7260 // Sanity checks.
7261 if (!isValidDtype(dtype)) {
7262 throw new Error(`Failed to cast to unknown dtype ${dtype}`);
7263 }
7264 if (dtype === 'string' && $x.dtype !== 'string' ||
7265 dtype !== 'string' && $x.dtype === 'string') {
7266 throw new Error('Only strings can be casted to strings');
7267 }
7268 const inputs = { x: $x };
7269 const attrs = { dtype };
7270 return ENGINE.runKernel(Cast, inputs, attrs);
7271 }
7272 const cast = op({ cast_ });
7273
7274 /**
7275 * @license
7276 * Copyright 2020 Google LLC. All Rights Reserved.
7277 * Licensed under the Apache License, Version 2.0 (the "License");
7278 * you may not use this file except in compliance with the License.
7279 * You may obtain a copy of the License at
7280 *
7281 * http://www.apache.org/licenses/LICENSE-2.0
7282 *
7283 * Unless required by applicable law or agreed to in writing, software
7284 * distributed under the License is distributed on an "AS IS" BASIS,
7285 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7286 * See the License for the specific language governing permissions and
7287 * limitations under the License.
7288 * =============================================================================
7289 */
7290 /**
7291 * Creates a new tensor with the same values and shape as the specified
7292 * tensor.
7293 *
7294 * ```js
7295 * const x = tf.tensor([1, 2]);
7296 *
7297 * x.clone().print();
7298 * ```
7299 *
7300 * @param x The tensor to clone.
7301 *
7302 * @doc {heading: 'Tensors', subheading: 'Creation'}
7303 */
7304 function clone_(x) {
7305 const $x = convertToTensor(x, 'x', 'clone', 'string_or_numeric');
7306 const inputs = { x: $x };
7307 // Note this op is called tf.identity in python. Hence the kernel name used
7308 // here.
7309 return ENGINE.runKernel(Identity, inputs);
7310 }
7311 const clone = op({ clone_ });
7312
7313 /**
7314 * @license
7315 * Copyright 2020 Google Inc. All Rights Reserved.
7316 * Licensed under the Apache License, Version 2.0 (the "License");
7317 * you may not use this file except in compliance with the License.
7318 * You may obtain a copy of the License at
7319 *
7320 * http://www.apache.org/licenses/LICENSE-2.0
7321 *
7322 * Unless required by applicable law or agreed to in writing, software
7323 * distributed under the License is distributed on an "AS IS" BASIS,
7324 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7325 * See the License for the specific language governing permissions and
7326 * limitations under the License.
7327 * =============================================================================
7328 */
7329 /**
7330 * Prints information about the `tf.Tensor` including its data.
7331 *
7332 * ```js
7333 * const verbose = true;
7334 * tf.tensor2d([1, 2, 3, 4], [2, 2]).print(verbose);
7335 * ```
7336 * @param x The tensor to be printed.
7337 * @param verbose Whether to print verbose information about the ` Tensor`,
7338 * including dtype and size.
7339 *
7340 * @doc {heading: 'Tensors', subheading: 'Creation'}
7341 */
7342 function print(x, verbose = false) {
7343 console.log(x.toString(verbose));
7344 }
7345
7346 /**
7347 * @license
7348 * Copyright 2020 Google Inc. All Rights Reserved.
7349 * Licensed under the Apache License, Version 2.0 (the "License");
7350 * you may not use this file except in compliance with the License.
7351 * You may obtain a copy of the License at
7352 *
7353 * http://www.apache.org/licenses/LICENSE-2.0
7354 *
7355 * Unless required by applicable law or agreed to in writing, software
7356 * distributed under the License is distributed on an "AS IS" BASIS,
7357 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7358 * See the License for the specific language governing permissions and
7359 * limitations under the License.
7360 * =============================================================================
7361 */
7362 getOrMakeEngine();
7363 const opHandler$1 = {
7364 buffer,
7365 cast,
7366 clone,
7367 print
7368 };
7369 setOpHandler(opHandler$1);
7370
7371 /**
7372 * @license
7373 * Copyright 2018 Google LLC. All Rights Reserved.
7374 * Licensed under the Apache License, Version 2.0 (the "License");
7375 * you may not use this file except in compliance with the License.
7376 * You may obtain a copy of the License at
7377 *
7378 * http://www.apache.org/licenses/LICENSE-2.0
7379 *
7380 * Unless required by applicable law or agreed to in writing, software
7381 * distributed under the License is distributed on an "AS IS" BASIS,
7382 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7383 * See the License for the specific language governing permissions and
7384 * limitations under the License.
7385 * =============================================================================
7386 */
7387 const DEFAULT_FILE_NAME_PREFIX = 'model';
7388 const DEFAULT_JSON_EXTENSION_NAME = '.json';
7389 const DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin';
7390 function defer(f) {
7391 return new Promise(resolve => setTimeout(resolve)).then(f);
7392 }
7393 class BrowserDownloads {
7394 constructor(fileNamePrefix) {
7395 if (!env().getBool('IS_BROWSER')) {
7396 // TODO(cais): Provide info on what IOHandlers are available under the
7397 // current environment.
7398 throw new Error('browserDownloads() cannot proceed because the current environment ' +
7399 'is not a browser.');
7400 }
7401 if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) {
7402 fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length);
7403 }
7404 if (fileNamePrefix == null || fileNamePrefix.length === 0) {
7405 fileNamePrefix = DEFAULT_FILE_NAME_PREFIX;
7406 }
7407 this.modelJsonFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME;
7408 this.weightDataFileName =
7409 fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME;
7410 }
7411 async save(modelArtifacts) {
7412 if (typeof (document) === 'undefined') {
7413 throw new Error('Browser downloads are not supported in ' +
7414 'this environment since `document` is not present');
7415 }
7416 const weightsURL = window.URL.createObjectURL(new Blob([modelArtifacts.weightData], { type: 'application/octet-stream' }));
7417 if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
7418 throw new Error('BrowserDownloads.save() does not support saving model topology ' +
7419 'in binary formats yet.');
7420 }
7421 else {
7422 const weightsManifest = [{
7423 paths: ['./' + this.weightDataFileName],
7424 weights: modelArtifacts.weightSpecs
7425 }];
7426 const modelJSON = getModelJSONForModelArtifacts(modelArtifacts, weightsManifest);
7427 const modelJsonURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelJSON)], { type: 'application/json' }));
7428 // If anchor elements are not provided, create them without attaching them
7429 // to parents, so that the downloaded file names can be controlled.
7430 const jsonAnchor = this.modelJsonAnchor == null ?
7431 document.createElement('a') :
7432 this.modelJsonAnchor;
7433 jsonAnchor.download = this.modelJsonFileName;
7434 jsonAnchor.href = modelJsonURL;
7435 // Trigger downloads by evoking a click event on the download anchors.
7436 // When multiple downloads are started synchronously, Firefox will only
7437 // save the last one.
7438 await defer(() => jsonAnchor.dispatchEvent(new MouseEvent('click')));
7439 if (modelArtifacts.weightData != null) {
7440 const weightDataAnchor = this.weightDataAnchor == null ?
7441 document.createElement('a') :
7442 this.weightDataAnchor;
7443 weightDataAnchor.download = this.weightDataFileName;
7444 weightDataAnchor.href = weightsURL;
7445 await defer(() => weightDataAnchor.dispatchEvent(new MouseEvent('click')));
7446 }
7447 return { modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts) };
7448 }
7449 }
7450 }
7451 BrowserDownloads.URL_SCHEME = 'downloads://';
7452 class BrowserFiles {
7453 constructor(files) {
7454 if (files == null || files.length < 1) {
7455 throw new Error(`When calling browserFiles, at least 1 file is required, ` +
7456 `but received ${files}`);
7457 }
7458 this.jsonFile = files[0];
7459 this.weightsFiles = files.slice(1);
7460 }
7461 async load() {
7462 return new Promise((resolve, reject) => {
7463 const jsonReader = new FileReader();
7464 jsonReader.onload = (event) => {
7465 // tslint:disable-next-line:no-any
7466 const modelJSON = JSON.parse(event.target.result);
7467 const modelTopology = modelJSON.modelTopology;
7468 if (modelTopology == null) {
7469 reject(new Error(`modelTopology field is missing from file ${this.jsonFile.name}`));
7470 return;
7471 }
7472 const weightsManifest = modelJSON.weightsManifest;
7473 if (weightsManifest == null) {
7474 reject(new Error(`weightManifest field is missing from file ${this.jsonFile.name}`));
7475 return;
7476 }
7477 if (this.weightsFiles.length === 0) {
7478 resolve({ modelTopology });
7479 return;
7480 }
7481 const modelArtifactsPromise = getModelArtifactsForJSON(modelJSON, (weightsManifest) => this.loadWeights(weightsManifest));
7482 resolve(modelArtifactsPromise);
7483 };
7484 jsonReader.onerror = error => reject(`Failed to read model topology and weights manifest JSON ` +
7485 `from file '${this.jsonFile.name}'. BrowserFiles supports loading ` +
7486 `Keras-style tf.Model artifacts only.`);
7487 jsonReader.readAsText(this.jsonFile);
7488 });
7489 }
7490 loadWeights(weightsManifest) {
7491 const weightSpecs = [];
7492 const paths = [];
7493 for (const entry of weightsManifest) {
7494 weightSpecs.push(...entry.weights);
7495 paths.push(...entry.paths);
7496 }
7497 const pathToFile = this.checkManifestAndWeightFiles(weightsManifest);
7498 const promises = paths.map(path => this.loadWeightsFile(path, pathToFile[path]));
7499 return Promise.all(promises).then(buffers => [weightSpecs, concatenateArrayBuffers(buffers)]);
7500 }
7501 loadWeightsFile(path, file) {
7502 return new Promise((resolve, reject) => {
7503 const weightFileReader = new FileReader();
7504 weightFileReader.onload = (event) => {
7505 // tslint:disable-next-line:no-any
7506 const weightData = event.target.result;
7507 resolve(weightData);
7508 };
7509 weightFileReader.onerror = error => reject(`Failed to weights data from file of path '${path}'.`);
7510 weightFileReader.readAsArrayBuffer(file);
7511 });
7512 }
7513 /**
7514 * Check the compatibility between weights manifest and weight files.
7515 */
7516 checkManifestAndWeightFiles(manifest) {
7517 const basenames = [];
7518 const fileNames = this.weightsFiles.map(file => basename(file.name));
7519 const pathToFile = {};
7520 for (const group of manifest) {
7521 group.paths.forEach(path => {
7522 const pathBasename = basename(path);
7523 if (basenames.indexOf(pathBasename) !== -1) {
7524 throw new Error(`Duplicate file basename found in weights manifest: ` +
7525 `'${pathBasename}'`);
7526 }
7527 basenames.push(pathBasename);
7528 if (fileNames.indexOf(pathBasename) === -1) {
7529 throw new Error(`Weight file with basename '${pathBasename}' is not provided.`);
7530 }
7531 else {
7532 pathToFile[path] = this.weightsFiles[fileNames.indexOf(pathBasename)];
7533 }
7534 });
7535 }
7536 if (basenames.length !== this.weightsFiles.length) {
7537 throw new Error(`Mismatch in the number of files in weights manifest ` +
7538 `(${basenames.length}) and the number of weight files provided ` +
7539 `(${this.weightsFiles.length}).`);
7540 }
7541 return pathToFile;
7542 }
7543 }
7544 const browserDownloadsRouter = (url) => {
7545 if (!env().getBool('IS_BROWSER')) {
7546 return null;
7547 }
7548 else {
7549 if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) {
7550 return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length));
7551 }
7552 else {
7553 return null;
7554 }
7555 }
7556 };
7557 IORouterRegistry.registerSaveRouter(browserDownloadsRouter);
7558 /**
7559 * Creates an IOHandler that triggers file downloads from the browser.
7560 *
7561 * The returned `IOHandler` instance can be used as model exporting methods such
7562 * as `tf.Model.save` and supports only saving.
7563 *
7564 * ```js
7565 * const model = tf.sequential();
7566 * model.add(tf.layers.dense(
7567 * {units: 1, inputShape: [10], activation: 'sigmoid'}));
7568 * const saveResult = await model.save('downloads://mymodel');
7569 * // This will trigger downloading of two files:
7570 * // 'mymodel.json' and 'mymodel.weights.bin'.
7571 * console.log(saveResult);
7572 * ```
7573 *
7574 * @param fileNamePrefix Prefix name of the files to be downloaded. For use with
7575 * `tf.Model`, `fileNamePrefix` should follow either of the following two
7576 * formats:
7577 * 1. `null` or `undefined`, in which case the default file
7578 * names will be used:
7579 * - 'model.json' for the JSON file containing the model topology and
7580 * weights manifest.
7581 * - 'model.weights.bin' for the binary file containing the binary weight
7582 * values.
7583 * 2. A single string or an Array of a single string, as the file name prefix.
7584 * For example, if `'foo'` is provided, the downloaded JSON
7585 * file and binary weights file will be named 'foo.json' and
7586 * 'foo.weights.bin', respectively.
7587 * @param config Additional configuration for triggering downloads.
7588 * @returns An instance of `BrowserDownloads` `IOHandler`.
7589 *
7590 * @doc {
7591 * heading: 'Models',
7592 * subheading: 'Loading',
7593 * namespace: 'io',
7594 * ignoreCI: true
7595 * }
7596 */
7597 function browserDownloads(fileNamePrefix = 'model') {
7598 return new BrowserDownloads(fileNamePrefix);
7599 }
7600 /**
7601 * Creates an IOHandler that loads model artifacts from user-selected files.
7602 *
7603 * This method can be used for loading from files such as user-selected files
7604 * in the browser.
7605 * When used in conjunction with `tf.loadLayersModel`, an instance of
7606 * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
7607 *
7608 * ```js
7609 * // Note: This code snippet won't run properly without the actual file input
7610 * // elements in the HTML DOM.
7611 *
7612 * // Suppose there are two HTML file input (`<input type="file" ...>`)
7613 * // elements.
7614 * const uploadJSONInput = document.getElementById('upload-json');
7615 * const uploadWeightsInput = document.getElementById('upload-weights');
7616 * const model = await tf.loadLayersModel(tf.io.browserFiles(
7617 * [uploadJSONInput.files[0], uploadWeightsInput.files[0]]));
7618 * ```
7619 *
7620 * @param files `File`s to load from. Currently, this function supports only
7621 * loading from files that contain Keras-style models (i.e., `tf.Model`s), for
7622 * which an `Array` of `File`s is expected (in that order):
7623 * - A JSON file containing the model topology and weight manifest.
7624 * - Optionally, One or more binary files containing the binary weights.
7625 * These files must have names that match the paths in the `weightsManifest`
7626 * contained by the aforementioned JSON file, or errors will be thrown
7627 * during loading. These weights files have the same format as the ones
7628 * generated by `tensorflowjs_converter` that comes with the `tensorflowjs`
7629 * Python PIP package. If no weights files are provided, only the model
7630 * topology will be loaded from the JSON file above.
7631 * @returns An instance of `Files` `IOHandler`.
7632 *
7633 * @doc {
7634 * heading: 'Models',
7635 * subheading: 'Loading',
7636 * namespace: 'io',
7637 * ignoreCI: true
7638 * }
7639 */
7640 function browserFiles(files) {
7641 return new BrowserFiles(files);
7642 }
7643
7644 /**
7645 * @license
7646 * Copyright 2019 Google LLC. All Rights Reserved.
7647 * Licensed under the Apache License, Version 2.0 (the "License");
7648 * you may not use this file except in compliance with the License.
7649 * You may obtain a copy of the License at
7650 *
7651 * http://www.apache.org/licenses/LICENSE-2.0
7652 *
7653 * Unless required by applicable law or agreed to in writing, software
7654 * distributed under the License is distributed on an "AS IS" BASIS,
7655 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7656 * See the License for the specific language governing permissions and
7657 * limitations under the License.
7658 * =============================================================================
7659 */
7660 /**
7661 * Monitor Promise.all progress, fire onProgress callback function.
7662 *
7663 * @param promises Promise list going to be monitored
7664 * @param onProgress Callback function. Fired when a promise resolved.
7665 * @param startFraction Optional fraction start. Default to 0.
7666 * @param endFraction Optional fraction end. Default to 1.
7667 */
7668 function monitorPromisesProgress(promises, onProgress, startFraction, endFraction) {
7669 checkPromises(promises);
7670 startFraction = startFraction == null ? 0 : startFraction;
7671 endFraction = endFraction == null ? 1 : endFraction;
7672 checkFraction(startFraction, endFraction);
7673 let resolvedPromise = 0;
7674 const registerMonitor = (promise) => {
7675 promise.then(value => {
7676 const fraction = startFraction +
7677 ++resolvedPromise / promises.length * (endFraction - startFraction);
7678 // pass fraction as parameter to callback function.
7679 onProgress(fraction);
7680 return value;
7681 });
7682 return promise;
7683 };
7684 function checkPromises(promises) {
7685 assert(promises != null && Array.isArray(promises) && promises.length > 0, () => 'promises must be a none empty array');
7686 }
7687 function checkFraction(startFraction, endFraction) {
7688 assert(startFraction >= 0 && startFraction <= 1, () => `Progress fraction must be in range [0, 1], but ` +
7689 `got startFraction ${startFraction}`);
7690 assert(endFraction >= 0 && endFraction <= 1, () => `Progress fraction must be in range [0, 1], but ` +
7691 `got endFraction ${endFraction}`);
7692 assert(endFraction >= startFraction, () => `startFraction must be no more than endFraction, but ` +
7693 `got startFraction ${startFraction} and endFraction ` +
7694 `${endFraction}`);
7695 }
7696 return Promise.all(promises.map(registerMonitor));
7697 }
7698
7699 /**
7700 * @license
7701 * Copyright 2018 Google LLC. All Rights Reserved.
7702 * Licensed under the Apache License, Version 2.0 (the "License");
7703 * you may not use this file except in compliance with the License.
7704 * You may obtain a copy of the License at
7705 *
7706 * http://www.apache.org/licenses/LICENSE-2.0
7707 *
7708 * Unless required by applicable law or agreed to in writing, software
7709 * distributed under the License is distributed on an "AS IS" BASIS,
7710 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7711 * See the License for the specific language governing permissions and
7712 * limitations under the License.
7713 * =============================================================================
7714 */
7715 /**
7716 * Reads binary weights data from a number of URLs.
7717 *
7718 * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls.
7719 * @param requestOptions RequestInit (options) for the HTTP requests.
7720 * @param fetchFunc Optional overriding value for the `window.fetch` function.
7721 * @param onProgress Optional, progress callback function, fired periodically
7722 * before the load is completed.
7723 * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same
7724 * length as `fetchURLs`.
7725 */
7726 async function loadWeightsAsArrayBuffer(fetchURLs, loadOptions) {
7727 if (loadOptions == null) {
7728 loadOptions = {};
7729 }
7730 const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch :
7731 loadOptions.fetchFunc;
7732 // Create the requests for all of the weights in parallel.
7733 const requests = fetchURLs.map(fetchURL => fetchFunc(fetchURL, loadOptions.requestInit, { isBinary: true }));
7734 const fetchStartFraction = 0;
7735 const fetchEndFraction = 0.5;
7736 const responses = loadOptions.onProgress == null ?
7737 await Promise.all(requests) :
7738 await monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction);
7739 const bufferPromises = responses.map(response => response.arrayBuffer());
7740 const bufferStartFraction = 0.5;
7741 const bufferEndFraction = 1;
7742 const buffers = loadOptions.onProgress == null ?
7743 await Promise.all(bufferPromises) :
7744 await monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction);
7745 return buffers;
7746 }
7747 /**
7748 * Reads a weights manifest JSON configuration, fetches the weights and
7749 * returns them as `Tensor`s.
7750 *
7751 * @param manifest The weights manifest JSON.
7752 * @param filePathPrefix The path prefix for filenames given in the manifest.
7753 * Defaults to the empty string.
7754 * @param weightNames The names of the weights to be fetched.
7755 */
7756 async function loadWeights(manifest, filePathPrefix = '', weightNames, requestInit) {
7757 // TODO(nsthorat): Groups are currently fetched atomically. If you need a
7758 // single weight from a group, the whole group will be fetched. At a future
7759 // date, we should support fetching only the individual shards within a
7760 // group that are needed to reconstruct the requested weight.
7761 // TODO(cais): Use `decodeWeights` for implementation.
7762 const fetchWeights = (fetchUrls) => loadWeightsAsArrayBuffer(fetchUrls, { requestInit });
7763 const loadWeights = weightsLoaderFactory(fetchWeights);
7764 return loadWeights(manifest, filePathPrefix, weightNames);
7765 }
7766 /**
7767 * Creates a function, which reads a weights manifest JSON configuration,
7768 * fetches the weight files using the specified function and returns them as
7769 * `Tensor`s.
7770 *
7771 * ```js
7772 * // example for creating a nodejs weight loader, which reads the weight files
7773 * // from disk using fs.readFileSync
7774 *
7775 * import * as fs from 'fs'
7776 *
7777 * const fetchWeightsFromDisk = (filePaths: string[]) =>
7778 * filePaths.map(filePath => fs.readFileSync(filePath).buffer)
7779 *
7780 * const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk)
7781 *
7782 * const manifest = JSON.parse(
7783 * fs.readFileSync('./my_model-weights_manifest').toString()
7784 * )
7785 * const weightMap = await loadWeights(manifest, './')
7786 * ```
7787 * @param fetchWeightsFunction The function used for fetching the weight files.
7788 * @returns Weight loading function.
7789 */
7790 function weightsLoaderFactory(fetchWeightsFunction) {
7791 return async (manifest, filePathPrefix = '', weightNames) => {
7792 // Collect all the groups, weights, and their relative offsets to be
7793 // fetched.
7794 const groupIndicesToFetchMap = manifest.map(() => false);
7795 const groupWeightsToFetch = {};
7796 const weightsFound = weightNames != null ? weightNames.map(() => false) : [];
7797 const allManifestWeightNames = [];
7798 manifest.forEach((manifestGroupConfig, groupIndex) => {
7799 let groupOffset = 0;
7800 manifestGroupConfig.weights.forEach(weightsEntry => {
7801 const rawDtype = ('quantization' in weightsEntry) ?
7802 weightsEntry.quantization.dtype :
7803 weightsEntry.dtype;
7804 const weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] *
7805 sizeFromShape(weightsEntry.shape);
7806 const enqueueWeightsForFetchingFn = () => {
7807 groupIndicesToFetchMap[groupIndex] = true;
7808 if (groupWeightsToFetch[groupIndex] == null) {
7809 groupWeightsToFetch[groupIndex] = [];
7810 }
7811 groupWeightsToFetch[groupIndex].push({
7812 manifestEntry: weightsEntry,
7813 groupOffset,
7814 sizeBytes: weightsBytes
7815 });
7816 };
7817 if (weightNames != null) {
7818 weightNames.forEach((weightName, weightIndex) => {
7819 if (weightName === weightsEntry.name) {
7820 enqueueWeightsForFetchingFn();
7821 weightsFound[weightIndex] = true;
7822 }
7823 });
7824 }
7825 else {
7826 enqueueWeightsForFetchingFn();
7827 }
7828 allManifestWeightNames.push(weightsEntry.name);
7829 groupOffset += weightsBytes;
7830 });
7831 });
7832 if (!weightsFound.every(found => found)) {
7833 const weightsNotFound = weightNames.filter((_, i) => !weightsFound[i]);
7834 throw new Error(`Could not find weights in manifest with names: ` +
7835 `${weightsNotFound.join(', ')}. \n` +
7836 `Manifest JSON has weights with names: ` +
7837 `${allManifestWeightNames.join(', ')}.`);
7838 }
7839 // Convert the one-hot boolean groupId => shouldFetch map to a list of group
7840 // IDs.
7841 const groupIndicesToFetch = groupIndicesToFetchMap.reduce((accumulator, shouldFetch, i) => {
7842 if (shouldFetch) {
7843 accumulator.push(i);
7844 }
7845 return accumulator;
7846 }, []);
7847 const fetchUrls = [];
7848 groupIndicesToFetch.forEach(i => {
7849 manifest[i].paths.forEach(filepath => {
7850 const fetchUrl = filePathPrefix +
7851 (!filePathPrefix.endsWith('/') ? '/' : '') + filepath;
7852 fetchUrls.push(fetchUrl);
7853 });
7854 });
7855 const buffers = await fetchWeightsFunction(fetchUrls);
7856 const weightsTensorMap = {};
7857 let bufferIndexOffset = 0;
7858 groupIndicesToFetch.forEach(i => {
7859 const numBuffers = manifest[i].paths.length;
7860 let groupBytes = 0;
7861 for (let i = 0; i < numBuffers; i++) {
7862 groupBytes += buffers[bufferIndexOffset + i].byteLength;
7863 }
7864 // Create a buffer for the whole group.
7865 const groupBuffer = new ArrayBuffer(groupBytes);
7866 const groupByteBuffer = new Uint8Array(groupBuffer);
7867 let groupBufferOffset = 0;
7868 for (let i = 0; i < numBuffers; i++) {
7869 const buffer = new Uint8Array(buffers[bufferIndexOffset + i]);
7870 groupByteBuffer.set(buffer, groupBufferOffset);
7871 groupBufferOffset += buffer.byteLength;
7872 }
7873 const weightsEntries = groupWeightsToFetch[i];
7874 weightsEntries.forEach(weightsEntry => {
7875 const byteBuffer = groupBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes);
7876 const nameToTensorMap = decodeWeights(byteBuffer, [weightsEntry.manifestEntry]);
7877 for (const name in nameToTensorMap) {
7878 weightsTensorMap[name] = nameToTensorMap[name];
7879 }
7880 });
7881 bufferIndexOffset += numBuffers;
7882 });
7883 return weightsTensorMap;
7884 };
7885 }
7886
7887 /**
7888 * @license
7889 * Copyright 2018 Google LLC. All Rights Reserved.
7890 * Licensed under the Apache License, Version 2.0 (the "License");
7891 * you may not use this file except in compliance with the License.
7892 * You may obtain a copy of the License at
7893 *
7894 * http://www.apache.org/licenses/LICENSE-2.0
7895 *
7896 * Unless required by applicable law or agreed to in writing, software
7897 * distributed under the License is distributed on an "AS IS" BASIS,
7898 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7899 * See the License for the specific language governing permissions and
7900 * limitations under the License.
7901 * =============================================================================
7902 */
7903 const OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
7904 const JSON_TYPE = 'application/json';
7905 class HTTPRequest {
7906 constructor(path, loadOptions) {
7907 this.DEFAULT_METHOD = 'POST';
7908 if (loadOptions == null) {
7909 loadOptions = {};
7910 }
7911 this.weightPathPrefix = loadOptions.weightPathPrefix;
7912 this.onProgress = loadOptions.onProgress;
7913 this.weightUrlConverter = loadOptions.weightUrlConverter;
7914 if (loadOptions.fetchFunc != null) {
7915 assert(typeof loadOptions.fetchFunc === 'function', () => 'Must pass a function that matches the signature of ' +
7916 '`fetch` (see ' +
7917 'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)');
7918 this.fetch = loadOptions.fetchFunc;
7919 }
7920 else {
7921 this.fetch = env().platform.fetch;
7922 }
7923 assert(path != null && path.length > 0, () => 'URL path for http must not be null, undefined or ' +
7924 'empty.');
7925 if (Array.isArray(path)) {
7926 assert(path.length === 2, () => 'URL paths for http must have a length of 2, ' +
7927 `(actual length is ${path.length}).`);
7928 }
7929 this.path = path;
7930 if (loadOptions.requestInit != null &&
7931 loadOptions.requestInit.body != null) {
7932 throw new Error('requestInit is expected to have no pre-existing body, but has one.');
7933 }
7934 this.requestInit = loadOptions.requestInit || {};
7935 }
7936 async save(modelArtifacts) {
7937 if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
7938 throw new Error('BrowserHTTPRequest.save() does not support saving model topology ' +
7939 'in binary formats yet.');
7940 }
7941 const init = Object.assign({ method: this.DEFAULT_METHOD }, this.requestInit);
7942 init.body = new FormData();
7943 const weightsManifest = [{
7944 paths: ['./model.weights.bin'],
7945 weights: modelArtifacts.weightSpecs,
7946 }];
7947 const modelTopologyAndWeightManifest = getModelJSONForModelArtifacts(modelArtifacts, weightsManifest);
7948 init.body.append('model.json', new Blob([JSON.stringify(modelTopologyAndWeightManifest)], { type: JSON_TYPE }), 'model.json');
7949 if (modelArtifacts.weightData != null) {
7950 init.body.append('model.weights.bin', new Blob([modelArtifacts.weightData], { type: OCTET_STREAM_MIME_TYPE }), 'model.weights.bin');
7951 }
7952 const response = await this.fetch(this.path, init);
7953 if (response.ok) {
7954 return {
7955 modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts),
7956 responses: [response],
7957 };
7958 }
7959 else {
7960 throw new Error(`BrowserHTTPRequest.save() failed due to HTTP response status ` +
7961 `${response.status}.`);
7962 }
7963 }
7964 /**
7965 * Load model artifacts via HTTP request(s).
7966 *
7967 * See the documentation to `tf.io.http` for details on the saved
7968 * artifacts.
7969 *
7970 * @returns The loaded model artifacts (if loading succeeds).
7971 */
7972 async load() {
7973 const modelConfigRequest = await this.fetch(this.path, this.requestInit);
7974 if (!modelConfigRequest.ok) {
7975 throw new Error(`Request to ${this.path} failed with status code ` +
7976 `${modelConfigRequest.status}. Please verify this URL points to ` +
7977 `the model JSON of the model to load.`);
7978 }
7979 let modelJSON;
7980 try {
7981 modelJSON = await modelConfigRequest.json();
7982 }
7983 catch (e) {
7984 let message = `Failed to parse model JSON of response from ${this.path}.`;
7985 // TODO(nsthorat): Remove this after some time when we're comfortable that
7986 // .pb files are mostly gone.
7987 if (this.path.endsWith('.pb')) {
7988 message += ' Your path contains a .pb file extension. ' +
7989 'Support for .pb models have been removed in TensorFlow.js 1.0 ' +
7990 'in favor of .json models. You can re-convert your Python ' +
7991 'TensorFlow model using the TensorFlow.js 1.0 conversion scripts ' +
7992 'or you can convert your.pb models with the \'pb2json\'' +
7993 'NPM script in the tensorflow/tfjs-converter repository.';
7994 }
7995 else {
7996 message += ' Please make sure the server is serving valid ' +
7997 'JSON for this request.';
7998 }
7999 throw new Error(message);
8000 }
8001 // We do not allow both modelTopology and weightsManifest to be missing.
8002 const modelTopology = modelJSON.modelTopology;
8003 const weightsManifest = modelJSON.weightsManifest;
8004 if (modelTopology == null && weightsManifest == null) {
8005 throw new Error(`The JSON from HTTP path ${this.path} contains neither model ` +
8006 `topology or manifest for weights.`);
8007 }
8008 return getModelArtifactsForJSON(modelJSON, (weightsManifest) => this.loadWeights(weightsManifest));
8009 }
8010 async loadWeights(weightsManifest) {
8011 const weightPath = Array.isArray(this.path) ? this.path[1] : this.path;
8012 const [prefix, suffix] = parseUrl(weightPath);
8013 const pathPrefix = this.weightPathPrefix || prefix;
8014 const weightSpecs = [];
8015 for (const entry of weightsManifest) {
8016 weightSpecs.push(...entry.weights);
8017 }
8018 const fetchURLs = [];
8019 const urlPromises = [];
8020 for (const weightsGroup of weightsManifest) {
8021 for (const path of weightsGroup.paths) {
8022 if (this.weightUrlConverter != null) {
8023 urlPromises.push(this.weightUrlConverter(path));
8024 }
8025 else {
8026 fetchURLs.push(pathPrefix + path + suffix);
8027 }
8028 }
8029 }
8030 if (this.weightUrlConverter) {
8031 fetchURLs.push(...await Promise.all(urlPromises));
8032 }
8033 const buffers = await loadWeightsAsArrayBuffer(fetchURLs, {
8034 requestInit: this.requestInit,
8035 fetchFunc: this.fetch,
8036 onProgress: this.onProgress
8037 });
8038 return [weightSpecs, concatenateArrayBuffers(buffers)];
8039 }
8040 }
8041 HTTPRequest.URL_SCHEME_REGEX = /^https?:\/\//;
8042 /**
8043 * Extract the prefix and suffix of the url, where the prefix is the path before
8044 * the last file, and suffix is the search params after the last file.
8045 * ```
8046 * const url = 'http://tfhub.dev/model/1/tensorflowjs_model.pb?tfjs-format=file'
8047 * [prefix, suffix] = parseUrl(url)
8048 * // prefix = 'http://tfhub.dev/model/1/'
8049 * // suffix = '?tfjs-format=file'
8050 * ```
8051 * @param url the model url to be parsed.
8052 */
8053 function parseUrl(url) {
8054 const lastSlash = url.lastIndexOf('/');
8055 const lastSearchParam = url.lastIndexOf('?');
8056 const prefix = url.substring(0, lastSlash);
8057 const suffix = lastSearchParam > lastSlash ? url.substring(lastSearchParam) : '';
8058 return [prefix + '/', suffix];
8059 }
8060 function isHTTPScheme(url) {
8061 return url.match(HTTPRequest.URL_SCHEME_REGEX) != null;
8062 }
8063 const httpRouter = (url, loadOptions) => {
8064 if (typeof fetch === 'undefined' &&
8065 (loadOptions == null || loadOptions.fetchFunc == null)) {
8066 // `http` uses `fetch` or `node-fetch`, if one wants to use it in
8067 // an environment that is not the browser or node they have to setup a
8068 // global fetch polyfill.
8069 return null;
8070 }
8071 else {
8072 let isHTTP = true;
8073 if (Array.isArray(url)) {
8074 isHTTP = url.every(urlItem => isHTTPScheme(urlItem));
8075 }
8076 else {
8077 isHTTP = isHTTPScheme(url);
8078 }
8079 if (isHTTP) {
8080 return http(url, loadOptions);
8081 }
8082 }
8083 return null;
8084 };
8085 IORouterRegistry.registerSaveRouter(httpRouter);
8086 IORouterRegistry.registerLoadRouter(httpRouter);
8087 /**
8088 * Creates an IOHandler subtype that sends model artifacts to HTTP server.
8089 *
8090 * An HTTP request of the `multipart/form-data` mime type will be sent to the
8091 * `path` URL. The form data includes artifacts that represent the topology
8092 * and/or weights of the model. In the case of Keras-style `tf.Model`, two
8093 * blobs (files) exist in form-data:
8094 * - A JSON file consisting of `modelTopology` and `weightsManifest`.
8095 * - A binary weights file consisting of the concatenated weight values.
8096 * These files are in the same format as the one generated by
8097 * [tfjs_converter](https://js.tensorflow.org/tutorials/import-keras.html).
8098 *
8099 * The following code snippet exemplifies the client-side code that uses this
8100 * function:
8101 *
8102 * ```js
8103 * const model = tf.sequential();
8104 * model.add(
8105 * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'}));
8106 *
8107 * const saveResult = await model.save(tf.io.http(
8108 * 'http://model-server:5000/upload', {requestInit: {method: 'PUT'}}));
8109 * console.log(saveResult);
8110 * ```
8111 *
8112 * If the default `POST` method is to be used, without any custom parameters
8113 * such as headers, you can simply pass an HTTP or HTTPS URL to `model.save`:
8114 *
8115 * ```js
8116 * const saveResult = await model.save('http://model-server:5000/upload');
8117 * ```
8118 *
8119 * The following GitHub Gist
8120 * https://gist.github.com/dsmilkov/1b6046fd6132d7408d5257b0976f7864
8121 * implements a server based on [flask](https://github.com/pallets/flask) that
8122 * can receive the request. Upon receiving the model artifacts via the requst,
8123 * this particular server reconsistutes instances of [Keras
8124 * Models](https://keras.io/models/model/) in memory.
8125 *
8126 *
8127 * @param path A URL path to the model.
8128 * Can be an absolute HTTP path (e.g.,
8129 * 'http://localhost:8000/model-upload)') or a relative path (e.g.,
8130 * './model-upload').
8131 * @param requestInit Request configurations to be used when sending
8132 * HTTP request to server using `fetch`. It can contain fields such as
8133 * `method`, `credentials`, `headers`, `mode`, etc. See
8134 * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request
8135 * for more information. `requestInit` must not have a body, because the
8136 * body will be set by TensorFlow.js. File blobs representing the model
8137 * topology (filename: 'model.json') and the weights of the model (filename:
8138 * 'model.weights.bin') will be appended to the body. If `requestInit` has a
8139 * `body`, an Error will be thrown.
8140 * @param loadOptions Optional configuration for the loading. It includes the
8141 * following fields:
8142 * - weightPathPrefix Optional, this specifies the path prefix for weight
8143 * files, by default this is calculated from the path param.
8144 * - fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
8145 * the `fetch` from node-fetch can be used here.
8146 * - onProgress Optional, progress callback function, fired periodically
8147 * before the load is completed.
8148 * @returns An instance of `IOHandler`.
8149 *
8150 * @doc {
8151 * heading: 'Models',
8152 * subheading: 'Loading',
8153 * namespace: 'io',
8154 * ignoreCI: true
8155 * }
8156 */
8157 function http(path, loadOptions) {
8158 return new HTTPRequest(path, loadOptions);
8159 }
8160 /**
8161 * Deprecated. Use `tf.io.http`.
8162 * @param path
8163 * @param loadOptions
8164 */
8165 function browserHTTPRequest(path, loadOptions) {
8166 return http(path, loadOptions);
8167 }
8168
8169 /**
8170 * @license
8171 * Copyright 2018 Google LLC. All Rights Reserved.
8172 * Licensed under the Apache License, Version 2.0 (the "License");
8173 * you may not use this file except in compliance with the License.
8174 * You may obtain a copy of the License at
8175 *
8176 * http://www.apache.org/licenses/LICENSE-2.0
8177 *
8178 * Unless required by applicable law or agreed to in writing, software
8179 * distributed under the License is distributed on an "AS IS" BASIS,
8180 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8181 * See the License for the specific language governing permissions and
8182 * limitations under the License.
8183 * =============================================================================
8184 */
8185 class PassthroughLoader {
8186 constructor(modelArtifacts) {
8187 this.modelArtifacts = modelArtifacts;
8188 }
8189 load() {
8190 return this.modelArtifacts;
8191 }
8192 }
8193 class PassthroughSaver {
8194 constructor(saveHandler) {
8195 this.saveHandler = saveHandler;
8196 }
8197 save(modelArtifacts) {
8198 return this.saveHandler(modelArtifacts);
8199 }
8200 }
8201 class PassthroughAsync {
8202 constructor(handler) {
8203 if (handler.load) {
8204 this.load = () => Promise.resolve(handler.load());
8205 }
8206 if (handler.save) {
8207 this.save = (modelArtifacts) => Promise.resolve(handler.save(modelArtifacts));
8208 }
8209 }
8210 }
8211 /**
8212 * Creates an IOHandler that loads model artifacts from memory.
8213 *
8214 * When used in conjunction with `tf.loadLayersModel`, an instance of
8215 * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
8216 *
8217 * ```js
8218 * const model = await tf.loadLayersModel(tf.io.fromMemory(
8219 * modelTopology, weightSpecs, weightData));
8220 * ```
8221 *
8222 * @param modelArtifacts a object containing model topology (i.e., parsed from
8223 * the JSON format).
8224 * @param weightSpecs An array of `WeightsManifestEntry` objects describing the
8225 * names, shapes, types, and quantization of the weight data. Optional.
8226 * @param weightData A single `ArrayBuffer` containing the weight data,
8227 * concatenated in the order described by the weightSpecs. Optional.
8228 * @param trainingConfig Model training configuration. Optional.
8229 *
8230 * @returns A passthrough `IOHandler` that simply loads the provided data.
8231 */
8232 function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) {
8233 const args = arguments;
8234 return new PassthroughAsync(fromMemorySync(...args));
8235 }
8236 /**
8237 * Creates an IOHandler that loads model artifacts from memory.
8238 *
8239 * When used in conjunction with `tf.loadLayersModel`, an instance of
8240 * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
8241 *
8242 * ```js
8243 * const model = await tf.loadLayersModel(tf.io.fromMemory(
8244 * modelTopology, weightSpecs, weightData));
8245 * ```
8246 *
8247 * @param modelArtifacts a object containing model topology (i.e., parsed from
8248 * the JSON format).
8249 * @param weightSpecs An array of `WeightsManifestEntry` objects describing the
8250 * names, shapes, types, and quantization of the weight data. Optional.
8251 * @param weightData A single `ArrayBuffer` containing the weight data,
8252 * concatenated in the order described by the weightSpecs. Optional.
8253 * @param trainingConfig Model training configuration. Optional.
8254 *
8255 * @returns A passthrough `IOHandlerSync` that simply loads the provided data.
8256 */
8257 function fromMemorySync(modelArtifacts, weightSpecs, weightData, trainingConfig) {
8258 if (arguments.length === 1) {
8259 const isModelArtifacts = modelArtifacts.modelTopology != null ||
8260 modelArtifacts.weightSpecs != null;
8261 if (isModelArtifacts) {
8262 return new PassthroughLoader(modelArtifacts);
8263 }
8264 else {
8265 // Legacy support: with only modelTopology.
8266 // TODO(cais): Remove this deprecated API.
8267 console.warn('Please call tf.io.fromMemory() with only one argument. ' +
8268 'The argument should be of type ModelArtifacts. ' +
8269 'The multi-argument signature of tf.io.fromMemory() has been ' +
8270 'deprecated and will be removed in a future release.');
8271 return new PassthroughLoader({ modelTopology: modelArtifacts });
8272 }
8273 }
8274 else {
8275 // Legacy support.
8276 // TODO(cais): Remove this deprecated API.
8277 console.warn('Please call tf.io.fromMemory() with only one argument. ' +
8278 'The argument should be of type ModelArtifacts. ' +
8279 'The multi-argument signature of tf.io.fromMemory() has been ' +
8280 'deprecated and will be removed in a future release.');
8281 return new PassthroughLoader({
8282 modelTopology: modelArtifacts,
8283 weightSpecs,
8284 weightData,
8285 trainingConfig
8286 });
8287 }
8288 }
8289 /**
8290 * Creates an IOHandler that passes saved model artifacts to a callback.
8291 *
8292 * ```js
8293 * function handleSave(artifacts) {
8294 * // ... do something with the artifacts ...
8295 * return {modelArtifactsInfo: {...}, ...};
8296 * }
8297 *
8298 * const saveResult = model.save(tf.io.withSaveHandler(handleSave));
8299 * ```
8300 *
8301 * @param saveHandler A function that accepts a `ModelArtifacts` and returns a
8302 * promise that resolves to a `SaveResult`.
8303 */
8304 function withSaveHandler(saveHandler) {
8305 return new PassthroughSaver(saveHandler);
8306 }
8307 /**
8308 * Creates an IOHandlerSync that passes saved model artifacts to a callback.
8309 *
8310 * ```js
8311 * function handleSave(artifacts) {
8312 * // ... do something with the artifacts ...
8313 * return {modelArtifactsInfo: {...}, ...};
8314 * }
8315 *
8316 * const saveResult = model.save(tf.io.withSaveHandler(handleSave));
8317 * ```
8318 *
8319 * @param saveHandler A function that accepts a `ModelArtifacts` and returns a
8320 * `SaveResult`.
8321 */
8322 function withSaveHandlerSync(saveHandler) {
8323 return new PassthroughSaver(saveHandler);
8324 }
8325
8326 /**
8327 * @license
8328 * Copyright 2018 Google LLC. All Rights Reserved.
8329 * Licensed under the Apache License, Version 2.0 (the "License");
8330 * you may not use this file except in compliance with the License.
8331 * You may obtain a copy of the License at
8332 *
8333 * http://www.apache.org/licenses/LICENSE-2.0
8334 *
8335 * Unless required by applicable law or agreed to in writing, software
8336 * distributed under the License is distributed on an "AS IS" BASIS,
8337 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8338 * See the License for the specific language governing permissions and
8339 * limitations under the License.
8340 * =============================================================================
8341 */
8342
8343 var io = /*#__PURE__*/Object.freeze({
8344 __proto__: null,
8345 browserFiles: browserFiles,
8346 browserHTTPRequest: browserHTTPRequest,
8347 concatenateArrayBuffers: concatenateArrayBuffers,
8348 decodeWeights: decodeWeights,
8349 encodeWeights: encodeWeights,
8350 fromMemory: fromMemory,
8351 fromMemorySync: fromMemorySync,
8352 getLoadHandlers: getLoadHandlers,
8353 getModelArtifactsForJSON: getModelArtifactsForJSON,
8354 getModelArtifactsInfoForJSON: getModelArtifactsInfoForJSON,
8355 getSaveHandlers: getSaveHandlers,
8356 http: http,
8357 isHTTPScheme: isHTTPScheme,
8358 loadWeights: loadWeights,
8359 registerLoadRouter: registerLoadRouter,
8360 registerSaveRouter: registerSaveRouter,
8361 weightsLoaderFactory: weightsLoaderFactory,
8362 withSaveHandler: withSaveHandler,
8363 withSaveHandlerSync: withSaveHandlerSync,
8364 copyModel: copyModel,
8365 listModels: listModels,
8366 moveModel: moveModel,
8367 removeModel: removeModel
8368 });
8369
8370 /**
8371 * @license
8372 * Copyright 2020 Google LLC. All Rights Reserved.
8373 * Licensed under the Apache License, Version 2.0 (the "License");
8374 * you may not use this file except in compliance with the License.
8375 * You may obtain a copy of the License at
8376 *
8377 * http://www.apache.org/licenses/LICENSE-2.0
8378 *
8379 * Unless required by applicable law or agreed to in writing, software
8380 * distributed under the License is distributed on an "AS IS" BASIS,
8381 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8382 * See the License for the specific language governing permissions and
8383 * limitations under the License.
8384 * =============================================================================
8385 */
8386 /**
8387 * Computes the dot product of two matrices, A * B. These must be matrices.
8388 *
8389 * ```js
8390 * const a = tf.tensor2d([1, 2], [1, 2]);
8391 * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]);
8392 *
8393 * a.matMul(b).print(); // or tf.matMul(a, b)
8394 * ```
8395 * @param a First matrix in dot product operation.
8396 * @param b Second matrix in dot product operation.
8397 * @param transposeA If true, `a` is transposed before multiplication.
8398 * @param transposeB If true, `b` is transposed before multiplication.
8399 *
8400 * @doc {heading: 'Operations', subheading: 'Matrices'}
8401 */
8402 function matMul_(a, b, transposeA = false, transposeB = false) {
8403 let $a = convertToTensor(a, 'a', 'matMul');
8404 let $b = convertToTensor(b, 'b', 'matMul');
8405 [$a, $b] = makeTypesMatch($a, $b);
8406 const inputs = { a: $a, b: $b };
8407 const attrs = { transposeA, transposeB };
8408 return ENGINE.runKernel(BatchMatMul, inputs, attrs);
8409 }
8410 const matMul = op({ matMul_ });
8411
8412 /**
8413 * @license
8414 * Copyright 2020 Google LLC. All Rights Reserved.
8415 * Licensed under the Apache License, Version 2.0 (the "License");
8416 * you may not use this file except in compliance with the License.
8417 * You may obtain a copy of the License at
8418 *
8419 * http://www.apache.org/licenses/LICENSE-2.0
8420 *
8421 * Unless required by applicable law or agreed to in writing, software
8422 * distributed under the License is distributed on an "AS IS" BASIS,
8423 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8424 * See the License for the specific language governing permissions and
8425 * limitations under the License.
8426 * =============================================================================
8427 */
8428 /**
8429 * Creates a one-hot `tf.Tensor`. The locations represented by `indices` take
8430 * value `onValue` (defaults to 1), while all other locations take value
8431 * `offValue` (defaults to 0). If `indices` is rank `R`, the output has rank
8432 * `R+1` with the last axis of size `depth`.
8433 * `indices` used to encode prediction class must start from 0. For example,
8434 * if you have 3 classes of data, class 1 should be encoded as 0, class 2
8435 * should be 1, and class 3 should be 2.
8436 *
8437 * ```js
8438 * tf.oneHot(tf.tensor1d([0, 1], 'int32'), 3).print();
8439 * ```
8440 *
8441 * @param indices `tf.Tensor` of indices with dtype `int32`. Indices must
8442 * start from 0.
8443 * @param depth The depth of the one hot dimension.
8444 * @param onValue A number used to fill in the output when the index matches
8445 * the location.
8446 * @param offValue A number used to fill in the output when the index does
8447 * not match the location.
8448 *
8449 * @doc {heading: 'Tensors', subheading: 'Creation'}
8450 */
8451 function oneHot_(indices, depth, onValue = 1, offValue = 0) {
8452 if (depth < 2) {
8453 throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`);
8454 }
8455 const $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32');
8456 const inputs = { indices: $indices };
8457 const attrs = { depth, onValue, offValue };
8458 return ENGINE.runKernel(OneHot, inputs, attrs);
8459 }
8460 const oneHot = op({ oneHot_ });
8461
8462 /**
8463 * @license
8464 * Copyright 2018 Google LLC. All Rights Reserved.
8465 * Licensed under the Apache License, Version 2.0 (the "License");
8466 * you may not use this file except in compliance with the License.
8467 * You may obtain a copy of the License at
8468 *
8469 * http://www.apache.org/licenses/LICENSE-2.0
8470 *
8471 * Unless required by applicable law or agreed to in writing, software
8472 * distributed under the License is distributed on an "AS IS" BASIS,
8473 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8474 * See the License for the specific language governing permissions and
8475 * limitations under the License.
8476 * =============================================================================
8477 */
8478 /**
8479 * Enables production mode which disables correctness checks in favor of
8480 * performance.
8481 *
8482 * @doc {heading: 'Environment'}
8483 */
8484 function enableProdMode() {
8485 env().set('PROD', true);
8486 }
8487 /**
8488 * Enables debug mode which will log information about all executed kernels:
8489 * the elapsed time of the kernel execution, as well as the rank, shape, and
8490 * size of the output tensor.
8491 *
8492 * Debug mode will significantly slow down your application as it will
8493 * download the result of every operation to the CPU. This should not be used in
8494 * production. Debug mode does not affect the timing information of the kernel
8495 * execution as we do not measure download time in the kernel execution time.
8496 *
8497 * See also: `tf.profile`, `tf.memory`.
8498 *
8499 * @doc {heading: 'Environment'}
8500 */
8501 function enableDebugMode() {
8502 env().set('DEBUG', true);
8503 }
8504 /** Globally disables deprecation warnings */
8505 function disableDeprecationWarnings() {
8506 env().set('DEPRECATION_WARNINGS_ENABLED', false);
8507 console.warn(`TensorFlow.js deprecation warnings have been disabled.`);
8508 }
8509 /** Warn users about deprecated functionality. */
8510 function deprecationWarn(msg) {
8511 if (env().getBool('DEPRECATION_WARNINGS_ENABLED')) {
8512 console.warn(msg + ' You can disable deprecation warnings with ' +
8513 'tf.disableDeprecationWarnings().');
8514 }
8515 }
8516 setDeprecationWarningFn(deprecationWarn);
8517 /**
8518 * Dispose all variables kept in backend engine.
8519 *
8520 * @doc {heading: 'Environment'}
8521 */
8522 function disposeVariables() {
8523 ENGINE.disposeVariables();
8524 }
8525 /**
8526 * It returns the global engine that keeps track of all tensors and backends.
8527 *
8528 * @doc {heading: 'Environment'}
8529 */
8530 function engine() {
8531 return ENGINE;
8532 }
8533 /**
8534 * Returns memory info at the current time in the program. The result is an
8535 * object with the following properties:
8536 *
8537 * - `numBytes`: Number of bytes allocated (undisposed) at this time.
8538 * - `numTensors`: Number of unique tensors allocated.
8539 * - `numDataBuffers`: Number of unique data buffers allocated
8540 * (undisposed) at this time, which is ≤ the number of tensors
8541 * (e.g. `a.reshape(newShape)` makes a new Tensor that shares the same
8542 * data buffer with `a`).
8543 * - `unreliable`: True if the memory usage is unreliable. See `reasons` when
8544 * `unreliable` is true.
8545 * - `reasons`: `string[]`, reasons why the memory is unreliable, present if
8546 * `unreliable` is true.
8547 *
8548 * WebGL Properties:
8549 * - `numBytesInGPU`: Number of bytes allocated (undisposed) in the GPU only at
8550 * this time.
8551 *
8552 * @doc {heading: 'Performance', subheading: 'Memory'}
8553 */
8554 function memory() {
8555 return ENGINE.memory();
8556 }
8557 /**
8558 * Executes the provided function `f()` and returns a promise that resolves
8559 * with information about the function's memory use:
8560 * - `newBytes`: the number of new bytes allocated
8561 * - `newTensors`: the number of new tensors created
8562 * - `peakBytes`: the peak number of bytes allocated
8563 * - `kernels`: an array of objects for each kernel involved that reports
8564 * their input and output shapes, number of bytes used, and number of new
8565 * tensors created.
8566 * - `kernelNames`: an array of unique strings with just the names of the
8567 * kernels in the `kernels` array.
8568 *
8569 * ```js
8570 * const profile = await tf.profile(() => {
8571 * const x = tf.tensor1d([1, 2, 3]);
8572 * let x2 = x.square();
8573 * x2.dispose();
8574 * x2 = x.square();
8575 * x2.dispose();
8576 * return x;
8577 * });
8578 *
8579 * console.log(`newBytes: ${profile.newBytes}`);
8580 * console.log(`newTensors: ${profile.newTensors}`);
8581 * console.log(`byte usage over all kernels: ${profile.kernels.map(k =>
8582 * k.totalBytesSnapshot)}`);
8583 * ```
8584 *
8585 *
8586 * @doc {heading: 'Performance', subheading: 'Profile'}
8587 */
8588 function profile(f) {
8589 return ENGINE.profile(f);
8590 }
8591 /**
8592 * Executes the provided function `fn` and after it is executed, cleans up all
8593 * intermediate tensors allocated by `fn` except those returned by `fn`.
8594 * `fn` must not return a Promise (async functions not allowed). The returned
8595 * result can be a complex object.
8596 *
8597 * Using this method helps avoid memory leaks. In general, wrap calls to
8598 * operations in `tf.tidy` for automatic memory cleanup.
8599 *
8600 * NOTE: Variables do *not* get cleaned up when inside a tidy(). If you want to
8601 * dispose variables, please use `tf.disposeVariables` or call dispose()
8602 * directly on variables.
8603 *
8604 * ```js
8605 * // y = 2 ^ 2 + 1
8606 * const y = tf.tidy(() => {
8607 * // a, b, and one will be cleaned up when the tidy ends.
8608 * const one = tf.scalar(1);
8609 * const a = tf.scalar(2);
8610 * const b = a.square();
8611 *
8612 * console.log('numTensors (in tidy): ' + tf.memory().numTensors);
8613 *
8614 * // The value returned inside the tidy function will return
8615 * // through the tidy, in this case to the variable y.
8616 * return b.add(one);
8617 * });
8618 *
8619 * console.log('numTensors (outside tidy): ' + tf.memory().numTensors);
8620 * y.print();
8621 * ```
8622 *
8623 * @param nameOrFn The name of the closure, or the function to execute.
8624 * If a name is provided, the 2nd argument should be the function.
8625 * If debug mode is on, the timing and the memory usage of the function
8626 * will be tracked and displayed on the console using the provided name.
8627 * @param fn The function to execute.
8628 *
8629 * @doc {heading: 'Performance', subheading: 'Memory'}
8630 */
8631 function tidy(nameOrFn, fn) {
8632 return ENGINE.tidy(nameOrFn, fn);
8633 }
8634 /**
8635 * Disposes any `tf.Tensor`s found within the provided object.
8636 *
8637 * @param container an object that may be a `tf.Tensor` or may directly
8638 * contain `tf.Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. If
8639 * the object is not a `tf.Tensor` or does not contain `Tensors`, nothing
8640 * happens. In general it is safe to pass any object here, except that
8641 * `Promise`s are not supported.
8642 *
8643 * @doc {heading: 'Performance', subheading: 'Memory'}
8644 */
8645 function dispose(container) {
8646 const tensors = getTensorsInContainer(container);
8647 tensors.forEach(tensor => tensor.dispose());
8648 }
8649 /**
8650 * Keeps a `tf.Tensor` generated inside a `tf.tidy` from being disposed
8651 * automatically.
8652 *
8653 * ```js
8654 * let b;
8655 * const y = tf.tidy(() => {
8656 * const one = tf.scalar(1);
8657 * const a = tf.scalar(2);
8658 *
8659 * // b will not be cleaned up by the tidy. a and one will be cleaned up
8660 * // when the tidy ends.
8661 * b = tf.keep(a.square());
8662 *
8663 * console.log('numTensors (in tidy): ' + tf.memory().numTensors);
8664 *
8665 * // The value returned inside the tidy function will return
8666 * // through the tidy, in this case to the variable y.
8667 * return b.add(one);
8668 * });
8669 *
8670 * console.log('numTensors (outside tidy): ' + tf.memory().numTensors);
8671 * console.log('y:');
8672 * y.print();
8673 * console.log('b:');
8674 * b.print();
8675 * ```
8676 *
8677 * @param result The tensor to keep from being disposed.
8678 *
8679 * @doc {heading: 'Performance', subheading: 'Memory'}
8680 */
8681 function keep(result) {
8682 return ENGINE.keep(result);
8683 }
8684 /**
8685 * Executes `f()` and returns a promise that resolves with timing
8686 * information.
8687 *
8688 * The result is an object with the following properties:
8689 *
8690 * - `wallMs`: Wall execution time.
8691 * - `kernelMs`: Kernel execution time, ignoring data transfer. If using the
8692 * WebGL backend and the query timer extension is not available, this will
8693 * return an error object.
8694 * - On `WebGL` The following additional properties exist:
8695 * - `uploadWaitMs`: CPU blocking time on texture uploads.
8696 * - `downloadWaitMs`: CPU blocking time on texture downloads (readPixels).
8697 *
8698 * ```js
8699 * const x = tf.randomNormal([20, 20]);
8700 * const time = await tf.time(() => x.matMul(x));
8701 *
8702 * console.log(`kernelMs: ${time.kernelMs}, wallTimeMs: ${time.wallMs}`);
8703 * ```
8704 *
8705 * @param f The function to execute and time.
8706 *
8707 * @doc {heading: 'Performance', subheading: 'Timing'}
8708 */
8709 function time(f) {
8710 return ENGINE.time(f);
8711 }
8712 /**
8713 * Sets the backend (cpu, webgl, wasm, etc) responsible for creating tensors and
8714 * executing operations on those tensors. Returns a promise that resolves
8715 * to a boolean if the backend initialization was successful.
8716 *
8717 * Note this disposes the current backend, if any, as well as any tensors
8718 * associated with it. A new backend is initialized, even if it is of the
8719 * same type as the previous one.
8720 *
8721 * @param backendName The name of the backend. Currently supports
8722 * `'webgl'|'cpu'` in the browser, `'tensorflow'` under node.js
8723 * (requires tfjs-node), and `'wasm'` (requires tfjs-backend-wasm).
8724 *
8725 * @doc {heading: 'Backends'}
8726 */
8727 function setBackend(backendName) {
8728 return ENGINE.setBackend(backendName);
8729 }
8730 /**
8731 * Returns a promise that resolves when the currently selected backend (or the
8732 * highest priority one) has initialized. Await this promise when you are using
8733 * a backend that has async initialization.
8734 *
8735 * @doc {heading: 'Backends'}
8736 */
8737 function ready() {
8738 return ENGINE.ready();
8739 }
8740 /**
8741 * Returns the current backend name (cpu, webgl, etc). The backend is
8742 * responsible for creating tensors and executing operations on those tensors.
8743 *
8744 * @doc {heading: 'Backends'}
8745 */
8746 function getBackend() {
8747 return ENGINE.backendName;
8748 }
8749 /**
8750 * Removes a backend and the registered factory.
8751 *
8752 * @doc {heading: 'Backends'}
8753 */
8754 function removeBackend(name) {
8755 ENGINE.removeBackend(name);
8756 }
8757 /**
8758 * Finds the backend registered under the provided name. Returns null if the
8759 * name is not in the registry, or the registration hasn't finished yet.
8760 */
8761 function findBackend(name) {
8762 return ENGINE.findBackend(name);
8763 }
8764 /**
8765 * Finds the backend factory registered under the provided name. Returns a
8766 * function that produces a new backend when called. Returns null if the name
8767 * is not in the registry.
8768 */
8769 function findBackendFactory(name) {
8770 return ENGINE.findBackendFactory(name);
8771 }
8772 /**
8773 * Registers a global backend. The registration should happen when importing
8774 * a module file (e.g. when importing `backend_webgl.ts`), and is used for
8775 * modular builds (e.g. custom tfjs bundle with only webgl support).
8776 *
8777 * @param factory The backend factory function. When called, it should
8778 * return a backend instance, or a promise of an instance.
8779 * @param priority The priority of the backend (higher = more important).
8780 * In case multiple backends are registered, the priority is used to find
8781 * the best backend. Defaults to 1.
8782 * @return False if there is already a registered backend under this name, true
8783 * if not.
8784 *
8785 * @doc {heading: 'Backends'}
8786 */
8787 function registerBackend(name, factory, priority = 1) {
8788 return ENGINE.registerBackend(name, factory, priority);
8789 }
8790 /**
8791 * Gets the current backend. If no backends have been initialized, this will
8792 * attempt to initialize the best backend. Will throw an error if the highest
8793 * priority backend has async initialization, in which case, you should call
8794 * 'await tf.ready()' before running other code.
8795 *
8796 * @doc {heading: 'Backends'}
8797 */
8798 function backend() {
8799 return ENGINE.backend;
8800 }
8801 /**
8802 * Sets the global platform.
8803 *
8804 * @param platformName The name of this platform.
8805 * @param platform A platform implementation.
8806 */
8807 function setPlatform(platformName, platform) {
8808 env().setPlatform(platformName, platform);
8809 }
8810
8811 /**
8812 * @license
8813 * Copyright 2020 Google LLC. All Rights Reserved.
8814 * Licensed under the Apache License, Version 2.0 (the "License");
8815 * you may not use this file except in compliance with the License.
8816 * You may obtain a copy of the License at
8817 *
8818 * http://www.apache.org/licenses/LICENSE-2.0
8819 *
8820 * Unless required by applicable law or agreed to in writing, software
8821 * distributed under the License is distributed on an "AS IS" BASIS,
8822 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8823 * See the License for the specific language governing permissions and
8824 * limitations under the License.
8825 * =============================================================================
8826 */
8827 /**
8828 * Returns the imaginary part of a complex (or real) tensor.
8829 *
8830 * Given a tensor input, this operation returns a tensor of type float that is
8831 * the imaginary part of each element in input considered as a complex number.
8832 * If input is real, a tensor of all zeros is returned.
8833 *
8834 * ```js
8835 * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]);
8836 * tf.imag(x).print();
8837 * ```
8838 *
8839 * @doc {heading: 'Tensors', subheading: 'Creation'}
8840 */
8841 function imag_(input) {
8842 const $input = convertToTensor(input, 'input', 'imag');
8843 const inputs = { input: $input };
8844 return ENGINE.runKernel(Imag, inputs);
8845 }
8846 const imag = op({ imag_ });
8847
8848 /**
8849 * @license
8850 * Copyright 2018 Google LLC. All Rights Reserved.
8851 * Licensed under the Apache License, Version 2.0 (the "License");
8852 * you may not use this file except in compliance with the License.
8853 * You may obtain a copy of the License at
8854 *
8855 * http://www.apache.org/licenses/LICENSE-2.0
8856 *
8857 * Unless required by applicable law or agreed to in writing, software
8858 * distributed under the License is distributed on an "AS IS" BASIS,
8859 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8860 * See the License for the specific language governing permissions and
8861 * limitations under the License.
8862 * =============================================================================
8863 */
8864 /**
8865 * Computes `-1 * x` element-wise.
8866 *
8867 * ```js
8868 * const x = tf.tensor2d([1, 2, -2, 0], [2, 2]);
8869 *
8870 * x.neg().print(); // or tf.neg(x)
8871 * ```
8872 *
8873 * @param x The input tensor.
8874 *
8875 * @doc {heading: 'Operations', subheading: 'Basic math'}
8876 */
8877 function neg_(x) {
8878 const $x = convertToTensor(x, 'x', 'neg');
8879 const inputs = { x: $x };
8880 return ENGINE.runKernel(Neg, inputs);
8881 }
8882 const neg = op({ neg_ });
8883
8884 /**
8885 * @license
8886 * Copyright 2020 Google LLC. All Rights Reserved.
8887 * Licensed under the Apache License, Version 2.0 (the "License");
8888 * you may not use this file except in compliance with the License.
8889 * You may obtain a copy of the License at
8890 *
8891 * http://www.apache.org/licenses/LICENSE-2.0
8892 *
8893 * Unless required by applicable law or agreed to in writing, software
8894 * distributed under the License is distributed on an "AS IS" BASIS,
8895 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8896 * See the License for the specific language governing permissions and
8897 * limitations under the License.
8898 * =============================================================================
8899 */
8900 /**
8901 * Returns the real part of a complex (or real) tensor.
8902 *
8903 * Given a tensor input, this operation returns a tensor of type float that is
8904 * the real part of each element in input considered as a complex number.
8905 *
8906 * If the input is real, it simply makes a clone.
8907 *
8908 * ```js
8909 * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]);
8910 * tf.real(x).print();
8911 * ```
8912 *
8913 * @doc {heading: 'Tensors', subheading: 'Creation'}
8914 */
8915 function real_(input) {
8916 const $input = convertToTensor(input, 'input', 'real');
8917 const inputs = { input: $input };
8918 return ENGINE.runKernel(Real, inputs);
8919 }
8920 const real = op({ real_ });
8921
8922 /**
8923 * @license
8924 * Copyright 2018 Google LLC. All Rights Reserved.
8925 * Licensed under the Apache License, Version 2.0 (the "License");
8926 * you may not use this file except in compliance with the License.
8927 * You may obtain a copy of the License at
8928 *
8929 * http://www.apache.org/licenses/LICENSE-2.0
8930 *
8931 * Unless required by applicable law or agreed to in writing, software
8932 * distributed under the License is distributed on an "AS IS" BASIS,
8933 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8934 * See the License for the specific language governing permissions and
8935 * limitations under the License.
8936 * =============================================================================
8937 */
8938 /**
8939 * Transposes the `tf.Tensor`. Permutes the dimensions according to `perm`.
8940 *
8941 * The returned `tf.Tensor`'s dimension `i` will correspond to the input
8942 * dimension `perm[i]`. If `perm` is not given, it is set to `[n-1...0]`,
8943 * where `n` is the rank of the input `tf.Tensor`. Hence by default, this
8944 * operation performs a regular matrix transpose on 2-D input `tf.Tensor`s.
8945 *
8946 * ```js
8947 * const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
8948 *
8949 * a.transpose().print(); // or tf.transpose(a)
8950 * ```
8951 *
8952 * @param x The tensor to transpose.
8953 * @param perm The permutation of the dimensions of a.
8954 * @param conjugate Will conjugate complex input if true.
8955 *
8956 * @doc {heading: 'Operations', subheading: 'Matrices'}
8957 */
8958 function transpose_(x, perm, conjugate) {
8959 const $x = convertToTensor(x, 'x', 'transpose');
8960 if (perm == null) {
8961 perm = $x.shape.map((s, i) => i).reverse();
8962 }
8963 assert($x.rank === perm.length, () => `Error in transpose: rank of input ${$x.rank} ` +
8964 `must match length of perm ${perm}.`);
8965 perm.forEach(axis => {
8966 assert(axis >= 0 && axis < $x.rank, () => `All entries in 'perm' must be between 0 and ${$x.rank - 1}` +
8967 ` but got ${perm}`);
8968 });
8969 if ($x.rank <= 1) {
8970 return $x.clone();
8971 }
8972 const inputs = { x: $x };
8973 const attrs = { perm };
8974 if ($x.dtype === 'complex64') {
8975 return tidy(() => {
8976 let $real = real($x);
8977 let $imag = imag($x);
8978 $real = ENGINE.runKernel(Transpose, { x: $real }, attrs);
8979 $imag = ENGINE.runKernel(Transpose, { x: $imag }, attrs);
8980 if (conjugate) {
8981 $imag = neg($imag);
8982 }
8983 return complex($real, $imag);
8984 });
8985 }
8986 return ENGINE.runKernel(Transpose, inputs, attrs);
8987 }
8988 const transpose = op({ transpose_ });
8989
8990 /**
8991 * @license
8992 * Copyright 2018 Google LLC. All Rights Reserved.
8993 * Licensed under the Apache License, Version 2.0 (the "License");
8994 * you may not use this file except in compliance with the License.
8995 * You may obtain a copy of the License at
8996 *
8997 * http://www.apache.org/licenses/LICENSE-2.0
8998 *
8999 * Unless required by applicable law or agreed to in writing, software
9000 * distributed under the License is distributed on an "AS IS" BASIS,
9001 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9002 * See the License for the specific language governing permissions and
9003 * limitations under the License.
9004 * =============================================================================
9005 */
9006 /**
9007 * Computes the confusion matrix from true labels and predicted labels.
9008 *
9009 * ```js
9010 * const labels = tf.tensor1d([0, 1, 2, 1, 0], 'int32');
9011 * const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'int32');
9012 * const numClasses = 3;
9013 * const out = tf.math.confusionMatrix(labels, predictions, numClasses);
9014 * out.print();
9015 * // Expected output matrix:
9016 * // [[2, 0, 0],
9017 * // [0, 1, 1],
9018 * // [0, 0, 1]]
9019 * ```
9020 *
9021 * @param labels The target labels, assumed to be 0-based integers
9022 * for the classes. The shape is `[numExamples]`, where
9023 * `numExamples` is the number of examples included.
9024 * @param predictions The predicted classes, assumed to be
9025 * 0-based integers for the classes. Must have the same shape as `labels`.
9026 * @param numClasses Number of all classes, as an integer.
9027 * Its value must be larger than the largest element in `labels` and
9028 * `predictions`.
9029 * @returns The confusion matrix as a int32-type 2D tensor. The value at
9030 * row `r` and column `c` is the number of times examples of actual class
9031 * `r` were predicted as class `c`.
9032 *
9033 * @doc {heading: 'Operations', subheading: 'Evaluation'}
9034 */
9035 function confusionMatrix_(labels, predictions, numClasses) {
9036 const $labels = convertToTensor(labels, 'labels', 'confusionMatrix');
9037 const $predictions = convertToTensor(predictions, 'predictions', 'confusionMatrix');
9038 assert(numClasses == null || numClasses > 0 && Number.isInteger(numClasses), () => `If provided, numClasses must be a positive integer, ` +
9039 `but got ${numClasses}`);
9040 assert($labels.rank === 1, () => `Expected the rank of labels to be 1, but got ${$labels.rank}`);
9041 assert($predictions.rank === 1, () => `Expected the rank of predictions to be 1, ` +
9042 `but got ${$predictions.rank}`);
9043 assert($labels.shape[0] === $predictions.shape[0], () => `Mismatch in the number of examples: ` +
9044 `${$labels.shape[0]} vs. ${$predictions.shape[0]}. ` +
9045 `Labels and predictions should have the same number of elements.`);
9046 assert(numClasses > 0 && Number.isInteger(numClasses), () => `numClasses is required to be a positive integer, but got ` +
9047 `${numClasses}`);
9048 // TODO(cais): In the future, if oneHot supports tensors inputs for
9049 // `numClasses`, `confusionMatrix` can make `numClasses` optional.
9050 const oneHotLabels = oneHot(cast($labels, 'int32'), numClasses);
9051 const oneHotPredictions = oneHot(cast($predictions, 'int32'), numClasses);
9052 const oneHotLabelsT = transpose(oneHotLabels);
9053 const product = matMul(oneHotLabelsT, oneHotPredictions);
9054 return cast(product, 'int32');
9055 }
9056 const confusionMatrix = op({ confusionMatrix_ });
9057
9058 /**
9059 * @license
9060 * Copyright 2018 Google LLC. All Rights Reserved.
9061 * Licensed under the Apache License, Version 2.0 (the "License");
9062 * you may not use this file except in compliance with the License.
9063 * You may obtain a copy of the License at
9064 *
9065 * http://www.apache.org/licenses/LICENSE-2.0
9066 *
9067 * Unless required by applicable law or agreed to in writing, software
9068 * distributed under the License is distributed on an "AS IS" BASIS,
9069 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9070 * See the License for the specific language governing permissions and
9071 * limitations under the License.
9072 * =============================================================================
9073 */
9074
9075 var math = /*#__PURE__*/Object.freeze({
9076 __proto__: null,
9077 confusionMatrix: confusionMatrix
9078 });
9079
9080 /**
9081 * @license
9082 * Copyright 2017 Google LLC. All Rights Reserved.
9083 * Licensed under the Apache License, Version 2.0 (the "License");
9084 * you may not use this file except in compliance with the License.
9085 * You may obtain a copy of the License at
9086 *
9087 * http://www.apache.org/licenses/LICENSE-2.0
9088 *
9089 * Unless required by applicable law or agreed to in writing, software
9090 * distributed under the License is distributed on an "AS IS" BASIS,
9091 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9092 * See the License for the specific language governing permissions and
9093 * limitations under the License.
9094 * =============================================================================
9095 */
9096 /**
9097 * Returns the dimensions in the input shape that are broadcasted to
9098 * produce the provided output shape.
9099 *
9100 * The returned dimensions are 0-indexed and sorted. An example:
9101 * inShape = [4, 1, 3]
9102 * outShape = [5, 4, 3, 3]
9103 * result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3.
9104 */
9105 function getBroadcastDims(inShape, outShape) {
9106 const inRank = inShape.length;
9107 const dims = [];
9108 for (let i = 0; i < inRank; i++) {
9109 const dim = inRank - 1 - i;
9110 const a = inShape[dim] || 1;
9111 const b = outShape[outShape.length - 1 - i] || 1;
9112 if (b > 1 && a === 1) {
9113 dims.unshift(dim);
9114 }
9115 }
9116 return dims;
9117 }
9118 /**
9119 * Returns the axes in the output space that should be reduced to produce
9120 * the input space.
9121 */
9122 function getReductionAxes(inShape, outShape) {
9123 const result = [];
9124 for (let i = 0; i < outShape.length; i++) {
9125 const inDim = inShape[inShape.length - i - 1];
9126 const outAxis = outShape.length - i - 1;
9127 const outDim = outShape[outAxis];
9128 if (inDim == null || (inDim === 1 && outDim > 1)) {
9129 result.unshift(outAxis);
9130 }
9131 }
9132 return result;
9133 }
9134 function assertAndGetBroadcastShape(shapeA, shapeB) {
9135 const result = [];
9136 const l = Math.max(shapeA.length, shapeB.length);
9137 for (let i = 0; i < l; i++) {
9138 let a = shapeA[shapeA.length - i - 1];
9139 if (a == null) {
9140 a = 1;
9141 }
9142 let b = shapeB[shapeB.length - i - 1];
9143 if (b == null) {
9144 b = 1;
9145 }
9146 if (a === 1) {
9147 result.unshift(b);
9148 }
9149 else if (b === 1) {
9150 result.unshift(a);
9151 }
9152 else if (a !== b) {
9153 const errMsg = `Operands could not be broadcast together with shapes ` +
9154 `${shapeA} and ${shapeB}.`;
9155 throw Error(errMsg);
9156 }
9157 else {
9158 result.unshift(a);
9159 }
9160 }
9161 return result;
9162 }
9163
9164 var broadcast_util = /*#__PURE__*/Object.freeze({
9165 __proto__: null,
9166 getBroadcastDims: getBroadcastDims,
9167 getReductionAxes: getReductionAxes,
9168 assertAndGetBroadcastShape: assertAndGetBroadcastShape
9169 });
9170
9171 /**
9172 * @license
9173 * Copyright 2018 Google LLC. All Rights Reserved.
9174 * Licensed under the Apache License, Version 2.0 (the "License");
9175 * you may not use this file except in compliance with the License.
9176 * You may obtain a copy of the License at
9177 *
9178 * http://www.apache.org/licenses/LICENSE-2.0
9179 *
9180 * Unless required by applicable law or agreed to in writing, software
9181 * distributed under the License is distributed on an "AS IS" BASIS,
9182 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9183 * See the License for the specific language governing permissions and
9184 * limitations under the License.
9185 * =============================================================================
9186 */
9187 /**
9188 * Creates rank-3 `tf.Tensor` with the provided values, shape and dtype.
9189 *
9190 * The same functionality can be achieved with `tf.tensor`, but in general
9191 * we recommend using `tf.tensor3d` as it makes the code more readable.
9192 *
9193 * ```js
9194 * // Pass a nested array.
9195 * tf.tensor3d([[[1], [2]], [[3], [4]]]).print();
9196 * ```
9197 * ```js
9198 * // Pass a flat array and specify a shape.
9199 * tf.tensor3d([1, 2, 3, 4], [2, 2, 1]).print();
9200 * ```
9201 *
9202 * @param values The values of the tensor. Can be nested array of numbers,
9203 * or a flat array, or a `TypedArray`.
9204 * @param shape The shape of the tensor. If not provided, it is inferred from
9205 * `values`.
9206 * @param dtype The data type.
9207 *
9208 * @doc {heading: 'Tensors', subheading: 'Creation'}
9209 */
9210 function tensor3d(values, shape, dtype) {
9211 assertNonNull(values);
9212 if (shape != null && shape.length !== 3) {
9213 throw new Error('tensor3d() requires shape to have three numbers');
9214 }
9215 const inferredShape = inferShape(values, dtype);
9216 if (inferredShape.length !== 3 && inferredShape.length !== 1) {
9217 throw new Error('tensor3d() requires values to be number[][][] or flat/TypedArray');
9218 }
9219 if (inferredShape.length === 1 && shape == null) {
9220 throw new Error('tensor3d() requires shape to be provided when `values` ' +
9221 'are a flat array');
9222 }
9223 return makeTensor(values, shape, inferredShape, dtype);
9224 }
9225
9226 /**
9227 * @license
9228 * Copyright 2019 Google LLC. All Rights Reserved.
9229 * Licensed under the Apache License, Version 2.0 (the "License");
9230 * you may not use this file except in compliance with the License.
9231 * You may obtain a copy of the License at
9232 *
9233 * http://www.apache.org/licenses/LICENSE-2.0
9234 *
9235 * Unless required by applicable law or agreed to in writing, software
9236 * distributed under the License is distributed on an "AS IS" BASIS,
9237 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9238 * See the License for the specific language governing permissions and
9239 * limitations under the License.
9240 * =============================================================================
9241 */
9242 let fromPixels2DContext;
9243 /**
9244 * Creates a `tf.Tensor` from an image.
9245 *
9246 * ```js
9247 * const image = new ImageData(1, 1);
9248 * image.data[0] = 100;
9249 * image.data[1] = 150;
9250 * image.data[2] = 200;
9251 * image.data[3] = 255;
9252 *
9253 * tf.browser.fromPixels(image).print();
9254 * ```
9255 *
9256 * @param pixels The input image to construct the tensor from. The
9257 * supported image types are all 4-channel. You can also pass in an image
9258 * object with following attributes:
9259 * `{data: Uint8Array; width: number; height: number}`
9260 * @param numChannels The number of channels of the output tensor. A
9261 * numChannels value less than 4 allows you to ignore channels. Defaults to
9262 * 3 (ignores alpha channel of input image).
9263 *
9264 * @returns A Tensor3D with the shape `[height, width, numChannels]`.
9265 *
9266 * Note: fromPixels can be lossy in some cases, same image may result in
9267 * slightly different tensor values, if rendered by different rendering
9268 * engines. This means that results from different browsers, or even same
9269 * browser with CPU and GPU rendering engines can be different. See discussion
9270 * in details:
9271 * https://github.com/tensorflow/tfjs/issues/5482
9272 *
9273 * @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true}
9274 */
9275 function fromPixels_(pixels, numChannels = 3) {
9276 // Sanity checks.
9277 if (numChannels > 4) {
9278 throw new Error('Cannot construct Tensor with more than 4 channels from pixels.');
9279 }
9280 if (pixels == null) {
9281 throw new Error('pixels passed to tf.browser.fromPixels() can not be null');
9282 }
9283 let isPixelData = false;
9284 let isImageData = false;
9285 let isVideo = false;
9286 let isImage = false;
9287 let isCanvasLike = false;
9288 let isImageBitmap = false;
9289 if (pixels.data instanceof Uint8Array) {
9290 isPixelData = true;
9291 }
9292 else if (typeof (ImageData) !== 'undefined' && pixels instanceof ImageData) {
9293 isImageData = true;
9294 }
9295 else if (typeof (HTMLVideoElement) !== 'undefined' &&
9296 pixels instanceof HTMLVideoElement) {
9297 isVideo = true;
9298 }
9299 else if (typeof (HTMLImageElement) !== 'undefined' &&
9300 pixels instanceof HTMLImageElement) {
9301 isImage = true;
9302 // tslint:disable-next-line: no-any
9303 }
9304 else if (pixels.getContext != null) {
9305 isCanvasLike = true;
9306 }
9307 else if (typeof (ImageBitmap) !== 'undefined' && pixels instanceof ImageBitmap) {
9308 isImageBitmap = true;
9309 }
9310 else {
9311 throw new Error('pixels passed to tf.browser.fromPixels() must be either an ' +
9312 `HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData ` +
9313 `in browser, or OffscreenCanvas, ImageData in webworker` +
9314 ` or {data: Uint32Array, width: number, height: number}, ` +
9315 `but was ${pixels.constructor.name}`);
9316 }
9317 if (isVideo) {
9318 const HAVE_CURRENT_DATA_READY_STATE = 2;
9319 if (isVideo &&
9320 pixels.readyState <
9321 HAVE_CURRENT_DATA_READY_STATE) {
9322 throw new Error('The video element has not loaded data yet. Please wait for ' +
9323 '`loadeddata` event on the <video> element.');
9324 }
9325 }
9326 // If the current backend has 'FromPixels' registered, it has a more
9327 // efficient way of handling pixel uploads, so we call that.
9328 const kernel = getKernel(FromPixels, ENGINE.backendName);
9329 if (kernel != null) {
9330 const inputs = { pixels };
9331 const attrs = { numChannels };
9332 return ENGINE.runKernel(FromPixels, inputs, attrs);
9333 }
9334 const [width, height] = isVideo ?
9335 [
9336 pixels.videoWidth,
9337 pixels.videoHeight
9338 ] :
9339 [pixels.width, pixels.height];
9340 let vals;
9341 if (isCanvasLike) {
9342 vals =
9343 // tslint:disable-next-line:no-any
9344 pixels.getContext('2d').getImageData(0, 0, width, height).data;
9345 }
9346 else if (isImageData || isPixelData) {
9347 vals = pixels.data;
9348 }
9349 else if (isImage || isVideo || isImageBitmap) {
9350 if (fromPixels2DContext == null) {
9351 if (typeof document === 'undefined') {
9352 if (typeof OffscreenCanvas !== 'undefined' &&
9353 typeof OffscreenCanvasRenderingContext2D !== 'undefined') {
9354 // @ts-ignore
9355 fromPixels2DContext = new OffscreenCanvas(1, 1).getContext('2d');
9356 }
9357 else {
9358 throw new Error('Cannot parse input in current context. ' +
9359 'Reason: OffscreenCanvas Context2D rendering is not supported.');
9360 }
9361 }
9362 else {
9363 fromPixels2DContext = document.createElement('canvas').getContext('2d');
9364 }
9365 }
9366 fromPixels2DContext.canvas.width = width;
9367 fromPixels2DContext.canvas.height = height;
9368 fromPixels2DContext.drawImage(pixels, 0, 0, width, height);
9369 vals = fromPixels2DContext.getImageData(0, 0, width, height).data;
9370 }
9371 let values;
9372 if (numChannels === 4) {
9373 values = new Int32Array(vals);
9374 }
9375 else {
9376 const numPixels = width * height;
9377 values = new Int32Array(numPixels * numChannels);
9378 for (let i = 0; i < numPixels; i++) {
9379 for (let channel = 0; channel < numChannels; ++channel) {
9380 values[i * numChannels + channel] = vals[i * 4 + channel];
9381 }
9382 }
9383 }
9384 const outShape = [height, width, numChannels];
9385 return tensor3d(values, outShape, 'int32');
9386 }
9387 // Helper functions for |fromPixelsAsync| to check whether the input can
9388 // be wrapped into imageBitmap.
9389 function isPixelData(pixels) {
9390 return (pixels != null) && (pixels.data instanceof Uint8Array);
9391 }
9392 function isImageBitmapFullySupported() {
9393 return typeof window !== 'undefined' &&
9394 typeof (ImageBitmap) !== 'undefined' &&
9395 window.hasOwnProperty('createImageBitmap');
9396 }
9397 function isNonEmptyPixels(pixels) {
9398 return pixels != null && pixels.width !== 0 && pixels.height !== 0;
9399 }
9400 function canWrapPixelsToImageBitmap(pixels) {
9401 return isImageBitmapFullySupported() && !(pixels instanceof ImageBitmap) &&
9402 isNonEmptyPixels(pixels) && !isPixelData(pixels);
9403 }
9404 /**
9405 * Creates a `tf.Tensor` from an image in async way.
9406 *
9407 * ```js
9408 * const image = new ImageData(1, 1);
9409 * image.data[0] = 100;
9410 * image.data[1] = 150;
9411 * image.data[2] = 200;
9412 * image.data[3] = 255;
9413 *
9414 * (await tf.browser.fromPixelsAsync(image)).print();
9415 * ```
9416 * This API is the async version of fromPixels. The API will first
9417 * check |WRAP_TO_IMAGEBITMAP| flag, and try to wrap the input to
9418 * imageBitmap if the flag is set to true.
9419 *
9420 * @param pixels The input image to construct the tensor from. The
9421 * supported image types are all 4-channel. You can also pass in an image
9422 * object with following attributes:
9423 * `{data: Uint8Array; width: number; height: number}`
9424 * @param numChannels The number of channels of the output tensor. A
9425 * numChannels value less than 4 allows you to ignore channels. Defaults to
9426 * 3 (ignores alpha channel of input image).
9427 *
9428 * @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true}
9429 */
9430 async function fromPixelsAsync(pixels, numChannels = 3) {
9431 let inputs = null;
9432 // Check whether the backend needs to wrap |pixels| to imageBitmap and
9433 // whether |pixels| can be wrapped to imageBitmap.
9434 if (env().getBool('WRAP_TO_IMAGEBITMAP') &&
9435 canWrapPixelsToImageBitmap(pixels)) {
9436 // Force the imageBitmap creation to not do any premultiply alpha
9437 // ops.
9438 let imageBitmap;
9439 try {
9440 // wrap in try-catch block, because createImageBitmap may not work
9441 // properly in some browsers, e.g.
9442 // https://bugzilla.mozilla.org/show_bug.cgi?id=1335594
9443 // tslint:disable-next-line: no-any
9444 imageBitmap = await createImageBitmap(pixels, { premultiplyAlpha: 'none' });
9445 }
9446 catch (e) {
9447 imageBitmap = null;
9448 }
9449 // createImageBitmap will clip the source size.
9450 // In some cases, the input will have larger size than its content.
9451 // E.g. new Image(10, 10) but with 1 x 1 content. Using
9452 // createImageBitmap will clip the size from 10 x 10 to 1 x 1, which
9453 // is not correct. We should avoid wrapping such resouce to
9454 // imageBitmap.
9455 if (imageBitmap != null && imageBitmap.width === pixels.width &&
9456 imageBitmap.height === pixels.height) {
9457 inputs = imageBitmap;
9458 }
9459 else {
9460 inputs = pixels;
9461 }
9462 }
9463 else {
9464 inputs = pixels;
9465 }
9466 return fromPixels_(inputs, numChannels);
9467 }
9468 /**
9469 * Draws a `tf.Tensor` of pixel values to a byte array or optionally a
9470 * canvas.
9471 *
9472 * When the dtype of the input is 'float32', we assume values in the range
9473 * [0-1]. Otherwise, when input is 'int32', we assume values in the range
9474 * [0-255].
9475 *
9476 * Returns a promise that resolves when the canvas has been drawn to.
9477 *
9478 * @param img A rank-2 tensor with shape `[height, width]`, or a rank-3 tensor
9479 * of shape `[height, width, numChannels]`. If rank-2, draws grayscale. If
9480 * rank-3, must have depth of 1, 3 or 4. When depth of 1, draws
9481 * grayscale. When depth of 3, we draw with the first three components of
9482 * the depth dimension corresponding to r, g, b and alpha = 1. When depth of
9483 * 4, all four components of the depth dimension correspond to r, g, b, a.
9484 * @param canvas The canvas to draw to.
9485 *
9486 * @doc {heading: 'Browser', namespace: 'browser'}
9487 */
9488 async function toPixels(img, canvas) {
9489 let $img = convertToTensor(img, 'img', 'toPixels');
9490 if (!(img instanceof Tensor)) {
9491 // Assume int32 if user passed a native array.
9492 const originalImgTensor = $img;
9493 $img = cast(originalImgTensor, 'int32');
9494 originalImgTensor.dispose();
9495 }
9496 if ($img.rank !== 2 && $img.rank !== 3) {
9497 throw new Error(`toPixels only supports rank 2 or 3 tensors, got rank ${$img.rank}.`);
9498 }
9499 const [height, width] = $img.shape.slice(0, 2);
9500 const depth = $img.rank === 2 ? 1 : $img.shape[2];
9501 if (depth > 4 || depth === 2) {
9502 throw new Error(`toPixels only supports depth of size ` +
9503 `1, 3 or 4 but got ${depth}`);
9504 }
9505 if ($img.dtype !== 'float32' && $img.dtype !== 'int32') {
9506 throw new Error(`Unsupported type for toPixels: ${$img.dtype}.` +
9507 ` Please use float32 or int32 tensors.`);
9508 }
9509 const data = await $img.data();
9510 const multiplier = $img.dtype === 'float32' ? 255 : 1;
9511 const bytes = new Uint8ClampedArray(width * height * 4);
9512 for (let i = 0; i < height * width; ++i) {
9513 const rgba = [0, 0, 0, 255];
9514 for (let d = 0; d < depth; d++) {
9515 const value = data[i * depth + d];
9516 if ($img.dtype === 'float32') {
9517 if (value < 0 || value > 1) {
9518 throw new Error(`Tensor values for a float32 Tensor must be in the ` +
9519 `range [0 - 1] but encountered ${value}.`);
9520 }
9521 }
9522 else if ($img.dtype === 'int32') {
9523 if (value < 0 || value > 255) {
9524 throw new Error(`Tensor values for a int32 Tensor must be in the ` +
9525 `range [0 - 255] but encountered ${value}.`);
9526 }
9527 }
9528 if (depth === 1) {
9529 rgba[0] = value * multiplier;
9530 rgba[1] = value * multiplier;
9531 rgba[2] = value * multiplier;
9532 }
9533 else {
9534 rgba[d] = value * multiplier;
9535 }
9536 }
9537 const j = i * 4;
9538 bytes[j + 0] = Math.round(rgba[0]);
9539 bytes[j + 1] = Math.round(rgba[1]);
9540 bytes[j + 2] = Math.round(rgba[2]);
9541 bytes[j + 3] = Math.round(rgba[3]);
9542 }
9543 if (canvas != null) {
9544 canvas.width = width;
9545 canvas.height = height;
9546 const ctx = canvas.getContext('2d');
9547 const imageData = new ImageData(bytes, width, height);
9548 ctx.putImageData(imageData, 0, 0);
9549 }
9550 if ($img !== img) {
9551 $img.dispose();
9552 }
9553 return bytes;
9554 }
9555 const fromPixels = op({ fromPixels_ });
9556
9557 var browser = /*#__PURE__*/Object.freeze({
9558 __proto__: null,
9559 fromPixelsAsync: fromPixelsAsync,
9560 toPixels: toPixels,
9561 fromPixels: fromPixels
9562 });
9563
9564 /**
9565 * Validate gather nd inputs.
9566 *
9567 * @param tensor The tensor contains the source values.
9568 * @param indices The tensor contains the indices to slice the source.
9569 *
9570 * @returns [resultShape, numUpdates, sliceSize, strides]
9571 */
9572 function prepareAndValidate(tensor, indices) {
9573 const tensorRank = tensor.shape.length;
9574 const indicesRank = indices.shape.length;
9575 if (tensorRank < 1) {
9576 throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' +
9577 ` but the rank was ${tensorRank}.`);
9578 }
9579 if (indicesRank < 1) {
9580 throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' +
9581 ` but the rank was ${indicesRank}.`);
9582 }
9583 if (indices.dtype !== 'int32') {
9584 throw new Error('tf.gatherND() expects the indices to be int32 type,' +
9585 ` but the dtype was ${indices.dtype}.`);
9586 }
9587 if (indices.shape[indicesRank - 1] > tensorRank) {
9588 throw new Error('index innermost dimension length must be <= tensor rank; saw: ' +
9589 `${indices.shape[indicesRank - 1]} vs. ${tensorRank}`);
9590 }
9591 if (sizeFromShape(tensor.shape) === 0) {
9592 throw new Error('Requested more than 0 entries, but input is empty.' +
9593 ` Input shape: ${tensor.shape}.`);
9594 }
9595 const indicesShape = indices.shape;
9596 const sliceRank = indicesShape[indicesShape.length - 1];
9597 // The result shape is
9598 // indices.shape[:-1] + params.shape[indices.shape[-1]:]
9599 let nResult = 1;
9600 for (let i = 0; i < indicesShape.length - 1; ++i) {
9601 nResult *= indicesShape[i];
9602 }
9603 const inputShape = tensor.shape;
9604 const resultShape = indicesShape.slice();
9605 resultShape.pop();
9606 let sliceSize = 1;
9607 for (let i = sliceRank; i < tensorRank; ++i) {
9608 sliceSize *= inputShape[i];
9609 resultShape.push(inputShape[i]);
9610 }
9611 const strides = [...computeStrides(tensor.shape).map(stride => stride / sliceSize),
9612 1].slice(0, sliceRank);
9613 return [resultShape, nResult, sliceSize, strides];
9614 }
9615
9616 var gather_nd_util = /*#__PURE__*/Object.freeze({
9617 __proto__: null,
9618 prepareAndValidate: prepareAndValidate
9619 });
9620
9621 /**
9622 * Check whether updates.shape = indices.shape[:batchDim] +
9623 * shape[sliceDim:]
9624 *
9625 * @param x The input tensor.
9626 */
9627 function validateUpdateShape(shape, indices, updates) {
9628 const sliceDim = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1;
9629 const batchDim = (indices.rank > 1) ? indices.rank - 1 : 1;
9630 const shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' +
9631 `shape[sliceDim:], got updates.shape: ${updates.shape}` +
9632 `, indices.shape: ${indices.shape}, shape: ${shape}` +
9633 `, sliceDim: ${sliceDim}, and batchDim: ${batchDim}.`;
9634 if (updates.rank < batchDim) {
9635 throw new Error(shapeError + ` update.rank < ${batchDim}. `);
9636 }
9637 if (shape.length < sliceDim + (updates.rank - batchDim)) {
9638 throw new Error(shapeError +
9639 ` Output shape length < ${sliceDim + (updates.rank - batchDim)}`);
9640 }
9641 if (updates.rank !== batchDim + shape.length - sliceDim) {
9642 throw new Error(shapeError + ` update.rank != ${batchDim + shape.length - sliceDim}`);
9643 }
9644 for (let d = 0; d < batchDim; ++d) {
9645 if (updates.shape[d] !== indices.shape[d]) {
9646 throw new Error(shapeError +
9647 ` updates.shape[${d}] (${updates.shape[d]}) != indices.shape[${d}] (${indices.shape[d]}).`);
9648 }
9649 }
9650 for (let d = 0; d < updates.rank - batchDim; ++d) {
9651 if (updates.shape[d + batchDim] !== shape[d + sliceDim]) {
9652 throw new Error(shapeError +
9653 ` updates.shape[${d + batchDim}] (${updates.shape[d + batchDim]}) != shape[${d + batchDim}] (${shape[d + batchDim]})`);
9654 }
9655 }
9656 }
9657 /**
9658 * Validate scatter nd inputs.
9659 *
9660 * @param update The tensor contains the update values.
9661 * @param indices The tensor contains the indices for the update values.
9662 * @param shape The shape of the output tensor.
9663 */
9664 function validateInput(updates, indices, shape) {
9665 if (indices.rank < 1) {
9666 throw new Error('tf.scatterND() expects the indices to be rank 1 or higher,' +
9667 ` but the rank was ${indices.rank}.`);
9668 }
9669 if (updates.rank < 1) {
9670 throw new Error('tf.scatterND() expects the updates to be rank 1 or higher,' +
9671 ` but the rank was ${updates.rank}.`);
9672 }
9673 if (indices.dtype !== 'int32') {
9674 throw new Error(`The dtype of 'indices' should be int32, but got dtype: ${indices.dtype}`);
9675 }
9676 if (shape.length < 1) {
9677 throw new Error(`Output rank must be greater or equal to 1, but got shape: ${shape}`);
9678 }
9679 if (shape.length === 0) {
9680 if (indices.size === 0) {
9681 throw new Error(`Indices specified for empty output. indices shape: ${indices.shape}`);
9682 }
9683 if (updates.size === 0) {
9684 throw new Error(`Updates specified for empty output. updates shape: ${updates.shape}`);
9685 }
9686 }
9687 validateUpdateShape(shape, indices, updates);
9688 }
9689 /**
9690 * Calculate the shape information for the output.
9691 *
9692 * @param update The tensor contains the update values.
9693 * @param indices The tensor contains the indices for the update values.
9694 * @param shape The shape of the output tensor.
9695 *
9696 * @returns ScatterShapeInfo
9697 */
9698 function calculateShapes(updates, indices, shape) {
9699 // Calculate the number of dimensions in indices
9700 const indicesRank = indices.shape.length;
9701 const sliceRank = (indicesRank > 1) ? indices.shape[indicesRank - 1] : 1;
9702 // Calculate the number of elements that make up each slice of our updated
9703 // tensor. This allows us to work with flattened tensors and copy over whole
9704 // slices at a time.
9705 const totalNd = shape.length;
9706 let sliceSize = 1;
9707 for (let i = sliceRank; i < totalNd; ++i) {
9708 sliceSize *= shape[i];
9709 }
9710 const safeSliceDim = (sliceRank < 1) ? 1 : sliceRank;
9711 const numUpdates = sizeFromShape(indices.shape) / safeSliceDim;
9712 const strides = [...computeStrides(shape.slice(0, sliceRank)), 1];
9713 const outputSize = sizeFromShape(shape);
9714 return { sliceRank, numUpdates, sliceSize, strides, outputSize };
9715 }
9716
9717 var scatter_nd_util = /*#__PURE__*/Object.freeze({
9718 __proto__: null,
9719 validateUpdateShape: validateUpdateShape,
9720 validateInput: validateInput,
9721 calculateShapes: calculateShapes
9722 });
9723
9724 /**
9725 * @license
9726 * Copyright 2021 Google LLC. All Rights Reserved.
9727 * Licensed under the Apache License, Version 2.0 (the "License");
9728 * you may not use this file except in compliance with the License.
9729 * You may obtain a copy of the License at
9730 *
9731 * http://www.apache.org/licenses/LICENSE-2.0
9732 *
9733 * Unless required by applicable law or agreed to in writing, software
9734 * distributed under the License is distributed on an "AS IS" BASIS,
9735 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9736 * See the License for the specific language governing permissions and
9737 * limitations under the License.
9738 * =============================================================================
9739 */
9740 const NEW_AXIS = -2;
9741 const SHRINK_AXIS = -1;
9742 function assertParamsValid(input, begin, size) {
9743 const inputRank = input.shape.length;
9744 assert(inputRank === begin.length, () => `Error in slice${inputRank}D: Length of begin ${begin} must ` +
9745 `match the rank of the array (${inputRank}).`);
9746 assert(inputRank === size.length, () => `Error in slice${inputRank}D: Length of size ${size} must ` +
9747 `match the rank of the array (${inputRank}).`);
9748 for (let i = 0; i < inputRank; ++i) {
9749 assert(begin[i] + size[i] <= input.shape[i], () => `Error in slice${inputRank}D: begin[${i}] + size[${i}] ` +
9750 `(${begin[i] + size[i]}) would overflow input.shape[${i}] (${input.shape[i]})`);
9751 }
9752 }
9753 /** Converts a binary mask to an array of axes. Used in stridedSlice(). */
9754 function maskToAxes(mask) {
9755 const axes = [];
9756 let axis = 0;
9757 while (mask > 0) {
9758 if (mask & 1) {
9759 axes.push(axis);
9760 }
9761 mask /= 2;
9762 axis++;
9763 }
9764 return axes;
9765 }
9766 /** Computes the output shape given the strided slice params. */
9767 function computeOutShape(begin, end, strides) {
9768 const size = [];
9769 for (let axis = 0; axis < begin.length; axis++) {
9770 size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]);
9771 }
9772 return size;
9773 }
9774 // Creates full selection at the elided dimensions. If the dimension matches
9775 // the ellipsis mask, override the current stride value. Otherwise, insert.
9776 function stridesWithElidedDims(strides, ellipsisInsertionIndex, numElidedAxes, inputShape) {
9777 const newStrides = [...strides];
9778 for (let i = newStrides.length; i < inputShape.length; i++) {
9779 newStrides.push(1);
9780 }
9781 for (let i = 0; i < numElidedAxes; i++) {
9782 if (i === 0) {
9783 newStrides[ellipsisInsertionIndex] = 1;
9784 }
9785 else {
9786 newStrides.splice(ellipsisInsertionIndex, 0 /* num elements to delete */, 1 /* element to add */);
9787 newStrides.pop();
9788 }
9789 }
9790 return newStrides;
9791 }
9792 function unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, normalizedAxis) {
9793 if (normalizedAxis <= ellipsisInsertionIndex) {
9794 return normalizedAxis;
9795 }
9796 return normalizedAxis - (numElidedAxes - 1);
9797 }
9798 function getElidedAxes(numElidedAxes, ellipsisInsertionIndex) {
9799 const elidedAxes = [];
9800 for (let i = 0; i < numElidedAxes; i++) {
9801 elidedAxes.push(ellipsisInsertionIndex + i);
9802 }
9803 return elidedAxes;
9804 }
9805 // Normalize the start, end and strides.
9806 function getNormalizedAxes(inputShape, ellipsisAxes, numInterpolatedAxes, begin, end, strides, beginMask, endMask, ellipsisMask) {
9807 const inputRank = inputShape.length;
9808 let normalizedBegin = new Array(inputRank), normalizedEnd = new Array(inputRank), normalizedStrides = new Array(inputRank);
9809 if (ellipsisAxes.length && numInterpolatedAxes > 0) {
9810 const fullIndex = ellipsisAxes[0];
9811 // The ellipsis applies to the masked index as well as any dimensions
9812 // that are interpolated.
9813 const numElidedAxes = numInterpolatedAxes + 1;
9814 normalizedBegin = startIndicesWithElidedDims(beginMask, fullIndex, numElidedAxes, begin, inputShape);
9815 normalizedEnd = stopIndicesWithElidedDims(endMask, fullIndex, numElidedAxes, end, inputShape);
9816 normalizedStrides =
9817 stridesWithElidedDims(strides, fullIndex, numElidedAxes, inputShape);
9818 }
9819 else {
9820 for (let axis = 0; axis < inputRank; axis++) {
9821 normalizedBegin[axis] = startForAxis(beginMask, begin, strides, inputShape, axis, ellipsisMask);
9822 normalizedEnd[axis] =
9823 stopForAxis(endMask, end, strides, inputShape, axis, ellipsisMask);
9824 normalizedStrides[axis] = stridesForAxis(strides, axis, ellipsisMask);
9825 }
9826 }
9827 return {
9828 begin: normalizedBegin,
9829 end: normalizedEnd,
9830 strides: normalizedStrides
9831 };
9832 }
9833 // Creates full selection at the elided dimensions. If the dimension matches
9834 // the ellipsis mask, override the current start value. Otherwise, insert.
9835 function startIndicesWithElidedDims(beginMask, ellipsisInsertionIndex, numElidedAxes, originalBegin, inputShape) {
9836 const newIndices = [...inputShape];
9837 const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
9838 for (let axis = 0; axis < newIndices.length; axis++) {
9839 if (elidedAxes.indexOf(axis) > -1) {
9840 newIndices[axis] = 0;
9841 }
9842 else {
9843 const originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
9844 let originalValue = originalBegin[originalAxis];
9845 if (beginMask & 1 << originalAxis) {
9846 originalValue = 0;
9847 }
9848 newIndices[axis] = originalValue;
9849 }
9850 }
9851 return newIndices;
9852 }
9853 // Creates full selection at the elided dimensions. If the dimension matches
9854 // the ellipsis mask, override the current stop value. Otherwise, insert.
9855 function stopIndicesWithElidedDims(endMask, ellipsisInsertionIndex, numElidedAxes, originalEnd, inputShape) {
9856 const newIndices = [...inputShape];
9857 const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
9858 for (let axis = 0; axis < newIndices.length; axis++) {
9859 if (elidedAxes.indexOf(axis) > -1) {
9860 newIndices[axis] = Number.MAX_SAFE_INTEGER;
9861 }
9862 else {
9863 const originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
9864 let originalValue = originalEnd[originalAxis];
9865 if (endMask & 1 << originalAxis) {
9866 originalValue = Number.MAX_SAFE_INTEGER;
9867 }
9868 newIndices[axis] = originalValue;
9869 }
9870 }
9871 for (let i = 0; i < newIndices.length; i++) {
9872 // Handle negative indices
9873 const axisSize = inputShape[i];
9874 if (newIndices[i] < 0) {
9875 newIndices[i] += axisSize;
9876 }
9877 newIndices[i] = clamp(0, newIndices[i], inputShape[i]);
9878 }
9879 return newIndices;
9880 }
9881 function stridesForAxis(strides, axis, ellipsisMask) {
9882 let stride = strides[axis];
9883 if (ellipsisMask & (1 << axis) || stride == null) {
9884 stride = 1;
9885 }
9886 return stride;
9887 }
9888 function startForAxis(beginMask, startIndices, strides, inputShape, axis, ellipsisMask) {
9889 // Begin with the specified index
9890 let start = startIndices[axis];
9891 const stride = strides[axis] || 1;
9892 // Check the axis bit from right of masked axes, or the begin index is not set
9893 // for the axis.
9894 if (beginMask & 1 << axis || ellipsisMask & 1 << axis || start == null) {
9895 if (stride > 0) {
9896 // Forward iteration - use the first element. These values will get
9897 // clamped below (Note: We could have set them to 0 and axis_size-1, but
9898 // use lowest() and max() to maintain symmetry with StopForAxis())
9899 start = Number.MIN_SAFE_INTEGER;
9900 }
9901 else {
9902 // Backward iteration - use the last element.
9903 start = Number.MAX_SAFE_INTEGER;
9904 }
9905 }
9906 // Handle negative indices
9907 const axisSize = inputShape[axis];
9908 if (start < 0) {
9909 start += axisSize;
9910 }
9911 // Clamping
9912 start = clamp(0, start, axisSize - 1);
9913 return start;
9914 }
9915 function stopForAxis(endMask, stopIndices, strides, inputShape, axis, ellipsisMask) {
9916 // Begin with the specified index
9917 let stop = stopIndices[axis];
9918 const stride = strides[axis] || 1;
9919 // Check the axis bit from right of masked axes, or if the stop index is not
9920 // set for this axis.
9921 if (endMask & (1 << axis) || ellipsisMask & (1 << axis) || stop == null) {
9922 if (stride > 0) {
9923 // Forward iteration - use the last element. These values will get
9924 // clamped below
9925 stop = Number.MAX_SAFE_INTEGER;
9926 }
9927 else {
9928 // Backward iteration - use the first element.
9929 stop = Number.MIN_SAFE_INTEGER;
9930 }
9931 }
9932 // Handle negative indices
9933 const axisSize = inputShape[axis];
9934 if (stop < 0) {
9935 stop += axisSize;
9936 }
9937 // Clamping
9938 // Because the end index points one past the last element, we need slightly
9939 // different clamping ranges depending on the direction.
9940 if (stride > 0) {
9941 // Forward iteration
9942 stop = clamp(0, stop, axisSize);
9943 }
9944 else {
9945 // Backward iteration
9946 stop = clamp(-1, stop, axisSize - 1);
9947 }
9948 return stop;
9949 }
9950 /**
9951 * Returns true if the slice occupies a continous set of elements in the
9952 * 'flat' space.
9953 */
9954 function isSliceContinous(shape, begin, size) {
9955 // Index of the first axis that has size > 1.
9956 let firstNonOneAxis = size.length;
9957 for (let i = 0; i < size.length; i++) {
9958 if (size[i] > 1) {
9959 firstNonOneAxis = i;
9960 break;
9961 }
9962 }
9963 for (let i = firstNonOneAxis + 1; i < size.length; i++) {
9964 if (begin[i] > 0 || size[i] !== shape[i]) {
9965 return false;
9966 }
9967 }
9968 return true;
9969 }
9970 function computeFlatOffset(begin, strides) {
9971 let flatOffset = begin.length > 0 ? begin[begin.length - 1] : 1;
9972 for (let i = 0; i < begin.length - 1; i++) {
9973 flatOffset += begin[i] * strides[i];
9974 }
9975 return flatOffset;
9976 }
9977 function parseSliceParams(x, begin, size) {
9978 // The following logic allows for more ergonomic calls.
9979 let begin_;
9980 const xRank = x.shape.length;
9981 if (typeof begin === 'number') {
9982 begin_ = [begin, ...new Array(xRank - 1).fill(0)];
9983 }
9984 else if (begin.length < xRank) {
9985 begin_ = begin.concat(new Array(xRank - begin.length).fill(0));
9986 }
9987 else {
9988 begin_ = begin.slice();
9989 }
9990 begin_.forEach(d => {
9991 assert(d !== -1, () => 'slice() does not support negative begin indexing.');
9992 });
9993 let size_;
9994 if (size == null) {
9995 size_ = new Array(xRank).fill(-1);
9996 }
9997 else if (typeof size === 'number') {
9998 size_ = [size, ...new Array(xRank - 1).fill(-1)];
9999 }
10000 else if (size.length < xRank) {
10001 size_ = size.concat(new Array(xRank - size.length).fill(-1));
10002 }
10003 else {
10004 size_ = size;
10005 }
10006 size_ = size_.map((d, i) => {
10007 if (d >= 0) {
10008 return d;
10009 }
10010 else {
10011 assert(d === -1, () => `Negative size values should be exactly -1 but got ` +
10012 `${d} for the slice() size at index ${i}.`);
10013 return x.shape[i] - begin_[i];
10014 }
10015 });
10016 return [begin_, size_];
10017 }
10018 // Convert the slicing specification from a sparse representation to a dense
10019 // representation. This means that all ellipses and newaxis are expanded out.
10020 function sliceInfo(xShape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
10021 let stridesNonNull;
10022 if (strides == null) {
10023 stridesNonNull = new Array(begin.length);
10024 stridesNonNull.fill(1);
10025 }
10026 else {
10027 stridesNonNull = strides;
10028 }
10029 // Only one non-zero bit is allowed in ellipsisMask, which means ellipsisMask
10030 // is a power of 2. Use bit compares to ensure ellipsisMask is 0 or a power
10031 // of 2. When i is a power of 2, i & (i - 1) is always 0.
10032 // Also ref:
10033 // https://stackoverflow.com/questions/600293/how-to-check-if-a-number-is-a-power-of-2
10034 if (ellipsisMask != null && (ellipsisMask & (ellipsisMask - 1)) !== 0) {
10035 throw new Error('Multiple ellipses in slice is not allowed.');
10036 }
10037 // Step 1: Account for ellipsis and new axis.
10038 // Check for ellipsis and count how many non-newaxis there are after.
10039 let ellipsisSeen = false;
10040 const sparseSpec = {
10041 dims: stridesNonNull.length,
10042 numAddAxisAfterEllipsis: 0,
10043 begin: begin.slice(),
10044 end: end.slice(),
10045 strides: stridesNonNull.slice(),
10046 beginMask,
10047 endMask,
10048 ellipsisMask,
10049 newAxisMask,
10050 shrinkAxisMask
10051 };
10052 for (let i = 0; i < sparseSpec.dims; i++) {
10053 if (ellipsisSeen && ((1 << i) & newAxisMask) !== 0) {
10054 sparseSpec.numAddAxisAfterEllipsis++;
10055 }
10056 if ((1 << i) & ellipsisMask) {
10057 ellipsisSeen = true;
10058 }
10059 }
10060 // If no ellipsis insert one at the end.
10061 if (!ellipsisSeen) {
10062 sparseSpec.ellipsisMask |= (1 << sparseSpec.dims);
10063 sparseSpec.dims++; // this effects loop iteration below
10064 }
10065 // Step 2: Make a sparse spec into a full index spec.
10066 //
10067 // The sparse spec deos not correspond to the number of dimensions.
10068 // Make a dense spec that cooresponds to the number of dimensions.
10069 //
10070 // For example suppose foo[...,3:] on foo.shape = [2, 2, 3] then we need to
10071 // produce the missing beginMask for the first two dimensions i.e. from
10072 // beginMaskSpec = 0, endMaskSpec = 2, we achieve beginMask = 6 (110),
10073 // endMask = 7 (111).
10074 const denseSpec = {
10075 dims: xShape.length,
10076 beginMask: 0,
10077 endMask: 0,
10078 beginValid: false,
10079 endValid: false
10080 };
10081 buildDenseSpec(sparseSpec, denseSpec);
10082 // Step 3: Make implicit ranges (non-zero beginMasks and endMasks) explicit
10083 // and bounds check.
10084 let isIdentity = true;
10085 let sliceDim0 = true;
10086 let isSimpleSlice = true;
10087 const processingShape = [];
10088 const finalShape = [];
10089 for (let i = 0; i < xShape.length; ++i) {
10090 if (denseSpec.strides[i] === 0) {
10091 throw Error(`strides[${i}] must be non-zero`);
10092 }
10093 const shrinkI = !!(denseSpec.shrinkAxisMask & (1 << i));
10094 const dimI = xShape[i];
10095 if (dimI === -1) {
10096 processingShape.push(shrinkI ? 1 : -1);
10097 continue;
10098 }
10099 const masks = [denseSpec.beginMask & (1 << i), denseSpec.endMask & (1 << i)];
10100 const validRange = [
10101 denseSpec.strides[i] > 0 ? 0 : -1,
10102 denseSpec.strides[i] > 0 ? dimI : dimI - 1
10103 ];
10104 if (shrinkI && denseSpec.strides[i] <= 0) {
10105 throw Error('only stride 1 allowed on non-range indexing.');
10106 }
10107 isSimpleSlice = isSimpleSlice && (denseSpec.strides[i] === 1);
10108 const beginAndEndMasked = !!((denseSpec.beginMask & (1 << i)) && (denseSpec.endMask & (1 << i)));
10109 if (denseSpec.beginValid && denseSpec.endValid) {
10110 if (shrinkI) {
10111 // If we are shrinking, the end index is now possibly incorrect. In
10112 // particular foo[-1] produces sparseBegin = -1, sparseEnd = 0.
10113 // and canonical puts these to n-1 and 0, which implies a degenerate
10114 // interval. Fortunately, it is now safe to re-create end as begin + 1.
10115 const xFwd = denseSpec.begin[i] < 0 ? dimI + denseSpec.begin[i] :
10116 denseSpec.begin[i];
10117 denseSpec.begin[i] = xFwd;
10118 denseSpec.end[i] = denseSpec.begin[i] + 1;
10119 if (xFwd < 0 || xFwd >= dimI) {
10120 throw Error(`slice index ${denseSpec.begin[i]} of dimension ${i} out of bounds.`);
10121 }
10122 }
10123 else {
10124 denseSpec.begin[i] = canonical(denseSpec.begin[i], 0, denseSpec.strides[i], dimI, masks, validRange);
10125 denseSpec.end[i] = canonical(denseSpec.end[i], 1, denseSpec.strides[i], dimI, masks, validRange);
10126 }
10127 // Update optimization values
10128 const takeAllInDimension = denseSpec.strides[i] === 1 &&
10129 denseSpec.begin[i] === 0 && denseSpec.end[i] === dimI;
10130 isIdentity = isIdentity && takeAllInDimension;
10131 sliceDim0 = sliceDim0 &&
10132 ((i === 0 && denseSpec.strides[i] === 1) || takeAllInDimension);
10133 }
10134 else {
10135 isIdentity =
10136 isIdentity && ((denseSpec.strides[i] === 1) && beginAndEndMasked);
10137 sliceDim0 = sliceDim0 &&
10138 ((i === 0 && denseSpec.strides[i] === 1) || beginAndEndMasked);
10139 }
10140 // Compute the processing shape (the intermediate Eigen will produce)
10141 let intervalLength;
10142 let knownInterval = false;
10143 if (denseSpec.beginValid && denseSpec.endValid) {
10144 intervalLength = denseSpec.end[i] - denseSpec.begin[i];
10145 knownInterval = true;
10146 }
10147 else if (shrinkI) {
10148 // The dimension is still known as 1 for the processingShape, but will be
10149 // discarded for the final shape.
10150 intervalLength = 1;
10151 knownInterval = true;
10152 }
10153 else if (beginAndEndMasked) {
10154 // Even if we don't have values for begin or end, we do know that this
10155 // dimension covers the whole interval. If we have shape information for
10156 // this dimension, that tells us the interval length.
10157 if (dimI >= 0) {
10158 if (denseSpec.strides[i] < 0) {
10159 intervalLength = -dimI;
10160 }
10161 else {
10162 intervalLength = dimI;
10163 }
10164 knownInterval = true;
10165 }
10166 }
10167 if (knownInterval) {
10168 let sizeI;
10169 // Hold zero if the interval is degenerate, otherwise account for
10170 // remainder
10171 if (intervalLength === 0 ||
10172 ((intervalLength < 0) !== (denseSpec.strides[i] < 0))) {
10173 sizeI = 0;
10174 }
10175 else {
10176 sizeI = Math.trunc(intervalLength / denseSpec.strides[i]) +
10177 (intervalLength % denseSpec.strides[i] !== 0 ? 1 : 0);
10178 }
10179 processingShape.push(sizeI);
10180 }
10181 else {
10182 processingShape.push(-1);
10183 }
10184 }
10185 // Step 4: Compute the final shape
10186 //
10187 // newAxis will increase dimension by 1 (with a one-size dimension)
10188 // slices like foo[3, ...] will reduce dimension by 1.
10189 // This cannot be done earlier, because it depends on Step 3.
10190 for (let denseDim = 0; denseDim < denseSpec.finalShapeGatherIndices.length; ++denseDim) {
10191 const gatherIndex = denseSpec.finalShapeGatherIndices[denseDim];
10192 if (gatherIndex >= 0) {
10193 finalShape.push(processingShape[gatherIndex]);
10194 }
10195 else if (gatherIndex === NEW_AXIS) {
10196 finalShape.push(1);
10197 }
10198 }
10199 const finalShapeSparse = finalShape.filter((dim, i) => denseSpec.finalShapeGatherIndices[i] !== NEW_AXIS);
10200 return {
10201 finalShapeSparse,
10202 finalShape,
10203 isIdentity,
10204 sliceDim0,
10205 isSimpleSlice,
10206 begin: denseSpec.begin,
10207 end: denseSpec.end,
10208 strides: denseSpec.strides
10209 };
10210 }
10211 function buildDenseSpec(sparse, dense) {
10212 dense.beginMask = 0;
10213 dense.endMask = 0;
10214 dense.shrinkAxisMask = 0;
10215 let fullIndex = 0;
10216 dense.beginValid = sparse.begin != null;
10217 dense.endValid = sparse.end != null;
10218 dense.begin = new Array(dense.dims);
10219 dense.end = new Array(dense.dims);
10220 dense.strides = new Array(dense.dims);
10221 dense.finalShapeGatherIndices = [];
10222 dense.finalShapeGatherIndicesSparse = [];
10223 dense.inputShapeGatherIndicesSparse = new Array(dense.dims);
10224 for (let i = 0; i < sparse.dims; i++) {
10225 if ((1 << i) & sparse.ellipsisMask) {
10226 // Only the bit that has ellipsis will fall in this condition.
10227 // Expand the ellipsis into the appropriate indices
10228 // Note: this only works because we guaranteed one ellipsis.
10229 const nextIndex = Math.min(dense.dims - (sparse.dims - i) + 1 + sparse.numAddAxisAfterEllipsis, dense.dims);
10230 for (; fullIndex < nextIndex; fullIndex++) {
10231 // newAxis aren't real axis so you have to skip.
10232 dense.begin[fullIndex] = 0;
10233 dense.end[fullIndex] = 0;
10234 dense.strides[fullIndex] = 1;
10235 dense.beginMask |= (1 << fullIndex);
10236 dense.endMask |= (1 << fullIndex);
10237 dense.finalShapeGatherIndices.push(fullIndex);
10238 dense.finalShapeGatherIndicesSparse.push(-1);
10239 dense.inputShapeGatherIndicesSparse[fullIndex] = i;
10240 }
10241 }
10242 else if ((1 << i) & sparse.newAxisMask) {
10243 // Only the bit that has newAxis will fall in this condition.
10244 dense.finalShapeGatherIndices.push(NEW_AXIS);
10245 dense.finalShapeGatherIndicesSparse.push(-1);
10246 }
10247 else {
10248 if (fullIndex === dense.begin.length) {
10249 throw Error(`Index out of range using input dim ${fullIndex}; input ` +
10250 `has only ${dense.dims} dims, ${dense.begin.length}.`);
10251 }
10252 // Gather slicing spec into appropriate index.
10253 if (sparse.begin != null) {
10254 dense.begin[fullIndex] = sparse.begin[i];
10255 }
10256 if (sparse.end != null) {
10257 dense.end[fullIndex] = sparse.end[i];
10258 }
10259 dense.strides[fullIndex] = sparse.strides[i];
10260 if (sparse.beginMask & (1 << i)) {
10261 dense.beginMask |= (1 << fullIndex);
10262 }
10263 if (sparse.endMask & (1 << i)) {
10264 dense.endMask |= (1 << fullIndex);
10265 }
10266 // If shrink, record where to get the dimensionality from (i.e. newAxis)
10267 // creates a fake 1 size dimension. Also remember shrink axis (now in
10268 // dense form) so we can ignore dense.end below.
10269 if (sparse.shrinkAxisMask & (1 << i)) {
10270 dense.finalShapeGatherIndices.push(SHRINK_AXIS);
10271 dense.finalShapeGatherIndicesSparse.push(-1);
10272 dense.shrinkAxisMask |= (1 << fullIndex);
10273 }
10274 else {
10275 dense.finalShapeGatherIndices.push(fullIndex);
10276 // Remember that where in the sparse shape the dense dim comes from.
10277 dense.finalShapeGatherIndicesSparse.push(i);
10278 }
10279 dense.inputShapeGatherIndicesSparse[fullIndex] = i;
10280 fullIndex++;
10281 }
10282 }
10283 }
10284 function canonical(x, c, strideI, dimI, masks, validRange) {
10285 if (masks[c]) {
10286 return strideI > 0 ? validRange[c] : validRange[(c + 1) & 1];
10287 }
10288 else {
10289 const xFwd = x < 0 ? dimI + x : x; // make negative indices positive
10290 return xFwd < validRange[0] ? validRange[0] :
10291 xFwd > validRange[1] ? validRange[1] : xFwd;
10292 }
10293 }
10294
10295 var slice_util = /*#__PURE__*/Object.freeze({
10296 __proto__: null,
10297 assertParamsValid: assertParamsValid,
10298 maskToAxes: maskToAxes,
10299 computeOutShape: computeOutShape,
10300 stridesWithElidedDims: stridesWithElidedDims,
10301 getNormalizedAxes: getNormalizedAxes,
10302 startIndicesWithElidedDims: startIndicesWithElidedDims,
10303 stopIndicesWithElidedDims: stopIndicesWithElidedDims,
10304 stridesForAxis: stridesForAxis,
10305 startForAxis: startForAxis,
10306 stopForAxis: stopForAxis,
10307 isSliceContinous: isSliceContinous,
10308 computeFlatOffset: computeFlatOffset,
10309 parseSliceParams: parseSliceParams,
10310 sliceInfo: sliceInfo
10311 });
10312
10313 /**
10314 * @license
10315 * Copyright 2018 Google LLC. All Rights Reserved.
10316 * Licensed under the Apache License, Version 2.0 (the "License");
10317 * you may not use this file except in compliance with the License.
10318 * You may obtain a copy of the License at
10319 *
10320 * http://www.apache.org/licenses/LICENSE-2.0
10321 *
10322 * Unless required by applicable law or agreed to in writing, software
10323 * distributed under the License is distributed on an "AS IS" BASIS,
10324 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10325 * See the License for the specific language governing permissions and
10326 * limitations under the License.
10327 * =============================================================================
10328 */
10329 /**
10330 * Serializable defines the serialization contract.
10331 *
10332 * TFJS requires serializable classes to return their className when asked
10333 * to avoid issues with minification.
10334 */
10335 class Serializable {
10336 /**
10337 * Return the class name for this class to use in serialization contexts.
10338 *
10339 * Generally speaking this will be the same thing that constructor.name
10340 * would have returned. However, the class name needs to be robust
10341 * against minification for serialization/deserialization to work properly.
10342 *
10343 * There's also places such as initializers.VarianceScaling, where
10344 * implementation details between different languages led to different
10345 * class hierarchies and a non-leaf node is used for serialization purposes.
10346 */
10347 getClassName() {
10348 return this.constructor
10349 .className;
10350 }
10351 /**
10352 * Creates an instance of T from a ConfigDict.
10353 *
10354 * This works for most descendants of serializable. A few need to
10355 * provide special handling.
10356 * @param cls A Constructor for the class to instantiate.
10357 * @param config The Configuration for the object.
10358 */
10359 /** @nocollapse */
10360 static fromConfig(cls, config) {
10361 return new cls(config);
10362 }
10363 }
10364 /**
10365 * Maps string keys to class constructors.
10366 *
10367 * Used during (de)serialization from the cross-language JSON format, which
10368 * requires the class name in the serialization format matches the class
10369 * names as used in Python, should it exist.
10370 */
10371 class SerializationMap {
10372 constructor() {
10373 this.classNameMap = {};
10374 }
10375 /**
10376 * Returns the singleton instance of the map.
10377 */
10378 static getMap() {
10379 if (SerializationMap.instance == null) {
10380 SerializationMap.instance = new SerializationMap();
10381 }
10382 return SerializationMap.instance;
10383 }
10384 /**
10385 * Registers the class as serializable.
10386 */
10387 static register(cls) {
10388 SerializationMap.getMap().classNameMap[cls.className] =
10389 [cls, cls.fromConfig];
10390 }
10391 }
10392 /**
10393 * Register a class with the serialization map of TensorFlow.js.
10394 *
10395 * This is often used for registering custom Layers, so they can be
10396 * serialized and deserialized.
10397 *
10398 * Example:
10399 *
10400 * ```js
10401 * class MyCustomLayer extends tf.layers.Layer {
10402 * static className = 'MyCustomLayer';
10403 *
10404 * constructor(config) {
10405 * super(config);
10406 * }
10407 * }
10408 * tf.serialization.registerClass(MyCustomLayer);
10409 * ```
10410 *
10411 * @param cls The class to be registered. It must have a public static member
10412 * called `className` defined and the value must be a non-empty string.
10413 *
10414 * @doc {heading: 'Models', subheading: 'Serialization', ignoreCI: true}
10415 */
10416 function registerClass(cls) {
10417 assert(cls.className != null, () => `Class being registered does not have the static className ` +
10418 `property defined.`);
10419 assert(typeof cls.className === 'string', () => `className is required to be a string, but got type ` +
10420 typeof cls.className);
10421 assert(cls.className.length > 0, () => `Class being registered has an empty-string as its className, ` +
10422 `which is disallowed.`);
10423 SerializationMap.register(cls);
10424 }
10425
10426 var serialization = /*#__PURE__*/Object.freeze({
10427 __proto__: null,
10428 Serializable: Serializable,
10429 SerializationMap: SerializationMap,
10430 registerClass: registerClass
10431 });
10432
10433 /**
10434 * @license
10435 * Copyright 2017 Google LLC. All Rights Reserved.
10436 * Licensed under the Apache License, Version 2.0 (the "License");
10437 * you may not use this file except in compliance with the License.
10438 * You may obtain a copy of the License at
10439 *
10440 * http://www.apache.org/licenses/LICENSE-2.0
10441 *
10442 * Unless required by applicable law or agreed to in writing, software
10443 * distributed under the License is distributed on an "AS IS" BASIS,
10444 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10445 * See the License for the specific language governing permissions and
10446 * limitations under the License.
10447 * =============================================================================
10448 */
10449 const TEST_EPSILON_FLOAT32 = 1e-3;
10450 const TEST_EPSILON_FLOAT16 = 1e-1;
10451 function expectArraysClose(actual, expected, epsilon) {
10452 if (epsilon == null) {
10453 epsilon = testEpsilon();
10454 }
10455 return expectArraysPredicate(actual, expected, (a, b) => areClose(a, b, epsilon));
10456 }
10457 function testEpsilon() {
10458 return ENGINE.backend.floatPrecision() === 32 ? TEST_EPSILON_FLOAT32 :
10459 TEST_EPSILON_FLOAT16;
10460 }
10461 function expectArraysPredicate(actual, expected, predicate) {
10462 let checkClassType = true;
10463 if (isTypedArray(actual) || isTypedArray(expected)) {
10464 checkClassType = false;
10465 }
10466 if (isTypedArray(actual) && isTypedArray(expected)) {
10467 checkClassType = true;
10468 }
10469 if (checkClassType) {
10470 const aType = actual.constructor.name;
10471 const bType = expected.constructor.name;
10472 if (aType !== bType) {
10473 throw new Error(`Arrays are of different type. Actual: ${aType}. ` +
10474 `Expected: ${bType}`);
10475 }
10476 }
10477 if (Array.isArray(actual) && Array.isArray(expected)) {
10478 const actualShape = inferShape(actual);
10479 const expectedShape = inferShape(expected);
10480 if (!arraysEqual(actualShape, expectedShape)) {
10481 throw new Error(`Arrays have different shapes. ` +
10482 `Actual: [${actualShape}]. Expected: [${expectedShape}]`);
10483 }
10484 }
10485 const actualFlat = isTypedArray(actual) ? actual : flatten(actual);
10486 const expectedFlat = isTypedArray(expected) ?
10487 expected :
10488 flatten(expected);
10489 if (actualFlat.length !== expectedFlat.length) {
10490 throw new Error(`Arrays have different lengths actual: ${actualFlat.length} vs ` +
10491 `expected: ${expectedFlat.length}.\n` +
10492 `Actual: ${actualFlat}.\n` +
10493 `Expected: ${expectedFlat}.`);
10494 }
10495 for (let i = 0; i < expectedFlat.length; ++i) {
10496 const a = actualFlat[i];
10497 const e = expectedFlat[i];
10498 if (!predicate(a, e)) {
10499 throw new Error(`Arrays differ: actual[${i}] = ${a}, expected[${i}] = ${e}.\n` +
10500 `Actual: ${actualFlat}.\n` +
10501 `Expected: ${expectedFlat}.`);
10502 }
10503 }
10504 }
10505 function expectPromiseToFail(fn, done) {
10506 fn().then(() => done.fail(), () => done());
10507 }
10508 function expectArraysEqual(actual, expected) {
10509 const exp = typeof expected === 'string' || typeof expected === 'number' ||
10510 typeof expected === 'boolean' ?
10511 [expected] :
10512 expected;
10513 if (isString(actual) || isString(actual[0]) ||
10514 isString(expected) || isString(expected[0])) {
10515 // tslint:disable-next-line: triple-equals
10516 return expectArraysPredicate(actual, exp, (a, b) => a == b);
10517 }
10518 return expectArraysPredicate(actual, expected, (a, b) => areClose(a, b, 0));
10519 }
10520 function expectNumbersClose(a, e, epsilon) {
10521 if (epsilon == null) {
10522 epsilon = testEpsilon();
10523 }
10524 if (!areClose(a, e, epsilon)) {
10525 throw new Error(`Numbers differ: actual === ${a}, expected === ${e}`);
10526 }
10527 }
10528 function areClose(a, e, epsilon) {
10529 if (!isFinite(a) && !isFinite(e)) {
10530 return true;
10531 }
10532 if (isNaN(a) || isNaN(e) || Math.abs(a - e) > epsilon) {
10533 return false;
10534 }
10535 return true;
10536 }
10537 function expectValuesInRange(actual, low, high) {
10538 for (let i = 0; i < actual.length; i++) {
10539 if (actual[i] < low || actual[i] > high) {
10540 throw new Error(`Value out of range:${actual[i]} low: ${low}, high: ${high}`);
10541 }
10542 }
10543 }
10544 function expectArrayBuffersEqual(actual, expected) {
10545 // Safari does not like comparing ArrayBuffers directly. Wrapping in
10546 // a Float32Array solves this issue.
10547 const actualArray = new Float32Array(actual);
10548 const expectedArray = new Float32Array(expected);
10549 if (actualArray.length !== expectedArray.length) {
10550 throw new Error('Expected ArrayBuffer to be of length ' +
10551 `${expectedArray.length}, but it was ${actualArray.length}`);
10552 }
10553 for (let i = 0; i < expectedArray.length; i++) {
10554 if (actualArray[i] !== expectedArray[i]) {
10555 throw new Error(`Expected ArrayBuffer value at ${i} to be ` +
10556 `${expectedArray[i]} but got ${actualArray[i]} instead`);
10557 }
10558 }
10559 }
10560 /** Encodes strings into utf-8 bytes. */
10561 function encodeStrings(a) {
10562 for (let i = 0; i < a.length; i++) {
10563 const val = a[i];
10564 if (Array.isArray(val)) {
10565 encodeStrings(val);
10566 }
10567 else {
10568 a[i] = encodeString(val);
10569 }
10570 }
10571 return a;
10572 }
10573
10574 var test_util = /*#__PURE__*/Object.freeze({
10575 __proto__: null,
10576 TEST_EPSILON_FLOAT16: TEST_EPSILON_FLOAT16,
10577 expectArraysClose: expectArraysClose,
10578 testEpsilon: testEpsilon,
10579 expectPromiseToFail: expectPromiseToFail,
10580 expectArraysEqual: expectArraysEqual,
10581 expectNumbersClose: expectNumbersClose,
10582 expectValuesInRange: expectValuesInRange,
10583 expectArrayBuffersEqual: expectArrayBuffersEqual,
10584 encodeStrings: encodeStrings
10585 });
10586
10587 /** @license See the LICENSE file. */
10588 // This code is auto-generated, do not modify this file!
10589 const version = '3.18.0';
10590
10591 /**
10592 * @license
10593 * Copyright 2020 Google LLC. All Rights Reserved.
10594 * Licensed under the Apache License, Version 2.0 (the "License");
10595 * you may not use this file except in compliance with the License.
10596 * You may obtain a copy of the License at
10597 *
10598 * http://www.apache.org/licenses/LICENSE-2.0
10599 *
10600 * Unless required by applicable law or agreed to in writing, software
10601 * distributed under the License is distributed on an "AS IS" BASIS,
10602 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10603 * See the License for the specific language governing permissions and
10604 * limitations under the License.
10605 * =============================================================================
10606 */
10607 /**
10608 * Adds two `tf.Tensor`s element-wise, A + B. Supports broadcasting.
10609 *
10610 *
10611 * ```js
10612 * const a = tf.tensor1d([1, 2, 3, 4]);
10613 * const b = tf.tensor1d([10, 20, 30, 40]);
10614 *
10615 * a.add(b).print(); // or tf.add(a, b)
10616 * ```
10617 *
10618 * ```js
10619 * // Broadcast add a with b.
10620 * const a = tf.scalar(5);
10621 * const b = tf.tensor1d([10, 20, 30, 40]);
10622 *
10623 * a.add(b).print(); // or tf.add(a, b)
10624 * ```
10625 * @param a The first `tf.Tensor` to add.
10626 * @param b The second `tf.Tensor` to add. Must have the same type as `a`.
10627 *
10628 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
10629 */
10630 function add_(a, b) {
10631 let $a = convertToTensor(a, 'a', 'add');
10632 let $b = convertToTensor(b, 'b', 'add');
10633 [$a, $b] = makeTypesMatch($a, $b);
10634 const inputs = { a: $a, b: $b };
10635 return ENGINE.runKernel(Add, inputs);
10636 }
10637 const add$1 = op({ add_ });
10638
10639 /**
10640 * @license
10641 * Copyright 2020 Google LLC. All Rights Reserved.
10642 * Licensed under the Apache License, Version 2.0 (the "License");
10643 * you may not use this file except in compliance with the License.
10644 * You may obtain a copy of the License at
10645 *
10646 * http://www.apache.org/licenses/LICENSE-2.0
10647 *
10648 * Unless required by applicable law or agreed to in writing, software
10649 * distributed under the License is distributed on an "AS IS" BASIS,
10650 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10651 * See the License for the specific language governing permissions and
10652 * limitations under the License.
10653 * =============================================================================
10654 */
10655 /**
10656 * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting.
10657 * The result is rounded with floor function.
10658 *
10659 *
10660 * ```js
10661 * const a = tf.tensor1d([1, 4, 9, 16]);
10662 * const b = tf.tensor1d([1, 2, 3, 4]);
10663 *
10664 * a.floorDiv(b).print(); // or tf.div(a, b)
10665 * ```
10666 *
10667 * ```js
10668 * // Broadcast div a with b.
10669 * const a = tf.tensor1d([2, 4, 6, 8]);
10670 * const b = tf.scalar(2);
10671 *
10672 * a.floorDiv(b).print(); // or tf.floorDiv(a, b)
10673 * ```
10674 *
10675 * @param a The first tensor as the numerator.
10676 * @param b The second tensor as the denominator. Must have the same dtype as
10677 * `a`.
10678 *
10679 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
10680 */
10681 function floorDiv_(a, b) {
10682 let $a = convertToTensor(a, 'a', 'floorDiv');
10683 let $b = convertToTensor(b, 'b', 'floorDiv');
10684 [$a, $b] = makeTypesMatch($a, $b);
10685 const inputs = { a: $a, b: $b };
10686 return ENGINE.runKernel(FloorDiv, inputs);
10687 }
10688 const floorDiv = op({ floorDiv_ });
10689
10690 /**
10691 * @license
10692 * Copyright 2020 Google LLC. All Rights Reserved.
10693 * Licensed under the Apache License, Version 2.0 (the "License");
10694 * you may not use this file except in compliance with the License.
10695 * You may obtain a copy of the License at
10696 *
10697 * http://www.apache.org/licenses/LICENSE-2.0
10698 *
10699 * Unless required by applicable law or agreed to in writing, software
10700 * distributed under the License is distributed on an "AS IS" BASIS,
10701 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10702 * See the License for the specific language governing permissions and
10703 * limitations under the License.
10704 * =============================================================================
10705 */
10706 /**
10707 * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting.
10708 *
10709 * ```js
10710 * const a = tf.tensor1d([1, 4, 9, 16]);
10711 * const b = tf.tensor1d([1, 2, 3, 4]);
10712 *
10713 * a.div(b).print(); // or tf.div(a, b)
10714 * ```
10715 *
10716 * ```js
10717 * // Broadcast div a with b.
10718 * const a = tf.tensor1d([2, 4, 6, 8]);
10719 * const b = tf.scalar(2);
10720 *
10721 * a.div(b).print(); // or tf.div(a, b)
10722 * ```
10723 *
10724 * @param a The first tensor as the numerator.
10725 * @param b The second tensor as the denominator. Must have the same dtype as
10726 * `a`.
10727 *
10728 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
10729 */
10730 function div_(a, b) {
10731 let $a = convertToTensor(a, 'a', 'div');
10732 let $b = convertToTensor(b, 'b', 'div');
10733 [$a, $b] = makeTypesMatch($a, $b);
10734 if ($a.dtype === 'int32' && $b.dtype === 'int32') {
10735 return floorDiv($a, $b);
10736 }
10737 const inputs = { a: $a, b: $b };
10738 const attrs = {};
10739 // tslint:disable-next-line: no-unnecessary-type-assertion
10740 return ENGINE.runKernel(RealDiv, inputs, attrs);
10741 }
10742 const div = op({ div_ });
10743
10744 /**
10745 * @license
10746 * Copyright 2020 Google LLC. All Rights Reserved.
10747 * Licensed under the Apache License, Version 2.0 (the "License");
10748 * you may not use this file except in compliance with the License.
10749 * You may obtain a copy of the License at
10750 *
10751 * http://www.apache.org/licenses/LICENSE-2.0
10752 *
10753 * Unless required by applicable law or agreed to in writing, software
10754 * distributed under the License is distributed on an "AS IS" BASIS,
10755 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10756 * See the License for the specific language governing permissions and
10757 * limitations under the License.
10758 * =============================================================================
10759 */
10760 /**
10761 * Multiplies two `tf.Tensor`s element-wise, A * B. Supports broadcasting.
10762 *
10763 * We also expose `tf.mulStrict` which has the same signature as this op and
10764 * asserts that `a` and `b` are the same shape (does not broadcast).
10765 *
10766 * ```js
10767 * const a = tf.tensor1d([1, 2, 3, 4]);
10768 * const b = tf.tensor1d([2, 3, 4, 5]);
10769 *
10770 * a.mul(b).print(); // or tf.mul(a, b)
10771 * ```
10772 *
10773 * ```js
10774 * // Broadcast mul a with b.
10775 * const a = tf.tensor1d([1, 2, 3, 4]);
10776 * const b = tf.scalar(5);
10777 *
10778 * a.mul(b).print(); // or tf.mul(a, b)
10779 * ```
10780 * @param a The first tensor to multiply.
10781 * @param b The second tensor to multiply. Must have the same dtype as `a`.
10782 *
10783 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
10784 */
10785 function mul_(a, b) {
10786 let $a = convertToTensor(a, 'a', 'mul');
10787 let $b = convertToTensor(b, 'b', 'mul');
10788 [$a, $b] = makeTypesMatch($a, $b);
10789 const inputs = { a: $a, b: $b };
10790 return ENGINE.runKernel(Multiply, inputs);
10791 }
10792 const mul = op({ mul_ });
10793
10794 /**
10795 * @license
10796 * Copyright 2018 Google LLC. All Rights Reserved.
10797 * Licensed under the Apache License, Version 2.0 (the "License");
10798 * you may not use this file except in compliance with the License.
10799 * You may obtain a copy of the License at
10800 *
10801 * http://www.apache.org/licenses/LICENSE-2.0
10802 *
10803 * Unless required by applicable law or agreed to in writing, software
10804 * distributed under the License is distributed on an "AS IS" BASIS,
10805 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10806 * See the License for the specific language governing permissions and
10807 * limitations under the License.
10808 * =============================================================================
10809 */
10810 /**
10811 * Computes absolute value element-wise: `abs(x)`
10812 *
10813 * ```js
10814 * const x = tf.tensor1d([-1, 2, -3, 4]);
10815 *
10816 * x.abs().print(); // or tf.abs(x)
10817 * ```
10818 * @param x The input `tf.Tensor`.
10819 *
10820 * @doc {heading: 'Operations', subheading: 'Basic math'}
10821 */
10822 function abs_(x) {
10823 const $x = convertToTensor(x, 'x', 'abs');
10824 if ($x.dtype === 'complex64') {
10825 const inputs = { x: $x };
10826 return ENGINE.runKernel(ComplexAbs, inputs);
10827 }
10828 else {
10829 const inputs = { x: $x };
10830 return ENGINE.runKernel(Abs, inputs);
10831 }
10832 }
10833 const abs = op({ abs_ });
10834
10835 /**
10836 * @license
10837 * Copyright 2018 Google LLC. All Rights Reserved.
10838 * Licensed under the Apache License, Version 2.0 (the "License");
10839 * you may not use this file except in compliance with the License.
10840 * You may obtain a copy of the License at
10841 *
10842 * http://www.apache.org/licenses/LICENSE-2.0
10843 *
10844 * Unless required by applicable law or agreed to in writing, software
10845 * distributed under the License is distributed on an "AS IS" BASIS,
10846 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10847 * See the License for the specific language governing permissions and
10848 * limitations under the License.
10849 * =============================================================================
10850 */
10851 /**
10852 * Computes acos of the input `tf.Tensor` element-wise: `acos(x)`
10853 *
10854 * ```js
10855 * const x = tf.tensor1d([0, 1, -1, .7]);
10856 *
10857 * x.acos().print(); // or tf.acos(x)
10858 * ```
10859 * @param x The input tensor.
10860 * @doc {heading: 'Operations', subheading: 'Basic math'}
10861 */
10862 function acos_(x) {
10863 const $x = convertToTensor(x, 'x', 'acos');
10864 const inputs = { x: $x };
10865 return ENGINE.runKernel(Acos, inputs);
10866 }
10867 const acos = op({ acos_ });
10868
10869 /**
10870 * @license
10871 * Copyright 2018 Google LLC. All Rights Reserved.
10872 * Licensed under the Apache License, Version 2.0 (the "License");
10873 * you may not use this file except in compliance with the License.
10874 * You may obtain a copy of the License at
10875 *
10876 * http://www.apache.org/licenses/LICENSE-2.0
10877 *
10878 * Unless required by applicable law or agreed to in writing, software
10879 * distributed under the License is distributed on an "AS IS" BASIS,
10880 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10881 * See the License for the specific language governing permissions and
10882 * limitations under the License.
10883 * =============================================================================
10884 */
10885 /**
10886 * Computes the inverse hyperbolic cos of the input `tf.Tensor` element-wise:
10887 * `acosh(x)`
10888 *
10889 * ```js
10890 * const x = tf.tensor1d([10, 1, 3, 5.7]);
10891 *
10892 * x.acosh().print(); // or tf.acosh(x)
10893 * ```
10894 * @param x The input tensor.
10895 *
10896 * @doc {heading: 'Operations', subheading: 'Basic math'}
10897 */
10898 function acosh_(x) {
10899 const $x = convertToTensor(x, 'x', 'acosh');
10900 const inputs = { x: $x };
10901 return ENGINE.runKernel(Acosh, inputs);
10902 }
10903 const acosh = op({ acosh_ });
10904
10905 /**
10906 * @license
10907 * Copyright 2020 Google LLC. All Rights Reserved.
10908 * Licensed under the Apache License, Version 2.0 (the "License");
10909 * you may not use this file except in compliance with the License.
10910 * You may obtain a copy of the License at
10911 *
10912 * http://www.apache.org/licenses/LICENSE-2.0
10913 *
10914 * Unless required by applicable law or agreed to in writing, software
10915 * distributed under the License is distributed on an "AS IS" BASIS,
10916 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10917 * See the License for the specific language governing permissions and
10918 * limitations under the License.
10919 * =============================================================================
10920 */
10921 /**
10922 * Adds a list of `tf.Tensor`s element-wise, each with the same shape and dtype.
10923 *
10924 * ```js
10925 * const a = tf.tensor1d([1, 2]);
10926 * const b = tf.tensor1d([3, 4]);
10927 * const c = tf.tensor1d([5, 6]);
10928 *
10929 * tf.addN([a, b, c]).print();
10930 * ```
10931 * @param tensors A list of tensors with the same shape and dtype.
10932 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
10933 */
10934 function addN_(tensors) {
10935 assert(Array.isArray(tensors), () => 'The argument passed to tf.addN() must be a list of tensors');
10936 assert(tensors.length >= 1, () => `Must pass at least one tensor to tf.addN(), but got ` +
10937 `${tensors.length}`);
10938 const $tensors = tensors.map((t, i) => convertToTensor(t, `tensors${i}`, 'addN'));
10939 const firstTensor = $tensors[0];
10940 $tensors.forEach(t => {
10941 if (t.dtype !== firstTensor.dtype) {
10942 throw new Error('All tensors passed to tf.addN() must have the same dtype');
10943 }
10944 });
10945 $tensors.forEach(t => {
10946 if (!arraysEqual(t.shape, firstTensor.shape)) {
10947 throw new Error('All tensors passed to tf.addN() must have the same shape');
10948 }
10949 });
10950 const inputs = $tensors;
10951 return ENGINE.runKernel(AddN, inputs);
10952 }
10953 const addN = op({ addN_ });
10954
10955 /**
10956 * @license
10957 * Copyright 2020 Google LLC. All Rights Reserved.
10958 * Licensed under the Apache License, Version 2.0 (the "License");
10959 * you may not use this file except in compliance with the License.
10960 * You may obtain a copy of the License at
10961 *
10962 * http://www.apache.org/licenses/LICENSE-2.0
10963 *
10964 * Unless required by applicable law or agreed to in writing, software
10965 * distributed under the License is distributed on an "AS IS" BASIS,
10966 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10967 * See the License for the specific language governing permissions and
10968 * limitations under the License.
10969 * =============================================================================
10970 */
10971 /**
10972 * Computes the logical and of elements across dimensions of a `tf.Tensor`.
10973 *
10974 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
10975 * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
10976 * `axes`. If `keepDims` is true, the reduced dimensions are retained with
10977 * length 1. If `axes` has no entries, all dimensions are reduced, and an
10978 * `tf.Tensor` with a single element is returned.
10979 *
10980 * ```js
10981 * const x = tf.tensor1d([1, 1, 1], 'bool');
10982 *
10983 * x.all().print(); // or tf.all(x)
10984 * ```
10985 *
10986 * ```js
10987 * const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');
10988 *
10989 * const axis = 1;
10990 * x.all(axis).print(); // or tf.all(x, axis)
10991 * ```
10992 *
10993 * @param x The input tensor. Must be of dtype bool.
10994 * @param axis The dimension(s) to reduce. By default it reduces
10995 * all dimensions.
10996 * @param keepDims If true, retains reduced dimensions with size 1.
10997 *
10998 * @doc {heading: 'Operations', subheading: 'Reduction'}
10999 */
11000 function all_(x, axis = null, keepDims = false) {
11001 const $x = convertToTensor(x, 'x', 'all', 'bool');
11002 const inputs = { x: $x };
11003 const attrs = { axis, keepDims };
11004 return ENGINE.runKernel(All, inputs, attrs);
11005 }
11006 const all = op({ all_ });
11007
11008 /**
11009 * @license
11010 * Copyright 2020 Google LLC. All Rights Reserved.
11011 * Licensed under the Apache License, Version 2.0 (the "License");
11012 * you may not use this file except in compliance with the License.
11013 * You may obtain a copy of the License at
11014 *
11015 * http://www.apache.org/licenses/LICENSE-2.0
11016 *
11017 * Unless required by applicable law or agreed to in writing, software
11018 * distributed under the License is distributed on an "AS IS" BASIS,
11019 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11020 * See the License for the specific language governing permissions and
11021 * limitations under the License.
11022 * =============================================================================
11023 */
11024 /**
11025 * Computes the logical or of elements across dimensions of a `tf.Tensor`.
11026 *
11027 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
11028 * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
11029 * `axes`. If `keepDims` is true, the reduced dimensions are retained with
11030 * length 1. If `axes` has no entries, all dimensions are reduced, and an
11031 * `tf.Tensor` with a single element is returned.
11032 *
11033 * ```js
11034 * const x = tf.tensor1d([1, 1, 1], 'bool');
11035 *
11036 * x.any().print(); // or tf.any(x)
11037 * ```
11038 *
11039 * ```js
11040 * const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');
11041 *
11042 * const axis = 1;
11043 * x.any(axis).print(); // or tf.any(x, axis)
11044 * ```
11045 *
11046 * @param x The input tensor. Must be of dtype bool.
11047 * @param axis The dimension(s) to reduce. By default it reduces
11048 * all dimensions.
11049 * @param keepDims If true, retains reduced dimensions with size 1.
11050 *
11051 * @doc {heading: 'Operations', subheading: 'Reduction'}
11052 */
11053 function any_(x, axis = null, keepDims = false) {
11054 const $x = convertToTensor(x, 'x', 'any', 'bool');
11055 const inputs = { x: $x };
11056 const attrs = { axis, keepDims };
11057 return ENGINE.runKernel(Any, inputs, attrs);
11058 }
11059 // tslint:disable-next-line:variable-name
11060 const any = op({ any_ });
11061
11062 /**
11063 * @license
11064 * Copyright 2020 Google Inc. All Rights Reserved.
11065 * Licensed under the Apache License, Version 2.0 (the "License");
11066 * you may not use this file except in compliance with the License.
11067 * You may obtain a copy of the License at
11068 *
11069 * http://www.apache.org/licenses/LICENSE-2.0
11070 *
11071 * Unless required by applicable law or agreed to in writing, software
11072 * distributed under the License is distributed on an "AS IS" BASIS,
11073 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11074 * See the License for the specific language governing permissions and
11075 * limitations under the License.
11076 * =============================================================================
11077 */
11078 /**
11079 * Returns the indices of the maximum values along an `axis`.
11080 *
11081 * The result has the same shape as `input` with the dimension along `axis`
11082 * removed.
11083 *
11084 * ```js
11085 * const x = tf.tensor1d([1, 2, 3]);
11086 *
11087 * x.argMax().print(); // or tf.argMax(x)
11088 * ```
11089 *
11090 * ```js
11091 * const x = tf.tensor2d([1, 2, 4, 3], [2, 2]);
11092 *
11093 * const axis = 1;
11094 * x.argMax(axis).print(); // or tf.argMax(x, axis)
11095 * ```
11096 *
11097 * @param x The input tensor.
11098 * @param axis The dimension to reduce. Defaults to 0 (outer-most dimension).
11099 *
11100 * @doc {heading: 'Operations', subheading: 'Reduction'}
11101 */
11102 function argMax_(x, axis = 0) {
11103 const $x = convertToTensor(x, 'x', 'argMax');
11104 const inputs = { x: $x };
11105 const attrs = { axis };
11106 return ENGINE.runKernel(ArgMax, inputs, attrs);
11107 }
11108 const argMax = op({ argMax_ });
11109
11110 /**
11111 * @license
11112 * Copyright 2020 Google Inc. All Rights Reserved.
11113 * Licensed under the Apache License, Version 2.0 (the "License");
11114 * you may not use this file except in compliance with the License.
11115 * You may obtain a copy of the License at
11116 *
11117 * http://www.apache.org/licenses/LICENSE-2.0
11118 *
11119 * Unless required by applicable law or agreed to in writing, software
11120 * distributed under the License is distributed on an "AS IS" BASIS,
11121 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11122 * See the License for the specific language governing permissions and
11123 * limitations under the License.
11124 * =============================================================================
11125 */
11126 /**
11127 * Returns the indices of the minimum values along an `axis`.
11128 *
11129 * The result has the same shape as `input` with the dimension along `axis`
11130 * removed.
11131 *
11132 * ```js
11133 * const x = tf.tensor1d([1, 2, 3]);
11134 *
11135 * x.argMin().print(); // or tf.argMin(x)
11136 * ```
11137 *
11138 * ```js
11139 * const x = tf.tensor2d([1, 2, 4, 3], [2, 2]);
11140 *
11141 * const axis = 1;
11142 * x.argMin(axis).print(); // or tf.argMin(x, axis)
11143 * ```
11144 *
11145 * @param x The input tensor.
11146 * @param axis The dimension to reduce. Defaults to 0 (outer-most dimension).
11147 *
11148 * @doc {heading: 'Operations', subheading: 'Reduction'}
11149 */
11150 function argMin_(x, axis = 0) {
11151 const $x = convertToTensor(x, 'x', 'argMin');
11152 const inputs = { x: $x };
11153 const attrs = { axis };
11154 return ENGINE.runKernel(ArgMin, inputs, attrs);
11155 }
11156 const argMin = op({ argMin_ });
11157
11158 /**
11159 * @license
11160 * Copyright 2018 Google LLC. All Rights Reserved.
11161 * Licensed under the Apache License, Version 2.0 (the "License");
11162 * you may not use this file except in compliance with the License.
11163 * You may obtain a copy of the License at
11164 *
11165 * http://www.apache.org/licenses/LICENSE-2.0
11166 *
11167 * Unless required by applicable law or agreed to in writing, software
11168 * distributed under the License is distributed on an "AS IS" BASIS,
11169 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11170 * See the License for the specific language governing permissions and
11171 * limitations under the License.
11172 * =============================================================================
11173 */
11174 /**
11175 * Computes asin of the input `tf.Tensor` element-wise: `asin(x)`
11176 *
11177 * ```js
11178 * const x = tf.tensor1d([0, 1, -1, .7]);
11179 *
11180 * x.asin().print(); // or tf.asin(x)
11181 * ```
11182 * @param x The input tensor.
11183 * @doc {heading: 'Operations', subheading: 'Basic math'}
11184 */
11185 function asin_(x) {
11186 const $x = convertToTensor(x, 'x', 'asin');
11187 const inputs = { x: $x };
11188 return ENGINE.runKernel(Asin, inputs);
11189 }
11190 const asin = op({ asin_ });
11191
11192 /**
11193 * @license
11194 * Copyright 2018 Google LLC. All Rights Reserved.
11195 * Licensed under the Apache License, Version 2.0 (the "License");
11196 * you may not use this file except in compliance with the License.
11197 * You may obtain a copy of the License at
11198 *
11199 * http://www.apache.org/licenses/LICENSE-2.0
11200 *
11201 * Unless required by applicable law or agreed to in writing, software
11202 * distributed under the License is distributed on an "AS IS" BASIS,
11203 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11204 * See the License for the specific language governing permissions and
11205 * limitations under the License.
11206 * =============================================================================
11207 */
11208 /**
11209 * Computes inverse hyperbolic sin of the input `tf.Tensor` element-wise:
11210 * `asinh(x)`
11211 *
11212 * ```js
11213 * const x = tf.tensor1d([0, 1, -1, .7]);
11214 *
11215 * x.asinh().print(); // or tf.asinh(x)
11216 * ```
11217 * @param x The input tensor.
11218 *
11219 * @doc {heading: 'Operations', subheading: 'Basic math'}
11220 */
11221 function asinh_(x) {
11222 const $x = convertToTensor(x, 'x', 'asinh');
11223 const inputs = { x: $x };
11224 return ENGINE.runKernel(Asinh, inputs);
11225 }
11226 const asinh = op({ asinh_ });
11227
11228 /**
11229 * @license
11230 * Copyright 2018 Google LLC. All Rights Reserved.
11231 * Licensed under the Apache License, Version 2.0 (the "License");
11232 * you may not use this file except in compliance with the License.
11233 * You may obtain a copy of the License at
11234 *
11235 * http://www.apache.org/licenses/LICENSE-2.0
11236 *
11237 * Unless required by applicable law or agreed to in writing, software
11238 * distributed under the License is distributed on an "AS IS" BASIS,
11239 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11240 * See the License for the specific language governing permissions and
11241 * limitations under the License.
11242 * =============================================================================
11243 */
11244 /**
11245 * Computes atan of the input `tf.Tensor` element-wise: `atan(x)`
11246 *
11247 * ```js
11248 * const x = tf.tensor1d([0, 1, -1, .7]);
11249 *
11250 * x.atan().print(); // or tf.atan(x)
11251 * ```
11252 * @param x The input tensor.
11253 *
11254 * @doc {heading: 'Operations', subheading: 'Basic math'}
11255 */
11256 function atan_(x) {
11257 const $x = convertToTensor(x, 'x', 'atan');
11258 const inputs = { x: $x };
11259 return ENGINE.runKernel(Atan, inputs);
11260 }
11261 const atan = op({ atan_ });
11262
11263 /**
11264 * @license
11265 * Copyright 2020 Google LLC. All Rights Reserved.
11266 * Licensed under the Apache License, Version 2.0 (the "License");
11267 * you may not use this file except in compliance with the License.
11268 * You may obtain a copy of the License at
11269 *
11270 * http://www.apache.org/licenses/LICENSE-2.0
11271 *
11272 * Unless required by applicable law or agreed to in writing, software
11273 * distributed under the License is distributed on an "AS IS" BASIS,
11274 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11275 * See the License for the specific language governing permissions and
11276 * limitations under the License.
11277 * =============================================================================
11278 */
11279 /**
11280 * Computes arctangent of `tf.Tensor`s a / b element-wise: `atan2(a, b)`.
11281 * Supports broadcasting.
11282 *
11283 * ```js
11284 * const a = tf.tensor1d([1.0, 1.0, -1.0, .7]);
11285 * const b = tf.tensor1d([2.0, 13.0, 3.5, .21]);
11286 *
11287 * tf.atan2(a, b).print()
11288 * ```
11289 *
11290 * @param a The first tensor.
11291 * @param b The second tensor. Must have the same dtype as `a`.
11292 *
11293 * @doc {heading: 'Operations', subheading: 'Basic math'}
11294 */
11295 function atan2_(a, b) {
11296 let $a = convertToTensor(a, 'a', 'atan2');
11297 let $b = convertToTensor(b, 'b', 'atan2');
11298 [$a, $b] = makeTypesMatch($a, $b);
11299 const inputs = { a: $a, b: $b };
11300 return ENGINE.runKernel(Atan2, inputs);
11301 }
11302 const atan2 = op({ atan2_ });
11303
11304 /**
11305 * @license
11306 * Copyright 2018 Google LLC. All Rights Reserved.
11307 * Licensed under the Apache License, Version 2.0 (the "License");
11308 * you may not use this file except in compliance with the License.
11309 * You may obtain a copy of the License at
11310 *
11311 * http://www.apache.org/licenses/LICENSE-2.0
11312 *
11313 * Unless required by applicable law or agreed to in writing, software
11314 * distributed under the License is distributed on an "AS IS" BASIS,
11315 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11316 * See the License for the specific language governing permissions and
11317 * limitations under the License.
11318 * =============================================================================
11319 */
11320 /**
11321 * Computes inverse hyperbolic tan of the input `tf.Tensor` element-wise:
11322 * `atanh(x)`
11323 *
11324 * ```js
11325 * const x = tf.tensor1d([0, .1, -.1, .7]);
11326 *
11327 * x.atanh().print(); // or tf.atanh(x)
11328 * ```
11329 * @param x The input tensor.
11330 *
11331 * @doc {heading: 'Operations', subheading: 'Basic math'}
11332 */
11333 function atanh_(x) {
11334 const $x = convertToTensor(x, 'x', 'atanh');
11335 const inputs = { x: $x };
11336 return ENGINE.runKernel(Atanh, inputs);
11337 }
11338 const atanh = op({ atanh_ });
11339
11340 /**
11341 * @license
11342 * Copyright 2020 Google LLC. All Rights Reserved.
11343 * Licensed under the Apache License, Version 2.0 (the "License");
11344 * you may not use this file except in compliance with the License.
11345 * You may obtain a copy of the License at
11346 *
11347 * http://www.apache.org/licenses/LICENSE-2.0
11348 *
11349 * Unless required by applicable law or agreed to in writing, software
11350 * distributed under the License is distributed on an "AS IS" BASIS,
11351 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11352 * See the License for the specific language governing permissions and
11353 * limitations under the License.
11354 * =============================================================================
11355 */
11356 /**
11357 *
11358 * @param inputShape Input tensor shape is of the following dimensions:
11359 * `[batch, height, width, inChannels]`.
11360 * @param filterShape The filter shape is of the following dimensions:
11361 * `[filterHeight, filterWidth, depth]`.
11362 * @param strides The strides of the sliding window for each dimension of the
11363 * input tensor: `[strideHeight, strideWidth]`.
11364 * If `strides` is a single number,
11365 * then `strideHeight == strideWidth`.
11366 * @param pad The type of padding algorithm.
11367 * - `same` and stride 1: output will be of same size as input,
11368 * regardless of filter size.
11369 * - `valid`: output will be smaller than input if filter is larger
11370 * than 1*1x1.
11371 * - For more info, see this guide:
11372 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
11373 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
11374 * @param dataFormat The data format of the input and output data.
11375 * Defaults to 'NHWC'.
11376 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`.
11377 * Defaults to `[1, 1]`. If `dilations` is a single number, then
11378 * `dilationHeight == dilationWidth`.
11379 */
11380 function computeDilation2DInfo(inputShape, filterShape, strides, pad, dataFormat = 'NHWC', dilations) {
11381 // `computerConv2DInfo` require filterShape to be in the dimension of:
11382 // `[filterHeight, filterWidth, depth, outDepth]`, dilation2d doesn't have
11383 // outDepth, it should have the same depth as the input.
11384 // Input shape: [batch, height, width, inChannels]
11385 const inputChannels = inputShape[3];
11386 const $filterShape = [...filterShape, inputChannels];
11387 const $dataFormat = convertConv2DDataFormat(dataFormat);
11388 return computeConv2DInfo(inputShape, $filterShape, strides, dilations, pad, null /* roundingMode */, null /* depthWise */, $dataFormat);
11389 }
11390 function computePool2DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat = 'channelsLast') {
11391 const [filterHeight, filterWidth] = parseTupleParam(filterSize);
11392 let filterShape;
11393 if (dataFormat === 'channelsLast') {
11394 filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]];
11395 }
11396 else if (dataFormat === 'channelsFirst') {
11397 filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]];
11398 }
11399 else {
11400 throw new Error(`Unknown dataFormat ${dataFormat}`);
11401 }
11402 return computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat);
11403 }
11404 /**
11405 * Computes the information for a forward pass of a pooling3D operation.
11406 */
11407 function computePool3DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat = 'NDHWC') {
11408 const [filterDepth, filterHeight, filterWidth] = parse3TupleParam(filterSize);
11409 let filterShape;
11410 let $dataFormat;
11411 if (dataFormat === 'NDHWC') {
11412 $dataFormat = 'channelsLast';
11413 filterShape =
11414 [filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]];
11415 }
11416 else if (dataFormat === 'NCDHW') {
11417 $dataFormat = 'channelsFirst';
11418 filterShape =
11419 [filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]];
11420 }
11421 else {
11422 throw new Error(`Unknown dataFormat ${dataFormat}`);
11423 }
11424 return computeConv3DInfo(inShape, filterShape, strides, dilations, pad, false, $dataFormat, roundingMode);
11425 }
11426 /**
11427 * Computes the information for a forward pass of a convolution/pooling
11428 * operation.
11429 */
11430 function computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, depthwise = false, dataFormat = 'channelsLast') {
11431 let [batchSize, inHeight, inWidth, inChannels] = [-1, -1, -1, -1];
11432 if (dataFormat === 'channelsLast') {
11433 [batchSize, inHeight, inWidth, inChannels] = inShape;
11434 }
11435 else if (dataFormat === 'channelsFirst') {
11436 [batchSize, inChannels, inHeight, inWidth] = inShape;
11437 }
11438 else {
11439 throw new Error(`Unknown dataFormat ${dataFormat}`);
11440 }
11441 const [filterHeight, filterWidth, , filterChannels] = filterShape;
11442 const [strideHeight, strideWidth] = parseTupleParam(strides);
11443 const [dilationHeight, dilationWidth] = parseTupleParam(dilations);
11444 const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
11445 const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
11446 const { padInfo, outHeight, outWidth } = getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode, dataFormat);
11447 const outChannels = depthwise ? filterChannels * inChannels : filterChannels;
11448 let outShape;
11449 if (dataFormat === 'channelsFirst') {
11450 outShape = [batchSize, outChannels, outHeight, outWidth];
11451 }
11452 else if (dataFormat === 'channelsLast') {
11453 outShape = [batchSize, outHeight, outWidth, outChannels];
11454 }
11455 return {
11456 batchSize,
11457 dataFormat,
11458 inHeight,
11459 inWidth,
11460 inChannels,
11461 outHeight,
11462 outWidth,
11463 outChannels,
11464 padInfo,
11465 strideHeight,
11466 strideWidth,
11467 filterHeight,
11468 filterWidth,
11469 effectiveFilterHeight,
11470 effectiveFilterWidth,
11471 dilationHeight,
11472 dilationWidth,
11473 inShape,
11474 outShape,
11475 filterShape
11476 };
11477 }
11478 /**
11479 * Computes the information for a forward pass of a 3D convolution/pooling
11480 * operation.
11481 */
11482 function computeConv3DInfo(inShape, filterShape, strides, dilations, pad, depthwise = false, dataFormat = 'channelsLast', roundingMode) {
11483 let [batchSize, inDepth, inHeight, inWidth, inChannels] = [-1, -1, -1, -1, -1];
11484 if (dataFormat === 'channelsLast') {
11485 [batchSize, inDepth, inHeight, inWidth, inChannels] = inShape;
11486 }
11487 else if (dataFormat === 'channelsFirst') {
11488 [batchSize, inChannels, inDepth, inHeight, inWidth] = inShape;
11489 }
11490 else {
11491 throw new Error(`Unknown dataFormat ${dataFormat}`);
11492 }
11493 const [filterDepth, filterHeight, filterWidth, , filterChannels] = filterShape;
11494 const [strideDepth, strideHeight, strideWidth] = parse3TupleParam(strides);
11495 const [dilationDepth, dilationHeight, dilationWidth] = parse3TupleParam(dilations);
11496 const effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth);
11497 const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
11498 const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
11499 const { padInfo, outDepth, outHeight, outWidth } = get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, roundingMode);
11500 const outChannels = depthwise ? filterChannels * inChannels : filterChannels;
11501 let outShape;
11502 if (dataFormat === 'channelsFirst') {
11503 outShape = [batchSize, outChannels, outDepth, outHeight, outWidth];
11504 }
11505 else if (dataFormat === 'channelsLast') {
11506 outShape = [batchSize, outDepth, outHeight, outWidth, outChannels];
11507 }
11508 return {
11509 batchSize,
11510 dataFormat,
11511 inDepth,
11512 inHeight,
11513 inWidth,
11514 inChannels,
11515 outDepth,
11516 outHeight,
11517 outWidth,
11518 outChannels,
11519 padInfo,
11520 strideDepth,
11521 strideHeight,
11522 strideWidth,
11523 filterDepth,
11524 filterHeight,
11525 filterWidth,
11526 effectiveFilterDepth,
11527 effectiveFilterHeight,
11528 effectiveFilterWidth,
11529 dilationDepth,
11530 dilationHeight,
11531 dilationWidth,
11532 inShape,
11533 outShape,
11534 filterShape
11535 };
11536 }
11537 function computeOutputShape2D(inShape, fieldSize, stride, zeroPad, roundingMode) {
11538 if (zeroPad == null) {
11539 zeroPad = computeDefaultPad(inShape, fieldSize, stride);
11540 }
11541 const inputRows = inShape[0];
11542 const inputCols = inShape[1];
11543 const outputRows = round((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
11544 const outputCols = round((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
11545 return [outputRows, outputCols];
11546 }
11547 function computeOutputShape4D(inShape, fieldSize, outChannels, stride, zeroPad, roundingMode) {
11548 if (zeroPad == null) {
11549 zeroPad = computeDefaultPad(inShape, fieldSize, stride);
11550 }
11551 const inputDepth = inShape[0];
11552 const inputRows = inShape[1];
11553 const inputCols = inShape[2];
11554 const outputDepths = round((inputDepth - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
11555 const outputRows = round((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
11556 const outputCols = round((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
11557 return [outputDepths, outputRows, outputCols, outChannels];
11558 }
11559 function computeDefaultPad(inputShape, fieldSize, stride, dilation = 1) {
11560 const effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation);
11561 return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2);
11562 }
11563 function parseTupleParam(param) {
11564 if (typeof param === 'number') {
11565 return [param, param, param];
11566 }
11567 if (param.length === 2) {
11568 return [param[0], param[1], 1];
11569 }
11570 return param;
11571 }
11572 function parse3TupleParam(param) {
11573 return typeof param === 'number' ? [param, param, param] : param;
11574 }
11575 /* See https://www.tensorflow.org/api_docs/python/tf/nn/atrous_conv2d
11576 * Atrous convolution is equivalent to standard convolution with upsampled
11577 * filters with effective_filter_height =
11578 * filter_height + (filter_height - 1) * (dilation - 1)
11579 * and effective_filter_width =
11580 * filter_width + (filter_width - 1) * (dilation - 1),
11581 * produced by inserting dilation - 1 zeros along consecutive elements across
11582 * the filters' spatial dimensions.
11583 * When there is a dilation, this converts a filter dimension to the
11584 * effective filter dimension, so it can be used in a standard convolution.
11585 */
11586 function getEffectiveFilterSize(filterSize, dilation) {
11587 if (dilation <= 1) {
11588 return filterSize;
11589 }
11590 return filterSize + (filterSize - 1) * (dilation - 1);
11591 }
11592 function getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode, dataFormat) {
11593 let padInfo;
11594 let outHeight;
11595 let outWidth;
11596 if (typeof pad === 'number') {
11597 const padType = (pad === 0) ? 'VALID' : 'NUMBER';
11598 padInfo = { top: pad, bottom: pad, left: pad, right: pad, type: padType };
11599 const outShape = computeOutputShape2D([inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode);
11600 outHeight = outShape[0];
11601 outWidth = outShape[1];
11602 }
11603 else if (pad === 'same') {
11604 outHeight = Math.ceil(inHeight / strideHeight);
11605 outWidth = Math.ceil(inWidth / strideWidth);
11606 const padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight);
11607 const padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth);
11608 const top = Math.floor(padAlongHeight / 2);
11609 const bottom = padAlongHeight - top;
11610 const left = Math.floor(padAlongWidth / 2);
11611 const right = padAlongWidth - left;
11612 padInfo = { top, bottom, left, right, type: 'SAME' };
11613 }
11614 else if (pad === 'valid') {
11615 padInfo = { top: 0, bottom: 0, left: 0, right: 0, type: 'VALID' };
11616 outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
11617 outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
11618 }
11619 else if (typeof pad === 'object') {
11620 const top = dataFormat === 'channelsLast' ? pad[1][0] : pad[2][0];
11621 const bottom = dataFormat === 'channelsLast' ? pad[1][1] : pad[2][1];
11622 const left = dataFormat === 'channelsLast' ? pad[2][0] : pad[3][0];
11623 const right = dataFormat === 'channelsLast' ? pad[2][1] : pad[3][1];
11624 const padType = (top === 0 && bottom === 0 && left === 0 && right === 0) ?
11625 'VALID' :
11626 'EXPLICIT';
11627 padInfo = { top, bottom, left, right, type: padType };
11628 outHeight = round((inHeight - filterHeight + top + bottom) / strideHeight + 1, roundingMode);
11629 outWidth = round((inWidth - filterWidth + left + right) / strideWidth + 1, roundingMode);
11630 }
11631 else {
11632 throw Error(`Unknown padding parameter: ${pad}`);
11633 }
11634 return { padInfo, outHeight, outWidth };
11635 }
11636 function get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, roundingMode) {
11637 let padInfo;
11638 let outDepth;
11639 let outHeight;
11640 let outWidth;
11641 if (typeof pad === 'number') {
11642 const padType = (pad === 0) ? 'VALID' : 'NUMBER';
11643 padInfo = {
11644 top: pad,
11645 bottom: pad,
11646 left: pad,
11647 right: pad,
11648 front: pad,
11649 back: pad,
11650 type: padType
11651 };
11652 const outShape = computeOutputShape4D([inDepth, inHeight, inWidth, 1], filterDepth, 1, strideDepth, pad, roundingMode);
11653 outDepth = outShape[0];
11654 outHeight = outShape[1];
11655 outWidth = outShape[2];
11656 }
11657 else if (pad === 'same') {
11658 outDepth = Math.ceil(inDepth / strideDepth);
11659 outHeight = Math.ceil(inHeight / strideHeight);
11660 outWidth = Math.ceil(inWidth / strideWidth);
11661 const padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth;
11662 const padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight;
11663 const padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth;
11664 const front = Math.floor(padAlongDepth / 2);
11665 const back = padAlongDepth - front;
11666 const top = Math.floor(padAlongHeight / 2);
11667 const bottom = padAlongHeight - top;
11668 const left = Math.floor(padAlongWidth / 2);
11669 const right = padAlongWidth - left;
11670 padInfo = { top, bottom, left, right, front, back, type: 'SAME' };
11671 }
11672 else if (pad === 'valid') {
11673 padInfo = {
11674 top: 0,
11675 bottom: 0,
11676 left: 0,
11677 right: 0,
11678 front: 0,
11679 back: 0,
11680 type: 'VALID'
11681 };
11682 outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth);
11683 outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
11684 outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
11685 }
11686 else {
11687 throw Error(`Unknown padding parameter: ${pad}`);
11688 }
11689 return { padInfo, outDepth, outHeight, outWidth };
11690 }
11691 /**
11692 * Rounds a value depending on the rounding mode
11693 * @param value
11694 * @param roundingMode A string from: 'ceil', 'round', 'floor'. If none is
11695 * provided, it will default to truncate.
11696 */
11697 function round(value, roundingMode) {
11698 if (!roundingMode) {
11699 return Math.trunc(value);
11700 }
11701 switch (roundingMode) {
11702 case 'round':
11703 // used for Caffe Conv
11704 return Math.round(value);
11705 case 'ceil':
11706 // used for Caffe Pool
11707 return Math.ceil(value);
11708 case 'floor':
11709 return Math.floor(value);
11710 default:
11711 throw new Error(`Unknown roundingMode ${roundingMode}`);
11712 }
11713 }
11714 function tupleValuesAreOne(param) {
11715 const [dimA, dimB, dimC] = parseTupleParam(param);
11716 return dimA === 1 && dimB === 1 && dimC === 1;
11717 }
11718 function eitherStridesOrDilationsAreOne(strides, dilations) {
11719 return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
11720 }
11721 /**
11722 * Convert Conv2D dataFormat from 'NHWC'|'NCHW' to
11723 * 'channelsLast'|'channelsFirst'
11724 * @param dataFormat in 'NHWC'|'NCHW' mode
11725 * @return dataFormat in 'channelsLast'|'channelsFirst' mode
11726 * @throws unknown dataFormat
11727 */
11728 function convertConv2DDataFormat(dataFormat) {
11729 if (dataFormat === 'NHWC') {
11730 return 'channelsLast';
11731 }
11732 else if (dataFormat === 'NCHW') {
11733 return 'channelsFirst';
11734 }
11735 else {
11736 throw new Error(`Unknown dataFormat ${dataFormat}`);
11737 }
11738 }
11739 /**
11740 * Check validity of pad when using dimRoundingMode.
11741 * @param opDesc A string of op description
11742 * @param pad The type of padding algorithm.
11743 * - `same` and stride 1: output will be of same size as input,
11744 * regardless of filter size.
11745 * - `valid` output will be smaller than input if filter is larger
11746 * than 1x1.
11747 * - For more info, see this guide:
11748 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
11749 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
11750 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
11751 * provided, it will default to truncate.
11752 * @throws unknown padding parameter
11753 */
11754 function checkPadOnDimRoundingMode(opDesc, pad, dimRoundingMode) {
11755 if (dimRoundingMode != null) {
11756 if (typeof pad === 'string') {
11757 throw Error(`Error in ${opDesc}: pad must be an integer when using ` +
11758 `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
11759 }
11760 else if (typeof pad === 'number') {
11761 assert(isInt(pad), () => `Error in ${opDesc}: pad must be an integer when using ` +
11762 `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
11763 }
11764 else if (typeof pad === 'object') {
11765 pad.forEach(p => {
11766 p.forEach(v => {
11767 assert(isInt(v), () => `Error in ${opDesc}: pad must be an integer when using ` +
11768 `dimRoundingMode ${dimRoundingMode} but got pad ${v}.`);
11769 });
11770 });
11771 }
11772 else {
11773 throw Error(`Error in ${opDesc}: Unknown padding parameter: ${pad}`);
11774 }
11775 }
11776 }
11777
11778 /**
11779 * @license
11780 * Copyright 2020 Google LLC. All Rights Reserved.
11781 * Licensed under the Apache License, Version 2.0 (the "License");
11782 * you may not use this file except in compliance with the License.
11783 * You may obtain a copy of the License at
11784 *
11785 * http://www.apache.org/licenses/LICENSE-2.0
11786 *
11787 * Unless required by applicable law or agreed to in writing, software
11788 * distributed under the License is distributed on an "AS IS" BASIS,
11789 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11790 * See the License for the specific language governing permissions and
11791 * limitations under the License.
11792 * =============================================================================
11793 */
11794 /**
11795 * Reshapes a `tf.Tensor` to a given shape.
11796 *
11797 * Given an input tensor, returns a new tensor with the same values as the
11798 * input tensor with shape `shape`.
11799 *
11800 * If one component of shape is the special value -1, the size of that
11801 * dimension is computed so that the total size remains constant. In
11802 * particular, a shape of [-1] flattens into 1-D. At most one component of
11803 * shape can be -1.
11804 *
11805 * If shape is 1-D or higher, then the operation returns a tensor with shape
11806 * shape filled with the values of tensor. In this case, the number of
11807 * elements implied by shape must be the same as the number of elements in
11808 * tensor.
11809 *
11810 * ```js
11811 * const x = tf.tensor1d([1, 2, 3, 4]);
11812 * x.reshape([2, 2]).print();
11813 * ```
11814 *
11815 * @param x The input tensor to be reshaped.
11816 * @param shape An array of integers defining the output tensor shape.
11817 *
11818 * @doc {heading: 'Tensors', subheading: 'Transformations'}
11819 */
11820 function reshape_(x, shape) {
11821 const $x = convertToTensor(x, 'x', 'reshape', 'string_or_numeric');
11822 const inputs = { x: $x };
11823 const attrs = { shape };
11824 return ENGINE.runKernel(Reshape, inputs, attrs);
11825 }
11826 const reshape = op({ reshape_ });
11827
11828 /**
11829 * @license
11830 * Copyright 2020 Google LLC. All Rights Reserved.
11831 * Licensed under the Apache License, Version 2.0 (the "License");
11832 * you may not use this file except in compliance with the License.
11833 * You may obtain a copy of the License at
11834 *
11835 * http://www.apache.org/licenses/LICENSE-2.0
11836 *
11837 * Unless required by applicable law or agreed to in writing, software
11838 * distributed under the License is distributed on an "AS IS" BASIS,
11839 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11840 * See the License for the specific language governing permissions and
11841 * limitations under the License.
11842 * =============================================================================
11843 */
11844 /**
11845 * Computes the 2D average pooling of an image.
11846 *
11847 * @param x The input tensor, of rank 4 or rank 3 of shape
11848 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
11849 * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
11850 * `filterSize` is a single number, then `filterHeight == filterWidth`.
11851 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
11852 * `strides` is a single number, then `strideHeight == strideWidth`.
11853 * @param pad The type of padding algorithm:
11854 * - `same` and stride 1: output will be of same size as input,
11855 * regardless of filter size.
11856 * - `valid`: output will be smaller than input if filter is larger
11857 * than 1x1.
11858 * - For more info, see this guide:
11859 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
11860 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
11861 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
11862 * provided, it will default to truncate.
11863 */
11864 function avgPool_(x, filterSize, strides, pad, dimRoundingMode) {
11865 const $x = convertToTensor(x, 'x', 'avgPool', 'float32');
11866 const dilations = 1;
11867 assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool: Either strides or dilations must be 1. ' +
11868 `Got strides ${strides} and dilations '${dilations}'`);
11869 let x4D = $x;
11870 let reshapedTo4D = false;
11871 if ($x.rank === 3) {
11872 reshapedTo4D = true;
11873 x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
11874 }
11875 assert(x4D.rank === 4, () => `Error in avgPool: x must be rank 4 but got rank ${x4D.rank}.`);
11876 checkPadOnDimRoundingMode('avgPool', pad, dimRoundingMode);
11877 const inputs = { x: x4D };
11878 const attrs = { filterSize, strides, pad, dimRoundingMode };
11879 // tslint:disable-next-line: no-unnecessary-type-assertion
11880 let res = ENGINE.runKernel(AvgPool, inputs, attrs);
11881 res = cast(res, $x.dtype);
11882 if (reshapedTo4D) {
11883 return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
11884 }
11885 return res;
11886 }
11887 const avgPool = op({ avgPool_ });
11888
11889 /**
11890 * @license
11891 * Copyright 2020 Google LLC. All Rights Reserved.
11892 * Licensed under the Apache License, Version 2.0 (the "License");
11893 * you may not use this file except in compliance with the License.
11894 * You may obtain a copy of the License at
11895 *
11896 * http://www.apache.org/licenses/LICENSE-2.0
11897 *
11898 * Unless required by applicable law or agreed to in writing, software
11899 * distributed under the License is distributed on an "AS IS" BASIS,
11900 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11901 * See the License for the specific language governing permissions and
11902 * limitations under the License.
11903 * =============================================================================
11904 */
11905 /**
11906 * Computes the 3D average pooling.
11907 *
11908 * ```js
11909 * const x = tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]);
11910 * const result = tf.avgPool3d(x, 2, 1, 'valid');
11911 * result.print();
11912 * ```
11913 *
11914 * @param x The input tensor, of rank 5 or rank 4 of shape
11915 * `[batch, depth, height, width, inChannels]`.
11916 * @param filterSize The filter size:
11917 * `[filterDepth, filterHeight, filterWidth]`.
11918 * If `filterSize` is a single number,
11919 * then `filterDepth == filterHeight == filterWidth`.
11920 * @param strides The strides of the pooling:
11921 * `[strideDepth, strideHeight, strideWidth]`.
11922 * If `strides` is a single number,
11923 * then `strideDepth == strideHeight == strideWidth`.
11924 * @param pad The type of padding algorithm.
11925 * - `same` and stride 1: output will be of same size as input,
11926 * regardless of filter size.
11927 * - `valid`: output will be smaller than input if filter is larger
11928 * than 1*1x1.
11929 * - For more info, see this guide:
11930 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
11931 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
11932 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
11933 * provided, it will default to truncate.
11934 * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
11935 * "NDHWC". Specify the data format of the input and output data. With the
11936 * default format "NDHWC", the data is stored in the order of: [batch,
11937 * depth, height, width, channels]. Only "NDHWC" is currently supported.
11938 *
11939 * @doc {heading: 'Operations', subheading: 'Convolution'}
11940 */
11941 function avgPool3d_(x, filterSize, strides, pad, dimRoundingMode, dataFormat = 'NDHWC') {
11942 const $x = convertToTensor(x, 'x', 'avgPool3d', 'float32');
11943 let x5D = $x;
11944 let reshapedTo5D = false;
11945 if ($x.rank === 4) {
11946 reshapedTo5D = true;
11947 x5D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
11948 }
11949 assert(x5D.rank === 5, () => `Error in avgPool3d: x must be rank 5 but got rank ${x5D.rank}.`);
11950 assert(dataFormat === 'NDHWC', () => `Error in avgPool3d: Only NDHWC is currently supported, ` +
11951 `but got dataFormat of ${dataFormat}`);
11952 checkPadOnDimRoundingMode('avgPool3d', pad, dimRoundingMode);
11953 const inputs = { x: x5D };
11954 const attrs = { filterSize, strides, pad, dimRoundingMode, dataFormat };
11955 // tslint:disable-next-line: no-unnecessary-type-assertion
11956 let res = ENGINE.runKernel(AvgPool3D, inputs, attrs);
11957 res = cast(res, x5D.dtype);
11958 if (reshapedTo5D) {
11959 return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
11960 }
11961 return res;
11962 }
11963 const avgPool3d = op({ avgPool3d_ });
11964
11965 /**
11966 * @license
11967 * Copyright 2020 Google LLC. All Rights Reserved.
11968 * Licensed under the Apache License, Version 2.0 (the "License");
11969 * you may not use this file except in compliance with the License.
11970 * You may obtain a copy of the License at
11971 *
11972 * http://www.apache.org/licenses/LICENSE-2.0
11973 *
11974 * Unless required by applicable law or agreed to in writing, software
11975 * distributed under the License is distributed on an "AS IS" BASIS,
11976 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11977 * See the License for the specific language governing permissions and
11978 * limitations under the License.
11979 * =============================================================================
11980 */
11981 /**
11982 * Concatenates a list of `tf.Tensor`s along a given axis.
11983 *
11984 * The tensors ranks and types must match, and their sizes must match in all
11985 * dimensions except `axis`.
11986 *
11987 * Also available are stricter rank-specific methods that assert that
11988 * `tensors` are of the given rank:
11989 * - `tf.concat1d`
11990 * - `tf.concat2d`
11991 * - `tf.concat3d`
11992 * - `tf.concat4d`
11993 *
11994 * Except `tf.concat1d` (which does not have axis param), all methods have
11995 * same signature as this method.
11996 *
11997 * ```js
11998 * const a = tf.tensor1d([1, 2]);
11999 * const b = tf.tensor1d([3, 4]);
12000 * a.concat(b).print(); // or a.concat(b)
12001 * ```
12002 *
12003 * ```js
12004 * const a = tf.tensor1d([1, 2]);
12005 * const b = tf.tensor1d([3, 4]);
12006 * const c = tf.tensor1d([5, 6]);
12007 * tf.concat([a, b, c]).print();
12008 * ```
12009 *
12010 * ```js
12011 * const a = tf.tensor2d([[1, 2], [10, 20]]);
12012 * const b = tf.tensor2d([[3, 4], [30, 40]]);
12013 * const axis = 1;
12014 * tf.concat([a, b], axis).print();
12015 * ```
12016 * @param tensors A list of tensors to concatenate.
12017 * @param axis The axis to concate along. Defaults to 0 (the first dim).
12018 *
12019 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
12020 */
12021 function concat_(tensors, axis = 0) {
12022 assert(tensors.length >= 1, () => 'Pass at least one tensor to concat');
12023 const $tensors = convertToTensorArray(tensors, 'tensors', 'concat', 'string_or_numeric');
12024 if ($tensors[0].dtype === 'complex64') {
12025 $tensors.forEach(tensor => {
12026 if (tensor.dtype !== 'complex64') {
12027 throw new Error(`Cannot concatenate complex64 tensors with a tensor
12028 with dtype ${tensor.dtype}. `);
12029 }
12030 });
12031 }
12032 if ($tensors.length === 1) {
12033 return clone($tensors[0]);
12034 }
12035 const inputs = $tensors;
12036 const attr = { axis };
12037 return ENGINE.runKernel(Concat, inputs, attr);
12038 }
12039 const concat = op({ concat_ });
12040
12041 /**
12042 * @license
12043 * Copyright 2018 Google LLC. All Rights Reserved.
12044 * Licensed under the Apache License, Version 2.0 (the "License");
12045 * you may not use this file except in compliance with the License.
12046 * You may obtain a copy of the License at
12047 *
12048 * http://www.apache.org/licenses/LICENSE-2.0
12049 *
12050 * Unless required by applicable law or agreed to in writing, software
12051 * distributed under the License is distributed on an "AS IS" BASIS,
12052 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12053 * See the License for the specific language governing permissions and
12054 * limitations under the License.
12055 * =============================================================================
12056 */
12057 /**
12058 * Computes sigmoid element-wise, `1 / (1 + exp(-x))`
12059 *
12060 * ```js
12061 * const x = tf.tensor1d([0, -1, 2, -3]);
12062 *
12063 * x.sigmoid().print(); // or tf.sigmoid(x)
12064 * ```
12065 * @param x The input tensor.
12066 *
12067 * @doc {heading: 'Operations', subheading: 'Basic math'}
12068 */
12069 function sigmoid_(x) {
12070 const $x = convertToTensor(x, 'x', 'sigmoid', 'float32');
12071 const inputs = { x: $x };
12072 return ENGINE.runKernel(Sigmoid, inputs);
12073 }
12074 const sigmoid = op({ sigmoid_ });
12075
12076 /**
12077 * @license
12078 * Copyright 2018 Google LLC. All Rights Reserved.
12079 * Licensed under the Apache License, Version 2.0 (the "License");
12080 * you may not use this file except in compliance with the License.
12081 * You may obtain a copy of the License at
12082 *
12083 * http://www.apache.org/licenses/LICENSE-2.0
12084 *
12085 * Unless required by applicable law or agreed to in writing, software
12086 * distributed under the License is distributed on an "AS IS" BASIS,
12087 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12088 * See the License for the specific language governing permissions and
12089 * limitations under the License.
12090 * =============================================================================
12091 */
12092 /**
12093 * Extracts a slice from a `tf.Tensor` starting at coordinates `begin`
12094 * and is of size `size`.
12095 *
12096 * Also available are stricter rank-specific methods with the same signature
12097 * as this method that assert that `x` is of the given rank:
12098 * - `tf.slice1d`
12099 * - `tf.slice2d`
12100 * - `tf.slice3d`
12101 * - `tf.slice4d`
12102 *
12103 * ```js
12104 * const x = tf.tensor1d([1, 2, 3, 4]);
12105 *
12106 * x.slice([1], [2]).print();
12107 * ```
12108 *
12109 * ```js
12110 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
12111 *
12112 * x.slice([1, 0], [1, 2]).print();
12113 * ```
12114 * @param x The input `tf.Tensor` to slice from.
12115 * @param begin The coordinates to start the slice from. The length can be
12116 * less than the rank of x - the rest of the axes will have implicit 0 as
12117 * start. Can also be a single number, in which case it specifies the
12118 * first axis.
12119 * @param size The size of the slice. The length can be less than the rank of
12120 * x - the rest of the axes will have implicit -1. A value of -1 requests
12121 * the rest of the dimensions in the axis. Can also be a single number,
12122 * in which case it specifies the size of the first axis.
12123 *
12124 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
12125 */
12126 function slice_(x, begin, size) {
12127 const $x = convertToTensor(x, 'x', 'slice', 'string_or_numeric');
12128 if ($x.rank === 0) {
12129 throw new Error('Slicing scalar is not possible');
12130 }
12131 const inputs = { x: $x };
12132 const attrs = { begin, size };
12133 return ENGINE.runKernel(Slice, inputs, attrs);
12134 }
12135 const slice = op({ slice_ });
12136
12137 /**
12138 * @license
12139 * Copyright 2018 Google LLC. All Rights Reserved.
12140 * Licensed under the Apache License, Version 2.0 (the "License");
12141 * you may not use this file except in compliance with the License.
12142 * You may obtain a copy of the License at
12143 *
12144 * http://www.apache.org/licenses/LICENSE-2.0
12145 *
12146 * Unless required by applicable law or agreed to in writing, software
12147 * distributed under the License is distributed on an "AS IS" BASIS,
12148 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12149 * See the License for the specific language governing permissions and
12150 * limitations under the License.
12151 * =============================================================================
12152 */
12153 /**
12154 * Computes hyperbolic tangent of the input `tf.Tensor` element-wise: `tanh(x)`
12155 *
12156 * ```js
12157 * const x = tf.tensor1d([0, 1, -1, 70]);
12158 *
12159 * x.tanh().print(); // or tf.tanh(x)
12160 * ```
12161 * @param x The input tensor.
12162 *
12163 * @doc {heading: 'Operations', subheading: 'Basic math'}
12164 */
12165 function tanh_(x) {
12166 const $x = convertToTensor(x, 'x', 'tanh', 'float32');
12167 const inputs = { x: $x };
12168 return ENGINE.runKernel(Tanh, inputs);
12169 }
12170 const tanh$1 = op({ tanh_ });
12171
12172 /**
12173 * @license
12174 * Copyright 2020 Google LLC. All Rights Reserved.
12175 * Licensed under the Apache License, Version 2.0 (the "License");
12176 * you may not use this file except in compliance with the License.
12177 * You may obtain a copy of the License at
12178 *
12179 * http://www.apache.org/licenses/LICENSE-2.0
12180 *
12181 * Unless required by applicable law or agreed to in writing, software
12182 * distributed under the License is distributed on an "AS IS" BASIS,
12183 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12184 * See the License for the specific language governing permissions and
12185 * limitations under the License.
12186 * =============================================================================
12187 */
12188 /**
12189 * Computes the next state and output of a BasicLSTMCell.
12190 *
12191 * Returns `[newC, newH]`.
12192 *
12193 * Derived from tf.contrib.rnn.BasicLSTMCell.
12194 *
12195 * @param forgetBias Forget bias for the cell.
12196 * @param lstmKernel The weights for the cell.
12197 * @param lstmBias The bias for the cell.
12198 * @param data The input to the cell.
12199 * @param c Previous cell state.
12200 * @param h Previous cell output.
12201 *
12202 * @doc {heading: 'Operations', subheading: 'RNN'}
12203 */
12204 function basicLSTMCell_(forgetBias, lstmKernel, lstmBias, data, c, h) {
12205 const $forgetBias = convertToTensor(forgetBias, 'forgetBias', 'basicLSTMCell');
12206 const $lstmKernel = convertToTensor(lstmKernel, 'lstmKernel', 'basicLSTMCell');
12207 const $lstmBias = convertToTensor(lstmBias, 'lstmBias', 'basicLSTMCell');
12208 const $data = convertToTensor(data, 'data', 'basicLSTMCell');
12209 const $c = convertToTensor(c, 'c', 'basicLSTMCell');
12210 const $h = convertToTensor(h, 'h', 'basicLSTMCell');
12211 const combined = concat([$data, $h], 1);
12212 const weighted = matMul(combined, $lstmKernel);
12213 const res = add$1(weighted, $lstmBias);
12214 // i = input_gate, j = new_input, f = forget_gate, o = output_gate
12215 const batchSize = res.shape[0];
12216 const sliceCols = res.shape[1] / 4;
12217 const sliceSize = [batchSize, sliceCols];
12218 const i = slice(res, [0, 0], sliceSize);
12219 const j = slice(res, [0, sliceCols], sliceSize);
12220 const f = slice(res, [0, sliceCols * 2], sliceSize);
12221 const o = slice(res, [0, sliceCols * 3], sliceSize);
12222 const newC = add$1(mul(sigmoid(i), tanh$1(j)), mul($c, sigmoid(add$1($forgetBias, f))));
12223 const newH = mul(tanh$1(newC), sigmoid(o));
12224 return [newC, newH];
12225 }
12226 const basicLSTMCell = op({ basicLSTMCell_ });
12227
12228 /**
12229 * @license
12230 * Copyright 2020 Google LLC. All Rights Reserved.
12231 * Licensed under the Apache License, Version 2.0 (the "License");
12232 * you may not use this file except in compliance with the License.
12233 * You may obtain a copy of the License at
12234 *
12235 * http://www.apache.org/licenses/LICENSE-2.0
12236 *
12237 * Unless required by applicable law or agreed to in writing, software
12238 * distributed under the License is distributed on an "AS IS" BASIS,
12239 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12240 * See the License for the specific language governing permissions and
12241 * limitations under the License.
12242 * =============================================================================
12243 */
12244 /**
12245 * This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of
12246 * shape `blockShape + [batch]`, interleaves these blocks back into the grid
12247 * defined by the spatial dimensions `[1, ..., M]`, to obtain a result with
12248 * the same rank as the input. The spatial dimensions of this intermediate
12249 * result are then optionally cropped according to `crops` to produce the
12250 * output. This is the reverse of `tf.spaceToBatchND`. See below for a precise
12251 * description.
12252 *
12253 * ```js
12254 * const x = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);
12255 * const blockShape = [2, 2];
12256 * const crops = [[0, 0], [0, 0]];
12257 *
12258 * x.batchToSpaceND(blockShape, crops).print();
12259 * ```
12260 *
12261 * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
12262 * remainingShape`, where spatialShape has `M` dimensions.
12263 * @param blockShape A 1-D array. Must have shape `[M]`, all values must
12264 * be >= 1.
12265 * @param crops A 2-D array. Must have shape `[M, 2]`, all values must be >= 0.
12266 * `crops[i] = [cropStart, cropEnd]` specifies the amount to crop from input
12267 * dimension `i + 1`, which corresponds to spatial dimension `i`. It is required
12268 * that `cropStart[i] + cropEnd[i] <= blockShape[i] * inputShape[i + 1]`
12269 *
12270 * This operation is equivalent to the following steps:
12271 *
12272 * 1. Reshape `x` to `reshaped` of shape: `[blockShape[0], ...,
12273 * blockShape[M-1], batch / prod(blockShape), x.shape[1], ...,
12274 * x.shape[N-1]]`
12275 *
12276 * 2. Permute dimensions of `reshaped`to produce `permuted` of shape `[batch /
12277 * prod(blockShape),x.shape[1], blockShape[0], ..., x.shape[M],
12278 * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
12279 *
12280 * 3. Reshape `permuted` to produce `reshapedPermuted` of shape `[batch /
12281 * prod(blockShape),x.shape[1] * blockShape[0], ..., x.shape[M] *
12282 * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
12283 *
12284 * 4. Crop the start and end of dimensions `[1, ..., M]` of `reshapedPermuted`
12285 * according to `crops` to produce the output of shape: `[batch /
12286 * prod(blockShape),x.shape[1] * blockShape[0] - crops[0,0] - crops[0,1],
12287 * ..., x.shape[M] * blockShape[M-1] - crops[M-1,0] -
12288 * crops[M-1,1],x.shape[M+1], ..., x.shape[N-1]]`
12289 *
12290 * @doc {heading: 'Tensors', subheading: 'Transformations'}
12291 */
12292 function batchToSpaceND_(x, blockShape, crops) {
12293 const $x = convertToTensor(x, 'x', 'batchToSpaceND');
12294 const prod = blockShape.reduce((a, b) => a * b);
12295 assert($x.rank >= 1 + blockShape.length, () => `input rank is ${$x.rank} but should be > than blockShape.length ${blockShape.length}`);
12296 assert(crops.length === blockShape.length, () => `crops.length is ${crops.length} but should be equal to blockShape.length ${blockShape.length}`);
12297 assert($x.shape[0] % prod === 0, () => `input tensor batch is ${$x.shape[0]} but is not divisible by the product of ` +
12298 `the elements of blockShape ${blockShape.join(' * ')} === ${prod}`);
12299 const inputs = { x: $x };
12300 const attrs = { blockShape, crops };
12301 return ENGINE.runKernel(BatchToSpaceND, inputs, attrs);
12302 }
12303 const batchToSpaceND = op({ batchToSpaceND_ });
12304
12305 function xAs4D(x) {
12306 let x4D;
12307 if (x.rank === 0 || x.rank === 1) {
12308 x4D = reshape(x, [1, 1, 1, x.size]);
12309 }
12310 else if (x.rank === 2) {
12311 x4D = reshape(x, [1, 1, x.shape[0], x.shape[1]]);
12312 }
12313 else if (x.rank === 3) {
12314 x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
12315 }
12316 else {
12317 x4D = x;
12318 }
12319 return x4D;
12320 }
12321
12322 /**
12323 * @license
12324 * Copyright 2020 Google LLC. All Rights Reserved.
12325 * Licensed under the Apache License, Version 2.0 (the "License");
12326 * you may not use this file except in compliance with the License.
12327 * You may obtain a copy of the License at
12328 *
12329 * http://www.apache.org/licenses/LICENSE-2.0
12330 *
12331 * Unless required by applicable law or agreed to in writing, software
12332 * distributed under the License is distributed on an "AS IS" BASIS,
12333 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12334 * See the License for the specific language governing permissions and
12335 * limitations under the License.
12336 * =============================================================================
12337 */
12338 /**
12339 * Batch normalization.
12340 *
12341 * As described in
12342 * [http://arxiv.org/abs/1502.03167](http://arxiv.org/abs/1502.03167).
12343 *
12344 * Mean, variance, scale, and offset can be of two shapes:
12345 * - The same shape as the input.
12346 * - In the common case, the depth dimension is the last dimension of x, so
12347 * the values would be an `tf.Tensor1D` of shape [depth].
12348 *
12349 * Also available are stricter rank-specific methods with the same signature
12350 * as this method that assert that parameters passed are of given rank
12351 * - `tf.batchNorm2d`
12352 * - `tf.batchNorm3d`
12353 * - `tf.batchNorm4d`
12354 *
12355 * @param x The input Tensor.
12356 * @param mean A mean Tensor.
12357 * @param variance A variance Tensor.
12358 * @param offset An offset Tensor.
12359 * @param scale A scale Tensor.
12360 * @param varianceEpsilon A small float number to avoid dividing by 0.
12361 *
12362 * @doc {heading: 'Operations', subheading: 'Normalization'}
12363 */
12364 function batchNorm_(x, mean, variance, offset, scale, varianceEpsilon) {
12365 if (varianceEpsilon == null) {
12366 varianceEpsilon = 0.001;
12367 }
12368 const $x = convertToTensor(x, 'x', 'batchNorm');
12369 const $mean = convertToTensor(mean, 'mean', 'batchNorm');
12370 const $variance = convertToTensor(variance, 'variance', 'batchNorm');
12371 let $scale;
12372 if (scale != null) {
12373 $scale = convertToTensor(scale, 'scale', 'batchNorm');
12374 }
12375 let $offset;
12376 if (offset != null) {
12377 $offset = convertToTensor(offset, 'offset', 'batchNorm');
12378 }
12379 assert($mean.rank === $variance.rank, () => 'Batch normalization gradient requires mean and variance to have ' +
12380 'equal ranks.');
12381 assert($offset == null || $mean.rank === $offset.rank, () => 'Batch normalization gradient requires mean and offset to have ' +
12382 'equal ranks.');
12383 assert($scale == null || $mean.rank === $scale.rank, () => 'Batch normalization gradient requires mean and scale to have ' +
12384 'equal ranks.');
12385 const x4D = xAs4D($x);
12386 const inputs = {
12387 x: x4D,
12388 scale: $scale,
12389 offset: $offset,
12390 mean: $mean,
12391 variance: $variance
12392 };
12393 const attrs = { varianceEpsilon };
12394 // tslint:disable-next-line: no-unnecessary-type-assertion
12395 const res = ENGINE.runKernel(FusedBatchNorm, inputs, attrs);
12396 return reshape(res, $x.shape);
12397 }
12398 const batchNorm = op({ batchNorm_ });
12399
12400 /**
12401 * Batch normalization, strictly for 2D. For the more relaxed version, see
12402 * `tf.batchNorm`.
12403 *
12404 * @param x The input Tensor.
12405 * @param mean A mean Tensor.
12406 * @param variance A variance Tensor.
12407 * @param offset An offset Tensor.
12408 * @param scale A scale Tensor.
12409 * @param varianceEpsilon A small float number to avoid dividing by 0.
12410 */
12411 function batchNorm2d_(x, mean, variance, offset, scale, varianceEpsilon) {
12412 const $x = convertToTensor(x, 'x', 'batchNorm');
12413 const $mean = convertToTensor(mean, 'mean', 'batchNorm');
12414 const $variance = convertToTensor(variance, 'variance', 'batchNorm');
12415 let $scale;
12416 if (scale != null) {
12417 $scale = convertToTensor(scale, 'scale', 'batchNorm');
12418 }
12419 let $offset;
12420 if (offset != null) {
12421 $offset = convertToTensor(offset, 'offset', 'batchNorm');
12422 }
12423 assert($x.rank === 2, () => `Error in batchNorm2D: x must be rank 2 but got rank ` +
12424 `${$x.rank}.`);
12425 assert($mean.rank === 2 || $mean.rank === 1, () => `Error in batchNorm2D: mean must be rank 2 or rank 1 but ` +
12426 `got rank ${$mean.rank}.`);
12427 assert($variance.rank === 2 || $variance.rank === 1, () => `Error in batchNorm2D: variance must be rank 2 or rank 1 ` +
12428 `but got rank ${$variance.rank}.`);
12429 if ($scale != null) {
12430 assert($scale.rank === 2 || $scale.rank === 1, () => `Error in batchNorm2D: scale must be rank 2 or rank 1 ` +
12431 `but got rank ${$scale.rank}.`);
12432 }
12433 if ($offset != null) {
12434 assert($offset.rank === 2 || $offset.rank === 1, () => `Error in batchNorm2D: offset must be rank 2 or rank 1 ` +
12435 `but got rank ${$offset.rank}.`);
12436 }
12437 return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
12438 }
12439 const batchNorm2d = op({ batchNorm2d_ });
12440
12441 /**
12442 * Batch normalization, strictly for 3D. For the more relaxed version, see
12443 * `tf.batchNorm`.
12444 *
12445 * @param x The input Tensor.
12446 * @param mean A mean Tensor.
12447 * @param variance A variance Tensor.
12448 * @param offset An offset Tensor.
12449 * @param scale A scale Tensor.
12450 * @param varianceEpsilon A small float number to avoid dividing by 0.
12451 */
12452 function batchNorm3d_(x, mean, variance, offset, scale, varianceEpsilon) {
12453 const $x = convertToTensor(x, 'x', 'batchNorm');
12454 const $mean = convertToTensor(mean, 'mean', 'batchNorm');
12455 const $variance = convertToTensor(variance, 'variance', 'batchNorm');
12456 let $scale;
12457 if (scale != null) {
12458 $scale = convertToTensor(scale, 'scale', 'batchNorm');
12459 }
12460 let $offset;
12461 if (offset != null) {
12462 $offset = convertToTensor(offset, 'offset', 'batchNorm');
12463 }
12464 assert($x.rank === 3, () => `Error in batchNorm3D: x must be rank 3 but got rank ` +
12465 `${$x.rank}.`);
12466 assert($mean.rank === 3 || $mean.rank === 1, () => `Error in batchNorm3D: mean must be rank 3 or rank 1 but ` +
12467 `got rank ${$mean.rank}.`);
12468 assert($variance.rank === 3 || $variance.rank === 1, () => `Error in batchNorm3D: variance must be rank 3 or rank 1 ` +
12469 `but got rank ${$variance.rank}.`);
12470 if ($scale != null) {
12471 assert($scale.rank === 3 || $scale.rank === 1, () => `Error in batchNorm3D: scale must be rank 3 or rank 1 ` +
12472 `but got rank ${$scale.rank}.`);
12473 }
12474 if ($offset != null) {
12475 assert($offset.rank === 3 || $offset.rank === 1, () => `Error in batchNorm3D: offset must be rank 3 or rank 1 ` +
12476 `but got rank ${$offset.rank}.`);
12477 }
12478 return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
12479 }
12480 const batchNorm3d = op({ batchNorm3d_ });
12481
12482 /**
12483 * Batch normalization, strictly for 4D. For the more relaxed version, see
12484 * `tf.batchNorm`.
12485 *
12486 * @param x The input Tensor.
12487 * @param mean A mean Tensor.
12488 * @param variance A variance Tensor.
12489 * @param offset An offset Tensor.
12490 * @param scale A scale Tensor.
12491 * @param varianceEpsilon A small float number to avoid dividing by 0.
12492 */
12493 function batchNorm4d_(x, mean, variance, offset, scale, varianceEpsilon) {
12494 const $x = convertToTensor(x, 'x', 'batchNorm');
12495 const $mean = convertToTensor(mean, 'mean', 'batchNorm');
12496 const $variance = convertToTensor(variance, 'variance', 'batchNorm');
12497 let $scale;
12498 if (scale != null) {
12499 $scale = convertToTensor(scale, 'scale', 'batchNorm');
12500 }
12501 let $offset;
12502 if (offset != null) {
12503 $offset = convertToTensor(offset, 'offset', 'batchNorm');
12504 }
12505 assert($x.rank === 4, () => `Error in batchNorm4D: x must be rank 4 but got rank ` +
12506 `${$x.rank}.`);
12507 assert($mean.rank === 4 || $mean.rank === 1, () => `Error in batchNorm4D: mean must be rank 4 or rank 1 but ` +
12508 `got rank ${$mean.rank}.`);
12509 assert($variance.rank === 4 || $variance.rank === 1, () => `Error in batchNorm4D: variance must be rank 4 or rank 1 ` +
12510 `but got rank ${$variance.rank}.`);
12511 if ($scale != null) {
12512 assert($scale.rank === 4 || $scale.rank === 1, () => `Error in batchNorm4D: scale must be rank 4 or rank 1 ` +
12513 `but got rank ${$scale.rank}.`);
12514 }
12515 if ($offset != null) {
12516 assert($offset.rank === 4 || $offset.rank === 1, () => `Error in batchNorm4D: offset must be rank 4 or rank 1 ` +
12517 `but got rank ${$offset.rank}.`);
12518 }
12519 return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
12520 }
12521 const batchNorm4d = op({ batchNorm4d_ });
12522
12523 /**
12524 * @license
12525 * Copyright 2020 Google LLC. All Rights Reserved.
12526 * Licensed under the Apache License, Version 2.0 (the "License");
12527 * you may not use this file except in compliance with the License.
12528 * You may obtain a copy of the License at
12529 *
12530 * http://www.apache.org/licenses/LICENSE-2.0
12531 *
12532 * Unless required by applicable law or agreed to in writing, software
12533 * distributed under the License is distributed on an "AS IS" BASIS,
12534 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12535 * See the License for the specific language governing permissions and
12536 * limitations under the License.
12537 * =============================================================================
12538 */
12539 /**
12540 * Outputs a vector with length `size` and the same dtype as `weights`.
12541 *
12542 * If `weights` are empty, then index `i` stores the number of times the value
12543 * `i` is counted in `x`. If `weights` are non-empty, then index `i` stores the
12544 * sum of the value in `weights` at each index where the corresponding value in
12545 * `x` is `i`.
12546 *
12547 * Values in `x` outside of the range [0, size) are ignored.
12548 *
12549 * @param x The input int tensor, rank 1.
12550 * @param weights The weights tensor, must have the same shape as x, or a
12551 * length-0 Tensor, in which case it acts as all weights equal to 1.
12552 * @param size Non-negative integer.
12553 *
12554 * @doc {heading: 'Operations', subheading: 'Reduction'}
12555 */
12556 function bincount_(x, weights, size) {
12557 const $x = convertToTensor(x, 'x', 'bincount');
12558 const $weights = convertToTensor(weights, 'weights', 'bincount');
12559 assert($x.dtype === 'int32', () => `Error in bincount: input ` +
12560 `dtype must be int32, but got ${$x.dtype}`);
12561 assert(size >= 0, () => `size must be non-negative, but got ${size}.`);
12562 assert($weights.size === $x.size || $weights.size === 0, () => `Error in bincount: weights must have the same size as input or` +
12563 `0-length, but got input shape: ${$x.shape}, weights shape: ` +
12564 `${$weights.shape}.`);
12565 const inputs = { x: $x, weights: $weights };
12566 const attrs = { size };
12567 return ENGINE.runKernel(Bincount, inputs, attrs);
12568 }
12569 const bincount = op({ bincount_ });
12570
12571 /**
12572 * @license
12573 * Copyright 2021 Google LLC. All Rights Reserved.
12574 * Licensed under the Apache License, Version 2.0 (the "License");
12575 * you may not use this file except in compliance with the License.
12576 * You may obtain a copy of the License at
12577 *
12578 * http://www.apache.org/licenses/LICENSE-2.0
12579 *
12580 * Unless required by applicable law or agreed to in writing, software
12581 * distributed under the License is distributed on an "AS IS" BASIS,
12582 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12583 * See the License for the specific language governing permissions and
12584 * limitations under the License.
12585 * =============================================================================
12586 */
12587 /**
12588 * Return the shape of s0 op s1 with broadcast.
12589 *
12590 * compute r0, the broadcasted shape as a tensor.
12591 * s0, s1 and r0 are all integer vectors.
12592 *
12593 * This function returns the shape of the result of an operation between
12594 * two tensors of size s0 and s1 performed with broadcast.
12595 *
12596 * @param s0 A tensor representing a shape
12597 * @param s1 A tensor representing a shape
12598 *
12599 * @doc {heading: 'Tensors', subheading: 'Transformations'}
12600 */
12601 function broadcastArgs_(s0, s1) {
12602 const shape1Input = convertToTensor(s0, 's0', 'broadcastArgs', 'int32');
12603 const shape2Input = convertToTensor(s1, 's1', 'broadcastArgs', 'int32');
12604 if (shape1Input.rank !== 1) {
12605 throw new Error('broadcastArgs(): first input must be a vector (rank=1). ' +
12606 `Has rank ${shape1Input.rank}`);
12607 }
12608 if (shape2Input.rank !== 1) {
12609 throw new Error('broadcastArgs(): second input must be a vector (rank=1). ' +
12610 `Has rank ${shape2Input.rank}`);
12611 }
12612 const inputs = { s0: shape1Input, s1: shape2Input };
12613 return ENGINE.runKernel(BroadcastArgs, inputs);
12614 }
12615 const broadcastArgs = op({ broadcastArgs_ });
12616
12617 /**
12618 * @license
12619 * Copyright 2020 Google LLC. All Rights Reserved.
12620 * Licensed under the Apache License, Version 2.0 (the "License");
12621 * you may not use this file except in compliance with the License.
12622 * You may obtain a copy of the License at
12623 *
12624 * http://www.apache.org/licenses/LICENSE-2.0
12625 *
12626 * Unless required by applicable law or agreed to in writing, software
12627 * distributed under the License is distributed on an "AS IS" BASIS,
12628 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12629 * See the License for the specific language governing permissions and
12630 * limitations under the License.
12631 * =============================================================================
12632 */
12633 /**
12634 * Broadcast an array to a compatible shape NumPy-style.
12635 *
12636 * The tensor's shape is compared to the broadcast shape from end to beginning.
12637 * Ones are prepended to the tensor's shape until is has the same length as
12638 * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is
12639 * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then
12640 * the input tensor is tiled N times along that axis (using tf.tile).
12641 *
12642 * @param input The tensor that is to be broadcasted.
12643 * @param shape The input is to be broadcast to this shape.
12644 *
12645 * @doc {heading: 'Tensors', subheading: 'Transformations'}
12646 */
12647 function broadcastTo_(x, shape) {
12648 let input = convertToTensor(x, 'broadcastTo', 'x');
12649 const xShape = input.shape;
12650 if (shape.some(d => !(d > 0) || d % 1 !== 0)) {
12651 throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`);
12652 }
12653 if (shape.length < input.rank) {
12654 throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${input.rank}.`);
12655 }
12656 if (shape.length > input.rank) {
12657 const newShape = input.shape.slice();
12658 while (newShape.length < shape.length) {
12659 newShape.unshift(1);
12660 }
12661 input = reshape(input, newShape);
12662 }
12663 const inputShape = input.shape;
12664 const reps = Array.from(shape);
12665 for (let i = shape.length - 1; i >= 0; i--) {
12666 if (inputShape[i] === shape[i]) {
12667 reps[i] = 1;
12668 }
12669 else if (input.shape[i] !== 1) {
12670 throw new Error(`broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`);
12671 }
12672 }
12673 const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0);
12674 if (axes.length === 0) {
12675 return clone(input);
12676 }
12677 // TODO call broadcastTo kernel directly once backends implement broadcstTo
12678 const inputs = { x: input };
12679 const attrs = { reps };
12680 return ENGINE.runKernel(Tile, inputs, attrs);
12681 }
12682 const broadcastTo = op({ broadcastTo_ });
12683
12684 /**
12685 * @license
12686 * Copyright 2018 Google LLC. All Rights Reserved.
12687 * Licensed under the Apache License, Version 2.0 (the "License");
12688 * you may not use this file except in compliance with the License.
12689 * You may obtain a copy of the License at
12690 *
12691 * http://www.apache.org/licenses/LICENSE-2.0
12692 *
12693 * Unless required by applicable law or agreed to in writing, software
12694 * distributed under the License is distributed on an "AS IS" BASIS,
12695 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12696 * See the License for the specific language governing permissions and
12697 * limitations under the License.
12698 * =============================================================================
12699 */
12700 /**
12701 * Computes ceiling of input `tf.Tensor` element-wise: `ceil(x)`
12702 *
12703 * ```js
12704 * const x = tf.tensor1d([.6, 1.1, -3.3]);
12705 *
12706 * x.ceil().print(); // or tf.ceil(x)
12707 * ```
12708 * @param x The input Tensor.
12709 *
12710 * @doc {heading: 'Operations', subheading: 'Basic math'}
12711 */
12712 function ceil_(x) {
12713 const $x = convertToTensor(x, 'x', 'ceil', 'float32');
12714 const inputs = { x: $x };
12715 return ENGINE.runKernel(Ceil, inputs);
12716 }
12717 const ceil = op({ ceil_ });
12718
12719 /**
12720 * @license
12721 * Copyright 2018 Google LLC. All Rights Reserved.
12722 * Licensed under the Apache License, Version 2.0 (the "License");
12723 * you may not use this file except in compliance with the License.
12724 * You may obtain a copy of the License at
12725 *
12726 * http://www.apache.org/licenses/LICENSE-2.0
12727 *
12728 * Unless required by applicable law or agreed to in writing, software
12729 * distributed under the License is distributed on an "AS IS" BASIS,
12730 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12731 * See the License for the specific language governing permissions and
12732 * limitations under the License.
12733 * =============================================================================
12734 */
12735 /**
12736 * Clips values element-wise. `max(min(x, clipValueMax), clipValueMin)`
12737 *
12738 * ```js
12739 * const x = tf.tensor1d([-1, 2, -3, 4]);
12740 *
12741 * x.clipByValue(-2, 3).print(); // or tf.clipByValue(x, -2, 3)
12742 * ```
12743 * @param x The input tensor.
12744 * @param clipValueMin Lower-bound of range to be clipped to.
12745 * @param clipValueMax Upper-bound of range to be clipped to.
12746 *
12747 * @doc {heading: 'Operations', subheading: 'Basic math'}
12748 */
12749 function clipByValue_(x, clipValueMin, clipValueMax) {
12750 const $x = convertToTensor(x, 'x', 'clipByValue');
12751 assert((clipValueMin <= clipValueMax), () => `Error in clip: min (${clipValueMin}) must be ` +
12752 `less than or equal to max (${clipValueMax}).`);
12753 const inputs = { x: $x };
12754 const attrs = { clipValueMin, clipValueMax };
12755 return ENGINE.runKernel(ClipByValue, inputs, attrs);
12756 }
12757 const clipByValue = op({ clipByValue_ });
12758
12759 /**
12760 * Concatenates a list of`tf.Tensor1D`s along an axis. See `concat` for details.
12761 *
12762 * For example, if:
12763 * A: shape(3) = |r1, g1, b1|
12764 * B: shape(2) = |r2, g2|
12765 * C = tf.concat1d([A, B]) == |r1, g1, b1, r2, g2|
12766 *
12767 * @param tensors A list of`tf.Tensor`s to concatenate.
12768 * @return The concatenated array.
12769 */
12770 function concat1d_(tensors) {
12771 return concat(tensors, 0 /* axis */);
12772 }
12773 const concat1d = op({ concat1d_ });
12774
12775 /**
12776 * Concatenates a list of`tf.Tensor2D`s along an axis. See `concat` for details.
12777 *
12778 * For example, if:
12779 * A: shape(2, 3) = | r1, g1, b1 |
12780 * | r2, g2, b2 |
12781 *
12782 * B: shape(2, 3) = | r3, g3, b3 |
12783 * | r4, g4, b4 |
12784 *
12785 * C = tf.concat2d([A, B], axis)
12786 *
12787 * if axis = 0:
12788 * C: shape(4, 3) = | r1, g1, b1 |
12789 * | r2, g2, b2 |
12790 * | r3, g3, b3 |
12791 * | r4, g4, b4 |
12792 *
12793 * if axis = 1:
12794 * C = shape(2, 6) = | r1, g1, b1, r3, g3, b3 |
12795 * | r2, g2, b2, r4, g4, b4 |
12796 *
12797 *
12798 * @param tensors A list of `tf.Tensor`s to concatenate.
12799 * @param axis The axis to concatenate along.
12800 * @return The concatenated array.
12801 */
12802 function concat2d_(tensors, axis) {
12803 return concat(tensors, axis);
12804 }
12805 const concat2d = op({ concat2d_ });
12806
12807 /**
12808 * Concatenates a list of `tf.Tensor3D`s along an axis.
12809 * See `concat` for details.
12810 *
12811 * For example, if:
12812 * A: shape(2, 1, 3) = | r1, g1, b1 |
12813 * | r2, g2, b2 |
12814 *
12815 * B: shape(2, 1, 3) = | r3, g3, b3 |
12816 * | r4, g4, b4 |
12817 *
12818 * C = tf.concat3d([A, B], axis)
12819 *
12820 * if axis = 0:
12821 * C: shape(4, 1, 3) = | r1, g1, b1 |
12822 * | r2, g2, b2 |
12823 * | r3, g3, b3 |
12824 * | r4, g4, b4 |
12825 *
12826 * if axis = 1:
12827 * C: shape(2, 2, 3) = | r1, g1, b1, r3, g3, b3 |
12828 * | r2, g2, b2, r4, g4, b4 |
12829 *
12830 * if axis = 2:
12831 * C = shape(2, 1, 6) = | r1, g1, b1, r3, g3, b3 |
12832 * | r2, g2, b2, r4, g4, b4 |
12833 *
12834 * @param tensors A list of`tf.Tensor`s to concatenate.
12835 * @param axis The axis to concate along.
12836 * @return The concatenated array.
12837 */
12838 function concat3d_(tensors, axis) {
12839 return concat(tensors, axis);
12840 }
12841 const concat3d = op({ concat3d_ });
12842
12843 /**
12844 * Concatenates a list of `tf.Tensor4D`s along an axis.
12845 * See `concat` for details.
12846 *
12847 * @param tensors A list of `tf.Tensor`s to concatenate.
12848 * @param axis The axis to concate along.
12849 * @return The concatenated array.
12850 */
12851 function concat4d_(tensors, axis) {
12852 return concat(tensors, axis);
12853 }
12854 const concat4d = op({ concat4d_ });
12855
12856 /**
12857 * @license
12858 * Copyright 2020 Google LLC. All Rights Reserved.
12859 * Licensed under the Apache License, Version 2.0 (the "License");
12860 * you may not use this file except in compliance with the License.
12861 * You may obtain a copy of the License at
12862 *
12863 * http://www.apache.org/licenses/LICENSE-2.0
12864 *
12865 * Unless required by applicable law or agreed to in writing, software
12866 * distributed under the License is distributed on an "AS IS" BASIS,
12867 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12868 * See the License for the specific language governing permissions and
12869 * limitations under the License.
12870 * =============================================================================
12871 */
12872 /**
12873 * Computes a 2D convolution over the input x.
12874 *
12875 * @param x The input tensor, of rank 4 or rank 3, of shape
12876 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
12877 * assumed.
12878 * @param filter The filter, rank 4, of shape
12879 * `[filterHeight, filterWidth, inDepth, outDepth]`.
12880 * @param strides The strides of the convolution: `[strideHeight,
12881 * strideWidth]`.
12882 * @param pad The type of padding algorithm.
12883 * - `same` and stride 1: output will be of same size as input,
12884 * regardless of filter size.
12885 * - `valid`: output will be smaller than input if filter is larger
12886 * than 1x1.
12887 * - For more info, see this guide:
12888 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
12889 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
12890 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
12891 * "NHWC". Specify the data format of the input and output data. With the
12892 * default format "NHWC", the data is stored in the order of: [batch,
12893 * height, width, channels].
12894 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
12895 * in which we sample input values across the height and width dimensions
12896 * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
12897 * number, then `dilationHeight == dilationWidth`. If it is greater than
12898 * 1, then all values of `strides` must be 1.
12899 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
12900 * provided, it will default to truncate.
12901 *
12902 * @doc {heading: 'Operations', subheading: 'Convolution'}
12903 */
12904 function conv2d_(x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode) {
12905 const $x = convertToTensor(x, 'x', 'conv2d', 'float32');
12906 const $filter = convertToTensor(filter, 'filter', 'conv2d', 'float32');
12907 let x4D = $x;
12908 let reshapedTo4D = false;
12909 if ($x.rank === 3) {
12910 reshapedTo4D = true;
12911 x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
12912 }
12913 assert(x4D.rank === 4, () => `Error in conv2d: input must be rank 4, but got rank ${x4D.rank}.`);
12914 assert($filter.rank === 4, () => `Error in conv2d: filter must be rank 4, but got rank ` +
12915 `${$filter.rank}.`);
12916 checkPadOnDimRoundingMode('conv2d', pad, dimRoundingMode);
12917 const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
12918 assert(inDepth === $filter.shape[2], () => `Error in conv2d: depth of input (${inDepth}) must match ` +
12919 `input depth for filter ${$filter.shape[2]}.`);
12920 assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in conv2D: Either strides or dilations must be 1. ' +
12921 `Got strides ${strides} and dilations '${dilations}'`);
12922 const inputs = { x: x4D, filter: $filter };
12923 const attrs = { strides, pad, dataFormat, dilations, dimRoundingMode };
12924 // tslint:disable-next-line: no-unnecessary-type-assertion
12925 const res = ENGINE.runKernel(Conv2D, inputs, attrs);
12926 if (reshapedTo4D) {
12927 return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
12928 }
12929 return res;
12930 }
12931 const conv2d = op({ conv2d_ });
12932
12933 /**
12934 * Computes a 1D convolution over the input x.
12935 *
12936 * @param x The input tensor, of rank 3 or rank 2, of shape
12937 * `[batch, width, inChannels]`. If rank 2, batch of 1 is assumed.
12938 * @param filter The filter, rank 3, of shape
12939 * `[filterWidth, inDepth, outDepth]`.
12940 * @param stride The number of entries by which the filter is moved right at
12941 * each step.
12942 * @param pad The type of padding algorithm.
12943 * - `same` and stride 1: output will be of same size as input,
12944 * regardless of filter size.
12945 * - `valid`: output will be smaller than input if filter is larger
12946 * than 1x1.
12947 * - For more info, see this guide:
12948 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
12949 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
12950 * @param dataFormat An optional string from "NWC", "NCW". Defaults to "NWC",
12951 * the data is stored in the order of [batch, in_width, in_channels]. Only
12952 * "NWC" is currently supported.
12953 * @param dilation The dilation rate in which we sample input values in
12954 * atrous convolution. Defaults to `1`. If it is greater than 1, then
12955 * stride must be `1`.
12956 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
12957 * provided, it will default to truncate.
12958 *
12959 * @doc {heading: 'Operations', subheading: 'Convolution'}
12960 */
12961 function conv1d_(x, filter, stride, pad, dataFormat = 'NWC', dilation = 1, dimRoundingMode) {
12962 const $x = convertToTensor(x, 'x', 'conv1d');
12963 const $filter = convertToTensor(filter, 'filter', 'conv1d');
12964 let x3D = $x;
12965 let reshapedTo3D = false;
12966 if ($x.rank === 2) {
12967 reshapedTo3D = true;
12968 x3D = reshape($x, [1, $x.shape[0], $x.shape[1]]);
12969 }
12970 assert(x3D.rank === 3, () => `Error in conv1d: input must be rank 3, but got rank ${x3D.rank}.`);
12971 assert($filter.rank === 3, () => `Error in conv1d: filter must be rank 3, but got rank ` +
12972 `${$filter.rank}.`);
12973 checkPadOnDimRoundingMode('conv1d', pad, dimRoundingMode);
12974 assert(x3D.shape[2] === $filter.shape[1], () => `Error in conv1d: depth of input (${x3D.shape[2]}) must match ` +
12975 `input depth for filter ${$filter.shape[1]}.`);
12976 assert(eitherStridesOrDilationsAreOne(stride, dilation), () => 'Error in conv1D: Either stride or dilation must be 1. ' +
12977 `Got stride ${stride} and dilation '${dilation}'`);
12978 assert(dataFormat === 'NWC', () => `Error in conv1d: got dataFormat of ${dataFormat} but only NWC is currently supported.`);
12979 const filter4D = reshape($filter, [1, $filter.shape[0], $filter.shape[1], $filter.shape[2]]);
12980 const input4D = reshape(x3D, [x3D.shape[0], 1, x3D.shape[1], x3D.shape[2]]);
12981 const strides = [1, stride];
12982 const dilations = [1, dilation];
12983 const conv2dDataFormat = 'NHWC';
12984 const res = conv2d(input4D, filter4D, strides, pad, conv2dDataFormat, dilations, dimRoundingMode);
12985 if (reshapedTo3D) {
12986 return reshape(res, [res.shape[2], res.shape[3]]);
12987 }
12988 return reshape(res, [res.shape[0], res.shape[2], res.shape[3]]);
12989 }
12990 const conv1d = op({ conv1d_ });
12991
12992 /**
12993 * @license
12994 * Copyright 2020 Google LLC. All Rights Reserved.
12995 * Licensed under the Apache License, Version 2.0 (the "License");
12996 * you may not use this file except in compliance with the License.
12997 * You may obtain a copy of the License at
12998 *
12999 * http://www.apache.org/licenses/LICENSE-2.0
13000 *
13001 * Unless required by applicable law or agreed to in writing, software
13002 * distributed under the License is distributed on an "AS IS" BASIS,
13003 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13004 * See the License for the specific language governing permissions and
13005 * limitations under the License.
13006 * =============================================================================
13007 */
13008 /**
13009 * Computes the derivative of the input of a 2D convolution.
13010 *
13011 * @param xShape The shape of the input: [batch, height, width, inDepth].
13012 * If length of 3, batch of 1 is assumed.
13013 * @param dy The derivative of the output, of rank 4 or rank 3 of shape
13014 * `[batch, outHeight, outWidth, outDepth]`. If rank 3, batch of 1 is
13015 * assumed.
13016 * @param filter The filter, rank 4, of shape
13017 * `[filterHeight, filterWidth, inDepth, outDepth]`.
13018 * @param strides The strides of the convolution: `[strideHeight,
13019 * strideWidth]`.
13020 * @param pad The type of padding algorithm used:
13021 * - `same` and stride 1: output will be of same size as input,
13022 * regardless of filter size.
13023 * - `valid`: output will be smaller than input if filter is larger
13024 * than 1x1.
13025 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
13026 * "NHWC". Specify the data format of the input and output data. With the
13027 * default format "NHWC", the data is stored in the order of: [batch,
13028 * height, width, channels].
13029 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
13030 * provided, it will default to truncate.
13031 */
13032 function conv2DBackpropInput_(xShape, dy, filter, strides, pad, dataFormat = 'NHWC', dimRoundingMode) {
13033 assert(xShape.length === dy.rank, () => `Length of inShape ` +
13034 `(${xShape.length}) and rank of dy (${dy.rank}) must match`);
13035 let xShape4D = xShape;
13036 let dy4D = dy;
13037 let reshapedTo4D = false;
13038 if (dy.rank === 3) {
13039 reshapedTo4D = true;
13040 dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
13041 xShape4D = [1, xShape[0], xShape[1], xShape[2]];
13042 }
13043 assert(xShape4D.length === 4, () => `Error in conv2dDerInput: inShape must be length 4, but got length ` +
13044 `${xShape4D.length}.`);
13045 assert(dy4D.rank === 4, () => `Error in conv2dDerInput: dy must be rank 4, but got ` +
13046 `rank ${dy4D.rank}`);
13047 assert(filter.rank === 4, () => `Error in conv2dDerInput: filter must be rank 4, but got ` +
13048 `rank ${filter.rank}`);
13049 const inDepth = dataFormat === 'NHWC' ? xShape4D[3] : xShape4D[1];
13050 const outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
13051 assert(inDepth === filter.shape[2], () => `Error in conv2dDerInput: depth of input (${inDepth}) must ` +
13052 `match input depth for filter ${filter.shape[2]}.`);
13053 assert(outDepth === filter.shape[3], () => `Error in conv2dDerInput: depth of output (${outDepth}) must ` +
13054 `match output depth for filter ${filter.shape[3]}.`);
13055 checkPadOnDimRoundingMode('conv2dDerInput', pad, dimRoundingMode);
13056 const inputs = { dy: dy4D, filter };
13057 const attrs = { strides, pad, dataFormat, dimRoundingMode, inputShape: xShape4D };
13058 // tslint:disable-next-line: no-unnecessary-type-assertion
13059 const res = ENGINE.runKernel(Conv2DBackpropInput, inputs, attrs);
13060 if (reshapedTo4D) {
13061 return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
13062 }
13063 return res;
13064 }
13065 const conv2DBackpropInput = op({ conv2DBackpropInput_ });
13066
13067 /**
13068 * Computes the transposed 2D convolution of an image, also known as a
13069 * deconvolution.
13070 *
13071 * @param x The input image, of rank 4 or rank 3, of shape
13072 * `[batch, height, width, inDepth]`. If rank 3, batch of 1 is assumed.
13073 * @param filter The filter, rank 4, of shape
13074 * `[filterHeight, filterWidth, outDepth, inDepth]`.
13075 * `inDepth` must match `inDepth` in `x`.
13076 * @param outputShape Output shape, of rank 4 or rank 3:
13077 * `[batch, height, width, outDepth]`. If rank 3, batch of 1 is assumed.
13078 * @param strides The strides of the original convolution:
13079 * `[strideHeight, strideWidth]`.
13080 * @param pad The type of padding algorithm used in the non-transpose version
13081 * of the op.
13082 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
13083 * provided, it will default to truncate.
13084 *
13085 * @doc {heading: 'Operations', subheading: 'Convolution'}
13086 */
13087 function conv2dTranspose_(x, filter, outputShape, strides, pad, dimRoundingMode) {
13088 const $x = convertToTensor(x, 'x', 'conv2dTranspose');
13089 const $filter = convertToTensor(filter, 'filter', 'conv2dTranspose');
13090 return conv2DBackpropInput(outputShape, $x, $filter, strides, pad, 'NHWC', dimRoundingMode);
13091 }
13092 const conv2dTranspose = op({ conv2dTranspose_ });
13093
13094 /**
13095 * @license
13096 * Copyright 2020 Google LLC. All Rights Reserved.
13097 * Licensed under the Apache License, Version 2.0 (the "License");
13098 * you may not use this file except in compliance with the License.
13099 * You may obtain a copy of the License at
13100 *
13101 * http://www.apache.org/licenses/LICENSE-2.0
13102 *
13103 * Unless required by applicable law or agreed to in writing, software
13104 * distributed under the License is distributed on an "AS IS" BASIS,
13105 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13106 * See the License for the specific language governing permissions and
13107 * limitations under the License.
13108 * =============================================================================
13109 */
13110 /**
13111 * Computes a 3D convolution over the input x.
13112 *
13113 * @param x The input tensor, of rank 5 or rank 4, of shape
13114 * `[batch, depth, height, width, channels]`. If rank 4,
13115 * batch of 1 is assumed.
13116 * @param filter The filter, rank 5, of shape
13117 * `[filterDepth, filterHeight, filterWidth, inChannels, outChannels]`.
13118 * inChannels must match between input and filter.
13119 * @param strides The strides of the convolution: `[strideDepth, strideHeight,
13120 * strideWidth]`.
13121 * @param pad The type of padding algorithm.
13122 * - `same` and stride 1: output will be of same size as input,
13123 * regardless of filter size.
13124 * - `valid`: output will be smaller than input if filter is larger
13125 * than 1x1.
13126 * - For more info, see this guide:
13127 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
13128 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
13129 * @param dataFormat: An optional string from: "NDHWC", "NCDHW". Defaults to
13130 * "NDHWC". Specify the data format of the input and output data. With the
13131 * default format "NDHWC", the data is stored in the order of: [batch,
13132 * depth, height, width, channels]. Only "NDHWC" is currently supported.
13133 * @param dilations The dilation rates: `[dilationDepth, dilationHeight,
13134 * dilationWidth]` in which we sample input values across the height
13135 * and width dimensions in atrous convolution. Defaults to `[1, 1, 1]`.
13136 * If `dilations` is a single number, then
13137 * `dilationDepth == dilationHeight == dilationWidth`. If it is greater
13138 * than 1, then all values of `strides` must be 1.
13139 *
13140 * @doc {heading: 'Operations', subheading: 'Convolution'}
13141 */
13142 function conv3d_(x, filter, strides, pad, dataFormat = 'NDHWC', dilations = [1, 1, 1]) {
13143 const $x = convertToTensor(x, 'x', 'conv3d');
13144 const $filter = convertToTensor(filter, 'filter', 'conv3d');
13145 let x5D = $x;
13146 let reshapedTo5D = false;
13147 if ($x.rank === 4) {
13148 reshapedTo5D = true;
13149 x5D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
13150 }
13151 assert(x5D.rank === 5, () => `Error in conv3d: input must be rank 5, but got rank ${x5D.rank}.`);
13152 assert($filter.rank === 5, () => `Error in conv3d: filter must be rank 5, but got rank ` +
13153 `${$filter.rank}.`);
13154 assert(x5D.shape[4] === $filter.shape[3], () => `Error in conv3d: depth of input (${x5D.shape[4]}) must match ` +
13155 `input depth for filter ${$filter.shape[3]}.`);
13156 assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in conv3D: Either strides or dilations must be 1. ' +
13157 `Got strides ${strides} and dilations '${dilations}'`);
13158 assert(dataFormat === 'NDHWC', () => `Error in conv3d: got dataFormat of ${dataFormat} but only NDHWC is currently supported.`);
13159 const inputs = { x: x5D, filter: $filter };
13160 const attrs = { strides, pad, dataFormat, dilations };
13161 // tslint:disable-next-line: no-unnecessary-type-assertion
13162 const res = ENGINE.runKernel(Conv3D, inputs, attrs);
13163 if (reshapedTo5D) {
13164 return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
13165 }
13166 return res;
13167 }
13168 const conv3d = op({ conv3d_ });
13169
13170 /**
13171 * @license
13172 * Copyright 2020 Google LLC. All Rights Reserved.
13173 * Licensed under the Apache License, Version 2.0 (the "License");
13174 * you may not use this file except in compliance with the License.
13175 * You may obtain a copy of the License at
13176 *
13177 * http://www.apache.org/licenses/LICENSE-2.0
13178 *
13179 * Unless required by applicable law or agreed to in writing, software
13180 * distributed under the License is distributed on an "AS IS" BASIS,
13181 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13182 * See the License for the specific language governing permissions and
13183 * limitations under the License.
13184 * =============================================================================
13185 */
13186 /**
13187 * Computes the derivative of the input of a 3D convolution.
13188 *
13189 * @param xShape The shape of the input: [batch, depth, height, width,
13190 * in_channels]. If length of 4, batch of 1 is assumed.
13191 * @param dy The derivative of the output, of rank 5 or rank 4 of shape
13192 * `[batch, outDepth, outHeight, outWidth, in_channels]`.
13193 * If rank 4, batch of 1 is assumed.
13194 * @param filter The filter, rank 5, of shape
13195 * `[filterDepth, filterHeight, filterWidth, inDepth, outDepth]`.
13196 * @param strides The strides of the convolution: `[strideDepth, strideHeight,
13197 * strideWidth]`.
13198 * @param pad The type of padding algorithm used:
13199 * - `same` and stride 1: output will be of same size as input,
13200 * regardless of filter size.
13201 * - `valid`: output will be smaller than input if filter is larger
13202 * than 1x1.
13203 */
13204 function conv3DBackpropInput_(xShape, dy, filter, strides, pad) {
13205 assert(xShape.length === dy.rank, () => `Length of inShape ` +
13206 `(${xShape.length}) and rank of dy (${dy.rank}) must match`);
13207 let xShape5D = xShape;
13208 let dy5D = dy;
13209 let reshapedTo5D = false;
13210 if (dy.rank === 4) {
13211 reshapedTo5D = true;
13212 dy5D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
13213 xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]];
13214 }
13215 const inDepth = xShape5D[4];
13216 const outDepth = dy5D.shape[4];
13217 assert(xShape5D.length === 5, () => `Error in conv3dDerInput: inShape must be length 5, but got length ` +
13218 `${xShape5D.length}.`);
13219 assert(dy5D.rank === 5, () => `Error in conv3dDerInput: dy must be rank 5, but got ` +
13220 `rank ${dy5D.rank}`);
13221 assert(filter.rank === 5, () => `Error in conv3dDerInput: filter must be rank 5, but got ` +
13222 `rank ${filter.rank}`);
13223 assert(inDepth === filter.shape[3], () => `Error in conv3dDerInput: depth of input (${inDepth}) must ` +
13224 `match input depth for filter ${filter.shape[3]}.`);
13225 assert(outDepth === filter.shape[4], () => `Error in conv3dDerInput: depth of output (${outDepth}) must ` +
13226 `match output depth for filter ${filter.shape[4]}.`);
13227 const inputs = { dy: dy5D, filter };
13228 const attrs = { pad, strides, inputShape: xShape5D };
13229 // tslint:disable-next-line: no-unnecessary-type-assertion
13230 const res = ENGINE.runKernel(Conv3DBackpropInputV2, inputs, attrs);
13231 if (reshapedTo5D) {
13232 return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
13233 }
13234 return res;
13235 }
13236 const conv3DBackpropInput = op({ conv3DBackpropInput_ });
13237
13238 /**
13239 * Computes the transposed 3D convolution of a volume, also known as a
13240 * deconvolution.
13241 *
13242 * @param x The input image, of rank 5 or rank 4, of shape
13243 * `[batch, depth, height, width, inDepth]`. If rank 4, batch of 1 is assumed.
13244 * @param filter The filter, rank 4, of shape
13245 * `[depth, filterHeight, filterWidth, outDepth, inDepth]`.
13246 * `inDepth` must match `inDepth` in `x`.
13247 * @param outputShape Output shape, of rank 5 or rank 4:
13248 * `[batch, depth, height, width, outDepth]`. If rank 3, batch of 1 is
13249 * assumed.
13250 * @param strides The strides of the original convolution:
13251 * `[strideDepth, strideHeight, strideWidth]`.
13252 * @param pad The type of padding algorithm used in the non-transpose version
13253 * of the op.
13254 *
13255 * @doc {heading: 'Operations', subheading: 'Convolution'}
13256 */
13257 function conv3dTranspose_(x, filter, outputShape, strides, pad) {
13258 const $x = convertToTensor(x, 'x', 'conv3dTranspose');
13259 const $filter = convertToTensor(filter, 'filter', 'conv3dTranspose');
13260 return conv3DBackpropInput(outputShape, $x, $filter, strides, pad);
13261 }
13262 const conv3dTranspose = op({ conv3dTranspose_ });
13263
13264 /**
13265 * @license
13266 * Copyright 2018 Google LLC. All Rights Reserved.
13267 * Licensed under the Apache License, Version 2.0 (the "License");
13268 * you may not use this file except in compliance with the License.
13269 * You may obtain a copy of the License at
13270 *
13271 * http://www.apache.org/licenses/LICENSE-2.0
13272 *
13273 * Unless required by applicable law or agreed to in writing, software
13274 * distributed under the License is distributed on an "AS IS" BASIS,
13275 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13276 * See the License for the specific language governing permissions and
13277 * limitations under the License.
13278 * =============================================================================
13279 */
13280 /**
13281 * Computes cos of the input `tf.Tensor` element-wise: `cos(x)`
13282 *
13283 * ```js
13284 * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
13285 *
13286 * x.cos().print(); // or tf.cos(x)
13287 * ```
13288 * @param x The input tensor. Must be float32 type.
13289 *
13290 * @doc {heading: 'Operations', subheading: 'Basic math'}
13291 */
13292 function cos_(x) {
13293 const $x = convertToTensor(x, 'x', 'cos', 'float32');
13294 const inputs = { x: $x };
13295 return ENGINE.runKernel(Cos, inputs);
13296 }
13297 const cos = op({ cos_ });
13298
13299 /**
13300 * @license
13301 * Copyright 2018 Google LLC. All Rights Reserved.
13302 * Licensed under the Apache License, Version 2.0 (the "License");
13303 * you may not use this file except in compliance with the License.
13304 * You may obtain a copy of the License at
13305 *
13306 * http://www.apache.org/licenses/LICENSE-2.0
13307 *
13308 * Unless required by applicable law or agreed to in writing, software
13309 * distributed under the License is distributed on an "AS IS" BASIS,
13310 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13311 * See the License for the specific language governing permissions and
13312 * limitations under the License.
13313 * =============================================================================
13314 */
13315 /**
13316 * Computes hyperbolic cos of the input `tf.Tensor` element-wise: `cosh(x)`
13317 *
13318 * ```js
13319 * const x = tf.tensor1d([0, 1, -1, .7]);
13320 *
13321 * x.cosh().print(); // or tf.cosh(x)
13322 * ```
13323 * @param x The input tensor. Must be float32 type.
13324 *
13325 * @doc {heading: 'Operations', subheading: 'Basic math'}
13326 */
13327 function cosh_(x) {
13328 const $x = convertToTensor(x, 'x', 'cosh', 'float32');
13329 const inputs = { x: $x };
13330 return ENGINE.runKernel(Cosh, inputs);
13331 }
13332 const cosh = op({ cosh_ });
13333
13334 /**
13335 * @license
13336 * Copyright 2022 Google LLC. All Rights Reserved.
13337 * Licensed under the Apache License, Version 2.0 (the 'License');
13338 * you may not use this file except in compliance with the License.
13339 * You may obtain a copy of the License at
13340 *
13341 * http://www.apache.org/licenses/LICENSE-2.0
13342 *
13343 * Unless required by applicable law or agreed to in writing, software
13344 * distributed under the License is distributed on an 'AS IS' BASIS,
13345 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13346 * See the License for the specific language governing permissions and
13347 * limitations under the License.
13348 * =============================================================================
13349 */
13350 /**
13351 * Computes the cumulative product of a `tf.Tensor` along `axis`.
13352 *
13353 * ```js
13354 * const x = tf.tensor([1, 2, 3, 4]);
13355 * x.cumprod().print();
13356 * ```
13357 * ```js
13358 * const x = tf.tensor([[1, 2], [3, 4]]);
13359 * x.cumprod().print();
13360 * ```
13361 *
13362 * @param x The input tensor to cumulatively multiply.
13363 * @param axis The axis along which to multiply. Optional. Defaults to 0.
13364 * @param exclusive Whether to perform exclusive cumulative product. Optional.
13365 * Defaults to false. If set to true then the product of each tensor entry
13366 * does not include its own value, but only the values previous to it
13367 * along the specified axis.
13368 * @param reverse Whether to multiply in the opposite direction. Optional.
13369 * Defaults to false.
13370 *
13371 * @doc {heading: 'Operations', subheading: 'Scan'}
13372 */
13373 function cumprod_(x, axis = 0, exclusive = false, reverse = false) {
13374 const $x = convertToTensor(x, 'x', 'cumprod');
13375 const inputs = { x: $x };
13376 const attrs = { axis, exclusive, reverse };
13377 return ENGINE.runKernel(Cumprod, inputs, attrs);
13378 }
13379 const cumprod = op({ cumprod_ });
13380
13381 /**
13382 * @license
13383 * Copyright 2018 Google LLC. All Rights Reserved.
13384 * Licensed under the Apache License, Version 2.0 (the "License");
13385 * you may not use this file except in compliance with the License.
13386 * You may obtain a copy of the License at
13387 *
13388 * http://www.apache.org/licenses/LICENSE-2.0
13389 *
13390 * Unless required by applicable law or agreed to in writing, software
13391 * distributed under the License is distributed on an "AS IS" BASIS,
13392 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13393 * See the License for the specific language governing permissions and
13394 * limitations under the License.
13395 * =============================================================================
13396 */
13397 /**
13398 * Computes the cumulative sum of a `tf.Tensor` along `axis`.
13399 *
13400 * ```js
13401 * const x = tf.tensor([1, 2, 3, 4]);
13402 * x.cumsum().print();
13403 * ```
13404 * ```js
13405 * const x = tf.tensor([[1, 2], [3, 4]]);
13406 * x.cumsum().print();
13407 * ```
13408 *
13409 * @param x The input tensor to be summed.
13410 * @param axis The axis along which to sum. Optional. Defaults to 0.
13411 * @param exclusive Whether to perform exclusive cumulative sum. Optional.
13412 * Defaults to false. If set to true then the sum of each tensor entry
13413 * does not include its own value, but only the values previous to it
13414 * along the specified axis.
13415 * @param reverse Whether to sum in the opposite direction. Optional.
13416 * Defaults to false.
13417 *
13418 * @doc {heading: 'Operations', subheading: 'Scan'}
13419 */
13420 function cumsum_(x, axis = 0, exclusive = false, reverse = false) {
13421 const $x = convertToTensor(x, 'x', 'cumsum');
13422 const inputs = { x: $x };
13423 const attrs = { axis, exclusive, reverse };
13424 return ENGINE.runKernel(Cumsum, inputs, attrs);
13425 }
13426 const cumsum = op({ cumsum_ });
13427
13428 /**
13429 * @license
13430 * Copyright 2020 Google LLC. All Rights Reserved.
13431 * Licensed under the Apache License, Version 2.0 (the "License");
13432 * you may not use this file except in compliance with the License.
13433 * You may obtain a copy of the License at
13434 *
13435 * http://www.apache.org/licenses/LICENSE-2.0
13436 *
13437 * Unless required by applicable law or agreed to in writing, software
13438 * distributed under the License is distributed on an "AS IS" BASIS,
13439 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13440 * See the License for the specific language governing permissions and
13441 * limitations under the License.
13442 * =============================================================================
13443 */
13444 /**
13445 * Outputs a vector with length `size` and the same dtype as `weights`.
13446 *
13447 * If `weights` are empty, then index `i` stores the number of times the value
13448 * `i` is counted in `x`. If `weights` are non-empty, then index `i` stores the
13449 * sum of the value in `weights` at each index where the corresponding value in
13450 * `x` is `i`.
13451 *
13452 * Values in `x` outside of the range [0, size) are ignored.
13453 *
13454 * @param x The input int tensor, rank 1 or rank 2.
13455 * @param weights The weights tensor, must have the same shape as x, or a
13456 * length-0 Tensor, in which case it acts as all weights equal to 1.
13457 * @param size Non-negative integer.
13458 * @param binaryOutput Optional. Whether the kernel should count the appearance
13459 * or number of occurrences. Defaults to False.
13460 *
13461 * @doc {heading: 'Operations', subheading: 'Reduction'}
13462 */
13463 function denseBincount_(x, weights, size, binaryOutput = false) {
13464 const $x = convertToTensor(x, 'x', 'denseBincount');
13465 const $weights = convertToTensor(weights, 'weights', 'denseBincount');
13466 assert($x.dtype === 'int32', () => `Error in denseBincount: input ` +
13467 `dtype must be int32, but got ${$x.dtype}`);
13468 assert($x.rank <= 2, () => `Error in denseBincount: input must be at most rank 2, but got ` +
13469 `rank ${$x.rank}.`);
13470 assert(size >= 0, () => `size must be non-negative, but got ${size}.`);
13471 assert($weights.size === $x.size || $weights.size === 0, () => `Error in denseBincount: weights must have the same shape as x or ` +
13472 `0-length, but got x shape: ${$x.shape}, weights shape: ` +
13473 `${$weights.shape}.`);
13474 const inputs = { x: $x, weights: $weights };
13475 const attrs = { size, binaryOutput };
13476 return ENGINE.runKernel(DenseBincount, inputs, attrs);
13477 }
13478 const denseBincount = op({ denseBincount_ });
13479
13480 /**
13481 * @license
13482 * Copyright 2020 Google LLC. All Rights Reserved.
13483 * Licensed under the Apache License, Version 2.0 (the "License");
13484 * you may not use this file except in compliance with the License.
13485 * You may obtain a copy of the License at
13486 *
13487 * http://www.apache.org/licenses/LICENSE-2.0
13488 *
13489 * Unless required by applicable law or agreed to in writing, software
13490 * distributed under the License is distributed on an "AS IS" BASIS,
13491 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13492 * See the License for the specific language governing permissions and
13493 * limitations under the License.
13494 * =============================================================================
13495 */
13496 /**
13497 * Rearranges data from depth into blocks of spatial data. More specifically,
13498 * this op outputs a copy of the input tensor where values from the `depth`
13499 * dimension are moved in spatial blocks to the `height` and `width` dimensions.
13500 * The attr `blockSize` indicates the input block size and how the data is
13501 * moved.
13502 *
13503 * - Chunks of data of size `blockSize * blockSize` from depth are rearranged
13504 * into non-overlapping blocks of size `blockSize x blockSize`
13505 *
13506 * - The width the output tensor is `inputWidth * blockSize`, whereas the
13507 * height is `inputHeight * blockSize`
13508 *
13509 * - The Y, X coordinates within each block of the output image are determined
13510 * by the high order component of the input channel index
13511 *
13512 * - The depth of the input tensor must be divisible by `blockSize *
13513 * blockSize`
13514 *
13515 * The `dataFormat` attr specifies the layout of the input and output tensors
13516 * with the following options: "NHWC": [ `batch, height, width, channels` ]
13517 * "NCHW": [ `batch, channels, height, width` ]
13518 *
13519 * ```js
13520 * const x = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]);
13521 * const blockSize = 2;
13522 * const dataFormat = "NHWC";
13523 *
13524 * tf.depthToSpace(x, blockSize, dataFormat).print();
13525 * ```
13526 *
13527 * @param x The input tensor of rank 4
13528 * @param blockSIze An `int` that is `>= 2`. The size of the spatial block
13529 * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to "NHWC"
13530 *
13531 * @doc {heading: 'Tensors', subheading: 'Transformations'}
13532 */
13533 function depthToSpace_(x, blockSize, dataFormat = 'NHWC') {
13534 const $x = convertToTensor(x, 'x', 'depthToSpace', 'float32');
13535 const inputHeight = (dataFormat === 'NHWC') ? $x.shape[1] : $x.shape[2];
13536 const inputWidth = (dataFormat === 'NHWC') ? $x.shape[2] : $x.shape[3];
13537 const inputDepth = (dataFormat === 'NHWC') ? $x.shape[3] : $x.shape[1];
13538 assert(blockSize > 1, () => `blockSize should be > 1 for depthToSpace, but was: ${blockSize}`);
13539 assert(inputHeight * blockSize >= 0, () => `Negative dimension size caused by overflow when multiplying
13540 ${inputHeight} and ${blockSize} for depthToSpace with input shape
13541 ${$x.shape}`);
13542 assert(inputWidth * blockSize >= 0, () => `Negative dimension size caused by overflow when multiplying
13543 ${inputWidth} and ${blockSize} for depthToSpace with input shape
13544 ${$x.shape}`);
13545 assert((inputDepth % (blockSize * blockSize) === 0), () => `Dimension size must be evenly divisible by ${blockSize * blockSize} but is ${inputDepth} for depthToSpace with input shape ${$x.shape}`);
13546 const inputs = { x: $x };
13547 const attrs = { blockSize, dataFormat };
13548 return ENGINE.runKernel(DepthToSpace, inputs, attrs);
13549 }
13550 const depthToSpace = op({ depthToSpace_ });
13551
13552 /**
13553 * @license
13554 * Copyright 2020 Google LLC. All Rights Reserved.
13555 * Licensed under the Apache License, Version 2.0 (the "License");
13556 * you may not use this file except in compliance with the License.
13557 * You may obtain a copy of the License at
13558 *
13559 * http://www.apache.org/licenses/LICENSE-2.0
13560 *
13561 * Unless required by applicable law or agreed to in writing, software
13562 * distributed under the License is distributed on an "AS IS" BASIS,
13563 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13564 * See the License for the specific language governing permissions and
13565 * limitations under the License.
13566 * =============================================================================
13567 */
13568 /**
13569 * Depthwise 2D convolution.
13570 *
13571 * Given a 4D `input` array and a `filter` array of shape
13572 * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing
13573 * `inChannels` convolutional filters of depth 1, this op applies a
13574 * different filter to each input channel (expanding from 1 channel to
13575 * `channelMultiplier` channels for each), then concatenates the results
13576 * together. The output has `inChannels * channelMultiplier` channels.
13577 *
13578 * See
13579 * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d](
13580 * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d)
13581 * for more details.
13582 *
13583 * @param x The input tensor, of rank 4 or rank 3, of shape
13584 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
13585 * assumed.
13586 * @param filter The filter tensor, rank 4, of shape
13587 * `[filterHeight, filterWidth, inChannels, channelMultiplier]`.
13588 * @param strides The strides of the convolution: `[strideHeight,
13589 * strideWidth]`. If strides is a single number, then `strideHeight ==
13590 * strideWidth`.
13591 * @param pad The type of padding algorithm.
13592 * - `same` and stride 1: output will be of same size as input,
13593 * regardless of filter size.
13594 * - `valid`: output will be smaller than input if filter is larger
13595 * than 1x1.
13596 * - For more info, see this guide:
13597 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
13598 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
13599 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
13600 * in which we sample input values across the height and width dimensions
13601 * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
13602 * number, then `dilationHeight == dilationWidth`. If it is greater than
13603 * 1, then all values of `strides` must be 1.
13604 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
13605 * "NHWC". Specify the data format of the input and output data. With the
13606 * default format "NHWC", the data is stored in the order of: [batch,
13607 * height, width, channels]. Only "NHWC" is currently supported.
13608 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
13609 * provided, it will default to truncate.
13610 *
13611 * @doc {heading: 'Operations', subheading: 'Convolution'}
13612 */
13613 function depthwiseConv2d_(x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode) {
13614 const $x = convertToTensor(x, 'x', 'depthwiseConv2d', 'float32');
13615 const $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d', 'float32');
13616 let x4D = $x;
13617 let reshapedTo4D = false;
13618 if ($x.rank === 3) {
13619 reshapedTo4D = true;
13620 x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
13621 }
13622 assert(x4D.rank === 4, () => `Error in depthwiseConv2d: input must be rank 4, but got ` +
13623 `rank ${x4D.rank}.`);
13624 assert($filter.rank === 4, () => `Error in depthwiseConv2d: filter must be rank 4, but got rank ` +
13625 `${$filter.rank}.`);
13626 assert(x4D.shape[3] === $filter.shape[2], () => `Error in depthwiseConv2d: number of input channels ` +
13627 `(${x4D.shape[3]}) must match the inChannels dimension in ` +
13628 `filter ${$filter.shape[2]}.`);
13629 checkPadOnDimRoundingMode('depthwiseConv2d', pad, dimRoundingMode);
13630 const inputs = { x: x4D, filter: $filter };
13631 const attrs = { strides, pad, dataFormat, dilations, dimRoundingMode };
13632 // tslint:disable-next-line: no-unnecessary-type-assertion
13633 const res = ENGINE.runKernel(DepthwiseConv2dNative, inputs, attrs);
13634 if (reshapedTo4D) {
13635 return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
13636 }
13637 return res;
13638 }
13639 const depthwiseConv2d = op({ depthwiseConv2d_ });
13640
13641 /**
13642 * @license
13643 * Copyright 2020 Google LLC. All Rights Reserved.
13644 * Licensed under the Apache License, Version 2.0 (the "License");
13645 * you may not use this file except in compliance with the License.
13646 * You may obtain a copy of the License at
13647 *
13648 * http://www.apache.org/licenses/LICENSE-2.0
13649 *
13650 * Unless required by applicable law or agreed to in writing, software
13651 * distributed under the License is distributed on an "AS IS" BASIS,
13652 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13653 * See the License for the specific language governing permissions and
13654 * limitations under the License.
13655 * =============================================================================
13656 */
13657 /**
13658 * Returns a diagonal tensor with a given diagonal values.
13659 *
13660 * Given a diagonal, this operation returns a tensor with the diagonal and
13661 * everything else padded with zeros.
13662 *
13663 * Assume the input has dimensions `[D1,..., Dk]`, then the output is a tensor
13664 * of rank 2k with dimensions `[D1,..., Dk, D1,..., Dk]`
13665 *
13666 * ```js
13667 * const x = tf.tensor1d([1, 2, 3, 4]);
13668 *
13669 * tf.diag(x).print()
13670 * ```
13671 * ```js
13672 * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 6, 8], [4, 2])
13673 *
13674 * tf.diag(x).print()
13675 * ```
13676 * @param x The input tensor.
13677 *
13678 * @doc {heading: 'Tensors', subheading: 'Creation'}
13679 */
13680 function diag_(x) {
13681 const $x = convertToTensor(x, 'x', 'diag');
13682 const inputs = { x: $x };
13683 return ENGINE.runKernel(Diag, inputs);
13684 }
13685 const diag = op({ diag_ });
13686
13687 /**
13688 * @license
13689 * Copyright 2020 Google LLC. All Rights Reserved.
13690 * Licensed under the Apache License, Version 2.0 (the "License");
13691 * you may not use this file except in compliance with the License.
13692 * You may obtain a copy of the License at
13693 *
13694 * http://www.apache.org/licenses/LICENSE-2.0
13695 *
13696 * Unless required by applicable law or agreed to in writing, software
13697 * distributed under the License is distributed on an "AS IS" BASIS,
13698 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13699 * See the License for the specific language governing permissions and
13700 * limitations under the License.
13701 * =============================================================================
13702 */
13703 /**
13704 * Computes the grayscale dilation over the input `x`.
13705 *
13706 * @param x The input tensor, rank 3 or rank 4 of shape
13707 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
13708 * @param filter The filter tensor, rank 3, of shape
13709 * `[filterHeight, filterWidth, depth]`.
13710 * @param strides The strides of the sliding window for each dimension of the
13711 * input tensor: `[strideHeight, strideWidth]`.
13712 * If `strides` is a single number,
13713 * then `strideHeight == strideWidth`.
13714 * @param pad The type of padding algorithm.
13715 * - `same` and stride 1: output will be of same size as input,
13716 * regardless of filter size.
13717 * - `valid`: output will be smaller than input if filter is larger
13718 * than 1*1x1.
13719 * - For more info, see this guide:
13720 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
13721 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
13722 * @param dataFormat Specify the data format of the input and output data.
13723 * Defaults to 'NHWC'. Only 'NHWC' is currently supported. With the
13724 * default format "NHWC", the data is stored in the order of: [batch,
13725 * height, width, channels].
13726 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
13727 * in which we sample input values across the height and width dimensions
13728 * for atrous morphological dilation. Defaults to `[1, 1]`. If `dilations`
13729 * is a single number, then `dilationHeight == dilationWidth`. If it is
13730 * greater than 1, then all values of `strides` must be 1.
13731 *
13732 * @doc {heading: 'Operations', subheading: 'Convolution'}
13733 */
13734 function dilation2d_(x, filter, strides, pad, dilations = [1, 1], dataFormat = 'NHWC') {
13735 const $x = convertToTensor(x, 'x', 'dilation2d');
13736 const $filter = convertToTensor(filter, 'filter', 'dilation2d');
13737 assert($x.rank === 3 || $x.rank === 4, () => `Error in dilation2d: input must be rank 3 or 4, but got rank ` +
13738 `${$x.rank}.`);
13739 assert($filter.rank === 3, () => `Error in dilation2d: filter must be rank 3, but got rank ` +
13740 `${$filter.rank}.`);
13741 assert(dataFormat === 'NHWC', () => `Error in dilation2d: Only NHWC is currently supported, ` +
13742 `but got dataFormat of ${dataFormat}`);
13743 let x4D = $x;
13744 let reshapedTo4D = false;
13745 if ($x.rank === 3) {
13746 x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
13747 reshapedTo4D = true;
13748 }
13749 const inputs = { x: x4D, filter: $filter };
13750 const attrs = { strides, pad, dilations };
13751 // tslint:disable-next-line: no-unnecessary-type-assertion
13752 const res = ENGINE.runKernel(Dilation2D, inputs, attrs);
13753 if (reshapedTo4D) {
13754 return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
13755 }
13756 return res;
13757 }
13758 const dilation2d = op({ dilation2d_ });
13759
13760 /**
13761 * @license
13762 * Copyright 2020 Google LLC. All Rights Reserved.
13763 * Licensed under the Apache License, Version 2.0 (the "License");
13764 * you may not use this file except in compliance with the License.
13765 * You may obtain a copy of the License at
13766 *
13767 * http://www.apache.org/licenses/LICENSE-2.0
13768 *
13769 * Unless required by applicable law or agreed to in writing, software
13770 * distributed under the License is distributed on an "AS IS" BASIS,
13771 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13772 * See the License for the specific language governing permissions and
13773 * limitations under the License.
13774 * =============================================================================
13775 */
13776 /**
13777 * Returns the truth value of (a == b) element-wise. Supports broadcasting.
13778 *
13779 * ```js
13780 * const a = tf.tensor1d([1, 2, 3]);
13781 * const b = tf.tensor1d([2, 2, 2]);
13782 *
13783 * a.equal(b).print();
13784 * ```
13785 *
13786 * @param a The first input tensor.
13787 * @param b The second input tensor. Must have the same dtype as `a`.
13788 *
13789 * @doc {heading: 'Operations', subheading: 'Logical'}
13790 */
13791 function equal_(a, b) {
13792 let $a = convertToTensor(a, 'a', 'equal', 'string_or_numeric');
13793 let $b = convertToTensor(b, 'b', 'equal', 'string_or_numeric');
13794 [$a, $b] = makeTypesMatch($a, $b);
13795 assertAndGetBroadcastShape($a.shape, $b.shape);
13796 const inputs = { a: $a, b: $b };
13797 return ENGINE.runKernel(Equal, inputs);
13798 }
13799 const equal = op({ equal_ });
13800
13801 /**
13802 * @license
13803 * Copyright 2020 Google LLC. All Rights Reserved.
13804 * Licensed under the Apache License, Version 2.0 (the "License");
13805 * you may not use this file except in compliance with the License.
13806 * You may obtain a copy of the License at
13807 *
13808 * http://www.apache.org/licenses/LICENSE-2.0
13809 *
13810 * Unless required by applicable law or agreed to in writing, software
13811 * distributed under the License is distributed on an "AS IS" BASIS,
13812 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13813 * See the License for the specific language governing permissions and
13814 * limitations under the License.
13815 * =============================================================================
13816 */
13817 /**
13818 * Returns the elements, either `a` or `b` depending on the `condition`.
13819 *
13820 * If the condition is true, select from `a`, otherwise select from `b`.
13821 *
13822 * ```js
13823 * const cond = tf.tensor1d([false, false, true], 'bool');
13824 * const a = tf.tensor1d([1 , 2, 3]);
13825 * const b = tf.tensor1d([-1, -2, -3]);
13826 *
13827 * a.where(cond, b).print();
13828 * ```
13829 *
13830 * @param condition The input condition. Must be of dtype bool.
13831 * @param a If `condition` is rank 1, `a` may have a higher rank but
13832 * its first dimension must match the size of `condition`.
13833 * @param b A tensor with the same dtype as `a` and with shape that is
13834 * compatible with `a`.
13835 * @return A tensor with same dtype as `a` and `b`, and shape that is
13836 * broadcastable from `a` and `b`.
13837 *
13838 * @doc {heading: 'Operations', subheading: 'Logical'}
13839 */
13840 function where_(condition, a, b) {
13841 const $a = convertToTensor(a, 'a', 'where');
13842 const $b = convertToTensor(b, 'b', 'where');
13843 const $condition = convertToTensor(condition, 'condition', 'where', 'bool');
13844 // TODO: move this logic to forward function when the broadcastTo op is
13845 // implemented in WASM.
13846 // Find the broadcastable shape for $condition, $a, and $b.
13847 const broadcastShape = assertAndGetBroadcastShape(assertAndGetBroadcastShape($condition.shape, $a.shape), $b.shape);
13848 const $broadcastedCondition = broadcastTo($condition, broadcastShape);
13849 const $broadcastedA = broadcastTo($a, broadcastShape);
13850 const $broadcastedB = broadcastTo($b, broadcastShape);
13851 const inputs = {
13852 condition: $broadcastedCondition,
13853 t: $broadcastedA,
13854 e: $broadcastedB
13855 };
13856 return ENGINE.runKernel(Select, inputs);
13857 }
13858 const where = op({ where_ });
13859
13860 /**
13861 * @license
13862 * Copyright 2018 Google LLC. All Rights Reserved.
13863 * Licensed under the Apache License, Version 2.0 (the "License");
13864 * you may not use this file except in compliance with the License.
13865 * You may obtain a copy of the License at
13866 *
13867 * http://www.apache.org/licenses/LICENSE-2.0
13868 *
13869 * Unless required by applicable law or agreed to in writing, software
13870 * distributed under the License is distributed on an "AS IS" BASIS,
13871 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13872 * See the License for the specific language governing permissions and
13873 * limitations under the License.
13874 * =============================================================================
13875 */
13876 /**
13877 * Creates a `tf.Tensor` with all elements set to 0 with the same shape as the
13878 * given tensor.
13879 *
13880 * ```js
13881 * const x = tf.tensor([1, 2]);
13882 * tf.zerosLike(x).print();
13883 * ```
13884 *
13885 * @param x The tensor of required shape.
13886 *
13887 * @doc {heading: 'Tensors', subheading: 'Creation'}
13888 */
13889 function zerosLike_(x) {
13890 const $x = convertToTensor(x, 'x', 'zerosLike');
13891 const inputs = { x: $x };
13892 return ENGINE.runKernel(ZerosLike, inputs);
13893 }
13894 const zerosLike = op({ zerosLike_ });
13895
13896 /**
13897 * @license
13898 * Copyright 2020 Google LLC. All Rights Reserved.
13899 * Licensed under the Apache License, Version 2.0 (the "License");
13900 * you may not use this file except in compliance with the License.
13901 * You may obtain a copy of the License at
13902 *
13903 * http://www.apache.org/licenses/LICENSE-2.0
13904 *
13905 * Unless required by applicable law or agreed to in writing, software
13906 * distributed under the License is distributed on an "AS IS" BASIS,
13907 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13908 * See the License for the specific language governing permissions and
13909 * limitations under the License.
13910 * =============================================================================
13911 */
13912 /**
13913 * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. Return 0
13914 * if denominator is 0.
13915 *
13916 *
13917 * ```js
13918 * const a = tf.tensor1d([1, 4, 9, 16]);
13919 * const b = tf.tensor1d([1, 2, 3, 4]);
13920 * const c = tf.tensor1d([0, 0, 0, 0]);
13921 *
13922 * a.divNoNan(b).print(); // or tf.divNoNan(a, b)
13923 * a.divNoNan(c).print(); // or tf.divNoNan(a, c)
13924 * ```
13925 *
13926 * ```js
13927 * // Broadcast div a with b.
13928 * const a = tf.tensor1d([2, 4, 6, 8]);
13929 * const b = tf.scalar(2);
13930 * const c = tf.scalar(0);
13931 *
13932 * a.divNoNan(b).print(); // or tf.divNoNan(a, b)
13933 * a.divNoNan(c).print(); // or tf.divNoNan(a, c)
13934 * ```
13935 *
13936 * @param a The first tensor as the numerator.
13937 * @param b The second tensor as the denominator. Must have the same dtype as
13938 * `a`.
13939 *
13940 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
13941 */
13942 function divNoNan_(a, b) {
13943 // TODO: Make this into its own kernel.
13944 let $a = convertToTensor(a, 'a', 'div');
13945 let $b = convertToTensor(b, 'b', 'div');
13946 [$a, $b] = makeTypesMatch($a, $b);
13947 const divResult = div($a, $b);
13948 const zeros = zerosLike(divResult);
13949 const bEqualsZero = equal($b, zeros);
13950 return where(bEqualsZero, zeros, divResult);
13951 }
13952 const divNoNan = op({ divNoNan_ });
13953
13954 /**
13955 * @license
13956 * Copyright 2020 Google LLC. All Rights Reserved.
13957 * Licensed under the Apache License, Version 2.0 (the "License");
13958 * you may not use this file except in compliance with the License.
13959 * You may obtain a copy of the License at
13960 *
13961 * http://www.apache.org/licenses/LICENSE-2.0
13962 *
13963 * Unless required by applicable law or agreed to in writing, software
13964 * distributed under the License is distributed on an "AS IS" BASIS,
13965 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13966 * See the License for the specific language governing permissions and
13967 * limitations under the License.
13968 * =============================================================================
13969 */
13970 /**
13971 * Computes the dot product of two matrices and/or vectors, `t1` and `t2`.
13972 *
13973 * ```js
13974 * const a = tf.tensor1d([1, 2]);
13975 * const b = tf.tensor2d([[1, 2], [3, 4]]);
13976 * const c = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
13977 *
13978 * a.dot(b).print(); // or tf.dot(a, b)
13979 * b.dot(a).print();
13980 * b.dot(c).print();
13981 * ```
13982 * @param t1 The first tensor in the dot operation.
13983 * @param t2 The second tensor in the dot operation.
13984 *
13985 * @doc {heading: 'Operations', subheading: 'Matrices'}
13986 */
13987 function dot_(t1, t2) {
13988 const $t1 = convertToTensor(t1, 't1', 'dot');
13989 const $t2 = convertToTensor(t2, 't2', 'dot');
13990 assert(($t1.rank === 1 || $t1.rank === 2) && ($t2.rank === 1 || $t2.rank === 2), () => `Error in dot: inputs must all be rank 1 or 2, but got ranks ` +
13991 `${$t1.rank} and ${$t2.rank}.`);
13992 const t1Inner = ($t1.rank === 1 ? $t1.size : $t1.shape[1]);
13993 const t2Inner = ($t2.rank === 1 ? $t2.size : $t2.shape[0]);
13994 assert(t1Inner === t2Inner, () => `Error in dot: inner dimensions of inputs must match, but got ` +
13995 `${t1Inner} and ${t2Inner}.`);
13996 if ($t1.rank === 1 && $t2.rank === 1) {
13997 const t12D = reshape($t1, [1, -1]);
13998 const t22D = reshape($t2, [-1, 1]);
13999 const t1t2 = matMul(t12D, t22D);
14000 return reshape(t1t2, []);
14001 }
14002 else if ($t1.rank === 1 && $t2.rank === 2) {
14003 const t12D = reshape($t1, [1, -1]);
14004 const t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
14005 const t1t2 = matMul(t12D, t22D);
14006 return reshape(t1t2, [t1t2.size]);
14007 }
14008 else if ($t1.rank === 2 && $t2.rank === 1) {
14009 const t22D = reshape($t2, [-1, 1]);
14010 const t1t2 = matMul($t1, t22D);
14011 return reshape(t1t2, [t1t2.size]);
14012 }
14013 else {
14014 const t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
14015 const t1t2 = matMul($t1, t22D);
14016 return t1t2;
14017 }
14018 }
14019 const dot = op({ dot_ });
14020
14021 /**
14022 * @license
14023 * Copyright 2021 Google LLC. All Rights Reserved.
14024 * Licensed under the Apache License, Version 2.0 (the "License");
14025 * you may not use this file except in compliance with the License.
14026 * You may obtain a copy of the License at
14027 *
14028 * http://www.apache.org/licenses/LICENSE-2.0
14029 *
14030 * Unless required by applicable law or agreed to in writing, software
14031 * distributed under the License is distributed on an "AS IS" BASIS,
14032 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14033 * See the License for the specific language governing permissions and
14034 * limitations under the License.
14035 * =============================================================================
14036 */
14037 /**
14038 * Tensor contraction over specified indices and outer product.
14039 *
14040 * `einsum` allows defining Tensors by defining their element-wise computation.
14041 * This computation is based on
14042 * [Einstein summation](https://en.wikipedia.org/wiki/Einstein_notation).
14043 *
14044 * Some special cases include:
14045 *
14046 * Matrix multiplication:
14047 * ```js
14048 * const x = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
14049 * const y = tf.tensor2d([[0, 1], [2, 3], [4, 5]]);
14050 * x.print();
14051 * y.print();
14052 * tf.einsum('ij,jk->ik', x, y).print();
14053 * ```
14054 *
14055 * Dot product:
14056 * ```js
14057 * const x = tf.tensor1d([1, 2, 3]);
14058 * const y = tf.tensor1d([0, 1, 2]);
14059 * x.print();
14060 * y.print();
14061 * tf.einsum('i,i->', x, y).print();
14062 * ```
14063 *
14064 * Batch dot product:
14065 * ```js
14066 * const x = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
14067 * const y = tf.tensor2d([[0, 1, 2], [3, 4, 5]]);
14068 * x.print();
14069 * y.print();
14070 * tf.einsum('bi,bi->b', x, y).print();
14071 * ```
14072 *
14073 * Outer prouduct:
14074 * ```js
14075 * const x = tf.tensor1d([1, 3, 5]);
14076 * const y = tf.tensor1d([2, 4, 6]);
14077 * x.print();
14078 * y.print();
14079 * tf.einsum('i,j->ij', x, y).print();
14080 * ```
14081 *
14082 * Matrix transpose:
14083 * ```js
14084 * const x = tf.tensor2d([[1, 2], [3, 4]]);
14085 * x.print();
14086 * tf.einsum('ij->ji', x).print();
14087 * ```
14088 *
14089 * Batch matrix transpose:
14090 * ```js
14091 * const x = tf.tensor3d([[[1, 2], [3, 4]], [[-1, -2], [-3, -4]]]);
14092 * x.print();
14093 * tf.einsum('bij->bji', x).print();
14094 * ```
14095 *
14096 * Limitations:
14097 *
14098 * This implementation of einsum has the following limitations:
14099 *
14100 * - Does not support >2 input tensors.
14101 * - Does not support duplicate axes for any given input tensor. E.g., equation
14102 * 'ii->' is not suppoted.
14103 * - The `...` notation is not supported.
14104 *
14105 * @param equation a string describing the contraction, in the same format as
14106 * [numpy.einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html).
14107 * @param tensors the input(s) to contract (each one a Tensor), whose shapes
14108 * should be consistent with equation.
14109 * @returns The output tensor.
14110 *
14111 * @doc {heading: 'Tensors', subheading: 'Matrices'}
14112 */
14113 function einsum_(equation, ...tensors) {
14114 const $tensors = tensors.map((t, i) => convertToTensor(t, `tensors${i}`, 'einsum'));
14115 const attrs = { equation };
14116 return ENGINE.runKernel(Einsum, $tensors, attrs);
14117 }
14118 const einsum = op({ einsum_ });
14119
14120 /**
14121 * @license
14122 * Copyright 2020 Google LLC. All Rights Reserved.
14123 * Licensed under the Apache License, Version 2.0 (the "License");
14124 * you may not use this file except in compliance with the License.
14125 * You may obtain a copy of the License at
14126 *
14127 * http://www.apache.org/licenses/LICENSE-2.0
14128 *
14129 * Unless required by applicable law or agreed to in writing, software
14130 * distributed under the License is distributed on an "AS IS" BASIS,
14131 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14132 * See the License for the specific language governing permissions and
14133 * limitations under the License.
14134 * =============================================================================
14135 */
14136 /**
14137 * Computes exponential linear element-wise: `x > 0 ? x : (e ^ x) - 1`.
14138 *
14139 * ```js
14140 * const x = tf.tensor1d([-1, 1, -3, 2]);
14141 *
14142 * x.elu().print(); // or tf.elu(x)
14143 * ```
14144 * @param x The input tensor.
14145 *
14146 * @doc {heading: 'Operations', subheading: 'Basic math'}
14147 */
14148 function elu_(x) {
14149 const $x = convertToTensor(x, 'x', 'elu', 'float32');
14150 const inputs = { x: $x };
14151 return ENGINE.runKernel(Elu, inputs);
14152 }
14153 const elu = op({ elu_ });
14154
14155 /**
14156 * @license
14157 * Copyright 2018 Google LLC. All Rights Reserved.
14158 * Licensed under the Apache License, Version 2.0 (the "License");
14159 * you may not use this file except in compliance with the License.
14160 * You may obtain a copy of the License at
14161 *
14162 * http://www.apache.org/licenses/LICENSE-2.0
14163 *
14164 * Unless required by applicable law or agreed to in writing, software
14165 * distributed under the License is distributed on an "AS IS" BASIS,
14166 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14167 * See the License for the specific language governing permissions and
14168 * limitations under the License.
14169 * =============================================================================
14170 */
14171 /**
14172 * Computes gause error function of the input `tf.Tensor` element-wise:
14173 * `erf(x)`
14174 *
14175 * ```js
14176 * const x = tf.tensor1d([0, .1, -.1, .7]);
14177 *
14178 * x.erf().print(); // or tf.erf(x);
14179 * ```
14180 * @param x The input tensor.
14181 *
14182 * @doc {heading: 'Operations', subheading: 'Basic math'}
14183 */
14184 function erf_(x) {
14185 let $x = convertToTensor(x, 'x', 'erf');
14186 assert($x.dtype === 'int32' || $x.dtype === 'float32', () => 'Input dtype must be `int32` or `float32`.');
14187 if ($x.dtype === 'int32') {
14188 $x = cast($x, 'float32');
14189 }
14190 const inputs = { x: $x };
14191 return ENGINE.runKernel(Erf, inputs);
14192 }
14193 const erf = op({ erf_ });
14194
14195 /**
14196 * @license
14197 * Copyright 2017 Google LLC. All Rights Reserved.
14198 * Licensed under the Apache License, Version 2.0 (the "License");
14199 * you may not use this file except in compliance with the License.
14200 * You may obtain a copy of the License at
14201 *
14202 * http://www.apache.org/licenses/LICENSE-2.0
14203 *
14204 * Unless required by applicable law or agreed to in writing, software
14205 * distributed under the License is distributed on an "AS IS" BASIS,
14206 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14207 * See the License for the specific language governing permissions and
14208 * limitations under the License.
14209 * =============================================================================
14210 */
14211 /**
14212 * Returns true if the axis specifies the inner most dimensions of the
14213 * array.
14214 */
14215 function axesAreInnerMostDims(axes, rank) {
14216 for (let i = 0; i < axes.length; ++i) {
14217 if (axes[axes.length - i - 1] !== rank - 1 - i) {
14218 return false;
14219 }
14220 }
14221 return true;
14222 }
14223 function combineLocations(outputLoc, reduceLoc, axes) {
14224 const rank = outputLoc.length + reduceLoc.length;
14225 const loc = [];
14226 let outIdx = 0;
14227 let reduceIdx = 0;
14228 for (let dim = 0; dim < rank; dim++) {
14229 if (axes.indexOf(dim) === -1) {
14230 loc.push(outputLoc[outIdx++]);
14231 }
14232 else {
14233 loc.push(reduceLoc[reduceIdx++]);
14234 }
14235 }
14236 return loc;
14237 }
14238 function computeOutAndReduceShapes(aShape, axes) {
14239 const outShape = [];
14240 const rank = aShape.length;
14241 for (let dim = 0; dim < rank; dim++) {
14242 if (axes.indexOf(dim) === -1) {
14243 outShape.push(aShape[dim]);
14244 }
14245 }
14246 const reduceShape = axes.map(dim => aShape[dim]);
14247 return [outShape, reduceShape];
14248 }
14249 function expandShapeToKeepDim(shape, axes) {
14250 const reduceSubShape = axes.map(x => 1);
14251 return combineLocations(shape, reduceSubShape, axes);
14252 }
14253 function assertAxesAreInnerMostDims(msg, axes, rank) {
14254 assert(axesAreInnerMostDims(axes, rank), () => `${msg} supports only inner-most axes for now. ` +
14255 `Got axes ${axes} and rank-${rank} input.`);
14256 }
14257 /**
14258 * Returns the axes permutation to be used with `tf.transpose`, if such
14259 * permutation is necessary. Otherwise it returns null. This method is used by
14260 * operations that operate only on inner-most axes.
14261 */
14262 function getAxesPermutation(axes, rank) {
14263 if (axesAreInnerMostDims(axes, rank)) {
14264 return null;
14265 }
14266 const result = [];
14267 for (let i = 0; i < rank; ++i) {
14268 if (axes.indexOf(i) === -1) {
14269 result.push(i);
14270 }
14271 }
14272 axes.forEach(axis => result.push(axis));
14273 return result;
14274 }
14275 /** Returns the axes permutation that undoes the original permutation. */
14276 function getUndoAxesPermutation(axes) {
14277 return axes.map((axis, i) => [i, axis])
14278 .sort((a, b) => a[1] - b[1])
14279 .map(x => x[0]);
14280 }
14281 function getInnerMostAxes(numAxes, rank) {
14282 const res = [];
14283 for (let i = rank - numAxes; i < rank; ++i) {
14284 res.push(i);
14285 }
14286 return res;
14287 }
14288
14289 /**
14290 * @license
14291 * Copyright 2020 Google LLC. All Rights Reserved.
14292 * Licensed under the Apache License, Version 2.0 (the "License");
14293 * you may not use this file except in compliance with the License.
14294 * You may obtain a copy of the License at
14295 *
14296 * http://www.apache.org/licenses/LICENSE-2.0
14297 *
14298 * Unless required by applicable law or agreed to in writing, software
14299 * distributed under the License is distributed on an "AS IS" BASIS,
14300 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14301 * See the License for the specific language governing permissions and
14302 * limitations under the License.
14303 * =============================================================================
14304 */
14305 /**
14306 * Computes the maximum of elements across dimensions of a `tf.Tensor`.
14307 *
14308 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
14309 * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
14310 * `axes`. If `keepDims` is true, the reduced dimensions are retained with
14311 * length 1. If `axes` has no entries, all dimensions are reduced, and an
14312 * `tf.Tensor` with a single element is returned.
14313 *
14314 * ```js
14315 * const x = tf.tensor1d([1, 2, 3]);
14316 *
14317 * x.max().print(); // or tf.max(x)
14318 * ```
14319 *
14320 * ```js
14321 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
14322 *
14323 * const axis = 1;
14324 * x.max(axis).print(); // or tf.max(x, axis)
14325 * ```
14326 *
14327 * @param x The input tensor.
14328 * @param axis The dimension(s) to reduce. By default it reduces
14329 * all dimensions.
14330 * @param keepDims If true, retains reduced dimensions with size 1.
14331 *
14332 * @doc {heading: 'Operations', subheading: 'Reduction'}
14333 */
14334 function max_(x, axis = null, keepDims = false) {
14335 const $x = convertToTensor(x, 'x', 'max');
14336 const inputs = { x: $x };
14337 const attrs = { reductionIndices: axis, keepDims };
14338 return ENGINE.runKernel(Max, inputs, attrs);
14339 }
14340 const max = op({ max_ });
14341
14342 /**
14343 * @license
14344 * Copyright 2020 Google Inc. All Rights Reserved.
14345 * Licensed under the Apache License, Version 2.0 (the "License");
14346 * you may not use this file except in compliance with the License.
14347 * You may obtain a copy of the License at
14348 *
14349 * http://www.apache.org/licenses/LICENSE-2.0
14350 *
14351 * Unless required by applicable law or agreed to in writing, software
14352 * distributed under the License is distributed on an "AS IS" BASIS,
14353 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14354 * See the License for the specific language governing permissions and
14355 * limitations under the License.
14356 * =============================================================================
14357 */
14358 /**
14359 * Computes the minimum value from the input.
14360 *
14361 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
14362 * is true, the rank of the array is reduced by 1 for each entry in `axes`.
14363 * If `keepDims` is true, the reduced dimensions are retained with length 1.
14364 * If `axes` has no entries, all dimensions are reduced, and an array with a
14365 * single element is returned.
14366 *
14367 * ```js
14368 * const x = tf.tensor1d([1, 2, 3]);
14369 *
14370 * x.min().print(); // or tf.min(x)
14371 * ```
14372 *
14373 * ```js
14374 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
14375 *
14376 * const axis = 1;
14377 * x.min(axis).print(); // or tf.min(x, axis)
14378 * ```
14379 *
14380 * @param x The input Tensor.
14381 * @param axis The dimension(s) to reduce. By default it reduces
14382 * all dimensions.
14383 * @param keepDims If true, retains reduced dimensions with size 1.
14384 *
14385 * @doc {heading: 'Operations', subheading: 'Reduction'}
14386 */
14387 function min_(x, axis = null, keepDims = false) {
14388 const $x = convertToTensor(x, 'x', 'min');
14389 const inputs = { x: $x };
14390 const attrs = { axis, keepDims };
14391 // tslint:disable-next-line: no-unnecessary-type-assertion
14392 return ENGINE.runKernel(Min, inputs, attrs);
14393 }
14394 const min = op({ min_ });
14395
14396 /**
14397 * @license
14398 * Copyright 2020 Google LLC. All Rights Reserved.
14399 * Licensed under the Apache License, Version 2.0 (the "License");
14400 * you may not use this file except in compliance with the License.
14401 * You may obtain a copy of the License at
14402 *
14403 * http://www.apache.org/licenses/LICENSE-2.0
14404 *
14405 * Unless required by applicable law or agreed to in writing, software
14406 * distributed under the License is distributed on an "AS IS" BASIS,
14407 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14408 * See the License for the specific language governing permissions and
14409 * limitations under the License.
14410 * =============================================================================
14411 */
14412 /**
14413 * Computes the power of one `tf.Tensor` to another. Supports broadcasting.
14414 *
14415 * Given a `tf.Tensor` x and a `tf.Tensor` y, this operation computes x^y for
14416 * corresponding elements in x and y. The result's dtype will be the upcasted
14417 * type of the `base` and `exp` dtypes.
14418 *
14419 * ```js
14420 * const a = tf.tensor([[2, 3], [4, 5]])
14421 * const b = tf.tensor([[1, 2], [3, 0]]).toInt();
14422 *
14423 * a.pow(b).print(); // or tf.pow(a, b)
14424 * ```
14425 *
14426 * ```js
14427 * const a = tf.tensor([[1, 2], [3, 4]])
14428 * const b = tf.tensor(2).toInt();
14429 *
14430 * a.pow(b).print(); // or tf.pow(a, b)
14431 * ```
14432 * We also expose `powStrict` which has the same signature as this op and
14433 * asserts that `base` and `exp` are the same shape (does not broadcast).
14434 *
14435 * @param base The base `tf.Tensor` to pow element-wise.
14436 * @param exp The exponent `tf.Tensor` to pow element-wise.
14437 *
14438 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
14439 */
14440 function pow_(base, exp) {
14441 let $base = convertToTensor(base, 'base', 'pow');
14442 let $exp = convertToTensor(exp, 'exp', 'pow');
14443 [$base, $exp] = makeTypesMatch($base, $exp);
14444 const inputs = { a: $base, b: $exp };
14445 return ENGINE.runKernel(Pow, inputs);
14446 }
14447 const pow = op({ pow_ });
14448
14449 /**
14450 * @license
14451 * Copyright 2018 Google LLC. All Rights Reserved.
14452 * Licensed under the Apache License, Version 2.0 (the "License");
14453 * you may not use this file except in compliance with the License.
14454 * You may obtain a copy of the License at
14455 *
14456 * http://www.apache.org/licenses/LICENSE-2.0
14457 *
14458 * Unless required by applicable law or agreed to in writing, software
14459 * distributed under the License is distributed on an "AS IS" BASIS,
14460 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14461 * See the License for the specific language governing permissions and
14462 * limitations under the License.
14463 * =============================================================================
14464 */
14465 /**
14466 * Creates rank-0 `tf.Tensor` (scalar) with the provided value and dtype.
14467 *
14468 * The same functionality can be achieved with `tf.tensor`, but in general
14469 * we recommend using `tf.scalar` as it makes the code more readable.
14470 *
14471 * ```js
14472 * tf.scalar(3.14).print();
14473 * ```
14474 *
14475 * @param value The value of the scalar.
14476 * @param dtype The data type.
14477 *
14478 * @doc {heading: 'Tensors', subheading: 'Creation'}
14479 */
14480 function scalar(value, dtype) {
14481 if (((isTypedArray(value) && dtype !== 'string') || Array.isArray(value)) &&
14482 dtype !== 'complex64') {
14483 throw new Error('Error creating a new Scalar: value must be a primitive ' +
14484 '(number|boolean|string)');
14485 }
14486 if (dtype === 'string' && isTypedArray(value) &&
14487 !(value instanceof Uint8Array)) {
14488 throw new Error('When making a scalar from encoded string, ' +
14489 'the value must be `Uint8Array`.');
14490 }
14491 const shape = [];
14492 const inferredShape = [];
14493 return makeTensor(value, shape, inferredShape, dtype);
14494 }
14495
14496 /**
14497 * @license
14498 * Copyright 2018 Google LLC. All Rights Reserved.
14499 * Licensed under the Apache License, Version 2.0 (the "License");
14500 * you may not use this file except in compliance with the License.
14501 * You may obtain a copy of the License at
14502 *
14503 * http://www.apache.org/licenses/LICENSE-2.0
14504 *
14505 * Unless required by applicable law or agreed to in writing, software
14506 * distributed under the License is distributed on an "AS IS" BASIS,
14507 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14508 * See the License for the specific language governing permissions and
14509 * limitations under the License.
14510 * =============================================================================
14511 */
14512 /**
14513 * Computes square root of the input `tf.Tensor` element-wise: `y = sqrt(x)`
14514 *
14515 * ```js
14516 * const x = tf.tensor1d([1, 2, 4, -1]);
14517 *
14518 * x.sqrt().print(); // or tf.sqrt(x)
14519 * ```
14520 * @param x The input tensor.
14521 *
14522 * @doc {heading: 'Operations', subheading: 'Basic math'}
14523 */
14524 function sqrt_(x) {
14525 const $x = convertToTensor(x, 'x', 'sqrt', 'float32');
14526 const inputs = { x: $x };
14527 return ENGINE.runKernel(Sqrt, inputs);
14528 }
14529 const sqrt = op({ sqrt_ });
14530
14531 /**
14532 * @license
14533 * Copyright 2019 Google LLC. All Rights Reserved.
14534 * Licensed under the Apache License, Version 2.0 (the "License");
14535 * you may not use this file except in compliance with the License.
14536 * You may obtain a copy of the License at
14537 *
14538 * http://www.apache.org/licenses/LICENSE-2.0
14539 *
14540 * Unless required by applicable law or agreed to in writing, software
14541 * distributed under the License is distributed on an "AS IS" BASIS,
14542 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14543 * See the License for the specific language governing permissions and
14544 * limitations under the License.
14545 * =============================================================================
14546 */
14547 /**
14548 * Computes square of `x` element-wise: `x ^ 2`
14549 *
14550 * ```js
14551 * const x = tf.tensor1d([1, 2, Math.sqrt(2), -1]);
14552 *
14553 * x.square().print(); // or tf.square(x)
14554 * ```
14555 * @param x The input Tensor.
14556 *
14557 * @doc {heading: 'Operations', subheading: 'Basic math'}
14558 */
14559 function square_(x) {
14560 const $x = convertToTensor(x, 'x', 'square');
14561 const attrs = {};
14562 return ENGINE.runKernel('Square', { x: $x }, attrs);
14563 }
14564 const square = op({ square_ });
14565
14566 /**
14567 * @license
14568 * Copyright 2018 Google LLC. All Rights Reserved.
14569 * Licensed under the Apache License, Version 2.0 (the "License");
14570 * you may not use this file except in compliance with the License.
14571 * You may obtain a copy of the License at
14572 *
14573 * http://www.apache.org/licenses/LICENSE-2.0
14574 *
14575 * Unless required by applicable law or agreed to in writing, software
14576 * distributed under the License is distributed on an "AS IS" BASIS,
14577 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14578 * See the License for the specific language governing permissions and
14579 * limitations under the License.
14580 * =============================================================================
14581 */
14582 /**
14583 * Computes the sum of elements across dimensions of a `tf.Tensor`.
14584 *
14585 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
14586 * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
14587 * `axes`. If `keepDims` is true, the reduced dimensions are retained with
14588 * length 1. If axes has no entries, all dimensions are reduced, and a
14589 * `tf.Tensor` with a single element is returned.
14590 *
14591 * ```js
14592 * const x = tf.tensor1d([1, 2, 3]);
14593 *
14594 * x.sum().print(); // or tf.sum(x)
14595 * ```
14596 *
14597 * ```js
14598 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
14599 *
14600 * const axis = 1;
14601 * x.sum(axis).print(); // or tf.sum(x, axis)
14602 * ```
14603 *
14604 * @param x The input tensor to compute the sum over. If the dtype is `bool`
14605 * it will be converted to `int32` and the output dtype will be `int32`.
14606 * @param axis The dimension(s) to reduce. By default it reduces
14607 * all dimensions.
14608 * @param keepDims If true, retains reduced dimensions with size 1.
14609 *
14610 * @doc {heading: 'Operations', subheading: 'Reduction'}
14611 */
14612 function sum_(x, axis = null, keepDims = false) {
14613 let $x = convertToTensor(x, 'x', 'sum');
14614 if ($x.dtype === 'bool') {
14615 $x = cast($x, 'int32');
14616 }
14617 const inputs = { x: $x };
14618 const attrs = { axis, keepDims };
14619 return ENGINE.runKernel(Sum, inputs, attrs);
14620 }
14621 const sum$1 = op({ sum_ });
14622
14623 /**
14624 * @license
14625 * Copyright 2018 Google LLC. All Rights Reserved.
14626 * Licensed under the Apache License, Version 2.0 (the "License");
14627 * you may not use this file except in compliance with the License.
14628 * You may obtain a copy of the License at
14629 *
14630 * http://www.apache.org/licenses/LICENSE-2.0
14631 *
14632 * Unless required by applicable law or agreed to in writing, software
14633 * distributed under the License is distributed on an "AS IS" BASIS,
14634 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14635 * See the License for the specific language governing permissions and
14636 * limitations under the License.
14637 * =============================================================================
14638 */
14639 /**
14640 * Computes the norm of scalar, vectors, and matrices.
14641 * This function can compute several different vector norms (the 1-norm, the
14642 * Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0)
14643 * and matrix norms (Frobenius, 1-norm, and inf-norm).
14644 *
14645 * ```js
14646 * const x = tf.tensor1d([1, 2, 3, 4]);
14647 *
14648 * x.norm().print(); // or tf.norm(x)
14649 * ```
14650 *
14651 * @param x The input array.
14652 * @param ord Optional. Order of the norm. Supported norm types are
14653 * following:
14654 *
14655 * | ord | norm for matrices | norm for vectors
14656 * |------------|---------------------------|---------------------
14657 * |'euclidean' |Frobenius norm |2-norm
14658 * |'fro' |Frobenius norm |
14659 * |Infinity |max(sum(abs(x), axis=1)) |max(abs(x))
14660 * |-Infinity |min(sum(abs(x), axis=1)) |min(abs(x))
14661 * |1 |max(sum(abs(x), axis=0)) |sum(abs(x))
14662 * |2 | |sum(abs(x)^2)^1/2*
14663 *
14664 * @param axis Optional. If axis is null (the default), the input is
14665 * considered a vector and a single vector norm is computed over the entire
14666 * set of values in the Tensor, i.e. norm(x, ord) is equivalent
14667 * to norm(x.reshape([-1]), ord). If axis is a integer, the input
14668 * is considered a batch of vectors, and axis determines the axis in x
14669 * over which to compute vector norms. If axis is a 2-tuple of integer it is
14670 * considered a batch of matrices and axis determines the axes in NDArray
14671 * over which to compute a matrix norm.
14672 * @param keepDims Optional. If true, the norm have the same dimensionality
14673 * as the input.
14674 *
14675 * @doc {heading: 'Operations', subheading: 'Matrices'}
14676 */
14677 function norm_(x, ord = 'euclidean', axis = null, keepDims = false) {
14678 x = convertToTensor(x, 'x', 'norm');
14679 const norm = normImpl(x, ord, axis);
14680 let keepDimsShape = norm.shape;
14681 if (keepDims) {
14682 const axes = parseAxisParam(axis, x.shape);
14683 keepDimsShape = expandShapeToKeepDim(norm.shape, axes);
14684 }
14685 return reshape(norm, keepDimsShape);
14686 }
14687 function normImpl(x, p, axis = null) {
14688 if (x.rank === 0) {
14689 return abs(x);
14690 }
14691 // consider vector when no axis is specified
14692 if (x.rank !== 1 && axis === null) {
14693 return normImpl(reshape(x, [-1]), p, axis);
14694 }
14695 // vector
14696 if (x.rank === 1 || typeof axis === 'number' ||
14697 Array.isArray(axis) && axis.length === 1) {
14698 if (p === 1) {
14699 return sum$1(abs(x), axis);
14700 }
14701 if (p === Infinity) {
14702 return max(abs(x), axis);
14703 }
14704 if (p === -Infinity) {
14705 return min(abs(x), axis);
14706 }
14707 if (p === 'euclidean' || p === 2) {
14708 // norm(x, 2) = sum(abs(xi) ^ 2) ^ 1/2
14709 return sqrt(sum$1(pow(abs(x), scalar(2, 'int32')), axis));
14710 }
14711 throw new Error(`Error in norm: invalid ord value: ${p}`);
14712 }
14713 // matrix (assumption axis[0] < axis[1])
14714 if (Array.isArray(axis) && axis.length === 2) {
14715 if (p === 1) {
14716 return max(sum$1(abs(x), axis[0]), axis[1] - 1);
14717 }
14718 if (p === Infinity) {
14719 return max(sum$1(abs(x), axis[1]), axis[0]);
14720 }
14721 if (p === -Infinity) {
14722 return min(sum$1(abs(x), axis[1]), axis[0]);
14723 }
14724 if (p === 'fro' || p === 'euclidean') {
14725 // norm(x) = sqrt(sum(pow(x, 2)))
14726 return sqrt(sum$1(square(x), axis));
14727 }
14728 throw new Error(`Error in norm: invalid ord value: ${p}`);
14729 }
14730 throw new Error(`Error in norm: invalid axis: ${axis}`);
14731 }
14732 const norm = op({ norm_ });
14733
14734 /**
14735 * @license
14736 * Copyright 2022 Google LLC. All Rights Reserved.
14737 * Licensed under the Apache License, Version 2.0 (the "License");
14738 * you may not use this file except in compliance with the License.
14739 * You may obtain a copy of the License at
14740 *
14741 * http://www.apache.org/licenses/LICENSE-2.0
14742 *
14743 * Unless required by applicable law or agreed to in writing, software
14744 * distributed under the License is distributed on an "AS IS" BASIS,
14745 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14746 * See the License for the specific language governing permissions and
14747 * limitations under the License.
14748 * =============================================================================
14749 */
14750 /**
14751 * Computes the euclidean norm of scalar, vectors, and matrices.
14752 *
14753 * ```js
14754 * const x = tf.tensor1d([1, 2, 3, 4]);
14755 *
14756 * x.euclideanNorm().print(); // or tf.euclideanNorm(x)
14757 * ```
14758 *
14759 * @param x The input array.
14760 * @param axis Optional. If axis is null (the default), the input is
14761 * considered a vector and a single vector norm is computed over the entire
14762 * set of values in the Tensor, i.e. euclideanNorm(x) is equivalent
14763 * to euclideanNorm(x.reshape([-1])). If axis is a integer, the input
14764 * is considered a batch of vectors, and axis determines the axis in x
14765 * over which to compute vector norms. If axis is a 2-tuple of integer it is
14766 * considered a batch of matrices and axis determines the axes in NDArray
14767 * over which to compute a matrix norm.
14768 * @param keepDims Optional. If true, the norm have the same dimensionality
14769 * as the input.
14770 *
14771 * @doc {heading: 'Operations', subheading: 'Matrices'}
14772 */
14773 function euclideanNorm_(x, axis = null, keepDims = false) {
14774 return norm(x, 'euclidean', axis, keepDims);
14775 }
14776 const euclideanNorm = op({ euclideanNorm_ });
14777
14778 /**
14779 * @license
14780 * Copyright 2018 Google LLC. All Rights Reserved.
14781 * Licensed under the Apache License, Version 2.0 (the "License");
14782 * you may not use this file except in compliance with the License.
14783 * You may obtain a copy of the License at
14784 *
14785 * http://www.apache.org/licenses/LICENSE-2.0
14786 *
14787 * Unless required by applicable law or agreed to in writing, software
14788 * distributed under the License is distributed on an "AS IS" BASIS,
14789 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14790 * See the License for the specific language governing permissions and
14791 * limitations under the License.
14792 * =============================================================================
14793 */
14794 /**
14795 * Computes exponential of the input `tf.Tensor` element-wise. `e ^ x`
14796 *
14797 * ```js
14798 * const x = tf.tensor1d([1, 2, -3]);
14799 *
14800 * x.exp().print(); // or tf.exp(x)
14801 * ```
14802 * @param x The input tensor.
14803 *
14804 * @doc {heading: 'Operations', subheading: 'Basic math'}
14805 */
14806 function exp_(x) {
14807 const $x = convertToTensor(x, 'x', 'exp');
14808 const inputs = { x: $x };
14809 return ENGINE.runKernel(Exp, inputs);
14810 }
14811 const exp = op({ exp_ });
14812
14813 /**
14814 * @license
14815 * Copyright 2020 Google LLC. All Rights Reserved.
14816 * Licensed under the Apache License, Version 2.0 (the "License");
14817 * you may not use this file except in compliance with the License.
14818 * You may obtain a copy of the License at
14819 *
14820 * http://www.apache.org/licenses/LICENSE-2.0
14821 *
14822 * Unless required by applicable law or agreed to in writing, software
14823 * distributed under the License is distributed on an "AS IS" BASIS,
14824 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14825 * See the License for the specific language governing permissions and
14826 * limitations under the License.
14827 * =============================================================================
14828 */
14829 /**
14830 * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension
14831 * into the tensor's shape.
14832 *
14833 * ```js
14834 * const x = tf.tensor1d([1, 2, 3, 4]);
14835 * const axis = 1;
14836 * x.expandDims(axis).print();
14837 * ```
14838 *
14839 * @param x The input tensor whose dimensions to be expanded.
14840 * @param axis The dimension index at which to insert shape of `1`. Defaults
14841 * to 0 (the first dimension).
14842 *
14843 * @doc {heading: 'Tensors', subheading: 'Transformations'}
14844 */
14845 function expandDims_(x, axis = 0) {
14846 const $x = convertToTensor(x, 'x', 'expandDims', 'string_or_numeric');
14847 assert(axis <= $x.rank, () => 'Axis must be <= rank of the tensor');
14848 const inputs = { input: $x };
14849 const attrs = { dim: axis };
14850 return ENGINE.runKernel(ExpandDims, inputs, attrs);
14851 }
14852 const expandDims = op({ expandDims_ });
14853
14854 /**
14855 * @license
14856 * Copyright 2018 Google LLC. All Rights Reserved.
14857 * Licensed under the Apache License, Version 2.0 (the "License");
14858 * you may not use this file except in compliance with the License.
14859 * You may obtain a copy of the License at
14860 *
14861 * http://www.apache.org/licenses/LICENSE-2.0
14862 *
14863 * Unless required by applicable law or agreed to in writing, software
14864 * distributed under the License is distributed on an "AS IS" BASIS,
14865 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14866 * See the License for the specific language governing permissions and
14867 * limitations under the License.
14868 * =============================================================================
14869 */
14870 /**
14871 * Computes exponential of the input `tf.Tensor` minus one element-wise.
14872 * `e ^ x - 1`
14873 *
14874 * ```js
14875 * const x = tf.tensor1d([1, 2, -3]);
14876 *
14877 * x.expm1().print(); // or tf.expm1(x)
14878 * ```
14879 * @param x The input tensor.
14880 *
14881 * @doc {heading: 'Operations', subheading: 'Basic math'}
14882 */
14883 function expm1_(x) {
14884 const $x = convertToTensor(x, 'x', 'expm1');
14885 const inputs = { x: $x };
14886 return ENGINE.runKernel(Expm1, inputs);
14887 }
14888 const expm1 = op({ expm1_ });
14889
14890 /**
14891 * @license
14892 * Copyright 2020 Google LLC. All Rights Reserved.
14893 * Licensed under the Apache License, Version 2.0 (the "License");
14894 * you may not use this file except in compliance with the License.
14895 * You may obtain a copy of the License at
14896 *
14897 * http://www.apache.org/licenses/LICENSE-2.0
14898 *
14899 * Unless required by applicable law or agreed to in writing, software
14900 * distributed under the License is distributed on an "AS IS" BASIS,
14901 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14902 * See the License for the specific language governing permissions and
14903 * limitations under the License.
14904 * =============================================================================
14905 */
14906 /**
14907 * Construct a tensor by repeating it the number of times given by reps.
14908 *
14909 * This operation creates a new tensor by replicating `input` `reps`
14910 * times. The output tensor's i'th dimension has `input.shape[i] *
14911 * reps[i]` elements, and the values of `input` are replicated
14912 * `reps[i]` times along the i'th dimension. For example, tiling
14913 * `[a, b, c, d]` by `[2]` produces `[a, b, c, d, a, b, c, d]`.
14914 *
14915 * ```js
14916 * const a = tf.tensor1d([1, 2]);
14917 *
14918 * a.tile([2]).print(); // or a.tile([2])
14919 * ```
14920 *
14921 * ```js
14922 * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
14923 *
14924 * a.tile([1, 2]).print(); // or a.tile([1, 2])
14925 * ```
14926 * @param x The tensor to tile.
14927 * @param reps Determines the number of replications per dimension.
14928 *
14929 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
14930 */
14931 function tile_(x, reps) {
14932 const $x = convertToTensor(x, 'x', 'tile', 'string_or_numeric');
14933 assert($x.rank === reps.length, () => `Error in transpose: rank of input ${$x.rank} ` +
14934 `must match length of reps ${reps}.`);
14935 const inputs = { x: $x };
14936 const attrs = { reps };
14937 return ENGINE.runKernel(Tile, inputs, attrs);
14938 }
14939 const tile = op({ tile_ });
14940
14941 /**
14942 * @license
14943 * Copyright 2020 Google LLC. All Rights Reserved.
14944 * Licensed under the Apache License, Version 2.0 (the "License");
14945 * you may not use this file except in compliance with the License.
14946 * You may obtain a copy of the License at
14947 *
14948 * http://www.apache.org/licenses/LICENSE-2.0
14949 *
14950 * Unless required by applicable law or agreed to in writing, software
14951 * distributed under the License is distributed on an "AS IS" BASIS,
14952 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14953 * See the License for the specific language governing permissions and
14954 * limitations under the License.
14955 * =============================================================================
14956 */
14957 /**
14958 * Create an identity matrix.
14959 *
14960 * @param numRows Number of rows.
14961 * @param numColumns Number of columns. Defaults to `numRows`.
14962 * @param batchShape If provided, will add the batch shape to the beginning
14963 * of the shape of the returned `tf.Tensor` by repeating the identity
14964 * matrix.
14965 * @param dtype Data type.
14966 * @returns Identity matrix of the specified size and data type, possibly
14967 * with batch repetition if `batchShape` is specified.
14968 *
14969 * @doc {heading: 'Tensors', subheading: 'Creation'}
14970 */
14971 function eye_(numRows, numColumns, batchShape, dtype = 'float32') {
14972 if (numColumns == null) {
14973 numColumns = numRows;
14974 }
14975 const buff = buffer([numRows, numColumns], dtype);
14976 const n = numRows <= numColumns ? numRows : numColumns;
14977 for (let i = 0; i < n; ++i) {
14978 buff.set(1, i, i);
14979 }
14980 const out = reshape(buff.toTensor(), [numRows, numColumns]);
14981 if (batchShape == null) {
14982 return out;
14983 }
14984 else {
14985 if (batchShape.length === 1) {
14986 return tile(expandDims(out, 0), [batchShape[0], 1, 1]);
14987 }
14988 else if (batchShape.length === 2) {
14989 // tslint:disable-next-line:no-unnecessary-type-assertion
14990 return tile(expandDims(expandDims(out, 0), 0), [batchShape[0], batchShape[1], 1, 1]);
14991 }
14992 else if (batchShape.length === 3) {
14993 // tslint:disable-next-line:no-unnecessary-type-assertion
14994 return tile(expandDims(expandDims(expandDims(out, 0), 0), 0), [
14995 batchShape[0], batchShape[1], batchShape[2], 1, 1
14996 ]);
14997 }
14998 else {
14999 throw new Error(`eye() currently supports only 1D and 2D ` +
15000 // tslint:disable-next-line:no-any
15001 `batchShapes, but received ${batchShape.length}D.`);
15002 }
15003 }
15004 }
15005 const eye = op({ eye_ });
15006
15007 /**
15008 * @license
15009 * Copyright 2020 Google LLC. All Rights Reserved.
15010 * Licensed under the Apache License, Version 2.0 (the "License");
15011 * you may not use this file except in compliance with the License.
15012 * You may obtain a copy of the License at
15013 *
15014 * http://www.apache.org/licenses/LICENSE-2.0
15015 *
15016 * Unless required by applicable law or agreed to in writing, software
15017 * distributed under the License is distributed on an "AS IS" BASIS,
15018 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15019 * See the License for the specific language governing permissions and
15020 * limitations under the License.
15021 * =============================================================================
15022 */
15023 /**
15024 * Creates a `tf.Tensor` filled with a scalar value.
15025 *
15026 * ```js
15027 * tf.fill([2, 2], 4).print();
15028 * ```
15029 *
15030 * @param shape An array of integers defining the output tensor shape.
15031 * @param value The scalar value to fill the tensor with.
15032 * @param dtype The type of an element in the resulting tensor. Defaults to
15033 * 'float'.
15034 *
15035 * @doc {heading: 'Tensors', subheading: 'Creation'}
15036 */
15037 function fill(shape, value, dtype) {
15038 const attrs = { shape, value, dtype };
15039 return ENGINE.runKernel(Fill, {}, attrs);
15040 }
15041
15042 /**
15043 * @license
15044 * Copyright 2018 Google LLC. All Rights Reserved.
15045 * Licensed under the Apache License, Version 2.0 (the "License");
15046 * you may not use this file except in compliance with the License.
15047 * You may obtain a copy of the License at
15048 *
15049 * http://www.apache.org/licenses/LICENSE-2.0
15050 *
15051 * Unless required by applicable law or agreed to in writing, software
15052 * distributed under the License is distributed on an "AS IS" BASIS,
15053 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15054 * See the License for the specific language governing permissions and
15055 * limitations under the License.
15056 * =============================================================================
15057 */
15058 /**
15059 * Computes floor of input `tf.Tensor` element-wise: `floor(x)`.
15060 *
15061 * ```js
15062 * const x = tf.tensor1d([.6, 1.1, -3.3]);
15063 *
15064 * x.floor().print(); // or tf.floor(x)
15065 * ```
15066 * @param x The input tensor.
15067 *
15068 * @doc {heading: 'Operations', subheading: 'Basic math'}
15069 */
15070 function floor_(x) {
15071 const $x = convertToTensor(x, 'x', 'floor', 'float32');
15072 const inputs = { x: $x };
15073 return ENGINE.runKernel(Floor, inputs);
15074 }
15075 const floor = op({ floor_ });
15076
15077 /**
15078 * @license
15079 * Copyright 2018 Google LLC. All Rights Reserved.
15080 * Licensed under the Apache License, Version 2.0 (the "License");
15081 * you may not use this file except in compliance with the License.
15082 * You may obtain a copy of the License at
15083 *
15084 * http://www.apache.org/licenses/LICENSE-2.0
15085 *
15086 * Unless required by applicable law or agreed to in writing, software
15087 * distributed under the License is distributed on an "AS IS" BASIS,
15088 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15089 * See the License for the specific language governing permissions and
15090 * limitations under the License.
15091 * =============================================================================
15092 */
15093 /**
15094 * Gather slices from tensor `x`'s axis `axis` according to `indices`.
15095 *
15096 * ```js
15097 * const x = tf.tensor1d([1, 2, 3, 4]);
15098 * const indices = tf.tensor1d([1, 3, 3], 'int32');
15099 *
15100 * x.gather(indices).print();
15101 * ```
15102 *
15103 * ```js
15104 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
15105 * const indices = tf.tensor1d([1, 1, 0], 'int32');
15106 *
15107 * x.gather(indices).print();
15108 * ```
15109 * @param x The input tensor whose slices to be gathered.
15110 * @param indices The indices of the values to extract.
15111 * @param axis The axis over which to select values. Defaults to 0.
15112 * @param batchDims Optional. The number of batch dimensions. It must be less
15113 * than or equal to rank(indices). Defaults to 0.
15114 * The output tensor will have shape of
15115 * `x.shape[:axis] + indices.shape[batchDims:] + x.shape[axis + 1:]`
15116 *
15117 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
15118 */
15119 function gather_(x, indices, axis = 0, batchDims = 0) {
15120 const $x = convertToTensor(x, 'x', 'gather');
15121 const $indices = convertToTensor(indices, 'indices', 'gather', 'int32');
15122 const inputs = { x: $x, indices: $indices };
15123 const attrs = { axis, batchDims };
15124 return ENGINE.runKernel(GatherV2, inputs, attrs);
15125 }
15126 const gather = op({ gather_ });
15127
15128 /**
15129 * @license
15130 * Copyright 2020 Google LLC. All Rights Reserved.
15131 * Licensed under the Apache License, Version 2.0 (the "License");
15132 * you may not use this file except in compliance with the License.
15133 * You may obtain a copy of the License at
15134 *
15135 * http://www.apache.org/licenses/LICENSE-2.0
15136 *
15137 * Unless required by applicable law or agreed to in writing, software
15138 * distributed under the License is distributed on an "AS IS" BASIS,
15139 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15140 * See the License for the specific language governing permissions and
15141 * limitations under the License.
15142 * =============================================================================
15143 */
15144 /**
15145 * Returns the truth value of (a > b) element-wise. Supports broadcasting.
15146 *
15147 * ```js
15148 * const a = tf.tensor1d([1, 2, 3]);
15149 * const b = tf.tensor1d([2, 2, 2]);
15150 *
15151 * a.greater(b).print();
15152 * ```
15153 *
15154 * @param a The first input tensor.
15155 * @param b The second input tensor. Must have the same dtype as `a`.
15156 *
15157 * @doc {heading: 'Operations', subheading: 'Logical'}
15158 */
15159 function greater_(a, b) {
15160 let $a = convertToTensor(a, 'a', 'greater', 'string_or_numeric');
15161 let $b = convertToTensor(b, 'b', 'greater', 'string_or_numeric');
15162 [$a, $b] = makeTypesMatch($a, $b);
15163 assertAndGetBroadcastShape($a.shape, $b.shape);
15164 const inputs = { a: $a, b: $b };
15165 return ENGINE.runKernel(Greater, inputs);
15166 }
15167 const greater = op({ greater_ });
15168
15169 /**
15170 * @license
15171 * Copyright 2020 Google LLC. All Rights Reserved.
15172 * Licensed under the Apache License, Version 2.0 (the "License");
15173 * you may not use this file except in compliance with the License.
15174 * You may obtain a copy of the License at
15175 *
15176 * http://www.apache.org/licenses/LICENSE-2.0
15177 *
15178 * Unless required by applicable law or agreed to in writing, software
15179 * distributed under the License is distributed on an "AS IS" BASIS,
15180 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15181 * See the License for the specific language governing permissions and
15182 * limitations under the License.
15183 * =============================================================================
15184 */
15185 /**
15186 * Returns the truth value of (a >= b) element-wise. Supports broadcasting.
15187 *
15188 * ```js
15189 * const a = tf.tensor1d([1, 2, 3]);
15190 * const b = tf.tensor1d([2, 2, 2]);
15191 *
15192 * a.greaterEqual(b).print();
15193 * ```
15194 *
15195 * @param a The first input tensor.
15196 * @param b The second input tensor. Must have the same dtype as `a`.
15197 *
15198 * @doc {heading: 'Operations', subheading: 'Logical'}
15199 */
15200 function greaterEqual_(a, b) {
15201 let $a = convertToTensor(a, 'a', 'greaterEqual', 'string_or_numeric');
15202 let $b = convertToTensor(b, 'b', 'greaterEqual', 'string_or_numeric');
15203 [$a, $b] = makeTypesMatch($a, $b);
15204 assertAndGetBroadcastShape($a.shape, $b.shape);
15205 const inputs = { a: $a, b: $b };
15206 return ENGINE.runKernel(GreaterEqual, inputs);
15207 }
15208 const greaterEqual = op({ greaterEqual_ });
15209
15210 /**
15211 * @license
15212 * Copyright 2018 Google LLC. All Rights Reserved.
15213 * Licensed under the Apache License, Version 2.0 (the "License");
15214 * you may not use this file except in compliance with the License.
15215 * You may obtain a copy of the License at
15216 *
15217 * http://www.apache.org/licenses/LICENSE-2.0
15218 *
15219 * Unless required by applicable law or agreed to in writing, software
15220 * distributed under the License is distributed on an "AS IS" BASIS,
15221 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15222 * See the License for the specific language governing permissions and
15223 * limitations under the License.
15224 * =============================================================================
15225 */
15226 /**
15227 * Returns which elements of x are finite.
15228 *
15229 * ```js
15230 * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
15231 *
15232 * x.isFinite().print(); // or tf.isNaN(x)
15233 * ```
15234 * @param x The input Tensor.
15235 *
15236 * @doc {heading: 'Operations', subheading: 'Basic math'}
15237 */
15238 function isFinite_(x) {
15239 const $x = convertToTensor(x, 'x', 'isFinite');
15240 const inputs = { x: $x };
15241 return ENGINE.runKernel(IsFinite, inputs);
15242 }
15243 const isFinite$1 = op({ isFinite_ });
15244
15245 /**
15246 * @license
15247 * Copyright 2018 Google LLC. All Rights Reserved.
15248 * Licensed under the Apache License, Version 2.0 (the "License");
15249 * you may not use this file except in compliance with the License.
15250 * You may obtain a copy of the License at
15251 *
15252 * http://www.apache.org/licenses/LICENSE-2.0
15253 *
15254 * Unless required by applicable law or agreed to in writing, software
15255 * distributed under the License is distributed on an "AS IS" BASIS,
15256 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15257 * See the License for the specific language governing permissions and
15258 * limitations under the License.
15259 * =============================================================================
15260 */
15261 /**
15262 * Returns which elements of x are Infinity or -Infinity.
15263 *
15264 * ```js
15265 * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
15266 *
15267 * x.isInf().print(); // or tf.isNaN(x)
15268 * ```
15269 * @param x The input Tensor.
15270 *
15271 * @doc {heading: 'Operations', subheading: 'Basic math'}
15272 */
15273 function isInf_(x) {
15274 const $x = convertToTensor(x, 'x', 'isInf');
15275 const inputs = { x: $x };
15276 return ENGINE.runKernel(IsInf, inputs);
15277 }
15278 const isInf = op({ isInf_ });
15279
15280 /**
15281 * @license
15282 * Copyright 2018 Google LLC. All Rights Reserved.
15283 * Licensed under the Apache License, Version 2.0 (the "License");
15284 * you may not use this file except in compliance with the License.
15285 * You may obtain a copy of the License at
15286 *
15287 * http://www.apache.org/licenses/LICENSE-2.0
15288 *
15289 * Unless required by applicable law or agreed to in writing, software
15290 * distributed under the License is distributed on an "AS IS" BASIS,
15291 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15292 * See the License for the specific language governing permissions and
15293 * limitations under the License.
15294 * =============================================================================
15295 */
15296 /**
15297 * RReturns which elements of x are NaN.
15298 *
15299 * ```js
15300 * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
15301 *
15302 * x.isNaN().print(); // or tf.isNaN(x)
15303 * ```
15304 * @param x The input Tensor.
15305 *
15306 * @doc {heading: 'Operations', subheading: 'Basic math'}
15307 */
15308 function isNaN_(x) {
15309 const $x = convertToTensor(x, 'x', 'isNaN');
15310 const inputs = { x: $x };
15311 return ENGINE.runKernel(IsNan, inputs);
15312 }
15313 const isNaN$1 = op({ isNaN_ });
15314
15315 /**
15316 * @license
15317 * Copyright 2020 Google LLC. All Rights Reserved.
15318 * Licensed under the Apache License, Version 2.0 (the "License");
15319 * you may not use this file except in compliance with the License.
15320 * You may obtain a copy of the License at
15321 *
15322 * http://www.apache.org/licenses/LICENSE-2.0
15323 *
15324 * Unless required by applicable law or agreed to in writing, software
15325 * distributed under the License is distributed on an "AS IS" BASIS,
15326 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15327 * See the License for the specific language governing permissions and
15328 * limitations under the License.
15329 * =============================================================================
15330 */
15331 /**
15332 * Computes leaky rectified linear element-wise.
15333 *
15334 * See
15335 * [http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf](
15336 * http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf)
15337 *
15338 * ```js
15339 * const x = tf.tensor1d([-1, 2, -3, 4]);
15340 *
15341 * x.leakyRelu(0.1).print(); // or tf.leakyRelu(x, 0.1)
15342 * ```
15343 * @param x The input tensor.
15344 * @param alpha The scaling factor for negative values, defaults to 0.2.
15345 *
15346 * @doc {heading: 'Operations', subheading: 'Basic math'}
15347 */
15348 function leakyRelu_(x, alpha = 0.2) {
15349 const $x = convertToTensor(x, 'x', 'leakyRelu');
15350 const inputs = { x: $x };
15351 const attrs = { alpha };
15352 return ENGINE.runKernel(LeakyRelu, inputs, attrs);
15353 }
15354 const leakyRelu = op({ leakyRelu_ });
15355
15356 /**
15357 * @license
15358 * Copyright 2020 Google LLC. All Rights Reserved.
15359 * Licensed under the Apache License, Version 2.0 (the "License");
15360 * you may not use this file except in compliance with the License.
15361 * You may obtain a copy of the License at
15362 *
15363 * http://www.apache.org/licenses/LICENSE-2.0
15364 *
15365 * Unless required by applicable law or agreed to in writing, software
15366 * distributed under the License is distributed on an "AS IS" BASIS,
15367 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15368 * See the License for the specific language governing permissions and
15369 * limitations under the License.
15370 * =============================================================================
15371 */
15372 /**
15373 * Returns the truth value of (a < b) element-wise. Supports broadcasting.
15374 *
15375 * ```js
15376 * const a = tf.tensor1d([1, 2, 3]);
15377 * const b = tf.tensor1d([2, 2, 2]);
15378 *
15379 * a.less(b).print();
15380 * ```
15381 * @param a The first input tensor.
15382 * @param b The second input tensor. Must have the same dtype as `a`.
15383 *
15384 * @doc {heading: 'Operations', subheading: 'Logical'}
15385 */
15386 function less_(a, b) {
15387 let $a = convertToTensor(a, 'a', 'less', 'string_or_numeric');
15388 let $b = convertToTensor(b, 'b', 'less', 'string_or_numeric');
15389 [$a, $b] = makeTypesMatch($a, $b);
15390 assertAndGetBroadcastShape($a.shape, $b.shape);
15391 const inputs = { a: $a, b: $b };
15392 return ENGINE.runKernel(Less, inputs);
15393 }
15394 const less = op({ less_ });
15395
15396 /**
15397 * @license
15398 * Copyright 2020 Google LLC. All Rights Reserved.
15399 * Licensed under the Apache License, Version 2.0 (the "License");
15400 * you may not use this file except in compliance with the License.
15401 * You may obtain a copy of the License at
15402 *
15403 * http://www.apache.org/licenses/LICENSE-2.0
15404 *
15405 * Unless required by applicable law or agreed to in writing, software
15406 * distributed under the License is distributed on an "AS IS" BASIS,
15407 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15408 * See the License for the specific language governing permissions and
15409 * limitations under the License.
15410 * =============================================================================
15411 */
15412 /**
15413 * Returns the truth value of (a <= b) element-wise. Supports broadcasting.
15414 *
15415 * ```js
15416 * const a = tf.tensor1d([1, 2, 3]);
15417 * const b = tf.tensor1d([2, 2, 2]);
15418 *
15419 * a.lessEqual(b).print();
15420 * ```
15421 *
15422 * @param a The first input tensor.
15423 * @param b The second input tensor. Must have the same dtype as `a`.
15424 *
15425 * @doc {heading: 'Operations', subheading: 'Logical'}
15426 */
15427 function lessEqual_(a, b) {
15428 let $a = convertToTensor(a, 'a', 'lessEqual', 'string_or_numeric');
15429 let $b = convertToTensor(b, 'b', 'lessEqual', 'string_or_numeric');
15430 [$a, $b] = makeTypesMatch($a, $b);
15431 assertAndGetBroadcastShape($a.shape, $b.shape);
15432 const inputs = { a: $a, b: $b };
15433 return ENGINE.runKernel(LessEqual, inputs);
15434 }
15435 const lessEqual = op({ lessEqual_ });
15436
15437 /**
15438 * @license
15439 * Copyright 2018 Google LLC. All Rights Reserved.
15440 * Licensed under the Apache License, Version 2.0 (the "License");
15441 * you may not use this file except in compliance with the License.
15442 * You may obtain a copy of the License at
15443 *
15444 * http://www.apache.org/licenses/LICENSE-2.0
15445 *
15446 * Unless required by applicable law or agreed to in writing, software
15447 * distributed under the License is distributed on an "AS IS" BASIS,
15448 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15449 * See the License for the specific language governing permissions and
15450 * limitations under the License.
15451 * =============================================================================
15452 */
15453 /**
15454 * Return an evenly spaced sequence of numbers over the given interval.
15455 *
15456 * ```js
15457 * tf.linspace(0, 9, 10).print();
15458 * ```
15459 * @param start The start value of the sequence.
15460 * @param stop The end value of the sequence.
15461 * @param num The number of values to generate.
15462 *
15463 * @doc {heading: 'Tensors', subheading: 'Creation'}
15464 */
15465 function linspace(start, stop, num) {
15466 if (num <= 0) {
15467 throw new Error('The number of values should be positive.');
15468 }
15469 const attrs = { start, stop, num };
15470 return ENGINE.runKernel(LinSpace, {}, attrs);
15471 }
15472
15473 /**
15474 * @license
15475 * Copyright 2020 Google LLC. All Rights Reserved.
15476 * Licensed under the Apache License, Version 2.0 (the "License");
15477 * you may not use this file except in compliance with the License.
15478 * You may obtain a copy of the License at
15479 *
15480 * http://www.apache.org/licenses/LICENSE-2.0
15481 *
15482 * Unless required by applicable law or agreed to in writing, software
15483 * distributed under the License is distributed on an "AS IS" BASIS,
15484 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15485 * See the License for the specific language governing permissions and
15486 * limitations under the License.
15487 * =============================================================================
15488 */
15489 /**
15490 * Normalizes the activation of a local neighborhood across or within
15491 * channels.
15492 *
15493 * @param x The input tensor. The 4-D input tensor is treated as a 3-D array
15494 * of 1D vectors (along the last dimension), and each vector is
15495 * normalized independently.
15496 * @param depthRadius The number of adjacent channels in the 1D normalization
15497 * window.
15498 * @param bias A constant bias term for the basis.
15499 * @param alpha A scale factor, usually positive.
15500 * @param beta An exponent.
15501 *
15502 * @doc {heading: 'Operations', subheading: 'Normalization'}
15503 */
15504 function localResponseNormalization_(x, depthRadius = 5, bias = 1, alpha = 1, beta = 0.5) {
15505 const $x = convertToTensor(x, 'x', 'localResponseNormalization');
15506 assert($x.rank === 4 || $x.rank === 3, () => `Error in localResponseNormalization: x must be rank 3 or 4 but got
15507 rank ${$x.rank}.`);
15508 assert(isInt(depthRadius), () => `Error in localResponseNormalization: depthRadius must be an ` +
15509 `integer but got depthRadius ${depthRadius}.`);
15510 let x4D = $x;
15511 let reshapedTo4D = false;
15512 if ($x.rank === 3) {
15513 reshapedTo4D = true;
15514 x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
15515 }
15516 const inputs = { x: x4D };
15517 const attrs = { depthRadius, bias, alpha, beta };
15518 // tslint:disable-next-line: no-unnecessary-type-assertion
15519 const res = ENGINE.runKernel(LRN, inputs, attrs);
15520 if (reshapedTo4D) {
15521 return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
15522 }
15523 else {
15524 return res;
15525 }
15526 }
15527 const localResponseNormalization = op({ localResponseNormalization_ });
15528
15529 /**
15530 * @license
15531 * Copyright 2018 Google LLC. All Rights Reserved.
15532 * Licensed under the Apache License, Version 2.0 (the "License");
15533 * you may not use this file except in compliance with the License.
15534 * You may obtain a copy of the License at
15535 *
15536 * http://www.apache.org/licenses/LICENSE-2.0
15537 *
15538 * Unless required by applicable law or agreed to in writing, software
15539 * distributed under the License is distributed on an "AS IS" BASIS,
15540 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15541 * See the License for the specific language governing permissions and
15542 * limitations under the License.
15543 * =============================================================================
15544 */
15545 /**
15546 * Computes natural logarithm of the input `tf.Tensor` element-wise: `ln(x)`
15547 *
15548 * ```js
15549 * const x = tf.tensor1d([1, 2, Math.E]);
15550 *
15551 * x.log().print(); // or tf.log(x)
15552 * ```
15553 * @param x The input tensor.
15554 *
15555 * @doc {heading: 'Operations', subheading: 'Basic math'}
15556 */
15557 function log_(x) {
15558 const $x = convertToTensor(x, 'x', 'log', 'float32');
15559 const inputs = { x: $x };
15560 return ENGINE.runKernel(Log, inputs);
15561 }
15562 const log$1 = op({ log_ });
15563
15564 /**
15565 * @license
15566 * Copyright 2018 Google LLC. All Rights Reserved.
15567 * Licensed under the Apache License, Version 2.0 (the "License");
15568 * you may not use this file except in compliance with the License.
15569 * You may obtain a copy of the License at
15570 *
15571 * http://www.apache.org/licenses/LICENSE-2.0
15572 *
15573 * Unless required by applicable law or agreed to in writing, software
15574 * distributed under the License is distributed on an "AS IS" BASIS,
15575 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15576 * See the License for the specific language governing permissions and
15577 * limitations under the License.
15578 * =============================================================================
15579 */
15580 /**
15581 * Computes natural logarithm of the input `tf.Tensor` plus one
15582 * element-wise: `ln(1 + x)`
15583 *
15584 * ```js
15585 * const x = tf.tensor1d([1, 2, Math.E - 1]);
15586 *
15587 * x.log1p().print(); // or tf.log1p(x)
15588 * ```
15589 * @param x The input tensor.
15590 *
15591 * @doc {heading: 'Operations', subheading: 'Basic math'}
15592 */
15593 function log1p_(x) {
15594 const $x = convertToTensor(x, 'x', 'log1p');
15595 const inputs = { x: $x };
15596 return ENGINE.runKernel(Log1p, inputs);
15597 }
15598 const log1p = op({ log1p_ });
15599
15600 /**
15601 * @license
15602 * Copyright 2018 Google LLC. All Rights Reserved.
15603 * Licensed under the Apache License, Version 2.0 (the "License");
15604 * you may not use this file except in compliance with the License.
15605 * You may obtain a copy of the License at
15606 *
15607 * http://www.apache.org/licenses/LICENSE-2.0
15608 *
15609 * Unless required by applicable law or agreed to in writing, software
15610 * distributed under the License is distributed on an "AS IS" BASIS,
15611 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15612 * See the License for the specific language governing permissions and
15613 * limitations under the License.
15614 * =============================================================================
15615 */
15616 /**
15617 * Provided `f(x)`, returns another function `g(x, dy?)`, which gives the
15618 * gradient of `f(x)` with respect to `x`.
15619 *
15620 * If `dy` is provided, the gradient of `f(x).mul(dy).sum()` with respect to
15621 * `x` is computed instead. `f(x)` must take a single tensor `x` and return a
15622 * single tensor `y`. If `f()` takes multiple inputs, use `tf.grads` instead.
15623 *
15624 * ```js
15625 * // f(x) = x ^ 2
15626 * const f = x => x.square();
15627 * // f'(x) = 2x
15628 * const g = tf.grad(f);
15629 *
15630 * const x = tf.tensor1d([2, 3]);
15631 * g(x).print();
15632 * ```
15633 *
15634 * ```js
15635 * // f(x) = x ^ 3
15636 * const f = x => x.pow(tf.scalar(3, 'int32'));
15637 * // f'(x) = 3x ^ 2
15638 * const g = tf.grad(f);
15639 * // f''(x) = 6x
15640 * const gg = tf.grad(g);
15641 *
15642 * const x = tf.tensor1d([2, 3]);
15643 * gg(x).print();
15644 * ```
15645 *
15646 * @param f The function f(x), to compute gradient for.
15647 *
15648 * @doc {heading: 'Training', subheading: 'Gradients'}
15649 */
15650 function grad(f) {
15651 assert(isFunction(f), () => 'The f passed in grad(f) must be a function');
15652 return (x, dy) => {
15653 // x can be of any dtype, thus null as the last argument.
15654 const $x = convertToTensor(x, 'x', 'tf.grad', 'string_or_numeric');
15655 const $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grad') : null;
15656 return ENGINE.tidy(() => {
15657 const { value, grads } = ENGINE.gradients(() => f($x), [$x], $dy);
15658 if ($dy != null) {
15659 assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grad(f)(x, dy) must match the shape ' +
15660 'returned by f(x)');
15661 }
15662 checkGrads(grads);
15663 return grads[0];
15664 });
15665 };
15666 }
15667 /**
15668 * Provided `f(x1, x2,...)`, returns another function `g([x1, x2,...], dy?)`,
15669 * which gives an array of gradients of `f()` with respect to each input
15670 * [`x1`,`x2`,...].
15671 *
15672 * If `dy` is passed when calling `g()`, the gradient of
15673 * `f(x1,...).mul(dy).sum()` with respect to each input is computed instead.
15674 * The provided `f` must take one or more tensors and return a single tensor
15675 * `y`. If `f()` takes a single input, we recommend using `tf.grad` instead.
15676 *
15677 * ```js
15678 * // f(a, b) = a * b
15679 * const f = (a, b) => a.mul(b);
15680 * // df / da = b, df / db = a
15681 * const g = tf.grads(f);
15682 *
15683 * const a = tf.tensor1d([2, 3]);
15684 * const b = tf.tensor1d([-2, -3]);
15685 * const [da, db] = g([a, b]);
15686 * console.log('da');
15687 * da.print();
15688 * console.log('db');
15689 * db.print();
15690 * ```
15691 *
15692 * @param f The function `f(x1, x2,...)` to compute gradients for.
15693 *
15694 * @doc {heading: 'Training', subheading: 'Gradients'}
15695 */
15696 function grads(f) {
15697 assert(isFunction(f), () => 'The f passed in grads(f) must be a function');
15698 return (args, dy) => {
15699 assert(Array.isArray(args), () => 'The args passed in grads(f)(args) must be an array ' +
15700 'of `Tensor`s or `TensorLike`s');
15701 // args can be of any dtype, thus null as the last argument.
15702 const $args = convertToTensorArray(args, 'args', 'tf.grads', 'string_or_numeric');
15703 const $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grads') : null;
15704 return ENGINE.tidy(() => {
15705 const { value, grads } = ENGINE.gradients(() => f(...$args), $args, $dy);
15706 if ($dy != null) {
15707 assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grads(f)([x1,...], dy) must ' +
15708 'match the shape returned by f([x1,...])');
15709 }
15710 checkGrads(grads);
15711 return grads;
15712 });
15713 };
15714 }
15715 /**
15716 * Like `tf.grad`, but also returns the value of `f()`. Useful when `f()`
15717 * returns a metric you want to show.
15718 *
15719 * The result is a rich object with the following properties:
15720 * - grad: The gradient of `f(x)` w.r.t `x` (result of `tf.grad`).
15721 * - value: The value returned by `f(x)`.
15722 *
15723 * ```js
15724 * // f(x) = x ^ 2
15725 * const f = x => x.square();
15726 * // f'(x) = 2x
15727 * const g = tf.valueAndGrad(f);
15728 *
15729 * const x = tf.tensor1d([2, 3]);
15730 * const {value, grad} = g(x);
15731 *
15732 * console.log('value');
15733 * value.print();
15734 * console.log('grad');
15735 * grad.print();
15736 * ```
15737 *
15738 * @doc {heading: 'Training', subheading: 'Gradients'}
15739 */
15740 function valueAndGrad(f) {
15741 assert(isFunction(f), () => 'The f passed in valueAndGrad(f) must be a function');
15742 return (x, dy) => {
15743 assert(x instanceof Tensor, () => 'The x passed in valueAndGrad(f)(x) must be a tensor');
15744 assert(dy == null || dy instanceof Tensor, () => 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor');
15745 const { grads, value } = ENGINE.gradients(() => f(x), [x], dy);
15746 checkGrads(grads);
15747 return { grad: grads[0], value };
15748 };
15749 }
15750 /**
15751 * Like `tf.grads`, but returns also the value of `f()`. Useful when `f()`
15752 * returns a metric you want to show.
15753 *
15754 * The result is a rich object with the following properties:
15755 * - grads: The gradients of `f()` w.r.t each input (result of `tf.grads`).
15756 * - value: The value returned by `f(x)`.
15757 *
15758 * ```js
15759 * // f(a, b) = a * b
15760 * const f = (a, b) => a.mul(b);
15761 * // df/da = b, df/db = a
15762 * const g = tf.valueAndGrads(f);
15763 *
15764 * const a = tf.tensor1d([2, 3]);
15765 * const b = tf.tensor1d([-2, -3]);
15766 * const {value, grads} = g([a, b]);
15767 *
15768 * const [da, db] = grads;
15769 *
15770 * console.log('value');
15771 * value.print();
15772 *
15773 * console.log('da');
15774 * da.print();
15775 * console.log('db');
15776 * db.print();
15777 * ```
15778 *
15779 * @doc {heading: 'Training', subheading: 'Gradients'}
15780 */
15781 function valueAndGrads(f) {
15782 assert(isFunction(f), () => 'The f passed in valueAndGrads(f) must be a function');
15783 return (args, dy) => {
15784 assert(Array.isArray(args) && args.every(arg => arg instanceof Tensor), () => 'The args passed in valueAndGrads(f)(args) must be array of ' +
15785 'tensors');
15786 assert(dy == null || dy instanceof Tensor, () => 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor');
15787 const res = ENGINE.gradients(() => f(...args), args, dy);
15788 if (dy != null) {
15789 assertShapesMatch(res.value.shape, dy.shape, 'The shape of dy passed in valueAndGrads(f)([x1,...], dy) must ' +
15790 'match the shape returned by f([x1,...])');
15791 }
15792 checkGrads(res.grads);
15793 return res;
15794 };
15795 }
15796 /**
15797 * Computes and returns the gradient of f(x) with respect to the list of
15798 * trainable variables provided by `varList`. If no list is provided, it
15799 * defaults to all trainable variables.
15800 *
15801 * ```js
15802 * const a = tf.variable(tf.tensor1d([3, 4]));
15803 * const b = tf.variable(tf.tensor1d([5, 6]));
15804 * const x = tf.tensor1d([1, 2]);
15805 *
15806 * // f(a, b) = a * x ^ 2 + b * x
15807 * const f = () => a.mul(x.square()).add(b.mul(x)).sum();
15808 * // df/da = x ^ 2, df/db = x
15809 * const {value, grads} = tf.variableGrads(f);
15810 *
15811 * Object.keys(grads).forEach(varName => grads[varName].print());
15812 * ```
15813 *
15814 * @param f The function to execute. f() should return a scalar.
15815 * @param varList The list of variables to compute the gradients with respect
15816 * to. Defaults to all trainable variables.
15817 * @returns An object with the following keys and values:
15818 * - `value`: The value of the function `f`.
15819 * - `grads`: A map from the names of the variables to the gradients.
15820 * If the `varList` argument is provided explicitly and contains a subset of
15821 * non-trainable variables, this map in the return value will contain keys
15822 * that map the names of the non-trainable variables to `null`.
15823 *
15824 * @doc {heading: 'Training', subheading: 'Gradients'}
15825 */
15826 function variableGrads(f, varList) {
15827 assert(isFunction(f), () => 'The f passed in variableGrads(f) must be a function');
15828 assert(varList == null ||
15829 Array.isArray(varList) && varList.every(v => v instanceof Variable), () => 'The varList passed in variableGrads(f, varList) must be an array ' +
15830 'of variables');
15831 const specifiedVarList = varList != null;
15832 if (!specifiedVarList) {
15833 // Get all of the trainable variables.
15834 varList = [];
15835 for (const varName in ENGINE.registeredVariables) {
15836 varList.push(ENGINE.registeredVariables[varName]);
15837 }
15838 }
15839 const specifiedNonTrainable = specifiedVarList ? varList.filter(variable => !variable.trainable) : null;
15840 // Prune non-trainable variables.
15841 const originalVarCount = varList.length;
15842 varList = varList.filter(variable => variable.trainable);
15843 assert(varList.length > 0, () => `variableGrads() expects at least one of the input variables to ` +
15844 `be trainable, but none of the ${originalVarCount} variables is ` +
15845 `trainable.`);
15846 const allowNoGradients = true;
15847 const { value, grads } = ENGINE.gradients(f, varList, null, allowNoGradients);
15848 assert(grads.some(g => g != null), () => 'Cannot find a connection between any variable and the result of ' +
15849 'the loss function y=f(x). Please make sure the operations that ' +
15850 'use variables are inside the function f passed to minimize().');
15851 assert(value.rank === 0, () => `The f passed in variableGrads(f) must return a scalar, but it ` +
15852 `returned a rank-${value.rank} tensor`);
15853 const namedGrads = {};
15854 varList.forEach((v, i) => {
15855 if (grads[i] != null) {
15856 namedGrads[v.name] = grads[i];
15857 }
15858 });
15859 if (specifiedNonTrainable != null) {
15860 // If varList is explicitly provided and contains non-trainable values,
15861 // add them to the returned gradients with `null` values.
15862 specifiedNonTrainable.forEach(v => namedGrads[v.name] = null);
15863 }
15864 return { value, grads: namedGrads };
15865 }
15866 /**
15867 * Overrides the gradient computation of a function `f`.
15868 *
15869 * Takes a function
15870 * `f(...inputs, save) => {value: Tensor, gradFunc: (dy, saved) => Tensor[]}`
15871 * and returns another function `g(...inputs)` which takes the same inputs as
15872 * `f`. When called, `g` returns `f().value`. In backward mode, custom gradients
15873 * with respect to each input of `f` are computed using `f().gradFunc`.
15874 *
15875 * The `save` function passsed to `f` should be used for saving tensors needed
15876 * in the gradient. And the `saved` passed to the `gradFunc` is a
15877 * `NamedTensorMap`, which contains those saved tensor.
15878 *
15879 * ```js
15880 * const customOp = tf.customGrad((x, save) => {
15881 * // Save x to make sure it's available later for the gradient.
15882 * save([x]);
15883 * // Override gradient of our custom x ^ 2 op to be dy * abs(x);
15884 * return {
15885 * value: x.square(),
15886 * // Note `saved.x` which points to the `x` we saved earlier.
15887 * gradFunc: (dy, saved) => [dy.mul(saved[0].abs())]
15888 * };
15889 * });
15890 *
15891 * const x = tf.tensor1d([-1, -2, 3]);
15892 * const dx = tf.grad(x => customOp(x));
15893 *
15894 * console.log(`f(x):`);
15895 * customOp(x).print();
15896 * console.log(`f'(x):`);
15897 * dx(x).print();
15898 * ```
15899 *
15900 * @param f The function to evaluate in forward mode, which should return
15901 * `{value: Tensor, gradFunc: (dy, saved) => Tensor[]}`, where `gradFunc`
15902 * returns the custom gradients of `f` with respect to its inputs.
15903 *
15904 * @doc {heading: 'Training', subheading: 'Gradients'}
15905 */
15906 function customGrad(f) {
15907 return ENGINE.customGrad(f);
15908 }
15909 function checkGrads(grads) {
15910 const numNullGradients = grads.filter(g => g == null).length;
15911 if (numNullGradients > 0) {
15912 throw new Error(`Cannot compute gradient of y=f(x) with respect to x. Make sure that
15913 the f you passed encloses all operations that lead from x to y.`);
15914 }
15915 }
15916
15917 /**
15918 * @license
15919 * Copyright 2018 Google LLC. All Rights Reserved.
15920 * Licensed under the Apache License, Version 2.0 (the "License");
15921 * you may not use this file except in compliance with the License.
15922 * You may obtain a copy of the License at
15923 *
15924 * http://www.apache.org/licenses/LICENSE-2.0
15925 *
15926 * Unless required by applicable law or agreed to in writing, software
15927 * distributed under the License is distributed on an "AS IS" BASIS,
15928 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15929 * See the License for the specific language governing permissions and
15930 * limitations under the License.
15931 * =============================================================================
15932 */
15933 /**
15934 * Computes softplus of the input `tf.Tensor` element-wise: `log(exp(x) + 1)`
15935 *
15936 * ```js
15937 * const x = tf.tensor1d([0, 1, -1, .7]);
15938 *
15939 * x.softplus().print(); // or tf.softplus(x)
15940 * ```
15941 * @param x The input tensor.
15942 *
15943 * @doc {heading: 'Operations', subheading: 'Basic math'}
15944 */
15945 function softplus_(x) {
15946 const $x = convertToTensor(x, 'x', 'softplus');
15947 const inputs = { x: $x };
15948 return ENGINE.runKernel(Softplus, inputs);
15949 }
15950 const softplus = op({ softplus_ });
15951
15952 /**
15953 * @license
15954 * Copyright 2018 Google LLC. All Rights Reserved.
15955 * Licensed under the Apache License, Version 2.0 (the "License");
15956 * you may not use this file except in compliance with the License.
15957 * You may obtain a copy of the License at
15958 *
15959 * http://www.apache.org/licenses/LICENSE-2.0
15960 *
15961 * Unless required by applicable law or agreed to in writing, software
15962 * distributed under the License is distributed on an "AS IS" BASIS,
15963 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15964 * See the License for the specific language governing permissions and
15965 * limitations under the License.
15966 * =============================================================================
15967 */
15968 /**
15969 * Computes log sigmoid of the input `tf.Tensor` element-wise:
15970 * `logSigmoid(x)`. For numerical stability, we use `-tf.softplus(-x)`.
15971 *
15972 * ```js
15973 * const x = tf.tensor1d([0, 1, -1, .7]);
15974 *
15975 * x.logSigmoid().print(); // or tf.logSigmoid(x)
15976 * ```
15977 * @param x The input tensor.
15978 *
15979 * @doc {heading: 'Operations', subheading: 'Basic math'}
15980 */
15981 function logSigmoid_(x) {
15982 const $x = convertToTensor(x, 'x', 'logSigmoid');
15983 // Use a custom gradient to maintain previous implementation.
15984 // There is no LogSigmoid kernel in TF so we can't use engine.runKernel
15985 // directly
15986 const customOp = customGrad((x) => {
15987 // TODO(yassogba) we can remove the chained softplus call here only
15988 // after backends have modualrized softplus at which point we can call
15989 // engine runKernel(..., Sotfplus, ...) directly.
15990 const value = neg(softplus(neg(x)));
15991 const gradFunc = (dy) => {
15992 const derX = mul(dy, sigmoid(neg(x)));
15993 return derX;
15994 };
15995 return { value, gradFunc };
15996 });
15997 return customOp($x);
15998 }
15999 const logSigmoid = op({ logSigmoid_ });
16000
16001 /**
16002 * @license
16003 * Copyright 2020 Google LLC. All Rights Reserved.
16004 * Licensed under the Apache License, Version 2.0 (the "License");
16005 * you may not use this file except in compliance with the License.
16006 * You may obtain a copy of the License at
16007 *
16008 * http://www.apache.org/licenses/LICENSE-2.0
16009 *
16010 * Unless required by applicable law or agreed to in writing, software
16011 * distributed under the License is distributed on an "AS IS" BASIS,
16012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16013 * See the License for the specific language governing permissions and
16014 * limitations under the License.
16015 * =============================================================================
16016 */
16017 /**
16018 * Subtracts two `tf.Tensor`s element-wise, A - B. Supports broadcasting.
16019 *
16020 * ```js
16021 * const a = tf.tensor1d([10, 20, 30, 40]);
16022 * const b = tf.tensor1d([1, 2, 3, 4]);
16023 *
16024 * a.sub(b).print(); // or tf.sub(a, b)
16025 * ```
16026 *
16027 * ```js
16028 * // Broadcast subtract a with b.
16029 * const a = tf.tensor1d([10, 20, 30, 40]);
16030 * const b = tf.scalar(5);
16031 *
16032 * a.sub(b).print(); // or tf.sub(a, b)
16033 * ```
16034 * @param a The first `tf.Tensor` to subtract from.
16035 * @param b The second `tf.Tensor` to be subtracted. Must have the same dtype as
16036 * `a`.
16037 *
16038 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
16039 */
16040 function sub_(a, b) {
16041 let $a = convertToTensor(a, 'a', 'sub');
16042 let $b = convertToTensor(b, 'b', 'sub');
16043 [$a, $b] = makeTypesMatch($a, $b);
16044 const inputs = { a: $a, b: $b };
16045 return ENGINE.runKernel(Sub, inputs);
16046 }
16047 const sub = op({ sub_ });
16048
16049 /**
16050 * @license
16051 * Copyright 2020 Google Inc. All Rights Reserved.
16052 * Licensed under the Apache License, Version 2.0 (the "License");
16053 * you may not use this file except in compliance with the License.
16054 * You may obtain a copy of the License at
16055 *
16056 * http://www.apache.org/licenses/LICENSE-2.0
16057 *
16058 * Unless required by applicable law or agreed to in writing, software
16059 * distributed under the License is distributed on an "AS IS" BASIS,
16060 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16061 * See the License for the specific language governing permissions and
16062 * limitations under the License.
16063 * =============================================================================
16064 */
16065 /**
16066 * Computes the log softmax.
16067 *
16068 * ```js
16069 * const a = tf.tensor1d([1, 2, 3]);
16070 *
16071 * a.logSoftmax().print(); // or tf.logSoftmax(a)
16072 * ```
16073 *
16074 * ```js
16075 * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);
16076 *
16077 * a.logSoftmax().print(); // or tf.logSoftmax(a)
16078 * ```
16079 *
16080 * @param logits The logits array.
16081 * @param axis The dimension softmax would be performed on. Defaults to `-1`
16082 * which indicates the last dimension.
16083 *
16084 * @doc {heading: 'Operations', subheading: 'Normalization'}
16085 */
16086 function logSoftmax_(logits, axis = -1) {
16087 const $logits = convertToTensor(logits, 'logits', 'logSoftmax');
16088 if (axis === -1) {
16089 axis = $logits.rank - 1;
16090 }
16091 if (axis !== $logits.rank - 1) {
16092 throw Error('Log Softmax along a non-last dimension is not yet supported. ' +
16093 `Logits was rank ${$logits.rank} and axis was ${axis}`);
16094 }
16095 // const forward: ForwardFunc<Tensor> = (backend, save) => {
16096 // const keepDims = true;
16097 // const xMax = max(logits, axis, true);
16098 // const shifted = sub(logits, xMax);
16099 // const value =
16100 // sub(cast(shifted, 'float32'), log(sum(exp(shifted), axis,
16101 // keepDims)));
16102 // save([value]);
16103 // return value;
16104 // };
16105 // Use a custom gradient for numerical stability.
16106 const customOp = customGrad((logits, save) => {
16107 const keepDims = true;
16108 const xMax = max(logits, axis, true);
16109 const shifted = sub(logits, xMax);
16110 const value = sub(cast(shifted, 'float32'), log$1(sum$1(exp(shifted), axis, keepDims)));
16111 save([value]);
16112 const gradFunc = (dy, saved) => {
16113 const [value] = saved;
16114 const keepDims = true;
16115 const softmax = exp(value);
16116 return sub(dy, mul(sum$1(dy, axis, keepDims), softmax));
16117 };
16118 return { value, gradFunc };
16119 });
16120 return customOp($logits);
16121 // TODO Use Engine.runKernel when CPU/WebGL/WASM backends implement this.
16122 // const inputs: LogSoftmaxInputs = {logits: $logits};
16123 // const attrs: LogSoftmaxAttrs = {axis};
16124 // return ENGINE.runKernel(
16125 // LogSoftmax, inputs as {} as NamedTensorMap,
16126 // attrs as {} as NamedAttrMap);
16127 }
16128 const logSoftmax = op({ logSoftmax_ });
16129
16130 /**
16131 * @license
16132 * Copyright 2020 Google LLC. All Rights Reserved.
16133 * Licensed under the Apache License, Version 2.0 (the "License");
16134 * you may not use this file except in compliance with the License.
16135 * You may obtain a copy of the License at
16136 *
16137 * http://www.apache.org/licenses/LICENSE-2.0
16138 *
16139 * Unless required by applicable law or agreed to in writing, software
16140 * distributed under the License is distributed on an "AS IS" BASIS,
16141 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16142 * See the License for the specific language governing permissions and
16143 * limitations under the License.
16144 * =============================================================================
16145 */
16146 /**
16147 * Computes the log(sum(exp(elements across the reduction dimensions)).
16148 *
16149 * Reduces the input along the dimensions given in `axis`. Unless `keepDims`
16150 * is true, the rank of the array is reduced by 1 for each entry in `axis`.
16151 * If `keepDims` is true, the reduced dimensions are retained with length 1.
16152 * If `axis` has no entries, all dimensions are reduced, and an array with a
16153 * single element is returned.
16154 *
16155 * ```js
16156 * const x = tf.tensor1d([1, 2, 3]);
16157 *
16158 * x.logSumExp().print(); // or tf.logSumExp(x)
16159 * ```
16160 *
16161 * ```js
16162 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
16163 *
16164 * const axis = 1;
16165 * x.logSumExp(axis).print(); // or tf.logSumExp(a, axis)
16166 * ```
16167 * @param x The input tensor.
16168 * @param axis The dimension(s) to reduce. If null (the default),
16169 * reduces all dimensions.
16170 * @param keepDims If true, retains reduced dimensions with length
16171 * of 1. Defaults to false.
16172 *
16173 * @doc {heading: 'Operations', subheading: 'Reduction'}
16174 */
16175 function logSumExp_(x, axis = null, keepDims = false) {
16176 const $x = convertToTensor(x, 'x', 'logSumExp');
16177 const axes = parseAxisParam(axis, $x.shape);
16178 const xMax = max($x, axes, true /* keepDims */);
16179 const a = sub($x, xMax);
16180 const b = exp(a);
16181 const c = sum$1(b, axes);
16182 const d = log$1(c);
16183 const res = add$1(reshape(xMax, d.shape), d);
16184 if (keepDims) {
16185 const newShape = expandShapeToKeepDim(res.shape, axes);
16186 return reshape(res, newShape);
16187 }
16188 return res;
16189 }
16190 const logSumExp = op({ logSumExp_ });
16191
16192 /**
16193 * @license
16194 * Copyright 2020 Google LLC. All Rights Reserved.
16195 * Licensed under the Apache License, Version 2.0 (the "License");
16196 * you may not use this file except in compliance with the License.
16197 * You may obtain a copy of the License at
16198 *
16199 * http://www.apache.org/licenses/LICENSE-2.0
16200 *
16201 * Unless required by applicable law or agreed to in writing, software
16202 * distributed under the License is distributed on an "AS IS" BASIS,
16203 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16204 * See the License for the specific language governing permissions and
16205 * limitations under the License.
16206 * =============================================================================
16207 */
16208 /**
16209 * Returns the truth value of `a AND b` element-wise. Supports broadcasting.
16210 *
16211 * ```js
16212 * const a = tf.tensor1d([false, false, true, true], 'bool');
16213 * const b = tf.tensor1d([false, true, false, true], 'bool');
16214 *
16215 * a.logicalAnd(b).print();
16216 * ```
16217 *
16218 * @param a The first input tensor. Must be of dtype bool.
16219 * @param b The second input tensor. Must be of dtype bool.
16220 *
16221 * @doc {heading: 'Operations', subheading: 'Logical'}
16222 */
16223 function logicalAnd_(a, b) {
16224 const $a = convertToTensor(a, 'a', 'logicalAnd', 'bool');
16225 const $b = convertToTensor(b, 'b', 'logicalAnd', 'bool');
16226 assertAndGetBroadcastShape($a.shape, $b.shape);
16227 const inputs = { a: $a, b: $b };
16228 return ENGINE.runKernel(LogicalAnd, inputs);
16229 }
16230 const logicalAnd = op({ logicalAnd_ });
16231
16232 /**
16233 * @license
16234 * Copyright 2020 Google LLC. All Rights Reserved.
16235 * Licensed under the Apache License, Version 2.0 (the "License");
16236 * you may not use this file except in compliance with the License.
16237 * You may obtain a copy of the License at
16238 *
16239 * http://www.apache.org/licenses/LICENSE-2.0
16240 *
16241 * Unless required by applicable law or agreed to in writing, software
16242 * distributed under the License is distributed on an "AS IS" BASIS,
16243 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16244 * See the License for the specific language governing permissions and
16245 * limitations under the License.
16246 * =============================================================================
16247 */
16248 /**
16249 * Returns the truth value of `NOT x` element-wise.
16250 *
16251 * ```js
16252 * const a = tf.tensor1d([false, true], 'bool');
16253 *
16254 * a.logicalNot().print();
16255 * ```
16256 *
16257 * @param x The input tensor. Must be of dtype 'bool'.
16258 *
16259 * @doc {heading: 'Operations', subheading: 'Logical'}
16260 */
16261 function logicalNot_(x) {
16262 const $x = convertToTensor(x, 'x', 'logicalNot', 'bool');
16263 const inputs = { x: $x };
16264 return ENGINE.runKernel(LogicalNot, inputs);
16265 }
16266 const logicalNot = op({ logicalNot_ });
16267
16268 /**
16269 * @license
16270 * Copyright 2020 Google LLC. All Rights Reserved.
16271 * Licensed under the Apache License, Version 2.0 (the "License");
16272 * you may not use this file except in compliance with the License.
16273 * You may obtain a copy of the License at
16274 *
16275 * http://www.apache.org/licenses/LICENSE-2.0
16276 *
16277 * Unless required by applicable law or agreed to in writing, software
16278 * distributed under the License is distributed on an "AS IS" BASIS,
16279 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16280 * See the License for the specific language governing permissions and
16281 * limitations under the License.
16282 * =============================================================================
16283 */
16284 /**
16285 * Returns the truth value of `a OR b` element-wise. Supports broadcasting.
16286 *
16287 * ```js
16288 * const a = tf.tensor1d([false, false, true, true], 'bool');
16289 * const b = tf.tensor1d([false, true, false, true], 'bool');
16290 *
16291 * a.logicalOr(b).print();
16292 * ```
16293 * @param a The first input tensor. Must be of dtype bool.
16294 * @param b The second input tensor. Must be of dtype bool.
16295 *
16296 * @doc {heading: 'Operations', subheading: 'Logical'}
16297 */
16298 function logicalOr_(a, b) {
16299 const $a = convertToTensor(a, 'a', 'logicalOr', 'bool');
16300 const $b = convertToTensor(b, 'b', 'logicalOr', 'bool');
16301 assertAndGetBroadcastShape($a.shape, $b.shape);
16302 const inputs = { a: $a, b: $b };
16303 return ENGINE.runKernel(LogicalOr, inputs);
16304 }
16305 const logicalOr = op({ logicalOr_ });
16306
16307 /**
16308 * @license
16309 * Copyright 2020 Google LLC. All Rights Reserved.
16310 * Licensed under the Apache License, Version 2.0 (the "License");
16311 * you may not use this file except in compliance with the License.
16312 * You may obtain a copy of the License at
16313 *
16314 * http://www.apache.org/licenses/LICENSE-2.0
16315 *
16316 * Unless required by applicable law or agreed to in writing, software
16317 * distributed under the License is distributed on an "AS IS" BASIS,
16318 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16319 * See the License for the specific language governing permissions and
16320 * limitations under the License.
16321 * =============================================================================
16322 */
16323 /**
16324 * Returns the truth value of `a XOR b` element-wise. Supports broadcasting.
16325 *
16326 * ```js
16327 * const a = tf.tensor1d([false, false, true, true], 'bool');
16328 * const b = tf.tensor1d([false, true, false, true], 'bool');
16329 *
16330 * a.logicalXor(b).print();
16331 * ```
16332 *
16333 * @param a The first input tensor. Must be of dtype bool.
16334 * @param b The second input tensor. Must be of dtype bool.
16335 *
16336 * @doc {heading: 'Operations', subheading: 'Logical'}
16337 */
16338 function logicalXor_(a, b) {
16339 const $a = convertToTensor(a, 'a', 'logicalXor', 'bool');
16340 const $b = convertToTensor(b, 'b', 'logicalXor', 'bool');
16341 assertAndGetBroadcastShape($a.shape, $b.shape);
16342 // x ^ y = (x | y) & ~(x & y)
16343 return logicalAnd(logicalOr(a, b), logicalNot(logicalAnd(a, b)));
16344 }
16345 const logicalXor = op({ logicalXor_ });
16346
16347 /**
16348 * @license
16349 * Copyright 2022 Google LLC. All Rights Reserved.
16350 * Licensed under the Apache License, Version 2.0 (the "License");
16351 * you may not use this file except in compliance with the License.
16352 * You may obtain a copy of the License at
16353 *
16354 * http://www.apache.org/licenses/LICENSE-2.0
16355 *
16356 * Unless required by applicable law or agreed to in writing, software
16357 * distributed under the License is distributed on an "AS IS" BASIS,
16358 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16359 * See the License for the specific language governing permissions and
16360 * limitations under the License.
16361 * =============================================================================
16362 */
16363 const INT32_MAX = 2147483648;
16364 /**
16365 * Searches for where a value would go in a sorted sequence.
16366 *
16367 * This is not a method for checking containment (like javascript in).
16368 *
16369 * The typical use case for this operation is "binning", "bucketing", or
16370 * "discretizing". The values are assigned to bucket-indices based on the edges
16371 * listed in 'sortedSequence'. This operation returns the bucket-index for each
16372 * value.
16373 *
16374 * The side argument controls which index is returned if a value lands exactly
16375 * on an edge.
16376 *
16377 * The axis is not settable for this operation. It always operates on the
16378 * innermost dimension (axis=-1). The operation will accept any number of outer
16379 * dimensions.
16380 *
16381 * Note: This operation assumes that 'sortedSequence' is sorted along the
16382 * innermost axis, maybe using 'sort(..., axis=-1)'. If the sequence is not
16383 * sorted no error is raised and the content of the returned tensor is not well
16384 * defined.
16385 *
16386 * ```js
16387 * const edges = tf.tensor1d([-1, 3.3, 9.1, 10.0]);
16388 * let values = tf.tensor1d([0.0, 4.1, 12.0]);
16389 * const result1 = tf.searchSorted(edges, values, 'left');
16390 * result1.print(); // [1, 2, 4]
16391 *
16392 * const seq = tf.tensor1d([0, 3, 9, 10, 10]);
16393 * values = tf.tensor1d([0, 4, 10]);
16394 * const result2 = tf.searchSorted(seq, values, 'left');
16395 * result2.print(); // [0, 2, 3]
16396 * const result3 = tf.searchSorted(seq, values, 'right');
16397 * result3.print(); // [1, 2, 5]
16398 *
16399 * const sortedSequence = tf.tensor2d([[0., 3., 8., 9., 10.],
16400 * [1., 2., 3., 4., 5.]]);
16401 * values = tf.tensor2d([[9.8, 2.1, 4.3],
16402 * [0.1, 6.6, 4.5, ]]);
16403 * const result4 = tf.searchSorted(sortedSequence, values, 'left');
16404 * result4.print(); // [[4, 1, 2], [0, 5, 4]]
16405 * ```
16406 * @param sortedSequence: N-D. Sorted sequence.
16407 * @param values: N-D. Search values.
16408 * @param side: 'left'|'right'. Defaults to 'left'. 'left' corresponds to lower
16409 * bound and 'right' to upper bound.
16410 * @return An N-D int32 tensor the size of values containing the result of
16411 * applying either lower bound or upper bound (depending on side) to each
16412 * value. The result is not a global index to the entire Tensor, but the
16413 * index in the last dimension.
16414 * @doc {heading: 'Operations', subheading: 'Evaluation'}
16415 */
16416 function searchSorted_(sortedSequence, values, side = 'left') {
16417 const $sortedSequence = convertToTensor(sortedSequence, 'sortedSequence', 'searchSorted');
16418 const $values = convertToTensor(values, 'values', 'searchSorted');
16419 const sequenceSize = $sortedSequence.shape[$sortedSequence.shape.length - 1];
16420 const valuesSize = $values.shape[$values.shape.length - 1];
16421 const $sortedSequence2D = reshape($sortedSequence, [-1, sequenceSize]);
16422 const $values2D = reshape($values, [-1, valuesSize]);
16423 if ($sortedSequence2D.rank < 2) {
16424 throw new Error(`Sorted input argument must be at least 2-dimensional`);
16425 }
16426 if ($sortedSequence2D.shape[0] !== $values2D.shape[0]) {
16427 throw new Error(`Leading dimension of 'sortedSequence' and 'values' must match.`);
16428 }
16429 if (sizeFromShape($values2D.shape) >= INT32_MAX) {
16430 throw new Error(`values tensor size must less than ${INT32_MAX}`);
16431 }
16432 if ($sortedSequence2D.shape[1] >= INT32_MAX) {
16433 throw new Error(`trailing dim_size must less than ${INT32_MAX} for int32 output type, was ${$sortedSequence2D.shape[1]}`);
16434 }
16435 const inputs = {
16436 sortedSequence: $sortedSequence2D,
16437 values: $values2D,
16438 };
16439 const attrs = { side };
16440 return ENGINE.runKernel(SearchSorted, inputs, attrs);
16441 }
16442 const searchSorted = op({ searchSorted_ });
16443
16444 /**
16445 * @license
16446 * Copyright 2022 Google LLC. All Rights Reserved.
16447 * Licensed under the Apache License, Version 2.0 (the "License");
16448 * you may not use this file except in compliance with the License.
16449 * You may obtain a copy of the License at
16450 *
16451 * http://www.apache.org/licenses/LICENSE-2.0
16452 *
16453 * Unless required by applicable law or agreed to in writing, software
16454 * distributed under the License is distributed on an "AS IS" BASIS,
16455 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16456 * See the License for the specific language governing permissions and
16457 * limitations under the License.
16458 * =============================================================================
16459 */
16460 /**
16461 * Searches for where a value would go in a sorted sequence.
16462 *
16463 * This is not a method for checking containment (like javascript in).
16464 *
16465 * The typical use case for this operation is "binning", "bucketing", or
16466 * "discretizing". The values are assigned to bucket-indices based on the edges
16467 * listed in 'sortedSequence'. This operation returns the bucket-index for each
16468 * value.
16469 *
16470 * The index returned corresponds to the first edge greater than or equal to the
16471 * value.
16472 *
16473 * The axis is not settable for this operation. It always operates on the
16474 * innermost dimension (axis=-1). The operation will accept any number of outer
16475 * dimensions.
16476 *
16477 * Note: This operation assumes that 'lowerBound' is sorted along the
16478 * innermost axis, maybe using 'sort(..., axis=-1)'. If the sequence is not
16479 * sorted no error is raised and the content of the returned tensor is not well
16480 * defined.
16481 *
16482 * ```js
16483 * const edges = tf.tensor1d([-1, 3.3, 9.1, 10.0]);
16484 * let values = tf.tensor1d([0.0, 4.1, 12.0]);
16485 * const result1 = tf.lowerBound(edges, values);
16486 * result1.print(); // [1, 2, 4]
16487 *
16488 * const seq = tf.tensor1d([0, 3, 9, 10, 10]);
16489 * values = tf.tensor1d([0, 4, 10]);
16490 * const result2 = tf.lowerBound(seq, values);
16491 * result2.print(); // [0, 2, 3]
16492 *
16493 * const sortedSequence = tf.tensor2d([[0., 3., 8., 9., 10.],
16494 * [1., 2., 3., 4., 5.]]);
16495 * values = tf.tensor2d([[9.8, 2.1, 4.3],
16496 * [0.1, 6.6, 4.5, ]]);
16497 * const result3 = tf.lowerBound(sortedSequence, values);
16498 * result3.print(); // [[4, 1, 2], [0, 5, 4]]
16499 * ```
16500 * @param sortedSequence: N-D. Sorted sequence.
16501 * @param values: N-D. Search values.
16502 * @return An N-D int32 tensor the size of values containing the result of
16503 * applying lower bound to each value. The result is not a global index to
16504 * the entire Tensor, but the index in the last dimension.
16505 * @doc {heading: 'Operations', subheading: 'Evaluation'}
16506 */
16507 function lowerBound(sortedSequence, values) {
16508 return searchSorted(sortedSequence, values, 'left');
16509 }
16510
16511 /**
16512 * @license
16513 * Copyright 2020 Google LLC. All Rights Reserved.
16514 * Licensed under the Apache License, Version 2.0 (the "License");
16515 * you may not use this file except in compliance with the License.
16516 * You may obtain a copy of the License at
16517 *
16518 * http://www.apache.org/licenses/LICENSE-2.0
16519 *
16520 * Unless required by applicable law or agreed to in writing, software
16521 * distributed under the License is distributed on an "AS IS" BASIS,
16522 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16523 * See the License for the specific language governing permissions and
16524 * limitations under the License.
16525 * =============================================================================
16526 */
16527 /**
16528 * Computes the 2D max pooling of an image.
16529 *
16530 * @param x The input tensor, of rank 4 or rank 3 of shape
16531 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
16532 * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
16533 * `filterSize` is a single number, then `filterHeight == filterWidth`.
16534 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
16535 * `strides` is a single number, then `strideHeight == strideWidth`.
16536 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
16537 * in which we sample input values across the height and width dimensions
16538 * in dilated pooling. Defaults to `[1, 1]`. If `dilations` is a single
16539 * number, then `dilationHeight == dilationWidth`. If it is greater than
16540 * 1, then all values of `strides` must be 1.
16541 * @param pad The type of padding algorithm.
16542 * - `same` and stride 1: output will be of same size as input,
16543 * regardless of filter size.
16544 * - `valid`: output will be smaller than input if filter is larger
16545 * than 1x1.
16546 * - For more info, see this guide:
16547 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
16548 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
16549 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
16550 * provided, it will default to truncate.
16551 */
16552 function maxPool_(x, filterSize, strides, pad, dimRoundingMode) {
16553 const $x = convertToTensor(x, 'x', 'maxPool');
16554 const dilations = 1;
16555 let x4D = $x;
16556 let reshapedTo4D = false;
16557 if ($x.rank === 3) {
16558 reshapedTo4D = true;
16559 x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
16560 }
16561 assert(x4D.rank === 4, () => `Error in maxPool: input must be rank 4 but got rank ${x4D.rank}.`);
16562 assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
16563 `Got strides ${strides} and dilations '${dilations}'`);
16564 checkPadOnDimRoundingMode('maxPool', pad, dimRoundingMode);
16565 const inputs = { x: x4D };
16566 const attrs = { filterSize, strides, pad, dimRoundingMode };
16567 // tslint:disable-next-line: no-unnecessary-type-assertion
16568 const res = ENGINE.runKernel(MaxPool, inputs, attrs);
16569 if (reshapedTo4D) {
16570 return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
16571 }
16572 return res;
16573 }
16574 const maxPool = op({ maxPool_ });
16575
16576 /**
16577 * @license
16578 * Copyright 2020 Google LLC. All Rights Reserved.
16579 * Licensed under the Apache License, Version 2.0 (the "License");
16580 * you may not use this file except in compliance with the License.
16581 * You may obtain a copy of the License at
16582 *
16583 * http://www.apache.org/licenses/LICENSE-2.0
16584 *
16585 * Unless required by applicable law or agreed to in writing, software
16586 * distributed under the License is distributed on an "AS IS" BASIS,
16587 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16588 * See the License for the specific language governing permissions and
16589 * limitations under the License.
16590 * =============================================================================
16591 */
16592 /**
16593 * Computes the 3D max pooling.
16594 *
16595 * ```js
16596 * const x = tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]);
16597 * const result = tf.maxPool3d(x, 2, 1, 'valid');
16598 * result.print();
16599 * ```
16600 *
16601 * @param x The input tensor, of rank 5 or rank 4 of shape
16602 * `[batch, depth, height, width, inChannels]`.
16603 * @param filterSize The filter size:
16604 * `[filterDepth, filterHeight, filterWidth]`.
16605 * If `filterSize` is a single number,
16606 * then `filterDepth == filterHeight == filterWidth`.
16607 * @param strides The strides of the pooling:
16608 * `[strideDepth, strideHeight, strideWidth]`.
16609 * If `strides` is a single number,
16610 * then `strideDepth == strideHeight == strideWidth`.
16611 * @param pad The type of padding algorithm.
16612 * - `same` and stride 1: output will be of same size as input,
16613 * regardless of filter size.
16614 * - `valid`: output will be smaller than input if filter is larger
16615 * than 1*1x1.
16616 * - For more info, see this guide:
16617 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
16618 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
16619 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
16620 * provided, it will default to truncate.
16621 * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
16622 * "NDHWC". Specify the data format of the input and output data. With the
16623 * default format "NDHWC", the data is stored in the order of: [batch,
16624 * depth, height, width, channels]. Only "NDHWC" is currently supported.
16625 * @doc {heading: 'Operations', subheading: 'Convolution'}
16626 */
16627 function maxPool3d_(x, filterSize = [1, 1, 1], strides, pad, dimRoundingMode, dataFormat = 'NDHWC') {
16628 const $x = convertToTensor(x, 'x', 'maxPool3d');
16629 let x5D = $x;
16630 let reshapedTo5D = false;
16631 if ($x.rank === 4) {
16632 reshapedTo5D = true;
16633 x5D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
16634 }
16635 assert(x5D.rank === 5, () => `Error in maxPool3d: x must be rank 5 but got rank ${x5D.rank}.`);
16636 assert(dataFormat === 'NDHWC', () => `Error in maxPool3d: Only NDHWC is currently supported, ` +
16637 `but got dataFormat of ${dataFormat}`);
16638 checkPadOnDimRoundingMode('maxPool3d', pad, dimRoundingMode);
16639 const inputs = { x: x5D };
16640 const attrs = { filterSize, strides, pad, dimRoundingMode, dataFormat };
16641 // tslint:disable-next-line: no-unnecessary-type-assertion
16642 const res = ENGINE.runKernel(MaxPool3D, inputs, attrs);
16643 if (reshapedTo5D) {
16644 return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
16645 }
16646 return res;
16647 }
16648 const maxPool3d = op({ maxPool3d_ });
16649
16650 /**
16651 * @license
16652 * Copyright 2018 Google LLC. All Rights Reserved.
16653 * Licensed under the Apache License, Version 2.0 (the "License");
16654 * you may not use this file except in compliance with the License.
16655 * You may obtain a copy of the License at
16656 *
16657 * http://www.apache.org/licenses/LICENSE-2.0
16658 *
16659 * Unless required by applicable law or agreed to in writing, software
16660 * distributed under the License is distributed on an "AS IS" BASIS,
16661 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16662 * See the License for the specific language governing permissions and
16663 * limitations under the License.
16664 * =============================================================================
16665 */
16666 /**
16667 * Computes the 2D max pooling of an image with Argmax index.
16668 * The indices in argmax are flattened, so that a maximum value at position `[b,
16669 * y, x, c]` becomes flattened index: `(y * width + x) * channels + c` if
16670 * include_batch_in_index is False; `((b * height + y) * width + x) * channels
16671 * +c` if include_batch_in_index is True.
16672 *
16673 * The indices returned are always in `[0, height) x [0, width)` before
16674 * flattening.
16675 *
16676 * @param x The input tensor, of rank 4 or rank 3 of shape
16677 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
16678 * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
16679 * `filterSize` is a single number, then `filterHeight == filterWidth`.
16680 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
16681 * `strides` is a single number, then `strideHeight == strideWidth`.
16682 * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
16683 * "NDHWC". Specify the data format of the input and output data. With the
16684 * default format "NDHWC", the data is stored in the order of: [batch,
16685 * depth, height, width, channels]. Only "NDHWC" is currently supported.
16686 * @param pad The type of padding algorithm.
16687 * - `same` and stride 1: output will be of same size as input,
16688 * regardless of filter size.
16689 * - `valid`: output will be smaller than input if filter is larger
16690 * than 1x1.
16691 * - For more info, see this guide:
16692 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
16693 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
16694 * @param includeBatchIndex Defaults to False. Whether to include batch
16695 * dimension in flattened index of argmax.
16696 *
16697 * @doc {heading: 'Operations', subheading: 'Convolution'}
16698 */
16699 function maxPoolWithArgmax_(x, filterSize, strides, pad, includeBatchInIndex = false) {
16700 const $x = convertToTensor(x, 'x', 'maxPoolWithArgmax');
16701 const inputs = { x: $x };
16702 const attrs = { filterSize, strides, pad, includeBatchInIndex };
16703 // tslint:disable-next-line: no-unnecessary-type-assertion
16704 const result = ENGINE.runKernel(MaxPoolWithArgmax, inputs, attrs);
16705 return { result: result[0], indexes: result[1] };
16706 }
16707 const maxPoolWithArgmax = op({ maxPoolWithArgmax_ });
16708
16709 /**
16710 * @license
16711 * Copyright 2020 Google LLC. All Rights Reserved.
16712 * Licensed under the Apache License, Version 2.0 (the "License");
16713 * you may not use this file except in compliance with the License.
16714 * You may obtain a copy of the License at
16715 *
16716 * http://www.apache.org/licenses/LICENSE-2.0
16717 *
16718 * Unless required by applicable law or agreed to in writing, software
16719 * distributed under the License is distributed on an "AS IS" BASIS,
16720 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16721 * See the License for the specific language governing permissions and
16722 * limitations under the License.
16723 * =============================================================================
16724 */
16725 /**
16726 * Returns the max of a and b (`a > b ? a : b`) element-wise.
16727 * Supports broadcasting.
16728 *
16729 * We also expose `tf.maximumStrict` which has the same signature as this op and
16730 * asserts that `a` and `b` are the same shape (does not broadcast).
16731 *
16732 * ```js
16733 * const a = tf.tensor1d([1, 4, 3, 16]);
16734 * const b = tf.tensor1d([1, 2, 9, 4]);
16735 *
16736 * a.maximum(b).print(); // or tf.maximum(a, b)
16737 * ```
16738 *
16739 * ```js
16740 * // Broadcast maximum a with b.
16741 * const a = tf.tensor1d([2, 4, 6, 8]);
16742 * const b = tf.scalar(5);
16743 *
16744 * a.maximum(b).print(); // or tf.maximum(a, b)
16745 * ```
16746 *
16747 * @param a The first tensor.
16748 * @param b The second tensor. Must have the same type as `a`.
16749 *
16750 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
16751 */
16752 function maximum_(a, b) {
16753 let $a = convertToTensor(a, 'a', 'maximum');
16754 let $b = convertToTensor(b, 'b', 'maximum');
16755 [$a, $b] = makeTypesMatch($a, $b);
16756 if ($a.dtype === 'bool') {
16757 $a = cast($a, 'int32');
16758 $b = cast($b, 'int32');
16759 }
16760 assertAndGetBroadcastShape($a.shape, $b.shape);
16761 const inputs = { a: $a, b: $b };
16762 return ENGINE.runKernel(Maximum, inputs);
16763 }
16764 const maximum = op({ maximum_ });
16765
16766 /**
16767 * @license
16768 * Copyright 2020 Google Inc. All Rights Reserved.
16769 * Licensed under the Apache License, Version 2.0 (the "License");
16770 * you may not use this file except in compliance with the License.
16771 * You may obtain a copy of the License at
16772 *
16773 * http://www.apache.org/licenses/LICENSE-2.0
16774 *
16775 * Unless required by applicable law or agreed to in writing, software
16776 * distributed under the License is distributed on an "AS IS" BASIS,
16777 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16778 * See the License for the specific language governing permissions and
16779 * limitations under the License.
16780 * =============================================================================
16781 */
16782 /**
16783 * Computes the mean of elements across dimensions of a `tf.Tensor`.
16784 *
16785 * Reduces `x` along the dimensions given in `axis`. Unless `keepDims` is
16786 * true, the rank of the `tf.Tensor` is reduced by 1 for each entry in `axis`.
16787 * If `keepDims` is true, the reduced dimensions are retained with length 1.
16788 * If `axis` has no entries, all dimensions are reduced, and a `tf.Tensor` with
16789 * a single element is returned.
16790 *
16791 * ```js
16792 * const x = tf.tensor1d([1, 2, 3]);
16793 *
16794 * x.mean().print(); // or tf.mean(a)
16795 * ```
16796 *
16797 * ```js
16798 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
16799 *
16800 * const axis = 1;
16801 * x.mean(axis).print(); // or tf.mean(x, axis)
16802 * ```
16803 *
16804 * @param x The input tensor.
16805 * @param axis The dimension(s) to reduce. By default it reduces
16806 * all dimensions.
16807 * @param keepDims If true, retains reduced dimensions with size 1.
16808 *
16809 * @doc {heading: 'Operations', subheading: 'Reduction'}
16810 */
16811 function mean_(x, axis = null, keepDims = false) {
16812 const $x = convertToTensor(x, 'x', 'mean');
16813 const inputs = { x: $x };
16814 const attrs = { axis, keepDims };
16815 return ENGINE.runKernel(Mean, inputs, attrs);
16816 }
16817 const mean = op({ mean_ });
16818
16819 /**
16820 * @license
16821 * Copyright 2018 Google LLC. All Rights Reserved.
16822 * Licensed under the Apache License, Version 2.0 (the "License");
16823 * you may not use this file except in compliance with the License.
16824 * You may obtain a copy of the License at
16825 *
16826 * http://www.apache.org/licenses/LICENSE-2.0
16827 *
16828 * Unless required by applicable law or agreed to in writing, software
16829 * distributed under the License is distributed on an "AS IS" BASIS,
16830 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16831 * See the License for the specific language governing permissions and
16832 * limitations under the License.
16833 * =============================================================================
16834 */
16835 /**
16836 * Creates a `tf.Tensor` with all elements set to 0.
16837 *
16838 * ```js
16839 * tf.zeros([2, 2]).print();
16840 * ```
16841 *
16842 * @param shape An array of integers defining the output tensor shape.
16843 * @param dtype The type of an element in the resulting tensor. Can
16844 * be 'float32', 'int32' or 'bool'. Defaults to 'float'.
16845 *
16846 * @doc {heading: 'Tensors', subheading: 'Creation'}
16847 */
16848 function zeros(shape, dtype = 'float32') {
16849 if (dtype === 'complex64') {
16850 const real = zeros(shape, 'float32');
16851 const imag = zeros(shape, 'float32');
16852 return complex(real, imag);
16853 }
16854 const values = makeZerosTypedArray(sizeFromShape(shape), dtype);
16855 return ENGINE.makeTensor(values, shape, dtype);
16856 }
16857
16858 /**
16859 * @license
16860 * Copyright 2018 Google LLC. All Rights Reserved.
16861 * Licensed under the Apache License, Version 2.0 (the "License");
16862 * you may not use this file except in compliance with the License.
16863 * You may obtain a copy of the License at
16864 *
16865 * http://www.apache.org/licenses/LICENSE-2.0
16866 *
16867 * Unless required by applicable law or agreed to in writing, software
16868 * distributed under the License is distributed on an "AS IS" BASIS,
16869 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16870 * See the License for the specific language governing permissions and
16871 * limitations under the License.
16872 * =============================================================================
16873 */
16874 /**
16875 * Creates a `tf.Tensor` with all elements set to 1.
16876 *
16877 * ```js
16878 * tf.ones([2, 2]).print();
16879 * ```
16880 *
16881 * @param shape An array of integers defining the output tensor shape.
16882 * @param dtype The type of an element in the resulting tensor. Defaults to
16883 * 'float'.
16884 *
16885 * @doc {heading: 'Tensors', subheading: 'Creation'}
16886 */
16887 function ones$1(shape, dtype = 'float32') {
16888 if (dtype === 'complex64') {
16889 const real = ones$1(shape, 'float32');
16890 const imag = zeros(shape, 'float32');
16891 return complex(real, imag);
16892 }
16893 const values = makeOnesTypedArray(sizeFromShape(shape), dtype);
16894 return ENGINE.makeTensor(values, shape, dtype);
16895 }
16896
16897 /**
16898 * @license
16899 * Copyright 2021 Google LLC. All Rights Reserved.
16900 * Licensed under the Apache License, Version 2.0 (the "License");
16901 * you may not use this file except in compliance with the License.
16902 * You may obtain a copy of the License at
16903 *
16904 * http://www.apache.org/licenses/LICENSE-2.0
16905 *
16906 * Unless required by applicable law or agreed to in writing, software
16907 * distributed under the License is distributed on an "AS IS" BASIS,
16908 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16909 * See the License for the specific language governing permissions and
16910 * limitations under the License.
16911 * =============================================================================
16912 */
16913 /**
16914 * Broadcasts parameters for evaluation on an N-D grid.
16915 *
16916 * Given N one-dimensional coordinate arrays `*args`, returns a list `outputs`
16917 * of N-D coordinate arrays for evaluating expressions on an N-D grid.
16918 *
16919 * Notes:
16920 * `meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions.
16921 * When the `indexing` argument is set to 'xy' (the default), the broadcasting
16922 * instructions for the first two dimensions are swapped.
16923 * Examples:
16924 * Calling `const [X, Y] = meshgrid(x, y)` with the tensors
16925 *
16926 * ```javascript
16927 * const x = [1, 2, 3];
16928 * const y = [4, 5, 6];
16929 * const [X, Y] = tf.meshgrid(x, y);
16930 * // X = [[1, 2, 3],
16931 * // [1, 2, 3],
16932 * // [1, 2, 3]]
16933 * // Y = [[4, 4, 4],
16934 * // [5, 5, 5],
16935 * // [6, 6, 6]]
16936 * ```
16937 *
16938 * @param x Tensor with rank geq 1.
16939 * @param y Tensor with rank geq 1.
16940 * @param indexing
16941 *
16942 * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
16943 */
16944 function meshgrid(x, y, { indexing = 'xy' } = {}) {
16945 if (indexing !== 'xy' && indexing !== 'ij') {
16946 throw new TypeError(`${indexing} is not a valid third argument to meshgrid`);
16947 }
16948 if (x === undefined) {
16949 return [];
16950 }
16951 let $x = convertToTensor(x, 'x', 'meshgrid', x instanceof Tensor ? x.dtype : 'float32');
16952 if (y === undefined) {
16953 return [$x];
16954 }
16955 let $y = convertToTensor(y, 'y', 'meshgrid', y instanceof Tensor ? y.dtype : 'float32');
16956 const w = sizeFromShape($x.shape);
16957 const h = sizeFromShape($y.shape);
16958 if (indexing === 'xy') {
16959 $x = reshape($x, [1, -1]);
16960 $y = reshape($y, [-1, 1]);
16961 return [
16962 matMul(ones$1([h, 1], $x.dtype), $x),
16963 matMul($y, ones$1([1, w], $y.dtype)),
16964 ];
16965 }
16966 $x = reshape($x, [-1, 1]);
16967 $y = reshape($y, [1, -1]);
16968 return [
16969 matMul($x, ones$1([1, h], $x.dtype)),
16970 matMul(ones$1([w, 1], $y.dtype), $y),
16971 ];
16972 }
16973
16974 /**
16975 * @license
16976 * Copyright 2020 Google LLC. All Rights Reserved.
16977 * Licensed under the Apache License, Version 2.0 (the "License");
16978 * you may not use this file except in compliance with the License.
16979 * You may obtain a copy of the License at
16980 *
16981 * http://www.apache.org/licenses/LICENSE-2.0
16982 *
16983 * Unless required by applicable law or agreed to in writing, software
16984 * distributed under the License is distributed on an "AS IS" BASIS,
16985 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16986 * See the License for the specific language governing permissions and
16987 * limitations under the License.
16988 * =============================================================================
16989 */
16990 /**
16991 * Returns the min of a and b (`a < b ? a : b`) element-wise.
16992 * Supports broadcasting.
16993 *
16994 * We also expose `minimumStrict` which has the same signature as this op and
16995 * asserts that `a` and `b` are the same shape (does not broadcast).
16996 *
16997 * ```js
16998 * const a = tf.tensor1d([1, 4, 3, 16]);
16999 * const b = tf.tensor1d([1, 2, 9, 4]);
17000 *
17001 * a.minimum(b).print(); // or tf.minimum(a, b)
17002 * ```
17003 *
17004 * ```js
17005 * // Broadcast minimum a with b.
17006 * const a = tf.tensor1d([2, 4, 6, 8]);
17007 * const b = tf.scalar(5);
17008 *
17009 * a.minimum(b).print(); // or tf.minimum(a, b)
17010 * ```
17011 *
17012 * @param a The first tensor.
17013 * @param b The second tensor. Must have the same type as `a`.
17014 *
17015 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
17016 */
17017 function minimum_(a, b) {
17018 let $a = convertToTensor(a, 'a', 'minimum');
17019 let $b = convertToTensor(b, 'b', 'minimum');
17020 [$a, $b] = makeTypesMatch($a, $b);
17021 if ($a.dtype === 'bool') {
17022 $a = cast($a, 'int32');
17023 $b = cast($b, 'int32');
17024 }
17025 assertAndGetBroadcastShape($a.shape, $b.shape);
17026 const inputs = { a: $a, b: $b };
17027 return ENGINE.runKernel(Minimum, inputs);
17028 }
17029 const minimum = op({ minimum_ });
17030
17031 /**
17032 * @license
17033 * Copyright 2020 Google LLC. All Rights Reserved.
17034 * Licensed under the Apache License, Version 2.0 (the "License");
17035 * you may not use this file except in compliance with the License.
17036 * You may obtain a copy of the License at
17037 *
17038 * http://www.apache.org/licenses/LICENSE-2.0
17039 *
17040 * Unless required by applicable law or agreed to in writing, software
17041 * distributed under the License is distributed on an "AS IS" BASIS,
17042 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17043 * See the License for the specific language governing permissions and
17044 * limitations under the License.
17045 * =============================================================================
17046 */
17047 /**
17048 * Pads a `tf.Tensor` using mirror padding.
17049 *
17050 * This operation implements the `REFLECT` and `SYMMETRIC` modes of pad.
17051 *
17052 * ```js
17053 * const x = tf.range(0, 9).reshape([1, 1, 3, 3]);
17054 * x.mirrorPad([[0, 0], [0, 0], [2, 2], [2, 2]], 'reflect').print();
17055 * ```
17056 * @param x The tensor to pad.
17057 * @param paddings An array of length `R` (the rank of the tensor), where
17058 * each element is a length-2 tuple of ints `[padBefore, padAfter]`,
17059 * specifying how much to pad along each dimension of the tensor.
17060 * In "reflect" mode, the padded regions do not include the borders,
17061 * while in "symmetric" mode the padded regions do include the borders.
17062 * For example, if the input is `[1, 2, 3]` and paddings is `[0, 2]`,
17063 * then the output is `[1, 2, 3, 2, 1]` in "reflect" mode, and
17064 * `[1, 2, 3, 3, 2]` in "symmetric" mode.
17065 * If `mode` is "reflect" then both `paddings[D, 0]` and `paddings[D, 1]`
17066 * must be no greater than `x.shape[D] - 1`. If mode is "symmetric"
17067 * then both `paddings[D, 0]` and `paddings[D, 1]` must be no greater than
17068 * `x.shape[D]`
17069 * @param mode String to specify padding mode. Can be `'reflect' | 'symmetric'`
17070 */
17071 /** @doc {heading: 'Tensors', subheading: 'Transformations'} */
17072 function mirrorPad_(x, paddings, mode) {
17073 assert(mode === 'reflect' || mode === 'symmetric', () => `Invalid mode. Mode must be either reflect or symmetric. ` +
17074 `Got ${mode}.`);
17075 const $x = convertToTensor(x, 'x', 'mirrorPad');
17076 if ($x.rank === 0) {
17077 throw new Error('mirrorPad(scalar) is not defined. ' +
17078 'Pass non-scalar to mirrorPad');
17079 }
17080 assert(paddings.length === $x.rank, () => `Padding doesn't match input. Must be ${$x.rank}. ` +
17081 `Got ${paddings.length}.`);
17082 const shapeOffset = mode === 'reflect' ? 1 : 0;
17083 for (let i = 0; i < $x.rank; i++) {
17084 assert(paddings[i].length === 2, () => `Invalid number of paddings. Must be length of 2 each.`);
17085 assert(paddings[i][0] >= 0 && paddings[i][0] <= $x.shape[i] - shapeOffset &&
17086 paddings[i][1] >= 0 && paddings[i][1] <= $x.shape[i] - shapeOffset, () => `Padding in dimension ${i} cannot be greater than or equal ` +
17087 `to ${$x.shape[i] - shapeOffset} or less than 0 for input of ` +
17088 `shape ${$x.shape}`);
17089 }
17090 const attrs = { paddings, mode };
17091 const inputs = { x: $x };
17092 return ENGINE.runKernel(MirrorPad, inputs, attrs);
17093 }
17094 const mirrorPad = op({ mirrorPad_ });
17095
17096 /**
17097 * @license
17098 * Copyright 2020 Google LLC. All Rights Reserved.
17099 * Licensed under the Apache License, Version 2.0 (the "License");
17100 * you may not use this file except in compliance with the License.
17101 * You may obtain a copy of the License at
17102 *
17103 * http://www.apache.org/licenses/LICENSE-2.0
17104 *
17105 * Unless required by applicable law or agreed to in writing, software
17106 * distributed under the License is distributed on an "AS IS" BASIS,
17107 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17108 * See the License for the specific language governing permissions and
17109 * limitations under the License.
17110 * =============================================================================
17111 */
17112 /**
17113 * Returns the mod of a and b element-wise.
17114 * `floor(x / y) * y + mod(x, y) = x`
17115 * Supports broadcasting.
17116 *
17117 * We also expose `tf.modStrict` which has the same signature as this op and
17118 * asserts that `a` and `b` are the same shape (does not broadcast).
17119 *
17120 * ```js
17121 * const a = tf.tensor1d([1, 4, 3, 16]);
17122 * const b = tf.tensor1d([1, 2, 9, 4]);
17123 *
17124 * a.mod(b).print(); // or tf.mod(a, b)
17125 * ```
17126 *
17127 * ```js
17128 * // Broadcast a mod b.
17129 * const a = tf.tensor1d([2, 4, 6, 8]);
17130 * const b = tf.scalar(5);
17131 *
17132 * a.mod(b).print(); // or tf.mod(a, b)
17133 * ```
17134 *
17135 * @param a The first tensor.
17136 * @param b The second tensor. Must have the same type as `a`.
17137 *
17138 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
17139 */
17140 function mod_(a, b) {
17141 let $a = convertToTensor(a, 'a', 'mod');
17142 let $b = convertToTensor(b, 'b', 'mod');
17143 [$a, $b] = makeTypesMatch($a, $b);
17144 const inputs = { a: $a, b: $b };
17145 return ENGINE.runKernel(Mod, inputs);
17146 }
17147 const mod = op({ mod_ });
17148
17149 /**
17150 * @license
17151 * Copyright 2020 Google LLC. All Rights Reserved.
17152 * Licensed under the Apache License, Version 2.0 (the "License");
17153 * you may not use this file except in compliance with the License.
17154 * You may obtain a copy of the License at
17155 *
17156 * http://www.apache.org/licenses/LICENSE-2.0
17157 *
17158 * Unless required by applicable law or agreed to in writing, software
17159 * distributed under the License is distributed on an "AS IS" BASIS,
17160 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17161 * See the License for the specific language governing permissions and
17162 * limitations under the License.
17163 * =============================================================================
17164 */
17165 /**
17166 * Calculates the mean and variance of `x`. The mean and variance are
17167 * calculated by aggregating the contents of `x` across `axes`. If `x` is
17168 * 1-D and `axes = [0]` this is just the mean and variance of a vector.
17169 *
17170 * @param x The input tensor.
17171 * @param axis The dimension(s) along with to compute mean and
17172 * variance. By default it reduces all dimensions.
17173 * @param keepDims If true, the moments have the same dimensionality as the
17174 * input.
17175 * @return An object with two keys: `mean` and `variance`.
17176 *
17177 * @doc {heading: 'Operations', subheading: 'Normalization'}
17178 */
17179 function moments_(x, axis = null, keepDims = false) {
17180 x = convertToTensor(x, 'x', 'moments');
17181 const axes = parseAxisParam(axis, x.shape);
17182 const xMean = mean(x, axes, keepDims);
17183 let keepDimsShape = xMean.shape;
17184 if (!keepDims) {
17185 keepDimsShape = expandShapeToKeepDim(xMean.shape, axes);
17186 }
17187 const devSquared = square(sub(cast(x, 'float32'), reshape(xMean, keepDimsShape)));
17188 const variance = mean(devSquared, axes, keepDims);
17189 return { mean: xMean, variance };
17190 }
17191 const moments = op({ moments_ });
17192
17193 /**
17194 * Computes the next states and outputs of a stack of LSTMCells.
17195 *
17196 * Each cell output is used as input to the next cell.
17197 *
17198 * Returns `[cellState, cellOutput]`.
17199 *
17200 * Derived from tf.contrib.rn.MultiRNNCell.
17201 *
17202 * @param lstmCells Array of LSTMCell functions.
17203 * @param data The input to the cell.
17204 * @param c Array of previous cell states.
17205 * @param h Array of previous cell outputs.
17206 *
17207 * @doc {heading: 'Operations', subheading: 'RNN'}
17208 */
17209 function multiRNNCell_(lstmCells, data, c, h) {
17210 const $data = convertToTensor(data, 'data', 'multiRNNCell');
17211 const $c = convertToTensorArray(c, 'c', 'multiRNNCell');
17212 const $h = convertToTensorArray(h, 'h', 'multiRNNCell');
17213 let input = $data;
17214 const newStates = [];
17215 for (let i = 0; i < lstmCells.length; i++) {
17216 const output = lstmCells[i](input, $c[i], $h[i]);
17217 newStates.push(output[0]);
17218 newStates.push(output[1]);
17219 input = output[1];
17220 }
17221 const newC = [];
17222 const newH = [];
17223 for (let i = 0; i < newStates.length; i += 2) {
17224 newC.push(newStates[i]);
17225 newH.push(newStates[i + 1]);
17226 }
17227 return [newC, newH];
17228 }
17229 const multiRNNCell = op({ multiRNNCell_ });
17230
17231 /**
17232 * @license
17233 * Copyright 2020 Google LLC. All Rights Reserved.
17234 * Licensed under the Apache License, Version 2.0 (the "License");
17235 * you may not use this file except in compliance with the License.
17236 * You may obtain a copy of the License at
17237 *
17238 * http://www.apache.org/licenses/LICENSE-2.0
17239 *
17240 * Unless required by applicable law or agreed to in writing, software
17241 * distributed under the License is distributed on an "AS IS" BASIS,
17242 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17243 * See the License for the specific language governing permissions and
17244 * limitations under the License.
17245 * =============================================================================
17246 */
17247 /**
17248 * Creates a `tf.Tensor` with values drawn from a multinomial distribution.
17249 *
17250 * ```js
17251 * const probs = tf.tensor([.75, .25]);
17252 * tf.multinomial(probs, 3).print();
17253 * ```
17254 *
17255 * @param logits 1D array with unnormalized log-probabilities, or
17256 * 2D array of shape `[batchSize, numOutcomes]`. See the `normalized`
17257 * parameter.
17258 * @param numSamples Number of samples to draw for each row slice.
17259 * @param seed The seed number.
17260 * @param normalized Whether the provided `logits` are normalized true
17261 * probabilities (sum to 1). Defaults to false.
17262 * @return 1D array of shape `[numSamples]`, or 2D array of shape
17263 * `[batchSize, numSamples]`, depending on the rank of the input.
17264 *
17265 * @doc {heading: 'Tensors', subheading: 'Random'}
17266 */
17267 function multinomial_(logits, numSamples, seed, normalized = false) {
17268 const $logits = convertToTensor(logits, 'logits', 'multinomial');
17269 const numOutcomes = $logits.size;
17270 const origRank = $logits.rank;
17271 if (numOutcomes < 2) {
17272 throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ` +
17273 `${numOutcomes}.`);
17274 }
17275 if (origRank > 2) {
17276 throw new Error(`Rank of probabilities must be 1 or 2, but is ${origRank}`);
17277 }
17278 // TODO(lina128): Investigate correct seed behavior. The code seems not allow
17279 // setting see to 0.
17280 seed = seed || Math.random();
17281 // The kernel only accepts (and returns) rank 2 tensors.
17282 const logits2D = origRank === 1 ? reshape($logits, [1, -1]) : $logits;
17283 const inputs = { logits: logits2D };
17284 const attrs = { numSamples, seed, normalized };
17285 // tslint:disable-next-line: no-unnecessary-type-assertion
17286 const res = ENGINE.runKernel(Multinomial, inputs, attrs);
17287 // tslint:disable-next-line:no-unnecessary-type-assertion
17288 return origRank === 1 ? reshape(res, [res.size]) : res;
17289 }
17290 const multinomial = op({ multinomial_ });
17291
17292 /**
17293 * @license
17294 * Copyright 2020 Google LLC. All Rights Reserved.
17295 * Licensed under the Apache License, Version 2.0 (the "License");
17296 * you may not use this file except in compliance with the License.
17297 * You may obtain a copy of the License at
17298 *
17299 * http://www.apache.org/licenses/LICENSE-2.0
17300 *
17301 * Unless required by applicable law or agreed to in writing, software
17302 * distributed under the License is distributed on an "AS IS" BASIS,
17303 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17304 * See the License for the specific language governing permissions and
17305 * limitations under the License.
17306 * =============================================================================
17307 */
17308 /**
17309 * Returns the truth value of (a != b) element-wise. Supports broadcasting.
17310 *
17311 * ```js
17312 * const a = tf.tensor1d([1, 2, 3]);
17313 * const b = tf.tensor1d([0, 2, 3]);
17314 *
17315 * a.notEqual(b).print();
17316 * ```
17317 * @param a The first input tensor.
17318 * @param b The second input tensor. Must have the same dtype as `a`.
17319 *
17320 * @doc {heading: 'Operations', subheading: 'Logical'}
17321 */
17322 function notEqual_(a, b) {
17323 let $a = convertToTensor(a, 'a', 'notEqual', 'string_or_numeric');
17324 let $b = convertToTensor(b, 'b', 'notEqual', 'string_or_numeric');
17325 [$a, $b] = makeTypesMatch($a, $b);
17326 assertAndGetBroadcastShape($a.shape, $b.shape);
17327 const inputs = { a: $a, b: $b };
17328 return ENGINE.runKernel(NotEqual, inputs);
17329 }
17330 const notEqual = op({ notEqual_ });
17331
17332 /**
17333 * @license
17334 * Copyright 2018 Google LLC. All Rights Reserved.
17335 * Licensed under the Apache License, Version 2.0 (the "License");
17336 * you may not use this file except in compliance with the License.
17337 * You may obtain a copy of the License at
17338 *
17339 * http://www.apache.org/licenses/LICENSE-2.0
17340 *
17341 * Unless required by applicable law or agreed to in writing, software
17342 * distributed under the License is distributed on an "AS IS" BASIS,
17343 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17344 * See the License for the specific language governing permissions and
17345 * limitations under the License.
17346 * =============================================================================
17347 */
17348 /**
17349 * Creates a `tf.Tensor` with all elements set to 1 with the same shape as the
17350 * given tensor.
17351 *
17352 * ```js
17353 * const x = tf.tensor([1, 2]);
17354 * tf.onesLike(x).print();
17355 * ```
17356 * @param x A tensor.
17357 *
17358 * @doc {heading: 'Tensors', subheading: 'Creation'}
17359 */
17360 function onesLike_(x) {
17361 const $x = convertToTensor(x, 'x', 'onesLike');
17362 const inputs = { x: $x };
17363 return ENGINE.runKernel(OnesLike, inputs);
17364 }
17365 const onesLike = op({ onesLike_ });
17366
17367 /**
17368 * Computes the outer product of two vectors, `v1` and `v2`.
17369 *
17370 * ```js
17371 * const a = tf.tensor1d([1, 2, 3]);
17372 * const b = tf.tensor1d([3, 4, 5]);
17373 *
17374 * tf.outerProduct(a, b).print();
17375 * ```
17376 * @param v1 The first vector in the outer product operation.
17377 * @param v2 The second vector in the outer product operation.
17378 *
17379 * @doc {heading: 'Operations', subheading: 'Matrices'}
17380 */
17381 function outerProduct_(v1, v2) {
17382 const $v1 = convertToTensor(v1, 'v1', 'outerProduct');
17383 const $v2 = convertToTensor(v2, 'v2', 'outerProduct');
17384 assert($v1.rank === 1 && $v2.rank === 1, () => `Error in outerProduct: inputs must be rank 1, but got ranks ` +
17385 `${$v1.rank} and ${$v2.rank}.`);
17386 const v12D = reshape($v1, [-1, 1]);
17387 const v22D = reshape($v2, [1, -1]);
17388 return matMul(v12D, v22D);
17389 }
17390 const outerProduct = op({ outerProduct_ });
17391
17392 /**
17393 * @license
17394 * Copyright 2020 Google LLC. All Rights Reserved.
17395 * Licensed under the Apache License, Version 2.0 (the "License");
17396 * you may not use this file except in compliance with the License.
17397 * You may obtain a copy of the License at
17398 *
17399 * http://www.apache.org/licenses/LICENSE-2.0
17400 *
17401 * Unless required by applicable law or agreed to in writing, software
17402 * distributed under the License is distributed on an "AS IS" BASIS,
17403 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17404 * See the License for the specific language governing permissions and
17405 * limitations under the License.
17406 * =============================================================================
17407 */
17408 /**
17409 * Pads a `tf.Tensor` with a given value and paddings.
17410 *
17411 * This operation implements `CONSTANT` mode. For `REFLECT` and `SYMMETRIC`,
17412 * refer to `tf.mirrorPad`
17413 *
17414 * Also available are stricter rank-specific methods with the same signature
17415 * as this method that assert that `paddings` is of given length.
17416 * - `tf.pad1d`
17417 * - `tf.pad2d`
17418 * - `tf.pad3d`
17419 * - `tf.pad4d`
17420 *
17421 * ```js
17422 * const x = tf.tensor1d([1, 2, 3, 4]);
17423 * x.pad([[1, 2]]).print();
17424 * ```
17425 * @param x The tensor to pad.
17426 * @param paddings An array of length `R` (the rank of the tensor), where
17427 * each element is a length-2 tuple of ints `[padBefore, padAfter]`,
17428 * specifying how much to pad along each dimension of the tensor.
17429 * @param constantValue The pad value to use. Defaults to 0.
17430 *
17431 * @doc {heading: 'Tensors', subheading: 'Transformations'}
17432 */
17433 function pad_(x, paddings, constantValue = 0) {
17434 const $x = convertToTensor(x, 'x', 'pad');
17435 if ($x.rank === 0) {
17436 throw new Error('pad(scalar) is not defined. Pass non-scalar to pad');
17437 }
17438 const attrs = { paddings, constantValue };
17439 const inputs = { x: $x };
17440 return ENGINE.runKernel(PadV2, inputs, attrs);
17441 }
17442 const pad = op({ pad_ });
17443
17444 /**
17445 * Pads a `tf.Tensor1D` with a given value and paddings. See `pad` for details.
17446 */
17447 function pad1d_(x, paddings, constantValue = 0) {
17448 assert(paddings.length === 2, () => 'Invalid number of paddings. Must be length of 2.');
17449 return pad(x, [paddings], constantValue);
17450 }
17451 const pad1d = op({ pad1d_ });
17452
17453 /**
17454 * Pads a `tf.Tensor2D` with a given value and paddings. See `pad` for details.
17455 */
17456 function pad2d_(x, paddings, constantValue = 0) {
17457 assert(paddings.length === 2 && paddings[0].length === 2 &&
17458 paddings[1].length === 2, () => 'Invalid number of paddings. Must be length of 2 each.');
17459 return pad(x, paddings, constantValue);
17460 }
17461 const pad2d = op({ pad2d_ });
17462
17463 /**
17464 * Pads a `tf.Tensor3D` with a given value and paddings. See `pad` for details.
17465 */
17466 function pad3d_(x, paddings, constantValue = 0) {
17467 assert(paddings.length === 3 && paddings[0].length === 2 &&
17468 paddings[1].length === 2 && paddings[2].length === 2, () => 'Invalid number of paddings. Must be length of 2 each.');
17469 return pad(x, paddings, constantValue);
17470 }
17471 const pad3d = op({ pad3d_ });
17472
17473 /**
17474 * Pads a `tf.Tensor4D` with a given value and paddings. See `pad` for details.
17475 */
17476 function pad4d_(x, paddings, constantValue = 0) {
17477 assert(paddings.length === 4 && paddings[0].length === 2 &&
17478 paddings[1].length === 2 && paddings[2].length === 2 &&
17479 paddings[3].length === 2, () => 'Invalid number of paddings. Must be length of 2 each.');
17480 return pad(x, paddings, constantValue);
17481 }
17482 const pad4d = op({ pad4d_ });
17483
17484 /**
17485 * @license
17486 * Copyright 2020 Google LLC. All Rights Reserved.
17487 * Licensed under the Apache License, Version 2.0 (the "License");
17488 * you may not use this file except in compliance with the License.
17489 * You may obtain a copy of the License at
17490 *
17491 * http://www.apache.org/licenses/LICENSE-2.0
17492 *
17493 * Unless required by applicable law or agreed to in writing, software
17494 * distributed under the License is distributed on an "AS IS" BASIS,
17495 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17496 * See the License for the specific language governing permissions and
17497 * limitations under the License.
17498 * =============================================================================
17499 */
17500 /**
17501 * This operation divides "spatial" dimensions `[1, ..., M]` of the input into
17502 * a grid of blocks of shape `blockShape`, and interleaves these blocks with
17503 * the "batch" dimension (0) such that in the output, the spatial
17504 * dimensions `[1, ..., M]` correspond to the position within the grid,
17505 * and the batch dimension combines both the position within a spatial block
17506 * and the original batch position. Prior to division into blocks,
17507 * the spatial dimensions of the input are optionally zero padded
17508 * according to `paddings`. See below for a precise description.
17509 *
17510 * ```js
17511 * const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]);
17512 * const blockShape = [2, 2];
17513 * const paddings = [[0, 0], [0, 0]];
17514 *
17515 * x.spaceToBatchND(blockShape, paddings).print();
17516 * ```
17517 *
17518 * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
17519 * remainingShape`, where spatialShape has `M` dimensions.
17520 * @param blockShape A 1-D array. Must have shape `[M]`, all values must
17521 * be >= 1.
17522 * @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >=
17523 * 0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad
17524 * from input dimension `i + 1`, which corresponds to spatial dimension `i`. It
17525 * is required that
17526 * `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0`
17527 *
17528 * This operation is equivalent to the following steps:
17529 *
17530 * 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input
17531 * according to `paddings` to produce `padded` of shape paddedShape.
17532 *
17533 * 2. Reshape `padded` to `reshapedPadded` of shape:
17534 * `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ...,
17535 * paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape`
17536 *
17537 * 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded`
17538 * of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ...,
17539 * paddedShape[M] / blockShape[M-1]] + remainingShape`
17540 *
17541 * 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the
17542 * batch dimension, producing an output tensor of shape:
17543 * `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ...,
17544 * paddedShape[M] / blockShape[M-1]] + remainingShape`
17545 *
17546 * @doc {heading: 'Tensors', subheading: 'Transformations'}
17547 */
17548 function spaceToBatchND_(x, blockShape, paddings) {
17549 const $x = convertToTensor(x, 'x', 'spaceToBatchND');
17550 assert($x.rank >= 1 + blockShape.length, () => `input rank ${$x.rank} should be > than [blockShape] ${blockShape.length}`);
17551 assert(paddings.length === blockShape.length, () => `paddings.shape[0] ${paddings.length} must be equal to [blockShape] ${blockShape.length}`);
17552 assert($x.shape.reduce((a, b, i) => {
17553 if (i > 0 && i <= blockShape.length) {
17554 return a &&
17555 ((b + paddings[i - 1][0] + paddings[i - 1][1]) %
17556 blockShape[i - 1] ===
17557 0);
17558 }
17559 return a;
17560 }, true), () => `input spatial dimensions ${$x.shape.slice(1)} with paddings ${paddings.toString()} must be divisible by blockShapes ${blockShape.toString()}`);
17561 const inputs = { x: $x };
17562 const attrs = { blockShape, paddings };
17563 return ENGINE.runKernel(SpaceToBatchND, inputs, attrs);
17564 }
17565 const spaceToBatchND = op({ spaceToBatchND_ });
17566
17567 /**
17568 * @license
17569 * Copyright 2018 Google LLC. All Rights Reserved.
17570 * Licensed under the Apache License, Version 2.0 (the "License");
17571 * you may not use this file except in compliance with the License.
17572 * You may obtain a copy of the License at
17573 *
17574 * http://www.apache.org/licenses/LICENSE-2.0
17575 *
17576 * Unless required by applicable law or agreed to in writing, software
17577 * distributed under the License is distributed on an "AS IS" BASIS,
17578 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17579 * See the License for the specific language governing permissions and
17580 * limitations under the License.
17581 * =============================================================================
17582 */
17583 /**
17584 * Performs an N-D pooling operation
17585 *
17586 * @param input The input tensor, of rank 4 or rank 3 of shape
17587 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
17588 * @param windowShape The filter size: `[filterHeight, filterWidth]`. If
17589 * `filterSize` is a single number, then `filterHeight == filterWidth`.
17590 * @param poolingType The type of pooling, either 'max' or 'avg'.
17591 * @param pad The type of padding algorithm:
17592 * - `same` and stride 1: output will be of same size as input,
17593 * regardless of filter size.
17594 * - `valid`: output will be smaller than input if filter is larger
17595 * than 1x1.
17596 * - For more info, see this guide:
17597 * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
17598 * https://www.tensorflow.org/api_guides/python/nn#Convolution)
17599 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
17600 * in which we sample input values across the height and width dimensions
17601 * in dilated pooling. Defaults to `[1, 1]`. If `dilationRate` is a single
17602 * number, then `dilationHeight == dilationWidth`. If it is greater than
17603 * 1, then all values of `strides` must be 1.
17604 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
17605 * `strides` is a single number, then `strideHeight == strideWidth`.
17606 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
17607 * provided, it will default to truncate.
17608 *
17609 * @doc {heading: 'Operations', subheading: 'Convolution'}
17610 */
17611 function pool_(input, windowShape, poolingType, pad, dilations, strides, dimRoundingMode) {
17612 if (dilations == null) {
17613 dilations = [1, 1];
17614 }
17615 if (strides == null) {
17616 strides = 1;
17617 }
17618 if (pad === 0) {
17619 pad = 'valid';
17620 }
17621 const $x = convertToTensor(input, 'x', 'maxPool');
17622 let x4D = $x;
17623 let reshapedTo4D = false;
17624 if ($x.rank === 3) {
17625 reshapedTo4D = true;
17626 x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
17627 }
17628 assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in pool: Either strides or dilations must be 1. ' +
17629 `Got strides ${strides} and dilations '${dilations}'`);
17630 const convInfo = computePool2DInfo(x4D.shape, windowShape, strides, dilations, pad);
17631 const dilation = [convInfo.dilationHeight, convInfo.dilationWidth];
17632 // The following implementation does batchToSpace(pool(spaceToBatch(x)))
17633 // whenever dilation > 1 since the TF kernels do not support dilation > 1.
17634 // tslint:disable-next-line:max-line-length
17635 // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/nn_ops.py#L1037
17636 let basePadding;
17637 if (pad === 'same') {
17638 basePadding = withSpaceToBatchBasePaddings([convInfo.filterHeight, convInfo.filterWidth], dilation);
17639 }
17640 else {
17641 basePadding = [[0, 0], [0, 0]];
17642 }
17643 const isDilationOne = dilation[0] === 1 && dilation[1] === 1;
17644 const [adjustedPadding, adjustedCrops] = requiredSpaceToBatchPaddings([convInfo.inHeight, convInfo.inWidth], dilation, basePadding);
17645 const convertedPad = isDilationOne ? pad : 'valid';
17646 const convertedX = isDilationOne ? x4D : spaceToBatchND(x4D, dilation, adjustedPadding);
17647 const forwardOp = poolingType === 'avg' ?
17648 () => avgPool(convertedX, windowShape, strides, convertedPad, dimRoundingMode) :
17649 () => maxPool(convertedX, windowShape, strides, convertedPad, dimRoundingMode);
17650 const y = forwardOp();
17651 const res = isDilationOne ? y : batchToSpaceND(y, dilation, adjustedCrops);
17652 if (reshapedTo4D) {
17653 return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
17654 }
17655 return res;
17656 }
17657 // Helper function to compute crops and paddings for pool with dilation > 1.
17658 // tslint:disable-next-line:max-line-length
17659 // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/array_ops.py#L2184
17660 function requiredSpaceToBatchPaddings(inputShape, blockShape, basePadding) {
17661 const padStart = basePadding.map(b => b[0]);
17662 const origPadEnd = basePadding.map(b => b[1]);
17663 const fullInputShape = inputShape.concat(padStart, origPadEnd);
17664 const padEndExtra = blockShape.map((b, i) => (b - fullInputShape[i] % b) % b);
17665 const padEnd = origPadEnd.map((s, i) => s + padEndExtra[i]);
17666 const paddings = blockShape.map((_, i) => [padStart[i], padEnd[i]]);
17667 const crops = blockShape.map((_, i) => [0, padEndExtra[i]]);
17668 return [paddings, crops];
17669 }
17670 // Helper function to compute base paddings for pool with dilation > 1.
17671 // tslint:disable-next-line:max-line-length
17672 // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/nn_ops.py#L524
17673 function withSpaceToBatchBasePaddings(filterShape, dilation) {
17674 // Spatial dimensions of the filters and the upsampled filters in which we
17675 // introduce (rate - 1) zeros between consecutive filter values.
17676 const dilatedFilterShape = filterShape.map((s, i) => {
17677 return s + (s - 1) * (dilation[i] - 1);
17678 });
17679 const padExtraShape = dilatedFilterShape.map(s => s - 1);
17680 // When padding is odd, we pad more at end, following the same
17681 // convention as conv2d.
17682 const padExtraStart = padExtraShape.map(s => Math.floor(s / 2));
17683 const padExtraEnd = padExtraShape.map((s, i) => s - padExtraStart[i]);
17684 return padExtraShape.map((_, i) => {
17685 return [padExtraStart[i], padExtraEnd[i]];
17686 });
17687 }
17688 const pool = op({ pool_ });
17689
17690 /**
17691 * @license
17692 * Copyright 2020 Google LLC. All Rights Reserved.
17693 * Licensed under the Apache License, Version 2.0 (the "License");
17694 * you may not use this file except in compliance with the License.
17695 * You may obtain a copy of the License at
17696 *
17697 * http://www.apache.org/licenses/LICENSE-2.0
17698 *
17699 * Unless required by applicable law or agreed to in writing, software
17700 * distributed under the License is distributed on an "AS IS" BASIS,
17701 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17702 * See the License for the specific language governing permissions and
17703 * limitations under the License.
17704 * =============================================================================
17705 */
17706 /**
17707 * Computes leaky rectified linear element-wise with parametric alphas.
17708 *
17709 * `x < 0 ? alpha * x : f(x) = x`
17710 *
17711 * ```js
17712 * const x = tf.tensor1d([-1, 2, -3, 4]);
17713 * const alpha = tf.scalar(0.1);
17714 *
17715 * x.prelu(alpha).print(); // or tf.prelu(x, alpha)
17716 * ```
17717 * @param x The input tensor.
17718 * @param alpha Scaling factor for negative values.
17719 *
17720 * @doc {heading: 'Operations', subheading: 'Basic math'}
17721 */
17722 function prelu_(x, alpha) {
17723 const $x = convertToTensor(x, 'x', 'prelu');
17724 const $alpha = convertToTensor(alpha, 'alpha', 'prelu');
17725 const inputs = { x: $x, alpha: $alpha };
17726 return ENGINE.runKernel(Prelu, inputs);
17727 }
17728 const prelu = op({ prelu_ });
17729
17730 /**
17731 * @license
17732 * Copyright 2020 Google LLC. All Rights Reserved.
17733 * Licensed under the Apache License, Version 2.0 (the "License");
17734 * you may not use this file except in compliance with the License.
17735 * You may obtain a copy of the License at
17736 *
17737 * http://www.apache.org/licenses/LICENSE-2.0
17738 *
17739 * Unless required by applicable law or agreed to in writing, software
17740 * distributed under the License is distributed on an "AS IS" BASIS,
17741 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17742 * See the License for the specific language governing permissions and
17743 * limitations under the License.
17744 * =============================================================================
17745 */
17746 /**
17747 * Computes the product of elements across dimensions of a `tf.Tensor`.
17748 *
17749 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
17750 * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
17751 * `axes`. If `keepDims` is true, the reduced dimensions are retained with
17752 * length 1. If `axes` has no entries, all dimensions are reduced, and a
17753 * `tf.Tensor` with a single element is returned.
17754 *
17755 * ```js
17756 * const x = tf.tensor1d([1, 2, 3]);
17757 *
17758 * x.prod().print(); // or tf.prod(x)
17759 * ```
17760 *
17761 * ```js
17762 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
17763 *
17764 * const axis = 1;
17765 * x.prod(axis).print(); // or tf.prod(x, axis)
17766 * ```
17767 *
17768 * @param x The input tensor to compute the product over. If the dtype is `bool`
17769 * it will be converted to `int32` and the output dtype will be `int32`.
17770 * @param axis The dimension(s) to reduce. By default it reduces
17771 * all dimensions.
17772 * @param keepDims If true, retains reduced dimensions with size 1.
17773 *
17774 * @doc {heading: 'Operations', subheading: 'Reduction'}
17775 */
17776 function prod_(x, axis = null, keepDims = false) {
17777 let $x = convertToTensor(x, 'x', 'prod');
17778 if ($x.dtype === 'bool') {
17779 // bool is not an allowed type for the underlying kernel.
17780 $x = cast($x, 'int32');
17781 }
17782 const inputs = { x: $x };
17783 const attrs = { axis, keepDims };
17784 return ENGINE.runKernel(Prod, inputs, attrs);
17785 }
17786 const prod = op({ prod_ });
17787
17788 /**
17789 * @license
17790 * Copyright 2020 Google LLC. All Rights Reserved.
17791 * Licensed under the Apache License, Version 2.0 (the "License");
17792 * you may not use this file except in compliance with the License.
17793 * You may obtain a copy of the License at
17794 *
17795 * http://www.apache.org/licenses/LICENSE-2.0
17796 *
17797 * Unless required by applicable law or agreed to in writing, software
17798 * distributed under the License is distributed on an "AS IS" BASIS,
17799 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17800 * See the License for the specific language governing permissions and
17801 * limitations under the License.
17802 * =============================================================================
17803 */
17804 /**
17805 * Creates a `tf.Tensor` with values sampled from a random number generator
17806 * function defined by the user.
17807 *
17808 * @param shape An array of integers defining the output tensor shape.
17809 * @param randFunction A random number generator function which is called
17810 * for each element in the output tensor.
17811 * @param dtype The data type of the output tensor. Defaults to 'float32'.
17812 *
17813 * @doc {heading: 'Tensors', subheading: 'Random'}
17814 */
17815 function rand_(shape, randFunction, dtype) {
17816 const size = sizeFromShape(shape);
17817 let values = null;
17818 if (dtype == null || dtype === 'float32') {
17819 values = new Float32Array(size);
17820 }
17821 else if (dtype === 'int32') {
17822 values = new Int32Array(size);
17823 }
17824 else if (dtype === 'bool') {
17825 values = new Uint8Array(size);
17826 }
17827 else {
17828 throw new Error(`Unknown data type ${dtype}`);
17829 }
17830 for (let i = 0; i < size; i++) {
17831 values[i] = randFunction();
17832 }
17833 return ENGINE.makeTensor(values, shape, dtype);
17834 }
17835 const rand = op({ rand_ });
17836
17837 var commonjsGlobal = typeof globalThis !== 'undefined' ? globalThis : typeof window !== 'undefined' ? window : typeof global !== 'undefined' ? global : typeof self !== 'undefined' ? self : {};
17838
17839 function unwrapExports (x) {
17840 return x && x.__esModule && Object.prototype.hasOwnProperty.call(x, 'default') ? x['default'] : x;
17841 }
17842
17843 function createCommonjsModule(fn, module) {
17844 return module = { exports: {} }, fn(module, module.exports), module.exports;
17845 }
17846
17847 function getCjsExportFromNamespace (n) {
17848 return n && n['default'] || n;
17849 }
17850
17851 function commonjsRequire () {
17852 throw new Error('Dynamic requires are not currently supported by @rollup/plugin-commonjs');
17853 }
17854
17855 var alea = createCommonjsModule(function (module) {
17856 // A port of an algorithm by Johannes Baagøe <baagoe@baagoe.com>, 2010
17857 // http://baagoe.com/en/RandomMusings/javascript/
17858 // https://github.com/nquinlan/better-random-numbers-for-javascript-mirror
17859 // Original work is under MIT license -
17860
17861 // Copyright (C) 2010 by Johannes Baagøe <baagoe@baagoe.org>
17862 //
17863 // Permission is hereby granted, free of charge, to any person obtaining a copy
17864 // of this software and associated documentation files (the "Software"), to deal
17865 // in the Software without restriction, including without limitation the rights
17866 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
17867 // copies of the Software, and to permit persons to whom the Software is
17868 // furnished to do so, subject to the following conditions:
17869 //
17870 // The above copyright notice and this permission notice shall be included in
17871 // all copies or substantial portions of the Software.
17872 //
17873 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17874 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17875 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17876 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17877 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
17878 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
17879 // THE SOFTWARE.
17880
17881
17882
17883 (function(global, module, define) {
17884
17885 function Alea(seed) {
17886 var me = this, mash = Mash();
17887
17888 me.next = function() {
17889 var t = 2091639 * me.s0 + me.c * 2.3283064365386963e-10; // 2^-32
17890 me.s0 = me.s1;
17891 me.s1 = me.s2;
17892 return me.s2 = t - (me.c = t | 0);
17893 };
17894
17895 // Apply the seeding algorithm from Baagoe.
17896 me.c = 1;
17897 me.s0 = mash(' ');
17898 me.s1 = mash(' ');
17899 me.s2 = mash(' ');
17900 me.s0 -= mash(seed);
17901 if (me.s0 < 0) { me.s0 += 1; }
17902 me.s1 -= mash(seed);
17903 if (me.s1 < 0) { me.s1 += 1; }
17904 me.s2 -= mash(seed);
17905 if (me.s2 < 0) { me.s2 += 1; }
17906 mash = null;
17907 }
17908
17909 function copy(f, t) {
17910 t.c = f.c;
17911 t.s0 = f.s0;
17912 t.s1 = f.s1;
17913 t.s2 = f.s2;
17914 return t;
17915 }
17916
17917 function impl(seed, opts) {
17918 var xg = new Alea(seed),
17919 state = opts && opts.state,
17920 prng = xg.next;
17921 prng.int32 = function() { return (xg.next() * 0x100000000) | 0; };
17922 prng.double = function() {
17923 return prng() + (prng() * 0x200000 | 0) * 1.1102230246251565e-16; // 2^-53
17924 };
17925 prng.quick = prng;
17926 if (state) {
17927 if (typeof(state) == 'object') copy(state, xg);
17928 prng.state = function() { return copy(xg, {}); };
17929 }
17930 return prng;
17931 }
17932
17933 function Mash() {
17934 var n = 0xefc8249d;
17935
17936 var mash = function(data) {
17937 data = data.toString();
17938 for (var i = 0; i < data.length; i++) {
17939 n += data.charCodeAt(i);
17940 var h = 0.02519603282416938 * n;
17941 n = h >>> 0;
17942 h -= n;
17943 h *= n;
17944 n = h >>> 0;
17945 h -= n;
17946 n += h * 0x100000000; // 2^32
17947 }
17948 return (n >>> 0) * 2.3283064365386963e-10; // 2^-32
17949 };
17950
17951 return mash;
17952 }
17953
17954
17955 if (module && module.exports) {
17956 module.exports = impl;
17957 } else if (define && define.amd) {
17958 define(function() { return impl; });
17959 } else {
17960 this.alea = impl;
17961 }
17962
17963 })(
17964 commonjsGlobal,
17965 ('object') == 'object' && module, // present in node.js
17966 (typeof undefined) == 'function' && undefined // present with an AMD loader
17967 );
17968 });
17969
17970 var xor128 = createCommonjsModule(function (module) {
17971 // A Javascript implementaion of the "xor128" prng algorithm by
17972 // George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper
17973
17974 (function(global, module, define) {
17975
17976 function XorGen(seed) {
17977 var me = this, strseed = '';
17978
17979 me.x = 0;
17980 me.y = 0;
17981 me.z = 0;
17982 me.w = 0;
17983
17984 // Set up generator function.
17985 me.next = function() {
17986 var t = me.x ^ (me.x << 11);
17987 me.x = me.y;
17988 me.y = me.z;
17989 me.z = me.w;
17990 return me.w ^= (me.w >>> 19) ^ t ^ (t >>> 8);
17991 };
17992
17993 if (seed === (seed | 0)) {
17994 // Integer seed.
17995 me.x = seed;
17996 } else {
17997 // String seed.
17998 strseed += seed;
17999 }
18000
18001 // Mix in string seed, then discard an initial batch of 64 values.
18002 for (var k = 0; k < strseed.length + 64; k++) {
18003 me.x ^= strseed.charCodeAt(k) | 0;
18004 me.next();
18005 }
18006 }
18007
18008 function copy(f, t) {
18009 t.x = f.x;
18010 t.y = f.y;
18011 t.z = f.z;
18012 t.w = f.w;
18013 return t;
18014 }
18015
18016 function impl(seed, opts) {
18017 var xg = new XorGen(seed),
18018 state = opts && opts.state,
18019 prng = function() { return (xg.next() >>> 0) / 0x100000000; };
18020 prng.double = function() {
18021 do {
18022 var top = xg.next() >>> 11,
18023 bot = (xg.next() >>> 0) / 0x100000000,
18024 result = (top + bot) / (1 << 21);
18025 } while (result === 0);
18026 return result;
18027 };
18028 prng.int32 = xg.next;
18029 prng.quick = prng;
18030 if (state) {
18031 if (typeof(state) == 'object') copy(state, xg);
18032 prng.state = function() { return copy(xg, {}); };
18033 }
18034 return prng;
18035 }
18036
18037 if (module && module.exports) {
18038 module.exports = impl;
18039 } else if (define && define.amd) {
18040 define(function() { return impl; });
18041 } else {
18042 this.xor128 = impl;
18043 }
18044
18045 })(
18046 commonjsGlobal,
18047 ('object') == 'object' && module, // present in node.js
18048 (typeof undefined) == 'function' && undefined // present with an AMD loader
18049 );
18050 });
18051
18052 var xorwow = createCommonjsModule(function (module) {
18053 // A Javascript implementaion of the "xorwow" prng algorithm by
18054 // George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper
18055
18056 (function(global, module, define) {
18057
18058 function XorGen(seed) {
18059 var me = this, strseed = '';
18060
18061 // Set up generator function.
18062 me.next = function() {
18063 var t = (me.x ^ (me.x >>> 2));
18064 me.x = me.y; me.y = me.z; me.z = me.w; me.w = me.v;
18065 return (me.d = (me.d + 362437 | 0)) +
18066 (me.v = (me.v ^ (me.v << 4)) ^ (t ^ (t << 1))) | 0;
18067 };
18068
18069 me.x = 0;
18070 me.y = 0;
18071 me.z = 0;
18072 me.w = 0;
18073 me.v = 0;
18074
18075 if (seed === (seed | 0)) {
18076 // Integer seed.
18077 me.x = seed;
18078 } else {
18079 // String seed.
18080 strseed += seed;
18081 }
18082
18083 // Mix in string seed, then discard an initial batch of 64 values.
18084 for (var k = 0; k < strseed.length + 64; k++) {
18085 me.x ^= strseed.charCodeAt(k) | 0;
18086 if (k == strseed.length) {
18087 me.d = me.x << 10 ^ me.x >>> 4;
18088 }
18089 me.next();
18090 }
18091 }
18092
18093 function copy(f, t) {
18094 t.x = f.x;
18095 t.y = f.y;
18096 t.z = f.z;
18097 t.w = f.w;
18098 t.v = f.v;
18099 t.d = f.d;
18100 return t;
18101 }
18102
18103 function impl(seed, opts) {
18104 var xg = new XorGen(seed),
18105 state = opts && opts.state,
18106 prng = function() { return (xg.next() >>> 0) / 0x100000000; };
18107 prng.double = function() {
18108 do {
18109 var top = xg.next() >>> 11,
18110 bot = (xg.next() >>> 0) / 0x100000000,
18111 result = (top + bot) / (1 << 21);
18112 } while (result === 0);
18113 return result;
18114 };
18115 prng.int32 = xg.next;
18116 prng.quick = prng;
18117 if (state) {
18118 if (typeof(state) == 'object') copy(state, xg);
18119 prng.state = function() { return copy(xg, {}); };
18120 }
18121 return prng;
18122 }
18123
18124 if (module && module.exports) {
18125 module.exports = impl;
18126 } else if (define && define.amd) {
18127 define(function() { return impl; });
18128 } else {
18129 this.xorwow = impl;
18130 }
18131
18132 })(
18133 commonjsGlobal,
18134 ('object') == 'object' && module, // present in node.js
18135 (typeof undefined) == 'function' && undefined // present with an AMD loader
18136 );
18137 });
18138
18139 var xorshift7 = createCommonjsModule(function (module) {
18140 // A Javascript implementaion of the "xorshift7" algorithm by
18141 // François Panneton and Pierre L'ecuyer:
18142 // "On the Xorgshift Random Number Generators"
18143 // http://saluc.engr.uconn.edu/refs/crypto/rng/panneton05onthexorshift.pdf
18144
18145 (function(global, module, define) {
18146
18147 function XorGen(seed) {
18148 var me = this;
18149
18150 // Set up generator function.
18151 me.next = function() {
18152 // Update xor generator.
18153 var X = me.x, i = me.i, t, v, w;
18154 t = X[i]; t ^= (t >>> 7); v = t ^ (t << 24);
18155 t = X[(i + 1) & 7]; v ^= t ^ (t >>> 10);
18156 t = X[(i + 3) & 7]; v ^= t ^ (t >>> 3);
18157 t = X[(i + 4) & 7]; v ^= t ^ (t << 7);
18158 t = X[(i + 7) & 7]; t = t ^ (t << 13); v ^= t ^ (t << 9);
18159 X[i] = v;
18160 me.i = (i + 1) & 7;
18161 return v;
18162 };
18163
18164 function init(me, seed) {
18165 var j, w, X = [];
18166
18167 if (seed === (seed | 0)) {
18168 // Seed state array using a 32-bit integer.
18169 w = X[0] = seed;
18170 } else {
18171 // Seed state using a string.
18172 seed = '' + seed;
18173 for (j = 0; j < seed.length; ++j) {
18174 X[j & 7] = (X[j & 7] << 15) ^
18175 (seed.charCodeAt(j) + X[(j + 1) & 7] << 13);
18176 }
18177 }
18178 // Enforce an array length of 8, not all zeroes.
18179 while (X.length < 8) X.push(0);
18180 for (j = 0; j < 8 && X[j] === 0; ++j);
18181 if (j == 8) w = X[7] = -1; else w = X[j];
18182
18183 me.x = X;
18184 me.i = 0;
18185
18186 // Discard an initial 256 values.
18187 for (j = 256; j > 0; --j) {
18188 me.next();
18189 }
18190 }
18191
18192 init(me, seed);
18193 }
18194
18195 function copy(f, t) {
18196 t.x = f.x.slice();
18197 t.i = f.i;
18198 return t;
18199 }
18200
18201 function impl(seed, opts) {
18202 if (seed == null) seed = +(new Date);
18203 var xg = new XorGen(seed),
18204 state = opts && opts.state,
18205 prng = function() { return (xg.next() >>> 0) / 0x100000000; };
18206 prng.double = function() {
18207 do {
18208 var top = xg.next() >>> 11,
18209 bot = (xg.next() >>> 0) / 0x100000000,
18210 result = (top + bot) / (1 << 21);
18211 } while (result === 0);
18212 return result;
18213 };
18214 prng.int32 = xg.next;
18215 prng.quick = prng;
18216 if (state) {
18217 if (state.x) copy(state, xg);
18218 prng.state = function() { return copy(xg, {}); };
18219 }
18220 return prng;
18221 }
18222
18223 if (module && module.exports) {
18224 module.exports = impl;
18225 } else if (define && define.amd) {
18226 define(function() { return impl; });
18227 } else {
18228 this.xorshift7 = impl;
18229 }
18230
18231 })(
18232 commonjsGlobal,
18233 ('object') == 'object' && module, // present in node.js
18234 (typeof undefined) == 'function' && undefined // present with an AMD loader
18235 );
18236 });
18237
18238 var xor4096 = createCommonjsModule(function (module) {
18239 // A Javascript implementaion of Richard Brent's Xorgens xor4096 algorithm.
18240 //
18241 // This fast non-cryptographic random number generator is designed for
18242 // use in Monte-Carlo algorithms. It combines a long-period xorshift
18243 // generator with a Weyl generator, and it passes all common batteries
18244 // of stasticial tests for randomness while consuming only a few nanoseconds
18245 // for each prng generated. For background on the generator, see Brent's
18246 // paper: "Some long-period random number generators using shifts and xors."
18247 // http://arxiv.org/pdf/1004.3115v1.pdf
18248 //
18249 // Usage:
18250 //
18251 // var xor4096 = require('xor4096');
18252 // random = xor4096(1); // Seed with int32 or string.
18253 // assert.equal(random(), 0.1520436450538547); // (0, 1) range, 53 bits.
18254 // assert.equal(random.int32(), 1806534897); // signed int32, 32 bits.
18255 //
18256 // For nonzero numeric keys, this impelementation provides a sequence
18257 // identical to that by Brent's xorgens 3 implementaion in C. This
18258 // implementation also provides for initalizing the generator with
18259 // string seeds, or for saving and restoring the state of the generator.
18260 //
18261 // On Chrome, this prng benchmarks about 2.1 times slower than
18262 // Javascript's built-in Math.random().
18263
18264 (function(global, module, define) {
18265
18266 function XorGen(seed) {
18267 var me = this;
18268
18269 // Set up generator function.
18270 me.next = function() {
18271 var w = me.w,
18272 X = me.X, i = me.i, t, v;
18273 // Update Weyl generator.
18274 me.w = w = (w + 0x61c88647) | 0;
18275 // Update xor generator.
18276 v = X[(i + 34) & 127];
18277 t = X[i = ((i + 1) & 127)];
18278 v ^= v << 13;
18279 t ^= t << 17;
18280 v ^= v >>> 15;
18281 t ^= t >>> 12;
18282 // Update Xor generator array state.
18283 v = X[i] = v ^ t;
18284 me.i = i;
18285 // Result is the combination.
18286 return (v + (w ^ (w >>> 16))) | 0;
18287 };
18288
18289 function init(me, seed) {
18290 var t, v, i, j, w, X = [], limit = 128;
18291 if (seed === (seed | 0)) {
18292 // Numeric seeds initialize v, which is used to generates X.
18293 v = seed;
18294 seed = null;
18295 } else {
18296 // String seeds are mixed into v and X one character at a time.
18297 seed = seed + '\0';
18298 v = 0;
18299 limit = Math.max(limit, seed.length);
18300 }
18301 // Initialize circular array and weyl value.
18302 for (i = 0, j = -32; j < limit; ++j) {
18303 // Put the unicode characters into the array, and shuffle them.
18304 if (seed) v ^= seed.charCodeAt((j + 32) % seed.length);
18305 // After 32 shuffles, take v as the starting w value.
18306 if (j === 0) w = v;
18307 v ^= v << 10;
18308 v ^= v >>> 15;
18309 v ^= v << 4;
18310 v ^= v >>> 13;
18311 if (j >= 0) {
18312 w = (w + 0x61c88647) | 0; // Weyl.
18313 t = (X[j & 127] ^= (v + w)); // Combine xor and weyl to init array.
18314 i = (0 == t) ? i + 1 : 0; // Count zeroes.
18315 }
18316 }
18317 // We have detected all zeroes; make the key nonzero.
18318 if (i >= 128) {
18319 X[(seed && seed.length || 0) & 127] = -1;
18320 }
18321 // Run the generator 512 times to further mix the state before using it.
18322 // Factoring this as a function slows the main generator, so it is just
18323 // unrolled here. The weyl generator is not advanced while warming up.
18324 i = 127;
18325 for (j = 4 * 128; j > 0; --j) {
18326 v = X[(i + 34) & 127];
18327 t = X[i = ((i + 1) & 127)];
18328 v ^= v << 13;
18329 t ^= t << 17;
18330 v ^= v >>> 15;
18331 t ^= t >>> 12;
18332 X[i] = v ^ t;
18333 }
18334 // Storing state as object members is faster than using closure variables.
18335 me.w = w;
18336 me.X = X;
18337 me.i = i;
18338 }
18339
18340 init(me, seed);
18341 }
18342
18343 function copy(f, t) {
18344 t.i = f.i;
18345 t.w = f.w;
18346 t.X = f.X.slice();
18347 return t;
18348 };
18349
18350 function impl(seed, opts) {
18351 if (seed == null) seed = +(new Date);
18352 var xg = new XorGen(seed),
18353 state = opts && opts.state,
18354 prng = function() { return (xg.next() >>> 0) / 0x100000000; };
18355 prng.double = function() {
18356 do {
18357 var top = xg.next() >>> 11,
18358 bot = (xg.next() >>> 0) / 0x100000000,
18359 result = (top + bot) / (1 << 21);
18360 } while (result === 0);
18361 return result;
18362 };
18363 prng.int32 = xg.next;
18364 prng.quick = prng;
18365 if (state) {
18366 if (state.X) copy(state, xg);
18367 prng.state = function() { return copy(xg, {}); };
18368 }
18369 return prng;
18370 }
18371
18372 if (module && module.exports) {
18373 module.exports = impl;
18374 } else if (define && define.amd) {
18375 define(function() { return impl; });
18376 } else {
18377 this.xor4096 = impl;
18378 }
18379
18380 })(
18381 commonjsGlobal, // window object or global
18382 ('object') == 'object' && module, // present in node.js
18383 (typeof undefined) == 'function' && undefined // present with an AMD loader
18384 );
18385 });
18386
18387 var tychei = createCommonjsModule(function (module) {
18388 // A Javascript implementaion of the "Tyche-i" prng algorithm by
18389 // Samuel Neves and Filipe Araujo.
18390 // See https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf
18391
18392 (function(global, module, define) {
18393
18394 function XorGen(seed) {
18395 var me = this, strseed = '';
18396
18397 // Set up generator function.
18398 me.next = function() {
18399 var b = me.b, c = me.c, d = me.d, a = me.a;
18400 b = (b << 25) ^ (b >>> 7) ^ c;
18401 c = (c - d) | 0;
18402 d = (d << 24) ^ (d >>> 8) ^ a;
18403 a = (a - b) | 0;
18404 me.b = b = (b << 20) ^ (b >>> 12) ^ c;
18405 me.c = c = (c - d) | 0;
18406 me.d = (d << 16) ^ (c >>> 16) ^ a;
18407 return me.a = (a - b) | 0;
18408 };
18409
18410 /* The following is non-inverted tyche, which has better internal
18411 * bit diffusion, but which is about 25% slower than tyche-i in JS.
18412 me.next = function() {
18413 var a = me.a, b = me.b, c = me.c, d = me.d;
18414 a = (me.a + me.b | 0) >>> 0;
18415 d = me.d ^ a; d = d << 16 ^ d >>> 16;
18416 c = me.c + d | 0;
18417 b = me.b ^ c; b = b << 12 ^ d >>> 20;
18418 me.a = a = a + b | 0;
18419 d = d ^ a; me.d = d = d << 8 ^ d >>> 24;
18420 me.c = c = c + d | 0;
18421 b = b ^ c;
18422 return me.b = (b << 7 ^ b >>> 25);
18423 }
18424 */
18425
18426 me.a = 0;
18427 me.b = 0;
18428 me.c = 2654435769 | 0;
18429 me.d = 1367130551;
18430
18431 if (seed === Math.floor(seed)) {
18432 // Integer seed.
18433 me.a = (seed / 0x100000000) | 0;
18434 me.b = seed | 0;
18435 } else {
18436 // String seed.
18437 strseed += seed;
18438 }
18439
18440 // Mix in string seed, then discard an initial batch of 64 values.
18441 for (var k = 0; k < strseed.length + 20; k++) {
18442 me.b ^= strseed.charCodeAt(k) | 0;
18443 me.next();
18444 }
18445 }
18446
18447 function copy(f, t) {
18448 t.a = f.a;
18449 t.b = f.b;
18450 t.c = f.c;
18451 t.d = f.d;
18452 return t;
18453 };
18454
18455 function impl(seed, opts) {
18456 var xg = new XorGen(seed),
18457 state = opts && opts.state,
18458 prng = function() { return (xg.next() >>> 0) / 0x100000000; };
18459 prng.double = function() {
18460 do {
18461 var top = xg.next() >>> 11,
18462 bot = (xg.next() >>> 0) / 0x100000000,
18463 result = (top + bot) / (1 << 21);
18464 } while (result === 0);
18465 return result;
18466 };
18467 prng.int32 = xg.next;
18468 prng.quick = prng;
18469 if (state) {
18470 if (typeof(state) == 'object') copy(state, xg);
18471 prng.state = function() { return copy(xg, {}); };
18472 }
18473 return prng;
18474 }
18475
18476 if (module && module.exports) {
18477 module.exports = impl;
18478 } else if (define && define.amd) {
18479 define(function() { return impl; });
18480 } else {
18481 this.tychei = impl;
18482 }
18483
18484 })(
18485 commonjsGlobal,
18486 ('object') == 'object' && module, // present in node.js
18487 (typeof undefined) == 'function' && undefined // present with an AMD loader
18488 );
18489 });
18490
18491 var seedrandom = createCommonjsModule(function (module) {
18492 /*
18493 Copyright 2014 David Bau.
18494
18495 Permission is hereby granted, free of charge, to any person obtaining
18496 a copy of this software and associated documentation files (the
18497 "Software"), to deal in the Software without restriction, including
18498 without limitation the rights to use, copy, modify, merge, publish,
18499 distribute, sublicense, and/or sell copies of the Software, and to
18500 permit persons to whom the Software is furnished to do so, subject to
18501 the following conditions:
18502
18503 The above copyright notice and this permission notice shall be
18504 included in all copies or substantial portions of the Software.
18505
18506 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18507 EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18508 MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
18509 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
18510 CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
18511 TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
18512 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
18513
18514 */
18515
18516 (function (pool, math) {
18517 //
18518 // The following constants are related to IEEE 754 limits.
18519 //
18520 var global = this,
18521 width = 256, // each RC4 output is 0 <= x < 256
18522 chunks = 6, // at least six RC4 outputs for each double
18523 digits = 52, // there are 52 significant digits in a double
18524 rngname = 'random', // rngname: name for Math.random and Math.seedrandom
18525 startdenom = math.pow(width, chunks),
18526 significance = math.pow(2, digits),
18527 overflow = significance * 2,
18528 mask = width - 1,
18529 nodecrypto; // node.js crypto module, initialized at the bottom.
18530
18531 //
18532 // seedrandom()
18533 // This is the seedrandom function described above.
18534 //
18535 function seedrandom(seed, options, callback) {
18536 var key = [];
18537 options = (options == true) ? { entropy: true } : (options || {});
18538
18539 // Flatten the seed string or build one from local entropy if needed.
18540 var shortseed = mixkey(flatten(
18541 options.entropy ? [seed, tostring(pool)] :
18542 (seed == null) ? autoseed() : seed, 3), key);
18543
18544 // Use the seed to initialize an ARC4 generator.
18545 var arc4 = new ARC4(key);
18546
18547 // This function returns a random double in [0, 1) that contains
18548 // randomness in every bit of the mantissa of the IEEE 754 value.
18549 var prng = function() {
18550 var n = arc4.g(chunks), // Start with a numerator n < 2 ^ 48
18551 d = startdenom, // and denominator d = 2 ^ 48.
18552 x = 0; // and no 'extra last byte'.
18553 while (n < significance) { // Fill up all significant digits by
18554 n = (n + x) * width; // shifting numerator and
18555 d *= width; // denominator and generating a
18556 x = arc4.g(1); // new least-significant-byte.
18557 }
18558 while (n >= overflow) { // To avoid rounding up, before adding
18559 n /= 2; // last byte, shift everything
18560 d /= 2; // right using integer math until
18561 x >>>= 1; // we have exactly the desired bits.
18562 }
18563 return (n + x) / d; // Form the number within [0, 1).
18564 };
18565
18566 prng.int32 = function() { return arc4.g(4) | 0; };
18567 prng.quick = function() { return arc4.g(4) / 0x100000000; };
18568 prng.double = prng;
18569
18570 // Mix the randomness into accumulated entropy.
18571 mixkey(tostring(arc4.S), pool);
18572
18573 // Calling convention: what to return as a function of prng, seed, is_math.
18574 return (options.pass || callback ||
18575 function(prng, seed, is_math_call, state) {
18576 if (state) {
18577 // Load the arc4 state from the given state if it has an S array.
18578 if (state.S) { copy(state, arc4); }
18579 // Only provide the .state method if requested via options.state.
18580 prng.state = function() { return copy(arc4, {}); };
18581 }
18582
18583 // If called as a method of Math (Math.seedrandom()), mutate
18584 // Math.random because that is how seedrandom.js has worked since v1.0.
18585 if (is_math_call) { math[rngname] = prng; return seed; }
18586
18587 // Otherwise, it is a newer calling convention, so return the
18588 // prng directly.
18589 else return prng;
18590 })(
18591 prng,
18592 shortseed,
18593 'global' in options ? options.global : (this == math),
18594 options.state);
18595 }
18596 math['seed' + rngname] = seedrandom;
18597
18598 //
18599 // ARC4
18600 //
18601 // An ARC4 implementation. The constructor takes a key in the form of
18602 // an array of at most (width) integers that should be 0 <= x < (width).
18603 //
18604 // The g(count) method returns a pseudorandom integer that concatenates
18605 // the next (count) outputs from ARC4. Its return value is a number x
18606 // that is in the range 0 <= x < (width ^ count).
18607 //
18608 function ARC4(key) {
18609 var t, keylen = key.length,
18610 me = this, i = 0, j = me.i = me.j = 0, s = me.S = [];
18611
18612 // The empty key [] is treated as [0].
18613 if (!keylen) { key = [keylen++]; }
18614
18615 // Set up S using the standard key scheduling algorithm.
18616 while (i < width) {
18617 s[i] = i++;
18618 }
18619 for (i = 0; i < width; i++) {
18620 s[i] = s[j = mask & (j + key[i % keylen] + (t = s[i]))];
18621 s[j] = t;
18622 }
18623
18624 // The "g" method returns the next (count) outputs as one number.
18625 (me.g = function(count) {
18626 // Using instance members instead of closure state nearly doubles speed.
18627 var t, r = 0,
18628 i = me.i, j = me.j, s = me.S;
18629 while (count--) {
18630 t = s[i = mask & (i + 1)];
18631 r = r * width + s[mask & ((s[i] = s[j = mask & (j + t)]) + (s[j] = t))];
18632 }
18633 me.i = i; me.j = j;
18634 return r;
18635 // For robust unpredictability, the function call below automatically
18636 // discards an initial batch of values. This is called RC4-drop[256].
18637 // See http://google.com/search?q=rsa+fluhrer+response&btnI
18638 })(width);
18639 }
18640
18641 //
18642 // copy()
18643 // Copies internal state of ARC4 to or from a plain object.
18644 //
18645 function copy(f, t) {
18646 t.i = f.i;
18647 t.j = f.j;
18648 t.S = f.S.slice();
18649 return t;
18650 };
18651
18652 //
18653 // flatten()
18654 // Converts an object tree to nested arrays of strings.
18655 //
18656 function flatten(obj, depth) {
18657 var result = [], typ = (typeof obj), prop;
18658 if (depth && typ == 'object') {
18659 for (prop in obj) {
18660 try { result.push(flatten(obj[prop], depth - 1)); } catch (e) {}
18661 }
18662 }
18663 return (result.length ? result : typ == 'string' ? obj : obj + '\0');
18664 }
18665
18666 //
18667 // mixkey()
18668 // Mixes a string seed into a key that is an array of integers, and
18669 // returns a shortened string seed that is equivalent to the result key.
18670 //
18671 function mixkey(seed, key) {
18672 var stringseed = seed + '', smear, j = 0;
18673 while (j < stringseed.length) {
18674 key[mask & j] =
18675 mask & ((smear ^= key[mask & j] * 19) + stringseed.charCodeAt(j++));
18676 }
18677 return tostring(key);
18678 }
18679
18680 //
18681 // autoseed()
18682 // Returns an object for autoseeding, using window.crypto and Node crypto
18683 // module if available.
18684 //
18685 function autoseed() {
18686 try {
18687 var out;
18688 if (nodecrypto && (out = nodecrypto.randomBytes)) {
18689 // The use of 'out' to remember randomBytes makes tight minified code.
18690 out = out(width);
18691 } else {
18692 out = new Uint8Array(width);
18693 (global.crypto || global.msCrypto).getRandomValues(out);
18694 }
18695 return tostring(out);
18696 } catch (e) {
18697 var browser = global.navigator,
18698 plugins = browser && browser.plugins;
18699 return [+new Date, global, plugins, global.screen, tostring(pool)];
18700 }
18701 }
18702
18703 //
18704 // tostring()
18705 // Converts an array of charcodes to a string
18706 //
18707 function tostring(a) {
18708 return String.fromCharCode.apply(0, a);
18709 }
18710
18711 //
18712 // When seedrandom.js is loaded, we immediately mix a few bits
18713 // from the built-in RNG into the entropy pool. Because we do
18714 // not want to interfere with deterministic PRNG state later,
18715 // seedrandom will not call math.random on its own again after
18716 // initialization.
18717 //
18718 mixkey(math.random(), pool);
18719
18720 //
18721 // Nodejs and AMD support: export the implementation as a module using
18722 // either convention.
18723 //
18724 if (('object') == 'object' && module.exports) {
18725 module.exports = seedrandom;
18726 // When in node.js, try using crypto package for autoseeding.
18727 try {
18728 nodecrypto = require('crypto');
18729 } catch (ex) {}
18730 } else if ((typeof undefined) == 'function' && undefined.amd) {
18731 undefined(function() { return seedrandom; });
18732 }
18733
18734 // End anonymous scope, and pass initial values.
18735 })(
18736 [], // pool: entropy pool starts empty
18737 Math // math: package containing random, pow, and seedrandom
18738 );
18739 });
18740
18741 // A library of seedable RNGs implemented in Javascript.
18742 //
18743 // Usage:
18744 //
18745 // var seedrandom = require('seedrandom');
18746 // var random = seedrandom(1); // or any seed.
18747 // var x = random(); // 0 <= x < 1. Every bit is random.
18748 // var x = random.quick(); // 0 <= x < 1. 32 bits of randomness.
18749
18750 // alea, a 53-bit multiply-with-carry generator by Johannes Baagøe.
18751 // Period: ~2^116
18752 // Reported to pass all BigCrush tests.
18753
18754
18755 // xor128, a pure xor-shift generator by George Marsaglia.
18756 // Period: 2^128-1.
18757 // Reported to fail: MatrixRank and LinearComp.
18758
18759
18760 // xorwow, George Marsaglia's 160-bit xor-shift combined plus weyl.
18761 // Period: 2^192-2^32
18762 // Reported to fail: CollisionOver, SimpPoker, and LinearComp.
18763
18764
18765 // xorshift7, by François Panneton and Pierre L'ecuyer, takes
18766 // a different approach: it adds robustness by allowing more shifts
18767 // than Marsaglia's original three. It is a 7-shift generator
18768 // with 256 bits, that passes BigCrush with no systmatic failures.
18769 // Period 2^256-1.
18770 // No systematic BigCrush failures reported.
18771
18772
18773 // xor4096, by Richard Brent, is a 4096-bit xor-shift with a
18774 // very long period that also adds a Weyl generator. It also passes
18775 // BigCrush with no systematic failures. Its long period may
18776 // be useful if you have many generators and need to avoid
18777 // collisions.
18778 // Period: 2^4128-2^32.
18779 // No systematic BigCrush failures reported.
18780
18781
18782 // Tyche-i, by Samuel Neves and Filipe Araujo, is a bit-shifting random
18783 // number generator derived from ChaCha, a modern stream cipher.
18784 // https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf
18785 // Period: ~2^127
18786 // No systematic BigCrush failures reported.
18787
18788
18789 // The original ARC4-based prng included in this library.
18790 // Period: ~2^1600
18791
18792
18793 seedrandom.alea = alea;
18794 seedrandom.xor128 = xor128;
18795 seedrandom.xorwow = xorwow;
18796 seedrandom.xorshift7 = xorshift7;
18797 seedrandom.xor4096 = xor4096;
18798 seedrandom.tychei = tychei;
18799
18800 var seedrandom$1 = seedrandom;
18801 var seedrandom_1 = seedrandom$1.alea;
18802
18803 /**
18804 * @license
18805 * Copyright 2018 Google LLC. All Rights Reserved.
18806 * Licensed under the Apache License, Version 2.0 (the "License");
18807 * you may not use this file except in compliance with the License.
18808 * You may obtain a copy of the License at
18809 *
18810 * http://www.apache.org/licenses/LICENSE-2.0
18811 *
18812 * Unless required by applicable law or agreed to in writing, software
18813 * distributed under the License is distributed on an "AS IS" BASIS,
18814 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18815 * See the License for the specific language governing permissions and
18816 * limitations under the License.
18817 * =============================================================================
18818 */
18819 // https://en.wikipedia.org/wiki/Marsaglia_polar_method
18820 class MPRandGauss {
18821 constructor(mean, stdDeviation, dtype, truncated, seed) {
18822 this.mean = mean;
18823 this.stdDev = stdDeviation;
18824 this.dtype = dtype;
18825 this.nextVal = NaN;
18826 this.truncated = truncated;
18827 if (this.truncated) {
18828 this.upper = this.mean + this.stdDev * 2;
18829 this.lower = this.mean - this.stdDev * 2;
18830 }
18831 const seedValue = seed ? seed : Math.random();
18832 this.random = seedrandom_1(seedValue.toString());
18833 }
18834 /** Returns next sample from a Gaussian distribution. */
18835 nextValue() {
18836 if (!isNaN(this.nextVal)) {
18837 const value = this.nextVal;
18838 this.nextVal = NaN;
18839 return value;
18840 }
18841 let resultX, resultY;
18842 let isValid = false;
18843 while (!isValid) {
18844 let v1, v2, s;
18845 do {
18846 v1 = 2 * this.random() - 1;
18847 v2 = 2 * this.random() - 1;
18848 s = v1 * v1 + v2 * v2;
18849 } while (s >= 1 || s === 0);
18850 const mul = Math.sqrt(-2.0 * Math.log(s) / s);
18851 resultX = this.mean + this.stdDev * v1 * mul;
18852 resultY = this.mean + this.stdDev * v2 * mul;
18853 if (!this.truncated || this.isValidTruncated(resultX)) {
18854 isValid = true;
18855 }
18856 }
18857 if (!this.truncated || this.isValidTruncated(resultY)) {
18858 this.nextVal = this.convertValue(resultY);
18859 }
18860 return this.convertValue(resultX);
18861 }
18862 /** Handles proper rounding for non-floating-point numbers. */
18863 convertValue(value) {
18864 if (this.dtype == null || this.dtype === 'float32') {
18865 return value;
18866 }
18867 return Math.round(value);
18868 }
18869 /** Returns true if less than 2-standard-deviations from the mean. */
18870 isValidTruncated(value) {
18871 return value <= this.upper && value >= this.lower;
18872 }
18873 }
18874 // Marsaglia, George, and Wai Wan Tsang. 2000. "A Simple Method for Generating
18875 // Gamma Variables."
18876 class RandGamma {
18877 constructor(alpha, beta, dtype, seed) {
18878 this.alpha = alpha;
18879 this.beta = 1 / beta; // convert rate to scale parameter
18880 this.dtype = dtype;
18881 const seedValue = seed ? seed : Math.random();
18882 this.randu = seedrandom_1(seedValue.toString());
18883 this.randn = new MPRandGauss(0, 1, dtype, false, this.randu());
18884 if (alpha < 1) {
18885 this.d = alpha + (2 / 3);
18886 }
18887 else {
18888 this.d = alpha - (1 / 3);
18889 }
18890 this.c = 1 / Math.sqrt(9 * this.d);
18891 }
18892 /** Returns next sample from a gamma distribution. */
18893 nextValue() {
18894 let x2, v0, v1, x, u, v;
18895 while (true) {
18896 do {
18897 x = this.randn.nextValue();
18898 v = 1 + (this.c * x);
18899 } while (v <= 0);
18900 v *= v * v;
18901 x2 = x * x;
18902 v0 = 1 - (0.331 * x2 * x2);
18903 v1 = (0.5 * x2) + (this.d * (1 - v + Math.log(v)));
18904 u = this.randu();
18905 if (u < v0 || Math.log(u) < v1) {
18906 break;
18907 }
18908 }
18909 v = (1 / this.beta) * this.d * v;
18910 if (this.alpha < 1) {
18911 v *= Math.pow(this.randu(), 1 / this.alpha);
18912 }
18913 return this.convertValue(v);
18914 }
18915 /** Handles proper rounding for non-floating-point numbers. */
18916 convertValue(value) {
18917 if (this.dtype === 'float32') {
18918 return value;
18919 }
18920 return Math.round(value);
18921 }
18922 }
18923 class UniformRandom {
18924 constructor(min = 0, max = 1, dtype, seed) {
18925 /** Handles proper rounding for non floating point numbers. */
18926 this.canReturnFloat = () => (this.dtype == null || this.dtype === 'float32');
18927 this.min = min;
18928 this.range = max - min;
18929 this.dtype = dtype;
18930 if (seed == null) {
18931 seed = Math.random();
18932 }
18933 if (typeof seed === 'number') {
18934 seed = seed.toString();
18935 }
18936 if (!this.canReturnFloat() && this.range <= 1) {
18937 throw new Error(`The difference between ${min} - ${max} <= 1 and dtype is not float`);
18938 }
18939 this.random = seedrandom_1(seed);
18940 }
18941 convertValue(value) {
18942 if (this.canReturnFloat()) {
18943 return value;
18944 }
18945 return Math.round(value);
18946 }
18947 nextValue() {
18948 return this.convertValue(this.min + this.range * this.random());
18949 }
18950 }
18951 function jarqueBeraNormalityTest(values) {
18952 // https://en.wikipedia.org/wiki/Jarque%E2%80%93Bera_test
18953 const n = values.length;
18954 const s = skewness(values);
18955 const k = kurtosis(values);
18956 const jb = n / 6 * (Math.pow(s, 2) + 0.25 * Math.pow(k - 3, 2));
18957 // JB test requires 2-degress of freedom from Chi-Square @ 0.95:
18958 // http://www.itl.nist.gov/div898/handbook/eda/section3/eda3674.htm
18959 const CHI_SQUARE_2DEG = 5.991;
18960 if (jb > CHI_SQUARE_2DEG) {
18961 throw new Error(`Invalid p-value for JB: ${jb}`);
18962 }
18963 }
18964 function expectArrayInMeanStdRange(actual, expectedMean, expectedStdDev, epsilon) {
18965 if (epsilon == null) {
18966 epsilon = testEpsilon();
18967 }
18968 const actualMean = mean$1(actual);
18969 expectNumbersClose(actualMean, expectedMean, epsilon);
18970 expectNumbersClose(standardDeviation(actual, actualMean), expectedStdDev, epsilon);
18971 }
18972 function mean$1(values) {
18973 let sum = 0;
18974 for (let i = 0; i < values.length; i++) {
18975 sum += values[i];
18976 }
18977 return sum / values.length;
18978 }
18979 function standardDeviation(values, mean) {
18980 let squareDiffSum = 0;
18981 for (let i = 0; i < values.length; i++) {
18982 const diff = values[i] - mean;
18983 squareDiffSum += diff * diff;
18984 }
18985 return Math.sqrt(squareDiffSum / values.length);
18986 }
18987 function kurtosis(values) {
18988 // https://en.wikipedia.org/wiki/Kurtosis
18989 const valuesMean = mean$1(values);
18990 const n = values.length;
18991 let sum2 = 0;
18992 let sum4 = 0;
18993 for (let i = 0; i < n; i++) {
18994 const v = values[i] - valuesMean;
18995 sum2 += Math.pow(v, 2);
18996 sum4 += Math.pow(v, 4);
18997 }
18998 return (1 / n) * sum4 / Math.pow((1 / n) * sum2, 2);
18999 }
19000 function skewness(values) {
19001 // https://en.wikipedia.org/wiki/Skewness
19002 const valuesMean = mean$1(values);
19003 const n = values.length;
19004 let sum2 = 0;
19005 let sum3 = 0;
19006 for (let i = 0; i < n; i++) {
19007 const v = values[i] - valuesMean;
19008 sum2 += Math.pow(v, 2);
19009 sum3 += Math.pow(v, 3);
19010 }
19011 return (1 / n) * sum3 / Math.pow((1 / (n - 1)) * sum2, 3 / 2);
19012 }
19013
19014 /**
19015 * @license
19016 * Copyright 2020 Google LLC. All Rights Reserved.
19017 * Licensed under the Apache License, Version 2.0 (the "License");
19018 * you may not use this file except in compliance with the License.
19019 * You may obtain a copy of the License at
19020 *
19021 * http://www.apache.org/licenses/LICENSE-2.0
19022 *
19023 * Unless required by applicable law or agreed to in writing, software
19024 * distributed under the License is distributed on an "AS IS" BASIS,
19025 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19026 * See the License for the specific language governing permissions and
19027 * limitations under the License.
19028 * =============================================================================
19029 */
19030 /**
19031 * Creates a `tf.Tensor` with values sampled from a gamma distribution.
19032 *
19033 * ```js
19034 * tf.randomGamma([2, 2], 1).print();
19035 * ```
19036 *
19037 * @param shape An array of integers defining the output tensor shape.
19038 * @param alpha The shape parameter of the gamma distribution.
19039 * @param beta The inverse scale parameter of the gamma distribution. Defaults
19040 * to 1.
19041 * @param dtype The data type of the output. Defaults to float32.
19042 * @param seed The seed for the random number generator.
19043 *
19044 * @doc {heading: 'Tensors', subheading: 'Random'}
19045 */
19046 function randomGamma_(shape, alpha, beta = 1, dtype = 'float32', seed) {
19047 if (beta == null) {
19048 beta = 1;
19049 }
19050 if (dtype == null) {
19051 dtype = 'float32';
19052 }
19053 if (dtype !== 'float32' && dtype !== 'int32') {
19054 throw new Error(`Unsupported data type ${dtype}`);
19055 }
19056 const rgamma = new RandGamma(alpha, beta, dtype, seed);
19057 const res = buffer(shape, dtype);
19058 for (let i = 0; i < res.values.length; i++) {
19059 res.values[i] = rgamma.nextValue();
19060 }
19061 return res.toTensor();
19062 }
19063 const randomGamma = op({ randomGamma_ });
19064
19065 /**
19066 * @license
19067 * Copyright 2020 Google LLC. All Rights Reserved.
19068 * Licensed under the Apache License, Version 2.0 (the "License");
19069 * you may not use this file except in compliance with the License.
19070 * You may obtain a copy of the License at
19071 *
19072 * http://www.apache.org/licenses/LICENSE-2.0
19073 *
19074 * Unless required by applicable law or agreed to in writing, software
19075 * distributed under the License is distributed on an "AS IS" BASIS,
19076 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19077 * See the License for the specific language governing permissions and
19078 * limitations under the License.
19079 * =============================================================================
19080 */
19081 /**
19082 * Creates a `tf.Tensor` with values sampled from a normal distribution.
19083 *
19084 * ```js
19085 * tf.randomNormal([2, 2]).print();
19086 * ```
19087 *
19088 * @param shape An array of integers defining the output tensor shape.
19089 * @param mean The mean of the normal distribution.
19090 * @param stdDev The standard deviation of the normal distribution.
19091 * @param dtype The data type of the output.
19092 * @param seed The seed for the random number generator.
19093 *
19094 * @doc {heading: 'Tensors', subheading: 'Random'}
19095 */
19096 function randomNormal_(shape, mean = 0, stdDev = 1, dtype, seed) {
19097 if (dtype != null && dtype === 'bool') {
19098 throw new Error(`Unsupported data type ${dtype}`);
19099 }
19100 const randGauss = new MPRandGauss(mean, stdDev, dtype, false /* truncated */, seed);
19101 const res = buffer(shape, dtype);
19102 for (let i = 0; i < res.values.length; i++) {
19103 res.values[i] = randGauss.nextValue();
19104 }
19105 return res.toTensor();
19106 }
19107 const randomNormal = op({ randomNormal_ });
19108
19109 /**
19110 * @license
19111 * Copyright 2020 Google LLC. All Rights Reserved.
19112 * Licensed under the Apache License, Version 2.0 (the "License");
19113 * you may not use this file except in compliance with the License.
19114 * You may obtain a copy of the License at
19115 *
19116 * http://www.apache.org/licenses/LICENSE-2.0
19117 *
19118 * Unless required by applicable law or agreed to in writing, software
19119 * distributed under the License is distributed on an "AS IS" BASIS,
19120 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19121 * See the License for the specific language governing permissions and
19122 * limitations under the License.
19123 * =============================================================================
19124 */
19125 /**
19126 * Creates a `tf.Tensor` with values sampled from a uniform distribution.
19127 *
19128 * The generated values follow a uniform distribution in the range [minval,
19129 * maxval). The lower bound minval is included in the range, while the upper
19130 * bound maxval is excluded.
19131 *
19132 * ```js
19133 * tf.randomUniform([2, 2]).print();
19134 * ```
19135 *
19136 * @param shape An array of integers defining the output tensor shape.
19137 * @param minval The lower bound on the range of random values to generate.
19138 * Defaults to 0.
19139 * @param maxval The upper bound on the range of random values to generate.
19140 * Defaults to 1.
19141 * @param dtype The data type of the output tensor. Defaults to 'float32'.
19142 *
19143 * @doc {heading: 'Tensors', subheading: 'Random'}
19144 */
19145 function randomUniform_(shape, minval = 0, maxval = 1, dtype = 'float32', seed) {
19146 const res = buffer(shape, dtype);
19147 const random = new UniformRandom(minval, maxval, null, seed);
19148 for (let i = 0; i < res.values.length; i++) {
19149 res.values[i] = random.nextValue();
19150 }
19151 return res.toTensor();
19152 }
19153 const randomUniform = op({ randomUniform_ });
19154
19155 /**
19156 * @license
19157 * Copyright 2018 Google LLC. All Rights Reserved.
19158 * Licensed under the Apache License, Version 2.0 (the "License");
19159 * you may not use this file except in compliance with the License.
19160 * You may obtain a copy of the License at
19161 *
19162 * http://www.apache.org/licenses/LICENSE-2.0
19163 *
19164 * Unless required by applicable law or agreed to in writing, software
19165 * distributed under the License is distributed on an "AS IS" BASIS,
19166 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19167 * See the License for the specific language governing permissions and
19168 * limitations under the License.
19169 * =============================================================================
19170 */
19171 /**
19172 * Creates a new `tf.Tensor1D` filled with the numbers in the range provided.
19173 *
19174 * The tensor is a is half-open interval meaning it includes start, but
19175 * excludes stop. Decrementing ranges and negative step values are also
19176 * supported.sv
19177 *
19178 *
19179 * ```js
19180 * tf.range(0, 9, 2).print();
19181 * ```
19182 *
19183 * @param start An integer start value
19184 * @param stop An integer stop value
19185 * @param step An integer increment (will default to 1 or -1)
19186 * @param dtype The data type of the output tensor. Defaults to 'float32'.
19187 *
19188 * @doc {heading: 'Tensors', subheading: 'Creation'}
19189 */
19190 function range(start, stop, step = 1, dtype = 'float32') {
19191 if (step === 0) {
19192 throw new Error('Cannot have a step of zero');
19193 }
19194 const attrs = { start, stop, step, dtype };
19195 return ENGINE.runKernel(Range, {} /* inputs */, attrs);
19196 }
19197
19198 /**
19199 * @license
19200 * Copyright 2018 Google LLC. All Rights Reserved.
19201 * Licensed under the Apache License, Version 2.0 (the "License");
19202 * you may not use this file except in compliance with the License.
19203 * You may obtain a copy of the License at
19204 *
19205 * http://www.apache.org/licenses/LICENSE-2.0
19206 *
19207 * Unless required by applicable law or agreed to in writing, software
19208 * distributed under the License is distributed on an "AS IS" BASIS,
19209 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19210 * See the License for the specific language governing permissions and
19211 * limitations under the License.
19212 * =============================================================================
19213 */
19214 /**
19215 * Computes reciprocal of x element-wise: `1 / x`
19216 *
19217 * ```js
19218 * const x = tf.tensor1d([0, 1, 2]);
19219 *
19220 * x.reciprocal().print(); // or tf.reciprocal(x)
19221 * ```
19222 * @param x The input tensor.
19223 *
19224 * @doc {heading: 'Operations', subheading: 'Basic math'}
19225 */
19226 function reciprocal_(x) {
19227 const $x = convertToTensor(x, 'x', 'reciprocal');
19228 const inputs = { x: $x };
19229 return ENGINE.runKernel(Reciprocal, inputs);
19230 }
19231 const reciprocal = op({ reciprocal_ });
19232
19233 /**
19234 * @license
19235 * Copyright 2020 Google LLC. All Rights Reserved.
19236 * Licensed under the Apache License, Version 2.0 (the "License");
19237 * you may not use this file except in compliance with the License.
19238 * You may obtain a copy of the License at
19239 *
19240 * http://www.apache.org/licenses/LICENSE-2.0
19241 *
19242 * Unless required by applicable law or agreed to in writing, software
19243 * distributed under the License is distributed on an "AS IS" BASIS,
19244 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19245 * See the License for the specific language governing permissions and
19246 * limitations under the License.
19247 * =============================================================================
19248 */
19249 /**
19250 * Computes rectified linear element-wise: `max(x, 0)`.
19251 *
19252 * ```js
19253 * const x = tf.tensor1d([-1, 2, -3, 4]);
19254 *
19255 * x.relu().print(); // or tf.relu(x)
19256 * ```
19257 * @param x The input tensor. If the dtype is `bool`, the output dtype will be
19258 * `int32'.
19259 *
19260 * @doc {heading: 'Operations', subheading: 'Basic math'}
19261 */
19262 function relu_(x) {
19263 const $x = convertToTensor(x, 'x', 'relu');
19264 const inputs = { x: $x };
19265 return ENGINE.runKernel(Relu, inputs);
19266 }
19267 const relu = op({ relu_ });
19268
19269 /**
19270 * @license
19271 * Copyright 2020 Google LLC. All Rights Reserved.
19272 * Licensed under the Apache License, Version 2.0 (the "License");
19273 * you may not use this file except in compliance with the License.
19274 * You may obtain a copy of the License at
19275 *
19276 * http://www.apache.org/licenses/LICENSE-2.0
19277 *
19278 * Unless required by applicable law or agreed to in writing, software
19279 * distributed under the License is distributed on an "AS IS" BASIS,
19280 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19281 * See the License for the specific language governing permissions and
19282 * limitations under the License.
19283 * =============================================================================
19284 */
19285 /**
19286 * Computes rectified linear 6 element-wise: `min(max(x, 0), 6)`.
19287 *
19288 * ```js
19289 * const x = tf.tensor1d([-1, 2, -3, 8]);
19290 *
19291 * x.relu6().print(); // or tf.relu6(x)
19292 * ```
19293 * @param x The input tensor. If the dtype is `bool`, the output dtype will be
19294 * `int32'.
19295 *
19296 * @doc {heading: 'Operations', subheading: 'Basic math'}
19297 */
19298 function relu6_(x) {
19299 const $x = convertToTensor(x, 'x', 'relu6');
19300 const inputs = { x: $x };
19301 return ENGINE.runKernel(Relu6, inputs);
19302 }
19303 const relu6 = op({ relu6_ });
19304
19305 /**
19306 * @license
19307 * Copyright 2018 Google LLC. All Rights Reserved.
19308 * Licensed under the Apache License, Version 2.0 (the "License");
19309 * you may not use this file except in compliance with the License.
19310 * You may obtain a copy of the License at
19311 *
19312 * http://www.apache.org/licenses/LICENSE-2.0
19313 *
19314 * Unless required by applicable law or agreed to in writing, software
19315 * distributed under the License is distributed on an "AS IS" BASIS,
19316 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19317 * See the License for the specific language governing permissions and
19318 * limitations under the License.
19319 * =============================================================================
19320 */
19321 /**
19322 * Reverses a `tf.Tensor` along a specified axis.
19323 *
19324 * Also available are stricter rank-specific methods that assert that `x` is
19325 * of the given rank:
19326 * - `tf.reverse1d`
19327 * - `tf.reverse2d`
19328 * - `tf.reverse3d`
19329 * - `tf.reverse4d`
19330 *
19331 * Except `tf.reverse1d` (which does not have axis param), all methods have
19332 * same signature as this method.
19333 *
19334 * ```js
19335 * const x = tf.tensor1d([1, 2, 3, 4]);
19336 *
19337 * x.reverse().print();
19338 * ```
19339 *
19340 * ```js
19341 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
19342 *
19343 * const axis = 1;
19344 * x.reverse(axis).print();
19345 * ```
19346 * @param x The input tensor to be reversed.
19347 * @param axis The set of dimensions to reverse. Must be in the
19348 * range [-rank(x), rank(x)). Defaults to all axes.
19349 *
19350 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
19351 */
19352 function reverse_(x, axis) {
19353 const $x = convertToTensor(x, 'x', 'reverse');
19354 const inputs = { x: $x };
19355 const attrs = { dims: axis };
19356 return ENGINE.runKernel(Reverse, inputs, attrs);
19357 }
19358 const reverse = op({ reverse_ });
19359
19360 /**
19361 * @license
19362 * Copyright 2020 Google LLC. All Rights Reserved.
19363 * Licensed under the Apache License, Version 2.0 (the "License");
19364 * you may not use this file except in compliance with the License.
19365 * You may obtain a copy of the License at
19366 *
19367 * http://www.apache.org/licenses/LICENSE-2.0
19368 *
19369 * Unless required by applicable law or agreed to in writing, software
19370 * distributed under the License is distributed on an "AS IS" BASIS,
19371 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19372 * See the License for the specific language governing permissions and
19373 * limitations under the License.
19374 * =============================================================================
19375 */
19376 /**
19377 * Reverses a `tf.Tensor1D`.
19378 *
19379 * @param x The input tensor.
19380 */
19381 function reverse1d_(x) {
19382 const $x = convertToTensor(x, 'x', 'reverse');
19383 assert($x.rank === 1, () => `Error in reverse1D: x must be rank 1 but got rank ${$x.rank}.`);
19384 return reverse($x, 0);
19385 }
19386 const reverse1d = op({ reverse1d_ });
19387
19388 /**
19389 * @license
19390 * Copyright 2020 Google LLC. All Rights Reserved.
19391 * Licensed under the Apache License, Version 2.0 (the "License");
19392 * you may not use this file except in compliance with the License.
19393 * You may obtain a copy of the License at
19394 *
19395 * http://www.apache.org/licenses/LICENSE-2.0
19396 *
19397 * Unless required by applicable law or agreed to in writing, software
19398 * distributed under the License is distributed on an "AS IS" BASIS,
19399 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19400 * See the License for the specific language governing permissions and
19401 * limitations under the License.
19402 * =============================================================================
19403 */
19404 /**
19405 * Reverses a `tf.Tensor2D` along a specified axis.
19406 *
19407 * @param x The input tensor.
19408 * @param axis The set of dimensions to reverse. Must be in the
19409 * range [-rank(x), rank(x)). Defaults to all axes.
19410 */
19411 function reverse2d_(x, axis) {
19412 const $x = convertToTensor(x, 'x', 'reverse');
19413 assert($x.rank === 2, () => `Error in reverse2D: x must be rank 2 but got rank ${$x.rank}.`);
19414 return reverse($x, axis);
19415 }
19416 const reverse2d = op({ reverse2d_ });
19417
19418 /**
19419 * @license
19420 * Copyright 2020 Google LLC. All Rights Reserved.
19421 * Licensed under the Apache License, Version 2.0 (the "License");
19422 * you may not use this file except in compliance with the License.
19423 * You may obtain a copy of the License at
19424 *
19425 * http://www.apache.org/licenses/LICENSE-2.0
19426 *
19427 * Unless required by applicable law or agreed to in writing, software
19428 * distributed under the License is distributed on an "AS IS" BASIS,
19429 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19430 * See the License for the specific language governing permissions and
19431 * limitations under the License.
19432 * =============================================================================
19433 */
19434 /**
19435 * Reverses a `tf.Tensor3D` along a specified axis.
19436 *
19437 * @param x The input tensor.
19438 * @param axis The set of dimensions to reverse. Must be in the
19439 * range [-rank(x), rank(x)). Defaults to all axes.
19440 */
19441 function reverse3d_(x, axis) {
19442 const $x = convertToTensor(x, 'x', 'reverse');
19443 assert($x.rank === 3, () => `Error in reverse3D: x must be rank 3 but got rank ${$x.rank}.`);
19444 return reverse($x, axis);
19445 }
19446 const reverse3d = op({ reverse3d_ });
19447
19448 /**
19449 * @license
19450 * Copyright 2020 Google LLC. All Rights Reserved.
19451 * Licensed under the Apache License, Version 2.0 (the "License");
19452 * you may not use this file except in compliance with the License.
19453 * You may obtain a copy of the License at
19454 *
19455 * http://www.apache.org/licenses/LICENSE-2.0
19456 *
19457 * Unless required by applicable law or agreed to in writing, software
19458 * distributed under the License is distributed on an "AS IS" BASIS,
19459 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19460 * See the License for the specific language governing permissions and
19461 * limitations under the License.
19462 * =============================================================================
19463 */
19464 /**
19465 * Reverses a `tf.Tensor4D` along a specified axis.
19466 *
19467 * @param x The input tensor.
19468 * @param axis The set of dimensions to reverse. Must be in the
19469 * range [-rank(x), rank(x)). Defaults to all axes.
19470 */
19471 function reverse4d_(x, axis) {
19472 const $x = convertToTensor(x, 'x', 'reverse');
19473 assert($x.rank === 4, () => `Error in reverse4D: x must be rank 4 but got rank ${$x.rank}.`);
19474 return reverse($x, axis);
19475 }
19476 const reverse4d = op({ reverse4d_ });
19477
19478 /**
19479 * @license
19480 * Copyright 2018 Google LLC. All Rights Reserved.
19481 * Licensed under the Apache License, Version 2.0 (the "License");
19482 * you may not use this file except in compliance with the License.
19483 * You may obtain a copy of the License at
19484 *
19485 * http://www.apache.org/licenses/LICENSE-2.0
19486 *
19487 * Unless required by applicable law or agreed to in writing, software
19488 * distributed under the License is distributed on an "AS IS" BASIS,
19489 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19490 * See the License for the specific language governing permissions and
19491 * limitations under the License.
19492 * =============================================================================
19493 */
19494 /**
19495 * Computes round of input `tf.Tensor` element-wise: `round(x)`.
19496 * It implements banker's rounding.
19497 *
19498 * ```js
19499 * const x = tf.tensor1d([.6, 1.1, -3.3]);
19500 *
19501 * x.round().print(); // or tf.round(x)
19502 * ```
19503 * @param x The input tensor.
19504 *
19505 * @doc {heading: 'Operations', subheading: 'Basic math'}
19506 */
19507 function round_(x) {
19508 const $x = convertToTensor(x, 'x', 'round');
19509 const inputs = { x: $x };
19510 return ENGINE.runKernel(Round, inputs);
19511 }
19512 const round$1 = op({ round_ });
19513
19514 /**
19515 * @license
19516 * Copyright 2018 Google LLC. All Rights Reserved.
19517 * Licensed under the Apache License, Version 2.0 (the "License");
19518 * you may not use this file except in compliance with the License.
19519 * You may obtain a copy of the License at
19520 *
19521 * http://www.apache.org/licenses/LICENSE-2.0
19522 *
19523 * Unless required by applicable law or agreed to in writing, software
19524 * distributed under the License is distributed on an "AS IS" BASIS,
19525 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19526 * See the License for the specific language governing permissions and
19527 * limitations under the License.
19528 * =============================================================================
19529 */
19530 /**
19531 * Computes reciprocal of square root of the input `tf.Tensor` element-wise:
19532 * `y = 1 / sqrt(x)`
19533 *
19534 * ```js
19535 * const x = tf.tensor1d([1, 2, 4, -1]);
19536 *
19537 * x.rsqrt().print(); // or tf.rsqrt(x)
19538 * ```
19539 * @param x The input tensor.
19540 *
19541 * @doc {heading: 'Operations', subheading: 'Basic math'}
19542 */
19543 function rsqrt_(x) {
19544 const $x = convertToTensor(x, 'x', 'rsqrt', 'float32');
19545 const inputs = { x: $x };
19546 return ENGINE.runKernel(Rsqrt, inputs);
19547 }
19548 const rsqrt = op({ rsqrt_ });
19549
19550 /**
19551 * @license
19552 * Copyright 2020 Google LLC. All Rights Reserved.
19553 * Licensed under the Apache License, Version 2.0 (the "License");
19554 * you may not use this file except in compliance with the License.
19555 * You may obtain a copy of the License at
19556 *
19557 * http://www.apache.org/licenses/LICENSE-2.0
19558 *
19559 * Unless required by applicable law or agreed to in writing, software
19560 * distributed under the License is distributed on an "AS IS" BASIS,
19561 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19562 * See the License for the specific language governing permissions and
19563 * limitations under the License.
19564 * =============================================================================
19565 */
19566 /**
19567 * Computes scaled exponential linear element-wise.
19568 *
19569 * `x < 0 ? scale * alpha * (exp(x) - 1) : x`
19570 *
19571 * ```js
19572 * const x = tf.tensor1d([-1, 2, -3, 4]);
19573 *
19574 * x.selu().print(); // or tf.selu(x)
19575 * ```
19576 * @param x The input tensor.
19577 *
19578 * @doc {heading: 'Operations', subheading: 'Basic math'}
19579 */
19580 function selu_(x) {
19581 const $x = convertToTensor(x, 'x', 'selu');
19582 const inputs = { x: $x };
19583 return ENGINE.runKernel(Selu, inputs);
19584 }
19585 const selu = op({ selu_ });
19586
19587 /**
19588 * 2-D convolution with separable filters.
19589 *
19590 * Performs a depthwise convolution that acts separately on channels followed
19591 * by a pointwise convolution that mixes channels. Note that this is
19592 * separability between dimensions [1, 2] and 3, not spatial separability
19593 * between dimensions 1 and 2.
19594 *
19595 * See
19596 * [https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d](
19597 * https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d)
19598 * for more details.
19599 *
19600 * @param x The input tensor, of rank 4 or rank 3, of shape
19601 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
19602 * assumed.
19603 * @param depthwiseFilter The depthwise filter tensor, rank 4, of shape
19604 * `[filterHeight, filterWidth, inChannels, channelMultiplier]`. This is
19605 * the filter used in the first step.
19606 * @param pointwiseFilter The pointwise filter tensor, rank 4, of shape
19607 * `[1, 1, inChannels * channelMultiplier, outChannels]`. This is
19608 * the filter used in the second step.
19609 * @param strides The strides of the convolution: `[strideHeight,
19610 * strideWidth]`. If strides is a single number, then `strideHeight ==
19611 * strideWidth`.
19612 * @param pad The type of padding algorithm.
19613 * - `same` and stride 1: output will be of same size as input,
19614 * regardless of filter size.
19615 * - `valid`: output will be smaller than input if filter is larger
19616 * than 1x1.
19617 * - For more info, see this guide:
19618 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
19619 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
19620 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
19621 * in which we sample input values across the height and width dimensions
19622 * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
19623 * number, then `dilationHeight == dilationWidth`. If it is greater than
19624 * 1, then all values of `strides` must be 1.
19625 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
19626 * "NHWC". Specify the data format of the input and output data. With the
19627 * default format "NHWC", the data is stored in the order of: [batch,
19628 * height, width, channels]. Only "NHWC" is currently supported.
19629 *
19630 * @doc {heading: 'Operations', subheading: 'Convolution'}
19631 */
19632 function separableConv2d_(x, depthwiseFilter, pointwiseFilter, strides, pad, dilation = [1, 1], dataFormat = 'NHWC') {
19633 const $x = convertToTensor(x, 'x', 'separableConv2d');
19634 const $depthwiseFilter = convertToTensor(depthwiseFilter, 'depthwiseFilter', 'separableConv2d');
19635 const $pointwiseFilter = convertToTensor(pointwiseFilter, 'pointwiseFilter', 'separableConv2d');
19636 let x4D = $x;
19637 let reshapedTo4D = false;
19638 if ($x.rank === 3) {
19639 reshapedTo4D = true;
19640 x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
19641 }
19642 if (dataFormat === 'NCHW') {
19643 throw new Error('separableConv2d currently does not support dataFormat NCHW; only ' +
19644 'NHWC is supported');
19645 }
19646 assert(x4D.rank === 4, () => `Error in separableConv2d: input must be rank 4, but got ` +
19647 `rank ${x4D.rank}.`);
19648 assert($depthwiseFilter.rank === 4, () => `Error in separableConv2d: depthwise filter must be rank 4, but ` +
19649 `got rank ${$depthwiseFilter.rank}.`);
19650 assert($pointwiseFilter.rank === 4, () => `Error in separableConv2d: pointwise filter must be rank 4, but ` +
19651 `got rank ${$depthwiseFilter.rank}.`);
19652 assert($pointwiseFilter.shape[0] === 1, () => `Error in separableConv2d: the first dimension of pointwise filter ` +
19653 ` must be 1, but got ${$pointwiseFilter.shape[0]}.`);
19654 assert($pointwiseFilter.shape[1] === 1, () => `Error in separableConv2d: the second dimension of pointwise ` +
19655 `filter must be 1, but got ${$pointwiseFilter.shape[1]}.`);
19656 const inChannels = $depthwiseFilter.shape[2];
19657 const channelMultiplier = $depthwiseFilter.shape[3];
19658 assert($pointwiseFilter.shape[2] === inChannels * channelMultiplier, () => `Error in separableConv2d: the third dimension of pointwise filter ` +
19659 `must be ${inChannels * channelMultiplier}, ` +
19660 `but got ${$pointwiseFilter.shape[2]}.`);
19661 const depthwise = depthwiseConv2d(x4D, $depthwiseFilter, strides, pad, dataFormat, dilation);
19662 const pointwiseStride = 1;
19663 const res = conv2d(depthwise, $pointwiseFilter, pointwiseStride, 'valid', dataFormat);
19664 if (reshapedTo4D) {
19665 return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
19666 }
19667 return res;
19668 }
19669 const separableConv2d = op({ separableConv2d_ });
19670
19671 /**
19672 * @license
19673 * Copyright 2020 Google Inc. All Rights Reserved.
19674 * Licensed under the Apache License, Version 2.0 (the "License");
19675 * you may not use this file except in compliance with the License.
19676 * You may obtain a copy of the License at
19677 *
19678 * http://www.apache.org/licenses/LICENSE-2.0
19679 *
19680 * Unless required by applicable law or agreed to in writing, software
19681 * distributed under the License is distributed on an "AS IS" BASIS,
19682 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19683 * See the License for the specific language governing permissions and
19684 * limitations under the License.
19685 * =============================================================================
19686 */
19687 /**
19688 * Computes the difference between two lists of numbers.
19689 *
19690 * Given a Tensor `x` and a Tensor `y`, this operation returns a Tensor `out`
19691 * that represents all values that are in `x` but not in `y`. The returned
19692 * Tensor `out` is sorted in the same order that the numbers appear in `x`
19693 * (duplicates are preserved). This operation also returns a Tensor indices that
19694 * represents the position of each out element in `x`. In other words:
19695 *
19696 * `out[i] = x[idx[i]] for i in [0, 1, ..., out.length - 1]`
19697 *
19698 * ```js
19699 * const x = [1, 2, 3, 4, 5, 6];
19700 * const y = [1, 3, 5];
19701 *
19702 * const [out, indices] = await tf.setdiff1dAsync(x, y);
19703 * out.print(); // [2, 4, 6]
19704 * indices.print(); // [1, 3, 5]
19705 * ```
19706 *
19707 * @param x 1-D Tensor. Values to keep.
19708 * @param y 1-D Tensor. Must have the same type as x. Values to exclude in the
19709 * output.
19710 * @returns Promise of Tensor tuple [out, indices].
19711 * out: Tensor with the same type as x.
19712 * indices: A Tensor of type int32.
19713 *
19714 * @doc {heading: 'Tensors', subheading: 'Transformations'}
19715 */
19716 async function setdiff1dAsync_(x, y) {
19717 const $x = convertToTensor(x, 'x', 'setdiff1d');
19718 const $y = convertToTensor(y, 'y', 'setdiff1d');
19719 assert($x.dtype === $y.dtype, () => `x and y should have the same dtype, but got x (${$x.dtype}) and y (${$y.dtype}).`);
19720 assert($x.rank === 1, () => `x should be 1D tensor, but got x (${$x.shape}).`);
19721 assert($y.rank === 1, () => `y should be 1D tensor, but got y (${$y.shape}).`);
19722 const xVals = await $x.data();
19723 const yVals = await $y.data();
19724 const ySet = new Set(yVals);
19725 let outputSize = 0;
19726 for (let i = 0; i < xVals.length; i++) {
19727 if (!ySet.has(xVals[i])) {
19728 outputSize++;
19729 }
19730 }
19731 const buffer = new TensorBuffer([outputSize], $x.dtype);
19732 const indices = new TensorBuffer([outputSize], 'int32');
19733 for (let i = 0, p = 0; i < xVals.length; i++) {
19734 if (!ySet.has(xVals[i])) {
19735 buffer.values[p] = xVals[i];
19736 indices.values[p] = i;
19737 p++;
19738 }
19739 }
19740 return [buffer.toTensor(), indices.toTensor()];
19741 }
19742 const setdiff1dAsync = setdiff1dAsync_;
19743
19744 /**
19745 * @license
19746 * Copyright 2018 Google LLC. All Rights Reserved.
19747 * Licensed under the Apache License, Version 2.0 (the "License");
19748 * you may not use this file except in compliance with the License.
19749 * You may obtain a copy of the License at
19750 *
19751 * http://www.apache.org/licenses/LICENSE-2.0
19752 *
19753 * Unless required by applicable law or agreed to in writing, software
19754 * distributed under the License is distributed on an "AS IS" BASIS,
19755 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19756 * See the License for the specific language governing permissions and
19757 * limitations under the License.
19758 * =============================================================================
19759 */
19760 /**
19761 * Returns an element-wise indication of the sign of a number.
19762 *
19763 * ```js
19764 * const x = tf.tensor1d([.6, 1.1, -3.3, NaN, 0]);
19765 *
19766 * x.sign().print(); // or tf.sign(x)
19767 * ```
19768 * @param x The input Tensor.
19769 *
19770 * @doc {heading: 'Operations', subheading: 'Basic math'}
19771 */
19772 function sign_(x) {
19773 const $x = convertToTensor(x, 'x', 'sign');
19774 const inputs = { x: $x };
19775 return ENGINE.runKernel(Sign, inputs);
19776 }
19777 const sign = op({ sign_ });
19778
19779 /**
19780 * @license
19781 * Copyright 2018 Google LLC. All Rights Reserved.
19782 * Licensed under the Apache License, Version 2.0 (the "License");
19783 * you may not use this file except in compliance with the License.
19784 * You may obtain a copy of the License at
19785 *
19786 * http://www.apache.org/licenses/LICENSE-2.0
19787 *
19788 * Unless required by applicable law or agreed to in writing, software
19789 * distributed under the License is distributed on an "AS IS" BASIS,
19790 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19791 * See the License for the specific language governing permissions and
19792 * limitations under the License.
19793 * =============================================================================
19794 */
19795 /**
19796 * Computes sin of the input Tensor element-wise: `sin(x)`
19797 *
19798 * ```js
19799 * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
19800 *
19801 * x.sin().print(); // or tf.sin(x)
19802 * ```
19803 * @param x The input tensor.
19804 *
19805 * @doc {heading: 'Operations', subheading: 'Basic math'}
19806 */
19807 function sin_(x) {
19808 const $x = convertToTensor(x, 'x', 'sin', 'float32');
19809 const inputs = { x: $x };
19810 return ENGINE.runKernel(Sin, inputs);
19811 }
19812 const sin = op({ sin_ });
19813
19814 /**
19815 * @license
19816 * Copyright 2018 Google LLC. All Rights Reserved.
19817 * Licensed under the Apache License, Version 2.0 (the "License");
19818 * you may not use this file except in compliance with the License.
19819 * You may obtain a copy of the License at
19820 *
19821 * http://www.apache.org/licenses/LICENSE-2.0
19822 *
19823 * Unless required by applicable law or agreed to in writing, software
19824 * distributed under the License is distributed on an "AS IS" BASIS,
19825 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19826 * See the License for the specific language governing permissions and
19827 * limitations under the License.
19828 * =============================================================================
19829 */
19830 /**
19831 * Computes hyperbolic sin of the input `tf.Tensor` element-wise: `sinh(x)`
19832 *
19833 * ```js
19834 * const x = tf.tensor1d([0, 1, -1, .7]);
19835 *
19836 * x.sinh().print(); // or tf.sinh(x)
19837 * ```
19838 * @param x The input tensor.
19839 *
19840 * @doc {heading: 'Operations', subheading: 'Basic math'}
19841 */
19842 function sinh_(x) {
19843 const $x = convertToTensor(x, 'x', 'sinh');
19844 const inputs = { x: $x };
19845 return ENGINE.runKernel(Sinh, inputs);
19846 }
19847 const sinh = op({ sinh_ });
19848
19849 /**
19850 * @license
19851 * Copyright 2018 Google LLC. All Rights Reserved.
19852 * Licensed under the Apache License, Version 2.0 (the "License");
19853 * you may not use this file except in compliance with the License.
19854 * You may obtain a copy of the License at
19855 *
19856 * http://www.apache.org/licenses/LICENSE-2.0
19857 *
19858 * Unless required by applicable law or agreed to in writing, software
19859 * distributed under the License is distributed on an "AS IS" BASIS,
19860 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19861 * See the License for the specific language governing permissions and
19862 * limitations under the License.
19863 * =============================================================================
19864 */
19865 /**
19866 * Extracts a 1D slice from 1D array starting at coordinates `begin` and is
19867 * of length `size`. See `slice` for details.
19868 */
19869 function slice1d_(x, begin, size) {
19870 const $x = convertToTensor(x, 'x', 'slice1d');
19871 assert($x.rank === 1, () => `slice1d expects a rank-1 tensor, but got a rank-${$x.rank} tensor`);
19872 return slice($x, [begin], [size]);
19873 }
19874 const slice1d = op({ slice1d_ });
19875
19876 /**
19877 * @license
19878 * Copyright 2018 Google LLC. All Rights Reserved.
19879 * Licensed under the Apache License, Version 2.0 (the "License");
19880 * you may not use this file except in compliance with the License.
19881 * You may obtain a copy of the License at
19882 *
19883 * http://www.apache.org/licenses/LICENSE-2.0
19884 *
19885 * Unless required by applicable law or agreed to in writing, software
19886 * distributed under the License is distributed on an "AS IS" BASIS,
19887 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19888 * See the License for the specific language governing permissions and
19889 * limitations under the License.
19890 * =============================================================================
19891 */
19892 /**
19893 * Extracts a 2D slice from a 2D array starting at coordinates `begin` and
19894 * is of size `size`. See `slice` for details.
19895 */
19896 function slice2d_(x, begin, size) {
19897 const $x = convertToTensor(x, 'x', 'slice2d');
19898 assert($x.rank === 2, () => `slice2d expects a rank-2 tensor, but got a rank-${$x.rank} tensor`);
19899 return slice($x, begin, size);
19900 }
19901 const slice2d = op({ slice2d_ });
19902
19903 /**
19904 * @license
19905 * Copyright 2018 Google LLC. All Rights Reserved.
19906 * Licensed under the Apache License, Version 2.0 (the "License");
19907 * you may not use this file except in compliance with the License.
19908 * You may obtain a copy of the License at
19909 *
19910 * http://www.apache.org/licenses/LICENSE-2.0
19911 *
19912 * Unless required by applicable law or agreed to in writing, software
19913 * distributed under the License is distributed on an "AS IS" BASIS,
19914 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19915 * See the License for the specific language governing permissions and
19916 * limitations under the License.
19917 * =============================================================================
19918 */
19919 /**
19920 * Extracts a 3D slice from a 3D array starting at coordinates `begin` and
19921 * is of size `size`. See `slice` for details.
19922 */
19923 function slice3d_(x, begin, size) {
19924 const $x = convertToTensor(x, 'x', 'slice3d');
19925 assert($x.rank === 3, () => `slice3d expects a rank-3 tensor, but got a rank-${$x.rank} tensor`);
19926 return slice($x, begin, size);
19927 }
19928 const slice3d = op({ slice3d_ });
19929
19930 /**
19931 * @license
19932 * Copyright 2018 Google LLC. All Rights Reserved.
19933 * Licensed under the Apache License, Version 2.0 (the "License");
19934 * you may not use this file except in compliance with the License.
19935 * You may obtain a copy of the License at
19936 *
19937 * http://www.apache.org/licenses/LICENSE-2.0
19938 *
19939 * Unless required by applicable law or agreed to in writing, software
19940 * distributed under the License is distributed on an "AS IS" BASIS,
19941 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19942 * See the License for the specific language governing permissions and
19943 * limitations under the License.
19944 * =============================================================================
19945 */
19946 /**
19947 * Extracts a 4D slice from a 4D array starting at coordinates `begin` and
19948 * is of size `size`. See `slice` for details.
19949 */
19950 function slice4d_(x, begin, size) {
19951 const $x = convertToTensor(x, 'x', 'slice4d');
19952 assert($x.rank === 4, () => `slice4d expects a rank-4 tensor, but got a rank-${$x.rank} tensor`);
19953 return slice($x, begin, size);
19954 }
19955 const slice4d = op({ slice4d_ });
19956
19957 /**
19958 * @license
19959 * Copyright 2018 Google LLC. All Rights Reserved.
19960 * Licensed under the Apache License, Version 2.0 (the "License");
19961 * you may not use this file except in compliance with the License.
19962 * You may obtain a copy of the License at
19963 *
19964 * http://www.apache.org/licenses/LICENSE-2.0
19965 *
19966 * Unless required by applicable law or agreed to in writing, software
19967 * distributed under the License is distributed on an "AS IS" BASIS,
19968 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19969 * See the License for the specific language governing permissions and
19970 * limitations under the License.
19971 * =============================================================================
19972 */
19973 /**
19974 * Computes the softmax normalized vector given the logits.
19975 *
19976 * ```js
19977 * const a = tf.tensor1d([1, 2, 3]);
19978 *
19979 * a.softmax().print(); // or tf.softmax(a)
19980 * ```
19981 *
19982 * ```js
19983 * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);
19984 *
19985 * a.softmax().print(); // or tf.softmax(a)
19986 * ```
19987 *
19988 * @param logits The logits array.
19989 * @param dim The dimension softmax would be performed on. Defaults to `-1`
19990 * which indicates the last dimension.
19991 *
19992 * @doc {heading: 'Operations', subheading: 'Normalization'}
19993 */
19994 function softmax_(logits, dim = -1) {
19995 const $logits = convertToTensor(logits, 'logits', 'softmax', 'float32');
19996 if (dim === -1) {
19997 dim = $logits.rank - 1;
19998 }
19999 if (dim !== $logits.rank - 1) {
20000 throw Error('Softmax along a non-last dimension is not yet supported. ' +
20001 `Logits was rank ${$logits.rank} and dim was ${dim}`);
20002 }
20003 const inputs = { logits: $logits };
20004 const attrs = { dim };
20005 return ENGINE.runKernel(Softmax, inputs, attrs);
20006 }
20007 const softmax = op({ softmax_ });
20008
20009 /**
20010 * @license
20011 * Copyright 2020 Google LLC. All Rights Reserved.
20012 * Licensed under the Apache License, Version 2.0 (the "License");
20013 * you may not use this file except in compliance with the License.
20014 * You may obtain a copy of the License at
20015 *
20016 * http://www.apache.org/licenses/LICENSE-2.0
20017 *
20018 * Unless required by applicable law or agreed to in writing, software
20019 * distributed under the License is distributed on an "AS IS" BASIS,
20020 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20021 * See the License for the specific language governing permissions and
20022 * limitations under the License.
20023 * =============================================================================
20024 */
20025 /**
20026 * Fast Fourier transform.
20027 *
20028 * Computes the 1-dimensional discrete Fourier transform over the inner-most
20029 * dimension of input.
20030 *
20031 * ```js
20032 * const real = tf.tensor1d([1, 2, 3]);
20033 * const imag = tf.tensor1d([1, 2, 3]);
20034 * const x = tf.complex(real, imag);
20035 *
20036 * x.fft().print(); // tf.spectral.fft(x).print();
20037 * ```
20038 * @param input The complex input to compute an fft over.
20039 *
20040 * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
20041 */
20042 function fft_(input) {
20043 assert(input.dtype === 'complex64', () => `The dtype for tf.spectral.fft() must be complex64 ` +
20044 `but got ${input.dtype}.`);
20045 const inputs = { input };
20046 return ENGINE.runKernel(FFT, inputs);
20047 }
20048 const fft = op({ fft_ });
20049
20050 /**
20051 * @license
20052 * Copyright 2020 Google LLC. All Rights Reserved.
20053 * Licensed under the Apache License, Version 2.0 (the "License");
20054 * you may not use this file except in compliance with the License.
20055 * You may obtain a copy of the License at
20056 *
20057 * http://www.apache.org/licenses/LICENSE-2.0
20058 *
20059 * Unless required by applicable law or agreed to in writing, software
20060 * distributed under the License is distributed on an "AS IS" BASIS,
20061 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20062 * See the License for the specific language governing permissions and
20063 * limitations under the License.
20064 * =============================================================================
20065 */
20066 /**
20067 * Inverse fast Fourier transform.
20068 *
20069 * Computes the inverse 1-dimensional discrete Fourier transform over the
20070 * inner-most dimension of input.
20071 *
20072 * ```js
20073 * const real = tf.tensor1d([1, 2, 3]);
20074 * const imag = tf.tensor1d([1, 2, 3]);
20075 * const x = tf.complex(real, imag);
20076 *
20077 * x.ifft().print(); // tf.spectral.ifft(x).print();
20078 * ```
20079 * @param input The complex input to compute an ifft over.
20080 *
20081 * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
20082 */
20083 function ifft_(input) {
20084 assert(input.dtype === 'complex64', () => `The dtype for tf.spectral.ifft() must be complex64 ` +
20085 `but got ${input.dtype}.`);
20086 const inputs = { input };
20087 return ENGINE.runKernel(IFFT, inputs);
20088 }
20089 const ifft = op({ ifft_ });
20090
20091 /**
20092 * @license
20093 * Copyright 2018 Google LLC. All Rights Reserved.
20094 * Licensed under the Apache License, Version 2.0 (the "License");
20095 * you may not use this file except in compliance with the License.
20096 * You may obtain a copy of the License at
20097 *
20098 * http://www.apache.org/licenses/LICENSE-2.0
20099 *
20100 * Unless required by applicable law or agreed to in writing, software
20101 * distributed under the License is distributed on an "AS IS" BASIS,
20102 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20103 * See the License for the specific language governing permissions and
20104 * limitations under the License.
20105 * =============================================================================
20106 */
20107 /**
20108 * Inversed real value input fast Fourier transform.
20109 *
20110 * Computes the 1-dimensional inversed discrete Fourier transform over the
20111 * inner-most dimension of the real input.
20112 *
20113 * ```js
20114 * const real = tf.tensor1d([1, 2, 3]);
20115 * const imag = tf.tensor1d([0, 0, 0]);
20116 * const x = tf.complex(real, imag);
20117 *
20118 * x.irfft().print();
20119 * ```
20120 * @param input The real value input to compute an irfft over.
20121 *
20122 * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
20123 */
20124 function irfft_(input) {
20125 const innerDimensionSize = input.shape[input.shape.length - 1];
20126 const batch = input.size / innerDimensionSize;
20127 let ret;
20128 if (innerDimensionSize <= 2) {
20129 const complexInput = reshape(input, [batch, innerDimensionSize]);
20130 ret = ifft(complexInput);
20131 }
20132 else {
20133 // The length of unique components of the DFT of a real-valued signal
20134 // is 2 * (input_len - 1)
20135 const outputShape = [batch, 2 * (innerDimensionSize - 1)];
20136 const realInput = reshape(real(input), [batch, innerDimensionSize]);
20137 const imagInput = reshape(imag(input), [batch, innerDimensionSize]);
20138 const realConjugate = reverse(slice(realInput, [0, 1], [batch, innerDimensionSize - 2]), 1);
20139 const imagConjugate = mul(reverse(slice(imagInput, [0, 1], [batch, innerDimensionSize - 2]), 1), scalar(-1));
20140 const r = concat([realInput, realConjugate], 1);
20141 const i = concat([imagInput, imagConjugate], 1);
20142 const complexInput = reshape(complex(r, i), [outputShape[0], outputShape[1]]);
20143 ret = ifft(complexInput);
20144 }
20145 ret = real(ret);
20146 // reshape the result if the input is 3D tensor.
20147 if (input.rank === 3 && input.shape[0] !== 0) {
20148 const temp = ret;
20149 const batch = input.shape[0];
20150 ret = reshape(ret, [batch, ret.shape[0] / batch, ret.shape[1]]);
20151 temp.dispose();
20152 }
20153 return ret;
20154 }
20155 const irfft = op({ irfft_ });
20156
20157 /**
20158 * @license
20159 * Copyright 2020 Google LLC. All Rights Reserved.
20160 * Licensed under the Apache License, Version 2.0 (the "License");
20161 * you may not use this file except in compliance with the License.
20162 * You may obtain a copy of the License at
20163 *
20164 * http://www.apache.org/licenses/LICENSE-2.0
20165 *
20166 * Unless required by applicable law or agreed to in writing, software
20167 * distributed under the License is distributed on an "AS IS" BASIS,
20168 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20169 * See the License for the specific language governing permissions and
20170 * limitations under the License.
20171 * =============================================================================
20172 */
20173 /**
20174 * Splits a `tf.Tensor` into sub tensors.
20175 *
20176 * If `numOrSizeSplits` is a number, splits `x` along dimension `axis`
20177 * into `numOrSizeSplits` smaller tensors.
20178 * Requires that `numOrSizeSplits` evenly divides `x.shape[axis]`.
20179 *
20180 * If `numOrSizeSplits` is a number array, splits `x` into
20181 * `numOrSizeSplits.length` pieces. The shape of the `i`-th piece has the
20182 * same size as `x` except along dimension `axis` where the size is
20183 * `numOrSizeSplits[i]`.
20184 *
20185 * ```js
20186 * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]);
20187 * const [a, b] = tf.split(x, 2, 1);
20188 * a.print();
20189 * b.print();
20190 *
20191 * const [c, d, e] = tf.split(x, [1, 2, 1], 1);
20192 * c.print();
20193 * d.print();
20194 * e.print();
20195 * ```
20196 *
20197 * @param x The input tensor to split.
20198 * @param numOrSizeSplits Either an integer indicating the number of
20199 * splits along the axis or an array of integers containing the sizes of
20200 * each output tensor along the axis. If a number then it must evenly divide
20201 * `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`.
20202 * Can contain one -1 indicating that dimension is to be inferred.
20203 * @param axis The dimension along which to split. Defaults to 0 (the first
20204 * dim).
20205 *
20206 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
20207 */
20208 function split_(x, numOrSizeSplits, axis = 0) {
20209 const $x = convertToTensor(x, 'x', 'split');
20210 const inputs = { x: $x };
20211 const attr = { numOrSizeSplits, axis };
20212 return ENGINE.runKernel(SplitV, inputs, attr);
20213 }
20214 const split = op({ split_ });
20215
20216 /**
20217 * @license
20218 * Copyright 2018 Google LLC. All Rights Reserved.
20219 * Licensed under the Apache License, Version 2.0 (the "License");
20220 * you may not use this file except in compliance with the License.
20221 * You may obtain a copy of the License at
20222 *
20223 * http://www.apache.org/licenses/LICENSE-2.0
20224 *
20225 * Unless required by applicable law or agreed to in writing, software
20226 * distributed under the License is distributed on an "AS IS" BASIS,
20227 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20228 * See the License for the specific language governing permissions and
20229 * limitations under the License.
20230 * =============================================================================
20231 */
20232 /**
20233 * Real value input fast Fourier transform.
20234 *
20235 * Computes the 1-dimensional discrete Fourier transform over the
20236 * inner-most dimension of the real input.
20237 *
20238 * ```js
20239 * const real = tf.tensor1d([1, 2, 3]);
20240 *
20241 * real.rfft().print();
20242 * ```
20243 * @param input The real value input to compute an rfft over.
20244 *
20245 * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
20246 */
20247 function rfft_(input, fftLength) {
20248 assert(input.dtype === 'float32', () => `The dtype for rfft() must be real value but got ${input.dtype}`);
20249 let innerDimensionSize = input.shape[input.shape.length - 1];
20250 const batch = input.size / innerDimensionSize;
20251 let adjustedInput;
20252 if (fftLength != null && fftLength < innerDimensionSize) {
20253 // Need to crop
20254 const begin = input.shape.map(v => 0);
20255 const size = input.shape.map(v => v);
20256 size[input.shape.length - 1] = fftLength;
20257 adjustedInput = slice(input, begin, size);
20258 innerDimensionSize = fftLength;
20259 }
20260 else if (fftLength != null && fftLength > innerDimensionSize) {
20261 // Need to pad with zeros
20262 const zerosShape = input.shape.map(v => v);
20263 zerosShape[input.shape.length - 1] = fftLength - innerDimensionSize;
20264 adjustedInput = concat([input, zeros(zerosShape)], input.shape.length - 1);
20265 innerDimensionSize = fftLength;
20266 }
20267 else {
20268 adjustedInput = input;
20269 }
20270 // Complement the input with zero imaginary numbers.
20271 const zerosInput = zerosLike(adjustedInput);
20272 const complexInput = reshape(complex(adjustedInput, zerosInput), [batch, innerDimensionSize]);
20273 const ret = fft(complexInput);
20274 // Exclude complex conjugations. These conjugations are put symmetrically.
20275 const half = Math.floor(innerDimensionSize / 2) + 1;
20276 const realValues = real(ret);
20277 const imagValues = imag(ret);
20278 const realComplexConjugate = split(realValues, [half, innerDimensionSize - half], realValues.shape.length - 1);
20279 const imagComplexConjugate = split(imagValues, [half, innerDimensionSize - half], imagValues.shape.length - 1);
20280 const outputShape = adjustedInput.shape.slice();
20281 outputShape[adjustedInput.shape.length - 1] = half;
20282 return reshape(complex(realComplexConjugate[0], imagComplexConjugate[0]), outputShape);
20283 }
20284 const rfft = op({ rfft_ });
20285
20286 /**
20287 * @license
20288 * Copyright 2020 Google LLC. All Rights Reserved.
20289 * Licensed under the Apache License, Version 2.0 (the "License");
20290 * you may not use this file except in compliance with the License.
20291 * You may obtain a copy of the License at
20292 *
20293 * http://www.apache.org/licenses/LICENSE-2.0
20294 *
20295 * Unless required by applicable law or agreed to in writing, software
20296 * distributed under the License is distributed on an "AS IS" BASIS,
20297 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20298 * See the License for the specific language governing permissions and
20299 * limitations under the License.
20300 * =============================================================================
20301 */
20302 /**
20303 * Returns (a - b) * (a - b) element-wise.
20304 * Supports broadcasting.
20305 *
20306 * ```js
20307 * const a = tf.tensor1d([1, 4, 3, 16]);
20308 * const b = tf.tensor1d([1, 2, 9, 4]);
20309 *
20310 * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b)
20311 * ```
20312 *
20313 * ```js
20314 * // Broadcast squared difference a with b.
20315 * const a = tf.tensor1d([2, 4, 6, 8]);
20316 * const b = tf.scalar(5);
20317 *
20318 * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b)
20319 * ```
20320 *
20321 * @param a The first tensor.
20322 * @param b The second tensor. Must have the same type as `a`.
20323 *
20324 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
20325 */
20326 function squaredDifference_(a, b) {
20327 let $a = convertToTensor(a, 'a', 'squaredDifference');
20328 let $b = convertToTensor(b, 'b', 'squaredDifference');
20329 [$a, $b] = makeTypesMatch($a, $b);
20330 assertAndGetBroadcastShape($a.shape, $b.shape);
20331 const inputs = { a: $a, b: $b };
20332 const attrs = {};
20333 return ENGINE.runKernel(SquaredDifference, inputs, attrs);
20334 }
20335 const squaredDifference = op({ squaredDifference_ });
20336
20337 /**
20338 * @license
20339 * Copyright 2020 Google LLC. All Rights Reserved.
20340 * Licensed under the Apache License, Version 2.0 (the "License");
20341 * you may not use this file except in compliance with the License.
20342 * You may obtain a copy of the License at
20343 *
20344 * http://www.apache.org/licenses/LICENSE-2.0
20345 *
20346 * Unless required by applicable law or agreed to in writing, software
20347 * distributed under the License is distributed on an "AS IS" BASIS,
20348 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20349 * See the License for the specific language governing permissions and
20350 * limitations under the License.
20351 * =============================================================================
20352 */
20353 /**
20354 * Removes dimensions of size 1 from the shape of a `tf.Tensor`.
20355 *
20356 * ```js
20357 * const x = tf.tensor([1, 2, 3, 4], [1, 1, 4]);
20358 * x.squeeze().print();
20359 * ```
20360 *
20361 * @param x The input tensor to be squeezed.
20362 * @param axis An optional list of numbers. If specified, only
20363 * squeezes the dimensions listed. The dimension index starts at 0. It
20364 * is an error to squeeze a dimension that is not 1.
20365 *
20366 * @doc {heading: 'Tensors', subheading: 'Transformations'}
20367 */
20368 function squeeze_(x, axis) {
20369 const $x = convertToTensor(x, 'x', 'squeeze');
20370 return reshape($x, squeezeShape($x.shape, axis).newShape);
20371 }
20372 const squeeze = op({ squeeze_ });
20373
20374 /**
20375 * @license
20376 * Copyright 2020 Google LLC. All Rights Reserved.
20377 * Licensed under the Apache License, Version 2.0 (the "License");
20378 * you may not use this file except in compliance with the License.
20379 * You may obtain a copy of the License at
20380 *
20381 * http://www.apache.org/licenses/LICENSE-2.0
20382 *
20383 * Unless required by applicable law or agreed to in writing, software
20384 * distributed under the License is distributed on an "AS IS" BASIS,
20385 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20386 * See the License for the specific language governing permissions and
20387 * limitations under the License.
20388 * =============================================================================
20389 */
20390 /**
20391 * Stacks a list of rank-`R` `tf.Tensor`s into one rank-`(R+1)` `tf.Tensor`.
20392 *
20393 * ```js
20394 * const a = tf.tensor1d([1, 2]);
20395 * const b = tf.tensor1d([3, 4]);
20396 * const c = tf.tensor1d([5, 6]);
20397 * tf.stack([a, b, c]).print();
20398 * ```
20399 *
20400 * @param tensors A list of tensor objects with the same shape and dtype.
20401 * @param axis The axis to stack along. Defaults to 0 (the first dim).
20402 *
20403 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
20404 */
20405 function stack_(tensors, axis = 0) {
20406 const $tensors = convertToTensorArray(tensors, 'tensors', 'stack', 'string_or_numeric');
20407 assert($tensors.length >= 1, () => 'Pass at least one tensor to tf.stack');
20408 if ($tensors.length > 0) {
20409 assert(axis <= $tensors[0].rank, () => 'Axis must be <= rank of the tensor');
20410 }
20411 const inputs = $tensors;
20412 const attrs = { axis };
20413 return ENGINE.runKernel(Pack, inputs, attrs);
20414 }
20415 const stack = op({ stack_ });
20416
20417 /**
20418 * @license
20419 * Copyright 2018 Google LLC. All Rights Reserved.
20420 * Licensed under the Apache License, Version 2.0 (the "License");
20421 * you may not use this file except in compliance with the License.
20422 * You may obtain a copy of the License at
20423 *
20424 * http://www.apache.org/licenses/LICENSE-2.0
20425 *
20426 * Unless required by applicable law or agreed to in writing, software
20427 * distributed under the License is distributed on an "AS IS" BASIS,
20428 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20429 * See the License for the specific language governing permissions and
20430 * limitations under the License.
20431 * =============================================================================
20432 */
20433 /**
20434 * Computes step of the input `tf.Tensor` element-wise: `x > 0 ? 1 : alpha * x`
20435 *
20436 * ```js
20437 * const x = tf.tensor1d([0, 2, -1, -3]);
20438 *
20439 * x.step(.5).print(); // or tf.step(x, .5)
20440 * ```
20441 * @param x The input tensor.
20442 * @param alpha The gradient when input is negative.
20443 *
20444 * @doc {heading: 'Operations', subheading: 'Basic math'}
20445 */
20446 function step_(x, alpha = 0.0) {
20447 const $x = convertToTensor(x, 'x', 'step');
20448 const inputs = { x: $x };
20449 const attrs = { alpha };
20450 return ENGINE.runKernel(Step, inputs, attrs);
20451 }
20452 const step = op({ step_ });
20453
20454 /**
20455 * @license
20456 * Copyright 2018 Google LLC. All Rights Reserved.
20457 * Licensed under the Apache License, Version 2.0 (the "License");
20458 * you may not use this file except in compliance with the License.
20459 * You may obtain a copy of the License at
20460 *
20461 * http://www.apache.org/licenses/LICENSE-2.0
20462 *
20463 * Unless required by applicable law or agreed to in writing, software
20464 * distributed under the License is distributed on an "AS IS" BASIS,
20465 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20466 * See the License for the specific language governing permissions and
20467 * limitations under the License.
20468 * =============================================================================
20469 */
20470 /**
20471 * Extracts a strided slice of a tensor.
20472 *
20473 * Roughly speaking, this op extracts a slice of size (end-begin)/stride from
20474 * the given input tensor (x). Starting at the location specified by begin the
20475 * slice continues by adding stride to the index until all dimensions are not
20476 * less than end. Note that a stride can be negative, which causes a reverse
20477 * slice.
20478 *
20479 * ```js
20480 * const t = tf.tensor3d([1, 1, 1 ,2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6],
20481 * [3, 2, 3]);
20482 * t.stridedSlice([1, 0, 0], [2, 1, 3], [1, 1, 1]).print() // [[[3, 3, 3]]]
20483 * t.stridedSlice([1, 0, 0], [2, 2, 3], [1, 1, 1]).print() // [[[3, 3, 3],
20484 * // [4, 4, 4]]]
20485 * t.stridedSlice([1, -1, 0], [2, -3, 3], [1, -1, 1]).print() // [[[4, 4, 4],
20486 * // [3, 3, 3]]]
20487 * ```
20488 *
20489 * @param x The tensor to stride slice.
20490 * @param begin The coordinates to start the slice from.
20491 * @param end: The coordinates to end the slice at.
20492 * @param strides: The size of the slice.
20493 * @param beginMask: If the ith bit of beginMask is set, begin[i] is ignored
20494 * and the fullest possible range in that dimension is used instead.
20495 * @param endMask: If the ith bit of endMask is set, end[i] is ignored
20496 * and the fullest possible range in that dimension is used instead.
20497 * @param shrinkAxisMask: a bitmask where bit i implies that
20498 * the ith specification should shrink the dimensionality. begin and end must
20499 * imply a slice of size 1 in the dimension.
20500 *
20501 * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
20502 */
20503 function stridedSlice_(x, begin, end, strides, beginMask = 0, endMask = 0, ellipsisMask = 0, newAxisMask = 0, shrinkAxisMask = 0) {
20504 const $x = convertToTensor(x, 'x', 'stridedSlice', 'string_or_numeric');
20505 const inputs = { x: $x };
20506 const attrs = {
20507 begin,
20508 end,
20509 strides,
20510 beginMask,
20511 endMask,
20512 ellipsisMask,
20513 newAxisMask,
20514 shrinkAxisMask
20515 };
20516 return ENGINE.runKernel(StridedSlice, inputs, attrs);
20517 }
20518 const stridedSlice = op({ stridedSlice_ });
20519
20520 /**
20521 * @license
20522 * Copyright 2018 Google LLC. All Rights Reserved.
20523 * Licensed under the Apache License, Version 2.0 (the "License");
20524 * you may not use this file except in compliance with the License.
20525 * You may obtain a copy of the License at
20526 *
20527 * http://www.apache.org/licenses/LICENSE-2.0
20528 *
20529 * Unless required by applicable law or agreed to in writing, software
20530 * distributed under the License is distributed on an "AS IS" BASIS,
20531 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20532 * See the License for the specific language governing permissions and
20533 * limitations under the License.
20534 * =============================================================================
20535 */
20536 /**
20537 * Computes tan of the input `tf.Tensor` element-wise, `tan(x)`
20538 *
20539 * ```js
20540 * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
20541 *
20542 * x.tan().print(); // or tf.tan(x)
20543 * ```
20544 * @param x The input tensor.
20545 *
20546 * @doc {heading: 'Operations', subheading: 'Basic math'}
20547 */
20548 function tan_(x) {
20549 const $x = convertToTensor(x, 'x', 'tan', 'float32');
20550 const inputs = { x: $x };
20551 return ENGINE.runKernel(Tan, inputs);
20552 }
20553 const tan = op({ tan_ });
20554
20555 /**
20556 * @license
20557 * Copyright 2018 Google LLC. All Rights Reserved.
20558 * Licensed under the Apache License, Version 2.0 (the "License");
20559 * you may not use this file except in compliance with the License.
20560 * You may obtain a copy of the License at
20561 *
20562 * http://www.apache.org/licenses/LICENSE-2.0
20563 *
20564 * Unless required by applicable law or agreed to in writing, software
20565 * distributed under the License is distributed on an "AS IS" BASIS,
20566 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20567 * See the License for the specific language governing permissions and
20568 * limitations under the License.
20569 * =============================================================================
20570 */
20571 /**
20572 * Creates rank-1 `tf.Tensor` with the provided values, shape and dtype.
20573 *
20574 * The same functionality can be achieved with `tf.tensor`, but in general
20575 * we recommend using `tf.tensor1d` as it makes the code more readable.
20576 *
20577 * ```js
20578 * tf.tensor1d([1, 2, 3]).print();
20579 * ```
20580 *
20581 * @param values The values of the tensor. Can be array of numbers,
20582 * or a `TypedArray`.
20583 * @param dtype The data type.
20584 *
20585 * @doc {heading: 'Tensors', subheading: 'Creation'}
20586 */
20587 function tensor1d(values, dtype) {
20588 assertNonNull(values);
20589 const inferredShape = inferShape(values, dtype);
20590 if (inferredShape.length !== 1) {
20591 throw new Error('tensor1d() requires values to be a flat/TypedArray');
20592 }
20593 const shape = null;
20594 return makeTensor(values, shape, inferredShape, dtype);
20595 }
20596
20597 /**
20598 * @license
20599 * Copyright 2018 Google LLC. All Rights Reserved.
20600 * Licensed under the Apache License, Version 2.0 (the "License");
20601 * you may not use this file except in compliance with the License.
20602 * You may obtain a copy of the License at
20603 *
20604 * http://www.apache.org/licenses/LICENSE-2.0
20605 *
20606 * Unless required by applicable law or agreed to in writing, software
20607 * distributed under the License is distributed on an "AS IS" BASIS,
20608 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20609 * See the License for the specific language governing permissions and
20610 * limitations under the License.
20611 * =============================================================================
20612 */
20613 /**
20614 * Creates rank-2 `tf.Tensor` with the provided values, shape and dtype.
20615 *
20616 * The same functionality can be achieved with `tf.tensor`, but in general
20617 * we recommend using `tf.tensor2d` as it makes the code more readable.
20618 *
20619 * ```js
20620 * // Pass a nested array.
20621 * tf.tensor2d([[1, 2], [3, 4]]).print();
20622 * ```
20623 * ```js
20624 * // Pass a flat array and specify a shape.
20625 * tf.tensor2d([1, 2, 3, 4], [2, 2]).print();
20626 * ```
20627 *
20628 * @param values The values of the tensor. Can be nested array of numbers,
20629 * or a flat array, or a `TypedArray`.
20630 * @param shape The shape of the tensor. If not provided, it is inferred from
20631 * `values`.
20632 * @param dtype The data type.
20633 *
20634 * @doc {heading: 'Tensors', subheading: 'Creation'}
20635 */
20636 function tensor2d(values, shape, dtype) {
20637 assertNonNull(values);
20638 if (shape != null && shape.length !== 2) {
20639 throw new Error('tensor2d() requires shape to have two numbers');
20640 }
20641 const inferredShape = inferShape(values, dtype);
20642 if (inferredShape.length !== 2 && inferredShape.length !== 1) {
20643 throw new Error('tensor2d() requires values to be number[][] or flat/TypedArray');
20644 }
20645 if (inferredShape.length === 1 && shape == null) {
20646 throw new Error('tensor2d() requires shape to be provided when `values` ' +
20647 'are a flat/TypedArray');
20648 }
20649 return makeTensor(values, shape, inferredShape, dtype);
20650 }
20651
20652 /**
20653 * @license
20654 * Copyright 2018 Google LLC. All Rights Reserved.
20655 * Licensed under the Apache License, Version 2.0 (the "License");
20656 * you may not use this file except in compliance with the License.
20657 * You may obtain a copy of the License at
20658 *
20659 * http://www.apache.org/licenses/LICENSE-2.0
20660 *
20661 * Unless required by applicable law or agreed to in writing, software
20662 * distributed under the License is distributed on an "AS IS" BASIS,
20663 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20664 * See the License for the specific language governing permissions and
20665 * limitations under the License.
20666 * =============================================================================
20667 */
20668 /**
20669 * Creates rank-4 `tf.Tensor` with the provided values, shape and dtype.
20670 *
20671 * The same functionality can be achieved with `tf.tensor`, but in general
20672 * we recommend using `tf.tensor4d` as it makes the code more readable.
20673 *
20674 * ```js
20675 * // Pass a nested array.
20676 * tf.tensor4d([[[[1], [2]], [[3], [4]]]]).print();
20677 * ```
20678 * ```js
20679 * // Pass a flat array and specify a shape.
20680 * tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]).print();
20681 * ```
20682 *
20683 * @param values The values of the tensor. Can be nested array of numbers,
20684 * or a flat array, or a `TypedArray`.
20685 * @param shape The shape of the tensor. Optional. If not provided,
20686 * it is inferred from `values`.
20687 * @param dtype The data type.
20688 *
20689 * @doc {heading: 'Tensors', subheading: 'Creation'}
20690 */
20691 function tensor4d(values, shape, dtype) {
20692 assertNonNull(values);
20693 if (shape != null && shape.length !== 4) {
20694 throw new Error('tensor4d() requires shape to have four numbers');
20695 }
20696 const inferredShape = inferShape(values, dtype);
20697 if (inferredShape.length !== 4 && inferredShape.length !== 1) {
20698 throw new Error('tensor4d() requires values to be number[][][][] or flat/TypedArray');
20699 }
20700 if (inferredShape.length === 1 && shape == null) {
20701 throw new Error('tensor4d() requires shape to be provided when `values` ' +
20702 'are a flat array');
20703 }
20704 return makeTensor(values, shape, inferredShape, dtype);
20705 }
20706
20707 /**
20708 * @license
20709 * Copyright 2018 Google LLC. All Rights Reserved.
20710 * Licensed under the Apache License, Version 2.0 (the "License");
20711 * you may not use this file except in compliance with the License.
20712 * You may obtain a copy of the License at
20713 *
20714 * http://www.apache.org/licenses/LICENSE-2.0
20715 *
20716 * Unless required by applicable law or agreed to in writing, software
20717 * distributed under the License is distributed on an "AS IS" BASIS,
20718 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20719 * See the License for the specific language governing permissions and
20720 * limitations under the License.
20721 * =============================================================================
20722 */
20723 /**
20724 * Creates rank-5 `tf.Tensor` with the provided values, shape and dtype.
20725 *
20726 * The same functionality can be achieved with `tf.tensor`, but in general
20727 * we recommend using `tf.tensor5d` as it makes the code more readable.
20728 *
20729 * ```js
20730 * // Pass a nested array.
20731 * tf.tensor5d([[[[[1],[2]],[[3],[4]]],[[[5],[6]],[[7],[8]]]]]).print();
20732 * ```
20733 * ```js
20734 * // Pass a flat array and specify a shape.
20735 * tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]).print();
20736 * ```
20737 *
20738 * @param values The values of the tensor. Can be nested array of numbers,
20739 * or a flat array, or a `TypedArray`.
20740 * @param shape The shape of the tensor. Optional. If not provided,
20741 * it is inferred from `values`.
20742 * @param dtype The data type.
20743 *
20744 * @doc {heading: 'Tensors', subheading: 'Creation'}
20745 */
20746 function tensor5d(values, shape, dtype) {
20747 assertNonNull(values);
20748 if (shape != null && shape.length !== 5) {
20749 throw new Error('tensor5d() requires shape to have five numbers');
20750 }
20751 const inferredShape = inferShape(values, dtype);
20752 if (inferredShape.length !== 5 && inferredShape.length !== 1) {
20753 throw new Error('tensor5d() requires values to be ' +
20754 'number[][][][][] or flat/TypedArray');
20755 }
20756 if (inferredShape.length === 1 && shape == null) {
20757 throw new Error('tensor5d() requires shape to be provided when `values` ' +
20758 'are a flat array');
20759 }
20760 return makeTensor(values, shape, inferredShape, dtype);
20761 }
20762
20763 /**
20764 * @license
20765 * Copyright 2018 Google LLC. All Rights Reserved.
20766 * Licensed under the Apache License, Version 2.0 (the "License");
20767 * you may not use this file except in compliance with the License.
20768 * You may obtain a copy of the License at
20769 *
20770 * http://www.apache.org/licenses/LICENSE-2.0
20771 *
20772 * Unless required by applicable law or agreed to in writing, software
20773 * distributed under the License is distributed on an "AS IS" BASIS,
20774 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20775 * See the License for the specific language governing permissions and
20776 * limitations under the License.
20777 * =============================================================================
20778 */
20779 /**
20780 * Creates rank-6 `tf.Tensor` with the provided values, shape and dtype.
20781 *
20782 * The same functionality can be achieved with `tf.tensor`, but in general
20783 * we recommend using `tf.tensor6d` as it makes the code more readable.
20784 *
20785 * ```js
20786 * // Pass a nested array.
20787 * tf.tensor6d([[[[[[1],[2]],[[3],[4]]],[[[5],[6]],[[7],[8]]]]]]).print();
20788 * ```
20789 * ```js
20790 * // Pass a flat array and specify a shape.
20791 * tf.tensor6d([1, 2, 3, 4, 5, 6, 7, 8], [1, 1, 2, 2, 2, 1]).print();
20792 * ```
20793 *
20794 * @param values The values of the tensor. Can be nested array of numbers,
20795 * or a flat array, or a `TypedArray`.
20796 * @param shape The shape of the tensor. Optional. If not provided,
20797 * it is inferred from `values`.
20798 * @param dtype The data type.
20799 *
20800 * @doc {heading: 'Tensors', subheading: 'Creation'}
20801 */
20802 function tensor6d(values, shape, dtype) {
20803 assertNonNull(values);
20804 if (shape != null && shape.length !== 6) {
20805 throw new Error('tensor6d() requires shape to have six numbers');
20806 }
20807 const inferredShape = inferShape(values, dtype);
20808 if (inferredShape.length !== 6 && inferredShape.length !== 1) {
20809 throw new Error('tensor6d() requires values to be number[][][][][][] or ' +
20810 'flat/TypedArray');
20811 }
20812 if (inferredShape.length === 1 && shape == null) {
20813 throw new Error('tensor6d() requires shape to be provided when `values` ' +
20814 'are a flat array');
20815 }
20816 shape = shape ||
20817 inferredShape;
20818 return makeTensor(values, shape, inferredShape, dtype);
20819 }
20820
20821 /**
20822 * @license
20823 * Copyright 2018 Google LLC. All Rights Reserved.
20824 * Licensed under the Apache License, Version 2.0 (the "License");
20825 * you may not use this file except in compliance with the License.
20826 * You may obtain a copy of the License at
20827 *
20828 * http://www.apache.org/licenses/LICENSE-2.0
20829 *
20830 * Unless required by applicable law or agreed to in writing, software
20831 * distributed under the License is distributed on an "AS IS" BASIS,
20832 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20833 * See the License for the specific language governing permissions and
20834 * limitations under the License.
20835 * =============================================================================
20836 */
20837 /**
20838 * Finds the values and indices of the `k` largest entries along the last
20839 * dimension.
20840 *
20841 * If the input is a vector (rank=1), finds the k largest entries in the vector
20842 * and outputs their values and indices as vectors. Thus values[j] is the j-th
20843 * largest entry in input, and its index is indices[j].
20844 * For higher rank inputs, computes the top k entries along the last dimension.
20845 *
20846 * If two elements are equal, the lower-index element appears first.
20847 *
20848 * ```js
20849 * const a = tf.tensor2d([[1, 5], [4, 3]]);
20850 * const {values, indices} = tf.topk(a);
20851 * values.print();
20852 * indices.print();
20853 * ```
20854 * @param x 1-D or higher `tf.Tensor` with last dimension being at least `k`.
20855 * @param k Number of top elements to look for along the last dimension.
20856 * @param sorted If true, the resulting `k` elements will be sorted by the
20857 * values in descending order.
20858 *
20859 * @doc {heading: 'Operations', subheading: 'Evaluation'}
20860 */
20861 function topk_(x, k = 1, sorted = true) {
20862 const $x = convertToTensor(x, 'x', 'topk');
20863 if ($x.rank === 0) {
20864 throw new Error('topk() expects the input to be of rank 1 or higher');
20865 }
20866 const lastDim = $x.shape[$x.shape.length - 1];
20867 if (k < 0) {
20868 throw new Error(`'k' passed to topk() must be >= 0 but got ${k}`);
20869 }
20870 if (k > lastDim) {
20871 throw new Error(`'k' passed to topk() must be <= the last dimension (${lastDim}) ` +
20872 `but got ${k}`);
20873 }
20874 const inputs = { x: $x };
20875 const attrs = { k, sorted };
20876 const [values, indices] = ENGINE.runKernel(TopK, inputs, attrs);
20877 return { values, indices };
20878 }
20879 const topk = op({ topk_ });
20880
20881 /**
20882 * @license
20883 * Copyright 2020 Google LLC. All Rights Reserved.
20884 * Licensed under the Apache License, Version 2.0 (the "License");
20885 * you may not use this file except in compliance with the License.
20886 * You may obtain a copy of the License at
20887 *
20888 * http://www.apache.org/licenses/LICENSE-2.0
20889 *
20890 * Unless required by applicable law or agreed to in writing, software
20891 * distributed under the License is distributed on an "AS IS" BASIS,
20892 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20893 * See the License for the specific language governing permissions and
20894 * limitations under the License.
20895 * =============================================================================
20896 */
20897 /**
20898 * Creates a `tf.Tensor` with values sampled from a truncated normal
20899 * distribution.
20900 *
20901 * ```js
20902 * tf.truncatedNormal([2, 2]).print();
20903 * ```
20904 *
20905 * The generated values follow a normal distribution with specified mean and
20906 * standard deviation, except that values whose magnitude is more than 2
20907 * standard deviations from the mean are dropped and re-picked.
20908 *
20909 * @param shape An array of integers defining the output tensor shape.
20910 * @param mean The mean of the normal distribution.
20911 * @param stdDev The standard deviation of the normal distribution.
20912 * @param dtype The data type of the output tensor.
20913 * @param seed The seed for the random number generator.
20914 *
20915 * @doc {heading: 'Tensors', subheading: 'Creation'}
20916 */
20917 function truncatedNormal_(shape, mean = 0, stdDev = 1, dtype, seed) {
20918 if (dtype != null && dtype === 'bool') {
20919 throw new Error(`Unsupported data type $ { dtype }`);
20920 }
20921 const randGauss = new MPRandGauss(mean, stdDev, dtype, true /* truncated */, seed);
20922 const res = buffer(shape, dtype);
20923 for (let i = 0; i < res.values.length; i++) {
20924 res.values[i] = randGauss.nextValue();
20925 }
20926 return res.toTensor();
20927 }
20928 const truncatedNormal = op({ truncatedNormal_ });
20929
20930 /**
20931 * @license
20932 * Copyright 2020 Google LLC. All Rights Reserved.
20933 * Licensed under the Apache License, Version 2.0 (the "License");
20934 * you may not use this file except in compliance with the License.
20935 * You may obtain a copy of the License at
20936 *
20937 * http://www.apache.org/licenses/LICENSE-2.0
20938 *
20939 * Unless required by applicable law or agreed to in writing, software
20940 * distributed under the License is distributed on an "AS IS" BASIS,
20941 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20942 * See the License for the specific language governing permissions and
20943 * limitations under the License.
20944 * =============================================================================
20945 */
20946 /**
20947 * Finds unique elements along an axis of a tensor.
20948 *
20949 * It returns a tensor `values` containing all of the unique elements along the
20950 * `axis` of the given tensor `x` in the same order that they occur along the
20951 * `axis` in `x`; `x` does not need to be sorted. It also returns a tensor
20952 * `indices` the same size as the number of the elements in `x` along the `axis`
20953 * dimension. It contains the index in the unique output `values`.
20954 *
20955 * ```js
20956 * // A 1-D tensor
20957 * const a = tf.tensor1d([1, 1, 2, 4, 4, 4, 7, 8, 8]);
20958 * const {values, indices} = tf.unique(a);
20959 * values.print(); // [1, 2, 4, 7, 8,]
20960 * indices.print(); // [0, 0, 1, 2, 2, 2, 3, 4, 4]
20961 * ```
20962 *
20963 * ```js
20964 * // A 2-D tensor with axis=0
20965 * //
20966 * // 'a' is: [[1, 0, 0],
20967 * // [1, 0, 0],
20968 * // [2, 0, 0]]
20969 * const a = tf.tensor2d([[1, 0, 0], [1, 0, 0], [2, 0, 0]]);
20970 * const {values, indices} = tf.unique(a, 0)
20971 * values.print(); // [[1, 0, 0],
20972 * // [2, 0, 0]]
20973 * indices.print(); // [0, 0, 1]
20974 * ```
20975 *
20976 * ```js
20977 * // A 2-D tensor with axis=1
20978 * //
20979 * // 'a' is: [[1, 0, 0],
20980 * // [1, 0, 0],
20981 * // [2, 0, 0]]
20982 * const a = tf.tensor2d([[1, 0, 0], [1, 0, 0], [2, 0, 0]]);
20983 * const {values, indices} = tf.unique(a, 1)
20984 * values.print(); // [[1, 0],
20985 * // [1, 0],
20986 * // [2, 0]]
20987 * indices.print(); // [0, 1, 1]
20988 * ```
20989 * @param x A tensor (int32, string, bool).
20990 * @param axis The axis of the tensor to find the unique elements.
20991 * @returns [uniqueElements, indices] (see above for details)
20992 *
20993 * @doc {heading: 'Operations', subheading: 'Evaluation'}
20994 */
20995 function unique_(x, axis = 0) {
20996 const $x = convertToTensor(x, 'x', 'unique', 'string_or_numeric');
20997 assert($x.rank > 0, () => 'The input tensor must be at least 1D');
20998 const inputs = { x: $x };
20999 const attrs = { axis };
21000 const [values, indices] = ENGINE.runKernel(Unique, inputs, attrs);
21001 return { values, indices };
21002 }
21003 const unique = op({ unique_ });
21004
21005 /**
21006 * @license
21007 * Copyright 2020 Google LLC. All Rights Reserved.
21008 * Licensed under the Apache License, Version 2.0 (the "License");
21009 * you may not use this file except in compliance with the License.
21010 * You may obtain a copy of the License at
21011 *
21012 * http://www.apache.org/licenses/LICENSE-2.0
21013 *
21014 * Unless required by applicable law or agreed to in writing, software
21015 * distributed under the License is distributed on an "AS IS" BASIS,
21016 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21017 * See the License for the specific language governing permissions and
21018 * limitations under the License.
21019 * =============================================================================
21020 */
21021 /**
21022 * Computes the sum along segments of a `tf.Tensor`.
21023 *
21024 * ```js
21025 * const x = tf.tensor1d([1, 2, 3, 4]);
21026 * const segmentIds = tf.tensor1d([1, 2, 0, 1], 'int32');
21027 * const numSegments = 3;
21028 *
21029 * x.unsortedSegmentSum(segmentIds, numSegments).print()
21030 * //or tf.unsortedSegmentSum(x, segmentIds, numSegments)
21031 * ```
21032 * @param x The `tf.Tensor` that will be summed along its segments.
21033 * @param segmentIds A `tf.Tensor1D` whose rank is equal to the rank of `x`'s
21034 * dimension along the `axis`. Maps each element of `x` to a segment.
21035 * @param numSegments The number of distinct `segmentIds`.
21036 *
21037 * @doc {heading: 'Operations', subheading: 'Segment'}
21038 */
21039 function unsortedSegmentSum_(x, segmentIds, numSegments) {
21040 const $x = convertToTensor(x, 'x', 'unsortedSegmentSum');
21041 const $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'unsortedSegmentSum', 'int32');
21042 assert(isInt(numSegments), () => 'numSegments must be of dtype int');
21043 const inputs = { x: $x, segmentIds: $segmentIds };
21044 const attrs = { numSegments };
21045 return ENGINE.runKernel(UnsortedSegmentSum, inputs, attrs);
21046 }
21047 const unsortedSegmentSum = op({ unsortedSegmentSum_ });
21048
21049 /**
21050 * @license
21051 * Copyright 2020 Google LLC. All Rights Reserved.
21052 * Licensed under the Apache License, Version 2.0 (the "License");
21053 * you may not use this file except in compliance with the License.
21054 * You may obtain a copy of the License at
21055 *
21056 * http://www.apache.org/licenses/LICENSE-2.0
21057 *
21058 * Unless required by applicable law or agreed to in writing, software
21059 * distributed under the License is distributed on an "AS IS" BASIS,
21060 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21061 * See the License for the specific language governing permissions and
21062 * limitations under the License.
21063 * =============================================================================
21064 */
21065 /**
21066 * Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s.
21067 *
21068 * ```js
21069 * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
21070 *
21071 * tf.unstack(a).forEach(tensor => tensor.print());
21072 * ```
21073 *
21074 * @param x A tensor object.
21075 * @param axis The axis to unstack along. Defaults to 0 (the first dim).
21076 *
21077 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
21078 */
21079 function unstack_(x, axis = 0) {
21080 const $x = convertToTensor(x, 'x', 'unstack', 'string_or_numeric');
21081 assert(axis >= -$x.shape.length && axis < $x.shape.length, () => `Axis = ${axis} is not in [-${$x.shape.length}, ${$x.shape.length})`);
21082 const inputs = { value: $x };
21083 const attrs = { axis };
21084 return ENGINE.runKernel(Unpack, inputs, attrs);
21085 }
21086 const unstack = op({ unstack_ });
21087
21088 /**
21089 * @license
21090 * Copyright 2022 Google LLC. All Rights Reserved.
21091 * Licensed under the Apache License, Version 2.0 (the "License");
21092 * you may not use this file except in compliance with the License.
21093 * You may obtain a copy of the License at
21094 *
21095 * http://www.apache.org/licenses/LICENSE-2.0
21096 *
21097 * Unless required by applicable law or agreed to in writing, software
21098 * distributed under the License is distributed on an "AS IS" BASIS,
21099 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21100 * See the License for the specific language governing permissions and
21101 * limitations under the License.
21102 * =============================================================================
21103 */
21104 /**
21105 * Searches for where a value would go in a sorted sequence.
21106 *
21107 * This is not a method for checking containment (like javascript in).
21108 *
21109 * The typical use case for this operation is "binning", "bucketing", or
21110 * "discretizing". The values are assigned to bucket-indices based on the edges
21111 * listed in 'sortedSequence'. This operation returns the bucket-index for each
21112 * value.
21113 *
21114 * The index returned corresponds to the first edge greater than the value.
21115 *
21116 * The axis is not settable for this operation. It always operates on the
21117 * innermost dimension (axis=-1). The operation will accept any number of outer
21118 * dimensions.
21119 *
21120 * Note: This operation assumes that 'upperBound' is sorted along the
21121 * innermost axis, maybe using 'sort(..., axis=-1)'. If the sequence is not
21122 * sorted no error is raised and the content of the returned tensor is not well
21123 * defined.
21124 *
21125 * ```js
21126 * const seq = tf.tensor1d([0, 3, 9, 10, 10]);
21127 * const values = tf.tensor1d([0, 4, 10]);
21128 * const result = tf.upperBound(seq, values);
21129 * result.print(); // [1, 2, 5]
21130 * ```
21131 * @param sortedSequence: N-D. Sorted sequence.
21132 * @param values: N-D. Search values.
21133 * @return An N-D int32 tensor the size of values containing the result of
21134 * applying upper bound to each value. The result is not a global index to
21135 * the entire Tensor, but the index in the last dimension.
21136 * @doc {heading: 'Operations', subheading: 'Evaluation'}
21137 */
21138 function upperBound(sortedSequence, values) {
21139 return searchSorted(sortedSequence, values, 'right');
21140 }
21141
21142 /**
21143 * @license
21144 * Copyright 2018 Google LLC. All Rights Reserved.
21145 * Licensed under the Apache License, Version 2.0 (the "License");
21146 * you may not use this file except in compliance with the License.
21147 * You may obtain a copy of the License at
21148 *
21149 * http://www.apache.org/licenses/LICENSE-2.0
21150 *
21151 * Unless required by applicable law or agreed to in writing, software
21152 * distributed under the License is distributed on an "AS IS" BASIS,
21153 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21154 * See the License for the specific language governing permissions and
21155 * limitations under the License.
21156 * =============================================================================
21157 */
21158 /**
21159 * Creates a new variable with the provided initial value.
21160 * ```js
21161 * const x = tf.variable(tf.tensor([1, 2, 3]));
21162 * x.assign(tf.tensor([4, 5, 6]));
21163 *
21164 * x.print();
21165 * ```
21166 *
21167 * @param initialValue Initial value for the tensor.
21168 * @param trainable If true, optimizers are allowed to update it.
21169 * @param name Name of the variable. Defaults to a unique id.
21170 * @param dtype If set, initialValue will be converted to the given type.
21171 *
21172 * @doc {heading: 'Tensors', subheading: 'Creation'}
21173 */
21174 function variable(initialValue, trainable = true, name, dtype) {
21175 return ENGINE.makeVariable(initialValue, trainable, name, dtype);
21176 }
21177
21178 /**
21179 * @license
21180 * Copyright 2018 Google LLC. All Rights Reserved.
21181 * Licensed under the Apache License, Version 2.0 (the "License");
21182 * you may not use this file except in compliance with the License.
21183 * You may obtain a copy of the License at
21184 *
21185 * http://www.apache.org/licenses/LICENSE-2.0
21186 *
21187 * Unless required by applicable law or agreed to in writing, software
21188 * distributed under the License is distributed on an "AS IS" BASIS,
21189 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21190 * See the License for the specific language governing permissions and
21191 * limitations under the License.
21192 * =============================================================================
21193 */
21194 function whereImpl(condShape, condVals) {
21195 const indices = [];
21196 for (let i = 0; i < condVals.length; i++) {
21197 if (condVals[i]) {
21198 indices.push(i);
21199 }
21200 }
21201 const inBuffer = buffer(condShape, 'int32');
21202 const out = buffer([indices.length, condShape.length], 'int32');
21203 for (let i = 0; i < indices.length; i++) {
21204 const loc = inBuffer.indexToLoc(indices[i]);
21205 const offset = i * condShape.length;
21206 out.values.set(loc, offset);
21207 }
21208 return out.toTensor();
21209 }
21210
21211 /**
21212 * @license
21213 * Copyright 2020 Google LLC. All Rights Reserved.
21214 * Licensed under the Apache License, Version 2.0 (the "License");
21215 * you may not use this file except in compliance with the License.
21216 * You may obtain a copy of the License at
21217 *
21218 * http://www.apache.org/licenses/LICENSE-2.0
21219 *
21220 * Unless required by applicable law or agreed to in writing, software
21221 * distributed under the License is distributed on an "AS IS" BASIS,
21222 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21223 * See the License for the specific language governing permissions and
21224 * limitations under the License.
21225 * =============================================================================
21226 */
21227 /**
21228 * Returns the coordinates of true elements of condition.
21229 *
21230 * The coordinates are returned in a 2-D tensor where the first dimension (rows)
21231 * represents the number of true elements, and the second dimension (columns)
21232 * represents the coordinates of the true elements. Keep in mind, the shape of
21233 * the output tensor can vary depending on how many true values there are in
21234 * input. Indices are output in row-major order. The resulting tensor has the
21235 * shape `[numTrueElems, condition.rank]`.
21236 *
21237 * This is analogous to calling the python `tf.where(cond)` without an x or y.
21238 *
21239 * ```js
21240 * const cond = tf.tensor1d([false, false, true], 'bool');
21241 * const result = await tf.whereAsync(cond);
21242 * result.print();
21243 * ```
21244 *
21245 * @doc {heading: 'Operations', subheading: 'Logical'}
21246 */
21247 async function whereAsync_(condition) {
21248 const $condition = convertToTensor(condition, 'condition', 'whereAsync', 'bool');
21249 const vals = await $condition.data();
21250 const res = whereImpl($condition.shape, vals);
21251 if (condition !== $condition) {
21252 $condition.dispose();
21253 }
21254 return res;
21255 }
21256 const whereAsync = whereAsync_;
21257
21258 /**
21259 * @license
21260 * Copyright 2018 Google LLC. All Rights Reserved.
21261 * Licensed under the Apache License, Version 2.0 (the "License");
21262 * you may not use this file except in compliance with the License.
21263 * You may obtain a copy of the License at
21264 *
21265 * http://www.apache.org/licenses/LICENSE-2.0
21266 *
21267 * Unless required by applicable law or agreed to in writing, software
21268 * distributed under the License is distributed on an "AS IS" BASIS,
21269 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21270 * See the License for the specific language governing permissions and
21271 * limitations under the License.
21272 * =============================================================================
21273 */
21274 /**
21275 * Apply boolean mask to tensor.
21276 *
21277 * ```js
21278 * const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
21279 * const mask = tf.tensor1d([1, 0, 1], 'bool');
21280 * const result = await tf.booleanMaskAsync(tensor, mask);
21281 * result.print();
21282 * ```
21283 *
21284 * @param tensor N-D tensor.
21285 * @param mask K-D boolean tensor, K <= N and K must be known statically.
21286 * @param axis A 0-D int Tensor representing the axis in tensor to mask from.
21287 * By default, axis is 0 which will mask from the first dimension.
21288 * Otherwise K + axis <= N.
21289 *
21290 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
21291 */
21292 async function booleanMaskAsync_(tensor, mask, axis) {
21293 const $tensor = convertToTensor(tensor, 'tensor', 'boolMask');
21294 const $mask = convertToTensor(mask, 'mask', 'boolMask', 'bool');
21295 const axisFrom = axis == null ? 0 : axis;
21296 const maskDim = $mask.rank;
21297 const tensorShape = $tensor.shape;
21298 assert(maskDim > 0, () => 'mask cannot be scalar');
21299 assertShapesMatch(tensorShape.slice(axisFrom, axisFrom + maskDim), $mask.shape, `mask's shape must match the first K dimensions of tensor's shape,`);
21300 let leadingSize = 1;
21301 for (let i = axisFrom; i < axisFrom + maskDim; i++) {
21302 leadingSize *= tensorShape[i];
21303 }
21304 const targetTensorShape = tensorShape.slice(0, axisFrom)
21305 .concat([leadingSize], tensorShape.slice(axisFrom + maskDim));
21306 const reshapedTensor = reshape($tensor, targetTensorShape);
21307 const reshapedMask = reshape($mask, [-1]);
21308 const positivePositions = await whereAsync(reshapedMask);
21309 const indices = squeeze(positivePositions, [1]);
21310 const res = gather(reshapedTensor, indices, axisFrom);
21311 // Ensure no memory leak.
21312 if (tensor !== $tensor) {
21313 $tensor.dispose();
21314 }
21315 if (mask !== $mask) {
21316 $mask.dispose();
21317 }
21318 indices.dispose();
21319 reshapedTensor.dispose();
21320 reshapedMask.dispose();
21321 positivePositions.dispose();
21322 return res;
21323 }
21324 const booleanMaskAsync = booleanMaskAsync_;
21325
21326 /**
21327 * @license
21328 * Copyright 2018 Google LLC. All Rights Reserved.
21329 * Licensed under the Apache License, Version 2.0 (the "License");
21330 * you may not use this file except in compliance with the License.
21331 * You may obtain a copy of the License at
21332 *
21333 * http://www.apache.org/licenses/LICENSE-2.0
21334 *
21335 * Unless required by applicable law or agreed to in writing, software
21336 * distributed under the License is distributed on an "AS IS" BASIS,
21337 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21338 * See the License for the specific language governing permissions and
21339 * limitations under the License.
21340 * =============================================================================
21341 */
21342 /**
21343 * Compute the moving average of a variable.
21344 *
21345 * Without zeroDebias, the moving average operation is defined by:
21346 * `v += delta`
21347 * where
21348 * `delta = (1 - decay) * (x - v)`
21349 *
21350 * With zeroDebias (default), the `delta` term is scaled to debias the
21351 * effect of the (assumed) zero-initialization of `v`.
21352 * `delta /= (1 - decay ^ step)`
21353 *
21354 * For more details on the zero-debiasing algorithm, see:
21355 * https://arxiv.org/abs/1412.6980
21356 *
21357 * Note that this function is completely stateless and does not keep track of
21358 * step count. The step count needs to be maintained by the caller and passed
21359 * in as `step`.
21360 *
21361 * @param v The current moving average value.
21362 * @param x New input value, must have the same shape and dtype as `v`.
21363 * @param decay The decay factor. Typical values are 0.95 and 0.99.
21364 * @param step Step count.
21365 * @param zeroDebias: Whether zeroDebias is to be performed (default: `true`).
21366 * @returns The new moving average value.
21367 *
21368 * @doc {heading: 'Operations', subheading: 'Moving Average'}
21369 */
21370 function movingAverage_(v, x, decay, step, zeroDebias = true) {
21371 const $v = convertToTensor(v, 'v', 'movingAverage');
21372 const $x = convertToTensor(x, 'x', 'movingAverage');
21373 const $decay = convertToTensor(decay, 'decay', 'movingAverage');
21374 assertTypesMatch($v, $x);
21375 assert(arraysEqual($v.shape, $x.shape), () => 'Shape mismatch in v and x');
21376 const one = scalar(1);
21377 const oneMinusDecay = sub(one, $decay);
21378 let update = mul(sub($x, $v), oneMinusDecay);
21379 if (zeroDebias) {
21380 assert(step != null, () => 'When using zeroDebias: true, step is required.');
21381 const $step = convertToTensor(step, 'step', 'movingAverage');
21382 update = div(update, sub(one, pow($decay, $step)));
21383 }
21384 return add$1($v, update);
21385 }
21386 const movingAverage = op({ movingAverage_ });
21387
21388 /**
21389 * @license
21390 * Copyright 2018 Google LLC. All Rights Reserved.
21391 * Licensed under the Apache License, Version 2.0 (the "License");
21392 * you may not use this file except in compliance with the License.
21393 * You may obtain a copy of the License at
21394 *
21395 * http://www.apache.org/licenses/LICENSE-2.0
21396 *
21397 * Unless required by applicable law or agreed to in writing, software
21398 * distributed under the License is distributed on an "AS IS" BASIS,
21399 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21400 * See the License for the specific language governing permissions and
21401 * limitations under the License.
21402 * =============================================================================
21403 */
21404 /**
21405 * Creates a new tensor by applying sparse updates to individual
21406 * values or slices within a zero tensor of the given shape tensor according to
21407 * indices. This operator is the inverse of the `tf.gatherND` operator which
21408 * extracts values or slices from a given tensor.
21409 *
21410 * ```js
21411 * const indices = tf.tensor2d([4, 3, 1, 7], [4, 1], 'int32');
21412 * const updates = tf.tensor1d([9, 10, 11, 12]);
21413 * const shape = [8];
21414 * tf.scatterND(indices, updates, shape).print() //[0, 11, 0, 10, 9, 0, 0, 12]
21415 * ```
21416 *
21417 * @param indices The tensor contains the indices into the output tensor.
21418 * @param updates The tensor contains the value for the indices.
21419 * @param shape: The shape of the output tensor.
21420 *
21421 * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
21422 */
21423 function scatterND_(indices, updates, shape) {
21424 const $indices = convertToTensor(indices, 'indices', 'scatterND', 'int32');
21425 const $updates = convertToTensor(updates, 'updates', 'scatterND');
21426 validateInput($updates, $indices, shape);
21427 const inputs = { indices: $indices, updates: $updates };
21428 const attrs = { shape };
21429 // tslint:disable-next-line: no-unnecessary-type-assertion
21430 return ENGINE.runKernel(ScatterNd, inputs, attrs);
21431 }
21432 const scatterND = op({ scatterND_ });
21433
21434 /**
21435 * Validate sparseToDense inputs.
21436 *
21437 * @param sparseIndices A 0-D, 1-D, or 2-D Tensor of type int32.
21438 * sparseIndices[i] contains the complete index where sparseValues[i] will be
21439 * placed.
21440 * @param sparseValues A 0-D or 1-D Tensor. Values
21441 * corresponding to each row of sparseIndices, or a scalar value to be used for
21442 * all sparse indices.
21443 * @param outputShape number[]. Shape of the dense output tensor.
21444 * @param validateIndices boolean. indice validation is not supported, error
21445 * will be thrown if it is set.
21446 */
21447 function validateInput$1(sparseIndices, sparseValues, outputShape, defaultValues) {
21448 if (sparseIndices.dtype !== 'int32') {
21449 throw new Error('tf.sparseToDense() expects the indices to be int32 type,' +
21450 ` but the dtype was ${sparseIndices.dtype}.`);
21451 }
21452 if (sparseIndices.rank > 2) {
21453 throw new Error('sparseIndices should be a scalar, vector, or matrix,' +
21454 ` but got shape ${sparseIndices.shape}.`);
21455 }
21456 const numElems = sparseIndices.rank > 0 ? sparseIndices.shape[0] : 1;
21457 const numDims = sparseIndices.rank > 1 ? sparseIndices.shape[1] : 1;
21458 if (outputShape.length !== numDims) {
21459 throw new Error('outputShape has incorrect number of elements:,' +
21460 ` ${outputShape.length}, should be: ${numDims}.`);
21461 }
21462 const numValues = sparseValues.size;
21463 if (!(sparseValues.rank === 0 ||
21464 sparseValues.rank === 1 && numValues === numElems)) {
21465 throw new Error('sparseValues has incorrect shape ' +
21466 `${sparseValues.shape}, should be [] or [${numElems}]`);
21467 }
21468 if (sparseValues.dtype !== defaultValues.dtype) {
21469 throw new Error('sparseValues.dtype must match defaultValues.dtype');
21470 }
21471 }
21472
21473 /**
21474 * @license
21475 * Copyright 2018 Google LLC. All Rights Reserved.
21476 * Licensed under the Apache License, Version 2.0 (the "License");
21477 * you may not use this file except in compliance with the License.
21478 * You may obtain a copy of the License at
21479 *
21480 * http://www.apache.org/licenses/LICENSE-2.0
21481 *
21482 * Unless required by applicable law or agreed to in writing, software
21483 * distributed under the License is distributed on an "AS IS" BASIS,
21484 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21485 * See the License for the specific language governing permissions and
21486 * limitations under the License.
21487 * =============================================================================
21488 */
21489 /**
21490 * Converts a sparse representation into a dense tensor.
21491 *
21492 * Builds an array dense with shape outputShape such that:
21493 *
21494 * // If sparseIndices is scalar
21495 * dense[i] = (i == sparseIndices ? sparseValues : defaultValue)
21496 *
21497 * // If sparseIndices is a vector, then for each i
21498 * dense[sparseIndices[i]] = sparseValues[i]
21499 *
21500 * // If sparseIndices is an n by d matrix, then for each i in [0, n)
21501 * dense[sparseIndices[i][0], ..., sparseIndices[i][d-1]] = sparseValues[i]
21502 * All other values in dense are set to defaultValue. If sparseValues is a
21503 * scalar, all sparse indices are set to this single value.
21504 *
21505 * If indices are repeated the final value is summed over all values for those
21506 * indices.
21507 *
21508 * ```js
21509 * const indices = tf.tensor1d([4, 5, 6, 1, 2, 3], 'int32');
21510 * const values = tf.tensor1d([10, 11, 12, 13, 14, 15], 'float32');
21511 * const shape = [8];
21512 * tf.sparseToDense(indices, values, shape).print();
21513 * ```
21514 *
21515 * @param sparseIndices A 0-D, 1-D, or 2-D Tensor of type int32.
21516 * sparseIndices[i] contains the complete index where sparseValues[i] will be
21517 * placed.
21518 * @param sparseValues A 0-D or 1-D Tensor. Values
21519 * corresponding to each row of sparseIndices, or a scalar value to be used for
21520 * all sparse indices.
21521 * @param outputShape Shape of the dense output tensor. the type is inferred.
21522 * @param defaultValue Scalar. Value to set for indices not specified in
21523 * sparseIndices. Defaults to zero.
21524 *
21525 * @doc {heading: 'Operations', subheading: 'Normalization'}
21526 */
21527 function sparseToDense_(sparseIndices, sparseValues, outputShape, defaultValue = 0) {
21528 const $sparseIndices = convertToTensor(sparseIndices, 'sparseIndices', 'sparseToDense', 'int32');
21529 const $sparseValues = convertToTensor(sparseValues, 'sparseValues', 'sparseToDense', 'string_or_numeric');
21530 const $defaultValue = convertToTensor(defaultValue, 'defaultValue', 'sparseToDense', $sparseValues.dtype);
21531 validateInput$1($sparseIndices, $sparseValues, outputShape, $defaultValue);
21532 const inputs = {
21533 sparseIndices: $sparseIndices,
21534 sparseValues: $sparseValues,
21535 defaultValue: $defaultValue
21536 };
21537 const attrs = { outputShape };
21538 return ENGINE.runKernel(SparseToDense, inputs, attrs);
21539 }
21540 const sparseToDense = op({ sparseToDense_ });
21541
21542 /**
21543 * @license
21544 * Copyright 2018 Google LLC. All Rights Reserved.
21545 * Licensed under the Apache License, Version 2.0 (the "License");
21546 * you may not use this file except in compliance with the License.
21547 * You may obtain a copy of the License at
21548 *
21549 * http://www.apache.org/licenses/LICENSE-2.0
21550 *
21551 * Unless required by applicable law or agreed to in writing, software
21552 * distributed under the License is distributed on an "AS IS" BASIS,
21553 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21554 * See the License for the specific language governing permissions and
21555 * limitations under the License.
21556 * =============================================================================
21557 */
21558 /**
21559 * Gather slices from input tensor into a Tensor with shape specified by
21560 * `indices`.
21561 *
21562 * `indices` is an K-dimensional integer tensor, best thought of as a
21563 * (K-1)-dimensional tensor of indices into input, where each element defines a
21564 * slice of input:
21565 * output[\\(i_0, ..., i_{K-2}\\)] = input[indices[\\(i_0, ..., i_{K-2}\\)]]
21566 *
21567 * Whereas in `tf.gather`, `indices` defines slices into the first dimension of
21568 * input, in `tf.gatherND`, `indices` defines slices into the first N dimensions
21569 * of input, where N = indices.shape[-1].
21570 *
21571 * The last dimension of indices can be at most the rank of input:
21572 * indices.shape[-1] <= input.rank
21573 *
21574 * The last dimension of `indices` corresponds to elements
21575 * (if indices.shape[-1] == input.rank) or slices
21576 * (if indices.shape[-1] < input.rank) along dimension indices.shape[-1] of
21577 * input.
21578 * The output tensor has shape
21579 * indices.shape[:-1] + input.shape[indices.shape[-1]:]
21580 *
21581 * Note that on CPU, if an out of bound index is found, an error is returned. On
21582 * GPU, if an out of bound index is found, a 0 is stored in the corresponding
21583 * output value.
21584 *
21585 * ```js
21586 * const indices = tf.tensor2d([0, 1, 1, 0], [2,2], 'int32');
21587 * const input = tf.tensor2d([9, 10, 11, 12], [2, 2]);
21588 * tf.gatherND(input, indices).print() // [10, 11]
21589 * ```
21590 *
21591 * @param x The tensor from which to gather values.
21592 * @param indices Index tensor, must be of type int32.
21593 *
21594 * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
21595 */
21596 function gatherND_(x, indices) {
21597 const $indices = convertToTensor(indices, 'indices', 'gatherND', 'int32');
21598 const $x = convertToTensor(x, 'x', 'gatherND', 'string_or_numeric');
21599 const inputs = { params: $x, indices: $indices };
21600 return ENGINE.runKernel(GatherNd, inputs);
21601 }
21602 const gatherND = op({ gatherND_ });
21603
21604 /**
21605 * @license
21606 * Copyright 2019 Google LLC. All Rights Reserved.
21607 * Licensed under the Apache License, Version 2.0 (the "License");
21608 * you may not use this file except in compliance with the License.
21609 * You may obtain a copy of the License at
21610 *
21611 * http://www.apache.org/licenses/LICENSE-2.0
21612 *
21613 * Unless required by applicable law or agreed to in writing, software
21614 * distributed under the License is distributed on an "AS IS" BASIS,
21615 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21616 * See the License for the specific language governing permissions and
21617 * limitations under the License.
21618 * =============================================================================
21619 */
21620 /**
21621 * Normalize noise shape based on provided tensor and noise shape.
21622 *
21623 * @param x Tensor.
21624 * @param noiseShape The shape for the randomly generated keep/drop flags, as
21625 * an array of numbers. Optional.
21626 * @returns Normalized noise shape.
21627 */
21628 function getNoiseShape(x, noiseShape) {
21629 if (noiseShape == null) {
21630 return x.shape.slice();
21631 }
21632 if (arraysEqual(x.shape, noiseShape)) {
21633 return noiseShape;
21634 }
21635 if (x.shape.length === noiseShape.length) {
21636 const newDimension = [];
21637 for (let i = 0; i < x.shape.length; i++) {
21638 if (noiseShape[i] == null && x.shape[i] != null) {
21639 newDimension.push(x.shape[i]);
21640 }
21641 else {
21642 newDimension.push(noiseShape[i]);
21643 }
21644 }
21645 return newDimension;
21646 }
21647 return noiseShape;
21648 }
21649
21650 /**
21651 * @license
21652 * Copyright 2018 Google LLC. All Rights Reserved.
21653 * Licensed under the Apache License, Version 2.0 (the "License");
21654 * you may not use this file except in compliance with the License.
21655 * You may obtain a copy of the License at
21656 *
21657 * http://www.apache.org/licenses/LICENSE-2.0
21658 *
21659 * Unless required by applicable law or agreed to in writing, software
21660 * distributed under the License is distributed on an "AS IS" BASIS,
21661 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21662 * See the License for the specific language governing permissions and
21663 * limitations under the License.
21664 * =============================================================================
21665 */
21666 /**
21667 * Computes dropout.
21668 *
21669 * ```js
21670 * const x = tf.tensor1d([1, 2, 2, 1]);
21671 * const rate = 0.75;
21672 * const output = tf.dropout(x, rate);
21673 * output.print();
21674 * ```
21675 *
21676 * @param x A floating point Tensor or TensorLike.
21677 * @param rate A float in the range [0, 1). The probability that each element
21678 * of x is discarded.
21679 * @param noiseShape An array of numbers of type int32, representing the
21680 * shape for randomly generated keep/drop flags. If the noiseShape has null
21681 * value, it will be automatically replaced with the x's relative dimension
21682 * size. Optional.
21683 * @param seed Used to create random seeds. Optional.
21684 * @returns A Tensor of the same shape of x.
21685 *
21686 * @doc {heading: 'Operations', subheading: 'Dropout'}
21687 */
21688 function dropout_(x, rate, noiseShape, seed) {
21689 const $x = convertToTensor(x, 'x', 'dropout');
21690 assert($x.dtype === 'float32', () => `x has to be a floating point tensor since it's going to be ` +
21691 `scaled, but got a ${$x.dtype} tensor instead.`);
21692 assert(rate >= 0 && rate < 1, () => `rate must be a float in the range [0, 1), but got ${rate}.`);
21693 if (rate === 0) {
21694 return x instanceof Tensor ? $x.clone() : $x;
21695 }
21696 const $noiseShape = getNoiseShape($x, noiseShape);
21697 const keepProb = 1 - rate;
21698 const multiplier = div(floor(add$1(randomUniform($noiseShape, 0, 1, 'float32', seed), keepProb)), keepProb);
21699 return mul($x, multiplier);
21700 }
21701 const dropout = op({ dropout_ });
21702
21703 /**
21704 * @license
21705 * Copyright 2019 Google LLC. All Rights Reserved.
21706 * Licensed under the Apache License, Version 2.0 (the "License");
21707 * you may not use this file except in compliance with the License.
21708 * You may obtain a copy of the License at
21709 *
21710 * http://www.apache.org/licenses/LICENSE-2.0
21711 *
21712 * Unless required by applicable law or agreed to in writing, software
21713 * distributed under the License is distributed on an "AS IS" BASIS,
21714 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21715 * See the License for the specific language governing permissions and
21716 * limitations under the License.
21717 * =============================================================================
21718 */
21719 function enclosingPowerOfTwo(value) {
21720 // Return 2**N for integer N such that 2**N >= value.
21721 return Math.floor(Math.pow(2, Math.ceil(Math.log(value) / Math.log(2.0))));
21722 }
21723 function cosineWindow(windowLength, a, b) {
21724 const even = 1 - windowLength % 2;
21725 const newValues = new Float32Array(windowLength);
21726 for (let i = 0; i < windowLength; ++i) {
21727 const cosArg = (2.0 * Math.PI * i) / (windowLength + even - 1);
21728 newValues[i] = a - b * Math.cos(cosArg);
21729 }
21730 return tensor1d(newValues, 'float32');
21731 }
21732
21733 /**
21734 * @license
21735 * Copyright 2019 Google LLC. All Rights Reserved.
21736 * Licensed under the Apache License, Version 2.0 (the "License");
21737 * you may not use this file except in compliance with the License.
21738 * You may obtain a copy of the License at
21739 *
21740 * http://www.apache.org/licenses/LICENSE-2.0
21741 *
21742 * Unless required by applicable law or agreed to in writing, software
21743 * distributed under the License is distributed on an "AS IS" BASIS,
21744 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21745 * See the License for the specific language governing permissions and
21746 * limitations under the License.
21747 * =============================================================================
21748 */
21749 /**
21750 * Returns whether the targets are in the top K predictions.
21751 *
21752 * ```js
21753 * const predictions = tf.tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
21754 * const targets = tf.tensor1d([2, 0]);
21755 * const precision = await tf.inTopKAsync(predictions, targets);
21756 * precision.print();
21757 * ```
21758 * @param predictions 2-D or higher `tf.Tensor` with last dimension being
21759 * at least `k`.
21760 * @param targets 1-D or higher `tf.Tensor`.
21761 * @param k Optional Number of top elements to look at for computing precision,
21762 * default to 1.
21763 *
21764 * @doc {heading: 'Operations', subheading: 'Evaluation'}
21765 */
21766 async function inTopKAsync_(predictions, targets, k = 1) {
21767 const $predictions = convertToTensor(predictions, 'predictions', 'inTopK');
21768 const $targets = convertToTensor(targets, 'targets', 'inTopK');
21769 assert($predictions.rank > 1, () => 'inTopK() expects the predictions to be of rank 2 or higher, ' +
21770 `but got ${$predictions.rank}`);
21771 assert($predictions.rank - 1 === $targets.rank, () => `predictions rank should be 1 larger than ` +
21772 `targets rank, but got predictions rank ` +
21773 `${$predictions.rank} and targets rank ${$targets.rank}`);
21774 assertShapesMatch($predictions.shape.slice(0, $predictions.shape.length - 1), $targets.shape, `predictions's shape should be align with the targets' shape, ` +
21775 'except the last dimension.');
21776 const lastDim = $predictions.shape[$predictions.shape.length - 1];
21777 assert(k > 0 && k <= lastDim, () => `'k' passed to inTopK() must be > 0 && <= the predictions last ` +
21778 `dimension (${lastDim}), but got ${k}`);
21779 const predictionsVals = await $predictions.data();
21780 const targetsVals = await $targets.data();
21781 // Reshape predictionsVals into a 2d tensor [batch, lastDim]
21782 // and look up topK along lastDim.
21783 const [batch, size] = [predictionsVals.length / lastDim, lastDim];
21784 const precision = getTypedArrayFromDType('bool', batch);
21785 for (let b = 0; b < batch; b++) {
21786 const offset = b * size;
21787 const vals = predictionsVals.subarray(offset, offset + size);
21788 const valAndInd = [];
21789 for (let i = 0; i < vals.length; i++) {
21790 valAndInd.push({ value: vals[i], index: i });
21791 }
21792 valAndInd.sort((a, b) => b.value - a.value);
21793 precision[b] = 0;
21794 for (let i = 0; i < k; i++) {
21795 if (valAndInd[i].index === targetsVals[b]) {
21796 precision[b] = 1;
21797 break;
21798 }
21799 }
21800 }
21801 if (predictions !== $predictions) {
21802 $predictions.dispose();
21803 }
21804 if (targets !== $targets) {
21805 $targets.dispose();
21806 }
21807 // Output precision has the same shape as targets.
21808 return tensor(precision, $targets.shape, 'bool');
21809 }
21810 const inTopKAsync = inTopKAsync_;
21811
21812 /**
21813 * @license
21814 * Copyright 2020 Google LLC. All Rights Reserved.
21815 * Licensed under the Apache License, Version 2.0 (the "License");
21816 * you may not use this file except in compliance with the License.
21817 * You may obtain a copy of the License at
21818 *
21819 * http://www.apache.org/licenses/LICENSE-2.0
21820 *
21821 * Unless required by applicable law or agreed to in writing, software
21822 * distributed under the License is distributed on an "AS IS" BASIS,
21823 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21824 * See the License for the specific language governing permissions and
21825 * limitations under the License.
21826 * =============================================================================
21827 */
21828 /**
21829 * Computes the derivative of the filter of a 2D convolution.
21830 *
21831 * @param x The input tensor, of rank 4 or rank 3 of shape
21832 * [batch, height, width, inChannels]. If rank 3, batch of 1 is assumed.
21833 * @param dy The dy image, of rank 4 or rank 3, of shape
21834 * [batch, height, width, outDepth]. If rank 3, batch of 1 is assumed.
21835 * @param filterShape The shape of the filter, length 4,
21836 * [filterHeight, filterWidth, inDepth, outDepth].
21837 * @param strides The strides of the convolution: [strideHeight,
21838 * strideWidth].
21839 * @param pad A string from: 'same', 'valid'. The type of padding algorithm
21840 * used in the forward prop of the op.
21841 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
21842 * "NHWC". Specify the data format of the input and output data. With the
21843 * default format "NHWC", the data is stored in the order of: [batch,
21844 * height, width, channels].
21845 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
21846 * provided, it will default to truncate.
21847 */
21848 function conv2DBackpropFilter_(x, dy, filterShape, strides, pad, dataFormat = 'NHWC', dimRoundingMode) {
21849 let x4D = x;
21850 if (x.rank === 3) {
21851 x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
21852 }
21853 let dy4D = dy;
21854 if (dy4D.rank === 3) {
21855 dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
21856 }
21857 assert(x4D.rank === 4, () => `Error in conv2dDerFilter: input must be rank 4, but got shape ` +
21858 `${x4D.shape}.`);
21859 assert(dy4D.rank === 4, () => `Error in conv2dDerFilter: dy must be rank 4, but got shape ` +
21860 `${dy4D.shape}.`);
21861 assert(filterShape.length === 4, () => `Error in conv2dDerFilter: filterShape must be length 4, but got ` +
21862 `${filterShape}.`);
21863 const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
21864 const outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
21865 assert(inDepth === filterShape[2], () => `Error in conv2dDerFilter: depth of input ${inDepth}) must ` +
21866 `match input depth in filter (${filterShape[2]}.`);
21867 assert(outDepth === filterShape[3], () => `Error in conv2dDerFilter: depth of dy (${outDepth}) must ` +
21868 `match output depth for filter (${filterShape[3]}).`);
21869 checkPadOnDimRoundingMode('conv2dDerFilter', pad, dimRoundingMode);
21870 const inputs = { x: x4D, dy: dy4D };
21871 const attrs = { strides, pad, dataFormat, dimRoundingMode, filterShape };
21872 // tslint:disable-next-line: no-unnecessary-type-assertion
21873 return ENGINE.runKernel(Conv2DBackpropFilter, inputs, attrs);
21874 }
21875 const conv2DBackpropFilter = op({ conv2DBackpropFilter_ });
21876
21877 /**
21878 * @license
21879 * Copyright 2019 Google LLC. All Rights Reserved.
21880 * Licensed under the Apache License, Version 2.0 (the "License");
21881 * you may not use this file except in compliance with the License.
21882 * You may obtain a copy of the License at
21883 *
21884 * http://www.apache.org/licenses/LICENSE-2.0
21885 *
21886 * Unless required by applicable law or agreed to in writing, software
21887 * distributed under the License is distributed on an "AS IS" BASIS,
21888 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21889 * See the License for the specific language governing permissions and
21890 * limitations under the License.
21891 * =============================================================================
21892 */
21893 // Returns gradient for fused activation.
21894 function getFusedDyActivation(dy, y, activation) {
21895 if (activation == null || activation === 'linear') {
21896 return dy;
21897 }
21898 if (activation === 'relu') {
21899 return mul(dy, step(y));
21900 }
21901 throw new Error(`Cannot compute gradient for fused activation ${activation}.`);
21902 }
21903 // Returns gradient for fused bias.
21904 function getFusedBiasGradient(bias, dyActivation) {
21905 let res = dyActivation;
21906 const reduceAxes = getReductionAxes(bias.shape, dyActivation.shape);
21907 if (reduceAxes.length > 0) {
21908 res = sum$1(res, reduceAxes);
21909 }
21910 return reshape(res, bias.shape);
21911 }
21912 function applyActivation(x, activation, preluActivationWeights, leakyreluAlpha) {
21913 if (activation === 'linear') {
21914 return x;
21915 }
21916 else if (activation === 'relu') {
21917 return relu(x);
21918 }
21919 else if (activation === 'elu') {
21920 return elu(x);
21921 }
21922 else if (activation === 'relu6') {
21923 return relu6(x);
21924 }
21925 else if (activation === 'prelu') {
21926 return prelu(x, preluActivationWeights);
21927 }
21928 else if (activation === 'leakyrelu') {
21929 return leakyRelu(x, leakyreluAlpha);
21930 }
21931 else if (activation === 'sigmoid') {
21932 return sigmoid(x);
21933 }
21934 throw new Error(`Unknown fused activation ${activation}.`);
21935 }
21936 // Whether we should call fused ops.
21937 const shouldFuse = (gradientDepth, activation) => {
21938 const gradientMode = gradientDepth > 0;
21939 return !gradientMode || activation === 'linear';
21940 };
21941
21942 /**
21943 * @license
21944 * Copyright 2019 Google LLC. All Rights Reserved.
21945 * Licensed under the Apache License, Version 2.0 (the "License");
21946 * you may not use this file except in compliance with the License.
21947 * You may obtain a copy of the License at
21948 *
21949 * http://www.apache.org/licenses/LICENSE-2.0
21950 *
21951 * Unless required by applicable law or agreed to in writing, software
21952 * distributed under the License is distributed on an "AS IS" BASIS,
21953 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21954 * See the License for the specific language governing permissions and
21955 * limitations under the License.
21956 * =============================================================================
21957 */
21958 /**
21959 * Computes a 2D convolution over the input x, optionally fused with adding a
21960 * bias and applying an activation.
21961 *
21962 * ```js
21963 * const inputDepth = 2;
21964 * const inShape = [2, 2, 2, inputDepth];
21965 * const outputDepth = 2;
21966 * const fSize = 1;
21967 * const pad = 0;
21968 * const strides = 1;
21969 *
21970 * const x = tf.tensor4d( [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
21971 * 16], inShape);
21972 * const w = tf.tensor4d([-1, 1, -2, 0.5], [fSize, fSize, inputDepth,
21973 * outputDepth]);
21974 *
21975 * tf.fused.conv2d({ x, filter: w, strides, pad, dataFormat: 'NHWC',
21976 * dilations: [1, 1], bias: tf.scalar(5), activation: 'relu' }).print();
21977 * ```
21978 *
21979 * @param obj An object with the following properties:
21980 * @param x The input tensor, of rank 4 or rank 3, of shape
21981 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
21982 * assumed.
21983 * @param filter The filter, rank 4, of shape
21984 * `[filterHeight, filterWidth, inDepth, outDepth]`.
21985 * @param strides The strides of the convolution: `[strideHeight,
21986 * strideWidth]`.
21987 * @param pad The type of padding algorithm.
21988 * - `same` and stride 1: output will be of same size as input,
21989 * regardless of filter size.
21990 * - `valid` output will be smaller than input if filter is larger
21991 * than 1x1.
21992 * - For more info, see this guide:
21993 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
21994 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
21995 * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to
21996 * "NHWC". Specify the data format of the input and output data. With the
21997 * default format "NHWC", the data is stored in the order of: [batch,
21998 * height, width, channels]. Only "NHWC" is currently supported.
21999 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
22000 * in which we sample input values across the height and width dimensions
22001 * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
22002 * number, then `dilationHeight == dilationWidth`. If it is greater than
22003 * 1, then all values of `strides` must be 1.
22004 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
22005 * provided, it will default to truncate.
22006 * @param bias Tensor to be added to the result.
22007 * @param activation Name of activation kernel (defaults to `linear`) to be
22008 * applied
22009 * after biasAdd.
22010 * @param preluActivationWeights Tensor of prelu weights to be applied as part
22011 * of a `prelu` activation, typically the same shape as `x`.
22012 * @param leakyreluAlpha Optional. Alpha to be applied as part of a `leakyrelu`
22013 * activation.
22014 */
22015 function fusedConv2d_({ x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode, bias, activation = 'linear', preluActivationWeights, leakyreluAlpha }) {
22016 activation = activation || 'linear';
22017 if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
22018 // TODO: Transpose bias and preluActivationWeights properly for NCHW
22019 // format before computation.
22020 assert(dataFormat === 'NHWC', () => `Error in fused conv2d: got dataFormat of ${dataFormat} but ` +
22021 `only NHWC is currently supported for the case of gradient depth ` +
22022 `is 0 and the activation is not linear.`);
22023 let result = conv2d(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
22024 if (bias != null) {
22025 result = add$1(result, bias);
22026 }
22027 return applyActivation(result, activation, preluActivationWeights, leakyreluAlpha);
22028 }
22029 const $x = convertToTensor(x, 'x', 'conv2d', 'float32');
22030 const $filter = convertToTensor(filter, 'filter', 'conv2d', 'float32');
22031 let x4D = $x;
22032 let reshapedTo4D = false;
22033 if ($x.rank === 3) {
22034 reshapedTo4D = true;
22035 x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
22036 }
22037 assert(x4D.rank === 4, () => `Error in fused conv2d: input must be rank 4, but got rank ` +
22038 `${x4D.rank}.`);
22039 assert($filter.rank === 4, () => `Error in fused conv2d: filter must be rank 4, but got rank ` +
22040 `${$filter.rank}.`);
22041 checkPadOnDimRoundingMode('fused conv2d', pad, dimRoundingMode);
22042 const inputChannels = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
22043 assert($filter.shape[2] === inputChannels, () => `Error in conv2d: depth of input (${inputChannels}) must match ` +
22044 `input depth for filter ${$filter.shape[2]}.`);
22045 assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in conv2D: Either strides or dilations must be 1. ' +
22046 `Got strides ${strides} and dilations '${dilations}'`);
22047 const convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode);
22048 let $bias;
22049 if (bias != null) {
22050 $bias = convertToTensor(bias, 'bias', 'fused conv2d');
22051 [$bias] = makeTypesMatch($bias, $x);
22052 // According to TensorFlow, the bias is supposed be a 1-D tensor or a
22053 // scalar.
22054 if (dataFormat === 'NHWC') {
22055 assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);
22056 }
22057 else {
22058 assert($bias.shape.length <= 1, () => `Error in fused conv2d: only supports scalar or 1-D Tensor ` +
22059 `bias for NCHW format but got the bias of ` +
22060 `rank-${$bias.shape.length}.`);
22061 assert($bias.shape.length === 0 || $bias.shape[0] === convInfo.outChannels ||
22062 $bias.shape[0] === 1, () => `Error in fused conv2d: bias shape (${$bias.shape}) is not ` +
22063 `compatible with the number of output channels ` +
22064 `(${convInfo.outChannels})`);
22065 }
22066 }
22067 let $preluActivationWeights;
22068 if (preluActivationWeights != null) {
22069 // PReLU's activation weights could be a scalar, a 1-D tensor or a 3-D
22070 // tensor.
22071 const alphaShape = preluActivationWeights.shape;
22072 assert(alphaShape.length <= 1 || alphaShape.length === 3, () => `Error in fused conv2d: only supports scalar, 1-D Tensor or ` +
22073 `3-D Tensor PReLU activation weights but got a tensor of ` +
22074 `rank-${alphaShape.length}.`);
22075 if (alphaShape.length === 1) {
22076 // Whether the data format is NCHW or NHWC, the 1-D PReLU activation
22077 // weights tensor should be aligned with the output channels of conv2d
22078 // result.
22079 assert(alphaShape[0] === 1 || alphaShape[0] === convInfo.outChannels, () => `Error in fused conv2d: PReLU activation weights ` +
22080 `(${alphaShape}) is not compatible with the number of output ` +
22081 `channels (${convInfo.outChannels}).`);
22082 }
22083 else if (alphaShape.length === 3) {
22084 // Whether the data format is NCHW or NHWC, the PReLU activation weights
22085 // tensor should has the compatible shape with the result of conv2d.
22086 try {
22087 assertAndGetBroadcastShape(alphaShape, convInfo.outShape);
22088 }
22089 catch (e) {
22090 const errMsg = `Error in fused conv2d: PReLU activation weights (${alphaShape}) ` +
22091 `is not compatible with the output shape of the conv2d ` +
22092 `(${convInfo.outShape}).`;
22093 throw Error(errMsg);
22094 }
22095 }
22096 $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused conv2d');
22097 }
22098 const grad = (dy, saved) => {
22099 assert(dataFormat === 'NHWC', () => `Error in gradient of fused conv2D: got dataFormat of ${dataFormat} but only NHWC is currently supported.`);
22100 const [$filter, x4D, y, $bias] = saved;
22101 const dyActivation = getFusedDyActivation(dy, y, activation);
22102 assert(tupleValuesAreOne(dilations), () => 'Error in gradient of fused conv2D: ' +
22103 `dilation rates greater than 1 ` +
22104 `are not yet supported in gradients. Got dilations '${dilations}'`);
22105 const xDer = conv2DBackpropInput(x4D.shape, dyActivation, $filter, strides, pad);
22106 const filterDer = conv2DBackpropFilter(x4D, dyActivation, $filter.shape, strides, pad);
22107 const der = [xDer, filterDer];
22108 if ($bias != null) {
22109 const biasDer = getFusedBiasGradient($bias, dyActivation);
22110 der.push(biasDer);
22111 }
22112 return der;
22113 };
22114 const inputs = {
22115 x: x4D,
22116 filter: $filter,
22117 bias: $bias,
22118 preluActivationWeights: $preluActivationWeights
22119 };
22120 const attrs = {
22121 strides,
22122 pad,
22123 dataFormat,
22124 dilations,
22125 dimRoundingMode,
22126 activation,
22127 leakyreluAlpha
22128 };
22129 // Depending on the the params passed in we will have different number of
22130 // inputs and thus a a different number of elements in the gradient.
22131 if (bias == null) {
22132 const customOp = customGrad((x4D, filter, save) => {
22133 let res =
22134 // tslint:disable-next-line: no-unnecessary-type-assertion
22135 ENGINE.runKernel(FusedConv2D, inputs, attrs);
22136 save([filter, x4D, res]);
22137 if (reshapedTo4D) {
22138 // tslint:disable-next-line: no-unnecessary-type-assertion
22139 res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
22140 }
22141 return { value: res, gradFunc: grad };
22142 });
22143 return customOp(x4D, $filter);
22144 }
22145 else {
22146 const customOpWithBias = customGrad((x4D, filter, bias, save) => {
22147 let res = ENGINE.runKernel(FusedConv2D, inputs, attrs);
22148 save([filter, x4D, res, bias]);
22149 if (reshapedTo4D) {
22150 // tslint:disable-next-line: no-unnecessary-type-assertion
22151 res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
22152 }
22153 return { value: res, gradFunc: grad };
22154 });
22155 return customOpWithBias(x4D, $filter, $bias);
22156 }
22157 }
22158 const conv2d$1 = op({ fusedConv2d_ });
22159
22160 /**
22161 * @license
22162 * Copyright 2020 Google LLC. All Rights Reserved.
22163 * Licensed under the Apache License, Version 2.0 (the "License");
22164 * you may not use this file except in compliance with the License.
22165 * You may obtain a copy of the License at
22166 *
22167 * http://www.apache.org/licenses/LICENSE-2.0
22168 *
22169 * Unless required by applicable law or agreed to in writing, software
22170 * distributed under the License is distributed on an "AS IS" BASIS,
22171 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22172 * See the License for the specific language governing permissions and
22173 * limitations under the License.
22174 * =============================================================================
22175 */
22176 function depthwiseConv2dNativeBackpropFilter_(x, dy, filterShape, strides, pad, dilations = [1, 1], dimRoundingMode) {
22177 let x4D = x;
22178 if (x.rank === 3) {
22179 x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
22180 }
22181 let dy4D = dy;
22182 if (dy4D.rank === 3) {
22183 dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
22184 }
22185 const inputs = { x: x4D, dy: dy4D };
22186 const attrs = { strides, pad, dimRoundingMode, dilations, filterShape };
22187 // tslint:disable-next-line: no-unnecessary-type-assertion
22188 return ENGINE.runKernel(DepthwiseConv2dNativeBackpropFilter, inputs, attrs);
22189 }
22190 const depthwiseConv2dNativeBackpropFilter = op({ depthwiseConv2dNativeBackpropFilter_ });
22191
22192 /**
22193 * @license
22194 * Copyright 2020 Google LLC. All Rights Reserved.
22195 * Licensed under the Apache License, Version 2.0 (the "License");
22196 * you may not use this file except in compliance with the License.
22197 * You may obtain a copy of the License at
22198 *
22199 * http://www.apache.org/licenses/LICENSE-2.0
22200 *
22201 * Unless required by applicable law or agreed to in writing, software
22202 * distributed under the License is distributed on an "AS IS" BASIS,
22203 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22204 * See the License for the specific language governing permissions and
22205 * limitations under the License.
22206 * =============================================================================
22207 */
22208 function depthwiseConv2dNativeBackpropInput_(xShape, dy, filter, strides, pad, dilations = [1, 1], dimRoundingMode) {
22209 let dy4D = dy;
22210 let reshapedTo4D = false;
22211 if (dy.rank === 3) {
22212 reshapedTo4D = true;
22213 dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
22214 }
22215 const inputs = { dy: dy4D, filter };
22216 const attrs = { strides, pad, dimRoundingMode, dilations, inputShape: xShape };
22217 const res =
22218 // tslint:disable-next-line: no-unnecessary-type-assertion
22219 ENGINE.runKernel(DepthwiseConv2dNativeBackpropInput, inputs, attrs);
22220 if (reshapedTo4D) {
22221 return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
22222 }
22223 return res;
22224 }
22225 const depthwiseConv2dNativeBackpropInput = op({ depthwiseConv2dNativeBackpropInput_ });
22226
22227 /**
22228 * @license
22229 * Copyright 2019 Google LLC. All Rights Reserved.
22230 * Licensed under the Apache License, Version 2.0 (the "License");
22231 * you may not use this file except in compliance with the License.
22232 * You may obtain a copy of the License at
22233 *
22234 * http://www.apache.org/licenses/LICENSE-2.0
22235 *
22236 * Unless required by applicable law or agreed to in writing, software
22237 * distributed under the License is distributed on an "AS IS" BASIS,
22238 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22239 * See the License for the specific language governing permissions and
22240 * limitations under the License.
22241 * =============================================================================
22242 */
22243 /**
22244 * Computes depthwise 2D convolution, optionally fused with adding a
22245 * bias and applying an activation.
22246 *
22247 * Given a 4D `input` array and a `filter` array of shape
22248 * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing
22249 * `inChannels` convolutional filters of depth 1, this op applies a
22250 * different filter to each input channel (expanding from 1 channel to
22251 * `channelMultiplier` channels for each), then concatenates the results
22252 * together. The output has `inChannels * channelMultiplier` channels.
22253 *
22254 * See
22255 * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d](
22256 * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d)
22257 * for more details.
22258 *
22259 * @param obj An object with the following properties:
22260 * @param x The input tensor, of rank 4 or rank 3, of shape
22261 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
22262 * assumed.
22263 * @param filter The filter tensor, rank 4, of shape
22264 * `[filterHeight, filterWidth, inChannels, channelMultiplier]`.
22265 * @param strides The strides of the convolution: `[strideHeight,
22266 * strideWidth]`. If strides is a single number, then `strideHeight ==
22267 * strideWidth`.
22268 * @param pad The type of padding algorithm.
22269 * - `same` and stride 1: output will be of same size as input,
22270 * regardless of filter size.
22271 * - `valid`: output will be smaller than input if filter is larger
22272 * than 1x1.
22273 * - For more info, see this guide:
22274 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
22275 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
22276 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
22277 * in which we sample input values across the height and width dimensions
22278 * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
22279 * number, then `dilationHeight == dilationWidth`. If it is greater than
22280 * 1, then all values of `strides` must be 1.
22281 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
22282 * "NHWC". Specify the data format of the input and output data. With the
22283 * default format "NHWC", the data is stored in the order of: [batch,
22284 * height, width, channels]. Only "NHWC" is currently supported.
22285 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
22286 * provided, it will default to truncate.
22287 * @param bias Tensor to be added to the result.
22288 * @param activation Name of activation kernel (defaults to `linear`).
22289 * @param preluActivationWeights Tensor of prelu weights to be applied as part
22290 * of a `prelu` activation, typically the same shape as `x`.
22291 * @param leakyreluAlpha Optional. Alpha to be applied as part of a `leakyrelu`
22292 * activation.
22293 */
22294 function fusedDepthwiseConv2d_({ x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode, bias, activation = 'linear', preluActivationWeights, leakyreluAlpha }) {
22295 if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
22296 let result = depthwiseConv2d(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
22297 if (bias != null) {
22298 result = add$1(result, bias);
22299 }
22300 return applyActivation(result, activation, preluActivationWeights, leakyreluAlpha);
22301 }
22302 const $x = convertToTensor(x, 'x', 'depthwiseConv2d', 'float32');
22303 const $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d', 'float32');
22304 let x4D = $x;
22305 let reshapedTo4D = false;
22306 if ($x.rank === 3) {
22307 reshapedTo4D = true;
22308 x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
22309 }
22310 assert(x4D.rank === 4, () => `Error in fused depthwiseConv2d: input must be rank 4, but got ` +
22311 `rank ${x4D.rank}.`);
22312 assert($filter.rank === 4, () => `Error in fused depthwiseConv2d: filter must be rank 4, ` +
22313 `but got rank ${$filter.rank}.`);
22314 assert(x4D.shape[3] === $filter.shape[2], () => `Error in fused depthwiseConv2d: number of input channels ` +
22315 `(${x4D.shape[3]}) must match the inChannels dimension in ` +
22316 `filter ${$filter.shape[2]}.`);
22317 if (dilations == null) {
22318 dilations = [1, 1];
22319 }
22320 assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in fused depthwiseConv2d: Either strides or dilations must ' +
22321 `be 1. Got strides ${strides} and dilations '${dilations}'`);
22322 checkPadOnDimRoundingMode('fused depthwiseConv2d', pad, dimRoundingMode);
22323 const convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
22324 let $bias;
22325 if (bias != null) {
22326 $bias = convertToTensor(bias, 'bias', 'fused conv2d');
22327 [$bias] = makeTypesMatch($bias, $x);
22328 assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);
22329 }
22330 let $preluActivationWeights;
22331 if (preluActivationWeights != null) {
22332 $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused depthwiseConv2d');
22333 }
22334 const grad = (dy, saved) => {
22335 assert(tupleValuesAreOne(dilations), () => 'Error in gradient of fused depthwiseConv2d: dilation rates ' +
22336 `greater than 1 are not yet supported. Got dilations ` +
22337 `'${dilations}'`);
22338 const [$filter, x4D, y, bias] = saved;
22339 const dyActivation = getFusedDyActivation(dy, y, activation);
22340 const xDer = depthwiseConv2dNativeBackpropInput(x4D.shape, dyActivation, $filter, strides, pad, dilations, dimRoundingMode);
22341 const filterDer = depthwiseConv2dNativeBackpropFilter(x4D, dyActivation, $filter.shape, strides, pad, dilations, dimRoundingMode);
22342 if (bias != null) {
22343 const biasDer = getFusedBiasGradient($bias, dyActivation);
22344 return [xDer, filterDer, biasDer];
22345 }
22346 return [xDer, filterDer];
22347 };
22348 const inputs = {
22349 x: x4D,
22350 filter: $filter,
22351 bias: $bias,
22352 preluActivationWeights: $preluActivationWeights
22353 };
22354 const attrs = {
22355 strides,
22356 pad,
22357 dataFormat,
22358 dilations,
22359 dimRoundingMode,
22360 activation,
22361 leakyreluAlpha
22362 };
22363 // Depending on the the params passed in we will have different number of
22364 // inputs and thus a a different number of elements in the gradient.
22365 if (bias == null) {
22366 const customOp = customGrad((x4D, filter, save) => {
22367 // tslint:disable-next-line: no-unnecessary-type-assertion
22368 let res = ENGINE.runKernel(FusedDepthwiseConv2D, inputs, attrs);
22369 save([filter, x4D, res]);
22370 if (reshapedTo4D) {
22371 // tslint:disable-next-line: no-unnecessary-type-assertion
22372 res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
22373 }
22374 return { value: res, gradFunc: grad };
22375 });
22376 return customOp(x4D, $filter);
22377 }
22378 else {
22379 const customOpWithBias = customGrad((x4D, filter, bias, save) => {
22380 // tslint:disable-next-line: no-unnecessary-type-assertion
22381 let res = ENGINE.runKernel(FusedDepthwiseConv2D, inputs, attrs);
22382 save([filter, x4D, res, bias]);
22383 if (reshapedTo4D) {
22384 // tslint:disable-next-line: no-unnecessary-type-assertion
22385 res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
22386 }
22387 return { value: res, gradFunc: grad };
22388 });
22389 return customOpWithBias(x4D, $filter, $bias);
22390 }
22391 }
22392 const depthwiseConv2d$1 = op({ fusedDepthwiseConv2d_ });
22393
22394 /**
22395 * @license
22396 * Copyright 2019 Google LLC. All Rights Reserved.
22397 * Licensed under the Apache License, Version 2.0 (the "License");
22398 * you may not use this file except in compliance with the License.
22399 * You may obtain a copy of the License at
22400 *
22401 * http://www.apache.org/licenses/LICENSE-2.0
22402 *
22403 * Unless required by applicable law or agreed to in writing, software
22404 * distributed under the License is distributed on an "AS IS" BASIS,
22405 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22406 * See the License for the specific language governing permissions and
22407 * limitations under the License.
22408 * =============================================================================
22409 */
22410 /**
22411 * Computes the dot product of two matrices with optional activation and bias.
22412 *
22413 * ```js
22414 * const a = tf.tensor2d([-1, -2], [1, 2]);
22415 * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]);
22416 * const bias = tf.tensor2d([1, 2], [1, 2]);
22417 *
22418 * tf.fused.matMul({a, b, bias, activation: 'relu'}).print();
22419 * ```
22420 *
22421 * @param obj An object with the following properties:
22422 * - `a` First matrix in dot product operation.
22423 * - `b` Second matrix in dot product operation.
22424 * - `transposeA` If true, `a` is transposed before multiplication.
22425 * - `transposeB` If true, `b` is transposed before multiplication.
22426 * - `bias` Matrix to be added to the result.
22427 * - `activation` Name of activation kernel (defaults to `linear`).
22428 * - `preluActivationWeights` Tensor of prelu weights.
22429 * - `leakyreluAlpha` Alpha of leakyrelu.
22430 */
22431 function fusedMatMul_({ a, b, transposeA = false, transposeB = false, bias, activation = 'linear', preluActivationWeights, leakyreluAlpha, }) {
22432 if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
22433 let result = matMul(a, b, transposeA, transposeB);
22434 if (bias != null) {
22435 result = add$1(result, bias);
22436 }
22437 return applyActivation(result, activation, preluActivationWeights, leakyreluAlpha);
22438 }
22439 let $a = convertToTensor(a, 'a', 'fused matMul');
22440 let $b = convertToTensor(b, 'b', 'fused matMul');
22441 [$a, $b] = makeTypesMatch($a, $b);
22442 const innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];
22443 const innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];
22444 const outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];
22445 const outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];
22446 const outerDimsA = $a.shape.slice(0, -2);
22447 const outerDimsB = $b.shape.slice(0, -2);
22448 const batchDimA = sizeFromShape(outerDimsA);
22449 const batchDimB = sizeFromShape(outerDimsB);
22450 assert(innerShapeA === innerShapeB, () => `Error in fused matMul: inner shapes (${innerShapeA}) and (` +
22451 `${innerShapeB}) of Tensors with shapes ${$a.shape} and ` +
22452 `${$b.shape} and transposeA=${transposeA}` +
22453 ` and transposeB=${transposeB} must match.`);
22454 const outShapeOuterDims = assertAndGetBroadcastShape($a.shape.slice(0, -2), $b.shape.slice(0, -2));
22455 const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
22456 const a3D = transposeA ?
22457 reshape($a, [batchDimA, innerShapeA, outerShapeA]) :
22458 reshape($a, [batchDimA, outerShapeA, innerShapeA]);
22459 const b3D = transposeB ?
22460 reshape($b, [batchDimB, outerShapeB, innerShapeB]) :
22461 reshape($b, [batchDimB, innerShapeB, outerShapeB]);
22462 let $bias;
22463 if (bias != null) {
22464 $bias = convertToTensor(bias, 'bias', 'fused matMul');
22465 [$bias] = makeTypesMatch($bias, $a);
22466 assertAndGetBroadcastShape(outShape, $bias.shape);
22467 }
22468 let $preluActivationWeights;
22469 if (preluActivationWeights != null) {
22470 $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused matMul');
22471 }
22472 const grad = (dy, saved) => {
22473 const [a3D, b3D, y, $bias] = saved;
22474 // we reshape dy because the result of the forward is not
22475 // necessarily going to be a 3d tensor due to a reshape done at the end of
22476 // the customOp.
22477 const dyActivation = getFusedDyActivation(reshape(dy, y.shape), y, activation);
22478 let aDer;
22479 let bDer;
22480 if (!transposeA && !transposeB) {
22481 aDer = matMul(dyActivation, b3D, false, true);
22482 bDer = matMul(a3D, dyActivation, true, false);
22483 }
22484 else if (!transposeA && transposeB) {
22485 aDer = matMul(dyActivation, b3D, false, false);
22486 bDer = matMul(dyActivation, a3D, true, false);
22487 }
22488 else if (transposeA && !transposeB) {
22489 aDer = matMul(b3D, dyActivation, false, true);
22490 bDer = matMul(a3D, dyActivation, false, false);
22491 }
22492 else {
22493 aDer = matMul(b3D, dyActivation, true, true);
22494 bDer = matMul(dyActivation, a3D, true, true);
22495 }
22496 if (bias != null) {
22497 const biasDer = getFusedBiasGradient($bias, dyActivation);
22498 return [aDer, bDer, biasDer];
22499 }
22500 else {
22501 return [aDer, bDer];
22502 }
22503 };
22504 const inputs = {
22505 a: a3D,
22506 b: b3D,
22507 bias: $bias,
22508 preluActivationWeights: $preluActivationWeights
22509 };
22510 const attrs = { transposeA, transposeB, activation, leakyreluAlpha };
22511 // Depending on the the params passed in we will have different number of
22512 // inputs and thus a a different number of elements in the gradient.
22513 if (bias == null) {
22514 const customOp = customGrad((a3D, b3D, save) => {
22515 const res =
22516 // tslint:disable-next-line: no-unnecessary-type-assertion
22517 ENGINE.runKernel(_FusedMatMul, inputs, attrs);
22518 save([a3D, b3D, res]);
22519 return { value: reshape(res, outShape), gradFunc: grad };
22520 });
22521 return customOp(a3D, b3D);
22522 }
22523 else {
22524 const customOpWithBias = customGrad((a3D, b3D, $bias, save) => {
22525 const res =
22526 // tslint:disable-next-line: no-unnecessary-type-assertion
22527 ENGINE.runKernel(_FusedMatMul, inputs, attrs);
22528 save([a3D, b3D, res, $bias]);
22529 return { value: reshape(res, outShape), gradFunc: grad };
22530 });
22531 return customOpWithBias(a3D, b3D, $bias);
22532 }
22533 }
22534 const matMul$1 = op({ fusedMatMul_ });
22535
22536 /**
22537 * @license
22538 * Copyright 2019 Google LLC. All Rights Reserved.
22539 * Licensed under the Apache License, Version 2.0 (the "License");
22540 * you may not use this file except in compliance with the License.
22541 * You may obtain a copy of the License at
22542 *
22543 * http://www.apache.org/licenses/LICENSE-2.0
22544 *
22545 * Unless required by applicable law or agreed to in writing, software
22546 * distributed under the License is distributed on an "AS IS" BASIS,
22547 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22548 * See the License for the specific language governing permissions and
22549 * limitations under the License.
22550 * =============================================================================
22551 */
22552
22553 var fused_ops = /*#__PURE__*/Object.freeze({
22554 __proto__: null,
22555 conv2d: conv2d$1,
22556 depthwiseConv2d: depthwiseConv2d$1,
22557 matMul: matMul$1
22558 });
22559
22560 /**
22561 * @license
22562 * Copyright 2019 Google LLC. All Rights Reserved.
22563 * Licensed under the Apache License, Version 2.0 (the "License");
22564 * you may not use this file except in compliance with the License.
22565 * You may obtain a copy of the License at
22566 *
22567 * http://www.apache.org/licenses/LICENSE-2.0
22568 *
22569 * Unless required by applicable law or agreed to in writing, software
22570 * distributed under the License is distributed on an "AS IS" BASIS,
22571 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22572 * See the License for the specific language governing permissions and
22573 * limitations under the License.
22574 * =============================================================================
22575 */
22576 /**
22577 * Generate a hamming window.
22578 *
22579 * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
22580 *
22581 * ```js
22582 * tf.signal.hammingWindow(10).print();
22583 * ```
22584 * @param The length of window
22585 *
22586 * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
22587 */
22588 function hammingWindow_(windowLength) {
22589 return cosineWindow(windowLength, 0.54, 0.46);
22590 }
22591 const hammingWindow = op({ hammingWindow_ });
22592
22593 /**
22594 * @license
22595 * Copyright 2019 Google LLC. All Rights Reserved.
22596 * Licensed under the Apache License, Version 2.0 (the "License");
22597 * you may not use this file except in compliance with the License.
22598 * You may obtain a copy of the License at
22599 *
22600 * http://www.apache.org/licenses/LICENSE-2.0
22601 *
22602 * Unless required by applicable law or agreed to in writing, software
22603 * distributed under the License is distributed on an "AS IS" BASIS,
22604 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22605 * See the License for the specific language governing permissions and
22606 * limitations under the License.
22607 * =============================================================================
22608 */
22609 /**
22610 * Generate a Hann window.
22611 *
22612 * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
22613 *
22614 * ```js
22615 * tf.signal.hannWindow(10).print();
22616 * ```
22617 * @param The length of window
22618 *
22619 * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
22620 */
22621 function hannWindow_(windowLength) {
22622 return cosineWindow(windowLength, 0.5, 0.5);
22623 }
22624 const hannWindow = op({ hannWindow_ });
22625
22626 /**
22627 * @license
22628 * Copyright 2019 Google LLC. All Rights Reserved.
22629 * Licensed under the Apache License, Version 2.0 (the "License");
22630 * you may not use this file except in compliance with the License.
22631 * You may obtain a copy of the License at
22632 *
22633 * http://www.apache.org/licenses/LICENSE-2.0
22634 *
22635 * Unless required by applicable law or agreed to in writing, software
22636 * distributed under the License is distributed on an "AS IS" BASIS,
22637 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22638 * See the License for the specific language governing permissions and
22639 * limitations under the License.
22640 * =============================================================================
22641 */
22642 /**
22643 * Expands input into frames of frameLength.
22644 * Slides a window size with frameStep.
22645 *
22646 * ```js
22647 * tf.signal.frame([1, 2, 3], 2, 1).print();
22648 * ```
22649 * @param signal The input tensor to be expanded
22650 * @param frameLength Length of each frame
22651 * @param frameStep The frame hop size in samples.
22652 * @param padEnd Whether to pad the end of signal with padValue.
22653 * @param padValue An number to use where the input signal does
22654 * not exist when padEnd is True.
22655 *
22656 * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
22657 */
22658 function frame_(signal, frameLength, frameStep, padEnd = false, padValue = 0) {
22659 let start = 0;
22660 const output = [];
22661 while (start + frameLength <= signal.size) {
22662 output.push(slice(signal, start, frameLength));
22663 start += frameStep;
22664 }
22665 if (padEnd) {
22666 while (start < signal.size) {
22667 const padLen = (start + frameLength) - signal.size;
22668 const pad = concat([
22669 slice(signal, start, frameLength - padLen), fill([padLen], padValue)
22670 ]);
22671 output.push(pad);
22672 start += frameStep;
22673 }
22674 }
22675 if (output.length === 0) {
22676 return tensor2d([], [0, frameLength]);
22677 }
22678 return reshape(concat(output), [output.length, frameLength]);
22679 }
22680 const frame = op({ frame_ });
22681
22682 /**
22683 * @license
22684 * Copyright 2019 Google LLC. All Rights Reserved.
22685 * Licensed under the Apache License, Version 2.0 (the "License");
22686 * you may not use this file except in compliance with the License.
22687 * You may obtain a copy of the License at
22688 *
22689 * http://www.apache.org/licenses/LICENSE-2.0
22690 *
22691 * Unless required by applicable law or agreed to in writing, software
22692 * distributed under the License is distributed on an "AS IS" BASIS,
22693 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22694 * See the License for the specific language governing permissions and
22695 * limitations under the License.
22696 * =============================================================================
22697 */
22698 /**
22699 * Computes the Short-time Fourier Transform of signals
22700 * See: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
22701 *
22702 * ```js
22703 * const input = tf.tensor1d([1, 1, 1, 1, 1])
22704 * tf.signal.stft(input, 3, 1).print();
22705 * ```
22706 * @param signal 1-dimensional real value tensor.
22707 * @param frameLength The window length of samples.
22708 * @param frameStep The number of samples to step.
22709 * @param fftLength The size of the FFT to apply.
22710 * @param windowFn A callable that takes a window length and returns 1-d tensor.
22711 *
22712 * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
22713 */
22714 function stft_(signal, frameLength, frameStep, fftLength, windowFn = hannWindow) {
22715 if (fftLength == null) {
22716 fftLength = enclosingPowerOfTwo(frameLength);
22717 }
22718 const framedSignal = frame(signal, frameLength, frameStep);
22719 const windowedSignal = mul(framedSignal, windowFn(frameLength));
22720 return rfft(windowedSignal, fftLength);
22721 }
22722 const stft = op({ stft_ });
22723
22724 /**
22725 * @license
22726 * Copyright 2020 Google LLC. All Rights Reserved.
22727 * Licensed under the Apache License, Version 2.0 (the "License");
22728 * you may not use this file except in compliance with the License.
22729 * You may obtain a copy of the License at
22730 *
22731 * http://www.apache.org/licenses/LICENSE-2.0
22732 *
22733 * Unless required by applicable law or agreed to in writing, software
22734 * distributed under the License is distributed on an "AS IS" BASIS,
22735 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22736 * See the License for the specific language governing permissions and
22737 * limitations under the License.
22738 * =============================================================================
22739 */
22740 /**
22741 * Extracts crops from the input image tensor and resizes them using bilinear
22742 * sampling or nearest neighbor sampling (possibly with aspect ratio change)
22743 * to a common output size specified by cropSize.
22744 *
22745 * @param image 4d tensor of shape `[batch,imageHeight,imageWidth, depth]`,
22746 * where imageHeight and imageWidth must be positive, specifying the
22747 * batch of images from which to take crops
22748 * @param boxes 2d float32 tensor of shape `[numBoxes, 4]`. Each entry is
22749 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the normalized
22750 * coordinates of the box in the boxInd[i]'th image in the batch
22751 * @param boxInd 1d int32 tensor of shape `[numBoxes]` with values in range
22752 * `[0, batch)` that specifies the image that the `i`-th box refers to.
22753 * @param cropSize 1d int32 tensor of 2 elements `[cropHeigh, cropWidth]`
22754 * specifying the size to which all crops are resized to.
22755 * @param method Optional string from `'bilinear' | 'nearest'`,
22756 * defaults to bilinear, which specifies the sampling method for resizing
22757 * @param extrapolationValue A threshold for deciding when to remove boxes based
22758 * on score. Defaults to 0.
22759 * @return A 4D tensor of the shape `[numBoxes,cropHeight,cropWidth,depth]`
22760 *
22761 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
22762 */
22763 function cropAndResize_(image, boxes, boxInd, cropSize, method = 'bilinear', extrapolationValue = 0) {
22764 const $image = convertToTensor(image, 'image', 'cropAndResize');
22765 const $boxes = convertToTensor(boxes, 'boxes', 'cropAndResize', 'float32');
22766 const $boxInd = convertToTensor(boxInd, 'boxInd', 'cropAndResize', 'int32');
22767 const numBoxes = $boxes.shape[0];
22768 assert($image.rank === 4, () => 'Error in cropAndResize: image must be rank 4,' +
22769 `but got rank ${$image.rank}.`);
22770 assert($boxes.rank === 2 && $boxes.shape[1] === 4, () => `Error in cropAndResize: boxes must be have size [${numBoxes},4] ` +
22771 `but had shape ${$boxes.shape}.`);
22772 assert($boxInd.rank === 1 && $boxInd.shape[0] === numBoxes, () => `Error in cropAndResize: boxInd must be have size [${numBoxes}] ` +
22773 `but had shape ${$boxes.shape}.`);
22774 assert(cropSize.length === 2, () => `Error in cropAndResize: cropSize must be of length 2, but got ` +
22775 `length ${cropSize.length}.`);
22776 assert(cropSize[0] >= 1 && cropSize[1] >= 1, () => `cropSize must be atleast [1,1], but was ${cropSize}`);
22777 assert(method === 'bilinear' || method === 'nearest', () => `method must be bilinear or nearest, but was ${method}`);
22778 const inputs = { image: $image, boxes: $boxes, boxInd: $boxInd };
22779 const attrs = { method, extrapolationValue, cropSize };
22780 const res = ENGINE.runKernel(CropAndResize, inputs, attrs);
22781 return res;
22782 }
22783 const cropAndResize = op({ cropAndResize_ });
22784
22785 /**
22786 * @license
22787 * Copyright 2020 Google LLC. All Rights Reserved.
22788 * Licensed under the Apache License, Version 2.0 (the "License");
22789 * you may not use this file except in compliance with the License.
22790 * You may obtain a copy of the License at
22791 *
22792 * http://www.apache.org/licenses/LICENSE-2.0
22793 *
22794 * Unless required by applicable law or agreed to in writing, software
22795 * distributed under the License is distributed on an "AS IS" BASIS,
22796 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22797 * See the License for the specific language governing permissions and
22798 * limitations under the License.
22799 * =============================================================================
22800 */
22801 /**
22802 * Flips the image left to right. Currently available in the CPU, WebGL, and
22803 * WASM backends.
22804 *
22805 * @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
22806 */
22807 /** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */
22808 function flipLeftRight_(image) {
22809 const $image = convertToTensor(image, 'image', 'flipLeftRight', 'float32');
22810 assert($image.rank === 4, () => 'Error in flipLeftRight: image must be rank 4,' +
22811 `but got rank ${$image.rank}.`);
22812 const inputs = { image: $image };
22813 const res = ENGINE.runKernel(FlipLeftRight, inputs, {});
22814 return res;
22815 }
22816 const flipLeftRight = op({ flipLeftRight_ });
22817
22818 /**
22819 * @license
22820 * Copyright 2021 Google LLC. All Rights Reserved.
22821 * Licensed under the Apache License, Version 2.0 (the "License");
22822 * you may not use this file except in compliance with the License.
22823 * You may obtain a copy of the License at
22824 *
22825 * http://www.apache.org/licenses/LICENSE-2.0
22826 *
22827 * Unless required by applicable law or agreed to in writing, software
22828 * distributed under the License is distributed on an "AS IS" BASIS,
22829 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22830 * See the License for the specific language governing permissions and
22831 * limitations under the License.
22832 * =============================================================================
22833 */
22834 /**
22835 * Converts images from grayscale to RGB format.
22836 *
22837 * @param image A grayscale tensor to convert. The `image`'s last dimension must
22838 * be size 1 with at least a two-dimensional shape.
22839 *
22840 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
22841 */
22842 function grayscaleToRGB_(image) {
22843 const $image = convertToTensor(image, 'image', 'grayscaleToRGB');
22844 const lastDimsIdx = $image.rank - 1;
22845 const lastDims = $image.shape[lastDimsIdx];
22846 assert($image.rank >= 2, () => 'Error in grayscaleToRGB: images must be at least rank 2, ' +
22847 `but got rank ${$image.rank}.`);
22848 assert(lastDims === 1, () => 'Error in grayscaleToRGB: last dimension of a grayscale image ' +
22849 `should be size 1, but got size ${lastDims}.`);
22850 const reps = new Array($image.rank);
22851 reps.fill(1, 0, lastDimsIdx);
22852 reps[lastDimsIdx] = 3;
22853 return tile($image, reps);
22854 }
22855 const grayscaleToRGB = op({ grayscaleToRGB_ });
22856
22857 /**
22858 * @license
22859 * Copyright 2020 Google LLC. All Rights Reserved.
22860 * Licensed under the Apache License, Version 2.0 (the "License");
22861 * you may not use this file except in compliance with the License.
22862 * You may obtain a copy of the License at
22863 *
22864 * http://www.apache.org/licenses/LICENSE-2.0
22865 *
22866 * Unless required by applicable law or agreed to in writing, software
22867 * distributed under the License is distributed on an "AS IS" BASIS,
22868 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22869 * See the License for the specific language governing permissions and
22870 * limitations under the License.
22871 * =============================================================================
22872 */
22873 /**
22874 * Rotates the input image tensor counter-clockwise with an optional offset
22875 * center of rotation. Currently available in the CPU, WebGL, and WASM backends.
22876 *
22877 * @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
22878 * @param radians The amount of rotation.
22879 * @param fillValue The value to fill in the empty space leftover
22880 * after rotation. Can be either a single grayscale value (0-255), or an
22881 * array of three numbers `[red, green, blue]` specifying the red, green,
22882 * and blue channels. Defaults to `0` (black).
22883 * @param center The center of rotation. Can be either a single value (0-1), or
22884 * an array of two numbers `[centerX, centerY]`. Defaults to `0.5` (rotates
22885 * the image around its center).
22886 *
22887 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
22888 */
22889 function rotateWithOffset_(image, radians, fillValue = 0, center = 0.5) {
22890 const $image = convertToTensor(image, 'image', 'rotateWithOffset', 'float32');
22891 assert($image.rank === 4, () => 'Error in rotateWithOffset: image must be rank 4,' +
22892 `but got rank ${$image.rank}.`);
22893 const inputs = { image: $image };
22894 const attrs = { radians, fillValue, center };
22895 const res = ENGINE.runKernel(RotateWithOffset, inputs, attrs);
22896 return res;
22897 }
22898 const rotateWithOffset = op({ rotateWithOffset_ });
22899
22900 /**
22901 * @license
22902 * Copyright 2020 Google LLC. All Rights Reserved.
22903 * Licensed under the Apache License, Version 2.0 (the "License");
22904 * you may not use this file except in compliance with the License.
22905 * You may obtain a copy of the License at
22906 *
22907 * http://www.apache.org/licenses/LICENSE-2.0
22908 *
22909 * Unless required by applicable law or agreed to in writing, software
22910 * distributed under the License is distributed on an "AS IS" BASIS,
22911 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22912 * See the License for the specific language governing permissions and
22913 * limitations under the License.
22914 * =============================================================================
22915 */
22916 function nonMaxSuppSanityCheck(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
22917 if (iouThreshold == null) {
22918 iouThreshold = 0.5;
22919 }
22920 if (scoreThreshold == null) {
22921 scoreThreshold = Number.NEGATIVE_INFINITY;
22922 }
22923 if (softNmsSigma == null) {
22924 softNmsSigma = 0.0;
22925 }
22926 const numBoxes = boxes.shape[0];
22927 maxOutputSize = Math.min(maxOutputSize, numBoxes);
22928 assert(0 <= iouThreshold && iouThreshold <= 1, () => `iouThreshold must be in [0, 1], but was '${iouThreshold}'`);
22929 assert(boxes.rank === 2, () => `boxes must be a 2D tensor, but was of rank '${boxes.rank}'`);
22930 assert(boxes.shape[1] === 4, () => `boxes must have 4 columns, but 2nd dimension was ${boxes.shape[1]}`);
22931 assert(scores.rank === 1, () => 'scores must be a 1D tensor');
22932 assert(scores.shape[0] === numBoxes, () => `scores has incompatible shape with boxes. Expected ${numBoxes}, ` +
22933 `but was ${scores.shape[0]}`);
22934 assert(0 <= softNmsSigma && softNmsSigma <= 1, () => `softNmsSigma must be in [0, 1], but was '${softNmsSigma}'`);
22935 return { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma };
22936 }
22937
22938 /**
22939 * @license
22940 * Copyright 2020 Google LLC. All Rights Reserved.
22941 * Licensed under the Apache License, Version 2.0 (the "License");
22942 * you may not use this file except in compliance with the License.
22943 * You may obtain a copy of the License at
22944 *
22945 * http://www.apache.org/licenses/LICENSE-2.0
22946 *
22947 * Unless required by applicable law or agreed to in writing, software
22948 * distributed under the License is distributed on an "AS IS" BASIS,
22949 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22950 * See the License for the specific language governing permissions and
22951 * limitations under the License.
22952 * =============================================================================
22953 */
22954 /**
22955 * Performs non maximum suppression of bounding boxes based on
22956 * iou (intersection over union).
22957 *
22958 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
22959 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
22960 * the bounding box.
22961 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
22962 * @param maxOutputSize The maximum number of boxes to be selected.
22963 * @param iouThreshold A float representing the threshold for deciding whether
22964 * boxes overlap too much with respect to IOU. Must be between [0, 1].
22965 * Defaults to 0.5 (50% box overlap).
22966 * @param scoreThreshold A threshold for deciding when to remove boxes based
22967 * on score. Defaults to -inf, which means any score is accepted.
22968 * @return A 1D tensor with the selected box indices.
22969 *
22970 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
22971 */
22972 function nonMaxSuppression_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY) {
22973 const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression', 'float32');
22974 const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression', 'float32');
22975 const inputs = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);
22976 maxOutputSize = inputs.maxOutputSize;
22977 iouThreshold = inputs.iouThreshold;
22978 scoreThreshold = inputs.scoreThreshold;
22979 const attrs = { maxOutputSize, iouThreshold, scoreThreshold };
22980 return ENGINE.runKernel(NonMaxSuppressionV3, { boxes: $boxes, scores: $scores }, attrs);
22981 }
22982 const nonMaxSuppression = op({ nonMaxSuppression_ });
22983
22984 /**
22985 * @license
22986 * Copyright 2019 Google LLC. All Rights Reserved.
22987 * Licensed under the Apache License, Version 2.0 (the "License");
22988 * you may not use this file except in compliance with the License.
22989 * You may obtain a copy of the License at
22990 *
22991 * http://www.apache.org/licenses/LICENSE-2.0
22992 *
22993 * Unless required by applicable law or agreed to in writing, software
22994 * distributed under the License is distributed on an "AS IS" BASIS,
22995 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22996 * See the License for the specific language governing permissions and
22997 * limitations under the License.
22998 * =============================================================================
22999 */
23000 /**
23001 * Inserts a value into a sorted array. This method allows duplicate, meaning it
23002 * allows inserting duplicate value, in which case, the element will be inserted
23003 * at the lowest index of the value.
23004 * @param arr The array to modify.
23005 * @param element The element to insert.
23006 * @param comparator Optional. If no comparator is specified, elements are
23007 * compared using array_util.defaultComparator, which is suitable for Strings
23008 * and Numbers in ascending arrays. If the array contains multiple instances of
23009 * the target value, the left-most instance will be returned. To provide a
23010 * comparator, it should take 2 arguments to compare and return a negative,
23011 * zero, or a positive number.
23012 */
23013 function binaryInsert(arr, element, comparator) {
23014 const index = binarySearch(arr, element, comparator);
23015 const insertionPoint = index < 0 ? -(index + 1) : index;
23016 arr.splice(insertionPoint, 0, element);
23017 }
23018 /**
23019 * Searches the array for the target using binary search, returns the index
23020 * of the found element, or position to insert if element not found. If no
23021 * comparator is specified, elements are compared using array_
23022 * util.defaultComparator, which is suitable for Strings and Numbers in
23023 * ascending arrays. If the array contains multiple instances of the target
23024 * value, the left-most instance will be returned.
23025 * @param arr The array to be searched in.
23026 * @param target The target to be searched for.
23027 * @param comparator Should take 2 arguments to compare and return a negative,
23028 * zero, or a positive number.
23029 * @return Lowest index of the target value if found, otherwise the insertion
23030 * point where the target should be inserted, in the form of
23031 * (-insertionPoint - 1).
23032 */
23033 function binarySearch(arr, target, comparator) {
23034 return binarySearch_(arr, target, comparator || defaultComparator);
23035 }
23036 /**
23037 * Compares its two arguments for order.
23038 * @param a The first element to be compared.
23039 * @param b The second element to be compared.
23040 * @return A negative number, zero, or a positive number as the first
23041 * argument is less than, equal to, or greater than the second.
23042 */
23043 function defaultComparator(a, b) {
23044 return a > b ? 1 : a < b ? -1 : 0;
23045 }
23046 function binarySearch_(arr, target, comparator) {
23047 let left = 0;
23048 let right = arr.length;
23049 let middle = 0;
23050 let found = false;
23051 while (left < right) {
23052 middle = left + ((right - left) >>> 1);
23053 const compareResult = comparator(target, arr[middle]);
23054 if (compareResult > 0) {
23055 left = middle + 1;
23056 }
23057 else {
23058 right = middle;
23059 // If compareResult is 0, the value is found. We record it is found,
23060 // and then keep looking because there may be duplicate.
23061 found = !compareResult;
23062 }
23063 }
23064 return found ? left : -left - 1;
23065 }
23066
23067 /**
23068 * @license
23069 * Copyright 2020 Google LLC. All Rights Reserved.
23070 * Licensed under the Apache License, Version 2.0 (the "License");
23071 * you may not use this file except in compliance with the License.
23072 * You may obtain a copy of the License at
23073 *
23074 * http://www.apache.org/licenses/LICENSE-2.0
23075 *
23076 * Unless required by applicable law or agreed to in writing, software
23077 * distributed under the License is distributed on an "AS IS" BASIS,
23078 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23079 * See the License for the specific language governing permissions and
23080 * limitations under the License.
23081 * =============================================================================
23082 */
23083 function nonMaxSuppressionV3Impl(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
23084 return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0 /* softNmsSigma */);
23085 }
23086 function nonMaxSuppressionV4Impl(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize) {
23087 return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0 /* softNmsSigma */, false /* returnScoresTensor */, padToMaxOutputSize /* padToMaxOutputSize */, true
23088 /* returnValidOutputs */ );
23089 }
23090 function nonMaxSuppressionV5Impl(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
23091 return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, true /* returnScoresTensor */);
23092 }
23093 function nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, returnScoresTensor = false, padToMaxOutputSize = false, returnValidOutputs = false) {
23094 // The list is sorted in ascending order, so that we can always pop the
23095 // candidate with the largest score in O(1) time.
23096 const candidates = [];
23097 for (let i = 0; i < scores.length; i++) {
23098 if (scores[i] > scoreThreshold) {
23099 candidates.push({ score: scores[i], boxIndex: i, suppressBeginIndex: 0 });
23100 }
23101 }
23102 candidates.sort(ascendingComparator);
23103 // If softNmsSigma is 0, the outcome of this algorithm is exactly same as
23104 // before.
23105 const scale = softNmsSigma > 0 ? (-0.5 / softNmsSigma) : 0.0;
23106 const selectedIndices = [];
23107 const selectedScores = [];
23108 while (selectedIndices.length < maxOutputSize && candidates.length > 0) {
23109 const candidate = candidates.pop();
23110 const { score: originalScore, boxIndex, suppressBeginIndex } = candidate;
23111 if (originalScore < scoreThreshold) {
23112 break;
23113 }
23114 // Overlapping boxes are likely to have similar scores, therefore we
23115 // iterate through the previously selected boxes backwards in order to
23116 // see if candidate's score should be suppressed. We use
23117 // suppressBeginIndex to track and ensure a candidate can be suppressed
23118 // by a selected box no more than once. Also, if the overlap exceeds
23119 // iouThreshold, we simply ignore the candidate.
23120 let ignoreCandidate = false;
23121 for (let j = selectedIndices.length - 1; j >= suppressBeginIndex; --j) {
23122 const iou = intersectionOverUnion(boxes, boxIndex, selectedIndices[j]);
23123 if (iou >= iouThreshold) {
23124 ignoreCandidate = true;
23125 break;
23126 }
23127 candidate.score =
23128 candidate.score * suppressWeight(iouThreshold, scale, iou);
23129 if (candidate.score <= scoreThreshold) {
23130 break;
23131 }
23132 }
23133 // At this point, if `candidate.score` has not dropped below
23134 // `scoreThreshold`, then we know that we went through all of the
23135 // previous selections and can safely update `suppressBeginIndex` to the
23136 // end of the selected array. Then we can re-insert the candidate with
23137 // the updated score and suppressBeginIndex back in the candidate list.
23138 // If on the other hand, `candidate.score` has dropped below the score
23139 // threshold, we will not add it back to the candidates list.
23140 candidate.suppressBeginIndex = selectedIndices.length;
23141 if (!ignoreCandidate) {
23142 // Candidate has passed all the tests, and is not suppressed, so
23143 // select the candidate.
23144 if (candidate.score === originalScore) {
23145 selectedIndices.push(boxIndex);
23146 selectedScores.push(candidate.score);
23147 }
23148 else if (candidate.score > scoreThreshold) {
23149 // Candidate's score is suppressed but is still high enough to be
23150 // considered, so add back to the candidates list.
23151 binaryInsert(candidates, candidate, ascendingComparator);
23152 }
23153 }
23154 }
23155 // NonMaxSuppressionV4 feature: padding output to maxOutputSize.
23156 const validOutputs = selectedIndices.length;
23157 const elemsToPad = maxOutputSize - validOutputs;
23158 if (padToMaxOutputSize && elemsToPad > 0) {
23159 selectedIndices.push(...new Array(elemsToPad).fill(0));
23160 selectedScores.push(...new Array(elemsToPad).fill(0.0));
23161 }
23162 const result = { selectedIndices };
23163 if (returnScoresTensor) {
23164 result['selectedScores'] = selectedScores;
23165 }
23166 if (returnValidOutputs) {
23167 result['validOutputs'] = validOutputs;
23168 }
23169 return result;
23170 }
23171 function intersectionOverUnion(boxes, i, j) {
23172 const iCoord = boxes.subarray(i * 4, i * 4 + 4);
23173 const jCoord = boxes.subarray(j * 4, j * 4 + 4);
23174 const yminI = Math.min(iCoord[0], iCoord[2]);
23175 const xminI = Math.min(iCoord[1], iCoord[3]);
23176 const ymaxI = Math.max(iCoord[0], iCoord[2]);
23177 const xmaxI = Math.max(iCoord[1], iCoord[3]);
23178 const yminJ = Math.min(jCoord[0], jCoord[2]);
23179 const xminJ = Math.min(jCoord[1], jCoord[3]);
23180 const ymaxJ = Math.max(jCoord[0], jCoord[2]);
23181 const xmaxJ = Math.max(jCoord[1], jCoord[3]);
23182 const areaI = (ymaxI - yminI) * (xmaxI - xminI);
23183 const areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);
23184 if (areaI <= 0 || areaJ <= 0) {
23185 return 0.0;
23186 }
23187 const intersectionYmin = Math.max(yminI, yminJ);
23188 const intersectionXmin = Math.max(xminI, xminJ);
23189 const intersectionYmax = Math.min(ymaxI, ymaxJ);
23190 const intersectionXmax = Math.min(xmaxI, xmaxJ);
23191 const intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) *
23192 Math.max(intersectionXmax - intersectionXmin, 0.0);
23193 return intersectionArea / (areaI + areaJ - intersectionArea);
23194 }
23195 // A Gaussian penalty function, this method always returns values in [0, 1].
23196 // The weight is a function of similarity, the more overlap two boxes are, the
23197 // smaller the weight is, meaning highly overlapping boxe will be significantly
23198 // penalized. On the other hand, a non-overlapping box will not be penalized.
23199 function suppressWeight(iouThreshold, scale, iou) {
23200 const weight = Math.exp(scale * iou * iou);
23201 return iou <= iouThreshold ? weight : 0.0;
23202 }
23203 function ascendingComparator(c1, c2) {
23204 // For objects with same scores, we make the object with the larger index go
23205 // first. In an array that pops from the end, this means that the object with
23206 // the smaller index will be popped first. This ensures the same output as
23207 // the TensorFlow python version.
23208 return (c1.score - c2.score) ||
23209 ((c1.score === c2.score) && (c2.boxIndex - c1.boxIndex));
23210 }
23211
23212 /**
23213 * @license
23214 * Copyright 2020 Google LLC. All Rights Reserved.
23215 * Licensed under the Apache License, Version 2.0 (the "License");
23216 * you may not use this file except in compliance with the License.
23217 * You may obtain a copy of the License at
23218 *
23219 * http://www.apache.org/licenses/LICENSE-2.0
23220 *
23221 * Unless required by applicable law or agreed to in writing, software
23222 * distributed under the License is distributed on an "AS IS" BASIS,
23223 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23224 * See the License for the specific language governing permissions and
23225 * limitations under the License.
23226 * =============================================================================
23227 */
23228 /**
23229 * Performs non maximum suppression of bounding boxes based on
23230 * iou (intersection over union).
23231 *
23232 * This is the async version of `nonMaxSuppression`
23233 *
23234 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
23235 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
23236 * the bounding box.
23237 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
23238 * @param maxOutputSize The maximum number of boxes to be selected.
23239 * @param iouThreshold A float representing the threshold for deciding whether
23240 * boxes overlap too much with respect to IOU. Must be between [0, 1].
23241 * Defaults to 0.5 (50% box overlap).
23242 * @param scoreThreshold A threshold for deciding when to remove boxes based
23243 * on score. Defaults to -inf, which means any score is accepted.
23244 * @return A 1D tensor with the selected box indices.
23245 *
23246 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
23247 */
23248 async function nonMaxSuppressionAsync_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY) {
23249 const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
23250 const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
23251 const inputs = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);
23252 maxOutputSize = inputs.maxOutputSize;
23253 iouThreshold = inputs.iouThreshold;
23254 scoreThreshold = inputs.scoreThreshold;
23255 const boxesAndScores = await Promise.all([$boxes.data(), $scores.data()]);
23256 const boxesVals = boxesAndScores[0];
23257 const scoresVals = boxesAndScores[1];
23258 // We call a cpu based impl directly with the typedarray data here rather
23259 // than a kernel because all kernels are synchronous (and thus cannot await
23260 // .data()).
23261 const { selectedIndices } = nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
23262 if ($boxes !== boxes) {
23263 $boxes.dispose();
23264 }
23265 if ($scores !== scores) {
23266 $scores.dispose();
23267 }
23268 return tensor1d(selectedIndices, 'int32');
23269 }
23270 const nonMaxSuppressionAsync = nonMaxSuppressionAsync_;
23271
23272 /**
23273 * @license
23274 * Copyright 2020 Google LLC. All Rights Reserved.
23275 * Licensed under the Apache License, Version 2.0 (the "License");
23276 * you may not use this file except in compliance with the License.
23277 * You may obtain a copy of the License at
23278 *
23279 * http://www.apache.org/licenses/LICENSE-2.0
23280 *
23281 * Unless required by applicable law or agreed to in writing, software
23282 * distributed under the License is distributed on an "AS IS" BASIS,
23283 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23284 * See the License for the specific language governing permissions and
23285 * limitations under the License.
23286 * =============================================================================
23287 */
23288 /**
23289 * Performs non maximum suppression of bounding boxes based on
23290 * iou (intersection over union).
23291 *
23292 * This op also supports a Soft-NMS mode (c.f.
23293 * Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
23294 * of other overlapping boxes, therefore favoring different regions of the image
23295 * with high scores. To enable this Soft-NMS mode, set the `softNmsSigma`
23296 * parameter to be larger than 0.
23297 *
23298 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
23299 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
23300 * the bounding box.
23301 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
23302 * @param maxOutputSize The maximum number of boxes to be selected.
23303 * @param iouThreshold A float representing the threshold for deciding whether
23304 * boxes overlap too much with respect to IOU. Must be between [0, 1].
23305 * Defaults to 0.5 (50% box overlap).
23306 * @param scoreThreshold A threshold for deciding when to remove boxes based
23307 * on score. Defaults to -inf, which means any score is accepted.
23308 * @param softNmsSigma A float representing the sigma parameter for Soft NMS.
23309 * When sigma is 0, it falls back to nonMaxSuppression.
23310 * @return A map with the following properties:
23311 * - selectedIndices: A 1D tensor with the selected box indices.
23312 * - selectedScores: A 1D tensor with the corresponding scores for each
23313 * selected box.
23314 *
23315 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
23316 */
23317 function nonMaxSuppressionWithScore_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, softNmsSigma = 0.0) {
23318 const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');
23319 const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');
23320 const params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
23321 maxOutputSize = params.maxOutputSize;
23322 iouThreshold = params.iouThreshold;
23323 scoreThreshold = params.scoreThreshold;
23324 softNmsSigma = params.softNmsSigma;
23325 const inputs = { boxes: $boxes, scores: $scores };
23326 const attrs = { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma };
23327 // tslint:disable-next-line: no-unnecessary-type-assertion
23328 const result = ENGINE.runKernel(NonMaxSuppressionV5, inputs, attrs);
23329 return { selectedIndices: result[0], selectedScores: result[1] };
23330 }
23331 const nonMaxSuppressionWithScore = op({ nonMaxSuppressionWithScore_ });
23332
23333 /**
23334 * @license
23335 * Copyright 2020 Google LLC. All Rights Reserved.
23336 * Licensed under the Apache License, Version 2.0 (the "License");
23337 * you may not use this file except in compliance with the License.
23338 * You may obtain a copy of the License at
23339 *
23340 * http://www.apache.org/licenses/LICENSE-2.0
23341 *
23342 * Unless required by applicable law or agreed to in writing, software
23343 * distributed under the License is distributed on an "AS IS" BASIS,
23344 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23345 * See the License for the specific language governing permissions and
23346 * limitations under the License.
23347 * =============================================================================
23348 */
23349 /**
23350 * Asynchronously performs non maximum suppression of bounding boxes based on
23351 * iou (intersection over union).
23352 *
23353 * This op also supports a Soft-NMS mode (c.f.
23354 * Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
23355 * of other overlapping boxes, therefore favoring different regions of the image
23356 * with high scores. To enable this Soft-NMS mode, set the `softNmsSigma`
23357 * parameter to be larger than 0.
23358 *
23359 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
23360 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
23361 * the bounding box.
23362 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
23363 * @param maxOutputSize The maximum number of boxes to be selected.
23364 * @param iouThreshold A float representing the threshold for deciding whether
23365 * boxes overlap too much with respect to IOU. Must be between [0, 1].
23366 * Defaults to 0.5 (50% box overlap).
23367 * @param scoreThreshold A threshold for deciding when to remove boxes based
23368 * on score. Defaults to -inf, which means any score is accepted.
23369 * @param softNmsSigma A float representing the sigma parameter for Soft NMS.
23370 * When sigma is 0, it falls back to nonMaxSuppression.
23371 * @return A map with the following properties:
23372 * - selectedIndices: A 1D tensor with the selected box indices.
23373 * - selectedScores: A 1D tensor with the corresponding scores for each
23374 * selected box.
23375 *
23376 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
23377 */
23378 async function nonMaxSuppressionWithScoreAsync_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, softNmsSigma = 0.0) {
23379 const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
23380 const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
23381 const params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
23382 maxOutputSize = params.maxOutputSize;
23383 iouThreshold = params.iouThreshold;
23384 scoreThreshold = params.scoreThreshold;
23385 softNmsSigma = params.softNmsSigma;
23386 const boxesAndScores = await Promise.all([$boxes.data(), $scores.data()]);
23387 const boxesVals = boxesAndScores[0];
23388 const scoresVals = boxesAndScores[1];
23389 // We call a cpu based impl directly with the typedarray data here rather
23390 // than a kernel because all kernels are synchronous (and thus cannot await
23391 // .data()).
23392 const { selectedIndices, selectedScores } = nonMaxSuppressionV5Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
23393 if ($boxes !== boxes) {
23394 $boxes.dispose();
23395 }
23396 if ($scores !== scores) {
23397 $scores.dispose();
23398 }
23399 return {
23400 selectedIndices: tensor1d(selectedIndices, 'int32'),
23401 selectedScores: tensor1d(selectedScores)
23402 };
23403 }
23404 const nonMaxSuppressionWithScoreAsync = nonMaxSuppressionWithScoreAsync_;
23405
23406 /**
23407 * @license
23408 * Copyright 2020 Google LLC. All Rights Reserved.
23409 * Licensed under the Apache License, Version 2.0 (the "License");
23410 * you may not use this file except in compliance with the License.
23411 * You may obtain a copy of the License at
23412 *
23413 * http://www.apache.org/licenses/LICENSE-2.0
23414 *
23415 * Unless required by applicable law or agreed to in writing, software
23416 * distributed under the License is distributed on an "AS IS" BASIS,
23417 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23418 * See the License for the specific language governing permissions and
23419 * limitations under the License.
23420 * =============================================================================
23421 */
23422 /**
23423 * Asynchronously performs non maximum suppression of bounding boxes based on
23424 * iou (intersection over union), with an option to pad results.
23425 *
23426 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
23427 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
23428 * the bounding box.
23429 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
23430 * @param maxOutputSize The maximum number of boxes to be selected.
23431 * @param iouThreshold A float representing the threshold for deciding whether
23432 * boxes overlap too much with respect to IOU. Must be between [0, 1].
23433 * Defaults to 0.5 (50% box overlap).
23434 * @param scoreThreshold A threshold for deciding when to remove boxes based
23435 * on score. Defaults to -inf, which means any score is accepted.
23436 * @param padToMaxOutputSize Defalts to false. If true, size of output
23437 * `selectedIndices` is padded to maxOutputSize.
23438 * @return A map with the following properties:
23439 * - selectedIndices: A 1D tensor with the selected box indices.
23440 * - validOutputs: A scalar denoting how many elements in `selectedIndices`
23441 * are valid. Valid elements occur first, then padding.
23442 *
23443 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
23444 */
23445 function nonMaxSuppressionPadded_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, padToMaxOutputSize = false) {
23446 const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');
23447 const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');
23448 const params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, null /* softNmsSigma */);
23449 const $maxOutputSize = params.maxOutputSize;
23450 const $iouThreshold = params.iouThreshold;
23451 const $scoreThreshold = params.scoreThreshold;
23452 const inputs = { boxes: $boxes, scores: $scores };
23453 const attrs = {
23454 maxOutputSize: $maxOutputSize,
23455 iouThreshold: $iouThreshold,
23456 scoreThreshold: $scoreThreshold,
23457 padToMaxOutputSize
23458 };
23459 // tslint:disable-next-line: no-unnecessary-type-assertion
23460 const result = ENGINE.runKernel(NonMaxSuppressionV4, inputs, attrs);
23461 return { selectedIndices: result[0], validOutputs: result[1] };
23462 }
23463 const nonMaxSuppressionPadded = op({ nonMaxSuppressionPadded_ });
23464
23465 /**
23466 * @license
23467 * Copyright 2020 Google LLC. All Rights Reserved.
23468 * Licensed under the Apache License, Version 2.0 (the "License");
23469 * you may not use this file except in compliance with the License.
23470 * You may obtain a copy of the License at
23471 *
23472 * http://www.apache.org/licenses/LICENSE-2.0
23473 *
23474 * Unless required by applicable law or agreed to in writing, software
23475 * distributed under the License is distributed on an "AS IS" BASIS,
23476 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23477 * See the License for the specific language governing permissions and
23478 * limitations under the License.
23479 * =============================================================================
23480 */
23481 /**
23482 * Asynchronously performs non maximum suppression of bounding boxes based on
23483 * iou (intersection over union), with an option to pad results.
23484 *
23485 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
23486 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
23487 * the bounding box.
23488 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
23489 * @param maxOutputSize The maximum number of boxes to be selected.
23490 * @param iouThreshold A float representing the threshold for deciding whether
23491 * boxes overlap too much with respect to IOU. Must be between [0, 1].
23492 * Defaults to 0.5 (50% box overlap).
23493 * @param scoreThreshold A threshold for deciding when to remove boxes based
23494 * on score. Defaults to -inf, which means any score is accepted.
23495 * @param padToMaxOutputSize Defalts to false. If true, size of output
23496 * `selectedIndices` is padded to maxOutputSize.
23497 * @return A map with the following properties:
23498 * - selectedIndices: A 1D tensor with the selected box indices.
23499 * - validOutputs: A scalar denoting how many elements in `selectedIndices`
23500 * are valid. Valid elements occur first, then padding.
23501 *
23502 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
23503 */
23504 async function nonMaxSuppressionPaddedAsync_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, padToMaxOutputSize = false) {
23505 const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
23506 const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
23507 const params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, null /* softNmsSigma */);
23508 const $maxOutputSize = params.maxOutputSize;
23509 const $iouThreshold = params.iouThreshold;
23510 const $scoreThreshold = params.scoreThreshold;
23511 const [boxesVals, scoresVals] = await Promise.all([$boxes.data(), $scores.data()]);
23512 // We call a cpu based impl directly with the typedarray data here rather
23513 // than a kernel because all kernels are synchronous (and thus cannot await
23514 // .data()).
23515 const { selectedIndices, validOutputs } = nonMaxSuppressionV4Impl(boxesVals, scoresVals, $maxOutputSize, $iouThreshold, $scoreThreshold, padToMaxOutputSize);
23516 if ($boxes !== boxes) {
23517 $boxes.dispose();
23518 }
23519 if ($scores !== scores) {
23520 $scores.dispose();
23521 }
23522 return {
23523 selectedIndices: tensor1d(selectedIndices, 'int32'),
23524 validOutputs: scalar(validOutputs, 'int32')
23525 };
23526 }
23527 const nonMaxSuppressionPaddedAsync = nonMaxSuppressionPaddedAsync_;
23528
23529 /**
23530 * @license
23531 * Copyright 2020 Google LLC. All Rights Reserved.
23532 * Licensed under the Apache License, Version 2.0 (the "License");
23533 * you may not use this file except in compliance with the License.
23534 * You may obtain a copy of the License at
23535 *
23536 * http://www.apache.org/licenses/LICENSE-2.0
23537 *
23538 * Unless required by applicable law or agreed to in writing, software
23539 * distributed under the License is distributed on an "AS IS" BASIS,
23540 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23541 * See the License for the specific language governing permissions and
23542 * limitations under the License.
23543 * =============================================================================
23544 */
23545 /**
23546 * Bilinear resize a single 3D image or a batch of 3D images to a new shape.
23547 *
23548 * @param images The images, of rank 4 or rank 3, of shape
23549 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
23550 * @param size The new shape `[newHeight, newWidth]` to resize the
23551 * images to. Each channel is resized individually.
23552 * @param alignCorners Defaults to `false`. If true, rescale
23553 * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4
23554 * corners of images and resized images. If false, rescale by
23555 * `new_height / height`. Treat similarly the width dimension.
23556 * @param halfPixelCenters Defaults to `false`. Whether to assume pixel centers
23557 * are at 0.5, which would make the floating point coordinates of the top
23558 * left pixel 0.5, 0.5.
23559 *
23560 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
23561 */
23562 function resizeBilinear_(images, size, alignCorners = false, halfPixelCenters = false) {
23563 const $images = convertToTensor(images, 'images', 'resizeBilinear');
23564 assert($images.rank === 3 || $images.rank === 4, () => `Error in resizeBilinear: x must be rank 3 or 4, but got ` +
23565 `rank ${$images.rank}.`);
23566 assert(size.length === 2, () => `Error in resizeBilinear: new shape must 2D, but got shape ` +
23567 `${size}.`);
23568 assert(halfPixelCenters === false || alignCorners === false, () => `Error in resizeBilinear: If halfPixelCenters is true, ` +
23569 `alignCorners must be false.`);
23570 let batchImages = $images;
23571 let reshapedTo4D = false;
23572 if ($images.rank === 3) {
23573 reshapedTo4D = true;
23574 batchImages = reshape($images, [1, $images.shape[0], $images.shape[1], $images.shape[2]]);
23575 }
23576 const [] = size;
23577 const inputs = { images: batchImages };
23578 const attrs = { alignCorners, halfPixelCenters, size };
23579 // tslint:disable-next-line: no-unnecessary-type-assertion
23580 const res = ENGINE.runKernel(ResizeBilinear, inputs, attrs);
23581 if (reshapedTo4D) {
23582 return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
23583 }
23584 return res;
23585 }
23586 const resizeBilinear = op({ resizeBilinear_ });
23587
23588 /**
23589 * @license
23590 * Copyright 2020 Google LLC. All Rights Reserved.
23591 * Licensed under the Apache License, Version 2.0 (the "License");
23592 * you may not use this file except in compliance with the License.
23593 * You may obtain a copy of the License at
23594 *
23595 * http://www.apache.org/licenses/LICENSE-2.0
23596 *
23597 * Unless required by applicable law or agreed to in writing, software
23598 * distributed under the License is distributed on an "AS IS" BASIS,
23599 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23600 * See the License for the specific language governing permissions and
23601 * limitations under the License.
23602 * =============================================================================
23603 */
23604 /**
23605 * NearestNeighbor resize a batch of 3D images to a new shape.
23606 *
23607 * @param images The images, of rank 4 or rank 3, of shape
23608 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
23609 * @param size The new shape `[newHeight, newWidth]` to resize the
23610 * images to. Each channel is resized individually.
23611 * @param alignCorners Defaults to False. If true, rescale
23612 * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4
23613 * corners of images and resized images. If false, rescale by
23614 * `new_height / height`. Treat similarly the width dimension.
23615 * @param halfPixelCenters Defaults to `false`. Whether to assumes pixels are of
23616 * half the actual dimensions, and yields more accurate resizes. This flag
23617 * would also make the floating point coordinates of the top left pixel
23618 * 0.5, 0.5.
23619 *
23620 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
23621 */
23622 function resizeNearestNeighbor_(images, size, alignCorners = false, halfPixelCenters = false) {
23623 const $images = convertToTensor(images, 'images', 'resizeNearestNeighbor');
23624 assert($images.rank === 3 || $images.rank === 4, () => `Error in resizeNearestNeighbor: x must be rank 3 or 4, but got ` +
23625 `rank ${$images.rank}.`);
23626 assert(size.length === 2, () => `Error in resizeNearestNeighbor: new shape must 2D, but got shape ` +
23627 `${size}.`);
23628 assert($images.dtype === 'float32' || $images.dtype === 'int32', () => '`images` must have `int32` or `float32` as dtype');
23629 assert(halfPixelCenters === false || alignCorners === false, () => `Error in resizeNearestNeighbor: If halfPixelCenters is true, ` +
23630 `alignCorners must be false.`);
23631 let batchImages = $images;
23632 let reshapedTo4D = false;
23633 if ($images.rank === 3) {
23634 reshapedTo4D = true;
23635 batchImages = reshape($images, [1, $images.shape[0], $images.shape[1], $images.shape[2]]);
23636 }
23637 const [] = size;
23638 const inputs = { images: batchImages };
23639 const attrs = { alignCorners, halfPixelCenters, size };
23640 // tslint:disable-next-line: no-unnecessary-type-assertion
23641 const res = ENGINE.runKernel(ResizeNearestNeighbor, inputs, attrs);
23642 if (reshapedTo4D) {
23643 return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
23644 }
23645 return res;
23646 }
23647 const resizeNearestNeighbor = op({ resizeNearestNeighbor_ });
23648
23649 /**
23650 * @license
23651 * Copyright 2021 Google LLC. All Rights Reserved.
23652 * Licensed under the Apache License, Version 2.0 (the "License");
23653 * you may not use this file except in compliance with the License.
23654 * You may obtain a copy of the License at
23655 *
23656 * https://www.apache.org/licenses/LICENSE-2.0
23657 *
23658 * Unless required by applicable law or agreed to in writing, software
23659 * distributed under the License is distributed on an "AS IS" BASIS,
23660 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23661 * See the License for the specific language governing permissions and
23662 * limitations under the License.
23663 * =============================================================================
23664 */
23665 /**
23666 * Performs image binarization with corresponding threshold
23667 * (depends on the method)value, which creates a binary image from a grayscale.
23668 * @param image 3d tensor of shape [imageHeight,imageWidth, depth],
23669 * where imageHeight and imageWidth must be positive.The image color
23670 * range should be [0, 255].
23671 * @param method Optional string from `'binary' | 'otsu'`
23672 * which specifies the method for thresholding. Defaults to 'binary'.
23673 * @param inverted Optional boolean whichspecifies
23674 * if colours should be inverted. Defaults to false.
23675 * @param threshValue Optional number which defines threshold value from 0 to 1.
23676 * Defaults to 0.5.
23677 * @return A 3d tensor of shape [imageHeight,imageWidth, depth], which
23678 * contains binarized image.
23679 */
23680 function threshold_(image, method = 'binary', inverted = false, threshValue = 0.5) {
23681 const $image = convertToTensor(image, 'image', 'threshold');
23682 /* 0.2989, 0.5870, 0.1140 are represent luma coefficients in CCIR601.
23683 Reference for converting between RGB and grayscale: https://en.wikipedia.org/wiki/Luma_%28video%29 */
23684 const RED_INTENCITY_COEF = 0.2989;
23685 const GREEN_INTENCITY_COEF = 0.5870;
23686 const BLUE_INTENCITY_COEF = 0.1140;
23687 const totalPixelsInImage = $image.shape[0] * $image.shape[1];
23688 let $threshold = mul(tensor1d([threshValue]), 255);
23689 let r, g, b, grayscale;
23690 assert($image.rank === 3, () => 'Error in threshold: image must be rank 3,' +
23691 `but got rank ${$image.rank}.`);
23692 assert($image.shape[2] === 3 || $image.shape[2] === 1, () => 'Error in threshold: ' +
23693 'image color channel must be equal to 3 or 1' +
23694 `but got ${$image.shape[2]}.`);
23695 assert($image.dtype === 'int32' || $image.dtype === 'float32', () => 'Error in dtype: image dtype must be int32 or float32,' +
23696 `but got dtype ${$image.dtype}.`);
23697 assert(method === 'otsu' || method === 'binary', () => `Method must be binary or otsu, but was ${method}`);
23698 if ($image.shape[2] === 3) {
23699 [r, g, b] = split($image, [1, 1, 1], -1);
23700 const $r = mul(r, RED_INTENCITY_COEF);
23701 const $g = mul(g, GREEN_INTENCITY_COEF);
23702 const $b = mul(b, BLUE_INTENCITY_COEF);
23703 grayscale = add$1(add$1($r, $g), $b);
23704 }
23705 else {
23706 grayscale = image;
23707 }
23708 if (method === 'otsu') {
23709 const $histogram = bincount(cast(round$1(grayscale), 'int32'), tensor([]), 256);
23710 $threshold = otsu($histogram, totalPixelsInImage);
23711 }
23712 const invCondition = inverted ?
23713 lessEqual(grayscale, $threshold) : greater(grayscale, $threshold);
23714 const result = cast(mul(invCondition, 255), 'int32');
23715 return result;
23716 }
23717 function otsu(histogram, total) {
23718 let bestThresh = tensor1d([-1]);
23719 let bestInBetVar = tensor1d([0]);
23720 let cInBetVar = tensor1d([0]);
23721 let classFirst, classSecond, meanFirst, meanSec, weightForeground, weightBack;
23722 for (let index = 0; index < histogram.size - 1; index++) {
23723 classFirst = slice(histogram, 0, index + 1);
23724 classSecond = slice(histogram, index + 1);
23725 weightForeground = div(sum$1(classFirst), total);
23726 weightBack = div(sum$1(classSecond), total);
23727 const meanFirstDivA = sum$1(mul(classFirst, range(0, classFirst.size)));
23728 meanFirst = div(meanFirstDivA, sum$1(classFirst));
23729 const meanSecFill = fill(classSecond.shape, classFirst.size);
23730 const meanSecAdd = add$1(range(0, classSecond.size), meanSecFill);
23731 const meanSecMul = mul(classSecond, (meanSecAdd));
23732 meanSec = div(sum$1(meanSecMul), sum$1(classSecond));
23733 const cInBetVarSubA = sub(meanFirst, meanSec);
23734 const cInBetVarSubB = sub(meanFirst, meanSec);
23735 const cInBetVarMul = mul(weightForeground, weightBack);
23736 cInBetVar = mul(mul(cInBetVarMul, cInBetVarSubA), cInBetVarSubB);
23737 const condition = greater(cInBetVar, bestInBetVar);
23738 bestInBetVar = where(condition, cInBetVar, bestInBetVar);
23739 bestThresh = where(condition, tensor1d([index]), bestThresh);
23740 }
23741 return bestThresh;
23742 }
23743 const threshold = op({ threshold_ });
23744
23745 /**
23746 * @license
23747 * Copyright 2021 Google LLC. All Rights Reserved.
23748 * Licensed under the Apache License, Version 2.0 (the "License");
23749 * you may not use this file except in compliance with the License.
23750 * You may obtain a copy of the License at
23751 *
23752 * http://www.apache.org/licenses/LICENSE-2.0
23753 *
23754 * Unless required by applicable law or agreed to in writing, software
23755 * distributed under the License is distributed on an "AS IS" BASIS,
23756 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23757 * See the License for the specific language governing permissions and
23758 * limitations under the License.
23759 * =============================================================================
23760 */
23761 /**
23762 * Applies the given transform(s) to the image(s).
23763 *
23764 * @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
23765 * @param transforms Projective transform matrix/matrices. A tensor1d of length
23766 * 8 or tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0
23767 * b1, b2, c0, c1], then it maps the output point (x, y) to a transformed
23768 * input point (x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k),
23769 * where k = c0 x + c1 y + 1. The transforms are inverted compared to the
23770 * transform mapping input points to output points.
23771 * @param interpolation Interpolation mode.
23772 * Supported values: 'nearest', 'bilinear'. Default to 'nearest'.
23773 * @param fillMode Points outside the boundaries of the input are filled
23774 * according to the given mode, one of 'constant', 'reflect', 'wrap',
23775 * 'nearest'. Default to 'constant'.
23776 * 'reflect': (d c b a | a b c d | d c b a ) The input is extended by
23777 * reflecting about the edge of the last pixel.
23778 * 'constant': (k k k k | a b c d | k k k k) The input is extended by
23779 * filling all values beyond the edge with the same constant value k.
23780 * 'wrap': (a b c d | a b c d | a b c d) The input is extended by
23781 * wrapping around to the opposite edge.
23782 * 'nearest': (a a a a | a b c d | d d d d) The input is extended by
23783 * the nearest pixel.
23784 * @param fillValue A float represents the value to be filled outside the
23785 * boundaries when fillMode is 'constant'.
23786 * @param Output dimension after the transform, [height, width]. If undefined,
23787 * output is the same size as input image.
23788 *
23789 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
23790 */
23791 function transform_(image, transforms, interpolation = 'nearest', fillMode = 'constant', fillValue = 0, outputShape) {
23792 const $image = convertToTensor(image, 'image', 'transform', 'float32');
23793 const $transforms = convertToTensor(transforms, 'transforms', 'transform', 'float32');
23794 assert($image.rank === 4, () => 'Error in transform: image must be rank 4,' +
23795 `but got rank ${$image.rank}.`);
23796 assert($transforms.rank === 2 &&
23797 ($transforms.shape[0] === $image.shape[0] ||
23798 $transforms.shape[0] === 1) &&
23799 $transforms.shape[1] === 8, () => `Error in transform: Input transform should be batch x 8 or 1 x 8`);
23800 assert(outputShape == null || outputShape.length === 2, () => 'Error in transform: outputShape must be [height, width] or null, ' +
23801 `but got ${outputShape}.`);
23802 const inputs = { image: $image, transforms: $transforms };
23803 const attrs = { interpolation, fillMode, fillValue, outputShape };
23804 return ENGINE.runKernel(Transform, inputs, attrs);
23805 }
23806 const transform = op({ transform_ });
23807
23808 /**
23809 * @license
23810 * Copyright 2020 Google LLC. All Rights Reserved.
23811 * Licensed under the Apache License, Version 2.0 (the "License");
23812 * you may not use this file except in compliance with the License.
23813 * You may obtain a copy of the License at
23814 *
23815 * http://www.apache.org/licenses/LICENSE-2.0
23816 *
23817 * Unless required by applicable law or agreed to in writing, software
23818 * distributed under the License is distributed on an "AS IS" BASIS,
23819 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23820 * See the License for the specific language governing permissions and
23821 * limitations under the License.
23822 * =============================================================================
23823 */
23824 /**
23825 * Copy a tensor setting everything outside a central band in each innermost
23826 * matrix to zero.
23827 *
23828 * The band part is computed as follows: Assume input has `k` dimensions
23829 * `[I, J, K, ..., M, N]`, then the output is a tensor with the same shape where
23830 * `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
23831 * The indicator function
23832 * `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower))`
23833 * `&& (num_upper < 0 || (n-m) <= num_upper)`
23834 *
23835 * ```js
23836 * const x = tf.tensor2d([[ 0, 1, 2, 3],
23837 * [-1, 0, 1, 2],
23838 * [-2, -1, 0, 1],
23839 * [-3, -2, -1, 0]]);
23840 * let y = tf.linalg.bandPart(x, 1, -1);
23841 * y.print(); // [[ 0, 1, 2, 3],
23842 * // [-1, 0, 1, 2],
23843 * // [ 0, -1, 0, 1],
23844 * // [ 0, 0 , -1, 0]]
23845 * let z = tf.linalg.bandPart(x, 2, 1);
23846 * z.print(); // [[ 0, 1, 0, 0],
23847 * // [-1, 0, 1, 0],
23848 * // [-2, -1, 0, 1],
23849 * // [ 0, -2, -1, 0]]
23850 * ```
23851 *
23852 * @param x Rank `k` tensor
23853 * @param numLower Number of subdiagonals to keep.
23854 * If negative, keep entire lower triangle.
23855 * @param numUpper Number of subdiagonals to keep.
23856 * If negative, keep entire upper triangle.
23857 * @returns Rank `k` tensor of the same shape as input.
23858 * The extracted banded tensor.
23859 *
23860 * @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'}
23861 */
23862 function bandPart_(a, numLower, numUpper) {
23863 assert(numLower % 1 === 0, () => `bandPart(): numLower must be an integer, got ${numLower}.`);
23864 assert(numUpper % 1 === 0, () => `bandPart(): numUpper must be an integer, got ${numUpper}.`);
23865 const $a = convertToTensor(a, 'a', 'bandPart');
23866 assert($a.rank >= 2, () => `bandPart(): Rank must be at least 2, got ${$a.rank}.`);
23867 const shape = $a.shape;
23868 const [M, N] = $a.shape.slice(-2);
23869 if (!(numLower <= M)) {
23870 throw new Error(`bandPart(): numLower (${numLower})` +
23871 ` must not be greater than the number of rows (${M}).`);
23872 }
23873 if (!(numUpper <= N)) {
23874 throw new Error(`bandPart(): numUpper (${numUpper})` +
23875 ` must not be greater than the number of columns (${N}).`);
23876 }
23877 if (numLower < 0) {
23878 numLower = M;
23879 }
23880 if (numUpper < 0) {
23881 numUpper = N;
23882 }
23883 const i = reshape(range(0, M, 1, 'int32'), [-1, 1]);
23884 const j = range(0, N, 1, 'int32');
23885 const ij = sub(i, j);
23886 const inBand = logicalAnd(lessEqual(ij, scalar(+numLower, 'int32')), greaterEqual(ij, scalar(-numUpper, 'int32')));
23887 const zero = zeros([M, N], $a.dtype);
23888 return reshape(stack(unstack(reshape($a, [-1, M, N]))
23889 .map(mat => where(inBand, mat, zero))), shape);
23890 }
23891 const bandPart = op({ bandPart_ });
23892
23893 /**
23894 * @license
23895 * Copyright 2020 Google LLC. All Rights Reserved.
23896 * Licensed under the Apache License, Version 2.0 (the "License");
23897 * you may not use this file except in compliance with the License.
23898 * You may obtain a copy of the License at
23899 *
23900 * http://www.apache.org/licenses/LICENSE-2.0
23901 *
23902 * Unless required by applicable law or agreed to in writing, software
23903 * distributed under the License is distributed on an "AS IS" BASIS,
23904 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23905 * See the License for the specific language governing permissions and
23906 * limitations under the License.
23907 * =============================================================================
23908 */
23909 /**
23910 * Gram-Schmidt orthogonalization.
23911 *
23912 * ```js
23913 * const x = tf.tensor2d([[1, 2], [3, 4]]);
23914 * let y = tf.linalg.gramSchmidt(x);
23915 * y.print();
23916 * console.log('Othogonalized:');
23917 * y.dot(y.transpose()).print(); // should be nearly the identity matrix.
23918 * console.log('First row direction maintained:');
23919 * const data = await y.array();
23920 * console.log(data[0][1] / data[0][0]); // should be nearly 2.
23921 * ```
23922 *
23923 * @param xs The vectors to be orthogonalized, in one of the two following
23924 * formats:
23925 * - An Array of `tf.Tensor1D`.
23926 * - A `tf.Tensor2D`, i.e., a matrix, in which case the vectors are the rows
23927 * of `xs`.
23928 * In each case, all the vectors must have the same length and the length
23929 * must be greater than or equal to the number of vectors.
23930 * @returns The orthogonalized and normalized vectors or matrix.
23931 * Orthogonalization means that the vectors or the rows of the matrix
23932 * are orthogonal (zero inner products). Normalization means that each
23933 * vector or each row of the matrix has an L2 norm that equals `1`.
23934 *
23935 * @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'}
23936 */
23937 function gramSchmidt_(xs) {
23938 let inputIsTensor2D;
23939 if (Array.isArray(xs)) {
23940 inputIsTensor2D = false;
23941 assert(xs != null && xs.length > 0, () => 'Gram-Schmidt process: input must not be null, undefined, or ' +
23942 'empty');
23943 const dim = xs[0].shape[0];
23944 for (let i = 1; i < xs.length; ++i) {
23945 assert(xs[i].shape[0] === dim, () => 'Gram-Schmidt: Non-unique lengths found in the input vectors: ' +
23946 `(${xs[i].shape[0]} vs. ${dim})`);
23947 }
23948 }
23949 else {
23950 inputIsTensor2D = true;
23951 xs = split(xs, xs.shape[0], 0).map(x => squeeze(x, [0]));
23952 }
23953 assert(xs.length <= xs[0].shape[0], () => `Gram-Schmidt: Number of vectors (${xs.length}) exceeds ` +
23954 `number of dimensions (${xs[0].shape[0]}).`);
23955 const ys = [];
23956 const xs1d = xs;
23957 for (let i = 0; i < xs.length; ++i) {
23958 ys.push(ENGINE.tidy(() => {
23959 let x = xs1d[i];
23960 if (i > 0) {
23961 for (let j = 0; j < i; ++j) {
23962 const proj = mul(sum$1(mul(ys[j], x)), ys[j]);
23963 x = sub(x, proj);
23964 }
23965 }
23966 return div(x, norm(x, 'euclidean'));
23967 }));
23968 }
23969 if (inputIsTensor2D) {
23970 return stack(ys, 0);
23971 }
23972 else {
23973 return ys;
23974 }
23975 }
23976 const gramSchmidt = op({ gramSchmidt_ });
23977
23978 /**
23979 * @license
23980 * Copyright 2020 Google LLC. All Rights Reserved.
23981 * Licensed under the Apache License, Version 2.0 (the "License");
23982 * you may not use this file except in compliance with the License.
23983 * You may obtain a copy of the License at
23984 *
23985 * http://www.apache.org/licenses/LICENSE-2.0
23986 *
23987 * Unless required by applicable law or agreed to in writing, software
23988 * distributed under the License is distributed on an "AS IS" BASIS,
23989 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23990 * See the License for the specific language governing permissions and
23991 * limitations under the License.
23992 * =============================================================================
23993 */
23994 /**
23995 * Compute QR decomposition of m-by-n matrix using Householder transformation.
23996 *
23997 * Implementation based on
23998 * [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf]
23999 * (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf)
24000 *
24001 * ```js
24002 * const a = tf.tensor2d([[1, 2], [3, 4]]);
24003 * let [q, r] = tf.linalg.qr(a);
24004 * console.log('Q');
24005 * q.print();
24006 * console.log('R');
24007 * r.print();
24008 * console.log('Orthogonalized');
24009 * q.dot(q.transpose()).print() // should be nearly the identity matrix.
24010 * console.log('Reconstructed');
24011 * q.dot(r).print(); // should be nearly [[1, 2], [3, 4]];
24012 * ```
24013 *
24014 * @param x The `tf.Tensor` to be QR-decomposed. Must have rank >= 2. Suppose
24015 * it has the shape `[..., M, N]`.
24016 * @param fullMatrices An optional boolean parameter. Defaults to `false`.
24017 * If `true`, compute full-sized `Q`. If `false` (the default),
24018 * compute only the leading N columns of `Q` and `R`.
24019 * @returns An `Array` of two `tf.Tensor`s: `[Q, R]`. `Q` is a unitary matrix,
24020 * i.e., its columns all have unit norm and are mutually orthogonal.
24021 * If `M >= N`,
24022 * If `fullMatrices` is `false` (default),
24023 * - `Q` has a shape of `[..., M, N]`,
24024 * - `R` has a shape of `[..., N, N]`.
24025 * If `fullMatrices` is `true` (default),
24026 * - `Q` has a shape of `[..., M, M]`,
24027 * - `R` has a shape of `[..., M, N]`.
24028 * If `M < N`,
24029 * - `Q` has a shape of `[..., M, M]`,
24030 * - `R` has a shape of `[..., M, N]`.
24031 * @throws If the rank of `x` is less than 2.
24032 *
24033 * @doc {heading:'Operations',
24034 * subheading:'Linear Algebra',
24035 * namespace:'linalg'}
24036 */
24037 function qr_(x, fullMatrices = false) {
24038 assert(x.rank >= 2, () => `qr() requires input tensor to have a rank >= 2, but got rank ${x.rank}`);
24039 if (x.rank === 2) {
24040 return qr2d(x, fullMatrices);
24041 }
24042 else {
24043 // Rank > 2.
24044 // TODO(cais): Below we split the input into individual 2D tensors,
24045 // perform QR decomposition on them and then stack the results back
24046 // together. We should explore whether this can be parallelized.
24047 const outerDimsProd = x.shape.slice(0, x.shape.length - 2)
24048 .reduce((value, prev) => value * prev);
24049 const x2ds = unstack(reshape(x, [
24050 outerDimsProd, x.shape[x.shape.length - 2],
24051 x.shape[x.shape.length - 1]
24052 ]), 0);
24053 const q2ds = [];
24054 const r2ds = [];
24055 x2ds.forEach(x2d => {
24056 const [q2d, r2d] = qr2d(x2d, fullMatrices);
24057 q2ds.push(q2d);
24058 r2ds.push(r2d);
24059 });
24060 const q = reshape(stack(q2ds, 0), x.shape);
24061 const r = reshape(stack(r2ds, 0), x.shape);
24062 return [q, r];
24063 }
24064 }
24065 function qr2d(x, fullMatrices = false) {
24066 return ENGINE.tidy(() => {
24067 assert(x.shape.length === 2, () => `qr2d() requires a 2D Tensor, but got a ${x.shape.length}D Tensor.`);
24068 const m = x.shape[0];
24069 const n = x.shape[1];
24070 let q = eye(m); // Orthogonal transform so far.
24071 let r = clone(x); // Transformed matrix so far.
24072 const one2D = tensor2d([[1]], [1, 1]);
24073 let w = clone(one2D);
24074 const iters = m >= n ? n : m;
24075 for (let j = 0; j < iters; ++j) {
24076 // This tidy within the for-loop ensures we clean up temporary
24077 // tensors as soon as they are no longer needed.
24078 const rTemp = r;
24079 const wTemp = w;
24080 const qTemp = q;
24081 [w, r, q] = ENGINE.tidy(() => {
24082 // Find H = I - tau * w * w', to put zeros below R(j, j).
24083 const rjEnd1 = slice(r, [j, j], [m - j, 1]);
24084 const normX = norm(rjEnd1);
24085 const rjj = slice(r, [j, j], [1, 1]);
24086 // The sign() function returns 0 on 0, which causes division by zero.
24087 const s = where(greater(rjj, 0), tensor2d([[-1]]), tensor2d([[1]]));
24088 const u1 = sub(rjj, mul(s, normX));
24089 const wPre = div(rjEnd1, u1);
24090 if (wPre.shape[0] === 1) {
24091 w = clone(one2D);
24092 }
24093 else {
24094 w = concat([
24095 one2D,
24096 slice(wPre, [1, 0], [wPre.shape[0] - 1, wPre.shape[1]])
24097 ], 0);
24098 }
24099 const tau = neg(div(matMul(s, u1), normX));
24100 // -- R := HR, Q := QH.
24101 const rjEndAll = slice(r, [j, 0], [m - j, n]);
24102 const tauTimesW = mul(tau, w);
24103 const wT = transpose(w);
24104 if (j === 0) {
24105 r = sub(rjEndAll, matMul(tauTimesW, matMul(wT, rjEndAll)));
24106 }
24107 else {
24108 const rTimesTau = sub(rjEndAll, matMul(tauTimesW, matMul(wT, rjEndAll)));
24109 r = concat([slice(r, [0, 0], [j, n]), rTimesTau], 0);
24110 }
24111 const tawTimesWT = transpose(tauTimesW);
24112 const qAllJEnd = slice(q, [0, j], [m, q.shape[1] - j]);
24113 if (j === 0) {
24114 q = sub(qAllJEnd, matMul(matMul(qAllJEnd, w), tawTimesWT));
24115 }
24116 else {
24117 const qTimesTau = sub(qAllJEnd, matMul(matMul(qAllJEnd, w), tawTimesWT));
24118 q = concat([slice(q, [0, 0], [m, j]), qTimesTau], 1);
24119 }
24120 return [w, r, q];
24121 });
24122 dispose([rTemp, wTemp, qTemp]);
24123 }
24124 if (!fullMatrices && m > n) {
24125 q = slice(q, [0, 0], [m, n]);
24126 r = slice(r, [0, 0], [n, n]);
24127 }
24128 return [q, r];
24129 });
24130 }
24131 const qr = op({ qr_ });
24132
24133 /**
24134 * @license
24135 * Copyright 2020 Google LLC. All Rights Reserved.
24136 * Licensed under the Apache License, Version 2.0 (the "License");
24137 * you may not use this file except in compliance with the License.
24138 * You may obtain a copy of the License at
24139 *
24140 * http://www.apache.org/licenses/LICENSE-2.0
24141 *
24142 * Unless required by applicable law or agreed to in writing, software
24143 * distributed under the License is distributed on an "AS IS" BASIS,
24144 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24145 * See the License for the specific language governing permissions and
24146 * limitations under the License.
24147 * =============================================================================
24148 */
24149 (function (Reduction) {
24150 Reduction[Reduction["NONE"] = 0] = "NONE";
24151 Reduction[Reduction["MEAN"] = 1] = "MEAN";
24152 Reduction[Reduction["SUM"] = 2] = "SUM";
24153 Reduction[Reduction["SUM_BY_NONZERO_WEIGHTS"] = 3] = "SUM_BY_NONZERO_WEIGHTS";
24154 })(exports.Reduction || (exports.Reduction = {}));
24155
24156 /**
24157 * Computes the weighted loss between two tensors.
24158 *
24159 * @param losses Tensor of shape `[batch_size, d1, ... dN]`.
24160 * @param weights Tensor whose rank is either 0, or the same rank as
24161 * `losses`, and must be broadcastable to `losses` (i.e., all
24162 * dimensions must be either `1`, or the same as the corresponding
24163 * `losses` dimension).
24164 *
24165 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
24166 */
24167 function computeWeightedLoss_(losses, weights, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
24168 const $losses = convertToTensor(losses, 'losses', 'computeWeightedLoss');
24169 let $weights = null;
24170 if (weights != null) {
24171 $weights = convertToTensor(weights, 'weights', 'computeWeightedLoss');
24172 }
24173 const weightedLoss = ($weights == null) ? $losses : mul($losses, $weights);
24174 if (reduction === exports.Reduction.NONE) {
24175 return weightedLoss;
24176 }
24177 if (reduction === exports.Reduction.SUM) {
24178 return sum$1(weightedLoss);
24179 }
24180 if (reduction === exports.Reduction.MEAN) {
24181 if ($weights == null) {
24182 return mean(weightedLoss);
24183 }
24184 else {
24185 const broadcastFactor = $losses.size / $weights.size;
24186 const result = div(sum$1(weightedLoss), sum$1($weights));
24187 return broadcastFactor > 1 ? div(result, scalar(broadcastFactor)) :
24188 result;
24189 }
24190 }
24191 if (reduction === exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
24192 if ($weights == null) {
24193 return div(sum$1(weightedLoss), scalar($losses.size));
24194 }
24195 else {
24196 const broadcastedWeights = mul($weights, ones$1($losses.shape));
24197 const numNonZeros = cast(sum$1(notEqual(broadcastedWeights, scalar(0))), 'float32');
24198 return div(sum$1(weightedLoss), numNonZeros);
24199 }
24200 }
24201 throw Error(`Unknown reduction: ${reduction}`);
24202 }
24203 const computeWeightedLoss = op({ computeWeightedLoss_ });
24204
24205 /**
24206 * @license
24207 * Copyright 2020 Google LLC. All Rights Reserved.
24208 * Licensed under the Apache License, Version 2.0 (the "License");
24209 * you may not use this file except in compliance with the License.
24210 * You may obtain a copy of the License at
24211 *
24212 * http://www.apache.org/licenses/LICENSE-2.0
24213 *
24214 * Unless required by applicable law or agreed to in writing, software
24215 * distributed under the License is distributed on an "AS IS" BASIS,
24216 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24217 * See the License for the specific language governing permissions and
24218 * limitations under the License.
24219 * =============================================================================
24220 */
24221 /**
24222 * Computes the absolute difference loss between two tensors.
24223 *
24224 * @param labels The ground truth output tensor, same dimensions as
24225 * 'predictions'.
24226 * @param predictions The predicted outputs.
24227 * @param weights Tensor whose rank is either 0, or the same rank as
24228 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
24229 * must be either `1`, or the same as the corresponding `losses`
24230 * dimension).
24231 * @param reduction Type of reduction to apply to loss. Should be of type
24232 * `Reduction`
24233 *
24234 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
24235 */
24236 function absoluteDifference_(labels, predictions, weights, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
24237 const $labels = convertToTensor(labels, 'labels', 'absoluteDifference');
24238 const $predictions = convertToTensor(predictions, 'predictions', 'absoluteDifference');
24239 let $weights = null;
24240 if (weights != null) {
24241 $weights = convertToTensor(weights, 'weights', 'absoluteDifference');
24242 }
24243 assertShapesMatch($labels.shape, $predictions.shape, 'Error in absoluteDifference: ');
24244 const losses = abs(sub($labels, $predictions));
24245 return computeWeightedLoss(losses, $weights, reduction);
24246 }
24247 const absoluteDifference = op({ absoluteDifference_ });
24248
24249 /**
24250 * Computes the cosine distance loss between two tensors.
24251 *
24252 * @param labels The ground truth output tensor, same dimensions as
24253 * 'predictions'.
24254 * @param predictions The predicted outputs.
24255 * @param axis The dimension along which the cosine distance is computed.
24256 * @param weights Tensor whose rank is either 0, or the same rank as
24257 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
24258 * must be either `1`, or the same as the corresponding `losses`
24259 * dimension).
24260 * @param reduction Type of reduction to apply to loss. Should be of type
24261 * `Reduction`
24262 *
24263 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
24264 */
24265 function cosineDistance_(labels, predictions, axis, weights, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
24266 const $labels = convertToTensor(labels, 'labels', 'cosineDistance');
24267 const $predictions = convertToTensor(predictions, 'predictions', 'cosineDistance');
24268 let $weights = null;
24269 if (weights != null) {
24270 $weights = convertToTensor(weights, 'weights', 'cosineDistance');
24271 }
24272 assertShapesMatch($labels.shape, $predictions.shape, 'Error in cosineDistance: ');
24273 const one = scalar(1);
24274 const losses = sub(one, sum$1(mul($labels, $predictions), axis, true));
24275 return computeWeightedLoss(losses, $weights, reduction);
24276 }
24277 const cosineDistance = op({ cosineDistance_ });
24278
24279 /**
24280 * Computes the Hinge loss between two tensors.
24281 *
24282 * @param labels The ground truth output tensor, same dimensions as
24283 * 'predictions'.
24284 * @param predictions The predicted outputs.
24285 * @param weights Tensor whose rank is either 0, or the same rank as
24286 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
24287 * must be either `1`, or the same as the corresponding `losses`
24288 * dimension).
24289 * @param reduction Type of reduction to apply to loss. Should be of type
24290 * `Reduction`
24291 *
24292 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
24293 */
24294 function hingeLoss_(labels, predictions, weights, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
24295 let $labels = convertToTensor(labels, 'labels', 'hingeLoss');
24296 const $predictions = convertToTensor(predictions, 'predictions', 'hingeLoss');
24297 let $weights = null;
24298 if (weights != null) {
24299 $weights = convertToTensor(weights, 'weights', 'hingeLoss');
24300 }
24301 assertShapesMatch($labels.shape, $predictions.shape, 'Error in hingeLoss: ');
24302 const one = scalar(1);
24303 // Convert binary labels to (-1, 1)
24304 $labels = sub(mul(scalar(2), $labels), one);
24305 const losses = relu(sub(one, mul($labels, $predictions)));
24306 return computeWeightedLoss(losses, $weights, reduction);
24307 }
24308 const hingeLoss = op({ hingeLoss_ });
24309
24310 /**
24311 * @license
24312 * Copyright 2020 Google LLC. All Rights Reserved.
24313 * Licensed under the Apache License, Version 2.0 (the "License");
24314 * you may not use this file except in compliance with the License.
24315 * You may obtain a copy of the License at
24316 *
24317 * http://www.apache.org/licenses/LICENSE-2.0
24318 *
24319 * Unless required by applicable law or agreed to in writing, software
24320 * distributed under the License is distributed on an "AS IS" BASIS,
24321 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24322 * See the License for the specific language governing permissions and
24323 * limitations under the License.
24324 * =============================================================================
24325 */
24326 /**
24327 * Computes the huber loss between two tensors.
24328 *
24329 * @param labels The ground truth output tensor, same dimensions as
24330 * 'predictions'.
24331 * @param predictions The predicted outputs.
24332 * @param weights Tensor whose rank is either 0, or the same rank as
24333 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
24334 * must be either `1`, or the same as the corresponding `losses`
24335 * dimension).
24336 * @param delta Point where huber loss changes from quadratic to linear.
24337 * @param reduction Type of reduction to apply to loss. Should be of type
24338 * `Reduction`.
24339 *
24340 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
24341 */
24342 function huberLoss_(labels, predictions, weights, delta = 1.0, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
24343 const $labels = convertToTensor(labels, 'labels', 'huberLoss');
24344 const $predictions = convertToTensor(predictions, 'predictions', 'huberLoss');
24345 let $weights = null;
24346 if (weights != null) {
24347 $weights = convertToTensor(weights, 'weights', 'huberLoss');
24348 }
24349 assertShapesMatch($labels.shape, $predictions.shape, 'Error in huberLoss: ');
24350 const deltaScalar = scalar(delta);
24351 const error = abs(sub($predictions, $labels));
24352 const quadratic = minimum(error, deltaScalar);
24353 const linear = sub(error, quadratic);
24354 const losses = add$1(mul(scalar(0.5), square(quadratic)), mul(deltaScalar, linear));
24355 return computeWeightedLoss(losses, $weights, reduction);
24356 }
24357 const huberLoss = op({ huberLoss_ });
24358
24359 /**
24360 * @license
24361 * Copyright 2020 Google LLC. All Rights Reserved.
24362 * Licensed under the Apache License, Version 2.0 (the "License");
24363 * you may not use this file except in compliance with the License.
24364 * You may obtain a copy of the License at
24365 *
24366 * http://www.apache.org/licenses/LICENSE-2.0
24367 *
24368 * Unless required by applicable law or agreed to in writing, software
24369 * distributed under the License is distributed on an "AS IS" BASIS,
24370 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24371 * See the License for the specific language governing permissions and
24372 * limitations under the License.
24373 * =============================================================================
24374 */
24375 /**
24376 * Computes the log loss between two tensors.
24377 *
24378 * @param labels The ground truth output tensor, same dimensions as
24379 * 'predictions'.
24380 * @param predictions The predicted outputs.
24381 * @param weights Tensor whose rank is either 0, or the same rank as
24382 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
24383 * must be either `1`, or the same as the corresponding `losses`
24384 * dimension).
24385 * @param epsilon A small increment to avoid taking log of zero
24386 * @param reduction Type of reduction to apply to loss. Should be of type
24387 * `Reduction`
24388 *
24389 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
24390 */
24391 function logLoss_(labels, predictions, weights, epsilon = 1e-7, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
24392 const $labels = convertToTensor(labels, 'labels', 'logLoss');
24393 const $predictions = convertToTensor(predictions, 'predictions', 'logLoss');
24394 let $weights = null;
24395 if (weights != null) {
24396 $weights = convertToTensor(weights, 'weights', 'logLoss');
24397 }
24398 assertShapesMatch($labels.shape, $predictions.shape, 'Error in logLoss: ');
24399 const one = scalar(1);
24400 const epsilonScalar = scalar(epsilon);
24401 const l1 = neg(mul($labels, log$1(add$1($predictions, epsilonScalar))));
24402 const l2 = mul(sub(one, $labels), log$1(add$1(sub(one, $predictions), epsilonScalar)));
24403 const losses = sub(l1, l2);
24404 return computeWeightedLoss(losses, $weights, reduction);
24405 }
24406 const logLoss = op({ logLoss_ });
24407
24408 /**
24409 * @license
24410 * Copyright 2020 Google LLC. All Rights Reserved.
24411 * Licensed under the Apache License, Version 2.0 (the "License");
24412 * you may not use this file except in compliance with the License.
24413 * You may obtain a copy of the License at
24414 *
24415 * http://www.apache.org/licenses/LICENSE-2.0
24416 *
24417 * Unless required by applicable law or agreed to in writing, software
24418 * distributed under the License is distributed on an "AS IS" BASIS,
24419 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24420 * See the License for the specific language governing permissions and
24421 * limitations under the License.
24422 * =============================================================================
24423 */
24424 /**
24425 * Computes the mean squared error between two tensors.
24426 *
24427 * @param labels The ground truth output tensor, same dimensions as
24428 * 'predictions'.
24429 * @param predictions The predicted outputs.
24430 * @param weights Tensor whose rank is either 0, or the same rank as
24431 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
24432 * must be either `1`, or the same as the corresponding `losses`
24433 * dimension).
24434 * @param reduction Type of reduction to apply to loss. Should be of type
24435 * `Reduction`
24436 *
24437 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
24438 */
24439 function meanSquaredError_(labels, predictions, weights, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
24440 const $labels = convertToTensor(labels, 'labels', 'meanSquaredError');
24441 const $predictions = convertToTensor(predictions, 'predictions', 'meanSquaredError');
24442 let $weights = null;
24443 if (weights != null) {
24444 $weights = convertToTensor(weights, 'weights', 'meanSquaredError');
24445 }
24446 assertShapesMatch($labels.shape, $predictions.shape, 'Error in meanSquaredError: ');
24447 const losses = squaredDifference($labels, $predictions);
24448 return computeWeightedLoss(losses, $weights, reduction);
24449 }
24450 const meanSquaredError = op({ meanSquaredError_ });
24451
24452 /**
24453 * @license
24454 * Copyright 2020 Google LLC. All Rights Reserved.
24455 * Licensed under the Apache License, Version 2.0 (the "License");
24456 * you may not use this file except in compliance with the License.
24457 * You may obtain a copy of the License at
24458 *
24459 * http://www.apache.org/licenses/LICENSE-2.0
24460 *
24461 * Unless required by applicable law or agreed to in writing, software
24462 * distributed under the License is distributed on an "AS IS" BASIS,
24463 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24464 * See the License for the specific language governing permissions and
24465 * limitations under the License.
24466 * =============================================================================
24467 */
24468 function sigmoidCrossEntropyWithLogits_(labels, logits) {
24469 const $labels = convertToTensor(labels, 'labels', 'sigmoidCrossEntropyWithLogits');
24470 const $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropyWithLogits');
24471 assertShapesMatch($labels.shape, $logits.shape, 'Error in sigmoidCrossEntropyWithLogits: ');
24472 /**
24473 * Implementation Details:
24474 *
24475 * For brevity, let `x = logits`, `z = labels`. The logistic loss is
24476 * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
24477 * = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
24478 * = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
24479 * = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
24480 * = (1 - z) * x + log(1 + exp(-x))
24481 * = x - x * z + log(1 + exp(-x))
24482 *
24483 * For x < 0, to avoid overflow in exp(-x), we reformulate the above
24484 * x - x * z + log(1 + exp(-x))
24485 * = log(exp(x)) - x * z + log(1 + exp(-x))
24486 * = - x * z + log(1 + exp(x))
24487 *
24488 * Hence, to ensure stability and avoid overflow, the implementation uses
24489 * this equivalent formulation:
24490 * max(x, 0) - x * z + log(1 + exp(-abs(x)))
24491 */
24492 const maxOutput = relu($logits);
24493 const outputXTarget = mul($logits, $labels);
24494 const sigmoidOutput = log1p(exp(neg(abs($logits))));
24495 return add$1(sub(maxOutput, outputXTarget), sigmoidOutput);
24496 }
24497 /**
24498 * Computes the sigmoid cross entropy loss between two tensors.
24499 *
24500 * If labelSmoothing is nonzero, smooth the labels towards 1/2:
24501 *
24502 * newMulticlassLabels = multiclassLabels * (1 - labelSmoothing)
24503 * + 0.5 * labelSmoothing
24504 *
24505 * @param multiClassLabels The ground truth output tensor of shape
24506 * [batch_size, num_classes], same dimensions as 'predictions'.
24507 * @param logits The predicted outputs.
24508 * @param weights Tensor whose rank is either 0, or the same rank as
24509 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
24510 * must be either `1`, or the same as the corresponding `losses`
24511 * dimension).
24512 * @param labelSmoothing If greater than 0, then smooth the labels.
24513 * @param reduction Type of reduction to apply to loss. Should be of type
24514 * `Reduction`
24515 *
24516 * @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }
24517 */
24518 function sigmoidCrossEntropy_(multiClassLabels, logits, weights, labelSmoothing = 0, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
24519 let $multiClassLabels = convertToTensor(multiClassLabels, 'multiClassLabels', 'sigmoidCrossEntropy');
24520 const $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropy');
24521 let $weights = null;
24522 if (weights != null) {
24523 $weights = convertToTensor(weights, 'weights', 'sigmoidCrossEntropy');
24524 }
24525 assertShapesMatch($multiClassLabels.shape, $logits.shape, 'Error in sigmoidCrossEntropy: ');
24526 if (labelSmoothing > 0) {
24527 const labelSmoothingScalar = scalar(labelSmoothing);
24528 const one = scalar(1);
24529 const half = scalar(0.5);
24530 $multiClassLabels =
24531 add$1(mul($multiClassLabels, sub(one, labelSmoothingScalar)), mul(half, labelSmoothingScalar));
24532 }
24533 const losses = sigmoidCrossEntropyWithLogits_($multiClassLabels, $logits);
24534 return computeWeightedLoss(losses, $weights, reduction);
24535 }
24536 const sigmoidCrossEntropy = op({ sigmoidCrossEntropy_ });
24537
24538 /**
24539 * @license
24540 * Copyright 2020 Google LLC. All Rights Reserved.
24541 * Licensed under the Apache License, Version 2.0 (the "License");
24542 * you may not use this file except in compliance with the License.
24543 * You may obtain a copy of the License at
24544 *
24545 * http://www.apache.org/licenses/LICENSE-2.0
24546 *
24547 * Unless required by applicable law or agreed to in writing, software
24548 * distributed under the License is distributed on an "AS IS" BASIS,
24549 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24550 * See the License for the specific language governing permissions and
24551 * limitations under the License.
24552 * =============================================================================
24553 */
24554 /**
24555 * Computes softmax cross entropy between logits and labels.
24556 *
24557 * Measures the probability error in discrete classification tasks in which
24558 * the classes are mutually exclusive (each entry is in exactly one class).
24559 * For example, each CIFAR-10 image is labeled with one and only one label: an
24560 * image can be a dog or a truck, but not both.
24561 *
24562 * `NOTE`: While the classes are mutually exclusive, their probabilities need
24563 * not be. All that is required is that each row of labels is a valid
24564 * probability distribution. If they are not, the computation of the gradient
24565 * will be incorrect.
24566 *
24567 * `WARNING`: This op expects unscaled logits, since it performs a softmax on
24568 * logits internally for efficiency. Do not call this op with the output of
24569 * softmax, as it will produce incorrect results.
24570 *
24571 * logits and labels must have the same shape, e.g. [batch_size, num_classes]
24572 * and the same dtype.
24573 * @param labels The labels array.
24574 * @param logits The logits array.
24575 * @param dim The dimension softmax would be performed on. Defaults to `-1`
24576 * which indicates the last dimension.
24577 */
24578 function softmaxCrossEntropyWithLogits_(labels, logits, dim = -1) {
24579 if (dim === -1) {
24580 dim = logits.rank - 1;
24581 }
24582 if (dim !== logits.rank - 1) {
24583 throw Error(`Softmax cross entropy along a non-last dimension is not yet ` +
24584 `supported. Labels / logits was rank ${logits.rank} ` +
24585 `and dim was ${dim}`);
24586 }
24587 // Use a custom gradient for numerical stability.
24588 const customOp = customGrad((labels, logits, save) => {
24589 // Reference:
24590 // 1. http://cs231n.github.io/linear-classify/#softmax
24591 // 2. https://blog.feedly.com/tricks-of-the-trade-logsumexp/
24592 const keepDims = true;
24593 const lse = logSumExp(logits, [dim], keepDims);
24594 const logResult = sub(cast(logits, 'float32'), lse);
24595 save([labels, logResult]);
24596 const costVector = neg(mul(logResult, labels));
24597 const value = sum$1(costVector, [dim]);
24598 const gradFunc = (dy, saved) => {
24599 const [labels, logResult] = saved;
24600 const dyShape = expandShapeToKeepDim(dy.shape, [dim]);
24601 return [
24602 mul(reshape(dy, dyShape), sub(cast(labels, 'float32'), exp(logResult))),
24603 mul(reshape(dy, dyShape), sub(exp(logResult), cast(labels, 'float32'))),
24604 ];
24605 };
24606 return { value, gradFunc };
24607 });
24608 return customOp(labels, logits);
24609 }
24610 /**
24611 * Computes the softmax cross entropy loss between two tensors.
24612 *
24613 * If labelSmoothing is nonzero, smooth the labels towards 1/2:
24614 *
24615 * newOnehotLabels = onehotLabels * (1 - labelSmoothing)
24616 * + labelSmoothing / numClasses
24617 *
24618 * @param onehotLabels One hot encoded labels
24619 * [batch_size, num_classes], same dimensions as 'predictions'.
24620 * @param logits The predicted outputs.
24621 * @param weights Tensor whose rank is either 0, or 1, and must be
24622 * broadcastable to `loss` of shape [batch_size]
24623 * @param labelSmoothing If greater than 0, then smooth the labels.
24624 * @param reduction Type of reduction to apply to loss. Should be of type
24625 * `Reduction`
24626 *
24627 * @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }
24628 */
24629 function softmaxCrossEntropy_(onehotLabels, logits, weights, labelSmoothing = 0, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
24630 let $onehotLabels = convertToTensor(onehotLabels, 'onehotLabels', 'softmaxCrossEntropy');
24631 const $logits = convertToTensor(logits, 'logits', 'softmaxCrossEntropy');
24632 let $weights = null;
24633 if (weights != null) {
24634 $weights = convertToTensor(weights, 'weights', 'softmaxCrossEntropy');
24635 }
24636 assertShapesMatch($onehotLabels.shape, $logits.shape, 'Error in softmaxCrossEntropy: ');
24637 if (labelSmoothing > 0) {
24638 const labelSmoothingScalar = scalar(labelSmoothing);
24639 const one = scalar(1);
24640 const numClasses = scalar($onehotLabels.shape[1]);
24641 $onehotLabels =
24642 add$1(mul($onehotLabels, sub(one, labelSmoothingScalar)), div(labelSmoothingScalar, numClasses));
24643 }
24644 const losses = softmaxCrossEntropyWithLogits_($onehotLabels, $logits);
24645 return computeWeightedLoss(losses, $weights, reduction);
24646 }
24647 const softmaxCrossEntropy = op({ softmaxCrossEntropy_ });
24648
24649 /**
24650 * @license
24651 * Copyright 2021 Google LLC. All Rights Reserved.
24652 * Licensed under the Apache License, Version 2.0 (the "License");
24653 * you may not use this file except in compliance with the License.
24654 * You may obtain a copy of the License at
24655 *
24656 * http://www.apache.org/licenses/LICENSE-2.0
24657 *
24658 * Unless required by applicable law or agreed to in writing, software
24659 * distributed under the License is distributed on an "AS IS" BASIS,
24660 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24661 * See the License for the specific language governing permissions and
24662 * limitations under the License.
24663 * =============================================================================
24664 */
24665 /**
24666 * The input SparseTensor is represented via the map of inputs {`indices`,
24667 * `values`, `denseShape`}. The output SparseTensor has the same `denseShape`
24668 * but with indices `outputIndices` and values `outputValues`. This op inserts a
24669 * single entry for every row that doesn't have any values. The index is created
24670 * as `[row, 0, ..., 0]` and the inserted value is `defaultValue`.
24671 *
24672 * For example, suppose `spInput` has shape [5, 6] and non-empty values:
24673 * [0, 1]: a
24674 * [0, 3]: b
24675 * [2, 0]: c
24676 * [3, 1]: d
24677 *
24678 * Rows 1 and 4 are empty, so the output will be of shape [5, 6] with values:
24679 * [0, 1]: a
24680 * [0, 3]: b
24681 * [1, 0]: `defaultValue`
24682 * [2, 0]: c
24683 * [3, 1]: d
24684 * [4, 0]: `defaultValue`
24685 *
24686 * The output SparseTensor will be in row-major order and will have the same
24687 * shape as the input.
24688 *
24689 * This op also returns an indicator vector shaped [dense_shape[0]] such that
24690 * emptyRowIndicator[i] = True iff row i was an empty row.
24691 *
24692 * And a reverse index map vector shaped [indices.shape[0]] that is used during
24693 * backpropagation, reverseIndexMap[i] = outi s.t. indices[i, j] ==
24694 * outputIndices[outi, j] for all j
24695 *
24696 * ```js
24697 * const result = tf.sparse.sparseFillEmptyRows(
24698 * [[0, 0], [1, 0], [1, 3], [1, 4], [3, 2], [3, 3]],
24699 * [0, 10, 13, 14, 32, 33], [5, 6], -1);
24700 * console.log(result);
24701 * result['outputIndices'].print(); // [[0, 0], [1, 0], [1, 3], [1, 4],
24702 * // [2, 0], [3, 2], [3, 3], [4, 0]]
24703 * result['outputValues'].print(); // [0, 10, 13, 14,-1, 32, 33, -1]
24704 * result['emptyRowIndicator'].print(); // [false, false, true, false, true]
24705 * result['reverseIndexMap'].print(); // [0, 1, 2, 3, 5, 6]
24706 * ```
24707 * @param indices: 2-D. the indices of the sparse tensor.
24708 * @param values: 1-D. the values of the sparse tensor.
24709 * @param denseShape: 1-D. the shape of the sparse tensor.
24710 * @param defaultValue: 0-D. default value to insert into location [row, 0, ...,
24711 * 0] for rows missing from the input sparse tensor.
24712 * @return A map with the following properties:
24713 * - outputIndices
24714 * - outputValues: 1-D. the values of the filled sparse tensor.
24715 * - emptyRowIndicator: 1-D. whether the dense row was missing in the input
24716 * sparse tensor.
24717 * - reverseIndexMap: 1-D. a map from the input indices to the output
24718 * indices.
24719 * @doc {heading: 'Operations', subheading: 'Sparse'}
24720 */
24721 function sparseFillEmptyRows_(indices, values, denseShape, defaultValue) {
24722 const $indices = convertToTensor(indices, 'indices', 'sparseFillEmptyRows', 'int32');
24723 const $values = convertToTensor(values, 'values', 'sparseFillEmptyRows');
24724 const $denseShape = convertToTensor(denseShape, 'denseShape', 'sparseFillEmptyRows', 'int32');
24725 const $defaultValue = convertToTensor(defaultValue, 'defaultValue', 'sparseFillEmptyRows', $values.dtype);
24726 if ($indices.rank !== 2) {
24727 throw new Error(`Indices should be Tensor2D but received shape
24728 ${$indices.shape}`);
24729 }
24730 if ($values.rank !== 1) {
24731 throw new Error(`Values should be Tensor1D but received shape ${$values.shape}`);
24732 }
24733 if ($denseShape.rank !== 1) {
24734 throw new Error(`Dense shape should be Tensor1D but received shape ${$denseShape.shape}`);
24735 }
24736 if ($defaultValue.rank !== 0) {
24737 throw new Error(`Default value should be a scalar but received shape ${$defaultValue.shape}`);
24738 }
24739 const inputs = {
24740 indices: $indices,
24741 values: $values,
24742 denseShape: $denseShape,
24743 defaultValue: $defaultValue
24744 };
24745 const result = ENGINE.runKernel(SparseFillEmptyRows, inputs);
24746 return {
24747 outputIndices: result[0],
24748 outputValues: result[1],
24749 emptyRowIndicator: result[2],
24750 reverseIndexMap: result[3]
24751 };
24752 }
24753 const sparseFillEmptyRows = op({ sparseFillEmptyRows_ });
24754
24755 /**
24756 * @license
24757 * Copyright 2021 Google LLC. All Rights Reserved.
24758 * Licensed under the Apache License, Version 2.0 (the "License");
24759 * you may not use this file except in compliance with the License.
24760 * You may obtain a copy of the License at
24761 *
24762 * http://www.apache.org/licenses/LICENSE-2.0
24763 *
24764 * Unless required by applicable law or agreed to in writing, software
24765 * distributed under the License is distributed on an "AS IS" BASIS,
24766 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24767 * See the License for the specific language governing permissions and
24768 * limitations under the License.
24769 * =============================================================================
24770 */
24771 /**
24772 * This operation has the same semantics as reshape on the represented dense
24773 * tensor. The `inputIndices` are recomputed based on the requested `newShape`.
24774 * If one component of `newShape` is the special value -1, the size of that
24775 * dimension is computed so that the total dense size remains constant. At most
24776 * one component of `newShape` can be -1. The number of dense elements implied
24777 * by `newShape` must be the same as the number of dense elements originally
24778 * implied by `inputShape`. Reshaping does not affect the order of values in the
24779 * SparseTensor. If the input tensor has rank R_in and N non-empty values, and
24780 * `newShape` has length R_out, then `inputIndices` has shape [N, R_in],
24781 * `inputShape` has length R_in, `outputIndices` has shape [N, R_out], and
24782 * `outputShape` has length R_out.
24783 *
24784 * ```js
24785 * const result = tf.sparse.sparseReshape(
24786 * [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]],
24787 * [2, 3, 6], [9, -1]);
24788 * console.log(result);
24789 * result['outputIndices'].print(); //[[0, 0], [0, 1], [1, 2], [4, 2], [8, 1]]
24790 * result['outputShape'].print(); // [9, 4]
24791 * ```
24792 * @param inputIndices: 2-D. N x R_in matrix with the indices of non-empty
24793 * values in a SparseTensor.
24794 * @param inputShape: 1-D. R_in Tensor1D with the input SparseTensor's dense
24795 * shape.
24796 * @param newShape: 1-D. R_out Tensor1D with the requested new dense shape.
24797 * @return A map with the following properties:
24798 * - outputIndices: 2-D. N x R_out matrix with the updated indices of
24799 * non-empty values in the output SparseTensor.
24800 * - outputShape: 1-D. R_out vector with the full dense shape of the output
24801 * SparseTensor. This is the same as newShape but with any -1 dimensions
24802 * filled in.
24803 * @doc {heading: 'Operations', subheading: 'Sparse'}
24804 */
24805 function sparseReshape_(inputIndices, inputShape, newShape) {
24806 const $inputIndices = convertToTensor(inputIndices, 'inputIndices', 'sparseReshape', 'int32');
24807 const $inputShape = convertToTensor(inputShape, 'inputShape', 'sparseReshape', 'int32');
24808 const $newShape = convertToTensor(newShape, 'newShape', 'sparseReshape', 'int32');
24809 if ($inputIndices.rank !== 2) {
24810 throw new Error(`Input indices should be Tensor2D but received shape
24811 ${$inputIndices.shape}`);
24812 }
24813 if ($inputShape.rank !== 1) {
24814 throw new Error(`Input shape should be Tensor1D but received shape ${$inputShape.shape}`);
24815 }
24816 if ($newShape.rank !== 1) {
24817 throw new Error(`New shape should be Tensor1D but received shape ${$newShape.shape}`);
24818 }
24819 const inputs = {
24820 inputIndices: $inputIndices,
24821 inputShape: $inputShape,
24822 newShape: $newShape
24823 };
24824 const result = ENGINE.runKernel(SparseReshape, inputs);
24825 return { outputIndices: result[0], outputShape: result[1] };
24826 }
24827 const sparseReshape = op({ sparseReshape_ });
24828
24829 /**
24830 * @license
24831 * Copyright 2021 Google LLC. All Rights Reserved.
24832 * Licensed under the Apache License, Version 2.0 (the "License");
24833 * you may not use this file except in compliance with the License.
24834 * You may obtain a copy of the License at
24835 *
24836 * http://www.apache.org/licenses/LICENSE-2.0
24837 *
24838 * Unless required by applicable law or agreed to in writing, software
24839 * distributed under the License is distributed on an "AS IS" BASIS,
24840 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24841 * See the License for the specific language governing permissions and
24842 * limitations under the License.
24843 * =============================================================================
24844 */
24845 /**
24846 * Computes the mean along sparse segments of a tensor.
24847 *
24848 * ```js
24849 * const c = tf.tensor2d([[1,2,3,4], [-1,-2,-3,-4], [6,7,8,9]]);
24850 * // Select two rows, one segment.
24851 * const result1 = tf.sparse.sparseSegmentMean(c,
24852 * tf.tensor1d([0, 1], 'int32'),
24853 * tf.tensor1d([0, 0], 'int32'));
24854 * result1.print(); // [[0, 0, 0, 0]]
24855 *
24856 * // Select two rows, two segments.
24857 * const result2 = tf.sparse.sparseSegmentMean(c,
24858 * tf.tensor1d([0, 1], 'int32'),
24859 * tf.tensor1d([0, 1], 'int32'));
24860 * result2.print(); // [[1, 2, 3, 4], [-1, -2, -3, -4]]
24861 *
24862 * // Select all rows, two segments.
24863 * const result3 = tf.sparse.sparseSegmentMean(c,
24864 * tf.tensor1d([0, 1, 2], 'int32'),
24865 * tf.tensor1d([0, 1, 1], 'int32'));
24866 * result3.print(); // [[1.0, 2.0, 3.0, 4.0], [2.5, 2.5, 2.5, 2.5]]
24867 * ```
24868 * @param data: A Tensor of at least one dimension with data that will be
24869 * assembled in the output.
24870 * @param indices: A 1-D Tensor with indices into data. Has same rank as
24871 * segmentIds.
24872 * @param segmentIds: A 1-D Tensor with indices into the output Tensor. Values
24873 * should be sorted and can be repeated.
24874 * @return Has same shape as data, except for dimension 0 which has equal to
24875 * the number of segments.
24876 *
24877 * @doc {heading: 'Operations', subheading: 'Sparse'}
24878 */
24879 function sparseSegmentMean_(data, indices, segmentIds) {
24880 const $data = convertToTensor(data, 'data', 'sparseSegmentMean');
24881 const $indices = convertToTensor(indices, 'indices', 'sparseSegmentMean', 'int32');
24882 const $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'sparseSegmentMean', 'int32');
24883 if ($data.rank < 1) {
24884 throw new Error(`Data should be at least 1 dimensional but received scalar`);
24885 }
24886 if ($indices.rank !== 1) {
24887 throw new Error(`Indices should be Tensor1D but received shape
24888 ${$indices.shape}`);
24889 }
24890 if ($segmentIds.rank !== 1) {
24891 throw new Error(`Segment ids should be Tensor1D but received shape
24892 ${$segmentIds.shape}`);
24893 }
24894 const inputs = {
24895 data: $data,
24896 indices: $indices,
24897 segmentIds: $segmentIds
24898 };
24899 return ENGINE.runKernel(SparseSegmentMean, inputs);
24900 }
24901 const sparseSegmentMean = op({ sparseSegmentMean_ });
24902
24903 /**
24904 * @license
24905 * Copyright 2021 Google LLC. All Rights Reserved.
24906 * Licensed under the Apache License, Version 2.0 (the "License");
24907 * you may not use this file except in compliance with the License.
24908 * You may obtain a copy of the License at
24909 *
24910 * http://www.apache.org/licenses/LICENSE-2.0
24911 *
24912 * Unless required by applicable law or agreed to in writing, software
24913 * distributed under the License is distributed on an "AS IS" BASIS,
24914 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24915 * See the License for the specific language governing permissions and
24916 * limitations under the License.
24917 * =============================================================================
24918 */
24919 /**
24920 * Computes the sum along sparse segments of a tensor.
24921 *
24922 * ```js
24923 * const c = tf.tensor2d([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]);
24924 * // Select two rows, one segment.
24925 * const result1 = tf.sparse.sparseSegmentSum(c,
24926 * tf.tensor1d([0, 1], 'int32'),
24927 * tf.tensor1d([0, 0], 'int32'));
24928 * result1.print(); // [[0, 0, 0, 0]]
24929 *
24930 * // Select two rows, two segment.
24931 * const result2 = tf.sparse.sparseSegmentSum(c,
24932 * tf.tensor1d([0, 1], 'int32'),
24933 * tf.tensor1d([0, 1], 'int32'));
24934 * result2.print(); // [[1, 2, 3, 4], [-1, -2, -3, -4]]
24935 *
24936 * // Select all rows, two segments.
24937 * const result3 = tf.sparse.sparseSegmentSum(c,
24938 * tf.tensor1d([0, 1, 2], 'int32'),
24939 * tf.tensor1d([0, 0, 1], 'int32'));
24940 * result3.print(); // [[0, 0, 0, 0], [5, 6, 7, 8]]
24941 * ```
24942 * @param data: A Tensor of at least one dimension with data that will be
24943 * assembled in the output.
24944 * @param indices: A 1-D Tensor with indices into data. Has same rank as
24945 * segmentIds.
24946 * @param segmentIds: A 1-D Tensor with indices into the output Tensor. Values
24947 * should be sorted and can be repeated.
24948 * @return Has same shape as data, except for dimension 0 which has equal to
24949 * the number of segments.
24950 *
24951 * @doc {heading: 'Operations', subheading: 'Sparse'}
24952 */
24953 function sparseSegmentSum_(data, indices, segmentIds) {
24954 const $data = convertToTensor(data, 'data', 'sparseSegmentSum');
24955 const $indices = convertToTensor(indices, 'indices', 'sparseSegmentSum', 'int32');
24956 const $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'sparseSegmentSum', 'int32');
24957 if ($data.rank < 1) {
24958 throw new Error(`Data should be at least 1 dimensional but received scalar`);
24959 }
24960 if ($indices.rank !== 1) {
24961 throw new Error(`Indices should be Tensor1D but received shape
24962 ${$indices.shape}`);
24963 }
24964 if ($segmentIds.rank !== 1) {
24965 throw new Error(`Segment ids should be Tensor1D but received shape
24966 ${$segmentIds.shape}`);
24967 }
24968 const inputs = {
24969 data: $data,
24970 indices: $indices,
24971 segmentIds: $segmentIds
24972 };
24973 return ENGINE.runKernel(SparseSegmentSum, inputs);
24974 }
24975 const sparseSegmentSum = op({ sparseSegmentSum_ });
24976
24977 /**
24978 * @license
24979 * Copyright 2021 Google LLC. All Rights Reserved.
24980 * Licensed under the Apache License, Version 2.0 (the "License");
24981 * you may not use this file except in compliance with the License.
24982 * You may obtain a copy of the License at
24983 *
24984 * http://www.apache.org/licenses/LICENSE-2.0
24985 *
24986 * Unless required by applicable law or agreed to in writing, software
24987 * distributed under the License is distributed on an "AS IS" BASIS,
24988 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24989 * See the License for the specific language governing permissions and
24990 * limitations under the License.
24991 * =============================================================================
24992 */
24993 /**
24994 * Creates ngrams from ragged string data.
24995 *
24996 * This op accepts a ragged tensor with 1 ragged dimension containing only
24997 * strings and outputs a ragged tensor with 1 ragged dimension containing ngrams
24998 * of that string, joined along the innermost axis.
24999 *
25000 * ```js
25001 * const result = tf.string.stringNGrams(
25002 * ['a', 'b', 'c', 'd'], tf.tensor1d([0, 2, 4], 'int32'),
25003 * '|', [1, 2], 'LP', 'RP', -1, false);
25004 * result['nGrams'].print(); // ['a', 'b', 'LP|a', 'a|b', 'b|RP',
25005 * // 'c', 'd', 'LP|c', 'c|d', 'd|RP']
25006 * result['nGramsSplits'].print(); // [0, 5, 10]
25007 * ```
25008 * @param data: The values tensor of the ragged string tensor to make ngrams out
25009 * of. Must be a 1D string tensor.
25010 * @param dataSplits: The splits tensor of the ragged string tensor to make
25011 * ngrams out of.
25012 * @param separator: The string to append between elements of the token. Use ""
25013 * for no separator.
25014 * @param nGramWidths: The sizes of the ngrams to create.
25015 * @param leftPad: The string to use to pad the left side of the ngram sequence.
25016 * Only used if pad_width !== 0.
25017 * @param rightPad: The string to use to pad the right side of the ngram
25018 * sequence. Only used if pad_width !== 0.
25019 * @param padWidth: The number of padding elements to add to each side of each
25020 * sequence. Note that padding will never be greater than `nGramWidths`-1
25021 * regardless of this value. If `padWidth`=-1 , then add max(`nGramWidths)-1
25022 * elements.
25023 * @param preserveShortSequences: If true, then ensure that at least one ngram
25024 * is generated for each input sequence. In particular, if an input sequence
25025 * is shorter than min(ngramWidth) + 2*padWidth, then generate a single
25026 * ngram containing the entire sequence. If false, then no ngrams are
25027 * generated for these short input sequences.
25028 * @return A map with the following properties:
25029 * - nGrams: The values tensor of the output ngrams ragged tensor.
25030 * - nGramsSplits: The splits tensor of the output ngrams ragged tensor.
25031 *
25032 * @doc {heading: 'Operations', subheading: 'String'}
25033 */
25034 function stringNGrams_(data, dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
25035 const $data = convertToTensor(data, 'data', 'stringNGrams', 'string');
25036 if ($data.dtype !== 'string') {
25037 throw new Error('Data must be of datatype string');
25038 }
25039 if ($data.shape.length !== 1) {
25040 throw new Error(`Data must be a vector, saw: ${$data.shape}`);
25041 }
25042 const $dataSplits = convertToTensor(dataSplits, 'dataSplits', 'stringNGrams');
25043 if ($dataSplits.dtype !== 'int32') {
25044 throw new Error('Data splits must be of datatype int32');
25045 }
25046 const attrs = {
25047 separator,
25048 nGramWidths,
25049 leftPad,
25050 rightPad,
25051 padWidth,
25052 preserveShortSequences
25053 };
25054 const inputs = { data: $data, dataSplits: $dataSplits };
25055 const result = ENGINE.runKernel(StringNGrams, inputs, attrs);
25056 return { nGrams: result[0], nGramsSplits: result[1] };
25057 }
25058 const stringNGrams = op({ stringNGrams_ });
25059
25060 /**
25061 * @license
25062 * Copyright 2021 Google LLC. All Rights Reserved.
25063 * Licensed under the Apache License, Version 2.0 (the "License");
25064 * you may not use this file except in compliance with the License.
25065 * You may obtain a copy of the License at
25066 *
25067 * http://www.apache.org/licenses/LICENSE-2.0
25068 *
25069 * Unless required by applicable law or agreed to in writing, software
25070 * distributed under the License is distributed on an "AS IS" BASIS,
25071 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25072 * See the License for the specific language governing permissions and
25073 * limitations under the License.
25074 * =============================================================================
25075 */
25076 /**
25077 * Split elements of `input` based on `delimiter` into a SparseTensor .
25078 *
25079 * Let N be the size of source (typically N will be the batch size). Split each
25080 * element of `input` based on `delimiter` and return a SparseTensor containing
25081 * the splitted tokens. Empty tokens are ignored if `skipEmpty` is set to True.
25082 *
25083 * `delimiter` can be empty, or a string of split characters. If `delimiter` is
25084 * an empty string, each element of `input` is split into individual
25085 * character strings. Otherwise every character of `delimiter` is a potential
25086 * split point.
25087 *
25088 * ```js
25089 * const result = tf.string.stringSplit(['hello world', 'a b c'], ' ');
25090 * result['indices'].print(); // [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]]
25091 * result['values'].print(); // ['hello', 'world', 'a', 'b', 'c']
25092 * result['shape'].print(); // [2, 3]
25093 * ```
25094 * @param input: 1-D. Strings to split.
25095 * @param delimiter: 0-D. Delimiter characters, or empty string.
25096 * @param skipEmpty: Optional. If true, skip the empty strings from the result.
25097 * Defaults to true.
25098 * @return A map with the following properties:
25099 * - indices: A dense matrix of int32 representing the indices of the sparse
25100 * tensor.
25101 * - values: A vector of strings corresponding to the splited values.
25102 * - shape: a length-2 vector of int32 representing the shape of the sparse
25103 * tensor, where the first value is N and the second value is the maximum number
25104 * of tokens in a single input entry.
25105 *
25106 * @doc {heading: 'Operations', subheading: 'String'}
25107 */
25108 function stringSplit_(input, delimiter, skipEmpty = true) {
25109 const $input = convertToTensor(input, 'input', 'stringSplit', 'string');
25110 const $delimiter = convertToTensor(delimiter, 'delimiter', 'stringSplit', 'string');
25111 if ($input.rank !== 1) {
25112 throw new Error(`Input should be Tensor1D but received shape ${$input.shape}`);
25113 }
25114 if ($delimiter.rank !== 0) {
25115 throw new Error(`Delimiter should be a scalar but received shape ${$delimiter.shape}`);
25116 }
25117 const attrs = { skipEmpty };
25118 const inputs = { input: $input, delimiter: $delimiter };
25119 const result = ENGINE.runKernel(StringSplit, inputs, attrs);
25120 return { indices: result[0], values: result[1], shape: result[2] };
25121 }
25122 const stringSplit = op({ stringSplit_ });
25123
25124 /**
25125 * @license
25126 * Copyright 2021 Google LLC. All Rights Reserved.
25127 * Licensed under the Apache License, Version 2.0 (the "License");
25128 * you may not use this file except in compliance with the License.
25129 * You may obtain a copy of the License at
25130 *
25131 * http://www.apache.org/licenses/LICENSE-2.0
25132 *
25133 * Unless required by applicable law or agreed to in writing, software
25134 * distributed under the License is distributed on an "AS IS" BASIS,
25135 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25136 * See the License for the specific language governing permissions and
25137 * limitations under the License.
25138 * =============================================================================
25139 */
25140 /**
25141 * Converts each string in the input Tensor to its hash mod by a number of
25142 * buckets.
25143 *
25144 * The hash function is deterministic on the content of the string within the
25145 * process and will never change. However, it is not suitable for cryptography.
25146 * This function may be used when CPU time is scarce and inputs are trusted or
25147 * unimportant. There is a risk of adversaries constructing inputs that all hash
25148 * to the same bucket.
25149 *
25150 * ```js
25151 * const result = tf.string.stringToHashBucketFast(
25152 * ['Hello', 'TensorFlow', '2.x'], 3);
25153 * result.print(); // [0, 2, 2]
25154 * ```
25155 * @param input: The strings to assign a hash bucket.
25156 * @param numBuckets: The number of buckets.
25157 * @return A Tensor of the same shape as the input tensor.
25158 *
25159 * @doc {heading: 'Operations', subheading: 'String'}
25160 */
25161 function stringToHashBucketFast_(input, numBuckets) {
25162 const $input = convertToTensor(input, 'input', 'stringToHashBucketFast', 'string');
25163 const attrs = { numBuckets };
25164 if (numBuckets <= 0) {
25165 throw new Error(`Number of buckets must be at least 1`);
25166 }
25167 const inputs = { input: $input };
25168 return ENGINE.runKernel(StringToHashBucketFast, inputs, attrs);
25169 }
25170 const stringToHashBucketFast = op({ stringToHashBucketFast_ });
25171
25172 /**
25173 * @license
25174 * Copyright 2020 Google LLC. All Rights Reserved.
25175 * Licensed under the Apache License, Version 2.0 (the "License");
25176 * you may not use this file except in compliance with the License.
25177 * You may obtain a copy of the License at
25178 *
25179 * http://www.apache.org/licenses/LICENSE-2.0
25180 *
25181 * Unless required by applicable law or agreed to in writing, software
25182 * distributed under the License is distributed on an "AS IS" BASIS,
25183 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25184 * See the License for the specific language governing permissions and
25185 * limitations under the License.
25186 * =============================================================================
25187 */
25188 const spectral = {
25189 fft,
25190 ifft,
25191 rfft,
25192 irfft
25193 };
25194 const signal = {
25195 hammingWindow,
25196 hannWindow,
25197 frame,
25198 stft,
25199 };
25200 const image = {
25201 flipLeftRight,
25202 grayscaleToRGB,
25203 resizeNearestNeighbor,
25204 resizeBilinear,
25205 rotateWithOffset,
25206 cropAndResize,
25207 nonMaxSuppression,
25208 nonMaxSuppressionAsync,
25209 nonMaxSuppressionWithScore,
25210 nonMaxSuppressionWithScoreAsync,
25211 nonMaxSuppressionPadded,
25212 nonMaxSuppressionPaddedAsync,
25213 threshold,
25214 transform
25215 };
25216 const linalg = {
25217 bandPart,
25218 gramSchmidt,
25219 qr
25220 };
25221 const losses = {
25222 absoluteDifference,
25223 computeWeightedLoss,
25224 cosineDistance,
25225 hingeLoss,
25226 huberLoss,
25227 logLoss,
25228 meanSquaredError,
25229 sigmoidCrossEntropy,
25230 softmaxCrossEntropy
25231 };
25232 const sparse = {
25233 sparseFillEmptyRows,
25234 sparseReshape,
25235 sparseSegmentMean,
25236 sparseSegmentSum
25237 };
25238 // tslint:disable-next-line:variable-name
25239 const string = {
25240 stringNGrams,
25241 stringSplit,
25242 stringToHashBucketFast
25243 };
25244
25245 /**
25246 * @license
25247 * Copyright 2018 Google LLC. All Rights Reserved.
25248 * Licensed under the Apache License, Version 2.0 (the "License");
25249 * you may not use this file except in compliance with the License.
25250 * You may obtain a copy of the License at
25251 *
25252 * http://www.apache.org/licenses/LICENSE-2.0
25253 *
25254 * Unless required by applicable law or agreed to in writing, software
25255 * distributed under the License is distributed on an "AS IS" BASIS,
25256 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25257 * See the License for the specific language governing permissions and
25258 * limitations under the License.
25259 * =============================================================================
25260 */
25261 /** @doc {heading: 'Training', subheading: 'Classes', namespace: 'train'} */
25262 class Optimizer extends Serializable {
25263 /**
25264 * Executes `f()` and minimizes the scalar output of `f()` by computing
25265 * gradients of y with respect to the list of trainable variables provided by
25266 * `varList`. If no list is provided, it defaults to all trainable variables.
25267 *
25268 * @param f The function to execute and whose output to minimize.
25269 * @param returnCost Whether to return the scalar cost value produced by
25270 * executing `f()`.
25271 * @param varList An optional list of variables to update. If specified, only
25272 * the trainable variables in varList will be updated by minimize. Defaults to
25273 * all trainable variables.
25274 *
25275 * @doc {heading: 'Training', subheading: 'Optimizers'}
25276 */
25277 minimize(f, returnCost = false, varList) {
25278 const { value, grads } = this.computeGradients(f, varList);
25279 if (varList != null) {
25280 const gradArray = varList.map(v => ({ name: v.name, tensor: grads[v.name] }));
25281 this.applyGradients(gradArray);
25282 }
25283 else {
25284 this.applyGradients(grads);
25285 }
25286 // Dispose gradients.
25287 dispose(grads);
25288 if (returnCost) {
25289 return value;
25290 }
25291 else {
25292 value.dispose();
25293 return null;
25294 }
25295 }
25296 /**
25297 * The number of iterations that this optimizer instance has been invoked for.
25298 */
25299 get iterations() {
25300 if (this.iterations_ == null) {
25301 this.iterations_ = 0;
25302 }
25303 return this.iterations_;
25304 }
25305 incrementIterations() {
25306 this.iterations_ = this.iterations + 1;
25307 }
25308 /**
25309 * Executes f() and computes the gradient of the scalar output of f() with
25310 * respect to the list of trainable variables provided by `varList`. If no
25311 * list is provided, it defaults to all trainable variables.
25312 *
25313 * @param f The function to execute and whose output to use for computing
25314 * gradients with respect to variables.
25315 * @param varList An optional list of variables to compute gradients with
25316 * respect to. If specified, only the trainable variables in varList will have
25317 * gradients computed with respect to. Defaults to all trainable variables.
25318 *
25319 * @doc {heading: 'Training', subheading: 'Optimizers'}
25320 */
25321 computeGradients(f, varList) {
25322 return variableGrads(f, varList);
25323 }
25324 /**
25325 * Dispose the variables (if any) owned by this optimizer instance.
25326 */
25327 dispose() {
25328 if (this.iterations_ != null) {
25329 dispose(this.iterations_);
25330 }
25331 }
25332 async saveIterations() {
25333 if (this.iterations_ == null) {
25334 this.iterations_ = 0;
25335 }
25336 return {
25337 name: 'iter',
25338 // TODO(cais): Use 'int64' type when available.
25339 tensor: scalar(this.iterations_, 'int32')
25340 };
25341 }
25342 async getWeights() {
25343 throw new Error('getWeights() is not implemented for this optimizer yet.');
25344 }
25345 async setWeights(weightValues) {
25346 throw new Error(`setWeights() is not implemented for this optimizer class ` +
25347 `${this.getClassName()}`);
25348 }
25349 /**
25350 * Extract the first element of the weight values and set it
25351 * as the iterations counter variable of this instance of optimizer.
25352 *
25353 * @param weightValues
25354 * @returns Weight values with the first element consumed and excluded.
25355 */
25356 async extractIterations(weightValues) {
25357 this.iterations_ = (await weightValues[0].tensor.data())[0];
25358 return weightValues.slice(1);
25359 }
25360 }
25361 Object.defineProperty(Optimizer, Symbol.hasInstance, {
25362 value: (instance) => {
25363 return instance.minimize != null && instance.computeGradients != null &&
25364 instance.applyGradients != null;
25365 }
25366 });
25367
25368 /**
25369 * @license
25370 * Copyright 2018 Google LLC. All Rights Reserved.
25371 * Licensed under the Apache License, Version 2.0 (the "License");
25372 * you may not use this file except in compliance with the License.
25373 * You may obtain a copy of the License at
25374 *
25375 * http://www.apache.org/licenses/LICENSE-2.0
25376 *
25377 * Unless required by applicable law or agreed to in writing, software
25378 * distributed under the License is distributed on an "AS IS" BASIS,
25379 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25380 * See the License for the specific language governing permissions and
25381 * limitations under the License.
25382 * =============================================================================
25383 */
25384 /** @doclink Optimizer */
25385 class AdadeltaOptimizer extends Optimizer {
25386 constructor(learningRate, rho, epsilon = null) {
25387 super();
25388 this.learningRate = learningRate;
25389 this.rho = rho;
25390 this.epsilon = epsilon;
25391 this.accumulatedGrads = [];
25392 this.accumulatedUpdates = [];
25393 if (epsilon == null) {
25394 this.epsilon = ENGINE.backend.epsilon();
25395 }
25396 }
25397 applyGradients(variableGradients) {
25398 const variableNames = Array.isArray(variableGradients) ?
25399 variableGradients.map(item => item.name) :
25400 Object.keys(variableGradients);
25401 variableNames.forEach((name, i) => {
25402 const value = ENGINE.registeredVariables[name];
25403 const trainable = false;
25404 if (this.accumulatedGrads[i] == null) {
25405 this.accumulatedGrads[i] = {
25406 originalName: `${name}/accum_grad`,
25407 variable: tidy(() => zerosLike(value).variable(trainable))
25408 };
25409 }
25410 if (this.accumulatedUpdates[i] == null) {
25411 this.accumulatedUpdates[i] = {
25412 originalName: `${name}/accum_var`,
25413 variable: tidy(() => zerosLike(value).variable(trainable))
25414 };
25415 }
25416 const gradient = Array.isArray(variableGradients) ?
25417 variableGradients[i].tensor :
25418 variableGradients[name];
25419 if (gradient == null) {
25420 return;
25421 }
25422 const accumulatedGrad = this.accumulatedGrads[i].variable;
25423 const accumulatedUpdate = this.accumulatedUpdates[i].variable;
25424 tidy(() => {
25425 const newAccumulatedGrad = add$1(mul(accumulatedGrad, this.rho), mul(square(gradient), 1 - this.rho));
25426 const updates = mul(div(sqrt(add$1(accumulatedUpdate, this.epsilon)), sqrt(add$1(accumulatedGrad, this.epsilon))), gradient);
25427 const newAccumulatedUpdate = add$1(mul(accumulatedUpdate, this.rho), mul(square(updates), 1 - this.rho));
25428 accumulatedGrad.assign(newAccumulatedGrad);
25429 accumulatedUpdate.assign(newAccumulatedUpdate);
25430 const newValue = add$1(mul(updates, -this.learningRate), value);
25431 value.assign(newValue);
25432 });
25433 });
25434 this.incrementIterations();
25435 }
25436 dispose() {
25437 if (this.accumulatedUpdates != null) {
25438 dispose(this.accumulatedGrads.map(v => v.variable));
25439 dispose(this.accumulatedUpdates.map(v => v.variable));
25440 }
25441 }
25442 async getWeights() {
25443 // Order matters for Python compatibility.
25444 const variables = [...this.accumulatedGrads, ...this.accumulatedUpdates];
25445 return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable })));
25446 }
25447 async setWeights(weightValues) {
25448 weightValues = await this.extractIterations(weightValues);
25449 const variableCount = weightValues.length / 2;
25450 const trainable = false;
25451 this.accumulatedGrads =
25452 weightValues.slice(0, variableCount).map(v => ({
25453 originalName: v.name,
25454 variable: v.tensor.variable(trainable)
25455 }));
25456 this.accumulatedUpdates =
25457 weightValues.slice(variableCount, variableCount * 2)
25458 .map(v => ({
25459 originalName: v.name,
25460 variable: v.tensor.variable(trainable)
25461 }));
25462 }
25463 getConfig() {
25464 return {
25465 'learningRate': this.learningRate,
25466 'rho': this.rho,
25467 'epsilon': this.epsilon
25468 };
25469 }
25470 /** @nocollapse */
25471 static fromConfig(cls, config) {
25472 return new cls(config['learningRate'], config['rho'], config['epsilon']);
25473 }
25474 }
25475 /** @nocollapse */
25476 AdadeltaOptimizer.className = 'Adadelta'; // Name matters for Python compatibility.
25477 registerClass(AdadeltaOptimizer);
25478
25479 /**
25480 * @license
25481 * Copyright 2018 Google LLC. All Rights Reserved.
25482 * Licensed under the Apache License, Version 2.0 (the "License");
25483 * you may not use this file except in compliance with the License.
25484 * You may obtain a copy of the License at
25485 *
25486 * http://www.apache.org/licenses/LICENSE-2.0
25487 *
25488 * Unless required by applicable law or agreed to in writing, software
25489 * distributed under the License is distributed on an "AS IS" BASIS,
25490 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25491 * See the License for the specific language governing permissions and
25492 * limitations under the License.
25493 * =============================================================================
25494 */
25495 /** @doclink Optimizer */
25496 class AdagradOptimizer extends Optimizer {
25497 constructor(learningRate, initialAccumulatorValue = 0.1) {
25498 super();
25499 this.learningRate = learningRate;
25500 this.initialAccumulatorValue = initialAccumulatorValue;
25501 this.accumulatedGrads = [];
25502 }
25503 applyGradients(variableGradients) {
25504 const variableNames = Array.isArray(variableGradients) ?
25505 variableGradients.map(item => item.name) :
25506 Object.keys(variableGradients);
25507 variableNames.forEach((name, i) => {
25508 const value = ENGINE.registeredVariables[name];
25509 if (this.accumulatedGrads[i] == null) {
25510 const trainable = false;
25511 this.accumulatedGrads[i] = {
25512 originalName: `${name}/accumulator`,
25513 variable: tidy(() => fill(value.shape, this.initialAccumulatorValue)
25514 .variable(trainable))
25515 };
25516 }
25517 const gradient = Array.isArray(variableGradients) ?
25518 variableGradients[i].tensor :
25519 variableGradients[name];
25520 if (gradient == null) {
25521 return;
25522 }
25523 const accumulatedGrad = this.accumulatedGrads[i].variable;
25524 tidy(() => {
25525 const newAccumulatedGrad = add$1(accumulatedGrad, square(gradient));
25526 accumulatedGrad.assign(newAccumulatedGrad);
25527 const newValue = add$1(mul(div(gradient, sqrt(add$1(newAccumulatedGrad, ENGINE.backend.epsilon()))), -this.learningRate), value);
25528 value.assign(newValue);
25529 });
25530 });
25531 this.incrementIterations();
25532 }
25533 dispose() {
25534 if (this.accumulatedGrads != null) {
25535 dispose(this.accumulatedGrads.map(v => v.variable));
25536 }
25537 }
25538 async getWeights() {
25539 // Order matters for Python compatibility.
25540 return [await this.saveIterations()].concat(this.accumulatedGrads.map(v => ({ name: v.originalName, tensor: v.variable })));
25541 }
25542 async setWeights(weightValues) {
25543 weightValues = await this.extractIterations(weightValues);
25544 const trainable = false;
25545 this.accumulatedGrads = weightValues.map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) }));
25546 }
25547 getConfig() {
25548 return {
25549 'learningRate': this.learningRate,
25550 'initialAccumulatorValue': this.initialAccumulatorValue,
25551 };
25552 }
25553 /** @nocollapse */
25554 static fromConfig(cls, config) {
25555 return new cls(config['learningRate'], config['initialAccumulatorValue']);
25556 }
25557 }
25558 /** @nocollapse */
25559 AdagradOptimizer.className = 'Adagrad'; // Note: Name matters for Python compatibility.
25560 registerClass(AdagradOptimizer);
25561
25562 /**
25563 * @license
25564 * Copyright 2018 Google LLC. All Rights Reserved.
25565 * Licensed under the Apache License, Version 2.0 (the "License");
25566 * you may not use this file except in compliance with the License.
25567 * You may obtain a copy of the License at
25568 *
25569 * http://www.apache.org/licenses/LICENSE-2.0
25570 *
25571 * Unless required by applicable law or agreed to in writing, software
25572 * distributed under the License is distributed on an "AS IS" BASIS,
25573 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25574 * See the License for the specific language governing permissions and
25575 * limitations under the License.
25576 * =============================================================================
25577 */
25578 class AdamOptimizer extends Optimizer {
25579 constructor(learningRate, beta1, beta2, epsilon = null) {
25580 super();
25581 this.learningRate = learningRate;
25582 this.beta1 = beta1;
25583 this.beta2 = beta2;
25584 this.epsilon = epsilon;
25585 this.accumulatedFirstMoment = [];
25586 this.accumulatedSecondMoment = [];
25587 tidy(() => {
25588 // accB* will be updated by batch.
25589 this.accBeta1 = scalar(beta1).variable();
25590 this.accBeta2 = scalar(beta2).variable();
25591 });
25592 if (epsilon == null) {
25593 this.epsilon = ENGINE.backend.epsilon();
25594 }
25595 }
25596 applyGradients(variableGradients) {
25597 const varNames = Array.isArray(variableGradients) ?
25598 variableGradients.map(v => v.name) :
25599 Object.keys(variableGradients);
25600 tidy(() => {
25601 const oneMinusAccBeta1 = sub(1, this.accBeta1);
25602 const oneMinusAccBeta2 = sub(1, this.accBeta2);
25603 varNames.forEach((name, i) => {
25604 const value = ENGINE.registeredVariables[name];
25605 const trainable = false;
25606 if (this.accumulatedFirstMoment[i] == null) {
25607 this.accumulatedFirstMoment[i] = {
25608 originalName: `${name}/m`,
25609 variable: tidy(() => zerosLike(value).variable(trainable))
25610 };
25611 }
25612 if (this.accumulatedSecondMoment[i] == null) {
25613 this.accumulatedSecondMoment[i] = {
25614 originalName: `${name}/v`,
25615 variable: tidy(() => zerosLike(value).variable(trainable))
25616 };
25617 }
25618 const gradient = Array.isArray(variableGradients) ?
25619 variableGradients[i].tensor :
25620 variableGradients[name];
25621 if (gradient == null) {
25622 return;
25623 }
25624 const firstMoment = this.accumulatedFirstMoment[i].variable;
25625 const secondMoment = this.accumulatedSecondMoment[i].variable;
25626 const newFirstMoment = add$1(mul(firstMoment, this.beta1), mul(gradient, 1 - this.beta1));
25627 const newSecondMoment = add$1(mul(secondMoment, this.beta2), mul(square(gradient), 1 - this.beta2));
25628 const biasCorrectedFirstMoment = div(newFirstMoment, oneMinusAccBeta1);
25629 const biasCorrectedSecondMoment = div(newSecondMoment, oneMinusAccBeta2);
25630 firstMoment.assign(newFirstMoment);
25631 secondMoment.assign(newSecondMoment);
25632 const newValue = add$1(mul(div(biasCorrectedFirstMoment, add$1(sqrt(biasCorrectedSecondMoment), this.epsilon)), -this.learningRate), value);
25633 value.assign(newValue);
25634 });
25635 this.accBeta1.assign(mul(this.accBeta1, this.beta1));
25636 this.accBeta2.assign(mul(this.accBeta2, this.beta2));
25637 });
25638 this.incrementIterations();
25639 }
25640 dispose() {
25641 this.accBeta1.dispose();
25642 this.accBeta2.dispose();
25643 if (this.accumulatedFirstMoment != null) {
25644 dispose(this.accumulatedFirstMoment.map(v => v.variable));
25645 }
25646 if (this.accumulatedSecondMoment != null) {
25647 dispose(this.accumulatedSecondMoment.map(v => v.variable));
25648 }
25649 }
25650 async getWeights() {
25651 // Order matters for Python compatibility.
25652 const variables = [...this.accumulatedFirstMoment, ...this.accumulatedSecondMoment];
25653 return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable })));
25654 }
25655 async setWeights(weightValues) {
25656 weightValues = await this.extractIterations(weightValues);
25657 tidy(() => {
25658 this.accBeta1.assign(pow(this.beta1, this.iterations_ + 1));
25659 this.accBeta2.assign(pow(this.beta2, this.iterations_ + 1));
25660 });
25661 const variableCount = weightValues.length / 2;
25662 const trainable = false;
25663 this.accumulatedFirstMoment =
25664 weightValues.slice(0, variableCount).map(v => ({
25665 originalName: v.name,
25666 variable: v.tensor.variable(trainable)
25667 }));
25668 this.accumulatedSecondMoment =
25669 weightValues.slice(variableCount, variableCount * 2)
25670 .map(v => ({
25671 originalName: v.name,
25672 variable: v.tensor.variable(trainable)
25673 }));
25674 }
25675 getConfig() {
25676 return {
25677 'learningRate': this.learningRate,
25678 'beta1': this.beta1,
25679 'beta2': this.beta2,
25680 'epsilon': this.epsilon,
25681 };
25682 }
25683 /** @nocollapse */
25684 static fromConfig(cls, config) {
25685 return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon']);
25686 }
25687 }
25688 /** @nocollapse */
25689 AdamOptimizer.className = 'Adam'; // Note: Name matters for Python compatibility.
25690 registerClass(AdamOptimizer);
25691
25692 /**
25693 * @license
25694 * Copyright 2018 Google LLC. All Rights Reserved.
25695 * Licensed under the Apache License, Version 2.0 (the "License");
25696 * you may not use this file except in compliance with the License.
25697 * You may obtain a copy of the License at
25698 *
25699 * http://www.apache.org/licenses/LICENSE-2.0
25700 *
25701 * Unless required by applicable law or agreed to in writing, software
25702 * distributed under the License is distributed on an "AS IS" BASIS,
25703 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25704 * See the License for the specific language governing permissions and
25705 * limitations under the License.
25706 * =============================================================================
25707 */
25708 class AdamaxOptimizer extends Optimizer {
25709 constructor(learningRate, beta1, beta2, epsilon = null, decay = 0.0) {
25710 super();
25711 this.learningRate = learningRate;
25712 this.beta1 = beta1;
25713 this.beta2 = beta2;
25714 this.epsilon = epsilon;
25715 this.decay = decay;
25716 this.accumulatedFirstMoment = [];
25717 this.accumulatedWeightedInfNorm = [];
25718 tidy(() => {
25719 this.iteration = scalar(0).variable();
25720 this.accBeta1 = scalar(beta1).variable();
25721 });
25722 if (epsilon == null) {
25723 this.epsilon = ENGINE.backend.epsilon();
25724 }
25725 }
25726 applyGradients(variableGradients) {
25727 const variableNames = Array.isArray(variableGradients) ?
25728 variableGradients.map(item => item.name) :
25729 Object.keys(variableGradients);
25730 tidy(() => {
25731 const oneMinusAccBeta1 = sub(1, this.accBeta1);
25732 const lr = div(-this.learningRate, add$1(mul(this.iteration, this.decay), 1));
25733 variableNames.forEach((name, i) => {
25734 const value = ENGINE.registeredVariables[name];
25735 const trainable = false;
25736 if (this.accumulatedFirstMoment[i] == null) {
25737 this.accumulatedFirstMoment[i] = {
25738 originalName: `${name}/m`,
25739 variable: zerosLike(value).variable(trainable)
25740 };
25741 }
25742 if (this.accumulatedWeightedInfNorm[i] == null) {
25743 this.accumulatedWeightedInfNorm[i] = {
25744 originalName: `${name}/v`,
25745 variable: zerosLike(value).variable(trainable)
25746 };
25747 }
25748 const gradient = Array.isArray(variableGradients) ?
25749 variableGradients[i].tensor :
25750 variableGradients[name];
25751 if (gradient == null) {
25752 return;
25753 }
25754 const firstMoment = this.accumulatedFirstMoment[i].variable;
25755 const weightedInfNorm = this.accumulatedWeightedInfNorm[i].variable;
25756 const newFirstMoment = add$1(mul(firstMoment, this.beta1), mul(gradient, 1 - this.beta1));
25757 const ut0 = mul(weightedInfNorm, this.beta2);
25758 const ut1 = abs(gradient);
25759 const newWeightedInfNorm = maximum(ut0, ut1);
25760 firstMoment.assign(newFirstMoment);
25761 weightedInfNorm.assign(newWeightedInfNorm);
25762 const newValue = add$1(mul(div(lr, oneMinusAccBeta1), div(newFirstMoment, add$1(newWeightedInfNorm, this.epsilon))), value);
25763 value.assign(newValue);
25764 });
25765 this.iteration.assign(add$1(this.iteration, 1));
25766 this.accBeta1.assign(mul(this.accBeta1, this.beta1));
25767 });
25768 this.incrementIterations();
25769 }
25770 dispose() {
25771 this.accBeta1.dispose();
25772 this.iteration.dispose();
25773 if (this.accumulatedFirstMoment != null) {
25774 dispose(this.accumulatedFirstMoment.map(v => v.variable));
25775 }
25776 if (this.accumulatedWeightedInfNorm != null) {
25777 dispose(this.accumulatedWeightedInfNorm.map(v => v.variable));
25778 }
25779 }
25780 async getWeights() {
25781 throw new Error('getWeights() is not implemented for Adamax yet.');
25782 }
25783 async setWeights(weightValues) {
25784 throw new Error('setWeights() is not implemented for Adamax yet.');
25785 }
25786 getConfig() {
25787 return {
25788 'learningRate': this.learningRate,
25789 'beta1': this.beta1,
25790 'beta2': this.beta2,
25791 'epsilon': this.epsilon,
25792 'decay': this.decay
25793 };
25794 }
25795 /** @nocollapse */
25796 static fromConfig(cls, config) {
25797 return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon'], config['decay']);
25798 }
25799 }
25800 /** @nocollapse */
25801 AdamaxOptimizer.className = 'Adamax'; // Note: Name matters for Python compatbility.
25802 registerClass(AdamaxOptimizer);
25803
25804 /**
25805 * @license
25806 * Copyright 2018 Google LLC. All Rights Reserved.
25807 * Licensed under the Apache License, Version 2.0 (the "License");
25808 * you may not use this file except in compliance with the License.
25809 * You may obtain a copy of the License at
25810 *
25811 * http://www.apache.org/licenses/LICENSE-2.0
25812 *
25813 * Unless required by applicable law or agreed to in writing, software
25814 * distributed under the License is distributed on an "AS IS" BASIS,
25815 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25816 * See the License for the specific language governing permissions and
25817 * limitations under the License.
25818 * =============================================================================
25819 */
25820 /** @doclink Optimizer */
25821 class SGDOptimizer extends Optimizer {
25822 constructor(learningRate) {
25823 super();
25824 this.learningRate = learningRate;
25825 this.setLearningRate(learningRate);
25826 }
25827 applyGradients(variableGradients) {
25828 const varNames = Array.isArray(variableGradients) ?
25829 variableGradients.map(v => v.name) :
25830 Object.keys(variableGradients);
25831 varNames.forEach((name, i) => {
25832 const gradient = Array.isArray(variableGradients) ?
25833 variableGradients[i].tensor :
25834 variableGradients[name];
25835 if (gradient == null) {
25836 return;
25837 }
25838 const value = ENGINE.registeredVariables[name];
25839 tidy(() => {
25840 const newValue = add$1(mul(this.c, gradient), value);
25841 value.assign(newValue);
25842 });
25843 });
25844 this.incrementIterations();
25845 }
25846 /**
25847 * Sets the learning rate of the optimizer.
25848 */
25849 setLearningRate(learningRate) {
25850 this.learningRate = learningRate;
25851 if (this.c != null) {
25852 this.c.dispose();
25853 }
25854 this.c = keep(scalar(-learningRate));
25855 }
25856 dispose() {
25857 this.c.dispose();
25858 }
25859 async getWeights() {
25860 return [await this.saveIterations()];
25861 }
25862 async setWeights(weightValues) {
25863 weightValues = await this.extractIterations(weightValues);
25864 if (weightValues.length !== 0) {
25865 throw new Error('SGD optimizer does not have settable weights.');
25866 }
25867 }
25868 getConfig() {
25869 return { 'learningRate': this.learningRate };
25870 }
25871 /** @nocollapse */
25872 static fromConfig(cls, config) {
25873 return new cls(config['learningRate']);
25874 }
25875 }
25876 /** @nocollapse */
25877 SGDOptimizer.className = 'SGD'; // Note: Name matters for Python compatibility.
25878 registerClass(SGDOptimizer);
25879
25880 /**
25881 * @license
25882 * Copyright 2018 Google LLC. All Rights Reserved.
25883 * Licensed under the Apache License, Version 2.0 (the "License");
25884 * you may not use this file except in compliance with the License.
25885 * You may obtain a copy of the License at
25886 *
25887 * http://www.apache.org/licenses/LICENSE-2.0
25888 *
25889 * Unless required by applicable law or agreed to in writing, software
25890 * distributed under the License is distributed on an "AS IS" BASIS,
25891 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25892 * See the License for the specific language governing permissions and
25893 * limitations under the License.
25894 * =============================================================================
25895 */
25896 /** @doclink Optimizer */
25897 class MomentumOptimizer extends SGDOptimizer {
25898 constructor(learningRate, momentum, useNesterov = false) {
25899 super(learningRate);
25900 this.learningRate = learningRate;
25901 this.momentum = momentum;
25902 this.useNesterov = useNesterov;
25903 this.accumulations = [];
25904 this.m = scalar(this.momentum);
25905 }
25906 applyGradients(variableGradients) {
25907 const variableNames = Array.isArray(variableGradients) ?
25908 variableGradients.map(item => item.name) :
25909 Object.keys(variableGradients);
25910 variableNames.forEach((name, i) => {
25911 const value = ENGINE.registeredVariables[name];
25912 if (this.accumulations[i] == null) {
25913 const trainable = false;
25914 this.accumulations[i] = {
25915 originalName: `${name}/momentum`,
25916 variable: tidy(() => zerosLike(value).variable(trainable))
25917 };
25918 }
25919 const accumulation = this.accumulations[i].variable;
25920 const gradient = Array.isArray(variableGradients) ?
25921 variableGradients[i].tensor :
25922 variableGradients[name];
25923 if (gradient == null) {
25924 return;
25925 }
25926 tidy(() => {
25927 let newValue;
25928 const newAccumulation = add$1(mul(this.m, accumulation), gradient);
25929 if (this.useNesterov) {
25930 newValue = add$1(mul(this.c, add$1(gradient, mul(newAccumulation, this.m))), value);
25931 }
25932 else {
25933 newValue = add$1(mul(this.c, newAccumulation), value);
25934 }
25935 accumulation.assign(newAccumulation);
25936 value.assign(newValue);
25937 });
25938 });
25939 this.incrementIterations();
25940 }
25941 dispose() {
25942 this.m.dispose();
25943 if (this.accumulations != null) {
25944 dispose(this.accumulations.map(v => v.variable));
25945 }
25946 }
25947 /**
25948 * Sets the momentum of the optimizer.
25949 *
25950 * @param momentum
25951 */
25952 setMomentum(momentum) {
25953 this.momentum = momentum;
25954 }
25955 async getWeights() {
25956 // Order matters for Python compatibility.
25957 return [await this.saveIterations()].concat(this.accumulations.map(v => ({ name: v.originalName, tensor: v.variable })));
25958 }
25959 async setWeights(weightValues) {
25960 weightValues = await this.extractIterations(weightValues);
25961 const trainable = false;
25962 this.accumulations = weightValues.map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) }));
25963 }
25964 getConfig() {
25965 return {
25966 'learningRate': this.learningRate,
25967 'momentum': this.momentum,
25968 'useNesterov': this.useNesterov
25969 };
25970 }
25971 /** @nocollapse */
25972 static fromConfig(cls, config) {
25973 return new cls(config['learningRate'], config['momentum'], config['useNesterov']);
25974 }
25975 }
25976 /** @nocollapse */
25977 MomentumOptimizer.className = 'Momentum'; // Name matters for Python compatibility.
25978 registerClass(MomentumOptimizer);
25979
25980 /**
25981 * @license
25982 * Copyright 2018 Google LLC. All Rights Reserved.
25983 * Licensed under the Apache License, Version 2.0 (the "License");
25984 * you may not use this file except in compliance with the License.
25985 * You may obtain a copy of the License at
25986 *
25987 * http://www.apache.org/licenses/LICENSE-2.0
25988 *
25989 * Unless required by applicable law or agreed to in writing, software
25990 * distributed under the License is distributed on an "AS IS" BASIS,
25991 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25992 * See the License for the specific language governing permissions and
25993 * limitations under the License.
25994 * =============================================================================
25995 */
25996 /** @doclink Optimizer */
25997 class RMSPropOptimizer extends Optimizer {
25998 constructor(learningRate, decay = 0.9, momentum = 0.0, epsilon = null, centered = false) {
25999 super();
26000 this.learningRate = learningRate;
26001 this.decay = decay;
26002 this.momentum = momentum;
26003 this.epsilon = epsilon;
26004 this.accumulatedMeanSquares = [];
26005 this.accumulatedMoments = [];
26006 this.accumulatedMeanGrads = [];
26007 this.centered = centered;
26008 if (epsilon == null) {
26009 this.epsilon = ENGINE.backend.epsilon();
26010 }
26011 if (learningRate == null) {
26012 throw new Error(`learningRate for RMSPropOptimizer must be defined.`);
26013 }
26014 }
26015 applyGradients(variableGradients) {
26016 const variableNames = Array.isArray(variableGradients) ?
26017 variableGradients.map(item => item.name) :
26018 Object.keys(variableGradients);
26019 variableNames.forEach((name, i) => {
26020 const value = ENGINE.registeredVariables[name];
26021 const trainable = false;
26022 if (this.accumulatedMeanSquares[i] == null) {
26023 this.accumulatedMeanSquares[i] = {
26024 originalName: `${name}/rms`,
26025 variable: tidy(() => zerosLike(value).variable(trainable))
26026 };
26027 }
26028 if (this.accumulatedMoments[i] == null) {
26029 this.accumulatedMoments[i] = {
26030 originalName: `${name}/momentum`,
26031 variable: tidy(() => zerosLike(value).variable(trainable))
26032 };
26033 }
26034 if (this.accumulatedMeanGrads[i] == null && this.centered) {
26035 this.accumulatedMeanGrads[i] = {
26036 originalName: `${name}/mg`,
26037 variable: tidy(() => zerosLike(value).variable(trainable))
26038 };
26039 }
26040 const gradient = Array.isArray(variableGradients) ?
26041 variableGradients[i].tensor :
26042 variableGradients[name];
26043 if (gradient == null) {
26044 return;
26045 }
26046 const accumulatedMeanSquare = this.accumulatedMeanSquares[i].variable;
26047 const accumulatedMoments = this.accumulatedMoments[i].variable;
26048 tidy(() => {
26049 const newAccumulatedMeanSquare = add$1(mul(accumulatedMeanSquare, this.decay), mul(square(gradient), 1 - this.decay));
26050 if (this.centered) {
26051 const accumulatedMeanGrad = this.accumulatedMeanGrads[i].variable;
26052 // Centered gradient
26053 const newAccumulatedMeanGrad = add$1(mul(accumulatedMeanGrad, this.decay), mul(gradient, 1 - this.decay));
26054 const gradContribution = div(mul(gradient, this.learningRate), sqrt(sub(newAccumulatedMeanSquare, add$1(square(newAccumulatedMeanGrad), this.epsilon))));
26055 const newAccumulatedMoments = add$1(mul(accumulatedMoments, this.momentum), gradContribution);
26056 accumulatedMeanSquare.assign(newAccumulatedMeanSquare);
26057 accumulatedMeanGrad.assign(newAccumulatedMeanGrad);
26058 accumulatedMoments.assign(newAccumulatedMoments);
26059 const newValue = sub(value, newAccumulatedMoments);
26060 value.assign(newValue);
26061 }
26062 else {
26063 // Plain gradient
26064 const newAccumulatedMeanSquare = add$1(mul(accumulatedMeanSquare, this.decay), mul(square(gradient), 1 - this.decay));
26065 const newAccumulatedMoments = add$1(mul(accumulatedMoments, this.momentum), div(mul(gradient, this.learningRate), sqrt(add$1(newAccumulatedMeanSquare, this.epsilon))));
26066 accumulatedMeanSquare.assign(newAccumulatedMeanSquare);
26067 accumulatedMoments.assign(newAccumulatedMoments);
26068 const newValue = sub(value, newAccumulatedMoments);
26069 value.assign(newValue);
26070 }
26071 });
26072 });
26073 this.incrementIterations();
26074 }
26075 dispose() {
26076 if (this.accumulatedMeanSquares != null) {
26077 dispose(this.accumulatedMeanSquares.map(v => v.variable));
26078 }
26079 if (this.accumulatedMeanGrads != null && this.centered) {
26080 dispose(this.accumulatedMeanGrads.map(v => v.variable));
26081 }
26082 if (this.accumulatedMoments != null) {
26083 dispose(this.accumulatedMoments.map(v => v.variable));
26084 }
26085 }
26086 async getWeights() {
26087 // Order matters for Python compatibility.
26088 const variables = [...this.accumulatedMeanSquares, ...this.accumulatedMoments];
26089 if (this.centered) {
26090 variables.push(...this.accumulatedMeanGrads);
26091 }
26092 return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable })));
26093 }
26094 async setWeights(weightValues) {
26095 weightValues = await this.extractIterations(weightValues);
26096 const variableCount = this.centered ? weightValues.length / 3 : weightValues.length / 2;
26097 const trainable = false;
26098 this.accumulatedMeanSquares =
26099 weightValues.slice(0, variableCount).map(v => ({
26100 originalName: v.name,
26101 variable: v.tensor.variable(trainable)
26102 }));
26103 this.accumulatedMoments =
26104 weightValues.slice(variableCount, variableCount * 2)
26105 .map(v => ({
26106 originalName: v.name,
26107 variable: v.tensor.variable(trainable)
26108 }));
26109 if (this.centered) {
26110 this.accumulatedMeanGrads =
26111 weightValues.slice(variableCount * 2, variableCount * 3)
26112 .map(v => ({
26113 originalName: v.name,
26114 variable: v.tensor.variable(trainable)
26115 }));
26116 }
26117 }
26118 getConfig() {
26119 return {
26120 'learningRate': this.learningRate,
26121 'decay': this.decay,
26122 'momentum': this.momentum,
26123 'epsilon': this.epsilon,
26124 'centered': this.centered
26125 };
26126 }
26127 /** @nocollapse */
26128 static fromConfig(cls, config) {
26129 return new cls(config['learningRate'], config['decay'], config['momentum'], config['epsilon'], config['centered']);
26130 }
26131 }
26132 /** @nocollapse */
26133 RMSPropOptimizer.className = 'RMSProp'; // Note: Name matters for Python compatibility.
26134 registerClass(RMSPropOptimizer);
26135
26136 /**
26137 * @license
26138 * Copyright 2018 Google LLC. All Rights Reserved.
26139 * Licensed under the Apache License, Version 2.0 (the "License");
26140 * you may not use this file except in compliance with the License.
26141 * You may obtain a copy of the License at
26142 *
26143 * http://www.apache.org/licenses/LICENSE-2.0
26144 *
26145 * Unless required by applicable law or agreed to in writing, software
26146 * distributed under the License is distributed on an "AS IS" BASIS,
26147 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26148 * See the License for the specific language governing permissions and
26149 * limitations under the License.
26150 * =============================================================================
26151 */
26152 class OptimizerConstructors {
26153 /**
26154 * Constructs a `tf.SGDOptimizer` that uses stochastic gradient descent.
26155 *
26156 * ```js
26157 * // Fit a quadratic function by learning the coefficients a, b, c.
26158 * const xs = tf.tensor1d([0, 1, 2, 3]);
26159 * const ys = tf.tensor1d([1.1, 5.9, 16.8, 33.9]);
26160 *
26161 * const a = tf.scalar(Math.random()).variable();
26162 * const b = tf.scalar(Math.random()).variable();
26163 * const c = tf.scalar(Math.random()).variable();
26164 *
26165 * // y = a * x^2 + b * x + c.
26166 * const f = x => a.mul(x.square()).add(b.mul(x)).add(c);
26167 * const loss = (pred, label) => pred.sub(label).square().mean();
26168 *
26169 * const learningRate = 0.01;
26170 * const optimizer = tf.train.sgd(learningRate);
26171 *
26172 * // Train the model.
26173 * for (let i = 0; i < 10; i++) {
26174 * optimizer.minimize(() => loss(f(xs), ys));
26175 * }
26176 *
26177 * // Make predictions.
26178 * console.log(
26179 * `a: ${a.dataSync()}, b: ${b.dataSync()}, c: ${c.dataSync()}`);
26180 * const preds = f(xs).dataSync();
26181 * preds.forEach((pred, i) => {
26182 * console.log(`x: ${i}, pred: ${pred}`);
26183 * });
26184 * ```
26185 *
26186 * @param learningRate The learning rate to use for the SGD algorithm.
26187 *
26188 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
26189 */
26190 static sgd(learningRate) {
26191 return new SGDOptimizer(learningRate);
26192 }
26193 /**
26194 * Constructs a `tf.MomentumOptimizer` that uses momentum gradient
26195 * descent.
26196 *
26197 * See
26198 * [http://proceedings.mlr.press/v28/sutskever13.pdf](
26199 * http://proceedings.mlr.press/v28/sutskever13.pdf)
26200 *
26201 * @param learningRate The learning rate to use for the Momentum gradient
26202 * descent algorithm.
26203 * @param momentum The momentum to use for the momentum gradient descent
26204 * algorithm.
26205 *
26206 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
26207 */
26208 static momentum(learningRate, momentum, useNesterov = false) {
26209 return new MomentumOptimizer(learningRate, momentum, useNesterov);
26210 }
26211 /**
26212 * Constructs a `tf.RMSPropOptimizer` that uses RMSProp gradient
26213 * descent. This implementation uses plain momentum and is not centered
26214 * version of RMSProp.
26215 *
26216 * See
26217 * [http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf](
26218 * http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
26219 *
26220 * @param learningRate The learning rate to use for the RMSProp gradient
26221 * descent algorithm.
26222 * @param decay The discounting factor for the history/coming gradient.
26223 * @param momentum The momentum to use for the RMSProp gradient descent
26224 * algorithm.
26225 * @param epsilon Small value to avoid zero denominator.
26226 * @param centered If true, gradients are normalized by the estimated
26227 * variance of the gradient.
26228 *
26229 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
26230 */
26231 static rmsprop(learningRate, decay = .9, momentum = 0.0, epsilon = null, centered = false) {
26232 return new RMSPropOptimizer(learningRate, decay, momentum, epsilon, centered);
26233 }
26234 /**
26235 * Constructs a `tf.AdamOptimizer` that uses the Adam algorithm.
26236 * See [https://arxiv.org/abs/1412.6980](https://arxiv.org/abs/1412.6980)
26237 *
26238 * @param learningRate The learning rate to use for the Adam gradient
26239 * descent algorithm.
26240 * @param beta1 The exponential decay rate for the 1st moment estimates.
26241 * @param beta2 The exponential decay rate for the 2nd moment estimates.
26242 * @param epsilon A small constant for numerical stability.
26243 *
26244 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
26245 */
26246 static adam(learningRate = 0.001, beta1 = 0.9, beta2 = 0.999, epsilon = null) {
26247 return new AdamOptimizer(learningRate, beta1, beta2, epsilon);
26248 }
26249 /**
26250 * Constructs a `tf.AdadeltaOptimizer` that uses the Adadelta algorithm.
26251 * See [https://arxiv.org/abs/1212.5701](https://arxiv.org/abs/1212.5701)
26252 *
26253 * @param learningRate The learning rate to use for the Adadelta gradient
26254 * descent algorithm.
26255 * @param rho The learning rate decay over each update.
26256 * @param epsilon A constant epsilon used to better condition the grad
26257 * update.
26258 *
26259 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
26260 */
26261 static adadelta(learningRate = .001, rho = .95, epsilon = null) {
26262 return new AdadeltaOptimizer(learningRate, rho, epsilon);
26263 }
26264 /**
26265 * Constructs a `tf.AdamaxOptimizer` that uses the Adamax algorithm.
26266 * See [https://arxiv.org/abs/1412.6980](https://arxiv.org/abs/1412.6980)
26267 *
26268 * @param learningRate The learning rate to use for the Adamax gradient
26269 * descent algorithm.
26270 * @param beta1 The exponential decay rate for the 1st moment estimates.
26271 * @param beta2 The exponential decay rate for the 2nd moment estimates.
26272 * @param epsilon A small constant for numerical stability.
26273 * @param decay The learning rate decay over each update.
26274 *
26275 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
26276 */
26277 static adamax(learningRate = 0.002, beta1 = 0.9, beta2 = 0.999, epsilon = null, decay = 0.0) {
26278 return new AdamaxOptimizer(learningRate, beta1, beta2, epsilon, decay);
26279 }
26280 /**
26281 * Constructs a `tf.AdagradOptimizer` that uses the Adagrad algorithm.
26282 * See
26283 * [http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf](
26284 * http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
26285 * or
26286 * [http://ruder.io/optimizing-gradient-descent/index.html#adagrad](
26287 * http://ruder.io/optimizing-gradient-descent/index.html#adagrad)
26288 *
26289 * @param learningRate The learning rate to use for the Adagrad gradient
26290 * descent algorithm.
26291 * @param initialAccumulatorValue Starting value for the accumulators, must be
26292 * positive.
26293 *
26294 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
26295 */
26296 static adagrad(learningRate, initialAccumulatorValue = 0.1) {
26297 return new AdagradOptimizer(learningRate, initialAccumulatorValue);
26298 }
26299 }
26300
26301 /**
26302 * @license
26303 * Copyright 2018 Google LLC. All Rights Reserved.
26304 * Licensed under the Apache License, Version 2.0 (the "License");
26305 * you may not use this file except in compliance with the License.
26306 * You may obtain a copy of the License at
26307 *
26308 * http://www.apache.org/licenses/LICENSE-2.0
26309 *
26310 * Unless required by applicable law or agreed to in writing, software
26311 * distributed under the License is distributed on an "AS IS" BASIS,
26312 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26313 * See the License for the specific language governing permissions and
26314 * limitations under the License.
26315 * =============================================================================
26316 */
26317 // tslint:disable-next-line:no-unused-expression
26318 [MomentumOptimizer, SGDOptimizer, AdadeltaOptimizer, AdagradOptimizer,
26319 RMSPropOptimizer, AdamaxOptimizer, AdamOptimizer];
26320 const train = {
26321 sgd: OptimizerConstructors.sgd,
26322 momentum: OptimizerConstructors.momentum,
26323 adadelta: OptimizerConstructors.adadelta,
26324 adagrad: OptimizerConstructors.adagrad,
26325 rmsprop: OptimizerConstructors.rmsprop,
26326 adamax: OptimizerConstructors.adamax,
26327 adam: OptimizerConstructors.adam
26328 };
26329
26330 /**
26331 * @license
26332 * Copyright 2017 Google LLC. All Rights Reserved.
26333 * Licensed under the Apache License, Version 2.0 (the "License");
26334 * you may not use this file except in compliance with the License.
26335 * You may obtain a copy of the License at
26336 *
26337 * http://www.apache.org/licenses/LICENSE-2.0
26338 *
26339 * Unless required by applicable law or agreed to in writing, software
26340 * distributed under the License is distributed on an "AS IS" BASIS,
26341 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26342 * See the License for the specific language governing permissions and
26343 * limitations under the License.
26344 * =============================================================================
26345 */
26346 const delayCallback = (() => {
26347 if (typeof requestAnimationFrame !== 'undefined') {
26348 return requestAnimationFrame;
26349 }
26350 else if (typeof setImmediate !== 'undefined') {
26351 return setImmediate;
26352 }
26353 return (f) => f(); // no delays
26354 })();
26355 /**
26356 * Returns a promise that resolve when a requestAnimationFrame has completed.
26357 *
26358 * On Node.js this uses setImmediate instead of requestAnimationFrame.
26359 *
26360 * This is simply a sugar method so that users can do the following:
26361 * `await tf.nextFrame();`
26362 *
26363 * @doc {heading: 'Performance', subheading: 'Timing'}
26364 */
26365 function nextFrame() {
26366 return new Promise(resolve => delayCallback(() => resolve()));
26367 }
26368
26369 /**
26370 * @license
26371 * Copyright 2017 Google LLC. All Rights Reserved.
26372 * Licensed under the Apache License, Version 2.0 (the "License");
26373 * you may not use this file except in compliance with the License.
26374 * You may obtain a copy of the License at
26375 *
26376 * http://www.apache.org/licenses/LICENSE-2.0
26377 *
26378 * Unless required by applicable law or agreed to in writing, software
26379 * distributed under the License is distributed on an "AS IS" BASIS,
26380 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26381 * See the License for the specific language governing permissions and
26382 * limitations under the License.
26383 * =============================================================================
26384 */
26385 function assertParamsConsistent(shapes, axis) {
26386 const rank = shapes[0].length;
26387 shapes.forEach((shape, i) => {
26388 assert(shape.length === rank, () => `Error in concat${rank}D: rank of tensors[${i}] must be the same ` +
26389 `as the rank of the rest (${rank})`);
26390 });
26391 assert(axis >= 0 && axis < rank, () => `Error in concat${rank}D: axis must be between 0 and ${rank - 1}.`);
26392 const firstShape = shapes[0];
26393 shapes.forEach((shape, i) => {
26394 for (let r = 0; r < rank; r++) {
26395 assert((r === axis) || (shape[r] === firstShape[r]), () => `Error in concat${rank}D: Shape of tensors[${i}] (${shape}) ` +
26396 `does not match the shape of the rest (${firstShape}) ` +
26397 `along the non-concatenated axis ${i}.`);
26398 }
26399 });
26400 }
26401 function computeOutShape$1(shapes, axis) {
26402 const outputShape = shapes[0].slice();
26403 for (let i = 1; i < shapes.length; i++) {
26404 outputShape[axis] += shapes[i][axis];
26405 }
26406 return outputShape;
26407 }
26408
26409 /**
26410 * @license
26411 * Copyright 2017 Google LLC. All Rights Reserved.
26412 * Licensed under the Apache License, Version 2.0 (the "License");
26413 * you may not use this file except in compliance with the License.
26414 * You may obtain a copy of the License at
26415 *
26416 * http://www.apache.org/licenses/LICENSE-2.0
26417 *
26418 * Unless required by applicable law or agreed to in writing, software
26419 * distributed under the License is distributed on an "AS IS" BASIS,
26420 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26421 * See the License for the specific language governing permissions and
26422 * limitations under the License.
26423 * =============================================================================
26424 */
26425 const PARALLELIZE_THRESHOLD = 30;
26426 function computeOptimalWindowSize(inSize) {
26427 if (inSize <= PARALLELIZE_THRESHOLD) {
26428 return inSize;
26429 }
26430 return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
26431 }
26432
26433 /**
26434 * @license
26435 * Copyright 2020 Google LLC. All Rights Reserved.
26436 * Licensed under the Apache License, Version 2.0 (the "License");
26437 * you may not use this file except in compliance with the License.
26438 * You may obtain a copy of the License at
26439 *
26440 * http://www.apache.org/licenses/LICENSE-2.0
26441 *
26442 * Unless required by applicable law or agreed to in writing, software
26443 * distributed under the License is distributed on an "AS IS" BASIS,
26444 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26445 * See the License for the specific language governing permissions and
26446 * limitations under the License.
26447 * =============================================================================
26448 */
26449 // Returns the image center in pixels.
26450 function getImageCenter(center, imageHeight, imageWidth) {
26451 const centerX = imageWidth * (typeof center === 'number' ? center : center[0]);
26452 const centerY = imageHeight * (typeof center === 'number' ? center : center[1]);
26453 return [centerX, centerY];
26454 }
26455
26456 /**
26457 * @license
26458 * Copyright 2018 Google LLC. All Rights Reserved.
26459 * Licensed under the Apache License, Version 2.0 (the "License");
26460 * you may not use this file except in compliance with the License.
26461 * You may obtain a copy of the License at
26462 *
26463 * http://www.apache.org/licenses/LICENSE-2.0
26464 *
26465 * Unless required by applicable law or agreed to in writing, software
26466 * distributed under the License is distributed on an "AS IS" BASIS,
26467 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26468 * See the License for the specific language governing permissions and
26469 * limitations under the License.
26470 * =============================================================================
26471 */
26472 /**
26473 * Gets the new shape of the input Tensor after it's been reshaped
26474 * to:
26475 * [blockShape[0], ..., blockShape[M-1], batch / prod(blockShape),
26476 * inputShape[1], ..., inputShape[N-1]]
26477 *
26478 * See step 1: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
26479 */
26480 function getReshaped(inputShape, blockShape, prod, batchToSpace = true) {
26481 let reshaped = [];
26482 if (batchToSpace) {
26483 reshaped = reshaped.concat(blockShape.slice(0));
26484 reshaped.push(inputShape[0] / prod);
26485 reshaped = reshaped.concat(inputShape.slice(1));
26486 }
26487 else {
26488 reshaped = reshaped.concat(inputShape[0]);
26489 const spatialLength = blockShape.length;
26490 for (let i = 0; i < spatialLength; ++i) {
26491 reshaped =
26492 reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]);
26493 }
26494 reshaped = reshaped.concat(inputShape.slice(spatialLength + 1));
26495 }
26496 return reshaped;
26497 }
26498 /**
26499 * Gets the permutation that will transpose the dimensions of the
26500 * reshaped tensor to shape:
26501 *
26502 * [batch / prod(block_shape),inputShape[1], blockShape[0], ...,
26503 * inputShape[M], blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
26504 *
26505 * see step 2: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
26506 */
26507 function getPermuted(reshapedRank, blockShapeRank, batchToSpace = true) {
26508 const permuted = [];
26509 if (batchToSpace) {
26510 permuted.push(blockShapeRank);
26511 for (let i = blockShapeRank + 1; i < reshapedRank; ++i) {
26512 if (i <= 2 * blockShapeRank) {
26513 permuted.push(i);
26514 permuted.push(i - (blockShapeRank + 1));
26515 }
26516 else {
26517 permuted.push(i);
26518 }
26519 }
26520 }
26521 else {
26522 const permutedBeforeBatch = [];
26523 const permutedAfterBatch = [];
26524 for (let i = 1; i < reshapedRank; ++i) {
26525 if (i >= blockShapeRank * 2 + 1 || i % 2 === 1) {
26526 permutedAfterBatch.push(i);
26527 }
26528 else {
26529 permutedBeforeBatch.push(i);
26530 }
26531 }
26532 permuted.push(...permutedBeforeBatch);
26533 permuted.push(0);
26534 permuted.push(...permutedAfterBatch);
26535 }
26536 return permuted;
26537 }
26538 /**
26539 * Gets the shape of the reshaped and permuted input Tensor before any cropping
26540 * is applied. The new shape will be:
26541 *
26542 * [batch / prod(blockShape),inputShape[1] * blockShape[0], ...,
26543 * inputShape[M] * blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
26544 *
26545 * See step 3: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
26546 */
26547 function getReshapedPermuted(inputShape, blockShape, prod, batchToSpace = true) {
26548 const reshapedPermuted = [];
26549 if (batchToSpace) {
26550 reshapedPermuted.push(inputShape[0] / prod);
26551 }
26552 else {
26553 reshapedPermuted.push(inputShape[0] * prod);
26554 }
26555 for (let i = 1; i < inputShape.length; ++i) {
26556 if (i <= blockShape.length) {
26557 if (batchToSpace) {
26558 reshapedPermuted.push(blockShape[i - 1] * inputShape[i]);
26559 }
26560 else {
26561 reshapedPermuted.push(inputShape[i] / blockShape[i - 1]);
26562 }
26563 }
26564 else {
26565 reshapedPermuted.push(inputShape[i]);
26566 }
26567 }
26568 return reshapedPermuted;
26569 }
26570 /**
26571 * Converts the crops argument into the beginning coordinates of a slice
26572 * operation.
26573 */
26574 function getSliceBeginCoords(crops, blockShape) {
26575 const sliceBeginCoords = [0];
26576 for (let i = 0; i < blockShape; ++i) {
26577 sliceBeginCoords.push(crops[i][0]);
26578 }
26579 return sliceBeginCoords;
26580 }
26581 /**
26582 * Converts the crops argument into the size of a slice operation. When
26583 * combined with getSliceBeginCoords this function allows the reshaped and
26584 * permuted Tensor to be cropped to its final output shape of:
26585 *
26586 * inputShape[1] * blockShape[0] - crops[0,0] - crops[0,1], ...,
26587 * inputShape[M] * blockShape[M-1] -crops[M-1,0] -
26588 * crops[M-1,1],inputShape[M+1], ..., inputShape[N-1]]
26589 *
26590 * See step 4: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
26591 */
26592 function getSliceSize(uncroppedShape, crops, blockShape) {
26593 const sliceSize = uncroppedShape.slice(0, 1);
26594 for (let i = 0; i < blockShape; ++i) {
26595 sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]);
26596 }
26597 return sliceSize;
26598 }
26599
26600 /**
26601 * @license
26602 * Copyright 2018 Google LLC. All Rights Reserved.
26603 * Licensed under the Apache License, Version 2.0 (the "License");
26604 * you may not use this file except in compliance with the License.
26605 * You may obtain a copy of the License at
26606 *
26607 * http://www.apache.org/licenses/LICENSE-2.0
26608 *
26609 * Unless required by applicable law or agreed to in writing, software
26610 * distributed under the License is distributed on an "AS IS" BASIS,
26611 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26612 * See the License for the specific language governing permissions and
26613 * limitations under the License.
26614 * =============================================================================
26615 */
26616 const SELU_SCALEALPHA = 1.7580993408473768599402175208123;
26617 const SELU_SCALE = 1.0507009873554804934193349852946;
26618
26619 /**
26620 * @license
26621 * Copyright 2018 Google LLC. All Rights Reserved.
26622 * Licensed under the Apache License, Version 2.0 (the "License");
26623 * you may not use this file except in compliance with the License.
26624 * You may obtain a copy of the License at
26625 *
26626 * http://www.apache.org/licenses/LICENSE-2.0
26627 *
26628 * Unless required by applicable law or agreed to in writing, software
26629 * distributed under the License is distributed on an "AS IS" BASIS,
26630 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26631 * See the License for the specific language governing permissions and
26632 * limitations under the License.
26633 * =============================================================================
26634 */
26635 const ERF_P = 0.3275911;
26636 const ERF_A1 = 0.254829592;
26637 const ERF_A2 = -0.284496736;
26638 const ERF_A3 = 1.421413741;
26639 const ERF_A4 = -1.453152027;
26640 const ERF_A5 = 1.061405429;
26641
26642 /**
26643 * @license
26644 * Copyright 2018 Google LLC. All Rights Reserved.
26645 * Licensed under the Apache License, Version 2.0 (the "License");
26646 * you may not use this file except in compliance with the License.
26647 * You may obtain a copy of the License at
26648 *
26649 * http://www.apache.org/licenses/LICENSE-2.0
26650 *
26651 * Unless required by applicable law or agreed to in writing, software
26652 * distributed under the License is distributed on an "AS IS" BASIS,
26653 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26654 * See the License for the specific language governing permissions and
26655 * limitations under the License.
26656 * =============================================================================
26657 */
26658 /**
26659 * Merges real and imaginary Float32Arrays into a single complex Float32Array.
26660 *
26661 * The memory layout is interleaved as follows:
26662 * real: [r0, r1, r2]
26663 * imag: [i0, i1, i2]
26664 * complex: [r0, i0, r1, i1, r2, i2]
26665 *
26666 * This is the inverse of splitRealAndImagArrays.
26667 *
26668 * @param real The real values of the complex tensor values.
26669 * @param imag The imag values of the complex tensor values.
26670 * @returns A complex tensor as a Float32Array with merged values.
26671 */
26672 function mergeRealAndImagArrays(real, imag) {
26673 if (real.length !== imag.length) {
26674 throw new Error(`Cannot merge real and imag arrays of different lengths. real:` +
26675 `${real.length}, imag: ${imag.length}.`);
26676 }
26677 const result = new Float32Array(real.length * 2);
26678 for (let i = 0; i < result.length; i += 2) {
26679 result[i] = real[i / 2];
26680 result[i + 1] = imag[i / 2];
26681 }
26682 return result;
26683 }
26684 /**
26685 * Splits a complex Float32Array into real and imag parts.
26686 *
26687 * The memory layout is interleaved as follows:
26688 * complex: [r0, i0, r1, i1, r2, i2]
26689 * real: [r0, r1, r2]
26690 * imag: [i0, i1, i2]
26691 *
26692 * This is the inverse of mergeRealAndImagArrays.
26693 *
26694 * @param complex The complex tensor values.
26695 * @returns An object with real and imag Float32Array components of the complex
26696 * tensor.
26697 */
26698 function splitRealAndImagArrays(complex) {
26699 const real = new Float32Array(complex.length / 2);
26700 const imag = new Float32Array(complex.length / 2);
26701 for (let i = 0; i < complex.length; i += 2) {
26702 real[i / 2] = complex[i];
26703 imag[i / 2] = complex[i + 1];
26704 }
26705 return { real, imag };
26706 }
26707 /**
26708 * Extracts even indexed complex values in the given array.
26709 * @param complex The complex tensor values
26710 */
26711 function complexWithEvenIndex(complex) {
26712 const len = Math.ceil(complex.length / 4);
26713 const real = new Float32Array(len);
26714 const imag = new Float32Array(len);
26715 for (let i = 0; i < complex.length; i += 4) {
26716 real[Math.floor(i / 4)] = complex[i];
26717 imag[Math.floor(i / 4)] = complex[i + 1];
26718 }
26719 return { real, imag };
26720 }
26721 /**
26722 * Extracts odd indexed comple values in the given array.
26723 * @param complex The complex tensor values
26724 */
26725 function complexWithOddIndex(complex) {
26726 const len = Math.floor(complex.length / 4);
26727 const real = new Float32Array(len);
26728 const imag = new Float32Array(len);
26729 for (let i = 2; i < complex.length; i += 4) {
26730 real[Math.floor(i / 4)] = complex[i];
26731 imag[Math.floor(i / 4)] = complex[i + 1];
26732 }
26733 return { real, imag };
26734 }
26735 /**
26736 * Get the map representing a complex value in the given array.
26737 * @param complex The complex tensor values.
26738 * @param index An index of the target complex value.
26739 */
26740 function getComplexWithIndex(complex, index) {
26741 const real = complex[index * 2];
26742 const imag = complex[index * 2 + 1];
26743 return { real, imag };
26744 }
26745 /**
26746 * Insert a given complex value into the TypedArray.
26747 * @param data The array in which the complex value is inserted.
26748 * @param c The complex value to be inserted.
26749 * @param index An index of the target complex value.
26750 */
26751 function assignToTypedArray(data, real, imag, index) {
26752 data[index * 2] = real;
26753 data[index * 2 + 1] = imag;
26754 }
26755 /**
26756 * Make the list of exponent terms used by FFT.
26757 */
26758 function exponents(n, inverse) {
26759 const real = new Float32Array(n / 2);
26760 const imag = new Float32Array(n / 2);
26761 for (let i = 0; i < Math.ceil(n / 2); i++) {
26762 const x = (inverse ? 2 : -2) * Math.PI * (i / n);
26763 real[i] = Math.cos(x);
26764 imag[i] = Math.sin(x);
26765 }
26766 return { real, imag };
26767 }
26768 /**
26769 * Make the exponent term used by FFT.
26770 */
26771 function exponent(k, n, inverse) {
26772 const x = (inverse ? 2 : -2) * Math.PI * (k / n);
26773 const real = Math.cos(x);
26774 const imag = Math.sin(x);
26775 return { real, imag };
26776 }
26777
26778 /**
26779 * @license
26780 * Copyright 2021 Google LLC. All Rights Reserved.
26781 * Licensed under the Apache License, Version 2.0 (the "License");
26782 * you may not use this file except in compliance with the License.
26783 * You may obtain a copy of the License at
26784 *
26785 * http://www.apache.org/licenses/LICENSE-2.0
26786 *
26787 * Unless required by applicable law or agreed to in writing, software
26788 * distributed under the License is distributed on an "AS IS" BASIS,
26789 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26790 * See the License for the specific language governing permissions and
26791 * limitations under the License.
26792 * =============================================================================
26793 */
26794 const ARROW = '->';
26795 const ARROW_REGEX = /->/g;
26796 const COMMA = ',';
26797 const ELLIPSIS = '...';
26798 /**
26799 * Parse an equation for einsum.
26800 *
26801 * @param equation The einsum equation (e.g., "ij,jk->ik").
26802 * @param numTensors Number of tensors provided along with `equation`. Used to
26803 * check matching number of input tensors.
26804 * @returns An object consisting of the following fields:
26805 * - allDims: all dimension names as strings.
26806 * - summedDims: a list of all dimensions being summed over, as indices to
26807 * the elements of `allDims`.
26808 * - idDims: indices of the dimensions in each input tensor, as indices to
26809 * the elements of `allDims.
26810 */
26811 function decodeEinsumEquation(equation, numTensors) {
26812 equation = equation.replace(/\s/g, ''); // Remove witespace in equation.
26813 const numArrows = (equation.length - equation.replace(ARROW_REGEX, '').length) /
26814 ARROW.length;
26815 if (numArrows < 1) {
26816 throw new Error('Equations without an arrow are not supported.');
26817 }
26818 else if (numArrows > 1) {
26819 throw new Error(`Equation must contain exactly one arrow ("${ARROW}").`);
26820 }
26821 const [inputString, outputString] = equation.split(ARROW);
26822 assert(inputString.indexOf(ELLIPSIS) === -1, () => `The ellipsis notation ("${ELLIPSIS}") is not supported yet.`);
26823 const inputTerms = inputString.split(COMMA);
26824 const numInputs = inputTerms.length;
26825 if (numTensors !== numInputs) {
26826 throw new Error(`Expected ${numInputs} input tensors, received ${numTensors}`);
26827 }
26828 if (numInputs > 2) {
26829 throw new Error('Support for more than 2 input tensors is not implemented yet.');
26830 }
26831 const allDims = [];
26832 for (let i = 0; i < outputString.length; ++i) {
26833 const dimName = outputString[i];
26834 if (!inputTerms.some(inputTerm => inputTerm.indexOf(dimName) !== -1)) {
26835 throw new Error(`Output subscripts contain the label ${dimName} ` +
26836 `not present in the input subscripts.`);
26837 }
26838 if (allDims.indexOf(dimName) === -1) {
26839 allDims.push(dimName);
26840 }
26841 }
26842 for (let i = 0; i < inputString.length; ++i) {
26843 const dimName = inputString[i];
26844 if (allDims.indexOf(dimName) === -1 && dimName !== COMMA) {
26845 allDims.push(dimName);
26846 }
26847 }
26848 const idDims = new Array(inputTerms.length);
26849 for (let i = 0; i < numInputs; ++i) {
26850 if (new Set(inputTerms[i].split('')).size !== inputTerms[i].length) {
26851 throw new Error(`Found duplicate axes in input component ${inputTerms[i]}. ` +
26852 `Support for duplicate axes in input is not implemented yet.`);
26853 }
26854 idDims[i] = [];
26855 for (let j = 0; j < inputTerms[i].length; ++j) {
26856 idDims[i].push(allDims.indexOf(inputTerms[i][j]));
26857 }
26858 }
26859 const numDims = allDims.length; // Number of unique dimensions.
26860 const numOutDims = outputString.length; // Number of output dimensions.
26861 const summedDims = []; // Dimensions being summed over.
26862 for (let i = numOutDims; i < numDims; ++i) {
26863 summedDims.push(i);
26864 }
26865 return { allDims, summedDims, idDims };
26866 }
26867 /**
26868 * Get the permutation for a given input tensor.
26869 *
26870 * @param nDims Total number of dimension of all tensors involved in the einsum
26871 * operation.
26872 * @param idDims Dimension indices involve in the tensor in question.
26873 * @returns An object consisting of the following fields:
26874 * - permutationIndices: Indices to permute the axes of the tensor with.
26875 * - expandDims: Indices to the dimension that need to be expanded from the
26876 * tensor after permutation.
26877 */
26878 function getEinsumPermutation(nDims, idDims) {
26879 let permutationIndices = new Array(nDims);
26880 permutationIndices.fill(-1);
26881 for (let i = 0; i < idDims.length; ++i) {
26882 permutationIndices[idDims[i]] = i;
26883 }
26884 const expandDims = [];
26885 for (let i = 0; i < nDims; ++i) {
26886 if (permutationIndices[i] === -1) {
26887 expandDims.push(i);
26888 }
26889 }
26890 permutationIndices = permutationIndices.filter(d => d !== -1);
26891 return { permutationIndices, expandDims };
26892 }
26893 /**
26894 * Checks that the dimension sizes from different input tensors match the
26895 * equation.
26896 */
26897 function checkEinsumDimSizes(nDims, idDims, tensors) {
26898 const dimSizes = new Array(nDims);
26899 for (let i = 0; i < tensors.length; ++i) {
26900 const shape = tensors[i].shape;
26901 for (let j = 0; j < idDims[i].length; ++j) {
26902 if (dimSizes[idDims[i][j]] === undefined) {
26903 dimSizes[idDims[i][j]] = shape[j];
26904 }
26905 else {
26906 assert(dimSizes[idDims[i][j]] === shape[j], () => `Expected dimension ${dimSizes[idDims[i][j]]} at axis ${j} ` +
26907 `of input shaped ${JSON.stringify(shape)}, ` +
26908 `but got dimension ${shape[j]}`);
26909 }
26910 }
26911 }
26912 }
26913 /**
26914 * Gets path of computation for einsum.
26915 *
26916 * @param summedDims indices to the dimensions being summed over.
26917 * @param idDims A look up table for the dimensions present in each input
26918 * tensor. Each consituent array contains indices for the dimensions in the
26919 * corresponding input tensor.
26920 *
26921 * @return A map with two fields:
26922 * - path: The path of computation, with each element indicating the dimension
26923 * being summed over after the element-wise multiplication in that step.
26924 * - steps: With the same length as `path`. Each element contains the indices
26925 * to the input tensors being used for element-wise multiplication in the
26926 * corresponding step.
26927 */
26928 function getEinsumComputePath(summedDims, idDims) {
26929 const path = summedDims;
26930 const steps = [];
26931 let nSteps = 0;
26932 if (summedDims.length === 0) {
26933 // Einsum that involes no summing: e.g., transpose and outer product.
26934 path.push(-1);
26935 }
26936 nSteps = summedDims.length + 1;
26937 for (let i = 0; i < nSteps; ++i) {
26938 steps.push([]);
26939 }
26940 const computedTermIndices = [];
26941 for (let i = 0; i < path.length; ++i) {
26942 const summedDim = path[i];
26943 const termIndices = findTermsWithDim(idDims, summedDim);
26944 for (const termIndex of termIndices) {
26945 if (computedTermIndices.indexOf(termIndex) === -1) {
26946 steps[i].push(termIndex);
26947 computedTermIndices.push(termIndex);
26948 }
26949 }
26950 }
26951 return { path, steps };
26952 }
26953 /** Determines if an axes permutation is the identity permutation. */
26954 function isIdentityPermutation(perm) {
26955 return perm.every((dim, index) => dim === index);
26956 }
26957 function findTermsWithDim(idDims, dim) {
26958 const termIndices = [];
26959 for (let i = 0; i < idDims.length; ++i) {
26960 if (idDims[i].length === 0 || idDims[i].indexOf(dim) !== -1 || dim === -1) {
26961 termIndices.push(i);
26962 }
26963 }
26964 return termIndices;
26965 }
26966
26967 /**
26968 * Prepare the split size array. When the input is a number, the axis is evenly
26969 * divided among the split size. When the input contains the negative value, the
26970 * rest of the axis is allocated toward that.
26971 */
26972 function prepareSplitSize(x, numOrSizeSplits, axis = 0) {
26973 let splitSizes = [];
26974 if (typeof (numOrSizeSplits) === 'number') {
26975 assert(x.shape[axis] % numOrSizeSplits === 0, () => 'Number of splits must evenly divide the axis.');
26976 splitSizes =
26977 new Array(numOrSizeSplits).fill(x.shape[axis] / numOrSizeSplits);
26978 }
26979 else {
26980 const numOfNegs = numOrSizeSplits.reduce((count, value) => {
26981 if (value === -1) {
26982 count += 1;
26983 }
26984 return count;
26985 }, 0);
26986 assert(numOfNegs <= 1, () => 'There should be only one negative value in split array.');
26987 const negIndex = numOrSizeSplits.indexOf(-1);
26988 // Allow the number of split array to be -1, which indicates the rest
26989 // of dimension is allocated to that split.
26990 if (negIndex !== -1) {
26991 const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a);
26992 numOrSizeSplits[negIndex] = x.shape[axis] - total;
26993 }
26994 assert(x.shape[axis] === numOrSizeSplits.reduce((a, b) => a + b), () => 'The sum of sizes must match the size of the axis dimension.');
26995 splitSizes = numOrSizeSplits;
26996 }
26997 return splitSizes;
26998 }
26999
27000 /**
27001 * @license
27002 * Copyright 2021 Google LLC. All Rights Reserved.
27003 * Licensed under the Apache License, Version 2.0 (the "License");
27004 * you may not use this file except in compliance with the License.
27005 * You may obtain a copy of the License at
27006 *
27007 * http://www.apache.org/licenses/LICENSE-2.0
27008 *
27009 * Unless required by applicable law or agreed to in writing, software
27010 * distributed under the License is distributed on an "AS IS" BASIS,
27011 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27012 * See the License for the specific language governing permissions and
27013 * limitations under the License.
27014 * =============================================================================
27015 */
27016 /**
27017 * Generates sparse fill empty rows indices, dense shape mismatch error message.
27018 *
27019 * @param indicesLength The first dimension of indices.
27020 */
27021 function getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesLength) {
27022 return `Received SparseTensor with denseShape[0] = 0 but
27023 indices.shape[0] = ${indicesLength}`;
27024 }
27025 /**
27026 * Generates sparse fill empty rows negative index error message.
27027 *
27028 * @param index The index with a negative value.
27029 * @param value The negative value.
27030 */
27031 function getSparseFillEmptyRowsNegativeIndexErrorMessage(index, value) {
27032 return `indices(${index}, 0) is invalid: ${value} < 0`;
27033 }
27034 /**
27035 * Generates sparse fill empty rows out of range index error message.
27036 *
27037 * @param index The index with an out of range value.
27038 * @param value The out of range value.
27039 * @param limit The upper limit for indices.
27040 */
27041 function getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(index, value, limit) {
27042 return `indices(${index}, 0) is invalid: ${value} >= ${limit}`;
27043 }
27044
27045 /**
27046 * @license
27047 * Copyright 2021 Google LLC. All Rights Reserved.
27048 * Licensed under the Apache License, Version 2.0 (the "License");
27049 * you may not use this file except in compliance with the License.
27050 * You may obtain a copy of the License at
27051 *
27052 * http://www.apache.org/licenses/LICENSE-2.0
27053 *
27054 * Unless required by applicable law or agreed to in writing, software
27055 * distributed under the License is distributed on an "AS IS" BASIS,
27056 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27057 * See the License for the specific language governing permissions and
27058 * limitations under the License.
27059 * =============================================================================
27060 */
27061 /**
27062 * Generates sparse reshape multiple negative 1 output dimension error message.
27063 *
27064 * @param dim1 The first dimension with a negative 1 value.
27065 * @param dim2 The second dimension with a negative 1 value.
27066 */
27067 function getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(dim1, dim2) {
27068 return `only one output dimension may be -1, not both ${dim1} and ${dim2}`;
27069 }
27070 /**
27071 * Generates sparse reshape negative output dimension error message.
27072 *
27073 * @param dim The dimension with a negative value.
27074 * @param value The negative value.
27075 */
27076 function getSparseReshapeNegativeOutputDimErrorMessage(dim, value) {
27077 return `size ${dim} must be non-negative, not ${value}`;
27078 }
27079 /**
27080 * Generates sparse reshape empty tensor zero output dimension error message.
27081 *
27082 */
27083 function getSparseReshapeEmptyTensorZeroOutputDimErrorMessage() {
27084 return 'reshape cannot infer the missing input size for an empty tensor ' +
27085 'unless all specified input sizes are non-zero';
27086 }
27087 /**
27088 * Generates sparse reshape input output multiple mismatch error message.
27089 *
27090 * @param inputShape the input shape.
27091 * @param outputShape the requested output shape.
27092 */
27093 function getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape) {
27094 const inputSize = sizeFromShape(inputShape);
27095 const outputSize = sizeFromShape(outputShape);
27096 return `Input to reshape is a SparseTensor with ${inputSize}
27097 dense values, but the requested shape requires a multiple of ${outputSize}. inputShape=${inputShape} outputShape= ${outputShape}`;
27098 }
27099 /**
27100 * Generates sparse reshape input output inequality error message.
27101 *
27102 * @param inputShape the input shape.
27103 * @param outputShape the requested output shape.
27104 */
27105 function getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape) {
27106 const inputSize = sizeFromShape(inputShape);
27107 const outputSize = sizeFromShape(outputShape);
27108 return `Input to reshape is a tensor with ${inputSize} dense values, but the requested shape has ${outputSize}. inputShape=${inputShape} outputShape=${outputShape}`;
27109 }
27110
27111 /**
27112 * @license
27113 * Copyright 2021 Google LLC. All Rights Reserved.
27114 * Licensed under the Apache License, Version 2.0 (the "License");
27115 * you may not use this file except in compliance with the License.
27116 * You may obtain a copy of the License at
27117 *
27118 * http://www.apache.org/licenses/LICENSE-2.0
27119 *
27120 * Unless required by applicable law or agreed to in writing, software
27121 * distributed under the License is distributed on an "AS IS" BASIS,
27122 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27123 * See the License for the specific language governing permissions and
27124 * limitations under the License.
27125 * =============================================================================
27126 */
27127 /**
27128 * Generates sparse segment reduction negative segment ids error message.
27129 *
27130 */
27131 function getSparseSegmentReductionNegativeSegmentIdsErrorMessage() {
27132 return `segment ids must be >= 0`;
27133 }
27134 /**
27135 * Generates sparse segment reduction non increasing segment ids error message.
27136 *
27137 */
27138 function getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage() {
27139 return `segment ids are not increasing`;
27140 }
27141 /**
27142 * Generates sparse segment reduction segment id out of range error message.
27143 *
27144 * @param segmentId The segment id index that is out of range.
27145 * @param outputRows Upper bound of valid segment id values.
27146 */
27147 function getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(segmentId, outputRows) {
27148 return `Segment id ${segmentId} out of range [0, ${outputRows}), possibly because segmentIds input is not sorted.`;
27149 }
27150 /**
27151 * Generates sparse segment reduction input indice out of range error message.
27152 *
27153 * @param index The index that holds the out of range value.
27154 * @param indexValue The value that is out of range.
27155 * @param inputRows Upper bound of valid index values.
27156 */
27157 function getSparseSegmentReductionIndicesOutOfRangeErrorMessage(index, indexValue, inputRows) {
27158 return `Bad: indices[${index}] == ${indexValue} out of range [0, ${inputRows})`;
27159 }
27160
27161 /**
27162 * @license
27163 * Copyright 2018 Google LLC. All Rights Reserved.
27164 * Licensed under the Apache License, Version 2.0 (the "License");
27165 * you may not use this file except in compliance with the License.
27166 * You may obtain a copy of the License at
27167 *
27168 * http://www.apache.org/licenses/LICENSE-2.0
27169 *
27170 * Unless required by applicable law or agreed to in writing, software
27171 * distributed under the License is distributed on an "AS IS" BASIS,
27172 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27173 * See the License for the specific language governing permissions and
27174 * limitations under the License.
27175 * =============================================================================
27176 */
27177 function segOpComputeOptimalWindowSize(inSize, numSegments) {
27178 let done = false;
27179 let res;
27180 if (inSize <= PARALLELIZE_THRESHOLD) {
27181 res = inSize;
27182 done = true;
27183 }
27184 else {
27185 res = nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
27186 }
27187 while (!done) {
27188 if (res > numSegments || res === inSize) {
27189 done = true;
27190 }
27191 else {
27192 res = nearestDivisor(inSize, res + 1);
27193 }
27194 }
27195 return res;
27196 }
27197 function computeOutShape$2(aShape, axis, numSegments) {
27198 const outShape = [];
27199 const rank = aShape.length;
27200 for (let dim = 0; dim < rank; dim++) {
27201 if (dim !== axis) {
27202 outShape.push(aShape[dim]);
27203 }
27204 else {
27205 outShape.push(numSegments);
27206 }
27207 }
27208 return outShape;
27209 }
27210 function collectGatherOpShapeInfo(x, indices, axis, batchDims) {
27211 const indicesRank = indices.shape.length;
27212 const xRank = x.shape.length;
27213 if (batchDims !== 0) {
27214 if (batchDims < -indicesRank || batchDims > indicesRank) {
27215 throw new Error(`Expect batchDims in the range of [-${indicesRank}, ${indicesRank}], but got ${batchDims}`);
27216 }
27217 }
27218 if (batchDims < 0) {
27219 batchDims += indicesRank;
27220 }
27221 if (batchDims > xRank) {
27222 throw new Error(`batchDims (${batchDims}) must be less than rank(x) (
27223 ${xRank}).`);
27224 }
27225 if (axis < batchDims) {
27226 throw new Error(`batchDims (${batchDims}) must be less than or equal to axis (${axis}).`);
27227 }
27228 for (let i = 0; i < batchDims; ++i) {
27229 if (x.shape[i] !== indices.shape[i]) {
27230 throw new Error(`x.shape[${i}]: ${x.shape[i]} should be equal to indices.shape[${i}]: ${indices.shape[i]}.`);
27231 }
27232 }
27233 const dimSize = x.shape[axis];
27234 const outputShape = [];
27235 let batchSize = 1;
27236 let outerSize = 1;
27237 let sliceSize = 1;
27238 for (let i = 0; i < batchDims; ++i) {
27239 outputShape.push(x.shape[i]);
27240 batchSize *= x.shape[i];
27241 }
27242 for (let i = batchDims; i < axis; i++) {
27243 outputShape.push(x.shape[i]);
27244 outerSize *= x.shape[i];
27245 }
27246 for (let i = batchDims; i < indicesRank; i++) {
27247 outputShape.push(indices.shape[i]);
27248 }
27249 for (let i = axis + 1; i < xRank; i++) {
27250 outputShape.push(x.shape[i]);
27251 sliceSize *= x.shape[i];
27252 }
27253 return { batchSize, sliceSize, outerSize, dimSize, outputShape };
27254 }
27255
27256 var segment_util = /*#__PURE__*/Object.freeze({
27257 __proto__: null,
27258 segOpComputeOptimalWindowSize: segOpComputeOptimalWindowSize,
27259 computeOutShape: computeOutShape$2,
27260 collectGatherOpShapeInfo: collectGatherOpShapeInfo
27261 });
27262
27263 /**
27264 * @license
27265 * Copyright 2018 Google LLC. All Rights Reserved.
27266 * Licensed under the Apache License, Version 2.0 (the "License");
27267 * you may not use this file except in compliance with the License.
27268 * You may obtain a copy of the License at
27269 *
27270 * http://www.apache.org/licenses/LICENSE-2.0
27271 *
27272 * Unless required by applicable law or agreed to in writing, software
27273 * distributed under the License is distributed on an "AS IS" BASIS,
27274 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27275 * See the License for the specific language governing permissions and
27276 * limitations under the License.
27277 * =============================================================================
27278 */
27279 function fromUint8ToStringArray(vals) {
27280 try {
27281 // Decode the bytes into string.
27282 return vals.map(val => decodeString(val));
27283 }
27284 catch (err) {
27285 throw new Error(`Failed to decode encoded string bytes into utf-8, error: ${err}`);
27286 }
27287 }
27288 function fromStringArrayToUint8(strings) {
27289 return strings.map(s => encodeString(s));
27290 }
27291
27292 var backend_util = /*#__PURE__*/Object.freeze({
27293 __proto__: null,
27294 slice_util: slice_util,
27295 segment_util: segment_util,
27296 fromUint8ToStringArray: fromUint8ToStringArray,
27297 fromStringArrayToUint8: fromStringArrayToUint8,
27298 upcastType: upcastType,
27299 axesAreInnerMostDims: axesAreInnerMostDims,
27300 combineLocations: combineLocations,
27301 computeOutAndReduceShapes: computeOutAndReduceShapes,
27302 expandShapeToKeepDim: expandShapeToKeepDim,
27303 assertAxesAreInnerMostDims: assertAxesAreInnerMostDims,
27304 getAxesPermutation: getAxesPermutation,
27305 getUndoAxesPermutation: getUndoAxesPermutation,
27306 getInnerMostAxes: getInnerMostAxes,
27307 getBroadcastDims: getBroadcastDims,
27308 getReductionAxes: getReductionAxes,
27309 assertAndGetBroadcastShape: assertAndGetBroadcastShape,
27310 assertParamsConsistent: assertParamsConsistent,
27311 computeOutShape: computeOutShape$1,
27312 computeDilation2DInfo: computeDilation2DInfo,
27313 computePool2DInfo: computePool2DInfo,
27314 computePool3DInfo: computePool3DInfo,
27315 computeConv2DInfo: computeConv2DInfo,
27316 computeConv3DInfo: computeConv3DInfo,
27317 computeDefaultPad: computeDefaultPad,
27318 tupleValuesAreOne: tupleValuesAreOne,
27319 eitherStridesOrDilationsAreOne: eitherStridesOrDilationsAreOne,
27320 convertConv2DDataFormat: convertConv2DDataFormat,
27321 checkPadOnDimRoundingMode: checkPadOnDimRoundingMode,
27322 getFusedDyActivation: getFusedDyActivation,
27323 getFusedBiasGradient: getFusedBiasGradient,
27324 applyActivation: applyActivation,
27325 shouldFuse: shouldFuse,
27326 PARALLELIZE_THRESHOLD: PARALLELIZE_THRESHOLD,
27327 computeOptimalWindowSize: computeOptimalWindowSize,
27328 getImageCenter: getImageCenter,
27329 getReshaped: getReshaped,
27330 getPermuted: getPermuted,
27331 getReshapedPermuted: getReshapedPermuted,
27332 getSliceBeginCoords: getSliceBeginCoords,
27333 getSliceSize: getSliceSize,
27334 prepareAndValidate: prepareAndValidate,
27335 validateUpdateShape: validateUpdateShape,
27336 validateInput: validateInput,
27337 calculateShapes: calculateShapes,
27338 SELU_SCALEALPHA: SELU_SCALEALPHA,
27339 SELU_SCALE: SELU_SCALE,
27340 ERF_P: ERF_P,
27341 ERF_A1: ERF_A1,
27342 ERF_A2: ERF_A2,
27343 ERF_A3: ERF_A3,
27344 ERF_A4: ERF_A4,
27345 ERF_A5: ERF_A5,
27346 warn: warn,
27347 log: log,
27348 mergeRealAndImagArrays: mergeRealAndImagArrays,
27349 splitRealAndImagArrays: splitRealAndImagArrays,
27350 complexWithEvenIndex: complexWithEvenIndex,
27351 complexWithOddIndex: complexWithOddIndex,
27352 getComplexWithIndex: getComplexWithIndex,
27353 assignToTypedArray: assignToTypedArray,
27354 exponents: exponents,
27355 exponent: exponent,
27356 decodeEinsumEquation: decodeEinsumEquation,
27357 getEinsumPermutation: getEinsumPermutation,
27358 checkEinsumDimSizes: checkEinsumDimSizes,
27359 getEinsumComputePath: getEinsumComputePath,
27360 isIdentityPermutation: isIdentityPermutation,
27361 prepareSplitSize: prepareSplitSize,
27362 getSparseFillEmptyRowsIndicesDenseShapeMismatch: getSparseFillEmptyRowsIndicesDenseShapeMismatch,
27363 getSparseFillEmptyRowsNegativeIndexErrorMessage: getSparseFillEmptyRowsNegativeIndexErrorMessage,
27364 getSparseFillEmptyRowsOutOfRangeIndexErrorMessage: getSparseFillEmptyRowsOutOfRangeIndexErrorMessage,
27365 getSparseReshapeMultipleNegativeOneOutputDimErrorMessage: getSparseReshapeMultipleNegativeOneOutputDimErrorMessage,
27366 getSparseReshapeNegativeOutputDimErrorMessage: getSparseReshapeNegativeOutputDimErrorMessage,
27367 getSparseReshapeEmptyTensorZeroOutputDimErrorMessage: getSparseReshapeEmptyTensorZeroOutputDimErrorMessage,
27368 getSparseReshapeInputOutputMultipleErrorMessage: getSparseReshapeInputOutputMultipleErrorMessage,
27369 getSparseReshapeInputOutputMismatchErrorMessage: getSparseReshapeInputOutputMismatchErrorMessage,
27370 getSparseSegmentReductionNegativeSegmentIdsErrorMessage: getSparseSegmentReductionNegativeSegmentIdsErrorMessage,
27371 getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage: getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage,
27372 getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage: getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage,
27373 getSparseSegmentReductionIndicesOutOfRangeErrorMessage: getSparseSegmentReductionIndicesOutOfRangeErrorMessage
27374 });
27375
27376 /**
27377 * @license
27378 * Copyright 2020 Google LLC. All Rights Reserved.
27379 * Licensed under the Apache License, Version 2.0 (the "License");
27380 * you may not use this file except in compliance with the License.
27381 * You may obtain a copy of the License at
27382 *
27383 * http://www.apache.org/licenses/LICENSE-2.0
27384 *
27385 * Unless required by applicable law or agreed to in writing, software
27386 * distributed under the License is distributed on an "AS IS" BASIS,
27387 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27388 * See the License for the specific language governing permissions and
27389 * limitations under the License.
27390 * =============================================================================
27391 */
27392
27393 var kernel_impls = /*#__PURE__*/Object.freeze({
27394 __proto__: null,
27395 nonMaxSuppressionV3Impl: nonMaxSuppressionV3Impl,
27396 nonMaxSuppressionV4Impl: nonMaxSuppressionV4Impl,
27397 nonMaxSuppressionV5Impl: nonMaxSuppressionV5Impl,
27398 whereImpl: whereImpl
27399 });
27400
27401 /**
27402 * @license
27403 * Copyright 2020 Google Inc. All Rights Reserved.
27404 * Licensed under the Apache License, Version 2.0 (the "License");
27405 * you may not use this file except in compliance with the License.
27406 * You may obtain a copy of the License at
27407 *
27408 * http://www.apache.org/licenses/LICENSE-2.0
27409 *
27410 * Unless required by applicable law or agreed to in writing, software
27411 * distributed under the License is distributed on an "AS IS" BASIS,
27412 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27413 * See the License for the specific language governing permissions and
27414 * limitations under the License.
27415 * =============================================================================
27416 */
27417
27418 /**
27419 * @license
27420 * Copyright 2017 Google LLC. All Rights Reserved.
27421 * Licensed under the Apache License, Version 2.0 (the "License");
27422 * you may not use this file except in compliance with the License.
27423 * You may obtain a copy of the License at
27424 *
27425 * http://www.apache.org/licenses/LICENSE-2.0
27426 *
27427 * Unless required by applicable law or agreed to in writing, software
27428 * distributed under the License is distributed on an "AS IS" BASIS,
27429 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27430 * See the License for the specific language governing permissions and
27431 * limitations under the License.
27432 * =============================================================================
27433 */
27434
27435 /**
27436 * @license
27437 * Copyright 2020 Google LLC. All Rights Reserved.
27438 * Licensed under the Apache License, Version 2.0 (the "License");
27439 * you may not use this file except in compliance with the License.
27440 * You may obtain a copy of the License at
27441 *
27442 * http://www.apache.org/licenses/LICENSE-2.0
27443 *
27444 * Unless required by applicable law or agreed to in writing, software
27445 * distributed under the License is distributed on an "AS IS" BASIS,
27446 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27447 * See the License for the specific language governing permissions and
27448 * limitations under the License.
27449 * =============================================================================
27450 */
27451 const absGradConfig = {
27452 kernelName: Abs,
27453 inputsToSave: ['x'],
27454 gradFunc: (dy, saved) => {
27455 const [x] = saved;
27456 return { x: () => mul(dy, step(cast(x, 'float32'), -1)) };
27457 }
27458 };
27459
27460 /**
27461 * @license
27462 * Copyright 2020 Google LLC. All Rights Reserved.
27463 * Licensed under the Apache License, Version 2.0 (the "License");
27464 * you may not use this file except in compliance with the License.
27465 * You may obtain a copy of the License at
27466 *
27467 * http://www.apache.org/licenses/LICENSE-2.0
27468 *
27469 * Unless required by applicable law or agreed to in writing, software
27470 * distributed under the License is distributed on an "AS IS" BASIS,
27471 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27472 * See the License for the specific language governing permissions and
27473 * limitations under the License.
27474 * =============================================================================
27475 */
27476 const acosGradConfig = {
27477 kernelName: Acos,
27478 inputsToSave: ['x'],
27479 gradFunc: (dy, saved) => {
27480 const [x] = saved;
27481 return {
27482 x: () => {
27483 const a = square(cast(x, 'float32'));
27484 const b = sqrt(sub(scalar(1), a));
27485 return neg(div(dy, b));
27486 }
27487 };
27488 }
27489 };
27490
27491 /**
27492 * @license
27493 * Copyright 2020 Google LLC. All Rights Reserved.
27494 * Licensed under the Apache License, Version 2.0 (the "License");
27495 * you may not use this file except in compliance with the License.
27496 * You may obtain a copy of the License at
27497 *
27498 * http://www.apache.org/licenses/LICENSE-2.0
27499 *
27500 * Unless required by applicable law or agreed to in writing, software
27501 * distributed under the License is distributed on an "AS IS" BASIS,
27502 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27503 * See the License for the specific language governing permissions and
27504 * limitations under the License.
27505 * =============================================================================
27506 */
27507 const acoshGradConfig = {
27508 kernelName: Acosh,
27509 inputsToSave: ['x'],
27510 gradFunc: (dy, saved) => {
27511 const [x] = saved;
27512 return {
27513 x: () => {
27514 const a = sqrt(sub(square(cast(x, 'float32')), 1));
27515 return div(dy, a);
27516 }
27517 };
27518 }
27519 };
27520
27521 /**
27522 * @license
27523 * Copyright 2020 Google LLC. All Rights Reserved.
27524 * Licensed under the Apache License, Version 2.0 (the "License");
27525 * you may not use this file except in compliance with the License.
27526 * You may obtain a copy of the License at
27527 *
27528 * http://www.apache.org/licenses/LICENSE-2.0
27529 *
27530 * Unless required by applicable law or agreed to in writing, software
27531 * distributed under the License is distributed on an "AS IS" BASIS,
27532 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27533 * See the License for the specific language governing permissions and
27534 * limitations under the License.
27535 * =============================================================================
27536 */
27537 const addGradConfig = {
27538 kernelName: Add,
27539 inputsToSave: ['a', 'b'],
27540 gradFunc: (dy, saved) => {
27541 const [a, b] = saved;
27542 const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
27543 const derA = () => {
27544 let res = dy;
27545 const reduceAxes = getReductionAxes(a.shape, outShape);
27546 if (reduceAxes.length > 0) {
27547 res = sum$1(res, reduceAxes);
27548 }
27549 return reshape(res, a.shape);
27550 };
27551 const derB = () => {
27552 let res = dy;
27553 const reduceAxes = getReductionAxes(b.shape, outShape);
27554 if (reduceAxes.length > 0) {
27555 res = sum$1(res, reduceAxes);
27556 }
27557 return reshape(res, b.shape);
27558 };
27559 return { a: derA, b: derB };
27560 }
27561 };
27562
27563 /**
27564 * @license
27565 * Copyright 2020 Google LLC. All Rights Reserved.
27566 * Licensed under the Apache License, Version 2.0 (the "License");
27567 * you may not use this file except in compliance with the License.
27568 * You may obtain a copy of the License at
27569 *
27570 * http://www.apache.org/licenses/LICENSE-2.0
27571 *
27572 * Unless required by applicable law or agreed to in writing, software
27573 * distributed under the License is distributed on an "AS IS" BASIS,
27574 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27575 * See the License for the specific language governing permissions and
27576 * limitations under the License.
27577 * =============================================================================
27578 */
27579 const addNGradConfig = {
27580 kernelName: AddN,
27581 saveAllInputs: true,
27582 gradFunc: (dy, saved) => {
27583 const ders = {};
27584 saved.forEach((_, i) => {
27585 ders[i] = () => dy.clone();
27586 });
27587 return ders;
27588 }
27589 };
27590
27591 /**
27592 * @license
27593 * Copyright 2020 Google Inc. All Rights Reserved.
27594 * Licensed under the Apache License, Version 2.0 (the "License");
27595 * you may not use this file except in compliance with the License.
27596 * You may obtain a copy of the License at
27597 *
27598 * http://www.apache.org/licenses/LICENSE-2.0
27599 *
27600 * Unless required by applicable law or agreed to in writing, software
27601 * distributed under the License is distributed on an "AS IS" BASIS,
27602 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27603 * See the License for the specific language governing permissions and
27604 * limitations under the License.
27605 * =============================================================================
27606 */
27607 const argMaxGradConfig = {
27608 kernelName: ArgMax,
27609 inputsToSave: ['x'],
27610 gradFunc: (dy, saved) => {
27611 const [x] = saved;
27612 return { x: () => zerosLike(x) };
27613 }
27614 };
27615
27616 /**
27617 * @license
27618 * Copyright 2020 Google Inc. All Rights Reserved.
27619 * Licensed under the Apache License, Version 2.0 (the "License");
27620 * you may not use this file except in compliance with the License.
27621 * You may obtain a copy of the License at
27622 *
27623 * http://www.apache.org/licenses/LICENSE-2.0
27624 *
27625 * Unless required by applicable law or agreed to in writing, software
27626 * distributed under the License is distributed on an "AS IS" BASIS,
27627 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27628 * See the License for the specific language governing permissions and
27629 * limitations under the License.
27630 * =============================================================================
27631 */
27632 const argMinGradConfig = {
27633 kernelName: ArgMin,
27634 inputsToSave: ['x'],
27635 gradFunc: (dy, saved) => {
27636 const [x] = saved;
27637 return { x: () => zerosLike(x) };
27638 }
27639 };
27640
27641 /**
27642 * @license
27643 * Copyright 2020 Google LLC. All Rights Reserved.
27644 * Licensed under the Apache License, Version 2.0 (the "License");
27645 * you may not use this file except in compliance with the License.
27646 * You may obtain a copy of the License at
27647 *
27648 * http://www.apache.org/licenses/LICENSE-2.0
27649 *
27650 * Unless required by applicable law or agreed to in writing, software
27651 * distributed under the License is distributed on an "AS IS" BASIS,
27652 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27653 * See the License for the specific language governing permissions and
27654 * limitations under the License.
27655 * =============================================================================
27656 */
27657 const asinGradConfig = {
27658 kernelName: Asin,
27659 inputsToSave: ['x'],
27660 gradFunc: (dy, saved) => {
27661 const [x] = saved;
27662 return { x: () => div(dy, sqrt(sub(scalar(1), square(cast(x, 'float32'))))) };
27663 }
27664 };
27665
27666 /**
27667 * @license
27668 * Copyright 2020 Google LLC. All Rights Reserved.
27669 * Licensed under the Apache License, Version 2.0 (the "License");
27670 * you may not use this file except in compliance with the License.
27671 * You may obtain a copy of the License at
27672 *
27673 * http://www.apache.org/licenses/LICENSE-2.0
27674 *
27675 * Unless required by applicable law or agreed to in writing, software
27676 * distributed under the License is distributed on an "AS IS" BASIS,
27677 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27678 * See the License for the specific language governing permissions and
27679 * limitations under the License.
27680 * =============================================================================
27681 */
27682 const asinhGradConfig = {
27683 kernelName: Asinh,
27684 inputsToSave: ['x'],
27685 gradFunc: (dy, saved) => {
27686 const [x] = saved;
27687 return {
27688 x: () => {
27689 const a = sqrt(add$1(scalar(1), square(cast(x, 'float32'))));
27690 return div(dy, a);
27691 }
27692 };
27693 }
27694 };
27695
27696 /**
27697 * @license
27698 * Copyright 2020 Google LLC. All Rights Reserved.
27699 * Licensed under the Apache License, Version 2.0 (the "License");
27700 * you may not use this file except in compliance with the License.
27701 * You may obtain a copy of the License at
27702 *
27703 * http://www.apache.org/licenses/LICENSE-2.0
27704 *
27705 * Unless required by applicable law or agreed to in writing, software
27706 * distributed under the License is distributed on an "AS IS" BASIS,
27707 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27708 * See the License for the specific language governing permissions and
27709 * limitations under the License.
27710 * =============================================================================
27711 */
27712 const atan2GradConfig = {
27713 kernelName: Atan2,
27714 inputsToSave: ['a', 'b'],
27715 gradFunc: (dy, saved) => {
27716 const [a, b] = saved;
27717 const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
27718 const derA = () => {
27719 const d = add$1(square(a), square(b));
27720 let res = mul(dy, div(b, d));
27721 const reduceAxes = getReductionAxes(a.shape, outShape);
27722 if (reduceAxes.length > 0) {
27723 res = sum$1(res, reduceAxes);
27724 }
27725 return reshape(res, a.shape);
27726 };
27727 const derB = () => {
27728 const d = add$1(square(a), square(b));
27729 let res = neg(mul(dy, div(a, d)));
27730 const reduceAxes = getReductionAxes(b.shape, outShape);
27731 if (reduceAxes.length > 0) {
27732 res = sum$1(res, reduceAxes);
27733 }
27734 return reshape(res, b.shape);
27735 };
27736 return { a: derA, b: derB };
27737 }
27738 };
27739
27740 /**
27741 * @license
27742 * Copyright 2020 Google LLC. All Rights Reserved.
27743 * Licensed under the Apache License, Version 2.0 (the "License");
27744 * you may not use this file except in compliance with the License.
27745 * You may obtain a copy of the License at
27746 *
27747 * http://www.apache.org/licenses/LICENSE-2.0
27748 *
27749 * Unless required by applicable law or agreed to in writing, software
27750 * distributed under the License is distributed on an "AS IS" BASIS,
27751 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27752 * See the License for the specific language governing permissions and
27753 * limitations under the License.
27754 * =============================================================================
27755 */
27756 const atanGradConfig = {
27757 kernelName: Atan,
27758 inputsToSave: ['x'],
27759 gradFunc: (dy, saved) => {
27760 const [x] = saved;
27761 return { x: () => div(dy, add$1(square(cast(x, 'float32')), 1)) };
27762 }
27763 };
27764
27765 /**
27766 * @license
27767 * Copyright 2020 Google LLC. All Rights Reserved.
27768 * Licensed under the Apache License, Version 2.0 (the "License");
27769 * you may not use this file except in compliance with the License.
27770 * You may obtain a copy of the License at
27771 *
27772 * http://www.apache.org/licenses/LICENSE-2.0
27773 *
27774 * Unless required by applicable law or agreed to in writing, software
27775 * distributed under the License is distributed on an "AS IS" BASIS,
27776 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27777 * See the License for the specific language governing permissions and
27778 * limitations under the License.
27779 * =============================================================================
27780 */
27781 const atanhGradConfig = {
27782 kernelName: Atanh,
27783 inputsToSave: ['x'],
27784 gradFunc: (dy, saved) => {
27785 const [x] = saved;
27786 return { x: () => div(dy, sub(scalar(1), square(cast(x, 'float32')))) };
27787 }
27788 };
27789
27790 /**
27791 * @license
27792 * Copyright 2020 Google LLC. All Rights Reserved.
27793 * Licensed under the Apache License, Version 2.0 (the "License");
27794 * you may not use this file except in compliance with the License.
27795 * You may obtain a copy of the License at
27796 *
27797 * http://www.apache.org/licenses/LICENSE-2.0
27798 *
27799 * Unless required by applicable law or agreed to in writing, software
27800 * distributed under the License is distributed on an "AS IS" BASIS,
27801 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27802 * See the License for the specific language governing permissions and
27803 * limitations under the License.
27804 * =============================================================================
27805 */
27806 /**
27807 * Computes the backprop of a 3d avg pool.
27808 *
27809 * @param dy The dy error, of rank 5 of shape
27810 * [batchSize, depth, height, width, channels].
27811 * assumed.
27812 * @param input The original input image, of rank 5 or rank4 of shape
27813 * [batchSize, depth, height, width, channels].
27814 * @param filterSize The filter size:
27815 * `[filterDepth, filterHeight, filterWidth]`.
27816 * `filterSize` is a single number,
27817 * then `filterDepth == filterHeight == filterWidth`.
27818 * @param strides The strides of the pooling:
27819 * `[strideDepth, strideHeight, strideWidth]`. If
27820 * `strides` is a single number, then `strideHeight == strideWidth`.
27821 * @param pad A string from: 'same', 'valid'. The type of padding algorithm
27822 * used in the forward prop of the op.
27823 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
27824 * provided, it will default to truncate.
27825 */
27826 function avgPool3dGrad_(dy, input, filterSize, strides, pad, dimRoundingMode) {
27827 const $dy = convertToTensor(dy, 'dy', 'avgPool3dGrad');
27828 const $input = convertToTensor(input, 'input', 'avgPool3dGrad');
27829 let dy5D = $dy;
27830 let input5D = $input;
27831 let reshapedTo5D = false;
27832 if ($input.rank === 4) {
27833 reshapedTo5D = true;
27834 dy5D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
27835 input5D = reshape($input, [
27836 1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]
27837 ]);
27838 }
27839 assert(dy5D.rank === 5, () => `Error in avgPool3dGrad: dy must be rank 5 but got rank ` +
27840 `${dy5D.rank}.`);
27841 assert(input5D.rank === 5, () => `Error in avgPool3dGrad: input must be rank 5 but got rank ` +
27842 `${input5D.rank}.`);
27843 checkPadOnDimRoundingMode('avgPool3dGrad', pad, dimRoundingMode);
27844 const inputs = { dy: dy5D, input: input5D };
27845 const attrs = { filterSize, strides, pad, dimRoundingMode };
27846 // tslint:disable-next-line: no-unnecessary-type-assertion
27847 const res = ENGINE.runKernel(AvgPool3DGrad, inputs, attrs);
27848 if (reshapedTo5D) {
27849 return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
27850 }
27851 return res;
27852 }
27853 const avgPool3dGrad = op({ avgPool3dGrad_ });
27854
27855 /**
27856 * @license
27857 * Copyright 2020 Google LLC. All Rights Reserved.
27858 * Licensed under the Apache License, Version 2.0 (the "License");
27859 * you may not use this file except in compliance with the License.
27860 * You may obtain a copy of the License at
27861 *
27862 * http://www.apache.org/licenses/LICENSE-2.0
27863 *
27864 * Unless required by applicable law or agreed to in writing, software
27865 * distributed under the License is distributed on an "AS IS" BASIS,
27866 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27867 * See the License for the specific language governing permissions and
27868 * limitations under the License.
27869 * =============================================================================
27870 */
27871 const avgPool3DGradConfig = {
27872 kernelName: AvgPool3D,
27873 inputsToSave: ['x'],
27874 gradFunc: (dy, saved, attrs) => {
27875 const [x] = saved;
27876 const { filterSize, strides, pad, dimRoundingMode } = attrs;
27877 return {
27878 x: () => avgPool3dGrad(dy, x, filterSize, strides, pad, dimRoundingMode)
27879 };
27880 }
27881 };
27882
27883 /**
27884 * @license
27885 * Copyright 2020 Google LLC. All Rights Reserved.
27886 * Licensed under the Apache License, Version 2.0 (the "License");
27887 * you may not use this file except in compliance with the License.
27888 * You may obtain a copy of the License at
27889 *
27890 * http://www.apache.org/licenses/LICENSE-2.0
27891 *
27892 * Unless required by applicable law or agreed to in writing, software
27893 * distributed under the License is distributed on an "AS IS" BASIS,
27894 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27895 * See the License for the specific language governing permissions and
27896 * limitations under the License.
27897 * =============================================================================
27898 */
27899 /**
27900 * Computes the backprop of an 2D avg pool.
27901 *
27902 * @param dy The dy error, of rank 4 or rank 3 of shape
27903 * [batchSize, height, width, channels]. If rank 3, batch of 1 is
27904 * assumed.
27905 * @param input The input image, of rank 4 or rank 3 of shape
27906 * [batchSize, height, width, channels]. If rank 3, batch of 1 is
27907 * assumed.
27908 * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
27909 * `filterSize` is a single number, then `filterHeight == filterWidth`.
27910 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
27911 * `strides` is a single number, then `strideHeight == strideWidth`.
27912 * @param pad The type of padding algorithm used in the forward prop of the op.
27913 * 'same', 'valid', for more info, see this guide:
27914 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
27915 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
27916 */
27917 function avgPoolGrad_(dy, input, filterSize, strides, pad) {
27918 const $dy = convertToTensor(dy, 'dy', 'avgPoolGrad');
27919 const $input = convertToTensor(input, 'input', 'avgPoolGrad');
27920 assert($input.rank === $dy.rank, () => `Rank of input (${$input.rank}) does not match rank of dy (${$dy.rank})`);
27921 let input4D = $input;
27922 let dy4D = $dy;
27923 let reshapedTo4D = false;
27924 if ($input.rank === 3) {
27925 reshapedTo4D = true;
27926 input4D =
27927 reshape($input, [1, $input.shape[0], $input.shape[1], $input.shape[2]]);
27928 dy4D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2]]);
27929 }
27930 assert(dy4D.rank === 4, () => `Error in avgPoolGrad: dy must be rank 4 but got rank ` +
27931 `${dy4D.rank}.`);
27932 assert(input4D.rank === 4, () => `Error in avgPoolGrad: input must be rank 4 but got rank ` +
27933 `${input4D.rank}.`);
27934 const inputs = { dy: dy4D, input: input4D };
27935 const attrs = { filterSize, strides, pad };
27936 // tslint:disable-next-line: no-unnecessary-type-assertion
27937 const res = ENGINE.runKernel(AvgPoolGrad, inputs, attrs);
27938 if (reshapedTo4D) {
27939 return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
27940 }
27941 return res;
27942 }
27943 const avgPoolGrad = op({ avgPoolGrad_ });
27944
27945 /**
27946 * @license
27947 * Copyright 2020 Google LLC. All Rights Reserved.
27948 * Licensed under the Apache License, Version 2.0 (the "License");
27949 * you may not use this file except in compliance with the License.
27950 * You may obtain a copy of the License at
27951 *
27952 * http://www.apache.org/licenses/LICENSE-2.0
27953 *
27954 * Unless required by applicable law or agreed to in writing, software
27955 * distributed under the License is distributed on an "AS IS" BASIS,
27956 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27957 * See the License for the specific language governing permissions and
27958 * limitations under the License.
27959 * =============================================================================
27960 */
27961 const avgPoolGradConfig = {
27962 kernelName: AvgPool,
27963 inputsToSave: ['x'],
27964 gradFunc: (dy, saved, attrs) => {
27965 const [x] = saved;
27966 const { filterSize, strides, pad } = attrs;
27967 return { x: () => avgPoolGrad(dy, x, filterSize, strides, pad) };
27968 }
27969 };
27970
27971 /**
27972 * @license
27973 * Copyright 2020 Google LLC. All Rights Reserved.
27974 * Licensed under the Apache License, Version 2.0 (the "License");
27975 * you may not use this file except in compliance with the License.
27976 * You may obtain a copy of the License at
27977 *
27978 * http://www.apache.org/licenses/LICENSE-2.0
27979 *
27980 * Unless required by applicable law or agreed to in writing, software
27981 * distributed under the License is distributed on an "AS IS" BASIS,
27982 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27983 * See the License for the specific language governing permissions and
27984 * limitations under the License.
27985 * =============================================================================
27986 */
27987 const batchMatMulGradConfig = {
27988 kernelName: BatchMatMul,
27989 inputsToSave: ['a', 'b'],
27990 gradFunc: (dy, saved, attrs) => {
27991 const [a, b] = saved;
27992 const { transposeA, transposeB } = attrs;
27993 if (!transposeA && !transposeB) {
27994 return {
27995 a: () => matMul(dy, b, false, true),
27996 b: () => matMul(a, dy, true, false)
27997 };
27998 }
27999 else if (!transposeA && transposeB) {
28000 return {
28001 a: () => matMul(dy, b, false, false),
28002 b: () => matMul(dy, a, true, false)
28003 };
28004 }
28005 else if (transposeA && !transposeB) {
28006 return {
28007 a: () => matMul(b, dy, false, true),
28008 b: () => matMul(a, dy, false, false)
28009 };
28010 }
28011 else {
28012 return {
28013 a: () => matMul(b, dy, true, true),
28014 b: () => matMul(dy, a, true, true)
28015 };
28016 }
28017 }
28018 };
28019
28020 /**
28021 * @license
28022 * Copyright 2020 Google LLC. All Rights Reserved.
28023 * Licensed under the Apache License, Version 2.0 (the "License");
28024 * you may not use this file except in compliance with the License.
28025 * You may obtain a copy of the License at
28026 *
28027 * http://www.apache.org/licenses/LICENSE-2.0
28028 *
28029 * Unless required by applicable law or agreed to in writing, software
28030 * distributed under the License is distributed on an "AS IS" BASIS,
28031 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28032 * See the License for the specific language governing permissions and
28033 * limitations under the License.
28034 * =============================================================================
28035 */
28036 const batchToSpaceNDGradConfig = {
28037 kernelName: BatchToSpaceND,
28038 gradFunc: (dy, saved, attrs) => {
28039 const { blockShape, crops } = attrs;
28040 return { x: () => spaceToBatchND(dy, blockShape, crops) };
28041 }
28042 };
28043
28044 /**
28045 * @license
28046 * Copyright 2020 Google LLC. All Rights Reserved.
28047 * Licensed under the Apache License, Version 2.0 (the "License");
28048 * you may not use this file except in compliance with the License.
28049 * You may obtain a copy of the License at
28050 *
28051 * http://www.apache.org/licenses/LICENSE-2.0
28052 *
28053 * Unless required by applicable law or agreed to in writing, software
28054 * distributed under the License is distributed on an "AS IS" BASIS,
28055 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28056 * See the License for the specific language governing permissions and
28057 * limitations under the License.
28058 * =============================================================================
28059 */
28060 const broadcastToGradConfig = {
28061 kernelName: BroadcastTo,
28062 gradFunc: (dy, saved, attrs) => {
28063 const broadCastToAttrs = attrs;
28064 const inputShape = broadCastToAttrs.inputShape;
28065 const outputShape = broadCastToAttrs.shape;
28066 const reps = Array.from(outputShape);
28067 for (let i = inputShape.length - 1; i >= 0; i--) {
28068 if (inputShape[i] === outputShape[i]) {
28069 reps[i] = 1;
28070 }
28071 else if (inputShape[i] !== 1) {
28072 throw new Error(`broadcastTo(): [${inputShape}] cannot be broadcast to [${outputShape}].`);
28073 }
28074 }
28075 const axes = [];
28076 for (let i = 0; i < reps.length; i++) {
28077 if (reps[i] > 1) {
28078 axes.push(i);
28079 }
28080 }
28081 return { x: () => sum$1(dy, axes, true /* keepDims */) };
28082 }
28083 };
28084
28085 /**
28086 * @license
28087 * Copyright 2020 Google LLC. All Rights Reserved.
28088 * Licensed under the Apache License, Version 2.0 (the "License");
28089 * you may not use this file except in compliance with the License.
28090 * You may obtain a copy of the License at
28091 *
28092 * http://www.apache.org/licenses/LICENSE-2.0
28093 *
28094 * Unless required by applicable law or agreed to in writing, software
28095 * distributed under the License is distributed on an "AS IS" BASIS,
28096 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28097 * See the License for the specific language governing permissions and
28098 * limitations under the License.
28099 * =============================================================================
28100 */
28101 const castGradConfig = {
28102 kernelName: Cast,
28103 gradFunc: (dy) => {
28104 return { x: () => dy.clone() };
28105 }
28106 };
28107
28108 /**
28109 * @license
28110 * Copyright 2020 Google LLC. All Rights Reserved.
28111 * Licensed under the Apache License, Version 2.0 (the "License");
28112 * you may not use this file except in compliance with the License.
28113 * You may obtain a copy of the License at
28114 *
28115 * http://www.apache.org/licenses/LICENSE-2.0
28116 *
28117 * Unless required by applicable law or agreed to in writing, software
28118 * distributed under the License is distributed on an "AS IS" BASIS,
28119 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28120 * See the License for the specific language governing permissions and
28121 * limitations under the License.
28122 * =============================================================================
28123 */
28124 const ceilGradConfig = {
28125 kernelName: Ceil,
28126 gradFunc: (dy) => {
28127 // TODO(manrajgrover): Return null for gradients when backprop supports it.
28128 return { x: () => zerosLike(dy) };
28129 }
28130 };
28131
28132 /**
28133 * @license
28134 * Copyright 2020 Google LLC. All Rights Reserved.
28135 * Licensed under the Apache License, Version 2.0 (the "License");
28136 * you may not use this file except in compliance with the License.
28137 * You may obtain a copy of the License at
28138 *
28139 * http://www.apache.org/licenses/LICENSE-2.0
28140 *
28141 * Unless required by applicable law or agreed to in writing, software
28142 * distributed under the License is distributed on an "AS IS" BASIS,
28143 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28144 * See the License for the specific language governing permissions and
28145 * limitations under the License.
28146 * =============================================================================
28147 */
28148 const clipByValueGradConfig = {
28149 kernelName: ClipByValue,
28150 inputsToSave: ['x'],
28151 gradFunc: (dy, saved, attrs) => {
28152 const [x] = saved;
28153 const { clipValueMin, clipValueMax } = attrs;
28154 return {
28155 x: () => where(logicalAnd(greaterEqual(x, clipValueMin), lessEqual(x, clipValueMax)), dy, zerosLike(dy)),
28156 };
28157 }
28158 };
28159
28160 /**
28161 * @license
28162 * Copyright 2020 Google LLC. All Rights Reserved.
28163 * Licensed under the Apache License, Version 2.0 (the "License");
28164 * you may not use this file except in compliance with the License.
28165 * You may obtain a copy of the License at
28166 *
28167 * http://www.apache.org/licenses/LICENSE-2.0
28168 *
28169 * Unless required by applicable law or agreed to in writing, software
28170 * distributed under the License is distributed on an "AS IS" BASIS,
28171 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28172 * See the License for the specific language governing permissions and
28173 * limitations under the License.
28174 * =============================================================================
28175 */
28176 const complexAbsGradConfig = {
28177 kernelName: ComplexAbs,
28178 inputsToSave: ['x'],
28179 gradFunc: absGradConfig.gradFunc,
28180 };
28181
28182 /**
28183 * @license
28184 * Copyright 2020 Google LLC. All Rights Reserved.
28185 * Licensed under the Apache License, Version 2.0 (the "License");
28186 * you may not use this file except in compliance with the License.
28187 * You may obtain a copy of the License at
28188 *
28189 * http://www.apache.org/licenses/LICENSE-2.0
28190 *
28191 * Unless required by applicable law or agreed to in writing, software
28192 * distributed under the License is distributed on an "AS IS" BASIS,
28193 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28194 * See the License for the specific language governing permissions and
28195 * limitations under the License.
28196 * =============================================================================
28197 */
28198 const concatGradConfig = {
28199 kernelName: Concat,
28200 saveAllInputs: true,
28201 gradFunc: (dy, saved, attrs) => {
28202 const shapes = saved.map(t => t.shape);
28203 const { axis } = attrs;
28204 const $axis = parseAxisParam(axis, saved[0].shape)[0];
28205 const sizeSplits = shapes.map(s => s[$axis]);
28206 const derTensors = split(dy, sizeSplits, $axis);
28207 return derTensors.map(t => () => t);
28208 }
28209 };
28210
28211 /**
28212 * @license
28213 * Copyright 2020 Google LLC. All Rights Reserved.
28214 * Licensed under the Apache License, Version 2.0 (the "License");
28215 * you may not use this file except in compliance with the License.
28216 * You may obtain a copy of the License at
28217 *
28218 * http://www.apache.org/licenses/LICENSE-2.0
28219 *
28220 * Unless required by applicable law or agreed to in writing, software
28221 * distributed under the License is distributed on an "AS IS" BASIS,
28222 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28223 * See the License for the specific language governing permissions and
28224 * limitations under the License.
28225 * =============================================================================
28226 */
28227 const conv2DGradConfig = {
28228 kernelName: Conv2D,
28229 inputsToSave: ['x', 'filter'],
28230 gradFunc: (dy, saved, attrs) => {
28231 const [x4D, $filter] = saved;
28232 const { dilations, strides, pad, dataFormat } = attrs;
28233 assert(tupleValuesAreOne(dilations), () => 'Error in gradient of conv2D: dilation rates greater than 1 ' +
28234 `are not yet supported in gradients. Got dilations '${dilations}'`);
28235 return {
28236 x: () => conv2DBackpropInput(x4D.shape, dy, $filter, strides, pad, dataFormat),
28237 filter: () => conv2DBackpropFilter(x4D, dy, $filter.shape, strides, pad, dataFormat)
28238 };
28239 }
28240 };
28241
28242 /**
28243 * @license
28244 * Copyright 2020 Google LLC. All Rights Reserved.
28245 * Licensed under the Apache License, Version 2.0 (the "License");
28246 * you may not use this file except in compliance with the License.
28247 * You may obtain a copy of the License at
28248 *
28249 * http://www.apache.org/licenses/LICENSE-2.0
28250 *
28251 * Unless required by applicable law or agreed to in writing, software
28252 * distributed under the License is distributed on an "AS IS" BASIS,
28253 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28254 * See the License for the specific language governing permissions and
28255 * limitations under the License.
28256 * =============================================================================
28257 */
28258 const conv2DBackpropInputGradConfig = {
28259 kernelName: Conv2DBackpropInput,
28260 inputsToSave: ['dy', 'filter'],
28261 gradFunc: (ddx, saved, attrs) => {
28262 const [dy, filter] = saved;
28263 const { strides, pad, dataFormat, dimRoundingMode } = attrs;
28264 return {
28265 dy: () => conv2d(ddx, filter, strides, pad, dataFormat, 1 /* dilations */, dimRoundingMode),
28266 filter: () => conv2DBackpropFilter(ddx, dy, filter.shape, strides, pad, dataFormat, dimRoundingMode)
28267 };
28268 }
28269 };
28270
28271 /**
28272 * @license
28273 * Copyright 2020 Google LLC. All Rights Reserved.
28274 * Licensed under the Apache License, Version 2.0 (the "License");
28275 * you may not use this file except in compliance with the License.
28276 * You may obtain a copy of the License at
28277 *
28278 * http://www.apache.org/licenses/LICENSE-2.0
28279 *
28280 * Unless required by applicable law or agreed to in writing, software
28281 * distributed under the License is distributed on an "AS IS" BASIS,
28282 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28283 * See the License for the specific language governing permissions and
28284 * limitations under the License.
28285 * =============================================================================
28286 */
28287 /**
28288 * Computes the derivative of the filter of a 3D convolution.
28289 *
28290 * @param x The input tensor, of rank 5 or rank 4 of shape
28291 * [batch, depth, height, width, inChannels]. If rank 4, batch of 1 is
28292 * assumed.
28293 * @param dy The dy image, of rank 5 or rank 4, of shape
28294 * [batch, depth, height, width, outDepth]. If rank 4, batch of 1 is
28295 * assumed.
28296 * @param filterShape The shape of the filter, length 5,
28297 * [filterDepth, filterHeight, filterWidth, inDepth, outDepth].
28298 * @param strides The strides of the convolution: [strideDepth, strideHeight,
28299 * strideWidth].
28300 * @param pad A string from: 'same', 'valid'. The type of padding algorithm
28301 * used in the forward prop of the op.
28302 */
28303 function conv3DBackpropFilter_(x, dy, filterShape, strides, pad) {
28304 let x5D = x;
28305 if (x.rank === 4) {
28306 x5D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2], x.shape[3]]);
28307 }
28308 let dy5D = dy;
28309 if (dy5D.rank === 4) {
28310 dy5D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
28311 }
28312 assert(x5D.rank === 5, () => `Error in conv3dDerFilter: input must be rank 5, but got shape ` +
28313 `${x5D.shape}.`);
28314 assert(dy5D.rank === 5, () => `Error in conv3dDerFilter: dy must be rank 5, but got shape ` +
28315 `${dy5D.shape}.`);
28316 assert(filterShape.length === 5, () => `Error in conv3dDerFilter: filterShape must be length 5, but got ` +
28317 `${filterShape}.`);
28318 assert(x5D.shape[4] === filterShape[3], () => `Error in conv3dDerFilter: depth of input ${x5D.shape[4]}) must ` +
28319 `match input depth in filter (${filterShape[3]}.`);
28320 assert(dy5D.shape[4] === filterShape[4], () => `Error in conv3dDerFilter: depth of dy (${dy5D.shape[4]}) must ` +
28321 `match output depth for filter (${filterShape[4]}).`);
28322 const inputs = { x: x5D, dy: dy5D };
28323 const attrs = { strides, pad, filterShape };
28324 // tslint:disable-next-line: no-unnecessary-type-assertion
28325 return ENGINE.runKernel(Conv3DBackpropFilterV2, inputs, attrs);
28326 }
28327 const conv3DBackpropFilter = op({ conv3DBackpropFilter_ });
28328
28329 /**
28330 * @license
28331 * Copyright 2020 Google LLC. All Rights Reserved.
28332 * Licensed under the Apache License, Version 2.0 (the "License");
28333 * you may not use this file except in compliance with the License.
28334 * You may obtain a copy of the License at
28335 *
28336 * http://www.apache.org/licenses/LICENSE-2.0
28337 *
28338 * Unless required by applicable law or agreed to in writing, software
28339 * distributed under the License is distributed on an "AS IS" BASIS,
28340 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28341 * See the License for the specific language governing permissions and
28342 * limitations under the License.
28343 * =============================================================================
28344 */
28345 const conv3DGradConfig = {
28346 kernelName: Conv3D,
28347 inputsToSave: ['x', 'filter'],
28348 gradFunc: (dy, saved, attrs) => {
28349 const { dilations, strides, pad } = attrs;
28350 assert(tupleValuesAreOne(dilations), () => 'Error in gradient of conv3D: dilation rates greater than 1 are ' +
28351 `not yet supported in gradients. Got dilations '${dilations}'`);
28352 const [x5D, $filter] = saved;
28353 return {
28354 x: () => conv3DBackpropInput(x5D.shape, dy, $filter, strides, pad),
28355 filter: () => conv3DBackpropFilter(x5D, dy, $filter.shape, strides, pad)
28356 };
28357 }
28358 };
28359
28360 /**
28361 * @license
28362 * Copyright 2020 Google LLC. All Rights Reserved.
28363 * Licensed under the Apache License, Version 2.0 (the "License");
28364 * you may not use this file except in compliance with the License.
28365 * You may obtain a copy of the License at
28366 *
28367 * http://www.apache.org/licenses/LICENSE-2.0
28368 *
28369 * Unless required by applicable law or agreed to in writing, software
28370 * distributed under the License is distributed on an "AS IS" BASIS,
28371 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28372 * See the License for the specific language governing permissions and
28373 * limitations under the License.
28374 * =============================================================================
28375 */
28376 const cosGradConfig = {
28377 kernelName: Cos,
28378 inputsToSave: ['x'],
28379 gradFunc: (dy, saved) => {
28380 const [x] = saved;
28381 return { x: () => mul(neg(sin(cast(x, 'float32'))), dy) };
28382 }
28383 };
28384
28385 /**
28386 * @license
28387 * Copyright 2020 Google LLC. All Rights Reserved.
28388 * Licensed under the Apache License, Version 2.0 (the "License");
28389 * you may not use this file except in compliance with the License.
28390 * You may obtain a copy of the License at
28391 *
28392 * http://www.apache.org/licenses/LICENSE-2.0
28393 *
28394 * Unless required by applicable law or agreed to in writing, software
28395 * distributed under the License is distributed on an "AS IS" BASIS,
28396 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28397 * See the License for the specific language governing permissions and
28398 * limitations under the License.
28399 * =============================================================================
28400 */
28401 const coshGradConfig = {
28402 kernelName: Cosh,
28403 inputsToSave: ['x'],
28404 gradFunc: (dy, saved) => {
28405 const [x] = saved;
28406 return { x: () => mul(sinh(cast(x, 'float32')), dy) };
28407 }
28408 };
28409
28410 /**
28411 * @license
28412 * Copyright 2020 Google LLC. All Rights Reserved.
28413 * Licensed under the Apache License, Version 2.0 (the "License");
28414 * you may not use this file except in compliance with the License.
28415 * You may obtain a copy of the License at
28416 *
28417 * http://www.apache.org/licenses/LICENSE-2.0
28418 *
28419 * Unless required by applicable law or agreed to in writing, software
28420 * distributed under the License is distributed on an "AS IS" BASIS,
28421 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28422 * See the License for the specific language governing permissions and
28423 * limitations under the License.
28424 * =============================================================================
28425 */
28426 const cumsumGradConfig = {
28427 kernelName: Cumsum,
28428 inputsToSave: ['x'],
28429 gradFunc: (dy, saved, attrs) => {
28430 const [x] = saved;
28431 const { axis, exclusive, reverse } = attrs;
28432 return {
28433 x: () => {
28434 const permutation = getAxesPermutation([axis], x.rank);
28435 let out = cumsum(dy, axis, exclusive, !reverse);
28436 if (permutation != null) {
28437 out = transpose(out, permutation);
28438 }
28439 return out;
28440 }
28441 };
28442 }
28443 };
28444
28445 /**
28446 * @license
28447 * Copyright 2020 Google LLC. All Rights Reserved.
28448 * Licensed under the Apache License, Version 2.0 (the "License");
28449 * you may not use this file except in compliance with the License.
28450 * You may obtain a copy of the License at
28451 *
28452 * http://www.apache.org/licenses/LICENSE-2.0
28453 *
28454 * Unless required by applicable law or agreed to in writing, software
28455 * distributed under the License is distributed on an "AS IS" BASIS,
28456 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28457 * See the License for the specific language governing permissions and
28458 * limitations under the License.
28459 * =============================================================================
28460 */
28461 const depthwiseConv2dNativeGradConfig = {
28462 kernelName: DepthwiseConv2dNative,
28463 inputsToSave: ['x', 'filter'],
28464 gradFunc: (dy, saved, attrs) => {
28465 const { dilations, strides, pad, dimRoundingMode } = attrs;
28466 const $dilations = dilations == null ? [1, 1] : dilations;
28467 assert(tupleValuesAreOne($dilations), () => 'Error in gradient of depthwiseConv2dNative: dilation rates ' +
28468 `greater than 1 are not yet supported. Got dilations ` +
28469 `'${$dilations}'`);
28470 const [x, filter] = saved;
28471 assert(x.rank === 4, () => `Error in gradient of depthwiseConv2dNative: input must be ` +
28472 `rank 4, but got rank ${x.rank}.`);
28473 assert(filter.rank === 4, () => `Error in gradient of depthwiseConv2dNative: filter must be ` +
28474 `rank 4, but got rank ${filter.rank}.`);
28475 assert(x.shape[3] === filter.shape[2], () => `Error in gradient of depthwiseConv2d: number of input ` +
28476 `channels (${x.shape[3]}) must match the inChannels dimension ` +
28477 `in filter ${filter.shape[2]}.`);
28478 assert(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in gradient of depthwiseConv2d: Either strides or ' +
28479 `dilations must be 1. Got strides ${strides} and dilations ` +
28480 `'${$dilations}'.`);
28481 checkPadOnDimRoundingMode('depthwiseConv2d', pad, dimRoundingMode);
28482 return {
28483 x: () => depthwiseConv2dNativeBackpropInput(x.shape, dy, filter, strides, pad, $dilations, dimRoundingMode),
28484 filter: () => depthwiseConv2dNativeBackpropFilter(x, dy, filter.shape, strides, pad, $dilations, dimRoundingMode),
28485 };
28486 }
28487 };
28488
28489 /**
28490 * @license
28491 * Copyright 2020 Google LLC. All Rights Reserved.
28492 * Licensed under the Apache License, Version 2.0 (the "License");
28493 * you may not use this file except in compliance with the License.
28494 * You may obtain a copy of the License at
28495 *
28496 * http://www.apache.org/licenses/LICENSE-2.0
28497 *
28498 * Unless required by applicable law or agreed to in writing, software
28499 * distributed under the License is distributed on an "AS IS" BASIS,
28500 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28501 * See the License for the specific language governing permissions and
28502 * limitations under the License.
28503 * =============================================================================
28504 */
28505 const dilation2dGradConfig = {
28506 kernelName: Dilation2D,
28507 inputsToSave: ['x', 'filter'],
28508 gradFunc: (dy, saved, attrs) => {
28509 const [x, filter] = saved;
28510 const inputInputs = { x, filter, dy };
28511 const filterInputs = { x, filter, dy };
28512 return {
28513 x: () => ENGINE.runKernel(Dilation2DBackpropInput, inputInputs, attrs),
28514 filter: () => ENGINE.runKernel(Dilation2DBackpropFilter, filterInputs, attrs)
28515 };
28516 }
28517 };
28518
28519 /**
28520 * @license
28521 * Copyright 2020 Google LLC. All Rights Reserved.
28522 * Licensed under the Apache License, Version 2.0 (the "License");
28523 * you may not use this file except in compliance with the License.
28524 * You may obtain a copy of the License at
28525 *
28526 * http://www.apache.org/licenses/LICENSE-2.0
28527 *
28528 * Unless required by applicable law or agreed to in writing, software
28529 * distributed under the License is distributed on an "AS IS" BASIS,
28530 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28531 * See the License for the specific language governing permissions and
28532 * limitations under the License.
28533 * =============================================================================
28534 */
28535 const eluGradConfig = {
28536 kernelName: Elu,
28537 outputsToSave: [true],
28538 gradFunc: (dy, saved) => {
28539 const [y] = saved;
28540 const inputs = { dy, y };
28541 return { x: () => ENGINE.runKernel(EluGrad, inputs) };
28542 }
28543 };
28544
28545 /**
28546 * @license
28547 * Copyright 2020 Google LLC. All Rights Reserved.
28548 * Licensed under the Apache License, Version 2.0 (the "License");
28549 * you may not use this file except in compliance with the License.
28550 * You may obtain a copy of the License at
28551 *
28552 * http://www.apache.org/licenses/LICENSE-2.0
28553 *
28554 * Unless required by applicable law or agreed to in writing, software
28555 * distributed under the License is distributed on an "AS IS" BASIS,
28556 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28557 * See the License for the specific language governing permissions and
28558 * limitations under the License.
28559 * =============================================================================
28560 */
28561 const erfGradConfig = {
28562 kernelName: Erf,
28563 inputsToSave: ['x'],
28564 gradFunc: (dy, saved) => {
28565 const [x] = saved;
28566 const a = mul(exp(neg(square(x))), 2 / Math.sqrt(Math.PI));
28567 return { x: () => mul(dy, a) };
28568 }
28569 };
28570
28571 /**
28572 * @license
28573 * Copyright 2020 Google LLC. All Rights Reserved.
28574 * Licensed under the Apache License, Version 2.0 (the "License");
28575 * you may not use this file except in compliance with the License.
28576 * You may obtain a copy of the License at
28577 *
28578 * http://www.apache.org/licenses/LICENSE-2.0
28579 *
28580 * Unless required by applicable law or agreed to in writing, software
28581 * distributed under the License is distributed on an "AS IS" BASIS,
28582 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28583 * See the License for the specific language governing permissions and
28584 * limitations under the License.
28585 * =============================================================================
28586 */
28587 const expGradConfig = {
28588 kernelName: Exp,
28589 outputsToSave: [true],
28590 gradFunc: (dy, saved) => {
28591 const [y] = saved;
28592 return { x: () => mul(dy, y) };
28593 }
28594 };
28595
28596 /**
28597 * @license
28598 * Copyright 2020 Google LLC. All Rights Reserved.
28599 * Licensed under the Apache License, Version 2.0 (the "License");
28600 * you may not use this file except in compliance with the License.
28601 * You may obtain a copy of the License at
28602 *
28603 * http://www.apache.org/licenses/LICENSE-2.0
28604 *
28605 * Unless required by applicable law or agreed to in writing, software
28606 * distributed under the License is distributed on an "AS IS" BASIS,
28607 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28608 * See the License for the specific language governing permissions and
28609 * limitations under the License.
28610 * =============================================================================
28611 */
28612 const expandDimsGradConfig = {
28613 kernelName: ExpandDims,
28614 inputsToSave: ['input'],
28615 gradFunc: (dy, saved) => {
28616 const [input] = saved;
28617 return { input: () => reshape(dy, input.shape) };
28618 }
28619 };
28620
28621 /**
28622 * @license
28623 * Copyright 2020 Google LLC. All Rights Reserved.
28624 * Licensed under the Apache License, Version 2.0 (the "License");
28625 * you may not use this file except in compliance with the License.
28626 * You may obtain a copy of the License at
28627 *
28628 * http://www.apache.org/licenses/LICENSE-2.0
28629 *
28630 * Unless required by applicable law or agreed to in writing, software
28631 * distributed under the License is distributed on an "AS IS" BASIS,
28632 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28633 * See the License for the specific language governing permissions and
28634 * limitations under the License.
28635 * =============================================================================
28636 */
28637 const expm1GradConfig = {
28638 kernelName: Expm1,
28639 inputsToSave: ['x'],
28640 gradFunc: (dy, saved) => {
28641 const [x] = saved;
28642 return { x: () => mul(dy, exp(x)) };
28643 }
28644 };
28645
28646 /**
28647 * @license
28648 * Copyright 2020 Google LLC. All Rights Reserved.
28649 * Licensed under the Apache License, Version 2.0 (the "License");
28650 * you may not use this file except in compliance with the License.
28651 * You may obtain a copy of the License at
28652 *
28653 * http://www.apache.org/licenses/LICENSE-2.0
28654 *
28655 * Unless required by applicable law or agreed to in writing, software
28656 * distributed under the License is distributed on an "AS IS" BASIS,
28657 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28658 * See the License for the specific language governing permissions and
28659 * limitations under the License.
28660 * =============================================================================
28661 */
28662 const floorGradConfig = {
28663 kernelName: Floor,
28664 gradFunc: (dy) => {
28665 return { x: () => zerosLike(dy) };
28666 }
28667 };
28668
28669 /**
28670 * @license
28671 * Copyright 2020 Google LLC. All Rights Reserved.
28672 * Licensed under the Apache License, Version 2.0 (the "License");
28673 * you may not use this file except in compliance with the License.
28674 * You may obtain a copy of the License at
28675 *
28676 * http://www.apache.org/licenses/LICENSE-2.0
28677 *
28678 * Unless required by applicable law or agreed to in writing, software
28679 * distributed under the License is distributed on an "AS IS" BASIS,
28680 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28681 * See the License for the specific language governing permissions and
28682 * limitations under the License.
28683 * =============================================================================
28684 */
28685 const floorDivGradConfig = {
28686 kernelName: FloorDiv,
28687 inputsToSave: ['a', 'b'],
28688 gradFunc: (dy, saved) => {
28689 const [a, b] = saved;
28690 const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
28691 const derA = () => {
28692 const res = div(dy, cast(b, 'float32'));
28693 const reduceAxes = getReductionAxes(a.shape, outShape);
28694 if (reduceAxes.length > 0) {
28695 return reshape(sum$1(res, reduceAxes), a.shape);
28696 }
28697 return res;
28698 };
28699 const derB = () => {
28700 let res = mul(dy, cast(a, 'float32'));
28701 const reduceAxes = getReductionAxes(b.shape, outShape);
28702 if (reduceAxes.length > 0) {
28703 res = reshape(sum$1(res, reduceAxes), b.shape);
28704 }
28705 const tmp = square(b);
28706 return neg(div(res, cast(tmp, 'float32')));
28707 };
28708 return { a: derA, b: derB };
28709 }
28710 };
28711
28712 /**
28713 * @license
28714 * Copyright 2020 Google LLC. All Rights Reserved.
28715 * Licensed under the Apache License, Version 2.0 (the "License");
28716 * you may not use this file except in compliance with the License.
28717 * You may obtain a copy of the License at
28718 *
28719 * http://www.apache.org/licenses/LICENSE-2.0
28720 *
28721 * Unless required by applicable law or agreed to in writing, software
28722 * distributed under the License is distributed on an "AS IS" BASIS,
28723 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28724 * See the License for the specific language governing permissions and
28725 * limitations under the License.
28726 * =============================================================================
28727 */
28728 const fusedBatchNormGradConfig = {
28729 kernelName: FusedBatchNorm,
28730 inputsToSave: ['x', 'mean', 'variance', 'scale'],
28731 gradFunc: (dy, saved, attrs) => {
28732 const { varianceEpsilon } = attrs;
28733 const [x, mean, variance, scale] = saved;
28734 const scaleValue = scale == null ? scalar(1) : scale;
28735 const reductionAxes = getReductionAxes(mean.shape, x.shape);
28736 const tileShape = [];
28737 if (mean.rank === 1) {
28738 for (let i = 0; i < x.shape.length - 1; ++i) {
28739 tileShape.push(x.shape[i]);
28740 }
28741 tileShape.push(1);
28742 }
28743 const xMinusMean = sub(x, mean);
28744 const dyTimesScaleValue = mul(dy, scaleValue);
28745 const oneOverSqrtVariance = rsqrt(add$1(variance, scalar(varianceEpsilon)));
28746 const minusHalfRCube = mul(mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance), scalar(-0.5));
28747 const derX = () => {
28748 if (mean.rank === 1) {
28749 return reshape(mul(mul(dy, tile(reshape(oneOverSqrtVariance, [1, 1, 1, mean.shape[0]]), tileShape)), scaleValue), x.shape);
28750 }
28751 else {
28752 return reshape(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape);
28753 }
28754 };
28755 const derMean = () => {
28756 let meanDer = mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue);
28757 if (mean.rank === 1) {
28758 meanDer = sum$1(meanDer, reductionAxes);
28759 }
28760 return reshape(meanDer, mean.shape);
28761 };
28762 const derVariance = () => {
28763 let varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue);
28764 if (mean.rank === 1) {
28765 varianceDer = sum$1(varianceDer, reductionAxes);
28766 }
28767 return reshape(varianceDer, mean.shape);
28768 };
28769 const derScale = () => {
28770 const xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance);
28771 let scaleDer = mul(dy, xMinusMean2TimesRsqrt);
28772 if (mean.rank === 1) {
28773 scaleDer = sum$1(scaleDer, reductionAxes);
28774 }
28775 return reshape(scaleDer, mean.shape);
28776 };
28777 const derOffset = () => {
28778 let offsetDer = dy;
28779 if (mean.rank === 1) {
28780 offsetDer = sum$1(offsetDer, reductionAxes);
28781 }
28782 return reshape(offsetDer, mean.shape);
28783 };
28784 return {
28785 x: derX,
28786 mean: derMean,
28787 variance: derVariance,
28788 scale: derScale,
28789 offset: derOffset
28790 };
28791 }
28792 };
28793
28794 /**
28795 * @license
28796 * Copyright 2020 Google LLC. All Rights Reserved.
28797 * Licensed under the Apache License, Version 2.0 (the "License");
28798 * you may not use this file except in compliance with the License.
28799 * You may obtain a copy of the License at
28800 *
28801 * http://www.apache.org/licenses/LICENSE-2.0
28802 *
28803 * Unless required by applicable law or agreed to in writing, software
28804 * distributed under the License is distributed on an "AS IS" BASIS,
28805 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28806 * See the License for the specific language governing permissions and
28807 * limitations under the License.
28808 * =============================================================================
28809 */
28810 const gatherGradConfig = {
28811 kernelName: GatherV2,
28812 inputsToSave: ['x', 'indices'],
28813 gradFunc: (dy, saved, attrs) => {
28814 const [x, indices] = saved;
28815 const { axis } = attrs;
28816 const parsedAxis = parseAxisParam(axis, x.shape)[0];
28817 const derX = () => {
28818 const paramsShape = x.shape;
28819 const indicesSize = indices.size;
28820 const outerShape = paramsShape.slice(0, parsedAxis);
28821 const outerDims = outerShape.length;
28822 const innerShape = paramsShape.slice(axis, paramsShape.length).slice(1);
28823 const innerDims = innerShape.length;
28824 const outerAxesIndices = arrayRange(0, outerDims);
28825 const innerAxesIndices = arrayRange(outerDims + 1, outerDims + 1 + innerDims);
28826 const valuesShape = arrayConcat([outerShape, [indicesSize], innerShape]);
28827 const values = reshape(dy, valuesShape);
28828 const reshapedIndices = reshape(indices, [indicesSize]);
28829 const transposeDims = arrayConcat([[outerDims], outerAxesIndices, innerAxesIndices]);
28830 const valuesTranspose = transpose(values, transposeDims);
28831 let paramsGrad = unsortedSegmentSum(valuesTranspose, reshapedIndices, x.shape[parsedAxis]);
28832 const invertTransposeDims = getUndoAxesPermutation(transposeDims);
28833 paramsGrad = transpose(paramsGrad, invertTransposeDims);
28834 return paramsGrad;
28835 };
28836 return { x: derX, indices: () => indices };
28837 }
28838 };
28839 function arrayRange(start, stop) {
28840 const result = [];
28841 for (let i = start; i < stop; ++i) {
28842 result.push(i);
28843 }
28844 return result;
28845 }
28846 function arrayConcat(arrays) {
28847 const result = [];
28848 for (let i = 0; i < arrays.length; ++i) {
28849 for (let j = 0; j < arrays[i].length; ++j) {
28850 result.push(arrays[i][j]);
28851 }
28852 }
28853 return result;
28854 }
28855
28856 /**
28857 * @license
28858 * Copyright 2020 Google LLC. All Rights Reserved.
28859 * Licensed under the Apache License, Version 2.0 (the "License");
28860 * you may not use this file except in compliance with the License.
28861 * You may obtain a copy of the License at
28862 *
28863 * http://www.apache.org/licenses/LICENSE-2.0
28864 *
28865 * Unless required by applicable law or agreed to in writing, software
28866 * distributed under the License is distributed on an "AS IS" BASIS,
28867 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28868 * See the License for the specific language governing permissions and
28869 * limitations under the License.
28870 * =============================================================================
28871 */
28872 const greaterEqualGradConfig = {
28873 kernelName: GreaterEqual,
28874 inputsToSave: ['a', 'b'],
28875 gradFunc: (dy, saved) => {
28876 const [a, b] = saved;
28877 return { a: () => zerosLike(a), b: () => zerosLike(b) };
28878 }
28879 };
28880
28881 /**
28882 * @license
28883 * Copyright 2020 Google LLC. All Rights Reserved.
28884 * Licensed under the Apache License, Version 2.0 (the "License");
28885 * you may not use this file except in compliance with the License.
28886 * You may obtain a copy of the License at
28887 *
28888 * http://www.apache.org/licenses/LICENSE-2.0
28889 *
28890 * Unless required by applicable law or agreed to in writing, software
28891 * distributed under the License is distributed on an "AS IS" BASIS,
28892 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28893 * See the License for the specific language governing permissions and
28894 * limitations under the License.
28895 * =============================================================================
28896 */
28897 const identityGradConfig = {
28898 kernelName: Identity,
28899 gradFunc: (dy) => {
28900 return { x: () => cast(dy, 'float32') };
28901 }
28902 };
28903
28904 /**
28905 * @license
28906 * Copyright 2020 Google LLC. All Rights Reserved.
28907 * Licensed under the Apache License, Version 2.0 (the "License");
28908 * you may not use this file except in compliance with the License.
28909 * You may obtain a copy of the License at
28910 *
28911 * http://www.apache.org/licenses/LICENSE-2.0
28912 *
28913 * Unless required by applicable law or agreed to in writing, software
28914 * distributed under the License is distributed on an "AS IS" BASIS,
28915 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28916 * See the License for the specific language governing permissions and
28917 * limitations under the License.
28918 * =============================================================================
28919 */
28920 const isFiniteGradConfig = {
28921 kernelName: IsFinite,
28922 gradFunc: (dy) => {
28923 // TODO(nsthorat): Let gradients be null for cases where we want to stop
28924 // backpropgation.
28925 return { x: () => zerosLike(dy) };
28926 }
28927 };
28928
28929 /**
28930 * @license
28931 * Copyright 2020 Google LLC. All Rights Reserved.
28932 * Licensed under the Apache License, Version 2.0 (the "License");
28933 * you may not use this file except in compliance with the License.
28934 * You may obtain a copy of the License at
28935 *
28936 * http://www.apache.org/licenses/LICENSE-2.0
28937 *
28938 * Unless required by applicable law or agreed to in writing, software
28939 * distributed under the License is distributed on an "AS IS" BASIS,
28940 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28941 * See the License for the specific language governing permissions and
28942 * limitations under the License.
28943 * =============================================================================
28944 */
28945 const isInfGradConfig = {
28946 kernelName: IsInf,
28947 gradFunc: (dy) => {
28948 // TODO(nsthorat): Let gradients be null for cases where we want to stop
28949 // backpropgation.
28950 return { x: () => zerosLike(dy) };
28951 }
28952 };
28953
28954 /**
28955 * @license
28956 * Copyright 2020 Google LLC. All Rights Reserved.
28957 * Licensed under the Apache License, Version 2.0 (the "License");
28958 * you may not use this file except in compliance with the License.
28959 * You may obtain a copy of the License at
28960 *
28961 * http://www.apache.org/licenses/LICENSE-2.0
28962 *
28963 * Unless required by applicable law or agreed to in writing, software
28964 * distributed under the License is distributed on an "AS IS" BASIS,
28965 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28966 * See the License for the specific language governing permissions and
28967 * limitations under the License.
28968 * =============================================================================
28969 */
28970 const isNanGradConfig = {
28971 kernelName: IsNan,
28972 gradFunc: (dy) => {
28973 // TODO(nsthorat): Let gradients be null for cases where we want to stop
28974 // backpropgation.
28975 return { x: () => zerosLike(dy) };
28976 }
28977 };
28978
28979 /**
28980 * @license
28981 * Copyright 2020 Google LLC. All Rights Reserved.
28982 * Licensed under the Apache License, Version 2.0 (the "License");
28983 * you may not use this file except in compliance with the License.
28984 * You may obtain a copy of the License at
28985 *
28986 * http://www.apache.org/licenses/LICENSE-2.0
28987 *
28988 * Unless required by applicable law or agreed to in writing, software
28989 * distributed under the License is distributed on an "AS IS" BASIS,
28990 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28991 * See the License for the specific language governing permissions and
28992 * limitations under the License.
28993 * =============================================================================
28994 */
28995 const leakyReluGradConfig = {
28996 kernelName: LeakyRelu,
28997 inputsToSave: ['x'],
28998 gradFunc: (dy, saved, attrs) => {
28999 const [x] = saved;
29000 const { alpha } = attrs;
29001 const mask = greater(x, 0);
29002 // Returns `gradients * (features > 0) + alpha * gradients * (features <=
29003 // 0)`.
29004 return { x: () => where(mask, dy, mul(dy, alpha)) };
29005 }
29006 };
29007
29008 /**
29009 * @license
29010 * Copyright 2020 Google LLC. All Rights Reserved.
29011 * Licensed under the Apache License, Version 2.0 (the "License");
29012 * you may not use this file except in compliance with the License.
29013 * You may obtain a copy of the License at
29014 *
29015 * http://www.apache.org/licenses/LICENSE-2.0
29016 *
29017 * Unless required by applicable law or agreed to in writing, software
29018 * distributed under the License is distributed on an "AS IS" BASIS,
29019 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29020 * See the License for the specific language governing permissions and
29021 * limitations under the License.
29022 * =============================================================================
29023 */
29024 const log1pGradConfig = {
29025 kernelName: Log1p,
29026 inputsToSave: ['x'],
29027 gradFunc: (dy, saved) => {
29028 const [x] = saved;
29029 return { x: () => div(dy, add$1(x, 1)) };
29030 }
29031 };
29032
29033 /**
29034 * @license
29035 * Copyright 2020 Google LLC. All Rights Reserved.
29036 * Licensed under the Apache License, Version 2.0 (the "License");
29037 * you may not use this file except in compliance with the License.
29038 * You may obtain a copy of the License at
29039 *
29040 * http://www.apache.org/licenses/LICENSE-2.0
29041 *
29042 * Unless required by applicable law or agreed to in writing, software
29043 * distributed under the License is distributed on an "AS IS" BASIS,
29044 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29045 * See the License for the specific language governing permissions and
29046 * limitations under the License.
29047 * =============================================================================
29048 */
29049 const logGradConfig = {
29050 kernelName: Log,
29051 inputsToSave: ['x'],
29052 gradFunc: (dy, saved) => {
29053 const [x] = saved;
29054 return { x: () => div(dy, cast(x, 'float32')) };
29055 }
29056 };
29057
29058 /**
29059 * @license
29060 * Copyright 2020 Google LLC. All Rights Reserved.
29061 * Licensed under the Apache License, Version 2.0 (the "License");
29062 * you may not use this file except in compliance with the License.
29063 * You may obtain a copy of the License at
29064 *
29065 * http://www.apache.org/licenses/LICENSE-2.0
29066 *
29067 * Unless required by applicable law or agreed to in writing, software
29068 * distributed under the License is distributed on an "AS IS" BASIS,
29069 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29070 * See the License for the specific language governing permissions and
29071 * limitations under the License.
29072 * =============================================================================
29073 */
29074 const logSoftmaxGradConfig = {
29075 kernelName: LogSoftmax,
29076 inputsToSave: [],
29077 outputsToSave: [true],
29078 gradFunc: (dy, saved, attrs) => {
29079 const [value] = saved;
29080 const { axis } = attrs;
29081 return {
29082 logits: () => {
29083 const keepDims = true;
29084 const softmax = exp(value);
29085 return sub(dy, mul(sum$1(dy, axis, keepDims), softmax));
29086 }
29087 };
29088 }
29089 };
29090
29091 /**
29092 * @license
29093 * Copyright 2020 Google LLC. All Rights Reserved.
29094 * Licensed under the Apache License, Version 2.0 (the "License");
29095 * you may not use this file except in compliance with the License.
29096 * You may obtain a copy of the License at
29097 *
29098 * http://www.apache.org/licenses/LICENSE-2.0
29099 *
29100 * Unless required by applicable law or agreed to in writing, software
29101 * distributed under the License is distributed on an "AS IS" BASIS,
29102 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29103 * See the License for the specific language governing permissions and
29104 * limitations under the License.
29105 * =============================================================================
29106 */
29107 function localResponseNormalizationBackprop_(x, y, dy, depthRadius = 5, bias = 1, alpha = 1, beta = 0.5) {
29108 const inputs = { x, y, dy };
29109 const attrs = { depthRadius, bias, alpha, beta };
29110 return ENGINE.runKernel(LRNGrad, inputs, attrs);
29111 }
29112 const localResponseNormalizationBackprop = op({ localResponseNormalizationBackprop_ });
29113
29114 /**
29115 * @license
29116 * Copyright 2020 Google LLC. All Rights Reserved.
29117 * Licensed under the Apache License, Version 2.0 (the "License");
29118 * you may not use this file except in compliance with the License.
29119 * You may obtain a copy of the License at
29120 *
29121 * http://www.apache.org/licenses/LICENSE-2.0
29122 *
29123 * Unless required by applicable law or agreed to in writing, software
29124 * distributed under the License is distributed on an "AS IS" BASIS,
29125 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29126 * See the License for the specific language governing permissions and
29127 * limitations under the License.
29128 * =============================================================================
29129 */
29130 const lrnGradConfig = {
29131 kernelName: LRN,
29132 inputsToSave: ['x'],
29133 outputsToSave: [true],
29134 gradFunc: (dy, saved, attrs) => {
29135 const [x, y] = saved;
29136 const { depthRadius, bias, alpha, beta } = attrs;
29137 return {
29138 x: () => localResponseNormalizationBackprop(x, y, dy, depthRadius, bias, alpha, beta)
29139 };
29140 }
29141 };
29142
29143 /**
29144 * @license
29145 * Copyright 2020 Google LLC. All Rights Reserved.
29146 * Licensed under the Apache License, Version 2.0 (the "License");
29147 * you may not use this file except in compliance with the License.
29148 * You may obtain a copy of the License at
29149 *
29150 * http://www.apache.org/licenses/LICENSE-2.0
29151 *
29152 * Unless required by applicable law or agreed to in writing, software
29153 * distributed under the License is distributed on an "AS IS" BASIS,
29154 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29155 * See the License for the specific language governing permissions and
29156 * limitations under the License.
29157 * =============================================================================
29158 */
29159 /**
29160 * Gradient helper function for the min and max operations.
29161 */
29162 function gradForMinAndMax(dy, y, xOrig, origAxes) {
29163 if (y.rank < xOrig.rank) {
29164 y = reshape(y, expandShapeToKeepDim(y.shape, origAxes));
29165 }
29166 if (dy.rank < xOrig.rank) {
29167 dy = reshape(dy, expandShapeToKeepDim(dy.shape, origAxes));
29168 }
29169 return {
29170 x: () => {
29171 const dx = mul(dy, cast(equal(xOrig, y), dy.dtype));
29172 return dx;
29173 }
29174 };
29175 }
29176
29177 /**
29178 * @license
29179 * Copyright 2020 Google LLC. All Rights Reserved.
29180 * Licensed under the Apache License, Version 2.0 (the "License");
29181 * you may not use this file except in compliance with the License.
29182 * You may obtain a copy of the License at
29183 *
29184 * http://www.apache.org/licenses/LICENSE-2.0
29185 *
29186 * Unless required by applicable law or agreed to in writing, software
29187 * distributed under the License is distributed on an "AS IS" BASIS,
29188 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29189 * See the License for the specific language governing permissions and
29190 * limitations under the License.
29191 * =============================================================================
29192 */
29193 const maxGradConfig = {
29194 kernelName: Max,
29195 inputsToSave: ['x'],
29196 outputsToSave: [true],
29197 gradFunc: (dy, saved, attrs) => {
29198 const maxAttrs = attrs;
29199 const { reductionIndices } = maxAttrs;
29200 const x = saved[0];
29201 const y = saved[1];
29202 const origAxes = parseAxisParam(reductionIndices, x.shape);
29203 const maxGrad = gradForMinAndMax(dy, y, x, origAxes);
29204 return {
29205 x: () => {
29206 return maxGrad['x']();
29207 }
29208 };
29209 }
29210 };
29211
29212 /**
29213 * @license
29214 * Copyright 2020 Google LLC. All Rights Reserved.
29215 * Licensed under the Apache License, Version 2.0 (the "License");
29216 * you may not use this file except in compliance with the License.
29217 * You may obtain a copy of the License at
29218 *
29219 * http://www.apache.org/licenses/LICENSE-2.0
29220 *
29221 * Unless required by applicable law or agreed to in writing, software
29222 * distributed under the License is distributed on an "AS IS" BASIS,
29223 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29224 * See the License for the specific language governing permissions and
29225 * limitations under the License.
29226 * =============================================================================
29227 */
29228 const maximumGradConfig = {
29229 kernelName: Maximum,
29230 inputsToSave: ['a', 'b'],
29231 gradFunc: (dy, saved) => {
29232 const [a, b] = saved;
29233 const derA = () => mul(dy, cast(greaterEqual(a, b), 'float32'));
29234 const derB = () => mul(dy, cast(less(a, b), 'float32'));
29235 return { a: derA, b: derB };
29236 }
29237 };
29238
29239 /**
29240 * @license
29241 * Copyright 2020 Google LLC. All Rights Reserved.
29242 * Licensed under the Apache License, Version 2.0 (the "License");
29243 * you may not use this file except in compliance with the License.
29244 * You may obtain a copy of the License at
29245 *
29246 * http://www.apache.org/licenses/LICENSE-2.0
29247 *
29248 * Unless required by applicable law or agreed to in writing, software
29249 * distributed under the License is distributed on an "AS IS" BASIS,
29250 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29251 * See the License for the specific language governing permissions and
29252 * limitations under the License.
29253 * =============================================================================
29254 */
29255 /**
29256 * Computes the backprop of a 3d max pool.
29257 *
29258 * @param dy The dy error, of rank 5 of shape
29259 * [batchSize, depth, height, width, channels].
29260 * assumed.
29261 * @param input The original input image, of rank 5 or rank 4 of shape
29262 * [batchSize, depth, height, width, channels].
29263 * @param output The original output image, of rank 5 of shape
29264 * [batchSize, outDepth, outHeight, outWidth, channels].
29265 * @param filterSize The filter size:
29266 * `[filterDepth, filterHeight, filterWidth]`.
29267 * `filterSize` is a single number,
29268 * then `filterDepth == filterHeight == filterWidth`.
29269 * @param strides The strides of the pooling:
29270 * `[strideDepth, strideHeight, strideWidth]`. If
29271 * `strides` is a single number, then `strideHeight == strideWidth`.
29272 * @param pad A string from: 'same', 'valid'. The type of padding algorithm
29273 * used in the forward prop of the op.
29274 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
29275 * provided, it will default to truncate.
29276 */
29277 function maxPool3dGrad_(dy, input, output, filterSize, strides, pad, dimRoundingMode) {
29278 const $dy = convertToTensor(dy, 'dy', 'maxPool3dGrad');
29279 const $input = convertToTensor(input, 'input', 'maxPool3dGrad');
29280 const $output = convertToTensor(output, 'output', 'maxPool3dGrad');
29281 let dy5D = $dy;
29282 let input5D = $input;
29283 let output5D = $output;
29284 let reshapedTo5D = false;
29285 if ($input.rank === 4) {
29286 reshapedTo5D = true;
29287 dy5D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
29288 input5D = reshape($input, [
29289 1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]
29290 ]);
29291 output5D = reshape($output, [
29292 1, $output.shape[0], $output.shape[1], $output.shape[2], $output.shape[3]
29293 ]);
29294 }
29295 assert(dy5D.rank === 5, () => `Error in maxPool3dGrad: dy must be rank 5 but got rank ` +
29296 `${dy5D.rank}.`);
29297 assert(input5D.rank === 5, () => `Error in maxPool3dGrad: input must be rank 5 but got rank ` +
29298 `${input5D.rank}.`);
29299 assert(output5D.rank === 5, () => `Error in maxPool3dGrad: output must be rank 5 but got rank ` +
29300 `${output5D.rank}.`);
29301 checkPadOnDimRoundingMode('maxPool3dGrad', pad, dimRoundingMode);
29302 const inputs = { dy: dy5D, input: input5D, output: output5D };
29303 const attrs = { filterSize, strides, pad, dimRoundingMode };
29304 // tslint:disable-next-line: no-unnecessary-type-assertion
29305 const res = ENGINE.runKernel(MaxPool3DGrad, inputs, attrs);
29306 if (reshapedTo5D) {
29307 return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
29308 }
29309 return res;
29310 }
29311 const maxPool3dGrad = op({ maxPool3dGrad_ });
29312
29313 /**
29314 * @license
29315 * Copyright 2020 Google LLC. All Rights Reserved.
29316 * Licensed under the Apache License, Version 2.0 (the "License");
29317 * you may not use this file except in compliance with the License.
29318 * You may obtain a copy of the License at
29319 *
29320 * http://www.apache.org/licenses/LICENSE-2.0
29321 *
29322 * Unless required by applicable law or agreed to in writing, software
29323 * distributed under the License is distributed on an "AS IS" BASIS,
29324 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29325 * See the License for the specific language governing permissions and
29326 * limitations under the License.
29327 * =============================================================================
29328 */
29329 const maxPool3DGradConfig = {
29330 kernelName: MaxPool3D,
29331 inputsToSave: ['x'],
29332 outputsToSave: [true],
29333 gradFunc: (dy, saved, attrs) => {
29334 const [x, y] = saved;
29335 const { filterSize, strides, pad, dimRoundingMode } = attrs;
29336 return {
29337 x: () => maxPool3dGrad(dy, x, y, filterSize, strides, pad, dimRoundingMode)
29338 };
29339 }
29340 };
29341
29342 /**
29343 * @license
29344 * Copyright 2020 Google LLC. All Rights Reserved.
29345 * Licensed under the Apache License, Version 2.0 (the "License");
29346 * you may not use this file except in compliance with the License.
29347 * You may obtain a copy of the License at
29348 *
29349 * http://www.apache.org/licenses/LICENSE-2.0
29350 *
29351 * Unless required by applicable law or agreed to in writing, software
29352 * distributed under the License is distributed on an "AS IS" BASIS,
29353 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29354 * See the License for the specific language governing permissions and
29355 * limitations under the License.
29356 * =============================================================================
29357 */
29358 /**
29359 * Computes the backprop of a 2D max pool.
29360 *
29361 * @param dy The dy error, of rank 4 or rank 3 of shape
29362 * [batchSize, height, width, channels]. If rank 3, batch of 1 is
29363 * assumed.
29364 * @param input The original input image, of rank 4, of shape
29365 * [batchSize, height, width, channels].
29366 * @param output The original output image, of rank 4, of shape
29367 * [batchSize, outHeight, outWidth, channels].
29368 * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
29369 * `filterSize` is a single number, then `filterHeight == filterWidth`.
29370 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
29371 * `strides` is a single number, then `strideHeight == strideWidth`.
29372 * @param pad The type of padding algorithm used in the forward prop of the op.
29373 * 'same', 'valid', for more info, see this guide:
29374 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
29375 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
29376 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
29377 * provided, it will default to truncate.
29378 */
29379 function maxPoolGrad_(dy, input, output, filterSize, strides, pad, dimRoundingMode) {
29380 const $dy = convertToTensor(dy, 'dy', 'maxPoolGrad');
29381 const $input = convertToTensor(input, 'input', 'maxPoolGrad');
29382 const $output = convertToTensor(output, 'output', 'maxPoolGrad');
29383 assert($input.rank === $dy.rank, () => `Rank of input (${$input.rank}) does not match rank of dy ` +
29384 `(${$dy.rank})`);
29385 assert($dy.rank === 4, () => `Error in maxPoolGrad: dy must be rank 4 but got rank ` +
29386 `${$dy.rank}.`);
29387 assert($input.rank === 4, () => `Error in maxPoolGrad: input must be rank 4 but got rank ` +
29388 `${$input.rank}.`);
29389 checkPadOnDimRoundingMode('maxPoolGrad', pad, dimRoundingMode);
29390 const inputs = { dy: $dy, input: $input, output: $output };
29391 const attrs = { filterSize, strides, pad, dimRoundingMode };
29392 // tslint:disable-next-line: no-unnecessary-type-assertion
29393 return ENGINE.runKernel(MaxPoolGrad, inputs, attrs);
29394 }
29395 const maxPoolGrad = op({ maxPoolGrad_ });
29396
29397 /**
29398 * @license
29399 * Copyright 2020 Google LLC. All Rights Reserved.
29400 * Licensed under the Apache License, Version 2.0 (the "License");
29401 * you may not use this file except in compliance with the License.
29402 * You may obtain a copy of the License at
29403 *
29404 * http://www.apache.org/licenses/LICENSE-2.0
29405 *
29406 * Unless required by applicable law or agreed to in writing, software
29407 * distributed under the License is distributed on an "AS IS" BASIS,
29408 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29409 * See the License for the specific language governing permissions and
29410 * limitations under the License.
29411 * =============================================================================
29412 */
29413 const maxPoolGradConfig = {
29414 kernelName: MaxPool,
29415 inputsToSave: ['x'],
29416 outputsToSave: [true],
29417 gradFunc: (dy, saved, attrs) => {
29418 const [x, y] = saved;
29419 const { filterSize, strides, pad } = attrs;
29420 return {
29421 x: () => maxPoolGrad(dy, x, y, filterSize, strides, pad)
29422 };
29423 }
29424 };
29425
29426 /**
29427 * @license
29428 * Copyright 2020 Google LLC. All Rights Reserved.
29429 * Licensed under the Apache License, Version 2.0 (the "License");
29430 * you may not use this file except in compliance with the License.
29431 * You may obtain a copy of the License at
29432 *
29433 * http://www.apache.org/licenses/LICENSE-2.0
29434 *
29435 * Unless required by applicable law or agreed to in writing, software
29436 * distributed under the License is distributed on an "AS IS" BASIS,
29437 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29438 * See the License for the specific language governing permissions and
29439 * limitations under the License.
29440 * =============================================================================
29441 */
29442 const meanGradConfig = {
29443 kernelName: Mean,
29444 inputsToSave: ['x'],
29445 gradFunc: (dy, saved, attrs) => {
29446 const [x] = saved;
29447 const { axis } = attrs;
29448 const axes = parseAxisParam(axis, x.shape);
29449 const shapes = computeOutAndReduceShapes(x.shape, axes);
29450 const reduceShape = shapes[1];
29451 const reduceSize = sizeFromShape(reduceShape);
29452 const derX = () => {
29453 const expandedDyShape = x.shape.slice();
29454 axes.forEach(axis => {
29455 expandedDyShape[axis] = 1;
29456 });
29457 const expandedDy = reshape(dy, expandedDyShape);
29458 const res = div(mul(expandedDy, ones$1(x.shape, 'float32')), reduceSize);
29459 return res;
29460 };
29461 return { x: derX };
29462 }
29463 };
29464
29465 /**
29466 * @license
29467 * Copyright 2020 Google LLC. All Rights Reserved.
29468 * Licensed under the Apache License, Version 2.0 (the "License");
29469 * you may not use this file except in compliance with the License.
29470 * You may obtain a copy of the License at
29471 *
29472 * http://www.apache.org/licenses/LICENSE-2.0
29473 *
29474 * Unless required by applicable law or agreed to in writing, software
29475 * distributed under the License is distributed on an "AS IS" BASIS,
29476 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29477 * See the License for the specific language governing permissions and
29478 * limitations under the License.
29479 * =============================================================================
29480 */
29481 const minGradConfig = {
29482 kernelName: Min,
29483 inputsToSave: ['x'],
29484 outputsToSave: [true],
29485 gradFunc: (dy, saved, attrs) => {
29486 const minAttrs = attrs;
29487 const { axis } = minAttrs;
29488 const [x, y] = saved;
29489 const origAxes = parseAxisParam(axis, x.shape);
29490 const minGrad = gradForMinAndMax(dy, y, x, origAxes);
29491 return {
29492 x: () => {
29493 return minGrad['x']();
29494 }
29495 };
29496 }
29497 };
29498
29499 /**
29500 * @license
29501 * Copyright 2020 Google LLC. All Rights Reserved.
29502 * Licensed under the Apache License, Version 2.0 (the "License");
29503 * you may not use this file except in compliance with the License.
29504 * You may obtain a copy of the License at
29505 *
29506 * http://www.apache.org/licenses/LICENSE-2.0
29507 *
29508 * Unless required by applicable law or agreed to in writing, software
29509 * distributed under the License is distributed on an "AS IS" BASIS,
29510 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29511 * See the License for the specific language governing permissions and
29512 * limitations under the License.
29513 * =============================================================================
29514 */
29515 const minimumGradConfig = {
29516 kernelName: Minimum,
29517 inputsToSave: ['a', 'b'],
29518 gradFunc: (dy, saved) => {
29519 const [a, b] = saved;
29520 const derA = () => mul(dy, cast(lessEqual(a, b), 'float32'));
29521 const derB = () => mul(dy, cast(greater(a, b), 'float32'));
29522 return { a: derA, b: derB };
29523 }
29524 };
29525
29526 /**
29527 * @license
29528 * Copyright 2020 Google LLC. All Rights Reserved.
29529 * Licensed under the Apache License, Version 2.0 (the "License");
29530 * you may not use this file except in compliance with the License.
29531 * You may obtain a copy of the License at
29532 *
29533 * http://www.apache.org/licenses/LICENSE-2.0
29534 *
29535 * Unless required by applicable law or agreed to in writing, software
29536 * distributed under the License is distributed on an "AS IS" BASIS,
29537 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29538 * See the License for the specific language governing permissions and
29539 * limitations under the License.
29540 * =============================================================================
29541 */
29542 const mirrorPadGradConfig = {
29543 kernelName: MirrorPad,
29544 inputsToSave: ['x'],
29545 gradFunc: (dy, saved, attrs) => {
29546 // Pad introduces values around the original tensor, so the gradient
29547 // slices the original shape out of the gradient.
29548 const x = saved[0];
29549 const { paddings } = attrs;
29550 const begin = paddings.map(p => p[0]);
29551 return { x: () => slice(dy, begin, x.shape) };
29552 }
29553 };
29554
29555 /**
29556 * @license
29557 * Copyright 2020 Google LLC. All Rights Reserved.
29558 * Licensed under the Apache License, Version 2.0 (the "License");
29559 * you may not use this file except in compliance with the License.
29560 * You may obtain a copy of the License at
29561 *
29562 * http://www.apache.org/licenses/LICENSE-2.0
29563 *
29564 * Unless required by applicable law or agreed to in writing, software
29565 * distributed under the License is distributed on an "AS IS" BASIS,
29566 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29567 * See the License for the specific language governing permissions and
29568 * limitations under the License.
29569 * =============================================================================
29570 */
29571 const modGradConfig = {
29572 kernelName: Mod,
29573 inputsToSave: ['a', 'b'],
29574 gradFunc: (dy, saved) => {
29575 const [a, b] = saved;
29576 const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
29577 const derA = () => {
29578 const reduceAxes = getReductionAxes(a.shape, outShape);
29579 if (reduceAxes.length > 0) {
29580 return reshape(sum$1(dy, reduceAxes), a.shape);
29581 }
29582 return dy;
29583 };
29584 const derB = () => {
29585 const res = mul(dy, neg(floor(div(a, b))));
29586 const reduceAxes = getReductionAxes(b.shape, outShape);
29587 if (reduceAxes.length > 0) {
29588 return reshape(sum$1(res, reduceAxes), b.shape);
29589 }
29590 return res;
29591 };
29592 return { a: derA, b: derB };
29593 }
29594 };
29595
29596 /**
29597 * @license
29598 * Copyright 2020 Google LLC. All Rights Reserved.
29599 * Licensed under the Apache License, Version 2.0 (the "License");
29600 * you may not use this file except in compliance with the License.
29601 * You may obtain a copy of the License at
29602 *
29603 * http://www.apache.org/licenses/LICENSE-2.0
29604 *
29605 * Unless required by applicable law or agreed to in writing, software
29606 * distributed under the License is distributed on an "AS IS" BASIS,
29607 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29608 * See the License for the specific language governing permissions and
29609 * limitations under the License.
29610 * =============================================================================
29611 */
29612 const multiplyGradConfig = {
29613 kernelName: Multiply,
29614 inputsToSave: ['a', 'b'],
29615 gradFunc: (dy, saved) => {
29616 const [a, b] = saved;
29617 const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
29618 const derA = () => {
29619 const res = mul(dy, cast(b, 'float32'));
29620 const reduceAxes = getReductionAxes(a.shape, outShape);
29621 if (reduceAxes.length > 0) {
29622 return reshape(sum$1(res, reduceAxes), a.shape);
29623 }
29624 return res;
29625 };
29626 const derB = () => {
29627 const res = mul(dy, cast(a, 'float32'));
29628 const reduceAxes = getReductionAxes(b.shape, outShape);
29629 if (reduceAxes.length > 0) {
29630 return reshape(sum$1(res, reduceAxes), b.shape);
29631 }
29632 return res;
29633 };
29634 return { a: derA, b: derB };
29635 }
29636 };
29637
29638 /**
29639 * @license
29640 * Copyright 2020 Google LLC. All Rights Reserved.
29641 * Licensed under the Apache License, Version 2.0 (the "License");
29642 * you may not use this file except in compliance with the License.
29643 * You may obtain a copy of the License at
29644 *
29645 * http://www.apache.org/licenses/LICENSE-2.0
29646 *
29647 * Unless required by applicable law or agreed to in writing, software
29648 * distributed under the License is distributed on an "AS IS" BASIS,
29649 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29650 * See the License for the specific language governing permissions and
29651 * limitations under the License.
29652 * =============================================================================
29653 */
29654 const negGradConfig = {
29655 kernelName: Neg,
29656 gradFunc: (dy) => {
29657 return { x: () => neg(dy) };
29658 }
29659 };
29660
29661 /**
29662 * @license
29663 * Copyright 2020 Google LLC. All Rights Reserved.
29664 * Licensed under the Apache License, Version 2.0 (the "License");
29665 * you may not use this file except in compliance with the License.
29666 * You may obtain a copy of the License at
29667 *
29668 * http://www.apache.org/licenses/LICENSE-2.0
29669 *
29670 * Unless required by applicable law or agreed to in writing, software
29671 * distributed under the License is distributed on an "AS IS" BASIS,
29672 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29673 * See the License for the specific language governing permissions and
29674 * limitations under the License.
29675 * =============================================================================
29676 */
29677 const oneHotGradConfig = {
29678 kernelName: OneHot,
29679 inputsToSave: ['indices'],
29680 gradFunc: (dy, saved) => {
29681 const indices = saved[0];
29682 return { indices: () => zeros(indices.shape, 'float32') };
29683 }
29684 };
29685
29686 /**
29687 * @license
29688 * Copyright 2020 Google LLC. All Rights Reserved.
29689 * Licensed under the Apache License, Version 2.0 (the "License");
29690 * you may not use this file except in compliance with the License.
29691 * You may obtain a copy of the License at
29692 *
29693 * http://www.apache.org/licenses/LICENSE-2.0
29694 *
29695 * Unless required by applicable law or agreed to in writing, software
29696 * distributed under the License is distributed on an "AS IS" BASIS,
29697 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29698 * See the License for the specific language governing permissions and
29699 * limitations under the License.
29700 * =============================================================================
29701 */
29702 const onesLikeGradConfig = {
29703 kernelName: OnesLike,
29704 gradFunc: (dy) => {
29705 return { x: () => zerosLike(dy) };
29706 }
29707 };
29708
29709 /**
29710 * @license
29711 * Copyright 2020 Google LLC. All Rights Reserved.
29712 * Licensed under the Apache License, Version 2.0 (the "License");
29713 * you may not use this file except in compliance with the License.
29714 * You may obtain a copy of the License at
29715 *
29716 * http://www.apache.org/licenses/LICENSE-2.0
29717 *
29718 * Unless required by applicable law or agreed to in writing, software
29719 * distributed under the License is distributed on an "AS IS" BASIS,
29720 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29721 * See the License for the specific language governing permissions and
29722 * limitations under the License.
29723 * =============================================================================
29724 */
29725 const packGradConfig = {
29726 kernelName: Pack,
29727 saveAllInputs: true,
29728 gradFunc: (dy, saved, attrs) => {
29729 const { axis } = attrs;
29730 const derTensors = unstack(dy, axis);
29731 return derTensors.map(t => () => t);
29732 }
29733 };
29734
29735 /**
29736 * @license
29737 * Copyright 2020 Google LLC. All Rights Reserved.
29738 * Licensed under the Apache License, Version 2.0 (the "License");
29739 * you may not use this file except in compliance with the License.
29740 * You may obtain a copy of the License at
29741 *
29742 * http://www.apache.org/licenses/LICENSE-2.0
29743 *
29744 * Unless required by applicable law or agreed to in writing, software
29745 * distributed under the License is distributed on an "AS IS" BASIS,
29746 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29747 * See the License for the specific language governing permissions and
29748 * limitations under the License.
29749 * =============================================================================
29750 */
29751 const padV2GradConfig = {
29752 kernelName: PadV2,
29753 inputsToSave: ['x'],
29754 gradFunc: (dy, saved, attrs) => {
29755 // Pad introduces values around the original tensor, so the gradient
29756 // slices the original shape out of the gradient.
29757 const x = saved[0];
29758 const { paddings } = attrs;
29759 const begin = paddings.map(p => p[0]);
29760 return { x: () => slice(dy, begin, x.shape) };
29761 }
29762 };
29763
29764 /**
29765 * @license
29766 * Copyright 2020 Google LLC. All Rights Reserved.
29767 * Licensed under the Apache License, Version 2.0 (the "License");
29768 * you may not use this file except in compliance with the License.
29769 * You may obtain a copy of the License at
29770 *
29771 * http://www.apache.org/licenses/LICENSE-2.0
29772 *
29773 * Unless required by applicable law or agreed to in writing, software
29774 * distributed under the License is distributed on an "AS IS" BASIS,
29775 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29776 * See the License for the specific language governing permissions and
29777 * limitations under the License.
29778 * =============================================================================
29779 */
29780 const powGradConfig = {
29781 kernelName: Pow,
29782 inputsToSave: ['a', 'b'],
29783 outputsToSave: [true],
29784 gradFunc: (dy, saved) => {
29785 const [a, b, y] = saved;
29786 const base = a;
29787 const exp = b;
29788 const outShape = assertAndGetBroadcastShape(base.shape, exp.shape);
29789 const derBase = () => {
29790 const expFloat = cast(exp, 'float32');
29791 let res = mul(dy, mul(expFloat, pow(base, sub(expFloat, scalar(1)))));
29792 const reduceAxes = getReductionAxes(base.shape, outShape);
29793 if (reduceAxes.length > 0) {
29794 res = sum$1(res, reduceAxes);
29795 }
29796 return reshape(res, base.shape);
29797 };
29798 const derExp = () => {
29799 const condition = greater(base, 0);
29800 const logBase = where(condition, log$1(base), zerosLike(base));
29801 let res = mul(dy, mul(y, logBase));
29802 const reduceAxes = getReductionAxes(exp.shape, outShape);
29803 if (reduceAxes.length > 0) {
29804 res = sum$1(res, reduceAxes);
29805 }
29806 return reshape(res, exp.shape);
29807 };
29808 return { a: derBase, b: derExp };
29809 }
29810 };
29811
29812 /**
29813 * @license
29814 * Copyright 2020 Google LLC. All Rights Reserved.
29815 * Licensed under the Apache License, Version 2.0 (the "License");
29816 * you may not use this file except in compliance with the License.
29817 * You may obtain a copy of the License at
29818 *
29819 * http://www.apache.org/licenses/LICENSE-2.0
29820 *
29821 * Unless required by applicable law or agreed to in writing, software
29822 * distributed under the License is distributed on an "AS IS" BASIS,
29823 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29824 * See the License for the specific language governing permissions and
29825 * limitations under the License.
29826 * =============================================================================
29827 */
29828 const preluGradConfig = {
29829 kernelName: Prelu,
29830 inputsToSave: ['x', 'alpha'],
29831 gradFunc: (dy, saved) => {
29832 const [x, alpha] = saved;
29833 const mask = greater(x, 0);
29834 return {
29835 x: () => where(mask, dy, mul(dy, alpha)),
29836 alpha: () => {
29837 let res = where(mask, zerosLike(dy), mul(dy, x));
29838 const reduceAxes = getReductionAxes(alpha.shape, dy.shape);
29839 if (reduceAxes.length > 0) {
29840 res = sum$1(res, reduceAxes);
29841 }
29842 return reshape(res, alpha.shape);
29843 }
29844 };
29845 }
29846 };
29847
29848 /**
29849 * @license
29850 * Copyright 2022 Google Inc. All Rights Reserved.
29851 * Licensed under the Apache License, Version 2.0 (the "License");
29852 * you may not use this file except in compliance with the License.
29853 * You may obtain a copy of the License at
29854 *
29855 * http://www.apache.org/licenses/LICENSE-2.0
29856 *
29857 * Unless required by applicable law or agreed to in writing, software
29858 * distributed under the License is distributed on an "AS IS" BASIS,
29859 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29860 * See the License for the specific language governing permissions and
29861 * limitations under the License.
29862 * =============================================================================
29863 */
29864 // Gradient for product operation on a single axis.
29865 function prodGradFn_(x, dy, axis) {
29866 // The gradient tensor (dy) has a set of axes removed, so we create re-shaped
29867 // versions (of size 1) for the removed axis; this supports broadcasting over
29868 // those dimensions.
29869 const expandedYShape = x.shape.slice();
29870 expandedYShape[axis] = 1;
29871 // The actual gradient computation.
29872 const expandedDy = reshape(dy, expandedYShape);
29873 const xCumProd = cumprod(x, axis, true, false);
29874 const xCumRevProd = cumprod(x, axis, true, true);
29875 const dx = mul(xCumProd, xCumRevProd);
29876 return mul(expandedDy, dx);
29877 }
29878 // Support gradients when the product is done on many axes at once.
29879 // This done py pushing all the axes on which the product is applied into a
29880 // single axis.
29881 function prodsGradFn_(x, dy, axis) {
29882 // Move all axes for doing prod over to the end of the tensor.
29883 const xRank = x.shape.length;
29884 const finalProdAxis = xRank - axis.length;
29885 const xPermutation = getAxesPermutation(axis, xRank);
29886 let permutedX = x;
29887 if (xPermutation != null) {
29888 permutedX = transpose(x, xPermutation);
29889 }
29890 // Reshape all the prod dimensions into a single one, and do compute prod
29891 // gradients on that.
29892 const newShape = permutedX.shape.slice();
29893 const removedShape = newShape.splice(xRank - axis.length, axis.length);
29894 const endPartShape = removedShape.reduce((p, c) => p * c, 1);
29895 newShape.push(endPartShape);
29896 const reshapedPermutedX = permutedX.reshape(newShape);
29897 let prodGrad = prodGradFn_(reshapedPermutedX, dy, finalProdAxis);
29898 // Undo the re-shaping now we have the dx vector, and permute back to
29899 // original axes order.
29900 prodGrad = prodGrad.reshape(permutedX.shape);
29901 if (xPermutation != null) {
29902 const undoPermutation = getUndoAxesPermutation(xPermutation);
29903 prodGrad = transpose(prodGrad, undoPermutation);
29904 }
29905 return prodGrad;
29906 }
29907 // Running example:
29908 // [
29909 // [
29910 // [3.0, 4.0],
29911 // [5.0, 6.0],
29912 // [7.0, 8.0]
29913 // ],
29914 // [
29915 // [3.0, 5.0],
29916 // [0.0, 6.0],
29917 // [5.0, 6.0]
29918 // ]
29919 // ]
29920 //
29921 const prodGradConfig = {
29922 kernelName: Prod,
29923 inputsToSave: ['x'],
29924 gradFunc: (dy, saved, attrs) => {
29925 const [x] = saved;
29926 const { axis } = attrs;
29927 let axisArr = [];
29928 if (axis === undefined || axis === null) {
29929 axisArr = x.shape.map((_, i) => i);
29930 }
29931 else if (typeof axis === 'number') {
29932 axisArr = [axis];
29933 }
29934 else {
29935 axisArr = axis;
29936 }
29937 return { x: () => prodsGradFn_(x, dy, axisArr) };
29938 }
29939 };
29940
29941 /**
29942 * @license
29943 * Copyright 2020 Google LLC. All Rights Reserved.
29944 * Licensed under the Apache License, Version 2.0 (the "License");
29945 * you may not use this file except in compliance with the License.
29946 * You may obtain a copy of the License at
29947 *
29948 * http://www.apache.org/licenses/LICENSE-2.0
29949 *
29950 * Unless required by applicable law or agreed to in writing, software
29951 * distributed under the License is distributed on an "AS IS" BASIS,
29952 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29953 * See the License for the specific language governing permissions and
29954 * limitations under the License.
29955 * =============================================================================
29956 */
29957 const divGradConfig = {
29958 kernelName: RealDiv,
29959 inputsToSave: ['a', 'b'],
29960 gradFunc: (dy, saved) => {
29961 const [a, b] = saved;
29962 const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
29963 const derA = () => {
29964 const res = div(dy, cast(b, 'float32'));
29965 const reduceAxes = getReductionAxes(a.shape, outShape);
29966 if (reduceAxes.length > 0) {
29967 return reshape(sum$1(res, reduceAxes), a.shape);
29968 }
29969 return res;
29970 };
29971 const derB = () => {
29972 let res = mul(dy, cast(a, 'float32'));
29973 const reduceAxes = getReductionAxes(b.shape, outShape);
29974 if (reduceAxes.length > 0) {
29975 res = reshape(sum$1(res, reduceAxes), b.shape);
29976 }
29977 const tmp = square(b);
29978 return neg(div(res, cast(tmp, 'float32')));
29979 };
29980 return { a: derA, b: derB };
29981 }
29982 };
29983
29984 /**
29985 * @license
29986 * Copyright 2020 Google LLC. All Rights Reserved.
29987 * Licensed under the Apache License, Version 2.0 (the "License");
29988 * you may not use this file except in compliance with the License.
29989 * You may obtain a copy of the License at
29990 *
29991 * http://www.apache.org/licenses/LICENSE-2.0
29992 *
29993 * Unless required by applicable law or agreed to in writing, software
29994 * distributed under the License is distributed on an "AS IS" BASIS,
29995 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29996 * See the License for the specific language governing permissions and
29997 * limitations under the License.
29998 * =============================================================================
29999 */
30000 const reciprocalGradConfig = {
30001 kernelName: Reciprocal,
30002 inputsToSave: ['x'],
30003 gradFunc: (dy, saved) => {
30004 const [x] = saved;
30005 return { x: () => div(dy, neg(square(x))) };
30006 }
30007 };
30008
30009 /**
30010 * @license
30011 * Copyright 2020 Google LLC. All Rights Reserved.
30012 * Licensed under the Apache License, Version 2.0 (the "License");
30013 * you may not use this file except in compliance with the License.
30014 * You may obtain a copy of the License at
30015 *
30016 * http://www.apache.org/licenses/LICENSE-2.0
30017 *
30018 * Unless required by applicable law or agreed to in writing, software
30019 * distributed under the License is distributed on an "AS IS" BASIS,
30020 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30021 * See the License for the specific language governing permissions and
30022 * limitations under the License.
30023 * =============================================================================
30024 */
30025 const relu6GradConfig = {
30026 kernelName: Relu6,
30027 inputsToSave: ['x'],
30028 gradFunc: (dy, saved) => {
30029 const [x] = saved;
30030 const mask = mul(lessEqual(x, 6), step(x));
30031 return { x: () => mul(dy, cast(mask, 'float32')) };
30032 }
30033 };
30034
30035 /**
30036 * @license
30037 * Copyright 2020 Google LLC. All Rights Reserved.
30038 * Licensed under the Apache License, Version 2.0 (the "License");
30039 * you may not use this file except in compliance with the License.
30040 * You may obtain a copy of the License at
30041 *
30042 * http://www.apache.org/licenses/LICENSE-2.0
30043 *
30044 * Unless required by applicable law or agreed to in writing, software
30045 * distributed under the License is distributed on an "AS IS" BASIS,
30046 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30047 * See the License for the specific language governing permissions and
30048 * limitations under the License.
30049 * =============================================================================
30050 */
30051 const reluGradConfig = {
30052 kernelName: Relu,
30053 inputsToSave: ['x'],
30054 gradFunc: (dy, saved) => {
30055 const [x] = saved;
30056 return { x: () => mul(dy, cast(step(x), 'float32')) };
30057 }
30058 };
30059
30060 /**
30061 * @license
30062 * Copyright 2020 Google Inc. All Rights Reserved.
30063 * Licensed under the Apache License, Version 2.0 (the "License");
30064 * you may not use this file except in compliance with the License.
30065 * You may obtain a copy of the License at
30066 *
30067 * http://www.apache.org/licenses/LICENSE-2.0
30068 *
30069 * Unless required by applicable law or agreed to in writing, software
30070 * distributed under the License is distributed on an "AS IS" BASIS,
30071 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30072 * See the License for the specific language governing permissions and
30073 * limitations under the License.
30074 * =============================================================================
30075 */
30076 const reshapeGradConfig = {
30077 kernelName: Reshape,
30078 inputsToSave: ['x'],
30079 gradFunc: (dy, saved) => {
30080 const [x] = saved;
30081 return { x: () => reshape(dy, x.shape) };
30082 }
30083 };
30084
30085 /**
30086 * @license
30087 * Copyright 2020 Google LLC. All Rights Reserved.
30088 * Licensed under the Apache License, Version 2.0 (the "License");
30089 * you may not use this file except in compliance with the License.
30090 * You may obtain a copy of the License at
30091 *
30092 * http://www.apache.org/licenses/LICENSE-2.0
30093 *
30094 * Unless required by applicable law or agreed to in writing, software
30095 * distributed under the License is distributed on an "AS IS" BASIS,
30096 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30097 * See the License for the specific language governing permissions and
30098 * limitations under the License.
30099 * =============================================================================
30100 */
30101 const resizeBilinearGradConfig = {
30102 kernelName: ResizeBilinear,
30103 inputsToSave: ['images'],
30104 gradFunc: (dy, saved, attrs) => {
30105 const [images] = saved;
30106 const inputs = { dy, images };
30107 const imagesDer = () =>
30108 // tslint:disable-next-line: no-unnecessary-type-assertion
30109 ENGINE.runKernel(ResizeBilinearGrad, inputs, attrs);
30110 return { images: imagesDer };
30111 }
30112 };
30113
30114 /**
30115 * @license
30116 * Copyright 2020 Google LLC. All Rights Reserved.
30117 * Licensed under the Apache License, Version 2.0 (the "License");
30118 * you may not use this file except in compliance with the License.
30119 * You may obtain a copy of the License at
30120 *
30121 * http://www.apache.org/licenses/LICENSE-2.0
30122 *
30123 * Unless required by applicable law or agreed to in writing, software
30124 * distributed under the License is distributed on an "AS IS" BASIS,
30125 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30126 * See the License for the specific language governing permissions and
30127 * limitations under the License.
30128 * =============================================================================
30129 */
30130 const resizeNearestNeighborGradConfig = {
30131 kernelName: ResizeNearestNeighbor,
30132 inputsToSave: ['images'],
30133 gradFunc: (dy, saved, attrs) => {
30134 const [images] = saved;
30135 const inputs = { dy, images };
30136 const imagesDer = () =>
30137 // tslint:disable-next-line: no-unnecessary-type-assertion
30138 ENGINE.runKernel(ResizeNearestNeighborGrad, inputs, attrs);
30139 return { images: imagesDer };
30140 }
30141 };
30142
30143 /**
30144 * @license
30145 * Copyright 2020 Google LLC. All Rights Reserved.
30146 * Licensed under the Apache License, Version 2.0 (the "License");
30147 * you may not use this file except in compliance with the License.
30148 * You may obtain a copy of the License at
30149 *
30150 * http://www.apache.org/licenses/LICENSE-2.0
30151 *
30152 * Unless required by applicable law or agreed to in writing, software
30153 * distributed under the License is distributed on an "AS IS" BASIS,
30154 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30155 * See the License for the specific language governing permissions and
30156 * limitations under the License.
30157 * =============================================================================
30158 */
30159 const reverseGradConfig = {
30160 kernelName: Reverse,
30161 gradFunc: (dy, saved, attrs) => {
30162 const { dims } = attrs;
30163 const axes = parseAxisParam(dims, dy.shape);
30164 return { x: () => reverse(dy, axes) };
30165 }
30166 };
30167
30168 /**
30169 * @license
30170 * Copyright 2020 Google LLC. All Rights Reserved.
30171 * Licensed under the Apache License, Version 2.0 (the "License");
30172 * you may not use this file except in compliance with the License.
30173 * You may obtain a copy of the License at
30174 *
30175 * http://www.apache.org/licenses/LICENSE-2.0
30176 *
30177 * Unless required by applicable law or agreed to in writing, software
30178 * distributed under the License is distributed on an "AS IS" BASIS,
30179 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30180 * See the License for the specific language governing permissions and
30181 * limitations under the License.
30182 * =============================================================================
30183 */
30184 const roundGradConfig = {
30185 kernelName: Round,
30186 gradFunc: (dy) => {
30187 // TODO(nsthorat): Let gradients be null for cases where we want to stop
30188 // backpropgation.
30189 return { x: () => zerosLike(dy) };
30190 }
30191 };
30192
30193 /**
30194 * @license
30195 * Copyright 2020 Google LLC. All Rights Reserved.
30196 * Licensed under the Apache License, Version 2.0 (the "License");
30197 * you may not use this file except in compliance with the License.
30198 * You may obtain a copy of the License at
30199 *
30200 * http://www.apache.org/licenses/LICENSE-2.0
30201 *
30202 * Unless required by applicable law or agreed to in writing, software
30203 * distributed under the License is distributed on an "AS IS" BASIS,
30204 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30205 * See the License for the specific language governing permissions and
30206 * limitations under the License.
30207 * =============================================================================
30208 */
30209 const rsqrtGradConfig = {
30210 kernelName: Rsqrt,
30211 inputsToSave: ['x'],
30212 gradFunc: (dy, saved) => {
30213 const [x] = saved;
30214 return { x: () => neg(div(dy, mul(pow(x, 1.5), 2))) };
30215 }
30216 };
30217
30218 /**
30219 * @license
30220 * Copyright 2020 Google LLC. All Rights Reserved.
30221 * Licensed under the Apache License, Version 2.0 (the "License");
30222 * you may not use this file except in compliance with the License.
30223 * You may obtain a copy of the License at
30224 *
30225 * http://www.apache.org/licenses/LICENSE-2.0
30226 *
30227 * Unless required by applicable law or agreed to in writing, software
30228 * distributed under the License is distributed on an "AS IS" BASIS,
30229 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30230 * See the License for the specific language governing permissions and
30231 * limitations under the License.
30232 * =============================================================================
30233 */
30234 const selectGradConfig = {
30235 kernelName: Select,
30236 inputsToSave: ['condition'],
30237 gradFunc: (dy, saved) => {
30238 const [condition] = saved;
30239 return {
30240 // TODO(julianoks): Return null for condition gradient
30241 // when backprop supports it.
30242 condition: () => cast(zerosLike(condition), 'float32'),
30243 t: () => mul(dy, cast(condition, dy.dtype)),
30244 e: () => mul(dy, cast(logicalNot(condition), dy.dtype))
30245 };
30246 }
30247 };
30248
30249 /**
30250 * @license
30251 * Copyright 2020 Google LLC. All Rights Reserved.
30252 * Licensed under the Apache License, Version 2.0 (the "License");
30253 * you may not use this file except in compliance with the License.
30254 * You may obtain a copy of the License at
30255 *
30256 * http://www.apache.org/licenses/LICENSE-2.0
30257 *
30258 * Unless required by applicable law or agreed to in writing, software
30259 * distributed under the License is distributed on an "AS IS" BASIS,
30260 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30261 * See the License for the specific language governing permissions and
30262 * limitations under the License.
30263 * =============================================================================
30264 */
30265 const seluGradConfig = {
30266 kernelName: Selu,
30267 inputsToSave: ['x'],
30268 gradFunc: (dy, saved) => {
30269 const [x] = saved;
30270 return {
30271 x: () => {
30272 const mask = greater(x, scalar(0));
30273 const scaleAlpha = scalar(SELU_SCALEALPHA);
30274 const scale = scalar(SELU_SCALE);
30275 const greaterThanZeroDer = mul(dy, scale);
30276 const lessEqualZeroDer = mul(mul(dy, scaleAlpha), exp(cast(x, 'float32')));
30277 return where(mask, greaterThanZeroDer, lessEqualZeroDer);
30278 }
30279 };
30280 }
30281 };
30282
30283 /**
30284 * @license
30285 * Copyright 2020 Google LLC. All Rights Reserved.
30286 * Licensed under the Apache License, Version 2.0 (the "License");
30287 * you may not use this file except in compliance with the License.
30288 * You may obtain a copy of the License at
30289 *
30290 * http://www.apache.org/licenses/LICENSE-2.0
30291 *
30292 * Unless required by applicable law or agreed to in writing, software
30293 * distributed under the License is distributed on an "AS IS" BASIS,
30294 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30295 * See the License for the specific language governing permissions and
30296 * limitations under the License.
30297 * =============================================================================
30298 */
30299 const sigmoidGradConfig = {
30300 kernelName: Sigmoid,
30301 outputsToSave: [true],
30302 gradFunc: (dy, saved) => {
30303 const [y] = saved;
30304 return { x: () => mul(dy, mul(y, sub(scalar(1), y))) };
30305 }
30306 };
30307
30308 /**
30309 * @license
30310 * Copyright 2020 Google LLC. All Rights Reserved.
30311 * Licensed under the Apache License, Version 2.0 (the "License");
30312 * you may not use this file except in compliance with the License.
30313 * You may obtain a copy of the License at
30314 *
30315 * http://www.apache.org/licenses/LICENSE-2.0
30316 *
30317 * Unless required by applicable law or agreed to in writing, software
30318 * distributed under the License is distributed on an "AS IS" BASIS,
30319 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30320 * See the License for the specific language governing permissions and
30321 * limitations under the License.
30322 * =============================================================================
30323 */
30324 const signGradConfig = {
30325 kernelName: Sign,
30326 gradFunc: (dy) => {
30327 return { x: () => zerosLike(dy) };
30328 }
30329 };
30330
30331 /**
30332 * @license
30333 * Copyright 2020 Google LLC. All Rights Reserved.
30334 * Licensed under the Apache License, Version 2.0 (the "License");
30335 * you may not use this file except in compliance with the License.
30336 * You may obtain a copy of the License at
30337 *
30338 * http://www.apache.org/licenses/LICENSE-2.0
30339 *
30340 * Unless required by applicable law or agreed to in writing, software
30341 * distributed under the License is distributed on an "AS IS" BASIS,
30342 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30343 * See the License for the specific language governing permissions and
30344 * limitations under the License.
30345 * =============================================================================
30346 */
30347 const sinGradConfig = {
30348 kernelName: Sin,
30349 inputsToSave: ['x'],
30350 gradFunc: (dy, saved) => {
30351 const [x] = saved;
30352 return { x: () => mul(cos(cast(x, 'float32')), dy) };
30353 }
30354 };
30355
30356 /**
30357 * @license
30358 * Copyright 2020 Google LLC. All Rights Reserved.
30359 * Licensed under the Apache License, Version 2.0 (the "License");
30360 * you may not use this file except in compliance with the License.
30361 * You may obtain a copy of the License at
30362 *
30363 * http://www.apache.org/licenses/LICENSE-2.0
30364 *
30365 * Unless required by applicable law or agreed to in writing, software
30366 * distributed under the License is distributed on an "AS IS" BASIS,
30367 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30368 * See the License for the specific language governing permissions and
30369 * limitations under the License.
30370 * =============================================================================
30371 */
30372 const sinhGradConfig = {
30373 kernelName: Sinh,
30374 inputsToSave: ['x'],
30375 gradFunc: (dy, saved) => {
30376 const [x] = saved;
30377 return { x: () => mul(cosh(cast(x, 'float32')), dy) };
30378 }
30379 };
30380
30381 /**
30382 * @license
30383 * Copyright 2020 Google LLC. All Rights Reserved.
30384 * Licensed under the Apache License, Version 2.0 (the "License");
30385 * you may not use this file except in compliance with the License.
30386 * You may obtain a copy of the License at
30387 *
30388 * http://www.apache.org/licenses/LICENSE-2.0
30389 *
30390 * Unless required by applicable law or agreed to in writing, software
30391 * distributed under the License is distributed on an "AS IS" BASIS,
30392 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30393 * See the License for the specific language governing permissions and
30394 * limitations under the License.
30395 * =============================================================================
30396 */
30397 const sliceGradConfig = {
30398 kernelName: Slice,
30399 inputsToSave: ['x'],
30400 gradFunc: (dy, saved, attrs) => {
30401 const [x] = saved;
30402 const { begin, size } = attrs;
30403 const inputShape = x.shape;
30404 const [begin_, size_] = parseSliceParams(x, begin, size);
30405 // Create an Nx2 padding where the first column represents how many
30406 // zeros are prepended (at start) for each dimension, and the second
30407 // column indicates how many zeros are appended (at end).
30408 // The number of zeros to append is the shape of the input
30409 // elementwise-subtracted by both the begin vector and sizes vector.
30410 const paddings = [];
30411 for (let i = 0; i < dy.rank; i++) {
30412 paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]);
30413 }
30414 return { x: () => pad(dy, paddings) };
30415 }
30416 };
30417
30418 /**
30419 * @license
30420 * Copyright 2020 Google LLC. All Rights Reserved.
30421 * Licensed under the Apache License, Version 2.0 (the "License");
30422 * you may not use this file except in compliance with the License.
30423 * You may obtain a copy of the License at
30424 *
30425 * http://www.apache.org/licenses/LICENSE-2.0
30426 *
30427 * Unless required by applicable law or agreed to in writing, software
30428 * distributed under the License is distributed on an "AS IS" BASIS,
30429 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30430 * See the License for the specific language governing permissions and
30431 * limitations under the License.
30432 * =============================================================================
30433 */
30434 const softmaxGradConfig = {
30435 kernelName: Softmax,
30436 outputsToSave: [true],
30437 gradFunc: (dy, saved, attrs) => {
30438 const [y] = saved;
30439 const { dim } = attrs;
30440 const keepDims = true;
30441 const dyTimesY = mul(dy, y);
30442 return {
30443 logits: () => sub(dyTimesY, mul(sum$1(dyTimesY, [dim], keepDims), y))
30444 };
30445 }
30446 };
30447
30448 /**
30449 * @license
30450 * Copyright 2020 Google LLC. All Rights Reserved.
30451 * Licensed under the Apache License, Version 2.0 (the "License");
30452 * you may not use this file except in compliance with the License.
30453 * You may obtain a copy of the License at
30454 *
30455 * http://www.apache.org/licenses/LICENSE-2.0
30456 *
30457 * Unless required by applicable law or agreed to in writing, software
30458 * distributed under the License is distributed on an "AS IS" BASIS,
30459 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30460 * See the License for the specific language governing permissions and
30461 * limitations under the License.
30462 * =============================================================================
30463 */
30464 const softplusGradConfig = {
30465 kernelName: Softplus,
30466 inputsToSave: ['x'],
30467 gradFunc: (dy, saved) => {
30468 const [x] = saved;
30469 return { x: () => mul(dy, sigmoid(x)) };
30470 }
30471 };
30472
30473 /**
30474 * @license
30475 * Copyright 2020 Google LLC. All Rights Reserved.
30476 * Licensed under the Apache License, Version 2.0 (the "License");
30477 * you may not use this file except in compliance with the License.
30478 * You may obtain a copy of the License at
30479 *
30480 * http://www.apache.org/licenses/LICENSE-2.0
30481 *
30482 * Unless required by applicable law or agreed to in writing, software
30483 * distributed under the License is distributed on an "AS IS" BASIS,
30484 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30485 * See the License for the specific language governing permissions and
30486 * limitations under the License.
30487 * =============================================================================
30488 */
30489 const spaceToBatchNDGradConfig = {
30490 kernelName: SpaceToBatchND,
30491 gradFunc: (dy, saved, attrs) => {
30492 const { blockShape, paddings } = attrs;
30493 return { x: () => batchToSpaceND(dy, blockShape, paddings) };
30494 }
30495 };
30496
30497 /**
30498 * @license
30499 * Copyright 2020 Google LLC. All Rights Reserved.
30500 * Licensed under the Apache License, Version 2.0 (the "License");
30501 * you may not use this file except in compliance with the License.
30502 * You may obtain a copy of the License at
30503 *
30504 * http://www.apache.org/licenses/LICENSE-2.0
30505 *
30506 * Unless required by applicable law or agreed to in writing, software
30507 * distributed under the License is distributed on an "AS IS" BASIS,
30508 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30509 * See the License for the specific language governing permissions and
30510 * limitations under the License.
30511 * =============================================================================
30512 */
30513 const splitVGradConfig = {
30514 kernelName: SplitV,
30515 gradFunc: (dy, saved, attrs) => {
30516 const { axis } = attrs;
30517 return { x: () => concat(dy, axis) };
30518 }
30519 };
30520
30521 /**
30522 * @license
30523 * Copyright 2020 Google LLC. All Rights Reserved.
30524 * Licensed under the Apache License, Version 2.0 (the "License");
30525 * you may not use this file except in compliance with the License.
30526 * You may obtain a copy of the License at
30527 *
30528 * http://www.apache.org/licenses/LICENSE-2.0
30529 *
30530 * Unless required by applicable law or agreed to in writing, software
30531 * distributed under the License is distributed on an "AS IS" BASIS,
30532 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30533 * See the License for the specific language governing permissions and
30534 * limitations under the License.
30535 * =============================================================================
30536 */
30537 const sqrtGradConfig = {
30538 kernelName: Sqrt,
30539 inputsToSave: ['x'],
30540 gradFunc: (dy, saved) => {
30541 const [x] = saved;
30542 return { x: () => div(dy, mul(sqrt(cast(x, 'float32')), 2)) };
30543 }
30544 };
30545
30546 /**
30547 * @license
30548 * Copyright 2019 Google LLC. All Rights Reserved.
30549 * Licensed under the Apache License, Version 2.0 (the "License");
30550 * you may not use this file except in compliance with the License.
30551 * You may obtain a copy of the License at
30552 *
30553 * http://www.apache.org/licenses/LICENSE-2.0
30554 *
30555 * Unless required by applicable law or agreed to in writing, software
30556 * distributed under the License is distributed on an "AS IS" BASIS,
30557 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30558 * See the License for the specific language governing permissions and
30559 * limitations under the License.
30560 * =============================================================================
30561 */
30562 const squareGradConfig = {
30563 kernelName: Square,
30564 inputsToSave: ['x'],
30565 gradFunc: (dy, saved) => {
30566 const [x] = saved;
30567 return { x: () => mul(dy, mul(cast(x, 'float32'), 2)) };
30568 }
30569 };
30570
30571 /**
30572 * @license
30573 * Copyright 2020 Google LLC. All Rights Reserved.
30574 * Licensed under the Apache License, Version 2.0 (the "License");
30575 * you may not use this file except in compliance with the License.
30576 * You may obtain a copy of the License at
30577 *
30578 * http://www.apache.org/licenses/LICENSE-2.0
30579 *
30580 * Unless required by applicable law or agreed to in writing, software
30581 * distributed under the License is distributed on an "AS IS" BASIS,
30582 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30583 * See the License for the specific language governing permissions and
30584 * limitations under the License.
30585 * =============================================================================
30586 */
30587 const squaredDifferenceGradConfig = {
30588 kernelName: SquaredDifference,
30589 inputsToSave: ['a', 'b'],
30590 gradFunc: (dy, saved) => {
30591 const [a, b] = saved;
30592 const two = scalar(2);
30593 const derA = () => mul(dy, mul(two, sub(a, b)));
30594 const derB = () => mul(dy, mul(two, sub(b, a)));
30595 return { a: derA, b: derB };
30596 }
30597 };
30598
30599 /**
30600 * @license
30601 * Copyright 2020 Google LLC. All Rights Reserved.
30602 * Licensed under the Apache License, Version 2.0 (the "License");
30603 * you may not use this file except in compliance with the License.
30604 * You may obtain a copy of the License at
30605 *
30606 * http://www.apache.org/licenses/LICENSE-2.0
30607 *
30608 * Unless required by applicable law or agreed to in writing, software
30609 * distributed under the License is distributed on an "AS IS" BASIS,
30610 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30611 * See the License for the specific language governing permissions and
30612 * limitations under the License.
30613 * =============================================================================
30614 */
30615 const stepGradConfig = {
30616 kernelName: Step,
30617 gradFunc: (dy) => {
30618 // TODO(manrajgrover): Return null for gradients when backprop supports
30619 // it.
30620 return { x: () => zerosLike(dy) };
30621 }
30622 };
30623
30624 /**
30625 * @license
30626 * Copyright 2020 Google LLC. All Rights Reserved.
30627 * Licensed under the Apache License, Version 2.0 (the "License");
30628 * you may not use this file except in compliance with the License.
30629 * You may obtain a copy of the License at
30630 *
30631 * http://www.apache.org/licenses/LICENSE-2.0
30632 *
30633 * Unless required by applicable law or agreed to in writing, software
30634 * distributed under the License is distributed on an "AS IS" BASIS,
30635 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30636 * See the License for the specific language governing permissions and
30637 * limitations under the License.
30638 * =============================================================================
30639 */
30640 const subGradConfig = {
30641 kernelName: Sub,
30642 inputsToSave: ['a', 'b'],
30643 gradFunc: (dy, saved) => {
30644 const [a, b] = saved;
30645 const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
30646 const derA = () => {
30647 let res = dy;
30648 const reduceAxes = getReductionAxes(a.shape, outShape);
30649 if (reduceAxes.length > 0) {
30650 res = sum$1(res, reduceAxes);
30651 }
30652 return reshape(res, a.shape);
30653 };
30654 const derB = () => {
30655 let res = dy;
30656 const reduceAxes = getReductionAxes(b.shape, outShape);
30657 if (reduceAxes.length > 0) {
30658 res = sum$1(res, reduceAxes);
30659 }
30660 return reshape(neg(res), b.shape);
30661 };
30662 return { a: derA, b: derB };
30663 }
30664 };
30665
30666 /**
30667 * @license
30668 * Copyright 2020 Google Inc. All Rights Reserved.
30669 * Licensed under the Apache License, Version 2.0 (the "License");
30670 * you may not use this file except in compliance with the License.
30671 * You may obtain a copy of the License at
30672 *
30673 * http://www.apache.org/licenses/LICENSE-2.0
30674 *
30675 * Unless required by applicable law or agreed to in writing, software
30676 * distributed under the License is distributed on an "AS IS" BASIS,
30677 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30678 * See the License for the specific language governing permissions and
30679 * limitations under the License.
30680 * =============================================================================
30681 */
30682 const sumGradConfig = {
30683 kernelName: Sum,
30684 inputsToSave: ['x'],
30685 gradFunc: (dy, saved, attrs) => {
30686 const [x] = saved;
30687 const expandedDyShape = x.shape.slice();
30688 const { axis } = attrs;
30689 const axes = parseAxisParam(axis, x.shape);
30690 axes.forEach(axis => {
30691 expandedDyShape[axis] = 1;
30692 });
30693 const expandedDy = reshape(dy, expandedDyShape);
30694 const derX = mul(expandedDy, ones$1(x.shape, 'float32'));
30695 return { x: () => derX };
30696 }
30697 };
30698
30699 /**
30700 * @license
30701 * Copyright 2020 Google LLC. All Rights Reserved.
30702 * Licensed under the Apache License, Version 2.0 (the "License");
30703 * you may not use this file except in compliance with the License.
30704 * You may obtain a copy of the License at
30705 *
30706 * http://www.apache.org/licenses/LICENSE-2.0
30707 *
30708 * Unless required by applicable law or agreed to in writing, software
30709 * distributed under the License is distributed on an "AS IS" BASIS,
30710 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30711 * See the License for the specific language governing permissions and
30712 * limitations under the License.
30713 * =============================================================================
30714 */
30715 const tanGradConfig = {
30716 kernelName: Tan,
30717 inputsToSave: ['x'],
30718 gradFunc: (dy, saved) => {
30719 const [x] = saved;
30720 return { x: () => div(dy, square(cos(x))) };
30721 }
30722 };
30723
30724 /**
30725 * @license
30726 * Copyright 2020 Google LLC. All Rights Reserved.
30727 * Licensed under the Apache License, Version 2.0 (the "License");
30728 * you may not use this file except in compliance with the License.
30729 * You may obtain a copy of the License at
30730 *
30731 * http://www.apache.org/licenses/LICENSE-2.0
30732 *
30733 * Unless required by applicable law or agreed to in writing, software
30734 * distributed under the License is distributed on an "AS IS" BASIS,
30735 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30736 * See the License for the specific language governing permissions and
30737 * limitations under the License.
30738 * =============================================================================
30739 */
30740 const tanhGradConfig = {
30741 kernelName: Tanh,
30742 outputsToSave: [true],
30743 gradFunc: (dy, saved) => {
30744 const [y] = saved;
30745 return { x: () => mul(sub(scalar(1), square(y)), dy) };
30746 }
30747 };
30748
30749 /**
30750 * @license
30751 * Copyright 2020 Google LLC. All Rights Reserved.
30752 * Licensed under the Apache License, Version 2.0 (the "License");
30753 * you may not use this file except in compliance with the License.
30754 * You may obtain a copy of the License at
30755 *
30756 * http://www.apache.org/licenses/LICENSE-2.0
30757 *
30758 * Unless required by applicable law or agreed to in writing, software
30759 * distributed under the License is distributed on an "AS IS" BASIS,
30760 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30761 * See the License for the specific language governing permissions and
30762 * limitations under the License.
30763 * =============================================================================
30764 */
30765 const tileGradConfig = {
30766 kernelName: Tile,
30767 inputsToSave: ['x'],
30768 gradFunc: (dy, saved, attrs) => {
30769 const [x] = saved;
30770 const { reps } = attrs;
30771 const derX = () => {
30772 let xGrad = zerosLike(x);
30773 // TODO(cais): Maybe reduce memory footprint by avoiding repeated
30774 // slicing.
30775 if (x.rank === 1) {
30776 for (let i = 0; i < reps[0]; ++i) {
30777 xGrad = add$1(xGrad, slice(dy, [i * x.shape[0]], [x.shape[0]]));
30778 }
30779 }
30780 else if (x.rank === 2) {
30781 for (let i = 0; i < reps[0]; ++i) {
30782 for (let j = 0; j < reps[1]; ++j) {
30783 xGrad = add$1(xGrad, slice(dy, [i * x.shape[0], j * x.shape[1]], [
30784 x.shape[0], x.shape[1]
30785 ]));
30786 }
30787 }
30788 }
30789 else if (x.rank === 3) {
30790 for (let i = 0; i < reps[0]; ++i) {
30791 for (let j = 0; j < reps[1]; ++j) {
30792 for (let k = 0; k < reps[2]; ++k) {
30793 xGrad =
30794 add$1(xGrad, slice(dy, [i * x.shape[0], j * x.shape[1], k * x.shape[2]], [x.shape[0], x.shape[1], x.shape[2]]));
30795 }
30796 }
30797 }
30798 }
30799 else if (x.rank === 4) {
30800 for (let i = 0; i < reps[0]; ++i) {
30801 for (let j = 0; j < reps[1]; ++j) {
30802 for (let k = 0; k < reps[2]; ++k) {
30803 for (let l = 0; l < reps[3]; ++l) {
30804 xGrad =
30805 add$1(xGrad, slice(dy, [
30806 i * x.shape[0], j * x.shape[1], k * x.shape[2],
30807 l * x.shape[3]
30808 ], [x.shape[0], x.shape[1], x.shape[2], x.shape[3]]));
30809 }
30810 }
30811 }
30812 }
30813 }
30814 else {
30815 throw new Error(`Gradient for tile operation is not implemented for rank-` +
30816 `${x.rank} tensors yet.`);
30817 }
30818 return xGrad;
30819 };
30820 return { x: derX };
30821 },
30822 };
30823
30824 /**
30825 * @license
30826 * Copyright 2020 Google LLC. All Rights Reserved.
30827 * Licensed under the Apache License, Version 2.0 (the "License");
30828 * you may not use this file except in compliance with the License.
30829 * You may obtain a copy of the License at
30830 *
30831 * http://www.apache.org/licenses/LICENSE-2.0
30832 *
30833 * Unless required by applicable law or agreed to in writing, software
30834 * distributed under the License is distributed on an "AS IS" BASIS,
30835 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30836 * See the License for the specific language governing permissions and
30837 * limitations under the License.
30838 * =============================================================================
30839 */
30840 const transposeGradConfig = {
30841 kernelName: Transpose,
30842 gradFunc: (dy, saved, attrs) => {
30843 const transposeAttrs = attrs;
30844 const { perm } = transposeAttrs;
30845 const undoPerm = getUndoAxesPermutation(perm);
30846 return { x: () => transpose(dy, undoPerm) };
30847 }
30848 };
30849
30850 /**
30851 * @license
30852 * Copyright 2020 Google Inc. All Rights Reserved.
30853 * Licensed under the Apache License, Version 2.0 (the "License");
30854 * you may not use this file except in compliance with the License.
30855 * You may obtain a copy of the License at
30856 *
30857 * http://www.apache.org/licenses/LICENSE-2.0
30858 *
30859 * Unless required by applicable law or agreed to in writing, software
30860 * distributed under the License is distributed on an "AS IS" BASIS,
30861 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30862 * See the License for the specific language governing permissions and
30863 * limitations under the License.
30864 * =============================================================================
30865 */
30866 const unpackGradConfig = {
30867 kernelName: Unpack,
30868 gradFunc: (dy, saved, attrs) => {
30869 const unpackAttrs = attrs;
30870 const { axis } = unpackAttrs;
30871 return { value: () => stack(dy, axis) };
30872 }
30873 };
30874
30875 /**
30876 * @license
30877 * Copyright 2020 Google LLC. All Rights Reserved.
30878 * Licensed under the Apache License, Version 2.0 (the "License");
30879 * you may not use this file except in compliance with the License.
30880 * You may obtain a copy of the License at
30881 *
30882 * http://www.apache.org/licenses/LICENSE-2.0
30883 *
30884 * Unless required by applicable law or agreed to in writing, software
30885 * distributed under the License is distributed on an "AS IS" BASIS,
30886 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30887 * See the License for the specific language governing permissions and
30888 * limitations under the License.
30889 * =============================================================================
30890 */
30891 const unsortedSegmentSumGradConfig = {
30892 kernelName: UnsortedSegmentSum,
30893 inputsToSave: ['segmentIds'],
30894 gradFunc: (dy, saved) => {
30895 const [segmentIds] = saved;
30896 const derX = () => {
30897 return gatherDropNegatives(dy, segmentIds);
30898 };
30899 return { x: derX };
30900 }
30901 };
30902 function gatherDropNegatives(x, indices) {
30903 // Helper function for unsorted segment ops. Gathers params for
30904 // positive segment ids and gathers 0 for inputs with negative segment id.
30905 // Mirrors _GatherDropNegatives from tensorflow/python/ops/math_grad.py
30906 const zeroClippedIndices = maximum(indices, zerosLike(indices));
30907 const gathered = gather(x, zeroClippedIndices);
30908 let isPositive = greaterEqual(indices, scalar(0, 'int32'));
30909 const numIters = gathered.rank - isPositive.rank;
30910 for (let i = 0; i < numIters; ++i) {
30911 isPositive = expandDims(isPositive, i + 1);
30912 }
30913 isPositive = logicalAnd(isPositive, ones$1(gathered.shape, 'bool'));
30914 const zeroSlice = zerosLike(gathered);
30915 return where(isPositive, gathered, zeroSlice);
30916 }
30917
30918 /**
30919 * @license
30920 * Copyright 2020 Google LLC. All Rights Reserved.
30921 * Licensed under the Apache License, Version 2.0 (the "License");
30922 * you may not use this file except in compliance with the License.
30923 * You may obtain a copy of the License at
30924 *
30925 * http://www.apache.org/licenses/LICENSE-2.0
30926 *
30927 * Unless required by applicable law or agreed to in writing, software
30928 * distributed under the License is distributed on an "AS IS" BASIS,
30929 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30930 * See the License for the specific language governing permissions and
30931 * limitations under the License.
30932 * =============================================================================
30933 */
30934 const zerosLikeGradConfig = {
30935 kernelName: ZerosLike,
30936 gradFunc: (dy) => {
30937 return { x: () => zerosLike(dy) };
30938 }
30939 };
30940
30941 /**
30942 * @license
30943 * Copyright 2020 Google LLC. All Rights Reserved.
30944 * Licensed under the Apache License, Version 2.0 (the "License");
30945 * you may not use this file except in compliance with the License.
30946 * You may obtain a copy of the License at
30947 *
30948 * http://www.apache.org/licenses/LICENSE-2.0
30949 *
30950 * Unless required by applicable law or agreed to in writing, software
30951 * distributed under the License is distributed on an "AS IS" BASIS,
30952 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30953 * See the License for the specific language governing permissions and
30954 * limitations under the License.
30955 * =============================================================================
30956 */
30957 // Export all kernel configs here so that the package can auto register them
30958 const gradConfigs = [
30959 absGradConfig,
30960 acosGradConfig,
30961 acoshGradConfig,
30962 addGradConfig,
30963 addNGradConfig,
30964 argMaxGradConfig,
30965 argMinGradConfig,
30966 asinGradConfig,
30967 asinhGradConfig,
30968 atan2GradConfig,
30969 atanGradConfig,
30970 atanhGradConfig,
30971 avgPool3DGradConfig,
30972 avgPoolGradConfig,
30973 batchMatMulGradConfig,
30974 batchToSpaceNDGradConfig,
30975 broadcastToGradConfig,
30976 castGradConfig,
30977 ceilGradConfig,
30978 clipByValueGradConfig,
30979 complexAbsGradConfig,
30980 concatGradConfig,
30981 conv2DBackpropInputGradConfig,
30982 conv2DGradConfig,
30983 conv3DGradConfig,
30984 cosGradConfig,
30985 coshGradConfig,
30986 cumsumGradConfig,
30987 depthwiseConv2dNativeGradConfig,
30988 dilation2dGradConfig,
30989 divGradConfig,
30990 eluGradConfig,
30991 erfGradConfig,
30992 expGradConfig,
30993 expandDimsGradConfig,
30994 expm1GradConfig,
30995 floorDivGradConfig,
30996 floorGradConfig,
30997 fusedBatchNormGradConfig,
30998 gatherGradConfig,
30999 greaterEqualGradConfig,
31000 identityGradConfig,
31001 isFiniteGradConfig,
31002 isInfGradConfig,
31003 isNanGradConfig,
31004 leakyReluGradConfig,
31005 log1pGradConfig,
31006 logGradConfig,
31007 logSoftmaxGradConfig,
31008 lrnGradConfig,
31009 maxGradConfig,
31010 maxGradConfig,
31011 maximumGradConfig,
31012 maxPool3DGradConfig,
31013 maxPoolGradConfig,
31014 meanGradConfig,
31015 minGradConfig,
31016 minimumGradConfig,
31017 mirrorPadGradConfig,
31018 modGradConfig,
31019 multiplyGradConfig,
31020 negGradConfig,
31021 oneHotGradConfig,
31022 onesLikeGradConfig,
31023 packGradConfig,
31024 padV2GradConfig,
31025 padV2GradConfig,
31026 powGradConfig,
31027 preluGradConfig,
31028 prodGradConfig,
31029 reciprocalGradConfig,
31030 relu6GradConfig,
31031 reluGradConfig,
31032 reshapeGradConfig,
31033 resizeBilinearGradConfig,
31034 resizeNearestNeighborGradConfig,
31035 reverseGradConfig,
31036 roundGradConfig,
31037 rsqrtGradConfig,
31038 selectGradConfig,
31039 seluGradConfig,
31040 sigmoidGradConfig,
31041 signGradConfig,
31042 sinGradConfig,
31043 sinhGradConfig,
31044 sliceGradConfig,
31045 softmaxGradConfig,
31046 softplusGradConfig,
31047 spaceToBatchNDGradConfig,
31048 spaceToBatchNDGradConfig,
31049 splitVGradConfig,
31050 splitVGradConfig,
31051 sqrtGradConfig,
31052 squaredDifferenceGradConfig,
31053 squareGradConfig,
31054 stepGradConfig,
31055 subGradConfig,
31056 sumGradConfig,
31057 tanGradConfig,
31058 tanhGradConfig,
31059 tileGradConfig,
31060 transposeGradConfig,
31061 unpackGradConfig,
31062 unsortedSegmentSumGradConfig,
31063 zerosLikeGradConfig
31064 ];
31065 for (const gradientConfig of gradConfigs) {
31066 registerGradient(gradientConfig);
31067 }
31068
31069 /**
31070 * @license
31071 * Copyright 2020 Google LLC. All Rights Reserved.
31072 * Licensed under the Apache License, Version 2.0 (the "License");
31073 * you may not use this file except in compliance with the License.
31074 * You may obtain a copy of the License at
31075 *
31076 * http://www.apache.org/licenses/LICENSE-2.0
31077 *
31078 * Unless required by applicable law or agreed to in writing, software
31079 * distributed under the License is distributed on an "AS IS" BASIS,
31080 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31081 * See the License for the specific language governing permissions and
31082 * limitations under the License.
31083 * =============================================================================
31084 */
31085 getGlobalTensorClass().prototype.abs = function () {
31086 this.throwIfDisposed();
31087 return abs(this);
31088 };
31089
31090 /**
31091 * @license
31092 * Copyright 2020 Google LLC. All Rights Reserved.
31093 * Licensed under the Apache License, Version 2.0 (the "License");
31094 * you may not use this file except in compliance with the License.
31095 * You may obtain a copy of the License at
31096 *
31097 * http://www.apache.org/licenses/LICENSE-2.0
31098 *
31099 * Unless required by applicable law or agreed to in writing, software
31100 * distributed under the License is distributed on an "AS IS" BASIS,
31101 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31102 * See the License for the specific language governing permissions and
31103 * limitations under the License.
31104 * =============================================================================
31105 */
31106 getGlobalTensorClass().prototype.acos = function () {
31107 this.throwIfDisposed();
31108 return acos(this);
31109 };
31110
31111 /**
31112 * @license
31113 * Copyright 2020 Google LLC. All Rights Reserved.
31114 * Licensed under the Apache License, Version 2.0 (the "License");
31115 * you may not use this file except in compliance with the License.
31116 * You may obtain a copy of the License at
31117 *
31118 * http://www.apache.org/licenses/LICENSE-2.0
31119 *
31120 * Unless required by applicable law or agreed to in writing, software
31121 * distributed under the License is distributed on an "AS IS" BASIS,
31122 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31123 * See the License for the specific language governing permissions and
31124 * limitations under the License.
31125 * =============================================================================
31126 */
31127 getGlobalTensorClass().prototype.acosh = function () {
31128 this.throwIfDisposed();
31129 return acosh(this);
31130 };
31131
31132 /**
31133 * @license
31134 * Copyright 2020 Google LLC. All Rights Reserved.
31135 * Licensed under the Apache License, Version 2.0 (the "License");
31136 * you may not use this file except in compliance with the License.
31137 * You may obtain a copy of the License at
31138 *
31139 * http://www.apache.org/licenses/LICENSE-2.0
31140 *
31141 * Unless required by applicable law or agreed to in writing, software
31142 * distributed under the License is distributed on an "AS IS" BASIS,
31143 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31144 * See the License for the specific language governing permissions and
31145 * limitations under the License.
31146 * =============================================================================
31147 */
31148 getGlobalTensorClass().prototype.add = function (b) {
31149 this.throwIfDisposed();
31150 return add$1(this, b);
31151 };
31152
31153 /**
31154 * @license
31155 * Copyright 2020 Google LLC. All Rights Reserved.
31156 * Licensed under the Apache License, Version 2.0 (the "License");
31157 * you may not use this file except in compliance with the License.
31158 * You may obtain a copy of the License at
31159 *
31160 * http://www.apache.org/licenses/LICENSE-2.0
31161 *
31162 * Unless required by applicable law or agreed to in writing, software
31163 * distributed under the License is distributed on an "AS IS" BASIS,
31164 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31165 * See the License for the specific language governing permissions and
31166 * limitations under the License.
31167 * =============================================================================
31168 */
31169 getGlobalTensorClass().prototype.all = function (axis, keepDims) {
31170 this.throwIfDisposed();
31171 return all(this, axis, keepDims);
31172 };
31173
31174 /**
31175 * @license
31176 * Copyright 2020 Google LLC. All Rights Reserved.
31177 * Licensed under the Apache License, Version 2.0 (the "License");
31178 * you may not use this file except in compliance with the License.
31179 * You may obtain a copy of the License at
31180 *
31181 * http://www.apache.org/licenses/LICENSE-2.0
31182 *
31183 * Unless required by applicable law or agreed to in writing, software
31184 * distributed under the License is distributed on an "AS IS" BASIS,
31185 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31186 * See the License for the specific language governing permissions and
31187 * limitations under the License.
31188 * =============================================================================
31189 */
31190 getGlobalTensorClass().prototype.any = function (axis, keepDims) {
31191 this.throwIfDisposed();
31192 return any(this, axis, keepDims);
31193 };
31194
31195 /**
31196 * @license
31197 * Copyright 2020 Google LLC. All Rights Reserved.
31198 * Licensed under the Apache License, Version 2.0 (the "License");
31199 * you may not use this file except in compliance with the License.
31200 * You may obtain a copy of the License at
31201 *
31202 * http://www.apache.org/licenses/LICENSE-2.0
31203 *
31204 * Unless required by applicable law or agreed to in writing, software
31205 * distributed under the License is distributed on an "AS IS" BASIS,
31206 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31207 * See the License for the specific language governing permissions and
31208 * limitations under the License.
31209 * =============================================================================
31210 */
31211 getGlobalTensorClass().prototype.argMax = function (axis) {
31212 this.throwIfDisposed();
31213 return argMax(this, axis);
31214 };
31215
31216 /**
31217 * @license
31218 * Copyright 2020 Google LLC. All Rights Reserved.
31219 * Licensed under the Apache License, Version 2.0 (the "License");
31220 * you may not use this file except in compliance with the License.
31221 * You may obtain a copy of the License at
31222 *
31223 * http://www.apache.org/licenses/LICENSE-2.0
31224 *
31225 * Unless required by applicable law or agreed to in writing, software
31226 * distributed under the License is distributed on an "AS IS" BASIS,
31227 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31228 * See the License for the specific language governing permissions and
31229 * limitations under the License.
31230 * =============================================================================
31231 */
31232 getGlobalTensorClass().prototype.argMin = function (axis) {
31233 this.throwIfDisposed();
31234 return argMin(this, axis);
31235 };
31236
31237 /**
31238 * @license
31239 * Copyright 2020 Google LLC. All Rights Reserved.
31240 * Licensed under the Apache License, Version 2.0 (the "License");
31241 * you may not use this file except in compliance with the License.
31242 * You may obtain a copy of the License at
31243 *
31244 * http://www.apache.org/licenses/LICENSE-2.0
31245 *
31246 * Unless required by applicable law or agreed to in writing, software
31247 * distributed under the License is distributed on an "AS IS" BASIS,
31248 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31249 * See the License for the specific language governing permissions and
31250 * limitations under the License.
31251 * =============================================================================
31252 */
31253 /**
31254 * Converts a size-1 `tf.Tensor` to a `tf.Scalar`.
31255 * @doc {heading: 'Tensors', subheading: 'Classes'}
31256 */
31257 getGlobalTensorClass().prototype.asScalar = function () {
31258 this.throwIfDisposed();
31259 assert(this.size === 1, () => 'The array must have only 1 element.');
31260 return reshape(this, []);
31261 };
31262
31263 /**
31264 * @license
31265 * Copyright 2020 Google LLC. All Rights Reserved.
31266 * Licensed under the Apache License, Version 2.0 (the "License");
31267 * you may not use this file except in compliance with the License.
31268 * You may obtain a copy of the License at
31269 *
31270 * http://www.apache.org/licenses/LICENSE-2.0
31271 *
31272 * Unless required by applicable law or agreed to in writing, software
31273 * distributed under the License is distributed on an "AS IS" BASIS,
31274 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31275 * See the License for the specific language governing permissions and
31276 * limitations under the License.
31277 * =============================================================================
31278 */
31279 /**
31280 * Casts a `tf.Tensor` to a specified dtype.
31281 *
31282 * @param dtype Data-type to cast the tensor to.
31283 *
31284 * @doc {heading: 'Tensors', subheading: 'Classes'}
31285 */
31286 getGlobalTensorClass().prototype.asType = function (dtype) {
31287 this.throwIfDisposed();
31288 return cast(this, dtype);
31289 };
31290
31291 /**
31292 * @license
31293 * Copyright 2020 Google LLC. All Rights Reserved.
31294 * Licensed under the Apache License, Version 2.0 (the "License");
31295 * you may not use this file except in compliance with the License.
31296 * You may obtain a copy of the License at
31297 *
31298 * http://www.apache.org/licenses/LICENSE-2.0
31299 *
31300 * Unless required by applicable law or agreed to in writing, software
31301 * distributed under the License is distributed on an "AS IS" BASIS,
31302 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31303 * See the License for the specific language governing permissions and
31304 * limitations under the License.
31305 * =============================================================================
31306 */
31307 /**
31308 * Converts a `tf.Tensor` to a `tf.Tensor1D`.
31309 * @doc {heading: 'Tensors', subheading: 'Classes'}
31310 */
31311 getGlobalTensorClass().prototype.as1D = function () {
31312 this.throwIfDisposed();
31313 return reshape(this, [this.size]);
31314 };
31315
31316 /**
31317 * @license
31318 * Copyright 2020 Google LLC. All Rights Reserved.
31319 * Licensed under the Apache License, Version 2.0 (the "License");
31320 * you may not use this file except in compliance with the License.
31321 * You may obtain a copy of the License at
31322 *
31323 * http://www.apache.org/licenses/LICENSE-2.0
31324 *
31325 * Unless required by applicable law or agreed to in writing, software
31326 * distributed under the License is distributed on an "AS IS" BASIS,
31327 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31328 * See the License for the specific language governing permissions and
31329 * limitations under the License.
31330 * =============================================================================
31331 */
31332 /**
31333 * Converts a `tf.Tensor` to a `tf.Tensor2D`.
31334 *
31335 * @param rows Number of rows in `tf.Tensor2D`.
31336 * @param columns Number of columns in `tf.Tensor2D`.
31337 * @doc {heading: 'Tensors', subheading: 'Classes'}
31338 */
31339 getGlobalTensorClass().prototype.as2D = function (rows, columns) {
31340 this.throwIfDisposed();
31341 return reshape(this, [rows, columns]);
31342 };
31343
31344 /**
31345 * @license
31346 * Copyright 2020 Google LLC. All Rights Reserved.
31347 * Licensed under the Apache License, Version 2.0 (the "License");
31348 * you may not use this file except in compliance with the License.
31349 * You may obtain a copy of the License at
31350 *
31351 * http://www.apache.org/licenses/LICENSE-2.0
31352 *
31353 * Unless required by applicable law or agreed to in writing, software
31354 * distributed under the License is distributed on an "AS IS" BASIS,
31355 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31356 * See the License for the specific language governing permissions and
31357 * limitations under the License.
31358 * =============================================================================
31359 */
31360 /**
31361 * Converts a `tf.Tensor` to a `tf.Tensor3D`.
31362 *
31363 * @param rows Number of rows in `tf.Tensor3D`.
31364 * @param columns Number of columns in `tf.Tensor3D`.
31365 * @param depth Depth of `tf.Tensor3D`.
31366 * @doc {heading: 'Tensors', subheading: 'Classes'}
31367 */
31368 getGlobalTensorClass().prototype.as3D = function (rows, columns, depth) {
31369 this.throwIfDisposed();
31370 return reshape(this, [rows, columns, depth]);
31371 };
31372
31373 /**
31374 * @license
31375 * Copyright 2020 Google LLC. All Rights Reserved.
31376 * Licensed under the Apache License, Version 2.0 (the "License");
31377 * you may not use this file except in compliance with the License.
31378 * You may obtain a copy of the License at
31379 *
31380 * http://www.apache.org/licenses/LICENSE-2.0
31381 *
31382 * Unless required by applicable law or agreed to in writing, software
31383 * distributed under the License is distributed on an "AS IS" BASIS,
31384 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31385 * See the License for the specific language governing permissions and
31386 * limitations under the License.
31387 * =============================================================================
31388 */
31389 /**
31390 * Converts a `tf.Tensor` to a `tf.Tensor4D`.
31391 *
31392 * @param rows Number of rows in `tf.Tensor4D`.
31393 * @param columns Number of columns in `tf.Tensor4D`.
31394 * @param depth Depth of `tf.Tensor4D`.
31395 * @param depth2 4th dimension of `tf.Tensor4D`.
31396 * @doc {heading: 'Tensors', subheading: 'Classes'}
31397 */
31398 getGlobalTensorClass().prototype.as4D = function (rows, columns, depth, depth2) {
31399 this.throwIfDisposed();
31400 return reshape(this, [rows, columns, depth, depth2]);
31401 };
31402
31403 /**
31404 * @license
31405 * Copyright 2020 Google LLC. All Rights Reserved.
31406 * Licensed under the Apache License, Version 2.0 (the "License");
31407 * you may not use this file except in compliance with the License.
31408 * You may obtain a copy of the License at
31409 *
31410 * http://www.apache.org/licenses/LICENSE-2.0
31411 *
31412 * Unless required by applicable law or agreed to in writing, software
31413 * distributed under the License is distributed on an "AS IS" BASIS,
31414 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31415 * See the License for the specific language governing permissions and
31416 * limitations under the License.
31417 * =============================================================================
31418 */
31419 /**
31420 * Converts a `tf.Tensor` to a `tf.Tensor5D`.
31421 *
31422 * @param rows Number of rows in `tf.Tensor5D`.
31423 * @param columns Number of columns in `tf.Tensor5D`.
31424 * @param depth Depth of `tf.Tensor5D`.
31425 * @param depth2 4th dimension of `tf.Tensor5D`.
31426 * @param depth3 5th dimension of 'tf.Tensor5D'
31427 *
31428 * @doc {heading: 'Tensors', subheading: 'Classes'}
31429 */
31430 getGlobalTensorClass().prototype.as5D = function (rows, columns, depth, depth2, depth3) {
31431 this.throwIfDisposed();
31432 return reshape(this, [rows, columns, depth, depth2, depth3]);
31433 };
31434
31435 /**
31436 * @license
31437 * Copyright 2020 Google LLC. All Rights Reserved.
31438 * Licensed under the Apache License, Version 2.0 (the "License");
31439 * you may not use this file except in compliance with the License.
31440 * You may obtain a copy of the License at
31441 *
31442 * http://www.apache.org/licenses/LICENSE-2.0
31443 *
31444 * Unless required by applicable law or agreed to in writing, software
31445 * distributed under the License is distributed on an "AS IS" BASIS,
31446 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31447 * See the License for the specific language governing permissions and
31448 * limitations under the License.
31449 * =============================================================================
31450 */
31451 getGlobalTensorClass().prototype.asin = function () {
31452 this.throwIfDisposed();
31453 return asin(this);
31454 };
31455
31456 /**
31457 * @license
31458 * Copyright 2020 Google LLC. All Rights Reserved.
31459 * Licensed under the Apache License, Version 2.0 (the "License");
31460 * you may not use this file except in compliance with the License.
31461 * You may obtain a copy of the License at
31462 *
31463 * http://www.apache.org/licenses/LICENSE-2.0
31464 *
31465 * Unless required by applicable law or agreed to in writing, software
31466 * distributed under the License is distributed on an "AS IS" BASIS,
31467 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31468 * See the License for the specific language governing permissions and
31469 * limitations under the License.
31470 * =============================================================================
31471 */
31472 getGlobalTensorClass().prototype.asinh = function () {
31473 this.throwIfDisposed();
31474 return asinh(this);
31475 };
31476
31477 /**
31478 * @license
31479 * Copyright 2020 Google LLC. All Rights Reserved.
31480 * Licensed under the Apache License, Version 2.0 (the "License");
31481 * you may not use this file except in compliance with the License.
31482 * You may obtain a copy of the License at
31483 *
31484 * http://www.apache.org/licenses/LICENSE-2.0
31485 *
31486 * Unless required by applicable law or agreed to in writing, software
31487 * distributed under the License is distributed on an "AS IS" BASIS,
31488 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31489 * See the License for the specific language governing permissions and
31490 * limitations under the License.
31491 * =============================================================================
31492 */
31493 getGlobalTensorClass().prototype.atan = function () {
31494 this.throwIfDisposed();
31495 return atan(this);
31496 };
31497
31498 /**
31499 * @license
31500 * Copyright 2020 Google LLC. All Rights Reserved.
31501 * Licensed under the Apache License, Version 2.0 (the "License");
31502 * you may not use this file except in compliance with the License.
31503 * You may obtain a copy of the License at
31504 *
31505 * http://www.apache.org/licenses/LICENSE-2.0
31506 *
31507 * Unless required by applicable law or agreed to in writing, software
31508 * distributed under the License is distributed on an "AS IS" BASIS,
31509 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31510 * See the License for the specific language governing permissions and
31511 * limitations under the License.
31512 * =============================================================================
31513 */
31514 getGlobalTensorClass().prototype.atan2 = function (b) {
31515 this.throwIfDisposed();
31516 return atan2(this, b);
31517 };
31518
31519 /**
31520 * @license
31521 * Copyright 2020 Google LLC. All Rights Reserved.
31522 * Licensed under the Apache License, Version 2.0 (the "License");
31523 * you may not use this file except in compliance with the License.
31524 * You may obtain a copy of the License at
31525 *
31526 * http://www.apache.org/licenses/LICENSE-2.0
31527 *
31528 * Unless required by applicable law or agreed to in writing, software
31529 * distributed under the License is distributed on an "AS IS" BASIS,
31530 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31531 * See the License for the specific language governing permissions and
31532 * limitations under the License.
31533 * =============================================================================
31534 */
31535 getGlobalTensorClass().prototype.atanh = function () {
31536 this.throwIfDisposed();
31537 return atanh(this);
31538 };
31539
31540 getGlobalTensorClass().prototype.avgPool =
31541 function (filterSize, strides, pad, dimRoundingMode) {
31542 this.throwIfDisposed();
31543 return avgPool(this, filterSize, strides, pad, dimRoundingMode);
31544 };
31545
31546 /**
31547 * @license
31548 * Copyright 2020 Google LLC. All Rights Reserved.
31549 * Licensed under the Apache License, Version 2.0 (the "License");
31550 * you may not use this file except in compliance with the License.
31551 * You may obtain a copy of the License at
31552 *
31553 * http://www.apache.org/licenses/LICENSE-2.0
31554 *
31555 * Unless required by applicable law or agreed to in writing, software
31556 * distributed under the License is distributed on an "AS IS" BASIS,
31557 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31558 * See the License for the specific language governing permissions and
31559 * limitations under the License.
31560 * =============================================================================
31561 */
31562 getGlobalTensorClass().prototype.batchToSpaceND = function (blockShape, crops) {
31563 this.throwIfDisposed();
31564 return batchToSpaceND(this, blockShape, crops);
31565 };
31566
31567 /**
31568 * @license
31569 * Copyright 2020 Google LLC. All Rights Reserved.
31570 * Licensed under the Apache License, Version 2.0 (the "License");
31571 * you may not use this file except in compliance with the License.
31572 * You may obtain a copy of the License at
31573 *
31574 * http://www.apache.org/licenses/LICENSE-2.0
31575 *
31576 * Unless required by applicable law or agreed to in writing, software
31577 * distributed under the License is distributed on an "AS IS" BASIS,
31578 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31579 * See the License for the specific language governing permissions and
31580 * limitations under the License.
31581 * =============================================================================
31582 */
31583 getGlobalTensorClass().prototype.batchNorm = function (mean, variance, offset, scale, varianceEpsilon) {
31584 this.throwIfDisposed();
31585 return batchNorm(this, mean, variance, offset, scale, varianceEpsilon);
31586 };
31587
31588 /**
31589 * @license
31590 * Copyright 2020 Google LLC. All Rights Reserved.
31591 * Licensed under the Apache License, Version 2.0 (the "License");
31592 * you may not use this file except in compliance with the License.
31593 * You may obtain a copy of the License at
31594 *
31595 * http://www.apache.org/licenses/LICENSE-2.0
31596 *
31597 * Unless required by applicable law or agreed to in writing, software
31598 * distributed under the License is distributed on an "AS IS" BASIS,
31599 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31600 * See the License for the specific language governing permissions and
31601 * limitations under the License.
31602 * =============================================================================
31603 */
31604 getGlobalTensorClass().prototype.broadcastTo = function (shape) {
31605 this.throwIfDisposed();
31606 return broadcastTo(this, shape);
31607 };
31608
31609 /**
31610 * @license
31611 * Copyright 2020 Google LLC. All Rights Reserved.
31612 * Licensed under the Apache License, Version 2.0 (the "License");
31613 * you may not use this file except in compliance with the License.
31614 * You may obtain a copy of the License at
31615 *
31616 * http://www.apache.org/licenses/LICENSE-2.0
31617 *
31618 * Unless required by applicable law or agreed to in writing, software
31619 * distributed under the License is distributed on an "AS IS" BASIS,
31620 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31621 * See the License for the specific language governing permissions and
31622 * limitations under the License.
31623 * =============================================================================
31624 */
31625 getGlobalTensorClass().prototype.cast = function (dtype) {
31626 this.throwIfDisposed();
31627 return cast(this, dtype);
31628 };
31629
31630 /**
31631 * @license
31632 * Copyright 2020 Google LLC. All Rights Reserved.
31633 * Licensed under the Apache License, Version 2.0 (the "License");
31634 * you may not use this file except in compliance with the License.
31635 * You may obtain a copy of the License at
31636 *
31637 * http://www.apache.org/licenses/LICENSE-2.0
31638 *
31639 * Unless required by applicable law or agreed to in writing, software
31640 * distributed under the License is distributed on an "AS IS" BASIS,
31641 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31642 * See the License for the specific language governing permissions and
31643 * limitations under the License.
31644 * =============================================================================
31645 */
31646 getGlobalTensorClass().prototype.ceil = function () {
31647 this.throwIfDisposed();
31648 return ceil(this);
31649 };
31650
31651 /**
31652 * @license
31653 * Copyright 2020 Google LLC. All Rights Reserved.
31654 * Licensed under the Apache License, Version 2.0 (the "License");
31655 * you may not use this file except in compliance with the License.
31656 * You may obtain a copy of the License at
31657 *
31658 * http://www.apache.org/licenses/LICENSE-2.0
31659 *
31660 * Unless required by applicable law or agreed to in writing, software
31661 * distributed under the License is distributed on an "AS IS" BASIS,
31662 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31663 * See the License for the specific language governing permissions and
31664 * limitations under the License.
31665 * =============================================================================
31666 */
31667 getGlobalTensorClass().prototype.clipByValue = function (min, max) {
31668 this.throwIfDisposed();
31669 return clipByValue(this, min, max);
31670 };
31671
31672 /**
31673 * @license
31674 * Copyright 2020 Google LLC. All Rights Reserved.
31675 * Licensed under the Apache License, Version 2.0 (the "License");
31676 * you may not use this file except in compliance with the License.
31677 * You may obtain a copy of the License at
31678 *
31679 * http://www.apache.org/licenses/LICENSE-2.0
31680 *
31681 * Unless required by applicable law or agreed to in writing, software
31682 * distributed under the License is distributed on an "AS IS" BASIS,
31683 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31684 * See the License for the specific language governing permissions and
31685 * limitations under the License.
31686 * =============================================================================
31687 */
31688 getGlobalTensorClass().prototype.concat = function (x, axis) {
31689 this.throwIfDisposed();
31690 if (x instanceof Tensor) {
31691 x = [x];
31692 }
31693 return concat([this, ...x], axis);
31694 };
31695
31696 /**
31697 * @license
31698 * Copyright 2020 Google LLC. All Rights Reserved.
31699 * Licensed under the Apache License, Version 2.0 (the "License");
31700 * you may not use this file except in compliance with the License.
31701 * You may obtain a copy of the License at
31702 *
31703 * http://www.apache.org/licenses/LICENSE-2.0
31704 *
31705 * Unless required by applicable law or agreed to in writing, software
31706 * distributed under the License is distributed on an "AS IS" BASIS,
31707 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31708 * See the License for the specific language governing permissions and
31709 * limitations under the License.
31710 * =============================================================================
31711 */
31712 getGlobalTensorClass().prototype.conv1d = function (filter, stride, pad, dataFormat, dilation, dimRoundingMode) {
31713 this.throwIfDisposed();
31714 return conv1d(this, filter, stride, pad, dataFormat, dilation, dimRoundingMode);
31715 };
31716
31717 /**
31718 * @license
31719 * Copyright 2020 Google LLC. All Rights Reserved.
31720 * Licensed under the Apache License, Version 2.0 (the "License");
31721 * you may not use this file except in compliance with the License.
31722 * You may obtain a copy of the License at
31723 *
31724 * http://www.apache.org/licenses/LICENSE-2.0
31725 *
31726 * Unless required by applicable law or agreed to in writing, software
31727 * distributed under the License is distributed on an "AS IS" BASIS,
31728 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31729 * See the License for the specific language governing permissions and
31730 * limitations under the License.
31731 * =============================================================================
31732 */
31733 getGlobalTensorClass().prototype.conv2dTranspose =
31734 function (filter, outputShape, strides, pad, dimRoundingMode) {
31735 this.throwIfDisposed();
31736 return conv2dTranspose(this, filter, outputShape, strides, pad, dimRoundingMode);
31737 };
31738
31739 /**
31740 * @license
31741 * Copyright 2020 Google LLC. All Rights Reserved.
31742 * Licensed under the Apache License, Version 2.0 (the "License");
31743 * you may not use this file except in compliance with the License.
31744 * You may obtain a copy of the License at
31745 *
31746 * http://www.apache.org/licenses/LICENSE-2.0
31747 *
31748 * Unless required by applicable law or agreed to in writing, software
31749 * distributed under the License is distributed on an "AS IS" BASIS,
31750 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31751 * See the License for the specific language governing permissions and
31752 * limitations under the License.
31753 * =============================================================================
31754 */
31755 getGlobalTensorClass().prototype.conv2d = function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
31756 this.throwIfDisposed();
31757 return conv2d(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
31758 };
31759
31760 /**
31761 * @license
31762 * Copyright 2020 Google LLC. All Rights Reserved.
31763 * Licensed under the Apache License, Version 2.0 (the "License");
31764 * you may not use this file except in compliance with the License.
31765 * You may obtain a copy of the License at
31766 *
31767 * http://www.apache.org/licenses/LICENSE-2.0
31768 *
31769 * Unless required by applicable law or agreed to in writing, software
31770 * distributed under the License is distributed on an "AS IS" BASIS,
31771 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31772 * See the License for the specific language governing permissions and
31773 * limitations under the License.
31774 * =============================================================================
31775 */
31776 getGlobalTensorClass().prototype.cos = function () {
31777 this.throwIfDisposed();
31778 return cos(this);
31779 };
31780
31781 /**
31782 * @license
31783 * Copyright 2020 Google LLC. All Rights Reserved.
31784 * Licensed under the Apache License, Version 2.0 (the "License");
31785 * you may not use this file except in compliance with the License.
31786 * You may obtain a copy of the License at
31787 *
31788 * http://www.apache.org/licenses/LICENSE-2.0
31789 *
31790 * Unless required by applicable law or agreed to in writing, software
31791 * distributed under the License is distributed on an "AS IS" BASIS,
31792 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31793 * See the License for the specific language governing permissions and
31794 * limitations under the License.
31795 * =============================================================================
31796 */
31797 getGlobalTensorClass().prototype.cosh = function () {
31798 this.throwIfDisposed();
31799 return cosh(this);
31800 };
31801
31802 /**
31803 * @license
31804 * Copyright 2022 Google LLC. All Rights Reserved.
31805 * Licensed under the Apache License, Version 2.0 (the 'License');
31806 * you may not use this file except in compliance with the License.
31807 * You may obtain a copy of the License at
31808 *
31809 * http://www.apache.org/licenses/LICENSE-2.0
31810 *
31811 * Unless required by applicable law or agreed to in writing, software
31812 * distributed under the License is distributed on an 'AS IS' BASIS,
31813 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31814 * See the License for the specific language governing permissions and
31815 * limitations under the License.
31816 * =============================================================================
31817 */
31818 getGlobalTensorClass().prototype.cumprod = function (axis, exclusive, reverse) {
31819 this.throwIfDisposed();
31820 return cumprod(this, axis, exclusive, reverse);
31821 };
31822
31823 /**
31824 * @license
31825 * Copyright 2020 Google LLC. All Rights Reserved.
31826 * Licensed under the Apache License, Version 2.0 (the "License");
31827 * you may not use this file except in compliance with the License.
31828 * You may obtain a copy of the License at
31829 *
31830 * http://www.apache.org/licenses/LICENSE-2.0
31831 *
31832 * Unless required by applicable law or agreed to in writing, software
31833 * distributed under the License is distributed on an "AS IS" BASIS,
31834 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31835 * See the License for the specific language governing permissions and
31836 * limitations under the License.
31837 * =============================================================================
31838 */
31839 getGlobalTensorClass().prototype.cumsum = function (axis, exclusive, reverse) {
31840 this.throwIfDisposed();
31841 return cumsum(this, axis, exclusive, reverse);
31842 };
31843
31844 /**
31845 * @license
31846 * Copyright 2020 Google LLC. All Rights Reserved.
31847 * Licensed under the Apache License, Version 2.0 (the "License");
31848 * you may not use this file except in compliance with the License.
31849 * You may obtain a copy of the License at
31850 *
31851 * http://www.apache.org/licenses/LICENSE-2.0
31852 *
31853 * Unless required by applicable law or agreed to in writing, software
31854 * distributed under the License is distributed on an "AS IS" BASIS,
31855 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31856 * See the License for the specific language governing permissions and
31857 * limitations under the License.
31858 * =============================================================================
31859 */
31860 getGlobalTensorClass().prototype.depthToSpace = function (blockSize, dataFormat) {
31861 this.throwIfDisposed();
31862 return depthToSpace(this, blockSize, dataFormat);
31863 };
31864
31865 /**
31866 * @license
31867 * Copyright 2020 Google LLC. All Rights Reserved.
31868 * Licensed under the Apache License, Version 2.0 (the "License");
31869 * you may not use this file except in compliance with the License.
31870 * You may obtain a copy of the License at
31871 *
31872 * http://www.apache.org/licenses/LICENSE-2.0
31873 *
31874 * Unless required by applicable law or agreed to in writing, software
31875 * distributed under the License is distributed on an "AS IS" BASIS,
31876 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31877 * See the License for the specific language governing permissions and
31878 * limitations under the License.
31879 * =============================================================================
31880 */
31881 getGlobalTensorClass().prototype.depthwiseConv2d =
31882 function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
31883 this.throwIfDisposed();
31884 return depthwiseConv2d(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
31885 };
31886
31887 /**
31888 * @license
31889 * Copyright 2020 Google LLC. All Rights Reserved.
31890 * Licensed under the Apache License, Version 2.0 (the "License");
31891 * you may not use this file except in compliance with the License.
31892 * You may obtain a copy of the License at
31893 *
31894 * http://www.apache.org/licenses/LICENSE-2.0
31895 *
31896 * Unless required by applicable law or agreed to in writing, software
31897 * distributed under the License is distributed on an "AS IS" BASIS,
31898 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31899 * See the License for the specific language governing permissions and
31900 * limitations under the License.
31901 * =============================================================================
31902 */
31903 getGlobalTensorClass().prototype.dilation2d =
31904 function (filter, strides, pad, dilations, dataFormat) {
31905 this.throwIfDisposed();
31906 return dilation2d(this, filter, strides, pad, dilations, dataFormat);
31907 };
31908
31909 /**
31910 * @license
31911 * Copyright 2020 Google LLC. All Rights Reserved.
31912 * Licensed under the Apache License, Version 2.0 (the "License");
31913 * you may not use this file except in compliance with the License.
31914 * You may obtain a copy of the License at
31915 *
31916 * http://www.apache.org/licenses/LICENSE-2.0
31917 *
31918 * Unless required by applicable law or agreed to in writing, software
31919 * distributed under the License is distributed on an "AS IS" BASIS,
31920 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31921 * See the License for the specific language governing permissions and
31922 * limitations under the License.
31923 * =============================================================================
31924 */
31925 getGlobalTensorClass().prototype.divNoNan = function (b) {
31926 this.throwIfDisposed();
31927 return divNoNan(this, b);
31928 };
31929
31930 /**
31931 * @license
31932 * Copyright 2020 Google LLC. All Rights Reserved.
31933 * Licensed under the Apache License, Version 2.0 (the "License");
31934 * you may not use this file except in compliance with the License.
31935 * You may obtain a copy of the License at
31936 *
31937 * http://www.apache.org/licenses/LICENSE-2.0
31938 *
31939 * Unless required by applicable law or agreed to in writing, software
31940 * distributed under the License is distributed on an "AS IS" BASIS,
31941 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31942 * See the License for the specific language governing permissions and
31943 * limitations under the License.
31944 * =============================================================================
31945 */
31946 getGlobalTensorClass().prototype.div = function (b) {
31947 this.throwIfDisposed();
31948 return div(this, b);
31949 };
31950
31951 /**
31952 * @license
31953 * Copyright 2020 Google LLC. All Rights Reserved.
31954 * Licensed under the Apache License, Version 2.0 (the "License");
31955 * you may not use this file except in compliance with the License.
31956 * You may obtain a copy of the License at
31957 *
31958 * http://www.apache.org/licenses/LICENSE-2.0
31959 *
31960 * Unless required by applicable law or agreed to in writing, software
31961 * distributed under the License is distributed on an "AS IS" BASIS,
31962 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31963 * See the License for the specific language governing permissions and
31964 * limitations under the License.
31965 * =============================================================================
31966 */
31967 getGlobalTensorClass().prototype.dot = function (b) {
31968 this.throwIfDisposed();
31969 return dot(this, b);
31970 };
31971
31972 /**
31973 * @license
31974 * Copyright 2020 Google LLC. All Rights Reserved.
31975 * Licensed under the Apache License, Version 2.0 (the "License");
31976 * you may not use this file except in compliance with the License.
31977 * You may obtain a copy of the License at
31978 *
31979 * http://www.apache.org/licenses/LICENSE-2.0
31980 *
31981 * Unless required by applicable law or agreed to in writing, software
31982 * distributed under the License is distributed on an "AS IS" BASIS,
31983 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31984 * See the License for the specific language governing permissions and
31985 * limitations under the License.
31986 * =============================================================================
31987 */
31988 getGlobalTensorClass().prototype.elu = function () {
31989 this.throwIfDisposed();
31990 return elu(this);
31991 };
31992
31993 /**
31994 * @license
31995 * Copyright 2020 Google LLC. All Rights Reserved.
31996 * Licensed under the Apache License, Version 2.0 (the "License");
31997 * you may not use this file except in compliance with the License.
31998 * You may obtain a copy of the License at
31999 *
32000 * http://www.apache.org/licenses/LICENSE-2.0
32001 *
32002 * Unless required by applicable law or agreed to in writing, software
32003 * distributed under the License is distributed on an "AS IS" BASIS,
32004 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32005 * See the License for the specific language governing permissions and
32006 * limitations under the License.
32007 * =============================================================================
32008 */
32009 getGlobalTensorClass().prototype.equal = function (b) {
32010 this.throwIfDisposed();
32011 return equal(this, b);
32012 };
32013
32014 /**
32015 * @license
32016 * Copyright 2020 Google LLC. All Rights Reserved.
32017 * Licensed under the Apache License, Version 2.0 (the "License");
32018 * you may not use this file except in compliance with the License.
32019 * You may obtain a copy of the License at
32020 *
32021 * http://www.apache.org/licenses/LICENSE-2.0
32022 *
32023 * Unless required by applicable law or agreed to in writing, software
32024 * distributed under the License is distributed on an "AS IS" BASIS,
32025 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32026 * See the License for the specific language governing permissions and
32027 * limitations under the License.
32028 * =============================================================================
32029 */
32030 getGlobalTensorClass().prototype.erf = function () {
32031 this.throwIfDisposed();
32032 return erf(this);
32033 };
32034
32035 /**
32036 * @license
32037 * Copyright 2021 Google LLC. All Rights Reserved.
32038 * Licensed under the Apache License, Version 2.0 (the "License");
32039 * you may not use this file except in compliance with the License.
32040 * You may obtain a copy of the License at
32041 *
32042 * http://www.apache.org/licenses/LICENSE-2.0
32043 *
32044 * Unless required by applicable law or agreed to in writing, software
32045 * distributed under the License is distributed on an "AS IS" BASIS,
32046 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32047 * See the License for the specific language governing permissions and
32048 * limitations under the License.
32049 * =============================================================================
32050 */
32051 getGlobalTensorClass().prototype.euclideanNorm = function (axis, keepDims) {
32052 this.throwIfDisposed();
32053 return euclideanNorm(this, axis, keepDims);
32054 };
32055
32056 /**
32057 * @license
32058 * Copyright 2020 Google LLC. All Rights Reserved.
32059 * Licensed under the Apache License, Version 2.0 (the "License");
32060 * you may not use this file except in compliance with the License.
32061 * You may obtain a copy of the License at
32062 *
32063 * http://www.apache.org/licenses/LICENSE-2.0
32064 *
32065 * Unless required by applicable law or agreed to in writing, software
32066 * distributed under the License is distributed on an "AS IS" BASIS,
32067 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32068 * See the License for the specific language governing permissions and
32069 * limitations under the License.
32070 * =============================================================================
32071 */
32072 getGlobalTensorClass().prototype.exp = function () {
32073 this.throwIfDisposed();
32074 return exp(this);
32075 };
32076
32077 /**
32078 * @license
32079 * Copyright 2020 Google LLC. All Rights Reserved.
32080 * Licensed under the Apache License, Version 2.0 (the "License");
32081 * you may not use this file except in compliance with the License.
32082 * You may obtain a copy of the License at
32083 *
32084 * http://www.apache.org/licenses/LICENSE-2.0
32085 *
32086 * Unless required by applicable law or agreed to in writing, software
32087 * distributed under the License is distributed on an "AS IS" BASIS,
32088 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32089 * See the License for the specific language governing permissions and
32090 * limitations under the License.
32091 * =============================================================================
32092 */
32093 getGlobalTensorClass().prototype.expandDims = function (axis) {
32094 this.throwIfDisposed();
32095 return expandDims(this, axis);
32096 };
32097
32098 /**
32099 * @license
32100 * Copyright 2020 Google LLC. All Rights Reserved.
32101 * Licensed under the Apache License, Version 2.0 (the "License");
32102 * you may not use this file except in compliance with the License.
32103 * You may obtain a copy of the License at
32104 *
32105 * http://www.apache.org/licenses/LICENSE-2.0
32106 *
32107 * Unless required by applicable law or agreed to in writing, software
32108 * distributed under the License is distributed on an "AS IS" BASIS,
32109 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32110 * See the License for the specific language governing permissions and
32111 * limitations under the License.
32112 * =============================================================================
32113 */
32114 getGlobalTensorClass().prototype.expm1 = function () {
32115 this.throwIfDisposed();
32116 return expm1(this);
32117 };
32118
32119 /**
32120 * @license
32121 * Copyright 2020 Google LLC. All Rights Reserved.
32122 * Licensed under the Apache License, Version 2.0 (the "License");
32123 * you may not use this file except in compliance with the License.
32124 * You may obtain a copy of the License at
32125 *
32126 * http://www.apache.org/licenses/LICENSE-2.0
32127 *
32128 * Unless required by applicable law or agreed to in writing, software
32129 * distributed under the License is distributed on an "AS IS" BASIS,
32130 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32131 * See the License for the specific language governing permissions and
32132 * limitations under the License.
32133 * =============================================================================
32134 */
32135 getGlobalTensorClass().prototype.fft = function () {
32136 this.throwIfDisposed();
32137 return fft(this);
32138 };
32139
32140 /**
32141 * @license
32142 * Copyright 2020 Google LLC. All Rights Reserved.
32143 * Licensed under the Apache License, Version 2.0 (the "License");
32144 * you may not use this file except in compliance with the License.
32145 * You may obtain a copy of the License at
32146 *
32147 * http://www.apache.org/licenses/LICENSE-2.0
32148 *
32149 * Unless required by applicable law or agreed to in writing, software
32150 * distributed under the License is distributed on an "AS IS" BASIS,
32151 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32152 * See the License for the specific language governing permissions and
32153 * limitations under the License.
32154 * =============================================================================
32155 */
32156 /**
32157 * Flatten a Tensor to a 1D array.
32158 * @doc {heading: 'Tensors', subheading: 'Classes'}
32159 */
32160 getGlobalTensorClass().prototype.flatten = function () {
32161 this.throwIfDisposed();
32162 return reshape(this, [this.size]);
32163 };
32164
32165 /**
32166 * @license
32167 * Copyright 2020 Google LLC. All Rights Reserved.
32168 * Licensed under the Apache License, Version 2.0 (the "License");
32169 * you may not use this file except in compliance with the License.
32170 * You may obtain a copy of the License at
32171 *
32172 * http://www.apache.org/licenses/LICENSE-2.0
32173 *
32174 * Unless required by applicable law or agreed to in writing, software
32175 * distributed under the License is distributed on an "AS IS" BASIS,
32176 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32177 * See the License for the specific language governing permissions and
32178 * limitations under the License.
32179 * =============================================================================
32180 */
32181 getGlobalTensorClass().prototype.floor = function () {
32182 this.throwIfDisposed();
32183 return floor(this);
32184 };
32185
32186 /**
32187 * @license
32188 * Copyright 2020 Google LLC. All Rights Reserved.
32189 * Licensed under the Apache License, Version 2.0 (the "License");
32190 * you may not use this file except in compliance with the License.
32191 * You may obtain a copy of the License at
32192 *
32193 * http://www.apache.org/licenses/LICENSE-2.0
32194 *
32195 * Unless required by applicable law or agreed to in writing, software
32196 * distributed under the License is distributed on an "AS IS" BASIS,
32197 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32198 * See the License for the specific language governing permissions and
32199 * limitations under the License.
32200 * =============================================================================
32201 */
32202 getGlobalTensorClass().prototype.floorDiv = function (b) {
32203 this.throwIfDisposed();
32204 return floorDiv(this, b);
32205 };
32206
32207 /**
32208 * @license
32209 * Copyright 2020 Google LLC. All Rights Reserved.
32210 * Licensed under the Apache License, Version 2.0 (the "License");
32211 * you may not use this file except in compliance with the License.
32212 * You may obtain a copy of the License at
32213 *
32214 * http://www.apache.org/licenses/LICENSE-2.0
32215 *
32216 * Unless required by applicable law or agreed to in writing, software
32217 * distributed under the License is distributed on an "AS IS" BASIS,
32218 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32219 * See the License for the specific language governing permissions and
32220 * limitations under the License.
32221 * =============================================================================
32222 */
32223 getGlobalTensorClass().prototype.gather = function (indices, axis) {
32224 this.throwIfDisposed();
32225 return gather(this, indices, axis);
32226 };
32227
32228 /**
32229 * @license
32230 * Copyright 2020 Google LLC. All Rights Reserved.
32231 * Licensed under the Apache License, Version 2.0 (the "License");
32232 * you may not use this file except in compliance with the License.
32233 * You may obtain a copy of the License at
32234 *
32235 * http://www.apache.org/licenses/LICENSE-2.0
32236 *
32237 * Unless required by applicable law or agreed to in writing, software
32238 * distributed under the License is distributed on an "AS IS" BASIS,
32239 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32240 * See the License for the specific language governing permissions and
32241 * limitations under the License.
32242 * =============================================================================
32243 */
32244 getGlobalTensorClass().prototype.greaterEqual = function (b) {
32245 this.throwIfDisposed();
32246 return greaterEqual(this, b);
32247 };
32248
32249 /**
32250 * @license
32251 * Copyright 2020 Google LLC. All Rights Reserved.
32252 * Licensed under the Apache License, Version 2.0 (the "License");
32253 * you may not use this file except in compliance with the License.
32254 * You may obtain a copy of the License at
32255 *
32256 * http://www.apache.org/licenses/LICENSE-2.0
32257 *
32258 * Unless required by applicable law or agreed to in writing, software
32259 * distributed under the License is distributed on an "AS IS" BASIS,
32260 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32261 * See the License for the specific language governing permissions and
32262 * limitations under the License.
32263 * =============================================================================
32264 */
32265 getGlobalTensorClass().prototype.greater = function (b) {
32266 this.throwIfDisposed();
32267 return greater(this, b);
32268 };
32269
32270 /**
32271 * @license
32272 * Copyright 2020 Google LLC. All Rights Reserved.
32273 * Licensed under the Apache License, Version 2.0 (the "License");
32274 * you may not use this file except in compliance with the License.
32275 * You may obtain a copy of the License at
32276 *
32277 * http://www.apache.org/licenses/LICENSE-2.0
32278 *
32279 * Unless required by applicable law or agreed to in writing, software
32280 * distributed under the License is distributed on an "AS IS" BASIS,
32281 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32282 * See the License for the specific language governing permissions and
32283 * limitations under the License.
32284 * =============================================================================
32285 */
32286 getGlobalTensorClass().prototype.ifft = function () {
32287 this.throwIfDisposed();
32288 return ifft(this);
32289 };
32290
32291 /**
32292 * @license
32293 * Copyright 2020 Google LLC. All Rights Reserved.
32294 * Licensed under the Apache License, Version 2.0 (the "License");
32295 * you may not use this file except in compliance with the License.
32296 * You may obtain a copy of the License at
32297 *
32298 * http://www.apache.org/licenses/LICENSE-2.0
32299 *
32300 * Unless required by applicable law or agreed to in writing, software
32301 * distributed under the License is distributed on an "AS IS" BASIS,
32302 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32303 * See the License for the specific language governing permissions and
32304 * limitations under the License.
32305 * =============================================================================
32306 */
32307 getGlobalTensorClass().prototype.irfft = function () {
32308 this.throwIfDisposed();
32309 return irfft(this);
32310 };
32311
32312 /**
32313 * @license
32314 * Copyright 2020 Google LLC. All Rights Reserved.
32315 * Licensed under the Apache License, Version 2.0 (the "License");
32316 * you may not use this file except in compliance with the License.
32317 * You may obtain a copy of the License at
32318 *
32319 * http://www.apache.org/licenses/LICENSE-2.0
32320 *
32321 * Unless required by applicable law or agreed to in writing, software
32322 * distributed under the License is distributed on an "AS IS" BASIS,
32323 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32324 * See the License for the specific language governing permissions and
32325 * limitations under the License.
32326 * =============================================================================
32327 */
32328 getGlobalTensorClass().prototype.isFinite = function () {
32329 this.throwIfDisposed();
32330 return isFinite$1(this);
32331 };
32332
32333 /**
32334 * @license
32335 * Copyright 2020 Google LLC. All Rights Reserved.
32336 * Licensed under the Apache License, Version 2.0 (the "License");
32337 * you may not use this file except in compliance with the License.
32338 * You may obtain a copy of the License at
32339 *
32340 * http://www.apache.org/licenses/LICENSE-2.0
32341 *
32342 * Unless required by applicable law or agreed to in writing, software
32343 * distributed under the License is distributed on an "AS IS" BASIS,
32344 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32345 * See the License for the specific language governing permissions and
32346 * limitations under the License.
32347 * =============================================================================
32348 */
32349 getGlobalTensorClass().prototype.isInf = function () {
32350 this.throwIfDisposed();
32351 return isInf(this);
32352 };
32353
32354 /**
32355 * @license
32356 * Copyright 2020 Google LLC. All Rights Reserved.
32357 * Licensed under the Apache License, Version 2.0 (the "License");
32358 * you may not use this file except in compliance with the License.
32359 * You may obtain a copy of the License at
32360 *
32361 * http://www.apache.org/licenses/LICENSE-2.0
32362 *
32363 * Unless required by applicable law or agreed to in writing, software
32364 * distributed under the License is distributed on an "AS IS" BASIS,
32365 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32366 * See the License for the specific language governing permissions and
32367 * limitations under the License.
32368 * =============================================================================
32369 */
32370 getGlobalTensorClass().prototype.isNaN = function () {
32371 this.throwIfDisposed();
32372 return isNaN$1(this);
32373 };
32374
32375 /**
32376 * @license
32377 * Copyright 2020 Google LLC. All Rights Reserved.
32378 * Licensed under the Apache License, Version 2.0 (the "License");
32379 * you may not use this file except in compliance with the License.
32380 * You may obtain a copy of the License at
32381 *
32382 * http://www.apache.org/licenses/LICENSE-2.0
32383 *
32384 * Unless required by applicable law or agreed to in writing, software
32385 * distributed under the License is distributed on an "AS IS" BASIS,
32386 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32387 * See the License for the specific language governing permissions and
32388 * limitations under the License.
32389 * =============================================================================
32390 */
32391 getGlobalTensorClass().prototype.leakyRelu = function (alpha) {
32392 this.throwIfDisposed();
32393 return leakyRelu(this, alpha);
32394 };
32395
32396 /**
32397 * @license
32398 * Copyright 2020 Google LLC. All Rights Reserved.
32399 * Licensed under the Apache License, Version 2.0 (the "License");
32400 * you may not use this file except in compliance with the License.
32401 * You may obtain a copy of the License at
32402 *
32403 * http://www.apache.org/licenses/LICENSE-2.0
32404 *
32405 * Unless required by applicable law or agreed to in writing, software
32406 * distributed under the License is distributed on an "AS IS" BASIS,
32407 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32408 * See the License for the specific language governing permissions and
32409 * limitations under the License.
32410 * =============================================================================
32411 */
32412 getGlobalTensorClass().prototype.lessEqual = function (b) {
32413 this.throwIfDisposed();
32414 return lessEqual(this, b);
32415 };
32416
32417 /**
32418 * @license
32419 * Copyright 2020 Google LLC. All Rights Reserved.
32420 * Licensed under the Apache License, Version 2.0 (the "License");
32421 * you may not use this file except in compliance with the License.
32422 * You may obtain a copy of the License at
32423 *
32424 * http://www.apache.org/licenses/LICENSE-2.0
32425 *
32426 * Unless required by applicable law or agreed to in writing, software
32427 * distributed under the License is distributed on an "AS IS" BASIS,
32428 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32429 * See the License for the specific language governing permissions and
32430 * limitations under the License.
32431 * =============================================================================
32432 */
32433 getGlobalTensorClass().prototype.less = function (b) {
32434 this.throwIfDisposed();
32435 return less(this, b);
32436 };
32437
32438 /**
32439 * @license
32440 * Copyright 2020 Google LLC. All Rights Reserved.
32441 * Licensed under the Apache License, Version 2.0 (the "License");
32442 * you may not use this file except in compliance with the License.
32443 * You may obtain a copy of the License at
32444 *
32445 * http://www.apache.org/licenses/LICENSE-2.0
32446 *
32447 * Unless required by applicable law or agreed to in writing, software
32448 * distributed under the License is distributed on an "AS IS" BASIS,
32449 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32450 * See the License for the specific language governing permissions and
32451 * limitations under the License.
32452 * =============================================================================
32453 */
32454 getGlobalTensorClass().prototype.localResponseNormalization =
32455 function (depthRadius, bias, alpha, beta) {
32456 this.throwIfDisposed();
32457 return localResponseNormalization(this, depthRadius, bias, alpha, beta);
32458 };
32459
32460 /**
32461 * @license
32462 * Copyright 2020 Google LLC. All Rights Reserved.
32463 * Licensed under the Apache License, Version 2.0 (the "License");
32464 * you may not use this file except in compliance with the License.
32465 * You may obtain a copy of the License at
32466 *
32467 * http://www.apache.org/licenses/LICENSE-2.0
32468 *
32469 * Unless required by applicable law or agreed to in writing, software
32470 * distributed under the License is distributed on an "AS IS" BASIS,
32471 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32472 * See the License for the specific language governing permissions and
32473 * limitations under the License.
32474 * =============================================================================
32475 */
32476 getGlobalTensorClass().prototype.logSigmoid = function () {
32477 this.throwIfDisposed();
32478 return logSigmoid(this);
32479 };
32480
32481 /**
32482 * @license
32483 * Copyright 2020 Google LLC. All Rights Reserved.
32484 * Licensed under the Apache License, Version 2.0 (the "License");
32485 * you may not use this file except in compliance with the License.
32486 * You may obtain a copy of the License at
32487 *
32488 * http://www.apache.org/licenses/LICENSE-2.0
32489 *
32490 * Unless required by applicable law or agreed to in writing, software
32491 * distributed under the License is distributed on an "AS IS" BASIS,
32492 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32493 * See the License for the specific language governing permissions and
32494 * limitations under the License.
32495 * =============================================================================
32496 */
32497 getGlobalTensorClass().prototype.logSoftmax = function (axis) {
32498 this.throwIfDisposed();
32499 return logSoftmax(this, axis);
32500 };
32501
32502 /**
32503 * @license
32504 * Copyright 2020 Google LLC. All Rights Reserved.
32505 * Licensed under the Apache License, Version 2.0 (the "License");
32506 * you may not use this file except in compliance with the License.
32507 * You may obtain a copy of the License at
32508 *
32509 * http://www.apache.org/licenses/LICENSE-2.0
32510 *
32511 * Unless required by applicable law or agreed to in writing, software
32512 * distributed under the License is distributed on an "AS IS" BASIS,
32513 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32514 * See the License for the specific language governing permissions and
32515 * limitations under the License.
32516 * =============================================================================
32517 */
32518 getGlobalTensorClass().prototype.logSumExp = function (axis, keepDims) {
32519 this.throwIfDisposed();
32520 return logSumExp(this, axis, keepDims);
32521 };
32522
32523 /**
32524 * @license
32525 * Copyright 2020 Google LLC. All Rights Reserved.
32526 * Licensed under the Apache License, Version 2.0 (the "License");
32527 * you may not use this file except in compliance with the License.
32528 * You may obtain a copy of the License at
32529 *
32530 * http://www.apache.org/licenses/LICENSE-2.0
32531 *
32532 * Unless required by applicable law or agreed to in writing, software
32533 * distributed under the License is distributed on an "AS IS" BASIS,
32534 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32535 * See the License for the specific language governing permissions and
32536 * limitations under the License.
32537 * =============================================================================
32538 */
32539 getGlobalTensorClass().prototype.log = function () {
32540 this.throwIfDisposed();
32541 return log$1(this);
32542 };
32543
32544 /**
32545 * @license
32546 * Copyright 2020 Google LLC. All Rights Reserved.
32547 * Licensed under the Apache License, Version 2.0 (the "License");
32548 * you may not use this file except in compliance with the License.
32549 * You may obtain a copy of the License at
32550 *
32551 * http://www.apache.org/licenses/LICENSE-2.0
32552 *
32553 * Unless required by applicable law or agreed to in writing, software
32554 * distributed under the License is distributed on an "AS IS" BASIS,
32555 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32556 * See the License for the specific language governing permissions and
32557 * limitations under the License.
32558 * =============================================================================
32559 */
32560 getGlobalTensorClass().prototype.log1p = function () {
32561 this.throwIfDisposed();
32562 return log1p(this);
32563 };
32564
32565 /**
32566 * @license
32567 * Copyright 2020 Google LLC. All Rights Reserved.
32568 * Licensed under the Apache License, Version 2.0 (the "License");
32569 * you may not use this file except in compliance with the License.
32570 * You may obtain a copy of the License at
32571 *
32572 * http://www.apache.org/licenses/LICENSE-2.0
32573 *
32574 * Unless required by applicable law or agreed to in writing, software
32575 * distributed under the License is distributed on an "AS IS" BASIS,
32576 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32577 * See the License for the specific language governing permissions and
32578 * limitations under the License.
32579 * =============================================================================
32580 */
32581 getGlobalTensorClass().prototype.logicalAnd = function (b) {
32582 this.throwIfDisposed();
32583 return logicalAnd(this, b);
32584 };
32585
32586 /**
32587 * @license
32588 * Copyright 2020 Google LLC. All Rights Reserved.
32589 * Licensed under the Apache License, Version 2.0 (the "License");
32590 * you may not use this file except in compliance with the License.
32591 * You may obtain a copy of the License at
32592 *
32593 * http://www.apache.org/licenses/LICENSE-2.0
32594 *
32595 * Unless required by applicable law or agreed to in writing, software
32596 * distributed under the License is distributed on an "AS IS" BASIS,
32597 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32598 * See the License for the specific language governing permissions and
32599 * limitations under the License.
32600 * =============================================================================
32601 */
32602 getGlobalTensorClass().prototype.logicalNot = function () {
32603 this.throwIfDisposed();
32604 return logicalNot(this);
32605 };
32606
32607 /**
32608 * @license
32609 * Copyright 2020 Google LLC. All Rights Reserved.
32610 * Licensed under the Apache License, Version 2.0 (the "License");
32611 * you may not use this file except in compliance with the License.
32612 * You may obtain a copy of the License at
32613 *
32614 * http://www.apache.org/licenses/LICENSE-2.0
32615 *
32616 * Unless required by applicable law or agreed to in writing, software
32617 * distributed under the License is distributed on an "AS IS" BASIS,
32618 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32619 * See the License for the specific language governing permissions and
32620 * limitations under the License.
32621 * =============================================================================
32622 */
32623 getGlobalTensorClass().prototype.logicalOr = function (b) {
32624 this.throwIfDisposed();
32625 return logicalOr(this, b);
32626 };
32627
32628 /**
32629 * @license
32630 * Copyright 2020 Google LLC. All Rights Reserved.
32631 * Licensed under the Apache License, Version 2.0 (the "License");
32632 * you may not use this file except in compliance with the License.
32633 * You may obtain a copy of the License at
32634 *
32635 * http://www.apache.org/licenses/LICENSE-2.0
32636 *
32637 * Unless required by applicable law or agreed to in writing, software
32638 * distributed under the License is distributed on an "AS IS" BASIS,
32639 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32640 * See the License for the specific language governing permissions and
32641 * limitations under the License.
32642 * =============================================================================
32643 */
32644 getGlobalTensorClass().prototype.logicalXor = function (b) {
32645 this.throwIfDisposed();
32646 return logicalXor(this, b);
32647 };
32648
32649 /**
32650 * @license
32651 * Copyright 2020 Google LLC. All Rights Reserved.
32652 * Licensed under the Apache License, Version 2.0 (the "License");
32653 * you may not use this file except in compliance with the License.
32654 * You may obtain a copy of the License at
32655 *
32656 * http://www.apache.org/licenses/LICENSE-2.0
32657 *
32658 * Unless required by applicable law or agreed to in writing, software
32659 * distributed under the License is distributed on an "AS IS" BASIS,
32660 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32661 * See the License for the specific language governing permissions and
32662 * limitations under the License.
32663 * =============================================================================
32664 */
32665 getGlobalTensorClass().prototype.matMul = function (b, transposeA, transposeB) {
32666 this.throwIfDisposed();
32667 return matMul(this, b, transposeA, transposeB);
32668 };
32669
32670 getGlobalTensorClass().prototype.maxPool =
32671 function (filterSize, strides, pad, dimRoundingMode) {
32672 this.throwIfDisposed();
32673 return maxPool(this, filterSize, strides, pad, dimRoundingMode);
32674 };
32675
32676 /**
32677 * @license
32678 * Copyright 2020 Google LLC. All Rights Reserved.
32679 * Licensed under the Apache License, Version 2.0 (the "License");
32680 * you may not use this file except in compliance with the License.
32681 * You may obtain a copy of the License at
32682 *
32683 * http://www.apache.org/licenses/LICENSE-2.0
32684 *
32685 * Unless required by applicable law or agreed to in writing, software
32686 * distributed under the License is distributed on an "AS IS" BASIS,
32687 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32688 * See the License for the specific language governing permissions and
32689 * limitations under the License.
32690 * =============================================================================
32691 */
32692 getGlobalTensorClass().prototype.max = function (axis, keepDims) {
32693 this.throwIfDisposed();
32694 return max(this, axis, keepDims);
32695 };
32696
32697 /**
32698 * @license
32699 * Copyright 2020 Google LLC. All Rights Reserved.
32700 * Licensed under the Apache License, Version 2.0 (the "License");
32701 * you may not use this file except in compliance with the License.
32702 * You may obtain a copy of the License at
32703 *
32704 * http://www.apache.org/licenses/LICENSE-2.0
32705 *
32706 * Unless required by applicable law or agreed to in writing, software
32707 * distributed under the License is distributed on an "AS IS" BASIS,
32708 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32709 * See the License for the specific language governing permissions and
32710 * limitations under the License.
32711 * =============================================================================
32712 */
32713 getGlobalTensorClass().prototype.maximum = function (b) {
32714 this.throwIfDisposed();
32715 return maximum(this, b);
32716 };
32717
32718 /**
32719 * @license
32720 * Copyright 2020 Google LLC. All Rights Reserved.
32721 * Licensed under the Apache License, Version 2.0 (the "License");
32722 * you may not use this file except in compliance with the License.
32723 * You may obtain a copy of the License at
32724 *
32725 * http://www.apache.org/licenses/LICENSE-2.0
32726 *
32727 * Unless required by applicable law or agreed to in writing, software
32728 * distributed under the License is distributed on an "AS IS" BASIS,
32729 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32730 * See the License for the specific language governing permissions and
32731 * limitations under the License.
32732 * =============================================================================
32733 */
32734 getGlobalTensorClass().prototype.mean = function (axis, keepDims) {
32735 this.throwIfDisposed();
32736 return mean(this, axis, keepDims);
32737 };
32738
32739 /**
32740 * @license
32741 * Copyright 2020 Google LLC. All Rights Reserved.
32742 * Licensed under the Apache License, Version 2.0 (the "License");
32743 * you may not use this file except in compliance with the License.
32744 * You may obtain a copy of the License at
32745 *
32746 * http://www.apache.org/licenses/LICENSE-2.0
32747 *
32748 * Unless required by applicable law or agreed to in writing, software
32749 * distributed under the License is distributed on an "AS IS" BASIS,
32750 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32751 * See the License for the specific language governing permissions and
32752 * limitations under the License.
32753 * =============================================================================
32754 */
32755 getGlobalTensorClass().prototype.min = function (axis, keepDims) {
32756 this.throwIfDisposed();
32757 return min(this, axis, keepDims);
32758 };
32759
32760 /**
32761 * @license
32762 * Copyright 2020 Google LLC. All Rights Reserved.
32763 * Licensed under the Apache License, Version 2.0 (the "License");
32764 * you may not use this file except in compliance with the License.
32765 * You may obtain a copy of the License at
32766 *
32767 * http://www.apache.org/licenses/LICENSE-2.0
32768 *
32769 * Unless required by applicable law or agreed to in writing, software
32770 * distributed under the License is distributed on an "AS IS" BASIS,
32771 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32772 * See the License for the specific language governing permissions and
32773 * limitations under the License.
32774 * =============================================================================
32775 */
32776 getGlobalTensorClass().prototype.minimum = function (b) {
32777 this.throwIfDisposed();
32778 return minimum(this, b);
32779 };
32780
32781 /**
32782 * @license
32783 * Copyright 2020 Google LLC. All Rights Reserved.
32784 * Licensed under the Apache License, Version 2.0 (the "License");
32785 * you may not use this file except in compliance with the License.
32786 * You may obtain a copy of the License at
32787 *
32788 * http://www.apache.org/licenses/LICENSE-2.0
32789 *
32790 * Unless required by applicable law or agreed to in writing, software
32791 * distributed under the License is distributed on an "AS IS" BASIS,
32792 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32793 * See the License for the specific language governing permissions and
32794 * limitations under the License.
32795 * =============================================================================
32796 */
32797 getGlobalTensorClass().prototype.mirrorPad = function (paddings, mode) {
32798 this.throwIfDisposed();
32799 return mirrorPad(this, paddings, mode);
32800 };
32801
32802 /**
32803 * @license
32804 * Copyright 2020 Google LLC. All Rights Reserved.
32805 * Licensed under the Apache License, Version 2.0 (the "License");
32806 * you may not use this file except in compliance with the License.
32807 * You may obtain a copy of the License at
32808 *
32809 * http://www.apache.org/licenses/LICENSE-2.0
32810 *
32811 * Unless required by applicable law or agreed to in writing, software
32812 * distributed under the License is distributed on an "AS IS" BASIS,
32813 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32814 * See the License for the specific language governing permissions and
32815 * limitations under the License.
32816 * =============================================================================
32817 */
32818 getGlobalTensorClass().prototype.mod = function (b) {
32819 this.throwIfDisposed();
32820 return mod(this, b);
32821 };
32822
32823 /**
32824 * @license
32825 * Copyright 2020 Google LLC. All Rights Reserved.
32826 * Licensed under the Apache License, Version 2.0 (the "License");
32827 * you may not use this file except in compliance with the License.
32828 * You may obtain a copy of the License at
32829 *
32830 * http://www.apache.org/licenses/LICENSE-2.0
32831 *
32832 * Unless required by applicable law or agreed to in writing, software
32833 * distributed under the License is distributed on an "AS IS" BASIS,
32834 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32835 * See the License for the specific language governing permissions and
32836 * limitations under the License.
32837 * =============================================================================
32838 */
32839 getGlobalTensorClass().prototype.mul = function (b) {
32840 this.throwIfDisposed();
32841 return mul(this, b);
32842 };
32843
32844 /**
32845 * @license
32846 * Copyright 2020 Google LLC. All Rights Reserved.
32847 * Licensed under the Apache License, Version 2.0 (the "License");
32848 * you may not use this file except in compliance with the License.
32849 * You may obtain a copy of the License at
32850 *
32851 * http://www.apache.org/licenses/LICENSE-2.0
32852 *
32853 * Unless required by applicable law or agreed to in writing, software
32854 * distributed under the License is distributed on an "AS IS" BASIS,
32855 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32856 * See the License for the specific language governing permissions and
32857 * limitations under the License.
32858 * =============================================================================
32859 */
32860 getGlobalTensorClass().prototype.neg = function () {
32861 this.throwIfDisposed();
32862 return neg(this);
32863 };
32864
32865 /**
32866 * @license
32867 * Copyright 2020 Google LLC. All Rights Reserved.
32868 * Licensed under the Apache License, Version 2.0 (the "License");
32869 * you may not use this file except in compliance with the License.
32870 * You may obtain a copy of the License at
32871 *
32872 * http://www.apache.org/licenses/LICENSE-2.0
32873 *
32874 * Unless required by applicable law or agreed to in writing, software
32875 * distributed under the License is distributed on an "AS IS" BASIS,
32876 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32877 * See the License for the specific language governing permissions and
32878 * limitations under the License.
32879 * =============================================================================
32880 */
32881 getGlobalTensorClass().prototype.norm = function (ord, axis, keepDims) {
32882 this.throwIfDisposed();
32883 return norm(this, ord, axis, keepDims);
32884 };
32885
32886 /**
32887 * @license
32888 * Copyright 2020 Google LLC. All Rights Reserved.
32889 * Licensed under the Apache License, Version 2.0 (the "License");
32890 * you may not use this file except in compliance with the License.
32891 * You may obtain a copy of the License at
32892 *
32893 * http://www.apache.org/licenses/LICENSE-2.0
32894 *
32895 * Unless required by applicable law or agreed to in writing, software
32896 * distributed under the License is distributed on an "AS IS" BASIS,
32897 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32898 * See the License for the specific language governing permissions and
32899 * limitations under the License.
32900 * =============================================================================
32901 */
32902 getGlobalTensorClass().prototype.notEqual = function (b) {
32903 this.throwIfDisposed();
32904 return notEqual(this, b);
32905 };
32906
32907 /**
32908 * @license
32909 * Copyright 2020 Google LLC. All Rights Reserved.
32910 * Licensed under the Apache License, Version 2.0 (the "License");
32911 * you may not use this file except in compliance with the License.
32912 * You may obtain a copy of the License at
32913 *
32914 * http://www.apache.org/licenses/LICENSE-2.0
32915 *
32916 * Unless required by applicable law or agreed to in writing, software
32917 * distributed under the License is distributed on an "AS IS" BASIS,
32918 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32919 * See the License for the specific language governing permissions and
32920 * limitations under the License.
32921 * =============================================================================
32922 */
32923 getGlobalTensorClass().prototype.oneHot = function (depth, onValue = 1, offValue = 0) {
32924 this.throwIfDisposed();
32925 return oneHot(this, depth, onValue, offValue);
32926 };
32927
32928 /**
32929 * @license
32930 * Copyright 2020 Google LLC. All Rights Reserved.
32931 * Licensed under the Apache License, Version 2.0 (the "License");
32932 * you may not use this file except in compliance with the License.
32933 * You may obtain a copy of the License at
32934 *
32935 * http://www.apache.org/licenses/LICENSE-2.0
32936 *
32937 * Unless required by applicable law or agreed to in writing, software
32938 * distributed under the License is distributed on an "AS IS" BASIS,
32939 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32940 * See the License for the specific language governing permissions and
32941 * limitations under the License.
32942 * =============================================================================
32943 */
32944 getGlobalTensorClass().prototype.onesLike = function () {
32945 this.throwIfDisposed();
32946 return onesLike(this);
32947 };
32948
32949 /**
32950 * @license
32951 * Copyright 2020 Google LLC. All Rights Reserved.
32952 * Licensed under the Apache License, Version 2.0 (the "License");
32953 * you may not use this file except in compliance with the License.
32954 * You may obtain a copy of the License at
32955 *
32956 * http://www.apache.org/licenses/LICENSE-2.0
32957 *
32958 * Unless required by applicable law or agreed to in writing, software
32959 * distributed under the License is distributed on an "AS IS" BASIS,
32960 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32961 * See the License for the specific language governing permissions and
32962 * limitations under the License.
32963 * =============================================================================
32964 */
32965 getGlobalTensorClass().prototype.pad = function (paddings, constantValue) {
32966 this.throwIfDisposed();
32967 return pad(this, paddings, constantValue);
32968 };
32969
32970 getGlobalTensorClass().prototype.pool = function (windowShape, poolingType, padding, dilationRate, strides, dimRoundingMode) {
32971 this.throwIfDisposed();
32972 return pool(this, windowShape, poolingType, padding, dilationRate, strides, dimRoundingMode);
32973 };
32974
32975 /**
32976 * @license
32977 * Copyright 2020 Google LLC. All Rights Reserved.
32978 * Licensed under the Apache License, Version 2.0 (the "License");
32979 * you may not use this file except in compliance with the License.
32980 * You may obtain a copy of the License at
32981 *
32982 * http://www.apache.org/licenses/LICENSE-2.0
32983 *
32984 * Unless required by applicable law or agreed to in writing, software
32985 * distributed under the License is distributed on an "AS IS" BASIS,
32986 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32987 * See the License for the specific language governing permissions and
32988 * limitations under the License.
32989 * =============================================================================
32990 */
32991 getGlobalTensorClass().prototype.pow = function (exp) {
32992 this.throwIfDisposed();
32993 return pow(this, exp);
32994 };
32995
32996 /**
32997 * @license
32998 * Copyright 2020 Google LLC. All Rights Reserved.
32999 * Licensed under the Apache License, Version 2.0 (the "License");
33000 * you may not use this file except in compliance with the License.
33001 * You may obtain a copy of the License at
33002 *
33003 * http://www.apache.org/licenses/LICENSE-2.0
33004 *
33005 * Unless required by applicable law or agreed to in writing, software
33006 * distributed under the License is distributed on an "AS IS" BASIS,
33007 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33008 * See the License for the specific language governing permissions and
33009 * limitations under the License.
33010 * =============================================================================
33011 */
33012 getGlobalTensorClass().prototype.prelu = function (alpha) {
33013 this.throwIfDisposed();
33014 return prelu(this, alpha);
33015 };
33016
33017 /**
33018 * @license
33019 * Copyright 2020 Google LLC. All Rights Reserved.
33020 * Licensed under the Apache License, Version 2.0 (the "License");
33021 * you may not use this file except in compliance with the License.
33022 * You may obtain a copy of the License at
33023 *
33024 * http://www.apache.org/licenses/LICENSE-2.0
33025 *
33026 * Unless required by applicable law or agreed to in writing, software
33027 * distributed under the License is distributed on an "AS IS" BASIS,
33028 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33029 * See the License for the specific language governing permissions and
33030 * limitations under the License.
33031 * =============================================================================
33032 */
33033 getGlobalTensorClass().prototype.prod = function (axis, keepDims) {
33034 this.throwIfDisposed();
33035 return prod(this, axis, keepDims);
33036 };
33037
33038 /**
33039 * @license
33040 * Copyright 2020 Google LLC. All Rights Reserved.
33041 * Licensed under the Apache License, Version 2.0 (the "License");
33042 * you may not use this file except in compliance with the License.
33043 * You may obtain a copy of the License at
33044 *
33045 * http://www.apache.org/licenses/LICENSE-2.0
33046 *
33047 * Unless required by applicable law or agreed to in writing, software
33048 * distributed under the License is distributed on an "AS IS" BASIS,
33049 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33050 * See the License for the specific language governing permissions and
33051 * limitations under the License.
33052 * =============================================================================
33053 */
33054 getGlobalTensorClass().prototype.reciprocal = function () {
33055 this.throwIfDisposed();
33056 return reciprocal(this);
33057 };
33058
33059 /**
33060 * @license
33061 * Copyright 2020 Google LLC. All Rights Reserved.
33062 * Licensed under the Apache License, Version 2.0 (the "License");
33063 * you may not use this file except in compliance with the License.
33064 * You may obtain a copy of the License at
33065 *
33066 * http://www.apache.org/licenses/LICENSE-2.0
33067 *
33068 * Unless required by applicable law or agreed to in writing, software
33069 * distributed under the License is distributed on an "AS IS" BASIS,
33070 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33071 * See the License for the specific language governing permissions and
33072 * limitations under the License.
33073 * =============================================================================
33074 */
33075 getGlobalTensorClass().prototype.relu = function () {
33076 this.throwIfDisposed();
33077 return relu(this);
33078 };
33079
33080 /**
33081 * @license
33082 * Copyright 2020 Google LLC. All Rights Reserved.
33083 * Licensed under the Apache License, Version 2.0 (the "License");
33084 * you may not use this file except in compliance with the License.
33085 * You may obtain a copy of the License at
33086 *
33087 * http://www.apache.org/licenses/LICENSE-2.0
33088 *
33089 * Unless required by applicable law or agreed to in writing, software
33090 * distributed under the License is distributed on an "AS IS" BASIS,
33091 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33092 * See the License for the specific language governing permissions and
33093 * limitations under the License.
33094 * =============================================================================
33095 */
33096 getGlobalTensorClass().prototype.relu6 = function () {
33097 this.throwIfDisposed();
33098 return relu6(this);
33099 };
33100
33101 /**
33102 * @license
33103 * Copyright 2020 Google LLC. All Rights Reserved.
33104 * Licensed under the Apache License, Version 2.0 (the "License");
33105 * you may not use this file except in compliance with the License.
33106 * You may obtain a copy of the License at
33107 *
33108 * http://www.apache.org/licenses/LICENSE-2.0
33109 *
33110 * Unless required by applicable law or agreed to in writing, software
33111 * distributed under the License is distributed on an "AS IS" BASIS,
33112 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33113 * See the License for the specific language governing permissions and
33114 * limitations under the License.
33115 * =============================================================================
33116 */
33117 /**
33118 * Reshapes the tensor into the shape of the provided tensor.
33119 *
33120 * @param x The tensor of required shape.
33121 *
33122 * @doc {heading: 'Tensors', subheading: 'Classes'}
33123 */
33124 getGlobalTensorClass().prototype.reshapeAs = function (x) {
33125 this.throwIfDisposed();
33126 return reshape(this, x.shape);
33127 };
33128
33129 /**
33130 * @license
33131 * Copyright 2020 Google LLC. All Rights Reserved.
33132 * Licensed under the Apache License, Version 2.0 (the "License");
33133 * you may not use this file except in compliance with the License.
33134 * You may obtain a copy of the License at
33135 *
33136 * http://www.apache.org/licenses/LICENSE-2.0
33137 *
33138 * Unless required by applicable law or agreed to in writing, software
33139 * distributed under the License is distributed on an "AS IS" BASIS,
33140 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33141 * See the License for the specific language governing permissions and
33142 * limitations under the License.
33143 * =============================================================================
33144 */
33145 getGlobalTensorClass().prototype.reshape = function (shape) {
33146 this.throwIfDisposed();
33147 return reshape(this, shape);
33148 };
33149
33150 /**
33151 * @license
33152 * Copyright 2020 Google LLC. All Rights Reserved.
33153 * Licensed under the Apache License, Version 2.0 (the "License");
33154 * you may not use this file except in compliance with the License.
33155 * You may obtain a copy of the License at
33156 *
33157 * http://www.apache.org/licenses/LICENSE-2.0
33158 *
33159 * Unless required by applicable law or agreed to in writing, software
33160 * distributed under the License is distributed on an "AS IS" BASIS,
33161 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33162 * See the License for the specific language governing permissions and
33163 * limitations under the License.
33164 * =============================================================================
33165 */
33166 getGlobalTensorClass().prototype.resizeBilinear =
33167 function (newShape2D, alignCorners, halfPixelCenters) {
33168 this.throwIfDisposed();
33169 return resizeBilinear(this, newShape2D, alignCorners, halfPixelCenters);
33170 };
33171
33172 /**
33173 * @license
33174 * Copyright 2020 Google LLC. All Rights Reserved.
33175 * Licensed under the Apache License, Version 2.0 (the "License");
33176 * you may not use this file except in compliance with the License.
33177 * You may obtain a copy of the License at
33178 *
33179 * http://www.apache.org/licenses/LICENSE-2.0
33180 *
33181 * Unless required by applicable law or agreed to in writing, software
33182 * distributed under the License is distributed on an "AS IS" BASIS,
33183 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33184 * See the License for the specific language governing permissions and
33185 * limitations under the License.
33186 * =============================================================================
33187 */
33188 getGlobalTensorClass().prototype.resizeNearestNeighbor =
33189 function (newShape2D, alignCorners, halfFloatCenters) {
33190 this.throwIfDisposed();
33191 return resizeNearestNeighbor(this, newShape2D, alignCorners, halfFloatCenters);
33192 };
33193
33194 /**
33195 * @license
33196 * Copyright 2020 Google LLC. All Rights Reserved.
33197 * Licensed under the Apache License, Version 2.0 (the "License");
33198 * you may not use this file except in compliance with the License.
33199 * You may obtain a copy of the License at
33200 *
33201 * http://www.apache.org/licenses/LICENSE-2.0
33202 *
33203 * Unless required by applicable law or agreed to in writing, software
33204 * distributed under the License is distributed on an "AS IS" BASIS,
33205 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33206 * See the License for the specific language governing permissions and
33207 * limitations under the License.
33208 * =============================================================================
33209 */
33210 getGlobalTensorClass().prototype.reverse = function (axis) {
33211 this.throwIfDisposed();
33212 return reverse(this, axis);
33213 };
33214
33215 /**
33216 * @license
33217 * Copyright 2020 Google LLC. All Rights Reserved.
33218 * Licensed under the Apache License, Version 2.0 (the "License");
33219 * you may not use this file except in compliance with the License.
33220 * You may obtain a copy of the License at
33221 *
33222 * http://www.apache.org/licenses/LICENSE-2.0
33223 *
33224 * Unless required by applicable law or agreed to in writing, software
33225 * distributed under the License is distributed on an "AS IS" BASIS,
33226 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33227 * See the License for the specific language governing permissions and
33228 * limitations under the License.
33229 * =============================================================================
33230 */
33231 getGlobalTensorClass().prototype.rfft = function () {
33232 this.throwIfDisposed();
33233 return rfft(this);
33234 };
33235
33236 /**
33237 * @license
33238 * Copyright 2020 Google LLC. All Rights Reserved.
33239 * Licensed under the Apache License, Version 2.0 (the "License");
33240 * you may not use this file except in compliance with the License.
33241 * You may obtain a copy of the License at
33242 *
33243 * http://www.apache.org/licenses/LICENSE-2.0
33244 *
33245 * Unless required by applicable law or agreed to in writing, software
33246 * distributed under the License is distributed on an "AS IS" BASIS,
33247 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33248 * See the License for the specific language governing permissions and
33249 * limitations under the License.
33250 * =============================================================================
33251 */
33252 getGlobalTensorClass().prototype.round = function () {
33253 this.throwIfDisposed();
33254 return round$1(this);
33255 };
33256
33257 /**
33258 * @license
33259 * Copyright 2020 Google LLC. All Rights Reserved.
33260 * Licensed under the Apache License, Version 2.0 (the "License");
33261 * you may not use this file except in compliance with the License.
33262 * You may obtain a copy of the License at
33263 *
33264 * http://www.apache.org/licenses/LICENSE-2.0
33265 *
33266 * Unless required by applicable law or agreed to in writing, software
33267 * distributed under the License is distributed on an "AS IS" BASIS,
33268 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33269 * See the License for the specific language governing permissions and
33270 * limitations under the License.
33271 * =============================================================================
33272 */
33273 getGlobalTensorClass().prototype.rsqrt = function () {
33274 this.throwIfDisposed();
33275 return rsqrt(this);
33276 };
33277
33278 /**
33279 * @license
33280 * Copyright 2020 Google LLC. All Rights Reserved.
33281 * Licensed under the Apache License, Version 2.0 (the "License");
33282 * you may not use this file except in compliance with the License.
33283 * You may obtain a copy of the License at
33284 *
33285 * http://www.apache.org/licenses/LICENSE-2.0
33286 *
33287 * Unless required by applicable law or agreed to in writing, software
33288 * distributed under the License is distributed on an "AS IS" BASIS,
33289 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33290 * See the License for the specific language governing permissions and
33291 * limitations under the License.
33292 * =============================================================================
33293 */
33294 getGlobalTensorClass().prototype.selu = function () {
33295 this.throwIfDisposed();
33296 return selu(this);
33297 };
33298
33299 /**
33300 * @license
33301 * Copyright 2020 Google LLC. All Rights Reserved.
33302 * Licensed under the Apache License, Version 2.0 (the "License");
33303 * you may not use this file except in compliance with the License.
33304 * You may obtain a copy of the License at
33305 *
33306 * http://www.apache.org/licenses/LICENSE-2.0
33307 *
33308 * Unless required by applicable law or agreed to in writing, software
33309 * distributed under the License is distributed on an "AS IS" BASIS,
33310 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33311 * See the License for the specific language governing permissions and
33312 * limitations under the License.
33313 * =============================================================================
33314 */
33315 getGlobalTensorClass().prototype.separableConv2d =
33316 function (depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat) {
33317 this.throwIfDisposed();
33318 return separableConv2d(this, depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat);
33319 };
33320
33321 /**
33322 * @license
33323 * Copyright 2020 Google LLC. All Rights Reserved.
33324 * Licensed under the Apache License, Version 2.0 (the "License");
33325 * you may not use this file except in compliance with the License.
33326 * You may obtain a copy of the License at
33327 *
33328 * http://www.apache.org/licenses/LICENSE-2.0
33329 *
33330 * Unless required by applicable law or agreed to in writing, software
33331 * distributed under the License is distributed on an "AS IS" BASIS,
33332 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33333 * See the License for the specific language governing permissions and
33334 * limitations under the License.
33335 * =============================================================================
33336 */
33337 getGlobalTensorClass().prototype.sigmoid = function () {
33338 this.throwIfDisposed();
33339 return sigmoid(this);
33340 };
33341
33342 /**
33343 * @license
33344 * Copyright 2020 Google LLC. All Rights Reserved.
33345 * Licensed under the Apache License, Version 2.0 (the "License");
33346 * you may not use this file except in compliance with the License.
33347 * You may obtain a copy of the License at
33348 *
33349 * http://www.apache.org/licenses/LICENSE-2.0
33350 *
33351 * Unless required by applicable law or agreed to in writing, software
33352 * distributed under the License is distributed on an "AS IS" BASIS,
33353 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33354 * See the License for the specific language governing permissions and
33355 * limitations under the License.
33356 * =============================================================================
33357 */
33358 getGlobalTensorClass().prototype.sign = function () {
33359 this.throwIfDisposed();
33360 return sign(this);
33361 };
33362
33363 /**
33364 * @license
33365 * Copyright 2020 Google LLC. All Rights Reserved.
33366 * Licensed under the Apache License, Version 2.0 (the "License");
33367 * you may not use this file except in compliance with the License.
33368 * You may obtain a copy of the License at
33369 *
33370 * http://www.apache.org/licenses/LICENSE-2.0
33371 *
33372 * Unless required by applicable law or agreed to in writing, software
33373 * distributed under the License is distributed on an "AS IS" BASIS,
33374 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33375 * See the License for the specific language governing permissions and
33376 * limitations under the License.
33377 * =============================================================================
33378 */
33379 getGlobalTensorClass().prototype.sin = function () {
33380 this.throwIfDisposed();
33381 return sin(this);
33382 };
33383
33384 /**
33385 * @license
33386 * Copyright 2020 Google LLC. All Rights Reserved.
33387 * Licensed under the Apache License, Version 2.0 (the "License");
33388 * you may not use this file except in compliance with the License.
33389 * You may obtain a copy of the License at
33390 *
33391 * http://www.apache.org/licenses/LICENSE-2.0
33392 *
33393 * Unless required by applicable law or agreed to in writing, software
33394 * distributed under the License is distributed on an "AS IS" BASIS,
33395 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33396 * See the License for the specific language governing permissions and
33397 * limitations under the License.
33398 * =============================================================================
33399 */
33400 getGlobalTensorClass().prototype.sinh = function () {
33401 this.throwIfDisposed();
33402 return sinh(this);
33403 };
33404
33405 /**
33406 * @license
33407 * Copyright 2020 Google LLC. All Rights Reserved.
33408 * Licensed under the Apache License, Version 2.0 (the "License");
33409 * you may not use this file except in compliance with the License.
33410 * You may obtain a copy of the License at
33411 *
33412 * http://www.apache.org/licenses/LICENSE-2.0
33413 *
33414 * Unless required by applicable law or agreed to in writing, software
33415 * distributed under the License is distributed on an "AS IS" BASIS,
33416 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33417 * See the License for the specific language governing permissions and
33418 * limitations under the License.
33419 * =============================================================================
33420 */
33421 getGlobalTensorClass().prototype.slice = function (begin, size) {
33422 this.throwIfDisposed();
33423 return slice(this, begin, size);
33424 };
33425
33426 /**
33427 * @license
33428 * Copyright 2020 Google LLC. All Rights Reserved.
33429 * Licensed under the Apache License, Version 2.0 (the "License");
33430 * you may not use this file except in compliance with the License.
33431 * You may obtain a copy of the License at
33432 *
33433 * http://www.apache.org/licenses/LICENSE-2.0
33434 *
33435 * Unless required by applicable law or agreed to in writing, software
33436 * distributed under the License is distributed on an "AS IS" BASIS,
33437 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33438 * See the License for the specific language governing permissions and
33439 * limitations under the License.
33440 * =============================================================================
33441 */
33442 getGlobalTensorClass().prototype.softmax = function (dim) {
33443 this.throwIfDisposed();
33444 return softmax(this, dim);
33445 };
33446
33447 /**
33448 * @license
33449 * Copyright 2020 Google LLC. All Rights Reserved.
33450 * Licensed under the Apache License, Version 2.0 (the "License");
33451 * you may not use this file except in compliance with the License.
33452 * You may obtain a copy of the License at
33453 *
33454 * http://www.apache.org/licenses/LICENSE-2.0
33455 *
33456 * Unless required by applicable law or agreed to in writing, software
33457 * distributed under the License is distributed on an "AS IS" BASIS,
33458 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33459 * See the License for the specific language governing permissions and
33460 * limitations under the License.
33461 * =============================================================================
33462 */
33463 getGlobalTensorClass().prototype.softplus = function () {
33464 this.throwIfDisposed();
33465 return softplus(this);
33466 };
33467
33468 /**
33469 * @license
33470 * Copyright 2020 Google LLC. All Rights Reserved.
33471 * Licensed under the Apache License, Version 2.0 (the "License");
33472 * you may not use this file except in compliance with the License.
33473 * You may obtain a copy of the License at
33474 *
33475 * http://www.apache.org/licenses/LICENSE-2.0
33476 *
33477 * Unless required by applicable law or agreed to in writing, software
33478 * distributed under the License is distributed on an "AS IS" BASIS,
33479 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33480 * See the License for the specific language governing permissions and
33481 * limitations under the License.
33482 * =============================================================================
33483 */
33484 getGlobalTensorClass().prototype.spaceToBatchND = function (blockShape, paddings) {
33485 this.throwIfDisposed();
33486 return spaceToBatchND(this, blockShape, paddings);
33487 };
33488
33489 /**
33490 * @license
33491 * Copyright 2020 Google LLC. All Rights Reserved.
33492 * Licensed under the Apache License, Version 2.0 (the "License");
33493 * you may not use this file except in compliance with the License.
33494 * You may obtain a copy of the License at
33495 *
33496 * http://www.apache.org/licenses/LICENSE-2.0
33497 *
33498 * Unless required by applicable law or agreed to in writing, software
33499 * distributed under the License is distributed on an "AS IS" BASIS,
33500 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33501 * See the License for the specific language governing permissions and
33502 * limitations under the License.
33503 * =============================================================================
33504 */
33505 getGlobalTensorClass().prototype.split = function (numOrSizeSplits, axis) {
33506 this.throwIfDisposed();
33507 return split(this, numOrSizeSplits, axis);
33508 };
33509
33510 /**
33511 * @license
33512 * Copyright 2020 Google LLC. All Rights Reserved.
33513 * Licensed under the Apache License, Version 2.0 (the "License");
33514 * you may not use this file except in compliance with the License.
33515 * You may obtain a copy of the License at
33516 *
33517 * http://www.apache.org/licenses/LICENSE-2.0
33518 *
33519 * Unless required by applicable law or agreed to in writing, software
33520 * distributed under the License is distributed on an "AS IS" BASIS,
33521 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33522 * See the License for the specific language governing permissions and
33523 * limitations under the License.
33524 * =============================================================================
33525 */
33526 getGlobalTensorClass().prototype.sqrt = function () {
33527 this.throwIfDisposed();
33528 return sqrt(this);
33529 };
33530
33531 /**
33532 * @license
33533 * Copyright 2020 Google LLC. All Rights Reserved.
33534 * Licensed under the Apache License, Version 2.0 (the "License");
33535 * you may not use this file except in compliance with the License.
33536 * You may obtain a copy of the License at
33537 *
33538 * http://www.apache.org/licenses/LICENSE-2.0
33539 *
33540 * Unless required by applicable law or agreed to in writing, software
33541 * distributed under the License is distributed on an "AS IS" BASIS,
33542 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33543 * See the License for the specific language governing permissions and
33544 * limitations under the License.
33545 * =============================================================================
33546 */
33547 getGlobalTensorClass().prototype.square = function () {
33548 this.throwIfDisposed();
33549 return square(this);
33550 };
33551
33552 /**
33553 * @license
33554 * Copyright 2020 Google LLC. All Rights Reserved.
33555 * Licensed under the Apache License, Version 2.0 (the "License");
33556 * you may not use this file except in compliance with the License.
33557 * You may obtain a copy of the License at
33558 *
33559 * http://www.apache.org/licenses/LICENSE-2.0
33560 *
33561 * Unless required by applicable law or agreed to in writing, software
33562 * distributed under the License is distributed on an "AS IS" BASIS,
33563 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33564 * See the License for the specific language governing permissions and
33565 * limitations under the License.
33566 * =============================================================================
33567 */
33568 getGlobalTensorClass().prototype.squaredDifference = function (b) {
33569 this.throwIfDisposed();
33570 return squaredDifference(this, b);
33571 };
33572
33573 /**
33574 * @license
33575 * Copyright 2020 Google LLC. All Rights Reserved.
33576 * Licensed under the Apache License, Version 2.0 (the "License");
33577 * you may not use this file except in compliance with the License.
33578 * You may obtain a copy of the License at
33579 *
33580 * http://www.apache.org/licenses/LICENSE-2.0
33581 *
33582 * Unless required by applicable law or agreed to in writing, software
33583 * distributed under the License is distributed on an "AS IS" BASIS,
33584 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33585 * See the License for the specific language governing permissions and
33586 * limitations under the License.
33587 * =============================================================================
33588 */
33589 getGlobalTensorClass().prototype.squeeze = function (axis) {
33590 this.throwIfDisposed();
33591 return squeeze(this, axis);
33592 };
33593
33594 /**
33595 * @license
33596 * Copyright 2020 Google LLC. All Rights Reserved.
33597 * Licensed under the Apache License, Version 2.0 (the "License");
33598 * you may not use this file except in compliance with the License.
33599 * You may obtain a copy of the License at
33600 *
33601 * http://www.apache.org/licenses/LICENSE-2.0
33602 *
33603 * Unless required by applicable law or agreed to in writing, software
33604 * distributed under the License is distributed on an "AS IS" BASIS,
33605 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33606 * See the License for the specific language governing permissions and
33607 * limitations under the License.
33608 * =============================================================================
33609 */
33610 getGlobalTensorClass().prototype.stack = function (x, axis) {
33611 this.throwIfDisposed();
33612 const tensorsToBeStacked = x instanceof Tensor ? [this, x] : [this, ...x];
33613 return stack(tensorsToBeStacked, axis);
33614 };
33615
33616 /**
33617 * @license
33618 * Copyright 2020 Google LLC. All Rights Reserved.
33619 * Licensed under the Apache License, Version 2.0 (the "License");
33620 * you may not use this file except in compliance with the License.
33621 * You may obtain a copy of the License at
33622 *
33623 * http://www.apache.org/licenses/LICENSE-2.0
33624 *
33625 * Unless required by applicable law or agreed to in writing, software
33626 * distributed under the License is distributed on an "AS IS" BASIS,
33627 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33628 * See the License for the specific language governing permissions and
33629 * limitations under the License.
33630 * =============================================================================
33631 */
33632 getGlobalTensorClass().prototype.step = function (alpha) {
33633 this.throwIfDisposed();
33634 return step(this, alpha);
33635 };
33636
33637 /**
33638 * @license
33639 * Copyright 2020 Google LLC. All Rights Reserved.
33640 * Licensed under the Apache License, Version 2.0 (the "License");
33641 * you may not use this file except in compliance with the License.
33642 * You may obtain a copy of the License at
33643 *
33644 * http://www.apache.org/licenses/LICENSE-2.0
33645 *
33646 * Unless required by applicable law or agreed to in writing, software
33647 * distributed under the License is distributed on an "AS IS" BASIS,
33648 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33649 * See the License for the specific language governing permissions and
33650 * limitations under the License.
33651 * =============================================================================
33652 */
33653 getGlobalTensorClass().prototype.stridedSlice = function (begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
33654 this.throwIfDisposed();
33655 return stridedSlice(this, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
33656 };
33657
33658 /**
33659 * @license
33660 * Copyright 2020 Google LLC. All Rights Reserved.
33661 * Licensed under the Apache License, Version 2.0 (the "License");
33662 * you may not use this file except in compliance with the License.
33663 * You may obtain a copy of the License at
33664 *
33665 * http://www.apache.org/licenses/LICENSE-2.0
33666 *
33667 * Unless required by applicable law or agreed to in writing, software
33668 * distributed under the License is distributed on an "AS IS" BASIS,
33669 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33670 * See the License for the specific language governing permissions and
33671 * limitations under the License.
33672 * =============================================================================
33673 */
33674 getGlobalTensorClass().prototype.sub = function (b) {
33675 this.throwIfDisposed();
33676 return sub(this, b);
33677 };
33678
33679 /**
33680 * @license
33681 * Copyright 2020 Google LLC. All Rights Reserved.
33682 * Licensed under the Apache License, Version 2.0 (the "License");
33683 * you may not use this file except in compliance with the License.
33684 * You may obtain a copy of the License at
33685 *
33686 * http://www.apache.org/licenses/LICENSE-2.0
33687 *
33688 * Unless required by applicable law or agreed to in writing, software
33689 * distributed under the License is distributed on an "AS IS" BASIS,
33690 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33691 * See the License for the specific language governing permissions and
33692 * limitations under the License.
33693 * =============================================================================
33694 */
33695 getGlobalTensorClass().prototype.sum = function (axis, keepDims) {
33696 this.throwIfDisposed();
33697 return sum$1(this, axis, keepDims);
33698 };
33699
33700 /**
33701 * @license
33702 * Copyright 2020 Google LLC. All Rights Reserved.
33703 * Licensed under the Apache License, Version 2.0 (the "License");
33704 * you may not use this file except in compliance with the License.
33705 * You may obtain a copy of the License at
33706 *
33707 * http://www.apache.org/licenses/LICENSE-2.0
33708 *
33709 * Unless required by applicable law or agreed to in writing, software
33710 * distributed under the License is distributed on an "AS IS" BASIS,
33711 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33712 * See the License for the specific language governing permissions and
33713 * limitations under the License.
33714 * =============================================================================
33715 */
33716 getGlobalTensorClass().prototype.tan = function () {
33717 this.throwIfDisposed();
33718 return tan(this);
33719 };
33720
33721 /**
33722 * @license
33723 * Copyright 2020 Google LLC. All Rights Reserved.
33724 * Licensed under the Apache License, Version 2.0 (the "License");
33725 * you may not use this file except in compliance with the License.
33726 * You may obtain a copy of the License at
33727 *
33728 * http://www.apache.org/licenses/LICENSE-2.0
33729 *
33730 * Unless required by applicable law or agreed to in writing, software
33731 * distributed under the License is distributed on an "AS IS" BASIS,
33732 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33733 * See the License for the specific language governing permissions and
33734 * limitations under the License.
33735 * =============================================================================
33736 */
33737 getGlobalTensorClass().prototype.tanh = function () {
33738 this.throwIfDisposed();
33739 return tanh$1(this);
33740 };
33741
33742 /**
33743 * @license
33744 * Copyright 2020 Google LLC. All Rights Reserved.
33745 * Licensed under the Apache License, Version 2.0 (the "License");
33746 * you may not use this file except in compliance with the License.
33747 * You may obtain a copy of the License at
33748 *
33749 * http://www.apache.org/licenses/LICENSE-2.0
33750 *
33751 * Unless required by applicable law or agreed to in writing, software
33752 * distributed under the License is distributed on an "AS IS" BASIS,
33753 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33754 * See the License for the specific language governing permissions and
33755 * limitations under the License.
33756 * =============================================================================
33757 */
33758 getGlobalTensorClass().prototype.tile = function (reps) {
33759 this.throwIfDisposed();
33760 return tile(this, reps);
33761 };
33762
33763 /**
33764 * @license
33765 * Copyright 2020 Google LLC. All Rights Reserved.
33766 * Licensed under the Apache License, Version 2.0 (the "License");
33767 * you may not use this file except in compliance with the License.
33768 * You may obtain a copy of the License at
33769 *
33770 * http://www.apache.org/licenses/LICENSE-2.0
33771 *
33772 * Unless required by applicable law or agreed to in writing, software
33773 * distributed under the License is distributed on an "AS IS" BASIS,
33774 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33775 * See the License for the specific language governing permissions and
33776 * limitations under the License.
33777 * =============================================================================
33778 */
33779 /**
33780 * Casts the array to type `bool`
33781 *
33782 * @doc {heading: 'Tensors', subheading: 'Classes'}
33783 */
33784 getGlobalTensorClass().prototype.toBool = function () {
33785 this.throwIfDisposed();
33786 return cast(this, 'bool');
33787 };
33788
33789 /**
33790 * @license
33791 * Copyright 2020 Google LLC. All Rights Reserved.
33792 * Licensed under the Apache License, Version 2.0 (the "License");
33793 * you may not use this file except in compliance with the License.
33794 * You may obtain a copy of the License at
33795 *
33796 * http://www.apache.org/licenses/LICENSE-2.0
33797 *
33798 * Unless required by applicable law or agreed to in writing, software
33799 * distributed under the License is distributed on an "AS IS" BASIS,
33800 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33801 * See the License for the specific language governing permissions and
33802 * limitations under the License.
33803 * =============================================================================
33804 */
33805 /**
33806 * Casts the array to type `float32`
33807 *
33808 * @doc {heading: 'Tensors', subheading: 'Classes'}
33809 */
33810 getGlobalTensorClass().prototype.toFloat = function () {
33811 this.throwIfDisposed();
33812 return cast(this, 'float32');
33813 };
33814
33815 /**
33816 * @license
33817 * Copyright 2020 Google LLC. All Rights Reserved.
33818 * Licensed under the Apache License, Version 2.0 (the "License");
33819 * you may not use this file except in compliance with the License.
33820 * You may obtain a copy of the License at
33821 *
33822 * http://www.apache.org/licenses/LICENSE-2.0
33823 *
33824 * Unless required by applicable law or agreed to in writing, software
33825 * distributed under the License is distributed on an "AS IS" BASIS,
33826 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33827 * See the License for the specific language governing permissions and
33828 * limitations under the License.
33829 * =============================================================================
33830 */
33831 /**
33832 * Casts the array to type `int32`
33833 *
33834 * @doc {heading: 'Tensors', subheading: 'Classes'}
33835 */
33836 getGlobalTensorClass().prototype.toInt = function () {
33837 this.throwIfDisposed();
33838 return cast(this, 'int32');
33839 };
33840
33841 /**
33842 * @license
33843 * Copyright 2020 Google LLC. All Rights Reserved.
33844 * Licensed under the Apache License, Version 2.0 (the "License");
33845 * you may not use this file except in compliance with the License.
33846 * You may obtain a copy of the License at
33847 *
33848 * http://www.apache.org/licenses/LICENSE-2.0
33849 *
33850 * Unless required by applicable law or agreed to in writing, software
33851 * distributed under the License is distributed on an "AS IS" BASIS,
33852 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33853 * See the License for the specific language governing permissions and
33854 * limitations under the License.
33855 * =============================================================================
33856 */
33857 getGlobalTensorClass().prototype.topk = function (k, sorted) {
33858 this.throwIfDisposed();
33859 return topk(this, k, sorted);
33860 };
33861
33862 /**
33863 * @license
33864 * Copyright 2020 Google LLC. All Rights Reserved.
33865 * Licensed under the Apache License, Version 2.0 (the "License");
33866 * you may not use this file except in compliance with the License.
33867 * You may obtain a copy of the License at
33868 *
33869 * http://www.apache.org/licenses/LICENSE-2.0
33870 *
33871 * Unless required by applicable law or agreed to in writing, software
33872 * distributed under the License is distributed on an "AS IS" BASIS,
33873 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33874 * See the License for the specific language governing permissions and
33875 * limitations under the License.
33876 * =============================================================================
33877 */
33878 getGlobalTensorClass().prototype.transpose = function (perm) {
33879 this.throwIfDisposed();
33880 return transpose(this, perm);
33881 };
33882
33883 /**
33884 * @license
33885 * Copyright 2020 Google LLC. All Rights Reserved.
33886 * Licensed under the Apache License, Version 2.0 (the "License");
33887 * you may not use this file except in compliance with the License.
33888 * You may obtain a copy of the License at
33889 *
33890 * http://www.apache.org/licenses/LICENSE-2.0
33891 *
33892 * Unless required by applicable law or agreed to in writing, software
33893 * distributed under the License is distributed on an "AS IS" BASIS,
33894 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33895 * See the License for the specific language governing permissions and
33896 * limitations under the License.
33897 * =============================================================================
33898 */
33899 getGlobalTensorClass().prototype.unique = function (axis) {
33900 this.throwIfDisposed();
33901 return unique(this, axis);
33902 };
33903
33904 /**
33905 * @license
33906 * Copyright 2020 Google LLC. All Rights Reserved.
33907 * Licensed under the Apache License, Version 2.0 (the "License");
33908 * you may not use this file except in compliance with the License.
33909 * You may obtain a copy of the License at
33910 *
33911 * http://www.apache.org/licenses/LICENSE-2.0
33912 *
33913 * Unless required by applicable law or agreed to in writing, software
33914 * distributed under the License is distributed on an "AS IS" BASIS,
33915 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33916 * See the License for the specific language governing permissions and
33917 * limitations under the License.
33918 * =============================================================================
33919 */
33920 getGlobalTensorClass().prototype.unsortedSegmentSum =
33921 function (segmentIds, numSegments) {
33922 this.throwIfDisposed();
33923 return unsortedSegmentSum(this, segmentIds, numSegments);
33924 };
33925
33926 /**
33927 * @license
33928 * Copyright 2020 Google LLC. All Rights Reserved.
33929 * Licensed under the Apache License, Version 2.0 (the "License");
33930 * you may not use this file except in compliance with the License.
33931 * You may obtain a copy of the License at
33932 *
33933 * http://www.apache.org/licenses/LICENSE-2.0
33934 *
33935 * Unless required by applicable law or agreed to in writing, software
33936 * distributed under the License is distributed on an "AS IS" BASIS,
33937 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33938 * See the License for the specific language governing permissions and
33939 * limitations under the License.
33940 * =============================================================================
33941 */
33942 getGlobalTensorClass().prototype.unstack = function (axis) {
33943 this.throwIfDisposed();
33944 return unstack(this, axis);
33945 };
33946
33947 /**
33948 * @license
33949 * Copyright 2020 Google LLC. All Rights Reserved.
33950 * Licensed under the Apache License, Version 2.0 (the "License");
33951 * you may not use this file except in compliance with the License.
33952 * You may obtain a copy of the License at
33953 *
33954 * http://www.apache.org/licenses/LICENSE-2.0
33955 *
33956 * Unless required by applicable law or agreed to in writing, software
33957 * distributed under the License is distributed on an "AS IS" BASIS,
33958 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33959 * See the License for the specific language governing permissions and
33960 * limitations under the License.
33961 * =============================================================================
33962 */
33963 getGlobalTensorClass().prototype.where = function (condition, x) {
33964 this.throwIfDisposed();
33965 return where(condition, this, x);
33966 };
33967
33968 /**
33969 * @license
33970 * Copyright 2020 Google LLC. All Rights Reserved.
33971 * Licensed under the Apache License, Version 2.0 (the "License");
33972 * you may not use this file except in compliance with the License.
33973 * You may obtain a copy of the License at
33974 *
33975 * http://www.apache.org/licenses/LICENSE-2.0
33976 *
33977 * Unless required by applicable law or agreed to in writing, software
33978 * distributed under the License is distributed on an "AS IS" BASIS,
33979 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33980 * See the License for the specific language governing permissions and
33981 * limitations under the License.
33982 * =============================================================================
33983 */
33984 getGlobalTensorClass().prototype.zerosLike = function () {
33985 this.throwIfDisposed();
33986 return zerosLike(this);
33987 };
33988
33989 /**
33990 * @license
33991 * Copyright 2020 Google LLC. All Rights Reserved.
33992 * Licensed under the Apache License, Version 2.0 (the "License");
33993 * you may not use this file except in compliance with the License.
33994 * You may obtain a copy of the License at
33995 *
33996 * http://www.apache.org/licenses/LICENSE-2.0
33997 *
33998 * Unless required by applicable law or agreed to in writing, software
33999 * distributed under the License is distributed on an "AS IS" BASIS,
34000 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34001 * See the License for the specific language governing permissions and
34002 * limitations under the License.
34003 * =============================================================================
34004 */
34005
34006 /**
34007 * @license
34008 * Copyright 2018 Google LLC
34009 *
34010 * Use of this source code is governed by an MIT-style
34011 * license that can be found in the LICENSE file or at
34012 * https://opensource.org/licenses/MIT.
34013 * =============================================================================
34014 */
34015 /**
34016 * Explicit error types.
34017 *
34018 * See the following link for more information about why the code includes
34019 * calls to setPrototypeOf:
34020 *
34021 * https://github.com/Microsoft/TypeScript-wiki/blob/master/Breaking-Changes.md#extending-built-ins-like-error-array-and-map-may-no-longer-work
34022 */
34023 // tslint:enable
34024 /**
34025 * Equivalent of Python's AttributeError.
34026 */
34027 class AttributeError extends Error {
34028 constructor(message) {
34029 super(message);
34030 // Set the prototype explicitly.
34031 Object.setPrototypeOf(this, AttributeError.prototype);
34032 }
34033 }
34034 /**
34035 * Equivalent of Python's RuntimeError.
34036 */
34037 class RuntimeError extends Error {
34038 constructor(message) {
34039 super(message);
34040 // Set the prototype explicitly.
34041 Object.setPrototypeOf(this, RuntimeError.prototype);
34042 }
34043 }
34044 /**
34045 * Equivalent of Python's ValueError.
34046 */
34047 class ValueError extends Error {
34048 constructor(message) {
34049 super(message);
34050 // Set the prototype explicitly.
34051 Object.setPrototypeOf(this, ValueError.prototype);
34052 }
34053 }
34054 /**
34055 * Equivalent of Python's NotImplementedError.
34056 */
34057 class NotImplementedError extends Error {
34058 constructor(message) {
34059 super(message);
34060 // Set the prototype explicitly.
34061 Object.setPrototypeOf(this, NotImplementedError.prototype);
34062 }
34063 }
34064 /**
34065 * Equivalent of Python's AssertionError.
34066 */
34067 class AssertionError extends Error {
34068 constructor(message) {
34069 super(message);
34070 // Set the prototype explicitly.
34071 Object.setPrototypeOf(this, AssertionError.prototype);
34072 }
34073 }
34074 /**
34075 * Equivalent of Python's IndexError.
34076 */
34077 class IndexError extends Error {
34078 constructor(message) {
34079 super(message);
34080 // Set the prototype explicitly.
34081 Object.setPrototypeOf(this, IndexError.prototype);
34082 }
34083 }
34084
34085 /**
34086 * @license
34087 * Copyright 2022 Google LLC
34088 *
34089 * Use of this source code is governed by an MIT-style
34090 * license that can be found in the LICENSE file or at
34091 * https://opensource.org/licenses/MIT.
34092 * =============================================================================
34093 */
34094 /**
34095 * LruCache: A mapping from the String to T. If the number of the entries is
34096 * exceeding the `maxEntries`, the LruCache will delete the least recently
34097 * used entry.
34098 */
34099 class LruCache {
34100 constructor(maxEntries) {
34101 this.maxEntries = maxEntries || 100;
34102 this.cache = new Map();
34103 }
34104 /**
34105 * Get the entry for the key and mark it as used recently.
34106 */
34107 get(key) {
34108 let entry;
34109 if (this.cache.has(key)) {
34110 entry = this.cache.get(key);
34111 this.cache.delete(key);
34112 this.cache.set(key, entry);
34113 }
34114 return entry;
34115 }
34116 /**
34117 * Put the entry into the cache. If the key already existed, mark the key as
34118 * used recently.
34119 */
34120 put(key, value) {
34121 if (this.cache.has(key)) {
34122 this.cache.delete(key);
34123 }
34124 else if (this.cache.size >= this.maxEntries) {
34125 const keyToDelete = this.cache.keys().next().value;
34126 this.cache.delete(keyToDelete);
34127 }
34128 this.cache.set(key, value);
34129 }
34130 /**
34131 * Get the MaxEntries of the cache.
34132 */
34133 getMaxEntries() {
34134 return this.maxEntries;
34135 }
34136 /**
34137 * Set the MaxEntries of the cache. If the maxEntries is decreased, reduce
34138 * entries in the cache.
34139 */
34140 setMaxEntries(maxEntries) {
34141 if (maxEntries < 0) {
34142 throw new Error(`The maxEntries of LRU caches must be at least 0, but got ${maxEntries}.`);
34143 }
34144 if (this.maxEntries > maxEntries) {
34145 for (let i = 0; i < this.maxEntries - maxEntries; i++) {
34146 const keyToDelete = this.cache.keys().next().value;
34147 this.cache.delete(keyToDelete);
34148 }
34149 }
34150 this.maxEntries = maxEntries;
34151 }
34152 }
34153
34154 /**
34155 * @license
34156 * Copyright 2018 Google LLC
34157 *
34158 * Use of this source code is governed by an MIT-style
34159 * license that can be found in the LICENSE file or at
34160 * https://opensource.org/licenses/MIT.
34161 * =============================================================================
34162 */
34163 // tslint:enable
34164 /**
34165 * If `value` is an Array, equivalent to Python's `value * numValues`.
34166 * If `value` is not an Array, equivalent to Python's `[value] * numValues`
34167 */
34168 // tslint:disable-next-line:no-any
34169 function pyListRepeat(value, numValues) {
34170 if (Array.isArray(value)) {
34171 // tslint:disable-next-line:no-any
34172 let newArray = [];
34173 for (let i = 0; i < numValues; i++) {
34174 newArray = newArray.concat(value);
34175 }
34176 return newArray;
34177 }
34178 else {
34179 const newArray = new Array(numValues);
34180 newArray.fill(value);
34181 return newArray;
34182 }
34183 }
34184 function assert$1(val, message) {
34185 if (!val) {
34186 throw new AssertionError(message);
34187 }
34188 }
34189 /**
34190 * Count the number of elements of the `array` that are equal to `reference`.
34191 */
34192 function count(array, refernce) {
34193 let counter = 0;
34194 for (const item of array) {
34195 if (item === refernce) {
34196 counter++;
34197 }
34198 }
34199 return counter;
34200 }
34201 /**
34202 * If an array is of length 1, just return the first element. Otherwise, return
34203 * the full array.
34204 * @param tensors
34205 */
34206 function singletonOrArray(xs) {
34207 if (xs.length === 1) {
34208 return xs[0];
34209 }
34210 return xs;
34211 }
34212 /**
34213 * Normalizes a list/tensor into a list.
34214 *
34215 * If a tensor is passed, we return
34216 * a list of size 1 containing the tensor.
34217 *
34218 * @param x target object to be normalized.
34219 */
34220 // tslint:disable-next-line:no-any
34221 function toList(x) {
34222 if (Array.isArray(x)) {
34223 return x;
34224 }
34225 return [x];
34226 }
34227 /**
34228 * Generate a UID for a list
34229 */
34230 // tslint:disable-next-line:no-any
34231 function objectListUid(objs) {
34232 const objectList = toList(objs);
34233 let retVal = '';
34234 for (const obj of objectList) {
34235 if (obj.id == null) {
34236 throw new ValueError(`Object ${obj} passed to objectListUid without an id`);
34237 }
34238 if (retVal !== '') {
34239 retVal = retVal + ', ';
34240 }
34241 retVal = `${retVal}${Math.abs(obj.id)}`;
34242 }
34243 return retVal;
34244 }
34245 /**
34246 * Converts string to snake-case.
34247 * @param name
34248 */
34249 function toSnakeCase(name) {
34250 const intermediate = name.replace(/(.)([A-Z][a-z0-9]+)/g, '$1_$2');
34251 const insecure = intermediate.replace(/([a-z])([A-Z])/g, '$1_$2').toLowerCase();
34252 /*
34253 If the class is private the name starts with "_" which is not secure
34254 for creating scopes. We prefix the name with "private" in this case.
34255 */
34256 if (insecure[0] !== '_') {
34257 return insecure;
34258 }
34259 return 'private' + insecure;
34260 }
34261 function toCamelCase(identifier) {
34262 // quick return for empty string or single character strings
34263 if (identifier.length <= 1) {
34264 return identifier;
34265 }
34266 // Check for the underscore indicating snake_case
34267 if (identifier.indexOf('_') === -1) {
34268 return identifier;
34269 }
34270 return identifier.replace(/[_]+(\w|$)/g, (m, p1) => p1.toUpperCase());
34271 }
34272 // tslint:disable-next-line:no-any
34273 let _GLOBAL_CUSTOM_OBJECTS = {};
34274 function serializeKerasObject(instance) {
34275 if (instance === null || instance === undefined) {
34276 return null;
34277 }
34278 const dict = {};
34279 dict['className'] = instance.getClassName();
34280 dict['config'] = instance.getConfig();
34281 return dict;
34282 }
34283 /**
34284 * Replace ndarray-style scalar objects in serialization objects with numbers.
34285 *
34286 * Background: In some versions of tf.keras, certain scalar values in the HDF5
34287 * model save file can be serialized as: `{'type': 'ndarray', 'value': num}`,
34288 * where in `num` is a plain number. This method converts such serialization
34289 * to a `number`.
34290 *
34291 * @param config The keras-format serialization object to be processed
34292 * (in place).
34293 */
34294 function convertNDArrayScalarsInConfig(config) {
34295 if (config == null || typeof config !== 'object') {
34296 return;
34297 }
34298 else if (Array.isArray(config)) {
34299 config.forEach(configItem => convertNDArrayScalarsInConfig(configItem));
34300 }
34301 else {
34302 const fields = Object.keys(config);
34303 for (const field of fields) {
34304 const value = config[field];
34305 if (value != null && typeof value === 'object') {
34306 if (!Array.isArray(value) && value['type'] === 'ndarray' &&
34307 typeof value['value'] === 'number') {
34308 config[field] = value['value'];
34309 }
34310 else {
34311 convertNDArrayScalarsInConfig(value);
34312 }
34313 }
34314 }
34315 }
34316 }
34317 /**
34318 * Deserialize a saved Keras Object
34319 * @param identifier either a string ID or a saved Keras dictionary
34320 * @param moduleObjects a list of Python class names to object constructors
34321 * @param customObjects a list of Python class names to object constructors
34322 * @param printableModuleName debug text for the object being reconstituted
34323 * @param fastWeightInit Optional flag to use fast weight initialization
34324 * during deserialization. This is applicable to cases in which
34325 * the initialization will be immediately overwritten by loaded weight
34326 * values. Default: `false`.
34327 * @returns a TensorFlow.js Layers object
34328 */
34329 // tslint:disable:no-any
34330 function deserializeKerasObject(identifier, moduleObjects = {}, customObjects = {}, printableModuleName = 'object', fastWeightInit = false) {
34331 // tslint:enable
34332 if (typeof identifier === 'string') {
34333 const functionName = identifier;
34334 let fn;
34335 if (functionName in customObjects) {
34336 fn = customObjects[functionName];
34337 }
34338 else if (functionName in _GLOBAL_CUSTOM_OBJECTS) {
34339 fn = _GLOBAL_CUSTOM_OBJECTS[functionName];
34340 }
34341 else {
34342 fn = moduleObjects[functionName];
34343 if (fn == null) {
34344 throw new ValueError(`Unknown ${printableModuleName}: ${identifier}. ` +
34345 `This may be due to one of the following reasons:\n` +
34346 `1. The ${printableModuleName} is defined in Python, in which ` +
34347 `case it needs to be ported to TensorFlow.js or your JavaScript ` +
34348 `code.\n` +
34349 `2. The custom ${printableModuleName} is defined in JavaScript, ` +
34350 `but is not registered properly with ` +
34351 `tf.serialization.registerClass().`);
34352 // TODO(cais): Add link to tutorial page on custom layers.
34353 }
34354 }
34355 return fn;
34356 }
34357 else {
34358 // In this case we are dealing with a Keras config dictionary.
34359 const config = identifier;
34360 if (config['className'] == null || config['config'] == null) {
34361 throw new ValueError(`${printableModuleName}: Improper config format: ` +
34362 `${JSON.stringify(config)}.\n` +
34363 `'className' and 'config' must set.`);
34364 }
34365 const className = config['className'];
34366 let cls, fromConfig;
34367 if (className in customObjects) {
34368 [cls, fromConfig] = customObjects[className];
34369 }
34370 else if (className in _GLOBAL_CUSTOM_OBJECTS) {
34371 [cls, fromConfig] = _GLOBAL_CUSTOM_OBJECTS['className'];
34372 }
34373 else if (className in moduleObjects) {
34374 [cls, fromConfig] = moduleObjects[className];
34375 }
34376 if (cls == null) {
34377 throw new ValueError(`Unknown ${printableModuleName}: ${className}. ` +
34378 `This may be due to one of the following reasons:\n` +
34379 `1. The ${printableModuleName} is defined in Python, in which ` +
34380 `case it needs to be ported to TensorFlow.js or your JavaScript ` +
34381 `code.\n` +
34382 `2. The custom ${printableModuleName} is defined in JavaScript, ` +
34383 `but is not registered properly with ` +
34384 `tf.serialization.registerClass().`);
34385 // TODO(cais): Add link to tutorial page on custom layers.
34386 }
34387 if (fromConfig != null) {
34388 // Porting notes: Instead of checking to see whether fromConfig accepts
34389 // customObjects, we create a customObjects dictionary and tack it on to
34390 // config['config'] as config['config'].customObjects. Objects can use it,
34391 // if they want.
34392 // tslint:disable-next-line:no-any
34393 const customObjectsCombined = {};
34394 for (const key of Object.keys(_GLOBAL_CUSTOM_OBJECTS)) {
34395 customObjectsCombined[key] = _GLOBAL_CUSTOM_OBJECTS[key];
34396 }
34397 for (const key of Object.keys(customObjects)) {
34398 customObjectsCombined[key] = customObjects[key];
34399 }
34400 // Add the customObjects to config
34401 const nestedConfig = config['config'];
34402 nestedConfig['customObjects'] = customObjectsCombined;
34403 const backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS);
34404 for (const key of Object.keys(customObjects)) {
34405 _GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];
34406 }
34407 convertNDArrayScalarsInConfig(config['config']);
34408 const returnObj = fromConfig(cls, config['config'], customObjects, fastWeightInit);
34409 _GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects);
34410 return returnObj;
34411 }
34412 else {
34413 // Then `cls` may be a function returning a class.
34414 // In this case by convention `config` holds
34415 // the kwargs of the function.
34416 const backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS);
34417 for (const key of Object.keys(customObjects)) {
34418 _GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];
34419 }
34420 // In python this is **config['config'], for tfjs-layers we require
34421 // classes that use this fall-through construction method to take
34422 // a config interface that mimics the expansion of named parameters.
34423 const returnObj = new cls(config['config']);
34424 _GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects);
34425 return returnObj;
34426 }
34427 }
34428 }
34429 /**
34430 * Compares two numbers for sorting.
34431 * @param a
34432 * @param b
34433 */
34434 function numberCompare(a, b) {
34435 return (a < b) ? -1 : ((a > b) ? 1 : 0);
34436 }
34437 /**
34438 * Comparison of two numbers for reverse sorting.
34439 * @param a
34440 * @param b
34441 */
34442 function reverseNumberCompare(a, b) {
34443 return -1 * numberCompare(a, b);
34444 }
34445 /**
34446 * Convert a string into the corresponding DType.
34447 * @param dtype
34448 * @returns An instance of DType.
34449 */
34450 function stringToDType(dtype) {
34451 switch (dtype) {
34452 case 'float32':
34453 return 'float32';
34454 default:
34455 throw new ValueError(`Invalid dtype: ${dtype}`);
34456 }
34457 }
34458 /**
34459 * Test the element-by-element equality of two Arrays of strings.
34460 * @param xs First array of strings.
34461 * @param ys Second array of strings.
34462 * @returns Wether the two arrays are all equal, element by element.
34463 */
34464 function stringsEqual(xs, ys) {
34465 if (xs == null || ys == null) {
34466 return xs === ys;
34467 }
34468 if (xs.length !== ys.length) {
34469 return false;
34470 }
34471 for (let i = 0; i < xs.length; ++i) {
34472 if (xs[i] !== ys[i]) {
34473 return false;
34474 }
34475 }
34476 return true;
34477 }
34478 /**
34479 * Get the unique elements of an array.
34480 * @param xs Array.
34481 * @returns An Array consisting of the unique elements in `xs`.
34482 */
34483 function unique$1(xs) {
34484 if (xs == null) {
34485 return xs;
34486 }
34487 const out = [];
34488 // TODO(cais): Maybe improve performance by sorting.
34489 for (const x of xs) {
34490 if (out.indexOf(x) === -1) {
34491 out.push(x);
34492 }
34493 }
34494 return out;
34495 }
34496 /**
34497 * Determine if an Object is empty (i.e., does not have own properties).
34498 * @param obj Object
34499 * @returns Whether the Object is empty.
34500 * @throws ValueError: If object is `null` or `undefined`.
34501 */
34502 function isObjectEmpty(obj) {
34503 if (obj == null) {
34504 throw new ValueError(`Invalid value in obj: ${JSON.stringify(obj)}`);
34505 }
34506 for (const key in obj) {
34507 if (obj.hasOwnProperty(key)) {
34508 return false;
34509 }
34510 }
34511 return true;
34512 }
34513 /**
34514 * Helper function used to build type union/enum run-time checkers.
34515 * @param values The list of allowed values.
34516 * @param label A string name for the type
34517 * @param value The value to test.
34518 * @throws ValueError: If the value is not in values nor `undefined`/`null`.
34519 */
34520 function checkStringTypeUnionValue(values, label, value) {
34521 if (value == null) {
34522 return;
34523 }
34524 if (values.indexOf(value) < 0) {
34525 throw new ValueError(`${value} is not a valid ${label}. Valid values are ${values} or null/undefined.`);
34526 }
34527 }
34528 /**
34529 * Helper function for verifying the types of inputs.
34530 *
34531 * Ensures that the elements of `x` are all of type `expectedType`.
34532 * Also verifies that the length of `x` is within bounds.
34533 *
34534 * @param x Object to test.
34535 * @param expectedType The string expected type of all of the elements in the
34536 * Array.
34537 * @param minLength Return false if x.length is less than this.
34538 * @param maxLength Return false if x.length is greater than this.
34539 * @returns true if and only if `x` is an `Array<expectedType>` with
34540 * length >= `minLength` and <= `maxLength`.
34541 */
34542 // tslint:disable:no-any
34543 function checkArrayTypeAndLength(x, expectedType, minLength = 0, maxLength = Infinity) {
34544 assert$1(minLength >= 0);
34545 assert$1(maxLength >= minLength);
34546 return (Array.isArray(x) && x.length >= minLength && x.length <= maxLength &&
34547 x.every(e => typeof e === expectedType));
34548 }
34549 // tslint:enable:no-any
34550 /**
34551 * Assert that a value or an array of value are positive integer.
34552 *
34553 * @param value The value being asserted on. May be a single number or an array
34554 * of numbers.
34555 * @param name Name of the value, used to make the error message.
34556 */
34557 function assertPositiveInteger(value, name) {
34558 if (Array.isArray(value)) {
34559 assert(value.length > 0, () => `${name} is unexpectedly an empty array.`);
34560 value.forEach((v, i) => assertPositiveInteger(v, `element ${i + 1} of ${name}`));
34561 }
34562 else {
34563 assert(Number.isInteger(value) && value > 0, () => `Expected ${name} to be a positive integer, but got ` +
34564 `${formatAsFriendlyString(value)}.`);
34565 }
34566 }
34567 /**
34568 * Format a value into a display-friendly, human-readable fashion.
34569 *
34570 * - `null` is formatted as `'null'`
34571 * - Strings are formated with flanking pair of quotes.
34572 * - Arrays are formatted with flanking pair of square brackets.
34573 *
34574 * @param value The value to display.
34575 * @return Formatted string.
34576 */
34577 // tslint:disable-next-line:no-any
34578 function formatAsFriendlyString(value) {
34579 if (value === null) {
34580 return 'null';
34581 }
34582 else if (Array.isArray(value)) {
34583 return '[' + value.map(v => formatAsFriendlyString(v)).join(',') + ']';
34584 }
34585 else if (typeof value === 'string') {
34586 return `"${value}"`;
34587 }
34588 else {
34589 return `${value}`;
34590 }
34591 }
34592 /**
34593 * Returns a function `f2` (decorator) which wraps the original function
34594 * `f`. `f2` guarantees that `f` can be called at most once
34595 * every `waitMs` ms. If `f2` is called more often, it will return
34596 * the last returned result of `f`.
34597 *
34598 * @param f The original function `f` to wrap.
34599 * @param waitMs The time between two consecutive calls to `f` in ms.
34600 */
34601 function debounce(f, waitMs, nowFunc) {
34602 let lastTime = nowFunc != null ? nowFunc() : now();
34603 let lastResult;
34604 const f2 = (...args) => {
34605 const now$1 = nowFunc != null ? nowFunc() : now();
34606 if (now$1 - lastTime < waitMs) {
34607 return lastResult;
34608 }
34609 lastTime = now$1;
34610 lastResult = f(...args);
34611 return lastResult;
34612 };
34613 return f2;
34614 }
34615 /**
34616 * Returns the fusable activation given a layers identifier.
34617 *
34618 * @param activationName The layers identifier string.
34619 * @return The name of the fusable activation.
34620 */
34621 function mapActivationToFusedKernel(activationName) {
34622 if (activationName === 'relu') {
34623 return 'relu';
34624 }
34625 if (activationName === 'linear') {
34626 return 'linear';
34627 }
34628 if (activationName === 'elu') {
34629 return 'elu';
34630 }
34631 return null;
34632 }
34633 /**
34634 * Returns the cartesian product of sets of values.
34635 * This works the same as itertools.product in Python.
34636 *
34637 * Example:
34638 *
34639 * filters = [128, 256, 512]
34640 * paddings = ['same', 'valid']
34641 *
34642 * product = [ [128, 'same'], [128, 'valid'], [256, 'same'], [256, 'valid'],
34643 * [512, 'same'], [512, 'valid']]
34644 *
34645 * @param arrayOfValues List/array of values.
34646 * @return The cartesian product.
34647 */
34648 function getCartesianProductOfValues(...arrayOfValues) {
34649 assert$1(arrayOfValues.length > 0, 'arrayOfValues is empty');
34650 for (const values of arrayOfValues) {
34651 assert$1(Array.isArray(values), 'one of the values is not an array');
34652 assert$1(values.length > 0, 'one of the values is empty');
34653 }
34654 return arrayOfValues.reduce((products, values) => {
34655 if (products.length === 0) {
34656 return values.map(value => [value]);
34657 }
34658 return values
34659 .map(value => {
34660 return products.map((prevValue) => [...prevValue, value]);
34661 })
34662 .reduce((flattenedProduct, unflattenedProduct) => {
34663 return flattenedProduct.concat(unflattenedProduct);
34664 }, []);
34665 }, []);
34666 }
34667
34668 /**
34669 * @license
34670 * Copyright 2018 Google LLC
34671 *
34672 * Use of this source code is governed by an MIT-style
34673 * license that can be found in the LICENSE file or at
34674 * https://opensource.org/licenses/MIT.
34675 * =============================================================================
34676 */
34677 /**
34678 * Utilities related to persistent state in the backend.
34679 */
34680 /**
34681 * An ID to track `tf.SymbolicTensor`s and derived classes.
34682 * Required in different places in engine/topology.ts to identify unique
34683 * tensors.
34684 */
34685 let _nextUniqueTensorId = 0;
34686 function getNextUniqueTensorId() {
34687 return _nextUniqueTensorId++;
34688 }
34689 const _uidPrefixes = {};
34690 /**
34691 * Provides a unique UID given a string prefix.
34692 *
34693 * @param prefix
34694 */
34695 function getUid(prefix = '') {
34696 if (!(prefix in _uidPrefixes)) {
34697 _uidPrefixes[prefix] = 0;
34698 }
34699 _uidPrefixes[prefix] += 1;
34700 return prefix + _uidPrefixes[prefix].toString();
34701 }
34702
34703 /**
34704 * @license
34705 * Copyright 2018 Google LLC
34706 *
34707 * Use of this source code is governed by an MIT-style
34708 * license that can be found in the LICENSE file or at
34709 * https://opensource.org/licenses/MIT.
34710 * =============================================================================
34711 */
34712 const VALID_DATA_FORMAT_VALUES = ['channelsFirst', 'channelsLast'];
34713 const VALID_INTERPOLATION_FORMAT_VALUES = ['nearest', 'bilinear'];
34714 const VALID_PADDING_MODE_VALUES = ['valid', 'same', 'causal'];
34715 const VALID_POOL_MODE_VALUES = ['max', 'avg'];
34716 const VALID_BIDIRECTIONAL_MERGE_MODES = ['sum', 'mul', 'concat', 'ave'];
34717 const VALID_SAMPLE_WEIGHT_MODES = ['temporal'];
34718
34719 /**
34720 * @license
34721 * Copyright 2018 Google LLC
34722 *
34723 * Use of this source code is governed by an MIT-style
34724 * license that can be found in the LICENSE file or at
34725 * https://opensource.org/licenses/MIT.
34726 * =============================================================================
34727 */
34728 // A map from the requested scoped name of a Tensor to the number of Tensors
34729 // wanting that name so far. This allows enforcing name uniqueness by appending
34730 // an incrementing index, e.g. scope/name, scope/name_1, scope/name_2, etc.
34731 const nameMap = new Map();
34732 function checkDataFormat(value) {
34733 checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value);
34734 }
34735 function checkInterpolationFormat(value) {
34736 checkStringTypeUnionValue(VALID_INTERPOLATION_FORMAT_VALUES, 'InterpolationFormat', value);
34737 }
34738 function checkPaddingMode(value) {
34739 checkStringTypeUnionValue(VALID_PADDING_MODE_VALUES, 'PaddingMode', value);
34740 }
34741 function checkPoolMode(value) {
34742 checkStringTypeUnionValue(VALID_POOL_MODE_VALUES, 'PoolMode', value);
34743 }
34744 const _nameScopeStack = [];
34745 const _nameScopeDivider = '/';
34746 /**
34747 * Enter namescope, which can be nested.
34748 */
34749 function nameScope(name, fn) {
34750 _nameScopeStack.push(name);
34751 try {
34752 const val = fn();
34753 _nameScopeStack.pop();
34754 return val;
34755 }
34756 catch (e) {
34757 _nameScopeStack.pop();
34758 throw e;
34759 }
34760 }
34761 /**
34762 * Get the current namescope as a flat, concatenated string.
34763 */
34764 function currentNameScopePrefix() {
34765 if (_nameScopeStack.length === 0) {
34766 return '';
34767 }
34768 else {
34769 return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider;
34770 }
34771 }
34772 /**
34773 * Get the name a Tensor (or Variable) would have if not uniqueified.
34774 * @param tensorName
34775 * @return Scoped name string.
34776 */
34777 function getScopedTensorName(tensorName) {
34778 if (!isValidTensorName(tensorName)) {
34779 throw new Error('Not a valid tensor name: \'' + tensorName + '\'');
34780 }
34781 return currentNameScopePrefix() + tensorName;
34782 }
34783 /**
34784 * Get unique names for Tensors and Variables.
34785 * @param scopedName The fully-qualified name of the Tensor, i.e. as produced by
34786 * `getScopedTensorName()`.
34787 * @return A unique version of the given fully scoped name.
34788 * If this is the first time that the scoped name is seen in this session,
34789 * then the given `scopedName` is returned unaltered. If the same name is
34790 * seen again (producing a collision), an incrementing suffix is added to the
34791 * end of the name, so it takes the form 'scope/name_1', 'scope/name_2', etc.
34792 */
34793 function getUniqueTensorName(scopedName) {
34794 if (!isValidTensorName(scopedName)) {
34795 throw new Error('Not a valid tensor name: \'' + scopedName + '\'');
34796 }
34797 if (!nameMap.has(scopedName)) {
34798 nameMap.set(scopedName, 0);
34799 }
34800 const index = nameMap.get(scopedName);
34801 nameMap.set(scopedName, nameMap.get(scopedName) + 1);
34802 if (index > 0) {
34803 const result = `${scopedName}_${index}`;
34804 // Mark the composed name as used in case someone wants
34805 // to call getUniqueTensorName("name_1").
34806 nameMap.set(result, 1);
34807 return result;
34808 }
34809 else {
34810 return scopedName;
34811 }
34812 }
34813 const tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/);
34814 /**
34815 * Determine whether a string is a valid tensor name.
34816 * @param name
34817 * @returns A Boolean indicating whether `name` is a valid tensor name.
34818 */
34819 function isValidTensorName(name) {
34820 return !!name.match(tensorNameRegex);
34821 }
34822
34823 /**
34824 * @license
34825 * Copyright 2018 Google LLC
34826 *
34827 * Use of this source code is governed by an MIT-style
34828 * license that can be found in the LICENSE file or at
34829 * https://opensource.org/licenses/MIT.
34830 * =============================================================================
34831 */
34832 /**
34833 * Determine if a number is an integer.
34834 */
34835 function isInteger(x) {
34836 return x === parseInt(x.toString(), 10);
34837 }
34838 /**
34839 * Calculate the product of an array of numbers.
34840 * @param array The array to calculate the product over.
34841 * @param begin Beginning index, inclusive.
34842 * @param end Ending index, exclusive.
34843 * @return The product.
34844 */
34845 function arrayProd(array, begin, end) {
34846 if (begin == null) {
34847 begin = 0;
34848 }
34849 if (end == null) {
34850 end = array.length;
34851 }
34852 let prod = 1;
34853 for (let i = begin; i < end; ++i) {
34854 prod *= array[i];
34855 }
34856 return prod;
34857 }
34858 /**
34859 * Compute minimum value.
34860 * @param array
34861 * @return minimum value.
34862 */
34863 function min$1(array) {
34864 // same behavior as tf.min()
34865 if (array.length === 0) {
34866 return Number.NaN;
34867 }
34868 let min = Number.POSITIVE_INFINITY;
34869 for (let i = 0; i < array.length; i++) {
34870 const value = array[i];
34871 if (value < min) {
34872 min = value;
34873 }
34874 }
34875 return min;
34876 }
34877 /**
34878 * Compute maximum value.
34879 * @param array
34880 * @return maximum value
34881 */
34882 function max$1(array) {
34883 // same behavior as tf.max()
34884 if (array.length === 0) {
34885 return Number.NaN;
34886 }
34887 let max = Number.NEGATIVE_INFINITY;
34888 for (let i = 0; i < array.length; i++) {
34889 const value = array[i];
34890 if (value > max) {
34891 max = value;
34892 }
34893 }
34894 return max;
34895 }
34896 /**
34897 * Compute sum of array.
34898 * @param array
34899 * @return The sum.
34900 */
34901 function sum$2(array) {
34902 let sum = 0;
34903 for (let i = 0; i < array.length; i++) {
34904 const value = array[i];
34905 sum += value;
34906 }
34907 return sum;
34908 }
34909 /**
34910 * Compute mean of array.
34911 * @param array
34912 * @return The mean.
34913 */
34914 function mean$2(array) {
34915 return sum$2(array) / array.length;
34916 }
34917 /**
34918 * Compute variance of array.
34919 * @param array
34920 * @return The variance.
34921 */
34922 function variance(array) {
34923 const meanValue = mean$2(array);
34924 const demeaned = array.map((value) => value - meanValue);
34925 let sumSquare = 0;
34926 for (let i = 0; i < demeaned.length; i++) {
34927 const value = demeaned[i];
34928 sumSquare += value * value;
34929 }
34930 return sumSquare / array.length;
34931 }
34932 /**
34933 * Compute median of array.
34934 * @param array
34935 * @return The median value.
34936 */
34937 function median(array) {
34938 const arraySorted = array.slice().sort((a, b) => a - b);
34939 const lowIdx = Math.floor((arraySorted.length - 1) / 2);
34940 const highIdx = Math.ceil((arraySorted.length - 1) / 2);
34941 if (lowIdx === highIdx) {
34942 return arraySorted[lowIdx];
34943 }
34944 return (arraySorted[lowIdx] + arraySorted[highIdx]) / 2;
34945 }
34946 /**
34947 * Generate an array of integers in [begin, end).
34948 * @param begin Beginning integer, inclusive.
34949 * @param end Ending integer, exclusive.
34950 * @returns Range array.
34951 * @throws ValueError, iff `end` < `begin`.
34952 */
34953 function range$1(begin, end) {
34954 if (end < begin) {
34955 throw new ValueError(`end (${end}) < begin (${begin}) is forbidden.`);
34956 }
34957 const out = [];
34958 for (let i = begin; i < end; ++i) {
34959 out.push(i);
34960 }
34961 return out;
34962 }
34963
34964 /**
34965 * @license
34966 * Copyright 2018 Google LLC
34967 *
34968 * Use of this source code is governed by an MIT-style
34969 * license that can be found in the LICENSE file or at
34970 * https://opensource.org/licenses/MIT.
34971 * =============================================================================
34972 */
34973 let _epsilon;
34974 /**
34975 * Returns the value of the fuzz factor used in numeric expressions.
34976 */
34977 function epsilon() {
34978 if (_epsilon == null) {
34979 _epsilon = backend().epsilon();
34980 }
34981 return _epsilon;
34982 }
34983 /**
34984 * Sets the value of the fuzz factor used in numeric expressions.
34985 * @param e New value of epsilon.
34986 */
34987 function setEpsilon(e) {
34988 _epsilon = e;
34989 }
34990 /**
34991 * Returns the default image data format convention.
34992 */
34993 function imageDataFormat() {
34994 return 'channelsLast';
34995 }
34996
34997 /**
34998 * @license
34999 * Copyright 2018 Google LLC
35000 *
35001 * Use of this source code is governed by an MIT-style
35002 * license that can be found in the LICENSE file or at
35003 * https://opensource.org/licenses/MIT.
35004 * =============================================================================
35005 */
35006 // tslint:enable
35007 /* Setting and getting backend from deeplearn.js. */
35008 // Default deeplearn.js backend is WebGL (GPU).
35009 let backend$1 = 'webgl';
35010 function setBackend$1(requestedBackend) {
35011 setBackend(requestedBackend);
35012 backend$1 = requestedBackend;
35013 }
35014 function getBackend$1() {
35015 return backend$1;
35016 }
35017 /**
35018 * Indicates whether the backend is operating symbolically.
35019 *
35020 * This function will be used to determine how to interpret user code. If
35021 * it returns true, calls to the backend construct a symbolic graph; if
35022 * it returns false, calls to the backend execute immediately.
35023 */
35024 function isBackendSymbolic() {
35025 return false;
35026 }
35027 /**
35028 * Get the number of elements in a Tensor.
35029 * @param x The Tensor.
35030 * @return Number of elements in `x`.
35031 */
35032 function countParams(x) {
35033 const shape = x.shape;
35034 if (shape.length > 0) {
35035 return shape.reduce((a, b) => a * b);
35036 }
35037 else {
35038 // Scalar.
35039 return 1;
35040 }
35041 }
35042 /**
35043 * Casts a tensor to a different dtype and returns it.
35044 * @param x Input tensor.
35045 * @param dtype String: 'float32'|'int32'|'bool'.
35046 * @returns Tensor of the specified `dtype`.
35047 */
35048 function cast$1(x, dtype) {
35049 return cast(x, dtype);
35050 }
35051 /**
35052 * Adds a 1-sized dimension at index "axis".
35053 * @param x Input tensor.
35054 * @param axis Position where to add the new axis.
35055 * @returns Result of the dimension expansion.
35056 */
35057 function expandDims$1(x, axis = -1) {
35058 const outShape = x.shape.slice();
35059 if (axis < 0) {
35060 axis = outShape.length + axis + 1;
35061 }
35062 outShape.splice(axis, 0, 1);
35063 return reshape(x, outShape);
35064 }
35065 /**
35066 * Repeats a 2D tensor.
35067 *
35068 * If `x` has shape `[samples, dim]` and `n` is 2, for example, the output
35069 * will have shape `[samples, 2, dim]`.
35070 *
35071 * @param x Input tensor.
35072 * @param n Integer, number of times to repeat.
35073 * @returns The result of the repeat operation.
35074 * @throws ValueError: If input tensor is not 2D.
35075 */
35076 function repeat(x, n) {
35077 return tidy(() => {
35078 if (x.shape.length !== 2) {
35079 throw new ValueError(`repeat() expects a rank-2 tensor, but received a ` +
35080 `rank-${x.shape.length} tensor.`);
35081 }
35082 const y = expandDims$1(x, 1);
35083 return tile$1(y, [1, n, 1]);
35084 });
35085 }
35086 /**
35087 * Flatten a Tensor into 1D.
35088 * @param x Input tensor.
35089 * @return The result of the flattening `x`.
35090 */
35091 function flatten$1(x) {
35092 const newShape = [arrayProd(x.shape)];
35093 return reshape(x, newShape);
35094 }
35095 /**
35096 * Turn a nD tensor into a 2D tensor with same 0th dimension.
35097 * In other words, it flattens each data samples of a batch.
35098 *
35099 * @param x The tensor to flatten. The rank of this tensor is required to be 2
35100 * or higher.
35101 * @return The result of the flattening.
35102 */
35103 function batchFlatten(x) {
35104 if (x.rank <= 1) {
35105 throw new ValueError(`batchFlatten requires a minimum rank of 2. Got rank: ${x.rank}.`);
35106 }
35107 const newShape = [x.shape[0], arrayProd(x.shape, 1)];
35108 return reshape(x, newShape);
35109 }
35110 /**
35111 * Do slicing along the first axis.
35112 * @param array input `tf.Tensor`.
35113 * @param start starting index, inclusive.
35114 * @param size size of the slice along the first axis.
35115 * @returns result of the slicing.
35116 * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
35117 */
35118 function sliceAlongFirstAxis(array, start, size) {
35119 return tidy(() => {
35120 switch (array.rank) {
35121 case 1:
35122 return slice1d(array, start, size);
35123 case 2:
35124 return slice2d(array, [start, 0], [size, array.shape[1]]);
35125 case 3:
35126 return slice3d(array, [start, 0, 0], [size, array.shape[1], array.shape[2]]);
35127 case 4:
35128 return slice4d(array, [start, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3]]);
35129 case 5:
35130 return slice(array, [start, 0, 0, 0, 0], [
35131 size, array.shape[1], array.shape[2], array.shape[3], array.shape[4]
35132 ]);
35133 case 6:
35134 return slice(array, [start, 0, 0, 0, 0, 0], [
35135 size, array.shape[1], array.shape[2], array.shape[3], array.shape[4],
35136 array.shape[5]
35137 ]);
35138 default:
35139 throw new ValueError(`sliceAlongFirstAxis() received an unsupported tensor rank: ` +
35140 `${array.rank}`);
35141 }
35142 });
35143 }
35144 /**
35145 * Do slicing along the last axis.
35146 * @param array input `tf.Tensor`.
35147 * @param start starting index, inclusive.
35148 * @param size size of the slice along the last axis.
35149 * @returns result of the slicing.
35150 * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
35151 */
35152 function sliceAlongLastAxis(array, start, size) {
35153 return tidy(() => {
35154 switch (array.rank) {
35155 case 1:
35156 return slice1d(array, start, size);
35157 case 2:
35158 return slice2d(array, [0, start], [array.shape[0], size]);
35159 case 3:
35160 return slice3d(array, [0, 0, start], [array.shape[0], array.shape[1], size]);
35161 case 4:
35162 return slice4d(array, [0, 0, 0, start], [array.shape[0], array.shape[1], array.shape[2], size]);
35163 default:
35164 throw new ValueError(`sliceAlongLastAxis() received an unsupported tensor rank: ` +
35165 `${array.rank}`);
35166 }
35167 });
35168 }
35169 /**
35170 * Do slicing along the sepcified axis.
35171 * @param array input `tf.Tensor`.
35172 * @param start starting index, inclusive.
35173 * @param size of the slice along the chosen axis.
35174 * @param choose an axis.
35175 * @returns result of the slicing.
35176 * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
35177 */
35178 function sliceAlongAxis(array, start, size, axis) {
35179 return tidy(() => {
35180 switch (array.rank) {
35181 case 1:
35182 return slice1d(array, start, size);
35183 case 2:
35184 switch (axis) {
35185 case 1:
35186 return sliceAlongFirstAxis(array, start, size);
35187 case 2:
35188 return sliceAlongLastAxis(array, start, size);
35189 default:
35190 throw new ValueError(`The axis is not within the rank of the tensor ` +
35191 `${axis}`);
35192 }
35193 case 3:
35194 switch (axis) {
35195 case 1:
35196 return sliceAlongFirstAxis(array, start, size);
35197 case 2:
35198 return slice3d(array, [0, start, 0], [array.shape[0], size, array.shape[2]]);
35199 case 3:
35200 return sliceAlongLastAxis(array, start, size);
35201 default:
35202 throw new ValueError(`The axis is not within the rank of the tensor ` +
35203 `${axis}`);
35204 }
35205 case 4:
35206 switch (axis) {
35207 case 1:
35208 return sliceAlongFirstAxis(array, start, size);
35209 case 2:
35210 return slice4d(array, [0, start, 0, 0], [array.shape[0], size, array.shape[2], array.shape[3]]);
35211 case 3:
35212 return slice4d(array, [0, 0, start, 0], [array.shape[0], array.shape[1], size, array.shape[3]]);
35213 case 4:
35214 return sliceAlongLastAxis(array, start, size);
35215 default:
35216 throw new ValueError(`The axis is not within the rank of the tensor ` +
35217 `${axis}`);
35218 }
35219 default:
35220 throw new ValueError(`sliceAlongLastAxis() received an unsupported tensor rank: ` +
35221 `${array.rank}`);
35222 }
35223 });
35224 }
35225 /**
35226 * Concatenates a list of tensors alongside the specified axis.
35227 * @param tensors `Array` of tensors to concatenate.
35228 * @param axis Concatenation axis.
35229 * @returns The result of the concatenation.
35230 */
35231 function concatenate(tensors, axis = -1) {
35232 let rank;
35233 if (axis < 0) {
35234 rank = tensors[0].rank;
35235 if (rank !== 0) {
35236 axis = rank;
35237 }
35238 else {
35239 axis = 0;
35240 }
35241 }
35242 if (axis === tensors[0].rank) {
35243 // Porting Note: This is necessary because tfc.concat() requires axis to be
35244 // in the interval [-rank, rank).
35245 axis = -1;
35246 }
35247 // Porting Note: Sparse concat is not supported yet.
35248 return concat(tensors, axis);
35249 }
35250 /**
35251 * Concatenate two arrays along the first dimension.
35252 * @param a The 1st `tf.Tensor` to concatenate.
35253 * @param b The 2nd `tf.Tensor` to concatenate.
35254 * @returns Result of the concatenation.
35255 * @throws ValueError: If `a` is of an unsupported subtype of `tf.Tensor`.
35256 */
35257 function concatAlongFirstAxis(a, b) {
35258 switch (a.rank) {
35259 case 1:
35260 return concat1d([a, b]);
35261 case 2:
35262 return concat2d([a, b], 0);
35263 case 3:
35264 return concat3d([a, b], 0);
35265 case 4:
35266 return concat4d([a, b], 0);
35267 default:
35268 throw new ValueError(`concatAlongFirstAxis() received an unsupported ` +
35269 `tensor rank: ${a.rank}`);
35270 }
35271 }
35272 /**
35273 * Creates a tensor by tiling `x` by `n`.
35274 * @param x A tensor.
35275 * @param n An Array of integers or a single integer. If an Array, the length
35276 * must be the same as the number of dimensions in `x`. If a single integer,
35277 * it will be treated as an Array of length 1.
35278 */
35279 function tile$1(x, n) {
35280 if (!Array.isArray(n)) {
35281 n = [n];
35282 }
35283 if (x.rank !== n.length) {
35284 throw new ValueError(`The length of input n (${n.length}) does not match ` +
35285 `the number of dimensions in input x (${x.rank})`);
35286 }
35287 return tile(x, n);
35288 }
35289 /* Creation of random tensors. */
35290 /**
35291 * Get a tensor with normal distribution of values.
35292 *
35293 * @param shape Shape of the tensor.
35294 * @param mean mean value of the normal distribution.
35295 * @param stddev standard deviation of the normal distribution.
35296 * @param dtype
35297 * @param seed
35298 * @return The normal tensor.
35299 */
35300 function randomNormal$1(shape, mean = 0.0, stddev = 1.0, dtype, seed) {
35301 return randomNormal(shape, mean, stddev, dtype, seed);
35302 }
35303 /* Linear Algebra */
35304 /**
35305 * Multiply two tensors and returns the result as a tensor.
35306 *
35307 * For 2D tensors, this is equivalent to matrix multiplication (matMul).
35308 * For tensors of higher ranks, it follows the Theano behavior,
35309 * (e.g. `(2, 3) * (4, 3, 5) -> (2, 4, 5)`). From the Theano documentation:
35310 *
35311 * For N dimensions it is a sum product over the last axis of x and the
35312 * second-to-last of y:
35313 *
35314 * @param a A tensor of at least rank 2.
35315 * @param b A tensor of at least rank 2.
35316 * @param activation (optional) A string identifying the activation
35317 * function.
35318 * @return Result of the dot operation.
35319 */
35320 function dot$1(a, b, activation, bias) {
35321 if ((a.rank < 2) || (b.rank < 2)) {
35322 throw new NotImplementedError(`dot requires both inputs to be rank >= 2` +
35323 ` but got x shape = ${a.shape} and y shape = ${b.shape}`);
35324 }
35325 if (b.rank >= 3) {
35326 const xLastDim = a.shape.slice(-1)[0];
35327 const ySecondLastDim = b.shape.slice(-2)[0];
35328 if (xLastDim !== ySecondLastDim) {
35329 throw new NotImplementedError(`If rank y >= 3, then the second last dim` +
35330 ` of y must equal the last dim of x but got x shape = ${a.shape} and ` +
35331 ` y shape = ${b.shape}`);
35332 }
35333 }
35334 // Handle basic 2D x 2D case.
35335 if ((a.rank === 2) && (b.rank === 2)) {
35336 const transposeA = false;
35337 const transposeB = false;
35338 // tfc.fused.matMul only fuses certain activation functions. Unsupported
35339 // activation functions are treated as 'linear' activations, which is
35340 // equivalent to a no-op.
35341 return matMul$1({
35342 a,
35343 b: b,
35344 transposeA,
35345 transposeB,
35346 bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
35347 activation
35348 });
35349 }
35350 else {
35351 // Reshape x into the analogous 2D Tensor.
35352 const aFirstDims = a.shape.slice(); // Holds all but the last dim of x.
35353 const aLastDim = aFirstDims.pop();
35354 a = reshape(a, [-1, aLastDim]);
35355 // Reshape y into the analogous 2D Tensor, and keep track of the
35356 // required dimensions to reproduce the output shape.
35357 const bShape = b.shape.slice();
35358 const bLastDim = bShape.pop();
35359 const ySecondLastDim = bShape.pop();
35360 const yOtherDims = [...bShape, bLastDim];
35361 // permutation should be like [r-2, 0, 1, 2, ... r-4, r-3, r-1]
35362 // where r is the rank of y.
35363 const perm = Array.from({ length: b.rank }, (_, i) => {
35364 if (i === 0) {
35365 return b.rank - 2;
35366 }
35367 else if (i <= b.rank - 2) {
35368 return i - 1;
35369 }
35370 return i;
35371 });
35372 b = reshape(transpose(b, perm), [ySecondLastDim, -1]);
35373 // Multiply x and y as 2D Tensors, and then reshape back to original.
35374 const outputShape = [...aFirstDims, ...yOtherDims];
35375 const transposeA = false;
35376 const transposeB = false;
35377 return reshape(matMul$1({
35378 a,
35379 b,
35380 transposeA,
35381 transposeB,
35382 bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
35383 activation
35384 }), outputShape);
35385 }
35386 }
35387 /**
35388 * Compute the sign Tensor of an input Tensor.
35389 *
35390 * Elements of the input `tf.Tensor` that are === 0 are mapped to 0.
35391 * Elements of the input `tf.Tensor` that are > 0 are mapped to 1.
35392 * Elements of the input `tf.Tensor` that are < 0 are mapped to -1.
35393 *
35394 * @param x Input `tf.Tensor`.
35395 * @return The sign `tf.Tensor`.
35396 */
35397 function sign$1(x) {
35398 // TODO(cais): Move to the core.
35399 return tidy(() => {
35400 const zerosLikeX = zerosLike(x);
35401 const onesLikeX = onesLike(x);
35402 return where(equal(x, zerosLikeX), zerosLikeX, where(greater(x, zerosLike(x)), onesLikeX, mul(-1, onesLikeX)));
35403 });
35404 }
35405 /**
35406 * Computes the one-hot representation of an integer tensor.
35407 * @param indices nD integer tensor of shape
35408 * `(batch_size, dim1, dim2, ... dim(n-1))`
35409 * @param numClasses Integer, number of classes to consider.
35410 * @returns (n + 1)D one hot representation of the input
35411 * with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`
35412 */
35413 function oneHot$1(indices, numClasses) {
35414 return tidy(() => {
35415 if (indices.rank !== 1) {
35416 throw new Error('Only 1D one-hot tensors are supported in the ' +
35417 'deeplearn backend, at present.');
35418 }
35419 indices = cast(indices, 'int32');
35420 return cast(oneHot(indices, numClasses), 'float32');
35421 });
35422 }
35423 /* Elementary math functions. */
35424 /**
35425 * Retrieves the elements of indices `indices` in the tensor `reference`.
35426 * @param reference A tensor.
35427 * @param indices An integer tensor of indices or an `Array` of integers.
35428 * @param axis Axis along which to perform the gather operation.
35429 * @returns The result of the gathering as a tensor.
35430 */
35431 function gather$1(reference, indices, axis) {
35432 return tidy(() => {
35433 if (Array.isArray(indices)) {
35434 indices = tensor1d(indices, 'int32');
35435 }
35436 else {
35437 indices = cast(indices, 'int32');
35438 }
35439 return gather(reference, indices, axis);
35440 });
35441 }
35442 /**
35443 * Element-wise square.
35444 * @param x Input tensor.
35445 * @return element-wise x^2
35446 */
35447 function square$1(x) {
35448 return mul(x, x);
35449 }
35450 /**
35451 * Element-wise exponentiation.
35452 *
35453 * Porting Note: In PyKeras, `a` (the exponent) is a Python integer, which
35454 * takes advatnage of the backend's (e.g., TensorFlow's) automatic
35455 * conversion to tensor. Here we allow `a` to be either a number or a tensor.
35456 *
35457 * @param x The base tensor.
35458 * @param a The exponent, tensor or number. If a number, it is rounded to the
35459 * nearest integer and converted to a tensor.
35460 * @returns A tensor of the same shape as `x`.
35461 */
35462 function pow$1(x, a) {
35463 return tidy(() => {
35464 if (typeof (a) === 'number') {
35465 a = scalar(Math.round(a), 'int32');
35466 }
35467 if (a.dtype !== 'int32') {
35468 throw new NotImplementedError(`Non-int32 dtype (${a.dtype}) is not supported by pow() yet`);
35469 }
35470 return pow(x, a);
35471 });
35472 }
35473 /**
35474 * Reshapes bias tensor according to rank of x.
35475 */
35476 function reshapeBias(xRank, bias, dataFormat) {
35477 const biasShape = bias.shape;
35478 if (bias.rank !== 1 && bias.rank !== xRank) {
35479 throw new ValueError(`Unexpected bias dimensions: ${bias.rank}` +
35480 `; expected it to be 1 or ${xRank}`);
35481 }
35482 if (xRank === 5) {
35483 if (dataFormat === 'channelsFirst') {
35484 if (biasShape.length === 1) {
35485 return reshape(bias, [1, biasShape[0], 1, 1, 1]);
35486 }
35487 else {
35488 return reshape(bias, [1, biasShape[3], biasShape[0], biasShape[1], biasShape[2]]);
35489 }
35490 }
35491 else if (dataFormat === 'channelsLast') {
35492 if (biasShape.length === 1) {
35493 return reshape(bias, [1, 1, 1, 1, biasShape[0]]);
35494 }
35495 else {
35496 return reshape(bias, [1].concat(biasShape));
35497 }
35498 }
35499 }
35500 else if (xRank === 4) {
35501 if (dataFormat === 'channelsFirst') {
35502 if (biasShape.length === 1) {
35503 return reshape(bias, [1, biasShape[0], 1, 1]);
35504 }
35505 else {
35506 return reshape(bias, [1, biasShape[2], biasShape[0], biasShape[1]]);
35507 }
35508 }
35509 else if (dataFormat === 'channelsLast') {
35510 if (biasShape.length === 1) {
35511 return reshape(bias, [1, 1, 1, biasShape[0]]);
35512 }
35513 else {
35514 return reshape(bias, [1].concat(biasShape));
35515 }
35516 }
35517 }
35518 else if (xRank === 3) {
35519 if (dataFormat === 'channelsFirst') {
35520 if (biasShape.length === 1) {
35521 return reshape(bias, [1, biasShape[0], 1]);
35522 }
35523 else {
35524 return reshape(bias, [1, biasShape[1], biasShape[0]]);
35525 }
35526 }
35527 else if (dataFormat === 'channelsLast') {
35528 if (biasShape.length === 1) {
35529 return reshape(bias, [1, 1, biasShape[0]]);
35530 }
35531 else {
35532 return reshape(bias, [1].concat(biasShape));
35533 }
35534 }
35535 }
35536 else if (xRank < 3) {
35537 return bias;
35538 }
35539 throw new ValueError(`Unsupported input rank by biasAdd: ${bias.rank}`);
35540 }
35541 /* Neural-network operations. */
35542 /**
35543 * Add a bias to a tensor.
35544 *
35545 * @param x The tensor to add the bias to.
35546 * @param bias The bias to add to `x`. Must be 1D or the same rank as `x`.
35547 * @return Result of the bias adding.
35548 * @throws ValueError: If the rank of `bias` is incorrect.
35549 */
35550 function biasAdd(x, bias, dataFormat) {
35551 return tidy(() => {
35552 if (dataFormat == null) {
35553 dataFormat = imageDataFormat();
35554 }
35555 checkDataFormat(dataFormat);
35556 return add$1(x, reshapeBias(x.rank, bias, dataFormat));
35557 });
35558 }
35559 /**
35560 * Exponential linear unit (ELU).
35561 * @param x A tensor or variable to compute the activation function for.
35562 * @param alpha: A scalar, a scaling factor for the negative section.
35563 * @return Output of the ELU operation.
35564 */
35565 function elu$1(x, alpha = 1) {
35566 // TODO(cais): Add support for alpha values other than 1.
35567 if (alpha !== 1) {
35568 throw new NotImplementedError(`Support for alpha values other than 1 (${alpha}) is not implemented ` +
35569 `yet.`);
35570 }
35571 return elu(x);
35572 }
35573 /**
35574 * Softsign of a tensor.
35575 *
35576 * Defined as x / (abs(x) + 1), element-wise.
35577 *
35578 * @param x: Input.
35579 * @returns Output.
35580 */
35581 function softsign(x) {
35582 return tidy(() => div(x, add$1(abs(x), 1)));
35583 }
35584 /**
35585 * Sets entries in `x` to zero at random, while scaling the entire tensor.
35586 *
35587 * @param x input tensor.
35588 * @param level fraction of the entries in the tensor that will be set to 0.
35589 * @param noiseShape shape of randomly generated keep/drop flags, must be
35590 * broadcastable to the shape of `x`. Optional.
35591 * @param seed random seed to ensure determinism. Optional.
35592 * @returns Result of the dropout operation.
35593 */
35594 function dropout$1(x, level, noiseShape, seed) {
35595 return tidy(() => dropout(x, level, noiseShape, seed));
35596 }
35597 /**
35598 * Element-wise, segment-wise linear approximation of sigmoid.
35599 *
35600 * Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
35601 * In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
35602 *
35603 * @param x Input tensor.
35604 * @returns Output tensor.
35605 */
35606 function hardSigmoid(x) {
35607 return tidy(() => {
35608 const y = add$1(.5, mul(.2, x));
35609 return clipByValue(y, 0, 1);
35610 });
35611 }
35612 /**
35613 * Invoke `x` in the training phase, and `alt` otherwise.
35614 *
35615 * Porting Note: We do not create placeholder tensors for the `training`
35616 * boolean flag here, because there is no such thing in the TF.js imperative
35617 * backend.
35618 *
35619 * @param x The function to invoke iff `training` is `true`.
35620 * @param alt The function to invoke iff `training` is `false`.
35621 * @param training Boolean flag for whether training phase is active.
35622 * @returns The return value of `x()` if `training` is `true`, or the return
35623 * value of `alt()` if `training` is `false`.
35624 */
35625 function inTrainPhase(x, alt, training = false) {
35626 return training ? x() : alt();
35627 }
35628
35629 /**
35630 * @license
35631 * Copyright 2018 Google LLC
35632 *
35633 * Use of this source code is governed by an MIT-style
35634 * license that can be found in the LICENSE file or at
35635 * https://opensource.org/licenses/MIT.
35636 * =============================================================================
35637 */
35638 const VALID_FAN_MODE_VALUES = ['fanIn', 'fanOut', 'fanAvg'];
35639 const VALID_DISTRIBUTION_VALUES = ['normal', 'uniform', 'truncatedNormal'];
35640 // We can't easily extract a string[] from the string union type, but we can
35641 // recapitulate the list, enforcing at compile time that the values are valid
35642 // and that we have the right number of them.
35643 /**
35644 * A string array of valid Initializer class names.
35645 *
35646 * This is guaranteed to match the `InitializerClassName` union type.
35647 */
35648 const initializerClassNames = [
35649 'Zeros', 'Ones', 'Constant', 'RandomNormal', 'RandomUniform',
35650 'TruncatedNormal', 'VarianceScaling', 'Orthogonal', 'Identity'
35651 ];
35652
35653 /**
35654 * @license
35655 * Copyright 2018 Google LLC
35656 *
35657 * Use of this source code is governed by an MIT-style
35658 * license that can be found in the LICENSE file or at
35659 * https://opensource.org/licenses/MIT.
35660 * =============================================================================
35661 */
35662 function checkFanMode(value) {
35663 checkStringTypeUnionValue(VALID_FAN_MODE_VALUES, 'FanMode', value);
35664 }
35665 function checkDistribution(value) {
35666 checkStringTypeUnionValue(VALID_DISTRIBUTION_VALUES, 'Distribution', value);
35667 }
35668 /**
35669 * Initializer base class.
35670 *
35671 * @doc {
35672 * heading: 'Initializers', subheading: 'Classes', namespace: 'initializers'}
35673 */
35674 class Initializer extends Serializable {
35675 fromConfigUsesCustomObjects() {
35676 return false;
35677 }
35678 getConfig() {
35679 return {};
35680 }
35681 }
35682 class Zeros extends Initializer {
35683 apply(shape, dtype) {
35684 return zeros(shape, dtype);
35685 }
35686 }
35687 /** @nocollapse */
35688 Zeros.className = 'Zeros';
35689 registerClass(Zeros);
35690 class Ones extends Initializer {
35691 apply(shape, dtype) {
35692 return ones$1(shape, dtype);
35693 }
35694 }
35695 /** @nocollapse */
35696 Ones.className = 'Ones';
35697 registerClass(Ones);
35698 class Constant extends Initializer {
35699 constructor(args) {
35700 super();
35701 if (typeof args !== 'object') {
35702 throw new ValueError(`Expected argument of type ConstantConfig but got ${args}`);
35703 }
35704 if (args.value === undefined) {
35705 throw new ValueError(`config must have value set but got ${args}`);
35706 }
35707 this.value = args.value;
35708 }
35709 apply(shape, dtype) {
35710 return tidy(() => mul(scalar(this.value), ones$1(shape, dtype)));
35711 }
35712 getConfig() {
35713 return {
35714 value: this.value,
35715 };
35716 }
35717 }
35718 /** @nocollapse */
35719 Constant.className = 'Constant';
35720 registerClass(Constant);
35721 class RandomUniform extends Initializer {
35722 constructor(args) {
35723 super();
35724 this.DEFAULT_MINVAL = -0.05;
35725 this.DEFAULT_MAXVAL = 0.05;
35726 this.minval = args.minval || this.DEFAULT_MINVAL;
35727 this.maxval = args.maxval || this.DEFAULT_MAXVAL;
35728 this.seed = args.seed;
35729 }
35730 apply(shape, dtype) {
35731 return randomUniform(shape, this.minval, this.maxval, dtype);
35732 }
35733 getConfig() {
35734 return { minval: this.minval, maxval: this.maxval, seed: this.seed };
35735 }
35736 }
35737 /** @nocollapse */
35738 RandomUniform.className = 'RandomUniform';
35739 registerClass(RandomUniform);
35740 class RandomNormal extends Initializer {
35741 constructor(args) {
35742 super();
35743 this.DEFAULT_MEAN = 0.;
35744 this.DEFAULT_STDDEV = 0.05;
35745 this.mean = args.mean || this.DEFAULT_MEAN;
35746 this.stddev = args.stddev || this.DEFAULT_STDDEV;
35747 this.seed = args.seed;
35748 }
35749 apply(shape, dtype) {
35750 dtype = dtype || 'float32';
35751 if (dtype !== 'float32' && dtype !== 'int32') {
35752 throw new NotImplementedError(`randomNormal does not support dType ${dtype}.`);
35753 }
35754 return randomNormal$1(shape, this.mean, this.stddev, dtype, this.seed);
35755 }
35756 getConfig() {
35757 return { mean: this.mean, stddev: this.stddev, seed: this.seed };
35758 }
35759 }
35760 /** @nocollapse */
35761 RandomNormal.className = 'RandomNormal';
35762 registerClass(RandomNormal);
35763 class TruncatedNormal extends Initializer {
35764 constructor(args) {
35765 super();
35766 this.DEFAULT_MEAN = 0.;
35767 this.DEFAULT_STDDEV = 0.05;
35768 this.mean = args.mean || this.DEFAULT_MEAN;
35769 this.stddev = args.stddev || this.DEFAULT_STDDEV;
35770 this.seed = args.seed;
35771 }
35772 apply(shape, dtype) {
35773 dtype = dtype || 'float32';
35774 if (dtype !== 'float32' && dtype !== 'int32') {
35775 throw new NotImplementedError(`truncatedNormal does not support dType ${dtype}.`);
35776 }
35777 return truncatedNormal(shape, this.mean, this.stddev, dtype, this.seed);
35778 }
35779 getConfig() {
35780 return { mean: this.mean, stddev: this.stddev, seed: this.seed };
35781 }
35782 }
35783 /** @nocollapse */
35784 TruncatedNormal.className = 'TruncatedNormal';
35785 registerClass(TruncatedNormal);
35786 class Identity$1 extends Initializer {
35787 constructor(args) {
35788 super();
35789 this.gain = args.gain != null ? args.gain : 1.0;
35790 }
35791 apply(shape, dtype) {
35792 return tidy(() => {
35793 if (shape.length !== 2 || shape[0] !== shape[1]) {
35794 throw new ValueError('Identity matrix initializer can only be used for' +
35795 ' 2D square matrices.');
35796 }
35797 else {
35798 return mul(this.gain, eye(shape[0]));
35799 }
35800 });
35801 }
35802 getConfig() {
35803 return { gain: this.gain };
35804 }
35805 }
35806 /** @nocollapse */
35807 Identity$1.className = 'Identity';
35808 registerClass(Identity$1);
35809 /**
35810 * Computes the number of input and output units for a weight shape.
35811 * @param shape Shape of weight.
35812 * @param dataFormat data format to use for convolution kernels.
35813 * Note that all kernels in Keras are standardized on the
35814 * CHANNEL_LAST ordering (even when inputs are set to CHANNEL_FIRST).
35815 * @return An length-2 array: fanIn, fanOut.
35816 */
35817 function computeFans(shape, dataFormat = 'channelsLast') {
35818 let fanIn;
35819 let fanOut;
35820 checkDataFormat(dataFormat);
35821 if (shape.length === 2) {
35822 fanIn = shape[0];
35823 fanOut = shape[1];
35824 }
35825 else if ([3, 4, 5].indexOf(shape.length) !== -1) {
35826 if (dataFormat === 'channelsFirst') {
35827 const receptiveFieldSize = arrayProd(shape, 2);
35828 fanIn = shape[1] * receptiveFieldSize;
35829 fanOut = shape[0] * receptiveFieldSize;
35830 }
35831 else if (dataFormat === 'channelsLast') {
35832 const receptiveFieldSize = arrayProd(shape, 0, shape.length - 2);
35833 fanIn = shape[shape.length - 2] * receptiveFieldSize;
35834 fanOut = shape[shape.length - 1] * receptiveFieldSize;
35835 }
35836 }
35837 else {
35838 const shapeProd = arrayProd(shape);
35839 fanIn = Math.sqrt(shapeProd);
35840 fanOut = Math.sqrt(shapeProd);
35841 }
35842 return [fanIn, fanOut];
35843 }
35844 class VarianceScaling extends Initializer {
35845 /**
35846 * Constructor of VarianceScaling.
35847 * @throws ValueError for invalid value in scale.
35848 */
35849 constructor(args) {
35850 super();
35851 if (args.scale < 0.0) {
35852 throw new ValueError(`scale must be a positive float. Got: ${args.scale}`);
35853 }
35854 this.scale = args.scale == null ? 1.0 : args.scale;
35855 this.mode = args.mode == null ? 'fanIn' : args.mode;
35856 checkFanMode(this.mode);
35857 this.distribution =
35858 args.distribution == null ? 'normal' : args.distribution;
35859 checkDistribution(this.distribution);
35860 this.seed = args.seed;
35861 }
35862 apply(shape, dtype) {
35863 const fans = computeFans(shape);
35864 const fanIn = fans[0];
35865 const fanOut = fans[1];
35866 let scale = this.scale;
35867 if (this.mode === 'fanIn') {
35868 scale /= Math.max(1, fanIn);
35869 }
35870 else if (this.mode === 'fanOut') {
35871 scale /= Math.max(1, fanOut);
35872 }
35873 else {
35874 scale /= Math.max(1, (fanIn + fanOut) / 2);
35875 }
35876 if (this.distribution === 'normal') {
35877 const stddev = Math.sqrt(scale);
35878 dtype = dtype || 'float32';
35879 if (dtype !== 'float32' && dtype !== 'int32') {
35880 throw new NotImplementedError(`${this.getClassName()} does not support dType ${dtype}.`);
35881 }
35882 return truncatedNormal(shape, 0, stddev, dtype, this.seed);
35883 }
35884 else {
35885 const limit = Math.sqrt(3 * scale);
35886 return randomUniform(shape, -limit, limit, dtype);
35887 }
35888 }
35889 getConfig() {
35890 return {
35891 scale: this.scale,
35892 mode: this.mode,
35893 distribution: this.distribution,
35894 seed: this.seed
35895 };
35896 }
35897 }
35898 /** @nocollapse */
35899 VarianceScaling.className = 'VarianceScaling';
35900 registerClass(VarianceScaling);
35901 class GlorotUniform extends VarianceScaling {
35902 /**
35903 * Constructor of GlorotUniform
35904 * @param scale
35905 * @param mode
35906 * @param distribution
35907 * @param seed
35908 */
35909 constructor(args) {
35910 super({
35911 scale: 1.0,
35912 mode: 'fanAvg',
35913 distribution: 'uniform',
35914 seed: args == null ? null : args.seed
35915 });
35916 }
35917 getClassName() {
35918 // In Python Keras, GlorotUniform is not a class, but a helper method
35919 // that creates a VarianceScaling object. Use 'VarianceScaling' as
35920 // class name to be compatible with that.
35921 return VarianceScaling.className;
35922 }
35923 }
35924 /** @nocollapse */
35925 GlorotUniform.className = 'GlorotUniform';
35926 registerClass(GlorotUniform);
35927 class GlorotNormal extends VarianceScaling {
35928 /**
35929 * Constructor of GlorotNormal.
35930 * @param scale
35931 * @param mode
35932 * @param distribution
35933 * @param seed
35934 */
35935 constructor(args) {
35936 super({
35937 scale: 1.0,
35938 mode: 'fanAvg',
35939 distribution: 'normal',
35940 seed: args == null ? null : args.seed
35941 });
35942 }
35943 getClassName() {
35944 // In Python Keras, GlorotNormal is not a class, but a helper method
35945 // that creates a VarianceScaling object. Use 'VarianceScaling' as
35946 // class name to be compatible with that.
35947 return VarianceScaling.className;
35948 }
35949 }
35950 /** @nocollapse */
35951 GlorotNormal.className = 'GlorotNormal';
35952 registerClass(GlorotNormal);
35953 class HeNormal extends VarianceScaling {
35954 constructor(args) {
35955 super({
35956 scale: 2.0,
35957 mode: 'fanIn',
35958 distribution: 'normal',
35959 seed: args == null ? null : args.seed
35960 });
35961 }
35962 getClassName() {
35963 // In Python Keras, HeNormal is not a class, but a helper method
35964 // that creates a VarianceScaling object. Use 'VarianceScaling' as
35965 // class name to be compatible with that.
35966 return VarianceScaling.className;
35967 }
35968 }
35969 /** @nocollapse */
35970 HeNormal.className = 'HeNormal';
35971 registerClass(HeNormal);
35972 class HeUniform extends VarianceScaling {
35973 constructor(args) {
35974 super({
35975 scale: 2.0,
35976 mode: 'fanIn',
35977 distribution: 'uniform',
35978 seed: args == null ? null : args.seed
35979 });
35980 }
35981 getClassName() {
35982 // In Python Keras, HeUniform is not a class, but a helper method
35983 // that creates a VarianceScaling object. Use 'VarianceScaling' as
35984 // class name to be compatible with that.
35985 return VarianceScaling.className;
35986 }
35987 }
35988 /** @nocollapse */
35989 HeUniform.className = 'HeUniform';
35990 registerClass(HeUniform);
35991 class LeCunNormal extends VarianceScaling {
35992 constructor(args) {
35993 super({
35994 scale: 1.0,
35995 mode: 'fanIn',
35996 distribution: 'normal',
35997 seed: args == null ? null : args.seed
35998 });
35999 }
36000 getClassName() {
36001 // In Python Keras, LeCunNormal is not a class, but a helper method
36002 // that creates a VarianceScaling object. Use 'VarianceScaling' as
36003 // class name to be compatible with that.
36004 return VarianceScaling.className;
36005 }
36006 }
36007 /** @nocollapse */
36008 LeCunNormal.className = 'LeCunNormal';
36009 registerClass(LeCunNormal);
36010 class LeCunUniform extends VarianceScaling {
36011 constructor(args) {
36012 super({
36013 scale: 1.0,
36014 mode: 'fanIn',
36015 distribution: 'uniform',
36016 seed: args == null ? null : args.seed
36017 });
36018 }
36019 getClassName() {
36020 // In Python Keras, LeCunUniform is not a class, but a helper method
36021 // that creates a VarianceScaling object. Use 'VarianceScaling' as
36022 // class name to be compatible with that.
36023 return VarianceScaling.className;
36024 }
36025 }
36026 /** @nocollapse */
36027 LeCunUniform.className = 'LeCunNormal';
36028 registerClass(LeCunUniform);
36029 class Orthogonal extends Initializer {
36030 constructor(args) {
36031 super();
36032 this.DEFAULT_GAIN = 1;
36033 this.gain = args.gain == null ? this.DEFAULT_GAIN : args.gain;
36034 this.seed = args.seed;
36035 if (this.seed != null) {
36036 throw new NotImplementedError('Random seed is not implemented for Orthogonal Initializer yet.');
36037 }
36038 }
36039 apply(shape, dtype) {
36040 return tidy(() => {
36041 if (shape.length < 2) {
36042 throw new NotImplementedError('Shape must be at least 2D.');
36043 }
36044 if (shape[0] * shape[1] > 2000) {
36045 console.warn(`Orthogonal initializer is being called on a matrix with more ` +
36046 `than 2000 (${shape[0] * shape[1]}) elements: ` +
36047 `Slowness may result.`);
36048 }
36049 // TODO(cais): Add seed support.
36050 const normalizedShape = shape[0] > shape[1] ? [shape[1], shape[0]] : shape;
36051 const a = randomNormal$1(normalizedShape, 0, 1, 'float32');
36052 let q = linalg.gramSchmidt(a);
36053 if (shape[0] > shape[1]) {
36054 q = transpose(q);
36055 }
36056 return mul(this.gain, q);
36057 });
36058 }
36059 getConfig() {
36060 return {
36061 gain: this.gain,
36062 seed: this.seed,
36063 };
36064 }
36065 }
36066 /** @nocollapse */
36067 Orthogonal.className = 'Orthogonal';
36068 registerClass(Orthogonal);
36069 // Maps the JavaScript-like identifier keys to the corresponding registry
36070 // symbols.
36071 const INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
36072 'constant': 'Constant',
36073 'glorotNormal': 'GlorotNormal',
36074 'glorotUniform': 'GlorotUniform',
36075 'heNormal': 'HeNormal',
36076 'heUniform': 'HeUniform',
36077 'identity': 'Identity',
36078 'leCunNormal': 'LeCunNormal',
36079 'leCunUniform': 'LeCunUniform',
36080 'ones': 'Ones',
36081 'orthogonal': 'Orthogonal',
36082 'randomNormal': 'RandomNormal',
36083 'randomUniform': 'RandomUniform',
36084 'truncatedNormal': 'TruncatedNormal',
36085 'varianceScaling': 'VarianceScaling',
36086 'zeros': 'Zeros'
36087 };
36088 function deserializeInitializer(config, customObjects = {}) {
36089 return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'initializer');
36090 }
36091 function serializeInitializer(initializer) {
36092 return serializeKerasObject(initializer);
36093 }
36094 function getInitializer(identifier) {
36095 if (typeof identifier === 'string') {
36096 const className = identifier in INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
36097 INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
36098 identifier;
36099 /* We have four 'helper' classes for common initializers that
36100 all get serialized as 'VarianceScaling' and shouldn't go through
36101 the deserializeInitializer pathway. */
36102 if (className === 'GlorotNormal') {
36103 return new GlorotNormal();
36104 }
36105 else if (className === 'GlorotUniform') {
36106 return new GlorotUniform();
36107 }
36108 else if (className === 'HeNormal') {
36109 return new HeNormal();
36110 }
36111 else if (className === 'HeUniform') {
36112 return new HeUniform();
36113 }
36114 else if (className === 'LeCunNormal') {
36115 return new LeCunNormal();
36116 }
36117 else if (className === 'LeCunUniform') {
36118 return new LeCunUniform();
36119 }
36120 else {
36121 const config = {};
36122 config['className'] = className;
36123 config['config'] = {};
36124 return deserializeInitializer(config);
36125 }
36126 }
36127 else if (identifier instanceof Initializer) {
36128 return identifier;
36129 }
36130 else {
36131 return deserializeInitializer(identifier);
36132 }
36133 }
36134
36135 /**
36136 * @license
36137 * Copyright 2018 Google LLC
36138 *
36139 * Use of this source code is governed by an MIT-style
36140 * license that can be found in the LICENSE file or at
36141 * https://opensource.org/licenses/MIT.
36142 * =============================================================================
36143 */
36144 // tslint:enable
36145 /**
36146 * Determine whether the input is an Array of Shapes.
36147 */
36148 function isArrayOfShapes(x) {
36149 return Array.isArray(x) && Array.isArray(x[0]);
36150 }
36151 /**
36152 * Special case of normalizing shapes to lists.
36153 *
36154 * @param x A shape or list of shapes to normalize into a list of Shapes.
36155 * @return A list of Shapes.
36156 */
36157 function normalizeShapeList(x) {
36158 if (x.length === 0) {
36159 return [];
36160 }
36161 if (!Array.isArray(x[0])) {
36162 return [x];
36163 }
36164 return x;
36165 }
36166 /**
36167 * Helper function to obtain exactly one Tensor.
36168 * @param xs: A single `tf.Tensor` or an `Array` of `tf.Tensor`s.
36169 * @return A single `tf.Tensor`. If `xs` is an `Array`, return the first one.
36170 * @throws ValueError: If `xs` is an `Array` and its length is not 1.
36171 */
36172 function getExactlyOneTensor(xs) {
36173 let x;
36174 if (Array.isArray(xs)) {
36175 if (xs.length !== 1) {
36176 throw new ValueError(`Expected Tensor length to be 1; got ${xs.length}`);
36177 }
36178 x = xs[0];
36179 }
36180 else {
36181 x = xs;
36182 }
36183 return x;
36184 }
36185 /**
36186 * Helper function to obtain exactly on instance of Shape.
36187 *
36188 * @param shapes Input single `Shape` or Array of `Shape`s.
36189 * @returns If input is a single `Shape`, return it unchanged. If the input is
36190 * an `Array` containing exactly one instance of `Shape`, return the instance.
36191 * Otherwise, throw a `ValueError`.
36192 * @throws ValueError: If input is an `Array` of `Shape`s, and its length is not
36193 * 1.
36194 */
36195 function getExactlyOneShape(shapes) {
36196 if (Array.isArray(shapes) && Array.isArray(shapes[0])) {
36197 if (shapes.length === 1) {
36198 shapes = shapes;
36199 return shapes[0];
36200 }
36201 else {
36202 throw new ValueError(`Expected exactly 1 Shape; got ${shapes.length}`);
36203 }
36204 }
36205 else {
36206 return shapes;
36207 }
36208 }
36209
36210 /**
36211 * @license
36212 * Copyright 2018 Google LLC
36213 *
36214 * Use of this source code is governed by an MIT-style
36215 * license that can be found in the LICENSE file or at
36216 * https://opensource.org/licenses/MIT.
36217 * =============================================================================
36218 */
36219 /**
36220 * Count the elements in an Array of LayerVariables.
36221 *
36222 * @param weights: The LayerVariables of which the constituent numbers are to
36223 * be counted.
36224 * @returns A count of the elements in all the LayerVariables
36225 */
36226 function countParamsInWeights(weights) {
36227 let count = 0;
36228 for (const weight of weights) {
36229 if (weight.shape.length === 0) {
36230 count += 1;
36231 }
36232 else {
36233 count += weight.shape.reduce((a, b) => a * b);
36234 }
36235 }
36236 return count;
36237 }
36238
36239 /**
36240 * @license
36241 * Copyright 2018 Google LLC
36242 *
36243 * Use of this source code is governed by an MIT-style
36244 * license that can be found in the LICENSE file or at
36245 * https://opensource.org/licenses/MIT.
36246 * =============================================================================
36247 */
36248 const DEFAULT_VARIABLE_NAME_PREFIX = 'Variable';
36249 /**
36250 * A `tf.layers.LayerVariable` is similar to a `tf.Tensor` in that it has a
36251 * dtype and shape, but its value is mutable. The value is itself represented
36252 * as a`tf.Tensor`, and can be read with the `read()` method and updated with
36253 * the `write()` method.
36254 */
36255 class LayerVariable {
36256 /**
36257 * Construct Variable from a `tf.Tensor`.
36258 *
36259 * If not explicitly named, the Variable will be given a name with the
36260 * prefix 'Variable'. Variable names are unique. In the case of name
36261 * collision, suffixies '_<num>' will be added to the name.
36262 *
36263 * @param val Initial value of the Variable.
36264 * @param name Name of the variable. If `null` or `undefined` is provided, it
36265 * will default a name with the prefix 'Variable'.
36266 * @param constraint Optional, projection function to be applied to the
36267 * variable after optimize updates
36268 * @throws ValueError if `name` is `null` or `undefined`.
36269 */
36270 constructor(val, dtype = 'float32', name = DEFAULT_VARIABLE_NAME_PREFIX, trainable = true, constraint = null) {
36271 this.dtype = dtype == null ? 'float32' : dtype;
36272 this.shape = val.shape;
36273 this.id = getNextUniqueTensorId();
36274 name = name == null ? DEFAULT_VARIABLE_NAME_PREFIX : name;
36275 this.originalName = getScopedTensorName(name);
36276 this.name = getUniqueTensorName(this.originalName);
36277 this.trainable_ = trainable;
36278 this.constraint = constraint;
36279 this.val = variable(val, this.trainable_, this.name, this.dtype);
36280 }
36281 /**
36282 * Get a snapshot of the Variable's value.
36283 *
36284 * The returned value is a snapshot of the Variable's value at the time of
36285 * the invocation. Future mutations in the value of the tensor will only
36286 * be reflected by future calls to this method.
36287 */
36288 read() {
36289 this.assertNotDisposed();
36290 return this.val;
36291 }
36292 /**
36293 * Update the value of the Variable.
36294 *
36295 * @param newVal: The new value to update to. Must be consistent with the
36296 * dtype and shape of the Variable.
36297 * @return This Variable.
36298 */
36299 write(newVal) {
36300 // TODO(cais): Once TF.js Core supports Tensor.dtype, check dtype match.
36301 this.assertNotDisposed();
36302 checkShapesMatch(this.val, newVal);
36303 // Skip updating if this is the exact same tensor.
36304 if (this.val.id !== newVal.id) {
36305 this.val.assign(newVal);
36306 if (this.constraint != null) {
36307 this.val.assign(this.constraint.apply(this.val));
36308 }
36309 }
36310 return this;
36311 }
36312 /**
36313 * Dispose this LayersVariable instance from memory.
36314 */
36315 dispose() {
36316 this.assertNotDisposed();
36317 this.val.dispose();
36318 }
36319 assertNotDisposed() {
36320 if (this.val.isDisposed) {
36321 throw new Error(`LayersVariable ${this.name} is already disposed.`);
36322 }
36323 }
36324 get trainable() {
36325 return this.trainable_;
36326 }
36327 set trainable(trainable) {
36328 this.trainable_ = trainable;
36329 this.val.trainable = trainable;
36330 }
36331 }
36332 function checkShapesMatch(x, y) {
36333 if (x.shape.toString() !== y.shape.toString()) {
36334 throw new Error('Shape mismatch: ' + JSON.stringify(x.shape) + ' vs. ' +
36335 JSON.stringify(y.shape));
36336 }
36337 }
36338 /**
36339 * Create a Variable.
36340 * @param x The initial value of the `Variable`.
36341 * @param dtype optional, the type of the variable.
36342 * @param name optional, the name of the variable, default provided by
36343 * Variable.
36344 * @param constraint optional, a constraint to be applied after every update.
36345 * @return The newly instantiated `Variable`.
36346 */
36347 function variable$1(x, dtype, name, constraint) {
36348 return new LayerVariable(x, dtype, name, true, constraint);
36349 }
36350 /**
36351 * Instantiates an all-zeros Variable and returns it.
36352 *
36353 * @param shape Shape of the tensor.
36354 * @param dtype DType of the tensor.
36355 * @param name Name of the tensor.
36356 * @return An all-zero Variable.
36357 */
36358 function zerosVariable(shape, dtype, name) {
36359 // TODO(cais): Implement logic for dtype.
36360 return new LayerVariable(zeros(shape), dtype, name);
36361 }
36362 /**
36363 * Instantiates an all-zeros tensor of the same shape as another tensor.
36364 *
36365 * @param x The other tensor.
36366 * @param dtype DType of the tensor.
36367 * @param name Name of the tensor.
36368 * @return A newly instantiated Variable.
36369 */
36370 function zerosLike$1(x, dtype, name) {
36371 return new LayerVariable(zerosLike(x), dtype, name);
36372 }
36373 /**
36374 * Instantiates an all-ones tensor and returns it.
36375 *
36376 * @param shape Shape of the tensor.
36377 * @param dtype DType of the tensor.
36378 * @param name Name of the tensor.
36379 * @return An all-ones Variable.
36380 */
36381 function onesVariable(shape, dtype, name) {
36382 // TODO(cais): Implement logic for dtype.
36383 const allocated = ones$1(shape);
36384 return new LayerVariable(allocated, dtype, name);
36385 }
36386 /**
36387 * Instantiates an all-ones tensor of the same shape as another tensor.
36388 *
36389 * @param x The other tensor.
36390 * @param dtype DType of the tensor.
36391 * @param name Name of the tensor.
36392 * @return A newly instantiated Variable.
36393 */
36394 function onesLike$1(x, dtype, name) {
36395 const allocated = onesLike(x);
36396 return new LayerVariable(allocated, dtype, name);
36397 }
36398 /**
36399 * Instantiate an identity matrix and returns it, as a Variable
36400 *
36401 * @param size Number of rows/columns.
36402 * @param dtype Data type of returned Variable.
36403 * @param name Name of returned Variable.
36404 * @return A Variable, an identity matrix.
36405 */
36406 function eyeVariable(size, dtype, name) {
36407 return new LayerVariable(eye(size), dtype, name);
36408 }
36409 /**
36410 * Get a Variable with uniform distribution of values.
36411 * @param shape Shape of the tensor.
36412 * @param minval Lower bound of the uniform distribution.
36413 * @param maxval Upper bound of the uniform distribution.
36414 * @param dtype
36415 * @param seed
36416 * @param name Optional name.
36417 * @return The uniform-random Variable.
36418 */
36419 function randomUniformVariable(shape, minval, maxval, dtype, seed, name = 'randomUniform') {
36420 return new LayerVariable(randomUniform(shape, minval, maxval, dtype), dtype, name);
36421 }
36422 /**
36423 * Get a Variable with truncated-normal distribution of values.
36424 * @param shape Shape of the tensor.
36425 * @param mean mean value of the normal distribution.
36426 * @param stddev standard deviation of the normal distribution.
36427 * @param dtype
36428 * @param seed
36429 * @param name Optional name.
36430 * @return The truncated-normal-random Variable.
36431 */
36432 function truncatedNormalVariable(shape, mean = 0.0, stddev = 1.0, dtype, seed, name = 'truncatedNormal') {
36433 // TODO(cais): Implement logic for dtype and seed once they are supported
36434 // by deeplearn.js.
36435 dtype = dtype || 'float32';
36436 if (dtype !== 'float32' && dtype !== 'int32') {
36437 throw new NotImplementedError(`randomNormal does not support dType ${dtype}.`);
36438 }
36439 return new LayerVariable(truncatedNormal(shape, mean, stddev, dtype, seed), dtype, name);
36440 }
36441 /**
36442 * Get a Variable with normal distribution of values.
36443 * @param shape Shape of the tensor.
36444 * @param mean mean value of the normal distribution.
36445 * @param stddev standard deviation of the normal distribution.
36446 * @param dtype
36447 * @param seed
36448 * @param name Optional name.
36449 * @return The truncated-normal-random Variable.
36450 */
36451 function randomNormalVariable(shape, mean = 0.0, stddev = 1.0, dtype, seed, name = 'randomNormal') {
36452 dtype = dtype || 'float32';
36453 if (dtype !== 'float32' && dtype !== 'int32') {
36454 throw new NotImplementedError(`randomNormalVariable does not support dType ${dtype}.`);
36455 }
36456 return new LayerVariable(randomNormal(shape, mean, stddev, dtype, seed), dtype, name);
36457 }
36458 /**
36459 * Update the value of a Variable.
36460 * @param x The Variable to be updated.
36461 * @param xNew The new value to update to.
36462 * @return The Variable updated.
36463 */
36464 function update(x, xNew) {
36465 return x.write(xNew);
36466 }
36467 /**
36468 * Update the value of a Variable by adding an increment.
36469 * @param x The Variable to be updated.
36470 * @param increment The incrment to add to `x`.
36471 * @return The Variable updated.
36472 */
36473 function updateAdd(x, increment) {
36474 return x.write(add$1(x.read(), increment));
36475 }
36476 /**
36477 * Update the value of a Variable by subtracting a decrement.
36478 * @param x The Variable to be updated.
36479 * @param decrement The decrement to subtract from `x`.
36480 * @return The Variable updated.
36481 */
36482 function updateSub(x, decrement) {
36483 return x.write(sub(x.read(), decrement));
36484 }
36485 /**
36486 * Get the values of an array of Variables.
36487 *
36488 * @param tensors An `Array` of `Variable`s to get the values of.
36489 * @return The values of the inputs, as an `Array` of`tf.Tensor`s.
36490 */
36491 function batchGetValue(xs) {
36492 return xs.map(x => x.read());
36493 }
36494 /**
36495 * Update the value of multiple Variables at once.
36496 *
36497 * @param variablesAndValues An `Array`, each element is of type
36498 * [Variable, Tensor]. The first item is the
36499 * `Variable` of which the value is to be updated. The second item
36500 * carries the new value.
36501 */
36502 function batchSetValue(variablesAndValues) {
36503 variablesAndValues.forEach(variableAndValue => {
36504 const variable = variableAndValue[0];
36505 variable.write(variableAndValue[1]);
36506 });
36507 }
36508 /**
36509 * Returns the gradients of `variables` w.r.t. the return value of `lossFn`.
36510 * @param lossFn A function which returns a Scalar to be used as the function
36511 * value (i.e., numerator) for differentiation.
36512 * @param variables List of variables to be used as the independent variables
36513 * (i.e., denominator) for differentiation.
36514 * @returns An Array of gradients tensors.
36515 */
36516 function gradients(lossFn, variables) {
36517 // TODO(cais): The return type signature can be simplified if deeplearn makes
36518 // the corresponding type public.
36519 const variableList = variables.map(variable => variable.read());
36520 const valudAndGrads = variableGrads(lossFn, variableList);
36521 return variables.map(variable => valudAndGrads.grads[variable.name]);
36522 }
36523
36524 /**
36525 * @license
36526 * Copyright 2018 Google LLC
36527 *
36528 * Use of this source code is governed by an MIT-style
36529 * license that can be found in the LICENSE file or at
36530 * https://opensource.org/licenses/MIT.
36531 * =============================================================================
36532 */
36533 /**
36534 * Specifies the ndim, dtype and shape of every input to a layer.
36535 *
36536 * Every layer should expose (if appropriate) an `inputSpec` attribute:
36537 * a list of instances of InputSpec (one per input tensor).
36538 *
36539 * A null entry in a shape is compatible with any dimension,
36540 * a null shape is compatible with any shape.
36541 */
36542 class InputSpec {
36543 constructor(args) {
36544 this.dtype = args.dtype;
36545 this.shape = args.shape;
36546 /*
36547 TODO(michaelterry): Could throw error if ndim and shape are both defined
36548 (then backport).
36549 */
36550 if (args.shape != null) {
36551 this.ndim = args.shape.length;
36552 }
36553 else {
36554 this.ndim = args.ndim;
36555 }
36556 this.maxNDim = args.maxNDim;
36557 this.minNDim = args.minNDim;
36558 this.axes = args.axes || {};
36559 }
36560 }
36561 /**
36562 * `tf.SymbolicTensor` is a placeholder for a Tensor without any concrete value.
36563 *
36564 * They are most often encountered when building a graph of `Layer`s for a
36565 * a `tf.LayersModel` and the input data's shape, but not values are known.
36566 *
36567 * @doc {heading: 'Models', 'subheading': 'Classes'}
36568 */
36569 class SymbolicTensor {
36570 /**
36571 *
36572 * @param dtype
36573 * @param shape
36574 * @param sourceLayer The Layer that produced this symbolic tensor.
36575 * @param inputs The inputs passed to sourceLayer's __call__() method.
36576 * @param nodeIndex
36577 * @param tensorIndex
36578 * @param callArgs The keyword arguments passed to the __call__() method.
36579 * @param name
36580 * @param outputTensorIndex The index of this tensor in the list of outputs
36581 * returned by apply().
36582 */
36583 constructor(dtype, shape, sourceLayer, inputs, callArgs, name, outputTensorIndex) {
36584 this.dtype = dtype;
36585 this.shape = shape;
36586 this.sourceLayer = sourceLayer;
36587 this.inputs = inputs;
36588 this.callArgs = callArgs;
36589 this.outputTensorIndex = outputTensorIndex;
36590 this.id = getNextUniqueTensorId();
36591 if (name != null) {
36592 this.originalName = getScopedTensorName(name);
36593 this.name = getUniqueTensorName(this.originalName);
36594 }
36595 this.rank = shape.length;
36596 }
36597 }
36598 let _nextNodeID = 0;
36599 /**
36600 * A `Node` describes the connectivity between two layers.
36601 *
36602 * Each time a layer is connected to some new input,
36603 * a node is added to `layer.inboundNodes`.
36604 *
36605 * Each time the output of a layer is used by another layer,
36606 * a node is added to `layer.outboundNodes`.
36607 *
36608 * `nodeIndices` and `tensorIndices` are basically fine-grained coordinates
36609 * describing the origin of the `inputTensors`, verifying the following:
36610 *
36611 * `inputTensors[i] ==
36612 * inboundLayers[i].inboundNodes[nodeIndices[i]].outputTensors[
36613 * tensorIndices[i]]`
36614 *
36615 * A node from layer A to layer B is added to:
36616 * A.outboundNodes
36617 * B.inboundNodes
36618 */
36619 class Node {
36620 constructor(args,
36621 // TODO(michaelterry): Define actual type for this.
36622 callArgs) {
36623 this.callArgs = callArgs;
36624 this.id = _nextNodeID++;
36625 /*
36626 Layer instance (NOT a list).
36627 this is the layer that takes a list of input tensors
36628 and turns them into a list of output tensors.
36629 the current node will be added to
36630 the inboundNodes of outboundLayer.
36631 */
36632 this.outboundLayer = args.outboundLayer;
36633 /*
36634 The following 3 properties describe where
36635 the input tensors come from: which layers,
36636 and for each layer, which node and which
36637 tensor output of each node.
36638 */
36639 // List of layer instances.
36640 this.inboundLayers = args.inboundLayers;
36641 // List of integers, 1:1 mapping with inboundLayers.
36642 this.nodeIndices = args.nodeIndices;
36643 // List of integers, 1:1 mapping with inboundLayers.
36644 this.tensorIndices = args.tensorIndices;
36645 /*
36646 Following 2 properties:
36647 tensor inputs and outputs of outboundLayer.
36648 */
36649 // List of tensors. 1:1 mapping with inboundLayers.
36650 this.inputTensors = args.inputTensors;
36651 // List of tensors, created by outboundLayer.call().
36652 this.outputTensors = args.outputTensors;
36653 /*
36654 Following 2 properties: input and output masks.
36655 List of tensors, 1:1 mapping with inputTensor.
36656 */
36657 this.inputMasks = args.inputMasks;
36658 // List of tensors, created by outboundLayer.computeMask().
36659 this.outputMasks = args.outputMasks;
36660 // Following 2 properties: input and output shapes.
36661 // List of shape tuples, shapes of inputTensors.
36662 this.inputShapes = args.inputShapes;
36663 // List of shape tuples, shapes of outputTensors.
36664 this.outputShapes = args.outputShapes;
36665 // Add nodes to all layers involved.
36666 for (const layer of args.inboundLayers) {
36667 if (layer != null) {
36668 layer.outboundNodes.push(this);
36669 }
36670 }
36671 args.outboundLayer.inboundNodes.push(this);
36672 }
36673 getConfig() {
36674 const inboundNames = [];
36675 for (const layer of this.inboundLayers) {
36676 if (layer != null) {
36677 inboundNames.push(layer.name);
36678 }
36679 else {
36680 inboundNames.push(null);
36681 }
36682 }
36683 return {
36684 outboundLayer: this.outboundLayer ? this.outboundLayer.name : null,
36685 inboundLayers: inboundNames,
36686 nodeIndices: this.nodeIndices,
36687 tensorIndices: this.tensorIndices
36688 };
36689 }
36690 }
36691 let _nextLayerID = 0;
36692 /**
36693 * A layer is a grouping of operations and weights that can be composed to
36694 * create a `tf.LayersModel`.
36695 *
36696 * Layers are constructed by using the functions under the
36697 * [tf.layers](#Layers-Basic) namespace.
36698 *
36699 * @doc {heading: 'Layers', subheading: 'Classes', namespace: 'layers'}
36700 */
36701 class Layer extends Serializable {
36702 constructor(args = {}) {
36703 super();
36704 this._callHook = null;
36705 this._addedWeightNames = [];
36706 // Porting Notes: PyKeras does not have this property in this base Layer
36707 // class. Instead lets Layer subclass set it dynamically and checks the
36708 // value with `hasattr`. In tfjs-layers, we let this be a member of this
36709 // base class.
36710 this._stateful = false;
36711 this.id = _nextLayerID++;
36712 this.activityRegularizer = null;
36713 this.inputSpec = null;
36714 this.supportsMasking = false;
36715 // These properties will be set upon call of this.build()
36716 this._trainableWeights = [];
36717 this._nonTrainableWeights = [];
36718 this._losses = [];
36719 this._updates = [];
36720 this._built = false;
36721 /*
36722 These lists will be filled via successive calls
36723 to this.addInboundNode().
36724 */
36725 this.inboundNodes = [];
36726 this.outboundNodes = [];
36727 let name = args.name;
36728 if (!name) {
36729 const prefix = this.getClassName();
36730 name = toSnakeCase(prefix) + '_' + getUid(prefix);
36731 }
36732 this.name = name;
36733 this.trainable_ = args.trainable == null ? true : args.trainable;
36734 if (args.inputShape != null || args.batchInputShape != null) {
36735 /*
36736 In this case we will later create an input layer
36737 to insert before the current layer
36738 */
36739 let batchInputShape;
36740 if (args.batchInputShape != null) {
36741 batchInputShape = args.batchInputShape;
36742 }
36743 else if (args.inputShape != null) {
36744 let batchSize = null;
36745 if (args.batchSize != null) {
36746 batchSize = args.batchSize;
36747 }
36748 batchInputShape = [batchSize].concat(args.inputShape);
36749 }
36750 this.batchInputShape = batchInputShape;
36751 // Set dtype.
36752 let dtype = args.dtype;
36753 if (dtype == null) {
36754 dtype = args.inputDType;
36755 }
36756 if (dtype == null) {
36757 dtype = 'float32';
36758 }
36759 this.dtype = dtype;
36760 }
36761 if (args.weights != null) {
36762 this.initialWeights = args.weights;
36763 }
36764 else {
36765 this.initialWeights = null;
36766 }
36767 // The value of `_refCount` is initialized to null. When the layer is used
36768 // in a symbolic way for the first time, it will be set to 1.
36769 this._refCount = null;
36770 this.fastWeightInitDuringBuild = false;
36771 }
36772 /**
36773 * Converts a layer and its index to a unique (immutable type) name.
36774 * This function is used internally with `this.containerNodes`.
36775 * @param layer The layer.
36776 * @param nodeIndex The layer's position (e.g. via enumerate) in a list of
36777 * nodes.
36778 *
36779 * @returns The unique name.
36780 */
36781 static nodeKey(layer, nodeIndex) {
36782 return layer.name + '_ib-' + nodeIndex.toString();
36783 }
36784 /**
36785 * Returns this.inboundNode at index nodeIndex.
36786 *
36787 * Porting note: This is a replacement for _get_node_attribute_at_index()
36788 * @param nodeIndex
36789 * @param attrName The name of the attribute related to request for this node.
36790 */
36791 getNodeAtIndex(nodeIndex, attrName) {
36792 if (this.inboundNodes.length === 0) {
36793 throw new RuntimeError('The layer has never been called ' +
36794 `and thus has no defined ${attrName}.`);
36795 }
36796 if (this.inboundNodes.length <= nodeIndex) {
36797 throw new ValueError(`Asked to get ${attrName} at node ${nodeIndex}, ` +
36798 `but the layer has only ${this.inboundNodes.length} inbound nodes.`);
36799 }
36800 return this.inboundNodes[nodeIndex];
36801 }
36802 /**
36803 * Retrieves the input tensor(s) of a layer at a given node.
36804 *
36805 * @param nodeIndex Integer, index of the node from which to retrieve the
36806 * attribute. E.g. `nodeIndex=0` will correspond to the first time the layer
36807 * was called.
36808 *
36809 * @return A tensor (or list of tensors if the layer has multiple inputs).
36810 */
36811 getInputAt(nodeIndex) {
36812 return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'input').inputTensors);
36813 }
36814 /**
36815 * Retrieves the output tensor(s) of a layer at a given node.
36816 *
36817 * @param nodeIndex Integer, index of the node from which to retrieve the
36818 * attribute. E.g. `nodeIndex=0` will correspond to the first time the layer
36819 * was called.
36820 *
36821 * @return A tensor (or list of tensors if the layer has multiple outputs).
36822 */
36823 getOutputAt(nodeIndex) {
36824 return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'output').outputTensors);
36825 }
36826 // Properties
36827 /**
36828 * Retrieves the input tensor(s) of a layer.
36829 *
36830 * Only applicable if the layer has exactly one inbound node,
36831 * i.e. if it is connected to one incoming layer.
36832 *
36833 * @return Input tensor or list of input tensors.
36834 *
36835 * @exception AttributeError if the layer is connected to more than one
36836 * incoming layers.
36837 */
36838 get input() {
36839 if (this.inboundNodes.length > 1) {
36840 throw new AttributeError(`Layer ${this.name}` +
36841 ' has multiple inbound nodes, ' +
36842 'hence the notion of "layer input" ' +
36843 'is ill-defined. ' +
36844 'Use `getInputAt(nodeIndex)` instead.');
36845 }
36846 else if (this.inboundNodes.length === 0) {
36847 throw new AttributeError(`Layer ${this.name}` +
36848 ' is not connected, no input to return.');
36849 }
36850 return singletonOrArray(this.getNodeAtIndex(0, 'input').inputTensors);
36851 }
36852 /**
36853 * Retrieves the output tensor(s) of a layer.
36854 *
36855 * Only applicable if the layer has exactly one inbound node,
36856 * i.e. if it is connected to one incoming layer.
36857 *
36858 * @return Output tensor or list of output tensors.
36859 *
36860 * @exception AttributeError if the layer is connected to more than one
36861 * incoming layers.
36862 */
36863 get output() {
36864 if (this.inboundNodes.length === 0) {
36865 throw new AttributeError(`Layer ${this.name}` +
36866 ' has no inbound nodes.');
36867 }
36868 if (this.inboundNodes.length > 1) {
36869 throw new AttributeError(`Layer ${this.name}` +
36870 ' has multiple inbound nodes, ' +
36871 'hence the notion of "layer output" ' +
36872 'is ill-defined. ' +
36873 'Use `getOutputAt(nodeIndex)` instead.');
36874 }
36875 return singletonOrArray(this.getNodeAtIndex(0, 'output').outputTensors);
36876 }
36877 get losses() {
36878 return this._losses;
36879 }
36880 /**
36881 * Retrieves the Layer's current loss values.
36882 *
36883 * Used for regularizers during training.
36884 */
36885 calculateLosses() {
36886 // Porting Node: This is an augmentation to Layer.loss in PyKeras.
36887 // In PyKeras, Layer.loss returns symbolic tensors. Here a concrete
36888 // Tensor (specifically Scalar) values are returned. This is due to the
36889 // imperative backend.
36890 return this.losses.map(lossFn => lossFn());
36891 }
36892 get updates() {
36893 return this._updates;
36894 }
36895 get built() {
36896 return this._built;
36897 }
36898 set built(built) {
36899 this._built = built;
36900 }
36901 get trainable() {
36902 return this.trainable_;
36903 }
36904 set trainable(trainable) {
36905 this._trainableWeights.forEach(w => w.trainable = trainable);
36906 this.trainable_ = trainable;
36907 }
36908 get trainableWeights() {
36909 if (this.trainable_) {
36910 return this._trainableWeights.filter(w => w.trainable);
36911 }
36912 else {
36913 return [];
36914 }
36915 }
36916 set trainableWeights(weights) {
36917 this._trainableWeights = weights;
36918 }
36919 get nonTrainableWeights() {
36920 if (this.trainable) {
36921 return this._trainableWeights.filter(w => !w.trainable)
36922 .concat(this._nonTrainableWeights);
36923 }
36924 else {
36925 return this._trainableWeights.concat(this._nonTrainableWeights);
36926 }
36927 }
36928 set nonTrainableWeights(weights) {
36929 this._nonTrainableWeights = weights;
36930 }
36931 /**
36932 * The concatenation of the lists trainableWeights and nonTrainableWeights
36933 * (in this order).
36934 */
36935 get weights() {
36936 return this.trainableWeights.concat(this.nonTrainableWeights);
36937 }
36938 get stateful() {
36939 return this._stateful;
36940 }
36941 /**
36942 * Reset the states of the layer.
36943 *
36944 * This method of the base Layer class is essentially a no-op.
36945 * Subclasses that are stateful (e.g., stateful RNNs) should override this
36946 * method.
36947 */
36948 resetStates() {
36949 if (!this.stateful) {
36950 throw new Error('Cannot call the resetStates() method of a non-stateful Layer ' +
36951 'object.');
36952 }
36953 }
36954 /**
36955 * Checks compatibility between the layer and provided inputs.
36956 *
36957 * This checks that the tensor(s) `input`
36958 * verify the input assumptions of the layer
36959 * (if any). If not, exceptions are raised.
36960 *
36961 * @param inputs Input tensor or list of input tensors.
36962 *
36963 * @exception ValueError in case of mismatch between
36964 * the provided inputs and the expectations of the layer.
36965 */
36966 assertInputCompatibility(inputs) {
36967 inputs = toList(inputs);
36968 if (this.inputSpec == null || this.inputSpec.length === 0) {
36969 return;
36970 }
36971 const inputSpec = toList(this.inputSpec);
36972 if (inputs.length !== inputSpec.length) {
36973 throw new ValueError(`Layer ${this.name} expects ${inputSpec.length} inputs, ` +
36974 `but it received ${inputs.length} input tensors. ` +
36975 `Input received: ${inputs}`);
36976 }
36977 for (let inputIndex = 0; inputIndex < inputs.length; inputIndex++) {
36978 const x = inputs[inputIndex];
36979 const spec = inputSpec[inputIndex];
36980 if (spec == null) {
36981 continue;
36982 }
36983 // Check ndim.
36984 const ndim = x.rank;
36985 if (spec.ndim != null) {
36986 if (ndim !== spec.ndim) {
36987 throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}: ` +
36988 `expected ndim=${spec.ndim}, found ndim=${ndim}`);
36989 }
36990 }
36991 if (spec.maxNDim != null) {
36992 if (ndim > spec.maxNDim) {
36993 throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}` +
36994 `: expected max_ndim=${spec.maxNDim}, found ndim=${ndim}`);
36995 }
36996 }
36997 if (spec.minNDim != null) {
36998 if (ndim < spec.minNDim) {
36999 throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}` +
37000 `: expected min_ndim=${spec.minNDim}, found ndim=${ndim}.`);
37001 }
37002 }
37003 // Check dtype.
37004 if (spec.dtype != null) {
37005 if (x.dtype !== spec.dtype) {
37006 throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name} ` +
37007 `: expected dtype=${spec.dtype}, found dtype=${x.dtype}.`);
37008 }
37009 }
37010 // Check specific shape axes.
37011 if (spec.axes) {
37012 const xShape = x.shape;
37013 for (const key in spec.axes) {
37014 const axis = Number(key);
37015 const value = spec.axes[key];
37016 // Perform Python-style slicing in case axis < 0;
37017 // TODO(cais): Use https://github.com/alvivi/typescript-underscore to
37018 // ensure type safety through Underscore calls.
37019 const xShapeAtAxis = axis >= 0 ? xShape[axis] : xShape[xShape.length + axis];
37020 if (value != null && [value, null].indexOf(xShapeAtAxis) === -1) {
37021 throw new ValueError(`Input ${inputIndex} is incompatible with layer ` +
37022 `${this.name}: expected axis ${axis} of input shape to ` +
37023 `have value ${value} but got shape ${xShape}.`);
37024 }
37025 }
37026 }
37027 // Check shape.
37028 if (spec.shape != null) {
37029 for (let i = 0; i < spec.shape.length; ++i) {
37030 const specDim = spec.shape[i];
37031 const dim = x.shape[i];
37032 if (specDim != null && dim != null) {
37033 if (specDim !== dim) {
37034 throw new ValueError(`Input ${inputIndex} is incompatible with layer ` +
37035 `${this.name}: expected shape=${spec.shape}, ` +
37036 `found shape=${x.shape}.`);
37037 }
37038 }
37039 }
37040 }
37041 }
37042 }
37043 /**
37044 * This is where the layer's logic lives.
37045 *
37046 * @param inputs Input tensor, or list/tuple of input tensors.
37047 * @param kwargs Additional keyword arguments.
37048 *
37049 * @return A tensor or list/tuple of tensors.
37050 */
37051 call(inputs, kwargs) {
37052 return inputs;
37053 }
37054 invokeCallHook(inputs, kwargs) {
37055 if (this._callHook != null) {
37056 this._callHook(inputs, kwargs);
37057 }
37058 }
37059 /**
37060 * Set call hook.
37061 * This is currently used for testing only.
37062 * @param callHook
37063 */
37064 setCallHook(callHook) {
37065 this._callHook = callHook;
37066 }
37067 /**
37068 * Clear call hook.
37069 * This is currently used for testing only.
37070 */
37071 clearCallHook() {
37072 this._callHook = null;
37073 }
37074 /**
37075 * Builds or executes a `Layer's logic.
37076 *
37077 * When called with `tf.Tensor`(s), execute the `Layer`s computation and
37078 * return Tensor(s). For example:
37079 *
37080 * ```js
37081 * const denseLayer = tf.layers.dense({
37082 * units: 1,
37083 * kernelInitializer: 'zeros',
37084 * useBias: false
37085 * });
37086 *
37087 * // Invoke the layer's apply() method with a `tf.Tensor` (with concrete
37088 * // numeric values).
37089 * const input = tf.ones([2, 2]);
37090 * const output = denseLayer.apply(input);
37091 *
37092 * // The output's value is expected to be [[0], [0]], due to the fact that
37093 * // the dense layer has a kernel initialized to all-zeros and does not have
37094 * // a bias.
37095 * output.print();
37096 * ```
37097 *
37098 * When called with `tf.SymbolicTensor`(s), this will prepare the layer for
37099 * future execution. This entails internal book-keeping on shapes of
37100 * expected Tensors, wiring layers together, and initializing weights.
37101 *
37102 * Calling `apply` with `tf.SymbolicTensor`s are typically used during the
37103 * building of non-`tf.Sequential` models. For example:
37104 *
37105 * ```js
37106 * const flattenLayer = tf.layers.flatten();
37107 * const denseLayer = tf.layers.dense({units: 1});
37108 *
37109 * // Use tf.layers.input() to obtain a SymbolicTensor as input to apply().
37110 * const input = tf.input({shape: [2, 2]});
37111 * const output1 = flattenLayer.apply(input);
37112 *
37113 * // output1.shape is [null, 4]. The first dimension is the undetermined
37114 * // batch size. The second dimension comes from flattening the [2, 2]
37115 * // shape.
37116 * console.log(JSON.stringify(output1.shape));
37117 *
37118 * // The output SymbolicTensor of the flatten layer can be used to call
37119 * // the apply() of the dense layer:
37120 * const output2 = denseLayer.apply(output1);
37121 *
37122 * // output2.shape is [null, 1]. The first dimension is the undetermined
37123 * // batch size. The second dimension matches the number of units of the
37124 * // dense layer.
37125 * console.log(JSON.stringify(output2.shape));
37126 *
37127 * // The input and output and be used to construct a model that consists
37128 * // of the flatten and dense layers.
37129 * const model = tf.model({inputs: input, outputs: output2});
37130 * ```
37131 *
37132 * @param inputs a `tf.Tensor` or `tf.SymbolicTensor` or an Array of them.
37133 * @param kwargs Additional keyword arguments to be passed to `call()`.
37134 *
37135 * @return Output of the layer's `call` method.
37136 *
37137 * @exception ValueError error in case the layer is missing shape information
37138 * for its `build` call.
37139 *
37140 * @doc {heading: 'Models', 'subheading': 'Classes'}
37141 */
37142 // Porting Note: This is a replacement for __call__() in Python.
37143 apply(inputs, kwargs) {
37144 kwargs = kwargs || {};
37145 this.assertNotDisposed();
37146 // Ensure inputs are all the same type.
37147 const inputsList = toList(inputs);
37148 let allAreSymbolic = true;
37149 for (const input of inputsList) {
37150 if (!(input instanceof SymbolicTensor)) {
37151 allAreSymbolic = false;
37152 break;
37153 }
37154 }
37155 let noneAreSymbolic = true;
37156 for (const input of inputsList) {
37157 if (input instanceof SymbolicTensor) {
37158 noneAreSymbolic = false;
37159 break;
37160 }
37161 }
37162 if (allAreSymbolic === noneAreSymbolic) {
37163 throw new ValueError('Arguments to apply() must be all ' +
37164 'SymbolicTensors or all Tensors');
37165 }
37166 // TODO(michaelterry): nameScope() may not be necessary.
37167 return nameScope(this.name, () => {
37168 // Handle laying building (weight creating, input spec locking).
37169 if (!this.built) {
37170 /*
37171 Throw exceptions in case the input is not compatible
37172 with the inputSpec specified in the layer constructor.
37173 */
37174 this.assertInputCompatibility(inputs);
37175 // Collect input shapes to build layer.
37176 const inputShapes = [];
37177 for (const xElem of toList(inputs)) {
37178 inputShapes.push(xElem.shape);
37179 }
37180 this.build(singletonOrArray(inputShapes));
37181 this.built = true;
37182 // Load weights that were specified at layer instantiation.
37183 if (this.initialWeights) {
37184 this.setWeights(this.initialWeights);
37185 }
37186 if (this._refCount === null && noneAreSymbolic) {
37187 // The first use of this layer is a non-symbolic call, set ref count
37188 // to 1 so the Layer can be properly disposed if its dispose() method
37189 // is called.
37190 this._refCount = 1;
37191 }
37192 }
37193 /*
37194 Throw exceptions in case the input is not compatible
37195 with the inputSpec set at build time.
37196 */
37197 this.assertInputCompatibility(inputs);
37198 // Handle mask propagation.
37199 // TODO(michaelterry): Mask propagation not currently implemented.
37200 // Actually call the layer, collecting output(s), mask(s), and shape(s).
37201 if (noneAreSymbolic) {
37202 let output = this.call(inputs, kwargs);
37203 // TODO(michaelterry): Compute the outputMask
37204 // If the layer returns tensors from its inputs, unmodified,
37205 // we copy them to avoid loss of tensor metadata.
37206 const outputList = toList(output);
37207 const outputListCopy = [];
37208 // TODO(michaelterry): This copying may not be necessary given our eager
37209 // backend.
37210 for (let x of outputList) {
37211 if (inputsList.indexOf(x) !== -1) {
37212 x = x.clone();
37213 }
37214 outputListCopy.push(x);
37215 }
37216 output = singletonOrArray(outputListCopy);
37217 if (this.activityRegularizer != null) {
37218 throw new NotImplementedError('Layer invocation in the presence of activity ' +
37219 'regularizer(s) is not supported yet.');
37220 }
37221 // TODO(michaelterry): Call addInboundNode()?
37222 return output;
37223 }
37224 else {
37225 const inputShape = collectInputShape(inputs);
37226 const outputShape = this.computeOutputShape(inputShape);
37227 let output;
37228 const outputDType = guessOutputDType(inputs);
37229 this.warnOnIncompatibleInputShape(Array.isArray(inputs) ? inputShape[0] :
37230 inputShape);
37231 if (outputShape != null && outputShape.length > 0 &&
37232 Array.isArray(outputShape[0])) {
37233 // We have multiple output shapes. Create multiple output tensors.
37234 output = outputShape
37235 .map((shape, index) => new SymbolicTensor(outputDType, shape, this, toList(inputs), kwargs, this.name, index));
37236 }
37237 else {
37238 output = new SymbolicTensor(outputDType, outputShape, this, toList(inputs), kwargs, this.name);
37239 }
37240 /*
37241 Add an inbound node to the layer, so that it keeps track
37242 of the call and of all new variables created during the call.
37243 This also updates the layer history of the output tensor(s).
37244 If the input tensor(s) had no previous history,
37245 this does nothing.
37246 */
37247 this.addInboundNode(inputs, output, null, null, inputShape, outputShape, kwargs);
37248 this._refCount++;
37249 if (this.activityRegularizer != null) {
37250 throw new NotImplementedError('Layer invocation in the presence of activity ' +
37251 'regularizer(s) is not supported yet.');
37252 }
37253 return output;
37254 }
37255 });
37256 }
37257 /**
37258 * Check compatibility between input shape and this layer's batchInputShape.
37259 *
37260 * Print warning if any incompatibility is found.
37261 *
37262 * @param inputShape Input shape to be checked.
37263 */
37264 warnOnIncompatibleInputShape(inputShape) {
37265 if (this.batchInputShape == null) {
37266 return;
37267 }
37268 else if (inputShape.length !== this.batchInputShape.length) {
37269 console.warn(`The rank of the input tensor provided (shape: ` +
37270 `${JSON.stringify(inputShape)}) does not match that of the ` +
37271 `batchInputShape (${JSON.stringify(this.batchInputShape)}) ` +
37272 `of the layer ${this.name}`);
37273 }
37274 else {
37275 let dimMismatch = false;
37276 this.batchInputShape.forEach((dimension, i) => {
37277 if (dimension != null && inputShape[i] != null &&
37278 inputShape[i] !== dimension) {
37279 dimMismatch = true;
37280 }
37281 });
37282 if (dimMismatch) {
37283 console.warn(`The shape of the input tensor ` +
37284 `(${JSON.stringify(inputShape)}) does not ` +
37285 `match the expectation of layer ${this.name}: ` +
37286 `${JSON.stringify(this.batchInputShape)}`);
37287 }
37288 }
37289 }
37290 /**
37291 * Retrieves the output shape(s) of a layer.
37292 *
37293 * Only applicable if the layer has only one inbound node, or if all inbound
37294 * nodes have the same output shape.
37295 *
37296 * @returns Output shape or shapes.
37297 * @throws AttributeError: if the layer is connected to more than one incoming
37298 * nodes.
37299 *
37300 * @doc {heading: 'Models', 'subheading': 'Classes'}
37301 */
37302 get outputShape() {
37303 if (this.inboundNodes == null || this.inboundNodes.length === 0) {
37304 throw new AttributeError(`The layer ${this.name} has never been called and thus has no ` +
37305 `defined output shape.`);
37306 }
37307 const allOutputShapes = [];
37308 for (const node of this.inboundNodes) {
37309 const shapeString = JSON.stringify(node.outputShapes);
37310 if (allOutputShapes.indexOf(shapeString) === -1) {
37311 allOutputShapes.push(shapeString);
37312 }
37313 }
37314 if (allOutputShapes.length === 1) {
37315 const outputShapes = this.inboundNodes[0].outputShapes;
37316 if (Array.isArray(outputShapes) && Array.isArray(outputShapes[0]) &&
37317 outputShapes.length === 1) {
37318 return outputShapes[0];
37319 }
37320 else {
37321 return outputShapes;
37322 }
37323 }
37324 else {
37325 throw new AttributeError(`The layer ${this.name} has multiple inbound nodes with different ` +
37326 `output shapes. Hence the notion of "output shape" is ill-defined ` +
37327 `for the layer.`);
37328 // TODO(cais): Implement getOutputShapeAt().
37329 }
37330 }
37331 /**
37332 * Counts the total number of numbers (e.g., float32, int32) in the
37333 * weights.
37334 *
37335 * @returns An integer count.
37336 * @throws RuntimeError: If the layer is not built yet (in which case its
37337 * weights are not defined yet.)
37338 *
37339 * @doc {heading: 'Models', 'subheading': 'Classes'}
37340 */
37341 countParams() {
37342 if (!this.built) {
37343 throw new RuntimeError(`You tried to call countParams() on ${this.name}, ` +
37344 `but the layer is not built yet. Build it first by calling ` +
37345 `build(batchInputShape).`);
37346 }
37347 return countParamsInWeights(this.weights);
37348 }
37349 /**
37350 * Creates the layer weights.
37351 *
37352 * Must be implemented on all layers that have weights.
37353 *
37354 * Called when apply() is called to construct the weights.
37355 *
37356 * @param inputShape A `Shape` or array of `Shape` (unused).
37357 *
37358 * @doc {heading: 'Models', 'subheading': 'Classes'}
37359 */
37360 build(inputShape) {
37361 this.built = true;
37362 }
37363 /**
37364 * Returns the current values of the weights of the layer.
37365 *
37366 * @param trainableOnly Whether to get the values of only trainable weights.
37367 * @returns Weight values as an `Array` of `tf.Tensor`s.
37368 *
37369 * @doc {heading: 'Models', 'subheading': 'Classes'}
37370 */
37371 getWeights(trainableOnly = false) {
37372 return batchGetValue(trainableOnly ? this.trainableWeights : this.weights);
37373 }
37374 /**
37375 * Sets the weights of the layer, from Tensors.
37376 *
37377 * @param weights a list of Tensors. The number of arrays and their shape
37378 * must match number of the dimensions of the weights of the layer (i.e.
37379 * it should match the output of `getWeights`).
37380 *
37381 * @exception ValueError If the provided weights list does not match the
37382 * layer's specifications.
37383 *
37384 * @doc {heading: 'Models', 'subheading': 'Classes'}
37385 */
37386 setWeights(weights) {
37387 tidy(() => {
37388 const params = this.weights;
37389 if (params.length !== weights.length) {
37390 // TODO(cais): Restore the following and use `providedWeights`, instead
37391 // of `weights` in the error message, once the deeplearn.js bug is
37392 // fixed: https://github.com/PAIR-code/deeplearnjs/issues/498 const
37393 // providedWeights = JSON.stringify(weights).slice(0, 50);
37394 throw new ValueError(`You called setWeights(weights) on layer "${this.name}" ` +
37395 `with a weight list of length ${weights.length}, ` +
37396 `but the layer was expecting ${params.length} weights. ` +
37397 `Provided weights: ${weights}...`);
37398 }
37399 if (params.length === 0) {
37400 return;
37401 }
37402 const weightValueTuples = [];
37403 const paramValues = batchGetValue(params);
37404 for (let i = 0; i < paramValues.length; ++i) {
37405 const pv = paramValues[i];
37406 const p = params[i];
37407 const w = weights[i];
37408 if (!arraysEqual(pv.shape, w.shape)) {
37409 throw new ValueError(`Layer weight shape ${pv.shape} ` +
37410 `not compatible with provided weight shape ${w.shape}`);
37411 }
37412 weightValueTuples.push([p, w]);
37413 }
37414 batchSetValue(weightValueTuples);
37415 });
37416 }
37417 /**
37418 * Adds a weight variable to the layer.
37419 *
37420 * @param name Name of the new weight variable.
37421 * @param shape The shape of the weight.
37422 * @param dtype The dtype of the weight.
37423 * @param initializer An initializer instance.
37424 * @param regularizer A regularizer instance.
37425 * @param trainable Whether the weight should be trained via backprop or not
37426 * (assuming that the layer itself is also trainable).
37427 * @param constraint An optional trainable.
37428 * @return The created weight variable.
37429 *
37430 * @doc {heading: 'Models', 'subheading': 'Classes'}
37431 */
37432 addWeight(name, shape, dtype, initializer, regularizer, trainable, constraint, getInitializerFunc) {
37433 // Reject duplicate weight names.
37434 if (this._addedWeightNames.indexOf(name) !== -1) {
37435 throw new ValueError(`Duplicate weight name ${name} for layer ${this.name}`);
37436 }
37437 this._addedWeightNames.push(name);
37438 if (dtype == null) {
37439 dtype = 'float32';
37440 }
37441 if (this.fastWeightInitDuringBuild) {
37442 initializer = getInitializerFunc != null ? getInitializerFunc() :
37443 getInitializer('zeros');
37444 }
37445 const initValue = initializer.apply(shape, dtype);
37446 const weight = new LayerVariable(initValue, dtype, name, trainable, constraint);
37447 initValue.dispose();
37448 // Request backend not to dispose the weights of the model on scope() exit.
37449 if (regularizer != null) {
37450 this.addLoss(() => regularizer.apply(weight.read()));
37451 }
37452 if (trainable == null) {
37453 trainable = true;
37454 }
37455 if (trainable) {
37456 this._trainableWeights.push(weight);
37457 }
37458 else {
37459 this._nonTrainableWeights.push(weight);
37460 }
37461 return weight;
37462 }
37463 /**
37464 * Set the fast-weight-initialization flag.
37465 *
37466 * In cases where the initialized weight values will be immediately
37467 * overwritten by loaded weight values during model loading, setting
37468 * the flag to `true` saves unnecessary calls to potentially expensive
37469 * initializers and speeds up the loading process.
37470 *
37471 * @param value Target value of the flag.
37472 */
37473 setFastWeightInitDuringBuild(value) {
37474 this.fastWeightInitDuringBuild = value;
37475 }
37476 /**
37477 * Add losses to the layer.
37478 *
37479 * The loss may potentionally be conditional on some inputs tensors,
37480 * for instance activity losses are conditional on the layer's inputs.
37481 *
37482 * @doc {heading: 'Models', 'subheading': 'Classes'}
37483 */
37484 addLoss(losses) {
37485 if (losses == null || Array.isArray(losses) && losses.length === 0) {
37486 return;
37487 }
37488 // Update this.losses
37489 losses = toList(losses);
37490 if (this._losses !== undefined && this._losses !== null) {
37491 this.losses.push(...losses);
37492 }
37493 }
37494 /**
37495 * Computes the output shape of the layer.
37496 *
37497 * Assumes that the layer will be built to match that input shape provided.
37498 *
37499 * @param inputShape A shape (tuple of integers) or a list of shape tuples
37500 * (one per output tensor of the layer). Shape tuples can include null for
37501 * free dimensions, instead of an integer.
37502 *
37503 * @doc {heading: 'Models', 'subheading': 'Classes'}
37504 */
37505 computeOutputShape(inputShape) {
37506 return inputShape;
37507 }
37508 /**
37509 * Computes an output mask tensor.
37510 *
37511 * @param inputs Tensor or list of tensors.
37512 * @param mask Tensor or list of tensors.
37513 *
37514 * @return null or a tensor (or list of tensors, one per output tensor of the
37515 * layer).
37516 */
37517 computeMask(inputs, mask) {
37518 if (!this.supportsMasking) {
37519 if (mask != null) {
37520 if (Array.isArray(mask)) {
37521 mask.forEach(maskElement => {
37522 if (maskElement != null) {
37523 throw new TypeError(`Layer ${this.name} does not support masking, ` +
37524 'but was passed an inputMask.');
37525 }
37526 });
37527 }
37528 else {
37529 throw new TypeError(`Layer ${this.name} does not support masking, ` +
37530 'but was passed an inputMask.');
37531 }
37532 }
37533 // masking not explicitly supported: return null as mask
37534 return null;
37535 }
37536 // if masking is explictly supported, by default
37537 // carry over the input mask
37538 return mask;
37539 }
37540 /**
37541 * Internal method to create an inbound node for the layer.
37542 *
37543 * @param inputTensors List of input tensors.
37544 * @param outputTensors List of output tensors.
37545 * @param inputMasks List of input masks (a mask can be a tensor, or null).
37546 * @param outputMasks List of output masks (a mask can be a tensor, or null).
37547 * @param inputShapes List of input shape tuples.
37548 * @param outputShapes List of output shape tuples.
37549 * @param kwargs Dictionary of keyword arguments that were passed to the
37550 * `call` method of the layer at the call that created the node.
37551 */
37552 addInboundNode(inputTensors, outputTensors, inputMasks, outputMasks, inputShapes, outputShapes, kwargs = null) {
37553 const inputTensorList = toList(inputTensors);
37554 outputTensors = toList(outputTensors);
37555 inputMasks = toList(inputMasks);
37556 outputMasks = toList(outputMasks);
37557 inputShapes = normalizeShapeList(inputShapes);
37558 outputShapes = normalizeShapeList(outputShapes);
37559 // Collect input tensor(s) coordinates.
37560 const inboundLayers = [];
37561 const nodeIndices = [];
37562 const tensorIndices = [];
37563 for (const x of inputTensorList) {
37564 /*
37565 * TODO(michaelterry): Keras adds this value to tensors; it's not
37566 * clear whether we'll use this or not.
37567 */
37568 inboundLayers.push(x.sourceLayer);
37569 nodeIndices.push(x.nodeIndex);
37570 tensorIndices.push(x.tensorIndex);
37571 }
37572 // Create node, add it to inbound nodes.
37573 // (This call has side effects.)
37574 // tslint:disable-next-line:no-unused-expression
37575 new Node({
37576 outboundLayer: this,
37577 inboundLayers,
37578 nodeIndices,
37579 tensorIndices,
37580 inputTensors: inputTensorList,
37581 outputTensors,
37582 inputMasks,
37583 outputMasks,
37584 inputShapes,
37585 outputShapes
37586 }, kwargs);
37587 // Update tensor history
37588 for (let i = 0; i < outputTensors.length; i++) {
37589 // TODO(michaelterry: _uses_learning_phase not tracked.
37590 outputTensors[i].sourceLayer = this;
37591 outputTensors[i].nodeIndex = this.inboundNodes.length - 1;
37592 outputTensors[i].tensorIndex = i;
37593 }
37594 }
37595 /**
37596 * Returns the config of the layer.
37597 *
37598 * A layer config is a TS dictionary (serializable)
37599 * containing the configuration of a layer.
37600 * The same layer can be reinstantiated later
37601 * (without its trained weights) from this configuration.
37602 *
37603 * The config of a layer does not include connectivity
37604 * information, nor the layer class name. These are handled
37605 * by 'Container' (one layer of abstraction above).
37606 *
37607 * Porting Note: The TS dictionary follows TS naming standrds for
37608 * keys, and uses tfjs-layers type-safe Enums. Serialization methods
37609 * should use a helper function to convert to the pythonic storage
37610 * standard. (see serialization_utils.convertTsToPythonic)
37611 *
37612 * @returns TS dictionary of configuration.
37613 *
37614 * @doc {heading: 'Models', 'subheading': 'Classes'}
37615 */
37616 getConfig() {
37617 const config = { name: this.name, trainable: this.trainable };
37618 if (this.batchInputShape != null) {
37619 config['batchInputShape'] = this.batchInputShape;
37620 }
37621 if (this.dtype != null) {
37622 config['dtype'] = this.dtype;
37623 }
37624 return config;
37625 }
37626 /**
37627 * Dispose the weight variables that this Layer instance holds.
37628 *
37629 * @returns {number} Number of disposed variables.
37630 */
37631 disposeWeights() {
37632 this.weights.forEach(weight => weight.dispose());
37633 return this.weights.length;
37634 }
37635 assertNotDisposed() {
37636 if (this._refCount === 0) {
37637 throw new Error(`Layer '${this.name}' is already disposed.`);
37638 }
37639 }
37640 /**
37641 * Attempt to dispose layer's weights.
37642 *
37643 * This method decrease the reference count of the Layer object by 1.
37644 *
37645 * A Layer is reference-counted. Its reference count is incremented by 1
37646 * the first item its `apply()` method is called and when it becomes a part
37647 * of a new `Node` (through calling the `apply()`) method on a
37648 * `tf.SymbolicTensor`).
37649 *
37650 * If the reference count of a Layer becomes 0, all the weights will be
37651 * disposed and the underlying memory (e.g., the textures allocated in WebGL)
37652 * will be freed.
37653 *
37654 * Note: If the reference count is greater than 0 after the decrement, the
37655 * weights of the Layer will *not* be disposed.
37656 *
37657 * After a Layer is disposed, it cannot be used in calls such as `apply()`,
37658 * `getWeights()` or `setWeights()` anymore.
37659 *
37660 * @returns A DisposeResult Object with the following fields:
37661 * - refCountAfterDispose: The reference count of the Container after this
37662 * `dispose()` call.
37663 * - numDisposedVariables: Number of `tf.Variable`s (i.e., weights) disposed
37664 * during this `dispose()` call.
37665 * @throws {Error} If the layer is not built yet, or if the layer has already
37666 * been disposed.
37667 *
37668 * @doc {heading: 'Models', 'subheading': 'Classes'}
37669 */
37670 dispose() {
37671 if (!this.built) {
37672 throw new Error(`Cannot dispose Layer ${this.name} because it has not been ` +
37673 `built yet.`);
37674 }
37675 if (this._refCount === null) {
37676 throw new Error(`Cannot dispose Layer ${this.name} because it has not been used ` +
37677 `yet.`);
37678 }
37679 this.assertNotDisposed();
37680 let numDisposedVariables = 0;
37681 if (--this._refCount === 0) {
37682 numDisposedVariables = this.disposeWeights();
37683 }
37684 return { refCountAfterDispose: this._refCount, numDisposedVariables };
37685 }
37686 }
37687 /**
37688 * Collects the input shape(s) of a list of `tf.Tensor`s or
37689 * `tf.SymbolicTensor`s.
37690 *
37691 * TODO(michaelterry): Update PyKeras docs (backport).
37692 *
37693 * @param inputTensors List of input tensors (or single input tensor).
37694 *
37695 * @return List of shape tuples (or single tuple), one tuple per input.
37696 */
37697 function collectInputShape(inputTensors) {
37698 inputTensors =
37699 toList(inputTensors);
37700 const shapes = [];
37701 for (const x of inputTensors) {
37702 shapes.push(x.shape);
37703 }
37704 return singletonOrArray(shapes);
37705 }
37706 /**
37707 * Guesses output dtype based on inputs.
37708 *
37709 * At present, just returns 'float32' for any input.
37710 *
37711 * @param inputTensors List of input tensors (or single input tensor).
37712 *
37713 * @return The guessed DType. At present, always returns 'float32'.
37714 */
37715 function guessOutputDType(inputTensors) {
37716 return 'float32';
37717 }
37718 /**
37719 * Returns the list of input tensors necessary to compute `tensor`.
37720 *
37721 * Output will always be a list of tensors (potentially with 1 element).
37722 *
37723 * @param tensor The tensor to start from.
37724 * @param layer Origin layer of the tensor.
37725 * @param nodeIndex Origin node index of the tensor.
37726 *
37727 * @return Array of input tensors.
37728 */
37729 function getSourceInputs(tensor, layer, nodeIndex) {
37730 if (layer == null || (nodeIndex != null && nodeIndex > 0)) {
37731 layer = tensor.sourceLayer;
37732 nodeIndex = tensor.nodeIndex;
37733 }
37734 if (layer.inboundNodes.length === 0) {
37735 return [tensor];
37736 }
37737 else {
37738 const node = layer.inboundNodes[nodeIndex];
37739 if (node.inboundLayers.length === 0) {
37740 return node.inputTensors;
37741 }
37742 else {
37743 const sourceTensors = [];
37744 for (let i = 0; i < node.inboundLayers.length; i++) {
37745 const x = node.inputTensors[i];
37746 const layer = node.inboundLayers[i];
37747 const nodeIndex = node.nodeIndices[i];
37748 const previousSources = getSourceInputs(x, layer, nodeIndex);
37749 // Avoid input redundancy.
37750 for (const x of previousSources) {
37751 if (sourceTensors.indexOf(x) === -1) {
37752 sourceTensors.push(x);
37753 }
37754 }
37755 }
37756 return sourceTensors;
37757 }
37758 }
37759 }
37760
37761 /**
37762 * @license
37763 * Copyright 2018 Google LLC
37764 *
37765 * Use of this source code is governed by an MIT-style
37766 * license that can be found in the LICENSE file or at
37767 * https://opensource.org/licenses/MIT.
37768 * =============================================================================
37769 */
37770 class InputLayer extends Layer {
37771 constructor(args) {
37772 super({
37773 dtype: args.dtype,
37774 name: args.name != null ? args.name : getUid('input').toString()
37775 });
37776 // Normalize config.batchSize and config.sparse
37777 if (args.batchSize == null) {
37778 args.batchSize = null;
37779 }
37780 if (args.sparse == null) {
37781 args.sparse = false;
37782 }
37783 this.trainable = false;
37784 this.built = true;
37785 this.sparse = args.sparse;
37786 if (args.inputShape != null && args.batchInputShape != null) {
37787 throw new ValueError('Only provide the inputShape OR ' +
37788 'batchInputShape argument to inputLayer, not both at the same time.');
37789 }
37790 let batchInputShape = args.batchInputShape;
37791 if (batchInputShape == null) {
37792 if (args.inputShape == null) {
37793 throw new ValueError('An InputLayer should be passed either a ' +
37794 '`batchInputShape` or an `inputShape`.');
37795 }
37796 else {
37797 batchInputShape = [args.batchSize].concat(args.inputShape);
37798 }
37799 }
37800 else {
37801 // TODO(michaelterry): Backport to PyKeras
37802 if (args.batchSize != null) {
37803 throw new ValueError('Cannot specify batchSize if batchInputShape is ' +
37804 'specified when creating an InputLayer.');
37805 }
37806 }
37807 const dtype = args.dtype || 'float32';
37808 this.batchInputShape = batchInputShape;
37809 this.dtype = dtype;
37810 // TODO(michaelterry): Backport this to PyKeras?
37811 this.inputSpec = [{ shape: batchInputShape }];
37812 const inputTensor = new SymbolicTensor(this.dtype, this.batchInputShape, this, [], {}, this.name);
37813 inputTensor.nodeIndex = 0;
37814 inputTensor.tensorIndex = 0;
37815 // Create an input node to add to this.outboundNode.
37816 // (This call has side effects.)
37817 // tslint:disable-next-line:no-unused-expression
37818 new Node({
37819 outboundLayer: this,
37820 inboundLayers: [],
37821 nodeIndices: [],
37822 tensorIndices: [],
37823 inputTensors: [inputTensor],
37824 outputTensors: [inputTensor],
37825 inputMasks: [null],
37826 outputMasks: [null],
37827 inputShapes: [batchInputShape],
37828 outputShapes: [batchInputShape]
37829 });
37830 }
37831 apply(inputs, kwargs) {
37832 throw new ValueError('Cannot pass any input to an ' +
37833 `InputLayer's apply() method. InputLayer name: ${this.name}`);
37834 }
37835 dispose() {
37836 // dispose() for InputLayer is overridden as no-op.
37837 return { refCountAfterDispose: this._refCount, numDisposedVariables: 0 };
37838 }
37839 getConfig() {
37840 return {
37841 batchInputShape: this.batchInputShape,
37842 dtype: this.dtype,
37843 sparse: this.sparse,
37844 name: this.name
37845 };
37846 }
37847 }
37848 /** @nocollapse */
37849 InputLayer.className = 'InputLayer';
37850 registerClass(InputLayer);
37851 function Input(config) {
37852 if (config.batchShape == null && config.shape == null) {
37853 throw new Error('Please provide to Input either a `shape`' +
37854 ' or a `batchShape` argument. Note that ' +
37855 '`shape` does not include the batch ' +
37856 'dimension.');
37857 }
37858 if (config.batchShape != null && config.shape != null) {
37859 // TODO(michaelterry): Backport to PyKeras.
37860 throw new ValueError('Please provide either a `shape` or `batchShape` ' +
37861 'argument to Input, but not both.');
37862 }
37863 let batchShape = config.batchShape;
37864 if (config.shape != null && batchShape == null) {
37865 batchShape = [null].concat(config.shape);
37866 }
37867 let dtype = config.dtype;
37868 if (dtype == null) {
37869 dtype = 'float32';
37870 }
37871 const inputLayer = new InputLayer({
37872 batchInputShape: batchShape,
37873 name: config.name,
37874 dtype,
37875 sparse: config.sparse
37876 });
37877 const outputs = inputLayer.inboundNodes[0].outputTensors;
37878 return outputs[0];
37879 }
37880
37881 /**
37882 * @license
37883 * Copyright 2018 Google LLC
37884 *
37885 * Use of this source code is governed by an MIT-style
37886 * license that can be found in the LICENSE file or at
37887 * https://opensource.org/licenses/MIT.
37888 * =============================================================================
37889 */
37890 /**
37891 * Helper function to check the dtype and shape compatibility of a feed value.
37892 */
37893 function assertFeedCompatibility(key, val) {
37894 // Check dtype compatibility.
37895 if (key.dtype == null || key.dtype === val.dtype) {
37896 // a. If types match, return val tensor as is.
37897 return val;
37898 }
37899 try {
37900 // b. Attempt to convert to expected type.
37901 return cast(val, key.dtype);
37902 }
37903 catch (err) {
37904 // c. If conversion fails, return helpful error.
37905 throw new ValueError(`The dtype of the feed (${val.dtype}) can not be cast to the dtype ` +
37906 `of the key '${key.name}' (${key.dtype}).`);
37907 }
37908 }
37909 /**
37910 * FeedDict: A mapping from unique SymbolicTensors to feed values for them.
37911 * A feed value is a concrete value represented as an `Tensor`.
37912 */
37913 class FeedDict {
37914 /**
37915 * Constructor, optionally does copy-construction.
37916 * @param feeds An Array of `Feed`s, or another `FeedDict`, in which case
37917 * copy-construction will be performed.
37918 */
37919 constructor(feeds) {
37920 this.id2Value = {};
37921 this.id2Mask = {};
37922 this.name2Id = {};
37923 if (feeds instanceof FeedDict) {
37924 for (const id in feeds.id2Value) {
37925 this.id2Value[id] = feeds.id2Value[id];
37926 if (id in feeds.id2Mask) {
37927 this.id2Mask[id] = feeds.id2Mask[id];
37928 }
37929 }
37930 }
37931 else {
37932 if (feeds == null) {
37933 return;
37934 }
37935 for (const feed of feeds) {
37936 this.add(feed.key, feed.value);
37937 }
37938 }
37939 }
37940 /**
37941 * Add a key-value pair to the FeedDict.
37942 *
37943 * @param key The key of the feed.
37944 * @param value The value of the tensor feed.
37945 * @param mask The value of the mask feed (optional).
37946 * @returns This `FeedDict`.
37947 * @throws ValueError: If the key `SymbolicTensor` already exists in the
37948 * `FeedDict`.
37949 */
37950 add(key, value, mask) {
37951 if (this.id2Value[key.id] == null) {
37952 this.id2Value[key.id] = assertFeedCompatibility(key, value);
37953 this.name2Id[key.name] = key.id;
37954 if (mask != null) {
37955 this.id2Mask[key.id] = mask;
37956 }
37957 }
37958 else {
37959 throw new ValueError(`Duplicate key: name=${key.name}, id=${key.id}`);
37960 }
37961 return this;
37962 }
37963 /**
37964 * Add a Feed to the FeedDict.
37965 * @param feed The new `Feed` to add.
37966 * @returns This `FeedDict`.
37967 */
37968 addFeed(feed) {
37969 this.add(feed.key, feed.value);
37970 }
37971 /**
37972 * Probe whether a key already exists in the FeedDict.
37973 * @param key
37974 */
37975 hasKey(key) {
37976 return this.id2Value[key.id] != null;
37977 }
37978 /**
37979 * Get all the SymbolicTensor available in this FeedDict.
37980 */
37981 names() {
37982 return Object.keys(this.name2Id);
37983 }
37984 /**
37985 * Get the feed value for given key.
37986 * @param key The SymbolicTensor, or its name (as a string), of which the
37987 * value is sought.
37988 * @returns If `key` exists, the corresponding feed value.
37989 * @throws ValueError: If `key` does not exist in this `FeedDict`.
37990 */
37991 getValue(key) {
37992 if (key instanceof SymbolicTensor) {
37993 if (this.id2Value[key.id] == null) {
37994 throw new ValueError(`Nonexistent key: ${key.name}`);
37995 }
37996 else {
37997 return this.id2Value[key.id];
37998 }
37999 }
38000 else {
38001 const id = this.name2Id[key];
38002 if (id == null) {
38003 throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`);
38004 }
38005 return this.id2Value[id];
38006 }
38007 }
38008 /**
38009 * Get the feed mask for given key.
38010 * @param key The SymbolicTensor, or its name (as a string), of which the
38011 * value is sought.
38012 * @returns If `key` exists, the corresponding feed mask.
38013 * @throws ValueError: If `key` does not exist in this `FeedDict`.
38014 */
38015 getMask(key) {
38016 if (key instanceof SymbolicTensor) {
38017 if (this.id2Value[key.id] == null) {
38018 throw new ValueError(`Nonexistent key: ${key.name}`);
38019 }
38020 else {
38021 return this.id2Mask[key.id];
38022 }
38023 }
38024 else {
38025 const id = this.name2Id[key];
38026 if (id == null) {
38027 throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`);
38028 }
38029 return this.id2Mask[id];
38030 }
38031 }
38032 /** Dispose all mask Tensors held by this object. */
38033 disposeMasks() {
38034 if (this.id2Mask != null) {
38035 dispose(this.id2Mask);
38036 }
38037 }
38038 }
38039 // Cache for topologically sorted SymbolicTensors for given execution
38040 // targets (i.e., fetches).
38041 const cachedSorted = new LruCache();
38042 // Cache for recipient count maps for given execution targets (i.e., fetches).
38043 const cachedRecipientCounts = new LruCache();
38044 function updateCacheMaxEntries(maxEntries) {
38045 if (cachedSorted != null) {
38046 cachedSorted.setMaxEntries(maxEntries);
38047 }
38048 if (cachedRecipientCounts != null) {
38049 cachedRecipientCounts.setMaxEntries(maxEntries);
38050 }
38051 }
38052 /**
38053 * Execute a SymbolicTensor by using concrete feed values.
38054 *
38055 * A `SymbolicTensor` object is a node in a computation graph of TF.js
38056 * Layers. The object is backed by a source layer and input
38057 * `SymbolicTensor`s to the source layer. This method evaluates
38058 * the `call()` method of the source layer, using concrete values of the
38059 * inputs obtained from either
38060 * * `feedDict`, if the input key exists in `feedDict`, or else,
38061 * * a recursive call to `execute()` itself.
38062 *
38063 * @param x: The `SymbolicTensor` to execute.
38064 * @param feedDict: The feed values, as base condition of the recursion.
38065 * execution.
38066 * @param kwargs: Optional keyword arguments.
38067 * @param probe: A probe object (of interface `ExecutionProbe`) used for
38068 * testing memory footprint of `execute` calls.
38069 * @returns Result of the execution.
38070 * @throws ValueError: If any `SymbolicTensor`s from `InputLayer`s
38071 * encountered during the execution lacks a feed value in `feedDict`.
38072 */
38073 function execute(fetches, feedDict, kwargs, probe) {
38074 const training = kwargs == null ? false : kwargs['training'];
38075 const arrayFetches = Array.isArray(fetches);
38076 const fetchArray = arrayFetches ? fetches : [fetches];
38077 const outputNames = fetchArray.map(t => t.name);
38078 const finalOutputs = [];
38079 const feedNames = feedDict.names();
38080 for (const outputName of outputNames) {
38081 if (feedNames.indexOf(outputName) !== -1) {
38082 finalOutputs.push(feedDict.getValue(outputName));
38083 }
38084 else {
38085 finalOutputs.push(null);
38086 }
38087 }
38088 if (probe != null) {
38089 // For optional probing of memory footprint during execution.
38090 probe.maxNumTensors = -Infinity;
38091 probe.minNumTensors = Infinity;
38092 }
38093 // Check cache.
38094 const fetchAndFeedKey = outputNames.join(',') + '|' + feedDict.names().sort().join(',');
38095 let sorted = cachedSorted.get(fetchAndFeedKey);
38096 let recipientCounts;
38097 if (sorted == null) {
38098 // Cache doesn't contain the desired combination of fetches. Compute
38099 // topological sort for the combination for the first time.
38100 const out = getTopologicalSortAndRecipientCounts(fetchArray, feedDict);
38101 sorted = out.sorted;
38102 recipientCounts = out.recipientCounts;
38103 // Store results in cache for future use.
38104 cachedSorted.put(fetchAndFeedKey, sorted);
38105 cachedRecipientCounts.put(fetchAndFeedKey, recipientCounts);
38106 }
38107 recipientCounts = {};
38108 if (!training) {
38109 Object.assign(recipientCounts, cachedRecipientCounts.get(fetchAndFeedKey));
38110 }
38111 const internalFeedDict = new FeedDict(feedDict);
38112 // Start iterative execution on the topologically-sorted SymbolicTensors.
38113 for (let i = 0; i < sorted.length; ++i) {
38114 if (probe != null) {
38115 // For optional probing of memory usage during execution.
38116 const numTensors = memory().numTensors;
38117 if (numTensors > probe.maxNumTensors) {
38118 probe.maxNumTensors = numTensors;
38119 }
38120 if (numTensors < probe.minNumTensors) {
38121 probe.minNumTensors = numTensors;
38122 }
38123 }
38124 const symbolic = sorted[i];
38125 const srcLayer = symbolic.sourceLayer;
38126 if (srcLayer instanceof InputLayer) {
38127 continue;
38128 }
38129 const inputValues = [];
38130 const inputMasks = [];
38131 const tensorsToDispose = [];
38132 let maskExists = false;
38133 for (const input of symbolic.inputs) {
38134 const value = internalFeedDict.getValue(input);
38135 const mask = internalFeedDict.getMask(input);
38136 inputValues.push(value);
38137 inputMasks.push(mask);
38138 if (mask != null) {
38139 maskExists = true;
38140 }
38141 if (!training) {
38142 recipientCounts[input.name]--;
38143 if (recipientCounts[input.name] === 0 && !feedDict.hasKey(input) &&
38144 outputNames.indexOf(input.name) === -1 && !value.isDisposed &&
38145 input.sourceLayer.stateful !== true) {
38146 tensorsToDispose.push(value);
38147 }
38148 }
38149 }
38150 if (maskExists) {
38151 kwargs = kwargs || {};
38152 kwargs['mask'] = inputMasks[0];
38153 }
38154 const outputTensors = toList(srcLayer.apply(inputValues, kwargs));
38155 let outputMask = null;
38156 if (srcLayer.supportsMasking) {
38157 outputMask = srcLayer.computeMask(inputValues, inputMasks);
38158 }
38159 const layerOutputs = getNodeOutputs(symbolic);
38160 const outputSymbolicTensors = Array.isArray(layerOutputs) ? layerOutputs : [layerOutputs];
38161 for (let i = 0; i < outputSymbolicTensors.length; ++i) {
38162 if (!internalFeedDict.hasKey(outputSymbolicTensors[i])) {
38163 internalFeedDict.add(outputSymbolicTensors[i], outputTensors[i], Array.isArray(outputMask) ? outputMask[0] : outputMask);
38164 }
38165 const index = outputNames.indexOf(outputSymbolicTensors[i].name);
38166 if (index !== -1) {
38167 finalOutputs[index] = outputTensors[i];
38168 }
38169 }
38170 if (!training) {
38171 // Clean up Tensors that are no longer needed.
38172 dispose(tensorsToDispose);
38173 }
38174 }
38175 // NOTE(cais): Unlike intermediate tensors, we don't discard mask
38176 // tensors as we go, because these tensors are sometimes passed over a
38177 // series of mutliple layers, i.e., not obeying the immediate input
38178 // relations in the graph. If this becomes a memory-usage concern,
38179 // we can improve this in the future.
38180 internalFeedDict.disposeMasks();
38181 return arrayFetches ? finalOutputs : finalOutputs[0];
38182 }
38183 /**
38184 * Sort the `SymbolicTensor`s topologically, for an array of fetches.
38185 *
38186 * This function calls getTopologicalSortAndRecipientCountsForOneFetch and
38187 * merges their results.
38188 *
38189 * @param fetch The array of fetches requested. Must be a non-empty array.
38190 * @param feedDict The dictionary of fed values.
38191 * @returns sorted: Topologically-sorted array of SymbolicTensors.
38192 * recipientCounts: Recipient counts for all SymbolicTensors in `sorted`.
38193 */
38194 function getTopologicalSortAndRecipientCounts(fetches, feedDict) {
38195 assert(fetches != null && fetches.length > 0, () => `Expected at least one fetch, got none`);
38196 let finalSorted = [];
38197 let finalRecipientMap = {};
38198 if (fetches.length === 1) {
38199 // Special-casing 1 fetch for efficiency.
38200 const out = getTopologicalSortAndRecipientCountsForOneFetch(fetches[0], feedDict);
38201 finalSorted = out.sorted;
38202 finalRecipientMap = out.recipientMap;
38203 }
38204 else {
38205 const visited = new Set();
38206 for (const fetch of fetches) {
38207 const { sorted, recipientMap } = getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict);
38208 // Merge sorted SymbolicTensor Arrays.
38209 for (const symbolicTensor of sorted) {
38210 if (!visited.has(symbolicTensor.name)) {
38211 finalSorted.push(symbolicTensor);
38212 visited.add(symbolicTensor.name);
38213 }
38214 }
38215 // Merge recipient maps.
38216 for (const name in recipientMap) {
38217 if (finalRecipientMap[name] == null) {
38218 finalRecipientMap[name] = new Set();
38219 }
38220 recipientMap[name].forEach(recipient => finalRecipientMap[name].add(recipient));
38221 }
38222 }
38223 }
38224 return {
38225 sorted: finalSorted,
38226 recipientCounts: recipientMap2Counts(finalRecipientMap)
38227 };
38228 }
38229 function recipientMap2Counts(recipientMap) {
38230 const recipientCounts = {};
38231 for (const name in recipientMap) {
38232 recipientCounts[name] = recipientMap[name].size;
38233 }
38234 return recipientCounts;
38235 }
38236 /**
38237 * Sort the `SymbolicTensor`s topologically, for a single fetch.
38238 *
38239 * This helper function processes the upstream SymbolicTensors of a single
38240 * fetch.
38241 *
38242 * @param fetch The single fetch requested.
38243 * @param feedDict The dictionary of fed values.
38244 * @returns sorted: Topologically-sorted array of SymbolicTensors.
38245 * recipientMap: Recipient names for all SymbolicTensors in `sorted`.
38246 */
38247 function getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict) {
38248 const visited = new Set();
38249 const sorted = [];
38250 const recipientMap = {};
38251 // Put keys of the feedDict into visited first, so they don't have to be
38252 // walked. This is needed in case where there are feeds for intermediate
38253 // SymbolicTensors of the graph.
38254 for (const key of feedDict.names()) {
38255 visited.add(key);
38256 }
38257 const stack = [];
38258 const marks = [];
38259 // Initial population of stack and marks.
38260 stack.push(fetch);
38261 while (stack.length > 0) {
38262 const top = stack[stack.length - 1];
38263 if (visited.has(top.name)) {
38264 stack.pop();
38265 continue;
38266 }
38267 const topIsMarked = marks[marks.length - 1] === stack.length - 1;
38268 if (top.inputs.length === 0 || topIsMarked) {
38269 // Input SymbolicTensor or all children have been visited.
38270 stack.pop();
38271 sorted.push(top);
38272 visited.add(top.name);
38273 if (topIsMarked) {
38274 marks.pop();
38275 }
38276 }
38277 else {
38278 // A non-input SymbolicTensor whose upstream SymbolicTensors haven't
38279 // been visited yet. Push them onto the stack.
38280 marks.push(stack.length - 1);
38281 for (const input of top.inputs) {
38282 // Increment the recipient count. Note that this needs to happen
38283 // regardless of whether the SymbolicTensor has been visited before.
38284 if (recipientMap[input.name] == null) {
38285 recipientMap[input.name] = new Set();
38286 }
38287 recipientMap[input.name].add(top.name);
38288 if (visited.has(input.name)) {
38289 continue; // Avoid repeated visits to the same SymbolicTensor.
38290 }
38291 stack.push(input);
38292 }
38293 }
38294 }
38295 return { sorted, recipientMap };
38296 }
38297 /**
38298 * Get the symbolic output tensors of the node to which a given fetch belongs.
38299 * @param fetch The fetched symbolic tensor.
38300 * @returns The Array of symbolic tensors output by the node to which `fetch`
38301 * belongs.
38302 */
38303 function getNodeOutputs(fetch) {
38304 let layerOutputs;
38305 if (fetch.sourceLayer.inboundNodes.length === 1) {
38306 layerOutputs = fetch.sourceLayer.output;
38307 }
38308 else {
38309 let nodeIndex = null;
38310 for (let i = 0; i < fetch.sourceLayer.inboundNodes.length; ++i) {
38311 for (const outputTensor of fetch.sourceLayer.inboundNodes[i]
38312 .outputTensors) {
38313 if (outputTensor.id === fetch.id) {
38314 nodeIndex = i;
38315 break;
38316 }
38317 }
38318 }
38319 layerOutputs = fetch.sourceLayer.getOutputAt(nodeIndex);
38320 }
38321 return layerOutputs;
38322 }
38323
38324 /**
38325 * @license
38326 * Copyright 2022 Google LLC. All Rights Reserved.
38327 * Licensed under the Apache License, Version 2.0 (the "License");
38328 * you may not use this file except in compliance with the License.
38329 * You may obtain a copy of the License at
38330 *
38331 * http://www.apache.org/licenses/LICENSE-2.0
38332 *
38333 * Unless required by applicable law or agreed to in writing, software
38334 * distributed under the License is distributed on an "AS IS" BASIS,
38335 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38336 * See the License for the specific language governing permissions and
38337 * limitations under the License.
38338 * =============================================================================
38339 */
38340 const ENV$1 = env();
38341 /** The max number of entries for the caches of layers' topological sort. */
38342 ENV$1.registerFlag('TOPOLOGICAL_SORT_CACHE_MAX_ENTRIES', () => 100, updateCacheMaxEntries);
38343
38344 /**
38345 * @license
38346 * Copyright 2018 Google LLC
38347 *
38348 * Use of this source code is governed by an MIT-style
38349 * license that can be found in the LICENSE file or at
38350 * https://opensource.org/licenses/MIT.
38351 * =============================================================================
38352 */
38353 /**
38354 * Helper function used by many of the Constraints to find the L2Norms.
38355 */
38356 function calcL2Norms(w, axis) {
38357 return tidy(() => sqrt(sum$1(mul(w, w), axis, true)));
38358 }
38359 /**
38360 * Base class for functions that impose constraints on weight values
38361 *
38362 * @doc {
38363 * heading: 'Constraints',
38364 * subheading: 'Classes',
38365 * namespace: 'constraints'
38366 * }
38367 */
38368 class Constraint extends Serializable {
38369 getConfig() {
38370 return {};
38371 }
38372 }
38373 class MaxNorm extends Constraint {
38374 constructor(args) {
38375 super();
38376 this.defaultMaxValue = 2;
38377 this.defaultAxis = 0;
38378 this.maxValue =
38379 args.maxValue != null ? args.maxValue : this.defaultMaxValue;
38380 this.axis = args.axis != null ? args.axis : this.defaultAxis;
38381 }
38382 apply(w) {
38383 return tidy(() => {
38384 const norms = calcL2Norms(w, this.axis);
38385 const desired = clipByValue(norms, 0, this.maxValue);
38386 return mul(w, div(desired, add$1(epsilon(), norms)));
38387 });
38388 }
38389 getConfig() {
38390 return { maxValue: this.maxValue, axis: this.axis };
38391 }
38392 }
38393 /** @nocollapse */
38394 MaxNorm.className = 'MaxNorm';
38395 registerClass(MaxNorm);
38396 class UnitNorm extends Constraint {
38397 constructor(args) {
38398 super();
38399 this.defaultAxis = 0;
38400 this.axis = args.axis != null ? args.axis : this.defaultAxis;
38401 }
38402 apply(w) {
38403 return tidy(() => div(w, add$1(epsilon(), calcL2Norms(w, this.axis))));
38404 }
38405 getConfig() {
38406 return { axis: this.axis };
38407 }
38408 }
38409 /** @nocollapse */
38410 UnitNorm.className = 'UnitNorm';
38411 registerClass(UnitNorm);
38412 class NonNeg extends Constraint {
38413 apply(w) {
38414 return relu(w);
38415 }
38416 }
38417 /** @nocollapse */
38418 NonNeg.className = 'NonNeg';
38419 registerClass(NonNeg);
38420 class MinMaxNorm extends Constraint {
38421 constructor(args) {
38422 super();
38423 this.defaultMinValue = 0.0;
38424 this.defaultMaxValue = 1.0;
38425 this.defaultRate = 1.0;
38426 this.defaultAxis = 0;
38427 this.minValue =
38428 args.minValue != null ? args.minValue : this.defaultMinValue;
38429 this.maxValue =
38430 args.maxValue != null ? args.maxValue : this.defaultMaxValue;
38431 this.rate = args.rate != null ? args.rate : this.defaultRate;
38432 this.axis = args.axis != null ? args.axis : this.defaultAxis;
38433 }
38434 apply(w) {
38435 return tidy(() => {
38436 const norms = calcL2Norms(w, this.axis);
38437 const desired = add$1(mul(this.rate, clipByValue(norms, this.minValue, this.maxValue)), mul(1.0 - this.rate, norms));
38438 return mul(w, div(desired, add$1(epsilon(), norms)));
38439 });
38440 }
38441 getConfig() {
38442 return {
38443 minValue: this.minValue,
38444 maxValue: this.maxValue,
38445 rate: this.rate,
38446 axis: this.axis
38447 };
38448 }
38449 }
38450 /** @nocollapse */
38451 MinMaxNorm.className = 'MinMaxNorm';
38452 registerClass(MinMaxNorm);
38453 // Maps the JavaScript-like identifier keys to the corresponding registry
38454 // symbols.
38455 const CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
38456 'maxNorm': 'MaxNorm',
38457 'minMaxNorm': 'MinMaxNorm',
38458 'nonNeg': 'NonNeg',
38459 'unitNorm': 'UnitNorm'
38460 };
38461 function serializeConstraint(constraint) {
38462 return serializeKerasObject(constraint);
38463 }
38464 function deserializeConstraint(config, customObjects = {}) {
38465 return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'constraint');
38466 }
38467 function getConstraint(identifier) {
38468 if (identifier == null) {
38469 return null;
38470 }
38471 if (typeof identifier === 'string') {
38472 const className = identifier in CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
38473 CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
38474 identifier;
38475 const config = { className, config: {} };
38476 return deserializeConstraint(config);
38477 }
38478 else if (identifier instanceof Constraint) {
38479 return identifier;
38480 }
38481 else {
38482 return deserializeConstraint(identifier);
38483 }
38484 }
38485
38486 /**
38487 * @license
38488 * Copyright 2018 Google LLC
38489 *
38490 * Use of this source code is governed by an MIT-style
38491 * license that can be found in the LICENSE file or at
38492 * https://opensource.org/licenses/MIT.
38493 * =============================================================================
38494 */
38495 /**
38496 * MaxNorm weight constraint.
38497 *
38498 * Constrains the weights incident to each hidden unit
38499 * to have a norm less than or equal to a desired value.
38500 *
38501 * References
38502 * - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting
38503 * Srivastava, Hinton, et al.
38504 * 2014](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
38505 *
38506 * @doc {heading: 'Constraints',namespace: 'constraints'}
38507 */
38508 function maxNorm(args) {
38509 return new MaxNorm(args);
38510 }
38511 /**
38512 * Constrains the weights incident to each hidden unit to have unit norm.
38513 *
38514 * @doc {heading: 'Constraints', namespace: 'constraints'}
38515 */
38516 function unitNorm(args) {
38517 return new UnitNorm(args);
38518 }
38519 /**
38520 * Constains the weight to be non-negative.
38521 *
38522 * @doc {heading: 'Constraints', namespace: 'constraints'}
38523 */
38524 function nonNeg() {
38525 return new NonNeg();
38526 }
38527 /** @doc {heading: 'Constraints', namespace: 'constraints'} */
38528 function minMaxNorm(config) {
38529 return new MinMaxNorm(config);
38530 }
38531
38532 var exports_constraints = /*#__PURE__*/Object.freeze({
38533 __proto__: null,
38534 maxNorm: maxNorm,
38535 unitNorm: unitNorm,
38536 nonNeg: nonNeg,
38537 minMaxNorm: minMaxNorm
38538 });
38539
38540 /**
38541 * @license
38542 * Copyright 2018 Google LLC
38543 *
38544 * Use of this source code is governed by an MIT-style
38545 * license that can be found in the LICENSE file or at
38546 * https://opensource.org/licenses/MIT.
38547 * =============================================================================
38548 */
38549 /**
38550 * Initializer that generates tensors initialized to 0.
38551 *
38552 * @doc {heading: 'Initializers', namespace: 'initializers'}
38553 */
38554 function zeros$1() {
38555 return new Zeros();
38556 }
38557 /**
38558 * Initializer that generates tensors initialized to 1.
38559 *
38560 * @doc {heading: 'Initializers', namespace: 'initializers'}
38561 */
38562 function ones$2() {
38563 return new Ones();
38564 }
38565 /**
38566 * Initializer that generates values initialized to some constant.
38567 *
38568 * @doc {heading: 'Initializers', namespace: 'initializers'}
38569 */
38570 function constant(args) {
38571 return new Constant(args);
38572 }
38573 /**
38574 * Initializer that generates random values initialized to a uniform
38575 * distribution.
38576 *
38577 * Values will be distributed uniformly between the configured minval and
38578 * maxval.
38579 *
38580 * @doc {heading: 'Initializers', namespace: 'initializers'}
38581 */
38582 function randomUniform$1(args) {
38583 return new RandomUniform(args);
38584 }
38585 /**
38586 * Initializer that generates random values initialized to a normal
38587 * distribution.
38588 *
38589 * @doc {heading: 'Initializers', namespace: 'initializers'}
38590 */
38591 function randomNormal$2(args) {
38592 return new RandomNormal(args);
38593 }
38594 /**
38595 * Initializer that generates random values initialized to a truncated normal.
38596 * distribution.
38597 *
38598 * These values are similar to values from a `RandomNormal` except that values
38599 * more than two standard deviations from the mean are discarded and re-drawn.
38600 * This is the recommended initializer for neural network weights and filters.
38601 *
38602 * @doc {heading: 'Initializers', namespace: 'initializers'}
38603 */
38604 function truncatedNormal$1(args) {
38605 return new TruncatedNormal(args);
38606 }
38607 /**
38608 * Initializer that generates the identity matrix.
38609 * Only use for square 2D matrices.
38610 *
38611 * @doc {heading: 'Initializers', namespace: 'initializers'}
38612 */
38613 function identity(args) {
38614 return new Identity$1(args);
38615 }
38616 /**
38617 * Initializer capable of adapting its scale to the shape of weights.
38618 * With distribution=NORMAL, samples are drawn from a truncated normal
38619 * distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
38620 * - number of input units in the weight tensor, if mode = FAN_IN.
38621 * - number of output units, if mode = FAN_OUT.
38622 * - average of the numbers of input and output units, if mode = FAN_AVG.
38623 * With distribution=UNIFORM,
38624 * samples are drawn from a uniform distribution
38625 * within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
38626 *
38627 * @doc {heading: 'Initializers',namespace: 'initializers'}
38628 */
38629 function varianceScaling(config) {
38630 return new VarianceScaling(config);
38631 }
38632 /**
38633 * Glorot uniform initializer, also called Xavier uniform initializer.
38634 * It draws samples from a uniform distribution within [-limit, limit]
38635 * where `limit` is `sqrt(6 / (fan_in + fan_out))`
38636 * where `fan_in` is the number of input units in the weight tensor
38637 * and `fan_out` is the number of output units in the weight tensor
38638 *
38639 * Reference:
38640 * Glorot & Bengio, AISTATS 2010
38641 * http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf.
38642 *
38643 * @doc {heading: 'Initializers', namespace: 'initializers'}
38644 */
38645 function glorotUniform(args) {
38646 return new GlorotUniform(args);
38647 }
38648 /**
38649 * Glorot normal initializer, also called Xavier normal initializer.
38650 * It draws samples from a truncated normal distribution centered on 0
38651 * with `stddev = sqrt(2 / (fan_in + fan_out))`
38652 * where `fan_in` is the number of input units in the weight tensor
38653 * and `fan_out` is the number of output units in the weight tensor.
38654 *
38655 * Reference:
38656 * Glorot & Bengio, AISTATS 2010
38657 * http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
38658 *
38659 * @doc {heading: 'Initializers', namespace: 'initializers'}
38660 */
38661 function glorotNormal(args) {
38662 return new GlorotNormal(args);
38663 }
38664 /**
38665 * He normal initializer.
38666 *
38667 * It draws samples from a truncated normal distribution centered on 0
38668 * with `stddev = sqrt(2 / fanIn)`
38669 * where `fanIn` is the number of input units in the weight tensor.
38670 *
38671 * Reference:
38672 * He et al., http://arxiv.org/abs/1502.01852
38673 *
38674 * @doc {heading: 'Initializers', namespace: 'initializers'}
38675 */
38676 function heNormal(args) {
38677 return new HeNormal(args);
38678 }
38679 /**
38680 * He uniform initializer.
38681 *
38682 * It draws samples from a uniform distribution within [-limit, limit]
38683 * where `limit` is `sqrt(6 / fan_in)`
38684 * where `fanIn` is the number of input units in the weight tensor.
38685 *
38686 * Reference:
38687 * He et al., http://arxiv.org/abs/1502.01852
38688 *
38689 * @doc {heading: 'Initializers',namespace: 'initializers'}
38690 */
38691 function heUniform(args) {
38692 return new HeUniform(args);
38693 }
38694 /**
38695 * LeCun normal initializer.
38696 *
38697 * It draws samples from a truncated normal distribution centered on 0
38698 * with `stddev = sqrt(1 / fanIn)`
38699 * where `fanIn` is the number of input units in the weight tensor.
38700 *
38701 * References:
38702 * [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
38703 * [Efficient Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
38704 *
38705 * @doc {heading: 'Initializers', namespace: 'initializers'}
38706 */
38707 function leCunNormal(args) {
38708 return new LeCunNormal(args);
38709 }
38710 /**
38711 * LeCun uniform initializer.
38712 *
38713 * It draws samples from a uniform distribution in the interval
38714 * `[-limit, limit]` with `limit = sqrt(3 / fanIn)`,
38715 * where `fanIn` is the number of input units in the weight tensor.
38716 *
38717 * @doc {heading: 'Initializers', namespace: 'initializers'}
38718 */
38719 function leCunUniform(args) {
38720 return new LeCunUniform(args);
38721 }
38722 /**
38723 * Initializer that generates a random orthogonal matrix.
38724 *
38725 * Reference:
38726 * [Saxe et al., http://arxiv.org/abs/1312.6120](http://arxiv.org/abs/1312.6120)
38727 *
38728 * @doc {heading: 'Initializers', namespace: 'initializers'}
38729 */
38730 function orthogonal(args) {
38731 return new Orthogonal(args);
38732 }
38733
38734 var exports_initializers = /*#__PURE__*/Object.freeze({
38735 __proto__: null,
38736 zeros: zeros$1,
38737 ones: ones$2,
38738 constant: constant,
38739 randomUniform: randomUniform$1,
38740 randomNormal: randomNormal$2,
38741 truncatedNormal: truncatedNormal$1,
38742 identity: identity,
38743 varianceScaling: varianceScaling,
38744 glorotUniform: glorotUniform,
38745 glorotNormal: glorotNormal,
38746 heNormal: heNormal,
38747 heUniform: heUniform,
38748 leCunNormal: leCunNormal,
38749 leCunUniform: leCunUniform,
38750 orthogonal: orthogonal
38751 });
38752
38753 /**
38754 * @license
38755 * Copyright 2018 Google LLC
38756 *
38757 * Use of this source code is governed by an MIT-style
38758 * license that can be found in the LICENSE file or at
38759 * https://opensource.org/licenses/MIT.
38760 * =============================================================================
38761 */
38762 /**
38763 * Turn any Scalar values in a Logs object into actual number values.
38764 *
38765 * @param logs The `Logs` object to be resolved in place.
38766 */
38767 async function resolveScalarsInLogs(logs) {
38768 if (logs == null) {
38769 return;
38770 }
38771 const promises = [];
38772 const keys = [];
38773 const scalarsToDispose = [];
38774 for (const key in logs) {
38775 const value = logs[key];
38776 if (typeof value !== 'number') {
38777 const valueScalar = value;
38778 promises.push(valueScalar.data());
38779 keys.push(key);
38780 scalarsToDispose.push(valueScalar);
38781 }
38782 }
38783 if (promises.length > 0) {
38784 const values = await Promise.all(promises);
38785 for (let i = 0; i < values.length; ++i) {
38786 logs[keys[i]] = values[i][0];
38787 }
38788 // Dispose the original scalar tensors.
38789 dispose(scalarsToDispose);
38790 }
38791 }
38792 /**
38793 * Dispose all Tensors in an UnresolvedLogs object.
38794 *
38795 * @param logs An `UnresolvedLogs` object potentially containing `tf.Tensor`s in
38796 * places where the values can be `tf.Tensor` or `number`.
38797 */
38798 function disposeTensorsInLogs(logs) {
38799 if (logs == null) {
38800 return;
38801 }
38802 for (const key in logs) {
38803 const value = logs[key];
38804 if (typeof value !== 'number') {
38805 value.dispose();
38806 }
38807 }
38808 }
38809
38810 /**
38811 * @license
38812 * Copyright 2018 Google LLC
38813 *
38814 * Use of this source code is governed by an MIT-style
38815 * license that can be found in the LICENSE file or at
38816 * https://opensource.org/licenses/MIT.
38817 * =============================================================================
38818 */
38819 /** Verbosity logging level when fitting a model. */
38820 var ModelLoggingVerbosity;
38821 (function (ModelLoggingVerbosity) {
38822 ModelLoggingVerbosity[ModelLoggingVerbosity["SILENT"] = 0] = "SILENT";
38823 ModelLoggingVerbosity[ModelLoggingVerbosity["VERBOSE"] = 1] = "VERBOSE";
38824 })(ModelLoggingVerbosity || (ModelLoggingVerbosity = {}));
38825 /** How often to yield to the main thread when training (in ms). */
38826 const DEFAULT_YIELD_EVERY_MS = 125;
38827 /**
38828 * Abstract base class used to build new callbacks.
38829 *
38830 * The `logs` dictionary that callback methods take as argument will contain
38831 * keys for quantities relevant to the current batch or epoch.
38832 *
38833 * Currently, the `.fit()` method of the `Sequential` model class
38834 * will include the following quantities in the `logs` that
38835 * it passes to its callbacks:
38836 *
38837 * onEpochEnd: Logs include `acc` and `loss`, and optionally include `valLoss`
38838 * (if validation is enabled in `fit`), and `valAcc` (if validation and
38839 * accuracy monitoring are enabled).
38840 * onBatchBegin: Logs include `size`, the number of samples in the current
38841 * batch.
38842 * onBatchEnd: Logs include `loss`, and optionally `acc` (if accuracy monitoring
38843 * is enabled).
38844 */
38845 class BaseCallback {
38846 constructor() {
38847 // TODO(michaelterry): This type is a best guess.
38848 this.validationData = null;
38849 }
38850 setParams(params) {
38851 this.params = params;
38852 }
38853 async onEpochBegin(epoch, logs) { }
38854 async onEpochEnd(epoch, logs) { }
38855 async onBatchBegin(batch, logs) { }
38856 async onBatchEnd(batch, logs) { }
38857 async onTrainBegin(logs) { }
38858 async onTrainEnd(logs) { }
38859 // LayersModel needs to call Callback.setModel(), but cannot actually depend
38860 // on Callback because that creates a cyclic dependency. Providing this no-op
38861 // method on BaseCallback breaks the cycle: this way LayersModel can depend on
38862 // BaseCallback but not on Callback. The argument is typed as `Container`
38863 // (the superclass of LayersModel) to avoid recapitulating the cycle. Callback
38864 // overrides this method and enforces that the argument is really a
38865 // LayersModel.
38866 setModel(model) {
38867 // Do nothing. Use Callback instead of BaseCallback to track the model.
38868 }
38869 }
38870 /**
38871 * Container abstracting a list of callbacks.
38872 */
38873 class CallbackList {
38874 // TODO(cais): When the need arises, uncomment the following lines and
38875 // implement the queue for time values.
38876 // private deltaTBatch: number;
38877 // private deltaTsBatchBegin: Array<number>;
38878 // private deltaTsBatchEnd: Array<number>;
38879 /**
38880 * Constructor of CallbackList.
38881 * @param callbacks Array of `Callback` instances.
38882 * @param queueLength Queue length for keeping running statistics over
38883 * callback execution time.
38884 */
38885 constructor(callbacks, queueLength = 10) {
38886 // TODO(cais): Make use of queueLength when implementing the queue for time
38887 // values.
38888 if (callbacks == null) {
38889 callbacks = [];
38890 }
38891 this.callbacks = callbacks;
38892 this.queueLength = queueLength;
38893 }
38894 append(callback) {
38895 this.callbacks.push(callback);
38896 }
38897 setParams(params) {
38898 for (const callback of this.callbacks) {
38899 callback.setParams(params);
38900 }
38901 }
38902 setModel(model) {
38903 for (const callback of this.callbacks) {
38904 callback.setModel(model);
38905 }
38906 }
38907 /**
38908 * Called at the start of an epoch.
38909 * @param epoch Index of epoch.
38910 * @param logs Dictionary of logs.
38911 */
38912 async onEpochBegin(epoch, logs) {
38913 if (logs == null) {
38914 logs = {};
38915 }
38916 for (const callback of this.callbacks) {
38917 await callback.onEpochBegin(epoch, logs);
38918 }
38919 }
38920 /**
38921 * Called at the end of an epoch.
38922 * @param epoch Index of epoch.
38923 * @param logs Dictionary of logs.
38924 */
38925 async onEpochEnd(epoch, logs) {
38926 if (logs == null) {
38927 logs = {};
38928 }
38929 for (const callback of this.callbacks) {
38930 await callback.onEpochEnd(epoch, logs);
38931 }
38932 }
38933 /**
38934 * Called right before processing a batch.
38935 * @param batch Index of batch within the current epoch.
38936 * @param logs Dictionary of logs.
38937 */
38938 async onBatchBegin(batch, logs) {
38939 if (logs == null) {
38940 logs = {};
38941 }
38942 for (const callback of this.callbacks) {
38943 await callback.onBatchBegin(batch, logs);
38944 }
38945 }
38946 /**
38947 * Called at the end of a batch.
38948 * @param batch Index of batch within the current epoch.
38949 * @param logs Dictionary of logs.
38950 */
38951 async onBatchEnd(batch, logs) {
38952 if (logs == null) {
38953 logs = {};
38954 }
38955 for (const callback of this.callbacks) {
38956 await callback.onBatchEnd(batch, logs);
38957 }
38958 }
38959 /**
38960 * Called at the beginning of training.
38961 * @param logs Dictionary of logs.
38962 */
38963 async onTrainBegin(logs) {
38964 if (logs == null) {
38965 logs = {};
38966 }
38967 for (const callback of this.callbacks) {
38968 await callback.onTrainBegin(logs);
38969 }
38970 }
38971 /**
38972 * Called at the end of training.
38973 * @param logs Dictionary of logs.
38974 */
38975 async onTrainEnd(logs) {
38976 if (logs == null) {
38977 logs = {};
38978 }
38979 for (const callback of this.callbacks) {
38980 await callback.onTrainEnd(logs);
38981 }
38982 }
38983 }
38984 /**
38985 * Callback that accumulates epoch averages of metrics.
38986 *
38987 * This callback is automatically applied to every LayersModel.
38988 */
38989 class BaseLogger extends BaseCallback {
38990 constructor() {
38991 super();
38992 }
38993 async onEpochBegin(epoch) {
38994 this.seen = 0;
38995 this.totals = {};
38996 }
38997 async onBatchEnd(batch, logs) {
38998 if (logs == null) {
38999 logs = {};
39000 }
39001 const batchSize = logs['size'] == null ? 0 : logs['size'];
39002 this.seen += batchSize;
39003 for (const key in logs) {
39004 const value = logs[key];
39005 if (typeof value === 'number') {
39006 if (!this.totals.hasOwnProperty(key)) {
39007 this.totals[key] = 0;
39008 }
39009 this.totals[key] = this.totals[key] + value * batchSize;
39010 }
39011 else {
39012 let oldTotalsToDispose;
39013 if (key in this.totals) {
39014 oldTotalsToDispose = this.totals[key];
39015 }
39016 else {
39017 this.totals[key] = 0;
39018 }
39019 const total = tidy(() => add$1((this.totals[key]), mul(value, batchSize)));
39020 this.totals[key] = total;
39021 if (oldTotalsToDispose != null) {
39022 oldTotalsToDispose.dispose();
39023 }
39024 }
39025 }
39026 }
39027 async onEpochEnd(epoch, logs) {
39028 if (logs != null) {
39029 for (const key of this.params['metrics']) {
39030 if (this.totals[key] == null) {
39031 continue;
39032 }
39033 if (typeof this.totals[key] === 'number') {
39034 logs[key] = this.totals[key] / this.seen;
39035 }
39036 else {
39037 tidy(() => {
39038 const log = mul(div(1, this.seen), this.totals[key]);
39039 logs[key] = log;
39040 this.totals[key].dispose();
39041 keep(logs[key]);
39042 });
39043 }
39044 }
39045 }
39046 }
39047 }
39048 /**
39049 * Callback that records events into a `History` object. This callback is
39050 * automatically applied to every TF.js Layers model. The `History` object
39051 * gets returned by the `fit` method of models.
39052 */
39053 class History extends BaseCallback {
39054 async onTrainBegin(logs) {
39055 this.epoch = [];
39056 this.history = {};
39057 }
39058 async onEpochEnd(epoch, logs) {
39059 if (logs == null) {
39060 logs = {};
39061 }
39062 this.epoch.push(epoch);
39063 for (const key in logs) {
39064 if (this.history[key] == null) {
39065 this.history[key] = [];
39066 }
39067 this.history[key].push(logs[key]);
39068 }
39069 }
39070 /**
39071 * Await the values of all losses and metrics.
39072 */
39073 async syncData() {
39074 const promises = [];
39075 const keys = [];
39076 const indices = [];
39077 for (const key in this.history) {
39078 const valueArray = this.history[key];
39079 for (let i = 0; i < valueArray.length; ++i) {
39080 if (typeof valueArray[i] !== 'number') {
39081 const valueScalar = valueArray[i];
39082 promises.push(valueScalar.data());
39083 keys.push(key);
39084 indices.push(i);
39085 }
39086 }
39087 }
39088 const values = await Promise.all(promises);
39089 for (let n = 0; n < values.length; ++n) {
39090 const tensorToDispose = this.history[keys[n]][indices[n]];
39091 tensorToDispose.dispose();
39092 this.history[keys[n]][indices[n]] = values[n][0];
39093 }
39094 }
39095 }
39096 /**
39097 * Custom callback for training.
39098 */
39099 class CustomCallback extends BaseCallback {
39100 constructor(args, yieldEvery) {
39101 super();
39102 this.currentEpoch = 0;
39103 this.nowFunc = args.nowFunc;
39104 this.nextFrameFunc = args.nextFrameFunc || nextFrame;
39105 this.yieldEvery = yieldEvery || 'auto';
39106 if (this.yieldEvery === 'auto') {
39107 this.yieldEvery = DEFAULT_YIELD_EVERY_MS;
39108 }
39109 if (this.yieldEvery === 'never' && args.onYield != null) {
39110 throw new Error('yieldEvery is `never` but you provided an `onYield` callback. ' +
39111 'Either change `yieldEvery` or remove the callback');
39112 }
39113 if (isNumber(this.yieldEvery)) {
39114 // Decorate `maybeWait` so it will be called at most once every
39115 // `yieldEvery` ms.
39116 this.maybeWait = debounce(this.maybeWait.bind(this), this.yieldEvery, this.nowFunc);
39117 }
39118 this.trainBegin = args.onTrainBegin;
39119 this.trainEnd = args.onTrainEnd;
39120 this.epochBegin = args.onEpochBegin;
39121 this.epochEnd = args.onEpochEnd;
39122 this.batchBegin = args.onBatchBegin;
39123 this.batchEnd = args.onBatchEnd;
39124 this.yield = args.onYield;
39125 }
39126 async maybeWait(epoch, batch, logs) {
39127 const ps = [];
39128 if (this.yield != null) {
39129 await resolveScalarsInLogs(logs);
39130 ps.push(this.yield(epoch, batch, logs));
39131 }
39132 ps.push(this.nextFrameFunc());
39133 await Promise.all(ps);
39134 }
39135 async onEpochBegin(epoch, logs) {
39136 this.currentEpoch = epoch;
39137 if (this.epochBegin != null) {
39138 await resolveScalarsInLogs(logs);
39139 await this.epochBegin(epoch, logs);
39140 }
39141 }
39142 async onEpochEnd(epoch, logs) {
39143 const ps = [];
39144 if (this.epochEnd != null) {
39145 await resolveScalarsInLogs(logs);
39146 ps.push(this.epochEnd(epoch, logs));
39147 }
39148 if (this.yieldEvery === 'epoch') {
39149 ps.push(this.nextFrameFunc());
39150 }
39151 await Promise.all(ps);
39152 }
39153 async onBatchBegin(batch, logs) {
39154 if (this.batchBegin != null) {
39155 await resolveScalarsInLogs(logs);
39156 await this.batchBegin(batch, logs);
39157 }
39158 }
39159 async onBatchEnd(batch, logs) {
39160 const ps = [];
39161 if (this.batchEnd != null) {
39162 await resolveScalarsInLogs(logs);
39163 ps.push(this.batchEnd(batch, logs));
39164 }
39165 if (this.yieldEvery === 'batch') {
39166 ps.push(this.nextFrameFunc());
39167 }
39168 else if (isNumber(this.yieldEvery)) {
39169 ps.push(this.maybeWait(this.currentEpoch, batch, logs));
39170 }
39171 await Promise.all(ps);
39172 }
39173 async onTrainBegin(logs) {
39174 if (this.trainBegin != null) {
39175 await resolveScalarsInLogs(logs);
39176 await this.trainBegin(logs);
39177 }
39178 }
39179 async onTrainEnd(logs) {
39180 if (this.trainEnd != null) {
39181 await resolveScalarsInLogs(logs);
39182 await this.trainEnd(logs);
39183 }
39184 }
39185 }
39186 /**
39187 * Standardize callbacks or configurations of them to an Array of callbacks.
39188 */
39189 function standardizeCallbacks(callbacks, yieldEvery) {
39190 if (callbacks == null) {
39191 callbacks = {};
39192 }
39193 if (callbacks instanceof BaseCallback) {
39194 return [callbacks];
39195 }
39196 if (Array.isArray(callbacks) && callbacks[0] instanceof BaseCallback) {
39197 return callbacks;
39198 }
39199 // Convert custom callback configs to custom callback objects.
39200 const callbackConfigs = toList(callbacks);
39201 return callbackConfigs.map(callbackConfig => new CustomCallback(callbackConfig, yieldEvery));
39202 }
39203 /**
39204 * A global registry for callback constructors to be used during
39205 * LayersModel.fit().
39206 */
39207 class CallbackConstructorRegistry {
39208 /**
39209 * Blocks public access to constructor.
39210 */
39211 constructor() { }
39212 /**
39213 * Register a tf.LayersModel.fit() callback constructor.
39214 *
39215 * The registered callback constructor will be used to instantiate
39216 * callbacks for every tf.LayersModel.fit() call afterwards.
39217 *
39218 * @param verbosityLevel Level of verbosity at which the `callbackConstructor`
39219 * is to be reigstered.
39220 * @param callbackConstructor A no-arg constructor for `tf.Callback`.
39221 * @throws Error, if the same callbackConstructor has been registered before,
39222 * either at the same or a different `verbosityLevel`.
39223 */
39224 static registerCallbackConstructor(verbosityLevel, callbackConstructor) {
39225 assert(verbosityLevel >= 0 && Number.isInteger(verbosityLevel), () => `Verbosity level is expected to be an integer >= 0, ` +
39226 `but got ${verbosityLevel}`);
39227 CallbackConstructorRegistry.checkForDuplicate(callbackConstructor);
39228 if (CallbackConstructorRegistry.constructors[verbosityLevel] == null) {
39229 CallbackConstructorRegistry.constructors[verbosityLevel] = [];
39230 }
39231 CallbackConstructorRegistry.constructors[verbosityLevel].push(callbackConstructor);
39232 }
39233 static checkForDuplicate(callbackConstructor) {
39234 for (const levelName in CallbackConstructorRegistry.constructors) {
39235 const constructors = CallbackConstructorRegistry.constructors[+levelName];
39236 constructors.forEach(ctor => {
39237 if (ctor === callbackConstructor) {
39238 throw new ValueError('Duplicate callback constructor.');
39239 }
39240 });
39241 }
39242 }
39243 /**
39244 * Clear all registered callback constructors.
39245 */
39246 static clear() {
39247 CallbackConstructorRegistry.constructors = {};
39248 }
39249 /**
39250 * Create callbacks using the registered callback constructors.
39251 *
39252 * Given `verbosityLevel`, all constructors registered at that level or above
39253 * will be called and the instantiated callbacks will be used.
39254 *
39255 * @param verbosityLevel: Level of verbosity.
39256 */
39257 static createCallbacks(verbosityLevel) {
39258 const constructors = [];
39259 for (const levelName in CallbackConstructorRegistry.constructors) {
39260 const level = +levelName;
39261 if (verbosityLevel >= level) {
39262 constructors.push(...CallbackConstructorRegistry.constructors[level]);
39263 }
39264 }
39265 return constructors.map(ctor => new ctor());
39266 }
39267 }
39268 CallbackConstructorRegistry.constructors = {};
39269 function configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics) {
39270 const history = new History();
39271 const actualCallbacks = [
39272 new BaseLogger(), ...CallbackConstructorRegistry.createCallbacks(verbose)
39273 ];
39274 if (callbacks != null) {
39275 actualCallbacks.push(...callbacks);
39276 }
39277 actualCallbacks.push(history);
39278 const callbackList = new CallbackList(actualCallbacks);
39279 // TODO(cais): Figure out when this LayersModel instance can have a
39280 // dynamically
39281 // set property called 'callback_model' as in PyKeras.
39282 callbackList.setParams({
39283 epochs,
39284 initialEpoch,
39285 samples: numTrainSamples,
39286 steps: stepsPerEpoch,
39287 batchSize,
39288 verbose,
39289 doValidation,
39290 metrics: callbackMetrics,
39291 });
39292 return { callbackList, history };
39293 }
39294
39295 /**
39296 * @license
39297 * Copyright 2018 Google LLC
39298 *
39299 * Use of this source code is governed by an MIT-style
39300 * license that can be found in the LICENSE file or at
39301 * https://opensource.org/licenses/MIT.
39302 * =============================================================================
39303 */
39304 /**
39305 * Instantiate a layer from a config dictionary.
39306 * @param config dict of the form {class_name: str, config: dict}
39307 * @param customObjects dict mapping class names (or function names)
39308 * of custom (non-Keras) objects to class/functions
39309 * @param fastWeightInit Optional flag to use fast weight initialization
39310 * during deserialization. This is applicable to cases in which
39311 * the initialization will be immediately overwritten by loaded weight
39312 * values. Default: `false`.
39313 * @returns Layer instance (may be LayersModel, Sequential, Layer...)
39314 */
39315 function deserialize(config, customObjects = {}, fastWeightInit = false) {
39316 return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'layer', fastWeightInit);
39317 }
39318
39319 /**
39320 * @license
39321 * Copyright 2018 Google LLC
39322 *
39323 * Use of this source code is governed by an MIT-style
39324 * license that can be found in the LICENSE file or at
39325 * https://opensource.org/licenses/MIT.
39326 * =============================================================================
39327 */
39328 /**
39329 * Normalizes a tensor wrt the L2 norm alongside the specified axis.
39330 * @param x
39331 * @param axis Axis along which to perform normalization.
39332 */
39333 function l2Normalize(x, axis) {
39334 return tidy(() => {
39335 if (x.dtype !== 'float32') {
39336 x = cast(x, 'float32');
39337 }
39338 const squareSum = sum$1(square$1(x), axis, true);
39339 const epsilonTensor = fill(squareSum.shape, epsilon());
39340 const norm = sqrt(maximum(squareSum, epsilonTensor));
39341 return div(x, norm);
39342 });
39343 }
39344 function meanSquaredError$1(yTrue, yPred) {
39345 return tidy(() => mean(square$1(sub(yPred, yTrue)), -1));
39346 }
39347 function meanAbsoluteError(yTrue, yPred) {
39348 return tidy(() => mean(abs(sub(yPred, yTrue)), -1));
39349 }
39350 function meanAbsolutePercentageError(yTrue, yPred) {
39351 return tidy(() => {
39352 const diff = sub(yTrue, yPred);
39353 const clippedTrue = clipByValue(abs(yTrue), epsilon(), Number.MAX_VALUE);
39354 const absResult = abs(div(diff, clippedTrue));
39355 return mul(100, mean(absResult, -1));
39356 });
39357 }
39358 function meanSquaredLogarithmicError(yTrue, yPred) {
39359 return tidy(() => {
39360 const clippedPred = clipByValue(yPred, epsilon(), Number.MAX_VALUE);
39361 const firstLog = log$1(add$1(1, clippedPred));
39362 const clippedTrue = clipByValue(yTrue, epsilon(), Number.MAX_VALUE);
39363 const secondLog = log$1(add$1(1, clippedTrue));
39364 return mean(square$1(sub(firstLog, secondLog)), -1);
39365 });
39366 }
39367 function squaredHinge(yTrue, yPred) {
39368 return tidy(() => {
39369 const maxResult = maximum(0, sub(1, mul(yTrue, yPred)));
39370 return mean(square$1(maxResult), -1);
39371 });
39372 }
39373 function hinge(yTrue, yPred) {
39374 return tidy(() => {
39375 const maxResult = maximum(0, sub(1, mul(yTrue, yPred)));
39376 return mean(maxResult, -1);
39377 });
39378 }
39379 function categoricalHinge(yTrue, yPred) {
39380 return tidy(() => {
39381 const pos = sum$1(mul(yTrue, yPred), -1);
39382 const neg = max(mul(sub(1, yTrue), yPred), -1);
39383 return maximum(0, add$1(1, sub(neg, pos)));
39384 });
39385 }
39386 /**
39387 * Logarithm of the hyperbolic cosine of the prediction error.
39388 *
39389 * `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and
39390 * to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works mostly
39391 * like the mean squared error, but will not be so strongly affected by the
39392 * occasional wildly incorrect prediction.
39393 */
39394 function logcosh(yTrue, yPred) {
39395 return tidy(() => {
39396 const log2 = Math.log(2);
39397 const predictionDiff = sub(yPred, yTrue);
39398 const logcoshResult = sub(add$1(predictionDiff, softplus(mul(-2, predictionDiff))), log2);
39399 return mean(logcoshResult, -1);
39400 });
39401 }
39402 function categoricalCrossentropy(target, output, fromLogits = false) {
39403 return tidy(() => {
39404 if (fromLogits) {
39405 output = softmax(output);
39406 }
39407 else {
39408 // scale preds so that the class probabilities of each sample sum to 1.
39409 const outputSum = sum$1(output, output.shape.length - 1, true);
39410 output = div(output, outputSum);
39411 }
39412 output = clipByValue(output, epsilon(), 1 - epsilon());
39413 return neg(sum$1(mul(cast(target, 'float32'), log$1(output)), output.shape.length - 1));
39414 });
39415 }
39416 /**
39417 * Categorical crossentropy with integer targets.
39418 *
39419 * @param target An integer tensor.
39420 * @param output A tensor resulting from a softmax (unless `fromLogits` is
39421 * `true`, in which case `output` is expected to be the logits).
39422 * @param fromLogits Boolean, whether `output` is the result of a softmax, or is
39423 * a tensor of logits.
39424 */
39425 function sparseCategoricalCrossentropy(target, output, fromLogits = false) {
39426 return tidy(() => {
39427 const flatTarget = cast(floor(flatten$1(target)), 'int32');
39428 output = clipByValue(output, epsilon(), 1 - epsilon());
39429 const outputShape = output.shape;
39430 const oneHotTarget = reshape(oneHot(flatTarget, outputShape[outputShape.length - 1]), outputShape);
39431 return categoricalCrossentropy(oneHotTarget, output, fromLogits);
39432 });
39433 }
39434 /**
39435 * From TensorFlow's implementation in nn_impl.py:
39436 *
39437 * For brevity, let `x = logits`, `z = labels`. The logistic loss is
39438 * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
39439 * = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
39440 * = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
39441 * = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
39442 * = (1 - z) * x + log(1 + exp(-x))
39443 * = x - x * z + log(1 + exp(-x))
39444 * For x < 0, to avoid overflow in exp(-x), we reformulate the above
39445 * x - x * z + log(1 + exp(-x))
39446 * = log(exp(x)) - x * z + log(1 + exp(-x))
39447 * = - x * z + log(1 + exp(x))
39448 * Hence, to ensure stability and avoid overflow, the implementation uses this
39449 * equivalent formulation
39450 * max(x, 0) - x * z + log(1 + exp(-abs(x)))
39451 *
39452 * @param labels The labels.
39453 * @param logits The logits.
39454 */
39455 function sigmoidCrossEntropyWithLogits(labels, logits) {
39456 if (!arraysEqual(labels.shape, logits.shape)) {
39457 throw new ValueError(`logits and labels must have the same shape, but got shapes ` +
39458 `${JSON.stringify(labels.shape)} and ${JSON.stringify(logits.shape)}`);
39459 }
39460 return tidy(() => {
39461 // The logistic loss formula from above is
39462 // x - x * z + log(1 + exp(-x))
39463 // For x < 0, a more numerically stable formula is
39464 // -x * z + log(1 + exp(x))
39465 // Note that these two expressions can be combined into the following:
39466 // max(x, 0) - x * z + log(1 + exp(-abs(x)))
39467 const reluLogits = relu(logits);
39468 const negAbsLogits = neg(abs(logits));
39469 return add$1(sub(reluLogits, mul(logits, labels)), log1p(exp(negAbsLogits)));
39470 });
39471 }
39472 function binaryCrossentropy(yTrue, yPred) {
39473 return tidy(() => {
39474 let y;
39475 y = clipByValue(yPred, epsilon(), 1 - epsilon());
39476 y = log$1(div(y, sub(1, y)));
39477 return mean(sigmoidCrossEntropyWithLogits(yTrue, y), -1);
39478 });
39479 }
39480 function kullbackLeiblerDivergence(yTrue, yPred) {
39481 return tidy(() => {
39482 const clippedTrue = clipByValue(yTrue, epsilon(), 1);
39483 const clippedPred = clipByValue(yPred, epsilon(), 1);
39484 return sum$1(mul(yTrue, log$1(div(clippedTrue, clippedPred))), -1);
39485 });
39486 }
39487 function poisson(yTrue, yPred) {
39488 return tidy(() => {
39489 const logPred = log$1(add$1(epsilon(), yPred));
39490 return mean(sub(yPred, mul(yTrue, logPred)), -1);
39491 });
39492 }
39493 function cosineProximity(yTrue, yPred) {
39494 return tidy(() => {
39495 const trueNormalized = l2Normalize(yTrue, -1);
39496 const predNormalized = l2Normalize(yPred, -1);
39497 const trueXPred = mul(trueNormalized, predNormalized);
39498 return neg(sum$1(trueXPred, -1));
39499 });
39500 }
39501 const mse = meanSquaredError$1;
39502 const MSE = meanSquaredError$1;
39503 const mae = meanAbsoluteError;
39504 const MAE = meanAbsoluteError;
39505 const mape = meanAbsolutePercentageError;
39506 const MAPE = meanAbsolutePercentageError;
39507 const msle = meanSquaredLogarithmicError;
39508 const MSLE = meanSquaredLogarithmicError;
39509 const kld = kullbackLeiblerDivergence;
39510 const KLD = kullbackLeiblerDivergence;
39511 const cosine = cosineProximity;
39512 // TODO(michaelterry): Add deserialize() function.
39513 const lossesMap = {
39514 meanSquaredError: meanSquaredError$1,
39515 meanAbsoluteError,
39516 meanAbsolutePercentageError,
39517 meanSquaredLogarithmicError,
39518 squaredHinge,
39519 hinge,
39520 categoricalHinge,
39521 logcosh,
39522 categoricalCrossentropy,
39523 sparseCategoricalCrossentropy,
39524 binaryCrossentropy,
39525 kullbackLeiblerDivergence,
39526 poisson,
39527 cosineProximity
39528 };
39529 // Porting note: This diverges from the PyKeras implementation and may need to
39530 // change based on (de)serialization requirements.
39531 function get(identifierOrFn) {
39532 if (typeof identifierOrFn === 'string') {
39533 if (identifierOrFn in lossesMap) {
39534 return lossesMap[identifierOrFn];
39535 }
39536 let errMsg = `Unknown loss ${identifierOrFn}`;
39537 if (identifierOrFn.toLowerCase().includes('softmaxcrossentropy')) {
39538 errMsg = `Unknown loss ${identifierOrFn}. ` +
39539 'Use "categoricalCrossentropy" as the string name for ' +
39540 'tf.losses.softmaxCrossEntropy';
39541 }
39542 throw new ValueError(errMsg);
39543 }
39544 else {
39545 return identifierOrFn;
39546 }
39547 }
39548
39549 /**
39550 * @license
39551 * Copyright 2018 Google LLC
39552 *
39553 * Use of this source code is governed by an MIT-style
39554 * license that can be found in the LICENSE file or at
39555 * https://opensource.org/licenses/MIT.
39556 * =============================================================================
39557 */
39558 function binaryAccuracy(yTrue, yPred) {
39559 return tidy(() => {
39560 const threshold = mul(.5, onesLike(yPred));
39561 const yPredThresholded = cast$1(greater(yPred, threshold), yTrue.dtype);
39562 return mean(equal(yTrue, yPredThresholded), -1);
39563 });
39564 }
39565 function categoricalAccuracy(yTrue, yPred) {
39566 return tidy(() => cast$1(equal(argMax(yTrue, -1), argMax(yPred, -1)), 'float32'));
39567 }
39568 function truePositives(yTrue, yPred) {
39569 return tidy(() => {
39570 return cast(sum$1(logicalAnd(equal(yTrue, 1), equal(yPred, 1))), 'float32');
39571 });
39572 }
39573 function falseNegatives(yTrue, yPred) {
39574 return tidy(() => {
39575 return cast(sum$1(logicalAnd(equal(yTrue, 1), equal(yPred, 0))), 'float32');
39576 });
39577 }
39578 function falsePositives(yTrue, yPred) {
39579 return tidy(() => {
39580 return cast(sum$1(logicalAnd(equal(yTrue, 0), equal(yPred, 1))), 'float32');
39581 });
39582 }
39583 function precision(yTrue, yPred) {
39584 return tidy(() => {
39585 const tp = truePositives(yTrue, yPred);
39586 const fp = falsePositives(yTrue, yPred);
39587 const denominator = add$1(tp, fp);
39588 return cast(where(greater(denominator, 0), div(tp, denominator), 0), 'float32');
39589 });
39590 }
39591 function recall(yTrue, yPred) {
39592 return tidy(() => {
39593 const tp = truePositives(yTrue, yPred);
39594 const fn = falseNegatives(yTrue, yPred);
39595 const denominator = add$1(tp, fn);
39596 return cast(where(greater(denominator, 0), div(tp, denominator), 0), 'float32');
39597 });
39598 }
39599 function binaryCrossentropy$1(yTrue, yPred) {
39600 return binaryCrossentropy(yTrue, yPred);
39601 }
39602 function sparseCategoricalAccuracy(yTrue, yPred) {
39603 if (yTrue.rank === yPred.rank) {
39604 yTrue = squeeze(yTrue, [yTrue.rank - 1]);
39605 }
39606 yPred = argMax(yPred, -1);
39607 if (yPred.dtype !== yTrue.dtype) {
39608 yPred = cast(yPred, yTrue.dtype);
39609 }
39610 return cast(equal(yTrue, yPred), 'float32');
39611 }
39612 function topKCategoricalAccuracy(yTrue, yPred) {
39613 throw new NotImplementedError();
39614 }
39615 function sparseTopKCategoricalAccuracy(yTrue, yPred) {
39616 throw new NotImplementedError();
39617 }
39618 // Aliases.
39619 const mse$1 = meanSquaredError$1;
39620 const MSE$1 = meanSquaredError$1;
39621 const mae$1 = meanAbsoluteError;
39622 const MAE$1 = meanAbsoluteError;
39623 const mape$1 = meanAbsolutePercentageError;
39624 const MAPE$1 = meanAbsolutePercentageError;
39625 const categoricalCrossentropy$1 = categoricalCrossentropy;
39626 const cosine$1 = cosineProximity;
39627 const sparseCategoricalCrossentropy$1 = sparseCategoricalCrossentropy;
39628 // TODO(cais, nielsene): Add serialize().
39629 const metricsMap = {
39630 binaryAccuracy,
39631 categoricalAccuracy,
39632 precision,
39633 categoricalCrossentropy: categoricalCrossentropy$1,
39634 sparseCategoricalCrossentropy: sparseCategoricalCrossentropy$1,
39635 mse: mse$1,
39636 MSE: MSE$1,
39637 mae: mae$1,
39638 MAE: MAE$1,
39639 mape: mape$1,
39640 MAPE: MAPE$1,
39641 cosine: cosine$1
39642 };
39643 function get$1(identifier) {
39644 if (typeof identifier === 'string' && identifier in metricsMap) {
39645 return metricsMap[identifier];
39646 }
39647 else if (typeof identifier !== 'string' && identifier != null) {
39648 return identifier;
39649 }
39650 else {
39651 throw new ValueError(`Unknown metric ${identifier}`);
39652 }
39653 }
39654 /**
39655 * Get the shortcut function name.
39656 *
39657 * If the fn name is a string,
39658 * directly return the string name.
39659 * If the function is included in metricsMap or lossesMap,
39660 * return key of the map.
39661 * - If the function relative to multiple keys,
39662 * return the first found key as the function name.
39663 * - If the function exists in both lossesMap and metricsMap,
39664 * search lossesMap first.
39665 * If the function is not included in metricsMap or lossesMap,
39666 * return the function name.
39667 *
39668 * @param fn loss function, metric function, or short cut name.
39669 * @returns Loss or Metric name in string.
39670 */
39671 function getLossOrMetricName(fn) {
39672 assert$1(fn !== null, `Unknown LossOrMetricFn ${fn}`);
39673 if (typeof fn === 'string') {
39674 return fn;
39675 }
39676 else {
39677 let fnName;
39678 for (const key of Object.keys(lossesMap)) {
39679 if (lossesMap[key] === fn) {
39680 fnName = key;
39681 break;
39682 }
39683 }
39684 if (fnName !== undefined) {
39685 return fnName;
39686 }
39687 for (const key of Object.keys(metricsMap)) {
39688 if (metricsMap[key] === fn) {
39689 fnName = key;
39690 break;
39691 }
39692 }
39693 if (fnName !== undefined) {
39694 return fnName;
39695 }
39696 return fn.name;
39697 }
39698 }
39699
39700 /**
39701 * @license
39702 * Copyright 2018 Google LLC
39703 *
39704 * Use of this source code is governed by an MIT-style
39705 * license that can be found in the LICENSE file or at
39706 * https://opensource.org/licenses/MIT.
39707 * =============================================================================
39708 */
39709 // Add (de)serialize()
39710 // Porting note: This diverges from the PyKeras implementation and may need to
39711 // change based on (de)serialization requirements.
39712 function getOptimizer(identifier) {
39713 const optimizerMap = {
39714 'Adagrad': () => train.adagrad(0.01),
39715 'Adadelta': () => train.adadelta(1, 0.95, epsilon()),
39716 'Adam': () => train.adam(0.001, 0.9, 0.999, epsilon()),
39717 'Adamax': () => train.adamax(0.002, 0.9, 0.999, epsilon(), 0),
39718 'RMSProp': () => train.rmsprop(0.001, 0.9, 0, epsilon()),
39719 'SGD': () => train.sgd(0.01)
39720 };
39721 optimizerMap['adagrad'] = optimizerMap['Adagrad'];
39722 optimizerMap['adadelta'] = optimizerMap['Adadelta'];
39723 optimizerMap['adam'] = optimizerMap['Adam'];
39724 optimizerMap['adamax'] = optimizerMap['Adamax'];
39725 optimizerMap['rmsprop'] = optimizerMap['RMSProp'];
39726 optimizerMap['sgd'] = optimizerMap['SGD'];
39727 if (identifier in optimizerMap) {
39728 return optimizerMap[identifier]();
39729 }
39730 throw new ValueError(`Unknown Optimizer ${identifier}`);
39731 }
39732
39733 /**
39734 * @license
39735 * Copyright 2019 Google LLC
39736 *
39737 * Use of this source code is governed by an MIT-style
39738 * license that can be found in the LICENSE file or at
39739 * https://opensource.org/licenses/MIT.
39740 * =============================================================================
39741 */
39742 /** Utility functions related to user-defined metadata. */
39743 // Maximum recommended serialized size for user-defined metadata.
39744 // Beyond this limit, a warning message will be printed during model loading and
39745 // saving.
39746 const MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH = 1 * 1024 * 1024;
39747 /**
39748 * Check validity of user-defined metadata.
39749 *
39750 * @param userDefinedMetadata
39751 * @param modelName Name of the model that the user-defined metadata belongs to.
39752 * Used during construction of error messages.
39753 * @param checkSize Whether to check the size of the metadata is under
39754 * recommended limit. Default: `false`. If `true`, will try stringify the
39755 * JSON object and print a console warning if the serialzied size is above the
39756 * limit.
39757 * @throws Error if `userDefinedMetadata` is not a plain JSON object.
39758 */
39759 function checkUserDefinedMetadata(userDefinedMetadata, modelName, checkSize = false) {
39760 if (userDefinedMetadata == null ||
39761 typeof userDefinedMetadata !== 'object' ||
39762 Object.getPrototypeOf(userDefinedMetadata) !== Object.prototype ||
39763 !plainObjectCheck(userDefinedMetadata)) {
39764 throw new Error('User-defined metadata is expected to be a JSON object, but is not.');
39765 }
39766 if (checkSize) {
39767 const out = JSON.stringify(userDefinedMetadata);
39768 if (out.length > MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH) {
39769 console.warn(`User-defined metadata of model "${modelName}" is too large in ` +
39770 `size (length=${out.length} when serialized). It is not ` +
39771 `recommended to store such large objects in user-defined metadata. ` +
39772 `Please make sure its serialized length is <= ` +
39773 `${MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH}.`);
39774 }
39775 }
39776 }
39777 /**
39778 * Check if an input is plain JSON object or any valid subfield of it.
39779 *
39780 * @param x The input to be checked.
39781 * @param assertObject Whether to assert `x` is a JSON object, i.e., reject
39782 * cases of arrays and primitives.
39783 * @return Returns `true` if and only if `x` is a plain JSON object,
39784 * a JSON-valid primitive including string, number, boolean and null,
39785 * or an array of the said types.
39786 */
39787 // tslint:disable-next-line:no-any
39788 function plainObjectCheck(x) {
39789 if (x === null) {
39790 // Note: typeof `null` is 'object', and `null` is valid in JSON.
39791 return true;
39792 }
39793 else if (typeof x === 'object') {
39794 if (Object.getPrototypeOf(x) === Object.prototype) {
39795 // `x` is a JavaScript object and its prototype is Object.
39796 const keys = Object.keys(x);
39797 for (const key of keys) {
39798 if (typeof key !== 'string') {
39799 // JSON keys must be strings.
39800 return false;
39801 }
39802 if (!plainObjectCheck(x[key])) { // Recursive call.
39803 return false;
39804 }
39805 }
39806 return true;
39807 }
39808 else {
39809 // `x` is a JavaScript object but its prototype is not Object.
39810 if (Array.isArray(x)) {
39811 // `x` is a JavaScript array.
39812 for (const item of x) {
39813 if (!plainObjectCheck(item)) { // Recursive call.
39814 return false;
39815 }
39816 }
39817 return true;
39818 }
39819 else {
39820 // `x` is a JavaScript object and its prototype is not Object,
39821 // and it's not an Array. I.e., it's a complex object such as
39822 // `Error` and `Date`.
39823 return false;
39824 }
39825 }
39826 }
39827 else {
39828 // `x` is not a JavaScript object or `null`.
39829 const xType = typeof x;
39830 return xType === 'string' || xType === 'number' || xType === 'boolean';
39831 }
39832 }
39833
39834 /**
39835 * @license
39836 * Copyright 2018 Google LLC
39837 *
39838 * Use of this source code is governed by an MIT-style
39839 * license that can be found in the LICENSE file or at
39840 * https://opensource.org/licenses/MIT.
39841 * =============================================================================
39842 */
39843 /**
39844 * Print the summary of a LayersModel object.
39845 *
39846 * @param model tf.LayersModel instance.
39847 * @param lineLength Total length of printed lines. Set this to adapt to the
39848 * display to different terminal or console sizes.
39849 * @param positions Relative or absolute positions of log elements in each
39850 * line. Each number corresponds to right-most (i.e., ending) position of a
39851 * column.
39852 * If not provided, defaults to `[0.45, 0.85, 1]` for sequential-like
39853 * models and `[0.33, 0.55, 0.67, 1]` for non-sequential like models.
39854 * @param printFn Print function to use.
39855 * It will be called on each line of the summary. You can provide a custom
39856 * function in order to capture the string summary. Defaults to `console.log`.
39857 */
39858 function printSummary(model, lineLength, positions,
39859 // tslint:disable-next-line:no-any
39860 printFn = console.log) {
39861 const sequentialLike = isModelSequentialLike(model);
39862 // Header names for different log elements.
39863 const toDisplay = ['Layer (type)', 'Input Shape', 'Output shape', 'Param #'];
39864 if (sequentialLike) {
39865 lineLength = lineLength || 90;
39866 positions = positions || [0.32, 0.61, 0.89, 1];
39867 }
39868 else {
39869 lineLength = lineLength || 115;
39870 positions = positions || [0.24, 0.48, 0.70, 0.80, 1];
39871 // Header names for different log elements.
39872 }
39873 if (positions[positions.length - 1] <= 1) {
39874 // `positions` is relative. Convert it to absolute positioning.
39875 positions = positions.map(p => Math.floor(lineLength * p));
39876 }
39877 let relevantNodes;
39878 if (!sequentialLike) {
39879 toDisplay.push('Receives inputs');
39880 relevantNodes = [];
39881 for (const depth in model.nodesByDepth) {
39882 relevantNodes.push(...model.nodesByDepth[depth]);
39883 }
39884 }
39885 printFn('_'.repeat(lineLength));
39886 printRow(toDisplay, positions, printFn);
39887 printFn('='.repeat(lineLength));
39888 const layers = model.layers;
39889 for (let i = 0; i < layers.length; ++i) {
39890 if (sequentialLike) {
39891 printLayerSummary(layers[i], positions, printFn);
39892 }
39893 else {
39894 printLayerSummaryWithConnections(layers[i], positions, relevantNodes, printFn);
39895 }
39896 printFn((i === layers.length - 1 ? '=' : '_').repeat(lineLength));
39897 }
39898 // tslint:disable-next-line:no-any
39899 model.checkTrainableWeightsConsistency();
39900 const trainableCount = countTrainableParams(model);
39901 const nonTrainableCount = countParamsInWeights(model.nonTrainableWeights);
39902 printFn(`Total params: ${trainableCount + nonTrainableCount}`);
39903 printFn(`Trainable params: ${trainableCount}`);
39904 printFn(`Non-trainable params: ${nonTrainableCount}`);
39905 printFn('_'.repeat(lineLength));
39906 }
39907 function countTrainableParams(model) {
39908 let trainableCount;
39909 // tslint:disable:no-any
39910 if (model.collectedTrainableWeights != null) {
39911 trainableCount =
39912 countParamsInWeights(model.collectedTrainableWeights);
39913 }
39914 else {
39915 trainableCount = countParamsInWeights(model.trainableWeights);
39916 }
39917 // tslint:enable:no-any
39918 return trainableCount;
39919 }
39920 function isModelSequentialLike(model) {
39921 let sequentialLike = true;
39922 const nodesByDepth = [];
39923 const nodes = [];
39924 for (const depth in model.nodesByDepth) {
39925 nodesByDepth.push(model.nodesByDepth[depth]);
39926 }
39927 for (const depthNodes of nodesByDepth) {
39928 if (depthNodes.length > 1 ||
39929 depthNodes.length === 1 && depthNodes[0].inboundLayers.length > 1) {
39930 sequentialLike = false;
39931 break;
39932 }
39933 nodes.push(...depthNodes);
39934 }
39935 if (sequentialLike) {
39936 // Search for shared layers.
39937 for (const layer of model.layers) {
39938 let flag = false;
39939 for (const node of layer.inboundNodes) {
39940 if (nodes.indexOf(node) !== -1) {
39941 if (flag) {
39942 sequentialLike = false;
39943 break;
39944 }
39945 else {
39946 flag = true;
39947 }
39948 }
39949 }
39950 if (!sequentialLike) {
39951 break;
39952 }
39953 }
39954 }
39955 return sequentialLike;
39956 }
39957 function printRow(fields, positions,
39958 // tslint:disable-next-line:no-any
39959 printFn = console.log) {
39960 let line = '';
39961 for (let i = 0; i < fields.length; ++i) {
39962 if (i > 0) {
39963 line = line.slice(0, line.length - 1) + ' ';
39964 }
39965 line += fields[i];
39966 line = line.slice(0, positions[i]);
39967 line += ' '.repeat(positions[i] - line.length);
39968 }
39969 printFn(line);
39970 }
39971 /**
39972 * Prints a summary for a single Layer, without connectivity information.
39973 *
39974 * @param layer: Layer instance to print.
39975 */
39976 function printLayerSummary(layer, positions,
39977 // tslint:disable-next-line:no-any
39978 printFn) {
39979 let outputShape;
39980 let inputShape;
39981 try {
39982 inputShape = (layer.inboundNodes.map(x => JSON.stringify(x.inputShapes))).join(',');
39983 }
39984 catch (err) {
39985 inputShape = 'multiple';
39986 }
39987 try {
39988 outputShape = JSON.stringify(layer.outputShape);
39989 }
39990 catch (err) {
39991 outputShape = 'multiple';
39992 }
39993 const name = layer.name;
39994 const className = layer.getClassName();
39995 const fields = [`${name} (${className})`, inputShape,
39996 outputShape, layer.countParams().toString()];
39997 printRow(fields, positions, printFn);
39998 }
39999 /**
40000 * Prints a summary for a single Layer, with connectivity information.
40001 */
40002 function printLayerSummaryWithConnections(layer, positions, relevantNodes,
40003 // tslint:disable-next-line:no-any
40004 printFn) {
40005 let outputShape;
40006 let inputShape;
40007 try {
40008 inputShape = (layer.inboundNodes.map(x => JSON.stringify(x.inputShapes))).join(',');
40009 }
40010 catch (err) {
40011 inputShape = 'multiple';
40012 }
40013 try {
40014 outputShape = JSON.stringify(layer.outputShape);
40015 }
40016 catch (err) {
40017 outputShape = 'multiple';
40018 }
40019 const connections = [];
40020 for (const node of layer.inboundNodes) {
40021 if (relevantNodes != null && relevantNodes.length > 0 &&
40022 relevantNodes.indexOf(node) === -1) {
40023 continue;
40024 }
40025 for (let i = 0; i < node.inboundLayers.length; ++i) {
40026 const inboundLayer = node.inboundLayers[i].name;
40027 const inboundLayerIndex = node.nodeIndices[i];
40028 const inboundTensorIndex = node.tensorIndices[i];
40029 connections.push(`${inboundLayer}[${inboundLayerIndex}][${inboundTensorIndex}]`);
40030 }
40031 }
40032 const name = layer.name;
40033 const className = layer.getClassName();
40034 const firstConnection = connections.length === 0 ? '' : connections[0];
40035 const fields = [
40036 `${name} (${className})`, inputShape,
40037 outputShape, layer.countParams().toString(),
40038 firstConnection
40039 ];
40040 printRow(fields, positions, printFn);
40041 for (let i = 1; i < connections.length; ++i) {
40042 printRow(['', '', '', '', connections[i]], positions, printFn);
40043 }
40044 }
40045
40046 /**
40047 * @license
40048 * Copyright 2018 Google LLC
40049 *
40050 * Use of this source code is governed by an MIT-style
40051 * license that can be found in the LICENSE file or at
40052 * https://opensource.org/licenses/MIT.
40053 * =============================================================================
40054 */
40055 // tslint:enable
40056 /**
40057 * Test whether a value in an array is the name of a LayersModel or Layer.
40058 * @param key The key name that the value is found under. Note that the key
40059 * may not be at the level immediately above the value, if the value is in a
40060 * nested array.
40061 * @param index Index of the value in the Array that it is found in.
40062 * @param value The value object.
40063 * @returns A boolean indicating whether value is a name.
40064 */
40065 function isArrayItemInputOrOutputName(key, index, value) {
40066 return (key === 'inboundNodes' || key === 'outputLayers' ||
40067 key === 'inputLayers') &&
40068 index === 0 && typeof value === 'string';
40069 }
40070 /**
40071 * Convert a Pythonic config object to TypeScript config object.
40072 * @param pythonicConfig The config object to convert.
40073 * @param key Optional key name of the object being converted.
40074 * @returns Result of the conversion.
40075 */
40076 function convertPythonicToTs(pythonicConfig, key) {
40077 if (pythonicConfig === null) {
40078 return null;
40079 }
40080 else if (typeof pythonicConfig === 'string') {
40081 return toCamelCase(pythonicConfig);
40082 }
40083 else if ((typeof pythonicConfig === 'number') ||
40084 (typeof pythonicConfig === 'boolean')) {
40085 return pythonicConfig;
40086 }
40087 else if (pythonicConfig instanceof Array) {
40088 const tsArray = [];
40089 const arrayLength = pythonicConfig.length;
40090 for (let i = 0; i < arrayLength; ++i) {
40091 const item = pythonicConfig[i];
40092 if (isArrayItemInputOrOutputName(key, i, item)) {
40093 tsArray.push(item);
40094 }
40095 else {
40096 tsArray.push(convertPythonicToTs(item, key));
40097 }
40098 }
40099 return tsArray;
40100 }
40101 else {
40102 const tsDict = {};
40103 for (const pythonicKey of Object.keys(pythonicConfig)) {
40104 const pythonicValue = pythonicConfig[pythonicKey];
40105 if (pythonicKey === 'name' && typeof pythonicValue === 'string') {
40106 // Special case the 'name' key with a string value. Name values, such as
40107 // the names of LayersModel and Layer instances, should not undergo the
40108 // camel-case conversion.
40109 tsDict[pythonicKey] = pythonicValue;
40110 }
40111 else {
40112 const tsKey = toCamelCase(pythonicKey);
40113 tsDict[tsKey] = convertPythonicToTs(pythonicValue, tsKey);
40114 }
40115 }
40116 return tsDict;
40117 }
40118 }
40119 /**
40120 * Convert a TypeScript config object to Python config object.
40121 * @param tsConfig The config object to convert.
40122 * @param key Optional key name of the object being converted.
40123 * @returns Result of the conversion.
40124 */
40125 function convertTsToPythonic(tsConfig, key) {
40126 if (tsConfig === null || tsConfig === undefined) {
40127 return null;
40128 }
40129 else if (typeof tsConfig === 'string') {
40130 return toSnakeCase(tsConfig);
40131 }
40132 else if ((typeof tsConfig === 'number') || (typeof tsConfig === 'boolean')) {
40133 return tsConfig;
40134 }
40135 else if (tsConfig instanceof Array) {
40136 const pyArray = [];
40137 const arrayLength = tsConfig.length;
40138 for (let i = 0; i < arrayLength; ++i) {
40139 const item = tsConfig[i];
40140 if (isArrayItemInputOrOutputName(key, i, item)) {
40141 pyArray.push(item);
40142 }
40143 else {
40144 pyArray.push(convertTsToPythonic(item, key));
40145 }
40146 }
40147 return pyArray;
40148 }
40149 else {
40150 const pyDict = {};
40151 for (const tsKey of Object.keys(tsConfig)) {
40152 const tsValue = tsConfig[tsKey];
40153 const pyKey = toSnakeCase(tsKey);
40154 if ((tsKey === 'name' || tsKey === 'className') &&
40155 typeof tsValue === 'string') {
40156 // Special case the 'name' key with a string value. Name values, such as
40157 // the names of LayersModel and Layer instances, should not undergo the
40158 // snake-case conversion.
40159 pyDict[pyKey] = tsValue;
40160 }
40161 else {
40162 pyDict[pyKey] = convertTsToPythonic(tsValue, tsKey);
40163 }
40164 }
40165 return pyDict;
40166 }
40167 }
40168
40169 /** @license See the LICENSE file. */
40170 // This code is auto-generated, do not modify this file!
40171 const version$1 = '3.18.0';
40172
40173 /**
40174 * @license
40175 * Copyright 2018 Google LLC
40176 *
40177 * Use of this source code is governed by an MIT-style
40178 * license that can be found in the LICENSE file or at
40179 * https://opensource.org/licenses/MIT.
40180 * =============================================================================
40181 */
40182 /**
40183 * A Container is a directed acyclic graph of layers.
40184 *
40185 * It is the topological form of a "model". A LayersModel
40186 * is simply a Container with added training routines.
40187 *
40188 */
40189 class Container extends Layer {
40190 constructor(args) {
40191 // No args passed to super's constructor.
40192 super({});
40193 this.containerNodes = new Set();
40194 this.name = args.name;
40195 if (this.name == null) {
40196 const prefix = this.getClassName().toLowerCase();
40197 this.name = getUid(prefix);
40198 }
40199 this.supportsMasking = false;
40200 this.trainable_ = true;
40201 // TODO(michaelterry): Initialize perInputLosses/Updates here.
40202 // Container-specific properties.
40203 if (Array.isArray(args.inputs)) {
40204 this.inputs = args.inputs.slice();
40205 }
40206 else {
40207 this.inputs = [args.inputs];
40208 }
40209 if (Array.isArray(args.outputs)) {
40210 this.outputs = args.outputs.slice();
40211 }
40212 else {
40213 this.outputs = [args.outputs];
40214 }
40215 // Check for redundancy in inputs.
40216 if (unique$1(this.inputs).length !== this.inputs.length) {
40217 throw new ValueError('The list of inputs passed to the model is ' +
40218 'redundant. All inputs should only appear once. Found: ' +
40219 `${this.inputs.map(x => x.name)}`);
40220 }
40221 // Check for redundancy in outputs.
40222 if (unique$1(this.outputs).length !== this.outputs.length) {
40223 console.warn('The list of outputs passed to the model is redundant. ' +
40224 'All outputs should only appear once. Found: ' +
40225 `${this.outputs.map(x => x.name)}`);
40226 }
40227 /*
40228 List of initial layers (1 to 1 mapping with this.inputs, hence the same
40229 layer might appear twice)
40230 */
40231 this.inputLayers = [];
40232 this.inputLayersNodeIndices = [];
40233 this.inputLayersTensorIndices = [];
40234 /*
40235 List of layers (1 to 1 mapping with this.outputs, hence the same layer
40236 might appear twice)
40237 */
40238 this.outputLayers = [];
40239 this.outputLayersNodeIndices = [];
40240 this.outputLayersTensorIndices = [];
40241 /*
40242 All layers in order of horizontal graph traversal. Entries are unique.
40243 Includes input and output layers.
40244 */
40245 this.layers = [];
40246 /*
40247 References to container layers that were constructed internally. We need
40248 these to properly dispose of tensors from nested containers.
40249 */
40250 this.internalContainerRefs = [];
40251 // TODO(michaelterry): Determine if caching still needed with eager
40252 // backend.
40253 /*
40254 This is for performance optimization when calling the Container on new
40255 inputs. Every time the Container is called on a set on input tensors,
40256 we compute the output tensors, output masks and output shapes in one pass,
40257 then cache them here. When one of these outputs is queried later,
40258 we retrieve it from there instead of recomputing it.
40259 */
40260 // this.outputTensorCache = {};
40261 // this.outputShapeCache = {};
40262 // Build this.outputLayers:
40263 for (const x of this.outputs) {
40264 const layer = x.sourceLayer;
40265 const nodeIndex = x.nodeIndex;
40266 const tensorIndex = x.tensorIndex;
40267 this.outputLayers.push(layer);
40268 this.outputLayersNodeIndices.push(nodeIndex);
40269 this.outputLayersTensorIndices.push(tensorIndex);
40270 }
40271 // TODO(michaelterry): Add output mask cache code.
40272 // Build this.inputLayers:
40273 for (const x of this.inputs) {
40274 const layer = x.sourceLayer;
40275 const nodeIndex = x.nodeIndex;
40276 const tensorIndex = x.tensorIndex;
40277 /*
40278 It's supposed to be an input layer, so only one node
40279 and one tensor output.
40280 */
40281 assert$1(nodeIndex === 0, 'input layer has >1 nodes');
40282 assert$1(tensorIndex === 0, 'input layer has >1 tensors');
40283 this.inputLayers.push(layer);
40284 this.inputLayersNodeIndices.push(nodeIndex);
40285 this.inputLayersTensorIndices.push(tensorIndex);
40286 }
40287 // Build this.inputNames and this.outputNames.
40288 this.inputNames = [];
40289 this.outputNames = [];
40290 this.feedInputShapes = [];
40291 this.feedInputNames = [];
40292 this.feedOutputNames = [];
40293 for (let i = 0; i < this.inputLayers.length; i++) {
40294 const layer = this.inputLayers[i];
40295 // Check that layer is an InputLayer.
40296 if (!(layer instanceof InputLayer)) {
40297 throw new TypeError('Input layers to a LayersModel must be InputLayer objects. ' +
40298 `Received inputs: ${args.inputs}. ` +
40299 `Input ${i} (0-based) originates ` +
40300 `from layer type ${layer.getClassName()}.`);
40301 }
40302 this.inputNames.push(layer.name);
40303 this.feedInputShapes.push(layer.batchInputShape);
40304 this.feedInputNames.push(layer.name);
40305 }
40306 for (const layer of this.outputLayers) {
40307 this.outputNames.push(layer.name);
40308 }
40309 this.internalInputShapes = this.inputs.map(x => x.shape);
40310 this.internalOutputShapes = this.outputs.map(x => x.shape);
40311 /*
40312 Container_nodes: set of nodes included in the graph (not all nodes
40313 included in the layers are relevant to the current graph).
40314 */
40315 // ids of all nodes relevant to the Container:
40316 const nodesDepths = {};
40317 // To recover nodes from their ID.
40318 const nodeIDToNode = {};
40319 const layersDepths = {};
40320 // To layers from their ID.
40321 const layerIDToLayer = {};
40322 const layerIndices = {};
40323 const nodesInDecreasingDepth = [];
40324 /**
40325 * Builds a map of the graph of layers.
40326 *
40327 * This recursively updates the map `layerIndices`,
40328 * the list `nodesInDecreasingDepth` and the set `containerNodes`.
40329 *
40330 * @param tensor Some tensor in a graph.
40331 * @param finishedNodes Set of nodes whose subgraphs have been traversed
40332 * completely. Useful to prevent duplicated work.
40333 * @param nodesInProgress Set of nodes that are currently active on the
40334 * recursion stack. Useful to detect cycles.
40335 * @param layer Layer from which `tensor` comes from. If not provided,
40336 * will be obtained from tensor.sourceLayer.
40337 * @param nodeIndex Node index from which `tensor` comes from.
40338 * @param tensorIndex TensorIndex from which `tensor` comes from.
40339 *
40340 * @exception RuntimeError if a cycle is detected.
40341 */
40342 const buildMapOfGraph = (tensor, finishedNodes, nodesInProgress, layer, nodeIndex, tensorIndex) => {
40343 if (layer == null || nodeIndex == null || tensorIndex == null) {
40344 layer = tensor.sourceLayer;
40345 nodeIndex = tensor.nodeIndex;
40346 tensorIndex = tensor.tensorIndex;
40347 }
40348 const node = layer.inboundNodes[nodeIndex];
40349 // Prevent cycles.
40350 if (nodesInProgress.indexOf(node) !== -1) {
40351 throw new RuntimeError(`The tensor ${tensor.name} at layer "${layer.name}" ` +
40352 'is part of a cycle.');
40353 }
40354 // Don't repeat work for shared subgraphs
40355 if (finishedNodes.indexOf(node) !== -1) {
40356 return;
40357 }
40358 // Update containerNodes.
40359 this.containerNodes.add(Container.nodeKey(layer, nodeIndex));
40360 // Store the traversal order for layer sorting.
40361 if (!(layer.id in layerIndices)) {
40362 layerIndices[layer.id] = Object.keys(layerIndices).length;
40363 }
40364 if (nodesInProgress.indexOf(node) === -1) {
40365 nodesInProgress.push(node);
40366 }
40367 // Propagate to all previous tensors connected to this node.
40368 const numInboundLayers = node.inboundLayers.length;
40369 for (let i = 0; i < numInboundLayers; i++) {
40370 const x = node.inputTensors[i];
40371 const layer = node.inboundLayers[i];
40372 const nodeIndex = node.nodeIndices[i];
40373 const tensorIndex = node.tensorIndices[i];
40374 buildMapOfGraph(x, finishedNodes, nodesInProgress, layer, nodeIndex, tensorIndex);
40375 }
40376 finishedNodes.push(node);
40377 while (nodesInProgress.indexOf(node) >= 0) {
40378 nodesInProgress.splice(nodesInProgress.indexOf(node), 1);
40379 }
40380 nodesInDecreasingDepth.push(node);
40381 };
40382 const finishedNodes = [];
40383 const nodesInProgress = [];
40384 for (const x of this.outputs) {
40385 buildMapOfGraph(x, finishedNodes, nodesInProgress);
40386 }
40387 const reversedNodesInDecreasingDepth = nodesInDecreasingDepth.slice().reverse();
40388 for (const node of reversedNodesInDecreasingDepth) {
40389 nodeIDToNode[node.id] = node;
40390 // If the depth is not set, the node has no outbound nodes (depth 0).
40391 if (!(node.id in nodesDepths)) {
40392 nodesDepths[node.id] = 0;
40393 }
40394 let depth = nodesDepths[node.id];
40395 // Update the depth of the corresponding layer
40396 const previousDepth = (layersDepths[node.outboundLayer.id] == null ?
40397 0 :
40398 layersDepths[node.outboundLayer.id]);
40399 /*
40400 If we've seen this layer before at a higher depth, we should use that
40401 depth instead of the node depth. This is necessary for shared layers
40402 that have inputs at different depth levels in the graph.
40403 */
40404 depth = Math.max(depth, previousDepth);
40405 layersDepths[node.outboundLayer.id] = depth;
40406 layerIDToLayer[node.outboundLayer.id] = node.outboundLayer;
40407 nodesDepths[node.id] = depth;
40408 // Update the depth of inbound nodes.
40409 for (let i = 0; i < node.inboundLayers.length; i++) {
40410 const inboundLayer = node.inboundLayers[i];
40411 const nodeIndex = node.nodeIndices[i];
40412 const inboundNode = inboundLayer.inboundNodes[nodeIndex];
40413 const previousDepth = (nodesDepths[inboundNode.id] == null ? 0 :
40414 nodesDepths[inboundNode.id]);
40415 nodesDepths[inboundNode.id] = Math.max(depth + 1, previousDepth);
40416 nodeIDToNode[inboundNode.id] = inboundNode;
40417 }
40418 }
40419 // Build a dict {depth: list of nodes with this depth}
40420 const nodesByDepth = {};
40421 for (const nodeID in nodesDepths) {
40422 const depth = nodesDepths[nodeID];
40423 if (!(depth in nodesByDepth)) {
40424 nodesByDepth[depth] = [];
40425 }
40426 nodesByDepth[depth].push(nodeIDToNode[nodeID]);
40427 }
40428 // Build a dict {depth: list of layers with this depth}
40429 const layersByDepth = {};
40430 for (const layerID in layersDepths) {
40431 const depth = layersDepths[layerID];
40432 if (!(depth in layersByDepth)) {
40433 layersByDepth[depth] = [];
40434 }
40435 layersByDepth[depth].push(layerIDToLayer[layerID]);
40436 }
40437 // Get sorted list of layer depths.
40438 let depthKeys = Object.keys(layersByDepth)
40439 .map(x => parseInt(x, 10))
40440 .sort(reverseNumberCompare);
40441 // Set this.layers and this.layersByDepth.
40442 this.layers = [];
40443 for (const depth of depthKeys) {
40444 const layersForDepth = layersByDepth[depth];
40445 // Container.layers needs to have a deterministic order:
40446 // here we order them by traversal order.
40447 layersForDepth.sort((a, b) => {
40448 const aIndex = layerIndices[a.id];
40449 const bIndex = layerIndices[b.id];
40450 if (aIndex < bIndex) {
40451 return -1;
40452 }
40453 if (aIndex > bIndex) {
40454 return 1;
40455 }
40456 return 0;
40457 });
40458 for (const layer of layersForDepth) {
40459 if (layer instanceof Container) {
40460 this.internalContainerRefs.push(layer);
40461 }
40462 this.layers.push(layer);
40463 }
40464 }
40465 this.layersByDepth = layersByDepth;
40466 // Get sorted list of node depths;
40467 depthKeys = Object.keys(nodesByDepth)
40468 .map(x => parseInt(x, 10))
40469 .sort(reverseNumberCompare);
40470 // Check that all tensors required are computable.
40471 // computable_tensors: all tensors in the graph
40472 // that can be computed from the inputs provided.
40473 const computableTensors = this.inputs.slice();
40474 // To provide a better error msg.
40475 const layersWithCompleteInput = [];
40476 for (const depth of depthKeys) {
40477 for (const node of nodesByDepth[depth]) {
40478 const layer = node.outboundLayer;
40479 if (layer != null) {
40480 for (const x of node.inputTensors) {
40481 if (computableTensors.indexOf(x) === -1) {
40482 throw new RuntimeError(`Graph disconnected: cannot obtain value for tensor ${x}` +
40483 ` at layer "${layer.name}". ` +
40484 'The following previous layers were accessed without ' +
40485 `issue: ${layersWithCompleteInput}`);
40486 }
40487 }
40488 for (const x of node.outputTensors) {
40489 computableTensors.push(x);
40490 }
40491 layersWithCompleteInput.push(layer.name);
40492 }
40493 }
40494 }
40495 // Set this.containerNodes and this.nodesByDepth.
40496 this.nodesByDepth = nodesByDepth;
40497 // Ensure name unicity, which will be crucial for serialization
40498 // (since serialized nodes refer to layers by their name).
40499 const allNames = this.layers.map(x => x.name);
40500 for (const name of allNames) {
40501 const numOccurrences = allNames.filter(x => x === name).length;
40502 if (numOccurrences !== 1) {
40503 throw new RuntimeError(`The name "${name}" is used ${numOccurrences} times ` +
40504 'in the model. All layer names should be unique. Layer names: ' +
40505 JSON.stringify(allNames));
40506 }
40507 }
40508 // Layer parameters.
40509 // The new container starts with a single inbound node
40510 // for its inputs, and no outbound nodes.
40511 // Will be appended to by future calls to apply().
40512 this.outboundNodes = [];
40513 // Will be appended to below, and by future calls to apply().
40514 this.inboundNodes = [];
40515 // Create the node linking internal inputs to internal outputs.
40516 // (This call has side effects.)
40517 // tslint:disable-next-line:no-unused-expression
40518 new Node({
40519 outboundLayer: this,
40520 inboundLayers: [],
40521 nodeIndices: [],
40522 tensorIndices: [],
40523 inputTensors: this.inputs,
40524 outputTensors: this.outputs,
40525 inputMasks: this.inputs.map(x => null),
40526 outputMasks: this.outputs.map(x => null),
40527 inputShapes: this.inputs.map(x => x.shape),
40528 outputShapes: this.outputs.map(x => x.shape)
40529 });
40530 this.built = true;
40531 this._refCount = 1; // The ref count of a container always start at 1.
40532 }
40533 assertNotDisposed() {
40534 if (this._refCount === 0) {
40535 throw new Error(`Container '${this.name}' is already disposed.`);
40536 }
40537 }
40538 /**
40539 * Attempt to dispose a LayersModel's weights.
40540 *
40541 * This method decrease the reference count of the LayersModel object by 1.
40542 *
40543 * A LayersModel is reference-counted. Its reference count is incremented by 1
40544 * when it is first constructed and when it is used as a Layer of another
40545 * LayersModel.
40546 *
40547 * If the reference count of a LayersModel becomes 0, the `dispose` method of
40548 * all its constituent `Layer`s will be called.
40549 *
40550 * Note: If the reference count is greater than 0 after the decrement, the
40551 * `dispose` method of its constituent `Layer`s will *not* be called.
40552 *
40553 * After a LayersModel is disposed, it cannot be used in calls such as
40554 * 'predict`, `evaluate` or `fit` anymore.
40555 *
40556 * @returns A DisposeResult Object with the following fields:
40557 * - refCountAfterDispose: The reference count of the LayersModel after this
40558 * `dispose()` call.
40559 * - numDisposedVariables: Number of `tf.Variable`s (i.e., weights) disposed
40560 * during this `dispose()` call.
40561 * @throws {Error} If the layer is not built yet, or if the LayersModel has
40562 * already been disposed.
40563 */
40564 dispose() {
40565 this.assertNotDisposed();
40566 const result = { refCountAfterDispose: null, numDisposedVariables: 0 };
40567 if (--this._refCount === 0) {
40568 for (const layer of this.layers) {
40569 result.numDisposedVariables += layer.dispose().numDisposedVariables;
40570 }
40571 // Call dispose on each internally created container layer again to ensure
40572 // their refCounts hit zero and their tensors are subsequently deleted.
40573 for (const container of this.internalContainerRefs) {
40574 result.numDisposedVariables += container.dispose().numDisposedVariables;
40575 }
40576 }
40577 result.refCountAfterDispose = this._refCount;
40578 return result;
40579 }
40580 get trainable() {
40581 return this.trainable_;
40582 }
40583 set trainable(trainable) {
40584 this.layers.forEach(layer => {
40585 // tslint:disable-next-line:no-any
40586 layer._trainableWeights
40587 .forEach(w => w.trainable = trainable);
40588 });
40589 this.trainable_ = trainable;
40590 }
40591 get trainableWeights() {
40592 // Porting Note: This check below is to prevent errors where the
40593 // _trainableWeights inherited from the parent class (Layer) gets
40594 // inadvertently used.
40595 if (this._trainableWeights.length > 0) {
40596 throw new ValueError('Container instance unexpectedly contains _trainableWeights.' +
40597 'The trainable weights of a Container are a union of the ' +
40598 'trainable weights of its consituent Layers. Its own ' +
40599 '_trainableWeights must remain an empty Array.');
40600 }
40601 if (!this.trainable) {
40602 return [];
40603 }
40604 let weights = [];
40605 for (const layer of this.layers) {
40606 weights = weights.concat(layer.trainableWeights);
40607 }
40608 return weights;
40609 }
40610 get nonTrainableWeights() {
40611 const weights = [];
40612 for (const layer of this.layers) {
40613 weights.push(...layer.nonTrainableWeights);
40614 }
40615 if (!this.trainable) {
40616 const trainableWeights = [];
40617 for (const layer of this.layers) {
40618 trainableWeights.push(...layer.trainableWeights);
40619 }
40620 return trainableWeights.concat(weights);
40621 }
40622 return weights;
40623 }
40624 get weights() {
40625 return this.trainableWeights.concat(this.nonTrainableWeights);
40626 }
40627 /**
40628 * Loads all layer weights from a JSON object.
40629 *
40630 * Porting Note: HDF5 weight files cannot be directly loaded in JavaScript /
40631 * TypeScript. The utility script at `scripts/pykeras.py` offers means
40632 * to convert them into JSON strings compatible with this method.
40633 * Porting Note: TensorFlow.js Layers supports only loading by name currently.
40634 *
40635 * @param weights A JSON mapping weight names to weight values as nested
40636 * arrays of numbers, or a `NamedTensorMap`, i.e., a JSON mapping weight
40637 * names to `tf.Tensor` objects.
40638 * @param strict Require that the provided weights exactly match those
40639 * required by the container. Default: `true`. Passing `false` means that
40640 * extra weights and missing weights will be silently ignored.
40641 */
40642 loadWeights(weights, strict = true) {
40643 const nameToWeight = {};
40644 let totalWeightsCount = 0;
40645 for (const layer of this.layers) {
40646 for (const weight of layer.weights) {
40647 if (nameToWeight[weight.originalName] != null) {
40648 throw new ValueError(`Duplicate weight name: ${weight.originalName}`);
40649 }
40650 nameToWeight[weight.originalName] = weight;
40651 totalWeightsCount++;
40652 }
40653 }
40654 const weightValueTuples = [];
40655 for (const name in weights) {
40656 // TF 2.2.0 added cell name to the weight name in the format of
40657 // layer_name/cell_name/weight_name, we need to remove
40658 // the inner cell name.
40659 let validatedName = name;
40660 if (nameToWeight[name] == null) {
40661 const tokens = name.split('/');
40662 const shortenNameArray = tokens.slice(0, -2).concat([tokens[tokens.length - 1]]);
40663 validatedName = shortenNameArray.join('/');
40664 }
40665 if (nameToWeight[validatedName] != null) {
40666 weightValueTuples.push([nameToWeight[validatedName], weights[name]]);
40667 }
40668 else if (strict) {
40669 throw new ValueError(`Provided weight data has no target variable: ${name}`);
40670 }
40671 delete nameToWeight[validatedName];
40672 }
40673 if (strict) {
40674 // Check that all weights are set.
40675 const unsetNames = [];
40676 for (const name in nameToWeight) {
40677 unsetNames.push(name);
40678 }
40679 if (unsetNames.length > 0) {
40680 throw new ValueError(`${unsetNames.length} of ${totalWeightsCount} weights are not set: ` +
40681 `${unsetNames}`);
40682 }
40683 }
40684 batchSetValue(weightValueTuples);
40685 }
40686 /**
40687 * Util shared between different serialization methods.
40688 * @returns LayersModel config with Keras version information added.
40689 */
40690 updatedConfig() {
40691 const theConfig = this.getConfig();
40692 const modelConfig = {};
40693 modelConfig['className'] = this.getClassName();
40694 modelConfig['config'] = theConfig;
40695 modelConfig['kerasVersion'] = `tfjs-layers ${version$1}`;
40696 // TODO(nielsene): Replace something like K.backend() once
40697 // possible.
40698 modelConfig['backend'] = 'TensorFlow.js';
40699 return modelConfig;
40700 }
40701 /**
40702 * Returns a JSON string containing the network configuration.
40703 *
40704 * To load a network from a JSON save file, use
40705 * models.modelFromJSON(jsonString);
40706 * @param extraJsonArgs Unused in tfjs-layers, maintained for PyKeras
40707 * @param returnString Whether the return value should be stringified
40708 * (default: `true`).
40709 * @returns a JSON string if `returnString` (default), or a JSON object if
40710 * `!returnString`.
40711 */
40712 // tslint:disable-next-line:no-any
40713 toJSON(unused, returnString = true) {
40714 const modelConfig = convertTsToPythonic(this.updatedConfig());
40715 return returnString ? JSON.stringify(modelConfig) : modelConfig;
40716 }
40717 /**
40718 * Call the model on new inputs.
40719 *
40720 * In this case `call` just reapplies all ops in the graph to the new inputs
40721 * (e.g. build a new computational graph from the provided inputs).
40722 *
40723 * @param inputs A tensor or list of tensors.
40724 * @param mask A mask or list of masks. A mask can be either a tensor or null
40725 * (no mask).
40726 *
40727 * @return A tensor if there is a single output, or a list of tensors if there
40728 * are more than one outputs.
40729 */
40730 call(inputs, kwargs) {
40731 return tidy(() => {
40732 inputs = toList(inputs);
40733 const feedDict = new FeedDict();
40734 for (let i = 0; i < this.inputs.length; ++i) {
40735 feedDict.add(this.inputs[i], inputs[i]);
40736 }
40737 return execute(this.outputs, feedDict, kwargs);
40738 });
40739 }
40740 /**
40741 * Computes an output mask tensor.
40742 *
40743 * @param inputs Tensor or list of tensors.
40744 * @param mask Tensor or list of tensors.
40745 *
40746 * @return null or a tensor (or list of tensors, one per output tensor of the
40747 * layer).
40748 */
40749 computeMask(inputs, mask) {
40750 return tidy(() => {
40751 inputs = toList(inputs);
40752 let masks;
40753 if (mask == null) {
40754 masks = pyListRepeat(null, inputs.length);
40755 }
40756 else {
40757 masks = toList(mask);
40758 }
40759 // TODO(michaelterry): Add support for mask caching.
40760 return this.runInternalGraph(inputs, masks)[1];
40761 });
40762 }
40763 /**
40764 * Computes the output shape of the layer.
40765 *
40766 * Assumes that the layer will be built to match that input shape provided.
40767 *
40768 * @param inputShape A shape (tuple of integers) or a list of shape tuples
40769 * (one per output tensor of the layer). Shape tuples can include null for
40770 * free dimensions, instead of an integer.
40771 */
40772 computeOutputShape(inputShape) {
40773 const inputShapes = normalizeShapeList(inputShape);
40774 if (inputShapes.length !== this.inputLayers.length) {
40775 throw new ValueError(`Invalid inputShape argument ${inputShape}: ` +
40776 `model has ${this.inputLayers.length} tensor inputs.`);
40777 }
40778 // TODO(michaelterry): Add caching
40779 const layersToOutputShapes = {};
40780 for (let i = 0; i < inputShapes.length; i++) {
40781 const layer = this.inputLayers[i];
40782 const inputShape = inputShapes[i];
40783 // It's an input layer: computeOutputShape is identity,
40784 // and there is only one node and one tensor output.
40785 const shapeKey = layer.name + '_0_0';
40786 layersToOutputShapes[shapeKey] = inputShape;
40787 }
40788 const depthKeys = Object.keys(this.nodesByDepth)
40789 .map(x => parseInt(x, 10))
40790 .sort(reverseNumberCompare);
40791 // Iterate over nodes, by depth level.
40792 if (depthKeys.length > 1) {
40793 for (const depth of depthKeys) {
40794 const nodes = this.nodesByDepth[depth];
40795 for (const node of nodes) {
40796 // This is always a single layer, never a list.
40797 const layer = node.outboundLayer;
40798 if (this.inputLayers.map(x => x.id).indexOf(layer.id) !== -1) {
40799 // We've already covered the input layers a few lines above.
40800 continue;
40801 }
40802 // Potentially redundant list, same size of node.inputTensors.
40803 const inputShapes = [];
40804 for (let j = 0; j < node.inboundLayers.length; j++) {
40805 const inboundLayer = node.inboundLayers[j];
40806 const nodeIndex = node.nodeIndices[j];
40807 const tensorIndex = node.tensorIndices[j];
40808 const shapeKey = `${inboundLayer.name}_${nodeIndex}_${tensorIndex}`;
40809 const inputShape = layersToOutputShapes[shapeKey];
40810 inputShapes.push(inputShape);
40811 }
40812 const outputShape = layer.computeOutputShape(singletonOrArray(inputShapes));
40813 const outputShapes = normalizeShapeList(outputShape);
40814 const nodeIndex = layer.inboundNodes.indexOf(node);
40815 for (let j = 0; j < outputShapes.length; j++) {
40816 const shapeKey = `${layer.name}_${nodeIndex}_${j}`;
40817 layersToOutputShapes[shapeKey] = outputShapes[j];
40818 }
40819 }
40820 }
40821 }
40822 // Read final output shapes from layersToOutputShapes.
40823 const outputShapes = [];
40824 const outputShapeKeys = [];
40825 for (let i = 0; i < this.outputLayers.length; i++) {
40826 const layer = this.outputLayers[i];
40827 const nodeIndex = this.outputLayersNodeIndices[i];
40828 const tensorIndex = this.outputLayersTensorIndices[i];
40829 const shapeKey = `${layer.name}_${nodeIndex}_${tensorIndex}`;
40830 outputShapeKeys.push(shapeKey);
40831 }
40832 for (let i = 0; i < outputShapeKeys.length; i++) {
40833 const key = outputShapeKeys[i];
40834 assert$1(key in layersToOutputShapes);
40835 outputShapes.push(layersToOutputShapes[key]);
40836 }
40837 // TODO(michaelterry): Update cache
40838 return singletonOrArray(outputShapes);
40839 }
40840 /**
40841 * Computes output tensors for new inputs.
40842 *
40843 * Note:
40844 * - Expects `inputs` to be a list (potentially with 1 element).
40845 *
40846 * @param inputs List of tensors
40847 * @param masks List of masks (tensors or null).
40848 * @return Three lists: outputTensors, outputMasks, outputShapes
40849 */
40850 runInternalGraph(inputs, masks) {
40851 if (masks == null) {
40852 masks = pyListRepeat(null, inputs.length);
40853 }
40854 // Dictionary mapping reference tensors to tuples
40855 // (computed tensor, compute mask)
40856 // we assume a 1:1 mapping from tensor to mask
40857 // TODO: raise exception when a `.computeMask()` call
40858 // does not return a list the same size as `call`
40859 const tensorMap = {};
40860 for (let i = 0; i < this.inputs.length; ++i) {
40861 const x = this.inputs[i];
40862 const y = inputs[i];
40863 const mask = masks[i];
40864 tensorMap[x.id] = [y, mask];
40865 }
40866 const depthKeys = Object.keys(this.nodesByDepth)
40867 .map(x => parseInt(x, 10))
40868 .sort(reverseNumberCompare);
40869 for (const depth of depthKeys) {
40870 const nodes = this.nodesByDepth[depth];
40871 for (const node of nodes) {
40872 // This is always a single layer, never a list.
40873 const layer = node.outboundLayer;
40874 const referenceInputTensors = node.inputTensors;
40875 const referenceOutputTensors = node.outputTensors;
40876 // If all previous input tensors are available in tensorMap,
40877 // then call node.inboundLayer on them.
40878 // List of tuples [input, mask]:
40879 const computedData = new Array();
40880 for (const x of referenceInputTensors) {
40881 if (x.id in tensorMap) {
40882 computedData.push(tensorMap[x.id]);
40883 }
40884 }
40885 if (computedData.length === referenceInputTensors.length) {
40886 // TODO(michaelterry): Add K.name_scope here, if we need it.
40887 let kwargs = {};
40888 let computedTensors;
40889 let computedMasks;
40890 let outputTensors;
40891 let outputMasks;
40892 // call layer
40893 if (node.callArgs != null) {
40894 kwargs = node.callArgs;
40895 }
40896 if (computedData.length === 1) {
40897 const [computedTensor, computedMask] = computedData[0];
40898 if (kwargs['mask'] == null) {
40899 kwargs['mask'] = computedMask;
40900 }
40901 outputTensors =
40902 toList(layer.call(computedTensor, kwargs));
40903 outputMasks = toList(layer.computeMask(computedTensor, computedMask));
40904 computedTensors = [computedTensor];
40905 computedMasks = [computedMask];
40906 }
40907 else {
40908 computedTensors = computedData.map(x => x[0]);
40909 computedMasks = computedData.map(x => x[1]);
40910 if (kwargs['mask'] == null) {
40911 kwargs['mask'] = computedMasks;
40912 }
40913 outputTensors =
40914 toList(layer.call(computedTensors, kwargs));
40915 outputMasks = toList(layer.computeMask(computedTensors, computedMasks));
40916 }
40917 if (layer.activityRegularizer) {
40918 throw new NotImplementedError('LayersModel invocation with concrete Tensor value(s) in the ' +
40919 'presence of activity regularizer(s) is not supported yet.');
40920 }
40921 // TODO(michaelterry): Add model updates and losses
40922 // Update tensor map.
40923 for (let i = 0; i < referenceOutputTensors.length; ++i) {
40924 const x = referenceOutputTensors[i];
40925 const y = outputTensors[i];
40926 const mask = outputMasks[i];
40927 tensorMap[x.id] = [y, mask];
40928 }
40929 }
40930 }
40931 }
40932 const outputTensors = [];
40933 const outputMasks = [];
40934 const outputShapes = [];
40935 for (const x of this.outputs) {
40936 assert$1(x.id in tensorMap, `Could not compute output ${x.name} : ${x.id}`);
40937 const [tensor, mask] = tensorMap[x.id];
40938 outputShapes.push(tensor.shape);
40939 outputTensors.push(tensor);
40940 outputMasks.push(mask);
40941 }
40942 // TODO(michaelterry): Add support for caches.
40943 return [outputTensors, outputMasks, outputShapes];
40944 }
40945 /**
40946 * Builds a map of internal node keys to node ordering.
40947 * Used in serializaion a node orderings may change as unused nodes are
40948 * dropped. Porting Note: This helper method was pulled out of getConfig to
40949 * improve readability.
40950 * @param layers An array of Layers in the model.
40951 * @returns Map of Node Keys to index order within the layer.
40952 */
40953 buildNodeConversionMap(layers) {
40954 const nodeConversionMap = {};
40955 let keptNodes;
40956 for (const layer of this.layers) {
40957 keptNodes = layer instanceof Container ? 1 : 0;
40958 for (let originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) {
40959 const nodeKey = Container.nodeKey(layer, originalNodeIndex);
40960 if (this.containerNodes.has(nodeKey)) {
40961 // i.e. we mark it to be saved
40962 nodeConversionMap[nodeKey] = keptNodes;
40963 keptNodes += 1;
40964 }
40965 }
40966 }
40967 return nodeConversionMap;
40968 }
40969 /**
40970 * Retrieves a layer based on either its name (unique) or index.
40971 *
40972 * Indices are based on order of horizontal graph traversal (bottom-up).
40973 *
40974 * If both `name` and `index` are specified, `index` takes precedence.
40975 *
40976 * @param name Name of layer.
40977 * @param index Index of layer.
40978 * @returns A Layer instance.
40979 * @throws ValueError: In case of invalid layer name or index.
40980 *
40981 * @doc {
40982 * heading: 'Layers',
40983 * subheading: 'Classes',
40984 * namespace: 'layers',
40985 * subclasses: ['LayersModel']
40986 * }
40987 */
40988 getLayer(name, index) {
40989 if (index != null) {
40990 if (this.layers.length <= index) {
40991 throw new ValueError(`Was asked to retrieve layer at index ${index}, but model only ` +
40992 `has ${this.layers.length} layer(s).`);
40993 }
40994 else {
40995 return this.layers[index];
40996 }
40997 }
40998 else {
40999 if (name == null) {
41000 throw new ValueError('Provide either a layer name or layer index');
41001 }
41002 }
41003 for (const layer of this.layers) {
41004 if (layer.name === name) {
41005 return layer;
41006 }
41007 }
41008 throw new ValueError(`No such layer: ${name}`);
41009 }
41010 /**
41011 * Retrieves the Container's current loss values.
41012 *
41013 * Used for regularizers during training.
41014 */
41015 calculateLosses() {
41016 // Porting Node: This is an augmentation to Container.loss in PyKeras.
41017 // In PyKeras, Container.loss returns symbolic tensors. Here a concrete
41018 // Tensor (specifically Scalar) values are returned. This is due to the
41019 // imperative backend.
41020 return tidy(() => {
41021 const losses = [];
41022 for (const layer of this.layers) {
41023 for (let nodeIndex = 0; nodeIndex < layer.inboundNodes.length; ++nodeIndex) {
41024 const nodeKey = Container.nodeKey(layer, nodeIndex);
41025 if (this.containerNodes.has(nodeKey)) {
41026 losses.push(...layer.calculateLosses());
41027 }
41028 }
41029 }
41030 // TODO(cais): Add any unconditional model-level losses?
41031 return losses;
41032 });
41033 }
41034 getConfig() {
41035 const config = { name: this.name };
41036 // Build a map from layer unique name (self._node_key)
41037 // to the index of the nodes that are saved in the config.
41038 // Only nodes in container_nodes are saved.
41039 const nodeConversionMap = this.buildNodeConversionMap(this.layers);
41040 // Serialize and save the layers in layerConfigs
41041 const layerConfigs = [];
41042 for (const layer of this.layers) {
41043 const layerClassName = layer.getClassName();
41044 const layerConfig = layer.getConfig();
41045 const filteredInboundNodes = [];
41046 for (let originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) {
41047 const node = layer.inboundNodes[originalNodeIndex];
41048 const nodeKey = Container.nodeKey(layer, originalNodeIndex);
41049 let kwargs = {};
41050 if (this.containerNodes.has(nodeKey)) {
41051 // The node is relevant to the model:
41052 // add to filteredInboundNodes.
41053 if (node.callArgs) {
41054 try {
41055 JSON.stringify(node.callArgs);
41056 kwargs = node.callArgs;
41057 }
41058 catch (err) {
41059 console.warn(`Layer ${layer.name} was passed ` +
41060 `non-serializable keyword arguments: ` +
41061 `${node.callArgs}. They will not be included ` +
41062 `in the serialized model (and thus will be ` +
41063 `missing at deserialization time).`);
41064 kwargs = {};
41065 }
41066 }
41067 if (node.inboundLayers.length > 0) {
41068 const nodeData = [];
41069 for (let i = 0; i < node.inboundLayers.length; i++) {
41070 const inboundLayer = node.inboundLayers[i];
41071 const nodeIndex = node.nodeIndices[i];
41072 const tensorIndex = node.tensorIndices[i];
41073 const nodeKey = Container.nodeKey(inboundLayer, nodeIndex);
41074 let newNodeIndex = nodeConversionMap[nodeKey];
41075 if (newNodeIndex == null) {
41076 newNodeIndex = 0;
41077 }
41078 nodeData.push([inboundLayer.name, newNodeIndex, tensorIndex, kwargs]);
41079 }
41080 filteredInboundNodes.push(nodeData);
41081 }
41082 }
41083 }
41084 const dict = {};
41085 dict['name'] = layer.name;
41086 dict['className'] = layerClassName;
41087 dict['config'] = layerConfig;
41088 dict['inboundNodes'] = filteredInboundNodes;
41089 layerConfigs.push(dict);
41090 }
41091 config['layers'] = layerConfigs;
41092 // Gather info about inputs and outputs
41093 const modelInputs = [];
41094 for (let i = 0; i < this.inputLayers.length; i++) {
41095 const layer = this.inputLayers[i];
41096 const nodeIndex = this.inputLayersNodeIndices[i];
41097 const nodeKey = Container.nodeKey(layer, nodeIndex);
41098 if (!this.containerNodes.has(nodeKey)) {
41099 continue;
41100 }
41101 let newNodeIndex = nodeConversionMap[nodeKey];
41102 if (newNodeIndex === null || newNodeIndex === undefined) {
41103 newNodeIndex = 0;
41104 }
41105 const tensorIndex = this.inputLayersTensorIndices[i];
41106 modelInputs.push([layer.name, newNodeIndex, tensorIndex]);
41107 }
41108 config['inputLayers'] = modelInputs;
41109 const modelOutputs = [];
41110 for (let i = 0; i < this.outputLayers.length; i++) {
41111 const layer = this.outputLayers[i];
41112 const nodeIndex = this.outputLayersNodeIndices[i];
41113 const nodeKey = Container.nodeKey(layer, nodeIndex);
41114 if (!this.containerNodes.has(nodeKey)) {
41115 continue;
41116 }
41117 let newNodeIndex = nodeConversionMap[nodeKey];
41118 if (newNodeIndex === null || newNodeIndex === undefined) {
41119 newNodeIndex = 0;
41120 }
41121 const tensorIndex = this.outputLayersTensorIndices[i];
41122 modelOutputs.push([layer.name, newNodeIndex, tensorIndex]);
41123 }
41124 config['outputLayers'] = modelOutputs;
41125 return config;
41126 }
41127 /**
41128 * Instantiates a LayersModel from its config (output of `get_config()`).
41129 * @param cls the class to create
41130 * @param config LayersModel config dictionary.
41131 * @param customObjects An optional dictionary of custom objects.
41132 * @param fastWeightInit Optional flag to use fast weight initialization
41133 * during deserialization. This is applicable to cases in which
41134 * the initialization will be immediately overwritten by loaded weight
41135 * values. Default: `false`.
41136 * @returns A LayersModel instance.
41137 * @throws ValueError: In case of improperly formatted config dict.
41138 */
41139 /** @nocollapse */
41140 static fromConfig(cls, config, customObjects = {}, fastWeightInit = false) {
41141 // Layer instances created during
41142 // the graph reconstruction process
41143 const createdLayers = {};
41144 // Dictionary mapping layer instances to
41145 // node data that specifies a layer call.
41146 // It acts as a queue that maintains any unprocessed
41147 // layer call until it becomes possible to process it
41148 // (i.e. until the input tensors to the call all exist).
41149 const unprocessedNodes = {};
41150 function addUnprocessedNode(layer, nodeData) {
41151 if (!(layer.name in unprocessedNodes)) {
41152 unprocessedNodes[layer.name] = [nodeData];
41153 }
41154 else {
41155 unprocessedNodes[layer.name].push(nodeData);
41156 }
41157 }
41158 function processNode(layer, nodeData) {
41159 const inputTensors = [];
41160 let kwargs;
41161 for (const inputData of nodeData) {
41162 const inboundLayerName = inputData[0];
41163 const inboundNodeIndex = inputData[1];
41164 const inboundTensorIndex = inputData[2];
41165 kwargs = inputData[3] == null ?
41166 {} :
41167 inputData[3];
41168 if (!(inboundLayerName in createdLayers)) {
41169 addUnprocessedNode(layer, nodeData);
41170 return;
41171 }
41172 const inboundLayer = createdLayers[inboundLayerName];
41173 if (inboundLayer.inboundNodes.length <= inboundNodeIndex) {
41174 addUnprocessedNode(layer, nodeData);
41175 return;
41176 }
41177 const inboundNode = inboundLayer.inboundNodes[inboundNodeIndex];
41178 inputTensors.push(inboundNode.outputTensors[inboundTensorIndex]);
41179 }
41180 // Call layer on its inputs, thus creating the node
41181 // and building the layer if needed.
41182 // Note: This has Eager vs Graph Implications.
41183 if (inputTensors.length > 0) {
41184 layer.apply(singletonOrArray(inputTensors), kwargs); // was ** kwargs
41185 }
41186 }
41187 /**
41188 * Deserialize a layer, then call it on appropriate inputs.
41189 * @param layerData: layer config dict.
41190 * @throws ValueError: In case of improperly formatted `layer_data`
41191 * dict.
41192 */
41193 function processLayer(layerData) {
41194 const layerName = layerData['name'];
41195 // Instantiate layer.
41196 const layer = deserialize(layerData, config['customObjects'] != null ?
41197 config['customObjects'] :
41198 {});
41199 layer.setFastWeightInitDuringBuild(fastWeightInit);
41200 createdLayers[layerName] = layer;
41201 // Gather layer inputs.
41202 const inboundNodesData = layerData['inboundNodes'];
41203 inboundNodesData.forEach(nodeData => {
41204 if (!(nodeData instanceof Array)) {
41205 throw new ValueError(`Corrupted configuration, expected array for nodeData: ${nodeData}`);
41206 }
41207 // We don't process nodes (i.e. make layer calls)
41208 // on the fly because the inbound node may not yet exist,
41209 // in case of layer shared at different topological depths
41210 // (e.g.a model such as A(B(A(B(x)))))
41211 addUnprocessedNode(layer, nodeData);
41212 });
41213 }
41214 // First, we create all layers and enqueue nodes to be processed.
41215 const name = config['name'];
41216 const layersFromConfig = config['layers'];
41217 for (const layerData of layersFromConfig) {
41218 processLayer(layerData);
41219 }
41220 // Then we process nodes in order of layer depth.
41221 // Nodes that cannot yet be processed(if the inbound node
41222 // does not yet exist) are re - enqueued, and the process
41223 // is repeated until all nodes are processed.
41224 while (!isObjectEmpty(unprocessedNodes)) {
41225 for (const layerData of layersFromConfig) {
41226 const layer = createdLayers[layerData['name']];
41227 if (layer.name in unprocessedNodes) {
41228 const currentUnprocessedNodesForLayer = unprocessedNodes[layer.name];
41229 delete unprocessedNodes[layer.name];
41230 for (const nodeData of currentUnprocessedNodesForLayer) {
41231 processNode(layer, nodeData);
41232 }
41233 }
41234 }
41235 }
41236 const inputTensors = [];
41237 const outputTensors = [];
41238 const inputLayersFromConfig = config['inputLayers'];
41239 for (const layerData of inputLayersFromConfig) {
41240 const layerName = layerData[0];
41241 const nodeIndex = layerData[1];
41242 const tensorIndex = layerData[2];
41243 assert$1(layerName in createdLayers);
41244 const layer = createdLayers[layerName];
41245 const layerOutputTensors = layer.inboundNodes[nodeIndex].outputTensors;
41246 inputTensors.push(layerOutputTensors[tensorIndex]);
41247 }
41248 const outputLayersFromConfig = config['outputLayers'];
41249 for (const layerData of outputLayersFromConfig) {
41250 const layerName = layerData[0];
41251 const nodeIndex = layerData[1];
41252 const tensorIndex = layerData[2];
41253 assert$1(layerName in createdLayers);
41254 const layer = createdLayers[layerName];
41255 const layerOutputTensors = layer.inboundNodes[nodeIndex].outputTensors;
41256 outputTensors.push(layerOutputTensors[tensorIndex]);
41257 }
41258 return new cls({ inputs: inputTensors, outputs: outputTensors, name });
41259 }
41260 /**
41261 * Determine whether the container is stateful.
41262 *
41263 * Porting Note: this is the equivalent of the stateful @property of
41264 * the Container class in PyKeras.
41265 */
41266 get stateful() {
41267 // Porting Note: This check is to prevent inadvertent setting of the
41268 // _stateful property of the Container instance.
41269 if (this._stateful) {
41270 throw new ValueError('Container instance unexpectedly has _stateful = true. The ' +
41271 'statefulness of a Container is determined by the Layers it ' +
41272 'contains. Its _stateful property must remain the default false.');
41273 }
41274 for (const layer of this.layers) {
41275 if (layer.stateful) {
41276 return true;
41277 }
41278 }
41279 return false;
41280 }
41281 /**
41282 * Reset the state of all stateful constituent layers (if any).
41283 *
41284 * Examples of stateful layers include RNN layers whose `stateful` property
41285 * is set as `true`.
41286 */
41287 resetStates() {
41288 tidy(() => {
41289 this.layers.forEach(layer => {
41290 // tslint:disable:no-any
41291 if (layer.stateful) {
41292 layer.resetStates();
41293 }
41294 // tslint:enable:no-any
41295 });
41296 });
41297 }
41298 }
41299
41300 /**
41301 * @license
41302 * Copyright 2018 Google LLC
41303 *
41304 * Use of this source code is governed by an MIT-style
41305 * license that can be found in the LICENSE file or at
41306 * https://opensource.org/licenses/MIT.
41307 * =============================================================================
41308 */
41309 function standardizeSampleOrClassWeights(xWeight, outputNames, weightType) {
41310 const numOutputs = outputNames.length;
41311 if (xWeight == null || (Array.isArray(xWeight) && xWeight.length === 0)) {
41312 return outputNames.map(name => null);
41313 }
41314 if (numOutputs === 1) {
41315 if (Array.isArray(xWeight) && xWeight.length === 1) {
41316 return xWeight;
41317 }
41318 else if (typeof xWeight === 'object' && outputNames[0] in xWeight) {
41319 return [xWeight[outputNames[0]]];
41320 }
41321 else {
41322 return [xWeight];
41323 }
41324 }
41325 if (Array.isArray(xWeight)) {
41326 if (xWeight.length !== numOutputs) {
41327 throw new Error(`Provided ${weightType} is an array of ${xWeight.length} ` +
41328 `element(s), but the model has ${numOutputs} outputs. ` +
41329 `Make sure a set of weights is provided for each model output.`);
41330 }
41331 return xWeight;
41332 }
41333 else if (typeof xWeight === 'object' && Object.keys(xWeight).length > 0 &&
41334 typeof xWeight[Object.keys(xWeight)[0]] ===
41335 'object') {
41336 const output = [];
41337 outputNames.forEach(outputName => {
41338 if (outputName in xWeight) {
41339 output.push(xWeight[outputName]);
41340 }
41341 else {
41342 output.push(null);
41343 }
41344 });
41345 return output;
41346 }
41347 else {
41348 throw new Error(`The model has multiple (${numOutputs}) outputs, ` +
41349 `so ${weightType} must be either an array with ` +
41350 `${numOutputs} elements or an object with ${outputNames} keys. ` +
41351 `Provided ${weightType} not understood: ${JSON.stringify(xWeight)}`);
41352 }
41353 }
41354 /**
41355 * Standardize class weighting objects.
41356 *
41357 * This function takes a single class-weighting object, an array of them,
41358 * or a map from output name to class-weighting object. It compares it to the
41359 * output name(s) of the model, base on which it outputs an array of
41360 * class-weighting objects of which the length matches the number of outputs.
41361 *
41362 * @param classWeight Input class-weighting object(s).
41363 * @param outputNames All output name(s) of the model.
41364 * @return An array of class-weighting objects. The length of the array matches
41365 * the model's number of outputs.
41366 */
41367 function standardizeClassWeights(classWeight, outputNames) {
41368 return standardizeSampleOrClassWeights(classWeight, outputNames, 'classWeight');
41369 }
41370 function standardizeSampleWeights(classWeight, outputNames) {
41371 return standardizeSampleOrClassWeights(classWeight, outputNames, 'sampleWeight');
41372 }
41373 /**
41374 * Standardize by-sample and/or by-class weights for training.
41375 *
41376 * Note that this function operates on one model output at a time. For a model
41377 * with multiple outputs, you must call this function multiple times.
41378 *
41379 * @param y The target tensor that the by-sample and/or by-class weight is for.
41380 * The values of y are assumed to encode the classes, either directly
41381 * as an integer index, or as one-hot encoding.
41382 * @param sampleWeight By-sample weights.
41383 * @param classWeight By-class weights: an object mapping class indices
41384 * (integers) to a weight (float) to apply to the model's loss for the
41385 * samples from this class during training. This can be useful to tell the
41386 * model to "pay more attention" to samples from an under-represented class.
41387 * @param sampleWeightMode The mode for the sample weights.
41388 * @return A Promise of weight tensor, of which the size of the first dimension
41389 * matches that of `y`.
41390 */
41391 async function standardizeWeights(y, sampleWeight, classWeight, sampleWeightMode) {
41392 if (sampleWeight != null || sampleWeightMode != null) {
41393 // TODO(cais): Once 'temporal' mode is implemented, document it in the doc
41394 // string.
41395 throw new Error('Support sampleWeight is not implemented yet');
41396 }
41397 if (classWeight != null) {
41398 // Apply class weights per sample.
41399 const yClasses = tidy(() => {
41400 if (y.shape.length === 1) {
41401 // Assume class indices.
41402 return clone(y);
41403 }
41404 else if (y.shape.length === 2) {
41405 if (y.shape[1] > 1) {
41406 // Assume one-hot encoding of classes.
41407 const axis = 1;
41408 return argMax(y, axis);
41409 }
41410 else if (y.shape[1] === 1) {
41411 // Class index.
41412 return reshape(y, [y.shape[0]]);
41413 }
41414 else {
41415 throw new Error(`Encountered unexpected last-dimension size (${y.shape[1]}) ` +
41416 `during handling of class weights. The size is expected to be ` +
41417 `>= 1.`);
41418 }
41419 }
41420 else {
41421 throw new Error(`Unexpected rank of target (y) tensor (${y.rank}) during ` +
41422 `handling of class weights. The rank is expected to be 1 or 2.`);
41423 }
41424 });
41425 const yClassIndices = Array.from(await yClasses.data());
41426 dispose(yClasses);
41427 const classSampleWeight = [];
41428 yClassIndices.forEach(classIndex => {
41429 if (classWeight[classIndex] == null) {
41430 throw new Error(`classWeight must contain all classes in the training data. ` +
41431 `The class ${classIndex} exists in the data but not in ` +
41432 `classWeight`);
41433 }
41434 else {
41435 classSampleWeight.push(classWeight[classIndex]);
41436 }
41437 });
41438 return tensor1d(classSampleWeight, 'float32');
41439 }
41440 else {
41441 return null;
41442 }
41443 }
41444 /**
41445 * Apply per-sample weights on the loss values from a number of samples.
41446 *
41447 * @param losses Loss tensor of shape `[batchSize]`.
41448 * @param sampleWeights Per-sample weight tensor of shape `[batchSize]`.
41449 * @returns Tensor of the same shape as`losses`.
41450 */
41451 function computeWeightedLoss$1(losses, sampleWeights) {
41452 return mul(losses, sampleWeights);
41453 }
41454
41455 /**
41456 * @license
41457 * Copyright 2018 Google LLC
41458 *
41459 * Use of this source code is governed by an MIT-style
41460 * license that can be found in the LICENSE file or at
41461 * https://opensource.org/licenses/MIT.
41462 * =============================================================================
41463 */
41464 // Default batch size used during tensor-based validation.
41465 const DEFAULT_VALIDATION_BATCH_SIZE = 32;
41466 /**
41467 * Standardize the output of a dataset iterator for use by
41468 * LayersModel.fitDataset().
41469 *
41470 * @param model: A `tf.LayersModel` object.
41471 * @param iteratorOut The output of a dataset iterator. It is required to be
41472 * an object of the form `{xs: TensorOrArrayOrMap, ys:
41473 * TensorOrArrayOrMap}`, where `TensorOrArrayOrMap` is a single `tf.Tensor`,
41474 * a `tf.Tensor[]`, or a flat map from string names to `tf.Tensor`s.
41475 * @returns A flat array of `tf.Tensor` objects: the input `tf.Tensor`s
41476 * followed by the target `tf.Tensor`s. When `tf.Tensor`s are provided
41477 * as a map, the order in the resulting array is taken from the `inputNames`
41478 * and `outputNames` of the model.
41479 */
41480 function standardizeDataIteratorOutput(
41481 // Type `model` as `any` here to avoid circular dependency w/
41482 // training.ts.
41483 // tslint:disable-next-line:no-any
41484 model, iteratorOut) {
41485 let xs;
41486 let ys;
41487 const iteratorOutObj = iteratorOut;
41488 xs = iteratorOutObj['xs'];
41489 ys = iteratorOutObj['ys'];
41490 assert(xs != null && ys != null, () => 'A Dataset iterator for fitDataset() is expected to generate ' +
41491 'objects of the form `{xs: xVal, ys: yVal}`, where the two ' +
41492 'values may be `tf.Tensor`, an array of Tensors, or a map of ' +
41493 'string to Tensor. The provided Dataset instead generates ' +
41494 `${iteratorOut}`);
41495 const flattenedXs = flattenTensorOrArrayOrMap('input', model.inputNames, xs);
41496 const flattenedYs = flattenTensorOrArrayOrMap('output', model.outputNames, ys);
41497 const batchSize = flattenedXs[0].shape[0];
41498 assert(flattenedXs.length === model.inputs.length, () => `LayersModel has ${model.inputs.length} inputs, but the dataset ` +
41499 `provides ${flattenedXs.length} inputs. (Expected input keys: ` +
41500 `${JSON.stringify(model.inputNames)})`);
41501 assert(flattenedYs.length === model.outputs.length, () => `LayersModel has ${model.outputs.length} outputs, but the dataset ` +
41502 `provides ${flattenedYs.length} outputs. (Expected output keys: ` +
41503 `${JSON.stringify(model.outputNames)})`);
41504 for (let xIndex = 0; xIndex < flattenedXs.length; xIndex++) {
41505 assert(flattenedXs[xIndex].shape[0] === batchSize, () => `Batch size mismatch: input ` +
41506 `${model.inputNames[xIndex]} has ${flattenedXs[xIndex].shape[0]}; ` +
41507 `expected ${batchSize} based on input ${model.inputNames[0]}.`);
41508 }
41509 for (let yIndex = 0; yIndex < flattenedYs.length; yIndex++) {
41510 assert(flattenedYs[yIndex].shape[0] === batchSize, () => `Batch size mismatch: output ` +
41511 `${model.outputNames[yIndex]} has ${flattenedYs[yIndex].shape[0]}; ` +
41512 `expected ${batchSize} based on input ${model.inputNames[0]}.`);
41513 }
41514 return { xs: flattenedXs, ys: flattenedYs };
41515 }
41516 function flattenTensorOrArrayOrMap(inputOrOutput, names, values) {
41517 if (values instanceof Tensor) {
41518 return [values];
41519 }
41520 else if (Array.isArray(values)) {
41521 assert(values.length === names.length, () => `Received an array of ${values.length} Tensors, but expected ${names.length} to match the ${inputOrOutput} keys ${names}.`);
41522 return values;
41523 }
41524 else {
41525 const result = [];
41526 // Check that all the required keys are available.
41527 for (const name of names) {
41528 if (values[name] == null) {
41529 throw new ValueError(`The feature data generated by the dataset lacks the required ` +
41530 `${inputOrOutput} key '${name}'.`);
41531 }
41532 result.push(values[name]);
41533 }
41534 return result;
41535 }
41536 }
41537 function standardizeTensorValidationData(data) {
41538 if (data.length === 3) {
41539 throw new NotImplementedError('Validation with sample weights is not implemented yet.');
41540 }
41541 return { xs: data[0], ys: data[1] };
41542 }
41543 async function fitDataset(
41544 // Type `model` as `any` here to avoid circular dependency w/
41545 // training.ts.
41546 // tslint:disable-next-line:no-any
41547 model, dataset, args) {
41548 const hasBatchesPerEpoch = args.batchesPerEpoch != null;
41549 assert(model.optimizer != null, () => 'You must compile a model before training/testing. Use ' +
41550 'LayersModel.compile(modelCompileConfig).');
41551 assert(args != null, () => `For fitDataset(), the 2nd argument (config) is required, ` +
41552 `but it is not provided in this call.`);
41553 assert(args.epochs != null && args.epochs > 0 && Number.isInteger(args.epochs), () => `For fitDataset(), config.epochs is expected to be a positive ` +
41554 `integer, but got ${args.epochs}`);
41555 assert(!hasBatchesPerEpoch ||
41556 (args.batchesPerEpoch > 0 && Number.isInteger(args.batchesPerEpoch)), () => `For fitDataset(), config.batchesPerEpoch is expected to be a ` +
41557 `positive integer if specified, but got ${args.batchesPerEpoch}`);
41558 assert(
41559 // tslint:disable-next-line:no-any
41560 args['validationSplit'] == null, () => '`validationSplit` is not supported by `fitDataset()`. ' +
41561 'Use validationData instead.');
41562 if (model.isTraining) {
41563 throw new Error('Cannot start training because another fit() call is ongoing.');
41564 }
41565 model.isTraining = true;
41566 try {
41567 const doValidation = args.validationData != null;
41568 let valXs;
41569 let valYs;
41570 if (doValidation) {
41571 if (isDatasetObject(args.validationData)) {
41572 assert(args.validationBatches == null ||
41573 (args.validationBatches > 0 &&
41574 Number.isInteger(args.validationBatches)), () => `For fitDataset() with dataset-based validation, ` +
41575 `config.validationBatches is expected not to be provided, ` +
41576 `or to be a positive integer, ` +
41577 `but got ${args.validationBatches}`);
41578 }
41579 else {
41580 const validationData = standardizeTensorValidationData(args.validationData);
41581 valXs = validationData.xs;
41582 valYs = validationData.ys;
41583 }
41584 }
41585 const trainFunction = model.makeTrainFunction();
41586 const outLabels = model.getDedupedMetricsNames();
41587 let callbackMetrics;
41588 if (doValidation) {
41589 callbackMetrics =
41590 outLabels.slice().concat(outLabels.map(n => 'val_' + n));
41591 }
41592 else {
41593 callbackMetrics = outLabels.slice();
41594 }
41595 const callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery);
41596 const verbose = args.verbose == null ? 1 : args.verbose;
41597 const { callbackList, history } = configureCallbacks(callbacks, verbose, args.epochs, null, null, getStepsPerEpoch(dataset, args), null, // Batch size determined by the dataset itself.
41598 doValidation, callbackMetrics);
41599 callbackList.setModel(model);
41600 model.history = history;
41601 await callbackList.onTrainBegin();
41602 model.stopTraining_ = false;
41603 let epoch = args.initialEpoch == null ? 0 : args.initialEpoch;
41604 let dataIterator = await dataset.iterator();
41605 while (epoch < args.epochs) {
41606 const epochLogs = {};
41607 await callbackList.onEpochBegin(epoch);
41608 let stepsDone = 0;
41609 let batchIndex = 0;
41610 if (!hasBatchesPerEpoch) {
41611 dataIterator = await dataset.iterator();
41612 }
41613 while (hasBatchesPerEpoch ? stepsDone < args.batchesPerEpoch : true) {
41614 const iteratorOut = await dataIterator.next();
41615 // If `batchesPerEpoch` is specified, the dataset should not be
41616 // exhausted until all epoches are done.
41617 if (hasBatchesPerEpoch && iteratorOut.done) {
41618 console.warn('You provided `batchesPerEpoch` as ' +
41619 `${args.batchesPerEpoch}, ` +
41620 'but your dataset iterator ran out of data after ' +
41621 `${stepsDone} batches; ` +
41622 'interrupting training. Make sure that your ' +
41623 'dataset can generate at least `batchesPerEpoch * epochs` ' +
41624 'batches (in this case, ' +
41625 `${args.batchesPerEpoch * args.epochs} batches). ` +
41626 'You may need to use the repeat() function when building ' +
41627 'your dataset.');
41628 break;
41629 }
41630 if (iteratorOut.value != null) {
41631 const { xs, ys } = standardizeDataIteratorOutput(model, iteratorOut.value);
41632 const batchLogs = {};
41633 batchLogs['batch'] = batchIndex;
41634 batchLogs['size'] = xs[0].shape[0];
41635 await callbackList.onBatchBegin(batchIndex, batchLogs);
41636 const sampleWeights = [];
41637 if (args.classWeight != null) {
41638 const standardClassWeights = standardizeClassWeights(args.classWeight, model.outputNames);
41639 for (let i = 0; i < standardClassWeights.length; ++i) {
41640 sampleWeights.push(await standardizeWeights(ys[i], null, standardClassWeights[i]));
41641 }
41642 }
41643 // Train on batch.
41644 const ins = xs.concat(ys).concat(sampleWeights);
41645 const outs = trainFunction(ins);
41646 dispose(ins);
41647 for (let i = 0; i < outLabels.length; ++i) {
41648 const label = outLabels[i];
41649 const out = outs[i];
41650 batchLogs[label] = out;
41651 keep(out);
41652 }
41653 await callbackList.onBatchEnd(batchIndex, batchLogs);
41654 disposeTensorsInLogs(batchLogs);
41655 batchIndex++;
41656 stepsDone++;
41657 }
41658 if (hasBatchesPerEpoch ? stepsDone >= args.batchesPerEpoch :
41659 iteratorOut.done) {
41660 // Epoch finished. Perform validation.
41661 if (doValidation) {
41662 let valOuts;
41663 if (isDatasetObject(args.validationData)) {
41664 valOuts = toList(await model.evaluateDataset(args.validationData, { batches: args.validationBatches }));
41665 }
41666 else {
41667 valOuts = toList(model.evaluate(valXs, valYs, {
41668 batchSize: args.validationBatchSize == null ?
41669 DEFAULT_VALIDATION_BATCH_SIZE :
41670 args.validationBatchSize,
41671 verbose: 0
41672 }));
41673 }
41674 for (let i = 0; i < model.metricsNames.length; ++i) {
41675 epochLogs[`val_${model.metricsNames[i]}`] = valOuts[i];
41676 }
41677 }
41678 // Call `break` to exit one epoch lopp after validation is done. If
41679 // config.batchesPerEpoch is specified, an epoch while loop will
41680 // stop when `stepsDone >= config.batchesPerEpoch`. When
41681 // config.batchesPerEpoch is not provided, the following `break` is
41682 // required to exit the while lopp after dataset is exhausted.
41683 break;
41684 }
41685 if (model.stopTraining_) {
41686 break;
41687 }
41688 }
41689 await callbackList.onEpochEnd(epoch, epochLogs);
41690 epoch++;
41691 if (model.stopTraining_) {
41692 break;
41693 }
41694 }
41695 await callbackList.onTrainEnd();
41696 await model.history.syncData();
41697 return model.history;
41698 }
41699 finally {
41700 model.isTraining = false;
41701 }
41702 }
41703 /** Helper function that determines number of steps (batches) per epoch. */
41704 function getStepsPerEpoch(dataset, args) {
41705 // Attempt to determine # of batches in an epoch.
41706 let stepsPerEpoch = null;
41707 if (args.batchesPerEpoch != null) {
41708 stepsPerEpoch = args.batchesPerEpoch;
41709 }
41710 else if (Number.isFinite(dataset.size)) {
41711 stepsPerEpoch = dataset.size;
41712 }
41713 return stepsPerEpoch;
41714 }
41715 // Check if provided object is a Dataset object by checking its .iterator
41716 // element.
41717 function isDatasetObject(dataset) {
41718 return (typeof dataset.iterator === 'function');
41719 }
41720 // Check if provided object is a LazyIterator object by checking it's .next
41721 // element.
41722 function isLazyIteratorObject(iterator) {
41723 return (typeof iterator.next === 'function');
41724 }
41725 async function evaluateDataset(
41726 // Type `model` as `any` here to avoid circular dependency w/
41727 // training.ts.
41728 // tslint:disable-next-line:no-any
41729 model, dataset, args) {
41730 args = args || {};
41731 const hasBatches = args.batches != null;
41732 const f = model.testFunction;
41733 let outs = [];
41734 if (args.verbose > 0) {
41735 throw new NotImplementedError('Verbose mode is not implemented yet.');
41736 }
41737 assert(!hasBatches || (args.batches > 0 && Number.isInteger(args.batches)), () => 'Test loop expects `batches` to be a positive integer, but ' +
41738 `received ${JSON.stringify(args.batches)}`);
41739 const dataIterator = isLazyIteratorObject(dataset) ?
41740 dataset :
41741 await dataset.iterator();
41742 // Keeps track of number of examples used in this evaluation.
41743 let numExamples = 0;
41744 let batch = 0;
41745 while (hasBatches ? batch < args.batches : true) {
41746 const iteratorOut = await dataIterator.next();
41747 outs = tidy(() => {
41748 if (iteratorOut.value) {
41749 // TODO(cais): Once real dataset is available, use
41750 // `map(x => standardizeDataIteratorOutput(model, x).map(f)`.
41751 const { xs, ys } = standardizeDataIteratorOutput(model, iteratorOut.value);
41752 const xsAndYs = xs.concat(ys);
41753 const batchOuts = tidy(() => f(xsAndYs));
41754 dispose(xsAndYs);
41755 if (batch === 0) {
41756 for (let i = 0; i < batchOuts.length; ++i) {
41757 outs.push(scalar(0));
41758 }
41759 }
41760 const batchSize = xsAndYs[0].shape[0];
41761 for (let i = 0; i < batchOuts.length; ++i) {
41762 const batchOut = batchOuts[i];
41763 const oldScalar = outs[i];
41764 outs[i] =
41765 tidy(() => add$1(outs[i], mul(batchSize, batchOut)));
41766 if (batch > 0) {
41767 dispose(oldScalar);
41768 }
41769 }
41770 dispose(batchOuts);
41771 numExamples += batchSize;
41772 ++batch;
41773 }
41774 return outs;
41775 });
41776 if (iteratorOut.done) {
41777 if (hasBatches) {
41778 console.warn('Your dataset iterator ran out of data during evaluateDataset(). ' +
41779 'Interrupting evalution. Make sure that your ' +
41780 'dataset can generate at least `batches` ' +
41781 `batches (in this case, ${args.batches} batches). ` +
41782 'You may need to use the repeat() function when building ' +
41783 'your dataset.');
41784 }
41785 break;
41786 }
41787 }
41788 for (let i = 0; i < outs.length; ++i) {
41789 const oldScalar = outs[i];
41790 outs[i] = div(outs[i], numExamples);
41791 dispose(oldScalar);
41792 }
41793 return singletonOrArray(outs);
41794 }
41795
41796 /**
41797 * @license
41798 * Copyright 2018 Google LLC
41799 *
41800 * Use of this source code is governed by an MIT-style
41801 * license that can be found in the LICENSE file or at
41802 * https://opensource.org/licenses/MIT.
41803 * =============================================================================
41804 */
41805 function checkBatchSize(batchSize) {
41806 assert(batchSize > 0 && Number.isInteger(batchSize), () => `batchSize is required to be a positive integer, but got ${batchSize}`);
41807 }
41808 /**
41809 * Slice a Tensor or an Array of Tensors, by start and stop indices.
41810 *
41811 * Porting Note: The `_slice_arrays` function in PyKeras is covered by this
41812 * function and `sliceArraysByIndices()` together.
41813 *
41814 * @param arrays: the input.
41815 * @param start: the starting index (inclusive).
41816 * @param stop: the stopping index (exclusive).
41817 * @returns The result of the slicing. If `arrays` is an `Array` of
41818 * `tf.Tensor`s, the slicing will be applied to all elements of the `Array`
41819 * in the same way.
41820 */
41821 function sliceArrays(arrays, start, stop) {
41822 if (arrays == null) {
41823 return [null];
41824 }
41825 else if (Array.isArray(arrays)) {
41826 return arrays.map(array => sliceAlongFirstAxis(array, start, stop - start));
41827 }
41828 else { // Tensor.
41829 return sliceAlongFirstAxis(arrays, start, stop - start);
41830 }
41831 }
41832 /**
41833 * Slice a Tensor or an Array of Tensors, by random-order indices.
41834 *
41835 * Porting Note: The `_slice_arrays` function in PyKeras is covered by this
41836 * function and `sliceArrays()` together.
41837 *
41838 * @param arrays The input `tf.Tensor` or `Array` of `tf.Tensor`s to slice.
41839 * If an `Array` of `tf.Tensor`s, all `tf.Tensor`s will be sliced in the
41840 * same fashion.
41841 * @param indices The indices to use for slicing along the first (batch)
41842 * dimension.
41843 * @returns Result(s) of the slicing.
41844 */
41845 function sliceArraysByIndices(arrays, indices) {
41846 return tidy(() => {
41847 if (arrays == null) {
41848 return null;
41849 }
41850 else if (Array.isArray(arrays)) {
41851 return arrays.map(array => sliceArraysByIndices(array, indices));
41852 }
41853 else {
41854 // TODO(cais): indices should be a pre-constructed Tensor1D to avoid
41855 // tensor1d() calls.
41856 return gather$1(arrays, indices.dtype === 'int32' ? indices : cast(indices, 'int32'));
41857 }
41858 });
41859 }
41860 /**
41861 * Returns a list of batch indices (tuples of indices).
41862 * @param size: Integer, total size of the data to slice into batches.
41863 * @param batchSize: Integer, batch size.
41864 * @returns An Array of [batchStart, batchEnd] tuples. batchStart is
41865 * inclusive; batchEnd is exclusive. I.e., each batch consists of indices x
41866 * that satisfy batchStart <= x < batchEnd.
41867 */
41868 function makeBatches(size, batchSize) {
41869 const output = [];
41870 let batchStart = 0;
41871 let batchEnd = null;
41872 while (batchStart < size) {
41873 batchEnd = batchStart + batchSize;
41874 if (batchEnd >= size) {
41875 batchEnd = size;
41876 }
41877 output.push([batchStart, batchEnd]);
41878 batchStart = batchEnd;
41879 }
41880 return output;
41881 }
41882 /**
41883 * Abstract fit function for `f(ins)`.
41884 * @param f A Function returning a list of tensors. For training, this
41885 * function is expected to perform the updates to the variables.
41886 * @param ins List of tensors to be fed to `f`.
41887 * @param outLabels List of strings, display names of the outputs of `f`.
41888 * @param batchSize Integer batch size or `== null` if unknown. Default : 32.
41889 * @param epochs Number of times to iterate over the data. Default : 1.
41890 * @param verbose Verbosity mode: 0, 1, or 2. Default: 1.
41891 * @param callbacks List of callbacks to be called during training.
41892 * @param valF Function to call for validation.
41893 * @param valIns List of tensors to be fed to `valF`.
41894 * @param shuffle Whether to shuffle the data at the beginning of every
41895 * epoch. Default : true.
41896 * @param callbackMetrics List of strings, the display names of the metrics
41897 * passed to the callbacks. They should be the concatenation of the
41898 * display names of the outputs of `f` and the list of display names
41899 * of the outputs of `valF`.
41900 * @param initialEpoch Epoch at which to start training (useful for
41901 * resuming a previous training run). Default : 0.
41902 * @param stepsPerEpoch Total number of steps (batches on samples) before
41903 * declaring one epoch finished and starting the next epoch. Ignored with
41904 * the default value of `undefined` or `null`.
41905 * @param validationSteps Number of steps to run validation for (only if
41906 * doing validation from data tensors). Not applicable for tfjs-layers.
41907 * @returns A `History` object.
41908 */
41909 async function fitLoop(
41910 // Type `model` as `any` here to avoid circular dependency w/ training.ts.
41911 // tslint:disable-next-line:no-any
41912 model, f, ins, outLabels, batchSize, epochs, verbose, callbacks, valF, valIns, shuffle$1, callbackMetrics, initialEpoch, stepsPerEpoch, validationSteps) {
41913 if (batchSize == null) {
41914 batchSize = 32;
41915 }
41916 if (epochs == null) {
41917 epochs = 1;
41918 }
41919 if (shuffle$1 == null) {
41920 shuffle$1 = true;
41921 }
41922 if (initialEpoch == null) {
41923 initialEpoch = 0;
41924 }
41925 // TODO(cais): Change const to let below when implementing validation.
41926 let doValidation = false;
41927 if (valF != null && valIns != null) {
41928 doValidation = true;
41929 // TODO(cais): verbose message.
41930 }
41931 if (validationSteps != null) {
41932 doValidation = true;
41933 if (stepsPerEpoch == null) {
41934 throw new ValueError('Can only use `validationSteps` when doing step-wise training, ' +
41935 'i.e., `stepsPerEpoch` must be set.');
41936 }
41937 }
41938 const numTrainSamples = model.checkNumSamples(ins, batchSize, stepsPerEpoch, 'steps_per_epoch');
41939 let indexArray;
41940 if (numTrainSamples != null) {
41941 indexArray = range$1(0, numTrainSamples);
41942 }
41943 if (verbose == null) {
41944 verbose = 1;
41945 }
41946 const { callbackList, history } = configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics);
41947 callbackList.setModel(model);
41948 model.history = history;
41949 await callbackList.onTrainBegin();
41950 model.stopTraining_ = false;
41951 // TODO(cais): Take care of callbacks.validation_data as in PyKeras.
41952 // TODO(cais): Pre-convert feeds for performance as in PyKeras.
41953 for (let epoch = initialEpoch; epoch < epochs; ++epoch) {
41954 await callbackList.onEpochBegin(epoch);
41955 const epochLogs = {};
41956 if (stepsPerEpoch != null) {
41957 throw new NotImplementedError('stepsPerEpoch mode is not implemented yet.');
41958 }
41959 else {
41960 if (shuffle$1 === 'batch') {
41961 throw new NotImplementedError('batch shuffling is not implemneted yet');
41962 }
41963 else if (shuffle$1) {
41964 shuffle(indexArray);
41965 }
41966 // Convert the potentially shuffled indices to Tensor1D, to avoid the
41967 // cost of repeated creation of Array1Ds later on.
41968 const epochIndexArray1D = tensor1d(indexArray);
41969 const batches = makeBatches(numTrainSamples, batchSize);
41970 for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
41971 const batchLogs = {};
41972 await callbackList.onBatchBegin(batchIndex, batchLogs);
41973 tidy(() => {
41974 const batchStart = batches[batchIndex][0];
41975 const batchEnd = batches[batchIndex][1];
41976 const batchIds = sliceAlongFirstAxis(epochIndexArray1D, batchStart, batchEnd - batchStart);
41977 batchLogs['batch'] = batchIndex;
41978 batchLogs['size'] = batchEnd - batchStart;
41979 // TODO(cais): In ins, train flag can be a number, instead of an
41980 // Tensor? Do we need to handle this in tfjs-layers?
41981 const insBatch = sliceArraysByIndices(ins, batchIds);
41982 const outs = f(insBatch);
41983 for (let i = 0; i < outLabels.length; ++i) {
41984 const label = outLabels[i];
41985 const out = outs[i];
41986 batchLogs[label] = out;
41987 keep(out);
41988 // TODO(cais): Use scope() to avoid ownership.
41989 }
41990 if (batchIndex === batches.length - 1) { // Last batch.
41991 if (doValidation) {
41992 const valOuts = model.testLoop(valF, valIns, batchSize);
41993 // Porting Notes: In tfjs-layers, valOuts is always an Array.
41994 for (let i = 0; i < outLabels.length; ++i) {
41995 const label = outLabels[i];
41996 const out = valOuts[i];
41997 keep(out);
41998 // TODO(cais): Use scope() to avoid ownership.
41999 epochLogs['val_' + label] = out;
42000 }
42001 }
42002 }
42003 });
42004 await callbackList.onBatchEnd(batchIndex, batchLogs);
42005 disposeTensorsInLogs(batchLogs);
42006 if (model.stopTraining_) {
42007 break;
42008 }
42009 // TODO(cais): return outs as list of Tensor.
42010 }
42011 epochIndexArray1D.dispose();
42012 }
42013 // TODO(cais): Run validation at the end of the epoch.
42014 await callbackList.onEpochEnd(epoch, epochLogs);
42015 if (model.stopTraining_) {
42016 break;
42017 }
42018 }
42019 await callbackList.onTrainEnd();
42020 await model.history.syncData();
42021 return model.history;
42022 }
42023 async function fitTensors(
42024 // Type `model` as `any` here to avoid circular dependency w/ training.ts.
42025 // tslint:disable-next-line:no-any
42026 model, x, y, args = {}) {
42027 if (model.isTraining) {
42028 throw new Error('Cannot start training because another fit() call is ongoing.');
42029 }
42030 model.isTraining = true;
42031 let inputs;
42032 let targets;
42033 let originalInputs;
42034 let originalTargets;
42035 let inputValX;
42036 let inputValY;
42037 let valX;
42038 let valY;
42039 let sampleWeights;
42040 try {
42041 const batchSize = args.batchSize == null ? 32 : args.batchSize;
42042 checkBatchSize(batchSize);
42043 // Validate user data.
42044 // TODO(cais): Support sampleWeight.
42045 const checkBatchAxis = false;
42046 const standardizedOuts = await model.standardizeUserData(x, y, args.sampleWeight, args.classWeight, checkBatchAxis, batchSize);
42047 inputs = standardizedOuts[0];
42048 targets = standardizedOuts[1];
42049 sampleWeights = standardizedOuts[2];
42050 // Prepare validation data.
42051 let doValidation = false;
42052 let valIns;
42053 if (args.validationData != null && args.validationData.length > 0) {
42054 doValidation = true;
42055 if (args.validationData.length === 2) {
42056 // config.validationData consists of valX and valY.
42057 inputValX = args.validationData[0];
42058 inputValY = args.validationData[1];
42059 }
42060 else if (args.validationData.length === 3) {
42061 throw new NotImplementedError('validationData including sample weights is not supported yet.');
42062 }
42063 else {
42064 throw new ValueError(`When passing validation data, it must contain 2 (valX, valY) ` +
42065 `or 3 (valX, valY, valSampleWeight) items; ` +
42066 `${args.validationData} is invalid.`);
42067 }
42068 const checkBatchAxis = true;
42069 const valStandardized = await model.standardizeUserData(inputValX, inputValY, null, /** Unused sample weights. */ null, /** Unused class weights. */ checkBatchAxis, batchSize);
42070 valX = valStandardized[0];
42071 valY = valStandardized[1];
42072 valIns = valX.concat(valY);
42073 // TODO(cais): Add useLearningPhase data properly.
42074 }
42075 else if (args.validationSplit != null && args.validationSplit > 0 &&
42076 args.validationSplit < 1) {
42077 doValidation = true;
42078 // Porting Note: In tfjs-layers, inputs[0] is always a Tensor.
42079 const splitAt = Math.floor(inputs[0].shape[0] * (1 - args.validationSplit));
42080 const originalBatchSize = inputs[0].shape[0];
42081 valX = sliceArrays(inputs, splitAt, originalBatchSize);
42082 originalInputs = inputs;
42083 inputs = sliceArrays(inputs, 0, splitAt);
42084 valY = sliceArrays(targets, splitAt, originalBatchSize);
42085 originalTargets = targets;
42086 targets = sliceArrays(targets, 0, splitAt);
42087 // TODO(cais): Once sampleWeights becomes available, slice it to get
42088 // valSampleWeights.
42089 valIns = valX.concat(valY);
42090 // TODO(cais): Add useLearningPhase data properly.
42091 }
42092 else if (args.validationSteps != null) {
42093 doValidation = true;
42094 // TODO(cais): Add useLearningPhase.
42095 }
42096 const ins = inputs.concat(targets).concat(sampleWeights);
42097 model.checkTrainableWeightsConsistency();
42098 // TODO(cais): Handle use_learning_phase and learning_phase?
42099 // Porting Note: Here we see a key deviation of tfjs-layers from
42100 // Keras.
42101 // Due to the imperative nature of tfjs-layers' backend (tfjs-core),
42102 // we do not construct symbolic computation graphs to embody the
42103 // training process. Instead, we define a function that performs the
42104 // training action. In PyKeras, the data (inputs and targets) are fed
42105 // through graph placeholders. In tfjs-layers, the data are fed as
42106 // function arguments. Since the function are defined below in the
42107 // scope, we don't have equivalents of PyKeras's
42108 // `_make_train_funciton`.
42109 const trainFunction = model.makeTrainFunction();
42110 const outLabels = model.getDedupedMetricsNames();
42111 let valFunction;
42112 let callbackMetrics;
42113 if (doValidation) {
42114 model.makeTestFunction();
42115 valFunction = model.testFunction;
42116 callbackMetrics =
42117 outLabels.slice().concat(outLabels.map(n => 'val_' + n));
42118 }
42119 else {
42120 valFunction = null;
42121 valIns = [];
42122 callbackMetrics = outLabels.slice();
42123 }
42124 const callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery);
42125 const out = await fitLoop(model, trainFunction, ins, outLabels, batchSize, args.epochs, args.verbose, callbacks, valFunction, valIns, args.shuffle, callbackMetrics, args.initialEpoch, null, null);
42126 return out;
42127 }
42128 finally {
42129 model.isTraining = false;
42130 // Memory clean up.
42131 disposeNewTensors(inputs, x);
42132 disposeNewTensors(targets, y);
42133 disposeNewTensors(originalInputs, x);
42134 disposeNewTensors(originalTargets, y);
42135 disposeNewTensors(valX, inputValX);
42136 disposeNewTensors(valY, inputValY);
42137 if (sampleWeights != null) {
42138 dispose(sampleWeights);
42139 }
42140 }
42141 // TODO(cais): Add value to outLabels.
42142 }
42143 /**
42144 * Ensure tensors all have a rank of at least 2.
42145 *
42146 * If a tensor has a rank of 1, it is dimension-expanded to rank 2.
42147 * If any tensor has a rank of 0 (i.e., is a scalar), an error will be thrown.
42148 */
42149 function ensureTensorsRank2OrHigher(tensors) {
42150 const outs = [];
42151 if (tensors instanceof Tensor) {
42152 tensors = [tensors];
42153 }
42154 // Make Tensors at least 2D.
42155 for (let i = 0; i < tensors.length; ++i) {
42156 const tensor = tensors[i];
42157 if (tensor.rank === 1) {
42158 outs.push(expandDims$1(tensor, 1));
42159 }
42160 else if (tensor.rank === 0) {
42161 throw new Error('Expected tensor to be at least 1D, but received a 0D tensor ' +
42162 '(scalar).');
42163 }
42164 else {
42165 outs.push(tensor);
42166 }
42167 }
42168 return outs;
42169 }
42170 /**
42171 * Compare a set of tensors with a reference (old) set, discard the ones
42172 * in the new set that are not present in the reference set.
42173 *
42174 * This method is used for memory clenaup during calls such as
42175 * LayersModel.fit().
42176 *
42177 * @param tensors New set which may contain Tensors not present in
42178 * `refTensors`.
42179 * @param refTensors Reference Tensor set.
42180 */
42181 // TODO(cais, kangyizhang): Deduplicate with tfjs-data.
42182 function disposeNewTensors(tensors, refTensors) {
42183 if (tensors == null) {
42184 return;
42185 }
42186 const oldTensorIds = [];
42187 if (refTensors instanceof Tensor) {
42188 oldTensorIds.push(refTensors.id);
42189 }
42190 else if (Array.isArray(refTensors)) {
42191 refTensors.forEach(t => oldTensorIds.push(t.id));
42192 }
42193 else if (refTensors != null) {
42194 // `oldTensors` is a map from string name to Tensor.
42195 for (const name in refTensors) {
42196 const oldTensor = refTensors[name];
42197 oldTensorIds.push(oldTensor.id);
42198 }
42199 }
42200 const tensorsToDispose = [];
42201 if (tensors instanceof Tensor) {
42202 if (oldTensorIds.indexOf(tensors.id) === -1) {
42203 tensorsToDispose.push(tensors);
42204 }
42205 }
42206 else if (Array.isArray(tensors)) {
42207 tensors.forEach(t => {
42208 if (oldTensorIds.indexOf(t.id) === -1) {
42209 tensorsToDispose.push(t);
42210 }
42211 });
42212 }
42213 else if (tensors != null) {
42214 // `oldTensors` is a map from string name to Tensor.
42215 for (const name in tensors) {
42216 const tensor = tensors[name];
42217 if (oldTensorIds.indexOf(tensor.id) === -1) {
42218 tensorsToDispose.push(tensor);
42219 }
42220 }
42221 }
42222 tensorsToDispose.forEach(t => {
42223 if (!t.isDisposed) {
42224 t.dispose();
42225 }
42226 });
42227 }
42228
42229 /**
42230 * @license
42231 * Copyright 2018 Google LLC
42232 *
42233 * Use of this source code is governed by an MIT-style
42234 * license that can be found in the LICENSE file or at
42235 * https://opensource.org/licenses/MIT.
42236 * =============================================================================
42237 */
42238 /**
42239 * Helper function for polymorphic input data: 1. singleton Tensor.
42240 */
42241 function isDataTensor(x) {
42242 return x instanceof Tensor;
42243 }
42244 /**
42245 * Helper function for polymorphic input data: 2. Array of Tensor.
42246 */
42247 function isDataArray(x) {
42248 return Array.isArray(x);
42249 }
42250 /**
42251 * Helper function for polymorphic input data: 3. "dict" of Tensor.
42252 */
42253 function isDataDict(x) {
42254 return !isDataTensor(x) && !isDataArray(x);
42255 }
42256 /**
42257 * Normalizes inputs and targets provided by users.
42258 * @param data User-provided input data (polymorphic).
42259 * @param names An Array of expected Tensor names.
42260 * @param shapes Optional Array of expected Tensor shapes.
42261 * @param checkBatchAxis Whether to check that the batch axis of the arrays
42262 * match the expected value found in `shapes`.
42263 * @param exceptionPrefix String prefix used for exception formatting.
42264 * @returns List of standardized input Tensors (one Tensor per model input).
42265 * @throws ValueError: in case of improperly formatted user data.
42266 */
42267 function standardizeInputData(data, names, shapes, checkBatchAxis = true, exceptionPrefix = '') {
42268 if (names == null || names.length === 0) {
42269 // Check for the case where the model expected no data, but some data got
42270 // sent.
42271 if (data != null) {
42272 let gotUnexpectedData = false;
42273 if (isDataArray(data) && data.length > 0) {
42274 gotUnexpectedData = true;
42275 }
42276 else if (isDataDict(data)) {
42277 for (const key in data) {
42278 if (data.hasOwnProperty(key)) {
42279 gotUnexpectedData = true;
42280 break;
42281 }
42282 }
42283 }
42284 else {
42285 // `data` is a singleton Tensor in this case.
42286 gotUnexpectedData = true;
42287 }
42288 if (gotUnexpectedData) {
42289 throw new ValueError(`Error when checking model ${exceptionPrefix} expected no data, ` +
42290 `but got ${data}`);
42291 }
42292 }
42293 return [];
42294 }
42295 if (data == null) {
42296 return names.map(name => null);
42297 }
42298 let arrays;
42299 if (isDataDict(data)) {
42300 data = data;
42301 arrays = [];
42302 for (const name of names) {
42303 if (data[name] == null) {
42304 throw new ValueError(`No data provided for "${name}". Need data for each key in: ` +
42305 `${names}`);
42306 }
42307 arrays.push(data[name]);
42308 }
42309 }
42310 else if (isDataArray(data)) {
42311 data = data;
42312 if (data.length !== names.length) {
42313 throw new ValueError(`Error when checking model ${exceptionPrefix}: the Array of ` +
42314 `Tensors that you are passing to your model is not the size the ` +
42315 `model expected. Expected to see ${names.length} Tensor(s), but ` +
42316 `instead got the following list of Tensor(s): ${data}`);
42317 }
42318 arrays = data;
42319 }
42320 else {
42321 data = data;
42322 if (names.length > 1) {
42323 throw new ValueError(`The model ${exceptionPrefix} expects ${names.length} Tensor(s), ` +
42324 `but only received one Tensor. Found: Tensor with shape ${data.shape}`);
42325 }
42326 arrays = [data];
42327 }
42328 arrays = ensureTensorsRank2OrHigher(arrays);
42329 // Check shape compatibility.
42330 if (shapes != null) {
42331 for (let i = 0; i < names.length; ++i) {
42332 if (shapes[i] == null) {
42333 continue;
42334 }
42335 const array = arrays[i];
42336 if (array.shape.length !== shapes[i].length) {
42337 throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
42338 `to have ${shapes[i].length} dimension(s). but got array with ` +
42339 `shape ${array.shape}`);
42340 }
42341 for (let j = 0; j < shapes[i].length; ++j) {
42342 if (j === 0 && !checkBatchAxis) {
42343 // Skip the first (batch) axis.
42344 continue;
42345 }
42346 const dim = array.shape[j];
42347 const refDim = shapes[i][j];
42348 if (refDim != null && refDim >= 0 && dim !== refDim) {
42349 throw new ValueError(`${exceptionPrefix} expected a batch of elements where each ` +
42350 `example has shape [${shapes[i].slice(1, shapes[i].length)}] ` +
42351 `(i.e.,tensor shape [*,${shapes[i].slice(1, shapes[i].length)}])` +
42352 ` but the ${exceptionPrefix} received an input with ${array.shape[0]}` +
42353 ` examples, each with shape [${array.shape.slice(1, array.shape.length)}]` +
42354 ` (tensor shape [${array.shape}])`);
42355 }
42356 }
42357 }
42358 }
42359 return arrays;
42360 }
42361 /**
42362 * User input validation for Tensors.
42363 * @param inputs `Array` of `tf.Tensor`s for inputs.
42364 * @param targets `Array` of `tf.Tensor`s for targets.
42365 * @param weights Optional `Array` of `tf.Tensor`s for sample weights.
42366 * @throws ValueError: in case of incorrectly formatted data.
42367 */
42368 function checkArrayLengths(inputs, targets, weights) {
42369 const setX = unique$1(inputs.map(input => input.shape[0]));
42370 setX.sort();
42371 const setY = unique$1(targets.map(target => target.shape[0]));
42372 setY.sort();
42373 // TODO(cais): Check `weights` as well.
42374 if (setX.length > 1) {
42375 throw new ValueError(`All input Tensors (x) should have the same number of samples. ` +
42376 `Got array shapes: ` +
42377 `${JSON.stringify(inputs.map(input => input.shape))}`);
42378 }
42379 if (setY.length > 1) {
42380 throw new ValueError(`All target Tensors (y) should have the same number of samples. ` +
42381 `Got array shapes: ` +
42382 `${JSON.stringify(targets.map(target => target.shape))}`);
42383 }
42384 if (setX.length > 0 && setY.length > 0 && !arraysEqual(setX, setY)) {
42385 throw new ValueError(`Input Tensors should have the same number of samples as target ` +
42386 `Tensors. Found ${setX[0]} input sample(s) and ${setY[0]} target ` +
42387 `sample(s).`);
42388 }
42389 }
42390 /**
42391 * Validation on the compatibility of targes and loss functions.
42392 *
42393 * This helps prevent users from using loss functions incorrectly.
42394 *
42395 * @param targets `Array` of `tf.Tensor`s of targets.
42396 * @param lossFns `Array` of loss functions.
42397 * @param outputShapes `Array` of shapes of model outputs.
42398 */
42399 function checkLossAndTargetCompatibility(targets, lossFns, outputShapes) {
42400 // TODO(cais): Dedicated test coverage?
42401 const keyLosses = [
42402 meanSquaredError$1, binaryCrossentropy,
42403 categoricalCrossentropy
42404 ];
42405 for (let i = 0; i < targets.length; ++i) {
42406 const y = targets[i];
42407 const loss = lossFns[i];
42408 const shape = outputShapes[i];
42409 if (loss == null) {
42410 continue;
42411 }
42412 if (loss === categoricalCrossentropy) {
42413 if (y.shape[y.shape.length - 1] === 1) {
42414 throw new ValueError(`You are passing a target array of shape ${y.shape} while using ` +
42415 `a loss 'categorical_crossentropy'. 'categorical_crossentropy'` +
42416 `expects targets to be binary matrices (1s and 0s) of shape ` +
42417 `[samples, classes].`);
42418 // TODO(cais): Example code in error message.
42419 }
42420 }
42421 if (keyLosses.indexOf(loss) !== -1) {
42422 const slicedYShape = y.shape.slice(1);
42423 const slicedShape = shape.slice(1);
42424 for (let j = 0; j < slicedYShape.length; ++j) {
42425 const targetDim = slicedYShape[j];
42426 const outDim = slicedShape[j];
42427 if (outDim != null && targetDim !== outDim) {
42428 throw new ValueError(`A target Tensor with shape ${y.shape} was passed for an ` +
42429 `output of shape ${shape}, while using a loss function that ` +
42430 `expects targets to have the same shape as the output.`);
42431 }
42432 }
42433 }
42434 }
42435 }
42436 /**
42437 * Check inputs provided by the user.
42438 *
42439 * Porting Note: This corresponds to _standardize_input_data() in Python
42440 * Keras. Because of the strong typing in TF.js, we do not need to convert
42441 * the data. Specifically:
42442 * 1) in PyKeras, `data` can be `DataFrame` instances from pandas, for
42443 * example. We don't need to worry about that here because there is no
42444 * widely popular javascript/typesdcript equivalent of pandas (so far).
42445 * If one becomes available in the future, we can add support.
42446 * 2) in PyKeras, inputs can be Python dict. But here we are stipulating
42447 * that the data is either a single `tf.Tensor` or an Array of `tf.Tensor`s. We
42448 * may add support for `Object` data inputs in the future when the need
42449 * arises.
42450 *
42451 * Instead, we perform basic checks for number of parameters and shapes.
42452 *
42453 * @param data: The input data.
42454 * @param names: Name for the inputs, from the model.
42455 * @param shapes: Expected shapes for the input data, from the model.
42456 * @param checkBatchAxis: Whether the size along the batch axis (i.e., the
42457 * first dimension) will be checked for matching.
42458 * @param exceptionPrefix: Execption prefix message, used in generating error
42459 * messages.
42460 * @throws ValueError: on incorrect number of inputs or mismatches in shapes.
42461 */
42462 function checkInputData(data, names, shapes, checkBatchAxis = true, exceptionPrefix = '') {
42463 let arrays;
42464 if (Array.isArray(data)) {
42465 if (data.length !== names.length) {
42466 throw new ValueError(`Error when checking model ${exceptionPrefix}: the Array of ` +
42467 `Tensors that you are passing to your model is not the size the ` +
42468 `the model expected. Expected to see ${names.length} Tensor(s),` +
42469 ` but instead got ${data.length} Tensors(s).`);
42470 }
42471 arrays = data;
42472 }
42473 else {
42474 if (names.length > 1) {
42475 throw new ValueError(`The model expects ${names.length} ${exceptionPrefix} Tensors, ` +
42476 `but only received one Tensor. Found: array with shape ` +
42477 `${JSON.stringify(data.shape)}.`);
42478 }
42479 arrays = [data];
42480 }
42481 if (shapes != null) {
42482 for (let i = 0; i < names.length; ++i) {
42483 if (shapes[i] == null) {
42484 continue;
42485 }
42486 const array = arrays[i];
42487 if (array.shape.length !== shapes[i].length) {
42488 throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
42489 `to have ${shapes[i].length} dimension(s), but got array with ` +
42490 `shape ${JSON.stringify(array.shape)}`);
42491 }
42492 for (let j = 0; j < shapes[i].length; ++j) {
42493 if (j === 0 && !checkBatchAxis) {
42494 continue;
42495 }
42496 const dim = array.shape[j];
42497 const refDim = shapes[i][j];
42498 if (refDim != null) {
42499 if (refDim !== dim) {
42500 throw new ValueError(`Error when checking ${exceptionPrefix}: expected ` +
42501 `${names[i]} to have shape ${JSON.stringify(shapes[i])} but ` +
42502 `got array with shape ${JSON.stringify(array.shape)}.`);
42503 }
42504 }
42505 }
42506 }
42507 }
42508 }
42509 /**
42510 * Maps metric functions to model outputs.
42511 * @param metrics An shortcut strings name, metric function, `Array` or dict
42512 * (`Object`) of metric functions.
42513 * @param outputNames An `Array` of the names of model outputs.
42514 * @returns An `Array` (one entry per model output) of `Array` of metric
42515 * functions. For instance, if the model has 2 outputs, and for the first
42516 * output we want to compute `binaryAccuracy` and `binaryCrossentropy`,
42517 * and just `binaryAccuracy` for the second output, the `Array` would look
42518 * like:
42519 * `[[binaryAccuracy, binaryCrossentropy], [binaryAccuracy]]`
42520 * @throws TypeError: incompatible metrics format.
42521 */
42522 function collectMetrics(metrics, outputNames) {
42523 if (metrics == null || Array.isArray(metrics) && metrics.length === 0) {
42524 return outputNames.map(name => []);
42525 }
42526 let wrappedMetrics;
42527 if (typeof metrics === 'string' || typeof metrics === 'function') {
42528 wrappedMetrics = [metrics];
42529 }
42530 else if (Array.isArray(metrics) || typeof metrics === 'object') {
42531 wrappedMetrics = metrics;
42532 }
42533 else {
42534 throw new TypeError('Type of metrics argument not understood. Expected an string,' +
42535 `function, Array, or Object, found: ${metrics}`);
42536 }
42537 if (Array.isArray(wrappedMetrics)) {
42538 // We then apply all metrics to all outputs.
42539 return outputNames.map(name => wrappedMetrics);
42540 }
42541 else {
42542 // In this case, metrics is a dict.
42543 const nestedMetrics = [];
42544 for (const name of outputNames) {
42545 let outputMetrics = wrappedMetrics.hasOwnProperty(name) ? wrappedMetrics[name] : [];
42546 if (!Array.isArray(outputMetrics)) {
42547 outputMetrics = [outputMetrics];
42548 }
42549 nestedMetrics.push(outputMetrics);
42550 }
42551 return nestedMetrics;
42552 }
42553 }
42554 const LAYERS_MODEL_FORMAT_NAME = 'layers-model';
42555 /**
42556 * A `tf.LayersModel` is a directed, acyclic graph of `tf.Layer`s plus methods
42557 * for training, evaluation, prediction and saving.
42558 *
42559 * `tf.LayersModel` is the basic unit of training, inference and evaluation in
42560 * TensorFlow.js. To create a `tf.LayersModel`, use `tf.LayersModel`.
42561 *
42562 * See also:
42563 * `tf.Sequential`, `tf.loadLayersModel`.
42564 *
42565 * @doc {heading: 'Models', subheading: 'Classes'}
42566 */
42567 class LayersModel extends Container {
42568 constructor(args) {
42569 super(args);
42570 this.isTraining = false;
42571 }
42572 /**
42573 * Print a text summary of the model's layers.
42574 *
42575 * The summary includes
42576 * - Name and type of all layers that comprise the model.
42577 * - Output shape(s) of the layers
42578 * - Number of weight parameters of each layer
42579 * - If the model has non-sequential-like topology, the inputs each layer
42580 * receives
42581 * - The total number of trainable and non-trainable parameters of the model.
42582 *
42583 * ```js
42584 * const input1 = tf.input({shape: [10]});
42585 * const input2 = tf.input({shape: [20]});
42586 * const dense1 = tf.layers.dense({units: 4}).apply(input1);
42587 * const dense2 = tf.layers.dense({units: 8}).apply(input2);
42588 * const concat = tf.layers.concatenate().apply([dense1, dense2]);
42589 * const output =
42590 * tf.layers.dense({units: 3, activation: 'softmax'}).apply(concat);
42591 *
42592 * const model = tf.model({inputs: [input1, input2], outputs: output});
42593 * model.summary();
42594 * ```
42595 *
42596 * @param lineLength Custom line length, in number of characters.
42597 * @param positions Custom widths of each of the columns, as either
42598 * fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number
42599 * of characters (e.g., `[30, 50, 65]`). Each number corresponds to
42600 * right-most (i.e., ending) position of a column.
42601 * @param printFn Custom print function. Can be used to replace the default
42602 * `console.log`. For example, you can use `x => {}` to mute the printed
42603 * messages in the console.
42604 *
42605 * @doc {heading: 'Models', subheading: 'Classes'}
42606 */
42607 summary(lineLength, positions, printFn = console.log) {
42608 if (!this.built) {
42609 throw new ValueError(`This model has never been called, thus its weights have not been ` +
42610 `created yet. So no summary can be displayed. Build the model ` +
42611 `first (e.g., by calling it on some test data).`);
42612 }
42613 printSummary(this, lineLength, positions, printFn);
42614 }
42615 /**
42616 * Configures and prepares the model for training and evaluation. Compiling
42617 * outfits the model with an optimizer, loss, and/or metrics. Calling `fit`
42618 * or `evaluate` on an un-compiled model will throw an error.
42619 *
42620 * @param args a `ModelCompileArgs` specifying the loss, optimizer, and
42621 * metrics to be used for fitting and evaluating this model.
42622 *
42623 * @doc {heading: 'Models', subheading: 'Classes'}
42624 */
42625 compile(args) {
42626 if (args.loss == null) {
42627 args.loss = [];
42628 }
42629 this.loss = args.loss;
42630 if (typeof args.optimizer === 'string') {
42631 this.optimizer_ = getOptimizer(args.optimizer);
42632 this.isOptimizerOwned = true;
42633 }
42634 else {
42635 if (!(args.optimizer instanceof Optimizer)) {
42636 throw new ValueError(`User-defined optimizer must be an instance of tf.Optimizer.`);
42637 }
42638 this.optimizer_ = args.optimizer;
42639 this.isOptimizerOwned = false;
42640 }
42641 // TODO(cais): Add lossWeights.
42642 // TODO(cais): Add sampleWeightMode.
42643 // Prepare loss functions.
42644 let lossFunctions = [];
42645 if (!Array.isArray(args.loss) && typeof args.loss !== 'string' &&
42646 typeof args.loss !== 'function') {
42647 args.loss = args.loss;
42648 for (const name in args.loss) {
42649 if (this.outputNames.indexOf(name) === -1) {
42650 throw new ValueError(`Unknown entry in loss dictionary: "${name}". ` +
42651 `Only expected the following keys: ${this.outputNames}`);
42652 }
42653 }
42654 for (const name of this.outputNames) {
42655 if (args.loss[name] == null) {
42656 console.warn(`Output "${name}" is missing from loss dictionary. We assume ` +
42657 `this was done on purpose, and we will not be expecting data ` +
42658 `to be passed to ${name} during training`);
42659 }
42660 lossFunctions.push(get(args.loss[name]));
42661 }
42662 }
42663 else if (Array.isArray(args.loss)) {
42664 if (args.loss.length !== this.outputs.length) {
42665 throw new ValueError(`When passing an Array as loss, it should have one entry per ` +
42666 `model output. The model has ${this.outputs.length} output(s), ` +
42667 `but you passed loss=${args.loss}.`);
42668 }
42669 const theLosses = args.loss;
42670 lossFunctions = theLosses.map(l => get(l));
42671 }
42672 else {
42673 const lossFunction = get(args.loss);
42674 this.outputs.forEach(_ => {
42675 lossFunctions.push(lossFunction);
42676 });
42677 }
42678 this.lossFunctions = lossFunctions;
42679 this.feedOutputNames = [];
42680 this.feedOutputShapes = [];
42681 this.feedLossFns = [];
42682 for (let i = 0; i < this.outputs.length; ++i) {
42683 // TODO(cais): Logic for skipping target(s).
42684 const shape = this.internalOutputShapes[i];
42685 const name = this.outputNames[i];
42686 this.feedOutputNames.push(name);
42687 this.feedOutputShapes.push(shape);
42688 this.feedLossFns.push(this.lossFunctions[i]);
42689 }
42690 // TODO(cais): Add logic for output masks.
42691 // TODO(cais): Add logic for sample weights.
42692 const skipTargetIndices = [];
42693 // Prepare metrics.
42694 this.metrics = args.metrics;
42695 // TODO(cais): Add weightedMetrics.
42696 this.metricsNames = ['loss'];
42697 this.metricsTensors = [];
42698 // Compute total loss.
42699 // Porting Note: In PyKeras, metrics_tensors are symbolic tensor objects.
42700 // Here, metricsTensors are TypeScript functions. This difference is due
42701 // to the difference in symbolic/imperative property of the backends.
42702 nameScope('loss', () => {
42703 for (let i = 0; i < this.outputs.length; ++i) {
42704 if (skipTargetIndices.indexOf(i) !== -1) {
42705 continue;
42706 }
42707 // TODO(cais): Add weightedLoss, sampleWeight and mask.
42708 // The following line should be weightedLoss
42709 const weightedLoss = this.lossFunctions[i];
42710 if (this.outputs.length > 1) {
42711 this.metricsTensors.push([weightedLoss, i]);
42712 this.metricsNames.push(this.outputNames[i] + '_loss');
42713 }
42714 }
42715 // Porting Note: Due to the imperative nature of the backend, we calculate
42716 // the regularizer penalties in the totalLossFunction, instead of here.
42717 });
42718 const nestedMetrics = collectMetrics(args.metrics, this.outputNames);
42719 // TODO(cais): Add nestedWeightedMetrics.
42720 /**
42721 * Helper function used in loop below.
42722 */
42723 const appendMetric = (outputIndex, metricName, metricTensor) => {
42724 if (this.outputNames.length > 1) {
42725 metricName = this.outputNames[outputIndex] + '_' + metricName;
42726 }
42727 this.metricsNames.push(metricName);
42728 this.metricsTensors.push([metricTensor, outputIndex]);
42729 };
42730 nameScope('metric', () => {
42731 for (let i = 0; i < this.outputs.length; ++i) {
42732 if (skipTargetIndices.indexOf(i) !== -1) {
42733 continue;
42734 }
42735 const outputMetrics = nestedMetrics[i];
42736 // TODO(cais): Add weights and outputWeightedMetrics.
42737 // TODO(cais): Add optional arg `weights` to the following function.
42738 const handleMetrics = (metrics) => {
42739 const metricNamePrefix = '';
42740 let metricName;
42741 let accFn;
42742 let weightedMetricFn;
42743 // TODO(cais): Use 'weights_' for weighted metrics.
42744 for (const metric of metrics) {
42745 if (typeof metric === 'string' &&
42746 ['accuracy', 'acc', 'crossentropy', 'ce'].indexOf(metric) !==
42747 -1) {
42748 const outputShape = this.internalOutputShapes[i];
42749 if (outputShape[outputShape.length - 1] === 1 ||
42750 this.lossFunctions[i] === binaryCrossentropy) {
42751 // case: binary accuracy/crossentropy.
42752 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
42753 accFn = binaryAccuracy;
42754 }
42755 else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
42756 accFn = binaryCrossentropy$1;
42757 }
42758 }
42759 else if (this.lossFunctions[i] ===
42760 sparseCategoricalCrossentropy) {
42761 // case: categorical accuracy / crossentropy with sparse
42762 // targets.
42763 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
42764 accFn = sparseCategoricalAccuracy;
42765 }
42766 else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
42767 accFn = sparseCategoricalCrossentropy$1;
42768 }
42769 }
42770 else {
42771 // case: categorical accuracy / crossentropy.
42772 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
42773 accFn = categoricalAccuracy;
42774 }
42775 else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
42776 accFn = categoricalCrossentropy$1;
42777 }
42778 }
42779 let suffix;
42780 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
42781 suffix = 'acc';
42782 }
42783 else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
42784 suffix = 'ce';
42785 }
42786 // TODO(cais): Add weighting actually.
42787 weightedMetricFn = accFn;
42788 metricName = metricNamePrefix + suffix;
42789 }
42790 else {
42791 const metricFn = get$1(metric);
42792 // TODO(cais): Add weighting actually.
42793 weightedMetricFn = metricFn;
42794 metricName =
42795 metricNamePrefix + getLossOrMetricName(metric);
42796 }
42797 // TODO(cais): Add weighting and masking to metricResult.
42798 let metricResult;
42799 nameScope(metricName, () => {
42800 metricResult = weightedMetricFn;
42801 });
42802 appendMetric(i, metricName, metricResult);
42803 }
42804 };
42805 handleMetrics(outputMetrics);
42806 // TODO(cais): Call handleMetrics with weights.
42807 }
42808 });
42809 // Porting Notes: Given the imperative backend of tfjs-core,
42810 // there is no need for constructing the symbolic graph and placeholders.
42811 this.collectedTrainableWeights = this.trainableWeights;
42812 }
42813 /**
42814 * Check trainable weights count consistency.
42815 *
42816 * This will raise a warning if `this.trainableWeights` and
42817 * `this.collectedTrainableWeights` are inconsistent (i.e., have different
42818 * numbers of parameters).
42819 * Inconsistency will typically arise when one modifies `model.trainable`
42820 * without calling `model.compile()` again.
42821 */
42822 checkTrainableWeightsConsistency() {
42823 if (this.collectedTrainableWeights == null) {
42824 return;
42825 }
42826 if (this.trainableWeights.length !==
42827 this.collectedTrainableWeights.length) {
42828 console.warn('Discrepancy between trainableweights and collected trainable ' +
42829 'weights. Did you set `model.trainable` without calling ' +
42830 '`model.compile()` afterwards?');
42831 }
42832 }
42833 /**
42834 * Returns the loss value & metrics values for the model in test mode.
42835 *
42836 * Loss and metrics are specified during `compile()`, which needs to happen
42837 * before calls to `evaluate()`.
42838 *
42839 * Computation is done in batches.
42840 *
42841 * ```js
42842 * const model = tf.sequential({
42843 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
42844 * });
42845 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
42846 * const result = model.evaluate(
42847 * tf.ones([8, 10]), tf.ones([8, 1]), {batchSize: 4});
42848 * result.print();
42849 * ```
42850 *
42851 * @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the
42852 * model has multiple inputs.
42853 * @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the
42854 * model has multiple outputs.
42855 * @param args A `ModelEvaluateArgs`, containing optional fields.
42856 *
42857 * @return `Scalar` test loss (if the model has a single output and no
42858 * metrics) or `Array` of `Scalar`s (if the model has multiple outputs
42859 * and/or metrics). The attribute `model.metricsNames`
42860 * will give you the display labels for the scalar outputs.
42861 *
42862 * @doc {heading: 'Models', subheading: 'Classes'}
42863 */
42864 evaluate(x, y, args = {}) {
42865 const batchSize = args.batchSize == null ? 32 : args.batchSize;
42866 checkBatchSize(batchSize);
42867 // TODO(cais): Standardize `config.sampleWeights` as well.
42868 // Validate user data.
42869 const checkBatchAxis = true;
42870 const standardizedOuts = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
42871 try {
42872 // TODO(cais): If uses `useLearningPhase`, set the corresponding element
42873 // of the input to 0.
42874 const ins = standardizedOuts[0].concat(standardizedOuts[1]);
42875 this.makeTestFunction();
42876 const f = this.testFunction;
42877 const testOuts = this.testLoop(f, ins, batchSize, args.verbose, args.steps);
42878 return singletonOrArray(testOuts);
42879 }
42880 finally {
42881 disposeNewTensors(standardizedOuts[0], x);
42882 disposeNewTensors(standardizedOuts[1], y);
42883 }
42884 }
42885 // TODO(cais): Add code snippet below once real dataset objects are
42886 // available.
42887 /**
42888 * Evaluate model using a dataset object.
42889 *
42890 * Note: Unlike `evaluate()`, this method is asynchronous (`async`);
42891 *
42892 * @param dataset A dataset object. Its `iterator()` method is expected
42893 * to generate a dataset iterator object, the `next()` method of which
42894 * is expected to produce data batches for evaluation. The return value
42895 * of the `next()` call ought to contain a boolean `done` field and a
42896 * `value` field. The `value` field is expected to be an array of two
42897 * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
42898 * case is for models with exactly one input and one output (e.g..
42899 * a sequential model). The latter case is for models with multiple
42900 * inputs and/or multiple outputs. Of the two items in the array, the
42901 * first is the input feature(s) and the second is the output target(s).
42902 * @param args A configuration object for the dataset-based evaluation.
42903 * @returns Loss and metric values as an Array of `Scalar` objects.
42904 *
42905 * @doc {heading: 'Models', subheading: 'Classes'}
42906 */
42907 async evaluateDataset(dataset, args) {
42908 this.makeTestFunction();
42909 return evaluateDataset(this, dataset, args);
42910 }
42911 /**
42912 * Get number of samples provided for training, evaluation or prediction.
42913 *
42914 * @param ins Input `tf.Tensor`.
42915 * @param batchSize Integer batch size, optional.
42916 * @param steps Total number of steps (batches of samples) before
42917 * declaring loop finished. Optional.
42918 * @param stepsName The public API's parameter name for `steps`.
42919 * @returns Number of samples provided.
42920 */
42921 checkNumSamples(ins, batchSize, steps, stepsName = 'steps') {
42922 let numSamples;
42923 if (steps != null) {
42924 numSamples = null;
42925 if (batchSize != null) {
42926 throw new ValueError(`If ${stepsName} is set, batchSize must be null or undefined.` +
42927 `Got batchSize = ${batchSize}`);
42928 }
42929 }
42930 else if (ins != null) {
42931 if (Array.isArray(ins)) {
42932 numSamples = ins[0].shape[0];
42933 }
42934 else {
42935 numSamples = ins.shape[0];
42936 }
42937 }
42938 else {
42939 throw new ValueError(`Either the input data should have a defined shape, or ` +
42940 `${stepsName} shoud be specified.`);
42941 }
42942 return numSamples;
42943 }
42944 /**
42945 * Execute internal tensors of the model with input data feed.
42946 * @param inputs Input data feed. Must match the inputs of the model.
42947 * @param outputs Names of the output tensors to be fetched. Must match
42948 * names of the SymbolicTensors that belong to the graph.
42949 * @returns Fetched values for `outputs`.
42950 */
42951 execute(inputs, outputs) {
42952 if (Array.isArray(outputs) && outputs.length === 0) {
42953 throw new ValueError('`outputs` is an empty Array, which is not allowed.');
42954 }
42955 const outputsIsArray = Array.isArray(outputs);
42956 const outputNames = (outputsIsArray ? outputs : [outputs]);
42957 const outputSymbolicTensors = this.retrieveSymbolicTensors(outputNames);
42958 // Format the input into a FeedDict.
42959 const feedDict = new FeedDict();
42960 if (inputs instanceof Tensor) {
42961 inputs = [inputs];
42962 }
42963 if (Array.isArray(inputs)) {
42964 if (inputs.length !== this.inputs.length) {
42965 throw new ValueError(`The number of inputs provided (${inputs.length}) ` +
42966 `does not match the number of inputs of this model ` +
42967 `(${this.inputs.length}).`);
42968 }
42969 for (let i = 0; i < this.inputs.length; ++i) {
42970 feedDict.add(this.inputs[i], inputs[i]);
42971 }
42972 }
42973 else {
42974 for (const input of this.inputs) {
42975 const tensorValue = inputs[input.name];
42976 if (tensorValue == null) {
42977 throw new ValueError(`No value is provided for the model's input ${input.name}`);
42978 }
42979 feedDict.add(input, tensorValue);
42980 }
42981 }
42982 // Run execution.
42983 const executeOutputs = execute(outputSymbolicTensors, feedDict);
42984 return outputsIsArray ? executeOutputs : executeOutputs[0];
42985 }
42986 /**
42987 * Retrieve the model's internal symbolic tensors from symbolic-tensor names.
42988 */
42989 retrieveSymbolicTensors(symbolicTensorNames) {
42990 const outputSymbolicTensors = pyListRepeat(null, symbolicTensorNames.length);
42991 let outputsRemaining = symbolicTensorNames.length;
42992 for (const layer of this.layers) {
42993 const layerOutputs = Array.isArray(layer.output) ? layer.output : [layer.output];
42994 const layerOutputNames = layerOutputs.map(output => output.name);
42995 for (let i = 0; i < symbolicTensorNames.length; ++i) {
42996 const index = layerOutputNames.indexOf(symbolicTensorNames[i]);
42997 if (index !== -1) {
42998 outputSymbolicTensors[i] = layerOutputs[index];
42999 outputsRemaining--;
43000 }
43001 if (outputsRemaining === 0) {
43002 break;
43003 }
43004 }
43005 if (outputsRemaining === 0) {
43006 break;
43007 }
43008 }
43009 if (outputsRemaining > 0) {
43010 const remainingNames = [];
43011 outputSymbolicTensors.forEach((tensor, i) => {
43012 if (tensor == null) {
43013 remainingNames.push(symbolicTensorNames[i]);
43014 }
43015 });
43016 throw new ValueError(`Cannot find SymbolicTensors for output name(s): ` +
43017 `${JSON.stringify(remainingNames)}`);
43018 }
43019 return outputSymbolicTensors;
43020 }
43021 /**
43022 * Helper method to loop over some data in batches.
43023 *
43024 * Porting Note: Not using the functional approach in the Python equivalent
43025 * due to the imperative backend.
43026 * Porting Note: Does not support step mode currently.
43027 *
43028 * @param ins: input data
43029 * @param batchSize: integer batch size.
43030 * @param verbose: verbosity model
43031 * @returns: Predictions as `tf.Tensor` (if a single output) or an `Array` of
43032 * `tf.Tensor` (if multipe outputs).
43033 */
43034 predictLoop(ins, batchSize = 32, verbose = false) {
43035 return tidy(() => {
43036 const numSamples = this.checkNumSamples(ins);
43037 if (verbose) {
43038 throw new NotImplementedError('Verbose predictLoop() is not implemented yet.');
43039 }
43040 // Sample-based predictions.
43041 // Porting Note: Tensor currently does not support sliced assignments as
43042 // in numpy, e.g., x[1:3] = y. Therefore we use concatenation while
43043 // iterating over the batches.
43044 const batches = makeBatches(numSamples, batchSize);
43045 const outsBatches = this.outputs.map(output => []);
43046 // TODO(cais): Can the scope() be pushed down inside the for loop?
43047 for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
43048 const batchOuts = tidy(() => {
43049 const batchStart = batches[batchIndex][0];
43050 const batchEnd = batches[batchIndex][1];
43051 // TODO(cais): Take care of the case of the last element is a flag for
43052 // training/test.
43053 const insBatch = sliceArrays(ins, batchStart, batchEnd);
43054 // Construct the feeds for execute();
43055 const feeds = [];
43056 if (Array.isArray(insBatch)) {
43057 for (let i = 0; i < insBatch.length; ++i) {
43058 feeds.push({ key: this.inputs[i], value: insBatch[i] });
43059 }
43060 }
43061 else {
43062 feeds.push({ key: this.inputs[0], value: insBatch });
43063 }
43064 const feedDict = new FeedDict(feeds);
43065 return execute(this.outputs, feedDict);
43066 });
43067 batchOuts.forEach((batchOut, i) => outsBatches[i].push(batchOut));
43068 }
43069 return singletonOrArray(outsBatches.map(batches => concat(batches, 0)));
43070 });
43071 }
43072 /**
43073 * Generates output predictions for the input samples.
43074 *
43075 * Computation is done in batches.
43076 *
43077 * Note: the "step" mode of predict() is currently not supported.
43078 * This is because the TensorFlow.js core backend is imperative only.
43079 *
43080 * ```js
43081 * const model = tf.sequential({
43082 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
43083 * });
43084 * model.predict(tf.ones([8, 10]), {batchSize: 4}).print();
43085 * ```
43086 *
43087 * @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if
43088 * the model has multiple inputs.
43089 * @param args A `ModelPredictArgs` object containing optional fields.
43090 *
43091 * @return Prediction results as a `tf.Tensor`(s).
43092 *
43093 * @exception ValueError In case of mismatch between the provided input data
43094 * and the model's expectations, or in case a stateful model receives a
43095 * number of samples that is not a multiple of the batch size.
43096 *
43097 * @doc {heading: 'Models', subheading: 'Classes'}
43098 */
43099 predict(x, args = {}) {
43100 const xsRank2OrHigher = ensureTensorsRank2OrHigher(x);
43101 checkInputData(xsRank2OrHigher, this.inputNames, this.feedInputShapes, false);
43102 try {
43103 // TODO(cais): Take care of stateful models.
43104 // if (this.stateful) ...
43105 // TODO(cais): Take care of the learning_phase boolean flag.
43106 // if (this.useLearningPhase) ...
43107 const batchSize = args.batchSize == null ? 32 : args.batchSize;
43108 checkBatchSize(batchSize);
43109 return this.predictLoop(xsRank2OrHigher, batchSize);
43110 }
43111 finally {
43112 disposeNewTensors(xsRank2OrHigher, x);
43113 }
43114 }
43115 /**
43116 * Returns predictions for a single batch of samples.
43117 *
43118 * ```js
43119 * const model = tf.sequential({
43120 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
43121 * });
43122 * model.predictOnBatch(tf.ones([8, 10])).print();
43123 * ```
43124 * @param x: Input samples, as a Tensor (for models with exactly one
43125 * input) or an array of Tensors (for models with more than one input).
43126 * @return Tensor(s) of predictions
43127 *
43128 * @doc {heading: 'Models', subheading: 'Classes'}
43129 */
43130 predictOnBatch(x) {
43131 checkInputData(x, this.inputNames, this.feedInputShapes, true);
43132 // TODO(cais): Take care of the learning_phase boolean flag.
43133 // if (this.useLearningPhase) ...
43134 const batchSize = (Array.isArray(x) ? x[0] : x).shape[0];
43135 return this.predictLoop(x, batchSize);
43136 }
43137 standardizeUserDataXY(x, y, checkBatchAxis = true, batchSize) {
43138 // TODO(cais): Add sampleWeight, classWeight
43139 if (this.optimizer_ == null) {
43140 throw new RuntimeError('You must compile a model before training/testing. Use ' +
43141 'LayersModel.compile(modelCompileArgs).');
43142 }
43143 const outputShapes = [];
43144 for (let i = 0; i < this.feedOutputShapes.length; ++i) {
43145 const outputShape = this.feedOutputShapes[i];
43146 const lossFn = this.feedLossFns[i];
43147 if (lossFn === sparseCategoricalCrossentropy) {
43148 outputShapes.push(outputShape.slice(0, outputShape.length - 1).concat([1]));
43149 }
43150 else {
43151 // Porting Note: Because of strong typing `lossFn` must be a function.
43152 outputShapes.push(outputShape);
43153 }
43154 }
43155 x = standardizeInputData(x, this.feedInputNames, this.feedInputShapes, false, 'input');
43156 y = standardizeInputData(y, this.feedOutputNames, outputShapes, false, 'target');
43157 // TODO(cais): Standardize sampleWeights & classWeights.
43158 checkArrayLengths(x, y, null);
43159 // TODO(cais): Check sampleWeights as well.
43160 checkLossAndTargetCompatibility(y, this.feedLossFns, this.feedOutputShapes);
43161 if (this.stateful && batchSize != null && batchSize > 0) {
43162 if (x[0].shape[0] % batchSize !== 0) {
43163 throw new ValueError(`In a stateful network, you should only pass inputs with a ` +
43164 `number of samples that is divisible by the batch size ` +
43165 `${batchSize}. Found: ${x[0].shape[0]} sample(s).`);
43166 }
43167 }
43168 return [x, y];
43169 }
43170 async standardizeUserData(x, y, sampleWeight, classWeight, checkBatchAxis = true, batchSize) {
43171 const [standardXs, standardYs] = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
43172 // TODO(cais): Handle sampleWeights.
43173 if (sampleWeight != null) {
43174 throw new Error('sample weight is not supported yet.');
43175 }
43176 let standardSampleWeights = null;
43177 if (classWeight != null) {
43178 const classWeights = standardizeClassWeights(classWeight, this.outputNames);
43179 standardSampleWeights = [];
43180 for (let i = 0; i < classWeights.length; ++i) {
43181 standardSampleWeights.push(await standardizeWeights(standardYs[i], null, classWeights[i]));
43182 }
43183 }
43184 // TODO(cais): Deal with the case of model.stateful == true.
43185 return [standardXs, standardYs, standardSampleWeights];
43186 }
43187 /**
43188 * Loop over some test data in batches.
43189 * @param f A Function returning a list of tensors.
43190 * @param ins Array of tensors to be fed to `f`.
43191 * @param batchSize Integer batch size or `null` / `undefined`.
43192 * @param verbose verbosity mode.
43193 * @param steps Total number of steps (batches of samples) before
43194 * declaring test finished. Ignored with the default value of `null` /
43195 * `undefined`.
43196 * @returns Array of Scalars.
43197 */
43198 testLoop(f, ins, batchSize, verbose = 0, steps) {
43199 return tidy(() => {
43200 const numSamples = this.checkNumSamples(ins, batchSize, steps, 'steps');
43201 const outs = [];
43202 if (verbose > 0) {
43203 throw new NotImplementedError('Verbose mode is not implemented yet.');
43204 }
43205 // TODO(cais): Use `indicesForConversionToDense' to prevent slow down.
43206 if (steps != null) {
43207 throw new NotImplementedError('steps mode in testLoop() is not implemented yet');
43208 }
43209 else {
43210 const batches = makeBatches(numSamples, batchSize);
43211 const indexArray = tensor1d(range$1(0, numSamples));
43212 for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
43213 const batchStart = batches[batchIndex][0];
43214 const batchEnd = batches[batchIndex][1];
43215 const batchIds = sliceAlongFirstAxis(indexArray, batchStart, batchEnd - batchStart);
43216 // TODO(cais): In ins, train flag can be a number, instead of an
43217 // Tensor? Do we need to handle this in tfjs-layers?
43218 const insBatch = sliceArraysByIndices(ins, batchIds);
43219 const batchOuts = f(insBatch);
43220 if (batchIndex === 0) {
43221 for (let i = 0; i < batchOuts.length; ++i) {
43222 outs.push(scalar(0));
43223 }
43224 }
43225 for (let i = 0; i < batchOuts.length; ++i) {
43226 const batchOut = batchOuts[i];
43227 outs[i] =
43228 add$1(outs[i], mul(batchEnd - batchStart, batchOut));
43229 }
43230 }
43231 for (let i = 0; i < outs.length; ++i) {
43232 outs[i] = div(outs[i], numSamples);
43233 }
43234 }
43235 return outs;
43236 });
43237 }
43238 getDedupedMetricsNames() {
43239 const outLabels = this.metricsNames;
43240 // Rename duplicated metrics names (can happen with an output layer
43241 // shared among multiple dataflows).
43242 const dedupedOutLabels = [];
43243 for (let i = 0; i < outLabels.length; ++i) {
43244 const label = outLabels[i];
43245 let newLabel = label;
43246 if (count(outLabels, label) > 1) {
43247 const dupIndex = count(outLabels.slice(0, i), label);
43248 newLabel += `_${dupIndex}`;
43249 }
43250 dedupedOutLabels.push(newLabel);
43251 }
43252 return dedupedOutLabels;
43253 }
43254 /**
43255 * Creates a function that performs the following actions:
43256 *
43257 * 1. computes the losses
43258 * 2. sums them to get the total loss
43259 * 3. call the optimizer computes the gradients of the LayersModel's
43260 * trainable weights w.r.t. the total loss and update the variables
43261 * 4. calculates the metrics
43262 * 5. returns the values of the losses and metrics.
43263 */
43264 makeTrainFunction() {
43265 return (data) => {
43266 const lossValues = [];
43267 const inputs = data.slice(0, this.inputs.length);
43268 const targets = data.slice(this.inputs.length, this.inputs.length + this.outputs.length);
43269 const sampleWeights = data.slice(this.inputs.length + this.outputs.length, this.inputs.length + this.outputs.length * 2);
43270 const metricsValues = [];
43271 // Create a function that computes the total loss based on the
43272 // inputs. This function is used for obtaining gradients through
43273 // backprop.
43274 const totalLossFunction = () => {
43275 const feeds = [];
43276 for (let i = 0; i < this.inputs.length; ++i) {
43277 feeds.push({ key: this.inputs[i], value: inputs[i] });
43278 }
43279 const feedDict = new FeedDict(feeds);
43280 const outputs = execute(this.outputs, feedDict, { 'training': true });
43281 // TODO(cais): Take care of the case of multiple outputs from a
43282 // single layer?
43283 let totalLoss;
43284 for (let i = 0; i < this.lossFunctions.length; ++i) {
43285 const lossFunction = this.lossFunctions[i];
43286 let loss = lossFunction(targets[i], outputs[i]);
43287 if (sampleWeights[i] != null) {
43288 loss = computeWeightedLoss$1(loss, sampleWeights[i]);
43289 }
43290 // TODO(cais): push Scalar instead.
43291 const meanLoss = mean(loss);
43292 // TODO(cais): Use a scope() instead, to avoid ownership.
43293 lossValues.push(meanLoss);
43294 if (i === 0) {
43295 totalLoss = loss;
43296 }
43297 else {
43298 totalLoss = add$1(totalLoss, loss);
43299 }
43300 }
43301 // Compute the metrics.
43302 // TODO(cais): These should probably be calculated outside
43303 // totalLossFunction to benefit speed?
43304 for (let i = 0; i < this.metricsTensors.length; ++i) {
43305 let weightedMetric;
43306 if (this.outputs.length > 1 && i < this.outputs.length) {
43307 weightedMetric = lossValues[i];
43308 }
43309 else {
43310 const metric = this.metricsTensors[i][0];
43311 const outputIndex = this.metricsTensors[i][1];
43312 weightedMetric =
43313 mean(metric(targets[outputIndex], outputs[outputIndex]));
43314 }
43315 keep(weightedMetric);
43316 // TODO(cais): Use a scope() instead, to avoid ownership.
43317 metricsValues.push(weightedMetric);
43318 }
43319 totalLoss = mean(totalLoss);
43320 // Add regularizer penalties.
43321 this.calculateLosses().forEach(regularizerLoss => {
43322 totalLoss = add$1(totalLoss, regularizerLoss);
43323 });
43324 return totalLoss;
43325 };
43326 const variables = this.collectedTrainableWeights.map(param => param.read());
43327 const returnCost = true;
43328 const totalLossValue = this.optimizer_.minimize(totalLossFunction, returnCost, variables);
43329 return [totalLossValue].concat(metricsValues);
43330 };
43331 }
43332 /**
43333 * Create a function which, when invoked with an array of `tf.Tensor`s as a
43334 * batch of inputs, returns the prespecified loss and metrics of the model
43335 * under the batch of input data.
43336 */
43337 makeTestFunction() {
43338 this.testFunction = (data) => {
43339 return tidy(() => {
43340 const valOutputs = [];
43341 let totalLoss;
43342 const inputs = data.slice(0, this.inputs.length);
43343 const targets = data.slice(this.inputs.length, this.inputs.length + this.outputs.length);
43344 const feeds = [];
43345 for (let i = 0; i < this.inputs.length; ++i) {
43346 feeds.push({ key: this.inputs[i], value: inputs[i] });
43347 }
43348 const feedDict = new FeedDict(feeds);
43349 const outputs = execute(this.outputs, feedDict);
43350 // Compute total loss.
43351 for (let i = 0; i < this.lossFunctions.length; ++i) {
43352 const lossFunction = this.lossFunctions[i];
43353 // TODO(cais): Add sample weighting and replace the simple
43354 // averaging.
43355 const loss = mean(lossFunction(targets[i], outputs[i]));
43356 if (i === 0) {
43357 totalLoss = loss;
43358 }
43359 else {
43360 totalLoss = add$1(totalLoss, loss);
43361 }
43362 valOutputs.push(totalLoss);
43363 }
43364 // Compute the metrics.
43365 for (let i = 0; i < this.metricsTensors.length; ++i) {
43366 const metric = this.metricsTensors[i][0];
43367 const outputIndex = this.metricsTensors[i][1];
43368 // TODO(cais): Replace K.mean() with a proper weighting function.
43369 const meanMetric = mean(metric(targets[outputIndex], outputs[outputIndex]));
43370 valOutputs.push(meanMetric);
43371 }
43372 return valOutputs;
43373 });
43374 };
43375 }
43376 /**
43377 * Trains the model for a fixed number of epochs (iterations on a
43378 * dataset).
43379 *
43380 * ```js
43381 * const model = tf.sequential({
43382 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
43383 * });
43384 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
43385 * for (let i = 1; i < 5 ; ++i) {
43386 * const h = await model.fit(tf.ones([8, 10]), tf.ones([8, 1]), {
43387 * batchSize: 4,
43388 * epochs: 3
43389 * });
43390 * console.log("Loss after Epoch " + i + " : " + h.history.loss[0]);
43391 * }
43392 * ```
43393 *
43394 * @param x `tf.Tensor` of training data, or an array of `tf.Tensor`s if the
43395 * model has multiple inputs. If all inputs in the model are named, you
43396 * can also pass a dictionary mapping input names to `tf.Tensor`s.
43397 * @param y `tf.Tensor` of target (label) data, or an array of `tf.Tensor`s if
43398 * the model has multiple outputs. If all outputs in the model are named,
43399 * you can also pass a dictionary mapping output names to `tf.Tensor`s.
43400 * @param args A `ModelFitArgs`, containing optional fields.
43401 *
43402 * @return A `History` instance. Its `history` attribute contains all
43403 * information collected during training.
43404 *
43405 * @exception ValueError In case of mismatch between the provided input
43406 * data and what the model expects.
43407 *
43408 * @doc {heading: 'Models', subheading: 'Classes'}
43409 */
43410 async fit(x, y, args = {}) {
43411 return fitTensors(this, x, y, args);
43412 }
43413 // TODO(cais): Add code snippet below when it's possible to instantiate
43414 // actual dataset objects.
43415 /**
43416 * Trains the model using a dataset object.
43417 *
43418 * @param dataset A dataset object. Its `iterator()` method is expected
43419 * to generate a dataset iterator object, the `next()` method of which
43420 * is expected to produce data batches for training. The return value
43421 * of the `next()` call ought to contain a boolean `done` field and a
43422 * `value` field. The `value` field is expected to be an array of two
43423 * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
43424 * case is for models with exactly one input and one output (e.g..
43425 * a sequential model). The latter case is for models with multiple
43426 * inputs and/or multiple outputs.
43427 * Of the two items in the array, the first is the input feature(s) and
43428 * the second is the output target(s).
43429 * @param args A `ModelFitDatasetArgs`, containing optional fields.
43430 *
43431 * @return A `History` instance. Its `history` attribute contains all
43432 * information collected during training.
43433 *
43434 * @doc {heading: 'Models', subheading: 'Classes'}
43435 */
43436 async fitDataset(dataset, args) {
43437 return fitDataset(this, dataset, args);
43438 }
43439 /**
43440 * Runs a single gradient update on a single batch of data.
43441 *
43442 * This method differs from `fit()` and `fitDataset()` in the following
43443 * regards:
43444 * - It operates on exactly one batch of data.
43445 * - It returns only the loss and matric values, instead of
43446 * returning the batch-by-batch loss and metric values.
43447 * - It doesn't support fine-grained options such as verbosity and
43448 * callbacks.
43449 *
43450 * @param x Input data. It could be one of the following:
43451 * - A `tf.Tensor`, or an Array of `tf.Tensor`s (in case the model has
43452 * multiple inputs).
43453 * - An Object mapping input names to corresponding `tf.Tensor` (if the
43454 * model has named inputs).
43455 * @param y Target darta. It could be either a `tf.Tensor` a multiple
43456 * `tf.Tensor`s. It should be consistent with `x`.
43457 * @returns Training loss or losses (in case the model has
43458 * multiple outputs), along with metrics (if any), as numbers.
43459 *
43460 * @doc {heading: 'Models', subheading: 'Classes'}
43461 */
43462 async trainOnBatch(x, y) {
43463 // TODO(cais): Support sampleWeight and classWeight.
43464 // TODO(cais): Support Dataset objects.
43465 const standardizeOut = await this.standardizeUserData(x, y);
43466 const inputs = standardizeOut[0];
43467 const targets = standardizeOut[1];
43468 const trainFunction = this.makeTrainFunction();
43469 const losses = trainFunction(inputs.concat(targets));
43470 const lossValues = [];
43471 for (const loss of losses) {
43472 const v = await loss.data();
43473 lossValues.push(v[0]);
43474 }
43475 dispose(losses);
43476 disposeNewTensors(standardizeOut[0], x);
43477 disposeNewTensors(standardizeOut[1], y);
43478 return singletonOrArray(lossValues);
43479 }
43480 /**
43481 * Extract weight values of the model.
43482 *
43483 * @param config: An instance of `io.SaveConfig`, which specifies
43484 * model-saving options such as whether only trainable weights are to be
43485 * saved.
43486 * @returns A `NamedTensorMap` mapping original weight names (i.e.,
43487 * non-uniqueified weight names) to their values.
43488 */
43489 getNamedWeights(config) {
43490 const namedWeights = [];
43491 const trainableOnly = config != null && config.trainableOnly;
43492 const weights = trainableOnly ? this.trainableWeights : this.weights;
43493 const weightValues = this.getWeights(trainableOnly);
43494 for (let i = 0; i < weights.length; ++i) {
43495 if (trainableOnly && !weights[i].trainable) {
43496 // Optionally skip non-trainable weights.
43497 continue;
43498 }
43499 namedWeights.push({ name: weights[i].originalName, tensor: weightValues[i] });
43500 }
43501 return namedWeights;
43502 }
43503 /**
43504 * Setter used for force stopping of LayersModel.fit() (i.e., training).
43505 *
43506 * Example:
43507 *
43508 * ```js
43509 * const input = tf.input({shape: [10]});
43510 * const output = tf.layers.dense({units: 1}).apply(input);
43511 * const model = tf.model({inputs: [input], outputs: [output]});
43512 * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
43513 * const xs = tf.ones([8, 10]);
43514 * const ys = tf.zeros([8, 1]);
43515 *
43516 * const history = await model.fit(xs, ys, {
43517 * epochs: 10,
43518 * callbacks: {
43519 * onEpochEnd: async (epoch, logs) => {
43520 * if (epoch === 2) {
43521 * model.stopTraining = true;
43522 * }
43523 * }
43524 * }
43525 * });
43526 *
43527 * // There should be only 3 values in the loss array, instead of 10
43528 * values,
43529 * // due to the stopping after 3 epochs.
43530 * console.log(history.history.loss);
43531 * ```
43532 */
43533 set stopTraining(stop) {
43534 this.stopTraining_ = stop;
43535 }
43536 get stopTraining() {
43537 return this.stopTraining_;
43538 }
43539 get optimizer() {
43540 return this.optimizer_;
43541 }
43542 set optimizer(optimizer) {
43543 if (this.optimizer_ !== optimizer) {
43544 this.optimizer_ = optimizer;
43545 this.isOptimizerOwned = false;
43546 }
43547 }
43548 dispose() {
43549 const result = super.dispose();
43550 if (result.refCountAfterDispose === 0 && this.optimizer != null &&
43551 this.isOptimizerOwned) {
43552 const numTensorsBeforeOptmizerDisposal = memory().numTensors;
43553 this.optimizer_.dispose();
43554 result.numDisposedVariables +=
43555 numTensorsBeforeOptmizerDisposal - memory().numTensors;
43556 }
43557 return result;
43558 }
43559 getLossIdentifiers() {
43560 let lossNames;
43561 if (typeof this.loss === 'string') {
43562 lossNames = toSnakeCase(this.loss);
43563 }
43564 else if (Array.isArray(this.loss)) {
43565 for (const loss of this.loss) {
43566 if (typeof loss !== 'string') {
43567 throw new Error('Serialization of non-string loss is not supported.');
43568 }
43569 }
43570 lossNames = this.loss.map(name => toSnakeCase(name));
43571 }
43572 else {
43573 const outputNames = Object.keys(this.loss);
43574 lossNames = {};
43575 const losses = this.loss;
43576 for (const outputName of outputNames) {
43577 if (typeof losses[outputName] === 'string') {
43578 lossNames[outputName] =
43579 toSnakeCase(losses[outputName]);
43580 }
43581 else {
43582 throw new Error('Serialization of non-string loss is not supported.');
43583 }
43584 }
43585 }
43586 return lossNames;
43587 }
43588 getMetricIdentifiers() {
43589 if (typeof this.metrics === 'string' ||
43590 typeof this.metrics === 'function') {
43591 return [toSnakeCase(getLossOrMetricName(this.metrics))];
43592 }
43593 else if (Array.isArray(this.metrics)) {
43594 return this.metrics.map(metric => toSnakeCase(getLossOrMetricName(metric)));
43595 }
43596 else {
43597 const metricsIdentifiers = {};
43598 for (const key in this.metrics) {
43599 metricsIdentifiers[key] =
43600 toSnakeCase(getLossOrMetricName(this.metrics[key]));
43601 }
43602 return metricsIdentifiers;
43603 }
43604 }
43605 getTrainingConfig() {
43606 return {
43607 loss: this.getLossIdentifiers(),
43608 metrics: this.getMetricIdentifiers(),
43609 optimizer_config: {
43610 class_name: this.optimizer.getClassName(),
43611 config: this.optimizer.getConfig()
43612 }
43613 };
43614 // TODO(cais): Add weight_metrics when they are supported.
43615 // TODO(cais): Add sample_weight_mode when it's supported.
43616 // TODO(cais): Add loss_weights when it's supported.
43617 }
43618 loadTrainingConfig(trainingConfig) {
43619 if (trainingConfig.weighted_metrics != null) {
43620 throw new Error('Loading weight_metrics is not supported yet.');
43621 }
43622 if (trainingConfig.loss_weights != null) {
43623 throw new Error('Loading loss_weights is not supported yet.');
43624 }
43625 if (trainingConfig.sample_weight_mode != null) {
43626 throw new Error('Loading sample_weight_mode is not supported yet.');
43627 }
43628 const tsConfig = convertPythonicToTs(trainingConfig.optimizer_config);
43629 const optimizer = deserialize(tsConfig);
43630 let loss;
43631 if (typeof trainingConfig.loss === 'string') {
43632 loss = toCamelCase(trainingConfig.loss);
43633 }
43634 else if (Array.isArray(trainingConfig.loss)) {
43635 loss = trainingConfig.loss.map(lossEntry => toCamelCase(lossEntry));
43636 }
43637 else if (trainingConfig.loss != null) {
43638 loss = {};
43639 for (const key in trainingConfig.loss) {
43640 loss[key] = toCamelCase(trainingConfig.loss[key]);
43641 }
43642 }
43643 let metrics;
43644 if (Array.isArray(trainingConfig.metrics)) {
43645 metrics = trainingConfig.metrics.map(metric => toCamelCase(metric));
43646 }
43647 else if (trainingConfig.metrics != null) {
43648 metrics = {};
43649 for (const key in trainingConfig.metrics) {
43650 metrics[key] = toCamelCase(trainingConfig.metrics[key]);
43651 }
43652 }
43653 this.compile({ loss, metrics, optimizer });
43654 }
43655 /**
43656 * Save the configuration and/or weights of the LayersModel.
43657 *
43658 * An `IOHandler` is an object that has a `save` method of the proper
43659 * signature defined. The `save` method manages the storing or
43660 * transmission of serialized data ("artifacts") that represent the
43661 * model's topology and weights onto or via a specific medium, such as
43662 * file downloads, local storage, IndexedDB in the web browser and HTTP
43663 * requests to a server. TensorFlow.js provides `IOHandler`
43664 * implementations for a number of frequently used saving mediums, such as
43665 * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
43666 * for more details.
43667 *
43668 * This method also allows you to refer to certain types of `IOHandler`s
43669 * as URL-like string shortcuts, such as 'localstorage://' and
43670 * 'indexeddb://'.
43671 *
43672 * Example 1: Save `model`'s topology and weights to browser [local
43673 * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
43674 * then load it back.
43675 *
43676 * ```js
43677 * const model = tf.sequential(
43678 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
43679 * console.log('Prediction from original model:');
43680 * model.predict(tf.ones([1, 3])).print();
43681 *
43682 * const saveResults = await model.save('localstorage://my-model-1');
43683 *
43684 * const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
43685 * console.log('Prediction from loaded model:');
43686 * loadedModel.predict(tf.ones([1, 3])).print();
43687 * ```
43688 *
43689 * Example 2. Saving `model`'s topology and weights to browser
43690 * [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
43691 * then load it back.
43692 *
43693 * ```js
43694 * const model = tf.sequential(
43695 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
43696 * console.log('Prediction from original model:');
43697 * model.predict(tf.ones([1, 3])).print();
43698 *
43699 * const saveResults = await model.save('indexeddb://my-model-1');
43700 *
43701 * const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
43702 * console.log('Prediction from loaded model:');
43703 * loadedModel.predict(tf.ones([1, 3])).print();
43704 * ```
43705 *
43706 * Example 3. Saving `model`'s topology and weights as two files
43707 * (`my-model-1.json` and `my-model-1.weights.bin`) downloaded from
43708 * browser.
43709 *
43710 * ```js
43711 * const model = tf.sequential(
43712 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
43713 * const saveResults = await model.save('downloads://my-model-1');
43714 * ```
43715 *
43716 * Example 4. Send `model`'s topology and weights to an HTTP server.
43717 * See the documentation of `tf.io.http` for more details
43718 * including specifying request parameters and implementation of the
43719 * server.
43720 *
43721 * ```js
43722 * const model = tf.sequential(
43723 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
43724 * const saveResults = await model.save('http://my-server/model/upload');
43725 * ```
43726 *
43727 * @param handlerOrURL An instance of `IOHandler` or a URL-like,
43728 * scheme-based string shortcut for `IOHandler`.
43729 * @param config Options for saving the model.
43730 * @returns A `Promise` of `SaveResult`, which summarizes the result of
43731 * the saving, such as byte sizes of the saved artifacts for the model's
43732 * topology and weight values.
43733 *
43734 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
43735 */
43736 async save(handlerOrURL, config) {
43737 if (typeof handlerOrURL === 'string') {
43738 const handlers = getSaveHandlers(handlerOrURL);
43739 if (handlers.length === 0) {
43740 throw new ValueError(`Cannot find any save handlers for URL '${handlerOrURL}'`);
43741 }
43742 else if (handlers.length > 1) {
43743 throw new ValueError(`Found more than one (${handlers.length}) save handlers for ` +
43744 `URL '${handlerOrURL}'`);
43745 }
43746 handlerOrURL = handlers[0];
43747 }
43748 if (handlerOrURL.save == null) {
43749 throw new ValueError('LayersModel.save() cannot proceed because the IOHandler ' +
43750 'provided does not have the `save` attribute defined.');
43751 }
43752 const weightDataAndSpecs = await encodeWeights(this.getNamedWeights(config));
43753 const returnString = false;
43754 const unusedArg = null;
43755 const modelConfig = this.toJSON(unusedArg, returnString);
43756 const modelArtifacts = {
43757 modelTopology: modelConfig,
43758 format: LAYERS_MODEL_FORMAT_NAME,
43759 generatedBy: `TensorFlow.js tfjs-layers v${version$1}`,
43760 convertedBy: null,
43761 };
43762 const includeOptimizer = config == null ? false : config.includeOptimizer;
43763 if (includeOptimizer && this.optimizer != null) {
43764 modelArtifacts.trainingConfig = this.getTrainingConfig();
43765 const weightType = 'optimizer';
43766 const { data: optimizerWeightData, specs: optimizerWeightSpecs } = await encodeWeights(await this.optimizer.getWeights(), weightType);
43767 weightDataAndSpecs.specs.push(...optimizerWeightSpecs);
43768 weightDataAndSpecs.data = concatenateArrayBuffers([weightDataAndSpecs.data, optimizerWeightData]);
43769 }
43770 if (this.userDefinedMetadata != null) {
43771 // Check serialized size of user-defined metadata.
43772 const checkSize = true;
43773 checkUserDefinedMetadata(this.userDefinedMetadata, this.name, checkSize);
43774 modelArtifacts.userDefinedMetadata = this.userDefinedMetadata;
43775 }
43776 modelArtifacts.weightData = weightDataAndSpecs.data;
43777 modelArtifacts.weightSpecs = weightDataAndSpecs.specs;
43778 return handlerOrURL.save(modelArtifacts);
43779 }
43780 /**
43781 * Set user-defined metadata.
43782 *
43783 * The set metadata will be serialized together with the topology
43784 * and weights of the model during `save()` calls.
43785 *
43786 * @param setUserDefinedMetadata
43787 */
43788 setUserDefinedMetadata(userDefinedMetadata) {
43789 checkUserDefinedMetadata(userDefinedMetadata, this.name);
43790 this.userDefinedMetadata = userDefinedMetadata;
43791 }
43792 /**
43793 * Get user-defined metadata.
43794 *
43795 * The metadata is supplied via one of the two routes:
43796 * 1. By calling `setUserDefinedMetadata()`.
43797 * 2. Loaded during model loading (if the model is constructed
43798 * via `tf.loadLayersModel()`.)
43799 *
43800 * If no user-defined metadata is available from either of the
43801 * two routes, this function will return `undefined`.
43802 */
43803 getUserDefinedMetadata() {
43804 return this.userDefinedMetadata;
43805 }
43806 }
43807 // The class name is 'Model' rather than 'LayersModel' for backwards
43808 // compatibility since this class name shows up in the serialization format.
43809 /** @nocollapse */
43810 LayersModel.className = 'Model';
43811 registerClass(LayersModel);
43812 /**
43813 * A `tf.Functional` is an alias to `tf.LayersModel`.
43814 *
43815 * See also:
43816 * `tf.LayersModel`, `tf.Sequential`, `tf.loadLayersModel`.
43817 */
43818 /** @doc {heading: 'Models', subheading: 'Classes'} */
43819 class Functional extends LayersModel {
43820 }
43821 Functional.className = 'Functional';
43822 registerClass(Functional);
43823
43824 /**
43825 * @license
43826 * Copyright 2018 Google LLC
43827 *
43828 * Use of this source code is governed by an MIT-style
43829 * license that can be found in the LICENSE file or at
43830 * https://opensource.org/licenses/MIT.
43831 * =============================================================================
43832 */
43833 /**
43834 * Parses a JSON model configuration file and returns a model instance.
43835 *
43836 * ```js
43837 * // This example shows how to serialize a model using `toJSON()` and
43838 * // deserialize it as another model using `tf.models.modelFromJSON()`.
43839 * // Note: this example serializes and deserializes only the topology
43840 * // of the model; the weights of the loaded model will be different
43841 * // from those of the the original model, due to random weight
43842 * // initialization.
43843 * // To load the topology and weights of a model, use `tf.loadLayersModel()`.
43844 * const model1 = tf.sequential();
43845 * model1.add(tf.layers.repeatVector({inputShape: [2], n: 4}));
43846 * // Serialize `model1` as a JSON object.
43847 * const model1JSON = model1.toJSON(null, false);
43848 * model1.summary();
43849 *
43850 * const model2 = await tf.models.modelFromJSON(model1JSON);
43851 * model2.summary();
43852 * ```
43853 *
43854 * @param modelAndWeightsConfig JSON object or string encoding a model and
43855 * weights configuration. It can also be only the topology JSON of the
43856 * model, in which case the weights will not be loaded.
43857 * @param custom_objects Optional dictionary mapping names
43858 * (strings) to custom classes or functions to be
43859 * considered during deserialization.
43860 * @returns A TensorFlow.js Layers `tf.LayersModel` instance (uncompiled).
43861 */
43862 async function modelFromJSON(modelAndWeightsConfig, customObjects) {
43863 if (!('modelTopology' in modelAndWeightsConfig)) {
43864 modelAndWeightsConfig = { modelTopology: modelAndWeightsConfig };
43865 }
43866 modelAndWeightsConfig = modelAndWeightsConfig;
43867 let modelTopology = modelAndWeightsConfig.modelTopology;
43868 if (modelTopology['model_config'] != null) {
43869 // If the model-topology JSON contains a 'model_config' field, then it is
43870 // a full model JSON (e.g., from `keras.Model.save()`), which contains
43871 // not only the model's architecture in its 'model_config' field, but
43872 // additional information such as the model's optimizer. We use only the
43873 // 'model_config' field currently.
43874 modelTopology = modelTopology['model_config'];
43875 }
43876 const tsConfig = convertPythonicToTs(modelTopology);
43877 const model = deserialize(tsConfig, customObjects);
43878 if (modelAndWeightsConfig.weightsManifest != null) {
43879 // Load the weight values keyed by the original tensor names in the model
43880 // file that was loaded. These should match the keys of the weight
43881 // manifest.
43882 const weightValues = await loadWeights(modelAndWeightsConfig.weightsManifest, modelAndWeightsConfig.pathPrefix, model.weights.map(weight => weight.originalName));
43883 // Map the weights to the unique tensor names generated during model loading
43884 const uniqueWeightValues = {};
43885 for (const weight of model.weights) {
43886 uniqueWeightValues[weight.originalName] =
43887 weightValues[weight.originalName];
43888 }
43889 model.loadWeights(uniqueWeightValues);
43890 // Dispose temporary weight values.
43891 dispose(weightValues);
43892 }
43893 return model;
43894 }
43895 /**
43896 * Load a model, including its topology and optionally weights. See the
43897 * Tutorial named "How to import a Keras Model" for usage examples.
43898 *
43899 * Example 1: Save `model`'s topology and weights to browser [local
43900 * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
43901 * then load it back.
43902 *
43903 * ```js
43904 * const model = tf.sequential(
43905 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
43906 * console.log('Prediction from original model:');
43907 * model.predict(tf.ones([1, 3])).print();
43908 *
43909 * const saveResults = await model.save('localstorage://my-model-1');
43910 *
43911 * const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
43912 * console.log('Prediction from loaded model:');
43913 * loadedModel.predict(tf.ones([1, 3])).print();
43914 * ```
43915 *
43916 * Example 2. Saving `model`'s topology and weights to browser
43917 * [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
43918 * then load it back.
43919 *
43920 * ```js
43921 * const model = tf.sequential(
43922 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
43923 * console.log('Prediction from original model:');
43924 * model.predict(tf.ones([1, 3])).print();
43925 *
43926 * const saveResults = await model.save('indexeddb://my-model-1');
43927 *
43928 * const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
43929 * console.log('Prediction from loaded model:');
43930 * loadedModel.predict(tf.ones([1, 3])).print();
43931 * ```
43932 *
43933 * Example 3. Load a model from user-selected files from HTML
43934 * [file input
43935 * elements](https://developer.mozilla.org/en-US/docs/Web/HTML/Element/input/file).
43936 *
43937 * ```js
43938 * // Note: this code snippet will not work without the HTML elements in the
43939 * // page
43940 * const jsonUpload = document.getElementById('json-upload');
43941 * const weightsUpload = document.getElementById('weights-upload');
43942 *
43943 * const model = await tf.loadLayersModel(
43944 * tf.io.browserFiles([jsonUpload.files[0], weightsUpload.files[0]]));
43945 * ```
43946 *
43947 * Example 4. Load a model from an HTTP server.
43948 *
43949 * ```js
43950 * const model = await
43951 * tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json');
43952 * model.summary();
43953 * ```
43954 *
43955 * @param pathOrIOHandler Can be either of the two formats
43956 * 1. A string path to the `ModelAndWeightsConfig` JSON describing
43957 * the model in the canonical TensorFlow.js format. This path will be
43958 * interpreted as a relative HTTP path, to which `fetch` will be used to
43959 * request the model topology and weight manifest JSON.
43960 * The content of the JSON file is assumed to be a JSON object with the
43961 * following fields and values:
43962 * - 'modelTopology': A JSON object that can be either of:
43963 * 1. a model architecture JSON consistent with the format of the return
43964 * value of `keras.Model.to_json()`
43965 * 2. a full model JSON in the format of `keras.models.save_model()`.
43966 * - 'weightsManifest': A TensorFlow.js weights manifest.
43967 * See the Python converter function `save_model()` for more details.
43968 * It is also assumed that model weights can be accessed from relative
43969 * paths described by the `paths` fields in weights manifest.
43970 * 2. An `tf.io.IOHandler` object that loads model artifacts with its `load`
43971 * method.
43972 * @param options Optional configuration arguments for the model loading,
43973 * including:
43974 * - `strict`: Require that the provided weights exactly match those required
43975 * by the layers. Default true. Passing false means that both extra
43976 * weights and missing weights will be silently ignored.
43977 * - `onProgress`: A progress callback of the form:
43978 * `(fraction: number) => void`. This callback can be used to monitor the
43979 * model-loading process.
43980 * @returns A `Promise` of `tf.LayersModel`, with the topology and weights
43981 * loaded.
43982 */
43983 async function loadLayersModelInternal(pathOrIOHandler, options) {
43984 if (options == null) {
43985 options = {};
43986 }
43987 if (typeof pathOrIOHandler === 'string') {
43988 const handlers = getLoadHandlers(pathOrIOHandler, options);
43989 if (handlers.length === 0) {
43990 // For backward compatibility: if no load handler can be found,
43991 // assume it is a relative http path.
43992 // TODO(cais): Reformat the args into a single `LoadOptions` once the core
43993 // is refactored.
43994 handlers.push(browserHTTPRequest(pathOrIOHandler, options));
43995 }
43996 else if (handlers.length > 1) {
43997 throw new ValueError(`Found more than one (${handlers.length}) load handlers for ` +
43998 `URL '${pathOrIOHandler}'`);
43999 }
44000 pathOrIOHandler = handlers[0];
44001 }
44002 return loadLayersModelFromIOHandler(pathOrIOHandler, undefined, options);
44003 }
44004 /**
44005 * Load a model and optionally its weights, using an IOHandler object.
44006 *
44007 * @param handler The instance of `IOHandler` to be used during the model
44008 * loading.
44009 * @param customObjects Any optional custom objects to be used during model
44010 * loading.
44011 * @param strict Whether the weight loading will be done in strict mode.
44012 * Default: `true`.
44013 */
44014 async function loadLayersModelFromIOHandler(handler, customObjects, options) {
44015 if (options == null) {
44016 options = {};
44017 }
44018 if (handler.load == null) {
44019 throw new ValueError('Cannot proceed with model loading because the IOHandler provided ' +
44020 'does not have the `load` method implemented.');
44021 }
44022 const artifacts = await handler.load();
44023 let modelTopology = artifacts.modelTopology;
44024 if (modelTopology['model_config'] != null) {
44025 modelTopology = modelTopology['model_config'];
44026 }
44027 const strict = options.strict == null ? true : options.strict;
44028 // If weights are provided and the weight-loading mode is strict, use
44029 // fast weight initialization. This skips costly initializers such as
44030 // 'orthogonal' and saves unnecessary computation in cases where
44031 // the initialized weight values will immediately be overwritten by
44032 // loaded weight values.
44033 const fastWeightInit = artifacts.weightData != null && artifacts.weightSpecs != null && strict;
44034 const model = deserialize(convertPythonicToTs(modelTopology), customObjects, fastWeightInit);
44035 const trainingConfig = artifacts.trainingConfig;
44036 if (trainingConfig != null) {
44037 model.loadTrainingConfig(trainingConfig);
44038 }
44039 if (artifacts.userDefinedMetadata != null) {
44040 model.setUserDefinedMetadata(artifacts.userDefinedMetadata);
44041 }
44042 // If weightData is present, load the weights into the model.
44043 if (artifacts.weightData != null) {
44044 // Loading weights requires weightSpecs.
44045 if (artifacts.weightSpecs == null) {
44046 throw new ValueError('LayersModel artifacts contains weight data, but not weight specs. ' +
44047 'Therefore loading of weights cannot proceed.');
44048 }
44049 const { modelWeights, optimizerWeights } = decodeModelAndOptimizerWeights(artifacts.weightData, artifacts.weightSpecs);
44050 model.loadWeights(modelWeights, strict);
44051 if (model.optimizer != null && optimizerWeights.length > 0) {
44052 await model.optimizer.setWeights(optimizerWeights);
44053 }
44054 // Dispose temporary weight values.
44055 dispose(modelWeights);
44056 dispose(optimizerWeights.map(w => w.tensor));
44057 }
44058 return model;
44059 }
44060 function decodeModelAndOptimizerWeights(buffer, specs) {
44061 const name2Tensor = decodeWeights(buffer, specs);
44062 const modelWeights = {};
44063 const optimizerWeights = [];
44064 specs.forEach(spec => {
44065 if (spec.group === 'optimizer') {
44066 optimizerWeights.push({ name: spec.name, tensor: name2Tensor[spec.name] });
44067 }
44068 else {
44069 modelWeights[spec.name] = name2Tensor[spec.name];
44070 }
44071 });
44072 return { modelWeights, optimizerWeights };
44073 }
44074 /**
44075 * A model with a stack of layers, feeding linearly from one to the next.
44076 *
44077 * `tf.sequential` is a factory function that creates an instance of
44078 * `tf.Sequential`.
44079 *
44080 * ```js
44081 * // Define a model for linear regression.
44082 * const model = tf.sequential();
44083 * model.add(tf.layers.dense({units: 1, inputShape: [1]}));
44084 *
44085 * // Prepare the model for training: Specify the loss and the optimizer.
44086 * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
44087 *
44088 * // Generate some synthetic data for training.
44089 * const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
44090 * const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
44091 *
44092 * // Train the model using the data then do inference on a data point the
44093 * // model hasn't seen:
44094 * await model.fit(xs, ys);
44095 * model.predict(tf.tensor2d([5], [1, 1])).print();
44096 * ```
44097 *
44098 * @doc {heading: 'Models', subheading: 'Classes'}
44099 */
44100 class Sequential extends LayersModel {
44101 constructor(args) {
44102 super({ inputs: [], outputs: [] });
44103 args = args || {};
44104 this.trainable = true;
44105 this.built = false;
44106 // Set model name.
44107 this.name = (args.name != null) ? args.name : getUid('sequential_');
44108 // Add to the model any layers passed to the constructor.
44109 if (args.layers != null) {
44110 for (const layer of args.layers) {
44111 this.add(layer);
44112 }
44113 }
44114 }
44115 // Helper function to Sequential.add Throws if the new output shape will be
44116 // invalid.
44117 checkShape(layer) {
44118 const shape = layer.inboundNodes[0].outputTensors[0].shape;
44119 if (shape.some(x => x < 0)) {
44120 throw new ValueError('Negative dimension size caused by adding layer ' +
44121 `${layer.name} with input shape [` +
44122 `${layer.inboundNodes[0].inputTensors[0].shape}]`);
44123 }
44124 }
44125 /**
44126 * Adds a layer instance on top of the layer stack.
44127 *
44128 * ```js
44129 * const model = tf.sequential();
44130 * model.add(tf.layers.dense({units: 8, inputShape: [1]}));
44131 * model.add(tf.layers.dense({units: 4, activation: 'relu6'}));
44132 * model.add(tf.layers.dense({units: 1, activation: 'relu6'}));
44133 * // Note that the untrained model is random at this point.
44134 * model.predict(tf.randomNormal([10, 1])).print();
44135 * ```
44136 * @param layer Layer instance.
44137 *
44138 * @exception ValueError In case the `layer` argument does not know its
44139 * input shape.
44140 * @exception ValueError In case the `layer` argument has multiple output
44141 * tensors, or is already connected somewhere else (forbidden in
44142 * `Sequential` models).
44143 *
44144 * @doc {heading: 'Models', subheading: 'Classes'}
44145 */
44146 add(layer) {
44147 const isLayerModelInstance = layer instanceof Sequential || layer instanceof LayersModel;
44148 let modelLayer;
44149 if (isLayerModelInstance) {
44150 modelLayer = layer;
44151 if (modelLayer.outputs.length !== 1) {
44152 throw new ValueError('All layers in a Sequential model ' +
44153 'should have a single output tensor. ' +
44154 'For multi-output layers, ' +
44155 'use the functional API.');
44156 }
44157 if (modelLayer.inputs.length !== 1) {
44158 throw new ValueError('All layers in a Sequential model ' +
44159 'should have a single input tensor. ' +
44160 'For multi-input layers, ' +
44161 'use the functional API.');
44162 }
44163 }
44164 if (this.outputs.length === 0) {
44165 // first layer in model: check that it is an input layer
44166 if (layer.inboundNodes.length === 0) {
44167 // create an input layer
44168 if (layer.batchInputShape == null) {
44169 throw new ValueError('The first layer in a Sequential model must ' +
44170 'get an `inputShape` or `batchInputShape` argument.');
44171 }
44172 // Instantiate the input layer.
44173 const x = Input({
44174 batchShape: layer.batchInputShape,
44175 dtype: layer.dtype,
44176 name: layer.name + '_input'
44177 });
44178 // This will build the current layer and create the node connecting
44179 // the current layer to the input layer we just created.
44180 layer.apply(x);
44181 }
44182 if (isLayerModelInstance) {
44183 this.outputs = modelLayer.outputs;
44184 this.inputs = modelLayer.inputs;
44185 }
44186 else {
44187 if (layer.inboundNodes.length !== 1) {
44188 throw new ValueError('A layer added to a Sequential model must not already be ' +
44189 `connected somewhere else. LayersModel received layer ${layer.name} ` +
44190 `which has ${layer.inboundNodes.length} pre-existing inbound ` +
44191 'connections.');
44192 }
44193 if (layer.inboundNodes[0].outputTensors.length !== 1) {
44194 throw new ValueError('All layers in a Sequential model ' +
44195 'should have a single output tensor. ' +
44196 'For multi-output layers, ' +
44197 'use the functional API.');
44198 }
44199 this.checkShape(layer);
44200 this.outputs = [layer.inboundNodes[0].outputTensors[0]];
44201 this.inputs = getSourceInputs(this.outputs[0]);
44202 }
44203 this.inboundNodes = [];
44204 // We create an input node, which we will keep updated
44205 // as we add more layers.
44206 // (This call has side effects.)
44207 // tslint:disable-next-line:no-unused-expression
44208 new Node({
44209 outboundLayer: this,
44210 inboundLayers: [],
44211 nodeIndices: [],
44212 tensorIndices: [],
44213 inputTensors: this.inputs,
44214 outputTensors: this.outputs,
44215 // no model-level masking for now
44216 inputMasks: pyListRepeat(null, this.inputs.length),
44217 outputMasks: [null],
44218 inputShapes: this.inputs.map(x => x.shape),
44219 outputShapes: this.outputs[0].shape
44220 });
44221 }
44222 else {
44223 const outputTensor = layer.apply(this.outputs[0]);
44224 if (Array.isArray(outputTensor)) {
44225 throw new TypeError('All layers in a Sequential model ' +
44226 'should have a single output tensor. ' +
44227 'For multi-output layers, ' +
44228 'use the functional API.');
44229 }
44230 this.checkShape(layer);
44231 this.outputs = [outputTensor];
44232 // update self.inbound_nodes
44233 this.inboundNodes[0].outputTensors = this.outputs;
44234 this.inboundNodes[0].outputShapes = [this.outputs[0].shape];
44235 }
44236 this.layers.push(layer);
44237 this.built = false;
44238 }
44239 /**
44240 * Removes the last layer in the model.
44241 *
44242 * @exception TypeError if there are no layers in the model.
44243 */
44244 pop() {
44245 if (this.layers.length === 0) {
44246 throw new TypeError('There are no layers in the model.');
44247 }
44248 this.layers.pop();
44249 if (this.layers.length === 0) {
44250 this.outputs = [];
44251 this.inboundNodes = [];
44252 this.outboundNodes = [];
44253 }
44254 else {
44255 const lastLayerIndex = this.layers.length - 1;
44256 this.layers[lastLayerIndex].outboundNodes = [];
44257 this.outputs = [this.layers[lastLayerIndex].output];
44258 // update self.inbound_nodes
44259 this.inboundNodes[0].outputTensors = this.outputs;
44260 this.inboundNodes[0].outputShapes = [this.outputs[0].shape];
44261 }
44262 }
44263 call(inputs, kwargs) {
44264 if (this.model == null) {
44265 this.build();
44266 }
44267 return this.model.call(inputs, kwargs);
44268 }
44269 build(inputShape) {
44270 // Call `getExactlyOneShape` without using its return value,
44271 // to verify that exactly one input shape is provided.
44272 getExactlyOneShape(inputShape);
44273 if (this.inputs.length === 0 || this.outputs.length === 0) {
44274 throw new TypeError('Sequential model cannot be built: model is empty.' +
44275 ' Add some layers first.');
44276 }
44277 // actually create the model
44278 this.model = new LayersModel({
44279 inputs: this.inputs,
44280 outputs: this.outputs[0],
44281 name: this.name + '_model'
44282 });
44283 this.model.trainable = this.trainable;
44284 // mirror model attributes
44285 this.supportsMasking = this.model.supportsMasking;
44286 // TODO(michaelterry): Add caches
44287 this.inputLayers = this.model.inputLayers;
44288 this.inputLayersNodeIndices = this.model.inputLayersNodeIndices;
44289 this.inputLayersTensorIndices = this.model.inputLayersTensorIndices;
44290 this.outputLayers = this.model.outputLayers;
44291 this.outputLayersNodeIndices = this.model.outputLayersNodeIndices;
44292 this.outputLayersTensorIndices = this.model.outputLayersTensorIndices;
44293 this.nodesByDepth = this.model.nodesByDepth;
44294 this.containerNodes = this.model.containerNodes;
44295 this.outputNames = this.model.outputNames;
44296 this.inputNames = this.model.inputNames;
44297 // TODO(michaelterry): Add feedInputNames, feedInputs, if needed.
44298 // TODO(michaelterry): Add callbackModel if needed.
44299 this.built = true;
44300 }
44301 countParams() {
44302 if (!this.built) {
44303 this.build();
44304 }
44305 return super.countParams();
44306 }
44307 /**
44308 * Print a text summary of the Sequential model's layers.
44309 *
44310 * The summary includes
44311 * - Name and type of all layers that comprise the model.
44312 * - Output shape(s) of the layers
44313 * - Number of weight parameters of each layer
44314 * - The total number of trainable and non-trainable parameters of the
44315 * model.
44316 *
44317 * ```js
44318 * const model = tf.sequential();
44319 * model.add(
44320 * tf.layers.dense({units: 100, inputShape: [10], activation: 'relu'}));
44321 * model.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));
44322 *
44323 * model.summary();
44324 * ```
44325 *
44326 * @param lineLength Custom line length, in number of characters.
44327 * @param positions Custom widths of each of the columns, as either
44328 * fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number
44329 * of characters (e.g., `[30, 50, 65]`). Each number corresponds to
44330 * right-most (i.e., ending) position of a column.
44331 * @param printFn Custom print function. Can be used to replace the default
44332 * `console.log`. For example, you can use `x => {}` to mute the printed
44333 * messages in the console.
44334 *
44335 * @doc {heading: 'Models', subheading: 'Classes'}
44336 */
44337 summary(lineLength, positions, printFn = console.log) {
44338 if (!this.built) {
44339 this.build();
44340 }
44341 super.summary(lineLength, positions, printFn);
44342 }
44343 /**
44344 * Sets the weights of the model.
44345 *
44346 * @param weights Should be a list of Tensors with shapes and types matching
44347 * the output of `model.getWeights()`.
44348 */
44349 setWeights(weights) {
44350 if (this.model == null) {
44351 this.build();
44352 }
44353 this.model.setWeights(weights);
44354 }
44355 /**
44356 * Returns the loss value & metrics values for the model in test mode.
44357 *
44358 * Loss and metrics are specified during `compile()`, which needs to happen
44359 * before calls to `evaluate()`.
44360 *
44361 * Computation is done in batches.
44362 *
44363 * ```js
44364 * const model = tf.sequential({
44365 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
44366 * });
44367 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
44368 * const result = model.evaluate(tf.ones([8, 10]), tf.ones([8, 1]), {
44369 * batchSize: 4,
44370 * });
44371 * result.print();
44372 * ```
44373 *
44374 * @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the
44375 * model has multiple inputs.
44376 * @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the
44377 * model has multiple outputs.
44378 * @param args A `ModelEvaluateConfig`, containing optional fields.
44379 *
44380 * @return `Scalar` test loss (if the model has a single output and no
44381 * metrics) or `Array` of `Scalar`s (if the model has multiple outputs
44382 * and/or metrics). The attribute `model.metricsNames`
44383 * will give you the display labels for the scalar outputs.
44384 *
44385 * @doc {heading: 'Models', subheading: 'Classes'}
44386 */
44387 evaluate(x, y, args = {}) {
44388 if (!this.built) {
44389 throw new RuntimeError('The model needs to be compiled before being used.');
44390 }
44391 return this.model.evaluate(x, y, args);
44392 }
44393 // TODO(cais): Add code snippet below once real dataset objects are
44394 // available.
44395 /**
44396 * Evaluate model using a dataset object.
44397 *
44398 * Note: Unlike `evaluate()`, this method is asynchronous (`async`);
44399 *
44400 * @param dataset A dataset object. Its `iterator()` method is expected
44401 * to generate a dataset iterator object, the `next()` method of which
44402 * is expected to produce data batches for evaluation. The return value
44403 * of the `next()` call ought to contain a boolean `done` field and a
44404 * `value` field. The `value` field is expected to be an array of two
44405 * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
44406 * case is for models with exactly one input and one output (e.g..
44407 * a sequential model). The latter case is for models with multiple
44408 * inputs and/or multiple outputs. Of the two items in the array, the
44409 * first is the input feature(s) and the second is the output target(s).
44410 * @param args A configuration object for the dataset-based evaluation.
44411 * @returns Loss and metric values as an Array of `Scalar` objects.
44412 *
44413 * @doc {heading: 'Models', subheading: 'Classes'}
44414 */
44415 async evaluateDataset(dataset, args) {
44416 if (!this.built) {
44417 throw new RuntimeError('The model needs to be compiled before being used.');
44418 }
44419 return this.model.evaluateDataset(dataset, args);
44420 }
44421 /**
44422 * Generates output predictions for the input samples.
44423 *
44424 * Computation is done in batches.
44425 *
44426 * Note: the "step" mode of predict() is currently not supported.
44427 * This is because the TensorFlow.js core backend is imperative only.
44428 *
44429 * ```js
44430 * const model = tf.sequential({
44431 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
44432 * });
44433 * model.predict(tf.ones([2, 10])).print();
44434 * ```
44435 *
44436 * @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if
44437 * the model has multiple inputs.
44438 * @param conifg A `ModelPredictConfig` object containing optional fields.
44439 *
44440 * @return `tf.Tensor`(s) of predictions.
44441 *
44442 * @exception ValueError In case of mismatch between the provided input data
44443 * and the model's expectations, or in case a stateful model receives a
44444 * number of samples that is not a multiple of the batch size.
44445 *
44446 * @doc {heading: 'Models', subheading: 'Classes'}
44447 */
44448 predict(x, args = {}) {
44449 if (this.model == null) {
44450 this.build();
44451 }
44452 return this.model.predict(x, args);
44453 }
44454 /**
44455 * Returns predictions for a single batch of samples.
44456 *
44457 * @param x: Input samples, as a Tensor, or list of Tensors (if the model
44458 * has multiple inputs).
44459 * @return Tensor(s) of predictions
44460 */
44461 predictOnBatch(x) {
44462 if (this.model == null) {
44463 this.build();
44464 }
44465 return this.model.predictOnBatch(x);
44466 }
44467 /**
44468 * See `LayersModel.compile`.
44469 *
44470 * @param args
44471 */
44472 compile(args) {
44473 this.build();
44474 this.model.compile(args);
44475 this.optimizer_ = this.model.optimizer;
44476 // tslint:disable-next-line:no-any
44477 this.isOptimizerOwned = this.model.isOptimizerOwned;
44478 this.loss = this.model.loss;
44479 this.metrics = this.model.metrics;
44480 // TODO(cais): Add this.lossWeights, this.sampleWeightMode,
44481 // this.weightedMetrics, this.targets.
44482 this.metricsTensors = this.model.metricsTensors;
44483 this.metricsNames = this.model.metricsNames;
44484 // TODO(cais): Add sampleWeights.
44485 }
44486 get optimizer() {
44487 return this.model == null ? undefined : this.model.optimizer;
44488 }
44489 set optimizer(optimizer) {
44490 this.model.optimizer = optimizer;
44491 }
44492 /**
44493 * Trains the model for a fixed number of epochs (iterations on a dataset).
44494 *
44495 * ```js
44496 * const model = tf.sequential({
44497 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
44498 * });
44499 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
44500 * const history = await model.fit(tf.ones([8, 10]), tf.ones([8, 1]), {
44501 * batchSize: 4,
44502 * epochs: 3
44503 * });
44504 * console.log(history.history.loss[0]);
44505 * ```
44506 *
44507 * @param x `tf.Tensor` of training data, or an array of `tf.Tensor`s if the
44508 * model has multiple inputs. If all inputs in the model are named, you can
44509 * also pass a dictionary mapping input names to `tf.Tensor`s.
44510 * @param y `tf.Tensor` of target (label) data, or an array of `tf.Tensor`s if
44511 * the model has multiple outputs. If all outputs in the model are named, you
44512 * can also pass a dictionary mapping output names to `tf.Tensor`s.
44513 * @param args A `ModelFitConfig`, containing optional fields.
44514 *
44515 * @return A `History` instance. Its `history` attribute contains all
44516 * information collected during training.
44517 *
44518 * @exception ValueError In case of mismatch between the provided input data
44519 * and what the model expects.
44520 *
44521 * @doc {heading: 'Models', subheading: 'Classes'}
44522 */
44523 async fit(x, y, args = {}) {
44524 if (!this.built) {
44525 throw new RuntimeError('The model needs to be compiled before ' +
44526 'being used.');
44527 }
44528 return this.model.fit(x, y, args);
44529 }
44530 /**
44531 * Trains the model using a dataset object.
44532 *
44533 * ```js
44534 * const xArray = [
44535 * [1, 1, 1, 1, 1, 1, 1, 1, 1],
44536 * [1, 1, 1, 1, 1, 1, 1, 1, 1],
44537 * [1, 1, 1, 1, 1, 1, 1, 1, 1],
44538 * [1, 1, 1, 1, 1, 1, 1, 1, 1],
44539 * ];
44540 * const yArray = [1, 1, 1, 1];
44541 * // Create a dataset from the JavaScript array.
44542 * const xDataset = tf.data.array(xArray);
44543 * const yDataset = tf.data.array(yArray);
44544 * // Zip combines the `x` and `y` Datasets into a single Dataset, the
44545 * // iterator of which will return an object containing of two tensors,
44546 * // corresponding to `x` and `y`. The call to `batch(4)` will bundle
44547 * // four such samples into a single object, with the same keys now pointing
44548 * // to tensors that hold 4 examples, organized along the batch dimension.
44549 * // The call to `shuffle(4)` causes each iteration through the dataset to
44550 * // happen in a different order. The size of the shuffle window is 4.
44551 * const xyDataset = tf.data.zip({xs: xDataset, ys: yDataset})
44552 * .batch(4)
44553 * .shuffle(4);
44554 * const model = tf.sequential({
44555 * layers: [tf.layers.dense({units: 1, inputShape: [9]})]
44556 * });
44557 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
44558 * const history = await model.fitDataset(xyDataset, {
44559 * epochs: 4,
44560 * callbacks: {onEpochEnd: (epoch, logs) => console.log(logs.loss)}
44561 * });
44562 * ```
44563 *
44564 * @param dataset A dataset object. Its `iterator()` method is expected to
44565 * generate a dataset iterator object, the `next()` method of which is
44566 * expected to produce data batches for evaluation. The return value of the
44567 * `next()` call ought to contain a boolean `done` field and a `value`
44568 * field.
44569 *
44570 * The `value` field is expected to be an object of with fields
44571 * `xs` and `ys`, which point to the feature tensor and the target tensor,
44572 * respectively. This case is for models with exactly one input and one
44573 * output (e.g.. a sequential model). For example:
44574 * ```js
44575 * {value: {xs: xsTensor, ys: ysTensor}, done: false}
44576 * ```
44577 *
44578 * If the model has multiple inputs, the `xs` field of `value` should
44579 * be an object mapping input names to their respective feature tensors.
44580 * For example:
44581 * ```js
44582 * {
44583 * value: {
44584 * xs: {
44585 * input_1: xsTensor1,
44586 * input_2: xsTensor2
44587 * },
44588 * ys: ysTensor
44589 * },
44590 * done: false
44591 * }
44592 * ```
44593 * If the model has multiple outputs, the `ys` field of `value` should
44594 * be an object mapping output names to their respective target tensors.
44595 * For example:
44596 * ```js
44597 * {
44598 * value: {
44599 * xs: xsTensor,
44600 * ys: {
44601 * output_1: ysTensor1,
44602 * output_2: ysTensor2
44603 * },
44604 * },
44605 * done: false
44606 * }
44607 * ```
44608 * @param args A `ModelFitDatasetArgs`, containing optional fields.
44609 *
44610 * @return A `History` instance. Its `history` attribute contains all
44611 * information collected during training.
44612 *
44613 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
44614 */
44615 async fitDataset(dataset, args) {
44616 if (!this.built) {
44617 throw new RuntimeError('The model needs to be compiled before ' +
44618 'being used.');
44619 }
44620 return this.model.fitDataset(dataset, args);
44621 }
44622 /**
44623 * Runs a single gradient update on a single batch of data.
44624 *
44625 * This method differs from `fit()` and `fitDataset()` in the following
44626 * regards:
44627 * - It operates on exactly one batch of data.
44628 * - It returns only the loss and matric values, instead of
44629 * returning the batch-by-batch loss and metric values.
44630 * - It doesn't support fine-grained options such as verbosity and
44631 * callbacks.
44632 *
44633 * @param x Input data. It could be one of the following:
44634 * - A `tf.Tensor`, or an Array of `tf.Tensor`s (in case the model has
44635 * multiple inputs).
44636 * - An Object mapping input names to corresponding `tf.Tensor` (if the
44637 * model has named inputs).
44638 * @param y Target darta. It could be either a `tf.Tensor` a multiple
44639 * `tf.Tensor`s. It should be consistent with `x`.
44640 * @returns Training loss or losses (in case the model has
44641 * multiple outputs), along with metrics (if any), as numbers.
44642 *
44643 * @doc {heading: 'Models', subheading: 'Classes'}
44644 */
44645 async trainOnBatch(x, y) {
44646 return this.model.trainOnBatch(x, y);
44647 }
44648 /* See parent class for JsDoc */
44649 /** @nocollapse */
44650 static fromConfig(cls, config, customObjects = {}, fastWeightInit = false) {
44651 let configArray;
44652 let extraModelConfig = {};
44653 if (config instanceof Array) {
44654 if (!(config[0].className != null) ||
44655 config[0]['className'] === 'Merge') {
44656 throw new ValueError('Legacy serialization format not supported yet.');
44657 }
44658 configArray = config;
44659 }
44660 else {
44661 assert(config['layers'] != null, () => `When the config data for a Sequential model is not an Array, ` +
44662 `it must be an Object that contains the 'layers' field.`);
44663 configArray = config['layers'];
44664 delete config['layers'];
44665 extraModelConfig = config;
44666 }
44667 const model = new cls(extraModelConfig);
44668 if (!(model instanceof Sequential)) {
44669 throw new NotImplementedError(`Sequential.fromConfig called on non-Sequential input: ${model}`);
44670 }
44671 for (const conf of configArray) {
44672 const customObjects = undefined;
44673 const layer = deserialize(conf, customObjects, fastWeightInit);
44674 if (fastWeightInit) {
44675 layer.setFastWeightInitDuringBuild(true);
44676 }
44677 model.add(layer);
44678 }
44679 return model;
44680 }
44681 /**
44682 * Setter used for force stopping of LayersModel.fit() (i.e., training).
44683 *
44684 * Example:
44685 *
44686 * ```js
44687 * const model = tf.sequential();
44688 * model.add(tf.layers.dense({units: 1, inputShape: [10]}));
44689 * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
44690 * const xs = tf.ones([8, 10]);
44691 * const ys = tf.zeros([8, 1]);
44692 *
44693 * const history = await model.fit(xs, ys, {
44694 * epochs: 10,
44695 * callbacks: {
44696 * onEpochEnd: async (epoch, logs) => {
44697 * if (epoch === 2) {
44698 * model.stopTraining = true;
44699 * }
44700 * }
44701 * }
44702 * });
44703 *
44704 * // There should be only 3 values in the loss array, instead of 10 values,
44705 * // due to the stopping after 3 epochs.
44706 * console.log(history.history.loss);
44707 * ```
44708 */
44709 set stopTraining(stop) {
44710 // TODO(cais): When refactoring to remove the composition pattern happens,
44711 // remove this method overriding.
44712 if (this.model == null) {
44713 throw new ValueError('Cannot set the stopTraining property of a sequential model before ' +
44714 'it is compiled.');
44715 }
44716 this.model.stopTraining = stop;
44717 }
44718 get stopTraining() {
44719 if (this.model == null) {
44720 throw new ValueError('Cannot get the stopTraining property of a sequential model before ' +
44721 'it is compiled.');
44722 }
44723 return this.model.stopTraining;
44724 }
44725 // TODO(cais): Override get trainableWeights() here
44726 // tslint:disable-next-line:no-any
44727 getConfig() {
44728 // NOTE(cais): We override the return type of getConfig() to `any` here,
44729 // because the `Sequential` class is a special case among `Container`
44730 // subtypes in that its getConfig() method returns an Array (not a
44731 // dict).
44732 const layers = [];
44733 for (const layer of this.layers) {
44734 const dict = {};
44735 dict['className'] = layer.getClassName();
44736 dict['config'] = layer.getConfig();
44737 layers.push(dict);
44738 }
44739 return { name: this.name, layers };
44740 }
44741 }
44742 /** @nocollapse */
44743 Sequential.className = 'Sequential';
44744 registerClass(Sequential);
44745
44746 /**
44747 * @license
44748 * Copyright 2018 Google LLC
44749 *
44750 * Use of this source code is governed by an MIT-style
44751 * license that can be found in the LICENSE file or at
44752 * https://opensource.org/licenses/MIT.
44753 * =============================================================================
44754 */
44755 // TODO(cais): Add doc string to all the public static functions in this
44756 // class; include exectuable JavaScript code snippets where applicable
44757 // (b/74074458).
44758 // LayersModel and related factory methods.
44759 /**
44760 * A model is a data structure that consists of `Layers` and defines inputs
44761 * and outputs.
44762 *
44763 * The key difference between `tf.model` and `tf.sequential` is that
44764 * `tf.model` is more generic, supporting an arbitrary graph (without
44765 * cycles) of layers. `tf.sequential` is less generic and supports only a linear
44766 * stack of layers.
44767 *
44768 * When creating a `tf.LayersModel`, specify its input(s) and output(s). Layers
44769 * are used to wire input(s) to output(s).
44770 *
44771 * For example, the following code snippet defines a model consisting of
44772 * two `dense` layers, with 10 and 4 units, respectively.
44773 *
44774 * ```js
44775 * // Define input, which has a size of 5 (not including batch dimension).
44776 * const input = tf.input({shape: [5]});
44777 *
44778 * // First dense layer uses relu activation.
44779 * const denseLayer1 = tf.layers.dense({units: 10, activation: 'relu'});
44780 * // Second dense layer uses softmax activation.
44781 * const denseLayer2 = tf.layers.dense({units: 4, activation: 'softmax'});
44782 *
44783 * // Obtain the output symbolic tensor by applying the layers on the input.
44784 * const output = denseLayer2.apply(denseLayer1.apply(input));
44785 *
44786 * // Create the model based on the inputs.
44787 * const model = tf.model({inputs: input, outputs: output});
44788 *
44789 * // The model can be used for training, evaluation and prediction.
44790 * // For example, the following line runs prediction with the model on
44791 * // some fake data.
44792 * model.predict(tf.ones([2, 5])).print();
44793 * ```
44794 * See also:
44795 * `tf.sequential`, `tf.loadLayersModel`.
44796 *
44797 * @doc {heading: 'Models', subheading: 'Creation'}
44798 */
44799 function model(args) {
44800 return new LayersModel(args);
44801 }
44802 /**
44803 * Creates a `tf.Sequential` model. A sequential model is any model where the
44804 * outputs of one layer are the inputs to the next layer, i.e. the model
44805 * topology is a simple 'stack' of layers, with no branching or skipping.
44806 *
44807 * This means that the first layer passed to a `tf.Sequential` model should have
44808 * a defined input shape. What that means is that it should have received an
44809 * `inputShape` or `batchInputShape` argument, or for some type of layers
44810 * (recurrent, Dense...) an `inputDim` argument.
44811 *
44812 * The key difference between `tf.model` and `tf.sequential` is that
44813 * `tf.sequential` is less generic, supporting only a linear stack of layers.
44814 * `tf.model` is more generic and supports an arbitrary graph (without
44815 * cycles) of layers.
44816 *
44817 * Examples:
44818 *
44819 * ```js
44820 * const model = tf.sequential();
44821 *
44822 * // First layer must have an input shape defined.
44823 * model.add(tf.layers.dense({units: 32, inputShape: [50]}));
44824 * // Afterwards, TF.js does automatic shape inference.
44825 * model.add(tf.layers.dense({units: 4}));
44826 *
44827 * // Inspect the inferred shape of the model's output, which equals
44828 * // `[null, 4]`. The 1st dimension is the undetermined batch dimension; the
44829 * // 2nd is the output size of the model's last layer.
44830 * console.log(JSON.stringify(model.outputs[0].shape));
44831 * ```
44832 *
44833 * It is also possible to specify a batch size (with potentially undetermined
44834 * batch dimension, denoted by "null") for the first layer using the
44835 * `batchInputShape` key. The following example is equivalent to the above:
44836 *
44837 * ```js
44838 * const model = tf.sequential();
44839 *
44840 * // First layer must have a defined input shape
44841 * model.add(tf.layers.dense({units: 32, batchInputShape: [null, 50]}));
44842 * // Afterwards, TF.js does automatic shape inference.
44843 * model.add(tf.layers.dense({units: 4}));
44844 *
44845 * // Inspect the inferred shape of the model's output.
44846 * console.log(JSON.stringify(model.outputs[0].shape));
44847 * ```
44848 *
44849 * You can also use an `Array` of already-constructed `Layer`s to create
44850 * a `tf.Sequential` model:
44851 *
44852 * ```js
44853 * const model = tf.sequential({
44854 * layers: [tf.layers.dense({units: 32, inputShape: [50]}),
44855 * tf.layers.dense({units: 4})]
44856 * });
44857 * console.log(JSON.stringify(model.outputs[0].shape));
44858 * ```
44859 *
44860 * @doc {heading: 'Models', subheading: 'Creation'}
44861 */
44862 function sequential(config) {
44863 return new Sequential(config);
44864 }
44865 /**
44866 * Load a model composed of Layer objects, including its topology and optionally
44867 * weights. See the Tutorial named "How to import a Keras Model" for usage
44868 * examples.
44869 *
44870 * This method is applicable to:
44871 *
44872 * 1. Models created with the `tf.layers.*`, `tf.sequential`, and
44873 * `tf.model` APIs of TensorFlow.js and later saved with the
44874 * `tf.LayersModel.save` method.
44875 * 2. Models converted from Keras or TensorFlow tf.keras using the
44876 * [tensorflowjs_converter](https://github.com/tensorflow/tfjs/tree/master/tfjs-converter).
44877 *
44878 * This mode is *not* applicable to TensorFlow `SavedModel`s or their converted
44879 * forms. For those models, use `tf.loadGraphModel`.
44880 *
44881 * Example 1. Load a model from an HTTP server.
44882 *
44883 * ```js
44884 * const model = await tf.loadLayersModel(
44885 * 'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json');
44886 * model.summary();
44887 * ```
44888 *
44889 * Example 2: Save `model`'s topology and weights to browser [local
44890 * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
44891 * then load it back.
44892 *
44893 * ```js
44894 * const model = tf.sequential(
44895 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
44896 * console.log('Prediction from original model:');
44897 * model.predict(tf.ones([1, 3])).print();
44898 *
44899 * const saveResults = await model.save('localstorage://my-model-1');
44900 *
44901 * const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
44902 * console.log('Prediction from loaded model:');
44903 * loadedModel.predict(tf.ones([1, 3])).print();
44904 * ```
44905 *
44906 * Example 3. Saving `model`'s topology and weights to browser
44907 * [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
44908 * then load it back.
44909 *
44910 * ```js
44911 * const model = tf.sequential(
44912 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
44913 * console.log('Prediction from original model:');
44914 * model.predict(tf.ones([1, 3])).print();
44915 *
44916 * const saveResults = await model.save('indexeddb://my-model-1');
44917 *
44918 * const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
44919 * console.log('Prediction from loaded model:');
44920 * loadedModel.predict(tf.ones([1, 3])).print();
44921 * ```
44922 *
44923 * Example 4. Load a model from user-selected files from HTML
44924 * [file input
44925 * elements](https://developer.mozilla.org/en-US/docs/Web/HTML/Element/input/file).
44926 *
44927 * ```js
44928 * // Note: this code snippet will not work without the HTML elements in the
44929 * // page
44930 * const jsonUpload = document.getElementById('json-upload');
44931 * const weightsUpload = document.getElementById('weights-upload');
44932 *
44933 * const model = await tf.loadLayersModel(
44934 * tf.io.browserFiles([jsonUpload.files[0], weightsUpload.files[0]]));
44935 * ```
44936 *
44937 * @param pathOrIOHandler Can be either of the two formats
44938 * 1. A string path to the `ModelAndWeightsConfig` JSON describing
44939 * the model in the canonical TensorFlow.js format. For file://
44940 * (tfjs-node-only), http:// and https:// schemas, the path can be
44941 * either absolute or relative.
44942 * 2. An `tf.io.IOHandler` object that loads model artifacts with its `load`
44943 * method.
44944 * @param options Optional configuration arguments for the model loading,
44945 * including:
44946 * - `strict`: Require that the provided weights exactly match those required
44947 * by the layers. Default true. Passing false means that both extra
44948 * weights and missing weights will be silently ignored.
44949 * - `onProgress`: A function of the signature `(fraction: number) => void',
44950 * that can be used as the progress callback for the model loading.
44951 * @returns A `Promise` of `tf.LayersModel`, with the topology and weights
44952 * loaded.
44953 *
44954 * @doc {heading: 'Models', subheading: 'Loading'}
44955 */
44956 function loadLayersModel(pathOrIOHandler, options) {
44957 if (options == null) {
44958 options = {};
44959 }
44960 return loadLayersModelInternal(pathOrIOHandler, options);
44961 }
44962 /**
44963 * Used to instantiate an input to a model as a `tf.SymbolicTensor`.
44964 *
44965 * Users should call the `input` factory function for
44966 * consistency with other generator functions.
44967 *
44968 * Example:
44969 *
44970 * ```js
44971 * // Defines a simple logistic regression model with 32 dimensional input
44972 * // and 3 dimensional output.
44973 * const x = tf.input({shape: [32]});
44974 * const y = tf.layers.dense({units: 3, activation: 'softmax'}).apply(x);
44975 * const model = tf.model({inputs: x, outputs: y});
44976 * model.predict(tf.ones([2, 32])).print();
44977 * ```
44978 *
44979 * Note: `input` is only necessary when using `model`. When using
44980 * `sequential`, specify `inputShape` for the first layer or use `inputLayer`
44981 * as the first layer.
44982 *
44983 * @doc {heading: 'Models', subheading: 'Inputs'}
44984 */
44985 function input(config) {
44986 return Input(config);
44987 }
44988 function registerCallbackConstructor(verbosityLevel, callbackConstructor) {
44989 CallbackConstructorRegistry.registerCallbackConstructor(verbosityLevel, callbackConstructor);
44990 }
44991
44992 /**
44993 * @license
44994 * Copyright 2018 Google LLC
44995 *
44996 * Use of this source code is governed by an MIT-style
44997 * license that can be found in the LICENSE file or at
44998 * https://opensource.org/licenses/MIT.
44999 * =============================================================================
45000 */
45001 /**
45002 * Base class for Activations.
45003 *
45004 * Special note: due to cross-language compatibility reasons, the
45005 * static readonly className field in this family of classes must be set to
45006 * the initialLowerCamelCase name of the activation.
45007 */
45008 class Activation extends Serializable {
45009 getConfig() {
45010 return {};
45011 }
45012 }
45013 /**
45014 * Exponential linear unit (ELU).
45015 * Reference: https://arxiv.org/abs/1511.07289
45016 */
45017 class Elu$1 extends Activation {
45018 /**
45019 * Calculate the activation function.
45020 *
45021 * @param x: Input.
45022 * @param alpha: Scaling factor the negative section.
45023 * @return Output of the ELU activation.
45024 */
45025 apply(x, alpha = 1) {
45026 return elu$1(x, alpha);
45027 }
45028 }
45029 /** @nocollapse */
45030 Elu$1.className = 'elu';
45031 registerClass(Elu$1);
45032 /**
45033 * Scaled Exponential Linear Unit. (Klambauer et al., 2017).
45034 * Reference: Self-Normalizing Neural Networks, https://arxiv.org/abs/1706.02515
45035 * Notes:
45036 * - To be used together with the initialization "lecunNormal".
45037 * - To be used together with the dropout variant "AlphaDropout".
45038 */
45039 class Selu$1 extends Activation {
45040 apply(x) {
45041 return selu(x);
45042 }
45043 }
45044 /** @nocollapse */
45045 Selu$1.className = 'selu';
45046 registerClass(Selu$1);
45047 /**
45048 * Rectified linear unit
45049 */
45050 class Relu$1 extends Activation {
45051 apply(x) {
45052 return relu(x);
45053 }
45054 }
45055 /** @nocollapse */
45056 Relu$1.className = 'relu';
45057 registerClass(Relu$1);
45058 /**
45059 * Rectified linear unit activation maxing out at 6.0.
45060 */
45061 class Relu6$1 extends Activation {
45062 apply(x) {
45063 return tidy(() => minimum(6.0, relu(x)));
45064 }
45065 }
45066 /** @nocollapse */
45067 Relu6$1.className = 'relu6';
45068 registerClass(Relu6$1);
45069 //* Linear activation (no-op) */
45070 class Linear extends Activation {
45071 apply(x) {
45072 return x;
45073 }
45074 }
45075 /** @nocollapse */
45076 Linear.className = 'linear';
45077 registerClass(Linear);
45078 /**
45079 * Sigmoid activation function.
45080 */
45081 class Sigmoid$1 extends Activation {
45082 apply(x) {
45083 return sigmoid(x);
45084 }
45085 }
45086 /** @nocollapse */
45087 Sigmoid$1.className = 'sigmoid';
45088 registerClass(Sigmoid$1);
45089 /**
45090 * Segment-wise linear approximation of sigmoid.
45091 */
45092 class HardSigmoid extends Activation {
45093 apply(x) {
45094 return hardSigmoid(x);
45095 }
45096 }
45097 /** @nocollapse */
45098 HardSigmoid.className = 'hardSigmoid';
45099 registerClass(HardSigmoid);
45100 /**
45101 * Softplus activation function.
45102 */
45103 class Softplus$1 extends Activation {
45104 apply(x) {
45105 return softplus(x);
45106 }
45107 }
45108 /** @nocollapse */
45109 Softplus$1.className = 'softplus';
45110 registerClass(Softplus$1);
45111 /**
45112 * Softsign activation function.
45113 */
45114 class Softsign extends Activation {
45115 apply(x) {
45116 return softsign(x);
45117 }
45118 }
45119 /** @nocollapse */
45120 Softsign.className = 'softsign';
45121 registerClass(Softsign);
45122 /**
45123 * Hyperbolic tangent function.
45124 */
45125 class Tanh$1 extends Activation {
45126 apply(x) {
45127 return tanh$1(x);
45128 }
45129 }
45130 /** @nocollapse */
45131 Tanh$1.className = 'tanh';
45132 registerClass(Tanh$1);
45133 /**
45134 * Softmax activation function
45135 */
45136 class Softmax$1 extends Activation {
45137 /**
45138 * Calculate the activation function.
45139 *
45140 * @param x Tensor.
45141 * @param axis Integer, axis along which the softmax normalization is applied.
45142 * Invalid if < 2, as softmax across 1 (the batch dimension) is assumed to be
45143 * an error.
45144 *
45145 * @returns a Tensor of the same shape as x
45146 *
45147 * @throws ValueError: In case `dim(x) < 2`.
45148 */
45149 apply(x, axis = (-1)) {
45150 return softmax(x, axis);
45151 }
45152 }
45153 /** @nocollapse */
45154 Softmax$1.className = 'softmax';
45155 registerClass(Softmax$1);
45156 /**
45157 * Log softmax activation function
45158 */
45159 class LogSoftmax$1 extends Activation {
45160 /**
45161 * Calculate the activation function of log softmax:
45162 * log( exp(x_i) / sum(exp(x)) )
45163 *
45164 * @param x Tensor.
45165 * @param axis Integer, axis along which the softmax normalization is applied.
45166 * Invalid if < 2, as softmax across 1 (the batch dimension) is assumed to be
45167 * an error.
45168 *
45169 * @returns a Tensor of the same shape as x
45170 *
45171 * @throws ValueError: In case `dim(x) < 2`.
45172 */
45173 apply(x, axis = (-1)) {
45174 return logSoftmax(x, axis);
45175 }
45176 }
45177 /** @nocollapse */
45178 LogSoftmax$1.className = 'logSoftmax';
45179 registerClass(LogSoftmax$1);
45180 /**
45181 * Swish activation function
45182 */
45183 class Swish extends Activation {
45184 /**
45185 * Calculate the activation function.
45186 *
45187 * @param x Tensor.
45188 * @param alpha Scaling factor for the sigmoid function.
45189 * @returns a Tensor of the same shape as x
45190 */
45191 apply(x, alpha = 1) {
45192 return tidy(() => mul(sigmoid(mul(x, alpha)), x));
45193 }
45194 }
45195 /** @nocollapse */
45196 Swish.className = 'swish';
45197 registerClass(Swish);
45198 /**
45199 * Mish activation function
45200 */
45201 class Mish extends Activation {
45202 /**
45203 * Calculate the activation function.
45204 *
45205 * @param x Tensor.
45206 * @returns a Tensor of the same shape as x
45207 */
45208 apply(x) {
45209 return tidy(() => mul(x, tanh$1(softplus(x))));
45210 }
45211 }
45212 /** @nocollapse */
45213 Mish.className = 'mish';
45214 registerClass(Mish);
45215 function serializeActivation(activation) {
45216 return activation.getClassName();
45217 }
45218 function deserializeActivation(config, customObjects = {}) {
45219 return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'activation');
45220 }
45221 function getActivation(identifier) {
45222 if (identifier == null) {
45223 const config = {};
45224 config['className'] = 'linear';
45225 config['config'] = {};
45226 return deserializeActivation(config);
45227 }
45228 if (typeof identifier === 'string') {
45229 const config = {};
45230 config['className'] = identifier;
45231 config['config'] = {};
45232 return deserializeActivation(config);
45233 }
45234 else if (identifier instanceof Activation) {
45235 return identifier;
45236 }
45237 else {
45238 return deserializeActivation(identifier);
45239 }
45240 }
45241
45242 /**
45243 * @license
45244 * Copyright 2018 Google LLC
45245 *
45246 * Use of this source code is governed by an MIT-style
45247 * license that can be found in the LICENSE file or at
45248 * https://opensource.org/licenses/MIT.
45249 * =============================================================================
45250 */
45251 function assertObjectArgs(args) {
45252 if (args != null && typeof args !== 'object') {
45253 throw new Error(`Argument to L1L2 regularizer's constructor is expected to be an ` +
45254 `object, but received: ${args}`);
45255 }
45256 }
45257 /**
45258 * Regularizer base class.
45259 */
45260 class Regularizer extends Serializable {
45261 }
45262 class L1L2 extends Regularizer {
45263 constructor(args) {
45264 super();
45265 assertObjectArgs(args);
45266 this.l1 = args == null || args.l1 == null ? 0.01 : args.l1;
45267 this.l2 = args == null || args.l2 == null ? 0.01 : args.l2;
45268 this.hasL1 = this.l1 !== 0;
45269 this.hasL2 = this.l2 !== 0;
45270 }
45271 /**
45272 * Porting note: Renamed from __call__.
45273 * @param x Variable of which to calculate the regularization score.
45274 */
45275 apply(x) {
45276 return tidy(() => {
45277 let regularization = zeros([1]);
45278 if (this.hasL1) {
45279 regularization = add$1(regularization, sum$1(mul(this.l1, abs(x))));
45280 }
45281 if (this.hasL2) {
45282 regularization =
45283 add$1(regularization, sum$1(mul(this.l2, square$1(x))));
45284 }
45285 return reshape(regularization, []);
45286 });
45287 }
45288 getConfig() {
45289 return { 'l1': this.l1, 'l2': this.l2 };
45290 }
45291 /** @nocollapse */
45292 static fromConfig(cls, config) {
45293 return new cls({ l1: config['l1'], l2: config['l2'] });
45294 }
45295 }
45296 /** @nocollapse */
45297 L1L2.className = 'L1L2';
45298 registerClass(L1L2);
45299 function l1(args) {
45300 assertObjectArgs(args);
45301 return new L1L2({ l1: args != null ? args.l1 : null, l2: 0 });
45302 }
45303 function l2(args) {
45304 assertObjectArgs(args);
45305 return new L1L2({ l2: args != null ? args.l2 : null, l1: 0 });
45306 }
45307 // Maps the JavaScript-like identifier keys to the corresponding keras symbols.
45308 const REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
45309 'l1l2': 'L1L2'
45310 };
45311 function serializeRegularizer(constraint) {
45312 return serializeKerasObject(constraint);
45313 }
45314 function deserializeRegularizer(config, customObjects = {}) {
45315 return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'regularizer');
45316 }
45317 function getRegularizer(identifier) {
45318 if (identifier == null) {
45319 return null;
45320 }
45321 if (typeof identifier === 'string') {
45322 const className = identifier in REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
45323 REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
45324 identifier;
45325 const config = { className, config: {} };
45326 return deserializeRegularizer(config);
45327 }
45328 else if (identifier instanceof Regularizer) {
45329 return identifier;
45330 }
45331 else {
45332 return deserializeRegularizer(identifier);
45333 }
45334 }
45335
45336 /**
45337 * @license
45338 * Copyright 2018 Google LLC
45339 *
45340 * Use of this source code is governed by an MIT-style
45341 * license that can be found in the LICENSE file or at
45342 * https://opensource.org/licenses/MIT.
45343 * =============================================================================
45344 */
45345 class ReLU extends Layer {
45346 constructor(args) {
45347 super(args == null ? {} : args);
45348 this.supportsMasking = true;
45349 if (args != null) {
45350 this.maxValue = args.maxValue;
45351 }
45352 }
45353 call(inputs, kwargs) {
45354 inputs = getExactlyOneTensor(inputs);
45355 let output = relu(inputs);
45356 if (this.maxValue != null) {
45357 output = clipByValue(output, 0, this.maxValue);
45358 }
45359 return output;
45360 }
45361 computeOutputShape(inputShape) {
45362 return inputShape;
45363 }
45364 getConfig() {
45365 const config = { maxValue: this.maxValue };
45366 const baseConfig = super.getConfig();
45367 Object.assign(config, baseConfig);
45368 return config;
45369 }
45370 }
45371 /** @nocollapse */
45372 ReLU.className = 'ReLU';
45373 registerClass(ReLU);
45374 class LeakyReLU extends Layer {
45375 constructor(args) {
45376 super(args == null ? {} : args);
45377 this.DEFAULT_ALPHA = 0.3;
45378 if (args == null) {
45379 args = {};
45380 }
45381 this.alpha = args.alpha == null ? this.DEFAULT_ALPHA : args.alpha;
45382 }
45383 call(inputs, kwargs) {
45384 const x = getExactlyOneTensor(inputs);
45385 return leakyRelu(x, this.alpha);
45386 }
45387 computeOutputShape(inputShape) {
45388 return inputShape;
45389 }
45390 getConfig() {
45391 const config = { alpha: this.alpha };
45392 const baseConfig = super.getConfig();
45393 Object.assign(config, baseConfig);
45394 return config;
45395 }
45396 }
45397 /** @nocollapse */
45398 LeakyReLU.className = 'LeakyReLU';
45399 registerClass(LeakyReLU);
45400 class PReLU extends Layer {
45401 constructor(args) {
45402 super(args == null ? {} : args);
45403 this.DEFAULT_ALPHA_INITIALIZER = 'zeros';
45404 if (args == null) {
45405 args = {};
45406 }
45407 this.supportsMasking = true;
45408 this.alphaInitializer =
45409 getInitializer(args.alphaInitializer || this.DEFAULT_ALPHA_INITIALIZER);
45410 this.alphaRegularizer = getRegularizer(args.alphaRegularizer);
45411 this.alphaConstraint = getConstraint(args.alphaConstraint);
45412 if (args.sharedAxes == null) {
45413 this.sharedAxes = null;
45414 }
45415 else if (Array.isArray(args.sharedAxes)) {
45416 this.sharedAxes = args.sharedAxes;
45417 }
45418 else if (typeof args.sharedAxes === 'number') {
45419 this.sharedAxes = [args.sharedAxes];
45420 }
45421 else {
45422 throw new ValueError(`Expected sharedAxes to be a number or an array of numbers, ` +
45423 `but got ${args.sharedAxes}`);
45424 }
45425 }
45426 build(inputShape) {
45427 inputShape = getExactlyOneShape(inputShape);
45428 const paramShape = inputShape.slice(1);
45429 if (this.sharedAxes != null) {
45430 for (const i of this.sharedAxes) {
45431 paramShape[i - 1] = 1;
45432 }
45433 }
45434 this.alpha = this.addWeight('alpha', paramShape, 'float32', this.alphaInitializer, this.alphaRegularizer, true, this.alphaConstraint);
45435 // Set input spec.
45436 const axes = {};
45437 if (this.sharedAxes != null) {
45438 for (let i = 1; i < inputShape.length; ++i) {
45439 axes[i] = inputShape[i];
45440 }
45441 }
45442 this.inputSpec = [new InputSpec({
45443 ndim: inputShape.length,
45444 axes,
45445 })];
45446 this.built = true;
45447 }
45448 call(inputs, kwargs) {
45449 inputs = getExactlyOneTensor(inputs);
45450 return prelu(inputs, this.alpha.read());
45451 }
45452 getConfig() {
45453 const config = {
45454 alphaInitializer: serializeInitializer(this.alphaInitializer),
45455 alphaRegularizer: serializeRegularizer(this.alphaRegularizer),
45456 alphaConstraint: serializeConstraint(this.alphaConstraint),
45457 sharedAxes: this.sharedAxes
45458 };
45459 const baseConfig = super.getConfig();
45460 Object.assign(config, baseConfig);
45461 return config;
45462 }
45463 }
45464 /** @nocollapse */
45465 PReLU.className = 'PReLU';
45466 registerClass(PReLU);
45467 class ELU extends Layer {
45468 constructor(args) {
45469 super(args == null ? {} : args);
45470 this.DEFAULT_ALPHA = 1.0;
45471 if (args == null) {
45472 args = {};
45473 }
45474 if (args.alpha != null && args.alpha !== this.DEFAULT_ALPHA) {
45475 throw new NotImplementedError(`Non-default alpha value (${args.alpha}) is not supported by the ` +
45476 `ELU layer yet.`);
45477 }
45478 this.alpha = args.alpha == null ? this.DEFAULT_ALPHA : args.alpha;
45479 }
45480 call(inputs, kwargs) {
45481 const x = getExactlyOneTensor(inputs);
45482 return elu(x);
45483 }
45484 computeOutputShape(inputShape) {
45485 return inputShape;
45486 }
45487 getConfig() {
45488 const config = { alpha: this.alpha };
45489 const baseConfig = super.getConfig();
45490 Object.assign(config, baseConfig);
45491 return config;
45492 }
45493 }
45494 /** @nocollapse */
45495 ELU.className = 'ELU';
45496 registerClass(ELU);
45497 class ThresholdedReLU extends Layer {
45498 constructor(args) {
45499 super(args == null ? {} : args);
45500 this.DEFAULT_THETA = 1.0;
45501 if (args == null) {
45502 args = {};
45503 }
45504 this.theta = args.theta == null ? this.DEFAULT_THETA : args.theta;
45505 }
45506 call(inputs, kwargs) {
45507 const x = getExactlyOneTensor(inputs);
45508 return mul(x, cast(greater(x, this.theta), 'float32'));
45509 }
45510 computeOutputShape(inputShape) {
45511 return inputShape;
45512 }
45513 getConfig() {
45514 const config = { theta: this.theta };
45515 const baseConfig = super.getConfig();
45516 Object.assign(config, baseConfig);
45517 return config;
45518 }
45519 }
45520 /** @nocollapse */
45521 ThresholdedReLU.className = 'ThresholdedReLU';
45522 registerClass(ThresholdedReLU);
45523 class Softmax$2 extends Layer {
45524 constructor(args) {
45525 super(args == null ? {} : args);
45526 this.DEFAULT_AXIS = 1.0;
45527 if (args == null) {
45528 args = {};
45529 }
45530 this.softmax = new Softmax$1().apply;
45531 this.axis = args.axis == null ? this.DEFAULT_AXIS : args.axis;
45532 }
45533 call(inputs, kwargs) {
45534 const x = getExactlyOneTensor(inputs);
45535 return this.softmax(x, this.axis);
45536 }
45537 computeOutputShape(inputShape) {
45538 return inputShape;
45539 }
45540 getConfig() {
45541 const config = { axis: this.axis };
45542 const baseConfig = super.getConfig();
45543 Object.assign(config, baseConfig);
45544 return config;
45545 }
45546 }
45547 /** @nocollapse */
45548 Softmax$2.className = 'Softmax';
45549 registerClass(Softmax$2);
45550
45551 /**
45552 * @license
45553 * Copyright 2018 Google LLC
45554 *
45555 * Use of this source code is governed by an MIT-style
45556 * license that can be found in the LICENSE file or at
45557 * https://opensource.org/licenses/MIT.
45558 * =============================================================================
45559 */
45560 /**
45561 * Transforms a single number of array of numbers into an array of numbers.
45562 * @param value
45563 * @param n: The size of the tuple to be returned.
45564 * @param name: Name of the parameter, used for generating error messages.
45565 * @returns An array of numbers.
45566 */
45567 function normalizeArray(value, n, name) {
45568 if (typeof value === 'number') {
45569 return pyListRepeat(value, n);
45570 }
45571 else {
45572 if (value.length !== n) {
45573 throw new ValueError(`The ${name} argument must be an integer or tuple of ${n} integers.` +
45574 ` Received: ${value.length} elements.`);
45575 }
45576 for (let i = 0; i < n; ++i) {
45577 const singleValue = value[i];
45578 if (!isInteger(singleValue)) {
45579 throw new ValueError(`The ${name} argument must be an integer or tuple of ${n}` +
45580 ` integers. Received: ${JSON.stringify(value)} including a` +
45581 ` non-integer number ${singleValue}`);
45582 }
45583 }
45584 return value;
45585 }
45586 }
45587 /**
45588 * Determines output length of a convolution given input length.
45589 * @param inputLength
45590 * @param filterSize
45591 * @param padding
45592 * @param stride
45593 * @param dilation: dilation rate.
45594 */
45595 function convOutputLength(inputLength, filterSize, padding, stride, dilation = 1) {
45596 if (inputLength == null) {
45597 return inputLength;
45598 }
45599 const dilatedFilterSize = filterSize + (filterSize - 1) * (dilation - 1);
45600 let outputLength;
45601 if (padding === 'same') {
45602 outputLength = inputLength;
45603 }
45604 else { // VALID
45605 outputLength = inputLength - dilatedFilterSize + 1;
45606 }
45607 return Math.floor((outputLength + stride - 1) / stride);
45608 }
45609 function deconvLength(dimSize, strideSize, kernelSize, padding) {
45610 if (dimSize == null) {
45611 return null;
45612 }
45613 if (padding === 'valid') {
45614 dimSize = dimSize * strideSize + max$1([kernelSize - strideSize, 0]);
45615 }
45616 else if (padding === 'same') {
45617 dimSize = dimSize * strideSize;
45618 }
45619 else {
45620 throw new ValueError(`Unsupport padding mode: ${padding}.`);
45621 }
45622 return dimSize;
45623 }
45624
45625 /**
45626 * @license
45627 * Copyright 2018 Google LLC
45628 *
45629 * Use of this source code is governed by an MIT-style
45630 * license that can be found in the LICENSE file or at
45631 * https://opensource.org/licenses/MIT.
45632 * =============================================================================
45633 */
45634 /**
45635 * Transpose and cast the input before the conv2d.
45636 * @param x Input image tensor.
45637 * @param dataFormat
45638 */
45639 function preprocessConv2DInput(x, dataFormat) {
45640 // TODO(cais): Cast type to float32 if not.
45641 return tidy(() => {
45642 checkDataFormat(dataFormat);
45643 if (dataFormat === 'channelsFirst') {
45644 return transpose(x, [0, 2, 3, 1]); // NCHW -> NHWC.
45645 }
45646 else {
45647 return x;
45648 }
45649 });
45650 }
45651 /**
45652 * Transpose and cast the input before the conv3d.
45653 * @param x Input image tensor.
45654 * @param dataFormat
45655 */
45656 function preprocessConv3DInput(x, dataFormat) {
45657 return tidy(() => {
45658 checkDataFormat(dataFormat);
45659 if (dataFormat === 'channelsFirst') {
45660 return transpose(x, [0, 2, 3, 4, 1]); // NCDHW -> NDHWC.
45661 }
45662 else {
45663 return x;
45664 }
45665 });
45666 }
45667 /**
45668 * 1D-convolution with bias added.
45669 *
45670 * Porting Note: This function does not exist in the Python Keras backend.
45671 * It is exactly the same as `conv2d`, except the added `bias`.
45672 *
45673 * @param x Input tensor, rank-3, of shape `[batchSize, width, inChannels]`.
45674 * @param kernel Kernel, rank-3, of shape `[filterWidth, inDepth, outDepth]`.
45675 * @param bias Bias, rank-3, of shape `[outDepth]`.
45676 * @param strides
45677 * @param padding Padding mode.
45678 * @param dataFormat Data format.
45679 * @param dilationRate
45680 * @returns The result of the 1D convolution.
45681 * @throws ValueError, if `x`, `kernel` or `bias` is not of the correct rank.
45682 */
45683 function conv1dWithBias(x, kernel, bias, strides = 1, padding = 'valid', dataFormat, dilationRate = 1) {
45684 return tidy(() => {
45685 if (dataFormat == null) {
45686 dataFormat = imageDataFormat();
45687 }
45688 checkDataFormat(dataFormat);
45689 // Check the ranks of x, kernel and bias.
45690 if (x.shape.length !== 3) {
45691 throw new ValueError(`The input of a conv1dWithBias operation should be 3, but is ` +
45692 `${x.shape.length} instead.`);
45693 }
45694 if (kernel.shape.length !== 3) {
45695 throw new ValueError(`The kernel for a conv1dWithBias operation should be 3, but is ` +
45696 `${kernel.shape.length} instead`);
45697 }
45698 if (bias != null && bias.shape.length !== 1) {
45699 throw new ValueError(`The bias for a conv1dWithBias operation should be 1, but is ` +
45700 `${kernel.shape.length} instead`);
45701 }
45702 // TODO(cais): Support CAUSAL padding mode.
45703 if (dataFormat === 'channelsFirst') {
45704 x = transpose(x, [0, 2, 1]); // NCW -> NWC.
45705 }
45706 if (padding === 'causal') {
45707 throw new NotImplementedError('The support for CAUSAL padding mode in conv1dWithBias is not ' +
45708 'implemented yet.');
45709 }
45710 let y = conv1d(x, kernel, strides, padding === 'same' ? 'same' : 'valid', 'NWC', dilationRate);
45711 if (bias != null) {
45712 y = biasAdd(y, bias);
45713 }
45714 return y;
45715 });
45716 }
45717 /**
45718 * 1D-convolution.
45719 *
45720 * @param x Input tensor, rank-3, of shape `[batchSize, width, inChannels]`.
45721 * @param kernel Kernel, rank-3, of shape `[filterWidth, inDepth, outDepth]`.s
45722 * @param strides
45723 * @param padding Padding mode.
45724 * @param dataFormat Data format.
45725 * @param dilationRate
45726 * @returns The result of the 1D convolution.
45727 * @throws ValueError, if `x`, `kernel` or `bias` is not of the correct rank.
45728 */
45729 function conv1d$1(x, kernel, strides = 1, padding = 'valid', dataFormat, dilationRate = 1) {
45730 return tidy(() => {
45731 checkDataFormat(dataFormat);
45732 return conv1dWithBias(x, kernel, null, strides, padding, dataFormat, dilationRate);
45733 });
45734 }
45735 /**
45736 * 2D Convolution
45737 * @param x
45738 * @param kernel kernel of the convolution.
45739 * @param strides strides array.
45740 * @param padding padding mode. Default to 'valid'.
45741 * @param dataFormat data format. Defaults to 'channelsLast'.
45742 * @param dilationRate dilation rate array.
45743 * @returns Result of the 2D pooling.
45744 */
45745 function conv2d$2(x, kernel, strides = [1, 1], padding = 'valid', dataFormat, dilationRate) {
45746 return tidy(() => {
45747 checkDataFormat(dataFormat);
45748 return conv2dWithBiasActivation(x, kernel, null, strides, padding, dataFormat, dilationRate);
45749 });
45750 }
45751 /**
45752 * 2D Convolution with an added bias and optional activation.
45753 * Note: This function does not exist in the Python Keras Backend. This function
45754 * is exactly the same as `conv2d`, except the added `bias`.
45755 */
45756 function conv2dWithBiasActivation(x, kernel, bias, strides = [1, 1], padding = 'valid', dataFormat, dilationRate, activation = null) {
45757 return tidy(() => {
45758 if (dataFormat == null) {
45759 dataFormat = imageDataFormat();
45760 }
45761 checkDataFormat(dataFormat);
45762 if (x.rank !== 3 && x.rank !== 4) {
45763 throw new ValueError(`conv2dWithBiasActivation expects input to be of rank 3 or 4, ` +
45764 `but received ${x.rank}.`);
45765 }
45766 if (kernel.rank !== 3 && kernel.rank !== 4) {
45767 throw new ValueError(`conv2dWithBiasActivation expects kernel to be of rank 3 or 4, ` +
45768 `but received ${x.rank}.`);
45769 }
45770 let y = preprocessConv2DInput(x, dataFormat);
45771 if (padding === 'causal') {
45772 throw new NotImplementedError('The support for CAUSAL padding mode in conv1dWithBias is not ' +
45773 'implemented yet.');
45774 }
45775 y = conv2d$1({
45776 x: y,
45777 filter: kernel,
45778 strides: strides,
45779 pad: padding === 'same' ? 'same' : 'valid',
45780 dilations: dilationRate,
45781 dataFormat: 'NHWC',
45782 bias,
45783 activation
45784 });
45785 if (dataFormat === 'channelsFirst') {
45786 y = transpose(y, [0, 3, 1, 2]);
45787 }
45788 return y;
45789 });
45790 }
45791 /**
45792 * 3D Convolution.
45793 * @param x
45794 * @param kernel kernel of the convolution.
45795 * @param strides strides array.
45796 * @param padding padding mode. Default to 'valid'.
45797 * @param dataFormat data format. Defaults to 'channelsLast'.
45798 * @param dilationRate dilation rate array.
45799 * @returns Result of the 3D convolution.
45800 */
45801 function conv3d$1(x, kernel, strides = [1, 1, 1], padding = 'valid', dataFormat, dilationRate) {
45802 return tidy(() => {
45803 checkDataFormat(dataFormat);
45804 return conv3dWithBias(x, kernel, null, strides, padding, dataFormat, dilationRate);
45805 });
45806 }
45807 /**
45808 * 3D Convolution with an added bias.
45809 * Note: This function does not exist in the Python Keras Backend. This function
45810 * is exactly the same as `conv3d`, except the added `bias`.
45811 */
45812 function conv3dWithBias(x, kernel, bias, strides = [1, 1, 1], padding = 'valid', dataFormat, dilationRate) {
45813 return tidy(() => {
45814 if (dataFormat == null) {
45815 dataFormat = imageDataFormat();
45816 }
45817 checkDataFormat(dataFormat);
45818 if (x.rank !== 4 && x.rank !== 5) {
45819 throw new ValueError(`conv3dWithBias expects input to be of rank 4 or 5, but received ` +
45820 `${x.rank}.`);
45821 }
45822 if (kernel.rank !== 4 && kernel.rank !== 5) {
45823 throw new ValueError(`conv3dWithBias expects kernel to be of rank 4 or 5, but received ` +
45824 `${x.rank}.`);
45825 }
45826 let y = preprocessConv3DInput(x, dataFormat);
45827 if (padding === 'causal') {
45828 throw new NotImplementedError('The support for CAUSAL padding mode in conv3dWithBias is not ' +
45829 'implemented yet.');
45830 }
45831 y = conv3d(y, kernel, strides, padding === 'same' ? 'same' : 'valid', 'NDHWC', dilationRate);
45832 if (bias != null) {
45833 y = biasAdd(y, bias);
45834 }
45835 if (dataFormat === 'channelsFirst') {
45836 y = transpose(y, [0, 4, 1, 2, 3]);
45837 }
45838 return y;
45839 });
45840 }
45841 /**
45842 * Abstract convolution layer.
45843 */
45844 class BaseConv extends Layer {
45845 constructor(rank, args) {
45846 super(args);
45847 this.bias = null;
45848 this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
45849 this.DEFAULT_BIAS_INITIALIZER = 'zeros';
45850 BaseConv.verifyArgs(args);
45851 this.rank = rank;
45852 assertPositiveInteger(this.rank, 'rank');
45853 if (this.rank !== 1 && this.rank !== 2 && this.rank !== 3) {
45854 throw new NotImplementedError(`Convolution layer for rank other than 1, 2, or 3 (${this.rank}) is ` +
45855 `not implemented yet.`);
45856 }
45857 this.kernelSize = normalizeArray(args.kernelSize, rank, 'kernelSize');
45858 this.strides = normalizeArray(args.strides == null ? 1 : args.strides, rank, 'strides');
45859 this.padding = args.padding == null ? 'valid' : args.padding;
45860 checkPaddingMode(this.padding);
45861 this.dataFormat =
45862 args.dataFormat == null ? 'channelsLast' : args.dataFormat;
45863 checkDataFormat(this.dataFormat);
45864 this.activation = getActivation(args.activation);
45865 this.useBias = args.useBias == null ? true : args.useBias;
45866 this.biasInitializer =
45867 getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
45868 this.biasConstraint = getConstraint(args.biasConstraint);
45869 this.biasRegularizer = getRegularizer(args.biasRegularizer);
45870 this.activityRegularizer = getRegularizer(args.activityRegularizer);
45871 this.dilationRate = normalizeArray(args.dilationRate == null ? 1 : args.dilationRate, rank, 'dilationRate');
45872 if (this.rank === 1 &&
45873 (Array.isArray(this.dilationRate) && this.dilationRate.length !== 1)) {
45874 throw new ValueError(`dilationRate must be a number or an array of a single number ` +
45875 `for 1D convolution, but received ` +
45876 `${JSON.stringify(this.dilationRate)}`);
45877 }
45878 else if (this.rank === 2) {
45879 if (typeof this.dilationRate === 'number') {
45880 this.dilationRate = [this.dilationRate, this.dilationRate];
45881 }
45882 else if (this.dilationRate.length !== 2) {
45883 throw new ValueError(`dilationRate must be a number or array of two numbers for 2D ` +
45884 `convolution, but received ${JSON.stringify(this.dilationRate)}`);
45885 }
45886 }
45887 else if (this.rank === 3) {
45888 if (typeof this.dilationRate === 'number') {
45889 this.dilationRate =
45890 [this.dilationRate, this.dilationRate, this.dilationRate];
45891 }
45892 else if (this.dilationRate.length !== 3) {
45893 throw new ValueError(`dilationRate must be a number or array of three numbers for 3D ` +
45894 `convolution, but received ${JSON.stringify(this.dilationRate)}`);
45895 }
45896 }
45897 }
45898 static verifyArgs(args) {
45899 // Check config.kernelSize type and shape.
45900 assert$1('kernelSize' in args, `required key 'kernelSize' not in config`);
45901 if (typeof args.kernelSize !== 'number' &&
45902 !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 3)) {
45903 throw new ValueError(`BaseConv expects config.kernelSize to be number or number[] with ` +
45904 `length 1, 2, or 3, but received ${JSON.stringify(args.kernelSize)}.`);
45905 }
45906 }
45907 getConfig() {
45908 const config = {
45909 kernelSize: this.kernelSize,
45910 strides: this.strides,
45911 padding: this.padding,
45912 dataFormat: this.dataFormat,
45913 dilationRate: this.dilationRate,
45914 activation: serializeActivation(this.activation),
45915 useBias: this.useBias,
45916 biasInitializer: serializeInitializer(this.biasInitializer),
45917 biasRegularizer: serializeRegularizer(this.biasRegularizer),
45918 activityRegularizer: serializeRegularizer(this.activityRegularizer),
45919 biasConstraint: serializeConstraint(this.biasConstraint)
45920 };
45921 const baseConfig = super.getConfig();
45922 Object.assign(config, baseConfig);
45923 return config;
45924 }
45925 }
45926 /**
45927 * Abstract nD convolution layer. Ancestor of convolution layers which reduce
45928 * across channels, i.e., Conv1D and Conv2D, but not DepthwiseConv2D.
45929 */
45930 class Conv extends BaseConv {
45931 constructor(rank, args) {
45932 super(rank, args);
45933 this.kernel = null;
45934 Conv.verifyArgs(args);
45935 this.filters = args.filters;
45936 assertPositiveInteger(this.filters, 'filters');
45937 this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
45938 this.kernelConstraint = getConstraint(args.kernelConstraint);
45939 this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
45940 }
45941 build(inputShape) {
45942 inputShape = getExactlyOneShape(inputShape);
45943 const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
45944 if (inputShape[channelAxis] == null) {
45945 throw new ValueError(`The channel dimension of the input should be defined. ` +
45946 `Found ${inputShape[channelAxis]}`);
45947 }
45948 const inputDim = inputShape[channelAxis];
45949 const kernelShape = this.kernelSize.concat([inputDim, this.filters]);
45950 this.kernel = this.addWeight('kernel', kernelShape, null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
45951 if (this.useBias) {
45952 this.bias = this.addWeight('bias', [this.filters], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
45953 }
45954 this.inputSpec = [{ ndim: this.rank + 2, axes: { [channelAxis]: inputDim } }];
45955 this.built = true;
45956 }
45957 call(inputs, kwargs) {
45958 return tidy(() => {
45959 inputs = getExactlyOneTensor(inputs);
45960 let outputs;
45961 const biasValue = this.bias == null ? null : this.bias.read();
45962 const fusedActivationName = mapActivationToFusedKernel(this.activation.getClassName());
45963 if (fusedActivationName != null && this.rank === 2) {
45964 outputs = conv2dWithBiasActivation(inputs, this.kernel.read(), biasValue, this.strides, this.padding, this.dataFormat, this.dilationRate, fusedActivationName);
45965 }
45966 else {
45967 if (this.rank === 1) {
45968 outputs = conv1dWithBias(inputs, this.kernel.read(), biasValue, this.strides[0], this.padding, this.dataFormat, this.dilationRate[0]);
45969 }
45970 else if (this.rank === 2) {
45971 // TODO(cais): Move up to constructor.
45972 outputs = conv2dWithBiasActivation(inputs, this.kernel.read(), biasValue, this.strides, this.padding, this.dataFormat, this.dilationRate);
45973 }
45974 else if (this.rank === 3) {
45975 outputs = conv3dWithBias(inputs, this.kernel.read(), biasValue, this.strides, this.padding, this.dataFormat, this.dilationRate);
45976 }
45977 else {
45978 throw new NotImplementedError('convolutions greater than 3D are not implemented yet.');
45979 }
45980 if (this.activation != null) {
45981 outputs = this.activation.apply(outputs);
45982 }
45983 }
45984 return outputs;
45985 });
45986 }
45987 computeOutputShape(inputShape) {
45988 inputShape = getExactlyOneShape(inputShape);
45989 const newSpace = [];
45990 const space = (this.dataFormat === 'channelsLast') ?
45991 inputShape.slice(1, inputShape.length - 1) :
45992 inputShape.slice(2);
45993 for (let i = 0; i < space.length; ++i) {
45994 const newDim = convOutputLength(space[i], this.kernelSize[i], this.padding, this.strides[i], typeof this.dilationRate === 'number' ? this.dilationRate :
45995 this.dilationRate[i]);
45996 newSpace.push(newDim);
45997 }
45998 let outputShape = [inputShape[0]];
45999 if (this.dataFormat === 'channelsLast') {
46000 outputShape = outputShape.concat(newSpace);
46001 outputShape.push(this.filters);
46002 }
46003 else {
46004 outputShape.push(this.filters);
46005 outputShape = outputShape.concat(newSpace);
46006 }
46007 return outputShape;
46008 }
46009 getConfig() {
46010 const config = {
46011 filters: this.filters,
46012 kernelInitializer: serializeInitializer(this.kernelInitializer),
46013 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
46014 kernelConstraint: serializeConstraint(this.kernelConstraint)
46015 };
46016 const baseConfig = super.getConfig();
46017 Object.assign(config, baseConfig);
46018 return config;
46019 }
46020 static verifyArgs(args) {
46021 // Check config.filters type, shape, and value.
46022 if (!('filters' in args) || typeof args.filters !== 'number' ||
46023 args.filters < 1) {
46024 throw new ValueError(`Convolution layer expected config.filters to be a 'number' > 0 ` +
46025 `but got ${JSON.stringify(args.filters)}`);
46026 }
46027 }
46028 }
46029 class Conv2D$1 extends Conv {
46030 constructor(args) {
46031 super(2, args);
46032 Conv2D$1.verifyArgs(args);
46033 }
46034 getConfig() {
46035 const config = super.getConfig();
46036 delete config['rank'];
46037 return config;
46038 }
46039 static verifyArgs(args) {
46040 // config.kernelSize must be a number or array of numbers.
46041 if ((typeof args.kernelSize !== 'number') &&
46042 !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 2)) {
46043 throw new ValueError(`Conv2D expects config.kernelSize to be number or number[] with ` +
46044 `length 1 or 2, but received ${JSON.stringify(args.kernelSize)}.`);
46045 }
46046 }
46047 }
46048 /** @nocollapse */
46049 Conv2D$1.className = 'Conv2D';
46050 registerClass(Conv2D$1);
46051 class Conv3D$1 extends Conv {
46052 constructor(args) {
46053 super(3, args);
46054 Conv3D$1.verifyArgs(args);
46055 }
46056 getConfig() {
46057 const config = super.getConfig();
46058 delete config['rank'];
46059 return config;
46060 }
46061 static verifyArgs(args) {
46062 // config.kernelSize must be a number or array of numbers.
46063 if (typeof args.kernelSize !== 'number') {
46064 if (!(Array.isArray(args.kernelSize) &&
46065 (args.kernelSize.length === 1 || args.kernelSize.length === 3))) {
46066 throw new ValueError(`Conv3D expects config.kernelSize to be number or` +
46067 ` [number, number, number], but received ${JSON.stringify(args.kernelSize)}.`);
46068 }
46069 }
46070 }
46071 }
46072 /** @nocollapse */
46073 Conv3D$1.className = 'Conv3D';
46074 registerClass(Conv3D$1);
46075 class Conv2DTranspose extends Conv2D$1 {
46076 constructor(args) {
46077 super(args);
46078 this.inputSpec = [new InputSpec({ ndim: 4 })];
46079 if (this.padding !== 'same' && this.padding !== 'valid') {
46080 throw new ValueError(`Conv2DTranspose currently supports only padding modes 'same' ` +
46081 `and 'valid', but received padding mode ${this.padding}`);
46082 }
46083 }
46084 build(inputShape) {
46085 inputShape = getExactlyOneShape(inputShape);
46086 if (inputShape.length !== 4) {
46087 throw new ValueError('Input should have rank 4; Received input shape: ' +
46088 JSON.stringify(inputShape));
46089 }
46090 const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
46091 if (inputShape[channelAxis] == null) {
46092 throw new ValueError('The channel dimension of the inputs should be defined. ' +
46093 'Found `None`.');
46094 }
46095 const inputDim = inputShape[channelAxis];
46096 const kernelShape = this.kernelSize.concat([this.filters, inputDim]);
46097 this.kernel = this.addWeight('kernel', kernelShape, 'float32', this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
46098 if (this.useBias) {
46099 this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
46100 }
46101 // Set input spec.
46102 this.inputSpec =
46103 [new InputSpec({ ndim: 4, axes: { [channelAxis]: inputDim } })];
46104 this.built = true;
46105 }
46106 call(inputs, kwargs) {
46107 return tidy(() => {
46108 let input = getExactlyOneTensor(inputs);
46109 if (input.shape.length !== 4) {
46110 throw new ValueError(`Conv2DTranspose.call() expects input tensor to be rank-4, but ` +
46111 `received a tensor of rank-${input.shape.length}`);
46112 }
46113 const inputShape = input.shape;
46114 const batchSize = inputShape[0];
46115 let hAxis;
46116 let wAxis;
46117 if (this.dataFormat === 'channelsFirst') {
46118 hAxis = 2;
46119 wAxis = 3;
46120 }
46121 else {
46122 hAxis = 1;
46123 wAxis = 2;
46124 }
46125 const height = inputShape[hAxis];
46126 const width = inputShape[wAxis];
46127 const kernelH = this.kernelSize[0];
46128 const kernelW = this.kernelSize[1];
46129 const strideH = this.strides[0];
46130 const strideW = this.strides[1];
46131 // Infer the dynamic output shape.
46132 const outHeight = deconvLength(height, strideH, kernelH, this.padding);
46133 const outWidth = deconvLength(width, strideW, kernelW, this.padding);
46134 // Porting Note: We don't branch based on `this.dataFormat` here,
46135 // because
46136 // the tjfs-core function `conv2dTranspose` called below always
46137 // assumes channelsLast.
46138 const outputShape = [batchSize, outHeight, outWidth, this.filters];
46139 if (this.dataFormat !== 'channelsLast') {
46140 input = transpose(input, [0, 2, 3, 1]);
46141 }
46142 let outputs = conv2dTranspose(input, this.kernel.read(), outputShape, this.strides, this.padding);
46143 if (this.dataFormat !== 'channelsLast') {
46144 outputs = transpose(outputs, [0, 3, 1, 2]);
46145 }
46146 if (this.bias != null) {
46147 outputs =
46148 biasAdd(outputs, this.bias.read(), this.dataFormat);
46149 }
46150 if (this.activation != null) {
46151 outputs = this.activation.apply(outputs);
46152 }
46153 return outputs;
46154 });
46155 }
46156 computeOutputShape(inputShape) {
46157 inputShape = getExactlyOneShape(inputShape);
46158 const outputShape = inputShape.slice();
46159 let channelAxis;
46160 let heightAxis;
46161 let widthAxis;
46162 if (this.dataFormat === 'channelsFirst') {
46163 channelAxis = 1;
46164 heightAxis = 2;
46165 widthAxis = 3;
46166 }
46167 else {
46168 channelAxis = 3;
46169 heightAxis = 1;
46170 widthAxis = 2;
46171 }
46172 const kernelH = this.kernelSize[0];
46173 const kernelW = this.kernelSize[1];
46174 const strideH = this.strides[0];
46175 const strideW = this.strides[1];
46176 outputShape[channelAxis] = this.filters;
46177 outputShape[heightAxis] =
46178 deconvLength(outputShape[heightAxis], strideH, kernelH, this.padding);
46179 outputShape[widthAxis] =
46180 deconvLength(outputShape[widthAxis], strideW, kernelW, this.padding);
46181 return outputShape;
46182 }
46183 getConfig() {
46184 const config = super.getConfig();
46185 delete config['dilationRate'];
46186 return config;
46187 }
46188 }
46189 /** @nocollapse */
46190 Conv2DTranspose.className = 'Conv2DTranspose';
46191 registerClass(Conv2DTranspose);
46192 class Conv3DTranspose extends Conv3D$1 {
46193 constructor(args) {
46194 super(args);
46195 this.inputSpec = [new InputSpec({ ndim: 5 })];
46196 if (this.padding !== 'same' && this.padding !== 'valid') {
46197 throw new ValueError(`Conv3DTranspose currently supports only padding modes 'same' ` +
46198 `and 'valid', but received padding mode ${this.padding}`);
46199 }
46200 }
46201 build(inputShape) {
46202 inputShape = getExactlyOneShape(inputShape);
46203 if (inputShape.length !== 5) {
46204 throw new ValueError('Input should have rank 5; Received input shape: ' +
46205 JSON.stringify(inputShape));
46206 }
46207 const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
46208 if (inputShape[channelAxis] == null) {
46209 throw new ValueError('The channel dimension of the inputs should be defined. ' +
46210 'Found `None`.');
46211 }
46212 const inputDim = inputShape[channelAxis];
46213 const kernelShape = this.kernelSize.concat([this.filters, inputDim]);
46214 this.kernel = this.addWeight('kernel', kernelShape, 'float32', this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
46215 if (this.useBias) {
46216 this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
46217 }
46218 // Set input spec.
46219 this.inputSpec =
46220 [new InputSpec({ ndim: 5, axes: { [channelAxis]: inputDim } })];
46221 this.built = true;
46222 }
46223 call(inputs, kwargs) {
46224 return tidy(() => {
46225 let input = getExactlyOneTensor(inputs);
46226 if (input.shape.length !== 5) {
46227 throw new ValueError(`Conv3DTranspose.call() expects input tensor to be rank-4, but ` +
46228 `received a tensor of rank-${input.shape.length}`);
46229 }
46230 const inputShape = input.shape;
46231 const batchSize = inputShape[0];
46232 let hAxis;
46233 let wAxis;
46234 let dAxis;
46235 if (this.dataFormat === 'channelsFirst') {
46236 dAxis = 2;
46237 hAxis = 3;
46238 wAxis = 4;
46239 }
46240 else {
46241 dAxis = 1;
46242 hAxis = 2;
46243 wAxis = 3;
46244 }
46245 const depth = inputShape[dAxis];
46246 const height = inputShape[hAxis];
46247 const width = inputShape[wAxis];
46248 const kernelD = this.kernelSize[0];
46249 const kernelH = this.kernelSize[1];
46250 const kernelW = this.kernelSize[2];
46251 const strideD = this.strides[0];
46252 const strideH = this.strides[1];
46253 const strideW = this.strides[2];
46254 // Infer the dynamic output shape.
46255 const outDepth = deconvLength(depth, strideD, kernelD, this.padding);
46256 const outHeight = deconvLength(height, strideH, kernelH, this.padding);
46257 const outWidth = deconvLength(width, strideW, kernelW, this.padding);
46258 // Same as `conv2dTranspose`. We always assumes channelsLast.
46259 const outputShape = [batchSize, outDepth, outHeight, outWidth, this.filters];
46260 if (this.dataFormat !== 'channelsLast') {
46261 input = transpose(input, [0, 2, 3, 4, 1]);
46262 }
46263 let outputs = conv3dTranspose(input, this.kernel.read(), outputShape, this.strides, this.padding);
46264 if (this.dataFormat !== 'channelsLast') {
46265 outputs = transpose(outputs, [0, 4, 1, 2, 3]);
46266 }
46267 if (this.bias !== null) {
46268 outputs =
46269 biasAdd(outputs, this.bias.read(), this.dataFormat);
46270 }
46271 if (this.activation !== null) {
46272 outputs = this.activation.apply(outputs);
46273 }
46274 return outputs;
46275 });
46276 }
46277 computeOutputShape(inputShape) {
46278 inputShape = getExactlyOneShape(inputShape);
46279 const outputShape = inputShape.slice();
46280 let channelAxis;
46281 let depthAxis;
46282 let heightAxis;
46283 let widthAxis;
46284 if (this.dataFormat === 'channelsFirst') {
46285 channelAxis = 1;
46286 depthAxis = 2;
46287 heightAxis = 3;
46288 widthAxis = 4;
46289 }
46290 else {
46291 channelAxis = 4;
46292 depthAxis = 1;
46293 heightAxis = 2;
46294 widthAxis = 3;
46295 }
46296 const kernelD = this.kernelSize[0];
46297 const kernelH = this.kernelSize[1];
46298 const kernelW = this.kernelSize[2];
46299 const strideD = this.strides[0];
46300 const strideH = this.strides[1];
46301 const strideW = this.strides[2];
46302 outputShape[channelAxis] = this.filters;
46303 outputShape[depthAxis] =
46304 deconvLength(outputShape[depthAxis], strideD, kernelD, this.padding);
46305 outputShape[heightAxis] =
46306 deconvLength(outputShape[heightAxis], strideH, kernelH, this.padding);
46307 outputShape[widthAxis] =
46308 deconvLength(outputShape[widthAxis], strideW, kernelW, this.padding);
46309 return outputShape;
46310 }
46311 getConfig() {
46312 const config = super.getConfig();
46313 delete config['dilationRate'];
46314 return config;
46315 }
46316 }
46317 /** @nocollapse */
46318 Conv3DTranspose.className = 'Conv3DTranspose';
46319 registerClass(Conv3DTranspose);
46320 class SeparableConv extends Conv {
46321 constructor(rank, config) {
46322 super(rank, config);
46323 this.DEFAULT_DEPTHWISE_INITIALIZER = 'glorotUniform';
46324 this.DEFAULT_POINTWISE_INITIALIZER = 'glorotUniform';
46325 this.depthwiseKernel = null;
46326 this.pointwiseKernel = null;
46327 if (config.filters == null) {
46328 throw new ValueError('The `filters` configuration field is required by SeparableConv, ' +
46329 'but is unspecified.');
46330 }
46331 if (config.kernelInitializer != null || config.kernelRegularizer != null ||
46332 config.kernelConstraint != null) {
46333 throw new ValueError('Fields kernelInitializer, kernelRegularizer and kernelConstraint ' +
46334 'are invalid for SeparableConv2D. Use depthwiseInitializer, ' +
46335 'depthwiseRegularizer, depthwiseConstraint, pointwiseInitializer, ' +
46336 'pointwiseRegularizer and pointwiseConstraint instead.');
46337 }
46338 if (config.padding != null && config.padding !== 'same' &&
46339 config.padding !== 'valid') {
46340 throw new ValueError(`SeparableConv${this.rank}D supports only padding modes: ` +
46341 `'same' and 'valid', but received ${JSON.stringify(config.padding)}`);
46342 }
46343 this.depthMultiplier =
46344 config.depthMultiplier == null ? 1 : config.depthMultiplier;
46345 this.depthwiseInitializer = getInitializer(config.depthwiseInitializer || this.DEFAULT_DEPTHWISE_INITIALIZER);
46346 this.depthwiseRegularizer = getRegularizer(config.depthwiseRegularizer);
46347 this.depthwiseConstraint = getConstraint(config.depthwiseConstraint);
46348 this.pointwiseInitializer = getInitializer(config.depthwiseInitializer || this.DEFAULT_POINTWISE_INITIALIZER);
46349 this.pointwiseRegularizer = getRegularizer(config.pointwiseRegularizer);
46350 this.pointwiseConstraint = getConstraint(config.pointwiseConstraint);
46351 }
46352 build(inputShape) {
46353 inputShape = getExactlyOneShape(inputShape);
46354 if (inputShape.length < this.rank + 2) {
46355 throw new ValueError(`Inputs to SeparableConv${this.rank}D should have rank ` +
46356 `${this.rank + 2}, but received input shape: ` +
46357 `${JSON.stringify(inputShape)}`);
46358 }
46359 const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
46360 if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) {
46361 throw new ValueError(`The channel dimension of the inputs should be defined, ` +
46362 `but found ${JSON.stringify(inputShape[channelAxis])}`);
46363 }
46364 const inputDim = inputShape[channelAxis];
46365 const depthwiseKernelShape = this.kernelSize.concat([inputDim, this.depthMultiplier]);
46366 const pointwiseKernelShape = [];
46367 for (let i = 0; i < this.rank; ++i) {
46368 pointwiseKernelShape.push(1);
46369 }
46370 pointwiseKernelShape.push(inputDim * this.depthMultiplier, this.filters);
46371 const trainable = true;
46372 this.depthwiseKernel = this.addWeight('depthwise_kernel', depthwiseKernelShape, 'float32', this.depthwiseInitializer, this.depthwiseRegularizer, trainable, this.depthwiseConstraint);
46373 this.pointwiseKernel = this.addWeight('pointwise_kernel', pointwiseKernelShape, 'float32', this.pointwiseInitializer, this.pointwiseRegularizer, trainable, this.pointwiseConstraint);
46374 if (this.useBias) {
46375 this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, trainable, this.biasConstraint);
46376 }
46377 else {
46378 this.bias = null;
46379 }
46380 this.inputSpec =
46381 [new InputSpec({ ndim: this.rank + 2, axes: { [channelAxis]: inputDim } })];
46382 this.built = true;
46383 }
46384 call(inputs, kwargs) {
46385 return tidy(() => {
46386 inputs = getExactlyOneTensor(inputs);
46387 let output;
46388 if (this.rank === 1) {
46389 throw new NotImplementedError('1D separable convolution is not implemented yet.');
46390 }
46391 else if (this.rank === 2) {
46392 if (this.dataFormat === 'channelsFirst') {
46393 inputs = transpose(inputs, [0, 2, 3, 1]); // NCHW -> NHWC.
46394 }
46395 output = separableConv2d(inputs, this.depthwiseKernel.read(), this.pointwiseKernel.read(), this.strides, this.padding, this.dilationRate, 'NHWC');
46396 }
46397 if (this.useBias) {
46398 output = biasAdd(output, this.bias.read(), this.dataFormat);
46399 }
46400 if (this.activation != null) {
46401 output = this.activation.apply(output);
46402 }
46403 if (this.dataFormat === 'channelsFirst') {
46404 output = transpose(output, [0, 3, 1, 2]); // NHWC -> NCHW.
46405 }
46406 return output;
46407 });
46408 }
46409 getConfig() {
46410 const config = super.getConfig();
46411 delete config['rank'];
46412 delete config['kernelInitializer'];
46413 delete config['kernelRegularizer'];
46414 delete config['kernelConstraint'];
46415 config['depthwiseInitializer'] =
46416 serializeInitializer(this.depthwiseInitializer);
46417 config['pointwiseInitializer'] =
46418 serializeInitializer(this.pointwiseInitializer);
46419 config['depthwiseRegularizer'] =
46420 serializeRegularizer(this.depthwiseRegularizer);
46421 config['pointwiseRegularizer'] =
46422 serializeRegularizer(this.pointwiseRegularizer);
46423 config['depthwiseConstraint'] =
46424 serializeConstraint(this.depthwiseConstraint);
46425 config['pointwiseConstraint'] =
46426 serializeConstraint(this.pointwiseConstraint);
46427 return config;
46428 }
46429 }
46430 /** @nocollapse */
46431 SeparableConv.className = 'SeparableConv';
46432 class SeparableConv2D extends SeparableConv {
46433 constructor(args) {
46434 super(2, args);
46435 }
46436 }
46437 /** @nocollapse */
46438 SeparableConv2D.className = 'SeparableConv2D';
46439 registerClass(SeparableConv2D);
46440 class Conv1D extends Conv {
46441 constructor(args) {
46442 super(1, args);
46443 Conv1D.verifyArgs(args);
46444 this.inputSpec = [{ ndim: 3 }];
46445 }
46446 getConfig() {
46447 const config = super.getConfig();
46448 delete config['rank'];
46449 delete config['dataFormat'];
46450 return config;
46451 }
46452 static verifyArgs(args) {
46453 // config.kernelSize must be a number or array of numbers.
46454 if (typeof args.kernelSize !== 'number' &&
46455 !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 1)) {
46456 throw new ValueError(`Conv1D expects config.kernelSize to be number or number[] with ` +
46457 `length 1, but received ${JSON.stringify(args.kernelSize)}.`);
46458 }
46459 }
46460 }
46461 /** @nocollapse */
46462 Conv1D.className = 'Conv1D';
46463 registerClass(Conv1D);
46464 class Cropping2D extends Layer {
46465 constructor(args) {
46466 super(args);
46467 if (typeof args.cropping === 'number') {
46468 this.cropping =
46469 [[args.cropping, args.cropping], [args.cropping, args.cropping]];
46470 }
46471 else if (typeof args.cropping[0] === 'number') {
46472 this.cropping = [
46473 [args.cropping[0], args.cropping[0]],
46474 [args.cropping[1], args.cropping[1]]
46475 ];
46476 }
46477 else {
46478 this.cropping = args.cropping;
46479 }
46480 this.dataFormat =
46481 args.dataFormat === undefined ? 'channelsLast' : args.dataFormat;
46482 this.inputSpec = [{ ndim: 4 }];
46483 }
46484 computeOutputShape(inputShape) {
46485 if (this.dataFormat === 'channelsFirst') {
46486 return [
46487 inputShape[0], inputShape[1],
46488 inputShape[2] - this.cropping[0][0] - this.cropping[0][1],
46489 inputShape[3] - this.cropping[1][0] - this.cropping[1][1]
46490 ];
46491 }
46492 else {
46493 return [
46494 inputShape[0],
46495 inputShape[1] - this.cropping[0][0] - this.cropping[0][1],
46496 inputShape[2] - this.cropping[1][0] - this.cropping[1][1], inputShape[3]
46497 ];
46498 }
46499 }
46500 call(inputs, kwargs) {
46501 return tidy(() => {
46502 inputs = getExactlyOneTensor(inputs);
46503 if (this.dataFormat === 'channelsLast') {
46504 const hSliced = sliceAlongAxis(inputs, this.cropping[0][0], inputs.shape[1] - this.cropping[0][0] - this.cropping[0][1], 2);
46505 return sliceAlongAxis(hSliced, this.cropping[1][0], inputs.shape[2] - this.cropping[1][1] - this.cropping[1][0], 3);
46506 }
46507 else {
46508 const hSliced = sliceAlongAxis(inputs, this.cropping[0][0], inputs.shape[2] - this.cropping[0][0] - this.cropping[0][1], 3);
46509 return sliceAlongAxis(hSliced, this.cropping[1][0], inputs.shape[3] - this.cropping[1][1] - this.cropping[1][0], 4);
46510 }
46511 });
46512 }
46513 getConfig() {
46514 const config = { cropping: this.cropping, dataFormat: this.dataFormat };
46515 const baseConfig = super.getConfig();
46516 Object.assign(config, baseConfig);
46517 return config;
46518 }
46519 }
46520 /** @nocollapse */
46521 Cropping2D.className = 'Cropping2D';
46522 registerClass(Cropping2D);
46523 class UpSampling2D extends Layer {
46524 constructor(args) {
46525 super(args);
46526 this.DEFAULT_SIZE = [2, 2];
46527 this.inputSpec = [{ ndim: 4 }];
46528 this.size = args.size == null ? this.DEFAULT_SIZE : args.size;
46529 this.dataFormat =
46530 args.dataFormat == null ? 'channelsLast' : args.dataFormat;
46531 checkDataFormat(this.dataFormat);
46532 this.interpolation =
46533 args.interpolation == null ? 'nearest' : args.interpolation;
46534 checkInterpolationFormat(this.interpolation);
46535 }
46536 computeOutputShape(inputShape) {
46537 if (this.dataFormat === 'channelsFirst') {
46538 const height = inputShape[2] == null ? null : this.size[0] * inputShape[2];
46539 const width = inputShape[3] == null ? null : this.size[1] * inputShape[3];
46540 return [inputShape[0], inputShape[1], height, width];
46541 }
46542 else {
46543 const height = inputShape[1] == null ? null : this.size[0] * inputShape[1];
46544 const width = inputShape[2] == null ? null : this.size[1] * inputShape[2];
46545 return [inputShape[0], height, width, inputShape[3]];
46546 }
46547 }
46548 call(inputs, kwargs) {
46549 return tidy(() => {
46550 let input = getExactlyOneTensor(inputs);
46551 const inputShape = input.shape;
46552 if (this.dataFormat === 'channelsFirst') {
46553 input = transpose(input, [0, 2, 3, 1]);
46554 const height = this.size[0] * inputShape[2];
46555 const width = this.size[1] * inputShape[3];
46556 const resized = this.interpolation === 'nearest' ?
46557 image.resizeNearestNeighbor(input, [height, width]) :
46558 image.resizeBilinear(input, [height, width]);
46559 return transpose(resized, [0, 3, 1, 2]);
46560 }
46561 else {
46562 const height = this.size[0] * inputShape[1];
46563 const width = this.size[1] * inputShape[2];
46564 return this.interpolation === 'nearest' ?
46565 image.resizeNearestNeighbor(input, [height, width]) :
46566 image.resizeBilinear(input, [height, width]);
46567 }
46568 });
46569 }
46570 getConfig() {
46571 const config = {
46572 size: this.size,
46573 dataFormat: this.dataFormat,
46574 interpolation: this.interpolation
46575 };
46576 const baseConfig = super.getConfig();
46577 Object.assign(config, baseConfig);
46578 return config;
46579 }
46580 }
46581 /** @nocollapse */
46582 UpSampling2D.className = 'UpSampling2D';
46583 registerClass(UpSampling2D);
46584
46585 /**
46586 * @license
46587 * Copyright 2018 Google LLC
46588 *
46589 * Use of this source code is governed by an MIT-style
46590 * license that can be found in the LICENSE file or at
46591 * https://opensource.org/licenses/MIT.
46592 * =============================================================================
46593 */
46594 /**
46595 * 2D convolution with separable filters.
46596 * @param x Input tensor.
46597 * @param depthwiseKernel Convolution kernel for depthwise convolution.
46598 * @param strides Strides (Array of two integers).
46599 * @param padding Padding model.
46600 * @param dataFormat Data format.
46601 * @param dilationRate Array of two integers, dilation rates for the separable
46602 * convolution.
46603 * @returns Output tensor.
46604 * @throws ValueError If depthwiseKernel is not a 4D array.
46605 */
46606 function depthwiseConv2d$2(x, depthwiseKernel, strides = [1, 1], padding = 'valid', dataFormat, dilationRate) {
46607 return tidy(() => {
46608 if (dataFormat == null) {
46609 dataFormat = imageDataFormat();
46610 }
46611 checkDataFormat(dataFormat);
46612 let y = preprocessConv2DInput(x, dataFormat);
46613 if (x.rank !== 4) {
46614 throw new ValueError(`Input for depthwiseConv2d is required to be 4-D, but is instead ` +
46615 `${x.rank}-D`);
46616 }
46617 if (depthwiseKernel.rank !== 4) {
46618 throw new ValueError(`depthwiseKernel is required to be 4-D, but is instead ` +
46619 `${depthwiseKernel.rank}-D`);
46620 }
46621 y = depthwiseConv2d(y, depthwiseKernel, strides, padding === 'same' ? 'same' : 'valid', 'NHWC', dilationRate);
46622 if (dataFormat === 'channelsFirst') {
46623 y = transpose(y, [0, 3, 1, 2]);
46624 }
46625 return y;
46626 });
46627 }
46628 class DepthwiseConv2D extends BaseConv {
46629 constructor(args) {
46630 super(2, args);
46631 this.depthwiseKernel = null;
46632 this.depthMultiplier =
46633 args.depthMultiplier == null ? 1 : args.depthMultiplier;
46634 this.depthwiseInitializer = getInitializer(args.depthwiseInitializer || this.DEFAULT_KERNEL_INITIALIZER);
46635 this.depthwiseConstraint = getConstraint(args.depthwiseConstraint);
46636 this.depthwiseRegularizer = getRegularizer(args.depthwiseRegularizer);
46637 }
46638 build(inputShape) {
46639 inputShape = getExactlyOneShape(inputShape);
46640 if (inputShape.length < 4) {
46641 throw new ValueError(`Inputs to DepthwiseConv2D should have rank 4. ` +
46642 `Received input shape: ${JSON.stringify(inputShape)}.`);
46643 }
46644 const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : 3;
46645 if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) {
46646 throw new ValueError('The channel dimension of the inputs to DepthwiseConv2D should ' +
46647 `be defined, but is not (${inputShape[channelAxis]}).`);
46648 }
46649 const inputDim = inputShape[channelAxis];
46650 const depthwiseKernelShape = [
46651 this.kernelSize[0], this.kernelSize[1], inputDim, this.depthMultiplier
46652 ];
46653 this.depthwiseKernel = this.addWeight('depthwise_kernel', depthwiseKernelShape, null, this.depthwiseInitializer, this.depthwiseRegularizer, true, this.depthwiseConstraint);
46654 if (this.useBias) {
46655 this.bias = this.addWeight('bias', [inputDim * this.depthMultiplier], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
46656 }
46657 else {
46658 this.bias = null;
46659 }
46660 this.built = true;
46661 }
46662 call(inputs, kwargs) {
46663 return tidy(() => {
46664 inputs = getExactlyOneTensor(inputs);
46665 let outputs = depthwiseConv2d$2(inputs, this.depthwiseKernel.read(), this.strides, this.padding, this.dataFormat, null);
46666 // TODO(cais): Add support for dilation.
46667 if (this.useBias) {
46668 outputs = biasAdd(outputs, this.bias.read(), this.dataFormat);
46669 }
46670 if (this.activation != null) {
46671 outputs = this.activation.apply(outputs);
46672 }
46673 return outputs;
46674 });
46675 }
46676 computeOutputShape(inputShape) {
46677 inputShape = getExactlyOneShape(inputShape);
46678 const rows = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
46679 const cols = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
46680 const outFilters = this.dataFormat === 'channelsFirst' ?
46681 inputShape[1] * this.depthMultiplier :
46682 inputShape[3] * this.depthMultiplier;
46683 const outRows = convOutputLength(rows, this.kernelSize[0], this.padding, this.strides[0]);
46684 const outCols = convOutputLength(cols, this.kernelSize[1], this.padding, this.strides[1]);
46685 if (this.dataFormat === 'channelsFirst') {
46686 return [inputShape[0], outFilters, outRows, outCols];
46687 }
46688 else {
46689 // In this case, assume 'channelsLast'.
46690 return [inputShape[0], outRows, outCols, outFilters];
46691 }
46692 }
46693 getConfig() {
46694 const config = super.getConfig();
46695 config['depthMultiplier'] = this.depthMultiplier;
46696 config['depthwiseInitializer'] =
46697 serializeInitializer(this.depthwiseInitializer);
46698 config['depthwiseRegularizer'] =
46699 serializeRegularizer(this.depthwiseRegularizer);
46700 config['depthwiseConstraint'] =
46701 serializeConstraint(this.depthwiseRegularizer);
46702 return config;
46703 }
46704 }
46705 /** @nocollapse */
46706 DepthwiseConv2D.className = 'DepthwiseConv2D';
46707 registerClass(DepthwiseConv2D);
46708
46709 /**
46710 * @license
46711 * Copyright 2018 Google LLC
46712 *
46713 * Use of this source code is governed by an MIT-style
46714 * license that can be found in the LICENSE file or at
46715 * https://opensource.org/licenses/MIT.
46716 * =============================================================================
46717 */
46718 /**
46719 * Standardize `apply()` args to a single list of tensor inputs.
46720 *
46721 * When running a model loaded from file, the input tensors `initialState` and
46722 * `constants` are passed to `RNN.apply()` as part of `inputs` instead of the
46723 * dedicated kwargs fields. `inputs` consists of
46724 * `[inputs, initialState0, initialState1, ..., constant0, constant1]` in this
46725 * case.
46726 * This method makes sure that arguments are
46727 * separated and that `initialState` and `constants` are `Array`s of tensors
46728 * (or None).
46729 *
46730 * @param inputs Tensor or `Array` of tensors.
46731 * @param initialState Tensor or `Array` of tensors or `null`/`undefined`.
46732 * @param constants Tensor or `Array` of tensors or `null`/`undefined`.
46733 * @returns An object consisting of
46734 * inputs: A tensor.
46735 * initialState: `Array` of tensors or `null`.
46736 * constants: `Array` of tensors or `null`.
46737 * @throws ValueError, if `inputs` is an `Array` but either `initialState` or
46738 * `constants` is provided.
46739 */
46740 function standardizeArgs(inputs, initialState, constants, numConstants) {
46741 if (Array.isArray(inputs)) {
46742 if (initialState != null || constants != null) {
46743 throw new ValueError('When inputs is an array, neither initialState or constants ' +
46744 'should be provided');
46745 }
46746 if (numConstants != null) {
46747 constants = inputs.slice(inputs.length - numConstants, inputs.length);
46748 inputs = inputs.slice(0, inputs.length - numConstants);
46749 }
46750 if (inputs.length > 1) {
46751 initialState = inputs.slice(1, inputs.length);
46752 }
46753 inputs = inputs[0];
46754 }
46755 function toListOrNull(x) {
46756 if (x == null || Array.isArray(x)) {
46757 return x;
46758 }
46759 else {
46760 return [x];
46761 }
46762 }
46763 initialState = toListOrNull(initialState);
46764 constants = toListOrNull(constants);
46765 return { inputs, initialState, constants };
46766 }
46767 /**
46768 * Iterates over the time dimension of a tensor.
46769 *
46770 * @param stepFunction RNN step function.
46771 * Parameters:
46772 * inputs: tensor with shape `[samples, ...]` (no time dimension),
46773 * representing input for the batch of samples at a certain time step.
46774 * states: an Array of tensors.
46775 * Returns:
46776 * outputs: tensor with shape `[samples, outputDim]` (no time dimension).
46777 * newStates: list of tensors, same length and shapes as `states`. The first
46778 * state in the list must be the output tensor at the previous timestep.
46779 * @param inputs Tensor of temporal data of shape `[samples, time, ...]` (at
46780 * least 3D).
46781 * @param initialStates Tensor with shape `[samples, outputDim]` (no time
46782 * dimension), containing the initial values of the states used in the step
46783 * function.
46784 * @param goBackwards If `true`, do the iteration over the time dimension in
46785 * reverse order and return the reversed sequence.
46786 * @param mask Binary tensor with shape `[sample, time, 1]`, with a zero for
46787 * every element that is masked.
46788 * @param constants An Array of constant values passed at each step.
46789 * @param unroll Whether to unroll the RNN or to use a symbolic loop. *Not*
46790 * applicable to this imperative deeplearn.js backend. Its value is ignored.
46791 * @param needPerStepOutputs Whether the per-step outputs are to be
46792 * concatenated into a single tensor and returned (as the second return
46793 * value). Default: `false`. This arg is included so that the relatively
46794 * expensive concatenation of the stepwise outputs can be omitted unless
46795 * the stepwise outputs need to be kept (e.g., for an LSTM layer of which
46796 * `returnSequence` is `true`.)
46797 * @returns An Array: `[lastOutput, outputs, newStates]`.
46798 * lastOutput: the lastest output of the RNN, of shape `[samples, ...]`.
46799 * outputs: tensor with shape `[samples, time, ...]` where each entry
46800 * `output[s, t]` is the output of the step function at time `t` for sample
46801 * `s`. This return value is provided if and only if the
46802 * `needPerStepOutputs` is set as `true`. If it is set as `false`, this
46803 * return value will be `undefined`.
46804 * newStates: Array of tensors, latest states returned by the step function,
46805 * of shape `(samples, ...)`.
46806 * @throws ValueError If input dimension is less than 3.
46807 *
46808 * TODO(nielsene): This needs to be tidy-ed.
46809 */
46810 function rnn(stepFunction, inputs, initialStates, goBackwards = false, mask, constants, unroll = false, needPerStepOutputs = false) {
46811 return tidy(() => {
46812 const ndim = inputs.shape.length;
46813 if (ndim < 3) {
46814 throw new ValueError(`Input should be at least 3D, but is ${ndim}D.`);
46815 }
46816 // Transpose to time-major, i.e., from [batch, time, ...] to [time, batch,
46817 // ...].
46818 const axes = [1, 0].concat(range$1(2, ndim));
46819 inputs = transpose(inputs, axes);
46820 if (constants != null) {
46821 throw new NotImplementedError('The rnn() functoin of the deeplearn.js backend does not support ' +
46822 'constants yet.');
46823 }
46824 // Porting Note: the unroll option is ignored by the imperative backend.
46825 if (unroll) {
46826 console.warn('Backend rnn(): the unroll = true option is not applicable to the ' +
46827 'imperative deeplearn.js backend.');
46828 }
46829 if (mask != null) {
46830 mask = cast(cast(mask, 'bool'), 'float32');
46831 if (mask.rank === ndim - 1) {
46832 mask = expandDims(mask, -1);
46833 }
46834 mask = transpose(mask, axes);
46835 }
46836 if (goBackwards) {
46837 inputs = reverse(inputs, 0);
46838 if (mask != null) {
46839 mask = reverse(mask, 0);
46840 }
46841 }
46842 // Porting Note: PyKeras with TensorFlow backend uses a symbolic loop
46843 // (tf.while_loop). But for the imperative deeplearn.js backend, we just
46844 // use the usual TypeScript control flow to iterate over the time steps in
46845 // the inputs.
46846 // Porting Note: PyKeras patches a "_use_learning_phase" attribute to
46847 // outputs.
46848 // This is not idiomatic in TypeScript. The info regarding whether we are
46849 // in a learning (i.e., training) phase for RNN is passed in a different
46850 // way.
46851 const perStepOutputs = [];
46852 let lastOutput;
46853 let states = initialStates;
46854 const timeSteps = inputs.shape[0];
46855 const perStepInputs = unstack(inputs);
46856 let perStepMasks;
46857 if (mask != null) {
46858 perStepMasks = unstack(mask);
46859 }
46860 for (let t = 0; t < timeSteps; ++t) {
46861 const currentInput = perStepInputs[t];
46862 const stepOutputs = tidy(() => stepFunction(currentInput, states));
46863 if (mask == null) {
46864 lastOutput = stepOutputs[0];
46865 states = stepOutputs[1];
46866 }
46867 else {
46868 const maskedOutputs = tidy(() => {
46869 const stepMask = perStepMasks[t];
46870 const negStepMask = sub(onesLike(stepMask), stepMask);
46871 // TODO(cais): Would tfc.where() be better for performance?
46872 const output = add$1(mul(stepOutputs[0], stepMask), mul(states[0], negStepMask));
46873 const newStates = states.map((state, i) => {
46874 return add$1(mul(stepOutputs[1][i], stepMask), mul(state, negStepMask));
46875 });
46876 return { output, newStates };
46877 });
46878 lastOutput = maskedOutputs.output;
46879 states = maskedOutputs.newStates;
46880 }
46881 if (needPerStepOutputs) {
46882 perStepOutputs.push(lastOutput);
46883 }
46884 }
46885 let outputs;
46886 if (needPerStepOutputs) {
46887 const axis = 1;
46888 outputs = stack(perStepOutputs, axis);
46889 }
46890 return [lastOutput, outputs, states];
46891 });
46892 }
46893 class RNN extends Layer {
46894 constructor(args) {
46895 super(args);
46896 let cell;
46897 if (args.cell == null) {
46898 throw new ValueError('cell property is missing for the constructor of RNN.');
46899 }
46900 else if (Array.isArray(args.cell)) {
46901 cell = new StackedRNNCells({ cells: args.cell });
46902 }
46903 else {
46904 cell = args.cell;
46905 }
46906 if (cell.stateSize == null) {
46907 throw new ValueError('The RNN cell should have an attribute `stateSize` (tuple of ' +
46908 'integers, one integer per RNN state).');
46909 }
46910 this.cell = cell;
46911 this.returnSequences =
46912 args.returnSequences == null ? false : args.returnSequences;
46913 this.returnState = args.returnState == null ? false : args.returnState;
46914 this.goBackwards = args.goBackwards == null ? false : args.goBackwards;
46915 this._stateful = args.stateful == null ? false : args.stateful;
46916 this.unroll = args.unroll == null ? false : args.unroll;
46917 this.supportsMasking = true;
46918 this.inputSpec = [new InputSpec({ ndim: 3 })];
46919 this.stateSpec = null;
46920 this.states_ = null;
46921 // TODO(cais): Add constantsSpec and numConstants.
46922 this.numConstants = null;
46923 // TODO(cais): Look into the use of initial_state in the kwargs of the
46924 // constructor.
46925 this.keptStates = [];
46926 }
46927 // Porting Note: This is the equivalent of `RNN.states` property getter in
46928 // PyKeras.
46929 getStates() {
46930 if (this.states_ == null) {
46931 const numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
46932 return range$1(0, numStates).map(x => null);
46933 }
46934 else {
46935 return this.states_;
46936 }
46937 }
46938 // Porting Note: This is the equivalent of the `RNN.states` property setter in
46939 // PyKeras.
46940 setStates(states) {
46941 this.states_ = states;
46942 }
46943 computeOutputShape(inputShape) {
46944 if (isArrayOfShapes(inputShape)) {
46945 inputShape = inputShape[0];
46946 }
46947 inputShape = inputShape;
46948 // TODO(cais): Remove the casting once stacked RNN cells become supported.
46949 let stateSize = this.cell.stateSize;
46950 if (!Array.isArray(stateSize)) {
46951 stateSize = [stateSize];
46952 }
46953 const outputDim = stateSize[0];
46954 let outputShape;
46955 if (this.returnSequences) {
46956 outputShape = [inputShape[0], inputShape[1], outputDim];
46957 }
46958 else {
46959 outputShape = [inputShape[0], outputDim];
46960 }
46961 if (this.returnState) {
46962 const stateShape = [];
46963 for (const dim of stateSize) {
46964 stateShape.push([inputShape[0], dim]);
46965 }
46966 return [outputShape].concat(stateShape);
46967 }
46968 else {
46969 return outputShape;
46970 }
46971 }
46972 computeMask(inputs, mask) {
46973 return tidy(() => {
46974 if (Array.isArray(mask)) {
46975 mask = mask[0];
46976 }
46977 const outputMask = this.returnSequences ? mask : null;
46978 if (this.returnState) {
46979 const stateMask = this.states.map(s => null);
46980 return [outputMask].concat(stateMask);
46981 }
46982 else {
46983 return outputMask;
46984 }
46985 });
46986 }
46987 /**
46988 * Get the current state tensors of the RNN.
46989 *
46990 * If the state hasn't been set, return an array of `null`s of the correct
46991 * length.
46992 */
46993 get states() {
46994 if (this.states_ == null) {
46995 const numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
46996 const output = [];
46997 for (let i = 0; i < numStates; ++i) {
46998 output.push(null);
46999 }
47000 return output;
47001 }
47002 else {
47003 return this.states_;
47004 }
47005 }
47006 set states(s) {
47007 this.states_ = s;
47008 }
47009 build(inputShape) {
47010 // Note inputShape will be an Array of Shapes of initial states and
47011 // constants if these are passed in apply().
47012 const constantShape = null;
47013 if (this.numConstants != null) {
47014 throw new NotImplementedError('Constants support is not implemented in RNN yet.');
47015 }
47016 if (isArrayOfShapes(inputShape)) {
47017 inputShape = inputShape[0];
47018 }
47019 inputShape = inputShape;
47020 const batchSize = this.stateful ? inputShape[0] : null;
47021 const inputDim = inputShape.slice(2);
47022 this.inputSpec[0] = new InputSpec({ shape: [batchSize, null, ...inputDim] });
47023 // Allow cell (if RNNCell Layer) to build before we set or validate
47024 // stateSpec.
47025 const stepInputShape = [inputShape[0]].concat(inputShape.slice(2));
47026 if (constantShape != null) {
47027 throw new NotImplementedError('Constants support is not implemented in RNN yet.');
47028 }
47029 else {
47030 this.cell.build(stepInputShape);
47031 }
47032 // Set or validate stateSpec.
47033 let stateSize;
47034 if (Array.isArray(this.cell.stateSize)) {
47035 stateSize = this.cell.stateSize;
47036 }
47037 else {
47038 stateSize = [this.cell.stateSize];
47039 }
47040 if (this.stateSpec != null) {
47041 if (!arraysEqual(this.stateSpec.map(spec => spec.shape[spec.shape.length - 1]), stateSize)) {
47042 throw new ValueError(`An initialState was passed that is not compatible with ` +
47043 `cell.stateSize. Received stateSpec=${this.stateSpec}; ` +
47044 `However cell.stateSize is ${this.cell.stateSize}`);
47045 }
47046 }
47047 else {
47048 this.stateSpec =
47049 stateSize.map(dim => new InputSpec({ shape: [null, dim] }));
47050 }
47051 if (this.stateful) {
47052 this.resetStates();
47053 }
47054 }
47055 /**
47056 * Reset the state tensors of the RNN.
47057 *
47058 * If the `states` argument is `undefined` or `null`, will set the
47059 * state tensor(s) of the RNN to all-zero tensors of the appropriate
47060 * shape(s).
47061 *
47062 * If `states` is provided, will set the state tensors of the RNN to its
47063 * value.
47064 *
47065 * @param states Optional externally-provided initial states.
47066 * @param training Whether this call is done during training. For stateful
47067 * RNNs, this affects whether the old states are kept or discarded. In
47068 * particular, if `training` is `true`, the old states will be kept so
47069 * that subsequent backpropgataion through time (BPTT) may work properly.
47070 * Else, the old states will be discarded.
47071 */
47072 resetStates(states, training = false) {
47073 tidy(() => {
47074 if (!this.stateful) {
47075 throw new AttributeError('Cannot call resetStates() on an RNN Layer that is not stateful.');
47076 }
47077 const batchSize = this.inputSpec[0].shape[0];
47078 if (batchSize == null) {
47079 throw new ValueError('If an RNN is stateful, it needs to know its batch size. Specify ' +
47080 'the batch size of your input tensors: \n' +
47081 '- If using a Sequential model, specify the batch size by ' +
47082 'passing a `batchInputShape` option to your first layer.\n' +
47083 '- If using the functional API, specify the batch size by ' +
47084 'passing a `batchShape` option to your Input layer.');
47085 }
47086 // Initialize state if null.
47087 if (this.states_ == null) {
47088 if (Array.isArray(this.cell.stateSize)) {
47089 this.states_ =
47090 this.cell.stateSize.map(dim => zeros([batchSize, dim]));
47091 }
47092 else {
47093 this.states_ = [zeros([batchSize, this.cell.stateSize])];
47094 }
47095 }
47096 else if (states == null) {
47097 // Dispose old state tensors.
47098 dispose(this.states_);
47099 // For stateful RNNs, fully dispose kept old states.
47100 if (this.keptStates != null) {
47101 dispose(this.keptStates);
47102 this.keptStates = [];
47103 }
47104 if (Array.isArray(this.cell.stateSize)) {
47105 this.states_ =
47106 this.cell.stateSize.map(dim => zeros([batchSize, dim]));
47107 }
47108 else {
47109 this.states_[0] = zeros([batchSize, this.cell.stateSize]);
47110 }
47111 }
47112 else {
47113 if (!Array.isArray(states)) {
47114 states = [states];
47115 }
47116 if (states.length !== this.states_.length) {
47117 throw new ValueError(`Layer ${this.name} expects ${this.states_.length} state(s), ` +
47118 `but it received ${states.length} state value(s). Input ` +
47119 `received: ${states}`);
47120 }
47121 if (training === true) {
47122 // Store old state tensors for complete disposal later, i.e., during
47123 // the next no-arg call to this method. We do not dispose the old
47124 // states immediately because that BPTT (among other things) require
47125 // them.
47126 this.keptStates.push(this.states_.slice());
47127 }
47128 else {
47129 dispose(this.states_);
47130 }
47131 for (let index = 0; index < this.states_.length; ++index) {
47132 const value = states[index];
47133 const dim = Array.isArray(this.cell.stateSize) ?
47134 this.cell.stateSize[index] :
47135 this.cell.stateSize;
47136 const expectedShape = [batchSize, dim];
47137 if (!arraysEqual(value.shape, expectedShape)) {
47138 throw new ValueError(`State ${index} is incompatible with layer ${this.name}: ` +
47139 `expected shape=${expectedShape}, received shape=${value.shape}`);
47140 }
47141 this.states_[index] = value;
47142 }
47143 }
47144 this.states_ = this.states_.map(state => keep(state.clone()));
47145 });
47146 }
47147 apply(inputs, kwargs) {
47148 // TODO(cais): Figure out whether initialState is in kwargs or inputs.
47149 let initialState = kwargs == null ? null : kwargs['initialState'];
47150 let constants = kwargs == null ? null : kwargs['constants'];
47151 if (kwargs == null) {
47152 kwargs = {};
47153 }
47154 const standardized = standardizeArgs(inputs, initialState, constants, this.numConstants);
47155 inputs = standardized.inputs;
47156 initialState = standardized.initialState;
47157 constants = standardized.constants;
47158 // If any of `initial_state` or `constants` are specified and are
47159 // `tf.SymbolicTensor`s, then add them to the inputs and temporarily modify
47160 // the input_spec to include them.
47161 let additionalInputs = [];
47162 let additionalSpecs = [];
47163 if (initialState != null) {
47164 kwargs['initialState'] = initialState;
47165 additionalInputs = additionalInputs.concat(initialState);
47166 this.stateSpec = [];
47167 for (const state of initialState) {
47168 this.stateSpec.push(new InputSpec({ shape: state.shape }));
47169 }
47170 // TODO(cais): Use the following instead.
47171 // this.stateSpec = initialState.map(state => new InputSpec({shape:
47172 // state.shape}));
47173 additionalSpecs = additionalSpecs.concat(this.stateSpec);
47174 }
47175 if (constants != null) {
47176 kwargs['constants'] = constants;
47177 additionalInputs = additionalInputs.concat(constants);
47178 // TODO(cais): Add this.constantsSpec.
47179 this.numConstants = constants.length;
47180 }
47181 const isTensor = additionalInputs[0] instanceof SymbolicTensor;
47182 if (isTensor) {
47183 // Compute full input spec, including state and constants.
47184 const fullInput = [inputs].concat(additionalInputs);
47185 const fullInputSpec = this.inputSpec.concat(additionalSpecs);
47186 // Perform the call with temporarily replaced inputSpec.
47187 const originalInputSpec = this.inputSpec;
47188 this.inputSpec = fullInputSpec;
47189 const output = super.apply(fullInput, kwargs);
47190 this.inputSpec = originalInputSpec;
47191 return output;
47192 }
47193 else {
47194 return super.apply(inputs, kwargs);
47195 }
47196 }
47197 // tslint:disable-next-line:no-any
47198 call(inputs, kwargs) {
47199 // Input shape: `[samples, time (padded with zeros), input_dim]`.
47200 // Note that the .build() method of subclasses **must** define
47201 // this.inputSpec and this.stateSpec owith complete input shapes.
47202 return tidy(() => {
47203 const mask = kwargs == null ? null : kwargs['mask'];
47204 const training = kwargs == null ? null : kwargs['training'];
47205 let initialState = kwargs == null ? null : kwargs['initialState'];
47206 inputs = getExactlyOneTensor(inputs);
47207 if (initialState == null) {
47208 if (this.stateful) {
47209 initialState = this.states_;
47210 }
47211 else {
47212 initialState = this.getInitialState(inputs);
47213 }
47214 }
47215 const numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
47216 if (initialState.length !== numStates) {
47217 throw new ValueError(`RNN Layer has ${numStates} state(s) but was passed ` +
47218 `${initialState.length} initial state(s).`);
47219 }
47220 if (this.unroll) {
47221 console.warn('Ignoring unroll = true for RNN layer, due to imperative backend.');
47222 }
47223 const cellCallKwargs = { training };
47224 // TODO(cais): Add support for constants.
47225 const step = (inputs, states) => {
47226 // `inputs` and `states` are concatenated to form a single `Array` of
47227 // `tf.Tensor`s as the input to `cell.call()`.
47228 const outputs = this.cell.call([inputs].concat(states), cellCallKwargs);
47229 // Marshall the return value into output and new states.
47230 return [outputs[0], outputs.slice(1)];
47231 };
47232 // TODO(cais): Add support for constants.
47233 const rnnOutputs = rnn(step, inputs, initialState, this.goBackwards, mask, null, this.unroll, this.returnSequences);
47234 const lastOutput = rnnOutputs[0];
47235 const outputs = rnnOutputs[1];
47236 const states = rnnOutputs[2];
47237 if (this.stateful) {
47238 this.resetStates(states, training);
47239 }
47240 const output = this.returnSequences ? outputs : lastOutput;
47241 // TODO(cais): Porperty set learning phase flag.
47242 if (this.returnState) {
47243 return [output].concat(states);
47244 }
47245 else {
47246 return output;
47247 }
47248 });
47249 }
47250 getInitialState(inputs) {
47251 return tidy(() => {
47252 // Build an all-zero tensor of shape [samples, outputDim].
47253 // [Samples, timeSteps, inputDim].
47254 let initialState = zeros(inputs.shape);
47255 // [Samples].
47256 initialState = sum$1(initialState, [1, 2]);
47257 initialState = expandDims$1(initialState); // [Samples, 1].
47258 if (Array.isArray(this.cell.stateSize)) {
47259 return this.cell.stateSize.map(dim => dim > 1 ? tile$1(initialState, [1, dim]) : initialState);
47260 }
47261 else {
47262 return this.cell.stateSize > 1 ?
47263 [tile$1(initialState, [1, this.cell.stateSize])] :
47264 [initialState];
47265 }
47266 });
47267 }
47268 get trainableWeights() {
47269 if (!this.trainable) {
47270 return [];
47271 }
47272 // Porting Note: In TypeScript, `this` is always an instance of `Layer`.
47273 return this.cell.trainableWeights;
47274 }
47275 get nonTrainableWeights() {
47276 // Porting Note: In TypeScript, `this` is always an instance of `Layer`.
47277 if (!this.trainable) {
47278 return this.cell.weights;
47279 }
47280 return this.cell.nonTrainableWeights;
47281 }
47282 setFastWeightInitDuringBuild(value) {
47283 super.setFastWeightInitDuringBuild(value);
47284 if (this.cell != null) {
47285 this.cell.setFastWeightInitDuringBuild(value);
47286 }
47287 }
47288 getConfig() {
47289 const baseConfig = super.getConfig();
47290 const config = {
47291 returnSequences: this.returnSequences,
47292 returnState: this.returnState,
47293 goBackwards: this.goBackwards,
47294 stateful: this.stateful,
47295 unroll: this.unroll,
47296 };
47297 if (this.numConstants != null) {
47298 config['numConstants'] = this.numConstants;
47299 }
47300 const cellConfig = this.cell.getConfig();
47301 if (this.getClassName() === RNN.className) {
47302 config['cell'] = {
47303 'className': this.cell.getClassName(),
47304 'config': cellConfig,
47305 };
47306 }
47307 // this order is necessary, to prevent cell name from replacing layer name
47308 return Object.assign({}, cellConfig, baseConfig, config);
47309 }
47310 /** @nocollapse */
47311 static fromConfig(cls, config, customObjects = {}) {
47312 const cellConfig = config['cell'];
47313 const cell = deserialize(cellConfig, customObjects);
47314 return new cls(Object.assign(config, { cell }));
47315 }
47316 }
47317 /** @nocollapse */
47318 RNN.className = 'RNN';
47319 registerClass(RNN);
47320 // Porting Note: This is a common parent class for RNN cells. There is no
47321 // equivalent of this in PyKeras. Having a common parent class forgoes the
47322 // need for `has_attr(cell, ...)` checks or its TypeScript equivalent.
47323 /**
47324 * An RNNCell layer.
47325 *
47326 * @doc {heading: 'Layers', subheading: 'Classes'}
47327 */
47328 class RNNCell extends Layer {
47329 }
47330 class SimpleRNNCell extends RNNCell {
47331 constructor(args) {
47332 super(args);
47333 this.DEFAULT_ACTIVATION = 'tanh';
47334 this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
47335 this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
47336 this.DEFAULT_BIAS_INITIALIZER = 'zeros';
47337 this.units = args.units;
47338 assertPositiveInteger(this.units, `units`);
47339 this.activation = getActivation(args.activation == null ? this.DEFAULT_ACTIVATION : args.activation);
47340 this.useBias = args.useBias == null ? true : args.useBias;
47341 this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
47342 this.recurrentInitializer = getInitializer(args.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER);
47343 this.biasInitializer =
47344 getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
47345 this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
47346 this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
47347 this.biasRegularizer = getRegularizer(args.biasRegularizer);
47348 this.kernelConstraint = getConstraint(args.kernelConstraint);
47349 this.recurrentConstraint = getConstraint(args.recurrentConstraint);
47350 this.biasConstraint = getConstraint(args.biasConstraint);
47351 this.dropout = min$1([1, max$1([0, args.dropout == null ? 0 : args.dropout])]);
47352 this.recurrentDropout = min$1([
47353 1,
47354 max$1([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
47355 ]);
47356 this.dropoutFunc = args.dropoutFunc;
47357 this.stateSize = this.units;
47358 this.dropoutMask = null;
47359 this.recurrentDropoutMask = null;
47360 }
47361 build(inputShape) {
47362 inputShape = getExactlyOneShape(inputShape);
47363 // TODO(cais): Use regularizer.
47364 this.kernel = this.addWeight('kernel', [inputShape[inputShape.length - 1], this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
47365 this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
47366 if (this.useBias) {
47367 this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
47368 }
47369 else {
47370 this.bias = null;
47371 }
47372 this.built = true;
47373 }
47374 // Porting Note: PyKeras' equivalent of this method takes two tensor inputs:
47375 // `inputs` and `states`. Here, the two tensors are combined into an
47376 // `Tensor[]` Array as the first input argument.
47377 // Similarly, PyKeras' equivalent of this method returns two values:
47378 // `output` and `[output]`. Here the two are combined into one length-2
47379 // `Tensor[]`, consisting of `output` repeated.
47380 call(inputs, kwargs) {
47381 return tidy(() => {
47382 inputs = inputs;
47383 if (inputs.length !== 2) {
47384 throw new ValueError(`SimpleRNNCell expects 2 input Tensors, got ${inputs.length}.`);
47385 }
47386 let prevOutput = inputs[1];
47387 inputs = inputs[0];
47388 const training = kwargs['training'] == null ? false : kwargs['training'];
47389 if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
47390 this.dropoutMask = generateDropoutMask({
47391 ones: () => onesLike(inputs),
47392 rate: this.dropout,
47393 training,
47394 dropoutFunc: this.dropoutFunc,
47395 });
47396 }
47397 if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
47398 this.recurrentDropoutMask == null) {
47399 this.recurrentDropoutMask = generateDropoutMask({
47400 ones: () => onesLike(prevOutput),
47401 rate: this.recurrentDropout,
47402 training,
47403 dropoutFunc: this.dropoutFunc,
47404 });
47405 }
47406 let h;
47407 const dpMask = this.dropoutMask;
47408 const recDpMask = this.recurrentDropoutMask;
47409 if (dpMask != null) {
47410 h = dot$1(mul(inputs, dpMask), this.kernel.read());
47411 }
47412 else {
47413 h = dot$1(inputs, this.kernel.read());
47414 }
47415 if (this.bias != null) {
47416 h = biasAdd(h, this.bias.read());
47417 }
47418 if (recDpMask != null) {
47419 prevOutput = mul(prevOutput, recDpMask);
47420 }
47421 let output = add$1(h, dot$1(prevOutput, this.recurrentKernel.read()));
47422 if (this.activation != null) {
47423 output = this.activation.apply(output);
47424 }
47425 // TODO(cais): Properly set learning phase on output tensor?
47426 return [output, output];
47427 });
47428 }
47429 getConfig() {
47430 const baseConfig = super.getConfig();
47431 const config = {
47432 units: this.units,
47433 activation: serializeActivation(this.activation),
47434 useBias: this.useBias,
47435 kernelInitializer: serializeInitializer(this.kernelInitializer),
47436 recurrentInitializer: serializeInitializer(this.recurrentInitializer),
47437 biasInitializer: serializeInitializer(this.biasInitializer),
47438 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
47439 recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
47440 biasRegularizer: serializeRegularizer(this.biasRegularizer),
47441 activityRegularizer: serializeRegularizer(this.activityRegularizer),
47442 kernelConstraint: serializeConstraint(this.kernelConstraint),
47443 recurrentConstraint: serializeConstraint(this.recurrentConstraint),
47444 biasConstraint: serializeConstraint(this.biasConstraint),
47445 dropout: this.dropout,
47446 recurrentDropout: this.recurrentDropout,
47447 };
47448 return Object.assign({}, baseConfig, config);
47449 }
47450 }
47451 /** @nocollapse */
47452 SimpleRNNCell.className = 'SimpleRNNCell';
47453 registerClass(SimpleRNNCell);
47454 class SimpleRNN extends RNN {
47455 constructor(args) {
47456 args.cell = new SimpleRNNCell(args);
47457 super(args);
47458 // TODO(cais): Add activityRegularizer.
47459 }
47460 call(inputs, kwargs) {
47461 return tidy(() => {
47462 if (this.cell.dropoutMask != null) {
47463 dispose(this.cell.dropoutMask);
47464 this.cell.dropoutMask = null;
47465 }
47466 if (this.cell.recurrentDropoutMask != null) {
47467 dispose(this.cell.recurrentDropoutMask);
47468 this.cell.recurrentDropoutMask = null;
47469 }
47470 const mask = kwargs == null ? null : kwargs['mask'];
47471 const training = kwargs == null ? null : kwargs['training'];
47472 const initialState = kwargs == null ? null : kwargs['initialState'];
47473 return super.call(inputs, { mask, training, initialState });
47474 });
47475 }
47476 /** @nocollapse */
47477 static fromConfig(cls, config) {
47478 return new cls(config);
47479 }
47480 }
47481 /** @nocollapse */
47482 SimpleRNN.className = 'SimpleRNN';
47483 registerClass(SimpleRNN);
47484 class GRUCell extends RNNCell {
47485 constructor(args) {
47486 super(args);
47487 this.DEFAULT_ACTIVATION = 'tanh';
47488 this.DEFAULT_RECURRENT_ACTIVATION = 'hardSigmoid';
47489 this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
47490 this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
47491 this.DEFAULT_BIAS_INITIALIZER = 'zeros';
47492 if (args.resetAfter) {
47493 throw new ValueError(`GRUCell does not support reset_after parameter set to true.`);
47494 }
47495 this.units = args.units;
47496 assertPositiveInteger(this.units, 'units');
47497 this.activation = getActivation(args.activation === undefined ? this.DEFAULT_ACTIVATION :
47498 args.activation);
47499 this.recurrentActivation = getActivation(args.recurrentActivation === undefined ?
47500 this.DEFAULT_RECURRENT_ACTIVATION :
47501 args.recurrentActivation);
47502 this.useBias = args.useBias == null ? true : args.useBias;
47503 this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
47504 this.recurrentInitializer = getInitializer(args.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER);
47505 this.biasInitializer =
47506 getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
47507 this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
47508 this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
47509 this.biasRegularizer = getRegularizer(args.biasRegularizer);
47510 this.kernelConstraint = getConstraint(args.kernelConstraint);
47511 this.recurrentConstraint = getConstraint(args.recurrentConstraint);
47512 this.biasConstraint = getConstraint(args.biasConstraint);
47513 this.dropout = min$1([1, max$1([0, args.dropout == null ? 0 : args.dropout])]);
47514 this.recurrentDropout = min$1([
47515 1,
47516 max$1([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
47517 ]);
47518 this.dropoutFunc = args.dropoutFunc;
47519 this.implementation = args.implementation;
47520 this.stateSize = this.units;
47521 this.dropoutMask = null;
47522 this.recurrentDropoutMask = null;
47523 }
47524 build(inputShape) {
47525 inputShape = getExactlyOneShape(inputShape);
47526 const inputDim = inputShape[inputShape.length - 1];
47527 this.kernel = this.addWeight('kernel', [inputDim, this.units * 3], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
47528 this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units * 3], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
47529 if (this.useBias) {
47530 this.bias = this.addWeight('bias', [this.units * 3], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
47531 }
47532 else {
47533 this.bias = null;
47534 }
47535 // Porting Notes: Unlike the PyKeras implementation, we perform slicing
47536 // of the weights and bias in the call() method, at execution time.
47537 this.built = true;
47538 }
47539 call(inputs, kwargs) {
47540 return tidy(() => {
47541 inputs = inputs;
47542 if (inputs.length !== 2) {
47543 throw new ValueError(`GRUCell expects 2 input Tensors (inputs, h, c), got ` +
47544 `${inputs.length}.`);
47545 }
47546 const training = kwargs['training'] == null ? false : kwargs['training'];
47547 let hTMinus1 = inputs[1]; // Previous memory state.
47548 inputs = inputs[0];
47549 // Note: For superior performance, TensorFlow.js always uses
47550 // implementation 2, regardless of the actual value of
47551 // config.implementation.
47552 if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
47553 this.dropoutMask = generateDropoutMask({
47554 ones: () => onesLike(inputs),
47555 rate: this.dropout,
47556 training,
47557 count: 3,
47558 dropoutFunc: this.dropoutFunc,
47559 });
47560 }
47561 if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
47562 this.recurrentDropoutMask == null) {
47563 this.recurrentDropoutMask = generateDropoutMask({
47564 ones: () => onesLike(hTMinus1),
47565 rate: this.recurrentDropout,
47566 training,
47567 count: 3,
47568 dropoutFunc: this.dropoutFunc,
47569 });
47570 }
47571 const dpMask = this.dropoutMask;
47572 const recDpMask = this.recurrentDropoutMask;
47573 let z;
47574 let r;
47575 let hh;
47576 if (0 < this.dropout && this.dropout < 1) {
47577 inputs = mul(inputs, dpMask[0]);
47578 }
47579 let matrixX = dot$1(inputs, this.kernel.read());
47580 if (this.useBias) {
47581 matrixX = biasAdd(matrixX, this.bias.read());
47582 }
47583 if (0 < this.recurrentDropout && this.recurrentDropout < 1) {
47584 hTMinus1 = mul(hTMinus1, recDpMask[0]);
47585 }
47586 const recurrentKernelValue = this.recurrentKernel.read();
47587 const [rk1, rk2] = split(recurrentKernelValue, [2 * this.units, this.units], recurrentKernelValue.rank - 1);
47588 const matrixInner = dot$1(hTMinus1, rk1);
47589 const [xZ, xR, xH] = split(matrixX, 3, matrixX.rank - 1);
47590 const [recurrentZ, recurrentR] = split(matrixInner, 2, matrixInner.rank - 1);
47591 z = this.recurrentActivation.apply(add$1(xZ, recurrentZ));
47592 r = this.recurrentActivation.apply(add$1(xR, recurrentR));
47593 const recurrentH = dot$1(mul(r, hTMinus1), rk2);
47594 hh = this.activation.apply(add$1(xH, recurrentH));
47595 const h = add$1(mul(z, hTMinus1), mul(add$1(1, neg(z)), hh));
47596 // TODO(cais): Add use_learning_phase flag properly.
47597 return [h, h];
47598 });
47599 }
47600 getConfig() {
47601 const baseConfig = super.getConfig();
47602 const config = {
47603 units: this.units,
47604 activation: serializeActivation(this.activation),
47605 recurrentActivation: serializeActivation(this.recurrentActivation),
47606 useBias: this.useBias,
47607 kernelInitializer: serializeInitializer(this.kernelInitializer),
47608 recurrentInitializer: serializeInitializer(this.recurrentInitializer),
47609 biasInitializer: serializeInitializer(this.biasInitializer),
47610 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
47611 recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
47612 biasRegularizer: serializeRegularizer(this.biasRegularizer),
47613 activityRegularizer: serializeRegularizer(this.activityRegularizer),
47614 kernelConstraint: serializeConstraint(this.kernelConstraint),
47615 recurrentConstraint: serializeConstraint(this.recurrentConstraint),
47616 biasConstraint: serializeConstraint(this.biasConstraint),
47617 dropout: this.dropout,
47618 recurrentDropout: this.recurrentDropout,
47619 implementation: this.implementation,
47620 resetAfter: false
47621 };
47622 return Object.assign({}, baseConfig, config);
47623 }
47624 }
47625 /** @nocollapse */
47626 GRUCell.className = 'GRUCell';
47627 registerClass(GRUCell);
47628 class GRU extends RNN {
47629 constructor(args) {
47630 if (args.implementation === 0) {
47631 console.warn('`implementation=0` has been deprecated, and now defaults to ' +
47632 '`implementation=1`. Please update your layer call.');
47633 }
47634 args.cell = new GRUCell(args);
47635 super(args);
47636 // TODO(cais): Add activityRegularizer.
47637 }
47638 call(inputs, kwargs) {
47639 return tidy(() => {
47640 if (this.cell.dropoutMask != null) {
47641 dispose(this.cell.dropoutMask);
47642 this.cell.dropoutMask = null;
47643 }
47644 if (this.cell.recurrentDropoutMask != null) {
47645 dispose(this.cell.recurrentDropoutMask);
47646 this.cell.recurrentDropoutMask = null;
47647 }
47648 const mask = kwargs == null ? null : kwargs['mask'];
47649 const training = kwargs == null ? null : kwargs['training'];
47650 const initialState = kwargs == null ? null : kwargs['initialState'];
47651 return super.call(inputs, { mask, training, initialState });
47652 });
47653 }
47654 /** @nocollapse */
47655 static fromConfig(cls, config) {
47656 if (config['implmentation'] === 0) {
47657 config['implementation'] = 1;
47658 }
47659 return new cls(config);
47660 }
47661 }
47662 /** @nocollapse */
47663 GRU.className = 'GRU';
47664 registerClass(GRU);
47665 class LSTMCell extends RNNCell {
47666 constructor(args) {
47667 super(args);
47668 this.DEFAULT_ACTIVATION = 'tanh';
47669 this.DEFAULT_RECURRENT_ACTIVATION = 'hardSigmoid';
47670 this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
47671 this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
47672 this.DEFAULT_BIAS_INITIALIZER = 'zeros';
47673 this.units = args.units;
47674 assertPositiveInteger(this.units, 'units');
47675 this.activation = getActivation(args.activation === undefined ? this.DEFAULT_ACTIVATION :
47676 args.activation);
47677 this.recurrentActivation = getActivation(args.recurrentActivation === undefined ?
47678 this.DEFAULT_RECURRENT_ACTIVATION :
47679 args.recurrentActivation);
47680 this.useBias = args.useBias == null ? true : args.useBias;
47681 this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
47682 this.recurrentInitializer = getInitializer(args.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER);
47683 this.biasInitializer =
47684 getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
47685 this.unitForgetBias = args.unitForgetBias;
47686 this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
47687 this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
47688 this.biasRegularizer = getRegularizer(args.biasRegularizer);
47689 this.kernelConstraint = getConstraint(args.kernelConstraint);
47690 this.recurrentConstraint = getConstraint(args.recurrentConstraint);
47691 this.biasConstraint = getConstraint(args.biasConstraint);
47692 this.dropout = min$1([1, max$1([0, args.dropout == null ? 0 : args.dropout])]);
47693 this.recurrentDropout = min$1([
47694 1,
47695 max$1([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
47696 ]);
47697 this.dropoutFunc = args.dropoutFunc;
47698 this.implementation = args.implementation;
47699 this.stateSize = [this.units, this.units];
47700 this.dropoutMask = null;
47701 this.recurrentDropoutMask = null;
47702 }
47703 build(inputShape) {
47704 var _a;
47705 inputShape = getExactlyOneShape(inputShape);
47706 const inputDim = inputShape[inputShape.length - 1];
47707 this.kernel = this.addWeight('kernel', [inputDim, this.units * 4], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
47708 this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units * 4], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
47709 let biasInitializer;
47710 if (this.useBias) {
47711 if (this.unitForgetBias) {
47712 const capturedBiasInit = this.biasInitializer;
47713 const capturedUnits = this.units;
47714 biasInitializer = new (_a = class CustomInit extends Initializer {
47715 apply(shape, dtype) {
47716 // TODO(cais): More informative variable names?
47717 const bI = capturedBiasInit.apply([capturedUnits]);
47718 const bF = (new Ones()).apply([capturedUnits]);
47719 const bCAndH = capturedBiasInit.apply([capturedUnits * 2]);
47720 return concatAlongFirstAxis(concatAlongFirstAxis(bI, bF), bCAndH);
47721 }
47722 },
47723 /** @nocollapse */
47724 _a.className = 'CustomInit',
47725 _a)();
47726 }
47727 else {
47728 biasInitializer = this.biasInitializer;
47729 }
47730 this.bias = this.addWeight('bias', [this.units * 4], null, biasInitializer, this.biasRegularizer, true, this.biasConstraint);
47731 }
47732 else {
47733 this.bias = null;
47734 }
47735 // Porting Notes: Unlike the PyKeras implementation, we perform slicing
47736 // of the weights and bias in the call() method, at execution time.
47737 this.built = true;
47738 }
47739 call(inputs, kwargs) {
47740 return tidy(() => {
47741 const training = kwargs['training'] == null ? false : kwargs['training'];
47742 inputs = inputs;
47743 if (inputs.length !== 3) {
47744 throw new ValueError(`LSTMCell expects 3 input Tensors (inputs, h, c), got ` +
47745 `${inputs.length}.`);
47746 }
47747 let hTMinus1 = inputs[1]; // Previous memory state.
47748 const cTMinus1 = inputs[2]; // Previous carry state.
47749 inputs = inputs[0];
47750 if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
47751 this.dropoutMask = generateDropoutMask({
47752 ones: () => onesLike(inputs),
47753 rate: this.dropout,
47754 training,
47755 count: 4,
47756 dropoutFunc: this.dropoutFunc
47757 });
47758 }
47759 if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
47760 this.recurrentDropoutMask == null) {
47761 this.recurrentDropoutMask = generateDropoutMask({
47762 ones: () => onesLike(hTMinus1),
47763 rate: this.recurrentDropout,
47764 training,
47765 count: 4,
47766 dropoutFunc: this.dropoutFunc
47767 });
47768 }
47769 const dpMask = this.dropoutMask;
47770 const recDpMask = this.recurrentDropoutMask;
47771 // Note: For superior performance, TensorFlow.js always uses
47772 // implementation 2 regardless of the actual value of
47773 // config.implementation.
47774 let i;
47775 let f;
47776 let c;
47777 let o;
47778 if (0 < this.dropout && this.dropout < 1) {
47779 inputs = mul(inputs, dpMask[0]);
47780 }
47781 let z = dot$1(inputs, this.kernel.read());
47782 if (0 < this.recurrentDropout && this.recurrentDropout < 1) {
47783 hTMinus1 = mul(hTMinus1, recDpMask[0]);
47784 }
47785 z = add$1(z, dot$1(hTMinus1, this.recurrentKernel.read()));
47786 if (this.useBias) {
47787 z = biasAdd(z, this.bias.read());
47788 }
47789 const [z0, z1, z2, z3] = split(z, 4, z.rank - 1);
47790 i = this.recurrentActivation.apply(z0);
47791 f = this.recurrentActivation.apply(z1);
47792 c = add$1(mul(f, cTMinus1), mul(i, this.activation.apply(z2)));
47793 o = this.recurrentActivation.apply(z3);
47794 const h = mul(o, this.activation.apply(c));
47795 // TODO(cais): Add use_learning_phase flag properly.
47796 return [h, h, c];
47797 });
47798 }
47799 getConfig() {
47800 const baseConfig = super.getConfig();
47801 const config = {
47802 units: this.units,
47803 activation: serializeActivation(this.activation),
47804 recurrentActivation: serializeActivation(this.recurrentActivation),
47805 useBias: this.useBias,
47806 kernelInitializer: serializeInitializer(this.kernelInitializer),
47807 recurrentInitializer: serializeInitializer(this.recurrentInitializer),
47808 biasInitializer: serializeInitializer(this.biasInitializer),
47809 unitForgetBias: this.unitForgetBias,
47810 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
47811 recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
47812 biasRegularizer: serializeRegularizer(this.biasRegularizer),
47813 activityRegularizer: serializeRegularizer(this.activityRegularizer),
47814 kernelConstraint: serializeConstraint(this.kernelConstraint),
47815 recurrentConstraint: serializeConstraint(this.recurrentConstraint),
47816 biasConstraint: serializeConstraint(this.biasConstraint),
47817 dropout: this.dropout,
47818 recurrentDropout: this.recurrentDropout,
47819 implementation: this.implementation,
47820 };
47821 return Object.assign({}, baseConfig, config);
47822 }
47823 }
47824 /** @nocollapse */
47825 LSTMCell.className = 'LSTMCell';
47826 registerClass(LSTMCell);
47827 class LSTM extends RNN {
47828 constructor(args) {
47829 if (args.implementation === 0) {
47830 console.warn('`implementation=0` has been deprecated, and now defaults to ' +
47831 '`implementation=1`. Please update your layer call.');
47832 }
47833 args.cell = new LSTMCell(args);
47834 super(args);
47835 // TODO(cais): Add activityRegularizer.
47836 }
47837 call(inputs, kwargs) {
47838 return tidy(() => {
47839 if (this.cell.dropoutMask != null) {
47840 dispose(this.cell.dropoutMask);
47841 this.cell.dropoutMask = null;
47842 }
47843 if (this.cell.recurrentDropoutMask != null) {
47844 dispose(this.cell.recurrentDropoutMask);
47845 this.cell.recurrentDropoutMask = null;
47846 }
47847 const mask = kwargs == null ? null : kwargs['mask'];
47848 const training = kwargs == null ? null : kwargs['training'];
47849 const initialState = kwargs == null ? null : kwargs['initialState'];
47850 return super.call(inputs, { mask, training, initialState });
47851 });
47852 }
47853 /** @nocollapse */
47854 static fromConfig(cls, config) {
47855 if (config['implmentation'] === 0) {
47856 config['implementation'] = 1;
47857 }
47858 return new cls(config);
47859 }
47860 }
47861 /** @nocollapse */
47862 LSTM.className = 'LSTM';
47863 registerClass(LSTM);
47864 class StackedRNNCells extends RNNCell {
47865 constructor(args) {
47866 super(args);
47867 this.cells = args.cells;
47868 }
47869 get stateSize() {
47870 // States are a flat list in reverse order of the cell stack.
47871 // This allows perserving the requirement `stack.statesize[0] ===
47872 // outputDim`. E.g., states of a 2-layer LSTM would be `[h2, c2, h1, c1]`,
47873 // assuming one LSTM has states `[h, c]`.
47874 const stateSize = [];
47875 for (const cell of this.cells.slice().reverse()) {
47876 if (Array.isArray(cell.stateSize)) {
47877 stateSize.push(...cell.stateSize);
47878 }
47879 else {
47880 stateSize.push(cell.stateSize);
47881 }
47882 }
47883 return stateSize;
47884 }
47885 call(inputs, kwargs) {
47886 return tidy(() => {
47887 inputs = inputs;
47888 let states = inputs.slice(1);
47889 // Recover per-cell states.
47890 const nestedStates = [];
47891 for (const cell of this.cells.slice().reverse()) {
47892 if (Array.isArray(cell.stateSize)) {
47893 nestedStates.push(states.splice(0, cell.stateSize.length));
47894 }
47895 else {
47896 nestedStates.push(states.splice(0, 1));
47897 }
47898 }
47899 nestedStates.reverse();
47900 // Call the cells in order and store the returned states.
47901 const newNestedStates = [];
47902 let callInputs;
47903 for (let i = 0; i < this.cells.length; ++i) {
47904 const cell = this.cells[i];
47905 states = nestedStates[i];
47906 // TODO(cais): Take care of constants.
47907 if (i === 0) {
47908 callInputs = [inputs[0]].concat(states);
47909 }
47910 else {
47911 callInputs = [callInputs[0]].concat(states);
47912 }
47913 callInputs = cell.call(callInputs, kwargs);
47914 newNestedStates.push(callInputs.slice(1));
47915 }
47916 // Format the new states as a flat list in reverse cell order.
47917 states = [];
47918 for (const cellStates of newNestedStates.slice().reverse()) {
47919 states.push(...cellStates);
47920 }
47921 return [callInputs[0]].concat(states);
47922 });
47923 }
47924 build(inputShape) {
47925 if (isArrayOfShapes(inputShape)) {
47926 // TODO(cais): Take care of input constants.
47927 // const constantShape = inputShape.slice(1);
47928 inputShape = inputShape[0];
47929 }
47930 inputShape = inputShape;
47931 let outputDim;
47932 this.cells.forEach((cell, i) => {
47933 nameScope(`RNNCell_${i}`, () => {
47934 // TODO(cais): Take care of input constants.
47935 cell.build(inputShape);
47936 if (Array.isArray(cell.stateSize)) {
47937 outputDim = cell.stateSize[0];
47938 }
47939 else {
47940 outputDim = cell.stateSize;
47941 }
47942 inputShape = [inputShape[0], outputDim];
47943 });
47944 });
47945 this.built = true;
47946 }
47947 getConfig() {
47948 const baseConfig = super.getConfig();
47949 const getCellConfig = (cell) => {
47950 return {
47951 'className': cell.getClassName(),
47952 'config': cell.getConfig(),
47953 };
47954 };
47955 const cellConfigs = this.cells.map(getCellConfig);
47956 const config = { 'cells': cellConfigs };
47957 return Object.assign({}, baseConfig, config);
47958 }
47959 /** @nocollapse */
47960 static fromConfig(cls, config, customObjects = {}) {
47961 const cells = [];
47962 for (const cellConfig of config['cells']) {
47963 cells.push(deserialize(cellConfig, customObjects));
47964 }
47965 return new cls({ cells });
47966 }
47967 get trainableWeights() {
47968 if (!this.trainable) {
47969 return [];
47970 }
47971 const weights = [];
47972 for (const cell of this.cells) {
47973 weights.push(...cell.trainableWeights);
47974 }
47975 return weights;
47976 }
47977 get nonTrainableWeights() {
47978 const weights = [];
47979 for (const cell of this.cells) {
47980 weights.push(...cell.nonTrainableWeights);
47981 }
47982 if (!this.trainable) {
47983 const trainableWeights = [];
47984 for (const cell of this.cells) {
47985 trainableWeights.push(...cell.trainableWeights);
47986 }
47987 return trainableWeights.concat(weights);
47988 }
47989 return weights;
47990 }
47991 /**
47992 * Retrieve the weights of a the model.
47993 *
47994 * @returns A flat `Array` of `tf.Tensor`s.
47995 */
47996 getWeights() {
47997 const weights = [];
47998 for (const cell of this.cells) {
47999 weights.push(...cell.weights);
48000 }
48001 return batchGetValue(weights);
48002 }
48003 /**
48004 * Set the weights of the model.
48005 *
48006 * @param weights An `Array` of `tf.Tensor`s with shapes and types matching
48007 * the output of `getWeights()`.
48008 */
48009 setWeights(weights) {
48010 const tuples = [];
48011 for (const cell of this.cells) {
48012 const numParams = cell.weights.length;
48013 const inputWeights = weights.splice(numParams);
48014 for (let i = 0; i < cell.weights.length; ++i) {
48015 tuples.push([cell.weights[i], inputWeights[i]]);
48016 }
48017 }
48018 batchSetValue(tuples);
48019 }
48020 }
48021 /** @nocollapse */
48022 StackedRNNCells.className = 'StackedRNNCells';
48023 registerClass(StackedRNNCells);
48024 function generateDropoutMask(args) {
48025 const { ones, rate, training = false, count = 1, dropoutFunc } = args;
48026 const droppedInputs = () => dropoutFunc != null ? dropoutFunc(ones(), rate) : dropout$1(ones(), rate);
48027 const createMask = () => inTrainPhase(droppedInputs, ones, training);
48028 // just in case count is provided with null or undefined
48029 if (!count || count <= 1) {
48030 return keep(createMask().clone());
48031 }
48032 const masks = Array(count).fill(undefined).map(createMask);
48033 return masks.map(m => keep(m.clone()));
48034 }
48035
48036 /**
48037 * @license
48038 * Copyright 2020 Google LLC
48039 *
48040 * Use of this source code is governed by an MIT-style
48041 * license that can be found in the LICENSE file or at
48042 * https://opensource.org/licenses/MIT.
48043 * =============================================================================
48044 */
48045 var __rest = (undefined && undefined.__rest) || function (s, e) {
48046 var t = {};
48047 for (var p in s) if (Object.prototype.hasOwnProperty.call(s, p) && e.indexOf(p) < 0)
48048 t[p] = s[p];
48049 if (s != null && typeof Object.getOwnPropertySymbols === "function")
48050 for (var i = 0, p = Object.getOwnPropertySymbols(s); i < p.length; i++) {
48051 if (e.indexOf(p[i]) < 0 && Object.prototype.propertyIsEnumerable.call(s, p[i]))
48052 t[p[i]] = s[p[i]];
48053 }
48054 return t;
48055 };
48056 class ConvRNN2DCell extends RNNCell {
48057 }
48058 /**
48059 * Base class for convolutional-recurrent layers.
48060 */
48061 class ConvRNN2D extends RNN {
48062 constructor(args) {
48063 if (args.unroll) {
48064 throw new NotImplementedError('Unrolling is not possible with convolutional RNNs.');
48065 }
48066 if (Array.isArray(args.cell)) {
48067 throw new NotImplementedError('It is not possible at the moment to stack convolutional cells.');
48068 }
48069 super(args);
48070 this.inputSpec = [new InputSpec({ ndim: 5 })];
48071 }
48072 call(inputs, kwargs) {
48073 return tidy(() => {
48074 if (this.cell.dropoutMask != null) {
48075 dispose(this.cell.dropoutMask);
48076 this.cell.dropoutMask = null;
48077 }
48078 if (this.cell.recurrentDropoutMask != null) {
48079 dispose(this.cell.recurrentDropoutMask);
48080 this.cell.recurrentDropoutMask = null;
48081 }
48082 if (kwargs && kwargs['constants']) {
48083 throw new ValueError('ConvRNN2D cell does not support constants');
48084 }
48085 const mask = kwargs == null ? null : kwargs['mask'];
48086 const training = kwargs == null ? null : kwargs['training'];
48087 const initialState = kwargs == null ? null : kwargs['initialState'];
48088 return super.call(inputs, { mask, training, initialState });
48089 });
48090 }
48091 computeOutputShape(inputShape) {
48092 let outShape = this.computeSingleOutputShape(inputShape);
48093 if (!this.returnSequences) {
48094 outShape = [outShape[0], ...outShape.slice(2)];
48095 }
48096 if (this.returnState) {
48097 outShape =
48098 [outShape, ...Array(2).fill([inputShape[0], ...outShape.slice(-3)])];
48099 }
48100 return outShape;
48101 }
48102 getInitialState(inputs) {
48103 return tidy(() => {
48104 const { stateSize } = this.cell;
48105 const inputShape = inputs.shape;
48106 const outputShape = this.computeSingleOutputShape(inputShape);
48107 const stateShape = [outputShape[0], ...outputShape.slice(2)];
48108 const initialState = zeros(stateShape);
48109 if (Array.isArray(stateSize)) {
48110 return Array(stateSize.length).fill(initialState);
48111 }
48112 return [initialState];
48113 });
48114 }
48115 resetStates(states, training = false) {
48116 tidy(() => {
48117 if (!this.stateful) {
48118 throw new AttributeError('Cannot call resetStates() on an RNN Layer that is not stateful.');
48119 }
48120 const inputShape = this.inputSpec[0].shape;
48121 const outputShape = this.computeSingleOutputShape(inputShape);
48122 const stateShape = [outputShape[0], ...outputShape.slice(2)];
48123 const batchSize = inputShape[0];
48124 if (batchSize == null) {
48125 throw new ValueError('If an RNN is stateful, it needs to know its batch size. Specify ' +
48126 'the batch size of your input tensors: \n' +
48127 '- If using a Sequential model, specify the batch size by ' +
48128 'passing a `batchInputShape` option to your first layer.\n' +
48129 '- If using the functional API, specify the batch size by ' +
48130 'passing a `batchShape` option to your Input layer.');
48131 }
48132 // Initialize state if null.
48133 if (this.getStates() == null) {
48134 if (Array.isArray(this.cell.stateSize)) {
48135 this.states_ = this.cell.stateSize.map(() => zeros(stateShape));
48136 }
48137 else {
48138 this.states_ = [zeros(stateShape)];
48139 }
48140 }
48141 else if (states == null) {
48142 // Dispose old state tensors.
48143 dispose(this.states_);
48144 // For stateful RNNs, fully dispose kept old states.
48145 if (this.keptStates != null) {
48146 dispose(this.keptStates);
48147 this.keptStates = [];
48148 }
48149 if (Array.isArray(this.cell.stateSize)) {
48150 this.states_ = this.cell.stateSize.map(() => zeros(stateShape));
48151 }
48152 else {
48153 this.states_[0] = zeros(stateShape);
48154 }
48155 }
48156 else {
48157 if (!Array.isArray(states)) {
48158 states = [states];
48159 }
48160 if (states.length !== this.states_.length) {
48161 throw new ValueError(`Layer ${this.name} expects ${this.states_.length} state(s), ` +
48162 `but it received ${states.length} state value(s). Input ` +
48163 `received: ${states}`);
48164 }
48165 if (training) {
48166 // Store old state tensors for complete disposal later, i.e., during
48167 // the next no-arg call to this method. We do not dispose the old
48168 // states immediately because that BPTT (among other things) require
48169 // them.
48170 this.keptStates.push(this.states_.slice());
48171 }
48172 else {
48173 dispose(this.states_);
48174 }
48175 for (let index = 0; index < this.states_.length; ++index) {
48176 const value = states[index];
48177 const expectedShape = stateShape;
48178 if (!arraysEqual(value.shape, expectedShape)) {
48179 throw new ValueError(`State ${index} is incompatible with layer ${this.name}: ` +
48180 `expected shape=${expectedShape}, received shape=${value.shape}`);
48181 }
48182 this.states_[index] = value;
48183 }
48184 }
48185 this.states_ = this.states_.map(state => keep(state.clone()));
48186 });
48187 }
48188 computeSingleOutputShape(inputShape) {
48189 const { dataFormat, filters, kernelSize, padding, strides, dilationRate } = this.cell;
48190 const isChannelsFirst = dataFormat === 'channelsFirst';
48191 const h = inputShape[isChannelsFirst ? 3 : 2];
48192 const w = inputShape[isChannelsFirst ? 4 : 3];
48193 const hOut = convOutputLength(h, kernelSize[0], padding, strides[0], dilationRate[0]);
48194 const wOut = convOutputLength(w, kernelSize[1], padding, strides[1], dilationRate[1]);
48195 const outShape = [
48196 ...inputShape.slice(0, 2),
48197 ...(isChannelsFirst ? [filters, hOut, wOut] : [hOut, wOut, filters])
48198 ];
48199 return outShape;
48200 }
48201 }
48202 /** @nocollapse */
48203 ConvRNN2D.className = 'ConvRNN2D';
48204 class ConvLSTM2DCell extends LSTMCell {
48205 constructor(args) {
48206 const { filters, kernelSize, strides, padding, dataFormat, dilationRate, } = args;
48207 super(Object.assign({}, args, { units: filters }));
48208 this.filters = filters;
48209 assertPositiveInteger(this.filters, 'filters');
48210 this.kernelSize = normalizeArray(kernelSize, 2, 'kernelSize');
48211 this.kernelSize.forEach(size => assertPositiveInteger(size, 'kernelSize'));
48212 this.strides = normalizeArray(strides || 1, 2, 'strides');
48213 this.strides.forEach(stride => assertPositiveInteger(stride, 'strides'));
48214 this.padding = padding || 'valid';
48215 checkPaddingMode(this.padding);
48216 this.dataFormat = dataFormat || 'channelsLast';
48217 checkDataFormat(this.dataFormat);
48218 this.dilationRate = normalizeArray(dilationRate || 1, 2, 'dilationRate');
48219 this.dilationRate.forEach(rate => assertPositiveInteger(rate, 'dilationRate'));
48220 }
48221 build(inputShape) {
48222 var _a;
48223 inputShape = getExactlyOneShape(inputShape);
48224 const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
48225 if (inputShape[channelAxis] == null) {
48226 throw new ValueError(`The channel dimension of the input should be defined. ` +
48227 `Found ${inputShape[channelAxis]}`);
48228 }
48229 const inputDim = inputShape[channelAxis];
48230 const numOfKernels = 4;
48231 const kernelShape = this.kernelSize.concat([inputDim, this.filters * numOfKernels]);
48232 this.kernel = this.addWeight('kernel', kernelShape, null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
48233 const recurrentKernelShape = this.kernelSize.concat([this.filters, this.filters * numOfKernels]);
48234 this.recurrentKernel = this.addWeight('recurrent_kernel', recurrentKernelShape, null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
48235 if (this.useBias) {
48236 let biasInitializer;
48237 if (this.unitForgetBias) {
48238 const init = this.biasInitializer;
48239 const filters = this.filters;
48240 biasInitializer = new (_a = class CustomInit extends Initializer {
48241 apply(shape, dtype) {
48242 const biasI = init.apply([filters]);
48243 const biasF = ones$1([filters]);
48244 const biasCAndO = init.apply([filters * 2]);
48245 return concatenate([biasI, biasF, biasCAndO]);
48246 }
48247 },
48248 /** @nocollapse */
48249 _a.className = 'CustomInit',
48250 _a)();
48251 }
48252 else {
48253 biasInitializer = this.biasInitializer;
48254 }
48255 this.bias = this.addWeight('bias', [this.filters * numOfKernels], null, biasInitializer, this.biasRegularizer, true, this.biasConstraint);
48256 }
48257 this.built = true;
48258 }
48259 call(inputs, kwargs) {
48260 return tidy(() => {
48261 if (inputs.length !== 3) {
48262 throw new ValueError(`ConvLSTM2DCell expects 3 input Tensors (inputs, h, c), got ` +
48263 `${inputs.length}.`);
48264 }
48265 const training = kwargs['training'] || false;
48266 const x = inputs[0]; // Current input
48267 const hTMinus1 = inputs[1]; // Previous memory state.
48268 const cTMinus1 = inputs[2]; // Previous carry state.
48269 const numOfKernels = 4;
48270 if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
48271 this.dropoutMask = generateDropoutMask({
48272 ones: () => onesLike(x),
48273 rate: this.dropout,
48274 training,
48275 count: numOfKernels,
48276 dropoutFunc: this.dropoutFunc
48277 });
48278 }
48279 const dropoutMask = this.dropoutMask;
48280 const applyDropout = (x, mask, index) => {
48281 if (!mask || !mask[index]) {
48282 return x;
48283 }
48284 return mul(mask[index], x);
48285 };
48286 let xI = applyDropout(x, dropoutMask, 0);
48287 let xF = applyDropout(x, dropoutMask, 1);
48288 let xC = applyDropout(x, dropoutMask, 2);
48289 let xO = applyDropout(x, dropoutMask, 3);
48290 if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
48291 this.recurrentDropoutMask == null) {
48292 this.recurrentDropoutMask = generateDropoutMask({
48293 ones: () => onesLike(hTMinus1),
48294 rate: this.recurrentDropout,
48295 training,
48296 count: numOfKernels,
48297 dropoutFunc: this.dropoutFunc
48298 });
48299 }
48300 const recDropoutMask = this.recurrentDropoutMask;
48301 let hI = applyDropout(hTMinus1, recDropoutMask, 0);
48302 let hF = applyDropout(hTMinus1, recDropoutMask, 1);
48303 let hC = applyDropout(hTMinus1, recDropoutMask, 2);
48304 let hO = applyDropout(hTMinus1, recDropoutMask, 3);
48305 const kernelChannelAxis = 3;
48306 const [kernelI, kernelF, kernelC, kernelO] = split(this.kernel.read(), numOfKernels, kernelChannelAxis);
48307 const [biasI, biasF, biasC, biasO] = this.useBias ?
48308 split(this.bias.read(), numOfKernels) :
48309 [null, null, null, null];
48310 xI = this.inputConv(xI, kernelI, biasI, this.padding);
48311 xF = this.inputConv(xF, kernelF, biasF, this.padding);
48312 xC = this.inputConv(xC, kernelC, biasC, this.padding);
48313 xO = this.inputConv(xO, kernelO, biasO, this.padding);
48314 const [recKernelI, recKernelF, recKernelC, recKernelO] = split(this.recurrentKernel.read(), numOfKernels, kernelChannelAxis);
48315 hI = this.recurrentConv(hI, recKernelI);
48316 hF = this.recurrentConv(hF, recKernelF);
48317 hC = this.recurrentConv(hC, recKernelC);
48318 hO = this.recurrentConv(hO, recKernelO);
48319 const i = this.recurrentActivation.apply(add$1(xI, hI));
48320 const f = this.recurrentActivation.apply(add$1(xF, hF));
48321 const c = add$1(mul(f, cTMinus1), mul(i, this.activation.apply(add$1(xC, hC))));
48322 const h = mul(this.recurrentActivation.apply(add$1(xO, hO)), this.activation.apply(c));
48323 return [h, h, c];
48324 });
48325 }
48326 getConfig() {
48327 const _a = super.getConfig(), { 'units': _ } = _a, baseConfig = __rest(_a, ['units']);
48328 const config = {
48329 filters: this.filters,
48330 kernelSize: this.kernelSize,
48331 padding: this.padding,
48332 dataFormat: this.dataFormat,
48333 dilationRate: this.dilationRate,
48334 strides: this.strides,
48335 };
48336 return Object.assign({}, baseConfig, config);
48337 }
48338 inputConv(x, w, b, padding) {
48339 const out = conv2d(x, w, this.strides, (padding || 'valid'), this.dataFormat === 'channelsFirst' ? 'NCHW' : 'NHWC', this.dilationRate);
48340 if (b) {
48341 return biasAdd(out, b, this.dataFormat);
48342 }
48343 return out;
48344 }
48345 recurrentConv(x, w) {
48346 const strides = 1;
48347 return conv2d(x, w, strides, 'same', this.dataFormat === 'channelsFirst' ? 'NCHW' : 'NHWC');
48348 }
48349 }
48350 /** @nocollapse */
48351 ConvLSTM2DCell.className = 'ConvLSTM2DCell';
48352 registerClass(ConvLSTM2DCell);
48353 class ConvLSTM2D extends ConvRNN2D {
48354 constructor(args) {
48355 const cell = new ConvLSTM2DCell(args);
48356 super(Object.assign({}, args, { cell }));
48357 }
48358 /** @nocollapse */
48359 static fromConfig(cls, config) {
48360 return new cls(config);
48361 }
48362 }
48363 /** @nocollapse */
48364 ConvLSTM2D.className = 'ConvLSTM2D';
48365 registerClass(ConvLSTM2D);
48366
48367 /**
48368 * @license
48369 * Copyright 2018 Google LLC
48370 *
48371 * Use of this source code is governed by an MIT-style
48372 * license that can be found in the LICENSE file or at
48373 * https://opensource.org/licenses/MIT.
48374 * =============================================================================
48375 */
48376 class Dropout extends Layer {
48377 constructor(args) {
48378 super(args);
48379 this.rate = Math.max(Math.min(args.rate, 1), 0);
48380 // So that the scalar doesn't get tidied up between executions.
48381 this.noiseShape = args.noiseShape;
48382 this.seed = args.seed;
48383 this.supportsMasking = true;
48384 }
48385 getNoiseShape(input) {
48386 if (this.noiseShape == null) {
48387 return this.noiseShape;
48388 }
48389 const inputShape = input.shape;
48390 const noiseShape = [];
48391 for (let i = 0; i < this.noiseShape.length; ++i) {
48392 noiseShape.push(this.noiseShape[i] == null ? inputShape[i] : this.noiseShape[i]);
48393 }
48394 return noiseShape;
48395 }
48396 call(inputs, kwargs) {
48397 return tidy(() => {
48398 this.invokeCallHook(inputs, kwargs);
48399 const input = getExactlyOneTensor(inputs);
48400 if (0 < this.rate && this.rate < 1) {
48401 const training = kwargs['training'] == null ? false : kwargs['training'];
48402 const noiseShape = this.getNoiseShape(input);
48403 const output = inTrainPhase(() => dropout$1(input, this.rate, noiseShape, this.seed), () => input, training);
48404 return output;
48405 }
48406 return inputs;
48407 });
48408 }
48409 getConfig() {
48410 const config = {
48411 rate: this.rate,
48412 noiseShape: this.noiseShape,
48413 seed: this.seed,
48414 };
48415 const baseConfig = super.getConfig();
48416 Object.assign(config, baseConfig);
48417 return config;
48418 }
48419 dispose() {
48420 return super.dispose();
48421 }
48422 }
48423 /** @nocollapse */
48424 Dropout.className = 'Dropout';
48425 registerClass(Dropout);
48426 class SpatialDropout1D extends Dropout {
48427 constructor(args) {
48428 super(args);
48429 this.inputSpec = [{ ndim: 3 }];
48430 }
48431 getNoiseShape(input) {
48432 const inputShape = input.shape;
48433 return [inputShape[0], 1, inputShape[2]];
48434 }
48435 }
48436 /** @nocollapse */
48437 SpatialDropout1D.className = 'SpatialDropout1D';
48438 registerClass(SpatialDropout1D);
48439 class Dense extends Layer {
48440 constructor(args) {
48441 super(args);
48442 // Default activation: Linear (none).
48443 this.activation = null;
48444 this.useBias = true;
48445 this.kernel = null;
48446 this.bias = null;
48447 this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
48448 this.DEFAULT_BIAS_INITIALIZER = 'zeros';
48449 if (args.batchInputShape == null && args.inputShape == null &&
48450 args.inputDim != null) {
48451 // This logic is copied from Layer's constructor, since we can't
48452 // do exactly what the Python constructor does for Dense().
48453 let batchSize = null;
48454 if (args.batchSize != null) {
48455 batchSize = args.batchSize;
48456 }
48457 this.batchInputShape = [batchSize, args.inputDim];
48458 }
48459 this.units = args.units;
48460 assertPositiveInteger(this.units, 'units');
48461 this.activation = getActivation(args.activation);
48462 if (args.useBias != null) {
48463 this.useBias = args.useBias;
48464 }
48465 this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
48466 this.biasInitializer =
48467 getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
48468 this.kernelConstraint = getConstraint(args.kernelConstraint);
48469 this.biasConstraint = getConstraint(args.biasConstraint);
48470 this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
48471 this.biasRegularizer = getRegularizer(args.biasRegularizer);
48472 this.activityRegularizer = getRegularizer(args.activityRegularizer);
48473 this.supportsMasking = true;
48474 this.inputSpec = [{ minNDim: 2 }];
48475 }
48476 build(inputShape) {
48477 inputShape = getExactlyOneShape(inputShape);
48478 const inputLastDim = inputShape[inputShape.length - 1];
48479 if (this.kernel == null) {
48480 this.kernel = this.addWeight('kernel', [inputLastDim, this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
48481 if (this.useBias) {
48482 this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
48483 }
48484 }
48485 this.inputSpec = [{ minNDim: 2, axes: { [-1]: inputLastDim } }];
48486 this.built = true;
48487 }
48488 computeOutputShape(inputShape) {
48489 inputShape = getExactlyOneShape(inputShape);
48490 const outputShape = inputShape.slice();
48491 outputShape[outputShape.length - 1] = this.units;
48492 return outputShape;
48493 }
48494 call(inputs, kwargs) {
48495 return tidy(() => {
48496 this.invokeCallHook(inputs, kwargs);
48497 // Dense layer accepts only a single input.
48498 const input = getExactlyOneTensor(inputs);
48499 const fusedActivationName = mapActivationToFusedKernel(this.activation.getClassName());
48500 let output;
48501 if (fusedActivationName != null) {
48502 output = dot$1(input, this.kernel.read(), fusedActivationName, this.bias ? this.bias.read() : null);
48503 }
48504 else {
48505 output = dot$1(input, this.kernel.read());
48506 if (this.bias != null) {
48507 output = biasAdd(output, this.bias.read());
48508 }
48509 if (this.activation != null) {
48510 output = this.activation.apply(output);
48511 }
48512 }
48513 return output;
48514 });
48515 }
48516 getConfig() {
48517 const config = {
48518 units: this.units,
48519 activation: serializeActivation(this.activation),
48520 useBias: this.useBias,
48521 kernelInitializer: serializeInitializer(this.kernelInitializer),
48522 biasInitializer: serializeInitializer(this.biasInitializer),
48523 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
48524 biasRegularizer: serializeRegularizer(this.biasRegularizer),
48525 activityRegularizer: serializeRegularizer(this.activityRegularizer),
48526 kernelConstraint: serializeConstraint(this.kernelConstraint),
48527 biasConstraint: serializeConstraint(this.biasConstraint)
48528 };
48529 const baseConfig = super.getConfig();
48530 Object.assign(config, baseConfig);
48531 return config;
48532 }
48533 }
48534 /** @nocollapse */
48535 Dense.className = 'Dense';
48536 registerClass(Dense);
48537 class Flatten extends Layer {
48538 constructor(args) {
48539 args = args || {};
48540 super(args);
48541 this.inputSpec = [{ minNDim: 3 }];
48542 this.dataFormat = args.dataFormat;
48543 }
48544 computeOutputShape(inputShape) {
48545 inputShape = getExactlyOneShape(inputShape);
48546 for (const dim of inputShape.slice(1)) {
48547 if (dim == null) {
48548 throw new ValueError(`The shape of the input to "Flatten" is not fully defined ` +
48549 `(got ${inputShape.slice(1)}). Make sure to pass a complete ` +
48550 `"input_shape" or "batch_input_shape" argument to the first ` +
48551 `layer in your model.`);
48552 }
48553 }
48554 return [inputShape[0], arrayProd(inputShape, 1)];
48555 }
48556 call(inputs, kwargs) {
48557 return tidy(() => {
48558 this.invokeCallHook(inputs, kwargs);
48559 let input = getExactlyOneTensor(inputs);
48560 if (this.dataFormat === 'channelsFirst' && input.rank > 1) {
48561 const permutation = [0];
48562 for (let i = 2; i < input.rank; ++i) {
48563 permutation.push(i);
48564 }
48565 permutation.push(1);
48566 input = transpose(input, permutation);
48567 }
48568 return batchFlatten(input);
48569 });
48570 }
48571 getConfig() {
48572 const config = {};
48573 if (this.dataFormat != null) {
48574 config['dataFormat'] = this.dataFormat;
48575 }
48576 const baseConfig = super.getConfig();
48577 Object.assign(config, baseConfig);
48578 return config;
48579 }
48580 }
48581 /** @nocollapse */
48582 Flatten.className = 'Flatten';
48583 registerClass(Flatten);
48584 class Activation$1 extends Layer {
48585 constructor(args) {
48586 super(args);
48587 this.supportsMasking = true;
48588 this.activation = getActivation(args.activation);
48589 }
48590 call(inputs, kwargs) {
48591 return tidy(() => {
48592 this.invokeCallHook(inputs, kwargs);
48593 const input = getExactlyOneTensor(inputs);
48594 return this.activation.apply(input);
48595 });
48596 }
48597 getConfig() {
48598 const config = { activation: serializeActivation(this.activation) };
48599 const baseConfig = super.getConfig();
48600 Object.assign(config, baseConfig);
48601 return config;
48602 }
48603 }
48604 /** @nocollapse */
48605 Activation$1.className = 'Activation';
48606 registerClass(Activation$1);
48607 class RepeatVector extends Layer {
48608 constructor(args) {
48609 super(args);
48610 this.n = args.n;
48611 this.inputSpec = [{ ndim: 2 }];
48612 }
48613 computeOutputShape(inputShape) {
48614 return [inputShape[0], this.n, inputShape[1]];
48615 }
48616 call(inputs, kwargs) {
48617 return tidy(() => {
48618 inputs = getExactlyOneTensor(inputs);
48619 return repeat(inputs, this.n);
48620 });
48621 }
48622 getConfig() {
48623 const config = {
48624 n: this.n,
48625 };
48626 const baseConfig = super.getConfig();
48627 Object.assign(config, baseConfig);
48628 return config;
48629 }
48630 }
48631 /** @nocollapse */
48632 RepeatVector.className = 'RepeatVector';
48633 registerClass(RepeatVector);
48634 class Reshape$1 extends Layer {
48635 constructor(args) {
48636 super(args);
48637 this.targetShape = args.targetShape;
48638 // Make sure that all unknown dimensions are represented as `null`.
48639 for (let i = 0; i < this.targetShape.length; ++i) {
48640 if (this.isUnknown(this.targetShape[i])) {
48641 this.targetShape[i] = null;
48642 }
48643 }
48644 }
48645 isUnknown(dim) {
48646 return dim < 0 || dim == null;
48647 }
48648 /**
48649 * Finds and replaces a missing dimension in output shape.
48650 *
48651 * This is a near direct port of the internal Numpy function
48652 * `_fix_unknown_dimension` in `numpy/core/src/multiarray/shape.c`.
48653 *
48654 * @param inputShape: Original shape of array begin reshape.
48655 * @param outputShape: Target shape of the array, with at most a single
48656 * `null` or negative number, which indicates an underdetermined dimension
48657 * that should be derived from `inputShape` and the known dimensions of
48658 * `outputShape`.
48659 * @returns: The output shape with `null` replaced with its computed value.
48660 * @throws: ValueError: If `inputShape` and `outputShape` do not match.
48661 */
48662 fixUnknownDimension(inputShape, outputShape) {
48663 const errorMsg = 'Total size of new array must be unchanged.';
48664 const finalShape = outputShape.slice();
48665 let known = 1;
48666 let unknown = null;
48667 for (let i = 0; i < finalShape.length; ++i) {
48668 const dim = finalShape[i];
48669 if (this.isUnknown(dim)) {
48670 if (unknown === null) {
48671 unknown = i;
48672 }
48673 else {
48674 throw new ValueError('Can only specifiy one unknown dimension.');
48675 }
48676 }
48677 else {
48678 known *= dim;
48679 }
48680 }
48681 const originalSize = arrayProd(inputShape);
48682 if (unknown !== null) {
48683 if (known === 0 || originalSize % known !== 0) {
48684 throw new ValueError(errorMsg);
48685 }
48686 finalShape[unknown] = originalSize / known;
48687 }
48688 else if (originalSize !== known) {
48689 throw new ValueError(errorMsg);
48690 }
48691 return finalShape;
48692 }
48693 computeOutputShape(inputShape) {
48694 let anyUnknownDims = false;
48695 for (let i = 0; i < inputShape.length; ++i) {
48696 if (this.isUnknown(inputShape[i])) {
48697 anyUnknownDims = true;
48698 break;
48699 }
48700 }
48701 if (anyUnknownDims) {
48702 return inputShape.slice(0, 1).concat(this.targetShape);
48703 }
48704 else {
48705 return inputShape.slice(0, 1).concat(this.fixUnknownDimension(inputShape.slice(1), this.targetShape));
48706 }
48707 }
48708 call(inputs, kwargs) {
48709 return tidy(() => {
48710 this.invokeCallHook(inputs, kwargs);
48711 const input = getExactlyOneTensor(inputs);
48712 const inputShape = input.shape;
48713 const outputShape = inputShape.slice(0, 1).concat(this.fixUnknownDimension(inputShape.slice(1), this.targetShape));
48714 return reshape(input, outputShape);
48715 });
48716 }
48717 getConfig() {
48718 const config = {
48719 targetShape: this.targetShape,
48720 };
48721 const baseConfig = super.getConfig();
48722 Object.assign(config, baseConfig);
48723 return config;
48724 }
48725 }
48726 /** @nocollapse */
48727 Reshape$1.className = 'Reshape';
48728 registerClass(Reshape$1);
48729 class Permute extends Layer {
48730 constructor(args) {
48731 super(args);
48732 if (args.dims == null) {
48733 throw new Error('Required configuration field `dims` is missing during Permute ' +
48734 'constructor call.');
48735 }
48736 if (!Array.isArray(args.dims)) {
48737 throw new Error('Permute constructor requires `dims` to be an Array, but received ' +
48738 `${args.dims} instead.`);
48739 }
48740 // Check the validity of the permutation indices.
48741 const expectedSortedIndices = range$1(1, args.dims.length + 1);
48742 if (!arraysEqual(args.dims.slice().sort(), expectedSortedIndices)) {
48743 throw new Error('Invalid permutation `dims`: ' + JSON.stringify(args.dims) +
48744 ' `dims` must contain consecutive integers starting from 1.');
48745 }
48746 this.dims = args.dims;
48747 this.dimsIncludingBatch = [0].concat(this.dims);
48748 this.inputSpec = [new InputSpec({ ndim: this.dims.length + 1 })];
48749 }
48750 computeOutputShape(inputShape) {
48751 inputShape = getExactlyOneShape(inputShape);
48752 const outputShape = inputShape.slice();
48753 this.dims.forEach((dim, i) => {
48754 outputShape[i + 1] = inputShape[dim];
48755 });
48756 return outputShape;
48757 }
48758 call(inputs, kwargs) {
48759 return transpose(getExactlyOneTensor(inputs), this.dimsIncludingBatch);
48760 }
48761 getConfig() {
48762 const config = {
48763 dims: this.dims,
48764 };
48765 const baseConfig = super.getConfig();
48766 Object.assign(config, baseConfig);
48767 return config;
48768 }
48769 }
48770 /** @nocollapse */
48771 Permute.className = 'Permute';
48772 registerClass(Permute);
48773 class Masking extends Layer {
48774 constructor(args) {
48775 super(args == null ? {} : args);
48776 this.supportsMasking = true;
48777 if (args != null) {
48778 this.maskValue = args.maskValue == null ? 0 : args.maskValue;
48779 }
48780 else {
48781 this.maskValue = 0;
48782 }
48783 }
48784 computeOutputShape(inputShape) {
48785 return inputShape;
48786 }
48787 getConfig() {
48788 const baseConfig = super.getConfig();
48789 const config = { maskValue: this.maskValue };
48790 Object.assign(config, baseConfig);
48791 return config;
48792 }
48793 computeMask(inputs, mask) {
48794 const input = getExactlyOneTensor(inputs);
48795 const axis = -1;
48796 return any(notEqual(input, this.maskValue), axis);
48797 }
48798 call(inputs, kwargs) {
48799 return tidy(() => {
48800 this.invokeCallHook(inputs, kwargs);
48801 const input = getExactlyOneTensor(inputs);
48802 const axis = -1;
48803 const keepDims = true;
48804 const booleanMask = any(notEqual(input, this.maskValue), axis, keepDims);
48805 const output = mul(input, cast(booleanMask, input.dtype));
48806 return output;
48807 });
48808 }
48809 }
48810 /** @nocollapse */
48811 Masking.className = 'Masking';
48812 registerClass(Masking);
48813
48814 /**
48815 * @license
48816 * Copyright 2018 Google LLC
48817 *
48818 * Use of this source code is governed by an MIT-style
48819 * license that can be found in the LICENSE file or at
48820 * https://opensource.org/licenses/MIT.
48821 * =============================================================================
48822 */
48823 class Embedding extends Layer {
48824 constructor(args) {
48825 super(args);
48826 this.embeddings = null;
48827 this.DEFAULT_EMBEDDINGS_INITIALIZER = 'randomUniform';
48828 if (args.batchInputShape == null && args.inputShape == null) {
48829 // Porting Note: This logic is copied from Layer's constructor, since we
48830 // can't do exactly what the Python constructor does for Embedding().
48831 // Specifically, the super constructor can not be called after the
48832 // mutation of the `config` argument.
48833 let batchSize = null;
48834 if (args.batchSize != null) {
48835 batchSize = args.batchSize;
48836 }
48837 if (args.inputLength == null) {
48838 // Fix super-constructor to what it would have done if
48839 // 'config.inputShape' were (None, )
48840 this.batchInputShape = [batchSize, null];
48841 }
48842 else {
48843 // Fix super-constructor to what it would have done if
48844 // 'config.inputShape' were (config.inputLength, )
48845 this.batchInputShape =
48846 [batchSize].concat(toList(args.inputLength));
48847 }
48848 }
48849 this.inputDim = args.inputDim;
48850 assertPositiveInteger(this.inputDim, 'inputDim');
48851 this.outputDim = args.outputDim;
48852 assertPositiveInteger(this.outputDim, 'outputDim');
48853 this.embeddingsInitializer = getInitializer(args.embeddingsInitializer || this.DEFAULT_EMBEDDINGS_INITIALIZER);
48854 this.embeddingsRegularizer = getRegularizer(args.embeddingsRegularizer);
48855 this.activityRegularizer = getRegularizer(args.activityRegularizer);
48856 this.embeddingsConstraint = getConstraint(args.embeddingsConstraint);
48857 this.maskZero = args.maskZero;
48858 this.supportsMasking = args.maskZero;
48859 this.inputLength = args.inputLength;
48860 }
48861 build(inputShape) {
48862 this.embeddings = this.addWeight('embeddings', [this.inputDim, this.outputDim], this.dtype, this.embeddingsInitializer, this.embeddingsRegularizer, true, this.embeddingsConstraint);
48863 this.built = true;
48864 }
48865 // Override warnOnIncompatibleInputShape because an embedding layer allows
48866 // the input to have varying ranks.
48867 warnOnIncompatibleInputShape(inputShape) { }
48868 computeMask(inputs, mask) {
48869 return tidy(() => {
48870 if (!this.maskZero) {
48871 return null;
48872 }
48873 else {
48874 inputs = getExactlyOneTensor(inputs);
48875 return notEqual(inputs, zerosLike(inputs));
48876 }
48877 });
48878 }
48879 computeOutputShape(inputShape) {
48880 inputShape = getExactlyOneShape(inputShape);
48881 if (this.inputLength == null) {
48882 return [...inputShape, this.outputDim];
48883 }
48884 // inputLength can be an array if input is 3D or higher.
48885 const inLens = toList(this.inputLength);
48886 if (inLens.length !== inputShape.length - 1) {
48887 throw new ValueError(`"inputLength" is ${this.inputLength}, but received ` +
48888 `input shape has shape ${inputShape}`);
48889 }
48890 else {
48891 let i = 0;
48892 for (let k = 0; k < inLens.length; ++k) {
48893 const s1 = inLens[k];
48894 const s2 = inputShape[k + 1];
48895 if ((s1 != null) && (s2 != null) && (s1 !== s2)) {
48896 throw new ValueError(`"inputLength" is ${this.inputLength}, but received ` +
48897 `input shape has shape ${inputShape}`);
48898 }
48899 else if (s1 == null) {
48900 inLens[i] = s2;
48901 }
48902 i++;
48903 }
48904 }
48905 return [inputShape[0], ...inLens, this.outputDim];
48906 }
48907 call(inputs, kwargs) {
48908 return tidy(() => {
48909 this.invokeCallHook(inputs, kwargs);
48910 // Embedding layer accepts only a single input.
48911 let input = getExactlyOneTensor(inputs);
48912 if (input.dtype !== 'int32') {
48913 input = cast$1(input, 'int32');
48914 }
48915 const output = gather$1(this.embeddings.read(), reshape(input, [input.size]));
48916 return reshape(output, getExactlyOneShape(this.computeOutputShape(input.shape)));
48917 });
48918 }
48919 getConfig() {
48920 const config = {
48921 inputDim: this.inputDim,
48922 outputDim: this.outputDim,
48923 embeddingsInitializer: serializeInitializer(this.embeddingsInitializer),
48924 embeddingsRegularizer: serializeRegularizer(this.embeddingsRegularizer),
48925 activityRegularizer: serializeRegularizer(this.activityRegularizer),
48926 embeddingsConstraint: serializeConstraint(this.embeddingsConstraint),
48927 maskZero: this.maskZero,
48928 inputLength: this.inputLength
48929 };
48930 const baseConfig = super.getConfig();
48931 Object.assign(config, baseConfig);
48932 return config;
48933 }
48934 }
48935 /** @nocollapse */
48936 Embedding.className = 'Embedding';
48937 registerClass(Embedding);
48938
48939 /**
48940 * @license
48941 * Copyright 2018 Google LLC
48942 *
48943 * Use of this source code is governed by an MIT-style
48944 * license that can be found in the LICENSE file or at
48945 * https://opensource.org/licenses/MIT.
48946 * =============================================================================
48947 */
48948 /**
48949 * Generic Merge layer for element-wise merge functions.
48950 *
48951 * Used to implement `Sum`, `Average`, `Concatenate`, etc.
48952 */
48953 class Merge extends Layer {
48954 constructor(args) {
48955 super(args || {});
48956 this.supportsMasking = true;
48957 }
48958 /**
48959 * Logic for merging multiple tensors, to be overridden by subclasses.
48960 * @param inputs
48961 */
48962 mergeFunction(inputs) {
48963 throw new NotImplementedError();
48964 }
48965 /**
48966 * Computes the shape of the result of an elementwise operation.
48967 *
48968 * @param shape1: Shape of the first tensor.
48969 * @param shape2: Shape of the second tensor.
48970 * @returns Expected output shape when an elementwise operation is carried
48971 * out on 2 tensors with shapes `shape1` and `shape2`.
48972 * @throws ValueError: If `shape1` and `shape2` are not compatible for
48973 * element-wise operations.
48974 */
48975 computeElementwiseOpOutputShape(shape1, shape2) {
48976 if (shape1 == null || shape2 == null) {
48977 return null;
48978 }
48979 else if (shape1.length < shape2.length) {
48980 return this.computeElementwiseOpOutputShape(shape2, shape1);
48981 }
48982 else if (shape2.length === 0) {
48983 return shape1;
48984 }
48985 const outputShape = shape1.slice(0, shape1.length - shape2.length);
48986 for (let k = 0; k < shape2.length; ++k) {
48987 const i = shape1[shape1.length - shape2.length + k];
48988 const j = shape2[k];
48989 if (i == null || j == null || i < 0 || j < 0) {
48990 outputShape.push(null);
48991 }
48992 else if (i === 1) {
48993 outputShape.push(j);
48994 }
48995 else if (j === 1) {
48996 outputShape.push(i);
48997 }
48998 else {
48999 if (i !== j) {
49000 throw new ValueError('Operands could not be broadcast together with shapes ' +
49001 JSON.stringify(shape1) + ' ' + JSON.stringify(shape2));
49002 }
49003 outputShape.push(i);
49004 }
49005 }
49006 return outputShape;
49007 }
49008 build(inputShape) {
49009 // Used purely for shape validation.
49010 if (Array.isArray(inputShape) && !Array.isArray(inputShape[0])) {
49011 // Make sure that inputShape is an Array of shape.
49012 inputShape = [getExactlyOneShape(inputShape)];
49013 }
49014 inputShape = inputShape;
49015 if (inputShape.length < 2) {
49016 throw new ValueError('A merge layer should be called on an Array of at least 2 inputs.' +
49017 ` Got ${inputShape.length} input(s).`);
49018 }
49019 // Make sure that there is at most one unique batch size among the input
49020 // shapes.
49021 let batchSizes = [];
49022 for (const shape of inputShape) {
49023 if (shape != null && shape[0] !== null) {
49024 batchSizes.push(shape[0]);
49025 }
49026 }
49027 batchSizes = unique$1(batchSizes);
49028 if (batchSizes.length > 1) {
49029 throw new ValueError(`Can not merge tensors with different batch sizes. ` +
49030 `Got tensors with shapes: ${JSON.stringify(inputShape)}.`);
49031 }
49032 let outputShape = inputShape[0] == null ? null : inputShape[0].slice(1);
49033 for (let i = 1; i < inputShape.length; ++i) {
49034 const shape = inputShape[i] == null ? null : inputShape[i].slice(1);
49035 outputShape = this.computeElementwiseOpOutputShape(outputShape, shape);
49036 }
49037 // If the inputs have different ranks, we have to reshape them to make them
49038 // broadcastable.
49039 const allRanks = inputShape.map(shape => shape.length);
49040 if (inputShape.indexOf(null) === -1 &&
49041 unique$1(allRanks).length === 1) {
49042 this.reshapeRequired = false;
49043 }
49044 else {
49045 this.reshapeRequired = true;
49046 }
49047 }
49048 call(inputs, kwargs) {
49049 return tidy(() => {
49050 inputs = inputs;
49051 if (this.reshapeRequired) {
49052 const reshapedInputs = [];
49053 const inputDims = inputs.map(input => input.rank);
49054 if (inputDims.indexOf(null) === -1) {
49055 // If ranks of all inputs are available, we simply expand each of them
49056 // at axis=1 until all of them have the same rank.
49057 const maxNDim = max$1(inputDims);
49058 for (let x of inputs) {
49059 const xNDim = x.rank;
49060 for (let k = 0; k < maxNDim - xNDim; ++k) {
49061 x = expandDims$1(x, 1);
49062 }
49063 reshapedInputs.push(x);
49064 }
49065 return this.mergeFunction(reshapedInputs);
49066 }
49067 else {
49068 // Transpose all inputs so that batch size is the last dimension.
49069 // [batchSize, dim1, dim2, ...] -> [dim1, dim2, ..., batchSize]
49070 let transposed = false;
49071 for (const x of inputs) {
49072 const xNDim = x.rank;
49073 if (xNDim == null) {
49074 const xShape = x.shape;
49075 const batchSize = xShape[0];
49076 const newShape = xShape.slice(1).concat([batchSize]);
49077 let xTransposed = reshape(x, [batchSize].concat(arrayProd(xShape.slice(1))));
49078 xTransposed = transpose(xTransposed, [1, 0]);
49079 xTransposed = reshape(xTransposed, newShape);
49080 reshapedInputs.push(xTransposed);
49081 transposed = true;
49082 }
49083 else if (xNDim > 1) {
49084 const dims = range$1(1, xNDim).concat([0]);
49085 reshapedInputs.push(transpose(x, dims));
49086 transposed = true;
49087 }
49088 else {
49089 // We don't transpose inputs if they are 1D vectors or scalars.
49090 reshapedInputs.push(x);
49091 }
49092 }
49093 let y = this.mergeFunction(reshapedInputs);
49094 const yNDim = y.rank;
49095 if (transposed) {
49096 // If inputs have been transposed, we have to transpose the output
49097 // too.
49098 if (yNDim == null) {
49099 const yShape = y.shape;
49100 const yNDim = yShape.length;
49101 const batchSize = yShape[yNDim - 1];
49102 const newShape = [batchSize].concat(yShape.slice(0, yShape.length - 1));
49103 y = reshape(transpose(reshape(y, [-1, batchSize]), [1, 0]), newShape);
49104 }
49105 else if (yNDim > 1) {
49106 const dims = [yNDim - 1].concat(range$1(0, yNDim - 1));
49107 y = transpose(y, dims);
49108 }
49109 }
49110 return y;
49111 }
49112 }
49113 else {
49114 return this.mergeFunction(inputs);
49115 }
49116 });
49117 }
49118 computeOutputShape(inputShape) {
49119 inputShape = inputShape;
49120 let outputShape;
49121 if (inputShape[0] == null) {
49122 outputShape = null;
49123 }
49124 else {
49125 outputShape = inputShape[0].slice(1);
49126 }
49127 for (let i = 1; i < inputShape.length; ++i) {
49128 const shape = inputShape[i] == null ? null : inputShape[i].slice(1);
49129 outputShape = this.computeElementwiseOpOutputShape(outputShape, shape);
49130 }
49131 let batchSizes = [];
49132 for (const shape of inputShape) {
49133 if (shape != null && shape[0] !== null) {
49134 batchSizes.push(shape[0]);
49135 }
49136 }
49137 batchSizes = unique$1(batchSizes);
49138 if (batchSizes.length === 1) {
49139 outputShape = batchSizes.concat(outputShape);
49140 }
49141 else {
49142 outputShape = [null].concat(outputShape);
49143 }
49144 return outputShape;
49145 }
49146 computeMask(inputs, mask) {
49147 return tidy(() => {
49148 if (mask == null) {
49149 return null;
49150 }
49151 if (!Array.isArray(mask)) {
49152 throw new ValueError('`mask` should be an Array');
49153 }
49154 if (!Array.isArray(inputs)) {
49155 throw new ValueError('`inputs` should be an Array');
49156 }
49157 if (mask.length !== inputs.length) {
49158 throw new ValueError(`The Array 'inputs' and 'mask' are expected to have the same ` +
49159 `length, but have different lengths ` +
49160 `(${inputs.length} vs ${mask.length})`);
49161 }
49162 if (mask.every(m => m == null)) {
49163 return null;
49164 }
49165 mask = mask.map(m => m == null ? m : expandDims(m, 0));
49166 let output = mask[0];
49167 for (let i = 1; i < mask.length - 1; ++i) {
49168 output = logicalAnd(output, mask[i]);
49169 }
49170 return output;
49171 });
49172 }
49173 }
49174 class Add$1 extends Merge {
49175 constructor(args) {
49176 super(args);
49177 }
49178 mergeFunction(inputs) {
49179 return tidy(() => {
49180 let output = inputs[0].clone();
49181 for (let i = 1; i < inputs.length; ++i) {
49182 output = add$1(output, inputs[i]);
49183 }
49184 return output;
49185 });
49186 }
49187 }
49188 /** @nocollapse */
49189 Add$1.className = 'Add';
49190 registerClass(Add$1);
49191 /**
49192 * Calculate the element-wise sum of inputs, which all have the same shape.
49193 *
49194 * This function can be invoked in three ways.
49195 *
49196 * 1. Construct an instance of `Add` layer, by using no input argument
49197 * or a single configuration argument. The resultant `Add` layer can then
49198 * be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
49199 *
49200 * ```js
49201 * const addLayer = tf.layers.add();
49202 *
49203 * // The layer can be applied to inputs.
49204 * const input1 = tf.input({shape: [2, 2]});
49205 * const input2 = tf.input({shape: [2, 2]});
49206 * const output = addLayer.apply([input1, input2]);
49207 * console.log(output.shape);
49208 * // You get [null, 2, 2], with the first dimension as the undetermined batch
49209 * // dimension.
49210 * ```
49211 *
49212 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
49213 * an `Layer` object internally and calls its `apply` method on the inputs,
49214 * generating a new `tf.SymbolicTensor`. For example:
49215 *
49216 * ```js
49217 * const input1 = tf.input({shape: [2, 2]});
49218 * const input2 = tf.input({shape: [2, 2]});
49219 * const output = tf.layers.add([input1, input2]);
49220 * console.log(output.shape);
49221 * // You get [null, 2, 2], with the first dimension as the undetermined batch
49222 * // dimension.
49223 * ```
49224 *
49225 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
49226 * an `Layer` object internally and calls its `apply` method on the inputs,
49227 * generating a new `tf.Tensor` as the result of the computation. For
49228 * example:
49229 *
49230 * ```js
49231 * const input1 = tf.tensor2d([1, 2, 3, 4], [2, 2]);
49232 * const input2 = tf.tensor2d([10, 20, 30, 40], [2, 2]);
49233 * tf.layers.add([input1, input2]).print();
49234 * // Gives [[11, 22], [33, 44]].
49235 *
49236 */
49237 function add$2(config) {
49238 if (Array.isArray(config)) {
49239 const layer = new Add$1({});
49240 return layer.apply(config);
49241 }
49242 else {
49243 return new Add$1(config);
49244 }
49245 }
49246 class Multiply$1 extends Merge {
49247 constructor(args) {
49248 super(args);
49249 }
49250 mergeFunction(inputs) {
49251 return tidy(() => {
49252 let output = inputs[0].clone();
49253 for (let i = 1; i < inputs.length; ++i) {
49254 output = mul(output, inputs[i]);
49255 }
49256 return output;
49257 });
49258 }
49259 }
49260 /** @nocollapse */
49261 Multiply$1.className = 'Multiply';
49262 registerClass(Multiply$1);
49263 /**
49264 * Calculate the element-wise product of inputs, which all have the same shape.
49265 *
49266 * This function can be invoked in three ways.
49267 *
49268 * 1. Construct an instance of `Multiply` layer, by using no input argument
49269 * or a single configuration argument. The resultant `Multiply` layer can
49270 * then be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
49271 *
49272 * ```js
49273 * const multiplyLayer = tf.layers.multiply();
49274 *
49275 * // The layer can be applied to inputs.
49276 * const input1 = tf.input({shape: [2, 2]});
49277 * const input2 = tf.input({shape: [2, 2]});
49278 * const output = multiplyLayer.apply([input1, input2]);
49279 * console.log(output.shape);
49280 * // You get [null, 2, 2], with the first dimension as the undetermined batch
49281 * // dimension.
49282 * ```
49283 *
49284 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
49285 * an `Layer` object internally and calls its `apply` method on the inputs,
49286 * generating a new `tf.SymbolicTensor`. For example:
49287 *
49288 * ```js
49289 * const input1 = tf.input({shape: [2, 2]});
49290 * const input2 = tf.input({shape: [2, 2]});
49291 * const output = tf.layers.multiply([input1, input2]);
49292 * console.log(output.shape);
49293 * // You get [null, 2, 2], with the first dimension as the undetermined batch
49294 * // dimension.
49295 * ```
49296 *
49297 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
49298 * an `Layer` object internally and calls its `apply` method on the inputs,
49299 * generating a new `tf.Tensor` as the result of the computation. For
49300 * example:
49301 *
49302 * ```js
49303 * const input1 = tf.tensor2d([1, 2, 3, 4], [2, 2]);
49304 * const input2 = tf.tensor2d([10, 20, 30, 40], [2, 2]);
49305 * tf.layers.multiply([input1, input2]).print();
49306 * // Gives [[10, 40], [90, 160]].
49307 *
49308 */
49309 function multiply(config) {
49310 if (Array.isArray(config)) {
49311 const layer = new Multiply$1({});
49312 return layer.apply(config);
49313 }
49314 else {
49315 return new Multiply$1(config);
49316 }
49317 }
49318 class Average extends Merge {
49319 constructor(args) {
49320 super(args);
49321 }
49322 mergeFunction(inputs) {
49323 return tidy(() => {
49324 let output = inputs[0].clone();
49325 for (let i = 1; i < inputs.length; ++i) {
49326 output = add$1(output, inputs[i]);
49327 }
49328 return mul(1 / inputs.length, output);
49329 });
49330 }
49331 }
49332 /** @nocollapse */
49333 Average.className = 'Average';
49334 registerClass(Average);
49335 /**
49336 * Calculate the element-wise arithmetic mean of inputs, which all have the same
49337 * shape.
49338 *
49339 * This function can be invoked in three ways.
49340 *
49341 * 1. Construct an instance of `Average` layer, by using no input argument
49342 * or a single configuration argument. The resultant `Average` layer can then
49343 * be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
49344 *
49345 * ```js
49346 * const averageLayer = tf.layers.average();
49347 *
49348 * // The layer can be applied to inputs.
49349 * const input1 = tf.input({shape: [2, 2]});
49350 * const input2 = tf.input({shape: [2, 2]});
49351 * const output = averageLayer.apply([input1, input2]);
49352 * console.log(output.shape);
49353 * // You get [null, 2, 2], with the first dimension as the undetermined batch
49354 * // dimension.
49355 * ```
49356 *
49357 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
49358 * an `Layer` object internally and calls its `apply` method on the inputs,
49359 * generating a new `tf.SymbolicTensor`. For example:
49360 *
49361 * ```js
49362 * const input1 = tf.input({shape: [2, 2]});
49363 * const input2 = tf.input({shape: [2, 2]});
49364 * const output = tf.layers.average([input1, input2]);
49365 * console.log(output.shape);
49366 * // You get [null, 2, 2], with the first dimension as the undetermined batch
49367 * // dimension.
49368 * ```
49369 *
49370 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
49371 * an `Layer` object internally and calls its `apply` method on the inputs,
49372 * generating a new `tf.Tensor` as the result of the computation. For
49373 * example:
49374 *
49375 * ```js
49376 * const input1 = tf.tensor2d([1, 2, 3, 4], [2, 2]);
49377 * const input2 = tf.tensor2d([10, 20, 30, 40], [2, 2]);
49378 * tf.layers.average([input1, input2]).print();
49379 * // Gives [[5.5, 11], [16.5, 22]].
49380 *
49381 */
49382 function average(config) {
49383 if (Array.isArray(config)) {
49384 const layer = new Average({});
49385 return layer.apply(config);
49386 }
49387 else {
49388 return new Average(config);
49389 }
49390 }
49391 class Maximum$1 extends Merge {
49392 constructor(args) {
49393 super(args);
49394 }
49395 mergeFunction(inputs) {
49396 return tidy(() => {
49397 let output = inputs[0];
49398 for (let i = 1; i < inputs.length; ++i) {
49399 output = maximum(output, inputs[i]);
49400 }
49401 return output;
49402 });
49403 }
49404 }
49405 /** @nocollapse */
49406 Maximum$1.className = 'Maximum';
49407 registerClass(Maximum$1);
49408 /**
49409 * Calculate the element-wise maximum of inputs, which all have the same shape.
49410 *
49411 * This function can be invoked in three ways.
49412 *
49413 * 1. Construct an instance of `Maximum` layer, by using no input argument
49414 * or a single configuration argument. The resultant `Maximum` layer can then
49415 * be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
49416 *
49417 * ```js
49418 * const maximumLayer = tf.layers.maximum();
49419 *
49420 * // The layer can be applied to inputs.
49421 * const input1 = tf.input({shape: [2, 2]});
49422 * const input2 = tf.input({shape: [2, 2]});
49423 * const output = maximumLayer.apply([input1, input2]);
49424 * console.log(output.shape);
49425 * // You get [null, 2, 2], with the first dimension as the undetermined batch
49426 * // dimension.
49427 * ```
49428 *
49429 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
49430 * an `Layer` object internally and calls its `apply` method on the inputs,
49431 * generating a new `tf.SymbolicTensor`. For example:
49432 *
49433 * ```js
49434 * const input1 = tf.input({shape: [2, 2]});
49435 * const input2 = tf.input({shape: [2, 2]});
49436 * const output = tf.layers.maximum([input1, input2]);
49437 * console.log(output.shape);
49438 * // You get [null, 2, 2], with the first dimension as the undetermined batch
49439 * // dimension.
49440 * ```
49441 *
49442 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
49443 * an `Layer` object internally and calls its `apply` method on the inputs,
49444 * generating a new `tf.Tensor` as the result of the computation. For
49445 * example:
49446 *
49447 * ```js
49448 * const input1 = tf.tensor2d([1, 20, 3, 40], [2, 2]);
49449 * const input2 = tf.tensor2d([10, 2, 30, 4], [2, 2]);
49450 * tf.layers.maximum([input1, input2]).print();
49451 * // Gives [[10, 20], [30, 40]].
49452 *
49453 */
49454 function maximum$1(config) {
49455 if (Array.isArray(config)) {
49456 const layer = new Maximum$1({});
49457 return layer.apply(config);
49458 }
49459 else {
49460 return new Maximum$1(config);
49461 }
49462 }
49463 class Minimum$1 extends Merge {
49464 constructor(args) {
49465 super(args);
49466 }
49467 mergeFunction(inputs) {
49468 return tidy(() => {
49469 let output = inputs[0];
49470 for (let i = 1; i < inputs.length; ++i) {
49471 output = minimum(output, inputs[i]);
49472 }
49473 return output;
49474 });
49475 }
49476 }
49477 /** @nocollapse */
49478 Minimum$1.className = 'Minimum';
49479 registerClass(Minimum$1);
49480 /**
49481 * Calculate the element-wise minimum of inputs, which all have the same shape.
49482 *
49483 * This function can be invoked in three ways.
49484 *
49485 * 1. Construct an instance of `Minimum` layer, by using no input argument
49486 * or a single configuration argument. The resultant `Minimum` layer can then
49487 * be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
49488 *
49489 * ```js
49490 * const minimumLayer = tf.layers.minimum();
49491 *
49492 * // The layer can be applied to inputs.
49493 * const input1 = tf.input({shape: [2, 2]});
49494 * const input2 = tf.input({shape: [2, 2]});
49495 * const output = minimumLayer.apply([input1, input2]);
49496 * console.log(output.shape);
49497 * // You get [null, 2, 2], with the first dimension as the undetermined batch
49498 * // dimension.
49499 * ```
49500 *
49501 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
49502 * an `Layer` object internally and calls its `apply` method on the inputs,
49503 * generating a new `tf.SymbolicTensor`. For example:
49504 *
49505 * ```js
49506 * const input1 = tf.input({shape: [2, 2]});
49507 * const input2 = tf.input({shape: [2, 2]});
49508 * const output = tf.layers.minimum([input1, input2]);
49509 * console.log(output.shape);
49510 * // You get [null, 2, 2], with the first dimension as the undetermined batch
49511 * // dimension.
49512 * ```
49513 *
49514 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
49515 * an `Layer` object internally and calls its `apply` method on the inputs,
49516 * generating a new `tf.Tensor` as the result of the computation. For
49517 * example:
49518 *
49519 * ```js
49520 * const input1 = tf.tensor2d([1, 20, 3, 40], [2, 2]);
49521 * const input2 = tf.tensor2d([10, 2, 30, 4], [2, 2]);
49522 * tf.layers.minimum([input1, input2]).print();
49523 * // Gives [[1, 2], [3, 4]].
49524 *
49525 */
49526 function minimum$1(config) {
49527 if (Array.isArray(config)) {
49528 const layer = new Minimum$1({});
49529 return layer.apply(config);
49530 }
49531 else {
49532 return new Minimum$1(config);
49533 }
49534 }
49535 class Concatenate extends Merge {
49536 constructor(args) {
49537 super(args);
49538 this.DEFAULT_AXIS = -1;
49539 if (args == null) {
49540 args = {};
49541 }
49542 this.axis = args.axis == null ? this.DEFAULT_AXIS : args.axis;
49543 this.supportsMasking = true;
49544 this.reshapeRequired = false;
49545 }
49546 build(inputShape) {
49547 // Used purely for shape validation.]
49548 if (!(Array.isArray(inputShape) && Array.isArray(inputShape[0])) ||
49549 inputShape.length === 1) {
49550 throw new ValueError('A `Concatenate` layer should be called on a list of at least 2 ' +
49551 'inputs');
49552 }
49553 inputShape = inputShape;
49554 let allNoneShape = true;
49555 for (const shape of inputShape) {
49556 if (shape != null) {
49557 allNoneShape = false;
49558 break;
49559 }
49560 }
49561 if (allNoneShape) {
49562 return;
49563 }
49564 const shapeSet = [];
49565 for (let i = 0; i < inputShape.length; ++i) {
49566 const shapeWithoutConcatAxis = inputShape[i].slice();
49567 shapeWithoutConcatAxis.splice(this.axis, 1);
49568 let exists = false;
49569 for (const shape of shapeSet) {
49570 if (arraysEqual(shape, shapeWithoutConcatAxis)) {
49571 exists = true;
49572 break;
49573 }
49574 }
49575 if (!exists) {
49576 shapeSet.push(shapeWithoutConcatAxis);
49577 }
49578 }
49579 if (shapeSet.length > 1) {
49580 throw new ValueError('A `Concatenate` layer requires inputs with matching shapes ' +
49581 'except for the concat axis. Got input shapes: ' +
49582 JSON.stringify(inputShape));
49583 }
49584 }
49585 mergeFunction(inputs) {
49586 return tidy(() => {
49587 return concatenate(inputs, this.axis);
49588 });
49589 }
49590 computeOutputShape(inputShape) {
49591 if (!(Array.isArray(inputShape) && Array.isArray(inputShape[0]))) {
49592 throw new ValueError('A `Concatenate` layer should be called on a list of inputs.');
49593 }
49594 const inputShapes = inputShape;
49595 const outputShape = inputShapes[0].slice();
49596 const axis = this.axis < 0 ? outputShape.length + this.axis : this.axis;
49597 // Porting Note: the line above is because TypeScript doesn't support
49598 // negative indices.
49599 for (const shape of inputShapes.slice(1)) {
49600 if (outputShape[axis] == null || shape[axis] == null) {
49601 outputShape[axis] = null;
49602 break;
49603 }
49604 outputShape[axis] += shape[axis];
49605 }
49606 return outputShape;
49607 }
49608 computeMask(inputs, mask) {
49609 if (mask == null) {
49610 return null;
49611 }
49612 if (!Array.isArray(mask)) {
49613 throw new ValueError('`mask` should be an array for Concatenate');
49614 }
49615 if (!Array.isArray(inputs)) {
49616 throw new ValueError('`inputs` should be an array for Concatenate');
49617 }
49618 if (mask.length !== inputs.length) {
49619 throw new ValueError(`Mismatch in the length of mask (${mask.length}) ` +
49620 `and the legnth of inputs (${inputs.length})`);
49621 }
49622 return tidy(() => {
49623 let allNullMasks = true;
49624 mask.forEach(m => {
49625 if (m != null) {
49626 allNullMasks = false;
49627 return;
49628 }
49629 });
49630 if (allNullMasks) {
49631 return null;
49632 }
49633 const outputMasks = [];
49634 for (let i = 0; i < inputs.length; ++i) {
49635 if (mask[i] == null) {
49636 // Input is unmasked. Append all 1's to masks.
49637 outputMasks.push(cast(onesLike(inputs[i]), 'bool'));
49638 }
49639 else if (mask[i].rank < inputs[i].rank) {
49640 // Mask is smaller than the input, expand it.
49641 outputMasks.push(expandDims(mask[i], -1));
49642 }
49643 else {
49644 outputMasks.push(mask[i]);
49645 }
49646 }
49647 const concatenatedMasks = concat(outputMasks, this.axis);
49648 return all(concatenatedMasks, -1, false);
49649 });
49650 }
49651 getConfig() {
49652 const config = {
49653 'axis': this.axis,
49654 };
49655 const baseConfig = super.getConfig();
49656 Object.assign(config, baseConfig);
49657 return config;
49658 }
49659 }
49660 /** @nocollapse */
49661 Concatenate.className = 'Concatenate';
49662 registerClass(Concatenate);
49663 /**
49664 * Concatenate an `Array` of inputs.
49665 *
49666 * This function can be invoked in three ways.
49667 *
49668 * 1. Construct an instance of `Concatenate` layer, by using no input argument
49669 * or a single configuration argument. The resultant `Concatenate` layer can
49670 * then be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
49671 *
49672 * ```js
49673 * const concatLayer = tf.layers.concatenate();
49674 *
49675 * // The layer can be applied to inputs.
49676 * const input1 = tf.input({shape: [2, 3]});
49677 * const input2 = tf.input({shape: [2, 4]});
49678 * const output = concatLayer.apply([input1, input2]);
49679 * console.log(output.shape);
49680 * // You get [null, 2, 7], with the first dimension as the undetermined batch
49681 * // dimension and the last dimension as the result of concatenating the
49682 * // last dimensions of the two inputs.
49683 * ```
49684 *
49685 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
49686 * an `Layer` object internally and calls its `apply` method on the inputs,
49687 * generating a new `tf.SymbolicTensor`. For example:
49688 *
49689 * ```js
49690 * const input1 = tf.input({shape: [2, 3]});
49691 * const input2 = tf.input({shape: [2, 4]});
49692 * const output = tf.layers.concatenate([input1, input2]);
49693 * console.log(output.shape);
49694 * // You get [null, 2, 2], with the first dimension as the undetermined batch
49695 * // dimension and the last dimension as the result of concatenating the
49696 * // last dimensions of the two inputs.
49697 * ```
49698 *
49699 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
49700 * an `Layer` object internally and calls its `apply` method on the inputs,
49701 * generating a new `tf.Tensor` as the result of the computation. For
49702 * example:
49703 *
49704 * ```js
49705 * const input1 = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
49706 * const input2 = tf.tensor2d([[10, 20], [30, 40]], [2, 2]);
49707 * tf.layers.concatenate([input1, input2]).print();
49708 * // Gives [[1, 2, 10, 20], [3, 4, 30, 40]].
49709 *
49710 */
49711 function concatenate$1(config) {
49712 if (Array.isArray(config)) {
49713 const layer = new Concatenate({});
49714 return layer.apply(config);
49715 }
49716 else {
49717 return new Concatenate(config);
49718 }
49719 }
49720 /**
49721 * Interpretable potentially negative axis index.
49722 *
49723 * For example, given axis = -1, and dim = 3, this function will return 2.
49724 *
49725 * @param axis The axis index, may be a positive, zero or negative integer.
49726 * @param dim Total number of dimensions, a positive integer.
49727 * @returns A non-negative axis index equivalent to the input `axis`.
49728 */
49729 function interpretAxis(axis, dim) {
49730 while (axis < 0) {
49731 axis += dim;
49732 }
49733 return axis;
49734 }
49735 function batchDot(x, y, axes) {
49736 if (x.shape.length > 3 || y.shape.length > 3) {
49737 throw new NotImplementedError('batchDot is not implemented for tensors of 4D or higher rank yet');
49738 }
49739 assert(x.shape.length >= 2, () => `batchDot requires the rank of x to be >= 2, ` +
49740 `but got ${x.shape.length}`);
49741 assert(x.shape.length >= 2, () => `batchDot requires the rank of y to be >= 2, ` +
49742 `but got ${y.shape.length}`);
49743 if (typeof axes === 'number') {
49744 axes = [axes, axes];
49745 }
49746 if (x.dtype === 'complex64' || y.dtype === 'complex64') {
49747 throw new NotImplementedError('batchDot is not implemented for complex64-type Tensors yet.');
49748 }
49749 const xNDim = x.shape.length;
49750 const yNDim = y.shape.length;
49751 if (axes == null) {
49752 // Behave like batchMatmul by default.
49753 axes = [xNDim - 1, yNDim - 2];
49754 }
49755 const axesArray = axes;
49756 return tidy(() => {
49757 let diff;
49758 if (xNDim > yNDim) {
49759 diff = xNDim - yNDim;
49760 const diffShape = [];
49761 for (let i = 0; i < diff; ++i) {
49762 diffShape.push(1);
49763 }
49764 y = reshape(y, y.shape.concat(diffShape));
49765 }
49766 else if (yNDim > xNDim) {
49767 diff = yNDim - xNDim;
49768 const diffShape = [];
49769 for (let i = 0; i < diff; ++i) {
49770 diffShape.push(1);
49771 }
49772 x = reshape(x, x.shape.concat(diffShape));
49773 }
49774 else {
49775 diff = 0;
49776 }
49777 let out;
49778 if (x.shape.length === 2 && y.shape.length === 2) {
49779 if (axesArray[0] === axesArray[1]) {
49780 out = sum$1(mul(x, y), axesArray[0]);
49781 }
49782 else {
49783 out = sum$1(mul(transpose(x, [1, 0]), y), axesArray[1]);
49784 }
49785 }
49786 else {
49787 const adjX = axesArray[0] !== x.shape.length - 1;
49788 const adjY = axesArray[1] === y.shape.length - 1;
49789 out = matMul(x, y, adjX, adjY);
49790 }
49791 if (diff > 0) {
49792 let idx;
49793 if (xNDim > yNDim) {
49794 idx = xNDim + yNDim - 3;
49795 }
49796 else {
49797 idx = xNDim - 1;
49798 }
49799 const squeezeAxes = [];
49800 for (let i = idx; i < idx + diff; ++i) {
49801 squeezeAxes.push(i);
49802 }
49803 out = squeeze(out, squeezeAxes);
49804 }
49805 if (out.shape.length === 1) {
49806 out = expandDims(out, 1);
49807 }
49808 return out;
49809 });
49810 }
49811 class Dot extends Merge {
49812 constructor(args) {
49813 super(args);
49814 this.axes = args.axes;
49815 this.normalize = args.normalize == null ? false : args.normalize;
49816 this.supportsMasking = true;
49817 this.reshapeRequired = false;
49818 }
49819 build(inputShape) {
49820 assert(Array.isArray(inputShape) && inputShape.length === 2 &&
49821 Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), () => 'A `Dot` layer should be called on a list of exactly 2 inputs.');
49822 const shape1 = inputShape[0];
49823 const shape2 = inputShape[1];
49824 if (shape1.length > 3 || shape2.length > 3) {
49825 throw new NotImplementedError('Dot layer does not support tensors of 4D or higher rank yet.');
49826 }
49827 const axes = this.interpretAxes(shape1, shape2);
49828 if (shape1[axes[0]] !== shape2[axes[1]]) {
49829 throw new ValueError(`Dimension incompatibility: ` +
49830 `${shape1[axes[0]]} !== ${shape2[axes[1]]}`);
49831 }
49832 }
49833 mergeFunction(inputs) {
49834 if (inputs.length !== 2) {
49835 throw new ValueError('A `Dot` layer must be called on exactly 2 inputs, ' +
49836 `but received ${inputs.length} input(s).`);
49837 }
49838 let x1 = inputs[0];
49839 let x2 = inputs[1];
49840 let axes;
49841 if (!Array.isArray(this.axes)) {
49842 axes = [
49843 interpretAxis(this.axes, x1.shape.length),
49844 interpretAxis(this.axes, x2.shape.length)
49845 ];
49846 }
49847 else {
49848 axes = this.axes.map((axis, i) => interpretAxis(axis, inputs[i].shape.length));
49849 }
49850 if (this.normalize) {
49851 x1 = l2Normalize(x1, axes[0]);
49852 x2 = l2Normalize(x2, axes[1]);
49853 }
49854 return batchDot(x1, x2, axes);
49855 }
49856 interpretAxes(shape1, shape2) {
49857 let axes;
49858 if (!Array.isArray(this.axes)) {
49859 // `this.axes` is a single integer.
49860 axes = [
49861 interpretAxis(this.axes, shape1.length),
49862 interpretAxis(this.axes, shape2.length)
49863 ];
49864 }
49865 else {
49866 // `this.axes` is an Array of integers.
49867 axes = this.axes;
49868 }
49869 return axes;
49870 }
49871 computeOutputShape(inputShape) {
49872 assert(Array.isArray(inputShape) && inputShape.length === 2 &&
49873 Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), () => 'A `Dot` layer should be called on a list of exactly 2 inputs.');
49874 const shape1 = inputShape[0].slice();
49875 const shape2 = inputShape[1].slice();
49876 if (shape1.length > 3 || shape2.length > 3) {
49877 throw new NotImplementedError('Dot layer does not support tensors of 4D or higher rank yet.');
49878 }
49879 const axes = this.interpretAxes(shape1, shape2);
49880 shape1.splice(axes[0], 1);
49881 shape2.splice(axes[1], 1);
49882 shape2.splice(0, 1);
49883 const outputShape = shape1.concat(shape2);
49884 if (outputShape.length === 1) {
49885 outputShape.push(1);
49886 }
49887 return outputShape;
49888 }
49889 computeMask(inputs, mask) {
49890 return null;
49891 }
49892 getConfig() {
49893 const config = {
49894 'axes': this.axes,
49895 'normalize': this.normalize
49896 };
49897 const baseConfig = super.getConfig();
49898 Object.assign(config, baseConfig);
49899 return config;
49900 }
49901 }
49902 /** @nocollapse */
49903 Dot.className = 'Dot';
49904 registerClass(Dot);
49905 // TODO(cais): Add functional interfaces for the merge layers.
49906
49907 /**
49908 * @license
49909 * Copyright 2018 Google LLC
49910 *
49911 * Use of this source code is governed by an MIT-style
49912 * license that can be found in the LICENSE file or at
49913 * https://opensource.org/licenses/MIT.
49914 * =============================================================================
49915 */
49916 class GaussianNoise extends Layer {
49917 constructor(args) {
49918 super(args);
49919 this.supportsMasking = true;
49920 this.stddev = args.stddev;
49921 }
49922 computeOutputShape(inputShape) {
49923 return inputShape;
49924 }
49925 getConfig() {
49926 const baseConfig = super.getConfig();
49927 const config = { stddev: this.stddev };
49928 Object.assign(config, baseConfig);
49929 return config;
49930 }
49931 call(inputs, kwargs) {
49932 return tidy(() => {
49933 this.invokeCallHook(inputs, kwargs);
49934 const input = getExactlyOneTensor(inputs);
49935 const noised = () => add$1(randomNormal$1(input.shape, 0, this.stddev), input);
49936 const output = inTrainPhase(noised, () => input, kwargs['training'] || false);
49937 return output;
49938 });
49939 }
49940 }
49941 /** @nocollapse */
49942 GaussianNoise.className = 'GaussianNoise';
49943 registerClass(GaussianNoise);
49944 class GaussianDropout extends Layer {
49945 constructor(args) {
49946 super(args);
49947 this.supportsMasking = true;
49948 this.rate = args.rate;
49949 }
49950 computeOutputShape(inputShape) {
49951 return inputShape;
49952 }
49953 getConfig() {
49954 const baseConfig = super.getConfig();
49955 const config = { rate: this.rate };
49956 Object.assign(config, baseConfig);
49957 return config;
49958 }
49959 call(inputs, kwargs) {
49960 return tidy(() => {
49961 this.invokeCallHook(inputs, kwargs);
49962 const input = getExactlyOneTensor(inputs);
49963 if (this.rate > 0 && this.rate < 1) {
49964 const noised = () => {
49965 const stddev = Math.sqrt(this.rate / (1 - this.rate));
49966 return mul(input, randomNormal$1(input.shape, 1, stddev));
49967 };
49968 return inTrainPhase(noised, () => input, kwargs['training'] || false);
49969 }
49970 return input;
49971 });
49972 }
49973 }
49974 /** @nocollapse */
49975 GaussianDropout.className = 'GaussianDropout';
49976 registerClass(GaussianDropout);
49977 /**
49978 * Applies Alpha Dropout to the input.
49979 *
49980 * As it is a regularization layer, it is only active at training time.
49981 *
49982 * Alpha Dropout is a `Dropout` that keeps mean and variance of inputs
49983 * to their original values, in order to ensure the self-normalizing property
49984 * even after this dropout.
49985 * Alpha Dropout fits well to Scaled Exponential Linear Units
49986 * by randomly setting activations to the negative saturation value.
49987 *
49988 * Arguments:
49989 * - `rate`: float, drop probability (as with `Dropout`).
49990 * The multiplicative noise will have
49991 * standard deviation `sqrt(rate / (1 - rate))`.
49992 * - `noise_shape`: A 1-D `Tensor` of type `int32`, representing the
49993 * shape for randomly generated keep/drop flags.
49994 *
49995 * Input shape:
49996 * Arbitrary. Use the keyword argument `inputShape`
49997 * (tuple of integers, does not include the samples axis)
49998 * when using this layer as the first layer in a model.
49999 *
50000 * Output shape:
50001 * Same shape as input.
50002 *
50003 * References:
50004 * - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
50005 */
50006 class AlphaDropout extends Layer {
50007 constructor(args) {
50008 super(args);
50009 this.supportsMasking = true;
50010 this.rate = args.rate;
50011 this.noiseShape = args.noiseShape;
50012 }
50013 _getNoiseShape(inputs) {
50014 return this.noiseShape || getExactlyOneTensor(inputs).shape;
50015 }
50016 computeOutputShape(inputShape) {
50017 return inputShape;
50018 }
50019 getConfig() {
50020 const baseConfig = super.getConfig();
50021 const config = { rate: this.rate };
50022 Object.assign(config, baseConfig);
50023 return config;
50024 }
50025 call(inputs, kwargs) {
50026 return tidy(() => {
50027 if (this.rate < 1 && this.rate > 0) {
50028 const noiseShape = this._getNoiseShape(inputs);
50029 const droppedInputs = () => {
50030 const input = getExactlyOneTensor(inputs);
50031 const alpha = 1.6732632423543772848170429916717;
50032 const scale = 1.0507009873554804934193349852946;
50033 const alphaP = -alpha * scale;
50034 let keptIdx = greaterEqual(randomUniform(noiseShape), this.rate);
50035 keptIdx = cast$1(keptIdx, 'float32'); // get default dtype.
50036 // Get affine transformation params.
50037 const a = ((1 - this.rate) * (1 + this.rate * alphaP ** 2)) ** -0.5;
50038 const b = -a * alphaP * this.rate;
50039 // Apply mask.
50040 const x = add$1(mul(input, keptIdx), mul(add$1(keptIdx, -1), alphaP));
50041 return add$1(mul(x, a), b);
50042 };
50043 return inTrainPhase(droppedInputs, () => getExactlyOneTensor(inputs), kwargs['training'] || false);
50044 }
50045 return inputs;
50046 });
50047 }
50048 }
50049 /** @nocollapse */
50050 AlphaDropout.className = 'AlphaDropout';
50051 registerClass(AlphaDropout);
50052
50053 /**
50054 * @license
50055 * Copyright 2018 Google LLC
50056 *
50057 * Use of this source code is governed by an MIT-style
50058 * license that can be found in the LICENSE file or at
50059 * https://opensource.org/licenses/MIT.
50060 * =============================================================================
50061 */
50062 /**
50063 * Applies batch normalization on x given mean, var, beta and gamma.
50064 *
50065 * I.e. returns:
50066 * `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
50067 *
50068 * @param x Input tensor.
50069 * @param mean Mean of batch.
50070 * @param variance Variance of batch.
50071 * @param beta Tensor with which to center the input.
50072 * @param gamma Tensor by which to scale the input.
50073 * @param epsilon Fuzz factor.
50074 * @returns The result of the batch normalization.
50075 */
50076 function batchNormalization(x, mean, variance, beta, gamma, epsilon = 1e-3) {
50077 let out;
50078 if (x.rank === 2) {
50079 out = batchNorm2d(x, mean, variance, beta, gamma, epsilon);
50080 }
50081 else if (x.rank === 3) {
50082 // TODO(cais): Check rank; give proper error message.
50083 out = batchNorm3d(x, mean, variance, beta, gamma, epsilon);
50084 }
50085 else if (x.rank === 4) {
50086 out = batchNorm4d(x, mean, variance, beta, gamma, epsilon);
50087 }
50088 else {
50089 throw new NotImplementedError(`batchNormalization is not implemented for array of rank ${x.rank} ` +
50090 `yet`);
50091 }
50092 return out;
50093 }
50094 /**
50095 * Non-broadcasting batch normalization for use in training (not inference).
50096 *
50097 * The input is normalized to zero mean and unit variance along the
50098 * `reductionAxes`, followed by scaling with `gamma` and shifted by `beta`.
50099 * The result of that is returned as the first element
50100 * of the returned `Array`. The other two elements are the mean and variance,
50101 * respectively.
50102 *
50103 * @param x Input tensor to be normalized.
50104 * @param gamma Tensor by which to scale the input.
50105 * @param beta Tensor by which to center the input.
50106 * @param reductionAxes Axes over which to normalize.
50107 * @param epsilon Fuzz factor.
50108 * @returns An `Array` of three `Tensors`:
50109 * [normalized tensor, mean of input, variance of input].
50110 */
50111 function regularNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon = 1e-3) {
50112 return tidy(() => {
50113 const meanAndVariance = moments(x, reductionAxes);
50114 const mean = meanAndVariance.mean;
50115 const variance = meanAndVariance.variance;
50116 const normed = batchNormalization(x, mean, variance, beta, gamma, epsilon);
50117 return [normed, mean, variance];
50118 });
50119 }
50120 /**
50121 * Broadcasting batch normalization for use in training (not inference).
50122 *
50123 * The input is normalized to zero mean and unit variance along the
50124 * `reductionAxes`, followed by scaling with `gamma` and shifted by `beta`.
50125 * The result of that is returned as the first element
50126 * of the returned `Array`. The other two elements are the mean and variance,
50127 * respectively.
50128 *
50129 * @param x Input tensor to be normalized.
50130 * @param gamma Tensor by which to scale the input.
50131 * @param beta Tensor by which to center the input.
50132 * @param reductionAxes Axes over which to normalize.
50133 * @param epsilon Fuzz factor.
50134 * @returns An `Array` of three `Tensors`:
50135 * [normalized tensor, mean of input, variance of input].
50136 */
50137 function broadcastNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon = 1e-3) {
50138 return tidy(() => {
50139 const meanAndVariance = moments(x, reductionAxes);
50140 const mean = meanAndVariance.mean;
50141 const variance = meanAndVariance.variance;
50142 const targetShape = [];
50143 for (const axis of range$1(0, x.rank)) {
50144 if (reductionAxes.indexOf(axis) !== -1) {
50145 targetShape.push(1);
50146 }
50147 else {
50148 targetShape.push(x.shape[axis]);
50149 }
50150 }
50151 const broadcastMean = reshape(mean, targetShape);
50152 const broadcastVariance = reshape(variance, targetShape);
50153 const broadcastGamma = gamma == null ? null : reshape(gamma, targetShape);
50154 const broadcastBeta = beta == null ? null : reshape(beta, targetShape);
50155 const normed = batchNormalization(x, broadcastMean, broadcastVariance, broadcastBeta, broadcastGamma, epsilon);
50156 return [normed, mean, variance];
50157 });
50158 }
50159 /**
50160 * Batch normalization for use in training (not inference).
50161 *
50162 * @param x Input tensor to be normalized.
50163 * @param gamma Tensor by which to scale the input.
50164 * @param beta Tensor by which to center the input.
50165 * @param reductionAxes Axes over which to normalize.
50166 * @param epsilon Fuzz factor.
50167 * @returns An `Array` of three `Tensors`:
50168 * [normalized tensor, mean of input, variance of input].
50169 */
50170 function normalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon = 1e-3) {
50171 if (arraysEqual(reductionAxes.slice().sort(), range$1(0, x.rank - 1))) {
50172 return regularNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon);
50173 }
50174 else {
50175 return broadcastNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon);
50176 }
50177 }
50178 class BatchNormalization extends Layer {
50179 constructor(args) {
50180 if (args == null) {
50181 args = {};
50182 }
50183 super(args);
50184 this.supportsMasking = true;
50185 this.axis = args.axis == null ? -1 : args.axis;
50186 this.momentum = args.momentum == null ? 0.99 : args.momentum;
50187 this.epsilon = args.epsilon == null ? 1e-3 : args.epsilon;
50188 this.center = args.center == null ? true : args.center;
50189 this.scale = args.scale == null ? true : args.scale;
50190 this.betaInitializer = getInitializer(args.betaInitializer || 'zeros');
50191 this.gammaInitializer = getInitializer(args.gammaInitializer || 'ones');
50192 this.movingMeanInitializer =
50193 getInitializer(args.movingMeanInitializer || 'zeros');
50194 this.movingVarianceInitializer =
50195 getInitializer(args.movingVarianceInitializer || 'ones');
50196 this.betaConstraint = getConstraint(args.betaConstraint);
50197 this.gammaConstraint = getConstraint(args.gammaConstraint);
50198 this.betaRegularizer = getRegularizer(args.betaRegularizer);
50199 this.gammaRegularizer = getRegularizer(args.gammaRegularizer);
50200 }
50201 build(inputShape) {
50202 inputShape = getExactlyOneShape(inputShape);
50203 const axis = this.axis >= 0 ? this.axis : (this.axis + inputShape.length);
50204 const dim = inputShape[axis];
50205 if (dim == null) {
50206 throw new ValueError(`Axis ${axis} of input tensor should have a defined dimension but ` +
50207 `the layer received an input with shape ` +
50208 `${JSON.stringify(inputShape)}.`);
50209 }
50210 this.inputSpec =
50211 [new InputSpec({ ndim: inputShape.length, axes: { [axis]: dim } })];
50212 const shape = [dim];
50213 if (this.scale) {
50214 this.gamma = this.addWeight('gamma', shape, null, this.gammaInitializer, this.gammaRegularizer, true, this.gammaConstraint);
50215 }
50216 if (this.center) {
50217 this.beta = this.addWeight('beta', shape, null, this.betaInitializer, this.betaRegularizer, true, this.betaConstraint);
50218 }
50219 this.movingMean = this.addWeight('moving_mean', shape, null, this.movingMeanInitializer, null, false);
50220 this.movingVariance = this.addWeight('moving_variance', shape, null, this.movingVarianceInitializer, null, false);
50221 this.built = true;
50222 }
50223 call(inputs, kwargs) {
50224 return tidy(() => {
50225 const training = kwargs['training'] == null ? false : kwargs['training'];
50226 const input = getExactlyOneTensor(inputs);
50227 const inputShape = input.shape;
50228 const ndim = inputShape.length;
50229 const reductionAxes = range$1(0, ndim);
50230 const axis = this.axis >= 0 ? this.axis : (this.axis + ndim);
50231 reductionAxes.splice(axis, 1);
50232 const broadcastShape = pyListRepeat(1, ndim);
50233 broadcastShape[axis] = inputShape[axis];
50234 const sortedReductionAxes = reductionAxes.slice();
50235 sortedReductionAxes.sort();
50236 const needsBroadcasting = !arraysEqual(sortedReductionAxes, range$1(0, ndim).slice(0, ndim - 1));
50237 const normalizeInference = () => {
50238 if (needsBroadcasting) {
50239 const broadcastMovingMean = reshape(this.movingMean.read(), broadcastShape);
50240 const broadcastMovingVariance = reshape(this.movingVariance.read(), broadcastShape);
50241 const broadcastBeta = this.center ? reshape(this.beta.read(), broadcastShape) : null;
50242 const broadcastGamma = this.scale ? reshape(this.gamma.read(), broadcastShape) : null;
50243 return batchNormalization(input, broadcastMovingMean, broadcastMovingVariance, broadcastBeta, broadcastGamma, this.epsilon);
50244 }
50245 else {
50246 return batchNormalization(input, this.movingMean.read(), this.movingVariance.read(), this.beta == null ? null : this.beta.read(), this.gamma == null ? null : this.gamma.read(), this.epsilon);
50247 }
50248 };
50249 if (!training) {
50250 return normalizeInference();
50251 }
50252 const [normedTraining, mean, variance] = normalizeBatchInTraining(input, this.gamma.read(), this.beta.read(), reductionAxes, this.epsilon);
50253 const doMovingAverage = (variable, value, momentum) => {
50254 tidy(() => {
50255 const decay = 1 - momentum;
50256 const origValue = variable.read();
50257 const updateDelta = mul(sub(origValue, value), decay);
50258 variable.write(sub(origValue, updateDelta));
50259 });
50260 };
50261 // Perform updates to moving mean and moving variance for training.
50262 // Porting Note: In PyKeras, these updates to `movingMean` and
50263 // `movingAverage` are done as a deferred Graph, added to the `Layer`'s
50264 // `update`s using the `add_update()` method. Here we do it imperatively
50265 // and encapsulate the updates in a function that is invoked
50266 // immediately.
50267 const updateMovingMeanAndVariance = () => {
50268 doMovingAverage(this.movingMean, mean, this.momentum);
50269 doMovingAverage(this.movingVariance, variance, this.momentum);
50270 };
50271 updateMovingMeanAndVariance();
50272 return normedTraining;
50273 });
50274 }
50275 getConfig() {
50276 const config = {
50277 axis: this.axis,
50278 momentum: this.momentum,
50279 epsilon: this.epsilon,
50280 center: this.center,
50281 scale: this.scale,
50282 betaInitializer: serializeInitializer(this.betaInitializer),
50283 gammaInitializer: serializeInitializer(this.gammaInitializer),
50284 movingMeanInitializer: serializeInitializer(this.movingMeanInitializer),
50285 movingVarianceInitializer: serializeInitializer(this.movingVarianceInitializer),
50286 betaRegularizer: serializeRegularizer(this.betaRegularizer),
50287 gammaRegularizer: serializeRegularizer(this.gammaRegularizer),
50288 betaConstraint: serializeConstraint(this.betaConstraint),
50289 gammaConstraint: serializeConstraint(this.gammaConstraint)
50290 };
50291 const baseConfig = super.getConfig();
50292 Object.assign(config, baseConfig);
50293 return config;
50294 }
50295 }
50296 /** @nocollapse */
50297 BatchNormalization.className = 'BatchNormalization';
50298 registerClass(BatchNormalization);
50299 class LayerNormalization extends Layer {
50300 constructor(args) {
50301 if (args == null) {
50302 args = {};
50303 }
50304 super(args);
50305 this.axis = args.axis == null ? -1 : args.axis;
50306 if (typeof this.axis === 'number') {
50307 if (!Number.isInteger(this.axis)) {
50308 throw new Error(`Expected axis to be an integer, but received ${this.axis}`);
50309 }
50310 }
50311 else if (Array.isArray(this.axis)) {
50312 for (const axis of this.axis) {
50313 if (!Number.isInteger(axis)) {
50314 throw new Error(`Expected axis to be an array of integers, ` +
50315 `but received ${JSON.stringify(this.axis)}`);
50316 }
50317 }
50318 }
50319 else {
50320 throw new Error(`Expected axis to be an integer or an array of integers, ` +
50321 `but received ${JSON.stringify(this.axis)}`);
50322 }
50323 this.epsilon = args.epsilon == null ? 1e-3 : args.epsilon;
50324 this.center = args.center == null ? true : args.center;
50325 this.scale = args.scale == null ? true : args.scale;
50326 this.betaInitializer = getInitializer(args.betaInitializer || 'zeros');
50327 this.gammaInitializer = getInitializer(args.gammaInitializer || 'ones');
50328 this.betaRegularizer = getRegularizer(args.betaRegularizer);
50329 this.gammaRegularizer = getRegularizer(args.gammaRegularizer);
50330 this.supportsMasking = true;
50331 }
50332 build(inputShape) {
50333 inputShape = getExactlyOneShape(inputShape);
50334 const nDims = inputShape.length;
50335 // Convert axis to array and resolve negatives.
50336 if (typeof this.axis === 'number') {
50337 this.axis = [this.axis];
50338 }
50339 for (let i = 0; i < this.axis.length; ++i) {
50340 if (this.axis[i] < 0) {
50341 this.axis[i] += nDims;
50342 }
50343 }
50344 // Further validate axes.
50345 for (const axis of this.axis) {
50346 if (axis < 0 || axis >= nDims) {
50347 throw new Error(`Invalid axis: ${axis}`);
50348 }
50349 }
50350 if (this.axis.length !== unique$1(this.axis).length) {
50351 throw new Error(`Found duplicate axes in: ${this.axis}`);
50352 }
50353 const paramShape = this.axis.map(axis => inputShape[axis]);
50354 const trainable = true;
50355 if (this.scale) {
50356 this.gamma = this.addWeight('gamma', paramShape, 'float32', this.gammaInitializer, this.gammaRegularizer, trainable);
50357 }
50358 else {
50359 this.gamma = null;
50360 }
50361 if (this.center) {
50362 this.beta = this.addWeight('beta', paramShape, 'float32', this.betaInitializer, this.betaRegularizer, trainable);
50363 }
50364 else {
50365 this.beta = null;
50366 }
50367 this.built = true;
50368 }
50369 call(inputs, kwargs) {
50370 const input = getExactlyOneTensor(inputs);
50371 const inputShape = input.shape;
50372 const nDims = inputShape.length;
50373 return tidy(() => {
50374 const keepDims = true;
50375 let { mean, variance } = moments(input, this.axis, keepDims);
50376 const broadcastShape = pyListRepeat(1, nDims);
50377 for (const dim of this.axis) {
50378 broadcastShape[dim] = inputShape[dim];
50379 }
50380 const broadcast = (v) => {
50381 if (v != null && v.shape.length !== nDims) {
50382 return reshape(v, broadcastShape);
50383 }
50384 else {
50385 return v;
50386 }
50387 };
50388 let scale = this.scale ? broadcast(this.gamma.read()) : null;
50389 let offset = this.center ? broadcast(this.beta.read()) : null;
50390 // TODO(https://github.com/tensorflow/tfjs/issues/2120): The tiling below
50391 // is a workaround for the limitation of core's batchNormalization?d don't
50392 // support broadcasting in their gradients. In addition, the tiling is
50393 // necessary to ensure correctness on the browser CPU backend regardless
50394 // of forward or backward computation. Remove this workaround once the
50395 // limitation is addressed. See .
50396 const momentsTiling = [];
50397 const scaleOffsetTiling = [];
50398 for (let i = 0; i < nDims; ++i) {
50399 if (this.axis.indexOf(i) !== -1) {
50400 momentsTiling.push(inputShape[i]);
50401 scaleOffsetTiling.push(1);
50402 }
50403 else {
50404 momentsTiling.push(1);
50405 scaleOffsetTiling.push(inputShape[i]);
50406 }
50407 }
50408 mean = tile(mean, momentsTiling);
50409 variance = tile(variance, momentsTiling);
50410 if (scale != null) {
50411 scale = tile(scale, scaleOffsetTiling);
50412 }
50413 if (offset != null) {
50414 offset = tile(offset, scaleOffsetTiling);
50415 }
50416 return batchNormalization(input, mean, variance, offset, scale, this.epsilon);
50417 });
50418 }
50419 getConfig() {
50420 const config = {
50421 axis: this.axis,
50422 epsilon: this.epsilon,
50423 center: this.center,
50424 scale: this.scale,
50425 betaInitializer: serializeInitializer(this.betaInitializer),
50426 gammaInitializer: serializeInitializer(this.gammaInitializer),
50427 betaRegularizer: serializeRegularizer(this.betaRegularizer),
50428 gammaRegularizer: serializeRegularizer(this.gammaRegularizer)
50429 };
50430 const baseConfig = super.getConfig();
50431 Object.assign(config, baseConfig);
50432 return config;
50433 }
50434 }
50435 /** @nocollapse */
50436 LayerNormalization.className = 'LayerNormalization';
50437 registerClass(LayerNormalization);
50438
50439 /**
50440 * @license
50441 * Copyright 2018 Google LLC
50442 *
50443 * Use of this source code is governed by an MIT-style
50444 * license that can be found in the LICENSE file or at
50445 * https://opensource.org/licenses/MIT.
50446 * =============================================================================
50447 */
50448 /**
50449 * Pads the middle dimension of a 3D tensor.
50450 *
50451 * @param x Input `tf.Tensor` to be padded.
50452 * @param padding `Array` of 2 integers, how many zeros to add at the start and
50453 * end of the middle dimension (i.e., dimension 1).
50454 * @return A padded 3D `tf.Tensor`.
50455 */
50456 function temporalPadding(x, padding) {
50457 return tidy(() => {
50458 if (x.rank !== 3) {
50459 throw new ValueError(`temporalPadding expects input tensor to be 3-D, but received a ` +
50460 `${x.rank}-D tensor.`);
50461 }
50462 if (padding == null) {
50463 padding = [1, 1];
50464 }
50465 if (padding.length !== 2) {
50466 throw new ValueError(`temporalPadding expects input padding pattern to be a length-2 ` +
50467 `array, but received a length-${padding.length} array.`);
50468 }
50469 const pattern = [[0, 0], padding, [0, 0]];
50470 return pad(x, pattern);
50471 });
50472 }
50473 /**
50474 * Pads the 2nd and 3rd dimensions of a 4D tensor.
50475 *
50476 * @param x Input `tf.Tensor` to be padded.
50477 * @param padding `Array` of two `Array`s, each of which is an `Array` of two
50478 * integers. The amount of padding at the beginning and end of the 2nd and 3rd
50479 * dimensions, respectively.
50480 * @param dataFormat 'channelsLast' (default) or 'channelsFirst'.
50481 * @return Padded 4D `tf.Tensor`.
50482 */
50483 function spatial2dPadding(x, padding, dataFormat) {
50484 return tidy(() => {
50485 if (x.rank !== 4) {
50486 throw new ValueError(`temporalPadding expects input tensor to be 4-D, but received a ` +
50487 `${x.rank}-D tensor.`);
50488 }
50489 if (padding == null) {
50490 padding = [[1, 1], [1, 1]];
50491 }
50492 if (padding.length !== 2 || padding[0].length !== 2 ||
50493 padding[1].length !== 2) {
50494 throw new ValueError('spatial2dPadding expects `padding` to be an Array of two Arrays, ' +
50495 'each of which is an Array of two integers.');
50496 }
50497 if (dataFormat == null) {
50498 dataFormat = imageDataFormat();
50499 }
50500 if (dataFormat !== 'channelsLast' && dataFormat !== 'channelsFirst') {
50501 throw new ValueError(`Unknown data format: ${dataFormat}. ` +
50502 `Supported data formats are 'channelsLast' and 'channelsFirst.`);
50503 }
50504 let pattern;
50505 if (dataFormat === 'channelsFirst') {
50506 pattern = [[0, 0], [0, 0], padding[0], padding[1]];
50507 }
50508 else {
50509 pattern = [[0, 0], padding[0], padding[1], [0, 0]];
50510 }
50511 return pad(x, pattern);
50512 });
50513 }
50514 class ZeroPadding2D extends Layer {
50515 constructor(args) {
50516 if (args == null) {
50517 args = {};
50518 }
50519 super(args);
50520 this.dataFormat =
50521 args.dataFormat == null ? imageDataFormat() : args.dataFormat;
50522 // TODO(cais): Maybe refactor the following logic surrounding `padding`
50523 // into a helper method.
50524 if (args.padding == null) {
50525 this.padding = [[1, 1], [1, 1]];
50526 }
50527 else if (typeof args.padding === 'number') {
50528 this.padding =
50529 [[args.padding, args.padding], [args.padding, args.padding]];
50530 }
50531 else {
50532 args.padding = args.padding;
50533 if (args.padding.length !== 2) {
50534 throw new ValueError(`ZeroPadding2D expects padding to be a length-2 array, but ` +
50535 `received a length-${args.padding.length} array.`);
50536 }
50537 let heightPadding;
50538 let widthPadding;
50539 if (typeof args.padding[0] === 'number') {
50540 heightPadding = [args.padding[0], args.padding[0]];
50541 widthPadding = [args.padding[1], args.padding[1]];
50542 }
50543 else {
50544 args.padding = args.padding;
50545 if (args.padding[0].length !== 2) {
50546 throw new ValueError(`ZeroPadding2D expects height padding to be a length-2 array, ` +
50547 `but received a length-${args.padding[0].length} array.`);
50548 }
50549 heightPadding = args.padding[0];
50550 if (args.padding[1].length !== 2) {
50551 throw new ValueError(`ZeroPadding2D expects width padding to be a length-2 array, ` +
50552 `but received a length-${args.padding[1].length} array.`);
50553 }
50554 widthPadding = args.padding[1];
50555 }
50556 this.padding = [heightPadding, widthPadding];
50557 }
50558 this.inputSpec = [new InputSpec({ ndim: 4 })];
50559 }
50560 computeOutputShape(inputShape) {
50561 inputShape = getExactlyOneShape(inputShape);
50562 let rows;
50563 let cols;
50564 if (this.dataFormat === 'channelsFirst') {
50565 if (inputShape[2] != null && inputShape[2] >= 0) {
50566 rows = inputShape[2] + this.padding[0][0] + this.padding[0][1];
50567 }
50568 else {
50569 rows = null;
50570 }
50571 if (inputShape[3] != null && inputShape[3] >= 0) {
50572 cols = inputShape[3] + this.padding[1][0] + this.padding[1][1];
50573 }
50574 else {
50575 cols = null;
50576 }
50577 return [inputShape[0], inputShape[1], rows, cols];
50578 }
50579 else {
50580 if (inputShape[1] != null && inputShape[1] >= 0) {
50581 rows = inputShape[1] + this.padding[0][0] + this.padding[0][1];
50582 }
50583 else {
50584 rows = null;
50585 }
50586 if (inputShape[2] != null && inputShape[2] >= 0) {
50587 cols = inputShape[2] + this.padding[1][0] + this.padding[1][1];
50588 }
50589 else {
50590 cols = null;
50591 }
50592 return [inputShape[0], rows, cols, inputShape[3]];
50593 }
50594 }
50595 call(inputs, kwargs) {
50596 return tidy(() => spatial2dPadding(getExactlyOneTensor(inputs), this.padding, this.dataFormat));
50597 }
50598 getConfig() {
50599 const config = {
50600 padding: this.padding,
50601 dataFormat: this.dataFormat,
50602 };
50603 const baseConfig = super.getConfig();
50604 Object.assign(config, baseConfig);
50605 return config;
50606 }
50607 }
50608 /** @nocollapse */
50609 ZeroPadding2D.className = 'ZeroPadding2D';
50610 registerClass(ZeroPadding2D);
50611
50612 /**
50613 * @license
50614 * Copyright 2018 Google LLC
50615 *
50616 * Use of this source code is governed by an MIT-style
50617 * license that can be found in the LICENSE file or at
50618 * https://opensource.org/licenses/MIT.
50619 * =============================================================================
50620 */
50621 /**
50622 * 2D pooling.
50623 * @param x
50624 * @param poolSize
50625 * @param stridesdes strides. Defaults to [1, 1].
50626 * @param padding padding. Defaults to 'valid'.
50627 * @param dataFormat data format. Defaults to 'channelsLast'.
50628 * @param poolMode Mode of pooling. Defaults to 'max'.
50629 * @returns Result of the 2D pooling.
50630 */
50631 function pool2d(x, poolSize, strides, padding, dataFormat, poolMode) {
50632 return tidy(() => {
50633 checkDataFormat(dataFormat);
50634 checkPoolMode(poolMode);
50635 checkPaddingMode(padding);
50636 if (strides == null) {
50637 strides = [1, 1];
50638 }
50639 if (padding == null) {
50640 padding = 'valid';
50641 }
50642 if (dataFormat == null) {
50643 dataFormat = imageDataFormat();
50644 }
50645 if (poolMode == null) {
50646 poolMode = 'max';
50647 }
50648 // TODO(cais): Remove the preprocessing step once deeplearn.js supports
50649 // dataFormat as an input argument.
50650 x = preprocessConv2DInput(x, dataFormat); // x is NHWC after preprocessing.
50651 let y;
50652 const paddingString = (padding === 'same') ? 'same' : 'valid';
50653 if (poolMode === 'max') {
50654 // TODO(cais): Rank check?
50655 y = maxPool(x, poolSize, strides, paddingString);
50656 }
50657 else { // 'avg'
50658 // TODO(cais): Check the dtype and rank of x and give clear error message
50659 // if those are incorrect.
50660 y = avgPool(
50661 // TODO(cais): Rank check?
50662 x, poolSize, strides, paddingString);
50663 }
50664 if (dataFormat === 'channelsFirst') {
50665 y = transpose(y, [0, 3, 1, 2]); // NHWC -> NCHW.
50666 }
50667 return y;
50668 });
50669 }
50670 /**
50671 * 3D pooling.
50672 * @param x
50673 * @param poolSize. Default to [1, 1, 1].
50674 * @param strides strides. Defaults to [1, 1, 1].
50675 * @param padding padding. Defaults to 'valid'.
50676 * @param dataFormat data format. Defaults to 'channelsLast'.
50677 * @param poolMode Mode of pooling. Defaults to 'max'.
50678 * @returns Result of the 3D pooling.
50679 */
50680 function pool3d(x, poolSize, strides, padding, dataFormat, poolMode) {
50681 return tidy(() => {
50682 checkDataFormat(dataFormat);
50683 checkPoolMode(poolMode);
50684 checkPaddingMode(padding);
50685 if (strides == null) {
50686 strides = [1, 1, 1];
50687 }
50688 if (padding == null) {
50689 padding = 'valid';
50690 }
50691 if (dataFormat == null) {
50692 dataFormat = imageDataFormat();
50693 }
50694 if (poolMode == null) {
50695 poolMode = 'max';
50696 }
50697 // x is NDHWC after preprocessing.
50698 x = preprocessConv3DInput(x, dataFormat);
50699 let y;
50700 const paddingString = (padding === 'same') ? 'same' : 'valid';
50701 if (poolMode === 'max') {
50702 y = maxPool3d(x, poolSize, strides, paddingString);
50703 }
50704 else { // 'avg'
50705 y = avgPool3d(x, poolSize, strides, paddingString);
50706 }
50707 if (dataFormat === 'channelsFirst') {
50708 y = transpose(y, [0, 4, 1, 2, 3]); // NDHWC -> NCDHW.
50709 }
50710 return y;
50711 });
50712 }
50713 /**
50714 * Abstract class for different pooling 1D layers.
50715 */
50716 class Pooling1D extends Layer {
50717 /**
50718 *
50719 * @param args Parameters for the Pooling layer.
50720 *
50721 * config.poolSize defaults to 2.
50722 */
50723 constructor(args) {
50724 if (args.poolSize == null) {
50725 args.poolSize = 2;
50726 }
50727 super(args);
50728 if (typeof args.poolSize === 'number') {
50729 this.poolSize = [args.poolSize];
50730 }
50731 else if (Array.isArray(args.poolSize) &&
50732 args.poolSize.length === 1 &&
50733 typeof args.poolSize[0] === 'number') {
50734 this.poolSize = args.poolSize;
50735 }
50736 else {
50737 throw new ValueError(`poolSize for 1D convolutional layer must be a number or an ` +
50738 `Array of a single number, but received ` +
50739 `${JSON.stringify(args.poolSize)}`);
50740 }
50741 assertPositiveInteger(this.poolSize, 'poolSize');
50742 if (args.strides == null) {
50743 this.strides = this.poolSize;
50744 }
50745 else {
50746 if (typeof args.strides === 'number') {
50747 this.strides = [args.strides];
50748 }
50749 else if (Array.isArray(args.strides) &&
50750 args.strides.length === 1 &&
50751 typeof args.strides[0] === 'number') {
50752 this.strides = args.strides;
50753 }
50754 else {
50755 throw new ValueError(`strides for 1D convolutional layer must be a number or an ` +
50756 `Array of a single number, but received ` +
50757 `${JSON.stringify(args.strides)}`);
50758 }
50759 }
50760 assertPositiveInteger(this.strides, 'strides');
50761 this.padding = args.padding == null ? 'valid' : args.padding;
50762 checkPaddingMode(this.padding);
50763 this.inputSpec = [new InputSpec({ ndim: 3 })];
50764 }
50765 computeOutputShape(inputShape) {
50766 inputShape = getExactlyOneShape(inputShape);
50767 const length = convOutputLength(inputShape[1], this.poolSize[0], this.padding, this.strides[0]);
50768 return [inputShape[0], length, inputShape[2]];
50769 }
50770 call(inputs, kwargs) {
50771 return tidy(() => {
50772 this.invokeCallHook(inputs, kwargs);
50773 // Add dummy last dimension.
50774 inputs = expandDims$1(getExactlyOneTensor(inputs), 2);
50775 const output = this.poolingFunction(getExactlyOneTensor(inputs), [this.poolSize[0], 1], [this.strides[0], 1], this.padding, 'channelsLast');
50776 // Remove dummy last dimension.
50777 return squeeze(output, [2]);
50778 });
50779 }
50780 getConfig() {
50781 const config = {
50782 poolSize: this.poolSize,
50783 padding: this.padding,
50784 strides: this.strides,
50785 };
50786 const baseConfig = super.getConfig();
50787 Object.assign(config, baseConfig);
50788 return config;
50789 }
50790 }
50791 class MaxPooling1D extends Pooling1D {
50792 constructor(args) {
50793 super(args);
50794 }
50795 poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
50796 checkDataFormat(dataFormat);
50797 checkPaddingMode(padding);
50798 return pool2d(inputs, poolSize, strides, padding, dataFormat, 'max');
50799 }
50800 }
50801 /** @nocollapse */
50802 MaxPooling1D.className = 'MaxPooling1D';
50803 registerClass(MaxPooling1D);
50804 class AveragePooling1D extends Pooling1D {
50805 constructor(args) {
50806 super(args);
50807 }
50808 poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
50809 checkDataFormat(dataFormat);
50810 checkPaddingMode(padding);
50811 return pool2d(inputs, poolSize, strides, padding, dataFormat, 'avg');
50812 }
50813 }
50814 /** @nocollapse */
50815 AveragePooling1D.className = 'AveragePooling1D';
50816 registerClass(AveragePooling1D);
50817 /**
50818 * Abstract class for different pooling 2D layers.
50819 */
50820 class Pooling2D extends Layer {
50821 constructor(args) {
50822 if (args.poolSize == null) {
50823 args.poolSize = [2, 2];
50824 }
50825 super(args);
50826 this.poolSize = Array.isArray(args.poolSize) ?
50827 args.poolSize :
50828 [args.poolSize, args.poolSize];
50829 if (args.strides == null) {
50830 this.strides = this.poolSize;
50831 }
50832 else if (Array.isArray(args.strides)) {
50833 if (args.strides.length !== 2) {
50834 throw new ValueError(`If the strides property of a 2D pooling layer is an Array, ` +
50835 `it is expected to have a length of 2, but received length ` +
50836 `${args.strides.length}.`);
50837 }
50838 this.strides = args.strides;
50839 }
50840 else {
50841 // `config.strides` is a number.
50842 this.strides = [args.strides, args.strides];
50843 }
50844 assertPositiveInteger(this.poolSize, 'poolSize');
50845 assertPositiveInteger(this.strides, 'strides');
50846 this.padding = args.padding == null ? 'valid' : args.padding;
50847 this.dataFormat =
50848 args.dataFormat == null ? 'channelsLast' : args.dataFormat;
50849 checkDataFormat(this.dataFormat);
50850 checkPaddingMode(this.padding);
50851 this.inputSpec = [new InputSpec({ ndim: 4 })];
50852 }
50853 computeOutputShape(inputShape) {
50854 inputShape = getExactlyOneShape(inputShape);
50855 let rows = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
50856 let cols = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
50857 rows =
50858 convOutputLength(rows, this.poolSize[0], this.padding, this.strides[0]);
50859 cols =
50860 convOutputLength(cols, this.poolSize[1], this.padding, this.strides[1]);
50861 if (this.dataFormat === 'channelsFirst') {
50862 return [inputShape[0], inputShape[1], rows, cols];
50863 }
50864 else {
50865 return [inputShape[0], rows, cols, inputShape[3]];
50866 }
50867 }
50868 call(inputs, kwargs) {
50869 return tidy(() => {
50870 this.invokeCallHook(inputs, kwargs);
50871 return this.poolingFunction(getExactlyOneTensor(inputs), this.poolSize, this.strides, this.padding, this.dataFormat);
50872 });
50873 }
50874 getConfig() {
50875 const config = {
50876 poolSize: this.poolSize,
50877 padding: this.padding,
50878 strides: this.strides,
50879 dataFormat: this.dataFormat
50880 };
50881 const baseConfig = super.getConfig();
50882 Object.assign(config, baseConfig);
50883 return config;
50884 }
50885 }
50886 class MaxPooling2D extends Pooling2D {
50887 constructor(args) {
50888 super(args);
50889 }
50890 poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
50891 checkDataFormat(dataFormat);
50892 checkPaddingMode(padding);
50893 return pool2d(inputs, poolSize, strides, padding, dataFormat, 'max');
50894 }
50895 }
50896 /** @nocollapse */
50897 MaxPooling2D.className = 'MaxPooling2D';
50898 registerClass(MaxPooling2D);
50899 class AveragePooling2D extends Pooling2D {
50900 constructor(args) {
50901 super(args);
50902 }
50903 poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
50904 checkDataFormat(dataFormat);
50905 checkPaddingMode(padding);
50906 return pool2d(inputs, poolSize, strides, padding, dataFormat, 'avg');
50907 }
50908 }
50909 /** @nocollapse */
50910 AveragePooling2D.className = 'AveragePooling2D';
50911 registerClass(AveragePooling2D);
50912 /**
50913 * Abstract class for different pooling 3D layers.
50914 */
50915 class Pooling3D extends Layer {
50916 constructor(args) {
50917 if (args.poolSize == null) {
50918 args.poolSize = [2, 2, 2];
50919 }
50920 super(args);
50921 this.poolSize = Array.isArray(args.poolSize) ?
50922 args.poolSize :
50923 [args.poolSize, args.poolSize, args.poolSize];
50924 if (args.strides == null) {
50925 this.strides = this.poolSize;
50926 }
50927 else if (Array.isArray(args.strides)) {
50928 if (args.strides.length !== 3) {
50929 throw new ValueError(`If the strides property of a 3D pooling layer is an Array, ` +
50930 `it is expected to have a length of 3, but received length ` +
50931 `${args.strides.length}.`);
50932 }
50933 this.strides = args.strides;
50934 }
50935 else {
50936 // `config.strides` is a number.
50937 this.strides = [args.strides, args.strides, args.strides];
50938 }
50939 assertPositiveInteger(this.poolSize, 'poolSize');
50940 assertPositiveInteger(this.strides, 'strides');
50941 this.padding = args.padding == null ? 'valid' : args.padding;
50942 this.dataFormat =
50943 args.dataFormat == null ? 'channelsLast' : args.dataFormat;
50944 checkDataFormat(this.dataFormat);
50945 checkPaddingMode(this.padding);
50946 this.inputSpec = [new InputSpec({ ndim: 5 })];
50947 }
50948 computeOutputShape(inputShape) {
50949 inputShape = getExactlyOneShape(inputShape);
50950 let depths = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
50951 let rows = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
50952 let cols = this.dataFormat === 'channelsFirst' ? inputShape[4] : inputShape[3];
50953 depths = convOutputLength(depths, this.poolSize[0], this.padding, this.strides[0]);
50954 rows =
50955 convOutputLength(rows, this.poolSize[1], this.padding, this.strides[1]);
50956 cols =
50957 convOutputLength(cols, this.poolSize[2], this.padding, this.strides[2]);
50958 if (this.dataFormat === 'channelsFirst') {
50959 return [inputShape[0], inputShape[1], depths, rows, cols];
50960 }
50961 else {
50962 return [inputShape[0], depths, rows, cols, inputShape[4]];
50963 }
50964 }
50965 call(inputs, kwargs) {
50966 return tidy(() => {
50967 this.invokeCallHook(inputs, kwargs);
50968 return this.poolingFunction(getExactlyOneTensor(inputs), this.poolSize, this.strides, this.padding, this.dataFormat);
50969 });
50970 }
50971 getConfig() {
50972 const config = {
50973 poolSize: this.poolSize,
50974 padding: this.padding,
50975 strides: this.strides,
50976 dataFormat: this.dataFormat
50977 };
50978 const baseConfig = super.getConfig();
50979 Object.assign(config, baseConfig);
50980 return config;
50981 }
50982 }
50983 class MaxPooling3D extends Pooling3D {
50984 constructor(args) {
50985 super(args);
50986 }
50987 poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
50988 checkDataFormat(dataFormat);
50989 checkPaddingMode(padding);
50990 return pool3d(inputs, poolSize, strides, padding, dataFormat, 'max');
50991 }
50992 }
50993 /** @nocollapse */
50994 MaxPooling3D.className = 'MaxPooling3D';
50995 registerClass(MaxPooling3D);
50996 class AveragePooling3D extends Pooling3D {
50997 constructor(args) {
50998 super(args);
50999 }
51000 poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
51001 checkDataFormat(dataFormat);
51002 checkPaddingMode(padding);
51003 return pool3d(inputs, poolSize, strides, padding, dataFormat, 'avg');
51004 }
51005 }
51006 /** @nocollapse */
51007 AveragePooling3D.className = 'AveragePooling3D';
51008 registerClass(AveragePooling3D);
51009 /**
51010 * Abstract class for different global pooling 1D layers.
51011 */
51012 class GlobalPooling1D extends Layer {
51013 constructor(args) {
51014 super(args);
51015 this.inputSpec = [new InputSpec({ ndim: 3 })];
51016 }
51017 computeOutputShape(inputShape) {
51018 return [inputShape[0], inputShape[2]];
51019 }
51020 call(inputs, kwargs) {
51021 throw new NotImplementedError();
51022 }
51023 }
51024 class GlobalAveragePooling1D extends GlobalPooling1D {
51025 constructor(args) {
51026 super(args || {});
51027 }
51028 call(inputs, kwargs) {
51029 return tidy(() => {
51030 const input = getExactlyOneTensor(inputs);
51031 return mean(input, 1);
51032 });
51033 }
51034 }
51035 /** @nocollapse */
51036 GlobalAveragePooling1D.className = 'GlobalAveragePooling1D';
51037 registerClass(GlobalAveragePooling1D);
51038 class GlobalMaxPooling1D extends GlobalPooling1D {
51039 constructor(args) {
51040 super(args || {});
51041 }
51042 call(inputs, kwargs) {
51043 return tidy(() => {
51044 const input = getExactlyOneTensor(inputs);
51045 return max(input, 1);
51046 });
51047 }
51048 }
51049 /** @nocollapse */
51050 GlobalMaxPooling1D.className = 'GlobalMaxPooling1D';
51051 registerClass(GlobalMaxPooling1D);
51052 /**
51053 * Abstract class for different global pooling 2D layers.
51054 */
51055 class GlobalPooling2D extends Layer {
51056 constructor(args) {
51057 super(args);
51058 this.dataFormat =
51059 args.dataFormat == null ? 'channelsLast' : args.dataFormat;
51060 checkDataFormat(this.dataFormat);
51061 this.inputSpec = [new InputSpec({ ndim: 4 })];
51062 }
51063 computeOutputShape(inputShape) {
51064 inputShape = inputShape;
51065 if (this.dataFormat === 'channelsLast') {
51066 return [inputShape[0], inputShape[3]];
51067 }
51068 else {
51069 return [inputShape[0], inputShape[1]];
51070 }
51071 }
51072 call(inputs, kwargs) {
51073 throw new NotImplementedError();
51074 }
51075 getConfig() {
51076 const config = { dataFormat: this.dataFormat };
51077 const baseConfig = super.getConfig();
51078 Object.assign(config, baseConfig);
51079 return config;
51080 }
51081 }
51082 class GlobalAveragePooling2D extends GlobalPooling2D {
51083 call(inputs, kwargs) {
51084 return tidy(() => {
51085 const input = getExactlyOneTensor(inputs);
51086 if (this.dataFormat === 'channelsLast') {
51087 return mean(input, [1, 2]);
51088 }
51089 else {
51090 return mean(input, [2, 3]);
51091 }
51092 });
51093 }
51094 }
51095 /** @nocollapse */
51096 GlobalAveragePooling2D.className = 'GlobalAveragePooling2D';
51097 registerClass(GlobalAveragePooling2D);
51098 class GlobalMaxPooling2D extends GlobalPooling2D {
51099 call(inputs, kwargs) {
51100 return tidy(() => {
51101 const input = getExactlyOneTensor(inputs);
51102 if (this.dataFormat === 'channelsLast') {
51103 return max(input, [1, 2]);
51104 }
51105 else {
51106 return max(input, [2, 3]);
51107 }
51108 });
51109 }
51110 }
51111 /** @nocollapse */
51112 GlobalMaxPooling2D.className = 'GlobalMaxPooling2D';
51113 registerClass(GlobalMaxPooling2D);
51114
51115 /**
51116 * @license
51117 * Copyright 2018 Google LLC
51118 *
51119 * Use of this source code is governed by an MIT-style
51120 * license that can be found in the LICENSE file or at
51121 * https://opensource.org/licenses/MIT.
51122 * =============================================================================
51123 */
51124 /**
51125 * Abstract wrapper base class.
51126 *
51127 * Wrappers take another layer and augment it in various ways.
51128 * Do not use this class as a layer, it is only an abstract base class.
51129 * Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
51130 */
51131 class Wrapper extends Layer {
51132 constructor(args) {
51133 // Porting Note: In PyKeras, `self.layer` is set prior to the calling
51134 // `super()`. But we can't do that here due to TypeScript's restriction.
51135 // See: https://github.com/Microsoft/TypeScript/issues/8277
51136 // As a result, we have to add checks in `get trainable()` and
51137 // `set trainable()` below in order to prevent using `this.layer` when
51138 // its value is `undefined`. The super constructor does use the getter
51139 // and the setter of `this.layer`.
51140 super(args);
51141 this.layer = args.layer;
51142 }
51143 build(inputShape) {
51144 this.built = true;
51145 }
51146 // TODO(cais): Implement activityRegularizer getter.
51147 get trainable() {
51148 // Porting Note: the check of `this.layer` here is necessary due to the
51149 // way the `constructor` of this class is written (see Porting Note
51150 // above).
51151 if (this.layer != null) {
51152 return this.layer.trainable;
51153 }
51154 else {
51155 return false;
51156 }
51157 }
51158 set trainable(value) {
51159 // Porting Note: the check of `this.layer` here is necessary due to the
51160 // way the `constructor` of this class is written (see Porting Note
51161 // above).
51162 if (this.layer != null) {
51163 this.layer.trainable = value;
51164 }
51165 }
51166 get trainableWeights() {
51167 return this.layer.trainableWeights;
51168 }
51169 // TODO(cais): Implement setter for trainableWeights.
51170 get nonTrainableWeights() {
51171 return this.layer.nonTrainableWeights;
51172 }
51173 // TODO(cais): Implement setter for nonTrainableWeights.
51174 get updates() {
51175 // tslint:disable-next-line:no-any
51176 return this.layer._updates;
51177 }
51178 // TODO(cais): Implement getUpdatesFor().
51179 get losses() {
51180 return this.layer.losses;
51181 }
51182 // TODO(cais): Implement getLossesFor().
51183 getWeights() {
51184 return this.layer.getWeights();
51185 }
51186 setWeights(weights) {
51187 this.layer.setWeights(weights);
51188 }
51189 getConfig() {
51190 const config = {
51191 'layer': {
51192 'className': this.layer.getClassName(),
51193 'config': this.layer.getConfig(),
51194 }
51195 };
51196 const baseConfig = super.getConfig();
51197 Object.assign(config, baseConfig);
51198 return config;
51199 }
51200 setFastWeightInitDuringBuild(value) {
51201 super.setFastWeightInitDuringBuild(value);
51202 if (this.layer != null) {
51203 this.layer.setFastWeightInitDuringBuild(value);
51204 }
51205 }
51206 /** @nocollapse */
51207 static fromConfig(cls, config, customObjects = {}) {
51208 const layerConfig = config['layer'];
51209 const layer = deserialize(layerConfig, customObjects);
51210 delete config['layer'];
51211 const newConfig = { layer };
51212 Object.assign(newConfig, config);
51213 return new cls(newConfig);
51214 }
51215 }
51216 class TimeDistributed extends Wrapper {
51217 constructor(args) {
51218 super(args);
51219 this.supportsMasking = true;
51220 }
51221 build(inputShape) {
51222 inputShape = getExactlyOneShape(inputShape);
51223 if (inputShape.length < 3) {
51224 throw new ValueError(`TimeDistributed layer expects an input shape >= 3D, but received ` +
51225 `input shape ${JSON.stringify(inputShape)}`);
51226 }
51227 this.inputSpec = [{ shape: inputShape }];
51228 const childInputShape = [inputShape[0]].concat(inputShape.slice(2));
51229 if (!this.layer.built) {
51230 this.layer.build(childInputShape);
51231 this.layer.built = true;
51232 }
51233 super.build(inputShape);
51234 }
51235 computeOutputShape(inputShape) {
51236 inputShape = getExactlyOneShape(inputShape);
51237 const childInputShape = [inputShape[0]].concat(inputShape.slice(2));
51238 const childOutputShape = this.layer.computeOutputShape(childInputShape);
51239 const timesteps = inputShape[1];
51240 return [childOutputShape[0], timesteps].concat(childOutputShape.slice(1));
51241 }
51242 call(inputs, kwargs) {
51243 return tidy(() => {
51244 // TODO(cais): Add 'training' and 'useLearningPhase' to kwargs.
51245 inputs = getExactlyOneTensor(inputs);
51246 // Porting Note: In tfjs-layers, `inputs` are always concrete tensor
51247 // values. Hence the inputs can't have an undetermined first (batch)
51248 // dimension, which is why we always use the K.rnn approach here.
51249 const step = (inputs, states) => {
51250 // TODO(cais): Add useLearningPhase.
51251 // NOTE(cais): `layer.call` may return a length-1 array of Tensor in
51252 // some cases (e.g., `layer` is a `Sequential` instance), which is
51253 // why `getExactlyOneTensor` is used below.
51254 const output = getExactlyOneTensor(this.layer.call(inputs, kwargs));
51255 return [output, []];
51256 };
51257 const rnnOutputs = rnn(step, inputs, [], false /* goBackwards */, null /* mask */, null /* constants */, false /* unroll */, true /* needPerStepOutputs */);
51258 const y = rnnOutputs[1];
51259 // TODO(cais): Add activity regularization.
51260 // TODO(cais): Add useLearningPhase.
51261 return y;
51262 });
51263 }
51264 }
51265 /** @nocollapse */
51266 TimeDistributed.className = 'TimeDistributed';
51267 registerClass(TimeDistributed);
51268 function checkBidirectionalMergeMode(value) {
51269 checkStringTypeUnionValue(VALID_BIDIRECTIONAL_MERGE_MODES, 'BidirectionalMergeMode', value);
51270 }
51271 const DEFAULT_BIDIRECTIONAL_MERGE_MODE = 'concat';
51272 class Bidirectional extends Wrapper {
51273 constructor(args) {
51274 super(args);
51275 // Note: When creating `this.forwardLayer`, the original Layer object
51276 // (`config.layer`) ought to be cloned. This is why we call
51277 // `getConfig()` followed by `deserialize()`. Without this cloning,
51278 // the layer names saved during serialization will incorrectly contain
51279 // the 'forward_' prefix. In Python Keras, this is done using
51280 // `copy.copy` (shallow copy), which does not have a simple equivalent
51281 // in JavaScript. JavaScript's `Object.assign()` does not copy
51282 // methods.
51283 const layerConfig = args.layer.getConfig();
51284 const forwDict = {};
51285 forwDict['className'] = args.layer.getClassName();
51286 forwDict['config'] = layerConfig;
51287 this.forwardLayer = deserialize(forwDict);
51288 layerConfig['goBackwards'] =
51289 layerConfig['goBackwards'] === true ? false : true;
51290 const backDict = {};
51291 backDict['className'] = args.layer.getClassName();
51292 backDict['config'] = layerConfig;
51293 this.backwardLayer = deserialize(backDict);
51294 this.forwardLayer.name = 'forward_' + this.forwardLayer.name;
51295 this.backwardLayer.name = 'backward_' + this.backwardLayer.name;
51296 this.mergeMode = args.mergeMode === undefined ?
51297 DEFAULT_BIDIRECTIONAL_MERGE_MODE :
51298 args.mergeMode;
51299 checkBidirectionalMergeMode(this.mergeMode);
51300 if (args.weights) {
51301 throw new NotImplementedError('weights support is not implemented for Bidirectional layer yet.');
51302 }
51303 this._stateful = args.layer.stateful;
51304 this.returnSequences = args.layer.returnSequences;
51305 this.returnState = args.layer.returnState;
51306 this.supportsMasking = true;
51307 this._trainable = true;
51308 this.inputSpec = args.layer.inputSpec;
51309 this.numConstants = null;
51310 }
51311 get trainable() {
51312 return this._trainable;
51313 }
51314 set trainable(value) {
51315 // Porting Note: the check of `this.layer` here is necessary due to the
51316 // way the `constructor` of this class is written (see Porting Note
51317 // above).
51318 this._trainable = value;
51319 if (this.forwardLayer != null) {
51320 this.forwardLayer.trainable = value;
51321 }
51322 if (this.backwardLayer != null) {
51323 this.backwardLayer.trainable = value;
51324 }
51325 }
51326 getWeights() {
51327 return this.forwardLayer.getWeights().concat(this.backwardLayer.getWeights());
51328 }
51329 setWeights(weights) {
51330 const numWeights = weights.length;
51331 const numeightsOver2 = Math.floor(numWeights / 2);
51332 this.forwardLayer.setWeights(weights.slice(0, numeightsOver2));
51333 this.backwardLayer.setWeights(weights.slice(numeightsOver2));
51334 }
51335 computeOutputShape(inputShape) {
51336 let layerShapes = this.forwardLayer.computeOutputShape(inputShape);
51337 if (!(Array.isArray(layerShapes) && Array.isArray(layerShapes[0]))) {
51338 layerShapes = [layerShapes];
51339 }
51340 layerShapes = layerShapes;
51341 let outputShape;
51342 let outputShapes;
51343 let stateShape;
51344 if (this.returnState) {
51345 stateShape = layerShapes.slice(1);
51346 outputShape = layerShapes[0];
51347 }
51348 else {
51349 outputShape = layerShapes[0];
51350 }
51351 outputShape = outputShape;
51352 if (this.mergeMode === 'concat') {
51353 outputShape[outputShape.length - 1] *= 2;
51354 outputShapes = [outputShape];
51355 }
51356 else if (this.mergeMode == null) {
51357 outputShapes = [outputShape, outputShape.slice()];
51358 }
51359 else {
51360 outputShapes = [outputShape];
51361 }
51362 if (this.returnState) {
51363 if (this.mergeMode == null) {
51364 return outputShapes.concat(stateShape).concat(stateShape.slice());
51365 }
51366 return [outputShape].concat(stateShape).concat(stateShape.slice());
51367 }
51368 return singletonOrArray(outputShapes);
51369 }
51370 apply(inputs, kwargs) {
51371 let initialState = kwargs == null ? null : kwargs['initialState'];
51372 let constants = kwargs == null ? null : kwargs['constants'];
51373 if (kwargs == null) {
51374 kwargs = {};
51375 }
51376 const standardized = standardizeArgs(inputs, initialState, constants, this.numConstants);
51377 inputs = standardized.inputs;
51378 initialState = standardized.initialState;
51379 constants = standardized.constants;
51380 if (Array.isArray(inputs)) {
51381 initialState = inputs.slice(1);
51382 inputs = inputs[0];
51383 }
51384 if ((initialState == null || initialState.length === 0) &&
51385 constants == null) {
51386 return super.apply(inputs, kwargs);
51387 }
51388 const additionalInputs = [];
51389 const additionalSpecs = [];
51390 if (initialState != null) {
51391 const numStates = initialState.length;
51392 if (numStates % 2 > 0) {
51393 throw new ValueError('When passing `initialState` to a Bidrectional RNN, ' +
51394 'the state should be an Array containing the states of ' +
51395 'the underlying RNNs.');
51396 }
51397 kwargs['initialState'] = initialState;
51398 additionalInputs.push(...initialState);
51399 const stateSpecs = initialState
51400 .map(state => new InputSpec({ shape: state.shape }));
51401 this.forwardLayer.stateSpec = stateSpecs.slice(0, numStates / 2);
51402 this.backwardLayer.stateSpec = stateSpecs.slice(numStates / 2);
51403 additionalSpecs.push(...stateSpecs);
51404 }
51405 if (constants != null) {
51406 throw new NotImplementedError('Support for constants in Bidirectional layers is not ' +
51407 'implemented yet.');
51408 }
51409 const isSymbolicTensor = additionalInputs[0] instanceof SymbolicTensor;
51410 for (const tensor of additionalInputs) {
51411 if (tensor instanceof SymbolicTensor !== isSymbolicTensor) {
51412 throw new ValueError('The initial state of a Bidirectional layer cannot be ' +
51413 'specified as a mix of symbolic and non-symbolic tensors');
51414 }
51415 }
51416 if (isSymbolicTensor) {
51417 // Compute the full input and specs, including the states.
51418 const fullInput = [inputs].concat(additionalInputs);
51419 const fullInputSpec = this.inputSpec.concat(additionalSpecs);
51420 // Perform the call temporarily and replace inputSpec.
51421 // Note: with initial states symbolic calls and non-symbolic calls to
51422 // this method differ in how the initial states are passed. For
51423 // symbolic calls, the initial states are passed in the first arg, as
51424 // an Array of SymbolicTensors; for non-symbolic calls, they are
51425 // passed in the second arg as a part of the kwargs. Hence the need to
51426 // temporarily modify inputSpec here.
51427 // TODO(cais): Make refactoring so that this hacky code below is no
51428 // longer needed.
51429 const originalInputSpec = this.inputSpec;
51430 this.inputSpec = fullInputSpec;
51431 const output = super.apply(fullInput, kwargs);
51432 this.inputSpec = originalInputSpec;
51433 return output;
51434 }
51435 else {
51436 return super.apply(inputs, kwargs);
51437 }
51438 }
51439 call(inputs, kwargs) {
51440 return tidy(() => {
51441 const initialState = kwargs['initialState'];
51442 let y;
51443 let yRev;
51444 if (initialState == null) {
51445 y = this.forwardLayer.call(inputs, kwargs);
51446 yRev = this.backwardLayer.call(inputs, kwargs);
51447 }
51448 else {
51449 const forwardState = initialState.slice(0, initialState.length / 2);
51450 const backwardState = initialState.slice(initialState.length / 2);
51451 y = this.forwardLayer.call(inputs, Object.assign(kwargs, { initialState: forwardState }));
51452 yRev = this.backwardLayer.call(inputs, Object.assign(kwargs, { initialState: backwardState }));
51453 }
51454 let states;
51455 if (this.returnState) {
51456 if (Array.isArray(y)) {
51457 states = y.slice(1).concat(yRev.slice(1));
51458 }
51459 else {
51460 }
51461 y = y[0];
51462 yRev = yRev[0];
51463 }
51464 if (this.returnSequences) {
51465 yRev = reverse(yRev, 1);
51466 }
51467 let output;
51468 if (this.mergeMode === 'concat') {
51469 output = concatenate([y, yRev]);
51470 }
51471 else if (this.mergeMode === 'sum') {
51472 output = add$1(y, yRev);
51473 }
51474 else if (this.mergeMode === 'ave') {
51475 output = mul(.5, add$1(y, yRev));
51476 }
51477 else if (this.mergeMode === 'mul') {
51478 output = mul(y, yRev);
51479 }
51480 else if (this.mergeMode == null) {
51481 output = [y, yRev];
51482 }
51483 // TODO(cais): Properly set learning phase.
51484 if (this.returnState) {
51485 if (this.mergeMode == null) {
51486 return output.concat(states);
51487 }
51488 return [output].concat(states);
51489 }
51490 return output;
51491 });
51492 }
51493 resetStates(states) {
51494 this.forwardLayer.resetStates();
51495 this.backwardLayer.resetStates();
51496 }
51497 build(inputShape) {
51498 nameScope(this.forwardLayer.name, () => {
51499 this.forwardLayer.build(inputShape);
51500 });
51501 nameScope(this.backwardLayer.name, () => {
51502 this.backwardLayer.build(inputShape);
51503 });
51504 this.built = true;
51505 }
51506 computeMask(inputs, mask) {
51507 if (Array.isArray(mask)) {
51508 mask = mask[0];
51509 }
51510 let outputMask;
51511 if (this.returnSequences) {
51512 if (this.mergeMode == null) {
51513 outputMask = [mask, mask];
51514 }
51515 else {
51516 outputMask = mask;
51517 }
51518 }
51519 else {
51520 if (this.mergeMode == null) {
51521 outputMask = [null, null];
51522 }
51523 else {
51524 outputMask = null;
51525 }
51526 }
51527 if (this.returnState) {
51528 const states = this.forwardLayer.states;
51529 const stateMask = states.map(state => null);
51530 if (Array.isArray(outputMask)) {
51531 return outputMask.concat(stateMask).concat(stateMask);
51532 }
51533 else {
51534 return [outputMask].concat(stateMask).concat(stateMask);
51535 }
51536 }
51537 else {
51538 return outputMask;
51539 }
51540 }
51541 get trainableWeights() {
51542 return this.forwardLayer.trainableWeights.concat(this.backwardLayer.trainableWeights);
51543 }
51544 get nonTrainableWeights() {
51545 return this.forwardLayer.nonTrainableWeights.concat(this.backwardLayer.nonTrainableWeights);
51546 }
51547 // TODO(cais): Implement constraints().
51548 setFastWeightInitDuringBuild(value) {
51549 super.setFastWeightInitDuringBuild(value);
51550 if (this.forwardLayer != null) {
51551 this.forwardLayer.setFastWeightInitDuringBuild(value);
51552 }
51553 if (this.backwardLayer != null) {
51554 this.backwardLayer.setFastWeightInitDuringBuild(value);
51555 }
51556 }
51557 getConfig() {
51558 const config = {
51559 'mergeMode': this.mergeMode,
51560 };
51561 // TODO(cais): Add logic for `numConstants` once the property is added.
51562 const baseConfig = super.getConfig();
51563 Object.assign(config, baseConfig);
51564 return config;
51565 }
51566 /** @nocollapse */
51567 static fromConfig(cls, config) {
51568 const rnnLayer = deserialize(config['layer']);
51569 delete config['layer'];
51570 // TODO(cais): Add logic for `numConstants` once the property is added.
51571 if (config['numConstants'] != null) {
51572 throw new NotImplementedError(`Deserialization of a Bidirectional layer with numConstants ` +
51573 `present is not supported yet.`);
51574 }
51575 // tslint:disable-next-line:no-any
51576 const newConfig = config;
51577 newConfig['layer'] = rnnLayer;
51578 return new cls(newConfig);
51579 }
51580 }
51581 /** @nocollapse */
51582 Bidirectional.className = 'Bidirectional';
51583 registerClass(Bidirectional);
51584
51585 /**
51586 * @license
51587 * Copyright 2018 Google LLC
51588 *
51589 * Use of this source code is governed by an MIT-style
51590 * license that can be found in the LICENSE file or at
51591 * https://opensource.org/licenses/MIT.
51592 * =============================================================================
51593 */
51594 // TODO(cais): Add doc string to all the public static functions in this
51595 // class; include exectuable JavaScript code snippets where applicable
51596 // (b/74074458).
51597 // Input Layer.
51598 /**
51599 * An input layer is an entry point into a `tf.LayersModel`.
51600 *
51601 * `InputLayer` is generated automatically for `tf.Sequential`` models by
51602 * specifying the `inputshape` or `batchInputShape` for the first layer. It
51603 * should not be specified explicitly. However, it can be useful sometimes,
51604 * e.g., when constructing a sequential model from a subset of another
51605 * sequential model's layers. Like the code snippet below shows.
51606 *
51607 * ```js
51608 * // Define a model which simply adds two inputs.
51609 * const model1 = tf.sequential();
51610 * model1.add(tf.layers.dense({inputShape: [4], units: 3, activation: 'relu'}));
51611 * model1.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));
51612 * model1.summary();
51613 * model1.predict(tf.zeros([1, 4])).print();
51614 *
51615 * // Construct another model, reusing the second layer of `model1` while
51616 * // not using the first layer of `model1`. Note that you cannot add the second
51617 * // layer of `model` directly as the first layer of the new sequential model,
51618 * // because doing so will lead to an error related to the fact that the layer
51619 * // is not an input layer. Instead, you need to create an `inputLayer` and add
51620 * // it to the new sequential model before adding the reused layer.
51621 * const model2 = tf.sequential();
51622 * // Use an inputShape that matches the input shape of `model1`'s second
51623 * // layer.
51624 * model2.add(tf.layers.inputLayer({inputShape: [3]}));
51625 * model2.add(model1.layers[1]);
51626 * model2.summary();
51627 * model2.predict(tf.zeros([1, 3])).print();
51628 * ```
51629 *
51630 * @doc {heading: 'Layers', subheading: 'Inputs', namespace: 'layers'}
51631 */
51632 function inputLayer(args) {
51633 return new InputLayer(args);
51634 }
51635 // Advanced Activation Layers.
51636 /**
51637 * Exponetial Linear Unit (ELU).
51638 *
51639 * It follows:
51640 * `f(x) = alpha * (exp(x) - 1.) for x < 0`,
51641 * `f(x) = x for x >= 0`.
51642 *
51643 * Input shape:
51644 * Arbitrary. Use the configuration `inputShape` when using this layer as the
51645 * first layer in a model.
51646 *
51647 * Output shape:
51648 * Same shape as the input.
51649 *
51650 * References:
51651 * - [Fast and Accurate Deep Network Learning by Exponential Linear Units
51652 * (ELUs)](https://arxiv.org/abs/1511.07289v1)
51653 *
51654 * @doc {
51655 * heading: 'Layers',
51656 * subheading: 'Advanced Activation',
51657 * namespace: 'layers'
51658 * }
51659 */
51660 function elu$2(args) {
51661 return new ELU(args);
51662 }
51663 /**
51664 * Rectified Linear Unit activation function.
51665 *
51666 * Input shape:
51667 * Arbitrary. Use the config field `inputShape` (Array of integers, does
51668 * not include the sample axis) when using this layer as the first layer
51669 * in a model.
51670 *
51671 * Output shape:
51672 * Same shape as the input.
51673 *
51674 * @doc {
51675 * heading: 'Layers',
51676 * subheading: 'Advanced Activation',
51677 * namespace: 'layers'
51678 * }
51679 */
51680 function reLU(args) {
51681 return new ReLU(args);
51682 }
51683 /**
51684 * Leaky version of a rectified linear unit.
51685 *
51686 * It allows a small gradient when the unit is not active:
51687 * `f(x) = alpha * x for x < 0.`
51688 * `f(x) = x for x >= 0.`
51689 *
51690 * Input shape:
51691 * Arbitrary. Use the configuration `inputShape` when using this layer as the
51692 * first layer in a model.
51693 *
51694 * Output shape:
51695 * Same shape as the input.
51696 *
51697 * @doc {
51698 * heading: 'Layers',
51699 * subheading: 'Advanced Activation',
51700 * namespace: 'layers'
51701 * }
51702 */
51703 function leakyReLU(args) {
51704 return new LeakyReLU(args);
51705 }
51706 /**
51707 * Parameterized version of a leaky rectified linear unit.
51708 *
51709 * It follows
51710 * `f(x) = alpha * x for x < 0.`
51711 * `f(x) = x for x >= 0.`
51712 * wherein `alpha` is a trainable weight.
51713 *
51714 * Input shape:
51715 * Arbitrary. Use the configuration `inputShape` when using this layer as the
51716 * first layer in a model.
51717 *
51718 * Output shape:
51719 * Same shape as the input.
51720 *
51721 * @doc {
51722 * heading: 'Layers',
51723 * subheading: 'Advanced Activation',
51724 * namespace: 'layers'
51725 * }
51726 */
51727 function prelu$1(args) {
51728 return new PReLU(args);
51729 }
51730 /**
51731 * Softmax activation layer.
51732 *
51733 * Input shape:
51734 * Arbitrary. Use the configuration `inputShape` when using this layer as the
51735 * first layer in a model.
51736 *
51737 * Output shape:
51738 * Same shape as the input.
51739 *
51740 * @doc {
51741 * heading: 'Layers',
51742 * subheading: 'Advanced Activation',
51743 * namespace: 'layers'
51744 * }
51745 */
51746 function softmax$1(args) {
51747 return new Softmax$2(args);
51748 }
51749 /**
51750 * Thresholded Rectified Linear Unit.
51751 *
51752 * It follows:
51753 * `f(x) = x for x > theta`,
51754 * `f(x) = 0 otherwise`.
51755 *
51756 * Input shape:
51757 * Arbitrary. Use the configuration `inputShape` when using this layer as the
51758 * first layer in a model.
51759 *
51760 * Output shape:
51761 * Same shape as the input.
51762 *
51763 * References:
51764 * - [Zero-Bias Autoencoders and the Benefits of Co-Adapting
51765 * Features](http://arxiv.org/abs/1402.3337)
51766 *
51767 * @doc {
51768 * heading: 'Layers',
51769 * subheading: 'Advanced Activation',
51770 * namespace: 'layers'
51771 * }
51772 */
51773 function thresholdedReLU(args) {
51774 return new ThresholdedReLU(args);
51775 }
51776 // Convolutional Layers.
51777 /**
51778 * 1D convolution layer (e.g., temporal convolution).
51779 *
51780 * This layer creates a convolution kernel that is convolved
51781 * with the layer input over a single spatial (or temporal) dimension
51782 * to produce a tensor of outputs.
51783 *
51784 * If `use_bias` is True, a bias vector is created and added to the outputs.
51785 *
51786 * If `activation` is not `null`, it is applied to the outputs as well.
51787 *
51788 * When using this layer as the first layer in a model, provide an
51789 * `inputShape` argument `Array` or `null`.
51790 *
51791 * For example, `inputShape` would be:
51792 * - `[10, 128]` for sequences of 10 vectors of 128-dimensional vectors
51793 * - `[null, 128]` for variable-length sequences of 128-dimensional vectors.
51794 *
51795 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
51796 */
51797 function conv1d$2(args) {
51798 return new Conv1D(args);
51799 }
51800 /**
51801 * 2D convolution layer (e.g. spatial convolution over images).
51802 *
51803 * This layer creates a convolution kernel that is convolved
51804 * with the layer input to produce a tensor of outputs.
51805 *
51806 * If `useBias` is True, a bias vector is created and added to the outputs.
51807 *
51808 * If `activation` is not `null`, it is applied to the outputs as well.
51809 *
51810 * When using this layer as the first layer in a model,
51811 * provide the keyword argument `inputShape`
51812 * (Array of integers, does not include the sample axis),
51813 * e.g. `inputShape=[128, 128, 3]` for 128x128 RGB pictures
51814 * in `dataFormat='channelsLast'`.
51815 *
51816 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
51817 */
51818 function conv2d$3(args) {
51819 return new Conv2D$1(args);
51820 }
51821 /**
51822 * Transposed convolutional layer (sometimes called Deconvolution).
51823 *
51824 * The need for transposed convolutions generally arises
51825 * from the desire to use a transformation going in the opposite direction of
51826 * a normal convolution, i.e., from something that has the shape of the output
51827 * of some convolution to something that has the shape of its input while
51828 * maintaining a connectivity pattern that is compatible with said
51829 * convolution.
51830 *
51831 * When using this layer as the first layer in a model, provide the
51832 * configuration `inputShape` (`Array` of integers, does not include the
51833 * sample axis), e.g., `inputShape: [128, 128, 3]` for 128x128 RGB pictures in
51834 * `dataFormat: 'channelsLast'`.
51835 *
51836 * Input shape:
51837 * 4D tensor with shape:
51838 * `[batch, channels, rows, cols]` if `dataFormat` is `'channelsFirst'`.
51839 * or 4D tensor with shape
51840 * `[batch, rows, cols, channels]` if `dataFormat` is `'channelsLast`.
51841 *
51842 * Output shape:
51843 * 4D tensor with shape:
51844 * `[batch, filters, newRows, newCols]` if `dataFormat` is
51845 * `'channelsFirst'`. or 4D tensor with shape:
51846 * `[batch, newRows, newCols, filters]` if `dataFormat` is `'channelsLast'`.
51847 *
51848 * References:
51849 * - [A guide to convolution arithmetic for deep
51850 * learning](https://arxiv.org/abs/1603.07285v1)
51851 * - [Deconvolutional
51852 * Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf)
51853 *
51854 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
51855 */
51856 function conv2dTranspose$1(args) {
51857 return new Conv2DTranspose(args);
51858 }
51859 /**
51860 * 3D convolution layer (e.g. spatial convolution over volumes).
51861 *
51862 * This layer creates a convolution kernel that is convolved
51863 * with the layer input to produce a tensor of outputs.
51864 *
51865 * If `useBias` is True, a bias vector is created and added to the outputs.
51866 *
51867 * If `activation` is not `null`, it is applied to the outputs as well.
51868 *
51869 * When using this layer as the first layer in a model,
51870 * provide the keyword argument `inputShape`
51871 * (Array of integers, does not include the sample axis),
51872 * e.g. `inputShape=[128, 128, 128, 1]` for 128x128x128 grayscale volumes
51873 * in `dataFormat='channelsLast'`.
51874 *
51875 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
51876 */
51877 function conv3d$2(args) {
51878 return new Conv3D$1(args);
51879 }
51880 function conv3dTranspose$1(args) {
51881 return new Conv3DTranspose(args);
51882 }
51883 /**
51884 * Depthwise separable 2D convolution.
51885 *
51886 * Separable convolution consists of first performing
51887 * a depthwise spatial convolution
51888 * (which acts on each input channel separately)
51889 * followed by a pointwise convolution which mixes together the resulting
51890 * output channels. The `depthMultiplier` argument controls how many
51891 * output channels are generated per input channel in the depthwise step.
51892 *
51893 * Intuitively, separable convolutions can be understood as
51894 * a way to factorize a convolution kernel into two smaller kernels,
51895 * or as an extreme version of an Inception block.
51896 *
51897 * Input shape:
51898 * 4D tensor with shape:
51899 * `[batch, channels, rows, cols]` if data_format='channelsFirst'
51900 * or 4D tensor with shape:
51901 * `[batch, rows, cols, channels]` if data_format='channelsLast'.
51902 *
51903 * Output shape:
51904 * 4D tensor with shape:
51905 * `[batch, filters, newRows, newCols]` if data_format='channelsFirst'
51906 * or 4D tensor with shape:
51907 * `[batch, newRows, newCols, filters]` if data_format='channelsLast'.
51908 * `rows` and `cols` values might have changed due to padding.
51909 *
51910 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
51911 */
51912 function separableConv2d$1(args) {
51913 return new SeparableConv2D(args);
51914 }
51915 /**
51916 * Cropping layer for 2D input (e.g., image).
51917 *
51918 * This layer can crop an input
51919 * at the top, bottom, left and right side of an image tensor.
51920 *
51921 * Input shape:
51922 * 4D tensor with shape:
51923 * - If `dataFormat` is `"channelsLast"`:
51924 * `[batch, rows, cols, channels]`
51925 * - If `data_format` is `"channels_first"`:
51926 * `[batch, channels, rows, cols]`.
51927 *
51928 * Output shape:
51929 * 4D with shape:
51930 * - If `dataFormat` is `"channelsLast"`:
51931 * `[batch, croppedRows, croppedCols, channels]`
51932 * - If `dataFormat` is `"channelsFirst"`:
51933 * `[batch, channels, croppedRows, croppedCols]`.
51934 *
51935 * Examples
51936 * ```js
51937 *
51938 * const model = tf.sequential();
51939 * model.add(tf.layers.cropping2D({cropping:[[2, 2], [2, 2]],
51940 * inputShape: [128, 128, 3]}));
51941 * //now output shape is [batch, 124, 124, 3]
51942 * ```
51943 *
51944 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
51945 */
51946 function cropping2D(args) {
51947 return new Cropping2D(args);
51948 }
51949 /**
51950 * Upsampling layer for 2D inputs.
51951 *
51952 * Repeats the rows and columns of the data
51953 * by size[0] and size[1] respectively.
51954 *
51955 *
51956 * Input shape:
51957 * 4D tensor with shape:
51958 * - If `dataFormat` is `"channelsLast"`:
51959 * `[batch, rows, cols, channels]`
51960 * - If `dataFormat` is `"channelsFirst"`:
51961 * `[batch, channels, rows, cols]`
51962 *
51963 * Output shape:
51964 * 4D tensor with shape:
51965 * - If `dataFormat` is `"channelsLast"`:
51966 * `[batch, upsampledRows, upsampledCols, channels]`
51967 * - If `dataFormat` is `"channelsFirst"`:
51968 * `[batch, channels, upsampledRows, upsampledCols]`
51969 *
51970 *
51971 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
51972 */
51973 function upSampling2d(args) {
51974 return new UpSampling2D(args);
51975 }
51976 // Convolutional(depthwise) Layers.
51977 /**
51978 * Depthwise separable 2D convolution.
51979 *
51980 * Depthwise Separable convolutions consists in performing just the first step
51981 * in a depthwise spatial convolution (which acts on each input channel
51982 * separately). The `depthMultplier` argument controls how many output channels
51983 * are generated per input channel in the depthwise step.
51984 *
51985 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
51986 */
51987 function depthwiseConv2d$3(args) {
51988 return new DepthwiseConv2D(args);
51989 }
51990 // Basic Layers.
51991 /**
51992 * Applies an activation function to an output.
51993 *
51994 * This layer applies element-wise activation function. Other layers, notably
51995 * `dense` can also apply activation functions. Use this isolated activation
51996 * function to extract the values before and after the
51997 * activation. For instance:
51998 *
51999 * ```js
52000 * const input = tf.input({shape: [5]});
52001 * const denseLayer = tf.layers.dense({units: 1});
52002 * const activationLayer = tf.layers.activation({activation: 'relu6'});
52003 *
52004 * // Obtain the output symbolic tensors by applying the layers in order.
52005 * const denseOutput = denseLayer.apply(input);
52006 * const activationOutput = activationLayer.apply(denseOutput);
52007 *
52008 * // Create the model based on the inputs.
52009 * const model = tf.model({
52010 * inputs: input,
52011 * outputs: [denseOutput, activationOutput]
52012 * });
52013 *
52014 * // Collect both outputs and print separately.
52015 * const [denseOut, activationOut] = model.predict(tf.randomNormal([6, 5]));
52016 * denseOut.print();
52017 * activationOut.print();
52018 * ```
52019 *
52020 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
52021 */
52022 function activation(args) {
52023 return new Activation$1(args);
52024 }
52025 /**
52026 * Creates a dense (fully connected) layer.
52027 *
52028 * This layer implements the operation:
52029 * `output = activation(dot(input, kernel) + bias)`
52030 *
52031 * `activation` is the element-wise activation function
52032 * passed as the `activation` argument.
52033 *
52034 * `kernel` is a weights matrix created by the layer.
52035 *
52036 * `bias` is a bias vector created by the layer (only applicable if `useBias`
52037 * is `true`).
52038 *
52039 * **Input shape:**
52040 *
52041 * nD `tf.Tensor` with shape: `(batchSize, ..., inputDim)`.
52042 *
52043 * The most common situation would be
52044 * a 2D input with shape `(batchSize, inputDim)`.
52045 *
52046 * **Output shape:**
52047 *
52048 * nD tensor with shape: `(batchSize, ..., units)`.
52049 *
52050 * For instance, for a 2D input with shape `(batchSize, inputDim)`,
52051 * the output would have shape `(batchSize, units)`.
52052 *
52053 * Note: if the input to the layer has a rank greater than 2, then it is
52054 * flattened prior to the initial dot product with the kernel.
52055 *
52056 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
52057 */
52058 function dense(args) {
52059 return new Dense(args);
52060 }
52061 /**
52062 * Applies
52063 * [dropout](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf) to
52064 * the input.
52065 *
52066 * Dropout consists in randomly setting a fraction `rate` of input units to 0 at
52067 * each update during training time, which helps prevent overfitting.
52068 *
52069 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
52070 */
52071 function dropout$2(args) {
52072 return new Dropout(args);
52073 }
52074 /**
52075 * Spatial 1D version of Dropout.
52076 *
52077 * This Layer type performs the same function as the Dropout layer, but it drops
52078 * entire 1D feature maps instead of individual elements. For example, if an
52079 * input example consists of 3 timesteps and the feature map for each timestep
52080 * has a size of 4, a `spatialDropout1d` layer may zero out the feature maps
52081 * of the 1st timesteps and 2nd timesteps completely while sparing all feature
52082 * elements of the 3rd timestep.
52083 *
52084 * If adjacent frames (timesteps) are strongly correlated (as is normally the
52085 * case in early convolution layers), regular dropout will not regularize the
52086 * activation and will otherwise just result in merely an effective learning
52087 * rate decrease. In this case, `spatialDropout1d` will help promote
52088 * independence among feature maps and should be used instead.
52089 *
52090 * **Arguments:**
52091 * rate: A floating-point number >=0 and <=1. Fraction of the input elements
52092 * to drop.
52093 *
52094 * **Input shape:**
52095 * 3D tensor with shape `(samples, timesteps, channels)`.
52096 *
52097 * **Output shape:**
52098 * Same as the input shape.
52099 *
52100 * References:
52101 * - [Efficient Object Localization Using Convolutional
52102 * Networks](https://arxiv.org/abs/1411.4280)
52103 *
52104 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
52105 */
52106 function spatialDropout1d(args) {
52107 return new SpatialDropout1D(args);
52108 }
52109 /**
52110 * Flattens the input. Does not affect the batch size.
52111 *
52112 * A `Flatten` layer flattens each batch in its inputs to 1D (making the output
52113 * 2D).
52114 *
52115 * For example:
52116 *
52117 * ```js
52118 * const input = tf.input({shape: [4, 3]});
52119 * const flattenLayer = tf.layers.flatten();
52120 * // Inspect the inferred output shape of the flatten layer, which
52121 * // equals `[null, 12]`. The 2nd dimension is 4 * 3, i.e., the result of the
52122 * // flattening. (The 1st dimension is the undermined batch size.)
52123 * console.log(JSON.stringify(flattenLayer.apply(input).shape));
52124 * ```
52125 *
52126 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
52127 */
52128 function flatten$2(args) {
52129 return new Flatten(args);
52130 }
52131 /**
52132 * Repeats the input n times in a new dimension.
52133 *
52134 * ```js
52135 * const model = tf.sequential();
52136 * model.add(tf.layers.repeatVector({n: 4, inputShape: [2]}));
52137 * const x = tf.tensor2d([[10, 20]]);
52138 * // Use the model to do inference on a data point the model hasn't see
52139 * model.predict(x).print();
52140 * // output shape is now [batch, 2, 4]
52141 * ```
52142 *
52143 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
52144 */
52145 function repeatVector(args) {
52146 return new RepeatVector(args);
52147 }
52148 /**
52149 * Reshapes an input to a certain shape.
52150 *
52151 * ```js
52152 * const input = tf.input({shape: [4, 3]});
52153 * const reshapeLayer = tf.layers.reshape({targetShape: [2, 6]});
52154 * // Inspect the inferred output shape of the Reshape layer, which
52155 * // equals `[null, 2, 6]`. (The 1st dimension is the undermined batch size.)
52156 * console.log(JSON.stringify(reshapeLayer.apply(input).shape));
52157 * ```
52158 *
52159 * Input shape:
52160 * Arbitrary, although all dimensions in the input shape must be fixed.
52161 * Use the configuration `inputShape` when using this layer as the
52162 * first layer in a model.
52163 *
52164 *
52165 * Output shape:
52166 * [batchSize, targetShape[0], targetShape[1], ...,
52167 * targetShape[targetShape.length - 1]].
52168 *
52169 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
52170 */
52171 function reshape$1(args) {
52172 return new Reshape$1(args);
52173 }
52174 /**
52175 * Permutes the dimensions of the input according to a given pattern.
52176 *
52177 * Useful for, e.g., connecting RNNs and convnets together.
52178 *
52179 * Example:
52180 *
52181 * ```js
52182 * const model = tf.sequential();
52183 * model.add(tf.layers.permute({
52184 * dims: [2, 1],
52185 * inputShape: [10, 64]
52186 * }));
52187 * console.log(model.outputShape);
52188 * // Now model's output shape is [null, 64, 10], where null is the
52189 * // unpermuted sample (batch) dimension.
52190 * ```
52191 *
52192 * Input shape:
52193 * Arbitrary. Use the configuration field `inputShape` when using this
52194 * layer as the first layer in a model.
52195 *
52196 * Output shape:
52197 * Same rank as the input shape, but with the dimensions re-ordered (i.e.,
52198 * permuted) according to the `dims` configuration of this layer.
52199 *
52200 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
52201 */
52202 function permute(args) {
52203 return new Permute(args);
52204 }
52205 /**
52206 * Maps positive integers (indices) into dense vectors of fixed size.
52207 * eg. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]
52208 *
52209 * **Input shape:** 2D tensor with shape: `[batchSize, sequenceLength]`.
52210 *
52211 * **Output shape:** 3D tensor with shape: `[batchSize, sequenceLength,
52212 * outputDim]`.
52213 *
52214 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
52215 */
52216 function embedding(args) {
52217 return new Embedding(args);
52218 }
52219 // Merge Layers.
52220 /**
52221 * Layer that performs element-wise addition on an `Array` of inputs.
52222 *
52223 * It takes as input a list of tensors, all of the same shape, and returns a
52224 * single tensor (also of the same shape). The inputs are specified as an
52225 * `Array` when the `apply` method of the `Add` layer instance is called. For
52226 * example:
52227 *
52228 * ```js
52229 * const input1 = tf.input({shape: [2, 2]});
52230 * const input2 = tf.input({shape: [2, 2]});
52231 * const addLayer = tf.layers.add();
52232 * const sum = addLayer.apply([input1, input2]);
52233 * console.log(JSON.stringify(sum.shape));
52234 * // You get [null, 2, 2], with the first dimension as the undetermined batch
52235 * // dimension.
52236 * ```
52237 *
52238 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
52239 */
52240 function add$3(args) {
52241 return new Add$1(args);
52242 }
52243 /**
52244 * Layer that performs element-wise averaging on an `Array` of inputs.
52245 *
52246 * It takes as input a list of tensors, all of the same shape, and returns a
52247 * single tensor (also of the same shape). For example:
52248 *
52249 * ```js
52250 * const input1 = tf.input({shape: [2, 2]});
52251 * const input2 = tf.input({shape: [2, 2]});
52252 * const averageLayer = tf.layers.average();
52253 * const average = averageLayer.apply([input1, input2]);
52254 * console.log(JSON.stringify(average.shape));
52255 * // You get [null, 2, 2], with the first dimension as the undetermined batch
52256 * // dimension.
52257 * ```
52258 *
52259 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
52260 */
52261 function average$1(args) {
52262 return new Average(args);
52263 }
52264 /**
52265 * Layer that concatenates an `Array` of inputs.
52266 *
52267 * It takes a list of tensors, all of the same shape except for the
52268 * concatenation axis, and returns a single tensor, the concatenation
52269 * of all inputs. For example:
52270 *
52271 * ```js
52272 * const input1 = tf.input({shape: [2, 2]});
52273 * const input2 = tf.input({shape: [2, 3]});
52274 * const concatLayer = tf.layers.concatenate();
52275 * const output = concatLayer.apply([input1, input2]);
52276 * console.log(JSON.stringify(output.shape));
52277 * // You get [null, 2, 5], with the first dimension as the undetermined batch
52278 * // dimension. The last dimension (5) is the result of concatenating the
52279 * // last dimensions of the inputs (2 and 3).
52280 * ```
52281 *
52282 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
52283 */
52284 function concatenate$2(args) {
52285 return new Concatenate(args);
52286 }
52287 /**
52288 * Layer that computes the element-wise maximum an `Array` of inputs.
52289 *
52290 * It takes as input a list of tensors, all of the same shape and returns a
52291 * single tensor (also of the same shape). For example:
52292 *
52293 * ```js
52294 * const input1 = tf.input({shape: [2, 2]});
52295 * const input2 = tf.input({shape: [2, 2]});
52296 * const maxLayer = tf.layers.maximum();
52297 * const max = maxLayer.apply([input1, input2]);
52298 * console.log(JSON.stringify(max.shape));
52299 * // You get [null, 2, 2], with the first dimension as the undetermined batch
52300 * // dimension.
52301 * ```
52302 *
52303 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
52304 */
52305 function maximum$2(args) {
52306 return new Maximum$1(args);
52307 }
52308 /**
52309 * Layer that computes the element-wise minimum of an `Array` of inputs.
52310 *
52311 * It takes as input a list of tensors, all of the same shape and returns a
52312 * single tensor (also of the same shape). For example:
52313 *
52314 * ```js
52315 * const input1 = tf.input({shape: [2, 2]});
52316 * const input2 = tf.input({shape: [2, 2]});
52317 * const minLayer = tf.layers.minimum();
52318 * const min = minLayer.apply([input1, input2]);
52319 * console.log(JSON.stringify(min.shape));
52320 * // You get [null, 2, 2], with the first dimension as the undetermined batch
52321 * // dimension.
52322 * ```
52323 *
52324 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
52325 */
52326 function minimum$2(args) {
52327 return new Minimum$1(args);
52328 }
52329 /**
52330 * Layer that multiplies (element-wise) an `Array` of inputs.
52331 *
52332 * It takes as input an Array of tensors, all of the same
52333 * shape, and returns a single tensor (also of the same shape).
52334 * For example:
52335 *
52336 * ```js
52337 * const input1 = tf.input({shape: [2, 2]});
52338 * const input2 = tf.input({shape: [2, 2]});
52339 * const input3 = tf.input({shape: [2, 2]});
52340 * const multiplyLayer = tf.layers.multiply();
52341 * const product = multiplyLayer.apply([input1, input2, input3]);
52342 * console.log(product.shape);
52343 * // You get [null, 2, 2], with the first dimension as the undetermined batch
52344 * // dimension.
52345 *
52346 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
52347 */
52348 function multiply$1(args) {
52349 return new Multiply$1(args);
52350 }
52351 /**
52352 * Layer that computes a dot product between samples in two tensors.
52353 *
52354 * E.g., if applied to a list of two tensors `a` and `b` both of shape
52355 * `[batchSize, n]`, the output will be a tensor of shape `[batchSize, 1]`,
52356 * where each entry at index `[i, 0]` will be the dot product between
52357 * `a[i, :]` and `b[i, :]`.
52358 *
52359 * Example:
52360 *
52361 * ```js
52362 * const dotLayer = tf.layers.dot({axes: -1});
52363 * const x1 = tf.tensor2d([[10, 20], [30, 40]]);
52364 * const x2 = tf.tensor2d([[-1, -2], [-3, -4]]);
52365 *
52366 * // Invoke the layer's apply() method in eager (imperative) mode.
52367 * const y = dotLayer.apply([x1, x2]);
52368 * y.print();
52369 * ```
52370 *
52371 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
52372 */
52373 function dot$2(args) {
52374 return new Dot(args);
52375 }
52376 // Normalization Layers.
52377 /**
52378 * Batch normalization layer (Ioffe and Szegedy, 2014).
52379 *
52380 * Normalize the activations of the previous layer at each batch,
52381 * i.e. applies a transformation that maintains the mean activation
52382 * close to 0 and the activation standard deviation close to 1.
52383 *
52384 * Input shape:
52385 * Arbitrary. Use the keyword argument `inputShape` (Array of integers, does
52386 * not include the sample axis) when calling the constructor of this class,
52387 * if this layer is used as a first layer in a model.
52388 *
52389 * Output shape:
52390 * Same shape as input.
52391 *
52392 * References:
52393 * - [Batch Normalization: Accelerating Deep Network Training by Reducing
52394 * Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
52395 *
52396 * @doc {heading: 'Layers', subheading: 'Normalization', namespace: 'layers'}
52397 */
52398 function batchNormalization$1(args) {
52399 return new BatchNormalization(args);
52400 }
52401 /**
52402 * Layer-normalization layer (Ba et al., 2016).
52403 *
52404 * Normalizes the activations of the previous layer for each given example in a
52405 * batch independently, instead of across a batch like in `batchNormalization`.
52406 * In other words, this layer applies a transformation that maintanis the mean
52407 * activation within each example close to0 and activation variance close to 1.
52408 *
52409 * Input shape:
52410 * Arbitrary. Use the argument `inputShape` when using this layer as the first
52411 * layer in a model.
52412 *
52413 * Output shape:
52414 * Same as input.
52415 *
52416 * References:
52417 * - [Layer Normalization](https://arxiv.org/abs/1607.06450)
52418 *
52419 * @doc {heading: 'Layers', subheading: 'Normalization', namespace: 'layers'}
52420 */
52421 function layerNormalization(args) {
52422 return new LayerNormalization(args);
52423 }
52424 // Padding Layers.
52425 /**
52426 * Zero-padding layer for 2D input (e.g., image).
52427 *
52428 * This layer can add rows and columns of zeros
52429 * at the top, bottom, left and right side of an image tensor.
52430 *
52431 * Input shape:
52432 * 4D tensor with shape:
52433 * - If `dataFormat` is `"channelsLast"`:
52434 * `[batch, rows, cols, channels]`
52435 * - If `data_format` is `"channels_first"`:
52436 * `[batch, channels, rows, cols]`.
52437 *
52438 * Output shape:
52439 * 4D with shape:
52440 * - If `dataFormat` is `"channelsLast"`:
52441 * `[batch, paddedRows, paddedCols, channels]`
52442 * - If `dataFormat` is `"channelsFirst"`:
52443 * `[batch, channels, paddedRows, paddedCols]`.
52444 *
52445 * @doc {heading: 'Layers', subheading: 'Padding', namespace: 'layers'}
52446 */
52447 function zeroPadding2d(args) {
52448 return new ZeroPadding2D(args);
52449 }
52450 // Pooling Layers.
52451 /**
52452 * Average pooling operation for spatial data.
52453 *
52454 * Input shape: `[batchSize, inLength, channels]`
52455 *
52456 * Output shape: `[batchSize, pooledLength, channels]`
52457 *
52458 * `tf.avgPool1d` is an alias.
52459 *
52460 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
52461 */
52462 function averagePooling1d(args) {
52463 return new AveragePooling1D(args);
52464 }
52465 function avgPool1d(args) {
52466 return averagePooling1d(args);
52467 }
52468 // For backwards compatibility.
52469 // See https://github.com/tensorflow/tfjs/issues/152
52470 function avgPooling1d(args) {
52471 return averagePooling1d(args);
52472 }
52473 /**
52474 * Average pooling operation for spatial data.
52475 *
52476 * Input shape:
52477 * - If `dataFormat === CHANNEL_LAST`:
52478 * 4D tensor with shape:
52479 * `[batchSize, rows, cols, channels]`
52480 * - If `dataFormat === CHANNEL_FIRST`:
52481 * 4D tensor with shape:
52482 * `[batchSize, channels, rows, cols]`
52483 *
52484 * Output shape
52485 * - If `dataFormat === CHANNEL_LAST`:
52486 * 4D tensor with shape:
52487 * `[batchSize, pooleRows, pooledCols, channels]`
52488 * - If `dataFormat === CHANNEL_FIRST`:
52489 * 4D tensor with shape:
52490 * `[batchSize, channels, pooleRows, pooledCols]`
52491 *
52492 * `tf.avgPool2d` is an alias.
52493 *
52494 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
52495 */
52496 function averagePooling2d(args) {
52497 return new AveragePooling2D(args);
52498 }
52499 function avgPool2d(args) {
52500 return averagePooling2d(args);
52501 }
52502 // For backwards compatibility.
52503 // See https://github.com/tensorflow/tfjs/issues/152
52504 function avgPooling2d(args) {
52505 return averagePooling2d(args);
52506 }
52507 /**
52508 * Average pooling operation for 3D data.
52509 *
52510 * Input shape
52511 * - If `dataFormat === channelsLast`:
52512 * 5D tensor with shape:
52513 * `[batchSize, depths, rows, cols, channels]`
52514 * - If `dataFormat === channelsFirst`:
52515 * 4D tensor with shape:
52516 * `[batchSize, channels, depths, rows, cols]`
52517 *
52518 * Output shape
52519 * - If `dataFormat=channelsLast`:
52520 * 5D tensor with shape:
52521 * `[batchSize, pooledDepths, pooledRows, pooledCols, channels]`
52522 * - If `dataFormat=channelsFirst`:
52523 * 5D tensor with shape:
52524 * `[batchSize, channels, pooledDepths, pooledRows, pooledCols]`
52525 *
52526 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
52527 */
52528 function averagePooling3d(args) {
52529 return new AveragePooling3D(args);
52530 }
52531 function avgPool3d$1(args) {
52532 return averagePooling3d(args);
52533 }
52534 // For backwards compatibility.
52535 // See https://github.com/tensorflow/tfjs/issues/152
52536 function avgPooling3d(args) {
52537 return averagePooling3d(args);
52538 }
52539 /**
52540 * Global average pooling operation for temporal data.
52541 *
52542 * Input Shape: 3D tensor with shape: `[batchSize, steps, features]`.
52543 *
52544 * Output Shape:2D tensor with shape: `[batchSize, features]`.
52545 *
52546 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
52547 */
52548 function globalAveragePooling1d(args) {
52549 return new GlobalAveragePooling1D(args);
52550 }
52551 /**
52552 * Global average pooling operation for spatial data.
52553 *
52554 * Input shape:
52555 * - If `dataFormat` is `CHANNEL_LAST`:
52556 * 4D tensor with shape: `[batchSize, rows, cols, channels]`.
52557 * - If `dataFormat` is `CHANNEL_FIRST`:
52558 * 4D tensor with shape: `[batchSize, channels, rows, cols]`.
52559 *
52560 * Output shape:
52561 * 2D tensor with shape: `[batchSize, channels]`.
52562 *
52563 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
52564 */
52565 function globalAveragePooling2d(args) {
52566 return new GlobalAveragePooling2D(args);
52567 }
52568 /**
52569 * Global max pooling operation for temporal data.
52570 *
52571 * Input Shape: 3D tensor with shape: `[batchSize, steps, features]`.
52572 *
52573 * Output Shape:2D tensor with shape: `[batchSize, features]`.
52574 *
52575 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
52576 */
52577 function globalMaxPooling1d(args) {
52578 return new GlobalMaxPooling1D(args);
52579 }
52580 /**
52581 * Global max pooling operation for spatial data.
52582 *
52583 * Input shape:
52584 * - If `dataFormat` is `CHANNEL_LAST`:
52585 * 4D tensor with shape: `[batchSize, rows, cols, channels]`.
52586 * - If `dataFormat` is `CHANNEL_FIRST`:
52587 * 4D tensor with shape: `[batchSize, channels, rows, cols]`.
52588 *
52589 * Output shape:
52590 * 2D tensor with shape: `[batchSize, channels]`.
52591 *
52592 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
52593 */
52594 function globalMaxPooling2d(args) {
52595 return new GlobalMaxPooling2D(args);
52596 }
52597 /**
52598 * Max pooling operation for temporal data.
52599 *
52600 * Input shape: `[batchSize, inLength, channels]`
52601 *
52602 * Output shape: `[batchSize, pooledLength, channels]`
52603 *
52604 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
52605 */
52606 function maxPooling1d(args) {
52607 return new MaxPooling1D(args);
52608 }
52609 /**
52610 * Max pooling operation for spatial data.
52611 *
52612 * Input shape
52613 * - If `dataFormat === CHANNEL_LAST`:
52614 * 4D tensor with shape:
52615 * `[batchSize, rows, cols, channels]`
52616 * - If `dataFormat === CHANNEL_FIRST`:
52617 * 4D tensor with shape:
52618 * `[batchSize, channels, rows, cols]`
52619 *
52620 * Output shape
52621 * - If `dataFormat=CHANNEL_LAST`:
52622 * 4D tensor with shape:
52623 * `[batchSize, pooleRows, pooledCols, channels]`
52624 * - If `dataFormat=CHANNEL_FIRST`:
52625 * 4D tensor with shape:
52626 * `[batchSize, channels, pooleRows, pooledCols]`
52627 *
52628 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
52629 */
52630 function maxPooling2d(args) {
52631 return new MaxPooling2D(args);
52632 }
52633 /**
52634 * Max pooling operation for 3D data.
52635 *
52636 * Input shape
52637 * - If `dataFormat === channelsLast`:
52638 * 5D tensor with shape:
52639 * `[batchSize, depths, rows, cols, channels]`
52640 * - If `dataFormat === channelsFirst`:
52641 * 5D tensor with shape:
52642 * `[batchSize, channels, depths, rows, cols]`
52643 *
52644 * Output shape
52645 * - If `dataFormat=channelsLast`:
52646 * 5D tensor with shape:
52647 * `[batchSize, pooledDepths, pooledRows, pooledCols, channels]`
52648 * - If `dataFormat=channelsFirst`:
52649 * 5D tensor with shape:
52650 * `[batchSize, channels, pooledDepths, pooledRows, pooledCols]`
52651 *
52652 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
52653 */
52654 function maxPooling3d(args) {
52655 return new MaxPooling3D(args);
52656 }
52657 // Recurrent Layers.
52658 /**
52659 * Gated Recurrent Unit - Cho et al. 2014.
52660 *
52661 * This is an `RNN` layer consisting of one `GRUCell`. However, unlike
52662 * the underlying `GRUCell`, the `apply` method of `SimpleRNN` operates
52663 * on a sequence of inputs. The shape of the input (not including the first,
52664 * batch dimension) needs to be at least 2-D, with the first dimension being
52665 * time steps. For example:
52666 *
52667 * ```js
52668 * const rnn = tf.layers.gru({units: 8, returnSequences: true});
52669 *
52670 * // Create an input with 10 time steps.
52671 * const input = tf.input({shape: [10, 20]});
52672 * const output = rnn.apply(input);
52673 *
52674 * console.log(JSON.stringify(output.shape));
52675 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
52676 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
52677 * // 3rd dimension is the `GRUCell`'s number of units.
52678 *
52679 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
52680 */
52681 function gru(args) {
52682 return new GRU(args);
52683 }
52684 /**
52685 * Cell class for `GRU`.
52686 *
52687 * `GRUCell` is distinct from the `RNN` subclass `GRU` in that its
52688 * `apply` method takes the input data of only a single time step and returns
52689 * the cell's output at the time step, while `GRU` takes the input data
52690 * over a number of time steps. For example:
52691 *
52692 * ```js
52693 * const cell = tf.layers.gruCell({units: 2});
52694 * const input = tf.input({shape: [10]});
52695 * const output = cell.apply(input);
52696 *
52697 * console.log(JSON.stringify(output.shape));
52698 * // [null, 10]: This is the cell's output at a single time step. The 1st
52699 * // dimension is the unknown batch size.
52700 * ```
52701 *
52702 * Instance(s) of `GRUCell` can be used to construct `RNN` layers. The
52703 * most typical use of this workflow is to combine a number of cells into a
52704 * stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an
52705 * RNN. For example:
52706 *
52707 * ```js
52708 * const cells = [
52709 * tf.layers.gruCell({units: 4}),
52710 * tf.layers.gruCell({units: 8}),
52711 * ];
52712 * const rnn = tf.layers.rnn({cell: cells, returnSequences: true});
52713 *
52714 * // Create an input with 10 time steps and a length-20 vector at each step.
52715 * const input = tf.input({shape: [10, 20]});
52716 * const output = rnn.apply(input);
52717 *
52718 * console.log(JSON.stringify(output.shape));
52719 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
52720 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
52721 * // 3rd dimension is the last `gruCell`'s number of units.
52722 * ```
52723 *
52724 * To create an `RNN` consisting of only *one* `GRUCell`, use the
52725 * `tf.layers.gru`.
52726 *
52727 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
52728 */
52729 function gruCell(args) {
52730 return new GRUCell(args);
52731 }
52732 /**
52733 * Long-Short Term Memory layer - Hochreiter 1997.
52734 *
52735 * This is an `RNN` layer consisting of one `LSTMCell`. However, unlike
52736 * the underlying `LSTMCell`, the `apply` method of `LSTM` operates
52737 * on a sequence of inputs. The shape of the input (not including the first,
52738 * batch dimension) needs to be at least 2-D, with the first dimension being
52739 * time steps. For example:
52740 *
52741 * ```js
52742 * const lstm = tf.layers.lstm({units: 8, returnSequences: true});
52743 *
52744 * // Create an input with 10 time steps.
52745 * const input = tf.input({shape: [10, 20]});
52746 * const output = lstm.apply(input);
52747 *
52748 * console.log(JSON.stringify(output.shape));
52749 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
52750 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
52751 * // 3rd dimension is the `LSTMCell`'s number of units.
52752 *
52753 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
52754 */
52755 function lstm(args) {
52756 return new LSTM(args);
52757 }
52758 /**
52759 * Cell class for `LSTM`.
52760 *
52761 * `LSTMCell` is distinct from the `RNN` subclass `LSTM` in that its
52762 * `apply` method takes the input data of only a single time step and returns
52763 * the cell's output at the time step, while `LSTM` takes the input data
52764 * over a number of time steps. For example:
52765 *
52766 * ```js
52767 * const cell = tf.layers.lstmCell({units: 2});
52768 * const input = tf.input({shape: [10]});
52769 * const output = cell.apply(input);
52770 *
52771 * console.log(JSON.stringify(output.shape));
52772 * // [null, 10]: This is the cell's output at a single time step. The 1st
52773 * // dimension is the unknown batch size.
52774 * ```
52775 *
52776 * Instance(s) of `LSTMCell` can be used to construct `RNN` layers. The
52777 * most typical use of this workflow is to combine a number of cells into a
52778 * stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an
52779 * RNN. For example:
52780 *
52781 * ```js
52782 * const cells = [
52783 * tf.layers.lstmCell({units: 4}),
52784 * tf.layers.lstmCell({units: 8}),
52785 * ];
52786 * const rnn = tf.layers.rnn({cell: cells, returnSequences: true});
52787 *
52788 * // Create an input with 10 time steps and a length-20 vector at each step.
52789 * const input = tf.input({shape: [10, 20]});
52790 * const output = rnn.apply(input);
52791 *
52792 * console.log(JSON.stringify(output.shape));
52793 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
52794 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
52795 * // 3rd dimension is the last `lstmCell`'s number of units.
52796 * ```
52797 *
52798 * To create an `RNN` consisting of only *one* `LSTMCell`, use the
52799 * `tf.layers.lstm`.
52800 *
52801 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
52802 */
52803 function lstmCell(args) {
52804 return new LSTMCell(args);
52805 }
52806 /**
52807 * Fully-connected RNN where the output is to be fed back to input.
52808 *
52809 * This is an `RNN` layer consisting of one `SimpleRNNCell`. However, unlike
52810 * the underlying `SimpleRNNCell`, the `apply` method of `SimpleRNN` operates
52811 * on a sequence of inputs. The shape of the input (not including the first,
52812 * batch dimension) needs to be at least 2-D, with the first dimension being
52813 * time steps. For example:
52814 *
52815 * ```js
52816 * const rnn = tf.layers.simpleRNN({units: 8, returnSequences: true});
52817 *
52818 * // Create an input with 10 time steps.
52819 * const input = tf.input({shape: [10, 20]});
52820 * const output = rnn.apply(input);
52821 *
52822 * console.log(JSON.stringify(output.shape));
52823 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
52824 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
52825 * // 3rd dimension is the `SimpleRNNCell`'s number of units.
52826 * ```
52827 *
52828 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
52829 */
52830 function simpleRNN(args) {
52831 return new SimpleRNN(args);
52832 }
52833 /**
52834 * Cell class for `SimpleRNN`.
52835 *
52836 * `SimpleRNNCell` is distinct from the `RNN` subclass `SimpleRNN` in that its
52837 * `apply` method takes the input data of only a single time step and returns
52838 * the cell's output at the time step, while `SimpleRNN` takes the input data
52839 * over a number of time steps. For example:
52840 *
52841 * ```js
52842 * const cell = tf.layers.simpleRNNCell({units: 2});
52843 * const input = tf.input({shape: [10]});
52844 * const output = cell.apply(input);
52845 *
52846 * console.log(JSON.stringify(output.shape));
52847 * // [null, 10]: This is the cell's output at a single time step. The 1st
52848 * // dimension is the unknown batch size.
52849 * ```
52850 *
52851 * Instance(s) of `SimpleRNNCell` can be used to construct `RNN` layers. The
52852 * most typical use of this workflow is to combine a number of cells into a
52853 * stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an
52854 * RNN. For example:
52855 *
52856 * ```js
52857 * const cells = [
52858 * tf.layers.simpleRNNCell({units: 4}),
52859 * tf.layers.simpleRNNCell({units: 8}),
52860 * ];
52861 * const rnn = tf.layers.rnn({cell: cells, returnSequences: true});
52862 *
52863 * // Create an input with 10 time steps and a length-20 vector at each step.
52864 * const input = tf.input({shape: [10, 20]});
52865 * const output = rnn.apply(input);
52866 *
52867 * console.log(JSON.stringify(output.shape));
52868 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
52869 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
52870 * // 3rd dimension is the last `SimpleRNNCell`'s number of units.
52871 * ```
52872 *
52873 * To create an `RNN` consisting of only *one* `SimpleRNNCell`, use the
52874 * `tf.layers.simpleRNN`.
52875 *
52876 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
52877 */
52878 function simpleRNNCell(args) {
52879 return new SimpleRNNCell(args);
52880 }
52881 /**
52882 * Convolutional LSTM layer - Xingjian Shi 2015.
52883 *
52884 * This is an `ConvRNN2D` layer consisting of one `ConvLSTM2DCell`. However,
52885 * unlike the underlying `ConvLSTM2DCell`, the `apply` method of `ConvLSTM2D`
52886 * operates on a sequence of inputs. The shape of the input (not including the
52887 * first, batch dimension) needs to be 4-D, with the first dimension being time
52888 * steps. For example:
52889 *
52890 * ```js
52891 * const filters = 3;
52892 * const kernelSize = 3;
52893 *
52894 * const batchSize = 4;
52895 * const sequenceLength = 2;
52896 * const size = 5;
52897 * const channels = 3;
52898 *
52899 * const inputShape = [batchSize, sequenceLength, size, size, channels];
52900 * const input = tf.ones(inputShape);
52901 *
52902 * const layer = tf.layers.convLstm2d({filters, kernelSize});
52903 *
52904 * const output = layer.apply(input);
52905 * ```
52906 */
52907 /** @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */
52908 function convLstm2d(args) {
52909 return new ConvLSTM2D(args);
52910 }
52911 /**
52912 * Cell class for `ConvLSTM2D`.
52913 *
52914 * `ConvLSTM2DCell` is distinct from the `ConvRNN2D` subclass `ConvLSTM2D` in
52915 * that its `call` method takes the input data of only a single time step and
52916 * returns the cell's output at the time step, while `ConvLSTM2D` takes the
52917 * input data over a number of time steps. For example:
52918 *
52919 * ```js
52920 * const filters = 3;
52921 * const kernelSize = 3;
52922 *
52923 * const sequenceLength = 1;
52924 * const size = 5;
52925 * const channels = 3;
52926 *
52927 * const inputShape = [sequenceLength, size, size, channels];
52928 * const input = tf.ones(inputShape);
52929 *
52930 * const cell = tf.layers.convLstm2dCell({filters, kernelSize});
52931 *
52932 * cell.build(input.shape);
52933 *
52934 * const outputSize = size - kernelSize + 1;
52935 * const outShape = [sequenceLength, outputSize, outputSize, filters];
52936 *
52937 * const initialH = tf.zeros(outShape);
52938 * const initialC = tf.zeros(outShape);
52939 *
52940 * const [o, h, c] = cell.call([input, initialH, initialC], {});
52941 * ```
52942 */
52943 /** @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */
52944 function convLstm2dCell(args) {
52945 return new ConvLSTM2DCell(args);
52946 }
52947 /**
52948 * Base class for recurrent layers.
52949 *
52950 * Input shape:
52951 * 3D tensor with shape `[batchSize, timeSteps, inputDim]`.
52952 *
52953 * Output shape:
52954 * - if `returnState`, an Array of tensors (i.e., `tf.Tensor`s). The first
52955 * tensor is the output. The remaining tensors are the states at the
52956 * last time step, each with shape `[batchSize, units]`.
52957 * - if `returnSequences`, the output will have shape
52958 * `[batchSize, timeSteps, units]`.
52959 * - else, the output will have shape `[batchSize, units]`.
52960 *
52961 * Masking:
52962 * This layer supports masking for input data with a variable number
52963 * of timesteps. To introduce masks to your data,
52964 * use an embedding layer with the `mask_zero` parameter
52965 * set to `True`.
52966 *
52967 * Notes on using statefulness in RNNs:
52968 * You can set RNN layers to be 'stateful', which means that the states
52969 * computed for the samples in one batch will be reused as initial states
52970 * for the samples in the next batch. This assumes a one-to-one mapping
52971 * between samples in different successive batches.
52972 *
52973 * To enable statefulness:
52974 * - specify `stateful: true` in the layer constructor.
52975 * - specify a fixed batch size for your model, by passing
52976 * if sequential model:
52977 * `batchInputShape=[...]` to the first layer in your model.
52978 * else for functional model with 1 or more Input layers:
52979 * `batchShape=[...]` to all the first layers in your model.
52980 * This is the expected shape of your inputs *including the batch size*.
52981 * It should be a tuple of integers, e.g. `(32, 10, 100)`.
52982 * - specify `shuffle=False` when calling fit().
52983 *
52984 * To reset the states of your model, call `.resetStates()` on either
52985 * a specific layer, or on your entire model.
52986 *
52987 * Note on specifying the initial state of RNNs
52988 * You can specify the initial state of RNN layers symbolically by
52989 * calling them with the option `initialState`. The value of
52990 * `initialState` should be a tensor or list of tensors representing
52991 * the initial state of the RNN layer.
52992 *
52993 * You can specify the initial state of RNN layers numerically by
52994 * calling `resetStates` with the keyword argument `states`. The value of
52995 * `states` should be a numpy array or list of numpy arrays representing
52996 * the initial state of the RNN layer.
52997 *
52998 * Note on passing external constants to RNNs
52999 * You can pass "external" constants to the cell using the `constants`
53000 * keyword argument of `RNN.call` method. This requires that the `cell.call`
53001 * method accepts the same keyword argument `constants`. Such constants
53002 * can be used to conditon the cell transformation on additional static inputs
53003 * (not changing over time), a.k.a an attention mechanism.
53004 *
53005 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
53006 */
53007 function rnn$1(args) {
53008 return new RNN(args);
53009 }
53010 /**
53011 * Wrapper allowing a stack of RNN cells to behave as a single cell.
53012 *
53013 * Used to implement efficient stacked RNNs.
53014 *
53015 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
53016 */
53017 function stackedRNNCells(args) {
53018 return new StackedRNNCells(args);
53019 }
53020 // Wrapper Layers.
53021 /** @doc {heading: 'Layers', subheading: 'Wrapper', namespace: 'layers'} */
53022 function bidirectional(args) {
53023 return new Bidirectional(args);
53024 }
53025 /**
53026 * This wrapper applies a layer to every temporal slice of an input.
53027 *
53028 * The input should be at least 3D, and the dimension of the index `1` will be
53029 * considered to be the temporal dimension.
53030 *
53031 * Consider a batch of 32 samples, where each sample is a sequence of 10 vectors
53032 * of 16 dimensions. The batch input shape of the layer is then `[32, 10,
53033 * 16]`, and the `inputShape`, not including the sample dimension, is
53034 * `[10, 16]`.
53035 *
53036 * You can then use `TimeDistributed` to apply a `Dense` layer to each of the 10
53037 * timesteps, independently:
53038 *
53039 * ```js
53040 * const model = tf.sequential();
53041 * model.add(tf.layers.timeDistributed({
53042 * layer: tf.layers.dense({units: 8}),
53043 * inputShape: [10, 16],
53044 * }));
53045 *
53046 * // Now model.outputShape = [null, 10, 8].
53047 * // The output will then have shape `[32, 10, 8]`.
53048 *
53049 * // In subsequent layers, there is no need for `inputShape`:
53050 * model.add(tf.layers.timeDistributed({layer: tf.layers.dense({units: 32})}));
53051 * console.log(JSON.stringify(model.outputs[0].shape));
53052 * // Now model.outputShape = [null, 10, 32].
53053 * ```
53054 *
53055 * The output will then have shape `[32, 10, 32]`.
53056 *
53057 * `TimeDistributed` can be used with arbitrary layers, not just `Dense`, for
53058 * instance a `Conv2D` layer.
53059 *
53060 * ```js
53061 * const model = tf.sequential();
53062 * model.add(tf.layers.timeDistributed({
53063 * layer: tf.layers.conv2d({filters: 64, kernelSize: [3, 3]}),
53064 * inputShape: [10, 299, 299, 3],
53065 * }));
53066 * console.log(JSON.stringify(model.outputs[0].shape));
53067 * ```
53068 *
53069 * @doc {heading: 'Layers', subheading: 'Wrapper', namespace: 'layers'}
53070 */
53071 function timeDistributed(args) {
53072 return new TimeDistributed(args);
53073 }
53074 // Aliases for pooling.
53075 const globalMaxPool1d = globalMaxPooling1d;
53076 const globalMaxPool2d = globalMaxPooling2d;
53077 const maxPool1d = maxPooling1d;
53078 const maxPool2d = maxPooling2d;
53079 /**
53080 * Apply additive zero-centered Gaussian noise.
53081 *
53082 * As it is a regularization layer, it is only active at training time.
53083 *
53084 * This is useful to mitigate overfitting
53085 * (you could see it as a form of random data augmentation).
53086 * Gaussian Noise (GS) is a natural choice as corruption process
53087 * for real valued inputs.
53088 *
53089 * # Arguments
53090 * stddev: float, standard deviation of the noise distribution.
53091 *
53092 * # Input shape
53093 * Arbitrary. Use the keyword argument `input_shape`
53094 * (tuple of integers, does not include the samples axis)
53095 * when using this layer as the first layer in a model.
53096 *
53097 * # Output shape
53098 * Same shape as input.
53099 *
53100 * @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'}
53101 */
53102 function gaussianNoise(args) {
53103 return new GaussianNoise(args);
53104 }
53105 /**
53106 * Apply multiplicative 1-centered Gaussian noise.
53107 *
53108 * As it is a regularization layer, it is only active at training time.
53109 *
53110 * Arguments:
53111 * - `rate`: float, drop probability (as with `Dropout`).
53112 * The multiplicative noise will have
53113 * standard deviation `sqrt(rate / (1 - rate))`.
53114 *
53115 * Input shape:
53116 * Arbitrary. Use the keyword argument `inputShape`
53117 * (tuple of integers, does not include the samples axis)
53118 * when using this layer as the first layer in a model.
53119 *
53120 * Output shape:
53121 * Same shape as input.
53122 *
53123 * References:
53124 * - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting](
53125 * http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
53126 *
53127 * @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'}
53128 */
53129 function gaussianDropout(args) {
53130 return new GaussianDropout(args);
53131 }
53132 /**
53133 * Applies Alpha Dropout to the input.
53134 *
53135 * As it is a regularization layer, it is only active at training time.
53136 *
53137 * Alpha Dropout is a `Dropout` that keeps mean and variance of inputs
53138 * to their original values, in order to ensure the self-normalizing property
53139 * even after this dropout.
53140 * Alpha Dropout fits well to Scaled Exponential Linear Units
53141 * by randomly setting activations to the negative saturation value.
53142 *
53143 * Arguments:
53144 * - `rate`: float, drop probability (as with `Dropout`).
53145 * The multiplicative noise will have
53146 * standard deviation `sqrt(rate / (1 - rate))`.
53147 * - `noise_shape`: A 1-D `Tensor` of type `int32`, representing the
53148 * shape for randomly generated keep/drop flags.
53149 *
53150 * Input shape:
53151 * Arbitrary. Use the keyword argument `inputShape`
53152 * (tuple of integers, does not include the samples axis)
53153 * when using this layer as the first layer in a model.
53154 *
53155 * Output shape:
53156 * Same shape as input.
53157 *
53158 * References:
53159 * - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
53160 *
53161 * @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'}
53162 */
53163 function alphaDropout(args) {
53164 return new AlphaDropout(args);
53165 }
53166 /**
53167 * Masks a sequence by using a mask value to skip timesteps.
53168 *
53169 * If all features for a given sample timestep are equal to `mask_value`,
53170 * then the sample timestep will be masked (skipped) in all downstream layers
53171 * (as long as they support masking).
53172 *
53173 * If any downstream layer does not support masking yet receives such
53174 * an input mask, an exception will be raised.
53175 *
53176 * Arguments:
53177 * - `maskValue`: Either None or mask value to skip.
53178 *
53179 * Input shape:
53180 * Arbitrary. Use the keyword argument `inputShape`
53181 * (tuple of integers, does not include the samples axis)
53182 * when using this layer as the first layer in a model.
53183 *
53184 * Output shape:
53185 * Same shape as input.
53186 *
53187 * @doc {heading: 'Layers', subheading: 'Mask', namespace: 'layers'}
53188 */
53189 function masking(args) {
53190 return new Masking(args);
53191 }
53192
53193 var exports_layers = /*#__PURE__*/Object.freeze({
53194 __proto__: null,
53195 inputLayer: inputLayer,
53196 elu: elu$2,
53197 reLU: reLU,
53198 leakyReLU: leakyReLU,
53199 prelu: prelu$1,
53200 softmax: softmax$1,
53201 thresholdedReLU: thresholdedReLU,
53202 conv1d: conv1d$2,
53203 conv2d: conv2d$3,
53204 conv2dTranspose: conv2dTranspose$1,
53205 conv3d: conv3d$2,
53206 conv3dTranspose: conv3dTranspose$1,
53207 separableConv2d: separableConv2d$1,
53208 cropping2D: cropping2D,
53209 upSampling2d: upSampling2d,
53210 depthwiseConv2d: depthwiseConv2d$3,
53211 activation: activation,
53212 dense: dense,
53213 dropout: dropout$2,
53214 spatialDropout1d: spatialDropout1d,
53215 flatten: flatten$2,
53216 repeatVector: repeatVector,
53217 reshape: reshape$1,
53218 permute: permute,
53219 embedding: embedding,
53220 add: add$3,
53221 average: average$1,
53222 concatenate: concatenate$2,
53223 maximum: maximum$2,
53224 minimum: minimum$2,
53225 multiply: multiply$1,
53226 dot: dot$2,
53227 batchNormalization: batchNormalization$1,
53228 layerNormalization: layerNormalization,
53229 zeroPadding2d: zeroPadding2d,
53230 averagePooling1d: averagePooling1d,
53231 avgPool1d: avgPool1d,
53232 avgPooling1d: avgPooling1d,
53233 averagePooling2d: averagePooling2d,
53234 avgPool2d: avgPool2d,
53235 avgPooling2d: avgPooling2d,
53236 averagePooling3d: averagePooling3d,
53237 avgPool3d: avgPool3d$1,
53238 avgPooling3d: avgPooling3d,
53239 globalAveragePooling1d: globalAveragePooling1d,
53240 globalAveragePooling2d: globalAveragePooling2d,
53241 globalMaxPooling1d: globalMaxPooling1d,
53242 globalMaxPooling2d: globalMaxPooling2d,
53243 maxPooling1d: maxPooling1d,
53244 maxPooling2d: maxPooling2d,
53245 maxPooling3d: maxPooling3d,
53246 gru: gru,
53247 gruCell: gruCell,
53248 lstm: lstm,
53249 lstmCell: lstmCell,
53250 simpleRNN: simpleRNN,
53251 simpleRNNCell: simpleRNNCell,
53252 convLstm2d: convLstm2d,
53253 convLstm2dCell: convLstm2dCell,
53254 rnn: rnn$1,
53255 stackedRNNCells: stackedRNNCells,
53256 bidirectional: bidirectional,
53257 timeDistributed: timeDistributed,
53258 globalMaxPool1d: globalMaxPool1d,
53259 globalMaxPool2d: globalMaxPool2d,
53260 maxPool1d: maxPool1d,
53261 maxPool2d: maxPool2d,
53262 Layer: Layer,
53263 RNN: RNN,
53264 RNNCell: RNNCell,
53265 input: input,
53266 gaussianNoise: gaussianNoise,
53267 gaussianDropout: gaussianDropout,
53268 alphaDropout: alphaDropout,
53269 masking: masking
53270 });
53271
53272 /**
53273 * Binary accuracy metric function.
53274 *
53275 * `yTrue` and `yPred` can have 0-1 values. Example:
53276 * ```js
53277 * const x = tf.tensor2d([[1, 1, 1, 1], [0, 0, 0, 0]], [2, 4]);
53278 * const y = tf.tensor2d([[1, 0, 1, 0], [0, 0, 0, 1]], [2, 4]);
53279 * const accuracy = tf.metrics.binaryAccuracy(x, y);
53280 * accuracy.print();
53281 * ```
53282 *
53283 * `yTrue` and `yPred` can also have floating-number values between 0 and 1, in
53284 * which case the values will be thresholded at 0.5 to yield 0-1 values (i.e.,
53285 * a value >= 0.5 and <= 1.0 is interpreted as 1.
53286 * )
53287 * Example:
53288 * ```js
53289 * const x = tf.tensor1d([1, 1, 1, 1, 0, 0, 0, 0]);
53290 * const y = tf.tensor1d([0.2, 0.4, 0.6, 0.8, 0.2, 0.3, 0.4, 0.7]);
53291 * const accuracy = tf.metrics.binaryAccuracy(x, y);
53292 * accuracy.print();
53293 * ```
53294 *
53295 * @param yTrue Binary Tensor of truth.
53296 * @param yPred Binary Tensor of prediction.
53297 * @return Accuracy Tensor.
53298 *
53299 * @doc {heading: 'Metrics', namespace: 'metrics'}
53300 */
53301 function binaryAccuracy$1(yTrue, yPred) {
53302 return binaryAccuracy(yTrue, yPred);
53303 }
53304 /**
53305 * Binary crossentropy metric function.
53306 *
53307 * Example:
53308 * ```js
53309 * const x = tf.tensor2d([[0], [1], [1], [1]]);
53310 * const y = tf.tensor2d([[0], [0], [0.5], [1]]);
53311 * const crossentropy = tf.metrics.binaryCrossentropy(x, y);
53312 * crossentropy.print();
53313 * ```
53314 *
53315 * @param yTrue Binary Tensor of truth.
53316 * @param yPred Binary Tensor of prediction, probabilities for the `1` case.
53317 * @return Accuracy Tensor.
53318 *
53319 * @doc {heading: 'Metrics', namespace: 'metrics'}
53320 */
53321 function binaryCrossentropy$2(yTrue, yPred) {
53322 return binaryCrossentropy$1(yTrue, yPred);
53323 }
53324 /**
53325 * Sparse categorical accuracy metric function.
53326 *
53327 * Example:
53328 * ```js
53329 *
53330 * const yTrue = tf.tensor1d([1, 1, 2, 2, 0]);
53331 * const yPred = tf.tensor2d(
53332 * [[0, 1, 0], [1, 0, 0], [0, 0.4, 0.6], [0, 0.6, 0.4], [0.7, 0.3, 0]]);
53333 * const crossentropy = tf.metrics.sparseCategoricalAccuracy(yTrue, yPred);
53334 * crossentropy.print();
53335 * ```
53336 *
53337 * @param yTrue True labels: indices.
53338 * @param yPred Predicted probabilities or logits.
53339 * @returns Accuracy tensor.
53340 *
53341 * @doc {heading: 'Metrics', namespace: 'metrics'}
53342 */
53343 function sparseCategoricalAccuracy$1(yTrue, yPred) {
53344 return sparseCategoricalAccuracy(yTrue, yPred);
53345 }
53346 /**
53347 * Categorical accuracy metric function.
53348 *
53349 * Example:
53350 * ```js
53351 * const x = tf.tensor2d([[0, 0, 0, 1], [0, 0, 0, 1]]);
53352 * const y = tf.tensor2d([[0.1, 0.8, 0.05, 0.05], [0.1, 0.05, 0.05, 0.8]]);
53353 * const accuracy = tf.metrics.categoricalAccuracy(x, y);
53354 * accuracy.print();
53355 * ```
53356 *
53357 * @param yTrue Binary Tensor of truth: one-hot encoding of categories.
53358 * @param yPred Binary Tensor of prediction: probabilities or logits for the
53359 * same categories as in `yTrue`.
53360 * @return Accuracy Tensor.
53361 *
53362 * @doc {heading: 'Metrics', namespace: 'metrics'}
53363 */
53364 function categoricalAccuracy$1(yTrue, yPred) {
53365 return categoricalAccuracy(yTrue, yPred);
53366 }
53367 /**
53368 * Categorical crossentropy between an output tensor and a target tensor.
53369 *
53370 * @param target A tensor of the same shape as `output`.
53371 * @param output A tensor resulting from a softmax (unless `fromLogits` is
53372 * `true`, in which case `output` is expected to be the logits).
53373 * @param fromLogits Boolean, whether `output` is the result of a softmax, or is
53374 * a tensor of logits.
53375 *
53376 * @doc {heading: 'Metrics', namespace: 'metrics'}
53377 */
53378 function categoricalCrossentropy$2(yTrue, yPred) {
53379 return categoricalCrossentropy$1(yTrue, yPred);
53380 }
53381 /**
53382 * Computes the precision of the predictions with respect to the labels.
53383 *
53384 * Example:
53385 * ```js
53386 * const x = tf.tensor2d(
53387 * [
53388 * [0, 0, 0, 1],
53389 * [0, 1, 0, 0],
53390 * [0, 0, 0, 1],
53391 * [1, 0, 0, 0],
53392 * [0, 0, 1, 0]
53393 * ]
53394 * );
53395 *
53396 * const y = tf.tensor2d(
53397 * [
53398 * [0, 0, 1, 0],
53399 * [0, 1, 0, 0],
53400 * [0, 0, 0, 1],
53401 * [0, 1, 0, 0],
53402 * [0, 1, 0, 0]
53403 * ]
53404 * );
53405 *
53406 * const precision = tf.metrics.precision(x, y);
53407 * precision.print();
53408 * ```
53409 *
53410 * @param yTrue The ground truth values. Expected to be contain only 0-1 values.
53411 * @param yPred The predicted values. Expected to be contain only 0-1 values.
53412 * @return Precision Tensor.
53413 *
53414 * @doc {heading: 'Metrics', namespace: 'metrics'}
53415 */
53416 function precision$1(yTrue, yPred) {
53417 return precision(yTrue, yPred);
53418 }
53419 /**
53420 * Computes the recall of the predictions with respect to the labels.
53421 *
53422 * Example:
53423 * ```js
53424 * const x = tf.tensor2d(
53425 * [
53426 * [0, 0, 0, 1],
53427 * [0, 1, 0, 0],
53428 * [0, 0, 0, 1],
53429 * [1, 0, 0, 0],
53430 * [0, 0, 1, 0]
53431 * ]
53432 * );
53433 *
53434 * const y = tf.tensor2d(
53435 * [
53436 * [0, 0, 1, 0],
53437 * [0, 1, 0, 0],
53438 * [0, 0, 0, 1],
53439 * [0, 1, 0, 0],
53440 * [0, 1, 0, 0]
53441 * ]
53442 * );
53443 *
53444 * const recall = tf.metrics.recall(x, y);
53445 * recall.print();
53446 * ```
53447 *
53448 * @param yTrue The ground truth values. Expected to be contain only 0-1 values.
53449 * @param yPred The predicted values. Expected to be contain only 0-1 values.
53450 * @return Recall Tensor.
53451 *
53452 * @doc {heading: 'Metrics', namespace: 'metrics'}
53453 */
53454 function recall$1(yTrue, yPred) {
53455 return recall(yTrue, yPred);
53456 }
53457 /**
53458 * Loss or metric function: Cosine proximity.
53459 *
53460 * Mathematically, cosine proximity is defined as:
53461 * `-sum(l2Normalize(yTrue) * l2Normalize(yPred))`,
53462 * wherein `l2Normalize()` normalizes the L2 norm of the input to 1 and `*`
53463 * represents element-wise multiplication.
53464 *
53465 * ```js
53466 * const yTrue = tf.tensor2d([[1, 0], [1, 0]]);
53467 * const yPred = tf.tensor2d([[1 / Math.sqrt(2), 1 / Math.sqrt(2)], [0, 1]]);
53468 * const proximity = tf.metrics.cosineProximity(yTrue, yPred);
53469 * proximity.print();
53470 * ```
53471 *
53472 * @param yTrue Truth Tensor.
53473 * @param yPred Prediction Tensor.
53474 * @return Cosine proximity Tensor.
53475 *
53476 * @doc {heading: 'Metrics', namespace: 'metrics'}
53477 */
53478 function cosineProximity$1(yTrue, yPred) {
53479 return cosineProximity(yTrue, yPred);
53480 }
53481 /**
53482 * Loss or metric function: Mean absolute error.
53483 *
53484 * Mathematically, mean absolute error is defined as:
53485 * `mean(abs(yPred - yTrue))`,
53486 * wherein the `mean` is applied over feature dimensions.
53487 *
53488 * ```js
53489 * const yTrue = tf.tensor2d([[0, 1], [0, 0], [2, 3]]);
53490 * const yPred = tf.tensor2d([[0, 1], [0, 1], [-2, -3]]);
53491 * const mse = tf.metrics.meanAbsoluteError(yTrue, yPred);
53492 * mse.print();
53493 * ```
53494 *
53495 * @param yTrue Truth Tensor.
53496 * @param yPred Prediction Tensor.
53497 * @return Mean absolute error Tensor.
53498 *
53499 * @doc {heading: 'Metrics', namespace: 'metrics'}
53500 */
53501 function meanAbsoluteError$1(yTrue, yPred) {
53502 return meanAbsoluteError(yTrue, yPred);
53503 }
53504 /**
53505 * Loss or metric function: Mean absolute percentage error.
53506 *
53507 * ```js
53508 * const yTrue = tf.tensor2d([[0, 1], [10, 20]]);
53509 * const yPred = tf.tensor2d([[0, 1], [11, 24]]);
53510 * const mse = tf.metrics.meanAbsolutePercentageError(yTrue, yPred);
53511 * mse.print();
53512 * ```
53513 *
53514 * Aliases: `tf.metrics.MAPE`, `tf.metrics.mape`.
53515 *
53516 * @param yTrue Truth Tensor.
53517 * @param yPred Prediction Tensor.
53518 * @return Mean absolute percentage error Tensor.
53519 *
53520 * @doc {heading: 'Metrics', namespace: 'metrics'}
53521 */
53522 function meanAbsolutePercentageError$1(yTrue, yPred) {
53523 return meanAbsolutePercentageError(yTrue, yPred);
53524 }
53525 function MAPE$2(yTrue, yPred) {
53526 return meanAbsolutePercentageError(yTrue, yPred);
53527 }
53528 function mape$2(yTrue, yPred) {
53529 return meanAbsolutePercentageError(yTrue, yPred);
53530 }
53531 /**
53532 * Loss or metric function: Mean squared error.
53533 *
53534 * ```js
53535 * const yTrue = tf.tensor2d([[0, 1], [3, 4]]);
53536 * const yPred = tf.tensor2d([[0, 1], [-3, -4]]);
53537 * const mse = tf.metrics.meanSquaredError(yTrue, yPred);
53538 * mse.print();
53539 * ```
53540 *
53541 * Aliases: `tf.metrics.MSE`, `tf.metrics.mse`.
53542 *
53543 * @param yTrue Truth Tensor.
53544 * @param yPred Prediction Tensor.
53545 * @return Mean squared error Tensor.
53546 *
53547 * @doc {heading: 'Metrics', namespace: 'metrics'}
53548 */
53549 function meanSquaredError$2(yTrue, yPred) {
53550 return meanSquaredError$1(yTrue, yPred);
53551 }
53552 function MSE$2(yTrue, yPred) {
53553 return meanSquaredError$1(yTrue, yPred);
53554 }
53555 function mse$2(yTrue, yPred) {
53556 return meanSquaredError$1(yTrue, yPred);
53557 }
53558
53559 var exports_metrics = /*#__PURE__*/Object.freeze({
53560 __proto__: null,
53561 binaryAccuracy: binaryAccuracy$1,
53562 binaryCrossentropy: binaryCrossentropy$2,
53563 sparseCategoricalAccuracy: sparseCategoricalAccuracy$1,
53564 categoricalAccuracy: categoricalAccuracy$1,
53565 categoricalCrossentropy: categoricalCrossentropy$2,
53566 precision: precision$1,
53567 recall: recall$1,
53568 cosineProximity: cosineProximity$1,
53569 meanAbsoluteError: meanAbsoluteError$1,
53570 meanAbsolutePercentageError: meanAbsolutePercentageError$1,
53571 MAPE: MAPE$2,
53572 mape: mape$2,
53573 meanSquaredError: meanSquaredError$2,
53574 MSE: MSE$2,
53575 mse: mse$2
53576 });
53577
53578 /**
53579 * @license
53580 * Copyright 2018 Google LLC
53581 *
53582 * Use of this source code is governed by an MIT-style
53583 * license that can be found in the LICENSE file or at
53584 * https://opensource.org/licenses/MIT.
53585 * =============================================================================
53586 */
53587
53588 var exports_models = /*#__PURE__*/Object.freeze({
53589 __proto__: null,
53590 modelFromJSON: modelFromJSON
53591 });
53592
53593 /**
53594 * @license
53595 * Copyright 2018 Google LLC
53596 *
53597 * Use of this source code is governed by an MIT-style
53598 * license that can be found in the LICENSE file or at
53599 * https://opensource.org/licenses/MIT.
53600 * =============================================================================
53601 */
53602 /**
53603 * Regularizer for L1 and L2 regularization.
53604 *
53605 * Adds a term to the loss to penalize large weights:
53606 * loss += sum(l1 * abs(x)) + sum(l2 * x^2)
53607 *
53608 * @doc {heading: 'Regularizers', namespace: 'regularizers'}
53609 */
53610 function l1l2(config) {
53611 return new L1L2(config);
53612 }
53613 /**
53614 * Regularizer for L1 regularization.
53615 *
53616 * Adds a term to the loss to penalize large weights:
53617 * loss += sum(l1 * abs(x))
53618 * @param args l1 config.
53619 *
53620 * @doc {heading: 'Regularizers', namespace: 'regularizers'}
53621 */
53622 function l1$1(config) {
53623 return l1(config);
53624 }
53625 /**
53626 * Regularizer for L2 regularization.
53627 *
53628 * Adds a term to the loss to penalize large weights:
53629 * loss += sum(l2 * x^2)
53630 * @param args l2 config.
53631 *
53632 * @doc {heading: 'Regularizers', namespace: 'regularizers'}
53633 */
53634 function l2$1(config) {
53635 return l2(config);
53636 }
53637
53638 var exports_regularizers = /*#__PURE__*/Object.freeze({
53639 __proto__: null,
53640 l1l2: l1l2,
53641 l1: l1$1,
53642 l2: l2$1
53643 });
53644
53645 /**
53646 * @license
53647 * Copyright 2018 Google LLC
53648 *
53649 * Use of this source code is governed by an MIT-style
53650 * license that can be found in the LICENSE file or at
53651 * https://opensource.org/licenses/MIT.
53652 * =============================================================================
53653 */
53654 class Callback extends BaseCallback {
53655 constructor() {
53656 super(...arguments);
53657 /** Instance of `keras.models.Model`. Reference of the model being trained. */
53658 this.model = null;
53659 }
53660 setModel(model) {
53661 if (!(model instanceof LayersModel)) {
53662 throw new Error('model must be a LayersModel, not some other Container');
53663 }
53664 this.model = model;
53665 }
53666 }
53667 function less$1(currVal, prevVal) {
53668 return currVal < prevVal;
53669 }
53670 function greater$1(currVal, prevVal) {
53671 return currVal > prevVal;
53672 }
53673 /**
53674 * A Callback that stops training when a monitored quantity has stopped
53675 * improving.
53676 */
53677 class EarlyStopping extends Callback {
53678 constructor(args) {
53679 super();
53680 if (args == null) {
53681 args = {};
53682 }
53683 if (args.restoreBestWeights) {
53684 throw new NotImplementedError('restoreBestWeights = True is not implemented in EarlyStopping yet.');
53685 }
53686 this.monitor = args.monitor || 'val_loss';
53687 this.minDelta = Math.abs(args.minDelta || 0);
53688 this.patience = args.patience || 0;
53689 this.verbose = args.verbose || 0;
53690 this.mode = args.mode || 'auto';
53691 this.baseline = args.baseline;
53692 if (['auto', 'min', 'max'].indexOf(this.mode) === -1) {
53693 console.warn(`EarlyStopping mode '${this.mode}' is invalid. ` +
53694 `Falling back to mode 'auto'.`);
53695 this.mode = 'auto';
53696 }
53697 if (this.mode === 'min') {
53698 this.monitorFunc = less$1;
53699 }
53700 else if (this.mode === 'max') {
53701 this.monitorFunc = greater$1;
53702 }
53703 else {
53704 // For mode === 'auto'.
53705 if (this.monitor.indexOf('acc') !== -1) {
53706 this.monitorFunc = greater$1;
53707 }
53708 else {
53709 this.monitorFunc = less$1;
53710 }
53711 }
53712 if (this.monitorFunc === less$1) {
53713 this.minDelta *= -1;
53714 }
53715 }
53716 async onTrainBegin(logs) {
53717 this.wait = 0;
53718 this.stoppedEpoch = 0;
53719 if (this.baseline != null) {
53720 this.best = this.baseline;
53721 }
53722 else {
53723 this.best = this.monitorFunc === less$1 ? Infinity : -Infinity;
53724 }
53725 }
53726 async onEpochEnd(epoch, logs) {
53727 await resolveScalarsInLogs(logs);
53728 const current = this.getMonitorValue(logs);
53729 if (current == null) {
53730 return;
53731 }
53732 if (this.monitorFunc(current - this.minDelta, this.best)) {
53733 this.best = current;
53734 this.wait = 0;
53735 // TODO(cais): Logic for restoreBestWeights.
53736 }
53737 else {
53738 this.wait++;
53739 if (this.wait >= this.patience) {
53740 this.stoppedEpoch = epoch;
53741 this.model.stopTraining = true;
53742 }
53743 // TODO(cais): Logic for restoreBestWeights.
53744 }
53745 }
53746 async onTrainEnd(logs) {
53747 if (this.stoppedEpoch > 0 && this.verbose) {
53748 console.log(`Epoch ${this.stoppedEpoch}: early stopping.`);
53749 }
53750 }
53751 getMonitorValue(logs) {
53752 if (logs == null) {
53753 logs = {};
53754 }
53755 const monitorValue = logs[this.monitor];
53756 if (monitorValue == null) {
53757 console.warn(`Metric for EarlyStopping ${this.monitor} is not available. ` +
53758 `Available metrics are: ${Object.keys(logs)}`);
53759 }
53760 return monitorValue;
53761 }
53762 }
53763 /**
53764 * Factory function for a Callback that stops training when a monitored
53765 * quantity has stopped improving.
53766 *
53767 * Early stopping is a type of regularization, and protects model against
53768 * overfitting.
53769 *
53770 * The following example based on fake data illustrates how this callback
53771 * can be used during `tf.LayersModel.fit()`:
53772 *
53773 * ```js
53774 * const model = tf.sequential();
53775 * model.add(tf.layers.dense({
53776 * units: 3,
53777 * activation: 'softmax',
53778 * kernelInitializer: 'ones',
53779 * inputShape: [2]
53780 * }));
53781 * const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
53782 * const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
53783 * const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
53784 * const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
53785 * model.compile(
53786 * {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});
53787 *
53788 * // Without the EarlyStopping callback, the val_acc value would be:
53789 * // 0.5, 0.5, 0.5, 0.5, ...
53790 * // With val_acc being monitored, training should stop after the 2nd epoch.
53791 * const history = await model.fit(xs, ys, {
53792 * epochs: 10,
53793 * validationData: [xsVal, ysVal],
53794 * callbacks: tf.callbacks.earlyStopping({monitor: 'val_acc'})
53795 * });
53796 *
53797 * // Expect to see a length-2 array.
53798 * console.log(history.history.val_acc);
53799 * ```
53800 *
53801 * @doc {
53802 * heading: 'Callbacks',
53803 * namespace: 'callbacks'
53804 * }
53805 */
53806 function earlyStopping(args) {
53807 return new EarlyStopping(args);
53808 }
53809 const callbacks = { earlyStopping };
53810
53811 /**
53812 * @license
53813 * Copyright 2018 Google LLC
53814 *
53815 * Use of this source code is governed by an MIT-style
53816 * license that can be found in the LICENSE file or at
53817 * https://opensource.org/licenses/MIT.
53818 * =============================================================================
53819 */
53820
53821 /**
53822 * @license
53823 * Copyright 2021 Google LLC. All Rights Reserved.
53824 * Licensed under the Apache License, Version 2.0 (the "License");
53825 * you may not use this file except in compliance with the License.
53826 * You may obtain a copy of the License at
53827 *
53828 * http://www.apache.org/licenses/LICENSE-2.0
53829 *
53830 * Unless required by applicable law or agreed to in writing, software
53831 * distributed under the License is distributed on an "AS IS" BASIS,
53832 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53833 * See the License for the specific language governing permissions and
53834 * limitations under the License.
53835 * =============================================================================
53836 */
53837 const ENV$2 = env();
53838 /** Whether to keep intermediate tensors. */
53839 ENV$2.registerFlag('KEEP_INTERMEDIATE_TENSORS', () => false, debugValue => {
53840 if (debugValue) {
53841 console.warn('Keep intermediate tensors is ON. This will print the values of all ' +
53842 'intermediate tensors during model inference. Not all models ' +
53843 'support this mode. For details, check e2e/benchmarks/ ' +
53844 'model_config.js. This significantly impacts performance.');
53845 }
53846 });
53847
53848 /**
53849 * @license
53850 * Copyright 2019 Google LLC. All Rights Reserved.
53851 * Licensed under the Apache License, Version 2.0 (the "License");
53852 * you may not use this file except in compliance with the License.
53853 * You may obtain a copy of the License at
53854 *
53855 * http://www.apache.org/licenses/LICENSE-2.0
53856 *
53857 * Unless required by applicable law or agreed to in writing, software
53858 * distributed under the License is distributed on an "AS IS" BASIS,
53859 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53860 * See the License for the specific language governing permissions and
53861 * limitations under the License.
53862 *
53863 * =============================================================================
53864 */
53865 /** DataType enum. */
53866 var DataType;
53867 (function (DataType) {
53868 // Not a legal value for DataType. Used to indicate a DataType field
53869 // has not been set.
53870 DataType[DataType["DT_INVALID"] = 0] = "DT_INVALID";
53871 // Data types that all computation devices are expected to be
53872 // capable to support.
53873 DataType[DataType["DT_FLOAT"] = 1] = "DT_FLOAT";
53874 DataType[DataType["DT_DOUBLE"] = 2] = "DT_DOUBLE";
53875 DataType[DataType["DT_INT32"] = 3] = "DT_INT32";
53876 DataType[DataType["DT_UINT8"] = 4] = "DT_UINT8";
53877 DataType[DataType["DT_INT16"] = 5] = "DT_INT16";
53878 DataType[DataType["DT_INT8"] = 6] = "DT_INT8";
53879 DataType[DataType["DT_STRING"] = 7] = "DT_STRING";
53880 DataType[DataType["DT_COMPLEX64"] = 8] = "DT_COMPLEX64";
53881 DataType[DataType["DT_INT64"] = 9] = "DT_INT64";
53882 DataType[DataType["DT_BOOL"] = 10] = "DT_BOOL";
53883 DataType[DataType["DT_QINT8"] = 11] = "DT_QINT8";
53884 DataType[DataType["DT_QUINT8"] = 12] = "DT_QUINT8";
53885 DataType[DataType["DT_QINT32"] = 13] = "DT_QINT32";
53886 DataType[DataType["DT_BFLOAT16"] = 14] = "DT_BFLOAT16";
53887 DataType[DataType["DT_QINT16"] = 15] = "DT_QINT16";
53888 DataType[DataType["DT_QUINT16"] = 16] = "DT_QUINT16";
53889 DataType[DataType["DT_UINT16"] = 17] = "DT_UINT16";
53890 DataType[DataType["DT_COMPLEX128"] = 18] = "DT_COMPLEX128";
53891 DataType[DataType["DT_HALF"] = 19] = "DT_HALF";
53892 DataType[DataType["DT_RESOURCE"] = 20] = "DT_RESOURCE";
53893 DataType[DataType["DT_VARIANT"] = 21] = "DT_VARIANT";
53894 DataType[DataType["DT_UINT32"] = 22] = "DT_UINT32";
53895 DataType[DataType["DT_UINT64"] = 23] = "DT_UINT64";
53896 // Do not use! These are only for parameters. Every enum above
53897 // should have a corresponding value below (verified by types_test).
53898 DataType[DataType["DT_FLOAT_REF"] = 101] = "DT_FLOAT_REF";
53899 DataType[DataType["DT_DOUBLE_REF"] = 102] = "DT_DOUBLE_REF";
53900 DataType[DataType["DT_INT32_REF"] = 103] = "DT_INT32_REF";
53901 DataType[DataType["DT_UINT8_REF"] = 104] = "DT_UINT8_REF";
53902 DataType[DataType["DT_INT16_REF"] = 105] = "DT_INT16_REF";
53903 DataType[DataType["DT_INT8_REF"] = 106] = "DT_INT8_REF";
53904 DataType[DataType["DT_STRING_REF"] = 107] = "DT_STRING_REF";
53905 DataType[DataType["DT_COMPLEX64_REF"] = 108] = "DT_COMPLEX64_REF";
53906 DataType[DataType["DT_INT64_REF"] = 109] = "DT_INT64_REF";
53907 DataType[DataType["DT_BOOL_REF"] = 110] = "DT_BOOL_REF";
53908 DataType[DataType["DT_QINT8_REF"] = 111] = "DT_QINT8_REF";
53909 DataType[DataType["DT_QUINT8_REF"] = 112] = "DT_QUINT8_REF";
53910 DataType[DataType["DT_QINT32_REF"] = 113] = "DT_QINT32_REF";
53911 DataType[DataType["DT_BFLOAT16_REF"] = 114] = "DT_BFLOAT16_REF";
53912 DataType[DataType["DT_QINT16_REF"] = 115] = "DT_QINT16_REF";
53913 DataType[DataType["DT_QUINT16_REF"] = 116] = "DT_QUINT16_REF";
53914 DataType[DataType["DT_UINT16_REF"] = 117] = "DT_UINT16_REF";
53915 DataType[DataType["DT_COMPLEX128_REF"] = 118] = "DT_COMPLEX128_REF";
53916 DataType[DataType["DT_HALF_REF"] = 119] = "DT_HALF_REF";
53917 DataType[DataType["DT_RESOURCE_REF"] = 120] = "DT_RESOURCE_REF";
53918 DataType[DataType["DT_VARIANT_REF"] = 121] = "DT_VARIANT_REF";
53919 DataType[DataType["DT_UINT32_REF"] = 122] = "DT_UINT32_REF";
53920 DataType[DataType["DT_UINT64_REF"] = 123] = "DT_UINT64_REF";
53921 })(DataType || (DataType = {}));
53922 var SaverDef;
53923 (function (SaverDef) {
53924 /** CheckpointFormatVersion enum. */
53925 let CheckpointFormatVersion;
53926 (function (CheckpointFormatVersion) {
53927 CheckpointFormatVersion[CheckpointFormatVersion["LEGACY"] = 0] = "LEGACY";
53928 CheckpointFormatVersion[CheckpointFormatVersion["V1"] = 1] = "V1";
53929 CheckpointFormatVersion[CheckpointFormatVersion["V2"] = 2] = "V2";
53930 })(CheckpointFormatVersion = SaverDef.CheckpointFormatVersion || (SaverDef.CheckpointFormatVersion = {}));
53931 })(SaverDef || (SaverDef = {}));
53932
53933 /**
53934 * @license
53935 * Copyright 2019 Google LLC. All Rights Reserved.
53936 * Licensed under the Apache License, Version 2.0 (the "License");
53937 * you may not use this file except in compliance with the License.
53938 * You may obtain a copy of the License at
53939 *
53940 * http://www.apache.org/licenses/LICENSE-2.0
53941 *
53942 * Unless required by applicable law or agreed to in writing, software
53943 * distributed under the License is distributed on an "AS IS" BASIS,
53944 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53945 * See the License for the specific language governing permissions and
53946 * limitations under the License.
53947 * =============================================================================
53948 */
53949 const CUSTOM_OPS = {};
53950 /**
53951 * Register an Op for graph model executor. This allow you to register
53952 * TensorFlow custom op or override existing op.
53953 *
53954 * Here is an example of registering a new MatMul Op.
53955 * ```js
53956 * const customMatmul = (node) =>
53957 * tf.matMul(
53958 * node.inputs[0], node.inputs[1],
53959 * node.attrs['transpose_a'], node.attrs['transpose_b']);
53960 *
53961 * tf.registerOp('MatMul', customMatmul);
53962 * ```
53963 * The inputs and attrs of the node object is based on the TensorFlow op
53964 * registry.
53965 *
53966 * @param name The Tensorflow Op name.
53967 * @param opFunc An op function which is called with the current graph node
53968 * during execution and needs to return a tensor or a list of tensors. The node
53969 * has the following attributes:
53970 * - attr: A map from attribute name to its value
53971 * - inputs: A list of input tensors
53972 *
53973 * @doc {heading: 'Models', subheading: 'Op Registry'}
53974 */
53975 function registerOp(name, opFunc) {
53976 const opMapper = {
53977 tfOpName: name,
53978 category: 'custom',
53979 inputs: [],
53980 attrs: [],
53981 customExecutor: opFunc
53982 };
53983 CUSTOM_OPS[name] = opMapper;
53984 }
53985 /**
53986 * Retrieve the OpMapper object for the registered op.
53987 *
53988 * @param name The Tensorflow Op name.
53989 *
53990 * @doc {heading: 'Models', subheading: 'Op Registry'}
53991 */
53992 function getRegisteredOp(name) {
53993 return CUSTOM_OPS[name];
53994 }
53995 /**
53996 * Deregister the Op for graph model executor.
53997 *
53998 * @param name The Tensorflow Op name.
53999 *
54000 * @doc {heading: 'Models', subheading: 'Op Registry'}
54001 */
54002 function deregisterOp(name) {
54003 delete CUSTOM_OPS[name];
54004 }
54005
54006 /**
54007 * @license
54008 * Copyright 2018 Google LLC. All Rights Reserved.
54009 * Licensed under the Apache License, Version 2.0 (the "License");
54010 * you may not use this file except in compliance with the License.
54011 * You may obtain a copy of the License at
54012 *
54013 * http://www.apache.org/licenses/LICENSE-2.0
54014 *
54015 * Unless required by applicable law or agreed to in writing, software
54016 * distributed under the License is distributed on an "AS IS" BASIS,
54017 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54018 * See the License for the specific language governing permissions and
54019 * limitations under the License.
54020 * =============================================================================
54021 */
54022 function getParamValue(paramName, node, tensorMap, context, resourceManager) {
54023 const inputParam = node.inputParams[paramName];
54024 if (inputParam && inputParam.inputIndexStart !== undefined) {
54025 const start = inputParam.inputIndexStart;
54026 const end = inputParam.inputIndexEnd === 0 ?
54027 undefined :
54028 (inputParam.inputIndexEnd === undefined ? start + 1 :
54029 inputParam.inputIndexEnd);
54030 if (inputParam.type === 'tensor') {
54031 return getTensor(node.inputNames[inputParam.inputIndexStart], tensorMap, context, resourceManager);
54032 }
54033 if (inputParam.type === 'tensors') {
54034 const inputs = node.inputNames.slice(start, end);
54035 return inputs.map(name => getTensor(name, tensorMap, context, resourceManager));
54036 }
54037 const tensor = getTensor(node.inputNames.slice(start)[0], tensorMap, context, resourceManager);
54038 const data = tensor.dataSync();
54039 return inputParam.type === 'number' ?
54040 data[0] :
54041 toNestedArray(tensor.shape, data);
54042 }
54043 const attrParam = node.attrParams[paramName];
54044 return attrParam && attrParam.value;
54045 }
54046 /**
54047 * Retrieve the tensor from tensorsMap based on input name.
54048 * @param name Node input name
54049 * @param tensorsMap Tensors map keyed by the node
54050 * @param context contains tensors and information for running the current node.
54051 * @param resourceManager Optional. Contains global resources of the model.
54052 */
54053 function getTensor(name, tensorsMap, context, resourceManager) {
54054 const [nodeName, index] = parseNodeName(name);
54055 if (resourceManager != null) {
54056 const tensor = resourceManager.getHashTableHandleByName(nodeName);
54057 if (tensor != null) {
54058 return tensor;
54059 }
54060 }
54061 const contextId = context.currentContextIds.find(contextId => {
54062 return !!tensorsMap[getNodeNameWithContextId(nodeName, contextId)];
54063 });
54064 return contextId !== undefined ?
54065 tensorsMap[getNodeNameWithContextId(nodeName, contextId)][index] :
54066 undefined;
54067 }
54068 /**
54069 * Retrieve the tensors based on input name for current context.
54070 * @param name Node input name
54071 * @param tensorsMap Tensors map keyed by the node
54072 */
54073 function getTensorsForCurrentContenxt(name, tensorsMap, context) {
54074 return tensorsMap[getNodeNameWithContextId(name, context.currentContextId)];
54075 }
54076 /**
54077 * Returns the node name, outputName and index from the Node input name.
54078 * @param inputName The input name of the node, in format of
54079 * node_name:output_index, i.e. MatMul:0, if the output_index is not set, it is
54080 * default to 0.
54081 * If the input name contains output name i.e. StringSplit:indices:0, it will
54082 * return ['StringSplit', 0, 'indices'].
54083 */
54084 function getNodeNameAndIndex(inputName, context) {
54085 const [nodeName, index, outputName] = parseNodeName(inputName);
54086 return [
54087 getNodeNameWithContextId(nodeName, context && context.currentContextId),
54088 index, outputName
54089 ];
54090 }
54091 function getNodeNameWithContextId(name, contextId) {
54092 return !!contextId ? `${name}-${contextId}` : name;
54093 }
54094 function parseNodeName(name) {
54095 const parts = name.split(':');
54096 if (parts.length === 1) {
54097 return [name, 0, undefined];
54098 }
54099 const nodeName = parts[0];
54100 const outputName = parts.length === 3 ? parts[1] : undefined;
54101 const index = Number(parts[parts.length - 1]);
54102 return [nodeName, index, outputName];
54103 }
54104 function split$1(arr, size) {
54105 const res = [];
54106 for (let i = 0; i < arr.length; i += size) {
54107 res.push(arr.slice(i, i + size));
54108 }
54109 return res;
54110 }
54111 function getPadding(node, tensorMap, context) {
54112 let pad = getParamValue('pad', node, tensorMap, context);
54113 if (pad === 'explicit') {
54114 // This is 1d array, we need to convert it to 2d array
54115 pad = getParamValue('explicitPaddings', node, tensorMap, context);
54116 const explicitPadding = [[0, 0], [0, 0], [0, 0], [0, 0]];
54117 for (let i = 0; i < 4; i++) {
54118 explicitPadding[i][0] = pad[i * 2];
54119 explicitPadding[i][1] = pad[i * 2 + 1];
54120 }
54121 return explicitPadding;
54122 }
54123 return pad;
54124 }
54125 /**
54126 * Reuse the tensor if it is marked as keep, otherwise clone the tensor to
54127 * avoid disposal. This is important for TensorArray and TensorList ops, since
54128 * internally they use a tensor as the id for TensorArray and TensorList, and
54129 * to simplify lookup, they also use Tensor.id as the key to the internal map.
54130 * These id tensors have been marked as kept in the backend, we need avoid clone
54131 * them in order to create new Tensor.id.
54132 * @param tensor
54133 */
54134 function cloneTensor(tensor) {
54135 return tensor.kept ? tensor : clone(tensor);
54136 }
54137
54138 /**
54139 * @license
54140 * Copyright 2022 Google LLC. All Rights Reserved.
54141 * Licensed under the Apache License, Version 2.0 (the "License");
54142 * you may not use this file except in compliance with the License.
54143 * You may obtain a copy of the License at
54144 *
54145 * http://www.apache.org/licenses/LICENSE-2.0
54146 *
54147 * Unless required by applicable law or agreed to in writing, software
54148 * distributed under the License is distributed on an "AS IS" BASIS,
54149 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54150 * See the License for the specific language governing permissions and
54151 * limitations under the License.
54152 * =============================================================================
54153 */
54154 const json = [
54155 {
54156 'tfOpName': 'Add',
54157 'category': 'arithmetic',
54158 'inputs': [
54159 {
54160 'start': 0,
54161 'name': 'a',
54162 'type': 'tensor'
54163 },
54164 {
54165 'start': 1,
54166 'name': 'b',
54167 'type': 'tensor'
54168 }
54169 ],
54170 'attrs': [
54171 {
54172 'tfName': 'T',
54173 'name': 'dtype',
54174 'type': 'dtype',
54175 'notSupported': true
54176 }
54177 ]
54178 },
54179 {
54180 'tfOpName': 'AddV2',
54181 'category': 'arithmetic',
54182 'inputs': [
54183 {
54184 'start': 0,
54185 'name': 'a',
54186 'type': 'tensor'
54187 },
54188 {
54189 'start': 1,
54190 'name': 'b',
54191 'type': 'tensor'
54192 }
54193 ],
54194 'attrs': [
54195 {
54196 'tfName': 'T',
54197 'name': 'dtype',
54198 'type': 'dtype',
54199 'notSupported': true
54200 }
54201 ]
54202 },
54203 {
54204 'tfOpName': 'AddN',
54205 'category': 'arithmetic',
54206 'inputs': [
54207 {
54208 'start': 0,
54209 'end': 0,
54210 'name': 'tensors',
54211 'type': 'tensors'
54212 }
54213 ]
54214 },
54215 {
54216 'tfOpName': 'BiasAdd',
54217 'category': 'arithmetic',
54218 'inputs': [
54219 {
54220 'start': 0,
54221 'name': 'a',
54222 'type': 'tensor'
54223 },
54224 {
54225 'start': 1,
54226 'name': 'b',
54227 'type': 'tensor'
54228 }
54229 ],
54230 'attrs': [
54231 {
54232 'tfName': 'T',
54233 'name': 'dtype',
54234 'type': 'dtype',
54235 'notSupported': true
54236 },
54237 {
54238 'tfName': 'data_format',
54239 'name': 'dataFormat',
54240 'type': 'string',
54241 'notSupported': true
54242 }
54243 ]
54244 },
54245 {
54246 'tfOpName': 'Sub',
54247 'category': 'arithmetic',
54248 'inputs': [
54249 {
54250 'start': 0,
54251 'name': 'a',
54252 'type': 'tensor'
54253 },
54254 {
54255 'start': 1,
54256 'name': 'b',
54257 'type': 'tensor'
54258 }
54259 ],
54260 'attrs': [
54261 {
54262 'tfName': 'T',
54263 'name': 'dtype',
54264 'type': 'dtype',
54265 'notSupported': true
54266 }
54267 ]
54268 },
54269 {
54270 'tfOpName': 'RealDiv',
54271 'category': 'arithmetic',
54272 'inputs': [
54273 {
54274 'start': 0,
54275 'name': 'a',
54276 'type': 'tensor'
54277 },
54278 {
54279 'start': 1,
54280 'name': 'b',
54281 'type': 'tensor'
54282 }
54283 ],
54284 'attrs': [
54285 {
54286 'tfName': 'T',
54287 'name': 'dtype',
54288 'type': 'dtype',
54289 'notSupported': true
54290 }
54291 ]
54292 },
54293 {
54294 'tfOpName': 'Div',
54295 'category': 'arithmetic',
54296 'inputs': [
54297 {
54298 'start': 0,
54299 'name': 'a',
54300 'type': 'tensor'
54301 },
54302 {
54303 'start': 1,
54304 'name': 'b',
54305 'type': 'tensor'
54306 }
54307 ],
54308 'attrs': [
54309 {
54310 'tfName': 'T',
54311 'name': 'dtype',
54312 'type': 'dtype',
54313 'notSupported': true
54314 }
54315 ]
54316 },
54317 {
54318 'tfOpName': 'DivNoNan',
54319 'category': 'arithmetic',
54320 'inputs': [
54321 {
54322 'start': 0,
54323 'name': 'a',
54324 'type': 'tensor'
54325 },
54326 {
54327 'start': 1,
54328 'name': 'b',
54329 'type': 'tensor'
54330 }
54331 ],
54332 'attrs': [
54333 {
54334 'tfName': 'T',
54335 'name': 'dtype',
54336 'type': 'dtype',
54337 'notSupported': true
54338 }
54339 ]
54340 },
54341 {
54342 'tfOpName': 'FloorDiv',
54343 'category': 'arithmetic',
54344 'inputs': [
54345 {
54346 'start': 0,
54347 'name': 'a',
54348 'type': 'tensor'
54349 },
54350 {
54351 'start': 1,
54352 'name': 'b',
54353 'type': 'tensor'
54354 }
54355 ],
54356 'attrs': [
54357 {
54358 'tfName': 'T',
54359 'name': 'dtype',
54360 'type': 'dtype',
54361 'notSupported': true
54362 }
54363 ]
54364 },
54365 {
54366 'tfOpName': 'Mul',
54367 'category': 'arithmetic',
54368 'inputs': [
54369 {
54370 'start': 0,
54371 'name': 'a',
54372 'type': 'tensor'
54373 },
54374 {
54375 'start': 1,
54376 'name': 'b',
54377 'type': 'tensor'
54378 }
54379 ],
54380 'attrs': [
54381 {
54382 'tfName': 'T',
54383 'name': 'dtype',
54384 'type': 'dtype',
54385 'notSupported': true
54386 }
54387 ]
54388 },
54389 {
54390 'tfOpName': 'Maximum',
54391 'category': 'arithmetic',
54392 'inputs': [
54393 {
54394 'start': 0,
54395 'name': 'a',
54396 'type': 'tensor'
54397 },
54398 {
54399 'start': 1,
54400 'name': 'b',
54401 'type': 'tensor'
54402 }
54403 ],
54404 'attrs': [
54405 {
54406 'tfName': 'T',
54407 'name': 'dtype',
54408 'type': 'dtype',
54409 'notSupported': true
54410 }
54411 ]
54412 },
54413 {
54414 'tfOpName': 'Minimum',
54415 'category': 'arithmetic',
54416 'inputs': [
54417 {
54418 'start': 0,
54419 'name': 'a',
54420 'type': 'tensor'
54421 },
54422 {
54423 'start': 1,
54424 'name': 'b',
54425 'type': 'tensor'
54426 }
54427 ],
54428 'attrs': [
54429 {
54430 'tfName': 'T',
54431 'name': 'dtype',
54432 'type': 'dtype',
54433 'notSupported': true
54434 }
54435 ]
54436 },
54437 {
54438 'tfOpName': 'Pow',
54439 'category': 'arithmetic',
54440 'inputs': [
54441 {
54442 'start': 0,
54443 'name': 'a',
54444 'type': 'tensor'
54445 },
54446 {
54447 'start': 1,
54448 'name': 'b',
54449 'type': 'tensor'
54450 }
54451 ],
54452 'attrs': [
54453 {
54454 'tfName': 'T',
54455 'name': 'dtype',
54456 'type': 'dtype',
54457 'notSupported': true
54458 }
54459 ]
54460 },
54461 {
54462 'tfOpName': 'SquaredDifference',
54463 'category': 'arithmetic',
54464 'inputs': [
54465 {
54466 'start': 0,
54467 'name': 'a',
54468 'type': 'tensor'
54469 },
54470 {
54471 'start': 1,
54472 'name': 'b',
54473 'type': 'tensor'
54474 }
54475 ],
54476 'attrs': [
54477 {
54478 'tfName': 'T',
54479 'name': 'dtype',
54480 'type': 'dtype',
54481 'notSupported': true
54482 }
54483 ]
54484 },
54485 {
54486 'tfOpName': 'Mod',
54487 'category': 'arithmetic',
54488 'inputs': [
54489 {
54490 'start': 0,
54491 'name': 'a',
54492 'type': 'tensor'
54493 },
54494 {
54495 'start': 1,
54496 'name': 'b',
54497 'type': 'tensor'
54498 }
54499 ],
54500 'attrs': [
54501 {
54502 'tfName': 'T',
54503 'name': 'dtype',
54504 'type': 'dtype',
54505 'notSupported': true
54506 }
54507 ]
54508 },
54509 {
54510 'tfOpName': 'FloorMod',
54511 'category': 'arithmetic',
54512 'inputs': [
54513 {
54514 'start': 0,
54515 'name': 'a',
54516 'type': 'tensor'
54517 },
54518 {
54519 'start': 1,
54520 'name': 'b',
54521 'type': 'tensor'
54522 }
54523 ],
54524 'attrs': [
54525 {
54526 'tfName': 'T',
54527 'name': 'dtype',
54528 'type': 'dtype',
54529 'notSupported': true
54530 }
54531 ]
54532 }
54533 ];
54534
54535 var arithmetic = /*#__PURE__*/Object.freeze({
54536 __proto__: null,
54537 json: json
54538 });
54539
54540 /**
54541 * @license
54542 * Copyright 2022 Google LLC. All Rights Reserved.
54543 * Licensed under the Apache License, Version 2.0 (the "License");
54544 * you may not use this file except in compliance with the License.
54545 * You may obtain a copy of the License at
54546 *
54547 * http://www.apache.org/licenses/LICENSE-2.0
54548 *
54549 * Unless required by applicable law or agreed to in writing, software
54550 * distributed under the License is distributed on an "AS IS" BASIS,
54551 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54552 * See the License for the specific language governing permissions and
54553 * limitations under the License.
54554 * =============================================================================
54555 */
54556 const json$1 = [
54557 {
54558 'tfOpName': 'Abs',
54559 'category': 'basic_math',
54560 'inputs': [
54561 {
54562 'start': 0,
54563 'name': 'x',
54564 'type': 'tensor'
54565 }
54566 ],
54567 'attrs': [
54568 {
54569 'tfName': 'T',
54570 'name': 'dtype',
54571 'type': 'dtype',
54572 'notSupported': true
54573 }
54574 ]
54575 },
54576 {
54577 'tfOpName': 'Acos',
54578 'category': 'basic_math',
54579 'inputs': [
54580 {
54581 'start': 0,
54582 'name': 'x',
54583 'type': 'tensor'
54584 }
54585 ],
54586 'attrs': [
54587 {
54588 'tfName': 'T',
54589 'name': 'dtype',
54590 'type': 'dtype',
54591 'notSupported': true
54592 }
54593 ]
54594 },
54595 {
54596 'tfOpName': 'Asin',
54597 'category': 'basic_math',
54598 'inputs': [
54599 {
54600 'start': 0,
54601 'name': 'x',
54602 'type': 'tensor'
54603 }
54604 ],
54605 'attrs': [
54606 {
54607 'tfName': 'T',
54608 'name': 'dtype',
54609 'type': 'dtype',
54610 'notSupported': true
54611 }
54612 ]
54613 },
54614 {
54615 'tfOpName': 'Atan',
54616 'category': 'basic_math',
54617 'inputs': [
54618 {
54619 'start': 0,
54620 'name': 'x',
54621 'type': 'tensor'
54622 }
54623 ],
54624 'attrs': [
54625 {
54626 'tfName': 'T',
54627 'name': 'dtype',
54628 'type': 'dtype',
54629 'notSupported': true
54630 }
54631 ]
54632 },
54633 {
54634 'tfOpName': 'Atan2',
54635 'category': 'basic_math',
54636 'inputs': [
54637 {
54638 'start': 0,
54639 'name': 'x',
54640 'type': 'tensor'
54641 },
54642 {
54643 'start': 1,
54644 'name': 'y',
54645 'type': 'tensor'
54646 }
54647 ],
54648 'attrs': [
54649 {
54650 'tfName': 'T',
54651 'name': 'dtype',
54652 'type': 'dtype',
54653 'notSupported': true
54654 }
54655 ]
54656 },
54657 {
54658 'tfOpName': 'Ceil',
54659 'category': 'basic_math',
54660 'inputs': [
54661 {
54662 'start': 0,
54663 'name': 'x',
54664 'type': 'tensor'
54665 }
54666 ],
54667 'attrs': [
54668 {
54669 'tfName': 'T',
54670 'name': 'dtype',
54671 'type': 'dtype',
54672 'notSupported': true
54673 }
54674 ]
54675 },
54676 {
54677 'tfOpName': 'ClipByValue',
54678 'category': 'basic_math',
54679 'inputs': [
54680 {
54681 'start': 0,
54682 'name': 'x',
54683 'type': 'tensor'
54684 },
54685 {
54686 'start': 1,
54687 'name': 'clipValueMin',
54688 'type': 'number'
54689 },
54690 {
54691 'start': 2,
54692 'name': 'clipValueMax',
54693 'type': 'number'
54694 }
54695 ],
54696 'attrs': [
54697 {
54698 'tfName': 'T',
54699 'name': 'dtype',
54700 'type': 'dtype',
54701 'notSupported': true
54702 }
54703 ]
54704 },
54705 {
54706 'tfOpName': 'Complex',
54707 'category': 'basic_math',
54708 'inputs': [
54709 {
54710 'start': 0,
54711 'name': 'real',
54712 'type': 'tensor'
54713 },
54714 {
54715 'start': 1,
54716 'name': 'imag',
54717 'type': 'tensor'
54718 }
54719 ],
54720 'attrs': [
54721 {
54722 'tfName': 'T',
54723 'name': 'dtype',
54724 'type': 'dtype',
54725 'notSupported': true
54726 }
54727 ]
54728 },
54729 {
54730 'tfOpName': 'ComplexAbs',
54731 'category': 'basic_math',
54732 'inputs': [
54733 {
54734 'start': 0,
54735 'name': 'x',
54736 'type': 'tensor'
54737 }
54738 ],
54739 'attrs': [
54740 {
54741 'tfName': 'T',
54742 'name': 'dtype',
54743 'type': 'dtype',
54744 'notSupported': true
54745 }
54746 ]
54747 },
54748 {
54749 'tfOpName': 'Cos',
54750 'category': 'basic_math',
54751 'inputs': [
54752 {
54753 'start': 0,
54754 'name': 'x',
54755 'type': 'tensor'
54756 }
54757 ],
54758 'attrs': [
54759 {
54760 'tfName': 'T',
54761 'name': 'dtype',
54762 'type': 'dtype',
54763 'notSupported': true
54764 }
54765 ]
54766 },
54767 {
54768 'tfOpName': 'Cosh',
54769 'category': 'basic_math',
54770 'inputs': [
54771 {
54772 'start': 0,
54773 'name': 'x',
54774 'type': 'tensor'
54775 }
54776 ],
54777 'attrs': [
54778 {
54779 'tfName': 'T',
54780 'name': 'dtype',
54781 'type': 'dtype',
54782 'notSupported': true
54783 }
54784 ]
54785 },
54786 {
54787 'tfOpName': 'Elu',
54788 'category': 'basic_math',
54789 'inputs': [
54790 {
54791 'start': 0,
54792 'name': 'x',
54793 'type': 'tensor'
54794 }
54795 ],
54796 'attrs': [
54797 {
54798 'tfName': 'T',
54799 'name': 'dtype',
54800 'type': 'dtype',
54801 'notSupported': true
54802 }
54803 ]
54804 },
54805 {
54806 'tfOpName': 'Exp',
54807 'category': 'basic_math',
54808 'inputs': [
54809 {
54810 'start': 0,
54811 'name': 'x',
54812 'type': 'tensor'
54813 }
54814 ],
54815 'attrs': [
54816 {
54817 'tfName': 'T',
54818 'name': 'dtype',
54819 'type': 'dtype',
54820 'notSupported': true
54821 }
54822 ]
54823 },
54824 {
54825 'tfOpName': 'Floor',
54826 'category': 'basic_math',
54827 'inputs': [
54828 {
54829 'start': 0,
54830 'name': 'x',
54831 'type': 'tensor'
54832 }
54833 ],
54834 'attrs': [
54835 {
54836 'tfName': 'T',
54837 'name': 'dtype',
54838 'type': 'dtype',
54839 'notSupported': true
54840 }
54841 ]
54842 },
54843 {
54844 'tfOpName': 'Log',
54845 'category': 'basic_math',
54846 'inputs': [
54847 {
54848 'start': 0,
54849 'name': 'x',
54850 'type': 'tensor'
54851 }
54852 ],
54853 'attrs': [
54854 {
54855 'tfName': 'T',
54856 'name': 'dtype',
54857 'type': 'dtype',
54858 'notSupported': true
54859 }
54860 ]
54861 },
54862 {
54863 'tfOpName': 'Imag',
54864 'category': 'basic_math',
54865 'inputs': [
54866 {
54867 'start': 0,
54868 'name': 'x',
54869 'type': 'tensor'
54870 }
54871 ],
54872 'attrs': [
54873 {
54874 'tfName': 'T',
54875 'name': 'dtype',
54876 'type': 'dtype',
54877 'notSupported': true
54878 },
54879 {
54880 'tfName': 'Tout',
54881 'name': 'outputType',
54882 'type': 'dtype',
54883 'notSupported': true
54884 }
54885 ]
54886 },
54887 {
54888 'tfOpName': 'Neg',
54889 'category': 'basic_math',
54890 'inputs': [
54891 {
54892 'start': 0,
54893 'name': 'x',
54894 'type': 'tensor'
54895 }
54896 ],
54897 'attrs': [
54898 {
54899 'tfName': 'T',
54900 'name': 'dtype',
54901 'type': 'dtype',
54902 'notSupported': true
54903 }
54904 ]
54905 },
54906 {
54907 'tfOpName': 'Real',
54908 'category': 'basic_math',
54909 'inputs': [
54910 {
54911 'start': 0,
54912 'name': 'x',
54913 'type': 'tensor'
54914 }
54915 ],
54916 'attrs': [
54917 {
54918 'tfName': 'T',
54919 'name': 'dtype',
54920 'type': 'dtype',
54921 'notSupported': true
54922 },
54923 {
54924 'tfName': 'Tout',
54925 'name': 'outputType',
54926 'type': 'dtype',
54927 'notSupported': true
54928 }
54929 ]
54930 },
54931 {
54932 'tfOpName': 'Prelu',
54933 'category': 'basic_math',
54934 'inputs': [
54935 {
54936 'start': 0,
54937 'name': 'x',
54938 'type': 'tensor'
54939 },
54940 {
54941 'start': 1,
54942 'name': 'alpha',
54943 'type': 'tensor'
54944 }
54945 ],
54946 'attrs': [
54947 {
54948 'tfName': 'T',
54949 'name': 'dtype',
54950 'type': 'dtype',
54951 'notSupported': true
54952 }
54953 ]
54954 },
54955 {
54956 'tfOpName': 'Relu',
54957 'category': 'basic_math',
54958 'inputs': [
54959 {
54960 'start': 0,
54961 'name': 'x',
54962 'type': 'tensor'
54963 }
54964 ],
54965 'attrs': [
54966 {
54967 'tfName': 'T',
54968 'name': 'dtype',
54969 'type': 'dtype',
54970 'notSupported': true
54971 }
54972 ]
54973 },
54974 {
54975 'tfOpName': 'Relu6',
54976 'category': 'basic_math',
54977 'inputs': [
54978 {
54979 'start': 0,
54980 'name': 'x',
54981 'type': 'tensor'
54982 }
54983 ],
54984 'attrs': [
54985 {
54986 'tfName': 'T',
54987 'name': 'dtype',
54988 'type': 'dtype',
54989 'notSupported': true
54990 }
54991 ]
54992 },
54993 {
54994 'tfOpName': 'Selu',
54995 'category': 'basic_math',
54996 'inputs': [
54997 {
54998 'start': 0,
54999 'name': 'x',
55000 'type': 'tensor'
55001 }
55002 ],
55003 'attrs': [
55004 {
55005 'tfName': 'T',
55006 'name': 'dtype',
55007 'type': 'dtype',
55008 'notSupported': true
55009 }
55010 ]
55011 },
55012 {
55013 'tfOpName': 'Sigmoid',
55014 'category': 'basic_math',
55015 'inputs': [
55016 {
55017 'start': 0,
55018 'name': 'x',
55019 'type': 'tensor'
55020 }
55021 ],
55022 'attrs': [
55023 {
55024 'tfName': 'T',
55025 'name': 'dtype',
55026 'type': 'dtype',
55027 'notSupported': true
55028 }
55029 ]
55030 },
55031 {
55032 'tfOpName': 'Sin',
55033 'category': 'basic_math',
55034 'inputs': [
55035 {
55036 'start': 0,
55037 'name': 'x',
55038 'type': 'tensor'
55039 }
55040 ],
55041 'attrs': [
55042 {
55043 'tfName': 'T',
55044 'name': 'dtype',
55045 'type': 'dtype',
55046 'notSupported': true
55047 }
55048 ]
55049 },
55050 {
55051 'tfOpName': 'Sinh',
55052 'category': 'basic_math',
55053 'inputs': [
55054 {
55055 'start': 0,
55056 'name': 'x',
55057 'type': 'tensor'
55058 }
55059 ],
55060 'attrs': [
55061 {
55062 'tfName': 'T',
55063 'name': 'dtype',
55064 'type': 'dtype',
55065 'notSupported': true
55066 }
55067 ]
55068 },
55069 {
55070 'tfOpName': 'Sqrt',
55071 'category': 'basic_math',
55072 'inputs': [
55073 {
55074 'start': 0,
55075 'name': 'x',
55076 'type': 'tensor'
55077 }
55078 ],
55079 'attrs': [
55080 {
55081 'tfName': 'T',
55082 'name': 'dtype',
55083 'type': 'dtype',
55084 'notSupported': true
55085 }
55086 ]
55087 },
55088 {
55089 'tfOpName': 'Rsqrt',
55090 'category': 'basic_math',
55091 'inputs': [
55092 {
55093 'start': 0,
55094 'name': 'x',
55095 'type': 'tensor'
55096 }
55097 ],
55098 'attrs': [
55099 {
55100 'tfName': 'T',
55101 'name': 'dtype',
55102 'type': 'dtype',
55103 'notSupported': true
55104 }
55105 ]
55106 },
55107 {
55108 'tfOpName': 'Square',
55109 'category': 'basic_math',
55110 'inputs': [
55111 {
55112 'start': 0,
55113 'name': 'x',
55114 'type': 'tensor'
55115 }
55116 ],
55117 'attrs': [
55118 {
55119 'tfName': 'T',
55120 'name': 'dtype',
55121 'type': 'dtype',
55122 'notSupported': true
55123 }
55124 ]
55125 },
55126 {
55127 'tfOpName': 'Tan',
55128 'category': 'basic_math',
55129 'inputs': [
55130 {
55131 'start': 0,
55132 'name': 'x',
55133 'type': 'tensor'
55134 }
55135 ],
55136 'attrs': [
55137 {
55138 'tfName': 'T',
55139 'name': 'dtype',
55140 'type': 'dtype',
55141 'notSupported': true
55142 }
55143 ]
55144 },
55145 {
55146 'tfOpName': 'Tanh',
55147 'category': 'basic_math',
55148 'inputs': [
55149 {
55150 'start': 0,
55151 'name': 'x',
55152 'type': 'tensor'
55153 }
55154 ],
55155 'attrs': [
55156 {
55157 'tfName': 'T',
55158 'name': 'dtype',
55159 'type': 'dtype',
55160 'notSupported': true
55161 }
55162 ]
55163 },
55164 {
55165 'tfOpName': 'Sign',
55166 'category': 'basic_math',
55167 'inputs': [
55168 {
55169 'start': 0,
55170 'name': 'x',
55171 'type': 'tensor'
55172 }
55173 ],
55174 'attrs': [
55175 {
55176 'tfName': 'T',
55177 'name': 'dtype',
55178 'type': 'dtype',
55179 'notSupported': true
55180 }
55181 ]
55182 },
55183 {
55184 'tfOpName': 'Round',
55185 'category': 'basic_math',
55186 'inputs': [
55187 {
55188 'start': 0,
55189 'name': 'x',
55190 'type': 'tensor'
55191 }
55192 ],
55193 'attrs': [
55194 {
55195 'tfName': 'T',
55196 'name': 'dtype',
55197 'type': 'dtype',
55198 'notSupported': true
55199 }
55200 ]
55201 },
55202 {
55203 'tfOpName': 'Expm1',
55204 'category': 'basic_math',
55205 'inputs': [
55206 {
55207 'start': 0,
55208 'name': 'x',
55209 'type': 'tensor'
55210 }
55211 ],
55212 'attrs': [
55213 {
55214 'tfName': 'T',
55215 'name': 'dtype',
55216 'type': 'dtype',
55217 'notSupported': true
55218 }
55219 ]
55220 },
55221 {
55222 'tfOpName': 'Log1p',
55223 'category': 'basic_math',
55224 'inputs': [
55225 {
55226 'start': 0,
55227 'name': 'x',
55228 'type': 'tensor'
55229 }
55230 ],
55231 'attrs': [
55232 {
55233 'tfName': 'T',
55234 'name': 'dtype',
55235 'type': 'dtype',
55236 'notSupported': true
55237 }
55238 ]
55239 },
55240 {
55241 'tfOpName': 'Reciprocal',
55242 'category': 'basic_math',
55243 'inputs': [
55244 {
55245 'start': 0,
55246 'name': 'x',
55247 'type': 'tensor'
55248 }
55249 ],
55250 'attrs': [
55251 {
55252 'tfName': 'T',
55253 'name': 'dtype',
55254 'type': 'dtype',
55255 'notSupported': true
55256 }
55257 ]
55258 },
55259 {
55260 'tfOpName': 'Softplus',
55261 'category': 'basic_math',
55262 'inputs': [
55263 {
55264 'start': 0,
55265 'name': 'x',
55266 'type': 'tensor'
55267 }
55268 ],
55269 'attrs': [
55270 {
55271 'tfName': 'T',
55272 'name': 'dtype',
55273 'type': 'dtype',
55274 'notSupported': true
55275 }
55276 ]
55277 },
55278 {
55279 'tfOpName': 'Asinh',
55280 'category': 'basic_math',
55281 'inputs': [
55282 {
55283 'start': 0,
55284 'name': 'x',
55285 'type': 'tensor'
55286 }
55287 ],
55288 'attrs': [
55289 {
55290 'tfName': 'T',
55291 'name': 'dtype',
55292 'type': 'dtype',
55293 'notSupported': true
55294 }
55295 ]
55296 },
55297 {
55298 'tfOpName': 'Acosh',
55299 'category': 'basic_math',
55300 'inputs': [
55301 {
55302 'start': 0,
55303 'name': 'x',
55304 'type': 'tensor'
55305 }
55306 ],
55307 'attrs': [
55308 {
55309 'tfName': 'T',
55310 'name': 'dtype',
55311 'type': 'dtype',
55312 'notSupported': true
55313 }
55314 ]
55315 },
55316 {
55317 'tfOpName': 'Atanh',
55318 'category': 'basic_math',
55319 'inputs': [
55320 {
55321 'start': 0,
55322 'name': 'x',
55323 'type': 'tensor'
55324 }
55325 ],
55326 'attrs': [
55327 {
55328 'tfName': 'T',
55329 'name': 'dtype',
55330 'type': 'dtype',
55331 'notSupported': true
55332 }
55333 ]
55334 },
55335 {
55336 'tfOpName': 'Erf',
55337 'category': 'basic_math',
55338 'inputs': [
55339 {
55340 'start': 0,
55341 'name': 'x',
55342 'type': 'tensor'
55343 }
55344 ],
55345 'attrs': [
55346 {
55347 'tfName': 'T',
55348 'name': 'dtype',
55349 'type': 'dtype',
55350 'notSupported': true
55351 }
55352 ]
55353 },
55354 {
55355 'tfOpName': 'Prod',
55356 'category': 'basic_math',
55357 'inputs': [
55358 {
55359 'start': 0,
55360 'name': 'x',
55361 'type': 'tensor'
55362 },
55363 {
55364 'start': 1,
55365 'name': 'axes',
55366 'type': 'number[]'
55367 }
55368 ],
55369 'attrs': [
55370 {
55371 'tfName': 'keep_dims',
55372 'name': 'keepDims',
55373 'type': 'bool',
55374 'notSupported': true
55375 },
55376 {
55377 'tfName': 'T',
55378 'name': 'dtype',
55379 'type': 'dtype',
55380 'notSupported': true
55381 }
55382 ]
55383 },
55384 {
55385 'tfOpName': 'LeakyRelu',
55386 'category': 'basic_math',
55387 'inputs': [
55388 {
55389 'start': 0,
55390 'name': 'x',
55391 'type': 'tensor'
55392 }
55393 ],
55394 'attrs': [
55395 {
55396 'tfName': 'alpha',
55397 'name': 'alpha',
55398 'type': 'number',
55399 'defaultValue': 0.2
55400 },
55401 {
55402 'tfName': 'T',
55403 'name': 'dtype',
55404 'type': 'dtype',
55405 'notSupported': true
55406 }
55407 ]
55408 },
55409 {
55410 'tfOpName': 'IsNan',
55411 'category': 'basic_math',
55412 'inputs': [
55413 {
55414 'start': 0,
55415 'name': 'x',
55416 'type': 'tensor'
55417 }
55418 ],
55419 'attrs': [
55420 {
55421 'tfName': 'T',
55422 'name': 'dtype',
55423 'type': 'dtype',
55424 'notSupported': true
55425 }
55426 ]
55427 }
55428 ];
55429
55430 var basicMath = /*#__PURE__*/Object.freeze({
55431 __proto__: null,
55432 json: json$1
55433 });
55434
55435 /**
55436 * @license
55437 * Copyright 2022 Google LLC. All Rights Reserved.
55438 * Licensed under the Apache License, Version 2.0 (the "License");
55439 * you may not use this file except in compliance with the License.
55440 * You may obtain a copy of the License at
55441 *
55442 * http://www.apache.org/licenses/LICENSE-2.0
55443 *
55444 * Unless required by applicable law or agreed to in writing, software
55445 * distributed under the License is distributed on an "AS IS" BASIS,
55446 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55447 * See the License for the specific language governing permissions and
55448 * limitations under the License.
55449 * =============================================================================
55450 */
55451 const json$2 = [
55452 {
55453 'tfOpName': 'EmptyTensorList',
55454 'category': 'control',
55455 'inputs': [
55456 {
55457 'start': 0,
55458 'name': 'elementShape',
55459 'type': 'shape'
55460 },
55461 {
55462 'start': 1,
55463 'name': 'maxNumElements',
55464 'type': 'number'
55465 }
55466 ],
55467 'attrs': [
55468 {
55469 'tfName': 'element_dtype',
55470 'name': 'elementDType',
55471 'type': 'dtype'
55472 }
55473 ]
55474 },
55475 {
55476 'tfOpName': 'LoopCond',
55477 'category': 'control',
55478 'inputs': [
55479 {
55480 'start': 0,
55481 'name': 'pred',
55482 'type': 'tensor'
55483 }
55484 ]
55485 },
55486 {
55487 'tfOpName': 'Switch',
55488 'category': 'control',
55489 'inputs': [
55490 {
55491 'start': 0,
55492 'name': 'data',
55493 'type': 'tensor'
55494 },
55495 {
55496 'start': 1,
55497 'name': 'pred',
55498 'type': 'tensor'
55499 }
55500 ]
55501 },
55502 {
55503 'tfOpName': 'Merge',
55504 'category': 'control',
55505 'inputs': [
55506 {
55507 'start': 0,
55508 'end': 0,
55509 'name': 'tensors',
55510 'type': 'tensors'
55511 }
55512 ]
55513 },
55514 {
55515 'tfOpName': 'Enter',
55516 'category': 'control',
55517 'inputs': [
55518 {
55519 'start': 0,
55520 'name': 'tensor',
55521 'type': 'tensor'
55522 }
55523 ],
55524 'attrs': [
55525 {
55526 'tfName': 'T',
55527 'name': 'dtype',
55528 'type': 'dtype',
55529 'notSupported': true
55530 },
55531 {
55532 'tfName': 'frame_name',
55533 'name': 'frameName',
55534 'type': 'string'
55535 },
55536 {
55537 'tfName': 'is_constant',
55538 'name': 'isConstant',
55539 'type': 'bool'
55540 }
55541 ]
55542 },
55543 {
55544 'tfOpName': 'Exit',
55545 'category': 'control',
55546 'inputs': [
55547 {
55548 'start': 0,
55549 'name': 'tensor',
55550 'type': 'tensor'
55551 }
55552 ],
55553 'attrs': [
55554 {
55555 'tfName': 'T',
55556 'name': 'dtype',
55557 'type': 'dtype',
55558 'notSupported': true
55559 }
55560 ]
55561 },
55562 {
55563 'tfOpName': 'NextIteration',
55564 'category': 'control',
55565 'inputs': [
55566 {
55567 'start': 0,
55568 'name': 'tensor',
55569 'type': 'tensor'
55570 }
55571 ],
55572 'attrs': [
55573 {
55574 'tfName': 'T',
55575 'name': 'dtype',
55576 'type': 'dtype',
55577 'notSupported': true
55578 }
55579 ]
55580 },
55581 {
55582 'tfOpName': 'TensorArrayV3',
55583 'category': 'control',
55584 'inputs': [
55585 {
55586 'start': 0,
55587 'name': 'size',
55588 'type': 'number'
55589 }
55590 ],
55591 'attrs': [
55592 {
55593 'tfName': 'dtype',
55594 'name': 'dtype',
55595 'type': 'dtype'
55596 },
55597 {
55598 'tfName': 'element_shape',
55599 'name': 'elementShape',
55600 'type': 'shape'
55601 },
55602 {
55603 'tfName': 'dynamic_size',
55604 'name': 'dynamicSize',
55605 'type': 'bool'
55606 },
55607 {
55608 'tfName': 'clear_after_read',
55609 'name': 'clearAfterRead',
55610 'type': 'bool'
55611 },
55612 {
55613 'tfName': 'identical_element_shapes',
55614 'name': 'identicalElementShapes',
55615 'type': 'bool'
55616 },
55617 {
55618 'tfName': 'tensor_array_name',
55619 'name': 'name',
55620 'type': 'string'
55621 }
55622 ]
55623 },
55624 {
55625 'tfOpName': 'TensorArrayWriteV3',
55626 'category': 'control',
55627 'inputs': [
55628 {
55629 'start': 0,
55630 'name': 'tensorArrayId',
55631 'type': 'tensor'
55632 },
55633 {
55634 'start': 1,
55635 'name': 'index',
55636 'type': 'number'
55637 },
55638 {
55639 'start': 2,
55640 'name': 'tensor',
55641 'type': 'tensor'
55642 },
55643 {
55644 'start': 3,
55645 'name': 'flowIn',
55646 'type': 'number'
55647 }
55648 ],
55649 'attrs': [
55650 {
55651 'tfName': 'T',
55652 'name': 'dtype',
55653 'type': 'dtype',
55654 'notSupported': true
55655 }
55656 ]
55657 },
55658 {
55659 'tfOpName': 'TensorArrayReadV3',
55660 'category': 'control',
55661 'inputs': [
55662 {
55663 'start': 0,
55664 'name': 'tensorArrayId',
55665 'type': 'tensor'
55666 },
55667 {
55668 'start': 1,
55669 'name': 'index',
55670 'type': 'number'
55671 },
55672 {
55673 'start': 2,
55674 'name': 'flowIn',
55675 'type': 'number'
55676 }
55677 ],
55678 'attrs': [
55679 {
55680 'tfName': 'dtype',
55681 'name': 'dtype',
55682 'type': 'dtype',
55683 'notSupported': true
55684 }
55685 ]
55686 },
55687 {
55688 'tfOpName': 'TensorArrayGatherV3',
55689 'category': 'control',
55690 'inputs': [
55691 {
55692 'start': 0,
55693 'name': 'tensorArrayId',
55694 'type': 'tensor'
55695 },
55696 {
55697 'start': 1,
55698 'name': 'indices',
55699 'type': 'number[]'
55700 },
55701 {
55702 'start': 2,
55703 'name': 'flowIn',
55704 'type': 'number'
55705 }
55706 ],
55707 'attrs': [
55708 {
55709 'tfName': 'dtype',
55710 'name': 'dtype',
55711 'type': 'dtype'
55712 },
55713 {
55714 'tfName': 'element_shape',
55715 'name': 'elementShape',
55716 'type': 'shape'
55717 }
55718 ]
55719 },
55720 {
55721 'tfOpName': 'TensorArrayScatterV3',
55722 'category': 'control',
55723 'inputs': [
55724 {
55725 'start': 0,
55726 'name': 'tensorArrayId',
55727 'type': 'tensor'
55728 },
55729 {
55730 'start': 1,
55731 'name': 'indices',
55732 'type': 'number[]'
55733 },
55734 {
55735 'start': 2,
55736 'name': 'tensor',
55737 'type': 'tensor'
55738 },
55739 {
55740 'start': 3,
55741 'name': 'flowIn',
55742 'type': 'number'
55743 }
55744 ],
55745 'attrs': [
55746 {
55747 'tfName': 'T',
55748 'name': 'dtype',
55749 'type': 'dtype'
55750 }
55751 ]
55752 },
55753 {
55754 'tfOpName': 'TensorArrayConcatV3',
55755 'category': 'control',
55756 'inputs': [
55757 {
55758 'start': 0,
55759 'name': 'tensorArrayId',
55760 'type': 'tensor'
55761 },
55762 {
55763 'start': 1,
55764 'name': 'flowIn',
55765 'type': 'number'
55766 }
55767 ],
55768 'attrs': [
55769 {
55770 'tfName': 'dtype',
55771 'name': 'dtype',
55772 'type': 'dtype'
55773 },
55774 {
55775 'tfName': 'element_shape_except0',
55776 'name': 'elementShapeExcept0',
55777 'type': 'shape',
55778 'notSupported': true
55779 }
55780 ]
55781 },
55782 {
55783 'tfOpName': 'TensorArraySplitV3',
55784 'category': 'control',
55785 'inputs': [
55786 {
55787 'start': 0,
55788 'name': 'tensorArrayId',
55789 'type': 'tensor'
55790 },
55791 {
55792 'start': 1,
55793 'name': 'tensor',
55794 'type': 'tensor'
55795 },
55796 {
55797 'start': 2,
55798 'name': 'lengths',
55799 'type': 'number[]'
55800 },
55801 {
55802 'start': 3,
55803 'name': 'flowIn',
55804 'type': 'number'
55805 }
55806 ],
55807 'attrs': [
55808 {
55809 'tfName': 'T',
55810 'name': 'dtype',
55811 'type': 'dtype'
55812 }
55813 ]
55814 },
55815 {
55816 'tfOpName': 'TensorArraySizeV3',
55817 'category': 'control',
55818 'inputs': [
55819 {
55820 'start': 0,
55821 'name': 'tensorArrayId',
55822 'type': 'tensor'
55823 },
55824 {
55825 'start': 1,
55826 'name': 'flowIn',
55827 'type': 'number'
55828 }
55829 ]
55830 },
55831 {
55832 'tfOpName': 'TensorArrayCloseV3',
55833 'category': 'control',
55834 'inputs': [
55835 {
55836 'start': 0,
55837 'name': 'tensorArrayId',
55838 'type': 'tensor'
55839 }
55840 ]
55841 },
55842 {
55843 'tfOpName': 'StatelessIf',
55844 'category': 'control',
55845 'inputs': [
55846 {
55847 'start': 0,
55848 'name': 'cond',
55849 'type': 'tensor'
55850 },
55851 {
55852 'start': 1,
55853 'end': 0,
55854 'name': 'args',
55855 'type': 'tensors'
55856 }
55857 ],
55858 'attrs': [
55859 {
55860 'tfName': 'then_branch',
55861 'name': 'thenBranch',
55862 'type': 'func'
55863 },
55864 {
55865 'tfName': 'else_branch',
55866 'name': 'elseBranch',
55867 'type': 'func'
55868 }
55869 ]
55870 },
55871 {
55872 'tfOpName': 'If',
55873 'category': 'control',
55874 'inputs': [
55875 {
55876 'start': 0,
55877 'name': 'cond',
55878 'type': 'tensor'
55879 },
55880 {
55881 'start': 1,
55882 'end': 0,
55883 'name': 'args',
55884 'type': 'tensors'
55885 }
55886 ],
55887 'attrs': [
55888 {
55889 'tfName': 'then_branch',
55890 'name': 'thenBranch',
55891 'type': 'func'
55892 },
55893 {
55894 'tfName': 'else_branch',
55895 'name': 'elseBranch',
55896 'type': 'func'
55897 }
55898 ]
55899 },
55900 {
55901 'tfOpName': 'StatelessWhile',
55902 'category': 'control',
55903 'inputs': [
55904 {
55905 'start': 0,
55906 'end': 0,
55907 'name': 'args',
55908 'type': 'tensors'
55909 }
55910 ],
55911 'attrs': [
55912 {
55913 'tfName': 'cond',
55914 'name': 'cond',
55915 'type': 'func'
55916 },
55917 {
55918 'tfName': 'body',
55919 'name': 'body',
55920 'type': 'func'
55921 }
55922 ]
55923 },
55924 {
55925 'tfOpName': 'While',
55926 'category': 'control',
55927 'inputs': [
55928 {
55929 'start': 0,
55930 'end': 0,
55931 'name': 'args',
55932 'type': 'tensors'
55933 }
55934 ],
55935 'attrs': [
55936 {
55937 'tfName': 'cond',
55938 'name': 'cond',
55939 'type': 'func'
55940 },
55941 {
55942 'tfName': 'body',
55943 'name': 'body',
55944 'type': 'func'
55945 }
55946 ]
55947 },
55948 {
55949 'tfOpName': 'TensorListScatter',
55950 'category': 'control',
55951 'inputs': [
55952 {
55953 'start': 0,
55954 'name': 'tensor',
55955 'type': 'tensor'
55956 },
55957 {
55958 'start': 1,
55959 'name': 'indices',
55960 'type': 'number[]'
55961 },
55962 {
55963 'start': 2,
55964 'name': 'elementShape',
55965 'type': 'shape'
55966 }
55967 ],
55968 'attrs': [
55969 {
55970 'tfName': 'element_dtype',
55971 'name': 'elementDType',
55972 'type': 'dtype'
55973 }
55974 ]
55975 },
55976 {
55977 'tfOpName': 'TensorListScatterV2',
55978 'category': 'control',
55979 'inputs': [
55980 {
55981 'start': 0,
55982 'name': 'tensor',
55983 'type': 'tensor'
55984 },
55985 {
55986 'start': 1,
55987 'name': 'indices',
55988 'type': 'number[]'
55989 },
55990 {
55991 'start': 2,
55992 'name': 'elementShape',
55993 'type': 'shape'
55994 },
55995 {
55996 'start': 3,
55997 'name': 'numElements',
55998 'type': 'number'
55999 }
56000 ],
56001 'attrs': [
56002 {
56003 'tfName': 'element_dtype',
56004 'name': 'elementDType',
56005 'type': 'dtype'
56006 }
56007 ]
56008 },
56009 {
56010 'tfOpName': 'TensorListGather',
56011 'category': 'control',
56012 'inputs': [
56013 {
56014 'start': 0,
56015 'name': 'tensorListId',
56016 'type': 'tensor'
56017 },
56018 {
56019 'start': 1,
56020 'name': 'indices',
56021 'type': 'number[]'
56022 },
56023 {
56024 'start': 2,
56025 'name': 'elementShape',
56026 'type': 'shape'
56027 }
56028 ],
56029 'attrs': [
56030 {
56031 'tfName': 'element_dtype',
56032 'name': 'elementDType',
56033 'type': 'dtype'
56034 }
56035 ]
56036 },
56037 {
56038 'tfOpName': 'TensorListGetItem',
56039 'category': 'control',
56040 'inputs': [
56041 {
56042 'start': 0,
56043 'name': 'tensorListId',
56044 'type': 'tensor'
56045 },
56046 {
56047 'start': 1,
56048 'name': 'index',
56049 'type': 'number'
56050 },
56051 {
56052 'start': 2,
56053 'name': 'elementShape',
56054 'type': 'shape'
56055 }
56056 ],
56057 'attrs': [
56058 {
56059 'tfName': 'element_dtype',
56060 'name': 'elementDType',
56061 'type': 'dtype'
56062 }
56063 ]
56064 },
56065 {
56066 'tfOpName': 'TensorListSetItem',
56067 'category': 'control',
56068 'inputs': [
56069 {
56070 'start': 0,
56071 'name': 'tensorListId',
56072 'type': 'tensor'
56073 },
56074 {
56075 'start': 1,
56076 'name': 'index',
56077 'type': 'number'
56078 },
56079 {
56080 'start': 2,
56081 'name': 'tensor',
56082 'type': 'tensor'
56083 }
56084 ],
56085 'attrs': [
56086 {
56087 'tfName': 'element_dtype',
56088 'name': 'elementDType',
56089 'type': 'dtype'
56090 }
56091 ]
56092 },
56093 {
56094 'tfOpName': 'TensorListReserve',
56095 'category': 'control',
56096 'inputs': [
56097 {
56098 'start': 0,
56099 'name': 'elementShape',
56100 'type': 'shape'
56101 },
56102 {
56103 'start': 1,
56104 'name': 'numElements',
56105 'type': 'number'
56106 }
56107 ],
56108 'attrs': [
56109 {
56110 'tfName': 'element_dtype',
56111 'name': 'elementDType',
56112 'type': 'dtype'
56113 }
56114 ]
56115 },
56116 {
56117 'tfOpName': 'TensorListFromTensor',
56118 'category': 'control',
56119 'inputs': [
56120 {
56121 'start': 0,
56122 'name': 'tensor',
56123 'type': 'tensor'
56124 },
56125 {
56126 'start': 1,
56127 'name': 'elementShape',
56128 'type': 'shape'
56129 }
56130 ],
56131 'attrs': [
56132 {
56133 'tfName': 'element_dtype',
56134 'name': 'elementDType',
56135 'type': 'dtype'
56136 }
56137 ]
56138 },
56139 {
56140 'tfOpName': 'TensorListStack',
56141 'category': 'control',
56142 'inputs': [
56143 {
56144 'start': 0,
56145 'name': 'tensorListId',
56146 'type': 'tensor'
56147 },
56148 {
56149 'start': 1,
56150 'name': 'elementShape',
56151 'type': 'shape'
56152 }
56153 ],
56154 'attrs': [
56155 {
56156 'tfName': 'element_dtype',
56157 'name': 'elementDType',
56158 'type': 'dtype'
56159 },
56160 {
56161 'tfName': 'num_elements',
56162 'name': 'numElements',
56163 'type': 'dtype'
56164 }
56165 ]
56166 },
56167 {
56168 'tfOpName': 'TensorListSplit',
56169 'category': 'control',
56170 'inputs': [
56171 {
56172 'start': 0,
56173 'name': 'tensor',
56174 'type': 'tensor'
56175 },
56176 {
56177 'start': 1,
56178 'name': 'elementShape',
56179 'type': 'shape'
56180 },
56181 {
56182 'start': 2,
56183 'name': 'lengths',
56184 'type': 'number[]'
56185 }
56186 ],
56187 'attrs': [
56188 {
56189 'tfName': 'element_dtype',
56190 'name': 'elementDType',
56191 'type': 'dtype'
56192 }
56193 ]
56194 },
56195 {
56196 'tfOpName': 'TensorListConcat',
56197 'category': 'control',
56198 'inputs': [
56199 {
56200 'start': 0,
56201 'name': 'tensorListId',
56202 'type': 'tensor'
56203 }
56204 ],
56205 'attrs': [
56206 {
56207 'tfName': 'element_shape',
56208 'name': 'elementShape',
56209 'type': 'shape'
56210 },
56211 {
56212 'tfName': 'element_dtype',
56213 'name': 'elementDType',
56214 'type': 'dtype'
56215 }
56216 ]
56217 },
56218 {
56219 'tfOpName': 'TensorListConcatV2',
56220 'category': 'control',
56221 'inputs': [
56222 {
56223 'start': 0,
56224 'name': 'tensorListId',
56225 'type': 'tensor'
56226 }
56227 ],
56228 'attrs': [
56229 {
56230 'tfName': 'element_shape',
56231 'name': 'elementShape',
56232 'type': 'shape'
56233 },
56234 {
56235 'tfName': 'element_dtype',
56236 'name': 'elementDType',
56237 'type': 'dtype'
56238 }
56239 ]
56240 },
56241 {
56242 'tfOpName': 'TensorListPopBack',
56243 'category': 'control',
56244 'inputs': [
56245 {
56246 'start': 0,
56247 'name': 'tensorListId',
56248 'type': 'tensor'
56249 },
56250 {
56251 'start': 1,
56252 'name': 'elementShape',
56253 'type': 'shape'
56254 }
56255 ],
56256 'attrs': [
56257 {
56258 'tfName': 'element_dtype',
56259 'name': 'elementDType',
56260 'type': 'dtype'
56261 }
56262 ]
56263 },
56264 {
56265 'tfOpName': 'TensorListPushBack',
56266 'category': 'control',
56267 'inputs': [
56268 {
56269 'start': 0,
56270 'name': 'tensorListId',
56271 'type': 'tensor'
56272 },
56273 {
56274 'start': 1,
56275 'name': 'tensor',
56276 'type': 'tensor'
56277 }
56278 ],
56279 'attrs': [
56280 {
56281 'tfName': 'element_dtype',
56282 'name': 'elementDType',
56283 'type': 'dtype'
56284 }
56285 ]
56286 },
56287 {
56288 'tfOpName': 'TensorListLength',
56289 'category': 'control',
56290 'inputs': [
56291 {
56292 'start': 0,
56293 'name': 'tensorListId',
56294 'type': 'tensor'
56295 }
56296 ]
56297 },
56298 {
56299 'tfOpName': 'TensorListResize',
56300 'category': 'control',
56301 'inputs': [
56302 {
56303 'start': 0,
56304 'name': 'tensorListId',
56305 'type': 'tensor'
56306 },
56307 {
56308 'start': 1,
56309 'name': 'size',
56310 'type': 'number'
56311 }
56312 ]
56313 }
56314 ];
56315
56316 var control = /*#__PURE__*/Object.freeze({
56317 __proto__: null,
56318 json: json$2
56319 });
56320
56321 /**
56322 * @license
56323 * Copyright 2022 Google LLC. All Rights Reserved.
56324 * Licensed under the Apache License, Version 2.0 (the "License");
56325 * you may not use this file except in compliance with the License.
56326 * You may obtain a copy of the License at
56327 *
56328 * http://www.apache.org/licenses/LICENSE-2.0
56329 *
56330 * Unless required by applicable law or agreed to in writing, software
56331 * distributed under the License is distributed on an "AS IS" BASIS,
56332 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56333 * See the License for the specific language governing permissions and
56334 * limitations under the License.
56335 * =============================================================================
56336 */
56337 const json$3 = [
56338 {
56339 'tfOpName': 'AvgPool',
56340 'category': 'convolution',
56341 'inputs': [
56342 {
56343 'start': 0,
56344 'name': 'x',
56345 'type': 'tensor'
56346 }
56347 ],
56348 'attrs': [
56349 {
56350 'tfName': 'strides',
56351 'name': 'strides',
56352 'type': 'number[]'
56353 },
56354 {
56355 'tfName': 'padding',
56356 'name': 'pad',
56357 'type': 'string'
56358 },
56359 {
56360 'tfName': 'data_format',
56361 'name': 'dataFormat',
56362 'type': 'string',
56363 'notSupported': true
56364 },
56365 {
56366 'tfName': 'ksize',
56367 'name': 'kernelSize',
56368 'type': 'number[]'
56369 },
56370 {
56371 'tfName': 'T',
56372 'name': 'dtype',
56373 'type': 'dtype',
56374 'notSupported': true
56375 }
56376 ]
56377 },
56378 {
56379 'tfOpName': 'MaxPool',
56380 'category': 'convolution',
56381 'inputs': [
56382 {
56383 'start': 0,
56384 'name': 'x',
56385 'type': 'tensor'
56386 }
56387 ],
56388 'attrs': [
56389 {
56390 'tfName': 'strides',
56391 'name': 'strides',
56392 'type': 'number[]'
56393 },
56394 {
56395 'tfName': 'padding',
56396 'name': 'pad',
56397 'type': 'string'
56398 },
56399 {
56400 'tfName': 'data_format',
56401 'name': 'dataFormat',
56402 'type': 'string',
56403 'notSupported': true
56404 },
56405 {
56406 'tfName': 'ksize',
56407 'name': 'kernelSize',
56408 'type': 'number[]'
56409 },
56410 {
56411 'tfName': 'explicit_paddings',
56412 'name': 'explicitPaddings',
56413 'type': 'number[]',
56414 'defaultValue': [],
56415 'notSupported': true
56416 },
56417 {
56418 'tfName': 'T',
56419 'name': 'dtype',
56420 'type': 'dtype',
56421 'notSupported': true
56422 }
56423 ]
56424 },
56425 {
56426 'tfOpName': 'MaxPoolWithArgmax',
56427 'category': 'convolution',
56428 'inputs': [
56429 {
56430 'start': 0,
56431 'name': 'x',
56432 'type': 'tensor'
56433 }
56434 ],
56435 'attrs': [
56436 {
56437 'tfName': 'strides',
56438 'name': 'strides',
56439 'type': 'number[]'
56440 },
56441 {
56442 'tfName': 'padding',
56443 'name': 'pad',
56444 'type': 'string'
56445 },
56446 {
56447 'tfName': 'ksize',
56448 'name': 'kernelSize',
56449 'type': 'number[]'
56450 },
56451 {
56452 'tfName': 'include_batch_in_index',
56453 'name': 'includeBatchInIndex',
56454 'type': 'bool'
56455 },
56456 {
56457 'tfName': 'T',
56458 'name': 'dtype',
56459 'type': 'dtype',
56460 'notSupported': true
56461 }
56462 ]
56463 },
56464 {
56465 'tfOpName': 'AvgPool3D',
56466 'category': 'convolution',
56467 'inputs': [
56468 {
56469 'start': 0,
56470 'name': 'x',
56471 'type': 'tensor'
56472 }
56473 ],
56474 'attrs': [
56475 {
56476 'tfName': 'strides',
56477 'name': 'strides',
56478 'type': 'number[]'
56479 },
56480 {
56481 'tfName': 'padding',
56482 'name': 'pad',
56483 'type': 'string'
56484 },
56485 {
56486 'tfName': 'data_format',
56487 'name': 'dataFormat',
56488 'type': 'string',
56489 'notSupported': true
56490 },
56491 {
56492 'tfName': 'ksize',
56493 'name': 'kernelSize',
56494 'type': 'number[]'
56495 },
56496 {
56497 'tfName': 'T',
56498 'name': 'dtype',
56499 'type': 'dtype',
56500 'notSupported': true
56501 }
56502 ]
56503 },
56504 {
56505 'tfOpName': 'MaxPool3D',
56506 'category': 'convolution',
56507 'inputs': [
56508 {
56509 'start': 0,
56510 'name': 'x',
56511 'type': 'tensor'
56512 }
56513 ],
56514 'attrs': [
56515 {
56516 'tfName': 'strides',
56517 'name': 'strides',
56518 'type': 'number[]'
56519 },
56520 {
56521 'tfName': 'padding',
56522 'name': 'pad',
56523 'type': 'string'
56524 },
56525 {
56526 'tfName': 'data_format',
56527 'name': 'dataFormat',
56528 'type': 'string',
56529 'notSupported': true
56530 },
56531 {
56532 'tfName': 'ksize',
56533 'name': 'kernelSize',
56534 'type': 'number[]'
56535 },
56536 {
56537 'tfName': 'T',
56538 'name': 'dtype',
56539 'type': 'dtype',
56540 'notSupported': true
56541 }
56542 ]
56543 },
56544 {
56545 'tfOpName': 'Conv1D',
56546 'category': 'convolution',
56547 'inputs': [
56548 {
56549 'start': 0,
56550 'name': 'x',
56551 'type': 'tensor'
56552 },
56553 {
56554 'start': 1,
56555 'name': 'filter',
56556 'type': 'tensor'
56557 }
56558 ],
56559 'attrs': [
56560 {
56561 'tfName': 'stride',
56562 'name': 'stride',
56563 'type': 'number'
56564 },
56565 {
56566 'tfName': 'padding',
56567 'name': 'pad',
56568 'type': 'string'
56569 },
56570 {
56571 'tfName': 'data_format',
56572 'name': 'dataFormat',
56573 'type': 'string',
56574 'defaultValue': 'NWC'
56575 },
56576 {
56577 'tfName': 'T',
56578 'name': 'dtype',
56579 'type': 'dtype',
56580 'notSupported': true
56581 },
56582 {
56583 'tfName': 'dilation',
56584 'name': 'dilation',
56585 'type': 'number',
56586 'defaultValue': 1
56587 }
56588 ]
56589 },
56590 {
56591 'tfOpName': 'Conv2D',
56592 'category': 'convolution',
56593 'inputs': [
56594 {
56595 'start': 0,
56596 'name': 'x',
56597 'type': 'tensor'
56598 },
56599 {
56600 'start': 1,
56601 'name': 'filter',
56602 'type': 'tensor'
56603 }
56604 ],
56605 'attrs': [
56606 {
56607 'tfName': 'T',
56608 'name': 'dtype',
56609 'type': 'dtype',
56610 'notSupported': true
56611 },
56612 {
56613 'tfName': 'strides',
56614 'name': 'strides',
56615 'type': 'number[]'
56616 },
56617 {
56618 'tfName': 'padding',
56619 'name': 'pad',
56620 'type': 'string'
56621 },
56622 {
56623 'tfName': 'useCudnnOnGpu',
56624 'name': 'useCudnnOnGpu',
56625 'type': 'bool'
56626 },
56627 {
56628 'tfName': 'data_format',
56629 'name': 'dataFormat',
56630 'type': 'string',
56631 'defaultValue': 'NHWC'
56632 },
56633 {
56634 'tfName': 'explicit_paddings',
56635 'name': 'explicitPaddings',
56636 'type': 'number[]',
56637 'defaultValue': []
56638 },
56639 {
56640 'tfName': 'dilations',
56641 'name': 'dilations',
56642 'type': 'number[]'
56643 }
56644 ]
56645 },
56646 {
56647 'tfOpName': '_FusedConv2D',
56648 'category': 'convolution',
56649 'inputs': [
56650 {
56651 'start': 0,
56652 'name': 'x',
56653 'type': 'tensor'
56654 },
56655 {
56656 'start': 1,
56657 'name': 'filter',
56658 'type': 'tensor'
56659 },
56660 {
56661 'start': 2,
56662 'end': 0,
56663 'name': 'args',
56664 'type': 'tensors'
56665 }
56666 ],
56667 'attrs': [
56668 {
56669 'tfName': 'num_args',
56670 'name': 'numArgs',
56671 'type': 'number'
56672 },
56673 {
56674 'tfName': 'T',
56675 'name': 'dtype',
56676 'type': 'dtype',
56677 'notSupported': true
56678 },
56679 {
56680 'tfName': 'strides',
56681 'name': 'strides',
56682 'type': 'number[]'
56683 },
56684 {
56685 'tfName': 'padding',
56686 'name': 'pad',
56687 'type': 'string'
56688 },
56689 {
56690 'tfName': 'explicit_paddings',
56691 'name': 'explicitPaddings',
56692 'type': 'number[]',
56693 'defaultValue': []
56694 },
56695 {
56696 'tfName': 'use_cudnn_on_gpu',
56697 'name': 'useCudnnOnGpu',
56698 'type': 'bool',
56699 'defaultValue': true
56700 },
56701 {
56702 'tfName': 'data_format',
56703 'name': 'dataFormat',
56704 'type': 'string',
56705 'defaultValue': 'NHWC'
56706 },
56707 {
56708 'tfName': 'dilations',
56709 'name': 'dilations',
56710 'type': 'number[]',
56711 'defaultValue': [
56712 1,
56713 1,
56714 1,
56715 1
56716 ]
56717 },
56718 {
56719 'tfName': 'fused_ops',
56720 'name': 'fusedOps',
56721 'type': 'string[]',
56722 'defaultValue': []
56723 },
56724 {
56725 'tfName': 'epsilon',
56726 'name': 'epsilon',
56727 'type': 'number',
56728 'defaultValue': 0.0001
56729 },
56730 {
56731 'tfName': 'leakyrelu_alpha',
56732 'name': 'leakyreluAlpha',
56733 'type': 'number'
56734 }
56735 ]
56736 },
56737 {
56738 'tfOpName': 'Conv2DBackpropInput',
56739 'category': 'convolution',
56740 'inputs': [
56741 {
56742 'start': 2,
56743 'name': 'x',
56744 'type': 'tensor'
56745 },
56746 {
56747 'start': 1,
56748 'name': 'filter',
56749 'type': 'tensor'
56750 },
56751 {
56752 'start': 0,
56753 'name': 'outputShape',
56754 'type': 'number[]'
56755 }
56756 ],
56757 'attrs': [
56758 {
56759 'tfName': 'strides',
56760 'name': 'strides',
56761 'type': 'number[]'
56762 },
56763 {
56764 'tfName': 'padding',
56765 'name': 'pad',
56766 'type': 'string'
56767 },
56768 {
56769 'tfName': 'data_format',
56770 'name': 'dataFormat',
56771 'type': 'string',
56772 'notSupported': true
56773 },
56774 {
56775 'tfName': 'explicit_paddings',
56776 'name': 'explicitPaddings',
56777 'type': 'number[]',
56778 'defaultValue': []
56779 },
56780 {
56781 'tfName': 'dilations',
56782 'name': 'dilations',
56783 'type': 'number[]',
56784 'notSupported': true
56785 }
56786 ]
56787 },
56788 {
56789 'tfOpName': 'DepthwiseConv2d',
56790 'category': 'convolution',
56791 'inputs': [
56792 {
56793 'start': 0,
56794 'name': 'input',
56795 'type': 'tensor'
56796 },
56797 {
56798 'start': 1,
56799 'name': 'filter',
56800 'type': 'tensor'
56801 }
56802 ],
56803 'attrs': [
56804 {
56805 'tfName': 'strides',
56806 'name': 'strides',
56807 'type': 'number[]'
56808 },
56809 {
56810 'tfName': 'padding',
56811 'name': 'pad',
56812 'type': 'string'
56813 },
56814 {
56815 'tfName': 'data_format',
56816 'name': 'dataFormat',
56817 'type': 'string',
56818 'defaultValue': 'NHWC'
56819 },
56820 {
56821 'tfName': 'explicit_paddings',
56822 'name': 'explicitPaddings',
56823 'type': 'number[]',
56824 'defaultValue': []
56825 },
56826 {
56827 'tfName': 'dilations',
56828 'name': 'dilations',
56829 'type': 'number[]'
56830 }
56831 ]
56832 },
56833 {
56834 'tfOpName': 'DepthwiseConv2dNative',
56835 'category': 'convolution',
56836 'inputs': [
56837 {
56838 'start': 0,
56839 'name': 'input',
56840 'type': 'tensor'
56841 },
56842 {
56843 'start': 1,
56844 'name': 'filter',
56845 'type': 'tensor'
56846 }
56847 ],
56848 'attrs': [
56849 {
56850 'tfName': 'strides',
56851 'name': 'strides',
56852 'type': 'number[]'
56853 },
56854 {
56855 'tfName': 'padding',
56856 'name': 'pad',
56857 'type': 'string'
56858 },
56859 {
56860 'tfName': 'data_format',
56861 'name': 'dataFormat',
56862 'type': 'string',
56863 'defaultValue': 'NHWC'
56864 },
56865 {
56866 'tfName': 'explicit_paddings',
56867 'name': 'explicitPaddings',
56868 'type': 'number[]',
56869 'defaultValue': []
56870 },
56871 {
56872 'tfName': 'dilations',
56873 'name': 'dilations',
56874 'type': 'number[]'
56875 }
56876 ]
56877 },
56878 {
56879 'tfOpName': 'FusedDepthwiseConv2dNative',
56880 'category': 'convolution',
56881 'inputs': [
56882 {
56883 'start': 0,
56884 'name': 'x',
56885 'type': 'tensor'
56886 },
56887 {
56888 'start': 1,
56889 'name': 'filter',
56890 'type': 'tensor'
56891 },
56892 {
56893 'start': 2,
56894 'end': 0,
56895 'name': 'args',
56896 'type': 'tensors'
56897 }
56898 ],
56899 'attrs': [
56900 {
56901 'tfName': 'num_args',
56902 'name': 'numArgs',
56903 'type': 'number'
56904 },
56905 {
56906 'tfName': 'T',
56907 'name': 'dtype',
56908 'type': 'dtype',
56909 'notSupported': true
56910 },
56911 {
56912 'tfName': 'strides',
56913 'name': 'strides',
56914 'type': 'number[]'
56915 },
56916 {
56917 'tfName': 'padding',
56918 'name': 'pad',
56919 'type': 'string'
56920 },
56921 {
56922 'tfName': 'data_format',
56923 'name': 'dataFormat',
56924 'type': 'string',
56925 'defaultValue': 'NHWC'
56926 },
56927 {
56928 'tfName': 'dilations',
56929 'name': 'dilations',
56930 'type': 'number[]',
56931 'defaultValue': [
56932 1,
56933 1,
56934 1,
56935 1
56936 ]
56937 },
56938 {
56939 'tfName': 'fused_ops',
56940 'name': 'fusedOps',
56941 'type': 'string[]',
56942 'defaultValue': []
56943 },
56944 {
56945 'tfName': 'explicit_paddings',
56946 'name': 'explicitPaddings',
56947 'type': 'number[]',
56948 'defaultValue': []
56949 }
56950 ]
56951 },
56952 {
56953 'tfOpName': 'Conv3D',
56954 'category': 'convolution',
56955 'inputs': [
56956 {
56957 'start': 0,
56958 'name': 'x',
56959 'type': 'tensor'
56960 },
56961 {
56962 'start': 1,
56963 'name': 'filter',
56964 'type': 'tensor'
56965 }
56966 ],
56967 'attrs': [
56968 {
56969 'tfName': 'strides',
56970 'name': 'strides',
56971 'type': 'number[]'
56972 },
56973 {
56974 'tfName': 'padding',
56975 'name': 'pad',
56976 'type': 'string'
56977 },
56978 {
56979 'tfName': 'data_format',
56980 'name': 'dataFormat',
56981 'type': 'string',
56982 'defaultValue': 'NHWC'
56983 },
56984 {
56985 'tfName': 'dilations',
56986 'name': 'dilations',
56987 'type': 'number[]'
56988 }
56989 ]
56990 },
56991 {
56992 'tfOpName': 'Dilation2D',
56993 'category': 'convolution',
56994 'inputs': [
56995 {
56996 'start': 0,
56997 'name': 'x',
56998 'type': 'tensor'
56999 },
57000 {
57001 'start': 1,
57002 'name': 'filter',
57003 'type': 'tensor'
57004 }
57005 ],
57006 'attrs': [
57007 {
57008 'tfName': 'strides',
57009 'name': 'strides',
57010 'type': 'number[]'
57011 },
57012 {
57013 'tfName': 'rates',
57014 'name': 'dilations',
57015 'type': 'number[]'
57016 },
57017 {
57018 'tfName': 'padding',
57019 'name': 'pad',
57020 'type': 'string'
57021 }
57022 ]
57023 }
57024 ];
57025
57026 var convolution = /*#__PURE__*/Object.freeze({
57027 __proto__: null,
57028 json: json$3
57029 });
57030
57031 /**
57032 * @license
57033 * Copyright 2022 Google LLC. All Rights Reserved.
57034 * Licensed under the Apache License, Version 2.0 (the "License");
57035 * you may not use this file except in compliance with the License.
57036 * You may obtain a copy of the License at
57037 *
57038 * http://www.apache.org/licenses/LICENSE-2.0
57039 *
57040 * Unless required by applicable law or agreed to in writing, software
57041 * distributed under the License is distributed on an "AS IS" BASIS,
57042 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
57043 * See the License for the specific language governing permissions and
57044 * limitations under the License.
57045 * =============================================================================
57046 */
57047 const json$4 = [
57048 {
57049 'tfOpName': 'Fill',
57050 'category': 'creation',
57051 'inputs': [
57052 {
57053 'start': 0,
57054 'name': 'shape',
57055 'type': 'number[]'
57056 },
57057 {
57058 'start': 1,
57059 'name': 'value',
57060 'type': 'number'
57061 }
57062 ],
57063 'attrs': [
57064 {
57065 'tfName': 'T',
57066 'name': 'dtype',
57067 'type': 'dtype'
57068 }
57069 ]
57070 },
57071 {
57072 'tfOpName': 'LinSpace',
57073 'category': 'creation',
57074 'inputs': [
57075 {
57076 'start': 0,
57077 'name': 'start',
57078 'type': 'number'
57079 },
57080 {
57081 'start': 1,
57082 'name': 'stop',
57083 'type': 'number'
57084 },
57085 {
57086 'start': 2,
57087 'name': 'num',
57088 'type': 'number'
57089 }
57090 ],
57091 'attrs': [
57092 {
57093 'tfName': 'T',
57094 'name': 'dtype',
57095 'type': 'dtype',
57096 'notSupported': true
57097 }
57098 ]
57099 },
57100 {
57101 'tfOpName': 'OneHot',
57102 'category': 'creation',
57103 'inputs': [
57104 {
57105 'start': 0,
57106 'name': 'indices',
57107 'type': 'tensor'
57108 },
57109 {
57110 'start': 1,
57111 'name': 'depth',
57112 'type': 'number'
57113 },
57114 {
57115 'start': 2,
57116 'name': 'onValue',
57117 'type': 'number',
57118 'defaultValue': 1
57119 },
57120 {
57121 'start': 3,
57122 'name': 'offValue',
57123 'type': 'number',
57124 'defaultValue': 0
57125 }
57126 ],
57127 'attrs': [
57128 {
57129 'tfName': 'axis',
57130 'name': 'axis',
57131 'type': 'number',
57132 'notSupported': true
57133 },
57134 {
57135 'tfName': 'T',
57136 'name': 'dtype',
57137 'type': 'dtype',
57138 'notSupported': true
57139 }
57140 ]
57141 },
57142 {
57143 'tfOpName': 'Ones',
57144 'category': 'creation',
57145 'inputs': [
57146 {
57147 'start': 0,
57148 'name': 'shape',
57149 'type': 'number[]'
57150 }
57151 ],
57152 'attrs': [
57153 {
57154 'tfName': 'T',
57155 'name': 'dtype',
57156 'type': 'dtype'
57157 }
57158 ]
57159 },
57160 {
57161 'tfOpName': 'OnesLike',
57162 'category': 'creation',
57163 'inputs': [
57164 {
57165 'start': 0,
57166 'name': 'x',
57167 'type': 'tensor'
57168 }
57169 ],
57170 'attrs': [
57171 {
57172 'tfName': 'dtype',
57173 'name': 'dtype',
57174 'type': 'dtype'
57175 }
57176 ]
57177 },
57178 {
57179 'tfOpName': 'RandomUniform',
57180 'category': 'creation',
57181 'inputs': [
57182 {
57183 'start': 0,
57184 'name': 'shape',
57185 'type': 'number[]'
57186 }
57187 ],
57188 'attrs': [
57189 {
57190 'tfName': 'minval',
57191 'name': 'minval',
57192 'type': 'number',
57193 'defaultValue': 0
57194 },
57195 {
57196 'tfName': 'maxval',
57197 'name': 'maxval',
57198 'type': 'number',
57199 'defaultValue': 1
57200 },
57201 {
57202 'tfName': 'dtype',
57203 'name': 'dtype',
57204 'type': 'dtype'
57205 },
57206 {
57207 'tfName': 'seed',
57208 'name': 'seed',
57209 'type': 'number',
57210 'defaultValue': 0
57211 },
57212 {
57213 'tfName': 'seed2',
57214 'name': 'seed2',
57215 'type': 'number',
57216 'defaultValue': 0,
57217 'notSupported': true
57218 },
57219 {
57220 'tfName': 'T',
57221 'name': 'T',
57222 'type': 'number',
57223 'notSupported': true
57224 }
57225 ]
57226 },
57227 {
57228 'tfOpName': 'Range',
57229 'category': 'creation',
57230 'inputs': [
57231 {
57232 'start': 0,
57233 'name': 'start',
57234 'type': 'number'
57235 },
57236 {
57237 'start': 1,
57238 'name': 'stop',
57239 'type': 'number'
57240 },
57241 {
57242 'start': 2,
57243 'name': 'step',
57244 'type': 'number',
57245 'defaultValue': 0
57246 }
57247 ],
57248 'attrs': [
57249 {
57250 'tfName': 'Tidx',
57251 'name': 'dtype',
57252 'type': 'dtype'
57253 }
57254 ]
57255 },
57256 {
57257 'tfOpName': 'TruncatedNormal',
57258 'category': 'creation',
57259 'inputs': [
57260 {
57261 'start': 0,
57262 'name': 'shape',
57263 'type': 'number[]'
57264 }
57265 ],
57266 'attrs': [
57267 {
57268 'tfName': 'means',
57269 'name': 'mean',
57270 'type': 'number',
57271 'defaultValue': 0
57272 },
57273 {
57274 'tfName': 'stddev',
57275 'name': 'stdDev',
57276 'type': 'number',
57277 'defaultValue': 1
57278 },
57279 {
57280 'tfName': 'seed',
57281 'name': 'seed',
57282 'type': 'number'
57283 },
57284 {
57285 'tfName': 'seed2',
57286 'name': 'seed2',
57287 'type': 'number',
57288 'defaultValue': 0,
57289 'notSupported': true
57290 },
57291 {
57292 'tfName': 'dtype',
57293 'name': 'dtype',
57294 'type': 'dtype'
57295 },
57296 {
57297 'tfName': 'T',
57298 'name': 'T',
57299 'type': 'number',
57300 'notSupported': true
57301 }
57302 ]
57303 },
57304 {
57305 'tfOpName': 'Zeros',
57306 'category': 'creation',
57307 'inputs': [
57308 {
57309 'start': 0,
57310 'name': 'shape',
57311 'type': 'number[]'
57312 }
57313 ],
57314 'attrs': [
57315 {
57316 'tfName': 'T',
57317 'name': 'dtype',
57318 'type': 'dtype'
57319 }
57320 ]
57321 },
57322 {
57323 'tfOpName': 'ZerosLike',
57324 'category': 'creation',
57325 'inputs': [
57326 {
57327 'start': 0,
57328 'name': 'x',
57329 'type': 'tensor'
57330 }
57331 ],
57332 'attrs': [
57333 {
57334 'tfName': 'T',
57335 'name': 'dtype',
57336 'type': 'dtype'
57337 }
57338 ]
57339 },
57340 {
57341 'tfOpName': 'Multinomial',
57342 'category': 'creation',
57343 'inputs': [
57344 {
57345 'start': 0,
57346 'name': 'logits',
57347 'type': 'tensor'
57348 },
57349 {
57350 'start': 1,
57351 'name': 'numSamples',
57352 'type': 'number'
57353 }
57354 ],
57355 'attrs': [
57356 {
57357 'tfName': 'seed',
57358 'name': 'seed',
57359 'type': 'number'
57360 },
57361 {
57362 'tfName': 'seed2',
57363 'name': 'seed2',
57364 'type': 'number'
57365 },
57366 {
57367 'tfName': 'T',
57368 'name': 'dtype',
57369 'type': 'dtype'
57370 },
57371 {
57372 'tfName': 'output_dtype',
57373 'name': 'output_dtype',
57374 'type': 'dtype'
57375 }
57376 ]
57377 }
57378 ];
57379
57380 var creation = /*#__PURE__*/Object.freeze({
57381 __proto__: null,
57382 json: json$4
57383 });
57384
57385 /**
57386 * @license
57387 * Copyright 2022 Google LLC. All Rights Reserved.
57388 * Licensed under the Apache License, Version 2.0 (the "License");
57389 * you may not use this file except in compliance with the License.
57390 * You may obtain a copy of the License at
57391 *
57392 * http://www.apache.org/licenses/LICENSE-2.0
57393 *
57394 * Unless required by applicable law or agreed to in writing, software
57395 * distributed under the License is distributed on an "AS IS" BASIS,
57396 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
57397 * See the License for the specific language governing permissions and
57398 * limitations under the License.
57399 * =============================================================================
57400 */
57401 const json$5 = [
57402 {
57403 'tfOpName': 'NonMaxSuppressionV2',
57404 'category': 'dynamic',
57405 'inputs': [
57406 {
57407 'start': 0,
57408 'name': 'boxes',
57409 'type': 'tensor'
57410 },
57411 {
57412 'start': 1,
57413 'name': 'scores',
57414 'type': 'tensor'
57415 },
57416 {
57417 'start': 2,
57418 'name': 'maxOutputSize',
57419 'type': 'number'
57420 },
57421 {
57422 'start': 3,
57423 'name': 'iouThreshold',
57424 'type': 'number'
57425 }
57426 ]
57427 },
57428 {
57429 'tfOpName': 'NonMaxSuppressionV3',
57430 'category': 'dynamic',
57431 'inputs': [
57432 {
57433 'start': 0,
57434 'name': 'boxes',
57435 'type': 'tensor'
57436 },
57437 {
57438 'start': 1,
57439 'name': 'scores',
57440 'type': 'tensor'
57441 },
57442 {
57443 'start': 2,
57444 'name': 'maxOutputSize',
57445 'type': 'number'
57446 },
57447 {
57448 'start': 3,
57449 'name': 'iouThreshold',
57450 'type': 'number'
57451 },
57452 {
57453 'start': 4,
57454 'name': 'scoreThreshold',
57455 'type': 'number'
57456 }
57457 ]
57458 },
57459 {
57460 'tfOpName': 'NonMaxSuppressionV4',
57461 'category': 'dynamic',
57462 'inputs': [
57463 {
57464 'start': 0,
57465 'name': 'boxes',
57466 'type': 'tensor'
57467 },
57468 {
57469 'start': 1,
57470 'name': 'scores',
57471 'type': 'tensor'
57472 },
57473 {
57474 'start': 2,
57475 'name': 'maxOutputSize',
57476 'type': 'number'
57477 },
57478 {
57479 'start': 3,
57480 'name': 'iouThreshold',
57481 'type': 'number'
57482 },
57483 {
57484 'start': 4,
57485 'name': 'scoreThreshold',
57486 'type': 'number'
57487 }
57488 ],
57489 'attrs': [
57490 {
57491 'tfName': 'T',
57492 'name': 'dtype',
57493 'type': 'dtype',
57494 'notSupported': true
57495 },
57496 {
57497 'tfName': 'T_threshold',
57498 'name': 'threshold',
57499 'type': 'dtype',
57500 'notSupported': true
57501 },
57502 {
57503 'tfName': 'pad_to_max_output_size',
57504 'name': 'padToMaxOutputSize',
57505 'type': 'bool'
57506 }
57507 ]
57508 },
57509 {
57510 'tfOpName': 'NonMaxSuppressionV5',
57511 'category': 'dynamic',
57512 'inputs': [
57513 {
57514 'start': 0,
57515 'name': 'boxes',
57516 'type': 'tensor'
57517 },
57518 {
57519 'start': 1,
57520 'name': 'scores',
57521 'type': 'tensor'
57522 },
57523 {
57524 'start': 2,
57525 'name': 'maxOutputSize',
57526 'type': 'number'
57527 },
57528 {
57529 'start': 3,
57530 'name': 'iouThreshold',
57531 'type': 'number'
57532 },
57533 {
57534 'start': 4,
57535 'name': 'scoreThreshold',
57536 'type': 'number'
57537 },
57538 {
57539 'start': 5,
57540 'name': 'softNmsSigma',
57541 'type': 'number'
57542 }
57543 ]
57544 },
57545 {
57546 'tfOpName': 'Where',
57547 'category': 'dynamic',
57548 'inputs': [
57549 {
57550 'start': 0,
57551 'name': 'condition',
57552 'type': 'tensor'
57553 }
57554 ],
57555 'attrs': [
57556 {
57557 'tfName': 'T',
57558 'name': 'dtype',
57559 'type': 'dtype',
57560 'notSupported': true
57561 }
57562 ]
57563 },
57564 {
57565 'tfOpName': 'ListDiff',
57566 'category': 'dynamic',
57567 'inputs': [
57568 {
57569 'start': 0,
57570 'name': 'x',
57571 'type': 'tensor'
57572 },
57573 {
57574 'start': 1,
57575 'name': 'y',
57576 'type': 'tensor'
57577 }
57578 ],
57579 'attrs': [
57580 {
57581 'tfName': 'T',
57582 'name': 'dtype',
57583 'type': 'dtype',
57584 'notSupported': true
57585 }
57586 ]
57587 }
57588 ];
57589
57590 var dynamic = /*#__PURE__*/Object.freeze({
57591 __proto__: null,
57592 json: json$5
57593 });
57594
57595 /**
57596 * @license
57597 * Copyright 2022 Google LLC. All Rights Reserved.
57598 * Licensed under the Apache License, Version 2.0 (the "License");
57599 * you may not use this file except in compliance with the License.
57600 * You may obtain a copy of the License at
57601 *
57602 * http://www.apache.org/licenses/LICENSE-2.0
57603 *
57604 * Unless required by applicable law or agreed to in writing, software
57605 * distributed under the License is distributed on an "AS IS" BASIS,
57606 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
57607 * See the License for the specific language governing permissions and
57608 * limitations under the License.
57609 * =============================================================================
57610 */
57611 const json$6 = [
57612 {
57613 'tfOpName': 'LowerBound',
57614 'category': 'evaluation',
57615 'inputs': [
57616 {
57617 'start': 0,
57618 'name': 'sortedSequence',
57619 'type': 'tensor'
57620 },
57621 {
57622 'start': 1,
57623 'name': 'values',
57624 'type': 'tensor'
57625 }
57626 ]
57627 },
57628 {
57629 'tfOpName': 'TopKV2',
57630 'category': 'evaluation',
57631 'inputs': [
57632 {
57633 'start': 0,
57634 'name': 'x',
57635 'type': 'tensor'
57636 },
57637 {
57638 'start': 1,
57639 'name': 'k',
57640 'type': 'number'
57641 }
57642 ],
57643 'attrs': [
57644 {
57645 'tfName': 'sorted',
57646 'name': 'sorted',
57647 'type': 'bool'
57648 }
57649 ]
57650 },
57651 {
57652 'tfOpName': 'UpperBound',
57653 'category': 'evaluation',
57654 'inputs': [
57655 {
57656 'start': 0,
57657 'name': 'sortedSequence',
57658 'type': 'tensor'
57659 },
57660 {
57661 'start': 1,
57662 'name': 'values',
57663 'type': 'tensor'
57664 }
57665 ]
57666 },
57667 {
57668 'tfOpName': 'Unique',
57669 'category': 'evaluation',
57670 'inputs': [
57671 {
57672 'start': 0,
57673 'name': 'x',
57674 'type': 'tensor'
57675 }
57676 ]
57677 },
57678 {
57679 'tfOpName': 'UniqueV2',
57680 'category': 'evaluation',
57681 'inputs': [
57682 {
57683 'start': 0,
57684 'name': 'x',
57685 'type': 'tensor'
57686 },
57687 {
57688 'start': 1,
57689 'name': 'axis',
57690 'type': 'number'
57691 }
57692 ]
57693 }
57694 ];
57695
57696 var evaluation = /*#__PURE__*/Object.freeze({
57697 __proto__: null,
57698 json: json$6
57699 });
57700
57701 /**
57702 * @license
57703 * Copyright 2022 Google LLC. All Rights Reserved.
57704 * Licensed under the Apache License, Version 2.0 (the "License");
57705 * you may not use this file except in compliance with the License.
57706 * You may obtain a copy of the License at
57707 *
57708 * http://www.apache.org/licenses/LICENSE-2.0
57709 *
57710 * Unless required by applicable law or agreed to in writing, software
57711 * distributed under the License is distributed on an "AS IS" BASIS,
57712 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
57713 * See the License for the specific language governing permissions and
57714 * limitations under the License.
57715 * =============================================================================
57716 */
57717 const json$7 = [
57718 {
57719 'tfOpName': 'PlaceholderWithDefault',
57720 'category': 'graph',
57721 'inputs': [
57722 {
57723 'start': 0,
57724 'name': 'default',
57725 'type': 'tensor'
57726 }
57727 ],
57728 'attrs': [
57729 {
57730 'tfName': 'shape',
57731 'name': 'shape',
57732 'type': 'shape'
57733 },
57734 {
57735 'tfName': 'dtype',
57736 'name': 'dtype',
57737 'type': 'dtype'
57738 }
57739 ]
57740 },
57741 {
57742 'tfOpName': 'Placeholder',
57743 'category': 'graph',
57744 'attrs': [
57745 {
57746 'tfName': 'shape',
57747 'name': 'shape',
57748 'type': 'shape'
57749 },
57750 {
57751 'tfName': 'dtype',
57752 'name': 'dtype',
57753 'type': 'dtype'
57754 }
57755 ]
57756 },
57757 {
57758 'tfOpName': 'Const',
57759 'category': 'graph'
57760 },
57761 {
57762 'tfOpName': 'Identity',
57763 'category': 'graph',
57764 'inputs': [
57765 {
57766 'start': 0,
57767 'name': 'x',
57768 'type': 'tensor'
57769 }
57770 ]
57771 },
57772 {
57773 'tfOpName': 'IdentityN',
57774 'category': 'graph',
57775 'inputs': [
57776 {
57777 'start': 0,
57778 'end': 0,
57779 'name': 'x',
57780 'type': 'tensors'
57781 }
57782 ]
57783 },
57784 {
57785 'tfOpName': 'Snapshot',
57786 'category': 'graph',
57787 'inputs': [
57788 {
57789 'start': 0,
57790 'name': 'x',
57791 'type': 'tensor'
57792 }
57793 ]
57794 },
57795 {
57796 'tfOpName': 'Rank',
57797 'category': 'graph',
57798 'inputs': [
57799 {
57800 'start': 0,
57801 'name': 'x',
57802 'type': 'tensor'
57803 }
57804 ]
57805 },
57806 {
57807 'tfOpName': 'Size',
57808 'category': 'graph',
57809 'inputs': [
57810 {
57811 'start': 0,
57812 'name': 'x',
57813 'type': 'tensor'
57814 }
57815 ]
57816 },
57817 {
57818 'tfOpName': 'Shape',
57819 'category': 'graph',
57820 'inputs': [
57821 {
57822 'start': 0,
57823 'name': 'x',
57824 'type': 'tensor'
57825 }
57826 ]
57827 },
57828 {
57829 'tfOpName': 'ShapeN',
57830 'category': 'graph',
57831 'inputs': [
57832 {
57833 'start': 0,
57834 'end': 0,
57835 'name': 'x',
57836 'type': 'tensors'
57837 }
57838 ]
57839 },
57840 {
57841 'tfOpName': 'Print',
57842 'category': 'graph',
57843 'inputs': [
57844 {
57845 'start': 0,
57846 'name': 'x',
57847 'type': 'tensor'
57848 },
57849 {
57850 'start': 1,
57851 'name': 'data',
57852 'type': 'tensors'
57853 }
57854 ],
57855 'attrs': [
57856 {
57857 'tfName': 'message',
57858 'name': 'message',
57859 'type': 'string'
57860 },
57861 {
57862 'tfName': 'first_n',
57863 'name': 'firstN',
57864 'type': 'number',
57865 'notSupported': true
57866 },
57867 {
57868 'tfName': 'summarize',
57869 'name': 'summarize',
57870 'type': 'number',
57871 'defaultValue': 3
57872 }
57873 ]
57874 },
57875 {
57876 'tfOpName': 'NoOp',
57877 'category': 'graph',
57878 'inputs': []
57879 },
57880 {
57881 'tfOpName': 'StopGradient',
57882 'category': 'graph',
57883 'inputs': [
57884 {
57885 'start': 0,
57886 'name': 'x',
57887 'type': 'tensor'
57888 }
57889 ]
57890 },
57891 {
57892 'tfOpName': 'FakeQuantWithMinMaxVars',
57893 'category': 'graph',
57894 'inputs': [
57895 {
57896 'start': 0,
57897 'name': 'x',
57898 'type': 'tensor'
57899 }
57900 ],
57901 'attrs': [
57902 {
57903 'tfName': 'min',
57904 'name': 'min',
57905 'type': 'number'
57906 },
57907 {
57908 'tfName': 'max',
57909 'name': 'max',
57910 'type': 'number'
57911 }
57912 ]
57913 }
57914 ];
57915
57916 var graph = /*#__PURE__*/Object.freeze({
57917 __proto__: null,
57918 json: json$7
57919 });
57920
57921 /**
57922 * @license
57923 * Copyright 2022 Google LLC. All Rights Reserved.
57924 * Licensed under the Apache License, Version 2.0 (the "License");
57925 * you may not use this file except in compliance with the License.
57926 * You may obtain a copy of the License at
57927 *
57928 * http://www.apache.org/licenses/LICENSE-2.0
57929 *
57930 * Unless required by applicable law or agreed to in writing, software
57931 * distributed under the License is distributed on an "AS IS" BASIS,
57932 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
57933 * See the License for the specific language governing permissions and
57934 * limitations under the License.
57935 * =============================================================================
57936 */
57937 const json$8 = [
57938 {
57939 'tfOpName': 'HashTable',
57940 'category': 'hash_table',
57941 'inputs': [],
57942 'attrs': [
57943 {
57944 'tfName': 'shared_name',
57945 'name': 'sharedName',
57946 'type': 'string'
57947 },
57948 {
57949 'tfName': 'use_node_name_sharing',
57950 'name': 'useNodeNameSharing',
57951 'type': 'bool'
57952 },
57953 {
57954 'tfName': 'key_dtype',
57955 'name': 'keyDType',
57956 'type': 'dtype'
57957 },
57958 {
57959 'tfName': 'value_dtype',
57960 'name': 'valueDType',
57961 'type': 'dtype'
57962 }
57963 ]
57964 },
57965 {
57966 'tfOpName': 'HashTableV2',
57967 'category': 'hash_table',
57968 'inputs': [],
57969 'attrs': [
57970 {
57971 'tfName': 'shared_name',
57972 'name': 'sharedName',
57973 'type': 'string'
57974 },
57975 {
57976 'tfName': 'use_node_name_sharing',
57977 'name': 'useNodeNameSharing',
57978 'type': 'bool'
57979 },
57980 {
57981 'tfName': 'key_dtype',
57982 'name': 'keyDType',
57983 'type': 'dtype'
57984 },
57985 {
57986 'tfName': 'value_dtype',
57987 'name': 'valueDType',
57988 'type': 'dtype'
57989 }
57990 ]
57991 },
57992 {
57993 'tfOpName': 'LookupTableImport',
57994 'category': 'hash_table',
57995 'inputs': [
57996 {
57997 'start': 0,
57998 'name': 'tableHandle',
57999 'type': 'tensor'
58000 },
58001 {
58002 'start': 1,
58003 'name': 'keys',
58004 'type': 'tensor'
58005 },
58006 {
58007 'start': 2,
58008 'name': 'values',
58009 'type': 'tensor'
58010 }
58011 ],
58012 'attrs': [
58013 {
58014 'tfName': 'Tin',
58015 'name': 'tIn',
58016 'type': 'dtype',
58017 'notSupported': true
58018 },
58019 {
58020 'tfName': 'Tout',
58021 'name': 'tOut',
58022 'type': 'dtype',
58023 'notSupported': true
58024 }
58025 ]
58026 },
58027 {
58028 'tfOpName': 'LookupTableImportV2',
58029 'category': 'hash_table',
58030 'inputs': [
58031 {
58032 'start': 0,
58033 'name': 'tableHandle',
58034 'type': 'tensor'
58035 },
58036 {
58037 'start': 1,
58038 'name': 'keys',
58039 'type': 'tensor'
58040 },
58041 {
58042 'start': 2,
58043 'name': 'values',
58044 'type': 'tensor'
58045 }
58046 ],
58047 'attrs': [
58048 {
58049 'tfName': 'Tin',
58050 'name': 'tIn',
58051 'type': 'dtype',
58052 'notSupported': true
58053 },
58054 {
58055 'tfName': 'Tout',
58056 'name': 'tOut',
58057 'type': 'dtype',
58058 'notSupported': true
58059 }
58060 ]
58061 },
58062 {
58063 'tfOpName': 'LookupTableFind',
58064 'category': 'hash_table',
58065 'inputs': [
58066 {
58067 'start': 0,
58068 'name': 'tableHandle',
58069 'type': 'tensor'
58070 },
58071 {
58072 'start': 1,
58073 'name': 'keys',
58074 'type': 'tensor'
58075 },
58076 {
58077 'start': 2,
58078 'name': 'defaultValue',
58079 'type': 'tensor'
58080 }
58081 ],
58082 'attrs': [
58083 {
58084 'tfName': 'Tin',
58085 'name': 'tIn',
58086 'type': 'dtype',
58087 'notSupported': true
58088 },
58089 {
58090 'tfName': 'Tout',
58091 'name': 'tOut',
58092 'type': 'dtype',
58093 'notSupported': true
58094 }
58095 ]
58096 },
58097 {
58098 'tfOpName': 'LookupTableFindV2',
58099 'category': 'hash_table',
58100 'inputs': [
58101 {
58102 'start': 0,
58103 'name': 'tableHandle',
58104 'type': 'tensor'
58105 },
58106 {
58107 'start': 1,
58108 'name': 'keys',
58109 'type': 'tensor'
58110 },
58111 {
58112 'start': 2,
58113 'name': 'defaultValue',
58114 'type': 'tensor'
58115 }
58116 ],
58117 'attrs': [
58118 {
58119 'tfName': 'Tin',
58120 'name': 'tIn',
58121 'type': 'dtype',
58122 'notSupported': true
58123 },
58124 {
58125 'tfName': 'Tout',
58126 'name': 'tOut',
58127 'type': 'dtype',
58128 'notSupported': true
58129 }
58130 ]
58131 },
58132 {
58133 'tfOpName': 'LookupTableSize',
58134 'category': 'hash_table',
58135 'inputs': [
58136 {
58137 'start': 0,
58138 'name': 'tableHandle',
58139 'type': 'tensor'
58140 }
58141 ]
58142 },
58143 {
58144 'tfOpName': 'LookupTableSizeV2',
58145 'category': 'hash_table',
58146 'inputs': [
58147 {
58148 'start': 0,
58149 'name': 'tableHandle',
58150 'type': 'tensor'
58151 }
58152 ]
58153 }
58154 ];
58155
58156 var hashTable = /*#__PURE__*/Object.freeze({
58157 __proto__: null,
58158 json: json$8
58159 });
58160
58161 /**
58162 * @license
58163 * Copyright 2022 Google LLC. All Rights Reserved.
58164 * Licensed under the Apache License, Version 2.0 (the "License");
58165 * you may not use this file except in compliance with the License.
58166 * You may obtain a copy of the License at
58167 *
58168 * http://www.apache.org/licenses/LICENSE-2.0
58169 *
58170 * Unless required by applicable law or agreed to in writing, software
58171 * distributed under the License is distributed on an "AS IS" BASIS,
58172 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
58173 * See the License for the specific language governing permissions and
58174 * limitations under the License.
58175 * =============================================================================
58176 */
58177 const json$9 = [
58178 {
58179 'tfOpName': 'ResizeBilinear',
58180 'category': 'image',
58181 'inputs': [
58182 {
58183 'start': 0,
58184 'name': 'images',
58185 'type': 'tensor'
58186 },
58187 {
58188 'start': 1,
58189 'name': 'size',
58190 'type': 'number[]'
58191 }
58192 ],
58193 'attrs': [
58194 {
58195 'tfName': 'align_corners',
58196 'name': 'alignCorners',
58197 'type': 'bool'
58198 },
58199 {
58200 'tfName': 'half_pixel_centers',
58201 'name': 'halfPixelCenters',
58202 'type': 'bool'
58203 },
58204 {
58205 'tfName': 'T',
58206 'name': 'dtype',
58207 'type': 'dtype',
58208 'notSupported': true
58209 }
58210 ]
58211 },
58212 {
58213 'tfOpName': 'ResizeNearestNeighbor',
58214 'category': 'image',
58215 'inputs': [
58216 {
58217 'start': 0,
58218 'name': 'images',
58219 'type': 'tensor'
58220 },
58221 {
58222 'start': 1,
58223 'name': 'size',
58224 'type': 'number[]'
58225 }
58226 ],
58227 'attrs': [
58228 {
58229 'tfName': 'align_corners',
58230 'name': 'alignCorners',
58231 'type': 'bool'
58232 },
58233 {
58234 'tfName': 'half_pixel_centers',
58235 'name': 'halfPixelCenters',
58236 'type': 'bool'
58237 },
58238 {
58239 'tfName': 'T',
58240 'name': 'dtype',
58241 'type': 'dtype',
58242 'notSupported': true
58243 }
58244 ]
58245 },
58246 {
58247 'tfOpName': 'CropAndResize',
58248 'category': 'image',
58249 'inputs': [
58250 {
58251 'start': 0,
58252 'name': 'image',
58253 'type': 'tensor'
58254 },
58255 {
58256 'start': 1,
58257 'name': 'boxes',
58258 'type': 'tensor'
58259 },
58260 {
58261 'start': 2,
58262 'name': 'boxInd',
58263 'type': 'tensor'
58264 },
58265 {
58266 'start': 3,
58267 'name': 'cropSize',
58268 'type': 'number[]'
58269 }
58270 ],
58271 'attrs': [
58272 {
58273 'tfName': 'method',
58274 'name': 'method',
58275 'type': 'string'
58276 },
58277 {
58278 'tfName': 'extrapolation_value',
58279 'name': 'extrapolationValue',
58280 'type': 'number'
58281 }
58282 ]
58283 },
58284 {
58285 'tfOpName': 'ImageProjectiveTransformV3',
58286 'category': 'image',
58287 'inputs': [
58288 {
58289 'start': 0,
58290 'name': 'images',
58291 'type': 'tensor'
58292 },
58293 {
58294 'start': 1,
58295 'name': 'transforms',
58296 'type': 'tensor'
58297 },
58298 {
58299 'start': 2,
58300 'name': 'outputShape',
58301 'type': 'number[]'
58302 },
58303 {
58304 'start': 3,
58305 'name': 'fillValue',
58306 'type': 'number'
58307 }
58308 ],
58309 'attrs': [
58310 {
58311 'tfName': 'interpolation',
58312 'name': 'interpolation',
58313 'type': 'string'
58314 },
58315 {
58316 'tfName': 'fill_mode',
58317 'name': 'fillMode',
58318 'type': 'string'
58319 }
58320 ]
58321 }
58322 ];
58323
58324 var image$1 = /*#__PURE__*/Object.freeze({
58325 __proto__: null,
58326 json: json$9
58327 });
58328
58329 /**
58330 * @license
58331 * Copyright 2022 Google LLC. All Rights Reserved.
58332 * Licensed under the Apache License, Version 2.0 (the "License");
58333 * you may not use this file except in compliance with the License.
58334 * You may obtain a copy of the License at
58335 *
58336 * http://www.apache.org/licenses/LICENSE-2.0
58337 *
58338 * Unless required by applicable law or agreed to in writing, software
58339 * distributed under the License is distributed on an "AS IS" BASIS,
58340 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
58341 * See the License for the specific language governing permissions and
58342 * limitations under the License.
58343 * =============================================================================
58344 */
58345 const json$a = [
58346 {
58347 'tfOpName': 'Equal',
58348 'category': 'logical',
58349 'inputs': [
58350 {
58351 'start': 0,
58352 'name': 'a',
58353 'type': 'tensor'
58354 },
58355 {
58356 'start': 1,
58357 'name': 'b',
58358 'type': 'tensor'
58359 }
58360 ],
58361 'attrs': [
58362 {
58363 'tfName': 'T',
58364 'name': 'dtype',
58365 'type': 'dtype',
58366 'notSupported': true
58367 }
58368 ]
58369 },
58370 {
58371 'tfOpName': 'NotEqual',
58372 'category': 'logical',
58373 'inputs': [
58374 {
58375 'start': 0,
58376 'name': 'a',
58377 'type': 'tensor'
58378 },
58379 {
58380 'start': 1,
58381 'name': 'b',
58382 'type': 'tensor'
58383 }
58384 ],
58385 'attrs': [
58386 {
58387 'tfName': 'T',
58388 'name': 'dtype',
58389 'type': 'dtype',
58390 'notSupported': true
58391 }
58392 ]
58393 },
58394 {
58395 'tfOpName': 'Greater',
58396 'category': 'logical',
58397 'inputs': [
58398 {
58399 'start': 0,
58400 'name': 'a',
58401 'type': 'tensor'
58402 },
58403 {
58404 'start': 1,
58405 'name': 'b',
58406 'type': 'tensor'
58407 }
58408 ],
58409 'attrs': [
58410 {
58411 'tfName': 'T',
58412 'name': 'dtype',
58413 'type': 'dtype',
58414 'notSupported': true
58415 }
58416 ]
58417 },
58418 {
58419 'tfOpName': 'GreaterEqual',
58420 'category': 'logical',
58421 'inputs': [
58422 {
58423 'start': 0,
58424 'name': 'a',
58425 'type': 'tensor'
58426 },
58427 {
58428 'start': 1,
58429 'name': 'b',
58430 'type': 'tensor'
58431 }
58432 ],
58433 'attrs': [
58434 {
58435 'tfName': 'T',
58436 'name': 'dtype',
58437 'type': 'dtype',
58438 'notSupported': true
58439 }
58440 ]
58441 },
58442 {
58443 'tfOpName': 'Less',
58444 'category': 'logical',
58445 'inputs': [
58446 {
58447 'start': 0,
58448 'name': 'a',
58449 'type': 'tensor'
58450 },
58451 {
58452 'start': 1,
58453 'name': 'b',
58454 'type': 'tensor'
58455 }
58456 ],
58457 'attrs': [
58458 {
58459 'tfName': 'T',
58460 'name': 'dtype',
58461 'type': 'dtype',
58462 'notSupported': true
58463 }
58464 ]
58465 },
58466 {
58467 'tfOpName': 'LessEqual',
58468 'category': 'logical',
58469 'inputs': [
58470 {
58471 'start': 0,
58472 'name': 'a',
58473 'type': 'tensor'
58474 },
58475 {
58476 'start': 1,
58477 'name': 'b',
58478 'type': 'tensor'
58479 }
58480 ],
58481 'attrs': [
58482 {
58483 'tfName': 'T',
58484 'name': 'dtype',
58485 'type': 'dtype',
58486 'notSupported': true
58487 }
58488 ]
58489 },
58490 {
58491 'tfOpName': 'LogicalAnd',
58492 'category': 'logical',
58493 'inputs': [
58494 {
58495 'start': 0,
58496 'name': 'a',
58497 'type': 'tensor'
58498 },
58499 {
58500 'start': 1,
58501 'name': 'b',
58502 'type': 'tensor'
58503 }
58504 ],
58505 'attrs': [
58506 {
58507 'tfName': 'T',
58508 'name': 'dtype',
58509 'type': 'dtype',
58510 'notSupported': true
58511 }
58512 ]
58513 },
58514 {
58515 'tfOpName': 'LogicalNot',
58516 'category': 'logical',
58517 'inputs': [
58518 {
58519 'start': 0,
58520 'name': 'a',
58521 'type': 'tensor'
58522 }
58523 ],
58524 'attrs': [
58525 {
58526 'tfName': 'T',
58527 'name': 'dtype',
58528 'type': 'dtype',
58529 'notSupported': true
58530 }
58531 ]
58532 },
58533 {
58534 'tfOpName': 'LogicalOr',
58535 'category': 'logical',
58536 'inputs': [
58537 {
58538 'start': 0,
58539 'name': 'a',
58540 'type': 'tensor'
58541 },
58542 {
58543 'start': 1,
58544 'name': 'b',
58545 'type': 'tensor'
58546 }
58547 ],
58548 'attrs': [
58549 {
58550 'tfName': 'T',
58551 'name': 'dtype',
58552 'type': 'dtype',
58553 'notSupported': true
58554 }
58555 ]
58556 },
58557 {
58558 'tfOpName': 'Select',
58559 'category': 'logical',
58560 'inputs': [
58561 {
58562 'start': 0,
58563 'name': 'condition',
58564 'type': 'tensor'
58565 },
58566 {
58567 'start': 1,
58568 'name': 'a',
58569 'type': 'tensor'
58570 },
58571 {
58572 'start': 2,
58573 'name': 'b',
58574 'type': 'tensor'
58575 }
58576 ],
58577 'attrs': [
58578 {
58579 'tfName': 'T',
58580 'name': 'dtype',
58581 'type': 'dtype',
58582 'notSupported': true
58583 }
58584 ]
58585 },
58586 {
58587 'tfOpName': 'SelectV2',
58588 'category': 'logical',
58589 'inputs': [
58590 {
58591 'start': 0,
58592 'name': 'condition',
58593 'type': 'tensor'
58594 },
58595 {
58596 'start': 1,
58597 'name': 'a',
58598 'type': 'tensor'
58599 },
58600 {
58601 'start': 2,
58602 'name': 'b',
58603 'type': 'tensor'
58604 }
58605 ],
58606 'attrs': [
58607 {
58608 'tfName': 'T',
58609 'name': 'dtype',
58610 'type': 'dtype',
58611 'notSupported': true
58612 }
58613 ]
58614 }
58615 ];
58616
58617 var logical = /*#__PURE__*/Object.freeze({
58618 __proto__: null,
58619 json: json$a
58620 });
58621
58622 /**
58623 * @license
58624 * Copyright 2022 Google LLC. All Rights Reserved.
58625 * Licensed under the Apache License, Version 2.0 (the "License");
58626 * you may not use this file except in compliance with the License.
58627 * You may obtain a copy of the License at
58628 *
58629 * http://www.apache.org/licenses/LICENSE-2.0
58630 *
58631 * Unless required by applicable law or agreed to in writing, software
58632 * distributed under the License is distributed on an "AS IS" BASIS,
58633 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
58634 * See the License for the specific language governing permissions and
58635 * limitations under the License.
58636 * =============================================================================
58637 */
58638 const json$b = [
58639 {
58640 'tfOpName': '_FusedMatMul',
58641 'category': 'matrices',
58642 'inputs': [
58643 {
58644 'start': 0,
58645 'name': 'a',
58646 'type': 'tensor'
58647 },
58648 {
58649 'start': 1,
58650 'name': 'b',
58651 'type': 'tensor'
58652 },
58653 {
58654 'start': 2,
58655 'end': 0,
58656 'name': 'args',
58657 'type': 'tensors'
58658 }
58659 ],
58660 'attrs': [
58661 {
58662 'tfName': 'num_args',
58663 'name': 'numArgs',
58664 'type': 'number'
58665 },
58666 {
58667 'tfName': 'fused_ops',
58668 'name': 'fusedOps',
58669 'type': 'string[]',
58670 'defaultValue': []
58671 },
58672 {
58673 'tfName': 'epsilon',
58674 'name': 'epsilon',
58675 'type': 'number',
58676 'defaultValue': 0.0001
58677 },
58678 {
58679 'tfName': 'transpose_a',
58680 'name': 'transposeA',
58681 'type': 'bool',
58682 'defaultValue': false
58683 },
58684 {
58685 'tfName': 'transpose_b',
58686 'name': 'transposeB',
58687 'type': 'bool',
58688 'defaultValue': false
58689 },
58690 {
58691 'tfName': 'T',
58692 'name': 'dtype',
58693 'type': 'dtype',
58694 'notSupported': true
58695 }
58696 ]
58697 },
58698 {
58699 'tfOpName': 'MatMul',
58700 'category': 'matrices',
58701 'inputs': [
58702 {
58703 'start': 0,
58704 'name': 'a',
58705 'type': 'tensor'
58706 },
58707 {
58708 'start': 1,
58709 'name': 'b',
58710 'type': 'tensor'
58711 }
58712 ],
58713 'attrs': [
58714 {
58715 'tfName': 'transpose_a',
58716 'name': 'transposeA',
58717 'type': 'bool',
58718 'defaultValue': false
58719 },
58720 {
58721 'tfName': 'transpose_b',
58722 'name': 'transposeB',
58723 'type': 'bool',
58724 'defaultValue': false
58725 },
58726 {
58727 'tfName': 'T',
58728 'name': 'dtype',
58729 'type': 'dtype',
58730 'notSupported': true
58731 }
58732 ]
58733 },
58734 {
58735 'tfOpName': 'BatchMatMul',
58736 'category': 'matrices',
58737 'inputs': [
58738 {
58739 'start': 0,
58740 'name': 'a',
58741 'type': 'tensor'
58742 },
58743 {
58744 'start': 1,
58745 'name': 'b',
58746 'type': 'tensor'
58747 }
58748 ],
58749 'attrs': [
58750 {
58751 'tfName': 'adj_x',
58752 'name': 'transposeA',
58753 'type': 'bool',
58754 'defaultValue': false
58755 },
58756 {
58757 'tfName': 'adj_y',
58758 'name': 'transposeB',
58759 'type': 'bool',
58760 'defaultValue': false
58761 },
58762 {
58763 'tfName': 'T',
58764 'name': 'dtype',
58765 'type': 'dtype',
58766 'notSupported': true
58767 }
58768 ]
58769 },
58770 {
58771 'tfOpName': 'BatchMatMulV2',
58772 'category': 'matrices',
58773 'inputs': [
58774 {
58775 'start': 0,
58776 'name': 'a',
58777 'type': 'tensor'
58778 },
58779 {
58780 'start': 1,
58781 'name': 'b',
58782 'type': 'tensor'
58783 }
58784 ],
58785 'attrs': [
58786 {
58787 'tfName': 'adj_x',
58788 'name': 'transposeA',
58789 'type': 'bool',
58790 'defaultValue': false
58791 },
58792 {
58793 'tfName': 'adj_y',
58794 'name': 'transposeB',
58795 'type': 'bool',
58796 'defaultValue': false
58797 },
58798 {
58799 'tfName': 'T',
58800 'name': 'dtype',
58801 'type': 'dtype',
58802 'notSupported': true
58803 }
58804 ]
58805 },
58806 {
58807 'tfOpName': 'Transpose',
58808 'category': 'matrices',
58809 'inputs': [
58810 {
58811 'start': 0,
58812 'name': 'x',
58813 'type': 'tensor'
58814 },
58815 {
58816 'start': 1,
58817 'name': 'perm',
58818 'type': 'number[]'
58819 }
58820 ],
58821 'attrs': [
58822 {
58823 'tfName': 'T',
58824 'name': 'dtype',
58825 'type': 'dtype',
58826 'notSupported': true
58827 }
58828 ]
58829 },
58830 {
58831 'tfOpName': 'Einsum',
58832 'category': 'matrices',
58833 'inputs': [
58834 {
58835 'start': 0,
58836 'end': 0,
58837 'name': 'tensors',
58838 'type': 'tensors'
58839 }
58840 ],
58841 'attrs': [
58842 {
58843 'tfName': 'equation',
58844 'name': 'equation',
58845 'type': 'string'
58846 },
58847 {
58848 'tfName': 'N',
58849 'name': 'n',
58850 'type': 'number',
58851 'defaultValue': 2
58852 },
58853 {
58854 'tfName': 'T',
58855 'name': 'dtype',
58856 'type': 'dtype'
58857 }
58858 ]
58859 }
58860 ];
58861
58862 var matrices = /*#__PURE__*/Object.freeze({
58863 __proto__: null,
58864 json: json$b
58865 });
58866
58867 /**
58868 * @license
58869 * Copyright 2022 Google LLC. All Rights Reserved.
58870 * Licensed under the Apache License, Version 2.0 (the "License");
58871 * you may not use this file except in compliance with the License.
58872 * You may obtain a copy of the License at
58873 *
58874 * http://www.apache.org/licenses/LICENSE-2.0
58875 *
58876 * Unless required by applicable law or agreed to in writing, software
58877 * distributed under the License is distributed on an "AS IS" BASIS,
58878 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
58879 * See the License for the specific language governing permissions and
58880 * limitations under the License.
58881 * =============================================================================
58882 */
58883 const json$c = [
58884 {
58885 'tfOpName': 'EuclideanNorm',
58886 'category': 'normalization',
58887 'inputs': [
58888 {
58889 'start': 0,
58890 'name': 'x',
58891 'type': 'tensor'
58892 },
58893 {
58894 'start': 1,
58895 'name': 'axis',
58896 'type': 'number[]'
58897 }
58898 ],
58899 'attrs': [
58900 {
58901 'tfName': 'keep_dims',
58902 'name': 'keepDims',
58903 'type': 'bool',
58904 'defaultValue': false
58905 }
58906 ]
58907 },
58908 {
58909 'tfOpName': 'FusedBatchNorm',
58910 'category': 'normalization',
58911 'inputs': [
58912 {
58913 'start': 0,
58914 'name': 'x',
58915 'type': 'tensor'
58916 },
58917 {
58918 'start': 1,
58919 'name': 'scale',
58920 'type': 'tensor'
58921 },
58922 {
58923 'start': 2,
58924 'name': 'offset',
58925 'type': 'tensor'
58926 },
58927 {
58928 'start': 3,
58929 'name': 'mean',
58930 'type': 'tensor'
58931 },
58932 {
58933 'start': 4,
58934 'name': 'variance',
58935 'type': 'tensor'
58936 }
58937 ],
58938 'attrs': [
58939 {
58940 'tfName': 'epsilon',
58941 'name': 'epsilon',
58942 'type': 'number',
58943 'defaultValue': 0.001
58944 },
58945 {
58946 'tfName': 'data_format',
58947 'name': 'dataFormat',
58948 'type': 'string',
58949 'notSupported': true
58950 }
58951 ]
58952 },
58953 {
58954 'tfOpName': 'FusedBatchNormV2',
58955 'category': 'normalization',
58956 'inputs': [
58957 {
58958 'start': 0,
58959 'name': 'x',
58960 'type': 'tensor'
58961 },
58962 {
58963 'start': 1,
58964 'name': 'scale',
58965 'type': 'tensor'
58966 },
58967 {
58968 'start': 2,
58969 'name': 'offset',
58970 'type': 'tensor'
58971 },
58972 {
58973 'start': 3,
58974 'name': 'mean',
58975 'type': 'tensor'
58976 },
58977 {
58978 'start': 4,
58979 'name': 'variance',
58980 'type': 'tensor'
58981 }
58982 ],
58983 'attrs': [
58984 {
58985 'tfName': 'epsilon',
58986 'name': 'epsilon',
58987 'type': 'number',
58988 'defaultValue': 0.001
58989 },
58990 {
58991 'tfName': 'data_format',
58992 'name': 'dataFormat',
58993 'type': 'string',
58994 'notSupported': true
58995 }
58996 ]
58997 },
58998 {
58999 'tfOpName': 'FusedBatchNormV3',
59000 'category': 'normalization',
59001 'inputs': [
59002 {
59003 'start': 0,
59004 'name': 'x',
59005 'type': 'tensor'
59006 },
59007 {
59008 'start': 1,
59009 'name': 'scale',
59010 'type': 'tensor'
59011 },
59012 {
59013 'start': 2,
59014 'name': 'offset',
59015 'type': 'tensor'
59016 },
59017 {
59018 'start': 3,
59019 'name': 'mean',
59020 'type': 'tensor'
59021 },
59022 {
59023 'start': 4,
59024 'name': 'variance',
59025 'type': 'tensor'
59026 }
59027 ],
59028 'attrs': [
59029 {
59030 'tfName': 'epsilon',
59031 'name': 'epsilon',
59032 'type': 'number',
59033 'defaultValue': 0.001
59034 },
59035 {
59036 'tfName': 'data_format',
59037 'name': 'dataFormat',
59038 'type': 'string',
59039 'notSupported': true
59040 }
59041 ]
59042 },
59043 {
59044 'tfOpName': 'LRN',
59045 'category': 'normalization',
59046 'inputs': [
59047 {
59048 'start': 0,
59049 'name': 'x',
59050 'type': 'tensor'
59051 }
59052 ],
59053 'attrs': [
59054 {
59055 'tfName': 'depth_radius',
59056 'name': 'radius',
59057 'type': 'number',
59058 'defaultValue': 5
59059 },
59060 {
59061 'tfName': 'bias',
59062 'name': 'bias',
59063 'type': 'number',
59064 'defaultValue': 1
59065 },
59066 {
59067 'tfName': 'alpha',
59068 'name': 'alpha',
59069 'type': 'number',
59070 'defaultValue': 1
59071 },
59072 {
59073 'tfName': 'beta',
59074 'name': 'beta',
59075 'type': 'number',
59076 'defaultValue': 0.5
59077 }
59078 ]
59079 },
59080 {
59081 'tfOpName': 'Softmax',
59082 'category': 'normalization',
59083 'inputs': [
59084 {
59085 'start': 0,
59086 'name': 'x',
59087 'type': 'tensor'
59088 }
59089 ]
59090 },
59091 {
59092 'tfOpName': 'LogSoftmax',
59093 'category': 'normalization',
59094 'inputs': [
59095 {
59096 'start': 0,
59097 'name': 'x',
59098 'type': 'tensor'
59099 }
59100 ]
59101 },
59102 {
59103 'tfOpName': 'SparseToDense',
59104 'category': 'normalization',
59105 'inputs': [
59106 {
59107 'start': 0,
59108 'name': 'sparseIndices',
59109 'type': 'tensor'
59110 },
59111 {
59112 'start': 1,
59113 'name': 'outputShape',
59114 'type': 'number[]'
59115 },
59116 {
59117 'start': 2,
59118 'name': 'sparseValues',
59119 'type': 'tensor'
59120 },
59121 {
59122 'start': 3,
59123 'name': 'defaultValue',
59124 'type': 'tensor'
59125 }
59126 ],
59127 'attrs': [
59128 {
59129 'tfName': 'validate_indices',
59130 'name': 'validateIndices',
59131 'type': 'bool',
59132 'defaultValue': true,
59133 'notSupported': true
59134 }
59135 ]
59136 }
59137 ];
59138
59139 var normalization = /*#__PURE__*/Object.freeze({
59140 __proto__: null,
59141 json: json$c
59142 });
59143
59144 /**
59145 * @license
59146 * Copyright 2022 Google LLC. All Rights Reserved.
59147 * Licensed under the Apache License, Version 2.0 (the "License");
59148 * you may not use this file except in compliance with the License.
59149 * You may obtain a copy of the License at
59150 *
59151 * http://www.apache.org/licenses/LICENSE-2.0
59152 *
59153 * Unless required by applicable law or agreed to in writing, software
59154 * distributed under the License is distributed on an "AS IS" BASIS,
59155 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
59156 * See the License for the specific language governing permissions and
59157 * limitations under the License.
59158 * =============================================================================
59159 */
59160 const json$d = [
59161 {
59162 'tfOpName': 'Bincount',
59163 'category': 'reduction',
59164 'inputs': [
59165 {
59166 'start': 0,
59167 'name': 'x',
59168 'type': 'tensor'
59169 },
59170 {
59171 'start': 1,
59172 'name': 'size',
59173 'type': 'number'
59174 },
59175 {
59176 'start': 2,
59177 'name': 'weights',
59178 'type': 'tensor'
59179 }
59180 ]
59181 },
59182 {
59183 'tfOpName': 'DenseBincount',
59184 'category': 'reduction',
59185 'inputs': [
59186 {
59187 'start': 0,
59188 'name': 'x',
59189 'type': 'tensor'
59190 },
59191 {
59192 'start': 1,
59193 'name': 'size',
59194 'type': 'number'
59195 },
59196 {
59197 'start': 2,
59198 'name': 'weights',
59199 'type': 'tensor'
59200 }
59201 ],
59202 'attrs': [
59203 {
59204 'tfName': 'binary_output',
59205 'name': 'binaryOutput',
59206 'type': 'bool'
59207 }
59208 ]
59209 },
59210 {
59211 'tfOpName': 'Max',
59212 'category': 'reduction',
59213 'inputs': [
59214 {
59215 'start': 0,
59216 'name': 'x',
59217 'type': 'tensor'
59218 },
59219 {
59220 'start': 1,
59221 'name': 'axis',
59222 'type': 'number[]'
59223 }
59224 ],
59225 'attrs': [
59226 {
59227 'tfName': 'keep_dims',
59228 'name': 'keepDims',
59229 'type': 'bool'
59230 }
59231 ]
59232 },
59233 {
59234 'tfOpName': 'Mean',
59235 'category': 'reduction',
59236 'inputs': [
59237 {
59238 'start': 0,
59239 'name': 'x',
59240 'type': 'tensor'
59241 },
59242 {
59243 'start': 1,
59244 'name': 'axis',
59245 'type': 'number[]'
59246 }
59247 ],
59248 'attrs': [
59249 {
59250 'tfName': 'keep_dims',
59251 'name': 'keepDims',
59252 'type': 'bool'
59253 }
59254 ]
59255 },
59256 {
59257 'tfOpName': 'Min',
59258 'category': 'reduction',
59259 'inputs': [
59260 {
59261 'start': 0,
59262 'name': 'x',
59263 'type': 'tensor'
59264 },
59265 {
59266 'start': 1,
59267 'name': 'axis',
59268 'type': 'number[]'
59269 }
59270 ],
59271 'attrs': [
59272 {
59273 'tfName': 'keep_dims',
59274 'name': 'keepDims',
59275 'type': 'bool'
59276 }
59277 ]
59278 },
59279 {
59280 'tfOpName': 'Sum',
59281 'category': 'reduction',
59282 'inputs': [
59283 {
59284 'start': 0,
59285 'name': 'x',
59286 'type': 'tensor'
59287 },
59288 {
59289 'start': 1,
59290 'name': 'axis',
59291 'type': 'number[]'
59292 }
59293 ],
59294 'attrs': [
59295 {
59296 'tfName': 'keep_dims',
59297 'name': 'keepDims',
59298 'type': 'bool'
59299 }
59300 ]
59301 },
59302 {
59303 'tfOpName': 'All',
59304 'category': 'reduction',
59305 'inputs': [
59306 {
59307 'start': 0,
59308 'name': 'x',
59309 'type': 'tensor'
59310 },
59311 {
59312 'start': 1,
59313 'name': 'axis',
59314 'type': 'number[]'
59315 }
59316 ],
59317 'attrs': [
59318 {
59319 'tfName': 'keep_dims',
59320 'name': 'keepDims',
59321 'type': 'bool'
59322 }
59323 ]
59324 },
59325 {
59326 'tfOpName': 'Any',
59327 'category': 'reduction',
59328 'inputs': [
59329 {
59330 'start': 0,
59331 'name': 'x',
59332 'type': 'tensor'
59333 },
59334 {
59335 'start': 1,
59336 'name': 'axis',
59337 'type': 'number[]'
59338 }
59339 ],
59340 'attrs': [
59341 {
59342 'tfName': 'keep_dims',
59343 'name': 'keepDims',
59344 'type': 'bool'
59345 }
59346 ]
59347 },
59348 {
59349 'tfOpName': 'ArgMax',
59350 'category': 'reduction',
59351 'inputs': [
59352 {
59353 'start': 0,
59354 'name': 'x',
59355 'type': 'tensor'
59356 },
59357 {
59358 'start': 1,
59359 'name': 'axis',
59360 'type': 'number'
59361 }
59362 ]
59363 },
59364 {
59365 'tfOpName': 'ArgMin',
59366 'category': 'reduction',
59367 'inputs': [
59368 {
59369 'start': 0,
59370 'name': 'x',
59371 'type': 'tensor'
59372 },
59373 {
59374 'start': 1,
59375 'name': 'axis',
59376 'type': 'number'
59377 }
59378 ]
59379 },
59380 {
59381 'tfOpName': 'Prod',
59382 'category': 'reduction',
59383 'inputs': [
59384 {
59385 'start': 0,
59386 'name': 'x',
59387 'type': 'tensor'
59388 },
59389 {
59390 'start': 1,
59391 'name': 'axis',
59392 'type': 'number[]'
59393 }
59394 ],
59395 'attrs': [
59396 {
59397 'tfName': 'keep_dims',
59398 'name': 'keepDims',
59399 'type': 'bool'
59400 }
59401 ]
59402 },
59403 {
59404 'tfOpName': 'Cumprod',
59405 'category': 'reduction',
59406 'inputs': [
59407 {
59408 'start': 0,
59409 'name': 'x',
59410 'type': 'tensor'
59411 },
59412 {
59413 'start': 1,
59414 'name': 'axis',
59415 'type': 'number'
59416 }
59417 ],
59418 'attrs': [
59419 {
59420 'tfName': 'exclusive',
59421 'name': 'exclusive',
59422 'type': 'bool'
59423 },
59424 {
59425 'tfName': 'reverse',
59426 'name': 'reverse',
59427 'type': 'bool'
59428 }
59429 ]
59430 },
59431 {
59432 'tfOpName': 'Cumsum',
59433 'category': 'reduction',
59434 'inputs': [
59435 {
59436 'start': 0,
59437 'name': 'x',
59438 'type': 'tensor'
59439 },
59440 {
59441 'start': 1,
59442 'name': 'axis',
59443 'type': 'number'
59444 }
59445 ],
59446 'attrs': [
59447 {
59448 'tfName': 'exclusive',
59449 'name': 'exclusive',
59450 'type': 'bool'
59451 },
59452 {
59453 'tfName': 'reverse',
59454 'name': 'reverse',
59455 'type': 'bool'
59456 }
59457 ]
59458 }
59459 ];
59460
59461 var reduction = /*#__PURE__*/Object.freeze({
59462 __proto__: null,
59463 json: json$d
59464 });
59465
59466 /**
59467 * @license
59468 * Copyright 2022 Google LLC. All Rights Reserved.
59469 * Licensed under the Apache License, Version 2.0 (the "License");
59470 * you may not use this file except in compliance with the License.
59471 * You may obtain a copy of the License at
59472 *
59473 * http://www.apache.org/licenses/LICENSE-2.0
59474 *
59475 * Unless required by applicable law or agreed to in writing, software
59476 * distributed under the License is distributed on an "AS IS" BASIS,
59477 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
59478 * See the License for the specific language governing permissions and
59479 * limitations under the License.
59480 * =============================================================================
59481 */
59482 const json$e = [
59483 {
59484 'tfOpName': 'ConcatV2',
59485 'category': 'slice_join',
59486 'inputs': [
59487 {
59488 'start': 0,
59489 'end': -1,
59490 'name': 'tensors',
59491 'type': 'tensors'
59492 },
59493 {
59494 'start': -1,
59495 'name': 'axis',
59496 'type': 'number'
59497 }
59498 ],
59499 'attrs': [
59500 {
59501 'tfName': 'N',
59502 'name': 'n',
59503 'type': 'number',
59504 'defaultValue': 2
59505 }
59506 ]
59507 },
59508 {
59509 'tfOpName': 'Concat',
59510 'category': 'slice_join',
59511 'inputs': [
59512 {
59513 'start': 1,
59514 'end': 0,
59515 'name': 'tensors',
59516 'type': 'tensors'
59517 },
59518 {
59519 'start': 0,
59520 'name': 'axis',
59521 'type': 'number'
59522 }
59523 ],
59524 'attrs': [
59525 {
59526 'tfName': 'N',
59527 'name': 'n',
59528 'type': 'number',
59529 'defaultValue': 2
59530 }
59531 ]
59532 },
59533 {
59534 'tfOpName': 'GatherV2',
59535 'category': 'slice_join',
59536 'inputs': [
59537 {
59538 'start': 0,
59539 'name': 'x',
59540 'type': 'tensor'
59541 },
59542 {
59543 'start': 1,
59544 'name': 'indices',
59545 'type': 'tensor'
59546 },
59547 {
59548 'start': 2,
59549 'name': 'axis',
59550 'type': 'number',
59551 'defaultValue': 0
59552 }
59553 ],
59554 'attrs': [
59555 {
59556 'tfName': 'batch_dims',
59557 'name': 'batchDims',
59558 'type': 'number',
59559 'defaultValue': 0
59560 }
59561 ]
59562 },
59563 {
59564 'tfOpName': 'Gather',
59565 'category': 'slice_join',
59566 'inputs': [
59567 {
59568 'start': 0,
59569 'name': 'x',
59570 'type': 'tensor'
59571 },
59572 {
59573 'start': 1,
59574 'name': 'indices',
59575 'type': 'tensor'
59576 }
59577 ],
59578 'attrs': [
59579 {
59580 'tfName': 'validate_indices',
59581 'name': 'validateIndices',
59582 'type': 'bool',
59583 'notSupported': true
59584 }
59585 ]
59586 },
59587 {
59588 'tfOpName': 'Reverse',
59589 'category': 'slice_join',
59590 'inputs': [
59591 {
59592 'start': 0,
59593 'name': 'x',
59594 'type': 'tensor'
59595 },
59596 {
59597 'start': 1,
59598 'name': 'dims',
59599 'type': 'bool[]'
59600 }
59601 ]
59602 },
59603 {
59604 'tfOpName': 'ReverseV2',
59605 'category': 'slice_join',
59606 'inputs': [
59607 {
59608 'start': 0,
59609 'name': 'x',
59610 'type': 'tensor'
59611 },
59612 {
59613 'start': 1,
59614 'name': 'axis',
59615 'type': 'number[]'
59616 }
59617 ]
59618 },
59619 {
59620 'tfOpName': 'Slice',
59621 'category': 'slice_join',
59622 'inputs': [
59623 {
59624 'start': 0,
59625 'name': 'x',
59626 'type': 'tensor'
59627 },
59628 {
59629 'start': 1,
59630 'name': 'begin',
59631 'type': 'number[]'
59632 },
59633 {
59634 'start': 2,
59635 'name': 'size',
59636 'type': 'number[]'
59637 }
59638 ]
59639 },
59640 {
59641 'tfOpName': 'StridedSlice',
59642 'category': 'slice_join',
59643 'inputs': [
59644 {
59645 'start': 0,
59646 'name': 'x',
59647 'type': 'tensor'
59648 },
59649 {
59650 'start': 1,
59651 'name': 'begin',
59652 'type': 'number[]'
59653 },
59654 {
59655 'start': 2,
59656 'name': 'end',
59657 'type': 'number[]'
59658 },
59659 {
59660 'start': 3,
59661 'name': 'strides',
59662 'type': 'number[]'
59663 }
59664 ],
59665 'attrs': [
59666 {
59667 'tfName': 'begin_mask',
59668 'name': 'beginMask',
59669 'type': 'number',
59670 'defaultValue': 0
59671 },
59672 {
59673 'tfName': 'end_mask',
59674 'name': 'endMask',
59675 'type': 'number',
59676 'defaultValue': 0
59677 },
59678 {
59679 'tfName': 'new_axis_mask',
59680 'name': 'newAxisMask',
59681 'type': 'number',
59682 'defaultValue': 0
59683 },
59684 {
59685 'tfName': 'ellipsis_mask',
59686 'name': 'ellipsisMask',
59687 'type': 'number',
59688 'defaultValue': 0
59689 },
59690 {
59691 'tfName': 'shrink_axis_mask',
59692 'name': 'shrinkAxisMask',
59693 'type': 'number',
59694 'defaultValue': 0
59695 }
59696 ]
59697 },
59698 {
59699 'tfOpName': 'Pack',
59700 'category': 'slice_join',
59701 'inputs': [
59702 {
59703 'start': 0,
59704 'end': 0,
59705 'name': 'tensors',
59706 'type': 'tensors'
59707 }
59708 ],
59709 'attrs': [
59710 {
59711 'tfName': 'axis',
59712 'name': 'axis',
59713 'type': 'number',
59714 'defaultValue': 0
59715 }
59716 ]
59717 },
59718 {
59719 'tfOpName': 'Unpack',
59720 'category': 'slice_join',
59721 'inputs': [
59722 {
59723 'start': 0,
59724 'name': 'tensor',
59725 'type': 'tensor'
59726 }
59727 ],
59728 'attrs': [
59729 {
59730 'tfName': 'axis',
59731 'name': 'axis',
59732 'type': 'number',
59733 'defaultValue': 0
59734 },
59735 {
59736 'tfName': 'num',
59737 'name': 'num',
59738 'type': 'number',
59739 'defaultValue': 0,
59740 'notSupported': true
59741 }
59742 ]
59743 },
59744 {
59745 'tfOpName': 'Tile',
59746 'category': 'slice_join',
59747 'inputs': [
59748 {
59749 'start': 0,
59750 'name': 'x',
59751 'type': 'tensor'
59752 },
59753 {
59754 'start': 1,
59755 'name': 'reps',
59756 'type': 'number[]'
59757 }
59758 ]
59759 },
59760 {
59761 'tfOpName': 'Split',
59762 'category': 'slice_join',
59763 'inputs': [
59764 {
59765 'start': 0,
59766 'name': 'axis',
59767 'type': 'number',
59768 'defaultValue': 0
59769 },
59770 {
59771 'start': 1,
59772 'name': 'x',
59773 'type': 'tensor'
59774 }
59775 ],
59776 'attrs': [
59777 {
59778 'tfName': 'num_split',
59779 'name': 'numOrSizeSplits',
59780 'type': 'number',
59781 'defaultValue': 1
59782 }
59783 ]
59784 },
59785 {
59786 'tfOpName': 'SplitV',
59787 'category': 'slice_join',
59788 'inputs': [
59789 {
59790 'start': 0,
59791 'name': 'x',
59792 'type': 'tensor'
59793 },
59794 {
59795 'start': 1,
59796 'name': 'numOrSizeSplits',
59797 'type': 'number[]'
59798 },
59799 {
59800 'start': 2,
59801 'name': 'axis',
59802 'type': 'number',
59803 'defaultValue': 0
59804 }
59805 ]
59806 },
59807 {
59808 'tfOpName': 'ScatterNd',
59809 'category': 'slice_join',
59810 'inputs': [
59811 {
59812 'start': 0,
59813 'name': 'indices',
59814 'type': 'tensor'
59815 },
59816 {
59817 'start': 1,
59818 'name': 'values',
59819 'type': 'tensor'
59820 },
59821 {
59822 'start': 2,
59823 'name': 'shape',
59824 'type': 'number[]'
59825 }
59826 ]
59827 },
59828 {
59829 'tfOpName': 'GatherNd',
59830 'category': 'slice_join',
59831 'inputs': [
59832 {
59833 'start': 0,
59834 'name': 'x',
59835 'type': 'tensor'
59836 },
59837 {
59838 'start': 1,
59839 'name': 'indices',
59840 'type': 'tensor'
59841 }
59842 ]
59843 },
59844 {
59845 'tfOpName': 'SparseToDense',
59846 'category': 'slice_join',
59847 'inputs': [
59848 {
59849 'start': 0,
59850 'name': 'sparseIndices',
59851 'type': 'tensor'
59852 },
59853 {
59854 'start': 1,
59855 'name': 'outputShape',
59856 'type': 'number[]'
59857 },
59858 {
59859 'start': 2,
59860 'name': 'sparseValues',
59861 'type': 'tensor'
59862 },
59863 {
59864 'start': 3,
59865 'name': 'defaultValue',
59866 'type': 'tensor'
59867 }
59868 ],
59869 'attrs': [
59870 {
59871 'tfName': 'validate_indices',
59872 'name': 'validateIndices',
59873 'type': 'bool',
59874 'defaultValue': false,
59875 'notSupported': true
59876 }
59877 ]
59878 }
59879 ];
59880
59881 var sliceJoin = /*#__PURE__*/Object.freeze({
59882 __proto__: null,
59883 json: json$e
59884 });
59885
59886 /**
59887 * @license
59888 * Copyright 2022 Google LLC. All Rights Reserved.
59889 * Licensed under the Apache License, Version 2.0 (the "License");
59890 * you may not use this file except in compliance with the License.
59891 * You may obtain a copy of the License at
59892 *
59893 * http://www.apache.org/licenses/LICENSE-2.0
59894 *
59895 * Unless required by applicable law or agreed to in writing, software
59896 * distributed under the License is distributed on an "AS IS" BASIS,
59897 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
59898 * See the License for the specific language governing permissions and
59899 * limitations under the License.
59900 * =============================================================================
59901 */
59902 const json$f = [
59903 {
59904 'tfOpName': 'SparseFillEmptyRows',
59905 'category': 'sparse',
59906 'inputs': [
59907 {
59908 'start': 0,
59909 'name': 'indices',
59910 'type': 'tensor'
59911 },
59912 {
59913 'start': 1,
59914 'name': 'values',
59915 'type': 'tensor'
59916 },
59917 {
59918 'start': 2,
59919 'name': 'denseShape',
59920 'type': 'tensor'
59921 },
59922 {
59923 'start': 3,
59924 'name': 'defaultValue',
59925 'type': 'tensor'
59926 }
59927 ]
59928 },
59929 {
59930 'tfOpName': 'SparseReshape',
59931 'category': 'sparse',
59932 'inputs': [
59933 {
59934 'start': 0,
59935 'name': 'inputIndices',
59936 'type': 'tensor'
59937 },
59938 {
59939 'start': 1,
59940 'name': 'inputShape',
59941 'type': 'tensor'
59942 },
59943 {
59944 'start': 2,
59945 'name': 'newShape',
59946 'type': 'tensor'
59947 }
59948 ],
59949 'attrs': [
59950 {
59951 'tfName': 'T',
59952 'name': 'dtype',
59953 'type': 'dtype',
59954 'notSupported': true
59955 }
59956 ]
59957 },
59958 {
59959 'tfOpName': 'SparseSegmentMean',
59960 'category': 'sparse',
59961 'inputs': [
59962 {
59963 'start': 0,
59964 'name': 'data',
59965 'type': 'tensor'
59966 },
59967 {
59968 'start': 1,
59969 'name': 'indices',
59970 'type': 'tensor'
59971 },
59972 {
59973 'start': 2,
59974 'name': 'segmentIds',
59975 'type': 'tensor'
59976 }
59977 ]
59978 },
59979 {
59980 'tfOpName': 'SparseSegmentSum',
59981 'category': 'sparse',
59982 'inputs': [
59983 {
59984 'start': 0,
59985 'name': 'data',
59986 'type': 'tensor'
59987 },
59988 {
59989 'start': 1,
59990 'name': 'indices',
59991 'type': 'tensor'
59992 },
59993 {
59994 'start': 2,
59995 'name': 'segmentIds',
59996 'type': 'tensor'
59997 }
59998 ]
59999 }
60000 ];
60001
60002 var sparse$1 = /*#__PURE__*/Object.freeze({
60003 __proto__: null,
60004 json: json$f
60005 });
60006
60007 /**
60008 * @license
60009 * Copyright 2022 Google LLC. All Rights Reserved.
60010 * Licensed under the Apache License, Version 2.0 (the "License");
60011 * you may not use this file except in compliance with the License.
60012 * You may obtain a copy of the License at
60013 *
60014 * http://www.apache.org/licenses/LICENSE-2.0
60015 *
60016 * Unless required by applicable law or agreed to in writing, software
60017 * distributed under the License is distributed on an "AS IS" BASIS,
60018 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
60019 * See the License for the specific language governing permissions and
60020 * limitations under the License.
60021 * =============================================================================
60022 */
60023 const json$g = [
60024 {
60025 'tfOpName': 'FFT',
60026 'category': 'spectral',
60027 'inputs': [
60028 {
60029 'start': 0,
60030 'name': 'x',
60031 'type': 'tensor'
60032 }
60033 ]
60034 },
60035 {
60036 'tfOpName': 'IFFT',
60037 'category': 'spectral',
60038 'inputs': [
60039 {
60040 'start': 0,
60041 'name': 'x',
60042 'type': 'tensor'
60043 }
60044 ]
60045 },
60046 {
60047 'tfOpName': 'RFFT',
60048 'category': 'spectral',
60049 'inputs': [
60050 {
60051 'start': 0,
60052 'name': 'x',
60053 'type': 'tensor'
60054 },
60055 {
60056 'start': 1,
60057 'name': 'fft_length',
60058 'type': 'number',
60059 'notSupported': true
60060 }
60061 ]
60062 },
60063 {
60064 'tfOpName': 'IRFFT',
60065 'category': 'spectral',
60066 'inputs': [
60067 {
60068 'start': 0,
60069 'name': 'x',
60070 'type': 'tensor'
60071 },
60072 {
60073 'start': 1,
60074 'name': 'fft_length',
60075 'type': 'number',
60076 'notSupported': true
60077 }
60078 ]
60079 }
60080 ];
60081
60082 var spectral$1 = /*#__PURE__*/Object.freeze({
60083 __proto__: null,
60084 json: json$g
60085 });
60086
60087 /**
60088 * @license
60089 * Copyright 2022 Google LLC. All Rights Reserved.
60090 * Licensed under the Apache License, Version 2.0 (the "License");
60091 * you may not use this file except in compliance with the License.
60092 * You may obtain a copy of the License at
60093 *
60094 * http://www.apache.org/licenses/LICENSE-2.0
60095 *
60096 * Unless required by applicable law or agreed to in writing, software
60097 * distributed under the License is distributed on an "AS IS" BASIS,
60098 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
60099 * See the License for the specific language governing permissions and
60100 * limitations under the License.
60101 * =============================================================================
60102 */
60103 const json$h = [
60104 {
60105 'tfOpName': 'StringNGrams',
60106 'category': 'string',
60107 'inputs': [
60108 {
60109 'start': 0,
60110 'name': 'data',
60111 'type': 'tensor'
60112 },
60113 {
60114 'start': 1,
60115 'name': 'dataSplits',
60116 'type': 'tensor'
60117 }
60118 ],
60119 'attrs': [
60120 {
60121 'tfName': 'separator',
60122 'name': 'separator',
60123 'type': 'string'
60124 },
60125 {
60126 'tfName': 'ngram_widths',
60127 'name': 'nGramWidths',
60128 'type': 'number[]'
60129 },
60130 {
60131 'tfName': 'left_pad',
60132 'name': 'leftPad',
60133 'type': 'string'
60134 },
60135 {
60136 'tfName': 'right_pad',
60137 'name': 'rightPad',
60138 'type': 'string'
60139 },
60140 {
60141 'tfName': 'pad_width',
60142 'name': 'padWidth',
60143 'type': 'number'
60144 },
60145 {
60146 'tfName': 'preserve_short_sequences',
60147 'name': 'preserveShortSequences',
60148 'type': 'bool'
60149 }
60150 ],
60151 'outputs': [
60152 'ngrams',
60153 'ngrams_splits'
60154 ]
60155 },
60156 {
60157 'tfOpName': 'StringSplit',
60158 'category': 'string',
60159 'inputs': [
60160 {
60161 'start': 0,
60162 'name': 'input',
60163 'type': 'tensor'
60164 },
60165 {
60166 'start': 1,
60167 'name': 'delimiter',
60168 'type': 'tensor'
60169 }
60170 ],
60171 'attrs': [
60172 {
60173 'tfName': 'skip_empty',
60174 'name': 'skipEmpty',
60175 'type': 'bool'
60176 }
60177 ],
60178 'outputs': [
60179 'indices',
60180 'values',
60181 'shape'
60182 ]
60183 },
60184 {
60185 'tfOpName': 'StringToHashBucketFast',
60186 'category': 'string',
60187 'inputs': [
60188 {
60189 'start': 0,
60190 'name': 'input',
60191 'type': 'tensor'
60192 }
60193 ],
60194 'attrs': [
60195 {
60196 'tfName': 'num_buckets',
60197 'name': 'numBuckets',
60198 'type': 'number'
60199 }
60200 ]
60201 }
60202 ];
60203
60204 var string$1 = /*#__PURE__*/Object.freeze({
60205 __proto__: null,
60206 json: json$h
60207 });
60208
60209 /**
60210 * @license
60211 * Copyright 2022 Google LLC. All Rights Reserved.
60212 * Licensed under the Apache License, Version 2.0 (the "License");
60213 * you may not use this file except in compliance with the License.
60214 * You may obtain a copy of the License at
60215 *
60216 * http://www.apache.org/licenses/LICENSE-2.0
60217 *
60218 * Unless required by applicable law or agreed to in writing, software
60219 * distributed under the License is distributed on an "AS IS" BASIS,
60220 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
60221 * See the License for the specific language governing permissions and
60222 * limitations under the License.
60223 * =============================================================================
60224 */
60225 const json$i = [
60226 {
60227 'tfOpName': 'Cast',
60228 'category': 'transformation',
60229 'inputs': [
60230 {
60231 'start': 0,
60232 'name': 'x',
60233 'type': 'tensor'
60234 }
60235 ],
60236 'attrs': [
60237 {
60238 'tfName': 'SrcT',
60239 'name': 'sdtype',
60240 'type': 'dtype',
60241 'notSupported': true
60242 },
60243 {
60244 'tfName': 'DstT',
60245 'name': 'dtype',
60246 'type': 'dtype'
60247 }
60248 ]
60249 },
60250 {
60251 'tfOpName': 'ExpandDims',
60252 'category': 'transformation',
60253 'inputs': [
60254 {
60255 'start': 0,
60256 'name': 'x',
60257 'type': 'tensor'
60258 },
60259 {
60260 'start': 1,
60261 'name': 'axis',
60262 'type': 'number'
60263 }
60264 ]
60265 },
60266 {
60267 'tfOpName': 'MirrorPad',
60268 'category': 'transformation',
60269 'inputs': [
60270 {
60271 'start': 0,
60272 'name': 'x',
60273 'type': 'tensor'
60274 },
60275 {
60276 'start': 1,
60277 'name': 'padding',
60278 'type': 'number[]'
60279 }
60280 ],
60281 'attrs': [
60282 {
60283 'tfName': 'mode',
60284 'name': 'mode',
60285 'type': 'string'
60286 }
60287 ]
60288 },
60289 {
60290 'tfOpName': 'Pad',
60291 'category': 'transformation',
60292 'inputs': [
60293 {
60294 'start': 0,
60295 'name': 'x',
60296 'type': 'tensor'
60297 },
60298 {
60299 'start': 1,
60300 'name': 'padding',
60301 'type': 'number[]'
60302 }
60303 ],
60304 'attrs': [
60305 {
60306 'tfName': 'constant_value',
60307 'name': 'constantValue',
60308 'type': 'number',
60309 'defaultValue': 0
60310 }
60311 ]
60312 },
60313 {
60314 'tfOpName': 'PadV2',
60315 'category': 'transformation',
60316 'inputs': [
60317 {
60318 'start': 0,
60319 'name': 'x',
60320 'type': 'tensor'
60321 },
60322 {
60323 'start': 1,
60324 'name': 'padding',
60325 'type': 'number[]'
60326 },
60327 {
60328 'start': 2,
60329 'name': 'constantValue',
60330 'type': 'number',
60331 'defaultValue': 0
60332 }
60333 ]
60334 },
60335 {
60336 'tfOpName': 'Reshape',
60337 'category': 'transformation',
60338 'inputs': [
60339 {
60340 'start': 0,
60341 'name': 'x',
60342 'type': 'tensor'
60343 },
60344 {
60345 'start': 1,
60346 'name': 'shape',
60347 'type': 'number[]'
60348 }
60349 ]
60350 },
60351 {
60352 'tfOpName': 'Squeeze',
60353 'category': 'transformation',
60354 'inputs': [
60355 {
60356 'start': 0,
60357 'name': 'x',
60358 'type': 'tensor'
60359 }
60360 ],
60361 'attrs': [
60362 {
60363 'tfName': 'axis',
60364 'tfDeprecatedName': 'squeeze_dims',
60365 'name': 'axis',
60366 'type': 'number[]'
60367 }
60368 ]
60369 },
60370 {
60371 'tfOpName': 'SpaceToBatchND',
60372 'category': 'transformation',
60373 'inputs': [
60374 {
60375 'start': 0,
60376 'name': 'x',
60377 'type': 'tensor'
60378 },
60379 {
60380 'start': 1,
60381 'name': 'blockShape',
60382 'type': 'number[]'
60383 },
60384 {
60385 'start': 2,
60386 'name': 'paddings',
60387 'type': 'number[]'
60388 }
60389 ]
60390 },
60391 {
60392 'tfOpName': 'BatchToSpaceND',
60393 'category': 'transformation',
60394 'inputs': [
60395 {
60396 'start': 0,
60397 'name': 'x',
60398 'type': 'tensor'
60399 },
60400 {
60401 'start': 1,
60402 'name': 'blockShape',
60403 'type': 'number[]'
60404 },
60405 {
60406 'start': 2,
60407 'name': 'crops',
60408 'type': 'number[]'
60409 }
60410 ]
60411 },
60412 {
60413 'tfOpName': 'DepthToSpace',
60414 'category': 'transformation',
60415 'inputs': [
60416 {
60417 'start': 0,
60418 'name': 'x',
60419 'type': 'tensor'
60420 }
60421 ],
60422 'attrs': [
60423 {
60424 'tfName': 'block_size',
60425 'name': 'blockSize',
60426 'type': 'number'
60427 },
60428 {
60429 'tfName': 'data_format',
60430 'name': 'dataFormat',
60431 'type': 'string'
60432 }
60433 ]
60434 },
60435 {
60436 'tfOpName': 'BroadcastTo',
60437 'category': 'transformation',
60438 'inputs': [
60439 {
60440 'start': 0,
60441 'name': 'x',
60442 'type': 'tensor'
60443 },
60444 {
60445 'start': 1,
60446 'name': 'shape',
60447 'type': 'number[]'
60448 }
60449 ],
60450 'attrs': []
60451 },
60452 {
60453 'tfOpName': 'BroadcastArgs',
60454 'category': 'transformation',
60455 'inputs': [
60456 {
60457 'start': 0,
60458 'name': 's0',
60459 'type': 'tensor'
60460 },
60461 {
60462 'start': 1,
60463 'name': 's1',
60464 'type': 'tensor'
60465 }
60466 ],
60467 'attrs': []
60468 }
60469 ];
60470
60471 var transformation = /*#__PURE__*/Object.freeze({
60472 __proto__: null,
60473 json: json$i
60474 });
60475
60476 /**
60477 * @license
60478 * Copyright 2018 Google LLC. All Rights Reserved.
60479 * Licensed under the Apache License, Version 2.0 (the "License");
60480 * you may not use this file except in compliance with the License.
60481 * You may obtain a copy of the License at
60482 *
60483 * http://www.apache.org/licenses/LICENSE-2.0
60484 *
60485 * Unless required by applicable law or agreed to in writing, software
60486 * distributed under the License is distributed on an "AS IS" BASIS,
60487 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
60488 * See the License for the specific language governing permissions and
60489 * limitations under the License.
60490 * =============================================================================
60491 */
60492 class OperationMapper {
60493 // Singleton instance for the mapper
60494 static get Instance() {
60495 return this._instance || (this._instance = new this());
60496 }
60497 // Loads the op mapping from the JSON file.
60498 constructor() {
60499 const ops = [
60500 arithmetic, basicMath, control, convolution, creation, dynamic,
60501 evaluation, graph, hashTable, image$1, logical, matrices, normalization,
60502 reduction, sliceJoin, sparse$1, spectral$1, string$1, transformation
60503 ];
60504 const mappersJson = [].concat(...ops.map(op => op.json));
60505 this.opMappers = mappersJson.reduce((map, mapper) => {
60506 map[mapper.tfOpName] = mapper;
60507 return map;
60508 }, {});
60509 }
60510 // Converts the model inference graph from Tensorflow GraphDef to local
60511 // representation for TensorFlow.js API
60512 transformGraph(graph, signature = {}) {
60513 const tfNodes = graph.node;
60514 const placeholders = [];
60515 const weights = [];
60516 const initNodes = [];
60517 const nodes = tfNodes.reduce((map, node) => {
60518 map[node.name] = this.mapNode(node);
60519 if (node.op.startsWith('Placeholder')) {
60520 placeholders.push(map[node.name]);
60521 }
60522 else if (node.op === 'Const') {
60523 weights.push(map[node.name]);
60524 }
60525 else if (node.input == null || node.input.length === 0) {
60526 initNodes.push(map[node.name]);
60527 }
60528 return map;
60529 }, {});
60530 let inputs = [];
60531 const outputs = [];
60532 let inputNodeNameToKey = {};
60533 let outputNodeNameToKey = {};
60534 if (signature != null) {
60535 inputNodeNameToKey = this.mapSignatureEntries(signature.inputs);
60536 outputNodeNameToKey = this.mapSignatureEntries(signature.outputs);
60537 }
60538 const allNodes = Object.keys(nodes);
60539 allNodes.forEach(key => {
60540 const node = nodes[key];
60541 node.inputNames.forEach((name, index) => {
60542 const [nodeName, , outputName] = getNodeNameAndIndex(name);
60543 const inputNode = nodes[nodeName];
60544 if (inputNode.outputs != null) {
60545 const outputIndex = inputNode.outputs.indexOf(outputName);
60546 if (outputIndex !== -1) {
60547 const inputName = `${nodeName}:${outputIndex}`;
60548 // update the input name to use the mapped output index directly.
60549 node.inputNames[index] = inputName;
60550 }
60551 }
60552 node.inputs.push(inputNode);
60553 inputNode.children.push(node);
60554 });
60555 });
60556 // if signature has not outputs set, add any node that does not have
60557 // outputs.
60558 if (Object.keys(outputNodeNameToKey).length === 0) {
60559 allNodes.forEach(key => {
60560 const node = nodes[key];
60561 if (node.children.length === 0) {
60562 outputs.push(node);
60563 }
60564 });
60565 }
60566 else {
60567 Object.keys(outputNodeNameToKey).forEach(name => {
60568 const [nodeName,] = getNodeNameAndIndex(name);
60569 const node = nodes[nodeName];
60570 if (node != null) {
60571 node.signatureKey = outputNodeNameToKey[name];
60572 outputs.push(node);
60573 }
60574 });
60575 }
60576 if (Object.keys(inputNodeNameToKey).length > 0) {
60577 Object.keys(inputNodeNameToKey).forEach(name => {
60578 const [nodeName,] = getNodeNameAndIndex(name);
60579 const node = nodes[nodeName];
60580 if (node) {
60581 node.signatureKey = inputNodeNameToKey[name];
60582 inputs.push(node);
60583 }
60584 });
60585 }
60586 else {
60587 inputs = placeholders;
60588 }
60589 let functions = {};
60590 if (graph.library != null && graph.library.function != null) {
60591 functions = graph.library.function.reduce((functions, func) => {
60592 functions[func.signature.name] = this.mapFunction(func);
60593 return functions;
60594 }, {});
60595 }
60596 const result = { nodes, inputs, outputs, weights, placeholders, signature, functions };
60597 if (initNodes.length > 0) {
60598 result.initNodes = initNodes;
60599 }
60600 return result;
60601 }
60602 mapSignatureEntries(entries) {
60603 return Object.keys(entries || {})
60604 .reduce((prev, curr) => {
60605 prev[entries[curr].name] = curr;
60606 return prev;
60607 }, {});
60608 }
60609 mapNode(node) {
60610 // Unsupported ops will cause an error at run-time (not parse time), since
60611 // they may not be used by the actual execution subgraph.
60612 const mapper = getRegisteredOp(node.op) || this.opMappers[node.op] || {};
60613 if (node.attr == null) {
60614 node.attr = {};
60615 }
60616 const newNode = {
60617 name: node.name,
60618 op: node.op,
60619 category: mapper.category,
60620 inputNames: (node.input ||
60621 []).map(input => input.startsWith('^') ? input.slice(1) : input),
60622 inputs: [],
60623 children: [],
60624 inputParams: {},
60625 attrParams: {},
60626 rawAttrs: node.attr,
60627 outputs: mapper.outputs
60628 };
60629 if (mapper.inputs != null) {
60630 newNode.inputParams =
60631 mapper.inputs.reduce((map, param) => {
60632 map[param.name] = {
60633 type: param.type,
60634 inputIndexStart: param.start,
60635 inputIndexEnd: param.end
60636 };
60637 return map;
60638 }, {});
60639 }
60640 if (mapper.attrs != null) {
60641 newNode.attrParams =
60642 mapper.attrs.reduce((map, param) => {
60643 const type = param.type;
60644 let value = undefined;
60645 switch (param.type) {
60646 case 'string':
60647 value = getStringParam(node.attr, param.tfName, param.defaultValue);
60648 if (value === undefined && !!param.tfDeprecatedName) {
60649 value = getStringParam(node.attr, param.tfDeprecatedName, param.defaultValue);
60650 }
60651 break;
60652 case 'string[]':
60653 value = getStringArrayParam(node.attr, param.tfName, param.defaultValue);
60654 if (value === undefined && !!param.tfDeprecatedName) {
60655 value = getStringArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
60656 }
60657 break;
60658 case 'number':
60659 value = getNumberParam(node.attr, param.tfName, (param.defaultValue || 0));
60660 if (value === undefined && !!param.tfDeprecatedName) {
60661 value = getNumberParam(node.attr, param.tfDeprecatedName, param.defaultValue);
60662 }
60663 break;
60664 case 'number[]':
60665 value = getNumericArrayParam(node.attr, param.tfName, param.defaultValue);
60666 if (value === undefined && !!param.tfDeprecatedName) {
60667 value = getNumericArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
60668 }
60669 break;
60670 case 'bool':
60671 value = getBoolParam(node.attr, param.tfName, param.defaultValue);
60672 if (value === undefined && !!param.tfDeprecatedName) {
60673 value = getBoolParam(node.attr, param.tfDeprecatedName, param.defaultValue);
60674 }
60675 break;
60676 case 'bool[]':
60677 value = getBoolArrayParam(node.attr, param.tfName, param.defaultValue);
60678 if (value === undefined && !!param.tfDeprecatedName) {
60679 value = getBoolArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
60680 }
60681 break;
60682 case 'shape':
60683 value = getTensorShapeParam(node.attr, param.tfName, param.defaultValue);
60684 if (value === undefined && !!param.tfDeprecatedName) {
60685 value = getTensorShapeParam(node.attr, param.tfDeprecatedName, param.defaultValue);
60686 }
60687 break;
60688 case 'shape[]':
60689 value = getTensorShapeArrayParam(node.attr, param.tfName, param.defaultValue);
60690 if (value === undefined && !!param.tfDeprecatedName) {
60691 value = getTensorShapeArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
60692 }
60693 break;
60694 case 'dtype':
60695 value = getDtypeParam(node.attr, param.tfName, param.defaultValue);
60696 if (value === undefined && !!param.tfDeprecatedName) {
60697 value = getDtypeParam(node.attr, param.tfDeprecatedName, param.defaultValue);
60698 }
60699 break;
60700 case 'dtype[]':
60701 value = getDtypeArrayParam(node.attr, param.tfName, param.defaultValue);
60702 if (value === undefined && !!param.tfDeprecatedName) {
60703 value = getDtypeArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
60704 }
60705 break;
60706 case 'func':
60707 value = getFuncParam(node.attr, param.tfName, param.defaultValue);
60708 if (value === undefined && !!param.tfDeprecatedName) {
60709 value = getFuncParam(node.attr, param.tfDeprecatedName, param.defaultValue);
60710 }
60711 break;
60712 case 'tensor':
60713 case 'tensors':
60714 break;
60715 default:
60716 throw new Error(`Unsupported param type: ${param.type} for op: ${node.op}`);
60717 }
60718 map[param.name] = { value, type };
60719 return map;
60720 }, {});
60721 }
60722 return newNode;
60723 }
60724 // map the TFunctionDef to TFJS graph object
60725 mapFunction(functionDef) {
60726 const tfNodes = functionDef.nodeDef;
60727 const placeholders = [];
60728 const weights = [];
60729 let nodes = {};
60730 if (tfNodes != null) {
60731 nodes = tfNodes.reduce((map, node) => {
60732 map[node.name] = this.mapNode(node);
60733 if (node.op === 'Const') {
60734 weights.push(map[node.name]);
60735 }
60736 return map;
60737 }, {});
60738 }
60739 const inputs = [];
60740 const outputs = [];
60741 functionDef.signature.inputArg.forEach(arg => {
60742 const [nodeName,] = getNodeNameAndIndex(arg.name);
60743 const node = {
60744 name: nodeName,
60745 op: 'Placeholder',
60746 inputs: [],
60747 inputNames: [],
60748 category: 'graph',
60749 inputParams: {},
60750 attrParams: { dtype: { value: parseDtypeParam(arg.type), type: 'dtype' } },
60751 children: []
60752 };
60753 node.signatureKey = arg.name;
60754 inputs.push(node);
60755 nodes[nodeName] = node;
60756 });
60757 const allNodes = Object.keys(nodes);
60758 allNodes.forEach(key => {
60759 const node = nodes[key];
60760 node.inputNames.forEach((name, index) => {
60761 const [nodeName, , outputName] = getNodeNameAndIndex(name);
60762 const inputNode = nodes[nodeName];
60763 if (inputNode.outputs != null) {
60764 const outputIndex = inputNode.outputs.indexOf(outputName);
60765 if (outputIndex !== -1) {
60766 const inputName = `${nodeName}:${outputIndex}`;
60767 // update the input name to use the mapped output index directly.
60768 node.inputNames[index] = inputName;
60769 }
60770 }
60771 node.inputs.push(inputNode);
60772 inputNode.children.push(node);
60773 });
60774 });
60775 const returnNodeMap = functionDef.ret;
60776 functionDef.signature.outputArg.forEach(output => {
60777 const [nodeName, index] = getNodeNameAndIndex(returnNodeMap[output.name]);
60778 const node = nodes[nodeName];
60779 if (node != null) {
60780 node.defaultOutput = index;
60781 outputs.push(node);
60782 }
60783 });
60784 const signature = this.mapArgsToSignature(functionDef);
60785 return { nodes, inputs, outputs, weights, placeholders, signature };
60786 }
60787 mapArgsToSignature(functionDef) {
60788 return {
60789 methodName: functionDef.signature.name,
60790 inputs: functionDef.signature.inputArg.reduce((map, arg) => {
60791 map[arg.name] = this.mapArgToTensorInfo(arg);
60792 return map;
60793 }, {}),
60794 outputs: functionDef.signature.outputArg.reduce((map, arg) => {
60795 map[arg.name] = this.mapArgToTensorInfo(arg, functionDef.ret);
60796 return map;
60797 }, {}),
60798 };
60799 }
60800 mapArgToTensorInfo(arg, nameMap) {
60801 let name = arg.name;
60802 if (nameMap != null) {
60803 name = nameMap[name];
60804 }
60805 return { name, dtype: arg.type };
60806 }
60807 }
60808 function decodeBase64(text) {
60809 const global = env().global;
60810 if (typeof global.atob !== 'undefined') {
60811 return global.atob(text);
60812 }
60813 else if (typeof Buffer !== 'undefined') {
60814 return new Buffer(text, 'base64').toString();
60815 }
60816 else {
60817 throw new Error('Unable to decode base64 in this environment. ' +
60818 'Missing built-in atob() or Buffer()');
60819 }
60820 }
60821 function parseStringParam(s, keepCase) {
60822 const value = Array.isArray(s) ? String.fromCharCode.apply(null, s) : decodeBase64(s);
60823 return keepCase ? value : value.toLowerCase();
60824 }
60825 function getStringParam(attrs, name, def, keepCase = false) {
60826 const param = attrs[name];
60827 if (param != null) {
60828 return parseStringParam(param.s, keepCase);
60829 }
60830 return def;
60831 }
60832 function getBoolParam(attrs, name, def) {
60833 const param = attrs[name];
60834 return param ? param.b : def;
60835 }
60836 function getNumberParam(attrs, name, def) {
60837 const param = attrs[name] || {};
60838 const value = param['i'] != null ? param['i'] : (param['f'] != null ? param['f'] : def);
60839 return (typeof value === 'number') ? value : parseInt(value, 10);
60840 }
60841 function parseDtypeParam(value) {
60842 if (typeof (value) === 'string') {
60843 // tslint:disable-next-line:no-any
60844 value = DataType[value];
60845 }
60846 switch (value) {
60847 case DataType.DT_FLOAT:
60848 case DataType.DT_HALF:
60849 return 'float32';
60850 case DataType.DT_INT32:
60851 case DataType.DT_INT64:
60852 case DataType.DT_INT8:
60853 case DataType.DT_UINT8:
60854 return 'int32';
60855 case DataType.DT_BOOL:
60856 return 'bool';
60857 case DataType.DT_DOUBLE:
60858 return 'float32';
60859 case DataType.DT_STRING:
60860 return 'string';
60861 default:
60862 // Unknown dtype error will happen at runtime (instead of parse time),
60863 // since these nodes might not be used by the actual subgraph execution.
60864 return null;
60865 }
60866 }
60867 function getFuncParam(attrs, name, def) {
60868 const param = attrs[name];
60869 if (param && param.func) {
60870 return param.func.name;
60871 }
60872 return def;
60873 }
60874 function getDtypeParam(attrs, name, def) {
60875 const param = attrs[name];
60876 if (param && param.type) {
60877 return parseDtypeParam(param.type);
60878 }
60879 return def;
60880 }
60881 function getDtypeArrayParam(attrs, name, def) {
60882 const param = attrs[name];
60883 if (param && param.list && param.list.type) {
60884 return param.list.type.map(v => parseDtypeParam(v));
60885 }
60886 return def;
60887 }
60888 function parseTensorShapeParam(shape) {
60889 if (shape.unknownRank) {
60890 return undefined;
60891 }
60892 if (shape.dim != null) {
60893 return shape.dim.map(dim => (typeof dim.size === 'number') ? dim.size : parseInt(dim.size, 10));
60894 }
60895 return [];
60896 }
60897 function getTensorShapeParam(attrs, name, def) {
60898 const param = attrs[name];
60899 if (param && param.shape) {
60900 return parseTensorShapeParam(param.shape);
60901 }
60902 return def;
60903 }
60904 function getNumericArrayParam(attrs, name, def) {
60905 const param = attrs[name];
60906 if (param) {
60907 return ((param.list.f && param.list.f.length ? param.list.f :
60908 param.list.i) ||
60909 [])
60910 .map(v => (typeof v === 'number') ? v : parseInt(v, 10));
60911 }
60912 return def;
60913 }
60914 function getStringArrayParam(attrs, name, def, keepCase = false) {
60915 const param = attrs[name];
60916 if (param && param.list && param.list.s) {
60917 return param.list.s.map((v) => {
60918 return parseStringParam(v, keepCase);
60919 });
60920 }
60921 return def;
60922 }
60923 function getTensorShapeArrayParam(attrs, name, def) {
60924 const param = attrs[name];
60925 if (param && param.list && param.list.shape) {
60926 return param.list.shape.map((v) => {
60927 return parseTensorShapeParam(v);
60928 });
60929 }
60930 return def;
60931 }
60932 function getBoolArrayParam(attrs, name, def) {
60933 const param = attrs[name];
60934 if (param && param.list && param.list.b) {
60935 return param.list.b;
60936 }
60937 return def;
60938 }
60939
60940 /**
60941 * @license
60942 * Copyright 2019 Google LLC. All Rights Reserved.
60943 * Licensed under the Apache License, Version 2.0 (the "License");
60944 * you may not use this file except in compliance with the License.
60945 * You may obtain a copy of the License at
60946 *
60947 * http://www.apache.org/licenses/LICENSE-2.0
60948 *
60949 * Unless required by applicable law or agreed to in writing, software
60950 * distributed under the License is distributed on an "AS IS" BASIS,
60951 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
60952 * See the License for the specific language governing permissions and
60953 * limitations under the License.
60954 * =============================================================================
60955 */
60956 /**
60957 * Helper class for lookup inputs and params for nodes in the model graph.
60958 */
60959 class NodeValueImpl {
60960 constructor(node, tensorMap, context) {
60961 this.node = node;
60962 this.tensorMap = tensorMap;
60963 this.context = context;
60964 this.inputs = [];
60965 this.attrs = {};
60966 this.inputs = node.inputNames.map(name => this.getInput(name));
60967 if (node.rawAttrs != null) {
60968 this.attrs = Object.keys(node.rawAttrs)
60969 .reduce((attrs, key) => {
60970 attrs[key] = this.getAttr(key);
60971 return attrs;
60972 }, {});
60973 }
60974 }
60975 /**
60976 * Return the value of the attribute or input param.
60977 * @param name String: name of attribute or input param.
60978 */
60979 getInput(name) {
60980 return getTensor(name, this.tensorMap, this.context);
60981 }
60982 /**
60983 * Return the value of the attribute or input param.
60984 * @param name String: name of attribute or input param.
60985 */
60986 getAttr(name, defaultValue) {
60987 const value = this.node.rawAttrs[name];
60988 if (value.tensor != null) {
60989 return getTensor(name, this.tensorMap, this.context);
60990 }
60991 if (value.i != null || value.f != null) {
60992 return getNumberParam(this.node.rawAttrs, name, defaultValue);
60993 }
60994 if (value.s != null) {
60995 return getStringParam(this.node.rawAttrs, name, defaultValue);
60996 }
60997 if (value.b != null) {
60998 return getBoolParam(this.node.rawAttrs, name, defaultValue);
60999 }
61000 if (value.shape != null) {
61001 return getTensorShapeParam(this.node.rawAttrs, name, defaultValue);
61002 }
61003 if (value.type != null) {
61004 return getDtypeParam(this.node.rawAttrs, name, defaultValue);
61005 }
61006 if (value.list != null) {
61007 if (value.list.i != null || value.list.f != null) {
61008 return getNumericArrayParam(this.node.rawAttrs, name, defaultValue);
61009 }
61010 if (value.list.s != null) {
61011 return getStringArrayParam(this.node.rawAttrs, name, defaultValue);
61012 }
61013 if (value.list.shape != null) {
61014 return getTensorShapeArrayParam(this.node.rawAttrs, name, defaultValue);
61015 }
61016 if (value.list.b != null) {
61017 return getBoolArrayParam(this.node.rawAttrs, name, defaultValue);
61018 }
61019 if (value.list.type != null) {
61020 return getDtypeArrayParam(this.node.rawAttrs, name, defaultValue);
61021 }
61022 }
61023 return defaultValue;
61024 }
61025 }
61026
61027 /**
61028 * @license
61029 * Copyright 2020 Google LLC. All Rights Reserved.
61030 * Licensed under the Apache License, Version 2.0 (the "License");
61031 * you may not use this file except in compliance with the License.
61032 * You may obtain a copy of the License at
61033 *
61034 * http://www.apache.org/licenses/LICENSE-2.0
61035 *
61036 * Unless required by applicable law or agreed to in writing, software
61037 * distributed under the License is distributed on an "AS IS" BASIS,
61038 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
61039 * See the License for the specific language governing permissions and
61040 * limitations under the License.
61041 * =============================================================================
61042 */
61043
61044 /**
61045 * @license
61046 * Copyright 2018 Google LLC. All Rights Reserved.
61047 * Licensed under the Apache License, Version 2.0 (the "License");
61048 * you may not use this file except in compliance with the License.
61049 * You may obtain a copy of the License at
61050 *
61051 * http://www.apache.org/licenses/LICENSE-2.0
61052 *
61053 * Unless required by applicable law or agreed to in writing, software
61054 * distributed under the License is distributed on an "AS IS" BASIS,
61055 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
61056 * See the License for the specific language governing permissions and
61057 * limitations under the License.
61058 * =============================================================================
61059 */
61060 const executeOp = (node, tensorMap, context) => {
61061 switch (node.op) {
61062 case 'BiasAdd':
61063 case 'AddV2':
61064 case 'Add': {
61065 return [add$1(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
61066 }
61067 case 'AddN': {
61068 return [addN(getParamValue('tensors', node, tensorMap, context))];
61069 }
61070 case 'FloorMod':
61071 case 'Mod':
61072 return [mod(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
61073 case 'Mul':
61074 return [mul(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
61075 case 'RealDiv':
61076 case 'Div': {
61077 return [div(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
61078 }
61079 case 'DivNoNan': {
61080 return [divNoNan(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
61081 }
61082 case 'FloorDiv': {
61083 return [floorDiv(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
61084 }
61085 case 'Sub': {
61086 return [sub(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
61087 }
61088 case 'Minimum': {
61089 return [minimum(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
61090 }
61091 case 'Maximum': {
61092 return [maximum(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
61093 }
61094 case 'Pow': {
61095 return [pow(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
61096 }
61097 case 'SquaredDifference': {
61098 return [squaredDifference(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
61099 }
61100 default:
61101 throw TypeError(`Node type ${node.op} is not implemented`);
61102 }
61103 };
61104 const CATEGORY = 'arithmetic';
61105
61106 /**
61107 * @license
61108 * Copyright 2018 Google LLC. All Rights Reserved.
61109 * Licensed under the Apache License, Version 2.0 (the "License");
61110 * you may not use this file except in compliance with the License.
61111 * You may obtain a copy of the License at
61112 *
61113 * http://www.apache.org/licenses/LICENSE-2.0
61114 *
61115 * Unless required by applicable law or agreed to in writing, software
61116 * distributed under the License is distributed on an "AS IS" BASIS,
61117 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
61118 * See the License for the specific language governing permissions and
61119 * limitations under the License.
61120 * =============================================================================
61121 */
61122 const executeOp$1 = (node, tensorMap, context) => {
61123 switch (node.op) {
61124 case 'Abs':
61125 case 'ComplexAbs':
61126 return [abs(getParamValue('x', node, tensorMap, context))];
61127 case 'Acos':
61128 return [acos(getParamValue('x', node, tensorMap, context))];
61129 case 'Acosh':
61130 return [acosh(getParamValue('x', node, tensorMap, context))];
61131 case 'Asin':
61132 return [asin(getParamValue('x', node, tensorMap, context))];
61133 case 'Asinh':
61134 return [asinh(getParamValue('x', node, tensorMap, context))];
61135 case 'Atan':
61136 return [atan(getParamValue('x', node, tensorMap, context))];
61137 case 'Atan2':
61138 return [atan2(getParamValue('x', node, tensorMap, context), getParamValue('y', node, tensorMap, context))];
61139 case 'Atanh':
61140 return [atanh(getParamValue('x', node, tensorMap, context))];
61141 case 'Ceil':
61142 return [ceil(getParamValue('x', node, tensorMap, context))];
61143 case 'Complex':
61144 return [complex(getParamValue('real', node, tensorMap, context), getParamValue('imag', node, tensorMap, context))];
61145 case 'Cos':
61146 return [cos(getParamValue('x', node, tensorMap, context))];
61147 case 'Cosh':
61148 return [cosh(getParamValue('x', node, tensorMap, context))];
61149 case 'Elu':
61150 return [elu(getParamValue('x', node, tensorMap, context))];
61151 case 'Erf':
61152 return [erf(getParamValue('x', node, tensorMap, context))];
61153 case 'Exp':
61154 return [exp(getParamValue('x', node, tensorMap, context))];
61155 case 'Expm1': {
61156 return [expm1(getParamValue('x', node, tensorMap, context))];
61157 }
61158 case 'Floor':
61159 return [floor(getParamValue('x', node, tensorMap, context))];
61160 case 'Log':
61161 return [log$1(getParamValue('x', node, tensorMap, context))];
61162 case 'Log1p': {
61163 return [log1p(getParamValue('x', node, tensorMap, context))];
61164 }
61165 case 'Imag':
61166 return [imag(getParamValue('x', node, tensorMap, context))];
61167 case 'Neg':
61168 return [neg(getParamValue('x', node, tensorMap, context))];
61169 case 'Reciprocal': {
61170 return [reciprocal(getParamValue('x', node, tensorMap, context))];
61171 }
61172 case 'Real':
61173 return [real(getParamValue('x', node, tensorMap, context))];
61174 case 'Relu':
61175 return [relu(getParamValue('x', node, tensorMap, context))];
61176 case 'Round': {
61177 return [round$1(getParamValue('x', node, tensorMap, context))];
61178 }
61179 case 'Selu':
61180 return [selu(getParamValue('x', node, tensorMap, context))];
61181 case 'Sigmoid':
61182 return [sigmoid(getParamValue('x', node, tensorMap, context))];
61183 case 'Sin':
61184 return [sin(getParamValue('x', node, tensorMap, context))];
61185 case 'Sign': {
61186 return [sign(getParamValue('x', node, tensorMap, context))];
61187 }
61188 case 'Sinh': {
61189 return [sinh(getParamValue('x', node, tensorMap, context))];
61190 }
61191 case 'Softplus': {
61192 return [softplus(getParamValue('x', node, tensorMap, context))];
61193 }
61194 case 'Sqrt': {
61195 return [sqrt(getParamValue('x', node, tensorMap, context))];
61196 }
61197 case 'Square': {
61198 return [square(getParamValue('x', node, tensorMap, context))];
61199 }
61200 case 'Tanh': {
61201 return [tanh$1(getParamValue('x', node, tensorMap, context))];
61202 }
61203 case 'Tan':
61204 return [tan(getParamValue('x', node, tensorMap, context))];
61205 case 'ClipByValue':
61206 return [clipByValue(getParamValue('x', node, tensorMap, context), getParamValue('clipValueMin', node, tensorMap, context), getParamValue('clipValueMax', node, tensorMap, context))];
61207 case 'Relu6':
61208 return [relu6(getParamValue('x', node, tensorMap, context))];
61209 case 'Rsqrt':
61210 return [rsqrt(getTensor(node.inputNames[0], tensorMap, context))];
61211 case 'Prod':
61212 return [prod(getParamValue('x', node, tensorMap, context), getParamValue('axes', node, tensorMap, context))];
61213 case 'LeakyRelu':
61214 return [leakyRelu(getParamValue('x', node, tensorMap, context), getParamValue('alpha', node, tensorMap, context))];
61215 case 'Prelu':
61216 return [prelu(getParamValue('x', node, tensorMap, context), getParamValue('alpha', node, tensorMap, context))];
61217 case 'IsNan':
61218 return [isNaN$1(getTensor(node.inputNames[0], tensorMap, context))];
61219 default:
61220 throw TypeError(`Node type ${node.op} is not implemented`);
61221 }
61222 };
61223 const CATEGORY$1 = 'basic_math';
61224
61225 /**
61226 * @license
61227 * Copyright 2020 Google LLC. All Rights Reserved.
61228 * Licensed under the Apache License, Version 2.0 (the "License");
61229 * you may not use this file except in compliance with the License.
61230 * You may obtain a copy of the License at
61231 *
61232 * http://www.apache.org/licenses/LICENSE-2.0
61233 *
61234 * Unless required by applicable law or agreed to in writing, software
61235 * distributed under the License is distributed on an "AS IS" BASIS,
61236 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
61237 * See the License for the specific language governing permissions and
61238 * limitations under the License.
61239 * =============================================================================
61240 */
61241 /**
61242 * Used by TensorList and TensorArray to verify if elementShape matches, support
61243 * negative value as the dim shape.
61244 * @param shapeA
61245 * @param shapeB
61246 * @param errorMessagePrefix
61247 */
61248 function assertShapesMatchAllowUndefinedSize(shapeA, shapeB, errorMessagePrefix = '') {
61249 // constant shape means unknown rank
61250 if (typeof shapeA === 'number' || typeof shapeB === 'number') {
61251 return;
61252 }
61253 assert(shapeA.length === shapeB.length, () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);
61254 for (let i = 0; i < shapeA.length; i++) {
61255 const dim0 = shapeA[i];
61256 const dim1 = shapeB[i];
61257 assert(dim0 < 0 || dim1 < 0 || dim0 === dim1, () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);
61258 }
61259 }
61260 function fullDefinedShape(elementShape) {
61261 if (typeof elementShape === 'number' || elementShape.some(dim => dim < 0)) {
61262 return false;
61263 }
61264 return true;
61265 }
61266 /**
61267 * Generate the output element shape from the list elementShape, list tensors
61268 * and input param.
61269 * @param listElementShape
61270 * @param tensors
61271 * @param elementShape
61272 */
61273 function inferElementShape(listElementShape, tensors, elementShape) {
61274 let partialShape = mergeElementShape(listElementShape, elementShape);
61275 const notfullDefinedShape = !fullDefinedShape(partialShape);
61276 if (notfullDefinedShape && tensors.length === 0) {
61277 throw new Error(`Tried to calculate elements of an empty list` +
61278 ` with non-fully-defined elementShape: ${partialShape}`);
61279 }
61280 if (notfullDefinedShape) {
61281 tensors.forEach(tensor => {
61282 partialShape = mergeElementShape(tensor.shape, partialShape);
61283 });
61284 }
61285 if (!fullDefinedShape(partialShape)) {
61286 throw new Error(`Non-fully-defined elementShape: ${partialShape}`);
61287 }
61288 return partialShape;
61289 }
61290 function mergeElementShape(elementShapeA, elementShapeB) {
61291 if (typeof elementShapeA === 'number') {
61292 return elementShapeB;
61293 }
61294 if (typeof elementShapeB === 'number') {
61295 return elementShapeA;
61296 }
61297 if (elementShapeA.length !== elementShapeB.length) {
61298 throw new Error(`Incompatible ranks during merge: ${elementShapeA} vs. ${elementShapeB}`);
61299 }
61300 const result = [];
61301 for (let i = 0; i < elementShapeA.length; ++i) {
61302 const dim0 = elementShapeA[i];
61303 const dim1 = elementShapeB[i];
61304 if (dim0 >= 0 && dim1 >= 0 && dim0 !== dim1) {
61305 throw new Error(`Incompatible shape during merge: ${elementShapeA} vs. ${elementShapeB}`);
61306 }
61307 result[i] = dim0 >= 0 ? dim0 : dim1;
61308 }
61309 return result;
61310 }
61311
61312 /**
61313 * @license
61314 * Copyright 2018 Google LLC. All Rights Reserved.
61315 * Licensed under the Apache License, Version 2.0 (the "License");
61316 * you may not use this file except in compliance with the License.
61317 * You may obtain a copy of the License at
61318 *
61319 * http://www.apache.org/licenses/LICENSE-2.0
61320 *
61321 * Unless required by applicable law or agreed to in writing, software
61322 * distributed under the License is distributed on an "AS IS" BASIS,
61323 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
61324 * See the License for the specific language governing permissions and
61325 * limitations under the License.
61326 * =============================================================================
61327 */
61328 /**
61329 * The TensorArray object keeps an array of Tensors. It
61330 * allows reading from the array and writing to the array.
61331 */
61332 class TensorArray {
61333 constructor(name, dtype, maxSize, elementShape, identicalElementShapes, dynamicSize, clearAfterRead) {
61334 this.name = name;
61335 this.dtype = dtype;
61336 this.maxSize = maxSize;
61337 this.elementShape = elementShape;
61338 this.identicalElementShapes = identicalElementShapes;
61339 this.dynamicSize = dynamicSize;
61340 this.clearAfterRead = clearAfterRead;
61341 this.tensors = [];
61342 this.closed_ = false;
61343 this.idTensor = scalar(0);
61344 keep(this.idTensor);
61345 }
61346 get id() {
61347 return this.idTensor.id;
61348 }
61349 get closed() {
61350 return this.closed_;
61351 }
61352 /**
61353 * Dispose the tensors and idTensor and mark the TensoryArray as closed.
61354 */
61355 clearAndClose(keepIds) {
61356 this.tensors.forEach(tensor => {
61357 if (keepIds == null || !keepIds.has(tensor.tensor.id)) {
61358 tensor.tensor.dispose();
61359 }
61360 });
61361 this.tensors = [];
61362 this.closed_ = true;
61363 this.idTensor.dispose();
61364 }
61365 size() {
61366 return this.tensors.length;
61367 }
61368 /**
61369 * Read the value at location index in the TensorArray.
61370 * @param index Number the index to read from.
61371 */
61372 read(index) {
61373 if (this.closed_) {
61374 throw new Error(`TensorArray ${this.name} has already been closed.`);
61375 }
61376 if (index < 0 || index >= this.size()) {
61377 throw new Error(`Tried to read from index ${index}, but array size is: ${this.size()}`);
61378 }
61379 const tensorWithState = this.tensors[index];
61380 if (tensorWithState.cleared) {
61381 throw new Error(`TensorArray ${this.name}: Could not read index ${index} twice because it was cleared after a previous read ` +
61382 `(perhaps try setting clear_after_read = false?).`);
61383 }
61384 if (this.clearAfterRead) {
61385 tensorWithState.cleared = true;
61386 }
61387 tensorWithState.read = true;
61388 return tensorWithState.tensor;
61389 }
61390 /**
61391 * Helper method to read multiple tensors from the specified indices.
61392 */
61393 readMany(indices) {
61394 return indices.map(index => this.read(index));
61395 }
61396 /**
61397 * Write value into the index of the TensorArray.
61398 * @param index number the index to write to.
61399 * @param tensor
61400 */
61401 write(index, tensor) {
61402 if (this.closed_) {
61403 throw new Error(`TensorArray ${this.name} has already been closed.`);
61404 }
61405 if (index < 0 || !this.dynamicSize && index >= this.maxSize) {
61406 throw new Error(`Tried to write to index ${index}, but array is not resizeable and size is: ${this.maxSize}`);
61407 }
61408 const t = this.tensors[index] || {};
61409 if (tensor.dtype !== this.dtype) {
61410 throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index},
61411 because the value dtype is ${tensor.dtype}, but TensorArray dtype is ${this.dtype}.`);
61412 }
61413 // Set the shape for the first time write to unknow shape tensor array
61414 if (this.size() === 0 &&
61415 (this.elementShape == null || this.elementShape.length === 0)) {
61416 this.elementShape = tensor.shape;
61417 }
61418 assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, `TensorArray ${this.name}: Could not write to TensorArray index ${index}.`);
61419 if (t.read) {
61420 throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index}, because it has already been read.`);
61421 }
61422 if (t.written) {
61423 throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index}, because it has already been written.`);
61424 }
61425 t.tensor = tensor;
61426 keep(tensor);
61427 t.written = true;
61428 this.tensors[index] = t;
61429 }
61430 /**
61431 * Helper method to write multiple tensors to the specified indices.
61432 */
61433 writeMany(indices, tensors) {
61434 if (indices.length !== tensors.length) {
61435 throw new Error(`TensorArray ${this.name}: could not write multiple tensors,` +
61436 `because the index size: ${indices.length} is not the same as tensors size: ${tensors.length}.`);
61437 }
61438 indices.forEach((i, index) => this.write(i, tensors[index]));
61439 }
61440 /**
61441 * Return selected values in the TensorArray as a packed Tensor. All of
61442 * selected values must have been written and their shapes must all match.
61443 * @param [indices] number[] Optional. Taking values in [0, max_value). If the
61444 * TensorArray is not dynamic, max_value=size(). If not specified returns
61445 * all tensors in the original order.
61446 * @param [dtype]
61447 */
61448 gather(indices, dtype) {
61449 if (!!dtype && dtype !== this.dtype) {
61450 throw new Error(`TensorArray dtype is ${this.dtype} but gather requested dtype ${dtype}`);
61451 }
61452 if (!indices) {
61453 indices = [];
61454 for (let i = 0; i < this.size(); i++) {
61455 indices.push(i);
61456 }
61457 }
61458 else {
61459 indices = indices.slice(0, this.size());
61460 }
61461 if (indices.length === 0) {
61462 return tensor([], [0].concat(this.elementShape));
61463 }
61464 // Read all the PersistentTensors into a vector to keep track of
61465 // their memory.
61466 const tensors = this.readMany(indices);
61467 assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, 'TensorArray shape mismatch: ');
61468 return stack(tensors, 0);
61469 }
61470 /**
61471 * Return the values in the TensorArray as a concatenated Tensor.
61472 */
61473 concat(dtype) {
61474 if (!!dtype && dtype !== this.dtype) {
61475 throw new Error(`TensorArray dtype is ${this.dtype} but concat requested dtype ${dtype}`);
61476 }
61477 if (this.size() === 0) {
61478 return tensor([], [0].concat(this.elementShape));
61479 }
61480 const indices = [];
61481 for (let i = 0; i < this.size(); i++) {
61482 indices.push(i);
61483 }
61484 // Collect all the tensors from the tensors array.
61485 const tensors = this.readMany(indices);
61486 assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, `TensorArray shape mismatch: tensor array shape (${this.elementShape}) vs first tensor shape (${tensors[0].shape})`);
61487 return concat(tensors, 0);
61488 }
61489 /**
61490 * Scatter the values of a Tensor in specific indices of a TensorArray.
61491 * @param indices nummber[] values in [0, max_value). If the
61492 * TensorArray is not dynamic, max_value=size().
61493 * @param tensor Tensor input tensor.
61494 */
61495 scatter(indices, tensor) {
61496 if (tensor.dtype !== this.dtype) {
61497 throw new Error(`TensorArray dtype is ${this.dtype} but tensor has dtype ${tensor.dtype}`);
61498 }
61499 if (indices.length !== tensor.shape[0]) {
61500 throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${indices.length} vs. ${tensor.shape[0]}`);
61501 }
61502 const maxIndex = Math.max(...indices);
61503 if (!this.dynamicSize && maxIndex >= this.maxSize) {
61504 throw new Error(`Max index must be < array size (${maxIndex} vs. ${this.maxSize})`);
61505 }
61506 this.writeMany(indices, unstack(tensor, 0));
61507 }
61508 /**
61509 * Split the values of a Tensor into the TensorArray.
61510 * @param length number[] with the lengths to use when splitting value along
61511 * its first dimension.
61512 * @param tensor Tensor, the tensor to split.
61513 */
61514 split(length, tensor) {
61515 if (tensor.dtype !== this.dtype) {
61516 throw new Error(`TensorArray dtype is ${this.dtype} but tensor has dtype ${tensor.dtype}`);
61517 }
61518 let totalLength = 0;
61519 const cumulativeLengths = length.map(len => {
61520 totalLength += len;
61521 return totalLength;
61522 });
61523 if (totalLength !== tensor.shape[0]) {
61524 throw new Error(`Expected sum of lengths to be equal to
61525 tensor.shape[0], but sum of lengths is
61526 ${totalLength}, and tensor's shape is: ${tensor.shape}`);
61527 }
61528 if (!this.dynamicSize && length.length !== this.maxSize) {
61529 throw new Error(`TensorArray's size is not equal to the size of lengths (${this.maxSize} vs. ${length.length}), ` +
61530 'and the TensorArray is not marked as dynamically resizeable');
61531 }
61532 const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
61533 const tensors = [];
61534 tidy(() => {
61535 tensor = reshape(tensor, [1, totalLength, elementPerRow]);
61536 for (let i = 0; i < length.length; ++i) {
61537 const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1];
61538 const indices = [0, previousLength, 0];
61539 const sizes = [1, length[i], elementPerRow];
61540 tensors[i] = reshape(slice(tensor, indices, sizes), this.elementShape);
61541 }
61542 return tensors;
61543 });
61544 const indices = [];
61545 for (let i = 0; i < length.length; i++) {
61546 indices[i] = i;
61547 }
61548 this.writeMany(indices, tensors);
61549 }
61550 }
61551
61552 /**
61553 * @license
61554 * Copyright 2020 Google LLC. All Rights Reserved.
61555 * Licensed under the Apache License, Version 2.0 (the "License");
61556 * you may not use this file except in compliance with the License.
61557 * You may obtain a copy of the License at
61558 *
61559 * http://www.apache.org/licenses/LICENSE-2.0
61560 *
61561 * Unless required by applicable law or agreed to in writing, software
61562 * distributed under the License is distributed on an "AS IS" BASIS,
61563 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
61564 * See the License for the specific language governing permissions and
61565 * limitations under the License.
61566 * =============================================================================
61567 */
61568 /**
61569 * TensorList stores a container of `tf.Tensor` objects, which are accessible
61570 * via tensors field.
61571 *
61572 * In order to get a copy of the underlying list, use the copy method:
61573 * ```
61574 * TensorList b = a.copy();
61575 * b.tensors().pushBack(t); // This does not modify a.tensors().
61576 * ```
61577 *
61578 * Note that this is not a deep copy: the memory locations of the underlying
61579 * tensors will still point to the same locations of the corresponding tensors
61580 * in the original.
61581 */
61582 class TensorList {
61583 /**
61584 *
61585 * @param tensors list of tensors
61586 * @param elementShape shape of each tensor, this can be a single number (any
61587 * shape is allowed) or partial shape (dim = -1).
61588 * @param elementDtype data type of each tensor
61589 * @param maxNumElements The maximum allowed size of `tensors`. Defaults to -1
61590 * meaning that the size of `tensors` is unbounded.
61591 */
61592 constructor(tensors, elementShape, elementDtype, maxNumElements = -1) {
61593 this.tensors = tensors;
61594 this.elementShape = elementShape;
61595 this.elementDtype = elementDtype;
61596 if (tensors != null) {
61597 tensors.forEach(tensor => {
61598 if (elementDtype !== tensor.dtype) {
61599 throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${tensor.dtype}`);
61600 }
61601 assertShapesMatchAllowUndefinedSize(elementShape, tensor.shape, 'TensorList shape mismatch: ');
61602 keep(tensor);
61603 });
61604 }
61605 this.idTensor = scalar(0);
61606 this.maxNumElements = maxNumElements;
61607 keep(this.idTensor);
61608 }
61609 get id() {
61610 return this.idTensor.id;
61611 }
61612 /**
61613 * Get a new TensorList containing a copy of the underlying tensor container.
61614 */
61615 copy() {
61616 return new TensorList([...this.tensors], this.elementShape, this.elementDtype);
61617 }
61618 /**
61619 * Dispose the tensors and idTensor and clear the tensor list.
61620 */
61621 clearAndClose(keepIds) {
61622 this.tensors.forEach(tensor => {
61623 if (keepIds == null || !keepIds.has(tensor.id)) {
61624 tensor.dispose();
61625 }
61626 });
61627 this.tensors.length = 0;
61628 this.idTensor.dispose();
61629 }
61630 /**
61631 * The size of the tensors in the tensor list.
61632 */
61633 size() {
61634 return this.tensors.length;
61635 }
61636 /**
61637 * Return a tensor that stacks a list of rank-R tf.Tensors into one rank-(R+1)
61638 * tf.Tensor.
61639 * @param elementShape shape of each tensor
61640 * @param elementDtype data type of each tensor
61641 * @param numElements the number of elements to stack
61642 */
61643 stack(elementShape, elementDtype, numElements = -1) {
61644 if (elementDtype !== this.elementDtype) {
61645 throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
61646 }
61647 if (numElements !== -1 && this.tensors.length !== numElements) {
61648 throw new Error(`Operation expected a list with ${numElements} elements but got a list with ${this.tensors.length} elements.`);
61649 }
61650 assertShapesMatchAllowUndefinedSize(elementShape, this.elementShape, 'TensorList shape mismatch: ');
61651 const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
61652 return tidy(() => {
61653 const reshapedTensors = this.tensors.map(tensor => reshape(tensor, outputElementShape));
61654 return stack(reshapedTensors, 0);
61655 });
61656 }
61657 /**
61658 * Pop a tensor from the end of the list.
61659 * @param elementShape shape of the tensor
61660 * @param elementDtype data type of the tensor
61661 */
61662 popBack(elementShape, elementDtype) {
61663 if (elementDtype !== this.elementDtype) {
61664 throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
61665 }
61666 if (this.size() === 0) {
61667 throw new Error('Trying to pop from an empty list.');
61668 }
61669 const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
61670 const tensor = this.tensors.pop();
61671 assertShapesMatchAllowUndefinedSize(tensor.shape, elementShape, 'TensorList shape mismatch: ');
61672 return reshape(tensor, outputElementShape);
61673 }
61674 /**
61675 * Push a tensor to the end of the list.
61676 * @param tensor Tensor to be pushed.
61677 */
61678 pushBack(tensor) {
61679 if (tensor.dtype !== this.elementDtype) {
61680 throw new Error(`Invalid data types; op elements ${tensor.dtype}, but list elements ${this.elementDtype}`);
61681 }
61682 assertShapesMatchAllowUndefinedSize(tensor.shape, this.elementShape, 'TensorList shape mismatch: ');
61683 if (this.maxNumElements === this.size()) {
61684 throw new Error(`Trying to push element into a full list.`);
61685 }
61686 keep(tensor);
61687 this.tensors.push(tensor);
61688 }
61689 /**
61690 * Update the size of the list.
61691 * @param size the new size of the list.
61692 */
61693 resize(size) {
61694 if (size < 0) {
61695 throw new Error(`TensorListResize expects size to be non-negative. Got: ${size}`);
61696 }
61697 if (this.maxNumElements !== -1 && size > this.maxNumElements) {
61698 throw new Error(`TensorListResize input size ${size} is greater maxNumElement ${this.maxNumElements}.`);
61699 }
61700 const destTensorList = new TensorList([], this.elementShape, this.elementDtype, this.maxNumElements);
61701 destTensorList.tensors.length = size;
61702 for (let i = 0; i < Math.min(this.tensors.length, size); ++i) {
61703 destTensorList.tensors[i] = this.tensors[i];
61704 }
61705 return destTensorList;
61706 }
61707 /**
61708 * Retrieve the element at the provided index
61709 * @param elementShape shape of the tensor
61710 * @param elementDtype dtype of the tensor
61711 * @param elementIndex index of the tensor
61712 */
61713 getItem(elementIndex, elementShape, elementDtype) {
61714 if (elementDtype !== this.elementDtype) {
61715 throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
61716 }
61717 if (elementIndex < 0 || elementIndex > this.tensors.length) {
61718 throw new Error(`Trying to access element ${elementIndex} in a list with ${this.tensors.length} elements.`);
61719 }
61720 if (this.tensors[elementIndex] == null) {
61721 throw new Error(`element at index ${elementIndex} is null.`);
61722 }
61723 assertShapesMatchAllowUndefinedSize(this.tensors[elementIndex].shape, elementShape, 'TensorList shape mismatch: ');
61724 const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
61725 return reshape(this.tensors[elementIndex], outputElementShape);
61726 }
61727 /**
61728 * Set the tensor at the index
61729 * @param elementIndex index of the tensor
61730 * @param tensor the tensor to be inserted into the list
61731 */
61732 setItem(elementIndex, tensor) {
61733 if (tensor.dtype !== this.elementDtype) {
61734 throw new Error(`Invalid data types; op elements ${tensor.dtype}, but list elements ${this.elementDtype}`);
61735 }
61736 if (elementIndex < 0 ||
61737 this.maxNumElements !== -1 && elementIndex >= this.maxNumElements) {
61738 throw new Error(`Trying to set element ${elementIndex} in a list with max ${this.maxNumElements} elements.`);
61739 }
61740 assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, 'TensorList shape mismatch: ');
61741 keep(tensor);
61742 this.tensors[elementIndex] = tensor;
61743 }
61744 /**
61745 * Return selected values in the TensorList as a stacked Tensor. All of
61746 * selected values must have been written and their shapes must all match.
61747 * @param indices indices of tensors to gather
61748 * @param elementDtype output tensor dtype
61749 * @param elementShape output tensor element shape
61750 */
61751 gather(indices, elementDtype, elementShape) {
61752 if (elementDtype !== this.elementDtype) {
61753 throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
61754 }
61755 assertShapesMatchAllowUndefinedSize(this.elementShape, elementShape, 'TensorList shape mismatch: ');
61756 // When indices is greater than the size of the list, indices beyond the
61757 // size of the list are ignored.
61758 indices = indices.slice(0, this.size());
61759 const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
61760 if (indices.length === 0) {
61761 return tensor([], [0].concat(outputElementShape));
61762 }
61763 return tidy(() => {
61764 const tensors = indices.map(i => reshape(this.tensors[i], outputElementShape));
61765 return stack(tensors, 0);
61766 });
61767 }
61768 /**
61769 * Return the values in the TensorList as a concatenated Tensor.
61770 * @param elementDtype output tensor dtype
61771 * @param elementShape output tensor element shape
61772 */
61773 concat(elementDtype, elementShape) {
61774 if (!!elementDtype && elementDtype !== this.elementDtype) {
61775 throw new Error(`TensorList dtype is ${this.elementDtype} but concat requested dtype ${elementDtype}`);
61776 }
61777 assertShapesMatchAllowUndefinedSize(this.elementShape, elementShape, 'TensorList shape mismatch: ');
61778 const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
61779 if (this.size() === 0) {
61780 return tensor([], [0].concat(outputElementShape));
61781 }
61782 return tidy(() => {
61783 const tensors = this.tensors.map(t => reshape(t, outputElementShape));
61784 return concat(tensors, 0);
61785 });
61786 }
61787 }
61788 /**
61789 * Creates a TensorList which, when stacked, has the value of tensor.
61790 * @param tensor from tensor
61791 * @param elementShape output tensor element shape
61792 */
61793 function fromTensor(tensor, elementShape, elementDtype) {
61794 const dtype = tensor.dtype;
61795 if (tensor.shape.length < 1) {
61796 throw new Error(`Tensor must be at least a vector, but saw shape: ${tensor.shape}`);
61797 }
61798 if (tensor.dtype !== elementDtype) {
61799 throw new Error(`Invalid data types; op elements ${tensor.dtype}, but list elements ${elementDtype}`);
61800 }
61801 const tensorElementShape = tensor.shape.slice(1);
61802 assertShapesMatchAllowUndefinedSize(tensorElementShape, elementShape, 'TensorList shape mismatch: ');
61803 const tensorList = unstack(tensor);
61804 return new TensorList(tensorList, elementShape, dtype);
61805 }
61806 /**
61807 * Return a TensorList of the given size with empty elements.
61808 * @param elementShape the shape of the future elements of the list
61809 * @param elementDtype the desired type of elements in the list
61810 * @param numElements the number of elements to reserve
61811 */
61812 function reserve(elementShape, elementDtype, numElements) {
61813 return new TensorList([], elementShape, elementDtype, numElements);
61814 }
61815 /**
61816 * Put tensors at specific indices of a stacked tensor into a TensorList.
61817 * @param indices list of indices on how to scatter the tensor.
61818 * @param tensor input tensor.
61819 * @param elementShape the shape of the future elements of the list
61820 * @param numElements the number of elements to scatter
61821 */
61822 function scatter(tensor, indices, elementShape, numElements) {
61823 if (indices.length !== tensor.shape[0]) {
61824 throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${indices.length} vs. ${tensor.shape[0]}`);
61825 }
61826 const maxIndex = Math.max(...indices);
61827 if (numElements != null && numElements !== -1 && maxIndex >= numElements) {
61828 throw new Error(`Max index must be < array size (${maxIndex} vs. ${numElements})`);
61829 }
61830 const list = new TensorList([], elementShape, tensor.dtype, numElements);
61831 const tensors = unstack(tensor, 0);
61832 indices.forEach((value, index) => {
61833 list.setItem(value, tensors[index]);
61834 });
61835 return list;
61836 }
61837 /**
61838 * Split the values of a Tensor into a TensorList.
61839 * @param length the lengths to use when splitting value along
61840 * its first dimension.
61841 * @param tensor the tensor to split.
61842 * @param elementShape the shape of the future elements of the list
61843 */
61844 function split$2(tensor, length, elementShape) {
61845 let totalLength = 0;
61846 const cumulativeLengths = length.map(len => {
61847 totalLength += len;
61848 return totalLength;
61849 });
61850 if (totalLength !== tensor.shape[0]) {
61851 throw new Error(`Expected sum of lengths to be equal to
61852 tensor.shape[0], but sum of lengths is
61853 ${totalLength}, and tensor's shape is: ${tensor.shape}`);
61854 }
61855 const shapeWithoutFirstDim = tensor.shape.slice(1);
61856 const outputElementShape = mergeElementShape(shapeWithoutFirstDim, elementShape);
61857 const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
61858 const tensors = tidy(() => {
61859 const tensors = [];
61860 tensor = reshape(tensor, [1, totalLength, elementPerRow]);
61861 for (let i = 0; i < length.length; ++i) {
61862 const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1];
61863 const indices = [0, previousLength, 0];
61864 const sizes = [1, length[i], elementPerRow];
61865 tensors[i] = reshape(slice(tensor, indices, sizes), outputElementShape);
61866 }
61867 tensor.dispose();
61868 return tensors;
61869 });
61870 const list = new TensorList([], elementShape, tensor.dtype, length.length);
61871 for (let i = 0; i < tensors.length; i++) {
61872 list.setItem(i, tensors[i]);
61873 }
61874 return list;
61875 }
61876
61877 /**
61878 * @license
61879 * Copyright 2018 Google LLC. All Rights Reserved.
61880 * Licensed under the Apache License, Version 2.0 (the "License");
61881 * you may not use this file except in compliance with the License.
61882 * You may obtain a copy of the License at
61883 *
61884 * http://www.apache.org/licenses/LICENSE-2.0
61885 *
61886 * Unless required by applicable law or agreed to in writing, software
61887 * distributed under the License is distributed on an "AS IS" BASIS,
61888 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
61889 * See the License for the specific language governing permissions and
61890 * limitations under the License.
61891 * =============================================================================
61892 */
61893 const executeOp$2 = async (node, tensorMap, context) => {
61894 switch (node.op) {
61895 case 'If':
61896 case 'StatelessIf': {
61897 const thenFunc = getParamValue('thenBranch', node, tensorMap, context);
61898 const elseFunc = getParamValue('elseBranch', node, tensorMap, context);
61899 const cond = getParamValue('cond', node, tensorMap, context);
61900 const args = getParamValue('args', node, tensorMap, context);
61901 const condValue = await cond.data();
61902 if (condValue[0]) {
61903 return context.functionMap[thenFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap);
61904 }
61905 else {
61906 return context.functionMap[elseFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap);
61907 }
61908 }
61909 case 'While':
61910 case 'StatelessWhile': {
61911 const bodyFunc = getParamValue('body', node, tensorMap, context);
61912 const condFunc = getParamValue('cond', node, tensorMap, context);
61913 const args = getParamValue('args', node, tensorMap, context);
61914 // Calculate the condition of the loop
61915 const condResult = (await context.functionMap[condFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap));
61916 const argIds = args.map(tensor => tensor.id);
61917 let condValue = await condResult[0].data();
61918 // Dispose the intermediate tensors for condition function
61919 condResult.forEach(tensor => {
61920 if (!tensor.kept && argIds.indexOf(tensor.id) === -1) {
61921 tensor.dispose();
61922 }
61923 });
61924 let result = args;
61925 while (condValue[0]) {
61926 // Record the previous result for intermediate tensor tracking
61927 const origResult = result;
61928 // Execution the body of the loop
61929 result = await context.functionMap[bodyFunc].executeFunctionAsync(result, context.tensorArrayMap, context.tensorListMap);
61930 const resultIds = result.map(tensor => tensor.id);
61931 // Dispose the intermediate tensor for body function that is not global
61932 // kept, not input/output of the body function
61933 origResult.forEach(tensor => {
61934 if (!tensor.kept && argIds.indexOf(tensor.id) === -1 &&
61935 resultIds.indexOf(tensor.id) === -1) {
61936 tensor.dispose();
61937 }
61938 });
61939 // Recalcuate the condition of the loop using the latest results.
61940 const condResult = (await context.functionMap[condFunc].executeFunctionAsync(result, context.tensorArrayMap, context.tensorListMap));
61941 condValue = await condResult[0].data();
61942 // Dispose the intermediate tensors for condition function
61943 condResult.forEach(tensor => {
61944 if (!tensor.kept && argIds.indexOf(tensor.id) === -1 &&
61945 resultIds.indexOf(tensor.id) === -1) {
61946 tensor.dispose();
61947 }
61948 });
61949 }
61950 return result;
61951 }
61952 case 'LoopCond': {
61953 const pred = getParamValue('pred', node, tensorMap, context);
61954 return [cloneTensor(pred)];
61955 }
61956 case 'Switch': {
61957 const pred = getParamValue('pred', node, tensorMap, context);
61958 let data = getParamValue('data', node, tensorMap, context);
61959 if (!data.kept) {
61960 data = cloneTensor(data);
61961 }
61962 // Outputs nodes :0 => false, :1 => true
61963 return (await pred.data())[0] ? [undefined, data] : [data, undefined];
61964 }
61965 case 'Merge': {
61966 const inputName = node.inputNames.find(name => getTensor(name, tensorMap, context) !== undefined);
61967 if (inputName) {
61968 const data = getTensor(inputName, tensorMap, context);
61969 return [cloneTensor(data)];
61970 }
61971 return undefined;
61972 }
61973 case 'Enter': {
61974 const frameId = getParamValue('frameName', node, tensorMap, context);
61975 const data = getParamValue('tensor', node, tensorMap, context);
61976 context.enterFrame(frameId);
61977 return [cloneTensor(data)];
61978 }
61979 case 'Exit': {
61980 const data = getParamValue('tensor', node, tensorMap, context);
61981 context.exitFrame();
61982 return [cloneTensor(data)];
61983 }
61984 case 'NextIteration': {
61985 const data = getParamValue('tensor', node, tensorMap, context);
61986 context.nextIteration();
61987 return [cloneTensor(data)];
61988 }
61989 case 'TensorArrayV3': {
61990 const size = getParamValue('size', node, tensorMap, context);
61991 const dtype = getParamValue('dtype', node, tensorMap, context);
61992 const elementShape = getParamValue('elementShape', node, tensorMap, context);
61993 const dynamicSize = getParamValue('dynamicSize', node, tensorMap, context);
61994 const clearAfterRead = getParamValue('clearAfterRead', node, tensorMap, context);
61995 const identicalElementShapes = getParamValue('identicalElementShapes', node, tensorMap, context);
61996 const name = getParamValue('name', node, tensorMap, context);
61997 const tensorArray = new TensorArray(name, dtype, size, elementShape, identicalElementShapes, dynamicSize, clearAfterRead);
61998 context.addTensorArray(tensorArray);
61999 return [tensorArray.idTensor, scalar(1.0)];
62000 }
62001 case 'TensorArrayWriteV3': {
62002 const id = getParamValue('tensorArrayId', node, tensorMap, context);
62003 const index = getParamValue('index', node, tensorMap, context);
62004 const writeTensor = getParamValue('tensor', node, tensorMap, context);
62005 const writeTensorArray = context.getTensorArray(id.id);
62006 writeTensorArray.write(index, writeTensor);
62007 return [writeTensorArray.idTensor];
62008 }
62009 case 'TensorArrayReadV3': {
62010 const readId = getParamValue('tensorArrayId', node, tensorMap, context);
62011 const readIndex = getParamValue('index', node, tensorMap, context);
62012 const readTensorArray = context.getTensorArray(readId.id);
62013 return [readTensorArray.read(readIndex)];
62014 }
62015 case 'TensorArrayGatherV3': {
62016 const gatherId = getParamValue('tensorArrayId', node, tensorMap, context);
62017 const gatherIndices = getParamValue('indices', node, tensorMap, context);
62018 const gatherDtype = getParamValue('dtype', node, tensorMap, context);
62019 const gatherTensorArray = context.getTensorArray(gatherId.id);
62020 return [gatherTensorArray.gather(gatherIndices, gatherDtype)];
62021 }
62022 case 'TensorArrayScatterV3': {
62023 const scatterId = getParamValue('tensorArrayId', node, tensorMap, context);
62024 const scatterIndices = getParamValue('indices', node, tensorMap, context);
62025 const scatterTensor = getParamValue('tensor', node, tensorMap, context);
62026 const scatterTensorArray = context.getTensorArray(scatterId.id);
62027 scatterTensorArray.scatter(scatterIndices, scatterTensor);
62028 return [scatterTensorArray.idTensor];
62029 }
62030 case 'TensorArrayConcatV3': {
62031 const concatId = getParamValue('tensorArrayId', node, tensorMap, context);
62032 const concatTensorArray = context.getTensorArray(concatId.id);
62033 const concatDtype = getParamValue('dtype', node, tensorMap, context);
62034 return [concatTensorArray.concat(concatDtype)];
62035 }
62036 case 'TensorArraySplitV3': {
62037 const splitId = getParamValue('tensorArrayId', node, tensorMap, context);
62038 const splitTensor = getParamValue('tensor', node, tensorMap, context);
62039 const lengths = getParamValue('lengths', node, tensorMap, context);
62040 const splitTensorArray = context.getTensorArray(splitId.id);
62041 splitTensorArray.split(lengths, splitTensor);
62042 return [splitTensorArray.idTensor];
62043 }
62044 case 'TensorArraySizeV3': {
62045 const sizeId = getParamValue('tensorArrayId', node, tensorMap, context);
62046 const sizeTensorArray = context.getTensorArray(sizeId.id);
62047 return [scalar(sizeTensorArray.size(), 'int32')];
62048 }
62049 case 'TensorArrayCloseV3': {
62050 const closeId = getParamValue('tensorArrayId', node, tensorMap, context);
62051 const closeTensorArray = context.getTensorArray(closeId.id);
62052 closeTensorArray.clearAndClose();
62053 return [closeTensorArray.idTensor];
62054 }
62055 case 'TensorListSetItem': {
62056 const idTensor = getParamValue('tensorListId', node, tensorMap, context);
62057 const index = getParamValue('index', node, tensorMap, context);
62058 const writeTensor = getParamValue('tensor', node, tensorMap, context);
62059 const tensorList = context.getTensorList(idTensor.id);
62060 tensorList.setItem(index, writeTensor);
62061 return [tensorList.idTensor];
62062 }
62063 case 'TensorListGetItem': {
62064 const idTensor = getParamValue('tensorListId', node, tensorMap, context);
62065 const readIndex = getParamValue('index', node, tensorMap, context);
62066 const elementShape = getParamValue('elementShape', node, tensorMap, context);
62067 const elementDType = getParamValue('elementDType', node, tensorMap, context);
62068 const tensorList = context.getTensorList(idTensor.id);
62069 return [tensorList.getItem(readIndex, elementShape, elementDType)];
62070 }
62071 case 'TensorListScatterV2':
62072 case 'TensorListScatter': {
62073 const scatterIndices = getParamValue('indices', node, tensorMap, context);
62074 const scatterTensor = getParamValue('tensor', node, tensorMap, context);
62075 const elementShape = getParamValue('elementShape', node, tensorMap, context);
62076 const numElements = getParamValue('numElements', node, tensorMap, context);
62077 const tensorList = scatter(scatterTensor, scatterIndices, elementShape, numElements);
62078 context.addTensorList(tensorList);
62079 return [tensorList.idTensor];
62080 }
62081 case 'TensorListReserve':
62082 case 'EmptyTensorList': {
62083 const elementShape = getParamValue('elementShape', node, tensorMap, context);
62084 const elementDtype = getParamValue('elementDType', node, tensorMap, context);
62085 let numElementsParam;
62086 if (node.op === 'TensorListReserve') {
62087 numElementsParam = 'numElements';
62088 }
62089 else {
62090 numElementsParam = 'maxNumElements';
62091 }
62092 const numElements = getParamValue(numElementsParam, node, tensorMap, context);
62093 const tensorList = reserve(elementShape, elementDtype, numElements);
62094 context.addTensorList(tensorList);
62095 return [tensorList.idTensor];
62096 }
62097 case 'TensorListGather': {
62098 const gatherId = getParamValue('tensorListId', node, tensorMap, context);
62099 const gatherIndices = getParamValue('indices', node, tensorMap, context);
62100 const elementShape = getParamValue('elementShape', node, tensorMap, context);
62101 const elementDtype = getParamValue('elementDType', node, tensorMap, context);
62102 const tensorList = context.getTensorList(gatherId.id);
62103 return [tensorList.gather(gatherIndices, elementDtype, elementShape)];
62104 }
62105 case 'TensorListStack': {
62106 const idTensor = getParamValue('tensorListId', node, tensorMap, context);
62107 const elementShape = getParamValue('elementShape', node, tensorMap, context);
62108 const elementDtype = getParamValue('elementDType', node, tensorMap, context);
62109 const numElements = getParamValue('numElements', node, tensorMap, context);
62110 const tensorList = context.getTensorList(idTensor.id);
62111 return [tensorList.stack(elementShape, elementDtype, numElements)];
62112 }
62113 case 'TensorListFromTensor': {
62114 const tensor = getParamValue('tensor', node, tensorMap, context);
62115 const elementShape = getParamValue('elementShape', node, tensorMap, context);
62116 const elementDtype = getParamValue('elementDType', node, tensorMap, context);
62117 const tensorList = fromTensor(tensor, elementShape, elementDtype);
62118 context.addTensorList(tensorList);
62119 return [tensorList.idTensor];
62120 }
62121 case 'TensorListConcat':
62122 case 'TensorListConcatV2': {
62123 const concatId = getParamValue('tensorListId', node, tensorMap, context);
62124 const tensorList = context.getTensorList(concatId.id);
62125 const concatDtype = getParamValue('dtype', node, tensorMap, context);
62126 const elementShape = getParamValue('elementShape', node, tensorMap, context);
62127 return [tensorList.concat(concatDtype, elementShape)];
62128 }
62129 case 'TensorListPushBack': {
62130 const idTensor = getParamValue('tensorListId', node, tensorMap, context);
62131 const writeTensor = getParamValue('tensor', node, tensorMap, context);
62132 const tensorList = context.getTensorList(idTensor.id);
62133 tensorList.pushBack(writeTensor);
62134 return [tensorList.idTensor];
62135 }
62136 case 'TensorListPopBack': {
62137 const idTensor = getParamValue('tensorListId', node, tensorMap, context);
62138 const elementShape = getParamValue('elementShape', node, tensorMap, context);
62139 const elementDType = getParamValue('elementDType', node, tensorMap, context);
62140 const tensorList = context.getTensorList(idTensor.id);
62141 return [tensorList.popBack(elementShape, elementDType)];
62142 }
62143 case 'TensorListSplit': {
62144 const splitTensor = getParamValue('tensor', node, tensorMap, context);
62145 const elementShape = getParamValue('elementShape', node, tensorMap, context);
62146 const lengths = getParamValue('lengths', node, tensorMap, context);
62147 const tensorList = split$2(splitTensor, lengths, elementShape);
62148 context.addTensorList(tensorList);
62149 return [tensorList.idTensor];
62150 }
62151 case 'TensorListLength': {
62152 const idTensor = getParamValue('tensorListId', node, tensorMap, context);
62153 const tensorList = context.getTensorList(idTensor.id);
62154 return [scalar(tensorList.size(), 'int32')];
62155 }
62156 case 'TensorListResize': {
62157 const idTensor = getParamValue('tensorListId', node, tensorMap, context);
62158 const size = getParamValue('size', node, tensorMap, context);
62159 const srcTensorList = context.getTensorList(idTensor.id);
62160 const destTensorList = srcTensorList.resize(size);
62161 context.addTensorList(destTensorList);
62162 return [destTensorList.idTensor];
62163 }
62164 default:
62165 throw TypeError(`Node type ${node.op} is not implemented`);
62166 }
62167 };
62168 const CATEGORY$2 = 'control';
62169
62170 /**
62171 * @license
62172 * Copyright 2018 Google LLC. All Rights Reserved.
62173 * Licensed under the Apache License, Version 2.0 (the "License");
62174 * you may not use this file except in compliance with the License.
62175 * You may obtain a copy of the License at
62176 *
62177 * http://www.apache.org/licenses/LICENSE-2.0
62178 *
62179 * Unless required by applicable law or agreed to in writing, software
62180 * distributed under the License is distributed on an "AS IS" BASIS,
62181 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62182 * See the License for the specific language governing permissions and
62183 * limitations under the License.
62184 * =============================================================================
62185 */
62186 function fusedConvAndDepthWiseParams(node, tensorMap, context) {
62187 const [extraOp, activationFunc] = getParamValue('fusedOps', node, tensorMap, context);
62188 const isBiasAdd = extraOp === 'biasadd';
62189 const noBiasAdd = !isBiasAdd;
62190 const isPrelu = activationFunc === 'prelu';
62191 const isBatchNorm = extraOp === 'fusedbatchnorm';
62192 const numArgs = getParamValue('numArgs', node, tensorMap, context);
62193 if (isBiasAdd) {
62194 if (isPrelu && numArgs !== 2) {
62195 throw new Error('FusedConv2d and DepthwiseConv2d with BiasAdd and Prelu ' +
62196 'must have two extra arguments: bias and alpha.');
62197 }
62198 if (!isPrelu && isBiasAdd && numArgs !== 1) {
62199 throw new Error('FusedConv2d and DepthwiseConv2d with BiasAdd must have ' +
62200 'one extra argument: bias.');
62201 }
62202 }
62203 if (isBatchNorm) {
62204 throw new Error('FusedConv2d and DepthwiseConv2d with FusedBatchNorm is not supported');
62205 }
62206 const stride = getParamValue('strides', node, tensorMap, context);
62207 const pad = getPadding(node, tensorMap, context);
62208 const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
62209 .toUpperCase();
62210 const dilations = getParamValue('dilations', node, tensorMap, context);
62211 let [biasArg, preluArg] = getParamValue('args', node, tensorMap, context);
62212 if (noBiasAdd) {
62213 preluArg = biasArg;
62214 biasArg = undefined;
62215 }
62216 const leakyreluAlpha = getParamValue('leakyreluAlpha', node, tensorMap, context);
62217 return {
62218 stride,
62219 pad,
62220 dataFormat,
62221 dilations,
62222 biasArg,
62223 preluArg,
62224 activationFunc,
62225 leakyreluAlpha
62226 };
62227 }
62228 const executeOp$3 = (node, tensorMap, context) => {
62229 switch (node.op) {
62230 case 'Conv1D': {
62231 const stride = getParamValue('stride', node, tensorMap, context);
62232 const pad = getParamValue('pad', node, tensorMap, context);
62233 const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
62234 .toUpperCase();
62235 const dilation = getParamValue('dilation', node, tensorMap, context);
62236 return [conv1d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), stride, pad, dataFormat, dilation)];
62237 }
62238 case 'Conv2D': {
62239 const stride = getParamValue('strides', node, tensorMap, context);
62240 const pad = getPadding(node, tensorMap, context);
62241 const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
62242 .toUpperCase();
62243 const dilations = getParamValue('dilations', node, tensorMap, context);
62244 return [conv2d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [stride[1], stride[2]], pad, dataFormat, [dilations[1], dilations[2]])];
62245 }
62246 case '_FusedConv2D': {
62247 const { stride, pad, dataFormat, dilations, biasArg, preluArg, activationFunc, leakyreluAlpha } = fusedConvAndDepthWiseParams(node, tensorMap, context);
62248 return [conv2d$1({
62249 x: getParamValue('x', node, tensorMap, context),
62250 filter: getParamValue('filter', node, tensorMap, context),
62251 strides: [stride[1], stride[2]],
62252 pad: pad,
62253 dataFormat: dataFormat,
62254 dilations: [dilations[1], dilations[2]],
62255 bias: biasArg,
62256 activation: activationFunc,
62257 preluActivationWeights: preluArg,
62258 leakyreluAlpha
62259 })];
62260 }
62261 case 'FusedDepthwiseConv2dNative': {
62262 const { stride, pad, dataFormat, dilations, biasArg, preluArg, activationFunc, leakyreluAlpha, } = fusedConvAndDepthWiseParams(node, tensorMap, context);
62263 return [depthwiseConv2d$1({
62264 x: getParamValue('x', node, tensorMap, context),
62265 filter: getParamValue('filter', node, tensorMap, context),
62266 strides: [stride[1], stride[2]],
62267 pad: pad,
62268 dataFormat: dataFormat,
62269 dilations: [dilations[1], dilations[2]],
62270 bias: biasArg,
62271 activation: activationFunc,
62272 preluActivationWeights: preluArg,
62273 leakyreluAlpha
62274 })];
62275 }
62276 case 'Conv2DBackpropInput':
62277 case 'Conv2dTranspose': {
62278 const shape = getParamValue('outputShape', node, tensorMap, context);
62279 const stride = getParamValue('strides', node, tensorMap, context);
62280 const pad = getPadding(node, tensorMap, context);
62281 return [conv2dTranspose(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), shape, [stride[1], stride[2]], pad)];
62282 }
62283 case 'DepthwiseConv2dNative':
62284 case 'DepthwiseConv2d': {
62285 const stride = getParamValue('strides', node, tensorMap, context);
62286 const pad = getPadding(node, tensorMap, context);
62287 const dilations = getParamValue('dilations', node, tensorMap, context);
62288 const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
62289 .toUpperCase();
62290 return [depthwiseConv2d(getParamValue('input', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [stride[1], stride[2]], pad, dataFormat, [dilations[1], dilations[2]])];
62291 }
62292 case 'Conv3D': {
62293 const stride = getParamValue('strides', node, tensorMap, context);
62294 const pad = getParamValue('pad', node, tensorMap, context);
62295 const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
62296 .toUpperCase();
62297 const dilations = getParamValue('dilations', node, tensorMap, context);
62298 return [conv3d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [stride[1], stride[2], stride[3]], pad, dataFormat, [dilations[1], dilations[2], dilations[3]])];
62299 }
62300 case 'AvgPool': {
62301 const stride = getParamValue('strides', node, tensorMap, context);
62302 const pad = getParamValue('pad', node, tensorMap, context);
62303 const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
62304 return [avgPool(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], pad)];
62305 }
62306 case 'MaxPool': {
62307 const stride = getParamValue('strides', node, tensorMap, context);
62308 const pad = getParamValue('pad', node, tensorMap, context);
62309 const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
62310 return [maxPool(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], pad)];
62311 }
62312 case 'MaxPoolWithArgmax': {
62313 const stride = getParamValue('strides', node, tensorMap, context);
62314 const pad = getParamValue('pad', node, tensorMap, context);
62315 const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
62316 const includeBatchInIndex = getParamValue('includeBatchInIndex', node, tensorMap, context);
62317 const { result, indexes } = maxPoolWithArgmax(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], pad, includeBatchInIndex);
62318 return [result, indexes];
62319 }
62320 case 'AvgPool3D': {
62321 const stride = getParamValue('strides', node, tensorMap, context);
62322 const pad = getParamValue('pad', node, tensorMap, context);
62323 const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
62324 return [avgPool3d(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2], kernelSize[3]], [stride[1], stride[2], stride[3]], pad)];
62325 }
62326 case 'MaxPool3D': {
62327 const stride = getParamValue('strides', node, tensorMap, context);
62328 const pad = getParamValue('pad', node, tensorMap, context);
62329 const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
62330 return [maxPool3d(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2], kernelSize[3]], [stride[1], stride[2], stride[3]], pad)];
62331 }
62332 case 'Dilation2D': {
62333 const strides = getParamValue('strides', node, tensorMap, context);
62334 const pad = getParamValue('pad', node, tensorMap, context);
62335 const dilations = getParamValue('dilations', node, tensorMap, context);
62336 // strides: [1, stride_height, stride_width, 1].
62337 const strideHeight = strides[1];
62338 const strideWidth = strides[2];
62339 // dilations: [1, dilation_height, dilation_width, 1].
62340 const dilationHeight = dilations[1];
62341 const dilationWidth = dilations[2];
62342 return [dilation2d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [strideHeight, strideWidth], pad, [dilationHeight, dilationWidth], 'NHWC' /* dataFormat */)];
62343 }
62344 default:
62345 throw TypeError(`Node type ${node.op} is not implemented`);
62346 }
62347 };
62348 const CATEGORY$3 = 'convolution';
62349
62350 /**
62351 * @license
62352 * Copyright 2018 Google LLC. All Rights Reserved.
62353 * Licensed under the Apache License, Version 2.0 (the "License");
62354 * you may not use this file except in compliance with the License.
62355 * You may obtain a copy of the License at
62356 *
62357 * http://www.apache.org/licenses/LICENSE-2.0
62358 *
62359 * Unless required by applicable law or agreed to in writing, software
62360 * distributed under the License is distributed on an "AS IS" BASIS,
62361 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62362 * See the License for the specific language governing permissions and
62363 * limitations under the License.
62364 * =============================================================================
62365 */
62366 const executeOp$4 = (node, tensorMap, context) => {
62367 switch (node.op) {
62368 case 'Fill': {
62369 const shape = getParamValue('shape', node, tensorMap, context);
62370 const dtype = getParamValue('dtype', node, tensorMap, context);
62371 const value = getParamValue('value', node, tensorMap, context);
62372 return [fill(shape, value, dtype)];
62373 }
62374 case 'LinSpace': {
62375 const start = getParamValue('start', node, tensorMap, context);
62376 const stop = getParamValue('stop', node, tensorMap, context);
62377 const num = getParamValue('num', node, tensorMap, context);
62378 return [linspace(start, stop, num)];
62379 }
62380 case 'Multinomial': {
62381 const logits = getParamValue('logits', node, tensorMap, context);
62382 const numSamples = getParamValue('numSamples', node, tensorMap, context);
62383 const seed = getParamValue('seed', node, tensorMap, context);
62384 return [multinomial(logits, numSamples, seed)];
62385 }
62386 case 'OneHot': {
62387 const indices = getParamValue('indices', node, tensorMap, context);
62388 const depth = getParamValue('depth', node, tensorMap, context);
62389 const onValue = getParamValue('onValue', node, tensorMap, context);
62390 const offValue = getParamValue('offValue', node, tensorMap, context);
62391 return [oneHot(indices, depth, onValue, offValue)];
62392 }
62393 case 'Ones': {
62394 return [ones$1(getParamValue('shape', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
62395 }
62396 case 'OnesLike': {
62397 return [onesLike(getParamValue('x', node, tensorMap, context))];
62398 }
62399 case 'RandomUniform': {
62400 return [randomUniform(
62401 // tslint:disable-next-line:no-any
62402 getParamValue('shape', node, tensorMap, context), getParamValue('minval', node, tensorMap, context), getParamValue('maxval', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
62403 }
62404 case 'Range': {
62405 const start = getParamValue('start', node, tensorMap, context);
62406 const stop = getParamValue('stop', node, tensorMap, context);
62407 const step = getParamValue('step', node, tensorMap, context);
62408 return [range(start, stop, step, getParamValue('dtype', node, tensorMap, context))];
62409 }
62410 case 'TruncatedNormal': {
62411 const shape = getParamValue('shape', node, tensorMap, context);
62412 const mean = getParamValue('mean', node, tensorMap, context);
62413 const stdDev = getParamValue('stdDev', node, tensorMap, context);
62414 const seed = getParamValue('seed', node, tensorMap, context);
62415 return [truncatedNormal(shape, mean, stdDev, getParamValue('dtype', node, tensorMap, context), seed)];
62416 }
62417 case 'Zeros': {
62418 return [zeros(getParamValue('shape', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
62419 }
62420 case 'ZerosLike': {
62421 return [zerosLike(getParamValue('x', node, tensorMap, context))];
62422 }
62423 default:
62424 throw TypeError(`Node type ${node.op} is not implemented`);
62425 }
62426 };
62427 const CATEGORY$4 = 'creation';
62428
62429 /**
62430 * @license
62431 * Copyright 2018 Google LLC. All Rights Reserved.
62432 * Licensed under the Apache License, Version 2.0 (the "License");
62433 * you may not use this file except in compliance with the License.
62434 * You may obtain a copy of the License at
62435 *
62436 * http://www.apache.org/licenses/LICENSE-2.0
62437 *
62438 * Unless required by applicable law or agreed to in writing, software
62439 * distributed under the License is distributed on an "AS IS" BASIS,
62440 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62441 * See the License for the specific language governing permissions and
62442 * limitations under the License.
62443 * =============================================================================
62444 */
62445 function nmsParams(node, tensorMap, context) {
62446 const boxes = getParamValue('boxes', node, tensorMap, context);
62447 const scores = getParamValue('scores', node, tensorMap, context);
62448 const maxOutputSize = getParamValue('maxOutputSize', node, tensorMap, context);
62449 const iouThreshold = getParamValue('iouThreshold', node, tensorMap, context);
62450 const scoreThreshold = getParamValue('scoreThreshold', node, tensorMap, context);
62451 const softNmsSigma = getParamValue('softNmsSigma', node, tensorMap, context);
62452 return {
62453 boxes,
62454 scores,
62455 maxOutputSize,
62456 iouThreshold,
62457 scoreThreshold,
62458 softNmsSigma
62459 };
62460 }
62461 const executeOp$5 = async (node, tensorMap, context) => {
62462 switch (node.op) {
62463 case 'NonMaxSuppressionV5': {
62464 const { boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma } = nmsParams(node, tensorMap, context);
62465 const result = await image.nonMaxSuppressionWithScoreAsync(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
62466 return [result.selectedIndices, result.selectedScores];
62467 }
62468 case 'NonMaxSuppressionV4': {
62469 const { boxes, scores, maxOutputSize, iouThreshold, scoreThreshold } = nmsParams(node, tensorMap, context);
62470 const padToMaxOutputSize = getParamValue('padToMaxOutputSize', node, tensorMap, context);
62471 const result = await image.nonMaxSuppressionPaddedAsync(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize);
62472 return [result.selectedIndices, result.validOutputs];
62473 }
62474 case 'NonMaxSuppressionV3':
62475 case 'NonMaxSuppressionV2': {
62476 const { boxes, scores, maxOutputSize, iouThreshold, scoreThreshold } = nmsParams(node, tensorMap, context);
62477 return [await image.nonMaxSuppressionAsync(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold)];
62478 }
62479 case 'Where': {
62480 const condition = cast(getParamValue('condition', node, tensorMap, context), 'bool');
62481 const result = [await whereAsync(condition)];
62482 condition.dispose();
62483 return result;
62484 }
62485 case 'ListDiff': {
62486 return setdiff1dAsync(getParamValue('x', node, tensorMap, context), getParamValue('y', node, tensorMap, context));
62487 }
62488 default:
62489 throw TypeError(`Node type ${node.op} is not implemented`);
62490 }
62491 };
62492 const CATEGORY$5 = 'dynamic';
62493
62494 /**
62495 * @license
62496 * Copyright 2018 Google LLC. All Rights Reserved.
62497 * Licensed under the Apache License, Version 2.0 (the "License");
62498 * you may not use this file except in compliance with the License.
62499 * You may obtain a copy of the License at
62500 *
62501 * http://www.apache.org/licenses/LICENSE-2.0
62502 *
62503 * Unless required by applicable law or agreed to in writing, software
62504 * distributed under the License is distributed on an "AS IS" BASIS,
62505 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62506 * See the License for the specific language governing permissions and
62507 * limitations under the License.
62508 * =============================================================================
62509 */
62510 const executeOp$6 = (node, tensorMap, context) => {
62511 switch (node.op) {
62512 case 'LowerBound': {
62513 const sortedSequence = getParamValue('sortedSequence', node, tensorMap, context);
62514 const values = getParamValue('values', node, tensorMap, context);
62515 return [lowerBound(sortedSequence, values)];
62516 }
62517 case 'TopKV2': {
62518 const x = getParamValue('x', node, tensorMap, context);
62519 const k = getParamValue('k', node, tensorMap, context);
62520 const sorted = getParamValue('sorted', node, tensorMap, context);
62521 const result = topk(x, k, sorted);
62522 return [result.values, result.indices];
62523 }
62524 case 'UpperBound': {
62525 const sortedSequence = getParamValue('sortedSequence', node, tensorMap, context);
62526 const values = getParamValue('values', node, tensorMap, context);
62527 return [upperBound(sortedSequence, values)];
62528 }
62529 case 'Unique': {
62530 const x = getParamValue('x', node, tensorMap, context);
62531 const result = unique(x);
62532 return [result.values, result.indices];
62533 }
62534 case 'UniqueV2': {
62535 const x = getParamValue('x', node, tensorMap, context);
62536 const axis = getParamValue('axis', node, tensorMap, context);
62537 const result = unique(x, axis);
62538 return [result.values, result.indices];
62539 }
62540 default:
62541 throw TypeError(`Node type ${node.op} is not implemented`);
62542 }
62543 };
62544 const CATEGORY$6 = 'evaluation';
62545
62546 /**
62547 * @license
62548 * Copyright 2018 Google LLC. All Rights Reserved.
62549 * Licensed under the Apache License, Version 2.0 (the "License");
62550 * you may not use this file except in compliance with the License.
62551 * You may obtain a copy of the License at
62552 *
62553 * http://www.apache.org/licenses/LICENSE-2.0
62554 *
62555 * Unless required by applicable law or agreed to in writing, software
62556 * distributed under the License is distributed on an "AS IS" BASIS,
62557 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62558 * See the License for the specific language governing permissions and
62559 * limitations under the License.
62560 * =============================================================================
62561 */
62562 const executeOp$7 = (node, tensorMap, context) => {
62563 switch (node.op) {
62564 case 'Const': {
62565 return tensorMap[node.name];
62566 }
62567 case 'PlaceholderWithDefault':
62568 const def = getParamValue('default', node, tensorMap, context);
62569 return [getTensor(node.name, tensorMap, context) || def];
62570 case 'Placeholder':
62571 return [getTensor(node.name, tensorMap, context)];
62572 case 'Identity':
62573 case 'StopGradient':
62574 case 'FakeQuantWithMinMaxVars': { // This op is currently ignored.
62575 const data = getParamValue('x', node, tensorMap, context);
62576 return [cloneTensor(data)];
62577 }
62578 case 'IdentityN':
62579 return getParamValue('x', node, tensorMap, context)
62580 .map((t) => cloneTensor(t));
62581 case 'Snapshot':
62582 const snapshot = getParamValue('x', node, tensorMap, context);
62583 return [cloneTensor(snapshot)];
62584 case 'Shape':
62585 return [tensor1d(getParamValue('x', node, tensorMap, context).shape, 'int32')];
62586 case 'ShapeN':
62587 return getParamValue('x', node, tensorMap, context)
62588 .map((t) => tensor1d(t.shape));
62589 case 'Size':
62590 return [scalar(getParamValue('x', node, tensorMap, context).size, 'int32')];
62591 case 'Rank':
62592 return [scalar(getParamValue('x', node, tensorMap, context).rank, 'int32')];
62593 case 'NoOp':
62594 return [scalar(1)];
62595 case 'Print':
62596 const input = getParamValue('x', node, tensorMap, context);
62597 const data = getParamValue('data', node, tensorMap, context);
62598 const message = getParamValue('message', node, tensorMap, context);
62599 const summarize = getParamValue('summarize', node, tensorMap, context);
62600 console.warn('The graph has a tf.print() operation,' +
62601 'usually used for debugging, which slows down performance.');
62602 console.log(message);
62603 for (let i = 0; i < data.length; i++) {
62604 console.log(Array.prototype.slice.call(data[i].dataSync())
62605 .slice(0, summarize));
62606 }
62607 return [input];
62608 default:
62609 throw TypeError(`Node type ${node.op} is not implemented`);
62610 }
62611 };
62612 const CATEGORY$7 = 'graph';
62613
62614 /**
62615 * @license
62616 * Copyright 2020 Google LLC. All Rights Reserved.
62617 * Licensed under the Apache License, Version 2.0 (the "License");
62618 * you may not use this file except in compliance with the License.
62619 * You may obtain a copy of the License at
62620 *
62621 * http://www.apache.org/licenses/LICENSE-2.0
62622 *
62623 * Unless required by applicable law or agreed to in writing, software
62624 * distributed under the License is distributed on an "AS IS" BASIS,
62625 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62626 * See the License for the specific language governing permissions and
62627 * limitations under the License.
62628 * =============================================================================
62629 */
62630 /**
62631 * Hashtable contains a set of tensors, which can be accessed by key.
62632 */
62633 class HashTable {
62634 /**
62635 * Constructor of HashTable. Creates a hash table.
62636 *
62637 * @param keyDType `dtype` of the table keys.
62638 * @param valueDType `dtype` of the table values.
62639 */
62640 constructor(keyDType, valueDType) {
62641 this.keyDType = keyDType;
62642 this.valueDType = valueDType;
62643 this.handle = scalar(0);
62644 // tslint:disable-next-line: no-any
62645 this.tensorMap = new Map();
62646 keep(this.handle);
62647 }
62648 get id() {
62649 return this.handle.id;
62650 }
62651 /**
62652 * Dispose the tensors and handle and clear the hashtable.
62653 */
62654 clearAndClose() {
62655 this.tensorMap.forEach(value => value.dispose());
62656 this.tensorMap.clear();
62657 this.handle.dispose();
62658 }
62659 /**
62660 * The number of items in the hash table.
62661 */
62662 size() {
62663 return this.tensorMap.size;
62664 }
62665 /**
62666 * The number of items in the hash table as a rank-0 tensor.
62667 */
62668 tensorSize() {
62669 return scalar(this.size(), 'int32');
62670 }
62671 /**
62672 * Replaces the contents of the table with the specified keys and values.
62673 * @param keys Keys to store in the hashtable.
62674 * @param values Values to store in the hashtable.
62675 */
62676 async import(keys, values) {
62677 this.checkKeyAndValueTensor(keys, values);
62678 // We only store the primitive values of the keys, this allows lookup
62679 // to be O(1).
62680 const $keys = await keys.data();
62681 // Clear the hashTable before inserting new values.
62682 this.tensorMap.forEach(value => value.dispose());
62683 this.tensorMap.clear();
62684 return tidy(() => {
62685 const $values = unstack(values);
62686 const keysLength = $keys.length;
62687 const valuesLength = $values.length;
62688 assert(keysLength === valuesLength, () => `The number of elements doesn't match, keys has ` +
62689 `${keysLength} elements, the values has ${valuesLength} ` +
62690 `elements.`);
62691 for (let i = 0; i < keysLength; i++) {
62692 const key = $keys[i];
62693 const value = $values[i];
62694 keep(value);
62695 this.tensorMap.set(key, value);
62696 }
62697 return this.handle;
62698 });
62699 }
62700 /**
62701 * Looks up keys in a hash table, outputs the corresponding values.
62702 *
62703 * Performs batch lookups, for every element in the key tensor, `find`
62704 * stacks the corresponding value into the return tensor.
62705 *
62706 * If an element is not present in the table, the given `defaultValue` is
62707 * used.
62708 *
62709 * @param keys Keys to look up. Must have the same type as the keys of the
62710 * table.
62711 * @param defaultValue The scalar `defaultValue` is the value output for keys
62712 * not present in the table. It must also be of the same type as the
62713 * table values.
62714 */
62715 async find(keys, defaultValue) {
62716 this.checkKeyAndValueTensor(keys, defaultValue);
62717 const $keys = await keys.data();
62718 return tidy(() => {
62719 const result = [];
62720 for (let i = 0; i < $keys.length; i++) {
62721 const key = $keys[i];
62722 const value = this.findWithDefault(key, defaultValue);
62723 result.push(value);
62724 }
62725 return stack(result);
62726 });
62727 }
62728 // tslint:disable-next-line: no-any
62729 findWithDefault(key, defaultValue) {
62730 const result = this.tensorMap.get(key);
62731 return result != null ? result : defaultValue;
62732 }
62733 checkKeyAndValueTensor(key, value) {
62734 if (key.dtype !== this.keyDType) {
62735 throw new Error(`Expect key dtype ${this.keyDType}, but got ` +
62736 `${key.dtype}`);
62737 }
62738 if (value.dtype !== this.valueDType) {
62739 throw new Error(`Expect value dtype ${this.valueDType}, but got ` +
62740 `${value.dtype}`);
62741 }
62742 }
62743 }
62744
62745 /**
62746 * @license
62747 * Copyright 2020 Google LLC. All Rights Reserved.
62748 * Licensed under the Apache License, Version 2.0 (the "License");
62749 * you may not use this file except in compliance with the License.
62750 * You may obtain a copy of the License at
62751 *
62752 * http://www.apache.org/licenses/LICENSE-2.0
62753 *
62754 * Unless required by applicable law or agreed to in writing, software
62755 * distributed under the License is distributed on an "AS IS" BASIS,
62756 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62757 * See the License for the specific language governing permissions and
62758 * limitations under the License.
62759 * =============================================================================
62760 */
62761 const executeOp$8 = async (node, tensorMap, context, resourceManager) => {
62762 switch (node.op) {
62763 case 'HashTable':
62764 case 'HashTableV2': {
62765 const keyDType = getParamValue('keyDType', node, tensorMap, context);
62766 const valueDType = getParamValue('valueDType', node, tensorMap, context);
62767 const hashTable = new HashTable(keyDType, valueDType);
62768 resourceManager.addHashTable(node.name, hashTable);
62769 return [hashTable.handle];
62770 }
62771 case 'LookupTableImport':
62772 case 'LookupTableImportV2': {
62773 const handle = getParamValue('tableHandle', node, tensorMap, context, resourceManager);
62774 const keys = getParamValue('keys', node, tensorMap, context);
62775 const values = getParamValue('values', node, tensorMap, context);
62776 const hashTable = resourceManager.getHashTableById(handle.id);
62777 return [await hashTable.import(keys, values)];
62778 }
62779 case 'LookupTableFind':
62780 case 'LookupTableFindV2': {
62781 const handle = getParamValue('tableHandle', node, tensorMap, context, resourceManager);
62782 const keys = getParamValue('keys', node, tensorMap, context);
62783 const defaultValue = getParamValue('defaultValue', node, tensorMap, context);
62784 const hashTable = resourceManager.getHashTableById(handle.id);
62785 return [await hashTable.find(keys, defaultValue)];
62786 }
62787 case 'LookupTableSize':
62788 case 'LookupTableSizeV2': {
62789 const handle = getParamValue('tableHandle', node, tensorMap, context, resourceManager);
62790 const hashTable = resourceManager.getHashTableById(handle.id);
62791 return [hashTable.tensorSize()];
62792 }
62793 default:
62794 throw TypeError(`Node type ${node.op} is not implemented`);
62795 }
62796 };
62797 const CATEGORY$8 = 'hash_table';
62798
62799 /**
62800 * @license
62801 * Copyright 2018 Google LLC. All Rights Reserved.
62802 * Licensed under the Apache License, Version 2.0 (the "License");
62803 * you may not use this file except in compliance with the License.
62804 * You may obtain a copy of the License at
62805 *
62806 * http://www.apache.org/licenses/LICENSE-2.0
62807 *
62808 * Unless required by applicable law or agreed to in writing, software
62809 * distributed under the License is distributed on an "AS IS" BASIS,
62810 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62811 * See the License for the specific language governing permissions and
62812 * limitations under the License.
62813 * =============================================================================
62814 */
62815 const executeOp$9 = (node, tensorMap, context) => {
62816 switch (node.op) {
62817 case 'ResizeBilinear': {
62818 const images = getParamValue('images', node, tensorMap, context);
62819 const size = getParamValue('size', node, tensorMap, context);
62820 const alignCorners = getParamValue('alignCorners', node, tensorMap, context);
62821 const halfPixelCenters = getParamValue('halfPixelCenters', node, tensorMap, context);
62822 return [image.resizeBilinear(images, [size[0], size[1]], alignCorners, halfPixelCenters)];
62823 }
62824 case 'ResizeNearestNeighbor': {
62825 const images = getParamValue('images', node, tensorMap, context);
62826 const size = getParamValue('size', node, tensorMap, context);
62827 const alignCorners = getParamValue('alignCorners', node, tensorMap, context);
62828 const halfPixelCenters = getParamValue('halfPixelCenters', node, tensorMap, context);
62829 return [image.resizeNearestNeighbor(images, [size[0], size[1]], alignCorners, halfPixelCenters)];
62830 }
62831 case 'CropAndResize': {
62832 const image$1 = getParamValue('image', node, tensorMap, context);
62833 const boxes = getParamValue('boxes', node, tensorMap, context);
62834 const boxInd = getParamValue('boxInd', node, tensorMap, context);
62835 const cropSize = getParamValue('cropSize', node, tensorMap, context);
62836 const method = getParamValue('method', node, tensorMap, context);
62837 const extrapolationValue = getParamValue('extrapolationValue', node, tensorMap, context);
62838 return [image.cropAndResize(image$1, boxes, boxInd, cropSize, method, extrapolationValue)];
62839 }
62840 case 'ImageProjectiveTransformV3': {
62841 const images = getParamValue('images', node, tensorMap, context);
62842 const transforms = getParamValue('transforms', node, tensorMap, context);
62843 const outputShape = getParamValue('outputShape', node, tensorMap, context);
62844 const fillValue = getParamValue('fillValue', node, tensorMap, context);
62845 const interpolation = getParamValue('interpolation', node, tensorMap, context);
62846 const fillMode = getParamValue('fillMode', node, tensorMap, context);
62847 return [image.transform(images, transforms, interpolation.toLowerCase(), fillMode.toLowerCase(), fillValue, outputShape)];
62848 }
62849 default:
62850 throw TypeError(`Node type ${node.op} is not implemented`);
62851 }
62852 };
62853 const CATEGORY$9 = 'image';
62854
62855 /**
62856 * @license
62857 * Copyright 2018 Google LLC. All Rights Reserved.
62858 * Licensed under the Apache License, Version 2.0 (the "License");
62859 * you may not use this file except in compliance with the License.
62860 * You may obtain a copy of the License at
62861 *
62862 * http://www.apache.org/licenses/LICENSE-2.0
62863 *
62864 * Unless required by applicable law or agreed to in writing, software
62865 * distributed under the License is distributed on an "AS IS" BASIS,
62866 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62867 * See the License for the specific language governing permissions and
62868 * limitations under the License.
62869 * =============================================================================
62870 */
62871 const executeOp$a = (node, tensorMap, context) => {
62872 switch (node.op) {
62873 case 'Equal': {
62874 return [equal(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
62875 }
62876 case 'NotEqual': {
62877 return [notEqual(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
62878 }
62879 case 'Greater': {
62880 return [greater(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
62881 }
62882 case 'GreaterEqual': {
62883 return [greaterEqual(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
62884 }
62885 case 'Less': {
62886 return [less(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
62887 }
62888 case 'LessEqual': {
62889 return [lessEqual(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
62890 }
62891 case 'LogicalAnd': {
62892 return [logicalAnd(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
62893 }
62894 case 'LogicalNot': {
62895 return [logicalNot(getParamValue('a', node, tensorMap, context))];
62896 }
62897 case 'LogicalOr': {
62898 return [logicalOr(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
62899 }
62900 case 'Select':
62901 case 'SelectV2': {
62902 return [where(getParamValue('condition', node, tensorMap, context), getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
62903 }
62904 default:
62905 throw TypeError(`Node type ${node.op} is not implemented`);
62906 }
62907 };
62908 const CATEGORY$a = 'logical';
62909
62910 /**
62911 * @license
62912 * Copyright 2018 Google LLC. All Rights Reserved.
62913 * Licensed under the Apache License, Version 2.0 (the "License");
62914 * you may not use this file except in compliance with the License.
62915 * You may obtain a copy of the License at
62916 *
62917 * http://www.apache.org/licenses/LICENSE-2.0
62918 *
62919 * Unless required by applicable law or agreed to in writing, software
62920 * distributed under the License is distributed on an "AS IS" BASIS,
62921 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62922 * See the License for the specific language governing permissions and
62923 * limitations under the License.
62924 * =============================================================================
62925 */
62926 const executeOp$b = (node, tensorMap, context) => {
62927 switch (node.op) {
62928 case 'BatchMatMul':
62929 case 'BatchMatMulV2':
62930 case 'MatMul':
62931 return [matMul(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context), getParamValue('transposeA', node, tensorMap, context), getParamValue('transposeB', node, tensorMap, context))];
62932 case 'Einsum':
62933 return [einsum(getParamValue('equation', node, tensorMap, context), ...getParamValue('tensors', node, tensorMap, context))];
62934 case 'Transpose':
62935 return [transpose(getParamValue('x', node, tensorMap, context), getParamValue('perm', node, tensorMap, context))];
62936 case '_FusedMatMul':
62937 const [extraOp, activationFunc] = getParamValue('fusedOps', node, tensorMap, context);
62938 const isBiasAdd = extraOp === 'biasadd';
62939 const isPrelu = activationFunc === 'prelu';
62940 const numArgs = getParamValue('numArgs', node, tensorMap, context);
62941 const leakyreluAlpha = getParamValue('leakyreluAlpha', node, tensorMap, context);
62942 if (isBiasAdd) {
62943 if (isPrelu && numArgs !== 2) {
62944 throw new Error('Fused MatMul with BiasAdd and Prelu must have two ' +
62945 'extra arguments: bias and alpha.');
62946 }
62947 if (!isPrelu && numArgs !== 1) {
62948 throw new Error('Fused MatMul with BiasAdd must have one extra argument: bias.');
62949 }
62950 }
62951 const [biasArg, preluArg] = getParamValue('args', node, tensorMap, context);
62952 return [matMul$1({
62953 a: getParamValue('a', node, tensorMap, context),
62954 b: getParamValue('b', node, tensorMap, context),
62955 transposeA: getParamValue('transposeA', node, tensorMap, context),
62956 transposeB: getParamValue('transposeB', node, tensorMap, context),
62957 bias: biasArg,
62958 activation: activationFunc,
62959 preluActivationWeights: preluArg,
62960 leakyreluAlpha
62961 })];
62962 default:
62963 throw TypeError(`Node type ${node.op} is not implemented`);
62964 }
62965 };
62966 const CATEGORY$b = 'matrices';
62967
62968 /**
62969 * @license
62970 * Copyright 2018 Google LLC. All Rights Reserved.
62971 * Licensed under the Apache License, Version 2.0 (the "License");
62972 * you may not use this file except in compliance with the License.
62973 * You may obtain a copy of the License at
62974 *
62975 * http://www.apache.org/licenses/LICENSE-2.0
62976 *
62977 * Unless required by applicable law or agreed to in writing, software
62978 * distributed under the License is distributed on an "AS IS" BASIS,
62979 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62980 * See the License for the specific language governing permissions and
62981 * limitations under the License.
62982 * =============================================================================
62983 */
62984 const executeOp$c = (node, tensorMap, context) => {
62985 switch (node.op) {
62986 case 'EuclideanNorm':
62987 return [euclideanNorm(getParamValue('x', node, tensorMap, context), getParamValue('axis', node, tensorMap, context), getParamValue('keepDims', node, tensorMap, context))];
62988 case 'FusedBatchNorm':
62989 case 'FusedBatchNormV2': {
62990 return [batchNorm(getParamValue('x', node, tensorMap, context), getParamValue('mean', node, tensorMap, context), getParamValue('variance', node, tensorMap, context), getParamValue('offset', node, tensorMap, context), getParamValue('scale', node, tensorMap, context), getParamValue('epsilon', node, tensorMap, context))];
62991 }
62992 case 'FusedBatchNormV3': {
62993 return [batchNorm(getParamValue('x', node, tensorMap, context), getParamValue('mean', node, tensorMap, context), getParamValue('variance', node, tensorMap, context), getParamValue('offset', node, tensorMap, context), getParamValue('scale', node, tensorMap, context), getParamValue('epsilon', node, tensorMap, context))];
62994 }
62995 case 'LRN': {
62996 return [localResponseNormalization(getParamValue('x', node, tensorMap, context), getParamValue('radius', node, tensorMap, context), getParamValue('bias', node, tensorMap, context), getParamValue('alpha', node, tensorMap, context), getParamValue('beta', node, tensorMap, context))];
62997 }
62998 case 'Softmax': {
62999 return [softmax(getParamValue('x', node, tensorMap, context))];
63000 }
63001 case 'LogSoftmax': {
63002 return [logSoftmax(getParamValue('x', node, tensorMap, context))];
63003 }
63004 case 'SparseToDense': {
63005 return [sparseToDense(getParamValue('sparseIndices', node, tensorMap, context), getParamValue('outputShape', node, tensorMap, context), getParamValue('sparseValues', node, tensorMap, context), getParamValue('defaultValue', node, tensorMap, context))];
63006 }
63007 default:
63008 throw TypeError(`Node type ${node.op} is not implemented`);
63009 }
63010 };
63011 const CATEGORY$c = 'normalization';
63012
63013 /**
63014 * @license
63015 * Copyright 2018 Google LLC. All Rights Reserved.
63016 * Licensed under the Apache License, Version 2.0 (the "License");
63017 * you may not use this file except in compliance with the License.
63018 * You may obtain a copy of the License at
63019 *
63020 * http://www.apache.org/licenses/LICENSE-2.0
63021 *
63022 * Unless required by applicable law or agreed to in writing, software
63023 * distributed under the License is distributed on an "AS IS" BASIS,
63024 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
63025 * See the License for the specific language governing permissions and
63026 * limitations under the License.
63027 * =============================================================================
63028 */
63029 const executeOp$d = (node, tensorMap, context) => {
63030 switch (node.op) {
63031 case 'Max': {
63032 const axis = getParamValue('axis', node, tensorMap, context);
63033 const keepDims = getParamValue('keepDims', node, tensorMap, context);
63034 return [max(getParamValue('x', node, tensorMap, context), axis, keepDims)];
63035 }
63036 case 'Mean': {
63037 const axis = getParamValue('axis', node, tensorMap, context);
63038 const keepDims = getParamValue('keepDims', node, tensorMap, context);
63039 return [mean(getParamValue('x', node, tensorMap, context), axis, keepDims)];
63040 }
63041 case 'Min': {
63042 const axis = getParamValue('axis', node, tensorMap, context);
63043 const keepDims = getParamValue('keepDims', node, tensorMap, context);
63044 return [min(getParamValue('x', node, tensorMap, context), axis, keepDims)];
63045 }
63046 case 'Sum': {
63047 const axis = getParamValue('axis', node, tensorMap, context);
63048 const keepDims = getParamValue('keepDims', node, tensorMap, context);
63049 return [sum$1(getParamValue('x', node, tensorMap, context), axis, keepDims)];
63050 }
63051 case 'All': {
63052 const axis = getParamValue('axis', node, tensorMap, context);
63053 const keepDims = getParamValue('keepDims', node, tensorMap, context);
63054 return [all(getParamValue('x', node, tensorMap, context), axis, keepDims)];
63055 }
63056 case 'Any': {
63057 const axis = getParamValue('axis', node, tensorMap, context);
63058 const keepDims = getParamValue('keepDims', node, tensorMap, context);
63059 return [any(getParamValue('x', node, tensorMap, context), axis, keepDims)];
63060 }
63061 case 'ArgMax': {
63062 const axis = getParamValue('axis', node, tensorMap, context);
63063 return [argMax(getParamValue('x', node, tensorMap, context), axis)];
63064 }
63065 case 'ArgMin': {
63066 const axis = getParamValue('axis', node, tensorMap, context);
63067 return [argMin(getParamValue('x', node, tensorMap, context), axis)];
63068 }
63069 case 'Prod': {
63070 const axis = getParamValue('axis', node, tensorMap, context);
63071 const keepDims = getParamValue('keepDims', node, tensorMap, context);
63072 return [prod(getParamValue('x', node, tensorMap, context), axis, keepDims)];
63073 }
63074 case 'Cumprod': {
63075 const axis = getParamValue('axis', node, tensorMap, context);
63076 const exclusive = getParamValue('exclusive', node, tensorMap, context);
63077 const reverse = getParamValue('reverse', node, tensorMap, context);
63078 return [cumprod(getParamValue('x', node, tensorMap, context), axis, exclusive, reverse)];
63079 }
63080 case 'Cumsum': {
63081 const axis = getParamValue('axis', node, tensorMap, context);
63082 const exclusive = getParamValue('exclusive', node, tensorMap, context);
63083 const reverse = getParamValue('reverse', node, tensorMap, context);
63084 return [cumsum(getParamValue('x', node, tensorMap, context), axis, exclusive, reverse)];
63085 }
63086 case 'Bincount':
63087 const x = getParamValue('x', node, tensorMap, context);
63088 const weights = getParamValue('weights', node, tensorMap, context);
63089 const size = getParamValue('size', node, tensorMap, context);
63090 return [bincount(x, weights, size)];
63091 case 'DenseBincount': {
63092 const x = getParamValue('x', node, tensorMap, context);
63093 const weights = getParamValue('weights', node, tensorMap, context);
63094 const size = getParamValue('size', node, tensorMap, context);
63095 const binaryOutput = getParamValue('binaryOutput', node, tensorMap, context);
63096 return [denseBincount(x, weights, size, binaryOutput)];
63097 }
63098 default:
63099 throw TypeError(`Node type ${node.op} is not implemented`);
63100 }
63101 };
63102 const CATEGORY$d = 'reduction';
63103
63104 /**
63105 * @license
63106 * Copyright 2018 Google LLC. All Rights Reserved.
63107 * Licensed under the Apache License, Version 2.0 (the "License");
63108 * you may not use this file except in compliance with the License.
63109 * You may obtain a copy of the License at
63110 *
63111 * http://www.apache.org/licenses/LICENSE-2.0
63112 *
63113 * Unless required by applicable law or agreed to in writing, software
63114 * distributed under the License is distributed on an "AS IS" BASIS,
63115 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
63116 * See the License for the specific language governing permissions and
63117 * limitations under the License.
63118 * =============================================================================
63119 */
63120 const executeOp$e = (node, tensorMap, context) => {
63121 switch (node.op) {
63122 case 'ConcatV2':
63123 case 'Concat': {
63124 const n = getParamValue('n', node, tensorMap, context);
63125 const axis = getParamValue('axis', node, tensorMap, context);
63126 let inputs = getParamValue('tensors', node, tensorMap, context);
63127 inputs = inputs.slice(0, n);
63128 return [concat(inputs, axis)];
63129 }
63130 case 'Gather': {
63131 const input = getParamValue('x', node, tensorMap, context);
63132 const indices = getParamValue('indices', node, tensorMap, context);
63133 return [gather(input, cast(indices, 'int32'), 0)];
63134 }
63135 case 'GatherV2': {
63136 const axis = getParamValue('axis', node, tensorMap, context);
63137 const batchDims = getParamValue('batchDims', node, tensorMap, context);
63138 const input = getParamValue('x', node, tensorMap, context);
63139 const indices = getParamValue('indices', node, tensorMap, context);
63140 return [gather(input, cast(indices, 'int32'), axis, batchDims)];
63141 }
63142 case 'Reverse': {
63143 const dims = getParamValue('dims', node, tensorMap, context);
63144 const axis = [];
63145 for (let i = 0; i < dims.length; i++) {
63146 if (dims[i]) {
63147 axis.push(i);
63148 }
63149 }
63150 const input = getParamValue('x', node, tensorMap, context);
63151 return [reverse(input, axis)];
63152 }
63153 case 'ReverseV2': {
63154 const axis = getParamValue('axis', node, tensorMap, context);
63155 const input = getParamValue('x', node, tensorMap, context);
63156 return [reverse(input, axis)];
63157 }
63158 case 'Slice': {
63159 // tslint:disable-next-line:no-any
63160 const begin = getParamValue('begin', node, tensorMap, context);
63161 // tslint:disable-next-line:no-any
63162 const size = getParamValue('size', node, tensorMap, context);
63163 return [slice(getParamValue('x', node, tensorMap, context), begin, size)];
63164 }
63165 case 'StridedSlice': {
63166 const begin = getParamValue('begin', node, tensorMap, context);
63167 const end = getParamValue('end', node, tensorMap, context);
63168 const strides = getParamValue('strides', node, tensorMap, context);
63169 const beginMask = getParamValue('beginMask', node, tensorMap, context);
63170 const endMask = getParamValue('endMask', node, tensorMap, context);
63171 const ellipsisMask = getParamValue('ellipsisMask', node, tensorMap, context);
63172 const newAxisMask = getParamValue('newAxisMask', node, tensorMap, context);
63173 const shrinkAxisMask = getParamValue('shrinkAxisMask', node, tensorMap, context);
63174 const tensor = getParamValue('x', node, tensorMap, context);
63175 return [stridedSlice(tensor, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask)];
63176 }
63177 case 'Pack': {
63178 return tidy(() => {
63179 const axis = getParamValue('axis', node, tensorMap, context);
63180 const tensors = getParamValue('tensors', node, tensorMap, context);
63181 // Reshape the tensors to the first tensor's shape if they don't
63182 // match.
63183 const shape = tensors[0].shape;
63184 const squeezedShape = squeeze(tensors[0]).shape;
63185 const mapped = tensors.map(tensor => {
63186 const sameShape = arraysEqual(tensor.shape, shape);
63187 if (!sameShape &&
63188 !arraysEqual(squeeze(tensor).shape, squeezedShape)) {
63189 throw new Error('the input tensors shape does not match');
63190 }
63191 return sameShape ? tensor : reshape(tensor, shape);
63192 });
63193 return [stack(mapped, axis)];
63194 });
63195 }
63196 case 'Unpack': {
63197 const axis = getParamValue('axis', node, tensorMap, context);
63198 const tensor = getParamValue('tensor', node, tensorMap, context);
63199 return unstack(tensor, axis);
63200 }
63201 case 'Tile': {
63202 const reps = getParamValue('reps', node, tensorMap, context);
63203 return [tile(getParamValue('x', node, tensorMap, context), reps)];
63204 }
63205 case 'Split':
63206 case 'SplitV': {
63207 const axis = getParamValue('axis', node, tensorMap, context);
63208 const numOrSizeSplits = getParamValue('numOrSizeSplits', node, tensorMap, context);
63209 const tensor = getParamValue('x', node, tensorMap, context);
63210 return split(tensor, numOrSizeSplits, axis);
63211 }
63212 case 'ScatterNd': {
63213 const indices = getParamValue('indices', node, tensorMap, context);
63214 const values = getParamValue('values', node, tensorMap, context);
63215 const shape = getParamValue('shape', node, tensorMap, context);
63216 return [scatterND(indices, values, shape)];
63217 }
63218 case 'GatherNd': {
63219 const x = getParamValue('x', node, tensorMap, context);
63220 const indices = getParamValue('indices', node, tensorMap, context);
63221 return [gatherND(x, indices)];
63222 }
63223 case 'SparseToDense': {
63224 const indices = getParamValue('sparseIndices', node, tensorMap, context);
63225 const shape = getParamValue('outputShape', node, tensorMap, context);
63226 const sparseValues = getParamValue('sparseValues', node, tensorMap, context);
63227 const defaultValue = getParamValue('defaultValue', node, tensorMap, context);
63228 return [sparseToDense(indices, sparseValues, shape, sparseValues.dtype === defaultValue.dtype ?
63229 defaultValue :
63230 cast(defaultValue, sparseValues.dtype))];
63231 }
63232 default:
63233 throw TypeError(`Node type ${node.op} is not implemented`);
63234 }
63235 };
63236 const CATEGORY$e = 'slice_join';
63237
63238 /**
63239 * @license
63240 * Copyright 2021 Google LLC. All Rights Reserved.
63241 * Licensed under the Apache License, Version 2.0 (the "License");
63242 * you may not use this file except in compliance with the License.
63243 * You may obtain a copy of the License at
63244 *
63245 * http://www.apache.org/licenses/LICENSE-2.0
63246 *
63247 * Unless required by applicable law or agreed to in writing, software
63248 * distributed under the License is distributed on an "AS IS" BASIS,
63249 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
63250 * See the License for the specific language governing permissions and
63251 * limitations under the License.
63252 * =============================================================================
63253 */
63254 const executeOp$f = (node, tensorMap, context) => {
63255 switch (node.op) {
63256 case 'SparseFillEmptyRows': {
63257 const { outputIndices, outputValues, emptyRowIndicator, reverseIndexMap } = sparse.sparseFillEmptyRows(getParamValue('indices', node, tensorMap, context), getParamValue('values', node, tensorMap, context), getParamValue('denseShape', node, tensorMap, context), getParamValue('defaultValue', node, tensorMap, context));
63258 return [
63259 outputIndices, outputValues, emptyRowIndicator, reverseIndexMap
63260 ];
63261 }
63262 case 'SparseReshape': {
63263 const { outputIndices, outputShape } = sparse.sparseReshape(getParamValue('inputIndices', node, tensorMap, context), getParamValue('inputShape', node, tensorMap, context), getParamValue('newShape', node, tensorMap, context));
63264 return [outputIndices, outputShape];
63265 }
63266 case 'SparseSegmentMean': {
63267 const outputData = sparse.sparseSegmentMean(getParamValue('data', node, tensorMap, context), getParamValue('indices', node, tensorMap, context), getParamValue('segmentIds', node, tensorMap, context));
63268 return [outputData];
63269 }
63270 case 'SparseSegmentSum': {
63271 const outputData = sparse.sparseSegmentSum(getParamValue('data', node, tensorMap, context), getParamValue('indices', node, tensorMap, context), getParamValue('segmentIds', node, tensorMap, context));
63272 return [outputData];
63273 }
63274 default:
63275 throw TypeError(`Node type ${node.op} is not implemented`);
63276 }
63277 };
63278 const CATEGORY$f = 'sparse';
63279
63280 /**
63281 * @license
63282 * Copyright 2018 Google LLC. All Rights Reserved.
63283 * Licensed under the Apache License, Version 2.0 (the "License");
63284 * you may not use this file except in compliance with the License.
63285 * You may obtain a copy of the License at
63286 *
63287 * http://www.apache.org/licenses/LICENSE-2.0
63288 *
63289 * Unless required by applicable law or agreed to in writing, software
63290 * distributed under the License is distributed on an "AS IS" BASIS,
63291 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
63292 * See the License for the specific language governing permissions and
63293 * limitations under the License.
63294 * =============================================================================
63295 */
63296 const executeOp$g = (node, tensorMap, context) => {
63297 switch (node.op) {
63298 case 'FFT': {
63299 return [fft(getParamValue('x', node, tensorMap, context))];
63300 }
63301 case 'IFFT': {
63302 return [ifft(getParamValue('x', node, tensorMap, context))];
63303 }
63304 case 'RFFT': {
63305 return [rfft(getParamValue('x', node, tensorMap, context))];
63306 }
63307 case 'IRFFT': {
63308 return [irfft(getParamValue('x', node, tensorMap, context))];
63309 }
63310 default:
63311 throw TypeError(`Node type ${node.op} is not implemented`);
63312 }
63313 };
63314 const CATEGORY$g = 'spectral';
63315
63316 /**
63317 * @license
63318 * Copyright 2021 Google LLC. All Rights Reserved.
63319 * Licensed under the Apache License, Version 2.0 (the "License");
63320 * you may not use this file except in compliance with the License.
63321 * You may obtain a copy of the License at
63322 *
63323 * http://www.apache.org/licenses/LICENSE-2.0
63324 *
63325 * Unless required by applicable law or agreed to in writing, software
63326 * distributed under the License is distributed on an "AS IS" BASIS,
63327 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
63328 * See the License for the specific language governing permissions and
63329 * limitations under the License.
63330 * =============================================================================
63331 */
63332 const executeOp$h = (node, tensorMap, context) => {
63333 switch (node.op) {
63334 case 'StringNGrams': {
63335 const { nGrams, nGramsSplits } = string.stringNGrams(getParamValue('data', node, tensorMap, context), getParamValue('dataSplits', node, tensorMap, context), getParamValue('separator', node, tensorMap, context), getParamValue('nGramWidths', node, tensorMap, context), getParamValue('leftPad', node, tensorMap, context), getParamValue('rightPad', node, tensorMap, context), getParamValue('padWidth', node, tensorMap, context), getParamValue('preserveShortSequences', node, tensorMap, context));
63336 return [nGrams, nGramsSplits];
63337 }
63338 case 'StringSplit': {
63339 const { indices, values, shape } = string.stringSplit(getParamValue('input', node, tensorMap, context), getParamValue('delimiter', node, tensorMap, context), getParamValue('skipEmpty', node, tensorMap, context));
63340 return [indices, values, shape];
63341 }
63342 case 'StringToHashBucketFast': {
63343 const output = string.stringToHashBucketFast(getParamValue('input', node, tensorMap, context), getParamValue('numBuckets', node, tensorMap, context));
63344 return [output];
63345 }
63346 default:
63347 throw TypeError(`Node type ${node.op} is not implemented`);
63348 }
63349 };
63350 const CATEGORY$h = 'string';
63351
63352 /**
63353 * @license
63354 * Copyright 2018 Google LLC. All Rights Reserved.
63355 * Licensed under the Apache License, Version 2.0 (the "License");
63356 * you may not use this file except in compliance with the License.
63357 * You may obtain a copy of the License at
63358 *
63359 * http://www.apache.org/licenses/LICENSE-2.0
63360 *
63361 * Unless required by applicable law or agreed to in writing, software
63362 * distributed under the License is distributed on an "AS IS" BASIS,
63363 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
63364 * See the License for the specific language governing permissions and
63365 * limitations under the License.
63366 * =============================================================================
63367 */
63368 const executeOp$i = (node, tensorMap, context) => {
63369 switch (node.op) {
63370 case 'Cast': {
63371 return [cast(getParamValue('x', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
63372 }
63373 case 'ExpandDims': {
63374 const axis = getParamValue('axis', node, tensorMap, context);
63375 return [expandDims(getParamValue('x', node, tensorMap, context), axis)];
63376 }
63377 case 'Squeeze': {
63378 const axis = getParamValue('axis', node, tensorMap, context);
63379 return [squeeze(getParamValue('x', node, tensorMap, context), axis)];
63380 }
63381 case 'Reshape': {
63382 return [reshape(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
63383 }
63384 case 'MirrorPad': {
63385 return [mirrorPad(getParamValue('x', node, tensorMap, context), getParamValue('padding', node, tensorMap, context), getParamValue('mode', node, tensorMap, context))];
63386 }
63387 case 'PadV2':
63388 case 'Pad': {
63389 return [pad(getParamValue('x', node, tensorMap, context), getParamValue('padding', node, tensorMap, context), getParamValue('constantValue', node, tensorMap, context))];
63390 }
63391 case 'SpaceToBatchND': {
63392 const blockShape = getParamValue('blockShape', node, tensorMap, context);
63393 const paddings = getParamValue('paddings', node, tensorMap, context);
63394 return [spaceToBatchND(getParamValue('x', node, tensorMap, context), blockShape, paddings)];
63395 }
63396 case 'BatchToSpaceND': {
63397 const blockShape = getParamValue('blockShape', node, tensorMap, context);
63398 const crops = getParamValue('crops', node, tensorMap, context);
63399 return [batchToSpaceND(getParamValue('x', node, tensorMap, context), blockShape, crops)];
63400 }
63401 case 'DepthToSpace': {
63402 const blockSize = getParamValue('blockSize', node, tensorMap, context);
63403 const dataFormat = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
63404 return [depthToSpace(getParamValue('x', node, tensorMap, context), blockSize, dataFormat)];
63405 }
63406 case 'BroadcastTo': {
63407 return [broadcastTo(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
63408 }
63409 case 'BroadcastArgs': {
63410 return [broadcastArgs(getParamValue('s0', node, tensorMap, context), getParamValue('s1', node, tensorMap, context))];
63411 }
63412 default:
63413 throw TypeError(`Node type ${node.op} is not implemented`);
63414 }
63415 };
63416 const CATEGORY$i = 'transformation';
63417
63418 /**
63419 * @license
63420 * Copyright 2018 Google LLC. All Rights Reserved.
63421 * Licensed under the Apache License, Version 2.0 (the "License");
63422 * you may not use this file except in compliance with the License.
63423 * You may obtain a copy of the License at
63424 *
63425 * http://www.apache.org/licenses/LICENSE-2.0
63426 *
63427 * Unless required by applicable law or agreed to in writing, software
63428 * distributed under the License is distributed on an "AS IS" BASIS,
63429 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
63430 * See the License for the specific language governing permissions and
63431 * limitations under the License.
63432 * =============================================================================
63433 */
63434 /**
63435 * Executes the op defined by the node object.
63436 * @param node
63437 * @param tensorMap contains tensors for executed nodes and weights
63438 * @param context contains tensors and information for running the current node.
63439 * @param resourceManager Optional. Contains global resources of the model.
63440 */
63441 function executeOp$j(node, tensorMap, context, resourceManager) {
63442 const value = ((node, tensorMap, context) => {
63443 switch (node.category) {
63444 case 'arithmetic':
63445 return tidy(() => executeOp(node, tensorMap, context));
63446 case 'basic_math':
63447 return tidy(() => executeOp$1(node, tensorMap, context));
63448 case 'control':
63449 return executeOp$2(node, tensorMap, context);
63450 case 'convolution':
63451 return tidy(() => executeOp$3(node, tensorMap, context));
63452 case 'creation':
63453 return tidy(() => executeOp$4(node, tensorMap, context));
63454 case 'dynamic':
63455 return executeOp$5(node, tensorMap, context);
63456 case 'evaluation':
63457 return tidy(() => executeOp$6(node, tensorMap, context));
63458 case 'image':
63459 return tidy(() => executeOp$9(node, tensorMap, context));
63460 case 'graph':
63461 return tidy(() => executeOp$7(node, tensorMap, context));
63462 case 'logical':
63463 return tidy(() => executeOp$a(node, tensorMap, context));
63464 case 'matrices':
63465 return tidy(() => executeOp$b(node, tensorMap, context));
63466 case 'normalization':
63467 return tidy(() => executeOp$c(node, tensorMap, context));
63468 case 'reduction':
63469 return tidy(() => executeOp$d(node, tensorMap, context));
63470 case 'slice_join':
63471 return tidy(() => executeOp$e(node, tensorMap, context));
63472 case 'sparse':
63473 return tidy(() => executeOp$f(node, tensorMap, context));
63474 case 'spectral':
63475 return tidy(() => executeOp$g(node, tensorMap, context));
63476 case 'string':
63477 return tidy(() => executeOp$h(node, tensorMap, context));
63478 case 'transformation':
63479 return tidy(() => executeOp$i(node, tensorMap, context));
63480 case 'hash_table':
63481 return executeOp$8(node, tensorMap, context, resourceManager);
63482 case 'custom':
63483 const opMapper = getRegisteredOp(node.op);
63484 if (opMapper && opMapper.customExecutor) {
63485 return opMapper.customExecutor(new NodeValueImpl(node, tensorMap, context));
63486 }
63487 else {
63488 throw TypeError(`Custom op ${node.op} is not registered.`);
63489 }
63490 default:
63491 throw TypeError(`Unknown op '${node.op}'. File an issue at ` +
63492 `https://github.com/tensorflow/tfjs/issues so we can add it` +
63493 `, or register a custom execution with tf.registerOp()`);
63494 }
63495 })(node, tensorMap, context);
63496 if (isPromise(value)) {
63497 return value.then((data) => [].concat(data));
63498 }
63499 return [].concat(value);
63500 }
63501
63502 /**
63503 * ExecutionContext captures the runtime environment of the node. It keeps
63504 * track of the current frame and iteration for the control flow ops.
63505 *
63506 * For example, typical Dynamic RNN model may contain loops, for which
63507 * TensorFlow will generate graphs with Enter/Exit nodes to control the
63508 * current execution frame, and NextIteration Nodes for iteration id increment.
63509 * For model with branch logic, TensorFLow will generate Switch/Merge ops.
63510 */
63511 class ExecutionContext {
63512 constructor(weightMap = {}, tensorArrayMap = {}, tensorListMap = {}, functionMap = {}) {
63513 this.weightMap = weightMap;
63514 this.tensorArrayMap = tensorArrayMap;
63515 this.tensorListMap = tensorListMap;
63516 this.functionMap = functionMap;
63517 this.rootContext = { id: 0, frameName: '', iterationId: 0 };
63518 this.contexts = [this.rootContext];
63519 this.lastId = 0;
63520 this.generateCurrentContextIds();
63521 }
63522 newFrame(id, frameName) {
63523 return { id, frameName, iterationId: 0 };
63524 }
63525 /**
63526 * Set the current context
63527 * @param contexts: ExecutionContextInfo[] the current path of execution
63528 * frames
63529 */
63530 set currentContext(contexts) {
63531 if (this.contexts !== contexts) {
63532 this.contexts = contexts;
63533 this.generateCurrentContextIds();
63534 }
63535 }
63536 get currentContext() {
63537 return this.contexts;
63538 }
63539 /**
63540 * Returns the current context in string format.
63541 */
63542 get currentContextId() {
63543 return this._currentContextIds[0];
63544 }
63545 /**
63546 * Returns the current context and all parent contexts in string format.
63547 * This allow access to the nodes in the current and parent frames.
63548 */
63549 get currentContextIds() {
63550 return this._currentContextIds;
63551 }
63552 generateCurrentContextIds() {
63553 const names = [];
63554 for (let i = 0; i < this.contexts.length - 1; i++) {
63555 const contexts = this.contexts.slice(0, this.contexts.length - i);
63556 names.push(this.contextIdforContexts(contexts));
63557 }
63558 names.push('');
63559 this._currentContextIds = names;
63560 }
63561 contextIdforContexts(contexts) {
63562 return contexts ?
63563 contexts
63564 .map(context => (context.id === 0 && context.iterationId === 0) ?
63565 '' :
63566 `${context.frameName}-${context.iterationId}`)
63567 .join('/') :
63568 '';
63569 }
63570 /**
63571 * Enter a new frame, a new context is pushed on the current context list.
63572 * @param frameId new frame id
63573 */
63574 enterFrame(frameId) {
63575 if (this.contexts) {
63576 this.lastId++;
63577 this.contexts = this.contexts.slice();
63578 this.contexts.push(this.newFrame(this.lastId, frameId));
63579 this._currentContextIds.unshift(this.contextIdforContexts(this.contexts));
63580 }
63581 }
63582 /**
63583 * Exit the current frame, the last context is removed from the current
63584 * context list.
63585 */
63586 exitFrame() {
63587 if (this.contexts && this.contexts.length > 1) {
63588 this.contexts = this.contexts.slice();
63589 this.contexts.splice(-1);
63590 this.currentContextIds.shift();
63591 }
63592 else {
63593 throw new Error('Cannot exit frame, the context is empty');
63594 }
63595 }
63596 /**
63597 * Enter the next iteration of a loop, the iteration id of last context is
63598 * increased.
63599 */
63600 nextIteration() {
63601 if (this.contexts && this.contexts.length > 0) {
63602 this.contexts = this.contexts.slice();
63603 this.lastId++;
63604 const context = Object.assign({}, this.contexts[this.contexts.length - 1]);
63605 context.iterationId += 1;
63606 context.id = this.lastId;
63607 this.contexts.splice(-1, 1, context);
63608 this._currentContextIds.splice(0, 1, this.contextIdforContexts(this.contexts));
63609 }
63610 else {
63611 throw new Error('Cannot increase frame iteration, the context is empty');
63612 }
63613 }
63614 getWeight(name) {
63615 return this.weightMap[name];
63616 }
63617 addTensorArray(tensorArray) {
63618 this.tensorArrayMap[tensorArray.id] = tensorArray;
63619 }
63620 getTensorArray(id) {
63621 return this.tensorArrayMap[id];
63622 }
63623 addTensorList(tensorList) {
63624 this.tensorListMap[tensorList.id] = tensorList;
63625 }
63626 getTensorList(id) {
63627 return this.tensorListMap[id];
63628 }
63629 dispose(keepIds) {
63630 for (const key in this.tensorArrayMap) {
63631 this.tensorArrayMap[key].clearAndClose(keepIds);
63632 }
63633 for (const key in this.tensorListMap) {
63634 this.tensorListMap[key].clearAndClose(keepIds);
63635 }
63636 }
63637 }
63638
63639 /**
63640 * @license
63641 * Copyright 2019 Google LLC. All Rights Reserved.
63642 * Licensed under the Apache License, Version 2.0 (the "License");
63643 * you may not use this file except in compliance with the License.
63644 * You may obtain a copy of the License at
63645 *
63646 * http://www.apache.org/licenses/LICENSE-2.0
63647 *
63648 * Unless required by applicable law or agreed to in writing, software
63649 * distributed under the License is distributed on an "AS IS" BASIS,
63650 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
63651 * See the License for the specific language governing permissions and
63652 * limitations under the License.
63653 * =============================================================================
63654 */
63655 /**
63656 * Given graph inputs and desired outputs, find the minimal set of nodes
63657 * to execute in order to compute the outputs. In addition return other useful
63658 * info such:
63659 * - Missing inputs needed to compute the output.
63660 * - Whether the subgraph contains dynamic ops (control flow, dynamic shape).
63661 * - Alternative inputs in order to avoid async (dynamic op) execution.
63662 */
63663 function getExecutionSubgraph(inputs, outputs, weightMap, initNodes) {
63664 const usedNodes = new Set();
63665 const missingInputs = [];
63666 let dynamicNode = null;
63667 let syncInputs = null;
63668 // Start with the outputs, going backwards and find all the nodes that are
63669 // needed to compute those outputs.
63670 const seen = new Set();
63671 const inputNodeNames = Object.keys(inputs).map(name => parseNodeName(name)[0]);
63672 let initNodeNames = [];
63673 if (initNodes != null) {
63674 initNodeNames = initNodes.map(node => parseNodeName(node.name)[0]);
63675 }
63676 const frontier = [...outputs];
63677 while (frontier.length > 0) {
63678 const node = frontier.pop();
63679 if (isControlFlow(node) || isDynamicShape(node) || isHashTable(node)) {
63680 if (dynamicNode == null) {
63681 dynamicNode = node;
63682 syncInputs = dynamicNode.children.map(child => child.name)
63683 .filter(name => usedNodes.has(name));
63684 }
63685 }
63686 usedNodes.add(node.name);
63687 // Weights are dead end since we already have their values.
63688 if (weightMap[node.name] != null) {
63689 continue;
63690 }
63691 // This node is a dead end since it's one of the user-provided inputs.
63692 if (inputNodeNames.indexOf(node.name) !== -1) {
63693 continue;
63694 }
63695 // This node is a dead end since it doesn't have any inputs.
63696 if (initNodeNames.indexOf(node.name) !== -1) {
63697 continue;
63698 }
63699 if (node.inputs.length === 0) {
63700 missingInputs.push(node.name);
63701 continue;
63702 }
63703 node.inputs.forEach(input => {
63704 // Don't add to the frontier if it is already there.
63705 if (seen.has(input.name)) {
63706 return;
63707 }
63708 seen.add(input.name);
63709 frontier.push(input);
63710 });
63711 }
63712 return { inputs, outputs, usedNodes, missingInputs, dynamicNode, syncInputs };
63713 }
63714 /**
63715 * Given the execution info, return a list of nodes in topological order that
63716 * need to be executed to compute the output.
63717 */
63718 function getNodesInTopologicalOrder(graph, weightMap, executionInfo) {
63719 const { usedNodes, inputs } = executionInfo;
63720 const frontier = [];
63721 const inputNodes = Object.keys(inputs)
63722 .map(name => parseNodeName(name)[0])
63723 .map(name => graph.nodes[name]);
63724 const initNodes = graph.initNodes;
63725 inputNodes.forEach(input => {
63726 if (usedNodes.has(input.name)) {
63727 frontier.push(input);
63728 }
63729 });
63730 graph.weights.forEach(weight => {
63731 if (usedNodes.has(weight.name)) {
63732 frontier.push(weight);
63733 }
63734 });
63735 if (initNodes != null) {
63736 initNodes.forEach(node => {
63737 if (usedNodes.has(node.name)) {
63738 frontier.push(node);
63739 }
63740 });
63741 }
63742 const seen = new Set();
63743 const orderedNodes = [];
63744 while (frontier.length > 0) {
63745 const node = frontier.pop();
63746 seen.add(node.name);
63747 if (!weightMap[node.name]) {
63748 orderedNodes.push(node);
63749 }
63750 node.children.forEach(child => {
63751 if (!seen.has(child.name) && usedNodes.has(child.name) &&
63752 child.inputs.every(input => seen.has(input.name))) {
63753 frontier.push(child);
63754 }
63755 });
63756 }
63757 return orderedNodes;
63758 }
63759 const CONTROL_FLOW_OPS = [
63760 'Switch', 'Merge', 'Enter', 'Exit', 'NextIteration', 'StatelessIf',
63761 'StatelessWhile', 'if', 'While'
63762 ];
63763 const DYNAMIC_SHAPE_OPS = [
63764 'NonMaxSuppressionV2', 'NonMaxSuppressionV3', 'NonMaxSuppressionV5', 'Where'
63765 ];
63766 const HASH_TABLE_OPS = [
63767 'HashTable', 'HashTableV2', 'LookupTableImport', 'LookupTableImportV2',
63768 'LookupTableFind', 'LookupTableFindV2', 'LookupTableSize', 'LookupTableSizeV2'
63769 ];
63770 function isControlFlow(node) {
63771 return CONTROL_FLOW_OPS.indexOf(node.op) >= 0;
63772 }
63773 function isDynamicShape(node) {
63774 return DYNAMIC_SHAPE_OPS.indexOf(node.op) >= 0;
63775 }
63776 function isHashTable(node) {
63777 return HASH_TABLE_OPS.indexOf(node.op) >= 0;
63778 }
63779
63780 /**
63781 * @license
63782 * Copyright 2018 Google LLC. All Rights Reserved.
63783 * Licensed under the Apache License, Version 2.0 (the "License");
63784 * you may not use this file except in compliance with the License.
63785 * You may obtain a copy of the License at
63786 *
63787 * http://www.apache.org/licenses/LICENSE-2.0
63788 *
63789 * Unless required by applicable law or agreed to in writing, software
63790 * distributed under the License is distributed on an "AS IS" BASIS,
63791 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
63792 * See the License for the specific language governing permissions and
63793 * limitations under the License.
63794 * =============================================================================
63795 */
63796 class GraphExecutor {
63797 /**
63798 *
63799 * @param graph Graph the model or function graph to be executed.
63800 * @param parent When building function exector you need to set the parent
63801 * executor. Since the weights and function executor maps are set at parant
63802 * level, that function executor can access the function maps and weight maps
63803 * through the parent.
63804 */
63805 constructor(graph, parent) {
63806 this.graph = graph;
63807 this.parent = parent;
63808 this.compiledMap = new Map();
63809 this._weightMap = {};
63810 this.SEPERATOR = ',';
63811 this._functions = {};
63812 this._functionExecutorMap = {};
63813 this.intermediateTensors = {};
63814 this.keepTensorForDebug = false;
63815 this._outputs = graph.outputs;
63816 this._inputs = graph.inputs;
63817 this._initNodes = graph.initNodes;
63818 this._signature = graph.signature;
63819 this._functions = graph.functions;
63820 // create sub-graph executors
63821 if (graph.functions != null) {
63822 Object.keys(graph.functions).forEach(name => {
63823 this._functionExecutorMap[name] =
63824 new GraphExecutor(graph.functions[name], this);
63825 });
63826 }
63827 }
63828 get weightIds() {
63829 return this.parent ? this.parent.weightIds : this._weightIds;
63830 }
63831 get functionExecutorMap() {
63832 return this.parent ? this.parent.functionExecutorMap :
63833 this._functionExecutorMap;
63834 }
63835 get weightMap() {
63836 return this.parent ? this.parent.weightMap : this._weightMap;
63837 }
63838 set weightMap(weightMap) {
63839 const weightIds = Object.keys(weightMap).map(key => weightMap[key].map(tensor => tensor.id));
63840 this._weightIds = [].concat(...weightIds);
63841 this._weightMap = weightMap;
63842 }
63843 /**
63844 * Set `ResourceManager` shared by executors of a model.
63845 * @param resourceManager: `ResourceManager` of the `GraphModel`.
63846 */
63847 set resourceManager(resourceManager) {
63848 this._resourceManager = resourceManager;
63849 }
63850 get inputs() {
63851 return this._inputs.map(node => {
63852 return {
63853 name: node.name,
63854 shape: node.attrParams['shape'] ?
63855 node.attrParams['shape'].value :
63856 undefined,
63857 dtype: node.attrParams['dtype'] ?
63858 node.attrParams['dtype'].value :
63859 undefined
63860 };
63861 });
63862 }
63863 get outputs() {
63864 return this._outputs.map(node => {
63865 return {
63866 name: node.name,
63867 shape: node.attrParams['shape'] ?
63868 node.attrParams['shape'].value :
63869 undefined,
63870 dtype: node.attrParams['dtype'] ?
63871 node.attrParams['dtype'].value :
63872 undefined
63873 };
63874 });
63875 }
63876 get inputNodes() {
63877 return this._inputs.map(node => node.signatureKey || node.name);
63878 }
63879 get outputNodes() {
63880 return this._outputs.map((node) => {
63881 const name = node.signatureKey || node.name;
63882 return node.defaultOutput ? (`${name}:${node.defaultOutput}`) : name;
63883 });
63884 }
63885 get functions() {
63886 return Object.keys(this._functions).reduce((map, key) => {
63887 map[key] = this._functions[key].signature;
63888 return map;
63889 }, {});
63890 }
63891 getCompilationKey(inputs, outputs) {
63892 const sortedInputs = inputs.map(node => node.name).sort();
63893 const sortedOutputs = outputs.map(node => node.name).sort();
63894 return sortedInputs.join(this.SEPERATOR) + '--' +
63895 sortedOutputs.join(this.SEPERATOR);
63896 }
63897 /**
63898 * Compiles the inference graph and returns the minimal set of nodes that are
63899 * required for execution, in the correct execution order.
63900 */
63901 compile(inputs, outputs) {
63902 const executionInfo = getExecutionSubgraph(inputs, outputs, this.weightMap, this._initNodes);
63903 const { missingInputs, dynamicNode, syncInputs } = executionInfo;
63904 if (dynamicNode != null) {
63905 throw new Error(`This execution contains the node '${dynamicNode.name}', which has ` +
63906 `the dynamic op '${dynamicNode.op}'. Please use ` +
63907 `model.executeAsync() instead. Alternatively, to avoid the ` +
63908 `dynamic ops, specify the inputs [${syncInputs}]`);
63909 }
63910 if (missingInputs.length > 0) {
63911 const outNames = outputs.map(n => n.name);
63912 const inNames = Object.keys(inputs);
63913 throw new Error(`Cannot compute the outputs [${outNames}] from the provided inputs ` +
63914 `[${inNames}]. Missing the following inputs: [${missingInputs}]`);
63915 }
63916 return getNodesInTopologicalOrder(this.graph, this.weightMap, executionInfo);
63917 }
63918 /**
63919 * Executes the inference for given input tensors.
63920 * @param inputs Tensor map for the model inputs, keyed by the input node
63921 * names.
63922 * @param outputs Optional. output node name from the Tensorflow model, if
63923 * no outputs are specified, the default outputs of the model would be used.
63924 * You can inspect intermediate nodes of the model by adding them to the
63925 * outputs array.
63926 */
63927 execute(inputs, outputs) {
63928 inputs = this.mapInputs(inputs);
63929 const names = Object.keys(inputs).sort();
63930 this.checkInputs(inputs);
63931 this.checkInputShapeAndType(inputs);
63932 outputs = this.mapOutputs(outputs);
63933 this.checkOutputs(outputs);
63934 const inputNodes = names.map(name => this.graph.nodes[parseNodeName(name)[0]]);
63935 const outputNodeNames = outputs.map(name => parseNodeName(name)[0]);
63936 let outputNodes = outputNodeNames.map(name => this.graph.nodes[name]);
63937 this.resetIntermediateTensors();
63938 // If no outputs are specified, then use the default outputs of the model.
63939 if (outputNodes.length === 0) {
63940 outputNodes = this._outputs;
63941 }
63942 const compilationKey = this.getCompilationKey(inputNodes, outputNodes);
63943 // Do nothing if the compiled graph cache contains the input.
63944 let orderedNodes = this.compiledMap.get(compilationKey);
63945 if (orderedNodes == null) {
63946 orderedNodes = this.compile(inputs, outputNodes);
63947 this.compiledMap.set(compilationKey, orderedNodes);
63948 }
63949 const tensorArrayMap = {};
63950 const tensorListMap = {};
63951 return tidy(() => {
63952 const context = new ExecutionContext(this.weightMap, tensorArrayMap, tensorListMap, this.functionExecutorMap);
63953 const tensorsMap = Object.assign({}, this.weightMap);
63954 Object.keys(inputs).forEach(name => {
63955 const [nodeName, index] = parseNodeName(name);
63956 const tensors = [];
63957 tensors[index] = inputs[name];
63958 tensorsMap[nodeName] = tensors;
63959 });
63960 const tensorsToKeep = this.getFrozenTensorIds(tensorsMap);
63961 const intermediateTensorConsumerCount = {};
63962 for (let i = 0; i < orderedNodes.length; i++) {
63963 const node = orderedNodes[i];
63964 if (!tensorsMap[node.name]) {
63965 const tensors = executeOp$j(node, tensorsMap, context, this._resourceManager);
63966 if (isPromise(tensors)) {
63967 throw new Error(`The execution of the op '${node.op}' returned a promise. ` +
63968 `Please use model.executeAsync() instead.`);
63969 }
63970 tensorsMap[node.name] = tensors;
63971 this.checkTensorForDisposal(node.name, node, tensorsMap, context, tensorsToKeep, outputNodeNames, intermediateTensorConsumerCount);
63972 }
63973 }
63974 // dispose the context for the root executor
63975 if (this.parent == null) {
63976 context.dispose(tensorsToKeep);
63977 }
63978 return outputs.map(name => getTensor(name, tensorsMap, context));
63979 });
63980 }
63981 getFrozenTensorIds(tensorMap) {
63982 const ids = [].concat.apply([], Object.keys(tensorMap)
63983 .map(key => tensorMap[key])
63984 .map(tensors => tensors.map(tensor => tensor.id)));
63985 return new Set(ids);
63986 }
63987 checkTensorForDisposal(nodeName, node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount) {
63988 // Skip output nodes and any control flow nodes, since its dependency is
63989 // tricky to track correctly.
63990 if (node.category === 'control' || outputNames.indexOf(nodeName) !== -1) {
63991 return;
63992 }
63993 tensorMap[nodeName].forEach(tensor => {
63994 if (tensor != null) {
63995 intermediateTensorConsumerCount[tensor.id] =
63996 (intermediateTensorConsumerCount[tensor.id] || 0) +
63997 node.children.length;
63998 }
63999 });
64000 node.inputs.forEach(input => {
64001 // Skip any control flow nodes, since its dependency is tricky to track
64002 // correctly.
64003 if (input.category !== 'control') {
64004 const tensors = getTensorsForCurrentContenxt(input.name, tensorMap, context);
64005 if (tensors != null) {
64006 tensors.forEach(tensor => {
64007 if (tensor && !tensor.kept && !tensorsToKeep.has(tensor.id)) {
64008 const count = intermediateTensorConsumerCount[tensor.id];
64009 if (count === 1) {
64010 if (!this.keepTensorForDebug) {
64011 tensor.dispose();
64012 }
64013 else {
64014 const [nodeName, index] = getNodeNameAndIndex(node.name, context);
64015 if (this.intermediateTensors[nodeName]) {
64016 this.intermediateTensors[nodeName][index] = tensor;
64017 }
64018 else {
64019 this.intermediateTensors[nodeName] = [];
64020 this.intermediateTensors[nodeName][index] = tensor;
64021 }
64022 }
64023 delete intermediateTensorConsumerCount[tensor.id];
64024 }
64025 else if (count != null) {
64026 // only intermediate nodes has count set, inputs and weights are
64027 // not.
64028 intermediateTensorConsumerCount[tensor.id]--;
64029 }
64030 }
64031 });
64032 }
64033 }
64034 });
64035 }
64036 /**
64037 * Executes the inference for given input tensors in Async fashion.
64038 * @param inputs Tensor map for the model inputs, keyed by the input node
64039 * names.
64040 * @param outputs output node name from the Tensorflow model, if no outputs
64041 * are specified, the default outputs of the model would be used. You can
64042 * inspect intermediate nodes of the model by adding them to the outputs
64043 * array.
64044 */
64045 async executeAsync(inputs, outputs) {
64046 return this._executeAsync(inputs, outputs);
64047 }
64048 disposeIntermediateTensors() {
64049 if (!this.intermediateTensors) {
64050 return;
64051 }
64052 Object.keys(this.intermediateTensors)
64053 .forEach(key => this.intermediateTensors[key].forEach(tensor => tensor.dispose()));
64054 this.disposeTensorsMap();
64055 }
64056 disposeTensorsMap() {
64057 if (!this.tensorsMap) {
64058 return;
64059 }
64060 Object.keys(this.tensorsMap).forEach(key => {
64061 const tensorArray = this.tensorsMap[key];
64062 tensorArray.forEach(tensor => {
64063 if (tensor && !tensor.kept && !tensor.isDisposed &&
64064 !this.keepIds.has(tensor.id)) {
64065 tensor.dispose();
64066 }
64067 });
64068 });
64069 }
64070 getIntermediateTensors() {
64071 return this.tensorsMap;
64072 }
64073 resetIntermediateTensors() {
64074 for (const key in this.intermediateTensors) {
64075 this.intermediateTensors[key].forEach(tensor => tensor.dispose());
64076 delete this.intermediateTensors[key];
64077 }
64078 }
64079 /**
64080 * Executes the inference for given input tensors in Async fashion.
64081 * @param inputs Tensor map for the model inputs, keyed by the input node
64082 * names.
64083 * @param outputs Optional. output node name from the Tensorflow model,
64084 * if no outputs are specified, the default outputs of the model would be
64085 * used. You can inspect intermediate nodes of the model by adding them to the
64086 * outputs array.
64087 * @param isFunctionExecution Optional. Flag for executing a function.
64088 * @param tensorArrayMap Optional, global TensorArray map by id. Used for
64089 * function execution.
64090 * @param tensorArrayMap Optinal global TensorList map by id. Used for
64091 * function execution.
64092 */
64093 async _executeAsync(inputs, outputs, isFunctionExecution = false, tensorArrayMap = {}, tensorListMap = {}) {
64094 if (!isFunctionExecution) {
64095 inputs = this.mapInputs(inputs);
64096 this.checkInputs(inputs);
64097 this.checkInputShapeAndType(inputs);
64098 outputs = this.mapOutputs(outputs);
64099 this.checkOutputs(outputs);
64100 }
64101 // For model debug.
64102 try {
64103 this.keepTensorForDebug = env().getBool('KEEP_INTERMEDIATE_TENSORS');
64104 }
64105 catch (e) {
64106 console.warn(e.message);
64107 }
64108 this.resetIntermediateTensors();
64109 const context = new ExecutionContext(this.weightMap, tensorArrayMap, tensorListMap, this.functionExecutorMap);
64110 // Graph with control flow op requires runtime evaluation of the execution
64111 // order, while without control flow the execution order is pre-determined
64112 // in the compile method.
64113 this.tensorsMap = await this.executeWithControlFlow(inputs, context, outputs, isFunctionExecution);
64114 const results = outputs.map(name => getTensor(name, this.tensorsMap, context));
64115 // dispose all the intermediate tensors
64116 const outputIds = results.map(t => t.id);
64117 const inputIds = Object.keys(inputs).map(name => inputs[name].id);
64118 this.keepIds =
64119 new Set([...outputIds, ...inputIds, ...this.weightIds]);
64120 if (!this.keepTensorForDebug) {
64121 this.disposeTensorsMap();
64122 }
64123 // dispose the context for the root executor
64124 if (this.parent == null) {
64125 context.dispose(this.keepIds);
64126 }
64127 return results;
64128 }
64129 async executeFunctionAsync(inputs, tensorArrayMap, tensorListMap) {
64130 const mappedInputs = inputs.reduce((map, tensor, index) => {
64131 map[this.inputs[index].name] = tensor;
64132 return map;
64133 }, {});
64134 return this._executeAsync(mappedInputs, this.outputNodes, true, tensorArrayMap, tensorListMap);
64135 }
64136 /**
64137 * When there are control flow nodes in the graph, the graph execution use
64138 * ExecutionContext to keep track of the frames and loop iterators.
64139 * @param inputs placeholder tensors for the graph.
64140 * @param context the execution context object for current execution.
64141 * @param outputNames Optional. output node name from the Tensorflow model,
64142 * if no outputs are specified, the default outputs of the model would be
64143 * used. You can inspect intermediate nodes of the model by adding them to the
64144 * outputs array.
64145 * @param isFunctionExecution Flag for executing a function.
64146 */
64147 async executeWithControlFlow(inputs, context, outputNames, isFunctionExecution) {
64148 const names = Object.keys(inputs);
64149 const inputNodes = names.map(name => this.graph.nodes[parseNodeName(name)[0]]);
64150 const outputNodeNames = outputNames.map(name => parseNodeName(name)[0]);
64151 let outputNodes = outputNodeNames.map(name => this.graph.nodes[name]);
64152 // If no outputs are specified, then use the default outputs of the model.
64153 if (outputNodes.length === 0) {
64154 outputNodes = this._outputs;
64155 }
64156 const { usedNodes, missingInputs, dynamicNode, syncInputs } = getExecutionSubgraph(inputs, outputNodes, this.weightMap, this._initNodes);
64157 // First nodes to execute include inputNodes, weights, and initNodes.
64158 const stack = [
64159 ...inputNodes, ...this.graph.weights, ...(this._initNodes || [])
64160 ].map(node => {
64161 return { node, contexts: context.currentContext };
64162 });
64163 const tensorsMap = Object.assign({}, this.weightMap);
64164 Object.keys(inputs).forEach(name => {
64165 const [nodeName, index] = parseNodeName(name);
64166 const tensors = [];
64167 tensors[index] = inputs[name];
64168 tensorsMap[nodeName] = tensors;
64169 });
64170 const intermediateTensorConsumerCount = {};
64171 const tensorsToKeep = this.getFrozenTensorIds(tensorsMap);
64172 const added = {};
64173 while (stack.length > 0) {
64174 const promises = this.processStack(inputNodes, stack, context, tensorsMap, added, tensorsToKeep, outputNodeNames, intermediateTensorConsumerCount, usedNodes);
64175 await Promise.all(promises);
64176 }
64177 if (dynamicNode == null && !isFunctionExecution) {
64178 console.warn(`This model execution did not contain any nodes with control flow ` +
64179 `or dynamic output shapes. You can use model.execute() instead.`);
64180 }
64181 const missingOutputs = outputNodes
64182 .filter(node => !isControlFlow(node) &&
64183 !getTensor(node.name, tensorsMap, context))
64184 .map(node => node.name);
64185 if (missingOutputs.length > 0) {
64186 let alternativeMsg = '';
64187 if (dynamicNode != null) {
64188 alternativeMsg =
64189 `Alternatively, to avoid the dynamic ops, use model.execute() ` +
64190 `and specify the inputs [${syncInputs}]`;
64191 }
64192 throw new Error(`Cannot compute the outputs [${missingOutputs}] from the provided ` +
64193 `inputs [${names}]. Consider providing the following inputs: ` +
64194 `[${missingInputs}]. ${alternativeMsg}`);
64195 }
64196 return tensorsMap;
64197 }
64198 processStack(inputNodes, stack, context, tensorMap, added, tensorsToKeep, outputNames, intermediateTensorConsumerCount, usedNodes) {
64199 const promises = [];
64200 while (stack.length > 0) {
64201 const item = stack.pop();
64202 context.currentContext = item.contexts;
64203 let nodeName = '';
64204 // The tensor of the Enter op with isConstant set should be set
64205 // in the parent scope, so it will be available as constant for the
64206 // whole loop.
64207 if (item.node.op === 'Enter' &&
64208 getParamValue('isConstant', item.node, tensorMap, context)) {
64209 [nodeName] = getNodeNameAndIndex(item.node.name, context);
64210 }
64211 // only process nodes that are not in the tensorMap yet, this include
64212 // inputNodes and internal initNodes.
64213 if (tensorMap[item.node.name] == null) {
64214 const tensors = executeOp$j(item.node, tensorMap, context, this._resourceManager);
64215 if (!nodeName) {
64216 [nodeName] = getNodeNameAndIndex(item.node.name, context);
64217 }
64218 const currentContext = context.currentContext;
64219 if (isPromise(tensors)) {
64220 promises.push(tensors.then(t => {
64221 tensorMap[nodeName] = t;
64222 context.currentContext = currentContext;
64223 this.checkTensorForDisposal(nodeName, item.node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount);
64224 this.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
64225 return t;
64226 }));
64227 }
64228 else {
64229 tensorMap[nodeName] = tensors;
64230 this.checkTensorForDisposal(nodeName, item.node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount);
64231 this.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
64232 }
64233 }
64234 else {
64235 this.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
64236 }
64237 }
64238 return promises;
64239 }
64240 processChildNodes(node, stack, context, tensorMap, added, usedNodes) {
64241 node.children.forEach((childNode) => {
64242 const [nodeName,] = getNodeNameAndIndex(childNode.name, context);
64243 if (added[nodeName] || !usedNodes.has(childNode.name)) {
64244 return;
64245 }
64246 // Merge op can be pushed if any of its inputs has value.
64247 if (childNode.op === 'Merge') {
64248 if (childNode.inputNames.some(name => {
64249 return !!getTensor(name, tensorMap, context);
64250 })) {
64251 added[nodeName] = true;
64252 stack.push({ contexts: context.currentContext, node: childNode });
64253 }
64254 }
64255 else // Otherwise all inputs must to have value.
64256 if (childNode.inputNames.every(name => {
64257 return !!getTensor(name, tensorMap, context);
64258 })) {
64259 added[nodeName] = true;
64260 stack.push({ contexts: context.currentContext, node: childNode });
64261 }
64262 });
64263 }
64264 /**
64265 * Releases the memory used by the weight tensors.
64266 */
64267 dispose() {
64268 Object.keys(this.weightMap)
64269 .forEach(key => this.weightMap[key].forEach(tensor => tensor.dispose()));
64270 }
64271 checkInputShapeAndType(inputs) {
64272 Object.keys(inputs).forEach(name => {
64273 const input = inputs[name];
64274 const [nodeName,] = parseNodeName(name);
64275 const node = this.graph.nodes[nodeName];
64276 if (node.attrParams['shape'] && node.attrParams['shape'].value) {
64277 const shape = node.attrParams['shape'].value;
64278 const match = shape.length === input.shape.length &&
64279 input.shape.every((dim, index) => shape[index] === -1 || shape[index] === dim);
64280 assert(match, () => `The shape of dict['${node.name}'] provided in ` +
64281 `model.execute(dict) must be [${shape}], but was ` +
64282 `[${input.shape}]`);
64283 }
64284 if (node.attrParams['dtype'] && node.attrParams['dtype'].value) {
64285 assert(input.dtype === node.attrParams['dtype'].value, () => `The dtype of dict['${node.name}'] provided in ` +
64286 `model.execute(dict) must be ` +
64287 `${node.attrParams['dtype'].value}, but was ${input.dtype}`);
64288 }
64289 });
64290 }
64291 mapInputs(inputs) {
64292 const result = {};
64293 for (const inputName in inputs) {
64294 if (this._signature != null && this._signature.inputs != null &&
64295 this._signature.inputs[inputName] != null) {
64296 const tensor = this._signature.inputs[inputName];
64297 result[tensor.name] = inputs[inputName];
64298 }
64299 else {
64300 result[inputName] = inputs[inputName];
64301 }
64302 }
64303 return result;
64304 }
64305 checkInputs(inputs) {
64306 const notInGraph = Object.keys(inputs).filter(name => {
64307 const [nodeName] = parseNodeName(name);
64308 return this.graph.nodes[nodeName] == null;
64309 });
64310 if (notInGraph.length > 0) {
64311 throw new Error(`The dict provided in model.execute(dict) has ` +
64312 `keys: [${notInGraph}] that are not part of graph`);
64313 }
64314 }
64315 mapOutputs(outputs) {
64316 return outputs.map(name => {
64317 if (this._signature != null && this._signature.outputs != null &&
64318 this._signature.outputs[name] != null) {
64319 const tensor = this._signature.outputs[name];
64320 return tensor.name;
64321 }
64322 return name;
64323 }, {});
64324 }
64325 checkOutputs(outputs) {
64326 outputs.forEach(name => {
64327 const [normalizedName] = parseNodeName(name);
64328 if (!this.graph.nodes[normalizedName]) {
64329 throw new Error(`The output '${name}' is not found in the graph`);
64330 }
64331 });
64332 }
64333 }
64334
64335 /**
64336 * Contains global resources of a model.
64337 */
64338 class ResourceManager {
64339 constructor(hashTableNameToHandle = {}, hashTableMap = {}) {
64340 this.hashTableNameToHandle = hashTableNameToHandle;
64341 this.hashTableMap = hashTableMap;
64342 }
64343 /**
64344 * Register a `HashTable` in the resource manager.
64345 *
64346 * The `HashTable` can be retrieved by `resourceManager.getHashTableById`,
64347 * where id is the table handle tensor's id.
64348 *
64349 * @param name Op node name that creates the `HashTable`.
64350 * @param hashTable The `HashTable` to be added to resource manager.
64351 */
64352 addHashTable(name, hashTable) {
64353 this.hashTableNameToHandle[name] = hashTable.handle;
64354 this.hashTableMap[hashTable.id] = hashTable;
64355 }
64356 /**
64357 * Get the table handle by node name.
64358 * @param name Op node name that creates the `HashTable`. This name is also
64359 * used in the inputs list of lookup and import `HashTable` ops.
64360 */
64361 getHashTableHandleByName(name) {
64362 return this.hashTableNameToHandle[name];
64363 }
64364 /**
64365 * Get the actual `HashTable` by its handle tensor's id.
64366 * @param id The id of the handle tensor.
64367 */
64368 getHashTableById(id) {
64369 return this.hashTableMap[id];
64370 }
64371 /**
64372 * Dispose `ResourceManager`, including its hashTables and tensors in them.
64373 */
64374 dispose() {
64375 for (const key in this.hashTableMap) {
64376 this.hashTableMap[key].clearAndClose();
64377 delete this.hashTableMap[key];
64378 }
64379 for (const name in this.hashTableNameToHandle) {
64380 this.hashTableNameToHandle[name].dispose();
64381 delete this.hashTableNameToHandle[name];
64382 }
64383 }
64384 }
64385
64386 /**
64387 * @license
64388 * Copyright 2018 Google LLC. All Rights Reserved.
64389 * Licensed under the Apache License, Version 2.0 (the "License");
64390 * you may not use this file except in compliance with the License.
64391 * You may obtain a copy of the License at
64392 *
64393 * http://www.apache.org/licenses/LICENSE-2.0
64394 *
64395 * Unless required by applicable law or agreed to in writing, software
64396 * distributed under the License is distributed on an "AS IS" BASIS,
64397 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
64398 * See the License for the specific language governing permissions and
64399 * limitations under the License.
64400 * =============================================================================
64401 */
64402 const TFHUB_SEARCH_PARAM = '?tfjs-format=file';
64403 const DEFAULT_MODEL_NAME = 'model.json';
64404 /**
64405 * A `tf.GraphModel` is a directed, acyclic graph built from a
64406 * SavedModel GraphDef and allows inference execution.
64407 *
64408 * A `tf.GraphModel` can only be created by loading from a model converted from
64409 * a [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) using
64410 * the command line converter tool and loaded via `tf.loadGraphModel`.
64411 *
64412 * @doc {heading: 'Models', subheading: 'Classes'}
64413 */
64414 class GraphModel {
64415 /**
64416 * @param modelUrl url for the model, or an `io.IOHandler`.
64417 * @param weightManifestUrl url for the weight file generated by
64418 * scripts/convert.py script.
64419 * @param requestOption options for Request, which allows to send credentials
64420 * and custom headers.
64421 * @param onProgress Optional, progress callback function, fired periodically
64422 * before the load is completed.
64423 */
64424 constructor(modelUrl, loadOptions = {}) {
64425 this.modelUrl = modelUrl;
64426 this.loadOptions = loadOptions;
64427 this.version = 'n/a';
64428 if (loadOptions == null) {
64429 this.loadOptions = {};
64430 }
64431 this.resourceManager = new ResourceManager();
64432 }
64433 // Returns the version information for the tensorflow model GraphDef.
64434 get modelVersion() {
64435 return this.version;
64436 }
64437 get inputNodes() {
64438 return this.executor.inputNodes;
64439 }
64440 get outputNodes() {
64441 return this.executor.outputNodes;
64442 }
64443 get inputs() {
64444 return this.executor.inputs;
64445 }
64446 get outputs() {
64447 return this.executor.outputs;
64448 }
64449 get weights() {
64450 return this.executor.weightMap;
64451 }
64452 get metadata() {
64453 return this.artifacts.userDefinedMetadata;
64454 }
64455 get modelSignature() {
64456 return this.signature;
64457 }
64458 findIOHandler() {
64459 const path = this.modelUrl;
64460 if (path.load != null) {
64461 // Path is an IO Handler.
64462 this.handler = path;
64463 }
64464 else if (this.loadOptions.requestInit != null) {
64465 this.handler = browserHTTPRequest(path, this.loadOptions);
64466 }
64467 else {
64468 const handlers = getLoadHandlers(path, this.loadOptions);
64469 if (handlers.length === 0) {
64470 // For backward compatibility: if no load handler can be found,
64471 // assume it is a relative http path.
64472 handlers.push(browserHTTPRequest(path, this.loadOptions));
64473 }
64474 else if (handlers.length > 1) {
64475 throw new Error(`Found more than one (${handlers.length}) load handlers for ` +
64476 `URL '${[path]}'`);
64477 }
64478 this.handler = handlers[0];
64479 }
64480 }
64481 /**
64482 * Loads the model and weight files, construct the in memory weight map and
64483 * compile the inference graph.
64484 */
64485 load() {
64486 this.findIOHandler();
64487 if (this.handler.load == null) {
64488 throw new Error('Cannot proceed with model loading because the IOHandler provided ' +
64489 'does not have the `load` method implemented.');
64490 }
64491 const loadResult = this.handler.load();
64492 if (isPromise(loadResult)) {
64493 return loadResult.then(artifacts => this.loadSync(artifacts));
64494 }
64495 return this.loadSync(loadResult);
64496 }
64497 /**
64498 * Synchronously construct the in memory weight map and
64499 * compile the inference graph. Also initialize hashtable if any.
64500 *
64501 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
64502 */
64503 loadSync(artifacts) {
64504 this.artifacts = artifacts;
64505 const graph = this.artifacts.modelTopology;
64506 let signature;
64507 if (this.artifacts.userDefinedMetadata != null &&
64508 this.artifacts.userDefinedMetadata.signature != null) {
64509 signature = // tslint:disable-next-line:no-any
64510 this.artifacts.userDefinedMetadata.signature;
64511 }
64512 else {
64513 signature = this.artifacts.signature;
64514 }
64515 this.signature = signature;
64516 this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`;
64517 const weightMap = decodeWeights(this.artifacts.weightData, this.artifacts.weightSpecs);
64518 this.executor = new GraphExecutor(OperationMapper.Instance.transformGraph(graph, this.signature));
64519 this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap);
64520 // Attach a model-level resourceManager to each executor to share resources,
64521 // such as `HashTable`.
64522 this.executor.resourceManager = this.resourceManager;
64523 if (artifacts.modelInitializer != null &&
64524 artifacts.modelInitializer.node != null) {
64525 const initializer = OperationMapper.Instance.transformGraph(artifacts.modelInitializer);
64526 this.initializer = new GraphExecutor(initializer);
64527 this.initializer.weightMap = this.executor.weightMap;
64528 // Attach a model-level resourceManager to the initializer, the
64529 // hashTables created from when executing the initializer will be stored
64530 // in the resourceManager.
64531 this.initializer.resourceManager = this.resourceManager;
64532 this.initializer.executeAsync({}, []);
64533 }
64534 return true;
64535 }
64536 /**
64537 * Save the configuration and/or weights of the GraphModel.
64538 *
64539 * An `IOHandler` is an object that has a `save` method of the proper
64540 * signature defined. The `save` method manages the storing or
64541 * transmission of serialized data ("artifacts") that represent the
64542 * model's topology and weights onto or via a specific medium, such as
64543 * file downloads, local storage, IndexedDB in the web browser and HTTP
64544 * requests to a server. TensorFlow.js provides `IOHandler`
64545 * implementations for a number of frequently used saving mediums, such as
64546 * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
64547 * for more details.
64548 *
64549 * This method also allows you to refer to certain types of `IOHandler`s
64550 * as URL-like string shortcuts, such as 'localstorage://' and
64551 * 'indexeddb://'.
64552 *
64553 * Example 1: Save `model`'s topology and weights to browser [local
64554 * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
64555 * then load it back.
64556 *
64557 * ```js
64558 * const modelUrl =
64559 * 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
64560 * const model = await tf.loadGraphModel(modelUrl);
64561 * const zeros = tf.zeros([1, 224, 224, 3]);
64562 * model.predict(zeros).print();
64563 *
64564 * const saveResults = await model.save('localstorage://my-model-1');
64565 *
64566 * const loadedModel = await tf.loadGraphModel('localstorage://my-model-1');
64567 * console.log('Prediction from loaded model:');
64568 * model.predict(zeros).print();
64569 * ```
64570 *
64571 * @param handlerOrURL An instance of `IOHandler` or a URL-like,
64572 * scheme-based string shortcut for `IOHandler`.
64573 * @param config Options for saving the model.
64574 * @returns A `Promise` of `SaveResult`, which summarizes the result of
64575 * the saving, such as byte sizes of the saved artifacts for the model's
64576 * topology and weight values.
64577 *
64578 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
64579 */
64580 async save(handlerOrURL, config) {
64581 if (typeof handlerOrURL === 'string') {
64582 const handlers = getSaveHandlers(handlerOrURL);
64583 if (handlers.length === 0) {
64584 throw new Error(`Cannot find any save handlers for URL '${handlerOrURL}'`);
64585 }
64586 else if (handlers.length > 1) {
64587 throw new Error(`Found more than one (${handlers.length}) save handlers for ` +
64588 `URL '${handlerOrURL}'`);
64589 }
64590 handlerOrURL = handlers[0];
64591 }
64592 if (handlerOrURL.save == null) {
64593 throw new Error('GraphModel.save() cannot proceed because the IOHandler ' +
64594 'provided does not have the `save` attribute defined.');
64595 }
64596 return handlerOrURL.save(this.artifacts);
64597 }
64598 /**
64599 * Execute the inference for the input tensors.
64600 *
64601 * @param input The input tensors, when there is single input for the model,
64602 * inputs param should be a `tf.Tensor`. For models with mutliple inputs,
64603 * inputs params should be in either `tf.Tensor`[] if the input order is
64604 * fixed, or otherwise NamedTensorMap format.
64605 *
64606 * For model with multiple inputs, we recommend you use NamedTensorMap as the
64607 * input type, if you use `tf.Tensor`[], the order of the array needs to
64608 * follow the
64609 * order of inputNodes array. @see {@link GraphModel.inputNodes}
64610 *
64611 * You can also feed any intermediate nodes using the NamedTensorMap as the
64612 * input type. For example, given the graph
64613 * InputNode => Intermediate => OutputNode,
64614 * you can execute the subgraph Intermediate => OutputNode by calling
64615 * model.execute('IntermediateNode' : tf.tensor(...));
64616 *
64617 * This is useful for models that uses tf.dynamic_rnn, where the intermediate
64618 * state needs to be fed manually.
64619 *
64620 * For batch inference execution, the tensors for each input need to be
64621 * concatenated together. For example with mobilenet, the required input shape
64622 * is [1, 244, 244, 3], which represents the [batch, height, width, channel].
64623 * If we are provide a batched data of 100 images, the input tensor should be
64624 * in the shape of [100, 244, 244, 3].
64625 *
64626 * @param config Prediction configuration for specifying the batch size and
64627 * output node names. Currently the batch size option is ignored for graph
64628 * model.
64629 *
64630 * @returns Inference result tensors. The output would be single `tf.Tensor`
64631 * if model has single output node, otherwise Tensor[] or NamedTensorMap[]
64632 * will be returned for model with multiple outputs.
64633 *
64634 * @doc {heading: 'Models', subheading: 'Classes'}
64635 */
64636 predict(inputs, config) {
64637 return this.execute(inputs, this.outputNodes);
64638 }
64639 normalizeInputs(inputs) {
64640 if (!(inputs instanceof Tensor) && !Array.isArray(inputs)) {
64641 // The input is already a NamedTensorMap.
64642 return inputs;
64643 }
64644 inputs = Array.isArray(inputs) ? inputs : [inputs];
64645 if (inputs.length !== this.inputNodes.length) {
64646 throw new Error('Input tensor count mismatch,' +
64647 `the graph model has ${this.inputNodes.length} placeholders, ` +
64648 `while there are ${inputs.length} input tensors.`);
64649 }
64650 return this.inputNodes.reduce((map, inputName, i) => {
64651 map[inputName] = inputs[i];
64652 return map;
64653 }, {});
64654 }
64655 normalizeOutputs(outputs) {
64656 outputs = outputs || this.outputNodes;
64657 return !Array.isArray(outputs) ? [outputs] : outputs;
64658 }
64659 /**
64660 * Executes inference for the model for given input tensors.
64661 * @param inputs tensor, tensor array or tensor map of the inputs for the
64662 * model, keyed by the input node names.
64663 * @param outputs output node name from the Tensorflow model, if no
64664 * outputs are specified, the default outputs of the model would be used.
64665 * You can inspect intermediate nodes of the model by adding them to the
64666 * outputs array.
64667 *
64668 * @returns A single tensor if provided with a single output or no outputs
64669 * are provided and there is only one default output, otherwise return a
64670 * tensor array. The order of the tensor array is the same as the outputs
64671 * if provided, otherwise the order of outputNodes attribute of the model.
64672 *
64673 * @doc {heading: 'Models', subheading: 'Classes'}
64674 */
64675 execute(inputs, outputs) {
64676 inputs = this.normalizeInputs(inputs);
64677 outputs = this.normalizeOutputs(outputs);
64678 const result = this.executor.execute(inputs, outputs);
64679 return result.length > 1 ? result : result[0];
64680 }
64681 /**
64682 * Executes inference for the model for given input tensors in async
64683 * fashion, use this method when your model contains control flow ops.
64684 * @param inputs tensor, tensor array or tensor map of the inputs for the
64685 * model, keyed by the input node names.
64686 * @param outputs output node name from the Tensorflow model, if no outputs
64687 * are specified, the default outputs of the model would be used. You can
64688 * inspect intermediate nodes of the model by adding them to the outputs
64689 * array.
64690 *
64691 * @returns A Promise of single tensor if provided with a single output or
64692 * no outputs are provided and there is only one default output, otherwise
64693 * return a tensor map.
64694 *
64695 * @doc {heading: 'Models', subheading: 'Classes'}
64696 */
64697 async executeAsync(inputs, outputs) {
64698 inputs = this.normalizeInputs(inputs);
64699 outputs = this.normalizeOutputs(outputs);
64700 const result = await this.executor.executeAsync(inputs, outputs);
64701 return result.length > 1 ? result : result[0];
64702 }
64703 /**
64704 * Get intermediate tensors for model debugging mode (flag
64705 * KEEP_INTERMEDIATE_TENSORS is true).
64706 *
64707 * @doc {heading: 'Models', subheading: 'Classes'}
64708 */
64709 getIntermediateTensors() {
64710 return this.executor.getIntermediateTensors();
64711 }
64712 /**
64713 * Dispose intermediate tensors for model debugging mode (flag
64714 * KEEP_INTERMEDIATE_TENSORS is true).
64715 *
64716 * @doc {heading: 'Models', subheading: 'Classes'}
64717 */
64718 disposeIntermediateTensors() {
64719 this.executor.disposeIntermediateTensors();
64720 }
64721 convertTensorMapToTensorsMap(map) {
64722 return Object.keys(map).reduce((newMap, key) => {
64723 newMap[key] = [map[key]];
64724 return newMap;
64725 }, {});
64726 }
64727 /**
64728 * Releases the memory used by the weight tensors and resourceManager.
64729 *
64730 * @doc {heading: 'Models', subheading: 'Classes'}
64731 */
64732 dispose() {
64733 this.executor.dispose();
64734 if (this.initializer) {
64735 this.initializer.dispose();
64736 }
64737 this.resourceManager.dispose();
64738 }
64739 }
64740 /**
64741 * Load a graph model given a URL to the model definition.
64742 *
64743 * Example of loading MobileNetV2 from a URL and making a prediction with a
64744 * zeros input:
64745 *
64746 * ```js
64747 * const modelUrl =
64748 * 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
64749 * const model = await tf.loadGraphModel(modelUrl);
64750 * const zeros = tf.zeros([1, 224, 224, 3]);
64751 * model.predict(zeros).print();
64752 * ```
64753 *
64754 * Example of loading MobileNetV2 from a TF Hub URL and making a prediction with
64755 * a zeros input:
64756 *
64757 * ```js
64758 * const modelUrl =
64759 * 'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
64760 * const model = await tf.loadGraphModel(modelUrl, {fromTFHub: true});
64761 * const zeros = tf.zeros([1, 224, 224, 3]);
64762 * model.predict(zeros).print();
64763 * ```
64764 * @param modelUrl The url or an `io.IOHandler` that loads the model.
64765 * @param options Options for the HTTP request, which allows to send credentials
64766 * and custom headers.
64767 *
64768 * @doc {heading: 'Models', subheading: 'Loading'}
64769 */
64770 async function loadGraphModel(modelUrl, options = {}) {
64771 if (modelUrl == null) {
64772 throw new Error('modelUrl in loadGraphModel() cannot be null. Please provide a url ' +
64773 'or an IOHandler that loads the model');
64774 }
64775 if (options == null) {
64776 options = {};
64777 }
64778 if (options.fromTFHub && typeof modelUrl === 'string') {
64779 modelUrl = getTFHubUrl(modelUrl);
64780 }
64781 const model = new GraphModel(modelUrl, options);
64782 await model.load();
64783 return model;
64784 }
64785 /**
64786 * Load a graph model given a synchronous IO handler with a 'load' method.
64787 *
64788 * @param modelSource The `io.IOHandlerSync` that loads the model.
64789 *
64790 * @doc {heading: 'Models', subheading: 'Loading'}
64791 */
64792 function loadGraphModelSync(modelSource) {
64793 if (modelSource == null) {
64794 throw new Error('modelUrl in loadGraphModelSync() cannot be null. Please provide a ' +
64795 'url or an IOHandler that loads the model');
64796 }
64797 if (!modelSource.load) {
64798 throw new Error(`modelUrl IO Handler ${modelSource} has no load function`);
64799 }
64800 const model = new GraphModel(modelSource);
64801 model.load();
64802 return model;
64803 }
64804 function getTFHubUrl(modelUrl) {
64805 if (!modelUrl.endsWith('/')) {
64806 modelUrl = (modelUrl) + '/';
64807 }
64808 return `${modelUrl}${DEFAULT_MODEL_NAME}${TFHUB_SEARCH_PARAM}`;
64809 }
64810
64811 /** @license See the LICENSE file. */
64812 // This code is auto-generated, do not modify this file!
64813 const version$2 = '3.18.0';
64814
64815 /**
64816 * @license
64817 * Copyright 2018 Google LLC. All Rights Reserved.
64818 * Licensed under the Apache License, Version 2.0 (the "License");
64819 * you may not use this file except in compliance with the License.
64820 * You may obtain a copy of the License at
64821 *
64822 * http://www.apache.org/licenses/LICENSE-2.0
64823 *
64824 * Unless required by applicable law or agreed to in writing, software
64825 * distributed under the License is distributed on an "AS IS" BASIS,
64826 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
64827 * See the License for the specific language governing permissions and
64828 * limitations under the License.
64829 * =============================================================================
64830 */
64831
64832 /**
64833 * @license
64834 * Copyright 2018 Google LLC. All Rights Reserved.
64835 * Licensed under the Apache License, Version 2.0 (the "License");
64836 * you may not use this file except in compliance with the License.
64837 * You may obtain a copy of the License at
64838 *
64839 * http://www.apache.org/licenses/LICENSE-2.0
64840 *
64841 * Unless required by applicable law or agreed to in writing, software
64842 * distributed under the License is distributed on an "AS IS" BASIS,
64843 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
64844 * See the License for the specific language governing permissions and
64845 * limitations under the License.
64846 *
64847 * =============================================================================
64848 */
64849 /**
64850 * Apply a mapping function to a nested structure in a recursive manner.
64851 *
64852 * The result of the mapping is an object with the same nested structure (i.e.,
64853 * of arrays and dicts) as the input, except that some subtrees are replaced,
64854 * according to the results of the mapping function.
64855 *
64856 * Mappings are memoized. Thus, if the nested structure contains the same
64857 * object in multiple positions, the output will contain the same mapped object
64858 * in those positions. Cycles are not supported, however.
64859 *
64860 * @param input: The object to which to apply the mapping function.
64861 * @param mapFn: A function that expects a single node of the object tree, and
64862 * returns a `DeepMapResult`. The `DeepMapResult` either provides a
64863 * replacement value for that node (i.e., replacing the subtree), or indicates
64864 * that the node should be processed recursively.
64865 */
64866 function deepMap(input, mapFn) {
64867 return deepMapInternal(input, mapFn);
64868 }
64869 /**
64870 * @param seen: A Map of known object mappings (i.e., memoized results of
64871 * `mapFn()`)
64872 * @param containedIn: An set containing objects on the reference path currently
64873 * being processed (used to detect cycles).
64874 */
64875 function deepMapInternal(input, mapFn, seen = new Map(), containedIn = new Set()) {
64876 if (input == null) {
64877 return null;
64878 }
64879 if (typeof Blob === 'function' && input instanceof Blob) {
64880 return input.slice();
64881 }
64882 if (containedIn.has(input)) {
64883 throw new Error('Circular references are not supported.');
64884 }
64885 if (seen.has(input)) {
64886 return seen.get(input);
64887 }
64888 const result = mapFn(input);
64889 if (result.recurse && result.value !== null) {
64890 throw new Error('A deep map function may not return both a value and recurse=true.');
64891 }
64892 if (!result.recurse) {
64893 seen.set(input, result.value);
64894 return result.value;
64895 }
64896 else if (isIterable$1(input)) {
64897 // tslint:disable-next-line:no-any
64898 const mappedIterable = Array.isArray(input) ? [] : {};
64899 containedIn.add(input);
64900 for (const k in input) {
64901 const child = input[k];
64902 const childResult = deepMapInternal(child, mapFn, seen, containedIn);
64903 mappedIterable[k] = childResult;
64904 }
64905 containedIn.delete(input);
64906 if (input.__proto__) {
64907 mappedIterable.__proto__ = input.__proto__;
64908 }
64909 return mappedIterable;
64910 }
64911 else {
64912 throw new Error(`Can't recurse into non-iterable type: ${input}`);
64913 }
64914 }
64915 // TODO(soergel, kangyizhang) Reconsider naming of deepZip() to avoid confusion
64916 // with zip()
64917 /**
64918 * Zip nested structures together in a recursive manner.
64919 *
64920 * This has the effect of transposing or pivoting data, e.g. converting it from
64921 * a row-major representation to a column-major representation.
64922 *
64923 * For example, `deepZip([{a: 1, b: 2}, {a: 3, b: 4}])` returns
64924 * `{a: [1, 3], b: [2, 4]}`.
64925 *
64926 * The inputs should all have the same nested structure (i.e., of arrays and
64927 * dicts). The result is a single object with the same nested structure, where
64928 * the leaves are arrays collecting the values of the inputs at that location
64929 * (or, optionally, the result of a custom function applied to those arrays).
64930 *
64931 * @param inputs: An array of the objects to zip together.
64932 * @param zipFn: (optional) A function that expects an array of elements at a
64933 * single node of the object tree, and returns a `DeepMapResult`. The
64934 * `DeepMapResult` either provides a result value for that node (i.e.,
64935 * representing the subtree), or indicates that the node should be processed
64936 * recursively. The default zipFn recurses as far as possible and places
64937 * arrays at the leaves.
64938 */
64939 function deepZip(inputs, zipFn = zipToList) {
64940 return deepZipInternal(inputs, zipFn);
64941 }
64942 /**
64943 * @param containedIn: An set containing objects on the reference path currently
64944 * being processed (used to detect cycles).
64945 */
64946 function deepZipInternal(inputs, zipFn, containedIn = new Set()) {
64947 // The recursion follows the structure of input 0; it's assumed that all the
64948 // other inputs have the same structure.
64949 const input = inputs[0];
64950 if (containedIn.has(input)) {
64951 throw new Error('Circular references are not supported.');
64952 }
64953 const result = zipFn(inputs);
64954 if (result.recurse && result.value !== null) {
64955 throw new Error('A deep zip function may not return both a value and recurse=true.');
64956 }
64957 if (!result.recurse) {
64958 return result.value;
64959 }
64960 else if (isIterable$1(input)) {
64961 // tslint:disable-next-line:no-any
64962 const mappedIterable = Array.isArray(input) ? [] : {};
64963 containedIn.add(input);
64964 for (const k in input) {
64965 const children = inputs.map(x => x[k]);
64966 const childResult = deepZipInternal(children, zipFn, containedIn);
64967 mappedIterable[k] = childResult;
64968 }
64969 containedIn.delete(input);
64970 return mappedIterable;
64971 }
64972 else {
64973 throw new Error(`Can't recurse into non-iterable type: ${input}`);
64974 }
64975 }
64976 // tslint:disable-next-line:no-any
64977 function zipToList(x) {
64978 if (x === null) {
64979 return null;
64980 }
64981 // TODO(soergel): validate array type?
64982 if (isIterable$1(x[0])) {
64983 return { value: null, recurse: true };
64984 }
64985 else {
64986 return { value: x, recurse: false };
64987 }
64988 }
64989 /**
64990 * Apply an async mapping function to a nested structure in a recursive manner.
64991 *
64992 * This first creates a nested structure of Promises, and then awaits all of
64993 * those, resulting in a single Promise for a resolved nested structure.
64994 *
64995 * The result of the mapping is an object with the same nested structure (i.e.,
64996 * of arrays and dicts) as the input, except that some subtrees are replaced,
64997 * according to the results of the mapping function.
64998 *
64999 * Mappings are memoized. Thus, if the nested structure contains the same
65000 * object in multiple positions, the output will contain the same mapped object
65001 * in those positions. Cycles are not supported, however.
65002 *
65003 * @param input: The object to which to apply the mapping function.
65004 * @param mapFn: A function that expects a single node of the object tree, and
65005 * returns a `DeepMapAsyncResult`. The `DeepMapAsyncResult` either provides
65006 * a `Promise` for a replacement value for that node (i.e., replacing the
65007 * subtree), or indicates that the node should be processed recursively. Note
65008 * that the decision whether or not to recurse must be made immediately; only
65009 * the mapped value may be promised.
65010 */
65011 async function deepMapAndAwaitAll(input, mapFn) {
65012 const seen = new Map();
65013 // First do a normal deepMap, collecting Promises in 'seen' as a side effect.
65014 deepMapInternal(input, mapFn, seen);
65015 // Replace the Promises in 'seen' in place.
65016 // Note TypeScript provides no async map iteration, and regular map iteration
65017 // is broken too, so sadly we have to do Array.from() to make it work.
65018 // (There's no advantage to Promise.all(), and that would be tricky anyway.)
65019 for (const key of Array.from(seen.keys())) {
65020 const value = seen.get(key);
65021 if (isPromise(value)) {
65022 const mappedValue = await value;
65023 seen.set(key, mappedValue);
65024 }
65025 }
65026 // Normal deepMap again, this time filling in the resolved values.
65027 // It's unfortunate that we have to do two passes.
65028 // TODO(soergel): test performance and think harder about a fast solution.
65029 const result = deepMapInternal(input, mapFn, seen);
65030 return result;
65031 }
65032 /**
65033 * Determine whether the argument is iterable.
65034 *
65035 * @returns true if the argument is an array or any non-Tensor object.
65036 */
65037 // tslint:disable-next-line:no-any
65038 function isIterable$1(obj) {
65039 let isTextDecoder = false;
65040 if (env().get('IS_BROWSER')) {
65041 isTextDecoder = obj instanceof TextDecoder;
65042 }
65043 else {
65044 // tslint:disable-next-line:no-require-imports
65045 const { StringDecoder } = require('string_decoder');
65046 isTextDecoder = obj instanceof StringDecoder;
65047 }
65048 return obj != null && (!ArrayBuffer.isView(obj)) &&
65049 (Array.isArray(obj) ||
65050 (typeof obj === 'object' && !(obj instanceof Tensor) &&
65051 !(obj instanceof Promise) && !isTextDecoder));
65052 }
65053 /**
65054 * Determine whether the argument can be converted to Tensor.
65055 *
65056 * Tensors, primitives, arrays, and TypedArrays all qualify; anything else does
65057 * not.
65058 *
65059 * @returns true if the argument can be converted to Tensor.
65060 */
65061 // tslint:disable-next-line:no-any
65062 function canTensorify(obj) {
65063 return obj == null || isPrimitive(obj) || Array.isArray(obj) ||
65064 (typeof obj === 'object' && (obj instanceof Tensor)) ||
65065 isTypedArray(obj);
65066 }
65067 /**
65068 * Returns true if the given `value` is a primitive type. Otherwise returns
65069 * false. This is equivalant to node util.isPrimitive
65070 */
65071 function isPrimitive(value) {
65072 return (value === null ||
65073 (typeof value !== 'object' && typeof value !== 'function'));
65074 }
65075
65076 /**
65077 * @license
65078 * Copyright 2018 Google LLC. All Rights Reserved.
65079 * Licensed under the Apache License, Version 2.0 (the "License");
65080 * you may not use this file except in compliance with the License.
65081 * You may obtain a copy of the License at
65082 *
65083 * http://www.apache.org/licenses/LICENSE-2.0
65084 *
65085 * Unless required by applicable law or agreed to in writing, software
65086 * distributed under the License is distributed on an "AS IS" BASIS,
65087 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65088 * See the License for the specific language governing permissions and
65089 * limitations under the License.
65090 *
65091 * =============================================================================
65092 */
65093 function deepClone(container) {
65094 return deepMap(container, cloneIfTensor);
65095 }
65096 // tslint:disable-next-line: no-any
65097 function cloneIfTensor(item) {
65098 if (item instanceof Tensor) {
65099 return ({ value: item.clone(), recurse: false });
65100 }
65101 else if (isIterable$1(item)) {
65102 return { value: null, recurse: true };
65103 }
65104 else {
65105 return { value: item, recurse: false };
65106 }
65107 }
65108
65109 /**
65110 * @license
65111 * Copyright 2018 Google LLC. All Rights Reserved.
65112 * Licensed under the Apache License, Version 2.0 (the "License");
65113 * you may not use this file except in compliance with the License.
65114 * You may obtain a copy of the License at
65115 *
65116 * http://www.apache.org/licenses/LICENSE-2.0
65117 *
65118 * Unless required by applicable law or agreed to in writing, software
65119 * distributed under the License is distributed on an "AS IS" BASIS,
65120 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65121 * See the License for the specific language governing permissions and
65122 * limitations under the License.
65123 *
65124 * =============================================================================
65125 */
65126 /**
65127 * A ring buffer, providing O(1) FIFO, LIFO, and related operations.
65128 */
65129 class RingBuffer {
65130 /**
65131 * Constructs a `RingBuffer`.
65132 * @param capacity The number of items that the buffer can accomodate.
65133 */
65134 constructor(capacity) {
65135 this.capacity = capacity;
65136 // Note we store the indices in the range 0 <= index < 2*capacity.
65137 // This allows us to distinguish the full from the empty case.
65138 // See https://www.snellman.net/blog/archive/2016-12-13-ring-buffers/
65139 this.begin = 0; // inclusive
65140 this.end = 0; // exclusive
65141 if (capacity == null) {
65142 throw new RangeError('Can\'t create a ring buffer of unknown capacity.');
65143 }
65144 if (capacity < 1) {
65145 throw new RangeError('Can\'t create ring buffer of capacity < 1.');
65146 }
65147 this.data = new Array(capacity);
65148 this.doubledCapacity = 2 * capacity;
65149 }
65150 /**
65151 * Map any index into the range 0 <= index < 2*capacity.
65152 */
65153 wrap(index) {
65154 // don't trust % on negative numbers
65155 while (index < 0) {
65156 index += this.doubledCapacity;
65157 }
65158 return index % this.doubledCapacity;
65159 }
65160 get(index) {
65161 if (index < 0) {
65162 throw new RangeError('Can\'t get item at a negative index.');
65163 }
65164 return this.data[index % this.capacity];
65165 }
65166 set(index, value) {
65167 if (index < 0) {
65168 throw new RangeError('Can\'t set item at a negative index.');
65169 }
65170 this.data[index % this.capacity] = value;
65171 }
65172 /**
65173 * Returns the current number of items in the buffer.
65174 */
65175 length() {
65176 let length = this.end - this.begin;
65177 if (length < 0) {
65178 length = this.doubledCapacity + length;
65179 }
65180 return length;
65181 }
65182 /**
65183 * Reports whether the buffer is full.
65184 * @returns true if the number of items in the buffer equals its capacity, and
65185 * false otherwise.
65186 */
65187 isFull() {
65188 return this.length() === this.capacity;
65189 }
65190 /**
65191 * Reports whether the buffer is empty.
65192 * @returns true if the number of items in the buffer equals zero, and
65193 * false otherwise.
65194 */
65195 isEmpty() {
65196 return this.length() === 0;
65197 }
65198 /**
65199 * Adds an item to the end of the buffer.
65200 */
65201 push(value) {
65202 if (this.isFull()) {
65203 throw new RangeError('Ring buffer is full.');
65204 }
65205 this.set(this.end, value);
65206 this.end = this.wrap(this.end + 1);
65207 }
65208 /**
65209 * Adds many items to the end of the buffer, in order.
65210 */
65211 pushAll(values) {
65212 for (const value of values) {
65213 this.push(value);
65214 }
65215 }
65216 /**
65217 * Removes and returns the last item in the buffer.
65218 */
65219 pop() {
65220 if (this.isEmpty()) {
65221 throw new RangeError('Ring buffer is empty.');
65222 }
65223 this.end = this.wrap(this.end - 1);
65224 const result = this.get(this.end);
65225 this.set(this.end, undefined);
65226 return result;
65227 }
65228 /**
65229 * Adds an item to the beginning of the buffer.
65230 */
65231 unshift(value) {
65232 if (this.isFull()) {
65233 throw new RangeError('Ring buffer is full.');
65234 }
65235 this.begin = this.wrap(this.begin - 1);
65236 this.set(this.begin, value);
65237 }
65238 /**
65239 * Removes and returns the first item in the buffer.
65240 */
65241 shift() {
65242 if (this.isEmpty()) {
65243 throw new RangeError('Ring buffer is empty.');
65244 }
65245 const result = this.get(this.begin);
65246 this.set(this.begin, undefined);
65247 this.begin = this.wrap(this.begin + 1);
65248 return result;
65249 }
65250 /**
65251 * Removes and returns a specific item in the buffer, and moves the last item
65252 * to the vacated slot. This is useful for implementing a shuffling stream.
65253 * Note that this operation necessarily scrambles the original order.
65254 *
65255 * @param relativeIndex: the index of the item to remove, relative to the
65256 * first item in the buffer (e.g., hiding the ring nature of the underlying
65257 * storage).
65258 */
65259 shuffleExcise(relativeIndex) {
65260 if (this.isEmpty()) {
65261 throw new RangeError('Ring buffer is empty.');
65262 }
65263 const index = this.wrap(this.begin + relativeIndex);
65264 const result = this.get(index);
65265 this.set(index, this.pop());
65266 return result;
65267 }
65268 }
65269
65270 /**
65271 * @license
65272 * Copyright 2018 Google LLC. All Rights Reserved.
65273 * Licensed under the Apache License, Version 2.0 (the "License");
65274 * you may not use this file except in compliance with the License.
65275 * You may obtain a copy of the License at
65276 *
65277 * http://www.apache.org/licenses/LICENSE-2.0
65278 *
65279 * Unless required by applicable law or agreed to in writing, software
65280 * distributed under the License is distributed on an "AS IS" BASIS,
65281 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65282 * See the License for the specific language governing permissions and
65283 * limitations under the License.
65284 *
65285 * =============================================================================
65286 */
65287 class GrowingRingBuffer extends RingBuffer {
65288 /**
65289 * Constructs a `GrowingRingBuffer`.
65290 */
65291 constructor() {
65292 super(GrowingRingBuffer.INITIAL_CAPACITY);
65293 }
65294 isFull() {
65295 return false;
65296 }
65297 push(value) {
65298 if (super.isFull()) {
65299 this.expand();
65300 }
65301 super.push(value);
65302 }
65303 unshift(value) {
65304 if (super.isFull()) {
65305 this.expand();
65306 }
65307 super.unshift(value);
65308 }
65309 /**
65310 * Doubles the capacity of the buffer.
65311 */
65312 expand() {
65313 const newCapacity = this.capacity * 2;
65314 const newData = new Array(newCapacity);
65315 const len = this.length();
65316 // Rotate the buffer to start at index 0 again, since we can't just
65317 // allocate more space at the end.
65318 for (let i = 0; i < len; i++) {
65319 newData[i] = this.get(this.wrap(this.begin + i));
65320 }
65321 this.data = newData;
65322 this.capacity = newCapacity;
65323 this.doubledCapacity = 2 * this.capacity;
65324 this.begin = 0;
65325 this.end = len;
65326 }
65327 }
65328 GrowingRingBuffer.INITIAL_CAPACITY = 32;
65329
65330 /**
65331 * @license
65332 * Copyright 2018 Google LLC. All Rights Reserved.
65333 * Licensed under the Apache License, Version 2.0 (the "License");
65334 * you may not use this file except in compliance with the License.
65335 * You may obtain a copy of the License at
65336 *
65337 * http://www.apache.org/licenses/LICENSE-2.0
65338 *
65339 * Unless required by applicable law or agreed to in writing, software
65340 * distributed under the License is distributed on an "AS IS" BASIS,
65341 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65342 * See the License for the specific language governing permissions and
65343 * limitations under the License.
65344 *
65345 * =============================================================================
65346 */
65347 // Here we implement a simple asynchronous iterator.
65348 // This lets us avoid using either third-party stream libraries or
65349 // recent TypeScript language support requiring polyfills.
65350 /**
65351 * Create a `LazyIterator` from an array of items.
65352 */
65353 function iteratorFromItems(items) {
65354 return new ArrayIterator(items);
65355 }
65356 /**
65357 * Create a `LazyIterator` of incrementing integers.
65358 */
65359 function iteratorFromIncrementing(start) {
65360 let i = start;
65361 return iteratorFromFunction(() => ({ value: i++, done: false }));
65362 }
65363 /**
65364 * Create a `LazyIterator` from a function.
65365 *
65366 * ```js
65367 * let i = -1;
65368 * const func = () =>
65369 * ++i < 5 ? {value: i, done: false} : {value: null, done: true};
65370 * const iter = tf.data.iteratorFromFunction(func);
65371 * await iter.forEachAsync(e => console.log(e));
65372 * ```
65373 *
65374 * @param func A function that produces data on each call.
65375 */
65376 function iteratorFromFunction(func) {
65377 return new FunctionCallIterator(func);
65378 }
65379 /**
65380 * Create a `LazyIterator` by concatenating underlying streams, which are
65381 * themselves provided as a stream.
65382 *
65383 * This can also be thought of as a "stream flatten" operation.
65384 *
65385 * @param baseIterators A stream of streams to be concatenated.
65386 * @param baseErrorHandler An optional function that can intercept `Error`s
65387 * raised during a `next()` call on the base stream. This function can decide
65388 * whether the error should be propagated, whether the error should be
65389 * ignored, or whether the base stream should be terminated.
65390 */
65391 function iteratorFromConcatenated(baseIterators, baseErrorHandler) {
65392 return new ChainedIterator(baseIterators, baseErrorHandler);
65393 }
65394 /**
65395 * Create a `LazyIterator` by concatenating streams produced by calling a
65396 * stream-generating function a given number of times.
65397 *
65398 * Since a `LazyIterator` is read-once, it cannot be repeated, but this
65399 * function can be used to achieve a similar effect:
65400 *
65401 * LazyIterator.ofConcatenatedFunction(() => new MyIterator(), 6);
65402 *
65403 * @param iteratorFunc: A function that produces a new stream on each call.
65404 * @param count: The number of times to call the function.
65405 * @param baseErrorHandler An optional function that can intercept `Error`s
65406 * raised during a `next()` call on the base stream. This function can decide
65407 * whether the error should be propagated, whether the error should be
65408 * ignored, or whether the base stream should be terminated.
65409 */
65410 function iteratorFromConcatenatedFunction(iteratorFunc, count, baseErrorHandler) {
65411 return iteratorFromConcatenated(iteratorFromFunction(iteratorFunc).take(count), baseErrorHandler);
65412 }
65413 /**
65414 * Create a `LazyIterator` by zipping together an array, dict, or nested
65415 * structure of `LazyIterator`s (and perhaps additional constants).
65416 *
65417 * The underlying streams must provide elements in a consistent order such
65418 * that they correspond.
65419 *
65420 * Typically, the underlying streams should have the same number of
65421 * elements. If they do not, the behavior is determined by the
65422 * `mismatchMode` argument.
65423 *
65424 * The nested structure of the `iterators` argument determines the
65425 * structure of elements in the resulting iterator.
65426 *
65427 * @param iterators: An array or object containing LazyIterators at the
65428 * leaves.
65429 * @param mismatchMode: Determines what to do when one underlying iterator
65430 * is exhausted before the others. `ZipMismatchMode.FAIL` (the default)
65431 * causes an error to be thrown in this case. `ZipMismatchMode.SHORTEST`
65432 * causes the zipped iterator to terminate with the furst underlying
65433 * streams, so elements remaining on the longer streams are ignored.
65434 * `ZipMismatchMode.LONGEST` causes the zipped stream to continue, filling
65435 * in nulls for the exhausted streams, until all streams are exhausted.
65436 */
65437 function iteratorFromZipped(iterators, mismatchMode = ZipMismatchMode.FAIL) {
65438 return new ZipIterator(iterators, mismatchMode);
65439 }
65440 /**
65441 * An asynchronous iterator, providing lazy access to a potentially
65442 * unbounded stream of elements.
65443 *
65444 * Iterator can be obtained from a dataset:
65445 * `const iter = await dataset.iterator();`
65446 */
65447 class LazyIterator {
65448 /**
65449 * Collect all remaining elements of a bounded stream into an array.
65450 * Obviously this will succeed only for small streams that fit in memory.
65451 * Useful for testing.
65452 *
65453 * @returns A Promise for an array of stream elements, which will resolve
65454 * when the stream is exhausted.
65455 */
65456 async toArray() {
65457 const result = [];
65458 let x = await this.next();
65459 while (!x.done) {
65460 result.push(x.value);
65461 x = await this.next();
65462 }
65463 return result;
65464 }
65465 /**
65466 * Collect all elements of this dataset into an array with prefetching 100
65467 * elements. This is useful for testing, because the prefetch changes the
65468 * order in which the Promises are resolved along the processing pipeline.
65469 * This may help expose bugs where results are dependent on the order of
65470 * Promise resolution rather than on the logical order of the stream (i.e.,
65471 * due to hidden mutable state).
65472 *
65473 * @returns A Promise for an array of stream elements, which will resolve
65474 * when the stream is exhausted.
65475 */
65476 async toArrayForTest() {
65477 const stream = this.prefetch(100);
65478 const result = [];
65479 let x = await stream.next();
65480 while (!x.done) {
65481 result.push(x.value);
65482 x = await stream.next();
65483 }
65484 return result;
65485 }
65486 /**
65487 * Draw items from the stream until it is exhausted.
65488 *
65489 * This can be useful when the stream has side effects but no output. In
65490 * that case, calling this function guarantees that the stream will be
65491 * fully processed.
65492 */
65493 async resolveFully() {
65494 let x = await this.next();
65495 while (!x.done) {
65496 x = await this.next();
65497 }
65498 }
65499 /**
65500 * Draw items from the stream until it is exhausted, or a predicate fails.
65501 *
65502 * This can be useful when the stream has side effects but no output. In
65503 * that case, calling this function guarantees that the stream will be
65504 * fully processed.
65505 */
65506 async resolveWhile(predicate) {
65507 let x = await this.next();
65508 let shouldContinue = predicate(x.value);
65509 while ((!x.done) && shouldContinue) {
65510 x = await this.next();
65511 shouldContinue = predicate(x.value);
65512 }
65513 }
65514 /**
65515 * Handles errors thrown on this stream using a provided handler function.
65516 *
65517 * @param handler A function that handles any `Error` thrown during a `next()`
65518 * call and returns true if the stream should continue (dropping the failed
65519 * call) or false if the stream should quietly terminate. If the handler
65520 * itself throws (or rethrows) an `Error`, that will be propagated.
65521 *
65522 * @returns A `LazyIterator` of elements passed through from upstream,
65523 * possibly filtering or terminating on upstream `next()` calls that
65524 * throw an `Error`.
65525 */
65526 handleErrors(handler) {
65527 return new ErrorHandlingLazyIterator(this, handler);
65528 }
65529 // TODO(soergel): Implement reduce() etc.
65530 /**
65531 * Filters this stream according to `predicate`.
65532 *
65533 * @param predicate A function mapping a stream element to a boolean or a
65534 * `Promise` for one.
65535 *
65536 * @returns A `LazyIterator` of elements for which the predicate was true.
65537 */
65538 filter(predicate) {
65539 return new FilterIterator(this, predicate);
65540 }
65541 /**
65542 * Maps this stream through a 1-to-1 transform.
65543 *
65544 * @param transform A function mapping a stream element to a transformed
65545 * element.
65546 *
65547 * @returns A `LazyIterator` of transformed elements.
65548 */
65549 map(transform) {
65550 return new MapIterator(this, transform);
65551 }
65552 /**
65553 * Maps this stream through an async 1-to-1 transform.
65554 *
65555 * @param transform A function mapping a stream element to a `Promise` for a
65556 * transformed stream element.
65557 *
65558 * @returns A `LazyIterator` of transformed elements.
65559 */
65560 mapAsync(transform) {
65561 return new AsyncMapIterator(this, transform);
65562 }
65563 /**
65564 * Maps this stream through a 1-to-1 transform, forcing serial execution.
65565 *
65566 * @param transform A function mapping a stream element to a transformed
65567 * element.
65568 *
65569 * @returns A `LazyIterator` of transformed elements.
65570 */
65571 serialMapAsync(transform) {
65572 return new AsyncMapIterator(this, transform).serial();
65573 }
65574 /**
65575 * Maps this stream through a 1-to-many transform.
65576 *
65577 * @param transform A function mapping a stream element to an array of
65578 * transformed elements.
65579 *
65580 * @returns A `DataStream` of transformed elements.
65581 */
65582 flatmap(transform) {
65583 return new FlatmapIterator(this, transform);
65584 }
65585 /**
65586 * Apply a function to every element of the stream.
65587 *
65588 * @param f A function to apply to each stream element.
65589 */
65590 async forEachAsync(f) {
65591 return this.map(f).resolveFully();
65592 }
65593 /**
65594 * Apply a function to every element of the stream, forcing serial execution.
65595 *
65596 * @param f A function to apply to each stream element. Should return 'true'
65597 * to indicate that the stream should continue, or 'false' to cause it to
65598 * terminate.
65599 */
65600 async serialForEach(f) {
65601 return this.serialMapAsync(f).resolveWhile(x => (x === true));
65602 }
65603 /**
65604 * Groups elements into batches, represented as arrays of elements.
65605 *
65606 * We can think of the elements of this iterator as 'rows' (even if they are
65607 * nested structures). By the same token, consecutive values for a given
65608 * key within the elements form a 'column'. This matches the usual sense of
65609 * 'row' and 'column' when processing tabular data (e.g., parsing a CSV).
65610 *
65611 * Thus, "Row-major" means that the resulting batch is simply a collection of
65612 * rows: `[row1, row2, row3, ...]`. This is contrast to the column-major
65613 * form, which is needed for vectorized computation.
65614 *
65615 * @param batchSize The number of elements desired per batch.
65616 * @param smallLastBatch Whether to emit the final batch when it has fewer
65617 * than batchSize elements. Default true.
65618 * @returns A `LazyIterator` of batches of elements, represented as arrays
65619 * of the original element type.
65620 */
65621 rowMajorBatch(batchSize, smallLastBatch = true) {
65622 return new RowMajorBatchIterator(this, batchSize, smallLastBatch);
65623 }
65624 /**
65625 * Groups elements into batches, represented in column-major form.
65626 *
65627 * We can think of the elements of this iterator as 'rows' (even if they are
65628 * nested structures). By the same token, consecutive values for a given
65629 * key within the elements form a 'column'. This matches the usual sense of
65630 * 'row' and 'column' when processing tabular data (e.g., parsing a CSV).
65631 *
65632 * Thus, "column-major" means that the resulting batch is a (potentially
65633 * nested) structure representing the columns. Each column entry, then,
65634 * contains a collection of the values found in that column for a range of
65635 * input elements. This representation allows for vectorized computation, in
65636 * contrast to the row-major form.
65637 *
65638 * The inputs should all have the same nested structure (i.e., of arrays and
65639 * dicts). The result is a single object with the same nested structure,
65640 * where the leaves are arrays collecting the values of the inputs at that
65641 * location (or, optionally, the result of a custom function applied to those
65642 * arrays).
65643 *
65644 * @param batchSize The number of elements desired per batch.
65645 * @param smallLastBatch Whether to emit the final batch when it has fewer
65646 * than batchSize elements. Default true.
65647 * @param zipFn: (optional) A function that expects an array of elements at a
65648 * single node of the object tree, and returns a `DeepMapResult`. The
65649 * `DeepMapResult` either provides a result value for that node (i.e.,
65650 * representing the subtree), or indicates that the node should be processed
65651 * recursively. The default zipFn recurses as far as possible and places
65652 * arrays at the leaves.
65653 * @returns A `LazyIterator` of batches of elements, represented as an object
65654 * with collections at the leaves.
65655 */
65656 columnMajorBatch(batchSize, smallLastBatch = true,
65657 // tslint:disable-next-line:no-any
65658 zipFn = zipToList) {
65659 // First collect the desired number of input elements as a row-major batch.
65660 const rowBatches = this.rowMajorBatch(batchSize, smallLastBatch);
65661 // Now 'rotate' or 'pivot' the data, collecting all values from each column
65662 // in the batch (i.e., for each key within the elements) into an array.
65663 return rowBatches.map(x => deepZip(x, zipFn));
65664 }
65665 /**
65666 * Concatenate this `LazyIterator` with another.
65667 *
65668 * @param iterator A `LazyIterator` to be concatenated onto this one.
65669 * @param baseErrorHandler An optional function that can intercept `Error`s
65670 * raised during a `next()` call on the base stream. This function can
65671 * decide whether the error should be propagated, whether the error should
65672 * be ignored, or whether the base stream should be terminated.
65673 * @returns A `LazyIterator`.
65674 */
65675 concatenate(iterator, baseErrorHandler) {
65676 return new ChainedIterator(iteratorFromItems([this, iterator]), baseErrorHandler);
65677 }
65678 /**
65679 * Limits this stream to return at most `count` items.
65680 *
65681 * @param count The maximum number of items to provide from the stream. If
65682 * a negative or undefined value is given, the entire stream is returned
65683 * unaltered.
65684 */
65685 take(count) {
65686 if (count < 0 || count == null) {
65687 return this;
65688 }
65689 return new TakeIterator(this, count);
65690 }
65691 /**
65692 * Skips the first `count` items in this stream.
65693 *
65694 * @param count The number of items to skip. If a negative or undefined
65695 * value is given, the entire stream is returned unaltered.
65696 */
65697 skip(count) {
65698 if (count < 0 || count == null) {
65699 return this;
65700 }
65701 return new SkipIterator(this, count);
65702 }
65703 /**
65704 * Prefetch the first `bufferSize` items in this stream.
65705 *
65706 * Note this prefetches Promises, but makes no guarantees about when those
65707 * Promises resolve.
65708 *
65709 * @param bufferSize: An integer specifying the number of elements to be
65710 * prefetched.
65711 */
65712 prefetch(bufferSize) {
65713 return new PrefetchIterator(this, bufferSize);
65714 }
65715 // TODO(soergel): deep sharded shuffle, where supported
65716 /**
65717 * Randomly shuffles the elements of this stream.
65718 *
65719 * @param bufferSize: An integer specifying the number of elements from
65720 * this stream from which the new stream will sample.
65721 * @param seed: (Optional.) An integer specifying the random seed that
65722 * will be used to create the distribution.
65723 */
65724 shuffle(windowSize, seed) {
65725 return new ShuffleIterator(this, windowSize, seed);
65726 }
65727 /**
65728 * Force an iterator to execute serially: each next() call will await the
65729 * prior one, so that they cannot execute concurrently.
65730 */
65731 serial() {
65732 return new SerialIterator(this);
65733 }
65734 }
65735 // ============================================================================
65736 // The following private classes serve to implement the chainable methods
65737 // on LazyIterator. Unfortunately they can't be placed in separate files,
65738 // due to resulting trouble with circular imports.
65739 // ============================================================================
65740 // Iterators that just extend LazyIterator directly
65741 // ============================================================================
65742 class ArrayIterator extends LazyIterator {
65743 constructor(items) {
65744 super();
65745 this.items = items;
65746 this.trav = 0;
65747 }
65748 summary() {
65749 return `Array of ${this.items.length} items`;
65750 }
65751 async next() {
65752 if (this.trav >= this.items.length) {
65753 return { value: null, done: true };
65754 }
65755 const item = this.items[this.trav];
65756 this.trav++;
65757 return { value: deepClone(item), done: false };
65758 }
65759 }
65760 class FunctionCallIterator extends LazyIterator {
65761 constructor(nextFn) {
65762 super();
65763 this.nextFn = nextFn;
65764 }
65765 summary() {
65766 return `Function call`;
65767 }
65768 async next() {
65769 try {
65770 return this.nextFn();
65771 }
65772 catch (e) {
65773 // Modify the error message but leave the stack trace intact
65774 e.message =
65775 `Error thrown while iterating through a dataset: ${e.message}`;
65776 throw e;
65777 }
65778 }
65779 }
65780 class SerialIterator extends LazyIterator {
65781 constructor(upstream) {
65782 super();
65783 this.upstream = upstream;
65784 this.lastRead = Promise.resolve({ value: null, done: false });
65785 }
65786 summary() {
65787 return `${this.upstream.summary()} -> Serial`;
65788 }
65789 async next() {
65790 // This sets this.lastRead to a new Promise right away, as opposed to
65791 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
65792 // would not work because this.nextRead would be updated only after the
65793 // promise resolves.
65794 this.lastRead = this.lastRead.then(() => this.serialNext());
65795 return this.lastRead;
65796 }
65797 async serialNext() {
65798 return this.upstream.next();
65799 }
65800 }
65801 class SkipIterator extends LazyIterator {
65802 constructor(upstream, maxCount) {
65803 super();
65804 this.upstream = upstream;
65805 this.maxCount = maxCount;
65806 // Local state that should not be clobbered by out-of-order execution.
65807 this.count = 0;
65808 this.lastRead = Promise.resolve({ value: null, done: false });
65809 }
65810 summary() {
65811 return `${this.upstream.summary()} -> Skip`;
65812 }
65813 async next() {
65814 // This sets this.lastRead to a new Promise right away, as opposed to
65815 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
65816 // would not work because this.nextRead would be updated only after the
65817 // promise resolves.
65818 this.lastRead = this.lastRead.then(() => this.serialNext());
65819 return this.lastRead;
65820 }
65821 async serialNext() {
65822 // TODO(soergel): consider tradeoffs of reading in parallel, eg.
65823 // collecting next() promises in an Array and then waiting for
65824 // Promise.all() of those. Benefit: pseudo-parallel execution. Drawback:
65825 // maybe delayed GC.
65826 while (this.count++ < this.maxCount) {
65827 const skipped = await this.upstream.next();
65828 // short-circuit if upstream is already empty
65829 if (skipped.done) {
65830 return skipped;
65831 }
65832 dispose(skipped.value);
65833 }
65834 return this.upstream.next();
65835 }
65836 }
65837 class TakeIterator extends LazyIterator {
65838 constructor(upstream, maxCount) {
65839 super();
65840 this.upstream = upstream;
65841 this.maxCount = maxCount;
65842 this.count = 0;
65843 }
65844 summary() {
65845 return `${this.upstream.summary()} -> Take`;
65846 }
65847 async next() {
65848 if (this.count++ >= this.maxCount) {
65849 return { value: null, done: true };
65850 }
65851 return this.upstream.next();
65852 }
65853 }
65854 // Note this batch just groups items into row-wise element arrays.
65855 // Rotating these to a column-wise representation happens only at the dataset
65856 // level.
65857 class RowMajorBatchIterator extends LazyIterator {
65858 constructor(upstream, batchSize, enableSmallLastBatch = true) {
65859 super();
65860 this.upstream = upstream;
65861 this.batchSize = batchSize;
65862 this.enableSmallLastBatch = enableSmallLastBatch;
65863 this.lastRead = Promise.resolve({ value: null, done: false });
65864 }
65865 summary() {
65866 return `${this.upstream.summary()} -> RowMajorBatch`;
65867 }
65868 async next() {
65869 // This sets this.lastRead to a new Promise right away, as opposed to
65870 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
65871 // would not work because this.nextRead would be updated only after the
65872 // promise resolves.
65873 this.lastRead = this.lastRead.then(() => this.serialNext());
65874 return this.lastRead;
65875 }
65876 async serialNext() {
65877 const batch = [];
65878 while (batch.length < this.batchSize) {
65879 const item = await this.upstream.next();
65880 if (item.done) {
65881 if (this.enableSmallLastBatch && batch.length > 0) {
65882 return { value: batch, done: false };
65883 }
65884 return { value: null, done: true };
65885 }
65886 batch.push(item.value);
65887 }
65888 return { value: batch, done: false };
65889 }
65890 }
65891 class FilterIterator extends LazyIterator {
65892 constructor(upstream, predicate) {
65893 super();
65894 this.upstream = upstream;
65895 this.predicate = predicate;
65896 this.lastRead = Promise.resolve({ value: null, done: false });
65897 }
65898 summary() {
65899 return `${this.upstream.summary()} -> Filter`;
65900 }
65901 async next() {
65902 // This sets this.lastRead to a new Promise right away, as opposed to
65903 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
65904 // would not work because this.nextRead would be updated only after the
65905 // promise resolves.
65906 this.lastRead = this.lastRead.then(() => this.serialNext());
65907 return this.lastRead;
65908 }
65909 async serialNext() {
65910 while (true) {
65911 const item = await this.upstream.next();
65912 if (item.done || this.predicate(item.value)) {
65913 return item;
65914 }
65915 dispose(item.value);
65916 }
65917 }
65918 }
65919 class MapIterator extends LazyIterator {
65920 constructor(upstream, transform) {
65921 super();
65922 this.upstream = upstream;
65923 this.transform = transform;
65924 }
65925 summary() {
65926 return `${this.upstream.summary()} -> Map`;
65927 }
65928 async next() {
65929 const item = await this.upstream.next();
65930 if (item.done) {
65931 return { value: null, done: true };
65932 }
65933 const inputTensors = getTensorsInContainer(item.value);
65934 // Careful: the transform may mutate the item in place.
65935 // That's why we have to remember the input Tensors above, and then
65936 // below dispose only those that were not passed through to the output.
65937 // Note too that the transform function is responsible for tidying
65938 // any intermediate Tensors. Here we are concerned only about the
65939 // inputs.
65940 const mapped = this.transform(item.value);
65941 const outputTensors = getTensorsInContainer(mapped);
65942 // TODO(soergel) faster intersection
65943 // TODO(soergel) move to tf.disposeExcept(in, out)?
65944 for (const t of inputTensors) {
65945 if (!isTensorInList(t, outputTensors)) {
65946 t.dispose();
65947 }
65948 }
65949 return { value: mapped, done: false };
65950 }
65951 }
65952 class ErrorHandlingLazyIterator extends LazyIterator {
65953 constructor(upstream, handler) {
65954 super();
65955 this.upstream = upstream;
65956 this.handler = handler;
65957 this.count = 0;
65958 this.lastRead = Promise.resolve({ value: null, done: false });
65959 }
65960 summary() {
65961 return `${this.upstream.summary()} -> handleErrors`;
65962 }
65963 async next() {
65964 // This sets this.lastRead to a new Promise right away, as opposed to
65965 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
65966 // would not work because this.nextRead would be updated only after the
65967 // promise resolves.
65968 this.lastRead = this.lastRead.then(() => this.serialNext());
65969 return this.lastRead;
65970 }
65971 async serialNext() {
65972 while (true) {
65973 try {
65974 return await this.upstream.next();
65975 }
65976 catch (e) {
65977 if (!this.handler(e)) {
65978 return { value: null, done: true };
65979 }
65980 // If the handler returns true, loop and fetch the next upstream item.
65981 // If the upstream iterator throws an endless stream of errors, and if
65982 // the handler says to ignore them, then we loop forever here. That is
65983 // the correct behavior-- it's up to the handler to decide when to stop.
65984 }
65985 }
65986 }
65987 }
65988 class AsyncMapIterator extends LazyIterator {
65989 constructor(upstream, transform) {
65990 super();
65991 this.upstream = upstream;
65992 this.transform = transform;
65993 }
65994 summary() {
65995 return `${this.upstream.summary()} -> AsyncMap`;
65996 }
65997 async next() {
65998 const item = await this.upstream.next();
65999 if (item.done) {
66000 return { value: null, done: true };
66001 }
66002 const inputTensors = getTensorsInContainer(item.value);
66003 // Careful: the transform may mutate the item in place.
66004 // That's why we have to remember the input Tensors above, and then
66005 // below dispose only those that were not passed through to the output.
66006 // Note too that the transform function is responsible for tidying
66007 // any intermediate Tensors. Here we are concerned only about the
66008 // inputs.
66009 const mapped = await this.transform(item.value);
66010 const outputTensors = getTensorsInContainer(mapped);
66011 // TODO(soergel) faster intersection
66012 // TODO(soergel) move to tf.disposeExcept(in, out)?
66013 for (const t of inputTensors) {
66014 if (!isTensorInList(t, outputTensors)) {
66015 t.dispose();
66016 }
66017 }
66018 return { value: mapped, done: false };
66019 }
66020 }
66021 // Iterators that maintain a queue of pending items
66022 // ============================================================================
66023 /**
66024 * A base class for transforming streams that operate by maintaining an
66025 * output queue of elements that are ready to return via next(). This is
66026 * commonly required when the transformation is 1-to-many: A call to next()
66027 * may trigger a call to the underlying stream, which will produce many
66028 * mapped elements of this stream-- of which we need to return only one, so
66029 * we have to queue the rest.
66030 */
66031 class OneToManyIterator extends LazyIterator {
66032 constructor() {
66033 super();
66034 this.outputQueue = new GrowingRingBuffer();
66035 this.lastRead = Promise.resolve({ value: null, done: false });
66036 }
66037 async next() {
66038 // This sets this.lastRead to a new Promise right away, as opposed to
66039 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
66040 // would not work because this.nextRead would be updated only after the
66041 // promise resolves.
66042 this.lastRead = this.lastRead.then(() => this.serialNext());
66043 return this.lastRead;
66044 }
66045 async serialNext() {
66046 // Fetch so that the queue contains at least one item if possible.
66047 // If the upstream source is exhausted, AND there are no items left in
66048 // the output queue, then this stream is also exhausted.
66049 while (this.outputQueue.length() === 0) {
66050 // TODO(soergel): consider parallel reads.
66051 if (!await this.pump()) {
66052 return { value: null, done: true };
66053 }
66054 }
66055 return { value: this.outputQueue.shift(), done: false };
66056 }
66057 }
66058 class FlatmapIterator extends OneToManyIterator {
66059 constructor(upstream, transform) {
66060 super();
66061 this.upstream = upstream;
66062 this.transform = transform;
66063 }
66064 summary() {
66065 return `${this.upstream.summary()} -> Flatmap`;
66066 }
66067 async pump() {
66068 const item = await this.upstream.next();
66069 if (item.done) {
66070 return false;
66071 }
66072 const inputTensors = getTensorsInContainer(item.value);
66073 // Careful: the transform may mutate the item in place.
66074 // that's why we have to remember the input Tensors above, and then
66075 // below dispose only those that were not passed through to the output.
66076 // Note too that the transform function is responsible for tidying any
66077 // intermediate Tensors. Here we are concerned only about the inputs.
66078 const mappedArray = this.transform(item.value);
66079 const outputTensors = getTensorsInContainer(mappedArray);
66080 this.outputQueue.pushAll(mappedArray);
66081 // TODO(soergel) faster intersection, and deduplicate outputTensors
66082 // TODO(soergel) move to tf.disposeExcept(in, out)?
66083 for (const t of inputTensors) {
66084 if (!isTensorInList(t, outputTensors)) {
66085 t.dispose();
66086 }
66087 }
66088 return true;
66089 }
66090 }
66091 /**
66092 * Provides a `LazyIterator` that concatenates a stream of underlying
66093 * streams.
66094 *
66095 * Doing this in a concurrency-safe way requires some trickery. In
66096 * particular, we want this stream to return the elements from the
66097 * underlying streams in the correct order according to when next() was
66098 * called, even if the resulting Promises resolve in a different order.
66099 */
66100 class ChainedIterator extends LazyIterator {
66101 constructor(iterators, baseErrorHandler) {
66102 super();
66103 this.baseErrorHandler = baseErrorHandler;
66104 // Strict Promise execution order:
66105 // a next() call may not even begin until the previous one completes.
66106 this.lastRead = null;
66107 // Local state that should not be clobbered by out-of-order execution.
66108 this.iterator = null;
66109 this.moreIterators = iterators;
66110 }
66111 summary() {
66112 const upstreamSummaries = 'TODO: fill in upstream of chained summaries';
66113 return `${upstreamSummaries} -> Chained`;
66114 }
66115 async next() {
66116 this.lastRead = this.readFromChain(this.lastRead);
66117 return this.lastRead;
66118 }
66119 async readFromChain(lastRead) {
66120 // Must await on the previous read since the previous read may have advanced
66121 // the stream of streams, from which we need to read.
66122 // This is unfortunate since we can't parallelize reads. Which means
66123 // prefetching of chained streams is a no-op.
66124 // One solution is to prefetch immediately upstream of this.
66125 await lastRead;
66126 if (this.iterator == null) {
66127 const iteratorResult = await this.moreIterators.next();
66128 if (iteratorResult.done) {
66129 // No more streams to stream from.
66130 return { value: null, done: true };
66131 }
66132 this.iterator = iteratorResult.value;
66133 if (this.baseErrorHandler != null) {
66134 this.iterator = this.iterator.handleErrors(this.baseErrorHandler);
66135 }
66136 }
66137 const itemResult = await this.iterator.next();
66138 if (itemResult.done) {
66139 this.iterator = null;
66140 return this.readFromChain(lastRead);
66141 }
66142 return itemResult;
66143 }
66144 }
66145 var ZipMismatchMode;
66146 (function (ZipMismatchMode) {
66147 ZipMismatchMode[ZipMismatchMode["FAIL"] = 0] = "FAIL";
66148 ZipMismatchMode[ZipMismatchMode["SHORTEST"] = 1] = "SHORTEST";
66149 ZipMismatchMode[ZipMismatchMode["LONGEST"] = 2] = "LONGEST"; // use nulls for exhausted streams; use up the longest stream.
66150 })(ZipMismatchMode || (ZipMismatchMode = {}));
66151 /**
66152 * Provides a `LazyIterator` that zips together an array, dict, or nested
66153 * structure of `LazyIterator`s (and perhaps additional constants).
66154 *
66155 * The underlying streams must provide elements in a consistent order such
66156 * that they correspond.
66157 *
66158 * Typically, the underlying streams should have the same number of
66159 * elements. If they do not, the behavior is determined by the
66160 * `mismatchMode` argument.
66161 *
66162 * The nested structure of the `iterators` argument determines the
66163 * structure of elements in the resulting iterator.
66164 *
66165 * Doing this in a concurrency-safe way requires some trickery. In
66166 * particular, we want this stream to return the elements from the
66167 * underlying streams in the correct order according to when next() was
66168 * called, even if the resulting Promises resolve in a different order.
66169 *
66170 * @param iterators: An array or object containing LazyIterators at the
66171 * leaves.
66172 * @param mismatchMode: Determines what to do when one underlying iterator
66173 * is exhausted before the others. `ZipMismatchMode.FAIL` (the default)
66174 * causes an error to be thrown in this case. `ZipMismatchMode.SHORTEST`
66175 * causes the zipped iterator to terminate with the furst underlying
66176 * streams, so elements remaining on the longer streams are ignored.
66177 * `ZipMismatchMode.LONGEST` causes the zipped stream to continue, filling
66178 * in nulls for the exhausted streams, until all streams are exhausted.
66179 */
66180 class ZipIterator extends LazyIterator {
66181 constructor(iterators, mismatchMode = ZipMismatchMode.FAIL) {
66182 super();
66183 this.iterators = iterators;
66184 this.mismatchMode = mismatchMode;
66185 this.count = 0;
66186 this.currentPromise = null;
66187 }
66188 summary() {
66189 const upstreamSummaries = 'TODO: fill in upstream of zip summaries';
66190 return `{${upstreamSummaries}} -> Zip`;
66191 }
66192 async nextState(afterState) {
66193 // This chaining ensures that the underlying next() are not even called
66194 // before the previous ones have resolved.
66195 await afterState;
66196 // Collect underlying iterator "done" signals as a side effect in
66197 // getNext()
66198 let numIterators = 0;
66199 let iteratorsDone = 0;
66200 function getNext(container) {
66201 if (container instanceof LazyIterator) {
66202 const result = container.next();
66203 return {
66204 value: result.then(x => {
66205 numIterators++;
66206 if (x.done) {
66207 iteratorsDone++;
66208 }
66209 return x.value;
66210 }),
66211 recurse: false
66212 };
66213 }
66214 else {
66215 return { value: null, recurse: true };
66216 }
66217 }
66218 const mapped = await deepMapAndAwaitAll(this.iterators, getNext);
66219 if (numIterators === iteratorsDone) {
66220 // The streams have all ended.
66221 return { value: null, done: true };
66222 }
66223 if (iteratorsDone > 0) {
66224 switch (this.mismatchMode) {
66225 case ZipMismatchMode.FAIL:
66226 throw new Error('Zipped streams should have the same length. ' +
66227 `Mismatched at element ${this.count}.`);
66228 case ZipMismatchMode.SHORTEST:
66229 return { value: null, done: true };
66230 case ZipMismatchMode.LONGEST:
66231 default:
66232 // Continue. The exhausted streams already produced value: null.
66233 }
66234 }
66235 this.count++;
66236 return { value: mapped, done: false };
66237 }
66238 async next() {
66239 this.currentPromise = this.nextState(this.currentPromise);
66240 return this.currentPromise;
66241 }
66242 }
66243 // Iterators that maintain a ring buffer of pending promises
66244 // ============================================================================
66245 /**
66246 * A stream that prefetches a given number of items from an upstream source,
66247 * returning them in FIFO order.
66248 *
66249 * Note this prefetches Promises, but makes no guarantees about when those
66250 * Promises resolve.
66251 */
66252 class PrefetchIterator extends LazyIterator {
66253 constructor(upstream, bufferSize) {
66254 super();
66255 this.upstream = upstream;
66256 this.bufferSize = bufferSize;
66257 this.buffer = new RingBuffer(bufferSize);
66258 }
66259 summary() {
66260 return `${this.upstream.summary()} -> Prefetch`;
66261 }
66262 /**
66263 * Refill the prefetch buffer. Returns only after the buffer is full, or
66264 * the upstream source is exhausted.
66265 */
66266 refill() {
66267 while (!this.buffer.isFull()) {
66268 const v = this.upstream.next();
66269 this.buffer.push(v);
66270 }
66271 }
66272 next() {
66273 this.refill();
66274 // This shift will never throw an error because the buffer is always
66275 // full after a refill. If the stream is exhausted, the buffer will be
66276 // full of Promises that will resolve to the end-of-stream signal.
66277 return this.buffer.shift();
66278 }
66279 }
66280 /**
66281 * A stream that performs a sliding-window random shuffle on an upstream
66282 * source. This is like a `PrefetchIterator` except that the items are
66283 * returned in randomized order. Mixing naturally improves as the buffer
66284 * size increases.
66285 */
66286 class ShuffleIterator extends PrefetchIterator {
66287 constructor(upstream, windowSize, seed) {
66288 super(upstream, windowSize);
66289 this.upstream = upstream;
66290 this.windowSize = windowSize;
66291 // Local state that should not be clobbered by out-of-order execution.
66292 this.upstreamExhausted = false;
66293 this.random = seedrandom_1(seed || now().toString());
66294 this.lastRead = Promise.resolve({ value: null, done: false });
66295 }
66296 async next() {
66297 // This sets this.lastRead to a new Promise right away, as opposed to
66298 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
66299 // would not work because this.nextRead would be updated only after the
66300 // promise resolves.
66301 this.lastRead = this.lastRead.then(() => this.serialNext());
66302 return this.lastRead;
66303 }
66304 randomInt(max) {
66305 return Math.floor(this.random() * max);
66306 }
66307 chooseIndex() {
66308 return this.randomInt(this.buffer.length());
66309 }
66310 async serialNext() {
66311 // TODO(soergel): consider performance
66312 if (!this.upstreamExhausted) {
66313 this.refill();
66314 }
66315 while (!this.buffer.isEmpty()) {
66316 const chosenIndex = this.chooseIndex();
66317 const result = await this.buffer.shuffleExcise(chosenIndex);
66318 if (result.done) {
66319 this.upstreamExhausted = true;
66320 }
66321 else {
66322 this.refill();
66323 return result;
66324 }
66325 }
66326 return { value: null, done: true };
66327 }
66328 }
66329
66330 /**
66331 * @license
66332 * Copyright 2018 Google LLC. All Rights Reserved.
66333 * Licensed under the Apache License, Version 2.0 (the "License");
66334 * you may not use this file except in compliance with the License.
66335 * You may obtain a copy of the License at
66336 *
66337 * http://www.apache.org/licenses/LICENSE-2.0
66338 *
66339 * Unless required by applicable law or agreed to in writing, software
66340 * distributed under the License is distributed on an "AS IS" BASIS,
66341 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
66342 * See the License for the specific language governing permissions and
66343 * limitations under the License.
66344 *
66345 * =============================================================================
66346 */
66347 // TODO(soergel): consider vectorized operations within the pipeline.
66348 /**
66349 * Represents a potentially large list of independent data elements (typically
66350 * 'samples' or 'examples').
66351 *
66352 * A 'data example' may be a primitive, an array, a map from string keys to
66353 * values, or any nested structure of these.
66354 *
66355 * A `Dataset` represents an ordered collection of elements, together with a
66356 * chain of transformations to be performed on those elements. Each
66357 * transformation is a method of `Dataset` that returns another `Dataset`, so
66358 * these may be chained, e.g.
66359 * `const processedDataset = rawDataset.filter(...).map(...).batch(...)`.
66360 *
66361 * Data loading and transformation is done in a lazy, streaming fashion. The
66362 * dataset may be iterated over multiple times; each iteration starts the data
66363 * loading anew and recapitulates the transformations.
66364 *
66365 * A `Dataset` is typically processed as a stream of unbatched examples --i.e.,
66366 * its transformations are applied one example at a time. Batching produces a
66367 * new `Dataset` where each element is a batch. Batching should usually come
66368 * last in a pipeline, because data transformations are easier to express on a
66369 * per-example basis than on a per-batch basis.
66370 *
66371 * The following code examples are calling `await dataset.forEachAsync(...)` to
66372 * iterate once over the entire dataset in order to print out the data.
66373 *
66374 * @doc {heading: 'Data', subheading: 'Classes', namespace: 'data'}
66375 */
66376 class Dataset {
66377 constructor() {
66378 this.size = null;
66379 }
66380 // TODO(soergel): Make Datasets report whether repeated iterator() calls
66381 // produce the same result (e.g., reading from a file) or different results
66382 // (e.g., from the webcam). Currently we don't make this distinction but it
66383 // could be important for the user to know.
66384 // abstract isDeterministic(): boolean;
66385 /**
66386 * Groups elements into batches.
66387 *
66388 * It is assumed that each of the incoming dataset elements has the same
66389 * structure-- i.e. the same set of keys at each location in an object
66390 * hierarchy. For each key, the resulting `Dataset` provides a batched
66391 * element collecting all of the incoming values for that key.
66392 *
66393 * * Incoming primitives are grouped into a 1-D Tensor.
66394 * * Incoming Tensors are grouped into a new Tensor where the 0'th axis is
66395 * the batch dimension.
66396 * * Incoming arrays are converted to Tensor and then batched.
66397 * * A nested array is interpreted as an n-D Tensor, so the batched result
66398 * has n+1 dimensions.
66399 * * An array that cannot be converted to Tensor produces an error.
66400 *
66401 * If an array should not be batched as a unit, it should first be converted
66402 * to an object with integer keys.
66403 *
66404 * Here are a few examples:
66405 *
66406 * Batch a dataset of numbers:
66407 * ```js
66408 * const a = tf.data.array([1, 2, 3, 4, 5, 6, 7, 8]).batch(4);
66409 * await a.forEachAsync(e => e.print());
66410 * ```
66411 *
66412 * Batch a dataset of arrays:
66413 * ```js
66414 * const b = tf.data.array([[1], [2], [3], [4], [5], [6], [7], [8]]).batch(4);
66415 * await b.forEachAsync(e => e.print());
66416 * ```
66417 *
66418 * Batch a dataset of objects:
66419 * ```js
66420 * const c = tf.data.array([{a: 1, b: 11}, {a: 2, b: 12}, {a: 3, b: 13},
66421 * {a: 4, b: 14}, {a: 5, b: 15}, {a: 6, b: 16}, {a: 7, b: 17},
66422 * {a: 8, b: 18}]).batch(4);
66423 * await c.forEachAsync(e => {
66424 * console.log('{');
66425 * for(var key in e) {
66426 * console.log(key+':');
66427 * e[key].print();
66428 * }
66429 * console.log('}');
66430 * })
66431 * ```
66432 *
66433 * @param batchSize The number of elements desired per batch.
66434 * @param smallLastBatch Whether to emit the final batch when it has fewer
66435 * than batchSize elements. Default true.
66436 * @returns A `Dataset`, from which a stream of batches can be obtained.
66437 *
66438 * @doc {heading: 'Data', subheading: 'Classes'}
66439 */
66440 batch(batchSize, smallLastBatch = true) {
66441 const base = this;
66442 assert(batchSize > 0, () => `batchSize needs to be positive, but it is
66443 ${batchSize}`);
66444 let size;
66445 if (this.size === Infinity || this.size == null) {
66446 // If the size of this dataset is infinity or null, the new size keeps the
66447 // same.
66448 size = this.size;
66449 }
66450 else if (smallLastBatch) {
66451 // If the size of this dataset is known and include small last batch, the
66452 // new size is full batch count plus last batch.
66453 size = Math.ceil(this.size / batchSize);
66454 }
66455 else {
66456 // If the size of this dataset is known and not include small last batch,
66457 // the new size is full batch count.
66458 size = Math.floor(this.size / batchSize);
66459 }
66460 return datasetFromIteratorFn(async () => {
66461 return (await base.iterator())
66462 .columnMajorBatch(batchSize, smallLastBatch, deepBatchConcat);
66463 }, size);
66464 }
66465 /**
66466 * Concatenates this `Dataset` with another.
66467 *
66468 * ```js
66469 * const a = tf.data.array([1, 2, 3]);
66470 * const b = tf.data.array([4, 5, 6]);
66471 * const c = a.concatenate(b);
66472 * await c.forEachAsync(e => console.log(e));
66473 * ```
66474 *
66475 * @param dataset A `Dataset` to be concatenated onto this one.
66476 * @returns A `Dataset`.
66477 *
66478 * @doc {heading: 'Data', subheading: 'Classes'}
66479 */
66480 concatenate(dataset) {
66481 const base = this;
66482 let size;
66483 if (this.size === Infinity || dataset.size === Infinity) {
66484 // If the size of any of these two dataset is infinity, new size is
66485 // infinity.
66486 size = Infinity;
66487 }
66488 else if (this.size != null && dataset.size != null) {
66489 // If the size of both datasets are known and not infinity, new size is
66490 // sum the size of these two datasets.
66491 size = this.size + dataset.size;
66492 }
66493 else {
66494 // If neither of these two datasets has infinite size and any of these two
66495 // datasets' size is null, the new size is null.
66496 size = null;
66497 }
66498 return datasetFromIteratorFn(async () => (await base.iterator()).concatenate(await dataset.iterator()), size);
66499 }
66500 /**
66501 * Filters this dataset according to `predicate`.
66502 *
66503 * ```js
66504 * const a = tf.data.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
66505 * .filter(x => x%2 === 0);
66506 * await a.forEachAsync(e => console.log(e));
66507 * ```
66508 *
66509 * @param predicate A function mapping a dataset element to a boolean or a
66510 * `Promise` for one.
66511 *
66512 * @returns A `Dataset` of elements for which the predicate was true.
66513 *
66514 * @doc {heading: 'Data', subheading: 'Classes'}
66515 */
66516 filter(predicate) {
66517 const base = this;
66518 let size;
66519 if (this.size === Infinity) {
66520 // If the size of this dataset is infinity, new size is infinity
66521 size = Infinity;
66522 }
66523 else {
66524 // If this dataset has limited elements, new size is null because it might
66525 // exhausted randomly.
66526 size = null;
66527 }
66528 return datasetFromIteratorFn(async () => {
66529 return (await base.iterator()).filter(x => tidy(() => predicate(x)));
66530 }, size);
66531 }
66532 /**
66533 * Apply a function to every element of the dataset.
66534 *
66535 * After the function is applied to a dataset element, any Tensors contained
66536 * within that element are disposed.
66537 *
66538 * ```js
66539 * const a = tf.data.array([1, 2, 3]);
66540 * await a.forEachAsync(e => console.log(e));
66541 * ```
66542 *
66543 * @param f A function to apply to each dataset element.
66544 * @returns A `Promise` that resolves after all elements have been processed.
66545 *
66546 * @doc {heading: 'Data', subheading: 'Classes'}
66547 */
66548 async forEachAsync(f) {
66549 return (await this.iterator()).forEachAsync(f);
66550 }
66551 /**
66552 * Maps this dataset through a 1-to-1 transform.
66553 *
66554 * ```js
66555 * const a = tf.data.array([1, 2, 3]).map(x => x*x);
66556 * await a.forEachAsync(e => console.log(e));
66557 * ```
66558 *
66559 * @param transform A function mapping a dataset element to a transformed
66560 * dataset element.
66561 *
66562 * @returns A `Dataset` of transformed elements.
66563 *
66564 * @doc {heading: 'Data', subheading: 'Classes'}
66565 */
66566 map(transform) {
66567 const base = this;
66568 return datasetFromIteratorFn(async () => {
66569 return (await base.iterator()).map(x => tidy(() => transform(x)));
66570 }, this.size);
66571 }
66572 /**
66573 * Maps this dataset through an async 1-to-1 transform.
66574 *
66575 * ```js
66576 * const a =
66577 * tf.data.array([1, 2, 3]).mapAsync(x => new Promise(function(resolve){
66578 * setTimeout(() => {
66579 * resolve(x * x);
66580 * }, Math.random()*1000 + 500);
66581 * }));
66582 * console.log(await a.toArray());
66583 * ```
66584 *
66585 * @param transform A function mapping a dataset element to a `Promise` for a
66586 * transformed dataset element. This transform is responsible for disposing
66587 * any intermediate `Tensor`s, i.e. by wrapping its computation in
66588 * `tf.tidy()`; that cannot be automated here (as it is in the synchronous
66589 * `map()` case).
66590 *
66591 * @returns A `Dataset` of transformed elements.
66592 *
66593 * @doc {heading: 'Data', subheading: 'Classes'}
66594 */
66595 mapAsync(transform) {
66596 const base = this;
66597 return datasetFromIteratorFn(async () => {
66598 return (await base.iterator()).mapAsync(transform);
66599 }, this.size);
66600 }
66601 /**
66602 * Creates a `Dataset` that prefetches elements from this dataset.
66603 *
66604 * @param bufferSize: An integer specifying the number of elements to be
66605 * prefetched.
66606 * @returns A `Dataset`.
66607 *
66608 * @doc {heading: 'Data', subheading: 'Classes'}
66609 */
66610 prefetch(bufferSize) {
66611 if (bufferSize == null) {
66612 throw new RangeError('`Dataset.prefetch()` requires bufferSize to be specified.');
66613 }
66614 const base = this;
66615 return datasetFromIteratorFn(async () => (await base.iterator()).prefetch(bufferSize), this.size);
66616 }
66617 /**
66618 * Repeats this dataset `count` times.
66619 *
66620 * NOTE: If this dataset is a function of global state (e.g. a random number
66621 * generator), then different repetitions may produce different elements.
66622 *
66623 * ```js
66624 * const a = tf.data.array([1, 2, 3]).repeat(3);
66625 * await a.forEachAsync(e => console.log(e));
66626 * ```
66627 *
66628 * @param count: (Optional) An integer, representing the number of times
66629 * the dataset should be repeated. The default behavior (if `count` is
66630 * `undefined` or negative) is for the dataset be repeated indefinitely.
66631 * @returns A `Dataset`.
66632 *
66633 * @doc {heading: 'Data', subheading: 'Classes'}
66634 */
66635 repeat(count) {
66636 const base = this;
66637 let size;
66638 if (this.size != null && count > 0) {
66639 // If this dataset has size and count is positive, new size is current
66640 // size multiply count. This also covers the case that current size is
66641 // infinity.
66642 size = this.size * count;
66643 }
66644 else if (count === 0) {
66645 // If count is 0, new size is 0.
66646 size = 0;
66647 }
66648 else if (this.size != null && (count === undefined || count < 0)) {
66649 // If this dataset has size and count is undefined or negative, the
66650 // dataset will be repeated indefinitely and new size is infinity.
66651 size = Infinity;
66652 }
66653 else {
66654 // If the size of this dataset is null, the new dataset's size is null.
66655 size = null;
66656 }
66657 return datasetFromIteratorFn(async () => {
66658 const iteratorIterator = iteratorFromFunction(async () => ({ value: await base.iterator(), done: false }));
66659 return iteratorFromConcatenated(iteratorIterator.take(count));
66660 }, size);
66661 }
66662 /**
66663 * Creates a `Dataset` that skips `count` initial elements from this dataset.
66664 *
66665 * ```js
66666 * const a = tf.data.array([1, 2, 3, 4, 5, 6]).skip(3);
66667 * await a.forEachAsync(e => console.log(e));
66668 * ```
66669 *
66670 * @param count: The number of elements of this dataset that should be skipped
66671 * to form the new dataset. If `count` is greater than the size of this
66672 * dataset, the new dataset will contain no elements. If `count`
66673 * is `undefined` or negative, skips the entire dataset.
66674 *
66675 * @returns A `Dataset`.
66676 *
66677 * @doc {heading: 'Data', subheading: 'Classes'}
66678 */
66679 skip(count) {
66680 const base = this;
66681 let size;
66682 if (this.size != null && count >= 0 && this.size >= count) {
66683 // If the size of this dataset is greater than count, the new dataset's
66684 // size is current size minus skipped size.This also covers the case that
66685 // current size is infinity.
66686 size = this.size - count;
66687 }
66688 else if (this.size != null &&
66689 (this.size < count || count === undefined || count < 0)) {
66690 // If the size of this dataset is smaller than count, or count is
66691 // undefined or negative, skips the entire dataset and the new size is 0.
66692 size = 0;
66693 }
66694 else {
66695 // If the size of this dataset is null, the new dataset's size is null.
66696 size = null;
66697 }
66698 return datasetFromIteratorFn(async () => (await base.iterator()).skip(count), size);
66699 }
66700 /**
66701 * Pseudorandomly shuffles the elements of this dataset. This is done in a
66702 * streaming manner, by sampling from a given number of prefetched elements.
66703 *
66704 * ```js
66705 * const a = tf.data.array([1, 2, 3, 4, 5, 6]).shuffle(3);
66706 * await a.forEachAsync(e => console.log(e));
66707 * ```
66708 *
66709 * @param bufferSize: An integer specifying the number of elements from this
66710 * dataset from which the new dataset will sample.
66711 * @param seed: (Optional) An integer specifying the random seed that will
66712 * be used to create the distribution.
66713 * @param reshuffleEachIteration: (Optional) A boolean, which if true
66714 * indicates that the dataset should be pseudorandomly reshuffled each time
66715 * it is iterated over. If false, elements will be returned in the same
66716 * shuffled order on each iteration. (Defaults to `true`.)
66717 * @returns A `Dataset`.
66718 *
66719 * @doc {heading: 'Data', subheading: 'Classes'}
66720 */
66721 shuffle(bufferSize, seed, reshuffleEachIteration = true) {
66722 if (bufferSize == null || bufferSize < 0) {
66723 if (this.size == null) {
66724 throw new RangeError('`Dataset.shuffle()` requires bufferSize to be specified.');
66725 }
66726 else {
66727 throw new RangeError('`Dataset.shuffle()` requires bufferSize to be specified. ' +
66728 'If your data fits in main memory (for regular JS objects), ' +
66729 'and/or GPU memory (for `tf.Tensor`s), consider setting ' +
66730 `bufferSize to the dataset size (${this.size} elements)`);
66731 }
66732 }
66733 const base = this;
66734 const random = seedrandom_1(seed || now().toString());
66735 return datasetFromIteratorFn(async () => {
66736 let seed2 = random.int32();
66737 if (reshuffleEachIteration) {
66738 seed2 += random.int32();
66739 }
66740 return (await base.iterator()).shuffle(bufferSize, seed2.toString());
66741 }, this.size);
66742 }
66743 /**
66744 * Creates a `Dataset` with at most `count` initial elements from this
66745 * dataset.
66746 *
66747 * ```js
66748 * const a = tf.data.array([1, 2, 3, 4, 5, 6]).take(3);
66749 * await a.forEachAsync(e => console.log(e));
66750 * ```
66751 *
66752 * @param count: The number of elements of this dataset that should be taken
66753 * to form the new dataset. If `count` is `undefined` or negative, or if
66754 * `count` is greater than the size of this dataset, the new dataset will
66755 * contain all elements of this dataset.
66756 * @returns A `Dataset`.
66757 *
66758 * @doc {heading: 'Data', subheading: 'Classes'}
66759 */
66760 take(count) {
66761 const base = this;
66762 let size;
66763 if (this.size != null && this.size > count) {
66764 // If the size of this dataset is greater than count, the new dataset's
66765 // size is count.
66766 size = count;
66767 }
66768 else if (this.size != null && this.size <= count) {
66769 // If the size of this dataset is equal or smaller than count, the new
66770 // dataset's size is the size of this dataset.
66771 size = this.size;
66772 }
66773 else {
66774 // If the size of this dataset is null, the new dataset's size is null.
66775 size = null;
66776 }
66777 return datasetFromIteratorFn(async () => (await base.iterator()).take(count), size);
66778 }
66779 /**
66780 * Collect all elements of this dataset into an array.
66781 *
66782 * Obviously this will succeed only for small datasets that fit in memory.
66783 * Useful for testing and generally should be avoided if possible.
66784 *
66785 * ```js
66786 * const a = tf.data.array([1, 2, 3, 4, 5, 6]);
66787 * console.log(await a.toArray());
66788 * ```
66789 *
66790 * @returns A Promise for an array of elements, which will resolve
66791 * when a new stream has been obtained and fully consumed.
66792 *
66793 * @doc {heading: 'Data', subheading: 'Classes'}
66794 */
66795 async toArray() {
66796 if (this.size === Infinity) {
66797 throw new Error('Can not convert infinite data stream to array.');
66798 }
66799 return (await this.iterator()).toArray();
66800 }
66801 /**
66802 * Collect all elements of this dataset into an array with prefetching 100
66803 * elements. This is useful for testing, because the prefetch changes the
66804 * order in which the Promises are resolved along the processing pipeline.
66805 * This may help expose bugs where results are dependent on the order of
66806 * Promise resolution rather than on the logical order of the stream (i.e.,
66807 * due to hidden mutable state).
66808 *
66809 * @returns A Promise for an array of elements, which will resolve
66810 * when a new stream has been obtained and fully consumed.
66811 */
66812 async toArrayForTest() {
66813 if (this.size === Infinity) {
66814 throw new Error('Can not convert infinite data stream to array.');
66815 }
66816 return (await this.iterator()).toArrayForTest();
66817 }
66818 }
66819 // TODO(soergel): deep sharded shuffle, where supported
66820 Dataset.MAX_BUFFER_SIZE = 10000;
66821 /**
66822 * Create a `Dataset` defined by a provided iterator() function.
66823 *
66824 * ```js
66825 * let i = -1;
66826 * const func = () =>
66827 * ++i < 5 ? {value: i, done: false} : {value: null, done: true};
66828 * const iter = tf.data.iteratorFromFunction(func);
66829 * const ds = tf.data.datasetFromIteratorFn(iter);
66830 * await ds.forEachAsync(e => console.log(e));
66831 * ```
66832 */
66833 function datasetFromIteratorFn(iteratorFn, size = null) {
66834 return new class extends Dataset {
66835 constructor() {
66836 super(...arguments);
66837 this.size = size;
66838 }
66839 /*
66840 * Provide a new stream of elements. Note this will also start new streams
66841 * from any underlying `Dataset`s.
66842 */
66843 async iterator() {
66844 return iteratorFn();
66845 }
66846 }();
66847 }
66848 /**
66849 * Create a `Dataset` from an array of elements.
66850 *
66851 * Create a Dataset from an array of objects:
66852 * ```js
66853 * const a = tf.data.array([{'item': 1}, {'item': 2}, {'item': 3}]);
66854 * await a.forEachAsync(e => console.log(e));
66855 * ```
66856 *
66857 * Create a Dataset from an array of numbers:
66858 * ```js
66859 * const a = tf.data.array([4, 5, 6]);
66860 * await a.forEachAsync(e => console.log(e));
66861 * ```
66862 * @param items An array of elements that will be parsed as items in a dataset.
66863 *
66864 * @doc {heading: 'Data', subheading: 'Creation', namespace: 'data'}
66865 */
66866 function array(items) {
66867 return datasetFromIteratorFn(async () => iteratorFromItems(items), items.length);
66868 }
66869 /**
66870 * Create a `Dataset` by zipping together an array, dict, or nested
66871 * structure of `Dataset`s (and perhaps additional constants).
66872 * The underlying datasets must provide elements in a consistent order such that
66873 * they correspond.
66874 *
66875 * The number of elements in the resulting dataset is the same as the size of
66876 * the smallest dataset in datasets.
66877 *
66878 * The nested structure of the `datasets` argument determines the
66879 * structure of elements in the resulting iterator.
66880 *
66881 * Note this means that, given an array of two datasets that produce dict
66882 * elements, the result is a dataset that produces elements that are arrays
66883 * of two dicts:
66884 *
66885 * Zip an array of datasets:
66886 * ```js
66887 * console.log('Zip two datasets of objects:');
66888 * const ds1 = tf.data.array([{a: 1}, {a: 2}, {a: 3}]);
66889 * const ds2 = tf.data.array([{b: 4}, {b: 5}, {b: 6}]);
66890 * const ds3 = tf.data.zip([ds1, ds2]);
66891 * await ds3.forEachAsync(e => console.log(JSON.stringify(e)));
66892 *
66893 * // If the goal is to merge the dicts in order to produce elements like
66894 * // {a: ..., b: ...}, this requires a second step such as:
66895 * console.log('Merge the objects:');
66896 * const ds4 = ds3.map(x => {return {a: x[0].a, b: x[1].b}});
66897 * await ds4.forEachAsync(e => console.log(e));
66898 * ```
66899 *
66900 * Zip a dict of datasets:
66901 * ```js
66902 * const a = tf.data.array([{a: 1}, {a: 2}, {a: 3}]);
66903 * const b = tf.data.array([{b: 4}, {b: 5}, {b: 6}]);
66904 * const c = tf.data.zip({c: a, d: b});
66905 * await c.forEachAsync(e => console.log(JSON.stringify(e)));
66906 * ```
66907 *
66908 * @doc {heading: 'Data', subheading: 'Operations', namespace: 'data'}
66909 */
66910 function zip(datasets) {
66911 // manually type-check the argument for JS users
66912 if (!isIterable$1(datasets)) {
66913 throw new Error('The argument to zip() must be an object or array.');
66914 }
66915 let size;
66916 if (Array.isArray(datasets)) {
66917 for (let i = 0; i < datasets.length; i++) {
66918 size = size == null ? datasets[i].size :
66919 Math.min(size, datasets[i].size);
66920 }
66921 }
66922 else if (datasets instanceof Object) {
66923 for (const ds in datasets) {
66924 size = size == null ? datasets[ds].size :
66925 Math.min(size, datasets[ds].size);
66926 }
66927 }
66928 return datasetFromIteratorFn(async () => {
66929 const streams = await deepMapAndAwaitAll(datasets, d => {
66930 if (d instanceof Dataset) {
66931 return { value: d.iterator(), recurse: false };
66932 }
66933 else if (isIterable$1(d)) {
66934 return { value: null, recurse: true };
66935 }
66936 else {
66937 throw new Error('Leaves of the structure passed to zip() must be Datasets, ' +
66938 'not primitives.');
66939 }
66940 });
66941 return iteratorFromZipped(streams, ZipMismatchMode.SHORTEST);
66942 }, size);
66943 }
66944 /**
66945 * A zip function for use with deepZip, passed via the columnMajorBatch call.
66946 *
66947 * Accepts an array of identically-structured nested elements and either batches
66948 * them (if they are primitives, numeric arrays, or Tensors) or requests
66949 * recursion (if not).
66950 */
66951 // tslint:disable-next-line:no-any
66952 function deepBatchConcat(rows) {
66953 if (rows === null) {
66954 return null;
66955 }
66956 // use the first item to decide whether to recurse or batch here.
66957 const exampleRow = rows[0];
66958 if (canTensorify(exampleRow)) {
66959 // rows is an array of primitives, Tensors, or arrays. Batch them.
66960 const value = batchConcat(rows);
66961 return { value, recurse: false };
66962 }
66963 // the example row is an object, so recurse into it.
66964 return { value: null, recurse: true };
66965 }
66966 /**
66967 * Assembles a list of same-shaped numbers, number arrays, or Tensors
66968 * into a single new Tensor where axis 0 is the batch dimension.
66969 */
66970 function batchConcat(arrays) {
66971 if (arrays.length === 0) {
66972 // We can't return an empty Tensor because we don't know the element shape.
66973 throw new Error('Can\'t make a batch of zero elements.');
66974 }
66975 if (arrays[0] instanceof Tensor) {
66976 // Input is an array of Tensors
66977 return stack(arrays);
66978 }
66979 else {
66980 // Input is a possibly-nested array of numbers.
66981 return tensor(arrays);
66982 }
66983 }
66984
66985 /**
66986 * @license
66987 * Copyright 2018 Google LLC. All Rights Reserved.
66988 * Licensed under the Apache License, Version 2.0 (the "License");
66989 * you may not use this file except in compliance with the License.
66990 * You may obtain a copy of the License at
66991 *
66992 * http://www.apache.org/licenses/LICENSE-2.0
66993 *
66994 * Unless required by applicable law or agreed to in writing, software
66995 * distributed under the License is distributed on an "AS IS" BASIS,
66996 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
66997 * See the License for the specific language governing permissions and
66998 * limitations under the License.
66999 *
67000 * =============================================================================
67001 */
67002 /**
67003 * Represents a potentially large collection of text lines.
67004 *
67005 * The results are not batched.
67006 */
67007 class TextLineDataset extends Dataset {
67008 /**
67009 * Create a `TextLineDataset`.
67010 *
67011 * @param input A `DataSource` providing a chunked, UTF8-encoded byte stream.
67012 */
67013 constructor(input) {
67014 super();
67015 this.input = input;
67016 }
67017 async iterator() {
67018 const inputIterator = await this.input.iterator();
67019 const utf8Iterator = inputIterator.decodeUTF8();
67020 const lineIterator = utf8Iterator.split('\n').map(line => {
67021 // Windows/DOS format text file has extra line breaker at the end of line.
67022 if (line.endsWith('\r')) {
67023 line = line.slice(0, -1);
67024 }
67025 return line;
67026 });
67027 return lineIterator;
67028 }
67029 }
67030
67031 /**
67032 * @license
67033 * Copyright 2018 Google LLC. All Rights Reserved.
67034 * Licensed under the Apache License, Version 2.0 (the "License");
67035 * you may not use this file except in compliance with the License.
67036 * You may obtain a copy of the License at
67037 *
67038 * http://www.apache.org/licenses/LICENSE-2.0
67039 *
67040 * Unless required by applicable law or agreed to in writing, software
67041 * distributed under the License is distributed on an "AS IS" BASIS,
67042 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
67043 * See the License for the specific language governing permissions and
67044 * limitations under the License.
67045 *
67046 * =============================================================================
67047 */
67048 const CODE_QUOTE = '"';
67049 const STATE_OUT = Symbol('out');
67050 const STATE_FIELD = Symbol('field');
67051 const STATE_QUOTE = Symbol('quote');
67052 const STATE_QUOTE_AFTER_QUOTE = Symbol('quoteafterquote');
67053 const STATE_WITHIN_QUOTE_IN_QUOTE = Symbol('quoteinquote');
67054 /**
67055 * Represents a potentially large collection of delimited text records.
67056 *
67057 * The produced `TensorContainer`s each contain one key-value pair for
67058 * every column of the table. When a field is empty in the incoming data, the
67059 * resulting value is `undefined`, or throw error if it is required. Values
67060 * that can be parsed as numbers are emitted as type `number`, other values
67061 * are parsed as `string`.
67062 *
67063 * The results are not batched.
67064 *
67065 * @doc {heading: 'Data', subheading: 'Classes', namespace: 'data'}
67066 */
67067 class CSVDataset extends Dataset {
67068 /**
67069 * Create a `CSVDataset`.
67070 *
67071 * @param input A `DataSource` providing a chunked, UTF8-encoded byte stream.
67072 * @param csvConfig (Optional) A CSVConfig object that contains configurations
67073 * of reading and decoding from CSV file(s).
67074 *
67075 * hasHeader: (Optional) A boolean value that indicates whether the first
67076 * row of provided CSV file is a header line with column names, and should
67077 * not be included in the data. Defaults to `true`.
67078 *
67079 * columnNames: (Optional) A list of strings that corresponds to
67080 * the CSV column names, in order. If provided, it ignores the column
67081 * names inferred from the header row. If not provided, infers the column
67082 * names from the first row of the records. If hasHeader is false and
67083 * columnNames is not provided, this method throws an error.
67084 *
67085 * columnConfigs: (Optional) A dictionary whose key is column names, value
67086 * is an object stating if this column is required, column's data type,
67087 * default value, and if this column is label. If provided, keys must
67088 * correspond to names provided in columnNames or inferred from the file
67089 * header lines. If isLabel is true any column, returns an array of two
67090 * items: the first item is a dict of features key/value pairs, the second
67091 * item is a dict of labels key/value pairs. If no feature is marked as
67092 * label, returns a dict of features only.
67093 *
67094 * configuredColumnsOnly (Optional) If true, only columns provided in
67095 * columnConfigs will be parsed and provided during iteration.
67096 *
67097 * delimiter (Optional) The string used to parse each line of the input
67098 * file. Defaults to `,`.
67099 */
67100 constructor(input, csvConfig) {
67101 super();
67102 this.input = input;
67103 this.hasHeader = true;
67104 this.fullColumnNames = null;
67105 this.columnNamesValidated = false;
67106 this.columnConfigs = null;
67107 this.configuredColumnsOnly = false;
67108 this.delimiter = ',';
67109 this.delimWhitespace = false;
67110 this.base = new TextLineDataset(input);
67111 if (!csvConfig) {
67112 csvConfig = {};
67113 }
67114 this.hasHeader = csvConfig.hasHeader === false ? false : true;
67115 this.fullColumnNames = csvConfig.columnNames;
67116 this.columnConfigs = csvConfig.columnConfigs;
67117 this.configuredColumnsOnly = csvConfig.configuredColumnsOnly;
67118 if (csvConfig.delimWhitespace) {
67119 assert(csvConfig.delimiter == null, () => 'Delimiter should not be provided when delimWhitespace is true.');
67120 this.delimWhitespace = true;
67121 this.delimiter = ' ';
67122 }
67123 else {
67124 this.delimiter = csvConfig.delimiter ? csvConfig.delimiter : ',';
67125 }
67126 }
67127 /**
67128 * Returns column names of the csv dataset. If `configuredColumnsOnly` is
67129 * true, return column names in `columnConfigs`. If `configuredColumnsOnly` is
67130 * false and `columnNames` is provided, `columnNames`. If
67131 * `configuredColumnsOnly` is false and `columnNames` is not provided, return
67132 * all column names parsed from the csv file. For example usage please go to
67133 * `tf.data.csv`.
67134 *
67135 * @doc {heading: 'Data', subheading: 'Classes'}
67136 */
67137 async columnNames() {
67138 if (!this.columnNamesValidated) {
67139 await this.setColumnNames();
67140 }
67141 return this.configuredColumnsOnly ? Object.keys(this.columnConfigs) :
67142 this.fullColumnNames;
67143 }
67144 /* 1) If `columnNames` is provided as string[], use this string[] as output
67145 * keys in corresponding order. The length must match the number of inferred
67146 * columns if `hasHeader` is true .
67147 * 2) If `columnNames` is not provided, parse header line as `columnNames` if
67148 * hasHeader is true. If `hasHeader` is false, throw an error.
67149 * 3) If `columnConfigs` is provided, all the keys in `columnConfigs` must
67150 * exist in parsed `columnNames`.
67151 */
67152 async setColumnNames() {
67153 const columnNamesFromFile = await this.maybeReadHeaderLine();
67154 if (!this.fullColumnNames && !columnNamesFromFile) {
67155 // Throw an error if columnNames is not provided and no header line.
67156 throw new Error('Column names must be provided if there is no header line.');
67157 }
67158 else if (this.fullColumnNames && columnNamesFromFile) {
67159 // Check provided columnNames match header line.
67160 assert(columnNamesFromFile.length === this.fullColumnNames.length, () => 'The length of provided columnNames (' +
67161 this.fullColumnNames.length.toString() +
67162 ') does not match the length of the header line read from ' +
67163 'file (' + columnNamesFromFile.length.toString() + ').');
67164 }
67165 if (!this.fullColumnNames) {
67166 this.fullColumnNames = columnNamesFromFile;
67167 }
67168 // Check if there are duplicate column names.
67169 const counts = this.fullColumnNames.reduce((countAcc, name) => {
67170 countAcc[name] = (countAcc[name] + 1) || 1;
67171 return countAcc;
67172 }, {});
67173 const duplicateNames = Object.keys(counts).filter((name) => (counts[name] > 1));
67174 assert(duplicateNames.length === 0, () => 'Duplicate column names found: ' + duplicateNames.toString());
67175 // Check if keys in columnConfigs match columnNames.
67176 if (this.columnConfigs) {
67177 for (const key of Object.keys(this.columnConfigs)) {
67178 const index = this.fullColumnNames.indexOf(key);
67179 if (index === -1) {
67180 throw new Error('The key "' + key +
67181 '" provided in columnConfigs does not match any of the column ' +
67182 'names (' + this.fullColumnNames.toString() + ').');
67183 }
67184 }
67185 }
67186 this.columnNamesValidated = true;
67187 }
67188 async maybeReadHeaderLine() {
67189 if (this.hasHeader) {
67190 const iter = await this.base.iterator();
67191 const firstElement = await iter.next();
67192 if (firstElement.done) {
67193 throw new Error('No data was found for CSV parsing.');
67194 }
67195 const firstLine = firstElement.value;
67196 const headers = this.parseRow(firstLine, false);
67197 return headers;
67198 }
67199 else {
67200 return null;
67201 }
67202 }
67203 async iterator() {
67204 if (!this.columnNamesValidated) {
67205 await this.setColumnNames();
67206 }
67207 let lines = await this.base.iterator();
67208 if (this.hasHeader) {
67209 // We previously read the first line to get the columnNames.
67210 // Now that we're providing data, skip it.
67211 lines = lines.skip(1);
67212 }
67213 return lines.map(x => this.makeDataElement(x));
67214 }
67215 makeDataElement(line) {
67216 const values = this.parseRow(line);
67217 const features = {};
67218 const labels = {};
67219 for (let i = 0; i < this.fullColumnNames.length; i++) {
67220 const key = this.fullColumnNames[i];
67221 const config = this.columnConfigs ? this.columnConfigs[key] : null;
67222 if (this.configuredColumnsOnly && !config) {
67223 // This column is not selected.
67224 continue;
67225 }
67226 else {
67227 const value = values[i];
67228 let parsedValue = null;
67229 if (value === '') {
67230 // If default value is provided, use it. If default value is not
67231 // provided, set as undefined.
67232 if (config && config.default !== undefined) {
67233 parsedValue = config.default;
67234 }
67235 else if (config && (config.required || config.isLabel)) {
67236 throw new Error(`Required column ${key} is empty in this line: ${line}`);
67237 }
67238 else {
67239 parsedValue = undefined;
67240 }
67241 }
67242 else {
67243 // A value is present, so parse it based on type
67244 const valueAsNum = Number(value);
67245 if (isNaN(valueAsNum)) {
67246 // The value is a string and this column is declared as boolean
67247 // in config, parse it as boolean.
67248 if (config && config.dtype === 'bool') {
67249 parsedValue = this.getBoolean(value);
67250 }
67251 else {
67252 // Set value as string
67253 parsedValue = value;
67254 }
67255 }
67256 else if (!config || !config.dtype) {
67257 // If this value is a number and no type config is provided, return
67258 // it as number.
67259 parsedValue = valueAsNum;
67260 }
67261 else {
67262 // If this value is a number and data type is provided, parse it
67263 // according to provided data type.
67264 switch (config.dtype) {
67265 case 'float32':
67266 parsedValue = valueAsNum;
67267 break;
67268 case 'int32':
67269 parsedValue = Math.floor(valueAsNum);
67270 break;
67271 case 'bool':
67272 parsedValue = this.getBoolean(value);
67273 break;
67274 default:
67275 parsedValue = valueAsNum;
67276 }
67277 }
67278 }
67279 // Check if this column is label.
67280 (config && config.isLabel) ? labels[key] = parsedValue :
67281 features[key] = parsedValue;
67282 }
67283 }
67284 // If label exists, return an object of features and labels as {xs:features,
67285 // ys:labels}, otherwise return features only.
67286 if (Object.keys(labels).length === 0) {
67287 return features;
67288 }
67289 else {
67290 return { xs: features, ys: labels };
67291 }
67292 }
67293 getBoolean(value) {
67294 if (value === '1' || value.toLowerCase() === 'true') {
67295 return 1;
67296 }
67297 else {
67298 return 0;
67299 }
67300 }
67301 // adapted from https://beta.observablehq.com/@mbostock/streaming-csv
67302 parseRow(line, validateElementCount = true) {
67303 const result = [];
67304 let readOffset = 0;
67305 const readLength = line.length;
67306 let currentState = STATE_OUT;
67307 // Goes through the line to parse quote.
67308 for (let i = 0; i < readLength; i++) {
67309 switch (currentState) {
67310 // Before enter a new field
67311 case STATE_OUT:
67312 switch (line.charAt(i)) {
67313 // Enter a quoted field
67314 case CODE_QUOTE:
67315 readOffset = i + 1;
67316 currentState = STATE_QUOTE;
67317 break;
67318 // Read an empty field
67319 case this.delimiter:
67320 readOffset = i + 1;
67321 // If delimiter is white space and configured to collapse
67322 // multiple white spaces, ignore this white space.
67323 if (this.delimiter === ' ' && this.delimWhitespace) {
67324 break;
67325 }
67326 result.push('');
67327 currentState = STATE_OUT;
67328 break;
67329 // Enter an unquoted field
67330 default:
67331 currentState = STATE_FIELD;
67332 readOffset = i;
67333 break;
67334 }
67335 break;
67336 // In an unquoted field
67337 case STATE_FIELD:
67338 switch (line.charAt(i)) {
67339 // Exit an unquoted field, add it to result
67340 case this.delimiter:
67341 result.push(line.substring(readOffset, i));
67342 currentState = STATE_OUT;
67343 readOffset = i + 1;
67344 break;
67345 default:
67346 }
67347 break;
67348 // In a quoted field
67349 case STATE_QUOTE:
67350 switch (line.charAt(i)) {
67351 // Read a quote after a quote
67352 case CODE_QUOTE:
67353 currentState = STATE_QUOTE_AFTER_QUOTE;
67354 break;
67355 default:
67356 }
67357 break;
67358 // This state means it's right after a second quote in a field
67359 case STATE_QUOTE_AFTER_QUOTE:
67360 switch (line.charAt(i)) {
67361 // Finished a quoted field
67362 case this.delimiter:
67363 result.push(line.substring(readOffset, i - 1));
67364 currentState = STATE_OUT;
67365 readOffset = i + 1;
67366 break;
67367 // Finished a quoted part in a quoted field
67368 case CODE_QUOTE:
67369 currentState = STATE_QUOTE;
67370 break;
67371 // In a quoted part in a quoted field
67372 default:
67373 currentState = STATE_WITHIN_QUOTE_IN_QUOTE;
67374 break;
67375 }
67376 break;
67377 case STATE_WITHIN_QUOTE_IN_QUOTE:
67378 switch (line.charAt(i)) {
67379 // Exit a quoted part in a quoted field
67380 case CODE_QUOTE:
67381 currentState = STATE_QUOTE;
67382 break;
67383 default:
67384 }
67385 break;
67386 default:
67387 }
67388 }
67389 // Adds last item based on if it is quoted.
67390 if (currentState === STATE_QUOTE_AFTER_QUOTE) {
67391 result.push(line.substring(readOffset, readLength - 1));
67392 }
67393 else {
67394 result.push(line.substring(readOffset));
67395 }
67396 // Check if each row has the same number of elements as column names.
67397 if (validateElementCount && result.length !== this.fullColumnNames.length) {
67398 throw new Error(`Invalid row in csv file. Should have ${this.fullColumnNames.length} elements in a row, but got ${result}`);
67399 }
67400 return result;
67401 }
67402 }
67403 // TODO(soergel): add more basic datasets for parity with tf.data
67404 // tf.data.FixedLengthRecordDataset()
67405 // tf.data.TFRecordDataset()
67406
67407 /**
67408 * @license
67409 * Copyright 2019 Google LLC. All Rights Reserved.
67410 * Licensed under the Apache License, Version 2.0 (the "License");
67411 * you may not use this file except in compliance with the License.
67412 * You may obtain a copy of the License at
67413 *
67414 * http://www.apache.org/licenses/LICENSE-2.0
67415 *
67416 * Unless required by applicable law or agreed to in writing, software
67417 * distributed under the License is distributed on an "AS IS" BASIS,
67418 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
67419 * See the License for the specific language governing permissions and
67420 * limitations under the License.
67421 *
67422 * =============================================================================
67423 */
67424 /**
67425 * Provide a stream of tensors from microphone audio stream. The tensors are
67426 * representing audio data as frequency-domain spectrogram generated with
67427 * browser's native FFT. Tensors representing time-domain waveform is available
67428 * based on configuration. Only works in browser environment.
67429 */
67430 class MicrophoneIterator extends LazyIterator {
67431 constructor(microphoneConfig) {
67432 super();
67433 this.microphoneConfig = microphoneConfig;
67434 this.isClosed = false;
67435 this.fftSize = microphoneConfig.fftSize || 1024;
67436 const fftSizeLog2 = Math.log2(this.fftSize);
67437 if (this.fftSize < 0 || fftSizeLog2 < 4 || fftSizeLog2 > 14 ||
67438 !Number.isInteger(fftSizeLog2)) {
67439 throw new Error(`Invalid fftSize: it must be a power of 2 between ` +
67440 `2 to 4 and 2 to 14, but got ${this.fftSize}`);
67441 }
67442 this.numFrames = microphoneConfig.numFramesPerSpectrogram || 43;
67443 this.sampleRateHz = microphoneConfig.sampleRateHz;
67444 this.columnTruncateLength =
67445 microphoneConfig.columnTruncateLength || this.fftSize;
67446 this.audioTrackConstraints = microphoneConfig.audioTrackConstraints;
67447 this.smoothingTimeConstant = microphoneConfig.smoothingTimeConstant || 0;
67448 this.includeSpectrogram =
67449 microphoneConfig.includeSpectrogram === false ? false : true;
67450 this.includeWaveform =
67451 microphoneConfig.includeWaveform === true ? true : false;
67452 if (!this.includeSpectrogram && !this.includeWaveform) {
67453 throw new Error('Both includeSpectrogram and includeWaveform are false. ' +
67454 'At least one type of data should be returned.');
67455 }
67456 }
67457 summary() {
67458 return `microphone`;
67459 }
67460 // Construct a MicrophoneIterator and start the audio stream.
67461 static async create(microphoneConfig = {}) {
67462 if (!env().get('IS_BROWSER')) {
67463 throw new Error('microphone API is only supported in browser environment.');
67464 }
67465 const microphoneIterator = new MicrophoneIterator(microphoneConfig);
67466 // Call async function start() to initialize the audio stream.
67467 await microphoneIterator.start();
67468 return microphoneIterator;
67469 }
67470 // Start the audio stream and FFT.
67471 async start() {
67472 try {
67473 this.stream = await navigator.mediaDevices.getUserMedia({
67474 audio: this.audioTrackConstraints == null ? true :
67475 this.audioTrackConstraints,
67476 video: false
67477 });
67478 }
67479 catch (e) {
67480 throw new Error(`Error thrown while initializing video stream: ${e.message}`);
67481 }
67482 if (!this.stream) {
67483 throw new Error('Could not obtain audio from microphone.');
67484 }
67485 const ctxConstructor =
67486 // tslint:disable-next-line:no-any
67487 window.AudioContext || window.webkitAudioContext;
67488 this.audioContext = new ctxConstructor();
67489 if (!this.sampleRateHz) {
67490 // If sample rate is not provided, use the available sample rate on
67491 // device.
67492 this.sampleRateHz = this.audioContext.sampleRate;
67493 }
67494 else if (this.audioContext.sampleRate !== this.sampleRateHz) {
67495 throw new Error(`Mismatch in sampling rate: ` +
67496 `Expected: ${this.sampleRateHz}; ` +
67497 `Actual: ${this.audioContext.sampleRate}`);
67498 }
67499 const streamSource = this.audioContext.createMediaStreamSource(this.stream);
67500 this.analyser = this.audioContext.createAnalyser();
67501 this.analyser.fftSize = this.fftSize * 2;
67502 this.analyser.smoothingTimeConstant = this.smoothingTimeConstant;
67503 streamSource.connect(this.analyser);
67504 this.freqData = new Float32Array(this.fftSize);
67505 this.timeData = new Float32Array(this.fftSize);
67506 return;
67507 }
67508 async next() {
67509 if (this.isClosed) {
67510 return { value: null, done: true };
67511 }
67512 let spectrogramTensor;
67513 let waveformTensor;
67514 const audioDataQueue = await this.getAudioData();
67515 if (this.includeSpectrogram) {
67516 const freqData = this.flattenQueue(audioDataQueue.freqDataQueue);
67517 spectrogramTensor = this.getTensorFromAudioDataArray(freqData, [this.numFrames, this.columnTruncateLength, 1]);
67518 }
67519 if (this.includeWaveform) {
67520 const timeData = this.flattenQueue(audioDataQueue.timeDataQueue);
67521 waveformTensor = this.getTensorFromAudioDataArray(timeData, [this.numFrames * this.fftSize, 1]);
67522 }
67523 return {
67524 value: { 'spectrogram': spectrogramTensor, 'waveform': waveformTensor },
67525 done: false
67526 };
67527 }
67528 // Capture one result from the audio stream, and extract the value from
67529 // iterator.next() result.
67530 async capture() {
67531 return (await this.next()).value;
67532 }
67533 async getAudioData() {
67534 const freqDataQueue = [];
67535 const timeDataQueue = [];
67536 let currentFrames = 0;
67537 return new Promise(resolve => {
67538 const intervalID = setInterval(() => {
67539 if (this.includeSpectrogram) {
67540 this.analyser.getFloatFrequencyData(this.freqData);
67541 // If the audio stream is initializing, return empty queue.
67542 if (this.freqData[0] === -Infinity) {
67543 resolve({ freqDataQueue, timeDataQueue });
67544 }
67545 freqDataQueue.push(this.freqData.slice(0, this.columnTruncateLength));
67546 }
67547 if (this.includeWaveform) {
67548 this.analyser.getFloatTimeDomainData(this.timeData);
67549 timeDataQueue.push(this.timeData.slice());
67550 }
67551 // Clean interval and return when all frames have been collected
67552 if (++currentFrames === this.numFrames) {
67553 clearInterval(intervalID);
67554 resolve({ freqDataQueue, timeDataQueue });
67555 }
67556 }, this.fftSize / this.sampleRateHz * 1e3);
67557 });
67558 }
67559 // Stop the audio stream and pause the iterator.
67560 stop() {
67561 if (!this.isClosed) {
67562 this.isClosed = true;
67563 this.analyser.disconnect();
67564 this.audioContext.close();
67565 if (this.stream != null && this.stream.getTracks().length > 0) {
67566 this.stream.getTracks()[0].stop();
67567 }
67568 }
67569 }
67570 // Override toArray() function to prevent collecting.
67571 toArray() {
67572 throw new Error('Can not convert infinite audio stream to array.');
67573 }
67574 // Return audio sampling rate in Hz
67575 getSampleRate() {
67576 return this.sampleRateHz;
67577 }
67578 flattenQueue(queue) {
67579 const frameSize = queue[0].length;
67580 const freqData = new Float32Array(queue.length * frameSize);
67581 queue.forEach((data, i) => freqData.set(data, i * frameSize));
67582 return freqData;
67583 }
67584 getTensorFromAudioDataArray(freqData, shape) {
67585 const vals = new Float32Array(sizeFromShape(shape));
67586 // If the data is less than the output shape, the rest is padded with zeros.
67587 vals.set(freqData, vals.length - freqData.length);
67588 return tensor(vals, shape);
67589 }
67590 }
67591
67592 /**
67593 * @license
67594 * Copyright 2018 Google LLC. All Rights Reserved.
67595 * Licensed under the Apache License, Version 2.0 (the "License");
67596 * you may not use this file except in compliance with the License.
67597 * You may obtain a copy of the License at
67598 *
67599 * http://www.apache.org/licenses/LICENSE-2.0
67600 *
67601 * Unless required by applicable law or agreed to in writing, software
67602 * distributed under the License is distributed on an "AS IS" BASIS,
67603 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
67604 * See the License for the specific language governing permissions and
67605 * limitations under the License.
67606 *
67607 * =============================================================================
67608 */
67609 /**
67610 * Provide a stream of image tensors from webcam video stream. Only works in
67611 * browser environment.
67612 */
67613 class WebcamIterator extends LazyIterator {
67614 constructor(webcamVideoElement, webcamConfig) {
67615 super();
67616 this.webcamVideoElement = webcamVideoElement;
67617 this.webcamConfig = webcamConfig;
67618 this.isClosed = true;
67619 this.resize = false;
67620 if (this.needToResize()) {
67621 this.resize = true;
67622 this.cropSize =
67623 [this.webcamConfig.resizeHeight, this.webcamConfig.resizeWidth];
67624 this.cropBoxInd = tensor1d([0], 'int32');
67625 if (this.webcamConfig.centerCrop) {
67626 // Calculate the box based on resizing shape.
67627 const widthCroppingRatio = this.webcamConfig.resizeWidth * 1.0 / this.webcamVideoElement.width;
67628 const heightCroppingRatio = this.webcamConfig.resizeHeight * 1.0 /
67629 this.webcamVideoElement.height;
67630 const widthCropStart = (1 - widthCroppingRatio) / 2;
67631 const heightCropStart = (1 - heightCroppingRatio) / 2;
67632 const widthCropEnd = widthCropStart + widthCroppingRatio;
67633 const heightCropEnd = heightCroppingRatio + heightCropStart;
67634 this.cropBox = tensor2d([heightCropStart, widthCropStart, heightCropEnd, widthCropEnd], [1, 4]);
67635 }
67636 else {
67637 this.cropBox = tensor2d([0, 0, 1, 1], [1, 4]);
67638 }
67639 }
67640 }
67641 summary() {
67642 return `webcam`;
67643 }
67644 // Construct a WebcamIterator and start it's video stream.
67645 static async create(webcamVideoElement, webcamConfig = {}) {
67646 if (!env().get('IS_BROWSER')) {
67647 throw new Error('tf.data.webcam is only supported in browser environment.');
67648 }
67649 if (!webcamVideoElement) {
67650 // If webcam video element is not provided, create a hidden video element
67651 // with provided width and height.
67652 webcamVideoElement = document.createElement('video');
67653 if (!webcamConfig.resizeWidth || !webcamConfig.resizeHeight) {
67654 throw new Error('Please provide webcam video element, or resizeWidth and ' +
67655 'resizeHeight to create a hidden video element.');
67656 }
67657 webcamVideoElement.width = webcamConfig.resizeWidth;
67658 webcamVideoElement.height = webcamConfig.resizeHeight;
67659 }
67660 const webcamIterator = new WebcamIterator(webcamVideoElement, webcamConfig);
67661 // Call async function to initialize the video stream.
67662 await webcamIterator.start();
67663 return webcamIterator;
67664 }
67665 // Async function to start video stream.
67666 async start() {
67667 if (this.webcamConfig.facingMode) {
67668 assert((this.webcamConfig.facingMode === 'user') ||
67669 (this.webcamConfig.facingMode === 'environment'), () => `Invalid webcam facing mode: ${this.webcamConfig.facingMode}. ` +
67670 `Please provide 'user' or 'environment'`);
67671 }
67672 try {
67673 this.stream = await navigator.mediaDevices.getUserMedia({
67674 video: {
67675 deviceId: this.webcamConfig.deviceId,
67676 facingMode: this.webcamConfig.facingMode ?
67677 this.webcamConfig.facingMode :
67678 'user',
67679 width: this.webcamVideoElement.width,
67680 height: this.webcamVideoElement.height
67681 }
67682 });
67683 }
67684 catch (e) {
67685 // Modify the error message but leave the stack trace intact
67686 e.message = `Error thrown while initializing video stream: ${e.message}`;
67687 throw e;
67688 }
67689 if (!this.stream) {
67690 throw new Error('Could not obtain video from webcam.');
67691 }
67692 // Older browsers may not have srcObject
67693 try {
67694 this.webcamVideoElement.srcObject = this.stream;
67695 }
67696 catch (error) {
67697 console.log(error);
67698 this.webcamVideoElement.src = window.URL.createObjectURL(this.stream);
67699 }
67700 // Start the webcam video stream
67701 this.webcamVideoElement.play();
67702 this.isClosed = false;
67703 return new Promise(resolve => {
67704 // Add event listener to make sure the webcam has been fully initialized.
67705 this.webcamVideoElement.onloadedmetadata = () => {
67706 resolve();
67707 };
67708 });
67709 }
67710 async next() {
67711 if (this.isClosed) {
67712 return { value: null, done: true };
67713 }
67714 let img;
67715 try {
67716 img = fromPixels(this.webcamVideoElement);
67717 }
67718 catch (e) {
67719 throw new Error(`Error thrown converting video to pixels: ${JSON.stringify(e)}`);
67720 }
67721 if (this.resize) {
67722 try {
67723 return { value: this.cropAndResizeFrame(img), done: false };
67724 }
67725 catch (e) {
67726 throw new Error(`Error thrown cropping the video: ${e.message}`);
67727 }
67728 finally {
67729 img.dispose();
67730 }
67731 }
67732 else {
67733 return { value: img, done: false };
67734 }
67735 }
67736 needToResize() {
67737 // If resizeWidth and resizeHeight are provided, and different from the
67738 // width and height of original HTMLVideoElement, then resizing and cropping
67739 // is required.
67740 if (this.webcamConfig.resizeWidth && this.webcamConfig.resizeHeight &&
67741 (this.webcamVideoElement.width !== this.webcamConfig.resizeWidth ||
67742 this.webcamVideoElement.height !== this.webcamConfig.resizeHeight)) {
67743 return true;
67744 }
67745 return false;
67746 }
67747 // Cropping and resizing each frame based on config
67748 cropAndResizeFrame(img) {
67749 return tidy(() => {
67750 const expandedImage = expandDims(cast(img, 'float32'), (0));
67751 let resizedImage;
67752 resizedImage = image.cropAndResize(expandedImage, this.cropBox, this.cropBoxInd, this.cropSize, 'bilinear');
67753 // Extract image from batch cropping.
67754 const shape = resizedImage.shape;
67755 return reshape(resizedImage, shape.slice(1));
67756 });
67757 }
67758 // Capture one frame from the video stream, and extract the value from
67759 // iterator.next() result.
67760 async capture() {
67761 return (await this.next()).value;
67762 }
67763 // Stop the video stream and pause webcam iterator.
67764 stop() {
67765 const tracks = this.stream.getTracks();
67766 tracks.forEach(track => track.stop());
67767 try {
67768 this.webcamVideoElement.srcObject = null;
67769 }
67770 catch (error) {
67771 console.log(error);
67772 this.webcamVideoElement.src = null;
67773 }
67774 this.isClosed = true;
67775 }
67776 // Override toArray() function to prevent collecting.
67777 toArray() {
67778 throw new Error('Can not convert infinite video stream to array.');
67779 }
67780 }
67781
67782 /**
67783 * @license
67784 * Copyright 2018 Google LLC. All Rights Reserved.
67785 * Licensed under the Apache License, Version 2.0 (the "License");
67786 * you may not use this file except in compliance with the License.
67787 * You may obtain a copy of the License at
67788 *
67789 * http://www.apache.org/licenses/LICENSE-2.0
67790 *
67791 * Unless required by applicable law or agreed to in writing, software
67792 * distributed under the License is distributed on an "AS IS" BASIS,
67793 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
67794 * See the License for the specific language governing permissions and
67795 * limitations under the License.
67796 *
67797 * =============================================================================
67798 */
67799 /**
67800 * Represents a data source readable as a stream of binary data chunks.
67801 *
67802 * Because `Dataset`s can be read repeatedly (via `Dataset.iterator()`), this
67803 * provides a means to repeatedly create streams from the underlying data
67804 * sources.
67805 */
67806 class DataSource {
67807 }
67808 // TODO(soergel): consider convenience factory functions here
67809 // in combination with chainable source->dataset above, e.g.:
67810 // tf.data.url(...).asCsvDataset().shuffle().batch()
67811
67812 /**
67813 * @license
67814 * Copyright 2018 Google LLC. All Rights Reserved.
67815 * Licensed under the Apache License, Version 2.0 (the "License");
67816 * you may not use this file except in compliance with the License.
67817 * You may obtain a copy of the License at
67818 *
67819 * http://www.apache.org/licenses/LICENSE-2.0
67820 *
67821 * Unless required by applicable law or agreed to in writing, software
67822 * distributed under the License is distributed on an "AS IS" BASIS,
67823 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
67824 * See the License for the specific language governing permissions and
67825 * limitations under the License.
67826 *
67827 * =============================================================================
67828 */
67829 class StringIterator extends LazyIterator {
67830 /**
67831 * Splits a string stream on a given separator.
67832 *
67833 * It is assumed that the incoming chunk boundaries have no semantic meaning,
67834 * so conceptually the incoming stream is treated simply as the concatenation
67835 * of its elements.
67836 *
67837 * The outgoing stream provides chunks corresponding to the results of the
67838 * standard string split() operation (even if such a chunk spanned incoming
67839 * chunks). The separators are not included.
67840 *
67841 * A typical usage is to split a text file (represented as a stream with
67842 * arbitrary chunk boundaries) into lines.
67843 *
67844 * @param upstream A readable stream of strings that can be treated as
67845 * concatenated.
67846 * @param separator A character to split on.
67847 */
67848 split(separator) {
67849 return new SplitIterator(this, separator);
67850 }
67851 }
67852 // ============================================================================
67853 // The following private classes serve to implement the chainable methods
67854 // on StringIterator. Unfortunately they can't be placed in separate files, due
67855 // to resulting trouble with circular imports.
67856 // ============================================================================
67857 // We wanted multiple inheritance, e.g.
67858 // class SplitIterator extends QueueIterator<string>, StringIterator
67859 // but the TypeScript mixin approach is a bit hacky, so we take this adapter
67860 // approach instead.
67861 class SplitIterator extends StringIterator {
67862 constructor(upstream, separator) {
67863 super();
67864 this.upstream = upstream;
67865 this.impl = new SplitIteratorImpl(upstream, separator);
67866 }
67867 summary() {
67868 return this.impl.summary();
67869 }
67870 async next() {
67871 return this.impl.next();
67872 }
67873 }
67874 class SplitIteratorImpl extends OneToManyIterator {
67875 constructor(upstream, separator) {
67876 super();
67877 this.upstream = upstream;
67878 this.separator = separator;
67879 // A partial string at the end of an upstream chunk
67880 this.carryover = '';
67881 }
67882 summary() {
67883 return `${this.upstream.summary()} -> Split('${this.separator}')`;
67884 }
67885 async pump() {
67886 const chunkResult = await this.upstream.next();
67887 if (chunkResult.done) {
67888 if (this.carryover === '') {
67889 return false;
67890 }
67891 // Pretend that the pump succeeded in order to emit the small last batch.
67892 // The next pump() call will actually fail.
67893 this.outputQueue.push(this.carryover);
67894 this.carryover = '';
67895 return true;
67896 }
67897 const lines = chunkResult.value.split(this.separator);
67898 // Note the behavior: " ab ".split(' ') === ['', 'ab', '']
67899 // Thus the carryover may be '' if the separator falls on a chunk
67900 // boundary; this produces the correct result.
67901 lines[0] = this.carryover + lines[0];
67902 for (const line of lines.slice(0, -1)) {
67903 this.outputQueue.push(line);
67904 }
67905 this.carryover = lines[lines.length - 1];
67906 return true;
67907 }
67908 }
67909
67910 /**
67911 * @license
67912 * Copyright 2018 Google LLC. All Rights Reserved.
67913 * Licensed under the Apache License, Version 2.0 (the "License");
67914 * you may not use this file except in compliance with the License.
67915 * You may obtain a copy of the License at
67916 *
67917 * http://www.apache.org/licenses/LICENSE-2.0
67918 *
67919 * Unless required by applicable law or agreed to in writing, software
67920 * distributed under the License is distributed on an "AS IS" BASIS,
67921 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
67922 * See the License for the specific language governing permissions and
67923 * limitations under the License.
67924 *
67925 * =============================================================================
67926 */
67927 class ByteChunkIterator extends LazyIterator {
67928 /**
67929 * Decode a stream of UTF8-encoded byte arrays to a stream of strings.
67930 *
67931 * The byte arrays producetd from the ByteChunkIterator on which this is
67932 * called will be interpreted as concatenated. No assumptions are made about
67933 * the boundaries of the incoming chunks, so a multi-byte UTF8 encoding of a
67934 * character may span the boundary between chunks. This naturally happens,
67935 * for instance, when reading fixed-size byte arrays from a file.
67936 */
67937 decodeUTF8() {
67938 return new Utf8Iterator(this);
67939 }
67940 }
67941 // ============================================================================
67942 // The following private classes serve to implement the chainable methods
67943 // on ByteChunkIterator. Unfortunately they can't be placed in separate files,
67944 // due to resulting trouble with circular imports.
67945 // ============================================================================
67946 // We wanted multiple inheritance, e.g.
67947 // class Utf8Iterator extends QueueIterator<string>, StringIterator
67948 // but the TypeScript mixin approach is a bit hacky, so we take this adapter
67949 // approach instead.
67950 class Utf8Iterator extends StringIterator {
67951 constructor(upstream) {
67952 super();
67953 this.upstream = upstream;
67954 this.impl = new Utf8IteratorImpl(upstream);
67955 }
67956 summary() {
67957 return this.impl.summary();
67958 }
67959 async next() {
67960 return this.impl.next();
67961 }
67962 }
67963 /**
67964 * Decode a stream of UTF8-encoded byte arrays to a stream of strings.
67965 *
67966 * This is tricky because the incoming byte array boundaries may disrupt a
67967 * multi-byte UTF8 character. Thus any incomplete character data at the end of
67968 * a chunk must be carried over and prepended to the next chunk before
67969 * decoding. Luckily with native decoder, TextDecoder in browser and
67970 * string_decoder in node, byte array boundaries are handled automatically.
67971 *
67972 * In the context of an input pipeline for machine learning, UTF8 decoding is
67973 * needed to parse text files containing training examples or prediction
67974 * requests (e.g., formatted as CSV or JSON). We cannot use the built-in
67975 * decoding provided by FileReader.readAsText() because here we are in a
67976 * streaming context, which FileReader does not support.
67977 *
67978 * @param upstream A `LazyIterator` of `Uint8Arrays` containing UTF8-encoded
67979 * text, which should be interpreted as concatenated. No assumptions are
67980 * made about the boundaries of the incoming chunks, so a multi-byte UTF8
67981 * encoding of a character may span the boundary between chunks. This
67982 * naturally happens, for instance, when reading fixed-size byte arrays from a
67983 * file.
67984 */
67985 class Utf8IteratorImpl extends OneToManyIterator {
67986 constructor(upstream) {
67987 super();
67988 this.upstream = upstream;
67989 if (env().get('IS_BROWSER')) {
67990 this.decoder = new TextDecoder('utf-8');
67991 }
67992 else {
67993 // tslint:disable-next-line:no-require-imports
67994 const { StringDecoder } = require('string_decoder');
67995 this.decoder = new StringDecoder('utf8');
67996 }
67997 }
67998 summary() {
67999 return `${this.upstream.summary()} -> Utf8`;
68000 }
68001 async pump() {
68002 const chunkResult = await this.upstream.next();
68003 let chunk;
68004 if (chunkResult.done) {
68005 return false;
68006 }
68007 else {
68008 chunk = chunkResult.value;
68009 }
68010 let text;
68011 if (env().get('IS_BROWSER')) {
68012 text = this.decoder.decode(chunk, { stream: true });
68013 }
68014 else {
68015 text = this.decoder.write(Buffer.from(chunk.buffer));
68016 }
68017 this.outputQueue.push(text);
68018 return true;
68019 }
68020 }
68021
68022 /**
68023 * @license
68024 * Copyright 2018 Google LLC. All Rights Reserved.
68025 * Licensed under the Apache License, Version 2.0 (the "License");
68026 * you may not use this file except in compliance with the License.
68027 * You may obtain a copy of the License at
68028 *
68029 * http://www.apache.org/licenses/LICENSE-2.0
68030 *
68031 * Unless required by applicable law or agreed to in writing, software
68032 * distributed under the License is distributed on an "AS IS" BASIS,
68033 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68034 * See the License for the specific language governing permissions and
68035 * limitations under the License.
68036 *
68037 * =============================================================================
68038 */
68039 /**
68040 * Provide a stream of chunks from a File, Blob, or Uint8Array.
68041 * @param file The source File, Blob or Uint8Array.
68042 * @param options Optional settings controlling file reading.
68043 * @returns a lazy Iterator of Uint8Arrays containing sequential chunks of the
68044 * input File, Blob or Uint8Array.
68045 */
68046 class FileChunkIterator extends ByteChunkIterator {
68047 constructor(file, options = {}) {
68048 super();
68049 this.file = file;
68050 this.options = options;
68051 assert((file instanceof Uint8Array) ||
68052 (env().get('IS_BROWSER') ?
68053 (file instanceof File || file instanceof Blob) :
68054 false), () => 'FileChunkIterator only supports File, Blob and Uint8Array ' +
68055 'right now.');
68056 this.offset = options.offset || 0;
68057 // default 1MB chunk has tolerable perf on large files
68058 this.chunkSize = options.chunkSize || 1024 * 1024;
68059 }
68060 summary() {
68061 return `FileChunks ${this.file}`;
68062 }
68063 async next() {
68064 if (this.offset >= ((this.file instanceof Uint8Array) ?
68065 this.file.byteLength :
68066 this.file.size)) {
68067 return { value: null, done: true };
68068 }
68069 const chunk = new Promise((resolve, reject) => {
68070 const end = this.offset + this.chunkSize;
68071 if (this.file instanceof Uint8Array) {
68072 // Note if end > this.uint8Array.byteLength, we just get a small last
68073 // chunk.
68074 resolve(new Uint8Array(this.file.slice(this.offset, end)));
68075 }
68076 else {
68077 // This branch assumes that this.file type is File or Blob, which
68078 // means it is in the browser environment.
68079 // TODO(soergel): is this a performance issue?
68080 const fileReader = new FileReader();
68081 fileReader.onload = (event) => {
68082 let data = fileReader.result;
68083 // Not sure we can trust the return type of
68084 // FileReader.readAsArrayBuffer See e.g.
68085 // https://github.com/node-file-api/FileReader/issues/2
68086 if (data instanceof ArrayBuffer) {
68087 data = new Uint8Array(data);
68088 }
68089 if (!(data instanceof Uint8Array)) {
68090 return reject(new TypeError('FileReader returned unknown type.'));
68091 }
68092 resolve(data);
68093 };
68094 fileReader.onabort = (event) => {
68095 return reject(new Error('Aborted'));
68096 };
68097 fileReader.onerror = (event) => {
68098 return reject(new Error(event.type));
68099 };
68100 // TODO(soergel): better handle onabort, onerror
68101 // Note if end > this.file.size, we just get a small last chunk.
68102 const slice = this.file.slice(this.offset, end);
68103 // We can't use readAsText here (even if we know the file is text)
68104 // because the slice boundary may fall within a multi-byte character.
68105 fileReader.readAsArrayBuffer(slice);
68106 }
68107 this.offset = end;
68108 });
68109 return { value: (await chunk), done: false };
68110 }
68111 }
68112
68113 /**
68114 * @license
68115 * Copyright 2018 Google LLC. All Rights Reserved.
68116 * Licensed under the Apache License, Version 2.0 (the "License");
68117 * you may not use this file except in compliance with the License.
68118 * You may obtain a copy of the License at
68119 *
68120 * http://www.apache.org/licenses/LICENSE-2.0
68121 *
68122 * Unless required by applicable law or agreed to in writing, software
68123 * distributed under the License is distributed on an "AS IS" BASIS,
68124 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68125 * See the License for the specific language governing permissions and
68126 * limitations under the License.
68127 *
68128 * =============================================================================
68129 */
68130 /**
68131 * Provide a stream of chunks from a URL.
68132 *
68133 * Note this class first downloads the entire file into memory before providing
68134 * the first element from the stream. This is because the Fetch API does not
68135 * yet reliably provide a reader stream for the response body.
68136 */
68137 async function urlChunkIterator(url, options = {}, fetchFunc) {
68138 let urlString;
68139 let requestInit;
68140 if ((typeof url) === 'string') {
68141 urlString = url;
68142 }
68143 else {
68144 urlString = url.url;
68145 requestInit = getRequestInitFromRequest(url);
68146 }
68147 const response = await (fetchFunc || fetch$2)(urlString, requestInit);
68148 if (response.ok) {
68149 const uint8Array = new Uint8Array(await response.arrayBuffer());
68150 return new FileChunkIterator(uint8Array, options);
68151 }
68152 else {
68153 throw new Error(response.statusText);
68154 }
68155 }
68156 // Generate RequestInit from Request to match tf.util.fetch signature.
68157 const getRequestInitFromRequest = (request) => {
68158 const init = {
68159 method: request.method,
68160 headers: request.headers,
68161 body: request.body,
68162 mode: request.mode,
68163 credentials: request.credentials,
68164 cache: request.cache,
68165 redirect: request.redirect,
68166 referrer: request.referrer,
68167 integrity: request.integrity,
68168 };
68169 return init;
68170 };
68171
68172 /**
68173 * @license
68174 * Copyright 2018 Google LLC. All Rights Reserved.
68175 * Licensed under the Apache License, Version 2.0 (the "License");
68176 * you may not use this file except in compliance with the License.
68177 * You may obtain a copy of the License at
68178 *
68179 * http://www.apache.org/licenses/LICENSE-2.0
68180 *
68181 * Unless required by applicable law or agreed to in writing, software
68182 * distributed under the License is distributed on an "AS IS" BASIS,
68183 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68184 * See the License for the specific language governing permissions and
68185 * limitations under the License.
68186 *
68187 * =============================================================================
68188 */
68189 // Skip tslint any type check cause this method is aiming to check type of
68190 // input.
68191 // tslint:disable-next-line:no-any
68192 function isLocalPath(source) {
68193 return (typeof source === 'string') && source.slice(0, 7) === 'file://';
68194 }
68195
68196 /**
68197 * @license
68198 * Copyright 2018 Google LLC. All Rights Reserved.
68199 * Licensed under the Apache License, Version 2.0 (the "License");
68200 * you may not use this file except in compliance with the License.
68201 * You may obtain a copy of the License at
68202 *
68203 * http://www.apache.org/licenses/LICENSE-2.0
68204 *
68205 * Unless required by applicable law or agreed to in writing, software
68206 * distributed under the License is distributed on an "AS IS" BASIS,
68207 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68208 * See the License for the specific language governing permissions and
68209 * limitations under the License.
68210 *
68211 * =============================================================================
68212 */
68213 /**
68214 * Represents a file, blob, or Uint8Array readable as a stream of binary data
68215 * chunks.
68216 */
68217 class FileDataSource extends DataSource {
68218 /**
68219 * Create a `FileDataSource`.
68220 *
68221 * @param input Local file path, or `File`/`Blob`/`Uint8Array` object to
68222 * read. Local file only works in node environment.
68223 * @param options Options passed to the underlying `FileChunkIterator`s,
68224 * such as {chunksize: 1024}.
68225 */
68226 constructor(input, options = {}) {
68227 super();
68228 this.input = input;
68229 this.options = options;
68230 }
68231 async iterator() {
68232 if (isLocalPath(this.input) && env().get('IS_NODE')) {
68233 // tslint:disable-next-line:no-require-imports
68234 const fs = require('fs');
68235 this.input = fs.readFileSync(this.input.slice(7));
68236 }
68237 // TODO(kangyizhang): Add LocalFileChunkIterator to split local streaming
68238 // with file in browser.
68239 return new FileChunkIterator(this.input, this.options);
68240 }
68241 }
68242
68243 /**
68244 * @license
68245 * Copyright 2018 Google LLC. All Rights Reserved.
68246 * Licensed under the Apache License, Version 2.0 (the "License");
68247 * you may not use this file except in compliance with the License.
68248 * You may obtain a copy of the License at
68249 *
68250 * http://www.apache.org/licenses/LICENSE-2.0
68251 *
68252 * Unless required by applicable law or agreed to in writing, software
68253 * distributed under the License is distributed on an "AS IS" BASIS,
68254 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68255 * See the License for the specific language governing permissions and
68256 * limitations under the License.
68257 *
68258 * =============================================================================
68259 */
68260 /*
68261 * Represents a URL readable as a stream of binary data chunks.
68262 */
68263 class URLDataSource extends DataSource {
68264 /**
68265 * Create a `URLDataSource`.
68266 *
68267 * @param url A source URL string, or a `Request` object.
68268 * @param options Options passed to the underlying `FileChunkIterator`s,
68269 * such as {chunksize: 1024}.
68270 */
68271 constructor(url, fileOptions = {}) {
68272 super();
68273 this.url = url;
68274 this.fileOptions = fileOptions;
68275 }
68276 // TODO(soergel): provide appropriate caching options. Currently this
68277 // will download the URL anew for each call to iterator(). Since we have
68278 // to treat the downloaded file as a blob/buffer anyway, we may as well retain
68279 // it-- but that raises GC issues. Also we may want a persistent disk cache.
68280 async iterator() {
68281 if (isLocalPath(this.url)) {
68282 return (new FileDataSource(this.url, this.fileOptions))
68283 .iterator();
68284 }
68285 else {
68286 return urlChunkIterator(this.url, this.fileOptions);
68287 }
68288 }
68289 }
68290
68291 /**
68292 * @license
68293 * Copyright 2018 Google LLC. All Rights Reserved.
68294 * Licensed under the Apache License, Version 2.0 (the "License");
68295 * you may not use this file except in compliance with the License.
68296 * You may obtain a copy of the License at
68297 *
68298 * http://www.apache.org/licenses/LICENSE-2.0
68299 *
68300 * Unless required by applicable law or agreed to in writing, software
68301 * distributed under the License is distributed on an "AS IS" BASIS,
68302 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68303 * See the License for the specific language governing permissions and
68304 * limitations under the License.
68305 *
68306 * =============================================================================
68307 */
68308 /**
68309 * Create a `CSVDataset` by reading and decoding CSV file(s) from provided URL
68310 * or local path if it's in Node environment.
68311 *
68312 * Note: If isLabel in columnConfigs is `true` for at least one column, the
68313 * element in returned `CSVDataset` will be an object of
68314 * `{xs:features, ys:labels}`: xs is a dict of features key/value pairs, ys
68315 * is a dict of labels key/value pairs. If no column is marked as label,
68316 * returns a dict of features only.
68317 *
68318 * ```js
68319 * const csvUrl =
68320 * 'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv';
68321 *
68322 * async function run() {
68323 * // We want to predict the column "medv", which represents a median value of
68324 * // a home (in $1000s), so we mark it as a label.
68325 * const csvDataset = tf.data.csv(
68326 * csvUrl, {
68327 * columnConfigs: {
68328 * medv: {
68329 * isLabel: true
68330 * }
68331 * }
68332 * });
68333 *
68334 * // Number of features is the number of column names minus one for the label
68335 * // column.
68336 * const numOfFeatures = (await csvDataset.columnNames()).length - 1;
68337 *
68338 * // Prepare the Dataset for training.
68339 * const flattenedDataset =
68340 * csvDataset
68341 * .map(({xs, ys}) =>
68342 * {
68343 * // Convert xs(features) and ys(labels) from object form (keyed by
68344 * // column name) to array form.
68345 * return {xs:Object.values(xs), ys:Object.values(ys)};
68346 * })
68347 * .batch(10);
68348 *
68349 * // Define the model.
68350 * const model = tf.sequential();
68351 * model.add(tf.layers.dense({
68352 * inputShape: [numOfFeatures],
68353 * units: 1
68354 * }));
68355 * model.compile({
68356 * optimizer: tf.train.sgd(0.000001),
68357 * loss: 'meanSquaredError'
68358 * });
68359 *
68360 * // Fit the model using the prepared Dataset
68361 * return model.fitDataset(flattenedDataset, {
68362 * epochs: 10,
68363 * callbacks: {
68364 * onEpochEnd: async (epoch, logs) => {
68365 * console.log(epoch + ':' + logs.loss);
68366 * }
68367 * }
68368 * });
68369 * }
68370 *
68371 * await run();
68372 * ```
68373 *
68374 * @param source URL or local path to get CSV file. If it's a local path, it
68375 * must have prefix `file://` and it only works in node environment.
68376 * @param csvConfig (Optional) A CSVConfig object that contains configurations
68377 * of reading and decoding from CSV file(s).
68378 *
68379 * @doc {
68380 * heading: 'Data',
68381 * subheading: 'Creation',
68382 * namespace: 'data',
68383 * configParamIndices: [1]
68384 * }
68385 */
68386 function csv(source, csvConfig = {}) {
68387 return new CSVDataset(new URLDataSource(source), csvConfig);
68388 }
68389 /**
68390 * Create a `Dataset` that produces each element by calling a provided function.
68391 *
68392 * Note that repeated iterations over this `Dataset` may produce different
68393 * results, because the function will be called anew for each element of each
68394 * iteration.
68395 *
68396 * Also, beware that the sequence of calls to this function may be out of order
68397 * in time with respect to the logical order of the Dataset. This is due to the
68398 * asynchronous lazy nature of stream processing, and depends on downstream
68399 * transformations (e.g. .shuffle()). If the provided function is pure, this is
68400 * no problem, but if it is a closure over a mutable state (e.g., a traversal
68401 * pointer), then the order of the produced elements may be scrambled.
68402 *
68403 * ```js
68404 * let i = -1;
68405 * const func = () =>
68406 * ++i < 5 ? {value: i, done: false} : {value: null, done: true};
68407 * const ds = tf.data.func(func);
68408 * await ds.forEachAsync(e => console.log(e));
68409 * ```
68410 *
68411 * @param f A function that produces one data element on each call.
68412 */
68413 function func(f) {
68414 const iter = iteratorFromFunction(f);
68415 return datasetFromIteratorFn(async () => iter);
68416 }
68417 /**
68418 * Create a `Dataset` that produces each element from provided JavaScript
68419 * generator, which is a function*
68420 * (https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators#Generator_functions),
68421 * or a function that returns an
68422 * iterator
68423 * (https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators#Generator_functions).
68424 *
68425 * The returned iterator should have `.next()` function that returns element in
68426 * format of `{value: TensorContainer, done:boolean}`.
68427 *
68428 * Example of creating a dataset from an iterator factory:
68429 * ```js
68430 * function makeIterator() {
68431 * const numElements = 10;
68432 * let index = 0;
68433 *
68434 * const iterator = {
68435 * next: () => {
68436 * let result;
68437 * if (index < numElements) {
68438 * result = {value: index, done: false};
68439 * index++;
68440 * return result;
68441 * }
68442 * return {value: index, done: true};
68443 * }
68444 * };
68445 * return iterator;
68446 * }
68447 * const ds = tf.data.generator(makeIterator);
68448 * await ds.forEachAsync(e => console.log(e));
68449 * ```
68450 *
68451 * Example of creating a dataset from a generator:
68452 * ```js
68453 * function* dataGenerator() {
68454 * const numElements = 10;
68455 * let index = 0;
68456 * while (index < numElements) {
68457 * const x = index;
68458 * index++;
68459 * yield x;
68460 * }
68461 * }
68462 *
68463 * const ds = tf.data.generator(dataGenerator);
68464 * await ds.forEachAsync(e => console.log(e));
68465 * ```
68466 *
68467 * @param generator A Javascript generator function that returns a JavaScript
68468 * iterator.
68469 *
68470 * @doc {
68471 * heading: 'Data',
68472 * subheading: 'Creation',
68473 * namespace: 'data',
68474 * configParamIndices: [1]
68475 * }
68476 */
68477 function generator(generator) {
68478 return datasetFromIteratorFn(async () => {
68479 const gen = await generator();
68480 return iteratorFromFunction(() => gen.next());
68481 });
68482 }
68483 /**
68484 * Create an iterator that generate `Tensor`s from webcam video stream. This API
68485 * only works in Browser environment when the device has webcam.
68486 *
68487 * Note: this code snippet only works when the device has a webcam. It will
68488 * request permission to open the webcam when running.
68489 * ```js
68490 * const videoElement = document.createElement('video');
68491 * videoElement.width = 100;
68492 * videoElement.height = 100;
68493 * const cam = await tf.data.webcam(videoElement);
68494 * const img = await cam.capture();
68495 * img.print();
68496 * cam.stop();
68497 * ```
68498 *
68499 * @param webcamVideoElement A `HTMLVideoElement` used to play video from
68500 * webcam. If this element is not provided, a hidden `HTMLVideoElement` will
68501 * be created. In that case, `resizeWidth` and `resizeHeight` must be
68502 * provided to set the generated tensor shape.
68503 * @param webcamConfig A `WebcamConfig` object that contains configurations of
68504 * reading and manipulating data from webcam video stream.
68505 *
68506 * @doc {
68507 * heading: 'Data',
68508 * subheading: 'Creation',
68509 * namespace: 'data',
68510 * ignoreCI: true
68511 * }
68512 */
68513 async function webcam(webcamVideoElement, webcamConfig) {
68514 return WebcamIterator.create(webcamVideoElement, webcamConfig);
68515 }
68516 /**
68517 * Create an iterator that generate frequency-domain spectrogram `Tensor`s from
68518 * microphone audio stream with browser's native FFT. This API only works in
68519 * browser environment when the device has microphone.
68520 *
68521 * Note: this code snippet only works when the device has a microphone. It will
68522 * request permission to open the microphone when running.
68523 * ```js
68524 * const mic = await tf.data.microphone({
68525 * fftSize: 1024,
68526 * columnTruncateLength: 232,
68527 * numFramesPerSpectrogram: 43,
68528 * sampleRateHz:44100,
68529 * includeSpectrogram: true,
68530 * includeWaveform: true
68531 * });
68532 * const audioData = await mic.capture();
68533 * const spectrogramTensor = audioData.spectrogram;
68534 * spectrogramTensor.print();
68535 * const waveformTensor = audioData.waveform;
68536 * waveformTensor.print();
68537 * mic.stop();
68538 * ```
68539 *
68540 * @param microphoneConfig A `MicrophoneConfig` object that contains
68541 * configurations of reading audio data from microphone.
68542 *
68543 * @doc {
68544 * heading: 'Data',
68545 * subheading: 'Creation',
68546 * namespace: 'data',
68547 * ignoreCI: true
68548 * }
68549 */
68550 async function microphone(microphoneConfig) {
68551 return MicrophoneIterator.create(microphoneConfig);
68552 }
68553
68554 /** @license See the LICENSE file. */
68555 // This code is auto-generated, do not modify this file!
68556 const version$3 = '3.18.0';
68557
68558 /**
68559 * @license
68560 * Copyright 2018 Google LLC. All Rights Reserved.
68561 * Licensed under the Apache License, Version 2.0 (the "License");
68562 * you may not use this file except in compliance with the License.
68563 * You may obtain a copy of the License at
68564 *
68565 * http://www.apache.org/licenses/LICENSE-2.0
68566 *
68567 * Unless required by applicable law or agreed to in writing, software
68568 * distributed under the License is distributed on an "AS IS" BASIS,
68569 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68570 * See the License for the specific language governing permissions and
68571 * limitations under the License.
68572 * =============================================================================
68573 */
68574
68575 var index = /*#__PURE__*/Object.freeze({
68576 __proto__: null,
68577 array: array,
68578 Dataset: Dataset,
68579 zip: zip,
68580 CSVDataset: CSVDataset,
68581 TextLineDataset: TextLineDataset,
68582 csv: csv,
68583 func: func,
68584 generator: generator,
68585 microphone: microphone,
68586 webcam: webcam,
68587 FileDataSource: FileDataSource,
68588 URLDataSource: URLDataSource,
68589 version_data: version$3
68590 });
68591
68592 /**
68593 * @license
68594 * Copyright 2019 Google LLC. All Rights Reserved.
68595 * Licensed under the Apache License, Version 2.0 (the "License");
68596 * you may not use this file except in compliance with the License.
68597 * You may obtain a copy of the License at
68598 *
68599 * http://www.apache.org/licenses/LICENSE-2.0
68600 *
68601 * Unless required by applicable law or agreed to in writing, software
68602 * distributed under the License is distributed on an "AS IS" BASIS,
68603 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68604 * See the License for the specific language governing permissions and
68605 * limitations under the License.
68606 * =============================================================================
68607 */
68608 function assertNotComplex(tensor, opName) {
68609 if (!Array.isArray(tensor)) {
68610 tensor = [tensor];
68611 }
68612 tensor.forEach(t => {
68613 if (t != null) {
68614 assert(t.dtype !== 'complex64', () => `${opName} does not support complex64 tensors in the CPU backend.`);
68615 }
68616 });
68617 }
68618
68619 /**
68620 * @license
68621 * Copyright 2021 Google LLC. All Rights Reserved.
68622 * Licensed under the Apache License, Version 2.0 (the "License");
68623 * you may not use this file except in compliance with the License.
68624 * You may obtain a copy of the License at
68625 *
68626 * http://www.apache.org/licenses/LICENSE-2.0
68627 *
68628 * Unless required by applicable law or agreed to in writing, software
68629 * distributed under the License is distributed on an "AS IS" BASIS,
68630 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68631 * See the License for the specific language governing permissions and
68632 * limitations under the License.
68633 * =============================================================================
68634 */
68635 const whereImpl$1 = whereImpl;
68636 class MathBackendCPU extends KernelBackend {
68637 constructor() {
68638 super();
68639 this.blockSize = 48;
68640 this.firstUse = true;
68641 this.data = new DataStorage(this, engine());
68642 }
68643 nextDataId() {
68644 return MathBackendCPU.nextDataId++;
68645 }
68646 write(values, shape, dtype) {
68647 if (this.firstUse) {
68648 this.firstUse = false;
68649 if (env().get('IS_NODE')) {
68650 warn('\n============================\n' +
68651 'Hi, looks like you are running TensorFlow.js in ' +
68652 'Node.js. To speed things up dramatically, install our node ' +
68653 'backend, visit https://github.com/tensorflow/tfjs-node for more details. ' +
68654 '\n============================');
68655 }
68656 }
68657 const dataId = { id: this.nextDataId() };
68658 this.data.set(dataId, { values, dtype, refCount: 1 });
68659 return dataId;
68660 }
68661 /**
68662 * Create a data bucket in cpu backend.
68663 * @param shape Shape of the `TensorInfo`.
68664 * @param dtype DType of the `TensorInfo`.
68665 * @param values The value of the `TensorInfo` stored as a flattened array.
68666 */
68667 makeTensorInfo(shape, dtype, values) {
68668 let outId;
68669 if (dtype === 'string' && values != null && values.length > 0 &&
68670 isString(values[0])) {
68671 const encodedValues = values.map(d => encodeString(d));
68672 outId = this.write(encodedValues, shape, dtype);
68673 }
68674 else {
68675 outId = this.write(values, shape, dtype);
68676 }
68677 return { dataId: outId, shape, dtype };
68678 }
68679 /** Return refCount of a `TensorData`. */
68680 refCount(dataId) {
68681 if (this.data.has(dataId)) {
68682 const tensorData = this.data.get(dataId);
68683 return tensorData.refCount;
68684 }
68685 return 0;
68686 }
68687 /** Increase refCount of a `TensorData`. */
68688 incRef(dataId) {
68689 const tensorData = this.data.get(dataId);
68690 tensorData.refCount++;
68691 }
68692 /** Decrease refCount of a `TensorData`. */
68693 decRef(dataId) {
68694 if (this.data.has(dataId)) {
68695 const tensorData = this.data.get(dataId);
68696 tensorData.refCount--;
68697 }
68698 }
68699 move(dataId, values, shape, dtype, refCount) {
68700 this.data.set(dataId, { values, dtype, refCount });
68701 }
68702 numDataIds() {
68703 return this.data.numDataIds();
68704 }
68705 async read(dataId) {
68706 return this.readSync(dataId);
68707 }
68708 readSync(dataId) {
68709 const { dtype, complexTensorInfos } = this.data.get(dataId);
68710 if (dtype === 'complex64') {
68711 const realValues = this.readSync(complexTensorInfos.real.dataId);
68712 const imagValues = this.readSync(complexTensorInfos.imag.dataId);
68713 return mergeRealAndImagArrays(realValues, imagValues);
68714 }
68715 return this.data.get(dataId).values;
68716 }
68717 bufferSync(t) {
68718 const data = this.readSync(t.dataId);
68719 if (t.dtype === 'string') {
68720 try {
68721 // Decode the bytes into string.
68722 const strings = data.map(d => decodeString(d));
68723 return buffer(t.shape, t.dtype, strings);
68724 }
68725 catch (_a) {
68726 throw new Error('Failed to decode encoded string bytes into utf-8');
68727 }
68728 }
68729 return buffer(t.shape, t.dtype, data);
68730 }
68731 makeOutput(values, shape, dtype) {
68732 return engine().makeTensorFromTensorInfo(this.makeTensorInfo(shape, dtype, values), this);
68733 }
68734 /**
68735 * Dispose the memory if the dataId has 0 refCount. Return true if the memory
68736 * is released or memory is not managed in this backend, false if memory is
68737 * not cleared.
68738 * @param dataId
68739 * @oaram force Optional, remove the data regardless of refCount
68740 */
68741 disposeData(dataId, force = false) {
68742 if (this.data.has(dataId)) {
68743 this.data.get(dataId).refCount--;
68744 if (!force && this.data.get(dataId).refCount > 0) {
68745 return false;
68746 }
68747 const { complexTensorInfos } = this.data.get(dataId);
68748 if (complexTensorInfos != null) {
68749 this.disposeData(complexTensorInfos.real.dataId, true);
68750 this.disposeData(complexTensorInfos.imag.dataId, true);
68751 }
68752 this.data.delete(dataId);
68753 }
68754 return true;
68755 }
68756 disposeIntermediateTensorInfo(tensorInfo) {
68757 this.disposeData(tensorInfo.dataId);
68758 }
68759 async time(f) {
68760 const start = now();
68761 f();
68762 const kernelMs = now() - start;
68763 return { kernelMs };
68764 }
68765 memory() {
68766 return {
68767 // Unreliable due to automatic gc. The numbers above are cumulative.
68768 unreliable: true,
68769 reasons: ['The reported memory is an upper bound. Due to automatic garbage ' +
68770 'collection, the true allocated memory may be less.']
68771 };
68772 }
68773 where(condition) {
68774 assertNotComplex([condition], 'where');
68775 const condVals = this.readSync(condition.dataId);
68776 return whereImpl$1(condition.shape, condVals);
68777 }
68778 dispose() { }
68779 floatPrecision() {
68780 return 32;
68781 }
68782 /** Returns the smallest representable number. */
68783 epsilon() {
68784 return super.epsilon();
68785 }
68786 }
68787 MathBackendCPU.nextDataId = 0;
68788
68789 /**
68790 * @license
68791 * Copyright 2020 Google LLC. All Rights Reserved.
68792 * Licensed under the Apache License, Version 2.0 (the License);
68793 * you may not use this file except in compliance with the License.
68794 * You may obtain a copy of the License at
68795 *
68796 * http://www.apache.org/licenses/LICENSE-2.0
68797 *
68798 * Unless required by applicable law or agreed to in writing, software
68799 * distributed under the License is distributed on an AS IS BASIS,
68800 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68801 * See the License for the specific language governing permissions and
68802 * limitations under the License.
68803 * =============================================================================
68804 */
68805 function simpleAbsImpl(vals) {
68806 const resultValues = new Float32Array(vals.length);
68807 for (let i = 0; i < vals.length; ++i) {
68808 resultValues[i] = Math.abs(vals[i]);
68809 }
68810 return resultValues;
68811 }
68812 const abs$1 = (args) => {
68813 const { x } = args.inputs;
68814 const cpuBackend = args.backend;
68815 assertNotComplex(x, 'abs');
68816 let resultValues = new Float32Array(sizeFromShape(x.shape));
68817 const values = cpuBackend.data.get(x.dataId).values;
68818 resultValues = simpleAbsImpl(values);
68819 return cpuBackend.makeOutput(resultValues, x.shape, x.dtype);
68820 };
68821 const absConfig = {
68822 kernelName: Abs,
68823 backendName: 'cpu',
68824 kernelFunc: abs$1,
68825 };
68826
68827 /**
68828 * @license
68829 * Copyright 2020 Google LLC. All Rights Reserved.
68830 * Licensed under the Apache License, Version 2.0 (the "License");
68831 * you may not use this file except in compliance with the License.
68832 * You may obtain a copy of the License at
68833 *
68834 * http://www.apache.org/licenses/LICENSE-2.0
68835 *
68836 * Unless required by applicable law or agreed to in writing, software
68837 * distributed under the License is distributed on an "AS IS" BASIS,
68838 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68839 * See the License for the specific language governing permissions and
68840 * limitations under the License.
68841 * =============================================================================
68842 */
68843 /**
68844 * Template that creates implementation for binary ops. Supports broadcast.
68845 */
68846 function createSimpleBinaryKernelImpl(op) {
68847 return (aShape, bShape, aVals, bVals, dtype) => {
68848 const newShape = assertAndGetBroadcastShape(aShape, bShape);
68849 const resultRank = newShape.length;
68850 const resultStrides = computeStrides(newShape);
68851 const resultSize = sizeFromShape(newShape);
68852 const result = getTypedArrayFromDType(dtype, resultSize);
68853 const aRank = aShape.length;
68854 const bRank = bShape.length;
68855 const aStrides = computeStrides(aShape);
68856 const bStrides = computeStrides(bShape);
68857 const aBroadcastDims = getBroadcastDims(aShape, newShape);
68858 const bBroadcastDims = getBroadcastDims(bShape, newShape);
68859 if (aBroadcastDims.length + bBroadcastDims.length === 0) {
68860 for (let i = 0; i < result.length; ++i) {
68861 result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
68862 }
68863 }
68864 else {
68865 for (let i = 0; i < result.length; ++i) {
68866 const loc = indexToLoc(i, resultRank, resultStrides);
68867 const aLoc = loc.slice(-aRank);
68868 aBroadcastDims.forEach(d => aLoc[d] = 0);
68869 const aIndex = locToIndex(aLoc, aRank, aStrides);
68870 const bLoc = loc.slice(-bRank);
68871 bBroadcastDims.forEach(d => bLoc[d] = 0);
68872 const bIndex = locToIndex(bLoc, bRank, bStrides);
68873 result[i] = op(aVals[aIndex], bVals[bIndex]);
68874 }
68875 }
68876 return [result, newShape];
68877 };
68878 }
68879
68880 /**
68881 * @license
68882 * Copyright 2020 Google LLC. All Rights Reserved.
68883 * Licensed under the Apache License, Version 2.0 (the "License");
68884 * you may not use this file except in compliance with the License.
68885 * You may obtain a copy of the License at
68886 *
68887 * http://www.apache.org/licenses/LICENSE-2.0
68888 *
68889 * Unless required by applicable law or agreed to in writing, software
68890 * distributed under the License is distributed on an "AS IS" BASIS,
68891 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68892 * See the License for the specific language governing permissions and
68893 * limitations under the License.
68894 * =============================================================================
68895 */
68896 function complex$1(args) {
68897 const { inputs, backend } = args;
68898 const { real, imag } = inputs;
68899 const realVals = backend.data.get(real.dataId).values;
68900 const imagVals = backend.data.get(imag.dataId).values;
68901 const complexInfo = backend.makeTensorInfo(real.shape, 'complex64');
68902 const complex = backend.data.get(complexInfo.dataId);
68903 // The complex tensor owns the underlying real and imag tensorInfos, only the
68904 // complex tensor tracks refCount, when complexData is disposed the
68905 // underlying tensorData will be disposed.
68906 complex.complexTensorInfos = {
68907 real: backend.makeTensorInfo(real.shape, 'float32', realVals),
68908 imag: backend.makeTensorInfo(imag.shape, 'float32', imagVals)
68909 };
68910 return complexInfo;
68911 }
68912 const complexConfig = {
68913 kernelName: Complex,
68914 backendName: 'cpu',
68915 kernelFunc: complex$1
68916 };
68917
68918 /**
68919 * @license
68920 * Copyright 2020 Google LLC. All Rights Reserved.
68921 * Licensed under the Apache License, Version 2.0 (the "License");
68922 * you may not use this file except in compliance with the License.
68923 * You may obtain a copy of the License at
68924 *
68925 * http://www.apache.org/licenses/LICENSE-2.0
68926 *
68927 * Unless required by applicable law or agreed to in writing, software
68928 * distributed under the License is distributed on an "AS IS" BASIS,
68929 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68930 * See the License for the specific language governing permissions and
68931 * limitations under the License.
68932 * =============================================================================
68933 */
68934 /**
68935 * Generates a tensorInfo with all zeros value.
68936 * @param backend cpu backend.
68937 * @param shape Shape for the zeros tensor.
68938 * @param dtype Optional. If set, the result has this dtype.
68939 */
68940 function zeros$2(backend, shape, dtype = 'float32') {
68941 if (dtype === 'complex64') {
68942 const real = zeros$2(backend, shape, 'float32');
68943 const imag = zeros$2(backend, shape, 'float32');
68944 return complex$1({ inputs: { real, imag }, backend });
68945 }
68946 const values = makeZerosTypedArray(sizeFromShape(shape), dtype);
68947 return backend.makeTensorInfo(shape, dtype, values);
68948 }
68949
68950 /**
68951 * @license
68952 * Copyright 2020 Google LLC. All Rights Reserved.
68953 * Licensed under the Apache License, Version 2.0 (the "License");
68954 * you may not use this file except in compliance with the License.
68955 * You may obtain a copy of the License at
68956 *
68957 * http://www.apache.org/licenses/LICENSE-2.0
68958 *
68959 * Unless required by applicable law or agreed to in writing, software
68960 * distributed under the License is distributed on an "AS IS" BASIS,
68961 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68962 * See the License for the specific language governing permissions and
68963 * limitations under the License.
68964 * =============================================================================
68965 */
68966 function identity$1(args) {
68967 const { inputs, backend } = args;
68968 const { x } = inputs;
68969 backend.incRef(x.dataId);
68970 return { dataId: x.dataId, shape: x.shape, dtype: x.dtype };
68971 }
68972 const identityConfig = {
68973 kernelName: Identity,
68974 backendName: 'cpu',
68975 kernelFunc: identity$1
68976 };
68977
68978 /**
68979 * @license
68980 * Copyright 2020 Google LLC. All Rights Reserved.
68981 * Licensed under the Apache License, Version 2.0 (the "License");
68982 * you may not use this file except in compliance with the License.
68983 * You may obtain a copy of the License at
68984 *
68985 * http://www.apache.org/licenses/LICENSE-2.0
68986 *
68987 * Unless required by applicable law or agreed to in writing, software
68988 * distributed under the License is distributed on an "AS IS" BASIS,
68989 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68990 * See the License for the specific language governing permissions and
68991 * limitations under the License.
68992 * =============================================================================
68993 */
68994 function real$1(args) {
68995 const { inputs, backend } = args;
68996 const { input } = inputs;
68997 const real = backend.data.get(input.dataId).complexTensorInfos.real;
68998 const realVal = backend.data.get(real.dataId).values;
68999 // When complex tensor is disposed, its underlying parts will be disposed too.
69000 // Make new tensor out of the real value of the complex. This makes sure the
69001 // value is still accessible even if complex tensor is disposed.
69002 return backend.makeTensorInfo(real.shape, real.dtype, realVal);
69003 }
69004 const realConfig = {
69005 kernelName: Real,
69006 backendName: 'cpu',
69007 kernelFunc: real$1
69008 };
69009
69010 /**
69011 * @license
69012 * Copyright 2020 Google LLC. All Rights Reserved.
69013 * Licensed under the Apache License, Version 2.0 (the "License");
69014 * you may not use this file except in compliance with the License.
69015 * You may obtain a copy of the License at
69016 *
69017 * http://www.apache.org/licenses/LICENSE-2.0
69018 *
69019 * Unless required by applicable law or agreed to in writing, software
69020 * distributed under the License is distributed on an "AS IS" BASIS,
69021 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69022 * See the License for the specific language governing permissions and
69023 * limitations under the License.
69024 * =============================================================================
69025 */
69026 function cast$2(args) {
69027 const { inputs, backend, attrs } = args;
69028 const { x } = inputs;
69029 const { dtype } = attrs;
69030 // Casting to complex64.
69031 if (dtype === 'complex64') {
69032 if (x.dtype === 'complex64') {
69033 return identity$1({ inputs: { x }, backend });
69034 }
69035 const zerosTensorInfo = zeros$2(backend, x.shape, x.dtype);
69036 const floatX = cast$2({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
69037 const result = complex$1({ inputs: { real: floatX, imag: zerosTensorInfo }, backend });
69038 backend.disposeIntermediateTensorInfo(zerosTensorInfo);
69039 backend.disposeIntermediateTensorInfo(floatX);
69040 return result;
69041 }
69042 // Casting from complex64
69043 if (x.dtype === 'complex64') {
69044 const realPart = real$1({ inputs: { input: x }, backend });
69045 const result = cast$2({ inputs: { x: realPart }, backend, attrs: { dtype } });
69046 backend.disposeIntermediateTensorInfo(realPart);
69047 return result;
69048 }
69049 if (!hasEncodingLoss(x.dtype, dtype)) {
69050 // We don't change the underlying data, since we cast to higher
69051 // precision.
69052 const result = identity$1({ inputs: { x }, backend });
69053 return { dataId: result.dataId, shape: result.shape, dtype };
69054 }
69055 if (dtype === 'int32') {
69056 const values = backend.data.get(x.dataId).values;
69057 const resultValues = Int32Array.from(values);
69058 return backend.makeTensorInfo(x.shape, 'int32', resultValues);
69059 }
69060 if (dtype === 'bool') {
69061 // This is essentially the result of notEqual(x, 0). We avoid using
69062 // kernel notEqual to avoid circular dependency, i.e. binary_utils ->
69063 // cast -> notEqual -> binary_utils.
69064 const xVals = backend.data.get(x.dataId).values;
69065 const zero = toTypedArray([0], x.dtype);
69066 const [resultData, resultShape] = createSimpleBinaryKernelImpl((a, b) => (a !== b) ? 1 : 0)(x.shape, [], xVals, zero, 'bool');
69067 return backend.makeTensorInfo(resultShape, 'bool', resultData);
69068 }
69069 throw new Error(`Error in Cast: failed to cast ${x.dtype} to ${dtype}`);
69070 }
69071 const castConfig = {
69072 kernelName: Cast,
69073 backendName: 'cpu',
69074 kernelFunc: cast$2
69075 };
69076
69077 /**
69078 * @license
69079 * Copyright 2020 Google LLC. All Rights Reserved.
69080 * Licensed under the Apache License, Version 2.0 (the "License");
69081 * you may not use this file except in compliance with the License.
69082 * You may obtain a copy of the License at
69083 *
69084 * http://www.apache.org/licenses/LICENSE-2.0
69085 *
69086 * Unless required by applicable law or agreed to in writing, software
69087 * distributed under the License is distributed on an "AS IS" BASIS,
69088 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69089 * See the License for the specific language governing permissions and
69090 * limitations under the License.
69091 * =============================================================================
69092 */
69093 /**
69094 * Template that creates a `KernelFunc` for binary ops.
69095 * @param name Kernel name.
69096 * @param binaryKernelImpl A `SimpleBinaryKernelImpl` for the kernel.
69097 * @param binaryKernelComplexImpl Optional. If exists, represents a
69098 * `ComplexBinaryKernelImpl` for the kernel, will be used when input dtype
69099 * is `complex64`.
69100 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
69101 * result has the same dtype as the first input. This is mainly used in
69102 * comparison kernels, such as Equal, Less, Greater, etc.
69103 */
69104 function binaryKernelFunc(name, simpleImpl, complexImpl, dtype) {
69105 if (complexImpl == null) {
69106 return ({ inputs, backend }) => {
69107 const { a, b } = inputs;
69108 const cpuBackend = backend;
69109 assertNotComplex([a, b], name);
69110 const aVals = cpuBackend.data.get(a.dataId).values;
69111 const bVals = cpuBackend.data.get(b.dataId).values;
69112 const decodedAVals = a.dtype === 'string' ?
69113 // tslint:disable-next-line: no-any
69114 fromUint8ToStringArray(aVals) :
69115 aVals;
69116 const decodedBVals = a.dtype === 'string' ?
69117 // tslint:disable-next-line: no-any
69118 fromUint8ToStringArray(bVals) :
69119 bVals;
69120 const $dtype = dtype || a.dtype;
69121 const [resultData, resultShape] = simpleImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype);
69122 return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
69123 };
69124 }
69125 return ({ inputs, backend }) => {
69126 const { a, b } = inputs;
69127 const cpuBackend = backend;
69128 if (a.dtype === 'complex64' || b.dtype === 'complex64') {
69129 const $aComplex = cast$2({ inputs: { x: a }, backend: cpuBackend, attrs: { dtype: 'complex64' } });
69130 const $aComplexVals = cpuBackend.data.get($aComplex.dataId);
69131 const aReal = $aComplexVals.complexTensorInfos.real;
69132 const aImag = $aComplexVals.complexTensorInfos.imag;
69133 const aRealVals = cpuBackend.data.get(aReal.dataId).values;
69134 const aImagVals = cpuBackend.data.get(aImag.dataId).values;
69135 const $bComplex = cast$2({ inputs: { x: b }, backend: cpuBackend, attrs: { dtype: 'complex64' } });
69136 const $bComplexVals = cpuBackend.data.get($bComplex.dataId);
69137 const bReal = $bComplexVals.complexTensorInfos.real;
69138 const bImag = $bComplexVals.complexTensorInfos.imag;
69139 const bRealVals = cpuBackend.data.get(bReal.dataId).values;
69140 const bImagVals = cpuBackend.data.get(bImag.dataId).values;
69141 const [resultRealData, resultImagData, resultShape] = complexImpl(a.shape, b.shape, aRealVals, aImagVals, bRealVals, bImagVals);
69142 const resultReal = cpuBackend.makeTensorInfo(resultShape, 'float32', resultRealData);
69143 const resultImag = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImagData);
69144 const result = complex$1({ inputs: { real: resultReal, imag: resultImag }, backend: cpuBackend });
69145 cpuBackend.disposeIntermediateTensorInfo($aComplex);
69146 cpuBackend.disposeIntermediateTensorInfo($bComplex);
69147 cpuBackend.disposeIntermediateTensorInfo(resultReal);
69148 cpuBackend.disposeIntermediateTensorInfo(resultImag);
69149 return result;
69150 }
69151 else {
69152 const aVals = cpuBackend.data.get(a.dataId).values;
69153 const bVals = cpuBackend.data.get(b.dataId).values;
69154 const $dtype = dtype || a.dtype;
69155 const [resultData, resultShape] = simpleImpl(a.shape, b.shape, aVals, bVals, $dtype);
69156 return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
69157 }
69158 };
69159 }
69160 /**
69161 * Template that creates the complex type implementation for binary ops.
69162 * Supports broadcast.
69163 */
69164 function createComplexBinaryKernelImpl(op) {
69165 return (aShape, bShape, aRealVals, aImagVals, bRealVals, bImagVals) => {
69166 const resultShape = assertAndGetBroadcastShape(aShape, bShape);
69167 const resultSize = sizeFromShape(resultShape);
69168 const resultRank = resultShape.length;
69169 const resultStrides = computeStrides(resultShape);
69170 const resultRealVals = getTypedArrayFromDType('float32', resultSize);
69171 const resultImagVals = getTypedArrayFromDType('float32', resultSize);
69172 const aBroadcastDims = getBroadcastDims(aShape, resultShape);
69173 const bBroadcastDims = getBroadcastDims(bShape, resultShape);
69174 const aVals = mergeRealAndImagArrays(aRealVals, aImagVals);
69175 const bVals = mergeRealAndImagArrays(bRealVals, bImagVals);
69176 const aRank = aShape.length;
69177 const aStrides = computeStrides(aShape);
69178 const bRank = bShape.length;
69179 const bStrides = computeStrides(bShape);
69180 if (aBroadcastDims.length + bBroadcastDims.length === 0) {
69181 for (let i = 0; i < resultRealVals.length; i++) {
69182 const aIdx = i % aVals.length;
69183 const bIdx = i % bVals.length;
69184 const result = op(aVals[aIdx * 2], aVals[aIdx * 2 + 1], bVals[bIdx * 2], bVals[bIdx * 2 + 1]);
69185 resultRealVals[i] = result.real;
69186 resultImagVals[i] = result.imag;
69187 }
69188 }
69189 else {
69190 for (let i = 0; i < resultRealVals.length; i++) {
69191 const loc = indexToLoc(i, resultRank, resultStrides);
69192 const aLoc = loc.slice(-aRank);
69193 aBroadcastDims.forEach(d => aLoc[d] = 0);
69194 const aIndex = locToIndex(aLoc, aRank, aStrides);
69195 const bLoc = loc.slice(-bRank);
69196 bBroadcastDims.forEach(d => bLoc[d] = 0);
69197 const bIndex = locToIndex(bLoc, bRank, bStrides);
69198 const opResult = op(aVals[aIndex * 2], aVals[aIndex * 2 + 1], bVals[bIndex * 2], bVals[bIndex * 2 + 1]);
69199 resultRealVals[i] = opResult.real;
69200 resultImagVals[i] = opResult.imag;
69201 }
69202 }
69203 return [resultRealVals, resultImagVals, resultShape];
69204 };
69205 }
69206
69207 /**
69208 * @license
69209 * Copyright 2020 Google LLC. All Rights Reserved.
69210 * Licensed under the Apache License, Version 2.0 (the "License");
69211 * you may not use this file except in compliance with the License.
69212 * You may obtain a copy of the License at
69213 *
69214 * http://www.apache.org/licenses/LICENSE-2.0
69215 *
69216 * Unless required by applicable law or agreed to in writing, software
69217 * distributed under the License is distributed on an "AS IS" BASIS,
69218 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69219 * See the License for the specific language governing permissions and
69220 * limitations under the License.
69221 * =============================================================================
69222 */
69223 const addImpl = createSimpleBinaryKernelImpl(((a, b) => a + b));
69224 const addComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
69225 return { real: aReal + bReal, imag: aImag + bImag };
69226 }));
69227 const add$4 = binaryKernelFunc(Add, addImpl, addComplexImpl);
69228 const addConfig = {
69229 kernelName: Add,
69230 backendName: 'cpu',
69231 kernelFunc: add$4
69232 };
69233
69234 /**
69235 * @license
69236 * Copyright 2020 Google LLC. All Rights Reserved.
69237 * Licensed under the Apache License, Version 2.0 (the "License");
69238 * you may not use this file except in compliance with the License.
69239 * You may obtain a copy of the License at
69240 *
69241 * http://www.apache.org/licenses/LICENSE-2.0
69242 *
69243 * Unless required by applicable law or agreed to in writing, software
69244 * distributed under the License is distributed on an "AS IS" BASIS,
69245 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69246 * See the License for the specific language governing permissions and
69247 * limitations under the License.
69248 * =============================================================================
69249 */
69250 function bincountImpl(xVals, weightsVals, weightsDtype, weightsShape, size) {
69251 const weightsSize = sizeFromShape(weightsShape);
69252 const outVals = makeZerosTypedArray(size, weightsDtype);
69253 for (let i = 0; i < xVals.length; i++) {
69254 const value = xVals[i];
69255 if (value < 0) {
69256 throw new Error('Input x must be non-negative!');
69257 }
69258 if (value >= size) {
69259 continue;
69260 }
69261 if (weightsSize > 0) {
69262 outVals[value] += weightsVals[i];
69263 }
69264 else {
69265 outVals[value] += 1;
69266 }
69267 }
69268 return outVals;
69269 }
69270 function bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput = false) {
69271 const numRows = xBuf.shape[0];
69272 const numCols = xBuf.shape[1];
69273 const outBuf = buffer([numRows, size], weightsBuf.dtype);
69274 for (let i = 0; i < numRows; i++) {
69275 for (let j = 0; j < numCols; j++) {
69276 const value = xBuf.get(i, j);
69277 if (value < 0) {
69278 throw new Error('Input x must be non-negative!');
69279 }
69280 if (value >= size) {
69281 continue;
69282 }
69283 if (binaryOutput) {
69284 outBuf.set(1, i, value);
69285 }
69286 else {
69287 if (weightsBuf.size > 0) {
69288 outBuf.set(outBuf.get(i, value) + weightsBuf.get(i, j), i, value);
69289 }
69290 else {
69291 outBuf.set(outBuf.get(i, value) + 1, i, value);
69292 }
69293 }
69294 }
69295 }
69296 return outBuf;
69297 }
69298
69299 /**
69300 * @license
69301 * Copyright 2020 Google LLC. All Rights Reserved.
69302 * Licensed under the Apache License, Version 2.0 (the "License");
69303 * you may not use this file except in compliance with the License.
69304 * You may obtain a copy of the License at
69305 *
69306 * http://www.apache.org/licenses/LICENSE-2.0
69307 *
69308 * Unless required by applicable law or agreed to in writing, software
69309 * distributed under the License is distributed on an "AS IS" BASIS,
69310 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69311 * See the License for the specific language governing permissions and
69312 * limitations under the License.
69313 * =============================================================================
69314 */
69315 /**
69316 * Template that creates implementation for unary op.
69317 */
69318 function createSimpleUnaryImpl(op) {
69319 return (values, dtype, attrs) => {
69320 const newValues = getTypedArrayFromDType(dtype, values.length);
69321 for (let i = 0; i < values.length; ++i) {
69322 newValues[i] = op(values[i], attrs);
69323 }
69324 return newValues;
69325 };
69326 }
69327
69328 /**
69329 * @license
69330 * Copyright 2020 Google LLC. All Rights Reserved.
69331 * Licensed under the Apache License, Version 2.0 (the "License");
69332 * you may not use this file except in compliance with the License.
69333 * You may obtain a copy of the License at
69334 *
69335 * http://www.apache.org/licenses/LICENSE-2.0
69336 *
69337 * Unless required by applicable law or agreed to in writing, software
69338 * distributed under the License is distributed on an "AS IS" BASIS,
69339 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69340 * See the License for the specific language governing permissions and
69341 * limitations under the License.
69342 * =============================================================================
69343 */
69344 /**
69345 * Template that creates a `KernelFunc` for unary ops.
69346 * @param name Kernel name.
69347 * @param op A `SimpleUnaryOperation` for the kernel.
69348 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
69349 * result has the same dtype as the input. This is mainly used in certain
69350 * kernels that return bool type, such as isFinite, isInf, etc.
69351 */
69352 function unaryKernelFunc(name, op, dtype) {
69353 return ({ inputs, attrs, backend }) => {
69354 const { x } = inputs;
69355 assertNotComplex(x, name);
69356 if (x.dtype === 'string' || dtype === 'string') {
69357 throw new Error('unaryKernelFunc does not support string input/output');
69358 }
69359 const cpuBackend = backend;
69360 const values = cpuBackend.data.get(x.dataId).values;
69361 const xSize = sizeFromShape(x.shape);
69362 const $dtype = dtype || x.dtype;
69363 const newValues = getArrayFromDType($dtype, xSize);
69364 for (let i = 0; i < xSize; ++i) {
69365 newValues[i] = op(values[i], attrs);
69366 }
69367 return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
69368 };
69369 }
69370 /**
69371 * Template that creates a `KernelFunc` for unary ops from the given
69372 * `SimpleUnaryImpl`..
69373 * @param name Kernel name.
69374 * @param unaryImpl A `SimpleUnaryImpl` that implements the op.
69375 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
69376 * result has the same dtype as the input. This is mainly used in certain
69377 * kernels that return bool type, such as isFinite, isInf, etc.
69378 */
69379 function unaryKernelFuncFromImpl(name, unaryImpl, dtype) {
69380 return ({ inputs, attrs, backend }) => {
69381 const { x } = inputs;
69382 assertNotComplex(x, name);
69383 if (x.dtype === 'string' || dtype === 'string') {
69384 throw new Error('unaryKernelFunc does not support string input/output');
69385 }
69386 const cpuBackend = backend;
69387 const values = cpuBackend.data.get(x.dataId).values;
69388 const $dtype = dtype || x.dtype;
69389 const newValues = unaryImpl(values, $dtype, attrs);
69390 return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
69391 };
69392 }
69393
69394 /**
69395 * @license
69396 * Copyright 2020 Google LLC. All Rights Reserved.
69397 * Licensed under the Apache License, Version 2.0 (the License);
69398 * you may not use this file except in compliance with the License.
69399 * You may obtain a copy of the License at
69400 *
69401 * http://www.apache.org/licenses/LICENSE-2.0
69402 *
69403 * Unless required by applicable law or agreed to in writing, software
69404 * distributed under the License is distributed on an AS IS BASIS,
69405 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69406 * See the License for the specific language governing permissions and
69407 * limitations under the License.
69408 * =============================================================================
69409 */
69410 const ceilImpl = createSimpleUnaryImpl((xi) => Math.ceil(xi));
69411 const ceil$1 = unaryKernelFuncFromImpl(Ceil, ceilImpl);
69412 const ceilConfig = {
69413 kernelName: Ceil,
69414 backendName: 'cpu',
69415 kernelFunc: ceil$1,
69416 };
69417
69418 /**
69419 * @license
69420 * Copyright 2020 Google LLC. All Rights Reserved.
69421 * Licensed under the Apache License, Version 2.0 (the "License");
69422 * you may not use this file except in compliance with the License.
69423 * You may obtain a copy of the License at
69424 *
69425 * http://www.apache.org/licenses/LICENSE-2.0
69426 *
69427 * Unless required by applicable law or agreed to in writing, software
69428 * distributed under the License is distributed on an "AS IS" BASIS,
69429 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69430 * See the License for the specific language governing permissions and
69431 * limitations under the License.
69432 * =============================================================================
69433 */
69434 function concatImpl(inputs, outShape, dtype, simplyConcat) {
69435 const outVals = getArrayFromDType(dtype, sizeFromShape(outShape));
69436 if (simplyConcat && dtype !== 'string') {
69437 // Use built-in TypedArray.set() method for speed.
69438 let offset = 0;
69439 inputs.forEach(input => {
69440 const size = sizeFromShape(input.shape);
69441 outVals.set(input.vals, offset);
69442 offset += size;
69443 });
69444 }
69445 else {
69446 let colOffset = 0;
69447 inputs.forEach(input => {
69448 const decodedData = dtype === 'string' ?
69449 fromUint8ToStringArray(input.vals) :
69450 input.vals;
69451 let tIdx = 0;
69452 for (let row = 0; row < input.shape[0]; ++row) {
69453 const resIdx = row * outShape[1] + colOffset;
69454 for (let col = 0; col < input.shape[1]; ++col) {
69455 outVals[resIdx + col] = decodedData[tIdx++];
69456 }
69457 }
69458 colOffset += input.shape[1];
69459 });
69460 }
69461 return outVals;
69462 }
69463
69464 /**
69465 * @license
69466 * Copyright 2020 Google LLC. All Rights Reserved.
69467 * Licensed under the Apache License, Version 2.0 (the "License");
69468 * you may not use this file except in compliance with the License.
69469 * You may obtain a copy of the License at
69470 *
69471 * http://www.apache.org/licenses/LICENSE-2.0
69472 *
69473 * Unless required by applicable law or agreed to in writing, software
69474 * distributed under the License is distributed on an "AS IS" BASIS,
69475 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69476 * See the License for the specific language governing permissions and
69477 * limitations under the License.
69478 * =============================================================================
69479 */
69480 const equalImpl = createSimpleBinaryKernelImpl((a, b) => (a === b) ? 1 : 0);
69481 const equal$1 = binaryKernelFunc(Equal, equalImpl, null /* complexImpl */, 'bool');
69482 const equalConfig = {
69483 kernelName: Equal,
69484 backendName: 'cpu',
69485 kernelFunc: equal$1
69486 };
69487
69488 /**
69489 * @license
69490 * Copyright 2020 Google LLC. All Rights Reserved.
69491 * Licensed under the Apache License, Version 2.0 (the License);
69492 * you may not use this file except in compliance with the License.
69493 * You may obtain a copy of the License at
69494 *
69495 * http://www.apache.org/licenses/LICENSE-2.0
69496 *
69497 * Unless required by applicable law or agreed to in writing, software
69498 * distributed under the License is distributed on an AS IS BASIS,
69499 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69500 * See the License for the specific language governing permissions and
69501 * limitations under the License.
69502 * =============================================================================
69503 */
69504 const expImpl = createSimpleUnaryImpl((xi) => Math.exp(xi));
69505 const exp$1 = unaryKernelFuncFromImpl(Exp, expImpl, 'float32');
69506 const expConfig = {
69507 kernelName: Exp,
69508 backendName: 'cpu',
69509 kernelFunc: exp$1,
69510 };
69511
69512 /**
69513 * @license
69514 * Copyright 2020 Google LLC. All Rights Reserved.
69515 * Licensed under the Apache License, Version 2.0 (the License);
69516 * you may not use this file except in compliance with the License.
69517 * You may obtain a copy of the License at
69518 *
69519 * http://www.apache.org/licenses/LICENSE-2.0
69520 *
69521 * Unless required by applicable law or agreed to in writing, software
69522 * distributed under the License is distributed on an AS IS BASIS,
69523 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69524 * See the License for the specific language governing permissions and
69525 * limitations under the License.
69526 * =============================================================================
69527 */
69528 const expm1Impl = createSimpleUnaryImpl((xi) => Math.expm1(xi));
69529 const expm1$1 = unaryKernelFuncFromImpl(Expm1, expm1Impl);
69530 const expm1Config = {
69531 kernelName: Expm1,
69532 backendName: 'cpu',
69533 kernelFunc: expm1$1,
69534 };
69535
69536 /**
69537 * @license
69538 * Copyright 2020 Google LLC. All Rights Reserved.
69539 * Licensed under the Apache License, Version 2.0 (the License);
69540 * you may not use this file except in compliance with the License.
69541 * You may obtain a copy of the License at
69542 *
69543 * http://www.apache.org/licenses/LICENSE-2.0
69544 *
69545 * Unless required by applicable law or agreed to in writing, software
69546 * distributed under the License is distributed on an AS IS BASIS,
69547 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69548 * See the License for the specific language governing permissions and
69549 * limitations under the License.
69550 * =============================================================================
69551 */
69552 const floorImpl = createSimpleUnaryImpl((xi) => Math.floor(xi));
69553 const floor$1 = unaryKernelFuncFromImpl(Floor, floorImpl);
69554 const floorConfig = {
69555 kernelName: Floor,
69556 backendName: 'cpu',
69557 kernelFunc: floor$1,
69558 };
69559
69560 /**
69561 * @license
69562 * Copyright 2021 Google LLC. All Rights Reserved.
69563 * Licensed under the Apache License, Version 2.0 (the "License");
69564 * you may not use this file except in compliance with the License.
69565 * You may obtain a copy of the License at
69566 *
69567 * http://www.apache.org/licenses/LICENSE-2.0
69568 *
69569 * Unless required by applicable law or agreed to in writing, software
69570 * distributed under the License is distributed on an "AS IS" BASIS,
69571 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69572 * See the License for the specific language governing permissions and
69573 * limitations under the License.
69574 * =============================================================================
69575 */
69576 function gatherNdImpl(indicesData, paramsBuf, dtype, numSlices, sliceRank, sliceSize, strides, paramsShape, paramsSize) {
69577 const outBuf = buffer([numSlices, sliceSize], dtype);
69578 for (let i = 0; i < numSlices; i++) {
69579 const index = [];
69580 let flattenIndex = 0;
69581 for (let j = 0; j < sliceRank; j++) {
69582 const dim = indicesData[i * sliceRank + j];
69583 flattenIndex += dim * strides[j];
69584 index.push(dim);
69585 }
69586 if (flattenIndex < 0 || flattenIndex >= paramsSize / sliceSize) {
69587 throw new Error(`Invalid indices: ${index} does not index into ${paramsShape}`);
69588 }
69589 for (let k = 0; k < sliceSize; k++) {
69590 outBuf.values[i * sliceSize + k] =
69591 paramsBuf.get(...paramsBuf.indexToLoc(flattenIndex * sliceSize + k));
69592 }
69593 }
69594 return outBuf;
69595 }
69596
69597 /**
69598 * @license
69599 * Copyright 2020 Google LLC. All Rights Reserved.
69600 * Licensed under the Apache License, Version 2.0 (the "License");
69601 * you may not use this file except in compliance with the License.
69602 * You may obtain a copy of the License at
69603 *
69604 * http://www.apache.org/licenses/LICENSE-2.0
69605 *
69606 * Unless required by applicable law or agreed to in writing, software
69607 * distributed under the License is distributed on an "AS IS" BASIS,
69608 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69609 * See the License for the specific language governing permissions and
69610 * limitations under the License.
69611 * =============================================================================
69612 */
69613 function gatherV2Impl(xBuf, indicesBuf, flattenOutputShape) {
69614 const outBuf = buffer(flattenOutputShape, xBuf.dtype);
69615 for (let i = 0; i < outBuf.size; ++i) {
69616 const newLoc = outBuf.indexToLoc(i);
69617 const originalLoc = newLoc.slice();
69618 const batchIdx = originalLoc[0];
69619 const indicesIdx = originalLoc[2];
69620 const indicesIndex = indicesBuf.locToIndex([batchIdx, indicesIdx]);
69621 originalLoc[2] = indicesBuf.values[indicesIndex];
69622 const originalIndex = xBuf.locToIndex(originalLoc);
69623 if (0 <= originalIndex && originalIndex < xBuf.values.length) {
69624 outBuf.values[i] = xBuf.values[originalIndex];
69625 } // Else, index is out of bounds, so leave the default zero val in outBuf.
69626 }
69627 return outBuf;
69628 }
69629
69630 /**
69631 * @license
69632 * Copyright 2020 Google LLC. All Rights Reserved.
69633 * Licensed under the Apache License, Version 2.0 (the "License");
69634 * you may not use this file except in compliance with the License.
69635 * You may obtain a copy of the License at
69636 *
69637 * http://www.apache.org/licenses/LICENSE-2.0
69638 *
69639 * Unless required by applicable law or agreed to in writing, software
69640 * distributed under the License is distributed on an "AS IS" BASIS,
69641 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69642 * See the License for the specific language governing permissions and
69643 * limitations under the License.
69644 * =============================================================================
69645 */
69646 const greaterImpl = createSimpleBinaryKernelImpl((a, b) => (a > b) ? 1 : 0);
69647 const greater$2 = binaryKernelFunc(Greater, greaterImpl, null /* complexImpl */, 'bool');
69648 const greaterConfig = {
69649 kernelName: Greater,
69650 backendName: 'cpu',
69651 kernelFunc: greater$2
69652 };
69653
69654 /**
69655 * @license
69656 * Copyright 2020 Google LLC. All Rights Reserved.
69657 * Licensed under the Apache License, Version 2.0 (the "License");
69658 * you may not use this file except in compliance with the License.
69659 * You may obtain a copy of the License at
69660 *
69661 * http://www.apache.org/licenses/LICENSE-2.0
69662 *
69663 * Unless required by applicable law or agreed to in writing, software
69664 * distributed under the License is distributed on an "AS IS" BASIS,
69665 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69666 * See the License for the specific language governing permissions and
69667 * limitations under the License.
69668 * =============================================================================
69669 */
69670 const greaterEqualImpl = createSimpleBinaryKernelImpl((a, b) => (a >= b) ? 1 : 0);
69671 const greaterEqual$1 = binaryKernelFunc(GreaterEqual, greaterEqualImpl, null /* complexImpl */, 'bool');
69672 const greaterEqualConfig = {
69673 kernelName: GreaterEqual,
69674 backendName: 'cpu',
69675 kernelFunc: greaterEqual$1
69676 };
69677
69678 /**
69679 * @license
69680 * Copyright 2020 Google LLC. All Rights Reserved.
69681 * Licensed under the Apache License, Version 2.0 (the "License");
69682 * you may not use this file except in compliance with the License.
69683 * You may obtain a copy of the License at
69684 *
69685 * http://www.apache.org/licenses/LICENSE-2.0
69686 *
69687 * Unless required by applicable law or agreed to in writing, software
69688 * distributed under the License is distributed on an "AS IS" BASIS,
69689 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69690 * See the License for the specific language governing permissions and
69691 * limitations under the License.
69692 * =============================================================================
69693 */
69694 const lessImpl = createSimpleBinaryKernelImpl((a, b) => (a < b) ? 1 : 0);
69695 const less$2 = binaryKernelFunc(Less, lessImpl, null /* complexImpl */, 'bool');
69696 const lessConfig = {
69697 kernelName: Less,
69698 backendName: 'cpu',
69699 kernelFunc: less$2
69700 };
69701
69702 /**
69703 * @license
69704 * Copyright 2020 Google LLC. All Rights Reserved.
69705 * Licensed under the Apache License, Version 2.0 (the "License");
69706 * you may not use this file except in compliance with the License.
69707 * You may obtain a copy of the License at
69708 *
69709 * http://www.apache.org/licenses/LICENSE-2.0
69710 *
69711 * Unless required by applicable law or agreed to in writing, software
69712 * distributed under the License is distributed on an "AS IS" BASIS,
69713 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69714 * See the License for the specific language governing permissions and
69715 * limitations under the License.
69716 * =============================================================================
69717 */
69718 const lessEqualImpl = createSimpleBinaryKernelImpl((a, b) => (a <= b) ? 1 : 0);
69719 const lessEqual$1 = binaryKernelFunc(LessEqual, lessEqualImpl, null /* complexImpl */, 'bool');
69720 const lessEqualConfig = {
69721 kernelName: LessEqual,
69722 backendName: 'cpu',
69723 kernelFunc: lessEqual$1
69724 };
69725
69726 /**
69727 * @license
69728 * Copyright 2020 Google LLC. All Rights Reserved.
69729 * Licensed under the Apache License, Version 2.0 (the "License");
69730 * you may not use this file except in compliance with the License.
69731 * You may obtain a copy of the License at
69732 *
69733 * http://www.apache.org/licenses/LICENSE-2.0
69734 *
69735 * Unless required by applicable law or agreed to in writing, software
69736 * distributed under the License is distributed on an "AS IS" BASIS,
69737 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69738 * See the License for the specific language governing permissions and
69739 * limitations under the License.
69740 * =============================================================================
69741 */
69742 function linSpaceImpl(start, stop, num) {
69743 const step = (stop - start) / (num - 1);
69744 const values = makeZerosTypedArray(num, 'float32');
69745 values[0] = start;
69746 for (let i = 1; i < values.length; i++) {
69747 values[i] = values[i - 1] + step;
69748 }
69749 return values;
69750 }
69751
69752 /**
69753 * @license
69754 * Copyright 2020 Google LLC. All Rights Reserved.
69755 * Licensed under the Apache License, Version 2.0 (the License);
69756 * you may not use this file except in compliance with the License.
69757 * You may obtain a copy of the License at
69758 *
69759 * http://www.apache.org/licenses/LICENSE-2.0
69760 *
69761 * Unless required by applicable law or agreed to in writing, software
69762 * distributed under the License is distributed on an AS IS BASIS,
69763 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69764 * See the License for the specific language governing permissions and
69765 * limitations under the License.
69766 * =============================================================================
69767 */
69768 const logImpl = createSimpleUnaryImpl((xi) => Math.log(xi));
69769 const log$2 = unaryKernelFuncFromImpl(Log, logImpl);
69770 const logConfig = {
69771 kernelName: Log,
69772 backendName: 'cpu',
69773 kernelFunc: log$2,
69774 };
69775
69776 /**
69777 * @license
69778 * Copyright 2020 Google LLC. All Rights Reserved.
69779 * Licensed under the Apache License, Version 2.0 (the "License");
69780 * you may not use this file except in compliance with the License.
69781 * You may obtain a copy of the License at
69782 *
69783 * http://www.apache.org/licenses/LICENSE-2.0
69784 *
69785 * Unless required by applicable law or agreed to in writing, software
69786 * distributed under the License is distributed on an "AS IS" BASIS,
69787 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69788 * See the License for the specific language governing permissions and
69789 * limitations under the License.
69790 * =============================================================================
69791 */
69792 function maxImpl(aVals, reduceSize, outShape, dtype) {
69793 const vals = getTypedArrayFromDType(dtype, sizeFromShape(outShape));
69794 for (let i = 0; i < vals.length; ++i) {
69795 const offset = i * reduceSize;
69796 let max = aVals[offset];
69797 for (let j = 0; j < reduceSize; ++j) {
69798 const value = aVals[offset + j];
69799 if (Number.isNaN(value) ||
69800 value > max) { // comparison with NaN always return false
69801 max = value;
69802 }
69803 }
69804 vals[i] = max;
69805 }
69806 return vals;
69807 }
69808
69809 /**
69810 * @license
69811 * Copyright 2020 Google LLC. All Rights Reserved.
69812 * Licensed under the Apache License, Version 2.0 (the "License");
69813 * you may not use this file except in compliance with the License.
69814 * You may obtain a copy of the License at
69815 *
69816 * http://www.apache.org/licenses/LICENSE-2.0
69817 *
69818 * Unless required by applicable law or agreed to in writing, software
69819 * distributed under the License is distributed on an "AS IS" BASIS,
69820 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69821 * See the License for the specific language governing permissions and
69822 * limitations under the License.
69823 * =============================================================================
69824 */
69825 const maximumImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => Math.max(aValue, bValue)));
69826 const maximum$3 = binaryKernelFunc(Maximum, maximumImpl);
69827 const maximumConfig = {
69828 kernelName: Maximum,
69829 backendName: 'cpu',
69830 kernelFunc: maximum$3
69831 };
69832
69833 /**
69834 * @license
69835 * Copyright 2020 Google LLC. All Rights Reserved.
69836 * Licensed under the Apache License, Version 2.0 (the "License");
69837 * you may not use this file except in compliance with the License.
69838 * You may obtain a copy of the License at
69839 *
69840 * http://www.apache.org/licenses/LICENSE-2.0
69841 *
69842 * Unless required by applicable law or agreed to in writing, software
69843 * distributed under the License is distributed on an "AS IS" BASIS,
69844 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69845 * See the License for the specific language governing permissions and
69846 * limitations under the License.
69847 * =============================================================================
69848 */
69849 const minimumImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => Math.min(aValue, bValue)));
69850 const minimum$3 = binaryKernelFunc(Minimum, minimumImpl);
69851 const minimumConfig = {
69852 kernelName: Minimum,
69853 backendName: 'cpu',
69854 kernelFunc: minimum$3
69855 };
69856
69857 /**
69858 * @license
69859 * Copyright 2020 Google LLC. All Rights Reserved.
69860 * Licensed under the Apache License, Version 2.0 (the "License");
69861 * you may not use this file except in compliance with the License.
69862 * You may obtain a copy of the License at
69863 *
69864 * http://www.apache.org/licenses/LICENSE-2.0
69865 *
69866 * Unless required by applicable law or agreed to in writing, software
69867 * distributed under the License is distributed on an "AS IS" BASIS,
69868 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69869 * See the License for the specific language governing permissions and
69870 * limitations under the License.
69871 * =============================================================================
69872 */
69873 const multiplyImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue * bValue));
69874 const multiplyComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
69875 return {
69876 real: aReal * bReal - aImag * bImag,
69877 imag: aReal * bImag + aImag * bReal
69878 };
69879 }));
69880 const multiply$2 = binaryKernelFunc(Multiply, multiplyImpl, multiplyComplexImpl);
69881 const multiplyConfig = {
69882 kernelName: Multiply,
69883 backendName: 'cpu',
69884 kernelFunc: multiply$2
69885 };
69886
69887 /**
69888 * @license
69889 * Copyright 2020 Google LLC. All Rights Reserved.
69890 * Licensed under the Apache License, Version 2.0 (the "License");
69891 * you may not use this file except in compliance with the License.
69892 * You may obtain a copy of the License at
69893 *
69894 * http://www.apache.org/licenses/LICENSE-2.0
69895 *
69896 * Unless required by applicable law or agreed to in writing, software
69897 * distributed under the License is distributed on an "AS IS" BASIS,
69898 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69899 * See the License for the specific language governing permissions and
69900 * limitations under the License.
69901 * =============================================================================
69902 */
69903 function negImpl(xVals, xShape, xDtype) {
69904 const minusOne = createScalarValue(-1, xDtype);
69905 return multiplyImpl([], xShape, minusOne, xVals, xDtype);
69906 }
69907 function neg$1(args) {
69908 const { inputs, backend } = args;
69909 const { x } = inputs;
69910 assertNotComplex(x, 'neg');
69911 const xVals = backend.data.get(x.dataId).values;
69912 const [res, newShape] = negImpl(xVals, x.shape, x.dtype);
69913 return backend.makeTensorInfo(newShape, x.dtype, res);
69914 }
69915 const negConfig = {
69916 kernelName: Neg,
69917 backendName: 'cpu',
69918 kernelFunc: neg$1
69919 };
69920
69921 /**
69922 * @license
69923 * Copyright 2020 Google LLC. All Rights Reserved.
69924 * Licensed under the Apache License, Version 2.0 (the "License");
69925 * you may not use this file except in compliance with the License.
69926 * You may obtain a copy of the License at
69927 *
69928 * http://www.apache.org/licenses/LICENSE-2.0
69929 *
69930 * Unless required by applicable law or agreed to in writing, software
69931 * distributed under the License is distributed on an "AS IS" BASIS,
69932 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69933 * See the License for the specific language governing permissions and
69934 * limitations under the License.
69935 * =============================================================================
69936 */
69937 const notEqualImpl = createSimpleBinaryKernelImpl(((a, b) => (a !== b) ? 1 : 0));
69938 const notEqual$1 = binaryKernelFunc(NotEqual, notEqualImpl, null /* complexOp */, 'bool');
69939 const notEqualConfig = {
69940 kernelName: NotEqual,
69941 backendName: 'cpu',
69942 kernelFunc: notEqual$1
69943 };
69944
69945 /**
69946 * @license
69947 * Copyright 2020 Google LLC. All Rights Reserved.
69948 * Licensed under the Apache License, Version 2.0 (the "License");
69949 * you may not use this file except in compliance with the License.
69950 * You may obtain a copy of the License at
69951 *
69952 * http://www.apache.org/licenses/LICENSE-2.0
69953 *
69954 * Unless required by applicable law or agreed to in writing, software
69955 * distributed under the License is distributed on an "AS IS" BASIS,
69956 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69957 * See the License for the specific language governing permissions and
69958 * limitations under the License.
69959 * =============================================================================
69960 */
69961 function transposeImpl(xVals, xShape, dtype, perm, newShape) {
69962 const xRank = xShape.length;
69963 const xSize = sizeFromShape(xShape);
69964 const xStrides = computeStrides(xShape);
69965 const newStrides = computeStrides(newShape);
69966 const result = getTypedArrayFromDType(dtype, sizeFromShape(newShape));
69967 for (let i = 0; i < xSize; ++i) {
69968 const loc = indexToLoc(i, xRank, xStrides);
69969 // Permute location.
69970 const newLoc = new Array(loc.length);
69971 for (let i = 0; i < newLoc.length; i++) {
69972 newLoc[i] = loc[perm[i]];
69973 }
69974 const newIndex = locToIndex(newLoc, xRank, newStrides);
69975 result[newIndex] = xVals[i];
69976 }
69977 return result;
69978 }
69979
69980 /**
69981 * @license
69982 * Copyright 2020 Google LLC. All Rights Reserved.
69983 * Licensed under the Apache License, Version 2.0 (the "License");
69984 * you may not use this file except in compliance with the License.
69985 * You may obtain a copy of the License at
69986 *
69987 * http://www.apache.org/licenses/LICENSE-2.0
69988 *
69989 * Unless required by applicable law or agreed to in writing, software
69990 * distributed under the License is distributed on an "AS IS" BASIS,
69991 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69992 * See the License for the specific language governing permissions and
69993 * limitations under the License.
69994 * =============================================================================
69995 */
69996 function transpose$1(args) {
69997 const { inputs, attrs, backend } = args;
69998 const { x } = inputs;
69999 const { perm } = attrs;
70000 assertNotComplex(x, 'transpose');
70001 const xRank = x.shape.length;
70002 const newShape = new Array(xRank);
70003 for (let i = 0; i < newShape.length; i++) {
70004 newShape[i] = x.shape[perm[i]];
70005 }
70006 const values = backend.data.get(x.dataId).values;
70007 const result = transposeImpl(values, x.shape, x.dtype, perm, newShape);
70008 const dataId = backend.write(result, newShape, x.dtype);
70009 return { dataId, shape: newShape, dtype: x.dtype };
70010 }
70011 const transposeConfig = {
70012 kernelName: Transpose,
70013 backendName: 'cpu',
70014 kernelFunc: transpose$1
70015 };
70016
70017 /**
70018 * @license
70019 * Copyright 2020 Google LLC. All Rights Reserved.
70020 * Licensed under the Apache License, Version 2.0 (the "License");
70021 * you may not use this file except in compliance with the License.
70022 * You may obtain a copy of the License at
70023 *
70024 * http://www.apache.org/licenses/LICENSE-2.0
70025 *
70026 * Unless required by applicable law or agreed to in writing, software
70027 * distributed under the License is distributed on an "AS IS" BASIS,
70028 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70029 * See the License for the specific language governing permissions and
70030 * limitations under the License.
70031 * =============================================================================
70032 */
70033 function prodImpl(xShape, xDtype, xVals, reductionAxes) {
70034 const [outShape, reduceShape] = computeOutAndReduceShapes(xShape, reductionAxes);
70035 const outDtype = upcastType(xDtype, 'int32');
70036 const outVals = makeZerosTypedArray(sizeFromShape(outShape), outDtype);
70037 const reduceSize = sizeFromShape(reduceShape);
70038 for (let i = 0; i < outVals.length; ++i) {
70039 const offset = i * reduceSize;
70040 let prod = 1;
70041 for (let j = 0; j < reduceSize; ++j) {
70042 prod *= xVals[offset + j];
70043 }
70044 outVals[i] = prod;
70045 }
70046 return { outVals, outShape, outDtype };
70047 }
70048 function prod$1(args) {
70049 const { inputs, backend, attrs } = args;
70050 const { x } = inputs;
70051 const { axis, keepDims } = attrs;
70052 assertNotComplex(x, 'prod');
70053 const xRank = x.shape.length;
70054 const axes = parseAxisParam(axis, x.shape);
70055 const permutation = getAxesPermutation(axes, xRank);
70056 let reductionAxes = axes;
70057 let permutedX = x;
70058 const intermediateTensorInfos = [];
70059 if (permutation != null) {
70060 permutedX = transpose$1({ inputs: { x }, backend, attrs: { perm: permutation } });
70061 intermediateTensorInfos.push(permutedX);
70062 reductionAxes = getInnerMostAxes(reductionAxes.length, xRank);
70063 }
70064 const xVals = backend.data.get(permutedX.dataId).values;
70065 const { outVals, outShape, outDtype } = prodImpl(permutedX.shape, permutedX.dtype, xVals, reductionAxes);
70066 let resultShape = outShape;
70067 if (keepDims) {
70068 resultShape = expandShapeToKeepDim(outShape, axes);
70069 }
70070 intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
70071 return backend.makeTensorInfo(resultShape, outDtype, outVals);
70072 }
70073 const prodConfig = {
70074 kernelName: Prod,
70075 backendName: 'cpu',
70076 kernelFunc: prod$1
70077 };
70078
70079 /**
70080 * @license
70081 * Copyright 2020 Google LLC. All Rights Reserved.
70082 * Licensed under the Apache License, Version 2.0 (the "License");
70083 * you may not use this file except in compliance with the License.
70084 * You may obtain a copy of the License at
70085 *
70086 * http://www.apache.org/licenses/LICENSE-2.0
70087 *
70088 * Unless required by applicable law or agreed to in writing, software
70089 * distributed under the License is distributed on an "AS IS" BASIS,
70090 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70091 * See the License for the specific language governing permissions and
70092 * limitations under the License.
70093 * =============================================================================
70094 */
70095 function rangeImpl(start, stop, step, dtype) {
70096 const sameStartStop = start === stop;
70097 const increasingRangeNegativeStep = start < stop && step < 0;
70098 const decreasingRangePositiveStep = stop < start && step > 1;
70099 if (sameStartStop || increasingRangeNegativeStep ||
70100 decreasingRangePositiveStep) {
70101 return makeZerosTypedArray(0, dtype);
70102 }
70103 const numElements = Math.abs(Math.ceil((stop - start) / step));
70104 const values = makeZerosTypedArray(numElements, dtype);
70105 if (stop < start && step === 1) {
70106 // Auto adjust the step's sign if it hasn't been set
70107 // (or was set to 1)
70108 step = -1;
70109 }
70110 values[0] = start;
70111 for (let i = 1; i < values.length; i++) {
70112 values[i] = values[i - 1] + step;
70113 }
70114 return values;
70115 }
70116
70117 /**
70118 * @license
70119 * Copyright 2020 Google LLC. All Rights Reserved.
70120 * Licensed under the Apache License, Version 2.0 (the License);
70121 * you may not use this file except in compliance with the License.
70122 * You may obtain a copy of the License at
70123 *
70124 * http://www.apache.org/licenses/LICENSE-2.0
70125 *
70126 * Unless required by applicable law or agreed to in writing, software
70127 * distributed under the License is distributed on an AS IS BASIS,
70128 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70129 * See the License for the specific language governing permissions and
70130 * limitations under the License.
70131 * =============================================================================
70132 */
70133 const rsqrtImpl = createSimpleUnaryImpl((xi) => 1 / Math.sqrt(xi));
70134 const rsqrt$1 = unaryKernelFuncFromImpl(Rsqrt, rsqrtImpl);
70135 const rsqrtConfig = {
70136 kernelName: Rsqrt,
70137 backendName: 'cpu',
70138 kernelFunc: rsqrt$1,
70139 };
70140
70141 /**
70142 * @license
70143 * Copyright 2020 Google LLC. All Rights Reserved.
70144 * Licensed under the Apache License, Version 2.0 (the "License");
70145 * you may not use this file except in compliance with the License.
70146 * You may obtain a copy of the License at
70147 *
70148 * http://www.apache.org/licenses/LICENSE-2.0
70149 *
70150 * Unless required by applicable law or agreed to in writing, software
70151 * distributed under the License is distributed on an "AS IS" BASIS,
70152 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70153 * See the License for the specific language governing permissions and
70154 * limitations under the License.
70155 * =============================================================================
70156 */
70157 function scatterImpl(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices) {
70158 const flattenShape = [outputSize / sliceSize, sliceSize];
70159 const indicesData = indices.values;
70160 const updatesData = updates.values;
70161 if (outputSize === 0) {
70162 return buffer(shape, updates.dtype);
70163 }
70164 const outBuf = buffer(flattenShape, updates.dtype);
70165 if (typeof defaultValue === 'string') {
70166 outBuf.values.fill(defaultValue);
70167 }
70168 else if (typeof defaultValue === 'number') {
70169 outBuf.values.fill(defaultValue);
70170 }
70171 else if (typeof defaultValue === 'boolean') {
70172 outBuf.values.fill(+defaultValue);
70173 }
70174 for (let i = 0; i < numUpdates; i++) {
70175 const index = [];
70176 let flattenIndex = 0;
70177 for (let j = 0; j < sliceRank; j++) {
70178 const dim = indicesData[i * sliceRank + j];
70179 index.push(dim);
70180 flattenIndex += dim * strides[j];
70181 }
70182 if (flattenIndex < 0 || flattenIndex >= outputSize / sliceSize) {
70183 throw new Error(`Invalid indices: ${index} does not index into ${shape}`);
70184 }
70185 for (let k = 0; k < sliceSize; k++) {
70186 if (sumDupeIndices) {
70187 outBuf.values[flattenIndex * sliceSize + k] +=
70188 updatesData[i * sliceSize + k];
70189 }
70190 else {
70191 outBuf.values[flattenIndex * sliceSize + k] = updates.rank === 0 ?
70192 updatesData[0] :
70193 updatesData[i * sliceSize + k];
70194 }
70195 }
70196 }
70197 return outBuf;
70198 }
70199
70200 /**
70201 * @license
70202 * Copyright 2020 Google LLC. All Rights Reserved.
70203 * Licensed under the Apache License, Version 2.0 (the License);
70204 * you may not use this file except in compliance with the License.
70205 * You may obtain a copy of the License at
70206 *
70207 * http://www.apache.org/licenses/LICENSE-2.0
70208 *
70209 * Unless required by applicable law or agreed to in writing, software
70210 * distributed under the License is distributed on an AS IS BASIS,
70211 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70212 * See the License for the specific language governing permissions and
70213 * limitations under the License.
70214 * =============================================================================
70215 */
70216 const sigmoidImpl = createSimpleUnaryImpl((xi) => 1 / (1 + Math.exp(-xi)));
70217 const sigmoid$1 = unaryKernelFunc(Sigmoid, (xi) => 1 / (1 + Math.exp(-xi)));
70218 const sigmoidConfig = {
70219 kernelName: Sigmoid,
70220 backendName: 'cpu',
70221 kernelFunc: sigmoid$1,
70222 };
70223
70224 /**
70225 * @license
70226 * Copyright 2020 Google LLC. All Rights Reserved.
70227 * Licensed under the Apache License, Version 2.0 (the "License");
70228 * you may not use this file except in compliance with the License.
70229 * You may obtain a copy of the License at
70230 *
70231 * http://www.apache.org/licenses/LICENSE-2.0
70232 *
70233 * Unless required by applicable law or agreed to in writing, software
70234 * distributed under the License is distributed on an "AS IS" BASIS,
70235 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70236 * See the License for the specific language governing permissions and
70237 * limitations under the License.
70238 * =============================================================================
70239 */
70240 function sliceImpl(vals, begin, size, shape, dtype) {
70241 const isContinous = isSliceContinous(shape, begin, size);
70242 const length = sizeFromShape(size);
70243 const xStrides = computeStrides(shape);
70244 if (isContinous) {
70245 const flatOffset = computeFlatOffset(begin, xStrides);
70246 if (dtype === 'string') {
70247 return vals.slice(flatOffset, flatOffset + length);
70248 }
70249 return vals.subarray(flatOffset, flatOffset + length);
70250 }
70251 const decodedData = dtype === 'string' ?
70252 fromUint8ToStringArray(vals) :
70253 vals;
70254 const inBuf = buffer(shape, dtype, decodedData);
70255 const outBuf = buffer(size, dtype);
70256 for (let i = 0; i < outBuf.size; ++i) {
70257 const outLoc = outBuf.indexToLoc(i);
70258 const inLoc = outLoc.map((idx, j) => idx + begin[j]);
70259 outBuf.set(inBuf.get(...inLoc), ...outLoc);
70260 }
70261 if (dtype === 'string') {
70262 return fromStringArrayToUint8(outBuf.values);
70263 }
70264 return outBuf.values;
70265 }
70266 function slice$1(args) {
70267 const { inputs, backend, attrs } = args;
70268 const { x } = inputs;
70269 const { begin, size } = attrs;
70270 assertNotComplex(x, 'slice');
70271 const [$begin, $size] = parseSliceParams(x, begin, size);
70272 assertParamsValid(x, $begin, $size);
70273 const vals = backend.data.get(x.dataId).values;
70274 const outVals = sliceImpl(vals, $begin, $size, x.shape, x.dtype);
70275 return backend.makeTensorInfo($size, x.dtype, outVals);
70276 }
70277 const sliceConfig = {
70278 kernelName: Slice,
70279 backendName: 'cpu',
70280 kernelFunc: slice$1
70281 };
70282
70283 /**
70284 * @license
70285 * Copyright 2021 Google LLC. All Rights Reserved.
70286 * Licensed under the Apache License, Version 2.0 (the "License");
70287 * you may not use this file except in compliance with the License.
70288 * You may obtain a copy of the License at
70289 *
70290 * http://www.apache.org/licenses/LICENSE-2.0
70291 *
70292 * Unless required by applicable law or agreed to in writing, software
70293 * distributed under the License is distributed on an "AS IS" BASIS,
70294 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70295 * See the License for the specific language governing permissions and
70296 * limitations under the License.
70297 * =============================================================================
70298 */
70299 function sparseFillEmptyRowsImpl(indices, indicesShape, indicesDType, values, valuesDType, denseShape, defaultValue) {
70300 const indicesCount = indicesShape[0];
70301 const denseRows = denseShape[0];
70302 const emptyRowIndicator = new Array(denseRows);
70303 const reverseIndexMap = new Array(indicesCount);
70304 const rank = indicesShape[1];
70305 if (denseRows === 0) {
70306 if (indicesCount !== 0) {
70307 throw new Error(getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesCount));
70308 }
70309 const outputIndices = getArrayFromDType(indicesDType, 0);
70310 const outputValues = getArrayFromDType(valuesDType, 0);
70311 return [
70312 outputIndices, [0, rank], outputValues, emptyRowIndicator, reverseIndexMap
70313 ];
70314 }
70315 let rowsAreOrdered = true;
70316 let lastIndicesRow = 0;
70317 const csrOffset = new Array(denseRows).fill(0);
70318 for (let i = 0; i < indicesCount; ++i) {
70319 // indices is a 2d tensor with shape of [N, rank]
70320 const row = indices[i * rank];
70321 if (row < 0) {
70322 throw new Error(getSparseFillEmptyRowsNegativeIndexErrorMessage(i, row));
70323 }
70324 if (row >= denseRows) {
70325 throw new Error(getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(i, row, denseRows));
70326 }
70327 ++csrOffset[row];
70328 rowsAreOrdered = rowsAreOrdered && (row >= lastIndicesRow);
70329 lastIndicesRow = row;
70330 }
70331 let allRowsFull = true;
70332 for (let row = 0; row < denseRows; ++row) {
70333 // csrOffset here describes the number of elements in this dense row
70334 const rowEmpty = (csrOffset[row] === 0);
70335 emptyRowIndicator[row] = rowEmpty;
70336 allRowsFull = allRowsFull && !rowEmpty;
70337 // In filled version, each row has at least one element.
70338 csrOffset[row] = Math.max(csrOffset[row], 1);
70339 // Update csrOffset to represent the number of elements up to and
70340 // including denseRows + 1:
70341 // csrOffset[0] == #{elements of row 0}
70342 // csrOffset[1] == #{elements of row 1} + #{elements of row 0}
70343 // ..
70344 // csrOffset[i] == starting index for elements in row i + 1.
70345 if (row > 0) {
70346 csrOffset[row] += csrOffset[row - 1];
70347 }
70348 }
70349 if (allRowsFull && rowsAreOrdered) {
70350 const outputIndices = indices;
70351 const outputValues = values;
70352 for (let i = 0; i < indicesCount; ++i) {
70353 reverseIndexMap[i] = i;
70354 }
70355 return [
70356 outputIndices, [indicesCount, rank], outputValues, emptyRowIndicator,
70357 reverseIndexMap
70358 ];
70359 }
70360 else {
70361 const fullIndicesCount = csrOffset[denseRows - 1];
70362 const outputIndices = getArrayFromDType(indicesDType, fullIndicesCount * rank);
70363 const outputValues = getArrayFromDType(valuesDType, fullIndicesCount);
70364 const filledCount = new Array(denseRows).fill(0);
70365 // Fill in values for rows that are not missing
70366 for (let i = 0; i < indicesCount; ++i) {
70367 // indices is a 2d tensor with shape of [N, rank]
70368 const row = indices[i * rank];
70369 const offset = filledCount[row];
70370 const outputI = ((row === 0) ? 0 : csrOffset[row - 1]) + offset;
70371 filledCount[row]++; // Increment the filled count for this row.
70372 for (let j = 0; j < rank; ++j) {
70373 // indices and outputIndices are 2d tensors with shape of [N, rank]
70374 outputIndices[outputI * rank + j] = indices[i * rank + j];
70375 }
70376 outputValues[outputI] = values[i];
70377 // We'll need this reverse index map to backprop correctly.
70378 reverseIndexMap[i] = outputI;
70379 }
70380 // Fill in values for rows that are missing
70381 for (let row = 0; row < denseRows; ++row) {
70382 const rowCount = filledCount[row];
70383 if (rowCount === 0) { // We haven't filled this row
70384 const startingIndex = (row === 0) ? 0 : csrOffset[row - 1];
70385 // Remaining index values were set to zero already.
70386 // Just need to set the row index in the right location.
70387 // outputIndices is a 2d tensor with shape of [N, rank]
70388 outputIndices[startingIndex * rank + 0] = row;
70389 for (let col = 1; col < rank; ++col) {
70390 outputIndices[startingIndex * rank + col] = 0;
70391 }
70392 outputValues[startingIndex] = defaultValue;
70393 }
70394 }
70395 return [
70396 outputIndices, [fullIndicesCount, rank], outputValues, emptyRowIndicator,
70397 reverseIndexMap
70398 ];
70399 }
70400 }
70401
70402 /**
70403 * @license
70404 * Copyright 2021 Google LLC. All Rights Reserved.
70405 * Licensed under the Apache License, Version 2.0 (the "License");
70406 * you may not use this file except in compliance with the License.
70407 * You may obtain a copy of the License at
70408 *
70409 * http://www.apache.org/licenses/LICENSE-2.0
70410 *
70411 * Unless required by applicable law or agreed to in writing, software
70412 * distributed under the License is distributed on an "AS IS" BASIS,
70413 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70414 * See the License for the specific language governing permissions and
70415 * limitations under the License.
70416 * =============================================================================
70417 */
70418 function sparseReshapeImpl(inputIndices, inputIndicesShape, inputDType, inputShape, targetShape) {
70419 const denseSize = sizeFromShape(inputShape);
70420 const nnz = inputIndicesShape[0];
70421 const outputRank = targetShape.length;
70422 // Compute the output shape. Determine product of specified dimensions, and
70423 // find the index of the unspecified one.
70424 const outputShape = [];
70425 let product = 1;
70426 let unknownIndex = -1;
70427 for (let d = 0; d < outputRank; ++d) {
70428 const size = targetShape[d];
70429 if (size === -1) {
70430 if (unknownIndex !== -1) {
70431 throw new Error(getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(unknownIndex, d));
70432 }
70433 unknownIndex = d;
70434 outputShape.push(1);
70435 }
70436 else {
70437 if (size < 0) {
70438 throw new Error(getSparseReshapeNegativeOutputDimErrorMessage(d, size));
70439 }
70440 product *= size;
70441 outputShape.push(size);
70442 }
70443 }
70444 if (unknownIndex !== -1) {
70445 if (product <= 0) {
70446 throw new Error(getSparseReshapeEmptyTensorZeroOutputDimErrorMessage());
70447 }
70448 const missing = Math.trunc(denseSize / product);
70449 if (product * missing !== denseSize) {
70450 throw new Error(getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape));
70451 }
70452 outputShape[unknownIndex] = missing;
70453 }
70454 const outputSize = sizeFromShape(outputShape);
70455 if (outputSize !== denseSize) {
70456 throw new Error(getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape));
70457 }
70458 const inputRank = inputShape.length;
70459 const inputStrides = [];
70460 if (inputRank > 0) {
70461 inputStrides[inputRank - 1] = 1;
70462 for (let d = inputRank - 2; d >= 0; --d) {
70463 inputStrides[d] = inputStrides[d + 1] * inputShape[d + 1];
70464 }
70465 }
70466 const outputStrides = [];
70467 if (outputRank > 0) {
70468 outputStrides[outputRank - 1] = 1;
70469 for (let d = outputRank - 2; d >= 0; --d) {
70470 outputStrides[d] = outputStrides[d + 1] * outputShape[d + 1];
70471 }
70472 }
70473 const newIndices = getArrayFromDType(inputDType, nnz * outputRank);
70474 for (let i = 0; i < nnz; ++i) {
70475 let id = 0;
70476 for (let j = 0; j < inputRank; ++j) {
70477 // inputIndices is a 2d tensor with shape of [nnz, inputRank]
70478 id += inputIndices[i * inputRank + j] * inputStrides[j];
70479 }
70480 for (let j = 0; j < outputRank; ++j) {
70481 // newIndices is a 2d tensor with shape of [nnz, outputRank]
70482 newIndices[i * outputRank + j] = Math.trunc(id / outputStrides[j]);
70483 id %= outputStrides[j];
70484 }
70485 }
70486 return [newIndices, [nnz, outputRank], outputShape];
70487 }
70488
70489 /**
70490 * @license
70491 * Copyright 2021 Google LLC. All Rights Reserved.
70492 * Licensed under the Apache License, Version 2.0 (the "License");
70493 * you may not use this file except in compliance with the License.
70494 * You may obtain a copy of the License at
70495 *
70496 * http://www.apache.org/licenses/LICENSE-2.0
70497 *
70498 * Unless required by applicable law or agreed to in writing, software
70499 * distributed under the License is distributed on an "AS IS" BASIS,
70500 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70501 * See the License for the specific language governing permissions and
70502 * limitations under the License.
70503 * =============================================================================
70504 */
70505 function sparseSegmentReductionImpl(input, inputShape, inputDType, indices, segmentIds, isMean = false, defaultValue = 0) {
70506 const numIndices = indices.length;
70507 // Flatten the array to two dimensions
70508 const inputFlat = [inputShape[0], input.length / inputShape[0]];
70509 const numCol = inputFlat[1];
70510 // Note that the current implementation assumes that segmentIds values are
70511 // sorted.
70512 const lastSegmentIdPlusOne = numIndices > 0 ? segmentIds[numIndices - 1] + 1 : 0;
70513 const outputRows = lastSegmentIdPlusOne;
70514 if (outputRows < 0) {
70515 throw new Error(getSparseSegmentReductionNegativeSegmentIdsErrorMessage());
70516 }
70517 const outputShape = inputShape.slice();
70518 outputShape[0] = outputRows;
70519 const outputLength = outputShape.reduce((product, value) => product * value, 1);
70520 // Output array is initialized with the value 0 by default.
70521 const output = getArrayFromDType(inputDType, outputLength);
70522 // Note that we do not initialize the output buffer with a default value, so
70523 // we need to explicitly set missing indices to the default value.
70524 if (numIndices === 0) {
70525 if (outputRows > 0) {
70526 output.fill(defaultValue);
70527 }
70528 return [output, outputShape];
70529 }
70530 if (outputRows <= 0) {
70531 throw new Error(getSparseSegmentReductionNegativeSegmentIdsErrorMessage());
70532 }
70533 let start = 0, end = 1;
70534 // Index from which the output is not initialized.
70535 let uninitializedIndex = 0;
70536 let outIndex = segmentIds[start];
70537 while (true) {
70538 // We initialize nextIndex to 0 to avoid may be uninitialized warning
70539 let nextIndex = 0;
70540 if (end < numIndices) {
70541 nextIndex = segmentIds[end];
70542 if (outIndex === nextIndex) {
70543 ++end;
70544 continue;
70545 }
70546 // We have a new segment here. Verify that the segment ids are growing.
70547 if (outIndex >= nextIndex) {
70548 throw new Error(getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage());
70549 }
70550 }
70551 if (outIndex < 0 || outIndex >= outputRows) {
70552 throw new Error(getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(outIndex, outputRows));
70553 }
70554 // If there is a gap between two indices, we need to set that gap to the
70555 // default value.
70556 if (outIndex > uninitializedIndex) {
70557 output.fill(defaultValue, uninitializedIndex * numCol, outIndex * numCol);
70558 }
70559 for (let i = start; i < end; ++i) {
70560 const index = indices[i];
70561 if (index < 0 || index >= inputFlat[0]) {
70562 throw new Error(getSparseSegmentReductionIndicesOutOfRangeErrorMessage(i, indices[i], inputFlat[0]));
70563 }
70564 for (let j = 0; j < numCol; j++) {
70565 output[outIndex * numCol + j] += input[index * numCol + j];
70566 }
70567 }
70568 if (isMean) {
70569 for (let j = 0; j < numCol; j++) {
70570 output[outIndex * numCol + j] /= end - start;
70571 }
70572 }
70573 start = end;
70574 ++end;
70575 uninitializedIndex = outIndex + 1;
70576 outIndex = nextIndex;
70577 if (end > numIndices) {
70578 break;
70579 }
70580 }
70581 // Fill the gap at the end with the default value.
70582 if (uninitializedIndex < outputRows) {
70583 output.fill(defaultValue, uninitializedIndex * numCol, outputRows * numCol);
70584 }
70585 return [output, outputShape];
70586 }
70587
70588 /**
70589 * @license
70590 * Copyright 2020 Google LLC. All Rights Reserved.
70591 * Licensed under the Apache License, Version 2.0 (the License);
70592 * you may not use this file except in compliance with the License.
70593 * You may obtain a copy of the License at
70594 *
70595 * http://www.apache.org/licenses/LICENSE-2.0
70596 *
70597 * Unless required by applicable law or agreed to in writing, software
70598 * distributed under the License is distributed on an AS IS BASIS,
70599 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70600 * See the License for the specific language governing permissions and
70601 * limitations under the License.
70602 * =============================================================================
70603 */
70604 const sqrtImpl = createSimpleUnaryImpl((xi) => Math.sqrt(xi));
70605 const sqrt$1 = unaryKernelFunc(Sqrt, (xi) => Math.sqrt(xi));
70606 const sqrtConfig = {
70607 kernelName: Sqrt,
70608 backendName: 'cpu',
70609 kernelFunc: sqrt$1,
70610 };
70611
70612 /**
70613 * @license
70614 * Copyright 2020 Google LLC. All Rights Reserved.
70615 * Licensed under the Apache License, Version 2.0 (the "License");
70616 * you may not use this file except in compliance with the License.
70617 * You may obtain a copy of the License at
70618 *
70619 * http://www.apache.org/licenses/LICENSE-2.0
70620 *
70621 * Unless required by applicable law or agreed to in writing, software
70622 * distributed under the License is distributed on an "AS IS" BASIS,
70623 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70624 * See the License for the specific language governing permissions and
70625 * limitations under the License.
70626 * =============================================================================
70627 */
70628 const squaredDifferenceImpl = createSimpleBinaryKernelImpl(((a, b) => {
70629 const diff = a - b;
70630 return diff * diff;
70631 }));
70632 const squaredDifference$1 = binaryKernelFunc(SquaredDifference, squaredDifferenceImpl);
70633 const squaredDifferenceConfig = {
70634 kernelName: SquaredDifference,
70635 backendName: 'cpu',
70636 kernelFunc: squaredDifference$1
70637 };
70638
70639 /**
70640 * @license
70641 * Copyright 2020 Google LLC. All Rights Reserved.
70642 * Licensed under the Apache License, Version 2.0 (the "License");
70643 * you may not use this file except in compliance with the License.
70644 * You may obtain a copy of the License at
70645 *
70646 * http://www.apache.org/licenses/LICENSE-2.0
70647 *
70648 * Unless required by applicable law or agreed to in writing, software
70649 * distributed under the License is distributed on an "AS IS" BASIS,
70650 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70651 * See the License for the specific language governing permissions and
70652 * limitations under the License.
70653 * =============================================================================
70654 */
70655 function stridedSliceImpl(outShape, xBuf, strides, begin) {
70656 const outBuf = buffer(outShape, xBuf.dtype);
70657 for (let i = 0; i < outBuf.size; i++) {
70658 const loc = outBuf.indexToLoc(i);
70659 const newLoc = new Array(loc.length);
70660 for (let j = 0; j < newLoc.length; j++) {
70661 newLoc[j] = loc[j] * strides[j] + begin[j];
70662 }
70663 outBuf.set(xBuf.get(...newLoc), ...loc);
70664 }
70665 return outBuf;
70666 }
70667
70668 /**
70669 * @license
70670 * Copyright 2021 Google LLC. All Rights Reserved.
70671 * Licensed under the Apache License, Version 2.0 (the "License");
70672 * you may not use this file except in compliance with the License.
70673 * You may obtain a copy of the License at
70674 *
70675 * http://www.apache.org/licenses/LICENSE-2.0
70676 *
70677 * Unless required by applicable law or agreed to in writing, software
70678 * distributed under the License is distributed on an "AS IS" BASIS,
70679 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70680 * See the License for the specific language governing permissions and
70681 * limitations under the License.
70682 * =============================================================================
70683 */
70684 /**
70685 * The StringNGramsOp class creates ngrams from ragged string data.
70686 * The constructor contains all attributes related to the operation such as
70687 * padding widths and strings, and the compute function can be used to
70688 * compute the ngrams for different ragged tensor inputs.
70689 */
70690 class StringNGramsOp {
70691 constructor(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
70692 this.separator = encodeString(separator);
70693 this.nGramWidths = nGramWidths;
70694 this.leftPad = encodeString(leftPad);
70695 this.rightPad = encodeString(rightPad);
70696 this.padWidth = padWidth;
70697 this.preserveShort = preserveShortSequences;
70698 }
70699 getPadWidth(nGramWidth) {
70700 // Ngrams can be padded with either a fixed pad width or a dynamic pad
70701 // width depending on the 'padWidth' arg, but in no case should the padding
70702 // ever be wider than 'nGramWidth' - 1.
70703 return Math.min(this.padWidth < 0 ? nGramWidth - 1 : this.padWidth, nGramWidth - 1);
70704 }
70705 getNumNGrams(length, nGramWidth) {
70706 const padWidth = this.getPadWidth(nGramWidth);
70707 return Math.max(0, ((length + 2 * padWidth) - nGramWidth) + 1);
70708 }
70709 createNGrams(data, splitIndex, output, outputStartIndex, numNGrams, nGramWidth) {
70710 for (let nGramIndex = 0; nGramIndex < numNGrams; ++nGramIndex) {
70711 const padWidth = this.getPadWidth(nGramWidth);
70712 const leftPadding = Math.max(0, padWidth - nGramIndex);
70713 const rightPadding = Math.max(0, padWidth - (numNGrams - (nGramIndex + 1)));
70714 const numTokens = nGramWidth - (leftPadding + rightPadding);
70715 const dataStartIndex = splitIndex + (leftPadding > 0 ? 0 : nGramIndex - padWidth);
70716 // Calculate the total expected size of the nGram so we can reserve the
70717 // correct amount of space in the string.
70718 let nGramSize = 0;
70719 // Size of the left padding.
70720 nGramSize += leftPadding * this.leftPad.length;
70721 // Size of the tokens.
70722 for (let n = 0; n < numTokens; ++n) {
70723 nGramSize += data[dataStartIndex + n].length;
70724 }
70725 // Size of the right padding.
70726 nGramSize += rightPadding * this.rightPad.length;
70727 // Size of the separators.
70728 const numSeparators = leftPadding + rightPadding + numTokens - 1;
70729 nGramSize += numSeparators * this.separator.length;
70730 // Build the nGram.
70731 output[outputStartIndex + nGramIndex] = new Uint8Array(nGramSize);
70732 const nGram = output[outputStartIndex + nGramIndex];
70733 let nextNGramIndex = 0;
70734 const appendToNGram = (str) => str.forEach((value) => nGram[nextNGramIndex++] = value);
70735 for (let n = 0; n < leftPadding; ++n) {
70736 appendToNGram(this.leftPad);
70737 appendToNGram(this.separator);
70738 }
70739 // Only output first numTokens - 1 pairs of data and separator
70740 for (let n = 0; n < numTokens - 1; ++n) {
70741 appendToNGram(data[dataStartIndex + n]);
70742 appendToNGram(this.separator);
70743 }
70744 // Handle case when there are no tokens or no right padding as these
70745 // can result in consecutive separators.
70746 if (numTokens > 0) {
70747 // If we have tokens, then output last and then pair each separator
70748 // with the right padding that follows, to ensure nGram ends either with
70749 // the token or with the right pad.
70750 appendToNGram(data[dataStartIndex + numTokens - 1]);
70751 for (let n = 0; n < rightPadding; ++n) {
70752 appendToNGram(this.separator);
70753 appendToNGram(this.rightPad);
70754 }
70755 }
70756 else {
70757 // If we don't have tokens, then the last item inserted into the nGram
70758 // has been the separator from the left padding loop above. Hence,
70759 // output right pad and separator and make sure to finish with a
70760 // padding, not a separator.
70761 for (let n = 0; n < rightPadding - 1; ++n) {
70762 appendToNGram(this.rightPad);
70763 appendToNGram(this.separator);
70764 }
70765 appendToNGram(this.rightPad);
70766 }
70767 }
70768 }
70769 // Data and splits together form the definition of the ragged tensor,
70770 // where data is 1 dimensional and contains the values of the tensor
70771 // and splits denotes the indices at which each row starts.
70772 compute(data, splits) {
70773 // Validate that the splits are valid indices into data, only if there are
70774 // splits specified.
70775 const inputDataSize = data.length;
70776 const splitsSize = splits.length;
70777 if (splitsSize > 0) {
70778 let prevSplit = splits[0];
70779 if (prevSplit !== 0) {
70780 throw new Error(`First split value must be 0, got ${prevSplit}`);
70781 }
70782 for (let i = 1; i < splitsSize; ++i) {
70783 let validSplits = splits[i] >= prevSplit;
70784 validSplits = validSplits && (splits[i] <= inputDataSize);
70785 if (!validSplits) {
70786 throw new Error(`Invalid split value ${splits[i]}, must be in [${prevSplit}, ${inputDataSize}]`);
70787 }
70788 prevSplit = splits[i];
70789 }
70790 if (prevSplit !== inputDataSize) {
70791 throw new Error(`Last split value must be data size. Expected ${inputDataSize}, got ${prevSplit}`);
70792 }
70793 }
70794 const numBatchItems = splitsSize - 1;
70795 const nGramsSplits = getArrayFromDType('int32', splitsSize);
70796 // If there is no data or size, return an empty ragged tensor.
70797 if (inputDataSize === 0 || splitsSize === 0) {
70798 const empty = new Array(inputDataSize);
70799 for (let i = 0; i <= numBatchItems; ++i) {
70800 nGramsSplits[i] = 0;
70801 }
70802 return [empty, nGramsSplits];
70803 }
70804 nGramsSplits[0] = 0;
70805 for (let i = 1; i <= numBatchItems; ++i) {
70806 const length = splits[i] - splits[i - 1];
70807 let numNGrams = 0;
70808 this.nGramWidths.forEach((nGramWidth) => {
70809 numNGrams += this.getNumNGrams(length, nGramWidth);
70810 });
70811 if (this.preserveShort && length > 0 && numNGrams === 0) {
70812 numNGrams = 1;
70813 }
70814 nGramsSplits[i] = nGramsSplits[i - 1] + numNGrams;
70815 }
70816 const nGrams = new Array(nGramsSplits[numBatchItems]);
70817 for (let i = 0; i < numBatchItems; ++i) {
70818 const splitIndex = splits[i];
70819 let outputStartIdx = nGramsSplits[i];
70820 this.nGramWidths.forEach((nGramWidth) => {
70821 const length = splits[i + 1] - splits[i];
70822 const numNGrams = this.getNumNGrams(length, nGramWidth);
70823 this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth);
70824 outputStartIdx += numNGrams;
70825 });
70826 // If we're preserving short sequences, check to see if no sequence was
70827 // generated by comparing the current output start idx to the original
70828 // one (nGramSplitsdata). If no ngrams were generated, then they will
70829 // be equal (since we increment outputStartIdx by numNGrams every
70830 // time we create a set of ngrams.)
70831 if (this.preserveShort && outputStartIdx === nGramsSplits[i]) {
70832 const dataLength = splits[i + 1] - splits[i];
70833 // One legitimate reason to not have any ngrams when this.preserveShort
70834 // is true is if the sequence itself is empty. In that case, move on.
70835 if (dataLength === 0) {
70836 continue;
70837 }
70838 // We don't have to worry about dynamic padding sizes here: if padding
70839 // was dynamic, every sequence would have had sufficient padding to
70840 // generate at least one nGram.
70841 const nGramWidth = dataLength + 2 * this.padWidth;
70842 const numNGrams = 1;
70843 this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth);
70844 }
70845 }
70846 return [nGrams, nGramsSplits];
70847 }
70848 }
70849 function stringNGramsImpl(data, dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
70850 return new StringNGramsOp(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences)
70851 .compute(data, dataSplits);
70852 }
70853
70854 /**
70855 * @license
70856 * Copyright 2021 Google LLC. All Rights Reserved.
70857 * Licensed under the Apache License, Version 2.0 (the "License");
70858 * you may not use this file except in compliance with the License.
70859 * You may obtain a copy of the License at
70860 *
70861 * http://www.apache.org/licenses/LICENSE-2.0
70862 *
70863 * Unless required by applicable law or agreed to in writing, software
70864 * distributed under the License is distributed on an "AS IS" BASIS,
70865 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70866 * See the License for the specific language governing permissions and
70867 * limitations under the License.
70868 * =============================================================================
70869 */
70870 function split$3(str, delimiters, skipEmpty, result) {
70871 if (!str.length) {
70872 return;
70873 }
70874 // When the delimiter is empty, the input is split into individual characters.
70875 if (delimiters.length === 0) {
70876 for (let i = 0; i < str.length; ++i) {
70877 result.push(str.subarray(i, i + 1));
70878 }
70879 return;
70880 }
70881 // When there is one delimiter, the input is split only at that delimiter.
70882 if (delimiters.length === 1) {
70883 const delimiter = delimiters[0];
70884 let f = str.indexOf(delimiter);
70885 while (f !== -1) {
70886 const token = str.subarray(0, f);
70887 if (!skipEmpty || token.length !== 0) {
70888 result.push(token);
70889 }
70890 str = str.subarray(f + 1);
70891 f = str.indexOf(delimiter);
70892 }
70893 if (!skipEmpty || str.length !== 0) {
70894 result.push(str);
70895 }
70896 return;
70897 }
70898 // When there are multiple delimiters, the input is split at every instance
70899 // one of the delimiters appears.
70900 let tokenStart = 0;
70901 for (let i = 0; i < str.length + 1; i++) {
70902 if ((i === str.length) || (delimiters.indexOf(str[i]) !== -1)) {
70903 const token = str.subarray(tokenStart, i);
70904 if (!skipEmpty || token.length !== 0) {
70905 result.push(token);
70906 }
70907 tokenStart = i + 1;
70908 }
70909 }
70910 }
70911 function stringSplitImpl(input, delimiter, skipEmpty) {
70912 const batchSize = input.length;
70913 // Empty delimiter means split the input character by character.
70914 const tokens = [];
70915 let outputSize = 0;
70916 let maxNumEntries = 0;
70917 const numIndices = new Array(batchSize);
70918 for (let i = 0; i < batchSize; ++i) {
70919 const prevTokensLength = tokens.length;
70920 split$3(input[i], delimiter, skipEmpty, tokens);
70921 const nEntries = tokens.length - prevTokensLength;
70922 numIndices[i] = nEntries;
70923 outputSize += nEntries;
70924 maxNumEntries = Math.max(maxNumEntries, nEntries);
70925 }
70926 const indices = getArrayFromDType('int32', outputSize * 2);
70927 const values = new Array(outputSize);
70928 const shape = [batchSize, maxNumEntries];
70929 let c = 0;
70930 for (let i = 0; i < batchSize; ++i) {
70931 for (let j = 0; j < numIndices[i]; ++j) {
70932 // indices is a 2d tensor with shape of [outputSize, 2]
70933 indices[c * 2] = i;
70934 indices[c * 2 + 1] = j;
70935 values[c] = tokens[c];
70936 ++c;
70937 }
70938 }
70939 return [indices, values, shape];
70940 }
70941
70942 /**
70943 * @license
70944 * Copyright 2021 Google LLC. All Rights Reserved.
70945 * Licensed under the Apache License, Version 2.0 (the "License");
70946 * you may not use this file except in compliance with the License.
70947 * You may obtain a copy of the License at
70948 *
70949 * http://www.apache.org/licenses/LICENSE-2.0
70950 *
70951 * Unless required by applicable law or agreed to in writing, software
70952 * distributed under the License is distributed on an "AS IS" BASIS,
70953 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70954 * See the License for the specific language governing permissions and
70955 * limitations under the License.
70956 * =============================================================================
70957 */
70958 function stringToHashBucketFastImpl(input, numBuckets) {
70959 const output = getArrayFromDType('int32', input.length);
70960 for (let i = 0; i < input.length; ++i) {
70961 output[i] =
70962 fingerPrint64(input[i]).modulo(numBuckets).getLowBitsUnsigned();
70963 }
70964 return output;
70965 }
70966
70967 /**
70968 * @license
70969 * Copyright 2020 Google LLC. All Rights Reserved.
70970 * Licensed under the Apache License, Version 2.0 (the "License");
70971 * you may not use this file except in compliance with the License.
70972 * You may obtain a copy of the License at
70973 *
70974 * http://www.apache.org/licenses/LICENSE-2.0
70975 *
70976 * Unless required by applicable law or agreed to in writing, software
70977 * distributed under the License is distributed on an "AS IS" BASIS,
70978 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70979 * See the License for the specific language governing permissions and
70980 * limitations under the License.
70981 * =============================================================================
70982 */
70983 const subImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue - bValue));
70984 const subComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
70985 return { real: aReal - bReal, imag: aImag - bImag };
70986 }));
70987 const sub$1 = binaryKernelFunc(Sub, subImpl, subComplexImpl);
70988 const subConfig = {
70989 kernelName: Sub,
70990 backendName: 'cpu',
70991 kernelFunc: sub$1
70992 };
70993
70994 /**
70995 * @license
70996 * Copyright 2019 Google LLC. All Rights Reserved.
70997 * Licensed under the Apache License, Version 2.0 (the "License");
70998 * you may not use this file except in compliance with the License.
70999 * You may obtain a copy of the License at
71000 *
71001 * http://www.apache.org/licenses/LICENSE-2.0
71002 *
71003 * Unless required by applicable law or agreed to in writing, software
71004 * distributed under the License is distributed on an "AS IS" BASIS,
71005 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71006 * See the License for the specific language governing permissions and
71007 * limitations under the License.
71008 * =============================================================================
71009 */
71010 /**
71011 * An implementation of the tile kernel shared between webgl and cpu for string
71012 * tensors only.
71013 */
71014 function tileImpl(xBuf, reps) {
71015 const newShape = new Array(xBuf.rank);
71016 for (let i = 0; i < newShape.length; i++) {
71017 newShape[i] = xBuf.shape[i] * reps[i];
71018 }
71019 const result = buffer(newShape, xBuf.dtype);
71020 for (let i = 0; i < result.values.length; ++i) {
71021 const newLoc = result.indexToLoc(i);
71022 const originalLoc = new Array(xBuf.rank);
71023 for (let j = 0; j < originalLoc.length; j++) {
71024 originalLoc[j] = newLoc[j] % xBuf.shape[j];
71025 }
71026 const originalIndex = xBuf.locToIndex(originalLoc);
71027 result.values[i] = xBuf.values[originalIndex];
71028 }
71029 return result;
71030 }
71031
71032 /**
71033 * @license
71034 * Copyright 2020 Google LLC. All Rights Reserved.
71035 * Licensed under the Apache License, Version 2.0 (the "License");
71036 * you may not use this file except in compliance with the License.
71037 * You may obtain a copy of the License at
71038 *
71039 * http://www.apache.org/licenses/LICENSE-2.0
71040 *
71041 * Unless required by applicable law or agreed to in writing, software
71042 * distributed under the License is distributed on an "AS IS" BASIS,
71043 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71044 * See the License for the specific language governing permissions and
71045 * limitations under the License.
71046 * =============================================================================
71047 */
71048 const comparePair = (a, b) => {
71049 const valueDiff = b.value - a.value;
71050 return valueDiff === 0 ? a.index - b.index : valueDiff;
71051 };
71052 /**
71053 * Partitions array where all elements smaller than the (k+1) smallest element
71054 * are found to the left of it, and all larger to the right of it.
71055 * Based on the Floyd-Rivest Algorithm, ref:
71056 * https://en.wikipedia.org/wiki/Floyd%E2%80%93Rivest_algorithm
71057 * @param array: Array to partition
71058 * @param left: Left index for the interval
71059 * @param right: Right index for the interval
71060 * @param k: Desired index value, where array[k] is the (k+1)th smallest element
71061 * when left = 0
71062 */
71063 function select(array, k, left = 0, right = array.length - 1) {
71064 while (right > left) {
71065 // Use select recursively to sample a smaller set of size s
71066 // the arbitrary constants 600 and 0.5 are used in the original
71067 // version to minimize execution time.
71068 if (right - left > 600) {
71069 const n = right - left + 1;
71070 const i = k - left + 1;
71071 const z = Math.log(n);
71072 const s = 0.5 * Math.exp(2 * z / 3);
71073 const sd = 0.5 * Math.sqrt(z * s * (n - s) / n) * Math.sign(i - n / 2);
71074 const newLeft = Math.max(left, Math.floor(k - i * s / n + sd));
71075 const newRight = Math.min(right, Math.floor(k + (n - i) * s / n + sd));
71076 select(array, k, newLeft, newRight);
71077 }
71078 // partition the elements between left and right around t
71079 const t = array[k];
71080 let i = left;
71081 let j = right;
71082 swap(array, left, k);
71083 if (comparePair(array[right], t) > 0) {
71084 swap(array, left, right);
71085 }
71086 while (i < j) {
71087 swap(array, i, j);
71088 i++;
71089 j--;
71090 while (comparePair(array[i], t) < 0) {
71091 i = i + 1;
71092 }
71093 while (comparePair(array[j], t) > 0) {
71094 j = j - 1;
71095 }
71096 }
71097 if (comparePair(array[left], t) === 0) {
71098 swap(array, left, j);
71099 }
71100 else {
71101 j = j + 1;
71102 swap(array, j, right);
71103 }
71104 // Adjust left and right towards the boundaries of the subset
71105 // containing the (k - left + 1)th smallest element.
71106 if (j <= k) {
71107 left = j + 1;
71108 }
71109 if (k <= j) {
71110 right = j - 1;
71111 }
71112 }
71113 }
71114 function topKImpl(x, xShape, xDtype, k, sorted) {
71115 // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.
71116 const lastDim = xShape[xShape.length - 1];
71117 const [batch, size] = [x.length / lastDim, lastDim];
71118 const allTopKVals = getTypedArrayFromDType(xDtype, batch * k);
71119 const allTopKIndices = getTypedArrayFromDType('int32', batch * k);
71120 for (let b = 0; b < batch; b++) {
71121 const offset = b * size;
71122 const vals = x.subarray(offset, offset + size);
71123 let valAndInd = new Array(vals.length);
71124 vals.forEach((value, index) => valAndInd[index] = { value, index });
71125 if (k < valAndInd.length) {
71126 select(valAndInd, k);
71127 valAndInd = valAndInd.slice(0, k);
71128 }
71129 if (sorted) {
71130 valAndInd.sort(comparePair);
71131 }
71132 const outOffset = b * k;
71133 const topKVals = allTopKVals.subarray(outOffset, outOffset + k);
71134 const topKIndices = allTopKIndices.subarray(outOffset, outOffset + k);
71135 for (let i = 0; i < k; i++) {
71136 topKVals[i] = valAndInd[i].value;
71137 topKIndices[i] = valAndInd[i].index;
71138 }
71139 }
71140 // Reshape back to the original input shape, except that the last
71141 // dimension is k.
71142 const outputShape = xShape.slice();
71143 outputShape[outputShape.length - 1] = k;
71144 return [
71145 buffer(outputShape, xDtype, allTopKVals),
71146 buffer(outputShape, 'int32', allTopKIndices)
71147 ];
71148 }
71149
71150 /**
71151 * @license
71152 * Copyright 2020 Google LLC. All Rights Reserved.
71153 * Licensed under the Apache License, Version 2.0 (the "License");
71154 * you may not use this file except in compliance with the License.
71155 * You may obtain a copy of the License at
71156 *
71157 * http://www.apache.org/licenses/LICENSE-2.0
71158 *
71159 * Unless required by applicable law or agreed to in writing, software
71160 * distributed under the License is distributed on an "AS IS" BASIS,
71161 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71162 * See the License for the specific language governing permissions and
71163 * limitations under the License.
71164 * =============================================================================
71165 */
71166 function uniqueImpl(values, axis, shape, dtype) {
71167 // Normalize and validate axis.
71168 const $axis = parseAxisParam(axis, shape)[0];
71169 // Calculate the new shape that is suitable for extracting data along the
71170 // given axis.
71171 //
71172 // The rank is 3.
71173 // The size of the 1st dimension is the size of all the axes < the given axis.
71174 // The size of the 2nd dimension is the same as the size of the given axis.
71175 // The size of the 3rd dimension is the size of all the axes > the given axis.
71176 //
71177 // For example, for a 4D tensor with shape=[2, 3, 5, 4] and axis=2, the
71178 // newShape would be: [2*3, 5, 4].
71179 //
71180 // Note that this is not the final output shape. This will be the shape for an
71181 // intermediate TensorBuffer (see inputBuffer below) to allow us to extract
71182 // values along the given axis. To demonstrate how it works, consider the
71183 // following example:
71184 //
71185 // Input: a 3D tensor, with shape [1, 2, 3]
71186 // [
71187 // [
71188 // [1,2,3],
71189 // [4,5,6]
71190 // ]
71191 // ]
71192 // Axis: 2 (the last axis).
71193 // Along axis 2, we expect to extract 3 tensors: [1,4], [2,5], [3,6].
71194 //
71195 // For this example, newShape would be: [2, 3, 1], where 2 is calculated from
71196 // 1*2. The re-shaped data would look like:
71197 //
71198 // [
71199 // [
71200 // [1], [2], [3]
71201 // ],
71202 // [
71203 // [4], [5], [6]
71204 // ]
71205 // ]
71206 //
71207 // Then, we can construct a 3-level nested loop by the following dimension
71208 // order to extract the values along the axis (dimension1):
71209 // i: dimension1 // 0,1,2 (newShape[1])
71210 // m: dimension0 // 0,1 (newShape[0])
71211 // n: dimension2 // 0 (newShape[2])
71212 //
71213 // m, i, n
71214 // ---------
71215 // Iteration 0: data at [0, 0, 0] => "1"
71216 // Iteration 1: data at [1, 0, 0] => "4"
71217 // We got [1,4].
71218 // Iteration 2: data at [0, 1, 0] => "2"
71219 // Iteration 3: data at [1, 1, 0] => "5"
71220 // We got [2,5].
71221 // Iteration 4: data at [0, 2, 0] => "3"
71222 // Iteration 5: data at [1, 2, 0] => "6"
71223 // We got [3,6].
71224 const newShape = [1, shape[0], 1];
71225 for (let i = 0; i < $axis; i++) {
71226 newShape[0] *= shape[i];
71227 }
71228 newShape[1] = shape[$axis];
71229 for (let i = $axis + 1; i < shape.length; i++) {
71230 newShape[2] *= shape[i];
71231 }
71232 // A map from unique elements (their string representations) to their values
71233 // in "indices" (below).
71234 const uniqueElements = {};
71235 // The indices of each unique element in the original tensor along the given
71236 // axis. It is 1D and has the same size as the given axis.
71237 const indices = new Int32Array(shape[$axis]);
71238 // Create a buffer so we can easily extract value at a given location.
71239 const inputBuffer = new TensorBuffer(newShape, dtype, values);
71240 // The indices along the given axis that have unique elements. This is a
71241 // de-duped version of "indices" above.
71242 const uniqueIndices = [];
71243 const is1DTensor = newShape[0] === 1 && newShape[2] === 1;
71244 for (let i = 0; i < shape[$axis]; i++) {
71245 // Extract values along the axis.
71246 let element;
71247 if (is1DTensor) {
71248 // Fast path for 1D tensor input.
71249 element = values[i].toString();
71250 }
71251 else {
71252 const axisValues = [];
71253 for (let m = 0; m < newShape[0]; m++) {
71254 for (let n = 0; n < newShape[2]; n++) {
71255 axisValues.push(inputBuffer.get(m, i, n));
71256 }
71257 }
71258 element = axisValues.join(',');
71259 }
71260 // Dedup and update various indices.
71261 if (uniqueElements[element] !== undefined) {
71262 indices[i] = uniqueElements[element];
71263 }
71264 else {
71265 const uniqueIndex = Object.keys(uniqueElements).length;
71266 uniqueElements[element] = uniqueIndex;
71267 indices[i] = uniqueIndex;
71268 uniqueIndices.push(i);
71269 }
71270 }
71271 // Now we know where each of the unique elements are located along the axis
71272 // (uniqueIndices). Extract them from input buffer and store them in the
71273 // output buffer.
71274 const outputTmpShape = newShape.slice();
71275 outputTmpShape[1] = Object.keys(uniqueElements).length;
71276 const outputBuffer = new TensorBuffer(outputTmpShape, dtype);
71277 uniqueIndices.forEach((uniqueElementIndex, i) => {
71278 for (let m = 0; m < newShape[0]; m++) {
71279 for (let n = 0; n < newShape[2]; n++) {
71280 outputBuffer.set(inputBuffer.get(m, uniqueElementIndex, n), m, i, n);
71281 }
71282 }
71283 });
71284 // The output shape can be calculated from the input shape with the size of
71285 // the given axis replaced by the number of unique elements along that axis.
71286 const outputShape = shape.slice();
71287 outputShape[$axis] = outputTmpShape[1];
71288 return {
71289 outputValues: outputBuffer.values,
71290 outputShape,
71291 indices,
71292 };
71293 }
71294
71295 /**
71296 * @license
71297 * Copyright 2020 Google LLC. All Rights Reserved.
71298 * Licensed under the Apache License, Version 2.0 (the "License");
71299 * you may not use this file except in compliance with the License.
71300 * You may obtain a copy of the License at
71301 *
71302 * http://www.apache.org/licenses/LICENSE-2.0
71303 *
71304 * Unless required by applicable law or agreed to in writing, software
71305 * distributed under the License is distributed on an "AS IS" BASIS,
71306 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71307 * See the License for the specific language governing permissions and
71308 * limitations under the License.
71309 * =============================================================================
71310 */
71311
71312 var shared = /*#__PURE__*/Object.freeze({
71313 __proto__: null,
71314 simpleAbsImpl: simpleAbsImpl,
71315 addImpl: addImpl,
71316 bincountImpl: bincountImpl,
71317 bincountReduceImpl: bincountReduceImpl,
71318 ceilImpl: ceilImpl,
71319 concatImpl: concatImpl,
71320 equalImpl: equalImpl,
71321 expImpl: expImpl,
71322 expm1Impl: expm1Impl,
71323 floorImpl: floorImpl,
71324 gatherNdImpl: gatherNdImpl,
71325 gatherV2Impl: gatherV2Impl,
71326 greaterImpl: greaterImpl,
71327 greaterEqualImpl: greaterEqualImpl,
71328 lessImpl: lessImpl,
71329 lessEqualImpl: lessEqualImpl,
71330 linSpaceImpl: linSpaceImpl,
71331 logImpl: logImpl,
71332 maxImpl: maxImpl,
71333 maximumImpl: maximumImpl,
71334 minimumImpl: minimumImpl,
71335 multiplyImpl: multiplyImpl,
71336 negImpl: negImpl,
71337 notEqualImpl: notEqualImpl,
71338 prodImpl: prodImpl,
71339 rangeImpl: rangeImpl,
71340 rsqrtImpl: rsqrtImpl,
71341 scatterImpl: scatterImpl,
71342 sigmoidImpl: sigmoidImpl,
71343 sliceImpl: sliceImpl,
71344 sparseFillEmptyRowsImpl: sparseFillEmptyRowsImpl,
71345 sparseReshapeImpl: sparseReshapeImpl,
71346 sparseSegmentReductionImpl: sparseSegmentReductionImpl,
71347 sqrtImpl: sqrtImpl,
71348 squaredDifferenceImpl: squaredDifferenceImpl,
71349 stridedSliceImpl: stridedSliceImpl,
71350 stringNGramsImpl: stringNGramsImpl,
71351 stringSplitImpl: stringSplitImpl,
71352 stringToHashBucketFastImpl: stringToHashBucketFastImpl,
71353 subImpl: subImpl,
71354 tileImpl: tileImpl,
71355 topKImpl: topKImpl,
71356 transposeImpl: transposeImpl,
71357 uniqueImpl: uniqueImpl
71358 });
71359
71360 /** @license See the LICENSE file. */
71361 // This code is auto-generated, do not modify this file!
71362 const version$4 = '3.18.0';
71363
71364 /**
71365 * @license
71366 * Copyright 2020 Google LLC. All Rights Reserved.
71367 * Licensed under the Apache License, Version 2.0 (the "License");
71368 * you may not use this file except in compliance with the License.
71369 * You may obtain a copy of the License at
71370 *
71371 * http://www.apache.org/licenses/LICENSE-2.0
71372 *
71373 * Unless required by applicable law or agreed to in writing, software
71374 * distributed under the License is distributed on an "AS IS" BASIS,
71375 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71376 * See the License for the specific language governing permissions and
71377 * limitations under the License.
71378 * =============================================================================
71379 */
71380 // Side effects for default initialization of MathBackendCPU
71381 registerBackend('cpu', () => new MathBackendCPU(), 1 /* priority */);
71382
71383 /**
71384 * @license
71385 * Copyright 2020 Google LLC. All Rights Reserved.
71386 * Licensed under the Apache License, Version 2.0 (the License);
71387 * you may not use this file except in compliance with the License.
71388 * You may obtain a copy of the License at
71389 *
71390 * http://www.apache.org/licenses/LICENSE-2.0
71391 *
71392 * Unless required by applicable law or agreed to in writing, software
71393 * distributed under the License is distributed on an AS IS BASIS,
71394 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71395 * See the License for the specific language governing permissions and
71396 * limitations under the License.
71397 * =============================================================================
71398 */
71399 const elu$3 = unaryKernelFunc(Elu, (xi) => xi >= 0 ? xi : (Math.exp(xi) - 1));
71400 const eluConfig = {
71401 kernelName: Elu,
71402 backendName: 'cpu',
71403 kernelFunc: elu$3,
71404 };
71405
71406 /**
71407 * @license
71408 * Copyright 2020 Google LLC. All Rights Reserved.
71409 * Licensed under the Apache License, Version 2.0 (the "License");
71410 * you may not use this file except in compliance with the License.
71411 * You may obtain a copy of the License at
71412 *
71413 * http://www.apache.org/licenses/LICENSE-2.0
71414 *
71415 * Unless required by applicable law or agreed to in writing, software
71416 * distributed under the License is distributed on an "AS IS" BASIS,
71417 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71418 * See the License for the specific language governing permissions and
71419 * limitations under the License.
71420 * =============================================================================
71421 */
71422 function leakyRelu$1(args) {
71423 const { inputs, backend, attrs } = args;
71424 const { x } = inputs;
71425 const { alpha } = attrs;
71426 assertNotComplex([x], 'leakyRelu');
71427 const xSize = sizeFromShape(x.shape);
71428 const xVals = backend.data.get(x.dataId).values;
71429 const outVals = getTypedArrayFromDType('float32', xSize);
71430 for (let i = 0; i < xVals.length; i++) {
71431 outVals[i] = xVals[i] < 0 ? alpha * xVals[i] : xVals[i];
71432 }
71433 return backend.makeTensorInfo(x.shape, 'float32', outVals);
71434 }
71435 const leakyReluConfig = {
71436 kernelName: LeakyRelu,
71437 backendName: 'cpu',
71438 kernelFunc: leakyRelu$1
71439 };
71440
71441 /**
71442 * @license
71443 * Copyright 2020 Google LLC. All Rights Reserved.
71444 * Licensed under the Apache License, Version 2.0 (the License);
71445 * you may not use this file except in compliance with the License.
71446 * You may obtain a copy of the License at
71447 *
71448 * http://www.apache.org/licenses/LICENSE-2.0
71449 *
71450 * Unless required by applicable law or agreed to in writing, software
71451 * distributed under the License is distributed on an AS IS BASIS,
71452 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71453 * See the License for the specific language governing permissions and
71454 * limitations under the License.
71455 * =============================================================================
71456 */
71457 const preluImpl = createSimpleBinaryKernelImpl((xValue, aValue) => xValue < 0 ? aValue * xValue : xValue);
71458 function prelu$2(args) {
71459 const { inputs, backend } = args;
71460 const { x, alpha } = inputs;
71461 assertNotComplex([x, alpha], 'prelu');
71462 const aVals = backend.data.get(x.dataId).values;
71463 const bVals = backend.data.get(alpha.dataId).values;
71464 const [resultData, resultShape] = preluImpl(x.shape, alpha.shape, aVals, bVals, 'float32');
71465 return backend.makeTensorInfo(resultShape, 'float32', resultData);
71466 }
71467 const preluConfig = {
71468 kernelName: Prelu,
71469 backendName: 'cpu',
71470 kernelFunc: prelu$2,
71471 };
71472
71473 /**
71474 * @license
71475 * Copyright 2020 Google LLC. All Rights Reserved.
71476 * Licensed under the Apache License, Version 2.0 (the License);
71477 * you may not use this file except in compliance with the License.
71478 * You may obtain a copy of the License at
71479 *
71480 * http://www.apache.org/licenses/LICENSE-2.0
71481 *
71482 * Unless required by applicable law or agreed to in writing, software
71483 * distributed under the License is distributed on an AS IS BASIS,
71484 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71485 * See the License for the specific language governing permissions and
71486 * limitations under the License.
71487 * =============================================================================
71488 */
71489 const relu$1 = unaryKernelFunc(Relu, (xi) => Math.max(0, xi));
71490 const reluConfig = {
71491 kernelName: Relu,
71492 backendName: 'cpu',
71493 kernelFunc: relu$1,
71494 };
71495
71496 /**
71497 * @license
71498 * Copyright 2020 Google LLC. All Rights Reserved.
71499 * Licensed under the Apache License, Version 2.0 (the License);
71500 * you may not use this file except in compliance with the License.
71501 * You may obtain a copy of the License at
71502 *
71503 * http://www.apache.org/licenses/LICENSE-2.0
71504 *
71505 * Unless required by applicable law or agreed to in writing, software
71506 * distributed under the License is distributed on an AS IS BASIS,
71507 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71508 * See the License for the specific language governing permissions and
71509 * limitations under the License.
71510 * =============================================================================
71511 */
71512 const relu6$1 = unaryKernelFunc(Relu6, (xi) => Math.min(Math.max(0, xi), 6));
71513 const relu6Config = {
71514 kernelName: Relu6,
71515 backendName: 'cpu',
71516 kernelFunc: relu6$1,
71517 };
71518
71519 /**
71520 * @license
71521 * Copyright 2020 Google LLC. All Rights Reserved.
71522 * Licensed under the Apache License, Version 2.0 (the "License");
71523 * you may not use this file except in compliance with the License.
71524 * You may obtain a copy of the License at
71525 *
71526 * http://www.apache.org/licenses/LICENSE-2.0
71527 *
71528 * Unless required by applicable law or agreed to in writing, software
71529 * distributed under the License is distributed on an "AS IS" BASIS,
71530 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71531 * See the License for the specific language governing permissions and
71532 * limitations under the License.
71533 * =============================================================================
71534 */
71535 function applyActivation$1(backend, x, activation, preluActivationWeights, leakyreluAlpha) {
71536 if (activation === 'linear') {
71537 return identity$1({ inputs: { x }, backend });
71538 }
71539 else if (activation === 'relu') {
71540 return relu$1({ inputs: { x }, backend });
71541 }
71542 else if (activation === 'elu') {
71543 return elu$3({ inputs: { x }, backend });
71544 }
71545 else if (activation === 'relu6') {
71546 return relu6$1({ inputs: { x }, backend });
71547 }
71548 else if (activation === 'prelu') {
71549 return prelu$2({ inputs: { x, alpha: preluActivationWeights }, backend });
71550 }
71551 else if (activation === 'leakyrelu') {
71552 return leakyRelu$1({ inputs: { x }, backend, attrs: { alpha: leakyreluAlpha } });
71553 }
71554 else if (activation === 'sigmoid') {
71555 return sigmoid$1({ inputs: { x }, backend });
71556 }
71557 throw new Error(`Activation ${activation} has not been implemented for the CPU backend.`);
71558 }
71559
71560 /**
71561 * @license
71562 * Copyright 2020 Google LLC. All Rights Reserved.
71563 * Licensed under the Apache License, Version 2.0 (the "License");
71564 * you may not use this file except in compliance with the License.
71565 * You may obtain a copy of the License at
71566 *
71567 * http://www.apache.org/licenses/LICENSE-2.0
71568 *
71569 * Unless required by applicable law or agreed to in writing, software
71570 * distributed under the License is distributed on an "AS IS" BASIS,
71571 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71572 * See the License for the specific language governing permissions and
71573 * limitations under the License.
71574 * =============================================================================
71575 */
71576 function reshape$2(args) {
71577 const { inputs, backend, attrs } = args;
71578 const { x } = inputs;
71579 const { shape } = attrs;
71580 const xSize = sizeFromShape(x.shape);
71581 const $shape = inferFromImplicitShape(shape, xSize);
71582 const $xSize = sizeFromShape($shape);
71583 assert(xSize === $xSize, () => `The new shape (${$shape}) has ${$xSize} elements and the old ` +
71584 `shape (${x.shape}) has ${xSize} elements. The new shape and old ` +
71585 `shape must have the same number of elements.`);
71586 backend.incRef(x.dataId);
71587 const xData = backend.data.get(x.dataId);
71588 if (xData.complexTensorInfos != null) {
71589 const real = xData.complexTensorInfos.real;
71590 const imag = xData.complexTensorInfos.imag;
71591 real.shape = $shape;
71592 imag.shape = $shape;
71593 }
71594 return { dataId: x.dataId, shape: $shape, dtype: x.dtype };
71595 }
71596 const reshapeConfig = {
71597 kernelName: Reshape,
71598 backendName: 'cpu',
71599 kernelFunc: reshape$2
71600 };
71601
71602 /**
71603 * @license
71604 * Copyright 2020 Google LLC. All Rights Reserved.
71605 * Licensed under the Apache License, Version 2.0 (the License);
71606 * you may not use this file except in compliance with the License.
71607 * You may obtain a copy of the License at
71608 *
71609 * http://www.apache.org/licenses/LICENSE-2.0
71610 *
71611 * Unless required by applicable law or agreed to in writing, software
71612 * distributed under the License is distributed on an AS IS BASIS,
71613 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71614 * See the License for the specific language governing permissions and
71615 * limitations under the License.
71616 * =============================================================================
71617 */
71618 function batchMatMul(args) {
71619 const { inputs, backend, attrs } = args;
71620 const { a, b } = inputs;
71621 const { transposeA, transposeB } = attrs;
71622 assertNotComplex([a, b], 'matMul');
71623 const aRank = a.shape.length;
71624 const bRank = b.shape.length;
71625 const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
71626 const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
71627 const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
71628 const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
71629 const outerDimsA = a.shape.slice(0, -2);
71630 const outerDimsB = b.shape.slice(0, -2);
71631 const batchDimA = sizeFromShape(outerDimsA);
71632 const batchDimB = sizeFromShape(outerDimsB);
71633 const outShapeOuterDims = assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2));
71634 const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
71635 assert(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` +
71636 `${innerShapeB}) of Tensors with shapes ${a.shape} and ` +
71637 `${b.shape} and transposeA=${transposeA}` +
71638 ` and transposeB=${transposeB} must match.`);
71639 const a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] :
71640 [batchDimA, outerShapeA, innerShapeA];
71641 const b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] :
71642 [batchDimB, innerShapeB, outerShapeB];
71643 // The rest of the implementation is designed to operate on rank-3 tensors
71644 const a3d = reshape$2({ inputs: { x: a }, backend, attrs: { shape: a3dShape } });
71645 const b3d = reshape$2({ inputs: { x: b }, backend, attrs: { shape: b3dShape } });
71646 const sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];
71647 const leftDim = transposeA ? a3d.shape[2] : a3d.shape[1];
71648 const rightDim = transposeB ? b3d.shape[1] : b3d.shape[2];
71649 const batchDim = Math.max(batchDimA, batchDimB);
71650 const a3dValues = backend.data.get(a3d.dataId).values;
71651 const b3dValues = backend.data.get(b3d.dataId).values;
71652 const a3dStrides = computeStrides(a3d.shape);
71653 const b3dStrides = computeStrides(b3d.shape);
71654 const [aBatch, aOuterStep, aInnerStep] = transposeA ?
71655 [a3dStrides[0], 1, a3dStrides[1]] :
71656 [a3dStrides[0], a3dStrides[1], 1];
71657 const [bInnerStep, bOuterStep, bBatch] = transposeB ?
71658 [1, b3dStrides[1], b3dStrides[0]] :
71659 [b3dStrides[1], 1, b3dStrides[0]];
71660 const size = leftDim * rightDim;
71661 const result = buffer([batchDim, leftDim, rightDim], a3d.dtype);
71662 const resVals = result.values;
71663 const blockSize = backend.blockSize;
71664 for (let bi = 0; bi < batchDim; bi++) {
71665 for (let i0 = 0; i0 < leftDim; i0 += blockSize) {
71666 for (let j0 = 0; j0 < rightDim; j0 += blockSize) {
71667 for (let k0 = 0; k0 < sharedDim; k0 += blockSize) {
71668 // for when blockSize doesn't evenly divide the input
71669 const iBlock = Math.min(i0 + blockSize, leftDim);
71670 const jBlock = Math.min(j0 + blockSize, rightDim);
71671 const kBlock = Math.min(k0 + blockSize, sharedDim);
71672 for (let i = i0; i < iBlock; i++) {
71673 for (let j = j0; j < jBlock; j++) {
71674 let sum = 0.0;
71675 for (let k = k0; k < kBlock; k++) {
71676 const batchOffsetA = Math.min(bi, batchDimA - 1) * aBatch;
71677 const batchOffsetB = Math.min(bi, batchDimB - 1) * bBatch;
71678 const aVal = a3dValues[batchOffsetA + i * aOuterStep + k * aInnerStep];
71679 const bVal = b3dValues[k * bInnerStep + j * bOuterStep + batchOffsetB];
71680 sum += aVal * bVal;
71681 }
71682 resVals[bi * size + (i * rightDim + j)] += sum;
71683 }
71684 }
71685 }
71686 }
71687 }
71688 }
71689 backend.disposeIntermediateTensorInfo(a3d);
71690 backend.disposeIntermediateTensorInfo(b3d);
71691 // set correct shape on output.
71692 return backend.makeTensorInfo(outShape, result.dtype, result.values);
71693 }
71694 const batchMatMulConfig = {
71695 kernelName: BatchMatMul,
71696 backendName: 'cpu',
71697 kernelFunc: batchMatMul,
71698 };
71699
71700 /**
71701 * @license
71702 * Copyright 2020 Google LLC. All Rights Reserved.
71703 * Licensed under the Apache License, Version 2.0 (the License);
71704 * you may not use this file except in compliance with the License.
71705 * You may obtain a copy of the License at
71706 *
71707 * http://www.apache.org/licenses/LICENSE-2.0
71708 *
71709 * Unless required by applicable law or agreed to in writing, software
71710 * distributed under the License is distributed on an AS IS BASIS,
71711 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71712 * See the License for the specific language governing permissions and
71713 * limitations under the License.
71714 * =============================================================================
71715 */
71716 function _fusedMatMul(args) {
71717 const { inputs, backend, attrs } = args;
71718 const { a, b, bias, preluActivationWeights } = inputs;
71719 const { transposeA, transposeB, activation, leakyreluAlpha } = attrs;
71720 let current;
71721 let addRes;
71722 let activationRes;
71723 const intermediates = [];
71724 const matMulRes = batchMatMul({ inputs: { a, b }, attrs: { transposeA, transposeB }, backend });
71725 current = matMulRes;
71726 if (bias) {
71727 addRes = add$4({ inputs: { a: current, b: bias }, backend });
71728 intermediates.push(current);
71729 current = addRes;
71730 }
71731 if (activation) {
71732 activationRes = applyActivation$1(backend, current, activation, preluActivationWeights, leakyreluAlpha);
71733 intermediates.push(current);
71734 current = activationRes;
71735 }
71736 for (const i of intermediates) {
71737 backend.disposeIntermediateTensorInfo(i);
71738 }
71739 return current;
71740 }
71741 const _fusedMatMulConfig = {
71742 kernelName: _FusedMatMul,
71743 backendName: 'cpu',
71744 kernelFunc: _fusedMatMul,
71745 };
71746
71747 /**
71748 * @license
71749 * Copyright 2020 Google LLC. All Rights Reserved.
71750 * Licensed under the Apache License, Version 2.0 (the License);
71751 * you may not use this file except in compliance with the License.
71752 * You may obtain a copy of the License at
71753 *
71754 * http://www.apache.org/licenses/LICENSE-2.0
71755 *
71756 * Unless required by applicable law or agreed to in writing, software
71757 * distributed under the License is distributed on an AS IS BASIS,
71758 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71759 * See the License for the specific language governing permissions and
71760 * limitations under the License.
71761 * =============================================================================
71762 */
71763 const acos$1 = unaryKernelFunc(Acos, (xi) => Math.acos(xi));
71764 const acosConfig = {
71765 kernelName: Acos,
71766 backendName: 'cpu',
71767 kernelFunc: acos$1,
71768 };
71769
71770 /**
71771 * @license
71772 * Copyright 2020 Google LLC. All Rights Reserved.
71773 * Licensed under the Apache License, Version 2.0 (the License);
71774 * you may not use this file except in compliance with the License.
71775 * You may obtain a copy of the License at
71776 *
71777 * http://www.apache.org/licenses/LICENSE-2.0
71778 *
71779 * Unless required by applicable law or agreed to in writing, software
71780 * distributed under the License is distributed on an AS IS BASIS,
71781 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71782 * See the License for the specific language governing permissions and
71783 * limitations under the License.
71784 * =============================================================================
71785 */
71786 const acosh$1 = unaryKernelFunc(Acosh, (xi) => Math.acosh(xi));
71787 const acoshConfig = {
71788 kernelName: Acosh,
71789 backendName: 'cpu',
71790 kernelFunc: acosh$1,
71791 };
71792
71793 /**
71794 * @license
71795 * Copyright 2020 Google LLC. All Rights Reserved.
71796 * Licensed under the Apache License, Version 2.0 (the "License");
71797 * you may not use this file except in compliance with the License.
71798 * You may obtain a copy of the License at
71799 *
71800 * http://www.apache.org/licenses/LICENSE-2.0
71801 *
71802 * Unless required by applicable law or agreed to in writing, software
71803 * distributed under the License is distributed on an "AS IS" BASIS,
71804 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71805 * See the License for the specific language governing permissions and
71806 * limitations under the License.
71807 * =============================================================================
71808 */
71809 function addN$1(args) {
71810 const { inputs, backend } = args;
71811 const tensors = inputs;
71812 assertNotComplex(inputs, 'addN');
71813 const vals = tensors.map(t => backend.data.get(t.dataId).values);
71814 const outBuf = buffer(tensors[0].shape, tensors[0].dtype);
71815 const outVals = outBuf.values;
71816 for (let i = 0; i < tensors.length; i++) {
71817 const currVals = vals[i];
71818 for (let j = 0; j < outVals.length; j++) {
71819 outVals[j] += currVals[j];
71820 }
71821 }
71822 return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
71823 }
71824 const addNConfig = {
71825 kernelName: AddN,
71826 backendName: 'cpu',
71827 kernelFunc: addN$1
71828 };
71829
71830 /**
71831 * @license
71832 * Copyright 2020 Google LLC. All Rights Reserved.
71833 * Licensed under the Apache License, Version 2.0 (the "License");
71834 * you may not use this file except in compliance with the License.
71835 * You may obtain a copy of the License at
71836 *
71837 * http://www.apache.org/licenses/LICENSE-2.0
71838 *
71839 * Unless required by applicable law or agreed to in writing, software
71840 * distributed under the License is distributed on an "AS IS" BASIS,
71841 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71842 * See the License for the specific language governing permissions and
71843 * limitations under the License.
71844 * =============================================================================
71845 */
71846 function all$1(args) {
71847 const { inputs, backend, attrs } = args;
71848 const { x } = inputs;
71849 const { axis, keepDims } = attrs;
71850 assertNotComplex(x, 'all');
71851 const origAxes = parseAxisParam(axis, x.shape);
71852 let axes = origAxes;
71853 const permutedAxes = getAxesPermutation(axes, x.shape.length);
71854 let $x = x;
71855 if (permutedAxes != null) {
71856 $x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
71857 axes = getInnerMostAxes(axes.length, x.shape.length);
71858 }
71859 assertAxesAreInnerMostDims('all', axes, $x.shape.length);
71860 const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
71861 const reduceSize = sizeFromShape(reduceShape);
71862 const vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
71863 const aVals = backend.data.get($x.dataId).values;
71864 for (let i = 0; i < vals.length; ++i) {
71865 const offset = i * reduceSize;
71866 let all = aVals[offset];
71867 for (let j = 0; j < reduceSize; ++j) {
71868 const value = aVals[offset + j];
71869 all = all && value;
71870 }
71871 vals[i] = all;
71872 }
71873 if (permutedAxes != null) {
71874 backend.disposeIntermediateTensorInfo($x);
71875 }
71876 const result = backend.makeTensorInfo(outShape, $x.dtype, vals);
71877 if (keepDims) {
71878 const expandedShape = expandShapeToKeepDim(outShape, origAxes);
71879 const reshapedResult = reshape$2({ inputs: { x: result }, backend, attrs: { shape: expandedShape } });
71880 backend.disposeIntermediateTensorInfo(result);
71881 return reshapedResult;
71882 }
71883 return result;
71884 }
71885 const allConfig = {
71886 kernelName: All,
71887 backendName: 'cpu',
71888 kernelFunc: all$1
71889 };
71890
71891 /**
71892 * @license
71893 * Copyright 2020 Google LLC. All Rights Reserved.
71894 * Licensed under the Apache License, Version 2.0 (the "License");
71895 * you may not use this file except in compliance with the License.
71896 * You may obtain a copy of the License at
71897 *
71898 * http://www.apache.org/licenses/LICENSE-2.0
71899 *
71900 * Unless required by applicable law or agreed to in writing, software
71901 * distributed under the License is distributed on an "AS IS" BASIS,
71902 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71903 * See the License for the specific language governing permissions and
71904 * limitations under the License.
71905 * =============================================================================
71906 */
71907 function any$1(args) {
71908 const { inputs, backend, attrs } = args;
71909 const { x } = inputs;
71910 const { axis, keepDims } = attrs;
71911 assertNotComplex(x, 'any');
71912 const origAxes = parseAxisParam(axis, x.shape);
71913 let axes = origAxes;
71914 const permutedAxes = getAxesPermutation(axes, x.shape.length);
71915 let $x = x;
71916 if (permutedAxes != null) {
71917 $x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
71918 axes = getInnerMostAxes(axes.length, x.shape.length);
71919 }
71920 assertAxesAreInnerMostDims('any', axes, $x.shape.length);
71921 const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
71922 const reduceSize = sizeFromShape(reduceShape);
71923 const vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
71924 const aVals = backend.data.get($x.dataId).values;
71925 for (let i = 0; i < vals.length; ++i) {
71926 const offset = i * reduceSize;
71927 let anyVal = aVals[offset];
71928 for (let j = 0; j < reduceSize; ++j) {
71929 const value = aVals[offset + j];
71930 anyVal = anyVal || value;
71931 }
71932 vals[i] = anyVal;
71933 }
71934 if (permutedAxes != null) {
71935 backend.disposeIntermediateTensorInfo($x);
71936 }
71937 const result = backend.makeTensorInfo(outShape, $x.dtype, vals);
71938 if (keepDims) {
71939 const expandedShape = expandShapeToKeepDim(outShape, origAxes);
71940 const reshapedResult = reshape$2({ inputs: { x: result }, backend, attrs: { shape: expandedShape } });
71941 backend.disposeIntermediateTensorInfo(result);
71942 return reshapedResult;
71943 }
71944 return result;
71945 }
71946 const anyConfig = {
71947 kernelName: Any,
71948 backendName: 'cpu',
71949 kernelFunc: any$1
71950 };
71951
71952 /**
71953 * @license
71954 * Copyright 2020 Google LLC. All Rights Reserved.
71955 * Licensed under the Apache License, Version 2.0 (the "License");
71956 * you may not use this file except in compliance with the License.
71957 * You may obtain a copy of the License at
71958 *
71959 * http://www.apache.org/licenses/LICENSE-2.0
71960 *
71961 * Unless required by applicable law or agreed to in writing, software
71962 * distributed under the License is distributed on an "AS IS" BASIS,
71963 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71964 * See the License for the specific language governing permissions and
71965 * limitations under the License.
71966 * =============================================================================
71967 */
71968 function argMax$1(args) {
71969 const { inputs, backend, attrs } = args;
71970 const { x } = inputs;
71971 const { axis } = attrs;
71972 assertNotComplex(x, 'argMax');
71973 let axes = parseAxisParam(axis, x.shape);
71974 const permutedAxes = getAxesPermutation(axes, x.shape.length);
71975 let $x = x;
71976 const intermediateTensorInfos = [];
71977 if (permutedAxes != null) {
71978 $x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
71979 intermediateTensorInfos.push($x);
71980 axes = getInnerMostAxes(axes.length, $x.shape.length);
71981 }
71982 axes = [axes[0]];
71983 assertAxesAreInnerMostDims('argMax', axes, $x.shape.length);
71984 const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
71985 const outSize = sizeFromShape(outShape);
71986 const vals = makeZerosTypedArray(outSize, 'int32');
71987 const reduceSize = sizeFromShape(reduceShape);
71988 const aVals = backend.data.get($x.dataId).values;
71989 for (let i = 0; i < vals.length; ++i) {
71990 const offset = i * reduceSize;
71991 let max = aVals[offset];
71992 let maxIndex = 0;
71993 for (let j = 0; j < reduceSize; ++j) {
71994 const value = aVals[offset + j];
71995 if (value > max) {
71996 max = value;
71997 maxIndex = j;
71998 }
71999 }
72000 vals[i] = maxIndex;
72001 }
72002 intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
72003 return backend.makeTensorInfo(outShape, 'int32', vals);
72004 }
72005 const argMaxConfig = {
72006 kernelName: ArgMax,
72007 backendName: 'cpu',
72008 kernelFunc: argMax$1
72009 };
72010
72011 /**
72012 * @license
72013 * Copyright 2020 Google LLC. All Rights Reserved.
72014 * Licensed under the Apache License, Version 2.0 (the "License");
72015 * you may not use this file except in compliance with the License.
72016 * You may obtain a copy of the License at
72017 *
72018 * http://www.apache.org/licenses/LICENSE-2.0
72019 *
72020 * Unless required by applicable law or agreed to in writing, software
72021 * distributed under the License is distributed on an "AS IS" BASIS,
72022 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72023 * See the License for the specific language governing permissions and
72024 * limitations under the License.
72025 * =============================================================================
72026 */
72027 function argMin$1(args) {
72028 const { inputs, backend, attrs } = args;
72029 const { x } = inputs;
72030 const { axis } = attrs;
72031 assertNotComplex(x, 'argMin');
72032 let axes = parseAxisParam(axis, x.shape);
72033 const permutedAxes = getAxesPermutation(axes, x.shape.length);
72034 let $x = x;
72035 const intermediateTensorInfos = [];
72036 if (permutedAxes != null) {
72037 $x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
72038 intermediateTensorInfos.push($x);
72039 axes = getInnerMostAxes(axes.length, $x.shape.length);
72040 }
72041 axes = [axes[0]];
72042 assertAxesAreInnerMostDims('argMin', axes, $x.shape.length);
72043 const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
72044 const outSize = sizeFromShape(outShape);
72045 const vals = makeZerosTypedArray(outSize, 'int32');
72046 const reduceSize = sizeFromShape(reduceShape);
72047 const aVals = backend.data.get($x.dataId).values;
72048 for (let i = 0; i < vals.length; ++i) {
72049 const offset = i * reduceSize;
72050 let min = aVals[offset];
72051 let minIndex = 0;
72052 for (let j = 0; j < reduceSize; ++j) {
72053 const value = aVals[offset + j];
72054 if (value < min) {
72055 min = value;
72056 minIndex = j;
72057 }
72058 }
72059 vals[i] = minIndex;
72060 }
72061 intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
72062 return backend.makeTensorInfo(outShape, 'int32', vals);
72063 }
72064 const argMinConfig = {
72065 kernelName: ArgMin,
72066 backendName: 'cpu',
72067 kernelFunc: argMin$1
72068 };
72069
72070 /**
72071 * @license
72072 * Copyright 2020 Google LLC. All Rights Reserved.
72073 * Licensed under the Apache License, Version 2.0 (the License);
72074 * you may not use this file except in compliance with the License.
72075 * You may obtain a copy of the License at
72076 *
72077 * http://www.apache.org/licenses/LICENSE-2.0
72078 *
72079 * Unless required by applicable law or agreed to in writing, software
72080 * distributed under the License is distributed on an AS IS BASIS,
72081 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72082 * See the License for the specific language governing permissions and
72083 * limitations under the License.
72084 * =============================================================================
72085 */
72086 const asin$1 = unaryKernelFunc(Asin, (xi) => Math.asin(xi));
72087 const asinConfig = {
72088 kernelName: Asin,
72089 backendName: 'cpu',
72090 kernelFunc: asin$1,
72091 };
72092
72093 /**
72094 * @license
72095 * Copyright 2020 Google LLC. All Rights Reserved.
72096 * Licensed under the Apache License, Version 2.0 (the License);
72097 * you may not use this file except in compliance with the License.
72098 * You may obtain a copy of the License at
72099 *
72100 * http://www.apache.org/licenses/LICENSE-2.0
72101 *
72102 * Unless required by applicable law or agreed to in writing, software
72103 * distributed under the License is distributed on an AS IS BASIS,
72104 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72105 * See the License for the specific language governing permissions and
72106 * limitations under the License.
72107 * =============================================================================
72108 */
72109 const asinh$1 = unaryKernelFunc(Asinh, (xi) => Math.asinh(xi));
72110 const asinhConfig = {
72111 kernelName: Asinh,
72112 backendName: 'cpu',
72113 kernelFunc: asinh$1,
72114 };
72115
72116 /**
72117 * @license
72118 * Copyright 2020 Google LLC. All Rights Reserved.
72119 * Licensed under the Apache License, Version 2.0 (the License);
72120 * you may not use this file except in compliance with the License.
72121 * You may obtain a copy of the License at
72122 *
72123 * http://www.apache.org/licenses/LICENSE-2.0
72124 *
72125 * Unless required by applicable law or agreed to in writing, software
72126 * distributed under the License is distributed on an AS IS BASIS,
72127 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72128 * See the License for the specific language governing permissions and
72129 * limitations under the License.
72130 * =============================================================================
72131 */
72132 const atan$1 = unaryKernelFunc(Atan, (xi) => Math.atan(xi));
72133 const atanConfig = {
72134 kernelName: Atan,
72135 backendName: 'cpu',
72136 kernelFunc: atan$1,
72137 };
72138
72139 /**
72140 * @license
72141 * Copyright 2020 Google LLC. All Rights Reserved.
72142 * Licensed under the Apache License, Version 2.0 (the License);
72143 * you may not use this file except in compliance with the License.
72144 * You may obtain a copy of the License at
72145 *
72146 * http://www.apache.org/licenses/LICENSE-2.0
72147 *
72148 * Unless required by applicable law or agreed to in writing, software
72149 * distributed under the License is distributed on an AS IS BASIS,
72150 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72151 * See the License for the specific language governing permissions and
72152 * limitations under the License.
72153 * =============================================================================
72154 */
72155 const atan2Impl = createSimpleBinaryKernelImpl((aValue, bValue) => Math.atan2(aValue, bValue));
72156 const atan2$1 = binaryKernelFunc(Atan2, atan2Impl);
72157 const atan2Config = {
72158 kernelName: Atan2,
72159 backendName: 'cpu',
72160 kernelFunc: atan2$1,
72161 };
72162
72163 /**
72164 * @license
72165 * Copyright 2020 Google LLC. All Rights Reserved.
72166 * Licensed under the Apache License, Version 2.0 (the License);
72167 * you may not use this file except in compliance with the License.
72168 * You may obtain a copy of the License at
72169 *
72170 * http://www.apache.org/licenses/LICENSE-2.0
72171 *
72172 * Unless required by applicable law or agreed to in writing, software
72173 * distributed under the License is distributed on an AS IS BASIS,
72174 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72175 * See the License for the specific language governing permissions and
72176 * limitations under the License.
72177 * =============================================================================
72178 */
72179 const atanh$1 = unaryKernelFunc(Atanh, (xi) => Math.atanh(xi));
72180 const atanhConfig = {
72181 kernelName: Atanh,
72182 backendName: 'cpu',
72183 kernelFunc: atanh$1,
72184 };
72185
72186 /**
72187 * @license
72188 * Copyright 2020 Google LLC. All Rights Reserved.
72189 * Licensed under the Apache License, Version 2.0 (the "License");
72190 * you may not use this file except in compliance with the License.
72191 * You may obtain a copy of the License at
72192 *
72193 * http://www.apache.org/licenses/LICENSE-2.0
72194 *
72195 * Unless required by applicable law or agreed to in writing, software
72196 * distributed under the License is distributed on an "AS IS" BASIS,
72197 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72198 * See the License for the specific language governing permissions and
72199 * limitations under the License.
72200 * =============================================================================
72201 */
72202 function pool$1(xValues, xShape, dtype, strides, convInfo, poolType) {
72203 const strideHeight = convInfo.strideHeight;
72204 const strideWidth = convInfo.strideWidth;
72205 const dilationHeight = convInfo.dilationHeight;
72206 const dilationWidth = convInfo.dilationWidth;
72207 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
72208 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
72209 const padTop = convInfo.padInfo.top;
72210 const padLeft = convInfo.padInfo.left;
72211 const initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY :
72212 Number.POSITIVE_INFINITY);
72213 const output = buffer(convInfo.outShape, dtype);
72214 const outputVals = output.values;
72215 const outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3];
72216 const outputRowStrides = convInfo.outShape[2] * convInfo.outShape[3];
72217 const outputColStrides = convInfo.outShape[3];
72218 for (let b = 0; b < convInfo.batchSize; ++b) {
72219 const outputBatchOffset = b * outputBatchStrides;
72220 const inputBatchOffset = b * strides[0];
72221 for (let d = 0; d < convInfo.inChannels; ++d) {
72222 for (let yR = 0; yR < convInfo.outHeight; ++yR) {
72223 const xRCorner = yR * strideHeight - padTop;
72224 const xRMin = Math.max(0, xRCorner);
72225 const xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
72226 const outputRowOffset = outputBatchOffset + yR * outputRowStrides;
72227 for (let yC = 0; yC < convInfo.outWidth; ++yC) {
72228 const xCCorner = yC * strideWidth - padLeft;
72229 const xCMin = Math.max(0, xCCorner);
72230 const xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
72231 let minMaxValue = initialValue;
72232 let avgValue = 0;
72233 let count = 0;
72234 for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {
72235 const xROffset = inputBatchOffset + xR * strides[1];
72236 for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {
72237 const xCOffset = xROffset + xC * strides[2];
72238 const pixel = xValues[xCOffset + d];
72239 if ((poolType === 'max' && pixel > minMaxValue)) {
72240 minMaxValue = pixel;
72241 }
72242 else if (poolType === 'avg') {
72243 avgValue += pixel;
72244 count++;
72245 }
72246 }
72247 if (isNaN(minMaxValue)) {
72248 break;
72249 }
72250 }
72251 const outputOffset = outputRowOffset + yC * outputColStrides + d;
72252 outputVals[outputOffset] =
72253 poolType === 'avg' ? avgValue / count : minMaxValue;
72254 }
72255 }
72256 }
72257 }
72258 return output;
72259 }
72260 function maxPoolPositions(xValues, xShape, dtype, convInfo, flattenPositions = false, includeBatchInIndex = false) {
72261 const maxPositions = buffer(convInfo.outShape, 'int32');
72262 const strideHeight = convInfo.strideHeight;
72263 const strideWidth = convInfo.strideWidth;
72264 const dilationHeight = convInfo.dilationHeight;
72265 const dilationWidth = convInfo.dilationWidth;
72266 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
72267 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
72268 const padTop = convInfo.padInfo.top;
72269 const padLeft = convInfo.padInfo.left;
72270 const xBuf = buffer(xShape, dtype, xValues);
72271 for (let b = 0; b < convInfo.batchSize; ++b) {
72272 for (let d = 0; d < convInfo.inChannels; ++d) {
72273 for (let yR = 0; yR < convInfo.outHeight; ++yR) {
72274 const xRCorner = yR * strideHeight - padTop;
72275 let xRMin = xRCorner;
72276 while (xRMin < 0) {
72277 xRMin += dilationHeight;
72278 }
72279 // const xRMin = Math.max(0, xRCorner);
72280 const xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
72281 for (let yC = 0; yC < convInfo.outWidth; ++yC) {
72282 const xCCorner = yC * strideWidth - padLeft;
72283 let xCMin = xCCorner;
72284 while (xCMin < 0) {
72285 xCMin += dilationWidth;
72286 }
72287 const xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
72288 let maxValue = Number.NEGATIVE_INFINITY;
72289 let maxPosition = -1;
72290 for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {
72291 const wR = xR - xRCorner;
72292 for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {
72293 const wC = xC - xCCorner;
72294 const pixel = xBuf.get(b, xR, xC, d);
72295 if (pixel > maxValue) {
72296 maxValue = pixel;
72297 if (flattenPositions) {
72298 maxPosition = includeBatchInIndex ?
72299 ((b * convInfo.inHeight + xR) * convInfo.inWidth + xC) *
72300 convInfo.inChannels +
72301 d :
72302 (xR * convInfo.inWidth + xC) * convInfo.inChannels + d;
72303 }
72304 else {
72305 maxPosition = wR * effectiveFilterWidth + wC;
72306 }
72307 }
72308 }
72309 }
72310 maxPositions.set(maxPosition, b, yR, yC, d);
72311 }
72312 }
72313 }
72314 }
72315 return maxPositions;
72316 }
72317 function pool3d$1(xValues, xShape, dtype, strides, convInfo, poolType) {
72318 const strideDepth = convInfo.strideDepth;
72319 const strideHeight = convInfo.strideHeight;
72320 const strideWidth = convInfo.strideWidth;
72321 const dilationDepth = convInfo.dilationDepth;
72322 const dilationHeight = convInfo.dilationHeight;
72323 const dilationWidth = convInfo.dilationWidth;
72324 const effectiveFilterDepth = convInfo.effectiveFilterDepth;
72325 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
72326 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
72327 const padFront = convInfo.padInfo.front;
72328 const padTop = convInfo.padInfo.top;
72329 const padLeft = convInfo.padInfo.left;
72330 const initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY :
72331 Number.POSITIVE_INFINITY);
72332 const output = buffer(convInfo.outShape, dtype);
72333 const outputVals = output.values;
72334 const outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] *
72335 convInfo.outShape[3] * convInfo.outShape[4];
72336 const outputDepthStrides = convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4];
72337 const outputRowStrides = convInfo.outShape[3] * convInfo.outShape[4];
72338 const outputColStrides = convInfo.outShape[4];
72339 for (let batch = 0; batch < convInfo.batchSize; ++batch) {
72340 const outputBatchOffset = batch * outputBatchStrides;
72341 const inputBatchOffset = batch * strides[0];
72342 for (let channel = 0; channel < convInfo.inChannels; ++channel) {
72343 for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
72344 const xDepthCorner = yDepth * strideDepth - padFront;
72345 let xDepthMin = xDepthCorner;
72346 while (xDepthMin < 0) {
72347 xDepthMin += dilationDepth;
72348 }
72349 const xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
72350 const outputDepthOffset = outputBatchOffset + yDepth * outputDepthStrides;
72351 for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {
72352 const xRowCorner = yRow * strideHeight - padTop;
72353 let xRowMin = xRowCorner;
72354 while (xRowMin < 0) {
72355 xRowMin += dilationHeight;
72356 }
72357 const xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
72358 const outputRowOffset = outputDepthOffset + yRow * outputRowStrides;
72359 for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {
72360 const xColCorner = yCol * strideWidth - padLeft;
72361 let xColMin = xColCorner;
72362 while (xColMin < 0) {
72363 xColMin += dilationWidth;
72364 }
72365 const xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
72366 // Shader code begins
72367 const outputColOffset = outputRowOffset + yCol * outputColStrides;
72368 let minMaxValue = initialValue;
72369 let avgValue = 0;
72370 let count = 0;
72371 for (let xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
72372 const xDepthOffset = inputBatchOffset + xDepth * strides[1];
72373 for (let xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
72374 const xRowOffset = xDepthOffset + xRow * strides[2];
72375 for (let xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
72376 const xColOffset = xRowOffset + xCol * strides[3];
72377 const pixel = xValues[xColOffset + channel];
72378 if ((poolType === 'max' && pixel > minMaxValue)) {
72379 minMaxValue = pixel;
72380 }
72381 else if (poolType === 'avg') {
72382 avgValue += pixel;
72383 count++;
72384 }
72385 if (isNaN(minMaxValue)) {
72386 break;
72387 }
72388 }
72389 if (isNaN(minMaxValue)) {
72390 break;
72391 }
72392 }
72393 if (isNaN(minMaxValue)) {
72394 break;
72395 }
72396 }
72397 const outputOffset = outputColOffset + channel;
72398 outputVals[outputOffset] =
72399 poolType === 'avg' ? avgValue / count : minMaxValue;
72400 }
72401 }
72402 }
72403 }
72404 }
72405 return output;
72406 }
72407 function maxPool3dPositions(xBuf, convInfo) {
72408 const maxPositions = buffer(convInfo.outShape, 'int32');
72409 const strideDepth = convInfo.strideDepth;
72410 const strideHeight = convInfo.strideHeight;
72411 const strideWidth = convInfo.strideWidth;
72412 const dilationDepth = convInfo.dilationDepth;
72413 const dilationHeight = convInfo.dilationHeight;
72414 const dilationWidth = convInfo.dilationWidth;
72415 const effectiveFilterDepth = convInfo.effectiveFilterDepth;
72416 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
72417 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
72418 const padFront = convInfo.padInfo.front;
72419 const padTop = convInfo.padInfo.top;
72420 const padLeft = convInfo.padInfo.left;
72421 for (let batch = 0; batch < convInfo.batchSize; ++batch) {
72422 for (let channel = 0; channel < convInfo.inChannels; ++channel) {
72423 for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
72424 const xDepthCorner = yDepth * strideDepth - padFront;
72425 let xDepthMin = xDepthCorner;
72426 while (xDepthMin < 0) {
72427 xDepthMin += dilationDepth;
72428 }
72429 const xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
72430 for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {
72431 const xRowCorner = yRow * strideHeight - padTop;
72432 let xRowMin = xRowCorner;
72433 while (xRowMin < 0) {
72434 xRowMin += dilationHeight;
72435 }
72436 const xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
72437 for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {
72438 const xColCorner = yCol * strideWidth - padLeft;
72439 let xColMin = xColCorner;
72440 while (xColMin < 0) {
72441 xColMin += dilationWidth;
72442 }
72443 const xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
72444 // Shader code begins
72445 let maxValue = Number.NEGATIVE_INFINITY;
72446 let maxPosition = -1;
72447 for (let xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
72448 const wDepth = xDepth - xDepthCorner;
72449 for (let xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
72450 const wRow = xRow - xRowCorner;
72451 for (let xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
72452 const wCol = xCol - xColCorner;
72453 const pixel = xBuf.get(batch, xDepth, xRow, xCol, channel);
72454 if (pixel >= maxValue) {
72455 maxValue = pixel;
72456 maxPosition =
72457 wDepth * effectiveFilterHeight * effectiveFilterWidth +
72458 wRow * effectiveFilterHeight + wCol;
72459 }
72460 }
72461 }
72462 }
72463 maxPositions.set(maxPosition, batch, yDepth, yRow, yCol, channel);
72464 }
72465 }
72466 }
72467 }
72468 }
72469 return maxPositions;
72470 }
72471
72472 /**
72473 * @license
72474 * Copyright 2020 Google LLC. All Rights Reserved.
72475 * Licensed under the Apache License, Version 2.0 (the "License");
72476 * you may not use this file except in compliance with the License.
72477 * You may obtain a copy of the License at
72478 *
72479 * http://www.apache.org/licenses/LICENSE-2.0
72480 *
72481 * Unless required by applicable law or agreed to in writing, software
72482 * distributed under the License is distributed on an "AS IS" BASIS,
72483 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72484 * See the License for the specific language governing permissions and
72485 * limitations under the License.
72486 * =============================================================================
72487 */
72488 function avgPool$1(args) {
72489 const { inputs, backend, attrs } = args;
72490 const { x } = inputs;
72491 assertNotComplex(x, 'avgPool');
72492 const { filterSize, strides, pad, dimRoundingMode } = attrs;
72493 const dilations = 1;
72494 assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool: Either strides or dilations must be 1. ' +
72495 `Got strides ${strides} and dilations '${dilations}'`);
72496 const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
72497 let res;
72498 if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
72499 arraysEqual(convInfo.inShape, convInfo.outShape)) {
72500 res = identity$1({ inputs: { x }, backend });
72501 }
72502 else {
72503 const xValues = backend.data.get(x.dataId).values;
72504 const strides = computeStrides(x.shape);
72505 const buffer = pool$1(xValues, x.shape, x.dtype, strides, convInfo, 'avg');
72506 res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
72507 }
72508 return res;
72509 }
72510 const avgPoolConfig = {
72511 kernelName: AvgPool,
72512 backendName: 'cpu',
72513 kernelFunc: avgPool$1
72514 };
72515
72516 /**
72517 * @license
72518 * Copyright 2020 Google LLC. All Rights Reserved.
72519 * Licensed under the Apache License, Version 2.0 (the "License");
72520 * you may not use this file except in compliance with the License.
72521 * You may obtain a copy of the License at
72522 *
72523 * http://www.apache.org/licenses/LICENSE-2.0
72524 *
72525 * Unless required by applicable law or agreed to in writing, software
72526 * distributed under the License is distributed on an "AS IS" BASIS,
72527 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72528 * See the License for the specific language governing permissions and
72529 * limitations under the License.
72530 * =============================================================================
72531 */
72532 function avgPool3D(args) {
72533 const { inputs, backend, attrs } = args;
72534 const { x } = inputs;
72535 const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs;
72536 assertNotComplex(x, 'avgPool3d');
72537 const convInfo = computePool3DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode, dataFormat);
72538 const xValues = backend.data.get(x.dataId).values;
72539 const outBuf = pool3d$1(xValues, x.shape, x.dtype, computeStrides(x.shape), convInfo, 'avg');
72540 return backend.makeTensorInfo(outBuf.shape, 'float32', outBuf.values);
72541 }
72542 const avgPool3DConfig = {
72543 kernelName: AvgPool3D,
72544 backendName: 'cpu',
72545 kernelFunc: avgPool3D
72546 };
72547
72548 /**
72549 * @license
72550 * Copyright 2020 Google LLC. All Rights Reserved.
72551 * Licensed under the Apache License, Version 2.0 (the "License");
72552 * you may not use this file except in compliance with the License.
72553 * You may obtain a copy of the License at
72554 *
72555 * http://www.apache.org/licenses/LICENSE-2.0
72556 *
72557 * Unless required by applicable law or agreed to in writing, software
72558 * distributed under the License is distributed on an "AS IS" BASIS,
72559 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72560 * See the License for the specific language governing permissions and
72561 * limitations under the License.
72562 * =============================================================================
72563 */
72564 function avgPool3DGrad(args) {
72565 const { inputs, backend, attrs } = args;
72566 const { dy, input } = inputs;
72567 const { filterSize, strides, pad, dimRoundingMode } = attrs;
72568 assertNotComplex([dy, input], 'avgPool3DGrad');
72569 const convInfo = computePool3DInfo(input.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
72570 const strideDepth = convInfo.strideDepth;
72571 const strideHeight = convInfo.strideHeight;
72572 const strideWidth = convInfo.strideWidth;
72573 const filterDepth = convInfo.filterDepth;
72574 const filterHeight = convInfo.filterHeight;
72575 const filterWidth = convInfo.filterWidth;
72576 const dilationDepth = convInfo.dilationDepth;
72577 const dilationHeight = convInfo.dilationHeight;
72578 const dilationWidth = convInfo.dilationWidth;
72579 const effectiveFilterDepth = convInfo.effectiveFilterDepth;
72580 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
72581 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
72582 const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
72583 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
72584 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
72585 const dx = buffer(input.shape, 'float32');
72586 const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
72587 const dyBuf = backend.bufferSync(dy);
72588 for (let batch = 0; batch < convInfo.batchSize; ++batch) {
72589 for (let channel = 0; channel < convInfo.inChannels; ++channel) {
72590 for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
72591 for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
72592 for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
72593 // Shader code begins.
72594 const dyDepthCorner = dxDepth - padFront;
72595 const dyRowCorner = dxRow - padTop;
72596 const dyColCorner = dxCol - padLeft;
72597 let dotProd = 0;
72598 for (let wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
72599 const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
72600 if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
72601 Math.floor(dyDepth) !== dyDepth) {
72602 continue;
72603 }
72604 for (let wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
72605 const dyRow = (dyRowCorner + wRow) / strideHeight;
72606 if (dyRow < 0 || dyRow >= convInfo.outHeight ||
72607 Math.floor(dyRow) !== dyRow) {
72608 continue;
72609 }
72610 for (let wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
72611 const dyCol = (dyColCorner + wCol) / strideWidth;
72612 if (dyCol < 0 || dyCol >= convInfo.outWidth ||
72613 Math.floor(dyCol) !== dyCol) {
72614 continue;
72615 }
72616 const pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
72617 dotProd += pixel;
72618 }
72619 }
72620 }
72621 dx.set(dotProd * avgMultiplier, batch, dxDepth, dxRow, dxCol, channel);
72622 }
72623 }
72624 }
72625 }
72626 }
72627 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
72628 }
72629 const avgPool3DGradConfig$1 = {
72630 kernelName: AvgPool3DGrad,
72631 backendName: 'cpu',
72632 kernelFunc: avgPool3DGrad
72633 };
72634
72635 /**
72636 * @license
72637 * Copyright 2020 Google LLC. All Rights Reserved.
72638 * Licensed under the Apache License, Version 2.0 (the "License");
72639 * you may not use this file except in compliance with the License.
72640 * You may obtain a copy of the License at
72641 *
72642 * http://www.apache.org/licenses/LICENSE-2.0
72643 *
72644 * Unless required by applicable law or agreed to in writing, software
72645 * distributed under the License is distributed on an "AS IS" BASIS,
72646 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72647 * See the License for the specific language governing permissions and
72648 * limitations under the License.
72649 * =============================================================================
72650 */
72651 function avgPoolGrad$1(args) {
72652 const { inputs, backend, attrs } = args;
72653 const { dy, input } = inputs;
72654 const x = input;
72655 assertNotComplex([dy, input], 'avgPoolGrad');
72656 const { filterSize, strides, pad } = attrs;
72657 const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad);
72658 const strideHeight = convInfo.strideHeight;
72659 const strideWidth = convInfo.strideWidth;
72660 const filterHeight = convInfo.filterHeight;
72661 const filterWidth = convInfo.filterWidth;
72662 const dilationHeight = convInfo.dilationHeight;
72663 const dilationWidth = convInfo.dilationWidth;
72664 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
72665 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
72666 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
72667 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
72668 const dx = buffer(x.shape, 'float32');
72669 const avgMultiplier = 1 / (filterHeight * filterWidth);
72670 const dyData = backend.data.get(dy.dataId).values;
72671 const dyBuf = buffer(dy.shape, 'float32', dyData);
72672 for (let b = 0; b < convInfo.batchSize; ++b) {
72673 for (let d = 0; d < convInfo.inChannels; ++d) {
72674 for (let dxR = 0; dxR < convInfo.inHeight; ++dxR) {
72675 for (let dxC = 0; dxC < convInfo.inWidth; ++dxC) {
72676 // Shader code begins.
72677 const dyRCorner = dxR - padTop;
72678 const dyCCorner = dxC - padLeft;
72679 let dotProd = 0;
72680 for (let wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
72681 const dyR = (dyRCorner + wR) / strideHeight;
72682 if (dyR < 0 || dyR >= convInfo.outHeight ||
72683 Math.floor(dyR) !== dyR) {
72684 continue;
72685 }
72686 for (let wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
72687 const dyC = (dyCCorner + wC) / strideWidth;
72688 if (dyC < 0 || dyC >= convInfo.outWidth ||
72689 Math.floor(dyC) !== dyC) {
72690 continue;
72691 }
72692 const pixel = dyBuf.get(b, dyR, dyC, d);
72693 dotProd += pixel;
72694 }
72695 }
72696 dx.set(dotProd * avgMultiplier, b, dxR, dxC, d);
72697 }
72698 }
72699 }
72700 }
72701 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
72702 }
72703 const avgPoolGradConfig$1 = {
72704 kernelName: AvgPoolGrad,
72705 backendName: 'cpu',
72706 kernelFunc: avgPoolGrad$1
72707 };
72708
72709 /**
72710 * @license
72711 * Copyright 2020 Google LLC. All Rights Reserved.
72712 * Licensed under the Apache License, Version 2.0 (the "License");
72713 * you may not use this file except in compliance with the License.
72714 * You may obtain a copy of the License at
72715 *
72716 * http://www.apache.org/licenses/LICENSE-2.0
72717 *
72718 * Unless required by applicable law or agreed to in writing, software
72719 * distributed under the License is distributed on an "AS IS" BASIS,
72720 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72721 * See the License for the specific language governing permissions and
72722 * limitations under the License.
72723 * =============================================================================
72724 */
72725 function batchNorm$1(args) {
72726 const { inputs, backend, attrs } = args;
72727 const { x, scale, offset, mean, variance } = inputs;
72728 assert(mean.shape.length === variance.shape.length, () => 'Batch normalization gradient requires mean and variance to have ' +
72729 'equal ranks.');
72730 assert(offset == null || mean.shape.length === offset.shape.length, () => 'Batch normalization gradient requires mean and offset to have ' +
72731 'equal ranks.');
72732 assert(scale == null || mean.shape.length === scale.shape.length, () => 'Batch normalization gradient requires mean and scale to have ' +
72733 'equal ranks.');
72734 assertNotComplex([x, mean, variance, scale, offset], 'batchNorm');
72735 let { varianceEpsilon } = attrs;
72736 if (varianceEpsilon == null) {
72737 varianceEpsilon = 0.001;
72738 }
72739 const xVals = backend.data.get(x.dataId).values;
72740 const mVals = backend.data.get(mean.dataId).values;
72741 const varVals = backend.data.get(variance.dataId).values;
72742 const sVals = scale ? backend.data.get(scale.dataId).values :
72743 new Float32Array([1]);
72744 const offVals = offset ?
72745 backend.data.get(offset.dataId).values :
72746 new Float32Array([0]);
72747 const outVals = new Float32Array(xVals.length);
72748 const offValsLength = offVals.length;
72749 const sValsLength = sVals.length;
72750 const varValsLength = varVals.length;
72751 const mValsLength = mVals.length;
72752 let offi = 0;
72753 let mi = 0;
72754 let si = 0;
72755 let vi = 0;
72756 for (let i = 0; i < xVals.length; ++i) {
72757 outVals[i] = offVals[offi++] +
72758 (xVals[i] - mVals[mi++]) * sVals[si++] /
72759 Math.sqrt(varVals[vi++] + varianceEpsilon);
72760 if (offi >= offValsLength) {
72761 offi = 0;
72762 }
72763 if (mi >= mValsLength) {
72764 mi = 0;
72765 }
72766 if (si >= sValsLength) {
72767 si = 0;
72768 }
72769 if (vi >= varValsLength) {
72770 vi = 0;
72771 }
72772 }
72773 return backend.makeTensorInfo(x.shape, x.dtype, outVals);
72774 }
72775 const batchNormConfig = {
72776 kernelName: FusedBatchNorm,
72777 backendName: 'cpu',
72778 kernelFunc: batchNorm$1,
72779 };
72780
72781 /**
72782 * @license
72783 * Copyright 2020 Google LLC. All Rights Reserved.
72784 * Licensed under the Apache License, Version 2.0 (the "License");
72785 * you may not use this file except in compliance with the License.
72786 * You may obtain a copy of the License at
72787 *
72788 * http://www.apache.org/licenses/LICENSE-2.0
72789 *
72790 * Unless required by applicable law or agreed to in writing, software
72791 * distributed under the License is distributed on an "AS IS" BASIS,
72792 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72793 * See the License for the specific language governing permissions and
72794 * limitations under the License.
72795 * =============================================================================
72796 */
72797 function batchToSpaceND$1(args) {
72798 const { inputs, backend, attrs } = args;
72799 const { x } = inputs;
72800 const { blockShape, crops } = attrs;
72801 assertNotComplex([x], 'batchToSpaceND');
72802 const prod = blockShape.reduce((a, b) => a * b);
72803 const reshaped = getReshaped(x.shape, blockShape, prod);
72804 const permuted = getPermuted(reshaped.length, blockShape.length);
72805 const reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod);
72806 const sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length);
72807 const sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length);
72808 const xReshaped = reshape$2({ inputs: { x }, backend, attrs: { shape: reshaped } });
72809 const xTransposed = transpose$1({ inputs: { x: xReshaped }, backend, attrs: { perm: permuted } });
72810 const xTransposedReshaped = reshape$2({ inputs: { x: xTransposed }, backend, attrs: { shape: reshapedPermuted } });
72811 const result = slice$1({
72812 inputs: { x: xTransposedReshaped },
72813 backend,
72814 attrs: { begin: sliceBeginCoords, size: sliceSize }
72815 });
72816 backend.disposeIntermediateTensorInfo(xReshaped);
72817 backend.disposeIntermediateTensorInfo(xTransposed);
72818 backend.disposeIntermediateTensorInfo(xTransposedReshaped);
72819 return result;
72820 }
72821 const batchToSpaceNDConfig = {
72822 kernelName: BatchToSpaceND,
72823 backendName: 'cpu',
72824 kernelFunc: batchToSpaceND$1
72825 };
72826
72827 /**
72828 * @license
72829 * Copyright 2020 Google LLC. All Rights Reserved.
72830 * Licensed under the Apache License, Version 2.0 (the "License");
72831 * you may not use this file except in compliance with the License.
72832 * You may obtain a copy of the License at
72833 *
72834 * http://www.apache.org/licenses/LICENSE-2.0
72835 *
72836 * Unless required by applicable law or agreed to in writing, software
72837 * distributed under the License is distributed on an "AS IS" BASIS,
72838 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72839 * See the License for the specific language governing permissions and
72840 * limitations under the License.
72841 * =============================================================================
72842 */
72843 function bincount$1(args) {
72844 const { inputs, backend, attrs } = args;
72845 const { x, weights } = inputs;
72846 const { size } = attrs;
72847 const xVals = backend.data.get(x.dataId).values;
72848 const weightsVals = backend.data.get(weights.dataId).values;
72849 const outVals = bincountImpl(xVals, weightsVals, weights.dtype, weights.shape, size);
72850 return backend.makeTensorInfo([size], weights.dtype, outVals);
72851 }
72852 const bincountConfig = {
72853 kernelName: Bincount,
72854 backendName: 'cpu',
72855 kernelFunc: bincount$1
72856 };
72857
72858 /**
72859 * @license
72860 * Copyright 2021 Google LLC. All Rights Reserved.
72861 * Licensed under the Apache License, Version 2.0 (the "License");
72862 * you may not use this file except in compliance with the License.
72863 * You may obtain a copy of the License at
72864 *
72865 * http://www.apache.org/licenses/LICENSE-2.0
72866 *
72867 * Unless required by applicable law or agreed to in writing, software
72868 * distributed under the License is distributed on an "AS IS" BASIS,
72869 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72870 * See the License for the specific language governing permissions and
72871 * limitations under the License.
72872 * =============================================================================
72873 */
72874 function broadcastArgs$1(args) {
72875 const { inputs, backend } = args;
72876 const { s0, s1 } = inputs;
72877 const s0Vals = backend.data.get(s0.dataId).values;
72878 const s1Vals = backend.data.get(s1.dataId).values;
72879 const broadcastShape = assertAndGetBroadcastShape(Array.from(s0Vals), Array.from(s1Vals));
72880 return backend.makeTensorInfo([broadcastShape.length], 'int32', Int32Array.from(broadcastShape));
72881 }
72882 const broadcastArgsConfig = {
72883 kernelName: BroadcastArgs,
72884 backendName: 'cpu',
72885 kernelFunc: broadcastArgs$1
72886 };
72887
72888 /**
72889 * @license
72890 * Copyright 2020 Google LLC. All Rights Reserved.
72891 * Licensed under the Apache License, Version 2.0 (the License);
72892 * you may not use this file except in compliance with the License.
72893 * You may obtain a copy of the License at
72894 *
72895 * http://www.apache.org/licenses/LICENSE-2.0
72896 *
72897 * Unless required by applicable law or agreed to in writing, software
72898 * distributed under the License is distributed on an AS IS BASIS,
72899 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72900 * See the License for the specific language governing permissions and
72901 * limitations under the License.
72902 * =============================================================================
72903 */
72904 const clipByValue$1 = unaryKernelFunc(ClipByValue, (xi, attrs) => {
72905 const clipAttrs = attrs;
72906 if (xi > clipAttrs.clipValueMax) {
72907 return clipAttrs.clipValueMax;
72908 }
72909 return xi < clipAttrs.clipValueMin ? clipAttrs.clipValueMin : xi;
72910 });
72911 const clipByValueConfig = {
72912 kernelName: ClipByValue,
72913 backendName: 'cpu',
72914 kernelFunc: clipByValue$1,
72915 };
72916
72917 /**
72918 * @license
72919 * Copyright 2020 Google LLC. All Rights Reserved.
72920 * Licensed under the Apache License, Version 2.0 (the License);
72921 * you may not use this file except in compliance with the License.
72922 * You may obtain a copy of the License at
72923 *
72924 * http://www.apache.org/licenses/LICENSE-2.0
72925 *
72926 * Unless required by applicable law or agreed to in writing, software
72927 * distributed under the License is distributed on an AS IS BASIS,
72928 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72929 * See the License for the specific language governing permissions and
72930 * limitations under the License.
72931 * =============================================================================
72932 */
72933 const complexAbs = (args) => {
72934 const { x } = args.inputs;
72935 const cpuBackend = args.backend;
72936 const resultValues = new Float32Array(sizeFromShape(x.shape));
72937 const complexVals = cpuBackend.data.get(x.dataId);
72938 const real = complexVals.complexTensorInfos.real;
72939 const imag = complexVals.complexTensorInfos.imag;
72940 const realVals = cpuBackend.data.get(real.dataId).values;
72941 const imagVals = cpuBackend.data.get(imag.dataId).values;
72942 for (let i = 0; i < realVals.length; i++) {
72943 const real = realVals[i];
72944 const imag = imagVals[i];
72945 resultValues[i] = Math.hypot(real, imag);
72946 }
72947 return cpuBackend.makeOutput(resultValues, x.shape, 'float32');
72948 };
72949 const complexAbsConfig = {
72950 kernelName: ComplexAbs,
72951 backendName: 'cpu',
72952 kernelFunc: complexAbs,
72953 };
72954
72955 /**
72956 * @license
72957 * Copyright 2020 Google LLC. All Rights Reserved.
72958 * Licensed under the Apache License, Version 2.0 (the "License");
72959 * you may not use this file except in compliance with the License.
72960 * You may obtain a copy of the License at
72961 *
72962 * http://www.apache.org/licenses/LICENSE-2.0
72963 *
72964 * Unless required by applicable law or agreed to in writing, software
72965 * distributed under the License is distributed on an "AS IS" BASIS,
72966 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72967 * See the License for the specific language governing permissions and
72968 * limitations under the License.
72969 * =============================================================================
72970 */
72971 function imag$1(args) {
72972 const { inputs, backend } = args;
72973 const { input } = inputs;
72974 const imag = backend.data.get(input.dataId).complexTensorInfos.imag;
72975 const imagVal = backend.data.get(imag.dataId).values;
72976 // When complex tensor is disposed, its underlying parts will be disposed too.
72977 // Make new tensor out of the imag value of the complex. This makes sure the
72978 // value is still accessible even if complex tensor is disposed.
72979 return backend.makeTensorInfo(imag.shape, imag.dtype, imagVal);
72980 }
72981 const imagConfig = {
72982 kernelName: Imag,
72983 backendName: 'cpu',
72984 kernelFunc: imag$1
72985 };
72986
72987 /**
72988 * @license
72989 * Copyright 2020 Google LLC. All Rights Reserved.
72990 * Licensed under the Apache License, Version 2.0 (the "License");
72991 * you may not use this file except in compliance with the License.
72992 * You may obtain a copy of the License at
72993 *
72994 * http://www.apache.org/licenses/LICENSE-2.0
72995 *
72996 * Unless required by applicable law or agreed to in writing, software
72997 * distributed under the License is distributed on an "AS IS" BASIS,
72998 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72999 * See the License for the specific language governing permissions and
73000 * limitations under the License.
73001 * =============================================================================
73002 */
73003 function concat$1(args) {
73004 const { inputs, backend, attrs } = args;
73005 const { axis } = attrs;
73006 const $axis = parseAxisParam(axis, inputs[0].shape)[0];
73007 let outShape = computeOutShape$1(inputs.map(t => t.shape), $axis);
73008 if (sizeFromShape(outShape) === 0) {
73009 return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
73010 }
73011 // Keep only non-empty tensors (ignore tensors with 0 in their shape).
73012 const $inputs = inputs.filter(t => sizeFromShape(t.shape) > 0);
73013 if ($inputs.length === 1) {
73014 return identity$1({ inputs: { x: $inputs[0] }, backend });
73015 }
73016 const shapes = $inputs.map(t => t.shape);
73017 assertParamsConsistent(shapes, $axis);
73018 if ($inputs[0].dtype === 'complex64') {
73019 const reals = $inputs.map((t) => real$1({ inputs: { input: t }, backend }));
73020 const imags = $inputs.map((t) => imag$1({ inputs: { input: t }, backend }));
73021 const realConcated = concat$1({ inputs: reals, backend, attrs: { axis: $axis } });
73022 const imagConcated = concat$1({ inputs: imags, backend, attrs: { axis: $axis } });
73023 const result = complex$1({ inputs: { real: realConcated, imag: imagConcated }, backend });
73024 reals.forEach(r => backend.disposeIntermediateTensorInfo(r));
73025 imags.forEach(i => backend.disposeIntermediateTensorInfo(i));
73026 backend.disposeIntermediateTensorInfo(realConcated);
73027 backend.disposeIntermediateTensorInfo(imagConcated);
73028 return result;
73029 }
73030 // Any concat of n-dimensional tensors across any axis can be reduced to
73031 // a concatenation of two-dimensional tensors across the axis 1 by first
73032 // partitioning the axes of the original tensors into those less than the
73033 // axis to be concatenated and the rest. Then reshape the tensors
73034 // into a two-dimensional tensor by collapsing these two sets of axes and
73035 // concatenate the resulting matrices across the axis 1, finally reshaping
73036 // the result to have the proper shape.
73037 const inputs2D = $inputs.map(t => {
73038 const innerSize = sizeFromShape(t.shape.slice($axis));
73039 const shape = [-1, innerSize];
73040 return reshape$2({ inputs: { x: t }, backend, attrs: { shape } });
73041 });
73042 const inputsValShapes = inputs2D.map(t => {
73043 return { vals: backend.data.get(t.dataId).values, shape: t.shape };
73044 });
73045 // Concats 2d tensors along axis=1.
73046 outShape =
73047 computeOutShape$1(inputs2D.map(t => t.shape), 1 /* axis */);
73048 const simplyConcat = inputs2D[0].shape[0] === 1;
73049 const outVals = concatImpl(inputsValShapes, outShape, inputs[0].dtype, simplyConcat);
73050 const finalOutShape = computeOutShape$1($inputs.map(t => t.shape), $axis);
73051 const outInfo = backend.makeTensorInfo(finalOutShape, inputs[0].dtype, outVals);
73052 inputs2D.forEach(t => backend.disposeIntermediateTensorInfo(t));
73053 return outInfo;
73054 }
73055 const concatConfig = {
73056 kernelName: Concat,
73057 backendName: 'cpu',
73058 kernelFunc: concat$1
73059 };
73060
73061 /**
73062 * @license
73063 * Copyright 2020 Google LLC. All Rights Reserved.
73064 * Licensed under the Apache License, Version 2.0 (the "License");
73065 * you may not use this file except in compliance with the License.
73066 * You may obtain a copy of the License at
73067 *
73068 * http://www.apache.org/licenses/LICENSE-2.0
73069 *
73070 * Unless required by applicable law or agreed to in writing, software
73071 * distributed under the License is distributed on an "AS IS" BASIS,
73072 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73073 * See the License for the specific language governing permissions and
73074 * limitations under the License.
73075 * =============================================================================
73076 */
73077 function conv2D(args) {
73078 const { inputs, backend, attrs } = args;
73079 const { x, filter } = inputs;
73080 const { strides, pad, dataFormat, dilations, dimRoundingMode } = attrs;
73081 assertNotComplex([x, filter], 'conv2d');
73082 const $dataFormat = convertConv2DDataFormat(dataFormat);
73083 const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
73084 const filterHeight = convInfo.filterHeight;
73085 const filterWidth = convInfo.filterWidth;
73086 const dilationHeight = convInfo.dilationHeight;
73087 const dilationWidth = convInfo.dilationWidth;
73088 const padLeft = convInfo.padInfo.left;
73089 const padTop = convInfo.padInfo.top;
73090 const isChannelsLast = convInfo.dataFormat === 'channelsLast';
73091 const y = new TensorBuffer(convInfo.outShape, x.dtype);
73092 const xStrides = computeStrides(x.shape);
73093 const filterStrides = computeStrides(filter.shape);
73094 const xBatchStride = xStrides[0];
73095 const xRowStride = isChannelsLast ? xStrides[1] : xStrides[2];
73096 const xColStride = isChannelsLast ? xStrides[2] : 1;
73097 const xChannelStride = isChannelsLast ? 1 : xStrides[1];
73098 const yBatchStride = y.strides[0];
73099 const yRowStride = isChannelsLast ? y.strides[1] : y.strides[2];
73100 const yColStride = isChannelsLast ? y.strides[2] : 1;
73101 const yChannelStride = isChannelsLast ? 1 : y.strides[1];
73102 const xVals = backend.data.get(x.dataId).values;
73103 const wVals = backend.data.get(filter.dataId).values;
73104 const yVals = y.values;
73105 for (let b = 0; b < convInfo.batchSize; ++b) {
73106 const xOffset1 = b * xBatchStride;
73107 const yOffset1 = b * yBatchStride;
73108 for (let yR = 0; yR < convInfo.outHeight; ++yR) {
73109 const yOffset2 = yOffset1 + yR * yRowStride;
73110 const xRCorner = yR * convInfo.strideHeight - padTop;
73111 for (let wR = 0; wR < filterHeight; ++wR) {
73112 const xR = xRCorner + wR * dilationHeight;
73113 if (xR < 0 || xR >= convInfo.inHeight) {
73114 continue;
73115 }
73116 const wOffset1 = wR * filterStrides[0];
73117 const xOffset2 = xOffset1 + xR * xRowStride;
73118 for (let yC = 0; yC < convInfo.outWidth; ++yC) {
73119 const yOffset3 = yOffset2 + yC * yColStride;
73120 const xCCorner = yC * convInfo.strideWidth - padLeft;
73121 for (let wC = 0; wC < filterWidth; ++wC) {
73122 const xC = xCCorner + wC * dilationWidth;
73123 if (xC < 0 || xC >= convInfo.inWidth) {
73124 continue;
73125 }
73126 const wOffset2 = wOffset1 + wC * filterStrides[1];
73127 const xOffset3 = xOffset2 + xC * xColStride;
73128 let wOffset3 = wOffset2;
73129 for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
73130 const xVal = xVals[xOffset3 + d1 * xChannelStride];
73131 for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
73132 yVals[yOffset3 + d2 * yChannelStride] +=
73133 xVal * wVals[wOffset3 + d2];
73134 }
73135 wOffset3 += convInfo.outChannels;
73136 }
73137 }
73138 }
73139 }
73140 }
73141 }
73142 return backend.makeTensorInfo(y.shape, y.dtype, yVals);
73143 }
73144 const conv2DConfig = {
73145 kernelName: Conv2D,
73146 backendName: 'cpu',
73147 kernelFunc: conv2D
73148 };
73149
73150 /**
73151 * @license
73152 * Copyright 2020 Google LLC. All Rights Reserved.
73153 * Licensed under the Apache License, Version 2.0 (the "License");
73154 * you may not use this file except in compliance with the License.
73155 * You may obtain a copy of the License at
73156 *
73157 * http://www.apache.org/licenses/LICENSE-2.0
73158 *
73159 * Unless required by applicable law or agreed to in writing, software
73160 * distributed under the License is distributed on an "AS IS" BASIS,
73161 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73162 * See the License for the specific language governing permissions and
73163 * limitations under the License.
73164 * =============================================================================
73165 */
73166 function conv2DBackpropFilter$1(args) {
73167 const { inputs, backend, attrs } = args;
73168 const { x, dy } = inputs;
73169 const { strides, pad, dataFormat, dimRoundingMode, filterShape } = attrs;
73170 assertNotComplex([x, dy], 'conv2dBackpropFilter');
73171 const $dataFormat = convertConv2DDataFormat(dataFormat);
73172 const convInfo = computeConv2DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
73173 const { strideHeight, strideWidth, filterHeight, filterWidth } = convInfo;
73174 const isChannelsLast = convInfo.dataFormat === 'channelsLast';
73175 const dW = new TensorBuffer(convInfo.filterShape, 'float32');
73176 const leftPad = convInfo.padInfo.left;
73177 const topPad = convInfo.padInfo.top;
73178 const xVals = backend.data.get(x.dataId).values;
73179 const dyVals = backend.data.get(dy.dataId).values;
73180 const xBuf = new TensorBuffer(x.shape, x.dtype, xVals);
73181 const dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);
73182 for (let wR = 0; wR < filterHeight; ++wR) {
73183 const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
73184 const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
73185 for (let wC = 0; wC < filterWidth; ++wC) {
73186 const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
73187 const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
73188 for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
73189 for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
73190 let dotProd = 0;
73191 for (let b = 0; b < convInfo.batchSize; ++b) {
73192 for (let yR = yRMin; yR < yRMax; ++yR) {
73193 const xR = wR + yR * strideHeight - topPad;
73194 for (let yC = yCMin; yC < yCMax; ++yC) {
73195 const xC = wC + yC * strideWidth - leftPad;
73196 if (isChannelsLast) {
73197 dotProd += xBuf.get(b, xR, xC, d1) *
73198 dyBuf.get(b, yR, yC, d2);
73199 }
73200 else {
73201 dotProd += xBuf.get(b, d1, xR, xC) *
73202 dyBuf.get(b, d2, yR, yC);
73203 }
73204 }
73205 }
73206 }
73207 dW.set(dotProd, wR, wC, d1, d2);
73208 }
73209 }
73210 }
73211 }
73212 return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
73213 }
73214 const conv2DBackpropFilterConfig = {
73215 kernelName: Conv2DBackpropFilter,
73216 backendName: 'cpu',
73217 kernelFunc: conv2DBackpropFilter$1
73218 };
73219
73220 /**
73221 * @license
73222 * Copyright 2020 Google LLC. All Rights Reserved.
73223 * Licensed under the Apache License, Version 2.0 (the "License");
73224 * you may not use this file except in compliance with the License.
73225 * You may obtain a copy of the License at
73226 *
73227 * http://www.apache.org/licenses/LICENSE-2.0
73228 *
73229 * Unless required by applicable law or agreed to in writing, software
73230 * distributed under the License is distributed on an "AS IS" BASIS,
73231 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73232 * See the License for the specific language governing permissions and
73233 * limitations under the License.
73234 * =============================================================================
73235 */
73236 function conv2DBackpropInput$1(args) {
73237 const { inputs, backend, attrs } = args;
73238 const { dy, filter } = inputs;
73239 const { inputShape, strides, pad, dataFormat, dimRoundingMode } = attrs;
73240 assertNotComplex([dy, filter], 'conv2dBackpropInput');
73241 const filterStrides = computeStrides(filter.shape);
73242 const dyStrides = computeStrides(dy.shape);
73243 let $dataFormat = convertConv2DDataFormat(dataFormat);
73244 const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad, dimRoundingMode, false, $dataFormat);
73245 const dx = new TensorBuffer(convInfo.inShape, 'float32');
73246 const dxValues = dx.values;
73247 const dyValues = backend.data.get(dy.dataId).values;
73248 const fltValues = backend.data.get(filter.dataId).values;
73249 const [fltS0, fltS1, fltS2] = filterStrides;
73250 const { batchSize, filterHeight, filterWidth, inChannels, inHeight, inWidth, outChannels, outHeight, outWidth, strideHeight, strideWidth } = convInfo;
73251 $dataFormat = convInfo.dataFormat;
73252 const topPad = filterHeight - 1 - convInfo.padInfo.top;
73253 const leftPad = filterWidth - 1 - convInfo.padInfo.left;
73254 const isChannelsLast = $dataFormat === 'channelsLast';
73255 const xBatchStride = dx.strides[0];
73256 const xRowStride = isChannelsLast ? dx.strides[1] : dx.strides[2];
73257 const xColStride = isChannelsLast ? dx.strides[2] : 1;
73258 const xChannelStride = isChannelsLast ? 1 : dx.strides[1];
73259 const yBatchStride = dyStrides[0];
73260 const yRowStride = isChannelsLast ? dyStrides[1] : dyStrides[2];
73261 const yColStride = isChannelsLast ? dyStrides[2] : 1;
73262 const yChannelStride = isChannelsLast ? 1 : dyStrides[1];
73263 for (let b = 0; b < batchSize; ++b) {
73264 for (let d1 = 0; d1 < inChannels; ++d1) {
73265 for (let xR = 0; xR < inHeight; ++xR) {
73266 const xRCorner = xR - topPad;
73267 const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
73268 const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
73269 for (let xC = 0; xC < inWidth; ++xC) {
73270 const xCCorner = xC - leftPad;
73271 const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
73272 const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
73273 let dotProd = 0;
73274 for (let yR = xRMin; yR < yRMax; ++yR) {
73275 const wR = yR * strideHeight - xRCorner;
73276 for (let yC = xCMin; yC < yCMax; ++yC) {
73277 const wC = yC * strideWidth - xCCorner;
73278 const dyOffset = yBatchStride * b + yRowStride * yR + yColStride * yC;
73279 const fltOffset = fltS0 * (filterHeight - 1 - wR) +
73280 fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
73281 for (let d2 = 0; d2 < outChannels; ++d2) {
73282 const pixel = dyValues[dyOffset + yChannelStride * d2];
73283 const weight = fltValues[fltOffset + d2];
73284 dotProd += pixel * weight;
73285 }
73286 }
73287 }
73288 const dxOffset = xBatchStride * b + xRowStride * xR +
73289 xColStride * xC + xChannelStride * d1;
73290 dxValues[dxOffset] = dotProd;
73291 }
73292 }
73293 }
73294 }
73295 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
73296 }
73297 const conv2DBackpropInputConfig = {
73298 kernelName: Conv2DBackpropInput,
73299 backendName: 'cpu',
73300 kernelFunc: conv2DBackpropInput$1
73301 };
73302
73303 /**
73304 * @license
73305 * Copyright 2020 Google LLC. All Rights Reserved.
73306 * Licensed under the Apache License, Version 2.0 (the "License");
73307 * you may not use this file except in compliance with the License.
73308 * You may obtain a copy of the License at
73309 *
73310 * http://www.apache.org/licenses/LICENSE-2.0
73311 *
73312 * Unless required by applicable law or agreed to in writing, software
73313 * distributed under the License is distributed on an "AS IS" BASIS,
73314 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73315 * See the License for the specific language governing permissions and
73316 * limitations under the License.
73317 * =============================================================================
73318 */
73319 function conv3D(args) {
73320 const { inputs, backend, attrs } = args;
73321 const { x, filter } = inputs;
73322 const { strides, pad, dilations } = attrs;
73323 assertNotComplex([x, filter], 'conv3d');
73324 const convInfo = computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad);
73325 const { filterDepth, filterHeight, filterWidth, dilationDepth, dilationHeight, dilationWidth, padInfo } = convInfo;
73326 const padFront = padInfo.front;
73327 const padLeft = padInfo.left;
73328 const padTop = padInfo.top;
73329 const y = new TensorBuffer(convInfo.outShape, x.dtype);
73330 const xVals = backend.data.get(x.dataId).values;
73331 const wVals = backend.data.get(filter.dataId).values;
73332 const yVals = y.values;
73333 const xStrides = computeStrides(x.shape);
73334 const filterStrides = computeStrides(filter.shape);
73335 for (let b = 0; b < convInfo.batchSize; ++b) {
73336 const xOffset1 = b * xStrides[0];
73337 const yOffset1 = b * y.strides[0];
73338 for (let yF = 0; yF < convInfo.outDepth; ++yF) {
73339 const yOffset2 = yOffset1 + yF * y.strides[1];
73340 const xFCorner = yF * convInfo.strideDepth - padFront;
73341 for (let wF = 0; wF < filterDepth; ++wF) {
73342 const xF = xFCorner + wF * dilationDepth;
73343 if (xF < 0 || xF >= convInfo.inDepth) {
73344 continue;
73345 }
73346 const wOffset1 = wF * filterStrides[0];
73347 const xOffset2 = xOffset1 + xF * xStrides[1];
73348 for (let yR = 0; yR < convInfo.outHeight; ++yR) {
73349 const yOffset3 = yOffset2 + yR * y.strides[2];
73350 const xRCorner = yR * convInfo.strideHeight - padTop;
73351 for (let wR = 0; wR < filterHeight; ++wR) {
73352 const xR = xRCorner + wR * dilationHeight;
73353 if (xR < 0 || xR >= convInfo.inHeight) {
73354 continue;
73355 }
73356 const wOffset2 = wOffset1 + wR * filterStrides[1];
73357 const xOffset3 = xOffset2 + xR * xStrides[2];
73358 for (let yC = 0; yC < convInfo.outWidth; ++yC) {
73359 const yOffset4 = yOffset3 + yC * convInfo.outChannels;
73360 const xCCorner = yC * convInfo.strideWidth - padLeft;
73361 for (let wC = 0; wC < filterWidth; ++wC) {
73362 const xC = xCCorner + wC * dilationWidth;
73363 if (xC < 0 || xC >= convInfo.inWidth) {
73364 continue;
73365 }
73366 const wOffset3 = wOffset2 + wC * filterStrides[2];
73367 const xOffset4 = xOffset3 + xC * convInfo.inChannels;
73368 let wOffset4 = wOffset3;
73369 for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
73370 const xVal = xVals[xOffset4 + d1];
73371 for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
73372 yVals[yOffset4 + d2] += xVal * wVals[wOffset4 + d2];
73373 }
73374 wOffset4 += convInfo.outChannels;
73375 }
73376 }
73377 }
73378 }
73379 }
73380 }
73381 }
73382 }
73383 return backend.makeTensorInfo(y.shape, y.dtype, y.values);
73384 }
73385 const conv3DConfig = {
73386 kernelName: Conv3D,
73387 backendName: 'cpu',
73388 kernelFunc: conv3D
73389 };
73390
73391 /**
73392 * @license
73393 * Copyright 2020 Google LLC. All Rights Reserved.
73394 * Licensed under the Apache License, Version 2.0 (the "License");
73395 * you may not use this file except in compliance with the License.
73396 * You may obtain a copy of the License at
73397 *
73398 * http://www.apache.org/licenses/LICENSE-2.0
73399 *
73400 * Unless required by applicable law or agreed to in writing, software
73401 * distributed under the License is distributed on an "AS IS" BASIS,
73402 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73403 * See the License for the specific language governing permissions and
73404 * limitations under the License.
73405 * =============================================================================
73406 */
73407 function conv3DBackpropFilterV2(args) {
73408 const { inputs, backend, attrs } = args;
73409 const { x, dy } = inputs;
73410 const { strides, pad, filterShape } = attrs;
73411 assertNotComplex([x, dy], 'conv3dBackpropFilterV2');
73412 const xStrides = computeStrides(x.shape);
73413 const dyStrides = computeStrides(dy.shape);
73414 const convInfo = computeConv3DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad);
73415 const strideDepth = convInfo.strideDepth;
73416 const strideHeight = convInfo.strideHeight;
73417 const strideWidth = convInfo.strideWidth;
73418 const filterDepth = convInfo.filterDepth;
73419 const filterHeight = convInfo.filterHeight;
73420 const filterWidth = convInfo.filterWidth;
73421 const dw = new TensorBuffer(convInfo.filterShape, 'float32');
73422 const dwValues = dw.values;
73423 const [dwS0, dwS1, dwS2, dwS3] = dw.strides;
73424 const dyValues = backend.data.get(dy.dataId).values;
73425 const [dyS0, dyS1, dyS2, dyS3] = dyStrides;
73426 const xValues = backend.data.get(x.dataId).values;
73427 const [xS0, xS1, xS2, xS3] = xStrides;
73428 const frontPad = convInfo.padInfo.front;
73429 const leftPad = convInfo.padInfo.left;
73430 const topPad = convInfo.padInfo.top;
73431 for (let wF = 0; wF < filterDepth; ++wF) {
73432 const yFMin = Math.max(0, Math.ceil((frontPad - wF) / strideDepth));
73433 const yFMax = Math.min(convInfo.outDepth, (convInfo.inDepth + frontPad - wF) / strideDepth);
73434 const wOffset1 = wF * dwS0;
73435 for (let wR = 0; wR < filterHeight; ++wR) {
73436 const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
73437 const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
73438 const wOffset2 = wR * dwS1 + wOffset1;
73439 for (let wC = 0; wC < filterWidth; ++wC) {
73440 const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
73441 const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
73442 const wOffset3 = wC * dwS2 + wOffset2;
73443 for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
73444 const wOffset4 = d1 * dwS3 + wOffset3;
73445 for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
73446 let dotProd = 0;
73447 for (let b = 0; b < convInfo.batchSize; ++b) {
73448 const xOffset1 = b * xS0;
73449 const yOffset1 = b * dyS0;
73450 for (let yF = yFMin; yF < yFMax; ++yF) {
73451 const xF = wF + yF * strideDepth - frontPad;
73452 const xOffset2 = xF * xS1 + xOffset1;
73453 const yOffset2 = yF * dyS1 + yOffset1;
73454 for (let yR = yRMin; yR < yRMax; ++yR) {
73455 const xR = wR + yR * strideHeight - topPad;
73456 const xOffset3 = xR * xS2 + xOffset2;
73457 const yOffset3 = yR * dyS2 + yOffset2;
73458 for (let yC = yCMin; yC < yCMax; ++yC) {
73459 const xC = wC + yC * strideWidth - leftPad;
73460 const xOffset4 = xC * xS3 + xOffset3;
73461 const yOffset4 = yC * dyS3 + yOffset3;
73462 dotProd += xValues[xOffset4 + d1] * dyValues[yOffset4 + d2];
73463 }
73464 }
73465 }
73466 }
73467 dwValues[wOffset4 + d2] = dotProd;
73468 }
73469 }
73470 }
73471 }
73472 }
73473 return backend.makeTensorInfo(dw.shape, dw.dtype, dw.values);
73474 }
73475 const conv3DBackpropFilterV2Config = {
73476 kernelName: Conv3DBackpropFilterV2,
73477 backendName: 'cpu',
73478 kernelFunc: conv3DBackpropFilterV2
73479 };
73480
73481 /**
73482 * @license
73483 * Copyright 2020 Google LLC. All Rights Reserved.
73484 * Licensed under the Apache License, Version 2.0 (the "License");
73485 * you may not use this file except in compliance with the License.
73486 * You may obtain a copy of the License at
73487 *
73488 * http://www.apache.org/licenses/LICENSE-2.0
73489 *
73490 * Unless required by applicable law or agreed to in writing, software
73491 * distributed under the License is distributed on an "AS IS" BASIS,
73492 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73493 * See the License for the specific language governing permissions and
73494 * limitations under the License.
73495 * =============================================================================
73496 */
73497 function conv3DBackpropInputV2(args) {
73498 const { inputs, backend, attrs } = args;
73499 const { dy, filter } = inputs;
73500 const { pad, strides, inputShape } = attrs;
73501 assertNotComplex([dy], 'conv3dBackpropInputV2');
73502 const dyStrides = computeStrides(dy.shape);
73503 const filterStrides = computeStrides(filter.shape);
73504 const convInfo = computeConv3DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad);
73505 const dx = new TensorBuffer(convInfo.inShape, 'float32');
73506 const dxValues = dx.values;
73507 const [dxS0, dxS1, dxS2, dxS3] = dx.strides;
73508 const dyValues = backend.data.get(dy.dataId).values;
73509 const [dyS0, dyS1, dyS2, dyS3] = dyStrides;
73510 const fltValues = backend.data.get(filter.dataId).values;
73511 const [fltS0, fltS1, fltS2, fltS3] = filterStrides;
73512 const { batchSize, filterDepth, filterHeight, filterWidth, inChannels, inDepth, inHeight, inWidth, outChannels, outDepth, outHeight, outWidth, strideDepth, strideHeight, strideWidth } = convInfo;
73513 const frontPad = filterDepth - 1 - convInfo.padInfo.front;
73514 const topPad = filterHeight - 1 - convInfo.padInfo.top;
73515 const leftPad = filterWidth - 1 - convInfo.padInfo.left;
73516 for (let b = 0; b < batchSize; ++b) {
73517 for (let d1 = 0; d1 < inChannels; ++d1) {
73518 // Frames of depth
73519 for (let xF = 0; xF < inDepth; ++xF) {
73520 const xFCorner = xF - frontPad;
73521 const xFMin = Math.max(0, Math.ceil(xFCorner / strideDepth));
73522 const yFMax = Math.min(outDepth, (filterDepth + xFCorner) / strideDepth);
73523 // Rows as per standard 2d matrix notation
73524 for (let xR = 0; xR < inHeight; ++xR) {
73525 const xRCorner = xR - topPad;
73526 const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
73527 const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
73528 // Columns as per standard 2d matrix notation
73529 for (let xC = 0; xC < inWidth; ++xC) {
73530 const xCCorner = xC - leftPad;
73531 const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
73532 const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
73533 let dotProd = 0;
73534 for (let yF = xFMin; yF < yFMax; ++yF) {
73535 const wF = yF * strideDepth - xFCorner;
73536 for (let yR = xRMin; yR < yRMax; ++yR) {
73537 const wR = yR * strideHeight - xRCorner;
73538 for (let yC = xCMin; yC < yCMax; ++yC) {
73539 const wC = yC * strideWidth - xCCorner;
73540 const dyOffset = dyS0 * b + dyS1 * yF + dyS2 * yR + dyS3 * yC;
73541 const fltOffset = fltS0 * (filterDepth - 1 - wF) +
73542 fltS1 * (filterHeight - 1 - wR) +
73543 fltS2 * (filterWidth - 1 - wC) + fltS3 * d1;
73544 for (let d2 = 0; d2 < outChannels; ++d2) {
73545 const pixel = dyValues[dyOffset + d2];
73546 const weight = fltValues[fltOffset + d2];
73547 dotProd += pixel * weight;
73548 }
73549 }
73550 }
73551 }
73552 dxValues[dxS0 * b + dxS1 * xF + dxS2 * xR + dxS3 * xC + d1] =
73553 dotProd;
73554 }
73555 }
73556 }
73557 }
73558 }
73559 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
73560 }
73561 const conv3DBackpropInputV2Config = {
73562 kernelName: Conv3DBackpropInputV2,
73563 backendName: 'cpu',
73564 kernelFunc: conv3DBackpropInputV2
73565 };
73566
73567 /**
73568 * @license
73569 * Copyright 2020 Google LLC. All Rights Reserved.
73570 * Licensed under the Apache License, Version 2.0 (the "License");
73571 * you may not use this file except in compliance with the License.
73572 * You may obtain a copy of the License at
73573 *
73574 * http://www.apache.org/licenses/LICENSE-2.0
73575 *
73576 * Unless required by applicable law or agreed to in writing, software
73577 * distributed under the License is distributed on an "AS IS" BASIS,
73578 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73579 * See the License for the specific language governing permissions and
73580 * limitations under the License.
73581 * =============================================================================
73582 */
73583 const cos$1 = unaryKernelFunc(Cos, (xi) => Math.cos(xi));
73584 const cosConfig = {
73585 kernelName: Cos,
73586 backendName: 'cpu',
73587 kernelFunc: cos$1,
73588 };
73589
73590 /**
73591 * @license
73592 * Copyright 2020 Google LLC. All Rights Reserved.
73593 * Licensed under the Apache License, Version 2.0 (the License);
73594 * you may not use this file except in compliance with the License.
73595 * You may obtain a copy of the License at
73596 *
73597 * http://www.apache.org/licenses/LICENSE-2.0
73598 *
73599 * Unless required by applicable law or agreed to in writing, software
73600 * distributed under the License is distributed on an AS IS BASIS,
73601 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73602 * See the License for the specific language governing permissions and
73603 * limitations under the License.
73604 * =============================================================================
73605 */
73606 const cosh$1 = unaryKernelFunc(Cosh, (xi) => Math.cosh(xi));
73607 const coshConfig = {
73608 kernelName: Cosh,
73609 backendName: 'cpu',
73610 kernelFunc: cosh$1,
73611 };
73612
73613 /**
73614 * @license
73615 * Copyright 2020 Google LLC. All Rights Reserved.
73616 * Licensed under the Apache License, Version 2.0 (the "License");
73617 * you may not use this file except in compliance with the License.
73618 * You may obtain a copy of the License at
73619 *
73620 * http://www.apache.org/licenses/LICENSE-2.0
73621 *
73622 * Unless required by applicable law or agreed to in writing, software
73623 * distributed under the License is distributed on an "AS IS" BASIS,
73624 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73625 * See the License for the specific language governing permissions and
73626 * limitations under the License.
73627 * =============================================================================
73628 */
73629 function cropAndResize$1(args) {
73630 const { inputs, backend, attrs } = args;
73631 const { image, boxes, boxInd } = inputs;
73632 const { cropSize, method, extrapolationValue } = attrs;
73633 const [batch, imageHeight, imageWidth, numChannels] = image.shape;
73634 const numBoxes = boxes.shape[0];
73635 const [cropHeight, cropWidth] = cropSize;
73636 const output = buffer([numBoxes, cropHeight, cropWidth, numChannels], 'float32');
73637 const boxVals = backend.data.get(boxes.dataId).values;
73638 const boxIndVals = backend.data.get(boxInd.dataId).values;
73639 const imageVals = backend.data.get(image.dataId).values;
73640 const inStride = computeStrides(image.shape); // to calculate flat indexes into image
73641 const outStride = computeStrides(output.shape); // to calculate flat indexes into output
73642 // Reference implementation
73643 // tslint:disable-next-line:max-line-length
73644 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op.cc
73645 for (let b = 0; b < numBoxes; b++) {
73646 const startInd = b * 4;
73647 const y1 = boxVals[startInd];
73648 const x1 = boxVals[startInd + 1];
73649 const y2 = boxVals[startInd + 2];
73650 const x2 = boxVals[startInd + 3];
73651 const bInd = boxIndVals[b];
73652 if (bInd >= batch) {
73653 continue;
73654 }
73655 const heightScale = (cropHeight > 1) ? (y2 - y1) * (imageHeight - 1) / (cropHeight - 1) : 0;
73656 const widthScale = (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : 0;
73657 for (let y = 0; y < cropHeight; y++) {
73658 const yInd = (cropHeight > 1) ?
73659 y1 * (imageHeight - 1) + y * (heightScale) :
73660 0.5 * (y1 + y2) * (imageHeight - 1);
73661 if (yInd < 0 || yInd > imageHeight - 1) {
73662 for (let x = 0; x < cropWidth; x++) {
73663 for (let c = 0; c < numChannels; c++) {
73664 const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
73665 output.values[ind] = extrapolationValue;
73666 }
73667 }
73668 continue;
73669 }
73670 if (method === 'bilinear') {
73671 const topInd = Math.floor(yInd);
73672 const bottomInd = Math.ceil(yInd);
73673 const yLerp = yInd - topInd;
73674 for (let x = 0; x < cropWidth; x++) {
73675 const xInd = (cropWidth > 1) ?
73676 x1 * (imageWidth - 1) + x * widthScale :
73677 0.5 * (x1 + x2) * (imageWidth - 1);
73678 if (xInd < 0 || xInd > imageWidth - 1) {
73679 for (let c = 0; c < numChannels; c++) {
73680 const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
73681 output.values[ind] = extrapolationValue;
73682 }
73683 continue;
73684 }
73685 const leftInd = Math.floor(xInd);
73686 const rightInd = Math.ceil(xInd);
73687 const xLerp = xInd - leftInd;
73688 for (let c = 0; c < numChannels; c++) {
73689 let ind = c + leftInd * inStride[2] + topInd * inStride[1] +
73690 bInd * inStride[0];
73691 const topLeft = imageVals[ind];
73692 ind = c + rightInd * inStride[2] + topInd * inStride[1] +
73693 bInd * inStride[0];
73694 const topRight = imageVals[ind];
73695 ind = c + leftInd * inStride[2] + bottomInd * inStride[1] +
73696 bInd * inStride[0];
73697 const bottomLeft = imageVals[ind];
73698 ind = c + rightInd * inStride[2] + bottomInd * inStride[1] +
73699 bInd * inStride[0];
73700 const bottomRight = imageVals[ind];
73701 const top = topLeft + (topRight - topLeft) * xLerp;
73702 const bottom = bottomLeft + (bottomRight - bottomLeft) * xLerp;
73703 ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
73704 output.values[ind] = top + ((bottom - top) * yLerp);
73705 }
73706 }
73707 }
73708 else { // method == "nearest"
73709 for (let x = 0; x < cropWidth; ++x) {
73710 const xInd = (cropWidth > 1) ?
73711 x1 * (imageWidth - 1) + x * widthScale :
73712 0.5 * (x1 + x2) * (imageWidth - 1);
73713 if (xInd < 0 || xInd > imageWidth - 1) {
73714 for (let c = 0; c < numChannels; c++) {
73715 const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
73716 output.values[ind] = extrapolationValue;
73717 }
73718 continue;
73719 }
73720 const closestX = Math.round(xInd);
73721 const closestY = Math.round(yInd);
73722 for (let c = 0; c < numChannels; c++) {
73723 const inInd = c + closestX * inStride[2] + closestY * inStride[1] +
73724 bInd * inStride[0];
73725 const outInd = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
73726 output.values[outInd] = imageVals[inInd];
73727 }
73728 }
73729 }
73730 }
73731 }
73732 return backend.makeTensorInfo(output.shape, output.dtype, output.values);
73733 }
73734 const cropAndResizeConfig = {
73735 kernelName: CropAndResize,
73736 backendName: 'cpu',
73737 kernelFunc: cropAndResize$1
73738 };
73739
73740 /**
73741 * @license
73742 * Copyright 2022 Google LLC. All Rights Reserved.
73743 * Licensed under the Apache License, Version 2.0 (the "License");
73744 * you may not use this file except in compliance with the License.
73745 * You may obtain a copy of the License at
73746 *
73747 * http://www.apache.org/licenses/LICENSE-2.0
73748 *
73749 * Unless required by applicable law or agreed to in writing, software
73750 * distributed under the License is distributed on an "AS IS" BASIS,
73751 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73752 * See the License for the specific language governing permissions and
73753 * limitations under the License.
73754 * =============================================================================
73755 */
73756 function cumprod$1(args) {
73757 const { inputs, backend, attrs } = args;
73758 const { x } = inputs;
73759 const { axis, exclusive, reverse } = attrs;
73760 assertNotComplex(x, 'cumprod');
73761 const permutation = getAxesPermutation([axis], x.shape.length);
73762 let $x = x;
73763 if (permutation != null) {
73764 $x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutation } });
73765 }
73766 const permutedAxis = getInnerMostAxes(1, x.shape.length)[0];
73767 if (permutedAxis !== $x.shape.length - 1) {
73768 throw new Error(`backend.cumprod in CPU expects an inner-most ` +
73769 `axis=${$x.shape.length - 1} but got axis=${permutedAxis}`);
73770 }
73771 const resultDtype = upcastType($x.dtype, 'int32');
73772 const vals = makeOnesTypedArray(sizeFromShape($x.shape), resultDtype);
73773 const aVals = backend.data.get($x.dataId).values;
73774 const finalDim = $x.shape[$x.shape.length - 1];
73775 const indexAdjuster = reverse ?
73776 (i, j) => i + finalDim - j - 1 :
73777 (i, j) => i + j;
73778 for (let i = 0; i < aVals.length; i += finalDim) {
73779 for (let j = 0; j < finalDim; j++) {
73780 const idx = indexAdjuster(i, j);
73781 if (j === 0) {
73782 vals[idx] = exclusive ? 1 : aVals[idx];
73783 }
73784 else {
73785 const prevIdx = indexAdjuster(i, j - 1);
73786 vals[idx] = exclusive ? aVals[prevIdx] * vals[prevIdx] :
73787 aVals[idx] * vals[prevIdx];
73788 }
73789 }
73790 }
73791 const result = backend.makeTensorInfo($x.shape, resultDtype, vals);
73792 if (permutation != null) {
73793 const reversePermutation = getUndoAxesPermutation(permutation);
73794 const reverseTransposedResult = transpose$1({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } });
73795 backend.disposeIntermediateTensorInfo(result);
73796 backend.disposeIntermediateTensorInfo($x);
73797 return reverseTransposedResult;
73798 }
73799 return result;
73800 }
73801 const cumprodConfig = {
73802 kernelName: Cumprod,
73803 backendName: 'cpu',
73804 kernelFunc: cumprod$1
73805 };
73806
73807 /**
73808 * @license
73809 * Copyright 2020 Google LLC. All Rights Reserved.
73810 * Licensed under the Apache License, Version 2.0 (the "License");
73811 * you may not use this file except in compliance with the License.
73812 * You may obtain a copy of the License at
73813 *
73814 * http://www.apache.org/licenses/LICENSE-2.0
73815 *
73816 * Unless required by applicable law or agreed to in writing, software
73817 * distributed under the License is distributed on an "AS IS" BASIS,
73818 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73819 * See the License for the specific language governing permissions and
73820 * limitations under the License.
73821 * =============================================================================
73822 */
73823 function cumsum$1(args) {
73824 const { inputs, backend, attrs } = args;
73825 const { x } = inputs;
73826 const { axis, exclusive, reverse } = attrs;
73827 assertNotComplex(x, 'cumsum');
73828 const permutation = getAxesPermutation([axis], x.shape.length);
73829 let $x = x;
73830 if (permutation != null) {
73831 $x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutation } });
73832 }
73833 const permutedAxis = getInnerMostAxes(1, x.shape.length)[0];
73834 if (permutedAxis !== $x.shape.length - 1) {
73835 throw new Error(`backend.cumsum in CPU expects an inner-most ` +
73836 `axis=${$x.shape.length - 1} but got axis=${permutedAxis}`);
73837 }
73838 const resultDtype = upcastType($x.dtype, 'int32');
73839 const vals = makeZerosTypedArray(sizeFromShape($x.shape), resultDtype);
73840 const aVals = backend.data.get($x.dataId).values;
73841 const finalDim = $x.shape[$x.shape.length - 1];
73842 const indexAdjuster = reverse ?
73843 (i, j) => i + finalDim - j - 1 :
73844 (i, j) => i + j;
73845 for (let i = 0; i < aVals.length; i += finalDim) {
73846 for (let j = 0; j < finalDim; j++) {
73847 const idx = indexAdjuster(i, j);
73848 if (j === 0) {
73849 vals[idx] = exclusive ? 0 : aVals[idx];
73850 }
73851 else {
73852 const prevIdx = indexAdjuster(i, j - 1);
73853 vals[idx] = exclusive ? aVals[prevIdx] + vals[prevIdx] :
73854 aVals[idx] + vals[prevIdx];
73855 }
73856 }
73857 }
73858 const result = backend.makeTensorInfo($x.shape, resultDtype, vals);
73859 if (permutation != null) {
73860 const reversePermutation = getUndoAxesPermutation(permutation);
73861 const reverseTransposedResult = transpose$1({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } });
73862 backend.disposeIntermediateTensorInfo(result);
73863 backend.disposeIntermediateTensorInfo($x);
73864 return reverseTransposedResult;
73865 }
73866 return result;
73867 }
73868 const cumsumConfig = {
73869 kernelName: Cumsum,
73870 backendName: 'cpu',
73871 kernelFunc: cumsum$1
73872 };
73873
73874 /**
73875 * @license
73876 * Copyright 2020 Google LLC. All Rights Reserved.
73877 * Licensed under the Apache License, Version 2.0 (the "License");
73878 * you may not use this file except in compliance with the License.
73879 * You may obtain a copy of the License at
73880 *
73881 * http://www.apache.org/licenses/LICENSE-2.0
73882 *
73883 * Unless required by applicable law or agreed to in writing, software
73884 * distributed under the License is distributed on an "AS IS" BASIS,
73885 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73886 * See the License for the specific language governing permissions and
73887 * limitations under the License.
73888 * =============================================================================
73889 */
73890 function denseBincount$1(args) {
73891 const { inputs, backend, attrs } = args;
73892 const { x, weights } = inputs;
73893 const { size, binaryOutput } = attrs;
73894 if (x.shape.length === 1) {
73895 const xVals = backend.data.get(x.dataId).values;
73896 const weightsVals = backend.data.get(weights.dataId).values;
73897 const outVals = bincountImpl(xVals, weightsVals, weights.dtype, weights.shape, size);
73898 return backend.makeTensorInfo([size], weights.dtype, outVals);
73899 }
73900 else if (x.shape.length === 2) {
73901 const xBuf = backend.bufferSync(x);
73902 const weightsBuf = backend.bufferSync(weights);
73903 const outBuf = bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput);
73904 return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values);
73905 }
73906 throw new Error(`Error in denseBincount: input must be at most rank 2, but got rank` +
73907 `${x.shape.length}.`);
73908 }
73909 const denseBincountConfig = {
73910 kernelName: DenseBincount,
73911 backendName: 'cpu',
73912 kernelFunc: denseBincount$1
73913 };
73914
73915 /**
73916 * @license
73917 * Copyright 2020 Google LLC. All Rights Reserved.
73918 * Licensed under the Apache License, Version 2.0 (the "License");
73919 * you may not use this file except in compliance with the License.
73920 * You may obtain a copy of the License at
73921 *
73922 * http://www.apache.org/licenses/LICENSE-2.0
73923 *
73924 * Unless required by applicable law or agreed to in writing, software
73925 * distributed under the License is distributed on an "AS IS" BASIS,
73926 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73927 * See the License for the specific language governing permissions and
73928 * limitations under the License.
73929 * =============================================================================
73930 */
73931 function depthToSpace$1(args) {
73932 const { inputs, backend, attrs } = args;
73933 const { x } = inputs;
73934 const { blockSize, dataFormat } = attrs;
73935 assert(dataFormat === 'NHWC', () => `Only NHWC dataFormat supported on CPU for depthToSpace. Got ${dataFormat}`);
73936 const batchSize = x.shape[0];
73937 const inputHeight = x.shape[1];
73938 const inputWidth = x.shape[2];
73939 const inputDepth = x.shape[3];
73940 const outputHeight = inputHeight * blockSize;
73941 const outputWidth = inputWidth * blockSize;
73942 const outputDepth = inputDepth / (blockSize * blockSize);
73943 const xValues = backend.data.get(x.dataId).values;
73944 const result = new Float32Array(batchSize * outputHeight * outputWidth * outputDepth);
73945 let outputIdx = 0;
73946 for (let b = 0; b < batchSize; ++b) {
73947 for (let h = 0; h < outputHeight; ++h) {
73948 const inH = Math.floor(h / blockSize);
73949 const offsetH = (h % blockSize);
73950 for (let w = 0; w < outputWidth; ++w) {
73951 const inW = Math.floor(w / blockSize);
73952 const offsetW = (w % blockSize);
73953 const offsetD = (offsetH * blockSize + offsetW) * outputDepth;
73954 for (let d = 0; d < outputDepth; ++d) {
73955 const inD = d + offsetD;
73956 const inputIdx = inD + inputDepth * (inW + inputWidth * (inH + inputHeight * b));
73957 result[outputIdx++] = xValues[inputIdx];
73958 }
73959 }
73960 }
73961 }
73962 return backend.makeTensorInfo([batchSize, outputHeight, outputWidth, outputDepth], x.dtype, result);
73963 }
73964 const depthToSpaceConfig = {
73965 kernelName: DepthToSpace,
73966 backendName: 'cpu',
73967 kernelFunc: depthToSpace$1
73968 };
73969
73970 /**
73971 * @license
73972 * Copyright 2020 Google LLC. All Rights Reserved.
73973 * Licensed under the Apache License, Version 2.0 (the "License");
73974 * you may not use this file except in compliance with the License.
73975 * You may obtain a copy of the License at
73976 *
73977 * http://www.apache.org/licenses/LICENSE-2.0
73978 *
73979 * Unless required by applicable law or agreed to in writing, software
73980 * distributed under the License is distributed on an "AS IS" BASIS,
73981 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73982 * See the License for the specific language governing permissions and
73983 * limitations under the License.
73984 * =============================================================================
73985 */
73986 function depthwiseConv2dNative(args) {
73987 const { inputs, backend, attrs } = args;
73988 const { x, filter } = inputs;
73989 const { strides, pad, dilations, dimRoundingMode } = attrs;
73990 assertNotComplex([x, filter], 'depthwiseConv2DNative');
73991 const xStrides = computeStrides(x.shape);
73992 const filterStrides = computeStrides(filter.shape);
73993 let $dilations = dilations;
73994 if ($dilations == null) {
73995 $dilations = [1, 1];
73996 }
73997 assert(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +
73998 `1. Got strides ${strides} and dilations '${$dilations}'`);
73999 const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */);
74000 const { filterHeight, filterWidth, dilationHeight, dilationWidth, padInfo } = convInfo;
74001 const padLeft = padInfo.left;
74002 const padTop = padInfo.top;
74003 const chMul = convInfo.outChannels / convInfo.inChannels;
74004 const y = new TensorBuffer(convInfo.outShape, x.dtype);
74005 const xVals = backend.data.get(x.dataId).values;
74006 const wVals = backend.data.get(filter.dataId).values;
74007 const yVals = y.values;
74008 for (let b = 0; b < convInfo.batchSize; ++b) {
74009 const xOffset1 = b * xStrides[0];
74010 const yOffset1 = b * y.strides[0];
74011 for (let yR = 0; yR < convInfo.outHeight; ++yR) {
74012 const yOffset2 = yOffset1 + yR * y.strides[1];
74013 const xRCorner = yR * convInfo.strideHeight - padTop;
74014 for (let wR = 0; wR < filterHeight; ++wR) {
74015 const xR = xRCorner + wR * dilationHeight;
74016 if (xR < 0 || xR >= convInfo.inHeight) {
74017 continue;
74018 }
74019 const wOffset1 = wR * filterStrides[0];
74020 const xOffset2 = xOffset1 + xR * xStrides[1];
74021 for (let yC = 0; yC < convInfo.outWidth; ++yC) {
74022 const yOffset3 = yOffset2 + yC * y.strides[2];
74023 const xCCorner = yC * convInfo.strideWidth - padLeft;
74024 for (let wC = 0; wC < filterWidth; ++wC) {
74025 const xC = xCCorner + wC * dilationWidth;
74026 if (xC < 0 || xC >= convInfo.inWidth) {
74027 continue;
74028 }
74029 const wOffset2 = wOffset1 + wC * filterStrides[1];
74030 const xOffset3 = xOffset2 + xC * convInfo.inChannels;
74031 let yOffset4 = yOffset3;
74032 let wOffset3 = wOffset2;
74033 for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
74034 const xVal = xVals[xOffset3 + d1];
74035 for (let q = 0; q < chMul; ++q) {
74036 yVals[yOffset4 + q] += xVal * wVals[wOffset3 + q];
74037 }
74038 yOffset4 += chMul;
74039 wOffset3 += chMul;
74040 }
74041 }
74042 }
74043 }
74044 }
74045 }
74046 return backend.makeTensorInfo(y.shape, y.dtype, y.values);
74047 }
74048 const depthwiseConv2dNativeConfig = {
74049 kernelName: DepthwiseConv2dNative,
74050 backendName: 'cpu',
74051 kernelFunc: depthwiseConv2dNative
74052 };
74053
74054 /**
74055 * @license
74056 * Copyright 2020 Google LLC. All Rights Reserved.
74057 * Licensed under the Apache License, Version 2.0 (the "License");
74058 * you may not use this file except in compliance with the License.
74059 * You may obtain a copy of the License at
74060 *
74061 * http://www.apache.org/licenses/LICENSE-2.0
74062 *
74063 * Unless required by applicable law or agreed to in writing, software
74064 * distributed under the License is distributed on an "AS IS" BASIS,
74065 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74066 * See the License for the specific language governing permissions and
74067 * limitations under the License.
74068 * =============================================================================
74069 */
74070 function depthwiseConv2dNativeBackpropFilter$1(args) {
74071 const { inputs, backend, attrs } = args;
74072 const { x, dy } = inputs;
74073 const { strides, dilations, pad, dimRoundingMode, filterShape } = attrs;
74074 assertNotComplex([x, dy], 'depthwiseConv2dNativeBackpropFilter');
74075 const convInfo = computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
74076 const { strideHeight, strideWidth, filterHeight, filterWidth } = convInfo;
74077 const dW = new TensorBuffer(convInfo.filterShape, 'float32');
74078 const leftPad = convInfo.padInfo.left;
74079 const topPad = convInfo.padInfo.top;
74080 const chMul = convInfo.outChannels / convInfo.inChannels;
74081 const xVals = backend.data.get(x.dataId).values;
74082 const xBuf = new TensorBuffer(x.shape, x.dtype, xVals);
74083 const dyVals = backend.data.get(dy.dataId).values;
74084 const dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);
74085 for (let wR = 0; wR < filterHeight; ++wR) {
74086 const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
74087 const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
74088 for (let wC = 0; wC < filterWidth; ++wC) {
74089 const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
74090 const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
74091 for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
74092 const d1 = Math.trunc(d2 / chMul);
74093 const dm = d2 % chMul;
74094 let dotProd = 0;
74095 for (let b = 0; b < convInfo.batchSize; ++b) {
74096 for (let yR = yRMin; yR < yRMax; ++yR) {
74097 const xR = wR + yR * strideHeight - topPad;
74098 for (let yC = yCMin; yC < yCMax; ++yC) {
74099 const xC = wC + yC * strideWidth - leftPad;
74100 dotProd += xBuf.get(b, xR, xC, d1) *
74101 dyBuf.get(b, yR, yC, d2);
74102 }
74103 }
74104 }
74105 dW.set(dotProd, wR, wC, d1, dm);
74106 }
74107 }
74108 }
74109 return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
74110 }
74111 const depthwiseConv2dNativeBackpropFilterConfig = {
74112 kernelName: DepthwiseConv2dNativeBackpropFilter,
74113 backendName: 'cpu',
74114 kernelFunc: depthwiseConv2dNativeBackpropFilter$1
74115 };
74116
74117 /**
74118 * @license
74119 * Copyright 2020 Google LLC. All Rights Reserved.
74120 * Licensed under the Apache License, Version 2.0 (the "License");
74121 * you may not use this file except in compliance with the License.
74122 * You may obtain a copy of the License at
74123 *
74124 * http://www.apache.org/licenses/LICENSE-2.0
74125 *
74126 * Unless required by applicable law or agreed to in writing, software
74127 * distributed under the License is distributed on an "AS IS" BASIS,
74128 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74129 * See the License for the specific language governing permissions and
74130 * limitations under the License.
74131 * =============================================================================
74132 */
74133 function depthwiseConv2dNativeBackpropInput$1(args) {
74134 const { inputs, backend, attrs } = args;
74135 const { dy, filter } = inputs;
74136 const { strides, dilations, pad, dimRoundingMode, inputShape } = attrs;
74137 assertNotComplex([dy, filter], 'depthwiseConv2DNativeBackpropInput');
74138 const dyStrides = computeStrides(dy.shape);
74139 const filterStrides = computeStrides(filter.shape);
74140 const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
74141 const dx = new TensorBuffer(convInfo.inShape, 'float32');
74142 const dxValues = dx.values;
74143 const [dxS0, dxS1, dxS2] = dx.strides;
74144 const dyValues = backend.data.get(dy.dataId).values;
74145 const [dyS0, dyS1, dyS2] = dyStrides;
74146 const fltValues = backend.data.get(filter.dataId).values;
74147 const [fltS0, fltS1, fltS2] = filterStrides;
74148 const { batchSize, filterHeight, filterWidth, inChannels, inHeight, inWidth, outChannels, outHeight, outWidth, strideHeight, strideWidth } = convInfo;
74149 const topPad = filterHeight - 1 - convInfo.padInfo.top;
74150 const leftPad = filterWidth - 1 - convInfo.padInfo.left;
74151 const chMul = outChannels / inChannels;
74152 for (let b = 0; b < batchSize; ++b) {
74153 for (let d1 = 0; d1 < inChannels; ++d1) {
74154 for (let xR = 0; xR < inHeight; ++xR) {
74155 const xRCorner = xR - topPad;
74156 const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
74157 const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
74158 for (let xC = 0; xC < inWidth; ++xC) {
74159 const xCCorner = xC - leftPad;
74160 const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
74161 const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
74162 let dotProd = 0;
74163 for (let yR = xRMin; yR < yRMax; ++yR) {
74164 const wR = yR * strideHeight - xRCorner;
74165 for (let yC = xCMin; yC < yCMax; ++yC) {
74166 const wC = yC * strideWidth - xCCorner;
74167 const dyOffset = dyS0 * b + dyS1 * yR + dyS2 * yC;
74168 const fltOffset = fltS0 * (filterHeight - 1 - wR) +
74169 fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
74170 for (let dm = 0; dm < chMul; ++dm) {
74171 const d2 = d1 * chMul + dm;
74172 const pixel = dyValues[dyOffset + d2];
74173 const weight = fltValues[fltOffset + dm];
74174 dotProd += pixel * weight;
74175 }
74176 }
74177 }
74178 dxValues[dxS0 * b + dxS1 * xR + dxS2 * xC + d1] = dotProd;
74179 }
74180 }
74181 }
74182 }
74183 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
74184 }
74185 const depthwiseConv2dNativeBackpropInputConfig = {
74186 kernelName: DepthwiseConv2dNativeBackpropInput,
74187 backendName: 'cpu',
74188 kernelFunc: depthwiseConv2dNativeBackpropInput$1
74189 };
74190
74191 /**
74192 * @license
74193 * Copyright 2020 Google LLC. All Rights Reserved.
74194 * Licensed under the Apache License, Version 2.0 (the "License");
74195 * you may not use this file except in compliance with the License.
74196 * You may obtain a copy of the License at
74197 *
74198 * http://www.apache.org/licenses/LICENSE-2.0
74199 *
74200 * Unless required by applicable law or agreed to in writing, software
74201 * distributed under the License is distributed on an "AS IS" BASIS,
74202 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74203 * See the License for the specific language governing permissions and
74204 * limitations under the License.
74205 * =============================================================================
74206 */
74207 function diag$1(args) {
74208 const { inputs, backend } = args;
74209 const { x } = inputs;
74210 const xSize = sizeFromShape(x.shape);
74211 const xVals = backend.data.get(x.dataId).values;
74212 const outBuf = buffer([xSize, xSize], x.dtype);
74213 const vals = outBuf.values;
74214 for (let i = 0; i < xVals.length; i++) {
74215 vals[i * xSize + i] = xVals[i];
74216 }
74217 const outShape = [...x.shape, ...x.shape];
74218 return backend.makeTensorInfo(outShape, outBuf.dtype, outBuf.values);
74219 }
74220 const diagConfig = {
74221 kernelName: Diag,
74222 backendName: 'cpu',
74223 kernelFunc: diag$1
74224 };
74225
74226 /**
74227 * @license
74228 * Copyright 2020 Google LLC. All Rights Reserved.
74229 * Licensed under the Apache License, Version 2.0 (the "License");
74230 * you may not use this file except in compliance with the License.
74231 * You may obtain a copy of the License at
74232 *
74233 * http://www.apache.org/licenses/LICENSE-2.0
74234 *
74235 * Unless required by applicable law or agreed to in writing, software
74236 * distributed under the License is distributed on an "AS IS" BASIS,
74237 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74238 * See the License for the specific language governing permissions and
74239 * limitations under the License.
74240 * =============================================================================
74241 */
74242 const dilation2DConfig = {
74243 kernelName: Dilation2D,
74244 backendName: 'cpu',
74245 kernelFunc: ({ inputs, backend, attrs }) => {
74246 const { x, filter } = inputs;
74247 const { strides, pad, dilations } = attrs;
74248 const cpuBackend = backend;
74249 const xVals = cpuBackend.data.get(x.dataId).values;
74250 const xRank = x.shape.length;
74251 const filterVals = cpuBackend.data.get(filter.dataId).values;
74252 const filterRank = filter.shape.length;
74253 const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations);
74254 const outSize = sizeFromShape(outShape);
74255 const outRank = outShape.length;
74256 const outputVals = getArrayFromDType(x.dtype, outSize);
74257 // Upsampling the input by fill in `dilation size - 1` values between each
74258 // input value.
74259 // This implementation follows the TF c++ implementation:
74260 // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
74261 for (let b = 0; b < batchSize; ++b) {
74262 for (let hOut = 0; hOut < outHeight; ++hOut) {
74263 const hBeg = hOut * strideHeight - padInfo.top;
74264 for (let wOut = 0; wOut < outWidth; ++wOut) {
74265 const wBeg = wOut * strideWidth - padInfo.left;
74266 for (let d = 0; d < inChannels; ++d) {
74267 let curVal = Number.MIN_SAFE_INTEGER;
74268 for (let h = 0; h < filterHeight; ++h) {
74269 const hIn = hBeg + h * dilationHeight;
74270 if (hIn >= 0 && hIn < inHeight) {
74271 for (let w = 0; w < filterWidth; ++w) {
74272 const wIn = wBeg + w * dilationWidth;
74273 if (wIn >= 0 && wIn < inWidth) {
74274 const xIndex = locToIndex([b, hIn, wIn, d], xRank, computeStrides(x.shape));
74275 const filterIndex = locToIndex([h, w, d], filterRank, computeStrides(filter.shape));
74276 const val = xVals[xIndex] + filterVals[filterIndex];
74277 if (val > curVal) {
74278 curVal = val;
74279 }
74280 }
74281 }
74282 }
74283 }
74284 const outputIndex = locToIndex([b, hOut, wOut, d], outRank, computeStrides(outShape));
74285 outputVals[outputIndex] = curVal;
74286 }
74287 }
74288 }
74289 }
74290 const dataId = cpuBackend.write(toTypedArray(outputVals, x.dtype), outShape, x.dtype);
74291 return { dataId, shape: outShape, dtype: x.dtype };
74292 }
74293 };
74294
74295 /**
74296 * @license
74297 * Copyright 2020 Google LLC. All Rights Reserved.
74298 * Licensed under the Apache License, Version 2.0 (the "License");
74299 * you may not use this file except in compliance with the License.
74300 * You may obtain a copy of the License at
74301 *
74302 * http://www.apache.org/licenses/LICENSE-2.0
74303 *
74304 * Unless required by applicable law or agreed to in writing, software
74305 * distributed under the License is distributed on an "AS IS" BASIS,
74306 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74307 * See the License for the specific language governing permissions and
74308 * limitations under the License.
74309 * =============================================================================
74310 */
74311 const dilation2DBackpropFilterConfig = {
74312 kernelName: Dilation2DBackpropFilter,
74313 backendName: 'cpu',
74314 kernelFunc: ({ inputs, backend, attrs }) => {
74315 const { x, filter, dy } = inputs;
74316 const { strides, pad, dilations } = attrs;
74317 const cpuBackend = backend;
74318 const $x = toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
74319 const $filter = toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
74320 const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations);
74321 assert(dy.rank === outShape.length, () => `Error in ${Dilation2DBackpropFilter}, dy ` +
74322 `must have the same rank as output ${outShape.length}, but got ` +
74323 `${dy.rank}`);
74324 const $dy = toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values);
74325 // The computed filter gradients has the same dimensions as the filter:
74326 // [filterHeight, filterWidth, depth]
74327 const gradients = makeZerosNestedTypedArray(filter.shape, filter.dtype);
74328 // In the case of multiple argmax branches, we only back-propagate along the
74329 // last branch, i.e., the one with largest value of `h * filter_cols + w`,
74330 // similarly to the max-pooling backward routines.
74331 // This implementation follows the TF c++ implementation:
74332 // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
74333 for (let b = 0; b < batchSize; ++b) {
74334 for (let hOut = 0; hOut < outHeight; ++hOut) {
74335 const hBeg = hOut * strideHeight - padInfo.top;
74336 for (let wOut = 0; wOut < outWidth; ++wOut) {
74337 const wBeg = wOut * strideWidth - padInfo.left;
74338 for (let d = 0; d < inChannels; ++d) {
74339 let curVal = Number.MIN_SAFE_INTEGER;
74340 let hMax = 0;
74341 let wMax = 0;
74342 for (let h = 0; h < filterHeight; ++h) {
74343 const hIn = hBeg + h * dilationHeight;
74344 if (hIn >= 0 && hIn < inHeight) {
74345 for (let w = 0; w < filterWidth; ++w) {
74346 const wIn = wBeg + w * dilationWidth;
74347 if (wIn >= 0 && wIn < inWidth) {
74348 const val = $x[b][hIn][wIn][d] + $filter[h][w][d];
74349 if (val > curVal) {
74350 curVal = val;
74351 hMax = h;
74352 wMax = w;
74353 }
74354 }
74355 }
74356 }
74357 }
74358 gradients[hMax][wMax][d] += $dy[b][hOut][wOut][d];
74359 }
74360 }
74361 }
74362 }
74363 const dataId = cpuBackend.write(toTypedArray(gradients, x.dtype), filter.shape, filter.dtype);
74364 return { dataId, shape: filter.shape, dtype: filter.dtype };
74365 }
74366 };
74367
74368 /**
74369 * @license
74370 * Copyright 2020 Google LLC. All Rights Reserved.
74371 * Licensed under the Apache License, Version 2.0 (the "License");
74372 * you may not use this file except in compliance with the License.
74373 * You may obtain a copy of the License at
74374 *
74375 * http://www.apache.org/licenses/LICENSE-2.0
74376 *
74377 * Unless required by applicable law or agreed to in writing, software
74378 * distributed under the License is distributed on an "AS IS" BASIS,
74379 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74380 * See the License for the specific language governing permissions and
74381 * limitations under the License.
74382 * =============================================================================
74383 */
74384 const dilation2DBackpropInputConfig = {
74385 kernelName: Dilation2DBackpropInput,
74386 backendName: 'cpu',
74387 kernelFunc: ({ inputs, backend, attrs }) => {
74388 const { x, filter, dy } = inputs;
74389 const { strides, pad, dilations } = attrs;
74390 const cpuBackend = backend;
74391 const $x = toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
74392 const $filter = toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
74393 const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations);
74394 assert(dy.rank === outShape.length, () => `Error in ${Dilation2DBackpropInput}, dy ` +
74395 `must have the same rank as output ${outShape.length}, but got ` +
74396 `${dy.rank}`);
74397 const $dy = toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values);
74398 // The computed gradients has the same dimensions as the input:
74399 // [batch, inputHeight, inputCols, inChannel]
74400 const gradients = makeZerosNestedTypedArray(x.shape, x.dtype);
74401 // In the case of multiple argmax branches, we only back-propagate along the
74402 // last branch, i.e., the one with largest value of `h * filter_cols + w`,
74403 // similarly to the max-pooling backward routines.
74404 // This implementation follows the TF c++ implementation:
74405 // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
74406 for (let b = 0; b < batchSize; ++b) {
74407 for (let hOut = 0; hOut < outHeight; ++hOut) {
74408 const hBeg = hOut * strideHeight - padInfo.top;
74409 for (let wOut = 0; wOut < outWidth; ++wOut) {
74410 const wBeg = wOut * strideWidth - padInfo.left;
74411 for (let d = 0; d < inChannels; ++d) {
74412 let curVal = Number.MIN_SAFE_INTEGER;
74413 let hInMax = (hBeg < 0) ? 0 : hBeg;
74414 let wInMax = (wBeg < 0) ? 0 : wBeg;
74415 for (let h = 0; h < filterHeight; ++h) {
74416 const hIn = hBeg + h * dilationHeight;
74417 if (hIn >= 0 && hIn < inHeight) {
74418 for (let w = 0; w < filterWidth; ++w) {
74419 const wIn = wBeg + w * dilationWidth;
74420 if (wIn >= 0 && wIn < inWidth) {
74421 const val = $x[b][hIn][wIn][d] + $filter[h][w][d];
74422 if (val > curVal) {
74423 curVal = val;
74424 hInMax = hIn;
74425 wInMax = wIn;
74426 }
74427 }
74428 }
74429 }
74430 }
74431 gradients[b][hInMax][wInMax][d] += $dy[b][hOut][wOut][d];
74432 }
74433 }
74434 }
74435 }
74436 const dataId = cpuBackend.write(toTypedArray(gradients, x.dtype), x.shape, x.dtype);
74437 return { dataId, shape: x.shape, dtype: x.dtype };
74438 }
74439 };
74440
74441 /**
74442 * @license
74443 * Copyright 2020 Google LLC. All Rights Reserved.
74444 * Licensed under the Apache License, Version 2.0 (the "License");
74445 * you may not use this file except in compliance with the License.
74446 * You may obtain a copy of the License at
74447 *
74448 * http://www.apache.org/licenses/LICENSE-2.0
74449 *
74450 * Unless required by applicable law or agreed to in writing, software
74451 * distributed under the License is distributed on an "AS IS" BASIS,
74452 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74453 * See the License for the specific language governing permissions and
74454 * limitations under the License.
74455 * =============================================================================
74456 */
74457 function sum$3(args) {
74458 const { inputs, backend, attrs } = args;
74459 const { x } = inputs;
74460 const { axis, keepDims } = attrs;
74461 assertNotComplex(x, 'sum');
74462 let $x;
74463 if (x.dtype === 'bool') {
74464 $x = cast$2({ inputs: { x }, backend, attrs: { dtype: 'int32' } });
74465 }
74466 else {
74467 $x = identity$1({ inputs: { x }, backend });
74468 }
74469 const xRank = $x.shape.length;
74470 const axes = parseAxisParam(axis, $x.shape);
74471 const permutation = getAxesPermutation(axes, xRank);
74472 let reductionAxes = axes;
74473 let permutedX = $x;
74474 if (permutation != null) {
74475 permutedX =
74476 transpose$1({ inputs: { x: $x }, backend, attrs: { perm: permutation } });
74477 reductionAxes = getInnerMostAxes(reductionAxes.length, xRank);
74478 }
74479 assertAxesAreInnerMostDims('sum', reductionAxes, permutedX.shape.length);
74480 const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, reductionAxes);
74481 const resultDtype = upcastType(permutedX.dtype, 'int32');
74482 let result = zeros$2(backend, outShape, resultDtype);
74483 const reduceSize = sizeFromShape(reduceShape);
74484 const vals = backend.data.get(result.dataId).values;
74485 const aVals = backend.data.get(permutedX.dataId).values;
74486 for (let i = 0; i < vals.length; ++i) {
74487 const offset = i * reduceSize;
74488 let sum = 0;
74489 for (let j = 0; j < reduceSize; ++j) {
74490 sum += aVals[offset + j];
74491 }
74492 vals[i] = sum;
74493 }
74494 if (keepDims) {
74495 const newShape = expandShapeToKeepDim(result.shape, axes);
74496 const oldResult = result;
74497 result = reshape$2({ inputs: { x: result }, backend, attrs: { shape: newShape } });
74498 backend.disposeIntermediateTensorInfo(oldResult);
74499 }
74500 backend.disposeIntermediateTensorInfo($x);
74501 if (permutation != null) {
74502 backend.disposeIntermediateTensorInfo(permutedX);
74503 }
74504 return result;
74505 }
74506 const sumConfig = {
74507 kernelName: Sum,
74508 backendName: 'cpu',
74509 kernelFunc: sum$3
74510 };
74511
74512 /**
74513 * @license
74514 * Copyright 2021 Google LLC. All Rights Reserved.
74515 * Licensed under the Apache License, Version 2.0 (the "License");
74516 * you may not use this file except in compliance with the License.
74517 * You may obtain a copy of the License at
74518 *
74519 * http://www.apache.org/licenses/LICENSE-2.0
74520 *
74521 * Unless required by applicable law or agreed to in writing, software
74522 * distributed under the License is distributed on an "AS IS" BASIS,
74523 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74524 * See the License for the specific language governing permissions and
74525 * limitations under the License.
74526 * =============================================================================
74527 */
74528 function einsum$1(args) {
74529 const { inputs, backend, attrs } = args;
74530 const { equation } = attrs;
74531 const tensors = inputs;
74532 const { allDims, summedDims, idDims } = decodeEinsumEquation(equation, tensors.length);
74533 checkEinsumDimSizes(allDims.length, idDims, tensors);
74534 const { path, steps } = getEinsumComputePath(summedDims, idDims);
74535 const nSteps = steps.length;
74536 let out = null;
74537 let numDimsRemaining = allDims.length;
74538 const tensorsToDispose = [];
74539 for (let i = 0; i < nSteps; ++i) {
74540 for (const idTerm of steps[i]) {
74541 const { permutationIndices: perm, expandDims: dimsToExpand } = getEinsumPermutation(numDimsRemaining, idDims[idTerm]);
74542 let x;
74543 if (isIdentityPermutation(perm)) {
74544 x = tensors[idTerm];
74545 }
74546 else {
74547 x = transpose$1({ inputs: { x: tensors[idTerm] }, backend, attrs: { perm } });
74548 tensorsToDispose.push(x);
74549 }
74550 const targetShape = x.shape.slice();
74551 for (let k = 0; k < dimsToExpand.length; ++k) {
74552 targetShape.splice(dimsToExpand[k], 0, 1);
74553 }
74554 if (!arraysEqual(x.shape, targetShape)) {
74555 x = reshape$2({ inputs: { x }, backend, attrs: { shape: targetShape } });
74556 tensorsToDispose.push(x);
74557 }
74558 if (out === null) {
74559 out = x;
74560 }
74561 else {
74562 // tslint:disable-next-line: no-unnecessary-type-assertion
74563 out = multiply$2({ inputs: { a: x, b: out }, backend });
74564 tensorsToDispose.push(out);
74565 }
74566 }
74567 if (i < nSteps - 1) {
74568 if (path[i] >= 0) {
74569 out = sum$3({
74570 inputs: { x: out },
74571 backend,
74572 attrs: {
74573 axis: path[i] - (allDims.length - numDimsRemaining),
74574 keepDims: false
74575 }
74576 });
74577 tensorsToDispose.push(out);
74578 }
74579 numDimsRemaining--;
74580 }
74581 }
74582 // Clean up intermediate tensors.
74583 for (const tensorInfo of tensorsToDispose) {
74584 if (tensorInfo === out) {
74585 continue;
74586 }
74587 backend.disposeIntermediateTensorInfo(tensorInfo);
74588 }
74589 return out;
74590 }
74591 const einsumConfig = {
74592 kernelName: Einsum,
74593 backendName: 'cpu',
74594 kernelFunc: einsum$1
74595 };
74596
74597 /**
74598 * @license
74599 * Copyright 2020 Google LLC. All Rights Reserved.
74600 * Licensed under the Apache License, Version 2.0 (the "License");
74601 * you may not use this file except in compliance with the License.
74602 * You may obtain a copy of the License at
74603 *
74604 * http://www.apache.org/licenses/LICENSE-2.0
74605 *
74606 * Unless required by applicable law or agreed to in writing, software
74607 * distributed under the License is distributed on an "AS IS" BASIS,
74608 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74609 * See the License for the specific language governing permissions and
74610 * limitations under the License.
74611 * =============================================================================
74612 */
74613 function eluGrad(args) {
74614 const { inputs, backend } = args;
74615 const { dy, y } = inputs;
74616 assertNotComplex([dy, y], 'eluGrad');
74617 const resultValues = new Float32Array(sizeFromShape(y.shape));
74618 const values = backend.data.get(y.dataId).values;
74619 const dyValues = backend.data.get(dy.dataId).values;
74620 for (let i = 0; i < values.length; ++i) {
74621 const v = values[i];
74622 if (v >= 1) {
74623 resultValues[i] = dyValues[i];
74624 }
74625 else {
74626 resultValues[i] = dyValues[i] * (v + 1);
74627 }
74628 }
74629 return backend.makeTensorInfo(y.shape, 'float32', resultValues);
74630 }
74631 const eluGradConfig$1 = {
74632 kernelName: EluGrad,
74633 backendName: 'cpu',
74634 kernelFunc: eluGrad
74635 };
74636
74637 /**
74638 * @license
74639 * Copyright 2020 Google LLC. All Rights Reserved.
74640 * Licensed under the Apache License, Version 2.0 (the License);
74641 * you may not use this file except in compliance with the License.
74642 * You may obtain a copy of the License at
74643 *
74644 * http://www.apache.org/licenses/LICENSE-2.0
74645 *
74646 * Unless required by applicable law or agreed to in writing, software
74647 * distributed under the License is distributed on an AS IS BASIS,
74648 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74649 * See the License for the specific language governing permissions and
74650 * limitations under the License.
74651 * =============================================================================
74652 */
74653 const p = ERF_P;
74654 const a1 = ERF_A1;
74655 const a2 = ERF_A2;
74656 const a3 = ERF_A3;
74657 const a4 = ERF_A4;
74658 const a5 = ERF_A5;
74659 const erf$1 = unaryKernelFunc(Erf, (xi) => {
74660 const sign = Math.sign(xi);
74661 const v = Math.abs(xi);
74662 const t = 1.0 / (1.0 + p * v);
74663 return sign *
74664 (1.0 -
74665 (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t *
74666 Math.exp(-v * v));
74667 });
74668 const erfConfig = {
74669 kernelName: Erf,
74670 backendName: 'cpu',
74671 kernelFunc: erf$1,
74672 };
74673
74674 /**
74675 * @license
74676 * Copyright 2020 Google LLC. All Rights Reserved.
74677 * Licensed under the Apache License, Version 2.0 (the "License");
74678 * you may not use this file except in compliance with the License.
74679 * You may obtain a copy of the License at
74680 *
74681 * http://www.apache.org/licenses/LICENSE-2.0
74682 *
74683 * Unless required by applicable law or agreed to in writing, software
74684 * distributed under the License is distributed on an "AS IS" BASIS,
74685 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74686 * See the License for the specific language governing permissions and
74687 * limitations under the License.
74688 * =============================================================================
74689 */
74690 function expandDims$2(args) {
74691 const { inputs, backend, attrs } = args;
74692 const { input } = inputs;
74693 const { dim } = attrs;
74694 const inputRank = input.shape.length;
74695 const newShape = input.shape.slice();
74696 let $dim = dim;
74697 if (dim < 0) {
74698 // Negative value is counted from the tail of rank.
74699 assert(-(inputRank + 1) <= dim, () => `Axis must be in the interval [${-(inputRank + 1)}, ${inputRank}]`);
74700 $dim = inputRank + dim + 1;
74701 }
74702 newShape.splice($dim, 0, 1);
74703 return reshape$2({ inputs: { x: input }, backend, attrs: { shape: newShape } });
74704 }
74705 const expandDimsConfig = {
74706 kernelName: ExpandDims,
74707 backendName: 'cpu',
74708 kernelFunc: expandDims$2
74709 };
74710
74711 /**
74712 * @license
74713 * Copyright 2020 Google LLC. All Rights Reserved.
74714 * Licensed under the Apache License, Version 2.0 (the "License");
74715 * you may not use this file except in compliance with the License.
74716 * You may obtain a copy of the License at
74717 *
74718 * http://www.apache.org/licenses/LICENSE-2.0
74719 *
74720 * Unless required by applicable law or agreed to in writing, software
74721 * distributed under the License is distributed on an "AS IS" BASIS,
74722 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74723 * See the License for the specific language governing permissions and
74724 * limitations under the License.
74725 * =============================================================================
74726 */
74727 const realDivImpl = createSimpleBinaryKernelImpl((a, b) => a / b);
74728 const div$1 = binaryKernelFunc(RealDiv, realDivImpl);
74729 const realDivConfig = {
74730 kernelName: RealDiv,
74731 backendName: 'cpu',
74732 kernelFunc: div$1
74733 };
74734
74735 /**
74736 * @license
74737 * Copyright 2020 Google LLC. All Rights Reserved.
74738 * Licensed under the Apache License, Version 2.0 (the "License");
74739 * you may not use this file except in compliance with the License.
74740 * You may obtain a copy of the License at
74741 *
74742 * http://www.apache.org/licenses/LICENSE-2.0
74743 *
74744 * Unless required by applicable law or agreed to in writing, software
74745 * distributed under the License is distributed on an "AS IS" BASIS,
74746 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74747 * See the License for the specific language governing permissions and
74748 * limitations under the License.
74749 * =============================================================================
74750 */
74751 /**
74752 * Calculate FFT of inner most elements of batch tensor.
74753 */
74754 function fftBatch(input, inverse, cpuBackend) {
74755 const inputShape = input.shape;
74756 const batch = inputShape[0];
74757 const innerDim = inputShape[1];
74758 const inputVals = cpuBackend.data.get(input.dataId);
74759 const real2D = inputVals.complexTensorInfos.real;
74760 const imag2D = inputVals.complexTensorInfos.imag;
74761 // Collects real and imaginary values separately.
74762 const resultShape = [batch, innerDim];
74763 const resultSize = sizeFromShape(resultShape);
74764 const resultReal = getTypedArrayFromDType('float32', resultSize);
74765 const resultImag = getTypedArrayFromDType('float32', resultSize);
74766 for (let b = 0; b < batch; b++) {
74767 // TODO: Support slice ops for complex type.
74768 const r = slice$1({
74769 inputs: { x: real2D },
74770 backend: cpuBackend,
74771 attrs: { begin: [b, 0], size: [1, innerDim] }
74772 });
74773 const i = slice$1({
74774 inputs: { x: imag2D },
74775 backend: cpuBackend,
74776 attrs: { begin: [b, 0], size: [1, innerDim] }
74777 });
74778 const input = complex$1({ inputs: { real: r, imag: i }, backend: cpuBackend });
74779 // Run FFT by batch element.
74780 const { real, imag } = fftImpl(input, inverse, cpuBackend);
74781 const res = mergeRealAndImagArrays(real, imag);
74782 for (let d = 0; d < innerDim; d++) {
74783 const c = getComplexWithIndex(res, d);
74784 resultReal[b * innerDim + d] = c.real;
74785 resultImag[b * innerDim + d] = c.imag;
74786 }
74787 cpuBackend.disposeIntermediateTensorInfo(r);
74788 cpuBackend.disposeIntermediateTensorInfo(i);
74789 cpuBackend.disposeIntermediateTensorInfo(input);
74790 }
74791 const $realInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', resultReal);
74792 const $imagInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImag);
74793 const result = complex$1({ inputs: { real: $realInfo, imag: $imagInfo }, backend: cpuBackend });
74794 cpuBackend.disposeIntermediateTensorInfo($realInfo);
74795 cpuBackend.disposeIntermediateTensorInfo($imagInfo);
74796 return result;
74797 }
74798 function fftImpl(input, inverse, cpuBackend) {
74799 const inputSize = sizeFromShape(input.shape);
74800 const inputVals = cpuBackend.data.get(input.dataId);
74801 const realVals = cpuBackend.data.get(inputVals.complexTensorInfos.real.dataId).values;
74802 const imagVals = cpuBackend.data.get(inputVals.complexTensorInfos.imag.dataId).values;
74803 if (isExponentOf2(inputSize)) {
74804 const result = fftRadix2(realVals, imagVals, inputSize, inverse, cpuBackend);
74805 const resultShape = [input.shape[0], input.shape[1]];
74806 if (inverse) {
74807 const realInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', result.real);
74808 const imagInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', result.imag);
74809 const sizeInfo = cpuBackend.makeTensorInfo([], 'float32', createScalarValue(inputSize, 'float32'));
74810 const sizeInfoCopy = identity$1({ inputs: { x: sizeInfo }, backend: cpuBackend });
74811 const divRealInfo = realDivConfig.kernelFunc({ inputs: { a: realInfo, b: sizeInfo }, backend: cpuBackend });
74812 const divImagInfo = realDivConfig.kernelFunc({ inputs: { a: imagInfo, b: sizeInfoCopy }, backend: cpuBackend });
74813 const divRealVals = cpuBackend.data.get(divRealInfo.dataId).values;
74814 const divImagVals = cpuBackend.data.get(divImagInfo.dataId).values;
74815 cpuBackend.disposeIntermediateTensorInfo(realInfo);
74816 cpuBackend.disposeIntermediateTensorInfo(imagInfo);
74817 cpuBackend.disposeIntermediateTensorInfo(sizeInfo);
74818 cpuBackend.disposeIntermediateTensorInfo(sizeInfoCopy);
74819 cpuBackend.disposeIntermediateTensorInfo(divRealInfo);
74820 cpuBackend.disposeIntermediateTensorInfo(divImagInfo);
74821 return { real: divRealVals, imag: divImagVals };
74822 }
74823 return result;
74824 }
74825 else {
74826 const data = mergeRealAndImagArrays(realVals, imagVals);
74827 const rawOutput = fourierTransformByMatmul(data, inputSize, inverse);
74828 return splitRealAndImagArrays(rawOutput);
74829 }
74830 }
74831 function isExponentOf2(size) {
74832 return (size & size - 1) === 0;
74833 }
74834 // FFT using Cooley-Tukey algorithm on radix 2 dimensional input.
74835 function fftRadix2(realVals, imagVals, size, inverse, cpuBackend) {
74836 if (size === 1) {
74837 return { real: realVals, imag: imagVals };
74838 }
74839 const data = mergeRealAndImagArrays(realVals, imagVals);
74840 const half = size / 2;
74841 const evenComplex = complexWithEvenIndex(data);
74842 const evenRealVals = evenComplex.real;
74843 const evenImagVals = evenComplex.imag;
74844 const evenShape = [evenRealVals.length];
74845 const evenRealInfo = cpuBackend.makeTensorInfo(evenShape, 'float32', evenRealVals);
74846 const evenImagInfo = cpuBackend.makeTensorInfo(evenShape, 'float32', evenImagVals);
74847 const evenTensorInfo = complex$1({ inputs: { real: evenRealInfo, imag: evenImagInfo }, backend: cpuBackend });
74848 const oddComplex = complexWithOddIndex(data);
74849 const oddRealVals = oddComplex.real;
74850 const oddImagVals = oddComplex.imag;
74851 const oddShape = [oddRealVals.length];
74852 const oddRealInfo = cpuBackend.makeTensorInfo(oddShape, 'float32', oddRealVals);
74853 const oddImagInfo = cpuBackend.makeTensorInfo(oddShape, 'float32', oddImagVals);
74854 const oddTensorInfo = complex$1({ inputs: { real: oddRealInfo, imag: oddImagInfo }, backend: cpuBackend });
74855 // Recursive call for half part of original input.
74856 const $evenComplex = fftRadix2(evenRealVals, evenImagVals, half, inverse, cpuBackend);
74857 const $evenRealVals = $evenComplex.real;
74858 const $evenImagVals = $evenComplex.imag;
74859 const $evenShape = [$evenRealVals.length];
74860 const $evenRealInfo = cpuBackend.makeTensorInfo($evenShape, 'float32', $evenRealVals);
74861 const $evenImagInfo = cpuBackend.makeTensorInfo($evenShape, 'float32', $evenImagVals);
74862 const $evenTensorInfo = complex$1({
74863 inputs: { real: $evenRealInfo, imag: $evenImagInfo },
74864 backend: cpuBackend
74865 });
74866 const $oddComplex = fftRadix2(oddRealVals, oddImagVals, half, inverse, cpuBackend);
74867 const $oddRealVals = $oddComplex.real;
74868 const $oddImagVals = $oddComplex.imag;
74869 const $oddShape = [$oddRealVals.length];
74870 const $oddRealInfo = cpuBackend.makeTensorInfo($oddShape, 'float32', $oddRealVals);
74871 const $oddImagInfo = cpuBackend.makeTensorInfo($oddShape, 'float32', $oddImagVals);
74872 const $oddTensorInfo = complex$1({ inputs: { real: $oddRealInfo, imag: $oddImagInfo }, backend: cpuBackend });
74873 const e = exponents(size, inverse);
74874 const eShape = [e.real.length];
74875 const eRealInfo = cpuBackend.makeTensorInfo(eShape, 'float32', e.real);
74876 const eImagInfo = cpuBackend.makeTensorInfo(eShape, 'float32', e.imag);
74877 const complexInfo = complex$1({ inputs: { real: eRealInfo, imag: eImagInfo }, backend: cpuBackend });
74878 const exponentInfo = multiply$2({ inputs: { a: complexInfo, b: $oddTensorInfo }, backend: cpuBackend });
74879 const addPart = add$4({
74880 inputs: { a: $evenTensorInfo, b: exponentInfo },
74881 backend: cpuBackend
74882 });
74883 const subPart = sub$1({
74884 inputs: { a: $evenTensorInfo, b: exponentInfo },
74885 backend: cpuBackend
74886 });
74887 const addPartReal = real$1({ inputs: { input: addPart }, backend: cpuBackend });
74888 const subPartReal = real$1({ inputs: { input: subPart }, backend: cpuBackend });
74889 const addPartImag = imag$1({ inputs: { input: addPart }, backend: cpuBackend });
74890 const subPartImag = imag$1({ inputs: { input: subPart }, backend: cpuBackend });
74891 const $real = concat$1({
74892 inputs: [addPartReal, subPartReal],
74893 backend: cpuBackend,
74894 attrs: { axis: 0 }
74895 });
74896 const $imag = concat$1({
74897 inputs: [addPartImag, subPartImag],
74898 backend: cpuBackend,
74899 attrs: { axis: 0 }
74900 });
74901 const $realVals = cpuBackend.data.get($real.dataId).values;
74902 const $imagVals = cpuBackend.data.get($imag.dataId).values;
74903 cpuBackend.disposeIntermediateTensorInfo(evenRealInfo);
74904 cpuBackend.disposeIntermediateTensorInfo(evenImagInfo);
74905 cpuBackend.disposeIntermediateTensorInfo(evenTensorInfo);
74906 cpuBackend.disposeIntermediateTensorInfo(oddRealInfo);
74907 cpuBackend.disposeIntermediateTensorInfo(oddImagInfo);
74908 cpuBackend.disposeIntermediateTensorInfo(oddTensorInfo);
74909 cpuBackend.disposeIntermediateTensorInfo($evenRealInfo);
74910 cpuBackend.disposeIntermediateTensorInfo($evenImagInfo);
74911 cpuBackend.disposeIntermediateTensorInfo($evenTensorInfo);
74912 cpuBackend.disposeIntermediateTensorInfo($oddRealInfo);
74913 cpuBackend.disposeIntermediateTensorInfo($oddImagInfo);
74914 cpuBackend.disposeIntermediateTensorInfo($oddTensorInfo);
74915 cpuBackend.disposeIntermediateTensorInfo(eRealInfo);
74916 cpuBackend.disposeIntermediateTensorInfo(eImagInfo);
74917 cpuBackend.disposeIntermediateTensorInfo(complexInfo);
74918 cpuBackend.disposeIntermediateTensorInfo(exponentInfo);
74919 cpuBackend.disposeIntermediateTensorInfo(addPart);
74920 cpuBackend.disposeIntermediateTensorInfo(subPart);
74921 cpuBackend.disposeIntermediateTensorInfo(addPartReal);
74922 cpuBackend.disposeIntermediateTensorInfo(addPartImag);
74923 cpuBackend.disposeIntermediateTensorInfo(subPartReal);
74924 cpuBackend.disposeIntermediateTensorInfo(subPartImag);
74925 cpuBackend.disposeIntermediateTensorInfo($real);
74926 cpuBackend.disposeIntermediateTensorInfo($imag);
74927 return { real: $realVals, imag: $imagVals };
74928 }
74929 // Calculate fourier transform by multplying sinusoid matrix.
74930 function fourierTransformByMatmul(data, size, inverse) {
74931 const ret = new Float32Array(size * 2);
74932 // TODO: Use matmul instead once it supports complex64 type.
74933 for (let r = 0; r < size; r++) {
74934 let real = 0.0;
74935 let imag = 0.0;
74936 for (let c = 0; c < size; c++) {
74937 const e = exponent(r * c, size, inverse);
74938 const term = getComplexWithIndex(data, c);
74939 real += term.real * e.real - term.imag * e.imag;
74940 imag += term.real * e.imag + term.imag * e.real;
74941 }
74942 if (inverse) {
74943 real /= size;
74944 imag /= size;
74945 }
74946 assignToTypedArray(ret, real, imag, r);
74947 }
74948 return ret;
74949 }
74950
74951 /**
74952 * @license
74953 * Copyright 2020 Google LLC. All Rights Reserved.
74954 * Licensed under the Apache License, Version 2.0 (the "License");
74955 * you may not use this file except in compliance with the License.
74956 * You may obtain a copy of the License at
74957 *
74958 * http://www.apache.org/licenses/LICENSE-2.0
74959 *
74960 * Unless required by applicable law or agreed to in writing, software
74961 * distributed under the License is distributed on an "AS IS" BASIS,
74962 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74963 * See the License for the specific language governing permissions and
74964 * limitations under the License.
74965 * =============================================================================
74966 */
74967 function fft$1(args) {
74968 const { inputs, backend } = args;
74969 const { input } = inputs;
74970 const inputSize = sizeFromShape(input.shape);
74971 // Collapse all outer dimensions to a single batch dimension.
74972 const innerDimensionSize = input.shape[input.shape.length - 1];
74973 const batch = inputSize / innerDimensionSize;
74974 const input2D = reshape$2({
74975 inputs: { x: input },
74976 backend,
74977 attrs: { shape: [batch, innerDimensionSize] }
74978 });
74979 const result = fftBatch(input2D, false, backend);
74980 const resultReshaped = reshape$2({ inputs: { x: result }, backend, attrs: { shape: input.shape } });
74981 backend.disposeIntermediateTensorInfo(input2D);
74982 backend.disposeIntermediateTensorInfo(result);
74983 return resultReshaped;
74984 }
74985 const fftConfig = {
74986 kernelName: FFT,
74987 backendName: 'cpu',
74988 kernelFunc: fft$1
74989 };
74990
74991 /**
74992 * @license
74993 * Copyright 2020 Google LLC. All Rights Reserved.
74994 * Licensed under the Apache License, Version 2.0 (the "License");
74995 * you may not use this file except in compliance with the License.
74996 * You may obtain a copy of the License at
74997 *
74998 * http://www.apache.org/licenses/LICENSE-2.0
74999 *
75000 * Unless required by applicable law or agreed to in writing, software
75001 * distributed under the License is distributed on an "AS IS" BASIS,
75002 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75003 * See the License for the specific language governing permissions and
75004 * limitations under the License.
75005 * =============================================================================
75006 */
75007 function fill$1(args) {
75008 const { backend, attrs } = args;
75009 const { shape, value, dtype } = attrs;
75010 const $dtype = dtype || inferDtype(value);
75011 const values = getArrayFromDType($dtype, sizeFromShape(shape));
75012 fillValues(values, value, $dtype);
75013 return backend.makeTensorInfo(shape, $dtype, values);
75014 }
75015 const fillConfig = {
75016 kernelName: Fill,
75017 backendName: 'cpu',
75018 kernelFunc: fill$1
75019 };
75020 function fillValues(values, value, dtype) {
75021 if (dtype === 'string') {
75022 values.fill(value);
75023 }
75024 else {
75025 values.fill(value);
75026 }
75027 }
75028
75029 /**
75030 * @license
75031 * Copyright 2020 Google LLC. All Rights Reserved.
75032 * Licensed under the Apache License, Version 2.0 (the "License");
75033 * you may not use this file except in compliance with the License.
75034 * You may obtain a copy of the License at
75035 *
75036 * http://www.apache.org/licenses/LICENSE-2.0
75037 *
75038 * Unless required by applicable law or agreed to in writing, software
75039 * distributed under the License is distributed on an "AS IS" BASIS,
75040 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75041 * See the License for the specific language governing permissions and
75042 * limitations under the License.
75043 * =============================================================================
75044 */
75045 const flipLeftRightConfig = {
75046 kernelName: FlipLeftRight,
75047 backendName: 'cpu',
75048 kernelFunc: ({ inputs, attrs, backend }) => {
75049 const { image } = inputs;
75050 const cpuBackend = backend;
75051 const output = getTypedArrayFromDType(image.dtype, sizeFromShape(image.shape));
75052 const [batch, imageHeight, imageWidth, numChannels] = image.shape;
75053 const imageVals = cpuBackend.data.get(image.dataId).values;
75054 for (let batchIdx = 0; batchIdx < batch; batchIdx++) {
75055 const batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
75056 for (let row = 0; row < imageHeight; row++) {
75057 const rowOffset = row * (imageWidth * numChannels);
75058 for (let col = 0; col < imageWidth; col++) {
75059 const colOffset = col * numChannels;
75060 for (let channel = 0; channel < numChannels; channel++) {
75061 const coordX = Math.round(imageWidth - col - 1);
75062 const outIdx = batchOffset + rowOffset + colOffset + channel;
75063 let outputValue = imageVals[outIdx];
75064 // If the coordinate position falls within the image boundaries...
75065 if (coordX >= 0 && coordX < imageWidth) {
75066 // set the output to the image value at the coordinate position.
75067 const rotatedColOffset = coordX * numChannels;
75068 const imageIdx = batchOffset + rowOffset + rotatedColOffset + channel;
75069 outputValue = imageVals[imageIdx];
75070 }
75071 output[outIdx] = outputValue;
75072 }
75073 }
75074 }
75075 }
75076 const dataId = cpuBackend.write(output, image.shape, image.dtype);
75077 return { dataId, shape: image.shape, dtype: image.dtype };
75078 }
75079 };
75080
75081 /**
75082 * @license
75083 * Copyright 2020 Google LLC. All Rights Reserved.
75084 * Licensed under the Apache License, Version 2.0 (the "License");
75085 * you may not use this file except in compliance with the License.
75086 * You may obtain a copy of the License at
75087 *
75088 * http://www.apache.org/licenses/LICENSE-2.0
75089 *
75090 * Unless required by applicable law or agreed to in writing, software
75091 * distributed under the License is distributed on an "AS IS" BASIS,
75092 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75093 * See the License for the specific language governing permissions and
75094 * limitations under the License.
75095 * =============================================================================
75096 */
75097 const floorDivImpl = createSimpleBinaryKernelImpl((a, b) => Math.floor(a / b));
75098 const floorDiv$1 = binaryKernelFunc(FloorDiv, floorDivImpl, null /* complexImpl */, 'int32');
75099 const floorDivConfig = {
75100 kernelName: FloorDiv,
75101 backendName: 'cpu',
75102 kernelFunc: floorDiv$1
75103 };
75104
75105 /**
75106 * @license
75107 * Copyright 2020 Google LLC. All Rights Reserved.
75108 * Licensed under the Apache License, Version 2.0 (the "License");
75109 * you may not use this file except in compliance with the License.
75110 * You may obtain a copy of the License at
75111 *
75112 * http://www.apache.org/licenses/LICENSE-2.0
75113 *
75114 * Unless required by applicable law or agreed to in writing, software
75115 * distributed under the License is distributed on an "AS IS" BASIS,
75116 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75117 * See the License for the specific language governing permissions and
75118 * limitations under the License.
75119 * =============================================================================
75120 */
75121 function fusedConv2D(args) {
75122 const { inputs, backend, attrs } = args;
75123 const { x, filter, bias, preluActivationWeights } = inputs;
75124 const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
75125 let result = conv2D({
75126 inputs: { x, filter },
75127 backend,
75128 attrs: { strides, pad, dataFormat, dilations, dimRoundingMode }
75129 });
75130 if (bias) {
75131 const resultOld = result;
75132 // For NCHW format, if bias is a 1-D tensor, it is supposed to be aligned
75133 // to the channel of the conv2d's result; if the bias is a scalar, the
75134 // bias_add is computed as if the bias was broadcasted to the shape of the
75135 // conv2d's result.
75136 if (dataFormat === 'NCHW' && bias.shape.length === 1 &&
75137 bias.shape[0] !== 1) {
75138 const reshapedBias = reshape$2({ inputs: { x: bias }, backend, attrs: { shape: [bias.shape[0], 1, 1] } });
75139 result =
75140 add$4({ inputs: { a: result, b: reshapedBias }, backend });
75141 backend.disposeIntermediateTensorInfo(reshapedBias);
75142 }
75143 else {
75144 // This condition handles NHWC and NCHW (scalar case). The only other case
75145 // for NCHW (1D case) is handled above.
75146 result = add$4({ inputs: { a: result, b: bias }, backend });
75147 }
75148 backend.disposeIntermediateTensorInfo(resultOld);
75149 }
75150 if (activation) {
75151 const resultOld = result;
75152 // For NCHW format, if PReLu activation weights is a 1-D tensor, it is
75153 // supposed to be aligned with the channel of the conv2d's result. For other
75154 // cases, whether NCHW or NHWC data format, the conv2d result is
75155 // already aligned with the activation weights.
75156 if (dataFormat === 'NCHW' && activation === 'prelu' &&
75157 preluActivationWeights.shape.length === 1 &&
75158 preluActivationWeights.shape[0] !== 1) {
75159 const reshapedAlpha = reshape$2({
75160 inputs: { x: preluActivationWeights },
75161 backend,
75162 attrs: { shape: [preluActivationWeights.shape[0], 1, 1] }
75163 });
75164 result = applyActivation$1(backend, result, activation, reshapedAlpha, leakyreluAlpha);
75165 backend.disposeIntermediateTensorInfo(reshapedAlpha);
75166 }
75167 else {
75168 result = applyActivation$1(backend, result, activation, preluActivationWeights, leakyreluAlpha);
75169 }
75170 backend.disposeIntermediateTensorInfo(resultOld);
75171 }
75172 return result;
75173 }
75174 const fusedConv2DConfig = {
75175 kernelName: FusedConv2D,
75176 backendName: 'cpu',
75177 kernelFunc: fusedConv2D
75178 };
75179
75180 /**
75181 * @license
75182 * Copyright 2020 Google LLC. All Rights Reserved.
75183 * Licensed under the Apache License, Version 2.0 (the "License");
75184 * you may not use this file except in compliance with the License.
75185 * You may obtain a copy of the License at
75186 *
75187 * http://www.apache.org/licenses/LICENSE-2.0
75188 *
75189 * Unless required by applicable law or agreed to in writing, software
75190 * distributed under the License is distributed on an "AS IS" BASIS,
75191 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75192 * See the License for the specific language governing permissions and
75193 * limitations under the License.
75194 * =============================================================================
75195 */
75196 function fusedDepthwiseConv2D(args) {
75197 const { inputs, backend, attrs } = args;
75198 const { x, filter, bias, preluActivationWeights } = inputs;
75199 const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
75200 let result = depthwiseConv2dNative({
75201 inputs: { x, filter },
75202 backend,
75203 attrs: { strides, pad, dataFormat, dilations, dimRoundingMode }
75204 });
75205 if (bias) {
75206 const oldResult = result;
75207 result = add$4({ inputs: { a: result, b: bias }, backend });
75208 backend.disposeIntermediateTensorInfo(oldResult);
75209 }
75210 if (activation) {
75211 const oldResult = result;
75212 result = applyActivation$1(backend, result, activation, preluActivationWeights, leakyreluAlpha);
75213 backend.disposeIntermediateTensorInfo(oldResult);
75214 }
75215 return result;
75216 }
75217 const fusedDepthwiseConv2DConfig = {
75218 kernelName: FusedDepthwiseConv2D,
75219 backendName: 'cpu',
75220 kernelFunc: fusedDepthwiseConv2D
75221 };
75222
75223 /**
75224 * @license
75225 * Copyright 2020 Google LLC. All Rights Reserved.
75226 * Licensed under the Apache License, Version 2.0 (the "License");
75227 * you may not use this file except in compliance with the License.
75228 * You may obtain a copy of the License at
75229 *
75230 * http://www.apache.org/licenses/LICENSE-2.0
75231 *
75232 * Unless required by applicable law or agreed to in writing, software
75233 * distributed under the License is distributed on an "AS IS" BASIS,
75234 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75235 * See the License for the specific language governing permissions and
75236 * limitations under the License.
75237 * =============================================================================
75238 */
75239 function gatherNd(args) {
75240 const { inputs, backend } = args;
75241 const { params, indices } = inputs;
75242 const paramsSize = sizeFromShape(params.shape);
75243 const indicesShape = indices.shape;
75244 const sliceRank = indicesShape[indicesShape.length - 1];
75245 const [resultShape, numSlices, sliceSize, strides] = prepareAndValidate(params, indices);
75246 if (numSlices === 0) {
75247 return backend.makeTensorInfo(resultShape, params.dtype, []);
75248 }
75249 const indicesData = backend.data.get(indices.dataId).values;
75250 const paramsBuf = backend.bufferSync(params);
75251 const outBuf = gatherNdImpl(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize);
75252 return backend.makeTensorInfo(resultShape, params.dtype, outBuf.values);
75253 }
75254 const gatherNdConfig = {
75255 kernelName: GatherNd,
75256 backendName: 'cpu',
75257 kernelFunc: gatherNd
75258 };
75259
75260 /**
75261 * @license
75262 * Copyright 2020 Google LLC. All Rights Reserved.
75263 * Licensed under the Apache License, Version 2.0 (the "License");
75264 * you may not use this file except in compliance with the License.
75265 * You may obtain a copy of the License at
75266 *
75267 * http://www.apache.org/licenses/LICENSE-2.0
75268 *
75269 * Unless required by applicable law or agreed to in writing, software
75270 * distributed under the License is distributed on an "AS IS" BASIS,
75271 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75272 * See the License for the specific language governing permissions and
75273 * limitations under the License.
75274 * =============================================================================
75275 */
75276 function gatherV2(args) {
75277 const { inputs, backend, attrs } = args;
75278 const { x, indices } = inputs;
75279 const { axis, batchDims } = attrs;
75280 assertNotComplex([x, indices], 'gatherV2');
75281 // Throw error when any index is out of bound.
75282 const parsedAxis = parseAxisParam(axis, x.shape)[0];
75283 const indicesVals = backend.data.get(indices.dataId).values;
75284 const axisDim = x.shape[parsedAxis];
75285 for (let i = 0; i < indicesVals.length; ++i) {
75286 const index = indicesVals[i];
75287 assert(index <= axisDim - 1 && index >= 0, () => `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);
75288 }
75289 let $batchDims = batchDims;
75290 if (batchDims == null) {
75291 $batchDims = 0;
75292 }
75293 const indicesSize = sizeFromShape(indices.shape);
75294 const shapeInfo = collectGatherOpShapeInfo(x, indices, parsedAxis, $batchDims);
75295 const flattenX = reshape$2({
75296 inputs: { x },
75297 backend,
75298 attrs: {
75299 shape: [
75300 shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize,
75301 shapeInfo.sliceSize
75302 ]
75303 }
75304 });
75305 const flattenIndex = reshape$2({
75306 inputs: { x: indices },
75307 backend,
75308 attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] }
75309 });
75310 const flattenOutputShape = [
75311 shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize,
75312 shapeInfo.sliceSize
75313 ];
75314 const indicesBuf = backend.bufferSync(flattenIndex);
75315 const xBuf = backend.bufferSync(flattenX);
75316 const outBuf = gatherV2Impl(xBuf, indicesBuf, flattenOutputShape);
75317 backend.disposeIntermediateTensorInfo(flattenX);
75318 backend.disposeIntermediateTensorInfo(flattenIndex);
75319 return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values);
75320 }
75321 const gatherV2Config = {
75322 kernelName: GatherV2,
75323 backendName: 'cpu',
75324 kernelFunc: gatherV2
75325 };
75326
75327 /**
75328 * @license
75329 * Copyright 2020 Google LLC. All Rights Reserved.
75330 * Licensed under the Apache License, Version 2.0 (the "License");
75331 * you may not use this file except in compliance with the License.
75332 * You may obtain a copy of the License at
75333 *
75334 * http://www.apache.org/licenses/LICENSE-2.0
75335 *
75336 * Unless required by applicable law or agreed to in writing, software
75337 * distributed under the License is distributed on an "AS IS" BASIS,
75338 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75339 * See the License for the specific language governing permissions and
75340 * limitations under the License.
75341 * =============================================================================
75342 */
75343 function ifft$1(args) {
75344 const { inputs, backend } = args;
75345 const { input } = inputs;
75346 const inputSize = sizeFromShape(input.shape);
75347 // Collapse all outer dimensions to a single batch dimension.
75348 const innerDimensionSize = input.shape[input.shape.length - 1];
75349 const batch = inputSize / innerDimensionSize;
75350 const input2D = reshape$2({
75351 inputs: { x: input },
75352 backend,
75353 attrs: { shape: [batch, innerDimensionSize] }
75354 });
75355 const result = fftBatch(input2D, true, backend);
75356 const resultReshaped = reshape$2({ inputs: { x: result }, backend, attrs: { shape: input.shape } });
75357 backend.disposeIntermediateTensorInfo(input2D);
75358 backend.disposeIntermediateTensorInfo(result);
75359 return resultReshaped;
75360 }
75361 const ifftConfig = {
75362 kernelName: IFFT,
75363 backendName: 'cpu',
75364 kernelFunc: ifft$1
75365 };
75366
75367 /**
75368 * @license
75369 * Copyright 2020 Google LLC. All Rights Reserved.
75370 * Licensed under the Apache License, Version 2.0 (the License);
75371 * you may not use this file except in compliance with the License.
75372 * You may obtain a copy of the License at
75373 *
75374 * http://www.apache.org/licenses/LICENSE-2.0
75375 *
75376 * Unless required by applicable law or agreed to in writing, software
75377 * distributed under the License is distributed on an AS IS BASIS,
75378 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75379 * See the License for the specific language governing permissions and
75380 * limitations under the License.
75381 * =============================================================================
75382 */
75383 const isFinite$2 = unaryKernelFunc(IsFinite, (xi) => Number.isFinite(xi) ? 1 : 0, 'bool');
75384 const isFiniteConfig = {
75385 kernelName: IsFinite,
75386 backendName: 'cpu',
75387 kernelFunc: isFinite$2,
75388 };
75389
75390 /**
75391 * @license
75392 * Copyright 2020 Google LLC. All Rights Reserved.
75393 * Licensed under the Apache License, Version 2.0 (the License);
75394 * you may not use this file except in compliance with the License.
75395 * You may obtain a copy of the License at
75396 *
75397 * http://www.apache.org/licenses/LICENSE-2.0
75398 *
75399 * Unless required by applicable law or agreed to in writing, software
75400 * distributed under the License is distributed on an AS IS BASIS,
75401 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75402 * See the License for the specific language governing permissions and
75403 * limitations under the License.
75404 * =============================================================================
75405 */
75406 const isInf$1 = unaryKernelFunc(IsInf, (xi) => Math.abs(xi) === Infinity ? 1 : 0, 'bool');
75407 const isInfConfig = {
75408 kernelName: IsInf,
75409 backendName: 'cpu',
75410 kernelFunc: isInf$1,
75411 };
75412
75413 /**
75414 * @license
75415 * Copyright 2020 Google LLC. All Rights Reserved.
75416 * Licensed under the Apache License, Version 2.0 (the License);
75417 * you may not use this file except in compliance with the License.
75418 * You may obtain a copy of the License at
75419 *
75420 * http://www.apache.org/licenses/LICENSE-2.0
75421 *
75422 * Unless required by applicable law or agreed to in writing, software
75423 * distributed under the License is distributed on an AS IS BASIS,
75424 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75425 * See the License for the specific language governing permissions and
75426 * limitations under the License.
75427 * =============================================================================
75428 */
75429 const isNaN$2 = unaryKernelFunc(IsNan, (xi) => Number.isNaN(xi) ? 1 : 0, 'bool');
75430 const isNaNConfig = {
75431 kernelName: IsNan,
75432 backendName: 'cpu',
75433 kernelFunc: isNaN$2,
75434 };
75435
75436 /**
75437 * @license
75438 * Copyright 2020 Google LLC. All Rights Reserved.
75439 * Licensed under the Apache License, Version 2.0 (the "License");
75440 * you may not use this file except in compliance with the License.
75441 * You may obtain a copy of the License at
75442 *
75443 * http://www.apache.org/licenses/LICENSE-2.0
75444 *
75445 * Unless required by applicable law or agreed to in writing, software
75446 * distributed under the License is distributed on an "AS IS" BASIS,
75447 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75448 * See the License for the specific language governing permissions and
75449 * limitations under the License.
75450 * =============================================================================
75451 */
75452 function linSpace(args) {
75453 const { backend, attrs } = args;
75454 const { start, stop, num } = attrs;
75455 const outVals = linSpaceImpl(start, stop, num);
75456 return backend.makeTensorInfo([outVals.length], 'float32', outVals);
75457 }
75458 const linSpaceConfig = {
75459 kernelName: LinSpace,
75460 backendName: 'cpu',
75461 kernelFunc: linSpace
75462 };
75463
75464 /**
75465 * @license
75466 * Copyright 2020 Google LLC. All Rights Reserved.
75467 * Licensed under the Apache License, Version 2.0 (the License);
75468 * you may not use this file except in compliance with the License.
75469 * You may obtain a copy of the License at
75470 *
75471 * http://www.apache.org/licenses/LICENSE-2.0
75472 *
75473 * Unless required by applicable law or agreed to in writing, software
75474 * distributed under the License is distributed on an AS IS BASIS,
75475 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75476 * See the License for the specific language governing permissions and
75477 * limitations under the License.
75478 * =============================================================================
75479 */
75480 const log1p$1 = unaryKernelFunc(Log1p, (xi) => Math.log1p(xi));
75481 const log1pConfig = {
75482 kernelName: Log1p,
75483 backendName: 'cpu',
75484 kernelFunc: log1p$1,
75485 };
75486
75487 /**
75488 * @license
75489 * Copyright 2020 Google LLC. All Rights Reserved.
75490 * Licensed under the Apache License, Version 2.0 (the "License");
75491 * you may not use this file except in compliance with the License.
75492 * You may obtain a copy of the License at
75493 *
75494 * http://www.apache.org/licenses/LICENSE-2.0
75495 *
75496 * Unless required by applicable law or agreed to in writing, software
75497 * distributed under the License is distributed on an "AS IS" BASIS,
75498 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75499 * See the License for the specific language governing permissions and
75500 * limitations under the License.
75501 * =============================================================================
75502 */
75503 const logicalAndImpl = createSimpleBinaryKernelImpl((a, b) => a && b);
75504 const logicalAnd$1 = binaryKernelFunc(LogicalAnd, logicalAndImpl, null /* complexImpl */, 'bool');
75505 const logicalAndConfig = {
75506 kernelName: LogicalAnd,
75507 backendName: 'cpu',
75508 kernelFunc: logicalAnd$1
75509 };
75510
75511 /**
75512 * @license
75513 * Copyright 2020 Google LLC. All Rights Reserved.
75514 * Licensed under the Apache License, Version 2.0 (the License);
75515 * you may not use this file except in compliance with the License.
75516 * You may obtain a copy of the License at
75517 *
75518 * http://www.apache.org/licenses/LICENSE-2.0
75519 *
75520 * Unless required by applicable law or agreed to in writing, software
75521 * distributed under the License is distributed on an AS IS BASIS,
75522 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75523 * See the License for the specific language governing permissions and
75524 * limitations under the License.
75525 * =============================================================================
75526 */
75527 const logicalNot$1 = unaryKernelFunc(LogicalNot, (xi) => xi ? 0 : 1, 'bool');
75528 const logicalNotConfig = {
75529 kernelName: LogicalNot,
75530 backendName: 'cpu',
75531 kernelFunc: logicalNot$1,
75532 };
75533
75534 /**
75535 * @license
75536 * Copyright 2020 Google LLC. All Rights Reserved.
75537 * Licensed under the Apache License, Version 2.0 (the "License");
75538 * you may not use this file except in compliance with the License.
75539 * You may obtain a copy of the License at
75540 *
75541 * http://www.apache.org/licenses/LICENSE-2.0
75542 *
75543 * Unless required by applicable law or agreed to in writing, software
75544 * distributed under the License is distributed on an "AS IS" BASIS,
75545 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75546 * See the License for the specific language governing permissions and
75547 * limitations under the License.
75548 * =============================================================================
75549 */
75550 const logicalOrImpl = createSimpleBinaryKernelImpl((a, b) => a || b);
75551 const logicalOr$1 = binaryKernelFunc(LogicalOr, logicalOrImpl, null /* complexImpl */, 'bool');
75552 const logicalOrConfig = {
75553 kernelName: LogicalOr,
75554 backendName: 'cpu',
75555 kernelFunc: logicalOr$1
75556 };
75557
75558 /**
75559 * @license
75560 * Copyright 2020 Google LLC. All Rights Reserved.
75561 * Licensed under the Apache License, Version 2.0 (the "License");
75562 * you may not use this file except in compliance with the License.
75563 * You may obtain a copy of the License at
75564 *
75565 * http://www.apache.org/licenses/LICENSE-2.0
75566 *
75567 * Unless required by applicable law or agreed to in writing, software
75568 * distributed under the License is distributed on an "AS IS" BASIS,
75569 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75570 * See the License for the specific language governing permissions and
75571 * limitations under the License.
75572 * =============================================================================
75573 */
75574 function lRN(args) {
75575 const { inputs, backend, attrs } = args;
75576 const { x } = inputs;
75577 const { depthRadius, bias, alpha, beta } = attrs;
75578 assertNotComplex(x, 'LRN');
75579 const channels = x.shape[3];
75580 const maxD = channels - 1;
75581 const xValues = backend.data.get(x.dataId).values;
75582 const size = sizeFromShape(x.shape);
75583 const result = new Float32Array(size);
75584 function sumAcrossChannels(offset) {
75585 const currentChannel = offset % channels;
75586 let beginSumOffset = offset - currentChannel + Math.max(0, currentChannel - depthRadius);
75587 const endSumOffset = offset - currentChannel + Math.min(currentChannel + depthRadius, maxD);
75588 let sum = 0.0;
75589 for (; beginSumOffset <= endSumOffset; beginSumOffset++) {
75590 const z = xValues[beginSumOffset];
75591 sum += z * z;
75592 }
75593 return sum;
75594 }
75595 for (let offset = 0; offset < size; offset++) {
75596 const sum = sumAcrossChannels(offset);
75597 const val = xValues[offset] * Math.pow(bias + alpha * sum, -beta);
75598 result[offset] = val;
75599 }
75600 return backend.makeTensorInfo(x.shape, x.dtype, result);
75601 }
75602 // tslint:disable-next-line: variable-name
75603 const LRNConfig = {
75604 kernelName: LRN,
75605 backendName: 'cpu',
75606 kernelFunc: lRN
75607 };
75608
75609 /**
75610 * @license
75611 * Copyright 2020 Google LLC. All Rights Reserved.
75612 * Licensed under the Apache License, Version 2.0 (the "License");
75613 * you may not use this file except in compliance with the License.
75614 * You may obtain a copy of the License at
75615 *
75616 * http://www.apache.org/licenses/LICENSE-2.0
75617 *
75618 * Unless required by applicable law or agreed to in writing, software
75619 * distributed under the License is distributed on an "AS IS" BASIS,
75620 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75621 * See the License for the specific language governing permissions and
75622 * limitations under the License.
75623 * =============================================================================
75624 */
75625 function lRNGrad(args) {
75626 const { inputs, backend, attrs } = args;
75627 const { x, y, dy } = inputs;
75628 const { depthRadius, bias, alpha, beta } = attrs;
75629 assertNotComplex(dy, 'LRNGrad');
75630 const dySize = sizeFromShape(dy.shape);
75631 const channels = dy.shape[3];
75632 const dyValues = backend.data.get(dy.dataId).values;
75633 const xValues = backend.data.get(x.dataId).values;
75634 const yValues = backend.data.get(y.dataId).values;
75635 const result = new Float32Array(dySize);
75636 const size = dySize;
75637 for (let offset = 0; offset < size; offset++) {
75638 const currentChannel = offset % channels;
75639 const depthBegin = (offset - currentChannel) + Math.max(0, currentChannel - depthRadius);
75640 const depthEnd = (offset - currentChannel) +
75641 Math.min(channels, currentChannel + depthRadius + 1);
75642 let norm = 0;
75643 for (let k = depthBegin; k < depthEnd; k++) {
75644 norm += Math.pow(xValues[k], 2);
75645 }
75646 norm = alpha * norm + bias;
75647 for (let k = depthBegin; k < depthEnd; k++) {
75648 let dyi = -2 * alpha * beta * xValues[k] * yValues[offset] / norm;
75649 if (offset === k) {
75650 dyi += Math.pow(norm, -beta);
75651 }
75652 dyi *= dyValues[offset];
75653 result[k] += dyi;
75654 }
75655 }
75656 return backend.makeTensorInfo(dy.shape, x.dtype, result);
75657 }
75658 // tslint:disable-next-line: variable-name
75659 const LRNGradConfig = {
75660 kernelName: LRNGrad,
75661 backendName: 'cpu',
75662 kernelFunc: lRNGrad
75663 };
75664
75665 /**
75666 * @license
75667 * Copyright 2020 Google LLC. All Rights Reserved.
75668 * Licensed under the Apache License, Version 2.0 (the "License");
75669 * you may not use this file except in compliance with the License.
75670 * You may obtain a copy of the License at
75671 *
75672 * http://www.apache.org/licenses/LICENSE-2.0
75673 *
75674 * Unless required by applicable law or agreed to in writing, software
75675 * distributed under the License is distributed on an "AS IS" BASIS,
75676 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75677 * See the License for the specific language governing permissions and
75678 * limitations under the License.
75679 * =============================================================================
75680 */
75681 function max$2(args) {
75682 const { inputs, backend, attrs } = args;
75683 const { x } = inputs;
75684 const { reductionIndices, keepDims } = attrs;
75685 const cpuBackend = backend;
75686 let xShape = x.shape;
75687 const xRank = xShape.length;
75688 const origAxes = parseAxisParam(reductionIndices, xShape);
75689 let axes = origAxes;
75690 const permutedAxes = getAxesPermutation(axes, xRank);
75691 let xVals = cpuBackend.data.get(x.dataId).values;
75692 if (permutedAxes != null) {
75693 const newShape = new Array(xRank);
75694 for (let i = 0; i < newShape.length; i++) {
75695 newShape[i] = xShape[permutedAxes[i]];
75696 }
75697 xVals = transposeImpl(xVals, xShape, x.dtype, permutedAxes, newShape);
75698 axes = getInnerMostAxes(axes.length, xRank);
75699 xShape = newShape;
75700 }
75701 assertNotComplex(x, 'max');
75702 assertAxesAreInnerMostDims('max', axes, xRank);
75703 const [maxOutShape, reduceShape] = computeOutAndReduceShapes(xShape, axes);
75704 const reduceSize = sizeFromShape(reduceShape);
75705 const result = maxImpl(xVals, reduceSize, maxOutShape, x.dtype);
75706 const dataId = cpuBackend.write(result, maxOutShape, x.dtype);
75707 let outShape = maxOutShape;
75708 if (keepDims) {
75709 // reshape
75710 const newShape = expandShapeToKeepDim(maxOutShape, origAxes);
75711 outShape = newShape;
75712 }
75713 return { dataId, shape: outShape, dtype: x.dtype };
75714 }
75715 const maxConfig = {
75716 kernelName: Max,
75717 backendName: 'cpu',
75718 kernelFunc: max$2
75719 };
75720
75721 /**
75722 * @license
75723 * Copyright 2020 Google LLC. All Rights Reserved.
75724 * Licensed under the Apache License, Version 2.0 (the "License");
75725 * you may not use this file except in compliance with the License.
75726 * You may obtain a copy of the License at
75727 *
75728 * http://www.apache.org/licenses/LICENSE-2.0
75729 *
75730 * Unless required by applicable law or agreed to in writing, software
75731 * distributed under the License is distributed on an "AS IS" BASIS,
75732 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75733 * See the License for the specific language governing permissions and
75734 * limitations under the License.
75735 * =============================================================================
75736 */
75737 function maxPool$1(args) {
75738 const { inputs, backend, attrs } = args;
75739 const { x } = inputs;
75740 assertNotComplex(x, 'maxPool');
75741 const { filterSize, strides, pad, dimRoundingMode } = attrs;
75742 const dilations = 1;
75743 assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
75744 `Got strides ${strides} and dilations '${dilations}'`);
75745 const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
75746 let res;
75747 if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
75748 arraysEqual(convInfo.inShape, convInfo.outShape)) {
75749 res = identity$1({ inputs: { x }, backend });
75750 }
75751 else {
75752 const xValues = backend.data.get(x.dataId).values;
75753 const strides = computeStrides(x.shape);
75754 const buffer = pool$1(xValues, x.shape, x.dtype, strides, convInfo, 'max');
75755 res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
75756 }
75757 return res;
75758 }
75759 const maxPoolConfig = {
75760 kernelName: MaxPool,
75761 backendName: 'cpu',
75762 kernelFunc: maxPool$1
75763 };
75764
75765 /**
75766 * @license
75767 * Copyright 2020 Google LLC. All Rights Reserved.
75768 * Licensed under the Apache License, Version 2.0 (the "License");
75769 * you may not use this file except in compliance with the License.
75770 * You may obtain a copy of the License at
75771 *
75772 * http://www.apache.org/licenses/LICENSE-2.0
75773 *
75774 * Unless required by applicable law or agreed to in writing, software
75775 * distributed under the License is distributed on an "AS IS" BASIS,
75776 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75777 * See the License for the specific language governing permissions and
75778 * limitations under the License.
75779 * =============================================================================
75780 */
75781 function maxPool3D(args) {
75782 const { inputs, backend, attrs } = args;
75783 const { x } = inputs;
75784 const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs;
75785 assertNotComplex(x, 'maxPool3d');
75786 const convInfo = computePool3DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode, dataFormat);
75787 const xValues = backend.data.get(x.dataId).values;
75788 const outBuf = pool3d$1(xValues, x.shape, x.dtype, computeStrides(x.shape), convInfo, 'max');
75789 return backend.makeTensorInfo(outBuf.shape, 'float32', outBuf.values);
75790 }
75791 const maxPool3DConfig = {
75792 kernelName: MaxPool3D,
75793 backendName: 'cpu',
75794 kernelFunc: maxPool3D
75795 };
75796
75797 /**
75798 * @license
75799 * Copyright 2020 Google LLC. All Rights Reserved.
75800 * Licensed under the Apache License, Version 2.0 (the "License");
75801 * you may not use this file except in compliance with the License.
75802 * You may obtain a copy of the License at
75803 *
75804 * http://www.apache.org/licenses/LICENSE-2.0
75805 *
75806 * Unless required by applicable law or agreed to in writing, software
75807 * distributed under the License is distributed on an "AS IS" BASIS,
75808 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75809 * See the License for the specific language governing permissions and
75810 * limitations under the License.
75811 * =============================================================================
75812 */
75813 function maxPool3DGrad(args) {
75814 const { inputs, backend, attrs } = args;
75815 const { dy, input } = inputs;
75816 const { filterSize, strides, pad, dimRoundingMode } = attrs;
75817 assertNotComplex([dy, input], 'maxPool3DGrad');
75818 const convInfo = computePool3DInfo(input.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
75819 const inputBuf = backend.bufferSync(input);
75820 const maxPosBuf = maxPool3dPositions(inputBuf, convInfo);
75821 const strideDepth = convInfo.strideDepth;
75822 const strideHeight = convInfo.strideHeight;
75823 const strideWidth = convInfo.strideWidth;
75824 const dilationDepth = convInfo.dilationDepth;
75825 const dilationHeight = convInfo.dilationHeight;
75826 const dilationWidth = convInfo.dilationWidth;
75827 const effectiveFilterDepth = convInfo.effectiveFilterDepth;
75828 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
75829 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
75830 const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
75831 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
75832 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
75833 const dx = buffer(input.shape, 'float32');
75834 const dyBuf = backend.bufferSync(dy);
75835 for (let batch = 0; batch < convInfo.batchSize; ++batch) {
75836 for (let channel = 0; channel < convInfo.inChannels; ++channel) {
75837 for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
75838 for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
75839 for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
75840 // Shader code begins
75841 const dyDepthCorner = dxDepth - padFront;
75842 const dyRowCorner = dxRow - padTop;
75843 const dyColCorner = dxCol - padLeft;
75844 let dotProd = 0;
75845 for (let wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
75846 const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
75847 if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
75848 Math.floor(dyDepth) !== dyDepth) {
75849 continue;
75850 }
75851 for (let wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
75852 const dyRow = (dyRowCorner + wRow) / strideHeight;
75853 if (dyRow < 0 || dyRow >= convInfo.outHeight ||
75854 Math.floor(dyRow) !== dyRow) {
75855 continue;
75856 }
75857 for (let wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
75858 const dyCol = (dyColCorner + wCol) / strideWidth;
75859 if (dyCol < 0 || dyCol >= convInfo.outWidth ||
75860 Math.floor(dyCol) !== dyCol) {
75861 continue;
75862 }
75863 const maxPos = effectiveFilterDepth * effectiveFilterHeight *
75864 effectiveFilterWidth -
75865 1 -
75866 maxPosBuf.get(batch, dyDepth, dyRow, dyCol, channel);
75867 const curPos = wDepth * effectiveFilterHeight * effectiveFilterWidth +
75868 wRow * effectiveFilterWidth + wCol;
75869 const mask = maxPos === curPos ? 1 : 0;
75870 if (mask === 0) {
75871 continue;
75872 }
75873 const pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
75874 dotProd += pixel * mask;
75875 }
75876 }
75877 }
75878 dx.set(dotProd, batch, dxDepth, dxRow, dxCol, channel);
75879 }
75880 }
75881 }
75882 }
75883 }
75884 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
75885 }
75886 const maxPool3DGradConfig$1 = {
75887 kernelName: MaxPool3DGrad,
75888 backendName: 'cpu',
75889 kernelFunc: maxPool3DGrad
75890 };
75891
75892 /**
75893 * @license
75894 * Copyright 2020 Google LLC. All Rights Reserved.
75895 * Licensed under the Apache License, Version 2.0 (the "License");
75896 * you may not use this file except in compliance with the License.
75897 * You may obtain a copy of the License at
75898 *
75899 * http://www.apache.org/licenses/LICENSE-2.0
75900 *
75901 * Unless required by applicable law or agreed to in writing, software
75902 * distributed under the License is distributed on an "AS IS" BASIS,
75903 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75904 * See the License for the specific language governing permissions and
75905 * limitations under the License.
75906 * =============================================================================
75907 */
75908 function maxPoolGrad$1(args) {
75909 const { inputs, backend, attrs } = args;
75910 const { dy, input, output } = inputs;
75911 const x = input;
75912 assertNotComplex([input, output], 'maxPoolGrad');
75913 const { filterSize, strides, pad, dimRoundingMode } = attrs;
75914 const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
75915 const xValues = backend.data.get(x.dataId).values;
75916 const maxPosBuf = buffer(convInfo.outShape, x.dtype, maxPoolPositions(xValues, x.shape, x.dtype, convInfo).values);
75917 const strideHeight = convInfo.strideHeight;
75918 const strideWidth = convInfo.strideWidth;
75919 const dilationHeight = convInfo.dilationHeight;
75920 const dilationWidth = convInfo.dilationWidth;
75921 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
75922 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
75923 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
75924 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
75925 const dx = buffer(x.shape, 'float32');
75926 const dyData = backend.data.get(dy.dataId).values;
75927 const dyBuf = buffer(dy.shape, 'float32', dyData);
75928 for (let b = 0; b < convInfo.batchSize; ++b) {
75929 for (let d = 0; d < convInfo.inChannels; ++d) {
75930 for (let dxR = 0; dxR < convInfo.inHeight; ++dxR) {
75931 for (let dxC = 0; dxC < convInfo.inWidth; ++dxC) {
75932 // Shader code begins.
75933 const dyRCorner = dxR - padTop;
75934 const dyCCorner = dxC - padLeft;
75935 let dotProd = 0;
75936 for (let wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
75937 const dyR = (dyRCorner + wR) / strideHeight;
75938 if (dyR < 0 || dyR >= convInfo.outHeight ||
75939 Math.floor(dyR) !== dyR) {
75940 continue;
75941 }
75942 for (let wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
75943 const dyC = (dyCCorner + wC) / strideWidth;
75944 if (dyC < 0 || dyC >= convInfo.outWidth ||
75945 Math.floor(dyC) !== dyC) {
75946 continue;
75947 }
75948 const maxPos = effectiveFilterHeight * effectiveFilterWidth - 1 -
75949 maxPosBuf.get(b, dyR, dyC, d);
75950 const curPos = wR * effectiveFilterWidth + wC;
75951 const mask = maxPos === curPos ? 1 : 0;
75952 if (mask === 0) {
75953 continue;
75954 }
75955 const pixel = dyBuf.get(b, dyR, dyC, d);
75956 dotProd += pixel * mask;
75957 }
75958 }
75959 dx.set(dotProd, b, dxR, dxC, d);
75960 }
75961 }
75962 }
75963 }
75964 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
75965 }
75966 const maxPoolGradConfig$1 = {
75967 kernelName: MaxPoolGrad,
75968 backendName: 'cpu',
75969 kernelFunc: maxPoolGrad$1
75970 };
75971
75972 /**
75973 * @license
75974 * Copyright 2020 Google LLC. All Rights Reserved.
75975 * Licensed under the Apache License, Version 2.0 (the "License");
75976 * you may not use this file except in compliance with the License.
75977 * You may obtain a copy of the License at
75978 *
75979 * http://www.apache.org/licenses/LICENSE-2.0
75980 *
75981 * Unless required by applicable law or agreed to in writing, software
75982 * distributed under the License is distributed on an "AS IS" BASIS,
75983 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75984 * See the License for the specific language governing permissions and
75985 * limitations under the License.
75986 * =============================================================================
75987 */
75988 function maxPoolWithArgmaxImpl(xValues, xShape, dtype, includeBatchInIndex, convInfo) {
75989 const strides = computeStrides(xShape);
75990 const maxPools = pool$1(xValues, xShape, dtype, strides, convInfo, 'max');
75991 const maxPositions = maxPoolPositions(xValues, xShape, dtype, convInfo, true, includeBatchInIndex);
75992 return [maxPools.values, maxPositions.values];
75993 }
75994
75995 /**
75996 * @license
75997 * Copyright 2020 Google LLC. All Rights Reserved.
75998 * Licensed under the Apache License, Version 2.0 (the "License");
75999 * you may not use this file except in compliance with the License.
76000 * You may obtain a copy of the License at
76001 *
76002 * http://www.apache.org/licenses/LICENSE-2.0
76003 *
76004 * Unless required by applicable law or agreed to in writing, software
76005 * distributed under the License is distributed on an "AS IS" BASIS,
76006 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76007 * See the License for the specific language governing permissions and
76008 * limitations under the License.
76009 * =============================================================================
76010 */
76011 const maxPoolWithArgmaxConfig = {
76012 kernelName: MaxPoolWithArgmax,
76013 backendName: 'cpu',
76014 kernelFunc: ({ inputs, attrs, backend }) => {
76015 const { x } = inputs;
76016 const { filterSize, strides, pad, includeBatchInIndex } = attrs;
76017 const cpuBackend = backend;
76018 assertNotComplex(x, 'MaxPoolWithArgmax');
76019 const values = cpuBackend.data.get(x.dataId).values;
76020 const convInfo = computePool2DInfo(x.shape, filterSize, strides, [1, 1], pad);
76021 const [pooled, indexes] = maxPoolWithArgmaxImpl(values, x.shape, x.dtype, includeBatchInIndex, convInfo);
76022 const pooledDataId = cpuBackend.write(pooled, convInfo.outShape, x.dtype);
76023 const indexesDataId = cpuBackend.write(indexes, convInfo.outShape, x.dtype);
76024 return [
76025 { dataId: pooledDataId, shape: convInfo.outShape, dtype: x.dtype },
76026 { dataId: indexesDataId, shape: convInfo.outShape, dtype: 'int32' }
76027 ];
76028 }
76029 };
76030
76031 /**
76032 * @license
76033 * Copyright 2020 Google LLC. All Rights Reserved.
76034 * Licensed under the Apache License, Version 2.0 (the "License");
76035 * you may not use this file except in compliance with the License.
76036 * You may obtain a copy of the License at
76037 *
76038 * http://www.apache.org/licenses/LICENSE-2.0
76039 *
76040 * Unless required by applicable law or agreed to in writing, software
76041 * distributed under the License is distributed on an "AS IS" BASIS,
76042 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76043 * See the License for the specific language governing permissions and
76044 * limitations under the License.
76045 * =============================================================================
76046 */
76047 function mean$3(args) {
76048 const { inputs, backend, attrs } = args;
76049 const { x } = inputs;
76050 const { axis, keepDims } = attrs;
76051 const axes = parseAxisParam(axis, x.shape);
76052 const shapes = computeOutAndReduceShapes(x.shape, axes);
76053 const reduceShape = shapes[1];
76054 const reduceSize = sizeFromShape(reduceShape);
76055 const toDispose = [];
76056 const reduceSizeScalar = backend.makeTensorInfo([], 'float32', new Float32Array([reduceSize]));
76057 toDispose.push(reduceSizeScalar);
76058 const $x = cast$2({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
76059 toDispose.push($x);
76060 const res = div$1({ inputs: { a: $x, b: reduceSizeScalar }, backend });
76061 toDispose.push(res);
76062 const result = sum$3({ inputs: { x: res }, backend, attrs: { axis, keepDims } });
76063 toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
76064 return result;
76065 }
76066 const meanConfig = {
76067 kernelName: Mean,
76068 backendName: 'cpu',
76069 kernelFunc: mean$3
76070 };
76071
76072 /**
76073 * @license
76074 * Copyright 2020 Google LLC. All Rights Reserved.
76075 * Licensed under the Apache License, Version 2.0 (the "License");
76076 * you may not use this file except in compliance with the License.
76077 * You may obtain a copy of the License at
76078 *
76079 * http://www.apache.org/licenses/LICENSE-2.0
76080 *
76081 * Unless required by applicable law or agreed to in writing, software
76082 * distributed under the License is distributed on an "AS IS" BASIS,
76083 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76084 * See the License for the specific language governing permissions and
76085 * limitations under the License.
76086 * =============================================================================
76087 */
76088 function min$2(args) {
76089 const { inputs, backend, attrs } = args;
76090 const { x } = inputs;
76091 const { axis, keepDims } = attrs;
76092 assertNotComplex(x, 'min');
76093 const origAxes = parseAxisParam(axis, x.shape);
76094 let axes = origAxes;
76095 const permutedAxes = getAxesPermutation(axes, x.shape.length);
76096 let $x = x;
76097 if (permutedAxes != null) {
76098 $x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
76099 axes = getInnerMostAxes(axes.length, x.shape.length);
76100 }
76101 assertAxesAreInnerMostDims('min', axes, $x.shape.length);
76102 const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
76103 const reduceSize = sizeFromShape(reduceShape);
76104 const vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
76105 const aVals = backend.data.get($x.dataId).values;
76106 for (let i = 0; i < vals.length; ++i) {
76107 const offset = i * reduceSize;
76108 let min = aVals[offset];
76109 for (let j = 0; j < reduceSize; ++j) {
76110 const value = aVals[offset + j];
76111 if (Number.isNaN(value) ||
76112 value < min) { // comparison with NaN always return false
76113 min = value;
76114 }
76115 }
76116 vals[i] = min;
76117 }
76118 if (permutedAxes != null) {
76119 backend.disposeIntermediateTensorInfo($x);
76120 }
76121 const result = backend.makeTensorInfo(outShape, $x.dtype, vals);
76122 if (keepDims) {
76123 const expandedShape = expandShapeToKeepDim(outShape, origAxes);
76124 const reshapedResult = reshape$2({ inputs: { x: result }, backend, attrs: { shape: expandedShape } });
76125 backend.disposeIntermediateTensorInfo(result);
76126 return reshapedResult;
76127 }
76128 return result;
76129 }
76130 const minConfig = {
76131 kernelName: Min,
76132 backendName: 'cpu',
76133 kernelFunc: min$2
76134 };
76135
76136 /**
76137 * @license
76138 * Copyright 2020 Google LLC. All Rights Reserved.
76139 * Licensed under the Apache License, Version 2.0 (the "License");
76140 * you may not use this file except in compliance with the License.
76141 * You may obtain a copy of the License at
76142 *
76143 * http://www.apache.org/licenses/LICENSE-2.0
76144 *
76145 * Unless required by applicable law or agreed to in writing, software
76146 * distributed under the License is distributed on an "AS IS" BASIS,
76147 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76148 * See the License for the specific language governing permissions and
76149 * limitations under the License.
76150 * =============================================================================
76151 */
76152 function mirrorPad$1(args) {
76153 const { inputs, backend, attrs } = args;
76154 const { x } = inputs;
76155 const { paddings, mode } = attrs;
76156 assertNotComplex(x, 'mirrorPad');
76157 const outShape = paddings.map((p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */);
76158 const start = paddings.map(p => p[0]);
76159 const end = paddings.map((p, i) => p[0] + x.shape[i]);
76160 const offset = mode === 'reflect' ? 0 : 1;
76161 const xVals = backend.data.get(x.dataId).values;
76162 const xRank = x.shape.length;
76163 const xStrides = computeStrides(x.shape);
76164 const resultSize = sizeFromShape(outShape);
76165 const resultRank = outShape.length;
76166 const resultStrides = computeStrides(outShape);
76167 const resVals = getTypedArrayFromDType(x.dtype, resultSize);
76168 for (let i = 0; i < resultSize; i++) {
76169 let coords = indexToLoc(i, resultRank, resultStrides);
76170 for (let i = 0; i < resultRank; i++) {
76171 if (coords[i] < start[i]) {
76172 coords[i] = start[i] * 2 - coords[i] - offset;
76173 }
76174 else if (coords[i] >= end[i]) {
76175 coords[i] = (end[i] - 1) * 2 - coords[i] + offset;
76176 }
76177 }
76178 coords = coords.map((c, i) => c - start[i]);
76179 const inIndex = locToIndex(coords, xRank, xStrides);
76180 resVals[i] = xVals[inIndex];
76181 }
76182 const outId = backend.write(resVals, outShape, x.dtype);
76183 return { dataId: outId, shape: outShape, dtype: x.dtype };
76184 }
76185 const mirrorPadConfig = {
76186 kernelName: MirrorPad,
76187 backendName: 'cpu',
76188 kernelFunc: mirrorPad$1
76189 };
76190
76191 /**
76192 * @license
76193 * Copyright 2020 Google LLC. All Rights Reserved.
76194 * Licensed under the Apache License, Version 2.0 (the "License");
76195 * you may not use this file except in compliance with the License.
76196 * You may obtain a copy of the License at
76197 *
76198 * http://www.apache.org/licenses/LICENSE-2.0
76199 *
76200 * Unless required by applicable law or agreed to in writing, software
76201 * distributed under the License is distributed on an "AS IS" BASIS,
76202 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76203 * See the License for the specific language governing permissions and
76204 * limitations under the License.
76205 * =============================================================================
76206 */
76207 const modImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => {
76208 const rem = aValue % bValue;
76209 if ((aValue < 0 && bValue < 0) || (aValue >= 0 && bValue >= 0)) {
76210 return rem;
76211 }
76212 else {
76213 return (rem + bValue) % bValue;
76214 }
76215 }));
76216 const mod$1 = binaryKernelFunc(Mod, modImpl);
76217 const modConfig = {
76218 kernelName: Mod,
76219 backendName: 'cpu',
76220 kernelFunc: mod$1
76221 };
76222
76223 /**
76224 * @license
76225 * Copyright 2020 Google LLC. All Rights Reserved.
76226 * Licensed under the Apache License, Version 2.0 (the "License");
76227 * you may not use this file except in compliance with the License.
76228 * You may obtain a copy of the License at
76229 *
76230 * http://www.apache.org/licenses/LICENSE-2.0
76231 *
76232 * Unless required by applicable law or agreed to in writing, software
76233 * distributed under the License is distributed on an "AS IS" BASIS,
76234 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76235 * See the License for the specific language governing permissions and
76236 * limitations under the License.
76237 * =============================================================================
76238 */
76239 function softmax$2(args) {
76240 const { inputs, backend, attrs } = args;
76241 const { logits } = inputs;
76242 const { dim } = attrs;
76243 const logitsRank = logits.shape.length;
76244 let $dim = dim;
76245 if ($dim === -1) {
76246 $dim = logitsRank - 1;
76247 }
76248 if ($dim !== logitsRank - 1) {
76249 throw Error('Softmax along a non-last dimension is not yet supported. ' +
76250 `Logits was rank ${logitsRank} and dim was ${$dim}`);
76251 }
76252 const axes = parseAxisParam([$dim], logits.shape);
76253 const maxLogit = max$2({
76254 inputs: { x: logits },
76255 backend,
76256 attrs: { reductionIndices: axes, keepDims: false }
76257 });
76258 const expandedShape = expandShapeToKeepDim(maxLogit.shape, axes);
76259 const maxLogitReshaped = reshape$2({ inputs: { x: maxLogit }, backend, attrs: { shape: expandedShape } });
76260 const a = sub$1({ inputs: { a: logits, b: maxLogitReshaped }, backend });
76261 const b = exp$1({ inputs: { x: a }, backend });
76262 const sumExp = sum$3({ inputs: { x: b }, backend, attrs: { axis: axes, keepDims: false } });
76263 const sumReshaped = reshape$2({ inputs: { x: sumExp }, backend, attrs: { shape: expandedShape } });
76264 const result = div$1({ inputs: { a: b, b: sumReshaped }, backend });
76265 backend.disposeIntermediateTensorInfo(maxLogit);
76266 backend.disposeIntermediateTensorInfo(maxLogitReshaped);
76267 backend.disposeIntermediateTensorInfo(a);
76268 backend.disposeIntermediateTensorInfo(b);
76269 backend.disposeIntermediateTensorInfo(sumExp);
76270 backend.disposeIntermediateTensorInfo(sumReshaped);
76271 return result;
76272 }
76273 const softmaxConfig = {
76274 kernelName: Softmax,
76275 backendName: 'cpu',
76276 kernelFunc: softmax$2
76277 };
76278
76279 /**
76280 * @license
76281 * Copyright 2020 Google LLC. All Rights Reserved.
76282 * Licensed under the Apache License, Version 2.0 (the "License");
76283 * you may not use this file except in compliance with the License.
76284 * You may obtain a copy of the License at
76285 *
76286 * http://www.apache.org/licenses/LICENSE-2.0
76287 *
76288 * Unless required by applicable law or agreed to in writing, software
76289 * distributed under the License is distributed on an "AS IS" BASIS,
76290 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76291 * See the License for the specific language governing permissions and
76292 * limitations under the License.
76293 * =============================================================================
76294 */
76295 function multinomial$1(args) {
76296 const { inputs, backend, attrs } = args;
76297 const { logits } = inputs;
76298 const { numSamples, seed, normalized } = attrs;
76299 assertNotComplex(logits, 'multinomial');
76300 const probabilities = normalized ?
76301 logits :
76302 softmax$2({ inputs: { logits }, backend, attrs: { dim: -1 } });
76303 const batchSize = probabilities.shape[0];
76304 const numEvents = probabilities.shape[1];
76305 const probVals = backend.data.get(probabilities.dataId).values;
76306 const resShape = [batchSize, numSamples];
76307 const resVals = makeZerosTypedArray(sizeFromShape(resShape), 'int32');
76308 for (let b = 0; b < batchSize; ++b) {
76309 const offset = b * numEvents;
76310 // The cdf won't include the last event. It will be implicit if no other
76311 // event happened.
76312 const cdf = new Float32Array(numEvents - 1);
76313 cdf[0] = probVals[offset];
76314 for (let event = 1; event < cdf.length; ++event) {
76315 cdf[event] = cdf[event - 1] + probVals[offset + event];
76316 }
76317 const random = seedrandom_1(seed.toString());
76318 const outOffset = b * numSamples;
76319 for (let sampleId = 0; sampleId < numSamples; ++sampleId) {
76320 const r = random();
76321 // Assume last event happened by default.
76322 resVals[outOffset + sampleId] = cdf.length;
76323 for (let event = 0; event < cdf.length; event++) {
76324 if (r < cdf[event]) {
76325 resVals[outOffset + sampleId] = event;
76326 break;
76327 }
76328 }
76329 }
76330 }
76331 if (!normalized) {
76332 backend.disposeIntermediateTensorInfo(probabilities);
76333 }
76334 return backend.makeTensorInfo(resShape, 'int32', resVals);
76335 }
76336 const multinomialConfig = {
76337 kernelName: Multinomial,
76338 backendName: 'cpu',
76339 kernelFunc: multinomial$1
76340 };
76341
76342 /**
76343 * @license
76344 * Copyright 2020 Google LLC. All Rights Reserved.
76345 * Licensed under the Apache License, Version 2.0 (the "License");
76346 * you may not use this file except in compliance with the License.
76347 * You may obtain a copy of the License at
76348 *
76349 * http://www.apache.org/licenses/LICENSE-2.0
76350 *
76351 * Unless required by applicable law or agreed to in writing, software
76352 * distributed under the License is distributed on an "AS IS" BASIS,
76353 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76354 * See the License for the specific language governing permissions and
76355 * limitations under the License.
76356 * =============================================================================
76357 */
76358 const nonMaxSuppressionV3Impl$1 = nonMaxSuppressionV3Impl;
76359 function nonMaxSuppressionV3(args) {
76360 const { inputs, backend, attrs } = args;
76361 const { boxes, scores } = inputs;
76362 const { maxOutputSize, iouThreshold, scoreThreshold } = attrs;
76363 assertNotComplex(boxes, 'NonMaxSuppression');
76364 const boxesVals = backend.data.get(boxes.dataId).values;
76365 const scoresVals = backend.data.get(scores.dataId).values;
76366 const { selectedIndices } = nonMaxSuppressionV3Impl$1(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
76367 return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices));
76368 }
76369 const nonMaxSuppressionV3Config = {
76370 kernelName: NonMaxSuppressionV3,
76371 backendName: 'cpu',
76372 kernelFunc: nonMaxSuppressionV3
76373 };
76374
76375 /**
76376 * @license
76377 * Copyright 2020 Google LLC. All Rights Reserved.
76378 * Licensed under the Apache License, Version 2.0 (the "License");
76379 * you may not use this file except in compliance with the License.
76380 * You may obtain a copy of the License at
76381 *
76382 * http://www.apache.org/licenses/LICENSE-2.0
76383 *
76384 * Unless required by applicable law or agreed to in writing, software
76385 * distributed under the License is distributed on an "AS IS" BASIS,
76386 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76387 * See the License for the specific language governing permissions and
76388 * limitations under the License.
76389 * =============================================================================
76390 */
76391 const nonMaxSuppressionV4Impl$1 = nonMaxSuppressionV4Impl;
76392 function nonMaxSuppressionV4(args) {
76393 const { inputs, backend, attrs } = args;
76394 const { boxes, scores } = inputs;
76395 const { maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize } = attrs;
76396 assertNotComplex(boxes, 'NonMaxSuppressionPadded');
76397 const boxesVals = backend.data.get(boxes.dataId).values;
76398 const scoresVals = backend.data.get(scores.dataId).values;
76399 const { selectedIndices, validOutputs } = nonMaxSuppressionV4Impl$1(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize);
76400 return [
76401 backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
76402 backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs]))
76403 ];
76404 }
76405 const nonMaxSuppressionV4Config = {
76406 kernelName: NonMaxSuppressionV4,
76407 backendName: 'cpu',
76408 kernelFunc: nonMaxSuppressionV4
76409 };
76410
76411 /**
76412 * @license
76413 * Copyright 2019 Google LLC. All Rights Reserved.
76414 * Licensed under the Apache License, Version 2.0 (the "License");
76415 * you may not use this file except in compliance with the License.
76416 * You may obtain a copy of the License at
76417 *
76418 * http://www.apache.org/licenses/LICENSE-2.0
76419 *
76420 * Unless required by applicable law or agreed to in writing, software
76421 * distributed under the License is distributed on an "AS IS" BASIS,
76422 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76423 * See the License for the specific language governing permissions and
76424 * limitations under the License.
76425 * =============================================================================
76426 */
76427 const nonMaxSuppressionV5Impl$1 = nonMaxSuppressionV5Impl;
76428 function nonMaxSuppressionV5(args) {
76429 const { inputs, backend, attrs } = args;
76430 const { boxes, scores } = inputs;
76431 const { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma } = attrs;
76432 assertNotComplex(boxes, 'NonMaxSuppressionWithScore');
76433 const boxesVals = backend.data.get(boxes.dataId).values;
76434 const scoresVals = backend.data.get(scores.dataId).values;
76435 const maxOutputSizeVal = maxOutputSize;
76436 const iouThresholdVal = iouThreshold;
76437 const scoreThresholdVal = scoreThreshold;
76438 const softNmsSigmaVal = softNmsSigma;
76439 const { selectedIndices, selectedScores } = nonMaxSuppressionV5Impl$1(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal);
76440 return [
76441 backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
76442 backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores))
76443 ];
76444 }
76445 const nonMaxSuppressionV5Config = {
76446 kernelName: NonMaxSuppressionV5,
76447 backendName: 'cpu',
76448 kernelFunc: nonMaxSuppressionV5
76449 };
76450
76451 /**
76452 * @license
76453 * Copyright 2020 Google LLC. All Rights Reserved.
76454 * Licensed under the Apache License, Version 2.0 (the "License");
76455 * you may not use this file except in compliance with the License.
76456 * You may obtain a copy of the License at
76457 *
76458 * http://www.apache.org/licenses/LICENSE-2.0
76459 *
76460 * Unless required by applicable law or agreed to in writing, software
76461 * distributed under the License is distributed on an "AS IS" BASIS,
76462 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76463 * See the License for the specific language governing permissions and
76464 * limitations under the License.
76465 * =============================================================================
76466 */
76467 function oneHot$2(args) {
76468 const { inputs, backend, attrs } = args;
76469 const { indices } = inputs;
76470 const { depth, onValue, offValue } = attrs;
76471 assertNotComplex(indices, 'oneHot');
76472 const indicesSize = sizeFromShape(indices.shape);
76473 const res = new Float32Array(indicesSize * depth);
76474 res.fill(offValue);
76475 const indicesVal = backend.data.get(indices.dataId).values;
76476 for (let event = 0; event < indicesSize; ++event) {
76477 if (indicesVal[event] >= 0 && indicesVal[event] < depth) {
76478 res[event * depth + indicesVal[event]] = onValue;
76479 }
76480 }
76481 return backend.makeTensorInfo([...indices.shape, depth], 'int32', res);
76482 }
76483 const oneHotConfig = {
76484 kernelName: OneHot,
76485 backendName: 'cpu',
76486 kernelFunc: oneHot$2
76487 };
76488
76489 /**
76490 * @license
76491 * Copyright 2020 Google LLC. All Rights Reserved.
76492 * Licensed under the Apache License, Version 2.0 (the "License");
76493 * you may not use this file except in compliance with the License.
76494 * You may obtain a copy of the License at
76495 *
76496 * http://www.apache.org/licenses/LICENSE-2.0
76497 *
76498 * Unless required by applicable law or agreed to in writing, software
76499 * distributed under the License is distributed on an "AS IS" BASIS,
76500 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76501 * See the License for the specific language governing permissions and
76502 * limitations under the License.
76503 * =============================================================================
76504 */
76505 function zerosLike$2(args) {
76506 const { inputs, backend } = args;
76507 const { x } = inputs;
76508 if (x.dtype === 'string') {
76509 throw new Error('zerosLike is not supported for string tensors');
76510 }
76511 else if (x.dtype === 'complex64') {
76512 const realPart = real$1({ inputs: { input: x }, backend });
76513 const r = zerosLike$2({ inputs: { x: realPart }, backend });
76514 const imagPart = imag$1({ inputs: { input: x }, backend });
76515 const i = zerosLike$2({ inputs: { x: imagPart }, backend });
76516 const result = complex$1({ inputs: { real: r, imag: i }, backend });
76517 backend.disposeIntermediateTensorInfo(realPart);
76518 backend.disposeIntermediateTensorInfo(r);
76519 backend.disposeIntermediateTensorInfo(imagPart);
76520 backend.disposeIntermediateTensorInfo(i);
76521 return result;
76522 }
76523 else {
76524 return fill$1({ backend, attrs: { shape: x.shape, value: 0, dtype: x.dtype } });
76525 }
76526 }
76527 const zerosLikeConfig = {
76528 kernelName: ZerosLike,
76529 backendName: 'cpu',
76530 kernelFunc: zerosLike$2
76531 };
76532
76533 /**
76534 * @license
76535 * Copyright 2020 Google LLC. All Rights Reserved.
76536 * Licensed under the Apache License, Version 2.0 (the "License");
76537 * you may not use this file except in compliance with the License.
76538 * You may obtain a copy of the License at
76539 *
76540 * http://www.apache.org/licenses/LICENSE-2.0
76541 *
76542 * Unless required by applicable law or agreed to in writing, software
76543 * distributed under the License is distributed on an "AS IS" BASIS,
76544 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76545 * See the License for the specific language governing permissions and
76546 * limitations under the License.
76547 * =============================================================================
76548 */
76549 function onesLike$2(args) {
76550 const { inputs, backend } = args;
76551 const { x } = inputs;
76552 if (x.dtype === 'string') {
76553 throw new Error('onesLike is not supported for string tensors');
76554 }
76555 else if (x.dtype === 'complex64') {
76556 const realPart = real$1({ inputs: { input: x }, backend });
76557 const r = onesLike$2({ inputs: { x: realPart }, backend });
76558 const imagPart = imag$1({ inputs: { input: x }, backend });
76559 const i = zerosLike$2({ inputs: { x: imagPart }, backend });
76560 const result = complex$1({ inputs: { real: r, imag: i }, backend });
76561 backend.disposeIntermediateTensorInfo(realPart);
76562 backend.disposeIntermediateTensorInfo(r);
76563 backend.disposeIntermediateTensorInfo(imagPart);
76564 backend.disposeIntermediateTensorInfo(i);
76565 return result;
76566 }
76567 else {
76568 return fill$1({ backend, attrs: { shape: x.shape, value: 1, dtype: x.dtype } });
76569 }
76570 }
76571 const onesLikeConfig = {
76572 kernelName: OnesLike,
76573 backendName: 'cpu',
76574 kernelFunc: onesLike$2
76575 };
76576
76577 /**
76578 * @license
76579 * Copyright 2020 Google LLC. All Rights Reserved.
76580 * Licensed under the Apache License, Version 2.0 (the "License");
76581 * you may not use this file except in compliance with the License.
76582 * You may obtain a copy of the License at
76583 *
76584 * http://www.apache.org/licenses/LICENSE-2.0
76585 *
76586 * Unless required by applicable law or agreed to in writing, software
76587 * distributed under the License is distributed on an "AS IS" BASIS,
76588 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76589 * See the License for the specific language governing permissions and
76590 * limitations under the License.
76591 * =============================================================================
76592 */
76593 function pack(args) {
76594 const { inputs, backend, attrs } = args;
76595 const { axis } = attrs;
76596 if (inputs.length === 1) {
76597 return expandDims$2({ inputs: { input: inputs[0] }, backend, attrs: { dim: axis } });
76598 }
76599 const shape = inputs[0].shape;
76600 const dtype = inputs[0].dtype;
76601 inputs.forEach(t => {
76602 assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes');
76603 assert(dtype === t.dtype, () => 'All tensors passed to stack must have matching dtypes');
76604 });
76605 const intermediateTensorInfos = [];
76606 const expandedTensors = inputs.map(t => {
76607 const expandedT = expandDims$2({ inputs: { input: t }, backend, attrs: { dim: axis } });
76608 intermediateTensorInfos.push(expandedT);
76609 return expandedT;
76610 });
76611 const result = concat$1({ inputs: expandedTensors, backend, attrs: { axis } });
76612 intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
76613 return result;
76614 }
76615 const packConfig = {
76616 kernelName: Pack,
76617 backendName: 'cpu',
76618 kernelFunc: pack
76619 };
76620
76621 /**
76622 * @license
76623 * Copyright 2020 Google LLC. All Rights Reserved.
76624 * Licensed under the Apache License, Version 2.0 (the "License");
76625 * you may not use this file except in compliance with the License.
76626 * You may obtain a copy of the License at
76627 *
76628 * http://www.apache.org/licenses/LICENSE-2.0
76629 *
76630 * Unless required by applicable law or agreed to in writing, software
76631 * distributed under the License is distributed on an "AS IS" BASIS,
76632 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76633 * See the License for the specific language governing permissions and
76634 * limitations under the License.
76635 * =============================================================================
76636 */
76637 function padV2(args) {
76638 const { inputs, backend, attrs } = args;
76639 const { x } = inputs;
76640 const { paddings, constantValue } = attrs;
76641 assertNotComplex(x, 'pad');
76642 const outShape = paddings.map((p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */);
76643 const start = paddings.map(p => p[0]);
76644 const xVals = backend.data.get(x.dataId).values;
76645 const xSize = sizeFromShape(x.shape);
76646 const xRank = x.shape.length;
76647 const xStrides = computeStrides(x.shape);
76648 const resultSize = sizeFromShape(outShape);
76649 const resultRank = outShape.length;
76650 const resultStrides = computeStrides(outShape);
76651 const resVals = getTypedArrayFromDType(x.dtype, resultSize);
76652 if (constantValue !== 0) {
76653 resVals.fill(constantValue);
76654 }
76655 for (let i = 0; i < xSize; i++) {
76656 const coords = indexToLoc(i, xRank, xStrides);
76657 const outCoords = coords.map((c, i) => c + start[i]);
76658 const outIndex = locToIndex(outCoords, resultRank, resultStrides);
76659 resVals[outIndex] = xVals[i];
76660 }
76661 const outId = backend.write(resVals, outShape, x.dtype);
76662 return { dataId: outId, shape: outShape, dtype: x.dtype };
76663 }
76664 const padV2Config = {
76665 kernelName: PadV2,
76666 backendName: 'cpu',
76667 kernelFunc: padV2
76668 };
76669
76670 /**
76671 * @license
76672 * Copyright 2020 Google LLC. All Rights Reserved.
76673 * Licensed under the Apache License, Version 2.0 (the "License");
76674 * you may not use this file except in compliance with the License.
76675 * You may obtain a copy of the License at
76676 *
76677 * http://www.apache.org/licenses/LICENSE-2.0
76678 *
76679 * Unless required by applicable law or agreed to in writing, software
76680 * distributed under the License is distributed on an "AS IS" BASIS,
76681 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76682 * See the License for the specific language governing permissions and
76683 * limitations under the License.
76684 * =============================================================================
76685 */
76686 const powImpl = createSimpleBinaryKernelImpl((a, b) => Math.pow(a, b));
76687 const pow$2 = binaryKernelFunc(Pow, powImpl);
76688 const powConfig = {
76689 kernelName: Pow,
76690 backendName: 'cpu',
76691 kernelFunc: pow$2
76692 };
76693
76694 /**
76695 * @license
76696 * Copyright 2020 Google LLC. All Rights Reserved.
76697 * Licensed under the Apache License, Version 2.0 (the "License");
76698 * you may not use this file except in compliance with the License.
76699 * You may obtain a copy of the License at
76700 *
76701 * http://www.apache.org/licenses/LICENSE-2.0
76702 *
76703 * Unless required by applicable law or agreed to in writing, software
76704 * distributed under the License is distributed on an "AS IS" BASIS,
76705 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76706 * See the License for the specific language governing permissions and
76707 * limitations under the License.
76708 * =============================================================================
76709 */
76710 function range$2(args) {
76711 const { backend, attrs } = args;
76712 const { start, stop, dtype, step } = attrs;
76713 const values = rangeImpl(start, stop, step, dtype);
76714 return backend.makeTensorInfo([values.length], dtype, values);
76715 }
76716 const rangeConfig = {
76717 kernelName: Range,
76718 backendName: 'cpu',
76719 kernelFunc: range$2
76720 };
76721
76722 /**
76723 * @license
76724 * Copyright 2020 Google LLC. All Rights Reserved.
76725 * Licensed under the Apache License, Version 2.0 (the License);
76726 * you may not use this file except in compliance with the License.
76727 * You may obtain a copy of the License at
76728 *
76729 * http://www.apache.org/licenses/LICENSE-2.0
76730 *
76731 * Unless required by applicable law or agreed to in writing, software
76732 * distributed under the License is distributed on an AS IS BASIS,
76733 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76734 * See the License for the specific language governing permissions and
76735 * limitations under the License.
76736 * =============================================================================
76737 */
76738 const reciprocal$1 = unaryKernelFunc(Reciprocal, (xi) => 1 / xi);
76739 const reciprocalConfig = {
76740 kernelName: Reciprocal,
76741 backendName: 'cpu',
76742 kernelFunc: reciprocal$1,
76743 };
76744
76745 /**
76746 * @license
76747 * Copyright 2020 Google LLC. All Rights Reserved.
76748 * Licensed under the Apache License, Version 2.0 (the "License");
76749 * you may not use this file except in compliance with the License.
76750 * You may obtain a copy of the License at
76751 *
76752 * http://www.apache.org/licenses/LICENSE-2.0
76753 *
76754 * Unless required by applicable law or agreed to in writing, software
76755 * distributed under the License is distributed on an "AS IS" BASIS,
76756 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76757 * See the License for the specific language governing permissions and
76758 * limitations under the License.
76759 * =============================================================================
76760 */
76761 function resizeBilinear$1(args) {
76762 const { inputs, backend, attrs } = args;
76763 const { images } = inputs;
76764 const { alignCorners, halfPixelCenters, size } = attrs;
76765 assertNotComplex(images, 'resizeBilinear');
76766 const imagesStrides = computeStrides(images.shape);
76767 const [newHeight, newWidth] = size;
76768 const [batch, oldHeight, oldWidth, numChannels] = images.shape;
76769 const xValues = backend.data.get(images.dataId).values;
76770 const result = new Float32Array(sizeFromShape([batch, newHeight, newWidth, numChannels]));
76771 const effectiveInputSize = [
76772 (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
76773 (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
76774 ];
76775 const effectiveOutputSize = [
76776 (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
76777 (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
76778 ];
76779 let outputIdx = 0;
76780 const effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
76781 const effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
76782 for (let b = 0; b < batch; b++) {
76783 for (let r = 0; r < newHeight; r++) {
76784 let sourceFracRow;
76785 if (halfPixelCenters) {
76786 sourceFracRow = effectiveRowSizeRatio * (r + 0.5) - 0.5;
76787 }
76788 else {
76789 sourceFracRow = effectiveRowSizeRatio * r;
76790 }
76791 const sourceRowFloor = Math.max(0, Math.floor(sourceFracRow));
76792 const rowFrac = sourceFracRow - sourceRowFloor;
76793 const sourceRowCeil = Math.min(oldHeight - 1, Math.ceil(sourceFracRow));
76794 const topRowOffset = b * imagesStrides[0] + sourceRowFloor * imagesStrides[1];
76795 const botRowOffset = b * imagesStrides[0] + sourceRowCeil * imagesStrides[1];
76796 for (let c = 0; c < newWidth; c++) {
76797 let sourceFracCol;
76798 if (halfPixelCenters) {
76799 sourceFracCol = effectiveColSizeRatio * (c + 0.5) - 0.5;
76800 }
76801 else {
76802 sourceFracCol = effectiveColSizeRatio * c;
76803 }
76804 const sourceColFloor = Math.max(0, Math.floor(sourceFracCol));
76805 const colFrac = sourceFracCol - sourceColFloor;
76806 const sourceColCeil = Math.min(oldWidth - 1, Math.ceil(sourceFracCol));
76807 const topLeftOffest = topRowOffset + sourceColFloor * imagesStrides[2];
76808 const botLeftOffset = botRowOffset + sourceColFloor * imagesStrides[2];
76809 const topRightOffset = topRowOffset + sourceColCeil * imagesStrides[2];
76810 const botRightOffest = botRowOffset + sourceColCeil * imagesStrides[2];
76811 for (let d = 0; d < numChannels; d++) {
76812 // Begin shader.
76813 // Compute the fractional index of the source.
76814 const topLeft = xValues[topLeftOffest + d];
76815 const bottomLeft = xValues[botLeftOffset + d];
76816 const topRight = xValues[topRightOffset + d];
76817 const bottomRight = xValues[botRightOffest + d];
76818 const top = topLeft + (topRight - topLeft) * colFrac;
76819 const bottom = bottomLeft + (bottomRight - bottomLeft) * colFrac;
76820 const newValue = top + (bottom - top) * rowFrac;
76821 result[outputIdx++] = newValue;
76822 }
76823 }
76824 }
76825 }
76826 return backend.makeTensorInfo([batch, newHeight, newWidth, numChannels], 'float32', result);
76827 }
76828 const resizeBilinearConfig = {
76829 kernelName: ResizeBilinear,
76830 backendName: 'cpu',
76831 kernelFunc: resizeBilinear$1
76832 };
76833
76834 /**
76835 * @license
76836 * Copyright 2020 Google LLC. All Rights Reserved.
76837 * Licensed under the Apache License, Version 2.0 (the "License");
76838 * you may not use this file except in compliance with the License.
76839 * You may obtain a copy of the License at
76840 *
76841 * http://www.apache.org/licenses/LICENSE-2.0
76842 *
76843 * Unless required by applicable law or agreed to in writing, software
76844 * distributed under the License is distributed on an "AS IS" BASIS,
76845 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76846 * See the License for the specific language governing permissions and
76847 * limitations under the License.
76848 * =============================================================================
76849 */
76850 function resizeBilinearGrad(args) {
76851 const { inputs, backend, attrs } = args;
76852 const { images, dy } = inputs;
76853 const { alignCorners } = attrs;
76854 assertNotComplex([dy, images], 'resizeBilinearGrad');
76855 const imagesStrides = computeStrides(images.shape);
76856 const [batch, xHeight, xWidth, depth] = images.shape;
76857 const [, yHeight, yWidth] = dy.shape;
76858 const output = new Float32Array(batch * xHeight * xWidth * depth);
76859 // In the backwards pass, we want to find the pixels that were generated
76860 // for each pixel in the input image the forward pass and add the
76861 // corresponding coefficient from dy to the gradient (with some
76862 // interpolation).
76863 const effectiveXSize = [
76864 (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
76865 (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
76866 ];
76867 const effectiveYSize = [
76868 (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
76869 (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
76870 ];
76871 const heightScale = effectiveXSize[0] / effectiveYSize[0];
76872 const widthScale = effectiveXSize[1] / effectiveYSize[1];
76873 // Reference implementation
76874 // tslint:disable-next-line:max-line-length
76875 // https://github.com/tensorflow/tensorflow/blob/3039375c86a5bbc9610c7725dcaa95d635f87ba2/tensorflow/core/kernels/resize_bilinear_op.cc#L275
76876 const dyValues = backend.data.get(dy.dataId).values;
76877 let offset = 0;
76878 for (let b = 0; b < batch; b++) {
76879 const bOffset = b * imagesStrides[0];
76880 for (let r = 0; r < yHeight; r++) {
76881 const dxR = r * heightScale;
76882 const topDxRIndex = Math.floor(dxR);
76883 const bottomDxRIndex = Math.min(Math.ceil(dxR), xHeight - 1);
76884 const topDxROffset = bOffset + topDxRIndex * imagesStrides[1];
76885 const bottomDxROffset = bOffset + bottomDxRIndex * imagesStrides[1];
76886 const dxRLerp = dxR - topDxRIndex;
76887 const inverseDxRLerp = 1.0 - dxRLerp;
76888 for (let c = 0; c < yWidth; c++) {
76889 const dxC = c * widthScale;
76890 const leftDxCIndex = Math.floor(dxC);
76891 const rightDxCIndex = Math.min(Math.ceil(dxC), xWidth - 1);
76892 const dxCLerp = dxC - leftDxCIndex;
76893 const inverseDxCLerp = 1.0 - dxCLerp;
76894 const topLeftRCOffset = topDxROffset + leftDxCIndex * imagesStrides[2];
76895 const topRightRCOffset = topDxROffset + rightDxCIndex * imagesStrides[2];
76896 const bottomLeftRCOffset = bottomDxROffset + leftDxCIndex * imagesStrides[2];
76897 const bottomRightRCOffset = bottomDxROffset + rightDxCIndex * imagesStrides[2];
76898 const inverseDxRLerpTimesInverseDxCLerp = inverseDxRLerp * inverseDxCLerp;
76899 const inverseDxRLerpTimesDxCLerp = inverseDxRLerp * dxCLerp;
76900 const dxRLerpTimesInverseDxCLerp = dxRLerp * inverseDxCLerp;
76901 const dxRLerpTimesDxCLerp = dxRLerp * dxCLerp;
76902 for (let d = 0; d < depth; d++) {
76903 const dyVal = dyValues[offset++];
76904 output[topLeftRCOffset + d] +=
76905 dyVal * inverseDxRLerpTimesInverseDxCLerp;
76906 output[topRightRCOffset + d] += dyVal * inverseDxRLerpTimesDxCLerp;
76907 output[bottomLeftRCOffset + d] += dyVal * dxRLerpTimesInverseDxCLerp;
76908 output[bottomRightRCOffset + d] += dyVal * dxRLerpTimesDxCLerp;
76909 }
76910 }
76911 }
76912 }
76913 return backend.makeTensorInfo([batch, xWidth, xHeight, depth], 'float32', output);
76914 }
76915 const resizeBilinearGradConfig$1 = {
76916 kernelName: ResizeBilinearGrad,
76917 backendName: 'cpu',
76918 kernelFunc: resizeBilinearGrad
76919 };
76920
76921 /**
76922 * @license
76923 * Copyright 2020 Google LLC. All Rights Reserved.
76924 * Licensed under the Apache License, Version 2.0 (the "License");
76925 * you may not use this file except in compliance with the License.
76926 * You may obtain a copy of the License at
76927 *
76928 * http://www.apache.org/licenses/LICENSE-2.0
76929 *
76930 * Unless required by applicable law or agreed to in writing, software
76931 * distributed under the License is distributed on an "AS IS" BASIS,
76932 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76933 * See the License for the specific language governing permissions and
76934 * limitations under the License.
76935 * =============================================================================
76936 */
76937 function resizeNearestNeighbor$1(args) {
76938 const { inputs, backend, attrs } = args;
76939 const { images } = inputs;
76940 const { alignCorners, halfPixelCenters, size } = attrs;
76941 assertNotComplex(images, 'resizeNearestNeighbor');
76942 const imagesStrides = computeStrides(images.shape);
76943 const [newHeight, newWidth] = size;
76944 const [batch, oldHeight, oldWidth, numChannels] = images.shape;
76945 const xValues = backend.data.get(images.dataId).values;
76946 const output = new Float32Array(batch * newHeight * newWidth * numChannels);
76947 const effectiveInputSize = [
76948 (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
76949 (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
76950 ];
76951 const effectiveOutputSize = [
76952 (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
76953 (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
76954 ];
76955 const effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
76956 const effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
76957 let outputOffset = 0;
76958 for (let b = 0; b < batch; b++) {
76959 const batchOffset = b * imagesStrides[0];
76960 for (let r = 0; r < newHeight; r++) {
76961 const sourceFracRow = halfPixelCenters ?
76962 effectiveRowSizeRatio * (r + 0.5) :
76963 effectiveRowSizeRatio * r;
76964 let sourceNearestRow = Math.min(oldHeight - 1, alignCorners ? Math.round(sourceFracRow) : Math.floor(sourceFracRow));
76965 if (halfPixelCenters) {
76966 sourceNearestRow = Math.max(0, sourceNearestRow);
76967 }
76968 const rowOffset = batchOffset + sourceNearestRow * imagesStrides[1];
76969 for (let c = 0; c < newWidth; c++) {
76970 const sourceFracCol = halfPixelCenters ?
76971 effectiveColSizeRatio * (c + 0.5) :
76972 effectiveColSizeRatio * c;
76973 let sourceNearestCol = Math.min(oldWidth - 1, alignCorners ? Math.round(sourceFracCol) :
76974 Math.floor(sourceFracCol));
76975 if (halfPixelCenters) {
76976 sourceNearestCol = Math.max(0, sourceNearestCol);
76977 }
76978 const colOffset = rowOffset + sourceNearestCol * imagesStrides[2];
76979 for (let d = 0; d < numChannels; d++) {
76980 // Begin shader.
76981 // Compute the fractional index of the source.
76982 const newVal = xValues[colOffset + d];
76983 output[outputOffset++] = newVal;
76984 }
76985 }
76986 }
76987 }
76988 return backend.makeTensorInfo([batch, newHeight, newWidth, numChannels], images.dtype, output);
76989 }
76990 const resizeNearestNeighborConfig = {
76991 kernelName: ResizeNearestNeighbor,
76992 backendName: 'cpu',
76993 kernelFunc: resizeNearestNeighbor$1
76994 };
76995
76996 /**
76997 * @license
76998 * Copyright 2020 Google LLC. All Rights Reserved.
76999 * Licensed under the Apache License, Version 2.0 (the "License");
77000 * you may not use this file except in compliance with the License.
77001 * You may obtain a copy of the License at
77002 *
77003 * http://www.apache.org/licenses/LICENSE-2.0
77004 *
77005 * Unless required by applicable law or agreed to in writing, software
77006 * distributed under the License is distributed on an "AS IS" BASIS,
77007 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77008 * See the License for the specific language governing permissions and
77009 * limitations under the License.
77010 * =============================================================================
77011 */
77012 function resizeNearestNeighborGrad(args) {
77013 const { inputs, backend, attrs } = args;
77014 const { images, dy } = inputs;
77015 const { alignCorners } = attrs;
77016 assertNotComplex([dy, images], 'resizeNearestNeighborGrad');
77017 const imagesStrides = computeStrides(images.shape);
77018 const dyStrides = computeStrides(dy.shape);
77019 const [batch, xHeight, xWidth, depth] = images.shape;
77020 const [, yHeight, yWidth] = dy.shape;
77021 const output = new Float32Array(batch * xHeight * xWidth * depth);
77022 const dyValues = backend.data.get(dy.dataId).values;
77023 // In the backwards pass, we want to find the pixels that were generated
77024 // for each pixel in the input image the forward pass
77025 const effectiveXSize = [
77026 (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
77027 (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
77028 ];
77029 const effectiveYSize = [
77030 (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
77031 (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
77032 ];
77033 const heightScale = effectiveXSize[0] / effectiveYSize[0];
77034 const widthScale = effectiveXSize[1] / effectiveYSize[1];
77035 const invHeightScale = 1 / heightScale;
77036 const invWidthScale = 1 / widthScale;
77037 // This defines the size of the window of values around a particular
77038 // index in dy that we want to search for contributions to dx.
77039 const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
77040 const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
77041 // Loop over the output space.
77042 for (let b = 0; b < batch; b++) {
77043 const batchOffset = b * imagesStrides[0];
77044 for (let r = 0; r < xHeight; r++) {
77045 const rowOffset = batchOffset + r * imagesStrides[1];
77046 // Compute bounds for where in dy we will look
77047 const startRLerp = Math.floor(r * invHeightScale);
77048 const startDyR = Math.floor(startRLerp - (winHeight / 2));
77049 for (let c = 0; c < xWidth; c++) {
77050 const colOffset = rowOffset + c * imagesStrides[2];
77051 // Compute bounds for where in dy we will look
77052 const startCLerp = Math.floor(c * invWidthScale);
77053 const startDyC = Math.floor(startCLerp - (winWidth / 2));
77054 for (let d = 0; d < depth; d++) {
77055 let accum = 0;
77056 // loop over dy
77057 for (let dyRIndex = 0; dyRIndex < winHeight; dyRIndex++) {
77058 const dyR = dyRIndex + startDyR;
77059 // Guard against the window exceeding the bounds of dy
77060 if (dyR < 0 || dyR >= yHeight) {
77061 continue;
77062 }
77063 const dyROffset = batchOffset + dyR * dyStrides[1];
77064 const sourceFracRow = dyR * heightScale;
77065 const sourceNearestRow = Math.min(xHeight - 1, alignCorners ? Math.round(sourceFracRow) :
77066 Math.floor(sourceFracRow));
77067 if (r !== sourceNearestRow) {
77068 continue;
77069 }
77070 for (let dyCIndex = 0; dyCIndex < winWidth; dyCIndex++) {
77071 const dyC = dyCIndex + startDyC;
77072 // Guard against the window exceeding the bounds of dy
77073 if (dyC < 0 || dyC >= yWidth) {
77074 continue;
77075 }
77076 const dyCOffset = dyROffset + dyC * dyStrides[2];
77077 const sourceFracCol = dyC * widthScale;
77078 const sourceNearestCol = Math.min(xWidth - 1, alignCorners ? Math.round(sourceFracCol) :
77079 Math.floor(sourceFracCol));
77080 if (c === sourceNearestCol) {
77081 accum += dyValues[dyCOffset + d];
77082 }
77083 }
77084 }
77085 output[colOffset + d] = accum;
77086 }
77087 }
77088 }
77089 }
77090 return backend.makeTensorInfo(images.shape, images.dtype, output);
77091 }
77092 const resizeNearestNeighborGradConfig$1 = {
77093 kernelName: ResizeNearestNeighborGrad,
77094 backendName: 'cpu',
77095 kernelFunc: resizeNearestNeighborGrad
77096 };
77097
77098 /**
77099 * @license
77100 * Copyright 2020 Google LLC. All Rights Reserved.
77101 * Licensed under the Apache License, Version 2.0 (the "License");
77102 * you may not use this file except in compliance with the License.
77103 * You may obtain a copy of the License at
77104 *
77105 * http://www.apache.org/licenses/LICENSE-2.0
77106 *
77107 * Unless required by applicable law or agreed to in writing, software
77108 * distributed under the License is distributed on an "AS IS" BASIS,
77109 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77110 * See the License for the specific language governing permissions and
77111 * limitations under the License.
77112 * =============================================================================
77113 */
77114 function reverse$1(args) {
77115 const { inputs, backend, attrs } = args;
77116 const { x } = inputs;
77117 const { dims } = attrs;
77118 assertNotComplex(x, 'reverse');
77119 const xRank = x.shape.length;
77120 const $dims = parseAxisParam(dims, x.shape);
77121 if (xRank === 0) {
77122 return identity$1({ inputs: { x }, backend });
77123 }
77124 const outBuf = new TensorBuffer(x.shape, x.dtype);
77125 const xBuf = backend.bufferSync(x);
77126 for (let i = 0; i < outBuf.size; i++) {
77127 const outLoc = outBuf.indexToLoc(i);
77128 const inLoc = outLoc.slice();
77129 $dims.forEach(d => inLoc[d] = x.shape[d] - 1 - inLoc[d]);
77130 outBuf.set(xBuf.get(...inLoc), ...outLoc);
77131 }
77132 return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
77133 }
77134 const reverseConfig = {
77135 kernelName: Reverse,
77136 backendName: 'cpu',
77137 kernelFunc: reverse$1
77138 };
77139
77140 /**
77141 * @license
77142 * Copyright 2020 Google LLC. All Rights Reserved.
77143 * Licensed under the Apache License, Version 2.0 (the "License");
77144 * you may not use this file except in compliance with the License.
77145 * You may obtain a copy of the License at
77146 *
77147 * http://www.apache.org/licenses/LICENSE-2.0
77148 *
77149 * Unless required by applicable law or agreed to in writing, software
77150 * distributed under the License is distributed on an "AS IS" BASIS,
77151 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77152 * See the License for the specific language governing permissions and
77153 * limitations under the License.
77154 * =============================================================================
77155 */
77156 const rotateWithOffsetConfig = {
77157 kernelName: RotateWithOffset,
77158 backendName: 'cpu',
77159 kernelFunc: ({ inputs, attrs, backend }) => {
77160 const { image } = inputs;
77161 const { radians, fillValue, center } = attrs;
77162 const cpuBackend = backend;
77163 const output = getTypedArrayFromDType(image.dtype, sizeFromShape(image.shape));
77164 const [batch, imageHeight, imageWidth, numChannels] = image.shape;
77165 const [centerX, centerY] = getImageCenter(center, imageHeight, imageWidth);
77166 const fullOpacityValue = 255;
77167 const sinFactor = Math.sin(radians);
77168 const cosFactor = Math.cos(radians);
77169 const imageVals = cpuBackend.data.get(image.dataId).values;
77170 for (let batchIdx = 0; batchIdx < batch; batchIdx++) {
77171 const batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
77172 for (let row = 0; row < imageHeight; row++) {
77173 const rowOffset = row * (imageWidth * numChannels);
77174 for (let col = 0; col < imageWidth; col++) {
77175 const colOffset = col * numChannels;
77176 for (let channel = 0; channel < numChannels; channel++) {
77177 const coords = [batch, row, col, channel];
77178 const x = coords[2];
77179 const y = coords[1];
77180 // coordX/coordY are the result of rotating and translating x/y.
77181 let coordX = (x - centerX) * cosFactor - (y - centerY) * sinFactor;
77182 let coordY = (x - centerX) * sinFactor + (y - centerY) * cosFactor;
77183 coordX = Math.round(coordX + centerX);
77184 coordY = Math.round(coordY + centerY);
77185 let outputValue = fillValue;
77186 if (typeof fillValue !== 'number') {
77187 if (channel === 3) {
77188 outputValue = fullOpacityValue;
77189 }
77190 else {
77191 outputValue = fillValue[channel];
77192 }
77193 }
77194 // If the coordinate position falls within the image boundaries...
77195 if (coordX >= 0 && coordX < imageWidth && coordY >= 0 &&
77196 coordY < imageHeight) {
77197 // set the output to the image value at the coordinate position.
77198 const rotatedRowOffset = coordY * (imageWidth * numChannels);
77199 const rotatedColOffset = coordX * numChannels;
77200 const imageIdx = batchOffset + rotatedRowOffset + rotatedColOffset + channel;
77201 outputValue = imageVals[imageIdx];
77202 }
77203 const outIdx = batchOffset + rowOffset + colOffset + channel;
77204 output[outIdx] = outputValue;
77205 }
77206 }
77207 }
77208 }
77209 const dataId = cpuBackend.write(output, image.shape, image.dtype);
77210 return { dataId, shape: image.shape, dtype: image.dtype };
77211 }
77212 };
77213
77214 /**
77215 * @license
77216 * Copyright 2020 Google LLC. All Rights Reserved.
77217 * Licensed under the Apache License, Version 2.0 (the License);
77218 * you may not use this file except in compliance with the License.
77219 * You may obtain a copy of the License at
77220 *
77221 * http://www.apache.org/licenses/LICENSE-2.0
77222 *
77223 * Unless required by applicable law or agreed to in writing, software
77224 * distributed under the License is distributed on an AS IS BASIS,
77225 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77226 * See the License for the specific language governing permissions and
77227 * limitations under the License.
77228 * =============================================================================
77229 */
77230 const round$2 = unaryKernelFunc(Round, (xi) => {
77231 // The algorithm is based on banker's rounding.
77232 const base = Math.floor(xi);
77233 if (xi - base < 0.5) {
77234 return Math.floor(xi);
77235 }
77236 else if (xi - base > 0.5) {
77237 return Math.ceil(xi);
77238 }
77239 else {
77240 if (base % 2.0 === 0.0) {
77241 return base;
77242 }
77243 else {
77244 return base + 1.0;
77245 }
77246 }
77247 });
77248 const roundConfig = {
77249 kernelName: Round,
77250 backendName: 'cpu',
77251 kernelFunc: round$2,
77252 };
77253
77254 /**
77255 * @license
77256 * Copyright 2020 Google LLC. All Rights Reserved.
77257 * Licensed under the Apache License, Version 2.0 (the "License");
77258 * you may not use this file except in compliance with the License.
77259 * You may obtain a copy of the License at
77260 *
77261 * http://www.apache.org/licenses/LICENSE-2.0
77262 *
77263 * Unless required by applicable law or agreed to in writing, software
77264 * distributed under the License is distributed on an "AS IS" BASIS,
77265 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77266 * See the License for the specific language governing permissions and
77267 * limitations under the License.
77268 * =============================================================================
77269 */
77270 function scatterNd(args) {
77271 const { inputs, backend, attrs } = args;
77272 const { indices, updates } = inputs;
77273 const { shape } = attrs;
77274 const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, shape);
77275 const sumDupeIndices = true;
77276 const indicesBuf = backend.bufferSync(indices);
77277 const updatesBuf = backend.bufferSync(updates);
77278 const outBuf = scatterImpl(indicesBuf, updatesBuf, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, 0 /* defaultValue */, sumDupeIndices);
77279 return backend.makeTensorInfo(shape, outBuf.dtype, outBuf.values);
77280 }
77281 const scatterNdConfig = {
77282 kernelName: ScatterNd,
77283 backendName: 'cpu',
77284 kernelFunc: scatterNd
77285 };
77286
77287 /**
77288 * @license
77289 * Copyright 2022 Google LLC. All Rights Reserved.
77290 * Licensed under the Apache License, Version 2.0 (the "License");
77291 * you may not use this file except in compliance with the License.
77292 * You may obtain a copy of the License at
77293 *
77294 * http://www.apache.org/licenses/LICENSE-2.0
77295 *
77296 * Unless required by applicable law or agreed to in writing, software
77297 * distributed under the License is distributed on an "AS IS" BASIS,
77298 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77299 * See the License for the specific language governing permissions and
77300 * limitations under the License.
77301 * =============================================================================
77302 */
77303 function lowerBound$1(array, value) {
77304 let left = 0;
77305 let right = array.length;
77306 let mid = 0;
77307 while (left < right) {
77308 mid = Math.floor((left + right) / 2);
77309 if (array[mid] < value) {
77310 left = mid + 1;
77311 }
77312 else {
77313 right = mid;
77314 }
77315 }
77316 return right;
77317 }
77318 function upperBound$1(array, value) {
77319 let left = 0;
77320 let right = array.length;
77321 let mid = 0;
77322 while (left < right) {
77323 mid = Math.floor((left + right) / 2);
77324 if (array[mid] <= value) {
77325 left = mid + 1;
77326 }
77327 else {
77328 right = mid;
77329 }
77330 }
77331 return right;
77332 }
77333 function searchSortedImpl(sortedInputs, values, batchSize, numInputs, numValues, side) {
77334 const output = getArrayFromDType('int32', batchSize * numValues);
77335 for (let b = 0; b < batchSize; ++b) {
77336 const sortedInputsSlice = sortedInputs.slice(b * numInputs, (b + 1) * numInputs);
77337 const outputOffset = b * numValues;
77338 for (let i = 0; i < numValues; ++i) {
77339 output[outputOffset + i] = side === 'left' ?
77340 lowerBound$1(sortedInputsSlice, values[i + outputOffset]) :
77341 upperBound$1(sortedInputsSlice, values[i + outputOffset]);
77342 }
77343 }
77344 return output;
77345 }
77346
77347 /**
77348 * @license
77349 * Copyright 2022 Google LLC. All Rights Reserved.
77350 * Licensed under the Apache License, Version 2.0 (the "License");
77351 * you may not use this file except in compliance with the License.
77352 * You may obtain a copy of the License at
77353 *
77354 * http://www.apache.org/licenses/LICENSE-2.0
77355 *
77356 * Unless required by applicable law or agreed to in writing, software
77357 * distributed under the License is distributed on an "AS IS" BASIS,
77358 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77359 * See the License for the specific language governing permissions and
77360 * limitations under the License.
77361 * =============================================================================
77362 */
77363 function searchSorted$1(args) {
77364 const { inputs, backend, attrs } = args;
77365 const { sortedSequence, values } = inputs;
77366 const { side } = attrs;
77367 const $sortedSequence = backend.data.get(sortedSequence.dataId).values;
77368 const $values = backend.data.get(values.dataId).values;
77369 const output = searchSortedImpl($sortedSequence, $values, sortedSequence.shape[0], sortedSequence.shape[1], values.shape[1], side);
77370 return backend.makeTensorInfo(values.shape, 'int32', output);
77371 }
77372 const searchSortedConfig = {
77373 kernelName: SearchSorted,
77374 backendName: 'cpu',
77375 kernelFunc: searchSorted$1,
77376 };
77377
77378 /**
77379 * @license
77380 * Copyright 2020 Google LLC. All Rights Reserved.
77381 * Licensed under the Apache License, Version 2.0 (the "License");
77382 * you may not use this file except in compliance with the License.
77383 * You may obtain a copy of the License at
77384 *
77385 * http://www.apache.org/licenses/LICENSE-2.0
77386 *
77387 * Unless required by applicable law or agreed to in writing, software
77388 * distributed under the License is distributed on an "AS IS" BASIS,
77389 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77390 * See the License for the specific language governing permissions and
77391 * limitations under the License.
77392 * =============================================================================
77393 */
77394 function select$1(args) {
77395 const { inputs, backend } = args;
77396 const { condition, t, e } = inputs;
77397 assertNotComplex([condition, t, e], 'select');
77398 const conditionRank = condition.shape.length;
77399 const values = backend.data.get(condition.dataId).values;
77400 const tValues = backend.data.get(t.dataId).values;
77401 const eValues = backend.data.get(e.dataId).values;
77402 const resultDtype = upcastType(t.dtype, e.dtype);
77403 const newValues = makeZerosTypedArray(sizeFromShape(t.shape), resultDtype);
77404 let index = 0;
77405 const offset = conditionRank === 0 || conditionRank > 1 || t.shape.length === 1 ?
77406 1 :
77407 sizeFromShape(t.shape.slice(1));
77408 for (let i = 0; i < values.length; i++) {
77409 for (let j = 0; j < offset; j++) {
77410 if (values[i] === 1) {
77411 newValues[index++] = tValues[i];
77412 }
77413 else {
77414 newValues[index++] = eValues[i];
77415 }
77416 }
77417 }
77418 return backend.makeTensorInfo(t.shape, resultDtype, newValues);
77419 }
77420 const selectConfig = {
77421 kernelName: Select,
77422 backendName: 'cpu',
77423 kernelFunc: select$1
77424 };
77425
77426 /**
77427 * @license
77428 * Copyright 2020 Google LLC. All Rights Reserved.
77429 * Licensed under the Apache License, Version 2.0 (the License);
77430 * you may not use this file except in compliance with the License.
77431 * You may obtain a copy of the License at
77432 *
77433 * http://www.apache.org/licenses/LICENSE-2.0
77434 *
77435 * Unless required by applicable law or agreed to in writing, software
77436 * distributed under the License is distributed on an AS IS BASIS,
77437 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77438 * See the License for the specific language governing permissions and
77439 * limitations under the License.
77440 * =============================================================================
77441 */
77442 const scaleAlpha = SELU_SCALEALPHA;
77443 const scale = SELU_SCALE;
77444 const selu$1 = unaryKernelFunc(Selu, (xi) => {
77445 if (xi >= 0) {
77446 return scale * xi;
77447 }
77448 else {
77449 return scaleAlpha * (Math.exp(xi) - 1);
77450 }
77451 });
77452 const seluConfig = {
77453 kernelName: Selu,
77454 backendName: 'cpu',
77455 kernelFunc: selu$1,
77456 };
77457
77458 /**
77459 * @license
77460 * Copyright 2020 Google LLC. All Rights Reserved.
77461 * Licensed under the Apache License, Version 2.0 (the License);
77462 * you may not use this file except in compliance with the License.
77463 * You may obtain a copy of the License at
77464 *
77465 * http://www.apache.org/licenses/LICENSE-2.0
77466 *
77467 * Unless required by applicable law or agreed to in writing, software
77468 * distributed under the License is distributed on an AS IS BASIS,
77469 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77470 * See the License for the specific language governing permissions and
77471 * limitations under the License.
77472 * =============================================================================
77473 */
77474 const sign$2 = unaryKernelFunc(Sign, (xi) => {
77475 if (xi < 0) {
77476 return -1;
77477 }
77478 else if (xi > 0) {
77479 return 1;
77480 }
77481 else {
77482 return 0;
77483 }
77484 });
77485 const signConfig = {
77486 kernelName: Sign,
77487 backendName: 'cpu',
77488 kernelFunc: sign$2,
77489 };
77490
77491 /**
77492 * @license
77493 * Copyright 2020 Google LLC. All Rights Reserved.
77494 * Licensed under the Apache License, Version 2.0 (the License);
77495 * you may not use this file except in compliance with the License.
77496 * You may obtain a copy of the License at
77497 *
77498 * http://www.apache.org/licenses/LICENSE-2.0
77499 *
77500 * Unless required by applicable law or agreed to in writing, software
77501 * distributed under the License is distributed on an AS IS BASIS,
77502 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77503 * See the License for the specific language governing permissions and
77504 * limitations under the License.
77505 * =============================================================================
77506 */
77507 const sin$1 = unaryKernelFunc(Sin, (xi) => Math.sin(xi));
77508 const sinConfig = {
77509 kernelName: Sin,
77510 backendName: 'cpu',
77511 kernelFunc: sin$1,
77512 };
77513
77514 /**
77515 * @license
77516 * Copyright 2020 Google LLC. All Rights Reserved.
77517 * Licensed under the Apache License, Version 2.0 (the License);
77518 * you may not use this file except in compliance with the License.
77519 * You may obtain a copy of the License at
77520 *
77521 * http://www.apache.org/licenses/LICENSE-2.0
77522 *
77523 * Unless required by applicable law or agreed to in writing, software
77524 * distributed under the License is distributed on an AS IS BASIS,
77525 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77526 * See the License for the specific language governing permissions and
77527 * limitations under the License.
77528 * =============================================================================
77529 */
77530 const sinh$1 = unaryKernelFunc(Sinh, (xi) => Math.sinh(xi));
77531 const sinhConfig = {
77532 kernelName: Sinh,
77533 backendName: 'cpu',
77534 kernelFunc: sinh$1,
77535 };
77536
77537 /**
77538 * @license
77539 * Copyright 2020 Google LLC. All Rights Reserved.
77540 * Licensed under the Apache License, Version 2.0 (the License);
77541 * you may not use this file except in compliance with the License.
77542 * You may obtain a copy of the License at
77543 *
77544 * http://www.apache.org/licenses/LICENSE-2.0
77545 *
77546 * Unless required by applicable law or agreed to in writing, software
77547 * distributed under the License is distributed on an AS IS BASIS,
77548 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77549 * See the License for the specific language governing permissions and
77550 * limitations under the License.
77551 * =============================================================================
77552 */
77553 // mirrors the implementation of tf.nn.softplus: https://goo.gl/vkcvwX
77554 // epsilon is the difference between 1.0 and the next representable float.
77555 // For a single precision 32 bit float this should be 2^-23, see:
77556 // https://math.byu.edu/~schow/work/IEEEFloatingPoint.htm
77557 const epsilon$1 = 1.1920928955078125e-7;
77558 const threshold$1 = Math.log(epsilon$1) + 2.0;
77559 const softplus$1 = unaryKernelFunc(Softplus, (xi) => {
77560 // Value above which exp(x) may overflow, but softplus(x) == x
77561 // is within machine epsilon.
77562 const tooLarge = xi > -threshold$1;
77563 // Value below which exp(x) may underflow, but softplus(x) == exp(x)
77564 // is within machine epsilon.
77565 const tooSmall = xi < threshold$1;
77566 const expX = Math.exp(xi);
77567 let result;
77568 if (tooSmall) {
77569 result = expX;
77570 }
77571 else if (tooLarge) {
77572 result = xi;
77573 }
77574 else {
77575 result = Math.log(1.0 + expX);
77576 }
77577 return result;
77578 });
77579 const softplusConfig = {
77580 kernelName: Softplus,
77581 backendName: 'cpu',
77582 kernelFunc: softplus$1,
77583 };
77584
77585 /**
77586 * @license
77587 * Copyright 2020 Google LLC. All Rights Reserved.
77588 * Licensed under the Apache License, Version 2.0 (the "License");
77589 * you may not use this file except in compliance with the License.
77590 * You may obtain a copy of the License at
77591 *
77592 * http://www.apache.org/licenses/LICENSE-2.0
77593 *
77594 * Unless required by applicable law or agreed to in writing, software
77595 * distributed under the License is distributed on an "AS IS" BASIS,
77596 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77597 * See the License for the specific language governing permissions and
77598 * limitations under the License.
77599 * =============================================================================
77600 */
77601 function spaceToBatchND$1(args) {
77602 const { inputs, backend, attrs } = args;
77603 const { x } = inputs;
77604 const { blockShape, paddings } = attrs;
77605 assertNotComplex([x], 'spaceToBatchND');
77606 const prod = sizeFromShape(blockShape);
77607 const completePaddings = [[0, 0]];
77608 completePaddings.push(...paddings);
77609 for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
77610 completePaddings.push([0, 0]);
77611 }
77612 const paddedX = padV2Config.kernelFunc({
77613 inputs: { x },
77614 backend,
77615 attrs: { paddings: completePaddings, constantValue: 0 }
77616 });
77617 const reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false);
77618 const permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false);
77619 const flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false);
77620 const reshapeInputs = { x: paddedX };
77621 const reshapeAttrs = { shape: reshapedPaddedShape };
77622 const paddedXReshaped = reshape$2({ inputs: reshapeInputs, backend, attrs: reshapeAttrs });
77623 const transposeInputs = { x: paddedXReshaped };
77624 const transposeAttrs = { perm: permutedReshapedPaddedPermutation };
77625 const paddedXT = transpose$1({ inputs: transposeInputs, backend, attrs: transposeAttrs });
77626 const resultReshapeInputs = { x: paddedXT };
77627 const resultReshapeAttrs = { shape: flattenShape };
77628 const result = reshape$2({ inputs: resultReshapeInputs, backend, attrs: resultReshapeAttrs });
77629 backend.disposeIntermediateTensorInfo(paddedX);
77630 backend.disposeIntermediateTensorInfo(paddedXReshaped);
77631 backend.disposeIntermediateTensorInfo(paddedXT);
77632 return result;
77633 }
77634 const spaceToBatchNDConfig = {
77635 kernelName: SpaceToBatchND,
77636 backendName: 'cpu',
77637 kernelFunc: spaceToBatchND$1
77638 };
77639
77640 /**
77641 * @license
77642 * Copyright 2021 Google LLC. All Rights Reserved.
77643 * Licensed under the Apache License, Version 2.0 (the "License");
77644 * you may not use this file except in compliance with the License.
77645 * You may obtain a copy of the License at
77646 *
77647 * http://www.apache.org/licenses/LICENSE-2.0
77648 *
77649 * Unless required by applicable law or agreed to in writing, software
77650 * distributed under the License is distributed on an "AS IS" BASIS,
77651 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77652 * See the License for the specific language governing permissions and
77653 * limitations under the License.
77654 * =============================================================================
77655 */
77656 function sparseFillEmptyRows$1(args) {
77657 const { inputs, backend } = args;
77658 const { indices, values, denseShape, defaultValue } = inputs;
77659 if (denseShape.shape.length !== 1) {
77660 throw new Error(`Dense shape must be a vector, saw:
77661 ${denseShape.shape}`);
77662 }
77663 if (indices.shape.length !== 2) {
77664 throw new Error(`Indices must be a matrix, saw:
77665 ${indices.shape}`);
77666 }
77667 if (values.shape.length !== 1) {
77668 throw new Error(`Values must be a vector, saw:
77669 ${values.shape}`);
77670 }
77671 if (defaultValue.shape.length !== 0) {
77672 throw new Error(`Default value must be a scalar, saw:
77673 ${defaultValue.shape}`);
77674 }
77675 const $indices = backend.data.get(indices.dataId).values;
77676 const $values = backend.data.get(values.dataId).values;
77677 const $denseShape = backend.data.get(denseShape.dataId).values;
77678 const $defaultValue = backend.data.get(defaultValue.dataId).values[0];
77679 const [outputIndices, outputIndicesShape, outputValues, emptyRowIndicator, reverseIndexMap] = sparseFillEmptyRowsImpl($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue);
77680 return [
77681 backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices),
77682 backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues),
77683 backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map((value) => Number(value)))),
77684 backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap)),
77685 ];
77686 }
77687 const sparseFillEmptyRowsConfig = {
77688 kernelName: SparseFillEmptyRows,
77689 backendName: 'cpu',
77690 kernelFunc: sparseFillEmptyRows$1,
77691 };
77692
77693 /**
77694 * @license
77695 * Copyright 2021 Google LLC. All Rights Reserved.
77696 * Licensed under the Apache License, Version 2.0 (the "License");
77697 * you may not use this file except in compliance with the License.
77698 * You may obtain a copy of the License at
77699 *
77700 * http://www.apache.org/licenses/LICENSE-2.0
77701 *
77702 * Unless required by applicable law or agreed to in writing, software
77703 * distributed under the License is distributed on an "AS IS" BASIS,
77704 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77705 * See the License for the specific language governing permissions and
77706 * limitations under the License.
77707 * =============================================================================
77708 */
77709 function sparseReshape$1(args) {
77710 const { inputs, backend } = args;
77711 const { inputIndices, inputShape, newShape } = inputs;
77712 if (inputIndices.shape.length !== 2) {
77713 throw new Error(`Input indices should be a matrix but received shape
77714 ${inputIndices.shape}`);
77715 }
77716 if (inputShape.shape.length !== 1) {
77717 throw new Error(`Input shape should be a vector but received shape
77718 ${inputShape.shape}`);
77719 }
77720 if (newShape.shape.length !== 1) {
77721 throw new Error(`Target shape should be a vector but received shape ${newShape.shape}`);
77722 }
77723 const $inputShape = Array.from(backend.data.get(inputShape.dataId).values);
77724 const $inputIndices = backend.data.get(inputIndices.dataId).values;
77725 const targetShape = Array.from(backend.data.get(newShape.dataId).values);
77726 const [newIndices, indicesShape, outputShape] = sparseReshapeImpl($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape);
77727 return [
77728 backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices),
77729 backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape)),
77730 ];
77731 }
77732 const sparseReshapeConfig = {
77733 kernelName: SparseReshape,
77734 backendName: 'cpu',
77735 kernelFunc: sparseReshape$1,
77736 };
77737
77738 /**
77739 * @license
77740 * Copyright 2021 Google LLC. All Rights Reserved.
77741 * Licensed under the Apache License, Version 2.0 (the "License");
77742 * you may not use this file except in compliance with the License.
77743 * You may obtain a copy of the License at
77744 *
77745 * http://www.apache.org/licenses/LICENSE-2.0
77746 *
77747 * Unless required by applicable law or agreed to in writing, software
77748 * distributed under the License is distributed on an "AS IS" BASIS,
77749 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77750 * See the License for the specific language governing permissions and
77751 * limitations under the License.
77752 * =============================================================================
77753 */
77754 function sparseSegmentMean$1(args) {
77755 const { inputs, backend } = args;
77756 const { data, indices, segmentIds } = inputs;
77757 if (data.shape.length < 1) {
77758 throw new Error(`Data should be at least 1 dimensional but received scalar`);
77759 }
77760 if (indices.shape.length !== 1) {
77761 throw new Error(`Indices should be a vector but received shape
77762 ${indices.shape}`);
77763 }
77764 if (segmentIds.shape.length !== 1) {
77765 throw new Error(`Segment ids should be a vector but received shape
77766 ${segmentIds.shape}`);
77767 }
77768 if (indices.shape[0] !== segmentIds.shape[0]) {
77769 throw new Error(`segmentIds and indices should have same size.`);
77770 }
77771 const $data = backend.data.get(data.dataId).values;
77772 const $indices = backend.data.get(indices.dataId).values;
77773 const $segmentIds = backend.data.get(segmentIds.dataId).values;
77774 const [outputData, outputDataShape] = sparseSegmentReductionImpl($data, data.shape, data.dtype, $indices, $segmentIds, true);
77775 return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
77776 }
77777 const sparseSegmentMeanConfig = {
77778 kernelName: SparseSegmentMean,
77779 backendName: 'cpu',
77780 kernelFunc: sparseSegmentMean$1,
77781 };
77782
77783 /**
77784 * @license
77785 * Copyright 2021 Google LLC. All Rights Reserved.
77786 * Licensed under the Apache License, Version 2.0 (the "License");
77787 * you may not use this file except in compliance with the License.
77788 * You may obtain a copy of the License at
77789 *
77790 * http://www.apache.org/licenses/LICENSE-2.0
77791 *
77792 * Unless required by applicable law or agreed to in writing, software
77793 * distributed under the License is distributed on an "AS IS" BASIS,
77794 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77795 * See the License for the specific language governing permissions and
77796 * limitations under the License.
77797 * =============================================================================
77798 */
77799 function sparseSegmentSum$1(args) {
77800 const { inputs, backend } = args;
77801 const { data, indices, segmentIds } = inputs;
77802 if (data.shape.length < 1) {
77803 throw new Error(`Data should be at least 1 dimensional but received scalar`);
77804 }
77805 if (indices.shape.length !== 1) {
77806 throw new Error(`Indices should be a vector but received shape
77807 ${indices.shape}`);
77808 }
77809 if (segmentIds.shape.length !== 1) {
77810 throw new Error(`Segment ids should be a vector but received shape
77811 ${segmentIds.shape}`);
77812 }
77813 if (indices.shape[0] !== segmentIds.shape[0]) {
77814 throw new Error(`segmentIds and indices should have same size.`);
77815 }
77816 const $data = backend.data.get(data.dataId).values;
77817 const $indices = backend.data.get(indices.dataId).values;
77818 const $segmentIds = backend.data.get(segmentIds.dataId).values;
77819 const [outputData, outputDataShape] = sparseSegmentReductionImpl($data, data.shape, data.dtype, $indices, $segmentIds);
77820 return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
77821 }
77822 const sparseSegmentSumConfig = {
77823 kernelName: SparseSegmentSum,
77824 backendName: 'cpu',
77825 kernelFunc: sparseSegmentSum$1,
77826 };
77827
77828 /**
77829 * @license
77830 * Copyright 2020 Google LLC. All Rights Reserved.
77831 * Licensed under the Apache License, Version 2.0 (the "License");
77832 * you may not use this file except in compliance with the License.
77833 * You may obtain a copy of the License at
77834 *
77835 * http://www.apache.org/licenses/LICENSE-2.0
77836 *
77837 * Unless required by applicable law or agreed to in writing, software
77838 * distributed under the License is distributed on an "AS IS" BASIS,
77839 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77840 * See the License for the specific language governing permissions and
77841 * limitations under the License.
77842 * =============================================================================
77843 */
77844 function sparseToDense$1(args) {
77845 const { inputs, backend, attrs } = args;
77846 const { sparseIndices, sparseValues, defaultValue } = inputs;
77847 const { outputShape } = attrs;
77848 const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(sparseValues, sparseIndices, outputShape);
77849 const sumDupeIndices = false;
77850 const indicesBuf = backend.bufferSync(sparseIndices);
77851 let outBuf;
77852 switch (sparseValues.dtype) {
77853 case 'bool': {
77854 const updatesBuf = backend.bufferSync(sparseValues);
77855 const $defaultValue = Boolean(backend.data.get(defaultValue.dataId).values[0]);
77856 outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
77857 break;
77858 }
77859 case 'float32': {
77860 const updatesBuf = backend.bufferSync(sparseValues);
77861 const $defaultValue = backend.data.get(defaultValue.dataId).values[0];
77862 outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
77863 break;
77864 }
77865 case 'int32': {
77866 const updatesBuf = backend.bufferSync(sparseValues);
77867 const $defaultValue = backend.data.get(defaultValue.dataId).values[0];
77868 outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
77869 break;
77870 }
77871 case 'string': {
77872 const updatesBuf = backend.bufferSync(sparseValues);
77873 const $defaultValue = decodeString(backend.data.get(defaultValue.dataId).values[0]);
77874 outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
77875 break;
77876 }
77877 default:
77878 throw new Error(`Unsupported type ${sparseValues.dtype}`);
77879 }
77880 return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values);
77881 }
77882 const sparseToDenseConfig = {
77883 kernelName: SparseToDense,
77884 backendName: 'cpu',
77885 kernelFunc: sparseToDense$1
77886 };
77887
77888 /**
77889 * @license
77890 * Copyright 2020 Google LLC. All Rights Reserved.
77891 * Licensed under the Apache License, Version 2.0 (the "License");
77892 * you may not use this file except in compliance with the License.
77893 * You may obtain a copy of the License at
77894 *
77895 * http://www.apache.org/licenses/LICENSE-2.0
77896 *
77897 * Unless required by applicable law or agreed to in writing, software
77898 * distributed under the License is distributed on an "AS IS" BASIS,
77899 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77900 * See the License for the specific language governing permissions and
77901 * limitations under the License.
77902 * =============================================================================
77903 */
77904 function splitV(args) {
77905 const { inputs, backend, attrs } = args;
77906 const { x } = inputs;
77907 const { numOrSizeSplits, axis } = attrs;
77908 const $axis = parseAxisParam(axis, x.shape)[0];
77909 const splitSizes = prepareSplitSize(x, numOrSizeSplits, $axis);
77910 const begin = new Array(x.shape.length).fill(0);
77911 const size = x.shape.slice();
77912 return splitSizes.map(s => {
77913 const sliceSize = [...size];
77914 sliceSize[$axis] = s;
77915 const sliceT = slice$1({ inputs: { x }, backend, attrs: { begin, size: sliceSize } });
77916 begin[$axis] += s;
77917 return sliceT;
77918 });
77919 }
77920 const splitVConfig = {
77921 kernelName: SplitV,
77922 backendName: 'cpu',
77923 kernelFunc: splitV
77924 };
77925
77926 /**
77927 * @license
77928 * Copyright 2019 Google LLC. All Rights Reserved.
77929 * Licensed under the Apache License, Version 2.0 (the "License");
77930 * you may not use this file except in compliance with the License.
77931 * You may obtain a copy of the License at
77932 *
77933 * http://www.apache.org/licenses/LICENSE-2.0
77934 *
77935 * Unless required by applicable law or agreed to in writing, software
77936 * distributed under the License is distributed on an "AS IS" BASIS,
77937 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77938 * See the License for the specific language governing permissions and
77939 * limitations under the License.
77940 * =============================================================================
77941 */
77942 const squareConfig = {
77943 kernelName: Square,
77944 backendName: 'cpu',
77945 kernelFunc: ({ inputs, backend }) => {
77946 const { x } = inputs;
77947 const cpuBackend = backend;
77948 assertNotComplex(x, 'square');
77949 const values = cpuBackend.data.get(x.dataId).values;
77950 const newValues = new Float32Array(values.length);
77951 for (let i = 0; i < values.length; ++i) {
77952 const value = values[i];
77953 newValues[i] = value * value;
77954 }
77955 const dataId = cpuBackend.write(newValues, x.shape, x.dtype);
77956 return { dataId, shape: x.shape, dtype: x.dtype };
77957 }
77958 };
77959
77960 /**
77961 * @license
77962 * Copyright 2020 Google LLC. All Rights Reserved.
77963 * Licensed under the Apache License, Version 2.0 (the License);
77964 * you may not use this file except in compliance with the License.
77965 * You may obtain a copy of the License at
77966 *
77967 * http://www.apache.org/licenses/LICENSE-2.0
77968 *
77969 * Unless required by applicable law or agreed to in writing, software
77970 * distributed under the License is distributed on an AS IS BASIS,
77971 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77972 * See the License for the specific language governing permissions and
77973 * limitations under the License.
77974 * =============================================================================
77975 */
77976 const step$1 = unaryKernelFunc(Step, (xi, attrs) => {
77977 const stepAttrs = attrs;
77978 if (isNaN(xi)) {
77979 return NaN;
77980 }
77981 else {
77982 return xi > 0 ? 1 : stepAttrs.alpha;
77983 }
77984 });
77985 const stepConfig = {
77986 kernelName: Step,
77987 backendName: 'cpu',
77988 kernelFunc: step$1,
77989 };
77990
77991 /**
77992 * @license
77993 * Copyright 2020 Google LLC. All Rights Reserved.
77994 * Licensed under the Apache License, Version 2.0 (the "License");
77995 * you may not use this file except in compliance with the License.
77996 * You may obtain a copy of the License at
77997 *
77998 * http://www.apache.org/licenses/LICENSE-2.0
77999 *
78000 * Unless required by applicable law or agreed to in writing, software
78001 * distributed under the License is distributed on an "AS IS" BASIS,
78002 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78003 * See the License for the specific language governing permissions and
78004 * limitations under the License.
78005 * =============================================================================
78006 */
78007 function stridedSlice$1(args) {
78008 const { inputs, backend, attrs } = args;
78009 const { x } = inputs;
78010 const { begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask } = attrs;
78011 assertNotComplex(x, 'stridedSlice');
78012 const { finalShapeSparse, finalShape, isIdentity, sliceDim0, isSimpleSlice, begin: $begin, end: $end, strides: $strides } = sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
78013 let result;
78014 // ref:
78015 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/strided_slice_op.cc
78016 if (isIdentity) {
78017 // Optimization #1, slice is a no-op plus reshape
78018 result = reshape$2({ inputs: { x }, backend, attrs: { shape: finalShape } });
78019 }
78020 else if (sliceDim0 || isSimpleSlice) {
78021 // Optimization #2, slice is memory contiguous (only occurs in dim 0)
78022 assert(x.shape.length >= 1, () => `Input must have rank at least 1, got: ${x.shape.length}`);
78023 const size = computeOutShape($begin, $end, $strides);
78024 // To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end).
78025 const sliced = slice$1({ inputs: { x }, backend, attrs: { begin: $begin, size } });
78026 result =
78027 reshape$2({ inputs: { x: sliced }, backend, attrs: { shape: finalShape } });
78028 backend.disposeIntermediateTensorInfo(sliced);
78029 }
78030 else {
78031 const xBuf = backend.bufferSync(x);
78032 const outBuf = stridedSliceImpl(finalShapeSparse, xBuf, $strides, $begin);
78033 result = backend.makeTensorInfo(finalShape, outBuf.dtype, outBuf.values);
78034 }
78035 return result;
78036 }
78037 const stridedSliceConfig = {
78038 kernelName: StridedSlice,
78039 backendName: 'cpu',
78040 kernelFunc: stridedSlice$1
78041 };
78042
78043 /**
78044 * @license
78045 * Copyright 2021 Google LLC. All Rights Reserved.
78046 * Licensed under the Apache License, Version 2.0 (the "License");
78047 * you may not use this file except in compliance with the License.
78048 * You may obtain a copy of the License at
78049 *
78050 * http://www.apache.org/licenses/LICENSE-2.0
78051 *
78052 * Unless required by applicable law or agreed to in writing, software
78053 * distributed under the License is distributed on an "AS IS" BASIS,
78054 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78055 * See the License for the specific language governing permissions and
78056 * limitations under the License.
78057 * =============================================================================
78058 */
78059 function stringNGrams$1(args) {
78060 const { inputs, backend, attrs } = args;
78061 const { separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences } = attrs;
78062 const { data, dataSplits } = inputs;
78063 const $data = backend.data.get(data.dataId).values;
78064 const $dataSplits = backend.data.get(dataSplits.dataId).values;
78065 const [nGrams, nGramsSplits] = stringNGramsImpl($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences);
78066 return [
78067 backend.makeTensorInfo([nGrams.length], 'string', nGrams),
78068 backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits),
78069 ];
78070 }
78071 const stringNGramsConfig = {
78072 kernelName: StringNGrams,
78073 backendName: 'cpu',
78074 kernelFunc: stringNGrams$1,
78075 };
78076
78077 /**
78078 * @license
78079 * Copyright 2021 Google LLC. All Rights Reserved.
78080 * Licensed under the Apache License, Version 2.0 (the "License");
78081 * you may not use this file except in compliance with the License.
78082 * You may obtain a copy of the License at
78083 *
78084 * http://www.apache.org/licenses/LICENSE-2.0
78085 *
78086 * Unless required by applicable law or agreed to in writing, software
78087 * distributed under the License is distributed on an "AS IS" BASIS,
78088 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78089 * See the License for the specific language governing permissions and
78090 * limitations under the License.
78091 * =============================================================================
78092 */
78093 function stringSplit$1(args) {
78094 const { inputs, backend, attrs } = args;
78095 const { skipEmpty } = attrs;
78096 const { input, delimiter } = inputs;
78097 if (input.dtype !== 'string') {
78098 throw new Error('Input must be of datatype string');
78099 }
78100 if (input.shape.length !== 1) {
78101 throw new Error(`Input must be a vector, got shape: ${input.shape}`);
78102 }
78103 if (delimiter.shape.length !== 0) {
78104 throw new Error(`Delimiter must be a scalar, got shape: ${delimiter.shape}`);
78105 }
78106 const $input = backend.data.get(input.dataId).values;
78107 const $delimiter = backend.data.get(delimiter.dataId).values[0];
78108 const [indices, values, shape] = stringSplitImpl($input, $delimiter, skipEmpty);
78109 const outputSize = values.length;
78110 return [
78111 backend.makeTensorInfo([outputSize, 2], 'int32', indices),
78112 backend.makeTensorInfo([outputSize], 'string', values),
78113 backend.makeTensorInfo([2], 'int32', new Int32Array(shape))
78114 ];
78115 }
78116 const stringSplitConfig = {
78117 kernelName: StringSplit,
78118 backendName: 'cpu',
78119 kernelFunc: stringSplit$1,
78120 };
78121
78122 /**
78123 * @license
78124 * Copyright 2021 Google LLC. All Rights Reserved.
78125 * Licensed under the Apache License, Version 2.0 (the "License");
78126 * you may not use this file except in compliance with the License.
78127 * You may obtain a copy of the License at
78128 *
78129 * http://www.apache.org/licenses/LICENSE-2.0
78130 *
78131 * Unless required by applicable law or agreed to in writing, software
78132 * distributed under the License is distributed on an "AS IS" BASIS,
78133 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78134 * See the License for the specific language governing permissions and
78135 * limitations under the License.
78136 * =============================================================================
78137 */
78138 function stringToHashBucketFast$1(args) {
78139 const { inputs, backend, attrs } = args;
78140 const { numBuckets } = attrs;
78141 const { input } = inputs;
78142 if (input.dtype !== 'string') {
78143 throw new Error('Input must be of datatype string');
78144 }
78145 if (numBuckets <= 0) {
78146 throw new Error(`Number of buckets must be at least 1`);
78147 }
78148 const $input = backend.data.get(input.dataId).values;
78149 const output = stringToHashBucketFastImpl($input, numBuckets);
78150 return backend.makeTensorInfo(input.shape, 'int32', output);
78151 }
78152 const stringToHashBucketFastConfig = {
78153 kernelName: StringToHashBucketFast,
78154 backendName: 'cpu',
78155 kernelFunc: stringToHashBucketFast$1,
78156 };
78157
78158 /**
78159 * @license
78160 * Copyright 2020 Google LLC. All Rights Reserved.
78161 * Licensed under the Apache License, Version 2.0 (the License);
78162 * you may not use this file except in compliance with the License.
78163 * You may obtain a copy of the License at
78164 *
78165 * http://www.apache.org/licenses/LICENSE-2.0
78166 *
78167 * Unless required by applicable law or agreed to in writing, software
78168 * distributed under the License is distributed on an AS IS BASIS,
78169 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78170 * See the License for the specific language governing permissions and
78171 * limitations under the License.
78172 * =============================================================================
78173 */
78174 const tan$1 = unaryKernelFunc(Tan, (xi) => Math.tan(xi));
78175 const tanConfig = {
78176 kernelName: Tan,
78177 backendName: 'cpu',
78178 kernelFunc: tan$1,
78179 };
78180
78181 /**
78182 * @license
78183 * Copyright 2020 Google LLC. All Rights Reserved.
78184 * Licensed under the Apache License, Version 2.0 (the License);
78185 * you may not use this file except in compliance with the License.
78186 * You may obtain a copy of the License at
78187 *
78188 * http://www.apache.org/licenses/LICENSE-2.0
78189 *
78190 * Unless required by applicable law or agreed to in writing, software
78191 * distributed under the License is distributed on an AS IS BASIS,
78192 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78193 * See the License for the specific language governing permissions and
78194 * limitations under the License.
78195 * =============================================================================
78196 */
78197 const tanh$2 = unaryKernelFunc(Tanh, (xi) => Math.tanh(xi));
78198 const tanhConfig = {
78199 kernelName: Tanh,
78200 backendName: 'cpu',
78201 kernelFunc: tanh$2,
78202 };
78203
78204 /**
78205 * @license
78206 * Copyright 2020 Google LLC. All Rights Reserved.
78207 * Licensed under the Apache License, Version 2.0 (the "License");
78208 * you may not use this file except in compliance with the License.
78209 * You may obtain a copy of the License at
78210 *
78211 * http://www.apache.org/licenses/LICENSE-2.0
78212 *
78213 * Unless required by applicable law or agreed to in writing, software
78214 * distributed under the License is distributed on an "AS IS" BASIS,
78215 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78216 * See the License for the specific language governing permissions and
78217 * limitations under the License.
78218 * =============================================================================
78219 */
78220 function tile$2(args) {
78221 const { inputs, backend, attrs } = args;
78222 const { x } = inputs;
78223 const { reps } = attrs;
78224 assertNotComplex(x, 'tile');
78225 const outBuf = tileImpl(backend.bufferSync(x), reps);
78226 return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
78227 }
78228 const tileConfig = {
78229 kernelName: Tile,
78230 backendName: 'cpu',
78231 kernelFunc: tile$2
78232 };
78233
78234 /**
78235 * @license
78236 * Copyright 2020 Google LLC. All Rights Reserved.
78237 * Licensed under the Apache License, Version 2.0 (the "License");
78238 * you may not use this file except in compliance with the License.
78239 * You may obtain a copy of the License at
78240 *
78241 * http://www.apache.org/licenses/LICENSE-2.0
78242 *
78243 * Unless required by applicable law or agreed to in writing, software
78244 * distributed under the License is distributed on an "AS IS" BASIS,
78245 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78246 * See the License for the specific language governing permissions and
78247 * limitations under the License.
78248 * =============================================================================
78249 */
78250 function topK(args) {
78251 const { inputs, backend, attrs } = args;
78252 const { x } = inputs;
78253 const { k, sorted } = attrs;
78254 assertNotComplex(x, 'topk');
78255 const xVals = backend.data.get(x.dataId).values;
78256 const [allTopKVals, allTopKIndices] = topKImpl(xVals, x.shape, x.dtype, k, sorted);
78257 return [
78258 backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values),
78259 backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)
78260 ];
78261 }
78262 const topKConfig = {
78263 kernelName: TopK,
78264 backendName: 'cpu',
78265 kernelFunc: topK
78266 };
78267
78268 /**
78269 * @license
78270 * Copyright 2021 Google LLC. All Rights Reserved.
78271 * Licensed under the Apache License, Version 2.0 (the "License");
78272 * you may not use this file except in compliance with the License.
78273 * You may obtain a copy of the License at
78274 *
78275 * http://www.apache.org/licenses/LICENSE-2.0
78276 *
78277 * Unless required by applicable law or agreed to in writing, software
78278 * distributed under the License is distributed on an "AS IS" BASIS,
78279 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78280 * See the License for the specific language governing permissions and
78281 * limitations under the License.
78282 * =============================================================================
78283 */
78284 function transform$1(args) {
78285 const { inputs, attrs, backend } = args;
78286 const { image, transforms } = inputs;
78287 const { interpolation, fillMode, fillValue, outputShape } = attrs;
78288 const [batch, imageHeight, imageWidth, numChannels] = image.shape;
78289 const [outHeight, outWidth] = outputShape != null ? outputShape : [imageHeight, imageWidth];
78290 const outShape = [batch, outHeight, outWidth, numChannels];
78291 const strides = computeStrides(image.shape);
78292 const batchStride = strides[0];
78293 const rowStride = strides[1];
78294 const colStride = strides[2];
78295 const outVals = getTypedArrayFromDType(image.dtype, sizeFromShape(outShape));
78296 outVals.fill(fillValue);
78297 const imageVals = backend.data.get(image.dataId).values;
78298 const transformVals = backend.data.get(transforms.dataId).values;
78299 // Ref TF implementation:
78300 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/image/image_ops.h
78301 for (let b = 0; b < batch; ++b) {
78302 const transform = transforms.shape[0] === 1 ?
78303 transformVals :
78304 transformVals.subarray(b * 8, b * 8 + 8);
78305 for (let outY = 0; outY < outHeight; ++outY) {
78306 for (let outX = 0; outX < outWidth; ++outX) {
78307 for (let channel = 0; channel < numChannels; ++channel) {
78308 let val;
78309 const projection = transform[6] * outX + transform[7] * outY + 1;
78310 if (projection === 0) {
78311 // Return the fill value for infinite coordinates,
78312 // which are outside the input image
78313 continue;
78314 }
78315 const inX = (transform[0] * outX + transform[1] * outY + transform[2]) /
78316 projection;
78317 const inY = (transform[3] * outX + transform[4] * outY + transform[5]) /
78318 projection;
78319 const x = mapCoord(inX, imageWidth, fillMode);
78320 const y = mapCoord(inY, imageHeight, fillMode);
78321 switch (interpolation) {
78322 case 'nearest':
78323 val = nearestInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, b, y, x, channel, fillValue);
78324 break;
78325 case 'bilinear':
78326 val = bilinearInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, b, y, x, channel, fillValue);
78327 break;
78328 default:
78329 throw new Error(`Error in Transform: Expect 'nearest' or ` +
78330 `'bilinear', but got ${interpolation}`);
78331 }
78332 const ind = b * batchStride + outY * rowStride + outX * colStride + channel;
78333 outVals[ind] = val;
78334 }
78335 }
78336 }
78337 return backend.makeTensorInfo(outShape, image.dtype, outVals);
78338 }
78339 const dataId = backend.write(outVals, outShape, image.dtype);
78340 return { dataId, shape: image.shape, dtype: image.dtype };
78341 }
78342 const transformConfig = {
78343 kernelName: Transform,
78344 backendName: 'cpu',
78345 kernelFunc: transform$1
78346 };
78347 function mapCoord(outCoord, len, mode) {
78348 switch (mode) {
78349 case 'reflect':
78350 return mapCoordReflect(outCoord, len);
78351 case 'wrap':
78352 return mapCoordWrap(outCoord, len);
78353 case 'nearest':
78354 return mapCoordNearest(outCoord, len);
78355 case 'constant':
78356 default:
78357 return mapCoordConstant(outCoord, len);
78358 }
78359 }
78360 function mapCoordReflect(outCoord, len) {
78361 // Reflect [abcd] to [dcba|abcd|dcba].
78362 let inCoord = outCoord;
78363 if (inCoord < 0) {
78364 if (len <= 1) {
78365 inCoord = 0;
78366 }
78367 else {
78368 const sz2 = 2 * len;
78369 if (inCoord < sz2) {
78370 inCoord = sz2 * Math.trunc(-inCoord / sz2) + inCoord;
78371 }
78372 inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1;
78373 }
78374 }
78375 else if (inCoord > len - 1) {
78376 if (len <= 1) {
78377 inCoord = 0;
78378 }
78379 else {
78380 const sz2 = 2 * len;
78381 inCoord -= sz2 * Math.trunc(inCoord / sz2);
78382 if (inCoord >= len) {
78383 inCoord = sz2 - inCoord - 1;
78384 }
78385 }
78386 }
78387 // clamp is necessary because when outCoord = 3.5 and len = 4,
78388 // inCoord = 3.5 and will be rounded to 4 in nearest interpolation.
78389 return clamp(0, inCoord, len - 1);
78390 }
78391 function mapCoordWrap(outCoord, len) {
78392 // Wrap [abcd] to [abcd|abcd|abcd].
78393 let inCoord = outCoord;
78394 if (inCoord < 0) {
78395 if (len <= 1) {
78396 inCoord = 0;
78397 }
78398 else {
78399 const sz = len - 1;
78400 inCoord += len * (Math.trunc(-inCoord / sz) + 1);
78401 }
78402 }
78403 else if (inCoord > len - 1) {
78404 if (len <= 1) {
78405 inCoord = 0;
78406 }
78407 else {
78408 const sz = len - 1;
78409 inCoord -= len * Math.trunc(inCoord / sz);
78410 }
78411 }
78412 // clamp is necessary because when outCoord = -0.5 and len = 4,
78413 // inCoord = 3.5 and will be rounded to 4 in nearest interpolation.
78414 return clamp(0, inCoord, len - 1);
78415 }
78416 function mapCoordConstant(outCoord, len) {
78417 return outCoord;
78418 }
78419 function mapCoordNearest(outCoord, len) {
78420 return clamp(0, outCoord, len - 1);
78421 }
78422 function readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
78423 const ind = batch * batchStride + y * rowStride + x * colStride + channel;
78424 if (0 <= y && y < imageHeight && 0 <= x && x < imageWidth) {
78425 return imageVals[ind];
78426 }
78427 else {
78428 return fillValue;
78429 }
78430 }
78431 function nearestInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
78432 const $y = Math.round(y);
78433 const $x = Math.round(x);
78434 return readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, $y, $x, channel, fillValue);
78435 }
78436 function bilinearInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
78437 const yFloor = Math.floor(y);
78438 const xFloor = Math.floor(x);
78439 const yCeil = yFloor + 1;
78440 const xCeil = xFloor + 1;
78441 // f(x, yFloor) = (xCeil - x) / (xCeil - xFloor) * f(xFloor, yFloor)
78442 // + (x - xFloor) / (xCeil - xFloor) * f(xCeil, yFloor)
78443 const valueYFloor = (xCeil - x) *
78444 readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yFloor, xFloor, channel, fillValue) +
78445 (x - xFloor) *
78446 readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yFloor, xCeil, channel, fillValue);
78447 // f(x, yCeil) = (xCeil - x) / (xCeil - xFloor) * f(xFloor, yCeil)
78448 // + (x - xFloor) / (xCeil - xFloor) * f(xCeil, yCeil)
78449 const valueYCeil = (xCeil - x) *
78450 readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yCeil, xFloor, channel, fillValue) +
78451 (x - xFloor) *
78452 readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yCeil, xCeil, channel, fillValue);
78453 // f(x, y) = (yCeil - y) / (yCeil - yFloor) * f(x, yFloor)
78454 // + (y - yFloor) / (yCeil - yFloor) * f(x, yCeil)
78455 return (yCeil - y) * valueYFloor + (y - yFloor) * valueYCeil;
78456 }
78457
78458 /**
78459 * @license
78460 * Copyright 2020 Google LLC. All Rights Reserved.
78461 * Licensed under the Apache License, Version 2.0 (the License);
78462 * you may not use this file except in compliance with the License.
78463 * You may obtain a copy of the License at
78464 *
78465 * http://www.apache.org/licenses/LICENSE-2.0
78466 *
78467 * Unless required by applicable law or agreed to in writing, software
78468 * distributed under the License is distributed on an AS IS BASIS,
78469 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78470 * See the License for the specific language governing permissions and
78471 * limitations under the License.
78472 * =============================================================================
78473 */
78474 function unique$2(args) {
78475 const { inputs, attrs, backend } = args;
78476 const { axis } = attrs;
78477 const { x } = inputs;
78478 assertNotComplex(x, 'unique');
78479 const values = backend.data.get(x.dataId).values;
78480 const { outputValues, outputShape, indices } = uniqueImpl(values, axis, x.shape, x.dtype);
78481 return [
78482 backend.makeTensorInfo(outputShape, x.dtype, outputValues),
78483 backend.makeTensorInfo([indices.length], 'int32', indices),
78484 ];
78485 }
78486 const uniqueConfig = {
78487 kernelName: Unique,
78488 backendName: 'cpu',
78489 kernelFunc: unique$2,
78490 };
78491
78492 /**
78493 * @license
78494 * Copyright 2020 Google LLC. All Rights Reserved.
78495 * Licensed under the Apache License, Version 2.0 (the "License");
78496 * you may not use this file except in compliance with the License.
78497 * You may obtain a copy of the License at
78498 *
78499 * http://www.apache.org/licenses/LICENSE-2.0
78500 *
78501 * Unless required by applicable law or agreed to in writing, software
78502 * distributed under the License is distributed on an "AS IS" BASIS,
78503 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78504 * See the License for the specific language governing permissions and
78505 * limitations under the License.
78506 * =============================================================================
78507 */
78508 function unpack(args) {
78509 const { inputs, backend, attrs } = args;
78510 const { value } = inputs;
78511 let { axis } = attrs;
78512 if (axis < 0) {
78513 axis += value.shape.length;
78514 }
78515 const valueRank = value.shape.length;
78516 const num = value.shape[axis];
78517 const outShape = new Array(valueRank - 1);
78518 let outIndex = 0;
78519 for (let i = 0; i < valueRank; i++) {
78520 if (i !== axis) {
78521 outShape[outIndex++] = value.shape[i];
78522 }
78523 }
78524 const begin = new Array(valueRank).fill(0);
78525 const size = value.shape.slice();
78526 size[axis] = 1;
78527 const res = new Array(num);
78528 for (let i = 0; i < res.length; i++) {
78529 begin[axis] = i;
78530 const tempRes = slice$1({ inputs: { x: value }, backend, attrs: { begin, size } });
78531 res[i] = reshape$2({ inputs: { x: tempRes }, backend, attrs: { shape: outShape } });
78532 backend.disposeIntermediateTensorInfo(tempRes);
78533 }
78534 return res;
78535 }
78536 const unpackConfig = {
78537 kernelName: Unpack,
78538 backendName: 'cpu',
78539 kernelFunc: unpack
78540 };
78541
78542 /**
78543 * @license
78544 * Copyright 2020 Google LLC. All Rights Reserved.
78545 * Licensed under the Apache License, Version 2.0 (the "License");
78546 * you may not use this file except in compliance with the License.
78547 * You may obtain a copy of the License at
78548 *
78549 * http://www.apache.org/licenses/LICENSE-2.0
78550 *
78551 * Unless required by applicable law or agreed to in writing, software
78552 * distributed under the License is distributed on an "AS IS" BASIS,
78553 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78554 * See the License for the specific language governing permissions and
78555 * limitations under the License.
78556 * =============================================================================
78557 */
78558 function unsortedSegmentSum$1(args) {
78559 const { inputs, backend, attrs } = args;
78560 const { x, segmentIds } = inputs;
78561 const { numSegments } = attrs;
78562 assertNotComplex(x, 'unsortedSegmentSum');
78563 const xRank = x.shape.length;
78564 const segmentIdsRank = segmentIds.shape.length;
78565 const res = [];
78566 const intermediates = [];
78567 // Reshape the segment id's so that they can be broadcast with
78568 // x. The new shape should be [segmentIds.shape, 1, ..., 1]
78569 const numIters = xRank - segmentIdsRank;
78570 let $segmentIds = segmentIds;
78571 for (let i = 0; i < numIters; ++i) {
78572 const expanded = expandDims$2({ inputs: { input: $segmentIds }, backend, attrs: { dim: i + 1 } });
78573 $segmentIds = expanded;
78574 intermediates.push(expanded);
78575 }
78576 for (let i = 0; i < numSegments; ++i) {
78577 const scalarValue = createScalarValue(i, 'int32');
78578 const segmentId = backend.makeTensorInfo([], 'int32', scalarValue);
78579 const mask = equal$1({ inputs: { a: segmentId, b: $segmentIds }, backend });
78580 const maskCasted = cast$2({ inputs: { x: mask }, backend, attrs: { dtype: 'float32' } });
78581 const mul = multiply$2({ inputs: { a: maskCasted, b: x }, backend });
78582 const sumTensorInfo = sum$3({ inputs: { x: mul }, backend, attrs: { axis: 0, keepDims: false } });
78583 res.push(sumTensorInfo);
78584 intermediates.push(segmentId);
78585 intermediates.push(mask);
78586 intermediates.push(maskCasted);
78587 intermediates.push(mul);
78588 intermediates.push(sumTensorInfo);
78589 }
78590 const result = pack({ inputs: res, backend, attrs: { axis: 0 } });
78591 intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t));
78592 return result;
78593 }
78594 const unsortedSegmentSumConfig = {
78595 kernelName: UnsortedSegmentSum,
78596 backendName: 'cpu',
78597 kernelFunc: unsortedSegmentSum$1
78598 };
78599
78600 /**
78601 * @license
78602 * Copyright 2020 Google LLC. All Rights Reserved.
78603 * Licensed under the Apache License, Version 2.0 (the "License");
78604 * you may not use this file except in compliance with the License.
78605 * You may obtain a copy of the License at
78606 *
78607 * http://www.apache.org/licenses/LICENSE-2.0
78608 *
78609 * Unless required by applicable law or agreed to in writing, software
78610 * distributed under the License is distributed on an "AS IS" BASIS,
78611 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78612 * See the License for the specific language governing permissions and
78613 * limitations under the License.
78614 * =============================================================================
78615 */
78616 // List all kernel configs here
78617 const kernelConfigs = [
78618 _fusedMatMulConfig,
78619 absConfig,
78620 acosConfig,
78621 acoshConfig,
78622 addConfig,
78623 addNConfig,
78624 allConfig,
78625 anyConfig,
78626 argMaxConfig,
78627 argMinConfig,
78628 asinConfig,
78629 asinhConfig,
78630 atanConfig,
78631 atan2Config,
78632 atanhConfig,
78633 avgPoolConfig,
78634 avgPool3DConfig,
78635 avgPool3DGradConfig$1,
78636 avgPoolGradConfig$1,
78637 batchMatMulConfig,
78638 batchNormConfig,
78639 batchToSpaceNDConfig,
78640 bincountConfig,
78641 broadcastArgsConfig,
78642 castConfig,
78643 ceilConfig,
78644 clipByValueConfig,
78645 complexConfig,
78646 complexAbsConfig,
78647 concatConfig,
78648 conv2DConfig,
78649 conv2DBackpropFilterConfig,
78650 conv2DBackpropInputConfig,
78651 conv3DConfig,
78652 conv3DBackpropFilterV2Config,
78653 conv3DBackpropInputV2Config,
78654 cosConfig,
78655 coshConfig,
78656 cropAndResizeConfig,
78657 cumprodConfig,
78658 cumsumConfig,
78659 denseBincountConfig,
78660 depthToSpaceConfig,
78661 depthwiseConv2dNativeConfig,
78662 depthwiseConv2dNativeBackpropFilterConfig,
78663 depthwiseConv2dNativeBackpropInputConfig,
78664 diagConfig,
78665 dilation2DConfig,
78666 dilation2DBackpropFilterConfig,
78667 dilation2DBackpropInputConfig,
78668 einsumConfig,
78669 eluConfig,
78670 eluGradConfig$1,
78671 equalConfig,
78672 erfConfig,
78673 expConfig,
78674 expandDimsConfig,
78675 expm1Config,
78676 fftConfig,
78677 fillConfig,
78678 flipLeftRightConfig,
78679 floorConfig,
78680 floorDivConfig,
78681 fusedConv2DConfig,
78682 fusedDepthwiseConv2DConfig,
78683 gatherNdConfig,
78684 gatherV2Config,
78685 greaterConfig,
78686 greaterEqualConfig,
78687 identityConfig,
78688 ifftConfig,
78689 imagConfig,
78690 isFiniteConfig,
78691 isInfConfig,
78692 isNaNConfig,
78693 leakyReluConfig,
78694 lessConfig,
78695 lessEqualConfig,
78696 linSpaceConfig,
78697 logConfig,
78698 log1pConfig,
78699 logicalAndConfig,
78700 logicalNotConfig,
78701 logicalOrConfig,
78702 LRNConfig,
78703 LRNGradConfig,
78704 maxConfig,
78705 maximumConfig,
78706 maxPoolConfig,
78707 maxPool3DConfig,
78708 maxPool3DGradConfig$1,
78709 maxPoolGradConfig$1,
78710 maxPoolWithArgmaxConfig,
78711 meanConfig,
78712 minConfig,
78713 minimumConfig,
78714 mirrorPadConfig,
78715 modConfig,
78716 multinomialConfig,
78717 multiplyConfig,
78718 negConfig,
78719 nonMaxSuppressionV3Config,
78720 nonMaxSuppressionV4Config,
78721 nonMaxSuppressionV5Config,
78722 notEqualConfig,
78723 oneHotConfig,
78724 onesLikeConfig,
78725 packConfig,
78726 padV2Config,
78727 powConfig,
78728 preluConfig,
78729 prodConfig,
78730 rangeConfig,
78731 realConfig,
78732 realDivConfig,
78733 reciprocalConfig,
78734 reluConfig,
78735 relu6Config,
78736 reshapeConfig,
78737 resizeBilinearConfig,
78738 resizeBilinearGradConfig$1,
78739 resizeNearestNeighborConfig,
78740 resizeNearestNeighborGradConfig$1,
78741 reverseConfig,
78742 rotateWithOffsetConfig,
78743 roundConfig,
78744 rsqrtConfig,
78745 scatterNdConfig,
78746 searchSortedConfig,
78747 selectConfig,
78748 seluConfig,
78749 sigmoidConfig,
78750 signConfig,
78751 sinConfig,
78752 sinhConfig,
78753 sliceConfig,
78754 softmaxConfig,
78755 softplusConfig,
78756 spaceToBatchNDConfig,
78757 sparseFillEmptyRowsConfig,
78758 sparseReshapeConfig,
78759 sparseSegmentMeanConfig,
78760 sparseSegmentSumConfig,
78761 sparseToDenseConfig,
78762 splitVConfig,
78763 sqrtConfig,
78764 squareConfig,
78765 squaredDifferenceConfig,
78766 stepConfig,
78767 stridedSliceConfig,
78768 stringNGramsConfig,
78769 stringSplitConfig,
78770 stringToHashBucketFastConfig,
78771 subConfig,
78772 sumConfig,
78773 tanConfig,
78774 tanhConfig,
78775 tileConfig,
78776 topKConfig,
78777 transformConfig,
78778 transposeConfig,
78779 uniqueConfig,
78780 unpackConfig,
78781 unsortedSegmentSumConfig,
78782 zerosLikeConfig
78783 ];
78784 for (const kernelConfig of kernelConfigs) {
78785 registerKernel(kernelConfig);
78786 }
78787
78788 /**
78789 * @license
78790 * Copyright 2020 Google LLC. All Rights Reserved.
78791 * Licensed under the Apache License, Version 2.0 (the "License");
78792 * you may not use this file except in compliance with the License.
78793 * You may obtain a copy of the License at
78794 *
78795 * http://www.apache.org/licenses/LICENSE-2.0
78796 *
78797 * Unless required by applicable law or agreed to in writing, software
78798 * distributed under the License is distributed on an "AS IS" BASIS,
78799 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78800 * See the License for the specific language governing permissions and
78801 * limitations under the License.
78802 * =============================================================================
78803 */
78804
78805 /**
78806 * @license
78807 * Copyright 2018 Google LLC. All Rights Reserved.
78808 * Licensed under the Apache License, Version 2.0 (the "License");
78809 * you may not use this file except in compliance with the License.
78810 * You may obtain a copy of the License at
78811 *
78812 * http://www.apache.org/licenses/LICENSE-2.0
78813 *
78814 * Unless required by applicable law or agreed to in writing, software
78815 * distributed under the License is distributed on an "AS IS" BASIS,
78816 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78817 * See the License for the specific language governing permissions and
78818 * limitations under the License.
78819 * =============================================================================
78820 */
78821 const contexts = {};
78822 const WEBGL_ATTRIBUTES = {
78823 alpha: false,
78824 antialias: false,
78825 premultipliedAlpha: false,
78826 preserveDrawingBuffer: false,
78827 depth: false,
78828 stencil: false,
78829 failIfMajorPerformanceCaveat: true
78830 };
78831 function clearWebGLContext(webGLVersion) {
78832 delete contexts[webGLVersion];
78833 }
78834 function setWebGLContext(webGLVersion, gl) {
78835 contexts[webGLVersion] = gl;
78836 }
78837 function getWebGLContext(webGLVersion, customCanvas) {
78838 if (!(webGLVersion in contexts) || customCanvas != null) {
78839 const newCtx = getWebGLRenderingContext(webGLVersion, customCanvas);
78840 if (newCtx !== null) {
78841 contexts[webGLVersion] = newCtx;
78842 }
78843 else {
78844 console.log('Could not get context for WebGL version', webGLVersion);
78845 return null;
78846 }
78847 }
78848 const gl = contexts[webGLVersion];
78849 if (gl == null || gl.isContextLost()) {
78850 delete contexts[webGLVersion];
78851 return getWebGLContext(webGLVersion);
78852 }
78853 gl.disable(gl.DEPTH_TEST);
78854 gl.disable(gl.STENCIL_TEST);
78855 gl.disable(gl.BLEND);
78856 gl.disable(gl.DITHER);
78857 gl.disable(gl.POLYGON_OFFSET_FILL);
78858 gl.disable(gl.SAMPLE_COVERAGE);
78859 gl.enable(gl.SCISSOR_TEST);
78860 gl.enable(gl.CULL_FACE);
78861 gl.cullFace(gl.BACK);
78862 return contexts[webGLVersion];
78863 }
78864 function createCanvas(webGLVersion) {
78865 if (typeof OffscreenCanvas !== 'undefined' && webGLVersion === 2) {
78866 return new OffscreenCanvas(300, 150);
78867 }
78868 else if (typeof document !== 'undefined') {
78869 return document.createElement('canvas');
78870 }
78871 else {
78872 throw new Error('Cannot create a canvas in this context');
78873 }
78874 }
78875 function getWebGLRenderingContext(webGLVersion, customCanvas) {
78876 if (webGLVersion !== 1 && webGLVersion !== 2) {
78877 throw new Error('Cannot get WebGL rendering context, WebGL is disabled.');
78878 }
78879 const canvas = customCanvas == null ? createCanvas(webGLVersion) : customCanvas;
78880 canvas.addEventListener('webglcontextlost', (ev) => {
78881 ev.preventDefault();
78882 delete contexts[webGLVersion];
78883 }, false);
78884 if (webGLVersion === 1) {
78885 return (canvas.getContext('webgl', WEBGL_ATTRIBUTES) ||
78886 canvas.getContext('experimental-webgl', WEBGL_ATTRIBUTES));
78887 }
78888 return canvas.getContext('webgl2', WEBGL_ATTRIBUTES);
78889 }
78890
78891 /**
78892 * @license
78893 * Copyright 2017 Google LLC. All Rights Reserved.
78894 * Licensed under the Apache License, Version 2.0 (the "License");
78895 * you may not use this file except in compliance with the License.
78896 * You may obtain a copy of the License at
78897 *
78898 * http://www.apache.org/licenses/LICENSE-2.0
78899 *
78900 * Unless required by applicable law or agreed to in writing, software
78901 * distributed under the License is distributed on an "AS IS" BASIS,
78902 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78903 * See the License for the specific language governing permissions and
78904 * limitations under the License.
78905 * =============================================================================
78906 */
78907 var PackingScheme;
78908 (function (PackingScheme) {
78909 /**
78910 * All values in a single texel are densely packed without any constraints.
78911 *
78912 * This is how the shader encodes a tensor with shape = [2, 3, 4]
78913 * (indices are [batch, row, col]).
78914 *
78915 * 000|001 010|011 020|021
78916 * ------- ------- -------
78917 * 002|003 012|013 022|023
78918 *
78919 * 100|101 110|111 120|121
78920 * ------- ------- -------
78921 * 102|103 112|113 122|123
78922 *
78923 */
78924 PackingScheme[PackingScheme["DENSE"] = 0] = "DENSE";
78925 /**
78926 * Single texels contain only values from the same batch, and from adjacent
78927 * rows and columns.
78928 *
78929 * This is how the shader encodes a tensor with shape = [2, 3, 5]
78930 * (indices are [batch, row, col]).
78931 *
78932 * 000|001 002|003 004|xxx 020|021 022|023 024|xxx
78933 * ------- ------- ------- ------- ------- -------
78934 * 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx
78935 *
78936 * 100|101 102|103 104|xxx 120|121 122|123 124|xxx
78937 * ------- ------- ------- ------- ------- -------
78938 * 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx
78939 *
78940 */
78941 PackingScheme[PackingScheme["SHARED_BATCH"] = 1] = "SHARED_BATCH";
78942 })(PackingScheme || (PackingScheme = {}));
78943 var TextureUsage;
78944 (function (TextureUsage) {
78945 TextureUsage[TextureUsage["RENDER"] = 0] = "RENDER";
78946 TextureUsage[TextureUsage["UPLOAD"] = 1] = "UPLOAD";
78947 TextureUsage[TextureUsage["PIXELS"] = 2] = "PIXELS";
78948 TextureUsage[TextureUsage["DOWNLOAD"] = 3] = "DOWNLOAD";
78949 })(TextureUsage || (TextureUsage = {}));
78950 var PhysicalTextureType;
78951 (function (PhysicalTextureType) {
78952 PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT16"] = 0] = "UNPACKED_FLOAT16";
78953 PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT32"] = 1] = "UNPACKED_FLOAT32";
78954 PhysicalTextureType[PhysicalTextureType["PACKED_4X1_UNSIGNED_BYTE"] = 2] = "PACKED_4X1_UNSIGNED_BYTE";
78955 PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT32"] = 3] = "PACKED_2X2_FLOAT32";
78956 PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT16"] = 4] = "PACKED_2X2_FLOAT16";
78957 })(PhysicalTextureType || (PhysicalTextureType = {}));
78958 function getUnpackedMatrixTextureShapeWidthHeight(rows, columns) {
78959 return [columns, rows];
78960 }
78961 function getUnpackedArraySizeFromMatrixSize(matrixSize, channelsPerTexture) {
78962 return matrixSize * channelsPerTexture;
78963 }
78964 function getColorMatrixTextureShapeWidthHeight(rows, columns) {
78965 return [columns * 4, rows];
78966 }
78967 /**
78968 * Get shape for densely packed RGBA texture.
78969 */
78970 function getDenseTexShape(shape) {
78971 const size = sizeFromShape(shape);
78972 const texelsNeeded = Math.ceil(size / 4);
78973 return sizeToSquarishShape(texelsNeeded);
78974 }
78975 function getMatrixSizeFromUnpackedArraySize(unpackedSize, channelsPerTexture) {
78976 if (unpackedSize % channelsPerTexture !== 0) {
78977 throw new Error(`unpackedSize (${unpackedSize}) must be a multiple of ` +
78978 `${channelsPerTexture}`);
78979 }
78980 return unpackedSize / channelsPerTexture;
78981 }
78982 function decodeMatrixFromUnpackedColorRGBAArray(unpackedArray, matrix, channels) {
78983 const requiredSize = unpackedArray.length * channels / 4;
78984 if (matrix.length < requiredSize) {
78985 throw new Error(`matrix length (${matrix.length}) must be >= ${requiredSize}`);
78986 }
78987 let dst = 0;
78988 for (let src = 0; src < unpackedArray.length; src += 4) {
78989 for (let c = 0; c < channels; c++) {
78990 matrix[dst++] = unpackedArray[src + c];
78991 }
78992 }
78993 }
78994 function getPackedMatrixTextureShapeWidthHeight(rows, columns) {
78995 return [
78996 Math.max(1, Math.ceil(columns / 2)), Math.max(1, Math.ceil(rows / 2))
78997 ];
78998 }
78999 function getPackedRGBAArraySizeFromMatrixShape(rows, columns) {
79000 const [w, h] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
79001 return w * h * 4;
79002 }
79003 function getTextureConfig(
79004 // tslint:disable-next-line:no-any
79005 gl, textureHalfFloatExtension) {
79006 // tslint:disable-next-line:no-any
79007 const glany = gl;
79008 let internalFormatFloat;
79009 let internalFormatHalfFloat;
79010 let internalFormatPackedHalfFloat;
79011 let internalFormatPackedFloat;
79012 let textureFormatFloat;
79013 let downloadTextureFormat;
79014 let downloadUnpackNumChannels;
79015 let defaultNumChannels;
79016 let textureTypeHalfFloat;
79017 let textureTypeFloat;
79018 if (env().getNumber('WEBGL_VERSION') === 2) {
79019 internalFormatFloat = glany.R32F;
79020 internalFormatHalfFloat = glany.R16F;
79021 internalFormatPackedHalfFloat = glany.RGBA16F;
79022 internalFormatPackedFloat = glany.RGBA32F;
79023 textureFormatFloat = glany.RED;
79024 downloadUnpackNumChannels = 4;
79025 defaultNumChannels = 1;
79026 textureTypeHalfFloat = glany.HALF_FLOAT;
79027 textureTypeFloat = glany.FLOAT;
79028 downloadTextureFormat = glany.RGBA8;
79029 }
79030 else {
79031 internalFormatFloat = gl.RGBA;
79032 internalFormatHalfFloat = gl.RGBA;
79033 internalFormatPackedHalfFloat = gl.RGBA;
79034 internalFormatPackedFloat = glany.RGBA;
79035 textureFormatFloat = gl.RGBA;
79036 downloadUnpackNumChannels = 4;
79037 defaultNumChannels = 4;
79038 textureTypeHalfFloat = textureHalfFloatExtension != null ?
79039 textureHalfFloatExtension.HALF_FLOAT_OES :
79040 null;
79041 textureTypeFloat = gl.FLOAT;
79042 downloadTextureFormat = gl.RGBA;
79043 }
79044 return {
79045 internalFormatFloat,
79046 internalFormatHalfFloat,
79047 internalFormatPackedHalfFloat,
79048 internalFormatPackedFloat,
79049 textureFormatFloat,
79050 downloadTextureFormat,
79051 downloadUnpackNumChannels,
79052 defaultNumChannels,
79053 textureTypeHalfFloat,
79054 textureTypeFloat
79055 };
79056 }
79057
79058 /**
79059 * @license
79060 * Copyright 2017 Google LLC. All Rights Reserved.
79061 * Licensed under the Apache License, Version 2.0 (the "License");
79062 * you may not use this file except in compliance with the License.
79063 * You may obtain a copy of the License at
79064 *
79065 * http://www.apache.org/licenses/LICENSE-2.0
79066 *
79067 * Unless required by applicable law or agreed to in writing, software
79068 * distributed under the License is distributed on an "AS IS" BASIS,
79069 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79070 * See the License for the specific language governing permissions and
79071 * limitations under the License.
79072 * =============================================================================
79073 */
79074 function callAndCheck(gl, func) {
79075 const returnValue = func();
79076 if (env().getBool('DEBUG')) {
79077 checkWebGLError(gl);
79078 }
79079 return returnValue;
79080 }
79081 function checkWebGLError(gl) {
79082 const error = gl.getError();
79083 if (error !== gl.NO_ERROR) {
79084 throw new Error('WebGL Error: ' + getWebGLErrorMessage(gl, error));
79085 }
79086 }
79087 // https://en.wikipedia.org/wiki/Half-precision_floating-point_format
79088 const MIN_FLOAT16 = 5.96e-8;
79089 const MAX_FLOAT16 = 65504;
79090 function canBeRepresented(num) {
79091 if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 ||
79092 (MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16)) {
79093 return true;
79094 }
79095 return false;
79096 }
79097 function getWebGLErrorMessage(gl, status) {
79098 switch (status) {
79099 case gl.NO_ERROR:
79100 return 'NO_ERROR';
79101 case gl.INVALID_ENUM:
79102 return 'INVALID_ENUM';
79103 case gl.INVALID_VALUE:
79104 return 'INVALID_VALUE';
79105 case gl.INVALID_OPERATION:
79106 return 'INVALID_OPERATION';
79107 case gl.INVALID_FRAMEBUFFER_OPERATION:
79108 return 'INVALID_FRAMEBUFFER_OPERATION';
79109 case gl.OUT_OF_MEMORY:
79110 return 'OUT_OF_MEMORY';
79111 case gl.CONTEXT_LOST_WEBGL:
79112 return 'CONTEXT_LOST_WEBGL';
79113 default:
79114 return `Unknown error code ${status}`;
79115 }
79116 }
79117 function getExtensionOrThrow(gl, extensionName) {
79118 return throwIfNull(gl, () => gl.getExtension(extensionName), 'Extension "' + extensionName + '" not supported on this browser.');
79119 }
79120 function createVertexShader(gl, vertexShaderSource) {
79121 const vertexShader = throwIfNull(gl, () => gl.createShader(gl.VERTEX_SHADER), 'Unable to create vertex WebGLShader.');
79122 callAndCheck(gl, () => gl.shaderSource(vertexShader, vertexShaderSource));
79123 callAndCheck(gl, () => gl.compileShader(vertexShader));
79124 if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) {
79125 console.log(gl.getShaderInfoLog(vertexShader));
79126 throw new Error('Failed to compile vertex shader.');
79127 }
79128 return vertexShader;
79129 }
79130 function createFragmentShader(gl, fragmentShaderSource) {
79131 const fragmentShader = throwIfNull(gl, () => gl.createShader(gl.FRAGMENT_SHADER), 'Unable to create fragment WebGLShader.');
79132 callAndCheck(gl, () => gl.shaderSource(fragmentShader, fragmentShaderSource));
79133 callAndCheck(gl, () => gl.compileShader(fragmentShader));
79134 if (env().get('ENGINE_COMPILE_ONLY')) {
79135 return fragmentShader;
79136 }
79137 if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) {
79138 logShaderSourceAndInfoLog(fragmentShaderSource, gl.getShaderInfoLog(fragmentShader));
79139 throw new Error('Failed to compile fragment shader.');
79140 }
79141 return fragmentShader;
79142 }
79143 const lineNumberRegex = /ERROR: [0-9]+:([0-9]+):/g;
79144 function logShaderSourceAndInfoLog(shaderSource, shaderInfoLog) {
79145 const lineNumberRegexResult = lineNumberRegex.exec(shaderInfoLog);
79146 if (lineNumberRegexResult == null) {
79147 console.log(`Couldn't parse line number in error: ${shaderInfoLog}`);
79148 console.log(shaderSource);
79149 return;
79150 }
79151 const lineNumber = +lineNumberRegexResult[1];
79152 const shaderLines = shaderSource.split('\n');
79153 const pad = shaderLines.length.toString().length + 2;
79154 const linesWithLineNumbers = shaderLines.map((line, lineNumber) => rightPad((lineNumber + 1).toString(), pad) + line);
79155 let maxLineLength = 0;
79156 for (let i = 0; i < linesWithLineNumbers.length; i++) {
79157 maxLineLength = Math.max(linesWithLineNumbers[i].length, maxLineLength);
79158 }
79159 const beforeErrorLines = linesWithLineNumbers.slice(0, lineNumber - 1);
79160 const errorLine = linesWithLineNumbers.slice(lineNumber - 1, lineNumber);
79161 const afterErrorLines = linesWithLineNumbers.slice(lineNumber);
79162 console.log(beforeErrorLines.join('\n'));
79163 console.log(shaderInfoLog.split('\n')[0]);
79164 console.log(`%c ${rightPad(errorLine[0], maxLineLength)}`, 'border:1px solid red; background-color:#e3d2d2; color:#a61717');
79165 console.log(afterErrorLines.join('\n'));
79166 }
79167 function createProgram(gl) {
79168 return throwIfNull(gl, () => gl.createProgram(), 'Unable to create WebGLProgram.');
79169 }
79170 function linkProgram(gl, program) {
79171 callAndCheck(gl, () => gl.linkProgram(program));
79172 if (env().get('ENGINE_COMPILE_ONLY')) {
79173 return;
79174 }
79175 if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) {
79176 console.log(gl.getProgramInfoLog(program));
79177 throw new Error('Failed to link vertex and fragment shaders.');
79178 }
79179 }
79180 function validateProgram(gl, program) {
79181 callAndCheck(gl, () => gl.validateProgram(program));
79182 if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) {
79183 console.log(gl.getProgramInfoLog(program));
79184 throw new Error('Shader program validation failed.');
79185 }
79186 }
79187 function createStaticVertexBuffer(gl, data) {
79188 const buffer = throwIfNull(gl, () => gl.createBuffer(), 'Unable to create WebGLBuffer');
79189 callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer));
79190 callAndCheck(gl, () => gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW));
79191 return buffer;
79192 }
79193 function createStaticIndexBuffer(gl, data) {
79194 const buffer = throwIfNull(gl, () => gl.createBuffer(), 'Unable to create WebGLBuffer');
79195 callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer));
79196 callAndCheck(gl, () => gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW));
79197 return buffer;
79198 }
79199 function getNumChannels() {
79200 if (env().getNumber('WEBGL_VERSION') === 2) {
79201 return 1;
79202 }
79203 return 4;
79204 }
79205 function createTexture(gl) {
79206 return throwIfNull(gl, () => gl.createTexture(), 'Unable to create WebGLTexture.');
79207 }
79208 function validateTextureSize(width, height) {
79209 const maxTextureSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
79210 if ((width <= 0) || (height <= 0)) {
79211 const requested = `[${width}x${height}]`;
79212 throw new Error('Requested texture size ' + requested + ' is invalid.');
79213 }
79214 if ((width > maxTextureSize) || (height > maxTextureSize)) {
79215 const requested = `[${width}x${height}]`;
79216 const max = `[${maxTextureSize}x${maxTextureSize}]`;
79217 throw new Error('Requested texture size ' + requested +
79218 ' greater than WebGL maximum on this browser / GPU ' + max + '.');
79219 }
79220 }
79221 function createFramebuffer(gl) {
79222 return throwIfNull(gl, () => gl.createFramebuffer(), 'Unable to create WebGLFramebuffer.');
79223 }
79224 function bindVertexBufferToProgramAttribute(gl, program, attribute, buffer, arrayEntriesPerItem, itemStrideInBytes, itemOffsetInBytes) {
79225 const loc = gl.getAttribLocation(program, attribute);
79226 if (loc === -1) {
79227 // The GPU compiler decided to strip out this attribute because it's unused,
79228 // thus no need to bind.
79229 return false;
79230 }
79231 callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer));
79232 callAndCheck(gl, () => gl.vertexAttribPointer(loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes, itemOffsetInBytes));
79233 callAndCheck(gl, () => gl.enableVertexAttribArray(loc));
79234 return true;
79235 }
79236 function bindTextureUnit(gl, texture, textureUnit) {
79237 validateTextureUnit(gl, textureUnit);
79238 callAndCheck(gl, () => gl.activeTexture(gl.TEXTURE0 + textureUnit));
79239 callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture));
79240 }
79241 function unbindTextureUnit(gl, textureUnit) {
79242 validateTextureUnit(gl, textureUnit);
79243 callAndCheck(gl, () => gl.activeTexture(gl.TEXTURE0 + textureUnit));
79244 callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
79245 }
79246 function getProgramUniformLocationOrThrow(gl, program, uniformName) {
79247 return throwIfNull(gl, () => gl.getUniformLocation(program, uniformName), 'uniform "' + uniformName + '" not present in program.');
79248 }
79249 function getProgramUniformLocation(gl, program, uniformName) {
79250 return gl.getUniformLocation(program, uniformName);
79251 }
79252 function bindTextureToProgramUniformSampler(gl, texture, uniformSamplerLocation, textureUnit) {
79253 callAndCheck(gl, () => bindTextureUnit(gl, texture, textureUnit));
79254 callAndCheck(gl, () => gl.uniform1i(uniformSamplerLocation, textureUnit));
79255 }
79256 function bindCanvasToFramebuffer(gl) {
79257 callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null));
79258 callAndCheck(gl, () => gl.viewport(0, 0, gl.canvas.width, gl.canvas.height));
79259 callAndCheck(gl, () => gl.scissor(0, 0, gl.canvas.width, gl.canvas.height));
79260 }
79261 function bindColorTextureToFramebuffer(gl, texture, framebuffer) {
79262 callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer));
79263 callAndCheck(gl, () => gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0));
79264 }
79265 function unbindColorTextureFromFramebuffer(gl, framebuffer) {
79266 callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer));
79267 callAndCheck(gl, () => gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0));
79268 }
79269 function validateFramebuffer(gl) {
79270 const status = gl.checkFramebufferStatus(gl.FRAMEBUFFER);
79271 if (status !== gl.FRAMEBUFFER_COMPLETE) {
79272 throw new Error('Error binding framebuffer: ' + getFramebufferErrorMessage(gl, status));
79273 }
79274 }
79275 function getFramebufferErrorMessage(gl, status) {
79276 switch (status) {
79277 case gl.FRAMEBUFFER_INCOMPLETE_ATTACHMENT:
79278 return 'FRAMEBUFFER_INCOMPLETE_ATTACHMENT';
79279 case gl.FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT:
79280 return 'FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT';
79281 case gl.FRAMEBUFFER_INCOMPLETE_DIMENSIONS:
79282 return 'FRAMEBUFFER_INCOMPLETE_DIMENSIONS';
79283 case gl.FRAMEBUFFER_UNSUPPORTED:
79284 return 'FRAMEBUFFER_UNSUPPORTED';
79285 default:
79286 return `unknown error ${status}`;
79287 }
79288 }
79289 function throwIfNull(gl, returnTOrNull, failureMessage) {
79290 const tOrNull = callAndCheck(gl, () => returnTOrNull());
79291 if (tOrNull == null) {
79292 throw new Error(failureMessage);
79293 }
79294 return tOrNull;
79295 }
79296 function validateTextureUnit(gl, textureUnit) {
79297 const maxTextureUnit = gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS - 1;
79298 const glTextureUnit = textureUnit + gl.TEXTURE0;
79299 if (glTextureUnit < gl.TEXTURE0 || glTextureUnit > maxTextureUnit) {
79300 const textureUnitRange = `[gl.TEXTURE0, gl.TEXTURE${maxTextureUnit}]`;
79301 throw new Error(`textureUnit must be in ${textureUnitRange}.`);
79302 }
79303 }
79304 function getBatchDim(shape, dimsToSkip = 2) {
79305 return sizeFromShape(shape.slice(0, shape.length - dimsToSkip));
79306 }
79307 function getRowsCols(shape) {
79308 if (shape.length === 0) {
79309 throw Error('Cannot get rows and columns of an empty shape array.');
79310 }
79311 return [
79312 shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1]
79313 ];
79314 }
79315 function getShapeAs3D(shape) {
79316 let shapeAs3D = [1, 1, 1];
79317 const isScalar = shape.length === 0 || (shape.length === 1 && shape[0] === 1);
79318 if (!isScalar) {
79319 shapeAs3D =
79320 [getBatchDim(shape), ...getRowsCols(shape)];
79321 }
79322 return shapeAs3D;
79323 }
79324 function getTextureShapeFromLogicalShape(logShape, isPacked = false) {
79325 let maxTexSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
79326 if (isPacked) {
79327 maxTexSize = maxTexSize * 2;
79328 // This logic ensures we accurately count the number of packed texels needed
79329 // to accommodate the tensor. We can only pack values in the same texel if
79330 // they are from adjacent pairs of rows/cols within the same batch. So if a
79331 // tensor has 3 rows, we pretend it has 4 rows in order to account for the
79332 // fact that the texels containing the third row are half empty.
79333 logShape = logShape.map((d, i) => i >= logShape.length - 2 ?
79334 nearestLargerEven(logShape[i]) :
79335 logShape[i]);
79336 // Packed texture height is at least 2 (the channel height of a single
79337 // texel).
79338 if (logShape.length === 1) {
79339 logShape = [2, logShape[0]];
79340 }
79341 }
79342 // If logical shape is 2, we don't squeeze, since we want to match physical.
79343 if (logShape.length !== 2) {
79344 const squeezeResult = squeezeShape(logShape);
79345 logShape = squeezeResult.newShape;
79346 }
79347 let size = sizeFromShape(logShape);
79348 if (logShape.length <= 1 && size <= maxTexSize) {
79349 return [1, size];
79350 }
79351 else if (logShape.length === 2 && logShape[0] <= maxTexSize &&
79352 logShape[1] <= maxTexSize) {
79353 return logShape;
79354 }
79355 else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize &&
79356 logShape[2] <= maxTexSize) {
79357 return [logShape[0] * logShape[1], logShape[2]];
79358 }
79359 else if (logShape.length === 3 && logShape[0] <= maxTexSize &&
79360 logShape[1] * logShape[2] <= maxTexSize) {
79361 return [logShape[0], logShape[1] * logShape[2]];
79362 }
79363 else if (logShape.length === 4 &&
79364 logShape[0] * logShape[1] * logShape[2] <= maxTexSize &&
79365 logShape[3] <= maxTexSize) {
79366 return [logShape[0] * logShape[1] * logShape[2], logShape[3]];
79367 }
79368 else if (logShape.length === 4 && logShape[0] <= maxTexSize &&
79369 logShape[1] * logShape[2] * logShape[3] <= maxTexSize) {
79370 return [logShape[0], logShape[1] * logShape[2] * logShape[3]];
79371 }
79372 else {
79373 if (isPacked) {
79374 // For packed textures size equals the number of channels required to
79375 // accommodate the texture data. However in order to squarify such that
79376 // inner dimensions stay even, we rewrite size to equal the number of
79377 // texels. Then in the return statement we rehydrate the squarified
79378 // dimensions to channel units.
79379 const batchDim = getBatchDim(logShape);
79380 let rows = 2, cols = 2;
79381 if (logShape.length) {
79382 [rows, cols] = getRowsCols(logShape);
79383 }
79384 size = batchDim * (rows / 2) * (cols / 2);
79385 return sizeToSquarishShape(size).map(d => d * 2);
79386 }
79387 return sizeToSquarishShape(size);
79388 }
79389 }
79390 function isEven(n) {
79391 return n % 2 === 0;
79392 }
79393 /**
79394 * This determines whether reshaping a packed texture requires rearranging
79395 * the data within the texture, assuming 2x2 packing.
79396 */
79397 function isReshapeFree(shape1, shape2) {
79398 shape1 = shape1.slice(-2);
79399 shape2 = shape2.slice(-2);
79400 if (arraysEqual(shape1, shape2)) {
79401 return true;
79402 }
79403 if (!shape1.length || !shape2.length) { // One of the shapes is a scalar.
79404 return true;
79405 }
79406 if (shape1[0] === 0 || shape1[1] === 0 || shape2[0] === 0 ||
79407 shape2[1] === 0) {
79408 return true;
79409 }
79410 if (shape1.length !== shape2.length) { // One of the shapes is a vector.
79411 const shape1Cols = shape1.slice(-1)[0];
79412 const shape2Cols = shape2.slice(-1)[0];
79413 if (shape1Cols === shape2Cols) {
79414 return true;
79415 }
79416 if (isEven(shape1Cols) && isEven(shape2Cols) &&
79417 (shape1[0] === 1 || shape2[0] === 1)) {
79418 return true;
79419 }
79420 }
79421 return shape1[1] === shape2[1] && isEven(shape1[0]) && isEven(shape2[0]);
79422 }
79423 // We cache webgl params because the environment gets reset between
79424 // unit tests and we don't want to constantly query the WebGLContext for
79425 // MAX_TEXTURE_SIZE.
79426 let MAX_TEXTURE_SIZE;
79427 let MAX_TEXTURES_IN_SHADER;
79428 function getWebGLMaxTextureSize(webGLVersion) {
79429 if (MAX_TEXTURE_SIZE == null) {
79430 const gl = getWebGLContext(webGLVersion);
79431 MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE);
79432 }
79433 return MAX_TEXTURE_SIZE;
79434 }
79435 function resetMaxTextureSize() {
79436 MAX_TEXTURE_SIZE = null;
79437 }
79438 function resetMaxTexturesInShader() {
79439 MAX_TEXTURES_IN_SHADER = null;
79440 }
79441 function getMaxTexturesInShader(webGLVersion) {
79442 if (MAX_TEXTURES_IN_SHADER == null) {
79443 const gl = getWebGLContext(webGLVersion);
79444 MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS);
79445 }
79446 // We cap at 16 to avoid spurious runtime "memory exhausted" error.
79447 return Math.min(16, MAX_TEXTURES_IN_SHADER);
79448 }
79449 function getWebGLDisjointQueryTimerVersion(webGLVersion) {
79450 if (webGLVersion === 0) {
79451 return 0;
79452 }
79453 let queryTimerVersion;
79454 const gl = getWebGLContext(webGLVersion);
79455 if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') &&
79456 webGLVersion === 2) {
79457 queryTimerVersion = 2;
79458 }
79459 else if (hasExtension(gl, 'EXT_disjoint_timer_query')) {
79460 queryTimerVersion = 1;
79461 }
79462 else {
79463 queryTimerVersion = 0;
79464 }
79465 return queryTimerVersion;
79466 }
79467 function hasExtension(gl, extensionName) {
79468 const ext = gl.getExtension(extensionName);
79469 return ext != null;
79470 }
79471 function isWebGLVersionEnabled(webGLVersion) {
79472 try {
79473 const gl = getWebGLContext(webGLVersion);
79474 if (gl != null) {
79475 return true;
79476 }
79477 }
79478 catch (e) {
79479 console.log('Error when getting WebGL context: ', e);
79480 return false;
79481 }
79482 return false;
79483 }
79484 function isCapableOfRenderingToFloatTexture(webGLVersion) {
79485 if (webGLVersion === 0) {
79486 return false;
79487 }
79488 const gl = getWebGLContext(webGLVersion);
79489 if (webGLVersion === 1) {
79490 if (!hasExtension(gl, 'OES_texture_float')) {
79491 return false;
79492 }
79493 }
79494 else {
79495 if (!hasExtension(gl, 'EXT_color_buffer_float')) {
79496 return false;
79497 }
79498 }
79499 const isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
79500 return isFrameBufferComplete;
79501 }
79502 /**
79503 * Check if we can download values from a float/half-float texture.
79504 *
79505 * Note that for performance reasons we use binding a texture to a framebuffer
79506 * as a proxy for ability to download float values later using readPixels. The
79507 * texture params of this texture will not match those in readPixels exactly
79508 * but if we are unable to bind some kind of float texture to the frameBuffer
79509 * then we definitely will not be able to read float values from it.
79510 */
79511 function isDownloadFloatTextureEnabled(webGLVersion) {
79512 if (webGLVersion === 0) {
79513 return false;
79514 }
79515 const gl = getWebGLContext(webGLVersion);
79516 if (webGLVersion === 1) {
79517 if (!hasExtension(gl, 'OES_texture_float')) {
79518 return false;
79519 }
79520 if (!hasExtension(gl, 'WEBGL_color_buffer_float')) {
79521 return false;
79522 }
79523 }
79524 else {
79525 if (hasExtension(gl, 'EXT_color_buffer_float')) {
79526 return createFloatTextureAndBindToFramebuffer(gl);
79527 }
79528 const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
79529 if (hasExtension(gl, COLOR_BUFFER_HALF_FLOAT)) {
79530 const textureHalfFloatExtension = gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
79531 return createHalfFloatTextureAndBindToFramebuffer(gl, textureHalfFloatExtension);
79532 }
79533 return false;
79534 }
79535 const isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
79536 return isFrameBufferComplete;
79537 }
79538 function createFloatTextureAndBindToFramebuffer(gl) {
79539 const texConfig = getTextureConfig(gl);
79540 const texture = gl.createTexture();
79541 gl.bindTexture(gl.TEXTURE_2D, texture);
79542 const width = 1;
79543 const height = 1;
79544 gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeFloat, null);
79545 const frameBuffer = gl.createFramebuffer();
79546 gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
79547 gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
79548 const isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
79549 gl.bindTexture(gl.TEXTURE_2D, null);
79550 gl.bindFramebuffer(gl.FRAMEBUFFER, null);
79551 gl.deleteTexture(texture);
79552 gl.deleteFramebuffer(frameBuffer);
79553 return isFrameBufferComplete;
79554 }
79555 function createHalfFloatTextureAndBindToFramebuffer(
79556 // tslint:disable-next-line:no-any
79557 gl, textureHalfFloatExtension) {
79558 const texConfig = getTextureConfig(gl, textureHalfFloatExtension);
79559 const texture = gl.createTexture();
79560 gl.bindTexture(gl.TEXTURE_2D, texture);
79561 const width = 1;
79562 const height = 1;
79563 gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatHalfFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeHalfFloat, null);
79564 const frameBuffer = gl.createFramebuffer();
79565 gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
79566 gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
79567 const isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
79568 gl.bindTexture(gl.TEXTURE_2D, null);
79569 gl.bindFramebuffer(gl.FRAMEBUFFER, null);
79570 gl.deleteTexture(texture);
79571 gl.deleteFramebuffer(frameBuffer);
79572 return isFrameBufferComplete;
79573 }
79574 function isWebGLFenceEnabled(webGLVersion) {
79575 if (webGLVersion !== 2) {
79576 return false;
79577 }
79578 const gl = getWebGLContext(webGLVersion);
79579 // tslint:disable-next-line:no-any
79580 const isEnabled = gl.fenceSync != null;
79581 return isEnabled;
79582 }
79583 function assertNotComplex$1(tensor, opName) {
79584 if (!Array.isArray(tensor)) {
79585 tensor = [tensor];
79586 }
79587 tensor.forEach(t => {
79588 if (t != null) {
79589 assert(t.dtype !== 'complex64', () => `${opName} does not support complex64 tensors ` +
79590 'in the WebGL backend.');
79591 }
79592 });
79593 }
79594
79595 /**
79596 * @license
79597 * Copyright 2019 Google LLC. All Rights Reserved.
79598 * Licensed under the Apache License, Version 2.0 (the "License");
79599 * you may not use this file except in compliance with the License.
79600 * You may obtain a copy of the License at
79601 *
79602 * http://www.apache.org/licenses/LICENSE-2.0
79603 *
79604 * Unless required by applicable law or agreed to in writing, software
79605 * distributed under the License is distributed on an "AS IS" BASIS,
79606 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79607 * See the License for the specific language governing permissions and
79608 * limitations under the License.
79609 * =============================================================================
79610 */
79611 const ENV$3 = env();
79612 /**
79613 * This file contains WebGL-specific flag registrations.
79614 */
79615 /**
79616 * True if WebGL is supported.
79617 */
79618 ENV$3.registerFlag('HAS_WEBGL', () => ENV$3.getNumber('WEBGL_VERSION') > 0);
79619 /** 0: No WebGL, 1: WebGL 1.0, 2: WebGL 2.0. */
79620 ENV$3.registerFlag('WEBGL_VERSION', () => {
79621 if (isWebGLVersionEnabled(2)) {
79622 return 2;
79623 }
79624 else if (isWebGLVersionEnabled(1)) {
79625 return 1;
79626 }
79627 return 0;
79628 });
79629 /** Whether to check for numerical representation problems. */
79630 ENV$3.registerFlag('WEBGL_CHECK_NUMERICAL_PROBLEMS', () => false);
79631 ENV$3.registerFlag('WEBGL_BUFFER_SUPPORTED', () => ENV$3.get('WEBGL_VERSION') === 2);
79632 /** Whether the WebGL backend will sometimes forward ops to the CPU. */
79633 ENV$3.registerFlag('WEBGL_CPU_FORWARD', () => true);
79634 /** Whether the WebGL backend will always use f16 textures for rendering. */
79635 ENV$3.registerFlag('WEBGL_FORCE_F16_TEXTURES', () => false);
79636 /** Whether to turn all packing related flags on. */
79637 ENV$3.registerFlag('WEBGL_PACK', () => ENV$3.getBool('HAS_WEBGL'));
79638 /** Whether we will pack the batchnormalization op. */
79639 ENV$3.registerFlag('WEBGL_PACK_NORMALIZATION', () => ENV$3.getBool('WEBGL_PACK'));
79640 /** Whether we will pack the clip op. */
79641 ENV$3.registerFlag('WEBGL_PACK_CLIP', () => ENV$3.getBool('WEBGL_PACK'));
79642 /** Whether we will pack the depthwise conv op. */
79643 ENV$3.registerFlag('WEBGL_PACK_DEPTHWISECONV', () => ENV$3.getBool('WEBGL_PACK'));
79644 /** Whether we will pack binary ops. */
79645 ENV$3.registerFlag('WEBGL_PACK_BINARY_OPERATIONS', () => ENV$3.getBool('WEBGL_PACK'));
79646 /** Whether we will pack unary ops. */
79647 ENV$3.registerFlag('WEBGL_PACK_UNARY_OPERATIONS', () => ENV$3.getBool('WEBGL_PACK'));
79648 /** Whether we will pack array ops. */
79649 ENV$3.registerFlag('WEBGL_PACK_ARRAY_OPERATIONS', () => ENV$3.getBool('WEBGL_PACK'));
79650 /** Whether we will pack image ops. */
79651 ENV$3.registerFlag('WEBGL_PACK_IMAGE_OPERATIONS', () => ENV$3.getBool('WEBGL_PACK'));
79652 /** Whether we will pack reduce ops. */
79653 ENV$3.registerFlag('WEBGL_PACK_REDUCE', () => ENV$3.getBool('WEBGL_PACK'));
79654 /** Whether packed WebGL kernels lazily unpack their outputs. */
79655 ENV$3.registerFlag('WEBGL_LAZILY_UNPACK', () => ENV$3.getBool('WEBGL_PACK'));
79656 /** Whether we will use the im2col algorithm to speed up convolutions. */
79657 ENV$3.registerFlag('WEBGL_CONV_IM2COL', () => ENV$3.getBool('WEBGL_PACK'));
79658 /** The maximum texture dimension. */
79659 ENV$3.registerFlag('WEBGL_MAX_TEXTURE_SIZE', () => getWebGLMaxTextureSize(ENV$3.getNumber('WEBGL_VERSION')));
79660 /** The maximum texture dimension. */
79661 ENV$3.registerFlag('WEBGL_MAX_TEXTURES_IN_SHADER', () => getMaxTexturesInShader(ENV$3.getNumber('WEBGL_VERSION')));
79662 /**
79663 * The disjoint_query_timer extension version.
79664 * 0: disabled, 1: EXT_disjoint_timer_query, 2:
79665 * EXT_disjoint_timer_query_webgl2.
79666 * In Firefox with WebGL 2.0,
79667 * EXT_disjoint_timer_query_webgl2 is not available, so we must use the
79668 * WebGL 1.0 extension.
79669 */
79670 ENV$3.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', () => {
79671 const webGLVersion = ENV$3.getNumber('WEBGL_VERSION');
79672 if (webGLVersion === 0) {
79673 return 0;
79674 }
79675 return getWebGLDisjointQueryTimerVersion(webGLVersion);
79676 });
79677 /**
79678 * Whether the timer object from the disjoint_query_timer extension gives
79679 * timing information that is reliable.
79680 */
79681 ENV$3.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', () => ENV$3.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0 &&
79682 !isMobile());
79683 /**
79684 * Whether the device is physically capable of rendering to float32 textures.
79685 */
79686 ENV$3.registerFlag('WEBGL_RENDER_FLOAT32_CAPABLE', () => isCapableOfRenderingToFloatTexture(ENV$3.getNumber('WEBGL_VERSION')));
79687 /**
79688 * Whether rendering to float32 textures is enabled. If disabled, renders to
79689 * float16 textures.
79690 */
79691 ENV$3.registerFlag('WEBGL_RENDER_FLOAT32_ENABLED', () => {
79692 return ENV$3.getBool('WEBGL_FORCE_F16_TEXTURES') ?
79693 false :
79694 ENV$3.getBool('WEBGL_RENDER_FLOAT32_CAPABLE');
79695 });
79696 /**
79697 * Whether downloading float textures is enabled (16 or 32 bit). If disabled,
79698 * uses IEEE 754 encoding of the float32 values to 4 uint8 when downloading.
79699 */
79700 ENV$3.registerFlag('WEBGL_DOWNLOAD_FLOAT_ENABLED', () => isDownloadFloatTextureEnabled(ENV$3.getNumber('WEBGL_VERSION')));
79701 /** Whether the fence API is available. */
79702 ENV$3.registerFlag('WEBGL_FENCE_API_ENABLED', () => isWebGLFenceEnabled(ENV$3.getNumber('WEBGL_VERSION')));
79703 /**
79704 * Tensors with size <= than this will be uploaded as uniforms, not textures.
79705 */
79706 ENV$3.registerFlag('WEBGL_SIZE_UPLOAD_UNIFORM', () => {
79707 // Use uniform uploads only when 32bit floats are supported. In
79708 // 16bit
79709 // environments there are problems with comparing a 16bit texture value
79710 // with a 32bit uniform value.
79711 const useUniforms = ENV$3.getBool('WEBGL_RENDER_FLOAT32_ENABLED');
79712 return useUniforms ? 4 : 0;
79713 });
79714 /**
79715 * If the total number of bytes allocated on the GPU is greater than this
79716 * number, we will aggressively delete textures upon disposal with
79717 * gl.deleteMatrixTexture, rather than making them available for reuse.
79718 *
79719 * Default value -1 indicates that we will never aggressively delete textures.
79720 */
79721 ENV$3.registerFlag('WEBGL_DELETE_TEXTURE_THRESHOLD', () => {
79722 return -1;
79723 }, threshold => {
79724 if (threshold < 0 && threshold !== -1) {
79725 throw new Error(`WEBGL_DELETE_TEXTURE_THRESHOLD must be -1 (indicating never ` +
79726 `delete) or at least 0, but got ${threshold}.`);
79727 }
79728 });
79729 /**
79730 * Trigger a manual GL command flush if the threshold of time has passed since
79731 * previous Kernel execution. This can be useful for Andorid device where GL
79732 * command flush are delayed un til the end of javascript task. This value is
79733 * measured in millisecond. Typically you want to set this value to close to 1.
79734 *
79735 * Default value 1 for mobile chrome, and -1 for rest cases. -1 indicates that
79736 * we will not enforce manual flush and depend on system default flush schedule.
79737 */
79738 ENV$3.registerFlag('WEBGL_FLUSH_THRESHOLD', () => {
79739 return isMobile() ? 1 : -1;
79740 }, threshold => {
79741 if (threshold < 0 && threshold !== -1) {
79742 throw new Error(`WEBGL_FLUSH_THRESHOLD must be -1 (indicating never ` +
79743 `manual flush) or at least 0, but got ${threshold}.`);
79744 }
79745 });
79746 /**
79747 * Threshold for input tensor size that determines whether WebGL backend will
79748 * delegate computation to CPU.
79749 *
79750 * Default value is 128.
79751 */
79752 ENV$3.registerFlag('CPU_HANDOFF_SIZE_THRESHOLD', () => 128);
79753 /** Whether we will use shapes uniforms. */
79754 ENV$3.registerFlag('WEBGL_USE_SHAPES_UNIFORMS', () => false);
79755 /**
79756 * Threshold for last dimension of input tensor that determines whether
79757 * WebGL backend for the Top K op will delegate computation to CPU. If input
79758 * is smaller than threshold then CPU will be used
79759 *
79760 * Default value is 100000.
79761 */
79762 ENV$3.registerFlag('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD', () => 100000);
79763 /**
79764 * Threshold for K that determines whether
79765 * WebGL backend for the Top K op will delegate computation to CPU. If k
79766 * is larger than threshold then CPU will be used
79767 *
79768 * Default value is 128.
79769 */
79770 ENV$3.registerFlag('TOPK_K_CPU_HANDOFF_THRESHOLD', () => 128);
79771
79772 /**
79773 * @license
79774 * Copyright 2018 Google LLC. All Rights Reserved.
79775 * Licensed under the Apache License, Version 2.0 (the "License");
79776 * you may not use this file except in compliance with the License.
79777 * You may obtain a copy of the License at
79778 *
79779 * http://www.apache.org/licenses/LICENSE-2.0
79780 *
79781 * Unless required by applicable law or agreed to in writing, software
79782 * distributed under the License is distributed on an "AS IS" BASIS,
79783 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79784 * See the License for the specific language governing permissions and
79785 * limitations under the License.
79786 * =============================================================================
79787 */
79788 function getGlslDifferences() {
79789 let version;
79790 let attribute;
79791 let varyingVs;
79792 let varyingFs;
79793 let texture2D;
79794 let output;
79795 let defineOutput;
79796 let defineSpecialNaN;
79797 let defineSpecialInf;
79798 let defineRound;
79799 if (env().getNumber('WEBGL_VERSION') === 2) {
79800 version = '#version 300 es';
79801 attribute = 'in';
79802 varyingVs = 'out';
79803 varyingFs = 'in';
79804 texture2D = 'texture';
79805 output = 'outputColor';
79806 defineOutput = 'out vec4 outputColor;';
79807 // Use custom isnan definition to work across differences between
79808 // implementations on various platforms. While this should happen in ANGLE
79809 // we still see differences between android and windows (on chrome) when
79810 // using isnan directly. Since WebGL2 supports uint type and
79811 // floatBitsToUinT built-in function, we could implment isnan following
79812 // IEEE 754 rules.
79813 // NaN defination in IEEE 754-1985 is :
79814 // - sign = either 0 or 1.
79815 // - biased exponent = all 1 bits.
79816 // - fraction = anything except all 0 bits (since all 0 bits represents
79817 // infinity).
79818 // https://en.wikipedia.org/wiki/IEEE_754-1985#Representation_of_non-numbers
79819 defineSpecialNaN = `
79820 bool isnan_custom(float val) {
79821 uint floatToUint = floatBitsToUint(val);
79822 return (floatToUint & 0x7fffffffu) > 0x7f800000u;
79823 }
79824
79825 bvec4 isnan_custom(vec4 val) {
79826 return bvec4(isnan_custom(val.x),
79827 isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w));
79828 }
79829
79830 #define isnan(value) isnan_custom(value)
79831 `;
79832 // In webgl 2 we do not need to specify a custom isinf so there is no
79833 // need for a special INFINITY constant.
79834 defineSpecialInf = ``;
79835 defineRound = `
79836 #define round(value) newRound(value)
79837 int newRound(float value) {
79838 return int(floor(value + 0.5));
79839 }
79840
79841 ivec4 newRound(vec4 value) {
79842 return ivec4(floor(value + vec4(0.5)));
79843 }
79844 `;
79845 }
79846 else {
79847 version = '';
79848 attribute = 'attribute';
79849 varyingVs = 'varying';
79850 varyingFs = 'varying';
79851 texture2D = 'texture2D';
79852 output = 'gl_FragColor';
79853 defineOutput = '';
79854 // WebGL1 has no built in isnan so we define one here.
79855 defineSpecialNaN = `
79856 #define isnan(value) isnan_custom(value)
79857 bool isnan_custom(float val) {
79858 return (val > 0. || val < 1. || val == 0.) ? false : true;
79859 }
79860 bvec4 isnan_custom(vec4 val) {
79861 return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w));
79862 }
79863 `;
79864 defineSpecialInf = `
79865 uniform float INFINITY;
79866
79867 bool isinf(float val) {
79868 return abs(val) == INFINITY;
79869 }
79870 bvec4 isinf(vec4 val) {
79871 return equal(abs(val), vec4(INFINITY));
79872 }
79873 `;
79874 defineRound = `
79875 int round(float value) {
79876 return int(floor(value + 0.5));
79877 }
79878
79879 ivec4 round(vec4 value) {
79880 return ivec4(floor(value + vec4(0.5)));
79881 }
79882 `;
79883 }
79884 return {
79885 version,
79886 attribute,
79887 varyingVs,
79888 varyingFs,
79889 texture2D,
79890 output,
79891 defineOutput,
79892 defineSpecialNaN,
79893 defineSpecialInf,
79894 defineRound
79895 };
79896 }
79897
79898 /**
79899 * @license
79900 * Copyright 2018 Google LLC. All Rights Reserved.
79901 * Licensed under the Apache License, Version 2.0 (the "License");
79902 * you may not use this file except in compliance with the License.
79903 * You may obtain a copy of the License at
79904 *
79905 * http://www.apache.org/licenses/LICENSE-2.0
79906 *
79907 * Unless required by applicable law or agreed to in writing, software
79908 * distributed under the License is distributed on an "AS IS" BASIS,
79909 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79910 * See the License for the specific language governing permissions and
79911 * limitations under the License.
79912 * =============================================================================
79913 */
79914 /**
79915 * Produces GLSL code that derives logical coordinates from a flat
79916 * index. The code performs integer division with each stride and decrements
79917 * the index until the index equals the final dimension coordinate.
79918 */
79919 function getLogicalCoordinatesFromFlatIndex(coords, shape, index = 'index') {
79920 const strides = computeStrides(shape);
79921 return strides
79922 .map((stride, i) => {
79923 const line1 = `int ${coords[i]} = ${index} / ${stride}`;
79924 const line2 = i === strides.length - 1 ?
79925 `int ${coords[i + 1]} = ${index} - ${coords[i]} * ${stride}` :
79926 `index -= ${coords[i]} * ${stride}`;
79927 return `${line1}; ${line2};`;
79928 })
79929 .join('');
79930 }
79931 function getOutputLogicalCoordinatesFromFlatIndexByUniform(coords, shape, index = 'index') {
79932 const strides = computeStrides(shape);
79933 return strides
79934 .map((_, i) => {
79935 const line1 = `int ${coords[i]} = ${index} / outShapeStrides[${i}]`;
79936 const line2 = i === strides.length - 1 ?
79937 `int ${coords[i + 1]} = ${index} - ${coords[i]} * outShapeStrides[${i}]` :
79938 `index -= ${coords[i]} * outShapeStrides[${i}]`;
79939 return `${line1}; ${line2};`;
79940 })
79941 .join('');
79942 }
79943 // Produces GLSL code that computes strides.
79944 function symbolicallyComputeStrides(indicesArr, variableName) {
79945 const numCoords = indicesArr.length;
79946 const shape = indicesArr.map(d => `${variableName}[${d}]`);
79947 const strides = new Array(numCoords - 1);
79948 strides[numCoords - 2] = shape[numCoords - 1];
79949 for (let i = numCoords - 3; i >= 0; --i) {
79950 strides[i] = `(${strides[i + 1]} * ${shape[i + 1]})`;
79951 }
79952 return strides;
79953 }
79954 function getLogicalCoordinatesFromFlatIndexByUniform(coords, variableName, index = 'index') {
79955 const indicesArray = coords.map((_, i) => i);
79956 const strides = symbolicallyComputeStrides(indicesArray, variableName);
79957 return strides
79958 .map((_, i) => {
79959 const line1 = `int ${coords[i]} = ${index} / ${strides[i]}`;
79960 const line2 = i === strides.length - 1 ?
79961 `int ${coords[i + 1]} = ${index} - ${coords[i]} * ${strides[i]}` :
79962 `index -= ${coords[i]} * ${strides[i]}`;
79963 return `${line1}; ${line2};`;
79964 })
79965 .join('');
79966 }
79967 function buildVec(x) {
79968 if (x.length === 1) {
79969 return `${x[0]}`;
79970 }
79971 return `vec${x.length}(${x.join(',')})`;
79972 }
79973 /**
79974 * Produces GLSL code that computes the dot product of the input x and y
79975 * vectors. Handles splitting inputs into increments of vec4s when necessary.
79976 */
79977 function dotify(x, y) {
79978 if (x.length !== y.length) {
79979 throw new Error(`Vectors to be dotted must be of the same length -` +
79980 `got ${x.length} and ${y.length}`);
79981 }
79982 const slices = [];
79983 const nearestVec4 = Math.floor(x.length / 4);
79984 const nearestVec4Remainder = x.length % 4;
79985 for (let i = 0; i < nearestVec4; i++) {
79986 const xSlice = x.slice(i * 4, i * 4 + 4);
79987 const ySlice = y.slice(i * 4, i * 4 + 4);
79988 slices.push(`${buildVec(xSlice)}, ${buildVec(ySlice)}`);
79989 }
79990 if (nearestVec4Remainder !== 0) {
79991 let xSlice = x.slice(nearestVec4 * 4);
79992 let ySlice = y.slice(nearestVec4 * 4);
79993 if (xSlice.length === 1) {
79994 xSlice = xSlice.map(d => `float(${d})`);
79995 ySlice = ySlice.map(d => `float(${d})`);
79996 }
79997 slices.push(`${buildVec(xSlice)}, ${buildVec(ySlice)}`);
79998 }
79999 return slices.map((d, i) => `dot(${d})`).join('+');
80000 }
80001 /**
80002 * Produces GLSL that computes the flat index from 3D coordinates.
80003 */
80004 function getFlatIndexFrom3D(shape) {
80005 const strides = computeStrides(shape).map(d => d.toString());
80006 return `
80007 int getFlatIndex(ivec3 coords) {
80008 return coords.x * ${strides[0]} + coords.y * ${strides[1]} + coords.z;
80009 }
80010`;
80011 }
80012 function getFlatIndexFrom3DOutput() {
80013 return `
80014 int getFlatIndex(ivec3 coords) {
80015 return coords.x * outShapeStrides[0] + coords.y * outShapeStrides[1] + coords.z;
80016 }
80017`;
80018 }
80019 const ENCODE_FLOAT_SNIPPET = `
80020 const float FLOAT_MAX = 1.70141184e38;
80021 const float FLOAT_MIN = 1.17549435e-38;
80022
80023 lowp vec4 encode_float(highp float v) {
80024 if (isnan(v)) {
80025 return vec4(255, 255, 255, 255);
80026 }
80027
80028 highp float av = abs(v);
80029
80030 if(av < FLOAT_MIN) {
80031 return vec4(0.0, 0.0, 0.0, 0.0);
80032 } else if(v > FLOAT_MAX) {
80033 return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;
80034 } else if(v < -FLOAT_MAX) {
80035 return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;
80036 }
80037
80038 highp vec4 c = vec4(0,0,0,0);
80039
80040 highp float e = floor(log2(av));
80041 highp float m = exp2(fract(log2(av))) - 1.0;
80042
80043 c[2] = floor(128.0 * m);
80044 m -= c[2] / 128.0;
80045 c[1] = floor(32768.0 * m);
80046 m -= c[1] / 32768.0;
80047 c[0] = floor(8388608.0 * m);
80048
80049 highp float ebias = e + 127.0;
80050 c[3] = floor(ebias / 2.0);
80051 ebias -= c[3] * 2.0;
80052 c[2] += floor(ebias) * 128.0;
80053
80054 c[3] += 128.0 * step(0.0, -v);
80055
80056 return c / 255.0;
80057 }
80058`;
80059
80060 /**
80061 * @license
80062 * Copyright 2017 Google LLC. All Rights Reserved.
80063 * Licensed under the Apache License, Version 2.0 (the "License");
80064 * you may not use this file except in compliance with the License.
80065 * You may obtain a copy of the License at
80066 *
80067 * http://www.apache.org/licenses/LICENSE-2.0
80068 *
80069 * Unless required by applicable law or agreed to in writing, software
80070 * distributed under the License is distributed on an "AS IS" BASIS,
80071 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80072 * See the License for the specific language governing permissions and
80073 * limitations under the License.
80074 * =============================================================================
80075 */
80076 const { getBroadcastDims: getBroadcastDims$1 } = backend_util;
80077 function makeShader(inputsInfo, outputShape, program) {
80078 const prefixSnippets = [];
80079 inputsInfo.forEach(x => {
80080 const size = sizeFromShape(x.shapeInfo.logicalShape);
80081 // Snippet when we decided to upload the values as uniform.
80082 if (x.shapeInfo.isUniform) {
80083 prefixSnippets.push(`uniform float ${x.name}${size > 1 ? `[${size}]` : ''};`);
80084 }
80085 else {
80086 prefixSnippets.push(`uniform sampler2D ${x.name};`);
80087 prefixSnippets.push(`uniform int offset${x.name};`);
80088 }
80089 if (program.enableShapeUniforms) {
80090 const { uniformShape } = getUniformInfoFromShape(program.packedInputs, x.shapeInfo.logicalShape, x.shapeInfo.texShape);
80091 switch (uniformShape.length) {
80092 case 1:
80093 prefixSnippets.push(`uniform int ${x.name}Shape;`);
80094 break;
80095 case 2:
80096 prefixSnippets.push(`uniform ivec2 ${x.name}Shape;`);
80097 break;
80098 case 3:
80099 prefixSnippets.push(`uniform ivec3 ${x.name}Shape;`);
80100 break;
80101 case 4:
80102 prefixSnippets.push(`uniform ivec4 ${x.name}Shape;`);
80103 break;
80104 default:
80105 break;
80106 }
80107 prefixSnippets.push(`uniform ivec2 ${x.name}TexShape;`);
80108 }
80109 });
80110 if (program.enableShapeUniforms) {
80111 switch (outputShape.logicalShape.length) {
80112 case 1:
80113 prefixSnippets.push(`uniform int outShape;`);
80114 break;
80115 case 2:
80116 prefixSnippets.push(`uniform ivec2 outShape;`);
80117 prefixSnippets.push(`uniform int outShapeStrides;`);
80118 break;
80119 case 3:
80120 prefixSnippets.push(`uniform ivec3 outShape;`);
80121 prefixSnippets.push(`uniform ivec2 outShapeStrides;`);
80122 break;
80123 case 4:
80124 prefixSnippets.push(`uniform ivec4 outShape;`);
80125 prefixSnippets.push(`uniform ivec3 outShapeStrides;`);
80126 break;
80127 default:
80128 break;
80129 }
80130 prefixSnippets.push(`uniform ivec2 outTexShape;`);
80131 }
80132 if (program.customUniforms) {
80133 program.customUniforms.forEach((d) => {
80134 prefixSnippets.push(`uniform ${d.type} ${d.name}${d.arrayIndex ? `[${d.arrayIndex}]` : ''};`);
80135 });
80136 }
80137 const inputPrefixSnippet = prefixSnippets.join('\n');
80138 const inputSamplingSnippet = inputsInfo
80139 .map(x => getInputSamplingSnippet(x, outputShape, program.packedInputs, program.enableShapeUniforms))
80140 .join('\n');
80141 const outTexShape = outputShape.texShape;
80142 const glsl = getGlslDifferences();
80143 const floatTextureSampleSnippet = getFloatTextureSampleSnippet(glsl);
80144 let outputSamplingSnippet;
80145 let floatTextureSetOutputSnippet;
80146 let shaderPrefix = getShaderPrefix(glsl);
80147 if (outputShape.isPacked) {
80148 outputSamplingSnippet = getPackedOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms);
80149 floatTextureSetOutputSnippet = getFloatTextureSetRGBASnippet(glsl);
80150 }
80151 else {
80152 outputSamplingSnippet = getOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms);
80153 floatTextureSetOutputSnippet = getFloatTextureSetRSnippet(glsl);
80154 }
80155 if (program.packedInputs) {
80156 shaderPrefix += SHADER_PACKED_PREFIX;
80157 }
80158 const source = [
80159 shaderPrefix, floatTextureSampleSnippet, floatTextureSetOutputSnippet,
80160 inputPrefixSnippet, outputSamplingSnippet, inputSamplingSnippet,
80161 program.userCode
80162 ].join('\n');
80163 return source;
80164 }
80165 function getSamplerFromInInfo(inInfo, enableShapeUniforms = false) {
80166 const shape = inInfo.shapeInfo.logicalShape;
80167 switch (shape.length) {
80168 case 0:
80169 return getSamplerScalar(inInfo, enableShapeUniforms);
80170 case 1:
80171 return getSampler1D(inInfo, enableShapeUniforms);
80172 case 2:
80173 return getSampler2D(inInfo, enableShapeUniforms);
80174 case 3:
80175 return getSampler3D(inInfo, enableShapeUniforms);
80176 case 4:
80177 return getSampler4D(inInfo, enableShapeUniforms);
80178 case 5:
80179 return getSampler5D(inInfo);
80180 case 6:
80181 return getSampler6D(inInfo);
80182 default:
80183 throw new Error(`${shape.length}-D input sampling` +
80184 ` is not yet supported`);
80185 }
80186 }
80187 function getPackedSamplerFromInInfo(inInfo, enableShapeUniforms) {
80188 const shape = inInfo.shapeInfo.logicalShape;
80189 switch (shape.length) {
80190 case 0:
80191 return getPackedSamplerScalar(inInfo);
80192 case 1:
80193 return getPackedSampler1D(inInfo, enableShapeUniforms);
80194 case 2:
80195 return getPackedSampler2D(inInfo, enableShapeUniforms);
80196 case 3:
80197 return getPackedSampler3D(inInfo, enableShapeUniforms);
80198 default:
80199 return getPackedSamplerND(inInfo, enableShapeUniforms);
80200 }
80201 }
80202 function getInputSamplingSnippet(inInfo, outShapeInfo, usesPackedTextures = false, enableShapeUniforms) {
80203 let res = '';
80204 if (usesPackedTextures) {
80205 res += getPackedSamplerFromInInfo(inInfo, enableShapeUniforms);
80206 }
80207 else {
80208 res += getSamplerFromInInfo(inInfo, enableShapeUniforms);
80209 }
80210 const inShape = inInfo.shapeInfo.logicalShape;
80211 const outShape = outShapeInfo.logicalShape;
80212 if (inShape.length <= outShape.length) {
80213 if (usesPackedTextures) {
80214 res += getPackedSamplerAtOutputCoords(inInfo, outShapeInfo);
80215 }
80216 else {
80217 res += getSamplerAtOutputCoords(inInfo, outShapeInfo);
80218 }
80219 }
80220 return res;
80221 }
80222 function getPackedOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) {
80223 switch (outShape.length) {
80224 case 0:
80225 return getOutputScalarCoords();
80226 case 1:
80227 return getOutputPacked1DCoords(outShape, outTexShape, enableShapeUniforms);
80228 case 2:
80229 return getOutputPacked2DCoords(outShape, outTexShape, enableShapeUniforms);
80230 case 3:
80231 return getOutputPacked3DCoords(outShape, outTexShape, enableShapeUniforms);
80232 default:
80233 return getOutputPackedNDCoords(outShape, outTexShape, enableShapeUniforms);
80234 }
80235 }
80236 function getOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) {
80237 switch (outShape.length) {
80238 case 0:
80239 return getOutputScalarCoords();
80240 case 1:
80241 return getOutput1DCoords(outShape, outTexShape, enableShapeUniforms);
80242 case 2:
80243 return getOutput2DCoords(outShape, outTexShape, enableShapeUniforms);
80244 case 3:
80245 return getOutput3DCoords(outShape, outTexShape, enableShapeUniforms);
80246 case 4:
80247 return getOutput4DCoords(outShape, outTexShape, enableShapeUniforms);
80248 case 5:
80249 return getOutput5DCoords(outShape, outTexShape);
80250 case 6:
80251 return getOutput6DCoords(outShape, outTexShape);
80252 default:
80253 throw new Error(`${outShape.length}-D output sampling is not yet supported`);
80254 }
80255 }
80256 function getFloatTextureSampleSnippet(glsl) {
80257 return `
80258 float sampleTexture(sampler2D textureSampler, vec2 uv) {
80259 return ${glsl.texture2D}(textureSampler, uv).r;
80260 }
80261 `;
80262 }
80263 function getFloatTextureSetRSnippet(glsl) {
80264 return `
80265 void setOutput(float val) {
80266 ${glsl.output} = vec4(val, 0, 0, 0);
80267 }
80268 `;
80269 }
80270 function getFloatTextureSetRGBASnippet(glsl) {
80271 return `
80272 void setOutput(vec4 val) {
80273 ${glsl.output} = val;
80274 }
80275 `;
80276 }
80277 function getShaderPrefix(glsl) {
80278 const SHADER_PREFIX = `${glsl.version}
80279 precision highp float;
80280 precision highp int;
80281 precision highp sampler2D;
80282 ${glsl.varyingFs} vec2 resultUV;
80283 ${glsl.defineOutput}
80284 const vec2 halfCR = vec2(0.5, 0.5);
80285
80286 struct ivec5
80287 {
80288 int x;
80289 int y;
80290 int z;
80291 int w;
80292 int u;
80293 };
80294
80295 struct ivec6
80296 {
80297 int x;
80298 int y;
80299 int z;
80300 int w;
80301 int u;
80302 int v;
80303 };
80304
80305 uniform float NAN;
80306 ${glsl.defineSpecialNaN}
80307 ${glsl.defineSpecialInf}
80308 ${glsl.defineRound}
80309
80310 int imod(int x, int y) {
80311 return x - y * (x / y);
80312 }
80313
80314 int idiv(int a, int b, float sign) {
80315 int res = a / b;
80316 int mod = imod(a, b);
80317 if (sign < 0. && mod != 0) {
80318 res -= 1;
80319 }
80320 return res;
80321 }
80322
80323 //Based on the work of Dave Hoskins
80324 //https://www.shadertoy.com/view/4djSRW
80325 #define HASHSCALE1 443.8975
80326 float random(float seed){
80327 vec2 p = resultUV * seed;
80328 vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);
80329 p3 += dot(p3, p3.yzx + 19.19);
80330 return fract((p3.x + p3.y) * p3.z);
80331 }
80332
80333 ${SAMPLE_1D_SNIPPET}
80334 ${SAMPLE_2D_SNIPPET}
80335 ${SAMPLE_3D_SNIPPET}
80336 `;
80337 return SHADER_PREFIX;
80338 }
80339 const SAMPLE_1D_SNIPPET = `
80340vec2 uvFromFlat(int texNumR, int texNumC, int index) {
80341 int texR = index / texNumC;
80342 int texC = index - texR * texNumC;
80343 return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
80344}
80345vec2 packedUVfrom1D(int texNumR, int texNumC, int index) {
80346 int texelIndex = index / 2;
80347 int texR = texelIndex / texNumC;
80348 int texC = texelIndex - texR * texNumC;
80349 return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
80350}
80351`;
80352 const SAMPLE_2D_SNIPPET = `
80353vec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR,
80354 int texNumC, int row, int col) {
80355 int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2);
80356 int texR = texelIndex / texNumC;
80357 int texC = texelIndex - texR * texNumC;
80358 return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
80359}
80360`;
80361 const SAMPLE_3D_SNIPPET = `
80362vec2 packedUVfrom3D(int texNumR, int texNumC,
80363 int texelsInBatch, int texelsInLogicalRow, int b,
80364 int row, int col) {
80365 int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2);
80366 int texR = index / texNumC;
80367 int texC = index - texR * texNumC;
80368 return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
80369}
80370`;
80371 const SHADER_PACKED_PREFIX = `
80372 float getChannel(vec4 frag, vec2 innerDims) {
80373 vec2 modCoord = mod(innerDims, 2.);
80374 return modCoord.x == 0. ?
80375 (modCoord.y == 0. ? frag.r : frag.g) :
80376 (modCoord.y == 0. ? frag.b : frag.a);
80377 }
80378 float getChannel(vec4 frag, int dim) {
80379 float modCoord = mod(float(dim), 2.);
80380 return modCoord == 0. ? frag.r : frag.g;
80381 }
80382`;
80383 function getOutputScalarCoords() {
80384 return `
80385 int getOutputCoords() {
80386 return 0;
80387 }
80388 `;
80389 }
80390 function getOutputPacked1DCoords(shape, texShape, enableShapeUniforms) {
80391 const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
80392 if (packedTexShape[0] === 1) {
80393 if (enableShapeUniforms) {
80394 return `
80395 int getOutputCoords() {
80396 return 2 * int(resultUV.x * ceil(float(outTexShape[1]) / 2.0));
80397 }
80398 `;
80399 }
80400 return `
80401 int getOutputCoords() {
80402 return 2 * int(resultUV.x * ${packedTexShape[1]}.0);
80403 }
80404 `;
80405 }
80406 if (packedTexShape[1] === 1) {
80407 if (enableShapeUniforms) {
80408 return `
80409 int getOutputCoords() {
80410 return 2 * int(resultUV.y * ceil(float(outTexShape[0]) / 2.0));
80411 }
80412 `;
80413 }
80414 return `
80415 int getOutputCoords() {
80416 return 2 * int(resultUV.y * ${packedTexShape[0]}.0);
80417 }
80418 `;
80419 }
80420 if (enableShapeUniforms) {
80421 return `
80422 int getOutputCoords() {
80423 ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
80424 ivec2 resTexRC = ivec2(resultUV.yx *
80425 vec2(packedTexShape[0], packedTexShape[1]));
80426 return 2 * (resTexRC.x * packedTexShape[1] + resTexRC.y);
80427 }
80428 `;
80429 }
80430 return `
80431 int getOutputCoords() {
80432 ivec2 resTexRC = ivec2(resultUV.yx *
80433 vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
80434 return 2 * (resTexRC.x * ${packedTexShape[1]} + resTexRC.y);
80435 }
80436 `;
80437 }
80438 function getOutput1DCoords(shape, texShape, enableShapeUniforms) {
80439 if (texShape[0] === 1) {
80440 if (enableShapeUniforms) {
80441 return `
80442 int getOutputCoords() {
80443 return int(resultUV.x * float(outTexShape[1]));
80444 }
80445 `;
80446 }
80447 return `
80448 int getOutputCoords() {
80449 return int(resultUV.x * ${texShape[1]}.0);
80450 }
80451 `;
80452 }
80453 if (texShape[1] === 1) {
80454 if (enableShapeUniforms) {
80455 return `
80456 int getOutputCoords() {
80457 return int(resultUV.y * float(outTexShape[0]));
80458 }
80459 `;
80460 }
80461 return `
80462 int getOutputCoords() {
80463 return int(resultUV.y * ${texShape[0]}.0);
80464 }
80465 `;
80466 }
80467 if (enableShapeUniforms) {
80468 return `
80469 int getOutputCoords() {
80470 ivec2 resTexRC = ivec2(resultUV.yx *
80471 vec2(outTexShape[0], outTexShape[1]));
80472 return resTexRC.x * outTexShape[1] + resTexRC.y;
80473 }
80474 `;
80475 }
80476 return `
80477 int getOutputCoords() {
80478 ivec2 resTexRC = ivec2(resultUV.yx *
80479 vec2(${texShape[0]}, ${texShape[1]}));
80480 return resTexRC.x * ${texShape[1]} + resTexRC.y;
80481 }
80482 `;
80483 }
80484 function getOutputPacked3DCoords(shape, texShape, enableShapeUniforms) {
80485 if (enableShapeUniforms) {
80486 return `
80487 ivec3 getOutputCoords() {
80488 ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
80489 int texelsInLogicalRow = int(ceil(float(outShape[2]) / 2.0));
80490 int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[1]) / 2.0));
80491 ivec2 resTexRC = ivec2(resultUV.yx *
80492 vec2(packedTexShape[0], packedTexShape[1]));
80493 int index = resTexRC.x * packedTexShape[1] + resTexRC.y;
80494
80495 int b = index / texelsInBatch;
80496 index -= b * texelsInBatch;
80497
80498 int r = 2 * (index / texelsInLogicalRow);
80499 int c = imod(index, texelsInLogicalRow) * 2;
80500
80501 return ivec3(b, r, c);
80502 }
80503 `;
80504 }
80505 const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
80506 const texelsInLogicalRow = Math.ceil(shape[2] / 2);
80507 const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2);
80508 return `
80509 ivec3 getOutputCoords() {
80510 ivec2 resTexRC = ivec2(resultUV.yx *
80511 vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
80512 int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
80513
80514 int b = index / ${texelsInBatch};
80515 index -= b * ${texelsInBatch};
80516
80517 int r = 2 * (index / ${texelsInLogicalRow});
80518 int c = imod(index, ${texelsInLogicalRow}) * 2;
80519
80520 return ivec3(b, r, c);
80521 }
80522 `;
80523 }
80524 function getOutput3DCoords(shape, texShape, enableShapeUniforms) {
80525 if (enableShapeUniforms) {
80526 const coordsFromIndexSnippet = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], shape);
80527 return `
80528 ivec3 getOutputCoords() {
80529 ivec2 resTexRC = ivec2(resultUV.yx *
80530 vec2(outTexShape[0], outTexShape[1]));
80531 int index = resTexRC.x * outTexShape[1] + resTexRC.y;
80532 ${coordsFromIndexSnippet}
80533 return ivec3(r, c, d);
80534 }
80535`;
80536 }
80537 const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
80538 return `
80539 ivec3 getOutputCoords() {
80540 ivec2 resTexRC = ivec2(resultUV.yx *
80541 vec2(${texShape[0]}, ${texShape[1]}));
80542 int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
80543 ${coordsFromIndexSnippet}
80544 return ivec3(r, c, d);
80545 }
80546 `;
80547 }
80548 function getOutputPackedNDCoords(shape, texShape, enableShapeUniforms) {
80549 if (enableShapeUniforms) {
80550 // TODO: support 5d and 6d
80551 return `
80552 ivec4 getOutputCoords() {
80553 ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
80554 ivec2 resTexRC = ivec2(resultUV.yx *
80555 vec2(packedTexShape[0], packedTexShape[1]));
80556 int index = resTexRC.x * packedTexShape[1] + resTexRC.y;
80557
80558 int texelsInLogicalRow = int(ceil(float(outShape[3]) / 2.0));
80559 int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[2]) / 2.0));
80560 int texelsInBatchN = texelsInBatch * outShape[1];
80561
80562 int b2 = index / texelsInBatchN;
80563 index -= b2 * texelsInBatchN;
80564
80565 int b = index / texelsInBatch;
80566 index -= b * texelsInBatch;
80567
80568 int r = 2 * (index / texelsInLogicalRow);
80569 int c = imod(index, texelsInLogicalRow) * 2;
80570
80571 return ivec4(b2, b, r, c);
80572 }
80573 `;
80574 }
80575 const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
80576 const texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2);
80577 const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2);
80578 let texelsInBatchN = texelsInBatch;
80579 let batches = ``;
80580 let coords = 'b, r, c';
80581 for (let b = 2; b < shape.length - 1; b++) {
80582 texelsInBatchN *= shape[shape.length - b - 1];
80583 batches = `
80584 int b${b} = index / ${texelsInBatchN};
80585 index -= b${b} * ${texelsInBatchN};
80586 ` + batches;
80587 coords = `b${b}, ` + coords;
80588 }
80589 return `
80590 ivec${shape.length} getOutputCoords() {
80591 ivec2 resTexRC = ivec2(resultUV.yx *
80592 vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
80593 int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
80594
80595 ${batches}
80596
80597 int b = index / ${texelsInBatch};
80598 index -= b * ${texelsInBatch};
80599
80600 int r = 2 * (index / ${texelsInLogicalRow});
80601 int c = imod(index, ${texelsInLogicalRow}) * 2;
80602
80603 return ivec${shape.length}(${coords});
80604 }
80605 `;
80606 }
80607 function getOutput4DCoords(shape, texShape, enableShapeUniforms) {
80608 if (enableShapeUniforms) {
80609 const coordsFromIndexSnippet = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd', 'd2'], shape);
80610 return `
80611 ivec4 getOutputCoords() {
80612 ivec2 resTexRC = ivec2(resultUV.yx *
80613 vec2(outTexShape[0], outTexShape[1]));
80614 int index = resTexRC.x * outTexShape[1] + resTexRC.y;
80615 ${coordsFromIndexSnippet}
80616 return ivec4(r, c, d, d2);
80617 }
80618 `;
80619 }
80620 const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2'], shape);
80621 return `
80622 ivec4 getOutputCoords() {
80623 ivec2 resTexRC = ivec2(resultUV.yx *
80624 vec2(${texShape[0]}, ${texShape[1]}));
80625 int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
80626 ${coordsFromIndexSnippet}
80627 return ivec4(r, c, d, d2);
80628 }
80629 `;
80630 }
80631 function getOutput5DCoords(shape, texShape) {
80632 const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3'], shape);
80633 return `
80634 ivec5 getOutputCoords() {
80635 ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]},
80636 ${texShape[1]}));
80637
80638 int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
80639
80640 ${coordsFromIndexSnippet}
80641
80642 ivec5 outShape = ivec5(r, c, d, d2, d3);
80643 return outShape;
80644 }
80645 `;
80646 }
80647 function getOutput6DCoords(shape, texShape) {
80648 const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3', 'd4'], shape);
80649 return `
80650 ivec6 getOutputCoords() {
80651 ivec2 resTexRC = ivec2(resultUV.yx *
80652 vec2(${texShape[0]}, ${texShape[1]}));
80653 int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
80654
80655 ${coordsFromIndexSnippet}
80656
80657 ivec6 result = ivec6(r, c, d, d2, d3, d4);
80658 return result;
80659 }
80660 `;
80661 }
80662 function getOutputPacked2DCoords(shape, texShape, enableShapeUniforms) {
80663 const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
80664 if (arraysEqual(shape, texShape)) {
80665 if (enableShapeUniforms) {
80666 return `
80667 ivec2 getOutputCoords() {
80668 ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
80669 return 2 * ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1]));
80670 }
80671 `;
80672 }
80673 return `
80674 ivec2 getOutputCoords() {
80675 return 2 * ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
80676 }
80677 `;
80678 }
80679 // texels needed to accommodate a logical row
80680 const texelsInLogicalRow = Math.ceil(shape[1] / 2);
80681 /**
80682 * getOutputCoords
80683 *
80684 * resTexRC: The rows and columns of the texels. If you move over one
80685 * texel to the right in the packed texture, you are moving over one column
80686 * (not two).
80687 *
80688 * index: The texel index
80689 */
80690 if (enableShapeUniforms) {
80691 return `
80692 ivec2 getOutputCoords() {
80693 ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
80694 int texelsInLogicalRow = int(ceil(float(outShape[1]) / 2.0));
80695 ivec2 resTexRC = ivec2(resultUV.yx *
80696 vec2(packedTexShape[0], packedTexShape[1]));
80697
80698 int index = resTexRC.x * packedTexShape[1] + resTexRC.y;
80699 int r = 2 * (index / texelsInLogicalRow);
80700 int c = imod(index, texelsInLogicalRow) * 2;
80701
80702 return ivec2(r, c);
80703 }
80704 `;
80705 }
80706 return `
80707 ivec2 getOutputCoords() {
80708 ivec2 resTexRC = ivec2(resultUV.yx *
80709 vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
80710
80711 int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
80712 int r = 2 * (index / ${texelsInLogicalRow});
80713 int c = imod(index, ${texelsInLogicalRow}) * 2;
80714
80715 return ivec2(r, c);
80716 }
80717 `;
80718 }
80719 function getOutput2DCoords(shape, texShape, enableShapeUniforms) {
80720 if (arraysEqual(shape, texShape)) {
80721 if (enableShapeUniforms) {
80722 return `
80723 ivec2 getOutputCoords() {
80724 return ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1]));
80725 }
80726 `;
80727 }
80728 return `
80729 ivec2 getOutputCoords() {
80730 return ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]}));
80731 }
80732 `;
80733 }
80734 if (shape[1] === 1) {
80735 if (enableShapeUniforms) {
80736 return `
80737 ivec2 getOutputCoords() {
80738 ivec2 resTexRC = ivec2(resultUV.yx *
80739 vec2(outTexShape[0], outTexShape[1]));
80740 int index = resTexRC.x * outTexShape[1] + resTexRC.y;
80741 return ivec2(index, 0);
80742 }
80743 `;
80744 }
80745 return `
80746 ivec2 getOutputCoords() {
80747 ivec2 resTexRC = ivec2(resultUV.yx *
80748 vec2(${texShape[0]}, ${texShape[1]}));
80749 int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
80750 return ivec2(index, 0);
80751 }
80752 `;
80753 }
80754 if (shape[0] === 1) {
80755 if (enableShapeUniforms) {
80756 return `
80757 ivec2 getOutputCoords() {
80758 ivec2 resTexRC = ivec2(resultUV.yx *
80759 vec2(outTexShape[0], outTexShape[1]));
80760 int index = resTexRC.x * outTexShape[1] + resTexRC.y;
80761 return ivec2(0, index);
80762 }
80763 `;
80764 }
80765 return `
80766 ivec2 getOutputCoords() {
80767 ivec2 resTexRC = ivec2(resultUV.yx *
80768 vec2(${texShape[0]}, ${texShape[1]}));
80769 int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
80770 return ivec2(0, index);
80771 }
80772 `;
80773 }
80774 if (enableShapeUniforms) {
80775 return `
80776 ivec2 getOutputCoords() {
80777 ivec2 resTexRC = ivec2(resultUV.yx *
80778 vec2(outTexShape[0], outTexShape[1]));
80779 int index = resTexRC.x * outTexShape[1] + resTexRC.y;
80780 int r = index / outShape[1];
80781 int c = index - r * outShape[1];
80782 return ivec2(r, c);
80783 }
80784 `;
80785 }
80786 return `
80787 ivec2 getOutputCoords() {
80788 ivec2 resTexRC = ivec2(resultUV.yx *
80789 vec2(${texShape[0]}, ${texShape[1]}));
80790 int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
80791 int r = index / ${shape[1]};
80792 int c = index - r * ${shape[1]};
80793 return ivec2(r, c);
80794 }
80795 `;
80796 }
80797 function getFlatOffsetUniformName(texName) {
80798 return `offset${texName}`;
80799 }
80800 function getPackedSamplerScalar(inputInfo) {
80801 const texName = inputInfo.name;
80802 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
80803 const glsl = getGlslDifferences();
80804 return `
80805 vec4 ${funcName}() {
80806 return ${glsl.texture2D}(${texName}, halfCR);
80807 }
80808 `;
80809 }
80810 function getSamplerScalar(inputInfo, enableShapeUniforms) {
80811 const texName = inputInfo.name;
80812 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
80813 if (inputInfo.shapeInfo.isUniform) {
80814 return `float ${funcName}() {return ${texName};}`;
80815 }
80816 const [texNumR, texNumC] = inputInfo.shapeInfo.texShape;
80817 if (texNumR === 1 && texNumC === 1) {
80818 return `
80819 float ${funcName}() {
80820 return sampleTexture(${texName}, halfCR);
80821 }
80822 `;
80823 }
80824 const offset = getFlatOffsetUniformName(texName);
80825 if (enableShapeUniforms) {
80826 return `
80827 float ${funcName}() {
80828 vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], ${offset});
80829 return sampleTexture(${texName}, uv);
80830 }
80831 `;
80832 }
80833 const [tNumR, tNumC] = inputInfo.shapeInfo.texShape;
80834 return `
80835 float ${funcName}() {
80836 vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, ${offset});
80837 return sampleTexture(${texName}, uv);
80838 }
80839 `;
80840 }
80841 function getPackedSampler1D(inputInfo, enableShapeUniforms) {
80842 const texName = inputInfo.name;
80843 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
80844 const texShape = inputInfo.shapeInfo.texShape;
80845 const glsl = getGlslDifferences();
80846 if (enableShapeUniforms) {
80847 return `
80848 vec4 ${funcName}(int index) {
80849 ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0));
80850 vec2 uv = packedUVfrom1D(
80851 packedTexShape[0], packedTexShape[1], index);
80852 return ${glsl.texture2D}(${texName}, uv);
80853 }
80854 `;
80855 }
80856 const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
80857 return `
80858 vec4 ${funcName}(int index) {
80859 vec2 uv = packedUVfrom1D(
80860 ${packedTexShape[0]}, ${packedTexShape[1]}, index);
80861 return ${glsl.texture2D}(${texName}, uv);
80862 }
80863 `;
80864 }
80865 function getSampler1D(inputInfo, enableShapeUniforms) {
80866 const texName = inputInfo.name;
80867 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
80868 if (inputInfo.shapeInfo.isUniform) {
80869 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
80870 return `
80871 float ${funcName}(int index) {
80872 ${getUniformSampler(inputInfo)}
80873 }
80874 `;
80875 }
80876 const texShape = inputInfo.shapeInfo.texShape;
80877 const tNumR = texShape[0];
80878 const tNumC = texShape[1];
80879 if (tNumC === 1 && tNumR === 1) {
80880 return `
80881 float ${funcName}(int index) {
80882 return sampleTexture(${texName}, halfCR);
80883 }
80884 `;
80885 }
80886 const offset = getFlatOffsetUniformName(texName);
80887 if (tNumC === 1) {
80888 if (enableShapeUniforms) {
80889 return `
80890 float ${funcName}(int index) {
80891 vec2 uv = vec2(0.5, (float(index + ${offset}) + 0.5) / float(${texName}TexShape[0]));
80892 return sampleTexture(${texName}, uv);
80893 }
80894 `;
80895 }
80896 return `
80897 float ${funcName}(int index) {
80898 vec2 uv = vec2(0.5, (float(index + ${offset}) + 0.5) / ${tNumR}.0);
80899 return sampleTexture(${texName}, uv);
80900 }
80901 `;
80902 }
80903 if (tNumR === 1) {
80904 if (enableShapeUniforms) {
80905 return `
80906 float ${funcName}(int index) {
80907 vec2 uv = vec2((float(index + ${offset}) + 0.5) / float(${texName}TexShape[1]), 0.5);
80908 return sampleTexture(${texName}, uv);
80909 }
80910 `;
80911 }
80912 return `
80913 float ${funcName}(int index) {
80914 vec2 uv = vec2((float(index + ${offset}) + 0.5) / ${tNumC}.0, 0.5);
80915 return sampleTexture(${texName}, uv);
80916 }
80917 `;
80918 }
80919 if (enableShapeUniforms) {
80920 return `
80921 float ${funcName}(int index) {
80922 vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index + ${offset});
80923 return sampleTexture(${texName}, uv);
80924 }
80925 `;
80926 }
80927 return `
80928 float ${funcName}(int index) {
80929 vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, index + ${offset});
80930 return sampleTexture(${texName}, uv);
80931 }
80932 `;
80933 }
80934 function getPackedSampler2D(inputInfo, enableShapeUniforms) {
80935 const shape = inputInfo.shapeInfo.logicalShape;
80936 const texName = inputInfo.name;
80937 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
80938 const texShape = inputInfo.shapeInfo.texShape;
80939 const texNumR = texShape[0];
80940 const texNumC = texShape[1];
80941 const glsl = getGlslDifferences();
80942 if (texShape != null && arraysEqual(shape, texShape)) {
80943 if (enableShapeUniforms) {
80944 return `
80945 vec4 ${funcName}(int row, int col) {
80946 vec2 uv = (vec2(col, row) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]);
80947
80948 return ${glsl.texture2D}(${texName}, uv);
80949 }
80950 `;
80951 }
80952 return `
80953 vec4 ${funcName}(int row, int col) {
80954 vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
80955
80956 return ${glsl.texture2D}(${texName}, uv);
80957 }
80958 `;
80959 }
80960 if (enableShapeUniforms) {
80961 return `
80962 vec4 ${funcName}(int row, int col) {
80963 ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0));
80964 int valuesPerRow = int(ceil(float(${texName}Shape[1]) / 2.0));
80965 vec2 uv = packedUVfrom2D(valuesPerRow, packedTexShape[0], packedTexShape[1], row, col);
80966 return ${glsl.texture2D}(${texName}, uv);
80967 }
80968 `;
80969 }
80970 const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
80971 const valuesPerRow = Math.ceil(shape[1] / 2);
80972 return `
80973 vec4 ${funcName}(int row, int col) {
80974 vec2 uv = packedUVfrom2D(${valuesPerRow}, ${packedTexShape[0]}, ${packedTexShape[1]}, row, col);
80975 return ${glsl.texture2D}(${texName}, uv);
80976 }
80977 `;
80978 }
80979 function getSampler2D(inputInfo, enableShapeUniforms) {
80980 const shape = inputInfo.shapeInfo.logicalShape;
80981 const texName = inputInfo.name;
80982 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
80983 const texShape = inputInfo.shapeInfo.texShape;
80984 if (texShape != null && arraysEqual(shape, texShape)) {
80985 if (enableShapeUniforms) {
80986 return `
80987 float ${funcName}(int row, int col) {
80988 vec2 uv = (vec2(col, row) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]);
80989 return sampleTexture(${texName}, uv);
80990 }
80991 `;
80992 }
80993 const texNumR = texShape[0];
80994 const texNumC = texShape[1];
80995 return `
80996 float ${funcName}(int row, int col) {
80997 vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
80998 return sampleTexture(${texName}, uv);
80999 }
81000 `;
81001 }
81002 const { newShape, keptDims } = squeezeShape(shape);
81003 const squeezedShape = newShape;
81004 if (squeezedShape.length < shape.length) {
81005 const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
81006 const params = ['row', 'col'];
81007 return `
81008 ${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)}
81009 float ${funcName}(int row, int col) {
81010 return ${funcName}(${getSqueezedParams(params, keptDims)});
81011 }
81012 `;
81013 }
81014 if (inputInfo.shapeInfo.isUniform) {
81015 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
81016 return `
81017 float ${funcName}(int row, int col) {
81018 int index = round(dot(vec2(row, col), vec2(${shape[1]}, 1)));
81019 ${getUniformSampler(inputInfo)}
81020 }
81021 `;
81022 }
81023 const texNumR = texShape[0];
81024 const texNumC = texShape[1];
81025 const offset = getFlatOffsetUniformName(texName);
81026 if (texNumC === 1) {
81027 // index is used directly as physical (no risk of float16 overflow).
81028 if (enableShapeUniforms) {
81029 return `
81030 float ${funcName}(int row, int col) {
81031 float index = dot(vec3(row, col, ${offset}), vec3(${texName}Shape[1], 1, 1));
81032 vec2 uv = vec2(0.5, (index + 0.5) / float(${texName}TexShape[0]));
81033 return sampleTexture(${texName}, uv);
81034 }
81035 `;
81036 }
81037 return `
81038 float ${funcName}(int row, int col) {
81039 float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1));
81040 vec2 uv = vec2(0.5, (index + 0.5) / ${texNumR}.0);
81041 return sampleTexture(${texName}, uv);
81042 }
81043 `;
81044 }
81045 if (texNumR === 1) {
81046 // index is used directly as physical (no risk of float16 overflow).
81047 if (enableShapeUniforms) {
81048 return `
81049 float ${funcName}(int row, int col) {
81050 float index = dot(vec3(row, col, ${offset}), vec3(${texName}Shape[1], 1, 1));
81051 vec2 uv = vec2((index + 0.5) / float(${texName}TexShape[1]), 0.5);
81052 return sampleTexture(${texName}, uv);
81053 }
81054 `;
81055 }
81056 return `
81057 float ${funcName}(int row, int col) {
81058 float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1));
81059 vec2 uv = vec2((index + 0.5) / ${texNumC}.0, 0.5);
81060 return sampleTexture(${texName}, uv);
81061 }
81062 `;
81063 }
81064 if (enableShapeUniforms) {
81065 return `
81066 float ${funcName}(int row, int col) {
81067 // Explicitly use integer operations as dot() only works on floats.
81068 int index = row * ${texName}Shape[1] + col + ${offset};
81069 vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index);
81070 return sampleTexture(${texName}, uv);
81071 }
81072 `;
81073 }
81074 return `
81075 float ${funcName}(int row, int col) {
81076 // Explicitly use integer operations as dot() only works on floats.
81077 int index = row * ${shape[1]} + col + ${offset};
81078 vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
81079 return sampleTexture(${texName}, uv);
81080 }
81081`;
81082 }
81083 function getPackedSampler3D(inputInfo, enableShapeUniforms) {
81084 const shape = inputInfo.shapeInfo.logicalShape;
81085 const texName = inputInfo.name;
81086 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
81087 const texShape = inputInfo.shapeInfo.texShape;
81088 const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
81089 if (shape[0] === 1) {
81090 const squeezedShape = shape.slice(1);
81091 const keptDims = [1, 2];
81092 const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
81093 const params = ['b', 'row', 'col'];
81094 return `
81095 ${getPackedSamplerFromInInfo(newInputInfo, enableShapeUniforms)}
81096 vec4 ${funcName}(int b, int row, int col) {
81097 return ${funcName}(${getSqueezedParams(params, keptDims)});
81098 }
81099 `;
81100 }
81101 const glsl = getGlslDifferences();
81102 if (enableShapeUniforms) {
81103 return `
81104 vec4 ${funcName}(int b, int row, int col) {
81105 ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0));
81106 int valuesPerRow = int(ceil(float(${texName}Shape[2]) / 2.0));
81107 int texelsInBatch = valuesPerRow * int(ceil(float(${texName}Shape[1]) / 2.0));
81108 vec2 uv = packedUVfrom3D(
81109 packedTexShape[0], packedTexShape[1], texelsInBatch, valuesPerRow, b, row, col);
81110 return ${glsl.texture2D}(${texName}, uv);
81111 }
81112 `;
81113 }
81114 const texNumR = packedTexShape[0];
81115 const texNumC = packedTexShape[1];
81116 const valuesPerRow = Math.ceil(shape[2] / 2);
81117 const texelsInBatch = valuesPerRow * Math.ceil(shape[1] / 2);
81118 return `
81119 vec4 ${funcName}(int b, int row, int col) {
81120 vec2 uv = packedUVfrom3D(
81121 ${texNumR}, ${texNumC}, ${texelsInBatch}, ${valuesPerRow}, b, row, col);
81122 return ${glsl.texture2D}(${texName}, uv);
81123 }
81124 `;
81125 }
81126 function getSampler3D(inputInfo, enableShapeUniforms) {
81127 const shape = inputInfo.shapeInfo.logicalShape;
81128 const texName = inputInfo.name;
81129 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
81130 const stride0 = shape[1] * shape[2];
81131 const stride1 = shape[2];
81132 const { newShape, keptDims } = squeezeShape(shape);
81133 const squeezedShape = newShape;
81134 if (squeezedShape.length < shape.length) {
81135 const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
81136 const params = ['row', 'col', 'depth'];
81137 return `
81138 ${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)}
81139 float ${funcName}(int row, int col, int depth) {
81140 return ${funcName}(${getSqueezedParams(params, keptDims)});
81141 }
81142 `;
81143 }
81144 if (inputInfo.shapeInfo.isUniform) {
81145 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
81146 return `
81147 float ${funcName}(int row, int col, int depth) {
81148 int index = round(dot(vec3(row, col, depth),
81149 vec3(${stride0}, ${stride1}, 1)));
81150 ${getUniformSampler(inputInfo)}
81151 }
81152 `;
81153 }
81154 const texShape = inputInfo.shapeInfo.texShape;
81155 const texNumR = texShape[0];
81156 const texNumC = texShape[1];
81157 const flatOffset = inputInfo.shapeInfo.flatOffset;
81158 if (texNumC === stride0 && flatOffset == null) {
81159 // texC is used directly as physical (no risk of float16 overflow).
81160 if (enableShapeUniforms) {
81161 return `
81162 float ${funcName}(int row, int col, int depth) {
81163 int stride1 = ${texName}Shape[2];
81164 float texR = float(row);
81165 float texC = dot(vec2(col, depth), vec2(stride1, 1));
81166 vec2 uv = (vec2(texC, texR) + halfCR) /
81167 vec2(${texName}TexShape[1], ${texName}TexShape[0]);
81168 return sampleTexture(${texName}, uv);
81169 }
81170 `;
81171 }
81172 return `
81173 float ${funcName}(int row, int col, int depth) {
81174 float texR = float(row);
81175 float texC = dot(vec2(col, depth), vec2(${stride1}, 1));
81176 vec2 uv = (vec2(texC, texR) + halfCR) /
81177 vec2(${texNumC}.0, ${texNumR}.0);
81178 return sampleTexture(${texName}, uv);
81179 }
81180 `;
81181 }
81182 if (texNumC === stride1 && flatOffset == null) {
81183 // texR is used directly as physical (no risk of float16 overflow).
81184 if (enableShapeUniforms) {
81185 return `
81186 float ${funcName}(int row, int col, int depth) {
81187 float texR = dot(vec2(row, col), vec2(${texName}Shape[1], 1));
81188 float texC = float(depth);
81189 vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]);
81190 return sampleTexture(${texName}, uv);
81191 }
81192 `;
81193 }
81194 return `
81195 float ${funcName}(int row, int col, int depth) {
81196 float texR = dot(vec2(row, col), vec2(${shape[1]}, 1));
81197 float texC = float(depth);
81198 vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
81199 return sampleTexture(${texName}, uv);
81200 }
81201 `;
81202 }
81203 const offset = getFlatOffsetUniformName(texName);
81204 if (enableShapeUniforms) {
81205 return `
81206 float ${funcName}(int row, int col, int depth) {
81207 // Explicitly use integer operations as dot() only works on floats.
81208 int stride0 = ${texName}Shape[1] * ${texName}Shape[2];
81209 int stride1 = ${texName}Shape[2];
81210 int index = row * ${stride0} + col * ${stride1} + depth + ${offset};
81211 vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index);
81212 return sampleTexture(${texName}, uv);
81213 }
81214 `;
81215 }
81216 return `
81217 float ${funcName}(int row, int col, int depth) {
81218 // Explicitly use integer operations as dot() only works on floats.
81219 int index = row * ${stride0} + col * ${stride1} + depth + ${offset};
81220 vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
81221 return sampleTexture(${texName}, uv);
81222 }
81223 `;
81224 }
81225 function getPackedSamplerND(inputInfo, enableShapeUniforms) {
81226 const texName = inputInfo.name;
81227 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
81228 const glsl = getGlslDifferences();
81229 if (enableShapeUniforms) {
81230 // TODO: support 5d and 6d
81231 return `
81232 vec4 ${funcName}(int b2, int b, int row, int col) {
81233 int valuesPerRow = int(ceil(float(${texName}Shape[3]) / 2.0));
81234 int texelsInBatch = valuesPerRow * int(ceil(float(${texName}Shape[2]) / 2.0));
81235 int index = b * texelsInBatch + (row / 2) * valuesPerRow + (col / 2);
81236 texelsInBatch *= ${texName}Shape[1];
81237 index = b2 * texelsInBatch + index;
81238 ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0));
81239 int texR = index / packedTexShape[1];
81240 int texC = index - texR * packedTexShape[1];
81241 vec2 uv = (vec2(texC, texR) + halfCR) / vec2(packedTexShape[1], packedTexShape[0]); return ${glsl.texture2D}(${texName}, uv);
81242 }
81243 `;
81244 }
81245 const shape = inputInfo.shapeInfo.logicalShape;
81246 const rank = shape.length;
81247 const texShape = inputInfo.shapeInfo.texShape;
81248 const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
81249 const texNumR = packedTexShape[0];
81250 const texNumC = packedTexShape[1];
81251 const valuesPerRow = Math.ceil(shape[rank - 1] / 2);
81252 let texelsInBatch = valuesPerRow * Math.ceil(shape[rank - 2] / 2);
81253 let params = `int b, int row, int col`;
81254 let index = `b * ${texelsInBatch} + (row / 2) * ${valuesPerRow} + (col / 2)`;
81255 for (let b = 2; b < rank - 1; b++) {
81256 params = `int b${b}, ` + params;
81257 texelsInBatch *= shape[rank - b - 1];
81258 index = `b${b} * ${texelsInBatch} + ` + index;
81259 }
81260 return `
81261 vec4 ${funcName}(${params}) {
81262 int index = ${index};
81263 int texR = index / ${texNumC};
81264 int texC = index - texR * ${texNumC};
81265 vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}, ${texNumR});
81266 return ${glsl.texture2D}(${texName}, uv);
81267 }
81268 `;
81269 }
81270 function getSampler4D(inputInfo, enableShapeUniforms) {
81271 const shape = inputInfo.shapeInfo.logicalShape;
81272 const texName = inputInfo.name;
81273 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
81274 const stride2 = shape[3];
81275 const stride1 = shape[2] * stride2;
81276 const stride0 = shape[1] * stride1;
81277 const { newShape, keptDims } = squeezeShape(shape);
81278 if (newShape.length < shape.length) {
81279 const newInputInfo = squeezeInputInfo(inputInfo, newShape);
81280 const params = ['row', 'col', 'depth', 'depth2'];
81281 return `
81282 ${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)}
81283 float ${funcName}(int row, int col, int depth, int depth2) {
81284 return ${funcName}(${getSqueezedParams(params, keptDims)});
81285 }
81286 `;
81287 }
81288 if (inputInfo.shapeInfo.isUniform) {
81289 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
81290 return `
81291 float ${funcName}(int row, int col, int depth, int depth2) {
81292 int index = round(dot(vec4(row, col, depth, depth2),
81293 vec4(${stride0}, ${stride1}, ${stride2}, 1)));
81294 ${getUniformSampler(inputInfo)}
81295 }
81296 `;
81297 }
81298 const flatOffset = inputInfo.shapeInfo.flatOffset;
81299 const texShape = inputInfo.shapeInfo.texShape;
81300 const texNumR = texShape[0];
81301 const texNumC = texShape[1];
81302 const stride2Str = `int stride2 = ${texName}Shape[3];`;
81303 const stride1Str = `int stride1 = ${texName}Shape[2] * stride2;`;
81304 const stride0Str = `int stride0 = ${texName}Shape[1] * stride1;`;
81305 if (texNumC === stride0 && flatOffset == null) {
81306 // texC is used directly as physical (no risk of float16 overflow).
81307 if (enableShapeUniforms) {
81308 return `
81309 float ${funcName}(int row, int col, int depth, int depth2) {
81310 ${stride2Str}
81311 ${stride1Str}
81312 float texR = float(row);
81313 float texC =
81314 dot(vec3(col, depth, depth2),
81315 vec3(stride1, stride2, 1));
81316 vec2 uv = (vec2(texC, texR) + halfCR) /
81317 vec2(${texName}TexShape[1], ${texName}TexShape[0]);
81318 return sampleTexture(${texName}, uv);
81319 }
81320 `;
81321 }
81322 return `
81323 float ${funcName}(int row, int col, int depth, int depth2) {
81324 float texR = float(row);
81325 float texC =
81326 dot(vec3(col, depth, depth2),
81327 vec3(${stride1}, ${stride2}, 1));
81328 vec2 uv = (vec2(texC, texR) + halfCR) /
81329 vec2(${texNumC}.0, ${texNumR}.0);
81330 return sampleTexture(${texName}, uv);
81331 }
81332 `;
81333 }
81334 if (texNumC === stride2 && flatOffset == null) {
81335 // texR is used directly as physical (no risk of float16 overflow).
81336 if (enableShapeUniforms) {
81337 return `
81338 float ${funcName}(int row, int col, int depth, int depth2) {
81339 float texR = dot(vec3(row, col, depth),
81340 vec3(${texName}Shape[1] * ${texName}Shape[2], ${texName}Shape[2], 1));
81341 float texC = float(depth2);
81342 vec2 uv = (vec2(texC, texR) + halfCR) /
81343 vec2(${texName}TexShape[1], ${texName}TexShape[0]);
81344 return sampleTexture(${texName}, uv);
81345 }
81346 `;
81347 }
81348 return `
81349 float ${funcName}(int row, int col, int depth, int depth2) {
81350 float texR = dot(vec3(row, col, depth),
81351 vec3(${shape[1] * shape[2]}, ${shape[2]}, 1));
81352 float texC = float(depth2);
81353 vec2 uv = (vec2(texC, texR) + halfCR) /
81354 vec2(${texNumC}.0, ${texNumR}.0);
81355 return sampleTexture(${texName}, uv);
81356 }
81357 `;
81358 }
81359 const offset = getFlatOffsetUniformName(texName);
81360 if (enableShapeUniforms) {
81361 return `
81362 float ${funcName}(int row, int col, int depth, int depth2) {
81363 // Explicitly use integer operations as dot() only works on floats.
81364 ${stride2Str}
81365 ${stride1Str}
81366 ${stride0Str}
81367 int index = row * stride0 + col * stride1 +
81368 depth * stride2 + depth2;
81369 vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index + ${offset});
81370 return sampleTexture(${texName}, uv);
81371 }
81372 `;
81373 }
81374 return `
81375 float ${funcName}(int row, int col, int depth, int depth2) {
81376 // Explicitly use integer operations as dot() only works on floats.
81377 int index = row * ${stride0} + col * ${stride1} +
81378 depth * ${stride2} + depth2;
81379 vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index + ${offset});
81380 return sampleTexture(${texName}, uv);
81381 }
81382 `;
81383 }
81384 function getSampler5D(inputInfo) {
81385 const shape = inputInfo.shapeInfo.logicalShape;
81386 const texName = inputInfo.name;
81387 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
81388 const stride3 = shape[4];
81389 const stride2 = shape[3] * stride3;
81390 const stride1 = shape[2] * stride2;
81391 const stride0 = shape[1] * stride1;
81392 const { newShape, keptDims } = squeezeShape(shape);
81393 if (newShape.length < shape.length) {
81394 const newInputInfo = squeezeInputInfo(inputInfo, newShape);
81395 const params = ['row', 'col', 'depth', 'depth2', 'depth3'];
81396 return `
81397 ${getSamplerFromInInfo(newInputInfo)}
81398 float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
81399 return ${funcName}(${getSqueezedParams(params, keptDims)});
81400 }
81401 `;
81402 }
81403 if (inputInfo.shapeInfo.isUniform) {
81404 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
81405 return `
81406 float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
81407 float index = dot(
81408 vec4(row, col, depth, depth2),
81409 vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) +
81410 depth3;
81411 ${getUniformSampler(inputInfo)}
81412 }
81413 `;
81414 }
81415 const flatOffset = inputInfo.shapeInfo.flatOffset;
81416 const texShape = inputInfo.shapeInfo.texShape;
81417 const texNumR = texShape[0];
81418 const texNumC = texShape[1];
81419 if (texNumC === stride0 && flatOffset == null) {
81420 // texC is used directly as physical (no risk of float16 overflow).
81421 return `
81422 float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
81423 int texR = row;
81424 float texC = dot(vec4(col, depth, depth2, depth3),
81425 vec4(${stride1}, ${stride2}, ${stride3}, 1));
81426 vec2 uv = (vec2(texC, texR) + halfCR) /
81427 vec2(${texNumC}.0, ${texNumR}.0);
81428 return sampleTexture(${texName}, uv);
81429 }
81430 `;
81431 }
81432 if (texNumC === stride3 && flatOffset == null) {
81433 // texR is used directly as physical (no risk of float16 overflow).
81434 return `
81435 float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
81436 float texR = dot(
81437 vec4(row, col, depth, depth2),
81438 vec4(${shape[1] * shape[2] * shape[3]},
81439 ${shape[2] * shape[3]}, ${shape[3]}, 1));
81440 int texC = depth3;
81441 vec2 uv = (vec2(texC, texR) + halfCR) /
81442 vec2(${texNumC}.0, ${texNumR}.0);
81443 return sampleTexture(${texName}, uv);
81444 }
81445 `;
81446 }
81447 const offset = getFlatOffsetUniformName(texName);
81448 return `
81449 float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
81450 // Explicitly use integer operations as dot() only works on floats.
81451 int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} +
81452 depth2 * ${stride3} + depth3 + ${offset};
81453 vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
81454 return sampleTexture(${texName}, uv);
81455 }
81456 `;
81457 }
81458 function getSampler6D(inputInfo) {
81459 const shape = inputInfo.shapeInfo.logicalShape;
81460 const texName = inputInfo.name;
81461 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
81462 const { newShape, keptDims } = squeezeShape(shape);
81463 if (newShape.length < shape.length) {
81464 const newInputInfo = squeezeInputInfo(inputInfo, newShape);
81465 const params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4'];
81466 return `
81467 ${getSamplerFromInInfo(newInputInfo)}
81468 float ${funcName}(int row, int col, int depth,
81469 int depth2, int depth3, int depth4) {
81470 return ${funcName}(${getSqueezedParams(params, keptDims)});
81471 }
81472 `;
81473 }
81474 const stride4 = shape[5];
81475 const stride3 = shape[4] * stride4;
81476 const stride2 = shape[3] * stride3;
81477 const stride1 = shape[2] * stride2;
81478 const stride0 = shape[1] * stride1;
81479 if (inputInfo.shapeInfo.isUniform) {
81480 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
81481 return `
81482 float ${funcName}(int row, int col, int depth,
81483 int depth2, int depth3, int depth4) {
81484 int index = round(dot(
81485 vec4(row, col, depth, depth2),
81486 vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) +
81487 dot(
81488 vec2(depth3, depth4),
81489 vec2(${stride4}, 1)));
81490 ${getUniformSampler(inputInfo)}
81491 }
81492 `;
81493 }
81494 const flatOffset = inputInfo.shapeInfo.flatOffset;
81495 const texShape = inputInfo.shapeInfo.texShape;
81496 const texNumR = texShape[0];
81497 const texNumC = texShape[1];
81498 if (texNumC === stride0 && flatOffset == null) {
81499 // texC is used directly as physical (no risk of float16 overflow).
81500 return `
81501 float ${funcName}(int row, int col, int depth,
81502 int depth2, int depth3, int depth4) {
81503 int texR = row;
81504 float texC = dot(vec4(col, depth, depth2, depth3),
81505 vec4(${stride1}, ${stride2}, ${stride3}, ${stride4})) +
81506 float(depth4);
81507 vec2 uv = (vec2(texC, texR) + halfCR) /
81508 vec2(${texNumC}.0, ${texNumR}.0);
81509 return sampleTexture(${texName}, uv);
81510 }
81511 `;
81512 }
81513 if (texNumC === stride4 && flatOffset == null) {
81514 // texR is used directly as physical (no risk of float16 overflow).
81515 return `
81516 float ${funcName}(int row, int col, int depth,
81517 int depth2, int depth3, int depth4) {
81518 float texR = dot(vec4(row, col, depth, depth2),
81519 vec4(${shape[1] * shape[2] * shape[3] * shape[4]},
81520 ${shape[2] * shape[3] * shape[4]},
81521 ${shape[3] * shape[4]},
81522 ${shape[4]})) + float(depth3);
81523 int texC = depth4;
81524 vec2 uv = (vec2(texC, texR) + halfCR) /
81525 vec2(${texNumC}.0, ${texNumR}.0);
81526 return sampleTexture(${texName}, uv);
81527 }
81528 `;
81529 }
81530 const offset = getFlatOffsetUniformName(texName);
81531 return `
81532 float ${funcName}(int row, int col, int depth,
81533 int depth2, int depth3, int depth4) {
81534 // Explicitly use integer operations as dot() only works on floats.
81535 int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} +
81536 depth2 * ${stride3} + depth3 * ${stride4} + depth4 + ${offset};
81537 vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
81538 return sampleTexture(${texName}, uv);
81539 }
81540 `;
81541 }
81542 function getUniformSampler(inputInfo) {
81543 const texName = inputInfo.name;
81544 const inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape);
81545 if (inSize < 2) {
81546 return `return ${texName};`;
81547 }
81548 return `
81549 for (int i = 0; i < ${inSize}; i++) {
81550 if (i == index) {
81551 return ${texName}[i];
81552 }
81553 }
81554 `;
81555 }
81556 function getPackedSamplerAtOutputCoords(inputInfo, outShapeInfo) {
81557 const texName = inputInfo.name;
81558 const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
81559 const funcName = 'get' + texFuncSnippet + 'AtOutCoords';
81560 const inRank = inputInfo.shapeInfo.logicalShape.length;
81561 const outRank = outShapeInfo.logicalShape.length;
81562 const broadcastDims = getBroadcastDims$1(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
81563 const type = getCoordsDataType(outRank);
81564 const rankDiff = outRank - inRank;
81565 let coordsSnippet;
81566 const fields = ['x', 'y', 'z', 'w', 'u', 'v'];
81567 if (inRank === 0) {
81568 coordsSnippet = '';
81569 }
81570 else if (outRank < 2 && broadcastDims.length >= 1) {
81571 coordsSnippet = 'coords = 0;';
81572 }
81573 else {
81574 coordsSnippet =
81575 broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`)
81576 .join('\n');
81577 }
81578 let unpackedCoordsSnippet = '';
81579 if (outRank < 2 && inRank > 0) {
81580 unpackedCoordsSnippet = 'coords';
81581 }
81582 else {
81583 unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape
81584 .map((s, i) => `coords.${fields[i + rankDiff]}`)
81585 .join(', ');
81586 }
81587 let output = `return outputValue;`;
81588 const inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape);
81589 const isInputScalar = inSize === 1;
81590 const outSize = sizeFromShape(outShapeInfo.logicalShape);
81591 const isOutputScalar = outSize === 1;
81592 if (inRank === 1 && !isInputScalar && !isOutputScalar) {
81593 output = `
81594 return vec4(outputValue.xy, outputValue.xy);
81595 `;
81596 }
81597 else if (isInputScalar && !isOutputScalar) {
81598 if (outRank === 1) {
81599 output = `
81600 return vec4(outputValue.x, outputValue.x, 0., 0.);
81601 `;
81602 }
81603 else {
81604 output = `
81605 return vec4(outputValue.x);
81606 `;
81607 }
81608 }
81609 else if (broadcastDims.length) {
81610 const rows = inRank - 2;
81611 const cols = inRank - 1;
81612 if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) {
81613 output = `return vec4(outputValue.x);`;
81614 }
81615 else if (broadcastDims.indexOf(rows) > -1) {
81616 output = `return vec4(outputValue.x, outputValue.y, ` +
81617 `outputValue.x, outputValue.y);`;
81618 }
81619 else if (broadcastDims.indexOf(cols) > -1) {
81620 output = `return vec4(outputValue.xx, outputValue.zz);`;
81621 }
81622 }
81623 return `
81624 vec4 ${funcName}() {
81625 ${type} coords = getOutputCoords();
81626 ${coordsSnippet}
81627 vec4 outputValue = get${texFuncSnippet}(${unpackedCoordsSnippet});
81628 ${output}
81629 }
81630 `;
81631 }
81632 function getSamplerAtOutputCoords(inputInfo, outShapeInfo) {
81633 const texName = inputInfo.name;
81634 const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
81635 const funcName = 'get' + texFuncSnippet + 'AtOutCoords';
81636 const outTexShape = outShapeInfo.texShape;
81637 const inTexShape = inputInfo.shapeInfo.texShape;
81638 const inRank = inputInfo.shapeInfo.logicalShape.length;
81639 const outRank = outShapeInfo.logicalShape.length;
81640 if (!inputInfo.shapeInfo.isUniform && inRank === outRank &&
81641 inputInfo.shapeInfo.flatOffset == null &&
81642 arraysEqual(inTexShape, outTexShape)) {
81643 return `
81644 float ${funcName}() {
81645 return sampleTexture(${texName}, resultUV);
81646 }
81647 `;
81648 }
81649 const type = getCoordsDataType(outRank);
81650 const broadcastDims = getBroadcastDims$1(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
81651 const rankDiff = outRank - inRank;
81652 let coordsSnippet;
81653 const fields = ['x', 'y', 'z', 'w', 'u', 'v'];
81654 if (inRank === 0) {
81655 coordsSnippet = '';
81656 }
81657 else if (outRank < 2 && broadcastDims.length >= 1) {
81658 coordsSnippet = 'coords = 0;';
81659 }
81660 else {
81661 coordsSnippet =
81662 broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`)
81663 .join('\n');
81664 }
81665 let unpackedCoordsSnippet = '';
81666 if (outRank < 2 && inRank > 0) {
81667 unpackedCoordsSnippet = 'coords';
81668 }
81669 else {
81670 unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape
81671 .map((s, i) => `coords.${fields[i + rankDiff]}`)
81672 .join(', ');
81673 }
81674 return `
81675 float ${funcName}() {
81676 ${type} coords = getOutputCoords();
81677 ${coordsSnippet}
81678 return get${texFuncSnippet}(${unpackedCoordsSnippet});
81679 }
81680 `;
81681 }
81682 function getCoordsDataType(rank) {
81683 if (rank <= 1) {
81684 return 'int';
81685 }
81686 else if (rank === 2) {
81687 return 'ivec2';
81688 }
81689 else if (rank === 3) {
81690 return 'ivec3';
81691 }
81692 else if (rank === 4) {
81693 return 'ivec4';
81694 }
81695 else if (rank === 5) {
81696 return 'ivec5';
81697 }
81698 else if (rank === 6) {
81699 return 'ivec6';
81700 }
81701 else {
81702 throw Error(`GPU for rank ${rank} is not yet supported`);
81703 }
81704 }
81705 function getUniformInfoFromShape(isPacked, shape, texShape) {
81706 const { newShape, keptDims } = squeezeShape(shape);
81707 const rank = shape.length;
81708 const useSqueezePackedShape = isPacked && rank === 3 && shape[0] === 1;
81709 const squeezeShape$1 = useSqueezePackedShape ? shape.slice(1) : newShape;
81710 const useSqueezeShape = (!isPacked && rank > 1 && !arraysEqual(shape, texShape) &&
81711 newShape.length < rank) ||
81712 useSqueezePackedShape;
81713 const uniformShape = useSqueezeShape ? squeezeShape$1 : shape;
81714 return { useSqueezeShape, uniformShape, keptDims };
81715 }
81716 /** Returns a new input info (a copy) that has a squeezed logical shape. */
81717 function squeezeInputInfo(inInfo, squeezedShape) {
81718 // Deep copy.
81719 const newInputInfo = JSON.parse(JSON.stringify(inInfo));
81720 newInputInfo.shapeInfo.logicalShape = squeezedShape;
81721 return newInputInfo;
81722 }
81723 function getSqueezedParams(params, keptDims) {
81724 return keptDims.map(d => params[d]).join(', ');
81725 }
81726
81727 /**
81728 * @license
81729 * Copyright 2017 Google LLC. All Rights Reserved.
81730 * Licensed under the Apache License, Version 2.0 (the "License");
81731 * you may not use this file except in compliance with the License.
81732 * You may obtain a copy of the License at
81733 *
81734 * http://www.apache.org/licenses/LICENSE-2.0
81735 *
81736 * Unless required by applicable law or agreed to in writing, software
81737 * distributed under the License is distributed on an "AS IS" BASIS,
81738 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81739 * See the License for the specific language governing permissions and
81740 * limitations under the License.
81741 * =============================================================================
81742 */
81743 function compileProgram(gpgpu, program, inputs, output) {
81744 const inputInfos = inputs.map((input, i) => {
81745 const shapeInfo = {
81746 logicalShape: input.shape,
81747 texShape: input.isUniform ? null : input.texData.texShape,
81748 isUniform: input.isUniform,
81749 isPacked: input.isUniform ? false : input.texData.isPacked,
81750 flatOffset: null
81751 };
81752 if (input.texData != null && input.texData.slice != null &&
81753 input.texData.slice.flatOffset > 0) {
81754 shapeInfo.flatOffset = input.texData.slice.flatOffset;
81755 }
81756 return { name: program.variableNames[i], shapeInfo };
81757 });
81758 const inShapeInfos = inputInfos.map(x => x.shapeInfo);
81759 const outShapeInfo = {
81760 logicalShape: output.shape,
81761 texShape: output.texData.texShape,
81762 isUniform: false,
81763 isPacked: output.texData.isPacked,
81764 flatOffset: null
81765 };
81766 const source = makeShader(inputInfos, outShapeInfo, program);
81767 const fragmentShader = createFragmentShader(gpgpu.gl, source);
81768 const webGLProgram = gpgpu.createProgram(fragmentShader);
81769 if (!env().get('ENGINE_COMPILE_ONLY')) {
81770 return Object.assign({ program,
81771 fragmentShader,
81772 source,
81773 webGLProgram,
81774 inShapeInfos,
81775 outShapeInfo }, getUniformLocations(gpgpu, program, webGLProgram));
81776 }
81777 else {
81778 return {
81779 program,
81780 fragmentShader,
81781 source,
81782 webGLProgram,
81783 inShapeInfos,
81784 outShapeInfo,
81785 uniformLocations: null,
81786 customUniformLocations: null,
81787 infLoc: null,
81788 nanLoc: null,
81789 inShapesLocations: null,
81790 inTexShapesLocations: null,
81791 outShapeLocation: null,
81792 outShapeStridesLocation: null,
81793 outTexShapeLocation: null
81794 };
81795 }
81796 }
81797 function getUniformLocations(gpgpu, program, webGLProgram) {
81798 const uniformLocations = {};
81799 const inShapesLocations = {};
81800 const inTexShapesLocations = {};
81801 const customUniformLocations = [];
81802 let outShapeLocation;
81803 let outTexShapeLocation;
81804 let outShapeStridesLocation;
81805 let infLoc = null;
81806 let nanLoc = null;
81807 // Add special uniforms (NAN, INFINITY)
81808 nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false);
81809 if (env().getNumber('WEBGL_VERSION') === 1) {
81810 infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false);
81811 }
81812 // Add user-defined uniforms
81813 const shouldThrow = false;
81814 for (let i = 0; i < program.variableNames.length; i++) {
81815 const varName = program.variableNames[i];
81816 uniformLocations[varName] =
81817 gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow);
81818 uniformLocations[`offset${varName}`] =
81819 gpgpu.getUniformLocation(webGLProgram, `offset${varName}`, shouldThrow);
81820 if (program.enableShapeUniforms) {
81821 inShapesLocations[`${varName}Shape`] = gpgpu.getUniformLocation(webGLProgram, `${varName}Shape`, shouldThrow);
81822 inTexShapesLocations[`${varName}TexShape`] = gpgpu.getUniformLocation(webGLProgram, `${varName}TexShape`, shouldThrow);
81823 }
81824 }
81825 if (program.enableShapeUniforms) {
81826 outShapeLocation =
81827 gpgpu.getUniformLocation(webGLProgram, 'outShape', shouldThrow);
81828 outShapeStridesLocation =
81829 gpgpu.getUniformLocation(webGLProgram, 'outShapeStrides', shouldThrow);
81830 outTexShapeLocation =
81831 gpgpu.getUniformLocation(webGLProgram, 'outTexShape', shouldThrow);
81832 }
81833 if (program.customUniforms) {
81834 program.customUniforms.forEach((d, i) => {
81835 customUniformLocations[i] =
81836 gpgpu.getUniformLocation(webGLProgram, d.name, shouldThrow);
81837 });
81838 }
81839 return {
81840 uniformLocations,
81841 customUniformLocations,
81842 infLoc,
81843 nanLoc,
81844 inShapesLocations,
81845 inTexShapesLocations,
81846 outShapeLocation,
81847 outShapeStridesLocation,
81848 outTexShapeLocation
81849 };
81850 }
81851 function validateBinaryAndProgram(shapeInfos, inputs) {
81852 if (shapeInfos.length !== inputs.length) {
81853 throw Error(`Binary was compiled with ${shapeInfos.length} inputs, but ` +
81854 `was executed with ${inputs.length} inputs`);
81855 }
81856 shapeInfos.forEach((s, i) => {
81857 const shapeA = s.logicalShape;
81858 const input = inputs[i];
81859 const shapeB = input.shape;
81860 if (!arraysEqual(shapeA, shapeB)) {
81861 throw Error(`Binary was compiled with different shapes than ` +
81862 `the current args. Shapes ${shapeA} and ${shapeB} must match`);
81863 }
81864 // The input is uploaded as uniform.
81865 if (s.isUniform && input.isUniform) {
81866 return;
81867 }
81868 const texShapeA = s.texShape;
81869 const texShapeB = input.isUniform ? null : input.texData.texShape;
81870 if (!arraysEqual(texShapeA, texShapeB)) {
81871 throw Error(`Binary was compiled with different texture shapes than the` +
81872 ` current args. Shape ${texShapeA} and ${texShapeB} must match`);
81873 }
81874 });
81875 }
81876 function runProgram(gpgpu, binary, inputs, output, customUniformValues) {
81877 if (!binary.program.enableShapeUniforms) {
81878 validateBinaryAndProgram(binary.inShapeInfos, inputs);
81879 validateBinaryAndProgram([binary.outShapeInfo], [output]);
81880 }
81881 const outTex = output.texData.texture;
81882 const outTexShape = output.texData.texShape;
81883 if (output.texData.isPacked) {
81884 gpgpu.setOutputPackedMatrixTexture(outTex.texture, outTexShape[0], outTexShape[1]);
81885 }
81886 else {
81887 gpgpu.setOutputMatrixTexture(outTex.texture, outTexShape[0], outTexShape[1]);
81888 }
81889 gpgpu.setProgram(binary.webGLProgram);
81890 // Set special uniforms (NAN, INFINITY)
81891 if (env().getNumber('WEBGL_VERSION') === 1) {
81892 if (binary.infLoc !== null) {
81893 gpgpu.gl.uniform1f(binary.infLoc, Infinity);
81894 }
81895 }
81896 if (binary.nanLoc !== null) {
81897 gpgpu.gl.uniform1f(binary.nanLoc, NaN);
81898 }
81899 // Set user-defined inputs
81900 inputs.forEach((input, i) => {
81901 const varName = binary.program.variableNames[i];
81902 const varLoc = binary.uniformLocations[varName];
81903 const varOffsetLoc = binary.uniformLocations[`offset${varName}`];
81904 const varShapeLoc = binary.inShapesLocations[`${varName}Shape`];
81905 const varTexShapeLoc = binary.inTexShapesLocations[`${varName}TexShape`];
81906 if (varShapeLoc) {
81907 const { uniformShape } = getUniformInfoFromShape(binary.program.packedInputs, input.shape, input.texData.texShape);
81908 switch (uniformShape.length) {
81909 case 1:
81910 gpgpu.gl.uniform1iv(varShapeLoc, new Int32Array(uniformShape));
81911 break;
81912 case 2:
81913 gpgpu.gl.uniform2iv(varShapeLoc, new Int32Array(uniformShape));
81914 break;
81915 case 3:
81916 gpgpu.gl.uniform3iv(varShapeLoc, new Int32Array(uniformShape));
81917 break;
81918 case 4:
81919 gpgpu.gl.uniform4iv(varShapeLoc, new Int32Array(uniformShape));
81920 break;
81921 default:
81922 break;
81923 }
81924 }
81925 if (varTexShapeLoc) {
81926 gpgpu.gl.uniform2i(varTexShapeLoc, input.texData.texShape[0], input.texData.texShape[1]);
81927 }
81928 if (varLoc == null) {
81929 // The compiler inferred that this variable is not used in this shader.
81930 return;
81931 }
81932 if (input.isUniform) {
81933 // Upload the values of the tensor as uniform.
81934 if (sizeFromShape(input.shape) < 2) {
81935 gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]);
81936 }
81937 else {
81938 let vals = input.uniformValues;
81939 if (!(vals instanceof Float32Array)) {
81940 vals = new Float32Array(vals);
81941 }
81942 gpgpu.gl.uniform1fv(varLoc, vals);
81943 }
81944 return;
81945 }
81946 // If the input was sliced, upload the flat offset index.
81947 if (input.texData.slice != null && varOffsetLoc != null) {
81948 gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset);
81949 }
81950 gpgpu.setInputMatrixTexture(input.texData.texture.texture, varLoc, i);
81951 });
81952 const outShapeLoc = binary.outShapeLocation;
81953 if (outShapeLoc) {
81954 switch (output.shape.length) {
81955 case 1:
81956 gpgpu.gl.uniform1iv(outShapeLoc, new Int32Array(output.shape));
81957 break;
81958 case 2:
81959 gpgpu.gl.uniform2iv(outShapeLoc, new Int32Array(output.shape));
81960 break;
81961 case 3:
81962 gpgpu.gl.uniform3iv(outShapeLoc, new Int32Array(output.shape));
81963 break;
81964 case 4:
81965 gpgpu.gl.uniform4iv(outShapeLoc, new Int32Array(output.shape));
81966 break;
81967 default:
81968 break;
81969 }
81970 }
81971 if (binary.outShapeStridesLocation) {
81972 const strides = computeStrides(output.shape);
81973 switch (output.shape.length) {
81974 case 2:
81975 gpgpu.gl.uniform1iv(binary.outShapeStridesLocation, new Int32Array(strides));
81976 break;
81977 case 3:
81978 gpgpu.gl.uniform2iv(binary.outShapeStridesLocation, new Int32Array(strides));
81979 break;
81980 case 4:
81981 gpgpu.gl.uniform3iv(binary.outShapeStridesLocation, new Int32Array(strides));
81982 break;
81983 default:
81984 break;
81985 }
81986 }
81987 if (binary.outTexShapeLocation) {
81988 gpgpu.gl.uniform2i(binary.outTexShapeLocation, output.texData.texShape[0], output.texData.texShape[1]);
81989 }
81990 if (binary.program.customUniforms && customUniformValues) {
81991 binary.program.customUniforms.forEach((d, i) => {
81992 const customLoc = binary.customUniformLocations[i];
81993 const customValue = customUniformValues[i];
81994 if (d.type === 'float') {
81995 gpgpu.gl.uniform1fv(customLoc, customValue);
81996 }
81997 else if (d.type === 'vec2') {
81998 gpgpu.gl.uniform2fv(customLoc, customValue);
81999 }
82000 else if (d.type === 'vec3') {
82001 gpgpu.gl.uniform3fv(customLoc, customValue);
82002 }
82003 else if (d.type === 'vec4') {
82004 gpgpu.gl.uniform4fv(customLoc, customValue);
82005 }
82006 else if (d.type === 'int') {
82007 gpgpu.gl.uniform1iv(customLoc, customValue);
82008 }
82009 else if (d.type === 'ivec2') {
82010 gpgpu.gl.uniform2iv(customLoc, customValue);
82011 }
82012 else if (d.type === 'ivec3') {
82013 gpgpu.gl.uniform3iv(customLoc, customValue);
82014 }
82015 else if (d.type === 'ivec4') {
82016 gpgpu.gl.uniform4iv(customLoc, customValue);
82017 }
82018 else {
82019 throw Error(`uniform type ${d.type} is not supported yet.`);
82020 }
82021 });
82022 }
82023 gpgpu.executeProgram();
82024 }
82025 function makeShaderKey(program, inputs, output) {
82026 let keyInputs = '';
82027 inputs.concat(output).forEach(x => {
82028 const hasOffset = x.texData != null && x.texData.slice != null &&
82029 x.texData.slice.flatOffset > 0;
82030 // TODO: Remove the condition of !x.isUniform.
82031 if (program.enableShapeUniforms && !x.isUniform) {
82032 const xTexShape = x.texData.texShape;
82033 const { useSqueezeShape, uniformShape, keptDims } = getUniformInfoFromShape(program.packedInputs, x.shape, xTexShape);
82034 let rank1 = '', rank2 = '', rank34 = '';
82035 if (uniformShape.length === 1 && program.packedInputs) {
82036 const packedTexShape = [Math.ceil(xTexShape[0] / 2), Math.ceil(xTexShape[1] / 2)];
82037 rank1 = `${packedTexShape[0] > 1}_${packedTexShape[1] > 1}`;
82038 }
82039 else if (uniformShape.length === 2 && !program.packedInputs) {
82040 rank2 = `${uniformShape[0] > 1}_${uniformShape[1] > 1}`;
82041 }
82042 else if (uniformShape.length > 2 && !program.packedInputs) {
82043 const strides = computeStrides(uniformShape);
82044 rank34 = `${strides[0] === xTexShape[1]}_${strides[strides.length - 1] === xTexShape[1]}`;
82045 }
82046 const xRank = x.shape.length;
82047 const isLogicalShapTexShapeEqual = uniformShape.length === 2 && arraysEqual(x.shape, xTexShape);
82048 const isScalar = sizeFromShape(x.shape) === 1;
82049 const broadcastDims = getBroadcastDims(x.shape, output.shape);
82050 const isInOutTexShapeEqual = !program.packedInputs &&
82051 xRank === output.shape.length &&
82052 arraysEqual(xTexShape, output.texData.texShape);
82053 const isTexShapeGreaterThanOne = program.packedInputs || uniformShape.length > 2 ?
82054 '' :
82055 `${xTexShape[0] > 1}_${xTexShape[1] > 1}`;
82056 // These key components are needed due to shader_compiler is embedding
82057 // them in the shader.
82058 // |xRank| is used to determine the coords length. See
82059 // get[Packed]SamplerAtOutputCoords.
82060 // |isInOutTexShapeEqual| is used to determine whether going to an
82061 // optimization path in getSamplerAtOutputCoords.
82062 // |useSqueezeShape| is extracted from squeezeInputInfo of
82063 // getSampler[2|3|4]D/getPackedSampler3D.
82064 // |isScalar| is extracted from isInputScalar/isOutputScalar in
82065 // getPackedSamplerAtOutputCoords.
82066 // |broadcastDims| is extracted from get[Packed]SamplerAtOutputCoords.
82067 // |isLogicalShapTexShapeEqual| is used in
82068 // getOutput[Packed]2DCoords/get[Packed]Sampler2D.
82069 // |rank1| is used in getOutputPacked1DCoords.
82070 // |rank2| is used in getOutput2DCoords.
82071 // |rank34| is used in getSampler3D/getSampler4D.
82072 // |isTexShapeGreaterThanOne| are used in
82073 // getSampler[Scalar|1D|2D]/getOutput1DCoords.
82074 keyInputs += `${xRank}_${isInOutTexShapeEqual}_${useSqueezeShape ? keptDims : ''}_${uniformShape.length}_${isScalar}_${broadcastDims}_${isLogicalShapTexShapeEqual}_${rank1}_${rank2}_${rank34}_${isTexShapeGreaterThanOne}_${hasOffset}`;
82075 }
82076 else {
82077 const texShape = x.isUniform ? 'uniform' : x.texData.texShape;
82078 keyInputs += `${x.shape}_${texShape}_${hasOffset}`;
82079 }
82080 });
82081 const keyUserCode = program.userCode;
82082 let key = program.constructor.name;
82083 // Fast string concat. See https://jsperf.com/string-concatenation/14.
82084 key += '_' + keyInputs + '_' + keyUserCode +
82085 `${env().getNumber('WEBGL_VERSION')}`;
82086 return key;
82087 }
82088 function useShapeUniforms(rank) {
82089 // TODO: Remove the limitaion of rank <= 4.
82090 return env().getBool('WEBGL_USE_SHAPES_UNIFORMS') && rank <= 4;
82091 }
82092
82093 /**
82094 * @license
82095 * Copyright 2019 Google LLC. All Rights Reserved.
82096 * Licensed under the Apache License, Version 2.0 (the "License");
82097 * you may not use this file except in compliance with the License.
82098 * You may obtain a copy of the License at
82099 *
82100 * http://www.apache.org/licenses/LICENSE-2.0
82101 *
82102 * Unless required by applicable law or agreed to in writing, software
82103 * distributed under the License is distributed on an "AS IS" BASIS,
82104 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82105 * See the License for the specific language governing permissions and
82106 * limitations under the License.
82107 * =============================================================================
82108 */
82109 class DecodeMatrixProgram {
82110 constructor(outputShape) {
82111 this.variableNames = ['A'];
82112 this.packedInputs = false;
82113 this.packedOutput = true;
82114 this.outPackingScheme = PackingScheme.DENSE;
82115 this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
82116 const glsl = getGlslDifferences();
82117 this.outputShape = outputShape;
82118 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
82119 this.userCode = `
82120 ivec3 outCoordsFromFlatIndex(int index) {
82121 ${this.enableShapeUniforms ?
82122 getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) :
82123 getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape)}
82124 return ivec3(r, c, d);
82125 }
82126
82127 void main() {
82128 ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1]));
82129 int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y);
82130
82131 vec4 result = vec4(0.);
82132
82133 for (int i=0; i<4; i++) {
82134 int flatIndex = index + i;
82135 ivec3 rc = outCoordsFromFlatIndex(flatIndex);
82136 result[i] = getA(rc.x, rc.y, rc.z);
82137 }
82138
82139 ${glsl.output} = result;
82140 }
82141 `;
82142 }
82143 }
82144
82145 /**
82146 * @license
82147 * Copyright 2019 Google LLC. All Rights Reserved.
82148 * Licensed under the Apache License, Version 2.0 (the "License");
82149 * you may not use this file except in compliance with the License.
82150 * You may obtain a copy of the License at
82151 *
82152 * http://www.apache.org/licenses/LICENSE-2.0
82153 *
82154 * Unless required by applicable law or agreed to in writing, software
82155 * distributed under the License is distributed on an "AS IS" BASIS,
82156 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82157 * See the License for the specific language governing permissions and
82158 * limitations under the License.
82159 * =============================================================================
82160 */
82161 class DecodeMatrixPackedProgram {
82162 constructor(outputShape) {
82163 this.variableNames = ['A'];
82164 this.packedInputs = true;
82165 this.packedOutput = true;
82166 this.outPackingScheme = PackingScheme.DENSE;
82167 this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
82168 const glsl = getGlslDifferences();
82169 this.outputShape = outputShape;
82170 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
82171 this.userCode = `
82172 ivec3 outCoordsFromFlatIndex(int index) {
82173 ${this.enableShapeUniforms ?
82174 getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) :
82175 getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape)}
82176 return ivec3(r, c, d);
82177 }
82178
82179 void main() {
82180 ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1]));
82181 int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y);
82182
82183 vec4 result = vec4(0.);
82184
82185 for (int i=0; i<4; i++) {
82186 int flatIndex = index + i;
82187 ivec3 rc = outCoordsFromFlatIndex(flatIndex);
82188 result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z));
82189 }
82190
82191 ${glsl.output} = result;
82192 }
82193 `;
82194 }
82195 }
82196
82197 /**
82198 * @license
82199 * Copyright 2018 Google LLC. All Rights Reserved.
82200 * Licensed under the Apache License, Version 2.0 (the "License");
82201 * you may not use this file except in compliance with the License.
82202 * You may obtain a copy of the License at
82203 *
82204 * http://www.apache.org/licenses/LICENSE-2.0
82205 *
82206 * Unless required by applicable law or agreed to in writing, software
82207 * distributed under the License is distributed on an "AS IS" BASIS,
82208 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82209 * See the License for the specific language governing permissions and
82210 * limitations under the License.
82211 * =============================================================================
82212 */
82213 class EncodeFloatProgram {
82214 constructor(outputShape) {
82215 this.variableNames = ['A'];
82216 this.outTexUsage = TextureUsage.DOWNLOAD;
82217 const glsl = getGlslDifferences();
82218 this.outputShape = outputShape;
82219 this.userCode = `
82220 ${ENCODE_FLOAT_SNIPPET}
82221
82222 void main() {
82223 float x = getAAtOutCoords();
82224 ${glsl.output} = encode_float(x);
82225 }
82226 `;
82227 }
82228 }
82229
82230 /**
82231 * @license
82232 * Copyright 2018 Google LLC. All Rights Reserved.
82233 * Licensed under the Apache License, Version 2.0 (the "License");
82234 * you may not use this file except in compliance with the License.
82235 * You may obtain a copy of the License at
82236 *
82237 * http://www.apache.org/licenses/LICENSE-2.0
82238 *
82239 * Unless required by applicable law or agreed to in writing, software
82240 * distributed under the License is distributed on an "AS IS" BASIS,
82241 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82242 * See the License for the specific language governing permissions and
82243 * limitations under the License.
82244 * =============================================================================
82245 */
82246 class EncodeFloatPackedProgram {
82247 constructor(outputShape) {
82248 this.variableNames = ['A'];
82249 this.packedInputs = true;
82250 this.packedOutput = false;
82251 this.outTexUsage = TextureUsage.DOWNLOAD;
82252 const glsl = getGlslDifferences();
82253 this.outputShape = outputShape;
82254 this.userCode = `
82255 ${ENCODE_FLOAT_SNIPPET}
82256
82257 void main() {
82258 ivec3 coords = getOutputCoords();
82259 float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z));
82260 ${glsl.output} = encode_float(x);
82261 }
82262 `;
82263 }
82264 }
82265
82266 /**
82267 * @license
82268 * Copyright 2018 Google LLC. All Rights Reserved.
82269 * Licensed under the Apache License, Version 2.0 (the "License");
82270 * you may not use this file except in compliance with the License.
82271 * You may obtain a copy of the License at
82272 *
82273 * http://www.apache.org/licenses/LICENSE-2.0
82274 *
82275 * Unless required by applicable law or agreed to in writing, software
82276 * distributed under the License is distributed on an "AS IS" BASIS,
82277 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82278 * See the License for the specific language governing permissions and
82279 * limitations under the License.
82280 * =============================================================================
82281 */
82282 class EncodeMatrixProgram {
82283 constructor(outputShape, inputIsUnsignedByte = false) {
82284 this.variableNames = ['A'];
82285 this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
82286 const glsl = getGlslDifferences();
82287 this.outputShape = outputShape;
82288 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
82289 let output = `result`;
82290 if (inputIsUnsignedByte) {
82291 output = `floor(result * 255. + 0.5)`;
82292 }
82293 this.userCode = `
82294 ${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() :
82295 getFlatIndexFrom3D(outputShape)}
82296
82297 void main() {
82298 ivec3 coords = getOutputCoords();
82299
82300 int flatIndex = getFlatIndex(coords);
82301 int offset = imod(flatIndex, 4);
82302
82303 flatIndex = idiv(flatIndex, 4, 1.);
82304
82305 int r = flatIndex / texShape[1];
82306 int c = imod(flatIndex, texShape[1]);
82307 vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]);
82308 vec4 values = ${glsl.texture2D}(A, uv);
82309
82310 float result;
82311
82312 if(offset == 0) {
82313 result = values[0];
82314 } else if(offset == 1) {
82315 result = values[1];
82316 } else if(offset == 2) {
82317 result = values[2];
82318 } else {
82319 result = values[3];
82320 }
82321
82322 ${glsl.output} = vec4(${output}, 0., 0., 0.);
82323 }
82324 `;
82325 }
82326 }
82327
82328 /**
82329 * @license
82330 * Copyright 2018 Google LLC. All Rights Reserved.
82331 * Licensed under the Apache License, Version 2.0 (the "License");
82332 * you may not use this file except in compliance with the License.
82333 * You may obtain a copy of the License at
82334 *
82335 * http://www.apache.org/licenses/LICENSE-2.0
82336 *
82337 * Unless required by applicable law or agreed to in writing, software
82338 * distributed under the License is distributed on an "AS IS" BASIS,
82339 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82340 * See the License for the specific language governing permissions and
82341 * limitations under the License.
82342 * =============================================================================
82343 */
82344 /*
82345 This is how the shader encodes a tensor with shape = [2, 3, 5]
82346 (indices are [batch, row, col]).
82347
82348 000|001 002|003 004|xxx 020|021 022|023 024|xxx
82349 ------- ------- ------- ------- ------- -------
82350 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx
82351
82352 100|101 102|103 104|xxx 120|121 122|123 124|xxx
82353 ------- ------- ------- ------- ------- -------
82354 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx
82355
82356 Single texels contain only values from the same batch, and from adjacent rows
82357 and columns.
82358 */
82359 class EncodeMatrixPackedProgram {
82360 constructor(outputShape, inputIsUnsignedByte = false) {
82361 this.variableNames = ['A'];
82362 this.packedInputs = false;
82363 this.packedOutput = true;
82364 this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
82365 const glsl = getGlslDifferences();
82366 this.outputShape = outputShape;
82367 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
82368 let mainLoop = '';
82369 let output = 'result';
82370 if (inputIsUnsignedByte) {
82371 output = 'floor(result * 255. + 0.5)';
82372 }
82373 for (let row = 0; row <= 1; row++) {
82374 for (let col = 0; col <= 1; col++) {
82375 const channel = row * 2 + col;
82376 mainLoop += `
82377 localCoords = coords;
82378 if(localCoords[2] + ${col} < ${this.enableShapeUniforms ? 'outShape[2]' : `${outputShape[2]}`}) {
82379 localCoords[2] += ${col};
82380 if (localCoords[1] + ${row} < ${this.enableShapeUniforms ? 'outShape[1]' : `${outputShape[1]}`}) {
82381 localCoords[1] += ${row};
82382
82383 flatIndex = getFlatIndex(localCoords);
82384 offset = imod(flatIndex, 4);
82385
82386 flatIndex = idiv(flatIndex, 4, 1.);
82387
82388 int r = flatIndex / texShape[1];
82389 int c = imod(flatIndex, texShape[1]);
82390 vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]);
82391 values = ${glsl.texture2D}(A, uv);
82392
82393 if (offset == 0) {
82394 result[${channel}] = values[0];
82395 } else if (offset == 1) {
82396 result[${channel}] = values[1];
82397 } else if (offset == 2) {
82398 result[${channel}] = values[2];
82399 } else {
82400 result[${channel}] = values[3];
82401 }
82402 }
82403 }
82404 `;
82405 }
82406 }
82407 this.userCode = `
82408 ${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() :
82409 getFlatIndexFrom3D(outputShape)}
82410
82411 void main() {
82412 ivec3 coords = getOutputCoords();
82413
82414 vec4 result = vec4(0.);
82415 int flatIndex, r, c, offset;
82416 ivec3 localCoords;
82417 vec2 uv;
82418 vec4 values;
82419
82420 ${mainLoop}
82421
82422 ${glsl.output} = ${output};
82423 }
82424 `;
82425 }
82426 }
82427
82428 /**
82429 * @license
82430 * Copyright 2017 Google LLC. All Rights Reserved.
82431 * Licensed under the Apache License, Version 2.0 (the "License");
82432 * you may not use this file except in compliance with the License.
82433 * You may obtain a copy of the License at
82434 *
82435 * http://www.apache.org/licenses/LICENSE-2.0
82436 *
82437 * Unless required by applicable law or agreed to in writing, software
82438 * distributed under the License is distributed on an "AS IS" BASIS,
82439 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82440 * See the License for the specific language governing permissions and
82441 * limitations under the License.
82442 * =============================================================================
82443 */
82444 function createVertexShader$1(gl) {
82445 const glsl = getGlslDifferences();
82446 const vertexShaderSource = `${glsl.version}
82447 precision highp float;
82448 ${glsl.attribute} vec3 clipSpacePos;
82449 ${glsl.attribute} vec2 uv;
82450 ${glsl.varyingVs} vec2 resultUV;
82451
82452 void main() {
82453 gl_Position = vec4(clipSpacePos, 1);
82454 resultUV = uv;
82455 }`;
82456 return createVertexShader(gl, vertexShaderSource);
82457 }
82458 function createVertexBuffer(gl) {
82459 // [x y z u v] * [upper-left, lower-left, upper-right, lower-right]
82460 const vertexArray = new Float32Array([-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]);
82461 return createStaticVertexBuffer(gl, vertexArray);
82462 }
82463 function createIndexBuffer(gl) {
82464 // OpenGL (and WebGL) have "CCW == front" winding
82465 const triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]);
82466 return createStaticIndexBuffer(gl, triangleVertexIndices);
82467 }
82468 function createAndConfigureTexture(gl, width, height, internalFormat, textureFormat, textureType) {
82469 validateTextureSize(width, height);
82470 const texture = createTexture(gl);
82471 const tex2d = gl.TEXTURE_2D;
82472 callAndCheck(gl, () => gl.bindTexture(tex2d, texture));
82473 callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE));
82474 callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE));
82475 callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST));
82476 callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST));
82477 if (env().getNumber('WEBGL_VERSION') === 1) {
82478 callAndCheck(gl, () => gl.texImage2D(tex2d, 0, internalFormat, width, height, 0, textureFormat, textureType, null));
82479 }
82480 else {
82481 callAndCheck(gl, () => gl
82482 .texStorage2D(tex2d, 1, internalFormat, width, height));
82483 }
82484 callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
82485 return { texture, texShape: [height, width] };
82486 }
82487 function getInternalFormatForFloat32MatrixTexture(textureConfig) {
82488 return textureConfig.internalFormatFloat;
82489 }
82490 function createFloat32MatrixTexture(gl, rows, columns, textureConfig) {
82491 const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
82492 return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat32MatrixTexture(textureConfig), textureConfig.textureFormatFloat, gl.FLOAT);
82493 }
82494 function getInternalFormatForFloat16MatrixTexture(textureConfig) {
82495 return textureConfig.internalFormatHalfFloat;
82496 }
82497 function createFloat16MatrixTexture(gl, rows, columns, textureConfig) {
82498 const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
82499 return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16MatrixTexture(textureConfig), textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat);
82500 }
82501 function getInternalFormatForUnsignedBytesMatrixTexture(textureConfig) {
82502 return textureConfig.downloadTextureFormat;
82503 }
82504 function createUnsignedBytesMatrixTexture(gl, rows, columns, textureConfig) {
82505 const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
82506 return createAndConfigureTexture(gl, width, height, getInternalFormatForUnsignedBytesMatrixTexture(textureConfig), gl.RGBA, gl.UNSIGNED_BYTE);
82507 }
82508 function getInternalFormatForPackedMatrixTexture(textureConfig) {
82509 return textureConfig.internalFormatPackedFloat;
82510 }
82511 function createPackedMatrixTexture(gl, rows, columns, textureConfig) {
82512 const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
82513 return createAndConfigureTexture(gl, width, height, getInternalFormatForPackedMatrixTexture(textureConfig), gl.RGBA, gl.FLOAT);
82514 }
82515 function getInternalFormatForFloat16PackedMatrixTexture(textureConfig) {
82516 return textureConfig.internalFormatPackedHalfFloat;
82517 }
82518 function createFloat16PackedMatrixTexture(gl, rows, columns, textureConfig) {
82519 const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
82520 return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16PackedMatrixTexture(textureConfig), gl.RGBA, textureConfig.textureTypeHalfFloat);
82521 }
82522 function bindVertexProgramAttributeStreams(gl, program, vertexBuffer) {
82523 const posOffset = 0; // x is the first buffer element
82524 const uvOffset = 3 * 4; // uv comes after [x y z]
82525 const stride = (3 * 4) + (2 * 4); // xyz + uv, each entry is 4-byte float.
82526 callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer));
82527 const success = bindVertexBufferToProgramAttribute(gl, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset);
82528 return success &&
82529 bindVertexBufferToProgramAttribute(gl, program, 'uv', vertexBuffer, 2, stride, uvOffset);
82530 }
82531 function uploadDenseMatrixToTexture(gl, texture, width, height, data, textureConfig) {
82532 callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture));
82533 let dataForUpload, texelDataType, internalFormat;
82534 if (data instanceof Uint8Array) {
82535 dataForUpload = new Uint8Array(width * height * 4);
82536 texelDataType = gl.UNSIGNED_BYTE;
82537 internalFormat = gl.RGBA;
82538 }
82539 else {
82540 dataForUpload = new Float32Array(width * height * 4);
82541 texelDataType = gl.FLOAT;
82542 internalFormat = textureConfig.internalFormatPackedFloat;
82543 }
82544 dataForUpload.set(data);
82545 if (env().getNumber('WEBGL_VERSION') === 2) {
82546 callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, width, height, gl.RGBA, texelDataType, dataForUpload));
82547 }
82548 else {
82549 callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA, texelDataType, dataForUpload));
82550 }
82551 callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
82552 }
82553 function uploadPixelDataToTexture(gl, texture, pixels) {
82554 callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture));
82555 if (pixels.data instanceof Uint8Array) {
82556 if (env().getNumber('WEBGL_VERSION') === 2) {
82557 callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, pixels.width, pixels.height, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data));
82558 }
82559 else {
82560 callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data));
82561 }
82562 }
82563 else {
82564 if (env().getNumber('WEBGL_VERSION') === 2) {
82565 callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels));
82566 }
82567 else {
82568 callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, pixels));
82569 }
82570 }
82571 callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
82572 }
82573 function createBufferFromOutputTexture(gl2, rows, columns, textureConfig) {
82574 // Create and bind the buffer.
82575 const buffer = gl2.createBuffer();
82576 callAndCheck(gl2, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer));
82577 // Initialize the buffer to the size of the texture in bytes.
82578 const bytesPerFloat = 4;
82579 const valuesPerTexel = 4;
82580 const bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns;
82581 callAndCheck(gl2, () => gl2.bufferData(gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ));
82582 // Enqueue a command on the GPU command queue to copy of texture into the
82583 // buffer.
82584 callAndCheck(gl2, () => gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0));
82585 callAndCheck(gl2, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null));
82586 return buffer;
82587 }
82588 function downloadFloat32MatrixFromBuffer(gl, buffer, size) {
82589 const gl2 = gl;
82590 const downloadTarget = new Float32Array(size);
82591 gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
82592 gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
82593 gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
82594 return downloadTarget;
82595 }
82596 function downloadByteEncodedFloatMatrixFromOutputTexture(gl, rows, columns, textureConfig) {
82597 const [w, h] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
82598 const numChannels = 4;
82599 const downloadTarget = new Uint8Array(getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels));
82600 callAndCheck(gl, () => gl.readPixels(0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE, downloadTarget));
82601 // By wrapping the buffer in a Float32Array, we use native browser IEEE 754
82602 // decoding of the 4 bytes that back each 32 bit float.
82603 return new Float32Array(downloadTarget.buffer);
82604 }
82605 function downloadPackedMatrixFromBuffer(gl, buffer, batch, rows, cols, physicalRows, physicalCols, textureConfig) {
82606 const gl2 = gl;
82607 const downloadTarget = new Float32Array(getPackedRGBAArraySizeFromMatrixShape(physicalRows, physicalCols));
82608 gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
82609 gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
82610 gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
82611 return downloadTarget;
82612 }
82613 function downloadMatrixFromPackedOutputTexture(gl, physicalRows, physicalCols) {
82614 const packedRGBA = new Float32Array(physicalRows * physicalCols * 4);
82615 callAndCheck(gl, () => gl.readPixels(0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA));
82616 return packedRGBA;
82617 }
82618
82619 /**
82620 * @license
82621 * Copyright 2017 Google LLC. All Rights Reserved.
82622 * Licensed under the Apache License, Version 2.0 (the "License");
82623 * you may not use this file except in compliance with the License.
82624 * You may obtain a copy of the License at
82625 *
82626 * http://www.apache.org/licenses/LICENSE-2.0
82627 *
82628 * Unless required by applicable law or agreed to in writing, software
82629 * distributed under the License is distributed on an "AS IS" BASIS,
82630 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82631 * See the License for the specific language governing permissions and
82632 * limitations under the License.
82633 * =============================================================================
82634 */
82635 class GPGPUContext {
82636 constructor(gl) {
82637 this.outputTexture = null;
82638 this.program = null;
82639 this.disposed = false;
82640 this.vertexAttrsAreBound = false;
82641 this.itemsToPoll = [];
82642 const glVersion = env().getNumber('WEBGL_VERSION');
82643 if (gl != null) {
82644 this.gl = gl;
82645 setWebGLContext(glVersion, gl);
82646 }
82647 else {
82648 this.gl = getWebGLContext(glVersion);
82649 }
82650 // WebGL 2.0 enables texture floats without an extension.
82651 let COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float';
82652 const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
82653 this.parallelCompilationExtension =
82654 this.gl.getExtension('KHR_parallel_shader_compile');
82655 if (env().getNumber('WEBGL_VERSION') === 1) {
82656 const TEXTURE_FLOAT = 'OES_texture_float';
82657 const TEXTURE_HALF_FLOAT = 'OES_texture_half_float';
82658 this.textureFloatExtension =
82659 getExtensionOrThrow(this.gl, TEXTURE_FLOAT);
82660 if (hasExtension(this.gl, TEXTURE_HALF_FLOAT)) {
82661 this.textureHalfFloatExtension =
82662 getExtensionOrThrow(this.gl, TEXTURE_HALF_FLOAT);
82663 }
82664 else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {
82665 throw new Error('GL context does not support half float textures, yet the ' +
82666 'environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
82667 }
82668 this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT);
82669 if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
82670 this.colorBufferHalfFloatExtension =
82671 getExtensionOrThrow(this.gl, COLOR_BUFFER_HALF_FLOAT);
82672 }
82673 else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {
82674 throw new Error('GL context does not support color renderable half floats, yet ' +
82675 'the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
82676 }
82677 }
82678 else {
82679 COLOR_BUFFER_FLOAT = 'EXT_color_buffer_float';
82680 if (hasExtension(this.gl, COLOR_BUFFER_FLOAT)) {
82681 this.colorBufferFloatExtension =
82682 this.gl.getExtension(COLOR_BUFFER_FLOAT);
82683 }
82684 else if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
82685 this.colorBufferHalfFloatExtension =
82686 this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
82687 }
82688 else {
82689 throw new Error('GL context does not support color renderable floats');
82690 }
82691 }
82692 this.vertexBuffer = createVertexBuffer(this.gl);
82693 this.indexBuffer = createIndexBuffer(this.gl);
82694 this.framebuffer = createFramebuffer(this.gl);
82695 this.textureConfig =
82696 getTextureConfig(this.gl, this.textureHalfFloatExtension);
82697 }
82698 get debug() {
82699 return env().getBool('DEBUG');
82700 }
82701 dispose() {
82702 if (this.disposed) {
82703 return;
82704 }
82705 if (this.program != null) {
82706 console.warn('Disposing a GPGPUContext that still has a bound WebGLProgram.' +
82707 ' This is probably a resource leak, delete the program with ' +
82708 'GPGPUContext.deleteProgram before disposing.');
82709 }
82710 if (this.outputTexture != null) {
82711 console.warn('Disposing a GPGPUContext that still has a bound output matrix ' +
82712 'texture. This is probably a resource leak, delete the output ' +
82713 'matrix texture with GPGPUContext.deleteMatrixTexture before ' +
82714 'disposing.');
82715 }
82716 const gl = this.gl;
82717 callAndCheck(gl, () => gl.finish());
82718 callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null));
82719 callAndCheck(gl, () => gl.deleteFramebuffer(this.framebuffer));
82720 callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, null));
82721 callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null));
82722 callAndCheck(gl, () => gl.deleteBuffer(this.indexBuffer));
82723 this.disposed = true;
82724 }
82725 createFloat32MatrixTexture(rows, columns) {
82726 this.throwIfDisposed();
82727 return createFloat32MatrixTexture(this.gl, rows, columns, this.textureConfig);
82728 }
82729 createFloat16MatrixTexture(rows, columns) {
82730 this.throwIfDisposed();
82731 return createFloat16MatrixTexture(this.gl, rows, columns, this.textureConfig);
82732 }
82733 createUnsignedBytesMatrixTexture(rows, columns) {
82734 this.throwIfDisposed();
82735 return createUnsignedBytesMatrixTexture(this.gl, rows, columns, this.textureConfig);
82736 }
82737 uploadPixelDataToTexture(texture, pixels) {
82738 this.throwIfDisposed();
82739 uploadPixelDataToTexture(this.gl, texture, pixels);
82740 }
82741 uploadDenseMatrixToTexture(texture, width, height, data) {
82742 this.throwIfDisposed();
82743 uploadDenseMatrixToTexture(this.gl, texture, width, height, data, this.textureConfig);
82744 }
82745 createFloat16PackedMatrixTexture(rows, columns) {
82746 this.throwIfDisposed();
82747 return createFloat16PackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
82748 }
82749 createPackedMatrixTexture(rows, columns) {
82750 this.throwIfDisposed();
82751 return createPackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
82752 }
82753 deleteMatrixTexture(texture) {
82754 this.throwIfDisposed();
82755 if (this.outputTexture === texture) {
82756 unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
82757 this.outputTexture = null;
82758 }
82759 callAndCheck(this.gl, () => this.gl.deleteTexture(texture));
82760 }
82761 downloadByteEncodedFloatMatrixFromOutputTexture(texture, rows, columns) {
82762 return this.downloadMatrixDriver(texture, () => downloadByteEncodedFloatMatrixFromOutputTexture(this.gl, rows, columns, this.textureConfig));
82763 }
82764 downloadPackedMatrixFromBuffer(buffer, batch, rows, columns, physicalRows, physicalCols) {
82765 return downloadPackedMatrixFromBuffer(this.gl, buffer, batch, rows, columns, physicalRows, physicalCols, this.textureConfig);
82766 }
82767 downloadFloat32MatrixFromBuffer(buffer, size) {
82768 return downloadFloat32MatrixFromBuffer(this.gl, buffer, size);
82769 }
82770 createBufferFromTexture(texture, rows, columns) {
82771 this.bindTextureToFrameBuffer(texture);
82772 const result = createBufferFromOutputTexture(this.gl, rows, columns, this.textureConfig);
82773 this.unbindTextureToFrameBuffer();
82774 return result;
82775 }
82776 createAndWaitForFence() {
82777 const fenceContext = this.createFence(this.gl);
82778 return this.pollFence(fenceContext);
82779 }
82780 createFence(gl) {
82781 let query;
82782 let isFencePassed;
82783 if (env().getBool('WEBGL_FENCE_API_ENABLED')) {
82784 const gl2 = gl;
82785 const sync = gl2.fenceSync(gl2.SYNC_GPU_COMMANDS_COMPLETE, 0);
82786 gl.flush();
82787 isFencePassed = () => {
82788 const status = gl2.clientWaitSync(sync, 0, 0);
82789 return status === gl2.ALREADY_SIGNALED ||
82790 status === gl2.CONDITION_SATISFIED;
82791 };
82792 query = sync;
82793 }
82794 else if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
82795 query = this.beginQuery();
82796 this.endQuery();
82797 isFencePassed = () => this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'));
82798 }
82799 else {
82800 // If we have no way to fence, return true immediately. This will fire in
82801 // WebGL 1.0 when there is no disjoint query timer. In this case, because
82802 // the fence passes immediately, we'll immediately ask for a download of
82803 // the texture, which will cause the UI thread to hang.
82804 isFencePassed = () => true;
82805 }
82806 return { query, isFencePassed };
82807 }
82808 downloadMatrixFromPackedTexture(texture, physicalRows, physicalCols) {
82809 return this.downloadMatrixDriver(texture, () => downloadMatrixFromPackedOutputTexture(this.gl, physicalRows, physicalCols));
82810 }
82811 createProgram(fragmentShader) {
82812 this.throwIfDisposed();
82813 const gl = this.gl;
82814 if (this.vertexShader == null) {
82815 this.vertexShader = createVertexShader$1(gl);
82816 }
82817 const program = createProgram(gl);
82818 callAndCheck(gl, () => gl.attachShader(program, this.vertexShader));
82819 callAndCheck(gl, () => gl.attachShader(program, fragmentShader));
82820 linkProgram(gl, program);
82821 if (this.debug) {
82822 validateProgram(gl, program);
82823 }
82824 if (!this.vertexAttrsAreBound) {
82825 this.setProgram(program);
82826 this.vertexAttrsAreBound = bindVertexProgramAttributeStreams(gl, this.program, this.vertexBuffer);
82827 }
82828 return program;
82829 }
82830 deleteProgram(program) {
82831 this.throwIfDisposed();
82832 if (program === this.program) {
82833 this.program = null;
82834 }
82835 if (program != null) {
82836 callAndCheck(this.gl, () => this.gl.deleteProgram(program));
82837 }
82838 }
82839 setProgram(program) {
82840 this.throwIfDisposed();
82841 this.program = program;
82842 if ((this.program != null) && this.debug) {
82843 validateProgram(this.gl, this.program);
82844 }
82845 callAndCheck(this.gl, () => this.gl.useProgram(program));
82846 }
82847 getUniformLocation(program, uniformName, shouldThrow = true) {
82848 this.throwIfDisposed();
82849 if (shouldThrow) {
82850 return getProgramUniformLocationOrThrow(this.gl, program, uniformName);
82851 }
82852 else {
82853 return getProgramUniformLocation(this.gl, program, uniformName);
82854 }
82855 }
82856 getAttributeLocation(program, attribute) {
82857 this.throwIfDisposed();
82858 return callAndCheck(this.gl, () => this.gl.getAttribLocation(program, attribute));
82859 }
82860 getUniformLocationNoThrow(program, uniformName) {
82861 this.throwIfDisposed();
82862 return this.gl.getUniformLocation(program, uniformName);
82863 }
82864 setInputMatrixTexture(inputMatrixTexture, uniformLocation, textureUnit) {
82865 this.throwIfDisposed();
82866 this.throwIfNoProgram();
82867 bindTextureToProgramUniformSampler(this.gl, inputMatrixTexture, uniformLocation, textureUnit);
82868 }
82869 setOutputMatrixTexture(outputMatrixTexture, rows, columns) {
82870 this.setOutputMatrixTextureDriver(outputMatrixTexture, columns, rows);
82871 }
82872 setOutputPackedMatrixTexture(outputPackedMatrixTexture, rows, columns) {
82873 this.throwIfDisposed();
82874 const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
82875 this.setOutputMatrixTextureDriver(outputPackedMatrixTexture, width, height);
82876 }
82877 setOutputMatrixWriteRegion(startRow, numRows, startColumn, numColumns) {
82878 this.setOutputMatrixWriteRegionDriver(startColumn, startRow, numColumns, numRows);
82879 }
82880 setOutputPackedMatrixWriteRegion(startRow, numRows, startColumn, numColumns) {
82881 throw new Error('setOutputPackedMatrixWriteRegion not implemented.');
82882 }
82883 debugValidate() {
82884 if (this.program != null) {
82885 validateProgram(this.gl, this.program);
82886 }
82887 validateFramebuffer(this.gl);
82888 }
82889 executeProgram() {
82890 this.throwIfDisposed();
82891 this.throwIfNoProgram();
82892 const gl = this.gl;
82893 if (this.debug) {
82894 this.debugValidate();
82895 }
82896 callAndCheck(gl, () => gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0));
82897 }
82898 blockUntilAllProgramsCompleted() {
82899 this.throwIfDisposed();
82900 callAndCheck(this.gl, () => this.gl.finish());
82901 }
82902 getQueryTimerExtension() {
82903 if (this.disjointQueryTimerExtension == null) {
82904 this.disjointQueryTimerExtension =
82905 getExtensionOrThrow(this.gl, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ?
82906 'EXT_disjoint_timer_query_webgl2' :
82907 'EXT_disjoint_timer_query');
82908 }
82909 return this.disjointQueryTimerExtension;
82910 }
82911 getQueryTimerExtensionWebGL2() {
82912 return this.getQueryTimerExtension();
82913 }
82914 getQueryTimerExtensionWebGL1() {
82915 return this.getQueryTimerExtension();
82916 }
82917 beginQuery() {
82918 if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
82919 const gl2 = this.gl;
82920 const ext = this.getQueryTimerExtensionWebGL2();
82921 const query = gl2.createQuery();
82922 gl2.beginQuery(ext.TIME_ELAPSED_EXT, query);
82923 return query;
82924 }
82925 const ext = this.getQueryTimerExtensionWebGL1();
82926 const query = ext.createQueryEXT();
82927 ext.beginQueryEXT(ext.TIME_ELAPSED_EXT, query);
82928 return query;
82929 }
82930 endQuery() {
82931 if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
82932 const gl2 = this.gl;
82933 const ext = this.getQueryTimerExtensionWebGL2();
82934 gl2.endQuery(ext.TIME_ELAPSED_EXT);
82935 return;
82936 }
82937 const ext = this.getQueryTimerExtensionWebGL1();
82938 ext.endQueryEXT(ext.TIME_ELAPSED_EXT);
82939 }
82940 async waitForQueryAndGetTime(query) {
82941 await repeatedTry(() => this.disposed || // while testing contexts are created / disposed
82942 // in rapid succession, so without this check we
82943 // may poll for the query timer indefinitely
82944 this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')));
82945 return this.getQueryTime(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'));
82946 }
82947 getQueryTime(query, queryTimerVersion) {
82948 if (queryTimerVersion === 0) {
82949 return null;
82950 }
82951 if (queryTimerVersion === 2) {
82952 const gl2 = this.gl;
82953 const timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT);
82954 // Return milliseconds.
82955 return timeElapsedNanos / 1000000;
82956 }
82957 else {
82958 const ext = this.getQueryTimerExtensionWebGL1();
82959 const timeElapsedNanos = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT);
82960 // Return milliseconds.
82961 return timeElapsedNanos / 1000000;
82962 }
82963 }
82964 isQueryAvailable(query, queryTimerVersion) {
82965 if (queryTimerVersion === 0) {
82966 return true;
82967 }
82968 if (queryTimerVersion === 2) {
82969 const gl2 = this.gl;
82970 const ext = this.getQueryTimerExtensionWebGL2();
82971 const available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE);
82972 if (this.disjoint == null) {
82973 this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
82974 }
82975 return available && !this.disjoint;
82976 }
82977 else {
82978 const ext = this.getQueryTimerExtensionWebGL1();
82979 const available = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_AVAILABLE_EXT);
82980 if (this.disjoint == null) {
82981 this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
82982 }
82983 return available && !this.disjoint;
82984 }
82985 }
82986 pollFence(fenceContext) {
82987 return new Promise(resolve => {
82988 this.addItemToPoll(() => fenceContext.isFencePassed(), () => resolve());
82989 });
82990 }
82991 pollItems() {
82992 // Find the last query that has finished.
82993 const index = linearSearchLastTrue(this.itemsToPoll.map(x => x.isDoneFn));
82994 for (let i = 0; i <= index; ++i) {
82995 const { resolveFn } = this.itemsToPoll[i];
82996 resolveFn();
82997 }
82998 this.itemsToPoll = this.itemsToPoll.slice(index + 1);
82999 }
83000 addItemToPoll(isDoneFn, resolveFn) {
83001 this.itemsToPoll.push({ isDoneFn, resolveFn });
83002 if (this.itemsToPoll.length > 1) {
83003 // We already have a running loop that polls.
83004 return;
83005 }
83006 // Start a new loop that polls.
83007 repeatedTry(() => {
83008 this.pollItems();
83009 // End the loop if no more items to poll.
83010 return this.itemsToPoll.length === 0;
83011 });
83012 }
83013 bindTextureToFrameBuffer(texture) {
83014 this.throwIfDisposed();
83015 bindColorTextureToFramebuffer(this.gl, texture, this.framebuffer);
83016 if (this.debug) {
83017 validateFramebuffer(this.gl);
83018 }
83019 }
83020 unbindTextureToFrameBuffer() {
83021 if (this.outputTexture != null) {
83022 bindColorTextureToFramebuffer(this.gl, this.outputTexture, this.framebuffer);
83023 if (this.debug) {
83024 validateFramebuffer(this.gl);
83025 }
83026 }
83027 else {
83028 unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
83029 }
83030 }
83031 downloadMatrixDriver(texture, downloadAndDecode) {
83032 this.bindTextureToFrameBuffer(texture);
83033 const result = downloadAndDecode();
83034 this.unbindTextureToFrameBuffer();
83035 return result;
83036 }
83037 setOutputMatrixTextureDriver(outputMatrixTextureMaybePacked, width, height) {
83038 this.throwIfDisposed();
83039 const gl = this.gl;
83040 bindColorTextureToFramebuffer(gl, outputMatrixTextureMaybePacked, this.framebuffer);
83041 if (this.debug) {
83042 validateFramebuffer(gl);
83043 }
83044 this.outputTexture = outputMatrixTextureMaybePacked;
83045 callAndCheck(gl, () => gl.viewport(0, 0, width, height));
83046 callAndCheck(gl, () => gl.scissor(0, 0, width, height));
83047 }
83048 setOutputMatrixWriteRegionDriver(x, y, width, height) {
83049 this.throwIfDisposed();
83050 callAndCheck(this.gl, () => this.gl.scissor(x, y, width, height));
83051 }
83052 throwIfDisposed() {
83053 if (this.disposed) {
83054 throw new Error('Attempted to use disposed GPGPUContext.');
83055 }
83056 }
83057 throwIfNoProgram() {
83058 if (this.program == null) {
83059 throw new Error('No GPU program is currently set.');
83060 }
83061 }
83062 }
83063 /**
83064 * Finds the index of the last true element using linear search.
83065 * Note: We can't do binary search because Chrome expects us to explicitly
83066 * test all fences before download:
83067 * https://github.com/tensorflow/tfjs/issues/1145
83068 */
83069 function linearSearchLastTrue(arr) {
83070 let i = 0;
83071 for (; i < arr.length; ++i) {
83072 const isDone = arr[i]();
83073 if (!isDone) {
83074 break;
83075 }
83076 }
83077 return i - 1;
83078 }
83079
83080 /**
83081 * @license
83082 * Copyright 2020 Google LLC. All Rights Reserved.
83083 * Licensed under the Apache License, Version 2.0 (the "License");
83084 * you may not use this file except in compliance with the License.
83085 * You may obtain a copy of the License at
83086 *
83087 * http://www.apache.org/licenses/LICENSE-2.0
83088 *
83089 * Unless required by applicable law or agreed to in writing, software
83090 * distributed under the License is distributed on an "AS IS" BASIS,
83091 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
83092 * See the License for the specific language governing permissions and
83093 * limitations under the License.
83094 * =============================================================================
83095 */
83096 const { addImpl: addImplCPU, bincountImpl: bincountImplCPU, bincountReduceImpl: bincountReduceImplCPU, ceilImpl: ceilImplCPU, concatImpl: concatImplCPU, equalImpl: equalImplCPU, expImpl: expImplCPU, expm1Impl: expm1ImplCPU, floorImpl: floorImplCPU, gatherNdImpl: gatherNdImplCPU, gatherV2Impl: gatherV2ImplCPU, greaterImpl: greaterImplCPU, greaterEqualImpl: greaterEqualImplCPU, lessImpl: lessImplCPU, lessEqualImpl: lessEqualImplCPU, linSpaceImpl: linSpaceImplCPU, logImpl: logImplCPU, maxImpl: maxImplCPU, maximumImpl: maximumImplCPU, minimumImpl: minimumImplCPU, multiplyImpl: multiplyImplCPU, negImpl: negImplCPU, notEqualImpl: notEqualImplCPU, prodImpl: prodImplCPU, rangeImpl: rangeImplCPU, rsqrtImpl: rsqrtImplCPU, scatterImpl: scatterImplCPU, sigmoidImpl: sigmoidImplCPU, simpleAbsImpl: simpleAbsImplCPU, sliceImpl: sliceImplCPU, sparseFillEmptyRowsImpl: sparseFillEmptyRowsImplCPU, sparseReshapeImpl: sparseReshapeImplCPU, sparseSegmentReductionImpl: sparseSegmentReductionImplCPU, sqrtImpl: sqrtImplCPU, stridedSliceImpl: stridedSliceImplCPU, stringNGramsImpl: stringNGramsImplCPU, stringSplitImpl: stringSplitImplCPU, stringToHashBucketFastImpl: stringToHashBucketFastImplCPU, subImpl: subImplCPU, tileImpl: tileImplCPU, topKImpl: topKImplCPU, transposeImpl: transposeImplCPU, uniqueImpl: uniqueImplCPU, } = shared;
83097
83098 /**
83099 * @license
83100 * Copyright 2018 Google LLC. All Rights Reserved.
83101 * Licensed under the Apache License, Version 2.0 (the "License");
83102 * you may not use this file except in compliance with the License.
83103 * You may obtain a copy of the License at
83104 *
83105 * http://www.apache.org/licenses/LICENSE-2.0
83106 *
83107 * Unless required by applicable law or agreed to in writing, software
83108 * distributed under the License is distributed on an "AS IS" BASIS,
83109 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
83110 * See the License for the specific language governing permissions and
83111 * limitations under the License.
83112 * =============================================================================
83113 */
83114 function getVecChannels(name, rank) {
83115 return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(d => `${name}.${d}`);
83116 }
83117 function getChannels(name, rank) {
83118 if (rank === 1) {
83119 return [name];
83120 }
83121 return getVecChannels(name, rank);
83122 }
83123 function getSourceCoords(rank, dims) {
83124 if (rank === 1) {
83125 return 'rc';
83126 }
83127 let coords = '';
83128 for (let i = 0; i < rank; i++) {
83129 coords += dims[i];
83130 if (i < rank - 1) {
83131 coords += ',';
83132 }
83133 }
83134 return coords;
83135 }
83136
83137 /**
83138 * @license
83139 * Copyright 2018 Google LLC. All Rights Reserved.
83140 * Licensed under the Apache License, Version 2.0 (the "License");
83141 * you may not use this file except in compliance with the License.
83142 * You may obtain a copy of the License at
83143 *
83144 * http://www.apache.org/licenses/LICENSE-2.0
83145 *
83146 * Unless required by applicable law or agreed to in writing, software
83147 * distributed under the License is distributed on an "AS IS" BASIS,
83148 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
83149 * See the License for the specific language governing permissions and
83150 * limitations under the License.
83151 * =============================================================================
83152 */
83153 class PackProgram {
83154 constructor(outputShape) {
83155 this.variableNames = ['A'];
83156 this.packedInputs = false;
83157 this.packedOutput = true;
83158 // Only input / output 3D tensors.
83159 this.outputShape = outputShape;
83160 this.rank = outputShape.length;
83161 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
83162 if (this.rank === 0) {
83163 this.userCode = `
83164 void main() {
83165 setOutput(vec4(getA(), 0., 0., 0.));
83166 }
83167 `;
83168 }
83169 else {
83170 const channels = getChannels('rc', this.rank);
83171 const dtype = getCoordsDataType(this.rank);
83172 const outOfBoundsCondition = this.getOutOfBoundsCondition(channels);
83173 const setup = this.getSetup(channels);
83174 const output = this.getOutput(channels);
83175 this.userCode = `
83176 void main() {
83177 ${dtype} rc = getOutputCoords();
83178
83179 if(${outOfBoundsCondition}) {
83180 setOutput(vec4(0));
83181 } else {
83182 ${setup}
83183
83184 setOutput(vec4(${output}));
83185 }
83186 }
83187 `;
83188 }
83189 }
83190 getSourceCoordsArr(dims) {
83191 const coords = [];
83192 for (let row = 0; row <= 1; row++) {
83193 for (let col = 0; col <= 1; col++) {
83194 let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`;
83195 for (let d = 2; d < this.rank; d++) {
83196 coord = `${dims[dims.length - 1 - d]},` + coord;
83197 }
83198 coords.push(coord);
83199 }
83200 }
83201 return coords;
83202 }
83203 getOutOfBoundsCondition(dims) {
83204 if (this.rank === 1) {
83205 return `rc > ${this.enableShapeUniforms ? 'outShape' : this.outputShape[0]}`;
83206 }
83207 let cond = '';
83208 for (let i = this.rank - 2; i < this.rank; i++) {
83209 cond += `${dims[i]} >= ${this.enableShapeUniforms ? `outShape[${i}]` : this.outputShape[i]}`;
83210 if (i < this.rank - 1) {
83211 cond += '||';
83212 }
83213 }
83214 return cond;
83215 }
83216 getSetup(dims) {
83217 if (this.rank === 1) {
83218 return '';
83219 }
83220 const innerDims = dims.slice(-2);
83221 const col = this.enableShapeUniforms ? `outShape[${this.rank} - 1]` :
83222 this.outputShape[this.rank - 1];
83223 const row = this.enableShapeUniforms ? `outShape[${this.rank} - 2]` :
83224 this.outputShape[this.rank - 2];
83225 return `
83226 int r = ${innerDims[0]};
83227 int c = ${innerDims[1]};
83228 int rp1 = r + 1;
83229 int cp1 = c + 1;
83230
83231 bool cEdge = cp1 >= ${col};
83232 bool rEdge = rp1 >= ${row};
83233 `;
83234 }
83235 getOutput(dims) {
83236 const sourceCoords = this.getSourceCoordsArr(dims);
83237 if (this.rank === 1) {
83238 const outShape = this.enableShapeUniforms ? 'outShape' : this.outputShape[0];
83239 return `getA(rc), (rc + 1 >= ${outShape} ? 0. : getA(rc + 1)), 0, 0`;
83240 }
83241 return `getA(${sourceCoords[0]}),
83242 cEdge ? 0. : getA(${sourceCoords[1]}),
83243 rEdge ? 0. : getA(${sourceCoords[2]}),
83244 rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`;
83245 }
83246 }
83247
83248 /**
83249 * @license
83250 * Copyright 2018 Google LLC. All Rights Reserved.
83251 * Licensed under the Apache License, Version 2.0 (the "License");
83252 * you may not use this file except in compliance with the License.
83253 * You may obtain a copy of the License at
83254 *
83255 * http://www.apache.org/licenses/LICENSE-2.0
83256 *
83257 * Unless required by applicable law or agreed to in writing, software
83258 * distributed under the License is distributed on an "AS IS" BASIS,
83259 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
83260 * See the License for the specific language governing permissions and
83261 * limitations under the License.
83262 * =============================================================================
83263 */
83264 class ReshapePackedProgram {
83265 constructor(outputShape, inputShape) {
83266 this.variableNames = ['A'];
83267 this.packedInputs = true;
83268 this.packedOutput = true;
83269 this.customUniforms = [{ name: 'inputShape', type: 'ivec3' }];
83270 this.outputShape = outputShape;
83271 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
83272 let mainLoop = ``;
83273 for (let i = 0; i < 4; i++) {
83274 let thisRC = `thisRC = rc;`;
83275 if (i % 2 === 1) {
83276 thisRC += `thisRC.z += 1;`;
83277 }
83278 if (i > 1) {
83279 thisRC += `thisRC.y += 1;`;
83280 }
83281 mainLoop += `
83282 ${thisRC}
83283 ${i > 0 ? `if(thisRC.y < rows && thisRC.z < cols){` : ''}
83284 int flatIndex = getFlatIndex(thisRC);
83285
83286 ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);
83287 vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));
83288
83289 result[${i}] =
83290 getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);
83291 ${i > 0 ? '}' : ''}
83292 `;
83293 }
83294 this.userCode = `
83295 ${getReshapedInputCoords(inputShape, this.enableShapeUniforms)}
83296 ${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() :
83297 getFlatIndexFrom3D(outputShape)}
83298
83299 void main() {
83300 ivec3 rc = getOutputCoords();
83301
83302 vec4 result = vec4(0.);
83303
83304 ivec3 thisRC;
83305 int rows = ${this.enableShapeUniforms ? 'outShape[1]' : outputShape[1]};
83306 int cols = ${this.enableShapeUniforms ? 'outShape[2]' : outputShape[2]};
83307
83308 ${mainLoop}
83309
83310 setOutput(result);
83311 }
83312 `;
83313 }
83314 }
83315 function getReshapedInputCoords(shape, enableShapeUniforms) {
83316 const coordsFromIndexSnippet = enableShapeUniforms ?
83317 getLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], 'inputShape') :
83318 getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
83319 return `
83320 ivec3 inputCoordsFromReshapedOutCoords(int index) {
83321 ${coordsFromIndexSnippet}
83322 return ivec3(r, c, d);
83323 }
83324 `;
83325 }
83326
83327 /**
83328 * @license
83329 * Copyright 2017 Google LLC. All Rights Reserved.
83330 * Licensed under the Apache License, Version 2.0 (the "License");
83331 * you may not use this file except in compliance with the License.
83332 * You may obtain a copy of the License at
83333 *
83334 * http://www.apache.org/licenses/LICENSE-2.0
83335 *
83336 * Unless required by applicable law or agreed to in writing, software
83337 * distributed under the License is distributed on an "AS IS" BASIS,
83338 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
83339 * See the License for the specific language governing permissions and
83340 * limitations under the License.
83341 * =============================================================================
83342 */
83343 class TextureManager {
83344 constructor(gpgpu) {
83345 this.gpgpu = gpgpu;
83346 this.numUsedTextures = 0;
83347 this.numFreeTextures = 0;
83348 this._numBytesAllocated = 0;
83349 this._numBytesFree = 0; // How many bytes that have been allocated
83350 // are available for reuse.
83351 this.freeTextures = {};
83352 this.logEnabled = false;
83353 this.usedTextures = {};
83354 }
83355 acquireTexture(shapeRC, usage, isPacked) {
83356 const physicalTexType = getPhysicalFromLogicalTextureType(usage, isPacked);
83357 const shapeKey = getKeyFromTextureShape(shapeRC, physicalTexType, isPacked);
83358 if (!(shapeKey in this.freeTextures)) {
83359 this.freeTextures[shapeKey] = [];
83360 }
83361 if (!(shapeKey in this.usedTextures)) {
83362 this.usedTextures[shapeKey] = [];
83363 }
83364 const texBytes = computeBytes(shapeRC, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
83365 if (this.freeTextures[shapeKey].length > 0) {
83366 this.numFreeTextures--;
83367 this.numUsedTextures++;
83368 this._numBytesFree -= texBytes;
83369 this.log();
83370 const newTexture = this.freeTextures[shapeKey].shift();
83371 this.usedTextures[shapeKey].push(newTexture);
83372 return newTexture;
83373 }
83374 let newTexture;
83375 if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT32) {
83376 newTexture = this.gpgpu.createPackedMatrixTexture(shapeRC[0], shapeRC[1]);
83377 }
83378 else if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT16) {
83379 newTexture =
83380 this.gpgpu.createFloat16PackedMatrixTexture(shapeRC[0], shapeRC[1]);
83381 }
83382 else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT32) {
83383 newTexture =
83384 this.gpgpu.createFloat32MatrixTexture(shapeRC[0], shapeRC[1]);
83385 }
83386 else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT16) {
83387 newTexture =
83388 this.gpgpu.createFloat16MatrixTexture(shapeRC[0], shapeRC[1]);
83389 }
83390 else if (physicalTexType === PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE) {
83391 newTexture =
83392 this.gpgpu.createUnsignedBytesMatrixTexture(shapeRC[0], shapeRC[1]);
83393 }
83394 this.usedTextures[shapeKey].push(newTexture);
83395 this.numUsedTextures++;
83396 this._numBytesAllocated += texBytes;
83397 this.log();
83398 return newTexture;
83399 }
83400 releaseTexture(texture, shape, logicalTexType, isPacked) {
83401 if (this.freeTextures == null) {
83402 // Already disposed.
83403 return;
83404 }
83405 const physicalTexType = getPhysicalFromLogicalTextureType(logicalTexType, isPacked);
83406 const shapeKey = getKeyFromTextureShape(shape, physicalTexType, isPacked);
83407 if (!(shapeKey in this.freeTextures)) {
83408 this.freeTextures[shapeKey] = [];
83409 }
83410 const texBytes = computeBytes(shape, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
83411 const deleteTexThreshold = env().get('WEBGL_DELETE_TEXTURE_THRESHOLD');
83412 if (deleteTexThreshold !== -1 &&
83413 this._numBytesAllocated > deleteTexThreshold) {
83414 this.gpgpu.deleteMatrixTexture(texture.texture);
83415 this._numBytesAllocated -= texBytes;
83416 }
83417 else {
83418 this.freeTextures[shapeKey].push(texture);
83419 this.numFreeTextures++;
83420 this._numBytesFree += texBytes;
83421 }
83422 this.numUsedTextures--;
83423 const texList = this.usedTextures[shapeKey];
83424 const texIndex = texList.indexOf(texture);
83425 if (texIndex < 0) {
83426 throw new Error('Cannot release a texture that was never provided by this ' +
83427 'texture manager');
83428 }
83429 texList.splice(texIndex, 1);
83430 this.log();
83431 }
83432 log() {
83433 if (!this.logEnabled) {
83434 return;
83435 }
83436 const total = this.numFreeTextures + this.numUsedTextures;
83437 console.log('Free/Used', `${this.numFreeTextures} / ${this.numUsedTextures}`, `(${total})`);
83438 const freeRatio = this._numBytesFree / this._numBytesAllocated;
83439 console.log(`Bytes allocated: ${this._numBytesAllocated}`);
83440 console.log(`Bytes unused: ${this._numBytesFree} (${Math.round(100 * freeRatio)}%)`);
83441 }
83442 get numBytesAllocated() {
83443 return this._numBytesAllocated;
83444 }
83445 get numBytesFree() {
83446 return this._numBytesFree;
83447 }
83448 getNumUsedTextures() {
83449 return this.numUsedTextures;
83450 }
83451 getNumFreeTextures() {
83452 return this.numFreeTextures;
83453 }
83454 dispose() {
83455 if (this.freeTextures == null) {
83456 // Already disposed.
83457 return;
83458 }
83459 for (const texShape in this.freeTextures) {
83460 this.freeTextures[texShape].forEach(tex => {
83461 this.gpgpu.deleteMatrixTexture(tex.texture);
83462 });
83463 }
83464 for (const texShape in this.usedTextures) {
83465 this.usedTextures[texShape].forEach(tex => {
83466 this.gpgpu.deleteMatrixTexture(tex.texture);
83467 });
83468 }
83469 this.freeTextures = null;
83470 this.usedTextures = null;
83471 this.numUsedTextures = 0;
83472 this.numFreeTextures = 0;
83473 this._numBytesAllocated = 0;
83474 this._numBytesFree = 0;
83475 }
83476 }
83477 function numBytesForInternalFormat(gl, internalFormat) {
83478 // tslint:disable-next-line:no-any
83479 const glany = gl;
83480 if (internalFormat === glany.R32F) {
83481 return 4;
83482 }
83483 else if (internalFormat === glany.R16F) {
83484 return 2;
83485 }
83486 else if (internalFormat === glany.RGBA32F) {
83487 return 16;
83488 }
83489 else if (internalFormat === gl.RGBA) {
83490 return 16;
83491 }
83492 else if (internalFormat === glany.RGBA16F) {
83493 return 8;
83494 }
83495 else if (internalFormat === glany.RGBA8) {
83496 return 4;
83497 }
83498 throw new Error(`Unknown internal format ${internalFormat}`);
83499 }
83500 function computeBytes(shape, physicalTexType, gl, textureConfig, isPacked) {
83501 // It is not possible to infer packed status from the texture type because
83502 // depending on the textureConfig, different texture types may resolve to the
83503 // same internal format (e.g. in WebGL1, the internal format for
83504 // UNPACKED_FLOAT16 textures is gl.RGBA). Therefore we pass in `isPacked`
83505 // explicitly.
83506 const internalFormat = internalFormatForPhysicalTexType(physicalTexType, textureConfig);
83507 let numElements;
83508 if (isPacked) {
83509 const [packedWidth, packedHeight] = getPackedMatrixTextureShapeWidthHeight(shape[0], shape[1]);
83510 numElements = packedWidth * packedHeight;
83511 }
83512 else {
83513 const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(shape[0], shape[1]);
83514 numElements = width * height;
83515 }
83516 const bytesPerElement = numBytesForInternalFormat(gl, internalFormat);
83517 return numElements * bytesPerElement;
83518 }
83519 function internalFormatForPhysicalTexType(physicalTexType, textureConfig) {
83520 switch (physicalTexType) {
83521 case PhysicalTextureType.PACKED_2X2_FLOAT32:
83522 return getInternalFormatForPackedMatrixTexture(textureConfig);
83523 case PhysicalTextureType.PACKED_2X2_FLOAT16:
83524 return getInternalFormatForFloat16PackedMatrixTexture(textureConfig);
83525 case PhysicalTextureType.UNPACKED_FLOAT32:
83526 return getInternalFormatForFloat32MatrixTexture(textureConfig);
83527 case PhysicalTextureType.UNPACKED_FLOAT16:
83528 return getInternalFormatForFloat16MatrixTexture(textureConfig);
83529 case PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE:
83530 return getInternalFormatForUnsignedBytesMatrixTexture(textureConfig);
83531 default:
83532 throw new Error(`Unknown physical texture type ${physicalTexType}`);
83533 }
83534 }
83535 function getPhysicalTextureForRendering(isPacked) {
83536 if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) {
83537 if (isPacked) {
83538 return PhysicalTextureType.PACKED_2X2_FLOAT32;
83539 }
83540 return PhysicalTextureType.UNPACKED_FLOAT32;
83541 }
83542 if (isPacked) {
83543 return PhysicalTextureType.PACKED_2X2_FLOAT16;
83544 }
83545 return PhysicalTextureType.UNPACKED_FLOAT16;
83546 }
83547 function getPhysicalFromLogicalTextureType(logicalTexType, isPacked) {
83548 if (logicalTexType === TextureUsage.UPLOAD) {
83549 return PhysicalTextureType.PACKED_2X2_FLOAT32;
83550 }
83551 else if (logicalTexType === TextureUsage.RENDER || logicalTexType == null) {
83552 return getPhysicalTextureForRendering(isPacked);
83553 }
83554 else if (logicalTexType === TextureUsage.DOWNLOAD ||
83555 logicalTexType === TextureUsage.PIXELS) {
83556 return PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE;
83557 }
83558 throw new Error(`Unknown logical texture type ${logicalTexType}`);
83559 }
83560 function getKeyFromTextureShape(shapeRowsCol, physicalTexType, isPacked) {
83561 return `${shapeRowsCol[0]}_${shapeRowsCol[1]}_${physicalTexType}_${isPacked}`;
83562 }
83563
83564 /**
83565 * @license
83566 * Copyright 2017 Google LLC. All Rights Reserved.
83567 * Licensed under the Apache License, Version 2.0 (the "License");
83568 * you may not use this file except in compliance with the License.
83569 * You may obtain a copy of the License at
83570 *
83571 * http://www.apache.org/licenses/LICENSE-2.0
83572 *
83573 * Unless required by applicable law or agreed to in writing, software
83574 * distributed under the License is distributed on an "AS IS" BASIS,
83575 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
83576 * See the License for the specific language governing permissions and
83577 * limitations under the License.
83578 * =============================================================================
83579 */
83580 class UnaryOpProgram {
83581 constructor(aShape, opSnippet) {
83582 this.variableNames = ['A'];
83583 this.outputShape = aShape;
83584 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
83585 this.userCode = `
83586 float unaryOperation(float x) {
83587 ${opSnippet}
83588 }
83589
83590 void main() {
83591 float x = getAAtOutCoords();
83592 float y = unaryOperation(x);
83593
83594 setOutput(y);
83595 }
83596 `;
83597 }
83598 }
83599 const CHECK_NAN_SNIPPET = `if (isnan(x)) return x;`;
83600 const LINEAR = `return x;`;
83601 const ABS = `return abs(x);`;
83602 function STEP(alpha = 0.0) {
83603 return CHECK_NAN_SNIPPET + `
83604 return x > 0.0 ? 1.0 : float(${alpha});
83605 `;
83606 }
83607 const ELU$1 = `return (x >= 0.0) ? x : (exp(x) - 1.0);`;
83608 const RELU = CHECK_NAN_SNIPPET + `
83609 return (x < 0.0) ? 0.0 : x;
83610`;
83611 const RELU6 = CHECK_NAN_SNIPPET + `
83612 return (x < 0.0) ? 0.0 : min(6.0, x);
83613`;
83614 const CLONE = 'return x;';
83615 const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * x));`;
83616
83617 /**
83618 * @license
83619 * Copyright 2018 Google LLC. All Rights Reserved.
83620 * Licensed under the Apache License, Version 2.0 (the "License");
83621 * you may not use this file except in compliance with the License.
83622 * You may obtain a copy of the License at
83623 *
83624 * http://www.apache.org/licenses/LICENSE-2.0
83625 *
83626 * Unless required by applicable law or agreed to in writing, software
83627 * distributed under the License is distributed on an "AS IS" BASIS,
83628 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
83629 * See the License for the specific language governing permissions and
83630 * limitations under the License.
83631 * =============================================================================
83632 */
83633 const LINEAR$1 = `return x;`;
83634 const ELU$2 = `
83635 vec4 result;
83636
83637 result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);
83638 result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);
83639 result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);
83640 result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);
83641
83642 return result;
83643`;
83644 const RELU$1 = `
83645 vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));
83646 bvec4 isNaN = isnan(x);
83647
83648 result.r = isNaN.r ? x.r : result.r;
83649 result.g = isNaN.g ? x.g : result.g;
83650 result.b = isNaN.b ? x.b : result.b;
83651 result.a = isNaN.a ? x.a : result.a;
83652
83653 return result;
83654`;
83655 const RELU6$1 = `
83656 vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));
83657 bvec4 isNaN = isnan(x);
83658
83659 result.r = isNaN.r ? x.r : result.r;
83660 result.g = isNaN.g ? x.g : result.g;
83661 result.b = isNaN.b ? x.b : result.b;
83662 result.a = isNaN.a ? x.a : result.a;
83663
83664 return result;
83665`;
83666 const SIGMOID$1 = `return 1.0 / (1.0 + exp(-1.0 * x));`;
83667 class UnaryOpPackedProgram {
83668 constructor(aShape, opSnippet) {
83669 this.variableNames = ['A'];
83670 this.packedInputs = true;
83671 this.packedOutput = true;
83672 this.outputShape = aShape;
83673 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
83674 this.userCode = `
83675 vec4 unaryOperation(vec4 x) {
83676 ${opSnippet}
83677 }
83678
83679 void main() {
83680 vec4 x = getAAtOutCoords();
83681 vec4 y = unaryOperation(x);
83682
83683 setOutput(y);
83684 }
83685 `;
83686 }
83687 }
83688
83689 /**
83690 * @license
83691 * Copyright 2018 Google LLC. All Rights Reserved.
83692 * Licensed under the Apache License, Version 2.0 (the "License");
83693 * you may not use this file except in compliance with the License.
83694 * You may obtain a copy of the License at
83695 *
83696 * http://www.apache.org/licenses/LICENSE-2.0
83697 *
83698 * Unless required by applicable law or agreed to in writing, software
83699 * distributed under the License is distributed on an "AS IS" BASIS,
83700 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
83701 * See the License for the specific language governing permissions and
83702 * limitations under the License.
83703 * =============================================================================
83704 */
83705 class UnpackProgram {
83706 constructor(outputShape) {
83707 this.variableNames = ['A'];
83708 this.packedInputs = true;
83709 this.packedOutput = false;
83710 this.outputShape = outputShape;
83711 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
83712 const rank = outputShape.length;
83713 const channels = getChannels('rc', rank);
83714 const dtype = getCoordsDataType(rank);
83715 const sourceCoords = getSourceCoords(rank, channels);
83716 const innerDims = channels.slice(-2);
83717 const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`;
83718 this.userCode = `
83719 void main() {
83720 ${dtype} rc = getOutputCoords();
83721 vec4 packedInput = getA(${sourceCoords});
83722
83723 setOutput(getChannel(packedInput, ${coords}));
83724 }
83725 `;
83726 }
83727 }
83728
83729 /**
83730 * @license
83731 * Copyright 2017 Google LLC. All Rights Reserved.
83732 * Licensed under the Apache License, Version 2.0 (the "License");
83733 * you may not use this file except in compliance with the License.
83734 * You may obtain a copy of the License at
83735 *
83736 * http://www.apache.org/licenses/LICENSE-2.0
83737 *
83738 * Unless required by applicable law or agreed to in writing, software
83739 * distributed under the License is distributed on an "AS IS" BASIS,
83740 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
83741 * See the License for the specific language governing permissions and
83742 * limitations under the License.
83743 * =============================================================================
83744 */
83745 const whereImpl$2 = whereImpl;
83746 const EPSILON_FLOAT32$1 = 1e-7;
83747 const EPSILON_FLOAT16$1 = 1e-4;
83748 const binaryCaches = {};
83749 function getBinaryCache(webGLVersion) {
83750 if (webGLVersion in binaryCaches) {
83751 return binaryCaches[webGLVersion];
83752 }
83753 binaryCaches[webGLVersion] = {};
83754 return binaryCaches[webGLVersion];
83755 }
83756 // Empirically determined constant used to determine size threshold for handing
83757 // off execution to the CPU.
83758 const CPU_HANDOFF_SIZE_THRESHOLD = env().getNumber('CPU_HANDOFF_SIZE_THRESHOLD');
83759 // Empirically determined constant used to decide the number of MB on GPU
83760 // before we warn about high memory use. The MB are this constant * screen area
83761 // * dpi / 1024 / 1024.
83762 const BEFORE_PAGING_CONSTANT = 600;
83763 function numMBBeforeWarning() {
83764 if (env().global.screen == null) {
83765 return 1024; // 1 GB.
83766 }
83767 return (env().global.screen.height * env().global.screen.width *
83768 window.devicePixelRatio) *
83769 BEFORE_PAGING_CONSTANT / 1024 / 1024;
83770 }
83771 class MathBackendWebGL extends KernelBackend {
83772 constructor(gpuResource) {
83773 super();
83774 // Maps data ids that have a pending read operation, to list of subscribers.
83775 this.pendingRead = new WeakMap();
83776 // List of data ids that are scheduled for disposal, but are waiting on a
83777 // pending read operation.
83778 this.pendingDisposal = new WeakSet();
83779 // Used to count the number of 'shallow' sliced tensors that point to the
83780 // same data id.
83781 this.dataRefCount = new WeakMap();
83782 this.numBytesInGPU = 0;
83783 // Accumulated time spent (including blocking) in uploading data to webgl.
83784 this.uploadWaitMs = 0;
83785 // Accumulated time spent (including blocking in downloading data from webgl.
83786 this.downloadWaitMs = 0;
83787 // record the last manual GL Flush time.
83788 this.lastGlFlushTime = 0;
83789 this.warnedAboutMemory = false;
83790 this.pendingDeletes = 0;
83791 this.disposed = false;
83792 if (!env().getBool('HAS_WEBGL')) {
83793 throw new Error('WebGL is not supported on this device');
83794 }
83795 let newGPGPU;
83796 if (gpuResource != null) {
83797 if (gpuResource instanceof GPGPUContext) {
83798 newGPGPU = gpuResource;
83799 }
83800 else {
83801 const gl = getWebGLContext(env().getNumber('WEBGL_VERSION'), gpuResource);
83802 newGPGPU = new GPGPUContext(gl);
83803 }
83804 this.binaryCache = {};
83805 this.gpgpuCreatedLocally = false;
83806 }
83807 else {
83808 const gl = getWebGLContext(env().getNumber('WEBGL_VERSION'));
83809 newGPGPU = new GPGPUContext(gl);
83810 this.binaryCache = getBinaryCache(env().getNumber('WEBGL_VERSION'));
83811 this.gpgpuCreatedLocally = true;
83812 }
83813 this.gpgpu = newGPGPU;
83814 this.canvas = this.gpgpu.gl.canvas;
83815 this.textureManager = new TextureManager(this.gpgpu);
83816 this.numMBBeforeWarning = numMBBeforeWarning();
83817 this.texData = new DataStorage(this, engine());
83818 }
83819 nextDataId() {
83820 return MathBackendWebGL.nextDataId++;
83821 }
83822 numDataIds() {
83823 return this.texData.numDataIds() - this.pendingDeletes;
83824 }
83825 write(values, shape, dtype) {
83826 if (env().getBool('WEBGL_CHECK_NUMERICAL_PROBLEMS') ||
83827 env().getBool('DEBUG')) {
83828 this.checkNumericalProblems(values);
83829 }
83830 if (dtype === 'complex64' && values != null) {
83831 throw new Error(`Cannot write to a complex64 dtype. ` +
83832 `Please use tf.complex(real, imag).`);
83833 }
83834 const dataId = { id: this.nextDataId() };
83835 this.texData.set(dataId, { shape, dtype, values, usage: TextureUsage.UPLOAD, refCount: 1 });
83836 return dataId;
83837 }
83838 /** Return refCount of a `TensorData`. */
83839 refCount(dataId) {
83840 if (this.texData.has(dataId)) {
83841 const tensorData = this.texData.get(dataId);
83842 return tensorData.refCount;
83843 }
83844 return 0;
83845 }
83846 /** Increase refCount of a `TextureData`. */
83847 incRef(dataId) {
83848 const texData = this.texData.get(dataId);
83849 texData.refCount++;
83850 }
83851 /** Decrease refCount of a `TextureData`. */
83852 decRef(dataId) {
83853 if (this.texData.has(dataId)) {
83854 const texData = this.texData.get(dataId);
83855 texData.refCount--;
83856 }
83857 }
83858 move(dataId, values, shape, dtype, refCount) {
83859 if (env().getBool('DEBUG')) {
83860 this.checkNumericalProblems(values);
83861 }
83862 if (dtype === 'complex64') {
83863 throw new Error(`Cannot write to a complex64 dtype. ` +
83864 `Please use tf.complex(real, imag).`);
83865 }
83866 this.texData.set(dataId, { shape, dtype, values, usage: TextureUsage.UPLOAD, refCount });
83867 }
83868 disposeIntermediateTensorInfo(tensorInfo) {
83869 this.disposeData(tensorInfo.dataId);
83870 }
83871 readSync(dataId) {
83872 const texData = this.texData.get(dataId);
83873 const { values, dtype, complexTensorInfos, slice, shape, isPacked } = texData;
83874 // The presence of `slice` indicates this tensor is a shallow slice of a
83875 // different tensor, and is using that original tensor's texture. Run
83876 // `clone` in order to copy that texture and read from it.
83877 if (slice != null) {
83878 let program;
83879 if (isPacked) {
83880 program = new UnaryOpPackedProgram(shape, CLONE);
83881 }
83882 else {
83883 program = new UnaryOpProgram(shape, CLONE);
83884 }
83885 const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype);
83886 const data = this.readSync(res.dataId);
83887 this.disposeIntermediateTensorInfo(res);
83888 return data;
83889 }
83890 if (values != null) {
83891 return this.convertAndCacheOnCPU(dataId);
83892 }
83893 if (dtype === 'string') {
83894 return values;
83895 }
83896 const shouldTimeProgram = this.activeTimers != null;
83897 let start;
83898 if (shouldTimeProgram) {
83899 start = now();
83900 }
83901 let result;
83902 if (dtype === 'complex64') {
83903 const realValues = this.readSync(complexTensorInfos.real.dataId);
83904 const imagValues = this.readSync(complexTensorInfos.imag.dataId);
83905 result = mergeRealAndImagArrays(realValues, imagValues);
83906 }
83907 else {
83908 result = this.getValuesFromTexture(dataId);
83909 }
83910 if (shouldTimeProgram) {
83911 this.downloadWaitMs += now() - start;
83912 }
83913 return this.convertAndCacheOnCPU(dataId, result);
83914 }
83915 async read(dataId) {
83916 if (this.pendingRead.has(dataId)) {
83917 const subscribers = this.pendingRead.get(dataId);
83918 return new Promise(resolve => subscribers.push(resolve));
83919 }
83920 const texData = this.texData.get(dataId);
83921 const { values, shape, slice, dtype, complexTensorInfos, isPacked } = texData;
83922 // The presence of `slice` indicates this tensor is a shallow slice of a
83923 // different tensor, and is using that original tensor's texture. Run
83924 // `clone` in order to copy that texture and read from it.
83925 if (slice != null) {
83926 let program;
83927 if (isPacked) {
83928 program = new UnaryOpPackedProgram(shape, CLONE);
83929 }
83930 else {
83931 program = new UnaryOpProgram(shape, CLONE);
83932 }
83933 const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype);
83934 const data = this.read(res.dataId);
83935 this.disposeIntermediateTensorInfo(res);
83936 return data;
83937 }
83938 if (values != null) {
83939 return this.convertAndCacheOnCPU(dataId);
83940 }
83941 if (env().getBool('DEBUG')) {
83942 // getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') caused a blocking GPU call.
83943 // For performance reason, only check it for debugging. In production,
83944 // it doesn't handle this use case anyway, so behavior is not changed.
83945 if (!env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') &&
83946 env().getNumber('WEBGL_VERSION') === 2) {
83947 throw new Error(`tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and ` +
83948 `WEBGL_VERSION=2 not yet supported.`);
83949 }
83950 }
83951 let buffer = null;
83952 let tmpDownloadTarget;
83953 if (dtype !== 'complex64' && env().get('WEBGL_BUFFER_SUPPORTED')) {
83954 // Possibly copy the texture into a buffer before inserting a fence.
83955 tmpDownloadTarget = this.decode(dataId);
83956 const tmpData = this.texData.get(tmpDownloadTarget.dataId);
83957 buffer = this.gpgpu.createBufferFromTexture(tmpData.texture.texture, ...getDenseTexShape(shape));
83958 }
83959 this.pendingRead.set(dataId, []);
83960 if (dtype !== 'complex64') {
83961 // Create a fence and wait for it to resolve.
83962 await this.gpgpu.createAndWaitForFence();
83963 }
83964 // Download the values from the GPU.
83965 let vals;
83966 if (dtype === 'complex64') {
83967 const ps = await Promise.all([
83968 this.read(complexTensorInfos.real.dataId),
83969 this.read(complexTensorInfos.imag.dataId)
83970 ]);
83971 const realValues = ps[0];
83972 const imagValues = ps[1];
83973 vals = mergeRealAndImagArrays(realValues, imagValues);
83974 }
83975 else if (buffer == null) {
83976 vals = this.getValuesFromTexture(dataId);
83977 }
83978 else {
83979 const size = sizeFromShape(shape);
83980 vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size);
83981 }
83982 if (tmpDownloadTarget != null) {
83983 this.disposeIntermediateTensorInfo(tmpDownloadTarget);
83984 }
83985 if (buffer != null) {
83986 const gl = this.gpgpu.gl;
83987 callAndCheck(gl, () => gl.deleteBuffer(buffer));
83988 }
83989 const dTypeVals = this.convertAndCacheOnCPU(dataId, vals);
83990 const subscribers = this.pendingRead.get(dataId);
83991 this.pendingRead.delete(dataId);
83992 // Notify all pending reads.
83993 subscribers.forEach(resolve => resolve(dTypeVals));
83994 if (this.pendingDisposal.has(dataId)) {
83995 this.pendingDisposal.delete(dataId);
83996 if (this.disposeData(dataId)) {
83997 engine().removeDataId(dataId, this);
83998 }
83999 this.pendingDeletes--;
84000 }
84001 return dTypeVals;
84002 }
84003 /**
84004 * Read tensor to a new texture that is densely packed for ease of use.
84005 * @param dataId The source tensor.
84006 * @param options
84007 * customTexShape: Optional. If set, will use the user defined texture
84008 * shape to create the texture.
84009 */
84010 readToGPU(dataId, options = {}) {
84011 const texData = this.texData.get(dataId);
84012 const { values, shape, slice, dtype, isPacked, texture } = texData;
84013 if (dtype === 'complex64') {
84014 throw new Error('Does not support reading texture for complex64 dtype.');
84015 }
84016 // The presence of `slice` indicates this tensor is a shallow slice of a
84017 // different tensor, and is using that original tensor's texture. Run
84018 // `clone` in order to copy that texture and read from it.
84019 if (slice != null) {
84020 let program;
84021 if (isPacked) {
84022 program = new UnaryOpPackedProgram(shape, CLONE);
84023 }
84024 else {
84025 program = new UnaryOpProgram(shape, CLONE);
84026 }
84027 const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype);
84028 const gpuResouorce = this.readToGPU(res, options);
84029 this.disposeIntermediateTensorInfo(res);
84030 return gpuResouorce;
84031 }
84032 if (texture == null) {
84033 if (values != null) {
84034 throw new Error('Data is not on GPU but on CPU.');
84035 }
84036 else {
84037 throw new Error('There is no data on GPU or CPU.');
84038 }
84039 }
84040 // Decode the texture so that it is stored densely (using four channels).
84041 const tmpTarget = this.decode(dataId, options.customTexShape);
84042 // Make engine track this tensor, so that we can dispose it later.
84043 const tensorRef = engine().makeTensorFromTensorInfo(tmpTarget);
84044 const tmpData = this.texData.get(tmpTarget.dataId);
84045 return Object.assign({ tensorRef }, tmpData.texture);
84046 }
84047 bufferSync(t) {
84048 const data = this.readSync(t.dataId);
84049 if (t.dtype === 'string') {
84050 try {
84051 // Decode the bytes into string.
84052 const strings = data.map(d => decodeString(d));
84053 return buffer(t.shape, t.dtype, strings);
84054 }
84055 catch (_a) {
84056 throw new Error('Failed to decode encoded string bytes into utf-8');
84057 }
84058 }
84059 return buffer(t.shape, t.dtype, data);
84060 }
84061 checkNumericalProblems(values) {
84062 if (values == null) {
84063 return;
84064 }
84065 for (let i = 0; i < values.length; i++) {
84066 const num = values[i];
84067 if (!canBeRepresented(num)) {
84068 if (env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) {
84069 throw Error(`The value ${num} cannot be represented with your ` +
84070 `current settings. Consider enabling float32 rendering: ` +
84071 `'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'`);
84072 }
84073 throw Error(`The value ${num} cannot be represented on this device.`);
84074 }
84075 }
84076 }
84077 getValuesFromTexture(dataId) {
84078 const { shape, dtype, isPacked } = this.texData.get(dataId);
84079 const size = sizeFromShape(shape);
84080 if (env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) {
84081 const tmpTarget = this.decode(dataId);
84082 const tmpData = this.texData.get(tmpTarget.dataId);
84083 const vals = this.gpgpu
84084 .downloadMatrixFromPackedTexture(tmpData.texture.texture, ...getDenseTexShape(shape))
84085 .subarray(0, size);
84086 this.disposeIntermediateTensorInfo(tmpTarget);
84087 return vals;
84088 }
84089 const shouldUsePackedProgram = env().getBool('WEBGL_PACK') && isPacked === true;
84090 const outputShape = shouldUsePackedProgram ? getShapeAs3D(shape) : shape;
84091 const program = shouldUsePackedProgram ?
84092 new EncodeFloatPackedProgram(outputShape) :
84093 new EncodeFloatProgram(outputShape);
84094 const output = this.runWebGLProgram(program, [{ shape: outputShape, dtype, dataId }], 'float32');
84095 const tmpData = this.texData.get(output.dataId);
84096 const vals = this.gpgpu
84097 .downloadByteEncodedFloatMatrixFromOutputTexture(tmpData.texture.texture, tmpData.texShape[0], tmpData.texShape[1])
84098 .subarray(0, size);
84099 this.disposeIntermediateTensorInfo(output);
84100 return vals;
84101 }
84102 timerAvailable() {
84103 return env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0;
84104 }
84105 time(f) {
84106 const oldActiveTimers = this.activeTimers;
84107 const newActiveTimers = [];
84108 let outerMostTime = false;
84109 if (this.programTimersStack == null) {
84110 this.programTimersStack = newActiveTimers;
84111 outerMostTime = true;
84112 }
84113 else {
84114 this.activeTimers.push(newActiveTimers);
84115 }
84116 this.activeTimers = newActiveTimers;
84117 f();
84118 // needing to split these up because util.flatten only accepts certain types
84119 const flattenedActiveTimerQueries = flatten(this.activeTimers.map((d) => d.query))
84120 .filter(d => d != null);
84121 const flattenedActiveTimerNames = flatten(this.activeTimers.map((d) => d.name))
84122 .filter(d => d != null);
84123 this.activeTimers = oldActiveTimers;
84124 if (outerMostTime) {
84125 this.programTimersStack = null;
84126 }
84127 const res = {
84128 uploadWaitMs: this.uploadWaitMs,
84129 downloadWaitMs: this.downloadWaitMs,
84130 kernelMs: null,
84131 wallMs: null // will be filled by the engine
84132 };
84133 return (async () => {
84134 if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') >
84135 0) {
84136 const kernelMs = await Promise.all(flattenedActiveTimerQueries);
84137 res['kernelMs'] = sum(kernelMs);
84138 res['getExtraProfileInfo'] = () => kernelMs
84139 .map((d, i) => ({ name: flattenedActiveTimerNames[i], ms: d }))
84140 .map(d => `${d.name}: ${d.ms}`)
84141 .join(', ');
84142 }
84143 else {
84144 res['kernelMs'] = {
84145 error: 'WebGL query timers are not supported in this environment.'
84146 };
84147 }
84148 this.uploadWaitMs = 0;
84149 this.downloadWaitMs = 0;
84150 return res;
84151 })();
84152 }
84153 memory() {
84154 return {
84155 unreliable: false,
84156 numBytesInGPU: this.numBytesInGPU,
84157 numBytesInGPUAllocated: this.textureManager.numBytesAllocated,
84158 numBytesInGPUFree: this.textureManager.numBytesFree
84159 };
84160 }
84161 startTimer() {
84162 if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
84163 return this.gpgpu.beginQuery();
84164 }
84165 return { startMs: now(), endMs: null };
84166 }
84167 endTimer(query) {
84168 if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
84169 this.gpgpu.endQuery();
84170 return query;
84171 }
84172 query.endMs = now();
84173 return query;
84174 }
84175 async getQueryTime(query) {
84176 if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
84177 return this.gpgpu.waitForQueryAndGetTime(query);
84178 }
84179 const timerQuery = query;
84180 return timerQuery.endMs - timerQuery.startMs;
84181 }
84182 /**
84183 * Decrease the RefCount on the dataId and dispose the memory if the dataId
84184 * has 0 refCount. If there are pending read on the data, the disposal would
84185 * added to the pending delete queue. Return true if the dataId is removed
84186 * from backend or the backend does not contain the dataId, false if the
84187 * dataId is not removed. Memory may or may not be released even when dataId
84188 * is removed, which also depends on dataRefCount, see `releaseGPU`.
84189 * @param dataId
84190 * @oaram force Optional, remove the data regardless of refCount
84191 */
84192 disposeData(dataId, force = false) {
84193 if (this.pendingDisposal.has(dataId)) {
84194 return false;
84195 }
84196 // No-op if already disposed.
84197 if (!this.texData.has(dataId)) {
84198 return true;
84199 }
84200 // if force flag is set, change refCount to 0, this would ensure disposal
84201 // when added to the pendingDisposal queue. Memory may or may not be
84202 // released, which also depends on dataRefCount, see `releaseGPU`.
84203 if (force) {
84204 this.texData.get(dataId).refCount = 0;
84205 }
84206 else {
84207 this.texData.get(dataId).refCount--;
84208 }
84209 if (!force && this.texData.get(dataId).refCount > 0) {
84210 return false;
84211 }
84212 if (this.pendingRead.has(dataId)) {
84213 this.pendingDisposal.add(dataId);
84214 this.pendingDeletes++;
84215 return false;
84216 }
84217 this.releaseGPUData(dataId);
84218 const { complexTensorInfos } = this.texData.get(dataId);
84219 if (complexTensorInfos != null) {
84220 this.disposeData(complexTensorInfos.real.dataId, force);
84221 this.disposeData(complexTensorInfos.imag.dataId, force);
84222 }
84223 this.texData.delete(dataId);
84224 return true;
84225 }
84226 releaseGPUData(dataId) {
84227 const { texture, dtype, texShape, usage, isPacked, slice } = this.texData.get(dataId);
84228 const key = slice && slice.origDataId || dataId;
84229 const refCount = this.dataRefCount.get(key);
84230 if (refCount > 1) {
84231 this.dataRefCount.set(key, refCount - 1);
84232 }
84233 else {
84234 this.dataRefCount.delete(key);
84235 if (texture != null) {
84236 this.numBytesInGPU -= this.computeBytes(texShape, dtype);
84237 this.textureManager.releaseTexture(texture, texShape, usage, isPacked);
84238 }
84239 }
84240 const texData = this.texData.get(dataId);
84241 texData.texture = null;
84242 texData.texShape = null;
84243 texData.isPacked = false;
84244 texData.slice = null;
84245 }
84246 getTexture(dataId) {
84247 this.uploadToGPU(dataId);
84248 return this.texData.get(dataId).texture.texture;
84249 }
84250 /**
84251 * Returns internal information for the specific data bucket. Used in unit
84252 * tests.
84253 */
84254 getDataInfo(dataId) {
84255 return this.texData.get(dataId);
84256 }
84257 /*
84258 Tests whether all the inputs to an op are small and on the CPU. This heuristic
84259 determines when it would be faster to execute a kernel on the CPU. WebGL
84260 kernels opt into running this check and forwarding when appropriate.
84261 TODO(https://github.com/tensorflow/tfjs/issues/872): Develop a more
84262 sustainable strategy for optimizing backend execution of ops.
84263 */
84264 shouldExecuteOnCPU(inputs, sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD) {
84265 return env().getBool('WEBGL_CPU_FORWARD') &&
84266 inputs.every(input => this.texData.get(input.dataId).texture == null &&
84267 sizeFromShape(input.shape) < sizeThreshold);
84268 }
84269 getGPGPUContext() {
84270 return this.gpgpu;
84271 }
84272 where(condition) {
84273 warn('tf.where() in webgl locks the UI thread. ' +
84274 'Call tf.whereAsync() instead');
84275 const condVals = condition.dataSync();
84276 return whereImpl$2(condition.shape, condVals);
84277 }
84278 packedUnaryOp(x, op, dtype) {
84279 const program = new UnaryOpPackedProgram(x.shape, op);
84280 const outInfo = this.compileAndRun(program, [x], dtype);
84281 return engine().makeTensorFromTensorInfo(outInfo);
84282 }
84283 // TODO(msoulanille) remove this once the backend has been modularized
84284 // a copy is needed here to break a circular dependency.
84285 // Also remove the op from unary_op.
84286 abs(x) {
84287 // TODO: handle cases when x is complex.
84288 if (this.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') {
84289 const outValues = simpleAbsImplCPU(this.texData.get(x.dataId).values);
84290 return this.makeOutput(x.shape, x.dtype, outValues);
84291 }
84292 if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
84293 return this.packedUnaryOp(x, ABS, x.dtype);
84294 }
84295 const program = new UnaryOpProgram(x.shape, ABS);
84296 const outInfo = this.compileAndRun(program, [x]);
84297 return engine().makeTensorFromTensorInfo(outInfo);
84298 }
84299 makeTensorInfo(shape, dtype, values) {
84300 let dataId;
84301 if (dtype === 'string' && values != null && values.length > 0 &&
84302 isString(values[0])) {
84303 const encodedValues = values.map(d => encodeString(d));
84304 dataId = this.write(encodedValues, shape, dtype);
84305 }
84306 else {
84307 dataId = this.write(values, shape, dtype);
84308 }
84309 this.texData.get(dataId).usage = null;
84310 return { dataId, shape, dtype };
84311 }
84312 makeOutput(shape, dtype, values) {
84313 return engine().makeTensorFromTensorInfo(this.makeTensorInfo(shape, dtype, values), this);
84314 }
84315 unpackTensor(input) {
84316 const program = new UnpackProgram(input.shape);
84317 return this.runWebGLProgram(program, [input], input.dtype);
84318 }
84319 packTensor(input) {
84320 const program = new PackProgram(input.shape);
84321 const preventEagerUnpackingOutput = true;
84322 return this.runWebGLProgram(program, [input], input.dtype, null /* customUniformValues */, preventEagerUnpackingOutput);
84323 }
84324 packedReshape(input, afterShape) {
84325 const input3DShape = [
84326 getBatchDim(input.shape),
84327 ...getRowsCols(input.shape)
84328 ];
84329 const input3D = {
84330 dtype: input.dtype,
84331 shape: input3DShape,
84332 dataId: input.dataId
84333 };
84334 const afterShapeAs3D = [
84335 getBatchDim(afterShape), ...getRowsCols(afterShape)
84336 ];
84337 const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
84338 const preventEagerUnpackingOfOutput = true;
84339 const customValues = [input3DShape];
84340 const output = this.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput);
84341 return { dataId: output.dataId, shape: afterShape, dtype: output.dtype };
84342 }
84343 decode(dataId, customTexShape) {
84344 const texData = this.texData.get(dataId);
84345 const { isPacked, shape, dtype } = texData;
84346 if (customTexShape != null) {
84347 const size = sizeFromShape(shape);
84348 const texSize = customTexShape[0] * customTexShape[1] * 4;
84349 assert(size <= texSize, () => 'customTexShape is too small. ' +
84350 'Row * Column * 4 should be equal or larger than the ' +
84351 'size of the tensor data.');
84352 }
84353 const shapeAs3D = getShapeAs3D(shape);
84354 let program;
84355 if (isPacked) {
84356 program = new DecodeMatrixPackedProgram(shapeAs3D);
84357 }
84358 else {
84359 program = new DecodeMatrixProgram(shapeAs3D);
84360 }
84361 const preventEagerUnpackingOfOutput = true;
84362 const customValues = [customTexShape != null ? customTexShape :
84363 getDenseTexShape(shapeAs3D)];
84364 const out = this.runWebGLProgram(program, [{ shape: shapeAs3D, dtype, dataId }], dtype, customValues, preventEagerUnpackingOfOutput, customTexShape);
84365 return { dtype, shape, dataId: out.dataId };
84366 }
84367 runWebGLProgram(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput = false, customTexShape) {
84368 const output = this.makeTensorInfo(program.outputShape, outputDtype);
84369 const outData = this.texData.get(output.dataId);
84370 if (program.packedOutput) {
84371 outData.isPacked = true;
84372 }
84373 if (program.outPackingScheme === PackingScheme.DENSE) {
84374 const texelShape = customTexShape != null ?
84375 customTexShape :
84376 getDenseTexShape(program.outputShape);
84377 // For a densely packed output, we explicitly set texShape
84378 // so it doesn't get assigned later according to our typical packing
84379 // scheme wherein a single texel can only contain values from adjacent
84380 // rows/cols.
84381 outData.texShape = texelShape.map(d => d * 2);
84382 }
84383 if (program.outTexUsage != null) {
84384 outData.usage = program.outTexUsage;
84385 }
84386 if (sizeFromShape(output.shape) === 0) {
84387 // Short-circuit the computation since the result is empty (has 0 in its
84388 // shape).
84389 outData.values =
84390 getTypedArrayFromDType(output.dtype, 0);
84391 return output;
84392 }
84393 const dataToDispose = [];
84394 const inputsData = inputs.map(input => {
84395 if (input.dtype === 'complex64') {
84396 throw new Error(`GPGPUProgram does not support complex64 input. For complex64 ` +
84397 `dtypes, please separate the program into real and imaginary ` +
84398 `parts.`);
84399 }
84400 let texData = this.texData.get(input.dataId);
84401 if (texData.texture == null) {
84402 if (!program.packedInputs &&
84403 sizeFromShape(input.shape) <=
84404 env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) {
84405 // Upload small tensors that live on the CPU as uniforms, not as
84406 // textures. Do this only when the environment supports 32bit floats
84407 // due to problems when comparing 16bit floats with 32bit floats.
84408 // TODO(https://github.com/tensorflow/tfjs/issues/821): Make it
84409 // possible for packed shaders to sample from uniforms.
84410 return {
84411 shape: input.shape,
84412 texData: null,
84413 isUniform: true,
84414 uniformValues: texData.values
84415 };
84416 }
84417 // This ensures that if a packed program's inputs have not yet been
84418 // uploaded to the GPU, they get uploaded as packed right off the bat.
84419 if (program.packedInputs) {
84420 texData.isPacked = true;
84421 texData.shape = input.shape;
84422 }
84423 }
84424 this.uploadToGPU(input.dataId);
84425 if (!!texData.isPacked !== !!program.packedInputs) {
84426 input = texData.isPacked ? this.unpackTensor(input) :
84427 this.packTensor(input);
84428 dataToDispose.push(input);
84429 texData = this.texData.get(input.dataId);
84430 }
84431 else if (texData.isPacked &&
84432 !isReshapeFree(texData.shape, input.shape)) {
84433 // This is a special case where a texture exists for a tensor
84434 // but the shapes are incompatible (due to packing constraints) because
84435 // the tensor did not have a chance to go through the packed reshape
84436 // shader. This only happens when we reshape the *same* tensor to form
84437 // *distinct* inputs to an op, e.g. dotting a vector with itself. This
84438 // case will disappear once packed uploading is the default.
84439 const savedInput = input;
84440 const targetShape = input.shape;
84441 input.shape = texData.shape;
84442 input = this.packedReshape(input, targetShape);
84443 dataToDispose.push(input);
84444 texData = this.texData.get(input.dataId);
84445 savedInput.shape = targetShape;
84446 }
84447 return { shape: input.shape, texData, isUniform: false };
84448 });
84449 this.uploadToGPU(output.dataId);
84450 const outputData = { shape: output.shape, texData: outData, isUniform: false };
84451 const key = makeShaderKey(program, inputsData, outputData);
84452 const binary = this.getAndSaveBinary(key, () => {
84453 return compileProgram(this.gpgpu, program, inputsData, outputData);
84454 });
84455 const shouldTimeProgram = this.activeTimers != null;
84456 let query;
84457 if (shouldTimeProgram) {
84458 query = this.startTimer();
84459 }
84460 if (!env().get('ENGINE_COMPILE_ONLY')) {
84461 runProgram(this.gpgpu, binary, inputsData, outputData, customUniformValues);
84462 }
84463 dataToDispose.forEach(info => this.disposeIntermediateTensorInfo(info));
84464 if (shouldTimeProgram) {
84465 query = this.endTimer(query);
84466 this.activeTimers.push({ name: program.constructor.name, query: this.getQueryTime(query) });
84467 }
84468 const glFlushThreshold = env().get('WEBGL_FLUSH_THRESHOLD');
84469 // Manually GL flush requested
84470 if (glFlushThreshold > 0) {
84471 const time = now();
84472 if ((time - this.lastGlFlushTime) > glFlushThreshold) {
84473 this.gpgpu.gl.flush();
84474 this.lastGlFlushTime = time;
84475 }
84476 }
84477 if (!env().getBool('WEBGL_LAZILY_UNPACK') && outData.isPacked &&
84478 preventEagerUnpackingOfOutput === false) {
84479 const unpacked = this.unpackTensor(output);
84480 this.disposeIntermediateTensorInfo(output);
84481 return unpacked;
84482 }
84483 return output;
84484 }
84485 compileAndRun(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput = false) {
84486 outputDtype = outputDtype || inputs[0].dtype;
84487 const outInfo = this.runWebGLProgram(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput);
84488 return outInfo;
84489 }
84490 getAndSaveBinary(key, getBinary) {
84491 if (!(key in this.binaryCache)) {
84492 this.binaryCache[key] = getBinary();
84493 }
84494 return this.binaryCache[key];
84495 }
84496 getTextureManager() {
84497 return this.textureManager;
84498 }
84499 dispose() {
84500 if (this.disposed) {
84501 return;
84502 }
84503 // Avoid disposing the compiled webgl programs during unit testing because
84504 // it slows down test execution.
84505 if (!env().getBool('IS_TEST')) {
84506 const allKeys = Object.keys(this.binaryCache);
84507 allKeys.forEach(key => {
84508 this.gpgpu.deleteProgram(this.binaryCache[key].webGLProgram);
84509 delete this.binaryCache[key];
84510 });
84511 }
84512 this.textureManager.dispose();
84513 if (this.canvas != null &&
84514 (typeof (HTMLCanvasElement) !== 'undefined' &&
84515 this.canvas instanceof HTMLCanvasElement)) {
84516 this.canvas.remove();
84517 }
84518 else {
84519 this.canvas = null;
84520 }
84521 if (this.gpgpuCreatedLocally) {
84522 this.gpgpu.program = null;
84523 this.gpgpu.dispose();
84524 }
84525 this.disposed = true;
84526 }
84527 floatPrecision() {
84528 if (this.floatPrecisionValue == null) {
84529 this.floatPrecisionValue = tidy(() => {
84530 if (!env().get('WEBGL_RENDER_FLOAT32_ENABLED')) {
84531 // Momentarily switching DEBUG flag to false so we don't throw an
84532 // error trying to upload a small value.
84533 const debugFlag = env().getBool('DEBUG');
84534 env().set('DEBUG', false);
84535 const underflowCheckValue = this.abs(scalar(1e-8)).dataSync()[0];
84536 env().set('DEBUG', debugFlag);
84537 if (underflowCheckValue > 0) {
84538 return 32;
84539 }
84540 }
84541 return 16;
84542 });
84543 }
84544 return this.floatPrecisionValue;
84545 }
84546 /** Returns the smallest representable number. */
84547 epsilon() {
84548 return this.floatPrecision() === 32 ? EPSILON_FLOAT32$1 : EPSILON_FLOAT16$1;
84549 }
84550 uploadToGPU(dataId) {
84551 const texData = this.texData.get(dataId);
84552 const { shape, dtype, values, texture, usage, isPacked } = texData;
84553 if (texture != null) {
84554 // Array is already on GPU. No-op.
84555 return;
84556 }
84557 const shouldTimeProgram = this.activeTimers != null;
84558 let start;
84559 if (shouldTimeProgram) {
84560 start = now();
84561 }
84562 let texShape = texData.texShape;
84563 if (texShape == null) {
84564 // This texShape may not be the final texture shape. For packed or dense
84565 // textures, the texShape will be changed when textures are created.
84566 texShape = getTextureShapeFromLogicalShape(shape, isPacked);
84567 texData.texShape = texShape;
84568 }
84569 if (values != null) {
84570 const shapeAs3D = getShapeAs3D(shape);
84571 let program;
84572 let width = texShape[1], height = texShape[0];
84573 const isByteArray = values instanceof Uint8Array || values instanceof Uint8ClampedArray;
84574 // texture for float array is PhysicalTextureType.PACKED_2X2_FLOAT32, we
84575 // need to make sure the upload uses the same packed size
84576 if (isPacked || !isByteArray) {
84577 [width, height] = getPackedMatrixTextureShapeWidthHeight(texShape[0], texShape[1]);
84578 }
84579 if (isPacked) {
84580 program = new EncodeMatrixPackedProgram(shapeAs3D, isByteArray);
84581 }
84582 else {
84583 program = new EncodeMatrixProgram(shapeAs3D, isByteArray);
84584 }
84585 // TexShape for float array needs to be the original shape, which byte
84586 // array needs to be packed size. This allow the data upload shape to be
84587 // matched with texture creation logic.
84588 const tempDenseInputTexShape = isByteArray ? [height, width] : texShape;
84589 const tempDenseInputHandle = this.makeTensorInfo(tempDenseInputTexShape, dtype);
84590 const tempDenseInputTexData = this.texData.get(tempDenseInputHandle.dataId);
84591 if (isByteArray) {
84592 tempDenseInputTexData.usage = TextureUsage.PIXELS;
84593 }
84594 else {
84595 tempDenseInputTexData.usage = TextureUsage.UPLOAD;
84596 }
84597 tempDenseInputTexData.texShape = tempDenseInputTexShape;
84598 this.gpgpu.uploadDenseMatrixToTexture(this.getTexture(tempDenseInputHandle.dataId), width, height, values);
84599 const customValues = [[height, width]];
84600 // We want the output to remain packed regardless of the value of
84601 // WEBGL_PACK.
84602 const preventEagerUnpacking = true;
84603 const encodedOutputTarget = this.runWebGLProgram(program, [tempDenseInputHandle], dtype, customValues, preventEagerUnpacking);
84604 // Have the original texture assume the identity of the encoded output.
84605 const outputTexData = this.texData.get(encodedOutputTarget.dataId);
84606 texData.texShape = outputTexData.texShape;
84607 texData.isPacked = outputTexData.isPacked;
84608 texData.usage = outputTexData.usage;
84609 if (!env().get('ENGINE_COMPILE_ONLY')) {
84610 texData.texture = outputTexData.texture;
84611 // Once uploaded, don't store the values on cpu.
84612 texData.values = null;
84613 this.texData.delete(encodedOutputTarget.dataId);
84614 }
84615 else {
84616 this.disposeData(encodedOutputTarget.dataId);
84617 }
84618 this.disposeIntermediateTensorInfo(tempDenseInputHandle);
84619 if (shouldTimeProgram) {
84620 this.uploadWaitMs += now() - start;
84621 }
84622 }
84623 else {
84624 const newTexture = this.acquireTexture(texShape, usage, dtype, isPacked);
84625 texData.texture = newTexture;
84626 }
84627 }
84628 convertAndCacheOnCPU(dataId, float32Values) {
84629 const texData = this.texData.get(dataId);
84630 const { dtype } = texData;
84631 this.releaseGPUData(dataId);
84632 if (float32Values != null) {
84633 texData.values = float32ToTypedArray(float32Values, dtype);
84634 }
84635 return texData.values;
84636 }
84637 acquireTexture(texShape, texType, dtype, isPacked) {
84638 this.numBytesInGPU += this.computeBytes(texShape, dtype);
84639 if (!this.warnedAboutMemory &&
84640 this.numBytesInGPU > this.numMBBeforeWarning * 1024 * 1024) {
84641 const mb = (this.numBytesInGPU / 1024 / 1024).toFixed(2);
84642 this.warnedAboutMemory = true;
84643 console.warn(`High memory usage in GPU: ${mb} MB, ` +
84644 `most likely due to a memory leak`);
84645 }
84646 return this.textureManager.acquireTexture(texShape, texType, isPacked);
84647 }
84648 computeBytes(shape, dtype) {
84649 return shape[0] * shape[1] * bytesPerElement(dtype);
84650 }
84651 checkCompileCompletion() {
84652 for (const [, binary] of Object.entries(this.binaryCache)) {
84653 this.checkCompletion_(binary);
84654 }
84655 }
84656 async checkCompileCompletionAsync() {
84657 const ps = [];
84658 if (this.gpgpu.parallelCompilationExtension) {
84659 for (const [, binary] of Object.entries(this.binaryCache)) {
84660 ps.push(this.checkCompletionAsync_(binary));
84661 }
84662 return Promise.all(ps);
84663 }
84664 else {
84665 for (const [, binary] of Object.entries(this.binaryCache)) {
84666 const p = new Promise((resolve) => {
84667 try {
84668 this.checkCompletion_(binary);
84669 resolve(true);
84670 }
84671 catch (error) {
84672 throw error;
84673 }
84674 });
84675 ps.push(p);
84676 }
84677 return Promise.all(ps);
84678 }
84679 }
84680 async checkCompletionAsync_(binary) {
84681 if (this.gpgpu.gl.getProgramParameter(binary.webGLProgram, this.gpgpu.parallelCompilationExtension.COMPLETION_STATUS_KHR)) {
84682 return this.checkCompletion_(binary);
84683 }
84684 else {
84685 await nextFrame();
84686 return this.checkCompletionAsync_(binary);
84687 }
84688 }
84689 checkCompletion_(binary) {
84690 if (this.gpgpu.gl.getProgramParameter(binary.webGLProgram, this.gpgpu.gl.LINK_STATUS) === false) {
84691 console.log(this.gpgpu.gl.getProgramInfoLog(binary.webGLProgram));
84692 if (this.gpgpu.gl.getShaderParameter(binary.fragmentShader, this.gpgpu.gl.COMPILE_STATUS) === false) {
84693 logShaderSourceAndInfoLog(binary.source, this.gpgpu.gl.getShaderInfoLog(binary.fragmentShader));
84694 throw new Error('Failed to compile fragment shader.');
84695 }
84696 throw new Error('Failed to link vertex and fragment shaders.');
84697 }
84698 return true;
84699 }
84700 getUniformLocations() {
84701 for (const [, binary] of Object.entries(this.binaryCache)) {
84702 const { uniformLocations, customUniformLocations, infLoc, nanLoc, inShapesLocations, inTexShapesLocations, outShapeLocation, outShapeStridesLocation, outTexShapeLocation } = getUniformLocations(this.gpgpu, binary.program, binary.webGLProgram);
84703 binary.uniformLocations = uniformLocations;
84704 binary.customUniformLocations = customUniformLocations;
84705 binary.infLoc = infLoc;
84706 binary.nanLoc = nanLoc;
84707 binary.inShapesLocations = inShapesLocations;
84708 binary.inTexShapesLocations = inTexShapesLocations;
84709 binary.outShapeLocation = outShapeLocation;
84710 binary.outShapeStridesLocation = outShapeStridesLocation;
84711 binary.outTexShapeLocation = outTexShapeLocation;
84712 }
84713 }
84714 }
84715 MathBackendWebGL.nextDataId = 0;
84716 function float32ToTypedArray(a, dtype) {
84717 if (dtype === 'float32' || dtype === 'complex64') {
84718 return a;
84719 }
84720 else if (dtype === 'int32' || dtype === 'bool') {
84721 const result = (dtype === 'int32') ? new Int32Array(a.length) :
84722 new Uint8Array(a.length);
84723 for (let i = 0; i < result.length; ++i) {
84724 result[i] = Math.round(a[i]);
84725 }
84726 return result;
84727 }
84728 else {
84729 throw new Error(`Unknown dtype ${dtype}`);
84730 }
84731 }
84732
84733 /** @license See the LICENSE file. */
84734 // This code is auto-generated, do not modify this file!
84735 const version$5 = '3.18.0';
84736
84737 /**
84738 * @license
84739 * Copyright 2019 Google LLC. All Rights Reserved.
84740 * Licensed under the Apache License, Version 2.0 (the "License");
84741 * you may not use this file except in compliance with the License.
84742 * You may obtain a copy of the License at
84743 *
84744 * http://www.apache.org/licenses/LICENSE-2.0
84745 *
84746 * Unless required by applicable law or agreed to in writing, software
84747 * distributed under the License is distributed on an "AS IS" BASIS,
84748 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84749 * See the License for the specific language governing permissions and
84750 * limitations under the License.
84751 * =============================================================================
84752 */
84753 /**
84754 * Enforce use of half precision textures if available on the platform.
84755 *
84756 * @doc {heading: 'Environment', namespace: 'webgl'}
84757 */
84758 function forceHalfFloat() {
84759 env().set('WEBGL_FORCE_F16_TEXTURES', true);
84760 }
84761
84762 /**
84763 * @license
84764 * Copyright 2020 Google Inc. All Rights Reserved.
84765 * Licensed under the Apache License, Version 2.0 (the "License");
84766 * you may not use this file except in compliance with the License.
84767 * You may obtain a copy of the License at
84768 *
84769 * http://www.apache.org/licenses/LICENSE-2.0
84770 *
84771 * Unless required by applicable law or agreed to in writing, software
84772 * distributed under the License is distributed on an "AS IS" BASIS,
84773 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84774 * See the License for the specific language governing permissions and
84775 * limitations under the License.
84776 * =============================================================================
84777 */
84778 if (isBrowser()) {
84779 registerBackend('webgl', () => new MathBackendWebGL(), 2 /* priority */);
84780 }
84781 const webgl = { forceHalfFloat };
84782
84783 /**
84784 * @license
84785 * Copyright 2017 Google LLC. All Rights Reserved.
84786 * Licensed under the Apache License, Version 2.0 (the "License");
84787 * you may not use this file except in compliance with the License.
84788 * You may obtain a copy of the License at
84789 *
84790 * http://www.apache.org/licenses/LICENSE-2.0
84791 *
84792 * Unless required by applicable law or agreed to in writing, software
84793 * distributed under the License is distributed on an "AS IS" BASIS,
84794 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84795 * See the License for the specific language governing permissions and
84796 * limitations under the License.
84797 * =============================================================================
84798 */
84799 const CHECK_NAN_SNIPPET$1 = `
84800 if (isnan(a)) return a;
84801 if (isnan(b)) return b;
84802`;
84803 const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);';
84804 class BinaryOpProgram {
84805 constructor(op, aShape, bShape) {
84806 this.variableNames = ['A', 'B'];
84807 this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
84808 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
84809 this.userCode = `
84810 float binaryOperation(float a, float b) {
84811 ${op}
84812 }
84813
84814 void main() {
84815 float a = getAAtOutCoords();
84816 float b = getBAtOutCoords();
84817 setOutput(binaryOperation(a, b));
84818 }
84819 `;
84820 }
84821 }
84822
84823 /**
84824 * @license
84825 * Copyright 2018 Google LLC. All Rights Reserved.
84826 * Licensed under the Apache License, Version 2.0 (the "License");
84827 * you may not use this file except in compliance with the License.
84828 * You may obtain a copy of the License at
84829 *
84830 * http://www.apache.org/licenses/LICENSE-2.0
84831 *
84832 * Unless required by applicable law or agreed to in writing, software
84833 * distributed under the License is distributed on an "AS IS" BASIS,
84834 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84835 * See the License for the specific language governing permissions and
84836 * limitations under the License.
84837 * =============================================================================
84838 */
84839 const CHECK_NAN_SNIPPET$2 = `
84840 result.r = isNaN.r > 0. ? NAN : result.r;
84841 result.g = isNaN.g > 0. ? NAN : result.g;
84842 result.b = isNaN.b > 0. ? NAN : result.b;
84843 result.a = isNaN.a > 0. ? NAN : result.a;
84844`;
84845 const ELU_DER = `
84846 vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));
84847 return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));
84848`;
84849 const NOT_EQUAL = `
84850 return vec4(notEqual(a, b));
84851`;
84852 class BinaryOpPackedProgram {
84853 constructor(op, aShape, bShape, checkOutOfBounds = false) {
84854 this.variableNames = ['A', 'B'];
84855 this.supportsBroadcasting = true;
84856 this.packedInputs = true;
84857 this.packedOutput = true;
84858 this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
84859 const rank = this.outputShape.length;
84860 this.enableShapeUniforms = useShapeUniforms(rank);
84861 let checkOutOfBoundsString = '';
84862 if (checkOutOfBounds) {
84863 if (rank === 0 || sizeFromShape(this.outputShape) === 1) {
84864 checkOutOfBoundsString = `
84865 result.y = 0.;
84866 result.z = 0.;
84867 result.w = 0.;
84868 `;
84869 }
84870 else {
84871 const dtype = getCoordsDataType(rank);
84872 checkOutOfBoundsString = `
84873 ${dtype} coords = getOutputCoords();
84874 `;
84875 if (rank === 1) {
84876 if (this.enableShapeUniforms) {
84877 checkOutOfBoundsString += `
84878 result.y = (coords + 1) >= outShape ? 0. : result.y;
84879 result.z = 0.;
84880 result.w = 0.;
84881 `;
84882 }
84883 else {
84884 checkOutOfBoundsString += `
84885 result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y;
84886 result.z = 0.;
84887 result.w = 0.;
84888 `;
84889 }
84890 }
84891 else {
84892 const channels = getChannels('coords', rank);
84893 if (this.enableShapeUniforms) {
84894 checkOutOfBoundsString += `
84895 bool nextRowOutOfBounds =
84896 (${channels[rank - 2]} + 1) >= outShape[${rank} - 2];
84897 bool nextColOutOfBounds =
84898 (${channels[rank - 1]} + 1) >= outShape[${rank} - 1];
84899 result.y = nextColOutOfBounds ? 0. : result.y;
84900 result.z = nextRowOutOfBounds ? 0. : result.z;
84901 result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
84902 `;
84903 }
84904 else {
84905 checkOutOfBoundsString += `
84906 bool nextRowOutOfBounds =
84907 (${channels[rank - 2]} + 1) >= ${this.outputShape[rank - 2]};
84908 bool nextColOutOfBounds =
84909 (${channels[rank - 1]} + 1) >= ${this.outputShape[rank - 1]};
84910 result.y = nextColOutOfBounds ? 0. : result.y;
84911 result.z = nextRowOutOfBounds ? 0. : result.z;
84912 result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
84913 `;
84914 }
84915 }
84916 }
84917 }
84918 this.userCode = `
84919 vec4 binaryOperation(vec4 a, vec4 b) {
84920 ${op}
84921 }
84922
84923 void main() {
84924 vec4 a = getAAtOutCoords();
84925 vec4 b = getBAtOutCoords();
84926
84927 vec4 result = binaryOperation(a, b);
84928 ${checkOutOfBoundsString}
84929
84930 setOutput(result);
84931 }
84932 `;
84933 }
84934 }
84935
84936 /**
84937 * @license
84938 * Copyright 2020 Google LLC. All Rights Reserved.
84939 * Licensed under the Apache License, Version 2.0 (the "License");
84940 * you may not use this file except in compliance with the License.
84941 * You may obtain a copy of the License at
84942 *
84943 * http://www.apache.org/licenses/LICENSE-2.0
84944 *
84945 * Unless required by applicable law or agreed to in writing, software
84946 * distributed under the License is distributed on an "AS IS" BASIS,
84947 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84948 * See the License for the specific language governing permissions and
84949 * limitations under the License.
84950 * =============================================================================
84951 */
84952 function identity$2(args) {
84953 const { inputs, backend } = args;
84954 const { x } = inputs;
84955 backend.incRef(x.dataId);
84956 return { dataId: x.dataId, shape: x.shape, dtype: x.dtype };
84957 }
84958 const identityConfig$1 = {
84959 kernelName: Identity,
84960 backendName: 'webgl',
84961 kernelFunc: identity$2
84962 };
84963
84964 /**
84965 * @license
84966 * Copyright 2020 Google LLC. All Rights Reserved.
84967 * Licensed under the Apache License, Version 2.0 (the "License");
84968 * you may not use this file except in compliance with the License.
84969 * You may obtain a copy of the License at
84970 *
84971 * http://www.apache.org/licenses/LICENSE-2.0
84972 *
84973 * Unless required by applicable law or agreed to in writing, software
84974 * distributed under the License is distributed on an "AS IS" BASIS,
84975 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84976 * See the License for the specific language governing permissions and
84977 * limitations under the License.
84978 * =============================================================================
84979 */
84980 /**
84981 * In WebGL data is stored in GPU textures which can't be efficiently copied, so
84982 * complex tensors share data with their real and imaginary components. Complex
84983 * tensors' reference to the components is tracked by refCount on the individual
84984 * component. The refCounts are increased by the identity call.
84985 *
84986 * When a complex tensor is disposed, it will reduce the refCount on the
84987 * components by calling disposeData on each.
84988 */
84989 function complex$2(args) {
84990 const { inputs, backend } = args;
84991 const { real, imag } = inputs;
84992 const complexInfo = backend.makeTensorInfo(real.shape, 'complex64');
84993 const complex = backend.texData.get(complexInfo.dataId);
84994 const realTensorInfo = identity$2({ inputs: { x: real }, backend });
84995 const imagTensorInfo = identity$2({ inputs: { x: imag }, backend });
84996 complex.complexTensorInfos = { real: realTensorInfo, imag: imagTensorInfo };
84997 return complexInfo;
84998 }
84999 const complexConfig$1 = {
85000 kernelName: Complex,
85001 backendName: 'webgl',
85002 kernelFunc: complex$2
85003 };
85004
85005 /**
85006 * @license
85007 * Copyright 2020 Google LLC. All Rights Reserved.
85008 * Licensed under the Apache License, Version 2.0 (the "License");
85009 * you may not use this file except in compliance with the License.
85010 * You may obtain a copy of the License at
85011 *
85012 * http://www.apache.org/licenses/LICENSE-2.0
85013 *
85014 * Unless required by applicable law or agreed to in writing, software
85015 * distributed under the License is distributed on an "AS IS" BASIS,
85016 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85017 * See the License for the specific language governing permissions and
85018 * limitations under the License.
85019 * =============================================================================
85020 */
85021 const LEAKYRELU = `return (a < 0.) ? b * a : a;`;
85022 const LEAKYRELU_PACKED = `
85023 vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));
85024 return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);
85025`;
85026 function leakyRelu$2(args) {
85027 const { inputs, backend, attrs } = args;
85028 const { x } = inputs;
85029 const { alpha } = attrs;
85030 const $alpha = backend.makeTensorInfo([], 'float32', createScalarValue(alpha, 'float32'));
85031 const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
85032 new BinaryOpPackedProgram(LEAKYRELU_PACKED, x.shape, $alpha.shape) :
85033 new BinaryOpProgram(LEAKYRELU, x.shape, $alpha.shape);
85034 const result = backend.runWebGLProgram(program, [x, $alpha], 'float32');
85035 backend.disposeIntermediateTensorInfo($alpha);
85036 return result;
85037 }
85038 const leakyReluConfig$1 = {
85039 kernelName: LeakyRelu,
85040 backendName: 'webgl',
85041 kernelFunc: leakyRelu$2
85042 };
85043
85044 /**
85045 * @license
85046 * Copyright 2020 Google LLC. All Rights Reserved.
85047 * Licensed under the Apache License, Version 2.0 (the "License");
85048 * you may not use this file except in compliance with the License.
85049 * You may obtain a copy of the License at
85050 *
85051 * http://www.apache.org/licenses/LICENSE-2.0
85052 *
85053 * Unless required by applicable law or agreed to in writing, software
85054 * distributed under the License is distributed on an "AS IS" BASIS,
85055 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85056 * See the License for the specific language governing permissions and
85057 * limitations under the License.
85058 * =============================================================================
85059 */
85060 const PRELU = `return (a < 0.) ? b * a : a;`;
85061 const PRELU_PACKED = `
85062 vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));
85063 return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);
85064`;
85065 function prelu$3(args) {
85066 const { inputs, backend } = args;
85067 const { x, alpha } = inputs;
85068 const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
85069 new BinaryOpPackedProgram(PRELU_PACKED, x.shape, alpha.shape) :
85070 new BinaryOpProgram(PRELU, x.shape, alpha.shape);
85071 return backend.runWebGLProgram(program, [x, alpha], 'float32');
85072 }
85073 const preluConfig$1 = {
85074 kernelName: Prelu,
85075 backendName: 'webgl',
85076 kernelFunc: prelu$3
85077 };
85078
85079 /**
85080 * @license
85081 * Copyright 2020 Google LLC. All Rights Reserved.
85082 * Licensed under the Apache License, Version 2.0 (the "License");
85083 * you may not use this file except in compliance with the License.
85084 * You may obtain a copy of the License at
85085 *
85086 * http://www.apache.org/licenses/LICENSE-2.0
85087 *
85088 * Unless required by applicable law or agreed to in writing, software
85089 * distributed under the License is distributed on an "AS IS" BASIS,
85090 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85091 * See the License for the specific language governing permissions and
85092 * limitations under the License.
85093 * =============================================================================
85094 */
85095 const CHECK_NAN_SNIPPET_UNARY = `if (isnan(x)) return x;`;
85096 const CHECK_NAN_SNIPPET_BINARY = `
85097 if (isnan(a)) return a;
85098 if (isnan(b)) return b;
85099`;
85100 const CHECK_NAN_SNIPPET_BINARY_PACKED = `
85101 result.r = isNaN.r > 0. ? NAN : result.r;
85102 result.g = isNaN.g > 0. ? NAN : result.g;
85103 result.b = isNaN.b > 0. ? NAN : result.b;
85104 result.a = isNaN.a > 0. ? NAN : result.a;
85105`;
85106 /**
85107 * Template that creates a `KernelFunc` for unary ops.
85108 * @param opSnippet Op snippet to create `UnaryOpProgram`.
85109 * @param packedOpSnippet Op snippet to create `UnaryOpPackedProgram`.
85110 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
85111 * result has the same dtype as the first input. This is mainly used in
85112 * comparison kernels, such as Equal, Less, Greater, etc.
85113 */
85114 function unaryKernelFunc$1({ opSnippet, packedOpSnippet, cpuKernelImpl, dtype }) {
85115 return ({ inputs, backend }) => {
85116 const { x } = inputs;
85117 const webglBackend = backend;
85118 const $dtype = dtype || x.dtype;
85119 if (webglBackend.shouldExecuteOnCPU([x]) && cpuKernelImpl != null) {
85120 const xData = webglBackend.texData.get(x.dataId);
85121 const outValues = cpuKernelImpl(xData.values, $dtype);
85122 return webglBackend.makeTensorInfo(x.shape, $dtype, outValues);
85123 }
85124 const shouldUsePackedProgram = env().getBool('WEBGL_PACK_UNARY_OPERATIONS') && packedOpSnippet != null;
85125 let program;
85126 if (shouldUsePackedProgram) {
85127 program = new UnaryOpPackedProgram(x.shape, packedOpSnippet);
85128 }
85129 else {
85130 program = new UnaryOpProgram(x.shape, opSnippet);
85131 }
85132 return webglBackend.runWebGLProgram(program, [x], $dtype);
85133 };
85134 }
85135 /**
85136 * Template that creates a `KernelFunc` for binary ops.
85137 * @param opSnippet Op snippet to create `BinaryOpProgram`.
85138 * @param packedOpSnippet Op snippet to create `BinaryOpPackedProgram`.
85139 * @param checkOutOfBoundsForPackedProgram Whether to set checkOutOfBounds=true
85140 * when creating BinaryOpPackedProgram.
85141 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
85142 * result has the same dtype as the first input. This is mainly used in
85143 * comparison kernels, such as Equal, Less, Greater, etc.
85144 */
85145 function binaryKernelFunc$1({ opSnippet, packedOpSnippet, checkOutOfBounds = false, supportsComplex = false, cpuKernelImpl, dtype }) {
85146 return ({ inputs, backend }) => {
85147 const { a, b } = inputs;
85148 const webglBackend = backend;
85149 if (supportsComplex && a.dtype === 'complex64') {
85150 const aData = webglBackend.texData.get(a.dataId);
85151 const bData = webglBackend.texData.get(b.dataId);
85152 const [real, imag] = [
85153 [aData.complexTensorInfos.real, bData.complexTensorInfos.real],
85154 [aData.complexTensorInfos.imag, bData.complexTensorInfos.imag]
85155 ].map(complexParts => {
85156 const [aPart, bPart] = complexParts;
85157 const aHandle = {
85158 dataId: aPart.dataId,
85159 dtype: aPart.dtype,
85160 shape: a.shape
85161 };
85162 const bHandle = {
85163 dataId: bPart.dataId,
85164 dtype: bPart.dtype,
85165 shape: b.shape
85166 };
85167 const program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
85168 return webglBackend.runWebGLProgram(program, [aHandle, bHandle], upcastType(aPart.dtype, bPart.dtype));
85169 });
85170 const complexOutput = complex$2({ inputs: { real, imag }, backend: webglBackend });
85171 webglBackend.disposeIntermediateTensorInfo(real);
85172 webglBackend.disposeIntermediateTensorInfo(imag);
85173 // TODO(annxingyuan): Implement CPU forwarding for complex inputs.
85174 return complexOutput;
85175 }
85176 const $dtype = dtype || upcastType(a.dtype, b.dtype);
85177 if ((a.dtype === 'string' || b.dtype === 'string' ||
85178 webglBackend.shouldExecuteOnCPU([a, b])) &&
85179 cpuKernelImpl != null) {
85180 const aVals = webglBackend.texData.get(a.dataId).values;
85181 const bVals = webglBackend.texData.get(b.dataId).values;
85182 const decodedAVals = a.dtype === 'string' ?
85183 // tslint:disable-next-line: no-any
85184 fromUint8ToStringArray(aVals) :
85185 aVals;
85186 const decodedBVals = a.dtype === 'string' ?
85187 // tslint:disable-next-line: no-any
85188 fromUint8ToStringArray(bVals) :
85189 bVals;
85190 const [outValues, outShape] = cpuKernelImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype);
85191 const out = webglBackend.makeTensorInfo(outShape, $dtype);
85192 const outData = webglBackend.texData.get(out.dataId);
85193 outData.values = outValues;
85194 return out;
85195 }
85196 const shouldUsePackedProgram = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') &&
85197 packedOpSnippet != null;
85198 let program;
85199 if (shouldUsePackedProgram) {
85200 program = new BinaryOpPackedProgram(packedOpSnippet, a.shape, b.shape, checkOutOfBounds);
85201 }
85202 else {
85203 program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
85204 }
85205 return webglBackend.runWebGLProgram(program, [a, b], $dtype);
85206 };
85207 }
85208 function mapActivationToShaderProgram(activation, packed = false) {
85209 if (activation === 'linear') {
85210 if (packed) {
85211 return LINEAR$1;
85212 }
85213 return LINEAR;
85214 }
85215 else if (activation === 'relu') {
85216 if (packed) {
85217 return RELU$1;
85218 }
85219 return RELU;
85220 }
85221 else if (activation === 'elu') {
85222 if (packed) {
85223 return ELU$2;
85224 }
85225 return ELU$1;
85226 }
85227 else if (activation === 'relu6') {
85228 if (packed) {
85229 return RELU6$1;
85230 }
85231 return RELU6;
85232 }
85233 else if (activation === 'prelu') {
85234 if (packed) {
85235 return PRELU_PACKED;
85236 }
85237 return PRELU;
85238 }
85239 else if (activation === 'leakyrelu') {
85240 if (packed) {
85241 return LEAKYRELU_PACKED;
85242 }
85243 return LEAKYRELU;
85244 }
85245 else if (activation === 'sigmoid') {
85246 if (packed) {
85247 return SIGMOID$1;
85248 }
85249 return SIGMOID;
85250 }
85251 throw new Error(`Activation ${activation} has not been implemented for the WebGL backend.`);
85252 }
85253
85254 /**
85255 * @license
85256 * Copyright 2018 Google LLC. All Rights Reserved.
85257 * Licensed under the Apache License, Version 2.0 (the "License");
85258 * you may not use this file except in compliance with the License.
85259 * You may obtain a copy of the License at
85260 *
85261 * http://www.apache.org/licenses/LICENSE-2.0
85262 *
85263 * Unless required by applicable law or agreed to in writing, software
85264 * distributed under the License is distributed on an "AS IS" BASIS,
85265 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85266 * See the License for the specific language governing permissions and
85267 * limitations under the License.
85268 * =============================================================================
85269 */
85270 class MatMulPackedProgram {
85271 constructor(aShape, bShape, outputShape, transposeA = false, transposeB = false, addBias = false, activation = null, hasPreluActivation = false, hasLeakyreluActivation = false) {
85272 this.variableNames = ['matrixA', 'matrixB'];
85273 this.packedInputs = true;
85274 this.packedOutput = true;
85275 this.outputShape = outputShape;
85276 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
85277 const sharedDim = transposeA ? aShape[1] : aShape[2];
85278 const sharedDimensionPacked = Math.ceil(sharedDim / 2);
85279 const aSample = transposeA ? 'i * 2, rc.y' : 'rc.y, i * 2';
85280 const bSample = transposeB ? 'rc.z, i * 2' : 'i * 2, rc.z';
85281 const aSwizzle = transposeA ? ['a.xxyy', 'a.zzww'] : ['a.xxzz', 'a.yyww'];
85282 const bSwizzle = transposeB ? ['b.xzxz', 'b.ywyw'] : ['b.xyxy', 'b.zwzw'];
85283 let activationSnippet = '', applyActivationSnippet = '';
85284 if (activation) {
85285 if (hasPreluActivation) {
85286 activationSnippet = `vec4 activation(vec4 a) {
85287 vec4 b = getPreluActivationWeightsAtOutCoords();
85288 ${activation}
85289 }`;
85290 }
85291 else if (hasLeakyreluActivation) {
85292 activationSnippet = `vec4 activation(vec4 a) {
85293 vec4 b = getLeakyreluAlphaAtOutCoords();
85294 ${activation}
85295 }`;
85296 }
85297 else {
85298 activationSnippet = `vec4 activation(vec4 x) {
85299 ${activation}
85300 }`;
85301 }
85302 applyActivationSnippet = `result = activation(result);`;
85303 }
85304 const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
85305 if (addBias) {
85306 this.variableNames.push('bias');
85307 }
85308 if (hasPreluActivation) {
85309 this.variableNames.push('preluActivationWeights');
85310 }
85311 if (hasLeakyreluActivation) {
85312 this.variableNames.push('leakyreluAlpha');
85313 }
85314 let batchASnippet = 'rc.x';
85315 let batchBSnippet = 'rc.x';
85316 if (aShape[0] < bShape[0]) {
85317 batchASnippet = `int(min(float(rc.x), ${aShape[0] - 1}.))`;
85318 }
85319 else if (bShape[0] < aShape[0]) {
85320 batchBSnippet = `int(min(float(rc.x), ${bShape[0] - 1}.))`;
85321 }
85322 this.userCode = `
85323 ${activationSnippet}
85324 // Don't use uniform for sharedDimensionPacked for performance.
85325 const float sharedDimension = ${sharedDimensionPacked}.0;
85326
85327 vec4 dot2x2ARowBCol(ivec3 rc) {
85328 vec4 result = vec4(0);
85329 for (int i = 0; i < ${sharedDimensionPacked}; i++) {
85330 int batchA = ${batchASnippet};
85331 int batchB = ${batchBSnippet};
85332 vec4 a = getMatrixA(batchA, ${aSample});
85333 vec4 b = getMatrixB(batchB, ${bSample});
85334
85335 // These swizzled products need to be separately added.
85336 // See: https://github.com/tensorflow/tfjs/issues/1735
85337 result += (${aSwizzle[0]} * ${bSwizzle[0]});
85338 result += (${aSwizzle[1]} * ${bSwizzle[1]});
85339 }
85340 return result;
85341 }
85342
85343 void main() {
85344 ivec3 rc = getOutputCoords();
85345 vec4 result = dot2x2ARowBCol(rc);
85346
85347 ${addBiasSnippet}
85348
85349 ${applyActivationSnippet}
85350
85351 setOutput(result);
85352 }
85353 `;
85354 }
85355 }
85356
85357 /**
85358 * @license
85359 * Copyright 2018 Google LLC. All Rights Reserved.
85360 * Licensed under the Apache License, Version 2.0 (the "License");
85361 * you may not use this file except in compliance with the License.
85362 * You may obtain a copy of the License at
85363 *
85364 * http://www.apache.org/licenses/LICENSE-2.0
85365 *
85366 * Unless required by applicable law or agreed to in writing, software
85367 * distributed under the License is distributed on an "AS IS" BASIS,
85368 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85369 * See the License for the specific language governing permissions and
85370 * limitations under the License.
85371 * =============================================================================
85372 */
85373 // (Ar + Ai)(Br + Bi) =
85374 // ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr
85375 // Yr = ArBr - AB
85376 // Yi = ArBi + AiBr
85377 const COMPLEX_MULTIPLY = {
85378 REAL: 'return areal * breal - aimag * bimag;',
85379 IMAG: 'return areal * bimag + aimag * breal;'
85380 };
85381 class BinaryOpComplexProgram {
85382 constructor(op, aShape, bShape) {
85383 this.variableNames = ['AReal', 'AImag', 'BReal', 'BImag'];
85384 this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
85385 this.userCode = `
85386 float binaryOpComplex(
85387 float areal, float aimag, float breal, float bimag) {
85388 ${op}
85389 }
85390
85391 void main() {
85392 float areal = getARealAtOutCoords();
85393 float aimag = getAImagAtOutCoords();
85394 float breal = getBRealAtOutCoords();
85395 float bimag = getBImagAtOutCoords();
85396 setOutput(binaryOpComplex(areal, aimag, breal, bimag));
85397 }
85398 `;
85399 }
85400 }
85401
85402 /**
85403 * @license
85404 * Copyright 2020 Google LLC. All Rights Reserved.
85405 * Licensed under the Apache License, Version 2.0 (the "License");
85406 * you may not use this file except in compliance with the License.
85407 * You may obtain a copy of the License at
85408 *
85409 * http://www.apache.org/licenses/LICENSE-2.0
85410 *
85411 * Unless required by applicable law or agreed to in writing, software
85412 * distributed under the License is distributed on an "AS IS" BASIS,
85413 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85414 * See the License for the specific language governing permissions and
85415 * limitations under the License.
85416 * =============================================================================
85417 */
85418 const MUL = 'return a * b;';
85419 function multiply$3(args) {
85420 const { inputs, backend } = args;
85421 const { a, b } = inputs;
85422 const dtype = upcastType(a.dtype, b.dtype);
85423 if (a.dtype === 'complex64') {
85424 const aData = backend.texData.get(a.dataId);
85425 const bData = backend.texData.get(b.dataId);
85426 const realProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.REAL, a.shape, b.shape);
85427 const imagProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.IMAG, a.shape, b.shape);
85428 const inputs = [
85429 {
85430 dataId: aData.complexTensorInfos.real.dataId,
85431 dtype: aData.complexTensorInfos.real.dtype,
85432 shape: a.shape
85433 },
85434 {
85435 dataId: aData.complexTensorInfos.imag.dataId,
85436 dtype: aData.complexTensorInfos.imag.dtype,
85437 shape: a.shape
85438 },
85439 {
85440 dataId: bData.complexTensorInfos.real.dataId,
85441 dtype: bData.complexTensorInfos.real.dtype,
85442 shape: b.shape
85443 },
85444 {
85445 dataId: bData.complexTensorInfos.imag.dataId,
85446 dtype: bData.complexTensorInfos.imag.dtype,
85447 shape: b.shape
85448 }
85449 ];
85450 const realPart = backend.runWebGLProgram(realProgram, inputs, 'float32');
85451 const imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32');
85452 const complexOutput = complex$2({ inputs: { real: realPart, imag: imagPart }, backend });
85453 backend.disposeIntermediateTensorInfo(realPart);
85454 backend.disposeIntermediateTensorInfo(imagPart);
85455 // TODO(annxingyuan): CPU forwarding for complex inputs.
85456 return complexOutput;
85457 }
85458 if (backend.shouldExecuteOnCPU([a, b])) {
85459 const aData = backend.texData.get(a.dataId);
85460 const bData = backend.texData.get(b.dataId);
85461 const [outValues, outShape] = multiplyImplCPU(a.shape, b.shape, aData.values, bData.values, dtype);
85462 const out = backend.makeTensorInfo(outShape, dtype);
85463 const outData = backend.texData.get(out.dataId);
85464 outData.values = outValues;
85465 return out;
85466 }
85467 let program;
85468 if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
85469 program = new BinaryOpPackedProgram(MUL, a.shape, b.shape);
85470 }
85471 else {
85472 program = new BinaryOpProgram(MUL, a.shape, b.shape);
85473 }
85474 return backend.runWebGLProgram(program, [a, b], dtype);
85475 }
85476 const multiplyConfig$1 = {
85477 kernelName: Multiply,
85478 backendName: 'webgl',
85479 kernelFunc: multiply$3
85480 };
85481
85482 /**
85483 * @license
85484 * Copyright 2020 Google LLC. All Rights Reserved.
85485 * Licensed under the Apache License, Version 2.0 (the "License");
85486 * you may not use this file except in compliance with the License.
85487 * You may obtain a copy of the License at
85488 *
85489 * http://www.apache.org/licenses/LICENSE-2.0
85490 *
85491 * Unless required by applicable law or agreed to in writing, software
85492 * distributed under the License is distributed on an "AS IS" BASIS,
85493 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85494 * See the License for the specific language governing permissions and
85495 * limitations under the License.
85496 * =============================================================================
85497 */
85498 function packedReshape(input, afterShape, backend) {
85499 const input3DShape = [getBatchDim(input.shape),
85500 ...getRowsCols(input.shape)];
85501 const input3D = {
85502 dtype: input.dtype,
85503 shape: input3DShape,
85504 dataId: input.dataId
85505 };
85506 const afterShapeAs3D = [getBatchDim(afterShape),
85507 ...getRowsCols(afterShape)];
85508 const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
85509 const preventEagerUnpackingOfOutput = true;
85510 const customValues = [input3DShape];
85511 const output = backend.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput);
85512 return { dataId: output.dataId, shape: afterShape, dtype: output.dtype };
85513 }
85514
85515 /**
85516 * @license
85517 * Copyright 2020 Google LLC. All Rights Reserved.
85518 * Licensed under the Apache License, Version 2.0 (the "License");
85519 * you may not use this file except in compliance with the License.
85520 * You may obtain a copy of the License at
85521 *
85522 * http://www.apache.org/licenses/LICENSE-2.0
85523 *
85524 * Unless required by applicable law or agreed to in writing, software
85525 * distributed under the License is distributed on an "AS IS" BASIS,
85526 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85527 * See the License for the specific language governing permissions and
85528 * limitations under the License.
85529 * =============================================================================
85530 */
85531 function reshape$3(args) {
85532 const { inputs, backend, attrs } = args;
85533 const { x } = inputs;
85534 const { shape } = attrs;
85535 const webglBackend = backend;
85536 const xSize = sizeFromShape(x.shape);
85537 const $shape = inferFromImplicitShape(shape, xSize);
85538 const $xSize = sizeFromShape($shape);
85539 assert(xSize === $xSize, () => `The new shape (${$shape}) has ${$xSize} elements and the old ` +
85540 `shape (${x.shape}) has ${xSize} elements. The new shape and old ` +
85541 `shape must have the same number of elements.`);
85542 const xTexData = webglBackend.texData.get(x.dataId);
85543 if (xTexData.isPacked && !isReshapeFree(x.shape, $shape) &&
85544 !(xTexData.texture !== null && isReshapeFree(xTexData.shape, $shape))) {
85545 return packedReshape(x, $shape, webglBackend);
85546 }
85547 webglBackend.incRef(x.dataId);
85548 return { dataId: x.dataId, shape: $shape, dtype: x.dtype };
85549 }
85550 const reshapeConfig$1 = {
85551 kernelName: Reshape,
85552 backendName: 'webgl',
85553 kernelFunc: reshape$3
85554 };
85555
85556 /**
85557 * @license
85558 * Copyright 2020 Google LLC. All Rights Reserved.
85559 * Licensed under the Apache License, Version 2.0 (the "License");
85560 * you may not use this file except in compliance with the License.
85561 * You may obtain a copy of the License at
85562 *
85563 * http://www.apache.org/licenses/LICENSE-2.0
85564 *
85565 * Unless required by applicable law or agreed to in writing, software
85566 * distributed under the License is distributed on an "AS IS" BASIS,
85567 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85568 * See the License for the specific language governing permissions and
85569 * limitations under the License.
85570 * =============================================================================
85571 */
85572 class MeanProgram {
85573 constructor(reduceInfo, divisor) {
85574 this.variableNames = ['x'];
85575 const { windowSize, batchSize, inSize, outSize } = reduceInfo;
85576 this.outputShape = [batchSize, outSize];
85577 const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
85578 const windowSizeVec4Remainder = windowSize % 4;
85579 let updateSnippet = `sumValue += dot(values, ones);`;
85580 if (divisor != null) {
85581 const denominator = 1 / divisor;
85582 updateSnippet = `sumValue += dot(values * ${isInt(denominator) ? denominator.toPrecision(2) :
85583 denominator}, ones);`;
85584 }
85585 let checkOutOfBounds = '';
85586 if (inSize % windowSize > 0) {
85587 checkOutOfBounds = `
85588 if (inIdx < 0 || inIdx >= ${inSize}) {
85589 return 0.0;
85590 }
85591 `;
85592 }
85593 this.userCode = `
85594 const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
85595
85596 float getValue(int batch, int inIdx) {
85597 ${checkOutOfBounds}
85598 return getX(batch, inIdx);
85599 }
85600
85601 void main() {
85602 ivec2 coords = getOutputCoords();
85603 int batch = coords[0];
85604 int outIdx = coords[1];
85605 int inOffset = outIdx * ${windowSize};
85606
85607 float sumValue = 0.0;
85608
85609 for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
85610 int inIdx = inOffset + i;
85611 vec4 values = vec4(
85612 getValue(batch, inIdx),
85613 getValue(batch, inIdx + 1),
85614 getValue(batch, inIdx + 2),
85615 getValue(batch, inIdx + 3)
85616 );
85617
85618 ${updateSnippet}
85619 }
85620
85621 int inIdx = inOffset + ${windowSizeNearestVec4};
85622 if (${windowSizeVec4Remainder === 1}) {
85623 vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0);
85624
85625 ${updateSnippet}
85626 } else if (${windowSizeVec4Remainder === 2}) {
85627 vec4 values = vec4(
85628 getValue(batch, inIdx),
85629 getValue(batch, inIdx + 1), 0.0, 0.0);
85630
85631 ${updateSnippet}
85632 } else if (${windowSizeVec4Remainder === 3}) {
85633 vec4 values = vec4(
85634 getValue(batch, inIdx),
85635 getValue(batch, inIdx + 1),
85636 getValue(batch, inIdx + 2), 0.0);
85637
85638 ${updateSnippet}
85639 }
85640 setOutput(sumValue);
85641 }
85642 `;
85643 }
85644 }
85645
85646 /**
85647 * @license
85648 * Copyright 2017 Google LLC. All Rights Reserved.
85649 * Licensed under the Apache License, Version 2.0 (the "License");
85650 * you may not use this file except in compliance with the License.
85651 * You may obtain a copy of the License at
85652 *
85653 * http://www.apache.org/licenses/LICENSE-2.0
85654 *
85655 * Unless required by applicable law or agreed to in writing, software
85656 * distributed under the License is distributed on an "AS IS" BASIS,
85657 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85658 * See the License for the specific language governing permissions and
85659 * limitations under the License.
85660 * =============================================================================
85661 */
85662 class ReduceProgram {
85663 constructor(reduceInfo, reduceType) {
85664 this.variableNames = ['x'];
85665 const { windowSize, batchSize, inSize, outSize } = reduceInfo;
85666 this.outputShape = [batchSize, outSize];
85667 let initializationValue = '0.0';
85668 let compareOp = ``;
85669 if (reduceType === 'prod') {
85670 initializationValue = '1.0';
85671 }
85672 else if (reduceType === 'min') {
85673 // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
85674 initializationValue = '1.0 / 1e-20';
85675 compareOp = `min`;
85676 }
85677 else if (reduceType === 'max') {
85678 // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
85679 initializationValue = '-1.0 / 1e-20';
85680 compareOp = `max`;
85681 }
85682 let returnValue = `${reduceType}(${reduceType}(${reduceType}(` +
85683 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
85684 if (reduceType === 'sum') {
85685 returnValue = `sumValue`;
85686 }
85687 else if (reduceType === 'prod') {
85688 returnValue = `prodValue`;
85689 }
85690 else if (reduceType === 'all') {
85691 returnValue = `allValue`;
85692 }
85693 else if (reduceType === 'any') {
85694 returnValue = `anyValue`;
85695 }
85696 const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
85697 const windowSizeVec4Remainder = windowSize % 4;
85698 let updateSnippet = `
85699 if (${reduceType === 'sum'}) {
85700 sumValue += dot(values, ones);
85701 } else if (${reduceType === 'prod'}) {
85702 vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);
85703 prodValue *= tmp[0] * tmp[1];
85704 } else {
85705 minMaxValue = ${compareOp}(values, minMaxValue);
85706 if (${reduceType === 'min'} || ${reduceType === 'max'}) {
85707 minMaxValue = ${compareOp}(values, minMaxValue);
85708 bvec4 isNaN = isnan(values);
85709 if (isNaN.r || isNaN.g || isNaN.b || isNaN.a) {
85710 minMaxValue = vec4(NAN);
85711 }
85712 }
85713 }
85714 `;
85715 let vecType = `vec4`;
85716 if (reduceType === 'all') {
85717 initializationValue = '1.0';
85718 updateSnippet = `
85719 bool reducedAllValue = all(values);
85720 float floatedReducedAllValue = float(reducedAllValue);
85721 allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);
85722 `;
85723 vecType = `bvec4`;
85724 }
85725 else if (reduceType === 'any') {
85726 initializationValue = '0.0';
85727 updateSnippet = `
85728 bool reducedAnyValue = any(values);
85729 float floatedReducedAnyValue = float(reducedAnyValue);
85730 anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);
85731 `;
85732 vecType = `bvec4`;
85733 }
85734 let checkOutOfBounds = '';
85735 if (inSize % windowSize > 0) {
85736 checkOutOfBounds = `
85737 if (inIdx < 0 || inIdx >= ${inSize}) {
85738 return initializationValue;
85739 }
85740 `;
85741 }
85742 this.userCode = `
85743 const float initializationValue = ${initializationValue};
85744 const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
85745
85746 float getValue(int batch, int inIdx) {
85747 ${checkOutOfBounds}
85748 return getX(batch, inIdx);
85749 }
85750
85751 void main() {
85752 ivec2 coords = getOutputCoords();
85753 int batch = coords[0];
85754 int outIdx = coords[1];
85755 int inOffset = outIdx * ${windowSize};
85756
85757 vec4 minMaxValue = vec4(${initializationValue});
85758 float prodValue = 1.0;
85759 float sumValue = 0.0;
85760 float allValue = 1.0;
85761 float anyValue = 0.0;
85762
85763 for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
85764 int inIdx = inOffset + i;
85765 ${vecType} values = ${vecType}(
85766 getValue(batch, inIdx),
85767 getValue(batch, inIdx + 1),
85768 getValue(batch, inIdx + 2),
85769 getValue(batch, inIdx + 3)
85770 );
85771
85772 ${updateSnippet}
85773 }
85774
85775 int inIdx = inOffset + ${windowSizeNearestVec4};
85776 if (${windowSizeVec4Remainder === 1}) {
85777 ${vecType} values = ${vecType}(
85778 getValue(batch, inIdx),
85779 initializationValue,
85780 initializationValue,
85781 initializationValue
85782 );
85783
85784 ${updateSnippet}
85785 } else if (${windowSizeVec4Remainder === 2}) {
85786 ${vecType} values = ${vecType}(
85787 getValue(batch, inIdx),
85788 getValue(batch, inIdx + 1),
85789 initializationValue,
85790 initializationValue
85791 );
85792
85793 ${updateSnippet}
85794 } else if (${windowSizeVec4Remainder === 3}) {
85795 ${vecType} values = ${vecType}(
85796 getValue(batch, inIdx),
85797 getValue(batch, inIdx + 1),
85798 getValue(batch, inIdx + 2),
85799 initializationValue
85800 );
85801
85802 ${updateSnippet}
85803 }
85804 setOutput(${returnValue});
85805 }
85806 `;
85807 }
85808 }
85809
85810 /**
85811 * @license
85812 * Copyright 2020 Google LLC. All Rights Reserved.
85813 * Licensed under the Apache License, Version 2.0 (the "License");
85814 * you may not use this file except in compliance with the License.
85815 * You may obtain a copy of the License at
85816 *
85817 * http://www.apache.org/licenses/LICENSE-2.0
85818 *
85819 * Unless required by applicable law or agreed to in writing, software
85820 * distributed under the License is distributed on an "AS IS" BASIS,
85821 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85822 * See the License for the specific language governing permissions and
85823 * limitations under the License.
85824 * =============================================================================
85825 */
85826 // Returns an array of configuration objects that describe each stage of the
85827 // reduction.
85828 function getReductionStages(inShape) {
85829 const stages = [];
85830 while (stages.length === 0 || stages[stages.length - 1].outSize !== 1) {
85831 const outSize = stages.length ? stages[stages.length - 1].outSize : inShape[1];
85832 const windowSize = computeOptimalWindowSize(outSize);
85833 stages.push({
85834 inSize: outSize,
85835 windowSize,
85836 outSize: Math.ceil(outSize / windowSize)
85837 });
85838 }
85839 return stages;
85840 }
85841 function reduce(x, dtype, reductionType, backend) {
85842 const reductionStages = getReductionStages(x.shape);
85843 let result = x;
85844 for (let i = 0; i < reductionStages.length; i++) {
85845 const { inSize, windowSize, outSize } = reductionStages[i];
85846 let program;
85847 let previousResult;
85848 if (reductionType === 'mean') {
85849 program = i === 0 ?
85850 new MeanProgram({ windowSize, inSize, batchSize: x.shape[0], outSize }, inSize) :
85851 new MeanProgram({ windowSize, inSize, batchSize: x.shape[0], outSize });
85852 }
85853 else {
85854 program = new ReduceProgram({ windowSize, inSize, batchSize: x.shape[0], outSize }, reductionType);
85855 }
85856 previousResult = result;
85857 result = backend.runWebGLProgram(program, [result], dtype);
85858 if (previousResult.dataId !== x.dataId) {
85859 backend.disposeIntermediateTensorInfo(previousResult);
85860 }
85861 }
85862 return result;
85863 }
85864
85865 /**
85866 * @license
85867 * Copyright 2017 Google LLC. All Rights Reserved.
85868 * Licensed under the Apache License, Version 2.0 (the "License");
85869 * you may not use this file except in compliance with the License.
85870 * You may obtain a copy of the License at
85871 *
85872 * http://www.apache.org/licenses/LICENSE-2.0
85873 *
85874 * Unless required by applicable law or agreed to in writing, software
85875 * distributed under the License is distributed on an "AS IS" BASIS,
85876 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85877 * See the License for the specific language governing permissions and
85878 * limitations under the License.
85879 * =============================================================================
85880 */
85881 class TransposeProgram {
85882 constructor(aShape, newDim) {
85883 this.variableNames = ['A'];
85884 const outputShape = new Array(aShape.length);
85885 for (let i = 0; i < outputShape.length; i++) {
85886 outputShape[i] = aShape[newDim[i]];
85887 }
85888 this.outputShape = outputShape;
85889 this.rank = outputShape.length;
85890 const dtype = getCoordsDataType(this.rank);
85891 const switched = getSwitchedCoords(newDim);
85892 this.userCode = `
85893 void main() {
85894 ${dtype} resRC = getOutputCoords();
85895 setOutput(getA(${switched}));
85896 }
85897 `;
85898 }
85899 }
85900 function getSwitchedCoords(newDim) {
85901 const rank = newDim.length;
85902 if (rank > 6) {
85903 throw Error(`Transpose for rank ${rank} is not yet supported`);
85904 }
85905 const originalOrder = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v'];
85906 const switchedCoords = new Array(rank);
85907 for (let i = 0; i < newDim.length; i++) {
85908 switchedCoords[newDim[i]] = originalOrder[i];
85909 }
85910 return switchedCoords.join();
85911 }
85912
85913 /**
85914 * @license
85915 * Copyright 2019 Google LLC. All Rights Reserved.
85916 * Licensed under the Apache License, Version 2.0 (the "License");
85917 * you may not use this file except in compliance with the License.
85918 * You may obtain a copy of the License at
85919 *
85920 * http://www.apache.org/licenses/LICENSE-2.0
85921 *
85922 * Unless required by applicable law or agreed to in writing, software
85923 * distributed under the License is distributed on an "AS IS" BASIS,
85924 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85925 * See the License for the specific language governing permissions and
85926 * limitations under the License.
85927 * =============================================================================
85928 */
85929 class TransposePackedProgram {
85930 constructor(aShape, newDim) {
85931 this.variableNames = ['A'];
85932 this.packedInputs = true;
85933 this.packedOutput = true;
85934 const outputShape = new Array(aShape.length);
85935 for (let i = 0; i < outputShape.length; i++) {
85936 outputShape[i] = aShape[newDim[i]];
85937 }
85938 this.outputShape = outputShape;
85939 this.rank = outputShape.length;
85940 if (this.rank > 6) {
85941 throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`);
85942 }
85943 const dtype = getCoordsDataType(this.rank);
85944 const outputOrder = getVecChannels('rc', this.rank);
85945 const switchedOrder = new Array(this.rank);
85946 for (let i = 0; i < newDim.length; i++) {
85947 switchedOrder[newDim[i]] = outputOrder[i];
85948 }
85949 const innerDims = `vec2(${switchedOrder.slice(-2).join()})`;
85950 const nextColumn = `++${outputOrder[this.rank - 1]} < ${outputShape[this.rank - 1]}`;
85951 const getc = `getChannel(getA(${switchedOrder.join()}), ${innerDims})`;
85952 this.userCode = `
85953 void main() {
85954 ${dtype} rc = getOutputCoords();
85955 vec4 result = vec4(0.);
85956 result[0] = ${getc};
85957 if(${nextColumn}) {
85958 result[1] = ${getc};
85959 }
85960 --${outputOrder[this.rank - 1]};
85961 if(++${outputOrder[this.rank - 2]} < ${outputShape[this.rank - 2]}) {
85962 result[2] = ${getc};
85963 if(${nextColumn}) {
85964 result[3] = ${getc};
85965 }
85966 }
85967 setOutput(result);
85968 }
85969 `;
85970 }
85971 }
85972
85973 /**
85974 * @license
85975 * Copyright 2020 Google LLC. All Rights Reserved.
85976 * Licensed under the Apache License, Version 2.0 (the "License");
85977 * you may not use this file except in compliance with the License.
85978 * You may obtain a copy of the License at
85979 *
85980 * http://www.apache.org/licenses/LICENSE-2.0
85981 *
85982 * Unless required by applicable law or agreed to in writing, software
85983 * distributed under the License is distributed on an "AS IS" BASIS,
85984 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85985 * See the License for the specific language governing permissions and
85986 * limitations under the License.
85987 * =============================================================================
85988 */
85989 function transposeImpl$1(x, perm, backend) {
85990 const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
85991 new TransposePackedProgram(x.shape, perm) :
85992 new TransposeProgram(x.shape, perm);
85993 return backend.runWebGLProgram(program, [x], x.dtype);
85994 }
85995
85996 /**
85997 * @license
85998 * Copyright 2020 Google LLC. All Rights Reserved.
85999 * Licensed under the Apache License, Version 2.0 (the "License");
86000 * you may not use this file except in compliance with the License.
86001 * You may obtain a copy of the License at
86002 *
86003 * http://www.apache.org/licenses/LICENSE-2.0
86004 *
86005 * Unless required by applicable law or agreed to in writing, software
86006 * distributed under the License is distributed on an "AS IS" BASIS,
86007 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86008 * See the License for the specific language governing permissions and
86009 * limitations under the License.
86010 * =============================================================================
86011 */
86012 function sumImpl(x, axis, keepDims, backend) {
86013 const reductionIndices = axis;
86014 const xRank = x.shape.length;
86015 const origAxes = parseAxisParam(reductionIndices, x.shape);
86016 let axes = origAxes;
86017 const permutedAxes = getAxesPermutation(axes, xRank);
86018 const sumInputIsTransposed = permutedAxes != null;
86019 let sumInput = x;
86020 if (sumInputIsTransposed) {
86021 sumInput = transposeImpl$1(x, permutedAxes, backend);
86022 axes = getInnerMostAxes(axes.length, xRank);
86023 }
86024 assertAxesAreInnerMostDims('sum', axes, xRank);
86025 const [sumOutShape, reduceShape] = computeOutAndReduceShapes(sumInput.shape, axes);
86026 let outShape = sumOutShape;
86027 if (keepDims) {
86028 // rather than reshape at the end, set the target shape here.
86029 outShape = expandShapeToKeepDim(sumOutShape, origAxes);
86030 }
86031 const inSize = sizeFromShape(reduceShape);
86032 const xSize = sizeFromShape(x.shape);
86033 const batchSize = xSize / inSize;
86034 const reshapedInput = reshape$3({ inputs: { x: sumInput }, attrs: { shape: [batchSize, inSize] }, backend });
86035 const outType = sumOutType(x.dtype);
86036 const reduced = reduce(reshapedInput, outType, 'sum', backend);
86037 const out = reshape$3({ inputs: { x: reduced }, attrs: { shape: outShape }, backend });
86038 backend.disposeIntermediateTensorInfo(reshapedInput);
86039 backend.disposeIntermediateTensorInfo(reduced);
86040 if (sumInputIsTransposed) {
86041 backend.disposeIntermediateTensorInfo(sumInput);
86042 }
86043 return out;
86044 }
86045
86046 /**
86047 * @license
86048 * Copyright 2020 Google LLC. All Rights Reserved.
86049 * Licensed under the Apache License, Version 2.0 (the "License");
86050 * you may not use this file except in compliance with the License.
86051 * You may obtain a copy of the License at
86052 *
86053 * http://www.apache.org/licenses/LICENSE-2.0
86054 *
86055 * Unless required by applicable law or agreed to in writing, software
86056 * distributed under the License is distributed on an "AS IS" BASIS,
86057 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86058 * See the License for the specific language governing permissions and
86059 * limitations under the License.
86060 * =============================================================================
86061 */
86062 function sum$4(args) {
86063 const { inputs, backend, attrs } = args;
86064 const { x } = inputs;
86065 const { axis, keepDims } = attrs;
86066 return sumImpl(x, axis, keepDims, backend);
86067 }
86068 const sumConfig$1 = {
86069 kernelName: Sum,
86070 backendName: 'webgl',
86071 kernelFunc: sum$4
86072 };
86073
86074 /**
86075 * @license
86076 * Copyright 2020 Google LLC. All Rights Reserved.
86077 * Licensed under the Apache License, Version 2.0 (the "License");
86078 * you may not use this file except in compliance with the License.
86079 * You may obtain a copy of the License at
86080 *
86081 * http://www.apache.org/licenses/LICENSE-2.0
86082 *
86083 * Unless required by applicable law or agreed to in writing, software
86084 * distributed under the License is distributed on an "AS IS" BASIS,
86085 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86086 * See the License for the specific language governing permissions and
86087 * limitations under the License.
86088 * =============================================================================
86089 */
86090 function transpose$2(args) {
86091 const { inputs, backend, attrs } = args;
86092 const { x } = inputs;
86093 const { perm } = attrs;
86094 const webglBackend = backend;
86095 const xRank = x.shape.length;
86096 const newShape = new Array(xRank);
86097 for (let i = 0; i < newShape.length; i++) {
86098 newShape[i] = x.shape[perm[i]];
86099 }
86100 let out;
86101 if (webglBackend.shouldExecuteOnCPU([x])) {
86102 const xTexData = webglBackend.texData.get(x.dataId);
86103 const values = xTexData.values;
86104 const outValues = transposeImplCPU(values, x.shape, x.dtype, perm, newShape);
86105 out = webglBackend.makeTensorInfo(newShape, x.dtype);
86106 const outData = webglBackend.texData.get(out.dataId);
86107 outData.values = outValues;
86108 }
86109 else {
86110 out = transposeImpl$1(x, perm, webglBackend);
86111 }
86112 return out;
86113 }
86114 const transposeConfig$1 = {
86115 kernelName: Transpose,
86116 backendName: 'webgl',
86117 kernelFunc: transpose$2
86118 };
86119
86120 /**
86121 * @license
86122 * Copyright 2020 Google LLC. All Rights Reserved.
86123 * Licensed under the Apache License, Version 2.0 (the "License");
86124 * you may not use this file except in compliance with the License.
86125 * You may obtain a copy of the License at
86126 *
86127 * http://www.apache.org/licenses/LICENSE-2.0
86128 *
86129 * Unless required by applicable law or agreed to in writing, software
86130 * distributed under the License is distributed on an "AS IS" BASIS,
86131 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86132 * See the License for the specific language governing permissions and
86133 * limitations under the License.
86134 * =============================================================================
86135 */
86136 // Empirically determined minimal shared dimension in matmul before we forward
86137 // to a.mul(b).sum() in order to take advantage of GPU parallelism. See
86138 // https://github.com/tensorflow/tfjs-core/pull/1379 for benchmarks.
86139 const MATMUL_SHARED_DIM_THRESHOLD = 1000;
86140 function batchMatMulImpl({ a, b, transposeA, transposeB, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) {
86141 const aRank = a.shape.length;
86142 const bRank = b.shape.length;
86143 const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
86144 const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
86145 const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
86146 const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
86147 const outerDimsA = a.shape.slice(0, -2);
86148 const outerDimsB = b.shape.slice(0, -2);
86149 const batchDimA = sizeFromShape(outerDimsA);
86150 const batchDimB = sizeFromShape(outerDimsB);
86151 const outShapeOuterDims = assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2));
86152 const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
86153 assert(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` +
86154 `${innerShapeB}) of Tensors with shapes ${a.shape} and ` +
86155 `${b.shape} and transposeA=${transposeA}` +
86156 ` and transposeB=${transposeB} must match.`);
86157 const a3dShape = transposeA ?
86158 [batchDimA, innerShapeA, outerShapeA] :
86159 [batchDimA, outerShapeA, innerShapeA];
86160 const b3dShape = transposeB ?
86161 [batchDimB, outerShapeB, innerShapeB] :
86162 [batchDimB, innerShapeB, outerShapeB];
86163 // The rest of the implementation is designed to operate on rank-3 tensors
86164 const a3d = reshape$3({ inputs: { x: a }, backend, attrs: { shape: a3dShape } });
86165 const b3d = reshape$3({ inputs: { x: b }, backend, attrs: { shape: b3dShape } });
86166 const intermediates = [a3d, b3d];
86167 const batchDim = Math.max(batchDimA, batchDimB);
86168 const sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];
86169 const hasBias = bias != null;
86170 const hasPreluActivationWeights = preluActivationWeights != null;
86171 const hasLeakyreluAlpha = activation === 'leakyrelu';
86172 const fusedActivation = activation != null ?
86173 mapActivationToShaderProgram(activation, true) :
86174 null;
86175 const containsFusedOps = hasBias || hasPreluActivationWeights ||
86176 hasLeakyreluAlpha || fusedActivation != null;
86177 let out;
86178 // Since the matrices are vectors, it is faster to call mul().sum()
86179 // because sum() is O(sqrt(N)) due to divide-and-conquer.
86180 if ((outerShapeA === 1 || outerShapeB === 1) &&
86181 sharedDim > MATMUL_SHARED_DIM_THRESHOLD && containsFusedOps === false) {
86182 let aVec = a3d;
86183 let bVec = b3d;
86184 if (transposeA) {
86185 aVec = transpose$2({ inputs: { x: a3d }, backend, attrs: { perm: [0, 2, 1] } });
86186 intermediates.push(aVec);
86187 }
86188 if (transposeB) {
86189 bVec = transpose$2({ inputs: { x: b3d }, backend, attrs: { perm: [0, 2, 1] } });
86190 intermediates.push(bVec);
86191 }
86192 const shouldReshapeA = outerShapeB !== 1;
86193 const shouldReshapeB = outerShapeB === 1;
86194 let aVec3d = aVec;
86195 if (shouldReshapeA) {
86196 aVec3d = reshape$3({
86197 inputs: { x: aVec },
86198 backend,
86199 attrs: { shape: [batchDim, sharedDim, 1] }
86200 });
86201 intermediates.push(aVec3d);
86202 }
86203 const axis = outerShapeB === 1 ? 2 : 1;
86204 let bVec3d = bVec;
86205 if (shouldReshapeB) {
86206 bVec3d = reshape$3({
86207 inputs: { x: bVec },
86208 backend,
86209 attrs: { shape: [batchDim, 1, sharedDim] }
86210 });
86211 intermediates.push(bVec3d);
86212 }
86213 const product = multiply$3({ inputs: { a: aVec3d, b: bVec3d }, backend });
86214 out = sum$4({ inputs: { x: product }, backend, attrs: { axis, keepDims: true } });
86215 intermediates.push(product);
86216 }
86217 else {
86218 const dtype = upcastType(a.dtype, b.dtype);
86219 const program = new MatMulPackedProgram(a3dShape, b3dShape, [batchDim, outerShapeA, outerShapeB], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
86220 const inputs = [a3d, b3d];
86221 if (bias != null) {
86222 inputs.push(bias);
86223 }
86224 if (hasPreluActivationWeights) {
86225 inputs.push(preluActivationWeights);
86226 }
86227 if (hasLeakyreluAlpha) {
86228 const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
86229 inputs.push($leakyreluAlpha);
86230 intermediates.push($leakyreluAlpha);
86231 }
86232 out = backend.runWebGLProgram(program, inputs, dtype);
86233 }
86234 const outReshaped = reshape$3({ inputs: { x: out }, backend, attrs: { shape: outShape } });
86235 intermediates.push(out);
86236 for (const i of intermediates) {
86237 backend.disposeIntermediateTensorInfo(i);
86238 }
86239 return outReshaped;
86240 }
86241
86242 /**
86243 * @license
86244 * Copyright 2020 Google LLC. All Rights Reserved.
86245 * Licensed under the Apache License, Version 2.0 (the License);
86246 * you may not use this file except in compliance with the License.
86247 * You may obtain a copy of the License at
86248 *
86249 * http://www.apache.org/licenses/LICENSE-2.0
86250 *
86251 * Unless required by applicable law or agreed to in writing, software
86252 * distributed under the License is distributed on an AS IS BASIS,
86253 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86254 * See the License for the specific language governing permissions and
86255 * limitations under the License.
86256 * =============================================================================
86257 */
86258 function _fusedMatMul$1(args) {
86259 const { inputs, backend, attrs } = args;
86260 const { a, b, bias, preluActivationWeights } = inputs;
86261 const { transposeA, transposeB, activation, leakyreluAlpha } = attrs;
86262 return batchMatMulImpl({
86263 a,
86264 b,
86265 transposeA,
86266 transposeB,
86267 backend,
86268 bias,
86269 preluActivationWeights,
86270 leakyreluAlpha,
86271 activation
86272 });
86273 }
86274 const _fusedMatMulConfig$1 = {
86275 kernelName: _FusedMatMul,
86276 backendName: 'webgl',
86277 kernelFunc: _fusedMatMul$1,
86278 };
86279
86280 /**
86281 * @license
86282 * Copyright 2020 Google LLC. All Rights Reserved.
86283 * Licensed under the Apache License, Version 2.0 (the "License");
86284 * you may not use this file except in compliance with the License.
86285 * You may obtain a copy of the License at
86286 *
86287 * http://www.apache.org/licenses/LICENSE-2.0
86288 *
86289 * Unless required by applicable law or agreed to in writing, software
86290 * distributed under the License is distributed on an "AS IS" BASIS,
86291 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86292 * See the License for the specific language governing permissions and
86293 * limitations under the License.
86294 * =============================================================================
86295 */
86296 const ABS$1 = `return abs(x);`;
86297 function abs$2(args) {
86298 const { inputs, backend } = args;
86299 const { x } = inputs;
86300 // TODO: handle cases when x is complex. Once the cpu implementation
86301 // can handle complex values, refactor to use unaryKernelFunc.
86302 if (backend.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') {
86303 const xData = backend.texData.get(x.dataId);
86304 const outValues = simpleAbsImplCPU(xData.values);
86305 return backend.makeTensorInfo(x.shape, x.dtype, outValues);
86306 }
86307 let program;
86308 if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
86309 program = new UnaryOpPackedProgram(x.shape, ABS$1);
86310 }
86311 else {
86312 program = new UnaryOpProgram(x.shape, ABS$1);
86313 }
86314 return backend.runWebGLProgram(program, [x], x.dtype);
86315 }
86316 const absConfig$1 = {
86317 kernelName: Abs,
86318 backendName: 'webgl',
86319 kernelFunc: abs$2
86320 };
86321
86322 /**
86323 * @license
86324 * Copyright 2020 Google LLC. All Rights Reserved.
86325 * Licensed under the Apache License, Version 2.0 (the "License");
86326 * you may not use this file except in compliance with the License.
86327 * You may obtain a copy of the License at
86328 *
86329 * http://www.apache.org/licenses/LICENSE-2.0
86330 *
86331 * Unless required by applicable law or agreed to in writing, software
86332 * distributed under the License is distributed on an "AS IS" BASIS,
86333 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86334 * See the License for the specific language governing permissions and
86335 * limitations under the License.
86336 * =============================================================================
86337 */
86338 const ACOS = CHECK_NAN_SNIPPET + `
86339 if (abs(x) > 1.) {
86340 return NAN;
86341 }
86342 return acos(x);
86343`;
86344 const acos$2 = unaryKernelFunc$1({ opSnippet: ACOS });
86345 const acosConfig$1 = {
86346 kernelName: Acos,
86347 backendName: 'webgl',
86348 kernelFunc: acos$2,
86349 };
86350
86351 /**
86352 * @license
86353 * Copyright 2020 Google LLC. All Rights Reserved.
86354 * Licensed under the Apache License, Version 2.0 (the "License");
86355 * you may not use this file except in compliance with the License.
86356 * You may obtain a copy of the License at
86357 *
86358 * http://www.apache.org/licenses/LICENSE-2.0
86359 *
86360 * Unless required by applicable law or agreed to in writing, software
86361 * distributed under the License is distributed on an "AS IS" BASIS,
86362 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86363 * See the License for the specific language governing permissions and
86364 * limitations under the License.
86365 * =============================================================================
86366 */
86367 const ACOSH = CHECK_NAN_SNIPPET + `
86368 if (x < 1.0) return NAN;
86369return log(x + sqrt(x * x - 1.0));`;
86370 const acosh$2 = unaryKernelFunc$1({ opSnippet: ACOSH });
86371 const acoshConfig$1 = {
86372 kernelName: Acosh,
86373 backendName: 'webgl',
86374 kernelFunc: acosh$2,
86375 };
86376
86377 /**
86378 * @license
86379 * Copyright 2020 Google LLC. All Rights Reserved.
86380 * Licensed under the Apache License, Version 2.0 (the "License");
86381 * you may not use this file except in compliance with the License.
86382 * You may obtain a copy of the License at
86383 *
86384 * http://www.apache.org/licenses/LICENSE-2.0
86385 *
86386 * Unless required by applicable law or agreed to in writing, software
86387 * distributed under the License is distributed on an "AS IS" BASIS,
86388 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86389 * See the License for the specific language governing permissions and
86390 * limitations under the License.
86391 * =============================================================================
86392 */
86393 const ADD = 'return a + b;';
86394 const addKernelFunc = binaryKernelFunc$1({
86395 opSnippet: ADD,
86396 packedOpSnippet: ADD,
86397 supportsComplex: true,
86398 cpuKernelImpl: addImplCPU
86399 });
86400 const addConfig$1 = {
86401 kernelName: Add,
86402 backendName: 'webgl',
86403 kernelFunc: addKernelFunc
86404 };
86405
86406 /**
86407 * @license
86408 * Copyright 2019 Google LLC. All Rights Reserved.
86409 * Licensed under the Apache License, Version 2.0 (the "License");
86410 * you may not use this file except in compliance with the License.
86411 * You may obtain a copy of the License at
86412 *
86413 * http://www.apache.org/licenses/LICENSE-2.0
86414 *
86415 * Unless required by applicable law or agreed to in writing, software
86416 * distributed under the License is distributed on an "AS IS" BASIS,
86417 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86418 * See the License for the specific language governing permissions and
86419 * limitations under the License.
86420 * =============================================================================
86421 */
86422 class AddNProgram {
86423 constructor(outputShape, shapes) {
86424 this.outputShape = [];
86425 this.outputShape = outputShape;
86426 this.variableNames = shapes.map((_, i) => `T${i}`);
86427 const snippets = [];
86428 // Get target elements from every input tensor.
86429 this.variableNames.forEach(variable => {
86430 snippets.push(`float v${variable} = get${variable}AtOutCoords();`);
86431 });
86432 // Calculate the sum of all elements.
86433 const operation = this.variableNames
86434 .map(variable => {
86435 return `v${variable}`;
86436 })
86437 .join(' + ');
86438 this.userCode = `
86439 void main() {
86440 ${snippets.join('\n ')}
86441
86442 float result = ${operation};
86443 setOutput(result);
86444 }
86445 `;
86446 }
86447 }
86448
86449 /**
86450 * @license
86451 * Copyright 2019 Google LLC. All Rights Reserved.
86452 * Licensed under the Apache License, Version 2.0 (the "License");
86453 * you may not use this file except in compliance with the License.
86454 * You may obtain a copy of the License at
86455 *
86456 * http://www.apache.org/licenses/LICENSE-2.0
86457 *
86458 * Unless required by applicable law or agreed to in writing, software
86459 * distributed under the License is distributed on an "AS IS" BASIS,
86460 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86461 * See the License for the specific language governing permissions and
86462 * limitations under the License.
86463 * =============================================================================
86464 */
86465 class AddNPackedProgram {
86466 constructor(outputShape, shapes) {
86467 this.outputShape = [];
86468 this.packedInputs = true;
86469 this.packedOutput = true;
86470 this.outputShape = outputShape;
86471 this.variableNames = shapes.map((_, i) => `T${i}`);
86472 const snippets = [];
86473 // Get target elements from every input tensor.
86474 this.variableNames.forEach(variable => {
86475 snippets.push(`vec4 v${variable} = get${variable}AtOutCoords();`);
86476 });
86477 // Calculate the sum of all elements.
86478 const operation = this.variableNames
86479 .map(variable => {
86480 return `v${variable}`;
86481 })
86482 .join(' + ');
86483 this.userCode = `
86484 void main() {
86485 ${snippets.join('\n ')}
86486
86487 vec4 result = ${operation};
86488 setOutput(result);
86489 }
86490 `;
86491 }
86492 }
86493
86494 /**
86495 * @license
86496 * Copyright 2020 Google LLC. All Rights Reserved.
86497 * Licensed under the Apache License, Version 2.0 (the "License");
86498 * you may not use this file except in compliance with the License.
86499 * You may obtain a copy of the License at
86500 *
86501 * http://www.apache.org/licenses/LICENSE-2.0
86502 *
86503 * Unless required by applicable law or agreed to in writing, software
86504 * distributed under the License is distributed on an "AS IS" BASIS,
86505 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86506 * See the License for the specific language governing permissions and
86507 * limitations under the License.
86508 * =============================================================================
86509 */
86510 function addN$2(args) {
86511 const { inputs, backend } = args;
86512 const tensors = inputs;
86513 if (tensors.length === 1) {
86514 return identity$2({ inputs: { x: tensors[0] }, backend });
86515 }
86516 // Limit the number of uploaded textures for optimization.
86517 if (tensors.length > env().get('WEBGL_MAX_TEXTURES_IN_SHADER')) {
86518 const midIndex = Math.floor(tensors.length / 2);
86519 const leftSide = addN$2({ inputs: tensors.slice(0, midIndex), backend });
86520 const rightSide = addN$2({ inputs: tensors.slice(midIndex), backend });
86521 return addN$2({ inputs: [leftSide, rightSide], backend });
86522 }
86523 const dtype = tensors.map(t => t.dtype).reduce((d1, d2) => upcastType(d1, d2));
86524 const shapes = tensors.map(t => t.shape);
86525 // We can make sure shapes are identical in op level.
86526 const usePackedOp = env().getBool('WEBGL_PACK');
86527 const program = usePackedOp ?
86528 new AddNPackedProgram(tensors[0].shape, shapes) :
86529 new AddNProgram(tensors[0].shape, shapes);
86530 return backend.runWebGLProgram(program, tensors, dtype);
86531 }
86532 const addNConfig$1 = {
86533 kernelName: AddN,
86534 backendName: 'webgl',
86535 kernelFunc: addN$2
86536 };
86537
86538 /**
86539 * @license
86540 * Copyright 2020 Google LLC. All Rights Reserved.
86541 * Licensed under the Apache License, Version 2.0 (the "License");
86542 * you may not use this file except in compliance with the License.
86543 * You may obtain a copy of the License at
86544 *
86545 * http://www.apache.org/licenses/LICENSE-2.0
86546 *
86547 * Unless required by applicable law or agreed to in writing, software
86548 * distributed under the License is distributed on an "AS IS" BASIS,
86549 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86550 * See the License for the specific language governing permissions and
86551 * limitations under the License.
86552 * =============================================================================
86553 */
86554 function all$2(args) {
86555 const { inputs, backend, attrs } = args;
86556 const { x } = inputs;
86557 const { axis, keepDims } = attrs;
86558 const xRank = x.shape.length;
86559 const origAxes = parseAxisParam(axis, x.shape);
86560 let axes = origAxes;
86561 const permutedAxes = getAxesPermutation(axes, xRank);
86562 let permutedX = x;
86563 if (permutedAxes != null) {
86564 permutedX = transpose$2({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
86565 axes = getInnerMostAxes(axes.length, xRank);
86566 }
86567 assertAxesAreInnerMostDims('all', axes, xRank);
86568 const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes);
86569 const inSize = sizeFromShape(reduceShape);
86570 const a2D = reshape$3({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
86571 const reduced = reduce(a2D, a2D.dtype, 'all', backend);
86572 let res;
86573 if (keepDims) {
86574 const newShape = expandShapeToKeepDim(outShape, origAxes);
86575 res = reshape$3({ inputs: { x: reduced }, backend, attrs: { shape: newShape } });
86576 }
86577 else {
86578 res = reshape$3({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
86579 }
86580 backend.disposeIntermediateTensorInfo(a2D);
86581 backend.disposeIntermediateTensorInfo(reduced);
86582 if (permutedAxes != null) {
86583 backend.disposeIntermediateTensorInfo(permutedX);
86584 }
86585 return res;
86586 }
86587 const allConfig$1 = {
86588 kernelName: All,
86589 backendName: 'webgl',
86590 kernelFunc: all$2
86591 };
86592
86593 /**
86594 * @license
86595 * Copyright 2020 Google LLC. All Rights Reserved.
86596 * Licensed under the Apache License, Version 2.0 (the "License");
86597 * you may not use this file except in compliance with the License.
86598 * You may obtain a copy of the License at
86599 *
86600 * http://www.apache.org/licenses/LICENSE-2.0
86601 *
86602 * Unless required by applicable law or agreed to in writing, software
86603 * distributed under the License is distributed on an "AS IS" BASIS,
86604 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86605 * See the License for the specific language governing permissions and
86606 * limitations under the License.
86607 * =============================================================================
86608 */
86609 function any$2(args) {
86610 const { inputs, backend, attrs } = args;
86611 const { x } = inputs;
86612 const { axis, keepDims } = attrs;
86613 const xRank = x.shape.length;
86614 const origAxes = parseAxisParam(axis, x.shape);
86615 let axes = origAxes;
86616 const permutedAxes = getAxesPermutation(axes, xRank);
86617 let permutedX = x;
86618 if (permutedAxes != null) {
86619 permutedX = transpose$2({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
86620 axes = getInnerMostAxes(axes.length, xRank);
86621 }
86622 assertAxesAreInnerMostDims('any', axes, xRank);
86623 const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes);
86624 const inSize = sizeFromShape(reduceShape);
86625 const a2D = reshape$3({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
86626 const reduced = reduce(a2D, a2D.dtype, 'any', backend);
86627 let res;
86628 if (keepDims) {
86629 const newShape = expandShapeToKeepDim(outShape, origAxes);
86630 res = reshape$3({ inputs: { x: reduced }, backend, attrs: { shape: newShape } });
86631 }
86632 else {
86633 res = reshape$3({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
86634 }
86635 backend.disposeIntermediateTensorInfo(a2D);
86636 backend.disposeIntermediateTensorInfo(reduced);
86637 if (permutedAxes != null) {
86638 backend.disposeIntermediateTensorInfo(permutedX);
86639 }
86640 return res;
86641 }
86642 const anyConfig$1 = {
86643 kernelName: Any,
86644 backendName: 'webgl',
86645 kernelFunc: any$2
86646 };
86647
86648 /**
86649 * @license
86650 * Copyright 2017 Google LLC. All Rights Reserved.
86651 * Licensed under the Apache License, Version 2.0 (the "License");
86652 * you may not use this file except in compliance with the License.
86653 * You may obtain a copy of the License at
86654 *
86655 * http://www.apache.org/licenses/LICENSE-2.0
86656 *
86657 * Unless required by applicable law or agreed to in writing, software
86658 * distributed under the License is distributed on an "AS IS" BASIS,
86659 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86660 * See the License for the specific language governing permissions and
86661 * limitations under the License.
86662 * =============================================================================
86663 */
86664 class ArgMinMaxProgram {
86665 constructor(reduceInfo, op, firstPass) {
86666 this.variableNames = ['A'];
86667 const { windowSize, batchSize, outSize } = reduceInfo;
86668 if (!firstPass) {
86669 this.variableNames.push('bestIndicesA');
86670 }
86671 this.outputShape = [batchSize, outSize];
86672 const compOp = (op === 'max') ? '>' : '<';
86673 const indexSnippet = firstPass ?
86674 'inOffset + i;' :
86675 'round(getBestIndicesA(batch, inOffset + i));';
86676 this.userCode = `
86677 void main() {
86678 ivec2 coords = getOutputCoords();
86679 int batch = coords[0];
86680 int outIdx = coords[1];
86681 int inOffset = outIdx * ${windowSize};
86682
86683 int bestIndex = inOffset;
86684 float bestValue = getA(batch, bestIndex);
86685
86686 for (int i = 0; i < ${windowSize}; i++) {
86687 int inIdx = ${indexSnippet};
86688 float candidate = getA(batch, inIdx);
86689 if (candidate ${compOp} bestValue) {
86690 bestValue = candidate;
86691 bestIndex = inIdx;
86692 }
86693 }
86694 setOutput(float(bestIndex));
86695 }
86696 `;
86697 }
86698 }
86699
86700 /**
86701 * @license
86702 * Copyright 2019 Google LLC. All Rights Reserved.
86703 * Licensed under the Apache License, Version 2.0 (the "License");
86704 * you may not use this file except in compliance with the License.
86705 * You may obtain a copy of the License at
86706 *
86707 * http://www.apache.org/licenses/LICENSE-2.0
86708 *
86709 * Unless required by applicable law or agreed to in writing, software
86710 * distributed under the License is distributed on an "AS IS" BASIS,
86711 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86712 * See the License for the specific language governing permissions and
86713 * limitations under the License.
86714 * =============================================================================
86715 */
86716 class ArgMinMaxPackedProgram {
86717 constructor(shape, windowSize, op, firstPass) {
86718 this.variableNames = ['A'];
86719 this.packedInputs = true;
86720 this.packedOutput = true;
86721 assert(shape.length > 2, () => `Packed arg${op.charAt(0).toUpperCase() +
86722 op.slice(1)} supports only inputs with rank above 2.`);
86723 const inSize = shape[shape.length - 1];
86724 const outSize = Math.ceil(inSize / windowSize);
86725 this.outputShape = shape.slice(0, -1);
86726 if (outSize > 1) {
86727 this.outputShape.push(outSize);
86728 }
86729 if (!firstPass) {
86730 this.variableNames.push('bestIndicesA');
86731 }
86732 const outShape = this.outputShape;
86733 const rank = outShape.length;
86734 const dtype = getCoordsDataType(rank);
86735 const coords = getChannels('coords', rank);
86736 let sourceLocSetup;
86737 let sourceRank;
86738 if (outSize === 1) {
86739 sourceRank = rank + 1;
86740 const sourceLocDType = getCoordsDataType(sourceRank);
86741 sourceLocSetup = `
86742 ${sourceLocDType} sourceLocR = ${sourceLocDType}(${coords.join()}, 0);
86743 ++${coords[rank - 1]};
86744 ${sourceLocDType} sourceLocG = ${sourceLocDType}(${coords.join()}, 0);
86745 ++${coords[rank - 2]};
86746 ${sourceLocDType} sourceLocA = ${sourceLocDType}(${coords.join()}, 0);
86747 --${coords[rank - 1]};
86748 ${sourceLocDType} sourceLocB = ${sourceLocDType}(${coords.join()}, 0);
86749 --${coords[rank - 2]};`;
86750 }
86751 else {
86752 sourceRank = rank;
86753 sourceLocSetup = `
86754 ${dtype} sourceLocR = coords;
86755 ++${coords[rank - 1]};
86756 ${dtype} sourceLocG = coords;
86757 ++${coords[rank - 2]};
86758 ${dtype} sourceLocA = coords;
86759 --${coords[rank - 1]};
86760 ${dtype} sourceLocB = coords;
86761 --${coords[rank - 2]};`;
86762 }
86763 const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank);
86764 const inChannel = '.' + channels[sourceRank - 1]; // e.g. ".b" for rank 3.
86765 const intChannels = channels.map(x => 'int ' + x);
86766 const srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r');
86767 const srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g');
86768 const srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b');
86769 const srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a');
86770 const compOp = (op === 'max') ? 'greaterThan' : 'lessThan';
86771 const fetchCandidateIdx = firstPass ? '' : `
86772 inIdx = round(vec4(getBestIndicesAChannel(${srcRCoords.join()}),
86773 getBestIndicesAChannel(${srcGCoords.join()}),
86774 getBestIndicesAChannel(${srcBCoords.join()}),
86775 getBestIndicesAChannel(${srcACoords.join()})));`;
86776 const fetchValue = `vec4(
86777 getAChannel(${srcRCoords.join()}),
86778 hasNextCol ? getAChannel(${srcGCoords.join()}) : 0.,
86779 hasNextRow ? getAChannel(${srcBCoords.join()}) : 0.,
86780 hasNextRow && hasNextCol ? getAChannel(${srcACoords.join()}) : 0.)`;
86781 const getBestIndicesAChannelSnippet = firstPass ? '' : `
86782 float getBestIndicesAChannel(${intChannels.join()}) {
86783 return getChannel(getBestIndicesA(${channels.join()}),
86784 vec2(${channels.slice(-2).join()}));
86785 }`;
86786 this.userCode = `
86787 float getAChannel(${intChannels.join()}) {
86788 return getChannel(getA(${channels.join()}),
86789 vec2(${channels.slice(-2).join()}));
86790 }
86791 ${getBestIndicesAChannelSnippet}
86792 void main() {
86793 ${dtype} coords = getOutputCoords();
86794 bool hasNextCol = ${coords[rank - 1]} < ${outShape[rank - 1] - 1};
86795 bool hasNextRow = ${coords[rank - 2]} < ${outShape[rank - 2] - 1};
86796 ${sourceLocSetup}
86797 ivec4 srcIdx = ivec4(sourceLocR${inChannel}, sourceLocG${inChannel},
86798 sourceLocB${inChannel}, sourceLocA${inChannel}) * ${windowSize};
86799 ivec4 inIdx = srcIdx;
86800 vec4 bestIndex = vec4(inIdx);
86801 vec4 bestValue = ${fetchValue};
86802
86803 for (int i = 0; i < ${windowSize}; i++) {
86804 inIdx = srcIdx;
86805 ${fetchCandidateIdx}
86806 vec4 candidate = ${fetchValue};
86807 bvec4 nan = isnan(candidate);
86808 bvec4 replace = bvec4(
86809 vec4(${compOp}(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));
86810
86811 bestValue = vec4(replace.x ? candidate.x : bestValue.x,
86812 replace.y ? candidate.y : bestValue.y,
86813 replace.z ? candidate.z : bestValue.z,
86814 replace.w ? candidate.w : bestValue.w);
86815 bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));
86816 srcIdx++;
86817 }
86818 setOutput(bestIndex);
86819 }
86820 `;
86821 }
86822 }
86823
86824 /**
86825 * @license
86826 * Copyright 2020 Google LLC. All Rights Reserved.
86827 * Licensed under the Apache License, Version 2.0 (the "License");
86828 * you may not use this file except in compliance with the License.
86829 * You may obtain a copy of the License at
86830 *
86831 * http://www.apache.org/licenses/LICENSE-2.0
86832 *
86833 * Unless required by applicable law or agreed to in writing, software
86834 * distributed under the License is distributed on an "AS IS" BASIS,
86835 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86836 * See the License for the specific language governing permissions and
86837 * limitations under the License.
86838 * =============================================================================
86839 */
86840 function argReduce(backend, x, reduceType, bestIndicesA = null) {
86841 let batchSize = x.shape[0];
86842 let inSize = x.shape[1];
86843 if (bestIndicesA != null) {
86844 batchSize = bestIndicesA.shape[0];
86845 inSize = bestIndicesA.shape[1];
86846 }
86847 const windowSize = computeOptimalWindowSize(inSize);
86848 const reduceInfo = { windowSize, inSize, batchSize, outSize: Math.ceil(inSize / windowSize) };
86849 const program = new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null);
86850 const inputs = [x];
86851 if (bestIndicesA != null) {
86852 inputs.push(bestIndicesA);
86853 }
86854 const output = backend.runWebGLProgram(program, inputs, 'int32');
86855 // No need to run another GPGPU program.
86856 if (output.shape[1] === 1) {
86857 return output;
86858 }
86859 const result = argReduce(backend, x, reduceType, output);
86860 backend.disposeIntermediateTensorInfo(output);
86861 return result;
86862 }
86863 function argReducePacked(backend, x, reduceType, bestIndicesA = null) {
86864 const inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape;
86865 const inSize = inShape[inShape.length - 1];
86866 const windowSize = computeOptimalWindowSize(inSize);
86867 const program = new ArgMinMaxPackedProgram(inShape, windowSize, reduceType, bestIndicesA == null);
86868 const inputs = bestIndicesA == null ? [x] : [x, bestIndicesA];
86869 const output = backend.runWebGLProgram(program, inputs, 'int32');
86870 if (output.shape.length === x.shape.length) {
86871 const result = argReducePacked(backend, x, reduceType, output);
86872 backend.disposeIntermediateTensorInfo(output);
86873 return result;
86874 }
86875 return output;
86876 }
86877 function argMinMaxReduce(backend, x, axis, reduceType) {
86878 const axes = [axis];
86879 assertAxesAreInnerMostDims('arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, x.shape.length);
86880 if (!env().getBool('WEBGL_PACK_REDUCE') || x.shape.length <= 2) {
86881 const intermediateTensorInfos = [];
86882 // Eagerly unpack x input since it is passed in to all the shaders which
86883 // require unpacked inputs.
86884 const xtexData = backend.texData.get(x.dataId);
86885 const xIsPacked = xtexData !== null && xtexData.isPacked;
86886 let xUnPacked = x;
86887 if (xIsPacked) {
86888 xUnPacked = backend.unpackTensor(x);
86889 intermediateTensorInfos.push(xUnPacked);
86890 }
86891 const [outShape, reduceShape] = computeOutAndReduceShapes(xUnPacked.shape, axes);
86892 const inSize = sizeFromShape(reduceShape);
86893 const a2D = reshape$3({ inputs: { x: xUnPacked }, backend, attrs: { shape: [-1, inSize] } });
86894 intermediateTensorInfos.push(a2D);
86895 const reduced = argReduce(backend, a2D, reduceType);
86896 intermediateTensorInfos.push(reduced);
86897 const reshaped = reshape$3({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
86898 intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
86899 return reshaped;
86900 }
86901 return argReducePacked(backend, x, reduceType);
86902 }
86903
86904 /**
86905 * @license
86906 * Copyright 2020 Google LLC. All Rights Reserved.
86907 * Licensed under the Apache License, Version 2.0 (the "License");
86908 * you may not use this file except in compliance with the License.
86909 * You may obtain a copy of the License at
86910 *
86911 * http://www.apache.org/licenses/LICENSE-2.0
86912 *
86913 * Unless required by applicable law or agreed to in writing, software
86914 * distributed under the License is distributed on an "AS IS" BASIS,
86915 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86916 * See the License for the specific language governing permissions and
86917 * limitations under the License.
86918 * =============================================================================
86919 */
86920 function argMax$2(args) {
86921 const { inputs, backend, attrs } = args;
86922 const { x } = inputs;
86923 const { axis } = attrs;
86924 let axes = parseAxisParam(axis, x.shape);
86925 const permutedAxes = getAxesPermutation(axes, x.shape.length);
86926 let $x = x;
86927 const intermediateTensorInfos = [];
86928 if (permutedAxes != null) {
86929 $x = transpose$2({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
86930 intermediateTensorInfos.push($x);
86931 axes = getInnerMostAxes(axes.length, $x.shape.length);
86932 }
86933 assertAxesAreInnerMostDims('argMax', [axes[0]], $x.shape.length);
86934 const out = argMinMaxReduce(backend, $x, axes[0], 'max');
86935 intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
86936 return out;
86937 }
86938 const argMaxConfig$1 = {
86939 kernelName: ArgMax,
86940 backendName: 'webgl',
86941 kernelFunc: argMax$2
86942 };
86943
86944 /**
86945 * @license
86946 * Copyright 2020 Google LLC. All Rights Reserved.
86947 * Licensed under the Apache License, Version 2.0 (the "License");
86948 * you may not use this file except in compliance with the License.
86949 * You may obtain a copy of the License at
86950 *
86951 * http://www.apache.org/licenses/LICENSE-2.0
86952 *
86953 * Unless required by applicable law or agreed to in writing, software
86954 * distributed under the License is distributed on an "AS IS" BASIS,
86955 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86956 * See the License for the specific language governing permissions and
86957 * limitations under the License.
86958 * =============================================================================
86959 */
86960 function argMin$2(args) {
86961 const { inputs, backend, attrs } = args;
86962 const { x } = inputs;
86963 const { axis } = attrs;
86964 let axes = parseAxisParam(axis, x.shape);
86965 const permutedAxes = getAxesPermutation(axes, x.shape.length);
86966 let $x = x;
86967 const intermediateTensorInfos = [];
86968 if (permutedAxes != null) {
86969 $x = transpose$2({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
86970 intermediateTensorInfos.push($x);
86971 axes = getInnerMostAxes(axes.length, $x.shape.length);
86972 }
86973 assertAxesAreInnerMostDims('argMin', [axes[0]], $x.shape.length);
86974 const out = argMinMaxReduce(backend, $x, axes[0], 'min');
86975 intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
86976 return out;
86977 }
86978 const argMinConfig$1 = {
86979 kernelName: ArgMin,
86980 backendName: 'webgl',
86981 kernelFunc: argMin$2
86982 };
86983
86984 /**
86985 * @license
86986 * Copyright 2020 Google LLC. All Rights Reserved.
86987 * Licensed under the Apache License, Version 2.0 (the "License");
86988 * you may not use this file except in compliance with the License.
86989 * You may obtain a copy of the License at
86990 *
86991 * http://www.apache.org/licenses/LICENSE-2.0
86992 *
86993 * Unless required by applicable law or agreed to in writing, software
86994 * distributed under the License is distributed on an "AS IS" BASIS,
86995 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86996 * See the License for the specific language governing permissions and
86997 * limitations under the License.
86998 * =============================================================================
86999 */
87000 const ASIN = CHECK_NAN_SNIPPET + `
87001 if (abs(x) > 1.) {
87002 return NAN;
87003 }
87004 return asin(x);
87005`;
87006 const asin$2 = unaryKernelFunc$1({ opSnippet: ASIN });
87007 const asinConfig$1 = {
87008 kernelName: Asin,
87009 backendName: 'webgl',
87010 kernelFunc: asin$2,
87011 };
87012
87013 /**
87014 * @license
87015 * Copyright 2020 Google LLC. All Rights Reserved.
87016 * Licensed under the Apache License, Version 2.0 (the "License");
87017 * you may not use this file except in compliance with the License.
87018 * You may obtain a copy of the License at
87019 *
87020 * http://www.apache.org/licenses/LICENSE-2.0
87021 *
87022 * Unless required by applicable law or agreed to in writing, software
87023 * distributed under the License is distributed on an "AS IS" BASIS,
87024 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87025 * See the License for the specific language governing permissions and
87026 * limitations under the License.
87027 * =============================================================================
87028 */
87029 const ASINH = CHECK_NAN_SNIPPET + `return log(x + sqrt(x * x + 1.0));`;
87030 const asinh$2 = unaryKernelFunc$1({ opSnippet: ASINH });
87031 const asinhConfig$1 = {
87032 kernelName: Asinh,
87033 backendName: 'webgl',
87034 kernelFunc: asinh$2,
87035 };
87036
87037 /**
87038 * @license
87039 * Copyright 2020 Google LLC. All Rights Reserved.
87040 * Licensed under the Apache License, Version 2.0 (the "License");
87041 * you may not use this file except in compliance with the License.
87042 * You may obtain a copy of the License at
87043 *
87044 * http://www.apache.org/licenses/LICENSE-2.0
87045 *
87046 * Unless required by applicable law or agreed to in writing, software
87047 * distributed under the License is distributed on an "AS IS" BASIS,
87048 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87049 * See the License for the specific language governing permissions and
87050 * limitations under the License.
87051 * =============================================================================
87052 */
87053 const ATAN = CHECK_NAN_SNIPPET + `
87054 return atan(x);
87055`;
87056 const atan$2 = unaryKernelFunc$1({ opSnippet: ATAN });
87057 const atanConfig$1 = {
87058 kernelName: Atan,
87059 backendName: 'webgl',
87060 kernelFunc: atan$2,
87061 };
87062
87063 /**
87064 * @license
87065 * Copyright 2020 Google LLC. All Rights Reserved.
87066 * Licensed under the Apache License, Version 2.0 (the "License");
87067 * you may not use this file except in compliance with the License.
87068 * You may obtain a copy of the License at
87069 *
87070 * http://www.apache.org/licenses/LICENSE-2.0
87071 *
87072 * Unless required by applicable law or agreed to in writing, software
87073 * distributed under the License is distributed on an "AS IS" BASIS,
87074 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87075 * See the License for the specific language governing permissions and
87076 * limitations under the License.
87077 * =============================================================================
87078 */
87079 const ATAN2 = CHECK_NAN_SNIPPET_BINARY + `
87080 return atan(a, b);
87081`;
87082 const ATAN2_PACKED = `
87083 vec4 result = atan(a, b);
87084 vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));
87085 ` +
87086 CHECK_NAN_SNIPPET_BINARY_PACKED + `
87087 return result;
87088`;
87089 const atan2$2 = binaryKernelFunc$1({ opSnippet: ATAN2, packedOpSnippet: ATAN2_PACKED });
87090 const atan2Config$1 = {
87091 kernelName: Atan2,
87092 backendName: 'webgl',
87093 kernelFunc: atan2$2,
87094 };
87095
87096 /**
87097 * @license
87098 * Copyright 2020 Google LLC. All Rights Reserved.
87099 * Licensed under the Apache License, Version 2.0 (the "License");
87100 * you may not use this file except in compliance with the License.
87101 * You may obtain a copy of the License at
87102 *
87103 * http://www.apache.org/licenses/LICENSE-2.0
87104 *
87105 * Unless required by applicable law or agreed to in writing, software
87106 * distributed under the License is distributed on an "AS IS" BASIS,
87107 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87108 * See the License for the specific language governing permissions and
87109 * limitations under the License.
87110 * =============================================================================
87111 */
87112 const ATANH = CHECK_NAN_SNIPPET + `
87113 if ((x < -1.0) || (x > 1.0)) return NAN;
87114return (log(1.0 + x) - log(1.0 - x)) / 2.0;`;
87115 const atanh$2 = unaryKernelFunc$1({ opSnippet: ATANH });
87116 const atanhConfig$1 = {
87117 kernelName: Atanh,
87118 backendName: 'webgl',
87119 kernelFunc: atanh$2,
87120 };
87121
87122 /**
87123 * @license
87124 * Copyright 2017 Google LLC. All Rights Reserved.
87125 * Licensed under the Apache License, Version 2.0 (the "License");
87126 * you may not use this file except in compliance with the License.
87127 * You may obtain a copy of the License at
87128 *
87129 * http://www.apache.org/licenses/LICENSE-2.0
87130 *
87131 * Unless required by applicable law or agreed to in writing, software
87132 * distributed under the License is distributed on an "AS IS" BASIS,
87133 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87134 * See the License for the specific language governing permissions and
87135 * limitations under the License.
87136 * =============================================================================
87137 */
87138 class Pool2DProgram {
87139 constructor(convInfo, poolType, computePositions, flattenPositions = false, includeBatchInIndex = false) {
87140 this.variableNames = ['x'];
87141 if (poolType === 'avg' && computePositions) {
87142 throw new Error('Cannot compute positions for average pool.');
87143 }
87144 const filterWidth = convInfo.filterWidth;
87145 const strideHeight = convInfo.strideHeight;
87146 const strideWidth = convInfo.strideWidth;
87147 const dilationHeight = convInfo.dilationHeight;
87148 const dilationWidth = convInfo.dilationWidth;
87149 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
87150 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
87151 const padTop = convInfo.padInfo.top;
87152 const padLeft = convInfo.padInfo.left;
87153 this.outputShape = convInfo.outShape;
87154 const isAvgPool = poolType === 'avg';
87155 const batchFlattenPositionStr = `((batch * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;
87156 const flattenPositionStr = `(xR * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;
87157 let initializationValue = '0.0';
87158 if (!isAvgPool) {
87159 // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
87160 initializationValue = '-1.0 / 1e-20';
87161 }
87162 if (computePositions) {
87163 const compareOp = '>=';
87164 this.userCode = `
87165 const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
87166 const ivec2 pads = ivec2(${padTop}, ${padLeft});
87167
87168 void main() {
87169 ivec4 coords = getOutputCoords();
87170 int batch = coords[0];
87171 int d = coords[3];
87172
87173 ivec2 xRCCorner = coords.yz * strides - pads;
87174 int xRCorner = xRCCorner.x;
87175 int xCCorner = xRCCorner.y;
87176
87177 // max/min x(?, ?, d) to get y(yR, yC, d).
87178 // ? = to be determined
87179 float minMaxValue = 0.0;
87180 float minMaxValueFound = 0.0;
87181 int minMaxPosition = 0;
87182 float avgValue = 0.0;
87183
87184 for (int wR = 0; wR < ${effectiveFilterHeight};
87185 wR += ${dilationHeight}) {
87186 int xR = xRCorner + wR;
87187
87188 if (xR < 0 || xR >= ${convInfo.inHeight}) {
87189 continue;
87190 }
87191
87192 for (int wC = 0; wC < ${effectiveFilterWidth};
87193 wC += ${dilationWidth}) {
87194 int xC = xCCorner + wC;
87195
87196 if (xC < 0 || xC >= ${convInfo.inWidth}) {
87197 continue;
87198 }
87199
87200 float value = getX(batch, xR, xC, d);
87201
87202 // If a min / max value has already been found, use it. If not,
87203 // use the current value.
87204 float currMinMaxValue = mix(
87205 value, minMaxValue, minMaxValueFound);
87206 if (value ${compareOp} currMinMaxValue) {
87207 minMaxValue = value;
87208 minMaxValueFound = 1.0;
87209 minMaxPosition = ${flattenPositions ? (includeBatchInIndex ? batchFlattenPositionStr :
87210 flattenPositionStr) :
87211 `wR * ${effectiveFilterWidth} + wC`};
87212 }
87213 }
87214 }
87215 setOutput(float(minMaxPosition));
87216 }
87217 `;
87218 return;
87219 }
87220 const compareOp = 'max';
87221 let returnValue = `${poolType}(${poolType}(${poolType}(` +
87222 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
87223 if (poolType === 'avg') {
87224 returnValue = `avgValue / count`;
87225 }
87226 const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
87227 const filterWidthVec4Remainder = filterWidth % 4;
87228 const updateSnippet = `
87229 if (${isAvgPool}) {
87230 avgValue += dot(values, ones);
87231 } else {
87232 minMaxValue = ${compareOp}(values, minMaxValue);
87233 }
87234 `;
87235 this.userCode = `
87236 const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
87237 const ivec2 pads = ivec2(${padTop}, ${padLeft});
87238 const float initializationValue = ${initializationValue};
87239 const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
87240
87241 float count = 0.0;
87242
87243 float getValue(int batch, int xR, int xC, int d) {
87244 if (xC < 0 || xC >= ${convInfo.inWidth}) {
87245 return initializationValue;
87246 }
87247 count += 1.0;
87248 return getX(batch, xR, xC, d);
87249 }
87250
87251 void main() {
87252 ivec4 coords = getOutputCoords();
87253 int batch = coords[0];
87254 int d = coords[3];
87255
87256 ivec2 xRCCorner = coords.yz * strides - pads;
87257 int xRCorner = xRCCorner.x;
87258 int xCCorner = xRCCorner.y;
87259
87260 // max/min x(?, ?, d) to get y(yR, yC, d).
87261 // ? = to be determined
87262 vec4 minMaxValue = vec4(${initializationValue});
87263 float avgValue = 0.0;
87264 count = 0.0;
87265
87266 for (int wR = 0; wR < ${effectiveFilterHeight};
87267 wR += ${dilationHeight}) {
87268 int xR = xRCorner + wR;
87269
87270 if (xR < 0 || xR >= ${convInfo.inHeight}) {
87271 continue;
87272 }
87273
87274 for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {
87275 int xC = xCCorner + wC * ${dilationWidth};
87276
87277 vec4 values = vec4(
87278 getValue(batch, xR, xC, d),
87279 getValue(batch, xR, xC + ${dilationWidth}, d),
87280 getValue(batch, xR, xC + 2 * ${dilationWidth}, d),
87281 getValue(batch, xR, xC + 3 * ${dilationWidth}, d)
87282 );
87283
87284 ${updateSnippet}
87285 }
87286
87287 int xC = xCCorner + ${filterWidthNearestVec4};
87288 if (${filterWidthVec4Remainder === 1}) {
87289 vec4 values = vec4(
87290 getValue(batch, xR, xC, d),
87291 initializationValue,
87292 initializationValue,
87293 initializationValue
87294 );
87295
87296 ${updateSnippet}
87297 } else if (${filterWidthVec4Remainder === 2}) {
87298 vec4 values = vec4(
87299 getValue(batch, xR, xC, d),
87300 getValue(batch, xR, xC + ${dilationWidth}, d),
87301 initializationValue,
87302 initializationValue
87303 );
87304
87305 ${updateSnippet}
87306 } else if (${filterWidthVec4Remainder === 3}) {
87307 vec4 values = vec4(
87308 getValue(batch, xR, xC, d),
87309 getValue(batch, xR, xC + ${dilationWidth}, d),
87310 getValue(batch, xR, xC + 2 * ${dilationWidth}, d),
87311 initializationValue
87312 );
87313
87314 ${updateSnippet}
87315 }
87316 }
87317 setOutput(${returnValue});
87318 }
87319 `;
87320 }
87321 }
87322 class Pool3DProgram {
87323 constructor(convInfo, poolType, computePositions, flattenPositions = false, includeBatchInIndex = false) {
87324 this.variableNames = ['x'];
87325 if (poolType === 'avg' && computePositions) {
87326 throw new Error('Cannot compute positions for average pool.');
87327 }
87328 const filterWidth = convInfo.filterWidth;
87329 const strideDepth = convInfo.strideDepth;
87330 const strideHeight = convInfo.strideHeight;
87331 const strideWidth = convInfo.strideWidth;
87332 const dilationDepth = convInfo.dilationDepth;
87333 const dilationHeight = convInfo.dilationHeight;
87334 const dilationWidth = convInfo.dilationWidth;
87335 const effectiveFilterDepth = convInfo.effectiveFilterDepth;
87336 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
87337 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
87338 const padFront = convInfo.padInfo.front;
87339 const padTop = convInfo.padInfo.top;
87340 const padLeft = convInfo.padInfo.left;
87341 this.outputShape = convInfo.outShape;
87342 const isAvgPool = poolType === 'avg';
87343 let initializationValue = '0.0';
87344 if (!isAvgPool) {
87345 // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
87346 initializationValue = '-1.0 / 1e-20';
87347 }
87348 if (computePositions) {
87349 const compareOp = '>=';
87350 this.userCode = `
87351 const ivec3 strides =
87352 ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
87353 const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
87354
87355 void main() {
87356 ivec5 coords = getOutputCoords();
87357 int batch = coords.x;
87358 int ch = coords.u;
87359
87360 ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
87361 int xDCorner = xCorner.x;
87362 int xRCorner = xCorner.y;
87363 int xCCorner = xCorner.z;
87364
87365 // max/min x(?, ?, ?, ch) to get y(yD, yR, yC, ch).
87366 // ? = to be determined
87367 float minMaxValue = 0.0;
87368 float minMaxValueFound = 0.0;
87369 int minMaxPosition = 0;
87370
87371 for (int wD = 0; wD < ${effectiveFilterDepth};
87372 wD += ${dilationDepth}) {
87373 int xD = xDCorner + wD;
87374
87375 if (xD < 0 || xD >= ${convInfo.inDepth}) {
87376 continue;
87377 }
87378
87379 for (int wR = 0; wR < ${effectiveFilterHeight};
87380 wR += ${dilationHeight}) {
87381 int xR = xRCorner + wR;
87382
87383 if (xR < 0 || xR >= ${convInfo.inHeight}) {
87384 continue;
87385 }
87386
87387 for (int wC = 0; wC < ${effectiveFilterWidth};
87388 wC += ${dilationWidth}) {
87389 int xC = xCCorner + wC;
87390
87391 if (xC < 0 || xC >= ${convInfo.inWidth}) {
87392 continue;
87393 }
87394
87395 float value = getX(batch, xD, xR, xC, ch);
87396
87397 // If a min / max value has already been found, use it. If not,
87398 // use the current value.
87399 float currMinMaxValue = mix(
87400 value, minMaxValue, minMaxValueFound);
87401 if (value ${compareOp} currMinMaxValue) {
87402 minMaxValue = value;
87403 minMaxValueFound = 1.0;
87404 minMaxPosition = ${flattenPositions ?
87405 (includeBatchInIndex ?
87406 `(((batch * ${convInfo.inDepth} + xD) * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch` :
87407 `((xD * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch`) :
87408 `wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} +
87409 wR * ${effectiveFilterWidth} + wC`};
87410 }
87411 }
87412 }
87413 }
87414 setOutput(float(minMaxPosition));
87415 }
87416 `;
87417 return;
87418 }
87419 const compareOp = 'max';
87420 let returnValue = `${poolType}(${poolType}(${poolType}(` +
87421 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
87422 if (poolType === 'avg') {
87423 returnValue = `avgValue / count`;
87424 }
87425 const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
87426 const filterWidthVec4Remainder = filterWidth % 4;
87427 const updateSnippet = `
87428 if (${isAvgPool}) {
87429 avgValue += dot(values, ones);
87430 } else {
87431 minMaxValue = ${compareOp}(values, minMaxValue);
87432 }
87433 `;
87434 this.userCode = `
87435 const ivec3 strides =
87436 ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
87437 const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
87438 const float initializationValue = ${initializationValue};
87439 const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
87440
87441 float count = 0.0;
87442
87443 float getValue(int batch, int xD, int xR, int xC, int ch) {
87444 if (xC < 0 || xC >= ${convInfo.inWidth}) {
87445 return initializationValue;
87446 }
87447 count += 1.0;
87448 return getX(batch, xD, xR, xC, ch);
87449 }
87450
87451 void main() {
87452 ivec5 coords = getOutputCoords();
87453 int batch = coords.x;
87454 int ch = coords.u;
87455
87456 ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
87457 int xDCorner = xCorner.x;
87458 int xRCorner = xCorner.y;
87459 int xCCorner = xCorner.z;
87460
87461 // max/min x(?, ?, ?, d) to get y(yD, yR, yC, ch).
87462 // ? = to be determined
87463 vec4 minMaxValue = vec4(${initializationValue});
87464 float avgValue = 0.0;
87465 count = 0.0;
87466
87467 for (int wD = 0; wD < ${effectiveFilterDepth};
87468 wD += ${dilationDepth}) {
87469 int xD = xDCorner + wD;
87470
87471 if (xD < 0 || xD >= ${convInfo.inDepth}) {
87472 continue;
87473 }
87474
87475 for (int wR = 0; wR < ${effectiveFilterHeight};
87476 wR += ${dilationHeight}) {
87477 int xR = xRCorner + wR;
87478
87479 if (xR < 0 || xR >= ${convInfo.inHeight}) {
87480 continue;
87481 }
87482
87483 for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {
87484 int xC = xCCorner + wC * ${dilationWidth};
87485
87486 vec4 values = vec4(
87487 getValue(batch, xD, xR, xC, ch),
87488 getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
87489 getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),
87490 getValue(batch, xD, xR, xC + 3 * ${dilationWidth}, ch)
87491 );
87492
87493 ${updateSnippet}
87494 }
87495
87496 int xC = xCCorner + ${filterWidthNearestVec4};
87497 if (${filterWidthVec4Remainder === 1}) {
87498 vec4 values = vec4(
87499 getValue(batch, xD, xR, xC, ch),
87500 initializationValue,
87501 initializationValue,
87502 initializationValue
87503 );
87504
87505 ${updateSnippet}
87506 } else if (${filterWidthVec4Remainder === 2}) {
87507 vec4 values = vec4(
87508 getValue(batch, xD, xR, xC, ch),
87509 getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
87510 initializationValue,
87511 initializationValue
87512 );
87513
87514 ${updateSnippet}
87515 } else if (${filterWidthVec4Remainder === 3}) {
87516 vec4 values = vec4(
87517 getValue(batch, xD, xR, xC, ch),
87518 getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
87519 getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),
87520 initializationValue
87521 );
87522
87523 ${updateSnippet}
87524 }
87525 }
87526 setOutput(${returnValue});
87527 }
87528 }
87529 `;
87530 }
87531 }
87532
87533 /**
87534 * @license
87535 * Copyright 2020 Google LLC. All Rights Reserved.
87536 * Licensed under the Apache License, Version 2.0 (the "License");
87537 * you may not use this file except in compliance with the License.
87538 * You may obtain a copy of the License at
87539 *
87540 * http://www.apache.org/licenses/LICENSE-2.0
87541 *
87542 * Unless required by applicable law or agreed to in writing, software
87543 * distributed under the License is distributed on an "AS IS" BASIS,
87544 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87545 * See the License for the specific language governing permissions and
87546 * limitations under the License.
87547 * =============================================================================
87548 */
87549 function avgPool$2(args) {
87550 const { inputs, backend, attrs } = args;
87551 const { x } = inputs;
87552 assertNotComplex$1(x, 'avgPool');
87553 const { filterSize, strides, pad, dimRoundingMode } = attrs;
87554 const dilations = 1;
87555 assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool: Either strides or dilations must be 1. ' +
87556 `Got strides ${strides} and dilations '${dilations}'`);
87557 const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
87558 if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
87559 arraysEqual(convInfo.inShape, convInfo.outShape)) {
87560 return identity$2({ inputs: { x }, backend });
87561 }
87562 const avgPoolProgram = new Pool2DProgram(convInfo, 'avg', false);
87563 return backend.runWebGLProgram(avgPoolProgram, [x], 'float32');
87564 }
87565 const avgPoolConfig$1 = {
87566 kernelName: AvgPool,
87567 backendName: 'webgl',
87568 kernelFunc: avgPool$2
87569 };
87570
87571 /**
87572 * @license
87573 * Copyright 2020 Google LLC. All Rights Reserved.
87574 * Licensed under the Apache License, Version 2.0 (the "License");
87575 * you may not use this file except in compliance with the License.
87576 * You may obtain a copy of the License at
87577 *
87578 * http://www.apache.org/licenses/LICENSE-2.0
87579 *
87580 * Unless required by applicable law or agreed to in writing, software
87581 * distributed under the License is distributed on an "AS IS" BASIS,
87582 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87583 * See the License for the specific language governing permissions and
87584 * limitations under the License.
87585 * =============================================================================
87586 */
87587 function avgPool3D$1(args) {
87588 const { inputs, backend, attrs } = args;
87589 const { x } = inputs;
87590 const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs;
87591 const dilations = [1, 1, 1];
87592 const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat);
87593 const avgPoolProgram = new Pool3DProgram(convInfo, 'avg', false);
87594 return backend.runWebGLProgram(avgPoolProgram, [x], 'float32');
87595 }
87596 const avgPool3DConfig$1 = {
87597 kernelName: AvgPool3D,
87598 backendName: 'webgl',
87599 kernelFunc: avgPool3D$1
87600 };
87601
87602 /**
87603 * @license
87604 * Copyright 2017 Google LLC. All Rights Reserved.
87605 * Licensed under the Apache License, Version 2.0 (the "License");
87606 * you may not use this file except in compliance with the License.
87607 * You may obtain a copy of the License at
87608 *
87609 * http://www.apache.org/licenses/LICENSE-2.0
87610 *
87611 * Unless required by applicable law or agreed to in writing, software
87612 * distributed under the License is distributed on an "AS IS" BASIS,
87613 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87614 * See the License for the specific language governing permissions and
87615 * limitations under the License.
87616 * =============================================================================
87617 */
87618 class AvgPool2DBackpropProgram {
87619 constructor(convInfo) {
87620 this.variableNames = ['dy'];
87621 this.outputShape = convInfo.inShape;
87622 const filterHeight = convInfo.filterHeight;
87623 const filterWidth = convInfo.filterWidth;
87624 const strideHeight = convInfo.strideHeight;
87625 const strideWidth = convInfo.strideWidth;
87626 const dilationHeight = convInfo.dilationHeight;
87627 const dilationWidth = convInfo.dilationWidth;
87628 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
87629 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
87630 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
87631 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
87632 const avgMultiplier = 1 / (filterHeight * filterWidth);
87633 this.userCode = `
87634 const ivec2 pads = ivec2(${padTop}, ${padLeft});
87635 const float avgMultiplier = float(${avgMultiplier});
87636
87637 void main() {
87638 ivec4 coords = getOutputCoords();
87639 int b = coords[0];
87640 int d = coords[3];
87641
87642 ivec2 dyRCCorner = coords.yz - pads;
87643 int dyRCorner = dyRCCorner.x;
87644 int dyCCorner = dyRCCorner.y;
87645
87646 // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).
87647 // ? = to be determined. : = across all values in that axis.
87648 float dotProd = 0.0;
87649 for (int wR = 0; wR < ${effectiveFilterHeight};
87650 wR += ${dilationHeight}) {
87651 float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
87652
87653 if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
87654 continue;
87655 }
87656 int idyR = int(dyR);
87657
87658 for (int wC = 0; wC < ${effectiveFilterWidth};
87659 wC+= ${dilationWidth}) {
87660 float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
87661
87662 if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
87663 fract(dyC) > 0.0) {
87664 continue;
87665 }
87666 int idyC = int(dyC);
87667
87668 float dyValue = getDy(b, idyR, idyC, d);
87669
87670 dotProd += dyValue * avgMultiplier;
87671 }
87672 }
87673 setOutput(dotProd);
87674 }
87675 `;
87676 }
87677 }
87678 class AvgPool3DBackpropProgram {
87679 constructor(convInfo) {
87680 this.variableNames = ['dy'];
87681 this.outputShape = convInfo.inShape;
87682 const filterDepth = convInfo.filterDepth;
87683 const filterHeight = convInfo.filterHeight;
87684 const filterWidth = convInfo.filterWidth;
87685 const strideDepth = convInfo.strideDepth;
87686 const strideHeight = convInfo.strideHeight;
87687 const strideWidth = convInfo.strideWidth;
87688 const dilationDepth = convInfo.dilationDepth;
87689 const dilationHeight = convInfo.dilationHeight;
87690 const dilationWidth = convInfo.dilationWidth;
87691 const effectiveFilterDepth = convInfo.effectiveFilterDepth;
87692 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
87693 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
87694 const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
87695 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
87696 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
87697 const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
87698 this.userCode = `
87699 const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
87700 const float avgMultiplier = float(${avgMultiplier});
87701
87702 void main() {
87703 ivec5 coords = getOutputCoords();
87704 int batch = coords.x;
87705 int ch = coords.u;
87706
87707 ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
87708 int dyDCorner = dyCorner.x;
87709 int dyRCorner = dyCorner.y;
87710 int dyCCorner = dyCorner.z;
87711
87712 // Convolve dy(?, ?, ?, d) with pos mask(:, :, :, ch) to get
87713 // dx(xD, xR, xC, ch).
87714 // ? = to be determined. : = across all values in that axis.
87715 float dotProd = 0.0;
87716
87717 for (int wD = 0; wD < ${effectiveFilterDepth};
87718 wD += ${dilationDepth}) {
87719 float dyD = float(dyDCorner + wD) / ${strideDepth}.0;
87720
87721 if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) {
87722 continue;
87723 }
87724 int idyD = int(dyD);
87725
87726 for (int wR = 0; wR < ${effectiveFilterHeight};
87727 wR += ${dilationHeight}) {
87728 float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
87729
87730 if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
87731 fract(dyR) > 0.0) {
87732 continue;
87733 }
87734 int idyR = int(dyR);
87735
87736 for (int wC = 0; wC < ${effectiveFilterWidth};
87737 wC += ${dilationWidth}) {
87738 float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
87739
87740 if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
87741 fract(dyC) > 0.0) {
87742 continue;
87743 }
87744 int idyC = int(dyC);
87745
87746 float dyValue = getDy(batch, idyD, idyR, idyC, ch);
87747
87748 dotProd += dyValue * avgMultiplier;
87749 }
87750 }
87751 }
87752 setOutput(dotProd);
87753 }
87754 `;
87755 }
87756 }
87757
87758 /**
87759 * @license
87760 * Copyright 2020 Google LLC. All Rights Reserved.
87761 * Licensed under the Apache License, Version 2.0 (the "License");
87762 * you may not use this file except in compliance with the License.
87763 * You may obtain a copy of the License at
87764 *
87765 * http://www.apache.org/licenses/LICENSE-2.0
87766 *
87767 * Unless required by applicable law or agreed to in writing, software
87768 * distributed under the License is distributed on an "AS IS" BASIS,
87769 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87770 * See the License for the specific language governing permissions and
87771 * limitations under the License.
87772 * =============================================================================
87773 */
87774 function avgPool3DGrad$1(args) {
87775 const { inputs, backend, attrs } = args;
87776 const { dy, input } = inputs;
87777 const x = input;
87778 const { filterSize, strides, pad, dimRoundingMode } = attrs;
87779 const dilations = [1, 1, 1];
87780 const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
87781 const avgPoolBackpropProgram = new AvgPool3DBackpropProgram(convInfo);
87782 return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype);
87783 }
87784 const avgPool3DGradConfig$2 = {
87785 kernelName: AvgPool3DGrad,
87786 backendName: 'webgl',
87787 kernelFunc: avgPool3DGrad$1
87788 };
87789
87790 /**
87791 * @license
87792 * Copyright 2020 Google LLC. All Rights Reserved.
87793 * Licensed under the Apache License, Version 2.0 (the "License");
87794 * you may not use this file except in compliance with the License.
87795 * You may obtain a copy of the License at
87796 *
87797 * http://www.apache.org/licenses/LICENSE-2.0
87798 *
87799 * Unless required by applicable law or agreed to in writing, software
87800 * distributed under the License is distributed on an "AS IS" BASIS,
87801 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87802 * See the License for the specific language governing permissions and
87803 * limitations under the License.
87804 * =============================================================================
87805 */
87806 function avgPoolGrad$2(args) {
87807 const { inputs, backend, attrs } = args;
87808 const { dy, input } = inputs;
87809 const x = input;
87810 assertNotComplex$1([dy, input], 'avgPoolGrad');
87811 const { filterSize, strides, pad } = attrs;
87812 const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad);
87813 const avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo);
87814 return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype);
87815 }
87816 const avgPoolGradConfig$2 = {
87817 kernelName: AvgPoolGrad,
87818 backendName: 'webgl',
87819 kernelFunc: avgPoolGrad$2
87820 };
87821
87822 /**
87823 * @license
87824 * Copyright 2020 Google LLC. All Rights Reserved.
87825 * Licensed under the Apache License, Version 2.0 (the "License");
87826 * you may not use this file except in compliance with the License.
87827 * You may obtain a copy of the License at
87828 *
87829 * http://www.apache.org/licenses/LICENSE-2.0
87830 *
87831 * Unless required by applicable law or agreed to in writing, software
87832 * distributed under the License is distributed on an "AS IS" BASIS,
87833 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87834 * See the License for the specific language governing permissions and
87835 * limitations under the License.
87836 * =============================================================================
87837 */
87838 function batchMatMul$1(args) {
87839 const { inputs, backend, attrs } = args;
87840 const { a, b } = inputs;
87841 const { transposeA, transposeB } = attrs;
87842 return batchMatMulImpl({ a, b, transposeA, transposeB, backend });
87843 }
87844 const batchMatMulConfig$1 = {
87845 kernelName: BatchMatMul,
87846 backendName: 'webgl',
87847 kernelFunc: batchMatMul$1,
87848 };
87849
87850 /**
87851 * @license
87852 * Copyright 2017 Google LLC. All Rights Reserved.
87853 * Licensed under the Apache License, Version 2.0 (the "License");
87854 * you may not use this file except in compliance with the License.
87855 * You may obtain a copy of the License at
87856 *
87857 * http://www.apache.org/licenses/LICENSE-2.0
87858 *
87859 * Unless required by applicable law or agreed to in writing, software
87860 * distributed under the License is distributed on an "AS IS" BASIS,
87861 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87862 * See the License for the specific language governing permissions and
87863 * limitations under the License.
87864 * =============================================================================
87865 */
87866 class BatchNormProgram {
87867 constructor(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
87868 this.outputShape = [];
87869 this.variableNames = ['x', 'mean', 'variance'];
87870 assertAndGetBroadcastShape(xShape, meanShape);
87871 assertAndGetBroadcastShape(xShape, varianceShape);
87872 let offsetSnippet = '0.0';
87873 if (offsetShape != null) {
87874 assertAndGetBroadcastShape(xShape, offsetShape);
87875 this.variableNames.push('offset');
87876 offsetSnippet = 'getOffsetAtOutCoords()';
87877 }
87878 let scaleSnippet = '1.0';
87879 if (scaleShape != null) {
87880 assertAndGetBroadcastShape(xShape, scaleShape);
87881 this.variableNames.push('scale');
87882 scaleSnippet = 'getScaleAtOutCoords()';
87883 }
87884 this.outputShape = xShape;
87885 this.userCode = `
87886 void main() {
87887 float x = getXAtOutCoords();
87888 float mean = getMeanAtOutCoords();
87889 float variance = getVarianceAtOutCoords();
87890 float offset = ${offsetSnippet};
87891 float scale = ${scaleSnippet};
87892 float inv = scale * inversesqrt(variance + float(${varianceEpsilon}));
87893 setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1)));
87894 }
87895 `;
87896 }
87897 }
87898
87899 /**
87900 * @license
87901 * Copyright 2018 Google LLC. All Rights Reserved.
87902 * Licensed under the Apache License, Version 2.0 (the "License");
87903 * you may not use this file except in compliance with the License.
87904 * You may obtain a copy of the License at
87905 *
87906 * http://www.apache.org/licenses/LICENSE-2.0
87907 *
87908 * Unless required by applicable law or agreed to in writing, software
87909 * distributed under the License is distributed on an "AS IS" BASIS,
87910 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87911 * See the License for the specific language governing permissions and
87912 * limitations under the License.
87913 * =============================================================================
87914 */
87915 class BatchNormPackedProgram {
87916 constructor(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
87917 this.packedInputs = true;
87918 this.packedOutput = true;
87919 this.variableNames = ['x', 'mean', 'variance'];
87920 assertAndGetBroadcastShape(xShape, meanShape);
87921 assertAndGetBroadcastShape(xShape, varianceShape);
87922 let offsetSnippet = 'vec4(0.0)';
87923 if (offsetShape != null) {
87924 assertAndGetBroadcastShape(xShape, offsetShape);
87925 this.variableNames.push('offset');
87926 offsetSnippet = 'getOffsetAtOutCoords()';
87927 }
87928 let scaleSnippet = 'vec4(1.0)';
87929 if (scaleShape != null) {
87930 assertAndGetBroadcastShape(xShape, scaleShape);
87931 this.variableNames.push('scale');
87932 scaleSnippet = 'getScaleAtOutCoords()';
87933 }
87934 this.outputShape = xShape;
87935 this.userCode = `
87936 void main() {
87937 vec4 offset = ${offsetSnippet};
87938 vec4 scale = ${scaleSnippet};
87939
87940 vec4 x = getXAtOutCoords();
87941 vec4 mean = getMeanAtOutCoords();
87942 vec4 variance = getVarianceAtOutCoords();
87943
87944 vec4 inv = scale * inversesqrt(variance + vec4(${varianceEpsilon}));
87945
87946 setOutput((x - mean) * inv + offset);
87947 }
87948 `;
87949 }
87950 }
87951
87952 /**
87953 * @license
87954 * Copyright 2020 Google LLC. All Rights Reserved.
87955 * Licensed under the Apache License, Version 2.0 (the "License");
87956 * you may not use this file except in compliance with the License.
87957 * You may obtain a copy of the License at
87958 *
87959 * http://www.apache.org/licenses/LICENSE-2.0
87960 *
87961 * Unless required by applicable law or agreed to in writing, software
87962 * distributed under the License is distributed on an "AS IS" BASIS,
87963 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87964 * See the License for the specific language governing permissions and
87965 * limitations under the License.
87966 * =============================================================================
87967 */
87968 const batchNorm$2 = ({ inputs, backend, attrs }) => {
87969 const { x, mean, variance, offset, scale } = inputs;
87970 assert(mean.shape.length === variance.shape.length, () => 'Batch normalization gradient requires mean and variance to have ' +
87971 'equal ranks.');
87972 assert(offset == null || mean.shape.length === offset.shape.length, () => 'Batch normalization gradient requires mean and offset to have ' +
87973 'equal ranks.');
87974 assert(scale == null || mean.shape.length === scale.shape.length, () => 'Batch normalization gradient requires mean and scale to have ' +
87975 'equal ranks.');
87976 let { varianceEpsilon } = attrs;
87977 if (varianceEpsilon == null) {
87978 varianceEpsilon = 0.001;
87979 }
87980 const finalInputs = [x, mean, variance];
87981 let offsetShape = null;
87982 if (offset != null) {
87983 offsetShape = offset.shape;
87984 finalInputs.push(offset);
87985 }
87986 let scaleShape = null;
87987 if (scale != null) {
87988 scaleShape = scale.shape;
87989 finalInputs.push(scale);
87990 }
87991 const program = env().getBool('WEBGL_PACK_NORMALIZATION') ?
87992 new BatchNormPackedProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon) :
87993 new BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon);
87994 const output = backend.runWebGLProgram(program, finalInputs, finalInputs[0].dtype);
87995 return output;
87996 };
87997 const batchNormConfig$1 = {
87998 kernelName: FusedBatchNorm,
87999 backendName: 'webgl',
88000 kernelFunc: batchNorm$2,
88001 };
88002
88003 /**
88004 * @license
88005 * Copyright 2017 Google LLC. All Rights Reserved.
88006 * Licensed under the Apache License, Version 2.0 (the "License");
88007 * you may not use this file except in compliance with the License.
88008 * You may obtain a copy of the License at
88009 *
88010 * http://www.apache.org/licenses/LICENSE-2.0
88011 *
88012 * Unless required by applicable law or agreed to in writing, software
88013 * distributed under the License is distributed on an "AS IS" BASIS,
88014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88015 * See the License for the specific language governing permissions and
88016 * limitations under the License.
88017 * =============================================================================
88018 */
88019 class SliceProgram {
88020 constructor(destSize) {
88021 this.variableNames = ['source'];
88022 this.outputShape = destSize;
88023 this.rank = destSize.length;
88024 const dtype = getCoordsDataType(this.rank);
88025 this.customUniforms = [{ name: 'start', arrayIndex: this.rank, type: 'int' }];
88026 const sourceCoords = getCoords(this.rank);
88027 let body;
88028 const coordSum = destSize.map((_, i) => {
88029 return `sourceLoc.${coords[i]} = start[${i}] + coords.${coords[i]};`;
88030 });
88031 body = `
88032 ${dtype} sourceLoc;
88033 ${dtype} coords = getOutputCoords();
88034 ${coordSum.join('\n')}
88035 `;
88036 this.userCode = `
88037 void main() {
88038 ${body}
88039 setOutput(getSource(${sourceCoords}));
88040 }
88041 `;
88042 }
88043 }
88044 const coords = ['x', 'y', 'z', 'w', 'u', 'v'];
88045 function getCoords(rank) {
88046 if (rank === 1) {
88047 return 'sourceLoc';
88048 }
88049 else if (rank <= 6) {
88050 return coords.slice(0, rank).map(x => 'sourceLoc.' + x).join(',');
88051 }
88052 else {
88053 throw Error(`Slicing for rank ${rank} is not yet supported`);
88054 }
88055 }
88056
88057 /**
88058 * @license
88059 * Copyright 2019 Google LLC. All Rights Reserved.
88060 * Licensed under the Apache License, Version 2.0 (the "License");
88061 * you may not use this file except in compliance with the License.
88062 * You may obtain a copy of the License at
88063 *
88064 * http://www.apache.org/licenses/LICENSE-2.0
88065 *
88066 * Unless required by applicable law or agreed to in writing, software
88067 * distributed under the License is distributed on an "AS IS" BASIS,
88068 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88069 * See the License for the specific language governing permissions and
88070 * limitations under the License.
88071 * =============================================================================
88072 */
88073 class SlicePackedProgram {
88074 constructor(destSize) {
88075 this.variableNames = ['source'];
88076 this.packedInputs = true;
88077 this.packedOutput = true;
88078 this.outputShape = destSize;
88079 this.rank = destSize.length;
88080 this.customUniforms = [{ name: 'start', arrayIndex: this.rank, type: 'int' }];
88081 const dtype = getCoordsDataType(this.rank);
88082 const coords = getChannels('coords', this.rank);
88083 const sourceLoc = getChannels('sourceLoc', this.rank);
88084 const innerDims = this.rank === 1 ? 'sourceLoc' : `vec2(${sourceLoc.slice(-2).join()})`;
88085 const getChannel = `getChannel(getSource(${sourceLoc.join()}), ${innerDims})`;
88086 const upperRow = `
88087 result.x = ${getChannel};
88088 if (++${coords[this.rank - 1]} < ${destSize[this.rank - 1]}) {
88089 ++${sourceLoc[this.rank - 1]};
88090 result.y = ${getChannel};
88091 --${sourceLoc[this.rank - 1]};
88092 }
88093 `;
88094 const lowerRow = this.rank === 1 ? '' : `
88095 --${coords[this.rank - 1]};
88096 if (++${coords[this.rank - 2]} < ${destSize[this.rank - 2]}) {
88097 ++${sourceLoc[this.rank - 2]};
88098 result.z = ${getChannel};
88099 if (++${coords[this.rank - 1]} < ${destSize[this.rank - 1]}) {
88100 ++${sourceLoc[this.rank - 1]};
88101 result.w = ${getChannel};
88102 }
88103 }
88104 `;
88105 const sourceLocSetup = this.rank <= 4 ?
88106 `sourceLoc = coords +
88107 ${dtype}(${destSize.map((_, i) => `start[${i}]`).join()});` :
88108 destSize.map((_, i) => `${sourceLoc[i]} = ${coords[i]} + start[${i}];`)
88109 .join('\n');
88110 this.userCode = `
88111 void main() {
88112 ${dtype} coords = getOutputCoords();
88113 ${dtype} sourceLoc;
88114 ${sourceLocSetup}
88115 vec4 result = vec4(0.);
88116 ${upperRow}
88117 ${lowerRow}
88118 setOutput(result);
88119 }
88120 `;
88121 }
88122 }
88123
88124 /**
88125 * @license
88126 * Copyright 2020 Google LLC. All Rights Reserved.
88127 * Licensed under the Apache License, Version 2.0 (the "License");
88128 * you may not use this file except in compliance with the License.
88129 * You may obtain a copy of the License at
88130 *
88131 * http://www.apache.org/licenses/LICENSE-2.0
88132 *
88133 * Unless required by applicable law or agreed to in writing, software
88134 * distributed under the License is distributed on an "AS IS" BASIS,
88135 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88136 * See the License for the specific language governing permissions and
88137 * limitations under the License.
88138 * =============================================================================
88139 */
88140 function shallowSlice(x, begin, size, backend) {
88141 const xTexData = backend.texData.get(x.dataId);
88142 const t = backend.makeTensorInfo(size, x.dtype);
88143 const newTexData = backend.texData.get(t.dataId);
88144 // Copy texture data from the original tensor.
88145 Object.assign(newTexData, xTexData);
88146 newTexData.refCount = 1;
88147 newTexData.shape = size;
88148 newTexData.dtype = x.dtype;
88149 let flatOffset = computeFlatOffset(begin, computeStrides(x.shape));
88150 if (xTexData.slice) {
88151 // We are slicing an already sliced tensor, so we have to accumulate
88152 // the offset.
88153 flatOffset += xTexData.slice.flatOffset;
88154 }
88155 newTexData.slice = {
88156 flatOffset,
88157 // Point to the original dataId, which is used to do ref counting.
88158 origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId
88159 };
88160 // Increase the ref count for that data bucket.
88161 const refCount = backend.dataRefCount.get(newTexData.slice.origDataId) || 1;
88162 backend.dataRefCount.set(newTexData.slice.origDataId, refCount + 1);
88163 return t;
88164 }
88165 function slice$2(args) {
88166 const { inputs, backend, attrs } = args;
88167 const { x } = inputs;
88168 const { begin, size } = attrs;
88169 const [$begin, $size] = parseSliceParams(x, begin, size);
88170 assertParamsValid(x, $begin, $size);
88171 if (sizeFromShape($size) === 0) {
88172 return backend.makeTensorInfo($size, x.dtype, []);
88173 }
88174 // Run on cpu if dtype is string. For string, the backend represents it
88175 // as Uint8Array[], where each Uint8Array is a character. Given that the
88176 // computation is only on the outer array, uploading the whole data onto
88177 // gpu is wasteful. Also, currently webgl doesn't have a design to
88178 // upload and retrieve Uint8Array[] between cpu and gpu. Therefore, we
88179 // just run the kernel on cpu if dtype is string.
88180 if (backend.shouldExecuteOnCPU([x]) || x.dtype === 'string') {
88181 const xTexData = backend.texData.get(x.dataId);
88182 const outValues = sliceImplCPU(xTexData.values, $begin, $size, x.shape, x.dtype);
88183 return backend.makeTensorInfo($size, x.dtype, outValues);
88184 }
88185 const { isPacked } = backend.texData.get(x.dataId);
88186 const isContinous = isSliceContinous(x.shape, $begin, $size);
88187 if (isPacked || !isContinous) {
88188 const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
88189 new SlicePackedProgram($size) :
88190 new SliceProgram($size);
88191 const customValues = [$begin];
88192 return backend.runWebGLProgram(program, [x], x.dtype, customValues);
88193 }
88194 backend.uploadToGPU(x.dataId);
88195 return shallowSlice(x, $begin, $size, backend);
88196 }
88197 const sliceConfig$1 = {
88198 kernelName: Slice,
88199 backendName: 'webgl',
88200 kernelFunc: slice$2
88201 };
88202
88203 /**
88204 * @license
88205 * Copyright 2020 Google LLC. All Rights Reserved.
88206 * Licensed under the Apache License, Version 2.0 (the "License");
88207 * you may not use this file except in compliance with the License.
88208 * You may obtain a copy of the License at
88209 *
88210 * http://www.apache.org/licenses/LICENSE-2.0
88211 *
88212 * Unless required by applicable law or agreed to in writing, software
88213 * distributed under the License is distributed on an "AS IS" BASIS,
88214 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88215 * See the License for the specific language governing permissions and
88216 * limitations under the License.
88217 * =============================================================================
88218 */
88219 const batchToSpaceND$2 = (args) => {
88220 const { inputs, backend, attrs } = args;
88221 const { x } = inputs;
88222 const { blockShape, crops } = attrs;
88223 assert(x.shape.length <= 4, () => 'batchToSpaceND for rank > 4 with a WebGL backend not ' +
88224 'implemented yet');
88225 const prod = blockShape.reduce((a, b) => a * b);
88226 const reshaped = getReshaped(x.shape, blockShape, prod);
88227 const permuted = getPermuted(reshaped.length, blockShape.length);
88228 const reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod);
88229 const sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length);
88230 const sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length);
88231 const toDispose = [];
88232 const reshapedIntermediate = reshape$3({ inputs: { x }, backend, attrs: { shape: reshaped } });
88233 const transposedIntermediate = transpose$2({ inputs: { x: reshapedIntermediate }, backend, attrs: { perm: permuted } });
88234 const reshapedIntermediate2 = reshape$3({
88235 inputs: { x: transposedIntermediate },
88236 backend,
88237 attrs: { shape: reshapedPermuted }
88238 });
88239 const sliced = slice$2({
88240 inputs: { x: reshapedIntermediate2 },
88241 backend,
88242 attrs: { begin: sliceBeginCoords, size: sliceSize }
88243 });
88244 toDispose.push(reshapedIntermediate);
88245 toDispose.push(transposedIntermediate);
88246 toDispose.push(reshapedIntermediate2);
88247 toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
88248 return sliced;
88249 };
88250 const batchToSpaceNDConfig$1 = {
88251 kernelName: BatchToSpaceND,
88252 backendName: 'webgl',
88253 kernelFunc: batchToSpaceND$2
88254 };
88255
88256 /**
88257 * @license
88258 * Copyright 2020 Google LLC. All Rights Reserved.
88259 * Licensed under the Apache License, Version 2.0 (the "License");
88260 * you may not use this file except in compliance with the License.
88261 * You may obtain a copy of the License at
88262 *
88263 * http://www.apache.org/licenses/LICENSE-2.0
88264 *
88265 * Unless required by applicable law or agreed to in writing, software
88266 * distributed under the License is distributed on an "AS IS" BASIS,
88267 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88268 * See the License for the specific language governing permissions and
88269 * limitations under the License.
88270 * =============================================================================
88271 */
88272 function bincount$2(args) {
88273 const { inputs, backend, attrs } = args;
88274 const { x, weights } = inputs;
88275 const { size } = attrs;
88276 const xVals = backend.readSync(x.dataId);
88277 const weightsVals = backend.readSync(weights.dataId);
88278 const outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size);
88279 return backend.makeTensorInfo([size], weights.dtype, outVals);
88280 }
88281 const bincountConfig$1 = {
88282 kernelName: Bincount,
88283 backendName: 'webgl',
88284 kernelFunc: bincount$2
88285 };
88286
88287 /**
88288 * @license
88289 * Copyright 2021 Google LLC. All Rights Reserved.
88290 * Licensed under the Apache License, Version 2.0 (the "License");
88291 * you may not use this file except in compliance with the License.
88292 * You may obtain a copy of the License at
88293 *
88294 * http://www.apache.org/licenses/LICENSE-2.0
88295 *
88296 * Unless required by applicable law or agreed to in writing, software
88297 * distributed under the License is distributed on an "AS IS" BASIS,
88298 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88299 * See the License for the specific language governing permissions and
88300 * limitations under the License.
88301 * =============================================================================
88302 */
88303 function broadcastArgs$2(args) {
88304 const { inputs, backend } = args;
88305 const { s0, s1 } = inputs;
88306 const s0Vals = backend.readSync(s0.dataId);
88307 const s1Vals = backend.readSync(s1.dataId);
88308 const broadcastShape = assertAndGetBroadcastShape(Array.from(s0Vals), Array.from(s1Vals));
88309 return backend.makeTensorInfo([broadcastShape.length], 'int32', Int32Array.from(broadcastShape));
88310 }
88311 const broadcastArgsConfig$1 = {
88312 kernelName: BroadcastArgs,
88313 backendName: 'webgl',
88314 kernelFunc: broadcastArgs$2
88315 };
88316
88317 /**
88318 * @license
88319 * Copyright 2020 Google LLC. All Rights Reserved.
88320 * Licensed under the Apache License, Version 2.0 (the "License");
88321 * you may not use this file except in compliance with the License.
88322 * You may obtain a copy of the License at
88323 *
88324 * http://www.apache.org/licenses/LICENSE-2.0
88325 *
88326 * Unless required by applicable law or agreed to in writing, software
88327 * distributed under the License is distributed on an "AS IS" BASIS,
88328 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88329 * See the License for the specific language governing permissions and
88330 * limitations under the License.
88331 * =============================================================================
88332 */
88333 const NOT_EQUAL$1 = `return float(a != b);`;
88334 const notEqual$2 = binaryKernelFunc$1({ opSnippet: NOT_EQUAL$1, cpuKernelImpl: notEqualImplCPU, dtype: 'bool' });
88335 const notEqualConfig$1 = {
88336 kernelName: NotEqual,
88337 backendName: 'webgl',
88338 kernelFunc: notEqual$2,
88339 };
88340
88341 /**
88342 * @license
88343 * Copyright 2020 Google LLC. All Rights Reserved.
88344 * Licensed under the Apache License, Version 2.0 (the "License");
88345 * you may not use this file except in compliance with the License.
88346 * You may obtain a copy of the License at
88347 *
88348 * http://www.apache.org/licenses/LICENSE-2.0
88349 *
88350 * Unless required by applicable law or agreed to in writing, software
88351 * distributed under the License is distributed on an "AS IS" BASIS,
88352 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88353 * See the License for the specific language governing permissions and
88354 * limitations under the License.
88355 * =============================================================================
88356 */
88357 function real$2(args) {
88358 const { inputs, backend } = args;
88359 const { input } = inputs;
88360 const inputData = backend.texData.get(input.dataId);
88361 return identity$2({ inputs: { x: inputData.complexTensorInfos.real }, backend });
88362 }
88363 const realConfig$1 = {
88364 kernelName: Real,
88365 backendName: 'webgl',
88366 kernelFunc: real$2
88367 };
88368
88369 /**
88370 * @license
88371 * Copyright 2020 Google LLC. All Rights Reserved.
88372 * Licensed under the Apache License, Version 2.0 (the "License");
88373 * you may not use this file except in compliance with the License.
88374 * You may obtain a copy of the License at
88375 *
88376 * http://www.apache.org/licenses/LICENSE-2.0
88377 *
88378 * Unless required by applicable law or agreed to in writing, software
88379 * distributed under the License is distributed on an "AS IS" BASIS,
88380 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88381 * See the License for the specific language governing permissions and
88382 * limitations under the License.
88383 * =============================================================================
88384 */
88385 const TO_INT = `return float(int(x));`;
88386 function int(input, backend) {
88387 const program = new UnaryOpProgram(input.shape, TO_INT);
88388 const output = backend.runWebGLProgram(program, [input], 'int32');
88389 return { dataId: output.dataId, shape: output.shape, dtype: output.dtype };
88390 }
88391
88392 /**
88393 * @license
88394 * Copyright 2020 Google LLC. All Rights Reserved.
88395 * Licensed under the Apache License, Version 2.0 (the "License");
88396 * you may not use this file except in compliance with the License.
88397 * You may obtain a copy of the License at
88398 *
88399 * http://www.apache.org/licenses/LICENSE-2.0
88400 *
88401 * Unless required by applicable law or agreed to in writing, software
88402 * distributed under the License is distributed on an "AS IS" BASIS,
88403 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88404 * See the License for the specific language governing permissions and
88405 * limitations under the License.
88406 * =============================================================================
88407 */
88408 function cast$3(args) {
88409 const { inputs, backend, attrs } = args;
88410 const { x } = inputs;
88411 const { dtype } = attrs;
88412 // Casting to complex64.
88413 if (dtype === 'complex64') {
88414 if (x.dtype === 'complex64') {
88415 return identity$2({ inputs: { x }, backend });
88416 }
88417 // TODO(annxingyuan): Import kernel function once zeros is modularized.
88418 const zerosTensor = zeros(x.shape);
88419 const floatX = cast$3({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
88420 const result = complex$2({ inputs: { real: floatX, imag: zerosTensor }, backend });
88421 zerosTensor.dispose();
88422 backend.disposeIntermediateTensorInfo(floatX);
88423 return result;
88424 }
88425 // Casting from complex64
88426 if (x.dtype === 'complex64') {
88427 const realPart = real$2({ inputs: { input: x }, backend });
88428 const result = cast$3({ inputs: { x: realPart }, backend, attrs: { dtype } });
88429 backend.disposeIntermediateTensorInfo(realPart);
88430 return result;
88431 }
88432 if (!hasEncodingLoss(x.dtype, dtype)) {
88433 // We don't change the underlying data, since we cast to higher
88434 // precision.
88435 const result = identity$2({ inputs: { x }, backend });
88436 return { dataId: result.dataId, shape: result.shape, dtype };
88437 }
88438 if (dtype === 'int32') {
88439 return int(x, backend);
88440 }
88441 if (dtype === 'bool') {
88442 const zerosTensorInfo = backend.makeTensorInfo([], 'bool', getTypedArrayFromDType('bool', 1));
88443 const binaryInputs = { a: x, b: zerosTensorInfo };
88444 const result = notEqual$2({ inputs: binaryInputs, backend });
88445 backend.disposeIntermediateTensorInfo(zerosTensorInfo);
88446 return result;
88447 }
88448 throw new Error(`Error in Cast: failed to cast ${x.dtype} to ${dtype}`);
88449 }
88450 const castConfig$1 = {
88451 kernelName: Cast,
88452 backendName: 'webgl',
88453 kernelFunc: cast$3
88454 };
88455
88456 /**
88457 * @license
88458 * Copyright 2020 Google LLC. All Rights Reserved.
88459 * Licensed under the Apache License, Version 2.0 (the "License");
88460 * you may not use this file except in compliance with the License.
88461 * You may obtain a copy of the License at
88462 *
88463 * http://www.apache.org/licenses/LICENSE-2.0
88464 *
88465 * Unless required by applicable law or agreed to in writing, software
88466 * distributed under the License is distributed on an "AS IS" BASIS,
88467 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88468 * See the License for the specific language governing permissions and
88469 * limitations under the License.
88470 * =============================================================================
88471 */
88472 const CEIL = `return ceil(x);`;
88473 const ceil$2 = unaryKernelFunc$1({ opSnippet: CEIL, packedOpSnippet: CEIL, cpuKernelImpl: ceilImplCPU });
88474 const ceilConfig$1 = {
88475 kernelName: Ceil,
88476 backendName: 'webgl',
88477 kernelFunc: ceil$2
88478 };
88479
88480 /**
88481 * @license
88482 * Copyright 2017 Google LLC. All Rights Reserved.
88483 * Licensed under the Apache License, Version 2.0 (the "License");
88484 * you may not use this file except in compliance with the License.
88485 * You may obtain a copy of the License at
88486 *
88487 * http://www.apache.org/licenses/LICENSE-2.0
88488 *
88489 * Unless required by applicable law or agreed to in writing, software
88490 * distributed under the License is distributed on an "AS IS" BASIS,
88491 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88492 * See the License for the specific language governing permissions and
88493 * limitations under the License.
88494 * =============================================================================
88495 */
88496 class ClipProgram {
88497 constructor(aShape) {
88498 this.variableNames = ['A'];
88499 this.customUniforms = [
88500 { name: 'minVal', type: 'float' },
88501 { name: 'maxVal', type: 'float' }
88502 ];
88503 this.outputShape = aShape;
88504 this.userCode = `
88505
88506 void main() {
88507 float value = getAAtOutCoords();
88508 if (isnan(value)) {
88509 setOutput(value);
88510 return;
88511 }
88512
88513 setOutput(clamp(value, minVal, maxVal));
88514 }
88515 `;
88516 }
88517 }
88518
88519 /**
88520 * @license
88521 * Copyright 2018 Google LLC. All Rights Reserved.
88522 * Licensed under the Apache License, Version 2.0 (the "License");
88523 * you may not use this file except in compliance with the License.
88524 * You may obtain a copy of the License at
88525 *
88526 * http://www.apache.org/licenses/LICENSE-2.0
88527 *
88528 * Unless required by applicable law or agreed to in writing, software
88529 * distributed under the License is distributed on an "AS IS" BASIS,
88530 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88531 * See the License for the specific language governing permissions and
88532 * limitations under the License.
88533 * =============================================================================
88534 */
88535 class ClipPackedProgram {
88536 constructor(aShape) {
88537 this.variableNames = ['A'];
88538 this.packedInputs = true;
88539 this.packedOutput = true;
88540 this.customUniforms = [
88541 { name: 'minVal', type: 'float' },
88542 { name: 'maxVal', type: 'float' }
88543 ];
88544 this.outputShape = aShape;
88545 this.userCode = `
88546 void main() {
88547 vec4 value = getAAtOutCoords();
88548
88549 if (any(isnan(value))) {
88550 setOutput(value);
88551 return;
88552 }
88553
88554 setOutput(clamp(value, vec4(minVal), vec4(maxVal)));
88555 }
88556 `;
88557 }
88558 }
88559
88560 /**
88561 * @license
88562 * Copyright 2020 Google LLC. All Rights Reserved.
88563 * Licensed under the Apache License, Version 2.0 (the "License");
88564 * you may not use this file except in compliance with the License.
88565 * You may obtain a copy of the License at
88566 *
88567 * http://www.apache.org/licenses/LICENSE-2.0
88568 *
88569 * Unless required by applicable law or agreed to in writing, software
88570 * distributed under the License is distributed on an "AS IS" BASIS,
88571 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88572 * See the License for the specific language governing permissions and
88573 * limitations under the License.
88574 * =============================================================================
88575 */
88576 function clipByValue$2(args) {
88577 const { inputs, backend, attrs } = args;
88578 const { x } = inputs;
88579 const { clipValueMin, clipValueMax } = attrs;
88580 let program;
88581 if (env().getBool('WEBGL_PACK_CLIP')) {
88582 program = new ClipPackedProgram(x.shape);
88583 }
88584 else {
88585 program = new ClipProgram(x.shape);
88586 }
88587 const customValues = [[clipValueMin], [clipValueMax]];
88588 return backend.runWebGLProgram(program, [x], x.dtype, customValues);
88589 }
88590 const clipByValueConfig$1 = {
88591 kernelName: ClipByValue,
88592 backendName: 'webgl',
88593 kernelFunc: clipByValue$2
88594 };
88595
88596 /**
88597 * @license
88598 * Copyright 2018 Google LLC. All Rights Reserved.
88599 * Licensed under the Apache License, Version 2.0 (the "License");
88600 * you may not use this file except in compliance with the License.
88601 * You may obtain a copy of the License at
88602 *
88603 * http://www.apache.org/licenses/LICENSE-2.0
88604 *
88605 * Unless required by applicable law or agreed to in writing, software
88606 * distributed under the License is distributed on an "AS IS" BASIS,
88607 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88608 * See the License for the specific language governing permissions and
88609 * limitations under the License.
88610 * =============================================================================
88611 */
88612 class ComplexAbsProgram {
88613 constructor(shape) {
88614 this.variableNames = ['real', 'imag'];
88615 this.outputShape = shape;
88616 this.userCode = `
88617 void main() {
88618 float re = abs(getRealAtOutCoords());
88619 float im = abs(getImagAtOutCoords());
88620 float mx = max(re, im);
88621
88622 // sadly the length function in glsl is not underflow-safe
88623 // (at least not on Intel GPUs). So the safe solution is
88624 // to ensure underflow-safety in all cases.
88625 setOutput(
88626 mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx))
88627 );
88628 }
88629 `;
88630 }
88631 }
88632
88633 /**
88634 * @license
88635 * Copyright 2020 Google LLC. All Rights Reserved.
88636 * Licensed under the Apache License, Version 2.0 (the "License");
88637 * you may not use this file except in compliance with the License.
88638 * You may obtain a copy of the License at
88639 *
88640 * http://www.apache.org/licenses/LICENSE-2.0
88641 *
88642 * Unless required by applicable law or agreed to in writing, software
88643 * distributed under the License is distributed on an "AS IS" BASIS,
88644 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88645 * See the License for the specific language governing permissions and
88646 * limitations under the License.
88647 * =============================================================================
88648 */
88649 // Returns a TensorInfo with the complex shape and the dataId of the
88650 // underlying part. We need to do this because a reshaped complex tensor is
88651 // not reflected in its parts.
88652 function makeComplexComponentTensorInfo(complexTensor, complexPart) {
88653 return {
88654 dataId: complexPart.dataId,
88655 dtype: complexPart.dtype,
88656 shape: complexTensor.shape
88657 };
88658 }
88659 function complexAbs$1(args) {
88660 const { inputs, backend } = args;
88661 const { x } = inputs;
88662 const xData = backend.texData.get(x.dataId);
88663 const program = new ComplexAbsProgram(x.shape);
88664 const programInputs = [
88665 makeComplexComponentTensorInfo(x, xData.complexTensorInfos.real),
88666 makeComplexComponentTensorInfo(x, xData.complexTensorInfos.imag),
88667 ];
88668 return backend.runWebGLProgram(program, programInputs, programInputs[0].dtype);
88669 }
88670 const complexAbsConfig$1 = {
88671 kernelName: ComplexAbs,
88672 backendName: 'webgl',
88673 kernelFunc: complexAbs$1
88674 };
88675
88676 /**
88677 * @license
88678 * Copyright 2017 Google LLC. All Rights Reserved.
88679 * Licensed under the Apache License, Version 2.0 (the "License");
88680 * you may not use this file except in compliance with the License.
88681 * You may obtain a copy of the License at
88682 *
88683 * http://www.apache.org/licenses/LICENSE-2.0
88684 *
88685 * Unless required by applicable law or agreed to in writing, software
88686 * distributed under the License is distributed on an "AS IS" BASIS,
88687 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88688 * See the License for the specific language governing permissions and
88689 * limitations under the License.
88690 * =============================================================================
88691 */
88692 class ConcatProgram {
88693 // Concats 2d tensors along axis=1. See comments in MathBackendWebGL.concat().
88694 constructor(shapes) {
88695 this.outputShape = [];
88696 this.outputShape = computeOutShape$1(shapes, 1 /* axis */);
88697 this.variableNames = shapes.map((_, i) => `T${i}`);
88698 const offsets = new Array(shapes.length - 1);
88699 offsets[0] = shapes[0][1];
88700 for (let i = 1; i < offsets.length; i++) {
88701 offsets[i] = offsets[i - 1] + shapes[i][1];
88702 }
88703 const snippets = [`if (yC < ${offsets[0]}) setOutput(getT0(yR, yC));`];
88704 for (let i = 1; i < offsets.length; i++) {
88705 const shift = offsets[i - 1];
88706 snippets.push(`else if (yC < ${offsets[i]}) ` +
88707 `setOutput(getT${i}(yR, yC-${shift}));`);
88708 }
88709 const lastIndex = offsets.length;
88710 const lastShift = offsets[offsets.length - 1];
88711 snippets.push(`else setOutput(getT${lastIndex}(yR, yC-${lastShift}));`);
88712 this.userCode = `
88713 void main() {
88714 ivec2 coords = getOutputCoords();
88715 int yR = coords.x;
88716 int yC = coords.y;
88717
88718 ${snippets.join('\n ')}
88719 }
88720 `;
88721 }
88722 }
88723
88724 /**
88725 * @license
88726 * Copyright 2019 Google LLC. All Rights Reserved.
88727 * Licensed under the Apache License, Version 2.0 (the "License");
88728 * you may not use this file except in compliance with the License.
88729 * You may obtain a copy of the License at
88730 *
88731 * http://www.apache.org/licenses/LICENSE-2.0
88732 *
88733 * Unless required by applicable law or agreed to in writing, software
88734 * distributed under the License is distributed on an "AS IS" BASIS,
88735 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88736 * See the License for the specific language governing permissions and
88737 * limitations under the License.
88738 * =============================================================================
88739 */
88740 class ConcatPackedProgram {
88741 constructor(shapes, axis) {
88742 this.packedInputs = true;
88743 this.packedOutput = true;
88744 this.outputShape = [];
88745 this.outputShape = computeOutShape$1(shapes, axis);
88746 const shape = this.outputShape;
88747 const rank = shape.length;
88748 const dtype = getCoordsDataType(rank);
88749 const coords = getChannels('coords', rank);
88750 const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank);
88751 this.variableNames = shapes.map((_, i) => `T${i}`);
88752 const offsets = new Array(shapes.length - 1);
88753 offsets[0] = shapes[0][axis];
88754 for (let i = 1; i < offsets.length; i++) {
88755 offsets[i] = offsets[i - 1] + shapes[i][axis];
88756 }
88757 const channel = channels[axis];
88758 const lastChannels = channels.slice(-2);
88759 const allChannels = channels.join();
88760 let getValueSnippet = `if (${channel} < ${offsets[0]}) {
88761 return getChannel(
88762 getT0(${allChannels}), vec2(${lastChannels.join()}));
88763 }`;
88764 for (let i = 1; i < offsets.length; i++) {
88765 const shift = offsets[i - 1];
88766 // Note: the >= comparison below may seem unnecessary given the check
88767 // above but is needed to workaround branch execution issues on some
88768 // devices. It makes all the conditions exclusive without relying on
88769 // execution order.
88770 getValueSnippet += `
88771 if (${channel} < ${offsets[i]} && ${channel} >= ${offsets[i - 1]}) {
88772 return getChannel(
88773 getT${i}(${shiftedChannels(channels, channel, shift)}),
88774 vec2(${shiftedChannels(lastChannels, channel, shift)}));
88775 }`;
88776 }
88777 const lastIndex = offsets.length;
88778 const shift = offsets[offsets.length - 1];
88779 getValueSnippet += `
88780 return getChannel(
88781 getT${lastIndex}(${shiftedChannels(channels, channel, shift)}),
88782 vec2(${shiftedChannels(lastChannels, channel, shift)}));`;
88783 this.userCode = `
88784 float getValue(${channels.map(x => 'int ' + x)}) {
88785 ${getValueSnippet}
88786 }
88787
88788 void main() {
88789 ${dtype} coords = getOutputCoords();
88790 vec4 result = vec4(getValue(${coords}), 0., 0., 0.);
88791
88792 ${coords[rank - 1]} = ${coords[rank - 1]} + 1;
88793 if (${coords[rank - 1]} < ${shape[rank - 1]}) {
88794 result.g = getValue(${coords});
88795 }
88796
88797 ${coords[rank - 2]} = ${coords[rank - 2]} + 1;
88798 if (${coords[rank - 2]} < ${shape[rank - 2]}) {
88799 result.a = getValue(${coords});
88800 }
88801
88802 ${coords[rank - 1]} = ${coords[rank - 1]} - 1;
88803 if (${coords[rank - 2]} < ${shape[rank - 2]} &&
88804 ${coords[rank - 1]} < ${shape[rank - 1]}) {
88805 result.b = getValue(${coords});
88806 }
88807 setOutput(result);
88808 }
88809 `;
88810 }
88811 }
88812 /**
88813 * Return an expression for coordinates into a vector where a given channel
88814 * will be offset by [shift].
88815 *
88816 * @param channels the channels to consider
88817 * @param channel the channel we want shifted
88818 * @param shift the amount to subtract from the channel.
88819 *
88820 * @returns a string of the form 'x, y-[shift], z' where any one channel can
88821 * have the shift applied.
88822 */
88823 function shiftedChannels(channels, channel, shift) {
88824 const channelIdx = channels.indexOf(channel);
88825 const res = channels.map((c, idx) => {
88826 if (idx === channelIdx) {
88827 return `${c} - ${shift}`;
88828 }
88829 else {
88830 return c;
88831 }
88832 });
88833 return res.join();
88834 }
88835
88836 /**
88837 * @license
88838 * Copyright 2020 Google LLC. All Rights Reserved.
88839 * Licensed under the Apache License, Version 2.0 (the "License");
88840 * you may not use this file except in compliance with the License.
88841 * You may obtain a copy of the License at
88842 *
88843 * http://www.apache.org/licenses/LICENSE-2.0
88844 *
88845 * Unless required by applicable law or agreed to in writing, software
88846 * distributed under the License is distributed on an "AS IS" BASIS,
88847 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88848 * See the License for the specific language governing permissions and
88849 * limitations under the License.
88850 * =============================================================================
88851 */
88852 function imag$2(args) {
88853 const { inputs, backend } = args;
88854 const { input } = inputs;
88855 const inputData = backend.texData.get(input.dataId);
88856 return identity$2({ inputs: { x: inputData.complexTensorInfos.imag }, backend });
88857 }
88858 const imagConfig$1 = {
88859 kernelName: Imag,
88860 backendName: 'webgl',
88861 kernelFunc: imag$2
88862 };
88863
88864 /**
88865 * @license
88866 * Copyright 2020 Google LLC. All Rights Reserved.
88867 * Licensed under the Apache License, Version 2.0 (the "License");
88868 * you may not use this file except in compliance with the License.
88869 * You may obtain a copy of the License at
88870 *
88871 * http://www.apache.org/licenses/LICENSE-2.0
88872 *
88873 * Unless required by applicable law or agreed to in writing, software
88874 * distributed under the License is distributed on an "AS IS" BASIS,
88875 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88876 * See the License for the specific language governing permissions and
88877 * limitations under the License.
88878 * =============================================================================
88879 */
88880 function concatImpl$1(inputs, axis, backend) {
88881 const dtype = inputs[0].dtype;
88882 if (dtype === 'complex64') {
88883 const reals = inputs.map((t) => real$2({ inputs: { input: t }, backend }));
88884 const imags = inputs.map((t) => imag$2({ inputs: { input: t }, backend }));
88885 const realConcated = concatImpl$1(reals, axis, backend);
88886 const imagConcated = concatImpl$1(imags, axis, backend);
88887 const result = complex$2({ inputs: { real: realConcated, imag: imagConcated }, backend });
88888 reals.forEach(r => backend.disposeIntermediateTensorInfo(r));
88889 imags.forEach(i => backend.disposeIntermediateTensorInfo(i));
88890 backend.disposeIntermediateTensorInfo(realConcated);
88891 backend.disposeIntermediateTensorInfo(imagConcated);
88892 return result;
88893 }
88894 let runOnCpu = backend.shouldExecuteOnCPU(inputs);
88895 // Run on cpu if dtype is string. For string, the backend represents it
88896 // as Uint8Array[], where each Uint8Array is a character. Given that the
88897 // computation is only on the outer array, uploading the whole data onto
88898 // gpu is wasteful. Also, currently webgl doesn't have a design to
88899 // upload and retrieve Uint8Array[] between cpu and gpu. Therefore, we
88900 // just run the kernel on cpu if dtype is string.
88901 if (dtype === 'string') {
88902 runOnCpu = true;
88903 }
88904 if (runOnCpu) {
88905 // Any concat of n-dimensional tensors across any axis can be reduced to
88906 // a concatenation of two-dimensional tensors across the axis 1 by first
88907 // partitioning the axes of the original tensors into those less than the
88908 // axis to be concatenated and the rest. Then reshape the tensors
88909 // into a two-dimensional tensor by collapsing these two sets of axes and
88910 // concatenate the resulting matrices across the axis 1, finally reshaping
88911 // the result to have the proper shape.
88912 const tensors2D = inputs.map(t => {
88913 const innerSize = sizeFromShape(t.shape.slice(axis));
88914 const shape = [-1, innerSize];
88915 return reshape$3({ inputs: { x: t }, backend, attrs: { shape } });
88916 });
88917 const inputsValShapes = tensors2D.map(t => {
88918 return { vals: backend.readSync(t.dataId), shape: t.shape };
88919 });
88920 // Concats 2d tensors along axis=1.
88921 const outShape = computeOutShape$1(tensors2D.map(t => t.shape), 1 /* axis */);
88922 const simplyConcat = tensors2D[0].shape[0] === 1;
88923 const outVals = concatImplCPU(inputsValShapes, outShape, dtype, simplyConcat);
88924 const finalOutShape = computeOutShape$1(inputs.map(t => t.shape), axis);
88925 const outInfo = backend.makeTensorInfo(finalOutShape, dtype, outVals);
88926 tensors2D.forEach(t => backend.disposeIntermediateTensorInfo(t));
88927 return outInfo;
88928 }
88929 if (inputs.length > env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) {
88930 const midIndex = Math.floor(inputs.length / 2);
88931 const leftSide = concatImpl$1(inputs.slice(0, midIndex), axis, backend);
88932 const rightSide = concatImpl$1(inputs.slice(midIndex), axis, backend);
88933 const result = concatImpl$1([leftSide, rightSide], axis, backend);
88934 backend.disposeIntermediateTensorInfo(leftSide);
88935 backend.disposeIntermediateTensorInfo(rightSide);
88936 return result;
88937 }
88938 if (env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') &&
88939 inputs[0].shape.length > 1) {
88940 const program = new ConcatPackedProgram(inputs.map(t => t.shape), axis);
88941 return backend.runWebGLProgram(program, inputs, dtype);
88942 }
88943 const { tensors2D, outShape } = computeTensors2D(inputs, axis, backend);
88944 const program = new ConcatProgram(tensors2D.map(t => t.shape));
88945 const result = backend.runWebGLProgram(program, tensors2D, dtype);
88946 tensors2D.forEach(r => backend.disposeIntermediateTensorInfo(r));
88947 const reshapedResult = reshape$3({ inputs: { x: result }, attrs: { shape: outShape }, backend });
88948 backend.disposeIntermediateTensorInfo(result);
88949 return reshapedResult;
88950 }
88951 function computeTensors2D(inputs, axis, backend) {
88952 // Any concat of n-dimensional tensors across any axis can be reduced to
88953 // a concatenation of two-dimensional tensors across the axis 1 by first
88954 // partitioning the axes of the original tensors into those less than the
88955 // axis to be concatenated and the rest. Then reshape the tensors
88956 // into a two-dimensional tensor by collapsing these two sets of axes and
88957 // concatenate the resulting matrices across the axis 1, finally reshaping
88958 // the result to have the proper shape.
88959 const outShape = computeOutShape$1(inputs.map(t => t.shape), axis);
88960 const tensors2D = inputs.map(x => reshape$3({
88961 inputs: { x },
88962 attrs: { shape: [-1, sizeFromShape(x.shape.slice(axis))] },
88963 backend
88964 }));
88965 return { tensors2D, outShape };
88966 }
88967
88968 /**
88969 * @license
88970 * Copyright 2020 Google LLC. All Rights Reserved.
88971 * Licensed under the Apache License, Version 2.0 (the "License");
88972 * you may not use this file except in compliance with the License.
88973 * You may obtain a copy of the License at
88974 *
88975 * http://www.apache.org/licenses/LICENSE-2.0
88976 *
88977 * Unless required by applicable law or agreed to in writing, software
88978 * distributed under the License is distributed on an "AS IS" BASIS,
88979 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88980 * See the License for the specific language governing permissions and
88981 * limitations under the License.
88982 * =============================================================================
88983 */
88984 function concat$2(args) {
88985 const { inputs, backend, attrs } = args;
88986 const { axis } = attrs;
88987 const $axis = parseAxisParam(axis, inputs[0].shape)[0];
88988 const outShape = computeOutShape$1(inputs.map(t => t.shape), $axis);
88989 if (sizeFromShape(outShape) === 0) {
88990 return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
88991 }
88992 // Keep only non-empty tensors (ignore tensors with 0 in their shape).
88993 const $inputs = inputs.filter(t => sizeFromShape(t.shape) > 0);
88994 if ($inputs.length === 1) {
88995 return identity$2({ inputs: { x: $inputs[0] }, backend });
88996 }
88997 const shapes = $inputs.map(t => t.shape);
88998 assertParamsConsistent(shapes, $axis);
88999 return concatImpl$1($inputs, $axis, backend);
89000 }
89001 const concatConfig$1 = {
89002 kernelName: Concat,
89003 backendName: 'webgl',
89004 kernelFunc: concat$2
89005 };
89006
89007 /**
89008 * @license
89009 * Copyright 2017 Google LLC. All Rights Reserved.
89010 * Licensed under the Apache License, Version 2.0 (the "License");
89011 * you may not use this file except in compliance with the License.
89012 * You may obtain a copy of the License at
89013 *
89014 * http://www.apache.org/licenses/LICENSE-2.0
89015 *
89016 * Unless required by applicable law or agreed to in writing, software
89017 * distributed under the License is distributed on an "AS IS" BASIS,
89018 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89019 * See the License for the specific language governing permissions and
89020 * limitations under the License.
89021 * =============================================================================
89022 */
89023 class Conv2DProgram {
89024 constructor(convInfo, addBias = false, activation = null, hasPreluActivationWeights = false, hasLeakyreluAlpha = false) {
89025 this.variableNames = ['x', 'W'];
89026 this.outputShape = convInfo.outShape;
89027 const padTop = convInfo.padInfo.top;
89028 const padLeft = convInfo.padInfo.left;
89029 const strideHeight = convInfo.strideHeight;
89030 const strideWidth = convInfo.strideWidth;
89031 const dilationHeight = convInfo.dilationHeight;
89032 const dilationWidth = convInfo.dilationWidth;
89033 const filterHeight = convInfo.filterHeight;
89034 const filterWidth = convInfo.filterWidth;
89035 const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
89036 const inputDepthVec4Remainder = convInfo.inChannels % 4;
89037 const isChannelsLast = convInfo.dataFormat === 'channelsLast';
89038 const rowDim = isChannelsLast ? 1 : 2;
89039 const colDim = isChannelsLast ? 2 : 3;
89040 const channelDim = isChannelsLast ? 3 : 1;
89041 let activationSnippet = '', applyActivationSnippet = '';
89042 if (activation) {
89043 if (hasPreluActivationWeights) {
89044 activationSnippet = `float activation(float a) {
89045 float b = getPreluActivationWeightsAtOutCoords();
89046 ${activation}
89047 }`;
89048 }
89049 else if (hasLeakyreluAlpha) {
89050 activationSnippet = `float activation(float a) {
89051 float b = getLeakyreluAlphaAtOutCoords();
89052 ${activation}
89053 }`;
89054 }
89055 else {
89056 activationSnippet = `
89057 float activation(float x) {
89058 ${activation}
89059 }
89060 `;
89061 }
89062 applyActivationSnippet = `result = activation(result);`;
89063 }
89064 const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
89065 if (addBias) {
89066 this.variableNames.push('bias');
89067 }
89068 if (hasPreluActivationWeights) {
89069 this.variableNames.push('preluActivationWeights');
89070 }
89071 if (hasLeakyreluAlpha) {
89072 this.variableNames.push('leakyreluAlpha');
89073 }
89074 this.userCode = `
89075 ${activationSnippet}
89076
89077 const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
89078 const ivec2 pads = ivec2(${padTop}, ${padLeft});
89079
89080 void main() {
89081 ivec4 coords = getOutputCoords();
89082 int batch = coords[0];
89083 int d2 = coords[${channelDim}];
89084
89085 ivec2 xRCCorner =
89086 ivec2(coords[${rowDim}], coords[${colDim}]) * strides - pads;
89087 int xRCorner = xRCCorner.x;
89088 int xCCorner = xRCCorner.y;
89089
89090 // Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).
89091 // ? = to be determined. : = across all values in that axis.
89092 float dotProd = 0.0;
89093 for (int wR = 0; wR < ${filterHeight}; wR++) {
89094 int xR = xRCorner + wR * ${dilationHeight};
89095
89096 if (xR < 0 || xR >= ${convInfo.inHeight}) {
89097 continue;
89098 }
89099
89100 for (int wC = 0; wC < ${filterWidth}; wC++) {
89101 int xC = xCCorner + wC * ${dilationWidth};
89102
89103 if (xC < 0 || xC >= ${convInfo.inWidth}) {
89104 continue;
89105 }
89106
89107 for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) {
89108 vec4 wValues = vec4(
89109 getW(wR, wC, d1, d2),
89110 getW(wR, wC, d1 + 1, d2),
89111 getW(wR, wC, d1 + 2, d2),
89112 getW(wR, wC, d1 + 3, d2)
89113 );
89114
89115 if (${isChannelsLast}) {
89116 vec4 xValues = vec4(
89117 getX(batch, xR, xC, d1),
89118 getX(batch, xR, xC, d1 + 1),
89119 getX(batch, xR, xC, d1 + 2),
89120 getX(batch, xR, xC, d1 + 3)
89121 );
89122 dotProd += dot(xValues, wValues);
89123 } else {
89124 vec4 xValues = vec4(
89125 getX(batch, d1, xR, xC),
89126 getX(batch, d1 + 1, xR, xC),
89127 getX(batch, d1 + 2, xR, xC),
89128 getX(batch, d1 + 3, xR, xC)
89129 );
89130 dotProd += dot(xValues, wValues);
89131 }
89132 }
89133
89134 if (${inputDepthVec4Remainder === 1}) {
89135
89136 if (${isChannelsLast}) {
89137 dotProd +=
89138 getX(batch, xR, xC, ${inputDepthNearestVec4}) *
89139 getW(wR, wC, ${inputDepthNearestVec4}, d2);
89140 } else {
89141 dotProd +=
89142 getX(batch, ${inputDepthNearestVec4}, xR, xC) *
89143 getW(wR, wC, ${inputDepthNearestVec4}, d2);
89144 }
89145
89146 } else if (${inputDepthVec4Remainder === 2}) {
89147 vec2 wValues = vec2(
89148 getW(wR, wC, ${inputDepthNearestVec4}, d2),
89149 getW(wR, wC, ${inputDepthNearestVec4} + 1, d2)
89150 );
89151
89152 if (${isChannelsLast}) {
89153 vec2 xValues = vec2(
89154 getX(batch, xR, xC, ${inputDepthNearestVec4}),
89155 getX(batch, xR, xC, ${inputDepthNearestVec4} + 1)
89156 );
89157 dotProd += dot(xValues, wValues);
89158 } else {
89159 vec2 xValues = vec2(
89160 getX(batch, ${inputDepthNearestVec4}, xR, xC),
89161 getX(batch, ${inputDepthNearestVec4} + 1, xR, xC)
89162 );
89163 dotProd += dot(xValues, wValues);
89164 }
89165
89166 } else if (${inputDepthVec4Remainder === 3}) {
89167 vec3 wValues = vec3(
89168 getW(wR, wC, ${inputDepthNearestVec4}, d2),
89169 getW(wR, wC, ${inputDepthNearestVec4} + 1, d2),
89170 getW(wR, wC, ${inputDepthNearestVec4} + 2, d2)
89171 );
89172
89173 if (${isChannelsLast}) {
89174 vec3 xValues = vec3(
89175 getX(batch, xR, xC, ${inputDepthNearestVec4}),
89176 getX(batch, xR, xC, ${inputDepthNearestVec4} + 1),
89177 getX(batch, xR, xC, ${inputDepthNearestVec4} + 2)
89178 );
89179 dotProd += dot(xValues, wValues);
89180 } else {
89181 vec3 xValues = vec3(
89182 getX(batch, ${inputDepthNearestVec4}, xR, xC),
89183 getX(batch, ${inputDepthNearestVec4} + 1, xR, xC),
89184 getX(batch, ${inputDepthNearestVec4} + 2, xR, xC)
89185 );
89186 dotProd += dot(xValues, wValues);
89187 }
89188
89189 }
89190 }
89191 }
89192
89193 float result = dotProd;
89194 ${addBiasSnippet}
89195 ${applyActivationSnippet}
89196 setOutput(result);
89197 }
89198 `;
89199 }
89200 }
89201 class Conv3DProgram {
89202 constructor(convInfo) {
89203 this.variableNames = ['x', 'W'];
89204 this.outputShape = convInfo.outShape;
89205 const padFront = convInfo.padInfo.front;
89206 const padTop = convInfo.padInfo.top;
89207 const padLeft = convInfo.padInfo.left;
89208 const strideDepth = convInfo.strideDepth;
89209 const strideHeight = convInfo.strideHeight;
89210 const strideWidth = convInfo.strideWidth;
89211 const dilationDepth = convInfo.dilationDepth;
89212 const dilationHeight = convInfo.dilationHeight;
89213 const dilationWidth = convInfo.dilationWidth;
89214 const filterDepth = convInfo.filterDepth;
89215 const filterHeight = convInfo.filterHeight;
89216 const filterWidth = convInfo.filterWidth;
89217 const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
89218 const inputDepthVec4Remainder = convInfo.inChannels % 4;
89219 this.userCode = `
89220 const ivec3 strides = ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
89221 const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
89222
89223 void main() {
89224 ivec5 coords = getOutputCoords();
89225 int batch = coords.x;
89226 int d2 = coords.u;
89227
89228 ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
89229 int xFCorner = xFRCCorner.x;
89230 int xRCorner = xFRCCorner.y;
89231 int xCCorner = xFRCCorner.z;
89232
89233 // Convolve x(?, ?, ?, d1) with w(:, :, :, d1, d2) to get
89234 // y(yF, yR, yC, d2). ? = to be determined. : = across all
89235 // values in that axis.
89236 float dotProd = 0.0;
89237 for (int wF = 0; wF < ${filterDepth}; wF++) {
89238 int xF = xFCorner + wF * ${dilationDepth};
89239
89240 if (xF < 0 || xF >= ${convInfo.inDepth}) {
89241 continue;
89242 }
89243
89244 for (int wR = 0; wR < ${filterHeight}; wR++) {
89245 int xR = xRCorner + wR * ${dilationHeight};
89246
89247 if (xR < 0 || xR >= ${convInfo.inHeight}) {
89248 continue;
89249 }
89250
89251 for (int wC = 0; wC < ${filterWidth}; wC++) {
89252 int xC = xCCorner + wC * ${dilationWidth};
89253
89254 if (xC < 0 || xC >= ${convInfo.inWidth}) {
89255 continue;
89256 }
89257
89258 for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) {
89259 vec4 xValues = vec4(
89260 getX(batch, xF, xR, xC, d1),
89261 getX(batch, xF, xR, xC, d1 + 1),
89262 getX(batch, xF, xR, xC, d1 + 2),
89263 getX(batch, xF, xR, xC, d1 + 3)
89264 );
89265 vec4 wValues = vec4(
89266 getW(wF, wR, wC, d1, d2),
89267 getW(wF, wR, wC, d1 + 1, d2),
89268 getW(wF, wR, wC, d1 + 2, d2),
89269 getW(wF, wR, wC, d1 + 3, d2)
89270 );
89271
89272 dotProd += dot(xValues, wValues);
89273 }
89274
89275 if (${inputDepthVec4Remainder === 1}) {
89276 dotProd +=
89277 getX(batch, xF, xR, xC, ${inputDepthNearestVec4}) *
89278 getW(wF, wR, wC, ${inputDepthNearestVec4}, d2);
89279 } else if (${inputDepthVec4Remainder === 2}) {
89280 vec2 xValues = vec2(
89281 getX(batch, xF, xR, xC, ${inputDepthNearestVec4}),
89282 getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1)
89283 );
89284 vec2 wValues = vec2(
89285 getW(wF, wR, wC, ${inputDepthNearestVec4}, d2),
89286 getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2)
89287 );
89288 dotProd += dot(xValues, wValues);
89289 } else if (${inputDepthVec4Remainder === 3}) {
89290 vec3 xValues = vec3(
89291 getX(batch, xF, xR, xC, ${inputDepthNearestVec4}),
89292 getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1),
89293 getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 2)
89294 );
89295 vec3 wValues = vec3(
89296 getW(wF, wR, wC, ${inputDepthNearestVec4}, d2),
89297 getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2),
89298 getW(wF, wR, wC, ${inputDepthNearestVec4} + 2, d2)
89299 );
89300 dotProd += dot(xValues, wValues);
89301 }
89302 }
89303 }
89304 }
89305 setOutput(dotProd);
89306 }
89307 `;
89308 }
89309 }
89310
89311 /**
89312 * @license
89313 * Copyright 2019 Google LLC. All Rights Reserved.
89314 * Licensed under the Apache License, Version 2.0 (the "License");
89315 * you may not use this file except in compliance with the License.
89316 * You may obtain a copy of the License at
89317 *
89318 * http://www.apache.org/licenses/LICENSE-2.0
89319 *
89320 * Unless required by applicable law or agreed to in writing, software
89321 * distributed under the License is distributed on an "AS IS" BASIS,
89322 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89323 * See the License for the specific language governing permissions and
89324 * limitations under the License.
89325 * =============================================================================
89326 */
89327 class Im2ColPackedProgram {
89328 constructor(outputShape, convInfo) {
89329 this.variableNames = ['A'];
89330 this.packedInputs = true;
89331 this.packedOutput = true;
89332 this.customUniforms = [
89333 { name: 'inputShape', type: 'ivec3' },
89334 { name: 'pad', type: 'ivec2' },
89335 { name: 'stride', type: 'ivec2' },
89336 { name: 'dilation', type: 'ivec2' },
89337 { name: 'inChannels', type: 'int' },
89338 { name: 'itemsPerBlockRow', type: 'int' },
89339 { name: 'outWidth', type: 'int' },
89340 ];
89341 this.outputShape = outputShape;
89342 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
89343 const { dataFormat } = convInfo;
89344 const glsl = getGlslDifferences();
89345 const isChannelsLast = dataFormat === 'channelsLast';
89346 const rowDim = isChannelsLast ? 0 : 1;
89347 const colDim = isChannelsLast ? 1 : 2;
89348 const boundsCheckingSnippet = this.enableShapeUniforms ?
89349 'if(blockIndex < outShape[1] && pos < outShape[0]) {' :
89350 `if(blockIndex < ${outputShape[1]} && pos < ${outputShape[0]}) {`;
89351 let unrolled = ``;
89352 for (let row = 0; row <= 1; row++) {
89353 for (let col = 0; col <= 1; col++) {
89354 unrolled += `
89355 blockIndex = rc.y + ${col};
89356 pos = rc.x + ${row};
89357
89358 ${boundsCheckingSnippet}
89359 offsetY = int(blockIndex / outWidth) * stride[0] - pad[0];
89360 d0 = offsetY + dilation[0] * (pos / itemsPerBlockRow);
89361
89362 if(d0 < inputShape[${rowDim}] && d0 >= 0) {
89363 // Use custom imod instead mod. On Intel GPU, mod may generate
89364 // unexpected value.
89365 // https://github.com/tensorflow/tfjs/issues/5447
89366 offsetX = imod(blockIndex, outWidth) * stride[1] - pad[1];
89367 d1 = offsetX + dilation[1] * (imod(pos, itemsPerBlockRow) /
89368 inChannels);
89369
89370 if(d1 < inputShape[${colDim}] && d1 >= 0) {
89371
89372 ch = imod(pos, inChannels);
89373
89374 if (${isChannelsLast}) {
89375 innerDims = vec2(d1, ch);
89376 result[${row * 2 + col}] = getChannel(
89377 getA(d0, int(innerDims.x),
89378 int(innerDims.y)), innerDims);
89379 } else {
89380 innerDims = vec2(d0, d1);
89381 result[${row * 2 + col}] = getChannel(
89382 getA(ch, int(innerDims.x),
89383 int(innerDims.y)), innerDims);
89384 }
89385 }
89386 }
89387 }
89388 `;
89389 }
89390 }
89391 this.userCode = `
89392 void main() {
89393 ivec2 rc = getOutputCoords();
89394
89395 vec4 result = vec4(0);
89396
89397 int blockIndex, pos, offsetY, d0, offsetX, d1, ch;
89398 vec2 innerDims;
89399
89400 ${unrolled}
89401
89402 ${glsl.output} = result;
89403 }
89404 `;
89405 }
89406 }
89407
89408 /**
89409 * @license
89410 * Copyright 2020 Google LLC. All Rights Reserved.
89411 * Licensed under the Apache License, Version 2.0 (the "License");
89412 * you may not use this file except in compliance with the License.
89413 * You may obtain a copy of the License at
89414 *
89415 * http://www.apache.org/licenses/LICENSE-2.0
89416 *
89417 * Unless required by applicable law or agreed to in writing, software
89418 * distributed under the License is distributed on an "AS IS" BASIS,
89419 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89420 * See the License for the specific language governing permissions and
89421 * limitations under the License.
89422 * =============================================================================
89423 */
89424 // For 1x1 kernels that iterate through every point in the input, convolution
89425 // can be expressed as matrix multiplication (without need for memory
89426 // remapping).
89427 function conv2dByMatMul({ x, filter, convInfo, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) {
89428 // Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
89429 // result from 2D to 4D.
89430 const xShape = x.shape;
89431 const xTexData = backend.texData.get(x.dataId);
89432 const sharedMatMulDim = convInfo.inChannels;
89433 const outerShapeX = xShape[0] * xShape[1] * xShape[2];
89434 const outerShapeFilter = convInfo.outChannels;
89435 const isChannelsLast = convInfo.dataFormat === 'channelsLast';
89436 const transposeA = false;
89437 const transposeB = false;
89438 let out;
89439 const intermediates = [];
89440 if (preluActivationWeights != null && !isChannelsLast &&
89441 preluActivationWeights.shape.length === 3) {
89442 // If PReLU's activation weights is NCHW format, then convert it to NHWC for
89443 // the following computation.
89444 const preluActivationWeightsInNhwcFormat = transpose$2({
89445 inputs: { x: preluActivationWeights },
89446 backend,
89447 attrs: { perm: [1, 2, 0] }
89448 });
89449 intermediates.push(preluActivationWeightsInNhwcFormat);
89450 preluActivationWeights = preluActivationWeightsInNhwcFormat;
89451 }
89452 // TODO: Once reduction ops are packed, batchMatMul will always be packed
89453 // and we can remove this condition.
89454 const batchMatMulWillBeUnpacked = (outerShapeX === 1 || outerShapeFilter === 1) &&
89455 sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD;
89456 // The algorithm in the if condition assumes (1) the output will be packed,
89457 // (2) x is packed, (3) x isChannelsLast, (4) x's packed texture is already
89458 // on GPU, (5) col is odd, (6) the width, height and inChannels are the same
89459 // for xTexData.shape and xShape.
89460 const canOptimize = !batchMatMulWillBeUnpacked && xTexData.isPacked &&
89461 isChannelsLast && xTexData.texture != null && xShape[2] % 2 !== 0 &&
89462 arraysEqual(xTexData.shape.slice(-3), xShape.slice(-3));
89463 if (canOptimize) {
89464 // We avoid expensive packed 2x2 reshape by padding col count to next,
89465 // even number. When col is odd, the result of packed batchMatMul is
89466 // the same (has the same texture layout and and values in the texture) as
89467 // it is for next even col. We make the odd-cols tensor to look like
89468 // even-cols tensor before the operation and, after the batchMatMul,
89469 // fix the even-cols result to have odd number of cols.
89470 const targetShape = xShape[0] * xShape[1] * (xShape[2] + 1);
89471 const xReshaped = {
89472 dataId: x.dataId,
89473 shape: [1, targetShape, convInfo.inChannels],
89474 dtype: x.dtype
89475 };
89476 // xTexData.shape gets referenced from GPGPUBinary.inShapeInfos.
89477 // Decrementing col count, after batchMatMul->...->compileProgram leads to
89478 // invalid col count within the reference in GPGPUBinary.inShapeInfos.
89479 // Alternative fix would be to provide a copy to GPGPUBinary.inShapeInfos
89480 // in compileProgram method, but that would affect compilation of all
89481 // programs - instead, provide a copy here, with even col count, before
89482 // calling batchMatMul->...->compileProgram and after that, the original
89483 // xTexData.shape is restored.
89484 const originalXTexDataShape = xTexData.shape;
89485 xTexData.shape = xTexData.shape.slice();
89486 xTexData.shape[xTexData.shape.length - 2]++;
89487 assert(isReshapeFree(xTexData.shape, xReshaped.shape), () => `packed reshape ${xTexData.shape} to ${xReshaped.shape} isn't free`);
89488 const filterReshaped = reshape$3({
89489 inputs: { x: filter },
89490 backend,
89491 attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] }
89492 });
89493 intermediates.push(filterReshaped);
89494 const pointwiseConv = batchMatMulImpl({
89495 a: xReshaped,
89496 b: filterReshaped,
89497 backend,
89498 transposeA,
89499 transposeB,
89500 bias,
89501 activation,
89502 preluActivationWeights,
89503 leakyreluAlpha
89504 });
89505 const pointwiseConvTexData = backend.texData.get(pointwiseConv.dataId);
89506 assert(pointwiseConvTexData.isPacked, () => 'batchMatMul result is expected to be packed');
89507 // Restore the input shape to original.
89508 xTexData.shape = originalXTexDataShape;
89509 // Set the output shape - there is no need for expensive reshape as data
89510 // layout is already correct.
89511 pointwiseConvTexData.shape = convInfo.outShape;
89512 out = identity$2({ inputs: { x: pointwiseConv }, backend });
89513 out.shape = convInfo.outShape;
89514 intermediates.push(pointwiseConv);
89515 }
89516 else {
89517 const xInNhwcFormat = isChannelsLast ?
89518 x :
89519 transpose$2({ inputs: { x }, backend, attrs: { perm: [0, 2, 3, 1] } });
89520 const xInNhwcFormatShape = xInNhwcFormat.shape;
89521 const targetShape = xInNhwcFormatShape[0] * xInNhwcFormatShape[1] * xInNhwcFormatShape[2];
89522 const xReshaped = reshape$3({
89523 inputs: { x: xInNhwcFormat },
89524 backend,
89525 attrs: { shape: [1, targetShape, convInfo.inChannels] }
89526 });
89527 const filterReshaped = reshape$3({
89528 inputs: { x: filter },
89529 backend,
89530 attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] }
89531 });
89532 const result = batchMatMulImpl({
89533 a: xReshaped,
89534 b: filterReshaped,
89535 transposeA,
89536 transposeB,
89537 backend,
89538 bias,
89539 activation,
89540 preluActivationWeights,
89541 leakyreluAlpha
89542 });
89543 const outInNHWCFormatShape = [
89544 convInfo.batchSize, convInfo.outHeight, convInfo.outWidth,
89545 convInfo.outChannels
89546 ];
89547 const outInNHWCFormat = reshape$3({ inputs: { x: result }, backend, attrs: { shape: outInNHWCFormatShape } });
89548 // If the data format is NCHW, then convert the output to be NCHW format.
89549 out = isChannelsLast ? outInNHWCFormat : transpose$2({
89550 inputs: { x: outInNHWCFormat },
89551 backend,
89552 attrs: { perm: [0, 3, 1, 2] }
89553 });
89554 if (!isChannelsLast) {
89555 intermediates.push(xInNhwcFormat);
89556 intermediates.push(outInNHWCFormat);
89557 }
89558 intermediates.push(xReshaped);
89559 intermediates.push(filterReshaped);
89560 intermediates.push(result);
89561 }
89562 for (const i of intermediates) {
89563 backend.disposeIntermediateTensorInfo(i);
89564 }
89565 return out;
89566 }
89567 // Implements the im2row algorithm as outlined in "High Performance
89568 // Convolutional Neural Networks for Document Processing" (Suvisoft, 2006)
89569 function conv2dWithIm2Row({ x, filter, convInfo, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) {
89570 // Rearranges conv2d input so each block to be convolved over forms the
89571 // column of a new matrix with shape [filterWidth * filterHeight *
89572 // inChannels, outHeight * outWidth]. The filter is also rearranged so each
89573 // output channel forms a row of a new matrix with shape [outChannels,
89574 // filterWidth * filterHeight * inChannels]. The convolution is then
89575 // computed by multiplying these matrices and reshaping the result.
89576 const { filterWidth, filterHeight, inChannels, outWidth, outHeight, dataFormat } = convInfo;
89577 const isChannelsLast = dataFormat === 'channelsLast';
89578 const sharedDim = filterWidth * filterHeight * inChannels;
89579 const numCols = outHeight * outWidth;
89580 const x2ColShape = [sharedDim, numCols];
89581 const transposeA = true;
89582 const transposeB = false;
89583 const intermediates = [];
89584 if (preluActivationWeights != null && !isChannelsLast &&
89585 preluActivationWeights.shape.length === 3) {
89586 // If PReLU's activation weights is NCHW format, then convert it to NHWC for
89587 // the following computation.
89588 const preluActivationWeightsInNhwcFormat = transpose$2({
89589 inputs: { x: preluActivationWeights },
89590 backend,
89591 attrs: { perm: [1, 2, 0] }
89592 });
89593 intermediates.push(preluActivationWeightsInNhwcFormat);
89594 preluActivationWeights = preluActivationWeightsInNhwcFormat;
89595 }
89596 const xSqueezed = reshape$3({ inputs: { x }, backend, attrs: { shape: x.shape.slice(1) } });
89597 const w2Row = reshape$3({
89598 inputs: { x: filter },
89599 backend,
89600 attrs: { shape: [1, sharedDim, sizeFromShape(filter.shape) / sharedDim] }
89601 });
89602 intermediates.push(xSqueezed);
89603 intermediates.push(w2Row);
89604 const im2ColProgram = new Im2ColPackedProgram(x2ColShape, convInfo);
89605 const customValues = [
89606 xSqueezed.shape, [convInfo.padInfo.top, convInfo.padInfo.left],
89607 [convInfo.strideHeight, convInfo.strideWidth],
89608 [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inChannels],
89609 [convInfo.filterWidth * convInfo.inChannels], [convInfo.outWidth]
89610 ];
89611 const im2Col = backend.runWebGLProgram(im2ColProgram, [xSqueezed], 'float32', customValues);
89612 const im2ColReshaped = reshape$3({
89613 inputs: { x: im2Col },
89614 backend,
89615 attrs: { shape: [1, x2ColShape[0], x2ColShape[1]] }
89616 });
89617 intermediates.push(im2Col);
89618 intermediates.push(im2ColReshaped);
89619 const hasBias = bias != null;
89620 const hasPreluActivationWeights = preluActivationWeights != null;
89621 const hasLeakyreluAlpha = activation === 'leakyrelu';
89622 const fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null;
89623 const matmulProgram = new MatMulPackedProgram(im2ColReshaped.shape, w2Row.shape, [1, numCols, convInfo.outChannels], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
89624 const inputs = [im2ColReshaped, w2Row];
89625 if (bias) {
89626 inputs.push(bias);
89627 }
89628 if (hasPreluActivationWeights) {
89629 inputs.push(preluActivationWeights);
89630 }
89631 if (hasLeakyreluAlpha) {
89632 const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
89633 inputs.push($leakyreluAlpha);
89634 intermediates.push($leakyreluAlpha);
89635 }
89636 const product = backend.runWebGLProgram(matmulProgram, inputs, 'float32');
89637 const outInNHWCFormatShape = [1, outHeight, outWidth, convInfo.outChannels];
89638 const outInNHWCFormat = reshape$3({ inputs: { x: product }, backend, attrs: { shape: outInNHWCFormatShape } });
89639 // If the data format is NCHW, then convert the output to be NCHW format.
89640 const out = isChannelsLast ?
89641 outInNHWCFormat :
89642 transpose$2({ inputs: { x: outInNHWCFormat }, backend, attrs: { perm: [0, 3, 1, 2] } });
89643 if (!isChannelsLast) {
89644 intermediates.push(outInNHWCFormat);
89645 }
89646 intermediates.push(product);
89647 for (const i of intermediates) {
89648 backend.disposeIntermediateTensorInfo(i);
89649 }
89650 return out;
89651 }
89652
89653 /**
89654 * @license
89655 * Copyright 2020 Google LLC. All Rights Reserved.
89656 * Licensed under the Apache License, Version 2.0 (the "License");
89657 * you may not use this file except in compliance with the License.
89658 * You may obtain a copy of the License at
89659 *
89660 * http://www.apache.org/licenses/LICENSE-2.0
89661 *
89662 * Unless required by applicable law or agreed to in writing, software
89663 * distributed under the License is distributed on an "AS IS" BASIS,
89664 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89665 * See the License for the specific language governing permissions and
89666 * limitations under the License.
89667 * =============================================================================
89668 */
89669 function conv2d$4(args) {
89670 const { inputs, backend, attrs } = args;
89671 const { x, filter } = inputs;
89672 const { strides, pad, dataFormat, dilations, dimRoundingMode } = attrs;
89673 const $dataFormat = convertConv2DDataFormat(dataFormat);
89674 const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
89675 let out;
89676 if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
89677 convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
89678 convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
89679 (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) {
89680 out = conv2dByMatMul({ x, filter, convInfo, backend });
89681 }
89682 else if (env().getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) {
89683 out = conv2dWithIm2Row({ x, filter, convInfo, backend });
89684 }
89685 else {
89686 const program = new Conv2DProgram(convInfo);
89687 out = backend.runWebGLProgram(program, [x, filter], 'float32');
89688 }
89689 const outReshaped = reshape$3({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } });
89690 backend.disposeIntermediateTensorInfo(out);
89691 return outReshaped;
89692 }
89693 const conv2DConfig$1 = {
89694 kernelName: Conv2D,
89695 backendName: 'webgl',
89696 kernelFunc: conv2d$4,
89697 };
89698
89699 /**
89700 * @license
89701 * Copyright 2017 Google LLC. All Rights Reserved.
89702 * Licensed under the Apache License, Version 2.0 (the "License");
89703 * you may not use this file except in compliance with the License.
89704 * You may obtain a copy of the License at
89705 *
89706 * http://www.apache.org/licenses/LICENSE-2.0
89707 *
89708 * Unless required by applicable law or agreed to in writing, software
89709 * distributed under the License is distributed on an "AS IS" BASIS,
89710 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89711 * See the License for the specific language governing permissions and
89712 * limitations under the License.
89713 * =============================================================================
89714 */
89715 class Conv2DDerFilterProgram {
89716 constructor(convInfo) {
89717 this.variableNames = ['x', 'dy'];
89718 this.outputShape = convInfo.filterShape;
89719 const strideHeight = convInfo.strideHeight;
89720 const strideWidth = convInfo.strideWidth;
89721 const padTop = convInfo.padInfo.top;
89722 const padLeft = convInfo.padInfo.left;
89723 const isChannelsLast = convInfo.dataFormat === 'channelsLast';
89724 this.userCode = `
89725 void main() {
89726 ivec4 coords = getOutputCoords();
89727 int wR = coords.x;
89728 int wC = coords.y;
89729 int d1 = coords.z;
89730 int d2 = coords.w;
89731
89732 // Convolve x(?, ?, d1) with dy(:, :, d2) to get dw(wR, wC, d1, d2).
89733 // ? = to be determined. : = across all values in that axis.
89734 float dotProd = 0.0;
89735
89736 for (int b = 0; b < ${convInfo.batchSize}; b++) {
89737 for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
89738 int xR = wR + yR * ${strideHeight} - ${padTop};
89739
89740 if (xR < 0 || xR >= ${convInfo.inHeight}) {
89741 continue;
89742 }
89743
89744 for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
89745 int xC = wC + yC * ${strideWidth} - ${padLeft};
89746
89747 if (xC < 0 || xC >= ${convInfo.inWidth}) {
89748 continue;
89749 }
89750
89751 if (${isChannelsLast}) {
89752 float dyValue = getDy(b, yR, yC, d2);
89753 float xValue = getX(b, xR, xC, d1);
89754 dotProd += (xValue * dyValue);
89755 } else {
89756 float dyValue = getDy(b, d2, yR, yC);
89757 float xValue = getX(b, d1, xR, xC);
89758 dotProd += (xValue * dyValue);
89759 }
89760
89761 }
89762 }
89763 }
89764 setOutput(dotProd);
89765 }
89766 `;
89767 }
89768 }
89769 class Conv2DDerInputProgram {
89770 constructor(convInfo) {
89771 this.variableNames = ['dy', 'W'];
89772 this.outputShape = convInfo.inShape;
89773 const filterHeight = convInfo.filterHeight;
89774 const filterWidth = convInfo.filterWidth;
89775 const strideHeight = convInfo.strideHeight;
89776 const strideWidth = convInfo.strideWidth;
89777 const isChannelsLast = convInfo.dataFormat === 'channelsLast';
89778 const padTop = filterHeight - 1 - convInfo.padInfo.top;
89779 const padLeft = filterWidth - 1 - convInfo.padInfo.left;
89780 const rowDim = isChannelsLast ? 1 : 2;
89781 const colDim = isChannelsLast ? 2 : 3;
89782 const channelDim = isChannelsLast ? 3 : 1;
89783 this.userCode = `
89784 const ivec2 pads = ivec2(${padTop}, ${padLeft});
89785
89786 void main() {
89787 ivec4 coords = getOutputCoords();
89788 int batch = coords[0];
89789 int d1 = coords[${channelDim}];
89790
89791 ivec2 dyCorner = ivec2(coords[${rowDim}], coords[${colDim}]) - pads;
89792 int dyRCorner = dyCorner.x;
89793 int dyCCorner = dyCorner.y;
89794
89795 // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
89796 // ? = to be determined. : = across all values in that axis.
89797 float dotProd = 0.0;
89798 for (int wR = 0; wR < ${filterHeight}; wR++) {
89799 float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
89800
89801 if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
89802 continue;
89803 }
89804 int idyR = int(dyR);
89805
89806 int wRPerm = ${filterHeight} - 1 - wR;
89807
89808 for (int wC = 0; wC < ${filterWidth}; wC++) {
89809 float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
89810
89811 if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
89812 fract(dyC) > 0.0) {
89813 continue;
89814 }
89815 int idyC = int(dyC);
89816
89817 int wCPerm = ${filterWidth} - 1 - wC;
89818
89819 for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) {
89820
89821 if (${isChannelsLast}) {
89822 float xValue = getDy(batch, idyR, idyC, d2);
89823 float wValue = getW(wRPerm, wCPerm, d1, d2);
89824 dotProd += xValue * wValue;
89825 } else {
89826 float xValue = getDy(batch, d2, idyR, idyC);
89827 float wValue = getW(wRPerm, wCPerm, d1, d2);
89828 dotProd += xValue * wValue;
89829 }
89830
89831 }
89832 }
89833 }
89834 setOutput(dotProd);
89835 }
89836 `;
89837 }
89838 }
89839 class Conv3DDerFilterProgram {
89840 constructor(convInfo) {
89841 this.variableNames = ['x', 'dy'];
89842 this.outputShape = convInfo.filterShape;
89843 const strideDepth = convInfo.strideDepth;
89844 const strideHeight = convInfo.strideHeight;
89845 const strideWidth = convInfo.strideWidth;
89846 const padFront = convInfo.padInfo.front;
89847 const padTop = convInfo.padInfo.top;
89848 const padLeft = convInfo.padInfo.left;
89849 this.userCode = `
89850 void main() {
89851 ivec5 coords = getOutputCoords();
89852 int wF = coords.x;
89853 int wR = coords.y;
89854 int wC = coords.z;
89855 int d1 = coords.w;
89856 int d2 = coords.u;
89857
89858 float dotProd = 0.0;
89859
89860 for (int b = 0; b < ${convInfo.batchSize}; b++) {
89861 for (int yF = 0; yF < ${convInfo.outDepth}; yF++) {
89862 int xF = wF + yF * ${strideDepth} - ${padFront};
89863
89864 if (xF < 0 || xF >= ${convInfo.inDepth}) {
89865 continue;
89866 }
89867
89868 for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
89869 int xR = wR + yR * ${strideHeight} - ${padTop};
89870
89871 if (xR < 0 || xR >= ${convInfo.inHeight}) {
89872 continue;
89873 }
89874
89875 for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
89876 int xC = wC + yC * ${strideWidth} - ${padLeft};
89877
89878 if (xC < 0 || xC >= ${convInfo.inWidth}) {
89879 continue;
89880 }
89881
89882 float dyValue = getDy(b, yF, yR, yC, d2);
89883 float xValue = getX(b, xF, xR, xC, d1);
89884 dotProd += (xValue * dyValue);
89885 }
89886 }
89887 }
89888 }
89889 setOutput(dotProd);
89890 }
89891 `;
89892 }
89893 }
89894 class Conv3DDerInputProgram {
89895 constructor(convInfo) {
89896 this.variableNames = ['dy', 'W'];
89897 this.outputShape = convInfo.inShape;
89898 const filterDepth = convInfo.filterDepth;
89899 const filterHeight = convInfo.filterHeight;
89900 const filterWidth = convInfo.filterWidth;
89901 const strideDepth = convInfo.strideDepth;
89902 const strideHeight = convInfo.strideHeight;
89903 const strideWidth = convInfo.strideWidth;
89904 const padFront = filterDepth - 1 - convInfo.padInfo.front;
89905 const padTop = filterHeight - 1 - convInfo.padInfo.top;
89906 const padLeft = filterWidth - 1 - convInfo.padInfo.left;
89907 this.userCode = `
89908 const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
89909
89910 void main() {
89911 ivec5 coords = getOutputCoords();
89912 int batch = coords.x;
89913 int d1 = coords.u;
89914
89915
89916 ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
89917 int dyFCorner = dyCorner.x;
89918 int dyRCorner = dyCorner.y;
89919 int dyCCorner = dyCorner.z;
89920
89921 float dotProd = 0.0;
89922 for (int wF = 0; wF < ${filterDepth}; wF++) {
89923 float dyF = float(dyFCorner + wF) / ${strideDepth}.0;
89924
89925 if (dyF < 0.0 || dyF >= ${convInfo.outDepth}.0 || fract(dyF) > 0.0) {
89926 continue;
89927 }
89928 int idyF = int(dyF);
89929
89930 int wFPerm = ${filterDepth} - 1 - wF;
89931
89932 for (int wR = 0; wR < ${filterHeight}; wR++) {
89933 float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
89934
89935 if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
89936 fract(dyR) > 0.0) {
89937 continue;
89938 }
89939 int idyR = int(dyR);
89940
89941 int wRPerm = ${filterHeight} - 1 - wR;
89942
89943 for (int wC = 0; wC < ${filterWidth}; wC++) {
89944 float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
89945
89946 if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
89947 fract(dyC) > 0.0) {
89948 continue;
89949 }
89950 int idyC = int(dyC);
89951
89952 int wCPerm = ${filterWidth} - 1 - wC;
89953
89954 for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) {
89955 float xValue = getDy(batch, idyF, idyR, idyC, d2);
89956 float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2);
89957 dotProd += xValue * wValue;
89958 }
89959 }
89960 }
89961 }
89962 setOutput(dotProd);
89963 }
89964 `;
89965 }
89966 }
89967
89968 /**
89969 * @license
89970 * Copyright 2020 Google LLC. All Rights Reserved.
89971 * Licensed under the Apache License, Version 2.0 (the "License");
89972 * you may not use this file except in compliance with the License.
89973 * You may obtain a copy of the License at
89974 *
89975 * http://www.apache.org/licenses/LICENSE-2.0
89976 *
89977 * Unless required by applicable law or agreed to in writing, software
89978 * distributed under the License is distributed on an "AS IS" BASIS,
89979 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89980 * See the License for the specific language governing permissions and
89981 * limitations under the License.
89982 * =============================================================================
89983 */
89984 function conv2DBackpropFilter$2(args) {
89985 const { inputs, backend, attrs } = args;
89986 const { x, dy } = inputs;
89987 const { strides, pad, dataFormat, dimRoundingMode, filterShape } = attrs;
89988 const $dataFormat = convertConv2DDataFormat(dataFormat);
89989 const convInfo = computeConv2DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
89990 const program = new Conv2DDerFilterProgram(convInfo);
89991 return backend.runWebGLProgram(program, [x, dy], 'float32');
89992 }
89993 const conv2DBackpropFilterConfig$1 = {
89994 kernelName: Conv2DBackpropFilter,
89995 backendName: 'webgl',
89996 kernelFunc: conv2DBackpropFilter$2,
89997 };
89998
89999 /**
90000 * @license
90001 * Copyright 2020 Google LLC. All Rights Reserved.
90002 * Licensed under the Apache License, Version 2.0 (the "License");
90003 * you may not use this file except in compliance with the License.
90004 * You may obtain a copy of the License at
90005 *
90006 * http://www.apache.org/licenses/LICENSE-2.0
90007 *
90008 * Unless required by applicable law or agreed to in writing, software
90009 * distributed under the License is distributed on an "AS IS" BASIS,
90010 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90011 * See the License for the specific language governing permissions and
90012 * limitations under the License.
90013 * =============================================================================
90014 */
90015 function conv2DBackpropInput$2(args) {
90016 const { inputs, backend, attrs } = args;
90017 const { dy, filter } = inputs;
90018 const { inputShape, strides, pad, dataFormat, dimRoundingMode } = attrs;
90019 const $dataFormat = convertConv2DDataFormat(dataFormat);
90020 const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad, dimRoundingMode, false, $dataFormat);
90021 const program = new Conv2DDerInputProgram(convInfo);
90022 return backend.runWebGLProgram(program, [dy, filter], 'float32');
90023 }
90024 const conv2DBackpropInputConfig$1 = {
90025 kernelName: Conv2DBackpropInput,
90026 backendName: 'webgl',
90027 kernelFunc: conv2DBackpropInput$2,
90028 };
90029
90030 /**
90031 * @license
90032 * Copyright 2020 Google LLC. All Rights Reserved.
90033 * Licensed under the Apache License, Version 2.0 (the "License");
90034 * you may not use this file except in compliance with the License.
90035 * You may obtain a copy of the License at
90036 *
90037 * http://www.apache.org/licenses/LICENSE-2.0
90038 *
90039 * Unless required by applicable law or agreed to in writing, software
90040 * distributed under the License is distributed on an "AS IS" BASIS,
90041 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90042 * See the License for the specific language governing permissions and
90043 * limitations under the License.
90044 * =============================================================================
90045 */
90046 function conv3D$1(args) {
90047 const { inputs, backend, attrs } = args;
90048 const { x, filter } = inputs;
90049 const { strides, pad, dilations } = attrs;
90050 const convInfo = computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad);
90051 const program = new Conv3DProgram(convInfo);
90052 return backend.runWebGLProgram(program, [x, filter], 'float32');
90053 }
90054 const conv3DConfig$1 = {
90055 kernelName: Conv3D,
90056 backendName: 'webgl',
90057 kernelFunc: conv3D$1,
90058 };
90059
90060 /**
90061 * @license
90062 * Copyright 2020 Google LLC. All Rights Reserved.
90063 * Licensed under the Apache License, Version 2.0 (the "License");
90064 * you may not use this file except in compliance with the License.
90065 * You may obtain a copy of the License at
90066 *
90067 * http://www.apache.org/licenses/LICENSE-2.0
90068 *
90069 * Unless required by applicable law or agreed to in writing, software
90070 * distributed under the License is distributed on an "AS IS" BASIS,
90071 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90072 * See the License for the specific language governing permissions and
90073 * limitations under the License.
90074 * =============================================================================
90075 */
90076 function conv3DBackpropFilterV2$1(args) {
90077 const { inputs, backend, attrs } = args;
90078 const { x, dy } = inputs;
90079 const { strides, pad, filterShape } = attrs;
90080 const convInfo = computeConv3DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad);
90081 const program = new Conv3DDerFilterProgram(convInfo);
90082 return backend.runWebGLProgram(program, [x, dy], 'float32');
90083 }
90084 const conv3DBackpropFilterV2Config$1 = {
90085 kernelName: Conv3DBackpropFilterV2,
90086 backendName: 'webgl',
90087 kernelFunc: conv3DBackpropFilterV2$1
90088 };
90089
90090 /**
90091 * @license
90092 * Copyright 2020 Google LLC. All Rights Reserved.
90093 * Licensed under the Apache License, Version 2.0 (the "License");
90094 * you may not use this file except in compliance with the License.
90095 * You may obtain a copy of the License at
90096 *
90097 * http://www.apache.org/licenses/LICENSE-2.0
90098 *
90099 * Unless required by applicable law or agreed to in writing, software
90100 * distributed under the License is distributed on an "AS IS" BASIS,
90101 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90102 * See the License for the specific language governing permissions and
90103 * limitations under the License.
90104 * =============================================================================
90105 */
90106 function conv3DBackpropInput$1(args) {
90107 const { inputs, backend, attrs } = args;
90108 const { dy, filter } = inputs;
90109 const { pad, strides, inputShape } = attrs;
90110 const convInfo = computeConv3DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad);
90111 const program = new Conv3DDerInputProgram(convInfo);
90112 return backend.runWebGLProgram(program, [dy, filter], 'float32');
90113 }
90114 const conv3DBackpropInputConfig = {
90115 kernelName: Conv3DBackpropInputV2,
90116 backendName: 'webgl',
90117 kernelFunc: conv3DBackpropInput$1,
90118 };
90119
90120 /**
90121 * @license
90122 * Copyright 2020 Google LLC. All Rights Reserved.
90123 * Licensed under the Apache License, Version 2.0 (the "License");
90124 * you may not use this file except in compliance with the License.
90125 * You may obtain a copy of the License at
90126 *
90127 * http://www.apache.org/licenses/LICENSE-2.0
90128 *
90129 * Unless required by applicable law or agreed to in writing, software
90130 * distributed under the License is distributed on an "AS IS" BASIS,
90131 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90132 * See the License for the specific language governing permissions and
90133 * limitations under the License.
90134 * =============================================================================
90135 */
90136 const COS = CHECK_NAN_SNIPPET_UNARY + `
90137 return cos(x);
90138`;
90139 const cos$2 = unaryKernelFunc$1({ opSnippet: COS });
90140 const cosConfig$1 = {
90141 kernelName: Cos,
90142 backendName: 'webgl',
90143 kernelFunc: cos$2,
90144 };
90145
90146 /**
90147 * @license
90148 * Copyright 2020 Google LLC. All Rights Reserved.
90149 * Licensed under the Apache License, Version 2.0 (the "License");
90150 * you may not use this file except in compliance with the License.
90151 * You may obtain a copy of the License at
90152 *
90153 * http://www.apache.org/licenses/LICENSE-2.0
90154 *
90155 * Unless required by applicable law or agreed to in writing, software
90156 * distributed under the License is distributed on an "AS IS" BASIS,
90157 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90158 * See the License for the specific language governing permissions and
90159 * limitations under the License.
90160 * =============================================================================
90161 */
90162 const COSH = `
90163 float e2x = exp(-x);
90164 return (e2x + 1.0 / e2x) / 2.0;
90165`;
90166 const cosh$2 = unaryKernelFunc$1({ opSnippet: COSH });
90167 const coshConfig$1 = {
90168 kernelName: Cosh,
90169 backendName: 'webgl',
90170 kernelFunc: cosh$2,
90171 };
90172
90173 /**
90174 * @license
90175 * Copyright 2017 Google LLC. All Rights Reserved.
90176 * Licensed under the Apache License, Version 2.0 (the "License");
90177 * you may not use this file except in compliance with the License.
90178 * You may obtain a copy of the License at
90179 *
90180 * http://www.apache.org/licenses/LICENSE-2.0
90181 *
90182 * Unless required by applicable law or agreed to in writing, software
90183 * distributed under the License is distributed on an "AS IS" BASIS,
90184 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90185 * See the License for the specific language governing permissions and
90186 * limitations under the License.
90187 * =============================================================================
90188 */
90189 class CropAndResizeProgram {
90190 constructor(imageShape, boxShape, cropSize, method, extrapolationValue) {
90191 this.variableNames = ['Image', 'Boxes', 'BoxInd'];
90192 this.outputShape = [];
90193 const [batch, imageHeight, imageWidth, depth] = imageShape;
90194 const [numBoxes,] = boxShape;
90195 const [cropHeight, cropWidth] = cropSize;
90196 this.outputShape = [numBoxes, cropHeight, cropWidth, depth];
90197 const methodId = method === 'bilinear' ? 1 : 0;
90198 const [inputHeightFloat, inputWidthFloat] = [`${imageHeight - 1}.0`, `${imageWidth - 1}.0`];
90199 const [heightRatio, heightScale, inY] = cropHeight > 1 ?
90200 [
90201 `${(imageHeight - 1) / (cropHeight - 1)}`,
90202 '(y2-y1) * height_ratio',
90203 `y1*${inputHeightFloat} + float(y)*(height_scale)`,
90204 ] :
90205 [
90206 '0.0',
90207 '0.0',
90208 `0.5 * (y1+y2) * ${inputHeightFloat}`,
90209 ];
90210 const [widthRatio, widthScale, inX] = cropWidth > 1 ?
90211 [
90212 `${(imageWidth - 1) / (cropWidth - 1)}`,
90213 '(x2-x1) * width_ratio',
90214 `x1*${inputWidthFloat} + float(x)*(width_scale)`,
90215 ] :
90216 [
90217 '0.0',
90218 '0.0',
90219 `0.5 * (x1+x2) * ${inputWidthFloat}`,
90220 ];
90221 // Reference implementation
90222 // tslint:disable-next-line:max-line-length
90223 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
90224 this.userCode = `
90225 const float height_ratio = float(${heightRatio});
90226 const float width_ratio = float(${widthRatio});
90227 void main() {
90228 ivec4 coords = getOutputCoords();
90229 int b = coords[0];
90230 int y = coords[1];
90231 int x = coords[2];
90232 int d = coords[3];
90233
90234 // get box vals
90235 float y1 = getBoxes(b,0);
90236 float x1 = getBoxes(b,1);
90237 float y2 = getBoxes(b,2);
90238 float x2 = getBoxes(b,3);
90239
90240 // get image in batch index
90241 int bInd = round(getBoxInd(b));
90242 if(bInd < 0 || bInd >= ${batch}) {
90243 return;
90244 }
90245
90246 float height_scale = ${heightScale};
90247 float width_scale = ${widthScale};
90248
90249 float in_y = ${inY};
90250 if( in_y < 0.0 || in_y > ${inputHeightFloat} ) {
90251 setOutput(float(${extrapolationValue}));
90252 return;
90253 }
90254 float in_x = ${inX};
90255 if( in_x < 0.0 || in_x > ${inputWidthFloat} ) {
90256 setOutput(float(${extrapolationValue}));
90257 return;
90258 }
90259
90260 vec2 sourceFracIndexCR = vec2(in_x,in_y);
90261 if(${methodId} == 1) {
90262 // Compute the four integer indices.
90263 ivec2 sourceFloorCR = ivec2(sourceFracIndexCR);
90264 ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR));
90265
90266 float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d);
90267 float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d);
90268 float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d);
90269 float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d);
90270
90271 vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR);
90272
90273 float top = topLeft + (topRight - topLeft) * fracCR.x;
90274 float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x;
90275 float newValue = top + (bottom - top) * fracCR.y;
90276 setOutput(newValue);
90277 } else {
90278 // Compute the coordinators of nearest neighbor point.
90279 ivec2 sourceNearestCR = ivec2(floor(
90280 sourceFracIndexCR + vec2(0.5,0.5)));
90281 float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d);
90282 setOutput(newValue);
90283 }
90284 }
90285 `;
90286 }
90287 }
90288
90289 /**
90290 * @license
90291 * Copyright 2020 Google LLC. All Rights Reserved.
90292 * Licensed under the Apache License, Version 2.0 (the "License");
90293 * you may not use this file except in compliance with the License.
90294 * You may obtain a copy of the License at
90295 *
90296 * http://www.apache.org/licenses/LICENSE-2.0
90297 *
90298 * Unless required by applicable law or agreed to in writing, software
90299 * distributed under the License is distributed on an "AS IS" BASIS,
90300 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90301 * See the License for the specific language governing permissions and
90302 * limitations under the License.
90303 * =============================================================================
90304 */
90305 const cropAndResize$2 = (args) => {
90306 const { inputs, backend, attrs } = args;
90307 const { image, boxes, boxInd } = inputs;
90308 const { cropSize, method, extrapolationValue } = attrs;
90309 const program = new CropAndResizeProgram(image.shape, boxes.shape, cropSize, method, extrapolationValue);
90310 return backend.runWebGLProgram(program, [image, boxes, boxInd], 'float32');
90311 };
90312 const cropAndResizeConfig$1 = {
90313 kernelName: CropAndResize,
90314 backendName: 'webgl',
90315 kernelFunc: cropAndResize$2
90316 };
90317
90318 var CumOpType;
90319 (function (CumOpType) {
90320 CumOpType["Prod"] = "*";
90321 CumOpType["Sum"] = "+";
90322 })(CumOpType || (CumOpType = {}));
90323 class CumProgram {
90324 constructor(op, outputShape, exclusive, reverse) {
90325 this.op = op;
90326 this.outputShape = outputShape;
90327 this.variableNames = ['x'];
90328 this.customUniforms = [{ name: 'index', type: 'float' }];
90329 const rank = this.outputShape.length;
90330 const initVal = this.op === CumOpType.Prod ? '1.0' : '0.0';
90331 const val = exclusive ? initVal : `getX(${getCoords$1(rank, 'coords', this.op)})`;
90332 const length = this.outputShape[this.outputShape.length - 1];
90333 let condition = '';
90334 let idxString = '';
90335 // When exclusive is set, the cum op becomes roll op that copies the
90336 // value from the previous index based on the direction specified by the
90337 // reverse flag.
90338 if (exclusive) {
90339 condition = reverse ? `end != ${length - 1}` : 'end != 0';
90340 idxString = reverse ? 'end + 1' : 'end - 1';
90341 }
90342 else {
90343 condition = reverse ? `end + pow2 < ${length}` : 'end >= pow2';
90344 idxString = (reverse ? 'end + pow2' : 'end - pow2');
90345 }
90346 this.userCode = `
90347 void main() {
90348 ${getCoordsDataType(rank)} coords = getOutputCoords();
90349 int end = ${getFinalCoord(rank, 'coords', this.op)};
90350 float val = ${val};
90351 int pow2 = int(pow(2.0, index));
90352 if (${condition}) {
90353 int idx = ${idxString};
90354 ${getFinalCoord(rank, 'coords', this.op)} = idx;
90355 val ${this.op}= getX(${getCoords$1(rank, 'coords', this.op)});
90356 }
90357 setOutput(val);
90358 }
90359 `;
90360 }
90361 }
90362 function getCoords$1(rank, name, op) {
90363 if (rank === 1) {
90364 return `${name}`;
90365 }
90366 else if (rank === 2) {
90367 return `${name}.x, ${name}.y`;
90368 }
90369 else if (rank === 3) {
90370 return `${name}.x, ${name}.y, ${name}.z`;
90371 }
90372 else if (rank === 4) {
90373 return `${name}.x, ${name}.y, ${name}.z, ${name}.w`;
90374 }
90375 else {
90376 throw new Error(`Cumulative ${op} for rank ${rank} is not yet supported`);
90377 }
90378 }
90379 function getFinalCoord(rank, name, op) {
90380 if (rank === 1) {
90381 return `${name}`;
90382 }
90383 else if (rank === 2) {
90384 return `${name}.y`;
90385 }
90386 else if (rank === 3) {
90387 return `${name}.z`;
90388 }
90389 else if (rank === 4) {
90390 return `${name}.w`;
90391 }
90392 else {
90393 throw new Error(`Cumulative ${op} for rank ${rank} is not yet supported`);
90394 }
90395 }
90396
90397 /**
90398 * @license
90399 * Copyright 2022 Google LLC. All Rights Reserved.
90400 * Licensed under the Apache License, Version 2.0 (the "License");
90401 * you may not use this file except in compliance with the License.
90402 * You may obtain a copy of the License at
90403 *
90404 * http://www.apache.org/licenses/LICENSE-2.0
90405 *
90406 * Unless required by applicable law or agreed to in writing, software
90407 * distributed under the License is distributed on an "AS IS" BASIS,
90408 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90409 * See the License for the specific language governing permissions and
90410 * limitations under the License.
90411 * =============================================================================
90412 */
90413 function cumImpl(op, x, backend, axis, exclusive, reverse) {
90414 const xRank = x.shape.length;
90415 const permutation = getAxesPermutation([axis], xRank);
90416 let permutedX = x;
90417 if (permutation != null) {
90418 permutedX = transpose$2({ inputs: { x }, backend, attrs: { perm: permutation } });
90419 }
90420 const permutedAxis = getInnerMostAxes(1, xRank)[0];
90421 if (permutedAxis !== xRank - 1) {
90422 throw new Error(`WebGL cumprod shader expects an inner-most axis=${x.shape.length - 1} ` +
90423 `but got axis=${axis}`);
90424 }
90425 const size = permutedX.shape[permutedAxis];
90426 let result = identity$2({ inputs: { x: permutedX }, backend });
90427 // Use cum parallel algorithm, inspired by:
90428 // https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
90429 // Note: although the algorithm is called sum, it works for any associtative
90430 // operator with an identity.
90431 for (let i = 0; i <= Math.ceil(Math.log2(size)) - 1; i++) {
90432 const program = new CumProgram(op, permutedX.shape, false, reverse);
90433 const customValues = [[i]];
90434 const prevResult = result;
90435 result =
90436 backend.runWebGLProgram(program, [result], result.dtype, customValues);
90437 backend.disposeIntermediateTensorInfo(prevResult);
90438 }
90439 // For exclusive cum, shift the end result in the direction of product or sum
90440 // and add 1 for product or 0 for sum to the front index.
90441 if (exclusive) {
90442 const program = new CumProgram(op, permutedX.shape, exclusive, reverse);
90443 const prevResult = result;
90444 result = backend.runWebGLProgram(program, [result], result.dtype);
90445 backend.disposeIntermediateTensorInfo(prevResult);
90446 }
90447 if (permutation != null) {
90448 const reversePermutation = getUndoAxesPermutation(permutation);
90449 const reverseTransposedResult = transpose$2({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } });
90450 backend.disposeIntermediateTensorInfo(result);
90451 backend.disposeIntermediateTensorInfo(permutedX);
90452 return reverseTransposedResult;
90453 }
90454 return result;
90455 }
90456
90457 /**
90458 * @license
90459 * Copyright 2022 Google LLC. All Rights Reserved.
90460 * Licensed under the Apache License, Version 2.0 (the "License");
90461 * you may not use this file except in compliance with the License.
90462 * You may obtain a copy of the License at
90463 *
90464 * http://www.apache.org/licenses/LICENSE-2.0
90465 *
90466 * Unless required by applicable law or agreed to in writing, software
90467 * distributed under the License is distributed on an "AS IS" BASIS,
90468 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90469 * See the License for the specific language governing permissions and
90470 * limitations under the License.
90471 * =============================================================================
90472 */
90473 function cumprod$2(args) {
90474 const { inputs, backend, attrs } = args;
90475 const { x } = inputs;
90476 const { axis, exclusive, reverse } = attrs;
90477 return cumImpl(CumOpType.Prod, x, backend, axis, exclusive, reverse);
90478 }
90479 const cumprodConfig$1 = {
90480 kernelName: Cumprod,
90481 backendName: 'webgl',
90482 kernelFunc: cumprod$2
90483 };
90484
90485 /**
90486 * @license
90487 * Copyright 2022 Google LLC. All Rights Reserved.
90488 * Licensed under the Apache License, Version 2.0 (the "License");
90489 * you may not use this file except in compliance with the License.
90490 * You may obtain a copy of the License at
90491 *
90492 * http://www.apache.org/licenses/LICENSE-2.0
90493 *
90494 * Unless required by applicable law or agreed to in writing, software
90495 * distributed under the License is distributed on an "AS IS" BASIS,
90496 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90497 * See the License for the specific language governing permissions and
90498 * limitations under the License.
90499 * =============================================================================
90500 */
90501 function cumsum$2(args) {
90502 const { inputs, backend, attrs } = args;
90503 const { x } = inputs;
90504 const { axis, exclusive, reverse } = attrs;
90505 return cumImpl(CumOpType.Sum, x, backend, axis, exclusive, reverse);
90506 }
90507 const cumsumConfig$1 = {
90508 kernelName: Cumsum,
90509 backendName: 'webgl',
90510 kernelFunc: cumsum$2
90511 };
90512
90513 /**
90514 * @license
90515 * Copyright 2020 Google LLC. All Rights Reserved.
90516 * Licensed under the Apache License, Version 2.0 (the "License");
90517 * you may not use this file except in compliance with the License.
90518 * You may obtain a copy of the License at
90519 *
90520 * http://www.apache.org/licenses/LICENSE-2.0
90521 *
90522 * Unless required by applicable law or agreed to in writing, software
90523 * distributed under the License is distributed on an "AS IS" BASIS,
90524 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90525 * See the License for the specific language governing permissions and
90526 * limitations under the License.
90527 * =============================================================================
90528 */
90529 function denseBincount$2(args) {
90530 const { inputs, backend, attrs } = args;
90531 const { x, weights } = inputs;
90532 const { size, binaryOutput } = attrs;
90533 if (x.shape.length === 1) {
90534 const xVals = backend.readSync(x.dataId);
90535 const weightsVals = backend.readSync(weights.dataId);
90536 const outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size);
90537 return backend.makeTensorInfo([size], weights.dtype, outVals);
90538 }
90539 else if (x.shape.length === 2) {
90540 const xBuf = backend.bufferSync(x);
90541 const weightsBuf = backend.bufferSync(weights);
90542 const outBuf = bincountReduceImplCPU(xBuf, weightsBuf, size, binaryOutput);
90543 return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values);
90544 }
90545 throw new Error(`Error in denseBincount: input must be at most rank 2, but got rank` +
90546 `${x.shape.length}.`);
90547 }
90548 const denseBincountConfig$1 = {
90549 kernelName: DenseBincount,
90550 backendName: 'webgl',
90551 kernelFunc: denseBincount$2
90552 };
90553
90554 /**
90555 * @license
90556 * Copyright 2018 Google LLC. All Rights Reserved.
90557 * Licensed under the Apache License, Version 2.0 (the "License");
90558 * you may not use this file except in compliance with the License.
90559 * You may obtain a copy of the License at
90560 *
90561 * http://www.apache.org/licenses/LICENSE-2.0
90562 *
90563 * Unless required by applicable law or agreed to in writing, software
90564 * distributed under the License is distributed on an "AS IS" BASIS,
90565 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90566 * See the License for the specific language governing permissions and
90567 * limitations under the License.
90568 * =============================================================================
90569 */
90570 class DepthToSpaceProgram {
90571 constructor(outputShape, blockSize, dataFormat) {
90572 this.variableNames = ['x'];
90573 this.outputShape = [];
90574 this.outputShape = outputShape;
90575 this.blockSize = blockSize;
90576 this.dataFormat = dataFormat;
90577 this.userCode = `
90578 void main() {
90579 ivec4 coords = getOutputCoords();
90580 int b = coords[0];
90581 int h = ${this.getHeightCoordString()};
90582 int w = ${this.getWidthCoordString()};
90583 int d = ${this.getDepthCoordString()};
90584
90585 int in_h = h / ${blockSize};
90586 int offset_h = imod(h, ${blockSize});
90587 int in_w = w / ${blockSize};
90588 int offset_w = imod(w, ${blockSize});
90589 int offset_d = (offset_h * ${blockSize} + offset_w) *
90590 ${this.getOutputDepthSize()};
90591 int in_d = d + offset_d;
90592
90593 float result = ${this.getInputSamplingString()};
90594 setOutput(result);
90595 }
90596 `;
90597 }
90598 getHeightCoordString() {
90599 if (this.dataFormat === 'NHWC') {
90600 return `coords[1]`;
90601 }
90602 else {
90603 return `coords[2]`;
90604 }
90605 }
90606 getWidthCoordString() {
90607 if (this.dataFormat === 'NHWC') {
90608 return `coords[2]`;
90609 }
90610 else {
90611 return `coords[3]`;
90612 }
90613 }
90614 getDepthCoordString() {
90615 if (this.dataFormat === 'NHWC') {
90616 return `coords[3]`;
90617 }
90618 else {
90619 return `coords[1]`;
90620 }
90621 }
90622 getOutputDepthSize() {
90623 if (this.dataFormat === 'NHWC') {
90624 return this.outputShape[3];
90625 }
90626 else {
90627 return this.outputShape[1];
90628 }
90629 }
90630 getInputSamplingString() {
90631 if (this.dataFormat === 'NHWC') {
90632 return `getX(b, in_h, in_w, in_d)`;
90633 }
90634 else {
90635 return `getX(b, in_d, in_h, in_w)`;
90636 }
90637 }
90638 }
90639
90640 /**
90641 * @license
90642 * Copyright 2020 Google LLC. All Rights Reserved.
90643 * Licensed under the Apache License, Version 2.0 (the "License");
90644 * you may not use this file except in compliance with the License.
90645 * You may obtain a copy of the License at
90646 *
90647 * http://www.apache.org/licenses/LICENSE-2.0
90648 *
90649 * Unless required by applicable law or agreed to in writing, software
90650 * distributed under the License is distributed on an "AS IS" BASIS,
90651 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90652 * See the License for the specific language governing permissions and
90653 * limitations under the License.
90654 * =============================================================================
90655 */
90656 function depthToSpace$2(args) {
90657 const { inputs, backend, attrs } = args;
90658 const { x } = inputs;
90659 const { blockSize, dataFormat } = attrs;
90660 const batchSize = x.shape[0];
90661 const inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2];
90662 const inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3];
90663 const inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1];
90664 const outputHeight = inputHeight * blockSize;
90665 const outputWidth = inputWidth * blockSize;
90666 const outputDepth = inputDepth / (blockSize * blockSize);
90667 const outputShape = (dataFormat === 'NHWC') ?
90668 [batchSize, outputHeight, outputWidth, outputDepth] :
90669 [batchSize, outputDepth, outputHeight, outputWidth];
90670 const program = new DepthToSpaceProgram(outputShape, blockSize, dataFormat);
90671 return backend.runWebGLProgram(program, [x], x.dtype);
90672 }
90673 const depthToSpaceConfig$1 = {
90674 kernelName: DepthToSpace,
90675 backendName: 'webgl',
90676 kernelFunc: depthToSpace$2
90677 };
90678
90679 /**
90680 * @license
90681 * Copyright 2017 Google LLC. All Rights Reserved.
90682 * Licensed under the Apache License, Version 2.0 (the "License");
90683 * you may not use this file except in compliance with the License.
90684 * You may obtain a copy of the License at
90685 *
90686 * http://www.apache.org/licenses/LICENSE-2.0
90687 *
90688 * Unless required by applicable law or agreed to in writing, software
90689 * distributed under the License is distributed on an "AS IS" BASIS,
90690 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90691 * See the License for the specific language governing permissions and
90692 * limitations under the License.
90693 * =============================================================================
90694 */
90695 class DepthwiseConv2DProgram {
90696 constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) {
90697 this.variableNames = ['x', 'W'];
90698 this.customUniforms = [
90699 { name: 'pads', type: 'ivec2' },
90700 { name: 'strides', type: 'ivec2' },
90701 { name: 'dilations', type: 'ivec2' },
90702 { name: 'inDims', type: 'ivec2' },
90703 ];
90704 this.outputShape = convInfo.outShape;
90705 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
90706 const filterHeight = convInfo.filterHeight;
90707 const filterWidth = convInfo.filterWidth;
90708 const channelMul = convInfo.outChannels / convInfo.inChannels;
90709 let activationSnippet = '', applyActivationSnippet = '';
90710 if (activation) {
90711 if (hasPreluActivation) {
90712 activationSnippet = `float activation(float a) {
90713 float b = getPreluActivationWeightsAtOutCoords();
90714 ${activation}
90715 }`;
90716 }
90717 else if (hasLeakyReluAlpha) {
90718 activationSnippet = `float activation(float a) {
90719 float b = getLeakyreluAlphaAtOutCoords();
90720 ${activation}
90721 }`;
90722 }
90723 else {
90724 activationSnippet = `
90725 float activation(float x) {
90726 ${activation}
90727 }
90728 `;
90729 }
90730 applyActivationSnippet = `result = activation(result);`;
90731 }
90732 const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
90733 if (addBias) {
90734 this.variableNames.push('bias');
90735 }
90736 if (hasPreluActivation) {
90737 this.variableNames.push('preluActivationWeights');
90738 }
90739 if (hasLeakyReluAlpha) {
90740 this.variableNames.push('leakyreluAlpha');
90741 }
90742 this.userCode = `
90743 ${activationSnippet}
90744
90745 void main() {
90746 ivec4 coords = getOutputCoords();
90747 int batch = coords.x;
90748 ivec2 xRCCorner = coords.yz * strides - pads;
90749 int d2 = coords.w;
90750 int d1 = d2 / ${channelMul};
90751 int q = d2 - d1 * ${channelMul};
90752
90753 int xRCorner = xRCCorner.x;
90754 int xCCorner = xRCCorner.y;
90755
90756 // Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).
90757 // ? = to be determined. : = across all values in that axis.
90758 float dotProd = 0.0;
90759 // TO DO(dsmilkov): Flatten the two for loops and vec4 the operations.
90760 for (int wR = 0; wR < ${filterHeight}; wR++) {
90761 int xR = xRCorner + wR * dilations[0];
90762
90763 if (xR < 0 || xR >= inDims[0]) {
90764 continue;
90765 }
90766
90767 for (int wC = 0; wC < ${filterWidth}; wC++) {
90768 int xC = xCCorner + wC * dilations[1];
90769
90770 if (xC < 0 || xC >= inDims[1]) {
90771 continue;
90772 }
90773
90774 float xVal = getX(batch, xR, xC, d1);
90775 float wVal = getW(wR, wC, d1, q);
90776 dotProd += xVal * wVal;
90777 }
90778 }
90779
90780 float result = dotProd;
90781 ${addBiasSnippet}
90782 ${applyActivationSnippet}
90783 setOutput(result);
90784 }
90785 `;
90786 }
90787 }
90788
90789 /**
90790 * @license
90791 * Copyright 2018 Google LLC. All Rights Reserved.
90792 * Licensed under the Apache License, Version 2.0 (the "License");
90793 * you may not use this file except in compliance with the License.
90794 * You may obtain a copy of the License at
90795 *
90796 * http://www.apache.org/licenses/LICENSE-2.0
90797 *
90798 * Unless required by applicable law or agreed to in writing, software
90799 * distributed under the License is distributed on an "AS IS" BASIS,
90800 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90801 * See the License for the specific language governing permissions and
90802 * limitations under the License.
90803 * =============================================================================
90804 */
90805 class DepthwiseConvPacked2DProgram {
90806 constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) {
90807 this.variableNames = ['x', 'W'];
90808 this.packedInputs = true;
90809 this.packedOutput = true;
90810 this.customUniforms = [
90811 { name: 'pads', type: 'ivec2' },
90812 { name: 'strides', type: 'ivec2' },
90813 { name: 'dilations', type: 'ivec2' },
90814 { name: 'inDims', type: 'ivec2' },
90815 ];
90816 this.outputShape = convInfo.outShape;
90817 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
90818 const channelMul = convInfo.outChannels / convInfo.inChannels;
90819 const padLeft = convInfo.padInfo.left;
90820 const strideWidth = convInfo.strideWidth;
90821 const dilationWidth = convInfo.dilationWidth;
90822 const filterHeight = convInfo.filterHeight;
90823 const filterWidth = convInfo.filterWidth;
90824 const texelsAcross = filterWidth;
90825 let mainLoop = `
90826 int xR; int xC; int xCOffset;
90827 vec4 wTexel; vec4 previous; vec4 final;`;
90828 for (let c = 0; c < filterWidth; c++) {
90829 mainLoop += `
90830 vec4 xTexelC${c * 2};
90831 int xTexelC${c * 2}Ready;
90832 vec4 xTexelC${c * 2 + 1};
90833 int xTexelC${c * 2 + 1}Ready;
90834 vec4 xC${c};`;
90835 }
90836 /**
90837 * This vectorized implementation works by gathering the values needed for
90838 * each output channel's dot product into vec4's and then multiplying them
90839 * all together (this happens in the final double for-loop below). Most of
90840 * the main loop consists of constructing these vec4's with the minimum
90841 * number of texture2D calls, which means making use of all four returned
90842 * values from a texture2D call at once.
90843 */
90844 mainLoop += `
90845 for (int r = 0; r < ${filterHeight}; r++) {
90846 `;
90847 for (let c = 0; c < filterWidth; c++) {
90848 mainLoop += `
90849 xTexelC${c * 2} = vec4(0.0);
90850 xTexelC${c * 2}Ready = 0;
90851 xTexelC${c * 2 + 1} = vec4(0.0);
90852 xTexelC${c * 2 + 1}Ready = 0;
90853 xC${c} = vec4(0.0);`;
90854 }
90855 mainLoop += `
90856 xR = xRCorner + r * dilations[0];
90857 if (xR >=0 && xR < inDims[0]) {
90858 `;
90859 for (let texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) {
90860 const colIndex = texelC * 2;
90861 mainLoop += `
90862 xC = xCCorner + ${colIndex * dilationWidth};
90863 `;
90864 if (strideWidth === 1) {
90865 if (colIndex < filterWidth) {
90866 // If padding is odd, the outer texels have to be composed.
90867 if (padLeft % 2 === 1) {
90868 // TODO: Ensure vec4 previous does not result in redundant sample,
90869 // and avoid setting xTexelRC's that exceed the boundary in the
90870 // first place rather than resetting them to vec4(0)).
90871 // To compute xCOffset:
90872 // - If padding is odd, we must add 1 to ensure we ask for an
90873 // even-numbered row.
90874 // - We subtract 2 to access the previous texel.
90875 mainLoop += `
90876 xCOffset = xC + 1;
90877 if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) {
90878 xTexelC${colIndex} = getX(batch, xR, xCOffset, d1);
90879
90880 // Need to manually clear unused channels in case
90881 // we're reading from recycled texture.
90882 if (xCOffset + 1 >= inDims[1]) {
90883 xTexelC${colIndex}.zw = vec2(0.0);
90884 }
90885 xTexelC${colIndex}Ready = 1;
90886 }
90887 `;
90888 // This texel has been read in previous iteration if the dilation
90889 // is 1.
90890 if (dilationWidth === 1 && colIndex > 0) {
90891 mainLoop += `
90892 xC${colIndex} = vec4(xTexelC${colIndex - 2}.zw, xTexelC${colIndex}.xy);
90893 `;
90894 }
90895 else {
90896 mainLoop += `
90897 xCOffset = xC + 1 - 2;
90898
90899 if (xCOffset >= 0 && xCOffset < inDims[1]) {
90900 previous = getX(batch, xR, xCOffset, d1);
90901
90902 // Need to manually clear unused channels in case
90903 // we're reading from recycled texture.
90904 if (xCOffset + 1 >= inDims[1]) {
90905 previous.zw = vec2(0.0);
90906 }
90907
90908 xC${colIndex} = vec4(previous.zw, xTexelC${colIndex}.xy);
90909 } else {
90910 xC${colIndex} = vec4(0.0, 0.0, xTexelC${colIndex}.xy);
90911 }
90912 `;
90913 }
90914 }
90915 else {
90916 // Padding is even, so xRC corresponds to a single texel.
90917 mainLoop += `
90918 if (xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) {
90919 xTexelC${colIndex} = getX(batch, xR, xC, d1);
90920 if (xC + 1 >= inDims[1]) {
90921 xTexelC${colIndex}.zw = vec2(0.0);
90922 }
90923 xTexelC${colIndex}Ready = 1;
90924 }
90925
90926 xC${colIndex} = xTexelC${colIndex};
90927 `;
90928 }
90929 if (colIndex + 1 < filterWidth) {
90930 // If dilation is even, the second entry should match the first
90931 // (either both are composed or both are single samples). But if
90932 // dilation is odd, then the second entry should be the opposite
90933 // of the first (if the first is composed, the second is a single
90934 // sample, and vice versa.)
90935 const nextTexelOffset = padLeft % 2 === 0 ?
90936 nearestLargerEven(dilationWidth) :
90937 dilationWidth;
90938 if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) ||
90939 (dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) {
90940 mainLoop += `
90941 xCOffset = xC + imod(pads[1], 2) + ${nextTexelOffset};
90942
90943 if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
90944 xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
90945
90946 // Need to manually clear unused channels in case
90947 // we're reading from recycled texture.
90948 if (xCOffset + 1 >= inDims[1]) {
90949 xTexelC${colIndex + 1}.zw = vec2(0.0);
90950 }
90951 xTexelC${colIndex + 1}Ready = 1;
90952 }
90953 `;
90954 // If dilation > 1 then the xRC's will not be able to share any
90955 // values, so each xRC will require two unique calls to getX.
90956 if (dilationWidth > 1) {
90957 mainLoop += `
90958 xCOffset -= 2;
90959 if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) {
90960 xTexelC${colIndex} = getX(batch, xR, xCOffset, d1);
90961 xTexelC${colIndex}Ready = 1;
90962 }
90963 `;
90964 }
90965 mainLoop += `
90966 xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.xy);
90967 `;
90968 }
90969 else {
90970 // If dilation is 1 and padding is odd, we have already read the
90971 // texel when constructing the previous x value. Here we can
90972 // simply skip the texture read.
90973 if (nextTexelOffset === 1) {
90974 mainLoop += `
90975 xC${colIndex + 1} = xTexelC${colIndex};
90976 `;
90977 }
90978 else {
90979 mainLoop += `
90980 xCOffset = xC + ${nextTexelOffset};
90981
90982 if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
90983 xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
90984 if (xCOffset + 1 >= inDims[1]) {
90985 xTexelC${colIndex + 1}.zw = vec2(0.0);
90986 }
90987 xTexelC${colIndex + 1}Ready = 1;
90988 }
90989
90990 xC${colIndex + 1} = xTexelC${colIndex + 1};
90991 `;
90992 }
90993 }
90994 }
90995 }
90996 }
90997 else { // stride === 2
90998 if (colIndex < filterWidth) {
90999 // Depending on whether padLeft is even or odd, we want either the
91000 // xy or zw channels from X texels for xC${colIndex}. If padLeft is
91001 // even, xC${colIndex +1} is simply the zw channels of texels we've
91002 // already sampled. But if padLeft is odd, xC{$c + 1}.zw will
91003 // need to come from the xy channels of a new texel, hence the `
91004 // vec4
91005 // final` initialized below.
91006 if (padLeft % 2 === 1) {
91007 mainLoop += `
91008 xCOffset = xC + 1 - strides[1];
91009 if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) {
91010 xTexelC${colIndex} = getX(batch, xR, xCOffset, d1);
91011 // Need to manually clear unused channels in case
91012 // we're reading from recycled texture.
91013 if (xCOffset + 1 >= inDims[1]) {
91014 xTexelC${colIndex}.zw = vec2(0.0);
91015 }
91016 xTexelC${colIndex}Ready = 1;
91017 }
91018
91019 if(xC + 1 >= 0 && xC + 1 < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
91020 xTexelC${colIndex + 1} = getX(batch, xR, xC + 1, d1);
91021 // Need to manually clear unused channels in case
91022 // we're reading from recycled texture.
91023 if (xC + 2 >= inDims[1]) {
91024 xTexelC${colIndex + 1}.zw = vec2(0.0);
91025 }
91026 xTexelC${colIndex + 1}Ready = 1;
91027 }
91028
91029 xC${colIndex} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw);
91030 `;
91031 if (colIndex + 1 < filterWidth) {
91032 mainLoop += `
91033 final = vec4(0.0);
91034 xCOffset = xC + 1 + strides[1];
91035 if(xCOffset >= 0 && xCOffset < inDims[1]) {
91036 final = getX(batch, xR, xCOffset, d1);
91037 }
91038 xC${colIndex + 1} = vec4(xTexelC${colIndex + 1}.xy, final.xy);
91039 `;
91040 }
91041 }
91042 else {
91043 mainLoop += `
91044 if(xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) {
91045 xTexelC${colIndex} = getX(batch, xR, xC, d1);
91046 if (xC + 1 >= inDims[1]) {
91047 xTexelC${colIndex}.zw = vec2(0.0);
91048 }
91049 xTexelC${colIndex}Ready = 1;
91050 }
91051
91052 xCOffset = xC + strides[1];
91053 if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
91054 xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
91055 if (xCOffset + 1 >= inDims[1]) {
91056 xTexelC${colIndex + 1}.zw = vec2(0.);
91057 }
91058 xTexelC${colIndex + 1}Ready = 1;
91059 }
91060
91061 xC${colIndex} = vec4(
91062 xTexelC${colIndex}.xy, xTexelC${colIndex + 1}.xy);
91063 `;
91064 if (colIndex + 1 < filterWidth) {
91065 mainLoop += `
91066 xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw);
91067 `;
91068 }
91069 }
91070 }
91071 }
91072 // localize the dotProd accumulation within the loop, the theory is for
91073 // GPU with limited cache, accumulate sum across large amount of
91074 // veriables will cause lots of cache misses. (i.e. 5x5 filter will have
91075 // 50 variables)
91076 if (colIndex < filterWidth) {
91077 mainLoop += `
91078 wTexel = getW(r, ${colIndex}, d1, q);
91079 dotProd += xC${colIndex} * vec4(wTexel.xz, wTexel.xz);
91080 `;
91081 if (colIndex + 1 < filterWidth) {
91082 mainLoop += `
91083 wTexel = getW(r, ${colIndex + 1}, d1, q);
91084 dotProd += xC${colIndex + 1} * vec4(wTexel.xz, wTexel.xz);
91085 `;
91086 }
91087 }
91088 }
91089 mainLoop += `
91090 }
91091 `;
91092 mainLoop += `
91093 }
91094 `;
91095 let activationSnippet = '', applyActivationSnippet = '';
91096 if (activation) {
91097 if (hasPreluActivation) {
91098 activationSnippet = `vec4 activation(vec4 a) {
91099 vec4 b = getPreluActivationWeightsAtOutCoords();
91100 ${activation}
91101 }`;
91102 }
91103 else if (hasLeakyReluAlpha) {
91104 activationSnippet = `vec4 activation(vec4 a) {
91105 vec4 b = getLeakyreluAlphaAtOutCoords();
91106 ${activation}
91107 }`;
91108 }
91109 else {
91110 activationSnippet = `vec4 activation(vec4 x) {
91111 ${activation}
91112 }`;
91113 }
91114 applyActivationSnippet = `result = activation(result);`;
91115 }
91116 const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
91117 if (addBias) {
91118 this.variableNames.push('bias');
91119 }
91120 if (hasPreluActivation) {
91121 this.variableNames.push('preluActivationWeights');
91122 }
91123 if (hasLeakyReluAlpha) {
91124 this.variableNames.push('leakyreluAlpha');
91125 }
91126 this.userCode = `
91127 ${activationSnippet}
91128
91129 void main() {
91130 ivec4 coords = getOutputCoords();
91131 int batch = coords.x;
91132 ivec2 xRCCorner = coords.yz * strides - pads;
91133 int d2 = coords.w;
91134 int d1 = d2 / ${channelMul};
91135 int q = d2 - d1 * ${channelMul};
91136 int xRCorner = xRCCorner.x;
91137 int xCCorner = xRCCorner.y;
91138
91139 //intialize dotProd with a small epsilon seems to reduce GPU accuracy loss.
91140 vec4 dotProd = vec4(0.000000000000001);
91141
91142 ${mainLoop}
91143
91144 vec4 result = dotProd - vec4(0.000000000000001);
91145 ${addBiasSnippet}
91146 ${applyActivationSnippet}
91147 setOutput(result);
91148 }
91149 `;
91150 }
91151 }
91152
91153 /**
91154 * @license
91155 * Copyright 2020 Google LLC. All Rights Reserved.
91156 * Licensed under the Apache License, Version 2.0 (the "License");
91157 * you may not use this file except in compliance with the License.
91158 * You may obtain a copy of the License at
91159 *
91160 * http://www.apache.org/licenses/LICENSE-2.0
91161 *
91162 * Unless required by applicable law or agreed to in writing, software
91163 * distributed under the License is distributed on an "AS IS" BASIS,
91164 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91165 * See the License for the specific language governing permissions and
91166 * limitations under the License.
91167 * =============================================================================
91168 */
91169 function depthwiseConv2dNative$1(args) {
91170 const { inputs, backend, attrs } = args;
91171 const { x, filter } = inputs;
91172 const { strides, pad, dilations, dimRoundingMode } = attrs;
91173 let $dilations = dilations;
91174 if ($dilations == null) {
91175 $dilations = [1, 1];
91176 }
91177 assert(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +
91178 `1. Got strides ${strides} and dilations '${$dilations}'`);
91179 const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */);
91180 let program;
91181 if (env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 &&
91182 convInfo.outChannels / convInfo.inChannels === 1) {
91183 program = new DepthwiseConvPacked2DProgram(convInfo);
91184 }
91185 else {
91186 program = new DepthwiseConv2DProgram(convInfo);
91187 }
91188 const customValues = [
91189 [convInfo.padInfo.top, convInfo.padInfo.left],
91190 [convInfo.strideHeight, convInfo.strideWidth],
91191 [convInfo.dilationHeight, convInfo.dilationWidth],
91192 [convInfo.inHeight, convInfo.inWidth]
91193 ];
91194 return backend.runWebGLProgram(program, [x, filter], 'float32', customValues);
91195 }
91196 const depthwiseConv2dNativeConfig$1 = {
91197 kernelName: DepthwiseConv2dNative,
91198 backendName: 'webgl',
91199 kernelFunc: depthwiseConv2dNative$1,
91200 };
91201
91202 /**
91203 * @license
91204 * Copyright 2018 Google LLC. All Rights Reserved.
91205 * Licensed under the Apache License, Version 2.0 (the "License");
91206 * you may not use this file except in compliance with the License.
91207 * You may obtain a copy of the License at
91208 *
91209 * http://www.apache.org/licenses/LICENSE-2.0
91210 *
91211 * Unless required by applicable law or agreed to in writing, software
91212 * distributed under the License is distributed on an "AS IS" BASIS,
91213 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91214 * See the License for the specific language governing permissions and
91215 * limitations under the License.
91216 * =============================================================================
91217 */
91218 class DepthwiseConv2DDerFilterProgram {
91219 constructor(convInfo) {
91220 this.variableNames = ['x', 'dy'];
91221 this.outputShape = convInfo.filterShape;
91222 const strideHeight = convInfo.strideHeight;
91223 const strideWidth = convInfo.strideWidth;
91224 const padTop = convInfo.padInfo.top;
91225 const padLeft = convInfo.padInfo.left;
91226 const channelMul = convInfo.outChannels / convInfo.inChannels;
91227 this.userCode = `
91228 void main() {
91229 ivec4 coords = getOutputCoords();
91230 int wR = coords.x;
91231 int wC = coords.y;
91232 int d1 = coords.z;
91233 int dm = coords.w;
91234 int d2 = d1 * ${channelMul} + dm;
91235
91236 float dotProd = 0.0;
91237
91238 // TO DO: Vec4 over the batch size
91239 for (int b = 0; b < ${convInfo.batchSize}; b++) {
91240 for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
91241 int xR = wR + yR * ${strideHeight} - ${padTop};
91242
91243 if (xR < 0 || xR >= ${convInfo.inHeight}) {
91244 continue;
91245 }
91246
91247 for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
91248 int xC = wC + yC * ${strideWidth} - ${padLeft};
91249
91250 if (xC < 0 || xC >= ${convInfo.inWidth}) {
91251 continue;
91252 }
91253
91254 float dyValue = getDy(b, yR, yC, d2);
91255 float xValue = getX(b, xR, xC, d1);
91256 dotProd += (xValue * dyValue);
91257 }
91258 }
91259 }
91260 setOutput(dotProd);
91261 }
91262 `;
91263 }
91264 }
91265 class DepthwiseConv2DDerInputProgram {
91266 constructor(convInfo) {
91267 this.variableNames = ['dy', 'W'];
91268 this.outputShape = convInfo.inShape;
91269 const filterHeight = convInfo.filterHeight;
91270 const filterWidth = convInfo.filterWidth;
91271 const strideHeight = convInfo.strideHeight;
91272 const strideWidth = convInfo.strideWidth;
91273 const padTop = filterHeight - 1 - convInfo.padInfo.top;
91274 const padLeft = filterWidth - 1 - convInfo.padInfo.left;
91275 const channelMul = convInfo.outChannels / convInfo.inChannels;
91276 this.userCode = `
91277 const ivec2 pads = ivec2(${padTop}, ${padLeft});
91278
91279 void main() {
91280 ivec4 coords = getOutputCoords();
91281 int batch = coords[0];
91282 int d1 = coords[3];
91283 ivec2 dyCorner = coords.yz - pads;
91284 int dyRCorner = dyCorner.x;
91285 int dyCCorner = dyCorner.y;
91286
91287 float dotProd = 0.0;
91288
91289 for (int wR = 0; wR < ${filterHeight}; wR++) {
91290 float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
91291
91292 if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
91293 continue;
91294 }
91295 int idyR = int(dyR);
91296
91297 int wRPerm = ${filterHeight} - 1 - wR;
91298
91299 for (int wC = 0; wC < ${filterWidth}; wC++) {
91300 float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
91301
91302 if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
91303 fract(dyC) > 0.0) {
91304 continue;
91305 }
91306 int idyC = int(dyC);
91307
91308 int wCPerm = ${filterWidth} - 1 - wC;
91309
91310 // TO DO: Vec4 over the channelMul
91311 for (int dm = 0; dm < ${channelMul}; dm++) {
91312 int d2 = d1 * ${channelMul} + dm;
91313 float xValue = getDy(batch, idyR, idyC, d2);
91314 float wValue = getW(wRPerm, wCPerm, d1, dm);
91315 dotProd += xValue * wValue;
91316 }
91317 }
91318 }
91319 setOutput(dotProd);
91320 }
91321 `;
91322 }
91323 }
91324
91325 /**
91326 * @license
91327 * Copyright 2020 Google LLC. All Rights Reserved.
91328 * Licensed under the Apache License, Version 2.0 (the "License");
91329 * you may not use this file except in compliance with the License.
91330 * You may obtain a copy of the License at
91331 *
91332 * http://www.apache.org/licenses/LICENSE-2.0
91333 *
91334 * Unless required by applicable law or agreed to in writing, software
91335 * distributed under the License is distributed on an "AS IS" BASIS,
91336 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91337 * See the License for the specific language governing permissions and
91338 * limitations under the License.
91339 * =============================================================================
91340 */
91341 function depthwiseConv2dNativeBackpropFilter$2(args) {
91342 const { inputs, backend, attrs } = args;
91343 const { x, dy } = inputs;
91344 const { strides, dilations, pad, dimRoundingMode, filterShape } = attrs;
91345 const convInfo = computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
91346 const program = new DepthwiseConv2DDerFilterProgram(convInfo);
91347 return backend.runWebGLProgram(program, [x, dy], 'float32');
91348 }
91349 const depthwiseConv2dNativeBackpropFilterConfig$1 = {
91350 kernelName: DepthwiseConv2dNativeBackpropFilter,
91351 backendName: 'webgl',
91352 kernelFunc: depthwiseConv2dNativeBackpropFilter$2
91353 };
91354
91355 /**
91356 * @license
91357 * Copyright 2020 Google LLC. All Rights Reserved.
91358 * Licensed under the Apache License, Version 2.0 (the "License");
91359 * you may not use this file except in compliance with the License.
91360 * You may obtain a copy of the License at
91361 *
91362 * http://www.apache.org/licenses/LICENSE-2.0
91363 *
91364 * Unless required by applicable law or agreed to in writing, software
91365 * distributed under the License is distributed on an "AS IS" BASIS,
91366 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91367 * See the License for the specific language governing permissions and
91368 * limitations under the License.
91369 * =============================================================================
91370 */
91371 function depthwiseConv2dNativeBackpropInput$2(args) {
91372 const { inputs, backend, attrs } = args;
91373 const { dy, filter } = inputs;
91374 const { strides, dilations, pad, dimRoundingMode, inputShape } = attrs;
91375 const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
91376 const program = new DepthwiseConv2DDerInputProgram(convInfo);
91377 return backend.runWebGLProgram(program, [dy, filter], 'float32');
91378 }
91379 const depthwiseConv2dNativeBackpropInputConfig$1 = {
91380 kernelName: DepthwiseConv2dNativeBackpropInput,
91381 backendName: 'webgl',
91382 kernelFunc: depthwiseConv2dNativeBackpropInput$2
91383 };
91384
91385 /**
91386 * @license
91387 * Copyright 2019 Google LLC. All Rights Reserved.
91388 * Licensed under the Apache License, Version 2.0 (the "License");
91389 * you may not use this file except in compliance with the License.
91390 * You may obtain a copy of the License at
91391 *
91392 * http://www.apache.org/licenses/LICENSE-2.0
91393 *
91394 * Unless required by applicable law or agreed to in writing, software
91395 * distributed under the License is distributed on an "AS IS" BASIS,
91396 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91397 * See the License for the specific language governing permissions and
91398 * limitations under the License.
91399 * =============================================================================
91400 */
91401 class DiagProgram {
91402 constructor(size) {
91403 this.variableNames = ['X'];
91404 this.outputShape = [size, size];
91405 this.userCode = `
91406 void main() {
91407 ivec2 coords = getOutputCoords();
91408 float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;
91409 setOutput(val);
91410 }
91411 `;
91412 }
91413 }
91414
91415 /**
91416 * @license
91417 * Copyright 2020 Google LLC. All Rights Reserved.
91418 * Licensed under the Apache License, Version 2.0 (the "License");
91419 * you may not use this file except in compliance with the License.
91420 * You may obtain a copy of the License at
91421 *
91422 * http://www.apache.org/licenses/LICENSE-2.0
91423 *
91424 * Unless required by applicable law or agreed to in writing, software
91425 * distributed under the License is distributed on an "AS IS" BASIS,
91426 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91427 * See the License for the specific language governing permissions and
91428 * limitations under the License.
91429 * =============================================================================
91430 */
91431 function diag$2(args) {
91432 const { inputs, backend } = args;
91433 const { x } = inputs;
91434 const outShape = [...x.shape, ...x.shape];
91435 const xSize = sizeFromShape(x.shape);
91436 const flat = reshape$3({ inputs: { x }, backend, attrs: { shape: [xSize] } });
91437 const program = new DiagProgram(xSize);
91438 const res = backend.runWebGLProgram(program, [flat], flat.dtype);
91439 const out = reshape$3({ inputs: { x: res }, backend, attrs: { shape: outShape } });
91440 backend.disposeIntermediateTensorInfo(flat);
91441 backend.disposeIntermediateTensorInfo(res);
91442 return out;
91443 }
91444 const diagConfig$1 = {
91445 kernelName: Diag,
91446 backendName: 'webgl',
91447 kernelFunc: diag$2
91448 };
91449
91450 /**
91451 * @license
91452 * Copyright 2017 Google LLC. All Rights Reserved.
91453 * Licensed under the Apache License, Version 2.0 (the "License");
91454 * you may not use this file except in compliance with the License.
91455 * You may obtain a copy of the License at
91456 *
91457 * http://www.apache.org/licenses/LICENSE-2.0
91458 *
91459 * Unless required by applicable law or agreed to in writing, software
91460 * distributed under the License is distributed on an "AS IS" BASIS,
91461 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91462 * See the License for the specific language governing permissions and
91463 * limitations under the License.
91464 * =============================================================================
91465 */
91466 class Dilation2DProgram {
91467 constructor(convInfo) {
91468 this.variableNames = ['x', 'W'];
91469 this.outputShape = convInfo.outShape;
91470 const { inHeight, inWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth } = convInfo;
91471 const { top: padTop, left: padLeft } = padInfo;
91472 this.userCode = `
91473 const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
91474 const ivec2 pads = ivec2(${padTop}, ${padLeft});
91475 const float neg_infinity = -3.4e38;
91476
91477 void main() {
91478 ivec4 coords = getOutputCoords();
91479 int batch = coords.x;
91480 int d1 = coords.w;
91481 ivec2 outTopLeftCorner =
91482 coords.yz * strides - pads;
91483 int hBeg = outTopLeftCorner.x;
91484 int wBeg = outTopLeftCorner.y;
91485
91486 float curVal = neg_infinity;
91487 for (int h = 0; h < ${filterHeight}; h++) {
91488 int hIn = hBeg + h * ${dilationHeight};
91489
91490 if (hIn >= 0 && hIn < ${inHeight}) {
91491 for (int w = 0; w < ${filterWidth}; w++) {
91492 int wIn = wBeg + w * ${dilationWidth};
91493
91494 if (wIn >= 0 && wIn < ${inWidth}) {
91495 float xVal = getX(batch, hIn, wIn, d1);
91496 float wVal = getW(h, w, d1);
91497
91498 float val = xVal + wVal;
91499 if (val > curVal) {
91500 curVal = val;
91501 }
91502 }
91503 }
91504 }
91505 }
91506
91507 float result = curVal;
91508 setOutput(result);
91509 }
91510 `;
91511 }
91512 }
91513
91514 /**
91515 * @license
91516 * Copyright 2020 Google LLC. All Rights Reserved.
91517 * Licensed under the Apache License, Version 2.0 (the "License");
91518 * you may not use this file except in compliance with the License.
91519 * You may obtain a copy of the License at
91520 *
91521 * http://www.apache.org/licenses/LICENSE-2.0
91522 *
91523 * Unless required by applicable law or agreed to in writing, software
91524 * distributed under the License is distributed on an "AS IS" BASIS,
91525 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91526 * See the License for the specific language governing permissions and
91527 * limitations under the License.
91528 * =============================================================================
91529 */
91530 function dilation2D(args) {
91531 const { inputs, backend, attrs } = args;
91532 const { x, filter } = inputs;
91533 const { strides, pad, dilations } = attrs;
91534 const convInfo = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations);
91535 let out;
91536 const program = new Dilation2DProgram(convInfo);
91537 out = backend.runWebGLProgram(program, [x, filter], 'float32');
91538 const outReshaped = reshape$3({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } });
91539 backend.disposeIntermediateTensorInfo(out);
91540 return outReshaped;
91541 }
91542 const dilation2DConfig$1 = {
91543 kernelName: Dilation2D,
91544 backendName: 'webgl',
91545 kernelFunc: dilation2D,
91546 };
91547
91548 /**
91549 * @license
91550 * Copyright 2021 Google LLC. All Rights Reserved.
91551 * Licensed under the Apache License, Version 2.0 (the "License");
91552 * you may not use this file except in compliance with the License.
91553 * You may obtain a copy of the License at
91554 *
91555 * http://www.apache.org/licenses/LICENSE-2.0
91556 *
91557 * Unless required by applicable law or agreed to in writing, software
91558 * distributed under the License is distributed on an "AS IS" BASIS,
91559 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91560 * See the License for the specific language governing permissions and
91561 * limitations under the License.
91562 * =============================================================================
91563 */
91564 function einsum$2(args) {
91565 const { inputs, backend, attrs } = args;
91566 const { equation } = attrs;
91567 const tensors = inputs;
91568 const { allDims, summedDims, idDims } = decodeEinsumEquation(equation, tensors.length);
91569 checkEinsumDimSizes(allDims.length, idDims, tensors);
91570 const { path, steps } = getEinsumComputePath(summedDims, idDims);
91571 const nSteps = steps.length;
91572 let out = null;
91573 let numDimsRemaining = allDims.length;
91574 const tensorsToDispose = [];
91575 for (let i = 0; i < nSteps; ++i) {
91576 for (const idTerm of steps[i]) {
91577 const { permutationIndices: perm, expandDims: dimsToExpand } = getEinsumPermutation(numDimsRemaining, idDims[idTerm]);
91578 let x;
91579 if (isIdentityPermutation(perm)) {
91580 x = tensors[idTerm];
91581 }
91582 else {
91583 x = transpose$2({ inputs: { x: tensors[idTerm] }, backend, attrs: { perm } });
91584 tensorsToDispose.push(x);
91585 }
91586 const targetShape = x.shape.slice();
91587 for (let k = 0; k < dimsToExpand.length; ++k) {
91588 targetShape.splice(dimsToExpand[k], 0, 1);
91589 }
91590 if (!arraysEqual(x.shape, targetShape)) {
91591 x = reshape$3({ inputs: { x }, backend, attrs: { shape: targetShape } });
91592 tensorsToDispose.push(x);
91593 }
91594 if (out === null) {
91595 out = x;
91596 }
91597 else {
91598 // tslint:disable-next-line: no-unnecessary-type-assertion
91599 out = multiply$3({ inputs: { a: x, b: out }, backend });
91600 tensorsToDispose.push(out);
91601 }
91602 }
91603 if (i < nSteps - 1) {
91604 if (path[i] >= 0) {
91605 out = sum$4({
91606 inputs: { x: out },
91607 backend,
91608 attrs: {
91609 axis: path[i] - (allDims.length - numDimsRemaining),
91610 keepDims: false
91611 }
91612 });
91613 tensorsToDispose.push(out);
91614 }
91615 numDimsRemaining--;
91616 }
91617 }
91618 // Clean up intermediate tensors.
91619 for (const tensorInfo of tensorsToDispose) {
91620 if (tensorInfo === out) {
91621 continue;
91622 }
91623 backend.disposeIntermediateTensorInfo(tensorInfo);
91624 }
91625 return out;
91626 }
91627 const einsumConfig$1 = {
91628 kernelName: Einsum,
91629 backendName: 'webgl',
91630 kernelFunc: einsum$2
91631 };
91632
91633 /**
91634 * @license
91635 * Copyright 2020 Google LLC. All Rights Reserved.
91636 * Licensed under the Apache License, Version 2.0 (the "License");
91637 * you may not use this file except in compliance with the License.
91638 * You may obtain a copy of the License at
91639 *
91640 * http://www.apache.org/licenses/LICENSE-2.0
91641 *
91642 * Unless required by applicable law or agreed to in writing, software
91643 * distributed under the License is distributed on an "AS IS" BASIS,
91644 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91645 * See the License for the specific language governing permissions and
91646 * limitations under the License.
91647 * =============================================================================
91648 */
91649 const ELU$3 = `return (x >= 0.0) ? x : (exp(x) - 1.0);`;
91650 const ELU_PACKED = `
91651 vec4 result;
91652
91653 result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);
91654 result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);
91655 result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);
91656 result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);
91657
91658 return result;
91659`;
91660 const elu$4 = unaryKernelFunc$1({ opSnippet: ELU$3, packedOpSnippet: ELU_PACKED });
91661 const eluConfig$1 = {
91662 kernelName: Elu,
91663 backendName: 'webgl',
91664 kernelFunc: elu$4
91665 };
91666
91667 /**
91668 * @license
91669 * Copyright 2020 Google LLC. All Rights Reserved.
91670 * Licensed under the Apache License, Version 2.0 (the "License");
91671 * you may not use this file except in compliance with the License.
91672 * You may obtain a copy of the License at
91673 *
91674 * http://www.apache.org/licenses/LICENSE-2.0
91675 *
91676 * Unless required by applicable law or agreed to in writing, software
91677 * distributed under the License is distributed on an "AS IS" BASIS,
91678 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91679 * See the License for the specific language governing permissions and
91680 * limitations under the License.
91681 * =============================================================================
91682 */
91683 const ELU_DER$1 = `return (b >= 1.0) ? a : a * (b + 1.0);`;
91684 const ELU_DER_PACKED = `
91685 vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));
91686 return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));
91687`;
91688 const eluGrad$1 = (args) => {
91689 const { inputs, backend } = args;
91690 const { dy, y } = inputs;
91691 const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
91692 new BinaryOpPackedProgram(ELU_DER_PACKED, dy.shape, y.shape) :
91693 new BinaryOpProgram(ELU_DER$1, dy.shape, y.shape);
91694 return backend.runWebGLProgram(program, [dy, y], dy.dtype);
91695 };
91696 const eluGradConfig$2 = {
91697 kernelName: EluGrad,
91698 backendName: 'webgl',
91699 kernelFunc: eluGrad$1
91700 };
91701
91702 /**
91703 * @license
91704 * Copyright 2020 Google LLC. All Rights Reserved.
91705 * Licensed under the Apache License, Version 2.0 (the "License");
91706 * you may not use this file except in compliance with the License.
91707 * You may obtain a copy of the License at
91708 *
91709 * http://www.apache.org/licenses/LICENSE-2.0
91710 *
91711 * Unless required by applicable law or agreed to in writing, software
91712 * distributed under the License is distributed on an "AS IS" BASIS,
91713 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91714 * See the License for the specific language governing permissions and
91715 * limitations under the License.
91716 * =============================================================================
91717 */
91718 const PACKED_EQUAL = `
91719 return vec4(equal(a, b));
91720`;
91721 const EQUAL = `return float(a == b);`;
91722 const equal$2 = binaryKernelFunc$1({
91723 opSnippet: EQUAL,
91724 packedOpSnippet: PACKED_EQUAL,
91725 dtype: 'bool',
91726 cpuKernelImpl: equalImplCPU,
91727 });
91728 const equalConfig$1 = {
91729 kernelName: Equal,
91730 backendName: 'webgl',
91731 kernelFunc: equal$2
91732 };
91733
91734 /**
91735 * @license
91736 * Copyright 2020 Google LLC. All Rights Reserved.
91737 * Licensed under the Apache License, Version 2.0 (the "License");
91738 * you may not use this file except in compliance with the License.
91739 * You may obtain a copy of the License at
91740 *
91741 * http://www.apache.org/licenses/LICENSE-2.0
91742 *
91743 * Unless required by applicable law or agreed to in writing, software
91744 * distributed under the License is distributed on an "AS IS" BASIS,
91745 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91746 * See the License for the specific language governing permissions and
91747 * limitations under the License.
91748 * =============================================================================
91749 */
91750 const ERF = `
91751 // Error function is calculated approximately with elementary function.
91752 // See "Handbook of Mathematical Functions with Formulas,
91753 // Graphs, and Mathematical Tables", Abramowitz and Stegun.
91754 float p = ${ERF_P};
91755 float a1 = ${ERF_A1};
91756 float a2 = ${ERF_A2};
91757 float a3 = ${ERF_A3};
91758 float a4 = ${ERF_A4};
91759 float a5 = ${ERF_A5};
91760
91761 float sign = sign(x);
91762 x = abs(x);
91763 float t = 1.0 / (1.0 + p * x);
91764 return sign * (1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x));
91765`;
91766 const erf$2 = unaryKernelFunc$1({ opSnippet: ERF });
91767 const erfConfig$1 = {
91768 kernelName: Erf,
91769 backendName: 'webgl',
91770 kernelFunc: erf$2,
91771 };
91772
91773 /**
91774 * @license
91775 * Copyright 2020 Google LLC. All Rights Reserved.
91776 * Licensed under the Apache License, Version 2.0 (the "License");
91777 * you may not use this file except in compliance with the License.
91778 * You may obtain a copy of the License at
91779 *
91780 * http://www.apache.org/licenses/LICENSE-2.0
91781 *
91782 * Unless required by applicable law or agreed to in writing, software
91783 * distributed under the License is distributed on an "AS IS" BASIS,
91784 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91785 * See the License for the specific language governing permissions and
91786 * limitations under the License.
91787 * =============================================================================
91788 */
91789 const EXP = CHECK_NAN_SNIPPET_UNARY + `
91790 return exp(x);
91791`;
91792 const EXP_PACKED = `
91793 vec4 result = exp(x);
91794 bvec4 isNaN = isnan(x);
91795 result.r = isNaN.r ? x.r : result.r;
91796 result.g = isNaN.g ? x.g : result.g;
91797 result.b = isNaN.b ? x.b : result.b;
91798 result.a = isNaN.a ? x.a : result.a;
91799
91800 return result;
91801`;
91802 const exp$2 = unaryKernelFunc$1({
91803 opSnippet: EXP,
91804 packedOpSnippet: EXP_PACKED,
91805 cpuKernelImpl: expImplCPU,
91806 dtype: 'float32',
91807 });
91808 const expConfig$1 = {
91809 kernelName: Exp,
91810 backendName: 'webgl',
91811 kernelFunc: exp$2
91812 };
91813
91814 /**
91815 * @license
91816 * Copyright 2020 Google LLC. All Rights Reserved.
91817 * Licensed under the Apache License, Version 2.0 (the License);
91818 * you may not use this file except in compliance with the License.
91819 * You may obtain a copy of the License at
91820 *
91821 * http://www.apache.org/licenses/LICENSE-2.0
91822 *
91823 * Unless required by applicable law or agreed to in writing, software
91824 * distributed under the License is distributed on an AS IS BASIS,
91825 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91826 * See the License for the specific language governing permissions and
91827 * limitations under the License.
91828 * =============================================================================
91829 */
91830 function expandDims$3(args) {
91831 const { inputs, attrs, backend } = args;
91832 const { dim } = attrs;
91833 const { input } = inputs;
91834 const inputRank = input.shape.length;
91835 const newShape = input.shape.slice();
91836 let $dim = dim;
91837 if (dim < 0) {
91838 // Negative value is counted from the tail of rank.
91839 assert(-(inputRank + 1) <= dim, () => `Axis must be in the interval [${-(inputRank + 1)}, ${inputRank}]`);
91840 $dim = inputRank + dim + 1;
91841 }
91842 newShape.splice($dim, 0, 1);
91843 return reshape$3({ inputs: { x: input }, backend, attrs: { shape: newShape } });
91844 }
91845 const expandDimsConfig$1 = {
91846 kernelName: ExpandDims,
91847 backendName: 'webgl',
91848 kernelFunc: expandDims$3,
91849 };
91850
91851 /**
91852 * @license
91853 * Copyright 2020 Google LLC. All Rights Reserved.
91854 * Licensed under the Apache License, Version 2.0 (the "License");
91855 * you may not use this file except in compliance with the License.
91856 * You may obtain a copy of the License at
91857 *
91858 * http://www.apache.org/licenses/LICENSE-2.0
91859 *
91860 * Unless required by applicable law or agreed to in writing, software
91861 * distributed under the License is distributed on an "AS IS" BASIS,
91862 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91863 * See the License for the specific language governing permissions and
91864 * limitations under the License.
91865 * =============================================================================
91866 */
91867 const EXPM1 = `return exp(x) - 1.0;`;
91868 const expm1$2 = unaryKernelFunc$1({ opSnippet: EXPM1, packedOpSnippet: EXPM1, cpuKernelImpl: expm1ImplCPU });
91869 const expm1Config$1 = {
91870 kernelName: Expm1,
91871 backendName: 'webgl',
91872 kernelFunc: expm1$2
91873 };
91874
91875 /**
91876 * @license
91877 * Copyright 2018 Google LLC. All Rights Reserved.
91878 * Licensed under the Apache License, Version 2.0 (the "License");
91879 * you may not use this file except in compliance with the License.
91880 * You may obtain a copy of the License at
91881 *
91882 * http://www.apache.org/licenses/LICENSE-2.0
91883 *
91884 * Unless required by applicable law or agreed to in writing, software
91885 * distributed under the License is distributed on an "AS IS" BASIS,
91886 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91887 * See the License for the specific language governing permissions and
91888 * limitations under the License.
91889 * =============================================================================
91890 */
91891 class FFTProgram {
91892 constructor(component, inputShape, inverse) {
91893 this.variableNames = ['real', 'imag'];
91894 const innerDim = inputShape[1];
91895 this.outputShape = inputShape;
91896 const exponentMultiplierSnippet = inverse ? `2.0 * ${Math.PI}` : `-2.0 * ${Math.PI}`;
91897 const resultDenominator = inverse ? `${innerDim}.0` : '1.0';
91898 let opString;
91899 if (component === 'real') {
91900 opString = 'return real * expR - imag * expI;';
91901 }
91902 else if (component === 'imag') {
91903 opString = 'return real * expI + imag * expR;';
91904 }
91905 else {
91906 throw new Error(`FFT component must be either "real" or "imag", got ${component}.`);
91907 }
91908 this.userCode = `
91909 const float exponentMultiplier = ${exponentMultiplierSnippet};
91910
91911 float unaryOpComplex(float real, float expR, float imag, float expI) {
91912 ${opString}
91913 }
91914
91915 float mulMatDFT(int batch, int index) {
91916 float indexRatio = float(index) / float(${innerDim});
91917 float exponentMultiplierTimesIndexRatio =
91918 exponentMultiplier * indexRatio;
91919
91920 float result = 0.0;
91921
91922 for (int i = 0; i < ${innerDim}; i++) {
91923 // x = (-2|2 * PI / N) * index * i;
91924 float x = exponentMultiplierTimesIndexRatio * float(i);
91925 float expR = cos(x);
91926 float expI = sin(x);
91927 float real = getReal(batch, i);
91928 float imag = getImag(batch, i);
91929
91930 result +=
91931 unaryOpComplex(real, expR, imag, expI) / ${resultDenominator};
91932 }
91933
91934 return result;
91935 }
91936
91937 void main() {
91938 ivec2 coords = getOutputCoords();
91939 setOutput(mulMatDFT(coords[0], coords[1]));
91940 }
91941 `;
91942 }
91943 }
91944
91945 /**
91946 * @license
91947 * Copyright 2020 Google LLC. All Rights Reserved.
91948 * Licensed under the Apache License, Version 2.0 (the "License");
91949 * you may not use this file except in compliance with the License.
91950 * You may obtain a copy of the License at
91951 *
91952 * http://www.apache.org/licenses/LICENSE-2.0
91953 *
91954 * Unless required by applicable law or agreed to in writing, software
91955 * distributed under the License is distributed on an "AS IS" BASIS,
91956 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91957 * See the License for the specific language governing permissions and
91958 * limitations under the License.
91959 * =============================================================================
91960 */
91961 function fftImpl$1(x, inverse, backend) {
91962 const xData = backend.texData.get(x.dataId);
91963 const inputSize = sizeFromShape(x.shape);
91964 // Collapse all outer dimensions to a single batch dimension.
91965 const innerDimensionSize = x.shape[x.shape.length - 1];
91966 const batch = inputSize / innerDimensionSize;
91967 const input2D = reshape$3({ inputs: { x }, backend, attrs: { shape: [batch, innerDimensionSize] } });
91968 const xShape = input2D.shape;
91969 const realProgram = new FFTProgram('real', xShape, inverse);
91970 const imagProgram = new FFTProgram('imag', xShape, inverse);
91971 const inputs = [
91972 {
91973 dataId: xData.complexTensorInfos.real.dataId,
91974 dtype: xData.complexTensorInfos.real.dtype,
91975 shape: xShape
91976 },
91977 {
91978 dataId: xData.complexTensorInfos.imag.dataId,
91979 dtype: xData.complexTensorInfos.imag.dtype,
91980 shape: xShape
91981 }
91982 ];
91983 const realPart = backend.runWebGLProgram(realProgram, inputs, 'float32');
91984 const imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32');
91985 const complexOutput = complex$2({ inputs: { real: realPart, imag: imagPart }, backend });
91986 backend.disposeIntermediateTensorInfo(realPart);
91987 backend.disposeIntermediateTensorInfo(imagPart);
91988 const complexOutputReshaped = reshape$3({ inputs: { x: complexOutput }, backend, attrs: { shape: x.shape } });
91989 backend.disposeIntermediateTensorInfo(input2D);
91990 backend.disposeIntermediateTensorInfo(complexOutput);
91991 return complexOutputReshaped;
91992 }
91993
91994 /**
91995 * @license
91996 * Copyright 2020 Google LLC. All Rights Reserved.
91997 * Licensed under the Apache License, Version 2.0 (the "License");
91998 * you may not use this file except in compliance with the License.
91999 * You may obtain a copy of the License at
92000 *
92001 * http://www.apache.org/licenses/LICENSE-2.0
92002 *
92003 * Unless required by applicable law or agreed to in writing, software
92004 * distributed under the License is distributed on an "AS IS" BASIS,
92005 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92006 * See the License for the specific language governing permissions and
92007 * limitations under the License.
92008 * =============================================================================
92009 */
92010 function fft$2(args) {
92011 const { inputs, backend } = args;
92012 const { input } = inputs;
92013 return fftImpl$1(input, false /* inverse */, backend);
92014 }
92015 const fftConfig$1 = {
92016 kernelName: FFT,
92017 backendName: 'webgl',
92018 kernelFunc: fft$2
92019 };
92020
92021 /**
92022 * @license
92023 * Copyright 2019 Google LLC. All Rights Reserved.
92024 * Licensed under the Apache License, Version 2.0 (the "License");
92025 * you may not use this file except in compliance with the License.
92026 * You may obtain a copy of the License at
92027 *
92028 * http://www.apache.org/licenses/LICENSE-2.0
92029 *
92030 * Unless required by applicable law or agreed to in writing, software
92031 * distributed under the License is distributed on an "AS IS" BASIS,
92032 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92033 * See the License for the specific language governing permissions and
92034 * limitations under the License.
92035 * =============================================================================
92036 */
92037 class FillProgram {
92038 constructor(shape, value) {
92039 this.outputShape = [];
92040 this.customUniforms = [{ name: 'value', type: 'float' }];
92041 this.variableNames = ['x'];
92042 this.outputShape = shape;
92043 this.userCode = `
92044 void main() {
92045 // Input can be obtained from uniform value.
92046 setOutput(value);
92047 }
92048 `;
92049 }
92050 }
92051
92052 /**
92053 * @license
92054 * Copyright 2020 Google LLC. All Rights Reserved.
92055 * Licensed under the Apache License, Version 2.0 (the "License");
92056 * you may not use this file except in compliance with the License.
92057 * You may obtain a copy of the License at
92058 *
92059 * http://www.apache.org/licenses/LICENSE-2.0
92060 *
92061 * Unless required by applicable law or agreed to in writing, software
92062 * distributed under the License is distributed on an "AS IS" BASIS,
92063 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92064 * See the License for the specific language governing permissions and
92065 * limitations under the License.
92066 * =============================================================================
92067 */
92068 function fill$2(args) {
92069 const { backend, attrs } = args;
92070 const { shape, value } = attrs;
92071 let { dtype } = attrs;
92072 dtype = dtype || inferDtype(value);
92073 if (dtype === 'string') {
92074 // String type should be handled in CPU memory.
92075 const values = getArrayFromDType(dtype, sizeFromShape(shape));
92076 values.fill(value);
92077 return backend.makeTensorInfo(shape, dtype, values);
92078 }
92079 else {
92080 const program = new FillProgram(shape, value);
92081 const customValues = [[value]];
92082 return backend.runWebGLProgram(program, [], dtype, customValues);
92083 }
92084 }
92085 const fillConfig$1 = {
92086 kernelName: Fill,
92087 backendName: 'webgl',
92088 kernelFunc: fill$2
92089 };
92090
92091 /**
92092 * @license
92093 * Copyright 2020 Google LLC. All Rights Reserved.
92094 * Licensed under the Apache License, Version 2.0 (the "License");
92095 * you may not use this file except in compliance with the License.
92096 * You may obtain a copy of the License at
92097 *
92098 * http://www.apache.org/licenses/LICENSE-2.0
92099 *
92100 * Unless required by applicable law or agreed to in writing, software
92101 * distributed under the License is distributed on an "AS IS" BASIS,
92102 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92103 * See the License for the specific language governing permissions and
92104 * limitations under the License.
92105 * =============================================================================
92106 */
92107 class FlipLeftRightProgram {
92108 constructor(imageShape) {
92109 this.variableNames = ['Image'];
92110 this.outputShape = [];
92111 const imageWidth = imageShape[2];
92112 this.outputShape = imageShape;
92113 this.userCode = `
92114 void main() {
92115 ivec4 coords = getOutputCoords();
92116 int x = coords[2];
92117
92118 int coordX = ${imageWidth} - x - 1;
92119 float outputValue;
92120 if(coordX >= 0 && coordX < ${imageWidth}) {
92121 outputValue = getImage(coords[0], coords[1], coordX, coords[3]);
92122 } else {
92123 outputValue = getImage(coords[0], coords[1], coords[2], coords[3]);
92124 }
92125 setOutput(outputValue);
92126 }
92127 `;
92128 }
92129 }
92130
92131 /**
92132 * @license
92133 * Copyright 2020 Google LLC. All Rights Reserved.
92134 * Licensed under the Apache License, Version 2.0 (the "License");
92135 * you may not use this file except in compliance with the License.
92136 * You may obtain a copy of the License at
92137 *
92138 * http://www.apache.org/licenses/LICENSE-2.0
92139 *
92140 * Unless required by applicable law or agreed to in writing, software
92141 * distributed under the License is distributed on an "AS IS" BASIS,
92142 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92143 * See the License for the specific language governing permissions and
92144 * limitations under the License.
92145 * =============================================================================
92146 */
92147 const flipLeftRightConfig$1 = {
92148 kernelName: FlipLeftRight,
92149 backendName: 'webgl',
92150 kernelFunc: ({ inputs, backend }) => {
92151 const { image } = inputs;
92152 const webglBackend = backend;
92153 const program = new FlipLeftRightProgram(image.shape);
92154 const output = webglBackend.runWebGLProgram(program, [image], image.dtype);
92155 return output;
92156 }
92157 };
92158
92159 /**
92160 * @license
92161 * Copyright 2020 Google LLC. All Rights Reserved.
92162 * Licensed under the Apache License, Version 2.0 (the "License");
92163 * you may not use this file except in compliance with the License.
92164 * You may obtain a copy of the License at
92165 *
92166 * http://www.apache.org/licenses/LICENSE-2.0
92167 *
92168 * Unless required by applicable law or agreed to in writing, software
92169 * distributed under the License is distributed on an "AS IS" BASIS,
92170 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92171 * See the License for the specific language governing permissions and
92172 * limitations under the License.
92173 * =============================================================================
92174 */
92175 const FLOOR = `return floor(x);`;
92176 const floor$2 = unaryKernelFunc$1({ opSnippet: FLOOR, packedOpSnippet: FLOOR, cpuKernelImpl: floorImplCPU });
92177 const floorConfig$1 = {
92178 kernelName: Floor,
92179 backendName: 'webgl',
92180 kernelFunc: floor$2,
92181 };
92182
92183 /**
92184 * @license
92185 * Copyright 2020 Google LLC. All Rights Reserved.
92186 * Licensed under the Apache License, Version 2.0 (the "License");
92187 * you may not use this file except in compliance with the License.
92188 * You may obtain a copy of the License at
92189 *
92190 * http://www.apache.org/licenses/LICENSE-2.0
92191 *
92192 * Unless required by applicable law or agreed to in writing, software
92193 * distributed under the License is distributed on an "AS IS" BASIS,
92194 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92195 * See the License for the specific language governing permissions and
92196 * limitations under the License.
92197 * =============================================================================
92198 */
92199 // We use native integer division to deal with floating point imprecision. Since
92200 // we implement floor division and glsl implements truncated division, we
92201 // correct for this by subtracting 1 from result when the result is negative and
92202 // there is a remainder.
92203 const INT_DIV = `
92204 float s = sign(a) * sign(b);
92205 int ia = round(a);
92206 int ib = round(b);
92207 if (ib != 0) {
92208 // Windows (D3D) wants guaranteed non-zero int division at compile-time.
92209 return float(idiv(ia, ib, s));
92210 } else {
92211 return NAN;
92212 }
92213`;
92214 const INT_DIV_PACKED = `
92215 ivec4 ia = round(a);
92216 ivec4 ib = round(b);
92217 bvec4 cond = notEqual(ib, ivec4(0));
92218 ivec4 result = ivec4(0);
92219 vec4 s = sign(a) * sign(b);
92220
92221 // Windows (D3D) wants guaranteed non-zero int division at compile-time.
92222 if (cond[0]) {
92223 result[0] = idiv(ia[0], ib[0], s[0]);
92224 }
92225 if (cond[1]) {
92226 result[1] = idiv(ia[1], ib[1], s[1]);
92227 }
92228 if (cond[2]) {
92229 result[2] = idiv(ia[2], ib[2], s[2]);
92230 }
92231 if (cond[3]) {
92232 result[3] = idiv(ia[3], ib[3], s[3]);
92233 }
92234 return vec4(result);
92235`;
92236 const floorDiv$2 = binaryKernelFunc$1({ opSnippet: INT_DIV, packedOpSnippet: INT_DIV_PACKED, dtype: 'int32' });
92237 const floorDivConfig$1 = {
92238 kernelName: FloorDiv,
92239 backendName: 'webgl',
92240 kernelFunc: floorDiv$2
92241 };
92242
92243 /**
92244 * @license
92245 * Copyright 2018 Google LLC. All Rights Reserved.
92246 * Licensed under the Apache License, Version 2.0 (the "License");
92247 * you may not use this file except in compliance with the License.
92248 * You may obtain a copy of the License at
92249 *
92250 * http://www.apache.org/licenses/LICENSE-2.0
92251 *
92252 * Unless required by applicable law or agreed to in writing, software
92253 * distributed under the License is distributed on an "AS IS" BASIS,
92254 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92255 * See the License for the specific language governing permissions and
92256 * limitations under the License.
92257 * =============================================================================
92258 */
92259 class FromPixelsProgram {
92260 constructor(outputShape) {
92261 this.variableNames = ['A'];
92262 const glsl = getGlslDifferences();
92263 const [height, width,] = outputShape;
92264 this.outputShape = outputShape;
92265 this.userCode = `
92266 void main() {
92267 ivec3 coords = getOutputCoords();
92268 int texR = coords[0];
92269 int texC = coords[1];
92270 int depth = coords[2];
92271 vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0);
92272
92273 vec4 values = ${glsl.texture2D}(A, uv);
92274 float value;
92275 if (depth == 0) {
92276 value = values.r;
92277 } else if (depth == 1) {
92278 value = values.g;
92279 } else if (depth == 2) {
92280 value = values.b;
92281 } else if (depth == 3) {
92282 value = values.a;
92283 }
92284
92285 setOutput(floor(value * 255.0 + 0.5));
92286 }
92287 `;
92288 }
92289 }
92290
92291 /**
92292 * @license
92293 * Copyright 2018 Google LLC. All Rights Reserved.
92294 * Licensed under the Apache License, Version 2.0 (the "License");
92295 * you may not use this file except in compliance with the License.
92296 * You may obtain a copy of the License at
92297 *
92298 * http://www.apache.org/licenses/LICENSE-2.0
92299 *
92300 * Unless required by applicable law or agreed to in writing, software
92301 * distributed under the License is distributed on an "AS IS" BASIS,
92302 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92303 * See the License for the specific language governing permissions and
92304 * limitations under the License.
92305 * =============================================================================
92306 */
92307 class FromPixelsPackedProgram {
92308 constructor(outputShape) {
92309 this.variableNames = ['A'];
92310 this.packedInputs = false;
92311 this.packedOutput = true;
92312 const glsl = getGlslDifferences();
92313 const [height, width,] = outputShape;
92314 this.outputShape = outputShape;
92315 this.userCode = `
92316 void main() {
92317 ivec3 coords = getOutputCoords();
92318 int texR = coords[0];
92319 int texC = coords[1];
92320 int depth = coords[2];
92321
92322 vec4 result = vec4(0.);
92323
92324 for(int row=0; row<=1; row++) {
92325 for(int col=0; col<=1; col++) {
92326 texC = coords[1] + row;
92327 depth = coords[2] + col;
92328
92329 vec2 uv = (vec2(texC, texR) + halfCR) /
92330 vec2(${width}.0, ${height}.0);
92331 vec4 values = ${glsl.texture2D}(A, uv);
92332 float value;
92333 if (depth == 0) {
92334 value = values.r;
92335 } else if (depth == 1) {
92336 value = values.g;
92337 } else if (depth == 2) {
92338 value = values.b;
92339 } else if (depth == 3) {
92340 value = values.a;
92341 }
92342
92343 result[row * 2 + col] = floor(value * 255.0 + 0.5);
92344 }
92345 }
92346
92347 ${glsl.output} = result;
92348 }
92349 `;
92350 }
92351 }
92352
92353 /**
92354 * @license
92355 * Copyright 2019 Google LLC. All Rights Reserved.
92356 * Licensed under the Apache License, Version 2.0 (the "License");
92357 * you may not use this file except in compliance with the License.
92358 * You may obtain a copy of the License at
92359 *
92360 * http://www.apache.org/licenses/LICENSE-2.0
92361 *
92362 * Unless required by applicable law or agreed to in writing, software
92363 * distributed under the License is distributed on an "AS IS" BASIS,
92364 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92365 * See the License for the specific language governing permissions and
92366 * limitations under the License.
92367 * =============================================================================
92368 */
92369 const fromPixelsConfig = {
92370 kernelName: FromPixels,
92371 backendName: 'webgl',
92372 kernelFunc: fromPixels$1,
92373 };
92374 let fromPixels2DContext$1;
92375 function fromPixels$1(args) {
92376 const { inputs, backend, attrs } = args;
92377 let { pixels } = inputs;
92378 const { numChannels } = attrs;
92379 const isVideo = typeof (HTMLVideoElement) !== 'undefined' &&
92380 pixels instanceof HTMLVideoElement;
92381 const isImage = typeof (HTMLImageElement) !== 'undefined' &&
92382 pixels instanceof HTMLImageElement;
92383 const [width, height] = isVideo ?
92384 [
92385 pixels.videoWidth,
92386 pixels.videoHeight
92387 ] :
92388 [pixels.width, pixels.height];
92389 const texShape = [height, width];
92390 const outShape = [height, width, numChannels];
92391 if (isImage || isVideo) {
92392 if (fromPixels2DContext$1 == null) {
92393 fromPixels2DContext$1 = document.createElement('canvas').getContext('2d');
92394 }
92395 fromPixels2DContext$1.canvas.width = width;
92396 fromPixels2DContext$1.canvas.height = height;
92397 fromPixels2DContext$1.drawImage(pixels, 0, 0, width, height);
92398 pixels = fromPixels2DContext$1.canvas;
92399 }
92400 const tempPixelHandle = backend.makeTensorInfo(texShape, 'int32');
92401 // This is a byte texture with pixels.
92402 backend.texData.get(tempPixelHandle.dataId).usage = TextureUsage.PIXELS;
92403 backend.gpgpu.uploadPixelDataToTexture(backend.getTexture(tempPixelHandle.dataId), pixels);
92404 const program = env().getBool('WEBGL_PACK') ?
92405 new FromPixelsPackedProgram(outShape) :
92406 new FromPixelsProgram(outShape);
92407 const res = backend.runWebGLProgram(program, [tempPixelHandle], 'int32');
92408 backend.disposeData(tempPixelHandle.dataId);
92409 return res;
92410 }
92411
92412 /**
92413 * @license
92414 * Copyright 2020 Google LLC. All Rights Reserved.
92415 * Licensed under the Apache License, Version 2.0 (the "License");
92416 * you may not use this file except in compliance with the License.
92417 * You may obtain a copy of the License at
92418 *
92419 * http://www.apache.org/licenses/LICENSE-2.0
92420 *
92421 * Unless required by applicable law or agreed to in writing, software
92422 * distributed under the License is distributed on an "AS IS" BASIS,
92423 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92424 * See the License for the specific language governing permissions and
92425 * limitations under the License.
92426 * =============================================================================
92427 */
92428 function fusedConv2d(args) {
92429 const { inputs, backend, attrs } = args;
92430 const { x, filter, bias, preluActivationWeights } = inputs;
92431 const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
92432 const $dataFormat = convertConv2DDataFormat(dataFormat);
92433 const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
92434 let out;
92435 const intermediates = [];
92436 if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
92437 convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
92438 convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
92439 (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) {
92440 out = conv2dByMatMul({
92441 x,
92442 filter,
92443 convInfo,
92444 backend,
92445 bias,
92446 activation,
92447 preluActivationWeights,
92448 leakyreluAlpha
92449 });
92450 }
92451 else if (env().getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) {
92452 out = conv2dWithIm2Row({
92453 x,
92454 filter,
92455 convInfo,
92456 backend,
92457 bias,
92458 activation,
92459 preluActivationWeights,
92460 leakyreluAlpha
92461 });
92462 }
92463 else {
92464 const hasBias = bias != null;
92465 const hasPreluActivationWeights = preluActivationWeights != null;
92466 const hasLeakyreluAlpha = activation === 'leakyrelu';
92467 const fusedActivation = activation ? mapActivationToShaderProgram(activation, false) : null;
92468 const program = new Conv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
92469 const inputs = [x, filter];
92470 // If the input is a 1-D tensor, align it with the channels.
92471 //
92472 // For fusedConv2d, the inputs (x, W, bias, preluActivationWeights) are
92473 // supposed to be aligned with the dataFormat. The 4-D tensor inputs or
92474 // scalar inputs are originally aligned, but the 1-D tensor inputs are
92475 // supposed to be aligned with the channels (only bias and PReLU activation
92476 // weights could be a 1-D tensor).
92477 const alignInputWithDataFormat = (input, dataFormat) => {
92478 if (dataFormat === 'NCHW' && input.shape.length === 1 &&
92479 input.shape[0] !== 1) {
92480 const alignedInput = reshape$3({
92481 inputs: { x: input },
92482 backend,
92483 attrs: { shape: [input.shape[0], 1, 1] }
92484 });
92485 intermediates.push(alignedInput);
92486 return alignedInput;
92487 }
92488 return input;
92489 };
92490 if (hasBias) {
92491 inputs.push(alignInputWithDataFormat(bias, dataFormat));
92492 }
92493 if (hasPreluActivationWeights) {
92494 inputs.push(alignInputWithDataFormat(preluActivationWeights, dataFormat));
92495 }
92496 if (hasLeakyreluAlpha) {
92497 const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
92498 inputs.push($leakyreluAlpha);
92499 intermediates.push($leakyreluAlpha);
92500 }
92501 out = backend.runWebGLProgram(program, inputs, 'float32');
92502 }
92503 const outReshaped = reshape$3({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } });
92504 intermediates.push(out);
92505 intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t));
92506 return outReshaped;
92507 }
92508 const fusedConv2DConfig$1 = {
92509 kernelName: FusedConv2D,
92510 backendName: 'webgl',
92511 kernelFunc: fusedConv2d,
92512 };
92513
92514 /**
92515 * @license
92516 * Copyright 2020 Google LLC. All Rights Reserved.
92517 * Licensed under the Apache License, Version 2.0 (the "License");
92518 * you may not use this file except in compliance with the License.
92519 * You may obtain a copy of the License at
92520 *
92521 * http://www.apache.org/licenses/LICENSE-2.0
92522 *
92523 * Unless required by applicable law or agreed to in writing, software
92524 * distributed under the License is distributed on an "AS IS" BASIS,
92525 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92526 * See the License for the specific language governing permissions and
92527 * limitations under the License.
92528 * =============================================================================
92529 */
92530 function fusedDepthwiseConv2D$1(args) {
92531 const { inputs, backend, attrs } = args;
92532 const { x, filter, bias, preluActivationWeights } = inputs;
92533 const { strides, pad, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
92534 const intermediates = [];
92535 let $dilations = dilations;
92536 if ($dilations == null) {
92537 $dilations = [1, 1];
92538 }
92539 assert(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +
92540 `1. Got strides ${strides} and dilations '${$dilations}'`);
92541 const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */);
92542 const shouldPackDepthwiseConv = env().getBool('WEBGL_PACK_DEPTHWISECONV') &&
92543 convInfo.strideWidth <= 2 &&
92544 convInfo.outChannels / convInfo.inChannels === 1;
92545 const fusedActivation = activation ?
92546 mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) :
92547 null;
92548 const programInputs = [x, filter];
92549 const hasBias = bias != null;
92550 const hasPreluActivationWeights = preluActivationWeights != null;
92551 const hasLeakyreluAlpha = activation === 'leakyrelu';
92552 if (hasBias) {
92553 programInputs.push(bias);
92554 }
92555 if (hasPreluActivationWeights) {
92556 programInputs.push(preluActivationWeights);
92557 }
92558 if (hasLeakyreluAlpha) {
92559 const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
92560 programInputs.push($leakyreluAlpha);
92561 intermediates.push($leakyreluAlpha);
92562 }
92563 let program;
92564 if (shouldPackDepthwiseConv) {
92565 program = new DepthwiseConvPacked2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
92566 }
92567 else {
92568 program = new DepthwiseConv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
92569 }
92570 const customValues = [
92571 [convInfo.padInfo.top, convInfo.padInfo.left],
92572 [convInfo.strideHeight, convInfo.strideWidth],
92573 [convInfo.dilationHeight, convInfo.dilationWidth],
92574 [convInfo.inHeight, convInfo.inWidth]
92575 ];
92576 const result = backend.runWebGLProgram(program, programInputs, 'float32', customValues);
92577 intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t));
92578 return result;
92579 }
92580 const fusedDepthwiseConv2DConfig$1 = {
92581 kernelName: FusedDepthwiseConv2D,
92582 backendName: 'webgl',
92583 kernelFunc: fusedDepthwiseConv2D$1,
92584 };
92585
92586 class GatherNDProgram {
92587 constructor(sliceDim, strides, shape) {
92588 this.sliceDim = sliceDim;
92589 this.strides = strides;
92590 this.variableNames = ['x', 'indices'];
92591 this.outputShape = shape;
92592 const stridesType = getCoordsDataType(strides.length);
92593 const dtype = getCoordsDataType(shape.length);
92594 const strideString = this.sliceDim > 1 ? 'strides[j]' : 'strides';
92595 this.userCode = `
92596 ${stridesType} strides = ${stridesType}(${this.strides});
92597 void main() {
92598 ${dtype} coords = getOutputCoords();
92599 int flattenIndex = 0;
92600 for (int j = 0; j < ${this.sliceDim}; j++) {
92601 int index = round(getIndices(coords[0], j));
92602 flattenIndex += index * ${strideString};
92603 }
92604 setOutput(getX(flattenIndex, coords[1]));
92605 }
92606 `;
92607 }
92608 }
92609
92610 /**
92611 * @license
92612 * Copyright 2020 Google LLC. All Rights Reserved.
92613 * Licensed under the Apache License, Version 2.0 (the "License");
92614 * you may not use this file except in compliance with the License.
92615 * You may obtain a copy of the License at
92616 *
92617 * http://www.apache.org/licenses/LICENSE-2.0
92618 *
92619 * Unless required by applicable law or agreed to in writing, software
92620 * distributed under the License is distributed on an "AS IS" BASIS,
92621 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92622 * See the License for the specific language governing permissions and
92623 * limitations under the License.
92624 * =============================================================================
92625 */
92626 function gatherNd$1(args) {
92627 const { inputs, backend } = args;
92628 const { params, indices } = inputs;
92629 const indicesShape = indices.shape;
92630 const sliceRank = indicesShape[indicesShape.length - 1];
92631 const paramsSize = sizeFromShape(params.shape);
92632 const [resultShape, numSlices, sliceSize, strides] = prepareAndValidate(params, indices);
92633 const flattenIndices = reshape$3({ inputs: { x: indices }, backend, attrs: { shape: [numSlices, sliceRank] } });
92634 const flattenX = reshape$3({
92635 inputs: { x: params },
92636 backend,
92637 attrs: { shape: [(sizeFromShape(params.shape) / sliceSize), sliceSize] }
92638 });
92639 if (backend.shouldExecuteOnCPU([params, indices]) ||
92640 params.dtype === 'string') {
92641 const indicesData = backend.readSync(indices.dataId);
92642 const paramsBuf = backend.bufferSync(params);
92643 const outValue = gatherNdImplCPU(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize);
92644 return backend.makeTensorInfo(resultShape, params.dtype, outValue.values);
92645 }
92646 const program = new GatherNDProgram(sliceRank, strides, [numSlices, sliceSize]);
92647 const res = backend.runWebGLProgram(program, [flattenX, flattenIndices], flattenX.dtype);
92648 const reshaped = reshape$3({ inputs: { x: res }, backend, attrs: { shape: resultShape } });
92649 backend.disposeIntermediateTensorInfo(flattenIndices);
92650 backend.disposeIntermediateTensorInfo(flattenX);
92651 backend.disposeIntermediateTensorInfo(res);
92652 return reshaped;
92653 }
92654 const gatherNdConfig$1 = {
92655 kernelName: GatherNd,
92656 backendName: 'webgl',
92657 kernelFunc: gatherNd$1
92658 };
92659
92660 /**
92661 * @license
92662 * Copyright 2017 Google LLC. All Rights Reserved.
92663 * Licensed under the Apache License, Version 2.0 (the "License");
92664 * you may not use this file except in compliance with the License.
92665 * You may obtain a copy of the License at
92666 *
92667 * http://www.apache.org/licenses/LICENSE-2.0
92668 *
92669 * Unless required by applicable law or agreed to in writing, software
92670 * distributed under the License is distributed on an "AS IS" BASIS,
92671 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92672 * See the License for the specific language governing permissions and
92673 * limitations under the License.
92674 * =============================================================================
92675 */
92676 class GatherProgram {
92677 constructor(aShape, outputShape) {
92678 this.variableNames = ['A', 'indices'];
92679 this.outputShape = outputShape;
92680 this.rank = outputShape.length;
92681 const dtype = getCoordsDataType(this.rank);
92682 const sourceCoords = getSourceCoords$1(aShape, 2);
92683 this.userCode = `
92684 void main() {
92685 ${dtype} resRC = getOutputCoords();
92686 int index = int(getIndices(resRC.x, resRC.z));
92687 float inBounds = (index >= 0) && (index < ${aShape[2]}) ? 1.0 : 0.0;
92688 setOutput(inBounds * getA(${sourceCoords}));
92689 }
92690 `;
92691 }
92692 }
92693 // The input and output are always flattened into rank 4 tensors.
92694 function getSourceCoords$1(aShape, axis) {
92695 const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
92696 const sourceCoords = [];
92697 for (let i = 0; i < aShape.length; i++) {
92698 if (i === 2) {
92699 sourceCoords.push('index');
92700 }
92701 else {
92702 sourceCoords.push(`${currentCoords[i]}`);
92703 }
92704 }
92705 return sourceCoords.join();
92706 }
92707
92708 /**
92709 * @license
92710 * Copyright 2020 Google LLC. All Rights Reserved.
92711 * Licensed under the Apache License, Version 2.0 (the "License");
92712 * you may not use this file except in compliance with the License.
92713 * You may obtain a copy of the License at
92714 *
92715 * http://www.apache.org/licenses/LICENSE-2.0
92716 *
92717 * Unless required by applicable law or agreed to in writing, software
92718 * distributed under the License is distributed on an "AS IS" BASIS,
92719 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92720 * See the License for the specific language governing permissions and
92721 * limitations under the License.
92722 * =============================================================================
92723 */
92724 function gatherV2$1(args) {
92725 const { inputs, backend, attrs } = args;
92726 const { x, indices } = inputs;
92727 const { axis, batchDims } = attrs;
92728 const parsedAxis = parseAxisParam(axis, x.shape)[0];
92729 if (env().get('DEBUG')) {
92730 // In debug mode, throw error when any index is out of bound.
92731 // Otherwise, just fill out of bounds with zeroes.
92732 const indicesVals = backend.readSync(indices.dataId);
92733 const axisDim = x.shape[parsedAxis];
92734 for (let i = 0; i < indicesVals.length; ++i) {
92735 const index = indicesVals[i];
92736 assert(index <= axisDim - 1 && index >= 0, () => `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);
92737 }
92738 }
92739 const shapeInfo = collectGatherOpShapeInfo(x, indices, parsedAxis, batchDims);
92740 const indicesSize = sizeFromShape(indices.shape);
92741 const toDispose = [];
92742 const flattenX = reshape$3({
92743 inputs: { x },
92744 backend,
92745 attrs: {
92746 shape: [
92747 shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize,
92748 shapeInfo.sliceSize
92749 ]
92750 }
92751 });
92752 const flattenIndex = reshape$3({
92753 inputs: { x: indices },
92754 backend,
92755 attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] }
92756 });
92757 toDispose.push(flattenX);
92758 toDispose.push(flattenIndex);
92759 const flattenOutputShape = [
92760 shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize,
92761 shapeInfo.sliceSize
92762 ];
92763 if (backend.shouldExecuteOnCPU([x, indices]) || x.dtype === 'string') {
92764 const indicesBuf = backend.bufferSync(flattenIndex);
92765 const xBuf = backend.bufferSync(flattenX);
92766 const outBuf = gatherV2ImplCPU(xBuf, indicesBuf, flattenOutputShape);
92767 toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
92768 return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values);
92769 }
92770 const program = new GatherProgram(flattenX.shape, flattenOutputShape);
92771 const res = backend.runWebGLProgram(program, [flattenX, flattenIndex], flattenX.dtype);
92772 toDispose.push(res);
92773 const reshaped = reshape$3({ inputs: { x: res }, backend, attrs: { shape: shapeInfo.outputShape } });
92774 toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
92775 return reshaped;
92776 }
92777 const gatherV2Config$1 = {
92778 kernelName: GatherV2,
92779 backendName: 'webgl',
92780 kernelFunc: gatherV2$1
92781 };
92782
92783 /**
92784 * @license
92785 * Copyright 2020 Google LLC. All Rights Reserved.
92786 * Licensed under the Apache License, Version 2.0 (the "License");
92787 * you may not use this file except in compliance with the License.
92788 * You may obtain a copy of the License at
92789 *
92790 * http://www.apache.org/licenses/LICENSE-2.0
92791 *
92792 * Unless required by applicable law or agreed to in writing, software
92793 * distributed under the License is distributed on an "AS IS" BASIS,
92794 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92795 * See the License for the specific language governing permissions and
92796 * limitations under the License.
92797 * =============================================================================
92798 */
92799 const GREATER = `return float(a > b);`;
92800 const GREATER_PACKED = `
92801 return vec4(greaterThan(a, b));
92802`;
92803 const greater$3 = binaryKernelFunc$1({
92804 opSnippet: GREATER,
92805 packedOpSnippet: GREATER_PACKED,
92806 cpuKernelImpl: greaterImplCPU,
92807 dtype: 'bool'
92808 });
92809 const greaterConfig$1 = {
92810 kernelName: Greater,
92811 backendName: 'webgl',
92812 kernelFunc: greater$3
92813 };
92814
92815 /**
92816 * @license
92817 * Copyright 2020 Google LLC. All Rights Reserved.
92818 * Licensed under the Apache License, Version 2.0 (the "License");
92819 * you may not use this file except in compliance with the License.
92820 * You may obtain a copy of the License at
92821 *
92822 * http://www.apache.org/licenses/LICENSE-2.0
92823 *
92824 * Unless required by applicable law or agreed to in writing, software
92825 * distributed under the License is distributed on an "AS IS" BASIS,
92826 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92827 * See the License for the specific language governing permissions and
92828 * limitations under the License.
92829 * =============================================================================
92830 */
92831 const GREATER_EQUAL = `return float(a >= b);`;
92832 const GREATER_EQUAL_PACKED = `
92833 return vec4(greaterThanEqual(a, b));
92834`;
92835 const greaterEqual$2 = binaryKernelFunc$1({
92836 opSnippet: GREATER_EQUAL,
92837 packedOpSnippet: GREATER_EQUAL_PACKED,
92838 dtype: 'bool',
92839 cpuKernelImpl: greaterEqualImplCPU
92840 });
92841 const greaterEqualConfig$1 = {
92842 kernelName: GreaterEqual,
92843 backendName: 'webgl',
92844 kernelFunc: greaterEqual$2
92845 };
92846
92847 /**
92848 * @license
92849 * Copyright 2020 Google LLC. All Rights Reserved.
92850 * Licensed under the Apache License, Version 2.0 (the "License");
92851 * you may not use this file except in compliance with the License.
92852 * You may obtain a copy of the License at
92853 *
92854 * http://www.apache.org/licenses/LICENSE-2.0
92855 *
92856 * Unless required by applicable law or agreed to in writing, software
92857 * distributed under the License is distributed on an "AS IS" BASIS,
92858 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92859 * See the License for the specific language governing permissions and
92860 * limitations under the License.
92861 * =============================================================================
92862 */
92863 function ifft$2(args) {
92864 const { inputs, backend } = args;
92865 const { input } = inputs;
92866 return fftImpl$1(input, true /* inverse */, backend);
92867 }
92868 const ifftConfig$1 = {
92869 kernelName: IFFT,
92870 backendName: 'webgl',
92871 kernelFunc: ifft$2
92872 };
92873
92874 /**
92875 * @license
92876 * Copyright 2020 Google LLC. All Rights Reserved.
92877 * Licensed under the Apache License, Version 2.0 (the "License");
92878 * you may not use this file except in compliance with the License.
92879 * You may obtain a copy of the License at
92880 *
92881 * http://www.apache.org/licenses/LICENSE-2.0
92882 *
92883 * Unless required by applicable law or agreed to in writing, software
92884 * distributed under the License is distributed on an "AS IS" BASIS,
92885 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92886 * See the License for the specific language governing permissions and
92887 * limitations under the License.
92888 * =============================================================================
92889 */
92890 const IS_FINITE = `return float(!isnan(x) && !isinf(x));`;
92891 const isFinite$3 = unaryKernelFunc$1({ opSnippet: IS_FINITE, dtype: 'bool' });
92892 const isFiniteConfig$1 = {
92893 kernelName: IsFinite,
92894 backendName: 'webgl',
92895 kernelFunc: isFinite$3,
92896 };
92897
92898 /**
92899 * @license
92900 * Copyright 2020 Google LLC. All Rights Reserved.
92901 * Licensed under the Apache License, Version 2.0 (the "License");
92902 * you may not use this file except in compliance with the License.
92903 * You may obtain a copy of the License at
92904 *
92905 * http://www.apache.org/licenses/LICENSE-2.0
92906 *
92907 * Unless required by applicable law or agreed to in writing, software
92908 * distributed under the License is distributed on an "AS IS" BASIS,
92909 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92910 * See the License for the specific language governing permissions and
92911 * limitations under the License.
92912 * =============================================================================
92913 */
92914 const IS_INF = `return float(isinf(x));`;
92915 const isInf$2 = unaryKernelFunc$1({ opSnippet: IS_INF, dtype: 'bool' });
92916 const isInfConfig$1 = {
92917 kernelName: IsInf,
92918 backendName: 'webgl',
92919 kernelFunc: isInf$2,
92920 };
92921
92922 /**
92923 * @license
92924 * Copyright 2020 Google LLC. All Rights Reserved.
92925 * Licensed under the Apache License, Version 2.0 (the "License");
92926 * you may not use this file except in compliance with the License.
92927 * You may obtain a copy of the License at
92928 *
92929 * http://www.apache.org/licenses/LICENSE-2.0
92930 *
92931 * Unless required by applicable law or agreed to in writing, software
92932 * distributed under the License is distributed on an "AS IS" BASIS,
92933 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92934 * See the License for the specific language governing permissions and
92935 * limitations under the License.
92936 * =============================================================================
92937 */
92938 const IS_NAN = `return float(isnan(x));`;
92939 const isNaN$3 = unaryKernelFunc$1({ opSnippet: IS_NAN, dtype: 'bool' });
92940 const isNaNConfig$1 = {
92941 kernelName: IsNan,
92942 backendName: 'webgl',
92943 kernelFunc: isNaN$3,
92944 };
92945
92946 /**
92947 * @license
92948 * Copyright 2020 Google LLC. All Rights Reserved.
92949 * Licensed under the Apache License, Version 2.0 (the "License");
92950 * you may not use this file except in compliance with the License.
92951 * You may obtain a copy of the License at
92952 *
92953 * http://www.apache.org/licenses/LICENSE-2.0
92954 *
92955 * Unless required by applicable law or agreed to in writing, software
92956 * distributed under the License is distributed on an "AS IS" BASIS,
92957 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92958 * See the License for the specific language governing permissions and
92959 * limitations under the License.
92960 * =============================================================================
92961 */
92962 const LESS = `return float(a < b);`;
92963 const LESS_PACKED = `
92964 return vec4(lessThan(a, b));
92965`;
92966 const less$3 = binaryKernelFunc$1({
92967 opSnippet: LESS,
92968 packedOpSnippet: LESS_PACKED,
92969 cpuKernelImpl: lessImplCPU,
92970 dtype: 'bool'
92971 });
92972 const lessConfig$1 = {
92973 kernelName: Less,
92974 backendName: 'webgl',
92975 kernelFunc: less$3
92976 };
92977
92978 /**
92979 * @license
92980 * Copyright 2020 Google LLC. All Rights Reserved.
92981 * Licensed under the Apache License, Version 2.0 (the "License");
92982 * you may not use this file except in compliance with the License.
92983 * You may obtain a copy of the License at
92984 *
92985 * http://www.apache.org/licenses/LICENSE-2.0
92986 *
92987 * Unless required by applicable law or agreed to in writing, software
92988 * distributed under the License is distributed on an "AS IS" BASIS,
92989 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92990 * See the License for the specific language governing permissions and
92991 * limitations under the License.
92992 * =============================================================================
92993 */
92994 const LESS_EQUAL = `return float(a <= b);`;
92995 const LESS_EQUAL_PACKED = `
92996 return vec4(lessThanEqual(a, b));
92997`;
92998 const lessEqual$2 = binaryKernelFunc$1({
92999 opSnippet: LESS_EQUAL,
93000 packedOpSnippet: LESS_EQUAL_PACKED,
93001 cpuKernelImpl: lessEqualImplCPU,
93002 dtype: 'bool'
93003 });
93004 const lessEqualConfig$1 = {
93005 kernelName: LessEqual,
93006 backendName: 'webgl',
93007 kernelFunc: lessEqual$2
93008 };
93009
93010 /**
93011 * @license
93012 * Copyright 2020 Google LLC. All Rights Reserved.
93013 * Licensed under the Apache License, Version 2.0 (the "License");
93014 * you may not use this file except in compliance with the License.
93015 * You may obtain a copy of the License at
93016 *
93017 * http://www.apache.org/licenses/LICENSE-2.0
93018 *
93019 * Unless required by applicable law or agreed to in writing, software
93020 * distributed under the License is distributed on an "AS IS" BASIS,
93021 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93022 * See the License for the specific language governing permissions and
93023 * limitations under the License.
93024 * =============================================================================
93025 */
93026 function linSpace$1(args) {
93027 const { backend, attrs } = args;
93028 const { start, stop, num } = attrs;
93029 // TODO: Use CPU implementation due to the precision problem in Safari.
93030 const outVals = linSpaceImplCPU(start, stop, num);
93031 return backend.makeTensorInfo([outVals.length], 'float32', outVals);
93032 }
93033 const linSpaceConfig$1 = {
93034 kernelName: LinSpace,
93035 backendName: 'webgl',
93036 kernelFunc: linSpace$1
93037 };
93038
93039 /**
93040 * @license
93041 * Copyright 2020 Google LLC. All Rights Reserved.
93042 * Licensed under the Apache License, Version 2.0 (the "License");
93043 * you may not use this file except in compliance with the License.
93044 * You may obtain a copy of the License at
93045 *
93046 * http://www.apache.org/licenses/LICENSE-2.0
93047 *
93048 * Unless required by applicable law or agreed to in writing, software
93049 * distributed under the License is distributed on an "AS IS" BASIS,
93050 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93051 * See the License for the specific language governing permissions and
93052 * limitations under the License.
93053 * =============================================================================
93054 */
93055 // Windows chrome return 0 if the input is negative value. We will specifically
93056 // return NaN if the input is 0 to solve compatiblity issue.
93057 const LOG = CHECK_NAN_SNIPPET_UNARY + `
93058 return x < 0.0 ? 0./0. : log(x);
93059`;
93060 const LOG_PACKED = `
93061 vec4 result = log(x);
93062 bvec4 isNaN = isnan(x);
93063 result.r = isNaN.r ? x.r : (x.r < 0.0 ? 0./0. : result.r);
93064 result.g = isNaN.g ? x.g : (x.g < 0.0 ? 0./0. : result.g);
93065 result.b = isNaN.b ? x.b : (x.b < 0.0 ? 0./0. : result.b);
93066 result.a = isNaN.a ? x.a : (x.a < 0.0 ? 0./0. : result.a);
93067 return result;
93068`;
93069 const log$3 = unaryKernelFunc$1({ opSnippet: LOG, packedOpSnippet: LOG_PACKED, cpuKernelImpl: logImplCPU });
93070 const logConfig$1 = {
93071 kernelName: Log,
93072 backendName: 'webgl',
93073 kernelFunc: log$3
93074 };
93075
93076 /**
93077 * @license
93078 * Copyright 2020 Google LLC. All Rights Reserved.
93079 * Licensed under the Apache License, Version 2.0 (the "License");
93080 * you may not use this file except in compliance with the License.
93081 * You may obtain a copy of the License at
93082 *
93083 * http://www.apache.org/licenses/LICENSE-2.0
93084 *
93085 * Unless required by applicable law or agreed to in writing, software
93086 * distributed under the License is distributed on an "AS IS" BASIS,
93087 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93088 * See the License for the specific language governing permissions and
93089 * limitations under the License.
93090 * =============================================================================
93091 */
93092 const LOG1P = CHECK_NAN_SNIPPET_UNARY + `
93093 return log(1.0 + x);
93094`;
93095 const log1p$2 = unaryKernelFunc$1({ opSnippet: LOG1P });
93096 const log1pConfig$1 = {
93097 kernelName: Log1p,
93098 backendName: 'webgl',
93099 kernelFunc: log1p$2,
93100 };
93101
93102 /**
93103 * @license
93104 * Copyright 2020 Google LLC. All Rights Reserved.
93105 * Licensed under the Apache License, Version 2.0 (the "License");
93106 * you may not use this file except in compliance with the License.
93107 * You may obtain a copy of the License at
93108 *
93109 * http://www.apache.org/licenses/LICENSE-2.0
93110 *
93111 * Unless required by applicable law or agreed to in writing, software
93112 * distributed under the License is distributed on an "AS IS" BASIS,
93113 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93114 * See the License for the specific language governing permissions and
93115 * limitations under the License.
93116 * =============================================================================
93117 */
93118 const LOGICAL_AND = `return float(a >= 1.0 && b >= 1.0);`;
93119 const LOGICAL_AND_PACKED = `
93120 return vec4(
93121 vec4(greaterThanEqual(a, vec4(1.0))) *
93122 vec4(greaterThanEqual(b, vec4(1.0))));
93123`;
93124 const logicalAnd$2 = binaryKernelFunc$1({
93125 opSnippet: LOGICAL_AND,
93126 packedOpSnippet: LOGICAL_AND_PACKED,
93127 dtype: 'bool'
93128 });
93129 const logicalAndConfig$1 = {
93130 kernelName: LogicalAnd,
93131 backendName: 'webgl',
93132 kernelFunc: logicalAnd$2
93133 };
93134
93135 /**
93136 * @license
93137 * Copyright 2020 Google LLC. All Rights Reserved.
93138 * Licensed under the Apache License, Version 2.0 (the "License");
93139 * you may not use this file except in compliance with the License.
93140 * You may obtain a copy of the License at
93141 *
93142 * http://www.apache.org/licenses/LICENSE-2.0
93143 *
93144 * Unless required by applicable law or agreed to in writing, software
93145 * distributed under the License is distributed on an "AS IS" BASIS,
93146 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93147 * See the License for the specific language governing permissions and
93148 * limitations under the License.
93149 * =============================================================================
93150 */
93151 const LOGICAL_NOT = `return float(!(x >= 1.0));`;
93152 const logicalNot$2 = unaryKernelFunc$1({ opSnippet: LOGICAL_NOT });
93153 const logicalNotConfig$1 = {
93154 kernelName: LogicalNot,
93155 backendName: 'webgl',
93156 kernelFunc: logicalNot$2,
93157 };
93158
93159 /**
93160 * @license
93161 * Copyright 2020 Google LLC. All Rights Reserved.
93162 * Licensed under the Apache License, Version 2.0 (the "License");
93163 * you may not use this file except in compliance with the License.
93164 * You may obtain a copy of the License at
93165 *
93166 * http://www.apache.org/licenses/LICENSE-2.0
93167 *
93168 * Unless required by applicable law or agreed to in writing, software
93169 * distributed under the License is distributed on an "AS IS" BASIS,
93170 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93171 * See the License for the specific language governing permissions and
93172 * limitations under the License.
93173 * =============================================================================
93174 */
93175 const LOGICAL_OR = `return float(a >= 1.0 || b >= 1.0);`;
93176 const LOGICAL_OR_PACKED = `
93177 return min(
93178 vec4(greaterThanEqual(a, vec4(1.0))) +
93179 vec4(greaterThanEqual(b, vec4(1.0))),
93180 vec4(1.0));
93181`;
93182 const logicalOr$2 = binaryKernelFunc$1({ opSnippet: LOGICAL_OR, packedOpSnippet: LOGICAL_OR_PACKED, dtype: 'bool' });
93183 const logicalOrConfig$1 = {
93184 kernelName: LogicalOr,
93185 backendName: 'webgl',
93186 kernelFunc: logicalOr$2
93187 };
93188
93189 /**
93190 * @license
93191 * Copyright 2017 Google LLC. All Rights Reserved.
93192 * Licensed under the Apache License, Version 2.0 (the "License");
93193 * you may not use this file except in compliance with the License.
93194 * You may obtain a copy of the License at
93195 *
93196 * http://www.apache.org/licenses/LICENSE-2.0
93197 *
93198 * Unless required by applicable law or agreed to in writing, software
93199 * distributed under the License is distributed on an "AS IS" BASIS,
93200 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93201 * See the License for the specific language governing permissions and
93202 * limitations under the License.
93203 * =============================================================================
93204 */
93205 class LRNProgram {
93206 constructor(xShape, radius, bias, alpha, beta) {
93207 this.variableNames = ['x'];
93208 this.outputShape = [];
93209 const rad = radius;
93210 const maxD = xShape[3] - 1;
93211 this.outputShape = xShape;
93212 // optimize pow(bias + alpha * sum, -beta)
93213 // src: https://github.com/tensorflow/tensorflow/..
93214 // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
93215 // tensorflow/core/kernels/mkl_lrn_op.cc#L320
93216 let powOperator;
93217 const basis = `float(${bias}) + float(${alpha}) * sum`;
93218 if (beta === 0.5) {
93219 powOperator = `inversesqrt(${basis})`;
93220 }
93221 else if (beta === 1.0) {
93222 powOperator = `1.0/(${basis})`;
93223 }
93224 else {
93225 powOperator = `exp(log(${basis}) * float(-${beta}));`;
93226 }
93227 this.userCode = `
93228 void main() {
93229 ivec4 coords = getOutputCoords();
93230 int b = coords[0];
93231 int r = coords[1];
93232 int c = coords[2];
93233 int d = coords[3];
93234 float x = getX(b, r, c, d);
93235 float sum = 0.0;
93236 for (int j = -${rad}; j <= ${rad}; j++) {
93237 int idx = d + j;
93238 if (idx >= 0 && idx <= ${maxD}) {
93239 float z = getX(b, r, c, idx);
93240 sum += z * z;
93241 }
93242 }
93243 float val = x * ${powOperator};
93244 setOutput(val);
93245 }
93246 `;
93247 }
93248 }
93249
93250 /**
93251 * @license
93252 * Copyright 2019 Google LLC. All Rights Reserved.
93253 * Licensed under the Apache License, Version 2.0 (the "License");
93254 * you may not use this file except in compliance with the License.
93255 * You may obtain a copy of the License at
93256 *
93257 * http://www.apache.org/licenses/LICENSE-2.0
93258 *
93259 * Unless required by applicable law or agreed to in writing, software
93260 * distributed under the License is distributed on an "AS IS" BASIS,
93261 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93262 * See the License for the specific language governing permissions and
93263 * limitations under the License.
93264 * =============================================================================
93265 */
93266 class LRNPackedProgram {
93267 constructor(xShape, radius, bias, alpha, beta) {
93268 this.variableNames = ['x'];
93269 this.outputShape = [];
93270 this.packedInputs = true;
93271 this.packedOutput = true;
93272 const rad = radius;
93273 const maxD = xShape[3] - 1;
93274 this.outputShape = xShape;
93275 // optimize pow(bias + alpha * sum, -beta)
93276 // src: https://github.com/tensorflow/tensorflow/..
93277 // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
93278 // tensorflow/core/kernels/mkl_lrn_op.cc#L320
93279 let powOperator;
93280 const basis = `float(${bias}) + float(${alpha}) * sum`;
93281 if (beta === 0.5) {
93282 powOperator = `inversesqrt(${basis})`;
93283 }
93284 else if (beta === 1.0) {
93285 powOperator = `1.0/(${basis})`;
93286 }
93287 else {
93288 powOperator = `exp(log(${basis}) * float(-${beta}));`;
93289 }
93290 this.userCode = `
93291 void main() {
93292 ivec4 coords = getOutputCoords();
93293 int b = coords.x;
93294 int r = coords.y;
93295 int c = coords.z;
93296 int d = coords.w;
93297
93298 bool hasNextCol = d < ${this.outputShape[3]};
93299 bool hasNextRow = c < ${this.outputShape[2]};
93300
93301 vec4 sum = vec4(0.);
93302 vec4 xFragAtOutputCoords = getX(b, r, c, d);
93303
93304 vec4 xAtOutputCoords = vec4(
93305 getChannel(xFragAtOutputCoords, vec2(c, d)),
93306 hasNextCol ?
93307 getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,
93308 hasNextRow ?
93309 getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,
93310 (hasNextRow && hasNextCol) ?
93311 getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0
93312 );
93313
93314 int firstChannel = d - ${rad};
93315 vec2 cache = vec2(0.);
93316 if(firstChannel >= 0){
93317 vec4 firstChannelFrag = getX(b, r, c, firstChannel);
93318 cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));
93319 if(hasNextRow){
93320 cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));
93321 }
93322 }
93323
93324 ivec2 depth = ivec2(d, d + 1);
93325 for (int j = - ${rad}; j <= ${rad}; j++) {
93326 ivec2 idx = depth + j;
93327 bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));
93328 bvec2 belowUpperBound = lessThanEqual(idx, ivec2(${maxD}));
93329
93330 bool depthInRange = aboveLowerBound.x && belowUpperBound.x;
93331 bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;
93332
93333 if(depthInRange || depthPlusOneInRange){
93334 vec4 z = vec4(0.);
93335 vec4 xFragAtCurrentDepth;
93336 z.xz = cache.xy;
93337 if(depthPlusOneInRange && hasNextCol){
93338 xFragAtCurrentDepth = idx.y != d ?
93339 getX(b, r, c, idx.y) : xFragAtOutputCoords;
93340 z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));
93341 if(hasNextRow){
93342 z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));
93343 }
93344 }
93345 cache.xy = z.yw;
93346 sum += z * z;
93347 }
93348 }
93349 vec4 result = xAtOutputCoords * ${powOperator};
93350 setOutput(result);
93351 }
93352 `;
93353 }
93354 }
93355
93356 /**
93357 * @license
93358 * Copyright 2020 Google LLC. All Rights Reserved.
93359 * Licensed under the Apache License, Version 2.0 (the "License");
93360 * you may not use this file except in compliance with the License.
93361 * You may obtain a copy of the License at
93362 *
93363 * http://www.apache.org/licenses/LICENSE-2.0
93364 *
93365 * Unless required by applicable law or agreed to in writing, software
93366 * distributed under the License is distributed on an "AS IS" BASIS,
93367 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93368 * See the License for the specific language governing permissions and
93369 * limitations under the License.
93370 * =============================================================================
93371 */
93372 const lrn = (args) => {
93373 const { inputs, backend, attrs } = args;
93374 const { x } = inputs;
93375 const { depthRadius, bias, alpha, beta } = attrs;
93376 const program = env().getBool('WEBGL_PACK_NORMALIZATION') ?
93377 new LRNPackedProgram(x.shape, depthRadius, bias, alpha, beta) :
93378 new LRNProgram(x.shape, depthRadius, bias, alpha, beta);
93379 return backend.runWebGLProgram(program, [x], x.dtype);
93380 };
93381 // tslint:disable-next-line: variable-name
93382 const LRNConfig$1 = {
93383 kernelName: LRN,
93384 backendName: 'webgl',
93385 kernelFunc: lrn
93386 };
93387
93388 /**
93389 * @license
93390 * Copyright 2018 Google LLC. All Rights Reserved.
93391 * Licensed under the Apache License, Version 2.0 (the "License");
93392 * you may not use this file except in compliance with the License.
93393 * You may obtain a copy of the License at
93394 *
93395 * http://www.apache.org/licenses/LICENSE-2.0
93396 *
93397 * Unless required by applicable law or agreed to in writing, software
93398 * distributed under the License is distributed on an "AS IS" BASIS,
93399 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93400 * See the License for the specific language governing permissions and
93401 * limitations under the License.
93402 * =============================================================================
93403 */
93404 class LRNGradProgram {
93405 constructor(inputShape, depthRadius, bias, alpha, beta) {
93406 this.variableNames = ['inputImage', 'outputImage', 'dy'];
93407 this.outputShape = [];
93408 this.outputShape = inputShape;
93409 this.depth = inputShape[3];
93410 this.depthRadius = depthRadius;
93411 this.bias = bias;
93412 this.alpha = alpha;
93413 this.beta = beta;
93414 this.userCode = `
93415 void main() {
93416 ivec4 coords = getOutputCoords();
93417 int b = coords[0];
93418 int r = coords[1];
93419 int c = coords[2];
93420
93421 float result = 0.0;
93422 for (int d = 0; d < ${this.depth}; ++d) {
93423 int depthBegin = int(max(0.0, float(d - ${depthRadius})));
93424 int depthEnd = int(min(float(${this.depth}),
93425 float(d + ${depthRadius} + 1)));
93426
93427 const int MIN_DEPTH_BEGIN = 0;
93428 const int MAX_DEPTH_END = ${this.depth};
93429
93430 float norm = 0.0;
93431 for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) {
93432 if (k < depthBegin){
93433 continue;
93434 }
93435 else if (k >= depthBegin && k < depthEnd) {
93436 norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k);
93437 }
93438 else {
93439 break;
93440 }
93441 }
93442
93443 norm = float(${alpha}) * norm + float(${bias});
93444
93445 for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){
93446 if (k < depthBegin){
93447 continue;
93448 }
93449 else if (k >= depthBegin && k < depthEnd){
93450 float dyi = -2.0 * float(${alpha})
93451 * float(${beta})
93452 * getInputImage(b ,r ,c, k) * getOutputImage(b, r, c, d)
93453 / norm;
93454 if (k == d) {
93455 dyi += pow(norm, -1.0 * ${beta});
93456 }
93457 if (k == coords[3]) {
93458 dyi *= getDy(b, r, c, d);
93459 result += dyi;
93460 }
93461 }
93462 else {
93463 break;
93464 }
93465 }
93466 }
93467 setOutput(result);
93468 }
93469 `;
93470 }
93471 }
93472
93473 /**
93474 * @license
93475 * Copyright 2020 Google LLC. All Rights Reserved.
93476 * Licensed under the Apache License, Version 2.0 (the "License");
93477 * you may not use this file except in compliance with the License.
93478 * You may obtain a copy of the License at
93479 *
93480 * http://www.apache.org/licenses/LICENSE-2.0
93481 *
93482 * Unless required by applicable law or agreed to in writing, software
93483 * distributed under the License is distributed on an "AS IS" BASIS,
93484 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93485 * See the License for the specific language governing permissions and
93486 * limitations under the License.
93487 * =============================================================================
93488 */
93489 const lrnGrad = (args) => {
93490 const { inputs, backend, attrs } = args;
93491 const { x, y, dy } = inputs;
93492 const { depthRadius, bias, alpha, beta } = attrs;
93493 const program = new LRNGradProgram(x.shape, depthRadius, bias, alpha, beta);
93494 return backend.runWebGLProgram(program, [x, y, dy], x.dtype);
93495 };
93496 // tslint:disable-next-line: variable-name
93497 const LRNGradConfig$1 = {
93498 kernelName: LRNGrad,
93499 backendName: 'webgl',
93500 kernelFunc: lrnGrad
93501 };
93502
93503 /**
93504 * @license
93505 * Copyright 2020 Google LLC. All Rights Reserved.
93506 * Licensed under the Apache License, Version 2.0 (the "License");
93507 * you may not use this file except in compliance with the License.
93508 * You may obtain a copy of the License at
93509 *
93510 * http://www.apache.org/licenses/LICENSE-2.0
93511 *
93512 * Unless required by applicable law or agreed to in writing, software
93513 * distributed under the License is distributed on an "AS IS" BASIS,
93514 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93515 * See the License for the specific language governing permissions and
93516 * limitations under the License.
93517 * =============================================================================
93518 */
93519 function maxImpl$1(x, reduceShape, outShape, backend) {
93520 const inSize = sizeFromShape(reduceShape);
93521 const xSize = sizeFromShape(x.shape);
93522 const batchSize = xSize / inSize;
93523 const reshapedInput = reshape$3({ inputs: { x }, attrs: { shape: [batchSize, inSize] }, backend });
93524 const reduced = reduce(reshapedInput, x.dtype, 'max', backend);
93525 const reshapedOutput = reshape$3({ inputs: { x: reduced }, attrs: { shape: outShape }, backend });
93526 backend.disposeIntermediateTensorInfo(reshapedInput);
93527 backend.disposeIntermediateTensorInfo(reduced);
93528 return reshapedOutput;
93529 }
93530
93531 /**
93532 * @license
93533 * Copyright 2020 Google LLC. All Rights Reserved.
93534 * Licensed under the Apache License, Version 2.0 (the "License");
93535 * you may not use this file except in compliance with the License.
93536 * You may obtain a copy of the License at
93537 *
93538 * http://www.apache.org/licenses/LICENSE-2.0
93539 *
93540 * Unless required by applicable law or agreed to in writing, software
93541 * distributed under the License is distributed on an "AS IS" BASIS,
93542 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93543 * See the License for the specific language governing permissions and
93544 * limitations under the License.
93545 * =============================================================================
93546 */
93547 function max$3(args) {
93548 const { inputs, backend, attrs } = args;
93549 const { x } = inputs;
93550 const { reductionIndices, keepDims } = attrs;
93551 const xRank = x.shape.length;
93552 const origAxes = parseAxisParam(reductionIndices, x.shape);
93553 let axes = origAxes;
93554 const permutedAxes = getAxesPermutation(axes, xRank);
93555 const maxInputIsTransposed = permutedAxes != null;
93556 const shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]);
93557 let maxInput = x;
93558 if (maxInputIsTransposed) {
93559 if (shouldExecuteOnCPU) {
93560 const xTexData = backend.texData.get(maxInput.dataId);
93561 const values = xTexData.values;
93562 const newShape = new Array(xRank);
93563 for (let i = 0; i < newShape.length; i++) {
93564 newShape[i] = x.shape[permutedAxes[i]];
93565 }
93566 const maxInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
93567 maxInput = backend.makeTensorInfo(newShape, x.dtype);
93568 const maxInputData = backend.texData.get(maxInput.dataId);
93569 maxInputData.values = maxInputValues;
93570 }
93571 else {
93572 maxInput = transposeImpl$1(x, permutedAxes, backend);
93573 }
93574 axes = getInnerMostAxes(axes.length, xRank);
93575 }
93576 assertAxesAreInnerMostDims('max', axes, xRank);
93577 const [maxOutShape, reduceShape] = computeOutAndReduceShapes(maxInput.shape, axes);
93578 let outShape = maxOutShape;
93579 if (keepDims) {
93580 // rather than reshape at the end, set the target shape here.
93581 outShape = expandShapeToKeepDim(maxOutShape, origAxes);
93582 }
93583 let out;
93584 if (shouldExecuteOnCPU) {
93585 const xTexData = backend.texData.get(maxInput.dataId);
93586 const values = xTexData.values;
93587 const outValues = maxImplCPU(values, sizeFromShape(reduceShape), outShape, x.dtype);
93588 out = backend.makeTensorInfo(outShape, x.dtype);
93589 const outData = backend.texData.get(out.dataId);
93590 outData.values = outValues;
93591 }
93592 else {
93593 out = maxImpl$1(maxInput, reduceShape, outShape, backend);
93594 }
93595 if (maxInputIsTransposed) {
93596 backend.disposeIntermediateTensorInfo(maxInput);
93597 }
93598 return out;
93599 }
93600 const maxConfig$1 = {
93601 kernelName: Max,
93602 backendName: 'webgl',
93603 kernelFunc: max$3
93604 };
93605
93606 /**
93607 * @license
93608 * Copyright 2020 Google LLC. All Rights Reserved.
93609 * Licensed under the Apache License, Version 2.0 (the "License");
93610 * you may not use this file except in compliance with the License.
93611 * You may obtain a copy of the License at
93612 *
93613 * http://www.apache.org/licenses/LICENSE-2.0
93614 *
93615 * Unless required by applicable law or agreed to in writing, software
93616 * distributed under the License is distributed on an "AS IS" BASIS,
93617 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93618 * See the License for the specific language governing permissions and
93619 * limitations under the License.
93620 * =============================================================================
93621 */
93622 const MAXIMUM = CHECK_NAN_SNIPPET$1 + `
93623 return max(a, b);
93624`;
93625 const MAXIMUM_PACKED = `
93626 vec4 result = vec4(max(a, b));
93627 vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));
93628 ` +
93629 CHECK_NAN_SNIPPET$2 + `
93630 return result;
93631`;
93632 const maximum$4 = binaryKernelFunc$1({
93633 opSnippet: MAXIMUM,
93634 packedOpSnippet: MAXIMUM_PACKED,
93635 cpuKernelImpl: maximumImplCPU
93636 });
93637 const maximumConfig$1 = {
93638 kernelName: Maximum,
93639 backendName: 'webgl',
93640 kernelFunc: maximum$4
93641 };
93642
93643 /**
93644 * @license
93645 * Copyright 2020 Google LLC. All Rights Reserved.
93646 * Licensed under the Apache License, Version 2.0 (the "License");
93647 * you may not use this file except in compliance with the License.
93648 * You may obtain a copy of the License at
93649 *
93650 * http://www.apache.org/licenses/LICENSE-2.0
93651 *
93652 * Unless required by applicable law or agreed to in writing, software
93653 * distributed under the License is distributed on an "AS IS" BASIS,
93654 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93655 * See the License for the specific language governing permissions and
93656 * limitations under the License.
93657 * =============================================================================
93658 */
93659 function maxPool$2(args) {
93660 const { inputs, backend, attrs } = args;
93661 const { x } = inputs;
93662 assertNotComplex$1(x, 'maxPool');
93663 const { filterSize, strides, pad, dimRoundingMode } = attrs;
93664 const dilations = 1;
93665 assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
93666 `Got strides ${strides} and dilations '${dilations}'`);
93667 const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
93668 if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
93669 arraysEqual(convInfo.inShape, convInfo.outShape)) {
93670 return identity$2({ inputs: { x }, backend });
93671 }
93672 const maxPoolProgram = new Pool2DProgram(convInfo, 'max', false);
93673 return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype);
93674 }
93675 const maxPoolConfig$1 = {
93676 kernelName: MaxPool,
93677 backendName: 'webgl',
93678 kernelFunc: maxPool$2
93679 };
93680
93681 /**
93682 * @license
93683 * Copyright 2020 Google LLC. All Rights Reserved.
93684 * Licensed under the Apache License, Version 2.0 (the "License");
93685 * you may not use this file except in compliance with the License.
93686 * You may obtain a copy of the License at
93687 *
93688 * http://www.apache.org/licenses/LICENSE-2.0
93689 *
93690 * Unless required by applicable law or agreed to in writing, software
93691 * distributed under the License is distributed on an "AS IS" BASIS,
93692 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93693 * See the License for the specific language governing permissions and
93694 * limitations under the License.
93695 * =============================================================================
93696 */
93697 function maxPool3d$1(args) {
93698 const { inputs, backend, attrs } = args;
93699 const { x } = inputs;
93700 const { filterSize, strides, pad, dataFormat, dimRoundingMode } = attrs;
93701 const dilations = [1, 1, 1];
93702 const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat);
93703 const maxPoolProgram = new Pool3DProgram(convInfo, 'max', false);
93704 return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype);
93705 }
93706 const maxPool3DConfig$1 = {
93707 kernelName: MaxPool3D,
93708 backendName: 'webgl',
93709 kernelFunc: maxPool3d$1
93710 };
93711
93712 /**
93713 * @license
93714 * Copyright 2017 Google LLC. All Rights Reserved.
93715 * Licensed under the Apache License, Version 2.0 (the "License");
93716 * you may not use this file except in compliance with the License.
93717 * You may obtain a copy of the License at
93718 *
93719 * http://www.apache.org/licenses/LICENSE-2.0
93720 *
93721 * Unless required by applicable law or agreed to in writing, software
93722 * distributed under the License is distributed on an "AS IS" BASIS,
93723 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93724 * See the License for the specific language governing permissions and
93725 * limitations under the License.
93726 * =============================================================================
93727 */
93728 class MaxPool2DBackpropProgram {
93729 constructor(convInfo) {
93730 this.variableNames = ['dy', 'maxPos'];
93731 this.outputShape = convInfo.inShape;
93732 const strideHeight = convInfo.strideHeight;
93733 const strideWidth = convInfo.strideWidth;
93734 const dilationHeight = convInfo.dilationHeight;
93735 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
93736 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
93737 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
93738 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
93739 const lastIndex = effectiveFilterHeight * effectiveFilterWidth - 1;
93740 this.userCode = `
93741 const ivec2 pads = ivec2(${padTop}, ${padLeft});
93742
93743 void main() {
93744 ivec4 coords = getOutputCoords();
93745 int b = coords[0];
93746 int d = coords[3];
93747
93748 ivec2 dyRCCorner = coords.yz - pads;
93749 int dyRCorner = dyRCCorner.x;
93750 int dyCCorner = dyRCCorner.y;
93751
93752 // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).
93753 // ? = to be determined. : = across all values in that axis.
93754 float dotProd = 0.0;
93755 for (int wR = 0; wR < ${effectiveFilterHeight};
93756 wR += ${dilationHeight}) {
93757 float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
93758
93759 if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
93760 continue;
93761 }
93762 int idyR = int(dyR);
93763
93764 for (int wC = 0; wC < ${effectiveFilterWidth}; wC++) {
93765 float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
93766
93767 if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
93768 fract(dyC) > 0.0) {
93769 continue;
93770 }
93771 int idyC = int(dyC);
93772
93773 float dyValue = getDy(b, idyR, idyC, d);
93774 int maxPosValue = ${lastIndex} - int(getMaxPos(b, idyR, idyC, d));
93775
93776 // Get the current value, check it against the value from the
93777 // position matrix.
93778 int curPosValue = wR * ${effectiveFilterWidth} + wC;
93779 float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);
93780
93781 dotProd += dyValue * mask;
93782 }
93783 }
93784 setOutput(dotProd);
93785 }
93786 `;
93787 }
93788 }
93789 class MaxPool3DBackpropProgram {
93790 constructor(convInfo) {
93791 this.variableNames = ['dy', 'maxPos'];
93792 this.outputShape = convInfo.inShape;
93793 const strideDepth = convInfo.strideDepth;
93794 const strideHeight = convInfo.strideHeight;
93795 const strideWidth = convInfo.strideWidth;
93796 const dilationDepth = convInfo.dilationDepth;
93797 const dilationHeight = convInfo.dilationHeight;
93798 const dilationWidth = convInfo.dilationWidth;
93799 const effectiveFilterDepth = convInfo.effectiveFilterDepth;
93800 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
93801 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
93802 const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
93803 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
93804 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
93805 const lastIndex = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1;
93806 this.userCode = `
93807 const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
93808
93809 void main() {
93810 ivec5 coords = getOutputCoords();
93811 int batch = coords.x;
93812 int ch = coords.u;
93813
93814 ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
93815 int dyDCorner = dyCorner.x;
93816 int dyRCorner = dyCorner.y;
93817 int dyCCorner = dyCorner.z;
93818
93819 // Convolve dy(?, ?, ?, ch) with pos mask(:, :, :, d) to get
93820 // dx(xD, xR, xC, ch).
93821 // ? = to be determined. : = across all values in that axis.
93822 float dotProd = 0.0;
93823
93824 for (int wD = 0; wD < ${effectiveFilterDepth};
93825 wD += ${dilationDepth}) {
93826 float dyD = float(dyDCorner + wD) / ${strideDepth}.0;
93827
93828 if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) {
93829 continue;
93830 }
93831 int idyD = int(dyD);
93832
93833 for (int wR = 0; wR < ${effectiveFilterHeight};
93834 wR += ${dilationHeight}) {
93835 float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
93836
93837 if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
93838 fract(dyR) > 0.0) {
93839 continue;
93840 }
93841 int idyR = int(dyR);
93842
93843 for (int wC = 0; wC < ${effectiveFilterWidth};
93844 wC += ${dilationWidth}) {
93845 float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
93846
93847 if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
93848 fract(dyC) > 0.0) {
93849 continue;
93850 }
93851 int idyC = int(dyC);
93852
93853 float dyValue = getDy(batch, idyD, idyR, idyC, ch);
93854 int maxPosValue = ${lastIndex} -
93855 int(getMaxPos(batch, idyD, idyR, idyC, ch));
93856
93857 // Get the current value, check it against the value from the
93858 // position matrix.
93859 int curPosValue =
93860 wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} +
93861 wR * ${effectiveFilterWidth} + wC;
93862 float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);
93863
93864 dotProd += dyValue * mask;
93865 }
93866 }
93867 }
93868 setOutput(dotProd);
93869 }
93870 `;
93871 }
93872 }
93873
93874 /**
93875 * @license
93876 * Copyright 2020 Google LLC. All Rights Reserved.
93877 * Licensed under the Apache License, Version 2.0 (the "License");
93878 * you may not use this file except in compliance with the License.
93879 * You may obtain a copy of the License at
93880 *
93881 * http://www.apache.org/licenses/LICENSE-2.0
93882 *
93883 * Unless required by applicable law or agreed to in writing, software
93884 * distributed under the License is distributed on an "AS IS" BASIS,
93885 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93886 * See the License for the specific language governing permissions and
93887 * limitations under the License.
93888 * =============================================================================
93889 */
93890 function maxPool3DGrad$1(args) {
93891 const { inputs, backend, attrs } = args;
93892 const { dy, input } = inputs;
93893 const x = input;
93894 const { filterSize, strides, pad, dimRoundingMode } = attrs;
93895 const dilations = [1, 1, 1];
93896 const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
93897 const maxPool3dPositionsProgram = new Pool3DProgram(convInfo, 'max', true /* get positions */);
93898 const maxPool3dPositions = backend.runWebGLProgram(maxPool3dPositionsProgram, [x], x.dtype);
93899 const maxPoolBackpropProgram = new MaxPool3DBackpropProgram(convInfo);
93900 const result = backend.runWebGLProgram(maxPoolBackpropProgram, [dy, maxPool3dPositions], x.dtype);
93901 backend.disposeIntermediateTensorInfo(maxPool3dPositions);
93902 return result;
93903 }
93904 const maxPool3DGradConfig$2 = {
93905 kernelName: MaxPool3DGrad,
93906 backendName: 'webgl',
93907 kernelFunc: maxPool3DGrad$1
93908 };
93909
93910 /**
93911 * @license
93912 * Copyright 2020 Google LLC. All Rights Reserved.
93913 * Licensed under the Apache License, Version 2.0 (the "License");
93914 * you may not use this file except in compliance with the License.
93915 * You may obtain a copy of the License at
93916 *
93917 * http://www.apache.org/licenses/LICENSE-2.0
93918 *
93919 * Unless required by applicable law or agreed to in writing, software
93920 * distributed under the License is distributed on an "AS IS" BASIS,
93921 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93922 * See the License for the specific language governing permissions and
93923 * limitations under the License.
93924 * =============================================================================
93925 */
93926 function maxPoolGrad$2(args) {
93927 const { inputs, backend, attrs } = args;
93928 const { dy, input, output } = inputs;
93929 const x = input;
93930 assertNotComplex$1([input, output], 'maxPoolGrad');
93931 const { filterSize, strides, pad, dimRoundingMode } = attrs;
93932 const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
93933 const getPositions = true;
93934 const maxPoolPositionsProgram = new Pool2DProgram(convInfo, 'max', getPositions);
93935 const maxPoolPositions = backend.runWebGLProgram(maxPoolPositionsProgram, [x], x.dtype);
93936 const maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo);
93937 const result = backend.runWebGLProgram(maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype);
93938 backend.disposeIntermediateTensorInfo(maxPoolPositions);
93939 return result;
93940 }
93941 const maxPoolGradConfig$2 = {
93942 kernelName: MaxPoolGrad,
93943 backendName: 'webgl',
93944 kernelFunc: maxPoolGrad$2
93945 };
93946
93947 /**
93948 * @license
93949 * Copyright 2020 Google LLC. All Rights Reserved.
93950 * Licensed under the Apache License, Version 2.0 (the "License");
93951 * you may not use this file except in compliance with the License.
93952 * You may obtain a copy of the License at
93953 *
93954 * http://www.apache.org/licenses/LICENSE-2.0
93955 *
93956 * Unless required by applicable law or agreed to in writing, software
93957 * distributed under the License is distributed on an "AS IS" BASIS,
93958 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93959 * See the License for the specific language governing permissions and
93960 * limitations under the License.
93961 * =============================================================================
93962 */
93963 function maxPoolWithArgmaxImpl$1(x, includeBatchInIndex, convInfo, backend) {
93964 let program = new Pool2DProgram(convInfo, 'max', false);
93965 const poolOutput = backend.runWebGLProgram(program, [x], 'float32');
93966 program = new Pool2DProgram(convInfo, 'max', true, true, includeBatchInIndex);
93967 const indexOutput = backend.runWebGLProgram(program, [x], 'float32');
93968 return [poolOutput, indexOutput];
93969 }
93970
93971 /**
93972 * @license
93973 * Copyright 2020 Google LLC. All Rights Reserved.
93974 * Licensed under the Apache License, Version 2.0 (the "License");
93975 * you may not use this file except in compliance with the License.
93976 * You may obtain a copy of the License at
93977 *
93978 * http://www.apache.org/licenses/LICENSE-2.0
93979 *
93980 * Unless required by applicable law or agreed to in writing, software
93981 * distributed under the License is distributed on an "AS IS" BASIS,
93982 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93983 * See the License for the specific language governing permissions and
93984 * limitations under the License.
93985 * =============================================================================
93986 */
93987 const maxPoolWithArgmaxConfig$1 = {
93988 kernelName: MaxPoolWithArgmax,
93989 backendName: 'webgl',
93990 kernelFunc: ({ inputs, attrs, backend }) => {
93991 const { x } = inputs;
93992 const { filterSize, strides, pad, includeBatchInIndex } = attrs;
93993 const webglBackend = backend;
93994 assert(x.shape.length === 4, () => `Error in maxPool: input must be rank 4 but got rank ${x.shape.length}.`);
93995 const dilations = [1, 1];
93996 assert(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
93997 `Got strides ${strides} and dilations '${dilations}'`);
93998 const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad);
93999 const [result, indexes] = maxPoolWithArgmaxImpl$1(x, includeBatchInIndex, convInfo, webglBackend);
94000 return [result, indexes];
94001 }
94002 };
94003
94004 /**
94005 * @license
94006 * Copyright 2020 Google LLC. All Rights Reserved.
94007 * Licensed under the Apache License, Version 2.0 (the "License");
94008 * you may not use this file except in compliance with the License.
94009 * You may obtain a copy of the License at
94010 *
94011 * http://www.apache.org/licenses/LICENSE-2.0
94012 *
94013 * Unless required by applicable law or agreed to in writing, software
94014 * distributed under the License is distributed on an "AS IS" BASIS,
94015 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94016 * See the License for the specific language governing permissions and
94017 * limitations under the License.
94018 * =============================================================================
94019 */
94020 function meanImpl(x, reduceShape, outShape, backend) {
94021 const inSize = sizeFromShape(reduceShape);
94022 const xSize = sizeFromShape(x.shape);
94023 const batchSize = xSize / inSize;
94024 const reshapedInput = reshape$3({ inputs: { x }, attrs: { shape: [batchSize, inSize] }, backend });
94025 const reduced = reduce(reshapedInput, 'float32', 'mean', backend);
94026 const reshapedOutput = reshape$3({ inputs: { x: reduced }, attrs: { shape: outShape }, backend });
94027 backend.disposeIntermediateTensorInfo(reshapedInput);
94028 backend.disposeIntermediateTensorInfo(reduced);
94029 return reshapedOutput;
94030 }
94031
94032 /**
94033 * @license
94034 * Copyright 2020 Google LLC. All Rights Reserved.
94035 * Licensed under the Apache License, Version 2.0 (the "License");
94036 * you may not use this file except in compliance with the License.
94037 * You may obtain a copy of the License at
94038 *
94039 * http://www.apache.org/licenses/LICENSE-2.0
94040 *
94041 * Unless required by applicable law or agreed to in writing, software
94042 * distributed under the License is distributed on an "AS IS" BASIS,
94043 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94044 * See the License for the specific language governing permissions and
94045 * limitations under the License.
94046 * =============================================================================
94047 */
94048 const meanConfig$1 = {
94049 kernelName: Mean,
94050 backendName: 'webgl',
94051 kernelFunc: ({ inputs, attrs, backend }) => {
94052 const { x } = inputs;
94053 const { keepDims, axis } = attrs;
94054 const webglBackend = backend;
94055 const xRank = x.shape.length;
94056 const origAxes = parseAxisParam(axis, x.shape);
94057 let axes = origAxes;
94058 const permutedAxes = getAxesPermutation(axes, xRank);
94059 const meanInputIsTransposed = permutedAxes != null;
94060 const shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]);
94061 const intermediates = [];
94062 let meanInput = x;
94063 if (meanInputIsTransposed) {
94064 if (shouldExecuteOnCPU) {
94065 const xTexData = webglBackend.texData.get(meanInput.dataId);
94066 const values = xTexData.values;
94067 const newShape = new Array(xRank);
94068 for (let i = 0; i < newShape.length; i++) {
94069 newShape[i] = x.shape[permutedAxes[i]];
94070 }
94071 const meanInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
94072 meanInput = webglBackend.makeTensorInfo(newShape, x.dtype);
94073 const meanInputData = webglBackend.texData.get(meanInput.dataId);
94074 meanInputData.values = meanInputValues;
94075 }
94076 else {
94077 meanInput = transposeImpl$1(x, permutedAxes, webglBackend);
94078 }
94079 intermediates.push(meanInput);
94080 axes = getInnerMostAxes(axes.length, xRank);
94081 }
94082 assertAxesAreInnerMostDims('sum', axes, xRank);
94083 const [meanOutShape, reduceShape] = computeOutAndReduceShapes(meanInput.shape, axes);
94084 let outShape = meanOutShape;
94085 if (keepDims) {
94086 // rather than reshape at the end, set the target shape here.
94087 outShape = expandShapeToKeepDim(meanOutShape, origAxes);
94088 }
94089 const out = meanImpl(meanInput, reduceShape, outShape, webglBackend);
94090 for (const i of intermediates) {
94091 webglBackend.disposeIntermediateTensorInfo(i);
94092 }
94093 return out;
94094 }
94095 };
94096
94097 /**
94098 * @license
94099 * Copyright 2020 Google LLC. All Rights Reserved.
94100 * Licensed under the Apache License, Version 2.0 (the "License");
94101 * you may not use this file except in compliance with the License.
94102 * You may obtain a copy of the License at
94103 *
94104 * http://www.apache.org/licenses/LICENSE-2.0
94105 *
94106 * Unless required by applicable law or agreed to in writing, software
94107 * distributed under the License is distributed on an "AS IS" BASIS,
94108 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94109 * See the License for the specific language governing permissions and
94110 * limitations under the License.
94111 * =============================================================================
94112 */
94113 function min$3(args) {
94114 const { inputs, backend, attrs } = args;
94115 const { x } = inputs;
94116 const { axis, keepDims } = attrs;
94117 const xRank = x.shape.length;
94118 const origAxes = parseAxisParam(axis, x.shape);
94119 let axes = origAxes;
94120 const permutedAxes = getAxesPermutation(axes, xRank);
94121 let permutedX = x;
94122 if (permutedAxes != null) {
94123 permutedX = transpose$2({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
94124 axes = getInnerMostAxes(axes.length, x.shape.length);
94125 }
94126 assertAxesAreInnerMostDims('min', axes, xRank);
94127 const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes);
94128 const inSize = sizeFromShape(reduceShape);
94129 const a2D = reshape$3({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
94130 const reduced = reduce(a2D, a2D.dtype, 'min', backend);
94131 let res;
94132 if (keepDims) {
94133 const newShape = expandShapeToKeepDim(outShape, origAxes);
94134 res = reshape$3({ inputs: { x: reduced }, backend, attrs: { shape: newShape } });
94135 }
94136 else {
94137 res = reshape$3({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
94138 }
94139 backend.disposeIntermediateTensorInfo(a2D);
94140 backend.disposeIntermediateTensorInfo(reduced);
94141 if (permutedAxes != null) {
94142 backend.disposeIntermediateTensorInfo(permutedX);
94143 }
94144 return res;
94145 }
94146 const minConfig$1 = {
94147 kernelName: Min,
94148 backendName: 'webgl',
94149 kernelFunc: min$3
94150 };
94151
94152 /**
94153 * @license
94154 * Copyright 2020 Google LLC. All Rights Reserved.
94155 * Licensed under the Apache License, Version 2.0 (the "License");
94156 * you may not use this file except in compliance with the License.
94157 * You may obtain a copy of the License at
94158 *
94159 * http://www.apache.org/licenses/LICENSE-2.0
94160 *
94161 * Unless required by applicable law or agreed to in writing, software
94162 * distributed under the License is distributed on an "AS IS" BASIS,
94163 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94164 * See the License for the specific language governing permissions and
94165 * limitations under the License.
94166 * =============================================================================
94167 */
94168 const MINIMUM = CHECK_NAN_SNIPPET$1 + `
94169 return min(a, b);
94170`;
94171 const MINIMUM_PACKED = `
94172 vec4 result = vec4(min(a, b));
94173 vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));
94174 ` +
94175 CHECK_NAN_SNIPPET$2 + `
94176 return result;
94177`;
94178 const minimum$4 = binaryKernelFunc$1({
94179 opSnippet: MINIMUM,
94180 packedOpSnippet: MINIMUM_PACKED,
94181 cpuKernelImpl: minimumImplCPU
94182 });
94183 const minimumConfig$1 = {
94184 kernelName: Minimum,
94185 backendName: 'webgl',
94186 kernelFunc: minimum$4
94187 };
94188
94189 /**
94190 * @license
94191 * Copyright 2020 Google LLC. All Rights Reserved.
94192 * Licensed under the Apache License, Version 2.0 (the "License");
94193 * you may not use this file except in compliance with the License.
94194 * You may obtain a copy of the License at
94195 *
94196 * http://www.apache.org/licenses/LICENSE-2.0
94197 *
94198 * Unless required by applicable law or agreed to in writing, software
94199 * distributed under the License is distributed on an "AS IS" BASIS,
94200 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94201 * See the License for the specific language governing permissions and
94202 * limitations under the License.
94203 * =============================================================================
94204 */
94205 class MirrorPadProgram {
94206 constructor(xShape, paddings, mode) {
94207 this.variableNames = ['x'];
94208 this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
94209 const rank = xShape.length;
94210 const dtype = getCoordsDataType(rank);
94211 const start = paddings.map(p => p[0]).join(',');
94212 const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
94213 const unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
94214 const offset = mode === 'reflect' ? 0 : 1;
94215 if (rank === 1) {
94216 this.userCode = `
94217 int start = ${start};
94218 int end = ${end};
94219
94220 void main() {
94221 int outC = getOutputCoords();
94222 if (outC < start) {
94223 outC = start * 2 - outC - ${offset};
94224 } else if(outC >= end) {
94225 outC = (end - 1) * 2 - outC + ${offset};
94226 }
94227 setOutput(getX(outC - start));
94228 }
94229 `;
94230 return;
94231 }
94232 this.userCode = `
94233 ${dtype} start = ${dtype}(${start});
94234 ${dtype} end = ${dtype}(${end});
94235
94236 void main() {
94237 ${dtype} outC = getOutputCoords();
94238 for (int i = 0; i < ${rank}; i++) {
94239 if (outC[i] < start[i]) {
94240 outC[i] = start[i] * 2 - outC[i] - ${offset};
94241 } else if(outC[i] >= end[i]) {
94242 outC[i] = (end[i] - 1) * 2 - outC[i] + ${offset};
94243 }
94244 }
94245 ${dtype} coords = outC - start;
94246 setOutput(getX(${unpackedCoords}));
94247 }
94248 `;
94249 }
94250 }
94251
94252 /**
94253 * @license
94254 * Copyright 2020 Google LLC. All Rights Reserved.
94255 * Licensed under the Apache License, Version 2.0 (the "License");
94256 * you may not use this file except in compliance with the License.
94257 * You may obtain a copy of the License at
94258 *
94259 * http://www.apache.org/licenses/LICENSE-2.0
94260 *
94261 * Unless required by applicable law or agreed to in writing, software
94262 * distributed under the License is distributed on an "AS IS" BASIS,
94263 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94264 * See the License for the specific language governing permissions and
94265 * limitations under the License.
94266 * =============================================================================
94267 */
94268 /**
94269 * Example shader code for
94270 * `mirrorPad(tf.tensor1d([1, 2, 3], 'int32'), [[2, 2]], 'reflect')`
94271 * ```
94272 * const int start = int(2);
94273 * const int end = int(5);
94274 *
94275 * void main() {
94276 * int outputLoc = getOutputCoords();
94277 * vec4 result = vec4(0.);
94278 *
94279 * int rc = outputLoc;
94280 *
94281 * int source = rc;
94282 * if (source < start) {
94283 * source = start * 2 - source - 0;
94284 * } else if (source >= end) {
94285 * source = (end - 1) * 2 - source + 0;
94286 * }
94287 * source -= start;
94288 *
94289 * result[0] = getChannel(getX(source), source);
94290 * rc += 1;
94291 * if(rc < 6) {
94292 * int source = rc;
94293 * if (source < start) {
94294 * source = start * 2 - source - 0;
94295 * } else if (source >= end) {
94296 * source = (end - 1) * 2 - source + 0;
94297 * }
94298 * source -= start;
94299 *
94300 * result[1] = getChannel(getX(source), source);
94301 * }
94302 *
94303 * setOutput(result);
94304 * }
94305 * ```
94306 */
94307 class MirrorPadPackedProgram {
94308 constructor(xShape, paddings, mode) {
94309 this.variableNames = ['x'];
94310 this.packedInputs = true;
94311 this.packedOutput = true;
94312 this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
94313 const rank = xShape.length;
94314 const dtype = getCoordsDataType(rank);
94315 const start = paddings.map(p => p[0]).join(',');
94316 const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
94317 const coords = getChannels('rc', rank);
94318 const source = getChannels('source', rank);
94319 const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`;
94320 const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`;
94321 const offset = mode === 'reflect' ? 0 : 1;
94322 let mainLoop = '';
94323 if (rank === 1) {
94324 const padSetup = `
94325 ${dtype} source = rc;
94326 if (source < start) {
94327 source = start * 2 - source - ${offset};
94328 } else if (source >= end) {
94329 source = (end - 1) * 2 - source + ${offset};
94330 }
94331 source -= start;
94332 `;
94333 mainLoop = `
94334 ${dtype} rc = outputLoc;
94335 ${padSetup}
94336 result[0] = getChannel(getX(${source.join()}), ${innerDims});
94337 ${coords[rank - 1]} += 1;
94338 if(${cLimit}) {
94339 ${padSetup}
94340 result[1] = getChannel(getX(${source.join()}), ${innerDims});
94341 }
94342 `;
94343 }
94344 else {
94345 const padSetup = `
94346 ${dtype} source = rc;
94347 ${dtype} lt = ${dtype}(lessThan(source, start));
94348 ${dtype} gte = ${dtype}(greaterThanEqual(source, end));
94349 ${dtype} orig = 1 - (lt + gte);
94350 source = orig * source +
94351 lt * (start * 2 - source - ${offset}) +
94352 gte * ((end - 1) * 2 - source + ${offset});
94353 source -= start;
94354 `;
94355 mainLoop = `
94356 ${dtype} rc = outputLoc;
94357 ${padSetup}
94358 result[0] = getChannel(getX(${source.join()}), ${innerDims});
94359 ${coords[rank - 1]} += 1;
94360 if(${cLimit}) {
94361 ${padSetup}
94362 result[1] = getChannel(getX(${source.join()}), ${innerDims});
94363 }
94364 rc = outputLoc;
94365 ${coords[rank - 2]} += 1;
94366 if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {
94367 ${padSetup}
94368 result[2] = getChannel(getX(${source.join()}), ${innerDims});
94369 ${coords[rank - 1]} += 1;
94370 if(${cLimit}) {
94371 ${padSetup}
94372 result[3] = getChannel(getX(${source.join()}), ${innerDims});
94373 }
94374 }
94375 `;
94376 }
94377 this.userCode = `
94378 const ${dtype} start = ${dtype}(${start});
94379 const ${dtype} end = ${dtype}(${end});
94380
94381 void main() {
94382 ${dtype} outputLoc = getOutputCoords();
94383 vec4 result = vec4(0.);
94384 ${mainLoop}
94385 setOutput(result);
94386 }
94387 `;
94388 }
94389 }
94390
94391 /**
94392 * @license
94393 * Copyright 2020 Google LLC. All Rights Reserved.
94394 * Licensed under the Apache License, Version 2.0 (the "License");
94395 * you may not use this file except in compliance with the License.
94396 * You may obtain a copy of the License at
94397 *
94398 * http://www.apache.org/licenses/LICENSE-2.0
94399 *
94400 * Unless required by applicable law or agreed to in writing, software
94401 * distributed under the License is distributed on an "AS IS" BASIS,
94402 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94403 * See the License for the specific language governing permissions and
94404 * limitations under the License.
94405 * =============================================================================
94406 */
94407 const mirrorPadKernelFunc = ({ inputs, backend, attrs }) => {
94408 const { x } = inputs;
94409 const { paddings, mode } = attrs;
94410 const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
94411 new MirrorPadPackedProgram(x.shape, paddings, mode) :
94412 new MirrorPadProgram(x.shape, paddings, mode);
94413 const output = backend.runWebGLProgram(program, [x], x.dtype);
94414 return output;
94415 };
94416 const mirrorPadConfig$1 = {
94417 kernelName: MirrorPad,
94418 backendName: 'webgl',
94419 kernelFunc: mirrorPadKernelFunc,
94420 };
94421
94422 /**
94423 * @license
94424 * Copyright 2020 Google LLC. All Rights Reserved.
94425 * Licensed under the Apache License, Version 2.0 (the "License");
94426 * you may not use this file except in compliance with the License.
94427 * You may obtain a copy of the License at
94428 *
94429 * http://www.apache.org/licenses/LICENSE-2.0
94430 *
94431 * Unless required by applicable law or agreed to in writing, software
94432 * distributed under the License is distributed on an "AS IS" BASIS,
94433 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94434 * See the License for the specific language governing permissions and
94435 * limitations under the License.
94436 * =============================================================================
94437 */
94438 const MOD = `if (b == 0.0) return NAN;
94439 return mod(a, b);`;
94440 const MOD_PACKED = `
94441 vec4 result = mod(a, b);
94442 vec4 isNaN = vec4(equal(b, vec4(0.0)));
94443 ` +
94444 CHECK_NAN_SNIPPET$2 + `
94445 return result;
94446`;
94447 const mod$2 = binaryKernelFunc$1({
94448 opSnippet: MOD,
94449 packedOpSnippet: MOD_PACKED,
94450 });
94451 const modConfig$1 = {
94452 kernelName: Mod,
94453 backendName: 'webgl',
94454 kernelFunc: mod$2
94455 };
94456
94457 /**
94458 * @license
94459 * Copyright 2017 Google LLC. All Rights Reserved.
94460 * Licensed under the Apache License, Version 2.0 (the "License");
94461 * you may not use this file except in compliance with the License.
94462 * You may obtain a copy of the License at
94463 *
94464 * http://www.apache.org/licenses/LICENSE-2.0
94465 *
94466 * Unless required by applicable law or agreed to in writing, software
94467 * distributed under the License is distributed on an "AS IS" BASIS,
94468 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94469 * See the License for the specific language governing permissions and
94470 * limitations under the License.
94471 * =============================================================================
94472 */
94473 class MultinomialProgram {
94474 constructor(batchSize, numOutcomes, numSamples) {
94475 this.variableNames = ['probs'];
94476 this.customUniforms = [{ name: 'seed', type: 'float' }];
94477 this.outputShape = [batchSize, numSamples];
94478 this.userCode = `
94479 void main() {
94480 ivec2 coords = getOutputCoords();
94481 int batch = coords[0];
94482
94483 float r = random(seed);
94484 float cdf = 0.0;
94485
94486 for (int i = 0; i < ${numOutcomes - 1}; i++) {
94487 cdf += getProbs(batch, i);
94488
94489 if (r < cdf) {
94490 setOutput(float(i));
94491 return;
94492 }
94493 }
94494
94495 // If no other event happened, last event happened.
94496 setOutput(float(${numOutcomes - 1}));
94497 }
94498 `;
94499 }
94500 }
94501
94502 /**
94503 * @license
94504 * Copyright 2020 Google LLC. All Rights Reserved.
94505 * Licensed under the Apache License, Version 2.0 (the "License");
94506 * you may not use this file except in compliance with the License.
94507 * You may obtain a copy of the License at
94508 *
94509 * http://www.apache.org/licenses/LICENSE-2.0
94510 *
94511 * Unless required by applicable law or agreed to in writing, software
94512 * distributed under the License is distributed on an "AS IS" BASIS,
94513 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94514 * See the License for the specific language governing permissions and
94515 * limitations under the License.
94516 * =============================================================================
94517 */
94518 // Without the equality check div produces 0.9999 for a = b, which when
94519 // floored can cause errors.
94520 const DIV = `
94521if (a == b) {
94522 return 1.0;
94523};
94524return a / b;`;
94525 // We do the same as in ./binaryop_gpu, with vec4 and ivec4.
94526 // On Linux, the vectorized implementation produces NaNs when a and b are 0.
94527 const DIV_PACKED = `
94528 // vec4 one = vec4(equal(a, b));
94529 // return one + (vec4(1.0) - one) * a / b;
94530 vec4 result = a / b;
94531 if(a.x == b.x) {
94532 result.x = 1.;
94533 }
94534 if(a.y == b.y) {
94535 result.y = 1.;
94536 }
94537 if(a.z == b.z) {
94538 result.z = 1.;
94539 }
94540 if(a.w == b.w) {
94541 result.w = 1.;
94542 }
94543
94544 return result;
94545`;
94546 const realDiv = binaryKernelFunc$1({ opSnippet: DIV, packedOpSnippet: DIV_PACKED, checkOutOfBounds: true });
94547 const realDivConfig$1 = {
94548 kernelName: RealDiv,
94549 backendName: 'webgl',
94550 kernelFunc: realDiv,
94551 };
94552
94553 /**
94554 * @license
94555 * Copyright 2020 Google LLC. All Rights Reserved.
94556 * Licensed under the Apache License, Version 2.0 (the "License");
94557 * you may not use this file except in compliance with the License.
94558 * You may obtain a copy of the License at
94559 *
94560 * http://www.apache.org/licenses/LICENSE-2.0
94561 *
94562 * Unless required by applicable law or agreed to in writing, software
94563 * distributed under the License is distributed on an "AS IS" BASIS,
94564 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94565 * See the License for the specific language governing permissions and
94566 * limitations under the License.
94567 * =============================================================================
94568 */
94569 const SUB = 'return a - b;';
94570 const sub$2 = binaryKernelFunc$1({
94571 opSnippet: SUB,
94572 packedOpSnippet: SUB,
94573 supportsComplex: true,
94574 cpuKernelImpl: subImplCPU
94575 });
94576 const subConfig$1 = {
94577 kernelName: Sub,
94578 backendName: 'webgl',
94579 kernelFunc: sub$2
94580 };
94581
94582 /**
94583 * @license
94584 * Copyright 2020 Google LLC. All Rights Reserved.
94585 * Licensed under the Apache License, Version 2.0 (the "License");
94586 * you may not use this file except in compliance with the License.
94587 * You may obtain a copy of the License at
94588 *
94589 * http://www.apache.org/licenses/LICENSE-2.0
94590 *
94591 * Unless required by applicable law or agreed to in writing, software
94592 * distributed under the License is distributed on an "AS IS" BASIS,
94593 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94594 * See the License for the specific language governing permissions and
94595 * limitations under the License.
94596 * =============================================================================
94597 */
94598 function softmax$3(args) {
94599 const { inputs, backend, attrs } = args;
94600 const { logits } = inputs;
94601 const { dim } = attrs;
94602 const axes = parseAxisParam([dim], logits.shape);
94603 const maxLogit = max$3({
94604 inputs: { x: logits },
94605 backend,
94606 attrs: { reductionIndices: axes, keepDims: false }
94607 });
94608 const expandedShape = expandShapeToKeepDim(maxLogit.shape, axes);
94609 const maxLogitsReshaped = reshape$3({ inputs: { x: maxLogit }, backend, attrs: { shape: expandedShape } });
94610 const a = sub$2({ inputs: { a: logits, b: maxLogitsReshaped }, backend });
94611 const b = exp$2({ inputs: { x: a }, backend });
94612 const sumExp = sum$4({ inputs: { x: b }, backend, attrs: { axis: axes, keepDims: false } });
94613 const sumExpReshaped = reshape$3({ inputs: { x: sumExp }, backend, attrs: { shape: expandedShape } });
94614 const res = realDiv({ inputs: { a: b, b: sumExpReshaped }, backend });
94615 backend.disposeIntermediateTensorInfo(maxLogit);
94616 backend.disposeIntermediateTensorInfo(maxLogitsReshaped);
94617 backend.disposeIntermediateTensorInfo(a);
94618 backend.disposeIntermediateTensorInfo(b);
94619 backend.disposeIntermediateTensorInfo(sumExp);
94620 backend.disposeIntermediateTensorInfo(sumExpReshaped);
94621 return res;
94622 }
94623 const softmaxConfig$1 = {
94624 kernelName: Softmax,
94625 backendName: 'webgl',
94626 kernelFunc: softmax$3
94627 };
94628
94629 /**
94630 * @license
94631 * Copyright 2020 Google LLC. All Rights Reserved.
94632 * Licensed under the Apache License, Version 2.0 (the "License");
94633 * you may not use this file except in compliance with the License.
94634 * You may obtain a copy of the License at
94635 *
94636 * http://www.apache.org/licenses/LICENSE-2.0
94637 *
94638 * Unless required by applicable law or agreed to in writing, software
94639 * distributed under the License is distributed on an "AS IS" BASIS,
94640 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94641 * See the License for the specific language governing permissions and
94642 * limitations under the License.
94643 * =============================================================================
94644 */
94645 function multinomial$2(args) {
94646 const { inputs, backend, attrs } = args;
94647 const { logits } = inputs;
94648 const { numSamples, seed, normalized } = attrs;
94649 const probs = normalized ?
94650 logits :
94651 softmax$3({ inputs: { logits }, backend, attrs: { dim: logits.shape.length - 1 } });
94652 const batchSize = probs.shape[0];
94653 const numOutcomes = probs.shape[1];
94654 const program = new MultinomialProgram(batchSize, numOutcomes, numSamples);
94655 const customValues = [[seed]];
94656 const res = backend.runWebGLProgram(program, [probs], 'int32', customValues);
94657 if (!normalized) {
94658 backend.disposeIntermediateTensorInfo(probs);
94659 }
94660 return res;
94661 }
94662 const multinomialConfig$1 = {
94663 kernelName: Multinomial,
94664 backendName: 'webgl',
94665 kernelFunc: multinomial$2
94666 };
94667
94668 /**
94669 * @license
94670 * Copyright 2020 Google LLC. All Rights Reserved.
94671 * Licensed under the Apache License, Version 2.0 (the "License");
94672 * you may not use this file except in compliance with the License.
94673 * You may obtain a copy of the License at
94674 *
94675 * http://www.apache.org/licenses/LICENSE-2.0
94676 *
94677 * Unless required by applicable law or agreed to in writing, software
94678 * distributed under the License is distributed on an "AS IS" BASIS,
94679 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94680 * See the License for the specific language governing permissions and
94681 * limitations under the License.
94682 * =============================================================================
94683 */
94684 const NEG = CHECK_NAN_SNIPPET + `
94685 return -x;
94686`;
94687 const NEG_PACKED = `
94688 vec4 result = -x;
94689 bvec4 isNaN = isnan(x);
94690
94691 result.r = isNaN.r ? x.r : result.r;
94692 result.g = isNaN.g ? x.g : result.g;
94693 result.b = isNaN.b ? x.b : result.b;
94694 result.a = isNaN.a ? x.a : result.a;
94695
94696 return result;
94697`;
94698 // This doesn't use unaryKernelFunc because negImplCPU is not of type
94699 // SimpleUnaryKernelImplCPU.
94700 function neg$2(args) {
94701 const { inputs, backend } = args;
94702 const { x } = inputs;
94703 if (backend.shouldExecuteOnCPU([x])) {
94704 const xData = backend.texData.get(x.dataId);
94705 const [outValues, newShape] = negImplCPU(xData.values, x.shape, x.dtype);
94706 return backend.makeTensorInfo(newShape, x.dtype, outValues);
94707 }
94708 let program;
94709 if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
94710 program = new UnaryOpPackedProgram(x.shape, NEG_PACKED);
94711 }
94712 else {
94713 program = new UnaryOpProgram(x.shape, NEG);
94714 }
94715 return backend.runWebGLProgram(program, [x], x.dtype);
94716 }
94717 const negConfig$1 = {
94718 kernelName: Neg,
94719 backendName: 'webgl',
94720 kernelFunc: neg$2
94721 };
94722
94723 /**
94724 * @license
94725 * Copyright 2020 Google LLC. All Rights Reserved.
94726 * Licensed under the Apache License, Version 2.0 (the "License");
94727 * you may not use this file except in compliance with the License.
94728 * You may obtain a copy of the License at
94729 *
94730 * http://www.apache.org/licenses/LICENSE-2.0
94731 *
94732 * Unless required by applicable law or agreed to in writing, software
94733 * distributed under the License is distributed on an "AS IS" BASIS,
94734 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94735 * See the License for the specific language governing permissions and
94736 * limitations under the License.
94737 * =============================================================================
94738 */
94739 const nonMaxSuppressionV3Impl$2 = nonMaxSuppressionV3Impl;
94740 function nonMaxSuppressionV3$1(args) {
94741 warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
94742 'Call tf.nonMaxSuppressionAsync() instead');
94743 const { inputs, backend, attrs } = args;
94744 const { boxes, scores } = inputs;
94745 const { maxOutputSize, iouThreshold, scoreThreshold } = attrs;
94746 const boxesVals = backend.readSync(boxes.dataId);
94747 const scoresVals = backend.readSync(scores.dataId);
94748 const { selectedIndices } = nonMaxSuppressionV3Impl$2(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
94749 return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices));
94750 }
94751 const nonMaxSuppressionV3Config$1 = {
94752 kernelName: NonMaxSuppressionV3,
94753 backendName: 'webgl',
94754 kernelFunc: nonMaxSuppressionV3$1
94755 };
94756
94757 /**
94758 * @license
94759 * Copyright 2020 Google LLC. All Rights Reserved.
94760 * Licensed under the Apache License, Version 2.0 (the "License");
94761 * you may not use this file except in compliance with the License.
94762 * You may obtain a copy of the License at
94763 *
94764 * http://www.apache.org/licenses/LICENSE-2.0
94765 *
94766 * Unless required by applicable law or agreed to in writing, software
94767 * distributed under the License is distributed on an "AS IS" BASIS,
94768 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94769 * See the License for the specific language governing permissions and
94770 * limitations under the License.
94771 * =============================================================================
94772 */
94773 const nonMaxSuppressionV4Impl$2 = nonMaxSuppressionV4Impl;
94774 function nonMaxSuppressionV4$1(args) {
94775 warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
94776 'Call tf.nonMaxSuppressionAsync() instead');
94777 const { inputs, backend, attrs } = args;
94778 const { boxes, scores } = inputs;
94779 const { maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize } = attrs;
94780 const boxesVals = backend.readSync(boxes.dataId);
94781 const scoresVals = backend.readSync(scores.dataId);
94782 const { selectedIndices, validOutputs } = nonMaxSuppressionV4Impl$2(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize);
94783 return [
94784 backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
94785 backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs]))
94786 ];
94787 }
94788 const nonMaxSuppressionV4Config$1 = {
94789 kernelName: NonMaxSuppressionV4,
94790 backendName: 'webgl',
94791 kernelFunc: nonMaxSuppressionV4$1
94792 };
94793
94794 /**
94795 * @license
94796 * Copyright 2020 Google LLC. All Rights Reserved.
94797 * Licensed under the Apache License, Version 2.0 (the "License");
94798 * you may not use this file except in compliance with the License.
94799 * You may obtain a copy of the License at
94800 *
94801 * http://www.apache.org/licenses/LICENSE-2.0
94802 *
94803 * Unless required by applicable law or agreed to in writing, software
94804 * distributed under the License is distributed on an "AS IS" BASIS,
94805 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94806 * See the License for the specific language governing permissions and
94807 * limitations under the License.
94808 * =============================================================================
94809 */
94810 const nonMaxSuppressionV5Impl$2 = nonMaxSuppressionV5Impl;
94811 function nonMaxSuppressionV5$1(args) {
94812 warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
94813 'Call tf.nonMaxSuppressionAsync() instead');
94814 const { inputs, backend, attrs } = args;
94815 const { boxes, scores } = inputs;
94816 const { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma } = attrs;
94817 const boxesVals = backend.readSync(boxes.dataId);
94818 const scoresVals = backend.readSync(scores.dataId);
94819 const maxOutputSizeVal = maxOutputSize;
94820 const iouThresholdVal = iouThreshold;
94821 const scoreThresholdVal = scoreThreshold;
94822 const softNmsSigmaVal = softNmsSigma;
94823 const { selectedIndices, selectedScores } = nonMaxSuppressionV5Impl$2(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal);
94824 return [
94825 backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
94826 backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores))
94827 ];
94828 }
94829 const nonMaxSuppressionV5Config$1 = {
94830 kernelName: NonMaxSuppressionV5,
94831 backendName: 'webgl',
94832 kernelFunc: nonMaxSuppressionV5$1
94833 };
94834
94835 /**
94836 * @license
94837 * Copyright 2017 Google LLC. All Rights Reserved.
94838 * Licensed under the Apache License, Version 2.0 (the "License");
94839 * you may not use this file except in compliance with the License.
94840 * You may obtain a copy of the License at
94841 *
94842 * http://www.apache.org/licenses/LICENSE-2.0
94843 *
94844 * Unless required by applicable law or agreed to in writing, software
94845 * distributed under the License is distributed on an "AS IS" BASIS,
94846 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94847 * See the License for the specific language governing permissions and
94848 * limitations under the License.
94849 * =============================================================================
94850 */
94851 class OneHotProgram {
94852 constructor(numIndices, depth, onValue, offValue) {
94853 this.variableNames = ['indices'];
94854 this.outputShape = [numIndices, depth];
94855 this.userCode = `
94856 void main() {
94857 ivec2 coords = getOutputCoords();
94858 int index = round(getIndices(coords.x));
94859 setOutput(mix(float(${offValue}), float(${onValue}),
94860 float(index == coords.y)));
94861 }
94862 `;
94863 }
94864 }
94865
94866 /**
94867 * @license
94868 * Copyright 2020 Google LLC. All Rights Reserved.
94869 * Licensed under the Apache License, Version 2.0 (the "License");
94870 * you may not use this file except in compliance with the License.
94871 * You may obtain a copy of the License at
94872 *
94873 * http://www.apache.org/licenses/LICENSE-2.0
94874 *
94875 * Unless required by applicable law or agreed to in writing, software
94876 * distributed under the License is distributed on an "AS IS" BASIS,
94877 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94878 * See the License for the specific language governing permissions and
94879 * limitations under the License.
94880 * =============================================================================
94881 */
94882 const oneHot$3 = (args) => {
94883 const { inputs, backend, attrs } = args;
94884 const { indices } = inputs;
94885 const { depth, onValue, offValue } = attrs;
94886 const indicesSize = sizeFromShape(indices.shape);
94887 const program = new OneHotProgram(indicesSize, depth, onValue, offValue);
94888 const reshaped = reshape$3({ inputs: { x: indices }, backend, attrs: { shape: [indicesSize] } });
94889 const result = backend.runWebGLProgram(program, [reshaped], indices.dtype);
94890 backend.disposeIntermediateTensorInfo(reshaped);
94891 const outShape = [...indices.shape, depth];
94892 const out = reshape$3({ inputs: { x: result }, backend, attrs: { shape: outShape } });
94893 backend.disposeIntermediateTensorInfo(result);
94894 return out;
94895 };
94896 const oneHotConfig$1 = {
94897 kernelName: OneHot,
94898 backendName: 'webgl',
94899 kernelFunc: oneHot$3
94900 };
94901
94902 /**
94903 * @license
94904 * Copyright 2020 Google LLC. All Rights Reserved.
94905 * Licensed under the Apache License, Version 2.0 (the "License");
94906 * you may not use this file except in compliance with the License.
94907 * You may obtain a copy of the License at
94908 *
94909 * http://www.apache.org/licenses/LICENSE-2.0
94910 *
94911 * Unless required by applicable law or agreed to in writing, software
94912 * distributed under the License is distributed on an "AS IS" BASIS,
94913 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94914 * See the License for the specific language governing permissions and
94915 * limitations under the License.
94916 * =============================================================================
94917 */
94918 function zerosLike$3(args) {
94919 const { inputs, backend } = args;
94920 const { x } = inputs;
94921 if (x.dtype === 'complex64') {
94922 const realPart = real$2({ inputs: { input: x }, backend });
94923 const r = zerosLike$3({ inputs: { x: realPart }, backend });
94924 const imagPart = imag$2({ inputs: { input: x }, backend });
94925 const i = zerosLike$3({ inputs: { x: imagPart }, backend });
94926 const result = complex$2({ inputs: { real: r, imag: i }, backend });
94927 backend.disposeIntermediateTensorInfo(realPart);
94928 backend.disposeIntermediateTensorInfo(r);
94929 backend.disposeIntermediateTensorInfo(imagPart);
94930 backend.disposeIntermediateTensorInfo(i);
94931 return result;
94932 }
94933 else {
94934 return fill$2({
94935 attrs: {
94936 shape: x.shape,
94937 dtype: x.dtype,
94938 value: x.dtype === 'string' ? '' : 0
94939 },
94940 backend
94941 });
94942 }
94943 }
94944 const zerosLikeConfig$1 = {
94945 kernelName: ZerosLike,
94946 backendName: 'webgl',
94947 kernelFunc: zerosLike$3
94948 };
94949
94950 /**
94951 * @license
94952 * Copyright 2020 Google LLC. All Rights Reserved.
94953 * Licensed under the Apache License, Version 2.0 (the "License");
94954 * you may not use this file except in compliance with the License.
94955 * You may obtain a copy of the License at
94956 *
94957 * http://www.apache.org/licenses/LICENSE-2.0
94958 *
94959 * Unless required by applicable law or agreed to in writing, software
94960 * distributed under the License is distributed on an "AS IS" BASIS,
94961 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94962 * See the License for the specific language governing permissions and
94963 * limitations under the License.
94964 * =============================================================================
94965 */
94966 function onesLike$3(args) {
94967 const { inputs, backend } = args;
94968 const { x } = inputs;
94969 if (x.dtype === 'string') {
94970 throw new Error('onesLike is not supported under string dtype');
94971 }
94972 else if (x.dtype === 'complex64') {
94973 const realPart = real$2({ inputs: { input: x }, backend });
94974 const r = onesLike$3({ inputs: { x: realPart }, backend });
94975 const imagPart = imag$2({ inputs: { input: x }, backend });
94976 const i = zerosLike$3({ inputs: { x: imagPart }, backend });
94977 const result = complex$2({ inputs: { real: r, imag: i }, backend });
94978 backend.disposeIntermediateTensorInfo(realPart);
94979 backend.disposeIntermediateTensorInfo(r);
94980 backend.disposeIntermediateTensorInfo(imagPart);
94981 backend.disposeIntermediateTensorInfo(i);
94982 return result;
94983 }
94984 else {
94985 // TODO(cais, smilkov): Add WebGL shader for onesLike:
94986 // https://github.com/tensorflow/tfjs/issues/1293
94987 return fill$2({ attrs: { shape: x.shape, dtype: x.dtype, value: 1 }, backend });
94988 }
94989 }
94990 const onesLikeConfig$1 = {
94991 kernelName: OnesLike,
94992 backendName: 'webgl',
94993 kernelFunc: onesLike$3
94994 };
94995
94996 /**
94997 * @license
94998 * Copyright 2020 Google LLC. All Rights Reserved.
94999 * Licensed under the Apache License, Version 2.0 (the "License");
95000 * you may not use this file except in compliance with the License.
95001 * You may obtain a copy of the License at
95002 *
95003 * http://www.apache.org/licenses/LICENSE-2.0
95004 *
95005 * Unless required by applicable law or agreed to in writing, software
95006 * distributed under the License is distributed on an "AS IS" BASIS,
95007 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95008 * See the License for the specific language governing permissions and
95009 * limitations under the License.
95010 * =============================================================================
95011 */
95012 function pack$1(args) {
95013 const { inputs, backend, attrs } = args;
95014 const { axis } = attrs;
95015 if (inputs.length === 1) {
95016 return expandDims$3({ inputs: { input: inputs[0] }, backend, attrs: { dim: axis } });
95017 }
95018 const shape = inputs[0].shape;
95019 const dtype = inputs[0].dtype;
95020 inputs.forEach(t => {
95021 assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes');
95022 assert(dtype === t.dtype, () => 'All tensors passed to stack must have matching dtypes');
95023 });
95024 const intermediateTensorInfos = [];
95025 const expandedTensors = inputs.map(t => {
95026 const expandedT = expandDims$3({ inputs: { input: t }, backend, attrs: { dim: axis } });
95027 intermediateTensorInfos.push(expandedT);
95028 return expandedT;
95029 });
95030 const result = concat$2({ inputs: expandedTensors, backend, attrs: { axis } });
95031 intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
95032 return result;
95033 }
95034 const packConfig$1 = {
95035 kernelName: Pack,
95036 backendName: 'webgl',
95037 kernelFunc: pack$1
95038 };
95039
95040 /**
95041 * @license
95042 * Copyright 2017 Google LLC. All Rights Reserved.
95043 * Licensed under the Apache License, Version 2.0 (the "License");
95044 * you may not use this file except in compliance with the License.
95045 * You may obtain a copy of the License at
95046 *
95047 * http://www.apache.org/licenses/LICENSE-2.0
95048 *
95049 * Unless required by applicable law or agreed to in writing, software
95050 * distributed under the License is distributed on an "AS IS" BASIS,
95051 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95052 * See the License for the specific language governing permissions and
95053 * limitations under the License.
95054 * =============================================================================
95055 */
95056 class PadProgram {
95057 constructor(xShape, paddings, constantValue) {
95058 this.variableNames = ['x'];
95059 this.customUniforms = [{ name: 'value', type: 'float' }];
95060 this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
95061 const rank = xShape.length;
95062 const type = getCoordsDataType(rank);
95063 const start = paddings.map(p => p[0]).join(',');
95064 const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
95065 const unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
95066 if (rank === 1) {
95067 this.userCode = `
95068 int start = ${start};
95069 int end = ${end};
95070
95071 void main() {
95072 int outC = getOutputCoords();
95073 if (outC < start || outC >= end) {
95074 setOutput(value);
95075 } else {
95076 setOutput(getX(outC - start));
95077 }
95078 }
95079 `;
95080 return;
95081 }
95082 this.userCode = `
95083 ${type} start = ${type}(${start});
95084 ${type} end = ${type}(${end});
95085
95086 void main() {
95087 ${type} outC = getOutputCoords();
95088 if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) {
95089 setOutput(value);
95090 } else {
95091 ${type} coords = outC - start;
95092 setOutput(getX(${unpackedCoords}));
95093 }
95094 }
95095 `;
95096 }
95097 }
95098
95099 /**
95100 * @license
95101 * Copyright 2019 Google LLC. All Rights Reserved.
95102 * Licensed under the Apache License, Version 2.0 (the "License");
95103 * you may not use this file except in compliance with the License.
95104 * You may obtain a copy of the License at
95105 *
95106 * http://www.apache.org/licenses/LICENSE-2.0
95107 *
95108 * Unless required by applicable law or agreed to in writing, software
95109 * distributed under the License is distributed on an "AS IS" BASIS,
95110 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95111 * See the License for the specific language governing permissions and
95112 * limitations under the License.
95113 * =============================================================================
95114 */
95115 class PadPackedProgram {
95116 constructor(xShape, paddings, constantValue) {
95117 this.variableNames = ['x'];
95118 this.packedInputs = true;
95119 this.packedOutput = true;
95120 this.customUniforms = [{ name: 'value', type: 'float' }];
95121 this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
95122 const rank = xShape.length;
95123 const dtype = getCoordsDataType(rank);
95124 const start = paddings.map(p => p[0]).join(',');
95125 const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
95126 const coords = getChannels('rc', rank);
95127 const source = getChannels('source', rank);
95128 const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`;
95129 const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`;
95130 const componentSetup = [
95131 `${dtype} rc = outputLoc;`, `${coords[rank - 1]} += 1;
95132 if(${cLimit}) {
95133 `,
95134 rank === 1 ? '' : `}
95135 rc = outputLoc;
95136 ${coords[rank - 2]} += 1;
95137 if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {`,
95138 rank === 1 ? '' : ` ${coords[rank - 1]} += 1;
95139 if(${cLimit}) {`
95140 ];
95141 const paddingArea = rank === 1 ?
95142 'rc < start || rc >= end' :
95143 'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))';
95144 let mainLoop = '';
95145 for (let i = 0, j = rank === 1 ? 2 : 4; i < j; i++) {
95146 mainLoop += `
95147 ${componentSetup[i]}
95148 if (${paddingArea}) {
95149 result[${i}] = float(value);
95150 } else {
95151 ${dtype} source = rc - start;
95152 result[${i}] = getChannel(getX(${source.join()}), ${innerDims});
95153 }
95154 `;
95155 }
95156 mainLoop += (rank === 1 ? `} ` : `}}`);
95157 this.userCode = `
95158 const ${dtype} start = ${dtype}(${start});
95159 const ${dtype} end = ${dtype}(${end});
95160
95161 void main() {
95162 ${dtype} outputLoc = getOutputCoords();
95163 vec4 result = vec4(0.);
95164 ${mainLoop}
95165 setOutput(result);
95166 }
95167 `;
95168 }
95169 }
95170
95171 /**
95172 * @license
95173 * Copyright 2020 Google LLC. All Rights Reserved.
95174 * Licensed under the Apache License, Version 2.0 (the "License");
95175 * you may not use this file except in compliance with the License.
95176 * You may obtain a copy of the License at
95177 *
95178 * http://www.apache.org/licenses/LICENSE-2.0
95179 *
95180 * Unless required by applicable law or agreed to in writing, software
95181 * distributed under the License is distributed on an "AS IS" BASIS,
95182 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95183 * See the License for the specific language governing permissions and
95184 * limitations under the License.
95185 * =============================================================================
95186 */
95187 const padV2$1 = (args) => {
95188 const { inputs, backend, attrs } = args;
95189 const { x } = inputs;
95190 const { paddings, constantValue } = attrs;
95191 if (sizeFromShape(x.shape) === 0) {
95192 // Short-circuit the computation, since x doesn't have value, only
95193 // the shape is used to compute output shape to pad.
95194 const outputShape = paddings.map((p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */);
95195 return fill$2({
95196 backend,
95197 attrs: { shape: outputShape, value: constantValue, dtype: x.dtype }
95198 });
95199 }
95200 const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
95201 new PadPackedProgram(x.shape, paddings, constantValue) :
95202 new PadProgram(x.shape, paddings, constantValue);
95203 const customValues = [[constantValue]];
95204 return backend.runWebGLProgram(program, [x], x.dtype, customValues);
95205 };
95206 const padV2Config$1 = {
95207 kernelName: PadV2,
95208 backendName: 'webgl',
95209 kernelFunc: padV2$1
95210 };
95211
95212 /**
95213 * @license
95214 * Copyright 2020 Google LLC. All Rights Reserved.
95215 * Licensed under the Apache License, Version 2.0 (the "License");
95216 * you may not use this file except in compliance with the License.
95217 * You may obtain a copy of the License at
95218 *
95219 * http://www.apache.org/licenses/LICENSE-2.0
95220 *
95221 * Unless required by applicable law or agreed to in writing, software
95222 * distributed under the License is distributed on an "AS IS" BASIS,
95223 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95224 * See the License for the specific language governing permissions and
95225 * limitations under the License.
95226 * =============================================================================
95227 */
95228 const POW = `
95229 if(a < 0.0 && floor(b) < b){
95230 return NAN;
95231 }
95232 if (b == 0.0) {
95233 return 1.0;
95234 }
95235 return (round(mod(b, 2.0)) != 1) ?
95236 pow(abs(a), b) : sign(a) * pow(abs(a), b);
95237`;
95238 const POW_PACKED = `
95239 // isModRound1 has 1 for components with round(mod(b, 2.0)) == 1, 0 otherwise.
95240 vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1)));
95241 vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1);
95242 vec4 result = multiplier * pow(abs(a), b);
95243
95244 // Ensure that a^0 = 1, including 0^0 = 1 as this correspond to TF and JS
95245 bvec4 isExpZero = equal(b, vec4(0.0));
95246 result.r = isExpZero.r ? 1.0 : result.r;
95247 result.g = isExpZero.g ? 1.0 : result.g;
95248 result.b = isExpZero.b ? 1.0 : result.b;
95249 result.a = isExpZero.a ? 1.0 : result.a;
95250
95251 vec4 isNaN = vec4(lessThan(a, vec4(0.0))) * vec4(lessThan(floor(b), b));
95252 ` +
95253 CHECK_NAN_SNIPPET$2 + `
95254 return result;
95255`;
95256 const pow$3 = binaryKernelFunc$1({ opSnippet: POW, packedOpSnippet: POW_PACKED });
95257 const powConfig$1 = {
95258 kernelName: Pow,
95259 backendName: 'webgl',
95260 kernelFunc: pow$3
95261 };
95262
95263 /**
95264 * @license
95265 * Copyright 2020 Google LLC. All Rights Reserved.
95266 * Licensed under the Apache License, Version 2.0 (the "License");
95267 * you may not use this file except in compliance with the License.
95268 * You may obtain a copy of the License at
95269 *
95270 * http://www.apache.org/licenses/LICENSE-2.0
95271 *
95272 * Unless required by applicable law or agreed to in writing, software
95273 * distributed under the License is distributed on an "AS IS" BASIS,
95274 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95275 * See the License for the specific language governing permissions and
95276 * limitations under the License.
95277 * =============================================================================
95278 */
95279 function prod$2(args) {
95280 const { inputs, backend, attrs } = args;
95281 const { x } = inputs;
95282 const { axis, keepDims } = attrs;
95283 const xRank = x.shape.length;
95284 const toDispose = [];
95285 const origAxes = parseAxisParam(axis, x.shape);
95286 let axes = origAxes;
95287 const permutedAxes = getAxesPermutation(axes, xRank);
95288 let permutedX = x;
95289 if (permutedAxes != null) {
95290 permutedX = transpose$2({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
95291 axes = getInnerMostAxes(axes.length, xRank);
95292 toDispose.push(permutedX);
95293 }
95294 assertAxesAreInnerMostDims('prod', axes, xRank);
95295 let res;
95296 if (backend.shouldExecuteOnCPU([permutedX])) {
95297 const xVals = backend.texData.get(permutedX.dataId).values;
95298 const { outVals, outShape, outDtype } = prodImplCPU(permutedX.shape, permutedX.dtype, xVals, axes);
95299 res = backend.makeTensorInfo(outShape, outDtype, outVals);
95300 }
95301 else {
95302 const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes);
95303 const inSize = sizeFromShape(reduceShape);
95304 const a2D = reshape$3({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
95305 const outputDType = sumOutType(x.dtype);
95306 const reduced = reduce(a2D, outputDType, 'prod', backend);
95307 res = reshape$3({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
95308 toDispose.push(a2D);
95309 toDispose.push(reduced);
95310 }
95311 if (keepDims) {
95312 toDispose.push(res);
95313 const newShape = expandShapeToKeepDim(res.shape, origAxes);
95314 res = reshape$3({ inputs: { x: res }, backend, attrs: { shape: newShape } });
95315 }
95316 toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
95317 return res;
95318 }
95319 const prodConfig$1 = {
95320 kernelName: Prod,
95321 backendName: 'webgl',
95322 kernelFunc: prod$2
95323 };
95324
95325 /**
95326 * @license
95327 * Copyright 2020 Google LLC. All Rights Reserved.
95328 * Licensed under the Apache License, Version 2.0 (the "License");
95329 * you may not use this file except in compliance with the License.
95330 * You may obtain a copy of the License at
95331 *
95332 * http://www.apache.org/licenses/LICENSE-2.0
95333 *
95334 * Unless required by applicable law or agreed to in writing, software
95335 * distributed under the License is distributed on an "AS IS" BASIS,
95336 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95337 * See the License for the specific language governing permissions and
95338 * limitations under the License.
95339 * =============================================================================
95340 */
95341 const range$3 = (args) => {
95342 const { backend, attrs } = args;
95343 const { start, stop, step, dtype } = attrs;
95344 const values = rangeImplCPU(start, stop, step, dtype);
95345 return backend.makeTensorInfo([values.length], dtype, values);
95346 };
95347 const rangeConfig$1 = {
95348 kernelName: Range,
95349 backendName: 'webgl',
95350 kernelFunc: range$3
95351 };
95352
95353 /**
95354 * @license
95355 * Copyright 2020 Google LLC. All Rights Reserved.
95356 * Licensed under the Apache License, Version 2.0 (the "License");
95357 * you may not use this file except in compliance with the License.
95358 * You may obtain a copy of the License at
95359 *
95360 * http://www.apache.org/licenses/LICENSE-2.0
95361 *
95362 * Unless required by applicable law or agreed to in writing, software
95363 * distributed under the License is distributed on an "AS IS" BASIS,
95364 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95365 * See the License for the specific language governing permissions and
95366 * limitations under the License.
95367 * =============================================================================
95368 */
95369 const RECIPROCAL = `return 1.0 / x;`;
95370 const reciprocal$2 = unaryKernelFunc$1({ opSnippet: RECIPROCAL });
95371 const reciprocalConfig$1 = {
95372 kernelName: Reciprocal,
95373 backendName: 'webgl',
95374 kernelFunc: reciprocal$2,
95375 };
95376
95377 /**
95378 * @license
95379 * Copyright 2020 Google LLC. All Rights Reserved.
95380 * Licensed under the Apache License, Version 2.0 (the "License");
95381 * you may not use this file except in compliance with the License.
95382 * You may obtain a copy of the License at
95383 *
95384 * http://www.apache.org/licenses/LICENSE-2.0
95385 *
95386 * Unless required by applicable law or agreed to in writing, software
95387 * distributed under the License is distributed on an "AS IS" BASIS,
95388 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95389 * See the License for the specific language governing permissions and
95390 * limitations under the License.
95391 * =============================================================================
95392 */
95393 const RELU$2 = CHECK_NAN_SNIPPET + `
95394 return (x < 0.0) ? 0.0 : x;
95395`;
95396 const RELU_PACKED = `
95397 vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));
95398 bvec4 isNaN = isnan(x);
95399
95400 result.r = isNaN.r ? x.r : result.r;
95401 result.g = isNaN.g ? x.g : result.g;
95402 result.b = isNaN.b ? x.b : result.b;
95403 result.a = isNaN.a ? x.a : result.a;
95404
95405 return result;
95406`;
95407 const relu$2 = unaryKernelFunc$1({ opSnippet: RELU$2, packedOpSnippet: RELU_PACKED });
95408 const reluConfig$1 = {
95409 kernelName: Relu,
95410 backendName: 'webgl',
95411 kernelFunc: relu$2
95412 };
95413
95414 /**
95415 * @license
95416 * Copyright 2020 Google LLC. All Rights Reserved.
95417 * Licensed under the Apache License, Version 2.0 (the "License");
95418 * you may not use this file except in compliance with the License.
95419 * You may obtain a copy of the License at
95420 *
95421 * http://www.apache.org/licenses/LICENSE-2.0
95422 *
95423 * Unless required by applicable law or agreed to in writing, software
95424 * distributed under the License is distributed on an "AS IS" BASIS,
95425 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95426 * See the License for the specific language governing permissions and
95427 * limitations under the License.
95428 * =============================================================================
95429 */
95430 const RELU6$2 = CHECK_NAN_SNIPPET + `
95431 return (x < 0.0) ? 0.0 : min(6.0, x);
95432`;
95433 const RELU6_PACKED = `
95434 vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));
95435 bvec4 isNaN = isnan(x);
95436
95437 result.r = isNaN.r ? x.r : result.r;
95438 result.g = isNaN.g ? x.g : result.g;
95439 result.b = isNaN.b ? x.b : result.b;
95440 result.a = isNaN.a ? x.a : result.a;
95441
95442 return result;
95443`;
95444 const relu6$2 = unaryKernelFunc$1({ opSnippet: RELU6$2, packedOpSnippet: RELU6_PACKED });
95445 const relu6Config$1 = {
95446 kernelName: Relu6,
95447 backendName: 'webgl',
95448 kernelFunc: relu6$2
95449 };
95450
95451 /**
95452 * @license
95453 * Copyright 2017 Google LLC. All Rights Reserved.
95454 * Licensed under the Apache License, Version 2.0 (the "License");
95455 * you may not use this file except in compliance with the License.
95456 * You may obtain a copy of the License at
95457 *
95458 * http://www.apache.org/licenses/LICENSE-2.0
95459 *
95460 * Unless required by applicable law or agreed to in writing, software
95461 * distributed under the License is distributed on an "AS IS" BASIS,
95462 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95463 * See the License for the specific language governing permissions and
95464 * limitations under the License.
95465 * =============================================================================
95466 */
95467 class ResizeBilinearProgram {
95468 constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
95469 this.variableNames = ['A'];
95470 this.outputShape = [];
95471 const [batch, oldHeight, oldWidth, depth] = inputShape;
95472 this.outputShape = [batch, newHeight, newWidth, depth];
95473 const effectiveInSize = [
95474 (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
95475 (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
95476 ];
95477 const effectiveOutSize = [
95478 (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
95479 (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
95480 ];
95481 let sourceFracIndexRC;
95482 if (halfPixelCenters) {
95483 sourceFracIndexRC =
95484 `(vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC` +
95485 ` - vec2(0.5)`;
95486 }
95487 else {
95488 sourceFracIndexRC = `vec2(yRC) * effectiveInputOverOutputRatioRC`;
95489 }
95490 this.userCode = `
95491 const vec2 effectiveInputOverOutputRatioRC = vec2(
95492 ${effectiveInSize[0] / effectiveOutSize[0]},
95493 ${effectiveInSize[1] / effectiveOutSize[1]});
95494 const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0);
95495
95496 void main() {
95497 ivec4 coords = getOutputCoords();
95498 int b = coords[0];
95499 int d = coords[3];
95500 ivec2 yRC = coords.yz;
95501
95502 // Fractional source index.
95503 vec2 sourceFracIndexRC = ${sourceFracIndexRC};
95504
95505 // Compute the four integer indices.
95506 ivec2 sourceFloorRC = ivec2(max(sourceFracIndexRC, vec2(0.0)));
95507 ivec2 sourceCeilRC = ivec2(
95508 min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));
95509
95510 float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d);
95511 float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d);
95512 float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d);
95513 float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d);
95514
95515 vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);
95516
95517 float top = topLeft + (topRight - topLeft) * fracRC.y;
95518 float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;
95519 float newValue = top + (bottom - top) * fracRC.x;
95520
95521 setOutput(newValue);
95522 }
95523 `;
95524 }
95525 }
95526
95527 /**
95528 * @license
95529 * Copyright 2019 Google LLC. All Rights Reserved.
95530 * Licensed under the Apache License, Version 2.0 (the "License");
95531 * you may not use this file except in compliance with the License.
95532 * You may obtain a copy of the License at
95533 *
95534 * http://www.apache.org/licenses/LICENSE-2.0
95535 *
95536 * Unless required by applicable law or agreed to in writing, software
95537 * distributed under the License is distributed on an "AS IS" BASIS,
95538 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95539 * See the License for the specific language governing permissions and
95540 * limitations under the License.
95541 * =============================================================================
95542 */
95543 class ResizeBilinearPackedProgram {
95544 constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
95545 this.variableNames = ['A'];
95546 this.packedInputs = true;
95547 this.packedOutput = true;
95548 this.outputShape = [];
95549 const [batch, oldHeight, oldWidth, depth] = inputShape;
95550 this.outputShape = [batch, newHeight, newWidth, depth];
95551 const effectiveInSize = [
95552 (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
95553 (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
95554 ];
95555 const effectiveOutSize = [
95556 (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
95557 (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
95558 ];
95559 let sourceFracIndexRC;
95560 if (halfPixelCenters) {
95561 sourceFracIndexRC = `(vec3(yRC) + vec3(0.5)) * ` +
95562 `effectiveInputOverOutputRatioRC - vec3(0.5)`;
95563 }
95564 else {
95565 sourceFracIndexRC = `vec3(yRC) * effectiveInputOverOutputRatioRC`;
95566 }
95567 this.userCode = `
95568 const vec3 effectiveInputOverOutputRatioRC = vec3(
95569 ${effectiveInSize[0] / effectiveOutSize[0]},
95570 ${effectiveInSize[1] / effectiveOutSize[1]},
95571 ${effectiveInSize[1] / effectiveOutSize[1]});
95572 const vec3 inputShapeRC = vec3(${oldHeight}.0, ${oldWidth}.0,
95573 ${oldWidth}.0);
95574
95575 float getAValue(int b, int r, int c, int d) {
95576 return getChannel(getA(b, r, c, d), vec2(c, d));
95577 }
95578
95579 void main() {
95580 ivec4 coords = getOutputCoords();
95581 int b = coords[0];
95582 int d = coords[3];
95583 // Calculate values for next column in yRC.z.
95584 ivec3 yRC = coords.yzz + ivec3(0, 0, 1);
95585
95586 // Fractional source index.
95587 vec3 sourceFracIndexRC = ${sourceFracIndexRC};
95588
95589 // Compute the four integer indices.
95590 ivec3 sourceFloorRC = ivec3(max(sourceFracIndexRC, vec3(0.0)));
95591 ivec3 sourceCeilRC = ivec3(
95592 min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));
95593
95594 // Should we calculate next column and row elements in 2x2 packed cell.
95595 bool hasNextCol = d < ${depth - 1};
95596 bool hasNextRow = coords.z < ${newWidth - 1};
95597
95598 // In parallel, construct four corners for all four components in
95599 // packed 2x2 cell.
95600 vec4 topLeft = vec4(
95601 getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d),
95602 hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1)
95603 : 0.0,
95604 hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d)
95605 : 0.0,
95606 (hasNextRow && hasNextCol) ?
95607 getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0);
95608
95609 vec4 bottomLeft = vec4(
95610 getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d),
95611 hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1)
95612 : 0.0,
95613 hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d)
95614 : 0.0,
95615 (hasNextRow && hasNextCol) ?
95616 getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0);
95617
95618 vec4 topRight = vec4(
95619 getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d),
95620 hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1)
95621 : 0.0,
95622 hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d)
95623 : 0.0,
95624 (hasNextRow && hasNextCol) ?
95625 getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0);
95626
95627 vec4 bottomRight = vec4(
95628 getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d),
95629 hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1)
95630 : 0.0,
95631 hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d)
95632 : 0.0,
95633 (hasNextRow && hasNextCol) ?
95634 getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0);
95635
95636 vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC);
95637
95638 vec4 top = mix(topLeft, topRight, fracRC.yyzz);
95639 vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz);
95640 vec4 newValue = mix(top, bottom, fracRC.x);
95641
95642 setOutput(newValue);
95643 }
95644 `;
95645 }
95646 }
95647
95648 /**
95649 * @license
95650 * Copyright 2020 Google LLC. All Rights Reserved.
95651 * Licensed under the Apache License, Version 2.0 (the "License");
95652 * you may not use this file except in compliance with the License.
95653 * You may obtain a copy of the License at
95654 *
95655 * http://www.apache.org/licenses/LICENSE-2.0
95656 *
95657 * Unless required by applicable law or agreed to in writing, software
95658 * distributed under the License is distributed on an "AS IS" BASIS,
95659 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95660 * See the License for the specific language governing permissions and
95661 * limitations under the License.
95662 * =============================================================================
95663 */
95664 function resizeBilinear$2(args) {
95665 const { inputs, backend, attrs } = args;
95666 const { images } = inputs;
95667 const { alignCorners, halfPixelCenters, size } = attrs;
95668 const [newHeight, newWidth] = size;
95669 const program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ?
95670 new ResizeBilinearPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) :
95671 new ResizeBilinearProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters);
95672 return backend.runWebGLProgram(program, [images], 'float32');
95673 }
95674 const resizeBilinearConfig$1 = {
95675 kernelName: ResizeBilinear,
95676 backendName: 'webgl',
95677 kernelFunc: resizeBilinear$2
95678 };
95679
95680 /**
95681 * @license
95682 * Copyright 2018 Google LLC. All Rights Reserved.
95683 * Licensed under the Apache License, Version 2.0 (the "License");
95684 * you may not use this file except in compliance with the License.
95685 * You may obtain a copy of the License at
95686 *
95687 * http://www.apache.org/licenses/LICENSE-2.0
95688 *
95689 * Unless required by applicable law or agreed to in writing, software
95690 * distributed under the License is distributed on an "AS IS" BASIS,
95691 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95692 * See the License for the specific language governing permissions and
95693 * limitations under the License.
95694 * =============================================================================
95695 */
95696 class ResizeBilinearBackpropProgram {
95697 constructor(dyShape, inputShape, alignCorners) {
95698 this.variableNames = ['dy'];
95699 this.outputShape = [];
95700 this.outputShape = inputShape;
95701 const [, xHeight, xWidth,] = inputShape;
95702 const [, yHeight, yWidth] = dyShape;
95703 // In the backwards pass, we want to find the pixels that were generated for
95704 // each pixel in the input image the forward pass and add the corresponding
95705 // coefficient from dy to the gradient (with some interpolation).
95706 const effectiveXSize = [
95707 (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
95708 (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
95709 ];
95710 const effectiveYSize = [
95711 (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
95712 (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
95713 ];
95714 const heightScale = effectiveXSize[0] / effectiveYSize[0];
95715 const widthScale = effectiveXSize[1] / effectiveYSize[1];
95716 const invHeightScale = 1 / heightScale;
95717 const invWidthScale = 1 / widthScale;
95718 // This defines the size of the window of values around a particular
95719 // index in dy that we want to search for contributions to dx.
95720 const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
95721 const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
95722 this.userCode = `
95723 void main() {
95724 ivec4 coords = getOutputCoords();
95725 int b = coords[0];
95726 int d = coords[3];
95727 int r = coords[1];
95728 int c = coords[2];
95729
95730 float accumulator = 0.0;
95731
95732 const float heightScale = float(${heightScale});
95733 const float widthScale = float(${widthScale});
95734
95735 const float invHeightScale = float(${invHeightScale});
95736 const float invWidthScale = float(${invWidthScale});
95737
95738 const int winHeight = int(${winHeight});
95739 const int winWidth = int(${winWidth});
95740
95741 // Compute bounds for where in dy we will look
95742 float startRLerp = floor(float(r) * invHeightScale);
95743 int startDyR = int(startRLerp - float(winHeight / 2));
95744
95745 float startCLerp = floor(float(c) * invWidthScale);
95746 int startDyC = int(startCLerp - float(winWidth / 2));
95747
95748 // Loop over dy
95749 for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {
95750 int dyR = dyROffset + startDyR;
95751
95752 // Guard against the window exceeding the bounds of dy
95753 if (dyR < 0 || dyR >= ${yHeight}) {
95754 continue;
95755 }
95756
95757 for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {
95758 int dyC = dyCOffset + startDyC;
95759
95760 // Guard against the window exceeding the bounds of dy
95761 if (dyC < 0 || dyC >= ${yWidth}) {
95762 continue;
95763 }
95764
95765 float dxR = float(dyR) * heightScale;
95766 int topDxRIndex = int(floor(dxR));
95767 int bottomDxRIndex = int(min(ceil(dxR), ${xHeight - 1}.0));
95768 float dxRLerp = dxR - float(topDxRIndex);
95769 float inverseDxRLerp = 1.0 - dxRLerp;
95770
95771 float dxC = float(dyC) * widthScale;
95772 int leftDxCIndex = int(floor(dxC));
95773 int rightDxCIndex = int(min(ceil(dxC), ${xWidth - 1}.0));
95774 float dxCLerp = dxC - float(leftDxCIndex);
95775 float inverseDxCLerp = 1.0 - dxCLerp;
95776
95777 if (r == topDxRIndex && c == leftDxCIndex) {
95778 // topLeft
95779 accumulator +=
95780 getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;
95781 }
95782
95783 if (r == topDxRIndex && c == rightDxCIndex) {
95784 // topRight
95785 accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;
95786 }
95787
95788 if (r == bottomDxRIndex && c == leftDxCIndex) {
95789 // bottomLeft
95790 accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;
95791 }
95792
95793 if (r == bottomDxRIndex && c == rightDxCIndex) {
95794 // bottomRight
95795 accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;
95796 }
95797 }
95798 }
95799 // End loop over dy
95800
95801 setOutput(accumulator);
95802 }
95803 `;
95804 }
95805 }
95806
95807 /**
95808 * @license
95809 * Copyright 2020 Google LLC. All Rights Reserved.
95810 * Licensed under the Apache License, Version 2.0 (the "License");
95811 * you may not use this file except in compliance with the License.
95812 * You may obtain a copy of the License at
95813 *
95814 * http://www.apache.org/licenses/LICENSE-2.0
95815 *
95816 * Unless required by applicable law or agreed to in writing, software
95817 * distributed under the License is distributed on an "AS IS" BASIS,
95818 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95819 * See the License for the specific language governing permissions and
95820 * limitations under the License.
95821 * =============================================================================
95822 */
95823 function resizeBilinearGrad$1(args) {
95824 const { inputs, backend, attrs } = args;
95825 const { images, dy } = inputs;
95826 const { alignCorners } = attrs;
95827 const program = new ResizeBilinearBackpropProgram(dy.shape, images.shape, alignCorners);
95828 return backend.runWebGLProgram(program, [dy], dy.dtype);
95829 }
95830 const resizeBilinearGradConfig$2 = {
95831 kernelName: ResizeBilinearGrad,
95832 backendName: 'webgl',
95833 kernelFunc: resizeBilinearGrad$1
95834 };
95835
95836 /**
95837 * @license
95838 * Copyright 2018 Google LLC. All Rights Reserved.
95839 * Licensed under the Apache License, Version 2.0 (the "License");
95840 * you may not use this file except in compliance with the License.
95841 * You may obtain a copy of the License at
95842 *
95843 * http://www.apache.org/licenses/LICENSE-2.0
95844 *
95845 * Unless required by applicable law or agreed to in writing, software
95846 * distributed under the License is distributed on an "AS IS" BASIS,
95847 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95848 * See the License for the specific language governing permissions and
95849 * limitations under the License.
95850 * =============================================================================
95851 */
95852 class ResizeNearestNeighborProgram {
95853 constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
95854 this.variableNames = ['A'];
95855 this.outputShape = [];
95856 const [batch, oldHeight, oldWidth, depth] = inputShape;
95857 this.outputShape = [batch, newHeight, newWidth, depth];
95858 const effectiveInSize = [
95859 (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
95860 (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
95861 ];
95862 const effectiveOutSize = [
95863 (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
95864 (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
95865 ];
95866 // When align corners is false, we rounds the value with floor.
95867 const roundBase = alignCorners ? '0.5' : '0.0';
95868 let sourceFracIndexRC;
95869 if (halfPixelCenters) {
95870 sourceFracIndexRC =
95871 `max((vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC` +
95872 `, vec2(0.0))`;
95873 }
95874 else {
95875 sourceFracIndexRC = `vec2(yRC) * effectiveInputOverOutputRatioRC`;
95876 }
95877 this.userCode = `
95878 const vec2 effectiveInputOverOutputRatioRC = vec2(
95879 ${effectiveInSize[0] / effectiveOutSize[0]},
95880 ${effectiveInSize[1] / effectiveOutSize[1]});
95881 const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0);
95882
95883 void main() {
95884 ivec4 coords = getOutputCoords();
95885 int b = coords[0];
95886 int d = coords[3];
95887 ivec2 yRC = coords.yz;
95888
95889 // Fractional source index.
95890 vec2 sourceFracIndexRC = ${sourceFracIndexRC};
95891
95892 // Compute the coordinators of nearest neighbor point.
95893 ivec2 sourceNearestRC = ivec2(
95894 min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ${roundBase})));
95895 float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d);
95896
95897 setOutput(newValue);
95898 }
95899 `;
95900 }
95901 }
95902
95903 /**
95904 * @license
95905 * Copyright 2019 Google LLC. All Rights Reserved.
95906 * Licensed under the Apache License, Version 2.0 (the "License");
95907 * you may not use this file except in compliance with the License.
95908 * You may obtain a copy of the License at
95909 *
95910 * http://www.apache.org/licenses/LICENSE-2.0
95911 *
95912 * Unless required by applicable law or agreed to in writing, software
95913 * distributed under the License is distributed on an "AS IS" BASIS,
95914 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95915 * See the License for the specific language governing permissions and
95916 * limitations under the License.
95917 * =============================================================================
95918 */
95919 class ResizeNearestNeighborPackedProgram {
95920 constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
95921 this.variableNames = ['A'];
95922 this.packedInputs = true;
95923 this.packedOutput = true;
95924 this.outputShape = [];
95925 const [batch, oldHeight, oldWidth, depth] = inputShape;
95926 this.outputShape = [batch, newHeight, newWidth, depth];
95927 const effectiveInSize = [
95928 (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
95929 (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
95930 ];
95931 const effectiveOutSize = [
95932 (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
95933 (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
95934 ];
95935 // When align corners is false, we rounds the value with floor.
95936 const roundBase = alignCorners ? '0.5' : '0.0';
95937 let sourceFracIndexRC;
95938 if (halfPixelCenters) {
95939 sourceFracIndexRC = `max((vec3(yRC) + vec3(0.5)) * ` +
95940 `effectiveInputOverOutputRatioRC, vec3(0.0))`;
95941 }
95942 else {
95943 sourceFracIndexRC = `vec3(yRC) * effectiveInputOverOutputRatioRC`;
95944 }
95945 this.userCode = `
95946 const vec3 effectiveInputOverOutputRatioRC = vec3(
95947 ${effectiveInSize[0] / effectiveOutSize[0]},
95948 ${effectiveInSize[1] / effectiveOutSize[1]},
95949 ${effectiveInSize[1] / effectiveOutSize[1]});
95950 const vec3 inputShapeRC = vec3(${oldHeight}.0, ${oldWidth}.0,
95951 ${oldWidth}.0);
95952
95953 float getAValue(int b, int r, int c, int d) {
95954 return getChannel(getA(b, r, c, d), vec2(c, d));
95955 }
95956
95957 void main() {
95958 ivec4 coords = getOutputCoords();
95959 int b = coords[0];
95960 int d = coords[3];
95961 // Calculate values for next column in yRC.z.
95962 ivec3 yRC = coords.yzz + ivec3(0, 0, 1);
95963
95964 // Fractional source index.
95965 vec3 sourceFracIndexRC = ${sourceFracIndexRC};
95966
95967 // Compute the coordinators of nearest neighbor point.
95968 ivec3 sourceNearestRC = ivec3(
95969 min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ${roundBase})));
95970
95971 // Should we calculate next column and row elements in 2x2 packed cell.
95972 bool hasNextCol = d < ${depth - 1};
95973 bool hasNextRow = coords.z < ${newWidth - 1};
95974
95975 vec4 newValue = vec4(
95976 getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d),
95977 hasNextCol ? getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d + 1)
95978 : 0.0,
95979 hasNextRow ? getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d)
95980 : 0.0,
95981 (hasNextRow && hasNextCol) ?
95982 getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d + 1) : 0.0);
95983
95984 setOutput(newValue);
95985 }
95986 `;
95987 }
95988 }
95989
95990 /**
95991 * @license
95992 * Copyright 2020 Google LLC. All Rights Reserved.
95993 * Licensed under the Apache License, Version 2.0 (the "License");
95994 * you may not use this file except in compliance with the License.
95995 * You may obtain a copy of the License at
95996 *
95997 * http://www.apache.org/licenses/LICENSE-2.0
95998 *
95999 * Unless required by applicable law or agreed to in writing, software
96000 * distributed under the License is distributed on an "AS IS" BASIS,
96001 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96002 * See the License for the specific language governing permissions and
96003 * limitations under the License.
96004 * =============================================================================
96005 */
96006 function resizeNearestNeighbor$2(args) {
96007 const { inputs, backend, attrs } = args;
96008 const { images } = inputs;
96009 const { alignCorners, halfPixelCenters, size } = attrs;
96010 const [newHeight, newWidth] = size;
96011 const program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ?
96012 new ResizeNearestNeighborPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) :
96013 new ResizeNearestNeighborProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters);
96014 return backend.runWebGLProgram(program, [images], images.dtype);
96015 }
96016 const resizeNearestNeighborConfig$1 = {
96017 kernelName: ResizeNearestNeighbor,
96018 backendName: 'webgl',
96019 kernelFunc: resizeNearestNeighbor$2
96020 };
96021
96022 /**
96023 * @license
96024 * Copyright 2018 Google LLC. All Rights Reserved.
96025 * Licensed under the Apache License, Version 2.0 (the "License");
96026 * you may not use this file except in compliance with the License.
96027 * You may obtain a copy of the License at
96028 *
96029 * http://www.apache.org/licenses/LICENSE-2.0
96030 *
96031 * Unless required by applicable law or agreed to in writing, software
96032 * distributed under the License is distributed on an "AS IS" BASIS,
96033 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96034 * See the License for the specific language governing permissions and
96035 * limitations under the License.
96036 * =============================================================================
96037 */
96038 class ResizeNearestNeigborBackpropProgram {
96039 constructor(dyShape, inputShape, alignCorners) {
96040 this.variableNames = ['dy'];
96041 this.outputShape = [];
96042 this.outputShape = inputShape;
96043 const [, xHeight, xWidth,] = inputShape;
96044 const [, yHeight, yWidth] = dyShape;
96045 // In the backwards pass, we want to find the pixels that were generated for
96046 // each pixel in the input image the forward pass and add the corresponding
96047 // coefficient from dy to the gradient (with some interpolation).
96048 const effectiveXSize = [
96049 (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
96050 (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
96051 ];
96052 const effectiveYSize = [
96053 (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
96054 (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
96055 ];
96056 const heightScale = effectiveXSize[0] / effectiveYSize[0];
96057 const widthScale = effectiveXSize[1] / effectiveYSize[1];
96058 const invHeightScale = 1 / heightScale;
96059 const invWidthScale = 1 / widthScale;
96060 // This defines the size of the window of values around a particular
96061 // index in dy that we want to search for contributions to dx.
96062 const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
96063 const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
96064 this.userCode = `
96065 void main() {
96066 ivec4 coords = getOutputCoords();
96067 int b = coords[0];
96068 int d = coords[3];
96069 int r = coords[1];
96070 int c = coords[2];
96071
96072 float accumulator = 0.0;
96073
96074 const float heightScale = float(${heightScale});
96075 const float widthScale = float(${widthScale});
96076
96077 const float invHeightScale = float(${invHeightScale});
96078 const float invWidthScale = float(${invWidthScale});
96079
96080 const int winHeight = int(${winHeight});
96081 const int winWidth = int(${winWidth});
96082
96083 // Compute bounds for where in dy we will look
96084 float startRLerp = floor(float(r) * invHeightScale);
96085 int startDyR = int(floor(startRLerp - float(winHeight / 2)));
96086
96087 float startCLerp = floor(float(c) * invWidthScale);
96088 int startDyC = int(floor(startCLerp - float(winWidth / 2)));
96089
96090 // Loop over dy
96091 for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {
96092 int dyR = dyROffset + startDyR;
96093
96094 // Guard against the window exceeding the bounds of dy
96095 if (dyR < 0 || dyR >= ${yHeight}) {
96096 continue;
96097 }
96098
96099 for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {
96100 int dyC = dyCOffset + startDyC;
96101
96102 // Guard against the window exceeding the bounds of dy
96103 if (dyC < 0 || dyC >= ${yWidth}) {
96104 continue;
96105 }
96106
96107 float sourceFracRow =
96108 float(${effectiveXSize[0]}) *
96109 (float(dyR) / float(${effectiveYSize[0]}));
96110
96111 float sourceFracCol =
96112 float(${effectiveXSize[1]}) *
96113 (float(dyC) / float(${effectiveYSize[1]}));
96114
96115 int sourceNearestRow = int(min(
96116 float(int(${xHeight}) - 1),
96117 ${alignCorners} ? float(round(sourceFracRow)) :
96118 float(floor(sourceFracRow))));
96119
96120 int sourceNearestCol = int(min(
96121 float(int(${xWidth}) - 1),
96122 ${alignCorners} ? float(round(sourceFracCol)) :
96123 float(floor(sourceFracCol))));
96124
96125 if (r == sourceNearestRow && c == sourceNearestCol) {
96126 accumulator += getDy(b, dyR, dyC, d);
96127 }
96128 }
96129 }
96130 // End loop over dy
96131
96132 setOutput(accumulator);
96133 }
96134 `;
96135 }
96136 }
96137
96138 /**
96139 * @license
96140 * Copyright 2020 Google LLC. All Rights Reserved.
96141 * Licensed under the Apache License, Version 2.0 (the "License");
96142 * you may not use this file except in compliance with the License.
96143 * You may obtain a copy of the License at
96144 *
96145 * http://www.apache.org/licenses/LICENSE-2.0
96146 *
96147 * Unless required by applicable law or agreed to in writing, software
96148 * distributed under the License is distributed on an "AS IS" BASIS,
96149 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96150 * See the License for the specific language governing permissions and
96151 * limitations under the License.
96152 * =============================================================================
96153 */
96154 function resizeNearestNeighborGrad$1(args) {
96155 const { inputs, backend, attrs } = args;
96156 const { images, dy } = inputs;
96157 const { alignCorners } = attrs;
96158 const program = new ResizeNearestNeigborBackpropProgram(dy.shape, images.shape, alignCorners);
96159 return backend.runWebGLProgram(program, [dy], dy.dtype);
96160 }
96161 const resizeNearestNeighborGradConfig$2 = {
96162 kernelName: ResizeNearestNeighborGrad,
96163 backendName: 'webgl',
96164 kernelFunc: resizeNearestNeighborGrad$1
96165 };
96166
96167 /**
96168 * @license
96169 * Copyright 2017 Google LLC. All Rights Reserved.
96170 * Licensed under the Apache License, Version 2.0 (the "License");
96171 * you may not use this file except in compliance with the License.
96172 * You may obtain a copy of the License at
96173 *
96174 * http://www.apache.org/licenses/LICENSE-2.0
96175 *
96176 * Unless required by applicable law or agreed to in writing, software
96177 * distributed under the License is distributed on an "AS IS" BASIS,
96178 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96179 * See the License for the specific language governing permissions and
96180 * limitations under the License.
96181 * =============================================================================
96182 */
96183 class ReverseProgram {
96184 constructor(xShape, axis) {
96185 this.variableNames = ['x'];
96186 const rank = xShape.length;
96187 if (rank > 4) {
96188 throw new Error(`WebGL backend: Reverse of rank-${rank} tensor is not yet supported`);
96189 }
96190 this.outputShape = xShape;
96191 if (rank === 1) {
96192 this.userCode = `
96193 void main() {
96194 int coord = getOutputCoords();
96195 setOutput(getX(${xShape[0]} - coord - 1));
96196 }
96197 `;
96198 return;
96199 }
96200 const getInCoord = (i) => {
96201 if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
96202 return `${xShape[i]} - coords[${i}] - 1`;
96203 }
96204 return `coords[${i}]`;
96205 };
96206 const inCoords = xShape.map((_, i) => getInCoord(i)).join(',');
96207 const type = getCoordsDataType(rank);
96208 this.userCode = `
96209 void main() {
96210 ${type} coords = getOutputCoords();
96211 setOutput(getX(${inCoords}));
96212 }
96213 `;
96214 }
96215 }
96216
96217 /**
96218 * @license
96219 * Copyright 2019 Google LLC. All Rights Reserved.
96220 * Licensed under the Apache License, Version 2.0 (the "License");
96221 * you may not use this file except in compliance with the License.
96222 * You may obtain a copy of the License at
96223 *
96224 * http://www.apache.org/licenses/LICENSE-2.0
96225 *
96226 * Unless required by applicable law or agreed to in writing, software
96227 * distributed under the License is distributed on an "AS IS" BASIS,
96228 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96229 * See the License for the specific language governing permissions and
96230 * limitations under the License.
96231 * =============================================================================
96232 */
96233 class ReversePackedProgram {
96234 constructor(xShape, axis) {
96235 this.variableNames = ['x'];
96236 this.packedInputs = true;
96237 this.packedOutput = true;
96238 const rank = xShape.length;
96239 if (rank > 4) {
96240 throw new Error(`WebGL backend: Reverse of rank-${rank} tensor is not yet supported`);
96241 }
96242 this.outputShape = xShape;
96243 const channels = getChannels('rc', rank);
96244 const nextColumn = `${channels[rank - 1]} + 1 < ${this.outputShape[rank - 1]}`;
96245 const nextRow = `${channels[rank - 2]} + 1 < ${this.outputShape[rank - 2]}`;
96246 const type = getCoordsDataType(rank);
96247 if (rank === 1) {
96248 this.userCode = `
96249 void main(){
96250 int rc = getOutputCoords();
96251 vec4 result = vec4(0.);
96252 result.r = getChannel(getX(${xShape[0]} - rc - 1),
96253 ${xShape[0]} - rc - 1);
96254 if(${nextColumn}){
96255 result.g = getChannel(getX(${xShape[0]} - (rc + 1) - 1),
96256 ${xShape[0]} - (rc + 1) - 1);
96257 }
96258 setOutput(result);
96259 }
96260 `;
96261 }
96262 else {
96263 this.userCode = `
96264 void main() {
96265 ${type} rc = getOutputCoords();
96266 vec4 result = vec4(0.);
96267 result.r = ${getR(channels.slice())};
96268 if(${nextColumn}){
96269 result.g = ${getG(channels.slice())};
96270 }
96271 if(${nextRow}) {
96272 result.b = ${getB(channels.slice())};
96273 if(${nextColumn}) {
96274 result.a = ${getA(channels.slice())};
96275 }
96276 }
96277 setOutput(result);
96278 }
96279 `;
96280 }
96281 function getR(channels) {
96282 return getChannel(channels);
96283 }
96284 function getG(channels) {
96285 channels[rank - 1] = '(' + channels[rank - 1] + ` + 1)`;
96286 return getChannel(channels);
96287 }
96288 function getB(channels) {
96289 channels[rank - 2] = '(' + channels[rank - 2] + ` + 1)`;
96290 return getChannel(channels);
96291 }
96292 function getA(channels) {
96293 channels[rank - 1] = '(' + channels[rank - 1] + ` + 1)`;
96294 channels[rank - 2] = '(' + channels[rank - 2] + ` + 1)`;
96295 return getChannel(channels);
96296 }
96297 function getChannel(channels) {
96298 const inCoordsArray = xShape.map((_, i) => getInCoord(i, channels));
96299 const inCoords = inCoordsArray.join(',');
96300 const innerDims = inCoordsArray.slice(-2).join(',');
96301 return `getChannel(getX(${inCoords}), vec2(${innerDims}))`;
96302 }
96303 function getInCoord(i, channels1) {
96304 if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
96305 return `${xShape[i]} - ${channels1[i]} - 1`;
96306 }
96307 else {
96308 return `${channels1[i]}`;
96309 }
96310 }
96311 }
96312 }
96313
96314 /**
96315 * @license
96316 * Copyright 2020 Google LLC. All Rights Reserved.
96317 * Licensed under the Apache License, Version 2.0 (the "License");
96318 * you may not use this file except in compliance with the License.
96319 * You may obtain a copy of the License at
96320 *
96321 * http://www.apache.org/licenses/LICENSE-2.0
96322 *
96323 * Unless required by applicable law or agreed to in writing, software
96324 * distributed under the License is distributed on an "AS IS" BASIS,
96325 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96326 * See the License for the specific language governing permissions and
96327 * limitations under the License.
96328 * =============================================================================
96329 */
96330 function reverse$2(args) {
96331 const { inputs, backend, attrs } = args;
96332 const { x } = inputs;
96333 const { dims } = attrs;
96334 const xRank = x.shape.length;
96335 const $dims = parseAxisParam(dims, x.shape);
96336 if (xRank === 0) {
96337 return identity$2({ inputs: { x }, backend });
96338 }
96339 const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
96340 new ReversePackedProgram(x.shape, $dims) :
96341 new ReverseProgram(x.shape, $dims);
96342 return backend.runWebGLProgram(program, [x], x.dtype);
96343 }
96344 const reverseConfig$1 = {
96345 kernelName: Reverse,
96346 backendName: 'webgl',
96347 kernelFunc: reverse$2
96348 };
96349
96350 /**
96351 * @license
96352 * Copyright 2020 Google LLC. All Rights Reserved.
96353 * Licensed under the Apache License, Version 2.0 (the "License");
96354 * you may not use this file except in compliance with the License.
96355 * You may obtain a copy of the License at
96356 *
96357 * http://www.apache.org/licenses/LICENSE-2.0
96358 *
96359 * Unless required by applicable law or agreed to in writing, software
96360 * distributed under the License is distributed on an "AS IS" BASIS,
96361 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96362 * See the License for the specific language governing permissions and
96363 * limitations under the License.
96364 * =============================================================================
96365 */
96366 class RotateProgram {
96367 constructor(imageShape, fillValue) {
96368 this.variableNames = ['Image'];
96369 this.outputShape = [];
96370 this.customUniforms = [{ name: 'params', type: 'vec4' }];
96371 const imageHeight = imageShape[1];
96372 const imageWidth = imageShape[2];
96373 this.outputShape = imageShape;
96374 let fillSnippet = '';
96375 if (typeof fillValue === 'number') {
96376 fillSnippet = `float outputValue = ${fillValue.toFixed(2)};`;
96377 }
96378 else {
96379 fillSnippet = `
96380 vec3 fill = vec3(${fillValue.join(',')});
96381 float outputValue = fill[coords[3]];`;
96382 }
96383 this.userCode = `
96384 void main() {
96385 ivec4 coords = getOutputCoords();
96386 int x = coords[2];
96387 int y = coords[1];
96388 float coordXFloat = (float(x) - params[0]) * params[3] -
96389 (float(y) - params[1]) * params[2];
96390 float coordYFloat = (float(x) - params[0]) * params[2] +
96391 (float(y) - params[1]) * params[3];
96392 int coordX = int(round(coordXFloat + params[0]));
96393 int coordY = int(round(coordYFloat + params[1]));
96394 ${fillSnippet}
96395 if(coordX >= 0 && coordX < ${imageWidth} && coordY >= 0 && coordY < ${imageHeight}) {
96396 outputValue = getImage(coords[0], coordY, coordX, coords[3]);
96397 }
96398 setOutput(outputValue);
96399 }
96400 `;
96401 }
96402 }
96403
96404 /**
96405 * @license
96406 * Copyright 2020 Google LLC. All Rights Reserved.
96407 * Licensed under the Apache License, Version 2.0 (the "License");
96408 * you may not use this file except in compliance with the License.
96409 * You may obtain a copy of the License at
96410 *
96411 * http://www.apache.org/licenses/LICENSE-2.0
96412 *
96413 * Unless required by applicable law or agreed to in writing, software
96414 * distributed under the License is distributed on an "AS IS" BASIS,
96415 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96416 * See the License for the specific language governing permissions and
96417 * limitations under the License.
96418 * =============================================================================
96419 */
96420 const rotateWithOffsetConfig$1 = {
96421 kernelName: RotateWithOffset,
96422 backendName: 'webgl',
96423 kernelFunc: ({ inputs, attrs, backend }) => {
96424 const { image } = inputs;
96425 const { radians, fillValue, center } = attrs;
96426 const webglBackend = backend;
96427 const program = new RotateProgram(image.shape, fillValue);
96428 const [centerX, centerY] = getImageCenter(center, image.shape[1], image.shape[2]);
96429 const customValues = [[centerX, centerY, Math.sin(radians), Math.cos(radians)]];
96430 const output = webglBackend.runWebGLProgram(program, [image], image.dtype, customValues);
96431 return output;
96432 }
96433 };
96434
96435 /**
96436 * @license
96437 * Copyright 2020 Google LLC. All Rights Reserved.
96438 * Licensed under the Apache License, Version 2.0 (the "License");
96439 * you may not use this file except in compliance with the License.
96440 * You may obtain a copy of the License at
96441 *
96442 * http://www.apache.org/licenses/LICENSE-2.0
96443 *
96444 * Unless required by applicable law or agreed to in writing, software
96445 * distributed under the License is distributed on an "AS IS" BASIS,
96446 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96447 * See the License for the specific language governing permissions and
96448 * limitations under the License.
96449 * =============================================================================
96450 */
96451 const ROUND = `
96452 // OpenGL ES does not support round function.
96453 // The algorithm is based on banker's rounding.
96454 float base = floor(x);
96455 if ((x - base) < 0.5) {
96456 return floor(x);
96457 } else if ((x - base) > 0.5) {
96458 return ceil(x);
96459 } else {
96460 if (mod(base, 2.0) == 0.0) {
96461 return base;
96462 } else {
96463 return base + 1.0;
96464 }
96465 }
96466`;
96467 const round$3 = unaryKernelFunc$1({ opSnippet: ROUND });
96468 const roundConfig$1 = {
96469 kernelName: Round,
96470 backendName: 'webgl',
96471 kernelFunc: round$3,
96472 };
96473
96474 /**
96475 * @license
96476 * Copyright 2020 Google LLC. All Rights Reserved.
96477 * Licensed under the Apache License, Version 2.0 (the "License");
96478 * you may not use this file except in compliance with the License.
96479 * You may obtain a copy of the License at
96480 *
96481 * http://www.apache.org/licenses/LICENSE-2.0
96482 *
96483 * Unless required by applicable law or agreed to in writing, software
96484 * distributed under the License is distributed on an "AS IS" BASIS,
96485 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96486 * See the License for the specific language governing permissions and
96487 * limitations under the License.
96488 * =============================================================================
96489 */
96490 const RSQRT = `return inversesqrt(x);`;
96491 const rsqrt$2 = unaryKernelFunc$1({ opSnippet: RSQRT, cpuKernelImpl: rsqrtImplCPU });
96492 const rsqrtConfig$1 = {
96493 kernelName: Rsqrt,
96494 backendName: 'webgl',
96495 kernelFunc: rsqrt$2
96496 };
96497
96498 /**
96499 * @license
96500 * Copyright 2018 Google LLC. All Rights Reserved.
96501 * Licensed under the Apache License, Version 2.0 (the "License");
96502 * you may not use this file except in compliance with the License.
96503 * You may obtain a copy of the License at
96504 *
96505 * http://www.apache.org/licenses/LICENSE-2.0
96506 *
96507 * Unless required by applicable law or agreed to in writing, software
96508 * distributed under the License is distributed on an "AS IS" BASIS,
96509 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96510 * See the License for the specific language governing permissions and
96511 * limitations under the License.
96512 * =============================================================================
96513 */
96514 class ScatterProgram {
96515 constructor(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex = true) {
96516 this.variableNames = ['updates', 'indices', 'defaultValue'];
96517 this.outputShape = shape;
96518 const stridesType = getCoordsDataType(strides.length);
96519 const dtype = getCoordsDataType(shape.length);
96520 let indicesString = '';
96521 if (indicesRank === 1) {
96522 indicesString = 'i';
96523 }
96524 else if (indicesRank === 2) {
96525 indicesString = 'i, j';
96526 }
96527 const indicesSnippet = `getIndices(${indicesString})`;
96528 let updatesString = '';
96529 if (updatesRank === 1) {
96530 updatesString = 'i';
96531 }
96532 else if (updatesRank === 2) {
96533 updatesString = 'i, coords[1]';
96534 }
96535 const updatesSnippet = `getUpdates(${updatesString})`;
96536 const strideString = sliceDim > 1 ? 'strides[j]' : 'strides';
96537 this.userCode = `
96538 ${stridesType} strides = ${stridesType}(${strides});
96539
96540 void main() {
96541 ${dtype} coords = getOutputCoords();
96542 float sum = 0.0;
96543 bool found = false;
96544 for (int i = 0; i < ${updateSize}; i++) {
96545 int flattenedIndex = 0;
96546 for (int j = 0; j < ${sliceDim}; j++) {
96547 int index = round(${indicesSnippet});
96548 flattenedIndex += index * ${strideString};
96549 }
96550 if (flattenedIndex == coords[0]) {
96551 sum += ${updatesSnippet};
96552 found = true;
96553 }
96554 }
96555 setOutput(mix(getDefaultValue(), sum, float(found)));
96556 }
96557 `;
96558 }
96559 }
96560
96561 /**
96562 * @license
96563 * Copyright 2020 Google LLC. All Rights Reserved.
96564 * Licensed under the Apache License, Version 2.0 (the "License");
96565 * you may not use this file except in compliance with the License.
96566 * You may obtain a copy of the License at
96567 *
96568 * http://www.apache.org/licenses/LICENSE-2.0
96569 *
96570 * Unless required by applicable law or agreed to in writing, software
96571 * distributed under the License is distributed on an "AS IS" BASIS,
96572 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96573 * See the License for the specific language governing permissions and
96574 * limitations under the License.
96575 * =============================================================================
96576 */
96577 function scatterNd$1(args) {
96578 const { inputs, backend, attrs } = args;
96579 const { indices, updates } = inputs;
96580 const { shape } = attrs;
96581 const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, shape);
96582 const flattenShape = [outputSize / sliceSize, sliceSize];
96583 if (outputSize === 0) {
96584 return backend.makeTensorInfo(shape, indices.dtype);
96585 }
96586 const flattenIndices = reshape$3({ inputs: { x: indices }, backend, attrs: { shape: [numUpdates, sliceRank] } });
96587 const flattenX = reshape$3({ inputs: { x: updates }, backend, attrs: { shape: [numUpdates, sliceSize] } });
96588 const defaultValue = backend.makeTensorInfo([], 'float32', new Float32Array([0])); // scalar(0)
96589 const program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape);
96590 const res = backend.runWebGLProgram(program, [flattenX, flattenIndices, defaultValue], flattenX.dtype);
96591 const reshaped = reshape$3({ inputs: { x: res }, backend, attrs: { shape } });
96592 backend.disposeIntermediateTensorInfo(flattenIndices);
96593 backend.disposeIntermediateTensorInfo(flattenX);
96594 backend.disposeIntermediateTensorInfo(res);
96595 backend.disposeIntermediateTensorInfo(defaultValue);
96596 return reshaped;
96597 }
96598 const scatterNdConfig$1 = {
96599 kernelName: ScatterNd,
96600 backendName: 'webgl',
96601 kernelFunc: scatterNd$1
96602 };
96603
96604 /**
96605 * @license
96606 * Copyright 2022 Google LLC. All Rights Reserved.
96607 * Licensed under the Apache License, Version 2.0 (the "License");
96608 * you may not use this file except in compliance with the License.
96609 * You may obtain a copy of the License at
96610 *
96611 * http://www.apache.org/licenses/LICENSE-2.0
96612 *
96613 * Unless required by applicable law or agreed to in writing, software
96614 * distributed under the License is distributed on an "AS IS" BASIS,
96615 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96616 * See the License for the specific language governing permissions and
96617 * limitations under the License.
96618 * =============================================================================
96619 */
96620 class SearchSortedProgram {
96621 constructor(batchSize, numInputs, numValues, side) {
96622 this.variableNames = ['sortedSequence', 'values'];
96623 this.customUniforms = [{ name: 'numInputs', type: 'int' }];
96624 this.outputShape = [batchSize, numValues];
96625 const webGL2LoopHead = 'while (left < right) {';
96626 // WebGL1 doesn't accept non constant loop conditions, so upper bound loop
96627 // iterations.
96628 const webGL1LoopHead = `for (int i = 0; i < ${Math.ceil(Math.log2(numInputs + 1))}; ++i) { if (left >= right) break;`;
96629 const loopHead = env().getNumber('WEBGL_VERSION') === 2 ? webGL2LoopHead :
96630 webGL1LoopHead;
96631 // left corresponds to lower bound and right to upper bound.
96632 const boundComparator = side === 'left' ? '<' : '<=';
96633 this.userCode = `
96634 int findBound(int batch, float value) {
96635 int left = 0;
96636 int right = numInputs;
96637 int mid;
96638 ${loopHead}
96639 mid = (left + right) / 2;
96640 if (getSortedSequence(batch, mid) ${boundComparator} value) {
96641 left = mid + 1;
96642 } else {
96643 right = mid;
96644 }
96645 }
96646 return right;
96647 }
96648
96649 void main() {
96650 ivec2 coords = getOutputCoords();
96651 int batch = coords[0];
96652 int valueIndex = coords[1];
96653
96654 float value = getValues(batch, valueIndex);
96655
96656 setOutput(float(findBound(batch, value)));
96657 }
96658 `;
96659 }
96660 }
96661
96662 /**
96663 * @license
96664 * Copyright 2022 Google LLC. All Rights Reserved.
96665 * Licensed under the Apache License, Version 2.0 (the "License");
96666 * you may not use this file except in compliance with the License.
96667 * You may obtain a copy of the License at
96668 *
96669 * http://www.apache.org/licenses/LICENSE-2.0
96670 *
96671 * Unless required by applicable law or agreed to in writing, software
96672 * distributed under the License is distributed on an "AS IS" BASIS,
96673 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96674 * See the License for the specific language governing permissions and
96675 * limitations under the License.
96676 * =============================================================================
96677 */
96678 function searchSorted$2(args) {
96679 const { inputs, backend, attrs } = args;
96680 const { sortedSequence, values } = inputs;
96681 const { side } = attrs;
96682 const program = new SearchSortedProgram(sortedSequence.shape[0], sortedSequence.shape[1], values.shape[1], side);
96683 const customValues = [[sortedSequence.shape[1]]];
96684 return backend.runWebGLProgram(program, [sortedSequence, values], 'int32', customValues);
96685 }
96686 const searchSortedConfig$1 = {
96687 kernelName: SearchSorted,
96688 backendName: 'webgl',
96689 kernelFunc: searchSorted$2,
96690 };
96691
96692 /**
96693 * @license
96694 * Copyright 2017 Google LLC. All Rights Reserved.
96695 * Licensed under the Apache License, Version 2.0 (the "License");
96696 * you may not use this file except in compliance with the License.
96697 * You may obtain a copy of the License at
96698 *
96699 * http://www.apache.org/licenses/LICENSE-2.0
96700 *
96701 * Unless required by applicable law or agreed to in writing, software
96702 * distributed under the License is distributed on an "AS IS" BASIS,
96703 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96704 * See the License for the specific language governing permissions and
96705 * limitations under the License.
96706 * =============================================================================
96707 */
96708 class SelectProgram {
96709 constructor(cRank, shape, rank) {
96710 this.variableNames = ['c', 'a', 'b'];
96711 this.outputShape = shape;
96712 let cCoords;
96713 let abCoords;
96714 if (rank > 4) {
96715 throw Error(`Where for rank ${rank} is not yet supported`);
96716 }
96717 if (rank === 1) {
96718 abCoords = `resRC`;
96719 cCoords = `resRC`;
96720 }
96721 else {
96722 const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
96723 const cCoordVars = [];
96724 const abCoordVars = [];
96725 for (let i = 0; i < shape.length; i++) {
96726 abCoordVars.push(`${currentCoords[i]}`);
96727 if (i < cRank) {
96728 cCoordVars.push(`${currentCoords[i]}`);
96729 }
96730 }
96731 cCoords = cCoordVars.join();
96732 abCoords = abCoordVars.join();
96733 }
96734 const dtype = getCoordsDataType(rank);
96735 this.userCode = `
96736 void main() {
96737 ${dtype} resRC = getOutputCoords();
96738 float cVal = getC(${cCoords});
96739 if (cVal >= 1.0) {
96740 setOutput(getA(${abCoords}));
96741 } else {
96742 setOutput(getB(${abCoords}));
96743 }
96744 }
96745 `;
96746 }
96747 }
96748
96749 /**
96750 * @license
96751 * Copyright 2020 Google LLC. All Rights Reserved.
96752 * Licensed under the Apache License, Version 2.0 (the "License");
96753 * you may not use this file except in compliance with the License.
96754 * You may obtain a copy of the License at
96755 *
96756 * http://www.apache.org/licenses/LICENSE-2.0
96757 *
96758 * Unless required by applicable law or agreed to in writing, software
96759 * distributed under the License is distributed on an "AS IS" BASIS,
96760 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96761 * See the License for the specific language governing permissions and
96762 * limitations under the License.
96763 * =============================================================================
96764 */
96765 function select$2(args) {
96766 const { inputs, backend } = args;
96767 const { condition, t, e } = inputs;
96768 const program = new SelectProgram(condition.shape.length, t.shape, t.shape.length);
96769 return backend.runWebGLProgram(program, [condition, t, e], upcastType(t.dtype, e.dtype));
96770 }
96771 const selectConfig$1 = {
96772 kernelName: Select,
96773 backendName: 'webgl',
96774 kernelFunc: select$2
96775 };
96776
96777 /**
96778 * @license
96779 * Copyright 2020 Google LLC. All Rights Reserved.
96780 * Licensed under the Apache License, Version 2.0 (the "License");
96781 * you may not use this file except in compliance with the License.
96782 * You may obtain a copy of the License at
96783 *
96784 * http://www.apache.org/licenses/LICENSE-2.0
96785 *
96786 * Unless required by applicable law or agreed to in writing, software
96787 * distributed under the License is distributed on an "AS IS" BASIS,
96788 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96789 * See the License for the specific language governing permissions and
96790 * limitations under the License.
96791 * =============================================================================
96792 */
96793 const SELU = `
96794 // Stable and Attracting Fixed Point (0, 1) for Normalized Weights.
96795 // see: https://arxiv.org/abs/1706.02515
96796 float scaleAlpha = ${SELU_SCALEALPHA};
96797 float scale = ${SELU_SCALE};
96798 return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0);
96799`;
96800 const selu$2 = unaryKernelFunc$1({ opSnippet: SELU });
96801 const seluConfig$1 = {
96802 kernelName: Selu,
96803 backendName: 'webgl',
96804 kernelFunc: selu$2,
96805 };
96806
96807 /**
96808 * @license
96809 * Copyright 2020 Google LLC. All Rights Reserved.
96810 * Licensed under the Apache License, Version 2.0 (the "License");
96811 * you may not use this file except in compliance with the License.
96812 * You may obtain a copy of the License at
96813 *
96814 * http://www.apache.org/licenses/LICENSE-2.0
96815 *
96816 * Unless required by applicable law or agreed to in writing, software
96817 * distributed under the License is distributed on an "AS IS" BASIS,
96818 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96819 * See the License for the specific language governing permissions and
96820 * limitations under the License.
96821 * =============================================================================
96822 */
96823 const SIGMOID$2 = CHECK_NAN_SNIPPET_UNARY + `
96824 return 1.0 / (1.0 + exp(-1.0 * x));
96825`;
96826 const SIGMOID_PACKED = `
96827 vec4 result = 1.0 / (1.0 + exp(-1.0 * x));
96828 bvec4 isNaN = isnan(x);
96829
96830 result.r = isNaN.r ? x.r : result.r;
96831 result.g = isNaN.g ? x.g : result.g;
96832 result.b = isNaN.b ? x.b : result.b;
96833 result.a = isNaN.a ? x.a : result.a;
96834
96835 return result;
96836`;
96837 const sigmoid$2 = unaryKernelFunc$1({
96838 opSnippet: SIGMOID$2,
96839 packedOpSnippet: SIGMOID_PACKED,
96840 cpuKernelImpl: sigmoidImplCPU
96841 });
96842 const sigmoidConfig$1 = {
96843 kernelName: Sigmoid,
96844 backendName: 'webgl',
96845 kernelFunc: sigmoid$2,
96846 };
96847
96848 /**
96849 * @license
96850 * Copyright 2020 Google LLC. All Rights Reserved.
96851 * Licensed under the Apache License, Version 2.0 (the "License");
96852 * you may not use this file except in compliance with the License.
96853 * You may obtain a copy of the License at
96854 *
96855 * http://www.apache.org/licenses/LICENSE-2.0
96856 *
96857 * Unless required by applicable law or agreed to in writing, software
96858 * distributed under the License is distributed on an "AS IS" BASIS,
96859 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96860 * See the License for the specific language governing permissions and
96861 * limitations under the License.
96862 * =============================================================================
96863 */
96864 // Sign does not propagate NANs.
96865 const SIGN = `
96866 if (isnan(x)) { return 0.0; }
96867 return sign(x);
96868`;
96869 const sign$3 = unaryKernelFunc$1({ opSnippet: SIGN });
96870 const signConfig$1 = {
96871 kernelName: Sign,
96872 backendName: 'webgl',
96873 kernelFunc: sign$3,
96874 };
96875
96876 /**
96877 * @license
96878 * Copyright 2020 Google LLC. All Rights Reserved.
96879 * Licensed under the Apache License, Version 2.0 (the "License");
96880 * you may not use this file except in compliance with the License.
96881 * You may obtain a copy of the License at
96882 *
96883 * http://www.apache.org/licenses/LICENSE-2.0
96884 *
96885 * Unless required by applicable law or agreed to in writing, software
96886 * distributed under the License is distributed on an "AS IS" BASIS,
96887 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96888 * See the License for the specific language governing permissions and
96889 * limitations under the License.
96890 * =============================================================================
96891 */
96892 const SIN = CHECK_NAN_SNIPPET_UNARY + `
96893 return sin(x);
96894`;
96895 const sin$2 = unaryKernelFunc$1({ opSnippet: SIN });
96896 const sinConfig$1 = {
96897 kernelName: Sin,
96898 backendName: 'webgl',
96899 kernelFunc: sin$2,
96900 };
96901
96902 /**
96903 * @license
96904 * Copyright 2020 Google LLC. All Rights Reserved.
96905 * Licensed under the Apache License, Version 2.0 (the "License");
96906 * you may not use this file except in compliance with the License.
96907 * You may obtain a copy of the License at
96908 *
96909 * http://www.apache.org/licenses/LICENSE-2.0
96910 *
96911 * Unless required by applicable law or agreed to in writing, software
96912 * distributed under the License is distributed on an "AS IS" BASIS,
96913 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96914 * See the License for the specific language governing permissions and
96915 * limitations under the License.
96916 * =============================================================================
96917 */
96918 const SINH = `
96919 float e2x = exp(x);
96920 return (e2x - 1.0 / e2x) / 2.0;
96921`;
96922 const sinh$2 = unaryKernelFunc$1({ opSnippet: SINH });
96923 const sinhConfig$1 = {
96924 kernelName: Sinh,
96925 backendName: 'webgl',
96926 kernelFunc: sinh$2,
96927 };
96928
96929 /**
96930 * @license
96931 * Copyright 2020 Google LLC. All Rights Reserved.
96932 * Licensed under the Apache License, Version 2.0 (the "License");
96933 * you may not use this file except in compliance with the License.
96934 * You may obtain a copy of the License at
96935 *
96936 * http://www.apache.org/licenses/LICENSE-2.0
96937 *
96938 * Unless required by applicable law or agreed to in writing, software
96939 * distributed under the License is distributed on an "AS IS" BASIS,
96940 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96941 * See the License for the specific language governing permissions and
96942 * limitations under the License.
96943 * =============================================================================
96944 */
96945 const SOFTPLUS = `
96946 float epsilon = 1.1920928955078125e-7;
96947 float threshold = log(epsilon) + 2.0;
96948
96949 bool too_large = x > -threshold;
96950 bool too_small = x < threshold;
96951
96952 float result;
96953 float exp_x = exp(x);
96954
96955 if (too_large){
96956 result = x;
96957 }
96958 else if (too_small){
96959 result = exp_x;
96960 }
96961 else{
96962 result = log(exp_x + 1.0);
96963 }
96964 return result;
96965`;
96966 const softplus$2 = unaryKernelFunc$1({ opSnippet: SOFTPLUS });
96967 const softplusConfig$1 = {
96968 kernelName: Softplus,
96969 backendName: 'webgl',
96970 kernelFunc: softplus$2,
96971 };
96972
96973 /**
96974 * @license
96975 * Copyright 2020 Google LLC. All Rights Reserved.
96976 * Licensed under the Apache License, Version 2.0 (the "License");
96977 * you may not use this file except in compliance with the License.
96978 * You may obtain a copy of the License at
96979 *
96980 * http://www.apache.org/licenses/LICENSE-2.0
96981 *
96982 * Unless required by applicable law or agreed to in writing, software
96983 * distributed under the License is distributed on an "AS IS" BASIS,
96984 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96985 * See the License for the specific language governing permissions and
96986 * limitations under the License.
96987 * =============================================================================
96988 */
96989 const spaceToBatchND$2 = (args) => {
96990 const { inputs, backend, attrs } = args;
96991 const { x } = inputs;
96992 const { blockShape, paddings } = attrs;
96993 assert(x.shape.length <= 4, () => 'spaceToBatchND for rank > 4 with a WebGL backend not ' +
96994 'implemented yet');
96995 const prod = blockShape.reduce((a, b) => a * b);
96996 const completePaddings = [[0, 0]];
96997 completePaddings.push(...paddings);
96998 for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
96999 completePaddings.push([0, 0]);
97000 }
97001 const toDispose = [];
97002 const paddedX = padV2$1({
97003 inputs: { x },
97004 backend,
97005 attrs: { paddings: completePaddings, constantValue: 0 }
97006 });
97007 const reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false);
97008 const permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false);
97009 const flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false);
97010 const reshapedPaddedX = reshape$3({ inputs: { x: paddedX }, backend, attrs: { shape: reshapedPaddedShape } });
97011 const paddedXT = transpose$2({
97012 inputs: { x: reshapedPaddedX },
97013 backend,
97014 attrs: { perm: permutedReshapedPaddedPermutation }
97015 });
97016 const result = reshape$3({ inputs: { x: paddedXT }, backend, attrs: { shape: flattenShape } });
97017 toDispose.push(paddedX);
97018 toDispose.push(reshapedPaddedX);
97019 toDispose.push(paddedXT);
97020 toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
97021 return result;
97022 };
97023 const spaceToBatchNDConfig$1 = {
97024 kernelName: SpaceToBatchND,
97025 backendName: 'webgl',
97026 kernelFunc: spaceToBatchND$2
97027 };
97028
97029 /**
97030 * @license
97031 * Copyright 2021 Google LLC. All Rights Reserved.
97032 * Licensed under the Apache License, Version 2.0 (the "License");
97033 * you may not use this file except in compliance with the License.
97034 * You may obtain a copy of the License at
97035 *
97036 * http://www.apache.org/licenses/LICENSE-2.0
97037 *
97038 * Unless required by applicable law or agreed to in writing, software
97039 * distributed under the License is distributed on an "AS IS" BASIS,
97040 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97041 * See the License for the specific language governing permissions and
97042 * limitations under the License.
97043 * =============================================================================
97044 */
97045 function sparseFillEmptyRows$2(args) {
97046 const { inputs, backend } = args;
97047 const { indices, values, denseShape, defaultValue } = inputs;
97048 if (denseShape.shape.length !== 1) {
97049 throw new Error(`Dense shape must be a vector, saw:
97050 ${denseShape.shape}`);
97051 }
97052 if (indices.shape.length !== 2) {
97053 throw new Error(`Indices must be a matrix, saw:
97054 ${indices.shape}`);
97055 }
97056 if (values.shape.length !== 1) {
97057 throw new Error(`Values must be a vector, saw:
97058 ${values.shape}`);
97059 }
97060 if (defaultValue.shape.length !== 0) {
97061 throw new Error(`Default value must be a scalar, saw:
97062 ${defaultValue.shape}`);
97063 }
97064 const $indices = backend.readSync(indices.dataId);
97065 const $values = backend.readSync(values.dataId);
97066 const $denseShape = backend.readSync(denseShape.dataId);
97067 const $defaultValue = backend.readSync(defaultValue.dataId)[0];
97068 const [outputIndices, outputIndicesShape, outputValues, emptyRowIndicator, reverseIndexMap] = sparseFillEmptyRowsImplCPU($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue);
97069 return [
97070 backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices),
97071 backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues),
97072 backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map((value) => Number(value)))),
97073 backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap)),
97074 ];
97075 }
97076 const sparseFillEmptyRowsConfig$1 = {
97077 kernelName: SparseFillEmptyRows,
97078 backendName: 'webgl',
97079 kernelFunc: sparseFillEmptyRows$2,
97080 };
97081
97082 /**
97083 * @license
97084 * Copyright 2021 Google LLC. All Rights Reserved.
97085 * Licensed under the Apache License, Version 2.0 (the "License");
97086 * you may not use this file except in compliance with the License.
97087 * You may obtain a copy of the License at
97088 *
97089 * http://www.apache.org/licenses/LICENSE-2.0
97090 *
97091 * Unless required by applicable law or agreed to in writing, software
97092 * distributed under the License is distributed on an "AS IS" BASIS,
97093 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97094 * See the License for the specific language governing permissions and
97095 * limitations under the License.
97096 * =============================================================================
97097 */
97098 function sparseReshape$2(args) {
97099 const { inputs, backend } = args;
97100 const { inputIndices, inputShape, newShape } = inputs;
97101 if (inputIndices.shape.length !== 2) {
97102 throw new Error(`Input indices should be a matrix but received shape ${inputIndices.shape}`);
97103 }
97104 if (inputShape.shape.length !== 1) {
97105 throw new Error(`Input shape should be a vector but received shape ${inputShape.shape}`);
97106 }
97107 if (newShape.shape.length !== 1) {
97108 throw new Error(`Target shape should be a vector but received shape ${newShape.shape}`);
97109 }
97110 const $inputShape = Array.from(backend.readSync(inputShape.dataId));
97111 const $inputIndices = backend.readSync(inputIndices.dataId);
97112 const targetShape = Array.from(backend.readSync(newShape.dataId));
97113 const [newIndices, indicesShape, outputShape] = sparseReshapeImplCPU($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape);
97114 return [
97115 backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices),
97116 backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape)),
97117 ];
97118 }
97119 const sparseReshapeConfig$1 = {
97120 kernelName: SparseReshape,
97121 backendName: 'webgl',
97122 kernelFunc: sparseReshape$2,
97123 };
97124
97125 /**
97126 * @license
97127 * Copyright 2021 Google LLC. All Rights Reserved.
97128 * Licensed under the Apache License, Version 2.0 (the "License");
97129 * you may not use this file except in compliance with the License.
97130 * You may obtain a copy of the License at
97131 *
97132 * http://www.apache.org/licenses/LICENSE-2.0
97133 *
97134 * Unless required by applicable law or agreed to in writing, software
97135 * distributed under the License is distributed on an "AS IS" BASIS,
97136 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97137 * See the License for the specific language governing permissions and
97138 * limitations under the License.
97139 * =============================================================================
97140 */
97141 function sparseSegmentMean$2(args) {
97142 const { inputs, backend } = args;
97143 const { data, indices, segmentIds } = inputs;
97144 if (data.shape.length < 1) {
97145 throw new Error(`Data should be at least 1 dimensional but received scalar`);
97146 }
97147 if (indices.shape.length !== 1) {
97148 throw new Error(`Indices should be a vector but received shape
97149 ${indices.shape}`);
97150 }
97151 if (segmentIds.shape.length !== 1) {
97152 throw new Error(`Segment ids should be a vector but received shape
97153 ${segmentIds.shape}`);
97154 }
97155 const $data = backend.readSync(data.dataId);
97156 const $indices = backend.readSync(indices.dataId);
97157 const $segmentIds = backend.readSync(segmentIds.dataId);
97158 const [outputData, outputDataShape] = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds, true);
97159 return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
97160 }
97161 const sparseSegmentMeanConfig$1 = {
97162 kernelName: SparseSegmentMean,
97163 backendName: 'webgl',
97164 kernelFunc: sparseSegmentMean$2,
97165 };
97166
97167 /**
97168 * @license
97169 * Copyright 2021 Google LLC. All Rights Reserved.
97170 * Licensed under the Apache License, Version 2.0 (the "License");
97171 * you may not use this file except in compliance with the License.
97172 * You may obtain a copy of the License at
97173 *
97174 * http://www.apache.org/licenses/LICENSE-2.0
97175 *
97176 * Unless required by applicable law or agreed to in writing, software
97177 * distributed under the License is distributed on an "AS IS" BASIS,
97178 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97179 * See the License for the specific language governing permissions and
97180 * limitations under the License.
97181 * =============================================================================
97182 */
97183 function sparseSegmentSum$2(args) {
97184 const { inputs, backend } = args;
97185 const { data, indices, segmentIds } = inputs;
97186 if (data.shape.length < 1) {
97187 throw new Error(`Data should be at least 1 dimensional but received scalar`);
97188 }
97189 if (indices.shape.length !== 1) {
97190 throw new Error(`Indices should be a vector but received shape
97191 ${indices.shape}`);
97192 }
97193 if (segmentIds.shape.length !== 1) {
97194 throw new Error(`Segment ids should be a vector but received shape
97195 ${segmentIds.shape}`);
97196 }
97197 const $data = backend.readSync(data.dataId);
97198 const $indices = backend.readSync(indices.dataId);
97199 const $segmentIds = backend.readSync(segmentIds.dataId);
97200 const [outputData, outputDataShape] = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds);
97201 return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
97202 }
97203 const sparseSegmentSumConfig$1 = {
97204 kernelName: SparseSegmentSum,
97205 backendName: 'webgl',
97206 kernelFunc: sparseSegmentSum$2,
97207 };
97208
97209 /**
97210 * @license
97211 * Copyright 2020 Google LLC. All Rights Reserved.
97212 * Licensed under the Apache License, Version 2.0 (the "License");
97213 * you may not use this file except in compliance with the License.
97214 * You may obtain a copy of the License at
97215 *
97216 * http://www.apache.org/licenses/LICENSE-2.0
97217 *
97218 * Unless required by applicable law or agreed to in writing, software
97219 * distributed under the License is distributed on an "AS IS" BASIS,
97220 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97221 * See the License for the specific language governing permissions and
97222 * limitations under the License.
97223 * =============================================================================
97224 */
97225 function sparseToDense$2(args) {
97226 const { inputs, backend, attrs } = args;
97227 const { sparseIndices, sparseValues, defaultValue } = inputs;
97228 const { outputShape } = attrs;
97229 const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(sparseValues, sparseIndices, outputShape);
97230 const sumDupeIndices = false;
97231 if (sparseValues.dtype === 'string') {
97232 const indicesBuf = backend.bufferSync(sparseIndices);
97233 const updatesBuf = backend.bufferSync(sparseValues);
97234 const $defaultValue = decodeString(backend.readSync(defaultValue.dataId)[0]);
97235 const outBuf = scatterImplCPU(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
97236 return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values);
97237 }
97238 const program = new ScatterProgram(numUpdates, sliceRank, sparseIndices.shape.length, sparseValues.shape.length, strides, [outputSize, 1], sumDupeIndices);
97239 const res = backend.runWebGLProgram(program, [sparseValues, sparseIndices, defaultValue], sparseValues.dtype);
97240 const reshaped = reshape$3({ inputs: { x: res }, backend, attrs: { shape: outputShape } });
97241 backend.disposeIntermediateTensorInfo(res);
97242 return reshaped;
97243 }
97244 const sparseToDenseConfig$1 = {
97245 kernelName: SparseToDense,
97246 backendName: 'webgl',
97247 kernelFunc: sparseToDense$2
97248 };
97249
97250 /**
97251 * @license
97252 * Copyright 2020 Google LLC. All Rights Reserved.
97253 * Licensed under the Apache License, Version 2.0 (the "License");
97254 * you may not use this file except in compliance with the License.
97255 * You may obtain a copy of the License at
97256 *
97257 * http://www.apache.org/licenses/LICENSE-2.0
97258 *
97259 * Unless required by applicable law or agreed to in writing, software
97260 * distributed under the License is distributed on an "AS IS" BASIS,
97261 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97262 * See the License for the specific language governing permissions and
97263 * limitations under the License.
97264 * =============================================================================
97265 */
97266 function splitV$1(args) {
97267 const { inputs, backend, attrs } = args;
97268 const { x } = inputs;
97269 const { numOrSizeSplits, axis } = attrs;
97270 const $axis = parseAxisParam(axis, x.shape)[0];
97271 const splitSizes = prepareSplitSize(x, numOrSizeSplits, $axis);
97272 const xRank = x.shape.length;
97273 const begin = new Array(xRank).fill(0);
97274 const size = x.shape.slice();
97275 return splitSizes.map(s => {
97276 const sliceSize = [...size];
97277 sliceSize[$axis] = s;
97278 const sliceT = slice$2({ inputs: { x }, backend, attrs: { begin, size: sliceSize } });
97279 begin[$axis] += s;
97280 return sliceT;
97281 });
97282 }
97283 const splitVConfig$1 = {
97284 kernelName: SplitV,
97285 backendName: 'webgl',
97286 kernelFunc: splitV$1
97287 };
97288
97289 /**
97290 * @license
97291 * Copyright 2020 Google LLC. All Rights Reserved.
97292 * Licensed under the Apache License, Version 2.0 (the "License");
97293 * you may not use this file except in compliance with the License.
97294 * You may obtain a copy of the License at
97295 *
97296 * http://www.apache.org/licenses/LICENSE-2.0
97297 *
97298 * Unless required by applicable law or agreed to in writing, software
97299 * distributed under the License is distributed on an "AS IS" BASIS,
97300 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97301 * See the License for the specific language governing permissions and
97302 * limitations under the License.
97303 * =============================================================================
97304 */
97305 const SQRT = `return sqrt(x);`;
97306 const sqrt$2 = unaryKernelFunc$1({ opSnippet: SQRT, packedOpSnippet: SQRT, cpuKernelImpl: sqrtImplCPU });
97307 const sqrtConfig$1 = {
97308 kernelName: Sqrt,
97309 backendName: 'webgl',
97310 kernelFunc: sqrt$2
97311 };
97312
97313 /**
97314 * @license
97315 * Copyright 2019 Google LLC. All Rights Reserved.
97316 * Licensed under the Apache License, Version 2.0 (the "License");
97317 * you may not use this file except in compliance with the License.
97318 * You may obtain a copy of the License at
97319 *
97320 * http://www.apache.org/licenses/LICENSE-2.0
97321 *
97322 * Unless required by applicable law or agreed to in writing, software
97323 * distributed under the License is distributed on an "AS IS" BASIS,
97324 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97325 * See the License for the specific language governing permissions and
97326 * limitations under the License.
97327 * =============================================================================
97328 */
97329 const SQUARE = `return x * x;`;
97330 const square$2 = unaryKernelFunc$1({ opSnippet: SQUARE });
97331 const squareConfig$1 = {
97332 kernelName: Square,
97333 backendName: 'webgl',
97334 kernelFunc: square$2,
97335 };
97336
97337 /**
97338 * @license
97339 * Copyright 2020 Google LLC. All Rights Reserved.
97340 * Licensed under the Apache License, Version 2.0 (the "License");
97341 * you may not use this file except in compliance with the License.
97342 * You may obtain a copy of the License at
97343 *
97344 * http://www.apache.org/licenses/LICENSE-2.0
97345 *
97346 * Unless required by applicable law or agreed to in writing, software
97347 * distributed under the License is distributed on an "AS IS" BASIS,
97348 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97349 * See the License for the specific language governing permissions and
97350 * limitations under the License.
97351 * =============================================================================
97352 */
97353 const SQUARED_DIFFERENCE$1 = 'return (a - b) * (a - b);';
97354 const squaredDifference$2 = binaryKernelFunc$1({ opSnippet: SQUARED_DIFFERENCE$1, packedOpSnippet: SQUARED_DIFFERENCE$1 });
97355 const squaredDifferenceConfig$1 = {
97356 kernelName: SquaredDifference,
97357 backendName: 'webgl',
97358 kernelFunc: squaredDifference$2,
97359 };
97360
97361 /**
97362 * @license
97363 * Copyright 2020 Google LLC. All Rights Reserved.
97364 * Licensed under the Apache License, Version 2.0 (the "License");
97365 * you may not use this file except in compliance with the License.
97366 * You may obtain a copy of the License at
97367 *
97368 * http://www.apache.org/licenses/LICENSE-2.0
97369 *
97370 * Unless required by applicable law or agreed to in writing, software
97371 * distributed under the License is distributed on an "AS IS" BASIS,
97372 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97373 * See the License for the specific language governing permissions and
97374 * limitations under the License.
97375 * =============================================================================
97376 */
97377 function step$2({ inputs, attrs, backend }) {
97378 const { x } = inputs;
97379 const opSnippet = CHECK_NAN_SNIPPET + `
97380 return x > 0.0 ? 1.0 : float(${attrs.alpha});
97381 `;
97382 const program = new UnaryOpProgram(x.shape, opSnippet);
97383 return backend.runWebGLProgram(program, [x], x.dtype);
97384 }
97385 const stepConfig$1 = {
97386 kernelName: Step,
97387 backendName: 'webgl',
97388 kernelFunc: step$2,
97389 };
97390
97391 /**
97392 * @license
97393 * Copyright 2017 Google LLC. All Rights Reserved.
97394 * Licensed under the Apache License, Version 2.0 (the "License");
97395 * you may not use this file except in compliance with the License.
97396 * You may obtain a copy of the License at
97397 *
97398 * http://www.apache.org/licenses/LICENSE-2.0
97399 *
97400 * Unless required by applicable law or agreed to in writing, software
97401 * distributed under the License is distributed on an "AS IS" BASIS,
97402 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97403 * See the License for the specific language governing permissions and
97404 * limitations under the License.
97405 * =============================================================================
97406 */
97407 class StridedSliceProgram {
97408 constructor(begin, strides, size) {
97409 this.variableNames = ['x'];
97410 this.outputShape = size;
97411 const rank = size.length;
97412 const inputDtype = getCoordsDataType(size.length);
97413 const dtype = getCoordsDataType(size.length);
97414 let newCoords = '';
97415 if (rank === 1) {
97416 newCoords = 'coords * strides + begin';
97417 }
97418 else {
97419 let outputAxis = 0;
97420 newCoords =
97421 size.map((_, i) => {
97422 outputAxis++;
97423 return size.length === 1 ?
97424 `coords * strides[${i}] + begin[${i}]` :
97425 `coords[${outputAxis - 1}] * strides[${i}] + begin[${i}]`;
97426 })
97427 .join(',');
97428 }
97429 this.userCode = `
97430 ${inputDtype} begin = ${inputDtype}(${begin});
97431 ${inputDtype} strides = ${inputDtype}(${strides});
97432
97433 void main() {
97434 ${dtype} coords = getOutputCoords();
97435 setOutput(getX(${newCoords}));
97436 }
97437 `;
97438 }
97439 }
97440
97441 /**
97442 * @license
97443 * Copyright 2020 Google LLC. All Rights Reserved.
97444 * Licensed under the Apache License, Version 2.0 (the "License");
97445 * you may not use this file except in compliance with the License.
97446 * You may obtain a copy of the License at
97447 *
97448 * http://www.apache.org/licenses/LICENSE-2.0
97449 *
97450 * Unless required by applicable law or agreed to in writing, software
97451 * distributed under the License is distributed on an "AS IS" BASIS,
97452 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97453 * See the License for the specific language governing permissions and
97454 * limitations under the License.
97455 * =============================================================================
97456 */
97457 function stridedSlice$2(args) {
97458 const { inputs, backend, attrs } = args;
97459 const { x } = inputs;
97460 const { begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask } = attrs;
97461 const { finalShapeSparse, finalShape, isIdentity, sliceDim0, isSimpleSlice, begin: $begin, end: $end, strides: $strides } = sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
97462 let result;
97463 if (isIdentity) {
97464 // Optimization #1, slice is a no-op plus reshape
97465 result = reshape$3({ inputs: { x }, backend, attrs: { shape: finalShape } });
97466 }
97467 else if (sliceDim0 || isSimpleSlice) {
97468 // Optimization #2, slice is memory contiguous (only occurs in dim 0)
97469 assert(x.shape.length >= 1, () => `Input must have rank at least 1, got: ${x.shape.length}`);
97470 const size = computeOutShape($begin, $end, $strides);
97471 // To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end).
97472 const sliced = slice$2({ inputs: { x }, backend, attrs: { begin: $begin, size } });
97473 result =
97474 reshape$3({ inputs: { x: sliced }, backend, attrs: { shape: finalShape } });
97475 backend.disposeIntermediateTensorInfo(sliced);
97476 }
97477 else {
97478 const shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]);
97479 if (shouldExecuteOnCPU) {
97480 // tslint:disable-next-line: no-unnecessary-type-assertion
97481 const values = backend.readSync(x.dataId);
97482 // tslint:disable-next-line: no-unnecessary-type-assertion
97483 const xBuf = buffer(x.shape, x.dtype, values);
97484 const resultValues = stridedSliceImplCPU(finalShapeSparse, xBuf, $strides, $begin);
97485 result = backend.makeTensorInfo(finalShape, x.dtype, resultValues.values);
97486 }
97487 else {
97488 const program = new StridedSliceProgram($begin, $strides, finalShapeSparse);
97489 result = backend.runWebGLProgram(program, [x], x.dtype);
97490 }
97491 }
97492 const resultReshaped = reshape$3({ inputs: { x: result }, backend, attrs: { shape: finalShape } });
97493 backend.disposeIntermediateTensorInfo(result);
97494 return resultReshaped;
97495 }
97496 const stridedSliceConfig$1 = {
97497 kernelName: StridedSlice,
97498 backendName: 'webgl',
97499 kernelFunc: stridedSlice$2
97500 };
97501
97502 /**
97503 * @license
97504 * Copyright 2021 Google LLC. All Rights Reserved.
97505 * Licensed under the Apache License, Version 2.0 (the "License");
97506 * you may not use this file except in compliance with the License.
97507 * You may obtain a copy of the License at
97508 *
97509 * http://www.apache.org/licenses/LICENSE-2.0
97510 *
97511 * Unless required by applicable law or agreed to in writing, software
97512 * distributed under the License is distributed on an "AS IS" BASIS,
97513 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97514 * See the License for the specific language governing permissions and
97515 * limitations under the License.
97516 * =============================================================================
97517 */
97518 function stringNGrams$2(args) {
97519 const { inputs, backend, attrs } = args;
97520 const { separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences } = attrs;
97521 const { data, dataSplits } = inputs;
97522 const $data = backend.readSync(data.dataId);
97523 const $dataSplits = backend.readSync(dataSplits.dataId);
97524 const [nGrams, nGramsSplits] = stringNGramsImplCPU($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences);
97525 return [
97526 backend.makeTensorInfo([nGrams.length], 'string', nGrams),
97527 backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits),
97528 ];
97529 }
97530 const stringNGramsConfig$1 = {
97531 kernelName: StringNGrams,
97532 backendName: 'webgl',
97533 kernelFunc: stringNGrams$2,
97534 };
97535
97536 /**
97537 * @license
97538 * Copyright 2021 Google LLC. All Rights Reserved.
97539 * Licensed under the Apache License, Version 2.0 (the "License");
97540 * you may not use this file except in compliance with the License.
97541 * You may obtain a copy of the License at
97542 *
97543 * http://www.apache.org/licenses/LICENSE-2.0
97544 *
97545 * Unless required by applicable law or agreed to in writing, software
97546 * distributed under the License is distributed on an "AS IS" BASIS,
97547 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97548 * See the License for the specific language governing permissions and
97549 * limitations under the License.
97550 * =============================================================================
97551 */
97552 function stringSplit$2(args) {
97553 const { inputs, backend, attrs } = args;
97554 const { skipEmpty } = attrs;
97555 const { input, delimiter } = inputs;
97556 if (input.dtype !== 'string') {
97557 throw new Error('Input must be of datatype string');
97558 }
97559 if (input.shape.length !== 1) {
97560 throw new Error(`Input must be a vector, got shape: ${input.shape}`);
97561 }
97562 if (delimiter.shape.length !== 0) {
97563 throw new Error(`Delimiter must be a scalar, got shape: ${delimiter.shape}`);
97564 }
97565 const $input = backend.readSync(input.dataId);
97566 const $delimiter = backend.readSync(delimiter.dataId)[0];
97567 const [indices, values, shape] = stringSplitImplCPU($input, $delimiter, skipEmpty);
97568 const outputSize = values.length;
97569 return [
97570 backend.makeTensorInfo([outputSize, 2], 'int32', indices),
97571 backend.makeTensorInfo([outputSize], 'string', values),
97572 backend.makeTensorInfo([2], 'int32', new Int32Array(shape))
97573 ];
97574 }
97575 const stringSplitConfig$1 = {
97576 kernelName: StringSplit,
97577 backendName: 'webgl',
97578 kernelFunc: stringSplit$2,
97579 };
97580
97581 /**
97582 * @license
97583 * Copyright 2021 Google LLC. All Rights Reserved.
97584 * Licensed under the Apache License, Version 2.0 (the "License");
97585 * you may not use this file except in compliance with the License.
97586 * You may obtain a copy of the License at
97587 *
97588 * http://www.apache.org/licenses/LICENSE-2.0
97589 *
97590 * Unless required by applicable law or agreed to in writing, software
97591 * distributed under the License is distributed on an "AS IS" BASIS,
97592 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97593 * See the License for the specific language governing permissions and
97594 * limitations under the License.
97595 * =============================================================================
97596 */
97597 function stringToHashBucketFast$2(args) {
97598 const { inputs, backend, attrs } = args;
97599 const { numBuckets } = attrs;
97600 const { input } = inputs;
97601 if (input.dtype !== 'string') {
97602 throw new Error('Input must be of datatype string');
97603 }
97604 if (numBuckets <= 0) {
97605 throw new Error(`Number of buckets must be at least 1`);
97606 }
97607 const $input = backend.readSync(input.dataId);
97608 const output = stringToHashBucketFastImplCPU($input, numBuckets);
97609 return backend.makeTensorInfo(input.shape, 'int32', output);
97610 }
97611 const stringToHashBucketFastConfig$1 = {
97612 kernelName: StringToHashBucketFast,
97613 backendName: 'webgl',
97614 kernelFunc: stringToHashBucketFast$2,
97615 };
97616
97617 /**
97618 * @license
97619 * Copyright 2020 Google LLC. All Rights Reserved.
97620 * Licensed under the Apache License, Version 2.0 (the "License");
97621 * you may not use this file except in compliance with the License.
97622 * You may obtain a copy of the License at
97623 *
97624 * http://www.apache.org/licenses/LICENSE-2.0
97625 *
97626 * Unless required by applicable law or agreed to in writing, software
97627 * distributed under the License is distributed on an "AS IS" BASIS,
97628 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97629 * See the License for the specific language governing permissions and
97630 * limitations under the License.
97631 * =============================================================================
97632 */
97633 const TAN = `return tan(x);`;
97634 const tan$2 = unaryKernelFunc$1({ opSnippet: TAN });
97635 const tanConfig$1 = {
97636 kernelName: Tan,
97637 backendName: 'webgl',
97638 kernelFunc: tan$2,
97639 };
97640
97641 /**
97642 * @license
97643 * Copyright 2020 Google LLC. All Rights Reserved.
97644 * Licensed under the Apache License, Version 2.0 (the "License");
97645 * you may not use this file except in compliance with the License.
97646 * You may obtain a copy of the License at
97647 *
97648 * http://www.apache.org/licenses/LICENSE-2.0
97649 *
97650 * Unless required by applicable law or agreed to in writing, software
97651 * distributed under the License is distributed on an "AS IS" BASIS,
97652 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97653 * See the License for the specific language governing permissions and
97654 * limitations under the License.
97655 * =============================================================================
97656 */
97657 const TANH = `
97658 float e2x = exp(-2.0 * abs(x));
97659 return sign(x) * (1.0 - e2x) / (1.0 + e2x);
97660`;
97661 const tanh$3 = unaryKernelFunc$1({ opSnippet: TANH });
97662 const tanhConfig$1 = {
97663 kernelName: Tanh,
97664 backendName: 'webgl',
97665 kernelFunc: tanh$3,
97666 };
97667
97668 /**
97669 * @license
97670 * Copyright 2017 Google LLC. All Rights Reserved.
97671 * Licensed under the Apache License, Version 2.0 (the "License");
97672 * you may not use this file except in compliance with the License.
97673 * You may obtain a copy of the License at
97674 *
97675 * http://www.apache.org/licenses/LICENSE-2.0
97676 *
97677 * Unless required by applicable law or agreed to in writing, software
97678 * distributed under the License is distributed on an "AS IS" BASIS,
97679 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97680 * See the License for the specific language governing permissions and
97681 * limitations under the License.
97682 * =============================================================================
97683 */
97684 class TileProgram {
97685 constructor(aShape, reps) {
97686 this.variableNames = ['A'];
97687 const outputShape = new Array(aShape.length);
97688 for (let i = 0; i < outputShape.length; i++) {
97689 outputShape[i] = aShape[i] * reps[i];
97690 }
97691 this.outputShape = outputShape;
97692 this.rank = outputShape.length;
97693 const dtype = getCoordsDataType(this.rank);
97694 const sourceCoords = getSourceCoords$2(aShape);
97695 this.userCode = `
97696 void main() {
97697 ${dtype} resRC = getOutputCoords();
97698 setOutput(getA(${sourceCoords}));
97699 }
97700 `;
97701 }
97702 }
97703 function getSourceCoords$2(aShape) {
97704 const rank = aShape.length;
97705 if (rank > 5) {
97706 throw Error(`Tile for rank ${rank} is not yet supported`);
97707 }
97708 if (rank === 1) {
97709 return `imod(resRC, ${aShape[0]})`;
97710 }
97711 const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u'];
97712 const sourceCoords = [];
97713 for (let i = 0; i < aShape.length; i++) {
97714 sourceCoords.push(`imod(${currentCoords[i]}, ${aShape[i]})`);
97715 }
97716 return sourceCoords.join();
97717 }
97718
97719 /**
97720 * @license
97721 * Copyright 2020 Google LLC. All Rights Reserved.
97722 * Licensed under the Apache License, Version 2.0 (the "License");
97723 * you may not use this file except in compliance with the License.
97724 * You may obtain a copy of the License at
97725 *
97726 * http://www.apache.org/licenses/LICENSE-2.0
97727 *
97728 * Unless required by applicable law or agreed to in writing, software
97729 * distributed under the License is distributed on an "AS IS" BASIS,
97730 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97731 * See the License for the specific language governing permissions and
97732 * limitations under the License.
97733 * =============================================================================
97734 */
97735 function tile$3(params) {
97736 const { inputs, backend, attrs } = params;
97737 const { x } = inputs;
97738 const { reps } = attrs;
97739 // tile gpu program cannot handle rank > 5 case.
97740 if (x.dtype === 'string' || x.shape.length > 5) {
97741 // Even thought string tensor is always on CPU, just to be consistent on how
97742 // to access tensor data.
97743 const data = backend.readSync(x.dataId);
97744 const value = x.dtype === 'string' ?
97745 data.map(d => decodeString(d)) :
97746 data;
97747 const buf = buffer(x.shape, x.dtype, value);
97748 const outBuf = tileImplCPU(buf, reps);
97749 return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
97750 }
97751 const program = new TileProgram(x.shape, reps);
97752 const output = backend.runWebGLProgram(program, [x], x.dtype);
97753 return output;
97754 }
97755 const tileConfig$1 = {
97756 kernelName: Tile,
97757 backendName: 'webgl',
97758 kernelFunc: tile$3,
97759 };
97760
97761 // Based on Algorithm 2 of Bitonic Top K, ref:
97762 // https://anilshanbhag.in/static/papers/gputopk_sigmod18.pdf
97763 // The original algorithm is based on computing the top K only, however
97764 // since for TFJS we require the indices of the top K values as well then the
97765 // algorithm found here is a bit modified. Rather than producing the values
97766 // at each step, the indices containing the top K are generated instead.
97767 // The output values are not generated to reduce the number of outputs in the
97768 // GPU, the values can easily be retrieved from the indices using a gather
97769 // op.
97770 class SwapProgram {
97771 /**
97772 * @param shape desired output shape (can be larger than input shape, output
97773 * will be padded with -Infinity)
97774 */
97775 constructor(shape) {
97776 this.variableNames = ['x', 'indices'];
97777 // |n| Size of the original input of TopK.
97778 // |firstPass|indicates if this is the first time swap is being used which
97779 // means no indices input containing the top K is present yet.
97780 // |inc| Swaps pairs of indices (0, inc), (1, inc + 1), (2, inc + 2) ...
97781 this.customUniforms = [
97782 { name: 'n', type: 'int' },
97783 { name: 'firstPass', type: 'int' },
97784 { name: 'negativeInf', type: 'float' },
97785 { name: 'dir', type: 'int' },
97786 { name: 'inc', type: 'int' }
97787 ];
97788 this.outputShape = shape;
97789 this.userCode = `
97790 void main() {
97791 ivec2 coords = getOutputCoords();
97792 int batch = coords[0];
97793 int elemIdx = coords[1];
97794
97795 // We compare elements pair-wise within a group of size 2 * inc.
97796 // The comparing rule for each group alternates between ascending
97797 // and descending. Within each group, we compare each pair at
97798 // positions i and i+inc. To decide whether an element at position i
97799 // is x0 or x1, we mod it by 2 * inc, if the result is smaller than
97800 // inc, it is in the first half of the group, we denote it as x0,
97801 // otherwise we denote it as x1.
97802 // For example, as shown in the Bitonic top K paper referenced above,
97803 // Figure5(a) shows that element[1] is in the
97804 // second half of the group when group size is 2, but it is in the
97805 // first half of the group when group size is 4.
97806
97807 bool isFirstInPair = imod(elemIdx, 2 * inc) < inc;
97808 int i = isFirstInPair ? elemIdx : elemIdx - inc;
97809
97810 int i0 = firstPass == 1 ? i : int(getIndices(batch, i));
97811 int i1 = firstPass == 1 ? i + inc : int(getIndices(batch, i + inc));
97812 float x0 = i0 < n ? getX(batch, i0) : negativeInf;
97813 float x1 = i1 < n ? getX(batch, i1) : negativeInf;
97814
97815 // Denotes which direction indices are in (ascending or descending).
97816 bool reverse = imod(elemIdx, 2 * dir) >= dir;
97817 bool isGreater = x0 > x1 || (x0 == x1 && i1 > i0);
97818 if (reverse == isGreater) { // Elements in opposite order of direction
97819 int iTemp = i0;
97820 i0 = i1;
97821 i1 = iTemp;
97822 }
97823 if (isFirstInPair) {
97824 setOutput(float(i0));
97825 } else {
97826 setOutput(float(i1));
97827 }
97828 }
97829 `;
97830 }
97831 }
97832 class MergeProgram {
97833 /**
97834 * @param shape desired output shape (must be half of the input size)
97835 */
97836 constructor(shape) {
97837 this.variableNames = ['x', 'indices'];
97838 // |n| Size of the original input of TopK
97839 // |firstPass| indicates if this is the first time swap is being used which
97840 // means no indices input containing the top K is present yet.
97841 // |k| Top k elements desired
97842 this.customUniforms = [
97843 { name: 'n', type: 'int' },
97844 { name: 'firstPass', type: 'int' },
97845 { name: 'k', type: 'int' }
97846 ];
97847 this.outputShape = shape;
97848 this.userCode = `
97849 void main() {
97850 // Takes max of indices (0, k), (1, k + 1), (2, k + 2) ...
97851 ivec2 coords = getOutputCoords();
97852 int batch = coords[0];
97853 int elemIdx = coords[1];
97854
97855 // The output size is half of the previous size.
97856 // If the previous sequence is | | | | _ _ _ _ | | | | _ _ _ _ (k=4),
97857 // we only need to output the indices at positions |, the indices at
97858 // positions _ can be thrown away, see Figure5(b) After Phase 2
97859 // (Merge phase) in the Bitonic Top K paper referenced above.
97860 // For example, the paper shows we only need to output the orange bars.
97861 // The output sequence should look like this | | | | | | | |.
97862 // Because the sequence is halved, to map the output index back
97863 // to the previous sequence to find the corresponding value,
97864 // we need to double the index. When we double the index,
97865 // we basically interpolate a position, so 2i looks like
97866 // | _ | _ | _ | _ | _ | _ | _. We move the | to the first k position
97867 // of each 2k positions by - elemIdx % k. E.g. for output at
97868 // index 4,5,6,7, we want to get the corresponding element at
97869 // original index 8,9,10,11, for output at index 8,9,10,11,
97870 // we want to get the corresponding element at original index
97871 // 16,17,18,19, so on and so forth.
97872
97873 int i = elemIdx < k ? elemIdx : (elemIdx * 2 - imod(elemIdx, k));
97874 int i0 = firstPass == 1 ? i : int(getIndices(batch, i));
97875 int i1 = firstPass == 1 ? i + k : int(getIndices(batch, i + k));
97876
97877 float x0 = getX(batch, i0);
97878 float x1 = i1 < n ? getX(batch, i1) : x0;
97879
97880 setOutput(x0 >= x1 ? float(i0) : float(i1));
97881 }
97882 `;
97883 }
97884 }
97885
97886 /**
97887 * @license
97888 * Copyright 2020 Google LLC. All Rights Reserved.
97889 * Licensed under the Apache License, Version 2.0 (the "License");
97890 * you may not use this file except in compliance with the License.
97891 * You may obtain a copy of the License at
97892 *
97893 * http://www.apache.org/licenses/LICENSE-2.0
97894 *
97895 * Unless required by applicable law or agreed to in writing, software
97896 * distributed under the License is distributed on an "AS IS" BASIS,
97897 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97898 * See the License for the specific language governing permissions and
97899 * limitations under the License.
97900 * =============================================================================
97901 */
97902 function disposeIntermediateTensorInfoOrNull(backend, tensorInfo) {
97903 if (tensorInfo !== null) {
97904 backend.disposeIntermediateTensorInfo(tensorInfo);
97905 }
97906 }
97907 function roundUpToPow2(num) {
97908 let pow2 = 1;
97909 while (pow2 < num) {
97910 pow2 *= 2;
97911 }
97912 return pow2;
97913 }
97914 // Based on Algorithm 2 of Bitonic Top K, ref:
97915 // https://anilshanbhag.in/static/papers/gputopk_sigmod18.pdf
97916 function topK$1(args) {
97917 const { inputs, backend, attrs } = args;
97918 const { x } = inputs;
97919 const { k, sorted } = attrs;
97920 // Empirically determined constant used to determine last dim threshold for
97921 // handing off execution to the CPU.
97922 const TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD = env().getNumber('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD');
97923 // Empirically determined constant used to determine k threshold for handing
97924 // off execution to the CPU.
97925 const TOPK_K_CPU_HANDOFF_THRESHOLD = env().getNumber('TOPK_K_CPU_HANDOFF_THRESHOLD');
97926 const xShape = x.shape;
97927 const lastDim = xShape[xShape.length - 1];
97928 if (backend.shouldExecuteOnCPU([x]) ||
97929 lastDim < TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD ||
97930 k > TOPK_K_CPU_HANDOFF_THRESHOLD) {
97931 const xVals = backend.readSync(x.dataId);
97932 const [allTopKVals, allTopKIndices] = topKImplCPU(xVals, xShape, x.dtype, k, sorted);
97933 return [
97934 backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values),
97935 backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)
97936 ];
97937 }
97938 if (k === 0) {
97939 xShape[xShape.length - 1] = 0;
97940 return [
97941 backend.makeTensorInfo(xShape, x.dtype, []),
97942 backend.makeTensorInfo(xShape, 'int32', [])
97943 ];
97944 }
97945 if (lastDim === 1 /* firstPass */) {
97946 return [
97947 x, fill$2({ attrs: { shape: xShape, dtype: 'int32', value: 0 }, backend })
97948 ];
97949 }
97950 // Eagerly unpack x input since it is passed in to all the shaders which
97951 // require unpacked inputs.
97952 const xtexData = backend.texData.get(x.dataId);
97953 const xIsPacked = xtexData !== null && xtexData.isPacked;
97954 const xUnPacked = xIsPacked ? backend.unpackTensor(x) : x;
97955 // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.
97956 const xSize = sizeFromShape(xShape);
97957 const batch = xSize / lastDim;
97958 const x2D = reshape$3({ inputs: { x: xUnPacked }, attrs: { shape: [batch, lastDim] }, backend });
97959 if (xIsPacked) {
97960 disposeIntermediateTensorInfoOrNull(backend, xUnPacked);
97961 }
97962 const kPow2 = roundUpToPow2(k);
97963 const lastDimPow2 = roundUpToPow2(lastDim);
97964 // Only the indices containing the top K are kept at every step to reduce
97965 // number of outputs in the GPU algorithms, so once the final set of indices
97966 // is computed then gather is used to grab the corresponding values
97967 // from the original input.
97968 let indices = null;
97969 // GPU algorithm always takes in an indices input but this input is not used
97970 // on the first run of a GPU algorithm, therefore if indices is null we simply
97971 // pass in x2D instead of it but the value will not actually be used
97972 const getInputs = () => indices === null ? [x2D, x2D] : [x2D, indices];
97973 const runSwap = (dir, inc, shape) => {
97974 const inputs = getInputs();
97975 const program = new SwapProgram(shape);
97976 const fistPass = indices === null ? 1 : 0;
97977 const customValues = [[lastDim], [fistPass], [Number.NEGATIVE_INFINITY], [dir], [inc]];
97978 const prevIndices = indices;
97979 indices = backend.runWebGLProgram(program, inputs, 'int32', customValues);
97980 disposeIntermediateTensorInfoOrNull(backend, prevIndices);
97981 };
97982 // Step 1: local sort
97983 for (let len = 1; len < kPow2; len *= 2) {
97984 const dir = len * 2;
97985 for (let inc = len; inc >= 1; inc /= 2) {
97986 runSwap(dir, inc, [batch, lastDimPow2]);
97987 }
97988 }
97989 // Step 2: merge
97990 for (let indicesSize = lastDimPow2; indicesSize > kPow2; indicesSize /= 2) {
97991 const inputs = getInputs();
97992 const mergeProgram = new MergeProgram([batch, indicesSize / 2]);
97993 const firstPass = indices === null ? 1 : 0;
97994 const customValues = [[lastDim], [firstPass], [kPow2]];
97995 const prevIndices = indices;
97996 indices =
97997 backend.runWebGLProgram(mergeProgram, inputs, 'int32', customValues);
97998 disposeIntermediateTensorInfoOrNull(backend, prevIndices);
97999 // Step 3: rebuild
98000 const len = kPow2 / 2;
98001 const dir = len * 2;
98002 for (let inc = len; inc >= 1; inc /= 2) {
98003 runSwap(dir, inc, indices.shape);
98004 }
98005 }
98006 // Keep only the requested top K results instead of kPow2
98007 let prevIndices = indices;
98008 indices = slice$2({ inputs: { x: indices }, backend, attrs: { begin: 0, size: [batch, k] } });
98009 disposeIntermediateTensorInfoOrNull(backend, prevIndices);
98010 // Gather values on last dimension
98011 let values = gatherV2$1({ inputs: { x: x2D, indices }, backend, attrs: { axis: 1, batchDims: 1 } });
98012 disposeIntermediateTensorInfoOrNull(backend, x2D);
98013 // Reshape back to the original input shape, except that the last
98014 // dimension is k.
98015 const newShape = xShape.slice(0, -1);
98016 newShape.push(k);
98017 prevIndices = indices;
98018 indices = reshape$3({ inputs: { x: indices }, attrs: { shape: newShape }, backend });
98019 disposeIntermediateTensorInfoOrNull(backend, prevIndices);
98020 const prevValues = values;
98021 values = reshape$3({ inputs: { x: values }, attrs: { shape: newShape }, backend });
98022 disposeIntermediateTensorInfoOrNull(backend, prevValues);
98023 return [values, indices];
98024 }
98025 const topKConfig$1 = {
98026 kernelName: TopK,
98027 backendName: 'webgl',
98028 kernelFunc: topK$1
98029 };
98030
98031 /**
98032 * @license
98033 * Copyright 2021 Google LLC. All Rights Reserved.
98034 * Licensed under the Apache License, Version 2.0 (the "License");
98035 * you may not use this file except in compliance with the License.
98036 * You may obtain a copy of the License at
98037 *
98038 * http://www.apache.org/licenses/LICENSE-2.0
98039 *
98040 * Unless required by applicable law or agreed to in writing, software
98041 * distributed under the License is distributed on an "AS IS" BASIS,
98042 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98043 * See the License for the specific language governing permissions and
98044 * limitations under the License.
98045 * =============================================================================
98046 */
98047 class TransformProgram {
98048 constructor(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape) {
98049 this.variableNames = ['Image', 'Transforms'];
98050 this.outputShape = outShape;
98051 const interpolationModeId = interpolation === 'nearest' ? 1 : 2;
98052 let fillModeId;
98053 switch (fillMode) {
98054 case 'constant':
98055 fillModeId = 1;
98056 break;
98057 case 'reflect':
98058 fillModeId = 2;
98059 break;
98060 case 'wrap':
98061 fillModeId = 3;
98062 break;
98063 case 'nearest':
98064 fillModeId = 4;
98065 break;
98066 default:
98067 fillModeId = 1;
98068 break;
98069 }
98070 this.userCode = `
98071 float mapCoord(float outCoord, float len) {
98072 float inCoord = outCoord;
98073 if(${fillModeId} == 2) {
98074 if (inCoord < 0.0) {
98075 if (len <= 1.0) {
98076 inCoord = 0.0;
98077 } else {
98078 float sz2 = 2.0 * len;
98079 if (inCoord < sz2) {
98080 inCoord = sz2 * float(int(float(-inCoord / sz2))) +
98081 inCoord;
98082 }
98083 inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1.0;
98084 }
98085 } else if (inCoord > len - 1.0) {
98086 if (len <= 1.0) {
98087 inCoord = 0.0;
98088 } else {
98089 float sz2 = 2.0 * len;
98090 inCoord -= sz2 * float(int(float(inCoord / sz2)));
98091 if (inCoord >= len) {
98092 inCoord = sz2 - inCoord - 1.0;
98093 }
98094 }
98095 }
98096 return clamp(inCoord, 0.0, len - 1.0);
98097 } else if (${fillModeId} == 3) {
98098 if (inCoord < 0.0) {
98099 if (len <= 1.0) {
98100 inCoord = 0.0;
98101 } else {
98102 float sz = len - 1.0;
98103 inCoord += len * (float(int(float(-inCoord / sz))) + 1.0);
98104 }
98105 } else if (inCoord > len - 1.0) {
98106 if (len <= 1.0) {
98107 inCoord = 0.0;
98108 } else {
98109 float sz = len - 1.0;
98110 inCoord -= len * float(int(float(inCoord / sz)));
98111 }
98112 }
98113 return clamp(inCoord, 0.0, len - 1.0);
98114 } else if (${fillModeId} == 4) {
98115 return clamp(outCoord, 0.0, len - 1.0);
98116 } else {
98117 return outCoord;
98118 }
98119 }
98120
98121 float readWithFillValue(int batch, int coordY, int coordX,
98122 int channel) {
98123 float outputValue;
98124 if (0 <= coordY && coordY < ${imageHeight} && 0 <= coordX && coordX < ${imageWidth}) {
98125 outputValue = getImage(batch, coordY, coordX, channel);
98126 } else {
98127 outputValue = float(${fillValue});
98128 }
98129 return outputValue;
98130 }
98131
98132 void main() {
98133 ivec4 coords = getOutputCoords();
98134 float outputValue;
98135 int batch = coords[0];
98136 int x = coords[2];
98137 int y = coords[1];
98138 int channel = coords[3];
98139 float xf = float(x);
98140 float yf = float(y);
98141 float a1 = getTransforms(batch, 0);
98142 float a2 = getTransforms(batch, 1);
98143 float a3 = getTransforms(batch, 2);
98144 float b1 = getTransforms(batch, 3);
98145 float b2 = getTransforms(batch, 4);
98146 float b3 = getTransforms(batch, 5);
98147 float c1 = getTransforms(batch, 6);
98148 float c2 = getTransforms(batch, 7);
98149 float projection = c1 * xf + c2 * yf + 1.0;
98150 if (projection == 0.0) {
98151 outputValue = float(${fillValue});
98152 } else {
98153 float inX = (a1 * xf + a2 * yf + a3) / projection;
98154 float inY = (b1 * xf + b2 * yf + b3) / projection;
98155 float mapX = mapCoord(inX, float(${imageWidth}));
98156 float mapY = mapCoord(inY, float(${imageHeight}));
98157
98158 if (${interpolationModeId} == 1) {
98159 int coordY = int(round(mapY));
98160 int coordX = int(round(mapX));
98161 outputValue = readWithFillValue(batch, coordY, coordX,
98162 channel);
98163 } else {
98164 float yFloor = floor(mapY);
98165 float xFloor = floor(mapX);
98166 float yCeil = yFloor + 1.0;
98167 float xCeil = xFloor + 1.0;
98168 float valueYFloor = (xCeil - mapX) *
98169 readWithFillValue(batch, int(yFloor), int(xFloor), channel) +
98170 (mapX - xFloor) *
98171 readWithFillValue(batch, int(yFloor), int(xCeil), channel);
98172 float valueYCeil = (xCeil - mapX) *
98173 readWithFillValue(batch, int(yCeil), int(xFloor), channel) +
98174 (mapX - xFloor) *
98175 readWithFillValue(batch, int(yCeil), int(xCeil), channel);
98176 outputValue = (yCeil - mapY) * valueYFloor +
98177 (mapY - yFloor) * valueYCeil;
98178 }
98179 }
98180 setOutput(outputValue);
98181 }
98182 `;
98183 }
98184 }
98185
98186 /**
98187 * @license
98188 * Copyright 2021 Google LLC. All Rights Reserved.
98189 * Licensed under the Apache License, Version 2.0 (the "License");
98190 * you may not use this file except in compliance with the License.
98191 * You may obtain a copy of the License at
98192 *
98193 * http://www.apache.org/licenses/LICENSE-2.0
98194 *
98195 * Unless required by applicable law or agreed to in writing, software
98196 * distributed under the License is distributed on an "AS IS" BASIS,
98197 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98198 * See the License for the specific language governing permissions and
98199 * limitations under the License.
98200 * =============================================================================
98201 */
98202 function transform$2(args) {
98203 const { inputs, backend, attrs } = args;
98204 const { image, transforms } = inputs;
98205 const { interpolation, fillMode, fillValue, outputShape } = attrs;
98206 const [batch, imageHeight, imageWidth, numChannels] = image.shape;
98207 const [outHeight, outWidth] = outputShape != null ? outputShape : [imageHeight, imageWidth];
98208 const outShape = [batch, outHeight, outWidth,
98209 numChannels];
98210 const program = new TransformProgram(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape);
98211 return backend.runWebGLProgram(program, [image, transforms], 'float32');
98212 }
98213 const transformConfig$1 = {
98214 kernelName: Transform,
98215 backendName: 'webgl',
98216 kernelFunc: transform$2
98217 };
98218
98219 /**
98220 * @license
98221 * Copyright 2020 Google LLC. All Rights Reserved.
98222 * Licensed under the Apache License, Version 2.0 (the License);
98223 * you may not use this file except in compliance with the License.
98224 * You may obtain a copy of the License at
98225 *
98226 * http://www.apache.org/licenses/LICENSE-2.0
98227 *
98228 * Unless required by applicable law or agreed to in writing, software
98229 * distributed under the License is distributed on an AS IS BASIS,
98230 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98231 * See the License for the specific language governing permissions and
98232 * limitations under the License.
98233 * =============================================================================
98234 */
98235 function unique$3(args) {
98236 const { inputs, attrs, backend } = args;
98237 const { axis } = attrs;
98238 const { x } = inputs;
98239 assertNotComplex$1(x, 'unique');
98240 // For now, always forward calculation to the CPU backend.
98241 console.warn('WARNING: ', 'UI might be locked temporarily as data is being downloaded');
98242 const values = backend.readSync(x.dataId);
98243 const { outputValues, outputShape, indices } = uniqueImplCPU(values, axis, x.shape, x.dtype);
98244 return [
98245 backend.makeTensorInfo(outputShape, x.dtype, outputValues),
98246 backend.makeTensorInfo([indices.length], 'int32', indices),
98247 ];
98248 }
98249 const uniqueConfig$1 = {
98250 kernelName: Unique,
98251 backendName: 'webgl',
98252 kernelFunc: unique$3,
98253 };
98254
98255 /**
98256 * @license
98257 * Copyright 2020 Google LLC. All Rights Reserved.
98258 * Licensed under the Apache License, Version 2.0 (the "License");
98259 * you may not use this file except in compliance with the License.
98260 * You may obtain a copy of the License at
98261 *
98262 * http://www.apache.org/licenses/LICENSE-2.0
98263 *
98264 * Unless required by applicable law or agreed to in writing, software
98265 * distributed under the License is distributed on an "AS IS" BASIS,
98266 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98267 * See the License for the specific language governing permissions and
98268 * limitations under the License.
98269 * =============================================================================
98270 */
98271 function unpack$1(args) {
98272 const { inputs, backend, attrs } = args;
98273 const { value } = inputs;
98274 let { axis } = attrs;
98275 if (axis < 0) {
98276 axis += value.shape.length;
98277 }
98278 const x = value;
98279 const xRank = x.shape.length;
98280 const num = value.shape[axis];
98281 const outShape = new Array(xRank - 1);
98282 let outIndex = 0;
98283 for (let i = 0; i < xRank; i++) {
98284 if (i !== axis) {
98285 outShape[outIndex++] = x.shape[i];
98286 }
98287 }
98288 const toDispose = [];
98289 const begin = new Array(xRank).fill(0);
98290 const size = x.shape.slice();
98291 size[axis] = 1;
98292 const res = new Array(num);
98293 for (let i = 0; i < res.length; i++) {
98294 begin[axis] = i;
98295 const sliced = slice$2({ inputs: { x }, backend, attrs: { begin, size } });
98296 const reshaped = reshape$3({ inputs: { x: sliced }, backend, attrs: { shape: outShape } });
98297 res[i] = reshaped;
98298 toDispose.push(sliced);
98299 }
98300 toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
98301 return res;
98302 }
98303 const unpackConfig$1 = {
98304 kernelName: Unpack,
98305 backendName: 'webgl',
98306 kernelFunc: unpack$1
98307 };
98308
98309 /**
98310 * @license
98311 * Copyright 2018 Google LLC. All Rights Reserved.
98312 * Licensed under the Apache License, Version 2.0 (the "License");
98313 * you may not use this file except in compliance with the License.
98314 * You may obtain a copy of the License at
98315 *
98316 * http://www.apache.org/licenses/LICENSE-2.0
98317 *
98318 * Unless required by applicable law or agreed to in writing, software
98319 * distributed under the License is distributed on an "AS IS" BASIS,
98320 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98321 * See the License for the specific language governing permissions and
98322 * limitations under the License.
98323 * =============================================================================
98324 */
98325 class SegmentOpProgram {
98326 constructor(segOpInfo, segOpType) {
98327 this.variableNames = ['x', 'segmentIds'];
98328 const windowSize = segOpInfo.windowSize;
98329 const batchSize = segOpInfo.batchSize;
98330 const inSize = segOpInfo.inSize;
98331 const numSegments = segOpInfo.numSegments;
98332 const outSize = numSegments * Math.ceil(inSize / windowSize);
98333 this.outputShape = [batchSize, outSize];
98334 const initializationValue = '0.0';
98335 const returnValue = `sumValue`;
98336 const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
98337 const windowSizeVec4Remainder = windowSize % 4;
98338 const updateSnippet = `
98339 sumValue += dot(values, segFilter);
98340 `;
98341 let checkValueOutOfBounds = '';
98342 if (inSize % windowSize > 0) {
98343 checkValueOutOfBounds = `
98344 if (inIdx < 0 || inIdx >= ${inSize}) {
98345 return initializationValue;
98346 }
98347 `;
98348 }
98349 let checkSegmentIdOutOfBounds = '';
98350 if (inSize % windowSize > 0) {
98351 checkSegmentIdOutOfBounds = `
98352 if (inIdx < 0 || inIdx >= ${inSize}) {
98353 return -1.0;
98354 }
98355 `;
98356 }
98357 this.userCode = `
98358 const float initializationValue = ${initializationValue};
98359
98360 float getValue(int batch, int inIdx) {
98361 ${checkValueOutOfBounds}
98362 return getX(batch, inIdx);
98363 }
98364
98365 float getSegmentIdAtIndex(int inIdx) {
98366 ${checkSegmentIdOutOfBounds}
98367 return getSegmentIds(inIdx);
98368 }
98369
98370 void main() {
98371 ivec2 coords = getOutputCoords();
98372 int batch = coords[0];
98373 int outIdx = coords[1];
98374 int inOffset = int(floor(float(outIdx) / float(
98375 ${numSegments})) * float(${windowSize}));
98376 int currentSeg = int(mod(float(outIdx), float(${numSegments})));
98377
98378 float sumValue = 0.0;
98379
98380 for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
98381 int inIdx = inOffset + i;
98382 vec4 values = vec4(
98383 getValue(batch, inIdx),
98384 getValue(batch, inIdx + 1),
98385 getValue(batch, inIdx + 2),
98386 getValue(batch, inIdx + 3)
98387 );
98388
98389 vec4 segFilter = vec4(
98390 int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
98391 int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
98392 int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,
98393 int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0
98394 );
98395
98396 ${updateSnippet}
98397 }
98398
98399 int inIdx = inOffset + ${windowSizeNearestVec4};
98400 if (${windowSizeVec4Remainder === 1}) {
98401 vec4 values = vec4(
98402 getValue(batch, inIdx),
98403 initializationValue,
98404 initializationValue,
98405 initializationValue
98406 );
98407
98408 int inIdxSeg = int(getSegmentIdAtIndex(inIdx));
98409
98410 vec4 segFilter = vec4(
98411 int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
98412 0,
98413 0,
98414 0
98415 );
98416
98417 ${updateSnippet}
98418 } else if (${windowSizeVec4Remainder === 2}) {
98419 vec4 values = vec4(
98420 getValue(batch, inIdx),
98421 getValue(batch, inIdx + 1),
98422 initializationValue,
98423 initializationValue
98424 );
98425
98426 vec4 segFilter = vec4(
98427 int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
98428 int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
98429 0,
98430 0
98431 );
98432
98433 ${updateSnippet}
98434 } else if (${windowSizeVec4Remainder === 3}) {
98435 vec4 values = vec4(
98436 getValue(batch, inIdx),
98437 getValue(batch, inIdx + 1),
98438 getValue(batch, inIdx + 2),
98439 initializationValue
98440 );
98441
98442 vec4 segFilter = vec4(
98443 int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
98444 int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
98445 int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,
98446 0
98447 );
98448
98449 ${updateSnippet}
98450 }
98451 setOutput(${returnValue});
98452 }
98453 `;
98454 }
98455 }
98456
98457 /**
98458 * @license
98459 * Copyright 2020 Google LLC. All Rights Reserved.
98460 * Licensed under the Apache License, Version 2.0 (the "License");
98461 * you may not use this file except in compliance with the License.
98462 * You may obtain a copy of the License at
98463 *
98464 * http://www.apache.org/licenses/LICENSE-2.0
98465 *
98466 * Unless required by applicable law or agreed to in writing, software
98467 * distributed under the License is distributed on an "AS IS" BASIS,
98468 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98469 * See the License for the specific language governing permissions and
98470 * limitations under the License.
98471 * =============================================================================
98472 */
98473 function unsortedSegmentSum$2(args) {
98474 const { inputs, backend, attrs } = args;
98475 const { x, segmentIds } = inputs;
98476 const { numSegments } = attrs;
98477 const xRank = x.shape.length;
98478 const toDispose = [];
98479 let axis = 0;
98480 const permutation = getAxesPermutation([axis], xRank);
98481 let permutedX = x;
98482 if (permutation != null) {
98483 permutedX = transpose$2({ inputs: { x }, backend, attrs: { perm: permutation } });
98484 toDispose.push(permutedX);
98485 axis = getInnerMostAxes(1, xRank)[0];
98486 }
98487 const outShape = computeOutShape$2(permutedX.shape, axis, numSegments);
98488 const inSize = sizeFromShape([permutedX.shape[axis]]);
98489 const a2D = reshape$3({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
98490 toDispose.push(a2D);
98491 const outputDType = sumOutType(x.dtype);
98492 const segOpCompute = (x, segOpType, segmentIds, dtype, numSegments) => {
98493 const batchSize = x.shape[0];
98494 const inSize = x.shape[1];
98495 const windowSize = segOpComputeOptimalWindowSize(inSize, numSegments);
98496 const segOpInfo = { windowSize, inSize, batchSize, numSegments };
98497 const program = new SegmentOpProgram(segOpInfo, segOpType);
98498 const output = backend.compileAndRun(program, [x, segmentIds], dtype);
98499 toDispose.push(output);
98500 // No need to run another GPGPU program.
98501 if (output.shape[1] === numSegments) {
98502 return output;
98503 }
98504 const rangeInfo = range$3({
98505 backend,
98506 attrs: { start: 0, stop: numSegments, step: 1, dtype: 'float32' }
98507 });
98508 const tileInfo = tile$3({
98509 inputs: { x: rangeInfo },
98510 backend,
98511 attrs: { reps: [inSize / windowSize] }
98512 });
98513 toDispose.push(rangeInfo);
98514 toDispose.push(tileInfo);
98515 const result = segOpCompute(output, segOpType, tileInfo, dtype, numSegments);
98516 return result;
98517 };
98518 const segOpResult = segOpCompute(a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments);
98519 const reshaped = reshape$3({ inputs: { x: segOpResult }, backend, attrs: { shape: outShape } });
98520 let result = reshaped;
98521 if (permutation != null) {
98522 toDispose.push(reshaped);
98523 const perm = getUndoAxesPermutation(permutation);
98524 result = transpose$2({ inputs: { x: result }, backend, attrs: { perm } });
98525 }
98526 toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
98527 return result;
98528 }
98529 const unsortedSegmentSumConfig$1 = {
98530 kernelName: UnsortedSegmentSum,
98531 backendName: 'webgl',
98532 kernelFunc: unsortedSegmentSum$2
98533 };
98534
98535 /**
98536 * @license
98537 * Copyright 2020 Google LLC. All Rights Reserved.
98538 * Licensed under the Apache License, Version 2.0 (the "License");
98539 * you may not use this file except in compliance with the License.
98540 * You may obtain a copy of the License at
98541 *
98542 * http://www.apache.org/licenses/LICENSE-2.0
98543 *
98544 * Unless required by applicable law or agreed to in writing, software
98545 * distributed under the License is distributed on an "AS IS" BASIS,
98546 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98547 * See the License for the specific language governing permissions and
98548 * limitations under the License.
98549 * =============================================================================
98550 */
98551 // List all kernel configs here
98552 const kernelConfigs$1 = [
98553 _fusedMatMulConfig$1,
98554 absConfig$1,
98555 acosConfig$1,
98556 acoshConfig$1,
98557 addConfig$1,
98558 addNConfig$1,
98559 allConfig$1,
98560 anyConfig$1,
98561 argMaxConfig$1,
98562 argMinConfig$1,
98563 asinConfig$1,
98564 asinhConfig$1,
98565 atanConfig$1,
98566 atan2Config$1,
98567 atanhConfig$1,
98568 avgPoolConfig$1,
98569 avgPool3DConfig$1,
98570 avgPool3DGradConfig$2,
98571 avgPoolGradConfig$2,
98572 batchMatMulConfig$1,
98573 batchNormConfig$1,
98574 batchToSpaceNDConfig$1,
98575 bincountConfig$1,
98576 broadcastArgsConfig$1,
98577 castConfig$1,
98578 ceilConfig$1,
98579 clipByValueConfig$1,
98580 complexConfig$1,
98581 complexAbsConfig$1,
98582 concatConfig$1,
98583 conv2DConfig$1,
98584 conv2DBackpropFilterConfig$1,
98585 conv2DBackpropInputConfig$1,
98586 conv3DConfig$1,
98587 conv3DBackpropFilterV2Config$1,
98588 conv3DBackpropInputConfig,
98589 cosConfig$1,
98590 coshConfig$1,
98591 cropAndResizeConfig$1,
98592 cumprodConfig$1,
98593 cumsumConfig$1,
98594 denseBincountConfig$1,
98595 depthToSpaceConfig$1,
98596 depthwiseConv2dNativeConfig$1,
98597 depthwiseConv2dNativeBackpropFilterConfig$1,
98598 depthwiseConv2dNativeBackpropInputConfig$1,
98599 diagConfig$1,
98600 dilation2DConfig$1,
98601 einsumConfig$1,
98602 eluConfig$1,
98603 eluGradConfig$2,
98604 equalConfig$1,
98605 erfConfig$1,
98606 expConfig$1,
98607 expandDimsConfig$1,
98608 expm1Config$1,
98609 fftConfig$1,
98610 fillConfig$1,
98611 flipLeftRightConfig$1,
98612 floorConfig$1,
98613 floorDivConfig$1,
98614 fromPixelsConfig,
98615 fusedConv2DConfig$1,
98616 fusedDepthwiseConv2DConfig$1,
98617 gatherNdConfig$1,
98618 gatherV2Config$1,
98619 greaterConfig$1,
98620 greaterEqualConfig$1,
98621 identityConfig$1,
98622 ifftConfig$1,
98623 imagConfig$1,
98624 isFiniteConfig$1,
98625 isInfConfig$1,
98626 isNaNConfig$1,
98627 leakyReluConfig$1,
98628 lessConfig$1,
98629 lessEqualConfig$1,
98630 linSpaceConfig$1,
98631 logConfig$1,
98632 log1pConfig$1,
98633 logicalAndConfig$1,
98634 logicalNotConfig$1,
98635 logicalOrConfig$1,
98636 LRNConfig$1,
98637 LRNGradConfig$1,
98638 maxConfig$1,
98639 maximumConfig$1,
98640 maxPoolConfig$1,
98641 maxPool3DConfig$1,
98642 maxPool3DGradConfig$2,
98643 maxPoolGradConfig$2,
98644 maxPoolWithArgmaxConfig$1,
98645 meanConfig$1,
98646 minConfig$1,
98647 minimumConfig$1,
98648 mirrorPadConfig$1,
98649 modConfig$1,
98650 multinomialConfig$1,
98651 multiplyConfig$1,
98652 negConfig$1,
98653 nonMaxSuppressionV3Config$1,
98654 nonMaxSuppressionV4Config$1,
98655 nonMaxSuppressionV5Config$1,
98656 notEqualConfig$1,
98657 oneHotConfig$1,
98658 onesLikeConfig$1,
98659 packConfig$1,
98660 padV2Config$1,
98661 powConfig$1,
98662 preluConfig$1,
98663 prodConfig$1,
98664 rangeConfig$1,
98665 realConfig$1,
98666 realDivConfig$1,
98667 reciprocalConfig$1,
98668 reluConfig$1,
98669 relu6Config$1,
98670 reshapeConfig$1,
98671 resizeBilinearConfig$1,
98672 resizeBilinearGradConfig$2,
98673 resizeNearestNeighborConfig$1,
98674 resizeNearestNeighborGradConfig$2,
98675 reverseConfig$1,
98676 rotateWithOffsetConfig$1,
98677 roundConfig$1,
98678 rsqrtConfig$1,
98679 scatterNdConfig$1,
98680 searchSortedConfig$1,
98681 selectConfig$1,
98682 seluConfig$1,
98683 sigmoidConfig$1,
98684 signConfig$1,
98685 sinConfig$1,
98686 sinhConfig$1,
98687 sliceConfig$1,
98688 softmaxConfig$1,
98689 softplusConfig$1,
98690 spaceToBatchNDConfig$1,
98691 sparseFillEmptyRowsConfig$1,
98692 sparseReshapeConfig$1,
98693 sparseSegmentMeanConfig$1,
98694 sparseSegmentSumConfig$1,
98695 sparseToDenseConfig$1,
98696 splitVConfig$1,
98697 sqrtConfig$1,
98698 squareConfig$1,
98699 squaredDifferenceConfig$1,
98700 stepConfig$1,
98701 stridedSliceConfig$1,
98702 stringNGramsConfig$1,
98703 stringSplitConfig$1,
98704 stringToHashBucketFastConfig$1,
98705 subConfig$1,
98706 sumConfig$1,
98707 tanConfig$1,
98708 tanhConfig$1,
98709 tileConfig$1,
98710 topKConfig$1,
98711 transformConfig$1,
98712 transposeConfig$1,
98713 uniqueConfig$1,
98714 unpackConfig$1,
98715 unsortedSegmentSumConfig$1,
98716 zerosLikeConfig$1
98717 ];
98718 for (const kernelConfig of kernelConfigs$1) {
98719 registerKernel(kernelConfig);
98720 }
98721
98722 /**
98723 * @license
98724 * Copyright 2020 Google LLC. All Rights Reserved.
98725 * Licensed under the Apache License, Version 2.0 (the "License");
98726 * you may not use this file except in compliance with the License.
98727 * You may obtain a copy of the License at
98728 *
98729 * http://www.apache.org/licenses/LICENSE-2.0
98730 *
98731 * Unless required by applicable law or agreed to in writing, software
98732 * distributed under the License is distributed on an "AS IS" BASIS,
98733 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98734 * See the License for the specific language governing permissions and
98735 * limitations under the License.
98736 * =============================================================================
98737 */
98738
98739 /** @license See the LICENSE file. */
98740 // This code is auto-generated, do not modify this file!
98741 const version$6 = '3.18.0';
98742
98743 /**
98744 * @license
98745 * Copyright 2018 Google LLC. All Rights Reserved.
98746 * Licensed under the Apache License, Version 2.0 (the "License");
98747 * you may not use this file except in compliance with the License.
98748 * You may obtain a copy of the License at
98749 *
98750 * http://www.apache.org/licenses/LICENSE-2.0
98751 *
98752 * Unless required by applicable law or agreed to in writing, software
98753 * distributed under the License is distributed on an "AS IS" BASIS,
98754 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98755 * See the License for the specific language governing permissions and
98756 * limitations under the License.
98757 * =============================================================================
98758 */
98759 const version$7 = {
98760 'tfjs-core': version,
98761 'tfjs-backend-cpu': version$4,
98762 'tfjs-backend-webgl': version$5,
98763 'tfjs-data': version$3,
98764 'tfjs-layers': version$1,
98765 'tfjs-converter': version$2,
98766 'tfjs': version$6
98767 };
98768
98769 exports.Abs = Abs;
98770 exports.Acos = Acos;
98771 exports.Acosh = Acosh;
98772 exports.AdadeltaOptimizer = AdadeltaOptimizer;
98773 exports.AdagradOptimizer = AdagradOptimizer;
98774 exports.AdamOptimizer = AdamOptimizer;
98775 exports.AdamaxOptimizer = AdamaxOptimizer;
98776 exports.Add = Add;
98777 exports.AddN = AddN;
98778 exports.All = All;
98779 exports.Any = Any;
98780 exports.ArgMax = ArgMax;
98781 exports.ArgMin = ArgMin;
98782 exports.Asin = Asin;
98783 exports.Asinh = Asinh;
98784 exports.Atan = Atan;
98785 exports.Atan2 = Atan2;
98786 exports.Atanh = Atanh;
98787 exports.AvgPool = AvgPool;
98788 exports.AvgPool3D = AvgPool3D;
98789 exports.AvgPool3DGrad = AvgPool3DGrad;
98790 exports.AvgPoolGrad = AvgPoolGrad;
98791 exports.BatchMatMul = BatchMatMul;
98792 exports.BatchToSpaceND = BatchToSpaceND;
98793 exports.Bincount = Bincount;
98794 exports.BroadcastArgs = BroadcastArgs;
98795 exports.BroadcastTo = BroadcastTo;
98796 exports.Callback = Callback;
98797 exports.CallbackList = CallbackList;
98798 exports.Cast = Cast;
98799 exports.Ceil = Ceil;
98800 exports.ClipByValue = ClipByValue;
98801 exports.Complex = Complex;
98802 exports.ComplexAbs = ComplexAbs;
98803 exports.Concat = Concat;
98804 exports.Conv2D = Conv2D;
98805 exports.Conv2DBackpropFilter = Conv2DBackpropFilter;
98806 exports.Conv2DBackpropInput = Conv2DBackpropInput;
98807 exports.Conv3D = Conv3D;
98808 exports.Conv3DBackpropFilterV2 = Conv3DBackpropFilterV2;
98809 exports.Conv3DBackpropInputV2 = Conv3DBackpropInputV2;
98810 exports.Cos = Cos;
98811 exports.Cosh = Cosh;
98812 exports.CropAndResize = CropAndResize;
98813 exports.Cumprod = Cumprod;
98814 exports.Cumsum = Cumsum;
98815 exports.CustomCallback = CustomCallback;
98816 exports.DataStorage = DataStorage;
98817 exports.DenseBincount = DenseBincount;
98818 exports.DepthToSpace = DepthToSpace;
98819 exports.DepthwiseConv2dNative = DepthwiseConv2dNative;
98820 exports.DepthwiseConv2dNativeBackpropFilter = DepthwiseConv2dNativeBackpropFilter;
98821 exports.DepthwiseConv2dNativeBackpropInput = DepthwiseConv2dNativeBackpropInput;
98822 exports.Diag = Diag;
98823 exports.Dilation2D = Dilation2D;
98824 exports.Dilation2DBackpropFilter = Dilation2DBackpropFilter;
98825 exports.Dilation2DBackpropInput = Dilation2DBackpropInput;
98826 exports.EarlyStopping = EarlyStopping;
98827 exports.Einsum = Einsum;
98828 exports.Elu = Elu;
98829 exports.EluGrad = EluGrad;
98830 exports.Environment = Environment;
98831 exports.Equal = Equal;
98832 exports.Erf = Erf;
98833 exports.Exp = Exp;
98834 exports.ExpandDims = ExpandDims;
98835 exports.Expm1 = Expm1;
98836 exports.FFT = FFT;
98837 exports.Fill = Fill;
98838 exports.FlipLeftRight = FlipLeftRight;
98839 exports.Floor = Floor;
98840 exports.FloorDiv = FloorDiv;
98841 exports.FromPixels = FromPixels;
98842 exports.FusedBatchNorm = FusedBatchNorm;
98843 exports.FusedConv2D = FusedConv2D;
98844 exports.FusedDepthwiseConv2D = FusedDepthwiseConv2D;
98845 exports.GatherNd = GatherNd;
98846 exports.GatherV2 = GatherV2;
98847 exports.GraphModel = GraphModel;
98848 exports.Greater = Greater;
98849 exports.GreaterEqual = GreaterEqual;
98850 exports.History = History;
98851 exports.IFFT = IFFT;
98852 exports.Identity = Identity;
98853 exports.Imag = Imag;
98854 exports.InputSpec = InputSpec;
98855 exports.IsFinite = IsFinite;
98856 exports.IsInf = IsInf;
98857 exports.IsNan = IsNan;
98858 exports.KernelBackend = KernelBackend;
98859 exports.LRN = LRN;
98860 exports.LRNGrad = LRNGrad;
98861 exports.LayerVariable = LayerVariable;
98862 exports.LayersModel = LayersModel;
98863 exports.LeakyRelu = LeakyRelu;
98864 exports.Less = Less;
98865 exports.LessEqual = LessEqual;
98866 exports.LinSpace = LinSpace;
98867 exports.Log = Log;
98868 exports.Log1p = Log1p;
98869 exports.LogSoftmax = LogSoftmax;
98870 exports.LogicalAnd = LogicalAnd;
98871 exports.LogicalNot = LogicalNot;
98872 exports.LogicalOr = LogicalOr;
98873 exports.LowerBound = LowerBound;
98874 exports.Max = Max;
98875 exports.MaxPool = MaxPool;
98876 exports.MaxPool3D = MaxPool3D;
98877 exports.MaxPool3DGrad = MaxPool3DGrad;
98878 exports.MaxPoolGrad = MaxPoolGrad;
98879 exports.MaxPoolWithArgmax = MaxPoolWithArgmax;
98880 exports.Maximum = Maximum;
98881 exports.Mean = Mean;
98882 exports.Min = Min;
98883 exports.Minimum = Minimum;
98884 exports.MirrorPad = MirrorPad;
98885 exports.Mod = Mod;
98886 exports.MomentumOptimizer = MomentumOptimizer;
98887 exports.Multinomial = Multinomial;
98888 exports.Multiply = Multiply;
98889 exports.Neg = Neg;
98890 exports.NonMaxSuppressionV3 = NonMaxSuppressionV3;
98891 exports.NonMaxSuppressionV4 = NonMaxSuppressionV4;
98892 exports.NonMaxSuppressionV5 = NonMaxSuppressionV5;
98893 exports.NotEqual = NotEqual;
98894 exports.OP_SCOPE_SUFFIX = OP_SCOPE_SUFFIX;
98895 exports.OneHot = OneHot;
98896 exports.OnesLike = OnesLike;
98897 exports.Optimizer = Optimizer;
98898 exports.OptimizerConstructors = OptimizerConstructors;
98899 exports.Pack = Pack;
98900 exports.PadV2 = PadV2;
98901 exports.Pool = Pool;
98902 exports.Pow = Pow;
98903 exports.Prelu = Prelu;
98904 exports.Prod = Prod;
98905 exports.RMSPropOptimizer = RMSPropOptimizer;
98906 exports.RNN = RNN;
98907 exports.Range = Range;
98908 exports.Real = Real;
98909 exports.RealDiv = RealDiv;
98910 exports.Reciprocal = Reciprocal;
98911 exports.Relu = Relu;
98912 exports.Relu6 = Relu6;
98913 exports.Reshape = Reshape;
98914 exports.ResizeBilinear = ResizeBilinear;
98915 exports.ResizeBilinearGrad = ResizeBilinearGrad;
98916 exports.ResizeNearestNeighbor = ResizeNearestNeighbor;
98917 exports.ResizeNearestNeighborGrad = ResizeNearestNeighborGrad;
98918 exports.Reverse = Reverse;
98919 exports.RotateWithOffset = RotateWithOffset;
98920 exports.Round = Round;
98921 exports.Rsqrt = Rsqrt;
98922 exports.SGDOptimizer = SGDOptimizer;
98923 exports.ScatterNd = ScatterNd;
98924 exports.SearchSorted = SearchSorted;
98925 exports.Select = Select;
98926 exports.Selu = Selu;
98927 exports.Sequential = Sequential;
98928 exports.Sigmoid = Sigmoid;
98929 exports.Sign = Sign;
98930 exports.Sin = Sin;
98931 exports.Sinh = Sinh;
98932 exports.Slice = Slice;
98933 exports.Softmax = Softmax;
98934 exports.Softplus = Softplus;
98935 exports.SpaceToBatchND = SpaceToBatchND;
98936 exports.SparseFillEmptyRows = SparseFillEmptyRows;
98937 exports.SparseReshape = SparseReshape;
98938 exports.SparseSegmentMean = SparseSegmentMean;
98939 exports.SparseSegmentSum = SparseSegmentSum;
98940 exports.SparseToDense = SparseToDense;
98941 exports.SplitV = SplitV;
98942 exports.Sqrt = Sqrt;
98943 exports.Square = Square;
98944 exports.SquaredDifference = SquaredDifference;
98945 exports.Step = Step;
98946 exports.StridedSlice = StridedSlice;
98947 exports.StringNGrams = StringNGrams;
98948 exports.StringSplit = StringSplit;
98949 exports.StringToHashBucketFast = StringToHashBucketFast;
98950 exports.Sub = Sub;
98951 exports.Sum = Sum;
98952 exports.SymbolicTensor = SymbolicTensor;
98953 exports.Tan = Tan;
98954 exports.Tanh = Tanh;
98955 exports.Tensor = Tensor;
98956 exports.TensorBuffer = TensorBuffer;
98957 exports.Tile = Tile;
98958 exports.TopK = TopK;
98959 exports.Transform = Transform;
98960 exports.Transpose = Transpose;
98961 exports.Unique = Unique;
98962 exports.Unpack = Unpack;
98963 exports.UnsortedSegmentSum = UnsortedSegmentSum;
98964 exports.UpperBound = UpperBound;
98965 exports.Variable = Variable;
98966 exports.ZerosLike = ZerosLike;
98967 exports._FusedMatMul = _FusedMatMul;
98968 exports.abs = abs;
98969 exports.acos = acos;
98970 exports.acosh = acosh;
98971 exports.add = add$1;
98972 exports.addN = addN;
98973 exports.all = all;
98974 exports.any = any;
98975 exports.argMax = argMax;
98976 exports.argMin = argMin;
98977 exports.asin = asin;
98978 exports.asinh = asinh;
98979 exports.atan = atan;
98980 exports.atan2 = atan2;
98981 exports.atanh = atanh;
98982 exports.avgPool = avgPool;
98983 exports.avgPool3d = avgPool3d;
98984 exports.backend = backend;
98985 exports.backend_util = backend_util;
98986 exports.basicLSTMCell = basicLSTMCell;
98987 exports.batchNorm = batchNorm;
98988 exports.batchNorm2d = batchNorm2d;
98989 exports.batchNorm3d = batchNorm3d;
98990 exports.batchNorm4d = batchNorm4d;
98991 exports.batchToSpaceND = batchToSpaceND;
98992 exports.bincount = bincount;
98993 exports.booleanMaskAsync = booleanMaskAsync;
98994 exports.broadcastArgs = broadcastArgs;
98995 exports.broadcastTo = broadcastTo;
98996 exports.broadcast_util = broadcast_util;
98997 exports.browser = browser;
98998 exports.buffer = buffer;
98999 exports.callbacks = callbacks;
99000 exports.cast = cast;
99001 exports.ceil = ceil;
99002 exports.clipByValue = clipByValue;
99003 exports.clone = clone;
99004 exports.complex = complex;
99005 exports.concat = concat;
99006 exports.concat1d = concat1d;
99007 exports.concat2d = concat2d;
99008 exports.concat3d = concat3d;
99009 exports.concat4d = concat4d;
99010 exports.constraints = exports_constraints;
99011 exports.conv1d = conv1d;
99012 exports.conv2d = conv2d;
99013 exports.conv2dTranspose = conv2dTranspose;
99014 exports.conv3d = conv3d;
99015 exports.conv3dTranspose = conv3dTranspose;
99016 exports.copyRegisteredKernels = copyRegisteredKernels;
99017 exports.cos = cos;
99018 exports.cosh = cosh;
99019 exports.cosineWindow = cosineWindow;
99020 exports.cumprod = cumprod;
99021 exports.cumsum = cumsum;
99022 exports.customGrad = customGrad;
99023 exports.data = index;
99024 exports.denseBincount = denseBincount;
99025 exports.deprecationWarn = deprecationWarn;
99026 exports.depthToSpace = depthToSpace;
99027 exports.depthwiseConv2d = depthwiseConv2d;
99028 exports.deregisterOp = deregisterOp;
99029 exports.device_util = device_util;
99030 exports.diag = diag;
99031 exports.dilation2d = dilation2d;
99032 exports.disableDeprecationWarnings = disableDeprecationWarnings;
99033 exports.dispose = dispose;
99034 exports.disposeVariables = disposeVariables;
99035 exports.div = div;
99036 exports.divNoNan = divNoNan;
99037 exports.dot = dot;
99038 exports.dropout = dropout;
99039 exports.einsum = einsum;
99040 exports.elu = elu;
99041 exports.enableDebugMode = enableDebugMode;
99042 exports.enableProdMode = enableProdMode;
99043 exports.enclosingPowerOfTwo = enclosingPowerOfTwo;
99044 exports.engine = engine;
99045 exports.env = env;
99046 exports.equal = equal;
99047 exports.erf = erf;
99048 exports.euclideanNorm = euclideanNorm;
99049 exports.exp = exp;
99050 exports.expandDims = expandDims;
99051 exports.expm1 = expm1;
99052 exports.eye = eye;
99053 exports.fft = fft;
99054 exports.fill = fill;
99055 exports.findBackend = findBackend;
99056 exports.findBackendFactory = findBackendFactory;
99057 exports.floor = floor;
99058 exports.floorDiv = floorDiv;
99059 exports.fused = fused_ops;
99060 exports.gather = gather;
99061 exports.gatherND = gatherND;
99062 exports.gather_util = gather_nd_util;
99063 exports.getBackend = getBackend;
99064 exports.getGradient = getGradient;
99065 exports.getKernel = getKernel;
99066 exports.getKernelsForBackend = getKernelsForBackend;
99067 exports.grad = grad;
99068 exports.grads = grads;
99069 exports.greater = greater;
99070 exports.greaterEqual = greaterEqual;
99071 exports.ifft = ifft;
99072 exports.imag = imag;
99073 exports.image = image;
99074 exports.inTopKAsync = inTopKAsync;
99075 exports.initializers = exports_initializers;
99076 exports.input = input;
99077 exports.io = io;
99078 exports.irfft = irfft;
99079 exports.isFinite = isFinite$1;
99080 exports.isInf = isInf;
99081 exports.isNaN = isNaN$1;
99082 exports.keep = keep;
99083 exports.kernel_impls = kernel_impls;
99084 exports.layers = exports_layers;
99085 exports.leakyRelu = leakyRelu;
99086 exports.less = less;
99087 exports.lessEqual = lessEqual;
99088 exports.linalg = linalg;
99089 exports.linspace = linspace;
99090 exports.loadGraphModel = loadGraphModel;
99091 exports.loadGraphModelSync = loadGraphModelSync;
99092 exports.loadLayersModel = loadLayersModel;
99093 exports.localResponseNormalization = localResponseNormalization;
99094 exports.log = log$1;
99095 exports.log1p = log1p;
99096 exports.logSigmoid = logSigmoid;
99097 exports.logSoftmax = logSoftmax;
99098 exports.logSumExp = logSumExp;
99099 exports.logicalAnd = logicalAnd;
99100 exports.logicalNot = logicalNot;
99101 exports.logicalOr = logicalOr;
99102 exports.logicalXor = logicalXor;
99103 exports.losses = losses;
99104 exports.lowerBound = lowerBound;
99105 exports.matMul = matMul;
99106 exports.math = math;
99107 exports.max = max;
99108 exports.maxPool = maxPool;
99109 exports.maxPool3d = maxPool3d;
99110 exports.maxPoolWithArgmax = maxPoolWithArgmax;
99111 exports.maximum = maximum;
99112 exports.mean = mean;
99113 exports.memory = memory;
99114 exports.meshgrid = meshgrid;
99115 exports.metrics = exports_metrics;
99116 exports.min = min;
99117 exports.minimum = minimum;
99118 exports.mirrorPad = mirrorPad;
99119 exports.mod = mod;
99120 exports.model = model;
99121 exports.models = exports_models;
99122 exports.moments = moments;
99123 exports.movingAverage = movingAverage;
99124 exports.mul = mul;
99125 exports.multiRNNCell = multiRNNCell;
99126 exports.multinomial = multinomial;
99127 exports.neg = neg;
99128 exports.nextFrame = nextFrame;
99129 exports.norm = norm;
99130 exports.notEqual = notEqual;
99131 exports.oneHot = oneHot;
99132 exports.ones = ones$1;
99133 exports.onesLike = onesLike;
99134 exports.op = op;
99135 exports.outerProduct = outerProduct;
99136 exports.pad = pad;
99137 exports.pad1d = pad1d;
99138 exports.pad2d = pad2d;
99139 exports.pad3d = pad3d;
99140 exports.pad4d = pad4d;
99141 exports.pool = pool;
99142 exports.pow = pow;
99143 exports.prelu = prelu;
99144 exports.print = print;
99145 exports.prod = prod;
99146 exports.profile = profile;
99147 exports.rand = rand;
99148 exports.randomGamma = randomGamma;
99149 exports.randomNormal = randomNormal;
99150 exports.randomUniform = randomUniform;
99151 exports.range = range;
99152 exports.ready = ready;
99153 exports.real = real;
99154 exports.reciprocal = reciprocal;
99155 exports.registerBackend = registerBackend;
99156 exports.registerCallbackConstructor = registerCallbackConstructor;
99157 exports.registerGradient = registerGradient;
99158 exports.registerKernel = registerKernel;
99159 exports.registerOp = registerOp;
99160 exports.regularizers = exports_regularizers;
99161 exports.relu = relu;
99162 exports.relu6 = relu6;
99163 exports.removeBackend = removeBackend;
99164 exports.reshape = reshape;
99165 exports.reverse = reverse;
99166 exports.reverse1d = reverse1d;
99167 exports.reverse2d = reverse2d;
99168 exports.reverse3d = reverse3d;
99169 exports.reverse4d = reverse4d;
99170 exports.rfft = rfft;
99171 exports.round = round$1;
99172 exports.rsqrt = rsqrt;
99173 exports.scalar = scalar;
99174 exports.scatterND = scatterND;
99175 exports.scatter_util = scatter_nd_util;
99176 exports.searchSorted = searchSorted;
99177 exports.selu = selu;
99178 exports.separableConv2d = separableConv2d;
99179 exports.sequential = sequential;
99180 exports.serialization = serialization;
99181 exports.setBackend = setBackend;
99182 exports.setPlatform = setPlatform;
99183 exports.setdiff1dAsync = setdiff1dAsync;
99184 exports.sigmoid = sigmoid;
99185 exports.sign = sign;
99186 exports.signal = signal;
99187 exports.sin = sin;
99188 exports.sinh = sinh;
99189 exports.slice = slice;
99190 exports.slice1d = slice1d;
99191 exports.slice2d = slice2d;
99192 exports.slice3d = slice3d;
99193 exports.slice4d = slice4d;
99194 exports.slice_util = slice_util;
99195 exports.softmax = softmax;
99196 exports.softplus = softplus;
99197 exports.spaceToBatchND = spaceToBatchND;
99198 exports.sparse = sparse;
99199 exports.sparseToDense = sparseToDense;
99200 exports.spectral = spectral;
99201 exports.split = split;
99202 exports.sqrt = sqrt;
99203 exports.square = square;
99204 exports.squaredDifference = squaredDifference;
99205 exports.squeeze = squeeze;
99206 exports.stack = stack;
99207 exports.step = step;
99208 exports.stridedSlice = stridedSlice;
99209 exports.string = string;
99210 exports.sub = sub;
99211 exports.sum = sum$1;
99212 exports.sumOutType = sumOutType;
99213 exports.tan = tan;
99214 exports.tanh = tanh$1;
99215 exports.tensor = tensor;
99216 exports.tensor1d = tensor1d;
99217 exports.tensor2d = tensor2d;
99218 exports.tensor3d = tensor3d;
99219 exports.tensor4d = tensor4d;
99220 exports.tensor5d = tensor5d;
99221 exports.tensor6d = tensor6d;
99222 exports.tensor_util = tensor_util;
99223 exports.test_util = test_util;
99224 exports.tidy = tidy;
99225 exports.tile = tile;
99226 exports.time = time;
99227 exports.topk = topk;
99228 exports.train = train;
99229 exports.transpose = transpose;
99230 exports.truncatedNormal = truncatedNormal;
99231 exports.unique = unique;
99232 exports.unregisterGradient = unregisterGradient;
99233 exports.unregisterKernel = unregisterKernel;
99234 exports.unsortedSegmentSum = unsortedSegmentSum;
99235 exports.unstack = unstack;
99236 exports.upcastType = upcastType;
99237 exports.upperBound = upperBound;
99238 exports.util = util;
99239 exports.valueAndGrad = valueAndGrad;
99240 exports.valueAndGrads = valueAndGrads;
99241 exports.variable = variable;
99242 exports.variableGrads = variableGrads;
99243 exports.version = version$7;
99244 exports.version_converter = version$2;
99245 exports.version_core = version;
99246 exports.version_layers = version$1;
99247 exports.where = where;
99248 exports.whereAsync = whereAsync;
99249 exports.zeros = zeros;
99250 exports.zerosLike = zerosLike;
99251
99252 Object.defineProperty(exports, '__esModule', { value: true });
99253
99254})));
99255//# sourceMappingURL=tf.es2017.js.map