UNPKG

4.42 MBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2024 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17(function (global, factory) {
18 typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports) :
19 typeof define === 'function' && define.amd ? define(['exports'], factory) :
20 (global = typeof globalThis !== 'undefined' ? globalThis : global || self, factory(global.tf = global.tf || {}));
21})(this, (function (exports) { 'use strict';
22
23 function _mergeNamespaces(n, m) {
24 m.forEach(function (e) {
25 e && typeof e !== 'string' && !Array.isArray(e) && Object.keys(e).forEach(function (k) {
26 if (k !== 'default' && !(k in n)) {
27 var d = Object.getOwnPropertyDescriptor(e, k);
28 Object.defineProperty(n, k, d.get ? d : {
29 enumerable: true,
30 get: function () { return e[k]; }
31 });
32 }
33 });
34 });
35 return Object.freeze(n);
36 }
37
38 /**
39 * @license
40 * Copyright 2020 Google LLC. All Rights Reserved.
41 * Licensed under the Apache License, Version 2.0 (the "License");
42 * you may not use this file except in compliance with the License.
43 * You may obtain a copy of the License at
44 *
45 * http://www.apache.org/licenses/LICENSE-2.0
46 *
47 * Unless required by applicable law or agreed to in writing, software
48 * distributed under the License is distributed on an "AS IS" BASIS,
49 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50 * See the License for the specific language governing permissions and
51 * limitations under the License.
52 * =============================================================================
53 */
54 const EPSILON_FLOAT32$1 = 1e-7;
55 const EPSILON_FLOAT16$1 = 1e-4;
56 /** Convenient class for storing tensor-related data. */
57 class DataStorage {
58 constructor(backend, dataMover) {
59 this.backend = backend;
60 this.dataMover = dataMover;
61 this.data = new WeakMap();
62 this.dataIdsCount = 0;
63 }
64 get(dataId) {
65 if (!this.data.has(dataId)) {
66 this.dataMover.moveData(this.backend, dataId);
67 }
68 return this.data.get(dataId);
69 }
70 set(dataId, value) {
71 this.dataIdsCount++;
72 this.data.set(dataId, value);
73 }
74 has(dataId) {
75 return this.data.has(dataId);
76 }
77 delete(dataId) {
78 this.dataIdsCount--;
79 return this.data.delete(dataId);
80 }
81 numDataIds() {
82 return this.dataIdsCount;
83 }
84 }
85 /**
86 * The interface that defines the kernels that should be implemented when
87 * adding a new backend. New backends don't need to implement every one of the
88 * methods, this can be done gradually (throw an error for unimplemented
89 * methods).
90 */
91 class KernelBackend {
92 refCount(dataId) {
93 return notYetImplemented('refCount');
94 }
95 incRef(dataId) {
96 return notYetImplemented('incRef');
97 }
98 timerAvailable() {
99 return true;
100 }
101 time(f) {
102 return notYetImplemented('time');
103 }
104 read(dataId) {
105 return notYetImplemented('read');
106 }
107 readSync(dataId) {
108 return notYetImplemented('readSync');
109 }
110 readToGPU(dataId, options) {
111 return notYetImplemented('readToGPU');
112 }
113 numDataIds() {
114 return notYetImplemented('numDataIds');
115 }
116 disposeData(dataId, force) {
117 return notYetImplemented('disposeData');
118 }
119 write(values, shape, dtype) {
120 return notYetImplemented('write');
121 }
122 move(dataId, values, shape, dtype, refCount) {
123 return notYetImplemented('move');
124 }
125 createTensorFromGPUData(values, shape, dtype) {
126 return notYetImplemented('createTensorFromGPUData');
127 }
128 memory() {
129 return notYetImplemented('memory');
130 }
131 /** Returns the highest precision for floats in bits (e.g. 16 or 32) */
132 floatPrecision() {
133 return notYetImplemented('floatPrecision');
134 }
135 /** Returns the smallest representable number. */
136 epsilon() {
137 return this.floatPrecision() === 32 ? EPSILON_FLOAT32$1 : EPSILON_FLOAT16$1;
138 }
139 dispose() {
140 return notYetImplemented('dispose');
141 }
142 }
143 function notYetImplemented(kernelName) {
144 throw new Error(`'${kernelName}' not yet implemented or not found in the registry. ` +
145 `This kernel may not be supported by the tfjs backend you have chosen`);
146 }
147
148 /**
149 * @license
150 * Copyright 2020 Google LLC. All Rights Reserved.
151 * Licensed under the Apache License, Version 2.0 (the "License");
152 * you may not use this file except in compliance with the License.
153 * You may obtain a copy of the License at
154 *
155 * http://www.apache.org/licenses/LICENSE-2.0
156 *
157 * Unless required by applicable law or agreed to in writing, software
158 * distributed under the License is distributed on an "AS IS" BASIS,
159 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
160 * See the License for the specific language governing permissions and
161 * limitations under the License.
162 * =============================================================================
163 */
164 /**
165 * Shuffles the array in-place using Fisher-Yates algorithm.
166 *
167 * ```js
168 * const a = [1, 2, 3, 4, 5];
169 * tf.util.shuffle(a);
170 * console.log(a);
171 * ```
172 *
173 * @param array The array to shuffle in-place.
174 *
175 * @doc {heading: 'Util', namespace: 'util'}
176 */
177 // tslint:disable-next-line:no-any
178 function shuffle(array) {
179 let counter = array.length;
180 let index = 0;
181 // While there are elements in the array
182 while (counter > 0) {
183 // Pick a random index
184 index = (Math.random() * counter) | 0;
185 // Decrease counter by 1
186 counter--;
187 // And swap the last element with it
188 swap(array, counter, index);
189 }
190 }
191 /**
192 * Shuffles two arrays in-place the same way using Fisher-Yates algorithm.
193 *
194 * ```js
195 * const a = [1,2,3,4,5];
196 * const b = [11,22,33,44,55];
197 * tf.util.shuffleCombo(a, b);
198 * console.log(a, b);
199 * ```
200 *
201 * @param array The first array to shuffle in-place.
202 * @param array2 The second array to shuffle in-place with the same permutation
203 * as the first array.
204 *
205 * @doc {heading: 'Util', namespace: 'util'}
206 */
207 function shuffleCombo(
208 // tslint:disable-next-line:no-any
209 array,
210 // tslint:disable-next-line:no-any
211 array2) {
212 if (array.length !== array2.length) {
213 throw new Error(`Array sizes must match to be shuffled together ` +
214 `First array length was ${array.length}` +
215 `Second array length was ${array2.length}`);
216 }
217 let counter = array.length;
218 let index = 0;
219 // While there are elements in the array
220 while (counter > 0) {
221 // Pick a random index
222 index = (Math.random() * counter) | 0;
223 // Decrease counter by 1
224 counter--;
225 // And swap the last element of each array with it
226 swap(array, counter, index);
227 swap(array2, counter, index);
228 }
229 }
230 /** Clamps a value to a specified range. */
231 function clamp(min, x, max) {
232 return Math.max(min, Math.min(x, max));
233 }
234 function nearestLargerEven(val) {
235 return val % 2 === 0 ? val : val + 1;
236 }
237 function swap(object, left, right) {
238 const temp = object[left];
239 object[left] = object[right];
240 object[right] = temp;
241 }
242 function sum$4(arr) {
243 let sum = 0;
244 for (let i = 0; i < arr.length; i++) {
245 sum += arr[i];
246 }
247 return sum;
248 }
249 /**
250 * Returns a sample from a uniform [a, b) distribution.
251 *
252 * @param a The minimum support (inclusive).
253 * @param b The maximum support (exclusive).
254 * @return A pseudorandom number on the half-open interval [a,b).
255 */
256 function randUniform(a, b) {
257 const r = Math.random();
258 return (b * r) + (1 - r) * a;
259 }
260 /** Returns the squared Euclidean distance between two vectors. */
261 function distSquared(a, b) {
262 let result = 0;
263 for (let i = 0; i < a.length; i++) {
264 const diff = Number(a[i]) - Number(b[i]);
265 result += diff * diff;
266 }
267 return result;
268 }
269 /**
270 * Asserts that the expression is true. Otherwise throws an error with the
271 * provided message.
272 *
273 * ```js
274 * const x = 2;
275 * tf.util.assert(x === 2, 'x is not 2');
276 * ```
277 *
278 * @param expr The expression to assert (as a boolean).
279 * @param msg A function that returns the message to report when throwing an
280 * error. We use a function for performance reasons.
281 *
282 * @doc {heading: 'Util', namespace: 'util'}
283 */
284 function assert$1(expr, msg) {
285 if (!expr) {
286 throw new Error(typeof msg === 'string' ? msg : msg());
287 }
288 }
289 function assertShapesMatch(shapeA, shapeB, errorMessagePrefix = '') {
290 assert$1(arraysEqual(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);
291 }
292 function assertNonNull(a) {
293 assert$1(a != null, () => `The input to the tensor constructor must be a non-null value.`);
294 }
295 /**
296 * Returns the size (number of elements) of the tensor given its shape.
297 *
298 * ```js
299 * const shape = [3, 4, 2];
300 * const size = tf.util.sizeFromShape(shape);
301 * console.log(size);
302 * ```
303 *
304 * @doc {heading: 'Util', namespace: 'util'}
305 */
306 function sizeFromShape(shape) {
307 if (shape.length === 0) {
308 // Scalar.
309 return 1;
310 }
311 let size = shape[0];
312 for (let i = 1; i < shape.length; i++) {
313 size *= shape[i];
314 }
315 return size;
316 }
317 function isScalarShape(shape) {
318 return shape.length === 0;
319 }
320 function arraysEqualWithNull(n1, n2) {
321 if (n1 === n2) {
322 return true;
323 }
324 if (n1 == null || n2 == null) {
325 return false;
326 }
327 if (n1.length !== n2.length) {
328 return false;
329 }
330 for (let i = 0; i < n1.length; i++) {
331 if (n1[i] !== null && n2[i] !== null && n1[i] !== n2[i]) {
332 return false;
333 }
334 }
335 return true;
336 }
337 function arraysEqual(n1, n2) {
338 if (n1 === n2) {
339 return true;
340 }
341 if (n1 == null || n2 == null) {
342 return false;
343 }
344 if (n1.length !== n2.length) {
345 return false;
346 }
347 for (let i = 0; i < n1.length; i++) {
348 if (n1[i] !== n2[i]) {
349 return false;
350 }
351 }
352 return true;
353 }
354 function isInt(a) {
355 return a % 1 === 0;
356 }
357 function tanh$3(x) {
358 // tslint:disable-next-line:no-any
359 if (Math.tanh != null) {
360 // tslint:disable-next-line:no-any
361 return Math.tanh(x);
362 }
363 if (x === Infinity) {
364 return 1;
365 }
366 else if (x === -Infinity) {
367 return -1;
368 }
369 else {
370 const e2x = Math.exp(2 * x);
371 return (e2x - 1) / (e2x + 1);
372 }
373 }
374 function sizeToSquarishShape(size) {
375 const width = Math.ceil(Math.sqrt(size));
376 return [width, Math.ceil(size / width)];
377 }
378 /**
379 * Creates a new array with randomized indices to a given quantity.
380 *
381 * ```js
382 * const randomTen = tf.util.createShuffledIndices(10);
383 * console.log(randomTen);
384 * ```
385 *
386 * @param number Quantity of how many shuffled indices to create.
387 *
388 * @doc {heading: 'Util', namespace: 'util'}
389 */
390 function createShuffledIndices(n) {
391 const shuffledIndices = new Uint32Array(n);
392 for (let i = 0; i < n; ++i) {
393 shuffledIndices[i] = i;
394 }
395 shuffle(shuffledIndices);
396 return shuffledIndices;
397 }
398 function rightPad(a, size) {
399 if (size <= a.length) {
400 return a;
401 }
402 return a + ' '.repeat(size - a.length);
403 }
404 function repeatedTry(checkFn, delayFn = (counter) => 0, maxCounter, scheduleFn) {
405 return new Promise((resolve, reject) => {
406 let tryCount = 0;
407 const tryFn = () => {
408 if (checkFn()) {
409 resolve();
410 return;
411 }
412 tryCount++;
413 const nextBackoff = delayFn(tryCount);
414 if (maxCounter != null && tryCount >= maxCounter) {
415 reject();
416 return;
417 }
418 if (scheduleFn != null) {
419 scheduleFn(tryFn, nextBackoff);
420 }
421 else {
422 // google3 does not allow assigning another variable to setTimeout.
423 // Don't refactor this so scheduleFn has a default value of setTimeout.
424 setTimeout(tryFn, nextBackoff);
425 }
426 };
427 tryFn();
428 });
429 }
430 /**
431 * Given the full size of the array and a shape that may contain -1 as the
432 * implicit dimension, returns the inferred shape where -1 is replaced.
433 * E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3].
434 *
435 * @param shape The shape, which may contain -1 in some dimension.
436 * @param size The full size (number of elements) of the array.
437 * @return The inferred shape where -1 is replaced with the inferred size.
438 */
439 function inferFromImplicitShape(shape, size) {
440 let shapeProd = 1;
441 let implicitIdx = -1;
442 for (let i = 0; i < shape.length; ++i) {
443 if (shape[i] >= 0) {
444 shapeProd *= shape[i];
445 }
446 else if (shape[i] === -1) {
447 if (implicitIdx !== -1) {
448 throw Error(`Shapes can only have 1 implicit size. ` +
449 `Found -1 at dim ${implicitIdx} and dim ${i}`);
450 }
451 implicitIdx = i;
452 }
453 else if (shape[i] < 0) {
454 throw Error(`Shapes can not be < 0. Found ${shape[i]} at dim ${i}`);
455 }
456 }
457 if (implicitIdx === -1) {
458 if (size > 0 && size !== shapeProd) {
459 throw Error(`Size(${size}) must match the product of shape ${shape}`);
460 }
461 return shape;
462 }
463 if (shapeProd === 0) {
464 throw Error(`Cannot infer the missing size in [${shape}] when ` +
465 `there are 0 elements`);
466 }
467 if (size % shapeProd !== 0) {
468 throw Error(`The implicit shape can't be a fractional number. ` +
469 `Got ${size} / ${shapeProd}`);
470 }
471 const newShape = shape.slice();
472 newShape[implicitIdx] = size / shapeProd;
473 return newShape;
474 }
475 function parseAxisParam(axis, shape) {
476 const rank = shape.length;
477 // Normalize input
478 axis = axis == null ? shape.map((s, i) => i) : [].concat(axis);
479 // Check for valid range
480 assert$1(axis.every(ax => ax >= -rank && ax < rank), () => `All values in axis param must be in range [-${rank}, ${rank}) but ` +
481 `got axis ${axis}`);
482 // Check for only integers
483 assert$1(axis.every(ax => isInt(ax)), () => `All values in axis param must be integers but ` +
484 `got axis ${axis}`);
485 // Handle negative axis.
486 return axis.map(a => a < 0 ? rank + a : a);
487 }
488 /** Reduces the shape by removing all dimensions of shape 1. */
489 function squeezeShape(shape, axis) {
490 const newShape = [];
491 const keptDims = [];
492 const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
493 const axes = (axis == null || isEmptyArray) ?
494 null :
495 parseAxisParam(axis, shape).sort();
496 let j = 0;
497 for (let i = 0; i < shape.length; ++i) {
498 if (axes != null) {
499 if (axes[j] === i && shape[i] !== 1) {
500 throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`);
501 }
502 if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
503 newShape.push(shape[i]);
504 keptDims.push(i);
505 }
506 if (axes[j] <= i) {
507 j++;
508 }
509 }
510 if (shape[i] !== 1) {
511 newShape.push(shape[i]);
512 keptDims.push(i);
513 }
514 }
515 return { newShape, keptDims };
516 }
517 function getTypedArrayFromDType(dtype, size) {
518 return getArrayFromDType(dtype, size);
519 }
520 function getArrayFromDType(dtype, size) {
521 let values = null;
522 if (dtype == null || dtype === 'float32') {
523 values = new Float32Array(size);
524 }
525 else if (dtype === 'int32') {
526 values = new Int32Array(size);
527 }
528 else if (dtype === 'bool') {
529 values = new Uint8Array(size);
530 }
531 else if (dtype === 'string') {
532 values = new Array(size);
533 }
534 else {
535 throw new Error(`Unknown data type ${dtype}`);
536 }
537 return values;
538 }
539 function checkConversionForErrors(vals, dtype) {
540 for (let i = 0; i < vals.length; i++) {
541 const num = vals[i];
542 if (isNaN(num) || !isFinite(num)) {
543 throw Error(`A tensor of type ${dtype} being uploaded contains ${num}.`);
544 }
545 }
546 }
547 /** Returns true if the dtype is valid. */
548 function isValidDtype(dtype) {
549 return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' ||
550 dtype === 'int32' || dtype === 'string';
551 }
552 /**
553 * Returns true if the new type can't encode the old type without loss of
554 * precision.
555 */
556 function hasEncodingLoss(oldType, newType) {
557 if (newType === 'complex64') {
558 return false;
559 }
560 if (newType === 'float32' && oldType !== 'complex64') {
561 return false;
562 }
563 if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {
564 return false;
565 }
566 if (newType === 'bool' && oldType === 'bool') {
567 return false;
568 }
569 return true;
570 }
571 function bytesPerElement(dtype) {
572 if (dtype === 'float32' || dtype === 'int32') {
573 return 4;
574 }
575 else if (dtype === 'complex64') {
576 return 8;
577 }
578 else if (dtype === 'bool') {
579 return 1;
580 }
581 else {
582 throw new Error(`Unknown dtype ${dtype}`);
583 }
584 }
585 /**
586 * Returns the approximate number of bytes allocated in the string array - 2
587 * bytes per character. Computing the exact bytes for a native string in JS
588 * is not possible since it depends on the encoding of the html page that
589 * serves the website.
590 */
591 function bytesFromStringArray(arr) {
592 if (arr == null) {
593 return 0;
594 }
595 let bytes = 0;
596 arr.forEach(x => bytes += x.length);
597 return bytes;
598 }
599 /** Returns true if the value is a string. */
600 function isString(value) {
601 return typeof value === 'string' || value instanceof String;
602 }
603 function isBoolean(value) {
604 return typeof value === 'boolean';
605 }
606 function isNumber(value) {
607 return typeof value === 'number';
608 }
609 function inferDtype(values) {
610 if (Array.isArray(values)) {
611 return inferDtype(values[0]);
612 }
613 if (values instanceof Float32Array) {
614 return 'float32';
615 }
616 else if (values instanceof Int32Array || values instanceof Uint8Array ||
617 values instanceof Uint8ClampedArray) {
618 return 'int32';
619 }
620 else if (isNumber(values)) {
621 return 'float32';
622 }
623 else if (isString(values)) {
624 return 'string';
625 }
626 else if (isBoolean(values)) {
627 return 'bool';
628 }
629 return 'float32';
630 }
631 function isFunction(f) {
632 return !!(f && f.constructor && f.call && f.apply);
633 }
634 function nearestDivisor(size, start) {
635 for (let i = start; i < size; ++i) {
636 if (size % i === 0) {
637 return i;
638 }
639 }
640 return size;
641 }
642 function computeStrides(shape) {
643 const rank = shape.length;
644 if (rank < 2) {
645 return [];
646 }
647 // Last dimension has implicit stride of 1, thus having D-1 (instead of D)
648 // strides.
649 const strides = new Array(rank - 1);
650 strides[rank - 2] = shape[rank - 1];
651 for (let i = rank - 3; i >= 0; --i) {
652 strides[i] = strides[i + 1] * shape[i + 1];
653 }
654 return strides;
655 }
656 function createNestedArray(offset, shape, a, isComplex = false) {
657 const ret = new Array();
658 if (shape.length === 1) {
659 const d = shape[0] * (isComplex ? 2 : 1);
660 for (let i = 0; i < d; i++) {
661 ret[i] = a[offset + i];
662 }
663 }
664 else {
665 const d = shape[0];
666 const rest = shape.slice(1);
667 const len = rest.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1);
668 for (let i = 0; i < d; i++) {
669 ret[i] = createNestedArray(offset + i * len, rest, a, isComplex);
670 }
671 }
672 return ret;
673 }
674 // Provide a nested array of TypedArray in given shape.
675 function toNestedArray(shape, a, isComplex = false) {
676 if (shape.length === 0) {
677 // Scalar type should return a single number.
678 return a[0];
679 }
680 const size = shape.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1);
681 if (size === 0) {
682 // A tensor with shape zero should be turned into empty list.
683 return [];
684 }
685 if (size !== a.length) {
686 throw new Error(`[${shape}] does not match the input size ${a.length}${isComplex ? ' for a complex tensor' : ''}.`);
687 }
688 return createNestedArray(0, shape, a, isComplex);
689 }
690 function convertBackendValuesAndArrayBuffer(data, dtype) {
691 // If is type Uint8Array[], return it directly.
692 if (Array.isArray(data)) {
693 return data;
694 }
695 if (dtype === 'float32') {
696 return data instanceof Float32Array ? data : new Float32Array(data);
697 }
698 else if (dtype === 'int32') {
699 return data instanceof Int32Array ? data : new Int32Array(data);
700 }
701 else if (dtype === 'bool' || dtype === 'string') {
702 return Uint8Array.from(new Int32Array(data));
703 }
704 else {
705 throw new Error(`Unknown dtype ${dtype}`);
706 }
707 }
708 function makeOnesTypedArray(size, dtype) {
709 const array = makeZerosTypedArray(size, dtype);
710 for (let i = 0; i < array.length; i++) {
711 array[i] = 1;
712 }
713 return array;
714 }
715 function makeZerosTypedArray(size, dtype) {
716 if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
717 return new Float32Array(size);
718 }
719 else if (dtype === 'int32') {
720 return new Int32Array(size);
721 }
722 else if (dtype === 'bool') {
723 return new Uint8Array(size);
724 }
725 else {
726 throw new Error(`Unknown data type ${dtype}`);
727 }
728 }
729 /**
730 * Make nested `TypedArray` filled with zeros.
731 * @param shape The shape information for the nested array.
732 * @param dtype dtype of the array element.
733 */
734 function makeZerosNestedTypedArray(shape, dtype) {
735 const size = shape.reduce((prev, curr) => prev * curr, 1);
736 if (dtype == null || dtype === 'float32') {
737 return toNestedArray(shape, new Float32Array(size));
738 }
739 else if (dtype === 'int32') {
740 return toNestedArray(shape, new Int32Array(size));
741 }
742 else if (dtype === 'bool') {
743 return toNestedArray(shape, new Uint8Array(size));
744 }
745 else {
746 throw new Error(`Unknown data type ${dtype}`);
747 }
748 }
749 function assertNonNegativeIntegerDimensions(shape) {
750 shape.forEach(dimSize => {
751 assert$1(Number.isInteger(dimSize) && dimSize >= 0, () => `Tensor must have a shape comprised of positive integers but got ` +
752 `shape [${shape}].`);
753 });
754 }
755 /**
756 * Computes flat index for a given location (multidimentionsal index) in a
757 * Tensor/multidimensional array.
758 *
759 * @param locs Location in the tensor.
760 * @param rank Rank of the tensor.
761 * @param strides Tensor strides.
762 */
763 function locToIndex(locs, rank, strides) {
764 if (rank === 0) {
765 return 0;
766 }
767 else if (rank === 1) {
768 return locs[0];
769 }
770 let index = locs[locs.length - 1];
771 for (let i = 0; i < locs.length - 1; ++i) {
772 index += strides[i] * locs[i];
773 }
774 return index;
775 }
776 /**
777 * Computes the location (multidimensional index) in a
778 * tensor/multidimentional array for a given flat index.
779 *
780 * @param index Index in flat array.
781 * @param rank Rank of tensor.
782 * @param strides Strides of tensor.
783 */
784 function indexToLoc(index, rank, strides) {
785 if (rank === 0) {
786 return [];
787 }
788 else if (rank === 1) {
789 return [index];
790 }
791 const locs = new Array(rank);
792 for (let i = 0; i < locs.length - 1; ++i) {
793 locs[i] = Math.floor(index / strides[i]);
794 index -= locs[i] * strides[i];
795 }
796 locs[locs.length - 1] = index;
797 return locs;
798 }
799 /**
800 * This method asserts whether an object is a Promise instance.
801 * @param object
802 */
803 // tslint:disable-next-line: no-any
804 function isPromise(object) {
805 // We chose to not use 'obj instanceOf Promise' for two reasons:
806 // 1. It only reliably works for es6 Promise, not other Promise
807 // implementations.
808 // 2. It doesn't work with framework that uses zone.js. zone.js monkey
809 // patch the async calls, so it is possible the obj (patched) is
810 // comparing to a pre-patched Promise.
811 return object && object.then && typeof object.then === 'function';
812 }
813
814 /**
815 * @license
816 * Copyright 2017 Google LLC. All Rights Reserved.
817 * Licensed under the Apache License, Version 2.0 (the "License");
818 * you may not use this file except in compliance with the License.
819 * You may obtain a copy of the License at
820 *
821 * http://www.apache.org/licenses/LICENSE-2.0
822 *
823 * Unless required by applicable law or agreed to in writing, software
824 * distributed under the License is distributed on an "AS IS" BASIS,
825 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
826 * See the License for the specific language governing permissions and
827 * limitations under the License.
828 * =============================================================================
829 */
830 // Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true.
831 const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';
832 /**
833 * The environment contains evaluated flags as well as the registered platform.
834 * This is always used as a global singleton and can be retrieved with
835 * `tf.env()`.
836 *
837 * @doc {heading: 'Environment'}
838 */
839 class Environment {
840 // tslint:disable-next-line: no-any
841 constructor(global) {
842 this.global = global;
843 this.flags = {};
844 this.flagRegistry = {};
845 this.urlFlags = {};
846 // Jasmine spies on this in 'environment_test.ts'
847 this.getQueryParams = getQueryParams;
848 this.populateURLFlags();
849 }
850 setPlatform(platformName, platform) {
851 if (this.platform != null) {
852 if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
853 console.warn(`Platform ${this.platformName} has already been set. ` +
854 `Overwriting the platform with ${platformName}.`);
855 }
856 }
857 this.platformName = platformName;
858 this.platform = platform;
859 }
860 registerFlag(flagName, evaluationFn, setHook) {
861 this.flagRegistry[flagName] = { evaluationFn, setHook };
862 // Override the flag value from the URL. This has to happen here because
863 // the environment is initialized before flags get registered.
864 if (this.urlFlags[flagName] != null) {
865 const flagValue = this.urlFlags[flagName];
866 if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
867 console.warn(`Setting feature override from URL ${flagName}: ${flagValue}.`);
868 }
869 this.set(flagName, flagValue);
870 }
871 }
872 async getAsync(flagName) {
873 if (flagName in this.flags) {
874 return this.flags[flagName];
875 }
876 this.flags[flagName] = await this.evaluateFlag(flagName);
877 return this.flags[flagName];
878 }
879 get(flagName) {
880 if (flagName in this.flags) {
881 return this.flags[flagName];
882 }
883 const flagValue = this.evaluateFlag(flagName);
884 if (isPromise(flagValue)) {
885 throw new Error(`Flag ${flagName} cannot be synchronously evaluated. ` +
886 `Please use getAsync() instead.`);
887 }
888 this.flags[flagName] = flagValue;
889 return this.flags[flagName];
890 }
891 getNumber(flagName) {
892 return this.get(flagName);
893 }
894 getBool(flagName) {
895 return this.get(flagName);
896 }
897 getString(flagName) {
898 return this.get(flagName);
899 }
900 getFlags() {
901 return this.flags;
902 }
903 // For backwards compatibility.
904 get features() {
905 return this.flags;
906 }
907 set(flagName, value) {
908 if (this.flagRegistry[flagName] == null) {
909 throw new Error(`Cannot set flag ${flagName} as it has not been registered.`);
910 }
911 this.flags[flagName] = value;
912 if (this.flagRegistry[flagName].setHook != null) {
913 this.flagRegistry[flagName].setHook(value);
914 }
915 }
916 evaluateFlag(flagName) {
917 if (this.flagRegistry[flagName] == null) {
918 throw new Error(`Cannot evaluate flag '${flagName}': no evaluation function found.`);
919 }
920 return this.flagRegistry[flagName].evaluationFn();
921 }
922 setFlags(flags) {
923 this.flags = Object.assign({}, flags);
924 }
925 reset() {
926 this.flags = {};
927 this.urlFlags = {};
928 this.populateURLFlags();
929 }
930 populateURLFlags() {
931 if (typeof this.global === 'undefined' ||
932 typeof this.global.location === 'undefined' ||
933 typeof this.global.location.search === 'undefined') {
934 return;
935 }
936 const urlParams = this.getQueryParams(this.global.location.search);
937 if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
938 const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');
939 keyValues.forEach(keyValue => {
940 const [key, value] = keyValue.split(':');
941 this.urlFlags[key] = parseValue(key, value);
942 });
943 }
944 }
945 }
946 function getQueryParams(queryString) {
947 const params = {};
948 queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (s, ...t) => {
949 decodeParam(params, t[0], t[1]);
950 return t.join('=');
951 });
952 return params;
953 }
954 function decodeParam(params, name, value) {
955 params[decodeURIComponent(name)] = decodeURIComponent(value || '');
956 }
957 function parseValue(flagName, value) {
958 const lowerCaseValue = value.toLowerCase();
959 if (lowerCaseValue === 'true' || lowerCaseValue === 'false') {
960 return lowerCaseValue === 'true';
961 }
962 else if (`${+lowerCaseValue}` === lowerCaseValue) {
963 return +lowerCaseValue;
964 }
965 else {
966 return value;
967 }
968 }
969 /**
970 * Returns the current environment (a global singleton).
971 *
972 * The environment object contains the evaluated feature values as well as the
973 * active platform.
974 *
975 * @doc {heading: 'Environment'}
976 */
977 function env() {
978 return exports.ENV;
979 }
980 exports.ENV = null;
981 function setEnvironmentGlobal(environment) {
982 exports.ENV = environment;
983 }
984
985 /**
986 * @license
987 * Copyright 2020 Google LLC. All Rights Reserved.
988 * Licensed under the Apache License, Version 2.0 (the "License");
989 * you may not use this file except in compliance with the License.
990 * You may obtain a copy of the License at
991 *
992 * http://www.apache.org/licenses/LICENSE-2.0
993 *
994 * Unless required by applicable law or agreed to in writing, software
995 * distributed under the License is distributed on an "AS IS" BASIS,
996 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
997 * See the License for the specific language governing permissions and
998 * limitations under the License.
999 * =============================================================================
1000 */
1001 // Note that the identifier globalNameSpace is scoped to this module, but will
1002 // always resolve to the same global object regardless of how the module is
1003 // resolved.
1004 // tslint:disable-next-line:no-any
1005 let globalNameSpace;
1006 // tslint:disable-next-line:no-any
1007 function getGlobalNamespace() {
1008 if (globalNameSpace == null) {
1009 // tslint:disable-next-line:no-any
1010 let ns;
1011 if (typeof (window) !== 'undefined') {
1012 ns = window;
1013 }
1014 else if (typeof (global) !== 'undefined') {
1015 ns = global;
1016 }
1017 else if (typeof (process) !== 'undefined') {
1018 ns = process;
1019 }
1020 else if (typeof (self) !== 'undefined') {
1021 ns = self;
1022 }
1023 else {
1024 throw new Error('Could not find a global object');
1025 }
1026 globalNameSpace = ns;
1027 }
1028 return globalNameSpace;
1029 }
1030 // tslint:disable-next-line:no-any
1031 function getGlobalMap() {
1032 const ns = getGlobalNamespace();
1033 if (ns._tfGlobals == null) {
1034 ns._tfGlobals = new Map();
1035 }
1036 return ns._tfGlobals;
1037 }
1038 /**
1039 * Returns a globally accessible 'singleton' object.
1040 *
1041 * @param key the name of the object
1042 * @param init a function to initialize to initialize this object
1043 * the first time it is fetched.
1044 */
1045 function getGlobal(key, init) {
1046 const globalMap = getGlobalMap();
1047 if (globalMap.has(key)) {
1048 return globalMap.get(key);
1049 }
1050 else {
1051 const singleton = init();
1052 globalMap.set(key, singleton);
1053 return globalMap.get(key);
1054 }
1055 }
1056
1057 const Abs = 'Abs';
1058 const Acos = 'Acos';
1059 const Acosh = 'Acosh';
1060 const Add$1 = 'Add';
1061 const AddN = 'AddN';
1062 const All = 'All';
1063 const Any = 'Any';
1064 const ArgMax = 'ArgMax';
1065 const ArgMin = 'ArgMin';
1066 const Asin = 'Asin';
1067 const Asinh = 'Asinh';
1068 const Atan = 'Atan';
1069 const Atanh = 'Atanh';
1070 const Atan2 = 'Atan2';
1071 const AvgPool = 'AvgPool';
1072 const AvgPoolGrad = 'AvgPoolGrad';
1073 const AvgPool3D = 'AvgPool3D';
1074 const AvgPool3DGrad = 'AvgPool3DGrad';
1075 const BatchMatMul = 'BatchMatMul';
1076 const BatchToSpaceND = 'BatchToSpaceND';
1077 const Bincount = 'Bincount';
1078 const BitwiseAnd = 'BitwiseAnd';
1079 const BroadcastTo = 'BroadcastTo';
1080 const BroadcastArgs = 'BroadcastArgs';
1081 const Cast = 'Cast';
1082 const Ceil = 'Ceil';
1083 const ClipByValue = 'ClipByValue';
1084 const Complex = 'Complex';
1085 const ComplexAbs = 'ComplexAbs';
1086 const Concat = 'Concat';
1087 const Conv2D$1 = 'Conv2D';
1088 const Conv2DBackpropFilter = 'Conv2DBackpropFilter';
1089 const Conv2DBackpropInput = 'Conv2DBackpropInput';
1090 const Conv3D$1 = 'Conv3D';
1091 const Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2';
1092 const Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2';
1093 const Cos = 'Cos';
1094 const Cosh = 'Cosh';
1095 const Cumprod = 'Cumprod';
1096 const Cumsum = 'Cumsum';
1097 const CropAndResize = 'CropAndResize';
1098 const DenseBincount = 'DenseBincount';
1099 const DepthToSpace = 'DepthToSpace';
1100 const DepthwiseConv2dNative = 'DepthwiseConv2dNative';
1101 const DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter';
1102 const DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput';
1103 const Diag = 'Diag';
1104 const Dilation2D = 'Dilation2D';
1105 const Dilation2DBackpropInput = 'Dilation2DBackpropInput';
1106 const Dilation2DBackpropFilter = 'Dilation2DBackpropFilter';
1107 const Draw = 'Draw';
1108 const RealDiv = 'RealDiv';
1109 const Einsum = 'Einsum';
1110 const Elu$1 = 'Elu';
1111 const EluGrad = 'EluGrad';
1112 const Erf = 'Erf';
1113 const Equal = 'Equal';
1114 const Exp = 'Exp';
1115 const ExpandDims = 'ExpandDims';
1116 const Expm1 = 'Expm1';
1117 const FFT = 'FFT';
1118 const Fill = 'Fill';
1119 const FlipLeftRight = 'FlipLeftRight';
1120 const Floor = 'Floor';
1121 const FloorDiv = 'FloorDiv';
1122 const FusedBatchNorm = 'FusedBatchNorm';
1123 const GatherV2 = 'GatherV2';
1124 const GatherNd = 'GatherNd';
1125 const Greater = 'Greater';
1126 const GreaterEqual = 'GreaterEqual';
1127 const Identity$1 = 'Identity';
1128 const IFFT = 'IFFT';
1129 const Imag = 'Imag';
1130 const IsFinite = 'IsFinite';
1131 const IsInf = 'IsInf';
1132 const IsNan = 'IsNan';
1133 const LeakyRelu = 'LeakyRelu';
1134 const Less = 'Less';
1135 const LessEqual = 'LessEqual';
1136 const LinSpace = 'LinSpace';
1137 const Log = 'Log';
1138 const Log1p = 'Log1p';
1139 const LogicalAnd = 'LogicalAnd';
1140 const LogicalNot = 'LogicalNot';
1141 const LogicalOr = 'LogicalOr';
1142 const LogicalXor = 'LogicalXor';
1143 const LogSoftmax$1 = 'LogSoftmax';
1144 const LowerBound = 'LowerBound';
1145 const LRN = 'LRN';
1146 const LRNGrad = 'LRNGrad';
1147 const MatrixBandPart = 'MatrixBandPart';
1148 const Max = 'Max';
1149 const Maximum$1 = 'Maximum';
1150 const MaxPool = 'MaxPool';
1151 const MaxPoolGrad = 'MaxPoolGrad';
1152 const MaxPool3D = 'MaxPool3D';
1153 const MaxPool3DGrad = 'MaxPool3DGrad';
1154 const MaxPoolWithArgmax = 'MaxPoolWithArgmax';
1155 const Mean = 'Mean';
1156 const Min = 'Min';
1157 const Minimum$1 = 'Minimum';
1158 const MirrorPad = 'MirrorPad';
1159 const Mod = 'Mod';
1160 const Multinomial = 'Multinomial';
1161 const Multiply$1 = 'Multiply';
1162 const Neg = 'Neg';
1163 const NotEqual = 'NotEqual';
1164 const NonMaxSuppressionV3 = 'NonMaxSuppressionV3';
1165 const NonMaxSuppressionV4 = 'NonMaxSuppressionV4';
1166 const NonMaxSuppressionV5 = 'NonMaxSuppressionV5';
1167 const OnesLike = 'OnesLike';
1168 const OneHot = 'OneHot';
1169 const Pack = 'Pack';
1170 const PadV2 = 'PadV2';
1171 const Pool = 'Pool';
1172 const Pow = 'Pow';
1173 const Prelu = 'Prelu';
1174 const Prod = 'Prod';
1175 const RaggedGather = 'RaggedGather';
1176 const RaggedRange = 'RaggedRange';
1177 const RaggedTensorToTensor = 'RaggedTensorToTensor';
1178 const Range = 'Range';
1179 const Real = 'Real';
1180 const Reciprocal = 'Reciprocal';
1181 const Relu$1 = 'Relu';
1182 const Reshape$1 = 'Reshape';
1183 const ResizeNearestNeighbor = 'ResizeNearestNeighbor';
1184 const ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad';
1185 const ResizeBilinear = 'ResizeBilinear';
1186 const ResizeBilinearGrad = 'ResizeBilinearGrad';
1187 const Relu6$1 = 'Relu6';
1188 const Reverse = 'Reverse';
1189 const Round = 'Round';
1190 const Rsqrt = 'Rsqrt';
1191 const ScatterNd = 'ScatterNd';
1192 const TensorScatterUpdate = 'TensorScatterUpdate';
1193 const SearchSorted = 'SearchSorted';
1194 const Select = 'Select';
1195 const Selu$1 = 'Selu';
1196 const Slice = 'Slice';
1197 const Sin = 'Sin';
1198 const Sinh = 'Sinh';
1199 const Sign = 'Sign';
1200 const Sigmoid$1 = 'Sigmoid';
1201 const Softplus$1 = 'Softplus';
1202 const Sqrt = 'Sqrt';
1203 const Sum = 'Sum';
1204 const SpaceToBatchND = 'SpaceToBatchND';
1205 const SplitV = 'SplitV';
1206 const Softmax$2 = 'Softmax';
1207 const SparseFillEmptyRows = 'SparseFillEmptyRows';
1208 const SparseReshape = 'SparseReshape';
1209 const SparseSegmentMean = 'SparseSegmentMean';
1210 const SparseSegmentSum = 'SparseSegmentSum';
1211 const SparseToDense = 'SparseToDense';
1212 const SquaredDifference = 'SquaredDifference';
1213 const Square = 'Square';
1214 const StaticRegexReplace = 'StaticRegexReplace';
1215 const StridedSlice = 'StridedSlice';
1216 const StringNGrams = 'StringNGrams';
1217 const StringSplit = 'StringSplit';
1218 const StringToHashBucketFast = 'StringToHashBucketFast';
1219 const Sub = 'Sub';
1220 const Tan = 'Tan';
1221 const Tanh$1 = 'Tanh';
1222 const Tile = 'Tile';
1223 const TopK = 'TopK';
1224 const Transform = 'Transform';
1225 const Transpose = 'Transpose';
1226 const Unique = 'Unique';
1227 const Unpack = 'Unpack';
1228 const UnsortedSegmentSum = 'UnsortedSegmentSum';
1229 const UpperBound = 'UpperBound';
1230 const ZerosLike = 'ZerosLike';
1231 /**
1232 * TensorFlow.js-only kernels
1233 */
1234 const Step = 'Step';
1235 const FromPixels = 'FromPixels';
1236 const RotateWithOffset = 'RotateWithOffset';
1237 const _FusedMatMul = '_FusedMatMul';
1238 const FusedConv2D = 'FusedConv2D';
1239 const FusedDepthwiseConv2D = 'FusedDepthwiseConv2D';
1240
1241 /**
1242 * @license
1243 * Copyright 2018 Google LLC. All Rights Reserved.
1244 * Licensed under the Apache License, Version 2.0 (the "License");
1245 * you may not use this file except in compliance with the License.
1246 * You may obtain a copy of the License at
1247 *
1248 * http://www.apache.org/licenses/LICENSE-2.0
1249 *
1250 * Unless required by applicable law or agreed to in writing, software
1251 * distributed under the License is distributed on an "AS IS" BASIS,
1252 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1253 * See the License for the specific language governing permissions and
1254 * limitations under the License.
1255 * =============================================================================
1256 */
1257 function warn(...msg) {
1258 if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
1259 console.warn(...msg);
1260 }
1261 }
1262 function log$3(...msg) {
1263 if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
1264 console.log(...msg);
1265 }
1266 }
1267
1268 /**
1269 * @license
1270 * Copyright 2019 Google LLC. All Rights Reserved.
1271 * Licensed under the Apache License, Version 2.0 (the "License");
1272 * you may not use this file except in compliance with the License.
1273 * You may obtain a copy of the License at
1274 *
1275 * http://www.apache.org/licenses/LICENSE-2.0
1276 *
1277 * Unless required by applicable law or agreed to in writing, software
1278 * distributed under the License is distributed on an "AS IS" BASIS,
1279 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1280 * See the License for the specific language governing permissions and
1281 * limitations under the License.
1282 * =============================================================================
1283 */
1284 const kernelRegistry = getGlobal('kernelRegistry', () => new Map());
1285 const gradRegistry = getGlobal('gradRegistry', () => new Map());
1286 /**
1287 * Returns the kernel function (code) associated with the provided names.
1288 *
1289 * @param kernelName The official name of the kernel.
1290 * @param backendName The official name of the backend.
1291 */
1292 function getKernel(kernelName, backendName) {
1293 const key = makeKey(kernelName, backendName);
1294 return kernelRegistry.get(key);
1295 }
1296 /**
1297 * Returns the registered gradient info associated with the provided kernel.
1298 * @param kernelName The official TF kernel name.
1299 */
1300 function getGradient(kernelName) {
1301 return gradRegistry.get(kernelName);
1302 }
1303 function getKernelsForBackend(backendName) {
1304 const it = kernelRegistry.entries();
1305 const result = [];
1306 while (true) {
1307 const { done, value } = it.next();
1308 if (done) {
1309 break;
1310 }
1311 const [key, config] = value;
1312 const [backend,] = key.split('_');
1313 if (backend === backendName) {
1314 result.push(config);
1315 }
1316 }
1317 return result;
1318 }
1319 /**
1320 * Registers the function (forward pass) for the kernel in a global registry.
1321 *
1322 * @param config A config object with the following properties:
1323 * - `kernelName` The official name of the kernel.
1324 * - `backendName` The official name of the backend.
1325 * - `kernelFunc` The function to run during the forward pass of the kernel.
1326 * - `setupFunc` Optional. Gets called once, after the backend initializes.
1327 * - `disposeFunc` Optional. Gets called once, right before the backend is
1328 * disposed.
1329 */
1330 function registerKernel(config) {
1331 const { kernelName, backendName } = config;
1332 const key = makeKey(kernelName, backendName);
1333 if (kernelRegistry.has(key)) {
1334 warn(`The kernel '${kernelName}' for backend ` +
1335 `'${backendName}' is already registered`);
1336 }
1337 kernelRegistry.set(key, config);
1338 }
1339 /**
1340 * Registers a gradient function for a given kernel in the global registry,
1341 * to be used during the back-propagation of that kernel.
1342 *
1343 * @param config An object with the following properties:
1344 * - `kernelName` The name of the kernel that the gradient function is for.
1345 * - `gradFunc` The function to run during back-propagation.
1346 */
1347 function registerGradient(config) {
1348 const { kernelName } = config;
1349 if (gradRegistry.has(kernelName)) {
1350 // TODO (yassogba) after 3.0 assess whether we need to keep this gated
1351 // to debug mode.
1352 if (env().getBool('DEBUG')) {
1353 warn(`Overriding the gradient for '${kernelName}'`);
1354 }
1355 }
1356 gradRegistry.set(kernelName, config);
1357 }
1358 /**
1359 * Removes the kernel function from the registry.
1360 *
1361 * @param kernelName The official name of the kernel.
1362 * @param backendName The official name of the backend.
1363 *
1364 */
1365 function unregisterKernel(kernelName, backendName) {
1366 const key = makeKey(kernelName, backendName);
1367 if (!kernelRegistry.has(key)) {
1368 throw new Error(`The kernel '${kernelName}' for backend ` +
1369 `'${backendName}' is not registered`);
1370 }
1371 kernelRegistry.delete(key);
1372 }
1373 /** Removes the registered gradient from the global registry. */
1374 function unregisterGradient(kernelName) {
1375 if (!gradRegistry.has(kernelName)) {
1376 throw new Error(`The gradient '${kernelName}' for backend is not registered`);
1377 }
1378 gradRegistry.delete(kernelName);
1379 }
1380 /**
1381 * Finds kernels that have already been registered to a backend and re-registers
1382 * them for a new backend. Useful for registering custom backends.
1383 * @param registeredBackendName Already registered backend.
1384 * @param newBackendName New backend.
1385 */
1386 function copyRegisteredKernels(registeredBackendName, newBackendName) {
1387 const kernels = getKernelsForBackend(registeredBackendName);
1388 kernels.forEach(kernelConfig => {
1389 const newKernelConfig = Object.assign({}, kernelConfig, { backendName: newBackendName });
1390 registerKernel(newKernelConfig);
1391 });
1392 }
1393 function makeKey(kernelName, backendName) {
1394 return `${backendName}_${kernelName}`;
1395 }
1396
1397 /**
1398 * @license
1399 * Copyright 2023 Google LLC.
1400 * Licensed under the Apache License, Version 2.0 (the "License");
1401 * you may not use this file except in compliance with the License.
1402 * You may obtain a copy of the License at
1403 *
1404 * http://www.apache.org/licenses/LICENSE-2.0
1405 *
1406 * Unless required by applicable law or agreed to in writing, software
1407 * distributed under the License is distributed on an "AS IS" BASIS,
1408 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1409 * See the License for the specific language governing permissions and
1410 * limitations under the License.
1411 * =============================================================================
1412 */
1413 function isTypedArrayBrowser(a) {
1414 return a instanceof Float32Array || a instanceof Int32Array ||
1415 a instanceof Uint8Array || a instanceof Uint8ClampedArray;
1416 }
1417
1418 var commonjsGlobal = typeof globalThis !== 'undefined' ? globalThis : typeof window !== 'undefined' ? window : typeof global !== 'undefined' ? global : typeof self !== 'undefined' ? self : {};
1419
1420 function getDefaultExportFromCjs (x) {
1421 return x && x.__esModule && Object.prototype.hasOwnProperty.call(x, 'default') ? x['default'] : x;
1422 }
1423
1424 function getDefaultExportFromNamespaceIfPresent (n) {
1425 return n && Object.prototype.hasOwnProperty.call(n, 'default') ? n['default'] : n;
1426 }
1427
1428 function getDefaultExportFromNamespaceIfNotNamed (n) {
1429 return n && Object.prototype.hasOwnProperty.call(n, 'default') && Object.keys(n).length === 1 ? n['default'] : n;
1430 }
1431
1432 function getAugmentedNamespace(n) {
1433 if (n.__esModule) return n;
1434 var f = n.default;
1435 if (typeof f == "function") {
1436 var a = function a () {
1437 if (this instanceof a) {
1438 var args = [null];
1439 args.push.apply(args, arguments);
1440 var Ctor = Function.bind.apply(f, args);
1441 return new Ctor();
1442 }
1443 return f.apply(this, arguments);
1444 };
1445 a.prototype = f.prototype;
1446 } else a = {};
1447 Object.defineProperty(a, '__esModule', {value: true});
1448 Object.keys(n).forEach(function (k) {
1449 var d = Object.getOwnPropertyDescriptor(n, k);
1450 Object.defineProperty(a, k, d.get ? d : {
1451 enumerable: true,
1452 get: function () {
1453 return n[k];
1454 }
1455 });
1456 });
1457 return a;
1458 }
1459
1460 var long = Long$1;
1461
1462 /**
1463 * wasm optimizations, to do native i64 multiplication and divide
1464 */
1465 var wasm = null;
1466
1467 try {
1468 wasm = new WebAssembly.Instance(new WebAssembly.Module(new Uint8Array([
1469 0, 97, 115, 109, 1, 0, 0, 0, 1, 13, 2, 96, 0, 1, 127, 96, 4, 127, 127, 127, 127, 1, 127, 3, 7, 6, 0, 1, 1, 1, 1, 1, 6, 6, 1, 127, 1, 65, 0, 11, 7, 50, 6, 3, 109, 117, 108, 0, 1, 5, 100, 105, 118, 95, 115, 0, 2, 5, 100, 105, 118, 95, 117, 0, 3, 5, 114, 101, 109, 95, 115, 0, 4, 5, 114, 101, 109, 95, 117, 0, 5, 8, 103, 101, 116, 95, 104, 105, 103, 104, 0, 0, 10, 191, 1, 6, 4, 0, 35, 0, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 126, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 127, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 128, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 129, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 130, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11
1470 ])), {}).exports;
1471 } catch (e) {
1472 // no wasm support :(
1473 }
1474
1475 /**
1476 * Constructs a 64 bit two's-complement integer, given its low and high 32 bit values as *signed* integers.
1477 * See the from* functions below for more convenient ways of constructing Longs.
1478 * @exports Long
1479 * @class A Long class for representing a 64 bit two's-complement integer value.
1480 * @param {number} low The low (signed) 32 bits of the long
1481 * @param {number} high The high (signed) 32 bits of the long
1482 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
1483 * @constructor
1484 */
1485 function Long$1(low, high, unsigned) {
1486
1487 /**
1488 * The low 32 bits as a signed value.
1489 * @type {number}
1490 */
1491 this.low = low | 0;
1492
1493 /**
1494 * The high 32 bits as a signed value.
1495 * @type {number}
1496 */
1497 this.high = high | 0;
1498
1499 /**
1500 * Whether unsigned or not.
1501 * @type {boolean}
1502 */
1503 this.unsigned = !!unsigned;
1504 }
1505
1506 // The internal representation of a long is the two given signed, 32-bit values.
1507 // We use 32-bit pieces because these are the size of integers on which
1508 // Javascript performs bit-operations. For operations like addition and
1509 // multiplication, we split each number into 16 bit pieces, which can easily be
1510 // multiplied within Javascript's floating-point representation without overflow
1511 // or change in sign.
1512 //
1513 // In the algorithms below, we frequently reduce the negative case to the
1514 // positive case by negating the input(s) and then post-processing the result.
1515 // Note that we must ALWAYS check specially whether those values are MIN_VALUE
1516 // (-2^63) because -MIN_VALUE == MIN_VALUE (since 2^63 cannot be represented as
1517 // a positive number, it overflows back into a negative). Not handling this
1518 // case would often result in infinite recursion.
1519 //
1520 // Common constant values ZERO, ONE, NEG_ONE, etc. are defined below the from*
1521 // methods on which they depend.
1522
1523 /**
1524 * An indicator used to reliably determine if an object is a Long or not.
1525 * @type {boolean}
1526 * @const
1527 * @private
1528 */
1529 Long$1.prototype.__isLong__;
1530
1531 Object.defineProperty(Long$1.prototype, "__isLong__", { value: true });
1532
1533 /**
1534 * @function
1535 * @param {*} obj Object
1536 * @returns {boolean}
1537 * @inner
1538 */
1539 function isLong(obj) {
1540 return (obj && obj["__isLong__"]) === true;
1541 }
1542
1543 /**
1544 * Tests if the specified object is a Long.
1545 * @function
1546 * @param {*} obj Object
1547 * @returns {boolean}
1548 */
1549 Long$1.isLong = isLong;
1550
1551 /**
1552 * A cache of the Long representations of small integer values.
1553 * @type {!Object}
1554 * @inner
1555 */
1556 var INT_CACHE = {};
1557
1558 /**
1559 * A cache of the Long representations of small unsigned integer values.
1560 * @type {!Object}
1561 * @inner
1562 */
1563 var UINT_CACHE = {};
1564
1565 /**
1566 * @param {number} value
1567 * @param {boolean=} unsigned
1568 * @returns {!Long}
1569 * @inner
1570 */
1571 function fromInt(value, unsigned) {
1572 var obj, cachedObj, cache;
1573 if (unsigned) {
1574 value >>>= 0;
1575 if (cache = (0 <= value && value < 256)) {
1576 cachedObj = UINT_CACHE[value];
1577 if (cachedObj)
1578 return cachedObj;
1579 }
1580 obj = fromBits(value, (value | 0) < 0 ? -1 : 0, true);
1581 if (cache)
1582 UINT_CACHE[value] = obj;
1583 return obj;
1584 } else {
1585 value |= 0;
1586 if (cache = (-128 <= value && value < 128)) {
1587 cachedObj = INT_CACHE[value];
1588 if (cachedObj)
1589 return cachedObj;
1590 }
1591 obj = fromBits(value, value < 0 ? -1 : 0, false);
1592 if (cache)
1593 INT_CACHE[value] = obj;
1594 return obj;
1595 }
1596 }
1597
1598 /**
1599 * Returns a Long representing the given 32 bit integer value.
1600 * @function
1601 * @param {number} value The 32 bit integer in question
1602 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
1603 * @returns {!Long} The corresponding Long value
1604 */
1605 Long$1.fromInt = fromInt;
1606
1607 /**
1608 * @param {number} value
1609 * @param {boolean=} unsigned
1610 * @returns {!Long}
1611 * @inner
1612 */
1613 function fromNumber(value, unsigned) {
1614 if (isNaN(value))
1615 return unsigned ? UZERO : ZERO;
1616 if (unsigned) {
1617 if (value < 0)
1618 return UZERO;
1619 if (value >= TWO_PWR_64_DBL)
1620 return MAX_UNSIGNED_VALUE;
1621 } else {
1622 if (value <= -TWO_PWR_63_DBL)
1623 return MIN_VALUE;
1624 if (value + 1 >= TWO_PWR_63_DBL)
1625 return MAX_VALUE;
1626 }
1627 if (value < 0)
1628 return fromNumber(-value, unsigned).neg();
1629 return fromBits((value % TWO_PWR_32_DBL) | 0, (value / TWO_PWR_32_DBL) | 0, unsigned);
1630 }
1631
1632 /**
1633 * Returns a Long representing the given value, provided that it is a finite number. Otherwise, zero is returned.
1634 * @function
1635 * @param {number} value The number in question
1636 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
1637 * @returns {!Long} The corresponding Long value
1638 */
1639 Long$1.fromNumber = fromNumber;
1640
1641 /**
1642 * @param {number} lowBits
1643 * @param {number} highBits
1644 * @param {boolean=} unsigned
1645 * @returns {!Long}
1646 * @inner
1647 */
1648 function fromBits(lowBits, highBits, unsigned) {
1649 return new Long$1(lowBits, highBits, unsigned);
1650 }
1651
1652 /**
1653 * Returns a Long representing the 64 bit integer that comes by concatenating the given low and high bits. Each is
1654 * assumed to use 32 bits.
1655 * @function
1656 * @param {number} lowBits The low 32 bits
1657 * @param {number} highBits The high 32 bits
1658 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
1659 * @returns {!Long} The corresponding Long value
1660 */
1661 Long$1.fromBits = fromBits;
1662
1663 /**
1664 * @function
1665 * @param {number} base
1666 * @param {number} exponent
1667 * @returns {number}
1668 * @inner
1669 */
1670 var pow_dbl = Math.pow; // Used 4 times (4*8 to 15+4)
1671
1672 /**
1673 * @param {string} str
1674 * @param {(boolean|number)=} unsigned
1675 * @param {number=} radix
1676 * @returns {!Long}
1677 * @inner
1678 */
1679 function fromString(str, unsigned, radix) {
1680 if (str.length === 0)
1681 throw Error('empty string');
1682 if (str === "NaN" || str === "Infinity" || str === "+Infinity" || str === "-Infinity")
1683 return ZERO;
1684 if (typeof unsigned === 'number') {
1685 // For goog.math.long compatibility
1686 radix = unsigned,
1687 unsigned = false;
1688 } else {
1689 unsigned = !! unsigned;
1690 }
1691 radix = radix || 10;
1692 if (radix < 2 || 36 < radix)
1693 throw RangeError('radix');
1694
1695 var p;
1696 if ((p = str.indexOf('-')) > 0)
1697 throw Error('interior hyphen');
1698 else if (p === 0) {
1699 return fromString(str.substring(1), unsigned, radix).neg();
1700 }
1701
1702 // Do several (8) digits each time through the loop, so as to
1703 // minimize the calls to the very expensive emulated div.
1704 var radixToPower = fromNumber(pow_dbl(radix, 8));
1705
1706 var result = ZERO;
1707 for (var i = 0; i < str.length; i += 8) {
1708 var size = Math.min(8, str.length - i),
1709 value = parseInt(str.substring(i, i + size), radix);
1710 if (size < 8) {
1711 var power = fromNumber(pow_dbl(radix, size));
1712 result = result.mul(power).add(fromNumber(value));
1713 } else {
1714 result = result.mul(radixToPower);
1715 result = result.add(fromNumber(value));
1716 }
1717 }
1718 result.unsigned = unsigned;
1719 return result;
1720 }
1721
1722 /**
1723 * Returns a Long representation of the given string, written using the specified radix.
1724 * @function
1725 * @param {string} str The textual representation of the Long
1726 * @param {(boolean|number)=} unsigned Whether unsigned or not, defaults to signed
1727 * @param {number=} radix The radix in which the text is written (2-36), defaults to 10
1728 * @returns {!Long} The corresponding Long value
1729 */
1730 Long$1.fromString = fromString;
1731
1732 /**
1733 * @function
1734 * @param {!Long|number|string|!{low: number, high: number, unsigned: boolean}} val
1735 * @param {boolean=} unsigned
1736 * @returns {!Long}
1737 * @inner
1738 */
1739 function fromValue(val, unsigned) {
1740 if (typeof val === 'number')
1741 return fromNumber(val, unsigned);
1742 if (typeof val === 'string')
1743 return fromString(val, unsigned);
1744 // Throws for non-objects, converts non-instanceof Long:
1745 return fromBits(val.low, val.high, typeof unsigned === 'boolean' ? unsigned : val.unsigned);
1746 }
1747
1748 /**
1749 * Converts the specified value to a Long using the appropriate from* function for its type.
1750 * @function
1751 * @param {!Long|number|string|!{low: number, high: number, unsigned: boolean}} val Value
1752 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
1753 * @returns {!Long}
1754 */
1755 Long$1.fromValue = fromValue;
1756
1757 // NOTE: the compiler should inline these constant values below and then remove these variables, so there should be
1758 // no runtime penalty for these.
1759
1760 /**
1761 * @type {number}
1762 * @const
1763 * @inner
1764 */
1765 var TWO_PWR_16_DBL = 1 << 16;
1766
1767 /**
1768 * @type {number}
1769 * @const
1770 * @inner
1771 */
1772 var TWO_PWR_24_DBL = 1 << 24;
1773
1774 /**
1775 * @type {number}
1776 * @const
1777 * @inner
1778 */
1779 var TWO_PWR_32_DBL = TWO_PWR_16_DBL * TWO_PWR_16_DBL;
1780
1781 /**
1782 * @type {number}
1783 * @const
1784 * @inner
1785 */
1786 var TWO_PWR_64_DBL = TWO_PWR_32_DBL * TWO_PWR_32_DBL;
1787
1788 /**
1789 * @type {number}
1790 * @const
1791 * @inner
1792 */
1793 var TWO_PWR_63_DBL = TWO_PWR_64_DBL / 2;
1794
1795 /**
1796 * @type {!Long}
1797 * @const
1798 * @inner
1799 */
1800 var TWO_PWR_24 = fromInt(TWO_PWR_24_DBL);
1801
1802 /**
1803 * @type {!Long}
1804 * @inner
1805 */
1806 var ZERO = fromInt(0);
1807
1808 /**
1809 * Signed zero.
1810 * @type {!Long}
1811 */
1812 Long$1.ZERO = ZERO;
1813
1814 /**
1815 * @type {!Long}
1816 * @inner
1817 */
1818 var UZERO = fromInt(0, true);
1819
1820 /**
1821 * Unsigned zero.
1822 * @type {!Long}
1823 */
1824 Long$1.UZERO = UZERO;
1825
1826 /**
1827 * @type {!Long}
1828 * @inner
1829 */
1830 var ONE = fromInt(1);
1831
1832 /**
1833 * Signed one.
1834 * @type {!Long}
1835 */
1836 Long$1.ONE = ONE;
1837
1838 /**
1839 * @type {!Long}
1840 * @inner
1841 */
1842 var UONE = fromInt(1, true);
1843
1844 /**
1845 * Unsigned one.
1846 * @type {!Long}
1847 */
1848 Long$1.UONE = UONE;
1849
1850 /**
1851 * @type {!Long}
1852 * @inner
1853 */
1854 var NEG_ONE = fromInt(-1);
1855
1856 /**
1857 * Signed negative one.
1858 * @type {!Long}
1859 */
1860 Long$1.NEG_ONE = NEG_ONE;
1861
1862 /**
1863 * @type {!Long}
1864 * @inner
1865 */
1866 var MAX_VALUE = fromBits(0xFFFFFFFF|0, 0x7FFFFFFF|0, false);
1867
1868 /**
1869 * Maximum signed value.
1870 * @type {!Long}
1871 */
1872 Long$1.MAX_VALUE = MAX_VALUE;
1873
1874 /**
1875 * @type {!Long}
1876 * @inner
1877 */
1878 var MAX_UNSIGNED_VALUE = fromBits(0xFFFFFFFF|0, 0xFFFFFFFF|0, true);
1879
1880 /**
1881 * Maximum unsigned value.
1882 * @type {!Long}
1883 */
1884 Long$1.MAX_UNSIGNED_VALUE = MAX_UNSIGNED_VALUE;
1885
1886 /**
1887 * @type {!Long}
1888 * @inner
1889 */
1890 var MIN_VALUE = fromBits(0, 0x80000000|0, false);
1891
1892 /**
1893 * Minimum signed value.
1894 * @type {!Long}
1895 */
1896 Long$1.MIN_VALUE = MIN_VALUE;
1897
1898 /**
1899 * @alias Long.prototype
1900 * @inner
1901 */
1902 var LongPrototype = Long$1.prototype;
1903
1904 /**
1905 * Converts the Long to a 32 bit integer, assuming it is a 32 bit integer.
1906 * @returns {number}
1907 */
1908 LongPrototype.toInt = function toInt() {
1909 return this.unsigned ? this.low >>> 0 : this.low;
1910 };
1911
1912 /**
1913 * Converts the Long to a the nearest floating-point representation of this value (double, 53 bit mantissa).
1914 * @returns {number}
1915 */
1916 LongPrototype.toNumber = function toNumber() {
1917 if (this.unsigned)
1918 return ((this.high >>> 0) * TWO_PWR_32_DBL) + (this.low >>> 0);
1919 return this.high * TWO_PWR_32_DBL + (this.low >>> 0);
1920 };
1921
1922 /**
1923 * Converts the Long to a string written in the specified radix.
1924 * @param {number=} radix Radix (2-36), defaults to 10
1925 * @returns {string}
1926 * @override
1927 * @throws {RangeError} If `radix` is out of range
1928 */
1929 LongPrototype.toString = function toString(radix) {
1930 radix = radix || 10;
1931 if (radix < 2 || 36 < radix)
1932 throw RangeError('radix');
1933 if (this.isZero())
1934 return '0';
1935 if (this.isNegative()) { // Unsigned Longs are never negative
1936 if (this.eq(MIN_VALUE)) {
1937 // We need to change the Long value before it can be negated, so we remove
1938 // the bottom-most digit in this base and then recurse to do the rest.
1939 var radixLong = fromNumber(radix),
1940 div = this.div(radixLong),
1941 rem1 = div.mul(radixLong).sub(this);
1942 return div.toString(radix) + rem1.toInt().toString(radix);
1943 } else
1944 return '-' + this.neg().toString(radix);
1945 }
1946
1947 // Do several (6) digits each time through the loop, so as to
1948 // minimize the calls to the very expensive emulated div.
1949 var radixToPower = fromNumber(pow_dbl(radix, 6), this.unsigned),
1950 rem = this;
1951 var result = '';
1952 while (true) {
1953 var remDiv = rem.div(radixToPower),
1954 intval = rem.sub(remDiv.mul(radixToPower)).toInt() >>> 0,
1955 digits = intval.toString(radix);
1956 rem = remDiv;
1957 if (rem.isZero())
1958 return digits + result;
1959 else {
1960 while (digits.length < 6)
1961 digits = '0' + digits;
1962 result = '' + digits + result;
1963 }
1964 }
1965 };
1966
1967 /**
1968 * Gets the high 32 bits as a signed integer.
1969 * @returns {number} Signed high bits
1970 */
1971 LongPrototype.getHighBits = function getHighBits() {
1972 return this.high;
1973 };
1974
1975 /**
1976 * Gets the high 32 bits as an unsigned integer.
1977 * @returns {number} Unsigned high bits
1978 */
1979 LongPrototype.getHighBitsUnsigned = function getHighBitsUnsigned() {
1980 return this.high >>> 0;
1981 };
1982
1983 /**
1984 * Gets the low 32 bits as a signed integer.
1985 * @returns {number} Signed low bits
1986 */
1987 LongPrototype.getLowBits = function getLowBits() {
1988 return this.low;
1989 };
1990
1991 /**
1992 * Gets the low 32 bits as an unsigned integer.
1993 * @returns {number} Unsigned low bits
1994 */
1995 LongPrototype.getLowBitsUnsigned = function getLowBitsUnsigned() {
1996 return this.low >>> 0;
1997 };
1998
1999 /**
2000 * Gets the number of bits needed to represent the absolute value of this Long.
2001 * @returns {number}
2002 */
2003 LongPrototype.getNumBitsAbs = function getNumBitsAbs() {
2004 if (this.isNegative()) // Unsigned Longs are never negative
2005 return this.eq(MIN_VALUE) ? 64 : this.neg().getNumBitsAbs();
2006 var val = this.high != 0 ? this.high : this.low;
2007 for (var bit = 31; bit > 0; bit--)
2008 if ((val & (1 << bit)) != 0)
2009 break;
2010 return this.high != 0 ? bit + 33 : bit + 1;
2011 };
2012
2013 /**
2014 * Tests if this Long's value equals zero.
2015 * @returns {boolean}
2016 */
2017 LongPrototype.isZero = function isZero() {
2018 return this.high === 0 && this.low === 0;
2019 };
2020
2021 /**
2022 * Tests if this Long's value equals zero. This is an alias of {@link Long#isZero}.
2023 * @returns {boolean}
2024 */
2025 LongPrototype.eqz = LongPrototype.isZero;
2026
2027 /**
2028 * Tests if this Long's value is negative.
2029 * @returns {boolean}
2030 */
2031 LongPrototype.isNegative = function isNegative() {
2032 return !this.unsigned && this.high < 0;
2033 };
2034
2035 /**
2036 * Tests if this Long's value is positive.
2037 * @returns {boolean}
2038 */
2039 LongPrototype.isPositive = function isPositive() {
2040 return this.unsigned || this.high >= 0;
2041 };
2042
2043 /**
2044 * Tests if this Long's value is odd.
2045 * @returns {boolean}
2046 */
2047 LongPrototype.isOdd = function isOdd() {
2048 return (this.low & 1) === 1;
2049 };
2050
2051 /**
2052 * Tests if this Long's value is even.
2053 * @returns {boolean}
2054 */
2055 LongPrototype.isEven = function isEven() {
2056 return (this.low & 1) === 0;
2057 };
2058
2059 /**
2060 * Tests if this Long's value equals the specified's.
2061 * @param {!Long|number|string} other Other value
2062 * @returns {boolean}
2063 */
2064 LongPrototype.equals = function equals(other) {
2065 if (!isLong(other))
2066 other = fromValue(other);
2067 if (this.unsigned !== other.unsigned && (this.high >>> 31) === 1 && (other.high >>> 31) === 1)
2068 return false;
2069 return this.high === other.high && this.low === other.low;
2070 };
2071
2072 /**
2073 * Tests if this Long's value equals the specified's. This is an alias of {@link Long#equals}.
2074 * @function
2075 * @param {!Long|number|string} other Other value
2076 * @returns {boolean}
2077 */
2078 LongPrototype.eq = LongPrototype.equals;
2079
2080 /**
2081 * Tests if this Long's value differs from the specified's.
2082 * @param {!Long|number|string} other Other value
2083 * @returns {boolean}
2084 */
2085 LongPrototype.notEquals = function notEquals(other) {
2086 return !this.eq(/* validates */ other);
2087 };
2088
2089 /**
2090 * Tests if this Long's value differs from the specified's. This is an alias of {@link Long#notEquals}.
2091 * @function
2092 * @param {!Long|number|string} other Other value
2093 * @returns {boolean}
2094 */
2095 LongPrototype.neq = LongPrototype.notEquals;
2096
2097 /**
2098 * Tests if this Long's value differs from the specified's. This is an alias of {@link Long#notEquals}.
2099 * @function
2100 * @param {!Long|number|string} other Other value
2101 * @returns {boolean}
2102 */
2103 LongPrototype.ne = LongPrototype.notEquals;
2104
2105 /**
2106 * Tests if this Long's value is less than the specified's.
2107 * @param {!Long|number|string} other Other value
2108 * @returns {boolean}
2109 */
2110 LongPrototype.lessThan = function lessThan(other) {
2111 return this.comp(/* validates */ other) < 0;
2112 };
2113
2114 /**
2115 * Tests if this Long's value is less than the specified's. This is an alias of {@link Long#lessThan}.
2116 * @function
2117 * @param {!Long|number|string} other Other value
2118 * @returns {boolean}
2119 */
2120 LongPrototype.lt = LongPrototype.lessThan;
2121
2122 /**
2123 * Tests if this Long's value is less than or equal the specified's.
2124 * @param {!Long|number|string} other Other value
2125 * @returns {boolean}
2126 */
2127 LongPrototype.lessThanOrEqual = function lessThanOrEqual(other) {
2128 return this.comp(/* validates */ other) <= 0;
2129 };
2130
2131 /**
2132 * Tests if this Long's value is less than or equal the specified's. This is an alias of {@link Long#lessThanOrEqual}.
2133 * @function
2134 * @param {!Long|number|string} other Other value
2135 * @returns {boolean}
2136 */
2137 LongPrototype.lte = LongPrototype.lessThanOrEqual;
2138
2139 /**
2140 * Tests if this Long's value is less than or equal the specified's. This is an alias of {@link Long#lessThanOrEqual}.
2141 * @function
2142 * @param {!Long|number|string} other Other value
2143 * @returns {boolean}
2144 */
2145 LongPrototype.le = LongPrototype.lessThanOrEqual;
2146
2147 /**
2148 * Tests if this Long's value is greater than the specified's.
2149 * @param {!Long|number|string} other Other value
2150 * @returns {boolean}
2151 */
2152 LongPrototype.greaterThan = function greaterThan(other) {
2153 return this.comp(/* validates */ other) > 0;
2154 };
2155
2156 /**
2157 * Tests if this Long's value is greater than the specified's. This is an alias of {@link Long#greaterThan}.
2158 * @function
2159 * @param {!Long|number|string} other Other value
2160 * @returns {boolean}
2161 */
2162 LongPrototype.gt = LongPrototype.greaterThan;
2163
2164 /**
2165 * Tests if this Long's value is greater than or equal the specified's.
2166 * @param {!Long|number|string} other Other value
2167 * @returns {boolean}
2168 */
2169 LongPrototype.greaterThanOrEqual = function greaterThanOrEqual(other) {
2170 return this.comp(/* validates */ other) >= 0;
2171 };
2172
2173 /**
2174 * Tests if this Long's value is greater than or equal the specified's. This is an alias of {@link Long#greaterThanOrEqual}.
2175 * @function
2176 * @param {!Long|number|string} other Other value
2177 * @returns {boolean}
2178 */
2179 LongPrototype.gte = LongPrototype.greaterThanOrEqual;
2180
2181 /**
2182 * Tests if this Long's value is greater than or equal the specified's. This is an alias of {@link Long#greaterThanOrEqual}.
2183 * @function
2184 * @param {!Long|number|string} other Other value
2185 * @returns {boolean}
2186 */
2187 LongPrototype.ge = LongPrototype.greaterThanOrEqual;
2188
2189 /**
2190 * Compares this Long's value with the specified's.
2191 * @param {!Long|number|string} other Other value
2192 * @returns {number} 0 if they are the same, 1 if the this is greater and -1
2193 * if the given one is greater
2194 */
2195 LongPrototype.compare = function compare(other) {
2196 if (!isLong(other))
2197 other = fromValue(other);
2198 if (this.eq(other))
2199 return 0;
2200 var thisNeg = this.isNegative(),
2201 otherNeg = other.isNegative();
2202 if (thisNeg && !otherNeg)
2203 return -1;
2204 if (!thisNeg && otherNeg)
2205 return 1;
2206 // At this point the sign bits are the same
2207 if (!this.unsigned)
2208 return this.sub(other).isNegative() ? -1 : 1;
2209 // Both are positive if at least one is unsigned
2210 return (other.high >>> 0) > (this.high >>> 0) || (other.high === this.high && (other.low >>> 0) > (this.low >>> 0)) ? -1 : 1;
2211 };
2212
2213 /**
2214 * Compares this Long's value with the specified's. This is an alias of {@link Long#compare}.
2215 * @function
2216 * @param {!Long|number|string} other Other value
2217 * @returns {number} 0 if they are the same, 1 if the this is greater and -1
2218 * if the given one is greater
2219 */
2220 LongPrototype.comp = LongPrototype.compare;
2221
2222 /**
2223 * Negates this Long's value.
2224 * @returns {!Long} Negated Long
2225 */
2226 LongPrototype.negate = function negate() {
2227 if (!this.unsigned && this.eq(MIN_VALUE))
2228 return MIN_VALUE;
2229 return this.not().add(ONE);
2230 };
2231
2232 /**
2233 * Negates this Long's value. This is an alias of {@link Long#negate}.
2234 * @function
2235 * @returns {!Long} Negated Long
2236 */
2237 LongPrototype.neg = LongPrototype.negate;
2238
2239 /**
2240 * Returns the sum of this and the specified Long.
2241 * @param {!Long|number|string} addend Addend
2242 * @returns {!Long} Sum
2243 */
2244 LongPrototype.add = function add(addend) {
2245 if (!isLong(addend))
2246 addend = fromValue(addend);
2247
2248 // Divide each number into 4 chunks of 16 bits, and then sum the chunks.
2249
2250 var a48 = this.high >>> 16;
2251 var a32 = this.high & 0xFFFF;
2252 var a16 = this.low >>> 16;
2253 var a00 = this.low & 0xFFFF;
2254
2255 var b48 = addend.high >>> 16;
2256 var b32 = addend.high & 0xFFFF;
2257 var b16 = addend.low >>> 16;
2258 var b00 = addend.low & 0xFFFF;
2259
2260 var c48 = 0, c32 = 0, c16 = 0, c00 = 0;
2261 c00 += a00 + b00;
2262 c16 += c00 >>> 16;
2263 c00 &= 0xFFFF;
2264 c16 += a16 + b16;
2265 c32 += c16 >>> 16;
2266 c16 &= 0xFFFF;
2267 c32 += a32 + b32;
2268 c48 += c32 >>> 16;
2269 c32 &= 0xFFFF;
2270 c48 += a48 + b48;
2271 c48 &= 0xFFFF;
2272 return fromBits((c16 << 16) | c00, (c48 << 16) | c32, this.unsigned);
2273 };
2274
2275 /**
2276 * Returns the difference of this and the specified Long.
2277 * @param {!Long|number|string} subtrahend Subtrahend
2278 * @returns {!Long} Difference
2279 */
2280 LongPrototype.subtract = function subtract(subtrahend) {
2281 if (!isLong(subtrahend))
2282 subtrahend = fromValue(subtrahend);
2283 return this.add(subtrahend.neg());
2284 };
2285
2286 /**
2287 * Returns the difference of this and the specified Long. This is an alias of {@link Long#subtract}.
2288 * @function
2289 * @param {!Long|number|string} subtrahend Subtrahend
2290 * @returns {!Long} Difference
2291 */
2292 LongPrototype.sub = LongPrototype.subtract;
2293
2294 /**
2295 * Returns the product of this and the specified Long.
2296 * @param {!Long|number|string} multiplier Multiplier
2297 * @returns {!Long} Product
2298 */
2299 LongPrototype.multiply = function multiply(multiplier) {
2300 if (this.isZero())
2301 return ZERO;
2302 if (!isLong(multiplier))
2303 multiplier = fromValue(multiplier);
2304
2305 // use wasm support if present
2306 if (wasm) {
2307 var low = wasm.mul(this.low,
2308 this.high,
2309 multiplier.low,
2310 multiplier.high);
2311 return fromBits(low, wasm.get_high(), this.unsigned);
2312 }
2313
2314 if (multiplier.isZero())
2315 return ZERO;
2316 if (this.eq(MIN_VALUE))
2317 return multiplier.isOdd() ? MIN_VALUE : ZERO;
2318 if (multiplier.eq(MIN_VALUE))
2319 return this.isOdd() ? MIN_VALUE : ZERO;
2320
2321 if (this.isNegative()) {
2322 if (multiplier.isNegative())
2323 return this.neg().mul(multiplier.neg());
2324 else
2325 return this.neg().mul(multiplier).neg();
2326 } else if (multiplier.isNegative())
2327 return this.mul(multiplier.neg()).neg();
2328
2329 // If both longs are small, use float multiplication
2330 if (this.lt(TWO_PWR_24) && multiplier.lt(TWO_PWR_24))
2331 return fromNumber(this.toNumber() * multiplier.toNumber(), this.unsigned);
2332
2333 // Divide each long into 4 chunks of 16 bits, and then add up 4x4 products.
2334 // We can skip products that would overflow.
2335
2336 var a48 = this.high >>> 16;
2337 var a32 = this.high & 0xFFFF;
2338 var a16 = this.low >>> 16;
2339 var a00 = this.low & 0xFFFF;
2340
2341 var b48 = multiplier.high >>> 16;
2342 var b32 = multiplier.high & 0xFFFF;
2343 var b16 = multiplier.low >>> 16;
2344 var b00 = multiplier.low & 0xFFFF;
2345
2346 var c48 = 0, c32 = 0, c16 = 0, c00 = 0;
2347 c00 += a00 * b00;
2348 c16 += c00 >>> 16;
2349 c00 &= 0xFFFF;
2350 c16 += a16 * b00;
2351 c32 += c16 >>> 16;
2352 c16 &= 0xFFFF;
2353 c16 += a00 * b16;
2354 c32 += c16 >>> 16;
2355 c16 &= 0xFFFF;
2356 c32 += a32 * b00;
2357 c48 += c32 >>> 16;
2358 c32 &= 0xFFFF;
2359 c32 += a16 * b16;
2360 c48 += c32 >>> 16;
2361 c32 &= 0xFFFF;
2362 c32 += a00 * b32;
2363 c48 += c32 >>> 16;
2364 c32 &= 0xFFFF;
2365 c48 += a48 * b00 + a32 * b16 + a16 * b32 + a00 * b48;
2366 c48 &= 0xFFFF;
2367 return fromBits((c16 << 16) | c00, (c48 << 16) | c32, this.unsigned);
2368 };
2369
2370 /**
2371 * Returns the product of this and the specified Long. This is an alias of {@link Long#multiply}.
2372 * @function
2373 * @param {!Long|number|string} multiplier Multiplier
2374 * @returns {!Long} Product
2375 */
2376 LongPrototype.mul = LongPrototype.multiply;
2377
2378 /**
2379 * Returns this Long divided by the specified. The result is signed if this Long is signed or
2380 * unsigned if this Long is unsigned.
2381 * @param {!Long|number|string} divisor Divisor
2382 * @returns {!Long} Quotient
2383 */
2384 LongPrototype.divide = function divide(divisor) {
2385 if (!isLong(divisor))
2386 divisor = fromValue(divisor);
2387 if (divisor.isZero())
2388 throw Error('division by zero');
2389
2390 // use wasm support if present
2391 if (wasm) {
2392 // guard against signed division overflow: the largest
2393 // negative number / -1 would be 1 larger than the largest
2394 // positive number, due to two's complement.
2395 if (!this.unsigned &&
2396 this.high === -0x80000000 &&
2397 divisor.low === -1 && divisor.high === -1) {
2398 // be consistent with non-wasm code path
2399 return this;
2400 }
2401 var low = (this.unsigned ? wasm.div_u : wasm.div_s)(
2402 this.low,
2403 this.high,
2404 divisor.low,
2405 divisor.high
2406 );
2407 return fromBits(low, wasm.get_high(), this.unsigned);
2408 }
2409
2410 if (this.isZero())
2411 return this.unsigned ? UZERO : ZERO;
2412 var approx, rem, res;
2413 if (!this.unsigned) {
2414 // This section is only relevant for signed longs and is derived from the
2415 // closure library as a whole.
2416 if (this.eq(MIN_VALUE)) {
2417 if (divisor.eq(ONE) || divisor.eq(NEG_ONE))
2418 return MIN_VALUE; // recall that -MIN_VALUE == MIN_VALUE
2419 else if (divisor.eq(MIN_VALUE))
2420 return ONE;
2421 else {
2422 // At this point, we have |other| >= 2, so |this/other| < |MIN_VALUE|.
2423 var halfThis = this.shr(1);
2424 approx = halfThis.div(divisor).shl(1);
2425 if (approx.eq(ZERO)) {
2426 return divisor.isNegative() ? ONE : NEG_ONE;
2427 } else {
2428 rem = this.sub(divisor.mul(approx));
2429 res = approx.add(rem.div(divisor));
2430 return res;
2431 }
2432 }
2433 } else if (divisor.eq(MIN_VALUE))
2434 return this.unsigned ? UZERO : ZERO;
2435 if (this.isNegative()) {
2436 if (divisor.isNegative())
2437 return this.neg().div(divisor.neg());
2438 return this.neg().div(divisor).neg();
2439 } else if (divisor.isNegative())
2440 return this.div(divisor.neg()).neg();
2441 res = ZERO;
2442 } else {
2443 // The algorithm below has not been made for unsigned longs. It's therefore
2444 // required to take special care of the MSB prior to running it.
2445 if (!divisor.unsigned)
2446 divisor = divisor.toUnsigned();
2447 if (divisor.gt(this))
2448 return UZERO;
2449 if (divisor.gt(this.shru(1))) // 15 >>> 1 = 7 ; with divisor = 8 ; true
2450 return UONE;
2451 res = UZERO;
2452 }
2453
2454 // Repeat the following until the remainder is less than other: find a
2455 // floating-point that approximates remainder / other *from below*, add this
2456 // into the result, and subtract it from the remainder. It is critical that
2457 // the approximate value is less than or equal to the real value so that the
2458 // remainder never becomes negative.
2459 rem = this;
2460 while (rem.gte(divisor)) {
2461 // Approximate the result of division. This may be a little greater or
2462 // smaller than the actual value.
2463 approx = Math.max(1, Math.floor(rem.toNumber() / divisor.toNumber()));
2464
2465 // We will tweak the approximate result by changing it in the 48-th digit or
2466 // the smallest non-fractional digit, whichever is larger.
2467 var log2 = Math.ceil(Math.log(approx) / Math.LN2),
2468 delta = (log2 <= 48) ? 1 : pow_dbl(2, log2 - 48),
2469
2470 // Decrease the approximation until it is smaller than the remainder. Note
2471 // that if it is too large, the product overflows and is negative.
2472 approxRes = fromNumber(approx),
2473 approxRem = approxRes.mul(divisor);
2474 while (approxRem.isNegative() || approxRem.gt(rem)) {
2475 approx -= delta;
2476 approxRes = fromNumber(approx, this.unsigned);
2477 approxRem = approxRes.mul(divisor);
2478 }
2479
2480 // We know the answer can't be zero... and actually, zero would cause
2481 // infinite recursion since we would make no progress.
2482 if (approxRes.isZero())
2483 approxRes = ONE;
2484
2485 res = res.add(approxRes);
2486 rem = rem.sub(approxRem);
2487 }
2488 return res;
2489 };
2490
2491 /**
2492 * Returns this Long divided by the specified. This is an alias of {@link Long#divide}.
2493 * @function
2494 * @param {!Long|number|string} divisor Divisor
2495 * @returns {!Long} Quotient
2496 */
2497 LongPrototype.div = LongPrototype.divide;
2498
2499 /**
2500 * Returns this Long modulo the specified.
2501 * @param {!Long|number|string} divisor Divisor
2502 * @returns {!Long} Remainder
2503 */
2504 LongPrototype.modulo = function modulo(divisor) {
2505 if (!isLong(divisor))
2506 divisor = fromValue(divisor);
2507
2508 // use wasm support if present
2509 if (wasm) {
2510 var low = (this.unsigned ? wasm.rem_u : wasm.rem_s)(
2511 this.low,
2512 this.high,
2513 divisor.low,
2514 divisor.high
2515 );
2516 return fromBits(low, wasm.get_high(), this.unsigned);
2517 }
2518
2519 return this.sub(this.div(divisor).mul(divisor));
2520 };
2521
2522 /**
2523 * Returns this Long modulo the specified. This is an alias of {@link Long#modulo}.
2524 * @function
2525 * @param {!Long|number|string} divisor Divisor
2526 * @returns {!Long} Remainder
2527 */
2528 LongPrototype.mod = LongPrototype.modulo;
2529
2530 /**
2531 * Returns this Long modulo the specified. This is an alias of {@link Long#modulo}.
2532 * @function
2533 * @param {!Long|number|string} divisor Divisor
2534 * @returns {!Long} Remainder
2535 */
2536 LongPrototype.rem = LongPrototype.modulo;
2537
2538 /**
2539 * Returns the bitwise NOT of this Long.
2540 * @returns {!Long}
2541 */
2542 LongPrototype.not = function not() {
2543 return fromBits(~this.low, ~this.high, this.unsigned);
2544 };
2545
2546 /**
2547 * Returns the bitwise AND of this Long and the specified.
2548 * @param {!Long|number|string} other Other Long
2549 * @returns {!Long}
2550 */
2551 LongPrototype.and = function and(other) {
2552 if (!isLong(other))
2553 other = fromValue(other);
2554 return fromBits(this.low & other.low, this.high & other.high, this.unsigned);
2555 };
2556
2557 /**
2558 * Returns the bitwise OR of this Long and the specified.
2559 * @param {!Long|number|string} other Other Long
2560 * @returns {!Long}
2561 */
2562 LongPrototype.or = function or(other) {
2563 if (!isLong(other))
2564 other = fromValue(other);
2565 return fromBits(this.low | other.low, this.high | other.high, this.unsigned);
2566 };
2567
2568 /**
2569 * Returns the bitwise XOR of this Long and the given one.
2570 * @param {!Long|number|string} other Other Long
2571 * @returns {!Long}
2572 */
2573 LongPrototype.xor = function xor(other) {
2574 if (!isLong(other))
2575 other = fromValue(other);
2576 return fromBits(this.low ^ other.low, this.high ^ other.high, this.unsigned);
2577 };
2578
2579 /**
2580 * Returns this Long with bits shifted to the left by the given amount.
2581 * @param {number|!Long} numBits Number of bits
2582 * @returns {!Long} Shifted Long
2583 */
2584 LongPrototype.shiftLeft = function shiftLeft(numBits) {
2585 if (isLong(numBits))
2586 numBits = numBits.toInt();
2587 if ((numBits &= 63) === 0)
2588 return this;
2589 else if (numBits < 32)
2590 return fromBits(this.low << numBits, (this.high << numBits) | (this.low >>> (32 - numBits)), this.unsigned);
2591 else
2592 return fromBits(0, this.low << (numBits - 32), this.unsigned);
2593 };
2594
2595 /**
2596 * Returns this Long with bits shifted to the left by the given amount. This is an alias of {@link Long#shiftLeft}.
2597 * @function
2598 * @param {number|!Long} numBits Number of bits
2599 * @returns {!Long} Shifted Long
2600 */
2601 LongPrototype.shl = LongPrototype.shiftLeft;
2602
2603 /**
2604 * Returns this Long with bits arithmetically shifted to the right by the given amount.
2605 * @param {number|!Long} numBits Number of bits
2606 * @returns {!Long} Shifted Long
2607 */
2608 LongPrototype.shiftRight = function shiftRight(numBits) {
2609 if (isLong(numBits))
2610 numBits = numBits.toInt();
2611 if ((numBits &= 63) === 0)
2612 return this;
2613 else if (numBits < 32)
2614 return fromBits((this.low >>> numBits) | (this.high << (32 - numBits)), this.high >> numBits, this.unsigned);
2615 else
2616 return fromBits(this.high >> (numBits - 32), this.high >= 0 ? 0 : -1, this.unsigned);
2617 };
2618
2619 /**
2620 * Returns this Long with bits arithmetically shifted to the right by the given amount. This is an alias of {@link Long#shiftRight}.
2621 * @function
2622 * @param {number|!Long} numBits Number of bits
2623 * @returns {!Long} Shifted Long
2624 */
2625 LongPrototype.shr = LongPrototype.shiftRight;
2626
2627 /**
2628 * Returns this Long with bits logically shifted to the right by the given amount.
2629 * @param {number|!Long} numBits Number of bits
2630 * @returns {!Long} Shifted Long
2631 */
2632 LongPrototype.shiftRightUnsigned = function shiftRightUnsigned(numBits) {
2633 if (isLong(numBits))
2634 numBits = numBits.toInt();
2635 numBits &= 63;
2636 if (numBits === 0)
2637 return this;
2638 else {
2639 var high = this.high;
2640 if (numBits < 32) {
2641 var low = this.low;
2642 return fromBits((low >>> numBits) | (high << (32 - numBits)), high >>> numBits, this.unsigned);
2643 } else if (numBits === 32)
2644 return fromBits(high, 0, this.unsigned);
2645 else
2646 return fromBits(high >>> (numBits - 32), 0, this.unsigned);
2647 }
2648 };
2649
2650 /**
2651 * Returns this Long with bits logically shifted to the right by the given amount. This is an alias of {@link Long#shiftRightUnsigned}.
2652 * @function
2653 * @param {number|!Long} numBits Number of bits
2654 * @returns {!Long} Shifted Long
2655 */
2656 LongPrototype.shru = LongPrototype.shiftRightUnsigned;
2657
2658 /**
2659 * Returns this Long with bits logically shifted to the right by the given amount. This is an alias of {@link Long#shiftRightUnsigned}.
2660 * @function
2661 * @param {number|!Long} numBits Number of bits
2662 * @returns {!Long} Shifted Long
2663 */
2664 LongPrototype.shr_u = LongPrototype.shiftRightUnsigned;
2665
2666 /**
2667 * Converts this Long to signed.
2668 * @returns {!Long} Signed long
2669 */
2670 LongPrototype.toSigned = function toSigned() {
2671 if (!this.unsigned)
2672 return this;
2673 return fromBits(this.low, this.high, false);
2674 };
2675
2676 /**
2677 * Converts this Long to unsigned.
2678 * @returns {!Long} Unsigned long
2679 */
2680 LongPrototype.toUnsigned = function toUnsigned() {
2681 if (this.unsigned)
2682 return this;
2683 return fromBits(this.low, this.high, true);
2684 };
2685
2686 /**
2687 * Converts this Long to its byte representation.
2688 * @param {boolean=} le Whether little or big endian, defaults to big endian
2689 * @returns {!Array.<number>} Byte representation
2690 */
2691 LongPrototype.toBytes = function toBytes(le) {
2692 return le ? this.toBytesLE() : this.toBytesBE();
2693 };
2694
2695 /**
2696 * Converts this Long to its little endian byte representation.
2697 * @returns {!Array.<number>} Little endian byte representation
2698 */
2699 LongPrototype.toBytesLE = function toBytesLE() {
2700 var hi = this.high,
2701 lo = this.low;
2702 return [
2703 lo & 0xff,
2704 lo >>> 8 & 0xff,
2705 lo >>> 16 & 0xff,
2706 lo >>> 24 ,
2707 hi & 0xff,
2708 hi >>> 8 & 0xff,
2709 hi >>> 16 & 0xff,
2710 hi >>> 24
2711 ];
2712 };
2713
2714 /**
2715 * Converts this Long to its big endian byte representation.
2716 * @returns {!Array.<number>} Big endian byte representation
2717 */
2718 LongPrototype.toBytesBE = function toBytesBE() {
2719 var hi = this.high,
2720 lo = this.low;
2721 return [
2722 hi >>> 24 ,
2723 hi >>> 16 & 0xff,
2724 hi >>> 8 & 0xff,
2725 hi & 0xff,
2726 lo >>> 24 ,
2727 lo >>> 16 & 0xff,
2728 lo >>> 8 & 0xff,
2729 lo & 0xff
2730 ];
2731 };
2732
2733 /**
2734 * Creates a Long from its byte representation.
2735 * @param {!Array.<number>} bytes Byte representation
2736 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
2737 * @param {boolean=} le Whether little or big endian, defaults to big endian
2738 * @returns {Long} The corresponding Long value
2739 */
2740 Long$1.fromBytes = function fromBytes(bytes, unsigned, le) {
2741 return le ? Long$1.fromBytesLE(bytes, unsigned) : Long$1.fromBytesBE(bytes, unsigned);
2742 };
2743
2744 /**
2745 * Creates a Long from its little endian byte representation.
2746 * @param {!Array.<number>} bytes Little endian byte representation
2747 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
2748 * @returns {Long} The corresponding Long value
2749 */
2750 Long$1.fromBytesLE = function fromBytesLE(bytes, unsigned) {
2751 return new Long$1(
2752 bytes[0] |
2753 bytes[1] << 8 |
2754 bytes[2] << 16 |
2755 bytes[3] << 24,
2756 bytes[4] |
2757 bytes[5] << 8 |
2758 bytes[6] << 16 |
2759 bytes[7] << 24,
2760 unsigned
2761 );
2762 };
2763
2764 /**
2765 * Creates a Long from its big endian byte representation.
2766 * @param {!Array.<number>} bytes Big endian byte representation
2767 * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
2768 * @returns {Long} The corresponding Long value
2769 */
2770 Long$1.fromBytesBE = function fromBytesBE(bytes, unsigned) {
2771 return new Long$1(
2772 bytes[4] << 24 |
2773 bytes[5] << 16 |
2774 bytes[6] << 8 |
2775 bytes[7],
2776 bytes[0] << 24 |
2777 bytes[1] << 16 |
2778 bytes[2] << 8 |
2779 bytes[3],
2780 unsigned
2781 );
2782 };
2783
2784 var long$1 = /*@__PURE__*/getDefaultExportFromCjs(long);
2785
2786 var LongExports = /*#__PURE__*/_mergeNamespaces({
2787 __proto__: null,
2788 default: long$1
2789 }, [long]);
2790
2791 /**
2792 * @license
2793 * Copyright 2021 Google LLC. All Rights Reserved.
2794 * Licensed under the Apache License, Version 2.0 (the "License");
2795 * you may not use this file except in compliance with the License.
2796 * You may obtain a copy of the License at
2797 *
2798 * http://www.apache.org/licenses/LICENSE-2.0
2799 *
2800 * Unless required by applicable law or agreed to in writing, software
2801 * distributed under the License is distributed on an "AS IS" BASIS,
2802 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2803 * See the License for the specific language governing permissions and
2804 * limitations under the License.
2805 * =============================================================================
2806 */
2807 // tslint:disable-next-line
2808 const Long =
2809 // tslint:disable-next-line
2810 long$1 || LongExports;
2811 function hexToLong(hex) {
2812 return Long.fromString(hex, true, 16);
2813 }
2814 // Some primes between 2^63 and 2^64 for various uses.
2815 // Hex 0xc3a5c85c97cb3127
2816 const k0 = hexToLong('c3a5c85c97cb3127');
2817 // Hex 0xb492b66fbe98f273
2818 const k1 = hexToLong('b492b66fbe98f273');
2819 // Hex 0x9ae16a3b2f90404f
2820 const k2 = hexToLong('9ae16a3b2f90404f');
2821 function shiftMix(val) {
2822 return val.xor(val.shru(47));
2823 }
2824 function fetch$2(s, offset, numBytes) {
2825 const bytes = s.slice(offset, offset + numBytes);
2826 return Long.fromBytes(Array.from(bytes), true, true);
2827 }
2828 function fetch64(s, offset) {
2829 return fetch$2(s, offset, 8);
2830 }
2831 function fetch32(s, offset) {
2832 return fetch$2(s, offset, 4);
2833 }
2834 function rotate64(val, shift) {
2835 // Avoid shifting by 64: doing so yields an undefined result.
2836 return shift === 0 ? val : val.shru(shift).or(val.shl(64 - shift));
2837 }
2838 function hashLen16(u, v, mul = hexToLong('9ddfea08eb382d69')) {
2839 // Murmur-inspired hashing.
2840 let a = u.xor(v).mul(mul);
2841 a = a.xor(a.shru(47));
2842 let b = v.xor(a).mul(mul);
2843 b = b.xor(b.shru(47));
2844 b = b.mul(mul);
2845 return b;
2846 }
2847 // Return a 16-byte hash for 48 bytes. Quick and dirty.
2848 // Callers do best to use "random-looking" values for a and b.
2849 function weakHashLen32WithSeeds(w, x, y, z, a, b) {
2850 a = a.add(w);
2851 b = rotate64(b.add(a).add(z), 21);
2852 const c = a;
2853 a = a.add(x);
2854 a = a.add(y);
2855 b = b.add(rotate64(a, 44));
2856 return [a.add(z), b.add(c)];
2857 }
2858 function weakHashLen32WithSeedsStr(s, offset, a, b) {
2859 return weakHashLen32WithSeeds(fetch64(s, offset), fetch64(s, offset + 8), fetch64(s, offset + 16), fetch64(s, offset + 24), a, b);
2860 }
2861 function hashLen0to16(s, len = s.length) {
2862 if (len >= 8) {
2863 const mul = k2.add(len * 2);
2864 const a = fetch64(s, 0).add(k2);
2865 const b = fetch64(s, len - 8);
2866 const c = rotate64(b, 37).mul(mul).add(a);
2867 const d = rotate64(a, 25).add(b).mul(mul);
2868 return hashLen16(c, d, mul);
2869 }
2870 if (len >= 4) {
2871 const mul = k2.add(len * 2);
2872 const a = fetch32(s, 0);
2873 return hashLen16(a.shl(3).add(len), fetch32(s, len - 4), mul);
2874 }
2875 if (len > 0) {
2876 const a = s[0];
2877 const b = s[len >> 1];
2878 const c = s[len - 1];
2879 const y = a + (b << 8);
2880 const z = len + (c << 2);
2881 return shiftMix(k2.mul(y).xor(k0.mul(z))).mul(k2);
2882 }
2883 return k2;
2884 }
2885 function hashLen17to32(s, len = s.length) {
2886 const mul = k2.add(len * 2);
2887 const a = fetch64(s, 0).mul(k1);
2888 const b = fetch64(s, 8);
2889 const c = fetch64(s, len - 8).mul(mul);
2890 const d = fetch64(s, len - 16).mul(k2);
2891 return hashLen16(rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d), a.add(rotate64(b.add(k2), 18)).add(c), mul);
2892 }
2893 function hashLen33to64(s, len = s.length) {
2894 const mul = k2.add(len * 2);
2895 const a = fetch64(s, 0).mul(k2);
2896 const b = fetch64(s, 8);
2897 const c = fetch64(s, len - 8).mul(mul);
2898 const d = fetch64(s, len - 16).mul(k2);
2899 const y = rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d);
2900 const z = hashLen16(y, a.add(rotate64(b.add(k2), 18)).add(c), mul);
2901 const e = fetch64(s, 16).mul(mul);
2902 const f = fetch64(s, 24);
2903 const g = y.add(fetch64(s, len - 32)).mul(mul);
2904 const h = z.add(fetch64(s, len - 24)).mul(mul);
2905 return hashLen16(rotate64(e.add(f), 43).add(rotate64(g, 30)).add(h), e.add(rotate64(f.add(a), 18)).add(g), mul);
2906 }
2907 function fingerPrint64(s, len = s.length) {
2908 const seed = Long.fromNumber(81, true);
2909 if (len <= 32) {
2910 if (len <= 16) {
2911 return hashLen0to16(s, len);
2912 }
2913 else {
2914 return hashLen17to32(s, len);
2915 }
2916 }
2917 else if (len <= 64) {
2918 return hashLen33to64(s, len);
2919 }
2920 // For strings over 64 bytes we loop. Internal state consists of
2921 // 56 bytes: v, w, x, y, and z.
2922 let x = seed;
2923 let y = seed.mul(k1).add(113);
2924 let z = shiftMix(y.mul(k2).add(113)).mul(k2);
2925 let v = [Long.UZERO, Long.UZERO];
2926 let w = [Long.UZERO, Long.UZERO];
2927 x = x.mul(k2).add(fetch64(s, 0));
2928 let offset = 0;
2929 // Set end so that after the loop we have 1 to 64 bytes left to process.
2930 const end = ((len - 1) >> 6) * 64;
2931 const last64 = end + ((len - 1) & 63) - 63;
2932 do {
2933 x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(k1);
2934 y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(k1);
2935 x = x.xor(w[1]);
2936 y = y.add(v[0]).add(fetch64(s, offset + 40));
2937 z = rotate64(z.add(w[0]), 33).mul(k1);
2938 v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(k1), x.add(w[0]));
2939 w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16)));
2940 [z, x] = [x, z];
2941 offset += 64;
2942 } while (offset !== end);
2943 const mul = k1.add(z.and(0xff).shl(1));
2944 // Point to the last 64 bytes of input.
2945 offset = last64;
2946 w[0] = w[0].add((len - 1) & 63);
2947 v[0] = v[0].add(w[0]);
2948 w[0] = w[0].add(v[0]);
2949 x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(mul);
2950 y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(mul);
2951 x = x.xor(w[1].mul(9));
2952 y = y.add(v[0].mul(9).add(fetch64(s, offset + 40)));
2953 z = rotate64(z.add(w[0]), 33).mul(mul);
2954 v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(mul), x.add(w[0]));
2955 w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16)));
2956 [z, x] = [x, z];
2957 return hashLen16(hashLen16(v[0], w[0], mul).add(shiftMix(y).mul(k0)).add(z), hashLen16(v[1], w[1], mul).add(x), mul);
2958 }
2959
2960 /**
2961 * @license
2962 * Copyright 2017 Google LLC. All Rights Reserved.
2963 * Licensed under the Apache License, Version 2.0 (the "License");
2964 * you may not use this file except in compliance with the License.
2965 * You may obtain a copy of the License at
2966 *
2967 * http://www.apache.org/licenses/LICENSE-2.0
2968 *
2969 * Unless required by applicable law or agreed to in writing, software
2970 * distributed under the License is distributed on an "AS IS" BASIS,
2971 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2972 * See the License for the specific language governing permissions and
2973 * limitations under the License.
2974 * =============================================================================
2975 */
2976 /**
2977 * Create typed array for scalar value. Used for storing in `DataStorage`.
2978 */
2979 function createScalarValue(value, dtype) {
2980 if (dtype === 'string') {
2981 return encodeString(value);
2982 }
2983 return toTypedArray([value], dtype);
2984 }
2985 function noConversionNeeded(a, dtype) {
2986 return (a instanceof Float32Array && dtype === 'float32') ||
2987 (a instanceof Int32Array && dtype === 'int32') ||
2988 (a instanceof Uint8Array && dtype === 'bool');
2989 }
2990 function toTypedArray(a, dtype) {
2991 if (dtype === 'string') {
2992 throw new Error('Cannot convert a string[] to a TypedArray');
2993 }
2994 if (Array.isArray(a)) {
2995 a = flatten$2(a);
2996 }
2997 if (env().getBool('DEBUG')) {
2998 checkConversionForErrors(a, dtype);
2999 }
3000 if (noConversionNeeded(a, dtype)) {
3001 return a;
3002 }
3003 if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
3004 return new Float32Array(a);
3005 }
3006 else if (dtype === 'int32') {
3007 return new Int32Array(a);
3008 }
3009 else if (dtype === 'bool') {
3010 const bool = new Uint8Array(a.length);
3011 for (let i = 0; i < bool.length; ++i) {
3012 if (Math.round(a[i]) !== 0) {
3013 bool[i] = 1;
3014 }
3015 }
3016 return bool;
3017 }
3018 else {
3019 throw new Error(`Unknown data type ${dtype}`);
3020 }
3021 }
3022 /**
3023 * Returns the current high-resolution time in milliseconds relative to an
3024 * arbitrary time in the past. It works across different platforms (node.js,
3025 * browsers).
3026 *
3027 * ```js
3028 * console.log(tf.util.now());
3029 * ```
3030 *
3031 * @doc {heading: 'Util', namespace: 'util'}
3032 */
3033 function now() {
3034 return env().platform.now();
3035 }
3036 /**
3037 * Returns a platform-specific implementation of
3038 * [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API).
3039 *
3040 * If `fetch` is defined on the global object (`window`, `process`, etc.),
3041 * `tf.util.fetch` returns that function.
3042 *
3043 * If not, `tf.util.fetch` returns a platform-specific solution.
3044 *
3045 * ```js
3046 * const resource = await tf.util.fetch('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs');
3047 * // handle response
3048 * ```
3049 *
3050 * @doc {heading: 'Util'}
3051 */
3052 function fetch$1(path, requestInits) {
3053 return env().platform.fetch(path, requestInits);
3054 }
3055 /**
3056 * Encodes the provided string into bytes using the provided encoding scheme.
3057 *
3058 * @param s The string to encode.
3059 * @param encoding The encoding scheme. Defaults to utf-8.
3060 *
3061 * @doc {heading: 'Util'}
3062 */
3063 function encodeString(s, encoding = 'utf-8') {
3064 encoding = encoding || 'utf-8';
3065 return env().platform.encode(s, encoding);
3066 }
3067 /**
3068 * Decodes the provided bytes into a string using the provided encoding scheme.
3069 * @param bytes The bytes to decode.
3070 *
3071 * @param encoding The encoding scheme. Defaults to utf-8.
3072 *
3073 * @doc {heading: 'Util'}
3074 */
3075 function decodeString(bytes, encoding = 'utf-8') {
3076 encoding = encoding || 'utf-8';
3077 return env().platform.decode(bytes, encoding);
3078 }
3079 function isTypedArray(a) {
3080 // TODO(mattsoulanille): Remove this fallback in 5.0.0
3081 if (env().platform.isTypedArray != null) {
3082 return env().platform.isTypedArray(a);
3083 }
3084 else {
3085 return isTypedArrayBrowser(a);
3086 }
3087 }
3088 // NOTE: We explicitly type out what T extends instead of any so that
3089 // util.flatten on a nested array of number doesn't try to infer T as a
3090 // number[][], causing us to explicitly type util.flatten<number>().
3091 /**
3092 * Flattens an arbitrarily nested array.
3093 *
3094 * ```js
3095 * const a = [[1, 2], [3, 4], [5, [6, [7]]]];
3096 * const flat = tf.util.flatten(a);
3097 * console.log(flat);
3098 * ```
3099 *
3100 * @param arr The nested array to flatten.
3101 * @param result The destination array which holds the elements.
3102 * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults
3103 * to false.
3104 *
3105 * @doc {heading: 'Util', namespace: 'util'}
3106 */
3107 function flatten$2(arr, result = [], skipTypedArray = false) {
3108 if (result == null) {
3109 result = [];
3110 }
3111 if (typeof arr === 'boolean' || typeof arr === 'number' ||
3112 typeof arr === 'string' || isPromise(arr) || arr == null ||
3113 isTypedArray(arr) && skipTypedArray) {
3114 result.push(arr);
3115 }
3116 else if (Array.isArray(arr) || isTypedArray(arr)) {
3117 for (let i = 0; i < arr.length; ++i) {
3118 flatten$2(arr[i], result, skipTypedArray);
3119 }
3120 }
3121 else {
3122 let maxIndex = -1;
3123 for (const key of Object.keys(arr)) {
3124 // 0 or positive integer.
3125 if (/^([1-9]+[0-9]*|0)$/.test(key)) {
3126 maxIndex = Math.max(maxIndex, Number(key));
3127 }
3128 }
3129 for (let i = 0; i <= maxIndex; i++) {
3130 // tslint:disable-next-line: no-unnecessary-type-assertion
3131 flatten$2(arr[i], result, skipTypedArray);
3132 }
3133 }
3134 return result;
3135 }
3136
3137 var util = /*#__PURE__*/Object.freeze({
3138 __proto__: null,
3139 arraysEqual: arraysEqual,
3140 arraysEqualWithNull: arraysEqualWithNull,
3141 assert: assert$1,
3142 assertNonNegativeIntegerDimensions: assertNonNegativeIntegerDimensions,
3143 assertNonNull: assertNonNull,
3144 assertShapesMatch: assertShapesMatch,
3145 bytesFromStringArray: bytesFromStringArray,
3146 bytesPerElement: bytesPerElement,
3147 checkConversionForErrors: checkConversionForErrors,
3148 clamp: clamp,
3149 computeStrides: computeStrides,
3150 convertBackendValuesAndArrayBuffer: convertBackendValuesAndArrayBuffer,
3151 createScalarValue: createScalarValue,
3152 createShuffledIndices: createShuffledIndices,
3153 decodeString: decodeString,
3154 distSquared: distSquared,
3155 encodeString: encodeString,
3156 fetch: fetch$1,
3157 fingerPrint64: fingerPrint64,
3158 flatten: flatten$2,
3159 getArrayFromDType: getArrayFromDType,
3160 getTypedArrayFromDType: getTypedArrayFromDType,
3161 hasEncodingLoss: hasEncodingLoss,
3162 hexToLong: hexToLong,
3163 indexToLoc: indexToLoc,
3164 inferDtype: inferDtype,
3165 inferFromImplicitShape: inferFromImplicitShape,
3166 isBoolean: isBoolean,
3167 isFunction: isFunction,
3168 isInt: isInt,
3169 isNumber: isNumber,
3170 isPromise: isPromise,
3171 isScalarShape: isScalarShape,
3172 isString: isString,
3173 isTypedArray: isTypedArray,
3174 isValidDtype: isValidDtype,
3175 locToIndex: locToIndex,
3176 makeOnesTypedArray: makeOnesTypedArray,
3177 makeZerosNestedTypedArray: makeZerosNestedTypedArray,
3178 makeZerosTypedArray: makeZerosTypedArray,
3179 nearestDivisor: nearestDivisor,
3180 nearestLargerEven: nearestLargerEven,
3181 now: now,
3182 parseAxisParam: parseAxisParam,
3183 randUniform: randUniform,
3184 repeatedTry: repeatedTry,
3185 rightPad: rightPad,
3186 shuffle: shuffle,
3187 shuffleCombo: shuffleCombo,
3188 sizeFromShape: sizeFromShape,
3189 sizeToSquarishShape: sizeToSquarishShape,
3190 squeezeShape: squeezeShape,
3191 sum: sum$4,
3192 swap: swap,
3193 tanh: tanh$3,
3194 toNestedArray: toNestedArray,
3195 toTypedArray: toTypedArray
3196 });
3197
3198 /**
3199 * @license
3200 * Copyright 2018 Google LLC. All Rights Reserved.
3201 * Licensed under the Apache License, Version 2.0 (the "License");
3202 * you may not use this file except in compliance with the License.
3203 * You may obtain a copy of the License at
3204 *
3205 * http://www.apache.org/licenses/LICENSE-2.0
3206 *
3207 * Unless required by applicable law or agreed to in writing, software
3208 * distributed under the License is distributed on an "AS IS" BASIS,
3209 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3210 * See the License for the specific language governing permissions and
3211 * limitations under the License.
3212 * =============================================================================
3213 */
3214 class Profiler {
3215 constructor(backendTimer, logger) {
3216 this.backendTimer = backendTimer;
3217 this.logger = logger;
3218 if (logger == null) {
3219 this.logger = new Logger();
3220 }
3221 }
3222 profileKernel(kernelName, inputs, f) {
3223 let outputs;
3224 const holdResultWrapperFn = () => {
3225 outputs = f();
3226 };
3227 let timer;
3228 const start = now();
3229 if (this.backendTimer.timerAvailable()) {
3230 timer = this.backendTimer.time(holdResultWrapperFn);
3231 }
3232 else {
3233 holdResultWrapperFn();
3234 for (const output of outputs) {
3235 output.dataSync();
3236 }
3237 timer = Promise.resolve({ kernelMs: now() - start });
3238 }
3239 if (env().getBool('CHECK_COMPUTATION_FOR_ERRORS')) {
3240 for (let i = 0; i < outputs.length; i++) {
3241 const output = outputs[i];
3242 // Dangling promise here because we don't want to propagate up
3243 // asynchronicity.
3244 output.data().then(tensorVals => {
3245 checkComputationForErrors(tensorVals, output.dtype, kernelName);
3246 });
3247 }
3248 }
3249 const kernelProfile = {
3250 kernelName,
3251 outputs,
3252 inputs,
3253 timeMs: timer.then(timing => timing.kernelMs),
3254 extraInfo: timer.then(timing => timing.getExtraProfileInfo != null ?
3255 timing.getExtraProfileInfo() :
3256 '')
3257 };
3258 return kernelProfile;
3259 }
3260 logKernelProfile(kernelProfile) {
3261 const { kernelName, outputs, timeMs, inputs, extraInfo } = kernelProfile;
3262 outputs.forEach(result => {
3263 Promise.all([result.data(), timeMs, extraInfo]).then(valueContainer => {
3264 this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]);
3265 });
3266 });
3267 }
3268 }
3269 function checkComputationForErrors(vals, dtype, kernelName) {
3270 if (dtype !== 'float32') {
3271 // Only floating point computations will generate NaN values
3272 return false;
3273 }
3274 for (let i = 0; i < vals.length; i++) {
3275 const num = vals[i];
3276 if (isNaN(num) || !isFinite(num)) {
3277 // Throwing custom exception so behavior is testable.
3278 console.warn(`Found ${num} in the result of '${kernelName}'`);
3279 return true;
3280 }
3281 }
3282 return false;
3283 }
3284 class Logger {
3285 logKernelProfile(name, result, vals, timeMs, inputs, extraInfo) {
3286 const time = typeof timeMs === 'number' ? rightPad(`${timeMs}ms`, 9) :
3287 timeMs['error'];
3288 const paddedName = rightPad(name, 25);
3289 const rank = result.rank;
3290 const size = result.size;
3291 const shape = rightPad(result.shape.toString(), 14);
3292 let inputShapesDescription = '';
3293 for (const name in inputs) {
3294 const input = inputs[name];
3295 if (input != null) {
3296 // The input might be a non-tensor (e.g HTMLImageElement), in which case
3297 // we claim the output shape as input shape.
3298 const inputShape = input.shape || result.shape;
3299 const inputRank = inputShape.length;
3300 inputShapesDescription +=
3301 `${name}: ${inputRank}D ${inputRank > 0 ? inputShape : ''} `;
3302 }
3303 }
3304 console.log(`%c${paddedName}\t%c${time}\t%c${rank}D ${shape}\t%c${size}\t%c${inputShapesDescription}\t%c${extraInfo}`, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue');
3305 }
3306 }
3307
3308 /**
3309 * @license
3310 * Copyright 2017 Google LLC. All Rights Reserved.
3311 * Licensed under the Apache License, Version 2.0 (the "License");
3312 * you may not use this file except in compliance with the License.
3313 * You may obtain a copy of the License at
3314 *
3315 * http://www.apache.org/licenses/LICENSE-2.0
3316 *
3317 * Unless required by applicable law or agreed to in writing, software
3318 * distributed under the License is distributed on an "AS IS" BASIS,
3319 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3320 * See the License for the specific language governing permissions and
3321 * limitations under the License.
3322 * =============================================================================
3323 */
3324 /**
3325 * Computes a list of TapeNodes that connect x to y, filtering everything else
3326 * out and preserving the order of the original tape elements.
3327 *
3328 * @param tape The tape elements to filter.
3329 * @param xs The input Tensors.
3330 * @param y The output Tensor.
3331 */
3332 function getFilteredNodesXToY(tape, xs, y) {
3333 // Forward pass to compute all the nodes and Tensors that are transitively a
3334 // function of x.
3335 const tensorsFromX = {};
3336 const nodesFromX = {};
3337 for (let i = 0; i < xs.length; i++) {
3338 tensorsFromX[xs[i].id] = true;
3339 }
3340 for (let i = 0; i < tape.length; i++) {
3341 const node = tape[i];
3342 const nodeInputs = node.inputs;
3343 for (const inputName in nodeInputs) {
3344 const input = nodeInputs[inputName];
3345 let anyInputFromX = false;
3346 for (let j = 0; j < xs.length; j++) {
3347 if (tensorsFromX[input.id]) {
3348 node.outputs.forEach(output => tensorsFromX[output.id] = true);
3349 anyInputFromX = true;
3350 nodesFromX[node.id] = true;
3351 break;
3352 }
3353 }
3354 if (anyInputFromX) {
3355 break;
3356 }
3357 }
3358 }
3359 // Backward pass to find all of the nodes and Tensors that lead to y.
3360 const tensorsLeadToY = {};
3361 tensorsLeadToY[y.id] = true;
3362 const nodesToY = {};
3363 for (let i = tape.length - 1; i >= 0; i--) {
3364 const node = tape[i];
3365 const nodeInputs = node.inputs;
3366 // If any of the outputs lead to y, mark all of the inputs as leading to y.
3367 for (let j = 0; j < node.outputs.length; j++) {
3368 if (tensorsLeadToY[node.outputs[j].id]) {
3369 for (const inputName in nodeInputs) {
3370 tensorsLeadToY[nodeInputs[inputName].id] = true;
3371 nodesToY[node.id] = true;
3372 }
3373 break;
3374 }
3375 }
3376 }
3377 // Return the paths that come from x and lead to y.
3378 const filteredTape = [];
3379 for (let i = 0; i < tape.length; i++) {
3380 const node = tape[i];
3381 if (nodesFromX[node.id] && nodesToY[node.id]) {
3382 // Prune the inputs from the node that aren't a function of x.
3383 const prunedInputs = {};
3384 for (const inputName in node.inputs) {
3385 const nodeInput = node.inputs[inputName];
3386 if (tensorsFromX[nodeInput.id]) {
3387 prunedInputs[inputName] = nodeInput;
3388 }
3389 }
3390 // Copy the node and overwrite inputsAndArgs to the pruned version.
3391 const prunedNode = Object.assign({}, node);
3392 prunedNode.inputs = prunedInputs;
3393 prunedNode.outputs = node.outputs;
3394 filteredTape.push(prunedNode);
3395 }
3396 }
3397 return filteredTape;
3398 }
3399 /**
3400 * Backpropagate gradients through the filtered TapeNodes.
3401 *
3402 * @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map
3403 * is mutated by this method.
3404 * @param filteredTape The filtered TapeNodes to backprop through.
3405 */
3406 function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy, add) {
3407 // Walk the tape backward and keep a map of Tensor to its gradient.
3408 for (let i = filteredTape.length - 1; i >= 0; i--) {
3409 const node = filteredTape[i];
3410 const dys = [];
3411 node.outputs.forEach(o => {
3412 const gradTensor = tensorAccumulatedGradientMap[o.id];
3413 if (gradTensor != null) {
3414 dys.push(gradTensor);
3415 }
3416 else {
3417 // This particular output is not in the back-propagation subgraph, so it
3418 // does not affect the final output, thus we put null for its dy.
3419 dys.push(null);
3420 }
3421 });
3422 if (node.gradient == null) {
3423 throw new Error(`Cannot compute gradient: gradient function not found ` +
3424 `for ${node.kernelName}.`);
3425 }
3426 // Backprop dy through this node and accumulate gradients over the inputs.
3427 const inputGradients = node.gradient(dys);
3428 for (const inputName in node.inputs) {
3429 if (!(inputName in inputGradients)) {
3430 throw new Error(`Cannot backprop through input ${inputName}. ` +
3431 `Available gradients found: ${Object.keys(inputGradients)}.`);
3432 }
3433 // Call the gradient function.
3434 const dx = tidy(() => inputGradients[inputName]());
3435 if (dx.dtype !== 'float32') {
3436 throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` +
3437 `${inputName} must have 'float32' dtype, but has '${dx.dtype}'`);
3438 }
3439 const x = node.inputs[inputName];
3440 if (!arraysEqual(dx.shape, x.shape)) {
3441 throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` +
3442 `'${inputName}' has shape '${dx.shape}', which does not match ` +
3443 `the shape of the input '${x.shape}'`);
3444 }
3445 if (tensorAccumulatedGradientMap[x.id] == null) {
3446 tensorAccumulatedGradientMap[x.id] = dx;
3447 }
3448 else {
3449 const curGradient = tensorAccumulatedGradientMap[x.id];
3450 tensorAccumulatedGradientMap[x.id] = add(curGradient, dx);
3451 curGradient.dispose();
3452 }
3453 }
3454 }
3455 }
3456
3457 /**
3458 * @license
3459 * Copyright 2018 Google LLC. All Rights Reserved.
3460 * Licensed under the Apache License, Version 2.0 (the "License");
3461 * you may not use this file except in compliance with the License.
3462 * You may obtain a copy of the License at
3463 *
3464 * http://www.apache.org/licenses/LICENSE-2.0
3465 *
3466 * Unless required by applicable law or agreed to in writing, software
3467 * distributed under the License is distributed on an "AS IS" BASIS,
3468 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3469 * See the License for the specific language governing permissions and
3470 * limitations under the License.
3471 * =============================================================================
3472 */
3473 // Maximum number of values before we decide to show ellipsis.
3474 const FORMAT_LIMIT_NUM_VALS = 20;
3475 // Number of first and last values to show when displaying a, b,...,y, z.
3476 const FORMAT_NUM_FIRST_LAST_VALS = 3;
3477 // Number of significant digits to show.
3478 const FORMAT_NUM_SIG_DIGITS = 7;
3479 function tensorToString(vals, shape, dtype, verbose) {
3480 const strides = computeStrides(shape);
3481 const padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides);
3482 const rank = shape.length;
3483 const valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol);
3484 const lines = ['Tensor'];
3485 if (verbose) {
3486 lines.push(` dtype: ${dtype}`);
3487 lines.push(` rank: ${rank}`);
3488 lines.push(` shape: [${shape}]`);
3489 lines.push(` values:`);
3490 }
3491 lines.push(valsLines.map(l => ' ' + l).join('\n'));
3492 return lines.join('\n');
3493 }
3494 function computeMaxSizePerColumn(vals, shape, dtype, strides) {
3495 const n = sizeFromShape(shape);
3496 const numCols = strides[strides.length - 1];
3497 const padPerCol = new Array(numCols).fill(0);
3498 const rank = shape.length;
3499 const valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals;
3500 if (rank > 1) {
3501 for (let row = 0; row < n / numCols; row++) {
3502 const offset = row * numCols;
3503 for (let j = 0; j < numCols; j++) {
3504 padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length);
3505 }
3506 }
3507 }
3508 return padPerCol;
3509 }
3510 function valToString(val, pad, dtype) {
3511 let valStr;
3512 if (Array.isArray(val)) {
3513 valStr = `${parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS))} + ` +
3514 `${parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS))}j`;
3515 }
3516 else if (isString(val)) {
3517 valStr = `'${val}'`;
3518 }
3519 else if (dtype === 'bool') {
3520 valStr = boolNumToString(val);
3521 }
3522 else {
3523 valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString();
3524 }
3525 return rightPad(valStr, pad);
3526 }
3527 function boolNumToString(v) {
3528 return v === 0 ? 'false' : 'true';
3529 }
3530 function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast = true) {
3531 const storagePerElement = dtype === 'complex64' ? 2 : 1;
3532 const size = shape[0];
3533 const rank = shape.length;
3534 if (rank === 0) {
3535 if (dtype === 'complex64') {
3536 const complexTuple = createComplexTuples(vals);
3537 return [valToString(complexTuple[0], 0, dtype)];
3538 }
3539 if (dtype === 'bool') {
3540 return [boolNumToString(vals[0])];
3541 }
3542 return [vals[0].toString()];
3543 }
3544 if (rank === 1) {
3545 if (size > FORMAT_LIMIT_NUM_VALS) {
3546 const firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement;
3547 let firstVals = Array.from(vals.slice(0, firstValsSize));
3548 let lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement));
3549 if (dtype === 'complex64') {
3550 firstVals = createComplexTuples(firstVals);
3551 lastVals = createComplexTuples(lastVals);
3552 }
3553 return [
3554 '[' +
3555 firstVals.map((x, i) => valToString(x, padPerCol[i], dtype))
3556 .join(', ') +
3557 ', ..., ' +
3558 lastVals
3559 .map((x, i) => valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype))
3560 .join(', ') +
3561 ']'
3562 ];
3563 }
3564 const displayVals = dtype === 'complex64' ? createComplexTuples(vals) :
3565 Array.from(vals);
3566 return [
3567 '[' +
3568 displayVals.map((x, i) => valToString(x, padPerCol[i], dtype))
3569 .join(', ') +
3570 ']'
3571 ];
3572 }
3573 // The array is rank 2 or more.
3574 const subshape = shape.slice(1);
3575 const substrides = strides.slice(1);
3576 const stride = strides[0] * storagePerElement;
3577 const lines = [];
3578 if (size > FORMAT_LIMIT_NUM_VALS) {
3579 for (let i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) {
3580 const start = i * stride;
3581 const end = start + stride;
3582 lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false /* isLast */));
3583 }
3584 lines.push('...');
3585 for (let i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) {
3586 const start = i * stride;
3587 const end = start + stride;
3588 lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */));
3589 }
3590 }
3591 else {
3592 for (let i = 0; i < size; i++) {
3593 const start = i * stride;
3594 const end = start + stride;
3595 lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */));
3596 }
3597 }
3598 const sep = rank === 2 ? ',' : '';
3599 lines[0] = '[' + (size > 0 ? lines[0] + sep : '');
3600 for (let i = 1; i < lines.length - 1; i++) {
3601 lines[i] = ' ' + lines[i] + sep;
3602 }
3603 let newLineSep = ',\n';
3604 for (let i = 2; i < rank; i++) {
3605 newLineSep += '\n';
3606 }
3607 lines[lines.length - 1] =
3608 ' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep);
3609 return lines;
3610 }
3611 function createComplexTuples(vals) {
3612 const complexTuples = [];
3613 for (let i = 0; i < vals.length; i += 2) {
3614 complexTuples.push([vals[i], vals[i + 1]]);
3615 }
3616 return complexTuples;
3617 }
3618
3619 /**
3620 * @license
3621 * Copyright 2017 Google LLC. All Rights Reserved.
3622 * Licensed under the Apache License, Version 2.0 (the "License");
3623 * you may not use this file except in compliance with the License.
3624 * You may obtain a copy of the License at
3625 *
3626 * http://www.apache.org/licenses/LICENSE-2.0
3627 *
3628 * Unless required by applicable law or agreed to in writing, software
3629 * distributed under the License is distributed on an "AS IS" BASIS,
3630 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3631 * See the License for the specific language governing permissions and
3632 * limitations under the License.
3633 * =============================================================================
3634 */
3635 /**
3636 * A mutable object, similar to `tf.Tensor`, that allows users to set values
3637 * at locations before converting to an immutable `tf.Tensor`.
3638 *
3639 * See `tf.buffer` for creating a tensor buffer.
3640 *
3641 * @doc {heading: 'Tensors', subheading: 'Classes'}
3642 */
3643 class TensorBuffer {
3644 constructor(shape, dtype, values) {
3645 this.dtype = dtype;
3646 this.shape = shape.slice();
3647 this.size = sizeFromShape(shape);
3648 if (values != null) {
3649 const n = values.length;
3650 assert$1(n === this.size, () => `Length of values '${n}' does not match the size ` +
3651 `inferred by the shape '${this.size}'.`);
3652 }
3653 if (dtype === 'complex64') {
3654 throw new Error(`complex64 dtype TensorBuffers are not supported. Please create ` +
3655 `a TensorBuffer for the real and imaginary parts separately and ` +
3656 `call tf.complex(real, imag).`);
3657 }
3658 this.values = values || getArrayFromDType(dtype, this.size);
3659 this.strides = computeStrides(shape);
3660 }
3661 /**
3662 * Sets a value in the buffer at a given location.
3663 *
3664 * @param value The value to set.
3665 * @param locs The location indices.
3666 *
3667 * @doc {heading: 'Tensors', subheading: 'Creation'}
3668 */
3669 set(value, ...locs) {
3670 if (locs.length === 0) {
3671 locs = [0];
3672 }
3673 assert$1(locs.length === this.rank, () => `The number of provided coordinates (${locs.length}) must ` +
3674 `match the rank (${this.rank})`);
3675 const index = this.locToIndex(locs);
3676 this.values[index] = value;
3677 }
3678 /**
3679 * Returns the value in the buffer at the provided location.
3680 *
3681 * @param locs The location indices.
3682 *
3683 * @doc {heading: 'Tensors', subheading: 'Creation'}
3684 */
3685 get(...locs) {
3686 if (locs.length === 0) {
3687 locs = [0];
3688 }
3689 let i = 0;
3690 for (const loc of locs) {
3691 if (loc < 0 || loc >= this.shape[i]) {
3692 const msg = `Requested out of range element at ${locs}. ` +
3693 ` Buffer shape=${this.shape}`;
3694 throw new Error(msg);
3695 }
3696 i++;
3697 }
3698 let index = locs[locs.length - 1];
3699 for (let i = 0; i < locs.length - 1; ++i) {
3700 index += this.strides[i] * locs[i];
3701 }
3702 return this.values[index];
3703 }
3704 locToIndex(locs) {
3705 if (this.rank === 0) {
3706 return 0;
3707 }
3708 else if (this.rank === 1) {
3709 return locs[0];
3710 }
3711 let index = locs[locs.length - 1];
3712 for (let i = 0; i < locs.length - 1; ++i) {
3713 index += this.strides[i] * locs[i];
3714 }
3715 return index;
3716 }
3717 indexToLoc(index) {
3718 if (this.rank === 0) {
3719 return [];
3720 }
3721 else if (this.rank === 1) {
3722 return [index];
3723 }
3724 const locs = new Array(this.shape.length);
3725 for (let i = 0; i < locs.length - 1; ++i) {
3726 locs[i] = Math.floor(index / this.strides[i]);
3727 index -= locs[i] * this.strides[i];
3728 }
3729 locs[locs.length - 1] = index;
3730 return locs;
3731 }
3732 get rank() {
3733 return this.shape.length;
3734 }
3735 /**
3736 * Creates an immutable `tf.Tensor` object from the buffer.
3737 *
3738 * @doc {heading: 'Tensors', subheading: 'Creation'}
3739 */
3740 toTensor() {
3741 return trackerFn().makeTensor(this.values, this.shape, this.dtype);
3742 }
3743 }
3744 // For tracking tensor creation and disposal.
3745 let trackerFn = null;
3746 // Used by chaining methods to call into ops.
3747 let opHandler$1 = null;
3748 // Used to warn about deprecated methods.
3749 let deprecationWarningFn = null;
3750 // This here so that we can use this method on dev branches and keep the
3751 // functionality at master.
3752 // tslint:disable-next-line:no-unused-expression
3753 [deprecationWarningFn];
3754 /**
3755 * An external consumer can register itself as the tensor tracker. This way
3756 * the Tensor class can notify the tracker for every tensor created and
3757 * disposed.
3758 */
3759 function setTensorTracker(fn) {
3760 trackerFn = fn;
3761 }
3762 /**
3763 * An external consumer can register itself as the op handler. This way the
3764 * Tensor class can have chaining methods that call into ops via the op
3765 * handler.
3766 */
3767 function setOpHandler(handler) {
3768 opHandler$1 = handler;
3769 }
3770 /**
3771 * Sets the deprecation warning function to be used by this file. This way the
3772 * Tensor class can be a leaf but still use the environment.
3773 */
3774 function setDeprecationWarningFn(fn) {
3775 deprecationWarningFn = fn;
3776 }
3777 /**
3778 * A `tf.Tensor` object represents an immutable, multidimensional array of
3779 * numbers that has a shape and a data type.
3780 *
3781 * For performance reasons, functions that create tensors do not necessarily
3782 * perform a copy of the data passed to them (e.g. if the data is passed as a
3783 * `Float32Array`), and changes to the data will change the tensor. This is not
3784 * a feature and is not supported. To avoid this behavior, use the tensor before
3785 * changing the input data or create a copy with `copy = tf.add(yourTensor, 0)`.
3786 *
3787 * See `tf.tensor` for details on how to create a `tf.Tensor`.
3788 *
3789 * @doc {heading: 'Tensors', subheading: 'Classes'}
3790 */
3791 class Tensor {
3792 constructor(shape, dtype, dataId, id) {
3793 /** Whether this tensor has been globally kept. */
3794 this.kept = false;
3795 this.isDisposedInternal = false;
3796 this.shape = shape.slice();
3797 this.dtype = dtype || 'float32';
3798 this.size = sizeFromShape(shape);
3799 this.strides = computeStrides(shape);
3800 this.dataId = dataId;
3801 this.id = id;
3802 this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher');
3803 }
3804 get rank() {
3805 return this.shape.length;
3806 }
3807 /**
3808 * Returns a promise of `tf.TensorBuffer` that holds the underlying data.
3809 *
3810 * @doc {heading: 'Tensors', subheading: 'Classes'}
3811 */
3812 async buffer() {
3813 const vals = await this.data();
3814 return opHandler$1.buffer(this.shape, this.dtype, vals);
3815 }
3816 /**
3817 * Returns a `tf.TensorBuffer` that holds the underlying data.
3818 * @doc {heading: 'Tensors', subheading: 'Classes'}
3819 */
3820 bufferSync() {
3821 return opHandler$1.buffer(this.shape, this.dtype, this.dataSync());
3822 }
3823 /**
3824 * Returns the tensor data as a nested array. The transfer of data is done
3825 * asynchronously.
3826 *
3827 * @doc {heading: 'Tensors', subheading: 'Classes'}
3828 */
3829 async array() {
3830 const vals = await this.data();
3831 return toNestedArray(this.shape, vals, this.dtype === 'complex64');
3832 }
3833 /**
3834 * Returns the tensor data as a nested array. The transfer of data is done
3835 * synchronously.
3836 *
3837 * @doc {heading: 'Tensors', subheading: 'Classes'}
3838 */
3839 arraySync() {
3840 return toNestedArray(this.shape, this.dataSync(), this.dtype === 'complex64');
3841 }
3842 /**
3843 * Asynchronously downloads the values from the `tf.Tensor`. Returns a
3844 * promise of `TypedArray` that resolves when the computation has finished.
3845 *
3846 * @doc {heading: 'Tensors', subheading: 'Classes'}
3847 */
3848 async data() {
3849 this.throwIfDisposed();
3850 const data = trackerFn().read(this.dataId);
3851 if (this.dtype === 'string') {
3852 const bytes = await data;
3853 try {
3854 return bytes.map(b => decodeString(b));
3855 }
3856 catch (_a) {
3857 throw new Error('Failed to decode the string bytes into utf-8. ' +
3858 'To get the original bytes, call tensor.bytes().');
3859 }
3860 }
3861 return data;
3862 }
3863 /**
3864 * Copy the tensor's data to a new GPU resource. Comparing to the `dataSync()`
3865 * and `data()`, this method prevents data from being downloaded to CPU.
3866 *
3867 * For WebGL backend, the data will be stored on a densely packed texture.
3868 * This means that the texture will use the RGBA channels to store value.
3869 *
3870 * For WebGPU backend, the data will be stored on a buffer. There is no
3871 * parameter, so can not use a user-defined size to create the buffer.
3872 *
3873 * @param options:
3874 * For WebGL,
3875 * - customTexShape: Optional. If set, will use the user defined
3876 * texture shape to create the texture.
3877 *
3878 * @returns For WebGL backend, a GPUData contains the new texture and
3879 * its information.
3880 * {
3881 * tensorRef: The tensor that is associated with this texture,
3882 * texture: WebGLTexture,
3883 * texShape: [number, number] // [height, width]
3884 * }
3885 *
3886 * For WebGPU backend, a GPUData contains the new buffer.
3887 * {
3888 * tensorRef: The tensor that is associated with this buffer,
3889 * buffer: GPUBuffer,
3890 * }
3891 *
3892 * Remember to dispose the GPUData after it is used by
3893 * `res.tensorRef.dispose()`.
3894 *
3895 * @doc {heading: 'Tensors', subheading: 'Classes'}
3896 */
3897 dataToGPU(options) {
3898 this.throwIfDisposed();
3899 return trackerFn().readToGPU(this.dataId, options);
3900 }
3901 /**
3902 * Synchronously downloads the values from the `tf.Tensor`. This blocks the
3903 * UI thread until the values are ready, which can cause performance issues.
3904 *
3905 * @doc {heading: 'Tensors', subheading: 'Classes'}
3906 */
3907 dataSync() {
3908 this.throwIfDisposed();
3909 const data = trackerFn().readSync(this.dataId);
3910 if (this.dtype === 'string') {
3911 try {
3912 return data.map(b => decodeString(b));
3913 }
3914 catch (_a) {
3915 throw new Error('Failed to decode the string bytes into utf-8. ' +
3916 'To get the original bytes, call tensor.bytes().');
3917 }
3918 }
3919 return data;
3920 }
3921 /** Returns the underlying bytes of the tensor's data. */
3922 async bytes() {
3923 this.throwIfDisposed();
3924 const data = await trackerFn().read(this.dataId);
3925 if (this.dtype === 'string') {
3926 return data;
3927 }
3928 else {
3929 return new Uint8Array(data.buffer);
3930 }
3931 }
3932 /**
3933 * Disposes `tf.Tensor` from memory.
3934 *
3935 * @doc {heading: 'Tensors', subheading: 'Classes'}
3936 */
3937 dispose() {
3938 if (this.isDisposed) {
3939 return;
3940 }
3941 if (this.kerasMask) {
3942 this.kerasMask.dispose();
3943 }
3944 trackerFn().disposeTensor(this);
3945 this.isDisposedInternal = true;
3946 }
3947 get isDisposed() {
3948 return this.isDisposedInternal;
3949 }
3950 throwIfDisposed() {
3951 if (this.isDisposed) {
3952 throw new Error(`Tensor is disposed.`);
3953 }
3954 }
3955 /**
3956 * Prints the `tf.Tensor`. See `tf.print` for details.
3957 *
3958 * @param verbose Whether to print verbose information about the tensor,
3959 * including dtype and size.
3960 *
3961 * @doc {heading: 'Tensors', subheading: 'Classes'}
3962 */
3963 print(verbose = false) {
3964 return opHandler$1.print(this, verbose);
3965 }
3966 /**
3967 * Returns a copy of the tensor. See `tf.clone` for details.
3968 * @doc {heading: 'Tensors', subheading: 'Classes'}
3969 */
3970 clone() {
3971 this.throwIfDisposed();
3972 return opHandler$1.clone(this);
3973 }
3974 /**
3975 * Returns a human-readable description of the tensor. Useful for logging.
3976 *
3977 * @doc {heading: 'Tensors', subheading: 'Classes'}
3978 */
3979 toString(verbose = false) {
3980 const vals = this.dataSync();
3981 return tensorToString(vals, this.shape, this.dtype, verbose);
3982 }
3983 cast(dtype) {
3984 this.throwIfDisposed();
3985 return opHandler$1.cast(this, dtype);
3986 }
3987 variable(trainable = true, name, dtype) {
3988 this.throwIfDisposed();
3989 return trackerFn().makeVariable(this, trainable, name, dtype);
3990 }
3991 }
3992 Object.defineProperty(Tensor, Symbol.hasInstance, {
3993 value: (instance) => {
3994 // Implementation note: we should use properties of the object that will be
3995 // defined before the constructor body has finished executing (methods).
3996 // This is because when this code is transpiled by babel, babel will call
3997 // classCallCheck before the constructor body is run.
3998 // See https://github.com/tensorflow/tfjs/issues/3384 for backstory.
3999 return !!instance && instance.data != null && instance.dataSync != null &&
4000 instance.throwIfDisposed != null;
4001 }
4002 });
4003 function getGlobalTensorClass() {
4004 // Use getGlobal so that we can augment the Tensor class across package
4005 // boundaries because the node resolution alg may result in different modules
4006 // being returned for this file depending on the path they are loaded from.
4007 return getGlobal('Tensor', () => {
4008 return Tensor;
4009 });
4010 }
4011 // Global side effect. Cache global reference to Tensor class
4012 getGlobalTensorClass();
4013 /**
4014 * A mutable `tf.Tensor`, useful for persisting state, e.g. for training.
4015 *
4016 * @doc {heading: 'Tensors', subheading: 'Classes'}
4017 */
4018 class Variable extends Tensor {
4019 constructor(initialValue, trainable, name, tensorId) {
4020 super(initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId);
4021 this.trainable = trainable;
4022 this.name = name;
4023 }
4024 /**
4025 * Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have
4026 * the same shape and dtype as the old `tf.Tensor`.
4027 *
4028 * @param newValue New tensor to be assigned to this variable.
4029 *
4030 * @doc {heading: 'Tensors', subheading: 'Classes'}
4031 */
4032 assign(newValue) {
4033 if (newValue.dtype !== this.dtype) {
4034 throw new Error(`dtype of the new value (${newValue.dtype}) and ` +
4035 `previous value (${this.dtype}) must match`);
4036 }
4037 if (!arraysEqual(newValue.shape, this.shape)) {
4038 throw new Error(`shape of the new value (${newValue.shape}) and ` +
4039 `previous value (${this.shape}) must match`);
4040 }
4041 trackerFn().disposeTensor(this);
4042 this.dataId = newValue.dataId;
4043 trackerFn().incRef(this, null /* backend */);
4044 }
4045 dispose() {
4046 trackerFn().disposeVariable(this);
4047 this.isDisposedInternal = true;
4048 }
4049 }
4050 Object.defineProperty(Variable, Symbol.hasInstance, {
4051 value: (instance) => {
4052 return instance instanceof Tensor && instance.assign != null &&
4053 instance.assign instanceof Function;
4054 }
4055 });
4056
4057 /**
4058 * @license
4059 * Copyright 2017 Google LLC. All Rights Reserved.
4060 * Licensed under the Apache License, Version 2.0 (the "License");
4061 * you may not use this file except in compliance with the License.
4062 * You may obtain a copy of the License at
4063 *
4064 * http://www.apache.org/licenses/LICENSE-2.0
4065 *
4066 * Unless required by applicable law or agreed to in writing, software
4067 * distributed under the License is distributed on an "AS IS" BASIS,
4068 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4069 * See the License for the specific language governing permissions and
4070 * limitations under the License.
4071 * =============================================================================
4072 */
4073 exports.Rank = void 0;
4074 (function (Rank) {
4075 Rank["R0"] = "R0";
4076 Rank["R1"] = "R1";
4077 Rank["R2"] = "R2";
4078 Rank["R3"] = "R3";
4079 Rank["R4"] = "R4";
4080 Rank["R5"] = "R5";
4081 Rank["R6"] = "R6";
4082 })(exports.Rank || (exports.Rank = {}));
4083 // Looks for upcasting types. Used, for example, in operations with mixed dtype
4084 // inputs.
4085 var UpcastInt32AndMap;
4086 (function (UpcastInt32AndMap) {
4087 UpcastInt32AndMap["float32"] = "float32";
4088 UpcastInt32AndMap["int32"] = "int32";
4089 UpcastInt32AndMap["bool"] = "int32";
4090 UpcastInt32AndMap["complex64"] = "complex64";
4091 })(UpcastInt32AndMap || (UpcastInt32AndMap = {}));
4092 var UpcastBoolAndMap;
4093 (function (UpcastBoolAndMap) {
4094 UpcastBoolAndMap["float32"] = "float32";
4095 UpcastBoolAndMap["int32"] = "int32";
4096 UpcastBoolAndMap["bool"] = "bool";
4097 UpcastBoolAndMap["complex64"] = "complex64";
4098 })(UpcastBoolAndMap || (UpcastBoolAndMap = {}));
4099 var UpcastFloat32AndMap;
4100 (function (UpcastFloat32AndMap) {
4101 UpcastFloat32AndMap["float32"] = "float32";
4102 UpcastFloat32AndMap["int32"] = "float32";
4103 UpcastFloat32AndMap["bool"] = "float32";
4104 UpcastFloat32AndMap["complex64"] = "complex64";
4105 })(UpcastFloat32AndMap || (UpcastFloat32AndMap = {}));
4106 var UpcastComplex64AndMap;
4107 (function (UpcastComplex64AndMap) {
4108 UpcastComplex64AndMap["float32"] = "complex64";
4109 UpcastComplex64AndMap["int32"] = "complex64";
4110 UpcastComplex64AndMap["bool"] = "complex64";
4111 UpcastComplex64AndMap["complex64"] = "complex64";
4112 })(UpcastComplex64AndMap || (UpcastComplex64AndMap = {}));
4113 const upcastTypeMap = {
4114 'float32': UpcastFloat32AndMap,
4115 'int32': UpcastInt32AndMap,
4116 'bool': UpcastBoolAndMap,
4117 'complex64': UpcastComplex64AndMap
4118 };
4119 function upcastType(typeA, typeB) {
4120 if (typeA === 'string' || typeB === 'string') {
4121 if (typeA === 'string' && typeB === 'string') {
4122 return 'string';
4123 }
4124 throw new Error(`Can not upcast ${typeA} with ${typeB}`);
4125 }
4126 return upcastTypeMap[typeA][typeB];
4127 }
4128 /** Returns the output type after summation. */
4129 function sumOutType(type) {
4130 return upcastType(type, 'int32');
4131 }
4132 function isWebGLData(values) {
4133 return values != null && typeof values === 'object' && 'texture' in values &&
4134 values.texture instanceof WebGLTexture;
4135 }
4136 function isWebGPUData(values) {
4137 return typeof GPUBuffer !== 'undefined' && values != null &&
4138 typeof values === 'object' && 'buffer' in values &&
4139 values.buffer instanceof GPUBuffer;
4140 }
4141
4142 /**
4143 * @license
4144 * Copyright 2018 Google LLC. All Rights Reserved.
4145 * Licensed under the Apache License, Version 2.0 (the "License");
4146 * you may not use this file except in compliance with the License.
4147 * You may obtain a copy of the License at
4148 *
4149 * http://www.apache.org/licenses/LICENSE-2.0
4150 *
4151 * Unless required by applicable law or agreed to in writing, software
4152 * distributed under the License is distributed on an "AS IS" BASIS,
4153 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4154 * See the License for the specific language governing permissions and
4155 * limitations under the License.
4156 * =============================================================================
4157 */
4158 function makeTypesMatch(a, b) {
4159 if (a.dtype === b.dtype) {
4160 return [a, b];
4161 }
4162 const dtype = upcastType(a.dtype, b.dtype);
4163 return [a.cast(dtype), b.cast(dtype)];
4164 }
4165 function assertTypesMatch(a, b) {
4166 assert$1(a.dtype === b.dtype, () => `The dtypes of the first(${a.dtype}) and` +
4167 ` second(${b.dtype}) input must match`);
4168 }
4169 function isTensorInList(tensor, tensorList) {
4170 return tensorList.some(x => x.id === tensor.id);
4171 }
4172 /**
4173 * Extracts any `Tensor`s found within the provided object.
4174 *
4175 * @param container an object that may be a `Tensor` or may directly contain
4176 * `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it
4177 * is safe to pass any object here, except that `Promise`s are not
4178 * supported.
4179 * @returns An array of `Tensors` found within the passed object. If the
4180 * argument is simply a `Tensor', a list containing that `Tensor` is
4181 * returned. If the object is not a `Tensor` or does not
4182 * contain `Tensors`, an empty list is returned.
4183 */
4184 function getTensorsInContainer(result) {
4185 const list = [];
4186 const seen = new Set();
4187 walkTensorContainer(result, list, seen);
4188 return list;
4189 }
4190 function walkTensorContainer(container, list, seen) {
4191 if (container == null) {
4192 return;
4193 }
4194 if (container instanceof Tensor) {
4195 list.push(container);
4196 return;
4197 }
4198 if (!isIterable$1(container)) {
4199 return;
4200 }
4201 // Iteration over keys works also for arrays.
4202 const iterable = container;
4203 for (const k in iterable) {
4204 const val = iterable[k];
4205 if (!seen.has(val)) {
4206 seen.add(val);
4207 walkTensorContainer(val, list, seen);
4208 }
4209 }
4210 }
4211 // tslint:disable-next-line:no-any
4212 function isIterable$1(obj) {
4213 return Array.isArray(obj) || typeof obj === 'object';
4214 }
4215
4216 var tensor_util = /*#__PURE__*/Object.freeze({
4217 __proto__: null,
4218 assertTypesMatch: assertTypesMatch,
4219 getTensorsInContainer: getTensorsInContainer,
4220 isTensorInList: isTensorInList,
4221 makeTypesMatch: makeTypesMatch
4222 });
4223
4224 /**
4225 * @license
4226 * Copyright 2018 Google LLC. All Rights Reserved.
4227 * Licensed under the Apache License, Version 2.0 (the "License");
4228 * you may not use this file except in compliance with the License.
4229 * You may obtain a copy of the License at
4230 *
4231 * http://www.apache.org/licenses/LICENSE-2.0
4232 *
4233 * Unless required by applicable law or agreed to in writing, software
4234 * distributed under the License is distributed on an "AS IS" BASIS,
4235 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4236 * See the License for the specific language governing permissions and
4237 * limitations under the License.
4238 * =============================================================================
4239 */
4240 function isRegisteredKernelInvocation(kernelInvocation) {
4241 return kernelInvocation.kernelName != null;
4242 }
4243 class EngineState {
4244 constructor() {
4245 // Public since optimizers will use it.
4246 this.registeredVariables = {};
4247 this.nextTapeNodeId = 0;
4248 this.numBytes = 0;
4249 this.numTensors = 0;
4250 this.numStringTensors = 0;
4251 this.numDataBuffers = 0;
4252 // Number of nested tf.grad() statements when computing higher-order
4253 // gradients. E.g. `1` for first-order gradients and `2` for second-order
4254 // gradients. Used to track if the tape should be removed after a backprop.
4255 this.gradientDepth = 0;
4256 // Number of nested kernel calls. When kernel depth is greater than 1, we turn
4257 // off the tape.
4258 this.kernelDepth = 0;
4259 this.scopeStack = [];
4260 /**
4261 * Keeps track of the number of data moves during a kernel execution. We
4262 * maintain a stack since kernels can call other kernels, recursively.
4263 */
4264 this.numDataMovesStack = [];
4265 this.nextScopeId = 0;
4266 this.tensorInfo = new WeakMap();
4267 this.profiling = false;
4268 this.activeProfile = {
4269 newBytes: 0,
4270 newTensors: 0,
4271 peakBytes: 0,
4272 kernels: [],
4273 result: null,
4274 get kernelNames() {
4275 return Array.from(new Set(this.kernels.map(k => k.name)));
4276 }
4277 };
4278 }
4279 dispose() {
4280 for (const variableName in this.registeredVariables) {
4281 this.registeredVariables[variableName].dispose();
4282 }
4283 }
4284 }
4285 class Engine {
4286 constructor(ENV) {
4287 this.ENV = ENV;
4288 this.registry = {};
4289 this.registryFactory = {};
4290 this.pendingBackendInitId = 0;
4291 this.state = new EngineState();
4292 }
4293 async ready() {
4294 if (this.pendingBackendInit != null) {
4295 return this.pendingBackendInit.then(() => { });
4296 }
4297 if (this.backendInstance != null) {
4298 return;
4299 }
4300 const sortedBackends = this.getSortedBackends();
4301 for (let i = 0; i < sortedBackends.length; i++) {
4302 const backendName = sortedBackends[i];
4303 const success = await this.initializeBackend(backendName).success;
4304 if (success) {
4305 await this.setBackend(backendName);
4306 return;
4307 }
4308 }
4309 throw new Error(`Could not initialize any backends, all backend initializations ` +
4310 `failed.`);
4311 }
4312 get backend() {
4313 if (this.pendingBackendInit != null) {
4314 throw new Error(`Backend '${this.backendName}' has not yet been initialized. Make ` +
4315 `sure to await tf.ready() or await tf.setBackend() before calling ` +
4316 `other methods`);
4317 }
4318 if (this.backendInstance == null) {
4319 const { name, asyncInit } = this.initializeBackendsAndReturnBest();
4320 if (asyncInit) {
4321 throw new Error(`The highest priority backend '${name}' has not yet been ` +
4322 `initialized. Make sure to await tf.ready() or ` +
4323 `await tf.setBackend() before calling other methods`);
4324 }
4325 this.setBackend(name);
4326 }
4327 return this.backendInstance;
4328 }
4329 backendNames() {
4330 return Object.keys(this.registryFactory);
4331 }
4332 findBackend(backendName) {
4333 if (!(backendName in this.registry)) {
4334 // If the backend hasn't been initialized but we have a registry entry for
4335 // it, initialize it and return it.
4336 if (backendName in this.registryFactory) {
4337 const { asyncInit } = this.initializeBackend(backendName);
4338 if (asyncInit) {
4339 // Backend is not ready yet.
4340 return null;
4341 }
4342 }
4343 else {
4344 return null;
4345 }
4346 }
4347 return this.registry[backendName];
4348 }
4349 findBackendFactory(backendName) {
4350 if (!(backendName in this.registryFactory)) {
4351 return null;
4352 }
4353 return this.registryFactory[backendName].factory;
4354 }
4355 registerBackend(backendName, factory, priority = 1) {
4356 if (backendName in this.registryFactory) {
4357 warn(`${backendName} backend was already registered. ` +
4358 `Reusing existing backend factory.`);
4359 return false;
4360 }
4361 this.registryFactory[backendName] = { factory, priority };
4362 return true;
4363 }
4364 async setBackend(backendName) {
4365 if (this.registryFactory[backendName] == null) {
4366 throw new Error(`Backend name '${backendName}' not found in registry`);
4367 }
4368 this.backendName = backendName;
4369 if (this.registry[backendName] == null) {
4370 this.backendInstance = null;
4371 const { success, asyncInit } = this.initializeBackend(backendName);
4372 const result = asyncInit ? await success : success;
4373 if (!result) {
4374 return false;
4375 }
4376 }
4377 this.backendInstance = this.registry[backendName];
4378 this.setupRegisteredKernels();
4379 // Reset the profiler.
4380 this.profiler = new Profiler(this.backendInstance);
4381 return true;
4382 }
4383 setupRegisteredKernels() {
4384 const kernels = getKernelsForBackend(this.backendName);
4385 kernels.forEach(kernel => {
4386 if (kernel.setupFunc != null) {
4387 kernel.setupFunc(this.backendInstance);
4388 }
4389 });
4390 }
4391 disposeRegisteredKernels(backendName) {
4392 const kernels = getKernelsForBackend(backendName);
4393 kernels.forEach(kernel => {
4394 if (kernel.disposeFunc != null) {
4395 kernel.disposeFunc(this.registry[backendName]);
4396 }
4397 });
4398 }
4399 /**
4400 * Initializes a backend by looking up the backend name in the factory
4401 * registry and calling the factory method. Returns a boolean representing
4402 * whether the initialization of the backend succeeded. Throws an error if
4403 * there is no backend in the factory registry.
4404 */
4405 initializeBackend(backendName) {
4406 const registryFactoryEntry = this.registryFactory[backendName];
4407 if (registryFactoryEntry == null) {
4408 throw new Error(`Cannot initialize backend ${backendName}, no registration found.`);
4409 }
4410 try {
4411 const backend = registryFactoryEntry.factory();
4412 /* Test if the factory returns a promise.
4413 Done in a more liberal way than
4414 previous 'Promise.resolve(backend)===backend'
4415 as we needed to account for custom Promise
4416 implementations (e.g. Angular) */
4417 if (backend && !(backend instanceof KernelBackend) &&
4418 typeof backend.then === 'function') {
4419 const promiseId = ++this.pendingBackendInitId;
4420 const success = backend
4421 .then(backendInstance => {
4422 // Outdated promise. Another backend was set in the meantime.
4423 if (promiseId < this.pendingBackendInitId) {
4424 return false;
4425 }
4426 this.registry[backendName] = backendInstance;
4427 this.pendingBackendInit = null;
4428 return true;
4429 })
4430 .catch(err => {
4431 // Outdated promise. Another backend was set in the meantime.
4432 if (promiseId < this.pendingBackendInitId) {
4433 return false;
4434 }
4435 this.pendingBackendInit = null;
4436 warn(`Initialization of backend ${backendName} failed`);
4437 warn(err.stack || err.message);
4438 return false;
4439 });
4440 this.pendingBackendInit = success;
4441 return { success, asyncInit: true };
4442 }
4443 else {
4444 this.registry[backendName] = backend;
4445 return { success: true, asyncInit: false };
4446 }
4447 }
4448 catch (err) {
4449 warn(`Initialization of backend ${backendName} failed`);
4450 warn(err.stack || err.message);
4451 return { success: false, asyncInit: false };
4452 }
4453 }
4454 removeBackend(backendName) {
4455 if (!(backendName in this.registryFactory)) {
4456 throw new Error(`${backendName} backend not found in registry`);
4457 }
4458 if (this.backendName === backendName && this.pendingBackendInit != null) {
4459 // There is a pending promise of the backend we want to remove. Make it
4460 // obsolete.
4461 this.pendingBackendInitId++;
4462 }
4463 if (backendName in this.registry) {
4464 this.disposeRegisteredKernels(backendName);
4465 this.registry[backendName].dispose();
4466 delete this.registry[backendName];
4467 }
4468 delete this.registryFactory[backendName];
4469 // Unset the backend if it is active.
4470 if (this.backendName === backendName) {
4471 this.pendingBackendInit = null;
4472 this.backendName = null;
4473 this.backendInstance = null;
4474 }
4475 }
4476 getSortedBackends() {
4477 if (Object.keys(this.registryFactory).length === 0) {
4478 throw new Error('No backend found in registry.');
4479 }
4480 return Object.keys(this.registryFactory).sort((a, b) => {
4481 // Highest priority comes first.
4482 return this.registryFactory[b].priority -
4483 this.registryFactory[a].priority;
4484 });
4485 }
4486 initializeBackendsAndReturnBest() {
4487 const sortedBackends = this.getSortedBackends();
4488 for (let i = 0; i < sortedBackends.length; i++) {
4489 const backendName = sortedBackends[i];
4490 const { success, asyncInit } = this.initializeBackend(backendName);
4491 if (asyncInit || success) {
4492 return { name: backendName, asyncInit };
4493 }
4494 }
4495 throw new Error(`Could not initialize any backends, all backend initializations ` +
4496 `failed.`);
4497 }
4498 moveData(backend, dataId) {
4499 const info = this.state.tensorInfo.get(dataId);
4500 const srcBackend = info.backend;
4501 const values = this.readSync(dataId);
4502 const refCount = srcBackend.refCount(dataId);
4503 // Delete the tensor from the old backend and move it to the new
4504 // backend.
4505 srcBackend.disposeData(dataId, true);
4506 info.backend = backend;
4507 backend.move(dataId, values, info.shape, info.dtype, refCount);
4508 if (this.shouldCheckForMemLeaks()) {
4509 // Track the number of moves during a kernel execution to correctly
4510 // detect memory leaks.
4511 this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
4512 }
4513 }
4514 tidy(nameOrFn, fn) {
4515 let name = null;
4516 if (fn == null) {
4517 // Called with only 1 argument.
4518 if (typeof nameOrFn !== 'function') {
4519 throw new Error('Please provide a function to tidy()');
4520 }
4521 fn = nameOrFn;
4522 }
4523 else {
4524 // Called with 2 arguments.
4525 if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {
4526 throw new Error('When calling with two arguments, the first argument ' +
4527 'to tidy() must be a string');
4528 }
4529 if (typeof fn !== 'function') {
4530 throw new Error('When calling with two arguments, the 2nd argument ' +
4531 'to tidy() must be a function');
4532 }
4533 name = nameOrFn;
4534 // TODO(nsthorat,smilkov): Do operation logging and performance
4535 // profiling.
4536 }
4537 let result;
4538 return this.scopedRun(() => this.startScope(name), () => this.endScope(result), () => {
4539 result = fn();
4540 if (result instanceof Promise) {
4541 console.error('Cannot return a Promise inside of tidy.');
4542 }
4543 return result;
4544 });
4545 }
4546 scopedRun(start, end, f) {
4547 start();
4548 try {
4549 const res = f();
4550 end();
4551 return res;
4552 }
4553 catch (ex) {
4554 end();
4555 throw ex;
4556 }
4557 }
4558 nextTensorId() {
4559 return Engine.nextTensorId++;
4560 }
4561 nextVariableId() {
4562 return Engine.nextVariableId++;
4563 }
4564 /**
4565 * This method is called instead of the public-facing tensor.clone() when
4566 * saving a tensor for backwards pass. It makes sure to add the clone
4567 * operation to the tape regardless of being called inside a kernel
4568 * execution.
4569 */
4570 clone(x) {
4571 const y = ENGINE.runKernel(Identity$1, { x });
4572 const inputs = { x };
4573 const grad = (dy) => ({
4574 x: () => {
4575 const dtype = 'float32';
4576 const gradInputs = { x: dy };
4577 const attrs = { dtype };
4578 return ENGINE.runKernel(Cast, gradInputs,
4579 // tslint:disable-next-line: no-unnecessary-type-assertion
4580 attrs);
4581 }
4582 });
4583 const saved = [];
4584 this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {});
4585 return y;
4586 }
4587 /**
4588 * Execute a kernel with the given name and return the output tensor.
4589 *
4590 * @param kernelName The name of the kernel to execute.
4591 * @param inputs A map of input names to tensors.
4592 * @param attrs A map of attribute names to their values. An attribute is a
4593 * primitive (non-tensor) input to the kernel.
4594 * @param inputsToSave A list of tensors, inputs to save for the backprop
4595 * computation.
4596 * @param outputsToSave A list of booleans, specifying which output to save
4597 * for the backprop computation. These are booleans since the output
4598 * tensors are not visible to the user.
4599 */
4600 runKernel(kernelName, inputs, attrs) {
4601 if (this.backendName == null) {
4602 // backend has not been initialized yet (backend initialization is lazy
4603 // can be deferred until an op/ kernel is run).
4604 // The below getter has side effects that will try to initialize the
4605 // backend and set properties like this.backendName
4606 // tslint:disable-next-line: no-unused-expression
4607 this.backend;
4608 }
4609 const hasKernel = getKernel(kernelName, this.backendName) != null;
4610 if (!hasKernel) {
4611 throw new Error(`Kernel '${kernelName}' not registered for backend '${this.backendName}'`);
4612 }
4613 return this.runKernelFunc({ kernelName, inputs, attrs });
4614 }
4615 shouldCheckForMemLeaks() {
4616 return this.ENV.getBool('IS_TEST');
4617 }
4618 checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos) {
4619 const numDataIdsAfter = this.backend.numDataIds();
4620 // Count the number of data ids associated with the result of the kernel.
4621 let numOutputDataIds = 0;
4622 outInfos.forEach(info => {
4623 // Complex numbers allocate 3 data ids, one for 'real', one for
4624 // 'imaginary', and one for the container that holds the former two.
4625 numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1);
4626 });
4627 // Account for the number of moves during kernel execution. A "data move"
4628 // can happen in the middle of a kernel execution, placing a new (key,value)
4629 // pair in the data storage. Since data moves have net zero effect (we
4630 // always remove the data from the old backend), we have to cancel them out
4631 // when detecting memory leaks.
4632 const numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];
4633 const dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;
4634 if (dataIdsLeaked > 0) {
4635 throw new Error(`Backend '${this.backendName}' has an internal memory leak ` +
4636 `(${dataIdsLeaked} data ids) after running '${kernelName}'`);
4637 }
4638 }
4639 /**
4640 * Internal helper method to execute a kernel Func
4641 *
4642 * Use `runKernel` to execute kernels from outside of engine.
4643 */
4644 runKernelFunc(kernelParams) {
4645 let outputs;
4646 let saved = [];
4647 const isTapeOn = this.isTapeOn();
4648 const startingBytecount = this.state.numBytes;
4649 const startingNumTensors = this.state.numTensors;
4650 if (this.shouldCheckForMemLeaks()) {
4651 this.state.numDataMovesStack.push(0);
4652 }
4653 let kernelFunc;
4654 if (this.backendName == null) {
4655 // backend has not been initialized yet (backend initialization is lazy
4656 // can be deferred until an op/ kernel is run).
4657 // The below getter has side effects that will try to initialize the
4658 // backend and set properties like this.backendName
4659 // tslint:disable-next-line: no-unused-expression
4660 this.backend;
4661 }
4662 let out;
4663 const kernelOrScopeName = isRegisteredKernelInvocation(kernelParams) ?
4664 kernelParams.kernelName :
4665 this.state.activeScope != null ? this.state.activeScope.name : '';
4666 // Create the kernelFunc from either a registered kernel OR passed in
4667 // forward/backward functions (used by custom grad). In this context a
4668 // kernelFunc wraps a kernel implementation with some bookkeeping.
4669 if (isRegisteredKernelInvocation(kernelParams)) {
4670 const { kernelName, inputs, attrs } = kernelParams;
4671 if (this.backendName == null) {
4672 // backend has not been initialized yet (backend initialization is lazy
4673 // can be deferred until an op/ kernel is run).
4674 // The below getter has side effects that will try to initialize the
4675 // backend and set properties like this.backendName
4676 // tslint:disable-next-line: no-unused-expression
4677 this.backend;
4678 }
4679 const kernel = getKernel(kernelName, this.backendName);
4680 assert$1(kernel != null, () => `Cannot find registered kernel '${kernelName}' for backend '${this.backendName}'`);
4681 kernelFunc = () => {
4682 const numDataIdsBefore = this.backend.numDataIds();
4683 out = kernel.kernelFunc({ inputs, attrs, backend: this.backend });
4684 const outInfos = Array.isArray(out) ? out : [out];
4685 if (this.shouldCheckForMemLeaks()) {
4686 this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
4687 }
4688 const outTensors = outInfos.map((outInfo) => {
4689 // todo (yassogba) remove this option (Tensor) when node backend
4690 // methods have been modularized and they all return tensorInfo.
4691 // TensorInfos do not have a rank attribute.
4692 if (outInfo.rank != null) {
4693 return outInfo;
4694 }
4695 return this.makeTensorFromTensorInfo(outInfo);
4696 });
4697 // Save any required inputs and outputs.
4698 // Do not save unless we are recording to the tape. Otherwise it would
4699 // cause a mem leak since there would be no backprop for these tensors
4700 // (which would otherwise dispose them).
4701 if (isTapeOn) {
4702 const tensorsToSave = this.getTensorsForGradient(kernelName, inputs, outTensors);
4703 saved = this.saveTensorsForBackwardMode(tensorsToSave);
4704 }
4705 return outTensors;
4706 };
4707 }
4708 else {
4709 const { forwardFunc } = kernelParams;
4710 // Running a customGrad op.
4711 const saveFunc = (tensors) => {
4712 // Do not save unless we are recording to the tape. Otherwise it would
4713 // cause a mem leak since we would never run backprop, which disposes
4714 // the kept tensors.
4715 if (!isTapeOn) {
4716 return;
4717 }
4718 saved = tensors.map(tensor => this.keep(this.clone(tensor)));
4719 };
4720 kernelFunc = () => {
4721 const numDataIdsBefore = this.backend.numDataIds();
4722 out = this.tidy(() => forwardFunc(this.backend, saveFunc));
4723 const outs = (Array.isArray(out) ? out : [out]);
4724 if (this.shouldCheckForMemLeaks()) {
4725 // Scope name is used to print a more helpful error message if needed.
4726 this.checkKernelForMemLeak(kernelOrScopeName, numDataIdsBefore, outs);
4727 }
4728 return outs;
4729 };
4730 }
4731 //
4732 // Run the kernelFunc. Optionally profiling it.
4733 //
4734 const { inputs, attrs } = kernelParams;
4735 const backwardsFunc = isRegisteredKernelInvocation(kernelParams) ?
4736 null :
4737 kernelParams.backwardsFunc;
4738 let kernelProfile;
4739 this.scopedRun(
4740 // Stop recording to a tape when running a kernel.
4741 () => this.state.kernelDepth++, () => this.state.kernelDepth--, () => {
4742 if (!this.ENV.getBool('DEBUG') && !this.state.profiling) {
4743 outputs = kernelFunc();
4744 }
4745 else {
4746 kernelProfile = this.profiler.profileKernel(kernelOrScopeName, inputs, () => kernelFunc());
4747 if (this.ENV.getBool('DEBUG')) {
4748 this.profiler.logKernelProfile(kernelProfile);
4749 }
4750 outputs = kernelProfile.outputs;
4751 }
4752 });
4753 if (isTapeOn) {
4754 this.addTapeNode(kernelOrScopeName, inputs, outputs, backwardsFunc, saved, attrs);
4755 }
4756 if (this.state.profiling) {
4757 this.state.activeProfile.kernels.push({
4758 name: kernelOrScopeName,
4759 bytesAdded: this.state.numBytes - startingBytecount,
4760 totalBytesSnapshot: this.state.numBytes,
4761 tensorsAdded: this.state.numTensors - startingNumTensors,
4762 totalTensorsSnapshot: this.state.numTensors,
4763 inputShapes: Object.keys(inputs).map(key => inputs[key] != null ? inputs[key].shape : null),
4764 outputShapes: outputs.map(item => item.shape),
4765 kernelTimeMs: kernelProfile.timeMs,
4766 extraInfo: kernelProfile.extraInfo
4767 });
4768 }
4769 return (Array.isArray(out) ? outputs : outputs[0]);
4770 }
4771 /**
4772 * Saves tensors used in forward mode for use in backward mode.
4773 *
4774 * @param tensors the list of tensors to save.
4775 */
4776 saveTensorsForBackwardMode(tensors) {
4777 const saved = tensors.map(tensor => this.keep(this.clone(tensor)));
4778 return saved;
4779 }
4780 /**
4781 * Returns a list of tensors to save for a given gradient calculation.
4782 *
4783 * @param kernelName name of kernel to look up gradient for.
4784 * @param inputs a map of input tensors.
4785 * @param outputs an array of output tensors from forward mode of kernel.
4786 */
4787 getTensorsForGradient(kernelName, inputs, outputs) {
4788 const gradConfig = getGradient(kernelName);
4789 if (gradConfig != null) {
4790 const inputsToSave = gradConfig.inputsToSave || [];
4791 const outputsToSave = gradConfig.outputsToSave || [];
4792 // If saveAllInputs is true, all inputs will be saved. Otherwise, inputs
4793 // specified in inputsToSave will be saved.
4794 let inputTensorsToSave;
4795 if (gradConfig.saveAllInputs) {
4796 assert$1(Array.isArray(inputs), () => 'saveAllInputs is true, expected inputs to be an array.');
4797 inputTensorsToSave = Object.keys(inputs).map((key) => inputs[key]);
4798 }
4799 else {
4800 inputTensorsToSave = inputsToSave.map((inputName) => inputs[inputName]);
4801 }
4802 const outputTensorsToSave = outputs.filter((_, i) => outputsToSave[i]);
4803 return inputTensorsToSave.concat(outputTensorsToSave);
4804 }
4805 // We return an empty list rather than throw an error because the kernel we
4806 // are looking up may not actually be relevant to backproping through the
4807 // overall function
4808 //
4809 // See 'does not error if irrelevant (pruned) ops are missing grads' test
4810 // in gradients_test.ts for an example.
4811 return [];
4812 }
4813 /**
4814 * Internal method used by public APIs for tensor creation. Makes a new
4815 * tensor with the provided shape, dtype and values. It always
4816 * creates a new data id and writes the values to the underlying backend.
4817 */
4818 makeTensor(values, shape, dtype, backend) {
4819 if (values == null) {
4820 throw new Error('Values passed to engine.makeTensor() are null');
4821 }
4822 dtype = dtype || 'float32';
4823 backend = backend || this.backend;
4824 let backendVals = values;
4825 if (dtype === 'string' && isString(values[0])) {
4826 backendVals = values.map(d => encodeString(d));
4827 }
4828 const dataId = backend.write(backendVals, shape, dtype);
4829 const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
4830 this.trackTensor(t, backend);
4831 // Count bytes for string tensors.
4832 if (dtype === 'string') {
4833 const info = this.state.tensorInfo.get(dataId);
4834 const newBytes = bytesFromStringArray(backendVals);
4835 this.state.numBytes += newBytes - info.bytes;
4836 info.bytes = newBytes;
4837 }
4838 return t;
4839 }
4840 /**
4841 * Internal method used by backends. Makes a new tensor
4842 * that is a wrapper around an existing data id. It doesn't create
4843 * a new data id, only increments the ref count used in memory tracking.
4844 * @deprecated
4845 */
4846 makeTensorFromDataId(dataId, shape, dtype, backend) {
4847 dtype = dtype || 'float32';
4848 const tensorInfo = { dataId, shape, dtype };
4849 return this.makeTensorFromTensorInfo(tensorInfo, backend);
4850 }
4851 /**
4852 * Internal method used by backends. Makes a new tensor that is a wrapper
4853 * around an existing data id in TensorInfo. It doesn't create a new data id,
4854 * only increments the ref count used in memory tracking.
4855 */
4856 makeTensorFromTensorInfo(tensorInfo, backend) {
4857 const { dataId, shape, dtype } = tensorInfo;
4858 const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
4859 this.trackTensor(t, backend);
4860 return t;
4861 }
4862 makeVariable(initialValue, trainable = true, name, dtype) {
4863 name = name || this.nextVariableId().toString();
4864 if (dtype != null && dtype !== initialValue.dtype) {
4865 initialValue = initialValue.cast(dtype);
4866 }
4867 const v = new Variable(initialValue, trainable, name, this.nextTensorId());
4868 if (this.state.registeredVariables[v.name] != null) {
4869 throw new Error(`Variable with name ${v.name} was already registered`);
4870 }
4871 this.state.registeredVariables[v.name] = v;
4872 this.incRef(v, this.backend);
4873 return v;
4874 }
4875 trackTensor(a, backend) {
4876 this.state.numTensors++;
4877 if (a.dtype === 'string') {
4878 this.state.numStringTensors++;
4879 }
4880 // Bytes for complex numbers are counted by their components. Bytes for
4881 // string tensors are counted when writing values.
4882 let bytes = 0;
4883 if (a.dtype !== 'complex64' && a.dtype !== 'string') {
4884 bytes = a.size * bytesPerElement(a.dtype);
4885 }
4886 this.state.numBytes += bytes;
4887 if (!this.state.tensorInfo.has(a.dataId)) {
4888 this.state.numDataBuffers++;
4889 this.state.tensorInfo.set(a.dataId, {
4890 backend: backend || this.backend,
4891 dtype: a.dtype,
4892 shape: a.shape,
4893 bytes
4894 });
4895 }
4896 if (!(a instanceof Variable)) {
4897 this.track(a);
4898 }
4899 }
4900 // Track the tensor by dataId and increase the refCount for the dataId in the
4901 // backend.
4902 // TODO(pyu10055): This is currently used by makeVariable method, to increase
4903 // refCount on the backend for the dataId. It can potentially be replaced with
4904 // Identity op indead of calling backend directly.
4905 incRef(a, backend) {
4906 this.trackTensor(a, backend);
4907 this.backend.incRef(a.dataId);
4908 }
4909 removeDataId(dataId, backend) {
4910 if (this.state.tensorInfo.has(dataId) &&
4911 this.state.tensorInfo.get(dataId).backend === backend) {
4912 this.state.tensorInfo.delete(dataId);
4913 this.state.numDataBuffers--;
4914 }
4915 }
4916 disposeTensor(a) {
4917 if (!this.state.tensorInfo.has(a.dataId)) {
4918 return;
4919 }
4920 const info = this.state.tensorInfo.get(a.dataId);
4921 this.state.numTensors--;
4922 if (a.dtype === 'string') {
4923 this.state.numStringTensors--;
4924 this.state.numBytes -= info.bytes;
4925 }
4926 // Don't count bytes for complex numbers as they are counted by their
4927 // components.
4928 if (a.dtype !== 'complex64' && a.dtype !== 'string') {
4929 const bytes = a.size * bytesPerElement(a.dtype);
4930 this.state.numBytes -= bytes;
4931 }
4932 // Remove the reference to dataId if backend dispose the data successfully
4933 if (info.backend.disposeData(a.dataId)) {
4934 this.removeDataId(a.dataId, info.backend);
4935 }
4936 // TODO(nsthorat): Construct an error and save the stack trace for
4937 // debugging when in debug mode. Creating a stack trace is too expensive
4938 // to do unconditionally.
4939 }
4940 disposeVariables() {
4941 for (const varName in this.state.registeredVariables) {
4942 const v = this.state.registeredVariables[varName];
4943 this.disposeVariable(v);
4944 }
4945 }
4946 disposeVariable(v) {
4947 this.disposeTensor(v);
4948 if (this.state.registeredVariables[v.name] != null) {
4949 delete this.state.registeredVariables[v.name];
4950 }
4951 }
4952 memory() {
4953 const info = this.backend.memory();
4954 info.numTensors = this.state.numTensors;
4955 info.numDataBuffers = this.state.numDataBuffers;
4956 info.numBytes = this.state.numBytes;
4957 if (this.state.numStringTensors > 0) {
4958 info.unreliable = true;
4959 if (info.reasons == null) {
4960 info.reasons = [];
4961 }
4962 info.reasons.push('Memory usage by string tensors is approximate ' +
4963 '(2 bytes per character)');
4964 }
4965 return info;
4966 }
4967 async profile(query) {
4968 this.state.profiling = true;
4969 const startBytes = this.state.numBytes;
4970 const startNumTensors = this.state.numTensors;
4971 this.state.activeProfile.kernels = [];
4972 this.state.activeProfile.result = await query();
4973 this.state.profiling = false;
4974 this.state.activeProfile.peakBytes = Math.max(...this.state.activeProfile.kernels.map(d => d.totalBytesSnapshot));
4975 this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
4976 this.state.activeProfile.newTensors =
4977 this.state.numTensors - startNumTensors;
4978 for (const kernel of this.state.activeProfile.kernels) {
4979 kernel.kernelTimeMs = await kernel.kernelTimeMs;
4980 kernel.extraInfo = await kernel.extraInfo;
4981 }
4982 return this.state.activeProfile;
4983 }
4984 isTapeOn() {
4985 return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
4986 }
4987 addTapeNode(kernelName, inputs, outputs, gradientsFunc, saved, attrs) {
4988 const tapeNode = { id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved };
4989 const gradConfig = getGradient(kernelName);
4990 if (gradConfig != null) {
4991 gradientsFunc = gradConfig.gradFunc;
4992 }
4993 if (gradientsFunc != null) {
4994 tapeNode.gradient = (dys) => {
4995 // TODO(smilkov): To optimize back-prop, pass dys that are not used in
4996 // the backprop graph to the user as null instead of zeros
4997 dys = dys.map((dy, i) => {
4998 if (dy == null) {
4999 const output = outputs[i];
5000 const vals = makeZerosTypedArray(output.size, output.dtype);
5001 return this.makeTensor(vals, output.shape, output.dtype);
5002 }
5003 return dy;
5004 });
5005 // Grad functions of ops with single outputs expect a dy, while ops
5006 // with multiple outputs expect dys (array of dy).
5007 return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);
5008 };
5009 }
5010 this.state.activeTape.push(tapeNode);
5011 }
5012 keep(result) {
5013 result.kept = true;
5014 return result;
5015 }
5016 startTape() {
5017 if (this.state.gradientDepth === 0) {
5018 this.state.activeTape = [];
5019 }
5020 this.state.gradientDepth++;
5021 }
5022 endTape() {
5023 this.state.gradientDepth--;
5024 }
5025 /**
5026 * Start a scope. Use this with endScope() to achieve the same functionality
5027 * as scope() without the need for a function closure.
5028 */
5029 startScope(name) {
5030 const scopeInfo = {
5031 track: [],
5032 name: 'unnamed scope',
5033 id: this.state.nextScopeId++
5034 };
5035 if (name) {
5036 scopeInfo.name = name;
5037 }
5038 this.state.scopeStack.push(scopeInfo);
5039 this.state.activeScope = scopeInfo;
5040 }
5041 /**
5042 * End a scope. Use this with startScope() to achieve the same functionality
5043 * as scope() without the need for a function closure.
5044 */
5045 endScope(result) {
5046 const tensorsToTrackInParent = getTensorsInContainer(result);
5047 const tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(t => t.id));
5048 // Dispose the arrays tracked in this scope.
5049 for (let i = 0; i < this.state.activeScope.track.length; i++) {
5050 const tensor = this.state.activeScope.track[i];
5051 if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) {
5052 tensor.dispose();
5053 }
5054 }
5055 const oldScope = this.state.scopeStack.pop();
5056 this.state.activeScope = this.state.scopeStack.length === 0 ?
5057 null :
5058 this.state.scopeStack[this.state.scopeStack.length - 1];
5059 // Track the current result in the parent scope.
5060 tensorsToTrackInParent.forEach(tensor => {
5061 // Only track the tensor if was allocated in the inner scope and is not
5062 // globally kept.
5063 if (!tensor.kept && tensor.scopeId === oldScope.id) {
5064 this.track(tensor);
5065 }
5066 });
5067 }
5068 /**
5069 * Returns gradients of `f` with respect to each of the `xs`. The gradients
5070 * returned are of the same length as `xs`, but some might be null if `f`
5071 * was not a function of that `x`. It also takes optional dy to multiply the
5072 * gradient, which defaults to `1`.
5073 */
5074 gradients(f, xs, dy, allowNoGradients = false) {
5075 assert$1(xs.length > 0, () => 'gradients() received an empty list of xs.');
5076 if (dy != null && dy.dtype !== 'float32') {
5077 throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`);
5078 }
5079 const y = this.scopedRun(() => this.startTape(), () => this.endTape(), () => this.tidy('forward', f));
5080 assert$1(y instanceof Tensor, () => 'The result y returned by f() must be a tensor.');
5081 // Filter out the nodes that don't connect x => y.
5082 const filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);
5083 if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
5084 throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' +
5085 'that the f you passed encloses all operations that lead from x ' +
5086 'to y.');
5087 }
5088 return this.tidy('backward', () => {
5089 const accumulatedGradientMap = {};
5090 accumulatedGradientMap[y.id] = (dy == null) ? ones$2(y.shape) : dy;
5091 // Backprop gradients through the filtered nodes.
5092 backpropagateGradients(accumulatedGradientMap, filteredTape,
5093 // Pass the tidy function to avoid circular dep with `tape.ts`.
5094 f => this.tidy(f),
5095 // Pass an add function to avoide a circular dep with `tape.ts`.
5096 add$4);
5097 const grads = xs.map(x => accumulatedGradientMap[x.id]);
5098 if (this.state.gradientDepth === 0) {
5099 // This means that we are not computing higher-order gradients
5100 // and can clean up the tape.
5101 this.state.activeTape.forEach(node => {
5102 for (const tensor of node.saved) {
5103 tensor.dispose();
5104 }
5105 });
5106 this.state.activeTape = null;
5107 }
5108 return { value: y, grads };
5109 });
5110 }
5111 customGrad(f) {
5112 assert$1(isFunction(f), () => 'The f passed in customGrad(f) must be a function.');
5113 return (...inputs) => {
5114 assert$1(inputs.every(t => t instanceof Tensor), () => 'The args passed in customGrad(f)(x1, x2,...) must all be ' +
5115 'tensors');
5116 let res;
5117 const inputMap = {};
5118 inputs.forEach((input, i) => {
5119 inputMap[i] = input;
5120 });
5121 const forwardFunc = (_, save) => {
5122 res = f(...[...inputs, save]);
5123 assert$1(res.value instanceof Tensor, () => 'The function f passed in customGrad(f) must return an ' +
5124 'object where `obj.value` is a tensor');
5125 assert$1(isFunction(res.gradFunc), () => 'The function f passed in customGrad(f) must return an ' +
5126 'object where `obj.gradFunc` is a function.');
5127 return res.value;
5128 };
5129 const backwardsFunc = (dy, saved) => {
5130 const gradRes = res.gradFunc(dy, saved);
5131 const grads = Array.isArray(gradRes) ? gradRes : [gradRes];
5132 assert$1(grads.length === inputs.length, () => 'The function f passed in customGrad(f) must return an ' +
5133 'object where `obj.gradFunc` is a function that returns ' +
5134 'the same number of tensors as inputs passed to f(...).');
5135 assert$1(grads.every(t => t instanceof Tensor), () => 'The function f passed in customGrad(f) must return an ' +
5136 'object where `obj.gradFunc` is a function that returns ' +
5137 'a list of only tensors.');
5138 const gradMap = {};
5139 grads.forEach((grad, i) => {
5140 gradMap[i] = () => grad;
5141 });
5142 return gradMap;
5143 };
5144 return this.runKernelFunc({
5145 forwardFunc,
5146 backwardsFunc,
5147 inputs: inputMap,
5148 });
5149 };
5150 }
5151 readSync(dataId) {
5152 // Route the read to the correct backend.
5153 const info = this.state.tensorInfo.get(dataId);
5154 return info.backend.readSync(dataId);
5155 }
5156 read(dataId) {
5157 // Route the read to the correct backend.
5158 const info = this.state.tensorInfo.get(dataId);
5159 return info.backend.read(dataId);
5160 }
5161 readToGPU(dataId, options) {
5162 // Route the read to the correct backend.
5163 const info = this.state.tensorInfo.get(dataId);
5164 return info.backend.readToGPU(dataId, options);
5165 }
5166 async time(query) {
5167 const start = now();
5168 const timingInfo = await this.backend.time(query);
5169 timingInfo.wallMs = now() - start;
5170 return timingInfo;
5171 }
5172 /**
5173 * Tracks a Tensor in the current scope to be automatically cleaned up
5174 * when the current scope ends, and returns the value.
5175 *
5176 * @param result The Tensor to track in the current scope.
5177 */
5178 track(result) {
5179 if (this.state.activeScope != null) {
5180 result.scopeId = this.state.activeScope.id;
5181 this.state.activeScope.track.push(result);
5182 }
5183 return result;
5184 }
5185 get registeredVariables() {
5186 return this.state.registeredVariables;
5187 }
5188 /**
5189 * Resets the engine state. Removes all backends but does not remove
5190 * registered backend factories.
5191 */
5192 reset() {
5193 // Make any pending promise obsolete.
5194 this.pendingBackendInitId++;
5195 this.state.dispose();
5196 this.ENV.reset();
5197 this.state = new EngineState();
5198 for (const backendName in this.registry) {
5199 this.disposeRegisteredKernels(backendName);
5200 this.registry[backendName].dispose();
5201 delete this.registry[backendName];
5202 }
5203 this.backendName = null;
5204 this.backendInstance = null;
5205 this.pendingBackendInit = null;
5206 }
5207 }
5208 Engine.nextTensorId = 0;
5209 Engine.nextVariableId = 0;
5210 function ones$2(shape) {
5211 const values = makeOnesTypedArray(sizeFromShape(shape), 'float32');
5212 return ENGINE.makeTensor(values, shape, 'float32');
5213 }
5214 function getOrMakeEngine() {
5215 const ns = getGlobalNamespace();
5216 if (ns._tfengine == null) {
5217 const environment = new Environment(ns);
5218 ns._tfengine = new Engine(environment);
5219 }
5220 setEnvironmentGlobal(ns._tfengine.ENV);
5221 // Tell the current tensor interface that the global engine is responsible
5222 // for tracking.
5223 setTensorTracker(() => ns._tfengine);
5224 return ns._tfengine;
5225 }
5226 const ENGINE = getOrMakeEngine();
5227 /**
5228 * A implementation of the add op for use within engine and tape.
5229 *
5230 * This allows us to avoid a circular dependency between add.ts and engine.
5231 * It is exported to be available in tape tests.
5232 */
5233 function add$4(a, b) {
5234 // We duplicate Add here to avoid a circular dependency with add.ts.
5235 const inputs = { a, b };
5236 return ENGINE.runKernel(Add$1, inputs);
5237 }
5238
5239 /**
5240 * @license
5241 * Copyright 2017 Google LLC. All Rights Reserved.
5242 * Licensed under the Apache License, Version 2.0 (the "License");
5243 * you may not use this file except in compliance with the License.
5244 * You may obtain a copy of the License at
5245 *
5246 * http://www.apache.org/licenses/LICENSE-2.0
5247 *
5248 * Unless required by applicable law or agreed to in writing, software
5249 * distributed under the License is distributed on an "AS IS" BASIS,
5250 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5251 * See the License for the specific language governing permissions and
5252 * limitations under the License.
5253 * =============================================================================
5254 */
5255 // tslint:disable-next-line:no-any
5256 function _isNavigatorDefined() {
5257 return typeof navigator !== 'undefined' && navigator != null;
5258 }
5259 let isMobileMockValue;
5260 function mockIsMobile(value) {
5261 isMobileMockValue = value;
5262 }
5263 function isMobile(nav) {
5264 if (isMobileMockValue !== undefined) {
5265 return isMobileMockValue;
5266 }
5267 if (nav || _isNavigatorDefined()) {
5268 if (!nav) {
5269 nav = navigator;
5270 }
5271 if (nav.product === 'ReactNative') {
5272 return true;
5273 }
5274 const a = nav.userAgent || nav.vendor ||
5275 // tslint:disable-next-line:no-any
5276 (typeof window !== 'undefined' ? window.opera : '');
5277 // Use `navigator.userAgentData.mobile` as fallback.
5278 if (!a) {
5279 // tslint:disable-next-line:no-any
5280 const navAny = nav;
5281 return navAny.userAgentData && navAny.userAgentData.mobile;
5282 }
5283 // tslint:disable-next-line:max-line-length
5284 return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i
5285 .test(a) ||
5286 // tslint:disable-next-line:max-line-length
5287 /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i
5288 .test(a.substr(0, 4));
5289 }
5290 return false;
5291 }
5292 function isBrowser() {
5293 return (typeof window !== 'undefined' && window.document != null) ||
5294 //@ts-ignore
5295 (typeof WorkerGlobalScope !== 'undefined');
5296 }
5297
5298 var device_util = /*#__PURE__*/Object.freeze({
5299 __proto__: null,
5300 isBrowser: isBrowser,
5301 isMobile: isMobile,
5302 mockIsMobile: mockIsMobile
5303 });
5304
5305 /**
5306 * @license
5307 * Copyright 2019 Google LLC. All Rights Reserved.
5308 * Licensed under the Apache License, Version 2.0 (the "License");
5309 * you may not use this file except in compliance with the License.
5310 * You may obtain a copy of the License at
5311 *
5312 * http://www.apache.org/licenses/LICENSE-2.0
5313 *
5314 * Unless required by applicable law or agreed to in writing, software
5315 * distributed under the License is distributed on an "AS IS" BASIS,
5316 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5317 * See the License for the specific language governing permissions and
5318 * limitations under the License.
5319 * =============================================================================
5320 */
5321 const ENV$3 = env();
5322 /**
5323 * This file contains environment-related flag registrations.
5324 */
5325 /** Whether to enable debug mode. */
5326 ENV$3.registerFlag('DEBUG', () => false, debugValue => {
5327 if (debugValue) {
5328 console.warn('Debugging mode is ON. The output of every math call will ' +
5329 'be downloaded to CPU and checked for NaNs. ' +
5330 'This significantly impacts performance.');
5331 }
5332 });
5333 /** Whether we are in a browser (as versus, say, node.js) environment. */
5334 ENV$3.registerFlag('IS_BROWSER', () => isBrowser());
5335 /** Whether we are in a browser (as versus, say, node.js) environment. */
5336 ENV$3.registerFlag('IS_NODE', () => (typeof process !== 'undefined') &&
5337 (typeof process.versions !== 'undefined') &&
5338 (typeof process.versions.node !== 'undefined'));
5339 /** Whether this browser is Chrome. */
5340 ENV$3.registerFlag('IS_CHROME', () => typeof navigator !== 'undefined' && navigator != null &&
5341 navigator.userAgent != null && /Chrome/.test(navigator.userAgent) &&
5342 /Google Inc/.test(navigator.vendor));
5343 /** Whether this browser is Safari. */
5344 ENV$3.registerFlag('IS_SAFARI', () => typeof navigator !== 'undefined' && navigator != null &&
5345 navigator.userAgent != null && /Safari/.test(navigator.userAgent) &&
5346 /Apple/.test(navigator.vendor));
5347 /**
5348 * True when the environment is "production" where we disable safety checks
5349 * to gain performance.
5350 */
5351 ENV$3.registerFlag('PROD', () => false);
5352 /**
5353 * Whether to do sanity checks when inferring a shape from user-provided
5354 * values, used when creating a new tensor.
5355 */
5356 ENV$3.registerFlag('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', () => ENV$3.getBool('DEBUG'));
5357 /** Whether deprecation warnings are enabled. */
5358 ENV$3.registerFlag('DEPRECATION_WARNINGS_ENABLED', () => true);
5359 /** True if running unit tests. */
5360 ENV$3.registerFlag('IS_TEST', () => false);
5361 /** Whether to check computation result for errors. */
5362 ENV$3.registerFlag('CHECK_COMPUTATION_FOR_ERRORS', () => ENV$3.getBool('DEBUG'));
5363 /** Whether the backend needs to wrap input to imageBitmap. */
5364 ENV$3.registerFlag('WRAP_TO_IMAGEBITMAP', () => false);
5365 /** Whether to enable canvas2d willReadFrequently for GPU backends */
5366 ENV$3.registerFlag('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU', () => false);
5367 /** Whether to use setTimeoutCustom */
5368 ENV$3.registerFlag('USE_SETTIMEOUTCUSTOM', () => false);
5369
5370 /**
5371 * @license
5372 * Copyright 2018 Google LLC. All Rights Reserved.
5373 * Licensed under the Apache License, Version 2.0 (the "License");
5374 * you may not use this file except in compliance with the License.
5375 * You may obtain a copy of the License at
5376 *
5377 * http://www.apache.org/licenses/LICENSE-2.0
5378 *
5379 * Unless required by applicable law or agreed to in writing, software
5380 * distributed under the License is distributed on an "AS IS" BASIS,
5381 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5382 * See the License for the specific language governing permissions and
5383 * limitations under the License.
5384 * =============================================================================
5385 */
5386 function inferShape(val, dtype) {
5387 let firstElem = val;
5388 if (isTypedArray(val)) {
5389 return dtype === 'string' ? [] : [val.length];
5390 }
5391 if (isWebGLData(val)) {
5392 const usedChannels = val.channels || 'RGBA';
5393 return [val.height, val.width * usedChannels.length];
5394 }
5395 else if (isWebGPUData(val)) {
5396 return [val.buffer.size / (dtype == null ? 4 : bytesPerElement(dtype))];
5397 }
5398 if (!Array.isArray(val)) {
5399 return []; // Scalar.
5400 }
5401 const shape = [];
5402 while (Array.isArray(firstElem) ||
5403 isTypedArray(firstElem) && dtype !== 'string') {
5404 shape.push(firstElem.length);
5405 firstElem = firstElem[0];
5406 }
5407 if (Array.isArray(val) &&
5408 env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) {
5409 deepAssertShapeConsistency(val, shape, []);
5410 }
5411 return shape;
5412 }
5413 function deepAssertShapeConsistency(val, shape, indices) {
5414 indices = indices || [];
5415 if (!(Array.isArray(val)) && !isTypedArray(val)) {
5416 assert$1(shape.length === 0, () => `Element arr[${indices.join('][')}] is a primitive, ` +
5417 `but should be an array/TypedArray of ${shape[0]} elements`);
5418 return;
5419 }
5420 assert$1(shape.length > 0, () => `Element arr[${indices.join('][')}] should be a primitive, ` +
5421 `but is an array of ${val.length} elements`);
5422 assert$1(val.length === shape[0], () => `Element arr[${indices.join('][')}] should have ${shape[0]} ` +
5423 `elements, but has ${val.length} elements`);
5424 const subShape = shape.slice(1);
5425 for (let i = 0; i < val.length; ++i) {
5426 deepAssertShapeConsistency(val[i], subShape, indices.concat(i));
5427 }
5428 }
5429 function assertDtype(expectedDtype, actualDType, argName, functionName) {
5430 if (expectedDtype === 'string_or_numeric') {
5431 return;
5432 }
5433 if (expectedDtype == null) {
5434 throw new Error(`Expected dtype cannot be null.`);
5435 }
5436 if (expectedDtype !== 'numeric' && expectedDtype !== actualDType ||
5437 expectedDtype === 'numeric' && actualDType === 'string') {
5438 throw new Error(`Argument '${argName}' passed to '${functionName}' must ` +
5439 `be ${expectedDtype} tensor, but got ${actualDType} tensor`);
5440 }
5441 }
5442 function convertToTensor(x, argName, functionName, parseAsDtype = 'numeric') {
5443 if (x instanceof getGlobalTensorClass()) {
5444 assertDtype(parseAsDtype, x.dtype, argName, functionName);
5445 return x;
5446 }
5447 let inferredDtype = inferDtype(x);
5448 // If the user expects a bool/int/float, use that info to update the
5449 // inferredDtype when it is not a string.
5450 if (inferredDtype !== 'string' &&
5451 ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) {
5452 inferredDtype = parseAsDtype;
5453 }
5454 assertDtype(parseAsDtype, inferredDtype, argName, functionName);
5455 if ((x == null) ||
5456 (!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' &&
5457 typeof x !== 'boolean' && typeof x !== 'string')) {
5458 const type = x == null ? 'null' : x.constructor.name;
5459 throw new Error(`Argument '${argName}' passed to '${functionName}' must be a ` +
5460 `Tensor or TensorLike, but got '${type}'`);
5461 }
5462 const inferredShape = inferShape(x, inferredDtype);
5463 if (!isTypedArray(x) && !Array.isArray(x)) {
5464 x = [x];
5465 }
5466 const skipTypedArray = true;
5467 const values = inferredDtype !== 'string' ?
5468 toTypedArray(x, inferredDtype) :
5469 flatten$2(x, [], skipTypedArray);
5470 return ENGINE.makeTensor(values, inferredShape, inferredDtype);
5471 }
5472 function convertToTensorArray(arg, argName, functionName, parseAsDtype = 'numeric') {
5473 if (!Array.isArray(arg)) {
5474 throw new Error(`Argument ${argName} passed to ${functionName} must be a ` +
5475 '`Tensor[]` or `TensorLike[]`');
5476 }
5477 const tensors = arg;
5478 return tensors.map((t, i) => convertToTensor(t, `${argName}[${i}]`, functionName, parseAsDtype));
5479 }
5480
5481 /**
5482 * @license
5483 * Copyright 2018 Google LLC. All Rights Reserved.
5484 * Licensed under the Apache License, Version 2.0 (the "License");
5485 * you may not use this file except in compliance with the License.
5486 * You may obtain a copy of the License at
5487 *
5488 * http://www.apache.org/licenses/LICENSE-2.0
5489 *
5490 * Unless required by applicable law or agreed to in writing, software
5491 * distributed under the License is distributed on an "AS IS" BASIS,
5492 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5493 * See the License for the specific language governing permissions and
5494 * limitations under the License.
5495 * =============================================================================
5496 */
5497 const OP_SCOPE_SUFFIX = '__op';
5498 /**
5499 * Used for wrapping functions that perform math operations on
5500 * Tensors. The function will be wrapped in a named scope that cleans all
5501 * memory usage after the function is done.
5502 */
5503 function op(f) {
5504 const keys = Object.keys(f);
5505 if (keys.length !== 1) {
5506 throw new Error(`Please provide an object with a single key ` +
5507 `(operation name) mapping to a function. Got an object with ` +
5508 `${keys.length} keys.`);
5509 }
5510 let opName = keys[0];
5511 const fn = f[opName];
5512 // Strip the underscore from the end of the function name.
5513 if (opName.endsWith('_')) {
5514 opName = opName.substring(0, opName.length - 1);
5515 }
5516 // add an __op suffix to distinguish ops from kernels in tf.profile
5517 opName = opName + OP_SCOPE_SUFFIX;
5518 // tslint:disable-next-line:no-any
5519 const f2 = (...args) => {
5520 ENGINE.startScope(opName);
5521 try {
5522 const result = fn(...args);
5523 if (isPromise(result)) {
5524 console.error('Cannot return a Promise inside of tidy.');
5525 }
5526 ENGINE.endScope(result);
5527 return result;
5528 }
5529 catch (ex) {
5530 ENGINE.endScope(null);
5531 throw ex;
5532 }
5533 };
5534 Object.defineProperty(f2, 'name', { value: opName, configurable: true });
5535 // tslint:disable-next-line:no-any
5536 return f2;
5537 }
5538
5539 /**
5540 * @license
5541 * Copyright 2020 Google LLC. All Rights Reserved.
5542 * Licensed under the Apache License, Version 2.0 (the "License");
5543 * you may not use this file except in compliance with the License.
5544 * You may obtain a copy of the License at
5545 *
5546 * http://www.apache.org/licenses/LICENSE-2.0
5547 *
5548 * Unless required by applicable law or agreed to in writing, software
5549 * distributed under the License is distributed on an "AS IS" BASIS,
5550 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5551 * See the License for the specific language governing permissions and
5552 * limitations under the License.
5553 * =============================================================================
5554 */
5555 /**
5556 * Converts two real numbers to a complex number.
5557 *
5558 * Given a tensor `real` representing the real part of a complex number, and a
5559 * tensor `imag` representing the imaginary part of a complex number, this
5560 * operation returns complex numbers elementwise of the form [r0, i0, r1, i1],
5561 * where r represents the real part and i represents the imag part.
5562 *
5563 * The input tensors real and imag must have the same shape.
5564 *
5565 * ```js
5566 * const real = tf.tensor1d([2.25, 3.25]);
5567 * const imag = tf.tensor1d([4.75, 5.75]);
5568 * const complex = tf.complex(real, imag);
5569 *
5570 * complex.print();
5571 * ```
5572 *
5573 * @doc {heading: 'Tensors', subheading: 'Creation'}
5574 */
5575 function complex_(real, imag) {
5576 const $real = convertToTensor(real, 'real', 'complex');
5577 const $imag = convertToTensor(imag, 'imag', 'complex');
5578 assertShapesMatch($real.shape, $imag.shape, `real and imag shapes, ${$real.shape} and ${$imag.shape}, ` +
5579 `must match in call to tf.complex().`);
5580 const inputs = { real: $real, imag: $imag };
5581 return ENGINE.runKernel(Complex, inputs);
5582 }
5583 const complex$2 = /* @__PURE__ */ op({ complex_ });
5584
5585 /**
5586 * @license
5587 * Copyright 2018 Google LLC. All Rights Reserved.
5588 * Licensed under the Apache License, Version 2.0 (the "License");
5589 * you may not use this file except in compliance with the License.
5590 * You may obtain a copy of the License at
5591 *
5592 * http://www.apache.org/licenses/LICENSE-2.0
5593 *
5594 * Unless required by applicable law or agreed to in writing, software
5595 * distributed under the License is distributed on an "AS IS" BASIS,
5596 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5597 * See the License for the specific language governing permissions and
5598 * limitations under the License.
5599 * =============================================================================
5600 */
5601 /** This is shared code across all tensor creation methods. */
5602 function makeTensor(values, shape, inferredShape, dtype) {
5603 if (dtype == null) {
5604 dtype = inferDtype(values);
5605 }
5606 else if (dtype === 'complex64') {
5607 throw new Error(`Cannot construct a complex64 tensor directly. ` +
5608 `Please use tf.complex(real, imag).`);
5609 }
5610 if (isWebGPUData(values) || isWebGLData(values)) {
5611 if (dtype !== 'float32' && dtype !== 'int32') {
5612 throw new Error(`Creating tensor from GPU data only supports ` +
5613 `'float32'|'int32' dtype, while the dtype is ${dtype}.`);
5614 }
5615 return ENGINE.backend.createTensorFromGPUData(values, shape || inferredShape, dtype);
5616 }
5617 if (!isTypedArray(values) && !Array.isArray(values) &&
5618 typeof values !== 'number' && typeof values !== 'boolean' &&
5619 typeof values !== 'string') {
5620 throw new Error('values passed to tensor(values) must be a number/boolean/string or ' +
5621 'an array of numbers/booleans/strings, or a TypedArray');
5622 }
5623 // Verify that the shape matches the inferred shape.
5624 if (shape != null) {
5625 assertNonNegativeIntegerDimensions(shape);
5626 const providedSize = sizeFromShape(shape);
5627 const inferredSize = sizeFromShape(inferredShape);
5628 assert$1(providedSize === inferredSize, () => `Based on the provided shape, [${shape}], the tensor should have ` +
5629 `${providedSize} values but has ${inferredSize}`);
5630 for (let i = 0; i < inferredShape.length; ++i) {
5631 const inferred = inferredShape[i];
5632 const flatDimsDontMatch = i === inferredShape.length - 1 ?
5633 inferred !== sizeFromShape(shape.slice(i)) :
5634 true;
5635 assert$1(inferredShape[i] === shape[i] || !flatDimsDontMatch, () => `Error creating a new Tensor. Inferred shape ` +
5636 `(${inferredShape}) does not match the provided ` +
5637 `shape (${shape}). `);
5638 }
5639 }
5640 if (!isTypedArray(values) && !Array.isArray(values)) {
5641 values = [values];
5642 }
5643 shape = shape || inferredShape;
5644 values = dtype !== 'string' ?
5645 toTypedArray(values, dtype) :
5646 flatten$2(values, [], true);
5647 return ENGINE.makeTensor(values, shape, dtype);
5648 }
5649
5650 /**
5651 * @license
5652 * Copyright 2018 Google LLC. All Rights Reserved.
5653 * Licensed under the Apache License, Version 2.0 (the "License");
5654 * you may not use this file except in compliance with the License.
5655 * You may obtain a copy of the License at
5656 *
5657 * http://www.apache.org/licenses/LICENSE-2.0
5658 *
5659 * Unless required by applicable law or agreed to in writing, software
5660 * distributed under the License is distributed on an "AS IS" BASIS,
5661 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5662 * See the License for the specific language governing permissions and
5663 * limitations under the License.
5664 * =============================================================================
5665 */
5666 /**
5667 * Creates a `tf.Tensor` with the provided values, shape and dtype.
5668 *
5669 * ```js
5670 * // Pass an array of values to create a vector.
5671 * tf.tensor([1, 2, 3, 4]).print();
5672 * ```
5673 *
5674 * ```js
5675 * // Pass a nested array of values to make a matrix or a higher
5676 * // dimensional tensor.
5677 * tf.tensor([[1, 2], [3, 4]]).print();
5678 * ```
5679 *
5680 * ```js
5681 * // Pass a flat array and specify a shape yourself.
5682 * tf.tensor([1, 2, 3, 4], [2, 2]).print();
5683 * ```
5684 *
5685 * ```js
5686 * // Pass a `WebGLData` object and specify a shape yourself.
5687 *
5688 * // This makes it possible for TF.js applications to avoid GPU / CPU sync.
5689 * // For example, if your application includes a preprocessing step on the GPU,
5690 * // you could upload the GPU output directly to TF.js, rather than first
5691 * // downloading the values.
5692 *
5693 * // Example for WebGL2:
5694 * if (tf.findBackend('custom-webgl') == null) {
5695 * const customCanvas = document.createElement('canvas');
5696 * const customBackend = new tf.MathBackendWebGL(customCanvas);
5697 * tf.registerBackend('custom-webgl', () => customBackend);
5698 * }
5699 * const savedBackend = tf.getBackend();
5700 * await tf.setBackend('custom-webgl');
5701 * const gl = tf.backend().gpgpu.gl;
5702 * const texture = gl.createTexture();
5703 * const tex2d = gl.TEXTURE_2D;
5704 * const width = 2;
5705 * const height = 2;
5706 *
5707 * gl.bindTexture(tex2d, texture);
5708 * gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
5709 * gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
5710 * gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
5711 * gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
5712 * gl.texImage2D(
5713 * tex2d, 0, gl.RGBA32F, // internalFormat
5714 * width, height, 0,
5715 * gl.RGBA, // textureFormat
5716 * gl.FLOAT, // textureType
5717 * new Float32Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
5718 * );
5719 *
5720 * // Currently, the `texture` has 4 pixels:
5721 * // Pixel0 is {R:0, G:1, B:2, A:3}
5722 * // Pixel1 is {R:4, G:5, B:6, A:7}
5723 * // Pixel2 is {R:8, G:9, B:10, A:11}
5724 * // Pixel3 is {R:12, G:13, B:14, A:15}
5725 *
5726 * const logicalShape = [height * width * 2];
5727 * const a = tf.tensor({texture, height, width, channels: 'BR'}, logicalShape);
5728 * a.print();
5729 * // Tensor value will be [2, 0, 6, 4, 10, 8, 14, 12], since [2, 0] is the
5730 * // values of 'B' and 'R' channels of Pixel0, [6, 4] is the values of 'B' and
5731 * 'R'
5732 * // channels of Pixel1...
5733 *
5734 * // For postprocessing on the GPU, it's possible to retrieve the texture
5735 * // backing any tensor by calling the tensor's `dataToGPU` method like
5736 * // so:
5737 *
5738 * const tex = a.dataToGPU();
5739 * await tf.setBackend(savedBackend);
5740 * ```
5741 *
5742 * ```js
5743 * // Pass a `WebGPUData` object and specify a shape yourself.
5744 *
5745 * // This makes it possible for TF.js applications to avoid GPU / CPU sync.
5746 * // For example, if your application includes a preprocessing step on the GPU,
5747 * // you could upload the GPU output directly to TF.js, rather than first
5748 * // downloading the values. Unlike WebGL, this optionally supports zero copy
5749 * // by WebGPUData.zeroCopy. When zeroCopy is false or undefined(default), this
5750 * // passing GPUBuffer can be destroyed after tensor is created. When zeroCopy
5751 * // is true, this GPUBuffer is bound directly by the tensor, so do not destroy
5752 * // this GPUBuffer until all access is done.
5753 *
5754 * // Example for WebGPU:
5755 * function createGPUBufferFromData(device, data, dtype) {
5756 * const bytesPerElement = 4;
5757 * const sizeInBytes = data.length * bytesPerElement;
5758 *
5759 * const gpuWriteBuffer = device.createBuffer({
5760 * mappedAtCreation: true,
5761 * size: sizeInBytes,
5762 * usage: GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC
5763 * });
5764 * const arrayBuffer = gpuWriteBuffer.getMappedRange();
5765 * if (dtype === 'float32') {
5766 * new Float32Array(arrayBuffer).set(data);
5767 * } else if (dtype === 'int32') {
5768 * new Int32Array(arrayBuffer).set(data);
5769 * } else {
5770 * throw new Error(
5771 * `Creating tensor from GPUBuffer only supports` +
5772 * `'float32'|'int32' dtype, while the dtype is ${dtype}.`);
5773 * }
5774 * gpuWriteBuffer.unmap();
5775 *
5776 * const gpuReadBuffer = device.createBuffer({
5777 * mappedAtCreation: false,
5778 * size: sizeInBytes,
5779 * usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE |
5780 * GPUBufferUsage.COPY_SRC
5781 * });
5782 *
5783 * const copyEncoder = device.createCommandEncoder();
5784 * copyEncoder.copyBufferToBuffer(
5785 * gpuWriteBuffer, 0, gpuReadBuffer, 0, sizeInBytes);
5786 * const copyCommands = copyEncoder.finish();
5787 * device.queue.submit([copyCommands]);
5788 * gpuWriteBuffer.destroy();
5789 * return gpuReadBuffer;
5790 * }
5791 *
5792 * const savedBackend = tf.getBackend();
5793 * await tf.setBackend('webgpu').catch(
5794 * () => {throw new Error(
5795 * 'Failed to use WebGPU backend. Please use Chrome Canary to run.')});
5796 * const dtype = 'float32';
5797 * const device = tf.backend().device;
5798 * const aData = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
5799 * const bData = [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4];
5800 * const expected = [2, 4, 6, 8, 6, 8, 10, 12, 10, 12, 14, 16, 14, 16, 18, 20];
5801 * const aBuffer = createGPUBufferFromData(device, aData, dtype);
5802 * const shape = [aData.length];
5803 * // To use zeroCopy, use {buffer: aBuffer, zeroCopy: true} instead and destroy
5804 * // aBuffer untill all access is done.
5805 * const a = tf.tensor({buffer: aBuffer}, shape, dtype);
5806 * const b = tf.tensor(bData, shape, dtype);
5807 * const result = tf.add(a, b);
5808 * result.print();
5809 * a.dispose();
5810 * b.dispose();
5811 * result.dispose();
5812 * aBuffer.destroy();
5813 * await tf.setBackend(savedBackend);
5814 * ```
5815 * @param values The values of the tensor. Can be nested array of numbers,
5816 * or a flat array, or a `TypedArray`(At the moment it supports Uint8Array,
5817 * Uint8ClampedArray, Int32Array, Float32Array) data types, or a `WebGLData`
5818 * object, or a `WebGPUData` object. If the values are strings, they will be
5819 * encoded as utf-8 and kept as `Uint8Array[]`. If the values is a `WebGLData`
5820 * object, the dtype could only be 'float32' or 'int32' and the object has to
5821 * have: 1. texture, a `WebGLTexture`, the texture must share the same
5822 * `WebGLRenderingContext` with TFJS's WebGL backend (you could create a custom
5823 * WebGL backend from your texture's canvas) and the internal texture format
5824 * for the input texture must be floating point or normalized integer; 2.
5825 * height, the height of the texture; 3. width, the width of the texture; 4.
5826 * channels, a non-empty subset of 'RGBA', indicating the values of which
5827 * channels will be passed to the tensor, such as 'R' or 'BR' (The order of the
5828 * channels affect the order of tensor values. ). (If the values passed from
5829 * texture is less than the tensor size, zeros will be padded at the rear.). If
5830 * the values is a `WebGPUData` object, the dtype could only be 'float32' or
5831 * 'int32 and the object has to have: buffer, a `GPUBuffer`. The buffer must:
5832 * 1. share the same `GPUDevice` with TFJS's WebGPU backend; 2. buffer.usage
5833 * should at least support GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC; 3.
5834 * buffer.size should not be smaller than the byte size of tensor shape.
5835 * WebGPUData optionally supports zero copy by flag zeroCopy. When zeroCopy is
5836 * false or undefined(default),this passing GPUBuffer can be destroyed after
5837 * tensor is created. When zeroCopy is true, this GPUBuffer is bound directly
5838 * by the tensor, so do not destroy this GPUBuffer until all access is done.
5839 * @param shape The shape of the tensor. Optional. If not provided,
5840 * it is inferred from `values`.
5841 * @param dtype The data type.
5842 *
5843 * @doc {heading: 'Tensors', subheading: 'Creation'}
5844 */
5845 function tensor(values, shape, dtype) {
5846 const inferredShape = inferShape(values, dtype);
5847 return makeTensor(values, shape, inferredShape, dtype);
5848 }
5849
5850 /**
5851 * @license
5852 * Copyright 2018 Google LLC. All Rights Reserved.
5853 * Licensed under the Apache License, Version 2.0 (the "License");
5854 * you may not use this file except in compliance with the License.
5855 * You may obtain a copy of the License at
5856 *
5857 * http://www.apache.org/licenses/LICENSE-2.0
5858 *
5859 * Unless required by applicable law or agreed to in writing, software
5860 * distributed under the License is distributed on an "AS IS" BASIS,
5861 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5862 * See the License for the specific language governing permissions and
5863 * limitations under the License.
5864 * =============================================================================
5865 */
5866 /* Type definitions for exporting and importing of models. */
5867 /**
5868 * A map from Tensor dtype to number of bytes per element of the Tensor.
5869 */
5870 const DTYPE_VALUE_SIZE_MAP = {
5871 'float32': 4,
5872 'float16': 2,
5873 'int32': 4,
5874 'uint16': 2,
5875 'uint8': 1,
5876 'bool': 1,
5877 'complex64': 8
5878 };
5879
5880 /**
5881 * Wraps a list of ArrayBuffers into a `slice()`-able object without allocating
5882 * a large ArrayBuffer.
5883 *
5884 * Allocating large ArrayBuffers (~2GB) can be unstable on Chrome. TFJS loads
5885 * its weights as a list of (usually) 4MB ArrayBuffers and then slices the
5886 * weight tensors out of them. For small models, it's safe to concatenate all
5887 * the weight buffers into a single ArrayBuffer and then slice the weight
5888 * tensors out of it, but for large models, a different approach is needed.
5889 */
5890 class CompositeArrayBuffer {
5891 /**
5892 * Concatenate a number of ArrayBuffers into one.
5893 *
5894 * @param buffers An array of ArrayBuffers to concatenate, or a single
5895 * ArrayBuffer.
5896 * @returns Result of concatenating `buffers` in order.
5897 */
5898 static join(buffers) {
5899 return new CompositeArrayBuffer(buffers).slice();
5900 }
5901 constructor(buffers) {
5902 this.shards = [];
5903 this.previousShardIndex = 0;
5904 if (buffers == null) {
5905 return;
5906 }
5907 // Normalize the `buffers` input to be `ArrayBuffer[]`.
5908 if (!(buffers instanceof Array)) {
5909 buffers = [buffers];
5910 }
5911 buffers = buffers.map((bufferOrTypedArray) => {
5912 if (isTypedArray(bufferOrTypedArray)) {
5913 return bufferOrTypedArray.buffer;
5914 }
5915 return bufferOrTypedArray;
5916 });
5917 // Skip setting up shards if there are no buffers.
5918 if (buffers.length === 0) {
5919 return;
5920 }
5921 this.bufferUniformSize = buffers[0].byteLength;
5922 let start = 0;
5923 for (let i = 0; i < buffers.length; i++) {
5924 const buffer = buffers[i];
5925 // Check that all buffers except the last one have the same length.
5926 if (i !== buffers.length - 1 &&
5927 buffer.byteLength !== this.bufferUniformSize) {
5928 // Unset the buffer uniform size, since the buffer sizes are not
5929 // uniform.
5930 this.bufferUniformSize = undefined;
5931 }
5932 // Create the shards, including their start and end points.
5933 const end = start + buffer.byteLength;
5934 this.shards.push({ buffer, start, end });
5935 start = end;
5936 }
5937 // Set the byteLength
5938 if (this.shards.length === 0) {
5939 this.byteLength = 0;
5940 }
5941 this.byteLength = this.shards[this.shards.length - 1].end;
5942 }
5943 slice(start = 0, end = this.byteLength) {
5944 // If there are no shards, then the CompositeArrayBuffer was initialized
5945 // with no data.
5946 if (this.shards.length === 0) {
5947 return new ArrayBuffer(0);
5948 }
5949 // NaN is treated as zero for slicing. This matches ArrayBuffer's behavior.
5950 start = isNaN(Number(start)) ? 0 : start;
5951 end = isNaN(Number(end)) ? 0 : end;
5952 // Fix the bounds to within the array.
5953 start = Math.max(0, start);
5954 end = Math.min(this.byteLength, end);
5955 if (end <= start) {
5956 return new ArrayBuffer(0);
5957 }
5958 const startShardIndex = this.findShardForByte(start);
5959 if (startShardIndex === -1) {
5960 // This should not happen since the start and end indices are always
5961 // within 0 and the composite array's length.
5962 throw new Error(`Could not find start shard for byte ${start}`);
5963 }
5964 const size = end - start;
5965 const outputBuffer = new ArrayBuffer(size);
5966 const outputArray = new Uint8Array(outputBuffer);
5967 let sliced = 0;
5968 for (let i = startShardIndex; i < this.shards.length; i++) {
5969 const shard = this.shards[i];
5970 const globalStart = start + sliced;
5971 const localStart = globalStart - shard.start;
5972 const outputStart = sliced;
5973 const globalEnd = Math.min(end, shard.end);
5974 const localEnd = globalEnd - shard.start;
5975 const outputSlice = new Uint8Array(shard.buffer, localStart, localEnd - localStart);
5976 outputArray.set(outputSlice, outputStart);
5977 sliced += outputSlice.length;
5978 if (end < shard.end) {
5979 break;
5980 }
5981 }
5982 return outputBuffer;
5983 }
5984 /**
5985 * Get the index of the shard that contains the byte at `byteIndex`.
5986 */
5987 findShardForByte(byteIndex) {
5988 if (this.shards.length === 0 || byteIndex < 0 ||
5989 byteIndex >= this.byteLength) {
5990 return -1;
5991 }
5992 // If the buffers have a uniform size, compute the shard directly.
5993 if (this.bufferUniformSize != null) {
5994 this.previousShardIndex = Math.floor(byteIndex / this.bufferUniformSize);
5995 return this.previousShardIndex;
5996 }
5997 // If the buffers don't have a uniform size, we need to search for the
5998 // shard. That means we need a function to check where the byteIndex lies
5999 // relative to a given shard.
6000 function check(shard) {
6001 if (byteIndex < shard.start) {
6002 return -1;
6003 }
6004 if (byteIndex >= shard.end) {
6005 return 1;
6006 }
6007 return 0;
6008 }
6009 // For efficiency, try the previous shard first.
6010 if (check(this.shards[this.previousShardIndex]) === 0) {
6011 return this.previousShardIndex;
6012 }
6013 // Otherwise, use a generic search function.
6014 // This should almost never end up being used in practice since the weight
6015 // entries should always be in order.
6016 const index = search(this.shards, check);
6017 if (index === -1) {
6018 return -1;
6019 }
6020 this.previousShardIndex = index;
6021 return this.previousShardIndex;
6022 }
6023 }
6024 /**
6025 * Search for an element of a sorted array.
6026 *
6027 * @param sortedArray The sorted array to search
6028 * @param compare A function to compare the current value against the searched
6029 * value. Return 0 on a match, negative if the searched value is less than
6030 * the value passed to the function, and positive if the searched value is
6031 * greater than the value passed to the function.
6032 * @returns The index of the element, or -1 if it's not in the array.
6033 */
6034 function search(sortedArray, compare) {
6035 // Binary search
6036 let min = 0;
6037 let max = sortedArray.length;
6038 while (min <= max) {
6039 const middle = Math.floor((max - min) / 2) + min;
6040 const side = compare(sortedArray[middle]);
6041 if (side === 0) {
6042 return middle;
6043 }
6044 else if (side < 0) {
6045 max = middle;
6046 }
6047 else {
6048 min = middle + 1;
6049 }
6050 }
6051 return -1;
6052 }
6053
6054 /**
6055 * @license
6056 * Copyright 2018 Google LLC. All Rights Reserved.
6057 * Licensed under the Apache License, Version 2.0 (the "License");
6058 * you may not use this file except in compliance with the License.
6059 * You may obtain a copy of the License at
6060 *
6061 * http://www.apache.org/licenses/LICENSE-2.0
6062 *
6063 * Unless required by applicable law or agreed to in writing, software
6064 * distributed under the License is distributed on an "AS IS" BASIS,
6065 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6066 * See the License for the specific language governing permissions and
6067 * limitations under the License.
6068 * =============================================================================
6069 */
6070 /**
6071 * Enables production mode which disables correctness checks in favor of
6072 * performance.
6073 *
6074 * @doc {heading: 'Environment'}
6075 */
6076 function enableProdMode() {
6077 env().set('PROD', true);
6078 }
6079 /**
6080 * Enables debug mode which will log information about all executed kernels:
6081 * the elapsed time of the kernel execution, as well as the rank, shape, and
6082 * size of the output tensor.
6083 *
6084 * Debug mode will significantly slow down your application as it will
6085 * download the result of every operation to the CPU. This should not be used in
6086 * production. Debug mode does not affect the timing information of the kernel
6087 * execution as we do not measure download time in the kernel execution time.
6088 *
6089 * See also: `tf.profile`, `tf.memory`.
6090 *
6091 * @doc {heading: 'Environment'}
6092 */
6093 function enableDebugMode() {
6094 env().set('DEBUG', true);
6095 }
6096 /** Globally disables deprecation warnings */
6097 function disableDeprecationWarnings() {
6098 env().set('DEPRECATION_WARNINGS_ENABLED', false);
6099 console.warn(`TensorFlow.js deprecation warnings have been disabled.`);
6100 }
6101 /** Warn users about deprecated functionality. */
6102 function deprecationWarn(msg) {
6103 if (env().getBool('DEPRECATION_WARNINGS_ENABLED')) {
6104 console.warn(msg + ' You can disable deprecation warnings with ' +
6105 'tf.disableDeprecationWarnings().');
6106 }
6107 }
6108 setDeprecationWarningFn(deprecationWarn);
6109 /**
6110 * Dispose all variables kept in backend engine.
6111 *
6112 * @doc {heading: 'Environment'}
6113 */
6114 function disposeVariables() {
6115 ENGINE.disposeVariables();
6116 }
6117 /**
6118 * It returns the global engine that keeps track of all tensors and backends.
6119 *
6120 * @doc {heading: 'Environment'}
6121 */
6122 function engine() {
6123 return ENGINE;
6124 }
6125 /**
6126 * Returns memory info at the current time in the program. The result is an
6127 * object with the following properties:
6128 *
6129 * - `numBytes`: Number of bytes allocated (undisposed) at this time.
6130 * - `numTensors`: Number of unique tensors allocated.
6131 * - `numDataBuffers`: Number of unique data buffers allocated
6132 * (undisposed) at this time, which is ≤ the number of tensors
6133 * (e.g. `a.reshape(newShape)` makes a new Tensor that shares the same
6134 * data buffer with `a`).
6135 * - `unreliable`: True if the memory usage is unreliable. See `reasons` when
6136 * `unreliable` is true.
6137 * - `reasons`: `string[]`, reasons why the memory is unreliable, present if
6138 * `unreliable` is true.
6139 *
6140 * WebGL Properties:
6141 * - `numBytesInGPU`: Number of bytes allocated (undisposed) in the GPU only at
6142 * this time.
6143 *
6144 * @doc {heading: 'Performance', subheading: 'Memory'}
6145 */
6146 function memory() {
6147 return ENGINE.memory();
6148 }
6149 /**
6150 * Executes the provided function `f()` and returns a promise that resolves
6151 * with information about the function's memory use:
6152 * - `newBytes`: the number of new bytes allocated
6153 * - `newTensors`: the number of new tensors created
6154 * - `peakBytes`: the peak number of bytes allocated
6155 * - `kernels`: an array of objects for each kernel involved that reports
6156 * their input and output shapes, number of bytes used, and number of new
6157 * tensors created.
6158 * - `kernelNames`: an array of unique strings with just the names of the
6159 * kernels in the `kernels` array.
6160 *
6161 * ```js
6162 * const profile = await tf.profile(() => {
6163 * const x = tf.tensor1d([1, 2, 3]);
6164 * let x2 = x.square();
6165 * x2.dispose();
6166 * x2 = x.square();
6167 * x2.dispose();
6168 * return x;
6169 * });
6170 *
6171 * console.log(`newBytes: ${profile.newBytes}`);
6172 * console.log(`newTensors: ${profile.newTensors}`);
6173 * console.log(`byte usage over all kernels: ${profile.kernels.map(k =>
6174 * k.totalBytesSnapshot)}`);
6175 * ```
6176 *
6177 *
6178 * @doc {heading: 'Performance', subheading: 'Profile'}
6179 */
6180 function profile(f) {
6181 return ENGINE.profile(f);
6182 }
6183 /**
6184 * Executes the provided function `fn` and after it is executed, cleans up all
6185 * intermediate tensors allocated by `fn` except those returned by `fn`.
6186 * `fn` must not return a Promise (async functions not allowed). The returned
6187 * result can be a complex object.
6188 *
6189 * Using this method helps avoid memory leaks. In general, wrap calls to
6190 * operations in `tf.tidy` for automatic memory cleanup.
6191 *
6192 * NOTE: Variables do *not* get cleaned up when inside a tidy(). If you want to
6193 * dispose variables, please use `tf.disposeVariables` or call dispose()
6194 * directly on variables.
6195 *
6196 * ```js
6197 * // y = 2 ^ 2 + 1
6198 * const y = tf.tidy(() => {
6199 * // a, b, and one will be cleaned up when the tidy ends.
6200 * const one = tf.scalar(1);
6201 * const a = tf.scalar(2);
6202 * const b = a.square();
6203 *
6204 * console.log('numTensors (in tidy): ' + tf.memory().numTensors);
6205 *
6206 * // The value returned inside the tidy function will return
6207 * // through the tidy, in this case to the variable y.
6208 * return b.add(one);
6209 * });
6210 *
6211 * console.log('numTensors (outside tidy): ' + tf.memory().numTensors);
6212 * y.print();
6213 * ```
6214 *
6215 * @param nameOrFn The name of the closure, or the function to execute.
6216 * If a name is provided, the 2nd argument should be the function.
6217 * If debug mode is on, the timing and the memory usage of the function
6218 * will be tracked and displayed on the console using the provided name.
6219 * @param fn The function to execute.
6220 *
6221 * @doc {heading: 'Performance', subheading: 'Memory'}
6222 */
6223 function tidy(nameOrFn, fn) {
6224 return ENGINE.tidy(nameOrFn, fn);
6225 }
6226 /**
6227 * Disposes any `tf.Tensor`s found within the provided object.
6228 *
6229 * @param container an object that may be a `tf.Tensor` or may directly
6230 * contain `tf.Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. If
6231 * the object is not a `tf.Tensor` or does not contain `Tensors`, nothing
6232 * happens. In general it is safe to pass any object here, except that
6233 * `Promise`s are not supported.
6234 *
6235 * @doc {heading: 'Performance', subheading: 'Memory'}
6236 */
6237 function dispose(container) {
6238 const tensors = getTensorsInContainer(container);
6239 tensors.forEach(tensor => tensor.dispose());
6240 }
6241 /**
6242 * Keeps a `tf.Tensor` generated inside a `tf.tidy` from being disposed
6243 * automatically.
6244 *
6245 * ```js
6246 * let b;
6247 * const y = tf.tidy(() => {
6248 * const one = tf.scalar(1);
6249 * const a = tf.scalar(2);
6250 *
6251 * // b will not be cleaned up by the tidy. a and one will be cleaned up
6252 * // when the tidy ends.
6253 * b = tf.keep(a.square());
6254 *
6255 * console.log('numTensors (in tidy): ' + tf.memory().numTensors);
6256 *
6257 * // The value returned inside the tidy function will return
6258 * // through the tidy, in this case to the variable y.
6259 * return b.add(one);
6260 * });
6261 *
6262 * console.log('numTensors (outside tidy): ' + tf.memory().numTensors);
6263 * console.log('y:');
6264 * y.print();
6265 * console.log('b:');
6266 * b.print();
6267 * ```
6268 *
6269 * @param result The tensor to keep from being disposed.
6270 *
6271 * @doc {heading: 'Performance', subheading: 'Memory'}
6272 */
6273 function keep(result) {
6274 return ENGINE.keep(result);
6275 }
6276 /**
6277 * Executes `f()` and returns a promise that resolves with timing
6278 * information.
6279 *
6280 * The result is an object with the following properties:
6281 *
6282 * - `wallMs`: Wall execution time.
6283 * - `kernelMs`: Kernel execution time, ignoring data transfer. If using the
6284 * WebGL backend and the query timer extension is not available, this will
6285 * return an error object.
6286 * - On `WebGL` The following additional properties exist:
6287 * - `uploadWaitMs`: CPU blocking time on texture uploads.
6288 * - `downloadWaitMs`: CPU blocking time on texture downloads (readPixels).
6289 *
6290 * ```js
6291 * const x = tf.randomNormal([20, 20]);
6292 * const time = await tf.time(() => x.matMul(x));
6293 *
6294 * console.log(`kernelMs: ${time.kernelMs}, wallTimeMs: ${time.wallMs}`);
6295 * ```
6296 *
6297 * @param f The function to execute and time.
6298 *
6299 * @doc {heading: 'Performance', subheading: 'Timing'}
6300 */
6301 function time(f) {
6302 return ENGINE.time(f);
6303 }
6304 /**
6305 * Sets the backend (cpu, webgl, wasm, etc) responsible for creating tensors and
6306 * executing operations on those tensors. Returns a promise that resolves
6307 * to a boolean if the backend initialization was successful.
6308 *
6309 * Note this disposes the current backend, if any, as well as any tensors
6310 * associated with it. A new backend is initialized, even if it is of the
6311 * same type as the previous one.
6312 *
6313 * @param backendName The name of the backend. Currently supports
6314 * `'webgl'|'cpu'` in the browser, `'tensorflow'` under node.js
6315 * (requires tfjs-node), and `'wasm'` (requires tfjs-backend-wasm).
6316 *
6317 * @doc {heading: 'Backends'}
6318 */
6319 function setBackend$1(backendName) {
6320 return ENGINE.setBackend(backendName);
6321 }
6322 /**
6323 * Returns a promise that resolves when the currently selected backend (or the
6324 * highest priority one) has initialized. Await this promise when you are using
6325 * a backend that has async initialization.
6326 *
6327 * @doc {heading: 'Backends'}
6328 */
6329 function ready() {
6330 return ENGINE.ready();
6331 }
6332 /**
6333 * Returns the current backend name (cpu, webgl, etc). The backend is
6334 * responsible for creating tensors and executing operations on those tensors.
6335 *
6336 * @doc {heading: 'Backends'}
6337 */
6338 function getBackend$1() {
6339 return ENGINE.backendName;
6340 }
6341 /**
6342 * Removes a backend and the registered factory.
6343 *
6344 * @doc {heading: 'Backends'}
6345 */
6346 function removeBackend(name) {
6347 ENGINE.removeBackend(name);
6348 }
6349 /**
6350 * Finds the backend registered under the provided name. Returns null if the
6351 * name is not in the registry, or the registration hasn't finished yet.
6352 */
6353 function findBackend(name) {
6354 return ENGINE.findBackend(name);
6355 }
6356 /**
6357 * Finds the backend factory registered under the provided name. Returns a
6358 * function that produces a new backend when called. Returns null if the name
6359 * is not in the registry.
6360 */
6361 function findBackendFactory(name) {
6362 return ENGINE.findBackendFactory(name);
6363 }
6364 /**
6365 * Registers a global backend. The registration should happen when importing
6366 * a module file (e.g. when importing `backend_webgl.ts`), and is used for
6367 * modular builds (e.g. custom tfjs bundle with only webgl support).
6368 *
6369 * @param factory The backend factory function. When called, it should
6370 * return a backend instance, or a promise of an instance.
6371 * @param priority The priority of the backend (higher = more important).
6372 * In case multiple backends are registered, the priority is used to find
6373 * the best backend. Defaults to 1.
6374 * @return False if there is already a registered backend under this name, true
6375 * if not.
6376 *
6377 * @doc {heading: 'Backends'}
6378 */
6379 function registerBackend(name, factory, priority = 1) {
6380 return ENGINE.registerBackend(name, factory, priority);
6381 }
6382 /**
6383 * Gets the current backend. If no backends have been initialized, this will
6384 * attempt to initialize the best backend. Will throw an error if the highest
6385 * priority backend has async initialization, in which case you should call
6386 * 'await tf.ready()' before running other code.
6387 *
6388 * @doc {heading: 'Backends'}
6389 */
6390 function backend$1() {
6391 return ENGINE.backend;
6392 }
6393 /**
6394 * Sets the global platform.
6395 *
6396 * @param platformName The name of this platform.
6397 * @param platform A platform implementation.
6398 */
6399 function setPlatform(platformName, platform) {
6400 env().setPlatform(platformName, platform);
6401 }
6402
6403 /**
6404 * @license
6405 * Copyright 2018 Google LLC. All Rights Reserved.
6406 * Licensed under the Apache License, Version 2.0 (the "License");
6407 * you may not use this file except in compliance with the License.
6408 * You may obtain a copy of the License at
6409 *
6410 * http://www.apache.org/licenses/LICENSE-2.0
6411 *
6412 * Unless required by applicable law or agreed to in writing, software
6413 * distributed under the License is distributed on an "AS IS" BASIS,
6414 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6415 * See the License for the specific language governing permissions and
6416 * limitations under the License.
6417 * =============================================================================
6418 */
6419 /** Number of bytes reserved for the length of the string. (32bit integer). */
6420 const NUM_BYTES_STRING_LENGTH = 4;
6421 /**
6422 * Encode a map from names to weight values as an ArrayBuffer, along with an
6423 * `Array` of `WeightsManifestEntry` as specification of the encoded weights.
6424 *
6425 * This function does not perform sharding.
6426 *
6427 * This function is the reverse of `decodeWeights`.
6428 *
6429 * @param tensors A map ("dict") from names to tensors.
6430 * @param group Group to which the weights belong (optional).
6431 * @returns A `Promise` of
6432 * - A flat `ArrayBuffer` with all the binary values of the `Tensor`s
6433 * concatenated.
6434 * - An `Array` of `WeightManifestEntry`s, carrying information including
6435 * tensor names, `dtype`s and shapes.
6436 * @throws Error: on unsupported tensor `dtype`.
6437 */
6438 async function encodeWeights(tensors, group) {
6439 // TODO(adarob, cais): Support quantization.
6440 const specs = [];
6441 const dataPromises = [];
6442 const names = Array.isArray(tensors) ?
6443 tensors.map(tensor => tensor.name) :
6444 Object.keys(tensors);
6445 for (let i = 0; i < names.length; ++i) {
6446 const name = names[i];
6447 const t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name];
6448 if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' &&
6449 t.dtype !== 'string' && t.dtype !== 'complex64') {
6450 throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`);
6451 }
6452 const spec = { name, shape: t.shape, dtype: t.dtype };
6453 if (t.dtype === 'string') {
6454 const utf8bytes = new Promise(async (resolve) => {
6455 const vals = await t.bytes();
6456 const totalNumBytes = vals.reduce((p, c) => p + c.length, 0) +
6457 NUM_BYTES_STRING_LENGTH * vals.length;
6458 const bytes = new Uint8Array(totalNumBytes);
6459 let offset = 0;
6460 for (let i = 0; i < vals.length; i++) {
6461 const val = vals[i];
6462 const bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer);
6463 bytes.set(bytesOfLength, offset);
6464 offset += NUM_BYTES_STRING_LENGTH;
6465 bytes.set(val, offset);
6466 offset += val.length;
6467 }
6468 resolve(bytes);
6469 });
6470 dataPromises.push(utf8bytes);
6471 }
6472 else {
6473 dataPromises.push(t.data());
6474 }
6475 if (group != null) {
6476 spec.group = group;
6477 }
6478 specs.push(spec);
6479 }
6480 const tensorValues = await Promise.all(dataPromises);
6481 return { data: concatenateTypedArrays(tensorValues), specs };
6482 }
6483 /**
6484 * Decode flat ArrayBuffer as weights.
6485 *
6486 * This function does not handle sharding.
6487 *
6488 * This function is the reverse of `encodeWeights`.
6489 *
6490 * @param weightData A flat ArrayBuffer or an array of ArrayBuffers carrying the
6491 * binary values of the tensors concatenated in the order specified in
6492 * `specs`.
6493 * @param specs Specifications of the names, dtypes and shapes of the tensors
6494 * whose value are encoded by `buffer`.
6495 * @return A map from tensor name to tensor value, with the names corresponding
6496 * to names in `specs`.
6497 * @throws Error, if any of the tensors has unsupported dtype.
6498 */
6499 function decodeWeights(weightData, specs) {
6500 // TODO(adarob, cais): Support quantization.
6501 const compositeBuffer = new CompositeArrayBuffer(weightData);
6502 const out = {};
6503 let offset = 0;
6504 for (const spec of specs) {
6505 const byteLength = getWeightBytelength(spec, (start, end) => {
6506 return compositeBuffer.slice(offset + start, offset + end);
6507 });
6508 out[spec.name] = decodeWeight(spec, compositeBuffer
6509 .slice(offset, offset + byteLength));
6510 offset += byteLength;
6511 }
6512 return out;
6513 }
6514 function getWeightBytelength(spec, slice) {
6515 const size = sizeFromShape(spec.shape);
6516 let bytesPerValue;
6517 if ('quantization' in spec) {
6518 const quantization = spec.quantization;
6519 bytesPerValue = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
6520 }
6521 else if (spec.dtype === 'string') {
6522 // Can not statically determine string length.
6523 let byteLength = 0;
6524 for (let i = 0; i < size; i++) {
6525 byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array(slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0];
6526 }
6527 return byteLength;
6528 }
6529 else {
6530 bytesPerValue = DTYPE_VALUE_SIZE_MAP[spec.dtype];
6531 }
6532 return size * bytesPerValue;
6533 }
6534 async function getWeightBytelengthAsync(spec, slice) {
6535 const size = sizeFromShape(spec.shape);
6536 let bytesPerValue;
6537 if ('quantization' in spec) {
6538 const quantization = spec.quantization;
6539 bytesPerValue = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
6540 }
6541 else if (spec.dtype === 'string') {
6542 // Can not statically determine string length.
6543 let byteLength = 0;
6544 for (let i = 0; i < size; i++) {
6545 byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array(await slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0];
6546 }
6547 return byteLength;
6548 }
6549 else {
6550 bytesPerValue = DTYPE_VALUE_SIZE_MAP[spec.dtype];
6551 }
6552 return size * bytesPerValue;
6553 }
6554 function decodeWeight(spec, byteBuffer) {
6555 const name = spec.name;
6556 const dtype = spec.dtype;
6557 const shape = spec.shape;
6558 const size = sizeFromShape(shape);
6559 let values;
6560 let offset = 0;
6561 if ('quantization' in spec) {
6562 const quantization = spec.quantization;
6563 if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
6564 if (!('min' in quantization && 'scale' in quantization)) {
6565 throw new Error(`Weight ${spec.name} with quantization ${quantization.dtype} ` +
6566 `doesn't have corresponding metadata min and scale.`);
6567 }
6568 }
6569 else if (quantization.dtype === 'float16') {
6570 if (dtype !== 'float32') {
6571 throw new Error(`Weight ${spec.name} is quantized with ${quantization.dtype} ` +
6572 `which only supports weights of type float32 not ${dtype}.`);
6573 }
6574 }
6575 else {
6576 throw new Error(`Weight ${spec.name} has unknown ` +
6577 `quantization dtype ${quantization.dtype}. ` +
6578 `Supported quantization dtypes are: ` +
6579 `'uint8', 'uint16', and 'float16'.`);
6580 }
6581 const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
6582 const quantizedArray = (quantization.dtype === 'uint8') ?
6583 new Uint8Array(byteBuffer) :
6584 new Uint16Array(byteBuffer);
6585 if (dtype === 'float32') {
6586 if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
6587 values = new Float32Array(quantizedArray.length);
6588 for (let i = 0; i < quantizedArray.length; i++) {
6589 const v = quantizedArray[i];
6590 values[i] = v * quantization.scale + quantization.min;
6591 }
6592 }
6593 else if (quantization.dtype === 'float16') {
6594 // TODO: This is inefficient. Make getFloat16Decoder efficient.
6595 const float16Decode = getFloat16Decoder();
6596 values = float16Decode(quantizedArray);
6597 }
6598 else {
6599 throw new Error(`Unsupported quantization type ${quantization.dtype} ` +
6600 `for weight type float32.`);
6601 }
6602 }
6603 else if (dtype === 'int32') {
6604 if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {
6605 throw new Error(`Unsupported quantization type ${quantization.dtype} ` +
6606 `for weight type int32.`);
6607 }
6608 values = new Int32Array(quantizedArray.length);
6609 for (let i = 0; i < quantizedArray.length; i++) {
6610 const v = quantizedArray[i];
6611 values[i] = Math.round(v * quantization.scale + quantization.min);
6612 }
6613 }
6614 else {
6615 throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
6616 }
6617 offset += size * quantizationSizeFactor;
6618 }
6619 else if (dtype === 'string') {
6620 const size = sizeFromShape(spec.shape);
6621 values = [];
6622 for (let i = 0; i < size; i++) {
6623 const byteLength = new Uint32Array(byteBuffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0];
6624 offset += NUM_BYTES_STRING_LENGTH;
6625 const bytes = new Uint8Array(byteBuffer.slice(offset, offset + byteLength));
6626 values.push(bytes);
6627 offset += byteLength;
6628 }
6629 }
6630 else {
6631 const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype];
6632 if (dtype === 'float32') {
6633 values = new Float32Array(byteBuffer);
6634 }
6635 else if (dtype === 'int32') {
6636 values = new Int32Array(byteBuffer);
6637 }
6638 else if (dtype === 'bool') {
6639 values = new Uint8Array(byteBuffer);
6640 }
6641 else if (dtype === 'complex64') {
6642 values = new Float32Array(byteBuffer);
6643 const real = new Float32Array(values.length / 2);
6644 const image = new Float32Array(values.length / 2);
6645 for (let i = 0; i < real.length; i++) {
6646 real[i] = values[i * 2];
6647 image[i] = values[i * 2 + 1];
6648 }
6649 const realTensor = tensor(real, shape, 'float32');
6650 const imageTensor = tensor(image, shape, 'float32');
6651 const complexTensor = complex$2(realTensor, imageTensor);
6652 realTensor.dispose();
6653 imageTensor.dispose();
6654 return complexTensor;
6655 }
6656 else {
6657 throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
6658 }
6659 offset += size * dtypeFactor;
6660 }
6661 return tensor(values, shape, dtype);
6662 }
6663 async function readToLength(reader, initialData, length) {
6664 let data = new Uint8Array(initialData);
6665 while (data.byteLength < length) {
6666 const { done, value } = await reader.read();
6667 if (done && value == null) {
6668 const missing = length - data.byteLength;
6669 throw new Error(`Reader is done but ${missing} bytes are still expected`);
6670 }
6671 // TODO: Don't create a new array every loop.
6672 const newData = new Uint8Array(data.length + value.byteLength);
6673 newData.set(data, 0);
6674 newData.set(new Uint8Array(value), data.length);
6675 data = newData;
6676 }
6677 return data.buffer;
6678 }
6679 async function decodeWeightsStream(weightStream, specs) {
6680 const tensors = {};
6681 const reader = weightStream.getReader();
6682 let data = new ArrayBuffer(0);
6683 for (const spec of specs) {
6684 const byteLength = await getWeightBytelengthAsync(spec, async (start, end) => {
6685 data = await readToLength(reader, data, end);
6686 return data.slice(start, end);
6687 });
6688 data = await readToLength(reader, data, byteLength);
6689 // Slice the tensor out
6690 const tensorData = data.slice(0, byteLength);
6691 data = data.slice(byteLength);
6692 const weightTensor = decodeWeight(spec, tensorData);
6693 tensors[spec.name] = weightTensor;
6694 // TODO(mattsoulanille): Better way to call uploadToGPU.
6695 // TODO(mattsoulanille): Make this work for webgl too.
6696 if (getBackend$1() === 'webgpu') {
6697 const b = backend$1();
6698 if ('uploadToGPU' in b &&
6699 sizeFromShape(weightTensor.shape) >= env()
6700 .get('WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD')) {
6701 b.uploadToGPU(weightTensor.dataId);
6702 }
6703 }
6704 }
6705 return tensors;
6706 }
6707 /**
6708 * Concatenate TypedArrays into an ArrayBuffer.
6709 */
6710 function concatenateTypedArrays(xs) {
6711 // TODO(adarob, cais): Support quantization.
6712 if (xs === null) {
6713 throw new Error(`Invalid input value: ${JSON.stringify(xs)}`);
6714 }
6715 let totalByteLength = 0;
6716 // `normalizedXs` is here for this reason: a `TypedArray`'s `buffer'
6717 // can have a different byte length from that of the `TypedArray` itself,
6718 // for example, when the `TypedArray` is created from an offset in an
6719 // `ArrayBuffer`. `normliazedXs` holds `TypedArray`s whose `buffer`s match
6720 // the `TypedArray` in byte length. If an element of `xs` does not show
6721 // this property, a new `TypedArray` that satisfy this property will be
6722 // constructed and pushed into `normalizedXs`.
6723 const normalizedXs = [];
6724 xs.forEach((x) => {
6725 totalByteLength += x.byteLength;
6726 // tslint:disable:no-any
6727 normalizedXs.push(x.byteLength === x.buffer.byteLength ? x :
6728 new x.constructor(x));
6729 if (!(x instanceof Float32Array || x instanceof Int32Array ||
6730 x instanceof Uint8Array)) {
6731 throw new Error(`Unsupported TypedArray subtype: ${x.constructor.name}`);
6732 }
6733 // tslint:enable:no-any
6734 });
6735 const y = new Uint8Array(totalByteLength);
6736 let offset = 0;
6737 normalizedXs.forEach((x) => {
6738 y.set(new Uint8Array(x.buffer), offset);
6739 offset += x.byteLength;
6740 });
6741 return y.buffer;
6742 }
6743 // Use Buffer on Node.js instead of Blob/atob/btoa
6744 const useNodeBuffer = typeof Buffer !== 'undefined' &&
6745 (typeof Blob === 'undefined' || typeof atob === 'undefined' ||
6746 typeof btoa === 'undefined');
6747 /**
6748 * Calculate the byte length of a JavaScript string.
6749 *
6750 * Note that a JavaScript string can contain wide characters, therefore the
6751 * length of the string is not necessarily equal to the byte length.
6752 *
6753 * @param str Input string.
6754 * @returns Byte length.
6755 */
6756 function stringByteLength(str) {
6757 if (useNodeBuffer) {
6758 return Buffer.byteLength(str, 'utf8');
6759 }
6760 return new Blob([str]).size;
6761 }
6762 /**
6763 * Encode an ArrayBuffer as a base64 encoded string.
6764 *
6765 * @param buffer `ArrayBuffer` to be converted.
6766 * @returns A string that base64-encodes `buffer`.
6767 */
6768 function arrayBufferToBase64String(buffer) {
6769 if (useNodeBuffer) {
6770 return Buffer.from(buffer).toString('base64');
6771 }
6772 const buf = new Uint8Array(buffer);
6773 let s = '';
6774 for (let i = 0, l = buf.length; i < l; i++) {
6775 s += String.fromCharCode(buf[i]);
6776 }
6777 return btoa(s);
6778 }
6779 /**
6780 * Decode a base64 string as an ArrayBuffer.
6781 *
6782 * @param str Base64 string.
6783 * @returns Decoded `ArrayBuffer`.
6784 */
6785 function base64StringToArrayBuffer(str) {
6786 if (useNodeBuffer) {
6787 const buf = Buffer.from(str, 'base64');
6788 return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength);
6789 }
6790 const s = atob(str);
6791 const buffer = new Uint8Array(s.length);
6792 for (let i = 0; i < s.length; ++i) {
6793 buffer.set([s.charCodeAt(i)], i);
6794 }
6795 return buffer.buffer;
6796 }
6797 /**
6798 * Concatenate a number of ArrayBuffers into one.
6799 *
6800 * @param buffers An array of ArrayBuffers to concatenate, or a single
6801 * ArrayBuffer.
6802 * @returns Result of concatenating `buffers` in order.
6803 *
6804 * @deprecated Use tf.io.CompositeArrayBuffer.join() instead.
6805 */
6806 function concatenateArrayBuffers(buffers) {
6807 return CompositeArrayBuffer.join(buffers);
6808 }
6809 /**
6810 * Get the basename of a path.
6811 *
6812 * Behaves in a way analogous to Linux's basename command.
6813 *
6814 * @param path
6815 */
6816 function basename(path) {
6817 const SEPARATOR = '/';
6818 path = path.trim();
6819 while (path.endsWith(SEPARATOR)) {
6820 path = path.slice(0, path.length - 1);
6821 }
6822 const items = path.split(SEPARATOR);
6823 return items[items.length - 1];
6824 }
6825 /**
6826 * Create `ModelJSON` from `ModelArtifacts`.
6827 *
6828 * @param artifacts Model artifacts, describing the model and its weights.
6829 * @param manifest Weight manifest, describing where the weights of the
6830 * `ModelArtifacts` are stored, and some metadata about them.
6831 * @returns Object representing the `model.json` file describing the model
6832 * artifacts and weights
6833 */
6834 function getModelJSONForModelArtifacts(artifacts, manifest) {
6835 const result = {
6836 modelTopology: artifacts.modelTopology,
6837 format: artifacts.format,
6838 generatedBy: artifacts.generatedBy,
6839 convertedBy: artifacts.convertedBy,
6840 weightsManifest: manifest
6841 };
6842 if (artifacts.signature != null) {
6843 result.signature = artifacts.signature;
6844 }
6845 if (artifacts.userDefinedMetadata != null) {
6846 result.userDefinedMetadata = artifacts.userDefinedMetadata;
6847 }
6848 if (artifacts.modelInitializer != null) {
6849 result.modelInitializer = artifacts.modelInitializer;
6850 }
6851 if (artifacts.initializerSignature != null) {
6852 result.initializerSignature = artifacts.initializerSignature;
6853 }
6854 if (artifacts.trainingConfig != null) {
6855 result.trainingConfig = artifacts.trainingConfig;
6856 }
6857 return result;
6858 }
6859 /**
6860 * Create `ModelArtifacts` from a JSON file and weights.
6861 *
6862 * @param modelJSON Object containing the parsed JSON of `model.json`
6863 * @param weightSpecs The list of WeightsManifestEntry for the model. Must be
6864 * passed if the modelJSON has a weightsManifest.
6865 * @param weightData An ArrayBuffer or array of ArrayBuffers of weight data for
6866 * the model corresponding to the weights in weightSpecs. Must be passed if
6867 * the modelJSON has a weightsManifest.
6868 * @returns A Promise of the `ModelArtifacts`, as described by the JSON file.
6869 */
6870 function getModelArtifactsForJSONSync(modelJSON, weightSpecs, weightData) {
6871 const modelArtifacts = {
6872 modelTopology: modelJSON.modelTopology,
6873 format: modelJSON.format,
6874 generatedBy: modelJSON.generatedBy,
6875 convertedBy: modelJSON.convertedBy
6876 };
6877 if (modelJSON.trainingConfig != null) {
6878 modelArtifacts.trainingConfig = modelJSON.trainingConfig;
6879 }
6880 if (modelJSON.weightsManifest != null) {
6881 if (!weightSpecs) {
6882 throw new Error('modelJSON has weightsManifest but weightSpecs is null');
6883 }
6884 if (!weightData) {
6885 throw new Error('modelJSON has weightsManifest but weightData is null');
6886 }
6887 modelArtifacts.weightSpecs = weightSpecs;
6888 modelArtifacts.weightData = weightData;
6889 }
6890 if (modelJSON.signature != null) {
6891 modelArtifacts.signature = modelJSON.signature;
6892 }
6893 if (modelJSON.userDefinedMetadata != null) {
6894 modelArtifacts.userDefinedMetadata = modelJSON.userDefinedMetadata;
6895 }
6896 if (modelJSON.modelInitializer != null) {
6897 modelArtifacts.modelInitializer = modelJSON.modelInitializer;
6898 }
6899 if (modelJSON.initializerSignature != null) {
6900 modelArtifacts.initializerSignature = modelJSON.initializerSignature;
6901 }
6902 return modelArtifacts;
6903 }
6904 /**
6905 * Create `ModelArtifacts` from a JSON file.
6906 *
6907 * @param modelJSON Object containing the parsed JSON of `model.json`
6908 * @param loadWeights Function that takes the JSON file's weights manifest,
6909 * reads weights from the listed path(s), and returns a Promise of the
6910 * weight manifest entries along with the weights data.
6911 * @returns A Promise of the `ModelArtifacts`, as described by the JSON file.
6912 */
6913 async function getModelArtifactsForJSON(modelJSON, loadWeights) {
6914 let weightSpecs;
6915 let weightData;
6916 if (modelJSON.weightsManifest != null) {
6917 [weightSpecs, weightData] = await loadWeights(modelJSON.weightsManifest);
6918 }
6919 return getModelArtifactsForJSONSync(modelJSON, weightSpecs, weightData);
6920 }
6921 /**
6922 * Populate ModelArtifactsInfo fields for a model with JSON topology.
6923 * @param modelArtifacts
6924 * @returns A ModelArtifactsInfo object.
6925 */
6926 function getModelArtifactsInfoForJSON(modelArtifacts) {
6927 if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
6928 throw new Error('Expected JSON model topology, received ArrayBuffer.');
6929 }
6930 return {
6931 dateSaved: new Date(),
6932 modelTopologyType: 'JSON',
6933 modelTopologyBytes: modelArtifacts.modelTopology == null ?
6934 0 :
6935 stringByteLength(JSON.stringify(modelArtifacts.modelTopology)),
6936 weightSpecsBytes: modelArtifacts.weightSpecs == null ?
6937 0 :
6938 stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)),
6939 weightDataBytes: modelArtifacts.weightData == null ?
6940 0 :
6941 new CompositeArrayBuffer(modelArtifacts.weightData).byteLength,
6942 };
6943 }
6944 /**
6945 * Concatenate the weights stored in a WeightsManifestConfig into a list of
6946 * WeightsManifestEntry
6947 *
6948 * @param weightsManifest The WeightsManifestConfig to extract weights from.
6949 * @returns A list of WeightsManifestEntry of the weights in the weightsManifest
6950 */
6951 function getWeightSpecs(weightsManifest) {
6952 const weightSpecs = [];
6953 for (const entry of weightsManifest) {
6954 weightSpecs.push(...entry.weights);
6955 }
6956 return weightSpecs;
6957 }
6958 /**
6959 * Computes mantisa table for casting Float16 to Float32
6960 * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
6961 *
6962 * @returns Uint32Array, 2048 mantissa lookup values.
6963 */
6964 function computeFloat16MantisaTable() {
6965 const convertMantissa = (i) => {
6966 let m = i << 13;
6967 let e = 0;
6968 while ((m & 0x00800000) === 0) {
6969 e -= 0x00800000;
6970 m <<= 1;
6971 }
6972 m &= ~0x00800000;
6973 e += 0x38800000;
6974 return m | e;
6975 };
6976 const mantisaTable = new Uint32Array(2048);
6977 mantisaTable[0] = 0;
6978 for (let i = 1; i < 1024; i++) {
6979 mantisaTable[i] = convertMantissa(i);
6980 }
6981 for (let i = 1024; i < 2048; i++) {
6982 mantisaTable[i] = 0x38000000 + ((i - 1024) << 13);
6983 }
6984 return mantisaTable;
6985 }
6986 /**
6987 * Computes exponent table for casting Float16 to Float32
6988 * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
6989 *
6990 * @returns Uint32Array, 64 exponent lookup values.
6991 */
6992 function computeFloat16ExponentTable() {
6993 const exponentTable = new Uint32Array(64);
6994 exponentTable[0] = 0;
6995 exponentTable[31] = 0x47800000;
6996 exponentTable[32] = 0x80000000;
6997 exponentTable[63] = 0xc7800000;
6998 for (let i = 1; i < 31; i++) {
6999 exponentTable[i] = i << 23;
7000 }
7001 for (let i = 33; i < 63; i++) {
7002 exponentTable[i] = 0x80000000 + ((i - 32) << 23);
7003 }
7004 return exponentTable;
7005 }
7006 /**
7007 * Computes offset table for casting Float16 to Float32
7008 * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
7009 *
7010 * @returns Uint32Array, 6d offset values.
7011 */
7012 function computeFloat16OffsetTable() {
7013 const offsetTable = new Uint32Array(64);
7014 for (let i = 0; i < 64; i++) {
7015 offsetTable[i] = 1024;
7016 }
7017 offsetTable[0] = offsetTable[32] = 0;
7018 return offsetTable;
7019 }
7020 /**
7021 * Retrieve a Float16 decoder which will decode a ByteArray of Float16 values
7022 * to a Float32Array.
7023 *
7024 * @returns Function (buffer: Uint16Array) => Float32Array which decodes
7025 * the Uint16Array of Float16 bytes to a Float32Array.
7026 */
7027 function getFloat16Decoder() {
7028 // Algorithm is based off of
7029 // http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
7030 // Cache lookup tables
7031 const mantisaTable = computeFloat16MantisaTable();
7032 const exponentTable = computeFloat16ExponentTable();
7033 const offsetTable = computeFloat16OffsetTable();
7034 return (quantizedArray) => {
7035 const buffer = new ArrayBuffer(4 * quantizedArray.length);
7036 const bufferUint32View = new Uint32Array(buffer);
7037 for (let index = 0; index < quantizedArray.length; index++) {
7038 const float16Bits = quantizedArray[index];
7039 const float32Bits = mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 0x3ff)] +
7040 exponentTable[float16Bits >> 10];
7041 bufferUint32View[index] = float32Bits;
7042 }
7043 return new Float32Array(buffer);
7044 };
7045 }
7046
7047 /**
7048 * @license
7049 * Copyright 2018 Google LLC. All Rights Reserved.
7050 * Licensed under the Apache License, Version 2.0 (the "License");
7051 * you may not use this file except in compliance with the License.
7052 * You may obtain a copy of the License at
7053 *
7054 * http://www.apache.org/licenses/LICENSE-2.0
7055 *
7056 * Unless required by applicable law or agreed to in writing, software
7057 * distributed under the License is distributed on an "AS IS" BASIS,
7058 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7059 * See the License for the specific language governing permissions and
7060 * limitations under the License.
7061 * =============================================================================
7062 */
7063 class IORouterRegistry {
7064 constructor() {
7065 this.saveRouters = [];
7066 this.loadRouters = [];
7067 }
7068 static getInstance() {
7069 if (IORouterRegistry.instance == null) {
7070 IORouterRegistry.instance = new IORouterRegistry();
7071 }
7072 return IORouterRegistry.instance;
7073 }
7074 /**
7075 * Register a save-handler router.
7076 *
7077 * @param saveRouter A function that maps a URL-like string onto an instance
7078 * of `IOHandler` with the `save` method defined or `null`.
7079 */
7080 static registerSaveRouter(saveRouter) {
7081 IORouterRegistry.getInstance().saveRouters.push(saveRouter);
7082 }
7083 /**
7084 * Register a load-handler router.
7085 *
7086 * @param loadRouter A function that maps a URL-like string onto an instance
7087 * of `IOHandler` with the `load` method defined or `null`.
7088 */
7089 static registerLoadRouter(loadRouter) {
7090 IORouterRegistry.getInstance().loadRouters.push(loadRouter);
7091 }
7092 /**
7093 * Look up IOHandler for saving, given a URL-like string.
7094 *
7095 * @param url
7096 * @returns If only one match is found, an instance of IOHandler with the
7097 * `save` method defined. If no match is found, `null`.
7098 * @throws Error, if more than one match is found.
7099 */
7100 static getSaveHandlers(url) {
7101 return IORouterRegistry.getHandlers(url, 'save');
7102 }
7103 /**
7104 * Look up IOHandler for loading, given a URL-like string.
7105 *
7106 * @param url
7107 * @param loadOptions Optional, custom load options.
7108 * @returns All valid handlers for `url`, given the currently registered
7109 * handler routers.
7110 */
7111 static getLoadHandlers(url, loadOptions) {
7112 return IORouterRegistry.getHandlers(url, 'load', loadOptions);
7113 }
7114 static getHandlers(url, handlerType, loadOptions) {
7115 const validHandlers = [];
7116 const routers = handlerType === 'load' ?
7117 IORouterRegistry.getInstance().loadRouters :
7118 IORouterRegistry.getInstance().saveRouters;
7119 routers.forEach(router => {
7120 const handler = router(url, loadOptions);
7121 if (handler !== null) {
7122 validHandlers.push(handler);
7123 }
7124 });
7125 return validHandlers;
7126 }
7127 }
7128 const registerSaveRouter = (loudRouter) => IORouterRegistry.registerSaveRouter(loudRouter);
7129 const registerLoadRouter = (loudRouter) => IORouterRegistry.registerLoadRouter(loudRouter);
7130 const getSaveHandlers = (url) => IORouterRegistry.getSaveHandlers(url);
7131 const getLoadHandlers = (url, loadOptions) => IORouterRegistry.getLoadHandlers(url, loadOptions);
7132
7133 /**
7134 * @license
7135 * Copyright 2018 Google LLC. All Rights Reserved.
7136 * Licensed under the Apache License, Version 2.0 (the "License");
7137 * you may not use this file except in compliance with the License.
7138 * You may obtain a copy of the License at
7139 *
7140 * http://www.apache.org/licenses/LICENSE-2.0
7141 *
7142 * Unless required by applicable law or agreed to in writing, software
7143 * distributed under the License is distributed on an "AS IS" BASIS,
7144 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7145 * See the License for the specific language governing permissions and
7146 * limitations under the License.
7147 * =============================================================================
7148 */
7149 const DATABASE_NAME = 'tensorflowjs';
7150 const DATABASE_VERSION = 1;
7151 // Model data and ModelArtifactsInfo (metadata) are stored in two separate
7152 // stores for efficient access of the list of stored models and their metadata.
7153 // 1. The object store for model data: topology, weights and weight manifests.
7154 const MODEL_STORE_NAME = 'models_store';
7155 // 2. The object store for ModelArtifactsInfo, including meta-information such
7156 // as the type of topology (JSON vs binary), byte size of the topology, byte
7157 // size of the weights, etc.
7158 const INFO_STORE_NAME = 'model_info_store';
7159 /**
7160 * Delete the entire database for tensorflow.js, including the models store.
7161 */
7162 async function deleteDatabase() {
7163 const idbFactory = getIndexedDBFactory();
7164 return new Promise((resolve, reject) => {
7165 const deleteRequest = idbFactory.deleteDatabase(DATABASE_NAME);
7166 deleteRequest.onsuccess = () => resolve();
7167 deleteRequest.onerror = error => reject(error);
7168 });
7169 }
7170 function getIndexedDBFactory() {
7171 if (!env().getBool('IS_BROWSER')) {
7172 // TODO(cais): Add more info about what IOHandler subtypes are available.
7173 // Maybe point to a doc page on the web and/or automatically determine
7174 // the available IOHandlers and print them in the error message.
7175 throw new Error('Failed to obtain IndexedDB factory because the current environment' +
7176 'is not a web browser.');
7177 }
7178 // tslint:disable-next-line:no-any
7179 const theWindow = typeof window === 'undefined' ? self : window;
7180 const factory = theWindow.indexedDB || theWindow.mozIndexedDB ||
7181 theWindow.webkitIndexedDB || theWindow.msIndexedDB ||
7182 theWindow.shimIndexedDB;
7183 if (factory == null) {
7184 throw new Error('The current browser does not appear to support IndexedDB.');
7185 }
7186 return factory;
7187 }
7188 function setUpDatabase(openRequest) {
7189 const db = openRequest.result;
7190 db.createObjectStore(MODEL_STORE_NAME, { keyPath: 'modelPath' });
7191 db.createObjectStore(INFO_STORE_NAME, { keyPath: 'modelPath' });
7192 }
7193 /**
7194 * IOHandler subclass: Browser IndexedDB.
7195 *
7196 * See the doc string of `browserIndexedDB` for more details.
7197 */
7198 class BrowserIndexedDB {
7199 constructor(modelPath) {
7200 this.indexedDB = getIndexedDBFactory();
7201 if (modelPath == null || !modelPath) {
7202 throw new Error('For IndexedDB, modelPath must not be null, undefined or empty.');
7203 }
7204 this.modelPath = modelPath;
7205 }
7206 async save(modelArtifacts) {
7207 // TODO(cais): Support saving GraphDef models.
7208 if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
7209 throw new Error('BrowserLocalStorage.save() does not support saving model topology ' +
7210 'in binary formats yet.');
7211 }
7212 return this.databaseAction(this.modelPath, modelArtifacts);
7213 }
7214 async load() {
7215 return this.databaseAction(this.modelPath);
7216 }
7217 /**
7218 * Perform database action to put model artifacts into or read model artifacts
7219 * from IndexedDB object store.
7220 *
7221 * Whether the action is put or get depends on whether `modelArtifacts` is
7222 * specified. If it is specified, the action will be put; otherwise the action
7223 * will be get.
7224 *
7225 * @param modelPath A unique string path for the model.
7226 * @param modelArtifacts If specified, it will be the model artifacts to be
7227 * stored in IndexedDB.
7228 * @returns A `Promise` of `SaveResult`, if the action is put, or a `Promise`
7229 * of `ModelArtifacts`, if the action is get.
7230 */
7231 databaseAction(modelPath, modelArtifacts) {
7232 return new Promise((resolve, reject) => {
7233 const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
7234 openRequest.onupgradeneeded = () => setUpDatabase(openRequest);
7235 openRequest.onsuccess = () => {
7236 const db = openRequest.result;
7237 if (modelArtifacts == null) {
7238 // Read model out from object store.
7239 const modelTx = db.transaction(MODEL_STORE_NAME, 'readonly');
7240 const modelStore = modelTx.objectStore(MODEL_STORE_NAME);
7241 const getRequest = modelStore.get(this.modelPath);
7242 getRequest.onsuccess = () => {
7243 if (getRequest.result == null) {
7244 db.close();
7245 return reject(new Error(`Cannot find model with path '${this.modelPath}' ` +
7246 `in IndexedDB.`));
7247 }
7248 else {
7249 resolve(getRequest.result.modelArtifacts);
7250 }
7251 };
7252 getRequest.onerror = error => {
7253 db.close();
7254 return reject(getRequest.error);
7255 };
7256 modelTx.oncomplete = () => db.close();
7257 }
7258 else {
7259 // Put model into object store.
7260 // Concatenate all the model weights into a single ArrayBuffer. Large
7261 // models (~1GB) have problems saving if they are not concatenated.
7262 // TODO(mattSoulanille): Save large models to multiple indexeddb
7263 // records.
7264 modelArtifacts.weightData = CompositeArrayBuffer.join(modelArtifacts.weightData);
7265 const modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts);
7266 // First, put ModelArtifactsInfo into info store.
7267 const infoTx = db.transaction(INFO_STORE_NAME, 'readwrite');
7268 let infoStore = infoTx.objectStore(INFO_STORE_NAME);
7269 let putInfoRequest;
7270 try {
7271 putInfoRequest =
7272 infoStore.put({ modelPath: this.modelPath, modelArtifactsInfo });
7273 }
7274 catch (error) {
7275 return reject(error);
7276 }
7277 let modelTx;
7278 putInfoRequest.onsuccess = () => {
7279 // Second, put model data into model store.
7280 modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite');
7281 const modelStore = modelTx.objectStore(MODEL_STORE_NAME);
7282 let putModelRequest;
7283 try {
7284 putModelRequest = modelStore.put({
7285 modelPath: this.modelPath,
7286 modelArtifacts,
7287 modelArtifactsInfo
7288 });
7289 }
7290 catch (error) {
7291 // Sometimes, the serialized value is too large to store.
7292 return reject(error);
7293 }
7294 putModelRequest.onsuccess = () => resolve({ modelArtifactsInfo });
7295 putModelRequest.onerror = error => {
7296 // If the put-model request fails, roll back the info entry as
7297 // well.
7298 infoStore = infoTx.objectStore(INFO_STORE_NAME);
7299 const deleteInfoRequest = infoStore.delete(this.modelPath);
7300 deleteInfoRequest.onsuccess = () => {
7301 db.close();
7302 return reject(putModelRequest.error);
7303 };
7304 deleteInfoRequest.onerror = error => {
7305 db.close();
7306 return reject(putModelRequest.error);
7307 };
7308 };
7309 };
7310 putInfoRequest.onerror = error => {
7311 db.close();
7312 return reject(putInfoRequest.error);
7313 };
7314 infoTx.oncomplete = () => {
7315 if (modelTx == null) {
7316 db.close();
7317 }
7318 else {
7319 modelTx.oncomplete = () => db.close();
7320 }
7321 };
7322 }
7323 };
7324 openRequest.onerror = error => reject(openRequest.error);
7325 });
7326 }
7327 }
7328 BrowserIndexedDB.URL_SCHEME = 'indexeddb://';
7329 const indexedDBRouter = (url) => {
7330 if (!env().getBool('IS_BROWSER')) {
7331 return null;
7332 }
7333 else {
7334 if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) {
7335 return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length));
7336 }
7337 else {
7338 return null;
7339 }
7340 }
7341 };
7342 IORouterRegistry.registerSaveRouter(indexedDBRouter);
7343 IORouterRegistry.registerLoadRouter(indexedDBRouter);
7344 /**
7345 * Creates a browser IndexedDB IOHandler for saving and loading models.
7346 *
7347 * ```js
7348 * const model = tf.sequential();
7349 * model.add(
7350 * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'}));
7351 *
7352 * const saveResult = await model.save('indexeddb://MyModel'));
7353 * console.log(saveResult);
7354 * ```
7355 *
7356 * @param modelPath A unique identifier for the model to be saved. Must be a
7357 * non-empty string.
7358 * @returns An instance of `BrowserIndexedDB` (subclass of `IOHandler`),
7359 * which can be used with, e.g., `tf.Model.save`.
7360 */
7361 function browserIndexedDB(modelPath) {
7362 return new BrowserIndexedDB(modelPath);
7363 }
7364 function maybeStripScheme$1(key) {
7365 return key.startsWith(BrowserIndexedDB.URL_SCHEME) ?
7366 key.slice(BrowserIndexedDB.URL_SCHEME.length) :
7367 key;
7368 }
7369 class BrowserIndexedDBManager {
7370 constructor() {
7371 this.indexedDB = getIndexedDBFactory();
7372 }
7373 async listModels() {
7374 return new Promise((resolve, reject) => {
7375 const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
7376 openRequest.onupgradeneeded = () => setUpDatabase(openRequest);
7377 openRequest.onsuccess = () => {
7378 const db = openRequest.result;
7379 const tx = db.transaction(INFO_STORE_NAME, 'readonly');
7380 const store = tx.objectStore(INFO_STORE_NAME);
7381 // tslint:disable:max-line-length
7382 // Need to cast `store` as `any` here because TypeScript's DOM
7383 // library does not have the `getAll()` method even though the
7384 // method is supported in the latest version of most mainstream
7385 // browsers:
7386 // https://developer.mozilla.org/en-US/docs/Web/API/IDBObjectStore/getAll
7387 // tslint:enable:max-line-length
7388 // tslint:disable-next-line:no-any
7389 const getAllInfoRequest = store.getAll();
7390 getAllInfoRequest.onsuccess = () => {
7391 const out = {};
7392 for (const item of getAllInfoRequest.result) {
7393 out[item.modelPath] = item.modelArtifactsInfo;
7394 }
7395 resolve(out);
7396 };
7397 getAllInfoRequest.onerror = error => {
7398 db.close();
7399 return reject(getAllInfoRequest.error);
7400 };
7401 tx.oncomplete = () => db.close();
7402 };
7403 openRequest.onerror = error => reject(openRequest.error);
7404 });
7405 }
7406 async removeModel(path) {
7407 path = maybeStripScheme$1(path);
7408 return new Promise((resolve, reject) => {
7409 const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
7410 openRequest.onupgradeneeded = () => setUpDatabase(openRequest);
7411 openRequest.onsuccess = () => {
7412 const db = openRequest.result;
7413 const infoTx = db.transaction(INFO_STORE_NAME, 'readwrite');
7414 const infoStore = infoTx.objectStore(INFO_STORE_NAME);
7415 const getInfoRequest = infoStore.get(path);
7416 let modelTx;
7417 getInfoRequest.onsuccess = () => {
7418 if (getInfoRequest.result == null) {
7419 db.close();
7420 return reject(new Error(`Cannot find model with path '${path}' ` +
7421 `in IndexedDB.`));
7422 }
7423 else {
7424 // First, delete the entry in the info store.
7425 const deleteInfoRequest = infoStore.delete(path);
7426 const deleteModelData = () => {
7427 // Second, delete the entry in the model store.
7428 modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite');
7429 const modelStore = modelTx.objectStore(MODEL_STORE_NAME);
7430 const deleteModelRequest = modelStore.delete(path);
7431 deleteModelRequest.onsuccess = () => resolve(getInfoRequest.result.modelArtifactsInfo);
7432 deleteModelRequest.onerror = error => reject(getInfoRequest.error);
7433 };
7434 // Proceed with deleting model data regardless of whether deletion
7435 // of info data succeeds or not.
7436 deleteInfoRequest.onsuccess = deleteModelData;
7437 deleteInfoRequest.onerror = error => {
7438 deleteModelData();
7439 db.close();
7440 return reject(getInfoRequest.error);
7441 };
7442 }
7443 };
7444 getInfoRequest.onerror = error => {
7445 db.close();
7446 return reject(getInfoRequest.error);
7447 };
7448 infoTx.oncomplete = () => {
7449 if (modelTx == null) {
7450 db.close();
7451 }
7452 else {
7453 modelTx.oncomplete = () => db.close();
7454 }
7455 };
7456 };
7457 openRequest.onerror = error => reject(openRequest.error);
7458 });
7459 }
7460 }
7461
7462 /**
7463 * @license
7464 * Copyright 2018 Google LLC. All Rights Reserved.
7465 * Licensed under the Apache License, Version 2.0 (the "License");
7466 * you may not use this file except in compliance with the License.
7467 * You may obtain a copy of the License at
7468 *
7469 * http://www.apache.org/licenses/LICENSE-2.0
7470 *
7471 * Unless required by applicable law or agreed to in writing, software
7472 * distributed under the License is distributed on an "AS IS" BASIS,
7473 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7474 * See the License for the specific language governing permissions and
7475 * limitations under the License.
7476 * =============================================================================
7477 */
7478 const PATH_SEPARATOR = '/';
7479 const PATH_PREFIX = 'tensorflowjs_models';
7480 const INFO_SUFFIX = 'info';
7481 const MODEL_TOPOLOGY_SUFFIX = 'model_topology';
7482 const WEIGHT_SPECS_SUFFIX = 'weight_specs';
7483 const WEIGHT_DATA_SUFFIX = 'weight_data';
7484 const MODEL_METADATA_SUFFIX = 'model_metadata';
7485 /**
7486 * Purge all tensorflow.js-saved model artifacts from local storage.
7487 *
7488 * @returns Paths of the models purged.
7489 */
7490 function purgeLocalStorageArtifacts() {
7491 if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' ||
7492 typeof window.localStorage === 'undefined') {
7493 throw new Error('purgeLocalStorageModels() cannot proceed because local storage is ' +
7494 'unavailable in the current environment.');
7495 }
7496 const LS = window.localStorage;
7497 const purgedModelPaths = [];
7498 for (let i = 0; i < LS.length; ++i) {
7499 const key = LS.key(i);
7500 const prefix = PATH_PREFIX + PATH_SEPARATOR;
7501 if (key.startsWith(prefix) && key.length > prefix.length) {
7502 LS.removeItem(key);
7503 const modelName = getModelPathFromKey(key);
7504 if (purgedModelPaths.indexOf(modelName) === -1) {
7505 purgedModelPaths.push(modelName);
7506 }
7507 }
7508 }
7509 return purgedModelPaths;
7510 }
7511 function getModelKeys(path) {
7512 return {
7513 info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR),
7514 topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR),
7515 weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR),
7516 weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR),
7517 modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR)
7518 };
7519 }
7520 function removeItems(keys) {
7521 for (const key of Object.values(keys)) {
7522 window.localStorage.removeItem(key);
7523 }
7524 }
7525 /**
7526 * Get model path from a local-storage key.
7527 *
7528 * E.g., 'tensorflowjs_models/my/model/1/info' --> 'my/model/1'
7529 *
7530 * @param key
7531 */
7532 function getModelPathFromKey(key) {
7533 const items = key.split(PATH_SEPARATOR);
7534 if (items.length < 3) {
7535 throw new Error(`Invalid key format: ${key}`);
7536 }
7537 return items.slice(1, items.length - 1).join(PATH_SEPARATOR);
7538 }
7539 function maybeStripScheme(key) {
7540 return key.startsWith(BrowserLocalStorage.URL_SCHEME) ?
7541 key.slice(BrowserLocalStorage.URL_SCHEME.length) :
7542 key;
7543 }
7544 /**
7545 * IOHandler subclass: Browser Local Storage.
7546 *
7547 * See the doc string to `browserLocalStorage` for more details.
7548 */
7549 class BrowserLocalStorage {
7550 constructor(modelPath) {
7551 if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' ||
7552 typeof window.localStorage === 'undefined') {
7553 // TODO(cais): Add more info about what IOHandler subtypes are
7554 // available.
7555 // Maybe point to a doc page on the web and/or automatically determine
7556 // the available IOHandlers and print them in the error message.
7557 throw new Error('The current environment does not support local storage.');
7558 }
7559 this.LS = window.localStorage;
7560 if (modelPath == null || !modelPath) {
7561 throw new Error('For local storage, modelPath must not be null, undefined or empty.');
7562 }
7563 this.modelPath = modelPath;
7564 this.keys = getModelKeys(this.modelPath);
7565 }
7566 /**
7567 * Save model artifacts to browser local storage.
7568 *
7569 * See the documentation to `browserLocalStorage` for details on the saved
7570 * artifacts.
7571 *
7572 * @param modelArtifacts The model artifacts to be stored.
7573 * @returns An instance of SaveResult.
7574 */
7575 async save(modelArtifacts) {
7576 if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
7577 throw new Error('BrowserLocalStorage.save() does not support saving model topology ' +
7578 'in binary formats yet.');
7579 }
7580 else {
7581 const topology = JSON.stringify(modelArtifacts.modelTopology);
7582 const weightSpecs = JSON.stringify(modelArtifacts.weightSpecs);
7583 const modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts);
7584 // TODO(mattsoulanille): Support saving models over 2GB that exceed
7585 // Chrome's ArrayBuffer size limit.
7586 const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData);
7587 try {
7588 this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo));
7589 this.LS.setItem(this.keys.topology, topology);
7590 this.LS.setItem(this.keys.weightSpecs, weightSpecs);
7591 this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(weightBuffer));
7592 // Note that JSON.stringify doesn't write out keys that have undefined
7593 // values, so for some keys, we set undefined instead of a null-ish
7594 // value.
7595 const metadata = {
7596 format: modelArtifacts.format,
7597 generatedBy: modelArtifacts.generatedBy,
7598 convertedBy: modelArtifacts.convertedBy,
7599 signature: modelArtifacts.signature != null ?
7600 modelArtifacts.signature :
7601 undefined,
7602 userDefinedMetadata: modelArtifacts.userDefinedMetadata != null ?
7603 modelArtifacts.userDefinedMetadata :
7604 undefined,
7605 modelInitializer: modelArtifacts.modelInitializer != null ?
7606 modelArtifacts.modelInitializer :
7607 undefined,
7608 initializerSignature: modelArtifacts.initializerSignature != null ?
7609 modelArtifacts.initializerSignature :
7610 undefined,
7611 trainingConfig: modelArtifacts.trainingConfig != null ?
7612 modelArtifacts.trainingConfig :
7613 undefined
7614 };
7615 this.LS.setItem(this.keys.modelMetadata, JSON.stringify(metadata));
7616 return { modelArtifactsInfo };
7617 }
7618 catch (err) {
7619 // If saving failed, clean up all items saved so far.
7620 removeItems(this.keys);
7621 throw new Error(`Failed to save model '${this.modelPath}' to local storage: ` +
7622 `size quota being exceeded is a possible cause of this failure: ` +
7623 `modelTopologyBytes=${modelArtifactsInfo.modelTopologyBytes}, ` +
7624 `weightSpecsBytes=${modelArtifactsInfo.weightSpecsBytes}, ` +
7625 `weightDataBytes=${modelArtifactsInfo.weightDataBytes}.`);
7626 }
7627 }
7628 }
7629 /**
7630 * Load a model from local storage.
7631 *
7632 * See the documentation to `browserLocalStorage` for details on the saved
7633 * artifacts.
7634 *
7635 * @returns The loaded model (if loading succeeds).
7636 */
7637 async load() {
7638 const info = JSON.parse(this.LS.getItem(this.keys.info));
7639 if (info == null) {
7640 throw new Error(`In local storage, there is no model with name '${this.modelPath}'`);
7641 }
7642 if (info.modelTopologyType !== 'JSON') {
7643 throw new Error('BrowserLocalStorage does not support loading non-JSON model ' +
7644 'topology yet.');
7645 }
7646 const out = {};
7647 // Load topology.
7648 const topology = JSON.parse(this.LS.getItem(this.keys.topology));
7649 if (topology == null) {
7650 throw new Error(`In local storage, the topology of model '${this.modelPath}' ` +
7651 `is missing.`);
7652 }
7653 out.modelTopology = topology;
7654 // Load weight specs.
7655 const weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs));
7656 if (weightSpecs == null) {
7657 throw new Error(`In local storage, the weight specs of model '${this.modelPath}' ` +
7658 `are missing.`);
7659 }
7660 out.weightSpecs = weightSpecs;
7661 // Load meta-data fields.
7662 const metadataString = this.LS.getItem(this.keys.modelMetadata);
7663 if (metadataString != null) {
7664 const metadata = JSON.parse(metadataString);
7665 out.format = metadata.format;
7666 out.generatedBy = metadata.generatedBy;
7667 out.convertedBy = metadata.convertedBy;
7668 if (metadata.signature != null) {
7669 out.signature = metadata.signature;
7670 }
7671 if (metadata.userDefinedMetadata != null) {
7672 out.userDefinedMetadata = metadata.userDefinedMetadata;
7673 }
7674 if (metadata.modelInitializer != null) {
7675 out.modelInitializer = metadata.modelInitializer;
7676 }
7677 if (metadata.initializerSignature != null) {
7678 out.initializerSignature = metadata.initializerSignature;
7679 }
7680 if (metadata.trainingConfig != null) {
7681 out.trainingConfig = metadata.trainingConfig;
7682 }
7683 }
7684 // Load weight data.
7685 const weightDataBase64 = this.LS.getItem(this.keys.weightData);
7686 if (weightDataBase64 == null) {
7687 throw new Error(`In local storage, the binary weight values of model ` +
7688 `'${this.modelPath}' are missing.`);
7689 }
7690 out.weightData = base64StringToArrayBuffer(weightDataBase64);
7691 return out;
7692 }
7693 }
7694 BrowserLocalStorage.URL_SCHEME = 'localstorage://';
7695 const localStorageRouter = (url) => {
7696 if (!env().getBool('IS_BROWSER')) {
7697 return null;
7698 }
7699 else {
7700 if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) {
7701 return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length));
7702 }
7703 else {
7704 return null;
7705 }
7706 }
7707 };
7708 IORouterRegistry.registerSaveRouter(localStorageRouter);
7709 IORouterRegistry.registerLoadRouter(localStorageRouter);
7710 /**
7711 * Factory function for local storage IOHandler.
7712 *
7713 * This `IOHandler` supports both `save` and `load`.
7714 *
7715 * For each model's saved artifacts, four items are saved to local storage.
7716 * - `${PATH_SEPARATOR}/${modelPath}/info`: Contains meta-info about the
7717 * model, such as date saved, type of the topology, size in bytes, etc.
7718 * - `${PATH_SEPARATOR}/${modelPath}/topology`: Model topology. For Keras-
7719 * style models, this is a stringized JSON.
7720 * - `${PATH_SEPARATOR}/${modelPath}/weight_specs`: Weight specs of the
7721 * model, can be used to decode the saved binary weight values (see
7722 * item below).
7723 * - `${PATH_SEPARATOR}/${modelPath}/weight_data`: Concatenated binary
7724 * weight values, stored as a base64-encoded string.
7725 *
7726 * Saving may throw an `Error` if the total size of the artifacts exceed the
7727 * browser-specific quota.
7728 *
7729 * @param modelPath A unique identifier for the model to be saved. Must be a
7730 * non-empty string.
7731 * @returns An instance of `IOHandler`, which can be used with, e.g.,
7732 * `tf.Model.save`.
7733 */
7734 function browserLocalStorage(modelPath) {
7735 return new BrowserLocalStorage(modelPath);
7736 }
7737 class BrowserLocalStorageManager {
7738 constructor() {
7739 assert$1(env().getBool('IS_BROWSER'), () => 'Current environment is not a web browser');
7740 assert$1(typeof window === 'undefined' ||
7741 typeof window.localStorage !== 'undefined', () => 'Current browser does not appear to support localStorage');
7742 this.LS = window.localStorage;
7743 }
7744 async listModels() {
7745 const out = {};
7746 const prefix = PATH_PREFIX + PATH_SEPARATOR;
7747 const suffix = PATH_SEPARATOR + INFO_SUFFIX;
7748 for (let i = 0; i < this.LS.length; ++i) {
7749 const key = this.LS.key(i);
7750 if (key.startsWith(prefix) && key.endsWith(suffix)) {
7751 const modelPath = getModelPathFromKey(key);
7752 out[modelPath] = JSON.parse(this.LS.getItem(key));
7753 }
7754 }
7755 return out;
7756 }
7757 async removeModel(path) {
7758 path = maybeStripScheme(path);
7759 const keys = getModelKeys(path);
7760 if (this.LS.getItem(keys.info) == null) {
7761 throw new Error(`Cannot find model at path '${path}'`);
7762 }
7763 const info = JSON.parse(this.LS.getItem(keys.info));
7764 removeItems(keys);
7765 return info;
7766 }
7767 }
7768
7769 /**
7770 * @license
7771 * Copyright 2018 Google LLC. All Rights Reserved.
7772 * Licensed under the Apache License, Version 2.0 (the "License");
7773 * you may not use this file except in compliance with the License.
7774 * You may obtain a copy of the License at
7775 *
7776 * http://www.apache.org/licenses/LICENSE-2.0
7777 *
7778 * Unless required by applicable law or agreed to in writing, software
7779 * distributed under the License is distributed on an "AS IS" BASIS,
7780 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7781 * See the License for the specific language governing permissions and
7782 * limitations under the License.
7783 * =============================================================================
7784 */
7785 const URL_SCHEME_SUFFIX = '://';
7786 class ModelStoreManagerRegistry {
7787 constructor() {
7788 this.managers = {};
7789 }
7790 static getInstance() {
7791 if (ModelStoreManagerRegistry.instance == null) {
7792 ModelStoreManagerRegistry.instance = new ModelStoreManagerRegistry();
7793 }
7794 return ModelStoreManagerRegistry.instance;
7795 }
7796 /**
7797 * Register a save-handler router.
7798 *
7799 * @param saveRouter A function that maps a URL-like string onto an instance
7800 * of `IOHandler` with the `save` method defined or `null`.
7801 */
7802 static registerManager(scheme, manager) {
7803 assert$1(scheme != null, () => 'scheme must not be undefined or null.');
7804 if (scheme.endsWith(URL_SCHEME_SUFFIX)) {
7805 scheme = scheme.slice(0, scheme.indexOf(URL_SCHEME_SUFFIX));
7806 }
7807 assert$1(scheme.length > 0, () => 'scheme must not be an empty string.');
7808 const registry = ModelStoreManagerRegistry.getInstance();
7809 assert$1(registry.managers[scheme] == null, () => `A model store manager is already registered for scheme '${scheme}'.`);
7810 registry.managers[scheme] = manager;
7811 }
7812 static getManager(scheme) {
7813 const manager = ModelStoreManagerRegistry.getInstance().managers[scheme];
7814 if (manager == null) {
7815 throw new Error(`Cannot find model manager for scheme '${scheme}'`);
7816 }
7817 return manager;
7818 }
7819 static getSchemes() {
7820 return Object.keys(ModelStoreManagerRegistry.getInstance().managers);
7821 }
7822 }
7823 /**
7824 * Helper method for parsing a URL string into a scheme and a path.
7825 *
7826 * @param url E.g., 'localstorage://my-model'
7827 * @returns A dictionary with two fields: scheme and path.
7828 * Scheme: e.g., 'localstorage' in the example above.
7829 * Path: e.g., 'my-model' in the example above.
7830 */
7831 function parseURL(url) {
7832 if (url.indexOf(URL_SCHEME_SUFFIX) === -1) {
7833 throw new Error(`The url string provided does not contain a scheme. ` +
7834 `Supported schemes are: ` +
7835 `${ModelStoreManagerRegistry.getSchemes().join(',')}`);
7836 }
7837 return {
7838 scheme: url.split(URL_SCHEME_SUFFIX)[0],
7839 path: url.split(URL_SCHEME_SUFFIX)[1],
7840 };
7841 }
7842 async function cloneModelInternal(sourceURL, destURL, deleteSource = false) {
7843 assert$1(sourceURL !== destURL, () => `Old path and new path are the same: '${sourceURL}'`);
7844 const loadHandlers = IORouterRegistry.getLoadHandlers(sourceURL);
7845 assert$1(loadHandlers.length > 0, () => `Copying failed because no load handler is found for source URL ${sourceURL}.`);
7846 assert$1(loadHandlers.length < 2, () => `Copying failed because more than one (${loadHandlers.length}) ` +
7847 `load handlers for source URL ${sourceURL}.`);
7848 const loadHandler = loadHandlers[0];
7849 const saveHandlers = IORouterRegistry.getSaveHandlers(destURL);
7850 assert$1(saveHandlers.length > 0, () => `Copying failed because no save handler is found for destination ` +
7851 `URL ${destURL}.`);
7852 assert$1(saveHandlers.length < 2, () => `Copying failed because more than one (${loadHandlers.length}) ` +
7853 `save handlers for destination URL ${destURL}.`);
7854 const saveHandler = saveHandlers[0];
7855 const sourceScheme = parseURL(sourceURL).scheme;
7856 const sourcePath = parseURL(sourceURL).path;
7857 const sameMedium = sourceScheme === parseURL(sourceURL).scheme;
7858 const modelArtifacts = await loadHandler.load();
7859 // If moving within the same storage medium, remove the old model as soon as
7860 // the loading is done. Without doing this, it is possible that the combined
7861 // size of the two models will cause the cloning to fail.
7862 if (deleteSource && sameMedium) {
7863 await ModelStoreManagerRegistry.getManager(sourceScheme)
7864 .removeModel(sourcePath);
7865 }
7866 const saveResult = await saveHandler.save(modelArtifacts);
7867 // If moving between mediums, the deletion is done after the save succeeds.
7868 // This guards against the case in which saving to the destination medium
7869 // fails.
7870 if (deleteSource && !sameMedium) {
7871 await ModelStoreManagerRegistry.getManager(sourceScheme)
7872 .removeModel(sourcePath);
7873 }
7874 return saveResult.modelArtifactsInfo;
7875 }
7876 /**
7877 * List all models stored in registered storage mediums.
7878 *
7879 * For a web browser environment, the registered mediums are Local Storage and
7880 * IndexedDB.
7881 *
7882 * ```js
7883 * // First create and save a model.
7884 * const model = tf.sequential();
7885 * model.add(tf.layers.dense(
7886 * {units: 1, inputShape: [10], activation: 'sigmoid'}));
7887 * await model.save('localstorage://demo/management/model1');
7888 *
7889 * // Then list existing models.
7890 * console.log(JSON.stringify(await tf.io.listModels()));
7891 *
7892 * // Delete the model.
7893 * await tf.io.removeModel('localstorage://demo/management/model1');
7894 *
7895 * // List models again.
7896 * console.log(JSON.stringify(await tf.io.listModels()));
7897 * ```
7898 *
7899 * @returns A `Promise` of a dictionary mapping URLs of existing models to
7900 * their model artifacts info. URLs include medium-specific schemes, e.g.,
7901 * 'indexeddb://my/model/1'. Model artifacts info include type of the
7902 * model's topology, byte sizes of the topology, weights, etc.
7903 *
7904 * @doc {
7905 * heading: 'Models',
7906 * subheading: 'Management',
7907 * namespace: 'io',
7908 * ignoreCI: true
7909 * }
7910 */
7911 async function listModels() {
7912 const schemes = ModelStoreManagerRegistry.getSchemes();
7913 const out = {};
7914 for (const scheme of schemes) {
7915 const schemeOut = await ModelStoreManagerRegistry.getManager(scheme).listModels();
7916 for (const path in schemeOut) {
7917 const url = scheme + URL_SCHEME_SUFFIX + path;
7918 out[url] = schemeOut[path];
7919 }
7920 }
7921 return out;
7922 }
7923 /**
7924 * Remove a model specified by URL from a registered storage medium.
7925 *
7926 * ```js
7927 * // First create and save a model.
7928 * const model = tf.sequential();
7929 * model.add(tf.layers.dense(
7930 * {units: 1, inputShape: [10], activation: 'sigmoid'}));
7931 * await model.save('localstorage://demo/management/model1');
7932 *
7933 * // Then list existing models.
7934 * console.log(JSON.stringify(await tf.io.listModels()));
7935 *
7936 * // Delete the model.
7937 * await tf.io.removeModel('localstorage://demo/management/model1');
7938 *
7939 * // List models again.
7940 * console.log(JSON.stringify(await tf.io.listModels()));
7941 * ```
7942 *
7943 * @param url A URL to a stored model, with a scheme prefix, e.g.,
7944 * 'localstorage://my-model-1', 'indexeddb://my/model/2'.
7945 * @returns ModelArtifactsInfo of the deleted model (if and only if deletion
7946 * is successful).
7947 * @throws Error if deletion fails, e.g., if no model exists at `path`.
7948 *
7949 * @doc {
7950 * heading: 'Models',
7951 * subheading: 'Management',
7952 * namespace: 'io',
7953 * ignoreCI: true
7954 * }
7955 */
7956 async function removeModel(url) {
7957 const schemeAndPath = parseURL(url);
7958 const manager = ModelStoreManagerRegistry.getManager(schemeAndPath.scheme);
7959 return manager.removeModel(schemeAndPath.path);
7960 }
7961 /**
7962 * Copy a model from one URL to another.
7963 *
7964 * This function supports:
7965 *
7966 * 1. Copying within a storage medium, e.g.,
7967 * `tf.io.copyModel('localstorage://model-1', 'localstorage://model-2')`
7968 * 2. Copying between two storage mediums, e.g.,
7969 * `tf.io.copyModel('localstorage://model-1', 'indexeddb://model-1')`
7970 *
7971 * ```js
7972 * // First create and save a model.
7973 * const model = tf.sequential();
7974 * model.add(tf.layers.dense(
7975 * {units: 1, inputShape: [10], activation: 'sigmoid'}));
7976 * await model.save('localstorage://demo/management/model1');
7977 *
7978 * // Then list existing models.
7979 * console.log(JSON.stringify(await tf.io.listModels()));
7980 *
7981 * // Copy the model, from Local Storage to IndexedDB.
7982 * await tf.io.copyModel(
7983 * 'localstorage://demo/management/model1',
7984 * 'indexeddb://demo/management/model1');
7985 *
7986 * // List models again.
7987 * console.log(JSON.stringify(await tf.io.listModels()));
7988 *
7989 * // Remove both models.
7990 * await tf.io.removeModel('localstorage://demo/management/model1');
7991 * await tf.io.removeModel('indexeddb://demo/management/model1');
7992 * ```
7993 *
7994 * @param sourceURL Source URL of copying.
7995 * @param destURL Destination URL of copying.
7996 * @returns ModelArtifactsInfo of the copied model (if and only if copying
7997 * is successful).
7998 * @throws Error if copying fails, e.g., if no model exists at `sourceURL`, or
7999 * if `oldPath` and `newPath` are identical.
8000 *
8001 * @doc {
8002 * heading: 'Models',
8003 * subheading: 'Management',
8004 * namespace: 'io',
8005 * ignoreCI: true
8006 * }
8007 */
8008 async function copyModel(sourceURL, destURL) {
8009 const deleteSource = false;
8010 return cloneModelInternal(sourceURL, destURL, deleteSource);
8011 }
8012 /**
8013 * Move a model from one URL to another.
8014 *
8015 * This function supports:
8016 *
8017 * 1. Moving within a storage medium, e.g.,
8018 * `tf.io.moveModel('localstorage://model-1', 'localstorage://model-2')`
8019 * 2. Moving between two storage mediums, e.g.,
8020 * `tf.io.moveModel('localstorage://model-1', 'indexeddb://model-1')`
8021 *
8022 * ```js
8023 * // First create and save a model.
8024 * const model = tf.sequential();
8025 * model.add(tf.layers.dense(
8026 * {units: 1, inputShape: [10], activation: 'sigmoid'}));
8027 * await model.save('localstorage://demo/management/model1');
8028 *
8029 * // Then list existing models.
8030 * console.log(JSON.stringify(await tf.io.listModels()));
8031 *
8032 * // Move the model, from Local Storage to IndexedDB.
8033 * await tf.io.moveModel(
8034 * 'localstorage://demo/management/model1',
8035 * 'indexeddb://demo/management/model1');
8036 *
8037 * // List models again.
8038 * console.log(JSON.stringify(await tf.io.listModels()));
8039 *
8040 * // Remove the moved model.
8041 * await tf.io.removeModel('indexeddb://demo/management/model1');
8042 * ```
8043 *
8044 * @param sourceURL Source URL of moving.
8045 * @param destURL Destination URL of moving.
8046 * @returns ModelArtifactsInfo of the copied model (if and only if copying
8047 * is successful).
8048 * @throws Error if moving fails, e.g., if no model exists at `sourceURL`, or
8049 * if `oldPath` and `newPath` are identical.
8050 *
8051 * @doc {
8052 * heading: 'Models',
8053 * subheading: 'Management',
8054 * namespace: 'io',
8055 * ignoreCI: true
8056 * }
8057 */
8058 async function moveModel(sourceURL, destURL) {
8059 const deleteSource = true;
8060 return cloneModelInternal(sourceURL, destURL, deleteSource);
8061 }
8062
8063 /**
8064 * @license
8065 * Copyright 2019 Google LLC. All Rights Reserved.
8066 * Licensed under the Apache License, Version 2.0 (the "License");
8067 * you may not use this file except in compliance with the License.
8068 * You may obtain a copy of the License at
8069 *
8070 * http://www.apache.org/licenses/LICENSE-2.0
8071 *
8072 * Unless required by applicable law or agreed to in writing, software
8073 * distributed under the License is distributed on an "AS IS" BASIS,
8074 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8075 * See the License for the specific language governing permissions and
8076 * limitations under the License.
8077 * =============================================================================
8078 */
8079 class PlatformBrowser {
8080 constructor() {
8081 // For setTimeoutCustom
8082 this.messageName = 'setTimeoutCustom';
8083 this.functionRefs = [];
8084 this.handledMessageCount = 0;
8085 this.hasEventListener = false;
8086 }
8087 fetch(path, init) {
8088 return fetch(path, init);
8089 }
8090 now() {
8091 return performance.now();
8092 }
8093 encode(text, encoding) {
8094 if (encoding !== 'utf-8' && encoding !== 'utf8') {
8095 throw new Error(`Browser's encoder only supports utf-8, but got ${encoding}`);
8096 }
8097 if (this.textEncoder == null) {
8098 this.textEncoder = new TextEncoder();
8099 }
8100 return this.textEncoder.encode(text);
8101 }
8102 decode(bytes, encoding) {
8103 return new TextDecoder(encoding).decode(bytes);
8104 }
8105 // If the setTimeout nesting level is greater than 5 and timeout is less
8106 // than 4ms, timeout will be clamped to 4ms, which hurts the perf.
8107 // Interleaving window.postMessage and setTimeout will trick the browser and
8108 // avoid the clamp.
8109 setTimeoutCustom(functionRef, delay) {
8110 if (typeof window === 'undefined' ||
8111 !env().getBool('USE_SETTIMEOUTCUSTOM')) {
8112 setTimeout(functionRef, delay);
8113 return;
8114 }
8115 this.functionRefs.push(functionRef);
8116 setTimeout(() => {
8117 window.postMessage({ name: this.messageName, index: this.functionRefs.length - 1 }, '*');
8118 }, delay);
8119 if (!this.hasEventListener) {
8120 this.hasEventListener = true;
8121 window.addEventListener('message', (event) => {
8122 if (event.source === window && event.data.name === this.messageName) {
8123 event.stopPropagation();
8124 const functionRef = this.functionRefs[event.data.index];
8125 functionRef();
8126 this.handledMessageCount++;
8127 if (this.handledMessageCount === this.functionRefs.length) {
8128 this.functionRefs = [];
8129 this.handledMessageCount = 0;
8130 }
8131 }
8132 }, true);
8133 }
8134 }
8135 isTypedArray(a) {
8136 return isTypedArrayBrowser(a);
8137 }
8138 }
8139 if (env().get('IS_BROWSER')) {
8140 env().setPlatform('browser', new PlatformBrowser());
8141 // Register LocalStorage IOHandler
8142 try {
8143 ModelStoreManagerRegistry.registerManager(BrowserLocalStorage.URL_SCHEME, new BrowserLocalStorageManager());
8144 }
8145 catch (err) {
8146 }
8147 // Register IndexedDB IOHandler
8148 try {
8149 ModelStoreManagerRegistry.registerManager(BrowserIndexedDB.URL_SCHEME, new BrowserIndexedDBManager());
8150 }
8151 catch (err) {
8152 }
8153 }
8154
8155 /**
8156 * @license
8157 * Copyright 2019 Google LLC. All Rights Reserved.
8158 * Licensed under the Apache License, Version 2.0 (the "License");
8159 * you may not use this file except in compliance with the License.
8160 * You may obtain a copy of the License at
8161 *
8162 * http://www.apache.org/licenses/LICENSE-2.0
8163 *
8164 * Unless required by applicable law or agreed to in writing, software
8165 * distributed under the License is distributed on an "AS IS" BASIS,
8166 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8167 * See the License for the specific language governing permissions and
8168 * limitations under the License.
8169 * =============================================================================
8170 */
8171 // We are wrapping this within an object so it can be stubbed by Jasmine.
8172 const getNodeFetch = {
8173 // tslint:disable-next-line:no-require-imports
8174 importFetch: () => require('node-fetch')
8175 };
8176 let systemFetch;
8177 // These getters and setters are for testing so we don't export a mutable
8178 // variable.
8179 function resetSystemFetch() {
8180 systemFetch = null;
8181 }
8182 function setSystemFetch(fetchFn) {
8183 systemFetch = fetchFn;
8184 }
8185 function getSystemFetch() {
8186 return systemFetch;
8187 }
8188 class PlatformNode {
8189 constructor() {
8190 // tslint:disable-next-line:no-require-imports
8191 this.util = require('util');
8192 // According to the spec, the built-in encoder can do only UTF-8 encoding.
8193 // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder
8194 this.textEncoder = new this.util.TextEncoder();
8195 }
8196 fetch(path, requestInits) {
8197 if (env().global.fetch != null) {
8198 return env().global.fetch(path, requestInits);
8199 }
8200 if (systemFetch == null) {
8201 systemFetch = getNodeFetch.importFetch();
8202 }
8203 return systemFetch(path, requestInits);
8204 }
8205 now() {
8206 const time = process.hrtime();
8207 return time[0] * 1000 + time[1] / 1000000;
8208 }
8209 encode(text, encoding) {
8210 if (encoding !== 'utf-8' && encoding !== 'utf8') {
8211 throw new Error(`Node built-in encoder only supports utf-8, but got ${encoding}`);
8212 }
8213 return this.textEncoder.encode(text);
8214 }
8215 decode(bytes, encoding) {
8216 if (bytes.length === 0) {
8217 return '';
8218 }
8219 return new this.util.TextDecoder(encoding).decode(bytes);
8220 }
8221 isTypedArray(a) {
8222 return this.util.types.isFloat32Array(a)
8223 || this.util.types.isInt32Array(a)
8224 || this.util.types.isUint8Array(a)
8225 || this.util.types.isUint8ClampedArray(a);
8226 }
8227 }
8228 if (env().get('IS_NODE') && !env().get('IS_BROWSER')) {
8229 env().setPlatform('node', new PlatformNode());
8230 }
8231
8232 /**
8233 * @license
8234 * Copyright 2020 Google Inc. All Rights Reserved.
8235 * Licensed under the Apache License, Version 2.0 (the "License");
8236 * you may not use this file except in compliance with the License.
8237 * You may obtain a copy of the License at
8238 *
8239 * http://www.apache.org/licenses/LICENSE-2.0
8240 *
8241 * Unless required by applicable law or agreed to in writing, software
8242 * distributed under the License is distributed on an "AS IS" BASIS,
8243 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8244 * See the License for the specific language governing permissions and
8245 * limitations under the License.
8246 * =============================================================================
8247 */
8248 /**
8249 * Creates an empty `tf.TensorBuffer` with the specified `shape` and `dtype`.
8250 *
8251 * The values are stored in CPU as `TypedArray`. Fill the buffer using
8252 * `buffer.set()`, or by modifying directly `buffer.values`.
8253 *
8254 * When done, call `buffer.toTensor()` to get an immutable `tf.Tensor` with
8255 * those values.
8256 *
8257 * ```js
8258 * // Create a buffer and set values at particular indices.
8259 * const buffer = tf.buffer([2, 2]);
8260 * buffer.set(3, 0, 0);
8261 * buffer.set(5, 1, 0);
8262 *
8263 * // Convert the buffer back to a tensor.
8264 * buffer.toTensor().print();
8265 * ```
8266 *
8267 * @param shape An array of integers defining the output tensor shape.
8268 * @param dtype The dtype of the buffer. Defaults to 'float32'.
8269 * @param values The values of the buffer as `TypedArray`. Defaults to
8270 * zeros.
8271 *
8272 * @doc {heading: 'Tensors', subheading: 'Creation'}
8273 */
8274 function buffer(shape, dtype = 'float32', values) {
8275 dtype = dtype || 'float32';
8276 assertNonNegativeIntegerDimensions(shape);
8277 return new TensorBuffer(shape, dtype, values);
8278 }
8279
8280 /**
8281 * @license
8282 * Copyright 2020 Google Inc. All Rights Reserved.
8283 * Licensed under the Apache License, Version 2.0 (the "License");
8284 * you may not use this file except in compliance with the License.
8285 * You may obtain a copy of the License at
8286 *
8287 * http://www.apache.org/licenses/LICENSE-2.0
8288 *
8289 * Unless required by applicable law or agreed to in writing, software
8290 * distributed under the License is distributed on an "AS IS" BASIS,
8291 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8292 * See the License for the specific language governing permissions and
8293 * limitations under the License.
8294 * =============================================================================
8295 */
8296 /**
8297 * Casts a `tf.Tensor` to a new dtype.
8298 *
8299 * ```js
8300 * const x = tf.tensor1d([1.5, 2.5, 3]);
8301 * tf.cast(x, 'int32').print();
8302 * ```
8303 * @param x The input tensor to be casted.
8304 * @param dtype The dtype to cast the input tensor to.
8305 *
8306 * @doc {heading: 'Tensors', subheading: 'Transformations'}
8307 */
8308 function cast_(x, dtype) {
8309 const $x = convertToTensor(x, 'x', 'cast');
8310 // Sanity checks.
8311 if (!isValidDtype(dtype)) {
8312 throw new Error(`Failed to cast to unknown dtype ${dtype}`);
8313 }
8314 if (dtype === 'string' && $x.dtype !== 'string' ||
8315 dtype !== 'string' && $x.dtype === 'string') {
8316 throw new Error('Only strings can be casted to strings');
8317 }
8318 const inputs = { x: $x };
8319 const attrs = { dtype };
8320 return ENGINE.runKernel(Cast, inputs, attrs);
8321 }
8322 const cast$3 = /* @__PURE__ */ op({ cast_ });
8323
8324 /**
8325 * @license
8326 * Copyright 2020 Google LLC. All Rights Reserved.
8327 * Licensed under the Apache License, Version 2.0 (the "License");
8328 * you may not use this file except in compliance with the License.
8329 * You may obtain a copy of the License at
8330 *
8331 * http://www.apache.org/licenses/LICENSE-2.0
8332 *
8333 * Unless required by applicable law or agreed to in writing, software
8334 * distributed under the License is distributed on an "AS IS" BASIS,
8335 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8336 * See the License for the specific language governing permissions and
8337 * limitations under the License.
8338 * =============================================================================
8339 */
8340 /**
8341 * Creates a new tensor with the same values and shape as the specified
8342 * tensor.
8343 *
8344 * ```js
8345 * const x = tf.tensor([1, 2]);
8346 *
8347 * x.clone().print();
8348 * ```
8349 *
8350 * @param x The tensor to clone.
8351 *
8352 * @doc {heading: 'Tensors', subheading: 'Creation'}
8353 */
8354 function clone_(x) {
8355 const $x = convertToTensor(x, 'x', 'clone', 'string_or_numeric');
8356 const inputs = { x: $x };
8357 // Note this op is called tf.identity in python. Hence the kernel name used
8358 // here.
8359 return ENGINE.runKernel(Identity$1, inputs);
8360 }
8361 const clone = /* @__PURE__ */ op({ clone_ });
8362
8363 /**
8364 * @license
8365 * Copyright 2020 Google Inc. All Rights Reserved.
8366 * Licensed under the Apache License, Version 2.0 (the "License");
8367 * you may not use this file except in compliance with the License.
8368 * You may obtain a copy of the License at
8369 *
8370 * http://www.apache.org/licenses/LICENSE-2.0
8371 *
8372 * Unless required by applicable law or agreed to in writing, software
8373 * distributed under the License is distributed on an "AS IS" BASIS,
8374 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8375 * See the License for the specific language governing permissions and
8376 * limitations under the License.
8377 * =============================================================================
8378 */
8379 /**
8380 * Prints information about the `tf.Tensor` including its data.
8381 *
8382 * ```js
8383 * const verbose = true;
8384 * tf.tensor2d([1, 2, 3, 4], [2, 2]).print(verbose);
8385 * ```
8386 * @param x The tensor to be printed.
8387 * @param verbose Whether to print verbose information about the ` Tensor`,
8388 * including dtype and size.
8389 *
8390 * @doc {heading: 'Tensors', subheading: 'Creation'}
8391 */
8392 function print(x, verbose = false) {
8393 console.log(x.toString(verbose));
8394 }
8395
8396 /**
8397 * @license
8398 * Copyright 2020 Google Inc. All Rights Reserved.
8399 * Licensed under the Apache License, Version 2.0 (the "License");
8400 * you may not use this file except in compliance with the License.
8401 * You may obtain a copy of the License at
8402 *
8403 * http://www.apache.org/licenses/LICENSE-2.0
8404 *
8405 * Unless required by applicable law or agreed to in writing, software
8406 * distributed under the License is distributed on an "AS IS" BASIS,
8407 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8408 * See the License for the specific language governing permissions and
8409 * limitations under the License.
8410 * =============================================================================
8411 */
8412 getOrMakeEngine();
8413 const opHandler = {
8414 buffer,
8415 cast: cast$3,
8416 clone,
8417 print
8418 };
8419 setOpHandler(opHandler);
8420
8421 /**
8422 * @license
8423 * Copyright 2020 Google LLC. All Rights Reserved.
8424 * Licensed under the Apache License, Version 2.0 (the "License");
8425 * you may not use this file except in compliance with the License.
8426 * You may obtain a copy of the License at
8427 *
8428 * http://www.apache.org/licenses/LICENSE-2.0
8429 *
8430 * Unless required by applicable law or agreed to in writing, software
8431 * distributed under the License is distributed on an "AS IS" BASIS,
8432 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8433 * See the License for the specific language governing permissions and
8434 * limitations under the License.
8435 * =============================================================================
8436 */
8437 /**
8438 * Adds two `tf.Tensor`s element-wise, A + B. Supports broadcasting.
8439 *
8440 *
8441 * ```js
8442 * const a = tf.tensor1d([1, 2, 3, 4]);
8443 * const b = tf.tensor1d([10, 20, 30, 40]);
8444 *
8445 * a.add(b).print(); // or tf.add(a, b)
8446 * ```
8447 *
8448 * ```js
8449 * // Broadcast add a with b.
8450 * const a = tf.scalar(5);
8451 * const b = tf.tensor1d([10, 20, 30, 40]);
8452 *
8453 * a.add(b).print(); // or tf.add(a, b)
8454 * ```
8455 * @param a The first `tf.Tensor` to add.
8456 * @param b The second `tf.Tensor` to add. Must have the same type as `a`.
8457 *
8458 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
8459 */
8460 function add_(a, b) {
8461 let $a = convertToTensor(a, 'a', 'add');
8462 let $b = convertToTensor(b, 'b', 'add');
8463 [$a, $b] = makeTypesMatch($a, $b);
8464 const inputs = { a: $a, b: $b };
8465 return ENGINE.runKernel(Add$1, inputs);
8466 }
8467 const add$3 = /* @__PURE__ */ op({ add_ });
8468
8469 /**
8470 * @license
8471 * Copyright 2020 Google LLC. All Rights Reserved.
8472 * Licensed under the Apache License, Version 2.0 (the "License");
8473 * you may not use this file except in compliance with the License.
8474 * You may obtain a copy of the License at
8475 *
8476 * http://www.apache.org/licenses/LICENSE-2.0
8477 *
8478 * Unless required by applicable law or agreed to in writing, software
8479 * distributed under the License is distributed on an "AS IS" BASIS,
8480 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8481 * See the License for the specific language governing permissions and
8482 * limitations under the License.
8483 * =============================================================================
8484 */
8485 /**
8486 * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting.
8487 * The result is rounded with floor function.
8488 *
8489 *
8490 * ```js
8491 * const a = tf.tensor1d([1, 4, 9, 16]);
8492 * const b = tf.tensor1d([1, 2, 3, 4]);
8493 *
8494 * a.floorDiv(b).print(); // or tf.div(a, b)
8495 * ```
8496 *
8497 * ```js
8498 * // Broadcast div a with b.
8499 * const a = tf.tensor1d([2, 4, 6, 8]);
8500 * const b = tf.scalar(2);
8501 *
8502 * a.floorDiv(b).print(); // or tf.floorDiv(a, b)
8503 * ```
8504 *
8505 * @param a The first tensor as the numerator.
8506 * @param b The second tensor as the denominator. Must have the same dtype as
8507 * `a`.
8508 *
8509 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
8510 */
8511 function floorDiv_(a, b) {
8512 let $a = convertToTensor(a, 'a', 'floorDiv');
8513 let $b = convertToTensor(b, 'b', 'floorDiv');
8514 [$a, $b] = makeTypesMatch($a, $b);
8515 const inputs = { a: $a, b: $b };
8516 return ENGINE.runKernel(FloorDiv, inputs);
8517 }
8518 const floorDiv$2 = /* @__PURE__ */ op({ floorDiv_ });
8519
8520 /**
8521 * @license
8522 * Copyright 2020 Google LLC. All Rights Reserved.
8523 * Licensed under the Apache License, Version 2.0 (the "License");
8524 * you may not use this file except in compliance with the License.
8525 * You may obtain a copy of the License at
8526 *
8527 * http://www.apache.org/licenses/LICENSE-2.0
8528 *
8529 * Unless required by applicable law or agreed to in writing, software
8530 * distributed under the License is distributed on an "AS IS" BASIS,
8531 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8532 * See the License for the specific language governing permissions and
8533 * limitations under the License.
8534 * =============================================================================
8535 */
8536 /**
8537 * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting.
8538 *
8539 * ```js
8540 * const a = tf.tensor1d([1, 4, 9, 16]);
8541 * const b = tf.tensor1d([1, 2, 3, 4]);
8542 *
8543 * a.div(b).print(); // or tf.div(a, b)
8544 * ```
8545 *
8546 * ```js
8547 * // Broadcast div a with b.
8548 * const a = tf.tensor1d([2, 4, 6, 8]);
8549 * const b = tf.scalar(2);
8550 *
8551 * a.div(b).print(); // or tf.div(a, b)
8552 * ```
8553 *
8554 * @param a The first tensor as the numerator.
8555 * @param b The second tensor as the denominator. Must have the same dtype as
8556 * `a`.
8557 *
8558 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
8559 */
8560 function div_(a, b) {
8561 let $a = convertToTensor(a, 'a', 'div');
8562 let $b = convertToTensor(b, 'b', 'div');
8563 [$a, $b] = makeTypesMatch($a, $b);
8564 if ($a.dtype === 'int32' && $b.dtype === 'int32') {
8565 return floorDiv$2($a, $b);
8566 }
8567 const inputs = { a: $a, b: $b };
8568 const attrs = {};
8569 // tslint:disable-next-line: no-unnecessary-type-assertion
8570 return ENGINE.runKernel(RealDiv, inputs, attrs);
8571 }
8572 const div$1 = /* @__PURE__ */ op({ div_ });
8573
8574 /**
8575 * @license
8576 * Copyright 2020 Google LLC. All Rights Reserved.
8577 * Licensed under the Apache License, Version 2.0 (the "License");
8578 * you may not use this file except in compliance with the License.
8579 * You may obtain a copy of the License at
8580 *
8581 * http://www.apache.org/licenses/LICENSE-2.0
8582 *
8583 * Unless required by applicable law or agreed to in writing, software
8584 * distributed under the License is distributed on an "AS IS" BASIS,
8585 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8586 * See the License for the specific language governing permissions and
8587 * limitations under the License.
8588 * =============================================================================
8589 */
8590 /**
8591 * Multiplies two `tf.Tensor`s element-wise, A * B. Supports broadcasting.
8592 *
8593 * We also expose `tf.mulStrict` which has the same signature as this op and
8594 * asserts that `a` and `b` are the same shape (does not broadcast).
8595 *
8596 * ```js
8597 * const a = tf.tensor1d([1, 2, 3, 4]);
8598 * const b = tf.tensor1d([2, 3, 4, 5]);
8599 *
8600 * a.mul(b).print(); // or tf.mul(a, b)
8601 * ```
8602 *
8603 * ```js
8604 * // Broadcast mul a with b.
8605 * const a = tf.tensor1d([1, 2, 3, 4]);
8606 * const b = tf.scalar(5);
8607 *
8608 * a.mul(b).print(); // or tf.mul(a, b)
8609 * ```
8610 * @param a The first tensor to multiply.
8611 * @param b The second tensor to multiply. Must have the same dtype as `a`.
8612 *
8613 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
8614 */
8615 function mul_(a, b) {
8616 let $a = convertToTensor(a, 'a', 'mul');
8617 let $b = convertToTensor(b, 'b', 'mul');
8618 [$a, $b] = makeTypesMatch($a, $b);
8619 const inputs = { a: $a, b: $b };
8620 return ENGINE.runKernel(Multiply$1, inputs);
8621 }
8622 const mul = /* @__PURE__ */ op({ mul_ });
8623
8624 /**
8625 * @license
8626 * Copyright 2018 Google LLC. All Rights Reserved.
8627 * Licensed under the Apache License, Version 2.0 (the "License");
8628 * you may not use this file except in compliance with the License.
8629 * You may obtain a copy of the License at
8630 *
8631 * http://www.apache.org/licenses/LICENSE-2.0
8632 *
8633 * Unless required by applicable law or agreed to in writing, software
8634 * distributed under the License is distributed on an "AS IS" BASIS,
8635 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8636 * See the License for the specific language governing permissions and
8637 * limitations under the License.
8638 * =============================================================================
8639 */
8640 /**
8641 * Computes absolute value element-wise: `abs(x)`
8642 *
8643 * ```js
8644 * const x = tf.tensor1d([-1, 2, -3, 4]);
8645 *
8646 * x.abs().print(); // or tf.abs(x)
8647 * ```
8648 * @param x The input `tf.Tensor`.
8649 *
8650 * @doc {heading: 'Operations', subheading: 'Basic math'}
8651 */
8652 function abs_(x) {
8653 const $x = convertToTensor(x, 'x', 'abs');
8654 if ($x.dtype === 'complex64') {
8655 const inputs = { x: $x };
8656 return ENGINE.runKernel(ComplexAbs, inputs);
8657 }
8658 else {
8659 const inputs = { x: $x };
8660 return ENGINE.runKernel(Abs, inputs);
8661 }
8662 }
8663 const abs$2 = /* @__PURE__ */ op({ abs_ });
8664
8665 /**
8666 * @license
8667 * Copyright 2018 Google LLC. All Rights Reserved.
8668 * Licensed under the Apache License, Version 2.0 (the "License");
8669 * you may not use this file except in compliance with the License.
8670 * You may obtain a copy of the License at
8671 *
8672 * http://www.apache.org/licenses/LICENSE-2.0
8673 *
8674 * Unless required by applicable law or agreed to in writing, software
8675 * distributed under the License is distributed on an "AS IS" BASIS,
8676 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8677 * See the License for the specific language governing permissions and
8678 * limitations under the License.
8679 * =============================================================================
8680 */
8681 /**
8682 * Computes acos of the input `tf.Tensor` element-wise: `acos(x)`
8683 *
8684 * ```js
8685 * const x = tf.tensor1d([0, 1, -1, .7]);
8686 *
8687 * x.acos().print(); // or tf.acos(x)
8688 * ```
8689 * @param x The input tensor.
8690 * @doc {heading: 'Operations', subheading: 'Basic math'}
8691 */
8692 function acos_(x) {
8693 const $x = convertToTensor(x, 'x', 'acos');
8694 const inputs = { x: $x };
8695 return ENGINE.runKernel(Acos, inputs);
8696 }
8697 const acos$2 = /* @__PURE__ */ op({ acos_ });
8698
8699 /**
8700 * @license
8701 * Copyright 2018 Google LLC. All Rights Reserved.
8702 * Licensed under the Apache License, Version 2.0 (the "License");
8703 * you may not use this file except in compliance with the License.
8704 * You may obtain a copy of the License at
8705 *
8706 * http://www.apache.org/licenses/LICENSE-2.0
8707 *
8708 * Unless required by applicable law or agreed to in writing, software
8709 * distributed under the License is distributed on an "AS IS" BASIS,
8710 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8711 * See the License for the specific language governing permissions and
8712 * limitations under the License.
8713 * =============================================================================
8714 */
8715 /**
8716 * Computes the inverse hyperbolic cos of the input `tf.Tensor` element-wise:
8717 * `acosh(x)`
8718 *
8719 * ```js
8720 * const x = tf.tensor1d([10, 1, 3, 5.7]);
8721 *
8722 * x.acosh().print(); // or tf.acosh(x)
8723 * ```
8724 * @param x The input tensor.
8725 *
8726 * @doc {heading: 'Operations', subheading: 'Basic math'}
8727 */
8728 function acosh_(x) {
8729 const $x = convertToTensor(x, 'x', 'acosh');
8730 const inputs = { x: $x };
8731 return ENGINE.runKernel(Acosh, inputs);
8732 }
8733 const acosh$2 = /* @__PURE__ */ op({ acosh_ });
8734
8735 /**
8736 * @license
8737 * Copyright 2020 Google LLC. All Rights Reserved.
8738 * Licensed under the Apache License, Version 2.0 (the "License");
8739 * you may not use this file except in compliance with the License.
8740 * You may obtain a copy of the License at
8741 *
8742 * http://www.apache.org/licenses/LICENSE-2.0
8743 *
8744 * Unless required by applicable law or agreed to in writing, software
8745 * distributed under the License is distributed on an "AS IS" BASIS,
8746 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8747 * See the License for the specific language governing permissions and
8748 * limitations under the License.
8749 * =============================================================================
8750 */
8751 /**
8752 * Adds a list of `tf.Tensor`s element-wise, each with the same shape and dtype.
8753 *
8754 * ```js
8755 * const a = tf.tensor1d([1, 2]);
8756 * const b = tf.tensor1d([3, 4]);
8757 * const c = tf.tensor1d([5, 6]);
8758 *
8759 * tf.addN([a, b, c]).print();
8760 * ```
8761 * @param tensors A list of tensors with the same shape and dtype.
8762 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
8763 */
8764 function addN_(tensors) {
8765 assert$1(Array.isArray(tensors), () => 'The argument passed to tf.addN() must be a list of tensors');
8766 assert$1(tensors.length >= 1, () => `Must pass at least one tensor to tf.addN(), but got ` +
8767 `${tensors.length}`);
8768 const $tensors = tensors.map((t, i) => convertToTensor(t, `tensors${i}`, 'addN'));
8769 const firstTensor = $tensors[0];
8770 $tensors.forEach(t => {
8771 if (t.dtype !== firstTensor.dtype) {
8772 throw new Error('All tensors passed to tf.addN() must have the same dtype');
8773 }
8774 });
8775 $tensors.forEach(t => {
8776 if (!arraysEqual(t.shape, firstTensor.shape)) {
8777 throw new Error('All tensors passed to tf.addN() must have the same shape');
8778 }
8779 });
8780 const inputs = $tensors;
8781 return ENGINE.runKernel(AddN, inputs);
8782 }
8783 const addN$2 = /* @__PURE__ */ op({ addN_ });
8784
8785 /**
8786 * @license
8787 * Copyright 2020 Google LLC. All Rights Reserved.
8788 * Licensed under the Apache License, Version 2.0 (the "License");
8789 * you may not use this file except in compliance with the License.
8790 * You may obtain a copy of the License at
8791 *
8792 * http://www.apache.org/licenses/LICENSE-2.0
8793 *
8794 * Unless required by applicable law or agreed to in writing, software
8795 * distributed under the License is distributed on an "AS IS" BASIS,
8796 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8797 * See the License for the specific language governing permissions and
8798 * limitations under the License.
8799 * =============================================================================
8800 */
8801 /**
8802 * Computes the logical and of elements across dimensions of a `tf.Tensor`.
8803 *
8804 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
8805 * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
8806 * `axes`. If `keepDims` is true, the reduced dimensions are retained with
8807 * length 1. If `axes` has no entries, all dimensions are reduced, and a
8808 * `tf.Tensor` with a single element is returned.
8809 *
8810 * ```js
8811 * const x = tf.tensor1d([1, 1, 1], 'bool');
8812 *
8813 * x.all().print(); // or tf.all(x)
8814 * ```
8815 *
8816 * ```js
8817 * const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');
8818 *
8819 * const axis = 1;
8820 * x.all(axis).print(); // or tf.all(x, axis)
8821 * ```
8822 *
8823 * @param x The input tensor. Must be of dtype bool.
8824 * @param axis The dimension(s) to reduce. By default it reduces
8825 * all dimensions.
8826 * @param keepDims If true, retains reduced dimensions with size 1.
8827 *
8828 * @doc {heading: 'Operations', subheading: 'Reduction'}
8829 */
8830 function all_(x, axis = null, keepDims = false) {
8831 const $x = convertToTensor(x, 'x', 'all', 'bool');
8832 const inputs = { x: $x };
8833 const attrs = { axis, keepDims };
8834 return ENGINE.runKernel(All, inputs, attrs);
8835 }
8836 const all$2 = /* @__PURE__ */ op({ all_ });
8837
8838 /**
8839 * @license
8840 * Copyright 2020 Google LLC. All Rights Reserved.
8841 * Licensed under the Apache License, Version 2.0 (the "License");
8842 * you may not use this file except in compliance with the License.
8843 * You may obtain a copy of the License at
8844 *
8845 * http://www.apache.org/licenses/LICENSE-2.0
8846 *
8847 * Unless required by applicable law or agreed to in writing, software
8848 * distributed under the License is distributed on an "AS IS" BASIS,
8849 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8850 * See the License for the specific language governing permissions and
8851 * limitations under the License.
8852 * =============================================================================
8853 */
8854 /**
8855 * Computes the logical or of elements across dimensions of a `tf.Tensor`.
8856 *
8857 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
8858 * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
8859 * `axes`. If `keepDims` is true, the reduced dimensions are retained with
8860 * length 1. If `axes` has no entries, all dimensions are reduced, and a
8861 * `tf.Tensor` with a single element is returned.
8862 *
8863 * ```js
8864 * const x = tf.tensor1d([1, 1, 1], 'bool');
8865 *
8866 * x.any().print(); // or tf.any(x)
8867 * ```
8868 *
8869 * ```js
8870 * const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');
8871 *
8872 * const axis = 1;
8873 * x.any(axis).print(); // or tf.any(x, axis)
8874 * ```
8875 *
8876 * @param x The input tensor. Must be of dtype bool.
8877 * @param axis The dimension(s) to reduce. By default it reduces
8878 * all dimensions.
8879 * @param keepDims If true, retains reduced dimensions with size 1.
8880 *
8881 * @doc {heading: 'Operations', subheading: 'Reduction'}
8882 */
8883 function any_(x, axis = null, keepDims = false) {
8884 const $x = convertToTensor(x, 'x', 'any', 'bool');
8885 const inputs = { x: $x };
8886 const attrs = { axis, keepDims };
8887 return ENGINE.runKernel(Any, inputs, attrs);
8888 }
8889 // tslint:disable-next-line:variable-name
8890 const any$2 = /* @__PURE__ */ op({ any_ });
8891
8892 /**
8893 * @license
8894 * Copyright 2020 Google Inc. All Rights Reserved.
8895 * Licensed under the Apache License, Version 2.0 (the "License");
8896 * you may not use this file except in compliance with the License.
8897 * You may obtain a copy of the License at
8898 *
8899 * http://www.apache.org/licenses/LICENSE-2.0
8900 *
8901 * Unless required by applicable law or agreed to in writing, software
8902 * distributed under the License is distributed on an "AS IS" BASIS,
8903 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8904 * See the License for the specific language governing permissions and
8905 * limitations under the License.
8906 * =============================================================================
8907 */
8908 /**
8909 * Returns the indices of the maximum values along an `axis`.
8910 *
8911 * The result has the same shape as `input` with the dimension along `axis`
8912 * removed.
8913 *
8914 * ```js
8915 * const x = tf.tensor1d([1, 2, 3]);
8916 *
8917 * x.argMax().print(); // or tf.argMax(x)
8918 * ```
8919 *
8920 * ```js
8921 * const x = tf.tensor2d([1, 2, 4, 3], [2, 2]);
8922 *
8923 * const axis = 1;
8924 * x.argMax(axis).print(); // or tf.argMax(x, axis)
8925 * ```
8926 *
8927 * @param x The input tensor.
8928 * @param axis The dimension to reduce. Defaults to 0 (outer-most dimension).
8929 *
8930 * @doc {heading: 'Operations', subheading: 'Reduction'}
8931 */
8932 function argMax_(x, axis = 0) {
8933 const $x = convertToTensor(x, 'x', 'argMax');
8934 const inputs = { x: $x };
8935 const attrs = { axis };
8936 return ENGINE.runKernel(ArgMax, inputs, attrs);
8937 }
8938 const argMax$2 = /* @__PURE__ */ op({ argMax_ });
8939
8940 /**
8941 * @license
8942 * Copyright 2020 Google Inc. All Rights Reserved.
8943 * Licensed under the Apache License, Version 2.0 (the "License");
8944 * you may not use this file except in compliance with the License.
8945 * You may obtain a copy of the License at
8946 *
8947 * http://www.apache.org/licenses/LICENSE-2.0
8948 *
8949 * Unless required by applicable law or agreed to in writing, software
8950 * distributed under the License is distributed on an "AS IS" BASIS,
8951 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8952 * See the License for the specific language governing permissions and
8953 * limitations under the License.
8954 * =============================================================================
8955 */
8956 /**
8957 * Returns the indices of the minimum values along an `axis`.
8958 *
8959 * The result has the same shape as `input` with the dimension along `axis`
8960 * removed.
8961 *
8962 * ```js
8963 * const x = tf.tensor1d([1, 2, 3]);
8964 *
8965 * x.argMin().print(); // or tf.argMin(x)
8966 * ```
8967 *
8968 * ```js
8969 * const x = tf.tensor2d([1, 2, 4, 3], [2, 2]);
8970 *
8971 * const axis = 1;
8972 * x.argMin(axis).print(); // or tf.argMin(x, axis)
8973 * ```
8974 *
8975 * @param x The input tensor.
8976 * @param axis The dimension to reduce. Defaults to 0 (outer-most dimension).
8977 *
8978 * @doc {heading: 'Operations', subheading: 'Reduction'}
8979 */
8980 function argMin_(x, axis = 0) {
8981 const $x = convertToTensor(x, 'x', 'argMin');
8982 const inputs = { x: $x };
8983 const attrs = { axis };
8984 return ENGINE.runKernel(ArgMin, inputs, attrs);
8985 }
8986 const argMin$2 = /* @__PURE__ */ op({ argMin_ });
8987
8988 /**
8989 * @license
8990 * Copyright 2018 Google LLC. All Rights Reserved.
8991 * Licensed under the Apache License, Version 2.0 (the "License");
8992 * you may not use this file except in compliance with the License.
8993 * You may obtain a copy of the License at
8994 *
8995 * http://www.apache.org/licenses/LICENSE-2.0
8996 *
8997 * Unless required by applicable law or agreed to in writing, software
8998 * distributed under the License is distributed on an "AS IS" BASIS,
8999 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9000 * See the License for the specific language governing permissions and
9001 * limitations under the License.
9002 * =============================================================================
9003 */
9004 /**
9005 * Computes asin of the input `tf.Tensor` element-wise: `asin(x)`
9006 *
9007 * ```js
9008 * const x = tf.tensor1d([0, 1, -1, .7]);
9009 *
9010 * x.asin().print(); // or tf.asin(x)
9011 * ```
9012 * @param x The input tensor.
9013 * @doc {heading: 'Operations', subheading: 'Basic math'}
9014 */
9015 function asin_(x) {
9016 const $x = convertToTensor(x, 'x', 'asin');
9017 const inputs = { x: $x };
9018 return ENGINE.runKernel(Asin, inputs);
9019 }
9020 const asin$2 = /* @__PURE__ */ op({ asin_ });
9021
9022 /**
9023 * @license
9024 * Copyright 2018 Google LLC. All Rights Reserved.
9025 * Licensed under the Apache License, Version 2.0 (the "License");
9026 * you may not use this file except in compliance with the License.
9027 * You may obtain a copy of the License at
9028 *
9029 * http://www.apache.org/licenses/LICENSE-2.0
9030 *
9031 * Unless required by applicable law or agreed to in writing, software
9032 * distributed under the License is distributed on an "AS IS" BASIS,
9033 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9034 * See the License for the specific language governing permissions and
9035 * limitations under the License.
9036 * =============================================================================
9037 */
9038 /**
9039 * Computes inverse hyperbolic sin of the input `tf.Tensor` element-wise:
9040 * `asinh(x)`
9041 *
9042 * ```js
9043 * const x = tf.tensor1d([0, 1, -1, .7]);
9044 *
9045 * x.asinh().print(); // or tf.asinh(x)
9046 * ```
9047 * @param x The input tensor.
9048 *
9049 * @doc {heading: 'Operations', subheading: 'Basic math'}
9050 */
9051 function asinh_(x) {
9052 const $x = convertToTensor(x, 'x', 'asinh');
9053 const inputs = { x: $x };
9054 return ENGINE.runKernel(Asinh, inputs);
9055 }
9056 const asinh$2 = /* @__PURE__ */ op({ asinh_ });
9057
9058 /**
9059 * @license
9060 * Copyright 2018 Google LLC. All Rights Reserved.
9061 * Licensed under the Apache License, Version 2.0 (the "License");
9062 * you may not use this file except in compliance with the License.
9063 * You may obtain a copy of the License at
9064 *
9065 * http://www.apache.org/licenses/LICENSE-2.0
9066 *
9067 * Unless required by applicable law or agreed to in writing, software
9068 * distributed under the License is distributed on an "AS IS" BASIS,
9069 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9070 * See the License for the specific language governing permissions and
9071 * limitations under the License.
9072 * =============================================================================
9073 */
9074 /**
9075 * Computes atan of the input `tf.Tensor` element-wise: `atan(x)`
9076 *
9077 * ```js
9078 * const x = tf.tensor1d([0, 1, -1, .7]);
9079 *
9080 * x.atan().print(); // or tf.atan(x)
9081 * ```
9082 * @param x The input tensor.
9083 *
9084 * @doc {heading: 'Operations', subheading: 'Basic math'}
9085 */
9086 function atan_(x) {
9087 const $x = convertToTensor(x, 'x', 'atan');
9088 const inputs = { x: $x };
9089 return ENGINE.runKernel(Atan, inputs);
9090 }
9091 const atan$2 = /* @__PURE__ */ op({ atan_ });
9092
9093 /**
9094 * @license
9095 * Copyright 2020 Google LLC. All Rights Reserved.
9096 * Licensed under the Apache License, Version 2.0 (the "License");
9097 * you may not use this file except in compliance with the License.
9098 * You may obtain a copy of the License at
9099 *
9100 * http://www.apache.org/licenses/LICENSE-2.0
9101 *
9102 * Unless required by applicable law or agreed to in writing, software
9103 * distributed under the License is distributed on an "AS IS" BASIS,
9104 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9105 * See the License for the specific language governing permissions and
9106 * limitations under the License.
9107 * =============================================================================
9108 */
9109 /**
9110 * Computes arctangent of `tf.Tensor`s a / b element-wise: `atan2(a, b)`.
9111 * Supports broadcasting.
9112 *
9113 * ```js
9114 * const a = tf.tensor1d([1.0, 1.0, -1.0, .7]);
9115 * const b = tf.tensor1d([2.0, 13.0, 3.5, .21]);
9116 *
9117 * tf.atan2(a, b).print()
9118 * ```
9119 *
9120 * @param a The first tensor.
9121 * @param b The second tensor. Must have the same dtype as `a`.
9122 *
9123 * @doc {heading: 'Operations', subheading: 'Basic math'}
9124 */
9125 function atan2_(a, b) {
9126 let $a = convertToTensor(a, 'a', 'atan2');
9127 let $b = convertToTensor(b, 'b', 'atan2');
9128 [$a, $b] = makeTypesMatch($a, $b);
9129 const inputs = { a: $a, b: $b };
9130 return ENGINE.runKernel(Atan2, inputs);
9131 }
9132 const atan2$2 = /* @__PURE__ */ op({ atan2_ });
9133
9134 /**
9135 * @license
9136 * Copyright 2018 Google LLC. All Rights Reserved.
9137 * Licensed under the Apache License, Version 2.0 (the "License");
9138 * you may not use this file except in compliance with the License.
9139 * You may obtain a copy of the License at
9140 *
9141 * http://www.apache.org/licenses/LICENSE-2.0
9142 *
9143 * Unless required by applicable law or agreed to in writing, software
9144 * distributed under the License is distributed on an "AS IS" BASIS,
9145 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9146 * See the License for the specific language governing permissions and
9147 * limitations under the License.
9148 * =============================================================================
9149 */
9150 /**
9151 * Computes inverse hyperbolic tan of the input `tf.Tensor` element-wise:
9152 * `atanh(x)`
9153 *
9154 * ```js
9155 * const x = tf.tensor1d([0, .1, -.1, .7]);
9156 *
9157 * x.atanh().print(); // or tf.atanh(x)
9158 * ```
9159 * @param x The input tensor.
9160 *
9161 * @doc {heading: 'Operations', subheading: 'Basic math'}
9162 */
9163 function atanh_(x) {
9164 const $x = convertToTensor(x, 'x', 'atanh');
9165 const inputs = { x: $x };
9166 return ENGINE.runKernel(Atanh, inputs);
9167 }
9168 const atanh$2 = /* @__PURE__ */ op({ atanh_ });
9169
9170 /**
9171 * @license
9172 * Copyright 2020 Google LLC. All Rights Reserved.
9173 * Licensed under the Apache License, Version 2.0 (the "License");
9174 * you may not use this file except in compliance with the License.
9175 * You may obtain a copy of the License at
9176 *
9177 * http://www.apache.org/licenses/LICENSE-2.0
9178 *
9179 * Unless required by applicable law or agreed to in writing, software
9180 * distributed under the License is distributed on an "AS IS" BASIS,
9181 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9182 * See the License for the specific language governing permissions and
9183 * limitations under the License.
9184 * =============================================================================
9185 */
9186 /**
9187 *
9188 * @param inputShape Input tensor shape is of the following dimensions:
9189 * `[batch, height, width, inChannels]`.
9190 * @param filterShape The filter shape is of the following dimensions:
9191 * `[filterHeight, filterWidth, depth]`.
9192 * @param strides The strides of the sliding window for each dimension of the
9193 * input tensor: `[strideHeight, strideWidth]`.
9194 * If `strides` is a single number,
9195 * then `strideHeight == strideWidth`.
9196 * @param pad The type of padding algorithm.
9197 * - `same` and stride 1: output will be of same size as input,
9198 * regardless of filter size.
9199 * - `valid`: output will be smaller than input if filter is larger
9200 * than 1*1x1.
9201 * - For more info, see this guide:
9202 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
9203 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
9204 * @param dataFormat The data format of the input and output data.
9205 * Defaults to 'NHWC'.
9206 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`.
9207 * Defaults to `[1, 1]`. If `dilations` is a single number, then
9208 * `dilationHeight == dilationWidth`.
9209 */
9210 function computeDilation2DInfo(inputShape, filterShape, strides, pad, dataFormat = 'NHWC', dilations) {
9211 // `computerConv2DInfo` require filterShape to be in the dimension of:
9212 // `[filterHeight, filterWidth, depth, outDepth]`, dilation2d doesn't have
9213 // outDepth, it should have the same depth as the input.
9214 // Input shape: [batch, height, width, inChannels]
9215 const inputChannels = inputShape[3];
9216 const $filterShape = [...filterShape, inputChannels];
9217 const $dataFormat = convertConv2DDataFormat(dataFormat);
9218 return computeConv2DInfo(inputShape, $filterShape, strides, dilations, pad, null /* roundingMode */, null /* depthWise */, $dataFormat);
9219 }
9220 function computePool2DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat = 'channelsLast') {
9221 const [filterHeight, filterWidth] = parseTupleParam(filterSize);
9222 let filterShape;
9223 if (dataFormat === 'channelsLast') {
9224 filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]];
9225 }
9226 else if (dataFormat === 'channelsFirst') {
9227 filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]];
9228 }
9229 else {
9230 throw new Error(`Unknown dataFormat ${dataFormat}`);
9231 }
9232 return computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat);
9233 }
9234 /**
9235 * Computes the information for a forward pass of a pooling3D operation.
9236 */
9237 function computePool3DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat = 'NDHWC') {
9238 const [filterDepth, filterHeight, filterWidth] = parse3TupleParam(filterSize);
9239 let filterShape;
9240 let $dataFormat;
9241 if (dataFormat === 'NDHWC') {
9242 $dataFormat = 'channelsLast';
9243 filterShape =
9244 [filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]];
9245 }
9246 else if (dataFormat === 'NCDHW') {
9247 $dataFormat = 'channelsFirst';
9248 filterShape =
9249 [filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]];
9250 }
9251 else {
9252 throw new Error(`Unknown dataFormat ${dataFormat}`);
9253 }
9254 return computeConv3DInfo(inShape, filterShape, strides, dilations, pad, false, $dataFormat, roundingMode);
9255 }
9256 /**
9257 * Computes the information for a forward pass of a convolution/pooling
9258 * operation.
9259 */
9260 function computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, depthwise = false, dataFormat = 'channelsLast') {
9261 let [batchSize, inHeight, inWidth, inChannels] = [-1, -1, -1, -1];
9262 if (dataFormat === 'channelsLast') {
9263 [batchSize, inHeight, inWidth, inChannels] = inShape;
9264 }
9265 else if (dataFormat === 'channelsFirst') {
9266 [batchSize, inChannels, inHeight, inWidth] = inShape;
9267 }
9268 else {
9269 throw new Error(`Unknown dataFormat ${dataFormat}`);
9270 }
9271 const [filterHeight, filterWidth, , filterChannels] = filterShape;
9272 const [strideHeight, strideWidth] = parseTupleParam(strides);
9273 const [dilationHeight, dilationWidth] = parseTupleParam(dilations);
9274 const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
9275 const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
9276 const { padInfo, outHeight, outWidth } = getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode, dataFormat);
9277 const outChannels = depthwise ? filterChannels * inChannels : filterChannels;
9278 let outShape;
9279 if (dataFormat === 'channelsFirst') {
9280 outShape = [batchSize, outChannels, outHeight, outWidth];
9281 }
9282 else if (dataFormat === 'channelsLast') {
9283 outShape = [batchSize, outHeight, outWidth, outChannels];
9284 }
9285 return {
9286 batchSize,
9287 dataFormat,
9288 inHeight,
9289 inWidth,
9290 inChannels,
9291 outHeight,
9292 outWidth,
9293 outChannels,
9294 padInfo,
9295 strideHeight,
9296 strideWidth,
9297 filterHeight,
9298 filterWidth,
9299 effectiveFilterHeight,
9300 effectiveFilterWidth,
9301 dilationHeight,
9302 dilationWidth,
9303 inShape,
9304 outShape,
9305 filterShape
9306 };
9307 }
9308 /**
9309 * Computes the information for a forward pass of a 3D convolution/pooling
9310 * operation.
9311 */
9312 function computeConv3DInfo(inShape, filterShape, strides, dilations, pad, depthwise = false, dataFormat = 'channelsLast', roundingMode) {
9313 let [batchSize, inDepth, inHeight, inWidth, inChannels] = [-1, -1, -1, -1, -1];
9314 if (dataFormat === 'channelsLast') {
9315 [batchSize, inDepth, inHeight, inWidth, inChannels] = inShape;
9316 }
9317 else if (dataFormat === 'channelsFirst') {
9318 [batchSize, inChannels, inDepth, inHeight, inWidth] = inShape;
9319 }
9320 else {
9321 throw new Error(`Unknown dataFormat ${dataFormat}`);
9322 }
9323 const [filterDepth, filterHeight, filterWidth, , filterChannels] = filterShape;
9324 const [strideDepth, strideHeight, strideWidth] = parse3TupleParam(strides);
9325 const [dilationDepth, dilationHeight, dilationWidth] = parse3TupleParam(dilations);
9326 const effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth);
9327 const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
9328 const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
9329 const { padInfo, outDepth, outHeight, outWidth } = get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, roundingMode);
9330 const outChannels = depthwise ? filterChannels * inChannels : filterChannels;
9331 let outShape;
9332 if (dataFormat === 'channelsFirst') {
9333 outShape = [batchSize, outChannels, outDepth, outHeight, outWidth];
9334 }
9335 else if (dataFormat === 'channelsLast') {
9336 outShape = [batchSize, outDepth, outHeight, outWidth, outChannels];
9337 }
9338 return {
9339 batchSize,
9340 dataFormat,
9341 inDepth,
9342 inHeight,
9343 inWidth,
9344 inChannels,
9345 outDepth,
9346 outHeight,
9347 outWidth,
9348 outChannels,
9349 padInfo,
9350 strideDepth,
9351 strideHeight,
9352 strideWidth,
9353 filterDepth,
9354 filterHeight,
9355 filterWidth,
9356 effectiveFilterDepth,
9357 effectiveFilterHeight,
9358 effectiveFilterWidth,
9359 dilationDepth,
9360 dilationHeight,
9361 dilationWidth,
9362 inShape,
9363 outShape,
9364 filterShape
9365 };
9366 }
9367 function computeOutputShape2D(inShape, fieldSize, stride, zeroPad, roundingMode) {
9368 if (zeroPad == null) {
9369 zeroPad = computeDefaultPad(inShape, fieldSize, stride);
9370 }
9371 const inputRows = inShape[0];
9372 const inputCols = inShape[1];
9373 const outputRows = round$3((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
9374 const outputCols = round$3((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
9375 return [outputRows, outputCols];
9376 }
9377 function computeOutputShape4D(inShape, filterShape, outChannels, strides, zeroPad, roundingMode) {
9378 if (zeroPad == null) {
9379 zeroPad = computeDefaultPad(inShape, filterShape[0], strides[0]);
9380 }
9381 const outShape = [0, 0, 0, outChannels];
9382 for (let index = 0; index < 3; index++) {
9383 if (inShape[index] + 2 * zeroPad >= filterShape[index]) {
9384 outShape[index] = round$3((inShape[index] - filterShape[index] + 2 * zeroPad) / strides[index] +
9385 1, roundingMode);
9386 }
9387 }
9388 return outShape;
9389 }
9390 function computeDefaultPad(inputShape, fieldSize, stride, dilation = 1) {
9391 const effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation);
9392 return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2);
9393 }
9394 function parseTupleParam(param) {
9395 if (typeof param === 'number') {
9396 return [param, param, param];
9397 }
9398 if (param.length === 2) {
9399 return [param[0], param[1], 1];
9400 }
9401 return param;
9402 }
9403 function parse3TupleParam(param) {
9404 return typeof param === 'number' ? [param, param, param] : param;
9405 }
9406 /* See https://www.tensorflow.org/api_docs/python/tf/nn/atrous_conv2d
9407 * Atrous convolution is equivalent to standard convolution with upsampled
9408 * filters with effective_filter_height =
9409 * filter_height + (filter_height - 1) * (dilation - 1)
9410 * and effective_filter_width =
9411 * filter_width + (filter_width - 1) * (dilation - 1),
9412 * produced by inserting dilation - 1 zeros along consecutive elements across
9413 * the filters' spatial dimensions.
9414 * When there is a dilation, this converts a filter dimension to the
9415 * effective filter dimension, so it can be used in a standard convolution.
9416 */
9417 function getEffectiveFilterSize(filterSize, dilation) {
9418 if (dilation <= 1) {
9419 return filterSize;
9420 }
9421 return filterSize + (filterSize - 1) * (dilation - 1);
9422 }
9423 function getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode, dataFormat) {
9424 let padInfo;
9425 let outHeight;
9426 let outWidth;
9427 if (typeof pad === 'number') {
9428 const padType = (pad === 0) ? 'VALID' : 'NUMBER';
9429 padInfo = { top: pad, bottom: pad, left: pad, right: pad, type: padType };
9430 const outShape = computeOutputShape2D([inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode);
9431 outHeight = outShape[0];
9432 outWidth = outShape[1];
9433 }
9434 else if (pad === 'same') {
9435 outHeight = Math.ceil(inHeight / strideHeight);
9436 outWidth = Math.ceil(inWidth / strideWidth);
9437 const padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight);
9438 const padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth);
9439 const top = Math.floor(padAlongHeight / 2);
9440 const bottom = padAlongHeight - top;
9441 const left = Math.floor(padAlongWidth / 2);
9442 const right = padAlongWidth - left;
9443 padInfo = { top, bottom, left, right, type: 'SAME' };
9444 }
9445 else if (pad === 'valid') {
9446 padInfo = { top: 0, bottom: 0, left: 0, right: 0, type: 'VALID' };
9447 outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
9448 outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
9449 }
9450 else if (typeof pad === 'object') {
9451 const top = dataFormat === 'channelsLast' ? pad[1][0] : pad[2][0];
9452 const bottom = dataFormat === 'channelsLast' ? pad[1][1] : pad[2][1];
9453 const left = dataFormat === 'channelsLast' ? pad[2][0] : pad[3][0];
9454 const right = dataFormat === 'channelsLast' ? pad[2][1] : pad[3][1];
9455 const padType = (top === 0 && bottom === 0 && left === 0 && right === 0) ?
9456 'VALID' :
9457 'EXPLICIT';
9458 padInfo = { top, bottom, left, right, type: padType };
9459 outHeight = round$3((inHeight - filterHeight + top + bottom) / strideHeight + 1, roundingMode);
9460 outWidth = round$3((inWidth - filterWidth + left + right) / strideWidth + 1, roundingMode);
9461 }
9462 else {
9463 throw Error(`Unknown padding parameter: ${pad}`);
9464 }
9465 return { padInfo, outHeight, outWidth };
9466 }
9467 function get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, roundingMode) {
9468 let padInfo;
9469 let outDepth;
9470 let outHeight;
9471 let outWidth;
9472 if (pad === 'valid') {
9473 pad = 0;
9474 }
9475 if (typeof pad === 'number') {
9476 const padType = (pad === 0) ? 'VALID' : 'NUMBER';
9477 padInfo = {
9478 top: pad,
9479 bottom: pad,
9480 left: pad,
9481 right: pad,
9482 front: pad,
9483 back: pad,
9484 type: padType
9485 };
9486 const outShape = computeOutputShape4D([inDepth, inHeight, inWidth, 1], [filterDepth, filterHeight, filterWidth], 1, [strideDepth, strideHeight, strideWidth], pad, roundingMode);
9487 outDepth = outShape[0];
9488 outHeight = outShape[1];
9489 outWidth = outShape[2];
9490 }
9491 else if (pad === 'same') {
9492 outDepth = Math.ceil(inDepth / strideDepth);
9493 outHeight = Math.ceil(inHeight / strideHeight);
9494 outWidth = Math.ceil(inWidth / strideWidth);
9495 const padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth;
9496 const padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight;
9497 const padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth;
9498 const front = Math.floor(padAlongDepth / 2);
9499 const back = padAlongDepth - front;
9500 const top = Math.floor(padAlongHeight / 2);
9501 const bottom = padAlongHeight - top;
9502 const left = Math.floor(padAlongWidth / 2);
9503 const right = padAlongWidth - left;
9504 padInfo = { top, bottom, left, right, front, back, type: 'SAME' };
9505 }
9506 else {
9507 throw Error(`Unknown padding parameter: ${pad}`);
9508 }
9509 return { padInfo, outDepth, outHeight, outWidth };
9510 }
9511 /**
9512 * Rounds a value depending on the rounding mode
9513 * @param value
9514 * @param roundingMode A string from: 'ceil', 'round', 'floor'. If none is
9515 * provided, it will default to truncate.
9516 */
9517 function round$3(value, roundingMode) {
9518 if (!roundingMode) {
9519 return Math.trunc(value);
9520 }
9521 switch (roundingMode) {
9522 case 'round':
9523 // used for Caffe Conv
9524 return Math.round(value);
9525 case 'ceil':
9526 // used for Caffe Pool
9527 return Math.ceil(value);
9528 case 'floor':
9529 return Math.floor(value);
9530 default:
9531 throw new Error(`Unknown roundingMode ${roundingMode}`);
9532 }
9533 }
9534 function tupleValuesAreOne(param) {
9535 const [dimA, dimB, dimC] = parseTupleParam(param);
9536 return dimA === 1 && dimB === 1 && dimC === 1;
9537 }
9538 function eitherStridesOrDilationsAreOne(strides, dilations) {
9539 return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
9540 }
9541 function stridesOrDilationsArePositive(values) {
9542 return parseTupleParam(values).every(value => value > 0);
9543 }
9544 /**
9545 * Convert Conv2D dataFormat from 'NHWC'|'NCHW' to
9546 * 'channelsLast'|'channelsFirst'
9547 * @param dataFormat in 'NHWC'|'NCHW' mode
9548 * @return dataFormat in 'channelsLast'|'channelsFirst' mode
9549 * @throws unknown dataFormat
9550 */
9551 function convertConv2DDataFormat(dataFormat) {
9552 if (dataFormat === 'NHWC') {
9553 return 'channelsLast';
9554 }
9555 else if (dataFormat === 'NCHW') {
9556 return 'channelsFirst';
9557 }
9558 else {
9559 throw new Error(`Unknown dataFormat ${dataFormat}`);
9560 }
9561 }
9562 /**
9563 * Check validity of pad when using dimRoundingMode.
9564 * @param opDesc A string of op description
9565 * @param pad The type of padding algorithm.
9566 * - `same` and stride 1: output will be of same size as input,
9567 * regardless of filter size.
9568 * - `valid` output will be smaller than input if filter is larger
9569 * than 1x1.
9570 * - For more info, see this guide:
9571 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
9572 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
9573 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
9574 * provided, it will default to truncate.
9575 * @throws unknown padding parameter
9576 */
9577 function checkPadOnDimRoundingMode(opDesc, pad, dimRoundingMode) {
9578 if (dimRoundingMode != null) {
9579 if (typeof pad === 'string') {
9580 throw Error(`Error in ${opDesc}: pad must be an integer when using ` +
9581 `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
9582 }
9583 else if (typeof pad === 'number') {
9584 assert$1(isInt(pad), () => `Error in ${opDesc}: pad must be an integer when using ` +
9585 `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
9586 }
9587 else if (typeof pad === 'object') {
9588 pad.forEach(p => {
9589 p.forEach(v => {
9590 assert$1(isInt(v), () => `Error in ${opDesc}: pad must be an integer when using ` +
9591 `dimRoundingMode ${dimRoundingMode} but got pad ${v}.`);
9592 });
9593 });
9594 }
9595 else {
9596 throw Error(`Error in ${opDesc}: Unknown padding parameter: ${pad}`);
9597 }
9598 }
9599 }
9600
9601 /**
9602 * @license
9603 * Copyright 2020 Google LLC. All Rights Reserved.
9604 * Licensed under the Apache License, Version 2.0 (the "License");
9605 * you may not use this file except in compliance with the License.
9606 * You may obtain a copy of the License at
9607 *
9608 * http://www.apache.org/licenses/LICENSE-2.0
9609 *
9610 * Unless required by applicable law or agreed to in writing, software
9611 * distributed under the License is distributed on an "AS IS" BASIS,
9612 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9613 * See the License for the specific language governing permissions and
9614 * limitations under the License.
9615 * =============================================================================
9616 */
9617 /**
9618 * Reshapes a `tf.Tensor` to a given shape.
9619 *
9620 * Given an input tensor, returns a new tensor with the same values as the
9621 * input tensor with shape `shape`.
9622 *
9623 * If one component of shape is the special value -1, the size of that
9624 * dimension is computed so that the total size remains constant. In
9625 * particular, a shape of [-1] flattens into 1-D. At most one component of
9626 * shape can be -1.
9627 *
9628 * If shape is 1-D or higher, then the operation returns a tensor with shape
9629 * shape filled with the values of tensor. In this case, the number of
9630 * elements implied by shape must be the same as the number of elements in
9631 * tensor.
9632 *
9633 * ```js
9634 * const x = tf.tensor1d([1, 2, 3, 4]);
9635 * x.reshape([2, 2]).print();
9636 * ```
9637 *
9638 * @param x The input tensor to be reshaped.
9639 * @param shape An array of integers defining the output tensor shape.
9640 *
9641 * @doc {heading: 'Tensors', subheading: 'Transformations'}
9642 */
9643 function reshape_(x, shape) {
9644 const $x = convertToTensor(x, 'x', 'reshape', 'string_or_numeric');
9645 const inputs = { x: $x };
9646 const attrs = { shape };
9647 return ENGINE.runKernel(Reshape$1, inputs, attrs);
9648 }
9649 const reshape$3 = /* @__PURE__ */ op({ reshape_ });
9650
9651 /**
9652 * @license
9653 * Copyright 2020 Google LLC. All Rights Reserved.
9654 * Licensed under the Apache License, Version 2.0 (the "License");
9655 * you may not use this file except in compliance with the License.
9656 * You may obtain a copy of the License at
9657 *
9658 * http://www.apache.org/licenses/LICENSE-2.0
9659 *
9660 * Unless required by applicable law or agreed to in writing, software
9661 * distributed under the License is distributed on an "AS IS" BASIS,
9662 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9663 * See the License for the specific language governing permissions and
9664 * limitations under the License.
9665 * =============================================================================
9666 */
9667 /**
9668 * Computes the 2D average pooling of an image.
9669 *
9670 * @param x The input tensor, of rank 4 or rank 3 of shape
9671 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
9672 * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
9673 * `filterSize` is a single number, then `filterHeight == filterWidth`.
9674 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
9675 * `strides` is a single number, then `strideHeight == strideWidth`.
9676 * @param pad The type of padding algorithm:
9677 * - `same` and stride 1: output will be of same size as input,
9678 * regardless of filter size.
9679 * - `valid`: output will be smaller than input if filter is larger
9680 * than 1x1.
9681 * - For more info, see this guide:
9682 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
9683 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
9684 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
9685 * provided, it will default to truncate.
9686 *
9687 * @doc {heading: 'Operations', subheading: 'Convolution'}
9688 */
9689 function avgPool_(x, filterSize, strides, pad, dimRoundingMode) {
9690 const $x = convertToTensor(x, 'x', 'avgPool', 'float32');
9691 const dilations = 1;
9692 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool: Either strides or dilations must be 1. ' +
9693 `Got strides ${strides} and dilations '${dilations}'`);
9694 let x4D = $x;
9695 let reshapedTo4D = false;
9696 if ($x.rank === 3) {
9697 reshapedTo4D = true;
9698 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
9699 }
9700 assert$1(x4D.rank === 4, () => `Error in avgPool: x must be rank 4 but got rank ${x4D.rank}.`);
9701 checkPadOnDimRoundingMode('avgPool', pad, dimRoundingMode);
9702 const inputs = { x: x4D };
9703 const attrs = { filterSize, strides, pad, dimRoundingMode };
9704 // tslint:disable-next-line: no-unnecessary-type-assertion
9705 let res = ENGINE.runKernel(AvgPool, inputs, attrs);
9706 res = cast$3(res, $x.dtype);
9707 if (reshapedTo4D) {
9708 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
9709 }
9710 return res;
9711 }
9712 const avgPool$2 = /* @__PURE__ */ op({ avgPool_ });
9713
9714 /**
9715 * @license
9716 * Copyright 2020 Google LLC. All Rights Reserved.
9717 * Licensed under the Apache License, Version 2.0 (the "License");
9718 * you may not use this file except in compliance with the License.
9719 * You may obtain a copy of the License at
9720 *
9721 * http://www.apache.org/licenses/LICENSE-2.0
9722 *
9723 * Unless required by applicable law or agreed to in writing, software
9724 * distributed under the License is distributed on an "AS IS" BASIS,
9725 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9726 * See the License for the specific language governing permissions and
9727 * limitations under the License.
9728 * =============================================================================
9729 */
9730 /**
9731 * Computes the 3D average pooling.
9732 *
9733 * ```js
9734 * const x = tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]);
9735 * const result = tf.avgPool3d(x, 2, 1, 'valid');
9736 * result.print();
9737 * ```
9738 *
9739 * @param x The input tensor, of rank 5 or rank 4 of shape
9740 * `[batch, depth, height, width, inChannels]`.
9741 * @param filterSize The filter size:
9742 * `[filterDepth, filterHeight, filterWidth]`.
9743 * If `filterSize` is a single number,
9744 * then `filterDepth == filterHeight == filterWidth`.
9745 * @param strides The strides of the pooling:
9746 * `[strideDepth, strideHeight, strideWidth]`.
9747 * If `strides` is a single number,
9748 * then `strideDepth == strideHeight == strideWidth`.
9749 * @param pad The type of padding algorithm.
9750 * - `same` and stride 1: output will be of same size as input,
9751 * regardless of filter size.
9752 * - `valid`: output will be smaller than input if filter is larger
9753 * than 1*1x1.
9754 * - For more info, see this guide:
9755 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
9756 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
9757 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
9758 * provided, it will default to truncate.
9759 * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
9760 * "NDHWC". Specify the data format of the input and output data. With the
9761 * default format "NDHWC", the data is stored in the order of: [batch,
9762 * depth, height, width, channels]. Only "NDHWC" is currently supported.
9763 *
9764 * @doc {heading: 'Operations', subheading: 'Convolution'}
9765 */
9766 function avgPool3d_(x, filterSize, strides, pad, dimRoundingMode, dataFormat = 'NDHWC') {
9767 const $x = convertToTensor(x, 'x', 'avgPool3d', 'float32');
9768 let x5D = $x;
9769 let reshapedTo5D = false;
9770 if ($x.rank === 4) {
9771 reshapedTo5D = true;
9772 x5D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
9773 }
9774 assert$1(x5D.rank === 5, () => `Error in avgPool3d: x must be rank 5 but got rank ${x5D.rank}.`);
9775 assert$1(dataFormat === 'NDHWC', () => `Error in avgPool3d: Only NDHWC is currently supported, ` +
9776 `but got dataFormat of ${dataFormat}`);
9777 assert$1((typeof strides === 'number' && strides > 0) ||
9778 (Array.isArray(strides) && strides[0] > 0 && strides[1] > 0 &&
9779 strides[2] > 0), () => `Error in avgPool3d: Stride must be > 0, but got '${strides}'`);
9780 checkPadOnDimRoundingMode('avgPool3d', pad, dimRoundingMode);
9781 const inputs = { x: x5D };
9782 const attrs = { filterSize, strides, pad, dimRoundingMode, dataFormat };
9783 // tslint:disable-next-line: no-unnecessary-type-assertion
9784 let res = ENGINE.runKernel(AvgPool3D, inputs, attrs);
9785 res = cast$3(res, x5D.dtype);
9786 if (reshapedTo5D) {
9787 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
9788 }
9789 return res;
9790 }
9791 const avgPool3d$1 = /* @__PURE__ */ op({ avgPool3d_ });
9792
9793 /**
9794 * @license
9795 * Copyright 2020 Google LLC. All Rights Reserved.
9796 * Licensed under the Apache License, Version 2.0 (the "License");
9797 * you may not use this file except in compliance with the License.
9798 * You may obtain a copy of the License at
9799 *
9800 * http://www.apache.org/licenses/LICENSE-2.0
9801 *
9802 * Unless required by applicable law or agreed to in writing, software
9803 * distributed under the License is distributed on an "AS IS" BASIS,
9804 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9805 * See the License for the specific language governing permissions and
9806 * limitations under the License.
9807 * =============================================================================
9808 */
9809 /**
9810 * Concatenates a list of `tf.Tensor`s along a given axis.
9811 *
9812 * The tensors ranks and types must match, and their sizes must match in all
9813 * dimensions except `axis`.
9814 *
9815 * Also available are stricter rank-specific methods that assert that
9816 * `tensors` are of the given rank:
9817 * - `tf.concat1d`
9818 * - `tf.concat2d`
9819 * - `tf.concat3d`
9820 * - `tf.concat4d`
9821 *
9822 * Except `tf.concat1d` (which does not have axis param), all methods have
9823 * same signature as this method.
9824 *
9825 * ```js
9826 * const a = tf.tensor1d([1, 2]);
9827 * const b = tf.tensor1d([3, 4]);
9828 * a.concat(b).print(); // or a.concat(b)
9829 * ```
9830 *
9831 * ```js
9832 * const a = tf.tensor1d([1, 2]);
9833 * const b = tf.tensor1d([3, 4]);
9834 * const c = tf.tensor1d([5, 6]);
9835 * tf.concat([a, b, c]).print();
9836 * ```
9837 *
9838 * ```js
9839 * const a = tf.tensor2d([[1, 2], [10, 20]]);
9840 * const b = tf.tensor2d([[3, 4], [30, 40]]);
9841 * const axis = 1;
9842 * tf.concat([a, b], axis).print();
9843 * ```
9844 * @param tensors A list of tensors to concatenate.
9845 * @param axis The axis to concatenate along. Defaults to 0 (the first dim).
9846 *
9847 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
9848 */
9849 function concat_(tensors, axis = 0) {
9850 assert$1(tensors.length >= 1, () => 'Pass at least one tensor to concat');
9851 const $tensors = convertToTensorArray(tensors, 'tensors', 'concat', 'string_or_numeric');
9852 if ($tensors[0].dtype === 'complex64') {
9853 $tensors.forEach(tensor => {
9854 if (tensor.dtype !== 'complex64') {
9855 throw new Error(`Cannot concatenate complex64 tensors with a tensor
9856 with dtype ${tensor.dtype}. `);
9857 }
9858 });
9859 }
9860 if ($tensors.length === 1) {
9861 return clone($tensors[0]);
9862 }
9863 const inputs = $tensors;
9864 const attr = { axis };
9865 return ENGINE.runKernel(Concat, inputs, attr);
9866 }
9867 const concat$2 = /* @__PURE__ */ op({ concat_ });
9868
9869 /**
9870 * @license
9871 * Copyright 2020 Google LLC. All Rights Reserved.
9872 * Licensed under the Apache License, Version 2.0 (the "License");
9873 * you may not use this file except in compliance with the License.
9874 * You may obtain a copy of the License at
9875 *
9876 * http://www.apache.org/licenses/LICENSE-2.0
9877 *
9878 * Unless required by applicable law or agreed to in writing, software
9879 * distributed under the License is distributed on an "AS IS" BASIS,
9880 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9881 * See the License for the specific language governing permissions and
9882 * limitations under the License.
9883 * =============================================================================
9884 */
9885 /**
9886 * Computes the dot product of two matrices, A * B. These must be matrices.
9887 *
9888 * ```js
9889 * const a = tf.tensor2d([1, 2], [1, 2]);
9890 * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]);
9891 *
9892 * a.matMul(b).print(); // or tf.matMul(a, b)
9893 * ```
9894 * @param a First matrix in dot product operation.
9895 * @param b Second matrix in dot product operation.
9896 * @param transposeA If true, `a` is transposed before multiplication.
9897 * @param transposeB If true, `b` is transposed before multiplication.
9898 *
9899 * @doc {heading: 'Operations', subheading: 'Matrices'}
9900 */
9901 function matMul_(a, b, transposeA = false, transposeB = false) {
9902 let $a = convertToTensor(a, 'a', 'matMul');
9903 let $b = convertToTensor(b, 'b', 'matMul');
9904 [$a, $b] = makeTypesMatch($a, $b);
9905 const inputs = { a: $a, b: $b };
9906 const attrs = { transposeA, transposeB };
9907 return ENGINE.runKernel(BatchMatMul, inputs, attrs);
9908 }
9909 const matMul$1 = /* @__PURE__ */ op({ matMul_ });
9910
9911 /**
9912 * @license
9913 * Copyright 2018 Google LLC. All Rights Reserved.
9914 * Licensed under the Apache License, Version 2.0 (the "License");
9915 * you may not use this file except in compliance with the License.
9916 * You may obtain a copy of the License at
9917 *
9918 * http://www.apache.org/licenses/LICENSE-2.0
9919 *
9920 * Unless required by applicable law or agreed to in writing, software
9921 * distributed under the License is distributed on an "AS IS" BASIS,
9922 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9923 * See the License for the specific language governing permissions and
9924 * limitations under the License.
9925 * =============================================================================
9926 */
9927 /**
9928 * Computes sigmoid element-wise, `1 / (1 + exp(-x))`
9929 *
9930 * ```js
9931 * const x = tf.tensor1d([0, -1, 2, -3]);
9932 *
9933 * x.sigmoid().print(); // or tf.sigmoid(x)
9934 * ```
9935 * @param x The input tensor.
9936 *
9937 * @doc {heading: 'Operations', subheading: 'Basic math'}
9938 */
9939 function sigmoid_(x) {
9940 const $x = convertToTensor(x, 'x', 'sigmoid', 'float32');
9941 const inputs = { x: $x };
9942 return ENGINE.runKernel(Sigmoid$1, inputs);
9943 }
9944 const sigmoid$2 = /* @__PURE__ */ op({ sigmoid_ });
9945
9946 /**
9947 * @license
9948 * Copyright 2018 Google LLC. All Rights Reserved.
9949 * Licensed under the Apache License, Version 2.0 (the "License");
9950 * you may not use this file except in compliance with the License.
9951 * You may obtain a copy of the License at
9952 *
9953 * http://www.apache.org/licenses/LICENSE-2.0
9954 *
9955 * Unless required by applicable law or agreed to in writing, software
9956 * distributed under the License is distributed on an "AS IS" BASIS,
9957 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9958 * See the License for the specific language governing permissions and
9959 * limitations under the License.
9960 * =============================================================================
9961 */
9962 /**
9963 * Extracts a slice from a `tf.Tensor` starting at coordinates `begin`
9964 * and is of size `size`.
9965 *
9966 * Also available are stricter rank-specific methods with the same signature
9967 * as this method that assert that `x` is of the given rank:
9968 * - `tf.slice1d`
9969 * - `tf.slice2d`
9970 * - `tf.slice3d`
9971 * - `tf.slice4d`
9972 *
9973 * ```js
9974 * const x = tf.tensor1d([1, 2, 3, 4]);
9975 *
9976 * x.slice([1], [2]).print();
9977 * ```
9978 *
9979 * ```js
9980 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
9981 *
9982 * x.slice([1, 0], [1, 2]).print();
9983 * ```
9984 * @param x The input `tf.Tensor` to slice from.
9985 * @param begin The coordinates to start the slice from. The length can be
9986 * less than the rank of x - the rest of the axes will have implicit 0 as
9987 * start. Can also be a single number, in which case it specifies the
9988 * first axis.
9989 * @param size The size of the slice. The length can be less than the rank of
9990 * x - the rest of the axes will have implicit -1. A value of -1 requests
9991 * the rest of the dimensions in the axis. Can also be a single number,
9992 * in which case it specifies the size of the first axis.
9993 *
9994 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
9995 */
9996 function slice_(x, begin, size) {
9997 const $x = convertToTensor(x, 'x', 'slice', 'string_or_numeric');
9998 if ($x.rank === 0) {
9999 throw new Error('Slicing scalar is not possible');
10000 }
10001 const inputs = { x: $x };
10002 const attrs = { begin, size };
10003 return ENGINE.runKernel(Slice, inputs, attrs);
10004 }
10005 const slice$2 = /* @__PURE__ */ op({ slice_ });
10006
10007 /**
10008 * @license
10009 * Copyright 2018 Google LLC. All Rights Reserved.
10010 * Licensed under the Apache License, Version 2.0 (the "License");
10011 * you may not use this file except in compliance with the License.
10012 * You may obtain a copy of the License at
10013 *
10014 * http://www.apache.org/licenses/LICENSE-2.0
10015 *
10016 * Unless required by applicable law or agreed to in writing, software
10017 * distributed under the License is distributed on an "AS IS" BASIS,
10018 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10019 * See the License for the specific language governing permissions and
10020 * limitations under the License.
10021 * =============================================================================
10022 */
10023 /**
10024 * Computes hyperbolic tangent of the input `tf.Tensor` element-wise: `tanh(x)`
10025 *
10026 * ```js
10027 * const x = tf.tensor1d([0, 1, -1, 70]);
10028 *
10029 * x.tanh().print(); // or tf.tanh(x)
10030 * ```
10031 * @param x The input tensor.
10032 *
10033 * @doc {heading: 'Operations', subheading: 'Basic math'}
10034 */
10035 function tanh_(x) {
10036 const $x = convertToTensor(x, 'x', 'tanh', 'float32');
10037 const inputs = { x: $x };
10038 return ENGINE.runKernel(Tanh$1, inputs);
10039 }
10040 const tanh$2 = /* @__PURE__ */ op({ tanh_ });
10041
10042 /**
10043 * @license
10044 * Copyright 2020 Google LLC. All Rights Reserved.
10045 * Licensed under the Apache License, Version 2.0 (the "License");
10046 * you may not use this file except in compliance with the License.
10047 * You may obtain a copy of the License at
10048 *
10049 * http://www.apache.org/licenses/LICENSE-2.0
10050 *
10051 * Unless required by applicable law or agreed to in writing, software
10052 * distributed under the License is distributed on an "AS IS" BASIS,
10053 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10054 * See the License for the specific language governing permissions and
10055 * limitations under the License.
10056 * =============================================================================
10057 */
10058 /**
10059 * Computes the next state and output of a BasicLSTMCell.
10060 *
10061 * Returns `[newC, newH]`.
10062 *
10063 * Derived from tf.contrib.rnn.BasicLSTMCell.
10064 *
10065 * @param forgetBias Forget bias for the cell.
10066 * @param lstmKernel The weights for the cell.
10067 * @param lstmBias The bias for the cell.
10068 * @param data The input to the cell.
10069 * @param c Previous cell state.
10070 * @param h Previous cell output.
10071 *
10072 * @doc {heading: 'Operations', subheading: 'RNN'}
10073 */
10074 function basicLSTMCell_(forgetBias, lstmKernel, lstmBias, data, c, h) {
10075 const $forgetBias = convertToTensor(forgetBias, 'forgetBias', 'basicLSTMCell');
10076 const $lstmKernel = convertToTensor(lstmKernel, 'lstmKernel', 'basicLSTMCell');
10077 const $lstmBias = convertToTensor(lstmBias, 'lstmBias', 'basicLSTMCell');
10078 const $data = convertToTensor(data, 'data', 'basicLSTMCell');
10079 const $c = convertToTensor(c, 'c', 'basicLSTMCell');
10080 const $h = convertToTensor(h, 'h', 'basicLSTMCell');
10081 const combined = concat$2([$data, $h], 1);
10082 const weighted = matMul$1(combined, $lstmKernel);
10083 const res = add$3(weighted, $lstmBias);
10084 // i = input_gate, j = new_input, f = forget_gate, o = output_gate
10085 const batchSize = res.shape[0];
10086 const sliceCols = res.shape[1] / 4;
10087 const sliceSize = [batchSize, sliceCols];
10088 const i = slice$2(res, [0, 0], sliceSize);
10089 const j = slice$2(res, [0, sliceCols], sliceSize);
10090 const f = slice$2(res, [0, sliceCols * 2], sliceSize);
10091 const o = slice$2(res, [0, sliceCols * 3], sliceSize);
10092 const newC = add$3(mul(sigmoid$2(i), tanh$2(j)), mul($c, sigmoid$2(add$3($forgetBias, f))));
10093 const newH = mul(tanh$2(newC), sigmoid$2(o));
10094 return [newC, newH];
10095 }
10096 const basicLSTMCell = /* @__PURE__ */ op({ basicLSTMCell_ });
10097
10098 /**
10099 * @license
10100 * Copyright 2020 Google LLC. All Rights Reserved.
10101 * Licensed under the Apache License, Version 2.0 (the "License");
10102 * you may not use this file except in compliance with the License.
10103 * You may obtain a copy of the License at
10104 *
10105 * http://www.apache.org/licenses/LICENSE-2.0
10106 *
10107 * Unless required by applicable law or agreed to in writing, software
10108 * distributed under the License is distributed on an "AS IS" BASIS,
10109 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10110 * See the License for the specific language governing permissions and
10111 * limitations under the License.
10112 * =============================================================================
10113 */
10114 /**
10115 * This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of
10116 * shape `blockShape + [batch]`, interleaves these blocks back into the grid
10117 * defined by the spatial dimensions `[1, ..., M]`, to obtain a result with
10118 * the same rank as the input. The spatial dimensions of this intermediate
10119 * result are then optionally cropped according to `crops` to produce the
10120 * output. This is the reverse of `tf.spaceToBatchND`. See below for a precise
10121 * description.
10122 *
10123 * ```js
10124 * const x = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);
10125 * const blockShape = [2, 2];
10126 * const crops = [[0, 0], [0, 0]];
10127 *
10128 * x.batchToSpaceND(blockShape, crops).print();
10129 * ```
10130 *
10131 * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
10132 * remainingShape`, where spatialShape has `M` dimensions.
10133 * @param blockShape A 1-D array. Must have shape `[M]`, all values must
10134 * be >= 1.
10135 * @param crops A 2-D array. Must have shape `[M, 2]`, all values must be >= 0.
10136 * `crops[i] = [cropStart, cropEnd]` specifies the amount to crop from input
10137 * dimension `i + 1`, which corresponds to spatial dimension `i`. It is required
10138 * that `cropStart[i] + cropEnd[i] <= blockShape[i] * inputShape[i + 1]`
10139 *
10140 * This operation is equivalent to the following steps:
10141 *
10142 * 1. Reshape `x` to `reshaped` of shape: `[blockShape[0], ...,
10143 * blockShape[M-1], batch / prod(blockShape), x.shape[1], ...,
10144 * x.shape[N-1]]`
10145 *
10146 * 2. Permute dimensions of `reshaped` to produce `permuted` of shape `[batch /
10147 * prod(blockShape),x.shape[1], blockShape[0], ..., x.shape[M],
10148 * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
10149 *
10150 * 3. Reshape `permuted` to produce `reshapedPermuted` of shape `[batch /
10151 * prod(blockShape),x.shape[1] * blockShape[0], ..., x.shape[M] *
10152 * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
10153 *
10154 * 4. Crop the start and end of dimensions `[1, ..., M]` of `reshapedPermuted`
10155 * according to `crops` to produce the output of shape: `[batch /
10156 * prod(blockShape),x.shape[1] * blockShape[0] - crops[0,0] - crops[0,1],
10157 * ..., x.shape[M] * blockShape[M-1] - crops[M-1,0] -
10158 * crops[M-1,1],x.shape[M+1], ..., x.shape[N-1]]`
10159 *
10160 * @doc {heading: 'Tensors', subheading: 'Transformations'}
10161 */
10162 function batchToSpaceND_(x, blockShape, crops) {
10163 const $x = convertToTensor(x, 'x', 'batchToSpaceND');
10164 const prod = blockShape.reduce((a, b) => a * b);
10165 assert$1($x.rank >= 1 + blockShape.length, () => `input rank is ${$x.rank} but should be > than blockShape.length ${blockShape.length}`);
10166 assert$1(crops.length === blockShape.length, () => `crops.length is ${crops.length} but should be equal to blockShape.length ${blockShape.length}`);
10167 assert$1($x.shape[0] % prod === 0, () => `input tensor batch is ${$x.shape[0]} but is not divisible by the product of ` +
10168 `the elements of blockShape ${blockShape.join(' * ')} === ${prod}`);
10169 const inputs = { x: $x };
10170 const attrs = { blockShape, crops };
10171 return ENGINE.runKernel(BatchToSpaceND, inputs, attrs);
10172 }
10173 const batchToSpaceND$2 = /* @__PURE__ */ op({ batchToSpaceND_ });
10174
10175 function xAs4D(x) {
10176 let x4D;
10177 if (x.rank === 0 || x.rank === 1) {
10178 x4D = reshape$3(x, [1, 1, 1, x.size]);
10179 }
10180 else if (x.rank === 2) {
10181 x4D = reshape$3(x, [1, 1, x.shape[0], x.shape[1]]);
10182 }
10183 else if (x.rank === 3) {
10184 x4D = reshape$3(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
10185 }
10186 else {
10187 x4D = x;
10188 }
10189 return x4D;
10190 }
10191
10192 /**
10193 * @license
10194 * Copyright 2020 Google LLC. All Rights Reserved.
10195 * Licensed under the Apache License, Version 2.0 (the "License");
10196 * you may not use this file except in compliance with the License.
10197 * You may obtain a copy of the License at
10198 *
10199 * http://www.apache.org/licenses/LICENSE-2.0
10200 *
10201 * Unless required by applicable law or agreed to in writing, software
10202 * distributed under the License is distributed on an "AS IS" BASIS,
10203 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10204 * See the License for the specific language governing permissions and
10205 * limitations under the License.
10206 * =============================================================================
10207 */
10208 /**
10209 * Batch normalization.
10210 *
10211 * As described in
10212 * [http://arxiv.org/abs/1502.03167](http://arxiv.org/abs/1502.03167).
10213 *
10214 * Mean, variance, scale, and offset can be of two shapes:
10215 * - The same shape as the input.
10216 * - In the common case, the depth dimension is the last dimension of x, so
10217 * the values would be a `tf.Tensor1D` of shape [depth].
10218 *
10219 * Also available are stricter rank-specific methods with the same signature
10220 * as this method that assert that parameters passed are of given rank
10221 * - `tf.batchNorm2d`
10222 * - `tf.batchNorm3d`
10223 * - `tf.batchNorm4d`
10224 *
10225 * @param x The input Tensor.
10226 * @param mean A mean Tensor.
10227 * @param variance A variance Tensor.
10228 * @param offset An offset Tensor.
10229 * @param scale A scale Tensor.
10230 * @param varianceEpsilon A small float number to avoid dividing by 0.
10231 *
10232 * @doc {heading: 'Operations', subheading: 'Normalization'}
10233 */
10234 function batchNorm_(x, mean, variance, offset, scale, varianceEpsilon) {
10235 if (varianceEpsilon == null) {
10236 varianceEpsilon = 0.001;
10237 }
10238 const $x = convertToTensor(x, 'x', 'batchNorm');
10239 const $mean = convertToTensor(mean, 'mean', 'batchNorm');
10240 const $variance = convertToTensor(variance, 'variance', 'batchNorm');
10241 let $scale;
10242 if (scale != null) {
10243 $scale = convertToTensor(scale, 'scale', 'batchNorm');
10244 }
10245 let $offset;
10246 if (offset != null) {
10247 $offset = convertToTensor(offset, 'offset', 'batchNorm');
10248 }
10249 assert$1($mean.rank === $variance.rank, () => 'Batch normalization gradient requires mean and variance to have ' +
10250 'equal ranks.');
10251 assert$1($offset == null || $mean.rank === $offset.rank, () => 'Batch normalization gradient requires mean and offset to have ' +
10252 'equal ranks.');
10253 assert$1($scale == null || $mean.rank === $scale.rank, () => 'Batch normalization gradient requires mean and scale to have ' +
10254 'equal ranks.');
10255 const x4D = xAs4D($x);
10256 const inputs = {
10257 x: x4D,
10258 scale: $scale,
10259 offset: $offset,
10260 mean: $mean,
10261 variance: $variance
10262 };
10263 const attrs = { varianceEpsilon };
10264 // tslint:disable-next-line: no-unnecessary-type-assertion
10265 const res = ENGINE.runKernel(FusedBatchNorm, inputs, attrs);
10266 return reshape$3(res, $x.shape);
10267 }
10268 const batchNorm$2 = /* @__PURE__ */ op({ batchNorm_ });
10269
10270 /**
10271 * Batch normalization, strictly for 2D. For the more relaxed version, see
10272 * `tf.batchNorm`.
10273 *
10274 * @param x The input Tensor.
10275 * @param mean A mean Tensor.
10276 * @param variance A variance Tensor.
10277 * @param offset An offset Tensor.
10278 * @param scale A scale Tensor.
10279 * @param varianceEpsilon A small float number to avoid dividing by 0.
10280 */
10281 function batchNorm2d_(x, mean, variance, offset, scale, varianceEpsilon) {
10282 const $x = convertToTensor(x, 'x', 'batchNorm');
10283 const $mean = convertToTensor(mean, 'mean', 'batchNorm');
10284 const $variance = convertToTensor(variance, 'variance', 'batchNorm');
10285 let $scale;
10286 if (scale != null) {
10287 $scale = convertToTensor(scale, 'scale', 'batchNorm');
10288 }
10289 let $offset;
10290 if (offset != null) {
10291 $offset = convertToTensor(offset, 'offset', 'batchNorm');
10292 }
10293 assert$1($x.rank === 2, () => `Error in batchNorm2D: x must be rank 2 but got rank ` +
10294 `${$x.rank}.`);
10295 assert$1($mean.rank === 2 || $mean.rank === 1, () => `Error in batchNorm2D: mean must be rank 2 or rank 1 but ` +
10296 `got rank ${$mean.rank}.`);
10297 assert$1($variance.rank === 2 || $variance.rank === 1, () => `Error in batchNorm2D: variance must be rank 2 or rank 1 ` +
10298 `but got rank ${$variance.rank}.`);
10299 if ($scale != null) {
10300 assert$1($scale.rank === 2 || $scale.rank === 1, () => `Error in batchNorm2D: scale must be rank 2 or rank 1 ` +
10301 `but got rank ${$scale.rank}.`);
10302 }
10303 if ($offset != null) {
10304 assert$1($offset.rank === 2 || $offset.rank === 1, () => `Error in batchNorm2D: offset must be rank 2 or rank 1 ` +
10305 `but got rank ${$offset.rank}.`);
10306 }
10307 return batchNorm$2($x, $mean, $variance, $offset, $scale, varianceEpsilon);
10308 }
10309 const batchNorm2d = /* @__PURE__ */ op({ batchNorm2d_ });
10310
10311 /**
10312 * Batch normalization, strictly for 3D. For the more relaxed version, see
10313 * `tf.batchNorm`.
10314 *
10315 * @param x The input Tensor.
10316 * @param mean A mean Tensor.
10317 * @param variance A variance Tensor.
10318 * @param offset An offset Tensor.
10319 * @param scale A scale Tensor.
10320 * @param varianceEpsilon A small float number to avoid dividing by 0.
10321 */
10322 function batchNorm3d_(x, mean, variance, offset, scale, varianceEpsilon) {
10323 const $x = convertToTensor(x, 'x', 'batchNorm');
10324 const $mean = convertToTensor(mean, 'mean', 'batchNorm');
10325 const $variance = convertToTensor(variance, 'variance', 'batchNorm');
10326 let $scale;
10327 if (scale != null) {
10328 $scale = convertToTensor(scale, 'scale', 'batchNorm');
10329 }
10330 let $offset;
10331 if (offset != null) {
10332 $offset = convertToTensor(offset, 'offset', 'batchNorm');
10333 }
10334 assert$1($x.rank === 3, () => `Error in batchNorm3D: x must be rank 3 but got rank ` +
10335 `${$x.rank}.`);
10336 assert$1($mean.rank === 3 || $mean.rank === 1, () => `Error in batchNorm3D: mean must be rank 3 or rank 1 but ` +
10337 `got rank ${$mean.rank}.`);
10338 assert$1($variance.rank === 3 || $variance.rank === 1, () => `Error in batchNorm3D: variance must be rank 3 or rank 1 ` +
10339 `but got rank ${$variance.rank}.`);
10340 if ($scale != null) {
10341 assert$1($scale.rank === 3 || $scale.rank === 1, () => `Error in batchNorm3D: scale must be rank 3 or rank 1 ` +
10342 `but got rank ${$scale.rank}.`);
10343 }
10344 if ($offset != null) {
10345 assert$1($offset.rank === 3 || $offset.rank === 1, () => `Error in batchNorm3D: offset must be rank 3 or rank 1 ` +
10346 `but got rank ${$offset.rank}.`);
10347 }
10348 return batchNorm$2($x, $mean, $variance, $offset, $scale, varianceEpsilon);
10349 }
10350 const batchNorm3d = /* @__PURE__ */ op({ batchNorm3d_ });
10351
10352 /**
10353 * Batch normalization, strictly for 4D. For the more relaxed version, see
10354 * `tf.batchNorm`.
10355 *
10356 * @param x The input Tensor.
10357 * @param mean A mean Tensor.
10358 * @param variance A variance Tensor.
10359 * @param offset An offset Tensor.
10360 * @param scale A scale Tensor.
10361 * @param varianceEpsilon A small float number to avoid dividing by 0.
10362 */
10363 function batchNorm4d_(x, mean, variance, offset, scale, varianceEpsilon) {
10364 const $x = convertToTensor(x, 'x', 'batchNorm');
10365 const $mean = convertToTensor(mean, 'mean', 'batchNorm');
10366 const $variance = convertToTensor(variance, 'variance', 'batchNorm');
10367 let $scale;
10368 if (scale != null) {
10369 $scale = convertToTensor(scale, 'scale', 'batchNorm');
10370 }
10371 let $offset;
10372 if (offset != null) {
10373 $offset = convertToTensor(offset, 'offset', 'batchNorm');
10374 }
10375 assert$1($x.rank === 4, () => `Error in batchNorm4D: x must be rank 4 but got rank ` +
10376 `${$x.rank}.`);
10377 assert$1($mean.rank === 4 || $mean.rank === 1, () => `Error in batchNorm4D: mean must be rank 4 or rank 1 but ` +
10378 `got rank ${$mean.rank}.`);
10379 assert$1($variance.rank === 4 || $variance.rank === 1, () => `Error in batchNorm4D: variance must be rank 4 or rank 1 ` +
10380 `but got rank ${$variance.rank}.`);
10381 if ($scale != null) {
10382 assert$1($scale.rank === 4 || $scale.rank === 1, () => `Error in batchNorm4D: scale must be rank 4 or rank 1 ` +
10383 `but got rank ${$scale.rank}.`);
10384 }
10385 if ($offset != null) {
10386 assert$1($offset.rank === 4 || $offset.rank === 1, () => `Error in batchNorm4D: offset must be rank 4 or rank 1 ` +
10387 `but got rank ${$offset.rank}.`);
10388 }
10389 return batchNorm$2($x, $mean, $variance, $offset, $scale, varianceEpsilon);
10390 }
10391 const batchNorm4d = /* @__PURE__ */ op({ batchNorm4d_ });
10392
10393 /**
10394 * @license
10395 * Copyright 2020 Google LLC. All Rights Reserved.
10396 * Licensed under the Apache License, Version 2.0 (the "License");
10397 * you may not use this file except in compliance with the License.
10398 * You may obtain a copy of the License at
10399 *
10400 * http://www.apache.org/licenses/LICENSE-2.0
10401 *
10402 * Unless required by applicable law or agreed to in writing, software
10403 * distributed under the License is distributed on an "AS IS" BASIS,
10404 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10405 * See the License for the specific language governing permissions and
10406 * limitations under the License.
10407 * =============================================================================
10408 */
10409 /**
10410 * Outputs a vector with length `size` and the same dtype as `weights`.
10411 *
10412 * If `weights` are empty, then index `i` stores the number of times the value
10413 * `i` is counted in `x`. If `weights` are non-empty, then index `i` stores the
10414 * sum of the value in `weights` at each index where the corresponding value in
10415 * `x` is `i`.
10416 *
10417 * Values in `x` outside of the range [0, size) are ignored.
10418 *
10419 * @param x The input int tensor, rank 1.
10420 * @param weights The weights tensor, must have the same shape as x, or a
10421 * length-0 Tensor, in which case it acts as all weights equal to 1.
10422 * @param size Non-negative integer.
10423 *
10424 * @doc {heading: 'Operations', subheading: 'Reduction'}
10425 */
10426 function bincount_(x, weights, size) {
10427 const $x = convertToTensor(x, 'x', 'bincount');
10428 const $weights = convertToTensor(weights, 'weights', 'bincount');
10429 assert$1($x.dtype === 'int32', () => `Error in bincount: input ` +
10430 `dtype must be int32, but got ${$x.dtype}`);
10431 assert$1(size >= 0, () => `size must be non-negative, but got ${size}.`);
10432 assert$1($weights.size === $x.size || $weights.size === 0, () => `Error in bincount: weights must have the same size as input or` +
10433 `0-length, but got input shape: ${$x.shape}, weights shape: ` +
10434 `${$weights.shape}.`);
10435 const inputs = { x: $x, weights: $weights };
10436 const attrs = { size };
10437 return ENGINE.runKernel(Bincount, inputs, attrs);
10438 }
10439 const bincount$2 = /* @__PURE__ */ op({ bincount_ });
10440
10441 /**
10442 * @license
10443 * Copyright 2023 Google LLC.
10444 * Licensed under the Apache License, Version 2.0 (the "License");
10445 * you may not use this file except in compliance with the License.
10446 * You may obtain a copy of the License at
10447 *
10448 * http://www.apache.org/licenses/LICENSE-2.0
10449 *
10450 * Unless required by applicable law or agreed to in writing, software
10451 * distributed under the License is distributed on an "AS IS" BASIS,
10452 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10453 * See the License for the specific language governing permissions and
10454 * limitations under the License.
10455 * =============================================================================
10456 */
10457 /**
10458 * Bitwise `AND` operation for input tensors.
10459 *
10460 * Given two input tensors, returns a new tensor
10461 * with the `AND` calculated values.
10462 *
10463 * The method supports int32 values
10464 *
10465 *
10466 * ```js
10467 * const x = tf.tensor1d([0, 5, 3, 14], 'int32');
10468 * const y = tf.tensor1d([5, 0, 7, 11], 'int32');
10469 * tf.bitwiseAnd(x, y).print();
10470 * ```
10471 *
10472 * @param x The input tensor to be calculated.
10473 * @param y The input tensor to be calculated.
10474 *
10475 * @doc {heading: 'Operations', subheading: 'Logical'}
10476 */
10477 function bitwiseAnd_(x, y) {
10478 const $x = convertToTensor(x, 'x', 'bitwiseAnd');
10479 const $y = convertToTensor(y, 'y', 'bitwiseAnd');
10480 if (!arraysEqual($x.shape, $y.shape)) {
10481 throw new Error(`BitwiseAnd: Tensors must have the same shape. x: ${$x.shape}, y: ${$y.shape}`);
10482 }
10483 if ($x.dtype !== 'int32' || $y.dtype !== 'int32') {
10484 throw new Error(`BitwiseAnd: Only supports 'int32' values in tensor, found type of x: ${$x.dtype} and type of y: ${$y.dtype}`);
10485 }
10486 const inputs = { a: $x, b: $y };
10487 return ENGINE.runKernel(BitwiseAnd, inputs);
10488 }
10489 const bitwiseAnd$2 = /* @__PURE__ */ op({ bitwiseAnd_ });
10490
10491 /**
10492 * @license
10493 * Copyright 2021 Google LLC. All Rights Reserved.
10494 * Licensed under the Apache License, Version 2.0 (the "License");
10495 * you may not use this file except in compliance with the License.
10496 * You may obtain a copy of the License at
10497 *
10498 * http://www.apache.org/licenses/LICENSE-2.0
10499 *
10500 * Unless required by applicable law or agreed to in writing, software
10501 * distributed under the License is distributed on an "AS IS" BASIS,
10502 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10503 * See the License for the specific language governing permissions and
10504 * limitations under the License.
10505 * =============================================================================
10506 */
10507 /**
10508 * Return the shape of s0 op s1 with broadcast.
10509 *
10510 * compute r0, the broadcasted shape as a tensor.
10511 * s0, s1 and r0 are all integer vectors.
10512 *
10513 * This function returns the shape of the result of an operation between
10514 * two tensors of size s0 and s1 performed with broadcast.
10515 *
10516 * @param s0 A tensor representing a shape
10517 * @param s1 A tensor representing a shape
10518 *
10519 * @doc {heading: 'Tensors', subheading: 'Transformations'}
10520 */
10521 function broadcastArgs_(s0, s1) {
10522 const shape1Input = convertToTensor(s0, 's0', 'broadcastArgs', 'int32');
10523 const shape2Input = convertToTensor(s1, 's1', 'broadcastArgs', 'int32');
10524 if (shape1Input.rank !== 1) {
10525 throw new Error('broadcastArgs(): first input must be a vector (rank=1). ' +
10526 `Has rank ${shape1Input.rank}`);
10527 }
10528 if (shape2Input.rank !== 1) {
10529 throw new Error('broadcastArgs(): second input must be a vector (rank=1). ' +
10530 `Has rank ${shape2Input.rank}`);
10531 }
10532 const inputs = { s0: shape1Input, s1: shape2Input };
10533 return ENGINE.runKernel(BroadcastArgs, inputs);
10534 }
10535 const broadcastArgs$2 = /* @__PURE__ */ op({ broadcastArgs_ });
10536
10537 /**
10538 * @license
10539 * Copyright 2020 Google LLC. All Rights Reserved.
10540 * Licensed under the Apache License, Version 2.0 (the "License");
10541 * you may not use this file except in compliance with the License.
10542 * You may obtain a copy of the License at
10543 *
10544 * http://www.apache.org/licenses/LICENSE-2.0
10545 *
10546 * Unless required by applicable law or agreed to in writing, software
10547 * distributed under the License is distributed on an "AS IS" BASIS,
10548 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10549 * See the License for the specific language governing permissions and
10550 * limitations under the License.
10551 * =============================================================================
10552 */
10553 /**
10554 * Broadcast an array to a compatible shape NumPy-style.
10555 *
10556 * The tensor's shape is compared to the broadcast shape from end to beginning.
10557 * Ones are prepended to the tensor's shape until it has the same length as
10558 * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is
10559 * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then
10560 * the input tensor is tiled N times along that axis (using tf.tile).
10561 *
10562 * @param input The tensor that is to be broadcasted.
10563 * @param shape The input is to be broadcast to this shape.
10564 *
10565 * @doc {heading: 'Tensors', subheading: 'Transformations'}
10566 */
10567 function broadcastTo_(x, shape) {
10568 let input = convertToTensor(x, 'broadcastTo', 'x');
10569 const xShape = input.shape;
10570 assertNonNegativeIntegerDimensions(shape);
10571 if (shape.length < input.rank) {
10572 throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${input.rank}.`);
10573 }
10574 if (shape.length > input.rank) {
10575 const newShape = input.shape.slice();
10576 while (newShape.length < shape.length) {
10577 newShape.unshift(1);
10578 }
10579 input = reshape$3(input, newShape);
10580 }
10581 const inputShape = input.shape;
10582 const reps = Array.from(shape);
10583 for (let i = shape.length - 1; i >= 0; i--) {
10584 if (inputShape[i] === shape[i]) {
10585 reps[i] = 1;
10586 }
10587 else if (input.shape[i] !== 1) {
10588 throw new Error(`broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`);
10589 }
10590 }
10591 const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0);
10592 if (axes.length === 0) {
10593 return clone(input);
10594 }
10595 // TODO call broadcastTo kernel directly once backends implement broadcstTo
10596 const inputs = { x: input };
10597 const attrs = { reps };
10598 return ENGINE.runKernel(Tile, inputs, attrs);
10599 }
10600 const broadcastTo = /* @__PURE__ */ op({ broadcastTo_ });
10601
10602 /**
10603 * @license
10604 * Copyright 2018 Google LLC. All Rights Reserved.
10605 * Licensed under the Apache License, Version 2.0 (the "License");
10606 * you may not use this file except in compliance with the License.
10607 * You may obtain a copy of the License at
10608 *
10609 * http://www.apache.org/licenses/LICENSE-2.0
10610 *
10611 * Unless required by applicable law or agreed to in writing, software
10612 * distributed under the License is distributed on an "AS IS" BASIS,
10613 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10614 * See the License for the specific language governing permissions and
10615 * limitations under the License.
10616 * =============================================================================
10617 */
10618 /**
10619 * Computes ceiling of input `tf.Tensor` element-wise: `ceil(x)`
10620 *
10621 * ```js
10622 * const x = tf.tensor1d([.6, 1.1, -3.3]);
10623 *
10624 * x.ceil().print(); // or tf.ceil(x)
10625 * ```
10626 * @param x The input Tensor.
10627 *
10628 * @doc {heading: 'Operations', subheading: 'Basic math'}
10629 */
10630 function ceil_(x) {
10631 const $x = convertToTensor(x, 'x', 'ceil', 'float32');
10632 const inputs = { x: $x };
10633 return ENGINE.runKernel(Ceil, inputs);
10634 }
10635 const ceil$2 = /* @__PURE__ */ op({ ceil_ });
10636
10637 /**
10638 * @license
10639 * Copyright 2020 Google LLC. All Rights Reserved.
10640 * Licensed under the Apache License, Version 2.0 (the "License");
10641 * you may not use this file except in compliance with the License.
10642 * You may obtain a copy of the License at
10643 *
10644 * http://www.apache.org/licenses/LICENSE-2.0
10645 *
10646 * Unless required by applicable law or agreed to in writing, software
10647 * distributed under the License is distributed on an "AS IS" BASIS,
10648 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10649 * See the License for the specific language governing permissions and
10650 * limitations under the License.
10651 * =============================================================================
10652 */
10653 /**
10654 * Creates a `tf.Tensor` filled with a scalar value.
10655 *
10656 * ```js
10657 * tf.fill([2, 2], 4).print();
10658 * ```
10659 *
10660 * @param shape An array of integers defining the output tensor shape.
10661 * @param value The scalar value to fill the tensor with.
10662 * @param dtype The type of an element in the resulting tensor. Defaults to
10663 * 'float32' if the given param value is a number, otherwise 'string'.
10664 *
10665 * @doc {heading: 'Tensors', subheading: 'Creation'}
10666 */
10667 function fill$2(shape, value, dtype) {
10668 assertNonNegativeIntegerDimensions(shape);
10669 dtype = dtype || inferDtype(value);
10670 const attrs = { shape, value, dtype };
10671 return ENGINE.runKernel(Fill, {}, attrs);
10672 }
10673
10674 /**
10675 * @license
10676 * Copyright 2018 Google LLC. All Rights Reserved.
10677 * Licensed under the Apache License, Version 2.0 (the "License");
10678 * you may not use this file except in compliance with the License.
10679 * You may obtain a copy of the License at
10680 *
10681 * http://www.apache.org/licenses/LICENSE-2.0
10682 *
10683 * Unless required by applicable law or agreed to in writing, software
10684 * distributed under the License is distributed on an "AS IS" BASIS,
10685 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10686 * See the License for the specific language governing permissions and
10687 * limitations under the License.
10688 * =============================================================================
10689 */
10690 /**
10691 * Clips values element-wise. `max(min(x, clipValueMax), clipValueMin)`
10692 *
10693 * ```js
10694 * const x = tf.tensor1d([-1, 2, -3, 4]);
10695 *
10696 * x.clipByValue(-2, 3).print(); // or tf.clipByValue(x, -2, 3)
10697 * ```
10698 * @param x The input tensor.
10699 * @param clipValueMin Lower bound of range to be clipped to.
10700 * @param clipValueMax Upper bound of range to be clipped to.
10701 *
10702 * @doc {heading: 'Operations', subheading: 'Basic math'}
10703 */
10704 function clipByValue_(x, clipValueMin, clipValueMax) {
10705 const $x = convertToTensor(x, 'x', 'clipByValue');
10706 assert$1((clipValueMin <= clipValueMax), () => `Error in clip: min (${clipValueMin}) must be ` +
10707 `less than or equal to max (${clipValueMax}).`);
10708 if (clipValueMin === clipValueMax) {
10709 return fill$2($x.shape, clipValueMin, $x.dtype);
10710 }
10711 const inputs = { x: $x };
10712 const attrs = { clipValueMin, clipValueMax };
10713 return ENGINE.runKernel(ClipByValue, inputs, attrs);
10714 }
10715 const clipByValue$2 = /* @__PURE__ */ op({ clipByValue_ });
10716
10717 /**
10718 * Concatenates a list of`tf.Tensor1D`s along an axis. See `concat` for details.
10719 *
10720 * For example, if:
10721 * A: shape(3) = |r1, g1, b1|
10722 * B: shape(2) = |r2, g2|
10723 * C = tf.concat1d([A, B]) == |r1, g1, b1, r2, g2|
10724 *
10725 * @param tensors A list of`tf.Tensor`s to concatenate.
10726 * @return The concatenated array.
10727 */
10728 function concat1d_(tensors) {
10729 return concat$2(tensors, 0 /* axis */);
10730 }
10731 const concat1d = /* @__PURE__ */ op({ concat1d_ });
10732
10733 /**
10734 * Concatenates a list of`tf.Tensor2D`s along an axis. See `concat` for details.
10735 *
10736 * For example, if:
10737 * A: shape(2, 3) = | r1, g1, b1 |
10738 * | r2, g2, b2 |
10739 *
10740 * B: shape(2, 3) = | r3, g3, b3 |
10741 * | r4, g4, b4 |
10742 *
10743 * C = tf.concat2d([A, B], axis)
10744 *
10745 * if axis = 0:
10746 * C: shape(4, 3) = | r1, g1, b1 |
10747 * | r2, g2, b2 |
10748 * | r3, g3, b3 |
10749 * | r4, g4, b4 |
10750 *
10751 * if axis = 1:
10752 * C = shape(2, 6) = | r1, g1, b1, r3, g3, b3 |
10753 * | r2, g2, b2, r4, g4, b4 |
10754 *
10755 *
10756 * @param tensors A list of `tf.Tensor`s to concatenate.
10757 * @param axis The axis to concatenate along.
10758 * @return The concatenated array.
10759 */
10760 function concat2d_(tensors, axis) {
10761 return concat$2(tensors, axis);
10762 }
10763 const concat2d = /* @__PURE__ */ op({ concat2d_ });
10764
10765 /**
10766 * Concatenates a list of `tf.Tensor3D`s along an axis.
10767 * See `concat` for details.
10768 *
10769 * For example, if:
10770 * A: shape(2, 1, 3) = | r1, g1, b1 |
10771 * | r2, g2, b2 |
10772 *
10773 * B: shape(2, 1, 3) = | r3, g3, b3 |
10774 * | r4, g4, b4 |
10775 *
10776 * C = tf.concat3d([A, B], axis)
10777 *
10778 * if axis = 0:
10779 * C: shape(4, 1, 3) = | r1, g1, b1 |
10780 * | r2, g2, b2 |
10781 * | r3, g3, b3 |
10782 * | r4, g4, b4 |
10783 *
10784 * if axis = 1:
10785 * C: shape(2, 2, 3) = | r1, g1, b1, r3, g3, b3 |
10786 * | r2, g2, b2, r4, g4, b4 |
10787 *
10788 * if axis = 2:
10789 * C = shape(2, 1, 6) = | r1, g1, b1, r3, g3, b3 |
10790 * | r2, g2, b2, r4, g4, b4 |
10791 *
10792 * @param tensors A list of`tf.Tensor`s to concatenate.
10793 * @param axis The axis to concate along.
10794 * @return The concatenated array.
10795 */
10796 function concat3d_(tensors, axis) {
10797 return concat$2(tensors, axis);
10798 }
10799 const concat3d = /* @__PURE__ */ op({ concat3d_ });
10800
10801 /**
10802 * Concatenates a list of `tf.Tensor4D`s along an axis.
10803 * See `concat` for details.
10804 *
10805 * @param tensors A list of `tf.Tensor`s to concatenate.
10806 * @param axis The axis to concate along.
10807 * @return The concatenated array.
10808 */
10809 function concat4d_(tensors, axis) {
10810 return concat$2(tensors, axis);
10811 }
10812 const concat4d = /* @__PURE__ */ op({ concat4d_ });
10813
10814 /**
10815 * @license
10816 * Copyright 2020 Google LLC. All Rights Reserved.
10817 * Licensed under the Apache License, Version 2.0 (the "License");
10818 * you may not use this file except in compliance with the License.
10819 * You may obtain a copy of the License at
10820 *
10821 * http://www.apache.org/licenses/LICENSE-2.0
10822 *
10823 * Unless required by applicable law or agreed to in writing, software
10824 * distributed under the License is distributed on an "AS IS" BASIS,
10825 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10826 * See the License for the specific language governing permissions and
10827 * limitations under the License.
10828 * =============================================================================
10829 */
10830 /**
10831 * Computes a 2D convolution over the input x.
10832 *
10833 * @param x The input tensor, of rank 4 or rank 3, of shape
10834 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
10835 * assumed.
10836 * @param filter The filter, rank 4, of shape
10837 * `[filterHeight, filterWidth, inDepth, outDepth]`.
10838 * @param strides The strides of the convolution: `[strideHeight,
10839 * strideWidth]`.
10840 * @param pad The type of padding algorithm.
10841 * - `same` and stride 1: output will be of same size as input,
10842 * regardless of filter size.
10843 * - `valid`: output will be smaller than input if filter is larger
10844 * than 1x1.
10845 * - For more info, see this guide:
10846 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
10847 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
10848 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
10849 * "NHWC". Specify the data format of the input and output data. With the
10850 * default format "NHWC", the data is stored in the order of: [batch,
10851 * height, width, channels].
10852 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
10853 * in which we sample input values across the height and width dimensions
10854 * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
10855 * number, then `dilationHeight == dilationWidth`. If it is greater than
10856 * 1, then all values of `strides` must be 1.
10857 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
10858 * provided, it will default to truncate.
10859 *
10860 * @doc {heading: 'Operations', subheading: 'Convolution'}
10861 */
10862 function conv2d_(x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode) {
10863 const $x = convertToTensor(x, 'x', 'conv2d', 'float32');
10864 const $filter = convertToTensor(filter, 'filter', 'conv2d', 'float32');
10865 let x4D = $x;
10866 let reshapedTo4D = false;
10867 if ($x.rank === 3) {
10868 reshapedTo4D = true;
10869 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
10870 }
10871 assert$1(x4D.rank === 4, () => `Error in conv2d: input must be rank 4, but got rank ${x4D.rank}.`);
10872 assert$1($filter.rank === 4, () => `Error in conv2d: filter must be rank 4, but got rank ` +
10873 `${$filter.rank}.`);
10874 checkPadOnDimRoundingMode('conv2d', pad, dimRoundingMode);
10875 const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
10876 assert$1(inDepth === $filter.shape[2], () => `Error in conv2d: depth of input (${inDepth}) must match ` +
10877 `input depth for filter ${$filter.shape[2]}.`);
10878 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in conv2D: Either strides or dilations must be 1. ' +
10879 `Got strides ${strides} and dilations '${dilations}'`);
10880 assert$1(stridesOrDilationsArePositive(dilations), () => 'Error in conv2D: Dilated rates should be larger than 0.');
10881 assert$1(stridesOrDilationsArePositive(strides), () => 'Error in conv2D: Strides should be larger than 0.');
10882 const inputs = { x: x4D, filter: $filter };
10883 const attrs = { strides, pad, dataFormat, dilations, dimRoundingMode };
10884 // tslint:disable-next-line: no-unnecessary-type-assertion
10885 const res = ENGINE.runKernel(Conv2D$1, inputs, attrs);
10886 if (reshapedTo4D) {
10887 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
10888 }
10889 return res;
10890 }
10891 const conv2d$4 = /* @__PURE__ */ op({ conv2d_ });
10892
10893 /**
10894 * Computes a 1D convolution over the input x.
10895 *
10896 * @param x The input tensor, of rank 3 or rank 2, of shape
10897 * `[batch, width, inChannels]`. If rank 2, batch of 1 is assumed.
10898 * @param filter The filter, rank 3, of shape
10899 * `[filterWidth, inDepth, outDepth]`.
10900 * @param stride The number of entries by which the filter is moved right at
10901 * each step.
10902 * @param pad The type of padding algorithm.
10903 * - `same` and stride 1: output will be of same size as input,
10904 * regardless of filter size.
10905 * - `valid`: output will be smaller than input if filter is larger
10906 * than 1x1.
10907 * - For more info, see this guide:
10908 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
10909 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
10910 * @param dataFormat An optional string from "NWC", "NCW". Defaults to "NWC",
10911 * the data is stored in the order of [batch, in_width, in_channels]. Only
10912 * "NWC" is currently supported.
10913 * @param dilation The dilation rate in which we sample input values in
10914 * atrous convolution. Defaults to `1`. If it is greater than 1, then
10915 * stride must be `1`.
10916 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
10917 * provided, it will default to truncate.
10918 *
10919 * @doc {heading: 'Operations', subheading: 'Convolution'}
10920 */
10921 function conv1d_(x, filter, stride, pad, dataFormat = 'NWC', dilation = 1, dimRoundingMode) {
10922 const $x = convertToTensor(x, 'x', 'conv1d');
10923 const $filter = convertToTensor(filter, 'filter', 'conv1d');
10924 let x3D = $x;
10925 let reshapedTo3D = false;
10926 if ($x.rank === 2) {
10927 reshapedTo3D = true;
10928 x3D = reshape$3($x, [1, $x.shape[0], $x.shape[1]]);
10929 }
10930 assert$1(x3D.rank === 3, () => `Error in conv1d: input must be rank 3, but got rank ${x3D.rank}.`);
10931 assert$1($filter.rank === 3, () => `Error in conv1d: filter must be rank 3, but got rank ` +
10932 `${$filter.rank}.`);
10933 checkPadOnDimRoundingMode('conv1d', pad, dimRoundingMode);
10934 assert$1(x3D.shape[2] === $filter.shape[1], () => `Error in conv1d: depth of input (${x3D.shape[2]}) must match ` +
10935 `input depth for filter ${$filter.shape[1]}.`);
10936 assert$1(eitherStridesOrDilationsAreOne(stride, dilation), () => 'Error in conv1D: Either stride or dilation must be 1. ' +
10937 `Got stride ${stride} and dilation '${dilation}'`);
10938 assert$1(stridesOrDilationsArePositive(dilation), () => 'Error in conv1D: Dilated rates should be larger than 0.');
10939 assert$1(stridesOrDilationsArePositive(stride), () => 'Error in conv1D: Stride should be larger than 0.');
10940 assert$1(dataFormat === 'NWC', () => `Error in conv1d: got dataFormat of ${dataFormat} but only NWC is currently supported.`);
10941 const filter4D = reshape$3($filter, [1, $filter.shape[0], $filter.shape[1], $filter.shape[2]]);
10942 const input4D = reshape$3(x3D, [x3D.shape[0], 1, x3D.shape[1], x3D.shape[2]]);
10943 const strides = [1, stride];
10944 const dilations = [1, dilation];
10945 const conv2dDataFormat = 'NHWC';
10946 const res = conv2d$4(input4D, filter4D, strides, pad, conv2dDataFormat, dilations, dimRoundingMode);
10947 if (reshapedTo3D) {
10948 return reshape$3(res, [res.shape[2], res.shape[3]]);
10949 }
10950 return reshape$3(res, [res.shape[0], res.shape[2], res.shape[3]]);
10951 }
10952 const conv1d$2 = /* @__PURE__ */ op({ conv1d_ });
10953
10954 /**
10955 * @license
10956 * Copyright 2020 Google LLC. All Rights Reserved.
10957 * Licensed under the Apache License, Version 2.0 (the "License");
10958 * you may not use this file except in compliance with the License.
10959 * You may obtain a copy of the License at
10960 *
10961 * http://www.apache.org/licenses/LICENSE-2.0
10962 *
10963 * Unless required by applicable law or agreed to in writing, software
10964 * distributed under the License is distributed on an "AS IS" BASIS,
10965 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10966 * See the License for the specific language governing permissions and
10967 * limitations under the License.
10968 * =============================================================================
10969 */
10970 /**
10971 * Computes the derivative of the input of a 2D convolution.
10972 *
10973 * @param xShape The shape of the input: [batch, height, width, inDepth].
10974 * If length of 3, batch of 1 is assumed.
10975 * @param dy The derivative of the output, of rank 4 or rank 3 of shape
10976 * `[batch, outHeight, outWidth, outDepth]`. If rank 3, batch of 1 is
10977 * assumed.
10978 * @param filter The filter, rank 4, of shape
10979 * `[filterHeight, filterWidth, inDepth, outDepth]`.
10980 * @param strides The strides of the convolution: `[strideHeight,
10981 * strideWidth]`.
10982 * @param pad The type of padding algorithm used:
10983 * - `same` and stride 1: output will be of same size as input,
10984 * regardless of filter size.
10985 * - `valid`: output will be smaller than input if filter is larger
10986 * than 1x1.
10987 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
10988 * "NHWC". Specify the data format of the input and output data. With the
10989 * default format "NHWC", the data is stored in the order of: [batch,
10990 * height, width, channels].
10991 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
10992 * provided, it will default to truncate.
10993 */
10994 function conv2DBackpropInput_(xShape, dy, filter, strides, pad, dataFormat = 'NHWC', dimRoundingMode) {
10995 assert$1(xShape.length === dy.rank, () => `Length of inShape ` +
10996 `(${xShape.length}) and rank of dy (${dy.rank}) must match`);
10997 let xShape4D = xShape;
10998 let dy4D = dy;
10999 let reshapedTo4D = false;
11000 if (dy.rank === 3) {
11001 reshapedTo4D = true;
11002 dy4D = reshape$3(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
11003 xShape4D = [1, xShape[0], xShape[1], xShape[2]];
11004 }
11005 assert$1(xShape4D.length === 4, () => `Error in conv2dDerInput: inShape must be length 4, but got length ` +
11006 `${xShape4D.length}.`);
11007 assert$1(dy4D.rank === 4, () => `Error in conv2dDerInput: dy must be rank 4, but got ` +
11008 `rank ${dy4D.rank}`);
11009 assert$1(filter.rank === 4, () => `Error in conv2dDerInput: filter must be rank 4, but got ` +
11010 `rank ${filter.rank}`);
11011 const inDepth = dataFormat === 'NHWC' ? xShape4D[3] : xShape4D[1];
11012 const outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
11013 assert$1(inDepth === filter.shape[2], () => `Error in conv2dDerInput: depth of input (${inDepth}) must ` +
11014 `match input depth for filter ${filter.shape[2]}.`);
11015 assert$1(outDepth === filter.shape[3], () => `Error in conv2dDerInput: depth of output (${outDepth}) must ` +
11016 `match output depth for filter ${filter.shape[3]}.`);
11017 checkPadOnDimRoundingMode('conv2dDerInput', pad, dimRoundingMode);
11018 const inputs = { dy: dy4D, filter };
11019 const attrs = { strides, pad, dataFormat, dimRoundingMode, inputShape: xShape4D };
11020 // tslint:disable-next-line: no-unnecessary-type-assertion
11021 const res = ENGINE.runKernel(Conv2DBackpropInput, inputs, attrs);
11022 if (reshapedTo4D) {
11023 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
11024 }
11025 return res;
11026 }
11027 const conv2DBackpropInput$2 = /* @__PURE__ */ op({ conv2DBackpropInput_ });
11028
11029 /**
11030 * Computes the transposed 2D convolution of an image, also known as a
11031 * deconvolution.
11032 *
11033 * @param x The input image, of rank 4 or rank 3, of shape
11034 * `[batch, height, width, inDepth]`. If rank 3, batch of 1 is assumed.
11035 * @param filter The filter, rank 4, of shape
11036 * `[filterHeight, filterWidth, outDepth, inDepth]`.
11037 * `inDepth` must match `inDepth` in `x`.
11038 * @param outputShape Output shape, of rank 4 or rank 3:
11039 * `[batch, height, width, outDepth]`. If rank 3, batch of 1 is assumed.
11040 * @param strides The strides of the original convolution:
11041 * `[strideHeight, strideWidth]`.
11042 * @param pad The type of padding algorithm used in the non-transpose version
11043 * of the op.
11044 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
11045 * provided, it will default to truncate.
11046 *
11047 * @doc {heading: 'Operations', subheading: 'Convolution'}
11048 */
11049 function conv2dTranspose_(x, filter, outputShape, strides, pad, dimRoundingMode) {
11050 const $x = convertToTensor(x, 'x', 'conv2dTranspose');
11051 const $filter = convertToTensor(filter, 'filter', 'conv2dTranspose');
11052 return conv2DBackpropInput$2(outputShape, $x, $filter, strides, pad, 'NHWC', dimRoundingMode);
11053 }
11054 const conv2dTranspose$1 = /* @__PURE__ */ op({ conv2dTranspose_ });
11055
11056 /**
11057 * @license
11058 * Copyright 2020 Google LLC. All Rights Reserved.
11059 * Licensed under the Apache License, Version 2.0 (the "License");
11060 * you may not use this file except in compliance with the License.
11061 * You may obtain a copy of the License at
11062 *
11063 * http://www.apache.org/licenses/LICENSE-2.0
11064 *
11065 * Unless required by applicable law or agreed to in writing, software
11066 * distributed under the License is distributed on an "AS IS" BASIS,
11067 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11068 * See the License for the specific language governing permissions and
11069 * limitations under the License.
11070 * =============================================================================
11071 */
11072 /**
11073 * Computes a 3D convolution over the input x.
11074 *
11075 * @param x The input tensor, of rank 5 or rank 4, of shape
11076 * `[batch, depth, height, width, channels]`. If rank 4,
11077 * batch of 1 is assumed.
11078 * @param filter The filter, rank 5, of shape
11079 * `[filterDepth, filterHeight, filterWidth, inChannels, outChannels]`.
11080 * inChannels must match between input and filter.
11081 * @param strides The strides of the convolution: `[strideDepth, strideHeight,
11082 * strideWidth]`.
11083 * @param pad The type of padding algorithm.
11084 * - `same` and stride 1: output will be of same size as input,
11085 * regardless of filter size.
11086 * - `valid`: output will be smaller than input if filter is larger
11087 * than 1x1.
11088 * - For more info, see this guide:
11089 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
11090 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
11091 * @param dataFormat: An optional string from: "NDHWC", "NCDHW". Defaults to
11092 * "NDHWC". Specify the data format of the input and output data. With the
11093 * default format "NDHWC", the data is stored in the order of: [batch,
11094 * depth, height, width, channels]. Only "NDHWC" is currently supported.
11095 * @param dilations The dilation rates: `[dilationDepth, dilationHeight,
11096 * dilationWidth]` in which we sample input values across the height
11097 * and width dimensions in atrous convolution. Defaults to `[1, 1, 1]`.
11098 * If `dilations` is a single number, then
11099 * `dilationDepth == dilationHeight == dilationWidth`. If it is greater
11100 * than 1, then all values of `strides` must be 1.
11101 *
11102 * @doc {heading: 'Operations', subheading: 'Convolution'}
11103 */
11104 function conv3d_(x, filter, strides, pad, dataFormat = 'NDHWC', dilations = [1, 1, 1]) {
11105 const $x = convertToTensor(x, 'x', 'conv3d');
11106 const $filter = convertToTensor(filter, 'filter', 'conv3d');
11107 let x5D = $x;
11108 let reshapedTo5D = false;
11109 if ($x.rank === 4) {
11110 reshapedTo5D = true;
11111 x5D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
11112 }
11113 assert$1(x5D.rank === 5, () => `Error in conv3d: input must be rank 5, but got rank ${x5D.rank}.`);
11114 assert$1($filter.rank === 5, () => `Error in conv3d: filter must be rank 5, but got rank ` +
11115 `${$filter.rank}.`);
11116 assert$1(x5D.shape[4] === $filter.shape[3], () => `Error in conv3d: depth of input (${x5D.shape[4]}) must match ` +
11117 `input depth for filter ${$filter.shape[3]}.`);
11118 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in conv3D: Either strides or dilations must be 1. ' +
11119 `Got strides ${strides} and dilations '${dilations}'`);
11120 assert$1(dataFormat === 'NDHWC', () => `Error in conv3d: got dataFormat of ${dataFormat} but only NDHWC is currently supported.`);
11121 assert$1(stridesOrDilationsArePositive(dilations), () => 'Error in conv3D: Dilated rates should be larger than 0.');
11122 assert$1(stridesOrDilationsArePositive(strides), () => 'Error in conv3D: Strides should be larger than 0.');
11123 const inputs = { x: x5D, filter: $filter };
11124 const attrs = { strides, pad, dataFormat, dilations };
11125 // tslint:disable-next-line: no-unnecessary-type-assertion
11126 const res = ENGINE.runKernel(Conv3D$1, inputs, attrs);
11127 if (reshapedTo5D) {
11128 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
11129 }
11130 return res;
11131 }
11132 const conv3d$2 = /* @__PURE__ */ op({ conv3d_ });
11133
11134 /**
11135 * @license
11136 * Copyright 2020 Google LLC. All Rights Reserved.
11137 * Licensed under the Apache License, Version 2.0 (the "License");
11138 * you may not use this file except in compliance with the License.
11139 * You may obtain a copy of the License at
11140 *
11141 * http://www.apache.org/licenses/LICENSE-2.0
11142 *
11143 * Unless required by applicable law or agreed to in writing, software
11144 * distributed under the License is distributed on an "AS IS" BASIS,
11145 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11146 * See the License for the specific language governing permissions and
11147 * limitations under the License.
11148 * =============================================================================
11149 */
11150 /**
11151 * Computes the derivative of the input of a 3D convolution.
11152 *
11153 * @param xShape The shape of the input: [batch, depth, height, width,
11154 * in_channels]. If length of 4, batch of 1 is assumed.
11155 * @param dy The derivative of the output, of rank 5 or rank 4 of shape
11156 * `[batch, outDepth, outHeight, outWidth, in_channels]`.
11157 * If rank 4, batch of 1 is assumed.
11158 * @param filter The filter, rank 5, of shape
11159 * `[filterDepth, filterHeight, filterWidth, inDepth, outDepth]`.
11160 * @param strides The strides of the convolution: `[strideDepth, strideHeight,
11161 * strideWidth]`.
11162 * @param pad The type of padding algorithm used:
11163 * - `same` and stride 1: output will be of same size as input,
11164 * regardless of filter size.
11165 * - `valid`: output will be smaller than input if filter is larger
11166 * than 1x1.
11167 */
11168 function conv3DBackpropInput_(xShape, dy, filter, strides, pad) {
11169 assert$1(xShape.length === dy.rank, () => `Length of inShape ` +
11170 `(${xShape.length}) and rank of dy (${dy.rank}) must match`);
11171 let xShape5D = xShape;
11172 let dy5D = dy;
11173 let reshapedTo5D = false;
11174 if (dy.rank === 4) {
11175 reshapedTo5D = true;
11176 dy5D = reshape$3(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
11177 xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]];
11178 }
11179 const inDepth = xShape5D[4];
11180 const outDepth = dy5D.shape[4];
11181 assert$1(xShape5D.length === 5, () => `Error in conv3dDerInput: inShape must be length 5, but got length ` +
11182 `${xShape5D.length}.`);
11183 assert$1(dy5D.rank === 5, () => `Error in conv3dDerInput: dy must be rank 5, but got ` +
11184 `rank ${dy5D.rank}`);
11185 assert$1(filter.rank === 5, () => `Error in conv3dDerInput: filter must be rank 5, but got ` +
11186 `rank ${filter.rank}`);
11187 assert$1(inDepth === filter.shape[3], () => `Error in conv3dDerInput: depth of input (${inDepth}) must ` +
11188 `match input depth for filter ${filter.shape[3]}.`);
11189 assert$1(outDepth === filter.shape[4], () => `Error in conv3dDerInput: depth of output (${outDepth}) must ` +
11190 `match output depth for filter ${filter.shape[4]}.`);
11191 const inputs = { dy: dy5D, filter };
11192 const attrs = { pad, strides, inputShape: xShape5D };
11193 // tslint:disable-next-line: no-unnecessary-type-assertion
11194 const res = ENGINE.runKernel(Conv3DBackpropInputV2, inputs, attrs);
11195 if (reshapedTo5D) {
11196 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
11197 }
11198 return res;
11199 }
11200 const conv3DBackpropInput$1 = /* @__PURE__ */ op({ conv3DBackpropInput_ });
11201
11202 /**
11203 * Computes the transposed 3D convolution of a volume, also known as a
11204 * deconvolution.
11205 *
11206 * @param x The input image, of rank 5 or rank 4, of shape
11207 * `[batch, depth, height, width, inDepth]`. If rank 4, batch of 1 is assumed.
11208 * @param filter The filter, rank 4, of shape
11209 * `[depth, filterHeight, filterWidth, outDepth, inDepth]`.
11210 * `inDepth` must match `inDepth` in `x`.
11211 * @param outputShape Output shape, of rank 5 or rank 4:
11212 * `[batch, depth, height, width, outDepth]`. If rank 3, batch of 1 is
11213 * assumed.
11214 * @param strides The strides of the original convolution:
11215 * `[strideDepth, strideHeight, strideWidth]`.
11216 * @param pad The type of padding algorithm used in the non-transpose version
11217 * of the op.
11218 *
11219 * @doc {heading: 'Operations', subheading: 'Convolution'}
11220 */
11221 function conv3dTranspose_(x, filter, outputShape, strides, pad) {
11222 const $x = convertToTensor(x, 'x', 'conv3dTranspose');
11223 const $filter = convertToTensor(filter, 'filter', 'conv3dTranspose');
11224 return conv3DBackpropInput$1(outputShape, $x, $filter, strides, pad);
11225 }
11226 const conv3dTranspose$1 = /* @__PURE__ */ op({ conv3dTranspose_ });
11227
11228 /**
11229 * @license
11230 * Copyright 2018 Google LLC. All Rights Reserved.
11231 * Licensed under the Apache License, Version 2.0 (the "License");
11232 * you may not use this file except in compliance with the License.
11233 * You may obtain a copy of the License at
11234 *
11235 * http://www.apache.org/licenses/LICENSE-2.0
11236 *
11237 * Unless required by applicable law or agreed to in writing, software
11238 * distributed under the License is distributed on an "AS IS" BASIS,
11239 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11240 * See the License for the specific language governing permissions and
11241 * limitations under the License.
11242 * =============================================================================
11243 */
11244 /**
11245 * Computes cos of the input `tf.Tensor` element-wise: `cos(x)`
11246 *
11247 * ```js
11248 * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
11249 *
11250 * x.cos().print(); // or tf.cos(x)
11251 * ```
11252 * @param x The input tensor. Must be float32 type.
11253 *
11254 * @doc {heading: 'Operations', subheading: 'Basic math'}
11255 */
11256 function cos_(x) {
11257 const $x = convertToTensor(x, 'x', 'cos', 'float32');
11258 const inputs = { x: $x };
11259 return ENGINE.runKernel(Cos, inputs);
11260 }
11261 const cos$2 = /* @__PURE__ */ op({ cos_ });
11262
11263 /**
11264 * @license
11265 * Copyright 2018 Google LLC. All Rights Reserved.
11266 * Licensed under the Apache License, Version 2.0 (the "License");
11267 * you may not use this file except in compliance with the License.
11268 * You may obtain a copy of the License at
11269 *
11270 * http://www.apache.org/licenses/LICENSE-2.0
11271 *
11272 * Unless required by applicable law or agreed to in writing, software
11273 * distributed under the License is distributed on an "AS IS" BASIS,
11274 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11275 * See the License for the specific language governing permissions and
11276 * limitations under the License.
11277 * =============================================================================
11278 */
11279 /**
11280 * Computes hyperbolic cos of the input `tf.Tensor` element-wise: `cosh(x)`
11281 *
11282 * ```js
11283 * const x = tf.tensor1d([0, 1, -1, .7]);
11284 *
11285 * x.cosh().print(); // or tf.cosh(x)
11286 * ```
11287 * @param x The input tensor. Must be float32 type.
11288 *
11289 * @doc {heading: 'Operations', subheading: 'Basic math'}
11290 */
11291 function cosh_(x) {
11292 const $x = convertToTensor(x, 'x', 'cosh', 'float32');
11293 const inputs = { x: $x };
11294 return ENGINE.runKernel(Cosh, inputs);
11295 }
11296 const cosh$2 = /* @__PURE__ */ op({ cosh_ });
11297
11298 /**
11299 * @license
11300 * Copyright 2022 Google LLC. All Rights Reserved.
11301 * Licensed under the Apache License, Version 2.0 (the 'License');
11302 * you may not use this file except in compliance with the License.
11303 * You may obtain a copy of the License at
11304 *
11305 * http://www.apache.org/licenses/LICENSE-2.0
11306 *
11307 * Unless required by applicable law or agreed to in writing, software
11308 * distributed under the License is distributed on an 'AS IS' BASIS,
11309 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11310 * See the License for the specific language governing permissions and
11311 * limitations under the License.
11312 * =============================================================================
11313 */
11314 /**
11315 * Computes the cumulative product of a `tf.Tensor` along `axis`.
11316 *
11317 * ```js
11318 * const x = tf.tensor([1, 2, 3, 4]);
11319 * x.cumprod().print();
11320 * ```
11321 * ```js
11322 * const x = tf.tensor([[1, 2], [3, 4]]);
11323 * x.cumprod().print();
11324 * ```
11325 *
11326 * @param x The input tensor to cumulatively multiply.
11327 * @param axis The axis along which to multiply. Optional. Defaults to 0.
11328 * @param exclusive Whether to perform exclusive cumulative product. Optional.
11329 * Defaults to false. If set to true then the product of each tensor entry
11330 * does not include its own value, but only the values previous to it
11331 * along the specified axis.
11332 * @param reverse Whether to multiply in the opposite direction. Optional.
11333 * Defaults to false.
11334 *
11335 * @doc {heading: 'Operations', subheading: 'Scan'}
11336 */
11337 function cumprod_(x, axis = 0, exclusive = false, reverse = false) {
11338 const $x = convertToTensor(x, 'x', 'cumprod');
11339 const inputs = { x: $x };
11340 const attrs = { axis, exclusive, reverse };
11341 return ENGINE.runKernel(Cumprod, inputs, attrs);
11342 }
11343 const cumprod$2 = /* @__PURE__ */ op({ cumprod_ });
11344
11345 /**
11346 * @license
11347 * Copyright 2018 Google LLC. All Rights Reserved.
11348 * Licensed under the Apache License, Version 2.0 (the "License");
11349 * you may not use this file except in compliance with the License.
11350 * You may obtain a copy of the License at
11351 *
11352 * http://www.apache.org/licenses/LICENSE-2.0
11353 *
11354 * Unless required by applicable law or agreed to in writing, software
11355 * distributed under the License is distributed on an "AS IS" BASIS,
11356 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11357 * See the License for the specific language governing permissions and
11358 * limitations under the License.
11359 * =============================================================================
11360 */
11361 /**
11362 * Computes the cumulative sum of a `tf.Tensor` along `axis`.
11363 *
11364 * ```js
11365 * const x = tf.tensor([1, 2, 3, 4]);
11366 * x.cumsum().print();
11367 * ```
11368 * ```js
11369 * const x = tf.tensor([[1, 2], [3, 4]]);
11370 * x.cumsum().print();
11371 * ```
11372 *
11373 * @param x The input tensor to be summed.
11374 * @param axis The axis along which to sum. Optional. Defaults to 0.
11375 * @param exclusive Whether to perform exclusive cumulative sum. Optional.
11376 * Defaults to false. If set to true then the sum of each tensor entry
11377 * does not include its own value, but only the values previous to it
11378 * along the specified axis.
11379 * @param reverse Whether to sum in the opposite direction. Optional.
11380 * Defaults to false.
11381 *
11382 * @doc {heading: 'Operations', subheading: 'Scan'}
11383 */
11384 function cumsum_(x, axis = 0, exclusive = false, reverse = false) {
11385 const $x = convertToTensor(x, 'x', 'cumsum');
11386 const inputs = { x: $x };
11387 const attrs = { axis, exclusive, reverse };
11388 return ENGINE.runKernel(Cumsum, inputs, attrs);
11389 }
11390 const cumsum$2 = /* @__PURE__ */ op({ cumsum_ });
11391
11392 /**
11393 * @license
11394 * Copyright 2020 Google LLC. All Rights Reserved.
11395 * Licensed under the Apache License, Version 2.0 (the "License");
11396 * you may not use this file except in compliance with the License.
11397 * You may obtain a copy of the License at
11398 *
11399 * http://www.apache.org/licenses/LICENSE-2.0
11400 *
11401 * Unless required by applicable law or agreed to in writing, software
11402 * distributed under the License is distributed on an "AS IS" BASIS,
11403 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11404 * See the License for the specific language governing permissions and
11405 * limitations under the License.
11406 * =============================================================================
11407 */
11408 /**
11409 * Outputs a vector with length `size` and the same dtype as `weights`.
11410 *
11411 * If `weights` are empty, then index `i` stores the number of times the value
11412 * `i` is counted in `x`. If `weights` are non-empty, then index `i` stores the
11413 * sum of the value in `weights` at each index where the corresponding value in
11414 * `x` is `i`.
11415 *
11416 * Values in `x` outside of the range [0, size) are ignored.
11417 *
11418 * @param x The input int tensor, rank 1 or rank 2.
11419 * @param weights The weights tensor, must have the same shape as x, or a
11420 * length-0 Tensor, in which case it acts as all weights equal to 1.
11421 * @param size Non-negative integer.
11422 * @param binaryOutput Optional. Whether the kernel should count the appearance
11423 * or number of occurrences. Defaults to False.
11424 *
11425 * @doc {heading: 'Operations', subheading: 'Reduction'}
11426 */
11427 function denseBincount_(x, weights, size, binaryOutput = false) {
11428 const $x = convertToTensor(x, 'x', 'denseBincount');
11429 const $weights = convertToTensor(weights, 'weights', 'denseBincount');
11430 assert$1($x.dtype === 'int32', () => `Error in denseBincount: input ` +
11431 `dtype must be int32, but got ${$x.dtype}`);
11432 assert$1($x.rank <= 2, () => `Error in denseBincount: input must be at most rank 2, but got ` +
11433 `rank ${$x.rank}.`);
11434 assert$1(size >= 0, () => `size must be non-negative, but got ${size}.`);
11435 assert$1($weights.size === $x.size || $weights.size === 0, () => `Error in denseBincount: weights must have the same shape as x or ` +
11436 `0-length, but got x shape: ${$x.shape}, weights shape: ` +
11437 `${$weights.shape}.`);
11438 const inputs = { x: $x, weights: $weights };
11439 const attrs = { size, binaryOutput };
11440 return ENGINE.runKernel(DenseBincount, inputs, attrs);
11441 }
11442 const denseBincount$2 = /* @__PURE__ */ op({ denseBincount_ });
11443
11444 /**
11445 * @license
11446 * Copyright 2020 Google LLC. All Rights Reserved.
11447 * Licensed under the Apache License, Version 2.0 (the "License");
11448 * you may not use this file except in compliance with the License.
11449 * You may obtain a copy of the License at
11450 *
11451 * http://www.apache.org/licenses/LICENSE-2.0
11452 *
11453 * Unless required by applicable law or agreed to in writing, software
11454 * distributed under the License is distributed on an "AS IS" BASIS,
11455 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11456 * See the License for the specific language governing permissions and
11457 * limitations under the License.
11458 * =============================================================================
11459 */
11460 /**
11461 * Rearranges data from depth into blocks of spatial data. More specifically,
11462 * this op outputs a copy of the input tensor where values from the `depth`
11463 * dimension are moved in spatial blocks to the `height` and `width` dimensions.
11464 * The attr `blockSize` indicates the input block size and how the data is
11465 * moved.
11466 *
11467 * - Chunks of data of size `blockSize * blockSize` from depth are rearranged
11468 * into non-overlapping blocks of size `blockSize x blockSize`
11469 *
11470 * - The width the output tensor is `inputWidth * blockSize`, whereas the
11471 * height is `inputHeight * blockSize`
11472 *
11473 * - The Y, X coordinates within each block of the output image are determined
11474 * by the high order component of the input channel index
11475 *
11476 * - The depth of the input tensor must be divisible by `blockSize *
11477 * blockSize`
11478 *
11479 * The `dataFormat` attr specifies the layout of the input and output tensors
11480 * with the following options: "NHWC": [ `batch, height, width, channels` ]
11481 * "NCHW": [ `batch, channels, height, width` ]
11482 *
11483 * ```js
11484 * const x = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]);
11485 * const blockSize = 2;
11486 * const dataFormat = "NHWC";
11487 *
11488 * tf.depthToSpace(x, blockSize, dataFormat).print();
11489 * ```
11490 *
11491 * @param x The input tensor of rank 4
11492 * @param blockSIze An `int` that is `>= 2`. The size of the spatial block
11493 * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to "NHWC"
11494 *
11495 * @doc {heading: 'Tensors', subheading: 'Transformations'}
11496 */
11497 function depthToSpace_(x, blockSize, dataFormat = 'NHWC') {
11498 const $x = convertToTensor(x, 'x', 'depthToSpace', 'float32');
11499 const inputHeight = (dataFormat === 'NHWC') ? $x.shape[1] : $x.shape[2];
11500 const inputWidth = (dataFormat === 'NHWC') ? $x.shape[2] : $x.shape[3];
11501 const inputDepth = (dataFormat === 'NHWC') ? $x.shape[3] : $x.shape[1];
11502 assert$1(blockSize > 1, () => `blockSize should be > 1 for depthToSpace, but was: ${blockSize}`);
11503 assert$1(inputHeight * blockSize >= 0, () => `Negative dimension size caused by overflow when multiplying
11504 ${inputHeight} and ${blockSize} for depthToSpace with input shape
11505 ${$x.shape}`);
11506 assert$1(inputWidth * blockSize >= 0, () => `Negative dimension size caused by overflow when multiplying
11507 ${inputWidth} and ${blockSize} for depthToSpace with input shape
11508 ${$x.shape}`);
11509 assert$1((inputDepth % (blockSize * blockSize) === 0), () => `Dimension size must be evenly divisible by ${blockSize * blockSize} but is ${inputDepth} for depthToSpace with input shape ${$x.shape}`);
11510 const inputs = { x: $x };
11511 const attrs = { blockSize, dataFormat };
11512 return ENGINE.runKernel(DepthToSpace, inputs, attrs);
11513 }
11514 const depthToSpace$2 = /* @__PURE__ */ op({ depthToSpace_ });
11515
11516 /**
11517 * @license
11518 * Copyright 2020 Google LLC. All Rights Reserved.
11519 * Licensed under the Apache License, Version 2.0 (the "License");
11520 * you may not use this file except in compliance with the License.
11521 * You may obtain a copy of the License at
11522 *
11523 * http://www.apache.org/licenses/LICENSE-2.0
11524 *
11525 * Unless required by applicable law or agreed to in writing, software
11526 * distributed under the License is distributed on an "AS IS" BASIS,
11527 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11528 * See the License for the specific language governing permissions and
11529 * limitations under the License.
11530 * =============================================================================
11531 */
11532 /**
11533 * Depthwise 2D convolution.
11534 *
11535 * Given a 4D `input` array and a `filter` array of shape
11536 * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing
11537 * `inChannels` convolutional filters of depth 1, this op applies a
11538 * different filter to each input channel (expanding from 1 channel to
11539 * `channelMultiplier` channels for each), then concatenates the results
11540 * together. The output has `inChannels * channelMultiplier` channels.
11541 *
11542 * See
11543 * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d](
11544 * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d)
11545 * for more details.
11546 *
11547 * @param x The input tensor, of rank 4 or rank 3, of shape
11548 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
11549 * assumed.
11550 * @param filter The filter tensor, rank 4, of shape
11551 * `[filterHeight, filterWidth, inChannels, channelMultiplier]`.
11552 * @param strides The strides of the convolution: `[strideHeight,
11553 * strideWidth]`. If strides is a single number, then `strideHeight ==
11554 * strideWidth`.
11555 * @param pad The type of padding algorithm.
11556 * - `same` and stride 1: output will be of same size as input,
11557 * regardless of filter size.
11558 * - `valid`: output will be smaller than input if filter is larger
11559 * than 1x1.
11560 * - For more info, see this guide:
11561 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
11562 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
11563 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
11564 * in which we sample input values across the height and width dimensions
11565 * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
11566 * number, then `dilationHeight == dilationWidth`. If it is greater than
11567 * 1, then all values of `strides` must be 1.
11568 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
11569 * "NHWC". Specify the data format of the input and output data. With the
11570 * default format "NHWC", the data is stored in the order of: [batch,
11571 * height, width, channels]. Only "NHWC" is currently supported.
11572 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
11573 * provided, it will default to truncate.
11574 *
11575 * @doc {heading: 'Operations', subheading: 'Convolution'}
11576 */
11577 function depthwiseConv2d_(x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode) {
11578 const $x = convertToTensor(x, 'x', 'depthwiseConv2d', 'float32');
11579 const $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d', 'float32');
11580 let x4D = $x;
11581 let reshapedTo4D = false;
11582 if ($x.rank === 3) {
11583 reshapedTo4D = true;
11584 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
11585 }
11586 assert$1(x4D.rank === 4, () => `Error in depthwiseConv2d: input must be rank 4, but got ` +
11587 `rank ${x4D.rank}.`);
11588 assert$1($filter.rank === 4, () => `Error in depthwiseConv2d: filter must be rank 4, but got rank ` +
11589 `${$filter.rank}.`);
11590 const inChannels = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
11591 assert$1(inChannels === $filter.shape[2], () => `Error in depthwiseConv2d: number of input channels ` +
11592 `(${inChannels}) must match the inChannels dimension in ` +
11593 `filter ${$filter.shape[2]}.`);
11594 checkPadOnDimRoundingMode('depthwiseConv2d', pad, dimRoundingMode);
11595 const inputs = { x: x4D, filter: $filter };
11596 const attrs = { strides, pad, dataFormat, dilations, dimRoundingMode };
11597 // tslint:disable-next-line: no-unnecessary-type-assertion
11598 const res = ENGINE.runKernel(DepthwiseConv2dNative, inputs, attrs);
11599 if (reshapedTo4D) {
11600 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
11601 }
11602 return res;
11603 }
11604 const depthwiseConv2d$3 = /* @__PURE__ */ op({ depthwiseConv2d_ });
11605
11606 /**
11607 * @license
11608 * Copyright 2020 Google LLC. All Rights Reserved.
11609 * Licensed under the Apache License, Version 2.0 (the "License");
11610 * you may not use this file except in compliance with the License.
11611 * You may obtain a copy of the License at
11612 *
11613 * http://www.apache.org/licenses/LICENSE-2.0
11614 *
11615 * Unless required by applicable law or agreed to in writing, software
11616 * distributed under the License is distributed on an "AS IS" BASIS,
11617 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11618 * See the License for the specific language governing permissions and
11619 * limitations under the License.
11620 * =============================================================================
11621 */
11622 /**
11623 * Returns a diagonal tensor with given diagonal values.
11624 *
11625 * Given a diagonal, this operation returns a tensor with the diagonal and
11626 * everything else padded with zeros.
11627 *
11628 * Assume the input has dimensions `[D1,..., Dk]`, then the output is a tensor
11629 * of rank 2k with dimensions `[D1,..., Dk, D1,..., Dk]`
11630 *
11631 * ```js
11632 * const x = tf.tensor1d([1, 2, 3, 4]);
11633 *
11634 * tf.diag(x).print()
11635 * ```
11636 * ```js
11637 * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [4, 2])
11638 *
11639 * tf.diag(x).print()
11640 * ```
11641 * @param x The input tensor.
11642 *
11643 * @doc {heading: 'Tensors', subheading: 'Creation'}
11644 */
11645 function diag_(x) {
11646 const $x = convertToTensor(x, 'x', 'diag');
11647 const inputs = { x: $x };
11648 return ENGINE.runKernel(Diag, inputs);
11649 }
11650 const diag$2 = /* @__PURE__ */ op({ diag_ });
11651
11652 /**
11653 * @license
11654 * Copyright 2020 Google LLC. All Rights Reserved.
11655 * Licensed under the Apache License, Version 2.0 (the "License");
11656 * you may not use this file except in compliance with the License.
11657 * You may obtain a copy of the License at
11658 *
11659 * http://www.apache.org/licenses/LICENSE-2.0
11660 *
11661 * Unless required by applicable law or agreed to in writing, software
11662 * distributed under the License is distributed on an "AS IS" BASIS,
11663 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11664 * See the License for the specific language governing permissions and
11665 * limitations under the License.
11666 * =============================================================================
11667 */
11668 /**
11669 * Computes the grayscale dilation over the input `x`.
11670 *
11671 * @param x The input tensor, rank 3 or rank 4 of shape
11672 * `[batch, height, width, depth]`. If rank 3, batch of 1 is assumed.
11673 * @param filter The filter tensor, rank 3, of shape
11674 * `[filterHeight, filterWidth, depth]`.
11675 * @param strides The strides of the sliding window for each dimension of the
11676 * input tensor: `[strideHeight, strideWidth]`.
11677 * If `strides` is a single number,
11678 * then `strideHeight == strideWidth`.
11679 * @param pad The type of padding algorithm.
11680 * - `same` and stride 1: output will be of same size as input,
11681 * regardless of filter size.
11682 * - `valid`: output will be smaller than input if filter is larger
11683 * than 1*1x1.
11684 * - For more info, see this guide:
11685 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
11686 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
11687 * @param dataFormat Specify the data format of the input and output data.
11688 * Defaults to 'NHWC'. Only 'NHWC' is currently supported. With the
11689 * default format "NHWC", the data is stored in the order of: [batch,
11690 * height, width, channels].
11691 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
11692 * in which we sample input values across the height and width dimensions
11693 * for atrous morphological dilation. Defaults to `[1, 1]`. If `dilations`
11694 * is a single number, then `dilationHeight == dilationWidth`. If it is
11695 * greater than 1, then all values of `strides` must be 1.
11696 *
11697 * @doc {heading: 'Operations', subheading: 'Convolution'}
11698 */
11699 function dilation2d_(x, filter, strides, pad, dilations = [1, 1], dataFormat = 'NHWC') {
11700 const $x = convertToTensor(x, 'x', 'dilation2d');
11701 const $filter = convertToTensor(filter, 'filter', 'dilation2d');
11702 assert$1($x.rank === 3 || $x.rank === 4, () => `Error in dilation2d: input must be rank 3 or 4, but got rank ` +
11703 `${$x.rank}.`);
11704 assert$1($filter.rank === 3, () => `Error in dilation2d: filter must be rank 3, but got rank ` +
11705 `${$filter.rank}.`);
11706 assert$1(dataFormat === 'NHWC', () => `Error in dilation2d: Only NHWC is currently supported, ` +
11707 `but got dataFormat of ${dataFormat}`);
11708 let x4D = $x;
11709 let reshapedTo4D = false;
11710 if ($x.rank === 3) {
11711 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
11712 reshapedTo4D = true;
11713 }
11714 assert$1(x4D.shape[3] === $filter.shape[2], () => `Error in dilation2d: input and filter must have the same depth: ${x4D.shape[3]} vs ${$filter.shape[2]}`);
11715 const inputs = { x: x4D, filter: $filter };
11716 const attrs = { strides, pad, dilations };
11717 // tslint:disable-next-line: no-unnecessary-type-assertion
11718 const res = ENGINE.runKernel(Dilation2D, inputs, attrs);
11719 if (reshapedTo4D) {
11720 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
11721 }
11722 return res;
11723 }
11724 const dilation2d = /* @__PURE__ */ op({ dilation2d_ });
11725
11726 /**
11727 * @license
11728 * Copyright 2017 Google LLC. All Rights Reserved.
11729 * Licensed under the Apache License, Version 2.0 (the "License");
11730 * you may not use this file except in compliance with the License.
11731 * You may obtain a copy of the License at
11732 *
11733 * http://www.apache.org/licenses/LICENSE-2.0
11734 *
11735 * Unless required by applicable law or agreed to in writing, software
11736 * distributed under the License is distributed on an "AS IS" BASIS,
11737 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11738 * See the License for the specific language governing permissions and
11739 * limitations under the License.
11740 * =============================================================================
11741 */
11742 /**
11743 * Returns the dimensions in the input shape that are broadcasted to
11744 * produce the provided output shape.
11745 *
11746 * The returned dimensions are 0-indexed and sorted. An example:
11747 * inShape = [4, 1, 3]
11748 * outShape = [5, 4, 3, 3]
11749 * result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3.
11750 */
11751 function getBroadcastDims$1(inShape, outShape) {
11752 const inRank = inShape.length;
11753 const dims = [];
11754 for (let i = 0; i < inRank; i++) {
11755 const dim = inRank - 1 - i;
11756 const a = inShape[dim] || 1;
11757 const b = outShape[outShape.length - 1 - i] || 1;
11758 if (b > 1 && a === 1) {
11759 dims.unshift(dim);
11760 }
11761 }
11762 return dims;
11763 }
11764 /**
11765 * Returns the axes in the output space that should be reduced to produce
11766 * the input space.
11767 */
11768 function getReductionAxes(inShape, outShape) {
11769 const result = [];
11770 for (let i = 0; i < outShape.length; i++) {
11771 const inDim = inShape[inShape.length - i - 1];
11772 const outAxis = outShape.length - i - 1;
11773 const outDim = outShape[outAxis];
11774 if (inDim == null || (inDim === 1 && outDim > 1)) {
11775 result.unshift(outAxis);
11776 }
11777 }
11778 return result;
11779 }
11780 function assertAndGetBroadcastShape(shapeA, shapeB) {
11781 const l = Math.max(shapeA.length, shapeB.length);
11782 const result = new Array(l);
11783 for (let i = 0; i < l; i++) {
11784 let a = shapeA[shapeA.length - i - 1];
11785 if (a == null) {
11786 a = 1;
11787 }
11788 let b = shapeB[shapeB.length - i - 1];
11789 if (b == null) {
11790 b = 1;
11791 }
11792 if (a === 1) {
11793 result[l - i - 1] = b;
11794 }
11795 else if (b === 1) {
11796 result[l - i - 1] = a;
11797 }
11798 else if (a !== b) {
11799 const errMsg = `Operands could not be broadcast together with shapes ` +
11800 `${shapeA} and ${shapeB}.`;
11801 throw Error(errMsg);
11802 }
11803 else {
11804 result[l - i - 1] = a;
11805 }
11806 }
11807 return result;
11808 }
11809
11810 var broadcast_util = /*#__PURE__*/Object.freeze({
11811 __proto__: null,
11812 assertAndGetBroadcastShape: assertAndGetBroadcastShape,
11813 getBroadcastDims: getBroadcastDims$1,
11814 getReductionAxes: getReductionAxes
11815 });
11816
11817 /**
11818 * @license
11819 * Copyright 2020 Google LLC. All Rights Reserved.
11820 * Licensed under the Apache License, Version 2.0 (the "License");
11821 * you may not use this file except in compliance with the License.
11822 * You may obtain a copy of the License at
11823 *
11824 * http://www.apache.org/licenses/LICENSE-2.0
11825 *
11826 * Unless required by applicable law or agreed to in writing, software
11827 * distributed under the License is distributed on an "AS IS" BASIS,
11828 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11829 * See the License for the specific language governing permissions and
11830 * limitations under the License.
11831 * =============================================================================
11832 */
11833 /**
11834 * Returns the truth value of (a == b) element-wise. Supports broadcasting.
11835 *
11836 * ```js
11837 * const a = tf.tensor1d([1, 2, 3]);
11838 * const b = tf.tensor1d([2, 2, 2]);
11839 *
11840 * a.equal(b).print();
11841 * ```
11842 *
11843 * @param a The first input tensor.
11844 * @param b The second input tensor. Must have the same dtype as `a`.
11845 *
11846 * @doc {heading: 'Operations', subheading: 'Logical'}
11847 */
11848 function equal_(a, b) {
11849 let $a = convertToTensor(a, 'a', 'equal', 'string_or_numeric');
11850 let $b = convertToTensor(b, 'b', 'equal', 'string_or_numeric');
11851 [$a, $b] = makeTypesMatch($a, $b);
11852 assertAndGetBroadcastShape($a.shape, $b.shape);
11853 const inputs = { a: $a, b: $b };
11854 return ENGINE.runKernel(Equal, inputs);
11855 }
11856 const equal$2 = /* @__PURE__ */ op({ equal_ });
11857
11858 /**
11859 * @license
11860 * Copyright 2020 Google LLC. All Rights Reserved.
11861 * Licensed under the Apache License, Version 2.0 (the "License");
11862 * you may not use this file except in compliance with the License.
11863 * You may obtain a copy of the License at
11864 *
11865 * http://www.apache.org/licenses/LICENSE-2.0
11866 *
11867 * Unless required by applicable law or agreed to in writing, software
11868 * distributed under the License is distributed on an "AS IS" BASIS,
11869 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11870 * See the License for the specific language governing permissions and
11871 * limitations under the License.
11872 * =============================================================================
11873 */
11874 /**
11875 * Returns the elements, either `a` or `b` depending on the `condition`.
11876 *
11877 * If the condition is true, select from `a`, otherwise select from `b`.
11878 *
11879 * ```js
11880 * const cond = tf.tensor1d([false, false, true], 'bool');
11881 * const a = tf.tensor1d([1 , 2, 3]);
11882 * const b = tf.tensor1d([-1, -2, -3]);
11883 *
11884 * a.where(cond, b).print();
11885 * ```
11886 *
11887 * @param condition The input condition. Must be of dtype bool.
11888 * @param a If `condition` is rank 1, `a` may have a higher rank but
11889 * its first dimension must match the size of `condition`.
11890 * @param b A tensor with the same dtype as `a` and with shape that is
11891 * compatible with `a`.
11892 * @return A tensor with same dtype as `a` and `b`, and shape that is
11893 * broadcastable from `a` and `b`.
11894 *
11895 * @doc {heading: 'Operations', subheading: 'Logical'}
11896 */
11897 function where_(condition, a, b) {
11898 const $a = convertToTensor(a, 'a', 'where');
11899 const $b = convertToTensor(b, 'b', 'where');
11900 const $condition = convertToTensor(condition, 'condition', 'where', 'bool');
11901 // TODO: move this logic to forward function when the broadcastTo op is
11902 // implemented in WASM.
11903 // Find the broadcastable shape for $condition, $a, and $b.
11904 const broadcastShape = assertAndGetBroadcastShape(assertAndGetBroadcastShape($condition.shape, $a.shape), $b.shape);
11905 const $broadcastedCondition = broadcastTo($condition, broadcastShape);
11906 const $broadcastedA = broadcastTo($a, broadcastShape);
11907 const $broadcastedB = broadcastTo($b, broadcastShape);
11908 const inputs = {
11909 condition: $broadcastedCondition,
11910 t: $broadcastedA,
11911 e: $broadcastedB
11912 };
11913 return ENGINE.runKernel(Select, inputs);
11914 }
11915 const where = /* @__PURE__ */ op({ where_ });
11916
11917 /**
11918 * @license
11919 * Copyright 2018 Google LLC. All Rights Reserved.
11920 * Licensed under the Apache License, Version 2.0 (the "License");
11921 * you may not use this file except in compliance with the License.
11922 * You may obtain a copy of the License at
11923 *
11924 * http://www.apache.org/licenses/LICENSE-2.0
11925 *
11926 * Unless required by applicable law or agreed to in writing, software
11927 * distributed under the License is distributed on an "AS IS" BASIS,
11928 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11929 * See the License for the specific language governing permissions and
11930 * limitations under the License.
11931 * =============================================================================
11932 */
11933 /**
11934 * Creates a `tf.Tensor` with all elements set to 0 with the same shape as the
11935 * given tensor.
11936 *
11937 * ```js
11938 * const x = tf.tensor([1, 2]);
11939 * tf.zerosLike(x).print();
11940 * ```
11941 *
11942 * @param x The tensor of required shape.
11943 *
11944 * @doc {heading: 'Tensors', subheading: 'Creation'}
11945 */
11946 function zerosLike_(x) {
11947 const $x = convertToTensor(x, 'x', 'zerosLike');
11948 const inputs = { x: $x };
11949 return ENGINE.runKernel(ZerosLike, inputs);
11950 }
11951 const zerosLike$3 = /* @__PURE__ */ op({ zerosLike_ });
11952
11953 /**
11954 * @license
11955 * Copyright 2020 Google LLC. All Rights Reserved.
11956 * Licensed under the Apache License, Version 2.0 (the "License");
11957 * you may not use this file except in compliance with the License.
11958 * You may obtain a copy of the License at
11959 *
11960 * http://www.apache.org/licenses/LICENSE-2.0
11961 *
11962 * Unless required by applicable law or agreed to in writing, software
11963 * distributed under the License is distributed on an "AS IS" BASIS,
11964 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11965 * See the License for the specific language governing permissions and
11966 * limitations under the License.
11967 * =============================================================================
11968 */
11969 /**
11970 * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. Return 0
11971 * if denominator is 0.
11972 *
11973 *
11974 * ```js
11975 * const a = tf.tensor1d([1, 4, 9, 16]);
11976 * const b = tf.tensor1d([1, 2, 3, 4]);
11977 * const c = tf.tensor1d([0, 0, 0, 0]);
11978 *
11979 * a.divNoNan(b).print(); // or tf.divNoNan(a, b)
11980 * a.divNoNan(c).print(); // or tf.divNoNan(a, c)
11981 * ```
11982 *
11983 * ```js
11984 * // Broadcast div a with b.
11985 * const a = tf.tensor1d([2, 4, 6, 8]);
11986 * const b = tf.scalar(2);
11987 * const c = tf.scalar(0);
11988 *
11989 * a.divNoNan(b).print(); // or tf.divNoNan(a, b)
11990 * a.divNoNan(c).print(); // or tf.divNoNan(a, c)
11991 * ```
11992 *
11993 * @param a The first tensor as the numerator.
11994 * @param b The second tensor as the denominator. Must have the same dtype as
11995 * `a`.
11996 *
11997 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
11998 */
11999 function divNoNan_(a, b) {
12000 // TODO: Make this into its own kernel.
12001 let $a = convertToTensor(a, 'a', 'div');
12002 let $b = convertToTensor(b, 'b', 'div');
12003 [$a, $b] = makeTypesMatch($a, $b);
12004 const divResult = div$1($a, $b);
12005 const zeros = zerosLike$3(divResult);
12006 const bEqualsZero = equal$2($b, zeros);
12007 return where(bEqualsZero, zeros, divResult);
12008 }
12009 const divNoNan = /* @__PURE__ */ op({ divNoNan_ });
12010
12011 /**
12012 * @license
12013 * Copyright 2020 Google LLC. All Rights Reserved.
12014 * Licensed under the Apache License, Version 2.0 (the "License");
12015 * you may not use this file except in compliance with the License.
12016 * You may obtain a copy of the License at
12017 *
12018 * http://www.apache.org/licenses/LICENSE-2.0
12019 *
12020 * Unless required by applicable law or agreed to in writing, software
12021 * distributed under the License is distributed on an "AS IS" BASIS,
12022 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12023 * See the License for the specific language governing permissions and
12024 * limitations under the License.
12025 * =============================================================================
12026 */
12027 /**
12028 * Computes the dot product of two matrices and/or vectors, `t1` and `t2`.
12029 *
12030 * ```js
12031 * const a = tf.tensor1d([1, 2]);
12032 * const b = tf.tensor2d([[1, 2], [3, 4]]);
12033 * const c = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
12034 *
12035 * a.dot(b).print(); // or tf.dot(a, b)
12036 * b.dot(a).print();
12037 * b.dot(c).print();
12038 * ```
12039 * @param t1 The first tensor in the dot operation.
12040 * @param t2 The second tensor in the dot operation.
12041 *
12042 * @doc {heading: 'Operations', subheading: 'Matrices'}
12043 */
12044 function dot_(t1, t2) {
12045 const $t1 = convertToTensor(t1, 't1', 'dot');
12046 const $t2 = convertToTensor(t2, 't2', 'dot');
12047 assert$1(($t1.rank === 1 || $t1.rank === 2) && ($t2.rank === 1 || $t2.rank === 2), () => `Error in dot: inputs must all be rank 1 or 2, but got ranks ` +
12048 `${$t1.rank} and ${$t2.rank}.`);
12049 const t1Inner = ($t1.rank === 1 ? $t1.size : $t1.shape[1]);
12050 const t2Inner = ($t2.rank === 1 ? $t2.size : $t2.shape[0]);
12051 assert$1(t1Inner === t2Inner, () => `Error in dot: inner dimensions of inputs must match, but got ` +
12052 `${t1Inner} and ${t2Inner}.`);
12053 if ($t1.rank === 1 && $t2.rank === 1) {
12054 const t12D = reshape$3($t1, [1, -1]);
12055 const t22D = reshape$3($t2, [-1, 1]);
12056 const t1t2 = matMul$1(t12D, t22D);
12057 return reshape$3(t1t2, []);
12058 }
12059 else if ($t1.rank === 1 && $t2.rank === 2) {
12060 const t12D = reshape$3($t1, [1, -1]);
12061 const t22D = reshape$3($t2, [$t2.shape[0], $t2.shape[1]]);
12062 const t1t2 = matMul$1(t12D, t22D);
12063 return reshape$3(t1t2, [t1t2.size]);
12064 }
12065 else if ($t1.rank === 2 && $t2.rank === 1) {
12066 const t22D = reshape$3($t2, [-1, 1]);
12067 const t1t2 = matMul$1($t1, t22D);
12068 return reshape$3(t1t2, [t1t2.size]);
12069 }
12070 else {
12071 const t22D = reshape$3($t2, [$t2.shape[0], $t2.shape[1]]);
12072 const t1t2 = matMul$1($t1, t22D);
12073 return t1t2;
12074 }
12075 }
12076 const dot$2 = /* @__PURE__ */ op({ dot_ });
12077
12078 /**
12079 * @license
12080 * Copyright 2021 Google LLC. All Rights Reserved.
12081 * Licensed under the Apache License, Version 2.0 (the "License");
12082 * you may not use this file except in compliance with the License.
12083 * You may obtain a copy of the License at
12084 *
12085 * http://www.apache.org/licenses/LICENSE-2.0
12086 *
12087 * Unless required by applicable law or agreed to in writing, software
12088 * distributed under the License is distributed on an "AS IS" BASIS,
12089 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12090 * See the License for the specific language governing permissions and
12091 * limitations under the License.
12092 * =============================================================================
12093 */
12094 /**
12095 * Tensor contraction over specified indices and outer product.
12096 *
12097 * `einsum` allows defining Tensors by defining their element-wise computation.
12098 * This computation is based on
12099 * [Einstein summation](https://en.wikipedia.org/wiki/Einstein_notation).
12100 *
12101 * Some special cases include:
12102 *
12103 * Matrix multiplication:
12104 * ```js
12105 * const x = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
12106 * const y = tf.tensor2d([[0, 1], [2, 3], [4, 5]]);
12107 * x.print();
12108 * y.print();
12109 * tf.einsum('ij,jk->ik', x, y).print();
12110 * ```
12111 *
12112 * Dot product:
12113 * ```js
12114 * const x = tf.tensor1d([1, 2, 3]);
12115 * const y = tf.tensor1d([0, 1, 2]);
12116 * x.print();
12117 * y.print();
12118 * tf.einsum('i,i->', x, y).print();
12119 * ```
12120 *
12121 * Batch dot product:
12122 * ```js
12123 * const x = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
12124 * const y = tf.tensor2d([[0, 1, 2], [3, 4, 5]]);
12125 * x.print();
12126 * y.print();
12127 * tf.einsum('bi,bi->b', x, y).print();
12128 * ```
12129 *
12130 * Outer prouduct:
12131 * ```js
12132 * const x = tf.tensor1d([1, 3, 5]);
12133 * const y = tf.tensor1d([2, 4, 6]);
12134 * x.print();
12135 * y.print();
12136 * tf.einsum('i,j->ij', x, y).print();
12137 * ```
12138 *
12139 * Matrix transpose:
12140 * ```js
12141 * const x = tf.tensor2d([[1, 2], [3, 4]]);
12142 * x.print();
12143 * tf.einsum('ij->ji', x).print();
12144 * ```
12145 *
12146 * Batch matrix transpose:
12147 * ```js
12148 * const x = tf.tensor3d([[[1, 2], [3, 4]], [[-1, -2], [-3, -4]]]);
12149 * x.print();
12150 * tf.einsum('bij->bji', x).print();
12151 * ```
12152 *
12153 * Limitations:
12154 *
12155 * This implementation of einsum has the following limitations:
12156 *
12157 * - Does not support >2 input tensors.
12158 * - Does not support duplicate axes for any given input tensor. E.g., equation
12159 * 'ii->' is not supported.
12160 * - The `...` notation is not supported.
12161 *
12162 * @param equation a string describing the contraction, in the same format as
12163 * [numpy.einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html).
12164 * @param tensors the input(s) to contract (each one a Tensor), whose shapes
12165 * should be consistent with equation.
12166 * @returns The output tensor.
12167 *
12168 * @doc {heading: 'Tensors', subheading: 'Matrices'}
12169 */
12170 function einsum_(equation, ...tensors) {
12171 const $tensors = tensors.map((t, i) => convertToTensor(t, `tensors${i}`, 'einsum'));
12172 const attrs = { equation };
12173 return ENGINE.runKernel(Einsum, $tensors, attrs);
12174 }
12175 const einsum$2 = /* @__PURE__ */ op({ einsum_ });
12176
12177 /**
12178 * @license
12179 * Copyright 2020 Google LLC. All Rights Reserved.
12180 * Licensed under the Apache License, Version 2.0 (the "License");
12181 * you may not use this file except in compliance with the License.
12182 * You may obtain a copy of the License at
12183 *
12184 * http://www.apache.org/licenses/LICENSE-2.0
12185 *
12186 * Unless required by applicable law or agreed to in writing, software
12187 * distributed under the License is distributed on an "AS IS" BASIS,
12188 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12189 * See the License for the specific language governing permissions and
12190 * limitations under the License.
12191 * =============================================================================
12192 */
12193 /**
12194 * Computes exponential linear element-wise: `x > 0 ? x : (e ^ x) - 1`.
12195 *
12196 * ```js
12197 * const x = tf.tensor1d([-1, 1, -3, 2]);
12198 *
12199 * x.elu().print(); // or tf.elu(x)
12200 * ```
12201 * @param x The input tensor.
12202 *
12203 * @doc {heading: 'Operations', subheading: 'Basic math'}
12204 */
12205 function elu_(x) {
12206 const $x = convertToTensor(x, 'x', 'elu', 'float32');
12207 const inputs = { x: $x };
12208 return ENGINE.runKernel(Elu$1, inputs);
12209 }
12210 const elu$4 = /* @__PURE__ */ op({ elu_ });
12211
12212 /**
12213 * @license
12214 * Copyright 2023 Google LLC.
12215 * Licensed under the Apache License, Version 2.0 (the "License");
12216 * you may not use this file except in compliance with the License.
12217 * You may obtain a copy of the License at
12218 *
12219 * http://www.apache.org/licenses/LICENSE-2.0
12220 *
12221 * Unless required by applicable law or agreed to in writing, software
12222 * distributed under the License is distributed on an "AS IS" BASIS,
12223 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12224 * See the License for the specific language governing permissions and
12225 * limitations under the License.
12226 * =============================================================================
12227 */
12228 /**
12229 * Checks the input tensor mathes the given shape.
12230 *
12231 * Given an input tensor, returns a new tensor with the same values as the
12232 * input tensor with shape `shape`.
12233 *
12234 * The method supports the null value in tensor. It will still check the shapes,
12235 * and null is a placeholder.
12236 *
12237 *
12238 * ```js
12239 * const x = tf.tensor1d([1, 2, 3, 4]);
12240 * const y = tf.tensor1d([1, null, 3, 4]);
12241 * const z = tf.tensor2d([1, 2, 3, 4], [2,2]);
12242 * tf.ensureShape(x, [4]).print();
12243 * tf.ensureShape(y, [4]).print();
12244 * tf.ensureShape(z, [null, 2]).print();
12245 * ```
12246 *
12247 * @param x The input tensor to be ensured.
12248 * @param shape A TensorShape representing the shape of this tensor, an array
12249 * or null.
12250 *
12251 * @doc {heading: 'Tensors', subheading: 'Transformations'}
12252 */
12253 function ensureShape_(x, shape) {
12254 const $x = convertToTensor(x, 'x', 'ensureShape', 'string_or_numeric');
12255 if (!arraysEqualWithNull($x.shape, shape)) {
12256 throw new Error(`EnsureShape: Shape of tensor ${$x.shape} is not compatible with expected shape ${shape}`);
12257 }
12258 return x;
12259 }
12260 const ensureShape = /* @__PURE__ */ op({ ensureShape_ });
12261
12262 /**
12263 * @license
12264 * Copyright 2018 Google LLC. All Rights Reserved.
12265 * Licensed under the Apache License, Version 2.0 (the "License");
12266 * you may not use this file except in compliance with the License.
12267 * You may obtain a copy of the License at
12268 *
12269 * http://www.apache.org/licenses/LICENSE-2.0
12270 *
12271 * Unless required by applicable law or agreed to in writing, software
12272 * distributed under the License is distributed on an "AS IS" BASIS,
12273 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12274 * See the License for the specific language governing permissions and
12275 * limitations under the License.
12276 * =============================================================================
12277 */
12278 /**
12279 * Computes Gauss error function of the input `tf.Tensor` element-wise:
12280 * `erf(x)`
12281 *
12282 * ```js
12283 * const x = tf.tensor1d([0, .1, -.1, .7]);
12284 *
12285 * x.erf().print(); // or tf.erf(x);
12286 * ```
12287 * @param x The input tensor.
12288 *
12289 * @doc {heading: 'Operations', subheading: 'Basic math'}
12290 */
12291 function erf_(x) {
12292 let $x = convertToTensor(x, 'x', 'erf');
12293 assert$1($x.dtype === 'int32' || $x.dtype === 'float32', () => 'Input dtype must be `int32` or `float32`.');
12294 if ($x.dtype === 'int32') {
12295 $x = cast$3($x, 'float32');
12296 }
12297 const inputs = { x: $x };
12298 return ENGINE.runKernel(Erf, inputs);
12299 }
12300 const erf$2 = /* @__PURE__ */ op({ erf_ });
12301
12302 /**
12303 * @license
12304 * Copyright 2017 Google LLC. All Rights Reserved.
12305 * Licensed under the Apache License, Version 2.0 (the "License");
12306 * you may not use this file except in compliance with the License.
12307 * You may obtain a copy of the License at
12308 *
12309 * http://www.apache.org/licenses/LICENSE-2.0
12310 *
12311 * Unless required by applicable law or agreed to in writing, software
12312 * distributed under the License is distributed on an "AS IS" BASIS,
12313 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12314 * See the License for the specific language governing permissions and
12315 * limitations under the License.
12316 * =============================================================================
12317 */
12318 /**
12319 * Returns true if the axis specifies the inner most dimensions of the
12320 * array.
12321 */
12322 function axesAreInnerMostDims(axes, rank) {
12323 for (let i = 0; i < axes.length; ++i) {
12324 if (axes[axes.length - i - 1] !== rank - 1 - i) {
12325 return false;
12326 }
12327 }
12328 return true;
12329 }
12330 function combineLocations(outputLoc, reduceLoc, axes) {
12331 const rank = outputLoc.length + reduceLoc.length;
12332 const loc = [];
12333 let outIdx = 0;
12334 let reduceIdx = 0;
12335 for (let dim = 0; dim < rank; dim++) {
12336 if (axes.indexOf(dim) === -1) {
12337 loc.push(outputLoc[outIdx++]);
12338 }
12339 else {
12340 loc.push(reduceLoc[reduceIdx++]);
12341 }
12342 }
12343 return loc;
12344 }
12345 function computeOutAndReduceShapes(aShape, axes) {
12346 const outShape = [];
12347 const rank = aShape.length;
12348 for (let dim = 0; dim < rank; dim++) {
12349 if (axes.indexOf(dim) === -1) {
12350 outShape.push(aShape[dim]);
12351 }
12352 }
12353 const reduceShape = axes.map(dim => aShape[dim]);
12354 return [outShape, reduceShape];
12355 }
12356 function expandShapeToKeepDim(shape, axes) {
12357 const reduceSubShape = axes.map(x => 1);
12358 return combineLocations(shape, reduceSubShape, axes);
12359 }
12360 function assertAxesAreInnerMostDims(msg, axes, rank) {
12361 assert$1(axesAreInnerMostDims(axes, rank), () => `${msg} supports only inner-most axes for now. ` +
12362 `Got axes ${axes} and rank-${rank} input.`);
12363 }
12364 /**
12365 * Returns the axes permutation to be used with `tf.transpose`, if such
12366 * permutation is necessary. Otherwise it returns null. This method is used by
12367 * operations that operate only on inner-most axes.
12368 */
12369 function getAxesPermutation(axes, rank) {
12370 if (axesAreInnerMostDims(axes, rank)) {
12371 return null;
12372 }
12373 const result = [];
12374 for (let i = 0; i < rank; ++i) {
12375 if (axes.indexOf(i) === -1) {
12376 result.push(i);
12377 }
12378 }
12379 axes.forEach(axis => result.push(axis));
12380 return result;
12381 }
12382 /** Returns the axes permutation that undoes the original permutation. */
12383 function getUndoAxesPermutation(axes) {
12384 return axes.map((axis, i) => [i, axis])
12385 .sort((a, b) => a[1] - b[1])
12386 .map(x => x[0]);
12387 }
12388 function getInnerMostAxes(numAxes, rank) {
12389 const res = [];
12390 for (let i = rank - numAxes; i < rank; ++i) {
12391 res.push(i);
12392 }
12393 return res;
12394 }
12395
12396 /**
12397 * @license
12398 * Copyright 2020 Google LLC. All Rights Reserved.
12399 * Licensed under the Apache License, Version 2.0 (the "License");
12400 * you may not use this file except in compliance with the License.
12401 * You may obtain a copy of the License at
12402 *
12403 * http://www.apache.org/licenses/LICENSE-2.0
12404 *
12405 * Unless required by applicable law or agreed to in writing, software
12406 * distributed under the License is distributed on an "AS IS" BASIS,
12407 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12408 * See the License for the specific language governing permissions and
12409 * limitations under the License.
12410 * =============================================================================
12411 */
12412 /**
12413 * Computes the maximum of elements across dimensions of a `tf.Tensor`.
12414 *
12415 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
12416 * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
12417 * `axes`. If `keepDims` is true, the reduced dimensions are retained with
12418 * length 1. If `axes` has no entries, all dimensions are reduced, and a
12419 * `tf.Tensor` with a single element is returned.
12420 *
12421 * ```js
12422 * const x = tf.tensor1d([1, 2, 3]);
12423 *
12424 * x.max().print(); // or tf.max(x)
12425 * ```
12426 *
12427 * ```js
12428 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
12429 *
12430 * const axis = 1;
12431 * x.max(axis).print(); // or tf.max(x, axis)
12432 * ```
12433 *
12434 * @param x The input tensor.
12435 * @param axis The dimension(s) to reduce. By default it reduces
12436 * all dimensions.
12437 * @param keepDims If true, retains reduced dimensions with size 1.
12438 *
12439 * @doc {heading: 'Operations', subheading: 'Reduction'}
12440 */
12441 function max_(x, axis = null, keepDims = false) {
12442 const $x = convertToTensor(x, 'x', 'max');
12443 const inputs = { x: $x };
12444 const attrs = { reductionIndices: axis, keepDims };
12445 return ENGINE.runKernel(Max, inputs, attrs);
12446 }
12447 const max$3 = /* @__PURE__ */ op({ max_ });
12448
12449 /**
12450 * @license
12451 * Copyright 2020 Google Inc. All Rights Reserved.
12452 * Licensed under the Apache License, Version 2.0 (the "License");
12453 * you may not use this file except in compliance with the License.
12454 * You may obtain a copy of the License at
12455 *
12456 * http://www.apache.org/licenses/LICENSE-2.0
12457 *
12458 * Unless required by applicable law or agreed to in writing, software
12459 * distributed under the License is distributed on an "AS IS" BASIS,
12460 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12461 * See the License for the specific language governing permissions and
12462 * limitations under the License.
12463 * =============================================================================
12464 */
12465 /**
12466 * Computes the minimum value from the input.
12467 *
12468 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
12469 * is true, the rank of the array is reduced by 1 for each entry in `axes`.
12470 * If `keepDims` is true, the reduced dimensions are retained with length 1.
12471 * If `axes` has no entries, all dimensions are reduced, and an array with a
12472 * single element is returned.
12473 *
12474 * ```js
12475 * const x = tf.tensor1d([1, 2, 3]);
12476 *
12477 * x.min().print(); // or tf.min(x)
12478 * ```
12479 *
12480 * ```js
12481 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
12482 *
12483 * const axis = 1;
12484 * x.min(axis).print(); // or tf.min(x, axis)
12485 * ```
12486 *
12487 * @param x The input Tensor.
12488 * @param axis The dimension(s) to reduce. By default it reduces
12489 * all dimensions.
12490 * @param keepDims If true, retains reduced dimensions with size 1.
12491 *
12492 * @doc {heading: 'Operations', subheading: 'Reduction'}
12493 */
12494 function min_(x, axis = null, keepDims = false) {
12495 const $x = convertToTensor(x, 'x', 'min');
12496 const inputs = { x: $x };
12497 const attrs = { axis, keepDims };
12498 // tslint:disable-next-line: no-unnecessary-type-assertion
12499 return ENGINE.runKernel(Min, inputs, attrs);
12500 }
12501 const min$3 = /* @__PURE__ */ op({ min_ });
12502
12503 /**
12504 * @license
12505 * Copyright 2020 Google LLC. All Rights Reserved.
12506 * Licensed under the Apache License, Version 2.0 (the "License");
12507 * you may not use this file except in compliance with the License.
12508 * You may obtain a copy of the License at
12509 *
12510 * http://www.apache.org/licenses/LICENSE-2.0
12511 *
12512 * Unless required by applicable law or agreed to in writing, software
12513 * distributed under the License is distributed on an "AS IS" BASIS,
12514 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12515 * See the License for the specific language governing permissions and
12516 * limitations under the License.
12517 * =============================================================================
12518 */
12519 /**
12520 * Computes the power of one `tf.Tensor` to another. Supports broadcasting.
12521 *
12522 * Given a `tf.Tensor` x and a `tf.Tensor` y, this operation computes x^y for
12523 * corresponding elements in x and y. The result's dtype will be the upcasted
12524 * type of the `base` and `exp` dtypes.
12525 *
12526 * ```js
12527 * const a = tf.tensor([[2, 3], [4, 5]])
12528 * const b = tf.tensor([[1, 2], [3, 0]]).toInt();
12529 *
12530 * a.pow(b).print(); // or tf.pow(a, b)
12531 * ```
12532 *
12533 * ```js
12534 * const a = tf.tensor([[1, 2], [3, 4]])
12535 * const b = tf.tensor(2).toInt();
12536 *
12537 * a.pow(b).print(); // or tf.pow(a, b)
12538 * ```
12539 * We also expose `powStrict` which has the same signature as this op and
12540 * asserts that `base` and `exp` are the same shape (does not broadcast).
12541 *
12542 * @param base The base `tf.Tensor` to pow element-wise.
12543 * @param exp The exponent `tf.Tensor` to pow element-wise.
12544 *
12545 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
12546 */
12547 function pow_(base, exp) {
12548 let $base = convertToTensor(base, 'base', 'pow');
12549 let $exp = convertToTensor(exp, 'exp', 'pow');
12550 [$base, $exp] = makeTypesMatch($base, $exp);
12551 const inputs = { a: $base, b: $exp };
12552 return ENGINE.runKernel(Pow, inputs);
12553 }
12554 const pow$3 = /* @__PURE__ */ op({ pow_ });
12555
12556 /**
12557 * @license
12558 * Copyright 2018 Google LLC. All Rights Reserved.
12559 * Licensed under the Apache License, Version 2.0 (the "License");
12560 * you may not use this file except in compliance with the License.
12561 * You may obtain a copy of the License at
12562 *
12563 * http://www.apache.org/licenses/LICENSE-2.0
12564 *
12565 * Unless required by applicable law or agreed to in writing, software
12566 * distributed under the License is distributed on an "AS IS" BASIS,
12567 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12568 * See the License for the specific language governing permissions and
12569 * limitations under the License.
12570 * =============================================================================
12571 */
12572 /**
12573 * Creates rank-0 `tf.Tensor` (scalar) with the provided value and dtype.
12574 *
12575 * The same functionality can be achieved with `tf.tensor`, but in general
12576 * we recommend using `tf.scalar` as it makes the code more readable.
12577 *
12578 * ```js
12579 * tf.scalar(3.14).print();
12580 * ```
12581 *
12582 * @param value The value of the scalar.
12583 * @param dtype The data type.
12584 *
12585 * @doc {heading: 'Tensors', subheading: 'Creation'}
12586 */
12587 function scalar(value, dtype) {
12588 if (((isTypedArray(value) && dtype !== 'string') || Array.isArray(value)) &&
12589 dtype !== 'complex64') {
12590 throw new Error('Error creating a new Scalar: value must be a primitive ' +
12591 '(number|boolean|string)');
12592 }
12593 if (dtype === 'string' && isTypedArray(value) &&
12594 !(value instanceof Uint8Array)) {
12595 throw new Error('When making a scalar from encoded string, ' +
12596 'the value must be `Uint8Array`.');
12597 }
12598 const shape = [];
12599 const inferredShape = [];
12600 return makeTensor(value, shape, inferredShape, dtype);
12601 }
12602
12603 /**
12604 * @license
12605 * Copyright 2018 Google LLC. All Rights Reserved.
12606 * Licensed under the Apache License, Version 2.0 (the "License");
12607 * you may not use this file except in compliance with the License.
12608 * You may obtain a copy of the License at
12609 *
12610 * http://www.apache.org/licenses/LICENSE-2.0
12611 *
12612 * Unless required by applicable law or agreed to in writing, software
12613 * distributed under the License is distributed on an "AS IS" BASIS,
12614 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12615 * See the License for the specific language governing permissions and
12616 * limitations under the License.
12617 * =============================================================================
12618 */
12619 /**
12620 * Computes square root of the input `tf.Tensor` element-wise: `y = sqrt(x)`
12621 *
12622 * ```js
12623 * const x = tf.tensor1d([1, 2, 4, -1]);
12624 *
12625 * x.sqrt().print(); // or tf.sqrt(x)
12626 * ```
12627 * @param x The input tensor.
12628 *
12629 * @doc {heading: 'Operations', subheading: 'Basic math'}
12630 */
12631 function sqrt_(x) {
12632 const $x = convertToTensor(x, 'x', 'sqrt', 'float32');
12633 const inputs = { x: $x };
12634 return ENGINE.runKernel(Sqrt, inputs);
12635 }
12636 const sqrt$2 = /* @__PURE__ */ op({ sqrt_ });
12637
12638 /**
12639 * @license
12640 * Copyright 2019 Google LLC. All Rights Reserved.
12641 * Licensed under the Apache License, Version 2.0 (the "License");
12642 * you may not use this file except in compliance with the License.
12643 * You may obtain a copy of the License at
12644 *
12645 * http://www.apache.org/licenses/LICENSE-2.0
12646 *
12647 * Unless required by applicable law or agreed to in writing, software
12648 * distributed under the License is distributed on an "AS IS" BASIS,
12649 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12650 * See the License for the specific language governing permissions and
12651 * limitations under the License.
12652 * =============================================================================
12653 */
12654 /**
12655 * Computes square of `x` element-wise: `x ^ 2`
12656 *
12657 * ```js
12658 * const x = tf.tensor1d([1, 2, Math.sqrt(2), -1]);
12659 *
12660 * x.square().print(); // or tf.square(x)
12661 * ```
12662 * @param x The input Tensor.
12663 *
12664 * @doc {heading: 'Operations', subheading: 'Basic math'}
12665 */
12666 function square_(x) {
12667 const $x = convertToTensor(x, 'x', 'square');
12668 const attrs = {};
12669 return ENGINE.runKernel('Square', { x: $x }, attrs);
12670 }
12671 const square$2 = /* @__PURE__ */ op({ square_ });
12672
12673 /**
12674 * @license
12675 * Copyright 2018 Google LLC. All Rights Reserved.
12676 * Licensed under the Apache License, Version 2.0 (the "License");
12677 * you may not use this file except in compliance with the License.
12678 * You may obtain a copy of the License at
12679 *
12680 * http://www.apache.org/licenses/LICENSE-2.0
12681 *
12682 * Unless required by applicable law or agreed to in writing, software
12683 * distributed under the License is distributed on an "AS IS" BASIS,
12684 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12685 * See the License for the specific language governing permissions and
12686 * limitations under the License.
12687 * =============================================================================
12688 */
12689 /**
12690 * Computes the sum of elements across dimensions of a `tf.Tensor`.
12691 *
12692 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
12693 * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
12694 * `axes`. If `keepDims` is true, the reduced dimensions are retained with
12695 * length 1. If axes has no entries, all dimensions are reduced, and a
12696 * `tf.Tensor` with a single element is returned.
12697 *
12698 * ```js
12699 * const x = tf.tensor1d([1, 2, 3]);
12700 *
12701 * x.sum().print(); // or tf.sum(x)
12702 * ```
12703 *
12704 * ```js
12705 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
12706 *
12707 * const axis = 1;
12708 * x.sum(axis).print(); // or tf.sum(x, axis)
12709 * ```
12710 *
12711 * @param x The input tensor to compute the sum over. If the dtype is `bool`
12712 * it will be converted to `int32` and the output dtype will be `int32`.
12713 * @param axis The dimension(s) to reduce. By default it reduces
12714 * all dimensions.
12715 * @param keepDims If true, retains reduced dimensions with size 1.
12716 *
12717 * @doc {heading: 'Operations', subheading: 'Reduction'}
12718 */
12719 function sum_(x, axis = null, keepDims = false) {
12720 let $x = convertToTensor(x, 'x', 'sum');
12721 if ($x.dtype === 'bool') {
12722 $x = cast$3($x, 'int32');
12723 }
12724 const inputs = { x: $x };
12725 const attrs = { axis, keepDims };
12726 return ENGINE.runKernel(Sum, inputs, attrs);
12727 }
12728 const sum$3 = /* @__PURE__ */ op({ sum_ });
12729
12730 /**
12731 * @license
12732 * Copyright 2018 Google LLC. All Rights Reserved.
12733 * Licensed under the Apache License, Version 2.0 (the "License");
12734 * you may not use this file except in compliance with the License.
12735 * You may obtain a copy of the License at
12736 *
12737 * http://www.apache.org/licenses/LICENSE-2.0
12738 *
12739 * Unless required by applicable law or agreed to in writing, software
12740 * distributed under the License is distributed on an "AS IS" BASIS,
12741 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12742 * See the License for the specific language governing permissions and
12743 * limitations under the License.
12744 * =============================================================================
12745 */
12746 /**
12747 * Computes the norm of scalar, vectors, and matrices.
12748 * This function can compute several different vector norms (the 1-norm, the
12749 * Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0)
12750 * and matrix norms (Frobenius, 1-norm, and inf-norm).
12751 *
12752 * ```js
12753 * const x = tf.tensor1d([1, 2, 3, 4]);
12754 *
12755 * x.norm().print(); // or tf.norm(x)
12756 * ```
12757 *
12758 * @param x The input array.
12759 * @param ord Optional. Order of the norm. Supported norm types are
12760 * following:
12761 *
12762 * | ord | norm for matrices | norm for vectors
12763 * |------------|---------------------------|---------------------
12764 * |'euclidean' |Frobenius norm |2-norm
12765 * |'fro' |Frobenius norm |
12766 * |Infinity |max(sum(abs(x), axis=1)) |max(abs(x))
12767 * |-Infinity |min(sum(abs(x), axis=1)) |min(abs(x))
12768 * |1 |max(sum(abs(x), axis=0)) |sum(abs(x))
12769 * |2 | |sum(abs(x)^2)^(1/2)
12770 *
12771 * @param axis Optional. If axis is null (the default), the input is
12772 * considered a vector and a single vector norm is computed over the entire
12773 * set of values in the Tensor, i.e. norm(x, ord) is equivalent
12774 * to norm(x.reshape([-1]), ord). If axis is an integer, the input
12775 * is considered a batch of vectors, and axis determines the axis in x
12776 * over which to compute vector norms. If axis is a 2-tuple of integer it is
12777 * considered a batch of matrices and axis determines the axes in NDArray
12778 * over which to compute a matrix norm.
12779 * @param keepDims Optional. If true, the norm has the same dimensionality
12780 * as the input.
12781 *
12782 * @doc {heading: 'Operations', subheading: 'Matrices'}
12783 */
12784 function norm_(x, ord = 'euclidean', axis = null, keepDims = false) {
12785 x = convertToTensor(x, 'x', 'norm');
12786 const norm = normImpl(x, ord, axis);
12787 let keepDimsShape = norm.shape;
12788 if (keepDims) {
12789 const axes = parseAxisParam(axis, x.shape);
12790 keepDimsShape = expandShapeToKeepDim(norm.shape, axes);
12791 }
12792 return reshape$3(norm, keepDimsShape);
12793 }
12794 function normImpl(x, p, axis = null) {
12795 if (x.rank === 0) {
12796 return abs$2(x);
12797 }
12798 // consider vector when no axis is specified
12799 if (x.rank !== 1 && axis === null) {
12800 return normImpl(reshape$3(x, [-1]), p, axis);
12801 }
12802 // vector
12803 if (x.rank === 1 || typeof axis === 'number' ||
12804 Array.isArray(axis) && axis.length === 1) {
12805 if (p === 1) {
12806 return sum$3(abs$2(x), axis);
12807 }
12808 if (p === Infinity) {
12809 return max$3(abs$2(x), axis);
12810 }
12811 if (p === -Infinity) {
12812 return min$3(abs$2(x), axis);
12813 }
12814 if (p === 'euclidean' || p === 2) {
12815 // norm(x, 2) = sum(abs(xi) ^ 2) ^ 1/2
12816 return sqrt$2(sum$3(pow$3(abs$2(x), scalar(2, 'int32')), axis));
12817 }
12818 throw new Error(`Error in norm: invalid ord value: ${p}`);
12819 }
12820 // matrix (assumption axis[0] < axis[1])
12821 if (Array.isArray(axis) && axis.length === 2) {
12822 if (p === 1) {
12823 return max$3(sum$3(abs$2(x), axis[0]), axis[1] - 1);
12824 }
12825 if (p === Infinity) {
12826 return max$3(sum$3(abs$2(x), axis[1]), axis[0]);
12827 }
12828 if (p === -Infinity) {
12829 return min$3(sum$3(abs$2(x), axis[1]), axis[0]);
12830 }
12831 if (p === 'fro' || p === 'euclidean') {
12832 // norm(x) = sqrt(sum(pow(x, 2)))
12833 return sqrt$2(sum$3(square$2(x), axis));
12834 }
12835 throw new Error(`Error in norm: invalid ord value: ${p}`);
12836 }
12837 throw new Error(`Error in norm: invalid axis: ${axis}`);
12838 }
12839 const norm = /* @__PURE__ */ op({ norm_ });
12840
12841 /**
12842 * @license
12843 * Copyright 2022 Google LLC. All Rights Reserved.
12844 * Licensed under the Apache License, Version 2.0 (the "License");
12845 * you may not use this file except in compliance with the License.
12846 * You may obtain a copy of the License at
12847 *
12848 * http://www.apache.org/licenses/LICENSE-2.0
12849 *
12850 * Unless required by applicable law or agreed to in writing, software
12851 * distributed under the License is distributed on an "AS IS" BASIS,
12852 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12853 * See the License for the specific language governing permissions and
12854 * limitations under the License.
12855 * =============================================================================
12856 */
12857 /**
12858 * Computes the Euclidean norm of scalar, vectors, and matrices.
12859 *
12860 * ```js
12861 * const x = tf.tensor1d([1, 2, 3, 4]);
12862 *
12863 * x.euclideanNorm().print(); // or tf.euclideanNorm(x)
12864 * ```
12865 *
12866 * @param x The input array.
12867 * @param axis Optional. If axis is null (the default), the input is
12868 * considered a vector and a single vector norm is computed over the entire
12869 * set of values in the Tensor, i.e. euclideanNorm(x) is equivalent
12870 * to euclideanNorm(x.reshape([-1])). If axis is an integer, the input
12871 * is considered a batch of vectors, and axis determines the axis in x
12872 * over which to compute vector norms. If axis is a 2-tuple of integer it is
12873 * considered a batch of matrices and axis determines the axes in NDArray
12874 * over which to compute a matrix norm.
12875 * @param keepDims Optional. If true, the norm has the same dimensionality
12876 * as the input.
12877 *
12878 * @doc {heading: 'Operations', subheading: 'Matrices'}
12879 */
12880 function euclideanNorm_(x, axis = null, keepDims = false) {
12881 return norm(x, 'euclidean', axis, keepDims);
12882 }
12883 const euclideanNorm = /* @__PURE__ */ op({ euclideanNorm_ });
12884
12885 /**
12886 * @license
12887 * Copyright 2018 Google LLC. All Rights Reserved.
12888 * Licensed under the Apache License, Version 2.0 (the "License");
12889 * you may not use this file except in compliance with the License.
12890 * You may obtain a copy of the License at
12891 *
12892 * http://www.apache.org/licenses/LICENSE-2.0
12893 *
12894 * Unless required by applicable law or agreed to in writing, software
12895 * distributed under the License is distributed on an "AS IS" BASIS,
12896 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12897 * See the License for the specific language governing permissions and
12898 * limitations under the License.
12899 * =============================================================================
12900 */
12901 /**
12902 * Computes exponential of the input `tf.Tensor` element-wise. `e ^ x`
12903 *
12904 * ```js
12905 * const x = tf.tensor1d([1, 2, -3]);
12906 *
12907 * x.exp().print(); // or tf.exp(x)
12908 * ```
12909 * @param x The input tensor.
12910 *
12911 * @doc {heading: 'Operations', subheading: 'Basic math'}
12912 */
12913 function exp_(x) {
12914 const $x = convertToTensor(x, 'x', 'exp');
12915 const inputs = { x: $x };
12916 return ENGINE.runKernel(Exp, inputs);
12917 }
12918 const exp$2 = /* @__PURE__ */ op({ exp_ });
12919
12920 /**
12921 * @license
12922 * Copyright 2020 Google LLC. All Rights Reserved.
12923 * Licensed under the Apache License, Version 2.0 (the "License");
12924 * you may not use this file except in compliance with the License.
12925 * You may obtain a copy of the License at
12926 *
12927 * http://www.apache.org/licenses/LICENSE-2.0
12928 *
12929 * Unless required by applicable law or agreed to in writing, software
12930 * distributed under the License is distributed on an "AS IS" BASIS,
12931 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12932 * See the License for the specific language governing permissions and
12933 * limitations under the License.
12934 * =============================================================================
12935 */
12936 /**
12937 * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension
12938 * into the tensor's shape.
12939 *
12940 * ```js
12941 * const x = tf.tensor1d([1, 2, 3, 4]);
12942 * const axis = 1;
12943 * x.expandDims(axis).print();
12944 * ```
12945 *
12946 * @param x The input tensor whose dimensions are to be expanded.
12947 * @param axis The dimension index at which to insert shape of `1`. Defaults
12948 * to 0 (the first dimension).
12949 *
12950 * @doc {heading: 'Tensors', subheading: 'Transformations'}
12951 */
12952 function expandDims_(x, axis = 0) {
12953 const $x = convertToTensor(x, 'x', 'expandDims', 'string_or_numeric');
12954 assert$1(axis <= $x.rank, () => 'Axis must be <= rank of the tensor');
12955 const inputs = { input: $x };
12956 const attrs = { dim: axis };
12957 return ENGINE.runKernel(ExpandDims, inputs, attrs);
12958 }
12959 const expandDims$3 = /* @__PURE__ */ op({ expandDims_ });
12960
12961 /**
12962 * @license
12963 * Copyright 2018 Google LLC. All Rights Reserved.
12964 * Licensed under the Apache License, Version 2.0 (the "License");
12965 * you may not use this file except in compliance with the License.
12966 * You may obtain a copy of the License at
12967 *
12968 * http://www.apache.org/licenses/LICENSE-2.0
12969 *
12970 * Unless required by applicable law or agreed to in writing, software
12971 * distributed under the License is distributed on an "AS IS" BASIS,
12972 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12973 * See the License for the specific language governing permissions and
12974 * limitations under the License.
12975 * =============================================================================
12976 */
12977 /**
12978 * Computes exponential of the input `tf.Tensor` minus one element-wise.
12979 * `e ^ x - 1`
12980 *
12981 * ```js
12982 * const x = tf.tensor1d([1, 2, -3]);
12983 *
12984 * x.expm1().print(); // or tf.expm1(x)
12985 * ```
12986 * @param x The input tensor.
12987 *
12988 * @doc {heading: 'Operations', subheading: 'Basic math'}
12989 */
12990 function expm1_(x) {
12991 const $x = convertToTensor(x, 'x', 'expm1');
12992 const inputs = { x: $x };
12993 return ENGINE.runKernel(Expm1, inputs);
12994 }
12995 const expm1$2 = /* @__PURE__ */ op({ expm1_ });
12996
12997 /**
12998 * @license
12999 * Copyright 2020 Google LLC. All Rights Reserved.
13000 * Licensed under the Apache License, Version 2.0 (the "License");
13001 * you may not use this file except in compliance with the License.
13002 * You may obtain a copy of the License at
13003 *
13004 * http://www.apache.org/licenses/LICENSE-2.0
13005 *
13006 * Unless required by applicable law or agreed to in writing, software
13007 * distributed under the License is distributed on an "AS IS" BASIS,
13008 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13009 * See the License for the specific language governing permissions and
13010 * limitations under the License.
13011 * =============================================================================
13012 */
13013 /**
13014 * Construct a tensor by repeating it the number of times given by reps.
13015 *
13016 * This operation creates a new tensor by replicating `input` `reps`
13017 * times. The output tensor's `i`th dimension has `input.shape[i] *
13018 * reps[i]` elements, and the values of `input` are replicated
13019 * `reps[i]` times along the `i`th dimension. For example, tiling
13020 * `[a, b, c, d]` by `[2]` produces `[a, b, c, d, a, b, c, d]`.
13021 *
13022 * ```js
13023 * const a = tf.tensor1d([1, 2]);
13024 *
13025 * a.tile([2]).print(); // or tf.tile(a, [2])
13026 * ```
13027 *
13028 * ```js
13029 * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
13030 *
13031 * a.tile([1, 2]).print(); // or tf.tile(a, [1,2])
13032 * ```
13033 * @param x The tensor to tile.
13034 * @param reps Determines the number of replications per dimension.
13035 *
13036 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
13037 */
13038 function tile_(x, reps) {
13039 const $x = convertToTensor(x, 'x', 'tile', 'string_or_numeric');
13040 assert$1($x.rank === reps.length, () => `Error in transpose: rank of input ${$x.rank} ` +
13041 `must match length of reps ${reps}.`);
13042 const inputs = { x: $x };
13043 const attrs = { reps };
13044 return ENGINE.runKernel(Tile, inputs, attrs);
13045 }
13046 const tile$3 = /* @__PURE__ */ op({ tile_ });
13047
13048 /**
13049 * @license
13050 * Copyright 2020 Google LLC. All Rights Reserved.
13051 * Licensed under the Apache License, Version 2.0 (the "License");
13052 * you may not use this file except in compliance with the License.
13053 * You may obtain a copy of the License at
13054 *
13055 * http://www.apache.org/licenses/LICENSE-2.0
13056 *
13057 * Unless required by applicable law or agreed to in writing, software
13058 * distributed under the License is distributed on an "AS IS" BASIS,
13059 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13060 * See the License for the specific language governing permissions and
13061 * limitations under the License.
13062 * =============================================================================
13063 */
13064 /**
13065 * Create an identity matrix.
13066 *
13067 * @param numRows Number of rows.
13068 * @param numColumns Number of columns. Defaults to `numRows`.
13069 * @param batchShape If provided, will add the batch shape to the beginning
13070 * of the shape of the returned `tf.Tensor` by repeating the identity
13071 * matrix.
13072 * @param dtype Data type.
13073 * @returns Identity matrix of the specified size and data type, possibly
13074 * with batch repetition if `batchShape` is specified.
13075 *
13076 * @doc {heading: 'Tensors', subheading: 'Creation'}
13077 */
13078 function eye_(numRows, numColumns, batchShape, dtype = 'float32') {
13079 if (numColumns == null) {
13080 numColumns = numRows;
13081 }
13082 const buff = buffer([numRows, numColumns], dtype);
13083 const n = numRows <= numColumns ? numRows : numColumns;
13084 for (let i = 0; i < n; ++i) {
13085 buff.set(1, i, i);
13086 }
13087 const out = reshape$3(buff.toTensor(), [numRows, numColumns]);
13088 if (batchShape == null) {
13089 return out;
13090 }
13091 else {
13092 if (batchShape.length === 1) {
13093 return tile$3(expandDims$3(out, 0), [batchShape[0], 1, 1]);
13094 }
13095 else if (batchShape.length === 2) {
13096 // tslint:disable-next-line:no-unnecessary-type-assertion
13097 return tile$3(expandDims$3(expandDims$3(out, 0), 0), [batchShape[0], batchShape[1], 1, 1]);
13098 }
13099 else if (batchShape.length === 3) {
13100 // tslint:disable-next-line:no-unnecessary-type-assertion
13101 return tile$3(expandDims$3(expandDims$3(expandDims$3(out, 0), 0), 0), [
13102 batchShape[0], batchShape[1], batchShape[2], 1, 1
13103 ]);
13104 }
13105 else {
13106 throw new Error(`eye() currently supports only 1D and 2D ` +
13107 // tslint:disable-next-line:no-any
13108 `batchShapes, but received ${batchShape.length}D.`);
13109 }
13110 }
13111 }
13112 const eye = /* @__PURE__ */ op({ eye_ });
13113
13114 /**
13115 * @license
13116 * Copyright 2018 Google LLC. All Rights Reserved.
13117 * Licensed under the Apache License, Version 2.0 (the "License");
13118 * you may not use this file except in compliance with the License.
13119 * You may obtain a copy of the License at
13120 *
13121 * http://www.apache.org/licenses/LICENSE-2.0
13122 *
13123 * Unless required by applicable law or agreed to in writing, software
13124 * distributed under the License is distributed on an "AS IS" BASIS,
13125 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13126 * See the License for the specific language governing permissions and
13127 * limitations under the License.
13128 * =============================================================================
13129 */
13130 /**
13131 * Computes floor of input `tf.Tensor` element-wise: `floor(x)`.
13132 *
13133 * ```js
13134 * const x = tf.tensor1d([.6, 1.1, -3.3]);
13135 *
13136 * x.floor().print(); // or tf.floor(x)
13137 * ```
13138 * @param x The input tensor.
13139 *
13140 * @doc {heading: 'Operations', subheading: 'Basic math'}
13141 */
13142 function floor_(x) {
13143 const $x = convertToTensor(x, 'x', 'floor', 'float32');
13144 const inputs = { x: $x };
13145 return ENGINE.runKernel(Floor, inputs);
13146 }
13147 const floor$2 = /* @__PURE__ */ op({ floor_ });
13148
13149 /**
13150 * @license
13151 * Copyright 2018 Google LLC. All Rights Reserved.
13152 * Licensed under the Apache License, Version 2.0 (the "License");
13153 * you may not use this file except in compliance with the License.
13154 * You may obtain a copy of the License at
13155 *
13156 * http://www.apache.org/licenses/LICENSE-2.0
13157 *
13158 * Unless required by applicable law or agreed to in writing, software
13159 * distributed under the License is distributed on an "AS IS" BASIS,
13160 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13161 * See the License for the specific language governing permissions and
13162 * limitations under the License.
13163 * =============================================================================
13164 */
13165 /**
13166 * Gather slices from tensor `x`'s axis `axis` according to `indices`.
13167 *
13168 * ```js
13169 * const x = tf.tensor1d([1, 2, 3, 4]);
13170 * const indices = tf.tensor1d([1, 3, 3], 'int32');
13171 *
13172 * x.gather(indices).print();
13173 * ```
13174 *
13175 * ```js
13176 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
13177 * const indices = tf.tensor1d([1, 1, 0], 'int32');
13178 *
13179 * x.gather(indices).print();
13180 * ```
13181 * @param x The input tensor whose slices are to be gathered.
13182 * @param indices The indices of the values to extract.
13183 * @param axis The axis over which to select values. Defaults to 0.
13184 * @param batchDims Optional. The number of batch dimensions. It must be less
13185 * than or equal to rank(indices). Defaults to 0.
13186 * The output tensor will have shape of
13187 * `x.shape[:axis] + indices.shape[batchDims:] + x.shape[axis + 1:]`
13188 *
13189 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
13190 */
13191 function gather_(x, indices, axis = 0, batchDims = 0) {
13192 const $x = convertToTensor(x, 'x', 'gather');
13193 const $indices = convertToTensor(indices, 'indices', 'gather', 'int32');
13194 const inputs = { x: $x, indices: $indices };
13195 const attrs = { axis, batchDims };
13196 return ENGINE.runKernel(GatherV2, inputs, attrs);
13197 }
13198 const gather$1 = /* @__PURE__ */ op({ gather_ });
13199
13200 /**
13201 * @license
13202 * Copyright 2020 Google LLC. All Rights Reserved.
13203 * Licensed under the Apache License, Version 2.0 (the "License");
13204 * you may not use this file except in compliance with the License.
13205 * You may obtain a copy of the License at
13206 *
13207 * http://www.apache.org/licenses/LICENSE-2.0
13208 *
13209 * Unless required by applicable law or agreed to in writing, software
13210 * distributed under the License is distributed on an "AS IS" BASIS,
13211 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13212 * See the License for the specific language governing permissions and
13213 * limitations under the License.
13214 * =============================================================================
13215 */
13216 /**
13217 * Returns the truth value of (a > b) element-wise. Supports broadcasting.
13218 *
13219 * ```js
13220 * const a = tf.tensor1d([1, 2, 3]);
13221 * const b = tf.tensor1d([2, 2, 2]);
13222 *
13223 * a.greater(b).print();
13224 * ```
13225 *
13226 * @param a The first input tensor.
13227 * @param b The second input tensor. Must have the same dtype as `a`.
13228 *
13229 * @doc {heading: 'Operations', subheading: 'Logical'}
13230 */
13231 function greater_(a, b) {
13232 let $a = convertToTensor(a, 'a', 'greater', 'string_or_numeric');
13233 let $b = convertToTensor(b, 'b', 'greater', 'string_or_numeric');
13234 [$a, $b] = makeTypesMatch($a, $b);
13235 assertAndGetBroadcastShape($a.shape, $b.shape);
13236 const inputs = { a: $a, b: $b };
13237 return ENGINE.runKernel(Greater, inputs);
13238 }
13239 const greater$3 = /* @__PURE__ */ op({ greater_ });
13240
13241 /**
13242 * @license
13243 * Copyright 2020 Google LLC. All Rights Reserved.
13244 * Licensed under the Apache License, Version 2.0 (the "License");
13245 * you may not use this file except in compliance with the License.
13246 * You may obtain a copy of the License at
13247 *
13248 * http://www.apache.org/licenses/LICENSE-2.0
13249 *
13250 * Unless required by applicable law or agreed to in writing, software
13251 * distributed under the License is distributed on an "AS IS" BASIS,
13252 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13253 * See the License for the specific language governing permissions and
13254 * limitations under the License.
13255 * =============================================================================
13256 */
13257 /**
13258 * Returns the truth value of (a >= b) element-wise. Supports broadcasting.
13259 *
13260 * ```js
13261 * const a = tf.tensor1d([1, 2, 3]);
13262 * const b = tf.tensor1d([2, 2, 2]);
13263 *
13264 * a.greaterEqual(b).print();
13265 * ```
13266 *
13267 * @param a The first input tensor.
13268 * @param b The second input tensor. Must have the same dtype as `a`.
13269 *
13270 * @doc {heading: 'Operations', subheading: 'Logical'}
13271 */
13272 function greaterEqual_(a, b) {
13273 let $a = convertToTensor(a, 'a', 'greaterEqual', 'string_or_numeric');
13274 let $b = convertToTensor(b, 'b', 'greaterEqual', 'string_or_numeric');
13275 [$a, $b] = makeTypesMatch($a, $b);
13276 assertAndGetBroadcastShape($a.shape, $b.shape);
13277 const inputs = { a: $a, b: $b };
13278 return ENGINE.runKernel(GreaterEqual, inputs);
13279 }
13280 const greaterEqual$2 = /* @__PURE__ */ op({ greaterEqual_ });
13281
13282 /**
13283 * @license
13284 * Copyright 2020 Google LLC. All Rights Reserved.
13285 * Licensed under the Apache License, Version 2.0 (the "License");
13286 * you may not use this file except in compliance with the License.
13287 * You may obtain a copy of the License at
13288 *
13289 * http://www.apache.org/licenses/LICENSE-2.0
13290 *
13291 * Unless required by applicable law or agreed to in writing, software
13292 * distributed under the License is distributed on an "AS IS" BASIS,
13293 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13294 * See the License for the specific language governing permissions and
13295 * limitations under the License.
13296 * =============================================================================
13297 */
13298 /**
13299 * Returns the imaginary part of a complex (or real) tensor.
13300 *
13301 * Given a tensor input, this operation returns a tensor of type float that is
13302 * the imaginary part of each element in input considered as a complex number.
13303 * If input is real, a tensor of all zeros is returned.
13304 *
13305 * ```js
13306 * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]);
13307 * tf.imag(x).print();
13308 * ```
13309 *
13310 * @doc {heading: 'Tensors', subheading: 'Creation'}
13311 */
13312 function imag_(input) {
13313 const $input = convertToTensor(input, 'input', 'imag');
13314 const inputs = { input: $input };
13315 return ENGINE.runKernel(Imag, inputs);
13316 }
13317 const imag$2 = /* @__PURE__ */ op({ imag_ });
13318
13319 /**
13320 * @license
13321 * Copyright 2018 Google LLC. All Rights Reserved.
13322 * Licensed under the Apache License, Version 2.0 (the "License");
13323 * you may not use this file except in compliance with the License.
13324 * You may obtain a copy of the License at
13325 *
13326 * http://www.apache.org/licenses/LICENSE-2.0
13327 *
13328 * Unless required by applicable law or agreed to in writing, software
13329 * distributed under the License is distributed on an "AS IS" BASIS,
13330 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13331 * See the License for the specific language governing permissions and
13332 * limitations under the License.
13333 * =============================================================================
13334 */
13335 /**
13336 * Returns which elements of x are finite.
13337 *
13338 * ```js
13339 * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
13340 *
13341 * x.isFinite().print(); // or tf.isNaN(x)
13342 * ```
13343 * @param x The input Tensor.
13344 *
13345 * @doc {heading: 'Operations', subheading: 'Basic math'}
13346 */
13347 function isFinite_(x) {
13348 const $x = convertToTensor(x, 'x', 'isFinite');
13349 const inputs = { x: $x };
13350 return ENGINE.runKernel(IsFinite, inputs);
13351 }
13352 const isFinite$3 = /* @__PURE__ */ op({ isFinite_ });
13353
13354 /**
13355 * @license
13356 * Copyright 2018 Google LLC. All Rights Reserved.
13357 * Licensed under the Apache License, Version 2.0 (the "License");
13358 * you may not use this file except in compliance with the License.
13359 * You may obtain a copy of the License at
13360 *
13361 * http://www.apache.org/licenses/LICENSE-2.0
13362 *
13363 * Unless required by applicable law or agreed to in writing, software
13364 * distributed under the License is distributed on an "AS IS" BASIS,
13365 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13366 * See the License for the specific language governing permissions and
13367 * limitations under the License.
13368 * =============================================================================
13369 */
13370 /**
13371 * Returns which elements of x are Infinity or -Infinity.
13372 *
13373 * ```js
13374 * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
13375 *
13376 * x.isInf().print(); // or tf.isNaN(x)
13377 * ```
13378 * @param x The input Tensor.
13379 *
13380 * @doc {heading: 'Operations', subheading: 'Basic math'}
13381 */
13382 function isInf_(x) {
13383 const $x = convertToTensor(x, 'x', 'isInf');
13384 const inputs = { x: $x };
13385 return ENGINE.runKernel(IsInf, inputs);
13386 }
13387 const isInf$2 = /* @__PURE__ */ op({ isInf_ });
13388
13389 /**
13390 * @license
13391 * Copyright 2018 Google LLC. All Rights Reserved.
13392 * Licensed under the Apache License, Version 2.0 (the "License");
13393 * you may not use this file except in compliance with the License.
13394 * You may obtain a copy of the License at
13395 *
13396 * http://www.apache.org/licenses/LICENSE-2.0
13397 *
13398 * Unless required by applicable law or agreed to in writing, software
13399 * distributed under the License is distributed on an "AS IS" BASIS,
13400 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13401 * See the License for the specific language governing permissions and
13402 * limitations under the License.
13403 * =============================================================================
13404 */
13405 /**
13406 * Returns which elements of x are NaN.
13407 *
13408 * ```js
13409 * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
13410 *
13411 * x.isNaN().print(); // or tf.isNaN(x)
13412 * ```
13413 * @param x The input Tensor.
13414 *
13415 * @doc {heading: 'Operations', subheading: 'Basic math'}
13416 */
13417 function isNaN_(x) {
13418 const $x = convertToTensor(x, 'x', 'isNaN');
13419 const inputs = { x: $x };
13420 return ENGINE.runKernel(IsNan, inputs);
13421 }
13422 const isNaN$3 = /* @__PURE__ */ op({ isNaN_ });
13423
13424 /**
13425 * @license
13426 * Copyright 2020 Google LLC. All Rights Reserved.
13427 * Licensed under the Apache License, Version 2.0 (the "License");
13428 * you may not use this file except in compliance with the License.
13429 * You may obtain a copy of the License at
13430 *
13431 * http://www.apache.org/licenses/LICENSE-2.0
13432 *
13433 * Unless required by applicable law or agreed to in writing, software
13434 * distributed under the License is distributed on an "AS IS" BASIS,
13435 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13436 * See the License for the specific language governing permissions and
13437 * limitations under the License.
13438 * =============================================================================
13439 */
13440 /**
13441 * Computes leaky rectified linear element-wise.
13442 *
13443 * See
13444 * [http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf](
13445 * http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf)
13446 *
13447 * ```js
13448 * const x = tf.tensor1d([-1, 2, -3, 4]);
13449 *
13450 * x.leakyRelu(0.1).print(); // or tf.leakyRelu(x, 0.1)
13451 * ```
13452 * @param x The input tensor.
13453 * @param alpha The scaling factor for negative values, defaults to 0.2.
13454 *
13455 * @doc {heading: 'Operations', subheading: 'Basic math'}
13456 */
13457 function leakyRelu_(x, alpha = 0.2) {
13458 const $x = convertToTensor(x, 'x', 'leakyRelu');
13459 const inputs = { x: $x };
13460 const attrs = { alpha };
13461 return ENGINE.runKernel(LeakyRelu, inputs, attrs);
13462 }
13463 const leakyRelu$2 = /* @__PURE__ */ op({ leakyRelu_ });
13464
13465 /**
13466 * @license
13467 * Copyright 2020 Google LLC. All Rights Reserved.
13468 * Licensed under the Apache License, Version 2.0 (the "License");
13469 * you may not use this file except in compliance with the License.
13470 * You may obtain a copy of the License at
13471 *
13472 * http://www.apache.org/licenses/LICENSE-2.0
13473 *
13474 * Unless required by applicable law or agreed to in writing, software
13475 * distributed under the License is distributed on an "AS IS" BASIS,
13476 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13477 * See the License for the specific language governing permissions and
13478 * limitations under the License.
13479 * =============================================================================
13480 */
13481 /**
13482 * Returns the truth value of (a < b) element-wise. Supports broadcasting.
13483 *
13484 * ```js
13485 * const a = tf.tensor1d([1, 2, 3]);
13486 * const b = tf.tensor1d([2, 2, 2]);
13487 *
13488 * a.less(b).print();
13489 * ```
13490 * @param a The first input tensor.
13491 * @param b The second input tensor. Must have the same dtype as `a`.
13492 *
13493 * @doc {heading: 'Operations', subheading: 'Logical'}
13494 */
13495 function less_(a, b) {
13496 let $a = convertToTensor(a, 'a', 'less', 'string_or_numeric');
13497 let $b = convertToTensor(b, 'b', 'less', 'string_or_numeric');
13498 [$a, $b] = makeTypesMatch($a, $b);
13499 assertAndGetBroadcastShape($a.shape, $b.shape);
13500 const inputs = { a: $a, b: $b };
13501 return ENGINE.runKernel(Less, inputs);
13502 }
13503 const less$3 = /* @__PURE__ */ op({ less_ });
13504
13505 /**
13506 * @license
13507 * Copyright 2020 Google LLC. All Rights Reserved.
13508 * Licensed under the Apache License, Version 2.0 (the "License");
13509 * you may not use this file except in compliance with the License.
13510 * You may obtain a copy of the License at
13511 *
13512 * http://www.apache.org/licenses/LICENSE-2.0
13513 *
13514 * Unless required by applicable law or agreed to in writing, software
13515 * distributed under the License is distributed on an "AS IS" BASIS,
13516 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13517 * See the License for the specific language governing permissions and
13518 * limitations under the License.
13519 * =============================================================================
13520 */
13521 /**
13522 * Returns the truth value of (a <= b) element-wise. Supports broadcasting.
13523 *
13524 * ```js
13525 * const a = tf.tensor1d([1, 2, 3]);
13526 * const b = tf.tensor1d([2, 2, 2]);
13527 *
13528 * a.lessEqual(b).print();
13529 * ```
13530 *
13531 * @param a The first input tensor.
13532 * @param b The second input tensor. Must have the same dtype as `a`.
13533 *
13534 * @doc {heading: 'Operations', subheading: 'Logical'}
13535 */
13536 function lessEqual_(a, b) {
13537 let $a = convertToTensor(a, 'a', 'lessEqual', 'string_or_numeric');
13538 let $b = convertToTensor(b, 'b', 'lessEqual', 'string_or_numeric');
13539 [$a, $b] = makeTypesMatch($a, $b);
13540 assertAndGetBroadcastShape($a.shape, $b.shape);
13541 const inputs = { a: $a, b: $b };
13542 return ENGINE.runKernel(LessEqual, inputs);
13543 }
13544 const lessEqual$2 = /* @__PURE__ */ op({ lessEqual_ });
13545
13546 /**
13547 * @license
13548 * Copyright 2018 Google LLC. All Rights Reserved.
13549 * Licensed under the Apache License, Version 2.0 (the "License");
13550 * you may not use this file except in compliance with the License.
13551 * You may obtain a copy of the License at
13552 *
13553 * http://www.apache.org/licenses/LICENSE-2.0
13554 *
13555 * Unless required by applicable law or agreed to in writing, software
13556 * distributed under the License is distributed on an "AS IS" BASIS,
13557 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13558 * See the License for the specific language governing permissions and
13559 * limitations under the License.
13560 * =============================================================================
13561 */
13562 /**
13563 * Return an evenly spaced sequence of numbers over the given interval.
13564 *
13565 * ```js
13566 * tf.linspace(0, 9, 10).print();
13567 * ```
13568 * @param start The start value of the sequence.
13569 * @param stop The end value of the sequence.
13570 * @param num The number of values to generate.
13571 *
13572 * @doc {heading: 'Tensors', subheading: 'Creation'}
13573 */
13574 function linspace(start, stop, num) {
13575 if (num <= 0) {
13576 throw new Error('The number of values should be positive.');
13577 }
13578 const attrs = { start, stop, num };
13579 return ENGINE.runKernel(LinSpace, {}, attrs);
13580 }
13581
13582 /**
13583 * @license
13584 * Copyright 2020 Google LLC. All Rights Reserved.
13585 * Licensed under the Apache License, Version 2.0 (the "License");
13586 * you may not use this file except in compliance with the License.
13587 * You may obtain a copy of the License at
13588 *
13589 * http://www.apache.org/licenses/LICENSE-2.0
13590 *
13591 * Unless required by applicable law or agreed to in writing, software
13592 * distributed under the License is distributed on an "AS IS" BASIS,
13593 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13594 * See the License for the specific language governing permissions and
13595 * limitations under the License.
13596 * =============================================================================
13597 */
13598 /**
13599 * Normalizes the activation of a local neighborhood across or within
13600 * channels.
13601 *
13602 * @param x The input tensor. The 4-D input tensor is treated as a 3-D array
13603 * of 1D vectors (along the last dimension), and each vector is
13604 * normalized independently.
13605 * @param depthRadius The number of adjacent channels in the 1D normalization
13606 * window.
13607 * @param bias A constant bias term for the basis.
13608 * @param alpha A scale factor, usually positive.
13609 * @param beta An exponent.
13610 *
13611 * @doc {heading: 'Operations', subheading: 'Normalization'}
13612 */
13613 function localResponseNormalization_(x, depthRadius = 5, bias = 1, alpha = 1, beta = 0.5) {
13614 const $x = convertToTensor(x, 'x', 'localResponseNormalization');
13615 assert$1($x.rank === 4 || $x.rank === 3, () => `Error in localResponseNormalization: x must be rank 3 or 4 but got
13616 rank ${$x.rank}.`);
13617 assert$1(isInt(depthRadius), () => `Error in localResponseNormalization: depthRadius must be an ` +
13618 `integer but got depthRadius ${depthRadius}.`);
13619 let x4D = $x;
13620 let reshapedTo4D = false;
13621 if ($x.rank === 3) {
13622 reshapedTo4D = true;
13623 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
13624 }
13625 const inputs = { x: x4D };
13626 const attrs = { depthRadius, bias, alpha, beta };
13627 // tslint:disable-next-line: no-unnecessary-type-assertion
13628 const res = ENGINE.runKernel(LRN, inputs, attrs);
13629 if (reshapedTo4D) {
13630 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
13631 }
13632 else {
13633 return res;
13634 }
13635 }
13636 const localResponseNormalization = /* @__PURE__ */ op({ localResponseNormalization_ });
13637
13638 /**
13639 * @license
13640 * Copyright 2018 Google LLC. All Rights Reserved.
13641 * Licensed under the Apache License, Version 2.0 (the "License");
13642 * you may not use this file except in compliance with the License.
13643 * You may obtain a copy of the License at
13644 *
13645 * http://www.apache.org/licenses/LICENSE-2.0
13646 *
13647 * Unless required by applicable law or agreed to in writing, software
13648 * distributed under the License is distributed on an "AS IS" BASIS,
13649 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13650 * See the License for the specific language governing permissions and
13651 * limitations under the License.
13652 * =============================================================================
13653 */
13654 /**
13655 * Computes natural logarithm of the input `tf.Tensor` element-wise: `ln(x)`
13656 *
13657 * ```js
13658 * const x = tf.tensor1d([1, 2, Math.E]);
13659 *
13660 * x.log().print(); // or tf.log(x)
13661 * ```
13662 * @param x The input tensor.
13663 *
13664 * @doc {heading: 'Operations', subheading: 'Basic math'}
13665 */
13666 function log_(x) {
13667 const $x = convertToTensor(x, 'x', 'log', 'float32');
13668 const inputs = { x: $x };
13669 return ENGINE.runKernel(Log, inputs);
13670 }
13671 const log$2 = /* @__PURE__ */ op({ log_ });
13672
13673 /**
13674 * @license
13675 * Copyright 2018 Google LLC. All Rights Reserved.
13676 * Licensed under the Apache License, Version 2.0 (the "License");
13677 * you may not use this file except in compliance with the License.
13678 * You may obtain a copy of the License at
13679 *
13680 * http://www.apache.org/licenses/LICENSE-2.0
13681 *
13682 * Unless required by applicable law or agreed to in writing, software
13683 * distributed under the License is distributed on an "AS IS" BASIS,
13684 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13685 * See the License for the specific language governing permissions and
13686 * limitations under the License.
13687 * =============================================================================
13688 */
13689 /**
13690 * Computes natural logarithm of the input `tf.Tensor` plus one
13691 * element-wise: `ln(1 + x)`
13692 *
13693 * ```js
13694 * const x = tf.tensor1d([1, 2, Math.E - 1]);
13695 *
13696 * x.log1p().print(); // or tf.log1p(x)
13697 * ```
13698 * @param x The input tensor.
13699 *
13700 * @doc {heading: 'Operations', subheading: 'Basic math'}
13701 */
13702 function log1p_(x) {
13703 const $x = convertToTensor(x, 'x', 'log1p');
13704 const inputs = { x: $x };
13705 return ENGINE.runKernel(Log1p, inputs);
13706 }
13707 const log1p$2 = /* @__PURE__ */ op({ log1p_ });
13708
13709 /**
13710 * @license
13711 * Copyright 2018 Google LLC. All Rights Reserved.
13712 * Licensed under the Apache License, Version 2.0 (the "License");
13713 * you may not use this file except in compliance with the License.
13714 * You may obtain a copy of the License at
13715 *
13716 * http://www.apache.org/licenses/LICENSE-2.0
13717 *
13718 * Unless required by applicable law or agreed to in writing, software
13719 * distributed under the License is distributed on an "AS IS" BASIS,
13720 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13721 * See the License for the specific language governing permissions and
13722 * limitations under the License.
13723 * =============================================================================
13724 */
13725 /**
13726 * Provided `f(x)`, returns another function `g(x, dy?)`, which gives the
13727 * gradient of `f(x)` with respect to `x`.
13728 *
13729 * If `dy` is provided, the gradient of `f(x).mul(dy).sum()` with respect to
13730 * `x` is computed instead. `f(x)` must take a single tensor `x` and return a
13731 * single tensor `y`. If `f()` takes multiple inputs, use `tf.grads` instead.
13732 *
13733 * ```js
13734 * // f(x) = x ^ 2
13735 * const f = x => x.square();
13736 * // f'(x) = 2x
13737 * const g = tf.grad(f);
13738 *
13739 * const x = tf.tensor1d([2, 3]);
13740 * g(x).print();
13741 * ```
13742 *
13743 * ```js
13744 * // f(x) = x ^ 3
13745 * const f = x => x.pow(tf.scalar(3, 'int32'));
13746 * // f'(x) = 3x ^ 2
13747 * const g = tf.grad(f);
13748 * // f''(x) = 6x
13749 * const gg = tf.grad(g);
13750 *
13751 * const x = tf.tensor1d([2, 3]);
13752 * gg(x).print();
13753 * ```
13754 *
13755 * @param f The function f(x), to compute gradient for.
13756 *
13757 * @doc {heading: 'Training', subheading: 'Gradients'}
13758 */
13759 function grad(f) {
13760 assert$1(isFunction(f), () => 'The f passed in grad(f) must be a function');
13761 return (x, dy) => {
13762 // x can be of any dtype, thus null as the last argument.
13763 const $x = convertToTensor(x, 'x', 'tf.grad', 'string_or_numeric');
13764 const $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grad') : null;
13765 return ENGINE.tidy(() => {
13766 const { value, grads } = ENGINE.gradients(() => f($x), [$x], $dy);
13767 if ($dy != null) {
13768 assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grad(f)(x, dy) must match the shape ' +
13769 'returned by f(x)');
13770 }
13771 checkGrads(grads);
13772 return grads[0];
13773 });
13774 };
13775 }
13776 /**
13777 * Provided `f(x1, x2,...)`, returns another function `g([x1, x2,...], dy?)`,
13778 * which gives an array of gradients of `f()` with respect to each input
13779 * [`x1`,`x2`,...].
13780 *
13781 * If `dy` is passed when calling `g()`, the gradient of
13782 * `f(x1,...).mul(dy).sum()` with respect to each input is computed instead.
13783 * The provided `f` must take one or more tensors and return a single tensor
13784 * `y`. If `f()` takes a single input, we recommend using `tf.grad` instead.
13785 *
13786 * ```js
13787 * // f(a, b) = a * b
13788 * const f = (a, b) => a.mul(b);
13789 * // df / da = b, df / db = a
13790 * const g = tf.grads(f);
13791 *
13792 * const a = tf.tensor1d([2, 3]);
13793 * const b = tf.tensor1d([-2, -3]);
13794 * const [da, db] = g([a, b]);
13795 * console.log('da');
13796 * da.print();
13797 * console.log('db');
13798 * db.print();
13799 * ```
13800 *
13801 * @param f The function `f(x1, x2,...)` to compute gradients for.
13802 *
13803 * @doc {heading: 'Training', subheading: 'Gradients'}
13804 */
13805 function grads(f) {
13806 assert$1(isFunction(f), () => 'The f passed in grads(f) must be a function');
13807 return (args, dy) => {
13808 assert$1(Array.isArray(args), () => 'The args passed in grads(f)(args) must be an array ' +
13809 'of `Tensor`s or `TensorLike`s');
13810 // args can be of any dtype, thus null as the last argument.
13811 const $args = convertToTensorArray(args, 'args', 'tf.grads', 'string_or_numeric');
13812 const $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grads') : null;
13813 return ENGINE.tidy(() => {
13814 const { value, grads } = ENGINE.gradients(() => f(...$args), $args, $dy);
13815 if ($dy != null) {
13816 assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grads(f)([x1,...], dy) must ' +
13817 'match the shape returned by f([x1,...])');
13818 }
13819 checkGrads(grads);
13820 return grads;
13821 });
13822 };
13823 }
13824 /**
13825 * Like `tf.grad`, but also returns the value of `f()`. Useful when `f()`
13826 * returns a metric you want to show.
13827 *
13828 * The result is a rich object with the following properties:
13829 * - grad: The gradient of `f(x)` w.r.t. `x` (result of `tf.grad`).
13830 * - value: The value returned by `f(x)`.
13831 *
13832 * ```js
13833 * // f(x) = x ^ 2
13834 * const f = x => x.square();
13835 * // f'(x) = 2x
13836 * const g = tf.valueAndGrad(f);
13837 *
13838 * const x = tf.tensor1d([2, 3]);
13839 * const {value, grad} = g(x);
13840 *
13841 * console.log('value');
13842 * value.print();
13843 * console.log('grad');
13844 * grad.print();
13845 * ```
13846 *
13847 * @doc {heading: 'Training', subheading: 'Gradients'}
13848 */
13849 function valueAndGrad(f) {
13850 assert$1(isFunction(f), () => 'The f passed in valueAndGrad(f) must be a function');
13851 return (x, dy) => {
13852 assert$1(x instanceof Tensor, () => 'The x passed in valueAndGrad(f)(x) must be a tensor');
13853 assert$1(dy == null || dy instanceof Tensor, () => 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor');
13854 const { grads, value } = ENGINE.gradients(() => f(x), [x], dy);
13855 checkGrads(grads);
13856 return { grad: grads[0], value };
13857 };
13858 }
13859 /**
13860 * Like `tf.grads`, but returns also the value of `f()`. Useful when `f()`
13861 * returns a metric you want to show.
13862 *
13863 * The result is a rich object with the following properties:
13864 * - grads: The gradients of `f()` w.r.t. each input (result of `tf.grads`).
13865 * - value: The value returned by `f(x)`.
13866 *
13867 * ```js
13868 * // f(a, b) = a * b
13869 * const f = (a, b) => a.mul(b);
13870 * // df/da = b, df/db = a
13871 * const g = tf.valueAndGrads(f);
13872 *
13873 * const a = tf.tensor1d([2, 3]);
13874 * const b = tf.tensor1d([-2, -3]);
13875 * const {value, grads} = g([a, b]);
13876 *
13877 * const [da, db] = grads;
13878 *
13879 * console.log('value');
13880 * value.print();
13881 *
13882 * console.log('da');
13883 * da.print();
13884 * console.log('db');
13885 * db.print();
13886 * ```
13887 *
13888 * @doc {heading: 'Training', subheading: 'Gradients'}
13889 */
13890 function valueAndGrads(f) {
13891 assert$1(isFunction(f), () => 'The f passed in valueAndGrads(f) must be a function');
13892 return (args, dy) => {
13893 assert$1(Array.isArray(args) && args.every(arg => arg instanceof Tensor), () => 'The args passed in valueAndGrads(f)(args) must be array of ' +
13894 'tensors');
13895 assert$1(dy == null || dy instanceof Tensor, () => 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor');
13896 const res = ENGINE.gradients(() => f(...args), args, dy);
13897 if (dy != null) {
13898 assertShapesMatch(res.value.shape, dy.shape, 'The shape of dy passed in valueAndGrads(f)([x1,...], dy) must ' +
13899 'match the shape returned by f([x1,...])');
13900 }
13901 checkGrads(res.grads);
13902 return res;
13903 };
13904 }
13905 /**
13906 * Computes and returns the gradient of f(x) with respect to the list of
13907 * trainable variables provided by `varList`. If no list is provided, it
13908 * defaults to all trainable variables.
13909 *
13910 * ```js
13911 * const a = tf.variable(tf.tensor1d([3, 4]));
13912 * const b = tf.variable(tf.tensor1d([5, 6]));
13913 * const x = tf.tensor1d([1, 2]);
13914 *
13915 * // f(a, b) = a * x ^ 2 + b * x
13916 * const f = () => a.mul(x.square()).add(b.mul(x)).sum();
13917 * // df/da = x ^ 2, df/db = x
13918 * const {value, grads} = tf.variableGrads(f);
13919 *
13920 * Object.keys(grads).forEach(varName => grads[varName].print());
13921 * ```
13922 *
13923 * @param f The function to execute. f() should return a scalar.
13924 * @param varList The list of variables to compute the gradients with respect
13925 * to. Defaults to all trainable variables.
13926 * @returns An object with the following keys and values:
13927 * - `value`: The value of the function `f`.
13928 * - `grads`: A map from the names of the variables to the gradients.
13929 * If the `varList` argument is provided explicitly and contains a subset of
13930 * non-trainable variables, this map in the return value will contain keys
13931 * that map the names of the non-trainable variables to `null`.
13932 *
13933 * @doc {heading: 'Training', subheading: 'Gradients'}
13934 */
13935 function variableGrads(f, varList) {
13936 assert$1(isFunction(f), () => 'The f passed in variableGrads(f) must be a function');
13937 assert$1(varList == null ||
13938 Array.isArray(varList) && varList.every(v => v instanceof Variable), () => 'The varList passed in variableGrads(f, varList) must be an array ' +
13939 'of variables');
13940 const specifiedVarList = varList != null;
13941 if (!specifiedVarList) {
13942 // Get all of the trainable variables.
13943 varList = [];
13944 for (const varName in ENGINE.registeredVariables) {
13945 varList.push(ENGINE.registeredVariables[varName]);
13946 }
13947 }
13948 const specifiedNonTrainable = specifiedVarList ? varList.filter(variable => !variable.trainable) : null;
13949 // Prune non-trainable variables.
13950 const originalVarCount = varList.length;
13951 varList = varList.filter(variable => variable.trainable);
13952 assert$1(varList.length > 0, () => `variableGrads() expects at least one of the input variables to ` +
13953 `be trainable, but none of the ${originalVarCount} variables is ` +
13954 `trainable.`);
13955 const allowNoGradients = true;
13956 const { value, grads } = ENGINE.gradients(f, varList, null, allowNoGradients);
13957 assert$1(grads.some(g => g != null), () => 'Cannot find a connection between any variable and the result of ' +
13958 'the loss function y=f(x). Please make sure the operations that ' +
13959 'use variables are inside the function f passed to minimize().');
13960 assert$1(value.rank === 0, () => `The f passed in variableGrads(f) must return a scalar, but it ` +
13961 `returned a rank-${value.rank} tensor`);
13962 const namedGrads = {};
13963 varList.forEach((v, i) => {
13964 if (grads[i] != null) {
13965 namedGrads[v.name] = grads[i];
13966 }
13967 });
13968 if (specifiedNonTrainable != null) {
13969 // If varList is explicitly provided and contains non-trainable values,
13970 // add them to the returned gradients with `null` values.
13971 specifiedNonTrainable.forEach(v => namedGrads[v.name] = null);
13972 }
13973 return { value, grads: namedGrads };
13974 }
13975 /**
13976 * Overrides the gradient computation of a function `f`.
13977 *
13978 * Takes a function
13979 * `f(...inputs, save) => {value: Tensor, gradFunc: (dy, saved) => Tensor[]}`
13980 * and returns another function `g(...inputs)` which takes the same inputs as
13981 * `f`. When called, `g` returns `f().value`. In backward mode, custom gradients
13982 * with respect to each input of `f` are computed using `f().gradFunc`.
13983 *
13984 * The `save` function passed to `f` should be used for saving tensors needed
13985 * in the gradient. And the `saved` passed to the `gradFunc` is a
13986 * `NamedTensorMap`, which contains those saved tensors.
13987 *
13988 * ```js
13989 * const customOp = tf.customGrad((x, save) => {
13990 * // Save x to make sure it's available later for the gradient.
13991 * save([x]);
13992 * // Override gradient of our custom x ^ 2 op to be dy * abs(x);
13993 * return {
13994 * value: x.square(),
13995 * // Note `saved.x` which points to the `x` we saved earlier.
13996 * gradFunc: (dy, saved) => [dy.mul(saved[0].abs())]
13997 * };
13998 * });
13999 *
14000 * const x = tf.tensor1d([-1, -2, 3]);
14001 * const dx = tf.grad(x => customOp(x));
14002 *
14003 * console.log(`f(x):`);
14004 * customOp(x).print();
14005 * console.log(`f'(x):`);
14006 * dx(x).print();
14007 * ```
14008 *
14009 * @param f The function to evaluate in forward mode, which should return
14010 * `{value: Tensor, gradFunc: (dy, saved) => Tensor[]}`, where `gradFunc`
14011 * returns the custom gradients of `f` with respect to its inputs.
14012 *
14013 * @doc {heading: 'Training', subheading: 'Gradients'}
14014 */
14015 function customGrad(f) {
14016 return ENGINE.customGrad(f);
14017 }
14018 function checkGrads(grads) {
14019 const numNullGradients = grads.filter(g => g == null).length;
14020 if (numNullGradients > 0) {
14021 throw new Error(`Cannot compute gradient of y=f(x) with respect to x. Make sure that
14022 the f you passed encloses all operations that lead from x to y.`);
14023 }
14024 }
14025
14026 /**
14027 * @license
14028 * Copyright 2018 Google LLC. All Rights Reserved.
14029 * Licensed under the Apache License, Version 2.0 (the "License");
14030 * you may not use this file except in compliance with the License.
14031 * You may obtain a copy of the License at
14032 *
14033 * http://www.apache.org/licenses/LICENSE-2.0
14034 *
14035 * Unless required by applicable law or agreed to in writing, software
14036 * distributed under the License is distributed on an "AS IS" BASIS,
14037 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14038 * See the License for the specific language governing permissions and
14039 * limitations under the License.
14040 * =============================================================================
14041 */
14042 /**
14043 * Computes `-1 * x` element-wise.
14044 *
14045 * ```js
14046 * const x = tf.tensor2d([1, 2, -2, 0], [2, 2]);
14047 *
14048 * x.neg().print(); // or tf.neg(x)
14049 * ```
14050 *
14051 * @param x The input tensor.
14052 *
14053 * @doc {heading: 'Operations', subheading: 'Basic math'}
14054 */
14055 function neg_(x) {
14056 const $x = convertToTensor(x, 'x', 'neg');
14057 const inputs = { x: $x };
14058 return ENGINE.runKernel(Neg, inputs);
14059 }
14060 const neg$2 = /* @__PURE__ */ op({ neg_ });
14061
14062 /**
14063 * @license
14064 * Copyright 2018 Google LLC. All Rights Reserved.
14065 * Licensed under the Apache License, Version 2.0 (the "License");
14066 * you may not use this file except in compliance with the License.
14067 * You may obtain a copy of the License at
14068 *
14069 * http://www.apache.org/licenses/LICENSE-2.0
14070 *
14071 * Unless required by applicable law or agreed to in writing, software
14072 * distributed under the License is distributed on an "AS IS" BASIS,
14073 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14074 * See the License for the specific language governing permissions and
14075 * limitations under the License.
14076 * =============================================================================
14077 */
14078 /**
14079 * Computes softplus of the input `tf.Tensor` element-wise: `log(exp(x) + 1)`
14080 *
14081 * ```js
14082 * const x = tf.tensor1d([0, 1, -1, .7]);
14083 *
14084 * x.softplus().print(); // or tf.softplus(x)
14085 * ```
14086 * @param x The input tensor.
14087 *
14088 * @doc {heading: 'Operations', subheading: 'Basic math'}
14089 */
14090 function softplus_(x) {
14091 const $x = convertToTensor(x, 'x', 'softplus');
14092 const inputs = { x: $x };
14093 return ENGINE.runKernel(Softplus$1, inputs);
14094 }
14095 const softplus$2 = /* @__PURE__ */ op({ softplus_ });
14096
14097 /**
14098 * @license
14099 * Copyright 2018 Google LLC. All Rights Reserved.
14100 * Licensed under the Apache License, Version 2.0 (the "License");
14101 * you may not use this file except in compliance with the License.
14102 * You may obtain a copy of the License at
14103 *
14104 * http://www.apache.org/licenses/LICENSE-2.0
14105 *
14106 * Unless required by applicable law or agreed to in writing, software
14107 * distributed under the License is distributed on an "AS IS" BASIS,
14108 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14109 * See the License for the specific language governing permissions and
14110 * limitations under the License.
14111 * =============================================================================
14112 */
14113 /**
14114 * Computes log sigmoid of the input `tf.Tensor` element-wise:
14115 * `logSigmoid(x)`. For numerical stability, we use `-tf.softplus(-x)`.
14116 *
14117 * ```js
14118 * const x = tf.tensor1d([0, 1, -1, .7]);
14119 *
14120 * x.logSigmoid().print(); // or tf.logSigmoid(x)
14121 * ```
14122 * @param x The input tensor.
14123 *
14124 * @doc {heading: 'Operations', subheading: 'Basic math'}
14125 */
14126 function logSigmoid_(x) {
14127 const $x = convertToTensor(x, 'x', 'logSigmoid');
14128 // Use a custom gradient to maintain previous implementation.
14129 // There is no LogSigmoid kernel in TF so we can't use engine.runKernel
14130 // directly
14131 const customOp = customGrad((x) => {
14132 // TODO(yassogba) we can remove the chained softplus call here only
14133 // after backends have modualrized softplus at which point we can call
14134 // engine runKernel(..., Sotfplus, ...) directly.
14135 const value = neg$2(softplus$2(neg$2(x)));
14136 const gradFunc = (dy) => {
14137 const derX = mul(dy, sigmoid$2(neg$2(x)));
14138 return derX;
14139 };
14140 return { value, gradFunc };
14141 });
14142 return customOp($x);
14143 }
14144 const logSigmoid = /* @__PURE__ */ op({ logSigmoid_ });
14145
14146 /**
14147 * @license
14148 * Copyright 2020 Google LLC. All Rights Reserved.
14149 * Licensed under the Apache License, Version 2.0 (the "License");
14150 * you may not use this file except in compliance with the License.
14151 * You may obtain a copy of the License at
14152 *
14153 * http://www.apache.org/licenses/LICENSE-2.0
14154 *
14155 * Unless required by applicable law or agreed to in writing, software
14156 * distributed under the License is distributed on an "AS IS" BASIS,
14157 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14158 * See the License for the specific language governing permissions and
14159 * limitations under the License.
14160 * =============================================================================
14161 */
14162 /**
14163 * Subtracts two `tf.Tensor`s element-wise, A - B. Supports broadcasting.
14164 *
14165 * ```js
14166 * const a = tf.tensor1d([10, 20, 30, 40]);
14167 * const b = tf.tensor1d([1, 2, 3, 4]);
14168 *
14169 * a.sub(b).print(); // or tf.sub(a, b)
14170 * ```
14171 *
14172 * ```js
14173 * // Broadcast subtract a with b.
14174 * const a = tf.tensor1d([10, 20, 30, 40]);
14175 * const b = tf.scalar(5);
14176 *
14177 * a.sub(b).print(); // or tf.sub(a, b)
14178 * ```
14179 * @param a The first `tf.Tensor` to subtract from.
14180 * @param b The second `tf.Tensor` to be subtracted. Must have the same dtype as
14181 * `a`.
14182 *
14183 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
14184 */
14185 function sub_(a, b) {
14186 let $a = convertToTensor(a, 'a', 'sub');
14187 let $b = convertToTensor(b, 'b', 'sub');
14188 [$a, $b] = makeTypesMatch($a, $b);
14189 const inputs = { a: $a, b: $b };
14190 return ENGINE.runKernel(Sub, inputs);
14191 }
14192 const sub$2 = /* @__PURE__ */ op({ sub_ });
14193
14194 /**
14195 * @license
14196 * Copyright 2020 Google Inc. All Rights Reserved.
14197 * Licensed under the Apache License, Version 2.0 (the "License");
14198 * you may not use this file except in compliance with the License.
14199 * You may obtain a copy of the License at
14200 *
14201 * http://www.apache.org/licenses/LICENSE-2.0
14202 *
14203 * Unless required by applicable law or agreed to in writing, software
14204 * distributed under the License is distributed on an "AS IS" BASIS,
14205 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14206 * See the License for the specific language governing permissions and
14207 * limitations under the License.
14208 * =============================================================================
14209 */
14210 /**
14211 * Computes the log softmax.
14212 *
14213 * ```js
14214 * const a = tf.tensor1d([1, 2, 3]);
14215 *
14216 * a.logSoftmax().print(); // or tf.logSoftmax(a)
14217 * ```
14218 *
14219 * ```js
14220 * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);
14221 *
14222 * a.logSoftmax().print(); // or tf.logSoftmax(a)
14223 * ```
14224 *
14225 * @param logits The logits array.
14226 * @param axis The dimension softmax would be performed on. Defaults to `-1`
14227 * which indicates the last dimension.
14228 *
14229 * @doc {heading: 'Operations', subheading: 'Normalization'}
14230 */
14231 function logSoftmax_(logits, axis = -1) {
14232 const $logits = convertToTensor(logits, 'logits', 'logSoftmax');
14233 if (axis === -1) {
14234 axis = $logits.rank - 1;
14235 }
14236 if (axis !== $logits.rank - 1) {
14237 throw Error('Log Softmax along a non-last dimension is not yet supported. ' +
14238 `Logits was rank ${$logits.rank} and axis was ${axis}`);
14239 }
14240 // const forward: ForwardFunc<Tensor> = (backend, save) => {
14241 // const keepDims = true;
14242 // const xMax = max(logits, axis, true);
14243 // const shifted = sub(logits, xMax);
14244 // const value =
14245 // sub(cast(shifted, 'float32'), log(sum(exp(shifted), axis,
14246 // keepDims)));
14247 // save([value]);
14248 // return value;
14249 // };
14250 // Use a custom gradient for numerical stability.
14251 const customOp = customGrad((logits, save) => {
14252 const keepDims = true;
14253 const xMax = max$3(logits, axis, true);
14254 const shifted = sub$2(logits, xMax);
14255 const value = sub$2(cast$3(shifted, 'float32'), log$2(sum$3(exp$2(shifted), axis, keepDims)));
14256 save([value]);
14257 const gradFunc = (dy, saved) => {
14258 const [value] = saved;
14259 const keepDims = true;
14260 const softmax = exp$2(value);
14261 return sub$2(dy, mul(sum$3(dy, axis, keepDims), softmax));
14262 };
14263 return { value, gradFunc };
14264 });
14265 return customOp($logits);
14266 // TODO Use Engine.runKernel when CPU/WebGL/WASM backends implement this.
14267 // const inputs: LogSoftmaxInputs = {logits: $logits};
14268 // const attrs: LogSoftmaxAttrs = {axis};
14269 // return ENGINE.runKernel(
14270 // LogSoftmax, inputs as unknown as NamedTensorMap,
14271 // attrs as unknown as NamedAttrMap);
14272 }
14273 const logSoftmax = /* @__PURE__ */ op({ logSoftmax_ });
14274
14275 /**
14276 * @license
14277 * Copyright 2020 Google LLC. All Rights Reserved.
14278 * Licensed under the Apache License, Version 2.0 (the "License");
14279 * you may not use this file except in compliance with the License.
14280 * You may obtain a copy of the License at
14281 *
14282 * http://www.apache.org/licenses/LICENSE-2.0
14283 *
14284 * Unless required by applicable law or agreed to in writing, software
14285 * distributed under the License is distributed on an "AS IS" BASIS,
14286 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14287 * See the License for the specific language governing permissions and
14288 * limitations under the License.
14289 * =============================================================================
14290 */
14291 /**
14292 * Computes the log(sum(exp(elements across the reduction dimensions))).
14293 *
14294 * Reduces the input along the dimensions given in `axis`. Unless `keepDims`
14295 * is true, the rank of the array is reduced by 1 for each entry in `axis`.
14296 * If `keepDims` is true, the reduced dimensions are retained with length 1.
14297 * If `axis` has no entries, all dimensions are reduced, and an array with a
14298 * single element is returned.
14299 *
14300 * ```js
14301 * const x = tf.tensor1d([1, 2, 3]);
14302 *
14303 * x.logSumExp().print(); // or tf.logSumExp(x)
14304 * ```
14305 *
14306 * ```js
14307 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
14308 *
14309 * const axis = 1;
14310 * x.logSumExp(axis).print(); // or tf.logSumExp(a, axis)
14311 * ```
14312 * @param x The input tensor.
14313 * @param axis The dimension(s) to reduce. If null (the default),
14314 * reduces all dimensions.
14315 * @param keepDims If true, retains reduced dimensions with length
14316 * of 1. Defaults to false.
14317 *
14318 * @doc {heading: 'Operations', subheading: 'Reduction'}
14319 */
14320 function logSumExp_(x, axis = null, keepDims = false) {
14321 const $x = convertToTensor(x, 'x', 'logSumExp');
14322 const axes = parseAxisParam(axis, $x.shape);
14323 const xMax = max$3($x, axes, true /* keepDims */);
14324 const a = sub$2($x, xMax);
14325 const b = exp$2(a);
14326 const c = sum$3(b, axes);
14327 const d = log$2(c);
14328 const res = add$3(reshape$3(xMax, d.shape), d);
14329 if (keepDims) {
14330 const newShape = expandShapeToKeepDim(res.shape, axes);
14331 return reshape$3(res, newShape);
14332 }
14333 return res;
14334 }
14335 const logSumExp = /* @__PURE__ */ op({ logSumExp_ });
14336
14337 /**
14338 * @license
14339 * Copyright 2020 Google LLC. All Rights Reserved.
14340 * Licensed under the Apache License, Version 2.0 (the "License");
14341 * you may not use this file except in compliance with the License.
14342 * You may obtain a copy of the License at
14343 *
14344 * http://www.apache.org/licenses/LICENSE-2.0
14345 *
14346 * Unless required by applicable law or agreed to in writing, software
14347 * distributed under the License is distributed on an "AS IS" BASIS,
14348 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14349 * See the License for the specific language governing permissions and
14350 * limitations under the License.
14351 * =============================================================================
14352 */
14353 /**
14354 * Returns the truth value of `a AND b` element-wise. Supports broadcasting.
14355 *
14356 * ```js
14357 * const a = tf.tensor1d([false, false, true, true], 'bool');
14358 * const b = tf.tensor1d([false, true, false, true], 'bool');
14359 *
14360 * a.logicalAnd(b).print();
14361 * ```
14362 *
14363 * @param a The first input tensor. Must be of dtype bool.
14364 * @param b The second input tensor. Must be of dtype bool.
14365 *
14366 * @doc {heading: 'Operations', subheading: 'Logical'}
14367 */
14368 function logicalAnd_(a, b) {
14369 const $a = convertToTensor(a, 'a', 'logicalAnd', 'bool');
14370 const $b = convertToTensor(b, 'b', 'logicalAnd', 'bool');
14371 assertAndGetBroadcastShape($a.shape, $b.shape);
14372 const inputs = { a: $a, b: $b };
14373 return ENGINE.runKernel(LogicalAnd, inputs);
14374 }
14375 const logicalAnd$2 = /* @__PURE__ */ op({ logicalAnd_ });
14376
14377 /**
14378 * @license
14379 * Copyright 2020 Google LLC. All Rights Reserved.
14380 * Licensed under the Apache License, Version 2.0 (the "License");
14381 * you may not use this file except in compliance with the License.
14382 * You may obtain a copy of the License at
14383 *
14384 * http://www.apache.org/licenses/LICENSE-2.0
14385 *
14386 * Unless required by applicable law or agreed to in writing, software
14387 * distributed under the License is distributed on an "AS IS" BASIS,
14388 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14389 * See the License for the specific language governing permissions and
14390 * limitations under the License.
14391 * =============================================================================
14392 */
14393 /**
14394 * Returns the truth value of `NOT x` element-wise.
14395 *
14396 * ```js
14397 * const a = tf.tensor1d([false, true], 'bool');
14398 *
14399 * a.logicalNot().print();
14400 * ```
14401 *
14402 * @param x The input tensor. Must be of dtype 'bool'.
14403 *
14404 * @doc {heading: 'Operations', subheading: 'Logical'}
14405 */
14406 function logicalNot_(x) {
14407 const $x = convertToTensor(x, 'x', 'logicalNot', 'bool');
14408 const inputs = { x: $x };
14409 return ENGINE.runKernel(LogicalNot, inputs);
14410 }
14411 const logicalNot$2 = /* @__PURE__ */ op({ logicalNot_ });
14412
14413 /**
14414 * @license
14415 * Copyright 2020 Google LLC. All Rights Reserved.
14416 * Licensed under the Apache License, Version 2.0 (the "License");
14417 * you may not use this file except in compliance with the License.
14418 * You may obtain a copy of the License at
14419 *
14420 * http://www.apache.org/licenses/LICENSE-2.0
14421 *
14422 * Unless required by applicable law or agreed to in writing, software
14423 * distributed under the License is distributed on an "AS IS" BASIS,
14424 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14425 * See the License for the specific language governing permissions and
14426 * limitations under the License.
14427 * =============================================================================
14428 */
14429 /**
14430 * Returns the truth value of `a OR b` element-wise. Supports broadcasting.
14431 *
14432 * ```js
14433 * const a = tf.tensor1d([false, false, true, true], 'bool');
14434 * const b = tf.tensor1d([false, true, false, true], 'bool');
14435 *
14436 * a.logicalOr(b).print();
14437 * ```
14438 * @param a The first input tensor. Must be of dtype bool.
14439 * @param b The second input tensor. Must be of dtype bool.
14440 *
14441 * @doc {heading: 'Operations', subheading: 'Logical'}
14442 */
14443 function logicalOr_(a, b) {
14444 const $a = convertToTensor(a, 'a', 'logicalOr', 'bool');
14445 const $b = convertToTensor(b, 'b', 'logicalOr', 'bool');
14446 assertAndGetBroadcastShape($a.shape, $b.shape);
14447 const inputs = { a: $a, b: $b };
14448 return ENGINE.runKernel(LogicalOr, inputs);
14449 }
14450 const logicalOr$2 = /* @__PURE__ */ op({ logicalOr_ });
14451
14452 /**
14453 * @license
14454 * Copyright 2020 Google LLC. All Rights Reserved.
14455 * Licensed under the Apache License, Version 2.0 (the "License");
14456 * you may not use this file except in compliance with the License.
14457 * You may obtain a copy of the License at
14458 *
14459 * http://www.apache.org/licenses/LICENSE-2.0
14460 *
14461 * Unless required by applicable law or agreed to in writing, software
14462 * distributed under the License is distributed on an "AS IS" BASIS,
14463 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14464 * See the License for the specific language governing permissions and
14465 * limitations under the License.
14466 * =============================================================================
14467 */
14468 /**
14469 * Returns the truth value of `a XOR b` element-wise. Supports broadcasting.
14470 *
14471 * ```js
14472 * const a = tf.tensor1d([false, false, true, true], 'bool');
14473 * const b = tf.tensor1d([false, true, false, true], 'bool');
14474 *
14475 * a.logicalXor(b).print();
14476 * ```
14477 *
14478 * @param a The first input tensor. Must be of dtype bool.
14479 * @param b The second input tensor. Must be of dtype bool.
14480 *
14481 * @doc {heading: 'Operations', subheading: 'Logical'}
14482 */
14483 function logicalXor_(a, b) {
14484 const $a = convertToTensor(a, 'a', 'logicalXor', 'bool');
14485 const $b = convertToTensor(b, 'b', 'logicalXor', 'bool');
14486 assertAndGetBroadcastShape($a.shape, $b.shape);
14487 // x ^ y = (x | y) & ~(x & y)
14488 return logicalAnd$2(logicalOr$2(a, b), logicalNot$2(logicalAnd$2(a, b)));
14489 }
14490 const logicalXor = /* @__PURE__ */ op({ logicalXor_ });
14491
14492 /**
14493 * @license
14494 * Copyright 2022 Google LLC. All Rights Reserved.
14495 * Licensed under the Apache License, Version 2.0 (the "License");
14496 * you may not use this file except in compliance with the License.
14497 * You may obtain a copy of the License at
14498 *
14499 * http://www.apache.org/licenses/LICENSE-2.0
14500 *
14501 * Unless required by applicable law or agreed to in writing, software
14502 * distributed under the License is distributed on an "AS IS" BASIS,
14503 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14504 * See the License for the specific language governing permissions and
14505 * limitations under the License.
14506 * =============================================================================
14507 */
14508 const INT32_MAX$1 = 2147483648;
14509 /**
14510 * Searches for where a value would go in a sorted sequence.
14511 *
14512 * This is not a method for checking containment (like javascript in).
14513 *
14514 * The typical use case for this operation is "binning", "bucketing", or
14515 * "discretizing". The values are assigned to bucket-indices based on the edges
14516 * listed in 'sortedSequence'. This operation returns the bucket-index for each
14517 * value.
14518 *
14519 * The side argument controls which index is returned if a value lands exactly
14520 * on an edge.
14521 *
14522 * The axis is not settable for this operation. It always operates on the
14523 * innermost dimension (axis=-1). The operation will accept any number of outer
14524 * dimensions.
14525 *
14526 * Note: This operation assumes that 'sortedSequence' is sorted along the
14527 * innermost axis, maybe using 'sort(..., axis=-1)'. If the sequence is not
14528 * sorted no error is raised and the content of the returned tensor is not well
14529 * defined.
14530 *
14531 * ```js
14532 * const edges = tf.tensor1d([-1, 3.3, 9.1, 10.0]);
14533 * let values = tf.tensor1d([0.0, 4.1, 12.0]);
14534 * const result1 = tf.searchSorted(edges, values, 'left');
14535 * result1.print(); // [1, 2, 4]
14536 *
14537 * const seq = tf.tensor1d([0, 3, 9, 10, 10]);
14538 * values = tf.tensor1d([0, 4, 10]);
14539 * const result2 = tf.searchSorted(seq, values, 'left');
14540 * result2.print(); // [0, 2, 3]
14541 * const result3 = tf.searchSorted(seq, values, 'right');
14542 * result3.print(); // [1, 2, 5]
14543 *
14544 * const sortedSequence = tf.tensor2d([[0., 3., 8., 9., 10.],
14545 * [1., 2., 3., 4., 5.]]);
14546 * values = tf.tensor2d([[9.8, 2.1, 4.3],
14547 * [0.1, 6.6, 4.5, ]]);
14548 * const result4 = tf.searchSorted(sortedSequence, values, 'left');
14549 * result4.print(); // [[4, 1, 2], [0, 5, 4]]
14550 * ```
14551 * @param sortedSequence: N-D. Sorted sequence.
14552 * @param values: N-D. Search values.
14553 * @param side: 'left'|'right'. Defaults to 'left'. 'left' corresponds to lower
14554 * bound and 'right' to upper bound.
14555 * @return An N-D int32 tensor the size of values containing the result of
14556 * applying either lower bound or upper bound (depending on side) to each
14557 * value. The result is not a global index to the entire Tensor, but the
14558 * index in the last dimension.
14559 * @doc {heading: 'Operations', subheading: 'Evaluation'}
14560 */
14561 function searchSorted_(sortedSequence, values, side = 'left') {
14562 const $sortedSequence = convertToTensor(sortedSequence, 'sortedSequence', 'searchSorted');
14563 const $values = convertToTensor(values, 'values', 'searchSorted');
14564 const sequenceSize = $sortedSequence.shape[$sortedSequence.shape.length - 1];
14565 const valuesSize = $values.shape[$values.shape.length - 1];
14566 const $sortedSequence2D = reshape$3($sortedSequence, [-1, sequenceSize]);
14567 const $values2D = reshape$3($values, [-1, valuesSize]);
14568 if ($sortedSequence2D.rank < 2) {
14569 throw new Error(`Sorted input argument must be at least 2-dimensional`);
14570 }
14571 if ($sortedSequence2D.shape[0] !== $values2D.shape[0]) {
14572 throw new Error(`Leading dimension of 'sortedSequence' and 'values' must match.`);
14573 }
14574 if (sizeFromShape($values2D.shape) >= INT32_MAX$1) {
14575 throw new Error(`values tensor size must less than ${INT32_MAX$1}`);
14576 }
14577 if ($sortedSequence2D.shape[1] >= INT32_MAX$1) {
14578 throw new Error(`trailing dim_size must less than ${INT32_MAX$1} for int32 output type, was ${$sortedSequence2D.shape[1]}`);
14579 }
14580 const inputs = {
14581 sortedSequence: $sortedSequence2D,
14582 values: $values2D,
14583 };
14584 const attrs = { side };
14585 return ENGINE.runKernel(SearchSorted, inputs, attrs);
14586 }
14587 const searchSorted$2 = /* @__PURE__ */ op({ searchSorted_ });
14588
14589 /**
14590 * @license
14591 * Copyright 2022 Google LLC. All Rights Reserved.
14592 * Licensed under the Apache License, Version 2.0 (the "License");
14593 * you may not use this file except in compliance with the License.
14594 * You may obtain a copy of the License at
14595 *
14596 * http://www.apache.org/licenses/LICENSE-2.0
14597 *
14598 * Unless required by applicable law or agreed to in writing, software
14599 * distributed under the License is distributed on an "AS IS" BASIS,
14600 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14601 * See the License for the specific language governing permissions and
14602 * limitations under the License.
14603 * =============================================================================
14604 */
14605 /**
14606 * Searches for where a value would go in a sorted sequence.
14607 *
14608 * This is not a method for checking containment (like javascript in).
14609 *
14610 * The typical use case for this operation is "binning", "bucketing", or
14611 * "discretizing". The values are assigned to bucket-indices based on the edges
14612 * listed in 'sortedSequence'. This operation returns the bucket-index for each
14613 * value.
14614 *
14615 * The index returned corresponds to the first edge greater than or equal to the
14616 * value.
14617 *
14618 * The axis is not settable for this operation. It always operates on the
14619 * innermost dimension (axis=-1). The operation will accept any number of outer
14620 * dimensions.
14621 *
14622 * Note: This operation assumes that 'lowerBound' is sorted along the
14623 * innermost axis, maybe using 'sort(..., axis=-1)'. If the sequence is not
14624 * sorted no error is raised and the content of the returned tensor is not well
14625 * defined.
14626 *
14627 * ```js
14628 * const edges = tf.tensor1d([-1, 3.3, 9.1, 10.0]);
14629 * let values = tf.tensor1d([0.0, 4.1, 12.0]);
14630 * const result1 = tf.lowerBound(edges, values);
14631 * result1.print(); // [1, 2, 4]
14632 *
14633 * const seq = tf.tensor1d([0, 3, 9, 10, 10]);
14634 * values = tf.tensor1d([0, 4, 10]);
14635 * const result2 = tf.lowerBound(seq, values);
14636 * result2.print(); // [0, 2, 3]
14637 *
14638 * const sortedSequence = tf.tensor2d([[0., 3., 8., 9., 10.],
14639 * [1., 2., 3., 4., 5.]]);
14640 * values = tf.tensor2d([[9.8, 2.1, 4.3],
14641 * [0.1, 6.6, 4.5, ]]);
14642 * const result3 = tf.lowerBound(sortedSequence, values);
14643 * result3.print(); // [[4, 1, 2], [0, 5, 4]]
14644 * ```
14645 * @param sortedSequence: N-D. Sorted sequence.
14646 * @param values: N-D. Search values.
14647 * @return An N-D int32 tensor the size of values containing the result of
14648 * applying lower bound to each value. The result is not a global index to
14649 * the entire Tensor, but the index in the last dimension.
14650 * @doc {heading: 'Operations', subheading: 'Evaluation'}
14651 */
14652 function lowerBound$1(sortedSequence, values) {
14653 return searchSorted$2(sortedSequence, values, 'left');
14654 }
14655
14656 /**
14657 * @license
14658 * Copyright 2020 Google LLC. All Rights Reserved.
14659 * Licensed under the Apache License, Version 2.0 (the "License");
14660 * you may not use this file except in compliance with the License.
14661 * You may obtain a copy of the License at
14662 *
14663 * http://www.apache.org/licenses/LICENSE-2.0
14664 *
14665 * Unless required by applicable law or agreed to in writing, software
14666 * distributed under the License is distributed on an "AS IS" BASIS,
14667 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14668 * See the License for the specific language governing permissions and
14669 * limitations under the License.
14670 * =============================================================================
14671 */
14672 /**
14673 * Computes the 2D max pooling of an image.
14674 *
14675 * @param x The input tensor, of rank 4 or rank 3 of shape
14676 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
14677 * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
14678 * `filterSize` is a single number, then `filterHeight == filterWidth`.
14679 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
14680 * `strides` is a single number, then `strideHeight == strideWidth`.
14681 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
14682 * in which we sample input values across the height and width dimensions
14683 * in dilated pooling. Defaults to `[1, 1]`. If `dilations` is a single
14684 * number, then `dilationHeight == dilationWidth`. If it is greater than
14685 * 1, then all values of `strides` must be 1.
14686 * @param pad The type of padding algorithm.
14687 * - `same` and stride 1: output will be of same size as input,
14688 * regardless of filter size.
14689 * - `valid`: output will be smaller than input if filter is larger
14690 * than 1x1.
14691 * - For more info, see this guide:
14692 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
14693 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
14694 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
14695 * provided, it will default to truncate.
14696 */
14697 function maxPool_(x, filterSize, strides, pad, dimRoundingMode) {
14698 const $x = convertToTensor(x, 'x', 'maxPool');
14699 const dilations = 1;
14700 let x4D = $x;
14701 let reshapedTo4D = false;
14702 if ($x.rank === 3) {
14703 reshapedTo4D = true;
14704 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
14705 }
14706 assert$1(x4D.rank === 4, () => `Error in maxPool: input must be rank 4 but got rank ${x4D.rank}.`);
14707 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
14708 `Got strides ${strides} and dilations '${dilations}'`);
14709 checkPadOnDimRoundingMode('maxPool', pad, dimRoundingMode);
14710 const inputs = { x: x4D };
14711 const attrs = { filterSize, strides, pad, dimRoundingMode };
14712 // tslint:disable-next-line: no-unnecessary-type-assertion
14713 const res = ENGINE.runKernel(MaxPool, inputs, attrs);
14714 if (reshapedTo4D) {
14715 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
14716 }
14717 return res;
14718 }
14719 const maxPool$2 = /* @__PURE__ */ op({ maxPool_ });
14720
14721 /**
14722 * @license
14723 * Copyright 2020 Google LLC. All Rights Reserved.
14724 * Licensed under the Apache License, Version 2.0 (the "License");
14725 * you may not use this file except in compliance with the License.
14726 * You may obtain a copy of the License at
14727 *
14728 * http://www.apache.org/licenses/LICENSE-2.0
14729 *
14730 * Unless required by applicable law or agreed to in writing, software
14731 * distributed under the License is distributed on an "AS IS" BASIS,
14732 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14733 * See the License for the specific language governing permissions and
14734 * limitations under the License.
14735 * =============================================================================
14736 */
14737 /**
14738 * Computes the 3D max pooling.
14739 *
14740 * ```js
14741 * const x = tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]);
14742 * const result = tf.maxPool3d(x, 2, 1, 'valid');
14743 * result.print();
14744 * ```
14745 *
14746 * @param x The input tensor, of rank 5 or rank 4 of shape
14747 * `[batch, depth, height, width, inChannels]`.
14748 * @param filterSize The filter size:
14749 * `[filterDepth, filterHeight, filterWidth]`.
14750 * If `filterSize` is a single number,
14751 * then `filterDepth == filterHeight == filterWidth`.
14752 * @param strides The strides of the pooling:
14753 * `[strideDepth, strideHeight, strideWidth]`.
14754 * If `strides` is a single number,
14755 * then `strideDepth == strideHeight == strideWidth`.
14756 * @param pad The type of padding algorithm.
14757 * - `same` and stride 1: output will be of same size as input,
14758 * regardless of filter size.
14759 * - `valid`: output will be smaller than input if filter is larger
14760 * than 1*1x1.
14761 * - For more info, see this guide:
14762 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
14763 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
14764 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
14765 * provided, it will default to truncate.
14766 * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
14767 * "NDHWC". Specify the data format of the input and output data. With the
14768 * default format "NDHWC", the data is stored in the order of: [batch,
14769 * depth, height, width, channels]. Only "NDHWC" is currently supported.
14770 * @doc {heading: 'Operations', subheading: 'Convolution'}
14771 */
14772 function maxPool3d_(x, filterSize = [1, 1, 1], strides, pad, dimRoundingMode, dataFormat = 'NDHWC') {
14773 const $x = convertToTensor(x, 'x', 'maxPool3d');
14774 let x5D = $x;
14775 let reshapedTo5D = false;
14776 if ($x.rank === 4) {
14777 reshapedTo5D = true;
14778 x5D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
14779 }
14780 assert$1(x5D.rank === 5, () => `Error in maxPool3d: x must be rank 5 but got rank ${x5D.rank}.`);
14781 assert$1(dataFormat === 'NDHWC', () => `Error in maxPool3d: Only NDHWC is currently supported, ` +
14782 `but got dataFormat of ${dataFormat}`);
14783 checkPadOnDimRoundingMode('maxPool3d', pad, dimRoundingMode);
14784 const inputs = { x: x5D };
14785 const attrs = { filterSize, strides, pad, dimRoundingMode, dataFormat };
14786 // tslint:disable-next-line: no-unnecessary-type-assertion
14787 const res = ENGINE.runKernel(MaxPool3D, inputs, attrs);
14788 if (reshapedTo5D) {
14789 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
14790 }
14791 return res;
14792 }
14793 const maxPool3d$1 = /* @__PURE__ */ op({ maxPool3d_ });
14794
14795 /**
14796 * @license
14797 * Copyright 2018 Google LLC. All Rights Reserved.
14798 * Licensed under the Apache License, Version 2.0 (the "License");
14799 * you may not use this file except in compliance with the License.
14800 * You may obtain a copy of the License at
14801 *
14802 * http://www.apache.org/licenses/LICENSE-2.0
14803 *
14804 * Unless required by applicable law or agreed to in writing, software
14805 * distributed under the License is distributed on an "AS IS" BASIS,
14806 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14807 * See the License for the specific language governing permissions and
14808 * limitations under the License.
14809 * =============================================================================
14810 */
14811 /**
14812 * Computes the 2D max pooling of an image with Argmax index.
14813 * The indices in argmax are flattened, so that a maximum value at position `[b,
14814 * y, x, c]` becomes flattened index: `(y * width + x) * channels + c` if
14815 * include_batch_in_index is False; `((b * height + y) * width + x) * channels
14816 * +c` if include_batch_in_index is True.
14817 *
14818 * The indices returned are always in `[0, height) x [0, width)` before
14819 * flattening.
14820 *
14821 * @param x The input tensor, of rank 4 or rank 3 of shape
14822 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
14823 * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
14824 * `filterSize` is a single number, then `filterHeight == filterWidth`.
14825 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
14826 * `strides` is a single number, then `strideHeight == strideWidth`.
14827 * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
14828 * "NDHWC". Specify the data format of the input and output data. With the
14829 * default format "NDHWC", the data is stored in the order of: [batch,
14830 * depth, height, width, channels]. Only "NDHWC" is currently supported.
14831 * @param pad The type of padding algorithm.
14832 * - `same` and stride 1: output will be of same size as input,
14833 * regardless of filter size.
14834 * - `valid`: output will be smaller than input if filter is larger
14835 * than 1x1.
14836 * - For more info, see this guide:
14837 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
14838 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
14839 * @param includeBatchIndex Defaults to False. Whether to include batch
14840 * dimension in flattened index of argmax.
14841 *
14842 * @doc {heading: 'Operations', subheading: 'Convolution'}
14843 */
14844 function maxPoolWithArgmax_(x, filterSize, strides, pad, includeBatchInIndex = false) {
14845 const $x = convertToTensor(x, 'x', 'maxPoolWithArgmax');
14846 const inputs = { x: $x };
14847 const attrs = { filterSize, strides, pad, includeBatchInIndex };
14848 // tslint:disable-next-line: no-unnecessary-type-assertion
14849 const result = ENGINE.runKernel(MaxPoolWithArgmax, inputs, attrs);
14850 return { result: result[0], indexes: result[1] };
14851 }
14852 const maxPoolWithArgmax = /* @__PURE__ */ op({ maxPoolWithArgmax_ });
14853
14854 /**
14855 * @license
14856 * Copyright 2020 Google LLC. All Rights Reserved.
14857 * Licensed under the Apache License, Version 2.0 (the "License");
14858 * you may not use this file except in compliance with the License.
14859 * You may obtain a copy of the License at
14860 *
14861 * http://www.apache.org/licenses/LICENSE-2.0
14862 *
14863 * Unless required by applicable law or agreed to in writing, software
14864 * distributed under the License is distributed on an "AS IS" BASIS,
14865 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14866 * See the License for the specific language governing permissions and
14867 * limitations under the License.
14868 * =============================================================================
14869 */
14870 /**
14871 * Returns the max of a and b (`a > b ? a : b`) element-wise.
14872 * Supports broadcasting.
14873 *
14874 * We also expose `tf.maximumStrict` which has the same signature as this op and
14875 * asserts that `a` and `b` are the same shape (does not broadcast).
14876 *
14877 * ```js
14878 * const a = tf.tensor1d([1, 4, 3, 16]);
14879 * const b = tf.tensor1d([1, 2, 9, 4]);
14880 *
14881 * a.maximum(b).print(); // or tf.maximum(a, b)
14882 * ```
14883 *
14884 * ```js
14885 * // Broadcast maximum a with b.
14886 * const a = tf.tensor1d([2, 4, 6, 8]);
14887 * const b = tf.scalar(5);
14888 *
14889 * a.maximum(b).print(); // or tf.maximum(a, b)
14890 * ```
14891 *
14892 * @param a The first tensor.
14893 * @param b The second tensor. Must have the same type as `a`.
14894 *
14895 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
14896 */
14897 function maximum_(a, b) {
14898 let $a = convertToTensor(a, 'a', 'maximum');
14899 let $b = convertToTensor(b, 'b', 'maximum');
14900 [$a, $b] = makeTypesMatch($a, $b);
14901 if ($a.dtype === 'bool') {
14902 $a = cast$3($a, 'int32');
14903 $b = cast$3($b, 'int32');
14904 }
14905 assertAndGetBroadcastShape($a.shape, $b.shape);
14906 const inputs = { a: $a, b: $b };
14907 return ENGINE.runKernel(Maximum$1, inputs);
14908 }
14909 const maximum$4 = /* @__PURE__ */ op({ maximum_ });
14910
14911 /**
14912 * @license
14913 * Copyright 2020 Google Inc. All Rights Reserved.
14914 * Licensed under the Apache License, Version 2.0 (the "License");
14915 * you may not use this file except in compliance with the License.
14916 * You may obtain a copy of the License at
14917 *
14918 * http://www.apache.org/licenses/LICENSE-2.0
14919 *
14920 * Unless required by applicable law or agreed to in writing, software
14921 * distributed under the License is distributed on an "AS IS" BASIS,
14922 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14923 * See the License for the specific language governing permissions and
14924 * limitations under the License.
14925 * =============================================================================
14926 */
14927 /**
14928 * Computes the mean of elements across dimensions of a `tf.Tensor`.
14929 *
14930 * Reduces `x` along the dimensions given in `axis`. Unless `keepDims` is
14931 * true, the rank of the `tf.Tensor` is reduced by 1 for each entry in `axis`.
14932 * If `keepDims` is true, the reduced dimensions are retained with length 1.
14933 * If `axis` has no entries, all dimensions are reduced, and a `tf.Tensor` with
14934 * a single element is returned.
14935 *
14936 * ```js
14937 * const x = tf.tensor1d([1, 2, 3]);
14938 *
14939 * x.mean().print(); // or tf.mean(a)
14940 * ```
14941 *
14942 * ```js
14943 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
14944 *
14945 * const axis = 1;
14946 * x.mean(axis).print(); // or tf.mean(x, axis)
14947 * ```
14948 *
14949 * @param x The input tensor.
14950 * @param axis The dimension(s) to reduce. By default it reduces
14951 * all dimensions.
14952 * @param keepDims If true, retains reduced dimensions with size 1.
14953 *
14954 * @doc {heading: 'Operations', subheading: 'Reduction'}
14955 */
14956 function mean_(x, axis = null, keepDims = false) {
14957 const $x = convertToTensor(x, 'x', 'mean');
14958 const inputs = { x: $x };
14959 const attrs = { axis, keepDims };
14960 return ENGINE.runKernel(Mean, inputs, attrs);
14961 }
14962 const mean$3 = /* @__PURE__ */ op({ mean_ });
14963
14964 /**
14965 * @license
14966 * Copyright 2018 Google LLC. All Rights Reserved.
14967 * Licensed under the Apache License, Version 2.0 (the "License");
14968 * you may not use this file except in compliance with the License.
14969 * You may obtain a copy of the License at
14970 *
14971 * http://www.apache.org/licenses/LICENSE-2.0
14972 *
14973 * Unless required by applicable law or agreed to in writing, software
14974 * distributed under the License is distributed on an "AS IS" BASIS,
14975 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14976 * See the License for the specific language governing permissions and
14977 * limitations under the License.
14978 * =============================================================================
14979 */
14980 /**
14981 * Creates a `tf.Tensor` with all elements set to 0.
14982 *
14983 * ```js
14984 * tf.zeros([2, 2]).print();
14985 * ```
14986 *
14987 * @param shape An array of integers defining the output tensor shape.
14988 * @param dtype The type of an element in the resulting tensor. Can
14989 * be 'float32', 'int32' or 'bool'. Defaults to 'float'.
14990 *
14991 * @doc {heading: 'Tensors', subheading: 'Creation'}
14992 */
14993 function zeros$2(shape, dtype = 'float32') {
14994 assertNonNegativeIntegerDimensions(shape);
14995 if (dtype === 'complex64') {
14996 const real = zeros$2(shape, 'float32');
14997 const imag = zeros$2(shape, 'float32');
14998 return complex$2(real, imag);
14999 }
15000 const values = makeZerosTypedArray(sizeFromShape(shape), dtype);
15001 return ENGINE.makeTensor(values, shape, dtype);
15002 }
15003
15004 /**
15005 * @license
15006 * Copyright 2018 Google LLC. All Rights Reserved.
15007 * Licensed under the Apache License, Version 2.0 (the "License");
15008 * you may not use this file except in compliance with the License.
15009 * You may obtain a copy of the License at
15010 *
15011 * http://www.apache.org/licenses/LICENSE-2.0
15012 *
15013 * Unless required by applicable law or agreed to in writing, software
15014 * distributed under the License is distributed on an "AS IS" BASIS,
15015 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15016 * See the License for the specific language governing permissions and
15017 * limitations under the License.
15018 * =============================================================================
15019 */
15020 /**
15021 * Creates a `tf.Tensor` with all elements set to 1.
15022 *
15023 * ```js
15024 * tf.ones([2, 2]).print();
15025 * ```
15026 *
15027 * @param shape An array of integers defining the output tensor shape.
15028 * @param dtype The type of an element in the resulting tensor. Defaults to
15029 * 'float'.
15030 *
15031 * @doc {heading: 'Tensors', subheading: 'Creation'}
15032 */
15033 function ones$1(shape, dtype = 'float32') {
15034 assertNonNegativeIntegerDimensions(shape);
15035 if (dtype === 'complex64') {
15036 const real = ones$1(shape, 'float32');
15037 const imag = zeros$2(shape, 'float32');
15038 return complex$2(real, imag);
15039 }
15040 const values = makeOnesTypedArray(sizeFromShape(shape), dtype);
15041 return ENGINE.makeTensor(values, shape, dtype);
15042 }
15043
15044 /**
15045 * @license
15046 * Copyright 2021 Google LLC. All Rights Reserved.
15047 * Licensed under the Apache License, Version 2.0 (the "License");
15048 * you may not use this file except in compliance with the License.
15049 * You may obtain a copy of the License at
15050 *
15051 * http://www.apache.org/licenses/LICENSE-2.0
15052 *
15053 * Unless required by applicable law or agreed to in writing, software
15054 * distributed under the License is distributed on an "AS IS" BASIS,
15055 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15056 * See the License for the specific language governing permissions and
15057 * limitations under the License.
15058 * =============================================================================
15059 */
15060 /**
15061 * Broadcasts parameters for evaluation on an N-D grid.
15062 *
15063 * Given N one-dimensional coordinate arrays `*args`, returns a list `outputs`
15064 * of N-D coordinate arrays for evaluating expressions on an N-D grid.
15065 *
15066 * Notes:
15067 * `meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions.
15068 * When the `indexing` argument is set to 'xy' (the default), the broadcasting
15069 * instructions for the first two dimensions are swapped.
15070 * Examples:
15071 * Calling `const [X, Y] = meshgrid(x, y)` with the tensors
15072 *
15073 * ```javascript
15074 * const x = [1, 2, 3];
15075 * const y = [4, 5, 6];
15076 * const [X, Y] = tf.meshgrid(x, y);
15077 * // X = [[1, 2, 3],
15078 * // [1, 2, 3],
15079 * // [1, 2, 3]]
15080 * // Y = [[4, 4, 4],
15081 * // [5, 5, 5],
15082 * // [6, 6, 6]]
15083 * ```
15084 *
15085 * @param x Tensor with rank geq 1.
15086 * @param y Tensor with rank geq 1.
15087 * @param indexing
15088 *
15089 * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
15090 */
15091 function meshgrid(x, y, { indexing = 'xy' } = {}) {
15092 if (indexing !== 'xy' && indexing !== 'ij') {
15093 throw new TypeError(`${indexing} is not a valid third argument to meshgrid`);
15094 }
15095 if (x === undefined) {
15096 return [];
15097 }
15098 let $x = convertToTensor(x, 'x', 'meshgrid', x instanceof Tensor ? x.dtype : 'float32');
15099 if (y === undefined) {
15100 return [$x];
15101 }
15102 let $y = convertToTensor(y, 'y', 'meshgrid', y instanceof Tensor ? y.dtype : 'float32');
15103 const w = sizeFromShape($x.shape);
15104 const h = sizeFromShape($y.shape);
15105 if (indexing === 'xy') {
15106 $x = reshape$3($x, [1, -1]);
15107 $y = reshape$3($y, [-1, 1]);
15108 return [
15109 matMul$1(ones$1([h, 1], $x.dtype), $x),
15110 matMul$1($y, ones$1([1, w], $y.dtype)),
15111 ];
15112 }
15113 $x = reshape$3($x, [-1, 1]);
15114 $y = reshape$3($y, [1, -1]);
15115 return [
15116 matMul$1($x, ones$1([1, h], $x.dtype)),
15117 matMul$1(ones$1([w, 1], $y.dtype), $y),
15118 ];
15119 }
15120
15121 /**
15122 * @license
15123 * Copyright 2020 Google LLC. All Rights Reserved.
15124 * Licensed under the Apache License, Version 2.0 (the "License");
15125 * you may not use this file except in compliance with the License.
15126 * You may obtain a copy of the License at
15127 *
15128 * http://www.apache.org/licenses/LICENSE-2.0
15129 *
15130 * Unless required by applicable law or agreed to in writing, software
15131 * distributed under the License is distributed on an "AS IS" BASIS,
15132 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15133 * See the License for the specific language governing permissions and
15134 * limitations under the License.
15135 * =============================================================================
15136 */
15137 /**
15138 * Returns the min of a and b (`a < b ? a : b`) element-wise.
15139 * Supports broadcasting.
15140 *
15141 * We also expose `minimumStrict` which has the same signature as this op and
15142 * asserts that `a` and `b` are the same shape (does not broadcast).
15143 *
15144 * ```js
15145 * const a = tf.tensor1d([1, 4, 3, 16]);
15146 * const b = tf.tensor1d([1, 2, 9, 4]);
15147 *
15148 * a.minimum(b).print(); // or tf.minimum(a, b)
15149 * ```
15150 *
15151 * ```js
15152 * // Broadcast minimum a with b.
15153 * const a = tf.tensor1d([2, 4, 6, 8]);
15154 * const b = tf.scalar(5);
15155 *
15156 * a.minimum(b).print(); // or tf.minimum(a, b)
15157 * ```
15158 *
15159 * @param a The first tensor.
15160 * @param b The second tensor. Must have the same type as `a`.
15161 *
15162 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
15163 */
15164 function minimum_(a, b) {
15165 let $a = convertToTensor(a, 'a', 'minimum');
15166 let $b = convertToTensor(b, 'b', 'minimum');
15167 [$a, $b] = makeTypesMatch($a, $b);
15168 if ($a.dtype === 'bool') {
15169 $a = cast$3($a, 'int32');
15170 $b = cast$3($b, 'int32');
15171 }
15172 assertAndGetBroadcastShape($a.shape, $b.shape);
15173 const inputs = { a: $a, b: $b };
15174 return ENGINE.runKernel(Minimum$1, inputs);
15175 }
15176 const minimum$4 = /* @__PURE__ */ op({ minimum_ });
15177
15178 /**
15179 * @license
15180 * Copyright 2020 Google LLC. All Rights Reserved.
15181 * Licensed under the Apache License, Version 2.0 (the "License");
15182 * you may not use this file except in compliance with the License.
15183 * You may obtain a copy of the License at
15184 *
15185 * http://www.apache.org/licenses/LICENSE-2.0
15186 *
15187 * Unless required by applicable law or agreed to in writing, software
15188 * distributed under the License is distributed on an "AS IS" BASIS,
15189 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15190 * See the License for the specific language governing permissions and
15191 * limitations under the License.
15192 * =============================================================================
15193 */
15194 /**
15195 * Pads a `tf.Tensor` using mirror padding.
15196 *
15197 * This operation implements the `REFLECT` and `SYMMETRIC` modes of pad.
15198 *
15199 * ```js
15200 * const x = tf.range(0, 9).reshape([1, 1, 3, 3]);
15201 * x.mirrorPad([[0, 0], [0, 0], [2, 2], [2, 2]], 'reflect').print();
15202 * ```
15203 * @param x The tensor to pad.
15204 * @param paddings An array of length `R` (the rank of the tensor), where
15205 * each element is a length-2 tuple of ints `[padBefore, padAfter]`,
15206 * specifying how much to pad along each dimension of the tensor.
15207 * In "reflect" mode, the padded regions do not include the borders,
15208 * while in "symmetric" mode the padded regions do include the borders.
15209 * For example, if the input is `[1, 2, 3]` and paddings is `[0, 2]`,
15210 * then the output is `[1, 2, 3, 2, 1]` in "reflect" mode, and
15211 * `[1, 2, 3, 3, 2]` in "symmetric" mode.
15212 * If `mode` is "reflect" then both `paddings[D, 0]` and `paddings[D, 1]`
15213 * must be no greater than `x.shape[D] - 1`. If mode is "symmetric"
15214 * then both `paddings[D, 0]` and `paddings[D, 1]` must be no greater than
15215 * `x.shape[D]`
15216 * @param mode String to specify padding mode. Can be `'reflect' | 'symmetric'`
15217 */
15218 /** @doc {heading: 'Tensors', subheading: 'Transformations'} */
15219 function mirrorPad_(x, paddings, mode) {
15220 assert$1(mode === 'reflect' || mode === 'symmetric', () => `Invalid mode. Mode must be either reflect or symmetric. ` +
15221 `Got ${mode}.`);
15222 const $x = convertToTensor(x, 'x', 'mirrorPad');
15223 if ($x.rank === 0) {
15224 throw new Error('mirrorPad(scalar) is not defined. ' +
15225 'Pass non-scalar to mirrorPad');
15226 }
15227 assert$1(paddings.length === $x.rank, () => `Padding doesn't match input. Must be ${$x.rank}. ` +
15228 `Got ${paddings.length}.`);
15229 const shapeOffset = mode === 'reflect' ? 1 : 0;
15230 for (let i = 0; i < $x.rank; i++) {
15231 assert$1(paddings[i].length === 2, () => `Invalid number of paddings. Must be length of 2 each.`);
15232 assert$1(paddings[i][0] >= 0 && paddings[i][0] <= $x.shape[i] - shapeOffset &&
15233 paddings[i][1] >= 0 && paddings[i][1] <= $x.shape[i] - shapeOffset, () => `Padding in dimension ${i} cannot be greater than or equal ` +
15234 `to ${$x.shape[i] - shapeOffset} or less than 0 for input of ` +
15235 `shape ${$x.shape}`);
15236 }
15237 const attrs = { paddings, mode };
15238 const inputs = { x: $x };
15239 return ENGINE.runKernel(MirrorPad, inputs, attrs);
15240 }
15241 const mirrorPad$1 = /* @__PURE__ */ op({ mirrorPad_ });
15242
15243 /**
15244 * @license
15245 * Copyright 2020 Google LLC. All Rights Reserved.
15246 * Licensed under the Apache License, Version 2.0 (the "License");
15247 * you may not use this file except in compliance with the License.
15248 * You may obtain a copy of the License at
15249 *
15250 * http://www.apache.org/licenses/LICENSE-2.0
15251 *
15252 * Unless required by applicable law or agreed to in writing, software
15253 * distributed under the License is distributed on an "AS IS" BASIS,
15254 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15255 * See the License for the specific language governing permissions and
15256 * limitations under the License.
15257 * =============================================================================
15258 */
15259 /**
15260 * Returns the mod of a and b element-wise.
15261 * `floor(x / y) * y + mod(x, y) = x`
15262 * Supports broadcasting.
15263 *
15264 * We also expose `tf.modStrict` which has the same signature as this op and
15265 * asserts that `a` and `b` are the same shape (does not broadcast).
15266 *
15267 * ```js
15268 * const a = tf.tensor1d([1, 4, 3, 16]);
15269 * const b = tf.tensor1d([1, 2, 9, 4]);
15270 *
15271 * a.mod(b).print(); // or tf.mod(a, b)
15272 * ```
15273 *
15274 * ```js
15275 * // Broadcast a mod b.
15276 * const a = tf.tensor1d([2, 4, 6, 8]);
15277 * const b = tf.scalar(5);
15278 *
15279 * a.mod(b).print(); // or tf.mod(a, b)
15280 * ```
15281 *
15282 * @param a The first tensor.
15283 * @param b The second tensor. Must have the same type as `a`.
15284 *
15285 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
15286 */
15287 function mod_(a, b) {
15288 let $a = convertToTensor(a, 'a', 'mod');
15289 let $b = convertToTensor(b, 'b', 'mod');
15290 [$a, $b] = makeTypesMatch($a, $b);
15291 const inputs = { a: $a, b: $b };
15292 return ENGINE.runKernel(Mod, inputs);
15293 }
15294 const mod$2 = /* @__PURE__ */ op({ mod_ });
15295
15296 /**
15297 * @license
15298 * Copyright 2020 Google LLC. All Rights Reserved.
15299 * Licensed under the Apache License, Version 2.0 (the "License");
15300 * you may not use this file except in compliance with the License.
15301 * You may obtain a copy of the License at
15302 *
15303 * http://www.apache.org/licenses/LICENSE-2.0
15304 *
15305 * Unless required by applicable law or agreed to in writing, software
15306 * distributed under the License is distributed on an "AS IS" BASIS,
15307 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15308 * See the License for the specific language governing permissions and
15309 * limitations under the License.
15310 * =============================================================================
15311 */
15312 /**
15313 * Calculates the mean and variance of `x`. The mean and variance are
15314 * calculated by aggregating the contents of `x` across `axes`. If `x` is
15315 * 1-D and `axes = [0]` this is just the mean and variance of a vector.
15316 *
15317 * @param x The input tensor.
15318 * @param axis The dimension(s) along with to compute mean and
15319 * variance. By default it reduces all dimensions.
15320 * @param keepDims If true, the moments have the same dimensionality as the
15321 * input.
15322 * @return An object with two keys: `mean` and `variance`.
15323 *
15324 * @doc {heading: 'Operations', subheading: 'Normalization'}
15325 */
15326 function moments_(x, axis = null, keepDims = false) {
15327 x = convertToTensor(x, 'x', 'moments');
15328 const axes = parseAxisParam(axis, x.shape);
15329 const xMean = mean$3(x, axes, keepDims);
15330 let keepDimsShape = xMean.shape;
15331 if (!keepDims) {
15332 keepDimsShape = expandShapeToKeepDim(xMean.shape, axes);
15333 }
15334 const devSquared = square$2(sub$2(cast$3(x, 'float32'), reshape$3(xMean, keepDimsShape)));
15335 const variance = mean$3(devSquared, axes, keepDims);
15336 return { mean: xMean, variance };
15337 }
15338 const moments = /* @__PURE__ */ op({ moments_ });
15339
15340 /**
15341 * Computes the next states and outputs of a stack of LSTMCells.
15342 *
15343 * Each cell output is used as input to the next cell.
15344 *
15345 * Returns `[cellState, cellOutput]`.
15346 *
15347 * Derived from tf.contrib.rn.MultiRNNCell.
15348 *
15349 * @param lstmCells Array of LSTMCell functions.
15350 * @param data The input to the cell.
15351 * @param c Array of previous cell states.
15352 * @param h Array of previous cell outputs.
15353 *
15354 * @doc {heading: 'Operations', subheading: 'RNN'}
15355 */
15356 function multiRNNCell_(lstmCells, data, c, h) {
15357 const $data = convertToTensor(data, 'data', 'multiRNNCell');
15358 const $c = convertToTensorArray(c, 'c', 'multiRNNCell');
15359 const $h = convertToTensorArray(h, 'h', 'multiRNNCell');
15360 let input = $data;
15361 const newStates = [];
15362 for (let i = 0; i < lstmCells.length; i++) {
15363 const output = lstmCells[i](input, $c[i], $h[i]);
15364 newStates.push(output[0]);
15365 newStates.push(output[1]);
15366 input = output[1];
15367 }
15368 const newC = [];
15369 const newH = [];
15370 for (let i = 0; i < newStates.length; i += 2) {
15371 newC.push(newStates[i]);
15372 newH.push(newStates[i + 1]);
15373 }
15374 return [newC, newH];
15375 }
15376 const multiRNNCell = /* @__PURE__ */ op({ multiRNNCell_ });
15377
15378 /**
15379 * @license
15380 * Copyright 2020 Google LLC. All Rights Reserved.
15381 * Licensed under the Apache License, Version 2.0 (the "License");
15382 * you may not use this file except in compliance with the License.
15383 * You may obtain a copy of the License at
15384 *
15385 * http://www.apache.org/licenses/LICENSE-2.0
15386 *
15387 * Unless required by applicable law or agreed to in writing, software
15388 * distributed under the License is distributed on an "AS IS" BASIS,
15389 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15390 * See the License for the specific language governing permissions and
15391 * limitations under the License.
15392 * =============================================================================
15393 */
15394 /**
15395 * Creates a `tf.Tensor` with values drawn from a multinomial distribution.
15396 *
15397 * ```js
15398 * const probs = tf.tensor([.75, .25]);
15399 * tf.multinomial(probs, 3).print();
15400 * ```
15401 *
15402 * @param logits 1D array with unnormalized log-probabilities, or
15403 * 2D array of shape `[batchSize, numOutcomes]`. See the `normalized`
15404 * parameter.
15405 * @param numSamples Number of samples to draw for each row slice.
15406 * @param seed The seed number.
15407 * @param normalized Whether the provided `logits` are normalized true
15408 * probabilities (sum to 1). Defaults to false.
15409 * @return 1D array of shape `[numSamples]`, or 2D array of shape
15410 * `[batchSize, numSamples]`, depending on the rank of the input.
15411 *
15412 * @doc {heading: 'Tensors', subheading: 'Random'}
15413 */
15414 function multinomial_(logits, numSamples, seed, normalized = false) {
15415 const $logits = convertToTensor(logits, 'logits', 'multinomial');
15416 const numOutcomes = $logits.size;
15417 const origRank = $logits.rank;
15418 if (numOutcomes < 2) {
15419 throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ` +
15420 `${numOutcomes}.`);
15421 }
15422 if (origRank > 2) {
15423 throw new Error(`Rank of probabilities must be 1 or 2, but is ${origRank}`);
15424 }
15425 // TODO(lina128): Investigate correct seed behavior. The code seems not allow
15426 // setting see to 0.
15427 seed = seed || Math.random();
15428 // The kernel only accepts (and returns) rank 2 tensors.
15429 const logits2D = origRank === 1 ? reshape$3($logits, [1, -1]) : $logits;
15430 const inputs = { logits: logits2D };
15431 const attrs = { numSamples, seed, normalized };
15432 // tslint:disable-next-line: no-unnecessary-type-assertion
15433 const res = ENGINE.runKernel(Multinomial, inputs, attrs);
15434 // tslint:disable-next-line:no-unnecessary-type-assertion
15435 return origRank === 1 ? reshape$3(res, [res.size]) : res;
15436 }
15437 const multinomial$2 = /* @__PURE__ */ op({ multinomial_ });
15438
15439 /**
15440 * @license
15441 * Copyright 2020 Google LLC. All Rights Reserved.
15442 * Licensed under the Apache License, Version 2.0 (the "License");
15443 * you may not use this file except in compliance with the License.
15444 * You may obtain a copy of the License at
15445 *
15446 * http://www.apache.org/licenses/LICENSE-2.0
15447 *
15448 * Unless required by applicable law or agreed to in writing, software
15449 * distributed under the License is distributed on an "AS IS" BASIS,
15450 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15451 * See the License for the specific language governing permissions and
15452 * limitations under the License.
15453 * =============================================================================
15454 */
15455 /**
15456 * Returns the truth value of (a != b) element-wise. Supports broadcasting.
15457 *
15458 * ```js
15459 * const a = tf.tensor1d([1, 2, 3]);
15460 * const b = tf.tensor1d([0, 2, 3]);
15461 *
15462 * a.notEqual(b).print();
15463 * ```
15464 * @param a The first input tensor.
15465 * @param b The second input tensor. Must have the same dtype as `a`.
15466 *
15467 * @doc {heading: 'Operations', subheading: 'Logical'}
15468 */
15469 function notEqual_(a, b) {
15470 let $a = convertToTensor(a, 'a', 'notEqual', 'string_or_numeric');
15471 let $b = convertToTensor(b, 'b', 'notEqual', 'string_or_numeric');
15472 [$a, $b] = makeTypesMatch($a, $b);
15473 assertAndGetBroadcastShape($a.shape, $b.shape);
15474 const inputs = { a: $a, b: $b };
15475 return ENGINE.runKernel(NotEqual, inputs);
15476 }
15477 const notEqual$2 = /* @__PURE__ */ op({ notEqual_ });
15478
15479 /**
15480 * @license
15481 * Copyright 2020 Google LLC. All Rights Reserved.
15482 * Licensed under the Apache License, Version 2.0 (the "License");
15483 * you may not use this file except in compliance with the License.
15484 * You may obtain a copy of the License at
15485 *
15486 * http://www.apache.org/licenses/LICENSE-2.0
15487 *
15488 * Unless required by applicable law or agreed to in writing, software
15489 * distributed under the License is distributed on an "AS IS" BASIS,
15490 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15491 * See the License for the specific language governing permissions and
15492 * limitations under the License.
15493 * =============================================================================
15494 */
15495 /**
15496 * Creates a one-hot `tf.Tensor`. The locations represented by `indices` take
15497 * value `onValue` (defaults to 1), while all other locations take value
15498 * `offValue` (defaults to 0). If `indices` is rank `R`, the output has rank
15499 * `R+1` with the last axis of size `depth`.
15500 * `indices` used to encode prediction class must start from 0. For example,
15501 * if you have 3 classes of data, class 1 should be encoded as 0, class 2
15502 * should be 1, and class 3 should be 2.
15503 *
15504 * ```js
15505 * tf.oneHot(tf.tensor1d([0, 1], 'int32'), 3).print();
15506 * ```
15507 *
15508 * @param indices `tf.Tensor` of indices with dtype `int32`. Indices must
15509 * start from 0.
15510 * @param depth The depth of the one hot dimension.
15511 * @param onValue A number used to fill in the output when the index matches
15512 * the location.
15513 * @param offValue A number used to fill in the output when the index does
15514 * not match the location.
15515 * @param dtype The dtype of the output tensor, default to 'int32'.
15516 *
15517 * @doc {heading: 'Tensors', subheading: 'Creation'}
15518 */
15519 function oneHot_(indices, depth, onValue = 1, offValue = 0, dtype = 'int32') {
15520 if (depth < 2) {
15521 throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`);
15522 }
15523 const $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32');
15524 const inputs = { indices: $indices };
15525 const attrs = { dtype, depth, onValue, offValue };
15526 return ENGINE.runKernel(OneHot, inputs, attrs);
15527 }
15528 const oneHot$3 = /* @__PURE__ */ op({ oneHot_ });
15529
15530 /**
15531 * @license
15532 * Copyright 2018 Google LLC. All Rights Reserved.
15533 * Licensed under the Apache License, Version 2.0 (the "License");
15534 * you may not use this file except in compliance with the License.
15535 * You may obtain a copy of the License at
15536 *
15537 * http://www.apache.org/licenses/LICENSE-2.0
15538 *
15539 * Unless required by applicable law or agreed to in writing, software
15540 * distributed under the License is distributed on an "AS IS" BASIS,
15541 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15542 * See the License for the specific language governing permissions and
15543 * limitations under the License.
15544 * =============================================================================
15545 */
15546 /**
15547 * Creates a `tf.Tensor` with all elements set to 1 with the same shape as the
15548 * given tensor.
15549 *
15550 * ```js
15551 * const x = tf.tensor([1, 2]);
15552 * tf.onesLike(x).print();
15553 * ```
15554 * @param x A tensor.
15555 *
15556 * @doc {heading: 'Tensors', subheading: 'Creation'}
15557 */
15558 function onesLike_(x) {
15559 const $x = convertToTensor(x, 'x', 'onesLike');
15560 const inputs = { x: $x };
15561 return ENGINE.runKernel(OnesLike, inputs);
15562 }
15563 const onesLike$3 = /* @__PURE__ */ op({ onesLike_ });
15564
15565 /**
15566 * Computes the outer product of two vectors, `v1` and `v2`.
15567 *
15568 * ```js
15569 * const a = tf.tensor1d([1, 2, 3]);
15570 * const b = tf.tensor1d([3, 4, 5]);
15571 *
15572 * tf.outerProduct(a, b).print();
15573 * ```
15574 * @param v1 The first vector in the outer product operation.
15575 * @param v2 The second vector in the outer product operation.
15576 *
15577 * @doc {heading: 'Operations', subheading: 'Matrices'}
15578 */
15579 function outerProduct_(v1, v2) {
15580 const $v1 = convertToTensor(v1, 'v1', 'outerProduct');
15581 const $v2 = convertToTensor(v2, 'v2', 'outerProduct');
15582 assert$1($v1.rank === 1 && $v2.rank === 1, () => `Error in outerProduct: inputs must be rank 1, but got ranks ` +
15583 `${$v1.rank} and ${$v2.rank}.`);
15584 const v12D = reshape$3($v1, [-1, 1]);
15585 const v22D = reshape$3($v2, [1, -1]);
15586 return matMul$1(v12D, v22D);
15587 }
15588 const outerProduct = /* @__PURE__ */ op({ outerProduct_ });
15589
15590 /**
15591 * @license
15592 * Copyright 2020 Google LLC. All Rights Reserved.
15593 * Licensed under the Apache License, Version 2.0 (the "License");
15594 * you may not use this file except in compliance with the License.
15595 * You may obtain a copy of the License at
15596 *
15597 * http://www.apache.org/licenses/LICENSE-2.0
15598 *
15599 * Unless required by applicable law or agreed to in writing, software
15600 * distributed under the License is distributed on an "AS IS" BASIS,
15601 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15602 * See the License for the specific language governing permissions and
15603 * limitations under the License.
15604 * =============================================================================
15605 */
15606 /**
15607 * Pads a `tf.Tensor` with a given value and paddings.
15608 *
15609 * This operation implements `CONSTANT` mode. For `REFLECT` and `SYMMETRIC`,
15610 * refer to `tf.mirrorPad`.
15611 *
15612 * Also available are stricter rank-specific methods with the same signature
15613 * as this method that assert that `paddings` is of given length.
15614 * - `tf.pad1d`
15615 * - `tf.pad2d`
15616 * - `tf.pad3d`
15617 * - `tf.pad4d`
15618 *
15619 * ```js
15620 * const x = tf.tensor1d([1, 2, 3, 4]);
15621 * x.pad([[1, 2]]).print();
15622 * ```
15623 * @param x The tensor to pad.
15624 * @param paddings An array of length `R` (the rank of the tensor), where
15625 * each element is a length-2 tuple of ints `[padBefore, padAfter]`,
15626 * specifying how much to pad along each dimension of the tensor.
15627 * @param constantValue The pad value to use. Defaults to 0.
15628 *
15629 * @doc {heading: 'Tensors', subheading: 'Transformations'}
15630 */
15631 function pad_(x, paddings, constantValue = 0) {
15632 const $x = convertToTensor(x, 'x', 'pad');
15633 if ($x.rank === 0) {
15634 throw new Error('pad(scalar) is not defined. Pass non-scalar to pad');
15635 }
15636 const attrs = { paddings, constantValue };
15637 const inputs = { x: $x };
15638 return ENGINE.runKernel(PadV2, inputs, attrs);
15639 }
15640 const pad = /* @__PURE__ */ op({ pad_ });
15641
15642 /**
15643 * Pads a `tf.Tensor1D` with a given value and paddings. See `pad` for details.
15644 */
15645 function pad1d_(x, paddings, constantValue = 0) {
15646 assert$1(paddings.length === 2, () => 'Invalid number of paddings. Must be length of 2.');
15647 return pad(x, [paddings], constantValue);
15648 }
15649 const pad1d = /* @__PURE__ */ op({ pad1d_ });
15650
15651 /**
15652 * Pads a `tf.Tensor2D` with a given value and paddings. See `pad` for details.
15653 */
15654 function pad2d_(x, paddings, constantValue = 0) {
15655 assert$1(paddings.length === 2 && paddings[0].length === 2 &&
15656 paddings[1].length === 2, () => 'Invalid number of paddings. Must be length of 2 each.');
15657 return pad(x, paddings, constantValue);
15658 }
15659 const pad2d = /* @__PURE__ */ op({ pad2d_ });
15660
15661 /**
15662 * Pads a `tf.Tensor3D` with a given value and paddings. See `pad` for details.
15663 */
15664 function pad3d_(x, paddings, constantValue = 0) {
15665 assert$1(paddings.length === 3 && paddings[0].length === 2 &&
15666 paddings[1].length === 2 && paddings[2].length === 2, () => 'Invalid number of paddings. Must be length of 2 each.');
15667 return pad(x, paddings, constantValue);
15668 }
15669 const pad3d = /* @__PURE__ */ op({ pad3d_ });
15670
15671 /**
15672 * Pads a `tf.Tensor4D` with a given value and paddings. See `pad` for details.
15673 */
15674 function pad4d_(x, paddings, constantValue = 0) {
15675 assert$1(paddings.length === 4 && paddings[0].length === 2 &&
15676 paddings[1].length === 2 && paddings[2].length === 2 &&
15677 paddings[3].length === 2, () => 'Invalid number of paddings. Must be length of 2 each.');
15678 return pad(x, paddings, constantValue);
15679 }
15680 const pad4d = /* @__PURE__ */ op({ pad4d_ });
15681
15682 /**
15683 * @license
15684 * Copyright 2020 Google LLC. All Rights Reserved.
15685 * Licensed under the Apache License, Version 2.0 (the "License");
15686 * you may not use this file except in compliance with the License.
15687 * You may obtain a copy of the License at
15688 *
15689 * http://www.apache.org/licenses/LICENSE-2.0
15690 *
15691 * Unless required by applicable law or agreed to in writing, software
15692 * distributed under the License is distributed on an "AS IS" BASIS,
15693 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15694 * See the License for the specific language governing permissions and
15695 * limitations under the License.
15696 * =============================================================================
15697 */
15698 /**
15699 * This operation divides "spatial" dimensions `[1, ..., M]` of the input into
15700 * a grid of blocks of shape `blockShape`, and interleaves these blocks with
15701 * the "batch" dimension (0) such that in the output, the spatial
15702 * dimensions `[1, ..., M]` correspond to the position within the grid,
15703 * and the batch dimension combines both the position within a spatial block
15704 * and the original batch position. Prior to division into blocks,
15705 * the spatial dimensions of the input are optionally zero padded
15706 * according to `paddings`. See below for a precise description.
15707 *
15708 * ```js
15709 * const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]);
15710 * const blockShape = [2, 2];
15711 * const paddings = [[0, 0], [0, 0]];
15712 *
15713 * x.spaceToBatchND(blockShape, paddings).print();
15714 * ```
15715 *
15716 * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
15717 * remainingShape`, where spatialShape has `M` dimensions.
15718 * @param blockShape A 1-D array. Must have shape `[M]`, all values must
15719 * be >= 1.
15720 * @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >=
15721 * 0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad
15722 * from input dimension `i + 1`, which corresponds to spatial dimension `i`. It
15723 * is required that
15724 * `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0`
15725 *
15726 * This operation is equivalent to the following steps:
15727 *
15728 * 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input
15729 * according to `paddings` to produce `padded` of shape paddedShape.
15730 *
15731 * 2. Reshape `padded` to `reshapedPadded` of shape:
15732 * `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ...,
15733 * paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape`
15734 *
15735 * 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded`
15736 * of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ...,
15737 * paddedShape[M] / blockShape[M-1]] + remainingShape`
15738 *
15739 * 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the
15740 * batch dimension, producing an output tensor of shape:
15741 * `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ...,
15742 * paddedShape[M] / blockShape[M-1]] + remainingShape`
15743 *
15744 * @doc {heading: 'Tensors', subheading: 'Transformations'}
15745 */
15746 function spaceToBatchND_(x, blockShape, paddings) {
15747 const $x = convertToTensor(x, 'x', 'spaceToBatchND');
15748 assert$1($x.rank >= 1 + blockShape.length, () => `input rank ${$x.rank} should be > than [blockShape] ${blockShape.length}`);
15749 assert$1(paddings.length === blockShape.length, () => `paddings.shape[0] ${paddings.length} must be equal to [blockShape] ${blockShape.length}`);
15750 assert$1($x.shape.reduce((a, b, i) => {
15751 if (i > 0 && i <= blockShape.length) {
15752 return a &&
15753 ((b + paddings[i - 1][0] + paddings[i - 1][1]) %
15754 blockShape[i - 1] ===
15755 0);
15756 }
15757 return a;
15758 }, true), () => `input spatial dimensions ${$x.shape.slice(1)} with paddings ${paddings.toString()} must be divisible by blockShapes ${blockShape.toString()}`);
15759 const inputs = { x: $x };
15760 const attrs = { blockShape, paddings };
15761 return ENGINE.runKernel(SpaceToBatchND, inputs, attrs);
15762 }
15763 const spaceToBatchND$2 = /* @__PURE__ */ op({ spaceToBatchND_ });
15764
15765 /**
15766 * @license
15767 * Copyright 2018 Google LLC. All Rights Reserved.
15768 * Licensed under the Apache License, Version 2.0 (the "License");
15769 * you may not use this file except in compliance with the License.
15770 * You may obtain a copy of the License at
15771 *
15772 * http://www.apache.org/licenses/LICENSE-2.0
15773 *
15774 * Unless required by applicable law or agreed to in writing, software
15775 * distributed under the License is distributed on an "AS IS" BASIS,
15776 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15777 * See the License for the specific language governing permissions and
15778 * limitations under the License.
15779 * =============================================================================
15780 */
15781 /**
15782 * Performs an N-D pooling operation
15783 *
15784 * @param input The input tensor, of rank 4 or rank 3 of shape
15785 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
15786 * @param windowShape The filter size: `[filterHeight, filterWidth]`. If
15787 * `filterSize` is a single number, then `filterHeight == filterWidth`.
15788 * @param poolingType The type of pooling, either 'max' or 'avg'.
15789 * @param pad The type of padding algorithm:
15790 * - `same` and stride 1: output will be of same size as input,
15791 * regardless of filter size.
15792 * - `valid`: output will be smaller than input if filter is larger
15793 * than 1x1.
15794 * - For more info, see this guide:
15795 * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
15796 * https://www.tensorflow.org/api_guides/python/nn#Convolution)
15797 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
15798 * in which we sample input values across the height and width dimensions
15799 * in dilated pooling. Defaults to `[1, 1]`. If `dilationRate` is a single
15800 * number, then `dilationHeight == dilationWidth`. If it is greater than
15801 * 1, then all values of `strides` must be 1.
15802 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
15803 * `strides` is a single number, then `strideHeight == strideWidth`.
15804 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
15805 * provided, it will default to truncate.
15806 *
15807 * @doc {heading: 'Operations', subheading: 'Convolution'}
15808 */
15809 function pool_(input, windowShape, poolingType, pad, dilations, strides, dimRoundingMode) {
15810 if (dilations == null) {
15811 dilations = [1, 1];
15812 }
15813 if (strides == null) {
15814 strides = 1;
15815 }
15816 if (pad === 0) {
15817 pad = 'valid';
15818 }
15819 const $x = convertToTensor(input, 'x', 'maxPool');
15820 let x4D = $x;
15821 let reshapedTo4D = false;
15822 if ($x.rank === 3) {
15823 reshapedTo4D = true;
15824 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
15825 }
15826 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in pool: Either strides or dilations must be 1. ' +
15827 `Got strides ${strides} and dilations '${dilations}'`);
15828 const convInfo = computePool2DInfo(x4D.shape, windowShape, strides, dilations, pad);
15829 const dilation = [convInfo.dilationHeight, convInfo.dilationWidth];
15830 // The following implementation does batchToSpace(pool(spaceToBatch(x)))
15831 // whenever dilation > 1 since the TF kernels do not support dilation > 1.
15832 // tslint:disable-next-line:max-line-length
15833 // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/nn_ops.py#L1037
15834 let basePadding;
15835 if (pad === 'same') {
15836 basePadding = withSpaceToBatchBasePaddings([convInfo.filterHeight, convInfo.filterWidth], dilation);
15837 }
15838 else {
15839 basePadding = [[0, 0], [0, 0]];
15840 }
15841 const isDilationOne = dilation[0] === 1 && dilation[1] === 1;
15842 const [adjustedPadding, adjustedCrops] = requiredSpaceToBatchPaddings([convInfo.inHeight, convInfo.inWidth], dilation, basePadding);
15843 const convertedPad = isDilationOne ? pad : 'valid';
15844 const convertedX = isDilationOne ? x4D : spaceToBatchND$2(x4D, dilation, adjustedPadding);
15845 const forwardOp = poolingType === 'avg' ?
15846 () => avgPool$2(convertedX, windowShape, strides, convertedPad, dimRoundingMode) :
15847 () => maxPool$2(convertedX, windowShape, strides, convertedPad, dimRoundingMode);
15848 const y = forwardOp();
15849 const res = isDilationOne ? y : batchToSpaceND$2(y, dilation, adjustedCrops);
15850 if (reshapedTo4D) {
15851 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
15852 }
15853 return res;
15854 }
15855 // Helper function to compute crops and paddings for pool with dilation > 1.
15856 // tslint:disable-next-line:max-line-length
15857 // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/array_ops.py#L2184
15858 function requiredSpaceToBatchPaddings(inputShape, blockShape, basePadding) {
15859 const padStart = basePadding.map(b => b[0]);
15860 const origPadEnd = basePadding.map(b => b[1]);
15861 const fullInputShape = inputShape.concat(padStart, origPadEnd);
15862 const padEndExtra = blockShape.map((b, i) => (b - fullInputShape[i] % b) % b);
15863 const padEnd = origPadEnd.map((s, i) => s + padEndExtra[i]);
15864 const paddings = blockShape.map((_, i) => [padStart[i], padEnd[i]]);
15865 const crops = blockShape.map((_, i) => [0, padEndExtra[i]]);
15866 return [paddings, crops];
15867 }
15868 // Helper function to compute base paddings for pool with dilation > 1.
15869 // tslint:disable-next-line:max-line-length
15870 // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/nn_ops.py#L524
15871 function withSpaceToBatchBasePaddings(filterShape, dilation) {
15872 // Spatial dimensions of the filters and the upsampled filters in which we
15873 // introduce (rate - 1) zeros between consecutive filter values.
15874 const dilatedFilterShape = filterShape.map((s, i) => {
15875 return s + (s - 1) * (dilation[i] - 1);
15876 });
15877 const padExtraShape = dilatedFilterShape.map(s => s - 1);
15878 // When padding is odd, we pad more at end, following the same
15879 // convention as conv2d.
15880 const padExtraStart = padExtraShape.map(s => Math.floor(s / 2));
15881 const padExtraEnd = padExtraShape.map((s, i) => s - padExtraStart[i]);
15882 return padExtraShape.map((_, i) => {
15883 return [padExtraStart[i], padExtraEnd[i]];
15884 });
15885 }
15886 const pool$1 = /* @__PURE__ */ op({ pool_ });
15887
15888 /**
15889 * @license
15890 * Copyright 2020 Google LLC. All Rights Reserved.
15891 * Licensed under the Apache License, Version 2.0 (the "License");
15892 * you may not use this file except in compliance with the License.
15893 * You may obtain a copy of the License at
15894 *
15895 * http://www.apache.org/licenses/LICENSE-2.0
15896 *
15897 * Unless required by applicable law or agreed to in writing, software
15898 * distributed under the License is distributed on an "AS IS" BASIS,
15899 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15900 * See the License for the specific language governing permissions and
15901 * limitations under the License.
15902 * =============================================================================
15903 */
15904 /**
15905 * Computes leaky rectified linear element-wise with parametric alphas.
15906 *
15907 * `x < 0 ? alpha * x : f(x) = x`
15908 *
15909 * ```js
15910 * const x = tf.tensor1d([-1, 2, -3, 4]);
15911 * const alpha = tf.scalar(0.1);
15912 *
15913 * x.prelu(alpha).print(); // or tf.prelu(x, alpha)
15914 * ```
15915 * @param x The input tensor.
15916 * @param alpha Scaling factor for negative values.
15917 *
15918 * @doc {heading: 'Operations', subheading: 'Basic math'}
15919 */
15920 function prelu_(x, alpha) {
15921 const $x = convertToTensor(x, 'x', 'prelu');
15922 const $alpha = convertToTensor(alpha, 'alpha', 'prelu');
15923 const inputs = { x: $x, alpha: $alpha };
15924 return ENGINE.runKernel(Prelu, inputs);
15925 }
15926 const prelu$3 = /* @__PURE__ */ op({ prelu_ });
15927
15928 /**
15929 * @license
15930 * Copyright 2020 Google LLC. All Rights Reserved.
15931 * Licensed under the Apache License, Version 2.0 (the "License");
15932 * you may not use this file except in compliance with the License.
15933 * You may obtain a copy of the License at
15934 *
15935 * http://www.apache.org/licenses/LICENSE-2.0
15936 *
15937 * Unless required by applicable law or agreed to in writing, software
15938 * distributed under the License is distributed on an "AS IS" BASIS,
15939 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15940 * See the License for the specific language governing permissions and
15941 * limitations under the License.
15942 * =============================================================================
15943 */
15944 /**
15945 * Computes the product of elements across dimensions of a `tf.Tensor`.
15946 *
15947 * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
15948 * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
15949 * `axes`. If `keepDims` is true, the reduced dimensions are retained with
15950 * length 1. If `axes` has no entries, all dimensions are reduced, and a
15951 * `tf.Tensor` with a single element is returned.
15952 *
15953 * ```js
15954 * const x = tf.tensor1d([1, 2, 3]);
15955 *
15956 * x.prod().print(); // or tf.prod(x)
15957 * ```
15958 *
15959 * ```js
15960 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
15961 *
15962 * const axis = 1;
15963 * x.prod(axis).print(); // or tf.prod(x, axis)
15964 * ```
15965 *
15966 * @param x The input tensor to compute the product over. If the dtype is `bool`
15967 * it will be converted to `int32` and the output dtype will be `int32`.
15968 * @param axis The dimension(s) to reduce. By default it reduces
15969 * all dimensions.
15970 * @param keepDims If true, retains reduced dimensions with size 1.
15971 *
15972 * @doc {heading: 'Operations', subheading: 'Reduction'}
15973 */
15974 function prod_(x, axis = null, keepDims = false) {
15975 let $x = convertToTensor(x, 'x', 'prod');
15976 if ($x.dtype === 'bool') {
15977 // bool is not an allowed type for the underlying kernel.
15978 $x = cast$3($x, 'int32');
15979 }
15980 const inputs = { x: $x };
15981 const attrs = { axis, keepDims };
15982 return ENGINE.runKernel(Prod, inputs, attrs);
15983 }
15984 const prod$2 = /* @__PURE__ */ op({ prod_ });
15985
15986 /**
15987 * @license
15988 * Copyright 2022 Google LLC. All Rights Reserved.
15989 * Licensed under the Apache License, Version 2.0 (the "License");
15990 * you may not use this file except in compliance with the License.
15991 * You may obtain a copy of the License at
15992 *
15993 * http://www.apache.org/licenses/LICENSE-2.0
15994 *
15995 * Unless required by applicable law or agreed to in writing, software
15996 * distributed under the License is distributed on an "AS IS" BASIS,
15997 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15998 * See the License for the specific language governing permissions and
15999 * limitations under the License.
16000 * =============================================================================
16001 */
16002 function raggedGather_(paramsNestedSplits, paramsDenseValues, indices, outputRaggedRank) {
16003 const $paramsNestedSplits = paramsNestedSplits.map((t, i) => convertToTensor(t, `tensors${i}`, 'raggedGather', 'int32'));
16004 const $paramsDenseValues = convertToTensor(paramsDenseValues, 'paramsDenseValues', 'raggedGather');
16005 const $indices = convertToTensor(indices, 'indices', 'raggedGather', 'int32');
16006 const inputs = {
16007 paramsNestedSplits: $paramsNestedSplits,
16008 paramsDenseValues: $paramsDenseValues,
16009 indices: $indices,
16010 };
16011 const attrs = { outputRaggedRank };
16012 const result = ENGINE.runKernel(RaggedGather, inputs, attrs);
16013 return {
16014 outputNestedSplits: result.slice(0, result.length - 1),
16015 outputDenseValues: result[result.length - 1],
16016 };
16017 }
16018 const raggedGather$2 = /* @__PURE__ */ op({ raggedGather_ });
16019
16020 /**
16021 * @license
16022 * Copyright 2022 Google LLC.
16023 * Licensed under the Apache License, Version 2.0 (the "License");
16024 * you may not use this file except in compliance with the License.
16025 * You may obtain a copy of the License at
16026 *
16027 * http://www.apache.org/licenses/LICENSE-2.0
16028 *
16029 * Unless required by applicable law or agreed to in writing, software
16030 * distributed under the License is distributed on an "AS IS" BASIS,
16031 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16032 * See the License for the specific language governing permissions and
16033 * limitations under the License.
16034 * =============================================================================
16035 */
16036 /**
16037 * Returns a RaggedTensor result composed from rtDenseValues and rtNestedSplits,
16038 * such that result[i] = [starts[i], starts[i] + deltas[i], ..., limits[i]]).
16039 *
16040 * @param starts: A Tensor. Must be one of the following types:
16041 * 'float32', 'int32'. The starts of each range.
16042 * @param limits: A Tensor. Must have the same type as starts. The limits of
16043 * each range.
16044 * @param deltas: A Tensor. Must have the same type as starts. The deltas of
16045 * each range.
16046 * @return A map with the following properties:
16047 * - rtNestedSplits: A Tensor of type 'int32'.
16048 * - rtDenseValues: A Tensor. Has the same type as starts.
16049 */
16050 function raggedRange_(starts, limits, deltas) {
16051 const $starts = convertToTensor(starts, 'starts', 'raggedRange');
16052 const $limits = convertToTensor(limits, 'limits', 'raggedRange', $starts.dtype);
16053 const $deltas = convertToTensor(deltas, 'deltas', 'raggedRange', $starts.dtype);
16054 const inputs = {
16055 starts: $starts,
16056 limits: $limits,
16057 deltas: $deltas,
16058 };
16059 const result = ENGINE.runKernel(RaggedRange, inputs);
16060 return {
16061 rtNestedSplits: result[0],
16062 rtDenseValues: result[1],
16063 };
16064 }
16065 const raggedRange$2 = /* @__PURE__ */ op({ raggedRange_ });
16066
16067 /**
16068 * @license
16069 * Copyright 2022 Google LLC. All Rights Reserved.
16070 * Licensed under the Apache License, Version 2.0 (the "License");
16071 * you may not use this file except in compliance with the License.
16072 * You may obtain a copy of the License at
16073 *
16074 * http://www.apache.org/licenses/LICENSE-2.0
16075 *
16076 * Unless required by applicable law or agreed to in writing, software
16077 * distributed under the License is distributed on an "AS IS" BASIS,
16078 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16079 * See the License for the specific language governing permissions and
16080 * limitations under the License.
16081 * =============================================================================
16082 */
16083 /**
16084 * Create a dense tensor from a ragged tensor, possibly altering its shape.
16085 *
16086 * The raggedTensorToTensor op creates a dense tensor from am array of row
16087 * partition tensors, a value vector, and default values. If the shape is
16088 * unspecified, the minimal shape required to contain all the elements in the
16089 * ragged tensor (the natural shape) will be used. If some dimensions are left
16090 * unspecified, then the size of the natural shape is used in that dimension.
16091 *
16092 * The defaultValue will be broadcast to the output shape. After that, the
16093 * values from the ragged tensor overwrite the default values. Note that the
16094 * defaultValue must have less dimensions than the value.
16095 *
16096 * The row partition tensors are in the order of the dimensions. At present, the
16097 * types can be: "ROW_SPLITS": the row_splits tensor from the ragged tensor.
16098 * "VALUE_ROWIDS": the value_rowids tensor from the ragged tensor.
16099 * "FIRST_DIM_SIZE": if value_rowids is used for the first dimension, then it
16100 * is preceded by "FIRST_DIM_SIZE".
16101 * ```
16102 * @param shape: A Tensor. Must be one of the following types: 'int32'. The
16103 * desired shape of the output tensor. If left unspecified (empty), the
16104 * minimal shape required to contain all the elements in the ragged tensor
16105 * (the natural shape) will be used. If some dimensions are left
16106 * unspecified, then the size of the natural shape is used in that
16107 * dimension.
16108 *
16109 * Note that dense dimensions cannot be modified by the shape argument.
16110 * Trying to change the size of a dense dimension will cause the op to fail.
16111 * Examples: natural shape: [4, 5, 6] shape: -1 output shape: [4, 5, 6]
16112 *
16113 * natural shape: [4, 5, 6] shape: [3, -1, 2] output shape: [3, 5, 2]
16114 *
16115 * natural shape: [4, 5, 6] shape: [3, 7, 2] output shape: [3, 7, 2]
16116 * @param values: A Tensor. A 1D tensor representing the values of the ragged
16117 * tensor.
16118 * @param defaultValue: A Tensor. Must have the same type as values. The
16119 * defaultValue when the shape is larger than the ragged tensor. The
16120 * defaultValue is broadcast until it is the shape of the output tensor,
16121 * and then overwritten by values in the ragged tensor. The default value
16122 * must be compatible with this broadcast operation, and must have fewer
16123 * dimensions than the value tensor.
16124 * @param rowPartitionTensors: A list of at least 1 Tensor objects with the same
16125 * type in: 'int32'.
16126 * @param rowPartitionTypes: A list of strings. The types of the row partition
16127 * tensors. At present, these can be:
16128 * "ROW_SPLITS": the row_splits tensor from the ragged tensor.
16129 * "VALUE_ROWIDS": the value_rowids tensor from the ragged tensor.
16130 * "FIRST_DIM_SIZE": if value_rowids is used for the first dimension, then
16131 * it is preceded by "FIRST_DIM_SIZE". The tensors are in the order of
16132 * the dimensions.
16133 * @return A Tensor. Has the same type as values.
16134 * @doc {heading: 'Operations', subheading: 'Ragged'}
16135 */
16136 function raggedTensorToTensor_(shape, values, defaultValue, rowPartitionTensors, rowPartitionTypes) {
16137 const $shape = convertToTensor(shape, 'shape', 'raggedTensorToTensor', 'int32');
16138 const $values = convertToTensor(values, 'values', 'raggedTensorToTensor');
16139 const $defaultValue = convertToTensor(defaultValue, 'defaultValue', 'raggedTensorToTensor', $values.dtype);
16140 const $rowPartitionTensors = rowPartitionTensors.map((t, i) => convertToTensor(t, `tensors${i}`, 'raggedTensorToTensor', 'int32'));
16141 const inputs = {
16142 shape: $shape,
16143 values: $values,
16144 defaultValue: $defaultValue,
16145 rowPartitionTensors: $rowPartitionTensors
16146 };
16147 const attrs = { rowPartitionTypes };
16148 return ENGINE.runKernel(RaggedTensorToTensor, inputs, attrs);
16149 }
16150 const raggedTensorToTensor$2 = /* @__PURE__ */ op({ raggedTensorToTensor_ });
16151
16152 /**
16153 * @license
16154 * Copyright 2020 Google LLC. All Rights Reserved.
16155 * Licensed under the Apache License, Version 2.0 (the "License");
16156 * you may not use this file except in compliance with the License.
16157 * You may obtain a copy of the License at
16158 *
16159 * http://www.apache.org/licenses/LICENSE-2.0
16160 *
16161 * Unless required by applicable law or agreed to in writing, software
16162 * distributed under the License is distributed on an "AS IS" BASIS,
16163 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16164 * See the License for the specific language governing permissions and
16165 * limitations under the License.
16166 * =============================================================================
16167 */
16168 /**
16169 * Creates a `tf.Tensor` with values sampled from a random number generator
16170 * function defined by the user.
16171 *
16172 * @param shape An array of integers defining the output tensor shape.
16173 * @param randFunction A random number generator function which is called
16174 * for each element in the output tensor.
16175 * @param dtype The data type of the output tensor. Defaults to 'float32'.
16176 *
16177 * @doc {heading: 'Tensors', subheading: 'Random'}
16178 */
16179 function rand_(shape, randFunction, dtype) {
16180 assertNonNegativeIntegerDimensions(shape);
16181 const size = sizeFromShape(shape);
16182 let values = null;
16183 if (dtype == null || dtype === 'float32') {
16184 values = new Float32Array(size);
16185 }
16186 else if (dtype === 'int32') {
16187 values = new Int32Array(size);
16188 }
16189 else if (dtype === 'bool') {
16190 values = new Uint8Array(size);
16191 }
16192 else {
16193 throw new Error(`Unknown data type ${dtype}`);
16194 }
16195 for (let i = 0; i < size; i++) {
16196 values[i] = randFunction();
16197 }
16198 return ENGINE.makeTensor(values, shape, dtype);
16199 }
16200 const rand = /* @__PURE__ */ op({ rand_ });
16201
16202 var alea$3 = {exports: {}};
16203
16204 var alea$1 = alea$3.exports;
16205
16206 (function (module) {
16207 // A port of an algorithm by Johannes Baagøe <baagoe@baagoe.com>, 2010
16208 // http://baagoe.com/en/RandomMusings/javascript/
16209 // https://github.com/nquinlan/better-random-numbers-for-javascript-mirror
16210 // Original work is under MIT license -
16211
16212 // Copyright (C) 2010 by Johannes Baagøe <baagoe@baagoe.org>
16213 //
16214 // Permission is hereby granted, free of charge, to any person obtaining a copy
16215 // of this software and associated documentation files (the "Software"), to deal
16216 // in the Software without restriction, including without limitation the rights
16217 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16218 // copies of the Software, and to permit persons to whom the Software is
16219 // furnished to do so, subject to the following conditions:
16220 //
16221 // The above copyright notice and this permission notice shall be included in
16222 // all copies or substantial portions of the Software.
16223 //
16224 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16225 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16226 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16227 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
16228 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
16229 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
16230 // THE SOFTWARE.
16231
16232
16233
16234 (function(global, module, define) {
16235
16236 function Alea(seed) {
16237 var me = this, mash = Mash();
16238
16239 me.next = function() {
16240 var t = 2091639 * me.s0 + me.c * 2.3283064365386963e-10; // 2^-32
16241 me.s0 = me.s1;
16242 me.s1 = me.s2;
16243 return me.s2 = t - (me.c = t | 0);
16244 };
16245
16246 // Apply the seeding algorithm from Baagoe.
16247 me.c = 1;
16248 me.s0 = mash(' ');
16249 me.s1 = mash(' ');
16250 me.s2 = mash(' ');
16251 me.s0 -= mash(seed);
16252 if (me.s0 < 0) { me.s0 += 1; }
16253 me.s1 -= mash(seed);
16254 if (me.s1 < 0) { me.s1 += 1; }
16255 me.s2 -= mash(seed);
16256 if (me.s2 < 0) { me.s2 += 1; }
16257 mash = null;
16258 }
16259
16260 function copy(f, t) {
16261 t.c = f.c;
16262 t.s0 = f.s0;
16263 t.s1 = f.s1;
16264 t.s2 = f.s2;
16265 return t;
16266 }
16267
16268 function impl(seed, opts) {
16269 var xg = new Alea(seed),
16270 state = opts && opts.state,
16271 prng = xg.next;
16272 prng.int32 = function() { return (xg.next() * 0x100000000) | 0; };
16273 prng.double = function() {
16274 return prng() + (prng() * 0x200000 | 0) * 1.1102230246251565e-16; // 2^-53
16275 };
16276 prng.quick = prng;
16277 if (state) {
16278 if (typeof(state) == 'object') copy(state, xg);
16279 prng.state = function() { return copy(xg, {}); };
16280 }
16281 return prng;
16282 }
16283
16284 function Mash() {
16285 var n = 0xefc8249d;
16286
16287 var mash = function(data) {
16288 data = String(data);
16289 for (var i = 0; i < data.length; i++) {
16290 n += data.charCodeAt(i);
16291 var h = 0.02519603282416938 * n;
16292 n = h >>> 0;
16293 h -= n;
16294 h *= n;
16295 n = h >>> 0;
16296 h -= n;
16297 n += h * 0x100000000; // 2^32
16298 }
16299 return (n >>> 0) * 2.3283064365386963e-10; // 2^-32
16300 };
16301
16302 return mash;
16303 }
16304
16305
16306 if (module && module.exports) {
16307 module.exports = impl;
16308 } else if (define && define.amd) {
16309 define(function() { return impl; });
16310 } else {
16311 this.alea = impl;
16312 }
16313
16314 })(
16315 commonjsGlobal,
16316 ('object') == 'object' && module, // present in node.js
16317 (typeof undefined) == 'function' && undefined // present with an AMD loader
16318 );
16319 } (alea$3));
16320
16321 var aleaExports = alea$3.exports;
16322 var alea$2 = /*@__PURE__*/getDefaultExportFromCjs(aleaExports);
16323
16324 var xor128$3 = {exports: {}};
16325
16326 var xor128$1 = xor128$3.exports;
16327
16328 (function (module) {
16329 // A Javascript implementaion of the "xor128" prng algorithm by
16330 // George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper
16331
16332 (function(global, module, define) {
16333
16334 function XorGen(seed) {
16335 var me = this, strseed = '';
16336
16337 me.x = 0;
16338 me.y = 0;
16339 me.z = 0;
16340 me.w = 0;
16341
16342 // Set up generator function.
16343 me.next = function() {
16344 var t = me.x ^ (me.x << 11);
16345 me.x = me.y;
16346 me.y = me.z;
16347 me.z = me.w;
16348 return me.w ^= (me.w >>> 19) ^ t ^ (t >>> 8);
16349 };
16350
16351 if (seed === (seed | 0)) {
16352 // Integer seed.
16353 me.x = seed;
16354 } else {
16355 // String seed.
16356 strseed += seed;
16357 }
16358
16359 // Mix in string seed, then discard an initial batch of 64 values.
16360 for (var k = 0; k < strseed.length + 64; k++) {
16361 me.x ^= strseed.charCodeAt(k) | 0;
16362 me.next();
16363 }
16364 }
16365
16366 function copy(f, t) {
16367 t.x = f.x;
16368 t.y = f.y;
16369 t.z = f.z;
16370 t.w = f.w;
16371 return t;
16372 }
16373
16374 function impl(seed, opts) {
16375 var xg = new XorGen(seed),
16376 state = opts && opts.state,
16377 prng = function() { return (xg.next() >>> 0) / 0x100000000; };
16378 prng.double = function() {
16379 do {
16380 var top = xg.next() >>> 11,
16381 bot = (xg.next() >>> 0) / 0x100000000,
16382 result = (top + bot) / (1 << 21);
16383 } while (result === 0);
16384 return result;
16385 };
16386 prng.int32 = xg.next;
16387 prng.quick = prng;
16388 if (state) {
16389 if (typeof(state) == 'object') copy(state, xg);
16390 prng.state = function() { return copy(xg, {}); };
16391 }
16392 return prng;
16393 }
16394
16395 if (module && module.exports) {
16396 module.exports = impl;
16397 } else if (define && define.amd) {
16398 define(function() { return impl; });
16399 } else {
16400 this.xor128 = impl;
16401 }
16402
16403 })(
16404 commonjsGlobal,
16405 ('object') == 'object' && module, // present in node.js
16406 (typeof undefined) == 'function' && undefined // present with an AMD loader
16407 );
16408 } (xor128$3));
16409
16410 var xor128Exports = xor128$3.exports;
16411 var xor128$2 = /*@__PURE__*/getDefaultExportFromCjs(xor128Exports);
16412
16413 var xorwow$3 = {exports: {}};
16414
16415 var xorwow$1 = xorwow$3.exports;
16416
16417 (function (module) {
16418 // A Javascript implementaion of the "xorwow" prng algorithm by
16419 // George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper
16420
16421 (function(global, module, define) {
16422
16423 function XorGen(seed) {
16424 var me = this, strseed = '';
16425
16426 // Set up generator function.
16427 me.next = function() {
16428 var t = (me.x ^ (me.x >>> 2));
16429 me.x = me.y; me.y = me.z; me.z = me.w; me.w = me.v;
16430 return (me.d = (me.d + 362437 | 0)) +
16431 (me.v = (me.v ^ (me.v << 4)) ^ (t ^ (t << 1))) | 0;
16432 };
16433
16434 me.x = 0;
16435 me.y = 0;
16436 me.z = 0;
16437 me.w = 0;
16438 me.v = 0;
16439
16440 if (seed === (seed | 0)) {
16441 // Integer seed.
16442 me.x = seed;
16443 } else {
16444 // String seed.
16445 strseed += seed;
16446 }
16447
16448 // Mix in string seed, then discard an initial batch of 64 values.
16449 for (var k = 0; k < strseed.length + 64; k++) {
16450 me.x ^= strseed.charCodeAt(k) | 0;
16451 if (k == strseed.length) {
16452 me.d = me.x << 10 ^ me.x >>> 4;
16453 }
16454 me.next();
16455 }
16456 }
16457
16458 function copy(f, t) {
16459 t.x = f.x;
16460 t.y = f.y;
16461 t.z = f.z;
16462 t.w = f.w;
16463 t.v = f.v;
16464 t.d = f.d;
16465 return t;
16466 }
16467
16468 function impl(seed, opts) {
16469 var xg = new XorGen(seed),
16470 state = opts && opts.state,
16471 prng = function() { return (xg.next() >>> 0) / 0x100000000; };
16472 prng.double = function() {
16473 do {
16474 var top = xg.next() >>> 11,
16475 bot = (xg.next() >>> 0) / 0x100000000,
16476 result = (top + bot) / (1 << 21);
16477 } while (result === 0);
16478 return result;
16479 };
16480 prng.int32 = xg.next;
16481 prng.quick = prng;
16482 if (state) {
16483 if (typeof(state) == 'object') copy(state, xg);
16484 prng.state = function() { return copy(xg, {}); };
16485 }
16486 return prng;
16487 }
16488
16489 if (module && module.exports) {
16490 module.exports = impl;
16491 } else if (define && define.amd) {
16492 define(function() { return impl; });
16493 } else {
16494 this.xorwow = impl;
16495 }
16496
16497 })(
16498 commonjsGlobal,
16499 ('object') == 'object' && module, // present in node.js
16500 (typeof undefined) == 'function' && undefined // present with an AMD loader
16501 );
16502 } (xorwow$3));
16503
16504 var xorwowExports = xorwow$3.exports;
16505 var xorwow$2 = /*@__PURE__*/getDefaultExportFromCjs(xorwowExports);
16506
16507 var xorshift7$3 = {exports: {}};
16508
16509 var xorshift7$1 = xorshift7$3.exports;
16510
16511 (function (module) {
16512 // A Javascript implementaion of the "xorshift7" algorithm by
16513 // François Panneton and Pierre L'ecuyer:
16514 // "On the Xorgshift Random Number Generators"
16515 // http://saluc.engr.uconn.edu/refs/crypto/rng/panneton05onthexorshift.pdf
16516
16517 (function(global, module, define) {
16518
16519 function XorGen(seed) {
16520 var me = this;
16521
16522 // Set up generator function.
16523 me.next = function() {
16524 // Update xor generator.
16525 var X = me.x, i = me.i, t, v, w;
16526 t = X[i]; t ^= (t >>> 7); v = t ^ (t << 24);
16527 t = X[(i + 1) & 7]; v ^= t ^ (t >>> 10);
16528 t = X[(i + 3) & 7]; v ^= t ^ (t >>> 3);
16529 t = X[(i + 4) & 7]; v ^= t ^ (t << 7);
16530 t = X[(i + 7) & 7]; t = t ^ (t << 13); v ^= t ^ (t << 9);
16531 X[i] = v;
16532 me.i = (i + 1) & 7;
16533 return v;
16534 };
16535
16536 function init(me, seed) {
16537 var j, w, X = [];
16538
16539 if (seed === (seed | 0)) {
16540 // Seed state array using a 32-bit integer.
16541 w = X[0] = seed;
16542 } else {
16543 // Seed state using a string.
16544 seed = '' + seed;
16545 for (j = 0; j < seed.length; ++j) {
16546 X[j & 7] = (X[j & 7] << 15) ^
16547 (seed.charCodeAt(j) + X[(j + 1) & 7] << 13);
16548 }
16549 }
16550 // Enforce an array length of 8, not all zeroes.
16551 while (X.length < 8) X.push(0);
16552 for (j = 0; j < 8 && X[j] === 0; ++j);
16553 if (j == 8) w = X[7] = -1; else w = X[j];
16554
16555 me.x = X;
16556 me.i = 0;
16557
16558 // Discard an initial 256 values.
16559 for (j = 256; j > 0; --j) {
16560 me.next();
16561 }
16562 }
16563
16564 init(me, seed);
16565 }
16566
16567 function copy(f, t) {
16568 t.x = f.x.slice();
16569 t.i = f.i;
16570 return t;
16571 }
16572
16573 function impl(seed, opts) {
16574 if (seed == null) seed = +(new Date);
16575 var xg = new XorGen(seed),
16576 state = opts && opts.state,
16577 prng = function() { return (xg.next() >>> 0) / 0x100000000; };
16578 prng.double = function() {
16579 do {
16580 var top = xg.next() >>> 11,
16581 bot = (xg.next() >>> 0) / 0x100000000,
16582 result = (top + bot) / (1 << 21);
16583 } while (result === 0);
16584 return result;
16585 };
16586 prng.int32 = xg.next;
16587 prng.quick = prng;
16588 if (state) {
16589 if (state.x) copy(state, xg);
16590 prng.state = function() { return copy(xg, {}); };
16591 }
16592 return prng;
16593 }
16594
16595 if (module && module.exports) {
16596 module.exports = impl;
16597 } else if (define && define.amd) {
16598 define(function() { return impl; });
16599 } else {
16600 this.xorshift7 = impl;
16601 }
16602
16603 })(
16604 commonjsGlobal,
16605 ('object') == 'object' && module, // present in node.js
16606 (typeof undefined) == 'function' && undefined // present with an AMD loader
16607 );
16608 } (xorshift7$3));
16609
16610 var xorshift7Exports = xorshift7$3.exports;
16611 var xorshift7$2 = /*@__PURE__*/getDefaultExportFromCjs(xorshift7Exports);
16612
16613 var xor4096$3 = {exports: {}};
16614
16615 var xor4096$1 = xor4096$3.exports;
16616
16617 (function (module) {
16618 // A Javascript implementaion of Richard Brent's Xorgens xor4096 algorithm.
16619 //
16620 // This fast non-cryptographic random number generator is designed for
16621 // use in Monte-Carlo algorithms. It combines a long-period xorshift
16622 // generator with a Weyl generator, and it passes all common batteries
16623 // of stasticial tests for randomness while consuming only a few nanoseconds
16624 // for each prng generated. For background on the generator, see Brent's
16625 // paper: "Some long-period random number generators using shifts and xors."
16626 // http://arxiv.org/pdf/1004.3115v1.pdf
16627 //
16628 // Usage:
16629 //
16630 // var xor4096 = require('xor4096');
16631 // random = xor4096(1); // Seed with int32 or string.
16632 // assert.equal(random(), 0.1520436450538547); // (0, 1) range, 53 bits.
16633 // assert.equal(random.int32(), 1806534897); // signed int32, 32 bits.
16634 //
16635 // For nonzero numeric keys, this impelementation provides a sequence
16636 // identical to that by Brent's xorgens 3 implementaion in C. This
16637 // implementation also provides for initalizing the generator with
16638 // string seeds, or for saving and restoring the state of the generator.
16639 //
16640 // On Chrome, this prng benchmarks about 2.1 times slower than
16641 // Javascript's built-in Math.random().
16642
16643 (function(global, module, define) {
16644
16645 function XorGen(seed) {
16646 var me = this;
16647
16648 // Set up generator function.
16649 me.next = function() {
16650 var w = me.w,
16651 X = me.X, i = me.i, t, v;
16652 // Update Weyl generator.
16653 me.w = w = (w + 0x61c88647) | 0;
16654 // Update xor generator.
16655 v = X[(i + 34) & 127];
16656 t = X[i = ((i + 1) & 127)];
16657 v ^= v << 13;
16658 t ^= t << 17;
16659 v ^= v >>> 15;
16660 t ^= t >>> 12;
16661 // Update Xor generator array state.
16662 v = X[i] = v ^ t;
16663 me.i = i;
16664 // Result is the combination.
16665 return (v + (w ^ (w >>> 16))) | 0;
16666 };
16667
16668 function init(me, seed) {
16669 var t, v, i, j, w, X = [], limit = 128;
16670 if (seed === (seed | 0)) {
16671 // Numeric seeds initialize v, which is used to generates X.
16672 v = seed;
16673 seed = null;
16674 } else {
16675 // String seeds are mixed into v and X one character at a time.
16676 seed = seed + '\0';
16677 v = 0;
16678 limit = Math.max(limit, seed.length);
16679 }
16680 // Initialize circular array and weyl value.
16681 for (i = 0, j = -32; j < limit; ++j) {
16682 // Put the unicode characters into the array, and shuffle them.
16683 if (seed) v ^= seed.charCodeAt((j + 32) % seed.length);
16684 // After 32 shuffles, take v as the starting w value.
16685 if (j === 0) w = v;
16686 v ^= v << 10;
16687 v ^= v >>> 15;
16688 v ^= v << 4;
16689 v ^= v >>> 13;
16690 if (j >= 0) {
16691 w = (w + 0x61c88647) | 0; // Weyl.
16692 t = (X[j & 127] ^= (v + w)); // Combine xor and weyl to init array.
16693 i = (0 == t) ? i + 1 : 0; // Count zeroes.
16694 }
16695 }
16696 // We have detected all zeroes; make the key nonzero.
16697 if (i >= 128) {
16698 X[(seed && seed.length || 0) & 127] = -1;
16699 }
16700 // Run the generator 512 times to further mix the state before using it.
16701 // Factoring this as a function slows the main generator, so it is just
16702 // unrolled here. The weyl generator is not advanced while warming up.
16703 i = 127;
16704 for (j = 4 * 128; j > 0; --j) {
16705 v = X[(i + 34) & 127];
16706 t = X[i = ((i + 1) & 127)];
16707 v ^= v << 13;
16708 t ^= t << 17;
16709 v ^= v >>> 15;
16710 t ^= t >>> 12;
16711 X[i] = v ^ t;
16712 }
16713 // Storing state as object members is faster than using closure variables.
16714 me.w = w;
16715 me.X = X;
16716 me.i = i;
16717 }
16718
16719 init(me, seed);
16720 }
16721
16722 function copy(f, t) {
16723 t.i = f.i;
16724 t.w = f.w;
16725 t.X = f.X.slice();
16726 return t;
16727 };
16728
16729 function impl(seed, opts) {
16730 if (seed == null) seed = +(new Date);
16731 var xg = new XorGen(seed),
16732 state = opts && opts.state,
16733 prng = function() { return (xg.next() >>> 0) / 0x100000000; };
16734 prng.double = function() {
16735 do {
16736 var top = xg.next() >>> 11,
16737 bot = (xg.next() >>> 0) / 0x100000000,
16738 result = (top + bot) / (1 << 21);
16739 } while (result === 0);
16740 return result;
16741 };
16742 prng.int32 = xg.next;
16743 prng.quick = prng;
16744 if (state) {
16745 if (state.X) copy(state, xg);
16746 prng.state = function() { return copy(xg, {}); };
16747 }
16748 return prng;
16749 }
16750
16751 if (module && module.exports) {
16752 module.exports = impl;
16753 } else if (define && define.amd) {
16754 define(function() { return impl; });
16755 } else {
16756 this.xor4096 = impl;
16757 }
16758
16759 })(
16760 commonjsGlobal, // window object or global
16761 ('object') == 'object' && module, // present in node.js
16762 (typeof undefined) == 'function' && undefined // present with an AMD loader
16763 );
16764 } (xor4096$3));
16765
16766 var xor4096Exports = xor4096$3.exports;
16767 var xor4096$2 = /*@__PURE__*/getDefaultExportFromCjs(xor4096Exports);
16768
16769 var tychei$3 = {exports: {}};
16770
16771 var tychei$1 = tychei$3.exports;
16772
16773 (function (module) {
16774 // A Javascript implementaion of the "Tyche-i" prng algorithm by
16775 // Samuel Neves and Filipe Araujo.
16776 // See https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf
16777
16778 (function(global, module, define) {
16779
16780 function XorGen(seed) {
16781 var me = this, strseed = '';
16782
16783 // Set up generator function.
16784 me.next = function() {
16785 var b = me.b, c = me.c, d = me.d, a = me.a;
16786 b = (b << 25) ^ (b >>> 7) ^ c;
16787 c = (c - d) | 0;
16788 d = (d << 24) ^ (d >>> 8) ^ a;
16789 a = (a - b) | 0;
16790 me.b = b = (b << 20) ^ (b >>> 12) ^ c;
16791 me.c = c = (c - d) | 0;
16792 me.d = (d << 16) ^ (c >>> 16) ^ a;
16793 return me.a = (a - b) | 0;
16794 };
16795
16796 /* The following is non-inverted tyche, which has better internal
16797 * bit diffusion, but which is about 25% slower than tyche-i in JS.
16798 me.next = function() {
16799 var a = me.a, b = me.b, c = me.c, d = me.d;
16800 a = (me.a + me.b | 0) >>> 0;
16801 d = me.d ^ a; d = d << 16 ^ d >>> 16;
16802 c = me.c + d | 0;
16803 b = me.b ^ c; b = b << 12 ^ d >>> 20;
16804 me.a = a = a + b | 0;
16805 d = d ^ a; me.d = d = d << 8 ^ d >>> 24;
16806 me.c = c = c + d | 0;
16807 b = b ^ c;
16808 return me.b = (b << 7 ^ b >>> 25);
16809 }
16810 */
16811
16812 me.a = 0;
16813 me.b = 0;
16814 me.c = 2654435769 | 0;
16815 me.d = 1367130551;
16816
16817 if (seed === Math.floor(seed)) {
16818 // Integer seed.
16819 me.a = (seed / 0x100000000) | 0;
16820 me.b = seed | 0;
16821 } else {
16822 // String seed.
16823 strseed += seed;
16824 }
16825
16826 // Mix in string seed, then discard an initial batch of 64 values.
16827 for (var k = 0; k < strseed.length + 20; k++) {
16828 me.b ^= strseed.charCodeAt(k) | 0;
16829 me.next();
16830 }
16831 }
16832
16833 function copy(f, t) {
16834 t.a = f.a;
16835 t.b = f.b;
16836 t.c = f.c;
16837 t.d = f.d;
16838 return t;
16839 };
16840
16841 function impl(seed, opts) {
16842 var xg = new XorGen(seed),
16843 state = opts && opts.state,
16844 prng = function() { return (xg.next() >>> 0) / 0x100000000; };
16845 prng.double = function() {
16846 do {
16847 var top = xg.next() >>> 11,
16848 bot = (xg.next() >>> 0) / 0x100000000,
16849 result = (top + bot) / (1 << 21);
16850 } while (result === 0);
16851 return result;
16852 };
16853 prng.int32 = xg.next;
16854 prng.quick = prng;
16855 if (state) {
16856 if (typeof(state) == 'object') copy(state, xg);
16857 prng.state = function() { return copy(xg, {}); };
16858 }
16859 return prng;
16860 }
16861
16862 if (module && module.exports) {
16863 module.exports = impl;
16864 } else if (define && define.amd) {
16865 define(function() { return impl; });
16866 } else {
16867 this.tychei = impl;
16868 }
16869
16870 })(
16871 commonjsGlobal,
16872 ('object') == 'object' && module, // present in node.js
16873 (typeof undefined) == 'function' && undefined // present with an AMD loader
16874 );
16875 } (tychei$3));
16876
16877 var tycheiExports = tychei$3.exports;
16878 var tychei$2 = /*@__PURE__*/getDefaultExportFromCjs(tycheiExports);
16879
16880 var seedrandom$3 = {exports: {}};
16881
16882 /*
16883 Copyright 2019 David Bau.
16884
16885 Permission is hereby granted, free of charge, to any person obtaining
16886 a copy of this software and associated documentation files (the
16887 "Software"), to deal in the Software without restriction, including
16888 without limitation the rights to use, copy, modify, merge, publish,
16889 distribute, sublicense, and/or sell copies of the Software, and to
16890 permit persons to whom the Software is furnished to do so, subject to
16891 the following conditions:
16892
16893 The above copyright notice and this permission notice shall be
16894 included in all copies or substantial portions of the Software.
16895
16896 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
16897 EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16898 MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16899 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
16900 CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16901 TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
16902 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
16903
16904 */
16905 var seedrandom$1 = seedrandom$3.exports;
16906
16907 (function (module) {
16908 (function (global, pool, math) {
16909 //
16910 // The following constants are related to IEEE 754 limits.
16911 //
16912
16913 var width = 256, // each RC4 output is 0 <= x < 256
16914 chunks = 6, // at least six RC4 outputs for each double
16915 digits = 52, // there are 52 significant digits in a double
16916 rngname = 'random', // rngname: name for Math.random and Math.seedrandom
16917 startdenom = math.pow(width, chunks),
16918 significance = math.pow(2, digits),
16919 overflow = significance * 2,
16920 mask = width - 1,
16921 nodecrypto; // node.js crypto module, initialized at the bottom.
16922
16923 //
16924 // seedrandom()
16925 // This is the seedrandom function described above.
16926 //
16927 function seedrandom(seed, options, callback) {
16928 var key = [];
16929 options = (options == true) ? { entropy: true } : (options || {});
16930
16931 // Flatten the seed string or build one from local entropy if needed.
16932 var shortseed = mixkey(flatten(
16933 options.entropy ? [seed, tostring(pool)] :
16934 (seed == null) ? autoseed() : seed, 3), key);
16935
16936 // Use the seed to initialize an ARC4 generator.
16937 var arc4 = new ARC4(key);
16938
16939 // This function returns a random double in [0, 1) that contains
16940 // randomness in every bit of the mantissa of the IEEE 754 value.
16941 var prng = function() {
16942 var n = arc4.g(chunks), // Start with a numerator n < 2 ^ 48
16943 d = startdenom, // and denominator d = 2 ^ 48.
16944 x = 0; // and no 'extra last byte'.
16945 while (n < significance) { // Fill up all significant digits by
16946 n = (n + x) * width; // shifting numerator and
16947 d *= width; // denominator and generating a
16948 x = arc4.g(1); // new least-significant-byte.
16949 }
16950 while (n >= overflow) { // To avoid rounding up, before adding
16951 n /= 2; // last byte, shift everything
16952 d /= 2; // right using integer math until
16953 x >>>= 1; // we have exactly the desired bits.
16954 }
16955 return (n + x) / d; // Form the number within [0, 1).
16956 };
16957
16958 prng.int32 = function() { return arc4.g(4) | 0; };
16959 prng.quick = function() { return arc4.g(4) / 0x100000000; };
16960 prng.double = prng;
16961
16962 // Mix the randomness into accumulated entropy.
16963 mixkey(tostring(arc4.S), pool);
16964
16965 // Calling convention: what to return as a function of prng, seed, is_math.
16966 return (options.pass || callback ||
16967 function(prng, seed, is_math_call, state) {
16968 if (state) {
16969 // Load the arc4 state from the given state if it has an S array.
16970 if (state.S) { copy(state, arc4); }
16971 // Only provide the .state method if requested via options.state.
16972 prng.state = function() { return copy(arc4, {}); };
16973 }
16974
16975 // If called as a method of Math (Math.seedrandom()), mutate
16976 // Math.random because that is how seedrandom.js has worked since v1.0.
16977 if (is_math_call) { math[rngname] = prng; return seed; }
16978
16979 // Otherwise, it is a newer calling convention, so return the
16980 // prng directly.
16981 else return prng;
16982 })(
16983 prng,
16984 shortseed,
16985 'global' in options ? options.global : (this == math),
16986 options.state);
16987 }
16988
16989 //
16990 // ARC4
16991 //
16992 // An ARC4 implementation. The constructor takes a key in the form of
16993 // an array of at most (width) integers that should be 0 <= x < (width).
16994 //
16995 // The g(count) method returns a pseudorandom integer that concatenates
16996 // the next (count) outputs from ARC4. Its return value is a number x
16997 // that is in the range 0 <= x < (width ^ count).
16998 //
16999 function ARC4(key) {
17000 var t, keylen = key.length,
17001 me = this, i = 0, j = me.i = me.j = 0, s = me.S = [];
17002
17003 // The empty key [] is treated as [0].
17004 if (!keylen) { key = [keylen++]; }
17005
17006 // Set up S using the standard key scheduling algorithm.
17007 while (i < width) {
17008 s[i] = i++;
17009 }
17010 for (i = 0; i < width; i++) {
17011 s[i] = s[j = mask & (j + key[i % keylen] + (t = s[i]))];
17012 s[j] = t;
17013 }
17014
17015 // The "g" method returns the next (count) outputs as one number.
17016 (me.g = function(count) {
17017 // Using instance members instead of closure state nearly doubles speed.
17018 var t, r = 0,
17019 i = me.i, j = me.j, s = me.S;
17020 while (count--) {
17021 t = s[i = mask & (i + 1)];
17022 r = r * width + s[mask & ((s[i] = s[j = mask & (j + t)]) + (s[j] = t))];
17023 }
17024 me.i = i; me.j = j;
17025 return r;
17026 // For robust unpredictability, the function call below automatically
17027 // discards an initial batch of values. This is called RC4-drop[256].
17028 // See http://google.com/search?q=rsa+fluhrer+response&btnI
17029 })(width);
17030 }
17031
17032 //
17033 // copy()
17034 // Copies internal state of ARC4 to or from a plain object.
17035 //
17036 function copy(f, t) {
17037 t.i = f.i;
17038 t.j = f.j;
17039 t.S = f.S.slice();
17040 return t;
17041 };
17042
17043 //
17044 // flatten()
17045 // Converts an object tree to nested arrays of strings.
17046 //
17047 function flatten(obj, depth) {
17048 var result = [], typ = (typeof obj), prop;
17049 if (depth && typ == 'object') {
17050 for (prop in obj) {
17051 try { result.push(flatten(obj[prop], depth - 1)); } catch (e) {}
17052 }
17053 }
17054 return (result.length ? result : typ == 'string' ? obj : obj + '\0');
17055 }
17056
17057 //
17058 // mixkey()
17059 // Mixes a string seed into a key that is an array of integers, and
17060 // returns a shortened string seed that is equivalent to the result key.
17061 //
17062 function mixkey(seed, key) {
17063 var stringseed = seed + '', smear, j = 0;
17064 while (j < stringseed.length) {
17065 key[mask & j] =
17066 mask & ((smear ^= key[mask & j] * 19) + stringseed.charCodeAt(j++));
17067 }
17068 return tostring(key);
17069 }
17070
17071 //
17072 // autoseed()
17073 // Returns an object for autoseeding, using window.crypto and Node crypto
17074 // module if available.
17075 //
17076 function autoseed() {
17077 try {
17078 var out;
17079 if (nodecrypto && (out = nodecrypto.randomBytes)) {
17080 // The use of 'out' to remember randomBytes makes tight minified code.
17081 out = out(width);
17082 } else {
17083 out = new Uint8Array(width);
17084 (global.crypto || global.msCrypto).getRandomValues(out);
17085 }
17086 return tostring(out);
17087 } catch (e) {
17088 var browser = global.navigator,
17089 plugins = browser && browser.plugins;
17090 return [+new Date, global, plugins, global.screen, tostring(pool)];
17091 }
17092 }
17093
17094 //
17095 // tostring()
17096 // Converts an array of charcodes to a string
17097 //
17098 function tostring(a) {
17099 return String.fromCharCode.apply(0, a);
17100 }
17101
17102 //
17103 // When seedrandom.js is loaded, we immediately mix a few bits
17104 // from the built-in RNG into the entropy pool. Because we do
17105 // not want to interfere with deterministic PRNG state later,
17106 // seedrandom will not call math.random on its own again after
17107 // initialization.
17108 //
17109 mixkey(math.random(), pool);
17110
17111 //
17112 // Nodejs and AMD support: export the implementation as a module using
17113 // either convention.
17114 //
17115 if (('object') == 'object' && module.exports) {
17116 module.exports = seedrandom;
17117 // When in node.js, try using crypto package for autoseeding.
17118 try {
17119 nodecrypto = require('crypto');
17120 } catch (ex) {}
17121 } else if ((typeof undefined) == 'function' && undefined.amd) {
17122 undefined(function() { return seedrandom; });
17123 } else {
17124 // When included as a plain script, set up Math.seedrandom global.
17125 math['seed' + rngname] = seedrandom;
17126 }
17127
17128
17129 // End anonymous scope, and pass initial values.
17130 })(
17131 // global: `self` in browsers (including strict mode and web workers),
17132 // otherwise `this` in Node and other environments
17133 (typeof self !== 'undefined') ? self : commonjsGlobal,
17134 [], // pool: entropy pool starts empty
17135 Math // math: package containing random, pow, and seedrandom
17136 );
17137 } (seedrandom$3));
17138
17139 var seedrandomExports = seedrandom$3.exports;
17140 var seedrandom$2 = /*@__PURE__*/getDefaultExportFromCjs(seedrandomExports);
17141
17142 // A library of seedable RNGs implemented in Javascript.
17143 //
17144 // Usage:
17145 //
17146 // var seedrandom = require('seedrandom');
17147 // var random = seedrandom(1); // or any seed.
17148 // var x = random(); // 0 <= x < 1. Every bit is random.
17149 // var x = random.quick(); // 0 <= x < 1. 32 bits of randomness.
17150
17151 // alea, a 53-bit multiply-with-carry generator by Johannes Baagøe.
17152 // Period: ~2^116
17153 // Reported to pass all BigCrush tests.
17154 var alea = aleaExports;
17155
17156 // xor128, a pure xor-shift generator by George Marsaglia.
17157 // Period: 2^128-1.
17158 // Reported to fail: MatrixRank and LinearComp.
17159 var xor128 = xor128Exports;
17160
17161 // xorwow, George Marsaglia's 160-bit xor-shift combined plus weyl.
17162 // Period: 2^192-2^32
17163 // Reported to fail: CollisionOver, SimpPoker, and LinearComp.
17164 var xorwow = xorwowExports;
17165
17166 // xorshift7, by François Panneton and Pierre L'ecuyer, takes
17167 // a different approach: it adds robustness by allowing more shifts
17168 // than Marsaglia's original three. It is a 7-shift generator
17169 // with 256 bits, that passes BigCrush with no systmatic failures.
17170 // Period 2^256-1.
17171 // No systematic BigCrush failures reported.
17172 var xorshift7 = xorshift7Exports;
17173
17174 // xor4096, by Richard Brent, is a 4096-bit xor-shift with a
17175 // very long period that also adds a Weyl generator. It also passes
17176 // BigCrush with no systematic failures. Its long period may
17177 // be useful if you have many generators and need to avoid
17178 // collisions.
17179 // Period: 2^4128-2^32.
17180 // No systematic BigCrush failures reported.
17181 var xor4096 = xor4096Exports;
17182
17183 // Tyche-i, by Samuel Neves and Filipe Araujo, is a bit-shifting random
17184 // number generator derived from ChaCha, a modern stream cipher.
17185 // https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf
17186 // Period: ~2^127
17187 // No systematic BigCrush failures reported.
17188 var tychei = tycheiExports;
17189
17190 // The original ARC4-based prng included in this library.
17191 // Period: ~2^1600
17192 var sr = seedrandomExports;
17193
17194 sr.alea = alea;
17195 sr.xor128 = xor128;
17196 sr.xorwow = xorwow;
17197 sr.xorshift7 = xorshift7;
17198 sr.xor4096 = xor4096;
17199 sr.tychei = tychei;
17200
17201 var seedrandom = sr;
17202
17203 var index$1 = /*@__PURE__*/getDefaultExportFromCjs(seedrandom);
17204
17205 /**
17206 * @license
17207 * Copyright 2017 Google LLC. All Rights Reserved.
17208 * Licensed under the Apache License, Version 2.0 (the "License");
17209 * you may not use this file except in compliance with the License.
17210 * You may obtain a copy of the License at
17211 *
17212 * http://www.apache.org/licenses/LICENSE-2.0
17213 *
17214 * Unless required by applicable law or agreed to in writing, software
17215 * distributed under the License is distributed on an "AS IS" BASIS,
17216 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17217 * See the License for the specific language governing permissions and
17218 * limitations under the License.
17219 * =============================================================================
17220 */
17221 const TEST_EPSILON_FLOAT32 = 1e-3;
17222 const TEST_EPSILON_FLOAT16 = 1e-1;
17223 function expectArraysClose(actual, expected, epsilon) {
17224 if (epsilon == null) {
17225 epsilon = testEpsilon();
17226 }
17227 return expectArraysPredicate(actual, expected, (a, b) => areClose(a, b, epsilon));
17228 }
17229 function testEpsilon() {
17230 return ENGINE.backend.floatPrecision() === 32 ? TEST_EPSILON_FLOAT32 :
17231 TEST_EPSILON_FLOAT16;
17232 }
17233 function expectArraysPredicate(actual, expected, predicate) {
17234 let checkClassType = true;
17235 if (isTypedArray(actual) || isTypedArray(expected)) {
17236 checkClassType = false;
17237 }
17238 if (isTypedArray(actual) && isTypedArray(expected)) {
17239 checkClassType = true;
17240 }
17241 if (checkClassType) {
17242 const aType = actual.constructor.name;
17243 const bType = expected.constructor.name;
17244 if (aType !== bType) {
17245 throw new Error(`Arrays are of different type. Actual: ${aType}. ` +
17246 `Expected: ${bType}`);
17247 }
17248 }
17249 if (Array.isArray(actual) && Array.isArray(expected)) {
17250 const actualShape = inferShape(actual);
17251 const expectedShape = inferShape(expected);
17252 if (!arraysEqual(actualShape, expectedShape)) {
17253 throw new Error(`Arrays have different shapes. ` +
17254 `Actual: [${actualShape}]. Expected: [${expectedShape}]`);
17255 }
17256 }
17257 const actualFlat = isTypedArray(actual) ? actual : flatten$2(actual);
17258 const expectedFlat = isTypedArray(expected) ?
17259 expected :
17260 flatten$2(expected);
17261 if (actualFlat.length !== expectedFlat.length) {
17262 throw new Error(`Arrays have different lengths actual: ${actualFlat.length} vs ` +
17263 `expected: ${expectedFlat.length}.\n` +
17264 `Actual: ${actualFlat}.\n` +
17265 `Expected: ${expectedFlat}.`);
17266 }
17267 for (let i = 0; i < expectedFlat.length; ++i) {
17268 const a = actualFlat[i];
17269 const e = expectedFlat[i];
17270 if (!predicate(a, e)) {
17271 throw new Error(`Arrays differ: actual[${i}] = ${a}, expected[${i}] = ${e}.\n` +
17272 `Actual: ${actualFlat}.\n` +
17273 `Expected: ${expectedFlat}.`);
17274 }
17275 }
17276 if (typeof expect !== 'undefined') {
17277 expect().nothing();
17278 }
17279 }
17280 function expectPromiseToFail(fn, done) {
17281 fn().then(() => done.fail(), () => done());
17282 if (typeof expect !== 'undefined') {
17283 expect().nothing();
17284 }
17285 }
17286 function expectArraysEqual(actual, expected) {
17287 const exp = typeof expected === 'string' || typeof expected === 'number' ||
17288 typeof expected === 'boolean' ?
17289 [expected] :
17290 expected;
17291 if (isString(actual) || isString(actual[0]) ||
17292 isString(expected) || isString(expected[0])) {
17293 // tslint:disable-next-line: triple-equals
17294 return expectArraysPredicate(actual, exp, (a, b) => a == b);
17295 }
17296 return expectArraysPredicate(actual, expected, (a, b) => areClose(a, b, 0));
17297 }
17298 function expectNumbersClose(a, e, epsilon) {
17299 if (epsilon == null) {
17300 epsilon = testEpsilon();
17301 }
17302 if (!areClose(a, e, epsilon)) {
17303 throw new Error(`Numbers differ: actual === ${a}, expected === ${e}`);
17304 }
17305 if (typeof expect !== 'undefined') {
17306 expect().nothing();
17307 }
17308 }
17309 function areClose(a, e, epsilon) {
17310 if (!isFinite(a) && !isFinite(e)) {
17311 return true;
17312 }
17313 if (isNaN(a) || isNaN(e) || Math.abs(a - e) > epsilon) {
17314 return false;
17315 }
17316 return true;
17317 }
17318 function expectValuesInRange(actual, low, high) {
17319 for (let i = 0; i < actual.length; i++) {
17320 if (actual[i] < low || actual[i] > high) {
17321 throw new Error(`Value out of range:${actual[i]} low: ${low}, high: ${high}`);
17322 }
17323 }
17324 }
17325 function expectArrayBuffersEqual(actual, expected) {
17326 // Safari does not like comparing ArrayBuffers directly. Wrapping in
17327 // a Float32Array solves this issue.
17328 const actualArray = new Float32Array(actual);
17329 const expectedArray = new Float32Array(expected);
17330 if (actualArray.length !== expectedArray.length) {
17331 throw new Error('Expected ArrayBuffer to be of length ' +
17332 `${expectedArray.length}, but it was ${actualArray.length}`);
17333 }
17334 for (let i = 0; i < expectedArray.length; i++) {
17335 if (actualArray[i] !== expectedArray[i]) {
17336 throw new Error(`Expected ArrayBuffer value at ${i} to be ` +
17337 `${expectedArray[i]} but got ${actualArray[i]} instead`);
17338 }
17339 }
17340 }
17341 /** Encodes strings into utf-8 bytes. */
17342 function encodeStrings(a) {
17343 for (let i = 0; i < a.length; i++) {
17344 const val = a[i];
17345 if (Array.isArray(val)) {
17346 encodeStrings(val);
17347 }
17348 else {
17349 a[i] = encodeString(val);
17350 }
17351 }
17352 return a;
17353 }
17354 /** Creates an HTMLVideoElement with autoplay-friendly default settings. */
17355 function createVideoElement(source) {
17356 const video = document.createElement('video');
17357 if ('playsInline' in video) {
17358 // tslint:disable-next-line:no-any
17359 video.playsInline = true;
17360 }
17361 video.muted = true;
17362 video.loop = true;
17363 video.style.position = 'fixed';
17364 video.style.left = '0px';
17365 video.style.top = '0px';
17366 video.preload = 'auto';
17367 video.appendChild(source);
17368 return new Promise(resolve => {
17369 video.addEventListener('loadeddata', _ => resolve(video));
17370 video.load();
17371 });
17372 }
17373 async function play(video) {
17374 await video.play();
17375 if ('requestVideoFrameCallback' in video) {
17376 await new Promise(resolve => {
17377 // tslint:disable-next-line:no-any
17378 video.requestVideoFrameCallback(resolve);
17379 });
17380 }
17381 }
17382
17383 var test_util = /*#__PURE__*/Object.freeze({
17384 __proto__: null,
17385 TEST_EPSILON_FLOAT16: TEST_EPSILON_FLOAT16,
17386 createVideoElement: createVideoElement,
17387 encodeStrings: encodeStrings,
17388 expectArrayBuffersEqual: expectArrayBuffersEqual,
17389 expectArraysClose: expectArraysClose,
17390 expectArraysEqual: expectArraysEqual,
17391 expectNumbersClose: expectNumbersClose,
17392 expectPromiseToFail: expectPromiseToFail,
17393 expectValuesInRange: expectValuesInRange,
17394 play: play,
17395 testEpsilon: testEpsilon
17396 });
17397
17398 /**
17399 * @license
17400 * Copyright 2018 Google LLC. All Rights Reserved.
17401 * Licensed under the Apache License, Version 2.0 (the "License");
17402 * you may not use this file except in compliance with the License.
17403 * You may obtain a copy of the License at
17404 *
17405 * http://www.apache.org/licenses/LICENSE-2.0
17406 *
17407 * Unless required by applicable law or agreed to in writing, software
17408 * distributed under the License is distributed on an "AS IS" BASIS,
17409 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17410 * See the License for the specific language governing permissions and
17411 * limitations under the License.
17412 * =============================================================================
17413 */
17414 // https://en.wikipedia.org/wiki/Marsaglia_polar_method
17415 class MPRandGauss {
17416 constructor(mean, stdDeviation, dtype, truncated, seed) {
17417 this.mean = mean;
17418 this.stdDev = stdDeviation;
17419 this.dtype = dtype;
17420 this.nextVal = NaN;
17421 this.truncated = truncated;
17422 if (this.truncated) {
17423 this.upper = this.mean + this.stdDev * 2;
17424 this.lower = this.mean - this.stdDev * 2;
17425 }
17426 const seedValue = seed ? seed : Math.random();
17427 this.random = seedrandom.alea(seedValue.toString());
17428 }
17429 /** Returns next sample from a Gaussian distribution. */
17430 nextValue() {
17431 if (!isNaN(this.nextVal)) {
17432 const value = this.nextVal;
17433 this.nextVal = NaN;
17434 return value;
17435 }
17436 let resultX, resultY;
17437 let isValid = false;
17438 while (!isValid) {
17439 let v1, v2, s;
17440 do {
17441 v1 = 2 * this.random() - 1;
17442 v2 = 2 * this.random() - 1;
17443 s = v1 * v1 + v2 * v2;
17444 } while (s >= 1 || s === 0);
17445 const mul = Math.sqrt(-2.0 * Math.log(s) / s);
17446 resultX = this.mean + this.stdDev * v1 * mul;
17447 resultY = this.mean + this.stdDev * v2 * mul;
17448 if (!this.truncated || this.isValidTruncated(resultX)) {
17449 isValid = true;
17450 }
17451 }
17452 if (!this.truncated || this.isValidTruncated(resultY)) {
17453 this.nextVal = this.convertValue(resultY);
17454 }
17455 return this.convertValue(resultX);
17456 }
17457 /** Handles proper rounding for non-floating-point numbers. */
17458 convertValue(value) {
17459 if (this.dtype == null || this.dtype === 'float32') {
17460 return value;
17461 }
17462 return Math.round(value);
17463 }
17464 /** Returns true if less than 2-standard-deviations from the mean. */
17465 isValidTruncated(value) {
17466 return value <= this.upper && value >= this.lower;
17467 }
17468 }
17469 // Marsaglia, George, and Wai Wan Tsang. 2000. "A Simple Method for Generating
17470 // Gamma Variables."
17471 class RandGamma {
17472 constructor(alpha, beta, dtype, seed) {
17473 this.alpha = alpha;
17474 this.beta = 1 / beta; // convert rate to scale parameter
17475 this.dtype = dtype;
17476 const seedValue = seed ? seed : Math.random();
17477 this.randu = seedrandom.alea(seedValue.toString());
17478 this.randn = new MPRandGauss(0, 1, dtype, false, this.randu());
17479 if (alpha < 1) {
17480 this.d = alpha + (2 / 3);
17481 }
17482 else {
17483 this.d = alpha - (1 / 3);
17484 }
17485 this.c = 1 / Math.sqrt(9 * this.d);
17486 }
17487 /** Returns next sample from a gamma distribution. */
17488 nextValue() {
17489 let x2, v0, v1, x, u, v;
17490 while (true) {
17491 do {
17492 x = this.randn.nextValue();
17493 v = 1 + (this.c * x);
17494 } while (v <= 0);
17495 v *= v * v;
17496 x2 = x * x;
17497 v0 = 1 - (0.331 * x2 * x2);
17498 v1 = (0.5 * x2) + (this.d * (1 - v + Math.log(v)));
17499 u = this.randu();
17500 if (u < v0 || Math.log(u) < v1) {
17501 break;
17502 }
17503 }
17504 v = (1 / this.beta) * this.d * v;
17505 if (this.alpha < 1) {
17506 v *= Math.pow(this.randu(), 1 / this.alpha);
17507 }
17508 return this.convertValue(v);
17509 }
17510 /** Handles proper rounding for non-floating-point numbers. */
17511 convertValue(value) {
17512 if (this.dtype === 'float32') {
17513 return value;
17514 }
17515 return Math.round(value);
17516 }
17517 }
17518 class UniformRandom {
17519 constructor(min = 0, max = 1, dtype, seed) {
17520 /** Handles proper rounding for non floating point numbers. */
17521 this.canReturnFloat = () => (this.dtype == null || this.dtype === 'float32');
17522 this.min = min;
17523 this.range = max - min;
17524 this.dtype = dtype;
17525 if (seed == null) {
17526 seed = Math.random();
17527 }
17528 if (typeof seed === 'number') {
17529 seed = seed.toString();
17530 }
17531 if (!this.canReturnFloat() && this.range <= 1) {
17532 throw new Error(`The difference between ${min} - ${max} <= 1 and dtype is not float`);
17533 }
17534 this.random = seedrandom.alea(seed);
17535 }
17536 convertValue(value) {
17537 if (this.canReturnFloat()) {
17538 return value;
17539 }
17540 return Math.round(value);
17541 }
17542 nextValue() {
17543 return this.convertValue(this.min + this.range * this.random());
17544 }
17545 }
17546 function jarqueBeraNormalityTest(values) {
17547 // https://en.wikipedia.org/wiki/Jarque%E2%80%93Bera_test
17548 const n = values.length;
17549 const s = skewness(values);
17550 const k = kurtosis(values);
17551 const jb = n / 6 * (Math.pow(s, 2) + 0.25 * Math.pow(k - 3, 2));
17552 // JB test requires 2-degress of freedom from Chi-Square @ 0.95:
17553 // http://www.itl.nist.gov/div898/handbook/eda/section3/eda3674.htm
17554 const CHI_SQUARE_2DEG = 5.991;
17555 if (jb > CHI_SQUARE_2DEG) {
17556 throw new Error(`Invalid p-value for JB: ${jb}`);
17557 }
17558 }
17559 function expectArrayInMeanStdRange(actual, expectedMean, expectedStdDev, epsilon) {
17560 if (epsilon == null) {
17561 epsilon = testEpsilon();
17562 }
17563 const actualMean = mean$2(actual);
17564 expectNumbersClose(actualMean, expectedMean, epsilon);
17565 expectNumbersClose(standardDeviation(actual, actualMean), expectedStdDev, epsilon);
17566 }
17567 function mean$2(values) {
17568 let sum = 0;
17569 for (let i = 0; i < values.length; i++) {
17570 sum += values[i];
17571 }
17572 return sum / values.length;
17573 }
17574 function standardDeviation(values, mean) {
17575 let squareDiffSum = 0;
17576 for (let i = 0; i < values.length; i++) {
17577 const diff = values[i] - mean;
17578 squareDiffSum += diff * diff;
17579 }
17580 return Math.sqrt(squareDiffSum / values.length);
17581 }
17582 function kurtosis(values) {
17583 // https://en.wikipedia.org/wiki/Kurtosis
17584 const valuesMean = mean$2(values);
17585 const n = values.length;
17586 let sum2 = 0;
17587 let sum4 = 0;
17588 for (let i = 0; i < n; i++) {
17589 const v = values[i] - valuesMean;
17590 sum2 += Math.pow(v, 2);
17591 sum4 += Math.pow(v, 4);
17592 }
17593 return (1 / n) * sum4 / Math.pow((1 / n) * sum2, 2);
17594 }
17595 function skewness(values) {
17596 // https://en.wikipedia.org/wiki/Skewness
17597 const valuesMean = mean$2(values);
17598 const n = values.length;
17599 let sum2 = 0;
17600 let sum3 = 0;
17601 for (let i = 0; i < n; i++) {
17602 const v = values[i] - valuesMean;
17603 sum2 += Math.pow(v, 2);
17604 sum3 += Math.pow(v, 3);
17605 }
17606 return (1 / n) * sum3 / Math.pow((1 / (n - 1)) * sum2, 3 / 2);
17607 }
17608
17609 /**
17610 * @license
17611 * Copyright 2020 Google LLC. All Rights Reserved.
17612 * Licensed under the Apache License, Version 2.0 (the "License");
17613 * you may not use this file except in compliance with the License.
17614 * You may obtain a copy of the License at
17615 *
17616 * http://www.apache.org/licenses/LICENSE-2.0
17617 *
17618 * Unless required by applicable law or agreed to in writing, software
17619 * distributed under the License is distributed on an "AS IS" BASIS,
17620 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17621 * See the License for the specific language governing permissions and
17622 * limitations under the License.
17623 * =============================================================================
17624 */
17625 /**
17626 * Creates a `tf.Tensor` with values sampled from a gamma distribution.
17627 *
17628 * ```js
17629 * tf.randomGamma([2, 2], 1).print();
17630 * ```
17631 *
17632 * @param shape An array of integers defining the output tensor shape.
17633 * @param alpha The shape parameter of the gamma distribution.
17634 * @param beta The inverse scale parameter of the gamma distribution. Defaults
17635 * to 1.
17636 * @param dtype The data type of the output. Defaults to float32.
17637 * @param seed The seed for the random number generator.
17638 *
17639 * @doc {heading: 'Tensors', subheading: 'Random'}
17640 */
17641 function randomGamma_(shape, alpha, beta = 1, dtype = 'float32', seed) {
17642 assertNonNegativeIntegerDimensions(shape);
17643 if (beta == null) {
17644 beta = 1;
17645 }
17646 if (dtype == null) {
17647 dtype = 'float32';
17648 }
17649 if (dtype !== 'float32' && dtype !== 'int32') {
17650 throw new Error(`Unsupported data type ${dtype}`);
17651 }
17652 const rgamma = new RandGamma(alpha, beta, dtype, seed);
17653 const res = buffer(shape, dtype);
17654 for (let i = 0; i < res.values.length; i++) {
17655 res.values[i] = rgamma.nextValue();
17656 }
17657 return res.toTensor();
17658 }
17659 const randomGamma = /* @__PURE__ */ op({ randomGamma_ });
17660
17661 /**
17662 * @license
17663 * Copyright 2020 Google LLC. All Rights Reserved.
17664 * Licensed under the Apache License, Version 2.0 (the "License");
17665 * you may not use this file except in compliance with the License.
17666 * You may obtain a copy of the License at
17667 *
17668 * http://www.apache.org/licenses/LICENSE-2.0
17669 *
17670 * Unless required by applicable law or agreed to in writing, software
17671 * distributed under the License is distributed on an "AS IS" BASIS,
17672 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17673 * See the License for the specific language governing permissions and
17674 * limitations under the License.
17675 * =============================================================================
17676 */
17677 /**
17678 * Creates a `tf.Tensor` with values sampled from a normal distribution.
17679 *
17680 * ```js
17681 * tf.randomNormal([2, 2]).print();
17682 * ```
17683 *
17684 * @param shape An array of integers defining the output tensor shape.
17685 * @param mean The mean of the normal distribution.
17686 * @param stdDev The standard deviation of the normal distribution.
17687 * @param dtype The data type of the output.
17688 * @param seed The seed for the random number generator.
17689 *
17690 * @doc {heading: 'Tensors', subheading: 'Random'}
17691 */
17692 function randomNormal_(shape, mean = 0, stdDev = 1, dtype, seed) {
17693 assertNonNegativeIntegerDimensions(shape);
17694 if (dtype != null && dtype === 'bool') {
17695 throw new Error(`Unsupported data type ${dtype}`);
17696 }
17697 const randGauss = new MPRandGauss(mean, stdDev, dtype, false /* truncated */, seed);
17698 const res = buffer(shape, dtype);
17699 for (let i = 0; i < res.values.length; i++) {
17700 res.values[i] = randGauss.nextValue();
17701 }
17702 return res.toTensor();
17703 }
17704 const randomNormal$2 = /* @__PURE__ */ op({ randomNormal_ });
17705
17706 /**
17707 * @license
17708 * Copyright 2022 Google LLC. All Rights Reserved.
17709 * Licensed under the Apache License, Version 2.0 (the "License");
17710 * you may not use this file except in compliance with the License.
17711 * You may obtain a copy of the License at
17712 *
17713 * http://www.apache.org/licenses/LICENSE-2.0
17714 *
17715 * Unless required by applicable law or agreed to in writing, software
17716 * distributed under the License is distributed on an "AS IS" BASIS,
17717 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17718 * See the License for the specific language governing permissions and
17719 * limitations under the License.
17720 * =============================================================================
17721 */
17722 /**
17723 * Creates a `tf.Tensor` with values sampled from a normal distribution.
17724 *
17725 * The generated values will have mean 0 and standard deviation 1.
17726 *
17727 * ```js
17728 * tf.randomStandardNormal([2, 2]).print();
17729 * ```
17730 *
17731 * @param shape An array of integers defining the output tensor shape.
17732 * @param dtype The data type of the output.
17733 * @param seed The seed for the random number generator.
17734 *
17735 * @doc {heading: 'Tensors', subheading: 'Random'}
17736 */
17737 function randomStandardNormal_(shape, dtype, seed) {
17738 if (dtype != null && dtype === 'bool') {
17739 throw new Error(`Unsupported data type ${dtype}`);
17740 }
17741 return randomNormal$2(shape, 0, 1, dtype, seed);
17742 }
17743 const randomStandardNormal = /* @__PURE__ */ op({ randomStandardNormal_ });
17744
17745 /**
17746 * @license
17747 * Copyright 2020 Google LLC. All Rights Reserved.
17748 * Licensed under the Apache License, Version 2.0 (the "License");
17749 * you may not use this file except in compliance with the License.
17750 * You may obtain a copy of the License at
17751 *
17752 * http://www.apache.org/licenses/LICENSE-2.0
17753 *
17754 * Unless required by applicable law or agreed to in writing, software
17755 * distributed under the License is distributed on an "AS IS" BASIS,
17756 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17757 * See the License for the specific language governing permissions and
17758 * limitations under the License.
17759 * =============================================================================
17760 */
17761 /**
17762 * Creates a `tf.Tensor` with values sampled from a uniform distribution.
17763 *
17764 * The generated values follow a uniform distribution in the range [minval,
17765 * maxval). The lower bound minval is included in the range, while the upper
17766 * bound maxval is excluded.
17767 *
17768 * ```js
17769 * tf.randomUniform([2, 2]).print();
17770 * ```
17771 *
17772 * @param shape An array of integers defining the output tensor shape.
17773 * @param minval The lower bound on the range of random values to generate.
17774 * Defaults to 0.
17775 * @param maxval The upper bound on the range of random values to generate.
17776 * Defaults to 1.
17777 * @param dtype The data type of the output tensor. Defaults to 'float32'.
17778 * @param seed An optional int. Defaults to 0. If seed is set to be non-zero,
17779 * the random number generator is seeded by the given seed. Otherwise, it is
17780 * seeded by a random seed.
17781 *
17782 * @doc {heading: 'Tensors', subheading: 'Random'}
17783 */
17784 function randomUniform_(shape, minval = 0, maxval = 1, dtype = 'float32', seed) {
17785 assertNonNegativeIntegerDimensions(shape);
17786 const res = buffer(shape, dtype);
17787 const random = new UniformRandom(minval, maxval, null, seed);
17788 for (let i = 0; i < res.values.length; i++) {
17789 res.values[i] = random.nextValue();
17790 }
17791 return res.toTensor();
17792 }
17793 const randomUniform$1 = /* @__PURE__ */ op({ randomUniform_ });
17794
17795 /**
17796 * @license
17797 * Copyright 2023 Google LLC.
17798 * Licensed under the Apache License, Version 2.0 (the "License");
17799 * you may not use this file except in compliance with the License.
17800 * You may obtain a copy of the License at
17801 *
17802 * http://www.apache.org/licenses/LICENSE-2.0
17803 *
17804 * Unless required by applicable law or agreed to in writing, software
17805 * distributed under the License is distributed on an "AS IS" BASIS,
17806 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17807 * See the License for the specific language governing permissions and
17808 * limitations under the License.
17809 * =============================================================================
17810 */
17811 /**
17812 * Creates a `tf.Tensor` with integers sampled from a uniform distribution.
17813 *
17814 * The generated values are uniform integers in the range [minval, maxval). The
17815 * lower bound minval is included in the range, while the upper bound maxval is
17816 * excluded.
17817 *
17818 * ```js
17819 * tf.randomUniformInt([2, 2], 0, 10).print();
17820 * ```
17821 *
17822 * @param shape An array of integers defining the output tensor shape.
17823 * @param minval Inclusive lower bound on the generated integers.
17824 * @param maxval Exclusive upper bound on the generated integers.
17825 * @param seed An optional int. Defaults to 0. If seed is set to be non-zero,
17826 * the random number generator is seeded by the given seed. Otherwise, it is
17827 * seeded by a random seed.
17828 *
17829 * @doc {heading: 'Tensors', subheading: 'Random'}
17830 */
17831 function randomUniformInt_(shape, minval, maxval, seed) {
17832 // TODO(mattsoulanille): Handle optional seed2 input.
17833 return randomUniform$1(shape, minval, maxval, 'int32', seed);
17834 }
17835 const randomUniformInt = /* @__PURE__ */ op({ randomUniformInt_ });
17836
17837 /**
17838 * @license
17839 * Copyright 2018 Google LLC. All Rights Reserved.
17840 * Licensed under the Apache License, Version 2.0 (the "License");
17841 * you may not use this file except in compliance with the License.
17842 * You may obtain a copy of the License at
17843 *
17844 * http://www.apache.org/licenses/LICENSE-2.0
17845 *
17846 * Unless required by applicable law or agreed to in writing, software
17847 * distributed under the License is distributed on an "AS IS" BASIS,
17848 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17849 * See the License for the specific language governing permissions and
17850 * limitations under the License.
17851 * =============================================================================
17852 */
17853 /**
17854 * Creates a new `tf.Tensor1D` filled with the numbers in the range provided.
17855 *
17856 * The tensor is a half-open interval meaning it includes start, but
17857 * excludes stop. Decrementing ranges and negative step values are also
17858 * supported.
17859 *
17860 *
17861 * ```js
17862 * tf.range(0, 9, 2).print();
17863 * ```
17864 *
17865 * @param start An integer start value
17866 * @param stop An integer stop value
17867 * @param step An integer increment (will default to 1 or -1)
17868 * @param dtype The data type of the output tensor. Defaults to 'float32'.
17869 *
17870 * @doc {heading: 'Tensors', subheading: 'Creation'}
17871 */
17872 function range$3(start, stop, step = 1, dtype = 'float32') {
17873 if (step === 0) {
17874 throw new Error('Cannot have a step of zero');
17875 }
17876 const attrs = { start, stop, step, dtype };
17877 return ENGINE.runKernel(Range, {} /* inputs */, attrs);
17878 }
17879
17880 /**
17881 * @license
17882 * Copyright 2020 Google LLC. All Rights Reserved.
17883 * Licensed under the Apache License, Version 2.0 (the "License");
17884 * you may not use this file except in compliance with the License.
17885 * You may obtain a copy of the License at
17886 *
17887 * http://www.apache.org/licenses/LICENSE-2.0
17888 *
17889 * Unless required by applicable law or agreed to in writing, software
17890 * distributed under the License is distributed on an "AS IS" BASIS,
17891 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17892 * See the License for the specific language governing permissions and
17893 * limitations under the License.
17894 * =============================================================================
17895 */
17896 /**
17897 * Returns the real part of a complex (or real) tensor.
17898 *
17899 * Given a tensor input, this operation returns a tensor of type float that is
17900 * the real part of each element in input considered as a complex number.
17901 *
17902 * If the input is real, it simply makes a clone.
17903 *
17904 * ```js
17905 * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]);
17906 * tf.real(x).print();
17907 * ```
17908 *
17909 * @doc {heading: 'Tensors', subheading: 'Creation'}
17910 */
17911 function real_(input) {
17912 const $input = convertToTensor(input, 'input', 'real');
17913 const inputs = { input: $input };
17914 return ENGINE.runKernel(Real, inputs);
17915 }
17916 const real$2 = /* @__PURE__ */ op({ real_ });
17917
17918 /**
17919 * @license
17920 * Copyright 2018 Google LLC. All Rights Reserved.
17921 * Licensed under the Apache License, Version 2.0 (the "License");
17922 * you may not use this file except in compliance with the License.
17923 * You may obtain a copy of the License at
17924 *
17925 * http://www.apache.org/licenses/LICENSE-2.0
17926 *
17927 * Unless required by applicable law or agreed to in writing, software
17928 * distributed under the License is distributed on an "AS IS" BASIS,
17929 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17930 * See the License for the specific language governing permissions and
17931 * limitations under the License.
17932 * =============================================================================
17933 */
17934 /**
17935 * Computes reciprocal of x element-wise: `1 / x`
17936 *
17937 * ```js
17938 * const x = tf.tensor1d([0, 1, 2]);
17939 *
17940 * x.reciprocal().print(); // or tf.reciprocal(x)
17941 * ```
17942 * @param x The input tensor.
17943 *
17944 * @doc {heading: 'Operations', subheading: 'Basic math'}
17945 */
17946 function reciprocal_(x) {
17947 const $x = convertToTensor(x, 'x', 'reciprocal');
17948 const inputs = { x: $x };
17949 return ENGINE.runKernel(Reciprocal, inputs);
17950 }
17951 const reciprocal$2 = /* @__PURE__ */ op({ reciprocal_ });
17952
17953 /**
17954 * @license
17955 * Copyright 2020 Google LLC. All Rights Reserved.
17956 * Licensed under the Apache License, Version 2.0 (the "License");
17957 * you may not use this file except in compliance with the License.
17958 * You may obtain a copy of the License at
17959 *
17960 * http://www.apache.org/licenses/LICENSE-2.0
17961 *
17962 * Unless required by applicable law or agreed to in writing, software
17963 * distributed under the License is distributed on an "AS IS" BASIS,
17964 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17965 * See the License for the specific language governing permissions and
17966 * limitations under the License.
17967 * =============================================================================
17968 */
17969 /**
17970 * Computes rectified linear element-wise: `max(x, 0)`.
17971 *
17972 * ```js
17973 * const x = tf.tensor1d([-1, 2, -3, 4]);
17974 *
17975 * x.relu().print(); // or tf.relu(x)
17976 * ```
17977 * @param x The input tensor. If the dtype is `bool`, the output dtype will be
17978 * `int32`.
17979 *
17980 * @doc {heading: 'Operations', subheading: 'Basic math'}
17981 */
17982 function relu_(x) {
17983 const $x = convertToTensor(x, 'x', 'relu');
17984 const inputs = { x: $x };
17985 return ENGINE.runKernel(Relu$1, inputs);
17986 }
17987 const relu$2 = /* @__PURE__ */ op({ relu_ });
17988
17989 /**
17990 * @license
17991 * Copyright 2020 Google LLC. All Rights Reserved.
17992 * Licensed under the Apache License, Version 2.0 (the "License");
17993 * you may not use this file except in compliance with the License.
17994 * You may obtain a copy of the License at
17995 *
17996 * http://www.apache.org/licenses/LICENSE-2.0
17997 *
17998 * Unless required by applicable law or agreed to in writing, software
17999 * distributed under the License is distributed on an "AS IS" BASIS,
18000 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18001 * See the License for the specific language governing permissions and
18002 * limitations under the License.
18003 * =============================================================================
18004 */
18005 /**
18006 * Computes rectified linear 6 element-wise: `min(max(x, 0), 6)`.
18007 *
18008 * ```js
18009 * const x = tf.tensor1d([-1, 2, -3, 8]);
18010 *
18011 * x.relu6().print(); // or tf.relu6(x)
18012 * ```
18013 * @param x The input tensor. If the dtype is `bool`, the output dtype will be
18014 * `int32`.
18015 *
18016 * @doc {heading: 'Operations', subheading: 'Basic math'}
18017 */
18018 function relu6_(x) {
18019 const $x = convertToTensor(x, 'x', 'relu6');
18020 const inputs = { x: $x };
18021 return ENGINE.runKernel(Relu6$1, inputs);
18022 }
18023 const relu6$2 = /* @__PURE__ */ op({ relu6_ });
18024
18025 /**
18026 * @license
18027 * Copyright 2018 Google LLC. All Rights Reserved.
18028 * Licensed under the Apache License, Version 2.0 (the "License");
18029 * you may not use this file except in compliance with the License.
18030 * You may obtain a copy of the License at
18031 *
18032 * http://www.apache.org/licenses/LICENSE-2.0
18033 *
18034 * Unless required by applicable law or agreed to in writing, software
18035 * distributed under the License is distributed on an "AS IS" BASIS,
18036 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18037 * See the License for the specific language governing permissions and
18038 * limitations under the License.
18039 * =============================================================================
18040 */
18041 /**
18042 * Reverses a `tf.Tensor` along a specified axis.
18043 *
18044 * Also available are stricter rank-specific methods that assert that `x` is
18045 * of the given rank:
18046 * - `tf.reverse1d`
18047 * - `tf.reverse2d`
18048 * - `tf.reverse3d`
18049 * - `tf.reverse4d`
18050 *
18051 * Except `tf.reverse1d` (which does not have axis param), all methods have
18052 * same signature as this method.
18053 *
18054 * ```js
18055 * const x = tf.tensor1d([1, 2, 3, 4]);
18056 *
18057 * x.reverse().print();
18058 * ```
18059 *
18060 * ```js
18061 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
18062 *
18063 * const axis = 1;
18064 * x.reverse(axis).print();
18065 * ```
18066 * @param x The input tensor to be reversed.
18067 * @param axis The set of dimensions to reverse. Must be in the
18068 * range [-rank(x), rank(x)). Defaults to all axes.
18069 *
18070 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
18071 */
18072 function reverse_(x, axis) {
18073 const $x = convertToTensor(x, 'x', 'reverse');
18074 const inputs = { x: $x };
18075 const attrs = { dims: axis };
18076 return ENGINE.runKernel(Reverse, inputs, attrs);
18077 }
18078 const reverse$2 = /* @__PURE__ */ op({ reverse_ });
18079
18080 /**
18081 * @license
18082 * Copyright 2020 Google LLC. All Rights Reserved.
18083 * Licensed under the Apache License, Version 2.0 (the "License");
18084 * you may not use this file except in compliance with the License.
18085 * You may obtain a copy of the License at
18086 *
18087 * http://www.apache.org/licenses/LICENSE-2.0
18088 *
18089 * Unless required by applicable law or agreed to in writing, software
18090 * distributed under the License is distributed on an "AS IS" BASIS,
18091 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18092 * See the License for the specific language governing permissions and
18093 * limitations under the License.
18094 * =============================================================================
18095 */
18096 /**
18097 * Reverses a `tf.Tensor1D`.
18098 *
18099 * @param x The input tensor.
18100 */
18101 function reverse1d_(x) {
18102 const $x = convertToTensor(x, 'x', 'reverse');
18103 assert$1($x.rank === 1, () => `Error in reverse1D: x must be rank 1 but got rank ${$x.rank}.`);
18104 return reverse$2($x, 0);
18105 }
18106 const reverse1d = /* @__PURE__ */ op({ reverse1d_ });
18107
18108 /**
18109 * @license
18110 * Copyright 2020 Google LLC. All Rights Reserved.
18111 * Licensed under the Apache License, Version 2.0 (the "License");
18112 * you may not use this file except in compliance with the License.
18113 * You may obtain a copy of the License at
18114 *
18115 * http://www.apache.org/licenses/LICENSE-2.0
18116 *
18117 * Unless required by applicable law or agreed to in writing, software
18118 * distributed under the License is distributed on an "AS IS" BASIS,
18119 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18120 * See the License for the specific language governing permissions and
18121 * limitations under the License.
18122 * =============================================================================
18123 */
18124 /**
18125 * Reverses a `tf.Tensor2D` along a specified axis.
18126 *
18127 * @param x The input tensor.
18128 * @param axis The set of dimensions to reverse. Must be in the
18129 * range [-rank(x), rank(x)). Defaults to all axes.
18130 */
18131 function reverse2d_(x, axis) {
18132 const $x = convertToTensor(x, 'x', 'reverse');
18133 assert$1($x.rank === 2, () => `Error in reverse2D: x must be rank 2 but got rank ${$x.rank}.`);
18134 return reverse$2($x, axis);
18135 }
18136 const reverse2d = /* @__PURE__ */ op({ reverse2d_ });
18137
18138 /**
18139 * @license
18140 * Copyright 2020 Google LLC. All Rights Reserved.
18141 * Licensed under the Apache License, Version 2.0 (the "License");
18142 * you may not use this file except in compliance with the License.
18143 * You may obtain a copy of the License at
18144 *
18145 * http://www.apache.org/licenses/LICENSE-2.0
18146 *
18147 * Unless required by applicable law or agreed to in writing, software
18148 * distributed under the License is distributed on an "AS IS" BASIS,
18149 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18150 * See the License for the specific language governing permissions and
18151 * limitations under the License.
18152 * =============================================================================
18153 */
18154 /**
18155 * Reverses a `tf.Tensor3D` along a specified axis.
18156 *
18157 * @param x The input tensor.
18158 * @param axis The set of dimensions to reverse. Must be in the
18159 * range [-rank(x), rank(x)). Defaults to all axes.
18160 */
18161 function reverse3d_(x, axis) {
18162 const $x = convertToTensor(x, 'x', 'reverse');
18163 assert$1($x.rank === 3, () => `Error in reverse3D: x must be rank 3 but got rank ${$x.rank}.`);
18164 return reverse$2($x, axis);
18165 }
18166 const reverse3d = /* @__PURE__ */ op({ reverse3d_ });
18167
18168 /**
18169 * @license
18170 * Copyright 2020 Google LLC. All Rights Reserved.
18171 * Licensed under the Apache License, Version 2.0 (the "License");
18172 * you may not use this file except in compliance with the License.
18173 * You may obtain a copy of the License at
18174 *
18175 * http://www.apache.org/licenses/LICENSE-2.0
18176 *
18177 * Unless required by applicable law or agreed to in writing, software
18178 * distributed under the License is distributed on an "AS IS" BASIS,
18179 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18180 * See the License for the specific language governing permissions and
18181 * limitations under the License.
18182 * =============================================================================
18183 */
18184 /**
18185 * Reverses a `tf.Tensor4D` along a specified axis.
18186 *
18187 * @param x The input tensor.
18188 * @param axis The set of dimensions to reverse. Must be in the
18189 * range [-rank(x), rank(x)). Defaults to all axes.
18190 */
18191 function reverse4d_(x, axis) {
18192 const $x = convertToTensor(x, 'x', 'reverse');
18193 assert$1($x.rank === 4, () => `Error in reverse4D: x must be rank 4 but got rank ${$x.rank}.`);
18194 return reverse$2($x, axis);
18195 }
18196 const reverse4d = /* @__PURE__ */ op({ reverse4d_ });
18197
18198 /**
18199 * @license
18200 * Copyright 2018 Google LLC. All Rights Reserved.
18201 * Licensed under the Apache License, Version 2.0 (the "License");
18202 * you may not use this file except in compliance with the License.
18203 * You may obtain a copy of the License at
18204 *
18205 * http://www.apache.org/licenses/LICENSE-2.0
18206 *
18207 * Unless required by applicable law or agreed to in writing, software
18208 * distributed under the License is distributed on an "AS IS" BASIS,
18209 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18210 * See the License for the specific language governing permissions and
18211 * limitations under the License.
18212 * =============================================================================
18213 */
18214 /**
18215 * Computes round of input `tf.Tensor` element-wise: `round(x)`.
18216 * It implements banker's rounding.
18217 *
18218 * ```js
18219 * const x = tf.tensor1d([.6, 1.1, -3.3]);
18220 *
18221 * x.round().print(); // or tf.round(x)
18222 * ```
18223 * @param x The input tensor.
18224 *
18225 * @doc {heading: 'Operations', subheading: 'Basic math'}
18226 */
18227 function round_(x) {
18228 const $x = convertToTensor(x, 'x', 'round');
18229 const inputs = { x: $x };
18230 return ENGINE.runKernel(Round, inputs);
18231 }
18232 const round$2 = /* @__PURE__ */ op({ round_ });
18233
18234 /**
18235 * @license
18236 * Copyright 2018 Google LLC. All Rights Reserved.
18237 * Licensed under the Apache License, Version 2.0 (the "License");
18238 * you may not use this file except in compliance with the License.
18239 * You may obtain a copy of the License at
18240 *
18241 * http://www.apache.org/licenses/LICENSE-2.0
18242 *
18243 * Unless required by applicable law or agreed to in writing, software
18244 * distributed under the License is distributed on an "AS IS" BASIS,
18245 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18246 * See the License for the specific language governing permissions and
18247 * limitations under the License.
18248 * =============================================================================
18249 */
18250 /**
18251 * Computes reciprocal of square root of the input `tf.Tensor` element-wise:
18252 * `y = 1 / sqrt(x)`
18253 *
18254 * ```js
18255 * const x = tf.tensor1d([1, 2, 4, -1]);
18256 *
18257 * x.rsqrt().print(); // or tf.rsqrt(x)
18258 * ```
18259 * @param x The input tensor.
18260 *
18261 * @doc {heading: 'Operations', subheading: 'Basic math'}
18262 */
18263 function rsqrt_(x) {
18264 const $x = convertToTensor(x, 'x', 'rsqrt', 'float32');
18265 const inputs = { x: $x };
18266 return ENGINE.runKernel(Rsqrt, inputs);
18267 }
18268 const rsqrt$2 = /* @__PURE__ */ op({ rsqrt_ });
18269
18270 /**
18271 * @license
18272 * Copyright 2020 Google LLC. All Rights Reserved.
18273 * Licensed under the Apache License, Version 2.0 (the "License");
18274 * you may not use this file except in compliance with the License.
18275 * You may obtain a copy of the License at
18276 *
18277 * http://www.apache.org/licenses/LICENSE-2.0
18278 *
18279 * Unless required by applicable law or agreed to in writing, software
18280 * distributed under the License is distributed on an "AS IS" BASIS,
18281 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18282 * See the License for the specific language governing permissions and
18283 * limitations under the License.
18284 * =============================================================================
18285 */
18286 /**
18287 * Computes scaled exponential linear element-wise.
18288 *
18289 * `x < 0 ? scale * alpha * (exp(x) - 1) : scale * x`
18290 *
18291 * ```js
18292 * const x = tf.tensor1d([-1, 2, -3, 4]);
18293 *
18294 * x.selu().print(); // or tf.selu(x)
18295 * ```
18296 * @param x The input tensor.
18297 *
18298 * @doc {heading: 'Operations', subheading: 'Basic math'}
18299 */
18300 function selu_(x) {
18301 const $x = convertToTensor(x, 'x', 'selu');
18302 const inputs = { x: $x };
18303 return ENGINE.runKernel(Selu$1, inputs);
18304 }
18305 const selu$2 = /* @__PURE__ */ op({ selu_ });
18306
18307 /**
18308 * 2-D convolution with separable filters.
18309 *
18310 * Performs a depthwise convolution that acts separately on channels followed
18311 * by a pointwise convolution that mixes channels. Note that this is
18312 * separability between dimensions [1, 2] and 3, not spatial separability
18313 * between dimensions 1 and 2.
18314 *
18315 * See
18316 * [https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d](
18317 * https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d)
18318 * for more details.
18319 *
18320 * @param x The input tensor, of rank 4 or rank 3, of shape
18321 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
18322 * assumed.
18323 * @param depthwiseFilter The depthwise filter tensor, rank 4, of shape
18324 * `[filterHeight, filterWidth, inChannels, channelMultiplier]`. This is
18325 * the filter used in the first step.
18326 * @param pointwiseFilter The pointwise filter tensor, rank 4, of shape
18327 * `[1, 1, inChannels * channelMultiplier, outChannels]`. This is
18328 * the filter used in the second step.
18329 * @param strides The strides of the convolution: `[strideHeight,
18330 * strideWidth]`. If strides is a single number, then `strideHeight ==
18331 * strideWidth`.
18332 * @param pad The type of padding algorithm.
18333 * - `same` and stride 1: output will be of same size as input,
18334 * regardless of filter size.
18335 * - `valid`: output will be smaller than input if filter is larger
18336 * than 1x1.
18337 * - For more info, see this guide:
18338 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
18339 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
18340 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
18341 * in which we sample input values across the height and width dimensions
18342 * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
18343 * number, then `dilationHeight == dilationWidth`. If it is greater than
18344 * 1, then all values of `strides` must be 1.
18345 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
18346 * "NHWC". Specify the data format of the input and output data. With the
18347 * default format "NHWC", the data is stored in the order of: [batch,
18348 * height, width, channels]. Only "NHWC" is currently supported.
18349 *
18350 * @doc {heading: 'Operations', subheading: 'Convolution'}
18351 */
18352 function separableConv2d_(x, depthwiseFilter, pointwiseFilter, strides, pad, dilation = [1, 1], dataFormat = 'NHWC') {
18353 const $x = convertToTensor(x, 'x', 'separableConv2d');
18354 const $depthwiseFilter = convertToTensor(depthwiseFilter, 'depthwiseFilter', 'separableConv2d');
18355 const $pointwiseFilter = convertToTensor(pointwiseFilter, 'pointwiseFilter', 'separableConv2d');
18356 let x4D = $x;
18357 let reshapedTo4D = false;
18358 if ($x.rank === 3) {
18359 reshapedTo4D = true;
18360 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
18361 }
18362 if (dataFormat === 'NCHW') {
18363 throw new Error('separableConv2d currently does not support dataFormat NCHW; only ' +
18364 'NHWC is supported');
18365 }
18366 assert$1(x4D.rank === 4, () => `Error in separableConv2d: input must be rank 4, but got ` +
18367 `rank ${x4D.rank}.`);
18368 assert$1($depthwiseFilter.rank === 4, () => `Error in separableConv2d: depthwise filter must be rank 4, but ` +
18369 `got rank ${$depthwiseFilter.rank}.`);
18370 assert$1($pointwiseFilter.rank === 4, () => `Error in separableConv2d: pointwise filter must be rank 4, but ` +
18371 `got rank ${$depthwiseFilter.rank}.`);
18372 assert$1($pointwiseFilter.shape[0] === 1, () => `Error in separableConv2d: the first dimension of pointwise filter ` +
18373 ` must be 1, but got ${$pointwiseFilter.shape[0]}.`);
18374 assert$1($pointwiseFilter.shape[1] === 1, () => `Error in separableConv2d: the second dimension of pointwise ` +
18375 `filter must be 1, but got ${$pointwiseFilter.shape[1]}.`);
18376 const inChannels = $depthwiseFilter.shape[2];
18377 const channelMultiplier = $depthwiseFilter.shape[3];
18378 assert$1($pointwiseFilter.shape[2] === inChannels * channelMultiplier, () => `Error in separableConv2d: the third dimension of pointwise filter ` +
18379 `must be ${inChannels * channelMultiplier}, ` +
18380 `but got ${$pointwiseFilter.shape[2]}.`);
18381 const depthwise = depthwiseConv2d$3(x4D, $depthwiseFilter, strides, pad, dataFormat, dilation);
18382 const pointwiseStride = 1;
18383 const res = conv2d$4(depthwise, $pointwiseFilter, pointwiseStride, 'valid', dataFormat);
18384 if (reshapedTo4D) {
18385 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
18386 }
18387 return res;
18388 }
18389 const separableConv2d$1 = /* @__PURE__ */ op({ separableConv2d_ });
18390
18391 /**
18392 * @license
18393 * Copyright 2020 Google Inc. All Rights Reserved.
18394 * Licensed under the Apache License, Version 2.0 (the "License");
18395 * you may not use this file except in compliance with the License.
18396 * You may obtain a copy of the License at
18397 *
18398 * http://www.apache.org/licenses/LICENSE-2.0
18399 *
18400 * Unless required by applicable law or agreed to in writing, software
18401 * distributed under the License is distributed on an "AS IS" BASIS,
18402 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18403 * See the License for the specific language governing permissions and
18404 * limitations under the License.
18405 * =============================================================================
18406 */
18407 /**
18408 * Computes the difference between two lists of numbers.
18409 *
18410 * Given a Tensor `x` and a Tensor `y`, this operation returns a Tensor `out`
18411 * that represents all values that are in `x` but not in `y`. The returned
18412 * Tensor `out` is sorted in the same order that the numbers appear in `x`
18413 * (duplicates are preserved). This operation also returns a Tensor indices that
18414 * represents the position of each out element in `x`. In other words:
18415 *
18416 * `out[i] = x[idx[i]] for i in [0, 1, ..., out.length - 1]`
18417 *
18418 * ```js
18419 * const x = [1, 2, 3, 4, 5, 6];
18420 * const y = [1, 3, 5];
18421 *
18422 * const [out, indices] = await tf.setdiff1dAsync(x, y);
18423 * out.print(); // [2, 4, 6]
18424 * indices.print(); // [1, 3, 5]
18425 * ```
18426 *
18427 * @param x 1-D Tensor. Values to keep.
18428 * @param y 1-D Tensor. Must have the same type as x. Values to exclude in the
18429 * output.
18430 * @returns Promise of Tensor tuple [out, indices].
18431 * out: Tensor with the same type as x.
18432 * indices: A Tensor of type int32.
18433 *
18434 * @doc {heading: 'Tensors', subheading: 'Transformations'}
18435 */
18436 async function setdiff1dAsync_(x, y) {
18437 const $x = convertToTensor(x, 'x', 'setdiff1d');
18438 const $y = convertToTensor(y, 'y', 'setdiff1d');
18439 assert$1($x.dtype === $y.dtype, () => `x and y should have the same dtype, but got x (${$x.dtype}) and y (${$y.dtype}).`);
18440 assert$1($x.rank === 1, () => `x should be 1D tensor, but got x (${$x.shape}).`);
18441 assert$1($y.rank === 1, () => `y should be 1D tensor, but got y (${$y.shape}).`);
18442 const xVals = await $x.data();
18443 const yVals = await $y.data();
18444 const ySet = new Set(yVals);
18445 let outputSize = 0;
18446 for (let i = 0; i < xVals.length; i++) {
18447 if (!ySet.has(xVals[i])) {
18448 outputSize++;
18449 }
18450 }
18451 const buffer = new TensorBuffer([outputSize], $x.dtype);
18452 const indices = new TensorBuffer([outputSize], 'int32');
18453 for (let i = 0, p = 0; i < xVals.length; i++) {
18454 if (!ySet.has(xVals[i])) {
18455 buffer.values[p] = xVals[i];
18456 indices.values[p] = i;
18457 p++;
18458 }
18459 }
18460 return [buffer.toTensor(), indices.toTensor()];
18461 }
18462 const setdiff1dAsync = setdiff1dAsync_;
18463
18464 /**
18465 * @license
18466 * Copyright 2018 Google LLC. All Rights Reserved.
18467 * Licensed under the Apache License, Version 2.0 (the "License");
18468 * you may not use this file except in compliance with the License.
18469 * You may obtain a copy of the License at
18470 *
18471 * http://www.apache.org/licenses/LICENSE-2.0
18472 *
18473 * Unless required by applicable law or agreed to in writing, software
18474 * distributed under the License is distributed on an "AS IS" BASIS,
18475 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18476 * See the License for the specific language governing permissions and
18477 * limitations under the License.
18478 * =============================================================================
18479 */
18480 /**
18481 * Returns an element-wise indication of the sign of a number.
18482 *
18483 * ```js
18484 * const x = tf.tensor1d([.6, 1.1, -3.3, NaN, 0]);
18485 *
18486 * x.sign().print(); // or tf.sign(x)
18487 * ```
18488 * @param x The input Tensor.
18489 *
18490 * @doc {heading: 'Operations', subheading: 'Basic math'}
18491 */
18492 function sign_(x) {
18493 const $x = convertToTensor(x, 'x', 'sign');
18494 const inputs = { x: $x };
18495 return ENGINE.runKernel(Sign, inputs);
18496 }
18497 const sign$3 = /* @__PURE__ */ op({ sign_ });
18498
18499 /**
18500 * @license
18501 * Copyright 2018 Google LLC. All Rights Reserved.
18502 * Licensed under the Apache License, Version 2.0 (the "License");
18503 * you may not use this file except in compliance with the License.
18504 * You may obtain a copy of the License at
18505 *
18506 * http://www.apache.org/licenses/LICENSE-2.0
18507 *
18508 * Unless required by applicable law or agreed to in writing, software
18509 * distributed under the License is distributed on an "AS IS" BASIS,
18510 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18511 * See the License for the specific language governing permissions and
18512 * limitations under the License.
18513 * =============================================================================
18514 */
18515 /**
18516 * Computes sin of the input Tensor element-wise: `sin(x)`
18517 *
18518 * ```js
18519 * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
18520 *
18521 * x.sin().print(); // or tf.sin(x)
18522 * ```
18523 * @param x The input tensor.
18524 *
18525 * @doc {heading: 'Operations', subheading: 'Basic math'}
18526 */
18527 function sin_(x) {
18528 const $x = convertToTensor(x, 'x', 'sin', 'float32');
18529 const inputs = { x: $x };
18530 return ENGINE.runKernel(Sin, inputs);
18531 }
18532 const sin$2 = /* @__PURE__ */ op({ sin_ });
18533
18534 /**
18535 * @license
18536 * Copyright 2018 Google LLC. All Rights Reserved.
18537 * Licensed under the Apache License, Version 2.0 (the "License");
18538 * you may not use this file except in compliance with the License.
18539 * You may obtain a copy of the License at
18540 *
18541 * http://www.apache.org/licenses/LICENSE-2.0
18542 *
18543 * Unless required by applicable law or agreed to in writing, software
18544 * distributed under the License is distributed on an "AS IS" BASIS,
18545 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18546 * See the License for the specific language governing permissions and
18547 * limitations under the License.
18548 * =============================================================================
18549 */
18550 /**
18551 * Computes hyperbolic sin of the input `tf.Tensor` element-wise: `sinh(x)`
18552 *
18553 * ```js
18554 * const x = tf.tensor1d([0, 1, -1, .7]);
18555 *
18556 * x.sinh().print(); // or tf.sinh(x)
18557 * ```
18558 * @param x The input tensor.
18559 *
18560 * @doc {heading: 'Operations', subheading: 'Basic math'}
18561 */
18562 function sinh_(x) {
18563 const $x = convertToTensor(x, 'x', 'sinh');
18564 const inputs = { x: $x };
18565 return ENGINE.runKernel(Sinh, inputs);
18566 }
18567 const sinh$2 = /* @__PURE__ */ op({ sinh_ });
18568
18569 /**
18570 * @license
18571 * Copyright 2018 Google LLC. All Rights Reserved.
18572 * Licensed under the Apache License, Version 2.0 (the "License");
18573 * you may not use this file except in compliance with the License.
18574 * You may obtain a copy of the License at
18575 *
18576 * http://www.apache.org/licenses/LICENSE-2.0
18577 *
18578 * Unless required by applicable law or agreed to in writing, software
18579 * distributed under the License is distributed on an "AS IS" BASIS,
18580 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18581 * See the License for the specific language governing permissions and
18582 * limitations under the License.
18583 * =============================================================================
18584 */
18585 /**
18586 * Extracts a 1D slice from 1D array starting at coordinates `begin` and is
18587 * of length `size`. See `slice` for details.
18588 */
18589 function slice1d_(x, begin, size) {
18590 const $x = convertToTensor(x, 'x', 'slice1d');
18591 assert$1($x.rank === 1, () => `slice1d expects a rank-1 tensor, but got a rank-${$x.rank} tensor`);
18592 return slice$2($x, [begin], [size]);
18593 }
18594 const slice1d = /* @__PURE__ */ op({ slice1d_ });
18595
18596 /**
18597 * @license
18598 * Copyright 2018 Google LLC. All Rights Reserved.
18599 * Licensed under the Apache License, Version 2.0 (the "License");
18600 * you may not use this file except in compliance with the License.
18601 * You may obtain a copy of the License at
18602 *
18603 * http://www.apache.org/licenses/LICENSE-2.0
18604 *
18605 * Unless required by applicable law or agreed to in writing, software
18606 * distributed under the License is distributed on an "AS IS" BASIS,
18607 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18608 * See the License for the specific language governing permissions and
18609 * limitations under the License.
18610 * =============================================================================
18611 */
18612 /**
18613 * Extracts a 2D slice from a 2D array starting at coordinates `begin` and
18614 * is of size `size`. See `slice` for details.
18615 */
18616 function slice2d_(x, begin, size) {
18617 const $x = convertToTensor(x, 'x', 'slice2d');
18618 assert$1($x.rank === 2, () => `slice2d expects a rank-2 tensor, but got a rank-${$x.rank} tensor`);
18619 return slice$2($x, begin, size);
18620 }
18621 const slice2d = /* @__PURE__ */ op({ slice2d_ });
18622
18623 /**
18624 * @license
18625 * Copyright 2018 Google LLC. All Rights Reserved.
18626 * Licensed under the Apache License, Version 2.0 (the "License");
18627 * you may not use this file except in compliance with the License.
18628 * You may obtain a copy of the License at
18629 *
18630 * http://www.apache.org/licenses/LICENSE-2.0
18631 *
18632 * Unless required by applicable law or agreed to in writing, software
18633 * distributed under the License is distributed on an "AS IS" BASIS,
18634 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18635 * See the License for the specific language governing permissions and
18636 * limitations under the License.
18637 * =============================================================================
18638 */
18639 /**
18640 * Extracts a 3D slice from a 3D array starting at coordinates `begin` and
18641 * is of size `size`. See `slice` for details.
18642 */
18643 function slice3d_(x, begin, size) {
18644 const $x = convertToTensor(x, 'x', 'slice3d');
18645 assert$1($x.rank === 3, () => `slice3d expects a rank-3 tensor, but got a rank-${$x.rank} tensor`);
18646 return slice$2($x, begin, size);
18647 }
18648 const slice3d = /* @__PURE__ */ op({ slice3d_ });
18649
18650 /**
18651 * @license
18652 * Copyright 2018 Google LLC. All Rights Reserved.
18653 * Licensed under the Apache License, Version 2.0 (the "License");
18654 * you may not use this file except in compliance with the License.
18655 * You may obtain a copy of the License at
18656 *
18657 * http://www.apache.org/licenses/LICENSE-2.0
18658 *
18659 * Unless required by applicable law or agreed to in writing, software
18660 * distributed under the License is distributed on an "AS IS" BASIS,
18661 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18662 * See the License for the specific language governing permissions and
18663 * limitations under the License.
18664 * =============================================================================
18665 */
18666 /**
18667 * Extracts a 4D slice from a 4D array starting at coordinates `begin` and
18668 * is of size `size`. See `slice` for details.
18669 */
18670 function slice4d_(x, begin, size) {
18671 const $x = convertToTensor(x, 'x', 'slice4d');
18672 assert$1($x.rank === 4, () => `slice4d expects a rank-4 tensor, but got a rank-${$x.rank} tensor`);
18673 return slice$2($x, begin, size);
18674 }
18675 const slice4d = /* @__PURE__ */ op({ slice4d_ });
18676
18677 /**
18678 * @license
18679 * Copyright 2018 Google LLC. All Rights Reserved.
18680 * Licensed under the Apache License, Version 2.0 (the "License");
18681 * you may not use this file except in compliance with the License.
18682 * You may obtain a copy of the License at
18683 *
18684 * http://www.apache.org/licenses/LICENSE-2.0
18685 *
18686 * Unless required by applicable law or agreed to in writing, software
18687 * distributed under the License is distributed on an "AS IS" BASIS,
18688 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18689 * See the License for the specific language governing permissions and
18690 * limitations under the License.
18691 * =============================================================================
18692 */
18693 /**
18694 * Computes the softmax normalized vector given the logits.
18695 *
18696 * ```js
18697 * const a = tf.tensor1d([1, 2, 3]);
18698 *
18699 * a.softmax().print(); // or tf.softmax(a)
18700 * ```
18701 *
18702 * ```js
18703 * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);
18704 *
18705 * a.softmax().print(); // or tf.softmax(a)
18706 * ```
18707 *
18708 * @param logits The logits array.
18709 * @param dim The dimension softmax would be performed on. Defaults to `-1`
18710 * which indicates the last dimension.
18711 *
18712 * @doc {heading: 'Operations', subheading: 'Normalization'}
18713 */
18714 function softmax_(logits, dim = -1) {
18715 const $logits = convertToTensor(logits, 'logits', 'softmax', 'float32');
18716 if (dim === -1) {
18717 dim = $logits.rank - 1;
18718 }
18719 if (dim !== $logits.rank - 1) {
18720 throw Error('Softmax along a non-last dimension is not yet supported. ' +
18721 `Logits was rank ${$logits.rank} and dim was ${dim}`);
18722 }
18723 const inputs = { logits: $logits };
18724 const attrs = { dim };
18725 return ENGINE.runKernel(Softmax$2, inputs, attrs);
18726 }
18727 const softmax$3 = /* @__PURE__ */ op({ softmax_ });
18728
18729 /**
18730 * @license
18731 * Copyright 2020 Google LLC. All Rights Reserved.
18732 * Licensed under the Apache License, Version 2.0 (the "License");
18733 * you may not use this file except in compliance with the License.
18734 * You may obtain a copy of the License at
18735 *
18736 * http://www.apache.org/licenses/LICENSE-2.0
18737 *
18738 * Unless required by applicable law or agreed to in writing, software
18739 * distributed under the License is distributed on an "AS IS" BASIS,
18740 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18741 * See the License for the specific language governing permissions and
18742 * limitations under the License.
18743 * =============================================================================
18744 */
18745 /**
18746 * Fast Fourier transform.
18747 *
18748 * Computes the 1-dimensional discrete Fourier transform over the inner-most
18749 * dimension of input.
18750 *
18751 * ```js
18752 * const real = tf.tensor1d([1, 2, 3]);
18753 * const imag = tf.tensor1d([1, 2, 3]);
18754 * const x = tf.complex(real, imag);
18755 *
18756 * x.fft().print(); // tf.spectral.fft(x).print();
18757 * ```
18758 * @param input The complex input to compute an fft over.
18759 *
18760 * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
18761 */
18762 function fft_(input) {
18763 assert$1(input.dtype === 'complex64', () => `The dtype for tf.spectral.fft() must be complex64 ` +
18764 `but got ${input.dtype}.`);
18765 const inputs = { input };
18766 return ENGINE.runKernel(FFT, inputs);
18767 }
18768 const fft$2 = /* @__PURE__ */ op({ fft_ });
18769
18770 /**
18771 * @license
18772 * Copyright 2020 Google LLC. All Rights Reserved.
18773 * Licensed under the Apache License, Version 2.0 (the "License");
18774 * you may not use this file except in compliance with the License.
18775 * You may obtain a copy of the License at
18776 *
18777 * http://www.apache.org/licenses/LICENSE-2.0
18778 *
18779 * Unless required by applicable law or agreed to in writing, software
18780 * distributed under the License is distributed on an "AS IS" BASIS,
18781 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18782 * See the License for the specific language governing permissions and
18783 * limitations under the License.
18784 * =============================================================================
18785 */
18786 /**
18787 * Inverse fast Fourier transform.
18788 *
18789 * Computes the inverse 1-dimensional discrete Fourier transform over the
18790 * inner-most dimension of input.
18791 *
18792 * ```js
18793 * const real = tf.tensor1d([1, 2, 3]);
18794 * const imag = tf.tensor1d([1, 2, 3]);
18795 * const x = tf.complex(real, imag);
18796 *
18797 * x.ifft().print(); // tf.spectral.ifft(x).print();
18798 * ```
18799 * @param input The complex input to compute an ifft over.
18800 *
18801 * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
18802 */
18803 function ifft_(input) {
18804 assert$1(input.dtype === 'complex64', () => `The dtype for tf.spectral.ifft() must be complex64 ` +
18805 `but got ${input.dtype}.`);
18806 const inputs = { input };
18807 return ENGINE.runKernel(IFFT, inputs);
18808 }
18809 const ifft$2 = /* @__PURE__ */ op({ ifft_ });
18810
18811 /**
18812 * @license
18813 * Copyright 2018 Google LLC. All Rights Reserved.
18814 * Licensed under the Apache License, Version 2.0 (the "License");
18815 * you may not use this file except in compliance with the License.
18816 * You may obtain a copy of the License at
18817 *
18818 * http://www.apache.org/licenses/LICENSE-2.0
18819 *
18820 * Unless required by applicable law or agreed to in writing, software
18821 * distributed under the License is distributed on an "AS IS" BASIS,
18822 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18823 * See the License for the specific language governing permissions and
18824 * limitations under the License.
18825 * =============================================================================
18826 */
18827 /**
18828 * Inversed real value input fast Fourier transform.
18829 *
18830 * Computes the 1-dimensional inversed discrete Fourier transform over the
18831 * inner-most dimension of the real input.
18832 *
18833 * ```js
18834 * const real = tf.tensor1d([1, 2, 3]);
18835 * const imag = tf.tensor1d([0, 0, 0]);
18836 * const x = tf.complex(real, imag);
18837 *
18838 * x.irfft().print();
18839 * ```
18840 * @param input The real value input to compute an irfft over.
18841 *
18842 * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
18843 */
18844 function irfft_(input) {
18845 const innerDimensionSize = input.shape[input.shape.length - 1];
18846 const batch = input.size / innerDimensionSize;
18847 let ret;
18848 if (innerDimensionSize <= 2) {
18849 const complexInput = reshape$3(input, [batch, innerDimensionSize]);
18850 ret = ifft$2(complexInput);
18851 }
18852 else {
18853 // The length of unique components of the DFT of a real-valued signal
18854 // is 2 * (input_len - 1)
18855 const outputShape = [batch, 2 * (innerDimensionSize - 1)];
18856 const realInput = reshape$3(real$2(input), [batch, innerDimensionSize]);
18857 const imagInput = reshape$3(imag$2(input), [batch, innerDimensionSize]);
18858 const realConjugate = reverse$2(slice$2(realInput, [0, 1], [batch, innerDimensionSize - 2]), 1);
18859 const imagConjugate = mul(reverse$2(slice$2(imagInput, [0, 1], [batch, innerDimensionSize - 2]), 1), scalar(-1));
18860 const r = concat$2([realInput, realConjugate], 1);
18861 const i = concat$2([imagInput, imagConjugate], 1);
18862 const complexInput = reshape$3(complex$2(r, i), [outputShape[0], outputShape[1]]);
18863 ret = ifft$2(complexInput);
18864 }
18865 ret = real$2(ret);
18866 // reshape the result if the input is 3D tensor.
18867 if (input.rank === 3 && input.shape[0] !== 0) {
18868 const temp = ret;
18869 const batch = input.shape[0];
18870 ret = reshape$3(ret, [batch, ret.shape[0] / batch, ret.shape[1]]);
18871 temp.dispose();
18872 }
18873 return ret;
18874 }
18875 const irfft = /* @__PURE__ */ op({ irfft_ });
18876
18877 /**
18878 * @license
18879 * Copyright 2020 Google LLC. All Rights Reserved.
18880 * Licensed under the Apache License, Version 2.0 (the "License");
18881 * you may not use this file except in compliance with the License.
18882 * You may obtain a copy of the License at
18883 *
18884 * http://www.apache.org/licenses/LICENSE-2.0
18885 *
18886 * Unless required by applicable law or agreed to in writing, software
18887 * distributed under the License is distributed on an "AS IS" BASIS,
18888 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18889 * See the License for the specific language governing permissions and
18890 * limitations under the License.
18891 * =============================================================================
18892 */
18893 /**
18894 * Splits a `tf.Tensor` into sub tensors.
18895 *
18896 * If `numOrSizeSplits` is a number, splits `x` along dimension `axis`
18897 * into `numOrSizeSplits` smaller tensors.
18898 * Requires that `numOrSizeSplits` evenly divides `x.shape[axis]`.
18899 *
18900 * If `numOrSizeSplits` is a number array, splits `x` into
18901 * `numOrSizeSplits.length` pieces. The shape of the `i`-th piece has the
18902 * same size as `x` except along dimension `axis` where the size is
18903 * `numOrSizeSplits[i]`.
18904 *
18905 * ```js
18906 * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]);
18907 * const [a, b] = tf.split(x, 2, 1);
18908 * a.print();
18909 * b.print();
18910 *
18911 * const [c, d, e] = tf.split(x, [1, 2, 1], 1);
18912 * c.print();
18913 * d.print();
18914 * e.print();
18915 * ```
18916 *
18917 * @param x The input tensor to split.
18918 * @param numOrSizeSplits Either an integer indicating the number of
18919 * splits along the axis or an array of integers containing the sizes of
18920 * each output tensor along the axis. If a number then it must evenly divide
18921 * `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`.
18922 * Can contain one -1 indicating that dimension is to be inferred.
18923 * @param axis The dimension along which to split. Defaults to 0 (the first
18924 * dim).
18925 *
18926 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
18927 */
18928 function split_(x, numOrSizeSplits, axis = 0) {
18929 const $x = convertToTensor(x, 'x', 'split');
18930 const inputs = { x: $x };
18931 const attr = { numOrSizeSplits, axis };
18932 return ENGINE.runKernel(SplitV, inputs, attr);
18933 }
18934 const split$3 = /* @__PURE__ */ op({ split_ });
18935
18936 /**
18937 * @license
18938 * Copyright 2018 Google LLC. All Rights Reserved.
18939 * Licensed under the Apache License, Version 2.0 (the "License");
18940 * you may not use this file except in compliance with the License.
18941 * You may obtain a copy of the License at
18942 *
18943 * http://www.apache.org/licenses/LICENSE-2.0
18944 *
18945 * Unless required by applicable law or agreed to in writing, software
18946 * distributed under the License is distributed on an "AS IS" BASIS,
18947 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18948 * See the License for the specific language governing permissions and
18949 * limitations under the License.
18950 * =============================================================================
18951 */
18952 /**
18953 * Real value input fast Fourier transform.
18954 *
18955 * Computes the 1-dimensional discrete Fourier transform over the
18956 * inner-most dimension of the real input.
18957 *
18958 * ```js
18959 * const real = tf.tensor1d([1, 2, 3]);
18960 *
18961 * real.rfft().print();
18962 * ```
18963 * @param input The real value input to compute an rfft over.
18964 *
18965 * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
18966 */
18967 function rfft_(input, fftLength) {
18968 assert$1(input.dtype === 'float32', () => `The dtype for rfft() must be real value but got ${input.dtype}`);
18969 let innerDimensionSize = input.shape[input.shape.length - 1];
18970 const batch = input.size / innerDimensionSize;
18971 let adjustedInput;
18972 if (fftLength != null && fftLength < innerDimensionSize) {
18973 // Need to crop
18974 const begin = input.shape.map(v => 0);
18975 const size = input.shape.map(v => v);
18976 size[input.shape.length - 1] = fftLength;
18977 adjustedInput = slice$2(input, begin, size);
18978 innerDimensionSize = fftLength;
18979 }
18980 else if (fftLength != null && fftLength > innerDimensionSize) {
18981 // Need to pad with zeros
18982 const zerosShape = input.shape.map(v => v);
18983 zerosShape[input.shape.length - 1] = fftLength - innerDimensionSize;
18984 adjustedInput = concat$2([input, zeros$2(zerosShape)], input.shape.length - 1);
18985 innerDimensionSize = fftLength;
18986 }
18987 else {
18988 adjustedInput = input;
18989 }
18990 // Complement the input with zero imaginary numbers.
18991 const zerosInput = zerosLike$3(adjustedInput);
18992 const complexInput = reshape$3(complex$2(adjustedInput, zerosInput), [batch, innerDimensionSize]);
18993 const ret = fft$2(complexInput);
18994 // Exclude complex conjugations. These conjugations are put symmetrically.
18995 const half = Math.floor(innerDimensionSize / 2) + 1;
18996 const realValues = real$2(ret);
18997 const imagValues = imag$2(ret);
18998 const realComplexConjugate = split$3(realValues, [half, innerDimensionSize - half], realValues.shape.length - 1);
18999 const imagComplexConjugate = split$3(imagValues, [half, innerDimensionSize - half], imagValues.shape.length - 1);
19000 const outputShape = adjustedInput.shape.slice();
19001 outputShape[adjustedInput.shape.length - 1] = half;
19002 return reshape$3(complex$2(realComplexConjugate[0], imagComplexConjugate[0]), outputShape);
19003 }
19004 const rfft = /* @__PURE__ */ op({ rfft_ });
19005
19006 /**
19007 * @license
19008 * Copyright 2020 Google LLC. All Rights Reserved.
19009 * Licensed under the Apache License, Version 2.0 (the "License");
19010 * you may not use this file except in compliance with the License.
19011 * You may obtain a copy of the License at
19012 *
19013 * http://www.apache.org/licenses/LICENSE-2.0
19014 *
19015 * Unless required by applicable law or agreed to in writing, software
19016 * distributed under the License is distributed on an "AS IS" BASIS,
19017 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19018 * See the License for the specific language governing permissions and
19019 * limitations under the License.
19020 * =============================================================================
19021 */
19022 /**
19023 * Returns (a - b) * (a - b) element-wise.
19024 * Supports broadcasting.
19025 *
19026 * ```js
19027 * const a = tf.tensor1d([1, 4, 3, 16]);
19028 * const b = tf.tensor1d([1, 2, 9, 4]);
19029 *
19030 * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b)
19031 * ```
19032 *
19033 * ```js
19034 * // Broadcast squared difference a with b.
19035 * const a = tf.tensor1d([2, 4, 6, 8]);
19036 * const b = tf.scalar(5);
19037 *
19038 * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b)
19039 * ```
19040 *
19041 * @param a The first tensor.
19042 * @param b The second tensor. Must have the same type as `a`.
19043 *
19044 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
19045 */
19046 function squaredDifference_(a, b) {
19047 let $a = convertToTensor(a, 'a', 'squaredDifference');
19048 let $b = convertToTensor(b, 'b', 'squaredDifference');
19049 [$a, $b] = makeTypesMatch($a, $b);
19050 assertAndGetBroadcastShape($a.shape, $b.shape);
19051 const inputs = { a: $a, b: $b };
19052 const attrs = {};
19053 return ENGINE.runKernel(SquaredDifference, inputs, attrs);
19054 }
19055 const squaredDifference$2 = /* @__PURE__ */ op({ squaredDifference_ });
19056
19057 /**
19058 * @license
19059 * Copyright 2020 Google LLC. All Rights Reserved.
19060 * Licensed under the Apache License, Version 2.0 (the "License");
19061 * you may not use this file except in compliance with the License.
19062 * You may obtain a copy of the License at
19063 *
19064 * http://www.apache.org/licenses/LICENSE-2.0
19065 *
19066 * Unless required by applicable law or agreed to in writing, software
19067 * distributed under the License is distributed on an "AS IS" BASIS,
19068 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19069 * See the License for the specific language governing permissions and
19070 * limitations under the License.
19071 * =============================================================================
19072 */
19073 /**
19074 * Removes dimensions of size 1 from the shape of a `tf.Tensor`.
19075 *
19076 * ```js
19077 * const x = tf.tensor([1, 2, 3, 4], [1, 1, 4]);
19078 * x.squeeze().print();
19079 * ```
19080 *
19081 * @param x The input tensor to be squeezed.
19082 * @param axis An optional list of numbers. If specified, only
19083 * squeezes the dimensions listed. The dimension index starts at 0. It
19084 * is an error to squeeze a dimension that is not 1.
19085 *
19086 * @doc {heading: 'Tensors', subheading: 'Transformations'}
19087 */
19088 function squeeze_(x, axis) {
19089 const $x = convertToTensor(x, 'x', 'squeeze', 'string_or_numeric');
19090 return reshape$3($x, squeezeShape($x.shape, axis).newShape);
19091 }
19092 const squeeze = /* @__PURE__ */ op({ squeeze_ });
19093
19094 /**
19095 * @license
19096 * Copyright 2020 Google LLC. All Rights Reserved.
19097 * Licensed under the Apache License, Version 2.0 (the "License");
19098 * you may not use this file except in compliance with the License.
19099 * You may obtain a copy of the License at
19100 *
19101 * http://www.apache.org/licenses/LICENSE-2.0
19102 *
19103 * Unless required by applicable law or agreed to in writing, software
19104 * distributed under the License is distributed on an "AS IS" BASIS,
19105 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19106 * See the License for the specific language governing permissions and
19107 * limitations under the License.
19108 * =============================================================================
19109 */
19110 /**
19111 * Stacks a list of rank-`R` `tf.Tensor`s into one rank-`(R+1)` `tf.Tensor`.
19112 *
19113 * ```js
19114 * const a = tf.tensor1d([1, 2]);
19115 * const b = tf.tensor1d([3, 4]);
19116 * const c = tf.tensor1d([5, 6]);
19117 * tf.stack([a, b, c]).print();
19118 * ```
19119 *
19120 * @param tensors A list of tensor objects with the same shape and dtype.
19121 * @param axis The axis to stack along. Defaults to 0 (the first dim).
19122 *
19123 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
19124 */
19125 function stack_(tensors, axis = 0) {
19126 const $tensors = convertToTensorArray(tensors, 'tensors', 'stack', 'string_or_numeric');
19127 assert$1($tensors.length >= 1, () => 'Pass at least one tensor to tf.stack');
19128 if ($tensors.length > 0) {
19129 assert$1(axis <= $tensors[0].rank, () => 'Axis must be <= rank of the tensor');
19130 }
19131 const inputs = $tensors;
19132 const attrs = { axis };
19133 return ENGINE.runKernel(Pack, inputs, attrs);
19134 }
19135 const stack = /* @__PURE__ */ op({ stack_ });
19136
19137 /**
19138 * @license
19139 * Copyright 2018 Google LLC. All Rights Reserved.
19140 * Licensed under the Apache License, Version 2.0 (the "License");
19141 * you may not use this file except in compliance with the License.
19142 * You may obtain a copy of the License at
19143 *
19144 * http://www.apache.org/licenses/LICENSE-2.0
19145 *
19146 * Unless required by applicable law or agreed to in writing, software
19147 * distributed under the License is distributed on an "AS IS" BASIS,
19148 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19149 * See the License for the specific language governing permissions and
19150 * limitations under the License.
19151 * =============================================================================
19152 */
19153 /**
19154 * Computes step of the input `tf.Tensor` element-wise: `x > 0 ? 1 : alpha`
19155 *
19156 * ```js
19157 * const x = tf.tensor1d([0, 2, -1, -3]);
19158 *
19159 * x.step(.5).print(); // or tf.step(x, .5)
19160 * ```
19161 * @param x The input tensor.
19162 * @param alpha The gradient when input is negative. Defaults to 0.
19163 *
19164 * @doc {heading: 'Operations', subheading: 'Basic math'}
19165 */
19166 function step_(x, alpha = 0.0) {
19167 const $x = convertToTensor(x, 'x', 'step');
19168 const inputs = { x: $x };
19169 const attrs = { alpha };
19170 return ENGINE.runKernel(Step, inputs, attrs);
19171 }
19172 const step$2 = /* @__PURE__ */ op({ step_ });
19173
19174 /**
19175 * @license
19176 * Copyright 2018 Google LLC. All Rights Reserved.
19177 * Licensed under the Apache License, Version 2.0 (the "License");
19178 * you may not use this file except in compliance with the License.
19179 * You may obtain a copy of the License at
19180 *
19181 * http://www.apache.org/licenses/LICENSE-2.0
19182 *
19183 * Unless required by applicable law or agreed to in writing, software
19184 * distributed under the License is distributed on an "AS IS" BASIS,
19185 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19186 * See the License for the specific language governing permissions and
19187 * limitations under the License.
19188 * =============================================================================
19189 */
19190 /**
19191 * Extracts a strided slice of a tensor.
19192 *
19193 * Roughly speaking, this op extracts a slice of size (end-begin)/stride from
19194 * the given input tensor (x). Starting at the location specified by begin the
19195 * slice continues by adding stride to the index until all dimensions are not
19196 * less than end. Note that a stride can be negative, which causes a reverse
19197 * slice.
19198 *
19199 * ```js
19200 * const t = tf.tensor3d([1, 1, 1 ,2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6],
19201 * [3, 2, 3]);
19202 * t.stridedSlice([1, 0, 0], [2, 1, 3], [1, 1, 1]).print() // [[[3, 3, 3]]]
19203 * t.stridedSlice([1, 0, 0], [2, 2, 3], [1, 1, 1]).print() // [[[3, 3, 3],
19204 * // [4, 4, 4]]]
19205 * t.stridedSlice([1, -1, 0], [2, -3, 3], [1, -1, 1]).print() // [[[4, 4, 4],
19206 * // [3, 3, 3]]]
19207 * ```
19208 *
19209 * @param x The tensor to stride slice.
19210 * @param begin The coordinates to start the slice from.
19211 * @param end: The coordinates to end the slice at.
19212 * @param strides: The size of the slice.
19213 * @param beginMask: If the ith bit of beginMask is set, begin[i] is ignored
19214 * and the fullest possible range in that dimension is used instead.
19215 * @param endMask: If the ith bit of endMask is set, end[i] is ignored
19216 * and the fullest possible range in that dimension is used instead.
19217 * @param shrinkAxisMask: a bitmask where bit i implies that
19218 * the ith specification should shrink the dimensionality. begin and end must
19219 * imply a slice of size 1 in the dimension.
19220 *
19221 * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
19222 */
19223 function stridedSlice_(x, begin, end, strides, beginMask = 0, endMask = 0, ellipsisMask = 0, newAxisMask = 0, shrinkAxisMask = 0) {
19224 const $x = convertToTensor(x, 'x', 'stridedSlice', 'string_or_numeric');
19225 const inputs = { x: $x };
19226 const attrs = {
19227 begin,
19228 end,
19229 strides,
19230 beginMask,
19231 endMask,
19232 ellipsisMask,
19233 newAxisMask,
19234 shrinkAxisMask
19235 };
19236 return ENGINE.runKernel(StridedSlice, inputs, attrs);
19237 }
19238 const stridedSlice$2 = /* @__PURE__ */ op({ stridedSlice_ });
19239
19240 /**
19241 * @license
19242 * Copyright 2018 Google LLC. All Rights Reserved.
19243 * Licensed under the Apache License, Version 2.0 (the "License");
19244 * you may not use this file except in compliance with the License.
19245 * You may obtain a copy of the License at
19246 *
19247 * http://www.apache.org/licenses/LICENSE-2.0
19248 *
19249 * Unless required by applicable law or agreed to in writing, software
19250 * distributed under the License is distributed on an "AS IS" BASIS,
19251 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19252 * See the License for the specific language governing permissions and
19253 * limitations under the License.
19254 * =============================================================================
19255 */
19256 /**
19257 * Computes tan of the input `tf.Tensor` element-wise, `tan(x)`
19258 *
19259 * ```js
19260 * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
19261 *
19262 * x.tan().print(); // or tf.tan(x)
19263 * ```
19264 * @param x The input tensor.
19265 *
19266 * @doc {heading: 'Operations', subheading: 'Basic math'}
19267 */
19268 function tan_(x) {
19269 const $x = convertToTensor(x, 'x', 'tan', 'float32');
19270 const inputs = { x: $x };
19271 return ENGINE.runKernel(Tan, inputs);
19272 }
19273 const tan$2 = /* @__PURE__ */ op({ tan_ });
19274
19275 /**
19276 * @license
19277 * Copyright 2018 Google LLC. All Rights Reserved.
19278 * Licensed under the Apache License, Version 2.0 (the "License");
19279 * you may not use this file except in compliance with the License.
19280 * You may obtain a copy of the License at
19281 *
19282 * http://www.apache.org/licenses/LICENSE-2.0
19283 *
19284 * Unless required by applicable law or agreed to in writing, software
19285 * distributed under the License is distributed on an "AS IS" BASIS,
19286 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19287 * See the License for the specific language governing permissions and
19288 * limitations under the License.
19289 * =============================================================================
19290 */
19291 /**
19292 * Creates rank-1 `tf.Tensor` with the provided values, shape and dtype.
19293 *
19294 * The same functionality can be achieved with `tf.tensor`, but in general
19295 * we recommend using `tf.tensor1d` as it makes the code more readable.
19296 *
19297 * ```js
19298 * tf.tensor1d([1, 2, 3]).print();
19299 * ```
19300 *
19301 * @param values The values of the tensor. Can be array of numbers,
19302 * or a `TypedArray`.
19303 * @param dtype The data type.
19304 *
19305 * @doc {heading: 'Tensors', subheading: 'Creation'}
19306 */
19307 function tensor1d(values, dtype) {
19308 assertNonNull(values);
19309 const inferredShape = inferShape(values, dtype);
19310 if (inferredShape.length !== 1) {
19311 throw new Error('tensor1d() requires values to be a flat/TypedArray');
19312 }
19313 const shape = null;
19314 return makeTensor(values, shape, inferredShape, dtype);
19315 }
19316
19317 /**
19318 * @license
19319 * Copyright 2018 Google LLC. All Rights Reserved.
19320 * Licensed under the Apache License, Version 2.0 (the "License");
19321 * you may not use this file except in compliance with the License.
19322 * You may obtain a copy of the License at
19323 *
19324 * http://www.apache.org/licenses/LICENSE-2.0
19325 *
19326 * Unless required by applicable law or agreed to in writing, software
19327 * distributed under the License is distributed on an "AS IS" BASIS,
19328 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19329 * See the License for the specific language governing permissions and
19330 * limitations under the License.
19331 * =============================================================================
19332 */
19333 /**
19334 * Creates rank-2 `tf.Tensor` with the provided values, shape and dtype.
19335 *
19336 * The same functionality can be achieved with `tf.tensor`, but in general
19337 * we recommend using `tf.tensor2d` as it makes the code more readable.
19338 *
19339 * ```js
19340 * // Pass a nested array.
19341 * tf.tensor2d([[1, 2], [3, 4]]).print();
19342 * ```
19343 * ```js
19344 * // Pass a flat array and specify a shape.
19345 * tf.tensor2d([1, 2, 3, 4], [2, 2]).print();
19346 * ```
19347 *
19348 * @param values The values of the tensor. Can be nested array of numbers,
19349 * or a flat array, or a `TypedArray`.
19350 * @param shape The shape of the tensor. If not provided, it is inferred from
19351 * `values`.
19352 * @param dtype The data type.
19353 *
19354 * @doc {heading: 'Tensors', subheading: 'Creation'}
19355 */
19356 function tensor2d(values, shape, dtype) {
19357 assertNonNull(values);
19358 if (shape != null && shape.length !== 2) {
19359 throw new Error('tensor2d() requires shape to have two numbers');
19360 }
19361 const inferredShape = inferShape(values, dtype);
19362 if (inferredShape.length !== 2 && inferredShape.length !== 1) {
19363 throw new Error('tensor2d() requires values to be number[][] or flat/TypedArray');
19364 }
19365 if (inferredShape.length === 1 && shape == null) {
19366 throw new Error('tensor2d() requires shape to be provided when `values` ' +
19367 'are a flat/TypedArray');
19368 }
19369 return makeTensor(values, shape, inferredShape, dtype);
19370 }
19371
19372 /**
19373 * @license
19374 * Copyright 2018 Google LLC. All Rights Reserved.
19375 * Licensed under the Apache License, Version 2.0 (the "License");
19376 * you may not use this file except in compliance with the License.
19377 * You may obtain a copy of the License at
19378 *
19379 * http://www.apache.org/licenses/LICENSE-2.0
19380 *
19381 * Unless required by applicable law or agreed to in writing, software
19382 * distributed under the License is distributed on an "AS IS" BASIS,
19383 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19384 * See the License for the specific language governing permissions and
19385 * limitations under the License.
19386 * =============================================================================
19387 */
19388 /**
19389 * Creates rank-3 `tf.Tensor` with the provided values, shape and dtype.
19390 *
19391 * The same functionality can be achieved with `tf.tensor`, but in general
19392 * we recommend using `tf.tensor3d` as it makes the code more readable.
19393 *
19394 * ```js
19395 * // Pass a nested array.
19396 * tf.tensor3d([[[1], [2]], [[3], [4]]]).print();
19397 * ```
19398 * ```js
19399 * // Pass a flat array and specify a shape.
19400 * tf.tensor3d([1, 2, 3, 4], [2, 2, 1]).print();
19401 * ```
19402 *
19403 * @param values The values of the tensor. Can be nested array of numbers,
19404 * or a flat array, or a `TypedArray`.
19405 * @param shape The shape of the tensor. If not provided, it is inferred from
19406 * `values`.
19407 * @param dtype The data type.
19408 *
19409 * @doc {heading: 'Tensors', subheading: 'Creation'}
19410 */
19411 function tensor3d(values, shape, dtype) {
19412 assertNonNull(values);
19413 if (shape != null && shape.length !== 3) {
19414 throw new Error('tensor3d() requires shape to have three numbers');
19415 }
19416 const inferredShape = inferShape(values, dtype);
19417 if (inferredShape.length !== 3 && inferredShape.length !== 1) {
19418 throw new Error('tensor3d() requires values to be number[][][] or flat/TypedArray');
19419 }
19420 if (inferredShape.length === 1 && shape == null) {
19421 throw new Error('tensor3d() requires shape to be provided when `values` ' +
19422 'are a flat array');
19423 }
19424 return makeTensor(values, shape, inferredShape, dtype);
19425 }
19426
19427 /**
19428 * @license
19429 * Copyright 2018 Google LLC. All Rights Reserved.
19430 * Licensed under the Apache License, Version 2.0 (the "License");
19431 * you may not use this file except in compliance with the License.
19432 * You may obtain a copy of the License at
19433 *
19434 * http://www.apache.org/licenses/LICENSE-2.0
19435 *
19436 * Unless required by applicable law or agreed to in writing, software
19437 * distributed under the License is distributed on an "AS IS" BASIS,
19438 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19439 * See the License for the specific language governing permissions and
19440 * limitations under the License.
19441 * =============================================================================
19442 */
19443 /**
19444 * Creates rank-4 `tf.Tensor` with the provided values, shape and dtype.
19445 *
19446 * The same functionality can be achieved with `tf.tensor`, but in general
19447 * we recommend using `tf.tensor4d` as it makes the code more readable.
19448 *
19449 * ```js
19450 * // Pass a nested array.
19451 * tf.tensor4d([[[[1], [2]], [[3], [4]]]]).print();
19452 * ```
19453 * ```js
19454 * // Pass a flat array and specify a shape.
19455 * tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]).print();
19456 * ```
19457 *
19458 * @param values The values of the tensor. Can be nested array of numbers,
19459 * or a flat array, or a `TypedArray`.
19460 * @param shape The shape of the tensor. Optional. If not provided,
19461 * it is inferred from `values`.
19462 * @param dtype The data type.
19463 *
19464 * @doc {heading: 'Tensors', subheading: 'Creation'}
19465 */
19466 function tensor4d(values, shape, dtype) {
19467 assertNonNull(values);
19468 if (shape != null && shape.length !== 4) {
19469 throw new Error('tensor4d() requires shape to have four numbers');
19470 }
19471 const inferredShape = inferShape(values, dtype);
19472 if (inferredShape.length !== 4 && inferredShape.length !== 1) {
19473 throw new Error('tensor4d() requires values to be number[][][][] or flat/TypedArray');
19474 }
19475 if (inferredShape.length === 1 && shape == null) {
19476 throw new Error('tensor4d() requires shape to be provided when `values` ' +
19477 'are a flat array');
19478 }
19479 return makeTensor(values, shape, inferredShape, dtype);
19480 }
19481
19482 /**
19483 * @license
19484 * Copyright 2018 Google LLC. All Rights Reserved.
19485 * Licensed under the Apache License, Version 2.0 (the "License");
19486 * you may not use this file except in compliance with the License.
19487 * You may obtain a copy of the License at
19488 *
19489 * http://www.apache.org/licenses/LICENSE-2.0
19490 *
19491 * Unless required by applicable law or agreed to in writing, software
19492 * distributed under the License is distributed on an "AS IS" BASIS,
19493 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19494 * See the License for the specific language governing permissions and
19495 * limitations under the License.
19496 * =============================================================================
19497 */
19498 /**
19499 * Creates rank-5 `tf.Tensor` with the provided values, shape and dtype.
19500 *
19501 * The same functionality can be achieved with `tf.tensor`, but in general
19502 * we recommend using `tf.tensor5d` as it makes the code more readable.
19503 *
19504 * ```js
19505 * // Pass a nested array.
19506 * tf.tensor5d([[[[[1],[2]],[[3],[4]]],[[[5],[6]],[[7],[8]]]]]).print();
19507 * ```
19508 * ```js
19509 * // Pass a flat array and specify a shape.
19510 * tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]).print();
19511 * ```
19512 *
19513 * @param values The values of the tensor. Can be nested array of numbers,
19514 * or a flat array, or a `TypedArray`.
19515 * @param shape The shape of the tensor. Optional. If not provided,
19516 * it is inferred from `values`.
19517 * @param dtype The data type.
19518 *
19519 * @doc {heading: 'Tensors', subheading: 'Creation'}
19520 */
19521 function tensor5d(values, shape, dtype) {
19522 assertNonNull(values);
19523 if (shape != null && shape.length !== 5) {
19524 throw new Error('tensor5d() requires shape to have five numbers');
19525 }
19526 const inferredShape = inferShape(values, dtype);
19527 if (inferredShape.length !== 5 && inferredShape.length !== 1) {
19528 throw new Error('tensor5d() requires values to be ' +
19529 'number[][][][][] or flat/TypedArray');
19530 }
19531 if (inferredShape.length === 1 && shape == null) {
19532 throw new Error('tensor5d() requires shape to be provided when `values` ' +
19533 'are a flat array');
19534 }
19535 return makeTensor(values, shape, inferredShape, dtype);
19536 }
19537
19538 /**
19539 * @license
19540 * Copyright 2018 Google LLC. All Rights Reserved.
19541 * Licensed under the Apache License, Version 2.0 (the "License");
19542 * you may not use this file except in compliance with the License.
19543 * You may obtain a copy of the License at
19544 *
19545 * http://www.apache.org/licenses/LICENSE-2.0
19546 *
19547 * Unless required by applicable law or agreed to in writing, software
19548 * distributed under the License is distributed on an "AS IS" BASIS,
19549 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19550 * See the License for the specific language governing permissions and
19551 * limitations under the License.
19552 * =============================================================================
19553 */
19554 /**
19555 * Creates rank-6 `tf.Tensor` with the provided values, shape and dtype.
19556 *
19557 * The same functionality can be achieved with `tf.tensor`, but in general
19558 * we recommend using `tf.tensor6d` as it makes the code more readable.
19559 *
19560 * ```js
19561 * // Pass a nested array.
19562 * tf.tensor6d([[[[[[1],[2]],[[3],[4]]],[[[5],[6]],[[7],[8]]]]]]).print();
19563 * ```
19564 * ```js
19565 * // Pass a flat array and specify a shape.
19566 * tf.tensor6d([1, 2, 3, 4, 5, 6, 7, 8], [1, 1, 2, 2, 2, 1]).print();
19567 * ```
19568 *
19569 * @param values The values of the tensor. Can be nested array of numbers,
19570 * or a flat array, or a `TypedArray`.
19571 * @param shape The shape of the tensor. Optional. If not provided,
19572 * it is inferred from `values`.
19573 * @param dtype The data type.
19574 *
19575 * @doc {heading: 'Tensors', subheading: 'Creation'}
19576 */
19577 function tensor6d(values, shape, dtype) {
19578 assertNonNull(values);
19579 if (shape != null && shape.length !== 6) {
19580 throw new Error('tensor6d() requires shape to have six numbers');
19581 }
19582 const inferredShape = inferShape(values, dtype);
19583 if (inferredShape.length !== 6 && inferredShape.length !== 1) {
19584 throw new Error('tensor6d() requires values to be number[][][][][][] or ' +
19585 'flat/TypedArray');
19586 }
19587 if (inferredShape.length === 1 && shape == null) {
19588 throw new Error('tensor6d() requires shape to be provided when `values` ' +
19589 'are a flat array');
19590 }
19591 shape = shape ||
19592 inferredShape;
19593 return makeTensor(values, shape, inferredShape, dtype);
19594 }
19595
19596 /**
19597 * Check whether updates.shape = indices.shape[:batchDim] +
19598 * shape[sliceDim:]
19599 *
19600 * @param x The input tensor.
19601 */
19602 function validateUpdateShape(shape, indices, updates) {
19603 const sliceDim = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1;
19604 const batchDim = (indices.rank > 1) ? indices.rank - 1 : 1;
19605 const shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' +
19606 `shape[sliceDim:], got updates.shape: ${updates.shape}` +
19607 `, indices.shape: ${indices.shape}, shape: ${shape}` +
19608 `, sliceDim: ${sliceDim}, and batchDim: ${batchDim}.`;
19609 if (updates.rank < batchDim) {
19610 throw new Error(shapeError + ` update.rank < ${batchDim}. `);
19611 }
19612 if (shape.length < sliceDim + (updates.rank - batchDim)) {
19613 throw new Error(shapeError +
19614 ` Output shape length < ${sliceDim + (updates.rank - batchDim)}`);
19615 }
19616 if (updates.rank !== batchDim + shape.length - sliceDim) {
19617 throw new Error(shapeError + ` update.rank != ${batchDim + shape.length - sliceDim}`);
19618 }
19619 for (let d = 0; d < batchDim; ++d) {
19620 if (updates.shape[d] !== indices.shape[d]) {
19621 throw new Error(shapeError +
19622 ` updates.shape[${d}] (${updates.shape[d]}) != indices.shape[${d}] (${indices.shape[d]}).`);
19623 }
19624 }
19625 for (let d = 0; d < updates.rank - batchDim; ++d) {
19626 if (updates.shape[d + batchDim] !== shape[d + sliceDim]) {
19627 throw new Error(shapeError +
19628 ` updates.shape[${d + batchDim}] (${updates.shape[d + batchDim]}) != shape[${d + batchDim}] (${shape[d + batchDim]})`);
19629 }
19630 }
19631 }
19632 /**
19633 * Validate scatter nd inputs.
19634 *
19635 * @param update The tensor contains the update values.
19636 * @param indices The tensor contains the indices for the update values.
19637 * @param shape The shape of the output tensor.
19638 */
19639 function validateInput$1(updates, indices, shape) {
19640 if (indices.rank < 1) {
19641 throw new Error('tf.scatterND() expects the indices to be rank 1 or higher,' +
19642 ` but the rank was ${indices.rank}.`);
19643 }
19644 if (updates.rank < 1) {
19645 throw new Error('tf.scatterND() expects the updates to be rank 1 or higher,' +
19646 ` but the rank was ${updates.rank}.`);
19647 }
19648 if (indices.dtype !== 'int32') {
19649 throw new Error(`The dtype of 'indices' should be int32, but got dtype: ${indices.dtype}`);
19650 }
19651 if (shape.length < 1) {
19652 throw new Error(`Output rank must be greater or equal to 1, but got shape: ${shape}`);
19653 }
19654 if (shape.length === 0) {
19655 if (indices.size === 0) {
19656 throw new Error(`Indices specified for empty output. indices shape: ${indices.shape}`);
19657 }
19658 if (updates.size === 0) {
19659 throw new Error(`Updates specified for empty output. updates shape: ${updates.shape}`);
19660 }
19661 }
19662 validateUpdateShape(shape, indices, updates);
19663 }
19664 /**
19665 * Calculate the shape information for the output.
19666 *
19667 * @param update The tensor contains the update values.
19668 * @param indices The tensor contains the indices for the update values.
19669 * @param shape The shape of the output tensor.
19670 *
19671 * @returns ScatterShapeInfo
19672 */
19673 function calculateShapes(updates, indices, shape) {
19674 // Calculate the number of dimensions in indices
19675 const indicesRank = indices.shape.length;
19676 const sliceRank = (indicesRank > 1) ? indices.shape[indicesRank - 1] : 1;
19677 // Calculate the number of elements that make up each slice of our updated
19678 // tensor. This allows us to work with flattened tensors and copy over whole
19679 // slices at a time.
19680 const totalNd = shape.length;
19681 let sliceSize = 1;
19682 for (let i = sliceRank; i < totalNd; ++i) {
19683 sliceSize *= shape[i];
19684 }
19685 const safeSliceDim = (sliceRank < 1) ? 1 : sliceRank;
19686 const numUpdates = sizeFromShape(indices.shape) / safeSliceDim;
19687 const strides = [...computeStrides(shape.slice(0, sliceRank)), 1];
19688 const outputSize = sizeFromShape(shape);
19689 return { sliceRank, numUpdates, sliceSize, strides, outputSize };
19690 }
19691
19692 var scatter_nd_util = /*#__PURE__*/Object.freeze({
19693 __proto__: null,
19694 calculateShapes: calculateShapes,
19695 validateInput: validateInput$1,
19696 validateUpdateShape: validateUpdateShape
19697 });
19698
19699 /**
19700 * @license
19701 * Copyright 2022 Google LLC. All Rights Reserved.
19702 * Licensed under the Apache License, Version 2.0 (the "License");
19703 * you may not use this file except in compliance with the License.
19704 * You may obtain a copy of the License at
19705 *
19706 * http://www.apache.org/licenses/LICENSE-2.0
19707 *
19708 * Unless required by applicable law or agreed to in writing, software
19709 * distributed under the License is distributed on an "AS IS" BASIS,
19710 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19711 * See the License for the specific language governing permissions and
19712 * limitations under the License.
19713 * =============================================================================
19714 */
19715 /**
19716 * Creates a new tensor by applying sparse updates to individual
19717 * values or slices to the passed in tensor according to
19718 * indices. This operator is the similar to scatterNd op, except that the
19719 * udpates are scattered on an existing tensor (as opposed to a zero-tensor).
19720 *
19721 * If indices contains duplicates, then we pick the last update for the index.
19722 *
19723 * If an out of bound index is found on CPU, an error is returned.
19724 *
19725 * Warning: There are some GPU specific semantics for this operation.
19726 * - If an out of bound index is found, the index is ignored.
19727 * - The order in which updates are applied is nondeterministic, so the output
19728 * will be nondeterministic if indices contains duplicates.
19729 * ```js
19730 * const shape = [8];
19731 * const tensor = tf.ones(shape);
19732 * const indices = tf.tensor2d([4, 3, 1, 7], [4, 1], 'int32');
19733 * const updates = tf.tensor1d([9, 10, 11, 12]);
19734 *
19735 * tf.tensorScatterUpdate(tensor, indices, updates).print();
19736 * //[1, 11, 1, 10, 9, 1, 1, 12]
19737 * ```
19738 *
19739 * @param tensor A Tensor. Tensor to copy/update.
19740 * @param indices The tensor contains the indices into the output tensor, must
19741 * have at least 2 axes: (num_updates, index_depth).
19742 * @param updates The tensor contains the value for the indices.
19743 *
19744 * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
19745 */
19746 function tensorScatterUpdate_(tensor, indices, updates) {
19747 const $tensor = convertToTensor(tensor, 'tensor', 'tensorScatterupdate');
19748 const $indices = convertToTensor(indices, 'indices', 'tensorScatterupdate', 'int32');
19749 const $updates = convertToTensor(updates, 'updates', 'tensorScatterupdate');
19750 validateInput$1($updates, $indices, $tensor.shape);
19751 if ($tensor.dtype !== $updates.dtype) {
19752 throw new Error(`tensor and updates must have the same dtype, instead they are ${$tensor.dtype} and ${$updates.dtype}.`);
19753 }
19754 const inputs = {
19755 tensor: $tensor,
19756 indices: $indices,
19757 updates: $updates
19758 };
19759 const attrs = {};
19760 // tslint:disable-next-line: no-unnecessary-type-assertion
19761 return ENGINE.runKernel(TensorScatterUpdate, inputs, attrs);
19762 }
19763 const tensorScatterUpdate$2 = op({ tensorScatterUpdate_ });
19764
19765 /**
19766 * @license
19767 * Copyright 2018 Google LLC. All Rights Reserved.
19768 * Licensed under the Apache License, Version 2.0 (the "License");
19769 * you may not use this file except in compliance with the License.
19770 * You may obtain a copy of the License at
19771 *
19772 * http://www.apache.org/licenses/LICENSE-2.0
19773 *
19774 * Unless required by applicable law or agreed to in writing, software
19775 * distributed under the License is distributed on an "AS IS" BASIS,
19776 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19777 * See the License for the specific language governing permissions and
19778 * limitations under the License.
19779 * =============================================================================
19780 */
19781 /**
19782 * Finds the values and indices of the `k` largest entries along the last
19783 * dimension.
19784 *
19785 * If the input is a vector (rank=1), finds the k largest entries in the vector
19786 * and outputs their values and indices as vectors. Thus values[j] is the j-th
19787 * largest entry in input, and its index is indices[j].
19788 * For higher rank inputs, computes the top k entries along the last dimension.
19789 *
19790 * If two elements are equal, the lower-index element appears first.
19791 *
19792 * ```js
19793 * const a = tf.tensor2d([[1, 5], [4, 3]]);
19794 * const {values, indices} = tf.topk(a);
19795 * values.print();
19796 * indices.print();
19797 * ```
19798 * @param x 1-D or higher `tf.Tensor` with last dimension being at least `k`.
19799 * @param k Number of top elements to look for along the last dimension.
19800 * @param sorted If true, the resulting `k` elements will be sorted by the
19801 * values in descending order.
19802 *
19803 * @doc {heading: 'Operations', subheading: 'Evaluation'}
19804 */
19805 function topk_(x, k = 1, sorted = true) {
19806 const $x = convertToTensor(x, 'x', 'topk');
19807 if ($x.rank === 0) {
19808 throw new Error('topk() expects the input to be of rank 1 or higher');
19809 }
19810 const lastDim = $x.shape[$x.shape.length - 1];
19811 if (k < 0) {
19812 throw new Error(`'k' passed to topk() must be >= 0 but got ${k}`);
19813 }
19814 if (k > lastDim) {
19815 throw new Error(`'k' passed to topk() must be <= the last dimension (${lastDim}) ` +
19816 `but got ${k}`);
19817 }
19818 const inputs = { x: $x };
19819 const attrs = { k, sorted };
19820 const [values, indices] = ENGINE.runKernel(TopK, inputs, attrs);
19821 return { values, indices };
19822 }
19823 const topk = /* @__PURE__ */ op({ topk_ });
19824
19825 /**
19826 * @license
19827 * Copyright 2020 Google LLC. All Rights Reserved.
19828 * Licensed under the Apache License, Version 2.0 (the "License");
19829 * you may not use this file except in compliance with the License.
19830 * You may obtain a copy of the License at
19831 *
19832 * http://www.apache.org/licenses/LICENSE-2.0
19833 *
19834 * Unless required by applicable law or agreed to in writing, software
19835 * distributed under the License is distributed on an "AS IS" BASIS,
19836 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19837 * See the License for the specific language governing permissions and
19838 * limitations under the License.
19839 * =============================================================================
19840 */
19841 /**
19842 * Creates a `tf.Tensor` with values sampled from a truncated normal
19843 * distribution.
19844 *
19845 * ```js
19846 * tf.truncatedNormal([2, 2]).print();
19847 * ```
19848 *
19849 * The generated values follow a normal distribution with specified mean and
19850 * standard deviation, except that values whose magnitude is more than 2
19851 * standard deviations from the mean are dropped and re-picked.
19852 *
19853 * @param shape An array of integers defining the output tensor shape.
19854 * @param mean The mean of the normal distribution.
19855 * @param stdDev The standard deviation of the normal distribution.
19856 * @param dtype The data type of the output tensor.
19857 * @param seed The seed for the random number generator.
19858 *
19859 * @doc {heading: 'Tensors', subheading: 'Creation'}
19860 */
19861 function truncatedNormal_(shape, mean = 0, stdDev = 1, dtype, seed) {
19862 assertNonNegativeIntegerDimensions(shape);
19863 if (dtype != null && dtype === 'bool') {
19864 throw new Error(`Unsupported data type $ { dtype }`);
19865 }
19866 const randGauss = new MPRandGauss(mean, stdDev, dtype, true /* truncated */, seed);
19867 const res = buffer(shape, dtype);
19868 for (let i = 0; i < res.values.length; i++) {
19869 res.values[i] = randGauss.nextValue();
19870 }
19871 return res.toTensor();
19872 }
19873 const truncatedNormal$1 = /* @__PURE__ */ op({ truncatedNormal_ });
19874
19875 /**
19876 * @license
19877 * Copyright 2020 Google LLC. All Rights Reserved.
19878 * Licensed under the Apache License, Version 2.0 (the "License");
19879 * you may not use this file except in compliance with the License.
19880 * You may obtain a copy of the License at
19881 *
19882 * http://www.apache.org/licenses/LICENSE-2.0
19883 *
19884 * Unless required by applicable law or agreed to in writing, software
19885 * distributed under the License is distributed on an "AS IS" BASIS,
19886 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19887 * See the License for the specific language governing permissions and
19888 * limitations under the License.
19889 * =============================================================================
19890 */
19891 /**
19892 * Finds unique elements along an axis of a tensor.
19893 *
19894 * It returns a tensor `values` containing all of the unique elements along the
19895 * `axis` of the given tensor `x` in the same order that they occur along the
19896 * `axis` in `x`; `x` does not need to be sorted. It also returns a tensor
19897 * `indices` the same size as the number of the elements in `x` along the `axis`
19898 * dimension. It contains the index in the unique output `values`.
19899 *
19900 * ```js
19901 * // A 1-D tensor
19902 * const a = tf.tensor1d([1, 1, 2, 4, 4, 4, 7, 8, 8]);
19903 * const {values, indices} = tf.unique(a);
19904 * values.print(); // [1, 2, 4, 7, 8,]
19905 * indices.print(); // [0, 0, 1, 2, 2, 2, 3, 4, 4]
19906 * ```
19907 *
19908 * ```js
19909 * // A 2-D tensor with axis=0
19910 * //
19911 * // 'a' is: [[1, 0, 0],
19912 * // [1, 0, 0],
19913 * // [2, 0, 0]]
19914 * const a = tf.tensor2d([[1, 0, 0], [1, 0, 0], [2, 0, 0]]);
19915 * const {values, indices} = tf.unique(a, 0)
19916 * values.print(); // [[1, 0, 0],
19917 * // [2, 0, 0]]
19918 * indices.print(); // [0, 0, 1]
19919 * ```
19920 *
19921 * ```js
19922 * // A 2-D tensor with axis=1
19923 * //
19924 * // 'a' is: [[1, 0, 0],
19925 * // [1, 0, 0],
19926 * // [2, 0, 0]]
19927 * const a = tf.tensor2d([[1, 0, 0], [1, 0, 0], [2, 0, 0]]);
19928 * const {values, indices} = tf.unique(a, 1)
19929 * values.print(); // [[1, 0],
19930 * // [1, 0],
19931 * // [2, 0]]
19932 * indices.print(); // [0, 1, 1]
19933 * ```
19934 * @param x A tensor (int32, string, bool).
19935 * @param axis The axis of the tensor to find the unique elements.
19936 * @returns [uniqueElements, indices] (see above for details)
19937 *
19938 * @doc {heading: 'Operations', subheading: 'Evaluation'}
19939 */
19940 function unique_(x, axis = 0) {
19941 const $x = convertToTensor(x, 'x', 'unique', 'string_or_numeric');
19942 assert$1($x.rank > 0, () => 'The input tensor must be at least 1D');
19943 const inputs = { x: $x };
19944 const attrs = { axis };
19945 const [values, indices] = ENGINE.runKernel(Unique, inputs, attrs);
19946 return { values, indices };
19947 }
19948 const unique$3 = /* @__PURE__ */ op({ unique_ });
19949
19950 /**
19951 * @license
19952 * Copyright 2020 Google LLC. All Rights Reserved.
19953 * Licensed under the Apache License, Version 2.0 (the "License");
19954 * you may not use this file except in compliance with the License.
19955 * You may obtain a copy of the License at
19956 *
19957 * http://www.apache.org/licenses/LICENSE-2.0
19958 *
19959 * Unless required by applicable law or agreed to in writing, software
19960 * distributed under the License is distributed on an "AS IS" BASIS,
19961 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19962 * See the License for the specific language governing permissions and
19963 * limitations under the License.
19964 * =============================================================================
19965 */
19966 /**
19967 * Computes the sum along segments of a `tf.Tensor`.
19968 *
19969 * ```js
19970 * const x = tf.tensor1d([1, 2, 3, 4]);
19971 * const segmentIds = tf.tensor1d([1, 2, 0, 1], 'int32');
19972 * const numSegments = 3;
19973 *
19974 * x.unsortedSegmentSum(segmentIds, numSegments).print()
19975 * //or tf.unsortedSegmentSum(x, segmentIds, numSegments)
19976 * ```
19977 * @param x The `tf.Tensor` that will be summed along its segments.
19978 * @param segmentIds A `tf.Tensor1D` whose rank is equal to the rank of `x`'s
19979 * dimension along the `axis`. Maps each element of `x` to a segment.
19980 * @param numSegments The number of distinct `segmentIds`.
19981 *
19982 * @doc {heading: 'Operations', subheading: 'Segment'}
19983 */
19984 function unsortedSegmentSum_(x, segmentIds, numSegments) {
19985 const $x = convertToTensor(x, 'x', 'unsortedSegmentSum');
19986 const $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'unsortedSegmentSum', 'int32');
19987 assert$1(isInt(numSegments), () => 'numSegments must be of dtype int');
19988 const inputs = { x: $x, segmentIds: $segmentIds };
19989 const attrs = { numSegments };
19990 return ENGINE.runKernel(UnsortedSegmentSum, inputs, attrs);
19991 }
19992 const unsortedSegmentSum$2 = /* @__PURE__ */ op({ unsortedSegmentSum_ });
19993
19994 /**
19995 * @license
19996 * Copyright 2020 Google LLC. All Rights Reserved.
19997 * Licensed under the Apache License, Version 2.0 (the "License");
19998 * you may not use this file except in compliance with the License.
19999 * You may obtain a copy of the License at
20000 *
20001 * http://www.apache.org/licenses/LICENSE-2.0
20002 *
20003 * Unless required by applicable law or agreed to in writing, software
20004 * distributed under the License is distributed on an "AS IS" BASIS,
20005 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20006 * See the License for the specific language governing permissions and
20007 * limitations under the License.
20008 * =============================================================================
20009 */
20010 /**
20011 * Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s.
20012 *
20013 * ```js
20014 * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
20015 *
20016 * tf.unstack(a).forEach(tensor => tensor.print());
20017 * ```
20018 *
20019 * @param x A tensor object.
20020 * @param axis The axis to unstack along. Defaults to 0 (the first dim).
20021 *
20022 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
20023 */
20024 function unstack_(x, axis = 0) {
20025 const $x = convertToTensor(x, 'x', 'unstack', 'string_or_numeric');
20026 assert$1(axis >= -$x.shape.length && axis < $x.shape.length, () => `Axis = ${axis} is not in [-${$x.shape.length}, ${$x.shape.length})`);
20027 const inputs = { value: $x };
20028 const attrs = { axis };
20029 return ENGINE.runKernel(Unpack, inputs, attrs);
20030 }
20031 const unstack = /* @__PURE__ */ op({ unstack_ });
20032
20033 /**
20034 * @license
20035 * Copyright 2022 Google LLC. All Rights Reserved.
20036 * Licensed under the Apache License, Version 2.0 (the "License");
20037 * you may not use this file except in compliance with the License.
20038 * You may obtain a copy of the License at
20039 *
20040 * http://www.apache.org/licenses/LICENSE-2.0
20041 *
20042 * Unless required by applicable law or agreed to in writing, software
20043 * distributed under the License is distributed on an "AS IS" BASIS,
20044 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20045 * See the License for the specific language governing permissions and
20046 * limitations under the License.
20047 * =============================================================================
20048 */
20049 /**
20050 * Searches for where a value would go in a sorted sequence.
20051 *
20052 * This is not a method for checking containment (like javascript in).
20053 *
20054 * The typical use case for this operation is "binning", "bucketing", or
20055 * "discretizing". The values are assigned to bucket-indices based on the edges
20056 * listed in 'sortedSequence'. This operation returns the bucket-index for each
20057 * value.
20058 *
20059 * The index returned corresponds to the first edge greater than the value.
20060 *
20061 * The axis is not settable for this operation. It always operates on the
20062 * innermost dimension (axis=-1). The operation will accept any number of outer
20063 * dimensions.
20064 *
20065 * Note: This operation assumes that 'upperBound' is sorted along the
20066 * innermost axis, maybe using 'sort(..., axis=-1)'. If the sequence is not
20067 * sorted no error is raised and the content of the returned tensor is not well
20068 * defined.
20069 *
20070 * ```js
20071 * const seq = tf.tensor1d([0, 3, 9, 10, 10]);
20072 * const values = tf.tensor1d([0, 4, 10]);
20073 * const result = tf.upperBound(seq, values);
20074 * result.print(); // [1, 2, 5]
20075 * ```
20076 * @param sortedSequence: N-D. Sorted sequence.
20077 * @param values: N-D. Search values.
20078 * @return An N-D int32 tensor the size of values containing the result of
20079 * applying upper bound to each value. The result is not a global index to
20080 * the entire Tensor, but the index in the last dimension.
20081 * @doc {heading: 'Operations', subheading: 'Evaluation'}
20082 */
20083 function upperBound$1(sortedSequence, values) {
20084 return searchSorted$2(sortedSequence, values, 'right');
20085 }
20086
20087 /**
20088 * @license
20089 * Copyright 2018 Google LLC. All Rights Reserved.
20090 * Licensed under the Apache License, Version 2.0 (the "License");
20091 * you may not use this file except in compliance with the License.
20092 * You may obtain a copy of the License at
20093 *
20094 * http://www.apache.org/licenses/LICENSE-2.0
20095 *
20096 * Unless required by applicable law or agreed to in writing, software
20097 * distributed under the License is distributed on an "AS IS" BASIS,
20098 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20099 * See the License for the specific language governing permissions and
20100 * limitations under the License.
20101 * =============================================================================
20102 */
20103 /**
20104 * Creates a new variable with the provided initial value.
20105 * ```js
20106 * const x = tf.variable(tf.tensor([1, 2, 3]));
20107 * x.assign(tf.tensor([4, 5, 6]));
20108 *
20109 * x.print();
20110 * ```
20111 *
20112 * @param initialValue Initial value for the tensor.
20113 * @param trainable If true, optimizers are allowed to update it.
20114 * @param name Name of the variable. Defaults to a unique id.
20115 * @param dtype If set, initialValue will be converted to the given type.
20116 *
20117 * @doc {heading: 'Tensors', subheading: 'Creation'}
20118 */
20119 function variable$1(initialValue, trainable = true, name, dtype) {
20120 return ENGINE.makeVariable(initialValue, trainable, name, dtype);
20121 }
20122
20123 /**
20124 * @license
20125 * Copyright 2018 Google LLC. All Rights Reserved.
20126 * Licensed under the Apache License, Version 2.0 (the "License");
20127 * you may not use this file except in compliance with the License.
20128 * You may obtain a copy of the License at
20129 *
20130 * http://www.apache.org/licenses/LICENSE-2.0
20131 *
20132 * Unless required by applicable law or agreed to in writing, software
20133 * distributed under the License is distributed on an "AS IS" BASIS,
20134 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20135 * See the License for the specific language governing permissions and
20136 * limitations under the License.
20137 * =============================================================================
20138 */
20139 function whereImpl$2(condShape, condVals) {
20140 const indices = [];
20141 for (let i = 0; i < condVals.length; i++) {
20142 if (condVals[i]) {
20143 indices.push(i);
20144 }
20145 }
20146 const inBuffer = buffer(condShape, 'int32');
20147 const out = buffer([indices.length, condShape.length], 'int32');
20148 for (let i = 0; i < indices.length; i++) {
20149 const loc = inBuffer.indexToLoc(indices[i]);
20150 const offset = i * condShape.length;
20151 out.values.set(loc, offset);
20152 }
20153 return out.toTensor();
20154 }
20155
20156 /**
20157 * @license
20158 * Copyright 2020 Google LLC. All Rights Reserved.
20159 * Licensed under the Apache License, Version 2.0 (the "License");
20160 * you may not use this file except in compliance with the License.
20161 * You may obtain a copy of the License at
20162 *
20163 * http://www.apache.org/licenses/LICENSE-2.0
20164 *
20165 * Unless required by applicable law or agreed to in writing, software
20166 * distributed under the License is distributed on an "AS IS" BASIS,
20167 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20168 * See the License for the specific language governing permissions and
20169 * limitations under the License.
20170 * =============================================================================
20171 */
20172 /**
20173 * Returns the coordinates of true elements of condition.
20174 *
20175 * The coordinates are returned in a 2-D tensor where the first dimension (rows)
20176 * represents the number of true elements, and the second dimension (columns)
20177 * represents the coordinates of the true elements. Keep in mind, the shape of
20178 * the output tensor can vary depending on how many true values there are in
20179 * input. Indices are output in row-major order. The resulting tensor has the
20180 * shape `[numTrueElems, condition.rank]`.
20181 *
20182 * This is analogous to calling the python `tf.where(cond)` without an x or y.
20183 *
20184 * ```js
20185 * const cond = tf.tensor1d([false, false, true], 'bool');
20186 * const result = await tf.whereAsync(cond);
20187 * result.print();
20188 * ```
20189 *
20190 * @doc {heading: 'Operations', subheading: 'Logical'}
20191 */
20192 async function whereAsync_(condition) {
20193 const $condition = convertToTensor(condition, 'condition', 'whereAsync', 'bool');
20194 const vals = await $condition.data();
20195 const res = whereImpl$2($condition.shape, vals);
20196 if (condition !== $condition) {
20197 $condition.dispose();
20198 }
20199 return res;
20200 }
20201 const whereAsync = whereAsync_;
20202
20203 /**
20204 * @license
20205 * Copyright 2018 Google LLC. All Rights Reserved.
20206 * Licensed under the Apache License, Version 2.0 (the "License");
20207 * you may not use this file except in compliance with the License.
20208 * You may obtain a copy of the License at
20209 *
20210 * http://www.apache.org/licenses/LICENSE-2.0
20211 *
20212 * Unless required by applicable law or agreed to in writing, software
20213 * distributed under the License is distributed on an "AS IS" BASIS,
20214 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20215 * See the License for the specific language governing permissions and
20216 * limitations under the License.
20217 * =============================================================================
20218 */
20219 /**
20220 * Apply boolean mask to tensor.
20221 *
20222 * ```js
20223 * const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
20224 * const mask = tf.tensor1d([1, 0, 1], 'bool');
20225 * const result = await tf.booleanMaskAsync(tensor, mask);
20226 * result.print();
20227 * ```
20228 *
20229 * @param tensor N-D tensor.
20230 * @param mask K-D boolean tensor, K <= N and K must be known statically.
20231 * @param axis A 0-D int Tensor representing the axis in tensor to mask from.
20232 * By default, axis is 0 which will mask from the first dimension.
20233 * Otherwise K + axis <= N.
20234 *
20235 * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
20236 */
20237 async function booleanMaskAsync_(tensor, mask, axis) {
20238 const $tensor = convertToTensor(tensor, 'tensor', 'boolMask');
20239 const $mask = convertToTensor(mask, 'mask', 'boolMask', 'bool');
20240 const axisFrom = axis == null ? 0 : axis;
20241 const maskDim = $mask.rank;
20242 const tensorShape = $tensor.shape;
20243 assert$1(maskDim > 0, () => 'mask cannot be scalar');
20244 assertShapesMatch(tensorShape.slice(axisFrom, axisFrom + maskDim), $mask.shape, `mask's shape must match the first K dimensions of tensor's shape,`);
20245 let leadingSize = 1;
20246 for (let i = axisFrom; i < axisFrom + maskDim; i++) {
20247 leadingSize *= tensorShape[i];
20248 }
20249 const targetTensorShape = tensorShape.slice(0, axisFrom)
20250 .concat([leadingSize], tensorShape.slice(axisFrom + maskDim));
20251 const reshapedTensor = reshape$3($tensor, targetTensorShape);
20252 const reshapedMask = reshape$3($mask, [-1]);
20253 const positivePositions = await whereAsync(reshapedMask);
20254 const indices = squeeze(positivePositions, [1]);
20255 const res = gather$1(reshapedTensor, indices, axisFrom);
20256 // Ensure no memory leak.
20257 if (tensor !== $tensor) {
20258 $tensor.dispose();
20259 }
20260 if (mask !== $mask) {
20261 $mask.dispose();
20262 }
20263 indices.dispose();
20264 reshapedTensor.dispose();
20265 reshapedMask.dispose();
20266 positivePositions.dispose();
20267 return res;
20268 }
20269 const booleanMaskAsync = booleanMaskAsync_;
20270
20271 /**
20272 * @license
20273 * Copyright 2018 Google LLC. All Rights Reserved.
20274 * Licensed under the Apache License, Version 2.0 (the "License");
20275 * you may not use this file except in compliance with the License.
20276 * You may obtain a copy of the License at
20277 *
20278 * http://www.apache.org/licenses/LICENSE-2.0
20279 *
20280 * Unless required by applicable law or agreed to in writing, software
20281 * distributed under the License is distributed on an "AS IS" BASIS,
20282 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20283 * See the License for the specific language governing permissions and
20284 * limitations under the License.
20285 * =============================================================================
20286 */
20287 /**
20288 * Transposes the `tf.Tensor`. Permutes the dimensions according to `perm`.
20289 *
20290 * The returned `tf.Tensor`'s dimension `i` will correspond to the input
20291 * dimension `perm[i]`. If `perm` is not given, it is set to `[n-1...0]`,
20292 * where `n` is the rank of the input `tf.Tensor`. Hence by default, this
20293 * operation performs a regular matrix transpose on 2-D input `tf.Tensor`s.
20294 *
20295 * ```js
20296 * const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
20297 *
20298 * a.transpose().print(); // or tf.transpose(a)
20299 * ```
20300 *
20301 * @param x The tensor to transpose.
20302 * @param perm The permutation of the dimensions of a.
20303 * @param conjugate Will conjugate complex input if true.
20304 *
20305 * @doc {heading: 'Operations', subheading: 'Matrices'}
20306 */
20307 function transpose_(x, perm, conjugate) {
20308 const $x = convertToTensor(x, 'x', 'transpose');
20309 if (perm == null) {
20310 perm = $x.shape.map((s, i) => i).reverse();
20311 }
20312 assert$1($x.rank === perm.length, () => `Error in transpose: rank of input ${$x.rank} ` +
20313 `must match length of perm ${perm}.`);
20314 perm.forEach(axis => {
20315 assert$1(axis >= 0 && axis < $x.rank, () => `All entries in 'perm' must be between 0 and ${$x.rank - 1}` +
20316 ` but got ${perm}`);
20317 });
20318 if ($x.rank <= 1) {
20319 return $x.clone();
20320 }
20321 const inputs = { x: $x };
20322 const attrs = { perm };
20323 if ($x.dtype === 'complex64') {
20324 return tidy(() => {
20325 let $real = real$2($x);
20326 let $imag = imag$2($x);
20327 $real = ENGINE.runKernel(Transpose, { x: $real }, attrs);
20328 $imag = ENGINE.runKernel(Transpose, { x: $imag }, attrs);
20329 if (conjugate) {
20330 $imag = neg$2($imag);
20331 }
20332 return complex$2($real, $imag);
20333 });
20334 }
20335 return ENGINE.runKernel(Transpose, inputs, attrs);
20336 }
20337 const transpose$2 = /* @__PURE__ */ op({ transpose_ });
20338
20339 /**
20340 * @license
20341 * Copyright 2018 Google LLC. All Rights Reserved.
20342 * Licensed under the Apache License, Version 2.0 (the "License");
20343 * you may not use this file except in compliance with the License.
20344 * You may obtain a copy of the License at
20345 *
20346 * http://www.apache.org/licenses/LICENSE-2.0
20347 *
20348 * Unless required by applicable law or agreed to in writing, software
20349 * distributed under the License is distributed on an "AS IS" BASIS,
20350 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20351 * See the License for the specific language governing permissions and
20352 * limitations under the License.
20353 * =============================================================================
20354 */
20355 /**
20356 * Compute the moving average of a variable.
20357 *
20358 * Without zeroDebias, the moving average operation is defined by:
20359 * `v += delta`
20360 * where
20361 * `delta = (1 - decay) * (x - v)`
20362 *
20363 * With zeroDebias (default), the `delta` term is scaled to debias the
20364 * effect of the (assumed) zero-initialization of `v`.
20365 * `delta /= (1 - decay ^ step)`
20366 *
20367 * For more details on the zero-debiasing algorithm, see:
20368 * https://arxiv.org/abs/1412.6980
20369 *
20370 * Note that this function is completely stateless and does not keep track of
20371 * step count. The step count needs to be maintained by the caller and passed
20372 * in as `step`.
20373 *
20374 * @param v The current moving average value.
20375 * @param x New input value, must have the same shape and dtype as `v`.
20376 * @param decay The decay factor. Typical values are 0.95 and 0.99.
20377 * @param step Step count.
20378 * @param zeroDebias: Whether zeroDebias is to be performed (default: `true`).
20379 * @returns The new moving average value.
20380 *
20381 * @doc {heading: 'Operations', subheading: 'Moving Average'}
20382 */
20383 function movingAverage_(v, x, decay, step, zeroDebias = true) {
20384 const $v = convertToTensor(v, 'v', 'movingAverage');
20385 const $x = convertToTensor(x, 'x', 'movingAverage');
20386 const $decay = convertToTensor(decay, 'decay', 'movingAverage');
20387 assertTypesMatch($v, $x);
20388 assert$1(arraysEqual($v.shape, $x.shape), () => 'Shape mismatch in v and x');
20389 const one = scalar(1);
20390 const oneMinusDecay = sub$2(one, $decay);
20391 let update = mul(sub$2($x, $v), oneMinusDecay);
20392 if (zeroDebias) {
20393 assert$1(step != null, () => 'When using zeroDebias: true, step is required.');
20394 const $step = convertToTensor(step, 'step', 'movingAverage');
20395 update = div$1(update, sub$2(one, pow$3($decay, $step)));
20396 }
20397 return add$3($v, update);
20398 }
20399 const movingAverage = /* @__PURE__ */ op({ movingAverage_ });
20400
20401 /**
20402 * @license
20403 * Copyright 2018 Google LLC. All Rights Reserved.
20404 * Licensed under the Apache License, Version 2.0 (the "License");
20405 * you may not use this file except in compliance with the License.
20406 * You may obtain a copy of the License at
20407 *
20408 * http://www.apache.org/licenses/LICENSE-2.0
20409 *
20410 * Unless required by applicable law or agreed to in writing, software
20411 * distributed under the License is distributed on an "AS IS" BASIS,
20412 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20413 * See the License for the specific language governing permissions and
20414 * limitations under the License.
20415 * =============================================================================
20416 */
20417 /**
20418 * Creates a new tensor by applying sparse updates to individual
20419 * values or slices within a zero tensor of the given shape tensor according to
20420 * indices. This operator is the inverse of the `tf.gatherND` operator which
20421 * extracts values or slices from a given tensor.
20422 *
20423 * ```js
20424 * const indices = tf.tensor2d([4, 3, 1, 7], [4, 1], 'int32');
20425 * const updates = tf.tensor1d([9, 10, 11, 12]);
20426 * const shape = [8];
20427 * tf.scatterND(indices, updates, shape).print() //[0, 11, 0, 10, 9, 0, 0, 12]
20428 * ```
20429 *
20430 * @param indices The tensor contains the indices into the output tensor.
20431 * @param updates The tensor contains the value for the indices.
20432 * @param shape: The shape of the output tensor.
20433 *
20434 * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
20435 */
20436 function scatterND_(indices, updates, shape) {
20437 assertNonNegativeIntegerDimensions(shape);
20438 const $indices = convertToTensor(indices, 'indices', 'scatterND', 'int32');
20439 const $updates = convertToTensor(updates, 'updates', 'scatterND');
20440 validateInput$1($updates, $indices, shape);
20441 const inputs = { indices: $indices, updates: $updates };
20442 const attrs = { shape };
20443 // tslint:disable-next-line: no-unnecessary-type-assertion
20444 return ENGINE.runKernel(ScatterNd, inputs, attrs);
20445 }
20446 const scatterND = /* @__PURE__ */ op({ scatterND_ });
20447
20448 /**
20449 * Validate sparseToDense inputs.
20450 *
20451 * @param sparseIndices A 0-D, 1-D, or 2-D Tensor of type int32.
20452 * sparseIndices[i] contains the complete index where sparseValues[i] will be
20453 * placed.
20454 * @param sparseValues A 0-D or 1-D Tensor. Values
20455 * corresponding to each row of sparseIndices, or a scalar value to be used for
20456 * all sparse indices.
20457 * @param outputShape number[]. Shape of the dense output tensor.
20458 * @param validateIndices boolean. indice validation is not supported, error
20459 * will be thrown if it is set.
20460 */
20461 function validateInput(sparseIndices, sparseValues, outputShape, defaultValues) {
20462 if (sparseIndices.dtype !== 'int32') {
20463 throw new Error('tf.sparseToDense() expects the indices to be int32 type,' +
20464 ` but the dtype was ${sparseIndices.dtype}.`);
20465 }
20466 if (sparseIndices.rank > 2) {
20467 throw new Error('sparseIndices should be a scalar, vector, or matrix,' +
20468 ` but got shape ${sparseIndices.shape}.`);
20469 }
20470 const numElems = sparseIndices.rank > 0 ? sparseIndices.shape[0] : 1;
20471 const numDims = sparseIndices.rank > 1 ? sparseIndices.shape[1] : 1;
20472 if (outputShape.length !== numDims) {
20473 throw new Error('outputShape has incorrect number of elements:,' +
20474 ` ${outputShape.length}, should be: ${numDims}.`);
20475 }
20476 const numValues = sparseValues.size;
20477 if (!(sparseValues.rank === 0 ||
20478 sparseValues.rank === 1 && numValues === numElems)) {
20479 throw new Error('sparseValues has incorrect shape ' +
20480 `${sparseValues.shape}, should be [] or [${numElems}]`);
20481 }
20482 if (sparseValues.dtype !== defaultValues.dtype) {
20483 throw new Error('sparseValues.dtype must match defaultValues.dtype');
20484 }
20485 }
20486
20487 /**
20488 * @license
20489 * Copyright 2018 Google LLC. All Rights Reserved.
20490 * Licensed under the Apache License, Version 2.0 (the "License");
20491 * you may not use this file except in compliance with the License.
20492 * You may obtain a copy of the License at
20493 *
20494 * http://www.apache.org/licenses/LICENSE-2.0
20495 *
20496 * Unless required by applicable law or agreed to in writing, software
20497 * distributed under the License is distributed on an "AS IS" BASIS,
20498 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20499 * See the License for the specific language governing permissions and
20500 * limitations under the License.
20501 * =============================================================================
20502 */
20503 /**
20504 * Converts a sparse representation into a dense tensor.
20505 *
20506 * Builds an array dense with shape outputShape such that:
20507 *
20508 * // If sparseIndices is scalar
20509 * dense[i] = (i == sparseIndices ? sparseValues : defaultValue)
20510 *
20511 * // If sparseIndices is a vector, then for each i
20512 * dense[sparseIndices[i]] = sparseValues[i]
20513 *
20514 * // If sparseIndices is an n by d matrix, then for each i in [0, n)
20515 * dense[sparseIndices[i][0], ..., sparseIndices[i][d-1]] = sparseValues[i]
20516 * All other values in dense are set to defaultValue. If sparseValues is a
20517 * scalar, all sparse indices are set to this single value.
20518 *
20519 * If indices are repeated the final value is summed over all values for those
20520 * indices.
20521 *
20522 * ```js
20523 * const indices = tf.tensor1d([4, 5, 6, 1, 2, 3], 'int32');
20524 * const values = tf.tensor1d([10, 11, 12, 13, 14, 15], 'float32');
20525 * const shape = [8];
20526 * tf.sparseToDense(indices, values, shape).print();
20527 * ```
20528 *
20529 * @param sparseIndices A 0-D, 1-D, or 2-D Tensor of type int32.
20530 * sparseIndices[i] contains the complete index where sparseValues[i] will be
20531 * placed.
20532 * @param sparseValues A 0-D or 1-D Tensor. Values
20533 * corresponding to each row of sparseIndices, or a scalar value to be used for
20534 * all sparse indices.
20535 * @param outputShape Shape of the dense output tensor. The type is inferred.
20536 * @param defaultValue Scalar. Value to set for indices not specified in
20537 * sparseIndices. Defaults to zero.
20538 *
20539 * @doc {heading: 'Operations', subheading: 'Normalization'}
20540 */
20541 function sparseToDense_(sparseIndices, sparseValues, outputShape, defaultValue = 0) {
20542 assertNonNegativeIntegerDimensions(outputShape);
20543 const $sparseIndices = convertToTensor(sparseIndices, 'sparseIndices', 'sparseToDense', 'int32');
20544 const $sparseValues = convertToTensor(sparseValues, 'sparseValues', 'sparseToDense', 'string_or_numeric');
20545 const $defaultValue = convertToTensor(defaultValue, 'defaultValue', 'sparseToDense', $sparseValues.dtype);
20546 validateInput($sparseIndices, $sparseValues, outputShape, $defaultValue);
20547 const inputs = {
20548 sparseIndices: $sparseIndices,
20549 sparseValues: $sparseValues,
20550 defaultValue: $defaultValue
20551 };
20552 const attrs = { outputShape };
20553 return ENGINE.runKernel(SparseToDense, inputs, attrs);
20554 }
20555 const sparseToDense$2 = /* @__PURE__ */ op({ sparseToDense_ });
20556
20557 /**
20558 * @license
20559 * Copyright 2018 Google LLC. All Rights Reserved.
20560 * Licensed under the Apache License, Version 2.0 (the "License");
20561 * you may not use this file except in compliance with the License.
20562 * You may obtain a copy of the License at
20563 *
20564 * http://www.apache.org/licenses/LICENSE-2.0
20565 *
20566 * Unless required by applicable law or agreed to in writing, software
20567 * distributed under the License is distributed on an "AS IS" BASIS,
20568 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20569 * See the License for the specific language governing permissions and
20570 * limitations under the License.
20571 * =============================================================================
20572 */
20573 /**
20574 * Gather slices from input tensor into a Tensor with shape specified by
20575 * `indices`.
20576 *
20577 * `indices` is a K-dimensional integer tensor, best thought of as a
20578 * (K-1)-dimensional tensor of indices into input, where each element defines a
20579 * slice of input:
20580 * output[\\(i_0, ..., i_{K-2}\\)] = input[indices[\\(i_0, ..., i_{K-2}\\)]]
20581 *
20582 * Whereas in `tf.gather`, `indices` defines slices into the first dimension of
20583 * input, in `tf.gatherND`, `indices` defines slices into the first N dimensions
20584 * of input, where N = indices.shape[-1].
20585 *
20586 * The last dimension of indices can be at most the rank of input:
20587 * indices.shape[-1] <= input.rank
20588 *
20589 * The last dimension of `indices` corresponds to elements
20590 * (if indices.shape[-1] == input.rank) or slices
20591 * (if indices.shape[-1] < input.rank) along dimension indices.shape[-1] of
20592 * input.
20593 * The output tensor has shape
20594 * indices.shape[:-1] + input.shape[indices.shape[-1]:]
20595 *
20596 * Note that on CPU, if an out of bound index is found, an error is returned. On
20597 * GPU, if an out of bound index is found, a 0 is stored in the corresponding
20598 * output value.
20599 *
20600 * ```js
20601 * const indices = tf.tensor2d([0, 1, 1, 0], [2,2], 'int32');
20602 * const input = tf.tensor2d([9, 10, 11, 12], [2, 2]);
20603 * tf.gatherND(input, indices).print() // [10, 11]
20604 * ```
20605 *
20606 * @param x The tensor from which to gather values.
20607 * @param indices Index tensor, must be of type int32.
20608 *
20609 * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
20610 */
20611 function gatherND_(x, indices) {
20612 const $indices = convertToTensor(indices, 'indices', 'gatherND', 'int32');
20613 const $x = convertToTensor(x, 'x', 'gatherND', 'string_or_numeric');
20614 const inputs = { params: $x, indices: $indices };
20615 return ENGINE.runKernel(GatherNd, inputs);
20616 }
20617 const gatherND = /* @__PURE__ */ op({ gatherND_ });
20618
20619 /**
20620 * @license
20621 * Copyright 2019 Google LLC. All Rights Reserved.
20622 * Licensed under the Apache License, Version 2.0 (the "License");
20623 * you may not use this file except in compliance with the License.
20624 * You may obtain a copy of the License at
20625 *
20626 * http://www.apache.org/licenses/LICENSE-2.0
20627 *
20628 * Unless required by applicable law or agreed to in writing, software
20629 * distributed under the License is distributed on an "AS IS" BASIS,
20630 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20631 * See the License for the specific language governing permissions and
20632 * limitations under the License.
20633 * =============================================================================
20634 */
20635 /**
20636 * Normalize noise shape based on provided tensor and noise shape.
20637 *
20638 * @param x Tensor.
20639 * @param noiseShape The shape for the randomly generated keep/drop flags, as
20640 * an array of numbers. Optional.
20641 * @returns Normalized noise shape.
20642 */
20643 function getNoiseShape(x, noiseShape) {
20644 if (noiseShape == null) {
20645 return x.shape.slice();
20646 }
20647 if (arraysEqual(x.shape, noiseShape)) {
20648 return noiseShape;
20649 }
20650 if (x.shape.length === noiseShape.length) {
20651 const newDimension = [];
20652 for (let i = 0; i < x.shape.length; i++) {
20653 if (noiseShape[i] == null && x.shape[i] != null) {
20654 newDimension.push(x.shape[i]);
20655 }
20656 else {
20657 newDimension.push(noiseShape[i]);
20658 }
20659 }
20660 return newDimension;
20661 }
20662 return noiseShape;
20663 }
20664
20665 /**
20666 * @license
20667 * Copyright 2018 Google LLC. All Rights Reserved.
20668 * Licensed under the Apache License, Version 2.0 (the "License");
20669 * you may not use this file except in compliance with the License.
20670 * You may obtain a copy of the License at
20671 *
20672 * http://www.apache.org/licenses/LICENSE-2.0
20673 *
20674 * Unless required by applicable law or agreed to in writing, software
20675 * distributed under the License is distributed on an "AS IS" BASIS,
20676 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20677 * See the License for the specific language governing permissions and
20678 * limitations under the License.
20679 * =============================================================================
20680 */
20681 /**
20682 * Computes dropout.
20683 *
20684 * ```js
20685 * const x = tf.tensor1d([1, 2, 2, 1]);
20686 * const rate = 0.75;
20687 * const output = tf.dropout(x, rate);
20688 * output.print();
20689 * ```
20690 *
20691 * @param x A floating point Tensor or TensorLike.
20692 * @param rate A float in the range [0, 1). The probability that each element
20693 * of x is discarded.
20694 * @param noiseShape An array of numbers of type int32, representing the
20695 * shape for randomly generated keep/drop flags. If the noiseShape has null
20696 * value, it will be automatically replaced with the x's relative dimension
20697 * size. Optional.
20698 * @param seed Used to create random seeds. Optional.
20699 * @returns A Tensor of the same shape of x.
20700 *
20701 * @doc {heading: 'Operations', subheading: 'Dropout'}
20702 */
20703 function dropout_(x, rate, noiseShape, seed) {
20704 const $x = convertToTensor(x, 'x', 'dropout');
20705 assert$1($x.dtype === 'float32', () => `x has to be a floating point tensor since it's going to be ` +
20706 `scaled, but got a ${$x.dtype} tensor instead.`);
20707 assert$1(rate >= 0 && rate < 1, () => `rate must be a float in the range [0, 1), but got ${rate}.`);
20708 if (rate === 0) {
20709 return x instanceof Tensor ? $x.clone() : $x;
20710 }
20711 const $noiseShape = getNoiseShape($x, noiseShape);
20712 const keepProb = 1 - rate;
20713 const multiplier = div$1(floor$2(add$3(randomUniform$1($noiseShape, 0, 1, 'float32', seed), keepProb)), keepProb);
20714 return mul($x, multiplier);
20715 }
20716 const dropout$2 = /* @__PURE__ */ op({ dropout_ });
20717
20718 /**
20719 * @license
20720 * Copyright 2019 Google LLC. All Rights Reserved.
20721 * Licensed under the Apache License, Version 2.0 (the "License");
20722 * you may not use this file except in compliance with the License.
20723 * You may obtain a copy of the License at
20724 *
20725 * http://www.apache.org/licenses/LICENSE-2.0
20726 *
20727 * Unless required by applicable law or agreed to in writing, software
20728 * distributed under the License is distributed on an "AS IS" BASIS,
20729 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20730 * See the License for the specific language governing permissions and
20731 * limitations under the License.
20732 * =============================================================================
20733 */
20734 function enclosingPowerOfTwo(value) {
20735 // Return 2**N for integer N such that 2**N >= value.
20736 return Math.floor(Math.pow(2, Math.ceil(Math.log(value) / Math.log(2.0))));
20737 }
20738 function cosineWindow(windowLength, a, b) {
20739 const even = 1 - windowLength % 2;
20740 const newValues = new Float32Array(windowLength);
20741 for (let i = 0; i < windowLength; ++i) {
20742 const cosArg = (2.0 * Math.PI * i) / (windowLength + even - 1);
20743 newValues[i] = a - b * Math.cos(cosArg);
20744 }
20745 return tensor1d(newValues, 'float32');
20746 }
20747
20748 /**
20749 * @license
20750 * Copyright 2019 Google LLC. All Rights Reserved.
20751 * Licensed under the Apache License, Version 2.0 (the "License");
20752 * you may not use this file except in compliance with the License.
20753 * You may obtain a copy of the License at
20754 *
20755 * http://www.apache.org/licenses/LICENSE-2.0
20756 *
20757 * Unless required by applicable law or agreed to in writing, software
20758 * distributed under the License is distributed on an "AS IS" BASIS,
20759 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20760 * See the License for the specific language governing permissions and
20761 * limitations under the License.
20762 * =============================================================================
20763 */
20764 /**
20765 * Returns whether the targets are in the top K predictions.
20766 *
20767 * ```js
20768 * const predictions = tf.tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
20769 * const targets = tf.tensor1d([2, 0]);
20770 * const precision = await tf.inTopKAsync(predictions, targets);
20771 * precision.print();
20772 * ```
20773 * @param predictions 2-D or higher `tf.Tensor` with last dimension being
20774 * at least `k`.
20775 * @param targets 1-D or higher `tf.Tensor`.
20776 * @param k Optional Number of top elements to look at for computing precision,
20777 * default to 1.
20778 *
20779 * @doc {heading: 'Operations', subheading: 'Evaluation'}
20780 */
20781 async function inTopKAsync_(predictions, targets, k = 1) {
20782 const $predictions = convertToTensor(predictions, 'predictions', 'inTopK');
20783 const $targets = convertToTensor(targets, 'targets', 'inTopK');
20784 assert$1($predictions.rank > 1, () => 'inTopK() expects the predictions to be of rank 2 or higher, ' +
20785 `but got ${$predictions.rank}`);
20786 assert$1($predictions.rank - 1 === $targets.rank, () => `predictions rank should be 1 larger than ` +
20787 `targets rank, but got predictions rank ` +
20788 `${$predictions.rank} and targets rank ${$targets.rank}`);
20789 assertShapesMatch($predictions.shape.slice(0, $predictions.shape.length - 1), $targets.shape, `predictions's shape should be align with the targets' shape, ` +
20790 'except the last dimension.');
20791 const lastDim = $predictions.shape[$predictions.shape.length - 1];
20792 assert$1(k > 0 && k <= lastDim, () => `'k' passed to inTopK() must be > 0 && <= the predictions last ` +
20793 `dimension (${lastDim}), but got ${k}`);
20794 const predictionsVals = await $predictions.data();
20795 const targetsVals = await $targets.data();
20796 // Reshape predictionsVals into a 2d tensor [batch, lastDim]
20797 // and look up topK along lastDim.
20798 const [batch, size] = [predictionsVals.length / lastDim, lastDim];
20799 const precision = getTypedArrayFromDType('bool', batch);
20800 for (let b = 0; b < batch; b++) {
20801 const offset = b * size;
20802 const vals = predictionsVals.subarray(offset, offset + size);
20803 const valAndInd = [];
20804 for (let i = 0; i < vals.length; i++) {
20805 valAndInd.push({ value: vals[i], index: i });
20806 }
20807 valAndInd.sort((a, b) => b.value - a.value);
20808 precision[b] = 0;
20809 for (let i = 0; i < k; i++) {
20810 if (valAndInd[i].index === targetsVals[b]) {
20811 precision[b] = 1;
20812 break;
20813 }
20814 }
20815 }
20816 if (predictions !== $predictions) {
20817 $predictions.dispose();
20818 }
20819 if (targets !== $targets) {
20820 $targets.dispose();
20821 }
20822 // Output precision has the same shape as targets.
20823 return tensor(precision, $targets.shape, 'bool');
20824 }
20825 const inTopKAsync = inTopKAsync_;
20826
20827 /**
20828 * @license
20829 * Copyright 2020 Google LLC. All Rights Reserved.
20830 * Licensed under the Apache License, Version 2.0 (the "License");
20831 * you may not use this file except in compliance with the License.
20832 * You may obtain a copy of the License at
20833 *
20834 * http://www.apache.org/licenses/LICENSE-2.0
20835 *
20836 * Unless required by applicable law or agreed to in writing, software
20837 * distributed under the License is distributed on an "AS IS" BASIS,
20838 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20839 * See the License for the specific language governing permissions and
20840 * limitations under the License.
20841 * =============================================================================
20842 */
20843 /**
20844 * Computes the derivative of the filter of a 2D convolution.
20845 *
20846 * @param x The input tensor, of rank 4 or rank 3 of shape
20847 * [batch, height, width, inChannels]. If rank 3, batch of 1 is assumed.
20848 * @param dy The dy image, of rank 4 or rank 3, of shape
20849 * [batch, height, width, outDepth]. If rank 3, batch of 1 is assumed.
20850 * @param filterShape The shape of the filter, length 4,
20851 * [filterHeight, filterWidth, inDepth, outDepth].
20852 * @param strides The strides of the convolution: [strideHeight,
20853 * strideWidth].
20854 * @param pad A string from: 'same', 'valid'. The type of padding algorithm
20855 * used in the forward prop of the op.
20856 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
20857 * "NHWC". Specify the data format of the input and output data. With the
20858 * default format "NHWC", the data is stored in the order of: [batch,
20859 * height, width, channels].
20860 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
20861 * provided, it will default to truncate.
20862 */
20863 function conv2DBackpropFilter_(x, dy, filterShape, strides, pad, dataFormat = 'NHWC', dimRoundingMode) {
20864 let x4D = x;
20865 if (x.rank === 3) {
20866 x4D = reshape$3(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
20867 }
20868 let dy4D = dy;
20869 if (dy4D.rank === 3) {
20870 dy4D = reshape$3(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
20871 }
20872 assert$1(x4D.rank === 4, () => `Error in conv2dDerFilter: input must be rank 4, but got shape ` +
20873 `${x4D.shape}.`);
20874 assert$1(dy4D.rank === 4, () => `Error in conv2dDerFilter: dy must be rank 4, but got shape ` +
20875 `${dy4D.shape}.`);
20876 assert$1(filterShape.length === 4, () => `Error in conv2dDerFilter: filterShape must be length 4, but got ` +
20877 `${filterShape}.`);
20878 const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
20879 const outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
20880 assert$1(inDepth === filterShape[2], () => `Error in conv2dDerFilter: depth of input ${inDepth}) must ` +
20881 `match input depth in filter (${filterShape[2]}.`);
20882 assert$1(outDepth === filterShape[3], () => `Error in conv2dDerFilter: depth of dy (${outDepth}) must ` +
20883 `match output depth for filter (${filterShape[3]}).`);
20884 checkPadOnDimRoundingMode('conv2dDerFilter', pad, dimRoundingMode);
20885 const inputs = { x: x4D, dy: dy4D };
20886 const attrs = { strides, pad, dataFormat, dimRoundingMode, filterShape };
20887 // tslint:disable-next-line: no-unnecessary-type-assertion
20888 return ENGINE.runKernel(Conv2DBackpropFilter, inputs, attrs);
20889 }
20890 const conv2DBackpropFilter$2 = /* @__PURE__ */ op({ conv2DBackpropFilter_ });
20891
20892 /**
20893 * @license
20894 * Copyright 2019 Google LLC. All Rights Reserved.
20895 * Licensed under the Apache License, Version 2.0 (the "License");
20896 * you may not use this file except in compliance with the License.
20897 * You may obtain a copy of the License at
20898 *
20899 * http://www.apache.org/licenses/LICENSE-2.0
20900 *
20901 * Unless required by applicable law or agreed to in writing, software
20902 * distributed under the License is distributed on an "AS IS" BASIS,
20903 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20904 * See the License for the specific language governing permissions and
20905 * limitations under the License.
20906 * =============================================================================
20907 */
20908 // Returns gradient for fused activation.
20909 function getFusedDyActivation(dy, y, activation) {
20910 if (activation == null || activation === 'linear') {
20911 return dy;
20912 }
20913 if (activation === 'relu') {
20914 return mul(dy, step$2(y));
20915 }
20916 throw new Error(`Cannot compute gradient for fused activation ${activation}.`);
20917 }
20918 // Returns gradient for fused bias.
20919 function getFusedBiasGradient(bias, dyActivation) {
20920 let res = dyActivation;
20921 const reduceAxes = getReductionAxes(bias.shape, dyActivation.shape);
20922 if (reduceAxes.length > 0) {
20923 res = sum$3(res, reduceAxes);
20924 }
20925 return reshape$3(res, bias.shape);
20926 }
20927 function applyActivation$1(x, activation, preluActivationWeights, leakyreluAlpha) {
20928 if (activation === 'linear') {
20929 return x;
20930 }
20931 else if (activation === 'relu') {
20932 return relu$2(x);
20933 }
20934 else if (activation === 'elu') {
20935 return elu$4(x);
20936 }
20937 else if (activation === 'relu6') {
20938 return relu6$2(x);
20939 }
20940 else if (activation === 'prelu') {
20941 return prelu$3(x, preluActivationWeights);
20942 }
20943 else if (activation === 'leakyrelu') {
20944 return leakyRelu$2(x, leakyreluAlpha);
20945 }
20946 else if (activation === 'sigmoid') {
20947 return sigmoid$2(x);
20948 }
20949 throw new Error(`Unknown fused activation ${activation}.`);
20950 }
20951 // Whether we should call fused ops.
20952 const shouldFuse = (gradientDepth, activation) => {
20953 const gradientMode = gradientDepth > 0;
20954 return !gradientMode || activation === 'linear';
20955 };
20956
20957 /**
20958 * @license
20959 * Copyright 2019 Google LLC. All Rights Reserved.
20960 * Licensed under the Apache License, Version 2.0 (the "License");
20961 * you may not use this file except in compliance with the License.
20962 * You may obtain a copy of the License at
20963 *
20964 * http://www.apache.org/licenses/LICENSE-2.0
20965 *
20966 * Unless required by applicable law or agreed to in writing, software
20967 * distributed under the License is distributed on an "AS IS" BASIS,
20968 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20969 * See the License for the specific language governing permissions and
20970 * limitations under the License.
20971 * =============================================================================
20972 */
20973 /**
20974 * Computes a 2D convolution over the input x, optionally fused with adding a
20975 * bias and applying an activation.
20976 *
20977 * ```js
20978 * const inputDepth = 2;
20979 * const inShape = [2, 2, 2, inputDepth];
20980 * const outputDepth = 2;
20981 * const fSize = 1;
20982 * const pad = 0;
20983 * const strides = 1;
20984 *
20985 * const x = tf.tensor4d( [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
20986 * 16], inShape);
20987 * const w = tf.tensor4d([-1, 1, -2, 0.5], [fSize, fSize, inputDepth,
20988 * outputDepth]);
20989 *
20990 * tf.fused.conv2d({ x, filter: w, strides, pad, dataFormat: 'NHWC',
20991 * dilations: [1, 1], bias: tf.scalar(5), activation: 'relu' }).print();
20992 * ```
20993 *
20994 * @param obj An object with the following properties:
20995 * @param x The input tensor, of rank 4 or rank 3, of shape
20996 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
20997 * assumed.
20998 * @param filter The filter, rank 4, of shape
20999 * `[filterHeight, filterWidth, inDepth, outDepth]`.
21000 * @param strides The strides of the convolution: `[strideHeight,
21001 * strideWidth]`.
21002 * @param pad The type of padding algorithm.
21003 * - `same` and stride 1: output will be of same size as input,
21004 * regardless of filter size.
21005 * - `valid` output will be smaller than input if filter is larger
21006 * than 1x1.
21007 * - For more info, see this guide:
21008 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
21009 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
21010 * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to
21011 * "NHWC". Specify the data format of the input and output data. With the
21012 * default format "NHWC", the data is stored in the order of: [batch,
21013 * height, width, channels]. Only "NHWC" is currently supported.
21014 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
21015 * in which we sample input values across the height and width dimensions
21016 * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
21017 * number, then `dilationHeight == dilationWidth`. If it is greater than
21018 * 1, then all values of `strides` must be 1.
21019 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
21020 * provided, it will default to truncate.
21021 * @param bias Tensor to be added to the result.
21022 * @param activation Name of activation kernel (defaults to `linear`) to be
21023 * applied
21024 * after biasAdd.
21025 * @param preluActivationWeights Tensor of prelu weights to be applied as part
21026 * of a `prelu` activation, typically the same shape as `x`.
21027 * @param leakyreluAlpha Optional. Alpha to be applied as part of a `leakyrelu`
21028 * activation.
21029 */
21030 function fusedConv2d_({ x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode, bias, activation = 'linear', preluActivationWeights, leakyreluAlpha }) {
21031 activation = activation || 'linear';
21032 if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
21033 // TODO: Transpose bias and preluActivationWeights properly for NCHW
21034 // format before computation.
21035 assert$1(dataFormat === 'NHWC', () => `Error in fused conv2d: got dataFormat of ${dataFormat} but ` +
21036 `only NHWC is currently supported for the case of gradient depth ` +
21037 `is 0 and the activation is not linear.`);
21038 let result = conv2d$4(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
21039 if (bias != null) {
21040 result = add$3(result, bias);
21041 }
21042 return applyActivation$1(result, activation, preluActivationWeights, leakyreluAlpha);
21043 }
21044 const $x = convertToTensor(x, 'x', 'conv2d', 'float32');
21045 const $filter = convertToTensor(filter, 'filter', 'conv2d', 'float32');
21046 let x4D = $x;
21047 let reshapedTo4D = false;
21048 if ($x.rank === 3) {
21049 reshapedTo4D = true;
21050 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
21051 }
21052 assert$1(x4D.rank === 4, () => `Error in fused conv2d: input must be rank 4, but got rank ` +
21053 `${x4D.rank}.`);
21054 assert$1($filter.rank === 4, () => `Error in fused conv2d: filter must be rank 4, but got rank ` +
21055 `${$filter.rank}.`);
21056 checkPadOnDimRoundingMode('fused conv2d', pad, dimRoundingMode);
21057 const inputChannels = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
21058 assert$1($filter.shape[2] === inputChannels, () => `Error in conv2d: depth of input (${inputChannels}) must match ` +
21059 `input depth for filter ${$filter.shape[2]}.`);
21060 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in conv2D: Either strides or dilations must be 1. ' +
21061 `Got strides ${strides} and dilations '${dilations}'`);
21062 const convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode);
21063 let $bias;
21064 if (bias != null) {
21065 $bias = convertToTensor(bias, 'bias', 'fused conv2d');
21066 [$bias] = makeTypesMatch($bias, $x);
21067 // According to TensorFlow, the bias is supposed be a 1-D tensor or a
21068 // scalar.
21069 //
21070 // 3-D or 4-D bias is not disabled for NHWC format, because they are
21071 // currently being used in some cases. For examplem in our code base,
21072 // https://github.com/tensorflow/tfjs/blob/b53bd47e880367ae57493f0ea628abaf08db2d5d/tfjs-core/src/ops/fused/fused_conv2d_test.ts#L1972.
21073 if (dataFormat === 'NHWC') {
21074 assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);
21075 }
21076 else {
21077 assert$1($bias.shape.length <= 1, () => `Error in fused conv2d: only supports scalar or 1-D Tensor ` +
21078 `bias for NCHW format but got the bias of ` +
21079 `rank-${$bias.shape.length}.`);
21080 assert$1($bias.shape.length === 0 || $bias.shape[0] === convInfo.outChannels ||
21081 $bias.shape[0] === 1, () => `Error in fused conv2d: bias shape (${$bias.shape}) is not ` +
21082 `compatible with the number of output channels ` +
21083 `(${convInfo.outChannels})`);
21084 }
21085 }
21086 let $preluActivationWeights;
21087 if (preluActivationWeights != null) {
21088 // PReLU's activation weights could be a scalar, a 1-D tensor or a 3-D
21089 // tensor.
21090 const alphaShape = preluActivationWeights.shape;
21091 assert$1(alphaShape.length <= 1 || alphaShape.length === 3, () => `Error in fused conv2d: only supports scalar, 1-D Tensor or ` +
21092 `3-D Tensor PReLU activation weights but got a tensor of ` +
21093 `rank-${alphaShape.length}.`);
21094 if (alphaShape.length === 1) {
21095 // Whether the data format is NCHW or NHWC, the 1-D PReLU activation
21096 // weights tensor should be aligned with the output channels of conv2d
21097 // result.
21098 assert$1(alphaShape[0] === 1 || alphaShape[0] === convInfo.outChannels, () => `Error in fused conv2d: PReLU activation weights ` +
21099 `(${alphaShape}) is not compatible with the number of output ` +
21100 `channels (${convInfo.outChannels}).`);
21101 }
21102 else if (alphaShape.length === 3) {
21103 // Whether the data format is NCHW or NHWC, the PReLU activation weights
21104 // tensor should has the compatible shape with the result of conv2d.
21105 try {
21106 assertAndGetBroadcastShape(alphaShape, convInfo.outShape);
21107 }
21108 catch (e) {
21109 const errMsg = `Error in fused conv2d: PReLU activation weights (${alphaShape}) ` +
21110 `is not compatible with the output shape of the conv2d ` +
21111 `(${convInfo.outShape}).`;
21112 throw Error(errMsg);
21113 }
21114 }
21115 $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused conv2d');
21116 }
21117 const grad = (dy, saved) => {
21118 assert$1(dataFormat === 'NHWC', () => `Error in gradient of fused conv2D: got dataFormat of ${dataFormat} but only NHWC is currently supported.`);
21119 const [$filter, x4D, y, $bias] = saved;
21120 const dyActivation = getFusedDyActivation(dy, y, activation);
21121 assert$1(tupleValuesAreOne(dilations), () => 'Error in gradient of fused conv2D: ' +
21122 `dilation rates greater than 1 ` +
21123 `are not yet supported in gradients. Got dilations '${dilations}'`);
21124 const xDer = conv2DBackpropInput$2(x4D.shape, dyActivation, $filter, strides, pad);
21125 const filterDer = conv2DBackpropFilter$2(x4D, dyActivation, $filter.shape, strides, pad);
21126 const der = [xDer, filterDer];
21127 if ($bias != null) {
21128 const biasDer = getFusedBiasGradient($bias, dyActivation);
21129 der.push(biasDer);
21130 }
21131 return der;
21132 };
21133 const inputs = {
21134 x: x4D,
21135 filter: $filter,
21136 bias: $bias,
21137 preluActivationWeights: $preluActivationWeights
21138 };
21139 const attrs = {
21140 strides,
21141 pad,
21142 dataFormat,
21143 dilations,
21144 dimRoundingMode,
21145 activation,
21146 leakyreluAlpha
21147 };
21148 // Depending on the the params passed in we will have different number of
21149 // inputs and thus a a different number of elements in the gradient.
21150 if (bias == null) {
21151 const customOp = customGrad((x4D, filter, save) => {
21152 let res =
21153 // tslint:disable-next-line: no-unnecessary-type-assertion
21154 ENGINE.runKernel(FusedConv2D, inputs, attrs);
21155 save([filter, x4D, res]);
21156 if (reshapedTo4D) {
21157 // tslint:disable-next-line: no-unnecessary-type-assertion
21158 res = reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
21159 }
21160 return { value: res, gradFunc: grad };
21161 });
21162 return customOp(x4D, $filter);
21163 }
21164 else {
21165 const customOpWithBias = customGrad((x4D, filter, bias, save) => {
21166 let res = ENGINE.runKernel(FusedConv2D, inputs, attrs);
21167 save([filter, x4D, res, bias]);
21168 if (reshapedTo4D) {
21169 // tslint:disable-next-line: no-unnecessary-type-assertion
21170 res = reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
21171 }
21172 return { value: res, gradFunc: grad };
21173 });
21174 return customOpWithBias(x4D, $filter, $bias);
21175 }
21176 }
21177 const conv2d$3 = /* @__PURE__ */ op({ fusedConv2d_ });
21178
21179 /**
21180 * @license
21181 * Copyright 2020 Google LLC. All Rights Reserved.
21182 * Licensed under the Apache License, Version 2.0 (the "License");
21183 * you may not use this file except in compliance with the License.
21184 * You may obtain a copy of the License at
21185 *
21186 * http://www.apache.org/licenses/LICENSE-2.0
21187 *
21188 * Unless required by applicable law or agreed to in writing, software
21189 * distributed under the License is distributed on an "AS IS" BASIS,
21190 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21191 * See the License for the specific language governing permissions and
21192 * limitations under the License.
21193 * =============================================================================
21194 */
21195 function depthwiseConv2dNativeBackpropFilter_(x, dy, filterShape, strides, pad, dilations = [1, 1], dimRoundingMode) {
21196 let x4D = x;
21197 if (x.rank === 3) {
21198 x4D = reshape$3(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
21199 }
21200 let dy4D = dy;
21201 if (dy4D.rank === 3) {
21202 dy4D = reshape$3(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
21203 }
21204 const inputs = { x: x4D, dy: dy4D };
21205 const attrs = { strides, pad, dimRoundingMode, dilations, filterShape };
21206 // tslint:disable-next-line: no-unnecessary-type-assertion
21207 return ENGINE.runKernel(DepthwiseConv2dNativeBackpropFilter, inputs, attrs);
21208 }
21209 const depthwiseConv2dNativeBackpropFilter$2 = op({ depthwiseConv2dNativeBackpropFilter_ });
21210
21211 /**
21212 * @license
21213 * Copyright 2020 Google LLC. All Rights Reserved.
21214 * Licensed under the Apache License, Version 2.0 (the "License");
21215 * you may not use this file except in compliance with the License.
21216 * You may obtain a copy of the License at
21217 *
21218 * http://www.apache.org/licenses/LICENSE-2.0
21219 *
21220 * Unless required by applicable law or agreed to in writing, software
21221 * distributed under the License is distributed on an "AS IS" BASIS,
21222 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21223 * See the License for the specific language governing permissions and
21224 * limitations under the License.
21225 * =============================================================================
21226 */
21227 function depthwiseConv2dNativeBackpropInput_(xShape, dy, filter, strides, pad, dilations = [1, 1], dimRoundingMode) {
21228 let dy4D = dy;
21229 let reshapedTo4D = false;
21230 if (dy.rank === 3) {
21231 reshapedTo4D = true;
21232 dy4D = reshape$3(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
21233 }
21234 const inputs = { dy: dy4D, filter };
21235 const attrs = { strides, pad, dimRoundingMode, dilations, inputShape: xShape };
21236 const res =
21237 // tslint:disable-next-line: no-unnecessary-type-assertion
21238 ENGINE.runKernel(DepthwiseConv2dNativeBackpropInput, inputs, attrs);
21239 if (reshapedTo4D) {
21240 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
21241 }
21242 return res;
21243 }
21244 const depthwiseConv2dNativeBackpropInput$2 = op({ depthwiseConv2dNativeBackpropInput_ });
21245
21246 /**
21247 * @license
21248 * Copyright 2019 Google LLC. All Rights Reserved.
21249 * Licensed under the Apache License, Version 2.0 (the "License");
21250 * you may not use this file except in compliance with the License.
21251 * You may obtain a copy of the License at
21252 *
21253 * http://www.apache.org/licenses/LICENSE-2.0
21254 *
21255 * Unless required by applicable law or agreed to in writing, software
21256 * distributed under the License is distributed on an "AS IS" BASIS,
21257 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21258 * See the License for the specific language governing permissions and
21259 * limitations under the License.
21260 * =============================================================================
21261 */
21262 /**
21263 * Computes depthwise 2D convolution, optionally fused with adding a
21264 * bias and applying an activation.
21265 *
21266 * Given a 4D `input` array and a `filter` array of shape
21267 * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing
21268 * `inChannels` convolutional filters of depth 1, this op applies a
21269 * different filter to each input channel (expanding from 1 channel to
21270 * `channelMultiplier` channels for each), then concatenates the results
21271 * together. The output has `inChannels * channelMultiplier` channels.
21272 *
21273 * See
21274 * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d](
21275 * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d)
21276 * for more details.
21277 *
21278 * @param obj An object with the following properties:
21279 * @param x The input tensor, of rank 4 or rank 3, of shape
21280 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
21281 * assumed.
21282 * @param filter The filter tensor, rank 4, of shape
21283 * `[filterHeight, filterWidth, inChannels, channelMultiplier]`.
21284 * @param strides The strides of the convolution: `[strideHeight,
21285 * strideWidth]`. If strides is a single number, then `strideHeight ==
21286 * strideWidth`.
21287 * @param pad The type of padding algorithm.
21288 * - `same` and stride 1: output will be of same size as input,
21289 * regardless of filter size.
21290 * - `valid`: output will be smaller than input if filter is larger
21291 * than 1x1.
21292 * - For more info, see this guide:
21293 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
21294 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
21295 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
21296 * in which we sample input values across the height and width dimensions
21297 * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
21298 * number, then `dilationHeight == dilationWidth`. If it is greater than
21299 * 1, then all values of `strides` must be 1.
21300 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
21301 * "NHWC". Specify the data format of the input and output data. With the
21302 * default format "NHWC", the data is stored in the order of: [batch,
21303 * height, width, channels]. Only "NHWC" is currently supported.
21304 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
21305 * provided, it will default to truncate.
21306 * @param bias Tensor to be added to the result.
21307 * @param activation Name of activation kernel (defaults to `linear`).
21308 * @param preluActivationWeights Tensor of prelu weights to be applied as part
21309 * of a `prelu` activation, typically the same shape as `x`.
21310 * @param leakyreluAlpha Optional. Alpha to be applied as part of a `leakyrelu`
21311 * activation.
21312 */
21313 function fusedDepthwiseConv2d_({ x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode, bias, activation = 'linear', preluActivationWeights, leakyreluAlpha }) {
21314 if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
21315 let result = depthwiseConv2d$3(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
21316 if (bias != null) {
21317 result = add$3(result, bias);
21318 }
21319 return applyActivation$1(result, activation, preluActivationWeights, leakyreluAlpha);
21320 }
21321 const $x = convertToTensor(x, 'x', 'depthwiseConv2d', 'float32');
21322 const $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d', 'float32');
21323 let x4D = $x;
21324 let reshapedTo4D = false;
21325 if ($x.rank === 3) {
21326 reshapedTo4D = true;
21327 x4D = reshape$3($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
21328 }
21329 assert$1(x4D.rank === 4, () => `Error in fused depthwiseConv2d: input must be rank 4, but got ` +
21330 `rank ${x4D.rank}.`);
21331 assert$1($filter.rank === 4, () => `Error in fused depthwiseConv2d: filter must be rank 4, ` +
21332 `but got rank ${$filter.rank}.`);
21333 assert$1(x4D.shape[3] === $filter.shape[2], () => `Error in fused depthwiseConv2d: number of input channels ` +
21334 `(${x4D.shape[3]}) must match the inChannels dimension in ` +
21335 `filter ${$filter.shape[2]}.`);
21336 if (dilations == null) {
21337 dilations = [1, 1];
21338 }
21339 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in fused depthwiseConv2d: Either strides or dilations must ' +
21340 `be 1. Got strides ${strides} and dilations '${dilations}'`);
21341 checkPadOnDimRoundingMode('fused depthwiseConv2d', pad, dimRoundingMode);
21342 const convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
21343 let $bias;
21344 if (bias != null) {
21345 $bias = convertToTensor(bias, 'bias', 'fused conv2d');
21346 [$bias] = makeTypesMatch($bias, $x);
21347 assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);
21348 }
21349 let $preluActivationWeights;
21350 if (preluActivationWeights != null) {
21351 $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused depthwiseConv2d');
21352 }
21353 const grad = (dy, saved) => {
21354 assert$1(tupleValuesAreOne(dilations), () => 'Error in gradient of fused depthwiseConv2d: dilation rates ' +
21355 `greater than 1 are not yet supported. Got dilations ` +
21356 `'${dilations}'`);
21357 const [$filter, x4D, y, bias] = saved;
21358 const dyActivation = getFusedDyActivation(dy, y, activation);
21359 const xDer = depthwiseConv2dNativeBackpropInput$2(x4D.shape, dyActivation, $filter, strides, pad, dilations, dimRoundingMode);
21360 const filterDer = depthwiseConv2dNativeBackpropFilter$2(x4D, dyActivation, $filter.shape, strides, pad, dilations, dimRoundingMode);
21361 if (bias != null) {
21362 const biasDer = getFusedBiasGradient($bias, dyActivation);
21363 return [xDer, filterDer, biasDer];
21364 }
21365 return [xDer, filterDer];
21366 };
21367 const inputs = {
21368 x: x4D,
21369 filter: $filter,
21370 bias: $bias,
21371 preluActivationWeights: $preluActivationWeights
21372 };
21373 const attrs = {
21374 strides,
21375 pad,
21376 dataFormat,
21377 dilations,
21378 dimRoundingMode,
21379 activation,
21380 leakyreluAlpha
21381 };
21382 // Depending on the the params passed in we will have different number of
21383 // inputs and thus a a different number of elements in the gradient.
21384 if (bias == null) {
21385 const customOp = customGrad((x4D, filter, save) => {
21386 // tslint:disable-next-line: no-unnecessary-type-assertion
21387 let res = ENGINE.runKernel(FusedDepthwiseConv2D, inputs, attrs);
21388 save([filter, x4D, res]);
21389 if (reshapedTo4D) {
21390 // tslint:disable-next-line: no-unnecessary-type-assertion
21391 res = reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
21392 }
21393 return { value: res, gradFunc: grad };
21394 });
21395 return customOp(x4D, $filter);
21396 }
21397 else {
21398 const customOpWithBias = customGrad((x4D, filter, bias, save) => {
21399 // tslint:disable-next-line: no-unnecessary-type-assertion
21400 let res = ENGINE.runKernel(FusedDepthwiseConv2D, inputs, attrs);
21401 save([filter, x4D, res, bias]);
21402 if (reshapedTo4D) {
21403 // tslint:disable-next-line: no-unnecessary-type-assertion
21404 res = reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
21405 }
21406 return { value: res, gradFunc: grad };
21407 });
21408 return customOpWithBias(x4D, $filter, $bias);
21409 }
21410 }
21411 const depthwiseConv2d$2 = /* @__PURE__ */ op({ fusedDepthwiseConv2d_ });
21412
21413 /**
21414 * @license
21415 * Copyright 2019 Google LLC. All Rights Reserved.
21416 * Licensed under the Apache License, Version 2.0 (the "License");
21417 * you may not use this file except in compliance with the License.
21418 * You may obtain a copy of the License at
21419 *
21420 * http://www.apache.org/licenses/LICENSE-2.0
21421 *
21422 * Unless required by applicable law or agreed to in writing, software
21423 * distributed under the License is distributed on an "AS IS" BASIS,
21424 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21425 * See the License for the specific language governing permissions and
21426 * limitations under the License.
21427 * =============================================================================
21428 */
21429 /**
21430 * Computes the dot product of two matrices with optional activation and bias.
21431 *
21432 * ```js
21433 * const a = tf.tensor2d([-1, -2], [1, 2]);
21434 * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]);
21435 * const bias = tf.tensor2d([1, 2], [1, 2]);
21436 *
21437 * tf.fused.matMul({a, b, bias, activation: 'relu'}).print();
21438 * ```
21439 *
21440 * @param obj An object with the following properties:
21441 * - `a` First matrix in dot product operation.
21442 * - `b` Second matrix in dot product operation.
21443 * - `transposeA` If true, `a` is transposed before multiplication.
21444 * - `transposeB` If true, `b` is transposed before multiplication.
21445 * - `bias` Matrix to be added to the result.
21446 * - `activation` Name of activation kernel (defaults to `linear`).
21447 * - `preluActivationWeights` Tensor of prelu weights.
21448 * - `leakyreluAlpha` Alpha of leakyrelu.
21449 */
21450 function fusedMatMul_({ a, b, transposeA = false, transposeB = false, bias, activation = 'linear', preluActivationWeights, leakyreluAlpha = 0.2, }) {
21451 if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
21452 let result = matMul$1(a, b, transposeA, transposeB);
21453 if (bias != null) {
21454 result = add$3(result, bias);
21455 }
21456 return applyActivation$1(result, activation, preluActivationWeights, leakyreluAlpha);
21457 }
21458 let $a = convertToTensor(a, 'a', 'fused matMul');
21459 let $b = convertToTensor(b, 'b', 'fused matMul');
21460 [$a, $b] = makeTypesMatch($a, $b);
21461 const innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];
21462 const innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];
21463 const outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];
21464 const outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];
21465 const outerDimsA = $a.shape.slice(0, -2);
21466 const outerDimsB = $b.shape.slice(0, -2);
21467 const batchDimA = sizeFromShape(outerDimsA);
21468 const batchDimB = sizeFromShape(outerDimsB);
21469 assert$1(innerShapeA === innerShapeB, () => `Error in fused matMul: inner shapes (${innerShapeA}) and (` +
21470 `${innerShapeB}) of Tensors with shapes ${$a.shape} and ` +
21471 `${$b.shape} and transposeA=${transposeA}` +
21472 ` and transposeB=${transposeB} must match.`);
21473 const outShapeOuterDims = assertAndGetBroadcastShape($a.shape.slice(0, -2), $b.shape.slice(0, -2));
21474 const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
21475 const a3D = transposeA ?
21476 reshape$3($a, [batchDimA, innerShapeA, outerShapeA]) :
21477 reshape$3($a, [batchDimA, outerShapeA, innerShapeA]);
21478 const b3D = transposeB ?
21479 reshape$3($b, [batchDimB, outerShapeB, innerShapeB]) :
21480 reshape$3($b, [batchDimB, innerShapeB, outerShapeB]);
21481 let $bias;
21482 if (bias != null) {
21483 $bias = convertToTensor(bias, 'bias', 'fused matMul');
21484 [$bias] = makeTypesMatch($bias, $a);
21485 assertAndGetBroadcastShape(outShape, $bias.shape);
21486 }
21487 let $preluActivationWeights;
21488 if (preluActivationWeights != null) {
21489 $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused matMul');
21490 }
21491 const grad = (dy, saved) => {
21492 const [a3D, b3D, y, $bias] = saved;
21493 // we reshape dy because the result of the forward is not
21494 // necessarily going to be a 3d tensor due to a reshape done at the end of
21495 // the customOp.
21496 const dyActivation = getFusedDyActivation(reshape$3(dy, y.shape), y, activation);
21497 let aDer;
21498 let bDer;
21499 if (!transposeA && !transposeB) {
21500 aDer = matMul$1(dyActivation, b3D, false, true);
21501 bDer = matMul$1(a3D, dyActivation, true, false);
21502 }
21503 else if (!transposeA && transposeB) {
21504 aDer = matMul$1(dyActivation, b3D, false, false);
21505 bDer = matMul$1(dyActivation, a3D, true, false);
21506 }
21507 else if (transposeA && !transposeB) {
21508 aDer = matMul$1(b3D, dyActivation, false, true);
21509 bDer = matMul$1(a3D, dyActivation, false, false);
21510 }
21511 else {
21512 aDer = matMul$1(b3D, dyActivation, true, true);
21513 bDer = matMul$1(dyActivation, a3D, true, true);
21514 }
21515 if (bias != null) {
21516 const biasDer = getFusedBiasGradient($bias, dyActivation);
21517 return [aDer, bDer, biasDer];
21518 }
21519 else {
21520 return [aDer, bDer];
21521 }
21522 };
21523 const inputs = {
21524 a: a3D,
21525 b: b3D,
21526 bias: $bias,
21527 preluActivationWeights: $preluActivationWeights
21528 };
21529 const attrs = { transposeA, transposeB, activation, leakyreluAlpha };
21530 // Depending on the the params passed in we will have different number of
21531 // inputs and thus a a different number of elements in the gradient.
21532 if (bias == null) {
21533 const customOp = customGrad((a3D, b3D, save) => {
21534 const res =
21535 // tslint:disable-next-line: no-unnecessary-type-assertion
21536 ENGINE.runKernel(_FusedMatMul, inputs, attrs);
21537 save([a3D, b3D, res]);
21538 return { value: reshape$3(res, outShape), gradFunc: grad };
21539 });
21540 return customOp(a3D, b3D);
21541 }
21542 else {
21543 const customOpWithBias = customGrad((a3D, b3D, $bias, save) => {
21544 const res =
21545 // tslint:disable-next-line: no-unnecessary-type-assertion
21546 ENGINE.runKernel(_FusedMatMul, inputs, attrs);
21547 save([a3D, b3D, res, $bias]);
21548 return { value: reshape$3(res, outShape), gradFunc: grad };
21549 });
21550 return customOpWithBias(a3D, b3D, $bias);
21551 }
21552 }
21553 const matMul = /* @__PURE__ */ op({ fusedMatMul_ });
21554
21555 /**
21556 * @license
21557 * Copyright 2019 Google LLC. All Rights Reserved.
21558 * Licensed under the Apache License, Version 2.0 (the "License");
21559 * you may not use this file except in compliance with the License.
21560 * You may obtain a copy of the License at
21561 *
21562 * http://www.apache.org/licenses/LICENSE-2.0
21563 *
21564 * Unless required by applicable law or agreed to in writing, software
21565 * distributed under the License is distributed on an "AS IS" BASIS,
21566 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21567 * See the License for the specific language governing permissions and
21568 * limitations under the License.
21569 * =============================================================================
21570 */
21571
21572 var fused_ops = /*#__PURE__*/Object.freeze({
21573 __proto__: null,
21574 conv2d: conv2d$3,
21575 depthwiseConv2d: depthwiseConv2d$2,
21576 matMul: matMul
21577 });
21578
21579 /**
21580 * @license
21581 * Copyright 2019 Google LLC. All Rights Reserved.
21582 * Licensed under the Apache License, Version 2.0 (the "License");
21583 * you may not use this file except in compliance with the License.
21584 * You may obtain a copy of the License at
21585 *
21586 * http://www.apache.org/licenses/LICENSE-2.0
21587 *
21588 * Unless required by applicable law or agreed to in writing, software
21589 * distributed under the License is distributed on an "AS IS" BASIS,
21590 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21591 * See the License for the specific language governing permissions and
21592 * limitations under the License.
21593 * =============================================================================
21594 */
21595 /**
21596 * Generate a hamming window.
21597 *
21598 * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
21599 *
21600 * ```js
21601 * tf.signal.hammingWindow(10).print();
21602 * ```
21603 * @param The length of window
21604 *
21605 * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
21606 */
21607 function hammingWindow_(windowLength) {
21608 return cosineWindow(windowLength, 0.54, 0.46);
21609 }
21610 const hammingWindow = /* @__PURE__ */ op({ hammingWindow_ });
21611
21612 /**
21613 * @license
21614 * Copyright 2019 Google LLC. All Rights Reserved.
21615 * Licensed under the Apache License, Version 2.0 (the "License");
21616 * you may not use this file except in compliance with the License.
21617 * You may obtain a copy of the License at
21618 *
21619 * http://www.apache.org/licenses/LICENSE-2.0
21620 *
21621 * Unless required by applicable law or agreed to in writing, software
21622 * distributed under the License is distributed on an "AS IS" BASIS,
21623 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21624 * See the License for the specific language governing permissions and
21625 * limitations under the License.
21626 * =============================================================================
21627 */
21628 /**
21629 * Generate a Hann window.
21630 *
21631 * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
21632 *
21633 * ```js
21634 * tf.signal.hannWindow(10).print();
21635 * ```
21636 * @param The length of window
21637 *
21638 * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
21639 */
21640 function hannWindow_(windowLength) {
21641 return cosineWindow(windowLength, 0.5, 0.5);
21642 }
21643 const hannWindow = /* @__PURE__ */ op({ hannWindow_ });
21644
21645 /**
21646 * @license
21647 * Copyright 2019 Google LLC. All Rights Reserved.
21648 * Licensed under the Apache License, Version 2.0 (the "License");
21649 * you may not use this file except in compliance with the License.
21650 * You may obtain a copy of the License at
21651 *
21652 * http://www.apache.org/licenses/LICENSE-2.0
21653 *
21654 * Unless required by applicable law or agreed to in writing, software
21655 * distributed under the License is distributed on an "AS IS" BASIS,
21656 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21657 * See the License for the specific language governing permissions and
21658 * limitations under the License.
21659 * =============================================================================
21660 */
21661 /**
21662 * Expands input into frames of frameLength.
21663 * Slides a window size with frameStep.
21664 *
21665 * ```js
21666 * tf.signal.frame([1, 2, 3], 2, 1).print();
21667 * ```
21668 * @param signal The input tensor to be expanded
21669 * @param frameLength Length of each frame
21670 * @param frameStep The frame hop size in samples.
21671 * @param padEnd Whether to pad the end of signal with padValue.
21672 * @param padValue A number to use where the input signal does
21673 * not exist when padEnd is True.
21674 *
21675 * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
21676 */
21677 function frame_(signal, frameLength, frameStep, padEnd = false, padValue = 0) {
21678 let start = 0;
21679 const output = [];
21680 while (start + frameLength <= signal.size) {
21681 output.push(slice$2(signal, start, frameLength));
21682 start += frameStep;
21683 }
21684 if (padEnd) {
21685 while (start < signal.size) {
21686 const padLen = (start + frameLength) - signal.size;
21687 const pad = concat$2([
21688 slice$2(signal, start, frameLength - padLen), fill$2([padLen], padValue)
21689 ]);
21690 output.push(pad);
21691 start += frameStep;
21692 }
21693 }
21694 if (output.length === 0) {
21695 return tensor2d([], [0, frameLength]);
21696 }
21697 return reshape$3(concat$2(output), [output.length, frameLength]);
21698 }
21699 const frame = /* @__PURE__ */ op({ frame_ });
21700
21701 /**
21702 * @license
21703 * Copyright 2019 Google LLC. All Rights Reserved.
21704 * Licensed under the Apache License, Version 2.0 (the "License");
21705 * you may not use this file except in compliance with the License.
21706 * You may obtain a copy of the License at
21707 *
21708 * http://www.apache.org/licenses/LICENSE-2.0
21709 *
21710 * Unless required by applicable law or agreed to in writing, software
21711 * distributed under the License is distributed on an "AS IS" BASIS,
21712 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21713 * See the License for the specific language governing permissions and
21714 * limitations under the License.
21715 * =============================================================================
21716 */
21717 /**
21718 * Computes the Short-time Fourier Transform of signals
21719 * See: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
21720 *
21721 * ```js
21722 * const input = tf.tensor1d([1, 1, 1, 1, 1])
21723 * tf.signal.stft(input, 3, 1).print();
21724 * ```
21725 * @param signal 1-dimensional real value tensor.
21726 * @param frameLength The window length of samples.
21727 * @param frameStep The number of samples to step.
21728 * @param fftLength The size of the FFT to apply.
21729 * @param windowFn A callable that takes a window length and returns 1-d tensor.
21730 *
21731 * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
21732 */
21733 function stft_(signal, frameLength, frameStep, fftLength, windowFn = hannWindow) {
21734 if (fftLength == null) {
21735 fftLength = enclosingPowerOfTwo(frameLength);
21736 }
21737 const framedSignal = frame(signal, frameLength, frameStep);
21738 const windowedSignal = mul(framedSignal, windowFn(frameLength));
21739 return rfft(windowedSignal, fftLength);
21740 }
21741 const stft = /* @__PURE__ */ op({ stft_ });
21742
21743 /**
21744 * @license
21745 * Copyright 2020 Google LLC. All Rights Reserved.
21746 * Licensed under the Apache License, Version 2.0 (the "License");
21747 * you may not use this file except in compliance with the License.
21748 * You may obtain a copy of the License at
21749 *
21750 * http://www.apache.org/licenses/LICENSE-2.0
21751 *
21752 * Unless required by applicable law or agreed to in writing, software
21753 * distributed under the License is distributed on an "AS IS" BASIS,
21754 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21755 * See the License for the specific language governing permissions and
21756 * limitations under the License.
21757 * =============================================================================
21758 */
21759 /**
21760 * Extracts crops from the input image tensor and resizes them using bilinear
21761 * sampling or nearest neighbor sampling (possibly with aspect ratio change)
21762 * to a common output size specified by cropSize.
21763 *
21764 * @param image 4d tensor of shape `[batch,imageHeight,imageWidth, depth]`,
21765 * where imageHeight and imageWidth must be positive, specifying the
21766 * batch of images from which to take crops
21767 * @param boxes 2d float32 tensor of shape `[numBoxes, 4]`. Each entry is
21768 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the normalized
21769 * coordinates of the box in the `boxInd[i]`th image in the batch
21770 * @param boxInd 1d int32 tensor of shape `[numBoxes]` with values in range
21771 * `[0, batch)` that specifies the image that the `i`-th box refers to.
21772 * @param cropSize 1d int32 tensor of 2 elements `[cropHeigh, cropWidth]`
21773 * specifying the size to which all crops are resized to.
21774 * @param method Optional string from `'bilinear' | 'nearest'`,
21775 * defaults to bilinear, which specifies the sampling method for resizing
21776 * @param extrapolationValue A threshold for deciding when to remove boxes based
21777 * on score. Defaults to 0.
21778 * @return A 4D tensor of the shape `[numBoxes,cropHeight,cropWidth,depth]`
21779 *
21780 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
21781 */
21782 function cropAndResize_(image, boxes, boxInd, cropSize, method = 'bilinear', extrapolationValue = 0) {
21783 const $image = convertToTensor(image, 'image', 'cropAndResize');
21784 const $boxes = convertToTensor(boxes, 'boxes', 'cropAndResize', 'float32');
21785 const $boxInd = convertToTensor(boxInd, 'boxInd', 'cropAndResize', 'int32');
21786 const numBoxes = $boxes.shape[0];
21787 assert$1($image.rank === 4, () => 'Error in cropAndResize: image must be rank 4,' +
21788 `but got rank ${$image.rank}.`);
21789 assert$1($boxes.rank === 2 && $boxes.shape[1] === 4, () => `Error in cropAndResize: boxes must be have size [${numBoxes},4] ` +
21790 `but had shape ${$boxes.shape}.`);
21791 assert$1($boxInd.rank === 1 && $boxInd.shape[0] === numBoxes, () => `Error in cropAndResize: boxInd must be have size [${numBoxes}] ` +
21792 `but had shape ${$boxes.shape}.`);
21793 assert$1(cropSize.length === 2, () => `Error in cropAndResize: cropSize must be of length 2, but got ` +
21794 `length ${cropSize.length}.`);
21795 assert$1(cropSize[0] >= 1 && cropSize[1] >= 1, () => `cropSize must be atleast [1,1], but was ${cropSize}`);
21796 assert$1(method === 'bilinear' || method === 'nearest', () => `method must be bilinear or nearest, but was ${method}`);
21797 const inputs = { image: $image, boxes: $boxes, boxInd: $boxInd };
21798 const attrs = { method, extrapolationValue, cropSize };
21799 const res = ENGINE.runKernel(CropAndResize, inputs, attrs);
21800 return res;
21801 }
21802 const cropAndResize$3 = /* @__PURE__ */ op({ cropAndResize_ });
21803
21804 /**
21805 * @license
21806 * Copyright 2020 Google LLC. All Rights Reserved.
21807 * Licensed under the Apache License, Version 2.0 (the "License");
21808 * you may not use this file except in compliance with the License.
21809 * You may obtain a copy of the License at
21810 *
21811 * http://www.apache.org/licenses/LICENSE-2.0
21812 *
21813 * Unless required by applicable law or agreed to in writing, software
21814 * distributed under the License is distributed on an "AS IS" BASIS,
21815 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21816 * See the License for the specific language governing permissions and
21817 * limitations under the License.
21818 * =============================================================================
21819 */
21820 /**
21821 * Flips the image left to right. Currently available in the CPU, WebGL, and
21822 * WASM backends.
21823 *
21824 * @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
21825 */
21826 /** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */
21827 function flipLeftRight_(image) {
21828 const $image = convertToTensor(image, 'image', 'flipLeftRight', 'float32');
21829 assert$1($image.rank === 4, () => 'Error in flipLeftRight: image must be rank 4,' +
21830 `but got rank ${$image.rank}.`);
21831 const inputs = { image: $image };
21832 const res = ENGINE.runKernel(FlipLeftRight, inputs, {});
21833 return res;
21834 }
21835 const flipLeftRight = /* @__PURE__ */ op({ flipLeftRight_ });
21836
21837 /**
21838 * @license
21839 * Copyright 2021 Google LLC. All Rights Reserved.
21840 * Licensed under the Apache License, Version 2.0 (the "License");
21841 * you may not use this file except in compliance with the License.
21842 * You may obtain a copy of the License at
21843 *
21844 * http://www.apache.org/licenses/LICENSE-2.0
21845 *
21846 * Unless required by applicable law or agreed to in writing, software
21847 * distributed under the License is distributed on an "AS IS" BASIS,
21848 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21849 * See the License for the specific language governing permissions and
21850 * limitations under the License.
21851 * =============================================================================
21852 */
21853 /**
21854 * Converts images from grayscale to RGB format.
21855 *
21856 * @param image A grayscale tensor to convert. The `image`'s last dimension must
21857 * be size 1 with at least a two-dimensional shape.
21858 *
21859 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
21860 */
21861 function grayscaleToRGB_(image) {
21862 const $image = convertToTensor(image, 'image', 'grayscaleToRGB');
21863 const lastDimsIdx = $image.rank - 1;
21864 const lastDims = $image.shape[lastDimsIdx];
21865 assert$1($image.rank >= 2, () => 'Error in grayscaleToRGB: images must be at least rank 2, ' +
21866 `but got rank ${$image.rank}.`);
21867 assert$1(lastDims === 1, () => 'Error in grayscaleToRGB: last dimension of a grayscale image ' +
21868 `should be size 1, but got size ${lastDims}.`);
21869 const reps = new Array($image.rank);
21870 reps.fill(1, 0, lastDimsIdx);
21871 reps[lastDimsIdx] = 3;
21872 return tile$3($image, reps);
21873 }
21874 const grayscaleToRGB = /* @__PURE__ */ op({ grayscaleToRGB_ });
21875
21876 /**
21877 * @license
21878 * Copyright 2023 Google LLC.
21879 * Licensed under the Apache License, Version 2.0 (the "License");
21880 * you may not use this file except in compliance with the License.
21881 * You may obtain a copy of the License at
21882 *
21883 * http://www.apache.org/licenses/LICENSE-2.0
21884 *
21885 * Unless required by applicable law or agreed to in writing, software
21886 * distributed under the License is distributed on an "AS IS" BASIS,
21887 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21888 * See the License for the specific language governing permissions and
21889 * limitations under the License.
21890 * =============================================================================
21891 */
21892 /**
21893 * Converts images from RGB format to grayscale.
21894 *
21895 * @param image A RGB tensor to convert. The `image`'s last dimension must
21896 * be size 3 with at least a two-dimensional shape.
21897 *
21898 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
21899 */
21900 function rgbToGrayscale_(image) {
21901 const $image = convertToTensor(image, 'image', 'RGBToGrayscale');
21902 const lastDimsIdx = $image.rank - 1;
21903 const lastDims = $image.shape[lastDimsIdx];
21904 assert$1($image.rank >= 2, () => 'Error in RGBToGrayscale: images must be at least rank 2, ' +
21905 `but got rank ${$image.rank}.`);
21906 assert$1(lastDims === 3, () => 'Error in RGBToGrayscale: last dimension of an RGB image ' +
21907 `should be size 3, but got size ${lastDims}.`);
21908 // Remember original dtype so we can convert back if needed
21909 const origDtype = $image.dtype;
21910 const fltImage = cast$3($image, 'float32');
21911 const rgbWeights = tensor1d([0.2989, 0.5870, 0.1140]);
21912 let grayFloat;
21913 switch ($image.rank) {
21914 case 2:
21915 grayFloat = einsum$2('ij,j->i', fltImage, rgbWeights);
21916 break;
21917 case 3:
21918 grayFloat = einsum$2('ijk,k->ij', fltImage, rgbWeights);
21919 break;
21920 case 4:
21921 grayFloat = einsum$2('ijkl,l->ijk', fltImage, rgbWeights);
21922 break;
21923 case 5:
21924 grayFloat = einsum$2('ijklm,m->ijkl', fltImage, rgbWeights);
21925 break;
21926 case 6:
21927 grayFloat = einsum$2('ijklmn,n->ijklm', fltImage, rgbWeights);
21928 break;
21929 default:
21930 throw new Error('Not a valid tensor rank.');
21931 }
21932 grayFloat = expandDims$3(grayFloat, -1);
21933 return cast$3(grayFloat, origDtype);
21934 }
21935 const rgbToGrayscale = /* @__PURE__ */ op({ rgbToGrayscale_ });
21936
21937 /**
21938 * @license
21939 * Copyright 2020 Google LLC. All Rights Reserved.
21940 * Licensed under the Apache License, Version 2.0 (the "License");
21941 * you may not use this file except in compliance with the License.
21942 * You may obtain a copy of the License at
21943 *
21944 * http://www.apache.org/licenses/LICENSE-2.0
21945 *
21946 * Unless required by applicable law or agreed to in writing, software
21947 * distributed under the License is distributed on an "AS IS" BASIS,
21948 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21949 * See the License for the specific language governing permissions and
21950 * limitations under the License.
21951 * =============================================================================
21952 */
21953 /**
21954 * Rotates the input image tensor counter-clockwise with an optional offset
21955 * center of rotation. Currently available in the CPU, WebGL, and WASM backends.
21956 *
21957 * @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
21958 * @param radians The amount of rotation.
21959 * @param fillValue The value to fill in the empty space leftover
21960 * after rotation. Can be either a single grayscale value (0-255), or an
21961 * array of three numbers `[red, green, blue]` specifying the red, green,
21962 * and blue channels. Defaults to `0` (black).
21963 * @param center The center of rotation. Can be either a single value (0-1), or
21964 * an array of two numbers `[centerX, centerY]`. Defaults to `0.5` (rotates
21965 * the image around its center).
21966 *
21967 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
21968 */
21969 function rotateWithOffset_(image, radians, fillValue = 0, center = 0.5) {
21970 const $image = convertToTensor(image, 'image', 'rotateWithOffset', 'float32');
21971 assert$1($image.rank === 4, () => 'Error in rotateWithOffset: image must be rank 4,' +
21972 `but got rank ${$image.rank}.`);
21973 const inputs = { image: $image };
21974 const attrs = { radians, fillValue, center };
21975 const res = ENGINE.runKernel(RotateWithOffset, inputs, attrs);
21976 return res;
21977 }
21978 const rotateWithOffset = /* @__PURE__ */ op({ rotateWithOffset_ });
21979
21980 /**
21981 * @license
21982 * Copyright 2020 Google LLC. All Rights Reserved.
21983 * Licensed under the Apache License, Version 2.0 (the "License");
21984 * you may not use this file except in compliance with the License.
21985 * You may obtain a copy of the License at
21986 *
21987 * http://www.apache.org/licenses/LICENSE-2.0
21988 *
21989 * Unless required by applicable law or agreed to in writing, software
21990 * distributed under the License is distributed on an "AS IS" BASIS,
21991 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21992 * See the License for the specific language governing permissions and
21993 * limitations under the License.
21994 * =============================================================================
21995 */
21996 function nonMaxSuppSanityCheck(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
21997 if (iouThreshold == null) {
21998 iouThreshold = 0.5;
21999 }
22000 if (scoreThreshold == null) {
22001 scoreThreshold = Number.NEGATIVE_INFINITY;
22002 }
22003 if (softNmsSigma == null) {
22004 softNmsSigma = 0.0;
22005 }
22006 const numBoxes = boxes.shape[0];
22007 maxOutputSize = Math.min(maxOutputSize, numBoxes);
22008 assert$1(0 <= iouThreshold && iouThreshold <= 1, () => `iouThreshold must be in [0, 1], but was '${iouThreshold}'`);
22009 assert$1(boxes.rank === 2, () => `boxes must be a 2D tensor, but was of rank '${boxes.rank}'`);
22010 assert$1(boxes.shape[1] === 4, () => `boxes must have 4 columns, but 2nd dimension was ${boxes.shape[1]}`);
22011 assert$1(scores.rank === 1, () => 'scores must be a 1D tensor');
22012 assert$1(scores.shape[0] === numBoxes, () => `scores has incompatible shape with boxes. Expected ${numBoxes}, ` +
22013 `but was ${scores.shape[0]}`);
22014 assert$1(0 <= softNmsSigma && softNmsSigma <= 1, () => `softNmsSigma must be in [0, 1], but was '${softNmsSigma}'`);
22015 return { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma };
22016 }
22017
22018 /**
22019 * @license
22020 * Copyright 2020 Google LLC. All Rights Reserved.
22021 * Licensed under the Apache License, Version 2.0 (the "License");
22022 * you may not use this file except in compliance with the License.
22023 * You may obtain a copy of the License at
22024 *
22025 * http://www.apache.org/licenses/LICENSE-2.0
22026 *
22027 * Unless required by applicable law or agreed to in writing, software
22028 * distributed under the License is distributed on an "AS IS" BASIS,
22029 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22030 * See the License for the specific language governing permissions and
22031 * limitations under the License.
22032 * =============================================================================
22033 */
22034 /**
22035 * Performs non maximum suppression of bounding boxes based on
22036 * iou (intersection over union).
22037 *
22038 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
22039 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
22040 * the bounding box.
22041 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
22042 * @param maxOutputSize The maximum number of boxes to be selected.
22043 * @param iouThreshold A float representing the threshold for deciding whether
22044 * boxes overlap too much with respect to IOU. Must be between [0, 1].
22045 * Defaults to 0.5 (50% box overlap).
22046 * @param scoreThreshold A threshold for deciding when to remove boxes based
22047 * on score. Defaults to -inf, which means any score is accepted.
22048 * @return A 1D tensor with the selected box indices.
22049 *
22050 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
22051 */
22052 function nonMaxSuppression_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY) {
22053 const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression', 'float32');
22054 const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression', 'float32');
22055 const inputs = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);
22056 maxOutputSize = inputs.maxOutputSize;
22057 iouThreshold = inputs.iouThreshold;
22058 scoreThreshold = inputs.scoreThreshold;
22059 const attrs = { maxOutputSize, iouThreshold, scoreThreshold };
22060 return ENGINE.runKernel(NonMaxSuppressionV3, { boxes: $boxes, scores: $scores }, attrs);
22061 }
22062 const nonMaxSuppression = /* @__PURE__ */ op({ nonMaxSuppression_ });
22063
22064 /**
22065 * @license
22066 * Copyright 2019 Google LLC. All Rights Reserved.
22067 * Licensed under the Apache License, Version 2.0 (the "License");
22068 * you may not use this file except in compliance with the License.
22069 * You may obtain a copy of the License at
22070 *
22071 * http://www.apache.org/licenses/LICENSE-2.0
22072 *
22073 * Unless required by applicable law or agreed to in writing, software
22074 * distributed under the License is distributed on an "AS IS" BASIS,
22075 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22076 * See the License for the specific language governing permissions and
22077 * limitations under the License.
22078 * =============================================================================
22079 */
22080 /**
22081 * Inserts a value into a sorted array. This method allows duplicate, meaning it
22082 * allows inserting duplicate value, in which case, the element will be inserted
22083 * at the lowest index of the value.
22084 * @param arr The array to modify.
22085 * @param element The element to insert.
22086 * @param comparator Optional. If no comparator is specified, elements are
22087 * compared using array_util.defaultComparator, which is suitable for Strings
22088 * and Numbers in ascending arrays. If the array contains multiple instances of
22089 * the target value, the left-most instance will be returned. To provide a
22090 * comparator, it should take 2 arguments to compare and return a negative,
22091 * zero, or a positive number.
22092 */
22093 function binaryInsert(arr, element, comparator) {
22094 const index = binarySearch(arr, element, comparator);
22095 const insertionPoint = index < 0 ? -(index + 1) : index;
22096 arr.splice(insertionPoint, 0, element);
22097 }
22098 /**
22099 * Searches the array for the target using binary search, returns the index
22100 * of the found element, or position to insert if element not found. If no
22101 * comparator is specified, elements are compared using array_
22102 * util.defaultComparator, which is suitable for Strings and Numbers in
22103 * ascending arrays. If the array contains multiple instances of the target
22104 * value, the left-most instance will be returned.
22105 * @param arr The array to be searched in.
22106 * @param target The target to be searched for.
22107 * @param comparator Should take 2 arguments to compare and return a negative,
22108 * zero, or a positive number.
22109 * @return Lowest index of the target value if found, otherwise the insertion
22110 * point where the target should be inserted, in the form of
22111 * (-insertionPoint - 1).
22112 */
22113 function binarySearch(arr, target, comparator) {
22114 return binarySearch_(arr, target, comparator || defaultComparator);
22115 }
22116 /**
22117 * Compares its two arguments for order.
22118 * @param a The first element to be compared.
22119 * @param b The second element to be compared.
22120 * @return A negative number, zero, or a positive number as the first
22121 * argument is less than, equal to, or greater than the second.
22122 */
22123 function defaultComparator(a, b) {
22124 return a > b ? 1 : a < b ? -1 : 0;
22125 }
22126 function binarySearch_(arr, target, comparator) {
22127 let left = 0;
22128 let right = arr.length;
22129 let middle = 0;
22130 let found = false;
22131 while (left < right) {
22132 middle = left + ((right - left) >>> 1);
22133 const compareResult = comparator(target, arr[middle]);
22134 if (compareResult > 0) {
22135 left = middle + 1;
22136 }
22137 else {
22138 right = middle;
22139 // If compareResult is 0, the value is found. We record it is found,
22140 // and then keep looking because there may be duplicate.
22141 found = !compareResult;
22142 }
22143 }
22144 return found ? left : -left - 1;
22145 }
22146
22147 /**
22148 * @license
22149 * Copyright 2020 Google LLC. All Rights Reserved.
22150 * Licensed under the Apache License, Version 2.0 (the "License");
22151 * you may not use this file except in compliance with the License.
22152 * You may obtain a copy of the License at
22153 *
22154 * http://www.apache.org/licenses/LICENSE-2.0
22155 *
22156 * Unless required by applicable law or agreed to in writing, software
22157 * distributed under the License is distributed on an "AS IS" BASIS,
22158 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22159 * See the License for the specific language governing permissions and
22160 * limitations under the License.
22161 * =============================================================================
22162 */
22163 function nonMaxSuppressionV3Impl$2(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
22164 return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0 /* softNmsSigma */);
22165 }
22166 function nonMaxSuppressionV4Impl$2(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize) {
22167 return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0 /* softNmsSigma */, false /* returnScoresTensor */, padToMaxOutputSize /* padToMaxOutputSize */, true
22168 /* returnValidOutputs */ );
22169 }
22170 function nonMaxSuppressionV5Impl$2(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
22171 return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, true /* returnScoresTensor */);
22172 }
22173 function nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, returnScoresTensor = false, padToMaxOutputSize = false, returnValidOutputs = false) {
22174 // The list is sorted in ascending order, so that we can always pop the
22175 // candidate with the largest score in O(1) time.
22176 const candidates = [];
22177 for (let i = 0; i < scores.length; i++) {
22178 if (scores[i] > scoreThreshold) {
22179 candidates.push({ score: scores[i], boxIndex: i, suppressBeginIndex: 0 });
22180 }
22181 }
22182 candidates.sort(ascendingComparator);
22183 // If softNmsSigma is 0, the outcome of this algorithm is exactly same as
22184 // before.
22185 const scale = softNmsSigma > 0 ? (-0.5 / softNmsSigma) : 0.0;
22186 const selectedIndices = [];
22187 const selectedScores = [];
22188 while (selectedIndices.length < maxOutputSize && candidates.length > 0) {
22189 const candidate = candidates.pop();
22190 const { score: originalScore, boxIndex, suppressBeginIndex } = candidate;
22191 if (originalScore < scoreThreshold) {
22192 break;
22193 }
22194 // Overlapping boxes are likely to have similar scores, therefore we
22195 // iterate through the previously selected boxes backwards in order to
22196 // see if candidate's score should be suppressed. We use
22197 // suppressBeginIndex to track and ensure a candidate can be suppressed
22198 // by a selected box no more than once. Also, if the overlap exceeds
22199 // iouThreshold, we simply ignore the candidate.
22200 let ignoreCandidate = false;
22201 for (let j = selectedIndices.length - 1; j >= suppressBeginIndex; --j) {
22202 const iou = intersectionOverUnion(boxes, boxIndex, selectedIndices[j]);
22203 if (iou >= iouThreshold) {
22204 ignoreCandidate = true;
22205 break;
22206 }
22207 candidate.score =
22208 candidate.score * suppressWeight(iouThreshold, scale, iou);
22209 if (candidate.score <= scoreThreshold) {
22210 break;
22211 }
22212 }
22213 // At this point, if `candidate.score` has not dropped below
22214 // `scoreThreshold`, then we know that we went through all of the
22215 // previous selections and can safely update `suppressBeginIndex` to the
22216 // end of the selected array. Then we can re-insert the candidate with
22217 // the updated score and suppressBeginIndex back in the candidate list.
22218 // If on the other hand, `candidate.score` has dropped below the score
22219 // threshold, we will not add it back to the candidates list.
22220 candidate.suppressBeginIndex = selectedIndices.length;
22221 if (!ignoreCandidate) {
22222 // Candidate has passed all the tests, and is not suppressed, so
22223 // select the candidate.
22224 if (candidate.score === originalScore) {
22225 selectedIndices.push(boxIndex);
22226 selectedScores.push(candidate.score);
22227 }
22228 else if (candidate.score > scoreThreshold) {
22229 // Candidate's score is suppressed but is still high enough to be
22230 // considered, so add back to the candidates list.
22231 binaryInsert(candidates, candidate, ascendingComparator);
22232 }
22233 }
22234 }
22235 // NonMaxSuppressionV4 feature: padding output to maxOutputSize.
22236 const validOutputs = selectedIndices.length;
22237 const elemsToPad = maxOutputSize - validOutputs;
22238 if (padToMaxOutputSize && elemsToPad > 0) {
22239 selectedIndices.push(...new Array(elemsToPad).fill(0));
22240 selectedScores.push(...new Array(elemsToPad).fill(0.0));
22241 }
22242 const result = { selectedIndices };
22243 if (returnScoresTensor) {
22244 result['selectedScores'] = selectedScores;
22245 }
22246 if (returnValidOutputs) {
22247 result['validOutputs'] = validOutputs;
22248 }
22249 return result;
22250 }
22251 function intersectionOverUnion(boxes, i, j) {
22252 const iCoord = boxes.subarray(i * 4, i * 4 + 4);
22253 const jCoord = boxes.subarray(j * 4, j * 4 + 4);
22254 const yminI = Math.min(iCoord[0], iCoord[2]);
22255 const xminI = Math.min(iCoord[1], iCoord[3]);
22256 const ymaxI = Math.max(iCoord[0], iCoord[2]);
22257 const xmaxI = Math.max(iCoord[1], iCoord[3]);
22258 const yminJ = Math.min(jCoord[0], jCoord[2]);
22259 const xminJ = Math.min(jCoord[1], jCoord[3]);
22260 const ymaxJ = Math.max(jCoord[0], jCoord[2]);
22261 const xmaxJ = Math.max(jCoord[1], jCoord[3]);
22262 const areaI = (ymaxI - yminI) * (xmaxI - xminI);
22263 const areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);
22264 if (areaI <= 0 || areaJ <= 0) {
22265 return 0.0;
22266 }
22267 const intersectionYmin = Math.max(yminI, yminJ);
22268 const intersectionXmin = Math.max(xminI, xminJ);
22269 const intersectionYmax = Math.min(ymaxI, ymaxJ);
22270 const intersectionXmax = Math.min(xmaxI, xmaxJ);
22271 const intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) *
22272 Math.max(intersectionXmax - intersectionXmin, 0.0);
22273 return intersectionArea / (areaI + areaJ - intersectionArea);
22274 }
22275 // A Gaussian penalty function, this method always returns values in [0, 1].
22276 // The weight is a function of similarity, the more overlap two boxes are, the
22277 // smaller the weight is,meaning highly overlapping boxes will be significantly
22278 // penalized. On the other hand, a non-overlapping box will not be penalized.
22279 function suppressWeight(iouThreshold, scale, iou) {
22280 const weight = Math.exp(scale * iou * iou);
22281 return iou <= iouThreshold ? weight : 0.0;
22282 }
22283 function ascendingComparator(c1, c2) {
22284 // For objects with same scores, we make the object with the larger index go
22285 // first. In an array that pops from the end, this means that the object with
22286 // the smaller index will be popped first. This ensures the same output as
22287 // the TensorFlow python version.
22288 return (c1.score - c2.score) ||
22289 ((c1.score === c2.score) && (c2.boxIndex - c1.boxIndex));
22290 }
22291
22292 /**
22293 * @license
22294 * Copyright 2020 Google LLC. All Rights Reserved.
22295 * Licensed under the Apache License, Version 2.0 (the "License");
22296 * you may not use this file except in compliance with the License.
22297 * You may obtain a copy of the License at
22298 *
22299 * http://www.apache.org/licenses/LICENSE-2.0
22300 *
22301 * Unless required by applicable law or agreed to in writing, software
22302 * distributed under the License is distributed on an "AS IS" BASIS,
22303 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22304 * See the License for the specific language governing permissions and
22305 * limitations under the License.
22306 * =============================================================================
22307 */
22308 /**
22309 * Performs non maximum suppression of bounding boxes based on
22310 * iou (intersection over union).
22311 *
22312 * This is the async version of `nonMaxSuppression`
22313 *
22314 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
22315 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
22316 * the bounding box.
22317 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
22318 * @param maxOutputSize The maximum number of boxes to be selected.
22319 * @param iouThreshold A float representing the threshold for deciding whether
22320 * boxes overlap too much with respect to IOU. Must be between [0, 1].
22321 * Defaults to 0.5 (50% box overlap).
22322 * @param scoreThreshold A threshold for deciding when to remove boxes based
22323 * on score. Defaults to -inf, which means any score is accepted.
22324 * @return A 1D tensor with the selected box indices.
22325 *
22326 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
22327 */
22328 async function nonMaxSuppressionAsync_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY) {
22329 const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
22330 const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
22331 const inputs = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);
22332 maxOutputSize = inputs.maxOutputSize;
22333 iouThreshold = inputs.iouThreshold;
22334 scoreThreshold = inputs.scoreThreshold;
22335 const boxesAndScores = await Promise.all([$boxes.data(), $scores.data()]);
22336 const boxesVals = boxesAndScores[0];
22337 const scoresVals = boxesAndScores[1];
22338 // We call a cpu based impl directly with the typedarray data here rather
22339 // than a kernel because all kernels are synchronous (and thus cannot await
22340 // .data()).
22341 const { selectedIndices } = nonMaxSuppressionV3Impl$2(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
22342 if ($boxes !== boxes) {
22343 $boxes.dispose();
22344 }
22345 if ($scores !== scores) {
22346 $scores.dispose();
22347 }
22348 return tensor1d(selectedIndices, 'int32');
22349 }
22350 const nonMaxSuppressionAsync = nonMaxSuppressionAsync_;
22351
22352 /**
22353 * @license
22354 * Copyright 2020 Google LLC. All Rights Reserved.
22355 * Licensed under the Apache License, Version 2.0 (the "License");
22356 * you may not use this file except in compliance with the License.
22357 * You may obtain a copy of the License at
22358 *
22359 * http://www.apache.org/licenses/LICENSE-2.0
22360 *
22361 * Unless required by applicable law or agreed to in writing, software
22362 * distributed under the License is distributed on an "AS IS" BASIS,
22363 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22364 * See the License for the specific language governing permissions and
22365 * limitations under the License.
22366 * =============================================================================
22367 */
22368 /**
22369 * Performs non maximum suppression of bounding boxes based on
22370 * iou (intersection over union).
22371 *
22372 * This op also supports a Soft-NMS mode (cf.
22373 * Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
22374 * of other overlapping boxes, therefore favoring different regions of the image
22375 * with high scores. To enable this Soft-NMS mode, set the `softNmsSigma`
22376 * parameter to be larger than 0.
22377 *
22378 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
22379 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
22380 * the bounding box.
22381 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
22382 * @param maxOutputSize The maximum number of boxes to be selected.
22383 * @param iouThreshold A float representing the threshold for deciding whether
22384 * boxes overlap too much with respect to IOU. Must be between [0, 1].
22385 * Defaults to 0.5 (50% box overlap).
22386 * @param scoreThreshold A threshold for deciding when to remove boxes based
22387 * on score. Defaults to -inf, which means any score is accepted.
22388 * @param softNmsSigma A float representing the sigma parameter for Soft NMS.
22389 * When sigma is 0, it falls back to nonMaxSuppression.
22390 * @return A map with the following properties:
22391 * - selectedIndices: A 1D tensor with the selected box indices.
22392 * - selectedScores: A 1D tensor with the corresponding scores for each
22393 * selected box.
22394 *
22395 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
22396 */
22397 function nonMaxSuppressionWithScore_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, softNmsSigma = 0.0) {
22398 const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');
22399 const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');
22400 const params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
22401 maxOutputSize = params.maxOutputSize;
22402 iouThreshold = params.iouThreshold;
22403 scoreThreshold = params.scoreThreshold;
22404 softNmsSigma = params.softNmsSigma;
22405 const inputs = { boxes: $boxes, scores: $scores };
22406 const attrs = { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma };
22407 // tslint:disable-next-line: no-unnecessary-type-assertion
22408 const result = ENGINE.runKernel(NonMaxSuppressionV5, inputs, attrs);
22409 return { selectedIndices: result[0], selectedScores: result[1] };
22410 }
22411 const nonMaxSuppressionWithScore = /* @__PURE__ */ op({ nonMaxSuppressionWithScore_ });
22412
22413 /**
22414 * @license
22415 * Copyright 2020 Google LLC. All Rights Reserved.
22416 * Licensed under the Apache License, Version 2.0 (the "License");
22417 * you may not use this file except in compliance with the License.
22418 * You may obtain a copy of the License at
22419 *
22420 * http://www.apache.org/licenses/LICENSE-2.0
22421 *
22422 * Unless required by applicable law or agreed to in writing, software
22423 * distributed under the License is distributed on an "AS IS" BASIS,
22424 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22425 * See the License for the specific language governing permissions and
22426 * limitations under the License.
22427 * =============================================================================
22428 */
22429 /**
22430 * Asynchronously performs non maximum suppression of bounding boxes based on
22431 * iou (intersection over union).
22432 *
22433 * This op also supports a Soft-NMS mode (cf.
22434 * Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
22435 * of other overlapping boxes, therefore favoring different regions of the image
22436 * with high scores. To enable this Soft-NMS mode, set the `softNmsSigma`
22437 * parameter to be larger than 0.
22438 *
22439 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
22440 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
22441 * the bounding box.
22442 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
22443 * @param maxOutputSize The maximum number of boxes to be selected.
22444 * @param iouThreshold A float representing the threshold for deciding whether
22445 * boxes overlap too much with respect to IOU. Must be between [0, 1].
22446 * Defaults to 0.5 (50% box overlap).
22447 * @param scoreThreshold A threshold for deciding when to remove boxes based
22448 * on score. Defaults to -inf, which means any score is accepted.
22449 * @param softNmsSigma A float representing the sigma parameter for Soft NMS.
22450 * When sigma is 0, it falls back to nonMaxSuppression.
22451 * @return A map with the following properties:
22452 * - selectedIndices: A 1D tensor with the selected box indices.
22453 * - selectedScores: A 1D tensor with the corresponding scores for each
22454 * selected box.
22455 *
22456 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
22457 */
22458 async function nonMaxSuppressionWithScoreAsync_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, softNmsSigma = 0.0) {
22459 const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
22460 const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
22461 const params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
22462 maxOutputSize = params.maxOutputSize;
22463 iouThreshold = params.iouThreshold;
22464 scoreThreshold = params.scoreThreshold;
22465 softNmsSigma = params.softNmsSigma;
22466 const boxesAndScores = await Promise.all([$boxes.data(), $scores.data()]);
22467 const boxesVals = boxesAndScores[0];
22468 const scoresVals = boxesAndScores[1];
22469 // We call a cpu based impl directly with the typedarray data here rather
22470 // than a kernel because all kernels are synchronous (and thus cannot await
22471 // .data()).
22472 const { selectedIndices, selectedScores } = nonMaxSuppressionV5Impl$2(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
22473 if ($boxes !== boxes) {
22474 $boxes.dispose();
22475 }
22476 if ($scores !== scores) {
22477 $scores.dispose();
22478 }
22479 return {
22480 selectedIndices: tensor1d(selectedIndices, 'int32'),
22481 selectedScores: tensor1d(selectedScores)
22482 };
22483 }
22484 const nonMaxSuppressionWithScoreAsync = nonMaxSuppressionWithScoreAsync_;
22485
22486 /**
22487 * @license
22488 * Copyright 2020 Google LLC. All Rights Reserved.
22489 * Licensed under the Apache License, Version 2.0 (the "License");
22490 * you may not use this file except in compliance with the License.
22491 * You may obtain a copy of the License at
22492 *
22493 * http://www.apache.org/licenses/LICENSE-2.0
22494 *
22495 * Unless required by applicable law or agreed to in writing, software
22496 * distributed under the License is distributed on an "AS IS" BASIS,
22497 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22498 * See the License for the specific language governing permissions and
22499 * limitations under the License.
22500 * =============================================================================
22501 */
22502 /**
22503 * Asynchronously performs non maximum suppression of bounding boxes based on
22504 * iou (intersection over union), with an option to pad results.
22505 *
22506 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
22507 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
22508 * the bounding box.
22509 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
22510 * @param maxOutputSize The maximum number of boxes to be selected.
22511 * @param iouThreshold A float representing the threshold for deciding whether
22512 * boxes overlap too much with respect to IOU. Must be between [0, 1].
22513 * Defaults to 0.5 (50% box overlap).
22514 * @param scoreThreshold A threshold for deciding when to remove boxes based
22515 * on score. Defaults to -inf, which means any score is accepted.
22516 * @param padToMaxOutputSize Defaults to false. If true, size of output
22517 * `selectedIndices` is padded to maxOutputSize.
22518 * @return A map with the following properties:
22519 * - selectedIndices: A 1D tensor with the selected box indices.
22520 * - validOutputs: A scalar denoting how many elements in `selectedIndices`
22521 * are valid. Valid elements occur first, then padding.
22522 *
22523 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
22524 */
22525 function nonMaxSuppressionPadded_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, padToMaxOutputSize = false) {
22526 const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');
22527 const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');
22528 const params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, null /* softNmsSigma */);
22529 const $maxOutputSize = params.maxOutputSize;
22530 const $iouThreshold = params.iouThreshold;
22531 const $scoreThreshold = params.scoreThreshold;
22532 const inputs = { boxes: $boxes, scores: $scores };
22533 const attrs = {
22534 maxOutputSize: $maxOutputSize,
22535 iouThreshold: $iouThreshold,
22536 scoreThreshold: $scoreThreshold,
22537 padToMaxOutputSize
22538 };
22539 // tslint:disable-next-line: no-unnecessary-type-assertion
22540 const result = ENGINE.runKernel(NonMaxSuppressionV4, inputs, attrs);
22541 return { selectedIndices: result[0], validOutputs: result[1] };
22542 }
22543 const nonMaxSuppressionPadded = /* @__PURE__ */ op({ nonMaxSuppressionPadded_ });
22544
22545 /**
22546 * @license
22547 * Copyright 2020 Google LLC. All Rights Reserved.
22548 * Licensed under the Apache License, Version 2.0 (the "License");
22549 * you may not use this file except in compliance with the License.
22550 * You may obtain a copy of the License at
22551 *
22552 * http://www.apache.org/licenses/LICENSE-2.0
22553 *
22554 * Unless required by applicable law or agreed to in writing, software
22555 * distributed under the License is distributed on an "AS IS" BASIS,
22556 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22557 * See the License for the specific language governing permissions and
22558 * limitations under the License.
22559 * =============================================================================
22560 */
22561 /**
22562 * Asynchronously performs non maximum suppression of bounding boxes based on
22563 * iou (intersection over union), with an option to pad results.
22564 *
22565 * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
22566 * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
22567 * the bounding box.
22568 * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
22569 * @param maxOutputSize The maximum number of boxes to be selected.
22570 * @param iouThreshold A float representing the threshold for deciding whether
22571 * boxes overlap too much with respect to IOU. Must be between [0, 1].
22572 * Defaults to 0.5 (50% box overlap).
22573 * @param scoreThreshold A threshold for deciding when to remove boxes based
22574 * on score. Defaults to -inf, which means any score is accepted.
22575 * @param padToMaxOutputSize Defaults to false. If true, size of output
22576 * `selectedIndices` is padded to maxOutputSize.
22577 * @return A map with the following properties:
22578 * - selectedIndices: A 1D tensor with the selected box indices.
22579 * - validOutputs: A scalar denoting how many elements in `selectedIndices`
22580 * are valid. Valid elements occur first, then padding.
22581 *
22582 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
22583 */
22584 async function nonMaxSuppressionPaddedAsync_(boxes, scores, maxOutputSize, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, padToMaxOutputSize = false) {
22585 const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
22586 const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
22587 const params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, null /* softNmsSigma */);
22588 const $maxOutputSize = params.maxOutputSize;
22589 const $iouThreshold = params.iouThreshold;
22590 const $scoreThreshold = params.scoreThreshold;
22591 const [boxesVals, scoresVals] = await Promise.all([$boxes.data(), $scores.data()]);
22592 // We call a cpu based impl directly with the typedarray data here rather
22593 // than a kernel because all kernels are synchronous (and thus cannot await
22594 // .data()).
22595 const { selectedIndices, validOutputs } = nonMaxSuppressionV4Impl$2(boxesVals, scoresVals, $maxOutputSize, $iouThreshold, $scoreThreshold, padToMaxOutputSize);
22596 if ($boxes !== boxes) {
22597 $boxes.dispose();
22598 }
22599 if ($scores !== scores) {
22600 $scores.dispose();
22601 }
22602 return {
22603 selectedIndices: tensor1d(selectedIndices, 'int32'),
22604 validOutputs: scalar(validOutputs, 'int32')
22605 };
22606 }
22607 const nonMaxSuppressionPaddedAsync = nonMaxSuppressionPaddedAsync_;
22608
22609 /**
22610 * @license
22611 * Copyright 2020 Google LLC. All Rights Reserved.
22612 * Licensed under the Apache License, Version 2.0 (the "License");
22613 * you may not use this file except in compliance with the License.
22614 * You may obtain a copy of the License at
22615 *
22616 * http://www.apache.org/licenses/LICENSE-2.0
22617 *
22618 * Unless required by applicable law or agreed to in writing, software
22619 * distributed under the License is distributed on an "AS IS" BASIS,
22620 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22621 * See the License for the specific language governing permissions and
22622 * limitations under the License.
22623 * =============================================================================
22624 */
22625 /**
22626 * Bilinear resize a single 3D image or a batch of 3D images to a new shape.
22627 *
22628 * @param images The images, of rank 4 or rank 3, of shape
22629 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
22630 * @param size The new shape `[newHeight, newWidth]` to resize the
22631 * images to. Each channel is resized individually.
22632 * @param alignCorners Defaults to `false`. If true, rescale
22633 * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4
22634 * corners of images and resized images. If false, rescale by
22635 * `new_height / height`. Treat similarly the width dimension.
22636 * @param halfPixelCenters Defaults to `false`. Whether to assume pixel centers
22637 * are at 0.5, which would make the floating point coordinates of the top
22638 * left pixel 0.5, 0.5.
22639 *
22640 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
22641 */
22642 function resizeBilinear_(images, size, alignCorners = false, halfPixelCenters = false) {
22643 const $images = convertToTensor(images, 'images', 'resizeBilinear');
22644 assert$1($images.rank === 3 || $images.rank === 4, () => `Error in resizeBilinear: x must be rank 3 or 4, but got ` +
22645 `rank ${$images.rank}.`);
22646 assert$1(size.length === 2, () => `Error in resizeBilinear: new shape must 2D, but got shape ` +
22647 `${size}.`);
22648 assert$1(halfPixelCenters === false || alignCorners === false, () => `Error in resizeBilinear: If halfPixelCenters is true, ` +
22649 `alignCorners must be false.`);
22650 let batchImages = $images;
22651 let reshapedTo4D = false;
22652 if ($images.rank === 3) {
22653 reshapedTo4D = true;
22654 batchImages = reshape$3($images, [1, $images.shape[0], $images.shape[1], $images.shape[2]]);
22655 }
22656 const [] = size;
22657 const inputs = { images: batchImages };
22658 const attrs = { alignCorners, halfPixelCenters, size };
22659 // tslint:disable-next-line: no-unnecessary-type-assertion
22660 const res = ENGINE.runKernel(ResizeBilinear, inputs, attrs);
22661 if (reshapedTo4D) {
22662 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
22663 }
22664 return res;
22665 }
22666 const resizeBilinear$3 = /* @__PURE__ */ op({ resizeBilinear_ });
22667
22668 /**
22669 * @license
22670 * Copyright 2020 Google LLC. All Rights Reserved.
22671 * Licensed under the Apache License, Version 2.0 (the "License");
22672 * you may not use this file except in compliance with the License.
22673 * You may obtain a copy of the License at
22674 *
22675 * http://www.apache.org/licenses/LICENSE-2.0
22676 *
22677 * Unless required by applicable law or agreed to in writing, software
22678 * distributed under the License is distributed on an "AS IS" BASIS,
22679 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22680 * See the License for the specific language governing permissions and
22681 * limitations under the License.
22682 * =============================================================================
22683 */
22684 /**
22685 * NearestNeighbor resize a batch of 3D images to a new shape.
22686 *
22687 * @param images The images, of rank 4 or rank 3, of shape
22688 * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
22689 * @param size The new shape `[newHeight, newWidth]` to resize the
22690 * images to. Each channel is resized individually.
22691 * @param alignCorners Defaults to False. If true, rescale
22692 * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4
22693 * corners of images and resized images. If false, rescale by
22694 * `new_height / height`. Treat similarly the width dimension.
22695 * @param halfPixelCenters Defaults to `false`. Whether to assume pixels are of
22696 * half the actual dimensions, and yield more accurate resizes. This flag
22697 * would also make the floating point coordinates of the top left pixel
22698 * 0.5, 0.5.
22699 *
22700 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
22701 */
22702 function resizeNearestNeighbor_(images, size, alignCorners = false, halfPixelCenters = false) {
22703 const $images = convertToTensor(images, 'images', 'resizeNearestNeighbor');
22704 assert$1($images.rank === 3 || $images.rank === 4, () => `Error in resizeNearestNeighbor: x must be rank 3 or 4, but got ` +
22705 `rank ${$images.rank}.`);
22706 assert$1(size.length === 2, () => `Error in resizeNearestNeighbor: new shape must 2D, but got shape ` +
22707 `${size}.`);
22708 assert$1($images.dtype === 'float32' || $images.dtype === 'int32', () => '`images` must have `int32` or `float32` as dtype');
22709 assert$1(halfPixelCenters === false || alignCorners === false, () => `Error in resizeNearestNeighbor: If halfPixelCenters is true, ` +
22710 `alignCorners must be false.`);
22711 let batchImages = $images;
22712 let reshapedTo4D = false;
22713 if ($images.rank === 3) {
22714 reshapedTo4D = true;
22715 batchImages = reshape$3($images, [1, $images.shape[0], $images.shape[1], $images.shape[2]]);
22716 }
22717 const [] = size;
22718 const inputs = { images: batchImages };
22719 const attrs = { alignCorners, halfPixelCenters, size };
22720 // tslint:disable-next-line: no-unnecessary-type-assertion
22721 const res = ENGINE.runKernel(ResizeNearestNeighbor, inputs, attrs);
22722 if (reshapedTo4D) {
22723 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
22724 }
22725 return res;
22726 }
22727 const resizeNearestNeighbor$2 = /* @__PURE__ */ op({ resizeNearestNeighbor_ });
22728
22729 /**
22730 * @license
22731 * Copyright 2021 Google LLC. All Rights Reserved.
22732 * Licensed under the Apache License, Version 2.0 (the "License");
22733 * you may not use this file except in compliance with the License.
22734 * You may obtain a copy of the License at
22735 *
22736 * https://www.apache.org/licenses/LICENSE-2.0
22737 *
22738 * Unless required by applicable law or agreed to in writing, software
22739 * distributed under the License is distributed on an "AS IS" BASIS,
22740 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22741 * See the License for the specific language governing permissions and
22742 * limitations under the License.
22743 * =============================================================================
22744 */
22745 /**
22746 * Performs image binarization with corresponding threshold
22747 * (depends on the method)value, which creates a binary image from a grayscale.
22748 * @param image 3d tensor of shape [imageHeight,imageWidth, depth],
22749 * where imageHeight and imageWidth must be positive.The image color
22750 * range should be [0, 255].
22751 * @param method Optional string from `'binary' | 'otsu'`
22752 * which specifies the method for thresholding. Defaults to 'binary'.
22753 * @param inverted Optional boolean whichspecifies
22754 * if colours should be inverted. Defaults to false.
22755 * @param threshValue Optional number which defines threshold value from 0 to 1.
22756 * Defaults to 0.5.
22757 * @return A 3d tensor of shape [imageHeight,imageWidth, depth], which
22758 * contains binarized image.
22759 */
22760 function threshold_(image, method = 'binary', inverted = false, threshValue = 0.5) {
22761 const $image = convertToTensor(image, 'image', 'threshold');
22762 /* 0.2989, 0.5870, 0.1140 are represent luma coefficients in CCIR601.
22763 Reference for converting between RGB and grayscale: https://en.wikipedia.org/wiki/Luma_%28video%29 */
22764 const RED_INTENCITY_COEF = 0.2989;
22765 const GREEN_INTENCITY_COEF = 0.5870;
22766 const BLUE_INTENCITY_COEF = 0.1140;
22767 const totalPixelsInImage = $image.shape[0] * $image.shape[1];
22768 let $threshold = mul(tensor1d([threshValue]), 255);
22769 let r, g, b, grayscale;
22770 assert$1($image.rank === 3, () => 'Error in threshold: image must be rank 3,' +
22771 `but got rank ${$image.rank}.`);
22772 assert$1($image.shape[2] === 3 || $image.shape[2] === 1, () => 'Error in threshold: ' +
22773 'image color channel must be equal to 3 or 1' +
22774 `but got ${$image.shape[2]}.`);
22775 assert$1($image.dtype === 'int32' || $image.dtype === 'float32', () => 'Error in dtype: image dtype must be int32 or float32,' +
22776 `but got dtype ${$image.dtype}.`);
22777 assert$1(method === 'otsu' || method === 'binary', () => `Method must be binary or otsu, but was ${method}`);
22778 if ($image.shape[2] === 3) {
22779 [r, g, b] = split$3($image, [1, 1, 1], -1);
22780 const $r = mul(r, RED_INTENCITY_COEF);
22781 const $g = mul(g, GREEN_INTENCITY_COEF);
22782 const $b = mul(b, BLUE_INTENCITY_COEF);
22783 grayscale = add$3(add$3($r, $g), $b);
22784 }
22785 else {
22786 grayscale = image;
22787 }
22788 if (method === 'otsu') {
22789 const $histogram = bincount$2(cast$3(round$2(grayscale), 'int32'), tensor([]), 256);
22790 $threshold = otsu($histogram, totalPixelsInImage);
22791 }
22792 const invCondition = inverted ?
22793 lessEqual$2(grayscale, $threshold) : greater$3(grayscale, $threshold);
22794 const result = cast$3(mul(invCondition, 255), 'int32');
22795 return result;
22796 }
22797 function otsu(histogram, total) {
22798 let bestThresh = tensor1d([-1]);
22799 let bestInBetVar = tensor1d([0]);
22800 let cInBetVar = tensor1d([0]);
22801 let classFirst, classSecond, meanFirst, meanSec, weightForeground, weightBack;
22802 for (let index = 0; index < histogram.size - 1; index++) {
22803 classFirst = slice$2(histogram, 0, index + 1);
22804 classSecond = slice$2(histogram, index + 1);
22805 weightForeground = div$1(sum$3(classFirst), total);
22806 weightBack = div$1(sum$3(classSecond), total);
22807 const meanFirstDivA = sum$3(mul(classFirst, range$3(0, classFirst.size)));
22808 meanFirst = div$1(meanFirstDivA, sum$3(classFirst));
22809 const meanSecFill = fill$2(classSecond.shape, classFirst.size);
22810 const meanSecAdd = add$3(range$3(0, classSecond.size), meanSecFill);
22811 const meanSecMul = mul(classSecond, (meanSecAdd));
22812 meanSec = div$1(sum$3(meanSecMul), sum$3(classSecond));
22813 const cInBetVarSubA = sub$2(meanFirst, meanSec);
22814 const cInBetVarSubB = sub$2(meanFirst, meanSec);
22815 const cInBetVarMul = mul(weightForeground, weightBack);
22816 cInBetVar = mul(mul(cInBetVarMul, cInBetVarSubA), cInBetVarSubB);
22817 const condition = greater$3(cInBetVar, bestInBetVar);
22818 bestInBetVar = where(condition, cInBetVar, bestInBetVar);
22819 bestThresh = where(condition, tensor1d([index]), bestThresh);
22820 }
22821 return bestThresh;
22822 }
22823 const threshold$1 = /* @__PURE__ */ op({ threshold_ });
22824
22825 /**
22826 * @license
22827 * Copyright 2021 Google LLC. All Rights Reserved.
22828 * Licensed under the Apache License, Version 2.0 (the "License");
22829 * you may not use this file except in compliance with the License.
22830 * You may obtain a copy of the License at
22831 *
22832 * http://www.apache.org/licenses/LICENSE-2.0
22833 *
22834 * Unless required by applicable law or agreed to in writing, software
22835 * distributed under the License is distributed on an "AS IS" BASIS,
22836 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22837 * See the License for the specific language governing permissions and
22838 * limitations under the License.
22839 * =============================================================================
22840 */
22841 /**
22842 * Applies the given transform(s) to the image(s).
22843 *
22844 * @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
22845 * @param transforms Projective transform matrix/matrices. A tensor1d of length
22846 * 8 or tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0,
22847 * b1, b2, c0, c1], then it maps the output point (x, y) to a transformed
22848 * input point (x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k),
22849 * where k = c0 x + c1 y + 1. The transforms are inverted compared to the
22850 * transform mapping input points to output points.
22851 * @param interpolation Interpolation mode.
22852 * Supported values: 'nearest', 'bilinear'. Default to 'nearest'.
22853 * @param fillMode Points outside the boundaries of the input are filled
22854 * according to the given mode, one of 'constant', 'reflect', 'wrap',
22855 * 'nearest'. Default to 'constant'.
22856 * 'reflect': (d c b a | a b c d | d c b a ) The input is extended by
22857 * reflecting about the edge of the last pixel.
22858 * 'constant': (k k k k | a b c d | k k k k) The input is extended by
22859 * filling all values beyond the edge with the same constant value k.
22860 * 'wrap': (a b c d | a b c d | a b c d) The input is extended by
22861 * wrapping around to the opposite edge.
22862 * 'nearest': (a a a a | a b c d | d d d d) The input is extended by
22863 * the nearest pixel.
22864 * @param fillValue A float represents the value to be filled outside the
22865 * boundaries when fillMode is 'constant'.
22866 * @param Output dimension after the transform, [height, width]. If undefined,
22867 * output is the same size as input image.
22868 *
22869 * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
22870 */
22871 function transform_(image, transforms, interpolation = 'nearest', fillMode = 'constant', fillValue = 0, outputShape) {
22872 const $image = convertToTensor(image, 'image', 'transform', 'float32');
22873 const $transforms = convertToTensor(transforms, 'transforms', 'transform', 'float32');
22874 assert$1($image.rank === 4, () => 'Error in transform: image must be rank 4,' +
22875 `but got rank ${$image.rank}.`);
22876 assert$1($transforms.rank === 2 &&
22877 ($transforms.shape[0] === $image.shape[0] ||
22878 $transforms.shape[0] === 1) &&
22879 $transforms.shape[1] === 8, () => `Error in transform: Input transform should be batch x 8 or 1 x 8`);
22880 assert$1(outputShape == null || outputShape.length === 2, () => 'Error in transform: outputShape must be [height, width] or null, ' +
22881 `but got ${outputShape}.`);
22882 const inputs = { image: $image, transforms: $transforms };
22883 const attrs = { interpolation, fillMode, fillValue, outputShape };
22884 return ENGINE.runKernel(Transform, inputs, attrs);
22885 }
22886 const transform$2 = /* @__PURE__ */ op({ transform_ });
22887
22888 /**
22889 * @license
22890 * Copyright 2020 Google LLC. All Rights Reserved.
22891 * Licensed under the Apache License, Version 2.0 (the "License");
22892 * you may not use this file except in compliance with the License.
22893 * You may obtain a copy of the License at
22894 *
22895 * http://www.apache.org/licenses/LICENSE-2.0
22896 *
22897 * Unless required by applicable law or agreed to in writing, software
22898 * distributed under the License is distributed on an "AS IS" BASIS,
22899 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22900 * See the License for the specific language governing permissions and
22901 * limitations under the License.
22902 * =============================================================================
22903 */
22904 /**
22905 * Copy a tensor setting everything outside a central band in each innermost
22906 * matrix to zero.
22907 *
22908 * The band part is computed as follows: Assume input has `k` dimensions
22909 * `[I, J, K, ..., M, N]`, then the output is a tensor with the same shape where
22910 * `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
22911 * The indicator function
22912 * `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)`
22913 * `&& (num_upper < 0 || (n-m) <= num_upper)`
22914 *
22915 * ```js
22916 * const x = tf.tensor2d([[ 0, 1, 2, 3],
22917 * [-1, 0, 1, 2],
22918 * [-2, -1, 0, 1],
22919 * [-3, -2, -1, 0]]);
22920 * let y = tf.linalg.bandPart(x, 1, -1);
22921 * y.print(); // [[ 0, 1, 2, 3],
22922 * // [-1, 0, 1, 2],
22923 * // [ 0, -1, 0, 1],
22924 * // [ 0, 0 , -1, 0]]
22925 * let z = tf.linalg.bandPart(x, 2, 1);
22926 * z.print(); // [[ 0, 1, 0, 0],
22927 * // [-1, 0, 1, 0],
22928 * // [-2, -1, 0, 1],
22929 * // [ 0, -2, -1, 0]]
22930 * ```
22931 *
22932 * @param x Rank `k` tensor
22933 * @param numLower Number of subdiagonals to keep.
22934 * If negative, keep entire lower triangle.
22935 * @param numUpper Number of subdiagonals to keep.
22936 * If negative, keep entire upper triangle.
22937 * @returns Rank `k` tensor of the same shape as input.
22938 * The extracted banded tensor.
22939 *
22940 * @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'}
22941 */
22942 function bandPart_(a, numLower, numUpper) {
22943 const $a = convertToTensor(a, 'a', 'bandPart');
22944 assert$1($a.rank >= 2, () => `bandPart(): Rank must be at least 2, got ${$a.rank}.`);
22945 const shape = $a.shape;
22946 const [M, N] = $a.shape.slice(-2);
22947 let $numLower;
22948 let $numUpper;
22949 if (typeof numLower === 'number') {
22950 assert$1(numLower % 1 === 0, () => `bandPart(): numLower must be an integer, got ${numLower}.`);
22951 assert$1(numLower <= M, () => `bandPart(): numLower (${numLower})` +
22952 ` must not be greater than the number of rows (${M}).`);
22953 $numLower =
22954 convertToTensor(numLower < 0 ? M : numLower, 'numLower', 'bandPart');
22955 }
22956 else {
22957 assert$1(numLower.dtype === 'int32', () => `bandPart(): numLower's dtype must be an int32.`);
22958 // If numLower is a Scalar, checking `numLower <= M` could hurt performance,
22959 // but minimum(numLower, M) could avoid unexpected results.
22960 $numLower = where(less$3(numLower, 0), M, minimum$4(numLower, M));
22961 }
22962 if (typeof numUpper === 'number') {
22963 assert$1(numUpper % 1 === 0, () => `bandPart(): numUpper must be an integer, got ${numUpper}.`);
22964 assert$1(numUpper <= N, () => `bandPart(): numUpper (${numUpper})` +
22965 ` must not be greater than the number of columns (${N}).`);
22966 $numUpper =
22967 convertToTensor(numUpper < 0 ? N : numUpper, 'numUpper', 'bandPart');
22968 }
22969 else {
22970 assert$1(numUpper.dtype === 'int32', () => `bandPart(): numUpper's dtype must be an int32.`);
22971 $numUpper = where(less$3(numUpper, 0), N, minimum$4(numUpper, N));
22972 }
22973 const i = reshape$3(range$3(0, M, 1, 'int32'), [-1, 1]);
22974 const j = range$3(0, N, 1, 'int32');
22975 const ij = sub$2(i, j);
22976 const inBand = logicalAnd$2(lessEqual$2(ij, $numLower), greaterEqual$2(ij, neg$2($numUpper)));
22977 const zero = zeros$2([M, N], $a.dtype);
22978 return reshape$3(stack(unstack(reshape$3($a, [-1, M, N]))
22979 .map(mat => where(inBand, mat, zero))), shape);
22980 }
22981 const bandPart = /* @__PURE__ */ op({ bandPart_ });
22982
22983 /**
22984 * @license
22985 * Copyright 2020 Google LLC. All Rights Reserved.
22986 * Licensed under the Apache License, Version 2.0 (the "License");
22987 * you may not use this file except in compliance with the License.
22988 * You may obtain a copy of the License at
22989 *
22990 * http://www.apache.org/licenses/LICENSE-2.0
22991 *
22992 * Unless required by applicable law or agreed to in writing, software
22993 * distributed under the License is distributed on an "AS IS" BASIS,
22994 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22995 * See the License for the specific language governing permissions and
22996 * limitations under the License.
22997 * =============================================================================
22998 */
22999 /**
23000 * Gram-Schmidt orthogonalization.
23001 *
23002 * ```js
23003 * const x = tf.tensor2d([[1, 2], [3, 4]]);
23004 * let y = tf.linalg.gramSchmidt(x);
23005 * y.print();
23006 * console.log('Orthogonalized:');
23007 * y.dot(y.transpose()).print(); // should be nearly the identity matrix.
23008 * console.log('First row direction maintained:');
23009 * const data = await y.array();
23010 * console.log(data[0][1] / data[0][0]); // should be nearly 2.
23011 * ```
23012 *
23013 * @param xs The vectors to be orthogonalized, in one of the two following
23014 * formats:
23015 * - An Array of `tf.Tensor1D`.
23016 * - A `tf.Tensor2D`, i.e., a matrix, in which case the vectors are the rows
23017 * of `xs`.
23018 * In each case, all the vectors must have the same length and the length
23019 * must be greater than or equal to the number of vectors.
23020 * @returns The orthogonalized and normalized vectors or matrix.
23021 * Orthogonalization means that the vectors or the rows of the matrix
23022 * are orthogonal (zero inner products). Normalization means that each
23023 * vector or each row of the matrix has an L2 norm that equals `1`.
23024 *
23025 * @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'}
23026 */
23027 function gramSchmidt_(xs) {
23028 let inputIsTensor2D;
23029 if (Array.isArray(xs)) {
23030 inputIsTensor2D = false;
23031 assert$1(xs != null && xs.length > 0, () => 'Gram-Schmidt process: input must not be null, undefined, or ' +
23032 'empty');
23033 const dim = xs[0].shape[0];
23034 for (let i = 1; i < xs.length; ++i) {
23035 assert$1(xs[i].shape[0] === dim, () => 'Gram-Schmidt: Non-unique lengths found in the input vectors: ' +
23036 `(${xs[i].shape[0]} vs. ${dim})`);
23037 }
23038 }
23039 else {
23040 inputIsTensor2D = true;
23041 xs = split$3(xs, xs.shape[0], 0).map(x => squeeze(x, [0]));
23042 }
23043 assert$1(xs.length <= xs[0].shape[0], () => `Gram-Schmidt: Number of vectors (${xs.length}) exceeds ` +
23044 `number of dimensions (${xs[0].shape[0]}).`);
23045 const ys = [];
23046 const xs1d = xs;
23047 for (let i = 0; i < xs.length; ++i) {
23048 ys.push(ENGINE.tidy(() => {
23049 let x = xs1d[i];
23050 if (i > 0) {
23051 for (let j = 0; j < i; ++j) {
23052 const proj = mul(sum$3(mul(ys[j], x)), ys[j]);
23053 x = sub$2(x, proj);
23054 }
23055 }
23056 return div$1(x, norm(x, 'euclidean'));
23057 }));
23058 }
23059 if (inputIsTensor2D) {
23060 return stack(ys, 0);
23061 }
23062 else {
23063 return ys;
23064 }
23065 }
23066 const gramSchmidt = /* @__PURE__ */ op({ gramSchmidt_ });
23067
23068 /**
23069 * @license
23070 * Copyright 2020 Google LLC. All Rights Reserved.
23071 * Licensed under the Apache License, Version 2.0 (the "License");
23072 * you may not use this file except in compliance with the License.
23073 * You may obtain a copy of the License at
23074 *
23075 * http://www.apache.org/licenses/LICENSE-2.0
23076 *
23077 * Unless required by applicable law or agreed to in writing, software
23078 * distributed under the License is distributed on an "AS IS" BASIS,
23079 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23080 * See the License for the specific language governing permissions and
23081 * limitations under the License.
23082 * =============================================================================
23083 */
23084 /**
23085 * Compute QR decomposition of m-by-n matrix using Householder transformation.
23086 *
23087 * Implementation based on
23088 * [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf]
23089 * (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf)
23090 *
23091 * ```js
23092 * const a = tf.tensor2d([[1, 2], [3, 4]]);
23093 * let [q, r] = tf.linalg.qr(a);
23094 * console.log('Q');
23095 * q.print();
23096 * console.log('R');
23097 * r.print();
23098 * console.log('Orthogonalized');
23099 * q.dot(q.transpose()).print() // should be nearly the identity matrix.
23100 * console.log('Reconstructed');
23101 * q.dot(r).print(); // should be nearly [[1, 2], [3, 4]];
23102 * ```
23103 *
23104 * @param x The `tf.Tensor` to be QR-decomposed. Must have rank >= 2. Suppose
23105 * it has the shape `[..., M, N]`.
23106 * @param fullMatrices An optional boolean parameter. Defaults to `false`.
23107 * If `true`, compute full-sized `Q`. If `false` (the default),
23108 * compute only the leading N columns of `Q` and `R`.
23109 * @returns An `Array` of two `tf.Tensor`s: `[Q, R]`. `Q` is a unitary matrix,
23110 * i.e., its columns all have unit norm and are mutually orthogonal.
23111 * If `M >= N`,
23112 * If `fullMatrices` is `false` (default),
23113 * - `Q` has a shape of `[..., M, N]`,
23114 * - `R` has a shape of `[..., N, N]`.
23115 * If `fullMatrices` is `true` (default),
23116 * - `Q` has a shape of `[..., M, M]`,
23117 * - `R` has a shape of `[..., M, N]`.
23118 * If `M < N`,
23119 * - `Q` has a shape of `[..., M, M]`,
23120 * - `R` has a shape of `[..., M, N]`.
23121 * @throws If the rank of `x` is less than 2.
23122 *
23123 * @doc {heading:'Operations',
23124 * subheading:'Linear Algebra',
23125 * namespace:'linalg'}
23126 */
23127 function qr_(x, fullMatrices = false) {
23128 assert$1(x.rank >= 2, () => `qr() requires input tensor to have a rank >= 2, but got rank ${x.rank}`);
23129 if (x.rank === 2) {
23130 return qr2d(x, fullMatrices);
23131 }
23132 else {
23133 // Rank > 2.
23134 // TODO(cais): Below we split the input into individual 2D tensors,
23135 // perform QR decomposition on them and then stack the results back
23136 // together. We should explore whether this can be parallelized.
23137 const outerDimsProd = x.shape.slice(0, x.shape.length - 2)
23138 .reduce((value, prev) => value * prev);
23139 const x2ds = unstack(reshape$3(x, [
23140 outerDimsProd, x.shape[x.shape.length - 2],
23141 x.shape[x.shape.length - 1]
23142 ]), 0);
23143 const q2ds = [];
23144 const r2ds = [];
23145 x2ds.forEach(x2d => {
23146 const [q2d, r2d] = qr2d(x2d, fullMatrices);
23147 q2ds.push(q2d);
23148 r2ds.push(r2d);
23149 });
23150 const q = reshape$3(stack(q2ds, 0), x.shape);
23151 const r = reshape$3(stack(r2ds, 0), x.shape);
23152 return [q, r];
23153 }
23154 }
23155 function qr2d(x, fullMatrices = false) {
23156 return ENGINE.tidy(() => {
23157 assert$1(x.shape.length === 2, () => `qr2d() requires a 2D Tensor, but got a ${x.shape.length}D Tensor.`);
23158 const m = x.shape[0];
23159 const n = x.shape[1];
23160 let q = eye(m); // Orthogonal transform so far.
23161 let r = clone(x); // Transformed matrix so far.
23162 const one2D = tensor2d([[1]], [1, 1]);
23163 let w = clone(one2D);
23164 const iters = m >= n ? n : m;
23165 for (let j = 0; j < iters; ++j) {
23166 // This tidy within the for-loop ensures we clean up temporary
23167 // tensors as soon as they are no longer needed.
23168 const rTemp = r;
23169 const wTemp = w;
23170 const qTemp = q;
23171 [w, r, q] = ENGINE.tidy(() => {
23172 // Find H = I - tau * w * w', to put zeros below R(j, j).
23173 const rjEnd1 = slice$2(r, [j, j], [m - j, 1]);
23174 const normX = norm(rjEnd1);
23175 const rjj = slice$2(r, [j, j], [1, 1]);
23176 // The sign() function returns 0 on 0, which causes division by zero.
23177 const s = where(greater$3(rjj, 0), tensor2d([[-1]]), tensor2d([[1]]));
23178 const u1 = sub$2(rjj, mul(s, normX));
23179 const wPre = div$1(rjEnd1, u1);
23180 if (wPre.shape[0] === 1) {
23181 w = clone(one2D);
23182 }
23183 else {
23184 w = concat$2([
23185 one2D,
23186 slice$2(wPre, [1, 0], [wPre.shape[0] - 1, wPre.shape[1]])
23187 ], 0);
23188 }
23189 const tau = neg$2(div$1(matMul$1(s, u1), normX));
23190 // -- R := HR, Q := QH.
23191 const rjEndAll = slice$2(r, [j, 0], [m - j, n]);
23192 const tauTimesW = mul(tau, w);
23193 const wT = transpose$2(w);
23194 if (j === 0) {
23195 r = sub$2(rjEndAll, matMul$1(tauTimesW, matMul$1(wT, rjEndAll)));
23196 }
23197 else {
23198 const rTimesTau = sub$2(rjEndAll, matMul$1(tauTimesW, matMul$1(wT, rjEndAll)));
23199 r = concat$2([slice$2(r, [0, 0], [j, n]), rTimesTau], 0);
23200 }
23201 const tawTimesWT = transpose$2(tauTimesW);
23202 const qAllJEnd = slice$2(q, [0, j], [m, q.shape[1] - j]);
23203 if (j === 0) {
23204 q = sub$2(qAllJEnd, matMul$1(matMul$1(qAllJEnd, w), tawTimesWT));
23205 }
23206 else {
23207 const qTimesTau = sub$2(qAllJEnd, matMul$1(matMul$1(qAllJEnd, w), tawTimesWT));
23208 q = concat$2([slice$2(q, [0, 0], [m, j]), qTimesTau], 1);
23209 }
23210 return [w, r, q];
23211 });
23212 dispose([rTemp, wTemp, qTemp]);
23213 }
23214 if (!fullMatrices && m > n) {
23215 q = slice$2(q, [0, 0], [m, n]);
23216 r = slice$2(r, [0, 0], [n, n]);
23217 }
23218 return [q, r];
23219 });
23220 }
23221 const qr = /* @__PURE__ */ op({ qr_ });
23222
23223 /**
23224 * @license
23225 * Copyright 2020 Google LLC. All Rights Reserved.
23226 * Licensed under the Apache License, Version 2.0 (the "License");
23227 * you may not use this file except in compliance with the License.
23228 * You may obtain a copy of the License at
23229 *
23230 * http://www.apache.org/licenses/LICENSE-2.0
23231 *
23232 * Unless required by applicable law or agreed to in writing, software
23233 * distributed under the License is distributed on an "AS IS" BASIS,
23234 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23235 * See the License for the specific language governing permissions and
23236 * limitations under the License.
23237 * =============================================================================
23238 */
23239 exports.Reduction = void 0;
23240 (function (Reduction) {
23241 Reduction[Reduction["NONE"] = 0] = "NONE";
23242 Reduction[Reduction["MEAN"] = 1] = "MEAN";
23243 Reduction[Reduction["SUM"] = 2] = "SUM";
23244 Reduction[Reduction["SUM_BY_NONZERO_WEIGHTS"] = 3] = "SUM_BY_NONZERO_WEIGHTS";
23245 })(exports.Reduction || (exports.Reduction = {}));
23246
23247 /**
23248 * Computes the weighted loss between two tensors.
23249 *
23250 * @param losses Tensor of shape `[batch_size, d1, ..., dN]`.
23251 * @param weights Tensor whose rank is either 0, or the same rank as
23252 * `losses`, and must be broadcastable to `losses` (i.e., all
23253 * dimensions must be either `1`, or the same as the corresponding
23254 * `losses` dimension).
23255 *
23256 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
23257 */
23258 function computeWeightedLoss_(losses, weights, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
23259 const $losses = convertToTensor(losses, 'losses', 'computeWeightedLoss');
23260 let $weights = null;
23261 if (weights != null) {
23262 $weights = convertToTensor(weights, 'weights', 'computeWeightedLoss');
23263 }
23264 const weightedLoss = ($weights == null) ? $losses : mul($losses, $weights);
23265 if (reduction === exports.Reduction.NONE) {
23266 return weightedLoss;
23267 }
23268 if (reduction === exports.Reduction.SUM) {
23269 return sum$3(weightedLoss);
23270 }
23271 if (reduction === exports.Reduction.MEAN) {
23272 if ($weights == null) {
23273 return mean$3(weightedLoss);
23274 }
23275 else {
23276 const broadcastFactor = $losses.size / $weights.size;
23277 const result = div$1(sum$3(weightedLoss), sum$3($weights));
23278 return broadcastFactor > 1 ? div$1(result, scalar(broadcastFactor)) :
23279 result;
23280 }
23281 }
23282 if (reduction === exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
23283 if ($weights == null) {
23284 return div$1(sum$3(weightedLoss), scalar($losses.size));
23285 }
23286 else {
23287 const broadcastedWeights = mul($weights, ones$1($losses.shape));
23288 const numNonZeros = cast$3(sum$3(notEqual$2(broadcastedWeights, scalar(0))), 'float32');
23289 return div$1(sum$3(weightedLoss), numNonZeros);
23290 }
23291 }
23292 throw Error(`Unknown reduction: ${reduction}`);
23293 }
23294 const computeWeightedLoss$1 = /* @__PURE__ */ op({ computeWeightedLoss_ });
23295
23296 /**
23297 * @license
23298 * Copyright 2020 Google LLC. All Rights Reserved.
23299 * Licensed under the Apache License, Version 2.0 (the "License");
23300 * you may not use this file except in compliance with the License.
23301 * You may obtain a copy of the License at
23302 *
23303 * http://www.apache.org/licenses/LICENSE-2.0
23304 *
23305 * Unless required by applicable law or agreed to in writing, software
23306 * distributed under the License is distributed on an "AS IS" BASIS,
23307 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23308 * See the License for the specific language governing permissions and
23309 * limitations under the License.
23310 * =============================================================================
23311 */
23312 /**
23313 * Computes the absolute difference loss between two tensors.
23314 *
23315 * @param labels The ground truth output tensor, same dimensions as
23316 * 'predictions'.
23317 * @param predictions The predicted outputs.
23318 * @param weights Tensor whose rank is either 0, or the same rank as
23319 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
23320 * must be either `1`, or the same as the corresponding `losses`
23321 * dimension).
23322 * @param reduction Type of reduction to apply to loss. Should be of type
23323 * `Reduction`
23324 *
23325 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
23326 */
23327 function absoluteDifference_(labels, predictions, weights, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
23328 const $labels = convertToTensor(labels, 'labels', 'absoluteDifference');
23329 const $predictions = convertToTensor(predictions, 'predictions', 'absoluteDifference');
23330 let $weights = null;
23331 if (weights != null) {
23332 $weights = convertToTensor(weights, 'weights', 'absoluteDifference');
23333 }
23334 assertShapesMatch($labels.shape, $predictions.shape, 'Error in absoluteDifference: ');
23335 const losses = abs$2(sub$2($labels, $predictions));
23336 return computeWeightedLoss$1(losses, $weights, reduction);
23337 }
23338 const absoluteDifference = /* @__PURE__ */ op({ absoluteDifference_ });
23339
23340 /**
23341 * Computes the cosine distance loss between two tensors.
23342 *
23343 * @param labels The ground truth output tensor, same dimensions as
23344 * 'predictions'.
23345 * @param predictions The predicted outputs.
23346 * @param axis The dimension along which the cosine distance is computed.
23347 * @param weights Tensor whose rank is either 0, or the same rank as
23348 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
23349 * must be either `1`, or the same as the corresponding `losses`
23350 * dimension).
23351 * @param reduction Type of reduction to apply to loss. Should be of type
23352 * `Reduction`
23353 *
23354 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
23355 */
23356 function cosineDistance_(labels, predictions, axis, weights, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
23357 const $labels = convertToTensor(labels, 'labels', 'cosineDistance');
23358 const $predictions = convertToTensor(predictions, 'predictions', 'cosineDistance');
23359 let $weights = null;
23360 if (weights != null) {
23361 $weights = convertToTensor(weights, 'weights', 'cosineDistance');
23362 }
23363 assertShapesMatch($labels.shape, $predictions.shape, 'Error in cosineDistance: ');
23364 const one = scalar(1);
23365 const losses = sub$2(one, sum$3(mul($labels, $predictions), axis, true));
23366 return computeWeightedLoss$1(losses, $weights, reduction);
23367 }
23368 const cosineDistance = /* @__PURE__ */ op({ cosineDistance_ });
23369
23370 /**
23371 * Computes the Hinge loss between two tensors.
23372 *
23373 * @param labels The ground truth output tensor, same dimensions as
23374 * 'predictions'.
23375 * @param predictions The predicted outputs.
23376 * @param weights Tensor whose rank is either 0, or the same rank as
23377 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
23378 * must be either `1`, or the same as the corresponding `losses`
23379 * dimension).
23380 * @param reduction Type of reduction to apply to loss. Should be of type
23381 * `Reduction`
23382 *
23383 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
23384 */
23385 function hingeLoss_(labels, predictions, weights, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
23386 let $labels = convertToTensor(labels, 'labels', 'hingeLoss');
23387 const $predictions = convertToTensor(predictions, 'predictions', 'hingeLoss');
23388 let $weights = null;
23389 if (weights != null) {
23390 $weights = convertToTensor(weights, 'weights', 'hingeLoss');
23391 }
23392 assertShapesMatch($labels.shape, $predictions.shape, 'Error in hingeLoss: ');
23393 const one = scalar(1);
23394 // Convert binary labels to (-1, 1)
23395 $labels = sub$2(mul(scalar(2), $labels), one);
23396 const losses = relu$2(sub$2(one, mul($labels, $predictions)));
23397 return computeWeightedLoss$1(losses, $weights, reduction);
23398 }
23399 const hingeLoss = /* @__PURE__ */ op({ hingeLoss_ });
23400
23401 /**
23402 * @license
23403 * Copyright 2020 Google LLC. All Rights Reserved.
23404 * Licensed under the Apache License, Version 2.0 (the "License");
23405 * you may not use this file except in compliance with the License.
23406 * You may obtain a copy of the License at
23407 *
23408 * http://www.apache.org/licenses/LICENSE-2.0
23409 *
23410 * Unless required by applicable law or agreed to in writing, software
23411 * distributed under the License is distributed on an "AS IS" BASIS,
23412 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23413 * See the License for the specific language governing permissions and
23414 * limitations under the License.
23415 * =============================================================================
23416 */
23417 /**
23418 * Computes the Huber loss between two tensors.
23419 *
23420 * @param labels The ground truth output tensor, same dimensions as
23421 * 'predictions'.
23422 * @param predictions The predicted outputs.
23423 * @param weights Tensor whose rank is either 0, or the same rank as
23424 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
23425 * must be either `1`, or the same as the corresponding `losses`
23426 * dimension).
23427 * @param delta Point where Huber loss changes from quadratic to linear.
23428 * @param reduction Type of reduction to apply to loss. Should be of type
23429 * `Reduction`.
23430 *
23431 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
23432 */
23433 function huberLoss_(labels, predictions, weights, delta = 1.0, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
23434 const $labels = convertToTensor(labels, 'labels', 'huberLoss');
23435 const $predictions = convertToTensor(predictions, 'predictions', 'huberLoss');
23436 let $weights = null;
23437 if (weights != null) {
23438 $weights = convertToTensor(weights, 'weights', 'huberLoss');
23439 }
23440 assertShapesMatch($labels.shape, $predictions.shape, 'Error in huberLoss: ');
23441 const deltaScalar = scalar(delta);
23442 const error = abs$2(sub$2($predictions, $labels));
23443 const quadratic = minimum$4(error, deltaScalar);
23444 const linear = sub$2(error, quadratic);
23445 const losses = add$3(mul(scalar(0.5), square$2(quadratic)), mul(deltaScalar, linear));
23446 return computeWeightedLoss$1(losses, $weights, reduction);
23447 }
23448 const huberLoss = /* @__PURE__ */ op({ huberLoss_ });
23449
23450 /**
23451 * @license
23452 * Copyright 2020 Google LLC. All Rights Reserved.
23453 * Licensed under the Apache License, Version 2.0 (the "License");
23454 * you may not use this file except in compliance with the License.
23455 * You may obtain a copy of the License at
23456 *
23457 * http://www.apache.org/licenses/LICENSE-2.0
23458 *
23459 * Unless required by applicable law or agreed to in writing, software
23460 * distributed under the License is distributed on an "AS IS" BASIS,
23461 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23462 * See the License for the specific language governing permissions and
23463 * limitations under the License.
23464 * =============================================================================
23465 */
23466 /**
23467 * Computes the log loss between two tensors.
23468 *
23469 * @param labels The ground truth output tensor, same dimensions as
23470 * 'predictions'.
23471 * @param predictions The predicted outputs.
23472 * @param weights Tensor whose rank is either 0, or the same rank as
23473 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
23474 * must be either `1`, or the same as the corresponding `losses`
23475 * dimension).
23476 * @param epsilon A small increment to avoid taking log of zero
23477 * @param reduction Type of reduction to apply to loss. Should be of type
23478 * `Reduction`
23479 *
23480 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
23481 */
23482 function logLoss_(labels, predictions, weights, epsilon = 1e-7, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
23483 const $labels = convertToTensor(labels, 'labels', 'logLoss');
23484 const $predictions = convertToTensor(predictions, 'predictions', 'logLoss');
23485 let $weights = null;
23486 if (weights != null) {
23487 $weights = convertToTensor(weights, 'weights', 'logLoss');
23488 }
23489 assertShapesMatch($labels.shape, $predictions.shape, 'Error in logLoss: ');
23490 const one = scalar(1);
23491 const epsilonScalar = scalar(epsilon);
23492 const l1 = neg$2(mul($labels, log$2(add$3($predictions, epsilonScalar))));
23493 const l2 = mul(sub$2(one, $labels), log$2(add$3(sub$2(one, $predictions), epsilonScalar)));
23494 const losses = sub$2(l1, l2);
23495 return computeWeightedLoss$1(losses, $weights, reduction);
23496 }
23497 const logLoss = /* @__PURE__ */ op({ logLoss_ });
23498
23499 /**
23500 * @license
23501 * Copyright 2020 Google LLC. All Rights Reserved.
23502 * Licensed under the Apache License, Version 2.0 (the "License");
23503 * you may not use this file except in compliance with the License.
23504 * You may obtain a copy of the License at
23505 *
23506 * http://www.apache.org/licenses/LICENSE-2.0
23507 *
23508 * Unless required by applicable law or agreed to in writing, software
23509 * distributed under the License is distributed on an "AS IS" BASIS,
23510 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23511 * See the License for the specific language governing permissions and
23512 * limitations under the License.
23513 * =============================================================================
23514 */
23515 /**
23516 * Computes the mean squared error between two tensors.
23517 *
23518 * @param labels The ground truth output tensor, same dimensions as
23519 * 'predictions'.
23520 * @param predictions The predicted outputs.
23521 * @param weights Tensor whose rank is either 0, or the same rank as
23522 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
23523 * must be either `1`, or the same as the corresponding `losses`
23524 * dimension).
23525 * @param reduction Type of reduction to apply to loss. Should be of type
23526 * `Reduction`
23527 *
23528 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
23529 */
23530 function meanSquaredError_(labels, predictions, weights, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
23531 const $labels = convertToTensor(labels, 'labels', 'meanSquaredError');
23532 const $predictions = convertToTensor(predictions, 'predictions', 'meanSquaredError');
23533 let $weights = null;
23534 if (weights != null) {
23535 $weights = convertToTensor(weights, 'weights', 'meanSquaredError');
23536 }
23537 assertShapesMatch($labels.shape, $predictions.shape, 'Error in meanSquaredError: ');
23538 const losses = squaredDifference$2($labels, $predictions);
23539 return computeWeightedLoss$1(losses, $weights, reduction);
23540 }
23541 const meanSquaredError$2 = /* @__PURE__ */ op({ meanSquaredError_ });
23542
23543 /**
23544 * @license
23545 * Copyright 2020 Google LLC. All Rights Reserved.
23546 * Licensed under the Apache License, Version 2.0 (the "License");
23547 * you may not use this file except in compliance with the License.
23548 * You may obtain a copy of the License at
23549 *
23550 * http://www.apache.org/licenses/LICENSE-2.0
23551 *
23552 * Unless required by applicable law or agreed to in writing, software
23553 * distributed under the License is distributed on an "AS IS" BASIS,
23554 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23555 * See the License for the specific language governing permissions and
23556 * limitations under the License.
23557 * =============================================================================
23558 */
23559 function sigmoidCrossEntropyWithLogits_(labels, logits) {
23560 const $labels = convertToTensor(labels, 'labels', 'sigmoidCrossEntropyWithLogits');
23561 const $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropyWithLogits');
23562 assertShapesMatch($labels.shape, $logits.shape, 'Error in sigmoidCrossEntropyWithLogits: ');
23563 /**
23564 * Implementation Details:
23565 *
23566 * For brevity, let `x = logits`, `z = labels`. The logistic loss is
23567 * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
23568 * = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
23569 * = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
23570 * = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
23571 * = (1 - z) * x + log(1 + exp(-x))
23572 * = x - x * z + log(1 + exp(-x))
23573 *
23574 * For x < 0, to avoid overflow in exp(-x), we reformulate the above
23575 * x - x * z + log(1 + exp(-x))
23576 * = log(exp(x)) - x * z + log(1 + exp(-x))
23577 * = - x * z + log(1 + exp(x))
23578 *
23579 * Hence, to ensure stability and avoid overflow, the implementation uses
23580 * this equivalent formulation:
23581 * max(x, 0) - x * z + log(1 + exp(-abs(x)))
23582 */
23583 const maxOutput = relu$2($logits);
23584 const outputXTarget = mul($logits, $labels);
23585 const sigmoidOutput = log1p$2(exp$2(neg$2(abs$2($logits))));
23586 return add$3(sub$2(maxOutput, outputXTarget), sigmoidOutput);
23587 }
23588 /**
23589 * Computes the sigmoid cross entropy loss between two tensors.
23590 *
23591 * If labelSmoothing is nonzero, smooth the labels towards 1/2:
23592 *
23593 * newMulticlassLabels = multiclassLabels * (1 - labelSmoothing)
23594 * + 0.5 * labelSmoothing
23595 *
23596 * @param multiClassLabels The ground truth output tensor of shape
23597 * [batch_size, num_classes], same dimensions as 'predictions'.
23598 * @param logits The predicted outputs.
23599 * @param weights Tensor whose rank is either 0, or the same rank as
23600 * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
23601 * must be either `1`, or the same as the corresponding `losses`
23602 * dimension).
23603 * @param labelSmoothing If greater than 0, then smooth the labels.
23604 * @param reduction Type of reduction to apply to loss. Should be of type
23605 * `Reduction`
23606 *
23607 * @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }
23608 */
23609 function sigmoidCrossEntropy_(multiClassLabels, logits, weights, labelSmoothing = 0, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
23610 let $multiClassLabels = convertToTensor(multiClassLabels, 'multiClassLabels', 'sigmoidCrossEntropy');
23611 const $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropy');
23612 let $weights = null;
23613 if (weights != null) {
23614 $weights = convertToTensor(weights, 'weights', 'sigmoidCrossEntropy');
23615 }
23616 assertShapesMatch($multiClassLabels.shape, $logits.shape, 'Error in sigmoidCrossEntropy: ');
23617 if (labelSmoothing > 0) {
23618 const labelSmoothingScalar = scalar(labelSmoothing);
23619 const one = scalar(1);
23620 const half = scalar(0.5);
23621 $multiClassLabels =
23622 add$3(mul($multiClassLabels, sub$2(one, labelSmoothingScalar)), mul(half, labelSmoothingScalar));
23623 }
23624 const losses = sigmoidCrossEntropyWithLogits_($multiClassLabels, $logits);
23625 return computeWeightedLoss$1(losses, $weights, reduction);
23626 }
23627 const sigmoidCrossEntropy = /* @__PURE__ */ op({ sigmoidCrossEntropy_ });
23628
23629 /**
23630 * @license
23631 * Copyright 2020 Google LLC. All Rights Reserved.
23632 * Licensed under the Apache License, Version 2.0 (the "License");
23633 * you may not use this file except in compliance with the License.
23634 * You may obtain a copy of the License at
23635 *
23636 * http://www.apache.org/licenses/LICENSE-2.0
23637 *
23638 * Unless required by applicable law or agreed to in writing, software
23639 * distributed under the License is distributed on an "AS IS" BASIS,
23640 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23641 * See the License for the specific language governing permissions and
23642 * limitations under the License.
23643 * =============================================================================
23644 */
23645 /**
23646 * Computes softmax cross entropy between logits and labels.
23647 *
23648 * Measures the probability error in discrete classification tasks in which
23649 * the classes are mutually exclusive (each entry is in exactly one class).
23650 * For example, each CIFAR-10 image is labeled with one and only one label: an
23651 * image can be a dog or a truck, but not both.
23652 *
23653 * `NOTE`: While the classes are mutually exclusive, their probabilities need
23654 * not be. All that is required is that each row of labels is a valid
23655 * probability distribution. If they are not, the computation of the gradient
23656 * will be incorrect.
23657 *
23658 * `WARNING`: This op expects unscaled logits, since it performs a softmax on
23659 * logits internally for efficiency. Do not call this op with the output of
23660 * softmax, as it will produce incorrect results.
23661 *
23662 * logits and labels must have the same shape, e.g. [batch_size, num_classes]
23663 * and the same dtype.
23664 * @param labels The labels array.
23665 * @param logits The logits array.
23666 * @param dim The dimension softmax would be performed on. Defaults to `-1`
23667 * which indicates the last dimension.
23668 */
23669 function softmaxCrossEntropyWithLogits_(labels, logits, dim = -1) {
23670 if (dim === -1) {
23671 dim = logits.rank - 1;
23672 }
23673 if (dim !== logits.rank - 1) {
23674 throw Error(`Softmax cross entropy along a non-last dimension is not yet ` +
23675 `supported. Labels / logits was rank ${logits.rank} ` +
23676 `and dim was ${dim}`);
23677 }
23678 // Use a custom gradient for numerical stability.
23679 const customOp = customGrad((labels, logits, save) => {
23680 // Reference:
23681 // 1. http://cs231n.github.io/linear-classify/#softmax
23682 // 2. https://blog.feedly.com/tricks-of-the-trade-logsumexp/
23683 const keepDims = true;
23684 const lse = logSumExp(logits, [dim], keepDims);
23685 const logResult = sub$2(cast$3(logits, 'float32'), lse);
23686 save([labels, logResult]);
23687 const costVector = neg$2(mul(logResult, labels));
23688 const value = sum$3(costVector, [dim]);
23689 const gradFunc = (dy, saved) => {
23690 const [labels, logResult] = saved;
23691 const dyShape = expandShapeToKeepDim(dy.shape, [dim]);
23692 return [
23693 mul(reshape$3(dy, dyShape), sub$2(cast$3(labels, 'float32'), exp$2(logResult))),
23694 mul(reshape$3(dy, dyShape), sub$2(exp$2(logResult), cast$3(labels, 'float32'))),
23695 ];
23696 };
23697 return { value, gradFunc };
23698 });
23699 return customOp(labels, logits);
23700 }
23701 /**
23702 * Computes the softmax cross entropy loss between two tensors.
23703 *
23704 * If labelSmoothing is nonzero, smooth the labels towards 1/2:
23705 *
23706 * newOnehotLabels = onehotLabels * (1 - labelSmoothing)
23707 * + labelSmoothing / numClasses
23708 *
23709 * @param onehotLabels One hot encoded labels
23710 * [batch_size, num_classes], same dimensions as 'predictions'.
23711 * @param logits The predicted outputs.
23712 * @param weights Tensor whose rank is either 0, or 1, and must be
23713 * broadcastable to `loss` of shape [batch_size]
23714 * @param labelSmoothing If greater than 0, then smooth the labels.
23715 * @param reduction Type of reduction to apply to loss. Should be of type
23716 * `Reduction`
23717 *
23718 * @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }
23719 */
23720 function softmaxCrossEntropy_(onehotLabels, logits, weights, labelSmoothing = 0, reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
23721 let $onehotLabels = convertToTensor(onehotLabels, 'onehotLabels', 'softmaxCrossEntropy');
23722 const $logits = convertToTensor(logits, 'logits', 'softmaxCrossEntropy');
23723 let $weights = null;
23724 if (weights != null) {
23725 $weights = convertToTensor(weights, 'weights', 'softmaxCrossEntropy');
23726 }
23727 assertShapesMatch($onehotLabels.shape, $logits.shape, 'Error in softmaxCrossEntropy: ');
23728 if (labelSmoothing > 0) {
23729 const labelSmoothingScalar = scalar(labelSmoothing);
23730 const one = scalar(1);
23731 const numClasses = scalar($onehotLabels.shape[1]);
23732 $onehotLabels =
23733 add$3(mul($onehotLabels, sub$2(one, labelSmoothingScalar)), div$1(labelSmoothingScalar, numClasses));
23734 }
23735 const losses = softmaxCrossEntropyWithLogits_($onehotLabels, $logits);
23736 return computeWeightedLoss$1(losses, $weights, reduction);
23737 }
23738 const softmaxCrossEntropy = /* @__PURE__ */ op({ softmaxCrossEntropy_ });
23739
23740 /**
23741 * @license
23742 * Copyright 2021 Google LLC. All Rights Reserved.
23743 * Licensed under the Apache License, Version 2.0 (the "License");
23744 * you may not use this file except in compliance with the License.
23745 * You may obtain a copy of the License at
23746 *
23747 * http://www.apache.org/licenses/LICENSE-2.0
23748 *
23749 * Unless required by applicable law or agreed to in writing, software
23750 * distributed under the License is distributed on an "AS IS" BASIS,
23751 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23752 * See the License for the specific language governing permissions and
23753 * limitations under the License.
23754 * =============================================================================
23755 */
23756 /**
23757 * The input SparseTensor is represented via the map of inputs {`indices`,
23758 * `values`, `denseShape`}. The output SparseTensor has the same `denseShape`
23759 * but with indices `outputIndices` and values `outputValues`. This op inserts a
23760 * single entry for every row that doesn't have any values. The index is created
23761 * as `[row, 0, ..., 0]` and the inserted value is `defaultValue`.
23762 *
23763 * For example, suppose `spInput` has shape [5, 6] and non-empty values:
23764 * [0, 1]: a
23765 * [0, 3]: b
23766 * [2, 0]: c
23767 * [3, 1]: d
23768 *
23769 * Rows 1 and 4 are empty, so the output will be of shape [5, 6] with values:
23770 * [0, 1]: a
23771 * [0, 3]: b
23772 * [1, 0]: `defaultValue`
23773 * [2, 0]: c
23774 * [3, 1]: d
23775 * [4, 0]: `defaultValue`
23776 *
23777 * The output SparseTensor will be in row-major order and will have the same
23778 * shape as the input.
23779 *
23780 * This op also returns an indicator vector shaped [dense_shape[0]] such that
23781 * emptyRowIndicator[i] = True iff row i was an empty row.
23782 *
23783 * And a reverse index map vector shaped [indices.shape[0]] that is used during
23784 * backpropagation, reverseIndexMap[i] = outi s.t. indices[i, j] ==
23785 * outputIndices[outi, j] for all j
23786 *
23787 * ```js
23788 * const result = tf.sparse.sparseFillEmptyRows(
23789 * [[0, 0], [1, 0], [1, 3], [1, 4], [3, 2], [3, 3]],
23790 * [0, 10, 13, 14, 32, 33], [5, 6], -1);
23791 * console.log(result);
23792 * result['outputIndices'].print(); // [[0, 0], [1, 0], [1, 3], [1, 4],
23793 * // [2, 0], [3, 2], [3, 3], [4, 0]]
23794 * result['outputValues'].print(); // [0, 10, 13, 14,-1, 32, 33, -1]
23795 * result['emptyRowIndicator'].print(); // [false, false, true, false, true]
23796 * result['reverseIndexMap'].print(); // [0, 1, 2, 3, 5, 6]
23797 * ```
23798 * @param indices: 2-D. The indices of the sparse tensor.
23799 * @param values: 1-D. The values of the sparse tensor.
23800 * @param denseShape: 1-D. The shape of the sparse tensor.
23801 * @param defaultValue: 0-D. Default value to insert into location [row, 0, ...,
23802 * 0] for rows missing from the input sparse tensor.
23803 * @return A map with the following properties:
23804 * - outputIndices
23805 * - outputValues: 1-D. The values of the filled sparse tensor.
23806 * - emptyRowIndicator: 1-D. Whether the dense row was missing in the input
23807 * sparse tensor.
23808 * - reverseIndexMap: 1-D. A map from the input indices to the output
23809 * indices.
23810 * @doc {heading: 'Operations', subheading: 'Sparse'}
23811 */
23812 function sparseFillEmptyRows_(indices, values, denseShape, defaultValue) {
23813 const $indices = convertToTensor(indices, 'indices', 'sparseFillEmptyRows', 'int32');
23814 const $values = convertToTensor(values, 'values', 'sparseFillEmptyRows');
23815 const $denseShape = convertToTensor(denseShape, 'denseShape', 'sparseFillEmptyRows', 'int32');
23816 const $defaultValue = convertToTensor(defaultValue, 'defaultValue', 'sparseFillEmptyRows', $values.dtype);
23817 if ($indices.rank !== 2) {
23818 throw new Error(`Indices should be Tensor2D but received shape
23819 ${$indices.shape}`);
23820 }
23821 if ($values.rank !== 1) {
23822 throw new Error(`Values should be Tensor1D but received shape ${$values.shape}`);
23823 }
23824 if ($denseShape.rank !== 1) {
23825 throw new Error(`Dense shape should be Tensor1D but received shape ${$denseShape.shape}`);
23826 }
23827 if ($defaultValue.rank !== 0) {
23828 throw new Error(`Default value should be a scalar but received shape ${$defaultValue.shape}`);
23829 }
23830 const inputs = {
23831 indices: $indices,
23832 values: $values,
23833 denseShape: $denseShape,
23834 defaultValue: $defaultValue
23835 };
23836 const result = ENGINE.runKernel(SparseFillEmptyRows, inputs);
23837 return {
23838 outputIndices: result[0],
23839 outputValues: result[1],
23840 emptyRowIndicator: result[2],
23841 reverseIndexMap: result[3]
23842 };
23843 }
23844 const sparseFillEmptyRows$2 = /* @__PURE__ */ op({ sparseFillEmptyRows_ });
23845
23846 /**
23847 * @license
23848 * Copyright 2021 Google LLC. All Rights Reserved.
23849 * Licensed under the Apache License, Version 2.0 (the "License");
23850 * you may not use this file except in compliance with the License.
23851 * You may obtain a copy of the License at
23852 *
23853 * http://www.apache.org/licenses/LICENSE-2.0
23854 *
23855 * Unless required by applicable law or agreed to in writing, software
23856 * distributed under the License is distributed on an "AS IS" BASIS,
23857 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23858 * See the License for the specific language governing permissions and
23859 * limitations under the License.
23860 * =============================================================================
23861 */
23862 /**
23863 * This operation has the same semantics as reshape on the represented dense
23864 * tensor. The `inputIndices` are recomputed based on the requested `newShape`.
23865 * If one component of `newShape` is the special value -1, the size of that
23866 * dimension is computed so that the total dense size remains constant. At most
23867 * one component of `newShape` can be -1. The number of dense elements implied
23868 * by `newShape` must be the same as the number of dense elements originally
23869 * implied by `inputShape`. Reshaping does not affect the order of values in the
23870 * SparseTensor. If the input tensor has rank R_in and N non-empty values, and
23871 * `newShape` has length R_out, then `inputIndices` has shape [N, R_in],
23872 * `inputShape` has length R_in, `outputIndices` has shape [N, R_out], and
23873 * `outputShape` has length R_out.
23874 *
23875 * ```js
23876 * const result = tf.sparse.sparseReshape(
23877 * [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]],
23878 * [2, 3, 6], [9, -1]);
23879 * console.log(result);
23880 * result['outputIndices'].print(); //[[0, 0], [0, 1], [1, 2], [4, 2], [8, 1]]
23881 * result['outputShape'].print(); // [9, 4]
23882 * ```
23883 * @param inputIndices: 2-D. N x R_in matrix with the indices of non-empty
23884 * values in a SparseTensor.
23885 * @param inputShape: 1-D. R_in Tensor1D with the input SparseTensor's dense
23886 * shape.
23887 * @param newShape: 1-D. R_out Tensor1D with the requested new dense shape.
23888 * @return A map with the following properties:
23889 * - outputIndices: 2-D. N x R_out matrix with the updated indices of
23890 * non-empty values in the output SparseTensor.
23891 * - outputShape: 1-D. R_out vector with the full dense shape of the output
23892 * SparseTensor. This is the same as newShape but with any -1 dimensions
23893 * filled in.
23894 * @doc {heading: 'Operations', subheading: 'Sparse'}
23895 */
23896 function sparseReshape_(inputIndices, inputShape, newShape) {
23897 const $inputIndices = convertToTensor(inputIndices, 'inputIndices', 'sparseReshape', 'int32');
23898 const $inputShape = convertToTensor(inputShape, 'inputShape', 'sparseReshape', 'int32');
23899 const $newShape = convertToTensor(newShape, 'newShape', 'sparseReshape', 'int32');
23900 if ($inputIndices.rank !== 2) {
23901 throw new Error(`Input indices should be Tensor2D but received shape
23902 ${$inputIndices.shape}`);
23903 }
23904 if ($inputShape.rank !== 1) {
23905 throw new Error(`Input shape should be Tensor1D but received shape ${$inputShape.shape}`);
23906 }
23907 if ($newShape.rank !== 1) {
23908 throw new Error(`New shape should be Tensor1D but received shape ${$newShape.shape}`);
23909 }
23910 const inputs = {
23911 inputIndices: $inputIndices,
23912 inputShape: $inputShape,
23913 newShape: $newShape
23914 };
23915 const result = ENGINE.runKernel(SparseReshape, inputs);
23916 return { outputIndices: result[0], outputShape: result[1] };
23917 }
23918 const sparseReshape$2 = /* @__PURE__ */ op({ sparseReshape_ });
23919
23920 /**
23921 * @license
23922 * Copyright 2021 Google LLC. All Rights Reserved.
23923 * Licensed under the Apache License, Version 2.0 (the "License");
23924 * you may not use this file except in compliance with the License.
23925 * You may obtain a copy of the License at
23926 *
23927 * http://www.apache.org/licenses/LICENSE-2.0
23928 *
23929 * Unless required by applicable law or agreed to in writing, software
23930 * distributed under the License is distributed on an "AS IS" BASIS,
23931 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23932 * See the License for the specific language governing permissions and
23933 * limitations under the License.
23934 * =============================================================================
23935 */
23936 /**
23937 * Computes the mean along sparse segments of a tensor.
23938 *
23939 * ```js
23940 * const c = tf.tensor2d([[1,2,3,4], [-1,-2,-3,-4], [6,7,8,9]]);
23941 * // Select two rows, one segment.
23942 * const result1 = tf.sparse.sparseSegmentMean(c,
23943 * tf.tensor1d([0, 1], 'int32'),
23944 * tf.tensor1d([0, 0], 'int32'));
23945 * result1.print(); // [[0, 0, 0, 0]]
23946 *
23947 * // Select two rows, two segments.
23948 * const result2 = tf.sparse.sparseSegmentMean(c,
23949 * tf.tensor1d([0, 1], 'int32'),
23950 * tf.tensor1d([0, 1], 'int32'));
23951 * result2.print(); // [[1, 2, 3, 4], [-1, -2, -3, -4]]
23952 *
23953 * // Select all rows, two segments.
23954 * const result3 = tf.sparse.sparseSegmentMean(c,
23955 * tf.tensor1d([0, 1, 2], 'int32'),
23956 * tf.tensor1d([0, 1, 1], 'int32'));
23957 * result3.print(); // [[1.0, 2.0, 3.0, 4.0], [2.5, 2.5, 2.5, 2.5]]
23958 * ```
23959 * @param data: A Tensor of at least one dimension with data that will be
23960 * assembled in the output.
23961 * @param indices: A 1-D Tensor with indices into data. Has same rank as
23962 * segmentIds.
23963 * @param segmentIds: A 1-D Tensor with indices into the output Tensor. Values
23964 * should be sorted and can be repeated.
23965 * @return Has same shape as data, except for dimension 0 which has equal to
23966 * the number of segments.
23967 *
23968 * @doc {heading: 'Operations', subheading: 'Sparse'}
23969 */
23970 function sparseSegmentMean_(data, indices, segmentIds) {
23971 const $data = convertToTensor(data, 'data', 'sparseSegmentMean');
23972 const $indices = convertToTensor(indices, 'indices', 'sparseSegmentMean', 'int32');
23973 const $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'sparseSegmentMean', 'int32');
23974 if ($data.rank < 1) {
23975 throw new Error(`Data should be at least 1 dimensional but received scalar`);
23976 }
23977 if ($indices.rank !== 1) {
23978 throw new Error(`Indices should be Tensor1D but received shape
23979 ${$indices.shape}`);
23980 }
23981 if ($segmentIds.rank !== 1) {
23982 throw new Error(`Segment ids should be Tensor1D but received shape
23983 ${$segmentIds.shape}`);
23984 }
23985 const inputs = {
23986 data: $data,
23987 indices: $indices,
23988 segmentIds: $segmentIds
23989 };
23990 return ENGINE.runKernel(SparseSegmentMean, inputs);
23991 }
23992 const sparseSegmentMean$2 = /* @__PURE__ */ op({ sparseSegmentMean_ });
23993
23994 /**
23995 * @license
23996 * Copyright 2021 Google LLC. All Rights Reserved.
23997 * Licensed under the Apache License, Version 2.0 (the "License");
23998 * you may not use this file except in compliance with the License.
23999 * You may obtain a copy of the License at
24000 *
24001 * http://www.apache.org/licenses/LICENSE-2.0
24002 *
24003 * Unless required by applicable law or agreed to in writing, software
24004 * distributed under the License is distributed on an "AS IS" BASIS,
24005 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24006 * See the License for the specific language governing permissions and
24007 * limitations under the License.
24008 * =============================================================================
24009 */
24010 /**
24011 * Computes the sum along sparse segments of a tensor.
24012 *
24013 * ```js
24014 * const c = tf.tensor2d([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]);
24015 * // Select two rows, one segment.
24016 * const result1 = tf.sparse.sparseSegmentSum(c,
24017 * tf.tensor1d([0, 1], 'int32'),
24018 * tf.tensor1d([0, 0], 'int32'));
24019 * result1.print(); // [[0, 0, 0, 0]]
24020 *
24021 * // Select two rows, two segments.
24022 * const result2 = tf.sparse.sparseSegmentSum(c,
24023 * tf.tensor1d([0, 1], 'int32'),
24024 * tf.tensor1d([0, 1], 'int32'));
24025 * result2.print(); // [[1, 2, 3, 4], [-1, -2, -3, -4]]
24026 *
24027 * // Select all rows, two segments.
24028 * const result3 = tf.sparse.sparseSegmentSum(c,
24029 * tf.tensor1d([0, 1, 2], 'int32'),
24030 * tf.tensor1d([0, 0, 1], 'int32'));
24031 * result3.print(); // [[0, 0, 0, 0], [5, 6, 7, 8]]
24032 * ```
24033 * @param data: A Tensor of at least one dimension with data that will be
24034 * assembled in the output.
24035 * @param indices: A 1-D Tensor with indices into data. Has same rank as
24036 * segmentIds.
24037 * @param segmentIds: A 1-D Tensor with indices into the output Tensor. Values
24038 * should be sorted and can be repeated.
24039 * @return Has same shape as data, except for dimension 0 which has equal to
24040 * the number of segments.
24041 *
24042 * @doc {heading: 'Operations', subheading: 'Sparse'}
24043 */
24044 function sparseSegmentSum_(data, indices, segmentIds) {
24045 const $data = convertToTensor(data, 'data', 'sparseSegmentSum');
24046 const $indices = convertToTensor(indices, 'indices', 'sparseSegmentSum', 'int32');
24047 const $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'sparseSegmentSum', 'int32');
24048 if ($data.rank < 1) {
24049 throw new Error(`Data should be at least 1 dimensional but received scalar`);
24050 }
24051 if ($indices.rank !== 1) {
24052 throw new Error(`Indices should be Tensor1D but received shape
24053 ${$indices.shape}`);
24054 }
24055 if ($segmentIds.rank !== 1) {
24056 throw new Error(`Segment ids should be Tensor1D but received shape
24057 ${$segmentIds.shape}`);
24058 }
24059 const inputs = {
24060 data: $data,
24061 indices: $indices,
24062 segmentIds: $segmentIds
24063 };
24064 return ENGINE.runKernel(SparseSegmentSum, inputs);
24065 }
24066 const sparseSegmentSum$2 = /* @__PURE__ */ op({ sparseSegmentSum_ });
24067
24068 /**
24069 * @license
24070 * Copyright 2021 Google LLC. All Rights Reserved.
24071 * Licensed under the Apache License, Version 2.0 (the "License");
24072 * you may not use this file except in compliance with the License.
24073 * You may obtain a copy of the License at
24074 *
24075 * http://www.apache.org/licenses/LICENSE-2.0
24076 *
24077 * Unless required by applicable law or agreed to in writing, software
24078 * distributed under the License is distributed on an "AS IS" BASIS,
24079 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24080 * See the License for the specific language governing permissions and
24081 * limitations under the License.
24082 * =============================================================================
24083 */
24084 /**
24085 * Creates ngrams from ragged string data.
24086 *
24087 * This op accepts a ragged tensor with 1 ragged dimension containing only
24088 * strings and outputs a ragged tensor with 1 ragged dimension containing ngrams
24089 * of that string, joined along the innermost axis.
24090 *
24091 * ```js
24092 * const result = tf.string.stringNGrams(
24093 * ['a', 'b', 'c', 'd'], tf.tensor1d([0, 2, 4], 'int32'),
24094 * '|', [1, 2], 'LP', 'RP', -1, false);
24095 * result['nGrams'].print(); // ['a', 'b', 'LP|a', 'a|b', 'b|RP',
24096 * // 'c', 'd', 'LP|c', 'c|d', 'd|RP']
24097 * result['nGramsSplits'].print(); // [0, 5, 10]
24098 * ```
24099 * @param data: The values tensor of the ragged string tensor to make ngrams out
24100 * of. Must be a 1D string tensor.
24101 * @param dataSplits: The splits tensor of the ragged string tensor to make
24102 * ngrams out of.
24103 * @param separator: The string to append between elements of the token. Use ""
24104 * for no separator.
24105 * @param nGramWidths: The sizes of the ngrams to create.
24106 * @param leftPad: The string to use to pad the left side of the ngram sequence.
24107 * Only used if pad_width !== 0.
24108 * @param rightPad: The string to use to pad the right side of the ngram
24109 * sequence. Only used if pad_width !== 0.
24110 * @param padWidth: The number of padding elements to add to each side of each
24111 * sequence. Note that padding will never be greater than `nGramWidths`-1
24112 * regardless of this value. If `padWidth`=-1, then add max(`nGramWidths`)-1
24113 * elements.
24114 * @param preserveShortSequences: If true, then ensure that at least one ngram
24115 * is generated for each input sequence. In particular, if an input sequence
24116 * is shorter than min(ngramWidth) + 2*padWidth, then generate a single
24117 * ngram containing the entire sequence. If false, then no ngrams are
24118 * generated for these short input sequences.
24119 * @return A map with the following properties:
24120 * - nGrams: The values tensor of the output ngrams ragged tensor.
24121 * - nGramsSplits: The splits tensor of the output ngrams ragged tensor.
24122 *
24123 * @doc {heading: 'Operations', subheading: 'String'}
24124 */
24125 function stringNGrams_(data, dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
24126 const $data = convertToTensor(data, 'data', 'stringNGrams', 'string');
24127 if ($data.dtype !== 'string') {
24128 throw new Error('Data must be of datatype string');
24129 }
24130 if ($data.shape.length !== 1) {
24131 throw new Error(`Data must be a vector, saw: ${$data.shape}`);
24132 }
24133 const $dataSplits = convertToTensor(dataSplits, 'dataSplits', 'stringNGrams');
24134 if ($dataSplits.dtype !== 'int32') {
24135 throw new Error('Data splits must be of datatype int32');
24136 }
24137 const attrs = {
24138 separator,
24139 nGramWidths,
24140 leftPad,
24141 rightPad,
24142 padWidth,
24143 preserveShortSequences
24144 };
24145 const inputs = { data: $data, dataSplits: $dataSplits };
24146 const result = ENGINE.runKernel(StringNGrams, inputs, attrs);
24147 return { nGrams: result[0], nGramsSplits: result[1] };
24148 }
24149 const stringNGrams$2 = /* @__PURE__ */ op({ stringNGrams_ });
24150
24151 /**
24152 * @license
24153 * Copyright 2021 Google LLC. All Rights Reserved.
24154 * Licensed under the Apache License, Version 2.0 (the "License");
24155 * you may not use this file except in compliance with the License.
24156 * You may obtain a copy of the License at
24157 *
24158 * http://www.apache.org/licenses/LICENSE-2.0
24159 *
24160 * Unless required by applicable law or agreed to in writing, software
24161 * distributed under the License is distributed on an "AS IS" BASIS,
24162 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24163 * See the License for the specific language governing permissions and
24164 * limitations under the License.
24165 * =============================================================================
24166 */
24167 /**
24168 * Split elements of `input` based on `delimiter` into a SparseTensor .
24169 *
24170 * Let N be the size of source (typically N will be the batch size). Split each
24171 * element of `input` based on `delimiter` and return a SparseTensor containing
24172 * the splitted tokens. Empty tokens are ignored if `skipEmpty` is set to True.
24173 *
24174 * `delimiter` can be empty, or a string of split characters. If `delimiter` is
24175 * an empty string, each element of `input` is split into individual
24176 * character strings. Otherwise every character of `delimiter` is a potential
24177 * split point.
24178 *
24179 * ```js
24180 * const result = tf.string.stringSplit(['hello world', 'a b c'], ' ');
24181 * result['indices'].print(); // [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]]
24182 * result['values'].print(); // ['hello', 'world', 'a', 'b', 'c']
24183 * result['shape'].print(); // [2, 3]
24184 * ```
24185 * @param input: 1-D. Strings to split.
24186 * @param delimiter: 0-D. Delimiter characters, or empty string.
24187 * @param skipEmpty: Optional. If true, skip the empty strings from the result.
24188 * Defaults to true.
24189 * @return A map with the following properties:
24190 * - indices: A dense matrix of int32 representing the indices of the sparse
24191 * tensor.
24192 * - values: A vector of strings corresponding to the splited values.
24193 * - shape: a length-2 vector of int32 representing the shape of the sparse
24194 * tensor, where the first value is N and the second value is the maximum number
24195 * of tokens in a single input entry.
24196 *
24197 * @doc {heading: 'Operations', subheading: 'String'}
24198 */
24199 function stringSplit_(input, delimiter, skipEmpty = true) {
24200 const $input = convertToTensor(input, 'input', 'stringSplit', 'string');
24201 const $delimiter = convertToTensor(delimiter, 'delimiter', 'stringSplit', 'string');
24202 if ($input.rank !== 1) {
24203 throw new Error(`Input should be Tensor1D but received shape ${$input.shape}`);
24204 }
24205 if ($delimiter.rank !== 0) {
24206 throw new Error(`Delimiter should be a scalar but received shape ${$delimiter.shape}`);
24207 }
24208 const attrs = { skipEmpty };
24209 const inputs = { input: $input, delimiter: $delimiter };
24210 const result = ENGINE.runKernel(StringSplit, inputs, attrs);
24211 return { indices: result[0], values: result[1], shape: result[2] };
24212 }
24213 const stringSplit$2 = /* @__PURE__ */ op({ stringSplit_ });
24214
24215 /**
24216 * @license
24217 * Copyright 2021 Google LLC. All Rights Reserved.
24218 * Licensed under the Apache License, Version 2.0 (the "License");
24219 * you may not use this file except in compliance with the License.
24220 * You may obtain a copy of the License at
24221 *
24222 * http://www.apache.org/licenses/LICENSE-2.0
24223 *
24224 * Unless required by applicable law or agreed to in writing, software
24225 * distributed under the License is distributed on an "AS IS" BASIS,
24226 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24227 * See the License for the specific language governing permissions and
24228 * limitations under the License.
24229 * =============================================================================
24230 */
24231 /**
24232 * Converts each string in the input Tensor to its hash mod by a number of
24233 * buckets.
24234 *
24235 * The hash function is deterministic on the content of the string within the
24236 * process and will never change. However, it is not suitable for cryptography.
24237 * This function may be used when CPU time is scarce and inputs are trusted or
24238 * unimportant. There is a risk of adversaries constructing inputs that all hash
24239 * to the same bucket.
24240 *
24241 * ```js
24242 * const result = tf.string.stringToHashBucketFast(
24243 * ['Hello', 'TensorFlow', '2.x'], 3);
24244 * result.print(); // [0, 2, 2]
24245 * ```
24246 * @param input: The strings to assign a hash bucket.
24247 * @param numBuckets: The number of buckets.
24248 * @return A Tensor of the same shape as the input tensor.
24249 *
24250 * @doc {heading: 'Operations', subheading: 'String'}
24251 */
24252 function stringToHashBucketFast_(input, numBuckets) {
24253 const $input = convertToTensor(input, 'input', 'stringToHashBucketFast', 'string');
24254 const attrs = { numBuckets };
24255 if (numBuckets <= 0) {
24256 throw new Error(`Number of buckets must be at least 1`);
24257 }
24258 const inputs = { input: $input };
24259 return ENGINE.runKernel(StringToHashBucketFast, inputs, attrs);
24260 }
24261 const stringToHashBucketFast$2 = /* @__PURE__ */ op({ stringToHashBucketFast_ });
24262
24263 /**
24264 * @license
24265 * Copyright 2023 Google LLC.
24266 * Licensed under the Apache License, Version 2.0 (the "License");
24267 * you may not use this file except in compliance with the License.
24268 * You may obtain a copy of the License at
24269 *
24270 * http://www.apache.org/licenses/LICENSE-2.0
24271 *
24272 * Unless required by applicable law or agreed to in writing, software
24273 * distributed under the License is distributed on an "AS IS" BASIS,
24274 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24275 * See the License for the specific language governing permissions and
24276 * limitations under the License.
24277 * =============================================================================
24278 */
24279 /**
24280 * Replace the match of a `pattern` in `input` with `rewrite`.
24281 *
24282 * ```js
24283 * const result = tf.string.staticRegexReplace(
24284 * ['format this spacing better'], ' +', ' ');
24285 * result.print(); // ['format this spacing better']
24286 * ```
24287 * @param input: A Tensor of type string. The text to be processed.
24288 * @param pattern: A string. The regular expression to match the input.
24289 * @param rewrite: A string. The rewrite to be applied to the matched
24290 * expression.
24291 * @param replaceGlobal: An optional bool. Defaults to True. If True, the
24292 * replacement is global, otherwise the replacement is done only on the
24293 * first match.
24294 * @return A Tensor of type string.
24295 *
24296 * @doc {heading: 'Operations', subheading: 'String'}
24297 */
24298 function staticRegexReplace_(input, pattern, rewrite, replaceGlobal = true) {
24299 const $input = convertToTensor(input, 'input', 'staticRegexReplace', 'string');
24300 const attrs = { pattern, rewrite, replaceGlobal };
24301 return ENGINE.runKernel(StaticRegexReplace, { x: $input }, attrs);
24302 }
24303 const staticRegexReplace$2 = /* @__PURE__ */ op({ staticRegexReplace_ });
24304
24305 /**
24306 * @license
24307 * Copyright 2020 Google LLC. All Rights Reserved.
24308 * Licensed under the Apache License, Version 2.0 (the "License");
24309 * you may not use this file except in compliance with the License.
24310 * You may obtain a copy of the License at
24311 *
24312 * http://www.apache.org/licenses/LICENSE-2.0
24313 *
24314 * Unless required by applicable law or agreed to in writing, software
24315 * distributed under the License is distributed on an "AS IS" BASIS,
24316 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24317 * See the License for the specific language governing permissions and
24318 * limitations under the License.
24319 * =============================================================================
24320 */
24321 const spectral$1 = {
24322 fft: fft$2,
24323 ifft: ifft$2,
24324 rfft,
24325 irfft
24326 };
24327 const signal = {
24328 hammingWindow,
24329 hannWindow,
24330 frame,
24331 stft,
24332 };
24333 const image$1 = {
24334 flipLeftRight,
24335 grayscaleToRGB,
24336 resizeNearestNeighbor: resizeNearestNeighbor$2,
24337 resizeBilinear: resizeBilinear$3,
24338 rgbToGrayscale,
24339 rotateWithOffset,
24340 cropAndResize: cropAndResize$3,
24341 nonMaxSuppression,
24342 nonMaxSuppressionAsync,
24343 nonMaxSuppressionWithScore,
24344 nonMaxSuppressionWithScoreAsync,
24345 nonMaxSuppressionPadded,
24346 nonMaxSuppressionPaddedAsync,
24347 threshold: threshold$1,
24348 transform: transform$2
24349 };
24350 const linalg = {
24351 bandPart,
24352 gramSchmidt,
24353 qr
24354 };
24355 const losses = {
24356 absoluteDifference,
24357 computeWeightedLoss: computeWeightedLoss$1,
24358 cosineDistance,
24359 hingeLoss,
24360 huberLoss,
24361 logLoss,
24362 meanSquaredError: meanSquaredError$2,
24363 sigmoidCrossEntropy,
24364 softmaxCrossEntropy
24365 };
24366 const sparse$1 = {
24367 sparseFillEmptyRows: sparseFillEmptyRows$2,
24368 sparseReshape: sparseReshape$2,
24369 sparseSegmentMean: sparseSegmentMean$2,
24370 sparseSegmentSum: sparseSegmentSum$2
24371 };
24372 // tslint:disable-next-line:variable-name
24373 const string$1 = {
24374 stringNGrams: stringNGrams$2,
24375 stringSplit: stringSplit$2,
24376 stringToHashBucketFast: stringToHashBucketFast$2,
24377 staticRegexReplace: staticRegexReplace$2,
24378 };
24379
24380 /**
24381 * @license
24382 * Copyright 2018 Google LLC. All Rights Reserved.
24383 * Licensed under the Apache License, Version 2.0 (the "License");
24384 * you may not use this file except in compliance with the License.
24385 * You may obtain a copy of the License at
24386 *
24387 * http://www.apache.org/licenses/LICENSE-2.0
24388 *
24389 * Unless required by applicable law or agreed to in writing, software
24390 * distributed under the License is distributed on an "AS IS" BASIS,
24391 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24392 * See the License for the specific language governing permissions and
24393 * limitations under the License.
24394 * =============================================================================
24395 */
24396 /**
24397 * Maps to mapping between the custom object and its name.
24398 *
24399 * After registering a custom class, these two maps will add key-value pairs
24400 * for the class object and the registered name.
24401 *
24402 * Therefore we can get the relative registered name by calling
24403 * getRegisteredName() function.
24404 *
24405 * For example:
24406 * GLOBAL_CUSTOM_OBJECT: {key=registeredName: value=corresponding
24407 * CustomObjectClass}
24408 *
24409 * GLOBAL_CUSTOM_NAMES: {key=CustomObjectClass: value=corresponding
24410 * registeredName}
24411 *
24412 */
24413 const GLOBAL_CUSTOM_OBJECT = new Map();
24414 const GLOBAL_CUSTOM_NAMES = new Map();
24415 /**
24416 * Serializable defines the serialization contract.
24417 *
24418 * TFJS requires serializable classes to return their className when asked
24419 * to avoid issues with minification.
24420 */
24421 class Serializable {
24422 /**
24423 * Return the class name for this class to use in serialization contexts.
24424 *
24425 * Generally speaking this will be the same thing that constructor.name
24426 * would have returned. However, the class name needs to be robust
24427 * against minification for serialization/deserialization to work properly.
24428 *
24429 * There's also places such as initializers.VarianceScaling, where
24430 * implementation details between different languages led to different
24431 * class hierarchies and a non-leaf node is used for serialization purposes.
24432 */
24433 getClassName() {
24434 return this.constructor
24435 .className;
24436 }
24437 /**
24438 * Creates an instance of T from a ConfigDict.
24439 *
24440 * This works for most descendants of serializable. A few need to
24441 * provide special handling.
24442 * @param cls A Constructor for the class to instantiate.
24443 * @param config The Configuration for the object.
24444 */
24445 /** @nocollapse */
24446 static fromConfig(cls, config) {
24447 return new cls(config);
24448 }
24449 }
24450 /**
24451 * Maps string keys to class constructors.
24452 *
24453 * Used during (de)serialization from the cross-language JSON format, which
24454 * requires the class name in the serialization format matches the class
24455 * names as used in Python, should it exist.
24456 */
24457 class SerializationMap {
24458 constructor() {
24459 this.classNameMap = {};
24460 }
24461 /**
24462 * Returns the singleton instance of the map.
24463 */
24464 static getMap() {
24465 if (SerializationMap.instance == null) {
24466 SerializationMap.instance = new SerializationMap();
24467 }
24468 return SerializationMap.instance;
24469 }
24470 /**
24471 * Registers the class as serializable.
24472 */
24473 static register(cls) {
24474 SerializationMap.getMap().classNameMap[cls.className] =
24475 [cls, cls.fromConfig];
24476 }
24477 }
24478 /**
24479 * Register a class with the serialization map of TensorFlow.js.
24480 *
24481 * This is often used for registering custom Layers, so they can be
24482 * serialized and deserialized.
24483 *
24484 * Example 1. Register the class without package name and specified name.
24485 *
24486 * ```js
24487 * class MyCustomLayer extends tf.layers.Layer {
24488 * static className = 'MyCustomLayer';
24489 *
24490 * constructor(config) {
24491 * super(config);
24492 * }
24493 * }
24494 * tf.serialization.registerClass(MyCustomLayer);
24495 * console.log(tf.serialization.GLOBALCUSTOMOBJECT.get("Custom>MyCustomLayer"));
24496 * console.log(tf.serialization.GLOBALCUSTOMNAMES.get(MyCustomLayer));
24497 * ```
24498 *
24499 * Example 2. Register the class with package name: "Package" and specified
24500 * name: "MyLayer".
24501 * ```js
24502 * class MyCustomLayer extends tf.layers.Layer {
24503 * static className = 'MyCustomLayer';
24504 *
24505 * constructor(config) {
24506 * super(config);
24507 * }
24508 * }
24509 * tf.serialization.registerClass(MyCustomLayer, "Package", "MyLayer");
24510 * console.log(tf.serialization.GLOBALCUSTOMOBJECT.get("Package>MyLayer"));
24511 * console.log(tf.serialization.GLOBALCUSTOMNAMES.get(MyCustomLayer));
24512 * ```
24513 *
24514 * Example 3. Register the class with specified name: "MyLayer".
24515 * ```js
24516 * class MyCustomLayer extends tf.layers.Layer {
24517 * static className = 'MyCustomLayer';
24518 *
24519 * constructor(config) {
24520 * super(config);
24521 * }
24522 * }
24523 * tf.serialization.registerClass(MyCustomLayer, undefined, "MyLayer");
24524 * console.log(tf.serialization.GLOBALCUSTOMOBJECT.get("Custom>MyLayer"));
24525 * console.log(tf.serialization.GLOBALCUSTOMNAMES.get(MyCustomLayer));
24526 * ```
24527 *
24528 * Example 4. Register the class with specified package name: "Package".
24529 * ```js
24530 * class MyCustomLayer extends tf.layers.Layer {
24531 * static className = 'MyCustomLayer';
24532 *
24533 * constructor(config) {
24534 * super(config);
24535 * }
24536 * }
24537 * tf.serialization.registerClass(MyCustomLayer, "Package");
24538 * console.log(tf.serialization.GLOBALCUSTOMOBJECT
24539 * .get("Package>MyCustomLayer"));
24540 * console.log(tf.serialization.GLOBALCUSTOMNAMES
24541 * .get(MyCustomLayer));
24542 * ```
24543 *
24544 * @param cls The class to be registered. It must have a public static member
24545 * called `className` defined and the value must be a non-empty string.
24546 * @param pkg The package name that this class belongs to. This used to define
24547 * the key in GlobalCustomObject. If not defined, it defaults to `Custom`.
24548 * @param name The name that user specified. It defaults to the actual name of
24549 * the class as specified by its static `className` property.
24550 * @doc {heading: 'Models', subheading: 'Serialization', ignoreCI: true}
24551 */
24552 function registerClass(cls, pkg, name) {
24553 assert$1(cls.className != null, () => `Class being registered does not have the static className ` +
24554 `property defined.`);
24555 assert$1(typeof cls.className === 'string', () => `className is required to be a string, but got type ` +
24556 typeof cls.className);
24557 assert$1(cls.className.length > 0, () => `Class being registered has an empty-string as its className, ` +
24558 `which is disallowed.`);
24559 if (typeof pkg === 'undefined') {
24560 pkg = 'Custom';
24561 }
24562 if (typeof name === 'undefined') {
24563 name = cls.className;
24564 }
24565 const className = name;
24566 const registerName = pkg + '>' + className;
24567 SerializationMap.register(cls);
24568 GLOBAL_CUSTOM_OBJECT.set(registerName, cls);
24569 GLOBAL_CUSTOM_NAMES.set(cls, registerName);
24570 return cls;
24571 }
24572 /**
24573 * Get the registered name of a class. If the class has not been registered,
24574 * return the class name.
24575 *
24576 * @param cls The class we want to get register name for. It must have a public
24577 * static member called `className` defined.
24578 * @returns registered name or class name.
24579 */
24580 function getRegisteredName(cls) {
24581 if (GLOBAL_CUSTOM_NAMES.has(cls)) {
24582 return GLOBAL_CUSTOM_NAMES.get(cls);
24583 }
24584 else {
24585 return cls.className;
24586 }
24587 }
24588
24589 var serialization = /*#__PURE__*/Object.freeze({
24590 __proto__: null,
24591 Serializable: Serializable,
24592 SerializationMap: SerializationMap,
24593 getRegisteredName: getRegisteredName,
24594 registerClass: registerClass
24595 });
24596
24597 /**
24598 * @license
24599 * Copyright 2018 Google LLC. All Rights Reserved.
24600 * Licensed under the Apache License, Version 2.0 (the "License");
24601 * you may not use this file except in compliance with the License.
24602 * You may obtain a copy of the License at
24603 *
24604 * http://www.apache.org/licenses/LICENSE-2.0
24605 *
24606 * Unless required by applicable law or agreed to in writing, software
24607 * distributed under the License is distributed on an "AS IS" BASIS,
24608 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24609 * See the License for the specific language governing permissions and
24610 * limitations under the License.
24611 * =============================================================================
24612 */
24613 /** @doc {heading: 'Training', subheading: 'Classes', namespace: 'train'} */
24614 class Optimizer extends Serializable {
24615 /**
24616 * Executes `f()` and minimizes the scalar output of `f()` by computing
24617 * gradients of y with respect to the list of trainable variables provided by
24618 * `varList`. If no list is provided, it defaults to all trainable variables.
24619 *
24620 * @param f The function to execute and whose output to minimize.
24621 * @param returnCost Whether to return the scalar cost value produced by
24622 * executing `f()`.
24623 * @param varList An optional list of variables to update. If specified, only
24624 * the trainable variables in varList will be updated by minimize. Defaults to
24625 * all trainable variables.
24626 *
24627 * @doc {heading: 'Training', subheading: 'Optimizers'}
24628 */
24629 minimize(f, returnCost = false, varList) {
24630 const { value, grads } = this.computeGradients(f, varList);
24631 if (varList != null) {
24632 const gradArray = varList.map(v => ({ name: v.name, tensor: grads[v.name] }));
24633 this.applyGradients(gradArray);
24634 }
24635 else {
24636 this.applyGradients(grads);
24637 }
24638 // Dispose gradients.
24639 dispose(grads);
24640 if (returnCost) {
24641 return value;
24642 }
24643 else {
24644 value.dispose();
24645 return null;
24646 }
24647 }
24648 /**
24649 * The number of iterations that this optimizer instance has been invoked for.
24650 */
24651 get iterations() {
24652 if (this.iterations_ == null) {
24653 this.iterations_ = 0;
24654 }
24655 return this.iterations_;
24656 }
24657 incrementIterations() {
24658 this.iterations_ = this.iterations + 1;
24659 }
24660 /**
24661 * Executes f() and computes the gradient of the scalar output of f() with
24662 * respect to the list of trainable variables provided by `varList`. If no
24663 * list is provided, it defaults to all trainable variables.
24664 *
24665 * @param f The function to execute and whose output to use for computing
24666 * gradients with respect to variables.
24667 * @param varList An optional list of variables to compute gradients with
24668 * respect to. If specified, only the trainable variables in varList will have
24669 * gradients computed with respect to. Defaults to all trainable variables.
24670 *
24671 * @doc {heading: 'Training', subheading: 'Optimizers'}
24672 */
24673 computeGradients(f, varList) {
24674 return variableGrads(f, varList);
24675 }
24676 /**
24677 * Dispose the variables (if any) owned by this optimizer instance.
24678 */
24679 dispose() {
24680 if (this.iterations_ != null) {
24681 dispose(this.iterations_);
24682 }
24683 }
24684 async saveIterations() {
24685 if (this.iterations_ == null) {
24686 this.iterations_ = 0;
24687 }
24688 return {
24689 name: 'iter',
24690 // TODO(cais): Use 'int64' type when available.
24691 tensor: scalar(this.iterations_, 'int32')
24692 };
24693 }
24694 async getWeights() {
24695 throw new Error('getWeights() is not implemented for this optimizer yet.');
24696 }
24697 async setWeights(weightValues) {
24698 throw new Error(`setWeights() is not implemented for this optimizer class ` +
24699 `${this.getClassName()}`);
24700 }
24701 /**
24702 * Extract the first element of the weight values and set it
24703 * as the iterations counter variable of this instance of optimizer.
24704 *
24705 * @param weightValues
24706 * @returns Weight values with the first element consumed and excluded.
24707 */
24708 async extractIterations(weightValues) {
24709 this.iterations_ = (await weightValues[0].tensor.data())[0];
24710 return weightValues.slice(1);
24711 }
24712 }
24713 Object.defineProperty(Optimizer, Symbol.hasInstance, {
24714 value: (instance) => {
24715 return instance.minimize != null && instance.computeGradients != null &&
24716 instance.applyGradients != null;
24717 }
24718 });
24719
24720 /**
24721 * @license
24722 * Copyright 2018 Google LLC. All Rights Reserved.
24723 * Licensed under the Apache License, Version 2.0 (the "License");
24724 * you may not use this file except in compliance with the License.
24725 * You may obtain a copy of the License at
24726 *
24727 * http://www.apache.org/licenses/LICENSE-2.0
24728 *
24729 * Unless required by applicable law or agreed to in writing, software
24730 * distributed under the License is distributed on an "AS IS" BASIS,
24731 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24732 * See the License for the specific language governing permissions and
24733 * limitations under the License.
24734 * =============================================================================
24735 */
24736 /** @doclink Optimizer */
24737 class AdadeltaOptimizer extends Optimizer {
24738 /** @nocollapse */
24739 static get className() {
24740 // Name matters for Python compatibility.
24741 // This is a getter instead of a property because when it's a property, it
24742 // prevents the entire class from being tree-shaken.
24743 return 'Adadelta';
24744 }
24745 constructor(learningRate, rho, epsilon = null) {
24746 super();
24747 this.learningRate = learningRate;
24748 this.rho = rho;
24749 this.epsilon = epsilon;
24750 this.accumulatedGrads = [];
24751 this.accumulatedUpdates = [];
24752 if (epsilon == null) {
24753 this.epsilon = ENGINE.backend.epsilon();
24754 }
24755 }
24756 applyGradients(variableGradients) {
24757 const variableNames = Array.isArray(variableGradients) ?
24758 variableGradients.map(item => item.name) :
24759 Object.keys(variableGradients);
24760 variableNames.forEach((name, i) => {
24761 const value = ENGINE.registeredVariables[name];
24762 const trainable = false;
24763 if (this.accumulatedGrads[i] == null) {
24764 this.accumulatedGrads[i] = {
24765 originalName: `${name}/accum_grad`,
24766 variable: tidy(() => zerosLike$3(value).variable(trainable))
24767 };
24768 }
24769 if (this.accumulatedUpdates[i] == null) {
24770 this.accumulatedUpdates[i] = {
24771 originalName: `${name}/accum_var`,
24772 variable: tidy(() => zerosLike$3(value).variable(trainable))
24773 };
24774 }
24775 const gradient = Array.isArray(variableGradients) ?
24776 variableGradients[i].tensor :
24777 variableGradients[name];
24778 if (gradient == null) {
24779 return;
24780 }
24781 const accumulatedGrad = this.accumulatedGrads[i].variable;
24782 const accumulatedUpdate = this.accumulatedUpdates[i].variable;
24783 tidy(() => {
24784 const newAccumulatedGrad = add$3(mul(accumulatedGrad, this.rho), mul(square$2(gradient), 1 - this.rho));
24785 const updates = mul(div$1(sqrt$2(add$3(accumulatedUpdate, this.epsilon)), sqrt$2(add$3(accumulatedGrad, this.epsilon))), gradient);
24786 const newAccumulatedUpdate = add$3(mul(accumulatedUpdate, this.rho), mul(square$2(updates), 1 - this.rho));
24787 accumulatedGrad.assign(newAccumulatedGrad);
24788 accumulatedUpdate.assign(newAccumulatedUpdate);
24789 const newValue = add$3(mul(updates, -this.learningRate), value);
24790 value.assign(newValue);
24791 });
24792 });
24793 this.incrementIterations();
24794 }
24795 dispose() {
24796 if (this.accumulatedUpdates != null) {
24797 dispose(this.accumulatedGrads.map(v => v.variable));
24798 dispose(this.accumulatedUpdates.map(v => v.variable));
24799 }
24800 }
24801 async getWeights() {
24802 // Order matters for Python compatibility.
24803 const variables = [...this.accumulatedGrads, ...this.accumulatedUpdates];
24804 return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable })));
24805 }
24806 async setWeights(weightValues) {
24807 weightValues = await this.extractIterations(weightValues);
24808 const variableCount = weightValues.length / 2;
24809 const trainable = false;
24810 this.accumulatedGrads =
24811 weightValues.slice(0, variableCount).map(v => ({
24812 originalName: v.name,
24813 variable: v.tensor.variable(trainable)
24814 }));
24815 this.accumulatedUpdates =
24816 weightValues.slice(variableCount, variableCount * 2)
24817 .map(v => ({
24818 originalName: v.name,
24819 variable: v.tensor.variable(trainable)
24820 }));
24821 }
24822 getConfig() {
24823 return {
24824 'learningRate': this.learningRate,
24825 'rho': this.rho,
24826 'epsilon': this.epsilon
24827 };
24828 }
24829 /** @nocollapse */
24830 static fromConfig(cls, config) {
24831 return new cls(config['learningRate'], config['rho'], config['epsilon']);
24832 }
24833 }
24834
24835 /**
24836 * @license
24837 * Copyright 2018 Google LLC. All Rights Reserved.
24838 * Licensed under the Apache License, Version 2.0 (the "License");
24839 * you may not use this file except in compliance with the License.
24840 * You may obtain a copy of the License at
24841 *
24842 * http://www.apache.org/licenses/LICENSE-2.0
24843 *
24844 * Unless required by applicable law or agreed to in writing, software
24845 * distributed under the License is distributed on an "AS IS" BASIS,
24846 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24847 * See the License for the specific language governing permissions and
24848 * limitations under the License.
24849 * =============================================================================
24850 */
24851 /** @doclink Optimizer */
24852 class AdagradOptimizer extends Optimizer {
24853 /** @nocollapse */
24854 static get className() {
24855 // Name matters for Python compatibility.
24856 // This is a getter instead of a property because when it's a property, it
24857 // prevents the entire class from being tree-shaken.
24858 return 'Adagrad';
24859 }
24860 constructor(learningRate, initialAccumulatorValue = 0.1) {
24861 super();
24862 this.learningRate = learningRate;
24863 this.initialAccumulatorValue = initialAccumulatorValue;
24864 this.accumulatedGrads = [];
24865 }
24866 applyGradients(variableGradients) {
24867 const variableNames = Array.isArray(variableGradients) ?
24868 variableGradients.map(item => item.name) :
24869 Object.keys(variableGradients);
24870 variableNames.forEach((name, i) => {
24871 const value = ENGINE.registeredVariables[name];
24872 if (this.accumulatedGrads[i] == null) {
24873 const trainable = false;
24874 this.accumulatedGrads[i] = {
24875 originalName: `${name}/accumulator`,
24876 variable: tidy(() => fill$2(value.shape, this.initialAccumulatorValue)
24877 .variable(trainable))
24878 };
24879 }
24880 const gradient = Array.isArray(variableGradients) ?
24881 variableGradients[i].tensor :
24882 variableGradients[name];
24883 if (gradient == null) {
24884 return;
24885 }
24886 const accumulatedGrad = this.accumulatedGrads[i].variable;
24887 tidy(() => {
24888 const newAccumulatedGrad = add$3(accumulatedGrad, square$2(gradient));
24889 accumulatedGrad.assign(newAccumulatedGrad);
24890 const newValue = add$3(mul(div$1(gradient, sqrt$2(add$3(newAccumulatedGrad, ENGINE.backend.epsilon()))), -this.learningRate), value);
24891 value.assign(newValue);
24892 });
24893 });
24894 this.incrementIterations();
24895 }
24896 dispose() {
24897 if (this.accumulatedGrads != null) {
24898 dispose(this.accumulatedGrads.map(v => v.variable));
24899 }
24900 }
24901 async getWeights() {
24902 // Order matters for Python compatibility.
24903 return [await this.saveIterations()].concat(this.accumulatedGrads.map(v => ({ name: v.originalName, tensor: v.variable })));
24904 }
24905 async setWeights(weightValues) {
24906 weightValues = await this.extractIterations(weightValues);
24907 const trainable = false;
24908 this.accumulatedGrads = weightValues.map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) }));
24909 }
24910 getConfig() {
24911 return {
24912 'learningRate': this.learningRate,
24913 'initialAccumulatorValue': this.initialAccumulatorValue,
24914 };
24915 }
24916 /** @nocollapse */
24917 static fromConfig(cls, config) {
24918 return new cls(config['learningRate'], config['initialAccumulatorValue']);
24919 }
24920 }
24921
24922 /**
24923 * @license
24924 * Copyright 2018 Google LLC. All Rights Reserved.
24925 * Licensed under the Apache License, Version 2.0 (the "License");
24926 * you may not use this file except in compliance with the License.
24927 * You may obtain a copy of the License at
24928 *
24929 * http://www.apache.org/licenses/LICENSE-2.0
24930 *
24931 * Unless required by applicable law or agreed to in writing, software
24932 * distributed under the License is distributed on an "AS IS" BASIS,
24933 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24934 * See the License for the specific language governing permissions and
24935 * limitations under the License.
24936 * =============================================================================
24937 */
24938 class AdamOptimizer extends Optimizer {
24939 /** @nocollapse */
24940 static get className() {
24941 // Name matters for Python compatibility.
24942 // This is a getter instead of a property because when it's a property, it
24943 // prevents the entire class from being tree-shaken.
24944 return 'Adam';
24945 }
24946 constructor(learningRate, beta1, beta2, epsilon = null) {
24947 super();
24948 this.learningRate = learningRate;
24949 this.beta1 = beta1;
24950 this.beta2 = beta2;
24951 this.epsilon = epsilon;
24952 this.accumulatedFirstMoment = [];
24953 this.accumulatedSecondMoment = [];
24954 tidy(() => {
24955 // accB* will be updated by batch.
24956 this.accBeta1 = scalar(beta1).variable();
24957 this.accBeta2 = scalar(beta2).variable();
24958 });
24959 if (epsilon == null) {
24960 this.epsilon = ENGINE.backend.epsilon();
24961 }
24962 }
24963 applyGradients(variableGradients) {
24964 const varNames = Array.isArray(variableGradients) ?
24965 variableGradients.map(v => v.name) :
24966 Object.keys(variableGradients);
24967 tidy(() => {
24968 const oneMinusAccBeta1 = sub$2(1, this.accBeta1);
24969 const oneMinusAccBeta2 = sub$2(1, this.accBeta2);
24970 varNames.forEach((name, i) => {
24971 const value = ENGINE.registeredVariables[name];
24972 const trainable = false;
24973 if (this.accumulatedFirstMoment[i] == null) {
24974 this.accumulatedFirstMoment[i] = {
24975 originalName: `${name}/m`,
24976 variable: tidy(() => zerosLike$3(value).variable(trainable))
24977 };
24978 }
24979 if (this.accumulatedSecondMoment[i] == null) {
24980 this.accumulatedSecondMoment[i] = {
24981 originalName: `${name}/v`,
24982 variable: tidy(() => zerosLike$3(value).variable(trainable))
24983 };
24984 }
24985 const gradient = Array.isArray(variableGradients) ?
24986 variableGradients[i].tensor :
24987 variableGradients[name];
24988 if (gradient == null) {
24989 return;
24990 }
24991 const firstMoment = this.accumulatedFirstMoment[i].variable;
24992 const secondMoment = this.accumulatedSecondMoment[i].variable;
24993 const newFirstMoment = add$3(mul(firstMoment, this.beta1), mul(gradient, 1 - this.beta1));
24994 const newSecondMoment = add$3(mul(secondMoment, this.beta2), mul(square$2(gradient), 1 - this.beta2));
24995 const biasCorrectedFirstMoment = div$1(newFirstMoment, oneMinusAccBeta1);
24996 const biasCorrectedSecondMoment = div$1(newSecondMoment, oneMinusAccBeta2);
24997 firstMoment.assign(newFirstMoment);
24998 secondMoment.assign(newSecondMoment);
24999 const newValue = add$3(mul(div$1(biasCorrectedFirstMoment, add$3(sqrt$2(biasCorrectedSecondMoment), this.epsilon)), -this.learningRate), value);
25000 value.assign(newValue);
25001 });
25002 this.accBeta1.assign(mul(this.accBeta1, this.beta1));
25003 this.accBeta2.assign(mul(this.accBeta2, this.beta2));
25004 });
25005 this.incrementIterations();
25006 }
25007 dispose() {
25008 this.accBeta1.dispose();
25009 this.accBeta2.dispose();
25010 if (this.accumulatedFirstMoment != null) {
25011 dispose(this.accumulatedFirstMoment.map(v => v.variable));
25012 }
25013 if (this.accumulatedSecondMoment != null) {
25014 dispose(this.accumulatedSecondMoment.map(v => v.variable));
25015 }
25016 }
25017 async getWeights() {
25018 // Order matters for Python compatibility.
25019 const variables = [...this.accumulatedFirstMoment, ...this.accumulatedSecondMoment];
25020 return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable })));
25021 }
25022 async setWeights(weightValues) {
25023 weightValues = await this.extractIterations(weightValues);
25024 tidy(() => {
25025 this.accBeta1.assign(pow$3(this.beta1, this.iterations_ + 1));
25026 this.accBeta2.assign(pow$3(this.beta2, this.iterations_ + 1));
25027 });
25028 const variableCount = weightValues.length / 2;
25029 const trainable = false;
25030 this.accumulatedFirstMoment =
25031 weightValues.slice(0, variableCount).map(v => ({
25032 originalName: v.name,
25033 variable: v.tensor.variable(trainable)
25034 }));
25035 this.accumulatedSecondMoment =
25036 weightValues.slice(variableCount, variableCount * 2)
25037 .map(v => ({
25038 originalName: v.name,
25039 variable: v.tensor.variable(trainable)
25040 }));
25041 }
25042 getConfig() {
25043 return {
25044 'learningRate': this.learningRate,
25045 'beta1': this.beta1,
25046 'beta2': this.beta2,
25047 'epsilon': this.epsilon,
25048 };
25049 }
25050 /** @nocollapse */
25051 static fromConfig(cls, config) {
25052 return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon']);
25053 }
25054 }
25055
25056 /**
25057 * @license
25058 * Copyright 2018 Google LLC. All Rights Reserved.
25059 * Licensed under the Apache License, Version 2.0 (the "License");
25060 * you may not use this file except in compliance with the License.
25061 * You may obtain a copy of the License at
25062 *
25063 * http://www.apache.org/licenses/LICENSE-2.0
25064 *
25065 * Unless required by applicable law or agreed to in writing, software
25066 * distributed under the License is distributed on an "AS IS" BASIS,
25067 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25068 * See the License for the specific language governing permissions and
25069 * limitations under the License.
25070 * =============================================================================
25071 */
25072 class AdamaxOptimizer extends Optimizer {
25073 /** @nocollapse */
25074 static get className() {
25075 // Name matters for Python compatibility.
25076 // This is a getter instead of a property because when it's a property, it
25077 // prevents the entire class from being tree-shaken.
25078 return 'Adamax';
25079 }
25080 constructor(learningRate, beta1, beta2, epsilon = null, decay = 0.0) {
25081 super();
25082 this.learningRate = learningRate;
25083 this.beta1 = beta1;
25084 this.beta2 = beta2;
25085 this.epsilon = epsilon;
25086 this.decay = decay;
25087 this.accumulatedFirstMoment = [];
25088 this.accumulatedWeightedInfNorm = [];
25089 tidy(() => {
25090 this.iteration = scalar(0).variable();
25091 this.accBeta1 = scalar(beta1).variable();
25092 });
25093 if (epsilon == null) {
25094 this.epsilon = ENGINE.backend.epsilon();
25095 }
25096 }
25097 applyGradients(variableGradients) {
25098 const variableNames = Array.isArray(variableGradients) ?
25099 variableGradients.map(item => item.name) :
25100 Object.keys(variableGradients);
25101 tidy(() => {
25102 const oneMinusAccBeta1 = sub$2(1, this.accBeta1);
25103 const lr = div$1(-this.learningRate, add$3(mul(this.iteration, this.decay), 1));
25104 variableNames.forEach((name, i) => {
25105 const value = ENGINE.registeredVariables[name];
25106 const trainable = false;
25107 if (this.accumulatedFirstMoment[i] == null) {
25108 this.accumulatedFirstMoment[i] = {
25109 originalName: `${name}/m`,
25110 variable: zerosLike$3(value).variable(trainable)
25111 };
25112 }
25113 if (this.accumulatedWeightedInfNorm[i] == null) {
25114 this.accumulatedWeightedInfNorm[i] = {
25115 originalName: `${name}/v`,
25116 variable: zerosLike$3(value).variable(trainable)
25117 };
25118 }
25119 const gradient = Array.isArray(variableGradients) ?
25120 variableGradients[i].tensor :
25121 variableGradients[name];
25122 if (gradient == null) {
25123 return;
25124 }
25125 const firstMoment = this.accumulatedFirstMoment[i].variable;
25126 const weightedInfNorm = this.accumulatedWeightedInfNorm[i].variable;
25127 const newFirstMoment = add$3(mul(firstMoment, this.beta1), mul(gradient, 1 - this.beta1));
25128 const ut0 = mul(weightedInfNorm, this.beta2);
25129 const ut1 = abs$2(gradient);
25130 const newWeightedInfNorm = maximum$4(ut0, ut1);
25131 firstMoment.assign(newFirstMoment);
25132 weightedInfNorm.assign(newWeightedInfNorm);
25133 const newValue = add$3(mul(div$1(lr, oneMinusAccBeta1), div$1(newFirstMoment, add$3(newWeightedInfNorm, this.epsilon))), value);
25134 value.assign(newValue);
25135 });
25136 this.iteration.assign(add$3(this.iteration, 1));
25137 this.accBeta1.assign(mul(this.accBeta1, this.beta1));
25138 });
25139 this.incrementIterations();
25140 }
25141 dispose() {
25142 this.accBeta1.dispose();
25143 this.iteration.dispose();
25144 if (this.accumulatedFirstMoment != null) {
25145 dispose(this.accumulatedFirstMoment.map(v => v.variable));
25146 }
25147 if (this.accumulatedWeightedInfNorm != null) {
25148 dispose(this.accumulatedWeightedInfNorm.map(v => v.variable));
25149 }
25150 }
25151 async getWeights() {
25152 throw new Error('getWeights() is not implemented for Adamax yet.');
25153 }
25154 async setWeights(weightValues) {
25155 throw new Error('setWeights() is not implemented for Adamax yet.');
25156 }
25157 getConfig() {
25158 return {
25159 'learningRate': this.learningRate,
25160 'beta1': this.beta1,
25161 'beta2': this.beta2,
25162 'epsilon': this.epsilon,
25163 'decay': this.decay
25164 };
25165 }
25166 /** @nocollapse */
25167 static fromConfig(cls, config) {
25168 return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon'], config['decay']);
25169 }
25170 }
25171
25172 /**
25173 * @license
25174 * Copyright 2018 Google LLC. All Rights Reserved.
25175 * Licensed under the Apache License, Version 2.0 (the "License");
25176 * you may not use this file except in compliance with the License.
25177 * You may obtain a copy of the License at
25178 *
25179 * http://www.apache.org/licenses/LICENSE-2.0
25180 *
25181 * Unless required by applicable law or agreed to in writing, software
25182 * distributed under the License is distributed on an "AS IS" BASIS,
25183 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25184 * See the License for the specific language governing permissions and
25185 * limitations under the License.
25186 * =============================================================================
25187 */
25188 /** @doclink Optimizer */
25189 class SGDOptimizer extends Optimizer {
25190 /** @nocollapse */
25191 static get className() {
25192 // Name matters for Python compatibility.
25193 // This is a getter instead of a property because when it's a property, it
25194 // prevents the entire class from being tree-shaken.
25195 return 'SGD';
25196 }
25197 constructor(learningRate) {
25198 super();
25199 this.learningRate = learningRate;
25200 this.setLearningRate(learningRate);
25201 }
25202 applyGradients(variableGradients) {
25203 const varNames = Array.isArray(variableGradients) ?
25204 variableGradients.map(v => v.name) :
25205 Object.keys(variableGradients);
25206 varNames.forEach((name, i) => {
25207 const gradient = Array.isArray(variableGradients) ?
25208 variableGradients[i].tensor :
25209 variableGradients[name];
25210 if (gradient == null) {
25211 return;
25212 }
25213 const value = ENGINE.registeredVariables[name];
25214 tidy(() => {
25215 const newValue = add$3(mul(this.c, gradient), value);
25216 value.assign(newValue);
25217 });
25218 });
25219 this.incrementIterations();
25220 }
25221 /**
25222 * Sets the learning rate of the optimizer.
25223 */
25224 setLearningRate(learningRate) {
25225 this.learningRate = learningRate;
25226 if (this.c != null) {
25227 this.c.dispose();
25228 }
25229 this.c = keep(scalar(-learningRate));
25230 }
25231 dispose() {
25232 this.c.dispose();
25233 }
25234 async getWeights() {
25235 return [await this.saveIterations()];
25236 }
25237 async setWeights(weightValues) {
25238 weightValues = await this.extractIterations(weightValues);
25239 if (weightValues.length !== 0) {
25240 throw new Error('SGD optimizer does not have settable weights.');
25241 }
25242 }
25243 getConfig() {
25244 return { 'learningRate': this.learningRate };
25245 }
25246 /** @nocollapse */
25247 static fromConfig(cls, config) {
25248 return new cls(config['learningRate']);
25249 }
25250 }
25251
25252 /**
25253 * @license
25254 * Copyright 2018 Google LLC. All Rights Reserved.
25255 * Licensed under the Apache License, Version 2.0 (the "License");
25256 * you may not use this file except in compliance with the License.
25257 * You may obtain a copy of the License at
25258 *
25259 * http://www.apache.org/licenses/LICENSE-2.0
25260 *
25261 * Unless required by applicable law or agreed to in writing, software
25262 * distributed under the License is distributed on an "AS IS" BASIS,
25263 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25264 * See the License for the specific language governing permissions and
25265 * limitations under the License.
25266 * =============================================================================
25267 */
25268 /** @doclink Optimizer */
25269 class MomentumOptimizer extends SGDOptimizer {
25270 /** @nocollapse */
25271 // Name matters for Python compatibility.
25272 static get className() {
25273 // Name matters for Python compatibility.
25274 // This is a getter instead of a property because when it's a property, it
25275 // prevents the entire class from being tree-shaken.
25276 return 'Momentum';
25277 }
25278 constructor(learningRate, momentum, useNesterov = false) {
25279 super(learningRate);
25280 this.learningRate = learningRate;
25281 this.momentum = momentum;
25282 this.useNesterov = useNesterov;
25283 this.accumulations = [];
25284 this.m = scalar(this.momentum);
25285 }
25286 applyGradients(variableGradients) {
25287 const variableNames = Array.isArray(variableGradients) ?
25288 variableGradients.map(item => item.name) :
25289 Object.keys(variableGradients);
25290 variableNames.forEach((name, i) => {
25291 const value = ENGINE.registeredVariables[name];
25292 if (this.accumulations[i] == null) {
25293 const trainable = false;
25294 this.accumulations[i] = {
25295 originalName: `${name}/momentum`,
25296 variable: tidy(() => zerosLike$3(value).variable(trainable))
25297 };
25298 }
25299 const accumulation = this.accumulations[i].variable;
25300 const gradient = Array.isArray(variableGradients) ?
25301 variableGradients[i].tensor :
25302 variableGradients[name];
25303 if (gradient == null) {
25304 return;
25305 }
25306 tidy(() => {
25307 let newValue;
25308 const newAccumulation = add$3(mul(this.m, accumulation), gradient);
25309 if (this.useNesterov) {
25310 newValue = add$3(mul(this.c, add$3(gradient, mul(newAccumulation, this.m))), value);
25311 }
25312 else {
25313 newValue = add$3(mul(this.c, newAccumulation), value);
25314 }
25315 accumulation.assign(newAccumulation);
25316 value.assign(newValue);
25317 });
25318 });
25319 this.incrementIterations();
25320 }
25321 dispose() {
25322 this.m.dispose();
25323 if (this.accumulations != null) {
25324 dispose(this.accumulations.map(v => v.variable));
25325 }
25326 }
25327 /**
25328 * Sets the momentum of the optimizer.
25329 *
25330 * @param momentum
25331 */
25332 setMomentum(momentum) {
25333 this.momentum = momentum;
25334 }
25335 async getWeights() {
25336 // Order matters for Python compatibility.
25337 return [await this.saveIterations()].concat(this.accumulations.map(v => ({ name: v.originalName, tensor: v.variable })));
25338 }
25339 async setWeights(weightValues) {
25340 weightValues = await this.extractIterations(weightValues);
25341 const trainable = false;
25342 this.accumulations = weightValues.map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) }));
25343 }
25344 getConfig() {
25345 return {
25346 'learningRate': this.learningRate,
25347 'momentum': this.momentum,
25348 'useNesterov': this.useNesterov
25349 };
25350 }
25351 /** @nocollapse */
25352 static fromConfig(cls, config) {
25353 return new cls(config['learningRate'], config['momentum'], config['useNesterov']);
25354 }
25355 }
25356
25357 /**
25358 * @license
25359 * Copyright 2018 Google LLC. All Rights Reserved.
25360 * Licensed under the Apache License, Version 2.0 (the "License");
25361 * you may not use this file except in compliance with the License.
25362 * You may obtain a copy of the License at
25363 *
25364 * http://www.apache.org/licenses/LICENSE-2.0
25365 *
25366 * Unless required by applicable law or agreed to in writing, software
25367 * distributed under the License is distributed on an "AS IS" BASIS,
25368 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25369 * See the License for the specific language governing permissions and
25370 * limitations under the License.
25371 * =============================================================================
25372 */
25373 /** @doclink Optimizer */
25374 class RMSPropOptimizer extends Optimizer {
25375 /** @nocollapse */
25376 static get className() {
25377 // Name matters for Python compatibility.
25378 // This is a getter instead of a property because when it's a property, it
25379 // prevents the entire class from being tree-shaken.
25380 return 'RMSProp';
25381 }
25382 constructor(learningRate, decay = 0.9, momentum = 0.0, epsilon = null, centered = false) {
25383 super();
25384 this.learningRate = learningRate;
25385 this.decay = decay;
25386 this.momentum = momentum;
25387 this.epsilon = epsilon;
25388 this.accumulatedMeanSquares = [];
25389 this.accumulatedMoments = [];
25390 this.accumulatedMeanGrads = [];
25391 this.centered = centered;
25392 if (epsilon == null) {
25393 this.epsilon = ENGINE.backend.epsilon();
25394 }
25395 if (learningRate == null) {
25396 throw new Error(`learningRate for RMSPropOptimizer must be defined.`);
25397 }
25398 }
25399 applyGradients(variableGradients) {
25400 const variableNames = Array.isArray(variableGradients) ?
25401 variableGradients.map(item => item.name) :
25402 Object.keys(variableGradients);
25403 variableNames.forEach((name, i) => {
25404 const value = ENGINE.registeredVariables[name];
25405 const trainable = false;
25406 if (this.accumulatedMeanSquares[i] == null) {
25407 this.accumulatedMeanSquares[i] = {
25408 originalName: `${name}/rms`,
25409 variable: tidy(() => zerosLike$3(value).variable(trainable))
25410 };
25411 }
25412 if (this.accumulatedMoments[i] == null) {
25413 this.accumulatedMoments[i] = {
25414 originalName: `${name}/momentum`,
25415 variable: tidy(() => zerosLike$3(value).variable(trainable))
25416 };
25417 }
25418 if (this.accumulatedMeanGrads[i] == null && this.centered) {
25419 this.accumulatedMeanGrads[i] = {
25420 originalName: `${name}/mg`,
25421 variable: tidy(() => zerosLike$3(value).variable(trainable))
25422 };
25423 }
25424 const gradient = Array.isArray(variableGradients) ?
25425 variableGradients[i].tensor :
25426 variableGradients[name];
25427 if (gradient == null) {
25428 return;
25429 }
25430 const accumulatedMeanSquare = this.accumulatedMeanSquares[i].variable;
25431 const accumulatedMoments = this.accumulatedMoments[i].variable;
25432 tidy(() => {
25433 const newAccumulatedMeanSquare = add$3(mul(accumulatedMeanSquare, this.decay), mul(square$2(gradient), 1 - this.decay));
25434 if (this.centered) {
25435 const accumulatedMeanGrad = this.accumulatedMeanGrads[i].variable;
25436 // Centered gradient
25437 const newAccumulatedMeanGrad = add$3(mul(accumulatedMeanGrad, this.decay), mul(gradient, 1 - this.decay));
25438 const gradContribution = div$1(mul(gradient, this.learningRate), sqrt$2(sub$2(newAccumulatedMeanSquare, add$3(square$2(newAccumulatedMeanGrad), this.epsilon))));
25439 const newAccumulatedMoments = add$3(mul(accumulatedMoments, this.momentum), gradContribution);
25440 accumulatedMeanSquare.assign(newAccumulatedMeanSquare);
25441 accumulatedMeanGrad.assign(newAccumulatedMeanGrad);
25442 accumulatedMoments.assign(newAccumulatedMoments);
25443 const newValue = sub$2(value, newAccumulatedMoments);
25444 value.assign(newValue);
25445 }
25446 else {
25447 // Plain gradient
25448 const newAccumulatedMeanSquare = add$3(mul(accumulatedMeanSquare, this.decay), mul(square$2(gradient), 1 - this.decay));
25449 const newAccumulatedMoments = add$3(mul(accumulatedMoments, this.momentum), div$1(mul(gradient, this.learningRate), sqrt$2(add$3(newAccumulatedMeanSquare, this.epsilon))));
25450 accumulatedMeanSquare.assign(newAccumulatedMeanSquare);
25451 accumulatedMoments.assign(newAccumulatedMoments);
25452 const newValue = sub$2(value, newAccumulatedMoments);
25453 value.assign(newValue);
25454 }
25455 });
25456 });
25457 this.incrementIterations();
25458 }
25459 dispose() {
25460 if (this.accumulatedMeanSquares != null) {
25461 dispose(this.accumulatedMeanSquares.map(v => v.variable));
25462 }
25463 if (this.accumulatedMeanGrads != null && this.centered) {
25464 dispose(this.accumulatedMeanGrads.map(v => v.variable));
25465 }
25466 if (this.accumulatedMoments != null) {
25467 dispose(this.accumulatedMoments.map(v => v.variable));
25468 }
25469 }
25470 async getWeights() {
25471 // Order matters for Python compatibility.
25472 const variables = [...this.accumulatedMeanSquares, ...this.accumulatedMoments];
25473 if (this.centered) {
25474 variables.push(...this.accumulatedMeanGrads);
25475 }
25476 return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable })));
25477 }
25478 async setWeights(weightValues) {
25479 weightValues = await this.extractIterations(weightValues);
25480 const variableCount = this.centered ? weightValues.length / 3 : weightValues.length / 2;
25481 const trainable = false;
25482 this.accumulatedMeanSquares =
25483 weightValues.slice(0, variableCount).map(v => ({
25484 originalName: v.name,
25485 variable: v.tensor.variable(trainable)
25486 }));
25487 this.accumulatedMoments =
25488 weightValues.slice(variableCount, variableCount * 2)
25489 .map(v => ({
25490 originalName: v.name,
25491 variable: v.tensor.variable(trainable)
25492 }));
25493 if (this.centered) {
25494 this.accumulatedMeanGrads =
25495 weightValues.slice(variableCount * 2, variableCount * 3)
25496 .map(v => ({
25497 originalName: v.name,
25498 variable: v.tensor.variable(trainable)
25499 }));
25500 }
25501 }
25502 getConfig() {
25503 return {
25504 'learningRate': this.learningRate,
25505 'decay': this.decay,
25506 'momentum': this.momentum,
25507 'epsilon': this.epsilon,
25508 'centered': this.centered
25509 };
25510 }
25511 /** @nocollapse */
25512 static fromConfig(cls, config) {
25513 return new cls(config['learningRate'], config['decay'], config['momentum'], config['epsilon'], config['centered']);
25514 }
25515 }
25516
25517 /**
25518 * @license
25519 * Copyright 2022 Google LLC.
25520 * Licensed under the Apache License, Version 2.0 (the "License");
25521 * you may not use this file except in compliance with the License.
25522 * You may obtain a copy of the License at
25523 *
25524 * http://www.apache.org/licenses/LICENSE-2.0
25525 *
25526 * Unless required by applicable law or agreed to in writing, software
25527 * distributed under the License is distributed on an "AS IS" BASIS,
25528 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25529 * See the License for the specific language governing permissions and
25530 * limitations under the License.
25531 * =============================================================================
25532 */
25533 const OPTIMIZERS = [
25534 AdadeltaOptimizer,
25535 AdagradOptimizer,
25536 AdamOptimizer,
25537 AdamaxOptimizer,
25538 MomentumOptimizer,
25539 RMSPropOptimizer,
25540 SGDOptimizer,
25541 ];
25542 function registerOptimizers() {
25543 for (const optimizer of OPTIMIZERS) {
25544 registerClass(optimizer);
25545 }
25546 }
25547
25548 /**
25549 * @license
25550 * Copyright 2018 Google LLC. All Rights Reserved.
25551 * Licensed under the Apache License, Version 2.0 (the "License");
25552 * you may not use this file except in compliance with the License.
25553 * You may obtain a copy of the License at
25554 *
25555 * http://www.apache.org/licenses/LICENSE-2.0
25556 *
25557 * Unless required by applicable law or agreed to in writing, software
25558 * distributed under the License is distributed on an "AS IS" BASIS,
25559 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25560 * See the License for the specific language governing permissions and
25561 * limitations under the License.
25562 * =============================================================================
25563 */
25564 const DEFAULT_FILE_NAME_PREFIX = 'model';
25565 const DEFAULT_JSON_EXTENSION_NAME = '.json';
25566 const DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin';
25567 function defer(f) {
25568 return new Promise(resolve => setTimeout(resolve)).then(f);
25569 }
25570 class BrowserDownloads {
25571 constructor(fileNamePrefix) {
25572 if (!env().getBool('IS_BROWSER')) {
25573 // TODO(cais): Provide info on what IOHandlers are available under the
25574 // current environment.
25575 throw new Error('browserDownloads() cannot proceed because the current environment ' +
25576 'is not a browser.');
25577 }
25578 if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) {
25579 fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length);
25580 }
25581 if (fileNamePrefix == null || fileNamePrefix.length === 0) {
25582 fileNamePrefix = DEFAULT_FILE_NAME_PREFIX;
25583 }
25584 this.modelJsonFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME;
25585 this.weightDataFileName =
25586 fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME;
25587 }
25588 async save(modelArtifacts) {
25589 if (typeof (document) === 'undefined') {
25590 throw new Error('Browser downloads are not supported in ' +
25591 'this environment since `document` is not present');
25592 }
25593 // TODO(mattsoulanille): Support saving models over 2GB that exceed
25594 // Chrome's ArrayBuffer size limit.
25595 const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData);
25596 const weightsURL = window.URL.createObjectURL(new Blob([weightBuffer], { type: 'application/octet-stream' }));
25597 if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
25598 throw new Error('BrowserDownloads.save() does not support saving model topology ' +
25599 'in binary formats yet.');
25600 }
25601 else {
25602 const weightsManifest = [{
25603 paths: ['./' + this.weightDataFileName],
25604 weights: modelArtifacts.weightSpecs
25605 }];
25606 const modelJSON = getModelJSONForModelArtifacts(modelArtifacts, weightsManifest);
25607 const modelJsonURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelJSON)], { type: 'application/json' }));
25608 // If anchor elements are not provided, create them without attaching them
25609 // to parents, so that the downloaded file names can be controlled.
25610 const jsonAnchor = this.modelJsonAnchor == null ?
25611 document.createElement('a') :
25612 this.modelJsonAnchor;
25613 jsonAnchor.download = this.modelJsonFileName;
25614 jsonAnchor.href = modelJsonURL;
25615 // Trigger downloads by evoking a click event on the download anchors.
25616 // When multiple downloads are started synchronously, Firefox will only
25617 // save the last one.
25618 await defer(() => jsonAnchor.dispatchEvent(new MouseEvent('click')));
25619 if (modelArtifacts.weightData != null) {
25620 const weightDataAnchor = this.weightDataAnchor == null ?
25621 document.createElement('a') :
25622 this.weightDataAnchor;
25623 weightDataAnchor.download = this.weightDataFileName;
25624 weightDataAnchor.href = weightsURL;
25625 await defer(() => weightDataAnchor.dispatchEvent(new MouseEvent('click')));
25626 }
25627 return { modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts) };
25628 }
25629 }
25630 }
25631 BrowserDownloads.URL_SCHEME = 'downloads://';
25632 class BrowserFiles {
25633 constructor(files) {
25634 if (files == null || files.length < 1) {
25635 throw new Error(`When calling browserFiles, at least 1 file is required, ` +
25636 `but received ${files}`);
25637 }
25638 this.jsonFile = files[0];
25639 this.weightsFiles = files.slice(1);
25640 }
25641 async load() {
25642 return new Promise((resolve, reject) => {
25643 const jsonReader = new FileReader();
25644 jsonReader.onload = (event) => {
25645 // tslint:disable-next-line:no-any
25646 const modelJSON = JSON.parse(event.target.result);
25647 const modelTopology = modelJSON.modelTopology;
25648 if (modelTopology == null) {
25649 reject(new Error(`modelTopology field is missing from file ${this.jsonFile.name}`));
25650 return;
25651 }
25652 const weightsManifest = modelJSON.weightsManifest;
25653 if (weightsManifest == null) {
25654 reject(new Error(`weightManifest field is missing from file ${this.jsonFile.name}`));
25655 return;
25656 }
25657 if (this.weightsFiles.length === 0) {
25658 resolve({ modelTopology });
25659 return;
25660 }
25661 const modelArtifactsPromise = getModelArtifactsForJSON(modelJSON, (weightsManifest) => this.loadWeights(weightsManifest));
25662 resolve(modelArtifactsPromise);
25663 };
25664 jsonReader.onerror = error => reject(`Failed to read model topology and weights manifest JSON ` +
25665 `from file '${this.jsonFile.name}'. BrowserFiles supports loading ` +
25666 `Keras-style tf.Model artifacts only.`);
25667 jsonReader.readAsText(this.jsonFile);
25668 });
25669 }
25670 loadWeights(weightsManifest) {
25671 const weightSpecs = [];
25672 const paths = [];
25673 for (const entry of weightsManifest) {
25674 weightSpecs.push(...entry.weights);
25675 paths.push(...entry.paths);
25676 }
25677 const pathToFile = this.checkManifestAndWeightFiles(weightsManifest);
25678 const promises = paths.map(path => this.loadWeightsFile(path, pathToFile[path]));
25679 return Promise.all(promises).then(buffers => [weightSpecs, buffers]);
25680 }
25681 loadWeightsFile(path, file) {
25682 return new Promise((resolve, reject) => {
25683 const weightFileReader = new FileReader();
25684 weightFileReader.onload = (event) => {
25685 // tslint:disable-next-line:no-any
25686 const weightData = event.target.result;
25687 resolve(weightData);
25688 };
25689 weightFileReader.onerror = error => reject(`Failed to weights data from file of path '${path}'.`);
25690 weightFileReader.readAsArrayBuffer(file);
25691 });
25692 }
25693 /**
25694 * Check the compatibility between weights manifest and weight files.
25695 */
25696 checkManifestAndWeightFiles(manifest) {
25697 const basenames = [];
25698 const fileNames = this.weightsFiles.map(file => basename(file.name));
25699 const pathToFile = {};
25700 for (const group of manifest) {
25701 group.paths.forEach(path => {
25702 const pathBasename = basename(path);
25703 if (basenames.indexOf(pathBasename) !== -1) {
25704 throw new Error(`Duplicate file basename found in weights manifest: ` +
25705 `'${pathBasename}'`);
25706 }
25707 basenames.push(pathBasename);
25708 if (fileNames.indexOf(pathBasename) === -1) {
25709 throw new Error(`Weight file with basename '${pathBasename}' is not provided.`);
25710 }
25711 else {
25712 pathToFile[path] = this.weightsFiles[fileNames.indexOf(pathBasename)];
25713 }
25714 });
25715 }
25716 if (basenames.length !== this.weightsFiles.length) {
25717 throw new Error(`Mismatch in the number of files in weights manifest ` +
25718 `(${basenames.length}) and the number of weight files provided ` +
25719 `(${this.weightsFiles.length}).`);
25720 }
25721 return pathToFile;
25722 }
25723 }
25724 const browserDownloadsRouter = (url) => {
25725 if (!env().getBool('IS_BROWSER')) {
25726 return null;
25727 }
25728 else {
25729 if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) {
25730 return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length));
25731 }
25732 else {
25733 return null;
25734 }
25735 }
25736 };
25737 IORouterRegistry.registerSaveRouter(browserDownloadsRouter);
25738 /**
25739 * Creates an IOHandler that triggers file downloads from the browser.
25740 *
25741 * The returned `IOHandler` instance can be used as model exporting methods such
25742 * as `tf.Model.save` and supports only saving.
25743 *
25744 * ```js
25745 * const model = tf.sequential();
25746 * model.add(tf.layers.dense(
25747 * {units: 1, inputShape: [10], activation: 'sigmoid'}));
25748 * const saveResult = await model.save('downloads://mymodel');
25749 * // This will trigger downloading of two files:
25750 * // 'mymodel.json' and 'mymodel.weights.bin'.
25751 * console.log(saveResult);
25752 * ```
25753 *
25754 * @param fileNamePrefix Prefix name of the files to be downloaded. For use with
25755 * `tf.Model`, `fileNamePrefix` should follow either of the following two
25756 * formats:
25757 * 1. `null` or `undefined`, in which case the default file
25758 * names will be used:
25759 * - 'model.json' for the JSON file containing the model topology and
25760 * weights manifest.
25761 * - 'model.weights.bin' for the binary file containing the binary weight
25762 * values.
25763 * 2. A single string or an Array of a single string, as the file name prefix.
25764 * For example, if `'foo'` is provided, the downloaded JSON
25765 * file and binary weights file will be named 'foo.json' and
25766 * 'foo.weights.bin', respectively.
25767 * @param config Additional configuration for triggering downloads.
25768 * @returns An instance of `BrowserDownloads` `IOHandler`.
25769 *
25770 * @doc {
25771 * heading: 'Models',
25772 * subheading: 'Loading',
25773 * namespace: 'io',
25774 * ignoreCI: true
25775 * }
25776 */
25777 function browserDownloads(fileNamePrefix = 'model') {
25778 return new BrowserDownloads(fileNamePrefix);
25779 }
25780 /**
25781 * Creates an IOHandler that loads model artifacts from user-selected files.
25782 *
25783 * This method can be used for loading from files such as user-selected files
25784 * in the browser.
25785 * When used in conjunction with `tf.loadLayersModel`, an instance of
25786 * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
25787 *
25788 * ```js
25789 * // Note: This code snippet won't run properly without the actual file input
25790 * // elements in the HTML DOM.
25791 *
25792 * // Suppose there are two HTML file input (`<input type="file" ...>`)
25793 * // elements.
25794 * const uploadJSONInput = document.getElementById('upload-json');
25795 * const uploadWeightsInput = document.getElementById('upload-weights');
25796 * const model = await tf.loadLayersModel(tf.io.browserFiles(
25797 * [uploadJSONInput.files[0], uploadWeightsInput.files[0]]));
25798 * ```
25799 *
25800 * @param files `File`s to load from. Currently, this function supports only
25801 * loading from files that contain Keras-style models (i.e., `tf.Model`s), for
25802 * which an `Array` of `File`s is expected (in that order):
25803 * - A JSON file containing the model topology and weight manifest.
25804 * - Optionally, one or more binary files containing the binary weights.
25805 * These files must have names that match the paths in the `weightsManifest`
25806 * contained by the aforementioned JSON file, or errors will be thrown
25807 * during loading. These weights files have the same format as the ones
25808 * generated by `tensorflowjs_converter` that comes with the `tensorflowjs`
25809 * Python PIP package. If no weights files are provided, only the model
25810 * topology will be loaded from the JSON file above.
25811 * @returns An instance of `Files` `IOHandler`.
25812 *
25813 * @doc {
25814 * heading: 'Models',
25815 * subheading: 'Loading',
25816 * namespace: 'io',
25817 * ignoreCI: true
25818 * }
25819 */
25820 function browserFiles(files) {
25821 return new BrowserFiles(files);
25822 }
25823
25824 /**
25825 * @license
25826 * Copyright 2019 Google LLC. All Rights Reserved.
25827 * Licensed under the Apache License, Version 2.0 (the "License");
25828 * you may not use this file except in compliance with the License.
25829 * You may obtain a copy of the License at
25830 *
25831 * http://www.apache.org/licenses/LICENSE-2.0
25832 *
25833 * Unless required by applicable law or agreed to in writing, software
25834 * distributed under the License is distributed on an "AS IS" BASIS,
25835 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25836 * See the License for the specific language governing permissions and
25837 * limitations under the License.
25838 * =============================================================================
25839 */
25840 /**
25841 * Monitor Promise.all progress, fire onProgress callback function.
25842 *
25843 * @param promises Promise list going to be monitored
25844 * @param onProgress Callback function. Fired when a promise resolved.
25845 * @param startFraction Optional fraction start. Default to 0.
25846 * @param endFraction Optional fraction end. Default to 1.
25847 */
25848 function monitorPromisesProgress(promises, onProgress, startFraction, endFraction) {
25849 checkPromises(promises);
25850 startFraction = startFraction == null ? 0 : startFraction;
25851 endFraction = endFraction == null ? 1 : endFraction;
25852 checkFraction(startFraction, endFraction);
25853 let resolvedPromise = 0;
25854 const registerMonitor = (promise) => {
25855 promise.then(value => {
25856 const fraction = startFraction +
25857 ++resolvedPromise / promises.length * (endFraction - startFraction);
25858 // pass fraction as parameter to callback function.
25859 onProgress(fraction);
25860 return value;
25861 });
25862 return promise;
25863 };
25864 function checkPromises(promises) {
25865 assert$1(promises != null && Array.isArray(promises) && promises.length > 0, () => 'promises must be a none empty array');
25866 }
25867 function checkFraction(startFraction, endFraction) {
25868 assert$1(startFraction >= 0 && startFraction <= 1, () => `Progress fraction must be in range [0, 1], but ` +
25869 `got startFraction ${startFraction}`);
25870 assert$1(endFraction >= 0 && endFraction <= 1, () => `Progress fraction must be in range [0, 1], but ` +
25871 `got endFraction ${endFraction}`);
25872 assert$1(endFraction >= startFraction, () => `startFraction must be no more than endFraction, but ` +
25873 `got startFraction ${startFraction} and endFraction ` +
25874 `${endFraction}`);
25875 }
25876 return Promise.all(promises.map(registerMonitor));
25877 }
25878
25879 /**
25880 * @license
25881 * Copyright 2018 Google LLC. All Rights Reserved.
25882 * Licensed under the Apache License, Version 2.0 (the "License");
25883 * you may not use this file except in compliance with the License.
25884 * You may obtain a copy of the License at
25885 *
25886 * http://www.apache.org/licenses/LICENSE-2.0
25887 *
25888 * Unless required by applicable law or agreed to in writing, software
25889 * distributed under the License is distributed on an "AS IS" BASIS,
25890 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25891 * See the License for the specific language governing permissions and
25892 * limitations under the License.
25893 * =============================================================================
25894 */
25895 /**
25896 * Reads binary weights data from a number of URLs.
25897 *
25898 * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls.
25899 * @param requestOptions RequestInit (options) for the HTTP requests.
25900 * @param fetchFunc Optional overriding value for the `window.fetch` function.
25901 * @param onProgress Optional, progress callback function, fired periodically
25902 * before the load is completed.
25903 * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same
25904 * length as `fetchURLs`.
25905 */
25906 async function loadWeightsAsArrayBuffer(fetchURLs, loadOptions) {
25907 if (loadOptions == null) {
25908 loadOptions = {};
25909 }
25910 const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch :
25911 loadOptions.fetchFunc;
25912 // Create the requests for all of the weights in parallel.
25913 const requests = fetchURLs.map(fetchURL => fetchFunc(fetchURL, loadOptions.requestInit, { isBinary: true }));
25914 const fetchStartFraction = 0;
25915 const fetchEndFraction = 0.5;
25916 const responses = loadOptions.onProgress == null ?
25917 await Promise.all(requests) :
25918 await monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction);
25919 const bufferPromises = responses.map(response => response.arrayBuffer());
25920 const bufferStartFraction = 0.5;
25921 const bufferEndFraction = 1;
25922 const buffers = loadOptions.onProgress == null ?
25923 await Promise.all(bufferPromises) :
25924 await monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction);
25925 return buffers;
25926 }
25927 function streamWeights(fetchURLs, loadOptions) {
25928 var _a;
25929 const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch :
25930 loadOptions.fetchFunc;
25931 let fetchIndex = 0;
25932 let chunkReader;
25933 (_a = loadOptions.onProgress) === null || _a === void 0 ? void 0 : _a.call(loadOptions, 0);
25934 return new ReadableStream({
25935 pull: async (controller) => {
25936 var _a;
25937 while (fetchIndex < fetchURLs.length) {
25938 if (!chunkReader) {
25939 const body = (await fetchFunc(fetchURLs[fetchIndex], loadOptions.requestInit, { isBinary: true })).body;
25940 chunkReader = body.getReader();
25941 }
25942 const { done, value } = await chunkReader.read();
25943 if (done) {
25944 fetchIndex++;
25945 chunkReader = undefined;
25946 (_a = loadOptions.onProgress) === null || _a === void 0 ? void 0 : _a.call(loadOptions, fetchIndex / fetchURLs.length);
25947 continue;
25948 }
25949 controller.enqueue(value);
25950 return;
25951 }
25952 controller.close();
25953 },
25954 });
25955 }
25956 /**
25957 * Reads a weights manifest JSON configuration, fetches the weights and
25958 * returns them as `Tensor`s.
25959 *
25960 * @param manifest The weights manifest JSON.
25961 * @param filePathPrefix The path prefix for filenames given in the manifest.
25962 * Defaults to the empty string.
25963 * @param weightNames The names of the weights to be fetched.
25964 */
25965 async function loadWeights(manifest, filePathPrefix = '', weightNames, requestInit) {
25966 // TODO(nsthorat): Groups are currently fetched atomically. If you need a
25967 // single weight from a group, the whole group will be fetched. At a future
25968 // date, we should support fetching only the individual shards within a
25969 // group that are needed to reconstruct the requested weight.
25970 // TODO(cais): Use `decodeWeights` for implementation.
25971 const fetchWeights = (fetchUrls) => loadWeightsAsArrayBuffer(fetchUrls, { requestInit });
25972 const loadWeights = weightsLoaderFactory(fetchWeights);
25973 return loadWeights(manifest, filePathPrefix, weightNames);
25974 }
25975 /**
25976 * Creates a function, which reads a weights manifest JSON configuration,
25977 * fetches the weight files using the specified function and returns them as
25978 * `Tensor`s.
25979 *
25980 * ```js
25981 * // example for creating a nodejs weight loader, which reads the weight files
25982 * // from disk using fs.readFileSync
25983 *
25984 * import * as fs from 'fs'
25985 *
25986 * const fetchWeightsFromDisk = (filePaths: string[]) =>
25987 * filePaths.map(filePath => fs.readFileSync(filePath).buffer)
25988 *
25989 * const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk)
25990 *
25991 * const manifest = JSON.parse(
25992 * fs.readFileSync('./my_model-weights_manifest').toString()
25993 * )
25994 * const weightMap = await loadWeights(manifest, './')
25995 * ```
25996 * @param fetchWeightsFunction The function used for fetching the weight files.
25997 * @returns Weight loading function.
25998 */
25999 function weightsLoaderFactory(fetchWeightsFunction) {
26000 return async (manifest, filePathPrefix = '', weightNames) => {
26001 // Collect all the groups, weights, and their relative offsets to be
26002 // fetched.
26003 const groupIndicesToFetchMap = manifest.map(() => false);
26004 const groupWeightsToFetch = {};
26005 const weightsFound = weightNames != null ? weightNames.map(() => false) : [];
26006 const allManifestWeightNames = [];
26007 manifest.forEach((manifestGroupConfig, groupIndex) => {
26008 let groupOffset = 0;
26009 manifestGroupConfig.weights.forEach(weightsEntry => {
26010 const rawDtype = ('quantization' in weightsEntry) ?
26011 weightsEntry.quantization.dtype :
26012 weightsEntry.dtype;
26013 const weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] *
26014 sizeFromShape(weightsEntry.shape);
26015 const enqueueWeightsForFetchingFn = () => {
26016 groupIndicesToFetchMap[groupIndex] = true;
26017 if (groupWeightsToFetch[groupIndex] == null) {
26018 groupWeightsToFetch[groupIndex] = [];
26019 }
26020 groupWeightsToFetch[groupIndex].push({
26021 manifestEntry: weightsEntry,
26022 groupOffset,
26023 sizeBytes: weightsBytes
26024 });
26025 };
26026 if (weightNames != null) {
26027 weightNames.forEach((weightName, weightIndex) => {
26028 if (weightName === weightsEntry.name) {
26029 enqueueWeightsForFetchingFn();
26030 weightsFound[weightIndex] = true;
26031 }
26032 });
26033 }
26034 else {
26035 enqueueWeightsForFetchingFn();
26036 }
26037 allManifestWeightNames.push(weightsEntry.name);
26038 groupOffset += weightsBytes;
26039 });
26040 });
26041 if (!weightsFound.every(found => found)) {
26042 const weightsNotFound = weightNames.filter((_, i) => !weightsFound[i]);
26043 throw new Error(`Could not find weights in manifest with names: ` +
26044 `${weightsNotFound.join(', ')}. \n` +
26045 `Manifest JSON has weights with names: ` +
26046 `${allManifestWeightNames.join(', ')}.`);
26047 }
26048 // Convert the one-hot boolean groupId => shouldFetch map to a list of group
26049 // IDs.
26050 const groupIndicesToFetch = groupIndicesToFetchMap.reduce((accumulator, shouldFetch, i) => {
26051 if (shouldFetch) {
26052 accumulator.push(i);
26053 }
26054 return accumulator;
26055 }, []);
26056 const fetchUrls = [];
26057 groupIndicesToFetch.forEach(i => {
26058 manifest[i].paths.forEach(filepath => {
26059 const fetchUrl = filePathPrefix +
26060 (!filePathPrefix.endsWith('/') ? '/' : '') + filepath;
26061 fetchUrls.push(fetchUrl);
26062 });
26063 });
26064 const buffers = await fetchWeightsFunction(fetchUrls);
26065 const weightsTensorMap = {};
26066 let bufferIndexOffset = 0;
26067 groupIndicesToFetch.forEach(i => {
26068 const numBuffers = manifest[i].paths.length;
26069 const weightsBuffer = new CompositeArrayBuffer(buffers.slice(bufferIndexOffset, bufferIndexOffset + numBuffers));
26070 const weightsEntries = groupWeightsToFetch[i];
26071 weightsEntries.forEach(weightsEntry => {
26072 const byteBuffer = weightsBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes);
26073 const nameToTensorMap = decodeWeights(byteBuffer, [weightsEntry.manifestEntry]);
26074 for (const name in nameToTensorMap) {
26075 weightsTensorMap[name] = nameToTensorMap[name];
26076 }
26077 });
26078 bufferIndexOffset += numBuffers;
26079 });
26080 return weightsTensorMap;
26081 };
26082 }
26083
26084 /**
26085 * @license
26086 * Copyright 2018 Google LLC. All Rights Reserved.
26087 * Licensed under the Apache License, Version 2.0 (the "License");
26088 * you may not use this file except in compliance with the License.
26089 * You may obtain a copy of the License at
26090 *
26091 * http://www.apache.org/licenses/LICENSE-2.0
26092 *
26093 * Unless required by applicable law or agreed to in writing, software
26094 * distributed under the License is distributed on an "AS IS" BASIS,
26095 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26096 * See the License for the specific language governing permissions and
26097 * limitations under the License.
26098 * =============================================================================
26099 */
26100 const OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
26101 const JSON_TYPE = 'application/json';
26102 class HTTPRequest {
26103 constructor(path, loadOptions) {
26104 this.DEFAULT_METHOD = 'POST';
26105 if (loadOptions == null) {
26106 loadOptions = {};
26107 }
26108 this.weightPathPrefix = loadOptions.weightPathPrefix;
26109 this.weightUrlConverter = loadOptions.weightUrlConverter;
26110 if (loadOptions.fetchFunc != null) {
26111 assert$1(typeof loadOptions.fetchFunc === 'function', () => 'Must pass a function that matches the signature of ' +
26112 '`fetch` (see ' +
26113 'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)');
26114 this.fetch = loadOptions.fetchFunc;
26115 }
26116 else {
26117 this.fetch = env().platform.fetch;
26118 }
26119 assert$1(path != null && path.length > 0, () => 'URL path for http must not be null, undefined or ' +
26120 'empty.');
26121 if (Array.isArray(path)) {
26122 assert$1(path.length === 2, () => 'URL paths for http must have a length of 2, ' +
26123 `(actual length is ${path.length}).`);
26124 }
26125 this.path = path;
26126 if (loadOptions.requestInit != null &&
26127 loadOptions.requestInit.body != null) {
26128 throw new Error('requestInit is expected to have no pre-existing body, but has one.');
26129 }
26130 this.requestInit = loadOptions.requestInit || {};
26131 this.loadOptions = loadOptions;
26132 }
26133 async save(modelArtifacts) {
26134 if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
26135 throw new Error('BrowserHTTPRequest.save() does not support saving model topology ' +
26136 'in binary formats yet.');
26137 }
26138 const init = Object.assign({ method: this.DEFAULT_METHOD }, this.requestInit);
26139 init.body = new FormData();
26140 const weightsManifest = [{
26141 paths: ['./model.weights.bin'],
26142 weights: modelArtifacts.weightSpecs,
26143 }];
26144 const modelTopologyAndWeightManifest = getModelJSONForModelArtifacts(modelArtifacts, weightsManifest);
26145 init.body.append('model.json', new Blob([JSON.stringify(modelTopologyAndWeightManifest)], { type: JSON_TYPE }), 'model.json');
26146 if (modelArtifacts.weightData != null) {
26147 // TODO(mattsoulanille): Support saving models over 2GB that exceed
26148 // Chrome's ArrayBuffer size limit.
26149 const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData);
26150 init.body.append('model.weights.bin', new Blob([weightBuffer], { type: OCTET_STREAM_MIME_TYPE }), 'model.weights.bin');
26151 }
26152 const response = await this.fetch(this.path, init);
26153 if (response.ok) {
26154 return {
26155 modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts),
26156 responses: [response],
26157 };
26158 }
26159 else {
26160 throw new Error(`BrowserHTTPRequest.save() failed due to HTTP response status ` +
26161 `${response.status}.`);
26162 }
26163 }
26164 async loadModelJSON() {
26165 const modelConfigRequest = await this.fetch(this.path, this.requestInit);
26166 if (!modelConfigRequest.ok) {
26167 throw new Error(`Request to ${this.path} failed with status code ` +
26168 `${modelConfigRequest.status}. Please verify this URL points to ` +
26169 `the model JSON of the model to load.`);
26170 }
26171 let modelJSON;
26172 try {
26173 modelJSON = await modelConfigRequest.json();
26174 }
26175 catch (e) {
26176 let message = `Failed to parse model JSON of response from ${this.path}.`;
26177 // TODO(nsthorat): Remove this after some time when we're comfortable that
26178 // .pb files are mostly gone.
26179 if (this.path.endsWith('.pb')) {
26180 message += ' Your path contains a .pb file extension. ' +
26181 'Support for .pb models have been removed in TensorFlow.js 1.0 ' +
26182 'in favor of .json models. You can re-convert your Python ' +
26183 'TensorFlow model using the TensorFlow.js 1.0 conversion scripts ' +
26184 'or you can convert your.pb models with the \'pb2json\'' +
26185 'NPM script in the tensorflow/tfjs-converter repository.';
26186 }
26187 else {
26188 message += ' Please make sure the server is serving valid ' +
26189 'JSON for this request.';
26190 }
26191 throw new Error(message);
26192 }
26193 // We do not allow both modelTopology and weightsManifest to be missing.
26194 const modelTopology = modelJSON.modelTopology;
26195 const weightsManifest = modelJSON.weightsManifest;
26196 if (modelTopology == null && weightsManifest == null) {
26197 throw new Error(`The JSON from HTTP path ${this.path} contains neither model ` +
26198 `topology or manifest for weights.`);
26199 }
26200 return modelJSON;
26201 }
26202 /**
26203 * Load model artifacts via HTTP request(s).
26204 *
26205 * See the documentation to `tf.io.http` for details on the saved
26206 * artifacts.
26207 *
26208 * @returns The loaded model artifacts (if loading succeeds).
26209 */
26210 async load() {
26211 if (this.loadOptions.streamWeights) {
26212 return this.loadStream();
26213 }
26214 const modelJSON = await this.loadModelJSON();
26215 return getModelArtifactsForJSON(modelJSON, (weightsManifest) => this.loadWeights(weightsManifest));
26216 }
26217 async loadStream() {
26218 const modelJSON = await this.loadModelJSON();
26219 const fetchURLs = await this.getWeightUrls(modelJSON.weightsManifest);
26220 const weightSpecs = getWeightSpecs(modelJSON.weightsManifest);
26221 const stream = () => streamWeights(fetchURLs, this.loadOptions);
26222 return Object.assign(Object.assign({}, modelJSON), { weightSpecs, getWeightStream: stream });
26223 }
26224 async getWeightUrls(weightsManifest) {
26225 const weightPath = Array.isArray(this.path) ? this.path[1] : this.path;
26226 const [prefix, suffix] = parseUrl(weightPath);
26227 const pathPrefix = this.weightPathPrefix || prefix;
26228 const fetchURLs = [];
26229 const urlPromises = [];
26230 for (const weightsGroup of weightsManifest) {
26231 for (const path of weightsGroup.paths) {
26232 if (this.weightUrlConverter != null) {
26233 urlPromises.push(this.weightUrlConverter(path));
26234 }
26235 else {
26236 fetchURLs.push(pathPrefix + path + suffix);
26237 }
26238 }
26239 }
26240 if (this.weightUrlConverter) {
26241 fetchURLs.push(...await Promise.all(urlPromises));
26242 }
26243 return fetchURLs;
26244 }
26245 async loadWeights(weightsManifest) {
26246 const fetchURLs = await this.getWeightUrls(weightsManifest);
26247 const weightSpecs = getWeightSpecs(weightsManifest);
26248 const buffers = await loadWeightsAsArrayBuffer(fetchURLs, this.loadOptions);
26249 return [weightSpecs, buffers];
26250 }
26251 }
26252 HTTPRequest.URL_SCHEME_REGEX = /^https?:\/\//;
26253 /**
26254 * Extract the prefix and suffix of the url, where the prefix is the path before
26255 * the last file, and suffix is the search params after the last file.
26256 * ```
26257 * const url = 'http://tfhub.dev/model/1/tensorflowjs_model.pb?tfjs-format=file'
26258 * [prefix, suffix] = parseUrl(url)
26259 * // prefix = 'http://tfhub.dev/model/1/'
26260 * // suffix = '?tfjs-format=file'
26261 * ```
26262 * @param url the model url to be parsed.
26263 */
26264 function parseUrl(url) {
26265 const lastSlash = url.lastIndexOf('/');
26266 const lastSearchParam = url.lastIndexOf('?');
26267 const prefix = url.substring(0, lastSlash);
26268 const suffix = lastSearchParam > lastSlash ? url.substring(lastSearchParam) : '';
26269 return [prefix + '/', suffix];
26270 }
26271 function isHTTPScheme(url) {
26272 return url.match(HTTPRequest.URL_SCHEME_REGEX) != null;
26273 }
26274 const httpRouter = (url, loadOptions) => {
26275 if (typeof fetch === 'undefined' &&
26276 (loadOptions == null || loadOptions.fetchFunc == null)) {
26277 // `http` uses `fetch` or `node-fetch`, if one wants to use it in
26278 // an environment that is not the browser or node they have to setup a
26279 // global fetch polyfill.
26280 return null;
26281 }
26282 else {
26283 let isHTTP = true;
26284 if (Array.isArray(url)) {
26285 isHTTP = url.every(urlItem => isHTTPScheme(urlItem));
26286 }
26287 else {
26288 isHTTP = isHTTPScheme(url);
26289 }
26290 if (isHTTP) {
26291 return http(url, loadOptions);
26292 }
26293 }
26294 return null;
26295 };
26296 IORouterRegistry.registerSaveRouter(httpRouter);
26297 IORouterRegistry.registerLoadRouter(httpRouter);
26298 /**
26299 * Creates an IOHandler subtype that sends model artifacts to HTTP server.
26300 *
26301 * An HTTP request of the `multipart/form-data` mime type will be sent to the
26302 * `path` URL. The form data includes artifacts that represent the topology
26303 * and/or weights of the model. In the case of Keras-style `tf.Model`, two
26304 * blobs (files) exist in form-data:
26305 * - A JSON file consisting of `modelTopology` and `weightsManifest`.
26306 * - A binary weights file consisting of the concatenated weight values.
26307 * These files are in the same format as the one generated by
26308 * [tfjs_converter](https://js.tensorflow.org/tutorials/import-keras.html).
26309 *
26310 * The following code snippet exemplifies the client-side code that uses this
26311 * function:
26312 *
26313 * ```js
26314 * const model = tf.sequential();
26315 * model.add(
26316 * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'}));
26317 *
26318 * const saveResult = await model.save(tf.io.http(
26319 * 'http://model-server:5000/upload', {requestInit: {method: 'PUT'}}));
26320 * console.log(saveResult);
26321 * ```
26322 *
26323 * If the default `POST` method is to be used, without any custom parameters
26324 * such as headers, you can simply pass an HTTP or HTTPS URL to `model.save`:
26325 *
26326 * ```js
26327 * const saveResult = await model.save('http://model-server:5000/upload');
26328 * ```
26329 *
26330 * The following GitHub Gist
26331 * https://gist.github.com/dsmilkov/1b6046fd6132d7408d5257b0976f7864
26332 * implements a server based on [flask](https://github.com/pallets/flask) that
26333 * can receive the request. Upon receiving the model artifacts via the request,
26334 * this particular server reconstitutes instances of [Keras
26335 * Models](https://keras.io/models/model/) in memory.
26336 *
26337 *
26338 * @param path A URL path to the model.
26339 * Can be an absolute HTTP path (e.g.,
26340 * 'http://localhost:8000/model-upload)') or a relative path (e.g.,
26341 * './model-upload').
26342 * @param requestInit Request configurations to be used when sending
26343 * HTTP request to server using `fetch`. It can contain fields such as
26344 * `method`, `credentials`, `headers`, `mode`, etc. See
26345 * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request
26346 * for more information. `requestInit` must not have a body, because the
26347 * body will be set by TensorFlow.js. File blobs representing the model
26348 * topology (filename: 'model.json') and the weights of the model (filename:
26349 * 'model.weights.bin') will be appended to the body. If `requestInit` has a
26350 * `body`, an Error will be thrown.
26351 * @param loadOptions Optional configuration for the loading. It includes the
26352 * following fields:
26353 * - weightPathPrefix Optional, this specifies the path prefix for weight
26354 * files, by default this is calculated from the path param.
26355 * - fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
26356 * the `fetch` from node-fetch can be used here.
26357 * - onProgress Optional, progress callback function, fired periodically
26358 * before the load is completed.
26359 * @returns An instance of `IOHandler`.
26360 *
26361 * @doc {
26362 * heading: 'Models',
26363 * subheading: 'Loading',
26364 * namespace: 'io',
26365 * ignoreCI: true
26366 * }
26367 */
26368 function http(path, loadOptions) {
26369 return new HTTPRequest(path, loadOptions);
26370 }
26371 /**
26372 * Deprecated. Use `tf.io.http`.
26373 * @param path
26374 * @param loadOptions
26375 */
26376 function browserHTTPRequest(path, loadOptions) {
26377 return http(path, loadOptions);
26378 }
26379
26380 /**
26381 * @license
26382 * Copyright 2018 Google LLC. All Rights Reserved.
26383 * Licensed under the Apache License, Version 2.0 (the "License");
26384 * you may not use this file except in compliance with the License.
26385 * You may obtain a copy of the License at
26386 *
26387 * http://www.apache.org/licenses/LICENSE-2.0
26388 *
26389 * Unless required by applicable law or agreed to in writing, software
26390 * distributed under the License is distributed on an "AS IS" BASIS,
26391 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26392 * See the License for the specific language governing permissions and
26393 * limitations under the License.
26394 * =============================================================================
26395 */
26396 class PassthroughLoader {
26397 constructor(modelArtifacts) {
26398 this.modelArtifacts = modelArtifacts;
26399 }
26400 load() {
26401 return this.modelArtifacts;
26402 }
26403 }
26404 class PassthroughSaver {
26405 constructor(saveHandler) {
26406 this.saveHandler = saveHandler;
26407 }
26408 save(modelArtifacts) {
26409 return this.saveHandler(modelArtifacts);
26410 }
26411 }
26412 class PassthroughAsync {
26413 constructor(handler) {
26414 if (handler.load) {
26415 this.load = () => Promise.resolve(handler.load());
26416 }
26417 if (handler.save) {
26418 this.save = (modelArtifacts) => Promise.resolve(handler.save(modelArtifacts));
26419 }
26420 }
26421 }
26422 /**
26423 * Creates an IOHandler that loads model artifacts from memory.
26424 *
26425 * When used in conjunction with `tf.loadLayersModel`, an instance of
26426 * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
26427 *
26428 * ```js
26429 * const model = await tf.loadLayersModel(tf.io.fromMemory(
26430 * modelTopology, weightSpecs, weightData));
26431 * ```
26432 *
26433 * @param modelArtifacts a object containing model topology (i.e., parsed from
26434 * the JSON format).
26435 * @param weightSpecs An array of `WeightsManifestEntry` objects describing the
26436 * names, shapes, types, and quantization of the weight data. Optional.
26437 * @param weightData A single `ArrayBuffer` containing the weight data,
26438 * concatenated in the order described by the weightSpecs. Optional.
26439 * @param trainingConfig Model training configuration. Optional.
26440 *
26441 * @returns A passthrough `IOHandler` that simply loads the provided data.
26442 */
26443 function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) {
26444 const args = arguments;
26445 return new PassthroughAsync(fromMemorySync(...args));
26446 }
26447 /**
26448 * Creates an IOHandler that loads model artifacts from memory.
26449 *
26450 * When used in conjunction with `tf.loadLayersModel`, an instance of
26451 * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
26452 *
26453 * ```js
26454 * const model = await tf.loadLayersModel(tf.io.fromMemory(
26455 * modelTopology, weightSpecs, weightData));
26456 * ```
26457 *
26458 * @param modelArtifacts a object containing model topology (i.e., parsed from
26459 * the JSON format).
26460 * @param weightSpecs An array of `WeightsManifestEntry` objects describing the
26461 * names, shapes, types, and quantization of the weight data. Optional.
26462 * @param weightData A single `ArrayBuffer` containing the weight data,
26463 * concatenated in the order described by the weightSpecs. Optional.
26464 * @param trainingConfig Model training configuration. Optional.
26465 *
26466 * @returns A passthrough `IOHandlerSync` that simply loads the provided data.
26467 */
26468 function fromMemorySync(modelArtifacts, weightSpecs, weightData, trainingConfig) {
26469 if (arguments.length === 1) {
26470 const isModelArtifacts = modelArtifacts.modelTopology != null ||
26471 modelArtifacts.weightSpecs != null;
26472 if (isModelArtifacts) {
26473 return new PassthroughLoader(modelArtifacts);
26474 }
26475 else {
26476 // Legacy support: with only modelTopology.
26477 // TODO(cais): Remove this deprecated API.
26478 console.warn('Please call tf.io.fromMemory() with only one argument. ' +
26479 'The argument should be of type ModelArtifacts. ' +
26480 'The multi-argument signature of tf.io.fromMemory() has been ' +
26481 'deprecated and will be removed in a future release.');
26482 return new PassthroughLoader({ modelTopology: modelArtifacts });
26483 }
26484 }
26485 else {
26486 // Legacy support.
26487 // TODO(cais): Remove this deprecated API.
26488 console.warn('Please call tf.io.fromMemory() with only one argument. ' +
26489 'The argument should be of type ModelArtifacts. ' +
26490 'The multi-argument signature of tf.io.fromMemory() has been ' +
26491 'deprecated and will be removed in a future release.');
26492 return new PassthroughLoader({
26493 modelTopology: modelArtifacts,
26494 weightSpecs,
26495 weightData,
26496 trainingConfig
26497 });
26498 }
26499 }
26500 /**
26501 * Creates an IOHandler that passes saved model artifacts to a callback.
26502 *
26503 * ```js
26504 * function handleSave(artifacts) {
26505 * // ... do something with the artifacts ...
26506 * return {modelArtifactsInfo: {...}, ...};
26507 * }
26508 *
26509 * const saveResult = model.save(tf.io.withSaveHandler(handleSave));
26510 * ```
26511 *
26512 * @param saveHandler A function that accepts a `ModelArtifacts` and returns a
26513 * promise that resolves to a `SaveResult`.
26514 */
26515 function withSaveHandler(saveHandler) {
26516 return new PassthroughSaver(saveHandler);
26517 }
26518 /**
26519 * Creates an IOHandlerSync that passes saved model artifacts to a callback.
26520 *
26521 * ```js
26522 * function handleSave(artifacts) {
26523 * // ... do something with the artifacts ...
26524 * return {modelArtifactsInfo: {...}, ...};
26525 * }
26526 *
26527 * const saveResult = model.save(tf.io.withSaveHandler(handleSave));
26528 * ```
26529 *
26530 * @param saveHandler A function that accepts a `ModelArtifacts` and returns a
26531 * `SaveResult`.
26532 */
26533 function withSaveHandlerSync(saveHandler) {
26534 return new PassthroughSaver(saveHandler);
26535 }
26536
26537 /**
26538 * @license
26539 * Copyright 2018 Google LLC. All Rights Reserved.
26540 * Licensed under the Apache License, Version 2.0 (the "License");
26541 * you may not use this file except in compliance with the License.
26542 * You may obtain a copy of the License at
26543 *
26544 * http://www.apache.org/licenses/LICENSE-2.0
26545 *
26546 * Unless required by applicable law or agreed to in writing, software
26547 * distributed under the License is distributed on an "AS IS" BASIS,
26548 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26549 * See the License for the specific language governing permissions and
26550 * limitations under the License.
26551 * =============================================================================
26552 */
26553
26554 var io = /*#__PURE__*/Object.freeze({
26555 __proto__: null,
26556 CompositeArrayBuffer: CompositeArrayBuffer,
26557 browserFiles: browserFiles,
26558 browserHTTPRequest: browserHTTPRequest,
26559 concatenateArrayBuffers: concatenateArrayBuffers,
26560 copyModel: copyModel,
26561 decodeWeights: decodeWeights,
26562 decodeWeightsStream: decodeWeightsStream,
26563 encodeWeights: encodeWeights,
26564 fromMemory: fromMemory,
26565 fromMemorySync: fromMemorySync,
26566 getLoadHandlers: getLoadHandlers,
26567 getModelArtifactsForJSON: getModelArtifactsForJSON,
26568 getModelArtifactsForJSONSync: getModelArtifactsForJSONSync,
26569 getModelArtifactsInfoForJSON: getModelArtifactsInfoForJSON,
26570 getSaveHandlers: getSaveHandlers,
26571 getWeightSpecs: getWeightSpecs,
26572 http: http,
26573 isHTTPScheme: isHTTPScheme,
26574 listModels: listModels,
26575 loadWeights: loadWeights,
26576 moveModel: moveModel,
26577 registerLoadRouter: registerLoadRouter,
26578 registerSaveRouter: registerSaveRouter,
26579 removeModel: removeModel,
26580 weightsLoaderFactory: weightsLoaderFactory,
26581 withSaveHandler: withSaveHandler,
26582 withSaveHandlerSync: withSaveHandlerSync
26583 });
26584
26585 /**
26586 * @license
26587 * Copyright 2018 Google LLC. All Rights Reserved.
26588 * Licensed under the Apache License, Version 2.0 (the "License");
26589 * you may not use this file except in compliance with the License.
26590 * You may obtain a copy of the License at
26591 *
26592 * http://www.apache.org/licenses/LICENSE-2.0
26593 *
26594 * Unless required by applicable law or agreed to in writing, software
26595 * distributed under the License is distributed on an "AS IS" BASIS,
26596 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26597 * See the License for the specific language governing permissions and
26598 * limitations under the License.
26599 * =============================================================================
26600 */
26601 /**
26602 * Computes the confusion matrix from true labels and predicted labels.
26603 *
26604 * ```js
26605 * const labels = tf.tensor1d([0, 1, 2, 1, 0], 'int32');
26606 * const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'int32');
26607 * const numClasses = 3;
26608 * const out = tf.math.confusionMatrix(labels, predictions, numClasses);
26609 * out.print();
26610 * // Expected output matrix:
26611 * // [[2, 0, 0],
26612 * // [0, 1, 1],
26613 * // [0, 0, 1]]
26614 * ```
26615 *
26616 * @param labels The target labels, assumed to be 0-based integers
26617 * for the classes. The shape is `[numExamples]`, where
26618 * `numExamples` is the number of examples included.
26619 * @param predictions The predicted classes, assumed to be
26620 * 0-based integers for the classes. Must have the same shape as `labels`.
26621 * @param numClasses Number of all classes, as an integer.
26622 * Its value must be larger than the largest element in `labels` and
26623 * `predictions`.
26624 * @returns The confusion matrix as a int32-type 2D tensor. The value at
26625 * row `r` and column `c` is the number of times examples of actual class
26626 * `r` were predicted as class `c`.
26627 *
26628 * @doc {heading: 'Operations', subheading: 'Evaluation'}
26629 */
26630 function confusionMatrix_(labels, predictions, numClasses) {
26631 const $labels = convertToTensor(labels, 'labels', 'confusionMatrix');
26632 const $predictions = convertToTensor(predictions, 'predictions', 'confusionMatrix');
26633 assert$1(numClasses == null || numClasses > 0 && Number.isInteger(numClasses), () => `If provided, numClasses must be a positive integer, ` +
26634 `but got ${numClasses}`);
26635 assert$1($labels.rank === 1, () => `Expected the rank of labels to be 1, but got ${$labels.rank}`);
26636 assert$1($predictions.rank === 1, () => `Expected the rank of predictions to be 1, ` +
26637 `but got ${$predictions.rank}`);
26638 assert$1($labels.shape[0] === $predictions.shape[0], () => `Mismatch in the number of examples: ` +
26639 `${$labels.shape[0]} vs. ${$predictions.shape[0]}. ` +
26640 `Labels and predictions should have the same number of elements.`);
26641 assert$1(numClasses > 0 && Number.isInteger(numClasses), () => `numClasses is required to be a positive integer, but got ` +
26642 `${numClasses}`);
26643 // TODO(cais): In the future, if oneHot supports tensors inputs for
26644 // `numClasses`, `confusionMatrix` can make `numClasses` optional.
26645 const oneHotLabels = oneHot$3(cast$3($labels, 'int32'), numClasses);
26646 const oneHotPredictions = oneHot$3(cast$3($predictions, 'int32'), numClasses);
26647 const oneHotLabelsT = transpose$2(oneHotLabels);
26648 const product = matMul$1(oneHotLabelsT, oneHotPredictions);
26649 return cast$3(product, 'int32');
26650 }
26651 const confusionMatrix = /* @__PURE__ */ op({ confusionMatrix_ });
26652
26653 /**
26654 * @license
26655 * Copyright 2018 Google LLC. All Rights Reserved.
26656 * Licensed under the Apache License, Version 2.0 (the "License");
26657 * you may not use this file except in compliance with the License.
26658 * You may obtain a copy of the License at
26659 *
26660 * http://www.apache.org/licenses/LICENSE-2.0
26661 *
26662 * Unless required by applicable law or agreed to in writing, software
26663 * distributed under the License is distributed on an "AS IS" BASIS,
26664 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26665 * See the License for the specific language governing permissions and
26666 * limitations under the License.
26667 * =============================================================================
26668 */
26669
26670 var math = /*#__PURE__*/Object.freeze({
26671 __proto__: null,
26672 confusionMatrix: confusionMatrix
26673 });
26674
26675 /**
26676 * @license
26677 * Copyright 2019 Google LLC. All Rights Reserved.
26678 * Licensed under the Apache License, Version 2.0 (the "License");
26679 * you may not use this file except in compliance with the License.
26680 * You may obtain a copy of the License at
26681 *
26682 * http://www.apache.org/licenses/LICENSE-2.0
26683 *
26684 * Unless required by applicable law or agreed to in writing, software
26685 * distributed under the License is distributed on an "AS IS" BASIS,
26686 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26687 * See the License for the specific language governing permissions and
26688 * limitations under the License.
26689 * =============================================================================
26690 */
26691 let fromPixels2DContext$1;
26692 let hasToPixelsWarned = false;
26693 /**
26694 * Creates a `tf.Tensor` from an image.
26695 *
26696 * ```js
26697 * const image = new ImageData(1, 1);
26698 * image.data[0] = 100;
26699 * image.data[1] = 150;
26700 * image.data[2] = 200;
26701 * image.data[3] = 255;
26702 *
26703 * tf.browser.fromPixels(image).print();
26704 * ```
26705 *
26706 * @param pixels The input image to construct the tensor from. The
26707 * supported image types are all 4-channel. You can also pass in an image
26708 * object with following attributes:
26709 * `{data: Uint8Array; width: number; height: number}`
26710 * @param numChannels The number of channels of the output tensor. A
26711 * numChannels value less than 4 allows you to ignore channels. Defaults to
26712 * 3 (ignores alpha channel of input image).
26713 *
26714 * @returns A Tensor3D with the shape `[height, width, numChannels]`.
26715 *
26716 * Note: fromPixels can be lossy in some cases, same image may result in
26717 * slightly different tensor values, if rendered by different rendering
26718 * engines. This means that results from different browsers, or even same
26719 * browser with CPU and GPU rendering engines can be different. See discussion
26720 * in details:
26721 * https://github.com/tensorflow/tfjs/issues/5482
26722 *
26723 * @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true}
26724 */
26725 function fromPixels_(pixels, numChannels = 3) {
26726 // Sanity checks.
26727 if (numChannels > 4) {
26728 throw new Error('Cannot construct Tensor with more than 4 channels from pixels.');
26729 }
26730 if (pixels == null) {
26731 throw new Error('pixels passed to tf.browser.fromPixels() can not be null');
26732 }
26733 let isPixelData = false;
26734 let isImageData = false;
26735 let isVideo = false;
26736 let isImage = false;
26737 let isCanvasLike = false;
26738 let isImageBitmap = false;
26739 if (pixels.data instanceof Uint8Array) {
26740 isPixelData = true;
26741 }
26742 else if (typeof (ImageData) !== 'undefined' && pixels instanceof ImageData) {
26743 isImageData = true;
26744 }
26745 else if (typeof (HTMLVideoElement) !== 'undefined' &&
26746 pixels instanceof HTMLVideoElement) {
26747 isVideo = true;
26748 }
26749 else if (typeof (HTMLImageElement) !== 'undefined' &&
26750 pixels instanceof HTMLImageElement) {
26751 isImage = true;
26752 // tslint:disable-next-line: no-any
26753 }
26754 else if (pixels.getContext != null) {
26755 isCanvasLike = true;
26756 }
26757 else if (typeof (ImageBitmap) !== 'undefined' && pixels instanceof ImageBitmap) {
26758 isImageBitmap = true;
26759 }
26760 else {
26761 throw new Error('pixels passed to tf.browser.fromPixels() must be either an ' +
26762 `HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData ` +
26763 `in browser, or OffscreenCanvas, ImageData in webworker` +
26764 ` or {data: Uint32Array, width: number, height: number}, ` +
26765 `but was ${pixels.constructor.name}`);
26766 }
26767 // If the current backend has 'FromPixels' registered, it has a more
26768 // efficient way of handling pixel uploads, so we call that.
26769 const kernel = getKernel(FromPixels, ENGINE.backendName);
26770 if (kernel != null) {
26771 const inputs = { pixels };
26772 const attrs = { numChannels };
26773 return ENGINE.runKernel(FromPixels, inputs, attrs);
26774 }
26775 const [width, height] = isVideo ?
26776 [
26777 pixels.videoWidth,
26778 pixels.videoHeight
26779 ] :
26780 [pixels.width, pixels.height];
26781 let vals;
26782 if (isCanvasLike) {
26783 vals =
26784 // tslint:disable-next-line:no-any
26785 pixels.getContext('2d').getImageData(0, 0, width, height).data;
26786 }
26787 else if (isImageData || isPixelData) {
26788 vals = pixels.data;
26789 }
26790 else if (isImage || isVideo || isImageBitmap) {
26791 if (fromPixels2DContext$1 == null) {
26792 if (typeof document === 'undefined') {
26793 if (typeof OffscreenCanvas !== 'undefined' &&
26794 typeof OffscreenCanvasRenderingContext2D !== 'undefined') {
26795 // @ts-ignore
26796 fromPixels2DContext$1 = new OffscreenCanvas(1, 1).getContext('2d');
26797 }
26798 else {
26799 throw new Error('Cannot parse input in current context. ' +
26800 'Reason: OffscreenCanvas Context2D rendering is not supported.');
26801 }
26802 }
26803 else {
26804 fromPixels2DContext$1 = document.createElement('canvas').getContext('2d', { willReadFrequently: true });
26805 }
26806 }
26807 fromPixels2DContext$1.canvas.width = width;
26808 fromPixels2DContext$1.canvas.height = height;
26809 fromPixels2DContext$1.drawImage(pixels, 0, 0, width, height);
26810 vals = fromPixels2DContext$1.getImageData(0, 0, width, height).data;
26811 }
26812 let values;
26813 if (numChannels === 4) {
26814 values = new Int32Array(vals);
26815 }
26816 else {
26817 const numPixels = width * height;
26818 values = new Int32Array(numPixels * numChannels);
26819 for (let i = 0; i < numPixels; i++) {
26820 for (let channel = 0; channel < numChannels; ++channel) {
26821 values[i * numChannels + channel] = vals[i * 4 + channel];
26822 }
26823 }
26824 }
26825 const outShape = [height, width, numChannels];
26826 return tensor3d(values, outShape, 'int32');
26827 }
26828 // Helper functions for |fromPixelsAsync| to check whether the input can
26829 // be wrapped into imageBitmap.
26830 function isPixelData(pixels) {
26831 return (pixels != null) && (pixels.data instanceof Uint8Array);
26832 }
26833 function isImageBitmapFullySupported() {
26834 return typeof window !== 'undefined' &&
26835 typeof (ImageBitmap) !== 'undefined' &&
26836 window.hasOwnProperty('createImageBitmap');
26837 }
26838 function isNonEmptyPixels(pixels) {
26839 return pixels != null && pixels.width !== 0 && pixels.height !== 0;
26840 }
26841 function canWrapPixelsToImageBitmap(pixels) {
26842 return isImageBitmapFullySupported() && !(pixels instanceof ImageBitmap) &&
26843 isNonEmptyPixels(pixels) && !isPixelData(pixels);
26844 }
26845 /**
26846 * Creates a `tf.Tensor` from an image in async way.
26847 *
26848 * ```js
26849 * const image = new ImageData(1, 1);
26850 * image.data[0] = 100;
26851 * image.data[1] = 150;
26852 * image.data[2] = 200;
26853 * image.data[3] = 255;
26854 *
26855 * (await tf.browser.fromPixelsAsync(image)).print();
26856 * ```
26857 * This API is the async version of fromPixels. The API will first
26858 * check |WRAP_TO_IMAGEBITMAP| flag, and try to wrap the input to
26859 * imageBitmap if the flag is set to true.
26860 *
26861 * @param pixels The input image to construct the tensor from. The
26862 * supported image types are all 4-channel. You can also pass in an image
26863 * object with following attributes:
26864 * `{data: Uint8Array; width: number; height: number}`
26865 * @param numChannels The number of channels of the output tensor. A
26866 * numChannels value less than 4 allows you to ignore channels. Defaults to
26867 * 3 (ignores alpha channel of input image).
26868 *
26869 * @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true}
26870 */
26871 async function fromPixelsAsync(pixels, numChannels = 3) {
26872 let inputs = null;
26873 // Check whether the backend needs to wrap |pixels| to imageBitmap and
26874 // whether |pixels| can be wrapped to imageBitmap.
26875 if (env().getBool('WRAP_TO_IMAGEBITMAP') &&
26876 canWrapPixelsToImageBitmap(pixels)) {
26877 // Force the imageBitmap creation to not do any premultiply alpha
26878 // ops.
26879 let imageBitmap;
26880 try {
26881 // wrap in try-catch block, because createImageBitmap may not work
26882 // properly in some browsers, e.g.
26883 // https://bugzilla.mozilla.org/show_bug.cgi?id=1335594
26884 // tslint:disable-next-line: no-any
26885 imageBitmap = await createImageBitmap(pixels, { premultiplyAlpha: 'none' });
26886 }
26887 catch (e) {
26888 imageBitmap = null;
26889 }
26890 // createImageBitmap will clip the source size.
26891 // In some cases, the input will have larger size than its content.
26892 // E.g. new Image(10, 10) but with 1 x 1 content. Using
26893 // createImageBitmap will clip the size from 10 x 10 to 1 x 1, which
26894 // is not correct. We should avoid wrapping such resouce to
26895 // imageBitmap.
26896 if (imageBitmap != null && imageBitmap.width === pixels.width &&
26897 imageBitmap.height === pixels.height) {
26898 inputs = imageBitmap;
26899 }
26900 else {
26901 inputs = pixels;
26902 }
26903 }
26904 else {
26905 inputs = pixels;
26906 }
26907 return fromPixels_(inputs, numChannels);
26908 }
26909 function validateImgTensor(img) {
26910 if (img.rank !== 2 && img.rank !== 3) {
26911 throw new Error(`toPixels only supports rank 2 or 3 tensors, got rank ${img.rank}.`);
26912 }
26913 const depth = img.rank === 2 ? 1 : img.shape[2];
26914 if (depth > 4 || depth === 2) {
26915 throw new Error(`toPixels only supports depth of size ` +
26916 `1, 3 or 4 but got ${depth}`);
26917 }
26918 if (img.dtype !== 'float32' && img.dtype !== 'int32') {
26919 throw new Error(`Unsupported type for toPixels: ${img.dtype}.` +
26920 ` Please use float32 or int32 tensors.`);
26921 }
26922 }
26923 function validateImageOptions(imageOptions) {
26924 const alpha = (imageOptions === null || imageOptions === void 0 ? void 0 : imageOptions.alpha) || 1;
26925 if (alpha > 1 || alpha < 0) {
26926 throw new Error(`Alpha value ${alpha} is suppoed to be in range [0 - 1].`);
26927 }
26928 }
26929 /**
26930 * Draws a `tf.Tensor` of pixel values to a byte array or optionally a
26931 * canvas.
26932 *
26933 * When the dtype of the input is 'float32', we assume values in the range
26934 * [0-1]. Otherwise, when input is 'int32', we assume values in the range
26935 * [0-255].
26936 *
26937 * Returns a promise that resolves when the canvas has been drawn to.
26938 *
26939 * @param img A rank-2 tensor with shape `[height, width]`, or a rank-3 tensor
26940 * of shape `[height, width, numChannels]`. If rank-2, draws grayscale. If
26941 * rank-3, must have depth of 1, 3 or 4. When depth of 1, draws
26942 * grayscale. When depth of 3, we draw with the first three components of
26943 * the depth dimension corresponding to r, g, b and alpha = 1. When depth of
26944 * 4, all four components of the depth dimension correspond to r, g, b, a.
26945 * @param canvas The canvas to draw to.
26946 *
26947 * @doc {heading: 'Browser', namespace: 'browser'}
26948 */
26949 async function toPixels(img, canvas) {
26950 let $img = convertToTensor(img, 'img', 'toPixels');
26951 if (!(img instanceof Tensor)) {
26952 // Assume int32 if user passed a native array.
26953 const originalImgTensor = $img;
26954 $img = cast$3(originalImgTensor, 'int32');
26955 originalImgTensor.dispose();
26956 }
26957 validateImgTensor($img);
26958 const [height, width] = $img.shape.slice(0, 2);
26959 const depth = $img.rank === 2 ? 1 : $img.shape[2];
26960 const data = await $img.data();
26961 const multiplier = $img.dtype === 'float32' ? 255 : 1;
26962 const bytes = new Uint8ClampedArray(width * height * 4);
26963 for (let i = 0; i < height * width; ++i) {
26964 const rgba = [0, 0, 0, 255];
26965 for (let d = 0; d < depth; d++) {
26966 const value = data[i * depth + d];
26967 if ($img.dtype === 'float32') {
26968 if (value < 0 || value > 1) {
26969 throw new Error(`Tensor values for a float32 Tensor must be in the ` +
26970 `range [0 - 1] but encountered ${value}.`);
26971 }
26972 }
26973 else if ($img.dtype === 'int32') {
26974 if (value < 0 || value > 255) {
26975 throw new Error(`Tensor values for a int32 Tensor must be in the ` +
26976 `range [0 - 255] but encountered ${value}.`);
26977 }
26978 }
26979 if (depth === 1) {
26980 rgba[0] = value * multiplier;
26981 rgba[1] = value * multiplier;
26982 rgba[2] = value * multiplier;
26983 }
26984 else {
26985 rgba[d] = value * multiplier;
26986 }
26987 }
26988 const j = i * 4;
26989 bytes[j + 0] = Math.round(rgba[0]);
26990 bytes[j + 1] = Math.round(rgba[1]);
26991 bytes[j + 2] = Math.round(rgba[2]);
26992 bytes[j + 3] = Math.round(rgba[3]);
26993 }
26994 if (canvas != null) {
26995 if (!hasToPixelsWarned) {
26996 const kernel = getKernel(Draw, ENGINE.backendName);
26997 if (kernel != null) {
26998 console.warn('tf.browser.toPixels is not efficient to draw tensor on canvas. ' +
26999 'Please try tf.browser.draw instead.');
27000 hasToPixelsWarned = true;
27001 }
27002 }
27003 canvas.width = width;
27004 canvas.height = height;
27005 const ctx = canvas.getContext('2d');
27006 const imageData = new ImageData(bytes, width, height);
27007 ctx.putImageData(imageData, 0, 0);
27008 }
27009 if ($img !== img) {
27010 $img.dispose();
27011 }
27012 return bytes;
27013 }
27014 /**
27015 * Draws a `tf.Tensor` to a canvas.
27016 *
27017 * When the dtype of the input is 'float32', we assume values in the range
27018 * [0-1]. Otherwise, when input is 'int32', we assume values in the range
27019 * [0-255].
27020 *
27021 * @param image The tensor to draw on the canvas. Must match one of
27022 * these shapes:
27023 * - Rank-2 with shape `[height, width`]: Drawn as grayscale.
27024 * - Rank-3 with shape `[height, width, 1]`: Drawn as grayscale.
27025 * - Rank-3 with shape `[height, width, 3]`: Drawn as RGB with alpha set in
27026 * `imageOptions` (defaults to 1, which is opaque).
27027 * - Rank-3 with shape `[height, width, 4]`: Drawn as RGBA.
27028 * @param canvas The canvas to draw to.
27029 * @param options The configuration arguments for image to be drawn and the
27030 * canvas to draw to.
27031 *
27032 * @doc {heading: 'Browser', namespace: 'browser'}
27033 */
27034 function draw$1(image, canvas, options) {
27035 let $img = convertToTensor(image, 'img', 'draw');
27036 if (!(image instanceof Tensor)) {
27037 // Assume int32 if user passed a native array.
27038 const originalImgTensor = $img;
27039 $img = cast$3(originalImgTensor, 'int32');
27040 originalImgTensor.dispose();
27041 }
27042 validateImgTensor($img);
27043 validateImageOptions(options === null || options === void 0 ? void 0 : options.imageOptions);
27044 const inputs = { image: $img };
27045 const attrs = { canvas, options };
27046 ENGINE.runKernel(Draw, inputs, attrs);
27047 }
27048 const fromPixels$1 = /* @__PURE__ */ op({ fromPixels_ });
27049
27050 var browser = /*#__PURE__*/Object.freeze({
27051 __proto__: null,
27052 draw: draw$1,
27053 fromPixels: fromPixels$1,
27054 fromPixelsAsync: fromPixelsAsync,
27055 toPixels: toPixels
27056 });
27057
27058 /**
27059 * Validate gather nd inputs.
27060 *
27061 * @param tensor The tensor contains the source values.
27062 * @param indices The tensor contains the indices to slice the source.
27063 *
27064 * @returns [resultShape, numUpdates, sliceSize, strides]
27065 */
27066 function prepareAndValidate(tensor, indices) {
27067 const tensorRank = tensor.shape.length;
27068 const indicesRank = indices.shape.length;
27069 if (tensorRank < 1) {
27070 throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' +
27071 ` but the rank was ${tensorRank}.`);
27072 }
27073 if (indicesRank < 1) {
27074 throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' +
27075 ` but the rank was ${indicesRank}.`);
27076 }
27077 if (indices.dtype !== 'int32') {
27078 throw new Error('tf.gatherND() expects the indices to be int32 type,' +
27079 ` but the dtype was ${indices.dtype}.`);
27080 }
27081 if (indices.shape[indicesRank - 1] > tensorRank) {
27082 throw new Error('index innermost dimension length must be <= tensor rank; saw: ' +
27083 `${indices.shape[indicesRank - 1]} vs. ${tensorRank}`);
27084 }
27085 if (sizeFromShape(tensor.shape) === 0) {
27086 throw new Error('Requested more than 0 entries, but input is empty.' +
27087 ` Input shape: ${tensor.shape}.`);
27088 }
27089 const indicesShape = indices.shape;
27090 const sliceRank = indicesShape[indicesShape.length - 1];
27091 // The result shape is
27092 // indices.shape[:-1] + params.shape[indices.shape[-1]:]
27093 let nResult = 1;
27094 for (let i = 0; i < indicesShape.length - 1; ++i) {
27095 nResult *= indicesShape[i];
27096 }
27097 const inputShape = tensor.shape;
27098 const resultShape = indicesShape.slice();
27099 resultShape.pop();
27100 let sliceSize = 1;
27101 for (let i = sliceRank; i < tensorRank; ++i) {
27102 sliceSize *= inputShape[i];
27103 resultShape.push(inputShape[i]);
27104 }
27105 const strides = [...computeStrides(tensor.shape).map(stride => stride / sliceSize),
27106 1].slice(0, sliceRank);
27107 return [resultShape, nResult, sliceSize, strides];
27108 }
27109
27110 var gather_nd_util = /*#__PURE__*/Object.freeze({
27111 __proto__: null,
27112 prepareAndValidate: prepareAndValidate
27113 });
27114
27115 /**
27116 * @license
27117 * Copyright 2021 Google LLC. All Rights Reserved.
27118 * Licensed under the Apache License, Version 2.0 (the "License");
27119 * you may not use this file except in compliance with the License.
27120 * You may obtain a copy of the License at
27121 *
27122 * http://www.apache.org/licenses/LICENSE-2.0
27123 *
27124 * Unless required by applicable law or agreed to in writing, software
27125 * distributed under the License is distributed on an "AS IS" BASIS,
27126 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27127 * See the License for the specific language governing permissions and
27128 * limitations under the License.
27129 * =============================================================================
27130 */
27131 const NEW_AXIS = -2;
27132 const SHRINK_AXIS = -1;
27133 function assertParamsValid(input, begin, size) {
27134 const inputRank = input.shape.length;
27135 assert$1(inputRank === begin.length, () => `Error in slice${inputRank}D: Length of begin ${begin} must ` +
27136 `match the rank of the array (${inputRank}).`);
27137 assert$1(inputRank === size.length, () => `Error in slice${inputRank}D: Length of size ${size} must ` +
27138 `match the rank of the array (${inputRank}).`);
27139 for (let i = 0; i < inputRank; ++i) {
27140 assert$1(begin[i] + size[i] <= input.shape[i], () => `Error in slice${inputRank}D: begin[${i}] + size[${i}] ` +
27141 `(${begin[i] + size[i]}) would overflow input.shape[${i}] (${input.shape[i]})`);
27142 }
27143 }
27144 /** Converts a binary mask to an array of axes. Used in stridedSlice(). */
27145 function maskToAxes(mask) {
27146 const axes = [];
27147 let axis = 0;
27148 while (mask > 0) {
27149 if (mask & 1) {
27150 axes.push(axis);
27151 }
27152 mask /= 2;
27153 axis++;
27154 }
27155 return axes;
27156 }
27157 /** Computes the output shape given the strided slice params. */
27158 function computeOutShape$2(begin, end, strides) {
27159 const size = [];
27160 for (let axis = 0; axis < begin.length; axis++) {
27161 size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]);
27162 }
27163 return size;
27164 }
27165 // Creates full selection at the elided dimensions. If the dimension matches
27166 // the ellipsis mask, override the current stride value. Otherwise, insert.
27167 function stridesWithElidedDims(strides, ellipsisInsertionIndex, numElidedAxes, inputShape) {
27168 const newStrides = [...strides];
27169 for (let i = newStrides.length; i < inputShape.length; i++) {
27170 newStrides.push(1);
27171 }
27172 for (let i = 0; i < numElidedAxes; i++) {
27173 if (i === 0) {
27174 newStrides[ellipsisInsertionIndex] = 1;
27175 }
27176 else {
27177 newStrides.splice(ellipsisInsertionIndex, 0 /* num elements to delete */, 1 /* element to add */);
27178 newStrides.pop();
27179 }
27180 }
27181 return newStrides;
27182 }
27183 function unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, normalizedAxis) {
27184 if (normalizedAxis <= ellipsisInsertionIndex) {
27185 return normalizedAxis;
27186 }
27187 return normalizedAxis - (numElidedAxes - 1);
27188 }
27189 function getElidedAxes(numElidedAxes, ellipsisInsertionIndex) {
27190 const elidedAxes = [];
27191 for (let i = 0; i < numElidedAxes; i++) {
27192 elidedAxes.push(ellipsisInsertionIndex + i);
27193 }
27194 return elidedAxes;
27195 }
27196 // Normalize the start, end and strides.
27197 function getNormalizedAxes(inputShape, ellipsisAxes, numInterpolatedAxes, begin, end, strides, beginMask, endMask, ellipsisMask) {
27198 const inputRank = inputShape.length;
27199 let normalizedBegin = new Array(inputRank), normalizedEnd = new Array(inputRank), normalizedStrides = new Array(inputRank);
27200 if (ellipsisAxes.length && numInterpolatedAxes > 0) {
27201 const fullIndex = ellipsisAxes[0];
27202 // The ellipsis applies to the masked index as well as any dimensions
27203 // that are interpolated.
27204 const numElidedAxes = numInterpolatedAxes + 1;
27205 normalizedBegin = startIndicesWithElidedDims(beginMask, fullIndex, numElidedAxes, begin, inputShape);
27206 normalizedEnd = stopIndicesWithElidedDims(endMask, fullIndex, numElidedAxes, end, inputShape);
27207 normalizedStrides =
27208 stridesWithElidedDims(strides, fullIndex, numElidedAxes, inputShape);
27209 }
27210 else {
27211 for (let axis = 0; axis < inputRank; axis++) {
27212 normalizedBegin[axis] = startForAxis(beginMask, begin, strides, inputShape, axis, ellipsisMask);
27213 normalizedEnd[axis] =
27214 stopForAxis(endMask, end, strides, inputShape, axis, ellipsisMask);
27215 normalizedStrides[axis] = stridesForAxis(strides, axis, ellipsisMask);
27216 }
27217 }
27218 return {
27219 begin: normalizedBegin,
27220 end: normalizedEnd,
27221 strides: normalizedStrides
27222 };
27223 }
27224 // Creates full selection at the elided dimensions. If the dimension matches
27225 // the ellipsis mask, override the current start value. Otherwise, insert.
27226 function startIndicesWithElidedDims(beginMask, ellipsisInsertionIndex, numElidedAxes, originalBegin, inputShape) {
27227 const newIndices = [...inputShape];
27228 const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
27229 for (let axis = 0; axis < newIndices.length; axis++) {
27230 if (elidedAxes.indexOf(axis) > -1) {
27231 newIndices[axis] = 0;
27232 }
27233 else {
27234 const originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
27235 let originalValue = originalBegin[originalAxis];
27236 if (beginMask & 1 << originalAxis) {
27237 originalValue = 0;
27238 }
27239 newIndices[axis] = originalValue;
27240 }
27241 }
27242 return newIndices;
27243 }
27244 // Creates full selection at the elided dimensions. If the dimension matches
27245 // the ellipsis mask, override the current stop value. Otherwise, insert.
27246 function stopIndicesWithElidedDims(endMask, ellipsisInsertionIndex, numElidedAxes, originalEnd, inputShape) {
27247 const newIndices = [...inputShape];
27248 const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
27249 for (let axis = 0; axis < newIndices.length; axis++) {
27250 if (elidedAxes.indexOf(axis) > -1) {
27251 newIndices[axis] = Number.MAX_SAFE_INTEGER;
27252 }
27253 else {
27254 const originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
27255 let originalValue = originalEnd[originalAxis];
27256 if (endMask & 1 << originalAxis) {
27257 originalValue = Number.MAX_SAFE_INTEGER;
27258 }
27259 newIndices[axis] = originalValue;
27260 }
27261 }
27262 for (let i = 0; i < newIndices.length; i++) {
27263 // Handle negative indices
27264 const axisSize = inputShape[i];
27265 if (newIndices[i] < 0) {
27266 newIndices[i] += axisSize;
27267 }
27268 newIndices[i] = clamp(0, newIndices[i], inputShape[i]);
27269 }
27270 return newIndices;
27271 }
27272 function stridesForAxis(strides, axis, ellipsisMask) {
27273 let stride = strides[axis];
27274 if (ellipsisMask & (1 << axis) || stride == null) {
27275 stride = 1;
27276 }
27277 return stride;
27278 }
27279 function startForAxis(beginMask, startIndices, strides, inputShape, axis, ellipsisMask) {
27280 // Begin with the specified index
27281 let start = startIndices[axis];
27282 const stride = strides[axis] || 1;
27283 // Check the axis bit from right of masked axes, or the begin index is not set
27284 // for the axis.
27285 if (beginMask & 1 << axis || ellipsisMask & 1 << axis || start == null) {
27286 if (stride > 0) {
27287 // Forward iteration - use the first element. These values will get
27288 // clamped below (Note: We could have set them to 0 and axis_size-1, but
27289 // use lowest() and max() to maintain symmetry with StopForAxis())
27290 start = Number.MIN_SAFE_INTEGER;
27291 }
27292 else {
27293 // Backward iteration - use the last element.
27294 start = Number.MAX_SAFE_INTEGER;
27295 }
27296 }
27297 // Handle negative indices
27298 const axisSize = inputShape[axis];
27299 if (start < 0) {
27300 start += axisSize;
27301 }
27302 // Clamping
27303 start = clamp(0, start, axisSize - 1);
27304 return start;
27305 }
27306 function stopForAxis(endMask, stopIndices, strides, inputShape, axis, ellipsisMask) {
27307 // Begin with the specified index
27308 let stop = stopIndices[axis];
27309 const stride = strides[axis] || 1;
27310 // Check the axis bit from right of masked axes, or if the stop index is not
27311 // set for this axis.
27312 if (endMask & (1 << axis) || ellipsisMask & (1 << axis) || stop == null) {
27313 if (stride > 0) {
27314 // Forward iteration - use the last element. These values will get
27315 // clamped below
27316 stop = Number.MAX_SAFE_INTEGER;
27317 }
27318 else {
27319 // Backward iteration - use the first element.
27320 stop = Number.MIN_SAFE_INTEGER;
27321 }
27322 }
27323 // Handle negative indices
27324 const axisSize = inputShape[axis];
27325 if (stop < 0) {
27326 stop += axisSize;
27327 }
27328 // Clamping
27329 // Because the end index points one past the last element, we need slightly
27330 // different clamping ranges depending on the direction.
27331 if (stride > 0) {
27332 // Forward iteration
27333 stop = clamp(0, stop, axisSize);
27334 }
27335 else {
27336 // Backward iteration
27337 stop = clamp(-1, stop, axisSize - 1);
27338 }
27339 return stop;
27340 }
27341 /**
27342 * Returns true if the slice occupies a continous set of elements in the
27343 * 'flat' space.
27344 */
27345 function isSliceContinous(shape, begin, size) {
27346 // Index of the first axis that has size > 1.
27347 let firstNonOneAxis = size.length;
27348 for (let i = 0; i < size.length; i++) {
27349 if (size[i] > 1) {
27350 firstNonOneAxis = i;
27351 break;
27352 }
27353 }
27354 for (let i = firstNonOneAxis + 1; i < size.length; i++) {
27355 if (begin[i] > 0 || size[i] !== shape[i]) {
27356 return false;
27357 }
27358 }
27359 return true;
27360 }
27361 function computeFlatOffset(begin, strides) {
27362 let flatOffset = begin.length > 0 ? begin[begin.length - 1] : 1;
27363 for (let i = 0; i < begin.length - 1; i++) {
27364 flatOffset += begin[i] * strides[i];
27365 }
27366 return flatOffset;
27367 }
27368 function parseSliceParams(x, begin, size) {
27369 // The following logic allows for more ergonomic calls.
27370 let begin_;
27371 const xRank = x.shape.length;
27372 if (typeof begin === 'number') {
27373 begin_ = [begin, ...new Array(xRank - 1).fill(0)];
27374 }
27375 else if (begin.length < xRank) {
27376 begin_ = begin.concat(new Array(xRank - begin.length).fill(0));
27377 }
27378 else {
27379 begin_ = begin.slice();
27380 }
27381 begin_.forEach(d => {
27382 assert$1(d !== -1, () => 'slice() does not support negative begin indexing.');
27383 });
27384 let size_;
27385 if (size == null) {
27386 size_ = new Array(xRank).fill(-1);
27387 }
27388 else if (typeof size === 'number') {
27389 size_ = [size, ...new Array(xRank - 1).fill(-1)];
27390 }
27391 else if (size.length < xRank) {
27392 size_ = size.concat(new Array(xRank - size.length).fill(-1));
27393 }
27394 else {
27395 size_ = size;
27396 }
27397 size_ = size_.map((d, i) => {
27398 if (d >= 0) {
27399 return d;
27400 }
27401 else {
27402 assert$1(d === -1, () => `Negative size values should be exactly -1 but got ` +
27403 `${d} for the slice() size at index ${i}.`);
27404 return x.shape[i] - begin_[i];
27405 }
27406 });
27407 return [begin_, size_];
27408 }
27409 // Convert the slicing specification from a sparse representation to a dense
27410 // representation. This means that all ellipses and newaxis are expanded out.
27411 function sliceInfo(xShape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
27412 let stridesNonNull;
27413 if (strides == null) {
27414 stridesNonNull = new Array(begin.length);
27415 stridesNonNull.fill(1);
27416 }
27417 else {
27418 stridesNonNull = strides;
27419 }
27420 // Only one non-zero bit is allowed in ellipsisMask, which means ellipsisMask
27421 // is a power of 2. Use bit compares to ensure ellipsisMask is 0 or a power
27422 // of 2. When i is a power of 2, i & (i - 1) is always 0.
27423 // Also ref:
27424 // https://stackoverflow.com/questions/600293/how-to-check-if-a-number-is-a-power-of-2
27425 if (ellipsisMask != null && (ellipsisMask & (ellipsisMask - 1)) !== 0) {
27426 throw new Error('Multiple ellipses in slice is not allowed.');
27427 }
27428 // Step 1: Account for ellipsis and new axis.
27429 // Check for ellipsis and count how many non-newaxis there are after.
27430 let ellipsisSeen = false;
27431 const sparseSpec = {
27432 dims: stridesNonNull.length,
27433 numAddAxisAfterEllipsis: 0,
27434 begin: begin.slice(),
27435 end: end.slice(),
27436 strides: stridesNonNull.slice(),
27437 beginMask,
27438 endMask,
27439 ellipsisMask,
27440 newAxisMask,
27441 shrinkAxisMask
27442 };
27443 for (let i = 0; i < sparseSpec.dims; i++) {
27444 if (ellipsisSeen && ((1 << i) & newAxisMask) !== 0) {
27445 sparseSpec.numAddAxisAfterEllipsis++;
27446 }
27447 if ((1 << i) & ellipsisMask) {
27448 ellipsisSeen = true;
27449 }
27450 }
27451 // If no ellipsis insert one at the end.
27452 if (!ellipsisSeen) {
27453 sparseSpec.ellipsisMask |= (1 << sparseSpec.dims);
27454 sparseSpec.dims++; // this effects loop iteration below
27455 }
27456 // Step 2: Make a sparse spec into a full index spec.
27457 //
27458 // The sparse spec deos not correspond to the number of dimensions.
27459 // Make a dense spec that cooresponds to the number of dimensions.
27460 //
27461 // For example suppose foo[...,3:] on foo.shape = [2, 2, 3] then we need to
27462 // produce the missing beginMask for the first two dimensions i.e. from
27463 // beginMaskSpec = 0, endMaskSpec = 2, we achieve beginMask = 6 (110),
27464 // endMask = 7 (111).
27465 const denseSpec = {
27466 dims: xShape.length,
27467 beginMask: 0,
27468 endMask: 0,
27469 beginValid: false,
27470 endValid: false
27471 };
27472 buildDenseSpec(sparseSpec, denseSpec);
27473 // Step 3: Make implicit ranges (non-zero beginMasks and endMasks) explicit
27474 // and bounds check.
27475 let isIdentity = true;
27476 let sliceDim0 = true;
27477 let isSimpleSlice = true;
27478 const processingShape = [];
27479 const finalShape = [];
27480 for (let i = 0; i < xShape.length; ++i) {
27481 if (denseSpec.strides[i] === 0) {
27482 throw Error(`strides[${i}] must be non-zero`);
27483 }
27484 const shrinkI = !!(denseSpec.shrinkAxisMask & (1 << i));
27485 const dimI = xShape[i];
27486 if (dimI === -1) {
27487 processingShape.push(shrinkI ? 1 : -1);
27488 continue;
27489 }
27490 const masks = [denseSpec.beginMask & (1 << i), denseSpec.endMask & (1 << i)];
27491 const validRange = [
27492 denseSpec.strides[i] > 0 ? 0 : -1,
27493 denseSpec.strides[i] > 0 ? dimI : dimI - 1
27494 ];
27495 if (shrinkI && denseSpec.strides[i] <= 0) {
27496 throw Error('only stride 1 allowed on non-range indexing.');
27497 }
27498 isSimpleSlice = isSimpleSlice && (denseSpec.strides[i] === 1);
27499 const beginAndEndMasked = !!((denseSpec.beginMask & (1 << i)) && (denseSpec.endMask & (1 << i)));
27500 if (denseSpec.beginValid && denseSpec.endValid) {
27501 if (shrinkI) {
27502 // If we are shrinking, the end index is now possibly incorrect. In
27503 // particular foo[-1] produces sparseBegin = -1, sparseEnd = 0.
27504 // and canonical puts these to n-1 and 0, which implies a degenerate
27505 // interval. Fortunately, it is now safe to re-create end as begin + 1.
27506 const xFwd = denseSpec.begin[i] < 0 ? dimI + denseSpec.begin[i] :
27507 denseSpec.begin[i];
27508 denseSpec.begin[i] = xFwd;
27509 denseSpec.end[i] = denseSpec.begin[i] + 1;
27510 if (xFwd < 0 || xFwd >= dimI) {
27511 throw Error(`slice index ${denseSpec.begin[i]} of dimension ${i} out of bounds.`);
27512 }
27513 }
27514 else {
27515 denseSpec.begin[i] = canonical(denseSpec.begin[i], 0, denseSpec.strides[i], dimI, masks, validRange);
27516 denseSpec.end[i] = canonical(denseSpec.end[i], 1, denseSpec.strides[i], dimI, masks, validRange);
27517 }
27518 // Update optimization values
27519 const takeAllInDimension = denseSpec.strides[i] === 1 &&
27520 denseSpec.begin[i] === 0 && denseSpec.end[i] === dimI;
27521 isIdentity = isIdentity && takeAllInDimension;
27522 sliceDim0 = sliceDim0 &&
27523 ((i === 0 && denseSpec.strides[i] === 1) || takeAllInDimension);
27524 }
27525 else {
27526 isIdentity =
27527 isIdentity && ((denseSpec.strides[i] === 1) && beginAndEndMasked);
27528 sliceDim0 = sliceDim0 &&
27529 ((i === 0 && denseSpec.strides[i] === 1) || beginAndEndMasked);
27530 }
27531 // Compute the processing shape (the intermediate Eigen will produce)
27532 let intervalLength;
27533 let knownInterval = false;
27534 if (denseSpec.beginValid && denseSpec.endValid) {
27535 intervalLength = denseSpec.end[i] - denseSpec.begin[i];
27536 knownInterval = true;
27537 }
27538 else if (shrinkI) {
27539 // The dimension is still known as 1 for the processingShape, but will be
27540 // discarded for the final shape.
27541 intervalLength = 1;
27542 knownInterval = true;
27543 }
27544 else if (beginAndEndMasked) {
27545 // Even if we don't have values for begin or end, we do know that this
27546 // dimension covers the whole interval. If we have shape information for
27547 // this dimension, that tells us the interval length.
27548 if (dimI >= 0) {
27549 if (denseSpec.strides[i] < 0) {
27550 intervalLength = -dimI;
27551 }
27552 else {
27553 intervalLength = dimI;
27554 }
27555 knownInterval = true;
27556 }
27557 }
27558 if (knownInterval) {
27559 let sizeI;
27560 // Hold zero if the interval is degenerate, otherwise account for
27561 // remainder
27562 if (intervalLength === 0 ||
27563 ((intervalLength < 0) !== (denseSpec.strides[i] < 0))) {
27564 sizeI = 0;
27565 }
27566 else {
27567 sizeI = Math.trunc(intervalLength / denseSpec.strides[i]) +
27568 (intervalLength % denseSpec.strides[i] !== 0 ? 1 : 0);
27569 }
27570 processingShape.push(sizeI);
27571 }
27572 else {
27573 processingShape.push(-1);
27574 }
27575 }
27576 // Step 4: Compute the final shape
27577 //
27578 // newAxis will increase dimension by 1 (with a one-size dimension)
27579 // slices like foo[3, ...] will reduce dimension by 1.
27580 // This cannot be done earlier, because it depends on Step 3.
27581 for (let denseDim = 0; denseDim < denseSpec.finalShapeGatherIndices.length; ++denseDim) {
27582 const gatherIndex = denseSpec.finalShapeGatherIndices[denseDim];
27583 if (gatherIndex >= 0) {
27584 finalShape.push(processingShape[gatherIndex]);
27585 }
27586 else if (gatherIndex === NEW_AXIS) {
27587 finalShape.push(1);
27588 }
27589 }
27590 const finalShapeSparse = finalShape.filter((dim, i) => denseSpec.finalShapeGatherIndices[i] !== NEW_AXIS);
27591 return {
27592 finalShapeSparse,
27593 finalShape,
27594 isIdentity,
27595 sliceDim0,
27596 isSimpleSlice,
27597 begin: denseSpec.begin,
27598 end: denseSpec.end,
27599 strides: denseSpec.strides
27600 };
27601 }
27602 function buildDenseSpec(sparse, dense) {
27603 dense.beginMask = 0;
27604 dense.endMask = 0;
27605 dense.shrinkAxisMask = 0;
27606 let fullIndex = 0;
27607 dense.beginValid = sparse.begin != null;
27608 dense.endValid = sparse.end != null;
27609 dense.begin = new Array(dense.dims);
27610 dense.end = new Array(dense.dims);
27611 dense.strides = new Array(dense.dims);
27612 dense.finalShapeGatherIndices = [];
27613 dense.finalShapeGatherIndicesSparse = [];
27614 dense.inputShapeGatherIndicesSparse = new Array(dense.dims);
27615 for (let i = 0; i < sparse.dims; i++) {
27616 if ((1 << i) & sparse.ellipsisMask) {
27617 // Only the bit that has ellipsis will fall in this condition.
27618 // Expand the ellipsis into the appropriate indices
27619 // Note: this only works because we guaranteed one ellipsis.
27620 const nextIndex = Math.min(dense.dims - (sparse.dims - i) + 1 + sparse.numAddAxisAfterEllipsis, dense.dims);
27621 for (; fullIndex < nextIndex; fullIndex++) {
27622 // newAxis aren't real axis so you have to skip.
27623 dense.begin[fullIndex] = 0;
27624 dense.end[fullIndex] = 0;
27625 dense.strides[fullIndex] = 1;
27626 dense.beginMask |= (1 << fullIndex);
27627 dense.endMask |= (1 << fullIndex);
27628 dense.finalShapeGatherIndices.push(fullIndex);
27629 dense.finalShapeGatherIndicesSparse.push(-1);
27630 dense.inputShapeGatherIndicesSparse[fullIndex] = i;
27631 }
27632 }
27633 else if ((1 << i) & sparse.newAxisMask) {
27634 // Only the bit that has newAxis will fall in this condition.
27635 dense.finalShapeGatherIndices.push(NEW_AXIS);
27636 dense.finalShapeGatherIndicesSparse.push(-1);
27637 }
27638 else {
27639 if (fullIndex === dense.begin.length) {
27640 throw Error(`Index out of range using input dim ${fullIndex}; input ` +
27641 `has only ${dense.dims} dims, ${dense.begin.length}.`);
27642 }
27643 // Gather slicing spec into appropriate index.
27644 if (sparse.begin != null) {
27645 dense.begin[fullIndex] = sparse.begin[i];
27646 }
27647 if (sparse.end != null) {
27648 dense.end[fullIndex] = sparse.end[i];
27649 }
27650 dense.strides[fullIndex] = sparse.strides[i];
27651 if (sparse.beginMask & (1 << i)) {
27652 dense.beginMask |= (1 << fullIndex);
27653 }
27654 if (sparse.endMask & (1 << i)) {
27655 dense.endMask |= (1 << fullIndex);
27656 }
27657 // If shrink, record where to get the dimensionality from (i.e. newAxis)
27658 // creates a fake 1 size dimension. Also remember shrink axis (now in
27659 // dense form) so we can ignore dense.end below.
27660 if (sparse.shrinkAxisMask & (1 << i)) {
27661 dense.finalShapeGatherIndices.push(SHRINK_AXIS);
27662 dense.finalShapeGatherIndicesSparse.push(-1);
27663 dense.shrinkAxisMask |= (1 << fullIndex);
27664 }
27665 else {
27666 dense.finalShapeGatherIndices.push(fullIndex);
27667 // Remember that where in the sparse shape the dense dim comes from.
27668 dense.finalShapeGatherIndicesSparse.push(i);
27669 }
27670 dense.inputShapeGatherIndicesSparse[fullIndex] = i;
27671 fullIndex++;
27672 }
27673 }
27674 }
27675 function canonical(x, c, strideI, dimI, masks, validRange) {
27676 if (masks[c]) {
27677 return strideI > 0 ? validRange[c] : validRange[(c + 1) & 1];
27678 }
27679 else {
27680 const xFwd = x < 0 ? dimI + x : x; // make negative indices positive
27681 return xFwd < validRange[0] ? validRange[0] :
27682 xFwd > validRange[1] ? validRange[1] : xFwd;
27683 }
27684 }
27685
27686 var slice_util = /*#__PURE__*/Object.freeze({
27687 __proto__: null,
27688 assertParamsValid: assertParamsValid,
27689 computeFlatOffset: computeFlatOffset,
27690 computeOutShape: computeOutShape$2,
27691 getNormalizedAxes: getNormalizedAxes,
27692 isSliceContinous: isSliceContinous,
27693 maskToAxes: maskToAxes,
27694 parseSliceParams: parseSliceParams,
27695 sliceInfo: sliceInfo,
27696 startForAxis: startForAxis,
27697 startIndicesWithElidedDims: startIndicesWithElidedDims,
27698 stopForAxis: stopForAxis,
27699 stopIndicesWithElidedDims: stopIndicesWithElidedDims,
27700 stridesForAxis: stridesForAxis,
27701 stridesWithElidedDims: stridesWithElidedDims
27702 });
27703
27704 /** @license See the LICENSE file. */
27705 // This code is auto-generated, do not modify this file!
27706 const version$7 = '4.22.0';
27707
27708 /**
27709 * @license
27710 * Copyright 2018 Google LLC. All Rights Reserved.
27711 * Licensed under the Apache License, Version 2.0 (the "License");
27712 * you may not use this file except in compliance with the License.
27713 * You may obtain a copy of the License at
27714 *
27715 * http://www.apache.org/licenses/LICENSE-2.0
27716 *
27717 * Unless required by applicable law or agreed to in writing, software
27718 * distributed under the License is distributed on an "AS IS" BASIS,
27719 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27720 * See the License for the specific language governing permissions and
27721 * limitations under the License.
27722 * =============================================================================
27723 */
27724 class OptimizerConstructors {
27725 /**
27726 * Constructs a `tf.SGDOptimizer` that uses stochastic gradient descent.
27727 *
27728 * ```js
27729 * // Fit a quadratic function by learning the coefficients a, b, c.
27730 * const xs = tf.tensor1d([0, 1, 2, 3]);
27731 * const ys = tf.tensor1d([1.1, 5.9, 16.8, 33.9]);
27732 *
27733 * const a = tf.scalar(Math.random()).variable();
27734 * const b = tf.scalar(Math.random()).variable();
27735 * const c = tf.scalar(Math.random()).variable();
27736 *
27737 * // y = a * x^2 + b * x + c.
27738 * const f = x => a.mul(x.square()).add(b.mul(x)).add(c);
27739 * const loss = (pred, label) => pred.sub(label).square().mean();
27740 *
27741 * const learningRate = 0.01;
27742 * const optimizer = tf.train.sgd(learningRate);
27743 *
27744 * // Train the model.
27745 * for (let i = 0; i < 10; i++) {
27746 * optimizer.minimize(() => loss(f(xs), ys));
27747 * }
27748 *
27749 * // Make predictions.
27750 * console.log(
27751 * `a: ${a.dataSync()}, b: ${b.dataSync()}, c: ${c.dataSync()}`);
27752 * const preds = f(xs).dataSync();
27753 * preds.forEach((pred, i) => {
27754 * console.log(`x: ${i}, pred: ${pred}`);
27755 * });
27756 * ```
27757 *
27758 * @param learningRate The learning rate to use for the SGD algorithm.
27759 *
27760 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
27761 */
27762 static sgd(learningRate) {
27763 return new SGDOptimizer(learningRate);
27764 }
27765 /**
27766 * Constructs a `tf.MomentumOptimizer` that uses momentum gradient
27767 * descent.
27768 *
27769 * See
27770 * [http://proceedings.mlr.press/v28/sutskever13.pdf](
27771 * http://proceedings.mlr.press/v28/sutskever13.pdf)
27772 *
27773 * @param learningRate The learning rate to use for the Momentum gradient
27774 * descent algorithm.
27775 * @param momentum The momentum to use for the momentum gradient descent
27776 * algorithm.
27777 *
27778 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
27779 */
27780 static momentum(learningRate, momentum, useNesterov = false) {
27781 return new MomentumOptimizer(learningRate, momentum, useNesterov);
27782 }
27783 /**
27784 * Constructs a `tf.RMSPropOptimizer` that uses RMSProp gradient
27785 * descent. This implementation uses plain momentum and is not centered
27786 * version of RMSProp.
27787 *
27788 * See
27789 * [http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf](
27790 * http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
27791 *
27792 * @param learningRate The learning rate to use for the RMSProp gradient
27793 * descent algorithm.
27794 * @param decay The discounting factor for the history/coming gradient.
27795 * @param momentum The momentum to use for the RMSProp gradient descent
27796 * algorithm.
27797 * @param epsilon Small value to avoid zero denominator.
27798 * @param centered If true, gradients are normalized by the estimated
27799 * variance of the gradient.
27800 *
27801 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
27802 */
27803 static rmsprop(learningRate, decay = .9, momentum = 0.0, epsilon = null, centered = false) {
27804 return new RMSPropOptimizer(learningRate, decay, momentum, epsilon, centered);
27805 }
27806 /**
27807 * Constructs a `tf.AdamOptimizer` that uses the Adam algorithm.
27808 * See [https://arxiv.org/abs/1412.6980](https://arxiv.org/abs/1412.6980)
27809 *
27810 * @param learningRate The learning rate to use for the Adam gradient
27811 * descent algorithm.
27812 * @param beta1 The exponential decay rate for the 1st moment estimates.
27813 * @param beta2 The exponential decay rate for the 2nd moment estimates.
27814 * @param epsilon A small constant for numerical stability.
27815 *
27816 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
27817 */
27818 static adam(learningRate = 0.001, beta1 = 0.9, beta2 = 0.999, epsilon = null) {
27819 return new AdamOptimizer(learningRate, beta1, beta2, epsilon);
27820 }
27821 /**
27822 * Constructs a `tf.AdadeltaOptimizer` that uses the Adadelta algorithm.
27823 * See [https://arxiv.org/abs/1212.5701](https://arxiv.org/abs/1212.5701)
27824 *
27825 * @param learningRate The learning rate to use for the Adadelta gradient
27826 * descent algorithm.
27827 * @param rho The learning rate decay over each update.
27828 * @param epsilon A constant epsilon used to better condition the grad
27829 * update.
27830 *
27831 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
27832 */
27833 static adadelta(learningRate = .001, rho = .95, epsilon = null) {
27834 return new AdadeltaOptimizer(learningRate, rho, epsilon);
27835 }
27836 /**
27837 * Constructs a `tf.AdamaxOptimizer` that uses the Adamax algorithm.
27838 * See [https://arxiv.org/abs/1412.6980](https://arxiv.org/abs/1412.6980)
27839 *
27840 * @param learningRate The learning rate to use for the Adamax gradient
27841 * descent algorithm.
27842 * @param beta1 The exponential decay rate for the 1st moment estimates.
27843 * @param beta2 The exponential decay rate for the 2nd moment estimates.
27844 * @param epsilon A small constant for numerical stability.
27845 * @param decay The learning rate decay over each update.
27846 *
27847 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
27848 */
27849 static adamax(learningRate = 0.002, beta1 = 0.9, beta2 = 0.999, epsilon = null, decay = 0.0) {
27850 return new AdamaxOptimizer(learningRate, beta1, beta2, epsilon, decay);
27851 }
27852 /**
27853 * Constructs a `tf.AdagradOptimizer` that uses the Adagrad algorithm.
27854 * See
27855 * [http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf](
27856 * http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
27857 * or
27858 * [http://ruder.io/optimizing-gradient-descent/index.html#adagrad](
27859 * http://ruder.io/optimizing-gradient-descent/index.html#adagrad)
27860 *
27861 * @param learningRate The learning rate to use for the Adagrad gradient
27862 * descent algorithm.
27863 * @param initialAccumulatorValue Starting value for the accumulators, must be
27864 * positive.
27865 *
27866 * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
27867 */
27868 static adagrad(learningRate, initialAccumulatorValue = 0.1) {
27869 return new AdagradOptimizer(learningRate, initialAccumulatorValue);
27870 }
27871 }
27872
27873 /**
27874 * @license
27875 * Copyright 2018 Google LLC. All Rights Reserved.
27876 * Licensed under the Apache License, Version 2.0 (the "License");
27877 * you may not use this file except in compliance with the License.
27878 * You may obtain a copy of the License at
27879 *
27880 * http://www.apache.org/licenses/LICENSE-2.0
27881 *
27882 * Unless required by applicable law or agreed to in writing, software
27883 * distributed under the License is distributed on an "AS IS" BASIS,
27884 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27885 * See the License for the specific language governing permissions and
27886 * limitations under the License.
27887 * =============================================================================
27888 */
27889 const train = OptimizerConstructors;
27890
27891 /**
27892 * @license
27893 * Copyright 2017 Google LLC. All Rights Reserved.
27894 * Licensed under the Apache License, Version 2.0 (the "License");
27895 * you may not use this file except in compliance with the License.
27896 * You may obtain a copy of the License at
27897 *
27898 * http://www.apache.org/licenses/LICENSE-2.0
27899 *
27900 * Unless required by applicable law or agreed to in writing, software
27901 * distributed under the License is distributed on an "AS IS" BASIS,
27902 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27903 * See the License for the specific language governing permissions and
27904 * limitations under the License.
27905 * =============================================================================
27906 */
27907 const delayCallback = (() => {
27908 if (typeof requestAnimationFrame !== 'undefined') {
27909 return requestAnimationFrame;
27910 }
27911 else if (typeof setImmediate !== 'undefined') {
27912 return setImmediate;
27913 }
27914 return (f) => f(); // no delays
27915 })();
27916 /**
27917 * Returns a promise that resolves when a requestAnimationFrame has completed.
27918 *
27919 * On Node.js this uses setImmediate instead of requestAnimationFrame.
27920 *
27921 * This is simply a sugar method so that users can do the following:
27922 * `await tf.nextFrame();`
27923 *
27924 * @doc {heading: 'Performance', subheading: 'Timing'}
27925 */
27926 function nextFrame() {
27927 return new Promise(resolve => delayCallback(() => resolve()));
27928 }
27929
27930 /**
27931 * @license
27932 * Copyright 2017 Google LLC. All Rights Reserved.
27933 * Licensed under the Apache License, Version 2.0 (the "License");
27934 * you may not use this file except in compliance with the License.
27935 * You may obtain a copy of the License at
27936 *
27937 * http://www.apache.org/licenses/LICENSE-2.0
27938 *
27939 * Unless required by applicable law or agreed to in writing, software
27940 * distributed under the License is distributed on an "AS IS" BASIS,
27941 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27942 * See the License for the specific language governing permissions and
27943 * limitations under the License.
27944 * =============================================================================
27945 */
27946 function assertParamsConsistent(shapes, axis) {
27947 const rank = shapes[0].length;
27948 shapes.forEach((shape, i) => {
27949 assert$1(shape.length === rank, () => `Error in concat${rank}D: rank of tensors[${i}] must be the same ` +
27950 `as the rank of the rest (${rank})`);
27951 });
27952 assert$1(axis >= 0 && axis < rank, () => `Error in concat${rank}D: axis must be between 0 and ${rank - 1}.`);
27953 const firstShape = shapes[0];
27954 shapes.forEach((shape, i) => {
27955 for (let r = 0; r < rank; r++) {
27956 assert$1((r === axis) || (shape[r] === firstShape[r]), () => `Error in concat${rank}D: Shape of tensors[${i}] (${shape}) ` +
27957 `does not match the shape of the rest (${firstShape}) ` +
27958 `along the non-concatenated axis ${i}.`);
27959 }
27960 });
27961 }
27962 function computeOutShape$1(shapes, axis) {
27963 const outputShape = shapes[0].slice();
27964 for (let i = 1; i < shapes.length; i++) {
27965 outputShape[axis] += shapes[i][axis];
27966 }
27967 return outputShape;
27968 }
27969
27970 /**
27971 * @license
27972 * Copyright 2020 Google Inc. All Rights Reserved.
27973 * Licensed under the Apache License, Version 2.0 (the "License");
27974 * you may not use this file except in compliance with the License.
27975 * You may obtain a copy of the License at
27976 *
27977 * http://www.apache.org/licenses/LICENSE-2.0
27978 *
27979 * Unless required by applicable law or agreed to in writing, software
27980 * distributed under the License is distributed on an "AS IS" BASIS,
27981 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27982 * See the License for the specific language governing permissions and
27983 * limitations under the License.
27984 * =============================================================================
27985 */
27986
27987 /**
27988 * @license
27989 * Copyright 2022 Google LLC. All Rights Reserved.
27990 * Licensed under the Apache License, Version 2.0 (the "License");
27991 * you may not use this file except in compliance with the License.
27992 * You may obtain a copy of the License at
27993 *
27994 * http://www.apache.org/licenses/LICENSE-2.0
27995 *
27996 * Unless required by applicable law or agreed to in writing, software
27997 * distributed under the License is distributed on an "AS IS" BASIS,
27998 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27999 * See the License for the specific language governing permissions and
28000 * limitations under the License.
28001 * =============================================================================
28002 */
28003 var RowPartitionType$1;
28004 (function (RowPartitionType) {
28005 RowPartitionType[RowPartitionType["FIRST_DIM_SIZE"] = 0] = "FIRST_DIM_SIZE";
28006 RowPartitionType[RowPartitionType["VALUE_ROWIDS"] = 1] = "VALUE_ROWIDS";
28007 RowPartitionType[RowPartitionType["ROW_LENGTHS"] = 2] = "ROW_LENGTHS";
28008 RowPartitionType[RowPartitionType["ROW_SPLITS"] = 3] = "ROW_SPLITS";
28009 RowPartitionType[RowPartitionType["ROW_LIMITS"] = 4] = "ROW_LIMITS";
28010 RowPartitionType[RowPartitionType["ROW_STARTS"] = 5] = "ROW_STARTS";
28011 })(RowPartitionType$1 || (RowPartitionType$1 = {}));
28012 function combineRaggedTensorToTensorShapes(raggedRank, shape, valueShape) {
28013 // Test for consistency of valueShape and shape specified.
28014 // If shape is unspecified and valueShape is specified, then copy
28015 // over the size from the valueShape dimension.
28016 let outputShape = new Array();
28017 if (valueShape == null && shape == null) {
28018 return outputShape;
28019 }
28020 if (shape == null) {
28021 // Here, value_shape must be of known size.
28022 while (outputShape.length < raggedRank + valueShape.length) {
28023 outputShape.push(-1);
28024 }
28025 }
28026 else {
28027 outputShape = shape.slice();
28028 }
28029 if (valueShape == null) {
28030 return outputShape;
28031 }
28032 // At this point, valueShape and output_shape have known ranks.
28033 if (raggedRank + valueShape.length !== outputShape.length) {
28034 throw new Error(`rt input.shape and shape=${shape} are incompatible: rt input.rank = ${raggedRank +
28035 valueShape.length}, but shape.rank = ${outputShape.length}`);
28036 }
28037 for (let i = 1; i < valueShape.length; ++i) {
28038 const valueDim = valueShape[i];
28039 const outputShapeDimIndex = outputShape[outputShape.length - valueShape.length + i];
28040 const outputShapeDim = outputShape[outputShapeDimIndex];
28041 if (valueDim >= 0) {
28042 if (outputShapeDim >= 0) {
28043 if (outputShapeDim !== valueDim) {
28044 throw new Error(`rt input.shape and shape=${shape} are incompatible: rt input.shape[${i + raggedRank}] = ${valueDim} but shape[${i + raggedRank}] = ${outputShapeDim}`);
28045 }
28046 }
28047 else {
28048 outputShape[outputShapeDimIndex] = valueDim;
28049 }
28050 }
28051 }
28052 return outputShape;
28053 }
28054 function getRowPartitionTypesHelper(rowPartitionTypeStrings) {
28055 const stringToType = {
28056 'FIRST_DIM_SIZE': RowPartitionType$1.FIRST_DIM_SIZE,
28057 'VALUE_ROWIDS': RowPartitionType$1.VALUE_ROWIDS,
28058 'ROW_LENGTHS': RowPartitionType$1.ROW_LENGTHS,
28059 'ROW_SPLITS': RowPartitionType$1.ROW_SPLITS,
28060 'ROW_LIMITS': RowPartitionType$1.ROW_LIMITS,
28061 'ROW_STARTS': RowPartitionType$1.ROW_STARTS
28062 };
28063 const result = [];
28064 for (const typeStr of rowPartitionTypeStrings) {
28065 if (typeStr in stringToType) {
28066 result.push(stringToType[typeStr]);
28067 }
28068 else {
28069 break;
28070 }
28071 }
28072 return result;
28073 }
28074 function getRaggedRank(rowPartitionTypes) {
28075 if (rowPartitionTypes.length === 0) {
28076 return 0;
28077 }
28078 if (rowPartitionTypes[0] === RowPartitionType$1.FIRST_DIM_SIZE) {
28079 return rowPartitionTypes.length - 1;
28080 }
28081 return rowPartitionTypes.length;
28082 }
28083 function validateDefaultValueShape(defaultValueShape, valueShape) {
28084 if (defaultValueShape == null || valueShape == null) {
28085 return;
28086 }
28087 const defaultNDims = defaultValueShape.length;
28088 const valuesNDims = valueShape.length;
28089 if (defaultNDims >= valuesNDims) {
28090 throw new Error(`defaultValue.shape=${defaultValueShape} and ragged tensor flatValues.shape=${valueShape}, are incompatible: defaultValue.rank = ${defaultNDims} must be less than ragged tensor input flatValues.rank = ${valuesNDims})`);
28091 }
28092 for (let i = 0; i < Math.min(defaultNDims, valuesNDims - 1); ++i) {
28093 const defaultDim = defaultValueShape[i];
28094 const valueDim = valueShape[i + 1];
28095 if (defaultDim >= 0 && valueDim >= 0 && defaultDim !== 1 &&
28096 defaultDim !== valueDim) {
28097 throw new Error(`defaultValue.shape=${defaultValueShape}, and ragged tensor input flatValues.shape=${valueShape} are incompatible: defaultValue.shape[${i - defaultValueShape.length}] = ${defaultDim} but ragged tensor input.flatValues.shape[${i - defaultValueShape.length}] = ${valueDim}`);
28098 }
28099 }
28100 }
28101
28102 /**
28103 * @license
28104 * Copyright 2017 Google LLC. All Rights Reserved.
28105 * Licensed under the Apache License, Version 2.0 (the "License");
28106 * you may not use this file except in compliance with the License.
28107 * You may obtain a copy of the License at
28108 *
28109 * http://www.apache.org/licenses/LICENSE-2.0
28110 *
28111 * Unless required by applicable law or agreed to in writing, software
28112 * distributed under the License is distributed on an "AS IS" BASIS,
28113 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28114 * See the License for the specific language governing permissions and
28115 * limitations under the License.
28116 * =============================================================================
28117 */
28118 const PARALLELIZE_THRESHOLD = 30;
28119 function computeOptimalWindowSize(inSize) {
28120 if (inSize <= PARALLELIZE_THRESHOLD) {
28121 return inSize;
28122 }
28123 return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
28124 }
28125
28126 /**
28127 * @license
28128 * Copyright 2020 Google LLC. All Rights Reserved.
28129 * Licensed under the Apache License, Version 2.0 (the "License");
28130 * you may not use this file except in compliance with the License.
28131 * You may obtain a copy of the License at
28132 *
28133 * http://www.apache.org/licenses/LICENSE-2.0
28134 *
28135 * Unless required by applicable law or agreed to in writing, software
28136 * distributed under the License is distributed on an "AS IS" BASIS,
28137 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28138 * See the License for the specific language governing permissions and
28139 * limitations under the License.
28140 * =============================================================================
28141 */
28142 // Returns the image center in pixels.
28143 function getImageCenter(center, imageHeight, imageWidth) {
28144 const centerX = imageWidth * (typeof center === 'number' ? center : center[0]);
28145 const centerY = imageHeight * (typeof center === 'number' ? center : center[1]);
28146 return [centerX, centerY];
28147 }
28148
28149 /**
28150 * @license
28151 * Copyright 2018 Google LLC. All Rights Reserved.
28152 * Licensed under the Apache License, Version 2.0 (the "License");
28153 * you may not use this file except in compliance with the License.
28154 * You may obtain a copy of the License at
28155 *
28156 * http://www.apache.org/licenses/LICENSE-2.0
28157 *
28158 * Unless required by applicable law or agreed to in writing, software
28159 * distributed under the License is distributed on an "AS IS" BASIS,
28160 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28161 * See the License for the specific language governing permissions and
28162 * limitations under the License.
28163 * =============================================================================
28164 */
28165 /**
28166 * Gets the new shape of the input Tensor after it's been reshaped
28167 * to:
28168 * [blockShape[0], ..., blockShape[M-1], batch / prod(blockShape),
28169 * inputShape[1], ..., inputShape[N-1]]
28170 *
28171 * See step 1: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
28172 */
28173 function getReshaped(inputShape, blockShape, prod, batchToSpace = true) {
28174 let reshaped = [];
28175 if (batchToSpace) {
28176 reshaped = reshaped.concat(blockShape.slice(0));
28177 reshaped.push(inputShape[0] / prod);
28178 reshaped = reshaped.concat(inputShape.slice(1));
28179 }
28180 else {
28181 reshaped = reshaped.concat(inputShape[0]);
28182 const spatialLength = blockShape.length;
28183 for (let i = 0; i < spatialLength; ++i) {
28184 reshaped =
28185 reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]);
28186 }
28187 reshaped = reshaped.concat(inputShape.slice(spatialLength + 1));
28188 }
28189 return reshaped;
28190 }
28191 /**
28192 * Gets the permutation that will transpose the dimensions of the
28193 * reshaped tensor to shape:
28194 *
28195 * [batch / prod(block_shape),inputShape[1], blockShape[0], ...,
28196 * inputShape[M], blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
28197 *
28198 * see step 2: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
28199 */
28200 function getPermuted(reshapedRank, blockShapeRank, batchToSpace = true) {
28201 const permuted = [];
28202 if (batchToSpace) {
28203 permuted.push(blockShapeRank);
28204 for (let i = blockShapeRank + 1; i < reshapedRank; ++i) {
28205 if (i <= 2 * blockShapeRank) {
28206 permuted.push(i);
28207 permuted.push(i - (blockShapeRank + 1));
28208 }
28209 else {
28210 permuted.push(i);
28211 }
28212 }
28213 }
28214 else {
28215 const permutedBeforeBatch = [];
28216 const permutedAfterBatch = [];
28217 for (let i = 1; i < reshapedRank; ++i) {
28218 if (i >= blockShapeRank * 2 + 1 || i % 2 === 1) {
28219 permutedAfterBatch.push(i);
28220 }
28221 else {
28222 permutedBeforeBatch.push(i);
28223 }
28224 }
28225 permuted.push(...permutedBeforeBatch);
28226 permuted.push(0);
28227 permuted.push(...permutedAfterBatch);
28228 }
28229 return permuted;
28230 }
28231 /**
28232 * Gets the shape of the reshaped and permuted input Tensor before any cropping
28233 * is applied. The new shape will be:
28234 *
28235 * [batch / prod(blockShape),inputShape[1] * blockShape[0], ...,
28236 * inputShape[M] * blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
28237 *
28238 * See step 3: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
28239 */
28240 function getReshapedPermuted(inputShape, blockShape, prod, batchToSpace = true) {
28241 const reshapedPermuted = [];
28242 if (batchToSpace) {
28243 reshapedPermuted.push(inputShape[0] / prod);
28244 }
28245 else {
28246 reshapedPermuted.push(inputShape[0] * prod);
28247 }
28248 for (let i = 1; i < inputShape.length; ++i) {
28249 if (i <= blockShape.length) {
28250 if (batchToSpace) {
28251 reshapedPermuted.push(blockShape[i - 1] * inputShape[i]);
28252 }
28253 else {
28254 reshapedPermuted.push(inputShape[i] / blockShape[i - 1]);
28255 }
28256 }
28257 else {
28258 reshapedPermuted.push(inputShape[i]);
28259 }
28260 }
28261 return reshapedPermuted;
28262 }
28263 /**
28264 * Converts the crops argument into the beginning coordinates of a slice
28265 * operation.
28266 */
28267 function getSliceBeginCoords(crops, blockShape) {
28268 const sliceBeginCoords = [0];
28269 for (let i = 0; i < blockShape; ++i) {
28270 sliceBeginCoords.push(crops[i][0]);
28271 }
28272 return sliceBeginCoords;
28273 }
28274 /**
28275 * Converts the crops argument into the size of a slice operation. When
28276 * combined with getSliceBeginCoords this function allows the reshaped and
28277 * permuted Tensor to be cropped to its final output shape of:
28278 *
28279 * inputShape[1] * blockShape[0] - crops[0,0] - crops[0,1], ...,
28280 * inputShape[M] * blockShape[M-1] -crops[M-1,0] -
28281 * crops[M-1,1],inputShape[M+1], ..., inputShape[N-1]]
28282 *
28283 * See step 4: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
28284 */
28285 function getSliceSize(uncroppedShape, crops, blockShape) {
28286 const sliceSize = uncroppedShape.slice(0, 1);
28287 for (let i = 0; i < blockShape; ++i) {
28288 sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]);
28289 }
28290 return sliceSize;
28291 }
28292
28293 /**
28294 * @license
28295 * Copyright 2018 Google LLC. All Rights Reserved.
28296 * Licensed under the Apache License, Version 2.0 (the "License");
28297 * you may not use this file except in compliance with the License.
28298 * You may obtain a copy of the License at
28299 *
28300 * http://www.apache.org/licenses/LICENSE-2.0
28301 *
28302 * Unless required by applicable law or agreed to in writing, software
28303 * distributed under the License is distributed on an "AS IS" BASIS,
28304 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28305 * See the License for the specific language governing permissions and
28306 * limitations under the License.
28307 * =============================================================================
28308 */
28309 const SELU_SCALEALPHA = 1.7580993408473768599402175208123;
28310 const SELU_SCALE = 1.0507009873554804934193349852946;
28311
28312 /**
28313 * @license
28314 * Copyright 2018 Google LLC. All Rights Reserved.
28315 * Licensed under the Apache License, Version 2.0 (the "License");
28316 * you may not use this file except in compliance with the License.
28317 * You may obtain a copy of the License at
28318 *
28319 * http://www.apache.org/licenses/LICENSE-2.0
28320 *
28321 * Unless required by applicable law or agreed to in writing, software
28322 * distributed under the License is distributed on an "AS IS" BASIS,
28323 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28324 * See the License for the specific language governing permissions and
28325 * limitations under the License.
28326 * =============================================================================
28327 */
28328 const ERF_P = 0.3275911;
28329 const ERF_A1 = 0.254829592;
28330 const ERF_A2 = -0.284496736;
28331 const ERF_A3 = 1.421413741;
28332 const ERF_A4 = -1.453152027;
28333 const ERF_A5 = 1.061405429;
28334
28335 /**
28336 * @license
28337 * Copyright 2018 Google LLC. All Rights Reserved.
28338 * Licensed under the Apache License, Version 2.0 (the "License");
28339 * you may not use this file except in compliance with the License.
28340 * You may obtain a copy of the License at
28341 *
28342 * http://www.apache.org/licenses/LICENSE-2.0
28343 *
28344 * Unless required by applicable law or agreed to in writing, software
28345 * distributed under the License is distributed on an "AS IS" BASIS,
28346 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28347 * See the License for the specific language governing permissions and
28348 * limitations under the License.
28349 * =============================================================================
28350 */
28351 /**
28352 * Merges real and imaginary Float32Arrays into a single complex Float32Array.
28353 *
28354 * The memory layout is interleaved as follows:
28355 * real: [r0, r1, r2]
28356 * imag: [i0, i1, i2]
28357 * complex: [r0, i0, r1, i1, r2, i2]
28358 *
28359 * This is the inverse of splitRealAndImagArrays.
28360 *
28361 * @param real The real values of the complex tensor values.
28362 * @param imag The imag values of the complex tensor values.
28363 * @returns A complex tensor as a Float32Array with merged values.
28364 */
28365 function mergeRealAndImagArrays(real, imag) {
28366 if (real.length !== imag.length) {
28367 throw new Error(`Cannot merge real and imag arrays of different lengths. real:` +
28368 `${real.length}, imag: ${imag.length}.`);
28369 }
28370 const result = new Float32Array(real.length * 2);
28371 for (let i = 0; i < result.length; i += 2) {
28372 result[i] = real[i / 2];
28373 result[i + 1] = imag[i / 2];
28374 }
28375 return result;
28376 }
28377 /**
28378 * Splits a complex Float32Array into real and imag parts.
28379 *
28380 * The memory layout is interleaved as follows:
28381 * complex: [r0, i0, r1, i1, r2, i2]
28382 * real: [r0, r1, r2]
28383 * imag: [i0, i1, i2]
28384 *
28385 * This is the inverse of mergeRealAndImagArrays.
28386 *
28387 * @param complex The complex tensor values.
28388 * @returns An object with real and imag Float32Array components of the complex
28389 * tensor.
28390 */
28391 function splitRealAndImagArrays(complex) {
28392 const real = new Float32Array(complex.length / 2);
28393 const imag = new Float32Array(complex.length / 2);
28394 for (let i = 0; i < complex.length; i += 2) {
28395 real[i / 2] = complex[i];
28396 imag[i / 2] = complex[i + 1];
28397 }
28398 return { real, imag };
28399 }
28400 /**
28401 * Extracts even indexed complex values in the given array.
28402 * @param complex The complex tensor values
28403 */
28404 function complexWithEvenIndex(complex) {
28405 const len = Math.ceil(complex.length / 4);
28406 const real = new Float32Array(len);
28407 const imag = new Float32Array(len);
28408 for (let i = 0; i < complex.length; i += 4) {
28409 real[Math.floor(i / 4)] = complex[i];
28410 imag[Math.floor(i / 4)] = complex[i + 1];
28411 }
28412 return { real, imag };
28413 }
28414 /**
28415 * Extracts odd indexed complete values in the given array.
28416 * @param complex The complex tensor values
28417 */
28418 function complexWithOddIndex(complex) {
28419 const len = Math.floor(complex.length / 4);
28420 const real = new Float32Array(len);
28421 const imag = new Float32Array(len);
28422 for (let i = 2; i < complex.length; i += 4) {
28423 real[Math.floor(i / 4)] = complex[i];
28424 imag[Math.floor(i / 4)] = complex[i + 1];
28425 }
28426 return { real, imag };
28427 }
28428 /**
28429 * Get the map representing a complex value in the given array.
28430 * @param complex The complex tensor values.
28431 * @param index An index of the target complex value.
28432 */
28433 function getComplexWithIndex(complex, index) {
28434 const real = complex[index * 2];
28435 const imag = complex[index * 2 + 1];
28436 return { real, imag };
28437 }
28438 /**
28439 * Insert a given complex value into the TypedArray.
28440 * @param data The array in which the complex value is inserted.
28441 * @param c The complex value to be inserted.
28442 * @param index An index of the target complex value.
28443 */
28444 function assignToTypedArray(data, real, imag, index) {
28445 data[index * 2] = real;
28446 data[index * 2 + 1] = imag;
28447 }
28448 /**
28449 * Make the list of exponent terms used by FFT.
28450 */
28451 function exponents(n, inverse) {
28452 const real = new Float32Array(n / 2);
28453 const imag = new Float32Array(n / 2);
28454 for (let i = 0; i < Math.ceil(n / 2); i++) {
28455 const x = (inverse ? 2 : -2) * Math.PI * (i / n);
28456 real[i] = Math.cos(x);
28457 imag[i] = Math.sin(x);
28458 }
28459 return { real, imag };
28460 }
28461 /**
28462 * Make the exponent term used by FFT.
28463 */
28464 function exponent(k, n, inverse) {
28465 const x = (inverse ? 2 : -2) * Math.PI * (k / n);
28466 const real = Math.cos(x);
28467 const imag = Math.sin(x);
28468 return { real, imag };
28469 }
28470
28471 /**
28472 * @license
28473 * Copyright 2021 Google LLC. All Rights Reserved.
28474 * Licensed under the Apache License, Version 2.0 (the "License");
28475 * you may not use this file except in compliance with the License.
28476 * You may obtain a copy of the License at
28477 *
28478 * http://www.apache.org/licenses/LICENSE-2.0
28479 *
28480 * Unless required by applicable law or agreed to in writing, software
28481 * distributed under the License is distributed on an "AS IS" BASIS,
28482 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28483 * See the License for the specific language governing permissions and
28484 * limitations under the License.
28485 * =============================================================================
28486 */
28487 const ARROW = '->';
28488 const ARROW_REGEX = /->/g;
28489 const COMMA = ',';
28490 const ELLIPSIS = '...';
28491 /**
28492 * Parse an equation for einsum.
28493 *
28494 * @param equation The einsum equation (e.g., "ij,jk->ik").
28495 * @param numTensors Number of tensors provided along with `equation`. Used to
28496 * check matching number of input tensors.
28497 * @returns An object consisting of the following fields:
28498 * - allDims: all dimension names as strings.
28499 * - summedDims: a list of all dimensions being summed over, as indices to
28500 * the elements of `allDims`.
28501 * - idDims: indices of the dimensions in each input tensor, as indices to
28502 * the elements of `allDims.
28503 */
28504 function decodeEinsumEquation(equation, numTensors) {
28505 equation = equation.replace(/\s/g, ''); // Remove witespace in equation.
28506 const numArrows = (equation.length - equation.replace(ARROW_REGEX, '').length) /
28507 ARROW.length;
28508 if (numArrows < 1) {
28509 throw new Error('Equations without an arrow are not supported.');
28510 }
28511 else if (numArrows > 1) {
28512 throw new Error(`Equation must contain exactly one arrow ("${ARROW}").`);
28513 }
28514 const [inputString, outputString] = equation.split(ARROW);
28515 assert$1(inputString.indexOf(ELLIPSIS) === -1, () => `The ellipsis notation ("${ELLIPSIS}") is not supported yet.`);
28516 const inputTerms = inputString.split(COMMA);
28517 const numInputs = inputTerms.length;
28518 if (numTensors !== numInputs) {
28519 throw new Error(`Expected ${numInputs} input tensors, received ${numTensors}`);
28520 }
28521 if (numInputs > 2) {
28522 throw new Error('Support for more than 2 input tensors is not implemented yet.');
28523 }
28524 const allDims = [];
28525 for (let i = 0; i < outputString.length; ++i) {
28526 const dimName = outputString[i];
28527 if (!inputTerms.some(inputTerm => inputTerm.indexOf(dimName) !== -1)) {
28528 throw new Error(`Output subscripts contain the label ${dimName} ` +
28529 `not present in the input subscripts.`);
28530 }
28531 if (allDims.indexOf(dimName) === -1) {
28532 allDims.push(dimName);
28533 }
28534 }
28535 for (let i = 0; i < inputString.length; ++i) {
28536 const dimName = inputString[i];
28537 if (allDims.indexOf(dimName) === -1 && dimName !== COMMA) {
28538 allDims.push(dimName);
28539 }
28540 }
28541 const idDims = new Array(inputTerms.length);
28542 for (let i = 0; i < numInputs; ++i) {
28543 if (new Set(inputTerms[i].split('')).size !== inputTerms[i].length) {
28544 throw new Error(`Found duplicate axes in input component ${inputTerms[i]}. ` +
28545 `Support for duplicate axes in input is not implemented yet.`);
28546 }
28547 idDims[i] = [];
28548 for (let j = 0; j < inputTerms[i].length; ++j) {
28549 idDims[i].push(allDims.indexOf(inputTerms[i][j]));
28550 }
28551 }
28552 const numDims = allDims.length; // Number of unique dimensions.
28553 const numOutDims = outputString.length; // Number of output dimensions.
28554 const summedDims = []; // Dimensions being summed over.
28555 for (let i = numOutDims; i < numDims; ++i) {
28556 summedDims.push(i);
28557 }
28558 return { allDims, summedDims, idDims };
28559 }
28560 /**
28561 * Get the permutation for a given input tensor.
28562 *
28563 * @param nDims Total number of dimension of all tensors involved in the einsum
28564 * operation.
28565 * @param idDims Dimension indices involve in the tensor in question.
28566 * @returns An object consisting of the following fields:
28567 * - permutationIndices: Indices to permute the axes of the tensor with.
28568 * - expandDims: Indices to the dimension that need to be expanded from the
28569 * tensor after permutation.
28570 */
28571 function getEinsumPermutation(nDims, idDims) {
28572 let permutationIndices = new Array(nDims);
28573 permutationIndices.fill(-1);
28574 for (let i = 0; i < idDims.length; ++i) {
28575 permutationIndices[idDims[i]] = i;
28576 }
28577 const expandDims = [];
28578 for (let i = 0; i < nDims; ++i) {
28579 if (permutationIndices[i] === -1) {
28580 expandDims.push(i);
28581 }
28582 }
28583 permutationIndices = permutationIndices.filter(d => d !== -1);
28584 return { permutationIndices, expandDims };
28585 }
28586 /**
28587 * Checks that the dimension sizes from different input tensors match the
28588 * equation.
28589 */
28590 function checkEinsumDimSizes(nDims, idDims, tensors) {
28591 const dimSizes = new Array(nDims);
28592 for (let i = 0; i < tensors.length; ++i) {
28593 const shape = tensors[i].shape;
28594 for (let j = 0; j < idDims[i].length; ++j) {
28595 if (dimSizes[idDims[i][j]] === undefined) {
28596 dimSizes[idDims[i][j]] = shape[j];
28597 }
28598 else {
28599 assert$1(dimSizes[idDims[i][j]] === shape[j], () => `Expected dimension ${dimSizes[idDims[i][j]]} at axis ${j} ` +
28600 `of input shaped ${JSON.stringify(shape)}, ` +
28601 `but got dimension ${shape[j]}`);
28602 }
28603 }
28604 }
28605 }
28606 /**
28607 * Gets path of computation for einsum.
28608 *
28609 * @param summedDims indices to the dimensions being summed over.
28610 * @param idDims A look up table for the dimensions present in each input
28611 * tensor.Each constituent array contains indices for the dimensions in the
28612 * corresponding input tensor.
28613 *
28614 * @return A map with two fields:
28615 * - path: The path of computation, with each element indicating the dimension
28616 * being summed over after the element-wise multiplication in that step.
28617 * - steps: With the same length as `path`. Each element contains the indices
28618 * to the input tensors being used for element-wise multiplication in the
28619 * corresponding step.
28620 */
28621 function getEinsumComputePath(summedDims, idDims) {
28622 const path = summedDims;
28623 const steps = [];
28624 let nSteps = 0;
28625 if (summedDims.length === 0) {
28626 // Einsum that involes no summing: e.g., transpose and outer product.
28627 path.push(-1);
28628 }
28629 nSteps = summedDims.length + 1;
28630 for (let i = 0; i < nSteps; ++i) {
28631 steps.push([]);
28632 }
28633 const computedTermIndices = [];
28634 for (let i = 0; i < path.length; ++i) {
28635 const summedDim = path[i];
28636 const termIndices = findTermsWithDim(idDims, summedDim);
28637 for (const termIndex of termIndices) {
28638 if (computedTermIndices.indexOf(termIndex) === -1) {
28639 steps[i].push(termIndex);
28640 computedTermIndices.push(termIndex);
28641 }
28642 }
28643 }
28644 return { path, steps };
28645 }
28646 /** Determines if an axes permutation is the identity permutation. */
28647 function isIdentityPermutation(perm) {
28648 return perm.every((dim, index) => dim === index);
28649 }
28650 function findTermsWithDim(idDims, dim) {
28651 const termIndices = [];
28652 for (let i = 0; i < idDims.length; ++i) {
28653 if (idDims[i].length === 0 || idDims[i].indexOf(dim) !== -1 || dim === -1) {
28654 termIndices.push(i);
28655 }
28656 }
28657 return termIndices;
28658 }
28659
28660 /**
28661 * Prepare the split size array. When the input is a number, the axis is evenly
28662 * divided among the split size. When the input contains the negative value, the
28663 * rest of the axis is allocated toward that.
28664 */
28665 function prepareSplitSize(x, numOrSizeSplits, axis = 0) {
28666 let splitSizes = [];
28667 if (typeof (numOrSizeSplits) === 'number') {
28668 assert$1(x.shape[axis] % numOrSizeSplits === 0, () => 'Number of splits must evenly divide the axis.');
28669 splitSizes =
28670 new Array(numOrSizeSplits).fill(x.shape[axis] / numOrSizeSplits);
28671 }
28672 else {
28673 const numOfNegs = numOrSizeSplits.reduce((count, value) => {
28674 if (value === -1) {
28675 count += 1;
28676 }
28677 return count;
28678 }, 0);
28679 assert$1(numOfNegs <= 1, () => 'There should be only one negative value in split array.');
28680 const negIndex = numOrSizeSplits.indexOf(-1);
28681 // Allow the number of split array to be -1, which indicates the rest
28682 // of dimension is allocated to that split.
28683 if (negIndex !== -1) {
28684 const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a);
28685 numOrSizeSplits[negIndex] = x.shape[axis] - total;
28686 }
28687 assert$1(x.shape[axis] === numOrSizeSplits.reduce((a, b) => a + b), () => 'The sum of sizes must match the size of the axis dimension.');
28688 splitSizes = numOrSizeSplits;
28689 }
28690 return splitSizes;
28691 }
28692
28693 /**
28694 * @license
28695 * Copyright 2021 Google LLC. All Rights Reserved.
28696 * Licensed under the Apache License, Version 2.0 (the "License");
28697 * you may not use this file except in compliance with the License.
28698 * You may obtain a copy of the License at
28699 *
28700 * http://www.apache.org/licenses/LICENSE-2.0
28701 *
28702 * Unless required by applicable law or agreed to in writing, software
28703 * distributed under the License is distributed on an "AS IS" BASIS,
28704 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28705 * See the License for the specific language governing permissions and
28706 * limitations under the License.
28707 * =============================================================================
28708 */
28709 /**
28710 * Generates sparse fill empty rows indices, dense shape mismatch error message.
28711 *
28712 * @param indicesLength The first dimension of indices.
28713 */
28714 function getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesLength) {
28715 return `Received SparseTensor with denseShape[0] = 0 but
28716 indices.shape[0] = ${indicesLength}`;
28717 }
28718 /**
28719 * Generates sparse fill empty rows negative index error message.
28720 *
28721 * @param index The index with a negative value.
28722 * @param value The negative value.
28723 */
28724 function getSparseFillEmptyRowsNegativeIndexErrorMessage(index, value) {
28725 return `indices(${index}, 0) is invalid: ${value} < 0`;
28726 }
28727 /**
28728 * Generates sparse fill empty rows out of range index error message.
28729 *
28730 * @param index The index with an out of range value.
28731 * @param value The out of range value.
28732 * @param limit The upper limit for indices.
28733 */
28734 function getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(index, value, limit) {
28735 return `indices(${index}, 0) is invalid: ${value} >= ${limit}`;
28736 }
28737
28738 /**
28739 * @license
28740 * Copyright 2021 Google LLC. All Rights Reserved.
28741 * Licensed under the Apache License, Version 2.0 (the "License");
28742 * you may not use this file except in compliance with the License.
28743 * You may obtain a copy of the License at
28744 *
28745 * http://www.apache.org/licenses/LICENSE-2.0
28746 *
28747 * Unless required by applicable law or agreed to in writing, software
28748 * distributed under the License is distributed on an "AS IS" BASIS,
28749 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28750 * See the License for the specific language governing permissions and
28751 * limitations under the License.
28752 * =============================================================================
28753 */
28754 /**
28755 * Generates sparse reshape multiple negative 1 output dimension error message.
28756 *
28757 * @param dim1 The first dimension with a negative 1 value.
28758 * @param dim2 The second dimension with a negative 1 value.
28759 */
28760 function getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(dim1, dim2) {
28761 return `only one output dimension may be -1, not both ${dim1} and ${dim2}`;
28762 }
28763 /**
28764 * Generates sparse reshape negative output dimension error message.
28765 *
28766 * @param dim The dimension with a negative value.
28767 * @param value The negative value.
28768 */
28769 function getSparseReshapeNegativeOutputDimErrorMessage(dim, value) {
28770 return `size ${dim} must be non-negative, not ${value}`;
28771 }
28772 /**
28773 * Generates sparse reshape empty tensor zero output dimension error message.
28774 *
28775 */
28776 function getSparseReshapeEmptyTensorZeroOutputDimErrorMessage() {
28777 return 'reshape cannot infer the missing input size for an empty tensor ' +
28778 'unless all specified input sizes are non-zero';
28779 }
28780 /**
28781 * Generates sparse reshape input output multiple mismatch error message.
28782 *
28783 * @param inputShape the input shape.
28784 * @param outputShape the requested output shape.
28785 */
28786 function getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape) {
28787 const inputSize = sizeFromShape(inputShape);
28788 const outputSize = sizeFromShape(outputShape);
28789 return `Input to reshape is a SparseTensor with ${inputSize}
28790 dense values, but the requested shape requires a multiple of ${outputSize}. inputShape=${inputShape} outputShape= ${outputShape}`;
28791 }
28792 /**
28793 * Generates sparse reshape input output inequality error message.
28794 *
28795 * @param inputShape the input shape.
28796 * @param outputShape the requested output shape.
28797 */
28798 function getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape) {
28799 const inputSize = sizeFromShape(inputShape);
28800 const outputSize = sizeFromShape(outputShape);
28801 return `Input to reshape is a tensor with ${inputSize} dense values, but the requested shape has ${outputSize}. inputShape=${inputShape} outputShape=${outputShape}`;
28802 }
28803
28804 /**
28805 * @license
28806 * Copyright 2021 Google LLC. All Rights Reserved.
28807 * Licensed under the Apache License, Version 2.0 (the "License");
28808 * you may not use this file except in compliance with the License.
28809 * You may obtain a copy of the License at
28810 *
28811 * http://www.apache.org/licenses/LICENSE-2.0
28812 *
28813 * Unless required by applicable law or agreed to in writing, software
28814 * distributed under the License is distributed on an "AS IS" BASIS,
28815 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28816 * See the License for the specific language governing permissions and
28817 * limitations under the License.
28818 * =============================================================================
28819 */
28820 /**
28821 * Generates sparse segment reduction negative segment ids error message.
28822 *
28823 */
28824 function getSparseSegmentReductionNegativeSegmentIdsErrorMessage() {
28825 return `segment ids must be >= 0`;
28826 }
28827 /**
28828 * Generates sparse segment reduction non increasing segment ids error message.
28829 *
28830 */
28831 function getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage() {
28832 return `segment ids are not increasing`;
28833 }
28834 /**
28835 * Generates sparse segment reduction segment id out of range error message.
28836 *
28837 * @param segmentId The segment id index that is out of range.
28838 * @param outputRows Upper bound of valid segment id values.
28839 */
28840 function getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(segmentId, outputRows) {
28841 return `Segment id ${segmentId} out of range [0, ${outputRows}), possibly because segmentIds input is not sorted.`;
28842 }
28843 /**
28844 * Generates sparse segment reduction input indice out of range error message.
28845 *
28846 * @param index The index that holds the out of range value.
28847 * @param indexValue The value that is out of range.
28848 * @param inputRows Upper bound of valid index values.
28849 */
28850 function getSparseSegmentReductionIndicesOutOfRangeErrorMessage(index, indexValue, inputRows) {
28851 return `Bad: indices[${index}] == ${indexValue} out of range [0, ${inputRows})`;
28852 }
28853
28854 /**
28855 * @license
28856 * Copyright 2018 Google LLC. All Rights Reserved.
28857 * Licensed under the Apache License, Version 2.0 (the "License");
28858 * you may not use this file except in compliance with the License.
28859 * You may obtain a copy of the License at
28860 *
28861 * http://www.apache.org/licenses/LICENSE-2.0
28862 *
28863 * Unless required by applicable law or agreed to in writing, software
28864 * distributed under the License is distributed on an "AS IS" BASIS,
28865 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28866 * See the License for the specific language governing permissions and
28867 * limitations under the License.
28868 * =============================================================================
28869 */
28870 function segOpComputeOptimalWindowSize(inSize, numSegments) {
28871 let done = false;
28872 let res;
28873 if (inSize <= PARALLELIZE_THRESHOLD) {
28874 res = inSize;
28875 done = true;
28876 }
28877 else {
28878 res = nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
28879 }
28880 while (!done) {
28881 if (res > numSegments || res === inSize) {
28882 done = true;
28883 }
28884 else {
28885 res = nearestDivisor(inSize, res + 1);
28886 }
28887 }
28888 return res;
28889 }
28890 function computeOutShape(aShape, axis, numSegments) {
28891 const outShape = [];
28892 const rank = aShape.length;
28893 for (let dim = 0; dim < rank; dim++) {
28894 if (dim !== axis) {
28895 outShape.push(aShape[dim]);
28896 }
28897 else {
28898 outShape.push(numSegments);
28899 }
28900 }
28901 return outShape;
28902 }
28903 function collectGatherOpShapeInfo(x, indices, axis, batchDims) {
28904 const indicesRank = indices.shape.length;
28905 const xRank = x.shape.length;
28906 if (batchDims !== 0) {
28907 if (batchDims < -indicesRank || batchDims > indicesRank) {
28908 throw new Error(`Expect batchDims in the range of [-${indicesRank}, ${indicesRank}], but got ${batchDims}`);
28909 }
28910 }
28911 if (batchDims < 0) {
28912 batchDims += indicesRank;
28913 }
28914 if (batchDims > xRank) {
28915 throw new Error(`batchDims (${batchDims}) must be less than rank(x) (
28916 ${xRank}).`);
28917 }
28918 if (axis < batchDims) {
28919 throw new Error(`batchDims (${batchDims}) must be less than or equal to axis (${axis}).`);
28920 }
28921 for (let i = 0; i < batchDims; ++i) {
28922 if (x.shape[i] !== indices.shape[i]) {
28923 throw new Error(`x.shape[${i}]: ${x.shape[i]} should be equal to indices.shape[${i}]: ${indices.shape[i]}.`);
28924 }
28925 }
28926 const dimSize = x.shape[axis];
28927 const outputShape = [];
28928 let batchSize = 1;
28929 let outerSize = 1;
28930 let sliceSize = 1;
28931 for (let i = 0; i < batchDims; ++i) {
28932 outputShape.push(x.shape[i]);
28933 batchSize *= x.shape[i];
28934 }
28935 for (let i = batchDims; i < axis; i++) {
28936 outputShape.push(x.shape[i]);
28937 outerSize *= x.shape[i];
28938 }
28939 for (let i = batchDims; i < indicesRank; i++) {
28940 outputShape.push(indices.shape[i]);
28941 }
28942 for (let i = axis + 1; i < xRank; i++) {
28943 outputShape.push(x.shape[i]);
28944 sliceSize *= x.shape[i];
28945 }
28946 return { batchSize, sliceSize, outerSize, dimSize, outputShape };
28947 }
28948
28949 var segment_util = /*#__PURE__*/Object.freeze({
28950 __proto__: null,
28951 collectGatherOpShapeInfo: collectGatherOpShapeInfo,
28952 computeOutShape: computeOutShape,
28953 segOpComputeOptimalWindowSize: segOpComputeOptimalWindowSize
28954 });
28955
28956 /**
28957 * @license
28958 * Copyright 2018 Google LLC. All Rights Reserved.
28959 * Licensed under the Apache License, Version 2.0 (the "License");
28960 * you may not use this file except in compliance with the License.
28961 * You may obtain a copy of the License at
28962 *
28963 * http://www.apache.org/licenses/LICENSE-2.0
28964 *
28965 * Unless required by applicable law or agreed to in writing, software
28966 * distributed under the License is distributed on an "AS IS" BASIS,
28967 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28968 * See the License for the specific language governing permissions and
28969 * limitations under the License.
28970 * =============================================================================
28971 */
28972 function fromUint8ToStringArray(vals) {
28973 try {
28974 // Decode the bytes into string.
28975 return vals.map(val => decodeString(val));
28976 }
28977 catch (err) {
28978 throw new Error(`Failed to decode encoded string bytes into utf-8, error: ${err}`);
28979 }
28980 }
28981 function fromStringArrayToUint8(strings) {
28982 return strings.map(s => encodeString(s));
28983 }
28984
28985 var backend_util = /*#__PURE__*/Object.freeze({
28986 __proto__: null,
28987 ERF_A1: ERF_A1,
28988 ERF_A2: ERF_A2,
28989 ERF_A3: ERF_A3,
28990 ERF_A4: ERF_A4,
28991 ERF_A5: ERF_A5,
28992 ERF_P: ERF_P,
28993 PARALLELIZE_THRESHOLD: PARALLELIZE_THRESHOLD,
28994 get RowPartitionType () { return RowPartitionType$1; },
28995 SELU_SCALE: SELU_SCALE,
28996 SELU_SCALEALPHA: SELU_SCALEALPHA,
28997 applyActivation: applyActivation$1,
28998 assertAndGetBroadcastShape: assertAndGetBroadcastShape,
28999 assertAxesAreInnerMostDims: assertAxesAreInnerMostDims,
29000 assertParamsConsistent: assertParamsConsistent,
29001 assignToTypedArray: assignToTypedArray,
29002 axesAreInnerMostDims: axesAreInnerMostDims,
29003 calculateShapes: calculateShapes,
29004 checkEinsumDimSizes: checkEinsumDimSizes,
29005 checkPadOnDimRoundingMode: checkPadOnDimRoundingMode,
29006 combineLocations: combineLocations,
29007 combineRaggedTensorToTensorShapes: combineRaggedTensorToTensorShapes,
29008 complexWithEvenIndex: complexWithEvenIndex,
29009 complexWithOddIndex: complexWithOddIndex,
29010 computeConv2DInfo: computeConv2DInfo,
29011 computeConv3DInfo: computeConv3DInfo,
29012 computeDefaultPad: computeDefaultPad,
29013 computeDilation2DInfo: computeDilation2DInfo,
29014 computeOptimalWindowSize: computeOptimalWindowSize,
29015 computeOutAndReduceShapes: computeOutAndReduceShapes,
29016 computeOutShape: computeOutShape$1,
29017 computePool2DInfo: computePool2DInfo,
29018 computePool3DInfo: computePool3DInfo,
29019 convertConv2DDataFormat: convertConv2DDataFormat,
29020 decodeEinsumEquation: decodeEinsumEquation,
29021 eitherStridesOrDilationsAreOne: eitherStridesOrDilationsAreOne,
29022 expandShapeToKeepDim: expandShapeToKeepDim,
29023 exponent: exponent,
29024 exponents: exponents,
29025 fromStringArrayToUint8: fromStringArrayToUint8,
29026 fromUint8ToStringArray: fromUint8ToStringArray,
29027 getAxesPermutation: getAxesPermutation,
29028 getBroadcastDims: getBroadcastDims$1,
29029 getComplexWithIndex: getComplexWithIndex,
29030 getEinsumComputePath: getEinsumComputePath,
29031 getEinsumPermutation: getEinsumPermutation,
29032 getFusedBiasGradient: getFusedBiasGradient,
29033 getFusedDyActivation: getFusedDyActivation,
29034 getImageCenter: getImageCenter,
29035 getInnerMostAxes: getInnerMostAxes,
29036 getPermuted: getPermuted,
29037 getRaggedRank: getRaggedRank,
29038 getReductionAxes: getReductionAxes,
29039 getReshaped: getReshaped,
29040 getReshapedPermuted: getReshapedPermuted,
29041 getRowPartitionTypesHelper: getRowPartitionTypesHelper,
29042 getSliceBeginCoords: getSliceBeginCoords,
29043 getSliceSize: getSliceSize,
29044 getSparseFillEmptyRowsIndicesDenseShapeMismatch: getSparseFillEmptyRowsIndicesDenseShapeMismatch,
29045 getSparseFillEmptyRowsNegativeIndexErrorMessage: getSparseFillEmptyRowsNegativeIndexErrorMessage,
29046 getSparseFillEmptyRowsOutOfRangeIndexErrorMessage: getSparseFillEmptyRowsOutOfRangeIndexErrorMessage,
29047 getSparseReshapeEmptyTensorZeroOutputDimErrorMessage: getSparseReshapeEmptyTensorZeroOutputDimErrorMessage,
29048 getSparseReshapeInputOutputMismatchErrorMessage: getSparseReshapeInputOutputMismatchErrorMessage,
29049 getSparseReshapeInputOutputMultipleErrorMessage: getSparseReshapeInputOutputMultipleErrorMessage,
29050 getSparseReshapeMultipleNegativeOneOutputDimErrorMessage: getSparseReshapeMultipleNegativeOneOutputDimErrorMessage,
29051 getSparseReshapeNegativeOutputDimErrorMessage: getSparseReshapeNegativeOutputDimErrorMessage,
29052 getSparseSegmentReductionIndicesOutOfRangeErrorMessage: getSparseSegmentReductionIndicesOutOfRangeErrorMessage,
29053 getSparseSegmentReductionNegativeSegmentIdsErrorMessage: getSparseSegmentReductionNegativeSegmentIdsErrorMessage,
29054 getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage: getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage,
29055 getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage: getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage,
29056 getUndoAxesPermutation: getUndoAxesPermutation,
29057 isIdentityPermutation: isIdentityPermutation,
29058 log: log$3,
29059 mergeRealAndImagArrays: mergeRealAndImagArrays,
29060 prepareAndValidate: prepareAndValidate,
29061 prepareSplitSize: prepareSplitSize,
29062 segment_util: segment_util,
29063 shouldFuse: shouldFuse,
29064 slice_util: slice_util,
29065 splitRealAndImagArrays: splitRealAndImagArrays,
29066 stridesOrDilationsArePositive: stridesOrDilationsArePositive,
29067 tupleValuesAreOne: tupleValuesAreOne,
29068 upcastType: upcastType,
29069 validateDefaultValueShape: validateDefaultValueShape,
29070 validateInput: validateInput$1,
29071 validateUpdateShape: validateUpdateShape,
29072 warn: warn
29073 });
29074
29075 /**
29076 * @license
29077 * Copyright 2020 Google LLC. All Rights Reserved.
29078 * Licensed under the Apache License, Version 2.0 (the "License");
29079 * you may not use this file except in compliance with the License.
29080 * You may obtain a copy of the License at
29081 *
29082 * http://www.apache.org/licenses/LICENSE-2.0
29083 *
29084 * Unless required by applicable law or agreed to in writing, software
29085 * distributed under the License is distributed on an "AS IS" BASIS,
29086 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29087 * See the License for the specific language governing permissions and
29088 * limitations under the License.
29089 * =============================================================================
29090 */
29091
29092 var kernel_impls = /*#__PURE__*/Object.freeze({
29093 __proto__: null,
29094 nonMaxSuppressionV3Impl: nonMaxSuppressionV3Impl$2,
29095 nonMaxSuppressionV4Impl: nonMaxSuppressionV4Impl$2,
29096 nonMaxSuppressionV5Impl: nonMaxSuppressionV5Impl$2,
29097 whereImpl: whereImpl$2
29098 });
29099
29100 /**
29101 * @license
29102 * Copyright 2020 Google Inc. All Rights Reserved.
29103 * Licensed under the Apache License, Version 2.0 (the "License");
29104 * you may not use this file except in compliance with the License.
29105 * You may obtain a copy of the License at
29106 *
29107 * http://www.apache.org/licenses/LICENSE-2.0
29108 *
29109 * Unless required by applicable law or agreed to in writing, software
29110 * distributed under the License is distributed on an "AS IS" BASIS,
29111 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29112 * See the License for the specific language governing permissions and
29113 * limitations under the License.
29114 * =============================================================================
29115 */
29116
29117 /**
29118 * @license
29119 * Copyright 2017 Google LLC. All Rights Reserved.
29120 * Licensed under the Apache License, Version 2.0 (the "License");
29121 * you may not use this file except in compliance with the License.
29122 * You may obtain a copy of the License at
29123 *
29124 * http://www.apache.org/licenses/LICENSE-2.0
29125 *
29126 * Unless required by applicable law or agreed to in writing, software
29127 * distributed under the License is distributed on an "AS IS" BASIS,
29128 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29129 * See the License for the specific language governing permissions and
29130 * limitations under the License.
29131 * =============================================================================
29132 */
29133 registerOptimizers();
29134
29135 /**
29136 * @license
29137 * Copyright 2020 Google LLC. All Rights Reserved.
29138 * Licensed under the Apache License, Version 2.0 (the "License");
29139 * you may not use this file except in compliance with the License.
29140 * You may obtain a copy of the License at
29141 *
29142 * http://www.apache.org/licenses/LICENSE-2.0
29143 *
29144 * Unless required by applicable law or agreed to in writing, software
29145 * distributed under the License is distributed on an "AS IS" BASIS,
29146 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29147 * See the License for the specific language governing permissions and
29148 * limitations under the License.
29149 * =============================================================================
29150 */
29151 const absGradConfig = {
29152 kernelName: Abs,
29153 inputsToSave: ['x'],
29154 gradFunc: (dy, saved) => {
29155 const [x] = saved;
29156 return { x: () => mul(dy, step$2(cast$3(x, 'float32'), -1)) };
29157 }
29158 };
29159
29160 /**
29161 * @license
29162 * Copyright 2020 Google LLC. All Rights Reserved.
29163 * Licensed under the Apache License, Version 2.0 (the "License");
29164 * you may not use this file except in compliance with the License.
29165 * You may obtain a copy of the License at
29166 *
29167 * http://www.apache.org/licenses/LICENSE-2.0
29168 *
29169 * Unless required by applicable law or agreed to in writing, software
29170 * distributed under the License is distributed on an "AS IS" BASIS,
29171 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29172 * See the License for the specific language governing permissions and
29173 * limitations under the License.
29174 * =============================================================================
29175 */
29176 const acosGradConfig = {
29177 kernelName: Acos,
29178 inputsToSave: ['x'],
29179 gradFunc: (dy, saved) => {
29180 const [x] = saved;
29181 return {
29182 x: () => {
29183 const a = square$2(cast$3(x, 'float32'));
29184 const b = sqrt$2(sub$2(scalar(1), a));
29185 return neg$2(div$1(dy, b));
29186 }
29187 };
29188 }
29189 };
29190
29191 /**
29192 * @license
29193 * Copyright 2020 Google LLC. All Rights Reserved.
29194 * Licensed under the Apache License, Version 2.0 (the "License");
29195 * you may not use this file except in compliance with the License.
29196 * You may obtain a copy of the License at
29197 *
29198 * http://www.apache.org/licenses/LICENSE-2.0
29199 *
29200 * Unless required by applicable law or agreed to in writing, software
29201 * distributed under the License is distributed on an "AS IS" BASIS,
29202 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29203 * See the License for the specific language governing permissions and
29204 * limitations under the License.
29205 * =============================================================================
29206 */
29207 const acoshGradConfig = {
29208 kernelName: Acosh,
29209 inputsToSave: ['x'],
29210 gradFunc: (dy, saved) => {
29211 const [x] = saved;
29212 return {
29213 x: () => {
29214 const a = sqrt$2(sub$2(square$2(cast$3(x, 'float32')), 1));
29215 return div$1(dy, a);
29216 }
29217 };
29218 }
29219 };
29220
29221 /**
29222 * @license
29223 * Copyright 2020 Google LLC. All Rights Reserved.
29224 * Licensed under the Apache License, Version 2.0 (the "License");
29225 * you may not use this file except in compliance with the License.
29226 * You may obtain a copy of the License at
29227 *
29228 * http://www.apache.org/licenses/LICENSE-2.0
29229 *
29230 * Unless required by applicable law or agreed to in writing, software
29231 * distributed under the License is distributed on an "AS IS" BASIS,
29232 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29233 * See the License for the specific language governing permissions and
29234 * limitations under the License.
29235 * =============================================================================
29236 */
29237 const addGradConfig = {
29238 kernelName: Add$1,
29239 inputsToSave: ['a', 'b'],
29240 gradFunc: (dy, saved) => {
29241 const [a, b] = saved;
29242 const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
29243 const derA = () => {
29244 let res = dy;
29245 const reduceAxes = getReductionAxes(a.shape, outShape);
29246 if (reduceAxes.length > 0) {
29247 res = sum$3(res, reduceAxes);
29248 }
29249 return reshape$3(res, a.shape);
29250 };
29251 const derB = () => {
29252 let res = dy;
29253 const reduceAxes = getReductionAxes(b.shape, outShape);
29254 if (reduceAxes.length > 0) {
29255 res = sum$3(res, reduceAxes);
29256 }
29257 return reshape$3(res, b.shape);
29258 };
29259 return { a: derA, b: derB };
29260 }
29261 };
29262
29263 /**
29264 * @license
29265 * Copyright 2020 Google LLC. All Rights Reserved.
29266 * Licensed under the Apache License, Version 2.0 (the "License");
29267 * you may not use this file except in compliance with the License.
29268 * You may obtain a copy of the License at
29269 *
29270 * http://www.apache.org/licenses/LICENSE-2.0
29271 *
29272 * Unless required by applicable law or agreed to in writing, software
29273 * distributed under the License is distributed on an "AS IS" BASIS,
29274 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29275 * See the License for the specific language governing permissions and
29276 * limitations under the License.
29277 * =============================================================================
29278 */
29279 const addNGradConfig = {
29280 kernelName: AddN,
29281 saveAllInputs: true,
29282 gradFunc: (dy, saved) => {
29283 const ders = {};
29284 saved.forEach((_, i) => {
29285 ders[i] = () => dy.clone();
29286 });
29287 return ders;
29288 }
29289 };
29290
29291 /**
29292 * @license
29293 * Copyright 2020 Google Inc. All Rights Reserved.
29294 * Licensed under the Apache License, Version 2.0 (the "License");
29295 * you may not use this file except in compliance with the License.
29296 * You may obtain a copy of the License at
29297 *
29298 * http://www.apache.org/licenses/LICENSE-2.0
29299 *
29300 * Unless required by applicable law or agreed to in writing, software
29301 * distributed under the License is distributed on an "AS IS" BASIS,
29302 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29303 * See the License for the specific language governing permissions and
29304 * limitations under the License.
29305 * =============================================================================
29306 */
29307 const argMaxGradConfig = {
29308 kernelName: ArgMax,
29309 inputsToSave: ['x'],
29310 gradFunc: (dy, saved) => {
29311 const [x] = saved;
29312 return { x: () => zerosLike$3(x) };
29313 }
29314 };
29315
29316 /**
29317 * @license
29318 * Copyright 2020 Google Inc. All Rights Reserved.
29319 * Licensed under the Apache License, Version 2.0 (the "License");
29320 * you may not use this file except in compliance with the License.
29321 * You may obtain a copy of the License at
29322 *
29323 * http://www.apache.org/licenses/LICENSE-2.0
29324 *
29325 * Unless required by applicable law or agreed to in writing, software
29326 * distributed under the License is distributed on an "AS IS" BASIS,
29327 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29328 * See the License for the specific language governing permissions and
29329 * limitations under the License.
29330 * =============================================================================
29331 */
29332 const argMinGradConfig = {
29333 kernelName: ArgMin,
29334 inputsToSave: ['x'],
29335 gradFunc: (dy, saved) => {
29336 const [x] = saved;
29337 return { x: () => zerosLike$3(x) };
29338 }
29339 };
29340
29341 /**
29342 * @license
29343 * Copyright 2020 Google LLC. All Rights Reserved.
29344 * Licensed under the Apache License, Version 2.0 (the "License");
29345 * you may not use this file except in compliance with the License.
29346 * You may obtain a copy of the License at
29347 *
29348 * http://www.apache.org/licenses/LICENSE-2.0
29349 *
29350 * Unless required by applicable law or agreed to in writing, software
29351 * distributed under the License is distributed on an "AS IS" BASIS,
29352 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29353 * See the License for the specific language governing permissions and
29354 * limitations under the License.
29355 * =============================================================================
29356 */
29357 const asinGradConfig = {
29358 kernelName: Asin,
29359 inputsToSave: ['x'],
29360 gradFunc: (dy, saved) => {
29361 const [x] = saved;
29362 return { x: () => div$1(dy, sqrt$2(sub$2(scalar(1), square$2(cast$3(x, 'float32'))))) };
29363 }
29364 };
29365
29366 /**
29367 * @license
29368 * Copyright 2020 Google LLC. All Rights Reserved.
29369 * Licensed under the Apache License, Version 2.0 (the "License");
29370 * you may not use this file except in compliance with the License.
29371 * You may obtain a copy of the License at
29372 *
29373 * http://www.apache.org/licenses/LICENSE-2.0
29374 *
29375 * Unless required by applicable law or agreed to in writing, software
29376 * distributed under the License is distributed on an "AS IS" BASIS,
29377 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29378 * See the License for the specific language governing permissions and
29379 * limitations under the License.
29380 * =============================================================================
29381 */
29382 const asinhGradConfig = {
29383 kernelName: Asinh,
29384 inputsToSave: ['x'],
29385 gradFunc: (dy, saved) => {
29386 const [x] = saved;
29387 return {
29388 x: () => {
29389 const a = sqrt$2(add$3(scalar(1), square$2(cast$3(x, 'float32'))));
29390 return div$1(dy, a);
29391 }
29392 };
29393 }
29394 };
29395
29396 /**
29397 * @license
29398 * Copyright 2020 Google LLC. All Rights Reserved.
29399 * Licensed under the Apache License, Version 2.0 (the "License");
29400 * you may not use this file except in compliance with the License.
29401 * You may obtain a copy of the License at
29402 *
29403 * http://www.apache.org/licenses/LICENSE-2.0
29404 *
29405 * Unless required by applicable law or agreed to in writing, software
29406 * distributed under the License is distributed on an "AS IS" BASIS,
29407 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29408 * See the License for the specific language governing permissions and
29409 * limitations under the License.
29410 * =============================================================================
29411 */
29412 const atan2GradConfig = {
29413 kernelName: Atan2,
29414 inputsToSave: ['a', 'b'],
29415 gradFunc: (dy, saved) => {
29416 const [a, b] = saved;
29417 const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
29418 const derA = () => {
29419 const d = add$3(square$2(a), square$2(b));
29420 let res = mul(dy, div$1(b, d));
29421 const reduceAxes = getReductionAxes(a.shape, outShape);
29422 if (reduceAxes.length > 0) {
29423 res = sum$3(res, reduceAxes);
29424 }
29425 return reshape$3(res, a.shape);
29426 };
29427 const derB = () => {
29428 const d = add$3(square$2(a), square$2(b));
29429 let res = neg$2(mul(dy, div$1(a, d)));
29430 const reduceAxes = getReductionAxes(b.shape, outShape);
29431 if (reduceAxes.length > 0) {
29432 res = sum$3(res, reduceAxes);
29433 }
29434 return reshape$3(res, b.shape);
29435 };
29436 return { a: derA, b: derB };
29437 }
29438 };
29439
29440 /**
29441 * @license
29442 * Copyright 2020 Google LLC. All Rights Reserved.
29443 * Licensed under the Apache License, Version 2.0 (the "License");
29444 * you may not use this file except in compliance with the License.
29445 * You may obtain a copy of the License at
29446 *
29447 * http://www.apache.org/licenses/LICENSE-2.0
29448 *
29449 * Unless required by applicable law or agreed to in writing, software
29450 * distributed under the License is distributed on an "AS IS" BASIS,
29451 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29452 * See the License for the specific language governing permissions and
29453 * limitations under the License.
29454 * =============================================================================
29455 */
29456 const atanGradConfig = {
29457 kernelName: Atan,
29458 inputsToSave: ['x'],
29459 gradFunc: (dy, saved) => {
29460 const [x] = saved;
29461 return { x: () => div$1(dy, add$3(square$2(cast$3(x, 'float32')), 1)) };
29462 }
29463 };
29464
29465 /**
29466 * @license
29467 * Copyright 2020 Google LLC. All Rights Reserved.
29468 * Licensed under the Apache License, Version 2.0 (the "License");
29469 * you may not use this file except in compliance with the License.
29470 * You may obtain a copy of the License at
29471 *
29472 * http://www.apache.org/licenses/LICENSE-2.0
29473 *
29474 * Unless required by applicable law or agreed to in writing, software
29475 * distributed under the License is distributed on an "AS IS" BASIS,
29476 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29477 * See the License for the specific language governing permissions and
29478 * limitations under the License.
29479 * =============================================================================
29480 */
29481 const atanhGradConfig = {
29482 kernelName: Atanh,
29483 inputsToSave: ['x'],
29484 gradFunc: (dy, saved) => {
29485 const [x] = saved;
29486 return { x: () => div$1(dy, sub$2(scalar(1), square$2(cast$3(x, 'float32')))) };
29487 }
29488 };
29489
29490 /**
29491 * @license
29492 * Copyright 2020 Google LLC. All Rights Reserved.
29493 * Licensed under the Apache License, Version 2.0 (the "License");
29494 * you may not use this file except in compliance with the License.
29495 * You may obtain a copy of the License at
29496 *
29497 * http://www.apache.org/licenses/LICENSE-2.0
29498 *
29499 * Unless required by applicable law or agreed to in writing, software
29500 * distributed under the License is distributed on an "AS IS" BASIS,
29501 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29502 * See the License for the specific language governing permissions and
29503 * limitations under the License.
29504 * =============================================================================
29505 */
29506 /**
29507 * Computes the backprop of a 3d avg pool.
29508 *
29509 * @param dy The dy error, of rank 5 of shape
29510 * [batchSize, depth, height, width, channels].
29511 * assumed.
29512 * @param input The original input image, of rank 5 or rank4 of shape
29513 * [batchSize, depth, height, width, channels].
29514 * @param filterSize The filter size:
29515 * `[filterDepth, filterHeight, filterWidth]`.
29516 * `filterSize` is a single number,
29517 * then `filterDepth == filterHeight == filterWidth`.
29518 * @param strides The strides of the pooling:
29519 * `[strideDepth, strideHeight, strideWidth]`. If
29520 * `strides` is a single number, then `strideHeight == strideWidth`.
29521 * @param pad A string from: 'same', 'valid'. The type of padding algorithm
29522 * used in the forward prop of the op.
29523 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
29524 * provided, it will default to truncate.
29525 */
29526 function avgPool3dGrad_(dy, input, filterSize, strides, pad, dimRoundingMode) {
29527 const $dy = convertToTensor(dy, 'dy', 'avgPool3dGrad');
29528 const $input = convertToTensor(input, 'input', 'avgPool3dGrad');
29529 let dy5D = $dy;
29530 let input5D = $input;
29531 let reshapedTo5D = false;
29532 if ($input.rank === 4) {
29533 reshapedTo5D = true;
29534 dy5D = reshape$3($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
29535 input5D = reshape$3($input, [
29536 1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]
29537 ]);
29538 }
29539 assert$1(dy5D.rank === 5, () => `Error in avgPool3dGrad: dy must be rank 5 but got rank ` +
29540 `${dy5D.rank}.`);
29541 assert$1(input5D.rank === 5, () => `Error in avgPool3dGrad: input must be rank 5 but got rank ` +
29542 `${input5D.rank}.`);
29543 checkPadOnDimRoundingMode('avgPool3dGrad', pad, dimRoundingMode);
29544 const inputs = { dy: dy5D, input: input5D };
29545 const attrs = { filterSize, strides, pad, dimRoundingMode };
29546 // tslint:disable-next-line: no-unnecessary-type-assertion
29547 const res = ENGINE.runKernel(AvgPool3DGrad, inputs, attrs);
29548 if (reshapedTo5D) {
29549 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
29550 }
29551 return res;
29552 }
29553 const avgPool3dGrad = /* @__PURE__ */ op({ avgPool3dGrad_ });
29554
29555 /**
29556 * @license
29557 * Copyright 2020 Google LLC. All Rights Reserved.
29558 * Licensed under the Apache License, Version 2.0 (the "License");
29559 * you may not use this file except in compliance with the License.
29560 * You may obtain a copy of the License at
29561 *
29562 * http://www.apache.org/licenses/LICENSE-2.0
29563 *
29564 * Unless required by applicable law or agreed to in writing, software
29565 * distributed under the License is distributed on an "AS IS" BASIS,
29566 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29567 * See the License for the specific language governing permissions and
29568 * limitations under the License.
29569 * =============================================================================
29570 */
29571 const avgPool3DGradConfig$2 = {
29572 kernelName: AvgPool3D,
29573 inputsToSave: ['x'],
29574 gradFunc: (dy, saved, attrs) => {
29575 const [x] = saved;
29576 const { filterSize, strides, pad, dimRoundingMode } = attrs;
29577 return {
29578 x: () => avgPool3dGrad(dy, x, filterSize, strides, pad, dimRoundingMode)
29579 };
29580 }
29581 };
29582
29583 /**
29584 * @license
29585 * Copyright 2020 Google LLC. All Rights Reserved.
29586 * Licensed under the Apache License, Version 2.0 (the "License");
29587 * you may not use this file except in compliance with the License.
29588 * You may obtain a copy of the License at
29589 *
29590 * http://www.apache.org/licenses/LICENSE-2.0
29591 *
29592 * Unless required by applicable law or agreed to in writing, software
29593 * distributed under the License is distributed on an "AS IS" BASIS,
29594 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29595 * See the License for the specific language governing permissions and
29596 * limitations under the License.
29597 * =============================================================================
29598 */
29599 /**
29600 * Computes the backprop of an 2D avg pool.
29601 *
29602 * @param dy The dy error, of rank 4 or rank 3 of shape
29603 * [batchSize, height, width, channels]. If rank 3, batch of 1 is
29604 * assumed.
29605 * @param input The input image, of rank 4 or rank 3 of shape
29606 * [batchSize, height, width, channels]. If rank 3, batch of 1 is
29607 * assumed.
29608 * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
29609 * `filterSize` is a single number, then `filterHeight == filterWidth`.
29610 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
29611 * `strides` is a single number, then `strideHeight == strideWidth`.
29612 * @param pad The type of padding algorithm used in the forward prop of the op.
29613 * 'same', 'valid', for more info, see this guide:
29614 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
29615 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
29616 */
29617 function avgPoolGrad_(dy, input, filterSize, strides, pad) {
29618 const $dy = convertToTensor(dy, 'dy', 'avgPoolGrad');
29619 const $input = convertToTensor(input, 'input', 'avgPoolGrad');
29620 assert$1($input.rank === $dy.rank, () => `Rank of input (${$input.rank}) does not match rank of dy (${$dy.rank})`);
29621 let input4D = $input;
29622 let dy4D = $dy;
29623 let reshapedTo4D = false;
29624 if ($input.rank === 3) {
29625 reshapedTo4D = true;
29626 input4D =
29627 reshape$3($input, [1, $input.shape[0], $input.shape[1], $input.shape[2]]);
29628 dy4D = reshape$3($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2]]);
29629 }
29630 assert$1(dy4D.rank === 4, () => `Error in avgPoolGrad: dy must be rank 4 but got rank ` +
29631 `${dy4D.rank}.`);
29632 assert$1(input4D.rank === 4, () => `Error in avgPoolGrad: input must be rank 4 but got rank ` +
29633 `${input4D.rank}.`);
29634 const inputs = { dy: dy4D, input: input4D };
29635 const attrs = { filterSize, strides, pad };
29636 // tslint:disable-next-line: no-unnecessary-type-assertion
29637 const res = ENGINE.runKernel(AvgPoolGrad, inputs, attrs);
29638 if (reshapedTo4D) {
29639 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3]]);
29640 }
29641 return res;
29642 }
29643 const avgPoolGrad$2 = /* @__PURE__ */ op({ avgPoolGrad_ });
29644
29645 /**
29646 * @license
29647 * Copyright 2020 Google LLC. All Rights Reserved.
29648 * Licensed under the Apache License, Version 2.0 (the "License");
29649 * you may not use this file except in compliance with the License.
29650 * You may obtain a copy of the License at
29651 *
29652 * http://www.apache.org/licenses/LICENSE-2.0
29653 *
29654 * Unless required by applicable law or agreed to in writing, software
29655 * distributed under the License is distributed on an "AS IS" BASIS,
29656 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29657 * See the License for the specific language governing permissions and
29658 * limitations under the License.
29659 * =============================================================================
29660 */
29661 const avgPoolGradConfig$2 = {
29662 kernelName: AvgPool,
29663 inputsToSave: ['x'],
29664 gradFunc: (dy, saved, attrs) => {
29665 const [x] = saved;
29666 const { filterSize, strides, pad } = attrs;
29667 return { x: () => avgPoolGrad$2(dy, x, filterSize, strides, pad) };
29668 }
29669 };
29670
29671 /**
29672 * @license
29673 * Copyright 2020 Google LLC. All Rights Reserved.
29674 * Licensed under the Apache License, Version 2.0 (the "License");
29675 * you may not use this file except in compliance with the License.
29676 * You may obtain a copy of the License at
29677 *
29678 * http://www.apache.org/licenses/LICENSE-2.0
29679 *
29680 * Unless required by applicable law or agreed to in writing, software
29681 * distributed under the License is distributed on an "AS IS" BASIS,
29682 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29683 * See the License for the specific language governing permissions and
29684 * limitations under the License.
29685 * =============================================================================
29686 */
29687 const batchMatMulGradConfig = {
29688 kernelName: BatchMatMul,
29689 inputsToSave: ['a', 'b'],
29690 gradFunc: (dy, saved, attrs) => {
29691 const [a, b] = saved;
29692 const { transposeA, transposeB } = attrs;
29693 if (!transposeA && !transposeB) {
29694 return {
29695 a: () => matMul$1(dy, b, false, true),
29696 b: () => matMul$1(a, dy, true, false)
29697 };
29698 }
29699 else if (!transposeA && transposeB) {
29700 return {
29701 a: () => matMul$1(dy, b, false, false),
29702 b: () => matMul$1(dy, a, true, false)
29703 };
29704 }
29705 else if (transposeA && !transposeB) {
29706 return {
29707 a: () => matMul$1(b, dy, false, true),
29708 b: () => matMul$1(a, dy, false, false)
29709 };
29710 }
29711 else {
29712 return {
29713 a: () => matMul$1(b, dy, true, true),
29714 b: () => matMul$1(dy, a, true, true)
29715 };
29716 }
29717 }
29718 };
29719
29720 /**
29721 * @license
29722 * Copyright 2020 Google LLC. All Rights Reserved.
29723 * Licensed under the Apache License, Version 2.0 (the "License");
29724 * you may not use this file except in compliance with the License.
29725 * You may obtain a copy of the License at
29726 *
29727 * http://www.apache.org/licenses/LICENSE-2.0
29728 *
29729 * Unless required by applicable law or agreed to in writing, software
29730 * distributed under the License is distributed on an "AS IS" BASIS,
29731 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29732 * See the License for the specific language governing permissions and
29733 * limitations under the License.
29734 * =============================================================================
29735 */
29736 const batchToSpaceNDGradConfig = {
29737 kernelName: BatchToSpaceND,
29738 gradFunc: (dy, saved, attrs) => {
29739 const { blockShape, crops } = attrs;
29740 return { x: () => spaceToBatchND$2(dy, blockShape, crops) };
29741 }
29742 };
29743
29744 /**
29745 * @license
29746 * Copyright 2020 Google LLC. All Rights Reserved.
29747 * Licensed under the Apache License, Version 2.0 (the "License");
29748 * you may not use this file except in compliance with the License.
29749 * You may obtain a copy of the License at
29750 *
29751 * http://www.apache.org/licenses/LICENSE-2.0
29752 *
29753 * Unless required by applicable law or agreed to in writing, software
29754 * distributed under the License is distributed on an "AS IS" BASIS,
29755 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29756 * See the License for the specific language governing permissions and
29757 * limitations under the License.
29758 * =============================================================================
29759 */
29760 const broadcastToGradConfig = {
29761 kernelName: BroadcastTo,
29762 gradFunc: (dy, saved, attrs) => {
29763 const broadCastToAttrs = attrs;
29764 const inputShape = broadCastToAttrs.inputShape;
29765 const outputShape = broadCastToAttrs.shape;
29766 const reps = Array.from(outputShape);
29767 for (let i = inputShape.length - 1; i >= 0; i--) {
29768 if (inputShape[i] === outputShape[i]) {
29769 reps[i] = 1;
29770 }
29771 else if (inputShape[i] !== 1) {
29772 throw new Error(`broadcastTo(): [${inputShape}] cannot be broadcast to [${outputShape}].`);
29773 }
29774 }
29775 const axes = [];
29776 for (let i = 0; i < reps.length; i++) {
29777 if (reps[i] > 1) {
29778 axes.push(i);
29779 }
29780 }
29781 return { x: () => sum$3(dy, axes, true /* keepDims */) };
29782 }
29783 };
29784
29785 /**
29786 * @license
29787 * Copyright 2020 Google LLC. All Rights Reserved.
29788 * Licensed under the Apache License, Version 2.0 (the "License");
29789 * you may not use this file except in compliance with the License.
29790 * You may obtain a copy of the License at
29791 *
29792 * http://www.apache.org/licenses/LICENSE-2.0
29793 *
29794 * Unless required by applicable law or agreed to in writing, software
29795 * distributed under the License is distributed on an "AS IS" BASIS,
29796 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29797 * See the License for the specific language governing permissions and
29798 * limitations under the License.
29799 * =============================================================================
29800 */
29801 const castGradConfig = {
29802 kernelName: Cast,
29803 gradFunc: (dy) => {
29804 return { x: () => dy.clone() };
29805 }
29806 };
29807
29808 /**
29809 * @license
29810 * Copyright 2020 Google LLC. All Rights Reserved.
29811 * Licensed under the Apache License, Version 2.0 (the "License");
29812 * you may not use this file except in compliance with the License.
29813 * You may obtain a copy of the License at
29814 *
29815 * http://www.apache.org/licenses/LICENSE-2.0
29816 *
29817 * Unless required by applicable law or agreed to in writing, software
29818 * distributed under the License is distributed on an "AS IS" BASIS,
29819 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29820 * See the License for the specific language governing permissions and
29821 * limitations under the License.
29822 * =============================================================================
29823 */
29824 const ceilGradConfig = {
29825 kernelName: Ceil,
29826 gradFunc: (dy) => {
29827 // TODO(manrajgrover): Return null for gradients when backprop supports it.
29828 return { x: () => zerosLike$3(dy) };
29829 }
29830 };
29831
29832 /**
29833 * @license
29834 * Copyright 2020 Google LLC. All Rights Reserved.
29835 * Licensed under the Apache License, Version 2.0 (the "License");
29836 * you may not use this file except in compliance with the License.
29837 * You may obtain a copy of the License at
29838 *
29839 * http://www.apache.org/licenses/LICENSE-2.0
29840 *
29841 * Unless required by applicable law or agreed to in writing, software
29842 * distributed under the License is distributed on an "AS IS" BASIS,
29843 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29844 * See the License for the specific language governing permissions and
29845 * limitations under the License.
29846 * =============================================================================
29847 */
29848 const clipByValueGradConfig = {
29849 kernelName: ClipByValue,
29850 inputsToSave: ['x'],
29851 gradFunc: (dy, saved, attrs) => {
29852 const [x] = saved;
29853 const { clipValueMin, clipValueMax } = attrs;
29854 return {
29855 x: () => where(logicalAnd$2(greaterEqual$2(x, clipValueMin), lessEqual$2(x, clipValueMax)), dy, zerosLike$3(dy)),
29856 };
29857 }
29858 };
29859
29860 /**
29861 * @license
29862 * Copyright 2020 Google LLC. All Rights Reserved.
29863 * Licensed under the Apache License, Version 2.0 (the "License");
29864 * you may not use this file except in compliance with the License.
29865 * You may obtain a copy of the License at
29866 *
29867 * http://www.apache.org/licenses/LICENSE-2.0
29868 *
29869 * Unless required by applicable law or agreed to in writing, software
29870 * distributed under the License is distributed on an "AS IS" BASIS,
29871 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29872 * See the License for the specific language governing permissions and
29873 * limitations under the License.
29874 * =============================================================================
29875 */
29876 const complexAbsGradConfig = {
29877 kernelName: ComplexAbs,
29878 inputsToSave: ['x'],
29879 gradFunc: absGradConfig.gradFunc,
29880 };
29881
29882 /**
29883 * @license
29884 * Copyright 2020 Google LLC. All Rights Reserved.
29885 * Licensed under the Apache License, Version 2.0 (the "License");
29886 * you may not use this file except in compliance with the License.
29887 * You may obtain a copy of the License at
29888 *
29889 * http://www.apache.org/licenses/LICENSE-2.0
29890 *
29891 * Unless required by applicable law or agreed to in writing, software
29892 * distributed under the License is distributed on an "AS IS" BASIS,
29893 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29894 * See the License for the specific language governing permissions and
29895 * limitations under the License.
29896 * =============================================================================
29897 */
29898 const concatGradConfig = {
29899 kernelName: Concat,
29900 saveAllInputs: true,
29901 gradFunc: (dy, saved, attrs) => {
29902 const shapes = saved.map(t => t.shape);
29903 const { axis } = attrs;
29904 const $axis = parseAxisParam(axis, saved[0].shape)[0];
29905 const sizeSplits = shapes.map(s => s[$axis]);
29906 const derTensors = split$3(dy, sizeSplits, $axis);
29907 return derTensors.map(t => () => t);
29908 }
29909 };
29910
29911 /**
29912 * @license
29913 * Copyright 2020 Google LLC. All Rights Reserved.
29914 * Licensed under the Apache License, Version 2.0 (the "License");
29915 * you may not use this file except in compliance with the License.
29916 * You may obtain a copy of the License at
29917 *
29918 * http://www.apache.org/licenses/LICENSE-2.0
29919 *
29920 * Unless required by applicable law or agreed to in writing, software
29921 * distributed under the License is distributed on an "AS IS" BASIS,
29922 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29923 * See the License for the specific language governing permissions and
29924 * limitations under the License.
29925 * =============================================================================
29926 */
29927 const conv2DGradConfig = {
29928 kernelName: Conv2D$1,
29929 inputsToSave: ['x', 'filter'],
29930 gradFunc: (dy, saved, attrs) => {
29931 const [x4D, $filter] = saved;
29932 const { dilations, strides, pad, dataFormat } = attrs;
29933 assert$1(tupleValuesAreOne(dilations), () => 'Error in gradient of conv2D: dilation rates greater than 1 ' +
29934 `are not yet supported in gradients. Got dilations '${dilations}'`);
29935 return {
29936 x: () => conv2DBackpropInput$2(x4D.shape, dy, $filter, strides, pad, dataFormat),
29937 filter: () => conv2DBackpropFilter$2(x4D, dy, $filter.shape, strides, pad, dataFormat)
29938 };
29939 }
29940 };
29941
29942 /**
29943 * @license
29944 * Copyright 2020 Google LLC. All Rights Reserved.
29945 * Licensed under the Apache License, Version 2.0 (the "License");
29946 * you may not use this file except in compliance with the License.
29947 * You may obtain a copy of the License at
29948 *
29949 * http://www.apache.org/licenses/LICENSE-2.0
29950 *
29951 * Unless required by applicable law or agreed to in writing, software
29952 * distributed under the License is distributed on an "AS IS" BASIS,
29953 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29954 * See the License for the specific language governing permissions and
29955 * limitations under the License.
29956 * =============================================================================
29957 */
29958 const conv2DBackpropInputGradConfig = {
29959 kernelName: Conv2DBackpropInput,
29960 inputsToSave: ['dy', 'filter'],
29961 gradFunc: (ddx, saved, attrs) => {
29962 const [dy, filter] = saved;
29963 const { strides, pad, dataFormat, dimRoundingMode } = attrs;
29964 return {
29965 dy: () => conv2d$4(ddx, filter, strides, pad, dataFormat, 1 /* dilations */, dimRoundingMode),
29966 filter: () => conv2DBackpropFilter$2(ddx, dy, filter.shape, strides, pad, dataFormat, dimRoundingMode)
29967 };
29968 }
29969 };
29970
29971 /**
29972 * @license
29973 * Copyright 2020 Google LLC. All Rights Reserved.
29974 * Licensed under the Apache License, Version 2.0 (the "License");
29975 * you may not use this file except in compliance with the License.
29976 * You may obtain a copy of the License at
29977 *
29978 * http://www.apache.org/licenses/LICENSE-2.0
29979 *
29980 * Unless required by applicable law or agreed to in writing, software
29981 * distributed under the License is distributed on an "AS IS" BASIS,
29982 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29983 * See the License for the specific language governing permissions and
29984 * limitations under the License.
29985 * =============================================================================
29986 */
29987 /**
29988 * Computes the derivative of the filter of a 3D convolution.
29989 *
29990 * @param x The input tensor, of rank 5 or rank 4 of shape
29991 * [batch, depth, height, width, inChannels]. If rank 4, batch of 1 is
29992 * assumed.
29993 * @param dy The dy image, of rank 5 or rank 4, of shape
29994 * [batch, depth, height, width, outDepth]. If rank 4, batch of 1 is
29995 * assumed.
29996 * @param filterShape The shape of the filter, length 5,
29997 * [filterDepth, filterHeight, filterWidth, inDepth, outDepth].
29998 * @param strides The strides of the convolution: [strideDepth, strideHeight,
29999 * strideWidth].
30000 * @param pad A string from: 'same', 'valid'. The type of padding algorithm
30001 * used in the forward prop of the op.
30002 */
30003 function conv3DBackpropFilter_(x, dy, filterShape, strides, pad) {
30004 let x5D = x;
30005 if (x.rank === 4) {
30006 x5D = reshape$3(x, [1, x.shape[0], x.shape[1], x.shape[2], x.shape[3]]);
30007 }
30008 let dy5D = dy;
30009 if (dy5D.rank === 4) {
30010 dy5D = reshape$3(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
30011 }
30012 assert$1(x5D.rank === 5, () => `Error in conv3dDerFilter: input must be rank 5, but got shape ` +
30013 `${x5D.shape}.`);
30014 assert$1(dy5D.rank === 5, () => `Error in conv3dDerFilter: dy must be rank 5, but got shape ` +
30015 `${dy5D.shape}.`);
30016 assert$1(filterShape.length === 5, () => `Error in conv3dDerFilter: filterShape must be length 5, but got ` +
30017 `${filterShape}.`);
30018 assert$1(x5D.shape[4] === filterShape[3], () => `Error in conv3dDerFilter: depth of input ${x5D.shape[4]}) must ` +
30019 `match input depth in filter (${filterShape[3]}.`);
30020 assert$1(dy5D.shape[4] === filterShape[4], () => `Error in conv3dDerFilter: depth of dy (${dy5D.shape[4]}) must ` +
30021 `match output depth for filter (${filterShape[4]}).`);
30022 const inputs = { x: x5D, dy: dy5D };
30023 const attrs = { strides, pad, filterShape };
30024 // tslint:disable-next-line: no-unnecessary-type-assertion
30025 return ENGINE.runKernel(Conv3DBackpropFilterV2, inputs, attrs);
30026 }
30027 const conv3DBackpropFilter = /* @__PURE__ */ op({ conv3DBackpropFilter_ });
30028
30029 /**
30030 * @license
30031 * Copyright 2020 Google LLC. All Rights Reserved.
30032 * Licensed under the Apache License, Version 2.0 (the "License");
30033 * you may not use this file except in compliance with the License.
30034 * You may obtain a copy of the License at
30035 *
30036 * http://www.apache.org/licenses/LICENSE-2.0
30037 *
30038 * Unless required by applicable law or agreed to in writing, software
30039 * distributed under the License is distributed on an "AS IS" BASIS,
30040 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30041 * See the License for the specific language governing permissions and
30042 * limitations under the License.
30043 * =============================================================================
30044 */
30045 const conv3DGradConfig = {
30046 kernelName: Conv3D$1,
30047 inputsToSave: ['x', 'filter'],
30048 gradFunc: (dy, saved, attrs) => {
30049 const { dilations, strides, pad } = attrs;
30050 assert$1(tupleValuesAreOne(dilations), () => 'Error in gradient of conv3D: dilation rates greater than 1 are ' +
30051 `not yet supported in gradients. Got dilations '${dilations}'`);
30052 const [x5D, $filter] = saved;
30053 return {
30054 x: () => conv3DBackpropInput$1(x5D.shape, dy, $filter, strides, pad),
30055 filter: () => conv3DBackpropFilter(x5D, dy, $filter.shape, strides, pad)
30056 };
30057 }
30058 };
30059
30060 /**
30061 * @license
30062 * Copyright 2020 Google LLC. All Rights Reserved.
30063 * Licensed under the Apache License, Version 2.0 (the "License");
30064 * you may not use this file except in compliance with the License.
30065 * You may obtain a copy of the License at
30066 *
30067 * http://www.apache.org/licenses/LICENSE-2.0
30068 *
30069 * Unless required by applicable law or agreed to in writing, software
30070 * distributed under the License is distributed on an "AS IS" BASIS,
30071 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30072 * See the License for the specific language governing permissions and
30073 * limitations under the License.
30074 * =============================================================================
30075 */
30076 const cosGradConfig = {
30077 kernelName: Cos,
30078 inputsToSave: ['x'],
30079 gradFunc: (dy, saved) => {
30080 const [x] = saved;
30081 return { x: () => mul(neg$2(sin$2(cast$3(x, 'float32'))), dy) };
30082 }
30083 };
30084
30085 /**
30086 * @license
30087 * Copyright 2020 Google LLC. All Rights Reserved.
30088 * Licensed under the Apache License, Version 2.0 (the "License");
30089 * you may not use this file except in compliance with the License.
30090 * You may obtain a copy of the License at
30091 *
30092 * http://www.apache.org/licenses/LICENSE-2.0
30093 *
30094 * Unless required by applicable law or agreed to in writing, software
30095 * distributed under the License is distributed on an "AS IS" BASIS,
30096 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30097 * See the License for the specific language governing permissions and
30098 * limitations under the License.
30099 * =============================================================================
30100 */
30101 const coshGradConfig = {
30102 kernelName: Cosh,
30103 inputsToSave: ['x'],
30104 gradFunc: (dy, saved) => {
30105 const [x] = saved;
30106 return { x: () => mul(sinh$2(cast$3(x, 'float32')), dy) };
30107 }
30108 };
30109
30110 /**
30111 * @license
30112 * Copyright 2020 Google LLC. All Rights Reserved.
30113 * Licensed under the Apache License, Version 2.0 (the "License");
30114 * you may not use this file except in compliance with the License.
30115 * You may obtain a copy of the License at
30116 *
30117 * http://www.apache.org/licenses/LICENSE-2.0
30118 *
30119 * Unless required by applicable law or agreed to in writing, software
30120 * distributed under the License is distributed on an "AS IS" BASIS,
30121 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30122 * See the License for the specific language governing permissions and
30123 * limitations under the License.
30124 * =============================================================================
30125 */
30126 const cumsumGradConfig = {
30127 kernelName: Cumsum,
30128 inputsToSave: ['x'],
30129 gradFunc: (dy, saved, attrs) => {
30130 const [x] = saved;
30131 const { axis, exclusive, reverse } = attrs;
30132 return {
30133 x: () => {
30134 const permutation = getAxesPermutation([axis], x.rank);
30135 let out = cumsum$2(dy, axis, exclusive, !reverse);
30136 if (permutation != null) {
30137 out = transpose$2(out, permutation);
30138 }
30139 return out;
30140 }
30141 };
30142 }
30143 };
30144
30145 /**
30146 * @license
30147 * Copyright 2020 Google LLC. All Rights Reserved.
30148 * Licensed under the Apache License, Version 2.0 (the "License");
30149 * you may not use this file except in compliance with the License.
30150 * You may obtain a copy of the License at
30151 *
30152 * http://www.apache.org/licenses/LICENSE-2.0
30153 *
30154 * Unless required by applicable law or agreed to in writing, software
30155 * distributed under the License is distributed on an "AS IS" BASIS,
30156 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30157 * See the License for the specific language governing permissions and
30158 * limitations under the License.
30159 * =============================================================================
30160 */
30161 const depthwiseConv2dNativeGradConfig = {
30162 kernelName: DepthwiseConv2dNative,
30163 inputsToSave: ['x', 'filter'],
30164 gradFunc: (dy, saved, attrs) => {
30165 const { dilations, strides, pad, dimRoundingMode } = attrs;
30166 const $dilations = dilations == null ? [1, 1] : dilations;
30167 assert$1(tupleValuesAreOne($dilations), () => 'Error in gradient of depthwiseConv2dNative: dilation rates ' +
30168 `greater than 1 are not yet supported. Got dilations ` +
30169 `'${$dilations}'`);
30170 const [x, filter] = saved;
30171 assert$1(x.rank === 4, () => `Error in gradient of depthwiseConv2dNative: input must be ` +
30172 `rank 4, but got rank ${x.rank}.`);
30173 assert$1(filter.rank === 4, () => `Error in gradient of depthwiseConv2dNative: filter must be ` +
30174 `rank 4, but got rank ${filter.rank}.`);
30175 assert$1(x.shape[3] === filter.shape[2], () => `Error in gradient of depthwiseConv2d: number of input ` +
30176 `channels (${x.shape[3]}) must match the inChannels dimension ` +
30177 `in filter ${filter.shape[2]}.`);
30178 assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in gradient of depthwiseConv2d: Either strides or ' +
30179 `dilations must be 1. Got strides ${strides} and dilations ` +
30180 `'${$dilations}'.`);
30181 checkPadOnDimRoundingMode('depthwiseConv2d', pad, dimRoundingMode);
30182 return {
30183 x: () => depthwiseConv2dNativeBackpropInput$2(x.shape, dy, filter, strides, pad, $dilations, dimRoundingMode),
30184 filter: () => depthwiseConv2dNativeBackpropFilter$2(x, dy, filter.shape, strides, pad, $dilations, dimRoundingMode),
30185 };
30186 }
30187 };
30188
30189 /**
30190 * @license
30191 * Copyright 2020 Google LLC. All Rights Reserved.
30192 * Licensed under the Apache License, Version 2.0 (the "License");
30193 * you may not use this file except in compliance with the License.
30194 * You may obtain a copy of the License at
30195 *
30196 * http://www.apache.org/licenses/LICENSE-2.0
30197 *
30198 * Unless required by applicable law or agreed to in writing, software
30199 * distributed under the License is distributed on an "AS IS" BASIS,
30200 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30201 * See the License for the specific language governing permissions and
30202 * limitations under the License.
30203 * =============================================================================
30204 */
30205 const dilation2dGradConfig = {
30206 kernelName: Dilation2D,
30207 inputsToSave: ['x', 'filter'],
30208 gradFunc: (dy, saved, attrs) => {
30209 const [x, filter] = saved;
30210 const inputInputs = { x, filter, dy };
30211 const filterInputs = { x, filter, dy };
30212 return {
30213 x: () => ENGINE.runKernel(Dilation2DBackpropInput, inputInputs, attrs),
30214 filter: () => ENGINE.runKernel(Dilation2DBackpropFilter, filterInputs, attrs)
30215 };
30216 }
30217 };
30218
30219 /**
30220 * @license
30221 * Copyright 2020 Google LLC. All Rights Reserved.
30222 * Licensed under the Apache License, Version 2.0 (the "License");
30223 * you may not use this file except in compliance with the License.
30224 * You may obtain a copy of the License at
30225 *
30226 * http://www.apache.org/licenses/LICENSE-2.0
30227 *
30228 * Unless required by applicable law or agreed to in writing, software
30229 * distributed under the License is distributed on an "AS IS" BASIS,
30230 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30231 * See the License for the specific language governing permissions and
30232 * limitations under the License.
30233 * =============================================================================
30234 */
30235 const eluGradConfig$2 = {
30236 kernelName: Elu$1,
30237 outputsToSave: [true],
30238 gradFunc: (dy, saved) => {
30239 const [y] = saved;
30240 const inputs = { dy, y };
30241 return { x: () => ENGINE.runKernel(EluGrad, inputs) };
30242 }
30243 };
30244
30245 /**
30246 * @license
30247 * Copyright 2020 Google LLC. All Rights Reserved.
30248 * Licensed under the Apache License, Version 2.0 (the "License");
30249 * you may not use this file except in compliance with the License.
30250 * You may obtain a copy of the License at
30251 *
30252 * http://www.apache.org/licenses/LICENSE-2.0
30253 *
30254 * Unless required by applicable law or agreed to in writing, software
30255 * distributed under the License is distributed on an "AS IS" BASIS,
30256 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30257 * See the License for the specific language governing permissions and
30258 * limitations under the License.
30259 * =============================================================================
30260 */
30261 const erfGradConfig = {
30262 kernelName: Erf,
30263 inputsToSave: ['x'],
30264 gradFunc: (dy, saved) => {
30265 const [x] = saved;
30266 const a = mul(exp$2(neg$2(square$2(x))), 2 / Math.sqrt(Math.PI));
30267 return { x: () => mul(dy, a) };
30268 }
30269 };
30270
30271 /**
30272 * @license
30273 * Copyright 2020 Google LLC. All Rights Reserved.
30274 * Licensed under the Apache License, Version 2.0 (the "License");
30275 * you may not use this file except in compliance with the License.
30276 * You may obtain a copy of the License at
30277 *
30278 * http://www.apache.org/licenses/LICENSE-2.0
30279 *
30280 * Unless required by applicable law or agreed to in writing, software
30281 * distributed under the License is distributed on an "AS IS" BASIS,
30282 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30283 * See the License for the specific language governing permissions and
30284 * limitations under the License.
30285 * =============================================================================
30286 */
30287 const expGradConfig = {
30288 kernelName: Exp,
30289 outputsToSave: [true],
30290 gradFunc: (dy, saved) => {
30291 const [y] = saved;
30292 return { x: () => mul(dy, y) };
30293 }
30294 };
30295
30296 /**
30297 * @license
30298 * Copyright 2020 Google LLC. All Rights Reserved.
30299 * Licensed under the Apache License, Version 2.0 (the "License");
30300 * you may not use this file except in compliance with the License.
30301 * You may obtain a copy of the License at
30302 *
30303 * http://www.apache.org/licenses/LICENSE-2.0
30304 *
30305 * Unless required by applicable law or agreed to in writing, software
30306 * distributed under the License is distributed on an "AS IS" BASIS,
30307 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30308 * See the License for the specific language governing permissions and
30309 * limitations under the License.
30310 * =============================================================================
30311 */
30312 const expandDimsGradConfig = {
30313 kernelName: ExpandDims,
30314 inputsToSave: ['input'],
30315 gradFunc: (dy, saved) => {
30316 const [input] = saved;
30317 return { input: () => reshape$3(dy, input.shape) };
30318 }
30319 };
30320
30321 /**
30322 * @license
30323 * Copyright 2020 Google LLC. All Rights Reserved.
30324 * Licensed under the Apache License, Version 2.0 (the "License");
30325 * you may not use this file except in compliance with the License.
30326 * You may obtain a copy of the License at
30327 *
30328 * http://www.apache.org/licenses/LICENSE-2.0
30329 *
30330 * Unless required by applicable law or agreed to in writing, software
30331 * distributed under the License is distributed on an "AS IS" BASIS,
30332 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30333 * See the License for the specific language governing permissions and
30334 * limitations under the License.
30335 * =============================================================================
30336 */
30337 const expm1GradConfig = {
30338 kernelName: Expm1,
30339 inputsToSave: ['x'],
30340 gradFunc: (dy, saved) => {
30341 const [x] = saved;
30342 return { x: () => mul(dy, exp$2(x)) };
30343 }
30344 };
30345
30346 /**
30347 * @license
30348 * Copyright 2020 Google LLC. All Rights Reserved.
30349 * Licensed under the Apache License, Version 2.0 (the "License");
30350 * you may not use this file except in compliance with the License.
30351 * You may obtain a copy of the License at
30352 *
30353 * http://www.apache.org/licenses/LICENSE-2.0
30354 *
30355 * Unless required by applicable law or agreed to in writing, software
30356 * distributed under the License is distributed on an "AS IS" BASIS,
30357 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30358 * See the License for the specific language governing permissions and
30359 * limitations under the License.
30360 * =============================================================================
30361 */
30362 const floorGradConfig = {
30363 kernelName: Floor,
30364 gradFunc: (dy) => {
30365 return { x: () => zerosLike$3(dy) };
30366 }
30367 };
30368
30369 /**
30370 * @license
30371 * Copyright 2020 Google LLC. All Rights Reserved.
30372 * Licensed under the Apache License, Version 2.0 (the "License");
30373 * you may not use this file except in compliance with the License.
30374 * You may obtain a copy of the License at
30375 *
30376 * http://www.apache.org/licenses/LICENSE-2.0
30377 *
30378 * Unless required by applicable law or agreed to in writing, software
30379 * distributed under the License is distributed on an "AS IS" BASIS,
30380 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30381 * See the License for the specific language governing permissions and
30382 * limitations under the License.
30383 * =============================================================================
30384 */
30385 const floorDivGradConfig = {
30386 kernelName: FloorDiv,
30387 inputsToSave: ['a', 'b'],
30388 gradFunc: (dy, saved) => {
30389 const [a, b] = saved;
30390 const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
30391 const derA = () => {
30392 const res = div$1(dy, cast$3(b, 'float32'));
30393 const reduceAxes = getReductionAxes(a.shape, outShape);
30394 if (reduceAxes.length > 0) {
30395 return reshape$3(sum$3(res, reduceAxes), a.shape);
30396 }
30397 return res;
30398 };
30399 const derB = () => {
30400 let res = mul(dy, cast$3(a, 'float32'));
30401 const reduceAxes = getReductionAxes(b.shape, outShape);
30402 if (reduceAxes.length > 0) {
30403 res = reshape$3(sum$3(res, reduceAxes), b.shape);
30404 }
30405 const tmp = square$2(b);
30406 return neg$2(div$1(res, cast$3(tmp, 'float32')));
30407 };
30408 return { a: derA, b: derB };
30409 }
30410 };
30411
30412 /**
30413 * @license
30414 * Copyright 2020 Google LLC. All Rights Reserved.
30415 * Licensed under the Apache License, Version 2.0 (the "License");
30416 * you may not use this file except in compliance with the License.
30417 * You may obtain a copy of the License at
30418 *
30419 * http://www.apache.org/licenses/LICENSE-2.0
30420 *
30421 * Unless required by applicable law or agreed to in writing, software
30422 * distributed under the License is distributed on an "AS IS" BASIS,
30423 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30424 * See the License for the specific language governing permissions and
30425 * limitations under the License.
30426 * =============================================================================
30427 */
30428 const fusedBatchNormGradConfig = {
30429 kernelName: FusedBatchNorm,
30430 inputsToSave: ['x', 'mean', 'variance', 'scale'],
30431 gradFunc: (dy, saved, attrs) => {
30432 const { varianceEpsilon } = attrs;
30433 const [x, mean, variance, scale] = saved;
30434 const scaleValue = scale == null ? scalar(1) : scale;
30435 const reductionAxes = getReductionAxes(mean.shape, x.shape);
30436 const tileShape = [];
30437 if (mean.rank === 1) {
30438 for (let i = 0; i < x.shape.length - 1; ++i) {
30439 tileShape.push(x.shape[i]);
30440 }
30441 tileShape.push(1);
30442 }
30443 const xMinusMean = sub$2(x, mean);
30444 const dyTimesScaleValue = mul(dy, scaleValue);
30445 const oneOverSqrtVariance = rsqrt$2(add$3(variance, scalar(varianceEpsilon)));
30446 const minusHalfRCube = mul(mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance), scalar(-0.5));
30447 const derX = () => {
30448 if (mean.rank === 1) {
30449 return reshape$3(mul(mul(dy, tile$3(reshape$3(oneOverSqrtVariance, [1, 1, 1, mean.shape[0]]), tileShape)), scaleValue), x.shape);
30450 }
30451 else {
30452 return reshape$3(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape);
30453 }
30454 };
30455 const derMean = () => {
30456 let meanDer = mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue);
30457 if (mean.rank === 1) {
30458 meanDer = sum$3(meanDer, reductionAxes);
30459 }
30460 return reshape$3(meanDer, mean.shape);
30461 };
30462 const derVariance = () => {
30463 let varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue);
30464 if (mean.rank === 1) {
30465 varianceDer = sum$3(varianceDer, reductionAxes);
30466 }
30467 return reshape$3(varianceDer, mean.shape);
30468 };
30469 const derScale = () => {
30470 const xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance);
30471 let scaleDer = mul(dy, xMinusMean2TimesRsqrt);
30472 if (mean.rank === 1) {
30473 scaleDer = sum$3(scaleDer, reductionAxes);
30474 }
30475 return reshape$3(scaleDer, mean.shape);
30476 };
30477 const derOffset = () => {
30478 let offsetDer = dy;
30479 if (mean.rank === 1) {
30480 offsetDer = sum$3(offsetDer, reductionAxes);
30481 }
30482 return reshape$3(offsetDer, mean.shape);
30483 };
30484 return {
30485 x: derX,
30486 mean: derMean,
30487 variance: derVariance,
30488 scale: derScale,
30489 offset: derOffset
30490 };
30491 }
30492 };
30493
30494 /**
30495 * @license
30496 * Copyright 2020 Google LLC. All Rights Reserved.
30497 * Licensed under the Apache License, Version 2.0 (the "License");
30498 * you may not use this file except in compliance with the License.
30499 * You may obtain a copy of the License at
30500 *
30501 * http://www.apache.org/licenses/LICENSE-2.0
30502 *
30503 * Unless required by applicable law or agreed to in writing, software
30504 * distributed under the License is distributed on an "AS IS" BASIS,
30505 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30506 * See the License for the specific language governing permissions and
30507 * limitations under the License.
30508 * =============================================================================
30509 */
30510 const gatherGradConfig = {
30511 kernelName: GatherV2,
30512 inputsToSave: ['x', 'indices'],
30513 gradFunc: (dy, saved, attrs) => {
30514 const [x, indices] = saved;
30515 const { axis, batchDims } = attrs;
30516 const parsedAxis = parseAxisParam(axis, x.shape)[0];
30517 const derXBatch = (x, indices, dy) => {
30518 return () => {
30519 const paramsShape = x.shape;
30520 const indicesSize = indices.size;
30521 const outerShape = paramsShape.slice(0, parsedAxis);
30522 const outerDims = outerShape.length;
30523 const innerShape = paramsShape.slice(axis, paramsShape.length).slice(1);
30524 const innerDims = innerShape.length;
30525 const outerAxesIndices = arrayRange(0, outerDims);
30526 const innerAxesIndices = arrayRange(outerDims + 1, outerDims + 1 + innerDims);
30527 const valuesShape = arrayConcat([outerShape, [indicesSize],
30528 innerShape]);
30529 const values = reshape$3(dy, valuesShape);
30530 const reshapedIndices = reshape$3(indices, [indicesSize]);
30531 const transposeDims = arrayConcat([[outerDims], outerAxesIndices, innerAxesIndices]);
30532 const valuesTranspose = transpose$2(values, transposeDims);
30533 let paramsGrad = unsortedSegmentSum$2(valuesTranspose, reshapedIndices, x.shape[parsedAxis]);
30534 const invertTransposeDims = getUndoAxesPermutation(transposeDims);
30535 paramsGrad = transpose$2(paramsGrad, invertTransposeDims);
30536 return paramsGrad;
30537 };
30538 };
30539 if (batchDims === 1) {
30540 const batchSize = x.shape[0];
30541 const xBatch = x.split(batchSize, 0);
30542 const derXBatched = () => {
30543 const stacked = stack(xBatch.map((x, i) => {
30544 return derXBatch(x, indices.slice(i, 1), dy.slice(i, 1))();
30545 }));
30546 return stacked.reshape(x.shape);
30547 };
30548 return { x: derXBatched, indices: () => indices };
30549 }
30550 else {
30551 return { x: derXBatch(x, indices, dy), indices: () => indices };
30552 }
30553 }
30554 };
30555 function arrayRange(start, stop) {
30556 const result = [];
30557 for (let i = start; i < stop; ++i) {
30558 result.push(i);
30559 }
30560 return result;
30561 }
30562 function arrayConcat(arrays) {
30563 const result = [];
30564 for (let i = 0; i < arrays.length; ++i) {
30565 for (let j = 0; j < arrays[i].length; ++j) {
30566 result.push(arrays[i][j]);
30567 }
30568 }
30569 return result;
30570 }
30571
30572 /**
30573 * @license
30574 * Copyright 2020 Google LLC. All Rights Reserved.
30575 * Licensed under the Apache License, Version 2.0 (the "License");
30576 * you may not use this file except in compliance with the License.
30577 * You may obtain a copy of the License at
30578 *
30579 * http://www.apache.org/licenses/LICENSE-2.0
30580 *
30581 * Unless required by applicable law or agreed to in writing, software
30582 * distributed under the License is distributed on an "AS IS" BASIS,
30583 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30584 * See the License for the specific language governing permissions and
30585 * limitations under the License.
30586 * =============================================================================
30587 */
30588 const greaterEqualGradConfig = {
30589 kernelName: GreaterEqual,
30590 inputsToSave: ['a', 'b'],
30591 gradFunc: (dy, saved) => {
30592 const [a, b] = saved;
30593 return { a: () => zerosLike$3(a), b: () => zerosLike$3(b) };
30594 }
30595 };
30596
30597 /**
30598 * @license
30599 * Copyright 2020 Google LLC. All Rights Reserved.
30600 * Licensed under the Apache License, Version 2.0 (the "License");
30601 * you may not use this file except in compliance with the License.
30602 * You may obtain a copy of the License at
30603 *
30604 * http://www.apache.org/licenses/LICENSE-2.0
30605 *
30606 * Unless required by applicable law or agreed to in writing, software
30607 * distributed under the License is distributed on an "AS IS" BASIS,
30608 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30609 * See the License for the specific language governing permissions and
30610 * limitations under the License.
30611 * =============================================================================
30612 */
30613 const identityGradConfig = {
30614 kernelName: Identity$1,
30615 gradFunc: (dy) => {
30616 return { x: () => cast$3(dy, 'float32') };
30617 }
30618 };
30619
30620 /**
30621 * @license
30622 * Copyright 2020 Google LLC. All Rights Reserved.
30623 * Licensed under the Apache License, Version 2.0 (the "License");
30624 * you may not use this file except in compliance with the License.
30625 * You may obtain a copy of the License at
30626 *
30627 * http://www.apache.org/licenses/LICENSE-2.0
30628 *
30629 * Unless required by applicable law or agreed to in writing, software
30630 * distributed under the License is distributed on an "AS IS" BASIS,
30631 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30632 * See the License for the specific language governing permissions and
30633 * limitations under the License.
30634 * =============================================================================
30635 */
30636 const isFiniteGradConfig = {
30637 kernelName: IsFinite,
30638 gradFunc: (dy) => {
30639 // TODO(nsthorat): Let gradients be null for cases where we want to stop
30640 // backpropgation.
30641 return { x: () => zerosLike$3(dy) };
30642 }
30643 };
30644
30645 /**
30646 * @license
30647 * Copyright 2020 Google LLC. All Rights Reserved.
30648 * Licensed under the Apache License, Version 2.0 (the "License");
30649 * you may not use this file except in compliance with the License.
30650 * You may obtain a copy of the License at
30651 *
30652 * http://www.apache.org/licenses/LICENSE-2.0
30653 *
30654 * Unless required by applicable law or agreed to in writing, software
30655 * distributed under the License is distributed on an "AS IS" BASIS,
30656 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30657 * See the License for the specific language governing permissions and
30658 * limitations under the License.
30659 * =============================================================================
30660 */
30661 const isInfGradConfig = {
30662 kernelName: IsInf,
30663 gradFunc: (dy) => {
30664 // TODO(nsthorat): Let gradients be null for cases where we want to stop
30665 // backpropgation.
30666 return { x: () => zerosLike$3(dy) };
30667 }
30668 };
30669
30670 /**
30671 * @license
30672 * Copyright 2020 Google LLC. All Rights Reserved.
30673 * Licensed under the Apache License, Version 2.0 (the "License");
30674 * you may not use this file except in compliance with the License.
30675 * You may obtain a copy of the License at
30676 *
30677 * http://www.apache.org/licenses/LICENSE-2.0
30678 *
30679 * Unless required by applicable law or agreed to in writing, software
30680 * distributed under the License is distributed on an "AS IS" BASIS,
30681 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30682 * See the License for the specific language governing permissions and
30683 * limitations under the License.
30684 * =============================================================================
30685 */
30686 const isNanGradConfig = {
30687 kernelName: IsNan,
30688 gradFunc: (dy) => {
30689 // TODO(nsthorat): Let gradients be null for cases where we want to stop
30690 // backpropgation.
30691 return { x: () => zerosLike$3(dy) };
30692 }
30693 };
30694
30695 /**
30696 * @license
30697 * Copyright 2020 Google LLC. All Rights Reserved.
30698 * Licensed under the Apache License, Version 2.0 (the "License");
30699 * you may not use this file except in compliance with the License.
30700 * You may obtain a copy of the License at
30701 *
30702 * http://www.apache.org/licenses/LICENSE-2.0
30703 *
30704 * Unless required by applicable law or agreed to in writing, software
30705 * distributed under the License is distributed on an "AS IS" BASIS,
30706 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30707 * See the License for the specific language governing permissions and
30708 * limitations under the License.
30709 * =============================================================================
30710 */
30711 const leakyReluGradConfig = {
30712 kernelName: LeakyRelu,
30713 inputsToSave: ['x'],
30714 gradFunc: (dy, saved, attrs) => {
30715 const [x] = saved;
30716 const { alpha } = attrs;
30717 const mask = greater$3(x, 0);
30718 // Returns `gradients * (features > 0) + alpha * gradients * (features <=
30719 // 0)`.
30720 return { x: () => where(mask, dy, mul(dy, alpha)) };
30721 }
30722 };
30723
30724 /**
30725 * @license
30726 * Copyright 2020 Google LLC. All Rights Reserved.
30727 * Licensed under the Apache License, Version 2.0 (the "License");
30728 * you may not use this file except in compliance with the License.
30729 * You may obtain a copy of the License at
30730 *
30731 * http://www.apache.org/licenses/LICENSE-2.0
30732 *
30733 * Unless required by applicable law or agreed to in writing, software
30734 * distributed under the License is distributed on an "AS IS" BASIS,
30735 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30736 * See the License for the specific language governing permissions and
30737 * limitations under the License.
30738 * =============================================================================
30739 */
30740 const log1pGradConfig = {
30741 kernelName: Log1p,
30742 inputsToSave: ['x'],
30743 gradFunc: (dy, saved) => {
30744 const [x] = saved;
30745 return { x: () => div$1(dy, add$3(x, 1)) };
30746 }
30747 };
30748
30749 /**
30750 * @license
30751 * Copyright 2020 Google LLC. All Rights Reserved.
30752 * Licensed under the Apache License, Version 2.0 (the "License");
30753 * you may not use this file except in compliance with the License.
30754 * You may obtain a copy of the License at
30755 *
30756 * http://www.apache.org/licenses/LICENSE-2.0
30757 *
30758 * Unless required by applicable law or agreed to in writing, software
30759 * distributed under the License is distributed on an "AS IS" BASIS,
30760 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30761 * See the License for the specific language governing permissions and
30762 * limitations under the License.
30763 * =============================================================================
30764 */
30765 const logGradConfig = {
30766 kernelName: Log,
30767 inputsToSave: ['x'],
30768 gradFunc: (dy, saved) => {
30769 const [x] = saved;
30770 return { x: () => div$1(dy, cast$3(x, 'float32')) };
30771 }
30772 };
30773
30774 /**
30775 * @license
30776 * Copyright 2020 Google LLC. All Rights Reserved.
30777 * Licensed under the Apache License, Version 2.0 (the "License");
30778 * you may not use this file except in compliance with the License.
30779 * You may obtain a copy of the License at
30780 *
30781 * http://www.apache.org/licenses/LICENSE-2.0
30782 *
30783 * Unless required by applicable law or agreed to in writing, software
30784 * distributed under the License is distributed on an "AS IS" BASIS,
30785 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30786 * See the License for the specific language governing permissions and
30787 * limitations under the License.
30788 * =============================================================================
30789 */
30790 const logSoftmaxGradConfig = {
30791 kernelName: LogSoftmax$1,
30792 inputsToSave: [],
30793 outputsToSave: [true],
30794 gradFunc: (dy, saved, attrs) => {
30795 const [value] = saved;
30796 const { axis } = attrs;
30797 return {
30798 logits: () => {
30799 const keepDims = true;
30800 const softmax = exp$2(value);
30801 return sub$2(dy, mul(sum$3(dy, axis, keepDims), softmax));
30802 }
30803 };
30804 }
30805 };
30806
30807 /**
30808 * @license
30809 * Copyright 2020 Google LLC. All Rights Reserved.
30810 * Licensed under the Apache License, Version 2.0 (the "License");
30811 * you may not use this file except in compliance with the License.
30812 * You may obtain a copy of the License at
30813 *
30814 * http://www.apache.org/licenses/LICENSE-2.0
30815 *
30816 * Unless required by applicable law or agreed to in writing, software
30817 * distributed under the License is distributed on an "AS IS" BASIS,
30818 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30819 * See the License for the specific language governing permissions and
30820 * limitations under the License.
30821 * =============================================================================
30822 */
30823 function localResponseNormalizationBackprop_(x, y, dy, depthRadius = 5, bias = 1, alpha = 1, beta = 0.5) {
30824 const inputs = { x, y, dy };
30825 const attrs = { depthRadius, bias, alpha, beta };
30826 return ENGINE.runKernel(LRNGrad, inputs, attrs);
30827 }
30828 const localResponseNormalizationBackprop = op({ localResponseNormalizationBackprop_ });
30829
30830 /**
30831 * @license
30832 * Copyright 2020 Google LLC. All Rights Reserved.
30833 * Licensed under the Apache License, Version 2.0 (the "License");
30834 * you may not use this file except in compliance with the License.
30835 * You may obtain a copy of the License at
30836 *
30837 * http://www.apache.org/licenses/LICENSE-2.0
30838 *
30839 * Unless required by applicable law or agreed to in writing, software
30840 * distributed under the License is distributed on an "AS IS" BASIS,
30841 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30842 * See the License for the specific language governing permissions and
30843 * limitations under the License.
30844 * =============================================================================
30845 */
30846 const lrnGradConfig = {
30847 kernelName: LRN,
30848 inputsToSave: ['x'],
30849 outputsToSave: [true],
30850 gradFunc: (dy, saved, attrs) => {
30851 const [x, y] = saved;
30852 const { depthRadius, bias, alpha, beta } = attrs;
30853 return {
30854 x: () => localResponseNormalizationBackprop(x, y, dy, depthRadius, bias, alpha, beta)
30855 };
30856 }
30857 };
30858
30859 /**
30860 * @license
30861 * Copyright 2020 Google LLC. All Rights Reserved.
30862 * Licensed under the Apache License, Version 2.0 (the "License");
30863 * you may not use this file except in compliance with the License.
30864 * You may obtain a copy of the License at
30865 *
30866 * http://www.apache.org/licenses/LICENSE-2.0
30867 *
30868 * Unless required by applicable law or agreed to in writing, software
30869 * distributed under the License is distributed on an "AS IS" BASIS,
30870 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30871 * See the License for the specific language governing permissions and
30872 * limitations under the License.
30873 * =============================================================================
30874 */
30875 /**
30876 * Gradient helper function for the min and max operations.
30877 */
30878 function gradForMinAndMax(dy, y, xOrig, origAxes) {
30879 if (y.rank < xOrig.rank) {
30880 y = reshape$3(y, expandShapeToKeepDim(y.shape, origAxes));
30881 }
30882 if (dy.rank < xOrig.rank) {
30883 dy = reshape$3(dy, expandShapeToKeepDim(dy.shape, origAxes));
30884 }
30885 return {
30886 x: () => {
30887 const dx = mul(dy, cast$3(equal$2(xOrig, y), dy.dtype));
30888 return dx;
30889 }
30890 };
30891 }
30892
30893 /**
30894 * @license
30895 * Copyright 2020 Google LLC. All Rights Reserved.
30896 * Licensed under the Apache License, Version 2.0 (the "License");
30897 * you may not use this file except in compliance with the License.
30898 * You may obtain a copy of the License at
30899 *
30900 * http://www.apache.org/licenses/LICENSE-2.0
30901 *
30902 * Unless required by applicable law or agreed to in writing, software
30903 * distributed under the License is distributed on an "AS IS" BASIS,
30904 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30905 * See the License for the specific language governing permissions and
30906 * limitations under the License.
30907 * =============================================================================
30908 */
30909 const maxGradConfig = {
30910 kernelName: Max,
30911 inputsToSave: ['x'],
30912 outputsToSave: [true],
30913 gradFunc: (dy, saved, attrs) => {
30914 const maxAttrs = attrs;
30915 const { reductionIndices } = maxAttrs;
30916 const x = saved[0];
30917 const y = saved[1];
30918 const origAxes = parseAxisParam(reductionIndices, x.shape);
30919 const maxGrad = gradForMinAndMax(dy, y, x, origAxes);
30920 return {
30921 x: () => {
30922 return maxGrad['x']();
30923 }
30924 };
30925 }
30926 };
30927
30928 /**
30929 * @license
30930 * Copyright 2020 Google LLC. All Rights Reserved.
30931 * Licensed under the Apache License, Version 2.0 (the "License");
30932 * you may not use this file except in compliance with the License.
30933 * You may obtain a copy of the License at
30934 *
30935 * http://www.apache.org/licenses/LICENSE-2.0
30936 *
30937 * Unless required by applicable law or agreed to in writing, software
30938 * distributed under the License is distributed on an "AS IS" BASIS,
30939 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30940 * See the License for the specific language governing permissions and
30941 * limitations under the License.
30942 * =============================================================================
30943 */
30944 const maximumGradConfig = {
30945 kernelName: Maximum$1,
30946 inputsToSave: ['a', 'b'],
30947 gradFunc: (dy, saved) => {
30948 const [a, b] = saved;
30949 const derA = () => mul(dy, cast$3(greaterEqual$2(a, b), 'float32'));
30950 const derB = () => mul(dy, cast$3(less$3(a, b), 'float32'));
30951 return { a: derA, b: derB };
30952 }
30953 };
30954
30955 /**
30956 * @license
30957 * Copyright 2020 Google LLC. All Rights Reserved.
30958 * Licensed under the Apache License, Version 2.0 (the "License");
30959 * you may not use this file except in compliance with the License.
30960 * You may obtain a copy of the License at
30961 *
30962 * http://www.apache.org/licenses/LICENSE-2.0
30963 *
30964 * Unless required by applicable law or agreed to in writing, software
30965 * distributed under the License is distributed on an "AS IS" BASIS,
30966 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30967 * See the License for the specific language governing permissions and
30968 * limitations under the License.
30969 * =============================================================================
30970 */
30971 /**
30972 * Computes the backprop of a 3d max pool.
30973 *
30974 * @param dy The dy error, of rank 5 of shape
30975 * [batchSize, depth, height, width, channels].
30976 * assumed.
30977 * @param input The original input image, of rank 5 or rank 4 of shape
30978 * [batchSize, depth, height, width, channels].
30979 * @param output The original output image, of rank 5 of shape
30980 * [batchSize, outDepth, outHeight, outWidth, channels].
30981 * @param filterSize The filter size:
30982 * `[filterDepth, filterHeight, filterWidth]`.
30983 * `filterSize` is a single number,
30984 * then `filterDepth == filterHeight == filterWidth`.
30985 * @param strides The strides of the pooling:
30986 * `[strideDepth, strideHeight, strideWidth]`. If
30987 * `strides` is a single number, then `strideHeight == strideWidth`.
30988 * @param pad A string from: 'same', 'valid'. The type of padding algorithm
30989 * used in the forward prop of the op.
30990 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
30991 * provided, it will default to truncate.
30992 */
30993 function maxPool3dGrad_(dy, input, output, filterSize, strides, pad, dimRoundingMode) {
30994 const $dy = convertToTensor(dy, 'dy', 'maxPool3dGrad');
30995 const $input = convertToTensor(input, 'input', 'maxPool3dGrad');
30996 const $output = convertToTensor(output, 'output', 'maxPool3dGrad');
30997 let dy5D = $dy;
30998 let input5D = $input;
30999 let output5D = $output;
31000 let reshapedTo5D = false;
31001 if ($input.rank === 4) {
31002 reshapedTo5D = true;
31003 dy5D = reshape$3($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
31004 input5D = reshape$3($input, [
31005 1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]
31006 ]);
31007 output5D = reshape$3($output, [
31008 1, $output.shape[0], $output.shape[1], $output.shape[2], $output.shape[3]
31009 ]);
31010 }
31011 assert$1(dy5D.rank === 5, () => `Error in maxPool3dGrad: dy must be rank 5 but got rank ` +
31012 `${dy5D.rank}.`);
31013 assert$1(input5D.rank === 5, () => `Error in maxPool3dGrad: input must be rank 5 but got rank ` +
31014 `${input5D.rank}.`);
31015 assert$1(output5D.rank === 5, () => `Error in maxPool3dGrad: output must be rank 5 but got rank ` +
31016 `${output5D.rank}.`);
31017 checkPadOnDimRoundingMode('maxPool3dGrad', pad, dimRoundingMode);
31018 const inputs = { dy: dy5D, input: input5D, output: output5D };
31019 const attrs = { filterSize, strides, pad, dimRoundingMode };
31020 // tslint:disable-next-line: no-unnecessary-type-assertion
31021 const res = ENGINE.runKernel(MaxPool3DGrad, inputs, attrs);
31022 if (reshapedTo5D) {
31023 return reshape$3(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
31024 }
31025 return res;
31026 }
31027 const maxPool3dGrad = /* @__PURE__ */ op({ maxPool3dGrad_ });
31028
31029 /**
31030 * @license
31031 * Copyright 2020 Google LLC. All Rights Reserved.
31032 * Licensed under the Apache License, Version 2.0 (the "License");
31033 * you may not use this file except in compliance with the License.
31034 * You may obtain a copy of the License at
31035 *
31036 * http://www.apache.org/licenses/LICENSE-2.0
31037 *
31038 * Unless required by applicable law or agreed to in writing, software
31039 * distributed under the License is distributed on an "AS IS" BASIS,
31040 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31041 * See the License for the specific language governing permissions and
31042 * limitations under the License.
31043 * =============================================================================
31044 */
31045 const maxPool3DGradConfig$2 = {
31046 kernelName: MaxPool3D,
31047 inputsToSave: ['x'],
31048 outputsToSave: [true],
31049 gradFunc: (dy, saved, attrs) => {
31050 const [x, y] = saved;
31051 const { filterSize, strides, pad, dimRoundingMode } = attrs;
31052 return {
31053 x: () => maxPool3dGrad(dy, x, y, filterSize, strides, pad, dimRoundingMode)
31054 };
31055 }
31056 };
31057
31058 /**
31059 * @license
31060 * Copyright 2020 Google LLC. All Rights Reserved.
31061 * Licensed under the Apache License, Version 2.0 (the "License");
31062 * you may not use this file except in compliance with the License.
31063 * You may obtain a copy of the License at
31064 *
31065 * http://www.apache.org/licenses/LICENSE-2.0
31066 *
31067 * Unless required by applicable law or agreed to in writing, software
31068 * distributed under the License is distributed on an "AS IS" BASIS,
31069 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31070 * See the License for the specific language governing permissions and
31071 * limitations under the License.
31072 * =============================================================================
31073 */
31074 /**
31075 * Computes the backprop of a 2D max pool.
31076 *
31077 * @param dy The dy error, of rank 4 or rank 3 of shape
31078 * [batchSize, height, width, channels]. If rank 3, batch of 1 is
31079 * assumed.
31080 * @param input The original input image, of rank 4, of shape
31081 * [batchSize, height, width, channels].
31082 * @param output The original output image, of rank 4, of shape
31083 * [batchSize, outHeight, outWidth, channels].
31084 * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
31085 * `filterSize` is a single number, then `filterHeight == filterWidth`.
31086 * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
31087 * `strides` is a single number, then `strideHeight == strideWidth`.
31088 * @param pad The type of padding algorithm used in the forward prop of the op.
31089 * 'same', 'valid', for more info, see this guide:
31090 * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
31091 * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
31092 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
31093 * provided, it will default to truncate.
31094 */
31095 function maxPoolGrad_(dy, input, output, filterSize, strides, pad, dimRoundingMode) {
31096 const $dy = convertToTensor(dy, 'dy', 'maxPoolGrad');
31097 const $input = convertToTensor(input, 'input', 'maxPoolGrad');
31098 const $output = convertToTensor(output, 'output', 'maxPoolGrad');
31099 assert$1($input.rank === $dy.rank, () => `Rank of input (${$input.rank}) does not match rank of dy ` +
31100 `(${$dy.rank})`);
31101 assert$1($dy.rank === 4, () => `Error in maxPoolGrad: dy must be rank 4 but got rank ` +
31102 `${$dy.rank}.`);
31103 assert$1($input.rank === 4, () => `Error in maxPoolGrad: input must be rank 4 but got rank ` +
31104 `${$input.rank}.`);
31105 checkPadOnDimRoundingMode('maxPoolGrad', pad, dimRoundingMode);
31106 const inputs = { dy: $dy, input: $input, output: $output };
31107 const attrs = { filterSize, strides, pad, dimRoundingMode };
31108 // tslint:disable-next-line: no-unnecessary-type-assertion
31109 return ENGINE.runKernel(MaxPoolGrad, inputs, attrs);
31110 }
31111 const maxPoolGrad$2 = /* @__PURE__ */ op({ maxPoolGrad_ });
31112
31113 /**
31114 * @license
31115 * Copyright 2020 Google LLC. All Rights Reserved.
31116 * Licensed under the Apache License, Version 2.0 (the "License");
31117 * you may not use this file except in compliance with the License.
31118 * You may obtain a copy of the License at
31119 *
31120 * http://www.apache.org/licenses/LICENSE-2.0
31121 *
31122 * Unless required by applicable law or agreed to in writing, software
31123 * distributed under the License is distributed on an "AS IS" BASIS,
31124 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31125 * See the License for the specific language governing permissions and
31126 * limitations under the License.
31127 * =============================================================================
31128 */
31129 const maxPoolGradConfig$2 = {
31130 kernelName: MaxPool,
31131 inputsToSave: ['x'],
31132 outputsToSave: [true],
31133 gradFunc: (dy, saved, attrs) => {
31134 const [x, y] = saved;
31135 const { filterSize, strides, pad } = attrs;
31136 return {
31137 x: () => maxPoolGrad$2(dy, x, y, filterSize, strides, pad)
31138 };
31139 }
31140 };
31141
31142 /**
31143 * @license
31144 * Copyright 2020 Google LLC. All Rights Reserved.
31145 * Licensed under the Apache License, Version 2.0 (the "License");
31146 * you may not use this file except in compliance with the License.
31147 * You may obtain a copy of the License at
31148 *
31149 * http://www.apache.org/licenses/LICENSE-2.0
31150 *
31151 * Unless required by applicable law or agreed to in writing, software
31152 * distributed under the License is distributed on an "AS IS" BASIS,
31153 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31154 * See the License for the specific language governing permissions and
31155 * limitations under the License.
31156 * =============================================================================
31157 */
31158 const meanGradConfig = {
31159 kernelName: Mean,
31160 inputsToSave: ['x'],
31161 gradFunc: (dy, saved, attrs) => {
31162 const [x] = saved;
31163 const { axis } = attrs;
31164 const axes = parseAxisParam(axis, x.shape);
31165 const shapes = computeOutAndReduceShapes(x.shape, axes);
31166 const reduceShape = shapes[1];
31167 const reduceSize = sizeFromShape(reduceShape);
31168 const derX = () => {
31169 const expandedDyShape = x.shape.slice();
31170 axes.forEach(axis => {
31171 expandedDyShape[axis] = 1;
31172 });
31173 const expandedDy = reshape$3(dy, expandedDyShape);
31174 const res = div$1(mul(expandedDy, ones$1(x.shape, 'float32')), reduceSize);
31175 return res;
31176 };
31177 return { x: derX };
31178 }
31179 };
31180
31181 /**
31182 * @license
31183 * Copyright 2020 Google LLC. All Rights Reserved.
31184 * Licensed under the Apache License, Version 2.0 (the "License");
31185 * you may not use this file except in compliance with the License.
31186 * You may obtain a copy of the License at
31187 *
31188 * http://www.apache.org/licenses/LICENSE-2.0
31189 *
31190 * Unless required by applicable law or agreed to in writing, software
31191 * distributed under the License is distributed on an "AS IS" BASIS,
31192 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31193 * See the License for the specific language governing permissions and
31194 * limitations under the License.
31195 * =============================================================================
31196 */
31197 const minGradConfig = {
31198 kernelName: Min,
31199 inputsToSave: ['x'],
31200 outputsToSave: [true],
31201 gradFunc: (dy, saved, attrs) => {
31202 const minAttrs = attrs;
31203 const { axis } = minAttrs;
31204 const [x, y] = saved;
31205 const origAxes = parseAxisParam(axis, x.shape);
31206 const minGrad = gradForMinAndMax(dy, y, x, origAxes);
31207 return {
31208 x: () => {
31209 return minGrad['x']();
31210 }
31211 };
31212 }
31213 };
31214
31215 /**
31216 * @license
31217 * Copyright 2020 Google LLC. All Rights Reserved.
31218 * Licensed under the Apache License, Version 2.0 (the "License");
31219 * you may not use this file except in compliance with the License.
31220 * You may obtain a copy of the License at
31221 *
31222 * http://www.apache.org/licenses/LICENSE-2.0
31223 *
31224 * Unless required by applicable law or agreed to in writing, software
31225 * distributed under the License is distributed on an "AS IS" BASIS,
31226 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31227 * See the License for the specific language governing permissions and
31228 * limitations under the License.
31229 * =============================================================================
31230 */
31231 const minimumGradConfig = {
31232 kernelName: Minimum$1,
31233 inputsToSave: ['a', 'b'],
31234 gradFunc: (dy, saved) => {
31235 const [a, b] = saved;
31236 const derA = () => mul(dy, cast$3(lessEqual$2(a, b), 'float32'));
31237 const derB = () => mul(dy, cast$3(greater$3(a, b), 'float32'));
31238 return { a: derA, b: derB };
31239 }
31240 };
31241
31242 /**
31243 * @license
31244 * Copyright 2020 Google LLC. All Rights Reserved.
31245 * Licensed under the Apache License, Version 2.0 (the "License");
31246 * you may not use this file except in compliance with the License.
31247 * You may obtain a copy of the License at
31248 *
31249 * http://www.apache.org/licenses/LICENSE-2.0
31250 *
31251 * Unless required by applicable law or agreed to in writing, software
31252 * distributed under the License is distributed on an "AS IS" BASIS,
31253 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31254 * See the License for the specific language governing permissions and
31255 * limitations under the License.
31256 * =============================================================================
31257 */
31258 const mirrorPadGradConfig = {
31259 kernelName: MirrorPad,
31260 inputsToSave: ['x'],
31261 gradFunc: (dy, saved, attrs) => {
31262 // Pad introduces values around the original tensor, so the gradient
31263 // slices the original shape out of the gradient.
31264 const x = saved[0];
31265 const { paddings } = attrs;
31266 const begin = paddings.map(p => p[0]);
31267 return { x: () => slice$2(dy, begin, x.shape) };
31268 }
31269 };
31270
31271 /**
31272 * @license
31273 * Copyright 2020 Google LLC. All Rights Reserved.
31274 * Licensed under the Apache License, Version 2.0 (the "License");
31275 * you may not use this file except in compliance with the License.
31276 * You may obtain a copy of the License at
31277 *
31278 * http://www.apache.org/licenses/LICENSE-2.0
31279 *
31280 * Unless required by applicable law or agreed to in writing, software
31281 * distributed under the License is distributed on an "AS IS" BASIS,
31282 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31283 * See the License for the specific language governing permissions and
31284 * limitations under the License.
31285 * =============================================================================
31286 */
31287 const modGradConfig = {
31288 kernelName: Mod,
31289 inputsToSave: ['a', 'b'],
31290 gradFunc: (dy, saved) => {
31291 const [a, b] = saved;
31292 const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
31293 const derA = () => {
31294 const reduceAxes = getReductionAxes(a.shape, outShape);
31295 if (reduceAxes.length > 0) {
31296 return reshape$3(sum$3(dy, reduceAxes), a.shape);
31297 }
31298 return dy;
31299 };
31300 const derB = () => {
31301 const res = mul(dy, neg$2(floor$2(div$1(a, b))));
31302 const reduceAxes = getReductionAxes(b.shape, outShape);
31303 if (reduceAxes.length > 0) {
31304 return reshape$3(sum$3(res, reduceAxes), b.shape);
31305 }
31306 return res;
31307 };
31308 return { a: derA, b: derB };
31309 }
31310 };
31311
31312 /**
31313 * @license
31314 * Copyright 2020 Google LLC. All Rights Reserved.
31315 * Licensed under the Apache License, Version 2.0 (the "License");
31316 * you may not use this file except in compliance with the License.
31317 * You may obtain a copy of the License at
31318 *
31319 * http://www.apache.org/licenses/LICENSE-2.0
31320 *
31321 * Unless required by applicable law or agreed to in writing, software
31322 * distributed under the License is distributed on an "AS IS" BASIS,
31323 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31324 * See the License for the specific language governing permissions and
31325 * limitations under the License.
31326 * =============================================================================
31327 */
31328 const multiplyGradConfig = {
31329 kernelName: Multiply$1,
31330 inputsToSave: ['a', 'b'],
31331 gradFunc: (dy, saved) => {
31332 const [a, b] = saved;
31333 const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
31334 const derA = () => {
31335 const res = mul(dy, cast$3(b, 'float32'));
31336 const reduceAxes = getReductionAxes(a.shape, outShape);
31337 if (reduceAxes.length > 0) {
31338 return reshape$3(sum$3(res, reduceAxes), a.shape);
31339 }
31340 return res;
31341 };
31342 const derB = () => {
31343 const res = mul(dy, cast$3(a, 'float32'));
31344 const reduceAxes = getReductionAxes(b.shape, outShape);
31345 if (reduceAxes.length > 0) {
31346 return reshape$3(sum$3(res, reduceAxes), b.shape);
31347 }
31348 return res;
31349 };
31350 return { a: derA, b: derB };
31351 }
31352 };
31353
31354 /**
31355 * @license
31356 * Copyright 2020 Google LLC. All Rights Reserved.
31357 * Licensed under the Apache License, Version 2.0 (the "License");
31358 * you may not use this file except in compliance with the License.
31359 * You may obtain a copy of the License at
31360 *
31361 * http://www.apache.org/licenses/LICENSE-2.0
31362 *
31363 * Unless required by applicable law or agreed to in writing, software
31364 * distributed under the License is distributed on an "AS IS" BASIS,
31365 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31366 * See the License for the specific language governing permissions and
31367 * limitations under the License.
31368 * =============================================================================
31369 */
31370 const negGradConfig = {
31371 kernelName: Neg,
31372 gradFunc: (dy) => {
31373 return { x: () => neg$2(dy) };
31374 }
31375 };
31376
31377 /**
31378 * @license
31379 * Copyright 2020 Google LLC. All Rights Reserved.
31380 * Licensed under the Apache License, Version 2.0 (the "License");
31381 * you may not use this file except in compliance with the License.
31382 * You may obtain a copy of the License at
31383 *
31384 * http://www.apache.org/licenses/LICENSE-2.0
31385 *
31386 * Unless required by applicable law or agreed to in writing, software
31387 * distributed under the License is distributed on an "AS IS" BASIS,
31388 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31389 * See the License for the specific language governing permissions and
31390 * limitations under the License.
31391 * =============================================================================
31392 */
31393 const oneHotGradConfig = {
31394 kernelName: OneHot,
31395 inputsToSave: ['indices'],
31396 gradFunc: (dy, saved) => {
31397 const indices = saved[0];
31398 return { indices: () => zeros$2(indices.shape, 'float32') };
31399 }
31400 };
31401
31402 /**
31403 * @license
31404 * Copyright 2020 Google LLC. All Rights Reserved.
31405 * Licensed under the Apache License, Version 2.0 (the "License");
31406 * you may not use this file except in compliance with the License.
31407 * You may obtain a copy of the License at
31408 *
31409 * http://www.apache.org/licenses/LICENSE-2.0
31410 *
31411 * Unless required by applicable law or agreed to in writing, software
31412 * distributed under the License is distributed on an "AS IS" BASIS,
31413 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31414 * See the License for the specific language governing permissions and
31415 * limitations under the License.
31416 * =============================================================================
31417 */
31418 const onesLikeGradConfig = {
31419 kernelName: OnesLike,
31420 gradFunc: (dy) => {
31421 return { x: () => zerosLike$3(dy) };
31422 }
31423 };
31424
31425 /**
31426 * @license
31427 * Copyright 2020 Google LLC. All Rights Reserved.
31428 * Licensed under the Apache License, Version 2.0 (the "License");
31429 * you may not use this file except in compliance with the License.
31430 * You may obtain a copy of the License at
31431 *
31432 * http://www.apache.org/licenses/LICENSE-2.0
31433 *
31434 * Unless required by applicable law or agreed to in writing, software
31435 * distributed under the License is distributed on an "AS IS" BASIS,
31436 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31437 * See the License for the specific language governing permissions and
31438 * limitations under the License.
31439 * =============================================================================
31440 */
31441 const packGradConfig = {
31442 kernelName: Pack,
31443 saveAllInputs: true,
31444 gradFunc: (dy, saved, attrs) => {
31445 const { axis } = attrs;
31446 const derTensors = unstack(dy, axis);
31447 return derTensors.map(t => () => t);
31448 }
31449 };
31450
31451 /**
31452 * @license
31453 * Copyright 2020 Google LLC. All Rights Reserved.
31454 * Licensed under the Apache License, Version 2.0 (the "License");
31455 * you may not use this file except in compliance with the License.
31456 * You may obtain a copy of the License at
31457 *
31458 * http://www.apache.org/licenses/LICENSE-2.0
31459 *
31460 * Unless required by applicable law or agreed to in writing, software
31461 * distributed under the License is distributed on an "AS IS" BASIS,
31462 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31463 * See the License for the specific language governing permissions and
31464 * limitations under the License.
31465 * =============================================================================
31466 */
31467 const padV2GradConfig = {
31468 kernelName: PadV2,
31469 inputsToSave: ['x'],
31470 gradFunc: (dy, saved, attrs) => {
31471 // Pad introduces values around the original tensor, so the gradient
31472 // slices the original shape out of the gradient.
31473 const x = saved[0];
31474 const { paddings } = attrs;
31475 const begin = paddings.map(p => p[0]);
31476 return { x: () => slice$2(dy, begin, x.shape) };
31477 }
31478 };
31479
31480 /**
31481 * @license
31482 * Copyright 2020 Google LLC. All Rights Reserved.
31483 * Licensed under the Apache License, Version 2.0 (the "License");
31484 * you may not use this file except in compliance with the License.
31485 * You may obtain a copy of the License at
31486 *
31487 * http://www.apache.org/licenses/LICENSE-2.0
31488 *
31489 * Unless required by applicable law or agreed to in writing, software
31490 * distributed under the License is distributed on an "AS IS" BASIS,
31491 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31492 * See the License for the specific language governing permissions and
31493 * limitations under the License.
31494 * =============================================================================
31495 */
31496 const powGradConfig = {
31497 kernelName: Pow,
31498 inputsToSave: ['a', 'b'],
31499 outputsToSave: [true],
31500 gradFunc: (dy, saved) => {
31501 const [a, b, y] = saved;
31502 const base = a;
31503 const exp = b;
31504 const outShape = assertAndGetBroadcastShape(base.shape, exp.shape);
31505 const derBase = () => {
31506 const expFloat = cast$3(exp, 'float32');
31507 let res = mul(dy, mul(expFloat, pow$3(base, sub$2(expFloat, scalar(1)))));
31508 const reduceAxes = getReductionAxes(base.shape, outShape);
31509 if (reduceAxes.length > 0) {
31510 res = sum$3(res, reduceAxes);
31511 }
31512 return reshape$3(res, base.shape);
31513 };
31514 const derExp = () => {
31515 const condition = greater$3(base, 0);
31516 const logBase = where(condition, log$2(base), zerosLike$3(base));
31517 let res = mul(dy, mul(y, logBase));
31518 const reduceAxes = getReductionAxes(exp.shape, outShape);
31519 if (reduceAxes.length > 0) {
31520 res = sum$3(res, reduceAxes);
31521 }
31522 return reshape$3(res, exp.shape);
31523 };
31524 return { a: derBase, b: derExp };
31525 }
31526 };
31527
31528 /**
31529 * @license
31530 * Copyright 2020 Google LLC. All Rights Reserved.
31531 * Licensed under the Apache License, Version 2.0 (the "License");
31532 * you may not use this file except in compliance with the License.
31533 * You may obtain a copy of the License at
31534 *
31535 * http://www.apache.org/licenses/LICENSE-2.0
31536 *
31537 * Unless required by applicable law or agreed to in writing, software
31538 * distributed under the License is distributed on an "AS IS" BASIS,
31539 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31540 * See the License for the specific language governing permissions and
31541 * limitations under the License.
31542 * =============================================================================
31543 */
31544 const preluGradConfig = {
31545 kernelName: Prelu,
31546 inputsToSave: ['x', 'alpha'],
31547 gradFunc: (dy, saved) => {
31548 const [x, alpha] = saved;
31549 const mask = greater$3(x, 0);
31550 return {
31551 x: () => where(mask, dy, mul(dy, alpha)),
31552 alpha: () => {
31553 let res = where(mask, zerosLike$3(dy), mul(dy, x));
31554 const reduceAxes = getReductionAxes(alpha.shape, dy.shape);
31555 if (reduceAxes.length > 0) {
31556 res = sum$3(res, reduceAxes);
31557 }
31558 return reshape$3(res, alpha.shape);
31559 }
31560 };
31561 }
31562 };
31563
31564 /**
31565 * @license
31566 * Copyright 2022 Google Inc. All Rights Reserved.
31567 * Licensed under the Apache License, Version 2.0 (the "License");
31568 * you may not use this file except in compliance with the License.
31569 * You may obtain a copy of the License at
31570 *
31571 * http://www.apache.org/licenses/LICENSE-2.0
31572 *
31573 * Unless required by applicable law or agreed to in writing, software
31574 * distributed under the License is distributed on an "AS IS" BASIS,
31575 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31576 * See the License for the specific language governing permissions and
31577 * limitations under the License.
31578 * =============================================================================
31579 */
31580 // Gradient for product operation on a single axis.
31581 function prodGradFn_(x, dy, axis) {
31582 // The gradient tensor (dy) has a set of axes removed, so we create re-shaped
31583 // versions (of size 1) for the removed axis; this supports broadcasting over
31584 // those dimensions.
31585 const expandedYShape = x.shape.slice();
31586 expandedYShape[axis] = 1;
31587 // The actual gradient computation.
31588 const expandedDy = reshape$3(dy, expandedYShape);
31589 const xCumProd = cumprod$2(x, axis, true, false);
31590 const xCumRevProd = cumprod$2(x, axis, true, true);
31591 const dx = mul(xCumProd, xCumRevProd);
31592 return mul(expandedDy, dx);
31593 }
31594 // Support gradients when the product is done on many axes at once.
31595 // This done py pushing all the axes on which the product is applied into a
31596 // single axis.
31597 function prodsGradFn_(x, dy, axis) {
31598 // Move all axes for doing prod over to the end of the tensor.
31599 const xRank = x.shape.length;
31600 const finalProdAxis = xRank - axis.length;
31601 const xPermutation = getAxesPermutation(axis, xRank);
31602 let permutedX = x;
31603 if (xPermutation != null) {
31604 permutedX = transpose$2(x, xPermutation);
31605 }
31606 // Reshape all the prod dimensions into a single one, and do compute prod
31607 // gradients on that.
31608 const newShape = permutedX.shape.slice();
31609 const removedShape = newShape.splice(xRank - axis.length, axis.length);
31610 const endPartShape = removedShape.reduce((p, c) => p * c, 1);
31611 newShape.push(endPartShape);
31612 const reshapedPermutedX = permutedX.reshape(newShape);
31613 let prodGrad = prodGradFn_(reshapedPermutedX, dy, finalProdAxis);
31614 // Undo the re-shaping now we have the dx vector, and permute back to
31615 // original axes order.
31616 prodGrad = prodGrad.reshape(permutedX.shape);
31617 if (xPermutation != null) {
31618 const undoPermutation = getUndoAxesPermutation(xPermutation);
31619 prodGrad = transpose$2(prodGrad, undoPermutation);
31620 }
31621 return prodGrad;
31622 }
31623 // Running example:
31624 // [
31625 // [
31626 // [3.0, 4.0],
31627 // [5.0, 6.0],
31628 // [7.0, 8.0]
31629 // ],
31630 // [
31631 // [3.0, 5.0],
31632 // [0.0, 6.0],
31633 // [5.0, 6.0]
31634 // ]
31635 // ]
31636 //
31637 const prodGradConfig = {
31638 kernelName: Prod,
31639 inputsToSave: ['x'],
31640 gradFunc: (dy, saved, attrs) => {
31641 const [x] = saved;
31642 const { axis } = attrs;
31643 let axisArr = [];
31644 if (axis === undefined || axis === null) {
31645 axisArr = x.shape.map((_, i) => i);
31646 }
31647 else if (typeof axis === 'number') {
31648 axisArr = [axis];
31649 }
31650 else {
31651 axisArr = axis;
31652 }
31653 return { x: () => prodsGradFn_(x, dy, axisArr) };
31654 }
31655 };
31656
31657 /**
31658 * @license
31659 * Copyright 2020 Google LLC. All Rights Reserved.
31660 * Licensed under the Apache License, Version 2.0 (the "License");
31661 * you may not use this file except in compliance with the License.
31662 * You may obtain a copy of the License at
31663 *
31664 * http://www.apache.org/licenses/LICENSE-2.0
31665 *
31666 * Unless required by applicable law or agreed to in writing, software
31667 * distributed under the License is distributed on an "AS IS" BASIS,
31668 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31669 * See the License for the specific language governing permissions and
31670 * limitations under the License.
31671 * =============================================================================
31672 */
31673 const divGradConfig = {
31674 kernelName: RealDiv,
31675 inputsToSave: ['a', 'b'],
31676 gradFunc: (dy, saved) => {
31677 const [a, b] = saved;
31678 const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
31679 const derA = () => {
31680 const res = div$1(dy, cast$3(b, 'float32'));
31681 const reduceAxes = getReductionAxes(a.shape, outShape);
31682 if (reduceAxes.length > 0) {
31683 return reshape$3(sum$3(res, reduceAxes), a.shape);
31684 }
31685 return res;
31686 };
31687 const derB = () => {
31688 let res = mul(dy, cast$3(a, 'float32'));
31689 const reduceAxes = getReductionAxes(b.shape, outShape);
31690 if (reduceAxes.length > 0) {
31691 res = reshape$3(sum$3(res, reduceAxes), b.shape);
31692 }
31693 const tmp = square$2(b);
31694 return neg$2(div$1(res, cast$3(tmp, 'float32')));
31695 };
31696 return { a: derA, b: derB };
31697 }
31698 };
31699
31700 /**
31701 * @license
31702 * Copyright 2020 Google LLC. All Rights Reserved.
31703 * Licensed under the Apache License, Version 2.0 (the "License");
31704 * you may not use this file except in compliance with the License.
31705 * You may obtain a copy of the License at
31706 *
31707 * http://www.apache.org/licenses/LICENSE-2.0
31708 *
31709 * Unless required by applicable law or agreed to in writing, software
31710 * distributed under the License is distributed on an "AS IS" BASIS,
31711 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31712 * See the License for the specific language governing permissions and
31713 * limitations under the License.
31714 * =============================================================================
31715 */
31716 const reciprocalGradConfig = {
31717 kernelName: Reciprocal,
31718 inputsToSave: ['x'],
31719 gradFunc: (dy, saved) => {
31720 const [x] = saved;
31721 return { x: () => div$1(dy, neg$2(square$2(x))) };
31722 }
31723 };
31724
31725 /**
31726 * @license
31727 * Copyright 2020 Google LLC. All Rights Reserved.
31728 * Licensed under the Apache License, Version 2.0 (the "License");
31729 * you may not use this file except in compliance with the License.
31730 * You may obtain a copy of the License at
31731 *
31732 * http://www.apache.org/licenses/LICENSE-2.0
31733 *
31734 * Unless required by applicable law or agreed to in writing, software
31735 * distributed under the License is distributed on an "AS IS" BASIS,
31736 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31737 * See the License for the specific language governing permissions and
31738 * limitations under the License.
31739 * =============================================================================
31740 */
31741 const relu6GradConfig = {
31742 kernelName: Relu6$1,
31743 inputsToSave: ['x'],
31744 gradFunc: (dy, saved) => {
31745 const [x] = saved;
31746 const mask = mul(lessEqual$2(x, 6), step$2(x));
31747 return { x: () => mul(dy, cast$3(mask, 'float32')) };
31748 }
31749 };
31750
31751 /**
31752 * @license
31753 * Copyright 2020 Google LLC. All Rights Reserved.
31754 * Licensed under the Apache License, Version 2.0 (the "License");
31755 * you may not use this file except in compliance with the License.
31756 * You may obtain a copy of the License at
31757 *
31758 * http://www.apache.org/licenses/LICENSE-2.0
31759 *
31760 * Unless required by applicable law or agreed to in writing, software
31761 * distributed under the License is distributed on an "AS IS" BASIS,
31762 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31763 * See the License for the specific language governing permissions and
31764 * limitations under the License.
31765 * =============================================================================
31766 */
31767 const reluGradConfig = {
31768 kernelName: Relu$1,
31769 inputsToSave: ['x'],
31770 gradFunc: (dy, saved) => {
31771 const [x] = saved;
31772 return { x: () => mul(dy, cast$3(step$2(x), 'float32')) };
31773 }
31774 };
31775
31776 /**
31777 * @license
31778 * Copyright 2020 Google Inc. All Rights Reserved.
31779 * Licensed under the Apache License, Version 2.0 (the "License");
31780 * you may not use this file except in compliance with the License.
31781 * You may obtain a copy of the License at
31782 *
31783 * http://www.apache.org/licenses/LICENSE-2.0
31784 *
31785 * Unless required by applicable law or agreed to in writing, software
31786 * distributed under the License is distributed on an "AS IS" BASIS,
31787 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31788 * See the License for the specific language governing permissions and
31789 * limitations under the License.
31790 * =============================================================================
31791 */
31792 const reshapeGradConfig = {
31793 kernelName: Reshape$1,
31794 inputsToSave: ['x'],
31795 gradFunc: (dy, saved) => {
31796 const [x] = saved;
31797 return { x: () => reshape$3(dy, x.shape) };
31798 }
31799 };
31800
31801 /**
31802 * @license
31803 * Copyright 2020 Google LLC. All Rights Reserved.
31804 * Licensed under the Apache License, Version 2.0 (the "License");
31805 * you may not use this file except in compliance with the License.
31806 * You may obtain a copy of the License at
31807 *
31808 * http://www.apache.org/licenses/LICENSE-2.0
31809 *
31810 * Unless required by applicable law or agreed to in writing, software
31811 * distributed under the License is distributed on an "AS IS" BASIS,
31812 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31813 * See the License for the specific language governing permissions and
31814 * limitations under the License.
31815 * =============================================================================
31816 */
31817 const resizeBilinearGradConfig$2 = {
31818 kernelName: ResizeBilinear,
31819 inputsToSave: ['images'],
31820 gradFunc: (dy, saved, attrs) => {
31821 const [images] = saved;
31822 const inputs = { dy, images };
31823 const imagesDer = () =>
31824 // tslint:disable-next-line: no-unnecessary-type-assertion
31825 ENGINE.runKernel(ResizeBilinearGrad, inputs, attrs);
31826 return { images: imagesDer };
31827 }
31828 };
31829
31830 /**
31831 * @license
31832 * Copyright 2020 Google LLC. All Rights Reserved.
31833 * Licensed under the Apache License, Version 2.0 (the "License");
31834 * you may not use this file except in compliance with the License.
31835 * You may obtain a copy of the License at
31836 *
31837 * http://www.apache.org/licenses/LICENSE-2.0
31838 *
31839 * Unless required by applicable law or agreed to in writing, software
31840 * distributed under the License is distributed on an "AS IS" BASIS,
31841 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31842 * See the License for the specific language governing permissions and
31843 * limitations under the License.
31844 * =============================================================================
31845 */
31846 const resizeNearestNeighborGradConfig$2 = {
31847 kernelName: ResizeNearestNeighbor,
31848 inputsToSave: ['images'],
31849 gradFunc: (dy, saved, attrs) => {
31850 const [images] = saved;
31851 const inputs = { dy, images };
31852 const imagesDer = () =>
31853 // tslint:disable-next-line: no-unnecessary-type-assertion
31854 ENGINE.runKernel(ResizeNearestNeighborGrad, inputs, attrs);
31855 return { images: imagesDer };
31856 }
31857 };
31858
31859 /**
31860 * @license
31861 * Copyright 2020 Google LLC. All Rights Reserved.
31862 * Licensed under the Apache License, Version 2.0 (the "License");
31863 * you may not use this file except in compliance with the License.
31864 * You may obtain a copy of the License at
31865 *
31866 * http://www.apache.org/licenses/LICENSE-2.0
31867 *
31868 * Unless required by applicable law or agreed to in writing, software
31869 * distributed under the License is distributed on an "AS IS" BASIS,
31870 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31871 * See the License for the specific language governing permissions and
31872 * limitations under the License.
31873 * =============================================================================
31874 */
31875 const reverseGradConfig = {
31876 kernelName: Reverse,
31877 gradFunc: (dy, saved, attrs) => {
31878 const { dims } = attrs;
31879 const axes = parseAxisParam(dims, dy.shape);
31880 return { x: () => reverse$2(dy, axes) };
31881 }
31882 };
31883
31884 /**
31885 * @license
31886 * Copyright 2020 Google LLC. All Rights Reserved.
31887 * Licensed under the Apache License, Version 2.0 (the "License");
31888 * you may not use this file except in compliance with the License.
31889 * You may obtain a copy of the License at
31890 *
31891 * http://www.apache.org/licenses/LICENSE-2.0
31892 *
31893 * Unless required by applicable law or agreed to in writing, software
31894 * distributed under the License is distributed on an "AS IS" BASIS,
31895 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31896 * See the License for the specific language governing permissions and
31897 * limitations under the License.
31898 * =============================================================================
31899 */
31900 const roundGradConfig = {
31901 kernelName: Round,
31902 gradFunc: (dy) => {
31903 // TODO(nsthorat): Let gradients be null for cases where we want to stop
31904 // backpropgation.
31905 return { x: () => zerosLike$3(dy) };
31906 }
31907 };
31908
31909 /**
31910 * @license
31911 * Copyright 2020 Google LLC. All Rights Reserved.
31912 * Licensed under the Apache License, Version 2.0 (the "License");
31913 * you may not use this file except in compliance with the License.
31914 * You may obtain a copy of the License at
31915 *
31916 * http://www.apache.org/licenses/LICENSE-2.0
31917 *
31918 * Unless required by applicable law or agreed to in writing, software
31919 * distributed under the License is distributed on an "AS IS" BASIS,
31920 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31921 * See the License for the specific language governing permissions and
31922 * limitations under the License.
31923 * =============================================================================
31924 */
31925 const rsqrtGradConfig = {
31926 kernelName: Rsqrt,
31927 inputsToSave: ['x'],
31928 gradFunc: (dy, saved) => {
31929 const [x] = saved;
31930 return { x: () => neg$2(div$1(dy, mul(pow$3(x, 1.5), 2))) };
31931 }
31932 };
31933
31934 /**
31935 * @license
31936 * Copyright 2020 Google LLC. All Rights Reserved.
31937 * Licensed under the Apache License, Version 2.0 (the "License");
31938 * you may not use this file except in compliance with the License.
31939 * You may obtain a copy of the License at
31940 *
31941 * http://www.apache.org/licenses/LICENSE-2.0
31942 *
31943 * Unless required by applicable law or agreed to in writing, software
31944 * distributed under the License is distributed on an "AS IS" BASIS,
31945 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31946 * See the License for the specific language governing permissions and
31947 * limitations under the License.
31948 * =============================================================================
31949 */
31950 const selectGradConfig = {
31951 kernelName: Select,
31952 inputsToSave: ['condition'],
31953 gradFunc: (dy, saved) => {
31954 const [condition] = saved;
31955 return {
31956 // TODO(julianoks): Return null for condition gradient
31957 // when backprop supports it.
31958 condition: () => cast$3(zerosLike$3(condition), 'float32'),
31959 t: () => mul(dy, cast$3(condition, dy.dtype)),
31960 e: () => mul(dy, cast$3(logicalNot$2(condition), dy.dtype))
31961 };
31962 }
31963 };
31964
31965 /**
31966 * @license
31967 * Copyright 2020 Google LLC. All Rights Reserved.
31968 * Licensed under the Apache License, Version 2.0 (the "License");
31969 * you may not use this file except in compliance with the License.
31970 * You may obtain a copy of the License at
31971 *
31972 * http://www.apache.org/licenses/LICENSE-2.0
31973 *
31974 * Unless required by applicable law or agreed to in writing, software
31975 * distributed under the License is distributed on an "AS IS" BASIS,
31976 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31977 * See the License for the specific language governing permissions and
31978 * limitations under the License.
31979 * =============================================================================
31980 */
31981 const seluGradConfig = {
31982 kernelName: Selu$1,
31983 inputsToSave: ['x'],
31984 gradFunc: (dy, saved) => {
31985 const [x] = saved;
31986 return {
31987 x: () => {
31988 const mask = greater$3(x, scalar(0));
31989 const scaleAlpha = scalar(SELU_SCALEALPHA);
31990 const scale = scalar(SELU_SCALE);
31991 const greaterThanZeroDer = mul(dy, scale);
31992 const lessEqualZeroDer = mul(mul(dy, scaleAlpha), exp$2(cast$3(x, 'float32')));
31993 return where(mask, greaterThanZeroDer, lessEqualZeroDer);
31994 }
31995 };
31996 }
31997 };
31998
31999 /**
32000 * @license
32001 * Copyright 2020 Google LLC. All Rights Reserved.
32002 * Licensed under the Apache License, Version 2.0 (the "License");
32003 * you may not use this file except in compliance with the License.
32004 * You may obtain a copy of the License at
32005 *
32006 * http://www.apache.org/licenses/LICENSE-2.0
32007 *
32008 * Unless required by applicable law or agreed to in writing, software
32009 * distributed under the License is distributed on an "AS IS" BASIS,
32010 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32011 * See the License for the specific language governing permissions and
32012 * limitations under the License.
32013 * =============================================================================
32014 */
32015 const sigmoidGradConfig = {
32016 kernelName: Sigmoid$1,
32017 outputsToSave: [true],
32018 gradFunc: (dy, saved) => {
32019 const [y] = saved;
32020 return { x: () => mul(dy, mul(y, sub$2(scalar(1), y))) };
32021 }
32022 };
32023
32024 /**
32025 * @license
32026 * Copyright 2020 Google LLC. All Rights Reserved.
32027 * Licensed under the Apache License, Version 2.0 (the "License");
32028 * you may not use this file except in compliance with the License.
32029 * You may obtain a copy of the License at
32030 *
32031 * http://www.apache.org/licenses/LICENSE-2.0
32032 *
32033 * Unless required by applicable law or agreed to in writing, software
32034 * distributed under the License is distributed on an "AS IS" BASIS,
32035 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32036 * See the License for the specific language governing permissions and
32037 * limitations under the License.
32038 * =============================================================================
32039 */
32040 const signGradConfig = {
32041 kernelName: Sign,
32042 gradFunc: (dy) => {
32043 return { x: () => zerosLike$3(dy) };
32044 }
32045 };
32046
32047 /**
32048 * @license
32049 * Copyright 2020 Google LLC. All Rights Reserved.
32050 * Licensed under the Apache License, Version 2.0 (the "License");
32051 * you may not use this file except in compliance with the License.
32052 * You may obtain a copy of the License at
32053 *
32054 * http://www.apache.org/licenses/LICENSE-2.0
32055 *
32056 * Unless required by applicable law or agreed to in writing, software
32057 * distributed under the License is distributed on an "AS IS" BASIS,
32058 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32059 * See the License for the specific language governing permissions and
32060 * limitations under the License.
32061 * =============================================================================
32062 */
32063 const sinGradConfig = {
32064 kernelName: Sin,
32065 inputsToSave: ['x'],
32066 gradFunc: (dy, saved) => {
32067 const [x] = saved;
32068 return { x: () => mul(cos$2(cast$3(x, 'float32')), dy) };
32069 }
32070 };
32071
32072 /**
32073 * @license
32074 * Copyright 2020 Google LLC. All Rights Reserved.
32075 * Licensed under the Apache License, Version 2.0 (the "License");
32076 * you may not use this file except in compliance with the License.
32077 * You may obtain a copy of the License at
32078 *
32079 * http://www.apache.org/licenses/LICENSE-2.0
32080 *
32081 * Unless required by applicable law or agreed to in writing, software
32082 * distributed under the License is distributed on an "AS IS" BASIS,
32083 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32084 * See the License for the specific language governing permissions and
32085 * limitations under the License.
32086 * =============================================================================
32087 */
32088 const sinhGradConfig = {
32089 kernelName: Sinh,
32090 inputsToSave: ['x'],
32091 gradFunc: (dy, saved) => {
32092 const [x] = saved;
32093 return { x: () => mul(cosh$2(cast$3(x, 'float32')), dy) };
32094 }
32095 };
32096
32097 /**
32098 * @license
32099 * Copyright 2020 Google LLC. All Rights Reserved.
32100 * Licensed under the Apache License, Version 2.0 (the "License");
32101 * you may not use this file except in compliance with the License.
32102 * You may obtain a copy of the License at
32103 *
32104 * http://www.apache.org/licenses/LICENSE-2.0
32105 *
32106 * Unless required by applicable law or agreed to in writing, software
32107 * distributed under the License is distributed on an "AS IS" BASIS,
32108 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32109 * See the License for the specific language governing permissions and
32110 * limitations under the License.
32111 * =============================================================================
32112 */
32113 const sliceGradConfig = {
32114 kernelName: Slice,
32115 inputsToSave: ['x'],
32116 gradFunc: (dy, saved, attrs) => {
32117 const [x] = saved;
32118 const { begin, size } = attrs;
32119 const inputShape = x.shape;
32120 const [begin_, size_] = parseSliceParams(x, begin, size);
32121 // Create an Nx2 padding where the first column represents how many
32122 // zeros are prepended (at start) for each dimension, and the second
32123 // column indicates how many zeros are appended (at end).
32124 // The number of zeros to append is the shape of the input
32125 // elementwise-subtracted by both the begin vector and sizes vector.
32126 const paddings = [];
32127 for (let i = 0; i < dy.rank; i++) {
32128 paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]);
32129 }
32130 return { x: () => pad(dy, paddings) };
32131 }
32132 };
32133
32134 /**
32135 * @license
32136 * Copyright 2020 Google LLC. All Rights Reserved.
32137 * Licensed under the Apache License, Version 2.0 (the "License");
32138 * you may not use this file except in compliance with the License.
32139 * You may obtain a copy of the License at
32140 *
32141 * http://www.apache.org/licenses/LICENSE-2.0
32142 *
32143 * Unless required by applicable law or agreed to in writing, software
32144 * distributed under the License is distributed on an "AS IS" BASIS,
32145 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32146 * See the License for the specific language governing permissions and
32147 * limitations under the License.
32148 * =============================================================================
32149 */
32150 const softmaxGradConfig = {
32151 kernelName: Softmax$2,
32152 outputsToSave: [true],
32153 gradFunc: (dy, saved, attrs) => {
32154 const [y] = saved;
32155 const { dim } = attrs;
32156 const keepDims = true;
32157 const dyTimesY = mul(dy, y);
32158 return {
32159 logits: () => sub$2(dyTimesY, mul(sum$3(dyTimesY, [dim], keepDims), y))
32160 };
32161 }
32162 };
32163
32164 /**
32165 * @license
32166 * Copyright 2020 Google LLC. All Rights Reserved.
32167 * Licensed under the Apache License, Version 2.0 (the "License");
32168 * you may not use this file except in compliance with the License.
32169 * You may obtain a copy of the License at
32170 *
32171 * http://www.apache.org/licenses/LICENSE-2.0
32172 *
32173 * Unless required by applicable law or agreed to in writing, software
32174 * distributed under the License is distributed on an "AS IS" BASIS,
32175 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32176 * See the License for the specific language governing permissions and
32177 * limitations under the License.
32178 * =============================================================================
32179 */
32180 const softplusGradConfig = {
32181 kernelName: Softplus$1,
32182 inputsToSave: ['x'],
32183 gradFunc: (dy, saved) => {
32184 const [x] = saved;
32185 return { x: () => mul(dy, sigmoid$2(x)) };
32186 }
32187 };
32188
32189 /**
32190 * @license
32191 * Copyright 2020 Google LLC. All Rights Reserved.
32192 * Licensed under the Apache License, Version 2.0 (the "License");
32193 * you may not use this file except in compliance with the License.
32194 * You may obtain a copy of the License at
32195 *
32196 * http://www.apache.org/licenses/LICENSE-2.0
32197 *
32198 * Unless required by applicable law or agreed to in writing, software
32199 * distributed under the License is distributed on an "AS IS" BASIS,
32200 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32201 * See the License for the specific language governing permissions and
32202 * limitations under the License.
32203 * =============================================================================
32204 */
32205 const spaceToBatchNDGradConfig = {
32206 kernelName: SpaceToBatchND,
32207 gradFunc: (dy, saved, attrs) => {
32208 const { blockShape, paddings } = attrs;
32209 return { x: () => batchToSpaceND$2(dy, blockShape, paddings) };
32210 }
32211 };
32212
32213 /**
32214 * @license
32215 * Copyright 2020 Google LLC. All Rights Reserved.
32216 * Licensed under the Apache License, Version 2.0 (the "License");
32217 * you may not use this file except in compliance with the License.
32218 * You may obtain a copy of the License at
32219 *
32220 * http://www.apache.org/licenses/LICENSE-2.0
32221 *
32222 * Unless required by applicable law or agreed to in writing, software
32223 * distributed under the License is distributed on an "AS IS" BASIS,
32224 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32225 * See the License for the specific language governing permissions and
32226 * limitations under the License.
32227 * =============================================================================
32228 */
32229 const splitVGradConfig = {
32230 kernelName: SplitV,
32231 gradFunc: (dy, saved, attrs) => {
32232 const { axis } = attrs;
32233 return { x: () => concat$2(dy, axis) };
32234 }
32235 };
32236
32237 /**
32238 * @license
32239 * Copyright 2020 Google LLC. All Rights Reserved.
32240 * Licensed under the Apache License, Version 2.0 (the "License");
32241 * you may not use this file except in compliance with the License.
32242 * You may obtain a copy of the License at
32243 *
32244 * http://www.apache.org/licenses/LICENSE-2.0
32245 *
32246 * Unless required by applicable law or agreed to in writing, software
32247 * distributed under the License is distributed on an "AS IS" BASIS,
32248 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32249 * See the License for the specific language governing permissions and
32250 * limitations under the License.
32251 * =============================================================================
32252 */
32253 const sqrtGradConfig = {
32254 kernelName: Sqrt,
32255 inputsToSave: ['x'],
32256 gradFunc: (dy, saved) => {
32257 const [x] = saved;
32258 return { x: () => div$1(dy, mul(sqrt$2(cast$3(x, 'float32')), 2)) };
32259 }
32260 };
32261
32262 /**
32263 * @license
32264 * Copyright 2019 Google LLC. All Rights Reserved.
32265 * Licensed under the Apache License, Version 2.0 (the "License");
32266 * you may not use this file except in compliance with the License.
32267 * You may obtain a copy of the License at
32268 *
32269 * http://www.apache.org/licenses/LICENSE-2.0
32270 *
32271 * Unless required by applicable law or agreed to in writing, software
32272 * distributed under the License is distributed on an "AS IS" BASIS,
32273 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32274 * See the License for the specific language governing permissions and
32275 * limitations under the License.
32276 * =============================================================================
32277 */
32278 const squareGradConfig = {
32279 kernelName: Square,
32280 inputsToSave: ['x'],
32281 gradFunc: (dy, saved) => {
32282 const [x] = saved;
32283 return { x: () => mul(dy, mul(cast$3(x, 'float32'), 2)) };
32284 }
32285 };
32286
32287 /**
32288 * @license
32289 * Copyright 2020 Google LLC. All Rights Reserved.
32290 * Licensed under the Apache License, Version 2.0 (the "License");
32291 * you may not use this file except in compliance with the License.
32292 * You may obtain a copy of the License at
32293 *
32294 * http://www.apache.org/licenses/LICENSE-2.0
32295 *
32296 * Unless required by applicable law or agreed to in writing, software
32297 * distributed under the License is distributed on an "AS IS" BASIS,
32298 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32299 * See the License for the specific language governing permissions and
32300 * limitations under the License.
32301 * =============================================================================
32302 */
32303 const squaredDifferenceGradConfig = {
32304 kernelName: SquaredDifference,
32305 inputsToSave: ['a', 'b'],
32306 gradFunc: (dy, saved) => {
32307 const [a, b] = saved;
32308 const two = scalar(2);
32309 const derA = () => mul(dy, mul(two, sub$2(a, b)));
32310 const derB = () => mul(dy, mul(two, sub$2(b, a)));
32311 return { a: derA, b: derB };
32312 }
32313 };
32314
32315 /**
32316 * @license
32317 * Copyright 2020 Google LLC. All Rights Reserved.
32318 * Licensed under the Apache License, Version 2.0 (the "License");
32319 * you may not use this file except in compliance with the License.
32320 * You may obtain a copy of the License at
32321 *
32322 * http://www.apache.org/licenses/LICENSE-2.0
32323 *
32324 * Unless required by applicable law or agreed to in writing, software
32325 * distributed under the License is distributed on an "AS IS" BASIS,
32326 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32327 * See the License for the specific language governing permissions and
32328 * limitations under the License.
32329 * =============================================================================
32330 */
32331 const stepGradConfig = {
32332 kernelName: Step,
32333 gradFunc: (dy) => {
32334 // TODO(manrajgrover): Return null for gradients when backprop supports
32335 // it.
32336 return { x: () => zerosLike$3(dy) };
32337 }
32338 };
32339
32340 /**
32341 * @license
32342 * Copyright 2020 Google LLC. All Rights Reserved.
32343 * Licensed under the Apache License, Version 2.0 (the "License");
32344 * you may not use this file except in compliance with the License.
32345 * You may obtain a copy of the License at
32346 *
32347 * http://www.apache.org/licenses/LICENSE-2.0
32348 *
32349 * Unless required by applicable law or agreed to in writing, software
32350 * distributed under the License is distributed on an "AS IS" BASIS,
32351 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32352 * See the License for the specific language governing permissions and
32353 * limitations under the License.
32354 * =============================================================================
32355 */
32356 const subGradConfig = {
32357 kernelName: Sub,
32358 inputsToSave: ['a', 'b'],
32359 gradFunc: (dy, saved) => {
32360 const [a, b] = saved;
32361 const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
32362 const derA = () => {
32363 let res = dy;
32364 const reduceAxes = getReductionAxes(a.shape, outShape);
32365 if (reduceAxes.length > 0) {
32366 res = sum$3(res, reduceAxes);
32367 }
32368 return reshape$3(res, a.shape);
32369 };
32370 const derB = () => {
32371 let res = dy;
32372 const reduceAxes = getReductionAxes(b.shape, outShape);
32373 if (reduceAxes.length > 0) {
32374 res = sum$3(res, reduceAxes);
32375 }
32376 return reshape$3(neg$2(res), b.shape);
32377 };
32378 return { a: derA, b: derB };
32379 }
32380 };
32381
32382 /**
32383 * @license
32384 * Copyright 2020 Google Inc. All Rights Reserved.
32385 * Licensed under the Apache License, Version 2.0 (the "License");
32386 * you may not use this file except in compliance with the License.
32387 * You may obtain a copy of the License at
32388 *
32389 * http://www.apache.org/licenses/LICENSE-2.0
32390 *
32391 * Unless required by applicable law or agreed to in writing, software
32392 * distributed under the License is distributed on an "AS IS" BASIS,
32393 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32394 * See the License for the specific language governing permissions and
32395 * limitations under the License.
32396 * =============================================================================
32397 */
32398 const sumGradConfig = {
32399 kernelName: Sum,
32400 inputsToSave: ['x'],
32401 gradFunc: (dy, saved, attrs) => {
32402 const [x] = saved;
32403 const expandedDyShape = x.shape.slice();
32404 const { axis } = attrs;
32405 const axes = parseAxisParam(axis, x.shape);
32406 axes.forEach(axis => {
32407 expandedDyShape[axis] = 1;
32408 });
32409 const expandedDy = reshape$3(dy, expandedDyShape);
32410 const derX = mul(expandedDy, ones$1(x.shape, 'float32'));
32411 return { x: () => derX };
32412 }
32413 };
32414
32415 /**
32416 * @license
32417 * Copyright 2020 Google LLC. All Rights Reserved.
32418 * Licensed under the Apache License, Version 2.0 (the "License");
32419 * you may not use this file except in compliance with the License.
32420 * You may obtain a copy of the License at
32421 *
32422 * http://www.apache.org/licenses/LICENSE-2.0
32423 *
32424 * Unless required by applicable law or agreed to in writing, software
32425 * distributed under the License is distributed on an "AS IS" BASIS,
32426 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32427 * See the License for the specific language governing permissions and
32428 * limitations under the License.
32429 * =============================================================================
32430 */
32431 const tanGradConfig = {
32432 kernelName: Tan,
32433 inputsToSave: ['x'],
32434 gradFunc: (dy, saved) => {
32435 const [x] = saved;
32436 return { x: () => div$1(dy, square$2(cos$2(x))) };
32437 }
32438 };
32439
32440 /**
32441 * @license
32442 * Copyright 2020 Google LLC. All Rights Reserved.
32443 * Licensed under the Apache License, Version 2.0 (the "License");
32444 * you may not use this file except in compliance with the License.
32445 * You may obtain a copy of the License at
32446 *
32447 * http://www.apache.org/licenses/LICENSE-2.0
32448 *
32449 * Unless required by applicable law or agreed to in writing, software
32450 * distributed under the License is distributed on an "AS IS" BASIS,
32451 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32452 * See the License for the specific language governing permissions and
32453 * limitations under the License.
32454 * =============================================================================
32455 */
32456 const tanhGradConfig = {
32457 kernelName: Tanh$1,
32458 outputsToSave: [true],
32459 gradFunc: (dy, saved) => {
32460 const [y] = saved;
32461 return { x: () => mul(sub$2(scalar(1), square$2(y)), dy) };
32462 }
32463 };
32464
32465 /**
32466 * @license
32467 * Copyright 2020 Google LLC. All Rights Reserved.
32468 * Licensed under the Apache License, Version 2.0 (the "License");
32469 * you may not use this file except in compliance with the License.
32470 * You may obtain a copy of the License at
32471 *
32472 * http://www.apache.org/licenses/LICENSE-2.0
32473 *
32474 * Unless required by applicable law or agreed to in writing, software
32475 * distributed under the License is distributed on an "AS IS" BASIS,
32476 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32477 * See the License for the specific language governing permissions and
32478 * limitations under the License.
32479 * =============================================================================
32480 */
32481 const tileGradConfig = {
32482 kernelName: Tile,
32483 inputsToSave: ['x'],
32484 gradFunc: (dy, saved, attrs) => {
32485 const [x] = saved;
32486 const { reps } = attrs;
32487 const derX = () => {
32488 let xGrad = zerosLike$3(x);
32489 // TODO(cais): Maybe reduce memory footprint by avoiding repeated
32490 // slicing.
32491 if (x.rank === 1) {
32492 for (let i = 0; i < reps[0]; ++i) {
32493 xGrad = add$3(xGrad, slice$2(dy, [i * x.shape[0]], [x.shape[0]]));
32494 }
32495 }
32496 else if (x.rank === 2) {
32497 for (let i = 0; i < reps[0]; ++i) {
32498 for (let j = 0; j < reps[1]; ++j) {
32499 xGrad = add$3(xGrad, slice$2(dy, [i * x.shape[0], j * x.shape[1]], [
32500 x.shape[0], x.shape[1]
32501 ]));
32502 }
32503 }
32504 }
32505 else if (x.rank === 3) {
32506 for (let i = 0; i < reps[0]; ++i) {
32507 for (let j = 0; j < reps[1]; ++j) {
32508 for (let k = 0; k < reps[2]; ++k) {
32509 xGrad =
32510 add$3(xGrad, slice$2(dy, [i * x.shape[0], j * x.shape[1], k * x.shape[2]], [x.shape[0], x.shape[1], x.shape[2]]));
32511 }
32512 }
32513 }
32514 }
32515 else if (x.rank === 4) {
32516 for (let i = 0; i < reps[0]; ++i) {
32517 for (let j = 0; j < reps[1]; ++j) {
32518 for (let k = 0; k < reps[2]; ++k) {
32519 for (let l = 0; l < reps[3]; ++l) {
32520 xGrad =
32521 add$3(xGrad, slice$2(dy, [
32522 i * x.shape[0], j * x.shape[1], k * x.shape[2],
32523 l * x.shape[3]
32524 ], [x.shape[0], x.shape[1], x.shape[2], x.shape[3]]));
32525 }
32526 }
32527 }
32528 }
32529 }
32530 else {
32531 throw new Error(`Gradient for tile operation is not implemented for rank-` +
32532 `${x.rank} tensors yet.`);
32533 }
32534 return xGrad;
32535 };
32536 return { x: derX };
32537 },
32538 };
32539
32540 /**
32541 * @license
32542 * Copyright 2020 Google LLC. All Rights Reserved.
32543 * Licensed under the Apache License, Version 2.0 (the "License");
32544 * you may not use this file except in compliance with the License.
32545 * You may obtain a copy of the License at
32546 *
32547 * http://www.apache.org/licenses/LICENSE-2.0
32548 *
32549 * Unless required by applicable law or agreed to in writing, software
32550 * distributed under the License is distributed on an "AS IS" BASIS,
32551 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32552 * See the License for the specific language governing permissions and
32553 * limitations under the License.
32554 * =============================================================================
32555 */
32556 const transposeGradConfig = {
32557 kernelName: Transpose,
32558 gradFunc: (dy, saved, attrs) => {
32559 const transposeAttrs = attrs;
32560 const { perm } = transposeAttrs;
32561 const undoPerm = getUndoAxesPermutation(perm);
32562 return { x: () => transpose$2(dy, undoPerm) };
32563 }
32564 };
32565
32566 /**
32567 * @license
32568 * Copyright 2020 Google Inc. All Rights Reserved.
32569 * Licensed under the Apache License, Version 2.0 (the "License");
32570 * you may not use this file except in compliance with the License.
32571 * You may obtain a copy of the License at
32572 *
32573 * http://www.apache.org/licenses/LICENSE-2.0
32574 *
32575 * Unless required by applicable law or agreed to in writing, software
32576 * distributed under the License is distributed on an "AS IS" BASIS,
32577 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32578 * See the License for the specific language governing permissions and
32579 * limitations under the License.
32580 * =============================================================================
32581 */
32582 const unpackGradConfig = {
32583 kernelName: Unpack,
32584 gradFunc: (dy, saved, attrs) => {
32585 const unpackAttrs = attrs;
32586 const { axis } = unpackAttrs;
32587 return { value: () => stack(dy, axis) };
32588 }
32589 };
32590
32591 /**
32592 * @license
32593 * Copyright 2020 Google LLC. All Rights Reserved.
32594 * Licensed under the Apache License, Version 2.0 (the "License");
32595 * you may not use this file except in compliance with the License.
32596 * You may obtain a copy of the License at
32597 *
32598 * http://www.apache.org/licenses/LICENSE-2.0
32599 *
32600 * Unless required by applicable law or agreed to in writing, software
32601 * distributed under the License is distributed on an "AS IS" BASIS,
32602 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32603 * See the License for the specific language governing permissions and
32604 * limitations under the License.
32605 * =============================================================================
32606 */
32607 const unsortedSegmentSumGradConfig = {
32608 kernelName: UnsortedSegmentSum,
32609 inputsToSave: ['segmentIds'],
32610 gradFunc: (dy, saved) => {
32611 const [segmentIds] = saved;
32612 const derX = () => {
32613 return gatherDropNegatives(dy, segmentIds);
32614 };
32615 return { x: derX };
32616 }
32617 };
32618 function gatherDropNegatives(x, indices) {
32619 // Helper function for unsorted segment ops. Gathers params for
32620 // positive segment ids and gathers 0 for inputs with negative segment id.
32621 // Mirrors _GatherDropNegatives from tensorflow/python/ops/math_grad.py
32622 const zeroClippedIndices = maximum$4(indices, zerosLike$3(indices));
32623 const gathered = gather$1(x, zeroClippedIndices);
32624 let isPositive = greaterEqual$2(indices, scalar(0, 'int32'));
32625 const numIters = gathered.rank - isPositive.rank;
32626 for (let i = 0; i < numIters; ++i) {
32627 isPositive = expandDims$3(isPositive, i + 1);
32628 }
32629 isPositive = logicalAnd$2(isPositive, ones$1(gathered.shape, 'bool'));
32630 const zeroSlice = zerosLike$3(gathered);
32631 return where(isPositive, gathered, zeroSlice);
32632 }
32633
32634 /**
32635 * @license
32636 * Copyright 2020 Google LLC. All Rights Reserved.
32637 * Licensed under the Apache License, Version 2.0 (the "License");
32638 * you may not use this file except in compliance with the License.
32639 * You may obtain a copy of the License at
32640 *
32641 * http://www.apache.org/licenses/LICENSE-2.0
32642 *
32643 * Unless required by applicable law or agreed to in writing, software
32644 * distributed under the License is distributed on an "AS IS" BASIS,
32645 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32646 * See the License for the specific language governing permissions and
32647 * limitations under the License.
32648 * =============================================================================
32649 */
32650 const zerosLikeGradConfig = {
32651 kernelName: ZerosLike,
32652 gradFunc: (dy) => {
32653 return { x: () => zerosLike$3(dy) };
32654 }
32655 };
32656
32657 /**
32658 * @license
32659 * Copyright 2020 Google LLC. All Rights Reserved.
32660 * Licensed under the Apache License, Version 2.0 (the "License");
32661 * you may not use this file except in compliance with the License.
32662 * You may obtain a copy of the License at
32663 *
32664 * http://www.apache.org/licenses/LICENSE-2.0
32665 *
32666 * Unless required by applicable law or agreed to in writing, software
32667 * distributed under the License is distributed on an "AS IS" BASIS,
32668 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32669 * See the License for the specific language governing permissions and
32670 * limitations under the License.
32671 * =============================================================================
32672 */
32673 // Export all kernel configs here so that the package can auto register them
32674 const gradConfigs = [
32675 absGradConfig,
32676 acosGradConfig,
32677 acoshGradConfig,
32678 addGradConfig,
32679 addNGradConfig,
32680 argMaxGradConfig,
32681 argMinGradConfig,
32682 asinGradConfig,
32683 asinhGradConfig,
32684 atan2GradConfig,
32685 atanGradConfig,
32686 atanhGradConfig,
32687 avgPool3DGradConfig$2,
32688 avgPoolGradConfig$2,
32689 batchMatMulGradConfig,
32690 batchToSpaceNDGradConfig,
32691 broadcastToGradConfig,
32692 castGradConfig,
32693 ceilGradConfig,
32694 clipByValueGradConfig,
32695 complexAbsGradConfig,
32696 concatGradConfig,
32697 conv2DBackpropInputGradConfig,
32698 conv2DGradConfig,
32699 conv3DGradConfig,
32700 cosGradConfig,
32701 coshGradConfig,
32702 cumsumGradConfig,
32703 depthwiseConv2dNativeGradConfig,
32704 dilation2dGradConfig,
32705 divGradConfig,
32706 eluGradConfig$2,
32707 erfGradConfig,
32708 expGradConfig,
32709 expandDimsGradConfig,
32710 expm1GradConfig,
32711 floorDivGradConfig,
32712 floorGradConfig,
32713 fusedBatchNormGradConfig,
32714 gatherGradConfig,
32715 greaterEqualGradConfig,
32716 identityGradConfig,
32717 isFiniteGradConfig,
32718 isInfGradConfig,
32719 isNanGradConfig,
32720 leakyReluGradConfig,
32721 log1pGradConfig,
32722 logGradConfig,
32723 logSoftmaxGradConfig,
32724 lrnGradConfig,
32725 maxGradConfig,
32726 maxGradConfig,
32727 maximumGradConfig,
32728 maxPool3DGradConfig$2,
32729 maxPoolGradConfig$2,
32730 meanGradConfig,
32731 minGradConfig,
32732 minimumGradConfig,
32733 mirrorPadGradConfig,
32734 modGradConfig,
32735 multiplyGradConfig,
32736 negGradConfig,
32737 oneHotGradConfig,
32738 onesLikeGradConfig,
32739 packGradConfig,
32740 padV2GradConfig,
32741 padV2GradConfig,
32742 powGradConfig,
32743 preluGradConfig,
32744 prodGradConfig,
32745 reciprocalGradConfig,
32746 relu6GradConfig,
32747 reluGradConfig,
32748 reshapeGradConfig,
32749 resizeBilinearGradConfig$2,
32750 resizeNearestNeighborGradConfig$2,
32751 reverseGradConfig,
32752 roundGradConfig,
32753 rsqrtGradConfig,
32754 selectGradConfig,
32755 seluGradConfig,
32756 sigmoidGradConfig,
32757 signGradConfig,
32758 sinGradConfig,
32759 sinhGradConfig,
32760 sliceGradConfig,
32761 softmaxGradConfig,
32762 softplusGradConfig,
32763 spaceToBatchNDGradConfig,
32764 spaceToBatchNDGradConfig,
32765 splitVGradConfig,
32766 splitVGradConfig,
32767 sqrtGradConfig,
32768 squaredDifferenceGradConfig,
32769 squareGradConfig,
32770 stepGradConfig,
32771 subGradConfig,
32772 sumGradConfig,
32773 tanGradConfig,
32774 tanhGradConfig,
32775 tileGradConfig,
32776 transposeGradConfig,
32777 unpackGradConfig,
32778 unsortedSegmentSumGradConfig,
32779 zerosLikeGradConfig
32780 ];
32781 for (const gradientConfig of gradConfigs) {
32782 registerGradient(gradientConfig);
32783 }
32784
32785 /**
32786 * @license
32787 * Copyright 2020 Google LLC. All Rights Reserved.
32788 * Licensed under the Apache License, Version 2.0 (the "License");
32789 * you may not use this file except in compliance with the License.
32790 * You may obtain a copy of the License at
32791 *
32792 * http://www.apache.org/licenses/LICENSE-2.0
32793 *
32794 * Unless required by applicable law or agreed to in writing, software
32795 * distributed under the License is distributed on an "AS IS" BASIS,
32796 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32797 * See the License for the specific language governing permissions and
32798 * limitations under the License.
32799 * =============================================================================
32800 */
32801 getGlobalTensorClass().prototype.abs = function () {
32802 this.throwIfDisposed();
32803 return abs$2(this);
32804 };
32805
32806 /**
32807 * @license
32808 * Copyright 2020 Google LLC. All Rights Reserved.
32809 * Licensed under the Apache License, Version 2.0 (the "License");
32810 * you may not use this file except in compliance with the License.
32811 * You may obtain a copy of the License at
32812 *
32813 * http://www.apache.org/licenses/LICENSE-2.0
32814 *
32815 * Unless required by applicable law or agreed to in writing, software
32816 * distributed under the License is distributed on an "AS IS" BASIS,
32817 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32818 * See the License for the specific language governing permissions and
32819 * limitations under the License.
32820 * =============================================================================
32821 */
32822 getGlobalTensorClass().prototype.acos = function () {
32823 this.throwIfDisposed();
32824 return acos$2(this);
32825 };
32826
32827 /**
32828 * @license
32829 * Copyright 2020 Google LLC. All Rights Reserved.
32830 * Licensed under the Apache License, Version 2.0 (the "License");
32831 * you may not use this file except in compliance with the License.
32832 * You may obtain a copy of the License at
32833 *
32834 * http://www.apache.org/licenses/LICENSE-2.0
32835 *
32836 * Unless required by applicable law or agreed to in writing, software
32837 * distributed under the License is distributed on an "AS IS" BASIS,
32838 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32839 * See the License for the specific language governing permissions and
32840 * limitations under the License.
32841 * =============================================================================
32842 */
32843 getGlobalTensorClass().prototype.acosh = function () {
32844 this.throwIfDisposed();
32845 return acosh$2(this);
32846 };
32847
32848 /**
32849 * @license
32850 * Copyright 2020 Google LLC. All Rights Reserved.
32851 * Licensed under the Apache License, Version 2.0 (the "License");
32852 * you may not use this file except in compliance with the License.
32853 * You may obtain a copy of the License at
32854 *
32855 * http://www.apache.org/licenses/LICENSE-2.0
32856 *
32857 * Unless required by applicable law or agreed to in writing, software
32858 * distributed under the License is distributed on an "AS IS" BASIS,
32859 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32860 * See the License for the specific language governing permissions and
32861 * limitations under the License.
32862 * =============================================================================
32863 */
32864 getGlobalTensorClass().prototype.add = function (b) {
32865 this.throwIfDisposed();
32866 return add$3(this, b);
32867 };
32868
32869 /**
32870 * @license
32871 * Copyright 2020 Google LLC. All Rights Reserved.
32872 * Licensed under the Apache License, Version 2.0 (the "License");
32873 * you may not use this file except in compliance with the License.
32874 * You may obtain a copy of the License at
32875 *
32876 * http://www.apache.org/licenses/LICENSE-2.0
32877 *
32878 * Unless required by applicable law or agreed to in writing, software
32879 * distributed under the License is distributed on an "AS IS" BASIS,
32880 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32881 * See the License for the specific language governing permissions and
32882 * limitations under the License.
32883 * =============================================================================
32884 */
32885 getGlobalTensorClass().prototype.all = function (axis, keepDims) {
32886 this.throwIfDisposed();
32887 return all$2(this, axis, keepDims);
32888 };
32889
32890 /**
32891 * @license
32892 * Copyright 2020 Google LLC. All Rights Reserved.
32893 * Licensed under the Apache License, Version 2.0 (the "License");
32894 * you may not use this file except in compliance with the License.
32895 * You may obtain a copy of the License at
32896 *
32897 * http://www.apache.org/licenses/LICENSE-2.0
32898 *
32899 * Unless required by applicable law or agreed to in writing, software
32900 * distributed under the License is distributed on an "AS IS" BASIS,
32901 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32902 * See the License for the specific language governing permissions and
32903 * limitations under the License.
32904 * =============================================================================
32905 */
32906 getGlobalTensorClass().prototype.any = function (axis, keepDims) {
32907 this.throwIfDisposed();
32908 return any$2(this, axis, keepDims);
32909 };
32910
32911 /**
32912 * @license
32913 * Copyright 2020 Google LLC. All Rights Reserved.
32914 * Licensed under the Apache License, Version 2.0 (the "License");
32915 * you may not use this file except in compliance with the License.
32916 * You may obtain a copy of the License at
32917 *
32918 * http://www.apache.org/licenses/LICENSE-2.0
32919 *
32920 * Unless required by applicable law or agreed to in writing, software
32921 * distributed under the License is distributed on an "AS IS" BASIS,
32922 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32923 * See the License for the specific language governing permissions and
32924 * limitations under the License.
32925 * =============================================================================
32926 */
32927 getGlobalTensorClass().prototype.argMax = function (axis) {
32928 this.throwIfDisposed();
32929 return argMax$2(this, axis);
32930 };
32931
32932 /**
32933 * @license
32934 * Copyright 2020 Google LLC. All Rights Reserved.
32935 * Licensed under the Apache License, Version 2.0 (the "License");
32936 * you may not use this file except in compliance with the License.
32937 * You may obtain a copy of the License at
32938 *
32939 * http://www.apache.org/licenses/LICENSE-2.0
32940 *
32941 * Unless required by applicable law or agreed to in writing, software
32942 * distributed under the License is distributed on an "AS IS" BASIS,
32943 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32944 * See the License for the specific language governing permissions and
32945 * limitations under the License.
32946 * =============================================================================
32947 */
32948 getGlobalTensorClass().prototype.argMin = function (axis) {
32949 this.throwIfDisposed();
32950 return argMin$2(this, axis);
32951 };
32952
32953 /**
32954 * @license
32955 * Copyright 2020 Google LLC. All Rights Reserved.
32956 * Licensed under the Apache License, Version 2.0 (the "License");
32957 * you may not use this file except in compliance with the License.
32958 * You may obtain a copy of the License at
32959 *
32960 * http://www.apache.org/licenses/LICENSE-2.0
32961 *
32962 * Unless required by applicable law or agreed to in writing, software
32963 * distributed under the License is distributed on an "AS IS" BASIS,
32964 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32965 * See the License for the specific language governing permissions and
32966 * limitations under the License.
32967 * =============================================================================
32968 */
32969 /**
32970 * Converts a size-1 `tf.Tensor` to a `tf.Scalar`.
32971 * @doc {heading: 'Tensors', subheading: 'Classes'}
32972 */
32973 getGlobalTensorClass().prototype.asScalar = function () {
32974 this.throwIfDisposed();
32975 assert$1(this.size === 1, () => 'The array must have only 1 element.');
32976 return reshape$3(this, []);
32977 };
32978
32979 /**
32980 * @license
32981 * Copyright 2020 Google LLC. All Rights Reserved.
32982 * Licensed under the Apache License, Version 2.0 (the "License");
32983 * you may not use this file except in compliance with the License.
32984 * You may obtain a copy of the License at
32985 *
32986 * http://www.apache.org/licenses/LICENSE-2.0
32987 *
32988 * Unless required by applicable law or agreed to in writing, software
32989 * distributed under the License is distributed on an "AS IS" BASIS,
32990 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32991 * See the License for the specific language governing permissions and
32992 * limitations under the License.
32993 * =============================================================================
32994 */
32995 /**
32996 * Casts a `tf.Tensor` to a specified dtype.
32997 *
32998 * @param dtype Data-type to cast the tensor to.
32999 *
33000 * @doc {heading: 'Tensors', subheading: 'Classes'}
33001 */
33002 getGlobalTensorClass().prototype.asType = function (dtype) {
33003 this.throwIfDisposed();
33004 return cast$3(this, dtype);
33005 };
33006
33007 /**
33008 * @license
33009 * Copyright 2020 Google LLC. All Rights Reserved.
33010 * Licensed under the Apache License, Version 2.0 (the "License");
33011 * you may not use this file except in compliance with the License.
33012 * You may obtain a copy of the License at
33013 *
33014 * http://www.apache.org/licenses/LICENSE-2.0
33015 *
33016 * Unless required by applicable law or agreed to in writing, software
33017 * distributed under the License is distributed on an "AS IS" BASIS,
33018 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33019 * See the License for the specific language governing permissions and
33020 * limitations under the License.
33021 * =============================================================================
33022 */
33023 /**
33024 * Converts a `tf.Tensor` to a `tf.Tensor1D`.
33025 * @doc {heading: 'Tensors', subheading: 'Classes'}
33026 */
33027 getGlobalTensorClass().prototype.as1D = function () {
33028 this.throwIfDisposed();
33029 return reshape$3(this, [this.size]);
33030 };
33031
33032 /**
33033 * @license
33034 * Copyright 2020 Google LLC. All Rights Reserved.
33035 * Licensed under the Apache License, Version 2.0 (the "License");
33036 * you may not use this file except in compliance with the License.
33037 * You may obtain a copy of the License at
33038 *
33039 * http://www.apache.org/licenses/LICENSE-2.0
33040 *
33041 * Unless required by applicable law or agreed to in writing, software
33042 * distributed under the License is distributed on an "AS IS" BASIS,
33043 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33044 * See the License for the specific language governing permissions and
33045 * limitations under the License.
33046 * =============================================================================
33047 */
33048 /**
33049 * Converts a `tf.Tensor` to a `tf.Tensor2D`.
33050 *
33051 * @param rows Number of rows in `tf.Tensor2D`.
33052 * @param columns Number of columns in `tf.Tensor2D`.
33053 * @doc {heading: 'Tensors', subheading: 'Classes'}
33054 */
33055 getGlobalTensorClass().prototype.as2D = function (rows, columns) {
33056 this.throwIfDisposed();
33057 return reshape$3(this, [rows, columns]);
33058 };
33059
33060 /**
33061 * @license
33062 * Copyright 2020 Google LLC. All Rights Reserved.
33063 * Licensed under the Apache License, Version 2.0 (the "License");
33064 * you may not use this file except in compliance with the License.
33065 * You may obtain a copy of the License at
33066 *
33067 * http://www.apache.org/licenses/LICENSE-2.0
33068 *
33069 * Unless required by applicable law or agreed to in writing, software
33070 * distributed under the License is distributed on an "AS IS" BASIS,
33071 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33072 * See the License for the specific language governing permissions and
33073 * limitations under the License.
33074 * =============================================================================
33075 */
33076 /**
33077 * Converts a `tf.Tensor` to a `tf.Tensor3D`.
33078 *
33079 * @param rows Number of rows in `tf.Tensor3D`.
33080 * @param columns Number of columns in `tf.Tensor3D`.
33081 * @param depth Depth of `tf.Tensor3D`.
33082 * @doc {heading: 'Tensors', subheading: 'Classes'}
33083 */
33084 getGlobalTensorClass().prototype.as3D = function (rows, columns, depth) {
33085 this.throwIfDisposed();
33086 return reshape$3(this, [rows, columns, depth]);
33087 };
33088
33089 /**
33090 * @license
33091 * Copyright 2020 Google LLC. All Rights Reserved.
33092 * Licensed under the Apache License, Version 2.0 (the "License");
33093 * you may not use this file except in compliance with the License.
33094 * You may obtain a copy of the License at
33095 *
33096 * http://www.apache.org/licenses/LICENSE-2.0
33097 *
33098 * Unless required by applicable law or agreed to in writing, software
33099 * distributed under the License is distributed on an "AS IS" BASIS,
33100 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33101 * See the License for the specific language governing permissions and
33102 * limitations under the License.
33103 * =============================================================================
33104 */
33105 /**
33106 * Converts a `tf.Tensor` to a `tf.Tensor4D`.
33107 *
33108 * @param rows Number of rows in `tf.Tensor4D`.
33109 * @param columns Number of columns in `tf.Tensor4D`.
33110 * @param depth Depth of `tf.Tensor4D`.
33111 * @param depth2 4th dimension of `tf.Tensor4D`.
33112 * @doc {heading: 'Tensors', subheading: 'Classes'}
33113 */
33114 getGlobalTensorClass().prototype.as4D = function (rows, columns, depth, depth2) {
33115 this.throwIfDisposed();
33116 return reshape$3(this, [rows, columns, depth, depth2]);
33117 };
33118
33119 /**
33120 * @license
33121 * Copyright 2020 Google LLC. All Rights Reserved.
33122 * Licensed under the Apache License, Version 2.0 (the "License");
33123 * you may not use this file except in compliance with the License.
33124 * You may obtain a copy of the License at
33125 *
33126 * http://www.apache.org/licenses/LICENSE-2.0
33127 *
33128 * Unless required by applicable law or agreed to in writing, software
33129 * distributed under the License is distributed on an "AS IS" BASIS,
33130 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33131 * See the License for the specific language governing permissions and
33132 * limitations under the License.
33133 * =============================================================================
33134 */
33135 /**
33136 * Converts a `tf.Tensor` to a `tf.Tensor5D`.
33137 *
33138 * @param rows Number of rows in `tf.Tensor5D`.
33139 * @param columns Number of columns in `tf.Tensor5D`.
33140 * @param depth Depth of `tf.Tensor5D`.
33141 * @param depth2 4th dimension of `tf.Tensor5D`.
33142 * @param depth3 5th dimension of 'tf.Tensor5D'
33143 *
33144 * @doc {heading: 'Tensors', subheading: 'Classes'}
33145 */
33146 getGlobalTensorClass().prototype.as5D = function (rows, columns, depth, depth2, depth3) {
33147 this.throwIfDisposed();
33148 return reshape$3(this, [rows, columns, depth, depth2, depth3]);
33149 };
33150
33151 /**
33152 * @license
33153 * Copyright 2020 Google LLC. All Rights Reserved.
33154 * Licensed under the Apache License, Version 2.0 (the "License");
33155 * you may not use this file except in compliance with the License.
33156 * You may obtain a copy of the License at
33157 *
33158 * http://www.apache.org/licenses/LICENSE-2.0
33159 *
33160 * Unless required by applicable law or agreed to in writing, software
33161 * distributed under the License is distributed on an "AS IS" BASIS,
33162 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33163 * See the License for the specific language governing permissions and
33164 * limitations under the License.
33165 * =============================================================================
33166 */
33167 getGlobalTensorClass().prototype.asin = function () {
33168 this.throwIfDisposed();
33169 return asin$2(this);
33170 };
33171
33172 /**
33173 * @license
33174 * Copyright 2020 Google LLC. All Rights Reserved.
33175 * Licensed under the Apache License, Version 2.0 (the "License");
33176 * you may not use this file except in compliance with the License.
33177 * You may obtain a copy of the License at
33178 *
33179 * http://www.apache.org/licenses/LICENSE-2.0
33180 *
33181 * Unless required by applicable law or agreed to in writing, software
33182 * distributed under the License is distributed on an "AS IS" BASIS,
33183 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33184 * See the License for the specific language governing permissions and
33185 * limitations under the License.
33186 * =============================================================================
33187 */
33188 getGlobalTensorClass().prototype.asinh = function () {
33189 this.throwIfDisposed();
33190 return asinh$2(this);
33191 };
33192
33193 /**
33194 * @license
33195 * Copyright 2020 Google LLC. All Rights Reserved.
33196 * Licensed under the Apache License, Version 2.0 (the "License");
33197 * you may not use this file except in compliance with the License.
33198 * You may obtain a copy of the License at
33199 *
33200 * http://www.apache.org/licenses/LICENSE-2.0
33201 *
33202 * Unless required by applicable law or agreed to in writing, software
33203 * distributed under the License is distributed on an "AS IS" BASIS,
33204 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33205 * See the License for the specific language governing permissions and
33206 * limitations under the License.
33207 * =============================================================================
33208 */
33209 getGlobalTensorClass().prototype.atan = function () {
33210 this.throwIfDisposed();
33211 return atan$2(this);
33212 };
33213
33214 /**
33215 * @license
33216 * Copyright 2020 Google LLC. All Rights Reserved.
33217 * Licensed under the Apache License, Version 2.0 (the "License");
33218 * you may not use this file except in compliance with the License.
33219 * You may obtain a copy of the License at
33220 *
33221 * http://www.apache.org/licenses/LICENSE-2.0
33222 *
33223 * Unless required by applicable law or agreed to in writing, software
33224 * distributed under the License is distributed on an "AS IS" BASIS,
33225 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33226 * See the License for the specific language governing permissions and
33227 * limitations under the License.
33228 * =============================================================================
33229 */
33230 getGlobalTensorClass().prototype.atan2 = function (b) {
33231 this.throwIfDisposed();
33232 return atan2$2(this, b);
33233 };
33234
33235 /**
33236 * @license
33237 * Copyright 2020 Google LLC. All Rights Reserved.
33238 * Licensed under the Apache License, Version 2.0 (the "License");
33239 * you may not use this file except in compliance with the License.
33240 * You may obtain a copy of the License at
33241 *
33242 * http://www.apache.org/licenses/LICENSE-2.0
33243 *
33244 * Unless required by applicable law or agreed to in writing, software
33245 * distributed under the License is distributed on an "AS IS" BASIS,
33246 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33247 * See the License for the specific language governing permissions and
33248 * limitations under the License.
33249 * =============================================================================
33250 */
33251 getGlobalTensorClass().prototype.atanh = function () {
33252 this.throwIfDisposed();
33253 return atanh$2(this);
33254 };
33255
33256 getGlobalTensorClass().prototype.avgPool =
33257 function (filterSize, strides, pad, dimRoundingMode) {
33258 this.throwIfDisposed();
33259 return avgPool$2(this, filterSize, strides, pad, dimRoundingMode);
33260 };
33261
33262 /**
33263 * @license
33264 * Copyright 2020 Google LLC. All Rights Reserved.
33265 * Licensed under the Apache License, Version 2.0 (the "License");
33266 * you may not use this file except in compliance with the License.
33267 * You may obtain a copy of the License at
33268 *
33269 * http://www.apache.org/licenses/LICENSE-2.0
33270 *
33271 * Unless required by applicable law or agreed to in writing, software
33272 * distributed under the License is distributed on an "AS IS" BASIS,
33273 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33274 * See the License for the specific language governing permissions and
33275 * limitations under the License.
33276 * =============================================================================
33277 */
33278 getGlobalTensorClass().prototype.batchToSpaceND = function (blockShape, crops) {
33279 this.throwIfDisposed();
33280 return batchToSpaceND$2(this, blockShape, crops);
33281 };
33282
33283 /**
33284 * @license
33285 * Copyright 2020 Google LLC. All Rights Reserved.
33286 * Licensed under the Apache License, Version 2.0 (the "License");
33287 * you may not use this file except in compliance with the License.
33288 * You may obtain a copy of the License at
33289 *
33290 * http://www.apache.org/licenses/LICENSE-2.0
33291 *
33292 * Unless required by applicable law or agreed to in writing, software
33293 * distributed under the License is distributed on an "AS IS" BASIS,
33294 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33295 * See the License for the specific language governing permissions and
33296 * limitations under the License.
33297 * =============================================================================
33298 */
33299 getGlobalTensorClass().prototype.batchNorm = function (mean, variance, offset, scale, varianceEpsilon) {
33300 this.throwIfDisposed();
33301 return batchNorm$2(this, mean, variance, offset, scale, varianceEpsilon);
33302 };
33303
33304 /**
33305 * @license
33306 * Copyright 2020 Google LLC. All Rights Reserved.
33307 * Licensed under the Apache License, Version 2.0 (the "License");
33308 * you may not use this file except in compliance with the License.
33309 * You may obtain a copy of the License at
33310 *
33311 * http://www.apache.org/licenses/LICENSE-2.0
33312 *
33313 * Unless required by applicable law or agreed to in writing, software
33314 * distributed under the License is distributed on an "AS IS" BASIS,
33315 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33316 * See the License for the specific language governing permissions and
33317 * limitations under the License.
33318 * =============================================================================
33319 */
33320 getGlobalTensorClass().prototype.broadcastTo = function (shape) {
33321 this.throwIfDisposed();
33322 return broadcastTo(this, shape);
33323 };
33324
33325 /**
33326 * @license
33327 * Copyright 2020 Google LLC. All Rights Reserved.
33328 * Licensed under the Apache License, Version 2.0 (the "License");
33329 * you may not use this file except in compliance with the License.
33330 * You may obtain a copy of the License at
33331 *
33332 * http://www.apache.org/licenses/LICENSE-2.0
33333 *
33334 * Unless required by applicable law or agreed to in writing, software
33335 * distributed under the License is distributed on an "AS IS" BASIS,
33336 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33337 * See the License for the specific language governing permissions and
33338 * limitations under the License.
33339 * =============================================================================
33340 */
33341 getGlobalTensorClass().prototype.cast = function (dtype) {
33342 this.throwIfDisposed();
33343 return cast$3(this, dtype);
33344 };
33345
33346 /**
33347 * @license
33348 * Copyright 2020 Google LLC. All Rights Reserved.
33349 * Licensed under the Apache License, Version 2.0 (the "License");
33350 * you may not use this file except in compliance with the License.
33351 * You may obtain a copy of the License at
33352 *
33353 * http://www.apache.org/licenses/LICENSE-2.0
33354 *
33355 * Unless required by applicable law or agreed to in writing, software
33356 * distributed under the License is distributed on an "AS IS" BASIS,
33357 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33358 * See the License for the specific language governing permissions and
33359 * limitations under the License.
33360 * =============================================================================
33361 */
33362 getGlobalTensorClass().prototype.ceil = function () {
33363 this.throwIfDisposed();
33364 return ceil$2(this);
33365 };
33366
33367 /**
33368 * @license
33369 * Copyright 2020 Google LLC. All Rights Reserved.
33370 * Licensed under the Apache License, Version 2.0 (the "License");
33371 * you may not use this file except in compliance with the License.
33372 * You may obtain a copy of the License at
33373 *
33374 * http://www.apache.org/licenses/LICENSE-2.0
33375 *
33376 * Unless required by applicable law or agreed to in writing, software
33377 * distributed under the License is distributed on an "AS IS" BASIS,
33378 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33379 * See the License for the specific language governing permissions and
33380 * limitations under the License.
33381 * =============================================================================
33382 */
33383 getGlobalTensorClass().prototype.clipByValue = function (min, max) {
33384 this.throwIfDisposed();
33385 return clipByValue$2(this, min, max);
33386 };
33387
33388 /**
33389 * @license
33390 * Copyright 2020 Google LLC. All Rights Reserved.
33391 * Licensed under the Apache License, Version 2.0 (the "License");
33392 * you may not use this file except in compliance with the License.
33393 * You may obtain a copy of the License at
33394 *
33395 * http://www.apache.org/licenses/LICENSE-2.0
33396 *
33397 * Unless required by applicable law or agreed to in writing, software
33398 * distributed under the License is distributed on an "AS IS" BASIS,
33399 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33400 * See the License for the specific language governing permissions and
33401 * limitations under the License.
33402 * =============================================================================
33403 */
33404 getGlobalTensorClass().prototype.concat = function (x, axis) {
33405 this.throwIfDisposed();
33406 if (x instanceof Tensor) {
33407 x = [x];
33408 }
33409 return concat$2([this, ...x], axis);
33410 };
33411
33412 /**
33413 * @license
33414 * Copyright 2020 Google LLC. All Rights Reserved.
33415 * Licensed under the Apache License, Version 2.0 (the "License");
33416 * you may not use this file except in compliance with the License.
33417 * You may obtain a copy of the License at
33418 *
33419 * http://www.apache.org/licenses/LICENSE-2.0
33420 *
33421 * Unless required by applicable law or agreed to in writing, software
33422 * distributed under the License is distributed on an "AS IS" BASIS,
33423 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33424 * See the License for the specific language governing permissions and
33425 * limitations under the License.
33426 * =============================================================================
33427 */
33428 getGlobalTensorClass().prototype.conv1d = function (filter, stride, pad, dataFormat, dilation, dimRoundingMode) {
33429 this.throwIfDisposed();
33430 return conv1d$2(this, filter, stride, pad, dataFormat, dilation, dimRoundingMode);
33431 };
33432
33433 /**
33434 * @license
33435 * Copyright 2020 Google LLC. All Rights Reserved.
33436 * Licensed under the Apache License, Version 2.0 (the "License");
33437 * you may not use this file except in compliance with the License.
33438 * You may obtain a copy of the License at
33439 *
33440 * http://www.apache.org/licenses/LICENSE-2.0
33441 *
33442 * Unless required by applicable law or agreed to in writing, software
33443 * distributed under the License is distributed on an "AS IS" BASIS,
33444 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33445 * See the License for the specific language governing permissions and
33446 * limitations under the License.
33447 * =============================================================================
33448 */
33449 getGlobalTensorClass().prototype.conv2dTranspose =
33450 function (filter, outputShape, strides, pad, dimRoundingMode) {
33451 this.throwIfDisposed();
33452 return conv2dTranspose$1(this, filter, outputShape, strides, pad, dimRoundingMode);
33453 };
33454
33455 /**
33456 * @license
33457 * Copyright 2020 Google LLC. All Rights Reserved.
33458 * Licensed under the Apache License, Version 2.0 (the "License");
33459 * you may not use this file except in compliance with the License.
33460 * You may obtain a copy of the License at
33461 *
33462 * http://www.apache.org/licenses/LICENSE-2.0
33463 *
33464 * Unless required by applicable law or agreed to in writing, software
33465 * distributed under the License is distributed on an "AS IS" BASIS,
33466 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33467 * See the License for the specific language governing permissions and
33468 * limitations under the License.
33469 * =============================================================================
33470 */
33471 getGlobalTensorClass().prototype.conv2d = function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
33472 this.throwIfDisposed();
33473 return conv2d$4(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
33474 };
33475
33476 /**
33477 * @license
33478 * Copyright 2020 Google LLC. All Rights Reserved.
33479 * Licensed under the Apache License, Version 2.0 (the "License");
33480 * you may not use this file except in compliance with the License.
33481 * You may obtain a copy of the License at
33482 *
33483 * http://www.apache.org/licenses/LICENSE-2.0
33484 *
33485 * Unless required by applicable law or agreed to in writing, software
33486 * distributed under the License is distributed on an "AS IS" BASIS,
33487 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33488 * See the License for the specific language governing permissions and
33489 * limitations under the License.
33490 * =============================================================================
33491 */
33492 getGlobalTensorClass().prototype.cos = function () {
33493 this.throwIfDisposed();
33494 return cos$2(this);
33495 };
33496
33497 /**
33498 * @license
33499 * Copyright 2020 Google LLC. All Rights Reserved.
33500 * Licensed under the Apache License, Version 2.0 (the "License");
33501 * you may not use this file except in compliance with the License.
33502 * You may obtain a copy of the License at
33503 *
33504 * http://www.apache.org/licenses/LICENSE-2.0
33505 *
33506 * Unless required by applicable law or agreed to in writing, software
33507 * distributed under the License is distributed on an "AS IS" BASIS,
33508 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33509 * See the License for the specific language governing permissions and
33510 * limitations under the License.
33511 * =============================================================================
33512 */
33513 getGlobalTensorClass().prototype.cosh = function () {
33514 this.throwIfDisposed();
33515 return cosh$2(this);
33516 };
33517
33518 /**
33519 * @license
33520 * Copyright 2022 Google LLC. All Rights Reserved.
33521 * Licensed under the Apache License, Version 2.0 (the 'License');
33522 * you may not use this file except in compliance with the License.
33523 * You may obtain a copy of the License at
33524 *
33525 * http://www.apache.org/licenses/LICENSE-2.0
33526 *
33527 * Unless required by applicable law or agreed to in writing, software
33528 * distributed under the License is distributed on an 'AS IS' BASIS,
33529 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33530 * See the License for the specific language governing permissions and
33531 * limitations under the License.
33532 * =============================================================================
33533 */
33534 getGlobalTensorClass().prototype.cumprod = function (axis, exclusive, reverse) {
33535 this.throwIfDisposed();
33536 return cumprod$2(this, axis, exclusive, reverse);
33537 };
33538
33539 /**
33540 * @license
33541 * Copyright 2020 Google LLC. All Rights Reserved.
33542 * Licensed under the Apache License, Version 2.0 (the "License");
33543 * you may not use this file except in compliance with the License.
33544 * You may obtain a copy of the License at
33545 *
33546 * http://www.apache.org/licenses/LICENSE-2.0
33547 *
33548 * Unless required by applicable law or agreed to in writing, software
33549 * distributed under the License is distributed on an "AS IS" BASIS,
33550 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33551 * See the License for the specific language governing permissions and
33552 * limitations under the License.
33553 * =============================================================================
33554 */
33555 getGlobalTensorClass().prototype.cumsum = function (axis, exclusive, reverse) {
33556 this.throwIfDisposed();
33557 return cumsum$2(this, axis, exclusive, reverse);
33558 };
33559
33560 /**
33561 * @license
33562 * Copyright 2020 Google LLC. All Rights Reserved.
33563 * Licensed under the Apache License, Version 2.0 (the "License");
33564 * you may not use this file except in compliance with the License.
33565 * You may obtain a copy of the License at
33566 *
33567 * http://www.apache.org/licenses/LICENSE-2.0
33568 *
33569 * Unless required by applicable law or agreed to in writing, software
33570 * distributed under the License is distributed on an "AS IS" BASIS,
33571 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33572 * See the License for the specific language governing permissions and
33573 * limitations under the License.
33574 * =============================================================================
33575 */
33576 getGlobalTensorClass().prototype.depthToSpace = function (blockSize, dataFormat) {
33577 this.throwIfDisposed();
33578 return depthToSpace$2(this, blockSize, dataFormat);
33579 };
33580
33581 /**
33582 * @license
33583 * Copyright 2020 Google LLC. All Rights Reserved.
33584 * Licensed under the Apache License, Version 2.0 (the "License");
33585 * you may not use this file except in compliance with the License.
33586 * You may obtain a copy of the License at
33587 *
33588 * http://www.apache.org/licenses/LICENSE-2.0
33589 *
33590 * Unless required by applicable law or agreed to in writing, software
33591 * distributed under the License is distributed on an "AS IS" BASIS,
33592 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33593 * See the License for the specific language governing permissions and
33594 * limitations under the License.
33595 * =============================================================================
33596 */
33597 getGlobalTensorClass().prototype.depthwiseConv2d =
33598 function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
33599 this.throwIfDisposed();
33600 return depthwiseConv2d$3(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
33601 };
33602
33603 /**
33604 * @license
33605 * Copyright 2020 Google LLC. All Rights Reserved.
33606 * Licensed under the Apache License, Version 2.0 (the "License");
33607 * you may not use this file except in compliance with the License.
33608 * You may obtain a copy of the License at
33609 *
33610 * http://www.apache.org/licenses/LICENSE-2.0
33611 *
33612 * Unless required by applicable law or agreed to in writing, software
33613 * distributed under the License is distributed on an "AS IS" BASIS,
33614 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33615 * See the License for the specific language governing permissions and
33616 * limitations under the License.
33617 * =============================================================================
33618 */
33619 getGlobalTensorClass().prototype.dilation2d =
33620 function (filter, strides, pad, dilations, dataFormat) {
33621 this.throwIfDisposed();
33622 return dilation2d(this, filter, strides, pad, dilations, dataFormat);
33623 };
33624
33625 /**
33626 * @license
33627 * Copyright 2020 Google LLC. All Rights Reserved.
33628 * Licensed under the Apache License, Version 2.0 (the "License");
33629 * you may not use this file except in compliance with the License.
33630 * You may obtain a copy of the License at
33631 *
33632 * http://www.apache.org/licenses/LICENSE-2.0
33633 *
33634 * Unless required by applicable law or agreed to in writing, software
33635 * distributed under the License is distributed on an "AS IS" BASIS,
33636 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33637 * See the License for the specific language governing permissions and
33638 * limitations under the License.
33639 * =============================================================================
33640 */
33641 getGlobalTensorClass().prototype.divNoNan = function (b) {
33642 this.throwIfDisposed();
33643 return divNoNan(this, b);
33644 };
33645
33646 /**
33647 * @license
33648 * Copyright 2020 Google LLC. All Rights Reserved.
33649 * Licensed under the Apache License, Version 2.0 (the "License");
33650 * you may not use this file except in compliance with the License.
33651 * You may obtain a copy of the License at
33652 *
33653 * http://www.apache.org/licenses/LICENSE-2.0
33654 *
33655 * Unless required by applicable law or agreed to in writing, software
33656 * distributed under the License is distributed on an "AS IS" BASIS,
33657 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33658 * See the License for the specific language governing permissions and
33659 * limitations under the License.
33660 * =============================================================================
33661 */
33662 getGlobalTensorClass().prototype.div = function (b) {
33663 this.throwIfDisposed();
33664 return div$1(this, b);
33665 };
33666
33667 /**
33668 * @license
33669 * Copyright 2020 Google LLC. All Rights Reserved.
33670 * Licensed under the Apache License, Version 2.0 (the "License");
33671 * you may not use this file except in compliance with the License.
33672 * You may obtain a copy of the License at
33673 *
33674 * http://www.apache.org/licenses/LICENSE-2.0
33675 *
33676 * Unless required by applicable law or agreed to in writing, software
33677 * distributed under the License is distributed on an "AS IS" BASIS,
33678 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33679 * See the License for the specific language governing permissions and
33680 * limitations under the License.
33681 * =============================================================================
33682 */
33683 getGlobalTensorClass().prototype.dot = function (b) {
33684 this.throwIfDisposed();
33685 return dot$2(this, b);
33686 };
33687
33688 /**
33689 * @license
33690 * Copyright 2020 Google LLC. All Rights Reserved.
33691 * Licensed under the Apache License, Version 2.0 (the "License");
33692 * you may not use this file except in compliance with the License.
33693 * You may obtain a copy of the License at
33694 *
33695 * http://www.apache.org/licenses/LICENSE-2.0
33696 *
33697 * Unless required by applicable law or agreed to in writing, software
33698 * distributed under the License is distributed on an "AS IS" BASIS,
33699 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33700 * See the License for the specific language governing permissions and
33701 * limitations under the License.
33702 * =============================================================================
33703 */
33704 getGlobalTensorClass().prototype.elu = function () {
33705 this.throwIfDisposed();
33706 return elu$4(this);
33707 };
33708
33709 /**
33710 * @license
33711 * Copyright 2020 Google LLC. All Rights Reserved.
33712 * Licensed under the Apache License, Version 2.0 (the "License");
33713 * you may not use this file except in compliance with the License.
33714 * You may obtain a copy of the License at
33715 *
33716 * http://www.apache.org/licenses/LICENSE-2.0
33717 *
33718 * Unless required by applicable law or agreed to in writing, software
33719 * distributed under the License is distributed on an "AS IS" BASIS,
33720 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33721 * See the License for the specific language governing permissions and
33722 * limitations under the License.
33723 * =============================================================================
33724 */
33725 getGlobalTensorClass().prototype.equal = function (b) {
33726 this.throwIfDisposed();
33727 return equal$2(this, b);
33728 };
33729
33730 /**
33731 * @license
33732 * Copyright 2020 Google LLC. All Rights Reserved.
33733 * Licensed under the Apache License, Version 2.0 (the "License");
33734 * you may not use this file except in compliance with the License.
33735 * You may obtain a copy of the License at
33736 *
33737 * http://www.apache.org/licenses/LICENSE-2.0
33738 *
33739 * Unless required by applicable law or agreed to in writing, software
33740 * distributed under the License is distributed on an "AS IS" BASIS,
33741 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33742 * See the License for the specific language governing permissions and
33743 * limitations under the License.
33744 * =============================================================================
33745 */
33746 getGlobalTensorClass().prototype.erf = function () {
33747 this.throwIfDisposed();
33748 return erf$2(this);
33749 };
33750
33751 /**
33752 * @license
33753 * Copyright 2021 Google LLC. All Rights Reserved.
33754 * Licensed under the Apache License, Version 2.0 (the "License");
33755 * you may not use this file except in compliance with the License.
33756 * You may obtain a copy of the License at
33757 *
33758 * http://www.apache.org/licenses/LICENSE-2.0
33759 *
33760 * Unless required by applicable law or agreed to in writing, software
33761 * distributed under the License is distributed on an "AS IS" BASIS,
33762 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33763 * See the License for the specific language governing permissions and
33764 * limitations under the License.
33765 * =============================================================================
33766 */
33767 getGlobalTensorClass().prototype.euclideanNorm = function (axis, keepDims) {
33768 this.throwIfDisposed();
33769 return euclideanNorm(this, axis, keepDims);
33770 };
33771
33772 /**
33773 * @license
33774 * Copyright 2020 Google LLC. All Rights Reserved.
33775 * Licensed under the Apache License, Version 2.0 (the "License");
33776 * you may not use this file except in compliance with the License.
33777 * You may obtain a copy of the License at
33778 *
33779 * http://www.apache.org/licenses/LICENSE-2.0
33780 *
33781 * Unless required by applicable law or agreed to in writing, software
33782 * distributed under the License is distributed on an "AS IS" BASIS,
33783 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33784 * See the License for the specific language governing permissions and
33785 * limitations under the License.
33786 * =============================================================================
33787 */
33788 getGlobalTensorClass().prototype.exp = function () {
33789 this.throwIfDisposed();
33790 return exp$2(this);
33791 };
33792
33793 /**
33794 * @license
33795 * Copyright 2020 Google LLC. All Rights Reserved.
33796 * Licensed under the Apache License, Version 2.0 (the "License");
33797 * you may not use this file except in compliance with the License.
33798 * You may obtain a copy of the License at
33799 *
33800 * http://www.apache.org/licenses/LICENSE-2.0
33801 *
33802 * Unless required by applicable law or agreed to in writing, software
33803 * distributed under the License is distributed on an "AS IS" BASIS,
33804 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33805 * See the License for the specific language governing permissions and
33806 * limitations under the License.
33807 * =============================================================================
33808 */
33809 getGlobalTensorClass().prototype.expandDims = function (axis) {
33810 this.throwIfDisposed();
33811 return expandDims$3(this, axis);
33812 };
33813
33814 /**
33815 * @license
33816 * Copyright 2020 Google LLC. All Rights Reserved.
33817 * Licensed under the Apache License, Version 2.0 (the "License");
33818 * you may not use this file except in compliance with the License.
33819 * You may obtain a copy of the License at
33820 *
33821 * http://www.apache.org/licenses/LICENSE-2.0
33822 *
33823 * Unless required by applicable law or agreed to in writing, software
33824 * distributed under the License is distributed on an "AS IS" BASIS,
33825 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33826 * See the License for the specific language governing permissions and
33827 * limitations under the License.
33828 * =============================================================================
33829 */
33830 getGlobalTensorClass().prototype.expm1 = function () {
33831 this.throwIfDisposed();
33832 return expm1$2(this);
33833 };
33834
33835 /**
33836 * @license
33837 * Copyright 2020 Google LLC. All Rights Reserved.
33838 * Licensed under the Apache License, Version 2.0 (the "License");
33839 * you may not use this file except in compliance with the License.
33840 * You may obtain a copy of the License at
33841 *
33842 * http://www.apache.org/licenses/LICENSE-2.0
33843 *
33844 * Unless required by applicable law or agreed to in writing, software
33845 * distributed under the License is distributed on an "AS IS" BASIS,
33846 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33847 * See the License for the specific language governing permissions and
33848 * limitations under the License.
33849 * =============================================================================
33850 */
33851 getGlobalTensorClass().prototype.fft = function () {
33852 this.throwIfDisposed();
33853 return fft$2(this);
33854 };
33855
33856 /**
33857 * @license
33858 * Copyright 2020 Google LLC. All Rights Reserved.
33859 * Licensed under the Apache License, Version 2.0 (the "License");
33860 * you may not use this file except in compliance with the License.
33861 * You may obtain a copy of the License at
33862 *
33863 * http://www.apache.org/licenses/LICENSE-2.0
33864 *
33865 * Unless required by applicable law or agreed to in writing, software
33866 * distributed under the License is distributed on an "AS IS" BASIS,
33867 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33868 * See the License for the specific language governing permissions and
33869 * limitations under the License.
33870 * =============================================================================
33871 */
33872 /**
33873 * Flatten a Tensor to a 1D array.
33874 * @doc {heading: 'Tensors', subheading: 'Classes'}
33875 */
33876 getGlobalTensorClass().prototype.flatten = function () {
33877 this.throwIfDisposed();
33878 return reshape$3(this, [this.size]);
33879 };
33880
33881 /**
33882 * @license
33883 * Copyright 2020 Google LLC. All Rights Reserved.
33884 * Licensed under the Apache License, Version 2.0 (the "License");
33885 * you may not use this file except in compliance with the License.
33886 * You may obtain a copy of the License at
33887 *
33888 * http://www.apache.org/licenses/LICENSE-2.0
33889 *
33890 * Unless required by applicable law or agreed to in writing, software
33891 * distributed under the License is distributed on an "AS IS" BASIS,
33892 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33893 * See the License for the specific language governing permissions and
33894 * limitations under the License.
33895 * =============================================================================
33896 */
33897 getGlobalTensorClass().prototype.floor = function () {
33898 this.throwIfDisposed();
33899 return floor$2(this);
33900 };
33901
33902 /**
33903 * @license
33904 * Copyright 2020 Google LLC. All Rights Reserved.
33905 * Licensed under the Apache License, Version 2.0 (the "License");
33906 * you may not use this file except in compliance with the License.
33907 * You may obtain a copy of the License at
33908 *
33909 * http://www.apache.org/licenses/LICENSE-2.0
33910 *
33911 * Unless required by applicable law or agreed to in writing, software
33912 * distributed under the License is distributed on an "AS IS" BASIS,
33913 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33914 * See the License for the specific language governing permissions and
33915 * limitations under the License.
33916 * =============================================================================
33917 */
33918 getGlobalTensorClass().prototype.floorDiv = function (b) {
33919 this.throwIfDisposed();
33920 return floorDiv$2(this, b);
33921 };
33922
33923 /**
33924 * @license
33925 * Copyright 2020 Google LLC. All Rights Reserved.
33926 * Licensed under the Apache License, Version 2.0 (the "License");
33927 * you may not use this file except in compliance with the License.
33928 * You may obtain a copy of the License at
33929 *
33930 * http://www.apache.org/licenses/LICENSE-2.0
33931 *
33932 * Unless required by applicable law or agreed to in writing, software
33933 * distributed under the License is distributed on an "AS IS" BASIS,
33934 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33935 * See the License for the specific language governing permissions and
33936 * limitations under the License.
33937 * =============================================================================
33938 */
33939 getGlobalTensorClass().prototype.gather = function (indices, axis, batchDims) {
33940 this.throwIfDisposed();
33941 return gather$1(this, indices, axis, batchDims);
33942 };
33943
33944 /**
33945 * @license
33946 * Copyright 2020 Google LLC. All Rights Reserved.
33947 * Licensed under the Apache License, Version 2.0 (the "License");
33948 * you may not use this file except in compliance with the License.
33949 * You may obtain a copy of the License at
33950 *
33951 * http://www.apache.org/licenses/LICENSE-2.0
33952 *
33953 * Unless required by applicable law or agreed to in writing, software
33954 * distributed under the License is distributed on an "AS IS" BASIS,
33955 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33956 * See the License for the specific language governing permissions and
33957 * limitations under the License.
33958 * =============================================================================
33959 */
33960 getGlobalTensorClass().prototype.greaterEqual = function (b) {
33961 this.throwIfDisposed();
33962 return greaterEqual$2(this, b);
33963 };
33964
33965 /**
33966 * @license
33967 * Copyright 2020 Google LLC. All Rights Reserved.
33968 * Licensed under the Apache License, Version 2.0 (the "License");
33969 * you may not use this file except in compliance with the License.
33970 * You may obtain a copy of the License at
33971 *
33972 * http://www.apache.org/licenses/LICENSE-2.0
33973 *
33974 * Unless required by applicable law or agreed to in writing, software
33975 * distributed under the License is distributed on an "AS IS" BASIS,
33976 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33977 * See the License for the specific language governing permissions and
33978 * limitations under the License.
33979 * =============================================================================
33980 */
33981 getGlobalTensorClass().prototype.greater = function (b) {
33982 this.throwIfDisposed();
33983 return greater$3(this, b);
33984 };
33985
33986 /**
33987 * @license
33988 * Copyright 2020 Google LLC. All Rights Reserved.
33989 * Licensed under the Apache License, Version 2.0 (the "License");
33990 * you may not use this file except in compliance with the License.
33991 * You may obtain a copy of the License at
33992 *
33993 * http://www.apache.org/licenses/LICENSE-2.0
33994 *
33995 * Unless required by applicable law or agreed to in writing, software
33996 * distributed under the License is distributed on an "AS IS" BASIS,
33997 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33998 * See the License for the specific language governing permissions and
33999 * limitations under the License.
34000 * =============================================================================
34001 */
34002 getGlobalTensorClass().prototype.ifft = function () {
34003 this.throwIfDisposed();
34004 return ifft$2(this);
34005 };
34006
34007 /**
34008 * @license
34009 * Copyright 2020 Google LLC. All Rights Reserved.
34010 * Licensed under the Apache License, Version 2.0 (the "License");
34011 * you may not use this file except in compliance with the License.
34012 * You may obtain a copy of the License at
34013 *
34014 * http://www.apache.org/licenses/LICENSE-2.0
34015 *
34016 * Unless required by applicable law or agreed to in writing, software
34017 * distributed under the License is distributed on an "AS IS" BASIS,
34018 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34019 * See the License for the specific language governing permissions and
34020 * limitations under the License.
34021 * =============================================================================
34022 */
34023 getGlobalTensorClass().prototype.irfft = function () {
34024 this.throwIfDisposed();
34025 return irfft(this);
34026 };
34027
34028 /**
34029 * @license
34030 * Copyright 2020 Google LLC. All Rights Reserved.
34031 * Licensed under the Apache License, Version 2.0 (the "License");
34032 * you may not use this file except in compliance with the License.
34033 * You may obtain a copy of the License at
34034 *
34035 * http://www.apache.org/licenses/LICENSE-2.0
34036 *
34037 * Unless required by applicable law or agreed to in writing, software
34038 * distributed under the License is distributed on an "AS IS" BASIS,
34039 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34040 * See the License for the specific language governing permissions and
34041 * limitations under the License.
34042 * =============================================================================
34043 */
34044 getGlobalTensorClass().prototype.isFinite = function () {
34045 this.throwIfDisposed();
34046 return isFinite$3(this);
34047 };
34048
34049 /**
34050 * @license
34051 * Copyright 2020 Google LLC. All Rights Reserved.
34052 * Licensed under the Apache License, Version 2.0 (the "License");
34053 * you may not use this file except in compliance with the License.
34054 * You may obtain a copy of the License at
34055 *
34056 * http://www.apache.org/licenses/LICENSE-2.0
34057 *
34058 * Unless required by applicable law or agreed to in writing, software
34059 * distributed under the License is distributed on an "AS IS" BASIS,
34060 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34061 * See the License for the specific language governing permissions and
34062 * limitations under the License.
34063 * =============================================================================
34064 */
34065 getGlobalTensorClass().prototype.isInf = function () {
34066 this.throwIfDisposed();
34067 return isInf$2(this);
34068 };
34069
34070 /**
34071 * @license
34072 * Copyright 2020 Google LLC. All Rights Reserved.
34073 * Licensed under the Apache License, Version 2.0 (the "License");
34074 * you may not use this file except in compliance with the License.
34075 * You may obtain a copy of the License at
34076 *
34077 * http://www.apache.org/licenses/LICENSE-2.0
34078 *
34079 * Unless required by applicable law or agreed to in writing, software
34080 * distributed under the License is distributed on an "AS IS" BASIS,
34081 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34082 * See the License for the specific language governing permissions and
34083 * limitations under the License.
34084 * =============================================================================
34085 */
34086 getGlobalTensorClass().prototype.isNaN = function () {
34087 this.throwIfDisposed();
34088 return isNaN$3(this);
34089 };
34090
34091 /**
34092 * @license
34093 * Copyright 2020 Google LLC. All Rights Reserved.
34094 * Licensed under the Apache License, Version 2.0 (the "License");
34095 * you may not use this file except in compliance with the License.
34096 * You may obtain a copy of the License at
34097 *
34098 * http://www.apache.org/licenses/LICENSE-2.0
34099 *
34100 * Unless required by applicable law or agreed to in writing, software
34101 * distributed under the License is distributed on an "AS IS" BASIS,
34102 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34103 * See the License for the specific language governing permissions and
34104 * limitations under the License.
34105 * =============================================================================
34106 */
34107 getGlobalTensorClass().prototype.leakyRelu = function (alpha) {
34108 this.throwIfDisposed();
34109 return leakyRelu$2(this, alpha);
34110 };
34111
34112 /**
34113 * @license
34114 * Copyright 2020 Google LLC. All Rights Reserved.
34115 * Licensed under the Apache License, Version 2.0 (the "License");
34116 * you may not use this file except in compliance with the License.
34117 * You may obtain a copy of the License at
34118 *
34119 * http://www.apache.org/licenses/LICENSE-2.0
34120 *
34121 * Unless required by applicable law or agreed to in writing, software
34122 * distributed under the License is distributed on an "AS IS" BASIS,
34123 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34124 * See the License for the specific language governing permissions and
34125 * limitations under the License.
34126 * =============================================================================
34127 */
34128 getGlobalTensorClass().prototype.lessEqual = function (b) {
34129 this.throwIfDisposed();
34130 return lessEqual$2(this, b);
34131 };
34132
34133 /**
34134 * @license
34135 * Copyright 2020 Google LLC. All Rights Reserved.
34136 * Licensed under the Apache License, Version 2.0 (the "License");
34137 * you may not use this file except in compliance with the License.
34138 * You may obtain a copy of the License at
34139 *
34140 * http://www.apache.org/licenses/LICENSE-2.0
34141 *
34142 * Unless required by applicable law or agreed to in writing, software
34143 * distributed under the License is distributed on an "AS IS" BASIS,
34144 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34145 * See the License for the specific language governing permissions and
34146 * limitations under the License.
34147 * =============================================================================
34148 */
34149 getGlobalTensorClass().prototype.less = function (b) {
34150 this.throwIfDisposed();
34151 return less$3(this, b);
34152 };
34153
34154 /**
34155 * @license
34156 * Copyright 2020 Google LLC. All Rights Reserved.
34157 * Licensed under the Apache License, Version 2.0 (the "License");
34158 * you may not use this file except in compliance with the License.
34159 * You may obtain a copy of the License at
34160 *
34161 * http://www.apache.org/licenses/LICENSE-2.0
34162 *
34163 * Unless required by applicable law or agreed to in writing, software
34164 * distributed under the License is distributed on an "AS IS" BASIS,
34165 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34166 * See the License for the specific language governing permissions and
34167 * limitations under the License.
34168 * =============================================================================
34169 */
34170 getGlobalTensorClass().prototype.localResponseNormalization =
34171 function (depthRadius, bias, alpha, beta) {
34172 this.throwIfDisposed();
34173 return localResponseNormalization(this, depthRadius, bias, alpha, beta);
34174 };
34175
34176 /**
34177 * @license
34178 * Copyright 2020 Google LLC. All Rights Reserved.
34179 * Licensed under the Apache License, Version 2.0 (the "License");
34180 * you may not use this file except in compliance with the License.
34181 * You may obtain a copy of the License at
34182 *
34183 * http://www.apache.org/licenses/LICENSE-2.0
34184 *
34185 * Unless required by applicable law or agreed to in writing, software
34186 * distributed under the License is distributed on an "AS IS" BASIS,
34187 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34188 * See the License for the specific language governing permissions and
34189 * limitations under the License.
34190 * =============================================================================
34191 */
34192 getGlobalTensorClass().prototype.logSigmoid = function () {
34193 this.throwIfDisposed();
34194 return logSigmoid(this);
34195 };
34196
34197 /**
34198 * @license
34199 * Copyright 2020 Google LLC. All Rights Reserved.
34200 * Licensed under the Apache License, Version 2.0 (the "License");
34201 * you may not use this file except in compliance with the License.
34202 * You may obtain a copy of the License at
34203 *
34204 * http://www.apache.org/licenses/LICENSE-2.0
34205 *
34206 * Unless required by applicable law or agreed to in writing, software
34207 * distributed under the License is distributed on an "AS IS" BASIS,
34208 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34209 * See the License for the specific language governing permissions and
34210 * limitations under the License.
34211 * =============================================================================
34212 */
34213 getGlobalTensorClass().prototype.logSoftmax = function (axis) {
34214 this.throwIfDisposed();
34215 return logSoftmax(this, axis);
34216 };
34217
34218 /**
34219 * @license
34220 * Copyright 2020 Google LLC. All Rights Reserved.
34221 * Licensed under the Apache License, Version 2.0 (the "License");
34222 * you may not use this file except in compliance with the License.
34223 * You may obtain a copy of the License at
34224 *
34225 * http://www.apache.org/licenses/LICENSE-2.0
34226 *
34227 * Unless required by applicable law or agreed to in writing, software
34228 * distributed under the License is distributed on an "AS IS" BASIS,
34229 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34230 * See the License for the specific language governing permissions and
34231 * limitations under the License.
34232 * =============================================================================
34233 */
34234 getGlobalTensorClass().prototype.logSumExp = function (axis, keepDims) {
34235 this.throwIfDisposed();
34236 return logSumExp(this, axis, keepDims);
34237 };
34238
34239 /**
34240 * @license
34241 * Copyright 2020 Google LLC. All Rights Reserved.
34242 * Licensed under the Apache License, Version 2.0 (the "License");
34243 * you may not use this file except in compliance with the License.
34244 * You may obtain a copy of the License at
34245 *
34246 * http://www.apache.org/licenses/LICENSE-2.0
34247 *
34248 * Unless required by applicable law or agreed to in writing, software
34249 * distributed under the License is distributed on an "AS IS" BASIS,
34250 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34251 * See the License for the specific language governing permissions and
34252 * limitations under the License.
34253 * =============================================================================
34254 */
34255 getGlobalTensorClass().prototype.log = function () {
34256 this.throwIfDisposed();
34257 return log$2(this);
34258 };
34259
34260 /**
34261 * @license
34262 * Copyright 2020 Google LLC. All Rights Reserved.
34263 * Licensed under the Apache License, Version 2.0 (the "License");
34264 * you may not use this file except in compliance with the License.
34265 * You may obtain a copy of the License at
34266 *
34267 * http://www.apache.org/licenses/LICENSE-2.0
34268 *
34269 * Unless required by applicable law or agreed to in writing, software
34270 * distributed under the License is distributed on an "AS IS" BASIS,
34271 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34272 * See the License for the specific language governing permissions and
34273 * limitations under the License.
34274 * =============================================================================
34275 */
34276 getGlobalTensorClass().prototype.log1p = function () {
34277 this.throwIfDisposed();
34278 return log1p$2(this);
34279 };
34280
34281 /**
34282 * @license
34283 * Copyright 2020 Google LLC. All Rights Reserved.
34284 * Licensed under the Apache License, Version 2.0 (the "License");
34285 * you may not use this file except in compliance with the License.
34286 * You may obtain a copy of the License at
34287 *
34288 * http://www.apache.org/licenses/LICENSE-2.0
34289 *
34290 * Unless required by applicable law or agreed to in writing, software
34291 * distributed under the License is distributed on an "AS IS" BASIS,
34292 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34293 * See the License for the specific language governing permissions and
34294 * limitations under the License.
34295 * =============================================================================
34296 */
34297 getGlobalTensorClass().prototype.logicalAnd = function (b) {
34298 this.throwIfDisposed();
34299 return logicalAnd$2(this, b);
34300 };
34301
34302 /**
34303 * @license
34304 * Copyright 2020 Google LLC. All Rights Reserved.
34305 * Licensed under the Apache License, Version 2.0 (the "License");
34306 * you may not use this file except in compliance with the License.
34307 * You may obtain a copy of the License at
34308 *
34309 * http://www.apache.org/licenses/LICENSE-2.0
34310 *
34311 * Unless required by applicable law or agreed to in writing, software
34312 * distributed under the License is distributed on an "AS IS" BASIS,
34313 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34314 * See the License for the specific language governing permissions and
34315 * limitations under the License.
34316 * =============================================================================
34317 */
34318 getGlobalTensorClass().prototype.logicalNot = function () {
34319 this.throwIfDisposed();
34320 return logicalNot$2(this);
34321 };
34322
34323 /**
34324 * @license
34325 * Copyright 2020 Google LLC. All Rights Reserved.
34326 * Licensed under the Apache License, Version 2.0 (the "License");
34327 * you may not use this file except in compliance with the License.
34328 * You may obtain a copy of the License at
34329 *
34330 * http://www.apache.org/licenses/LICENSE-2.0
34331 *
34332 * Unless required by applicable law or agreed to in writing, software
34333 * distributed under the License is distributed on an "AS IS" BASIS,
34334 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34335 * See the License for the specific language governing permissions and
34336 * limitations under the License.
34337 * =============================================================================
34338 */
34339 getGlobalTensorClass().prototype.logicalOr = function (b) {
34340 this.throwIfDisposed();
34341 return logicalOr$2(this, b);
34342 };
34343
34344 /**
34345 * @license
34346 * Copyright 2020 Google LLC. All Rights Reserved.
34347 * Licensed under the Apache License, Version 2.0 (the "License");
34348 * you may not use this file except in compliance with the License.
34349 * You may obtain a copy of the License at
34350 *
34351 * http://www.apache.org/licenses/LICENSE-2.0
34352 *
34353 * Unless required by applicable law or agreed to in writing, software
34354 * distributed under the License is distributed on an "AS IS" BASIS,
34355 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34356 * See the License for the specific language governing permissions and
34357 * limitations under the License.
34358 * =============================================================================
34359 */
34360 getGlobalTensorClass().prototype.logicalXor = function (b) {
34361 this.throwIfDisposed();
34362 return logicalXor(this, b);
34363 };
34364
34365 /**
34366 * @license
34367 * Copyright 2020 Google LLC. All Rights Reserved.
34368 * Licensed under the Apache License, Version 2.0 (the "License");
34369 * you may not use this file except in compliance with the License.
34370 * You may obtain a copy of the License at
34371 *
34372 * http://www.apache.org/licenses/LICENSE-2.0
34373 *
34374 * Unless required by applicable law or agreed to in writing, software
34375 * distributed under the License is distributed on an "AS IS" BASIS,
34376 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34377 * See the License for the specific language governing permissions and
34378 * limitations under the License.
34379 * =============================================================================
34380 */
34381 getGlobalTensorClass().prototype.matMul = function (b, transposeA, transposeB) {
34382 this.throwIfDisposed();
34383 return matMul$1(this, b, transposeA, transposeB);
34384 };
34385
34386 getGlobalTensorClass().prototype.maxPool =
34387 function (filterSize, strides, pad, dimRoundingMode) {
34388 this.throwIfDisposed();
34389 return maxPool$2(this, filterSize, strides, pad, dimRoundingMode);
34390 };
34391
34392 /**
34393 * @license
34394 * Copyright 2020 Google LLC. All Rights Reserved.
34395 * Licensed under the Apache License, Version 2.0 (the "License");
34396 * you may not use this file except in compliance with the License.
34397 * You may obtain a copy of the License at
34398 *
34399 * http://www.apache.org/licenses/LICENSE-2.0
34400 *
34401 * Unless required by applicable law or agreed to in writing, software
34402 * distributed under the License is distributed on an "AS IS" BASIS,
34403 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34404 * See the License for the specific language governing permissions and
34405 * limitations under the License.
34406 * =============================================================================
34407 */
34408 getGlobalTensorClass().prototype.max = function (axis, keepDims) {
34409 this.throwIfDisposed();
34410 return max$3(this, axis, keepDims);
34411 };
34412
34413 /**
34414 * @license
34415 * Copyright 2020 Google LLC. All Rights Reserved.
34416 * Licensed under the Apache License, Version 2.0 (the "License");
34417 * you may not use this file except in compliance with the License.
34418 * You may obtain a copy of the License at
34419 *
34420 * http://www.apache.org/licenses/LICENSE-2.0
34421 *
34422 * Unless required by applicable law or agreed to in writing, software
34423 * distributed under the License is distributed on an "AS IS" BASIS,
34424 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34425 * See the License for the specific language governing permissions and
34426 * limitations under the License.
34427 * =============================================================================
34428 */
34429 getGlobalTensorClass().prototype.maximum = function (b) {
34430 this.throwIfDisposed();
34431 return maximum$4(this, b);
34432 };
34433
34434 /**
34435 * @license
34436 * Copyright 2020 Google LLC. All Rights Reserved.
34437 * Licensed under the Apache License, Version 2.0 (the "License");
34438 * you may not use this file except in compliance with the License.
34439 * You may obtain a copy of the License at
34440 *
34441 * http://www.apache.org/licenses/LICENSE-2.0
34442 *
34443 * Unless required by applicable law or agreed to in writing, software
34444 * distributed under the License is distributed on an "AS IS" BASIS,
34445 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34446 * See the License for the specific language governing permissions and
34447 * limitations under the License.
34448 * =============================================================================
34449 */
34450 getGlobalTensorClass().prototype.mean = function (axis, keepDims) {
34451 this.throwIfDisposed();
34452 return mean$3(this, axis, keepDims);
34453 };
34454
34455 /**
34456 * @license
34457 * Copyright 2020 Google LLC. All Rights Reserved.
34458 * Licensed under the Apache License, Version 2.0 (the "License");
34459 * you may not use this file except in compliance with the License.
34460 * You may obtain a copy of the License at
34461 *
34462 * http://www.apache.org/licenses/LICENSE-2.0
34463 *
34464 * Unless required by applicable law or agreed to in writing, software
34465 * distributed under the License is distributed on an "AS IS" BASIS,
34466 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34467 * See the License for the specific language governing permissions and
34468 * limitations under the License.
34469 * =============================================================================
34470 */
34471 getGlobalTensorClass().prototype.min = function (axis, keepDims) {
34472 this.throwIfDisposed();
34473 return min$3(this, axis, keepDims);
34474 };
34475
34476 /**
34477 * @license
34478 * Copyright 2020 Google LLC. All Rights Reserved.
34479 * Licensed under the Apache License, Version 2.0 (the "License");
34480 * you may not use this file except in compliance with the License.
34481 * You may obtain a copy of the License at
34482 *
34483 * http://www.apache.org/licenses/LICENSE-2.0
34484 *
34485 * Unless required by applicable law or agreed to in writing, software
34486 * distributed under the License is distributed on an "AS IS" BASIS,
34487 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34488 * See the License for the specific language governing permissions and
34489 * limitations under the License.
34490 * =============================================================================
34491 */
34492 getGlobalTensorClass().prototype.minimum = function (b) {
34493 this.throwIfDisposed();
34494 return minimum$4(this, b);
34495 };
34496
34497 /**
34498 * @license
34499 * Copyright 2020 Google LLC. All Rights Reserved.
34500 * Licensed under the Apache License, Version 2.0 (the "License");
34501 * you may not use this file except in compliance with the License.
34502 * You may obtain a copy of the License at
34503 *
34504 * http://www.apache.org/licenses/LICENSE-2.0
34505 *
34506 * Unless required by applicable law or agreed to in writing, software
34507 * distributed under the License is distributed on an "AS IS" BASIS,
34508 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34509 * See the License for the specific language governing permissions and
34510 * limitations under the License.
34511 * =============================================================================
34512 */
34513 getGlobalTensorClass().prototype.mirrorPad = function (paddings, mode) {
34514 this.throwIfDisposed();
34515 return mirrorPad$1(this, paddings, mode);
34516 };
34517
34518 /**
34519 * @license
34520 * Copyright 2020 Google LLC. All Rights Reserved.
34521 * Licensed under the Apache License, Version 2.0 (the "License");
34522 * you may not use this file except in compliance with the License.
34523 * You may obtain a copy of the License at
34524 *
34525 * http://www.apache.org/licenses/LICENSE-2.0
34526 *
34527 * Unless required by applicable law or agreed to in writing, software
34528 * distributed under the License is distributed on an "AS IS" BASIS,
34529 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34530 * See the License for the specific language governing permissions and
34531 * limitations under the License.
34532 * =============================================================================
34533 */
34534 getGlobalTensorClass().prototype.mod = function (b) {
34535 this.throwIfDisposed();
34536 return mod$2(this, b);
34537 };
34538
34539 /**
34540 * @license
34541 * Copyright 2020 Google LLC. All Rights Reserved.
34542 * Licensed under the Apache License, Version 2.0 (the "License");
34543 * you may not use this file except in compliance with the License.
34544 * You may obtain a copy of the License at
34545 *
34546 * http://www.apache.org/licenses/LICENSE-2.0
34547 *
34548 * Unless required by applicable law or agreed to in writing, software
34549 * distributed under the License is distributed on an "AS IS" BASIS,
34550 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34551 * See the License for the specific language governing permissions and
34552 * limitations under the License.
34553 * =============================================================================
34554 */
34555 getGlobalTensorClass().prototype.mul = function (b) {
34556 this.throwIfDisposed();
34557 return mul(this, b);
34558 };
34559
34560 /**
34561 * @license
34562 * Copyright 2020 Google LLC. All Rights Reserved.
34563 * Licensed under the Apache License, Version 2.0 (the "License");
34564 * you may not use this file except in compliance with the License.
34565 * You may obtain a copy of the License at
34566 *
34567 * http://www.apache.org/licenses/LICENSE-2.0
34568 *
34569 * Unless required by applicable law or agreed to in writing, software
34570 * distributed under the License is distributed on an "AS IS" BASIS,
34571 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34572 * See the License for the specific language governing permissions and
34573 * limitations under the License.
34574 * =============================================================================
34575 */
34576 getGlobalTensorClass().prototype.neg = function () {
34577 this.throwIfDisposed();
34578 return neg$2(this);
34579 };
34580
34581 /**
34582 * @license
34583 * Copyright 2020 Google LLC. All Rights Reserved.
34584 * Licensed under the Apache License, Version 2.0 (the "License");
34585 * you may not use this file except in compliance with the License.
34586 * You may obtain a copy of the License at
34587 *
34588 * http://www.apache.org/licenses/LICENSE-2.0
34589 *
34590 * Unless required by applicable law or agreed to in writing, software
34591 * distributed under the License is distributed on an "AS IS" BASIS,
34592 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34593 * See the License for the specific language governing permissions and
34594 * limitations under the License.
34595 * =============================================================================
34596 */
34597 getGlobalTensorClass().prototype.norm = function (ord, axis, keepDims) {
34598 this.throwIfDisposed();
34599 return norm(this, ord, axis, keepDims);
34600 };
34601
34602 /**
34603 * @license
34604 * Copyright 2020 Google LLC. All Rights Reserved.
34605 * Licensed under the Apache License, Version 2.0 (the "License");
34606 * you may not use this file except in compliance with the License.
34607 * You may obtain a copy of the License at
34608 *
34609 * http://www.apache.org/licenses/LICENSE-2.0
34610 *
34611 * Unless required by applicable law or agreed to in writing, software
34612 * distributed under the License is distributed on an "AS IS" BASIS,
34613 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34614 * See the License for the specific language governing permissions and
34615 * limitations under the License.
34616 * =============================================================================
34617 */
34618 getGlobalTensorClass().prototype.notEqual = function (b) {
34619 this.throwIfDisposed();
34620 return notEqual$2(this, b);
34621 };
34622
34623 /**
34624 * @license
34625 * Copyright 2020 Google LLC. All Rights Reserved.
34626 * Licensed under the Apache License, Version 2.0 (the "License");
34627 * you may not use this file except in compliance with the License.
34628 * You may obtain a copy of the License at
34629 *
34630 * http://www.apache.org/licenses/LICENSE-2.0
34631 *
34632 * Unless required by applicable law or agreed to in writing, software
34633 * distributed under the License is distributed on an "AS IS" BASIS,
34634 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34635 * See the License for the specific language governing permissions and
34636 * limitations under the License.
34637 * =============================================================================
34638 */
34639 getGlobalTensorClass().prototype.oneHot = function (depth, onValue = 1, offValue = 0) {
34640 this.throwIfDisposed();
34641 return oneHot$3(this, depth, onValue, offValue);
34642 };
34643
34644 /**
34645 * @license
34646 * Copyright 2020 Google LLC. All Rights Reserved.
34647 * Licensed under the Apache License, Version 2.0 (the "License");
34648 * you may not use this file except in compliance with the License.
34649 * You may obtain a copy of the License at
34650 *
34651 * http://www.apache.org/licenses/LICENSE-2.0
34652 *
34653 * Unless required by applicable law or agreed to in writing, software
34654 * distributed under the License is distributed on an "AS IS" BASIS,
34655 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34656 * See the License for the specific language governing permissions and
34657 * limitations under the License.
34658 * =============================================================================
34659 */
34660 getGlobalTensorClass().prototype.onesLike = function () {
34661 this.throwIfDisposed();
34662 return onesLike$3(this);
34663 };
34664
34665 /**
34666 * @license
34667 * Copyright 2020 Google LLC. All Rights Reserved.
34668 * Licensed under the Apache License, Version 2.0 (the "License");
34669 * you may not use this file except in compliance with the License.
34670 * You may obtain a copy of the License at
34671 *
34672 * http://www.apache.org/licenses/LICENSE-2.0
34673 *
34674 * Unless required by applicable law or agreed to in writing, software
34675 * distributed under the License is distributed on an "AS IS" BASIS,
34676 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34677 * See the License for the specific language governing permissions and
34678 * limitations under the License.
34679 * =============================================================================
34680 */
34681 getGlobalTensorClass().prototype.pad = function (paddings, constantValue) {
34682 this.throwIfDisposed();
34683 return pad(this, paddings, constantValue);
34684 };
34685
34686 getGlobalTensorClass().prototype.pool = function (windowShape, poolingType, padding, dilationRate, strides, dimRoundingMode) {
34687 this.throwIfDisposed();
34688 return pool$1(this, windowShape, poolingType, padding, dilationRate, strides, dimRoundingMode);
34689 };
34690
34691 /**
34692 * @license
34693 * Copyright 2020 Google LLC. All Rights Reserved.
34694 * Licensed under the Apache License, Version 2.0 (the "License");
34695 * you may not use this file except in compliance with the License.
34696 * You may obtain a copy of the License at
34697 *
34698 * http://www.apache.org/licenses/LICENSE-2.0
34699 *
34700 * Unless required by applicable law or agreed to in writing, software
34701 * distributed under the License is distributed on an "AS IS" BASIS,
34702 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34703 * See the License for the specific language governing permissions and
34704 * limitations under the License.
34705 * =============================================================================
34706 */
34707 getGlobalTensorClass().prototype.pow = function (exp) {
34708 this.throwIfDisposed();
34709 return pow$3(this, exp);
34710 };
34711
34712 /**
34713 * @license
34714 * Copyright 2020 Google LLC. All Rights Reserved.
34715 * Licensed under the Apache License, Version 2.0 (the "License");
34716 * you may not use this file except in compliance with the License.
34717 * You may obtain a copy of the License at
34718 *
34719 * http://www.apache.org/licenses/LICENSE-2.0
34720 *
34721 * Unless required by applicable law or agreed to in writing, software
34722 * distributed under the License is distributed on an "AS IS" BASIS,
34723 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34724 * See the License for the specific language governing permissions and
34725 * limitations under the License.
34726 * =============================================================================
34727 */
34728 getGlobalTensorClass().prototype.prelu = function (alpha) {
34729 this.throwIfDisposed();
34730 return prelu$3(this, alpha);
34731 };
34732
34733 /**
34734 * @license
34735 * Copyright 2020 Google LLC. All Rights Reserved.
34736 * Licensed under the Apache License, Version 2.0 (the "License");
34737 * you may not use this file except in compliance with the License.
34738 * You may obtain a copy of the License at
34739 *
34740 * http://www.apache.org/licenses/LICENSE-2.0
34741 *
34742 * Unless required by applicable law or agreed to in writing, software
34743 * distributed under the License is distributed on an "AS IS" BASIS,
34744 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34745 * See the License for the specific language governing permissions and
34746 * limitations under the License.
34747 * =============================================================================
34748 */
34749 getGlobalTensorClass().prototype.prod = function (axis, keepDims) {
34750 this.throwIfDisposed();
34751 return prod$2(this, axis, keepDims);
34752 };
34753
34754 /**
34755 * @license
34756 * Copyright 2020 Google LLC. All Rights Reserved.
34757 * Licensed under the Apache License, Version 2.0 (the "License");
34758 * you may not use this file except in compliance with the License.
34759 * You may obtain a copy of the License at
34760 *
34761 * http://www.apache.org/licenses/LICENSE-2.0
34762 *
34763 * Unless required by applicable law or agreed to in writing, software
34764 * distributed under the License is distributed on an "AS IS" BASIS,
34765 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34766 * See the License for the specific language governing permissions and
34767 * limitations under the License.
34768 * =============================================================================
34769 */
34770 getGlobalTensorClass().prototype.reciprocal = function () {
34771 this.throwIfDisposed();
34772 return reciprocal$2(this);
34773 };
34774
34775 /**
34776 * @license
34777 * Copyright 2020 Google LLC. All Rights Reserved.
34778 * Licensed under the Apache License, Version 2.0 (the "License");
34779 * you may not use this file except in compliance with the License.
34780 * You may obtain a copy of the License at
34781 *
34782 * http://www.apache.org/licenses/LICENSE-2.0
34783 *
34784 * Unless required by applicable law or agreed to in writing, software
34785 * distributed under the License is distributed on an "AS IS" BASIS,
34786 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34787 * See the License for the specific language governing permissions and
34788 * limitations under the License.
34789 * =============================================================================
34790 */
34791 getGlobalTensorClass().prototype.relu = function () {
34792 this.throwIfDisposed();
34793 return relu$2(this);
34794 };
34795
34796 /**
34797 * @license
34798 * Copyright 2020 Google LLC. All Rights Reserved.
34799 * Licensed under the Apache License, Version 2.0 (the "License");
34800 * you may not use this file except in compliance with the License.
34801 * You may obtain a copy of the License at
34802 *
34803 * http://www.apache.org/licenses/LICENSE-2.0
34804 *
34805 * Unless required by applicable law or agreed to in writing, software
34806 * distributed under the License is distributed on an "AS IS" BASIS,
34807 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34808 * See the License for the specific language governing permissions and
34809 * limitations under the License.
34810 * =============================================================================
34811 */
34812 getGlobalTensorClass().prototype.relu6 = function () {
34813 this.throwIfDisposed();
34814 return relu6$2(this);
34815 };
34816
34817 /**
34818 * @license
34819 * Copyright 2020 Google LLC. All Rights Reserved.
34820 * Licensed under the Apache License, Version 2.0 (the "License");
34821 * you may not use this file except in compliance with the License.
34822 * You may obtain a copy of the License at
34823 *
34824 * http://www.apache.org/licenses/LICENSE-2.0
34825 *
34826 * Unless required by applicable law or agreed to in writing, software
34827 * distributed under the License is distributed on an "AS IS" BASIS,
34828 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34829 * See the License for the specific language governing permissions and
34830 * limitations under the License.
34831 * =============================================================================
34832 */
34833 /**
34834 * Reshapes the tensor into the shape of the provided tensor.
34835 *
34836 * @param x The tensor of required shape.
34837 *
34838 * @doc {heading: 'Tensors', subheading: 'Classes'}
34839 */
34840 getGlobalTensorClass().prototype.reshapeAs = function (x) {
34841 this.throwIfDisposed();
34842 return reshape$3(this, x.shape);
34843 };
34844
34845 /**
34846 * @license
34847 * Copyright 2020 Google LLC. All Rights Reserved.
34848 * Licensed under the Apache License, Version 2.0 (the "License");
34849 * you may not use this file except in compliance with the License.
34850 * You may obtain a copy of the License at
34851 *
34852 * http://www.apache.org/licenses/LICENSE-2.0
34853 *
34854 * Unless required by applicable law or agreed to in writing, software
34855 * distributed under the License is distributed on an "AS IS" BASIS,
34856 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34857 * See the License for the specific language governing permissions and
34858 * limitations under the License.
34859 * =============================================================================
34860 */
34861 getGlobalTensorClass().prototype.reshape = function (shape) {
34862 this.throwIfDisposed();
34863 return reshape$3(this, shape);
34864 };
34865
34866 /**
34867 * @license
34868 * Copyright 2020 Google LLC. All Rights Reserved.
34869 * Licensed under the Apache License, Version 2.0 (the "License");
34870 * you may not use this file except in compliance with the License.
34871 * You may obtain a copy of the License at
34872 *
34873 * http://www.apache.org/licenses/LICENSE-2.0
34874 *
34875 * Unless required by applicable law or agreed to in writing, software
34876 * distributed under the License is distributed on an "AS IS" BASIS,
34877 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34878 * See the License for the specific language governing permissions and
34879 * limitations under the License.
34880 * =============================================================================
34881 */
34882 getGlobalTensorClass().prototype.resizeBilinear =
34883 function (newShape2D, alignCorners, halfPixelCenters) {
34884 this.throwIfDisposed();
34885 return resizeBilinear$3(this, newShape2D, alignCorners, halfPixelCenters);
34886 };
34887
34888 /**
34889 * @license
34890 * Copyright 2020 Google LLC. All Rights Reserved.
34891 * Licensed under the Apache License, Version 2.0 (the "License");
34892 * you may not use this file except in compliance with the License.
34893 * You may obtain a copy of the License at
34894 *
34895 * http://www.apache.org/licenses/LICENSE-2.0
34896 *
34897 * Unless required by applicable law or agreed to in writing, software
34898 * distributed under the License is distributed on an "AS IS" BASIS,
34899 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34900 * See the License for the specific language governing permissions and
34901 * limitations under the License.
34902 * =============================================================================
34903 */
34904 getGlobalTensorClass().prototype.resizeNearestNeighbor =
34905 function (newShape2D, alignCorners, halfFloatCenters) {
34906 this.throwIfDisposed();
34907 return resizeNearestNeighbor$2(this, newShape2D, alignCorners, halfFloatCenters);
34908 };
34909
34910 /**
34911 * @license
34912 * Copyright 2020 Google LLC. All Rights Reserved.
34913 * Licensed under the Apache License, Version 2.0 (the "License");
34914 * you may not use this file except in compliance with the License.
34915 * You may obtain a copy of the License at
34916 *
34917 * http://www.apache.org/licenses/LICENSE-2.0
34918 *
34919 * Unless required by applicable law or agreed to in writing, software
34920 * distributed under the License is distributed on an "AS IS" BASIS,
34921 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34922 * See the License for the specific language governing permissions and
34923 * limitations under the License.
34924 * =============================================================================
34925 */
34926 getGlobalTensorClass().prototype.reverse = function (axis) {
34927 this.throwIfDisposed();
34928 return reverse$2(this, axis);
34929 };
34930
34931 /**
34932 * @license
34933 * Copyright 2020 Google LLC. All Rights Reserved.
34934 * Licensed under the Apache License, Version 2.0 (the "License");
34935 * you may not use this file except in compliance with the License.
34936 * You may obtain a copy of the License at
34937 *
34938 * http://www.apache.org/licenses/LICENSE-2.0
34939 *
34940 * Unless required by applicable law or agreed to in writing, software
34941 * distributed under the License is distributed on an "AS IS" BASIS,
34942 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34943 * See the License for the specific language governing permissions and
34944 * limitations under the License.
34945 * =============================================================================
34946 */
34947 getGlobalTensorClass().prototype.rfft = function () {
34948 this.throwIfDisposed();
34949 return rfft(this);
34950 };
34951
34952 /**
34953 * @license
34954 * Copyright 2020 Google LLC. All Rights Reserved.
34955 * Licensed under the Apache License, Version 2.0 (the "License");
34956 * you may not use this file except in compliance with the License.
34957 * You may obtain a copy of the License at
34958 *
34959 * http://www.apache.org/licenses/LICENSE-2.0
34960 *
34961 * Unless required by applicable law or agreed to in writing, software
34962 * distributed under the License is distributed on an "AS IS" BASIS,
34963 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34964 * See the License for the specific language governing permissions and
34965 * limitations under the License.
34966 * =============================================================================
34967 */
34968 getGlobalTensorClass().prototype.round = function () {
34969 this.throwIfDisposed();
34970 return round$2(this);
34971 };
34972
34973 /**
34974 * @license
34975 * Copyright 2020 Google LLC. All Rights Reserved.
34976 * Licensed under the Apache License, Version 2.0 (the "License");
34977 * you may not use this file except in compliance with the License.
34978 * You may obtain a copy of the License at
34979 *
34980 * http://www.apache.org/licenses/LICENSE-2.0
34981 *
34982 * Unless required by applicable law or agreed to in writing, software
34983 * distributed under the License is distributed on an "AS IS" BASIS,
34984 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34985 * See the License for the specific language governing permissions and
34986 * limitations under the License.
34987 * =============================================================================
34988 */
34989 getGlobalTensorClass().prototype.rsqrt = function () {
34990 this.throwIfDisposed();
34991 return rsqrt$2(this);
34992 };
34993
34994 /**
34995 * @license
34996 * Copyright 2020 Google LLC. All Rights Reserved.
34997 * Licensed under the Apache License, Version 2.0 (the "License");
34998 * you may not use this file except in compliance with the License.
34999 * You may obtain a copy of the License at
35000 *
35001 * http://www.apache.org/licenses/LICENSE-2.0
35002 *
35003 * Unless required by applicable law or agreed to in writing, software
35004 * distributed under the License is distributed on an "AS IS" BASIS,
35005 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35006 * See the License for the specific language governing permissions and
35007 * limitations under the License.
35008 * =============================================================================
35009 */
35010 getGlobalTensorClass().prototype.selu = function () {
35011 this.throwIfDisposed();
35012 return selu$2(this);
35013 };
35014
35015 /**
35016 * @license
35017 * Copyright 2020 Google LLC. All Rights Reserved.
35018 * Licensed under the Apache License, Version 2.0 (the "License");
35019 * you may not use this file except in compliance with the License.
35020 * You may obtain a copy of the License at
35021 *
35022 * http://www.apache.org/licenses/LICENSE-2.0
35023 *
35024 * Unless required by applicable law or agreed to in writing, software
35025 * distributed under the License is distributed on an "AS IS" BASIS,
35026 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35027 * See the License for the specific language governing permissions and
35028 * limitations under the License.
35029 * =============================================================================
35030 */
35031 getGlobalTensorClass().prototype.separableConv2d =
35032 function (depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat) {
35033 this.throwIfDisposed();
35034 return separableConv2d$1(this, depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat);
35035 };
35036
35037 /**
35038 * @license
35039 * Copyright 2020 Google LLC. All Rights Reserved.
35040 * Licensed under the Apache License, Version 2.0 (the "License");
35041 * you may not use this file except in compliance with the License.
35042 * You may obtain a copy of the License at
35043 *
35044 * http://www.apache.org/licenses/LICENSE-2.0
35045 *
35046 * Unless required by applicable law or agreed to in writing, software
35047 * distributed under the License is distributed on an "AS IS" BASIS,
35048 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35049 * See the License for the specific language governing permissions and
35050 * limitations under the License.
35051 * =============================================================================
35052 */
35053 getGlobalTensorClass().prototype.sigmoid = function () {
35054 this.throwIfDisposed();
35055 return sigmoid$2(this);
35056 };
35057
35058 /**
35059 * @license
35060 * Copyright 2020 Google LLC. All Rights Reserved.
35061 * Licensed under the Apache License, Version 2.0 (the "License");
35062 * you may not use this file except in compliance with the License.
35063 * You may obtain a copy of the License at
35064 *
35065 * http://www.apache.org/licenses/LICENSE-2.0
35066 *
35067 * Unless required by applicable law or agreed to in writing, software
35068 * distributed under the License is distributed on an "AS IS" BASIS,
35069 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35070 * See the License for the specific language governing permissions and
35071 * limitations under the License.
35072 * =============================================================================
35073 */
35074 getGlobalTensorClass().prototype.sign = function () {
35075 this.throwIfDisposed();
35076 return sign$3(this);
35077 };
35078
35079 /**
35080 * @license
35081 * Copyright 2020 Google LLC. All Rights Reserved.
35082 * Licensed under the Apache License, Version 2.0 (the "License");
35083 * you may not use this file except in compliance with the License.
35084 * You may obtain a copy of the License at
35085 *
35086 * http://www.apache.org/licenses/LICENSE-2.0
35087 *
35088 * Unless required by applicable law or agreed to in writing, software
35089 * distributed under the License is distributed on an "AS IS" BASIS,
35090 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35091 * See the License for the specific language governing permissions and
35092 * limitations under the License.
35093 * =============================================================================
35094 */
35095 getGlobalTensorClass().prototype.sin = function () {
35096 this.throwIfDisposed();
35097 return sin$2(this);
35098 };
35099
35100 /**
35101 * @license
35102 * Copyright 2020 Google LLC. All Rights Reserved.
35103 * Licensed under the Apache License, Version 2.0 (the "License");
35104 * you may not use this file except in compliance with the License.
35105 * You may obtain a copy of the License at
35106 *
35107 * http://www.apache.org/licenses/LICENSE-2.0
35108 *
35109 * Unless required by applicable law or agreed to in writing, software
35110 * distributed under the License is distributed on an "AS IS" BASIS,
35111 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35112 * See the License for the specific language governing permissions and
35113 * limitations under the License.
35114 * =============================================================================
35115 */
35116 getGlobalTensorClass().prototype.sinh = function () {
35117 this.throwIfDisposed();
35118 return sinh$2(this);
35119 };
35120
35121 /**
35122 * @license
35123 * Copyright 2020 Google LLC. All Rights Reserved.
35124 * Licensed under the Apache License, Version 2.0 (the "License");
35125 * you may not use this file except in compliance with the License.
35126 * You may obtain a copy of the License at
35127 *
35128 * http://www.apache.org/licenses/LICENSE-2.0
35129 *
35130 * Unless required by applicable law or agreed to in writing, software
35131 * distributed under the License is distributed on an "AS IS" BASIS,
35132 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35133 * See the License for the specific language governing permissions and
35134 * limitations under the License.
35135 * =============================================================================
35136 */
35137 getGlobalTensorClass().prototype.slice = function (begin, size) {
35138 this.throwIfDisposed();
35139 return slice$2(this, begin, size);
35140 };
35141
35142 /**
35143 * @license
35144 * Copyright 2020 Google LLC. All Rights Reserved.
35145 * Licensed under the Apache License, Version 2.0 (the "License");
35146 * you may not use this file except in compliance with the License.
35147 * You may obtain a copy of the License at
35148 *
35149 * http://www.apache.org/licenses/LICENSE-2.0
35150 *
35151 * Unless required by applicable law or agreed to in writing, software
35152 * distributed under the License is distributed on an "AS IS" BASIS,
35153 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35154 * See the License for the specific language governing permissions and
35155 * limitations under the License.
35156 * =============================================================================
35157 */
35158 getGlobalTensorClass().prototype.softmax = function (dim) {
35159 this.throwIfDisposed();
35160 return softmax$3(this, dim);
35161 };
35162
35163 /**
35164 * @license
35165 * Copyright 2020 Google LLC. All Rights Reserved.
35166 * Licensed under the Apache License, Version 2.0 (the "License");
35167 * you may not use this file except in compliance with the License.
35168 * You may obtain a copy of the License at
35169 *
35170 * http://www.apache.org/licenses/LICENSE-2.0
35171 *
35172 * Unless required by applicable law or agreed to in writing, software
35173 * distributed under the License is distributed on an "AS IS" BASIS,
35174 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35175 * See the License for the specific language governing permissions and
35176 * limitations under the License.
35177 * =============================================================================
35178 */
35179 getGlobalTensorClass().prototype.softplus = function () {
35180 this.throwIfDisposed();
35181 return softplus$2(this);
35182 };
35183
35184 /**
35185 * @license
35186 * Copyright 2020 Google LLC. All Rights Reserved.
35187 * Licensed under the Apache License, Version 2.0 (the "License");
35188 * you may not use this file except in compliance with the License.
35189 * You may obtain a copy of the License at
35190 *
35191 * http://www.apache.org/licenses/LICENSE-2.0
35192 *
35193 * Unless required by applicable law or agreed to in writing, software
35194 * distributed under the License is distributed on an "AS IS" BASIS,
35195 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35196 * See the License for the specific language governing permissions and
35197 * limitations under the License.
35198 * =============================================================================
35199 */
35200 getGlobalTensorClass().prototype.spaceToBatchND = function (blockShape, paddings) {
35201 this.throwIfDisposed();
35202 return spaceToBatchND$2(this, blockShape, paddings);
35203 };
35204
35205 /**
35206 * @license
35207 * Copyright 2020 Google LLC. All Rights Reserved.
35208 * Licensed under the Apache License, Version 2.0 (the "License");
35209 * you may not use this file except in compliance with the License.
35210 * You may obtain a copy of the License at
35211 *
35212 * http://www.apache.org/licenses/LICENSE-2.0
35213 *
35214 * Unless required by applicable law or agreed to in writing, software
35215 * distributed under the License is distributed on an "AS IS" BASIS,
35216 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35217 * See the License for the specific language governing permissions and
35218 * limitations under the License.
35219 * =============================================================================
35220 */
35221 getGlobalTensorClass().prototype.split = function (numOrSizeSplits, axis) {
35222 this.throwIfDisposed();
35223 return split$3(this, numOrSizeSplits, axis);
35224 };
35225
35226 /**
35227 * @license
35228 * Copyright 2020 Google LLC. All Rights Reserved.
35229 * Licensed under the Apache License, Version 2.0 (the "License");
35230 * you may not use this file except in compliance with the License.
35231 * You may obtain a copy of the License at
35232 *
35233 * http://www.apache.org/licenses/LICENSE-2.0
35234 *
35235 * Unless required by applicable law or agreed to in writing, software
35236 * distributed under the License is distributed on an "AS IS" BASIS,
35237 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35238 * See the License for the specific language governing permissions and
35239 * limitations under the License.
35240 * =============================================================================
35241 */
35242 getGlobalTensorClass().prototype.sqrt = function () {
35243 this.throwIfDisposed();
35244 return sqrt$2(this);
35245 };
35246
35247 /**
35248 * @license
35249 * Copyright 2020 Google LLC. All Rights Reserved.
35250 * Licensed under the Apache License, Version 2.0 (the "License");
35251 * you may not use this file except in compliance with the License.
35252 * You may obtain a copy of the License at
35253 *
35254 * http://www.apache.org/licenses/LICENSE-2.0
35255 *
35256 * Unless required by applicable law or agreed to in writing, software
35257 * distributed under the License is distributed on an "AS IS" BASIS,
35258 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35259 * See the License for the specific language governing permissions and
35260 * limitations under the License.
35261 * =============================================================================
35262 */
35263 getGlobalTensorClass().prototype.square = function () {
35264 this.throwIfDisposed();
35265 return square$2(this);
35266 };
35267
35268 /**
35269 * @license
35270 * Copyright 2020 Google LLC. All Rights Reserved.
35271 * Licensed under the Apache License, Version 2.0 (the "License");
35272 * you may not use this file except in compliance with the License.
35273 * You may obtain a copy of the License at
35274 *
35275 * http://www.apache.org/licenses/LICENSE-2.0
35276 *
35277 * Unless required by applicable law or agreed to in writing, software
35278 * distributed under the License is distributed on an "AS IS" BASIS,
35279 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35280 * See the License for the specific language governing permissions and
35281 * limitations under the License.
35282 * =============================================================================
35283 */
35284 getGlobalTensorClass().prototype.squaredDifference = function (b) {
35285 this.throwIfDisposed();
35286 return squaredDifference$2(this, b);
35287 };
35288
35289 /**
35290 * @license
35291 * Copyright 2020 Google LLC. All Rights Reserved.
35292 * Licensed under the Apache License, Version 2.0 (the "License");
35293 * you may not use this file except in compliance with the License.
35294 * You may obtain a copy of the License at
35295 *
35296 * http://www.apache.org/licenses/LICENSE-2.0
35297 *
35298 * Unless required by applicable law or agreed to in writing, software
35299 * distributed under the License is distributed on an "AS IS" BASIS,
35300 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35301 * See the License for the specific language governing permissions and
35302 * limitations under the License.
35303 * =============================================================================
35304 */
35305 getGlobalTensorClass().prototype.squeeze = function (axis) {
35306 this.throwIfDisposed();
35307 return squeeze(this, axis);
35308 };
35309
35310 /**
35311 * @license
35312 * Copyright 2020 Google LLC. All Rights Reserved.
35313 * Licensed under the Apache License, Version 2.0 (the "License");
35314 * you may not use this file except in compliance with the License.
35315 * You may obtain a copy of the License at
35316 *
35317 * http://www.apache.org/licenses/LICENSE-2.0
35318 *
35319 * Unless required by applicable law or agreed to in writing, software
35320 * distributed under the License is distributed on an "AS IS" BASIS,
35321 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35322 * See the License for the specific language governing permissions and
35323 * limitations under the License.
35324 * =============================================================================
35325 */
35326 getGlobalTensorClass().prototype.stack = function (x, axis) {
35327 this.throwIfDisposed();
35328 const tensorsToBeStacked = x instanceof Tensor ? [this, x] : [this, ...x];
35329 return stack(tensorsToBeStacked, axis);
35330 };
35331
35332 /**
35333 * @license
35334 * Copyright 2020 Google LLC. All Rights Reserved.
35335 * Licensed under the Apache License, Version 2.0 (the "License");
35336 * you may not use this file except in compliance with the License.
35337 * You may obtain a copy of the License at
35338 *
35339 * http://www.apache.org/licenses/LICENSE-2.0
35340 *
35341 * Unless required by applicable law or agreed to in writing, software
35342 * distributed under the License is distributed on an "AS IS" BASIS,
35343 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35344 * See the License for the specific language governing permissions and
35345 * limitations under the License.
35346 * =============================================================================
35347 */
35348 getGlobalTensorClass().prototype.step = function (alpha) {
35349 this.throwIfDisposed();
35350 return step$2(this, alpha);
35351 };
35352
35353 /**
35354 * @license
35355 * Copyright 2020 Google LLC. All Rights Reserved.
35356 * Licensed under the Apache License, Version 2.0 (the "License");
35357 * you may not use this file except in compliance with the License.
35358 * You may obtain a copy of the License at
35359 *
35360 * http://www.apache.org/licenses/LICENSE-2.0
35361 *
35362 * Unless required by applicable law or agreed to in writing, software
35363 * distributed under the License is distributed on an "AS IS" BASIS,
35364 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35365 * See the License for the specific language governing permissions and
35366 * limitations under the License.
35367 * =============================================================================
35368 */
35369 getGlobalTensorClass().prototype.stridedSlice = function (begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
35370 this.throwIfDisposed();
35371 return stridedSlice$2(this, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
35372 };
35373
35374 /**
35375 * @license
35376 * Copyright 2020 Google LLC. All Rights Reserved.
35377 * Licensed under the Apache License, Version 2.0 (the "License");
35378 * you may not use this file except in compliance with the License.
35379 * You may obtain a copy of the License at
35380 *
35381 * http://www.apache.org/licenses/LICENSE-2.0
35382 *
35383 * Unless required by applicable law or agreed to in writing, software
35384 * distributed under the License is distributed on an "AS IS" BASIS,
35385 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35386 * See the License for the specific language governing permissions and
35387 * limitations under the License.
35388 * =============================================================================
35389 */
35390 getGlobalTensorClass().prototype.sub = function (b) {
35391 this.throwIfDisposed();
35392 return sub$2(this, b);
35393 };
35394
35395 /**
35396 * @license
35397 * Copyright 2020 Google LLC. All Rights Reserved.
35398 * Licensed under the Apache License, Version 2.0 (the "License");
35399 * you may not use this file except in compliance with the License.
35400 * You may obtain a copy of the License at
35401 *
35402 * http://www.apache.org/licenses/LICENSE-2.0
35403 *
35404 * Unless required by applicable law or agreed to in writing, software
35405 * distributed under the License is distributed on an "AS IS" BASIS,
35406 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35407 * See the License for the specific language governing permissions and
35408 * limitations under the License.
35409 * =============================================================================
35410 */
35411 getGlobalTensorClass().prototype.sum = function (axis, keepDims) {
35412 this.throwIfDisposed();
35413 return sum$3(this, axis, keepDims);
35414 };
35415
35416 /**
35417 * @license
35418 * Copyright 2020 Google LLC. All Rights Reserved.
35419 * Licensed under the Apache License, Version 2.0 (the "License");
35420 * you may not use this file except in compliance with the License.
35421 * You may obtain a copy of the License at
35422 *
35423 * http://www.apache.org/licenses/LICENSE-2.0
35424 *
35425 * Unless required by applicable law or agreed to in writing, software
35426 * distributed under the License is distributed on an "AS IS" BASIS,
35427 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35428 * See the License for the specific language governing permissions and
35429 * limitations under the License.
35430 * =============================================================================
35431 */
35432 getGlobalTensorClass().prototype.tan = function () {
35433 this.throwIfDisposed();
35434 return tan$2(this);
35435 };
35436
35437 /**
35438 * @license
35439 * Copyright 2020 Google LLC. All Rights Reserved.
35440 * Licensed under the Apache License, Version 2.0 (the "License");
35441 * you may not use this file except in compliance with the License.
35442 * You may obtain a copy of the License at
35443 *
35444 * http://www.apache.org/licenses/LICENSE-2.0
35445 *
35446 * Unless required by applicable law or agreed to in writing, software
35447 * distributed under the License is distributed on an "AS IS" BASIS,
35448 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35449 * See the License for the specific language governing permissions and
35450 * limitations under the License.
35451 * =============================================================================
35452 */
35453 getGlobalTensorClass().prototype.tanh = function () {
35454 this.throwIfDisposed();
35455 return tanh$2(this);
35456 };
35457
35458 /**
35459 * @license
35460 * Copyright 2020 Google LLC. All Rights Reserved.
35461 * Licensed under the Apache License, Version 2.0 (the "License");
35462 * you may not use this file except in compliance with the License.
35463 * You may obtain a copy of the License at
35464 *
35465 * http://www.apache.org/licenses/LICENSE-2.0
35466 *
35467 * Unless required by applicable law or agreed to in writing, software
35468 * distributed under the License is distributed on an "AS IS" BASIS,
35469 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35470 * See the License for the specific language governing permissions and
35471 * limitations under the License.
35472 * =============================================================================
35473 */
35474 getGlobalTensorClass().prototype.tile = function (reps) {
35475 this.throwIfDisposed();
35476 return tile$3(this, reps);
35477 };
35478
35479 /**
35480 * @license
35481 * Copyright 2020 Google LLC. All Rights Reserved.
35482 * Licensed under the Apache License, Version 2.0 (the "License");
35483 * you may not use this file except in compliance with the License.
35484 * You may obtain a copy of the License at
35485 *
35486 * http://www.apache.org/licenses/LICENSE-2.0
35487 *
35488 * Unless required by applicable law or agreed to in writing, software
35489 * distributed under the License is distributed on an "AS IS" BASIS,
35490 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35491 * See the License for the specific language governing permissions and
35492 * limitations under the License.
35493 * =============================================================================
35494 */
35495 /**
35496 * Casts the array to type `bool`
35497 *
35498 * @doc {heading: 'Tensors', subheading: 'Classes'}
35499 */
35500 getGlobalTensorClass().prototype.toBool = function () {
35501 this.throwIfDisposed();
35502 return cast$3(this, 'bool');
35503 };
35504
35505 /**
35506 * @license
35507 * Copyright 2020 Google LLC. All Rights Reserved.
35508 * Licensed under the Apache License, Version 2.0 (the "License");
35509 * you may not use this file except in compliance with the License.
35510 * You may obtain a copy of the License at
35511 *
35512 * http://www.apache.org/licenses/LICENSE-2.0
35513 *
35514 * Unless required by applicable law or agreed to in writing, software
35515 * distributed under the License is distributed on an "AS IS" BASIS,
35516 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35517 * See the License for the specific language governing permissions and
35518 * limitations under the License.
35519 * =============================================================================
35520 */
35521 /**
35522 * Casts the array to type `float32`
35523 *
35524 * @doc {heading: 'Tensors', subheading: 'Classes'}
35525 */
35526 getGlobalTensorClass().prototype.toFloat = function () {
35527 this.throwIfDisposed();
35528 return cast$3(this, 'float32');
35529 };
35530
35531 /**
35532 * @license
35533 * Copyright 2020 Google LLC. All Rights Reserved.
35534 * Licensed under the Apache License, Version 2.0 (the "License");
35535 * you may not use this file except in compliance with the License.
35536 * You may obtain a copy of the License at
35537 *
35538 * http://www.apache.org/licenses/LICENSE-2.0
35539 *
35540 * Unless required by applicable law or agreed to in writing, software
35541 * distributed under the License is distributed on an "AS IS" BASIS,
35542 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35543 * See the License for the specific language governing permissions and
35544 * limitations under the License.
35545 * =============================================================================
35546 */
35547 /**
35548 * Casts the array to type `int32`
35549 *
35550 * @doc {heading: 'Tensors', subheading: 'Classes'}
35551 */
35552 getGlobalTensorClass().prototype.toInt = function () {
35553 this.throwIfDisposed();
35554 return cast$3(this, 'int32');
35555 };
35556
35557 /**
35558 * @license
35559 * Copyright 2020 Google LLC. All Rights Reserved.
35560 * Licensed under the Apache License, Version 2.0 (the "License");
35561 * you may not use this file except in compliance with the License.
35562 * You may obtain a copy of the License at
35563 *
35564 * http://www.apache.org/licenses/LICENSE-2.0
35565 *
35566 * Unless required by applicable law or agreed to in writing, software
35567 * distributed under the License is distributed on an "AS IS" BASIS,
35568 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35569 * See the License for the specific language governing permissions and
35570 * limitations under the License.
35571 * =============================================================================
35572 */
35573 getGlobalTensorClass().prototype.topk = function (k, sorted) {
35574 this.throwIfDisposed();
35575 return topk(this, k, sorted);
35576 };
35577
35578 /**
35579 * @license
35580 * Copyright 2020 Google LLC. All Rights Reserved.
35581 * Licensed under the Apache License, Version 2.0 (the "License");
35582 * you may not use this file except in compliance with the License.
35583 * You may obtain a copy of the License at
35584 *
35585 * http://www.apache.org/licenses/LICENSE-2.0
35586 *
35587 * Unless required by applicable law or agreed to in writing, software
35588 * distributed under the License is distributed on an "AS IS" BASIS,
35589 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35590 * See the License for the specific language governing permissions and
35591 * limitations under the License.
35592 * =============================================================================
35593 */
35594 getGlobalTensorClass().prototype.transpose = function (perm) {
35595 this.throwIfDisposed();
35596 return transpose$2(this, perm);
35597 };
35598
35599 /**
35600 * @license
35601 * Copyright 2020 Google LLC. All Rights Reserved.
35602 * Licensed under the Apache License, Version 2.0 (the "License");
35603 * you may not use this file except in compliance with the License.
35604 * You may obtain a copy of the License at
35605 *
35606 * http://www.apache.org/licenses/LICENSE-2.0
35607 *
35608 * Unless required by applicable law or agreed to in writing, software
35609 * distributed under the License is distributed on an "AS IS" BASIS,
35610 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35611 * See the License for the specific language governing permissions and
35612 * limitations under the License.
35613 * =============================================================================
35614 */
35615 getGlobalTensorClass().prototype.unique = function (axis) {
35616 this.throwIfDisposed();
35617 return unique$3(this, axis);
35618 };
35619
35620 /**
35621 * @license
35622 * Copyright 2020 Google LLC. All Rights Reserved.
35623 * Licensed under the Apache License, Version 2.0 (the "License");
35624 * you may not use this file except in compliance with the License.
35625 * You may obtain a copy of the License at
35626 *
35627 * http://www.apache.org/licenses/LICENSE-2.0
35628 *
35629 * Unless required by applicable law or agreed to in writing, software
35630 * distributed under the License is distributed on an "AS IS" BASIS,
35631 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35632 * See the License for the specific language governing permissions and
35633 * limitations under the License.
35634 * =============================================================================
35635 */
35636 getGlobalTensorClass().prototype.unsortedSegmentSum =
35637 function (segmentIds, numSegments) {
35638 this.throwIfDisposed();
35639 return unsortedSegmentSum$2(this, segmentIds, numSegments);
35640 };
35641
35642 /**
35643 * @license
35644 * Copyright 2020 Google LLC. All Rights Reserved.
35645 * Licensed under the Apache License, Version 2.0 (the "License");
35646 * you may not use this file except in compliance with the License.
35647 * You may obtain a copy of the License at
35648 *
35649 * http://www.apache.org/licenses/LICENSE-2.0
35650 *
35651 * Unless required by applicable law or agreed to in writing, software
35652 * distributed under the License is distributed on an "AS IS" BASIS,
35653 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35654 * See the License for the specific language governing permissions and
35655 * limitations under the License.
35656 * =============================================================================
35657 */
35658 getGlobalTensorClass().prototype.unstack = function (axis) {
35659 this.throwIfDisposed();
35660 return unstack(this, axis);
35661 };
35662
35663 /**
35664 * @license
35665 * Copyright 2020 Google LLC. All Rights Reserved.
35666 * Licensed under the Apache License, Version 2.0 (the "License");
35667 * you may not use this file except in compliance with the License.
35668 * You may obtain a copy of the License at
35669 *
35670 * http://www.apache.org/licenses/LICENSE-2.0
35671 *
35672 * Unless required by applicable law or agreed to in writing, software
35673 * distributed under the License is distributed on an "AS IS" BASIS,
35674 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35675 * See the License for the specific language governing permissions and
35676 * limitations under the License.
35677 * =============================================================================
35678 */
35679 getGlobalTensorClass().prototype.where = function (condition, x) {
35680 this.throwIfDisposed();
35681 return where(condition, this, x);
35682 };
35683
35684 /**
35685 * @license
35686 * Copyright 2020 Google LLC. All Rights Reserved.
35687 * Licensed under the Apache License, Version 2.0 (the "License");
35688 * you may not use this file except in compliance with the License.
35689 * You may obtain a copy of the License at
35690 *
35691 * http://www.apache.org/licenses/LICENSE-2.0
35692 *
35693 * Unless required by applicable law or agreed to in writing, software
35694 * distributed under the License is distributed on an "AS IS" BASIS,
35695 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35696 * See the License for the specific language governing permissions and
35697 * limitations under the License.
35698 * =============================================================================
35699 */
35700 getGlobalTensorClass().prototype.zerosLike = function () {
35701 this.throwIfDisposed();
35702 return zerosLike$3(this);
35703 };
35704
35705 /**
35706 * @license
35707 * Copyright 2020 Google LLC. All Rights Reserved.
35708 * Licensed under the Apache License, Version 2.0 (the "License");
35709 * you may not use this file except in compliance with the License.
35710 * You may obtain a copy of the License at
35711 *
35712 * http://www.apache.org/licenses/LICENSE-2.0
35713 *
35714 * Unless required by applicable law or agreed to in writing, software
35715 * distributed under the License is distributed on an "AS IS" BASIS,
35716 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35717 * See the License for the specific language governing permissions and
35718 * limitations under the License.
35719 * =============================================================================
35720 */
35721
35722 /**
35723 * @license
35724 * Copyright 2018 Google LLC
35725 *
35726 * Use of this source code is governed by an MIT-style
35727 * license that can be found in the LICENSE file or at
35728 * https://opensource.org/licenses/MIT.
35729 * =============================================================================
35730 */
35731 /**
35732 * Explicit error types.
35733 *
35734 * See the following link for more information about why the code includes
35735 * calls to setPrototypeOf:
35736 *
35737 * https://github.com/Microsoft/TypeScript-wiki/blob/master/Breaking-Changes.md#extending-built-ins-like-error-array-and-map-may-no-longer-work
35738 */
35739 // tslint:enable
35740 /**
35741 * Equivalent of Python's AttributeError.
35742 */
35743 class AttributeError extends Error {
35744 constructor(message) {
35745 super(message);
35746 // Set the prototype explicitly.
35747 Object.setPrototypeOf(this, AttributeError.prototype);
35748 }
35749 }
35750 /**
35751 * Equivalent of Python's RuntimeError.
35752 */
35753 class RuntimeError extends Error {
35754 constructor(message) {
35755 super(message);
35756 // Set the prototype explicitly.
35757 Object.setPrototypeOf(this, RuntimeError.prototype);
35758 }
35759 }
35760 /**
35761 * Equivalent of Python's ValueError.
35762 */
35763 class ValueError extends Error {
35764 constructor(message) {
35765 super(message);
35766 // Set the prototype explicitly.
35767 Object.setPrototypeOf(this, ValueError.prototype);
35768 }
35769 }
35770 /**
35771 * Equivalent of Python's NotImplementedError.
35772 */
35773 class NotImplementedError extends Error {
35774 constructor(message) {
35775 super(message);
35776 // Set the prototype explicitly.
35777 Object.setPrototypeOf(this, NotImplementedError.prototype);
35778 }
35779 }
35780 /**
35781 * Equivalent of Python's AssertionError.
35782 */
35783 class AssertionError extends Error {
35784 constructor(message) {
35785 super(message);
35786 // Set the prototype explicitly.
35787 Object.setPrototypeOf(this, AssertionError.prototype);
35788 }
35789 }
35790 /**
35791 * Equivalent of Python's IndexError.
35792 */
35793 class IndexError extends Error {
35794 constructor(message) {
35795 super(message);
35796 // Set the prototype explicitly.
35797 Object.setPrototypeOf(this, IndexError.prototype);
35798 }
35799 }
35800
35801 /**
35802 * @license
35803 * Copyright 2022 Google LLC
35804 *
35805 * Use of this source code is governed by an MIT-style
35806 * license that can be found in the LICENSE file or at
35807 * https://opensource.org/licenses/MIT.
35808 * =============================================================================
35809 */
35810 /**
35811 * LruCache: A mapping from the String to T. If the number of the entries is
35812 * exceeding the `maxEntries`, the LruCache will delete the least recently
35813 * used entry.
35814 */
35815 class LruCache {
35816 constructor(maxEntries) {
35817 this.maxEntries = maxEntries || 100;
35818 this.cache = new Map();
35819 }
35820 /**
35821 * Get the entry for the key and mark it as used recently.
35822 */
35823 get(key) {
35824 let entry;
35825 if (this.cache.has(key)) {
35826 entry = this.cache.get(key);
35827 this.cache.delete(key);
35828 this.cache.set(key, entry);
35829 }
35830 return entry;
35831 }
35832 /**
35833 * Put the entry into the cache. If the key already existed, mark the key as
35834 * used recently.
35835 */
35836 put(key, value) {
35837 if (this.cache.has(key)) {
35838 this.cache.delete(key);
35839 }
35840 else if (this.cache.size >= this.maxEntries) {
35841 const keyToDelete = this.cache.keys().next().value;
35842 this.cache.delete(keyToDelete);
35843 }
35844 this.cache.set(key, value);
35845 }
35846 /**
35847 * Get the MaxEntries of the cache.
35848 */
35849 getMaxEntries() {
35850 return this.maxEntries;
35851 }
35852 /**
35853 * Set the MaxEntries of the cache. If the maxEntries is decreased, reduce
35854 * entries in the cache.
35855 */
35856 setMaxEntries(maxEntries) {
35857 if (maxEntries < 0) {
35858 throw new Error(`The maxEntries of LRU caches must be at least 0, but got ${maxEntries}.`);
35859 }
35860 if (this.maxEntries > maxEntries) {
35861 for (let i = 0; i < this.maxEntries - maxEntries; i++) {
35862 const keyToDelete = this.cache.keys().next().value;
35863 this.cache.delete(keyToDelete);
35864 }
35865 }
35866 this.maxEntries = maxEntries;
35867 }
35868 }
35869
35870 /**
35871 * @license
35872 * Copyright 2018 Google LLC
35873 *
35874 * Use of this source code is governed by an MIT-style
35875 * license that can be found in the LICENSE file or at
35876 * https://opensource.org/licenses/MIT.
35877 * =============================================================================
35878 */
35879 // tslint:enable
35880 /**
35881 * If `value` is an Array, equivalent to Python's `value * numValues`.
35882 * If `value` is not an Array, equivalent to Python's `[value] * numValues`
35883 */
35884 // tslint:disable-next-line:no-any
35885 function pyListRepeat(value, numValues) {
35886 if (Array.isArray(value)) {
35887 // tslint:disable-next-line:no-any
35888 let newArray = [];
35889 for (let i = 0; i < numValues; i++) {
35890 newArray = newArray.concat(value);
35891 }
35892 return newArray;
35893 }
35894 else {
35895 const newArray = new Array(numValues);
35896 newArray.fill(value);
35897 return newArray;
35898 }
35899 }
35900 function assert(val, message) {
35901 if (!val) {
35902 throw new AssertionError(message);
35903 }
35904 }
35905 /**
35906 * Count the number of elements of the `array` that are equal to `reference`.
35907 */
35908 function count(array, refernce) {
35909 let counter = 0;
35910 for (const item of array) {
35911 if (item === refernce) {
35912 counter++;
35913 }
35914 }
35915 return counter;
35916 }
35917 /**
35918 * If an array is of length 1, just return the first element. Otherwise, return
35919 * the full array.
35920 * @param tensors
35921 */
35922 function singletonOrArray(xs) {
35923 if (xs.length === 1) {
35924 return xs[0];
35925 }
35926 return xs;
35927 }
35928 /**
35929 * Normalizes a list/tensor into a list.
35930 *
35931 * If a tensor is passed, we return
35932 * a list of size 1 containing the tensor.
35933 *
35934 * @param x target object to be normalized.
35935 */
35936 // tslint:disable-next-line:no-any
35937 function toList(x) {
35938 if (Array.isArray(x)) {
35939 return x;
35940 }
35941 return [x];
35942 }
35943 /**
35944 * Generate a UID for a list
35945 */
35946 // tslint:disable-next-line:no-any
35947 function objectListUid(objs) {
35948 const objectList = toList(objs);
35949 let retVal = '';
35950 for (const obj of objectList) {
35951 if (obj.id == null) {
35952 throw new ValueError(`Object ${obj} passed to objectListUid without an id`);
35953 }
35954 if (retVal !== '') {
35955 retVal = retVal + ', ';
35956 }
35957 retVal = `${retVal}${Math.abs(obj.id)}`;
35958 }
35959 return retVal;
35960 }
35961 /**
35962 * Converts string to snake-case.
35963 * @param name
35964 */
35965 function toSnakeCase(name) {
35966 const intermediate = name.replace(/(.)([A-Z][a-z0-9]+)/g, '$1_$2');
35967 const insecure = intermediate.replace(/([a-z])([A-Z])/g, '$1_$2').toLowerCase();
35968 /*
35969 If the class is private the name starts with "_" which is not secure
35970 for creating scopes. We prefix the name with "private" in this case.
35971 */
35972 if (insecure[0] !== '_') {
35973 return insecure;
35974 }
35975 return 'private' + insecure;
35976 }
35977 function toCamelCase(identifier) {
35978 // quick return for empty string or single character strings
35979 if (identifier.length <= 1) {
35980 return identifier;
35981 }
35982 // Check for the underscore indicating snake_case
35983 if (identifier.indexOf('_') === -1) {
35984 return identifier;
35985 }
35986 return identifier.replace(/[_]+(\w|$)/g, (m, p1) => p1.toUpperCase());
35987 }
35988 // tslint:disable-next-line:no-any
35989 let _GLOBAL_CUSTOM_OBJECTS = {};
35990 function serializeKerasObject(instance) {
35991 if (instance === null || instance === undefined) {
35992 return null;
35993 }
35994 const dict = {};
35995 dict['className'] = instance.getClassName();
35996 dict['config'] = instance.getConfig();
35997 return dict;
35998 }
35999 /**
36000 * Replace ndarray-style scalar objects in serialization objects with numbers.
36001 *
36002 * Background: In some versions of tf.keras, certain scalar values in the HDF5
36003 * model save file can be serialized as: `{'type': 'ndarray', 'value': num}`,
36004 * where in `num` is a plain number. This method converts such serialization
36005 * to a `number`.
36006 *
36007 * @param config The keras-format serialization object to be processed
36008 * (in place).
36009 */
36010 function convertNDArrayScalarsInConfig(config) {
36011 if (config == null || typeof config !== 'object') {
36012 return;
36013 }
36014 else if (Array.isArray(config)) {
36015 config.forEach(configItem => convertNDArrayScalarsInConfig(configItem));
36016 }
36017 else {
36018 const fields = Object.keys(config);
36019 for (const field of fields) {
36020 const value = config[field];
36021 if (value != null && typeof value === 'object') {
36022 if (!Array.isArray(value) && value['type'] === 'ndarray' &&
36023 typeof value['value'] === 'number') {
36024 config[field] = value['value'];
36025 }
36026 else {
36027 convertNDArrayScalarsInConfig(value);
36028 }
36029 }
36030 }
36031 }
36032 }
36033 /**
36034 * Deserialize a saved Keras Object
36035 * @param identifier either a string ID or a saved Keras dictionary
36036 * @param moduleObjects a list of Python class names to object constructors
36037 * @param customObjects a list of Python class names to object constructors
36038 * @param printableModuleName debug text for the object being reconstituted
36039 * @param fastWeightInit Optional flag to use fast weight initialization
36040 * during deserialization. This is applicable to cases in which
36041 * the initialization will be immediately overwritten by loaded weight
36042 * values. Default: `false`.
36043 * @returns a TensorFlow.js Layers object
36044 */
36045 // tslint:disable:no-any
36046 function deserializeKerasObject(identifier, moduleObjects = {}, customObjects = {}, printableModuleName = 'object', fastWeightInit = false) {
36047 // tslint:enable
36048 if (typeof identifier === 'string') {
36049 const functionName = identifier;
36050 let fn;
36051 if (functionName in customObjects) {
36052 fn = customObjects[functionName];
36053 }
36054 else if (functionName in _GLOBAL_CUSTOM_OBJECTS) {
36055 fn = _GLOBAL_CUSTOM_OBJECTS[functionName];
36056 }
36057 else {
36058 fn = moduleObjects[functionName];
36059 if (fn == null) {
36060 throw new ValueError(`Unknown ${printableModuleName}: ${identifier}. ` +
36061 `This may be due to one of the following reasons:\n` +
36062 `1. The ${printableModuleName} is defined in Python, in which ` +
36063 `case it needs to be ported to TensorFlow.js or your JavaScript ` +
36064 `code.\n` +
36065 `2. The custom ${printableModuleName} is defined in JavaScript, ` +
36066 `but is not registered properly with ` +
36067 `tf.serialization.registerClass().`);
36068 // TODO(cais): Add link to tutorial page on custom layers.
36069 }
36070 }
36071 return fn;
36072 }
36073 else {
36074 // In this case we are dealing with a Keras config dictionary.
36075 const config = identifier;
36076 if (config['className'] == null || config['config'] == null) {
36077 throw new ValueError(`${printableModuleName}: Improper config format: ` +
36078 `${JSON.stringify(config)}.\n` +
36079 `'className' and 'config' must set.`);
36080 }
36081 const className = config['className'];
36082 let cls, fromConfig;
36083 if (className in customObjects) {
36084 [cls, fromConfig] = customObjects[className];
36085 }
36086 else if (className in _GLOBAL_CUSTOM_OBJECTS) {
36087 [cls, fromConfig] = _GLOBAL_CUSTOM_OBJECTS['className'];
36088 }
36089 else if (className in moduleObjects) {
36090 [cls, fromConfig] = moduleObjects[className];
36091 }
36092 if (cls == null) {
36093 throw new ValueError(`Unknown ${printableModuleName}: ${className}. ` +
36094 `This may be due to one of the following reasons:\n` +
36095 `1. The ${printableModuleName} is defined in Python, in which ` +
36096 `case it needs to be ported to TensorFlow.js or your JavaScript ` +
36097 `code.\n` +
36098 `2. The custom ${printableModuleName} is defined in JavaScript, ` +
36099 `but is not registered properly with ` +
36100 `tf.serialization.registerClass().`);
36101 // TODO(cais): Add link to tutorial page on custom layers.
36102 }
36103 if (fromConfig != null) {
36104 // Porting notes: Instead of checking to see whether fromConfig accepts
36105 // customObjects, we create a customObjects dictionary and tack it on to
36106 // config['config'] as config['config'].customObjects. Objects can use it,
36107 // if they want.
36108 // tslint:disable-next-line:no-any
36109 const customObjectsCombined = {};
36110 for (const key of Object.keys(_GLOBAL_CUSTOM_OBJECTS)) {
36111 customObjectsCombined[key] = _GLOBAL_CUSTOM_OBJECTS[key];
36112 }
36113 for (const key of Object.keys(customObjects)) {
36114 customObjectsCombined[key] = customObjects[key];
36115 }
36116 // Add the customObjects to config
36117 const nestedConfig = config['config'];
36118 nestedConfig['customObjects'] = customObjectsCombined;
36119 const backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS);
36120 for (const key of Object.keys(customObjects)) {
36121 _GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];
36122 }
36123 convertNDArrayScalarsInConfig(config['config']);
36124 const returnObj = fromConfig(cls, config['config'], customObjects, fastWeightInit);
36125 _GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects);
36126 return returnObj;
36127 }
36128 else {
36129 // Then `cls` may be a function returning a class.
36130 // In this case by convention `config` holds
36131 // the kwargs of the function.
36132 const backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS);
36133 for (const key of Object.keys(customObjects)) {
36134 _GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];
36135 }
36136 // In python this is **config['config'], for tfjs-layers we require
36137 // classes that use this fall-through construction method to take
36138 // a config interface that mimics the expansion of named parameters.
36139 const returnObj = new cls(config['config']);
36140 _GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects);
36141 return returnObj;
36142 }
36143 }
36144 }
36145 /**
36146 * Compares two numbers for sorting.
36147 * @param a
36148 * @param b
36149 */
36150 function numberCompare(a, b) {
36151 return (a < b) ? -1 : ((a > b) ? 1 : 0);
36152 }
36153 /**
36154 * Comparison of two numbers for reverse sorting.
36155 * @param a
36156 * @param b
36157 */
36158 function reverseNumberCompare(a, b) {
36159 return -1 * numberCompare(a, b);
36160 }
36161 /**
36162 * Convert a string into the corresponding DType.
36163 * @param dtype
36164 * @returns An instance of DType.
36165 */
36166 function stringToDType(dtype) {
36167 switch (dtype) {
36168 case 'float32':
36169 return 'float32';
36170 default:
36171 throw new ValueError(`Invalid dtype: ${dtype}`);
36172 }
36173 }
36174 /**
36175 * Test the element-by-element equality of two Arrays of strings.
36176 * @param xs First array of strings.
36177 * @param ys Second array of strings.
36178 * @returns Wether the two arrays are all equal, element by element.
36179 */
36180 function stringsEqual(xs, ys) {
36181 if (xs == null || ys == null) {
36182 return xs === ys;
36183 }
36184 if (xs.length !== ys.length) {
36185 return false;
36186 }
36187 for (let i = 0; i < xs.length; ++i) {
36188 if (xs[i] !== ys[i]) {
36189 return false;
36190 }
36191 }
36192 return true;
36193 }
36194 /**
36195 * Get the unique elements of an array.
36196 * @param xs Array.
36197 * @returns An Array consisting of the unique elements in `xs`.
36198 */
36199 function unique$2(xs) {
36200 if (xs == null) {
36201 return xs;
36202 }
36203 const out = [];
36204 // TODO(cais): Maybe improve performance by sorting.
36205 for (const x of xs) {
36206 if (out.indexOf(x) === -1) {
36207 out.push(x);
36208 }
36209 }
36210 return out;
36211 }
36212 /**
36213 * Determine if an Object is empty (i.e., does not have own properties).
36214 * @param obj Object
36215 * @returns Whether the Object is empty.
36216 * @throws ValueError: If object is `null` or `undefined`.
36217 */
36218 function isObjectEmpty(obj) {
36219 if (obj == null) {
36220 throw new ValueError(`Invalid value in obj: ${JSON.stringify(obj)}`);
36221 }
36222 for (const key in obj) {
36223 if (obj.hasOwnProperty(key)) {
36224 return false;
36225 }
36226 }
36227 return true;
36228 }
36229 /**
36230 * Helper function used to build type union/enum run-time checkers.
36231 * @param values The list of allowed values.
36232 * @param label A string name for the type
36233 * @param value The value to test.
36234 * @throws ValueError: If the value is not in values nor `undefined`/`null`.
36235 */
36236 function checkStringTypeUnionValue(values, label, value) {
36237 if (value == null) {
36238 return;
36239 }
36240 if (values.indexOf(value) < 0) {
36241 throw new ValueError(`${value} is not a valid ${label}. Valid values are ${values} or null/undefined.`);
36242 }
36243 }
36244 /**
36245 * Helper function for verifying the types of inputs.
36246 *
36247 * Ensures that the elements of `x` are all of type `expectedType`.
36248 * Also verifies that the length of `x` is within bounds.
36249 *
36250 * @param x Object to test.
36251 * @param expectedType The string expected type of all of the elements in the
36252 * Array.
36253 * @param minLength Return false if x.length is less than this.
36254 * @param maxLength Return false if x.length is greater than this.
36255 * @returns true if and only if `x` is an `Array<expectedType>` with
36256 * length >= `minLength` and <= `maxLength`.
36257 */
36258 // tslint:disable:no-any
36259 function checkArrayTypeAndLength(x, expectedType, minLength = 0, maxLength = Infinity) {
36260 assert(minLength >= 0);
36261 assert(maxLength >= minLength);
36262 return (Array.isArray(x) && x.length >= minLength && x.length <= maxLength &&
36263 x.every(e => typeof e === expectedType));
36264 }
36265 // tslint:enable:no-any
36266 /**
36267 * Assert that a value or an array of value are positive integer.
36268 *
36269 * @param value The value being asserted on. May be a single number or an array
36270 * of numbers.
36271 * @param name Name of the value, used to make the error message.
36272 */
36273 function assertPositiveInteger(value, name) {
36274 if (Array.isArray(value)) {
36275 assert$1(value.length > 0, () => `${name} is unexpectedly an empty array.`);
36276 value.forEach((v, i) => assertPositiveInteger(v, `element ${i + 1} of ${name}`));
36277 }
36278 else {
36279 assert$1(Number.isInteger(value) && value > 0, () => `Expected ${name} to be a positive integer, but got ` +
36280 `${formatAsFriendlyString(value)}.`);
36281 }
36282 }
36283 /**
36284 * Format a value into a display-friendly, human-readable fashion.
36285 *
36286 * - `null` is formatted as `'null'`
36287 * - Strings are formated with flanking pair of quotes.
36288 * - Arrays are formatted with flanking pair of square brackets.
36289 *
36290 * @param value The value to display.
36291 * @return Formatted string.
36292 */
36293 // tslint:disable-next-line:no-any
36294 function formatAsFriendlyString(value) {
36295 if (value === null) {
36296 return 'null';
36297 }
36298 else if (Array.isArray(value)) {
36299 return '[' + value.map(v => formatAsFriendlyString(v)).join(',') + ']';
36300 }
36301 else if (typeof value === 'string') {
36302 return `"${value}"`;
36303 }
36304 else {
36305 return `${value}`;
36306 }
36307 }
36308 /**
36309 * Returns a function `f2` (decorator) which wraps the original function
36310 * `f`. `f2` guarantees that `f` can be called at most once
36311 * every `waitMs` ms. If `f2` is called more often, it will return
36312 * the last returned result of `f`.
36313 *
36314 * @param f The original function `f` to wrap.
36315 * @param waitMs The time between two consecutive calls to `f` in ms.
36316 */
36317 function debounce(f, waitMs, nowFunc) {
36318 let lastTime = nowFunc != null ? nowFunc() : now();
36319 let lastResult;
36320 const f2 = (...args) => {
36321 const now$1 = nowFunc != null ? nowFunc() : now();
36322 if (now$1 - lastTime < waitMs) {
36323 return lastResult;
36324 }
36325 lastTime = now$1;
36326 lastResult = f(...args);
36327 return lastResult;
36328 };
36329 return f2;
36330 }
36331 /**
36332 * Returns the fusable activation given a layers identifier.
36333 *
36334 * @param activationName The layers identifier string.
36335 * @return The name of the fusable activation.
36336 */
36337 function mapActivationToFusedKernel(activationName) {
36338 if (activationName === 'relu') {
36339 return 'relu';
36340 }
36341 if (activationName === 'linear') {
36342 return 'linear';
36343 }
36344 if (activationName === 'elu') {
36345 return 'elu';
36346 }
36347 return null;
36348 }
36349 /**
36350 * Returns the cartesian product of sets of values.
36351 * This works the same as itertools.product in Python.
36352 *
36353 * Example:
36354 *
36355 * filters = [128, 256, 512]
36356 * paddings = ['same', 'valid']
36357 *
36358 * product = [ [128, 'same'], [128, 'valid'], [256, 'same'], [256, 'valid'],
36359 * [512, 'same'], [512, 'valid']]
36360 *
36361 * @param arrayOfValues List/array of values.
36362 * @return The cartesian product.
36363 */
36364 function getCartesianProductOfValues(...arrayOfValues) {
36365 assert(arrayOfValues.length > 0, 'arrayOfValues is empty');
36366 for (const values of arrayOfValues) {
36367 assert(Array.isArray(values), 'one of the values is not an array');
36368 assert(values.length > 0, 'one of the values is empty');
36369 }
36370 return arrayOfValues.reduce((products, values) => {
36371 if (products.length === 0) {
36372 return values.map(value => [value]);
36373 }
36374 return values
36375 .map(value => {
36376 return products.map((prevValue) => [...prevValue, value]);
36377 })
36378 .reduce((flattenedProduct, unflattenedProduct) => {
36379 return flattenedProduct.concat(unflattenedProduct);
36380 }, []);
36381 }, []);
36382 }
36383
36384 /**
36385 * @license
36386 * Copyright 2018 Google LLC
36387 *
36388 * Use of this source code is governed by an MIT-style
36389 * license that can be found in the LICENSE file or at
36390 * https://opensource.org/licenses/MIT.
36391 * =============================================================================
36392 */
36393 /**
36394 * Utilities related to persistent state in the backend.
36395 */
36396 /**
36397 * An ID to track `tf.SymbolicTensor`s and derived classes.
36398 * Required in different places in engine/topology.ts to identify unique
36399 * tensors.
36400 */
36401 let _nextUniqueTensorId = 0;
36402 function getNextUniqueTensorId() {
36403 return _nextUniqueTensorId++;
36404 }
36405 const _uidPrefixes = {};
36406 /**
36407 * Provides a unique UID given a string prefix.
36408 *
36409 * @param prefix
36410 */
36411 function getUid(prefix = '') {
36412 if (!(prefix in _uidPrefixes)) {
36413 _uidPrefixes[prefix] = 0;
36414 }
36415 _uidPrefixes[prefix] += 1;
36416 return prefix + _uidPrefixes[prefix].toString();
36417 }
36418
36419 /**
36420 * @license
36421 * Copyright 2018 Google LLC
36422 *
36423 * Use of this source code is governed by an MIT-style
36424 * license that can be found in the LICENSE file or at
36425 * https://opensource.org/licenses/MIT.
36426 * =============================================================================
36427 */
36428 const VALID_DATA_FORMAT_VALUES = ['channelsFirst', 'channelsLast'];
36429 const VALID_INTERPOLATION_FORMAT_VALUES = ['nearest', 'bilinear'];
36430 const VALID_PADDING_MODE_VALUES = ['valid', 'same', 'causal'];
36431 const VALID_POOL_MODE_VALUES = ['max', 'avg'];
36432 const VALID_BIDIRECTIONAL_MERGE_MODES = ['sum', 'mul', 'concat', 'ave'];
36433 const VALID_SAMPLE_WEIGHT_MODES = ['temporal'];
36434
36435 /**
36436 * @license
36437 * Copyright 2018 Google LLC
36438 *
36439 * Use of this source code is governed by an MIT-style
36440 * license that can be found in the LICENSE file or at
36441 * https://opensource.org/licenses/MIT.
36442 * =============================================================================
36443 */
36444 // A map from the requested scoped name of a Tensor to the number of Tensors
36445 // wanting that name so far. This allows enforcing name uniqueness by appending
36446 // an incrementing index, e.g. scope/name, scope/name_1, scope/name_2, etc.
36447 const nameMap = new Map();
36448 function checkDataFormat(value) {
36449 checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value);
36450 }
36451 function checkInterpolationFormat(value) {
36452 checkStringTypeUnionValue(VALID_INTERPOLATION_FORMAT_VALUES, 'InterpolationFormat', value);
36453 }
36454 function checkPaddingMode(value) {
36455 checkStringTypeUnionValue(VALID_PADDING_MODE_VALUES, 'PaddingMode', value);
36456 }
36457 function checkPoolMode(value) {
36458 checkStringTypeUnionValue(VALID_POOL_MODE_VALUES, 'PoolMode', value);
36459 }
36460 const _nameScopeStack = [];
36461 const _nameScopeDivider = '/';
36462 /**
36463 * Enter namescope, which can be nested.
36464 */
36465 function nameScope(name, fn) {
36466 _nameScopeStack.push(name);
36467 try {
36468 const val = fn();
36469 _nameScopeStack.pop();
36470 return val;
36471 }
36472 catch (e) {
36473 _nameScopeStack.pop();
36474 throw e;
36475 }
36476 }
36477 /**
36478 * Get the current namescope as a flat, concatenated string.
36479 */
36480 function currentNameScopePrefix() {
36481 if (_nameScopeStack.length === 0) {
36482 return '';
36483 }
36484 else {
36485 return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider;
36486 }
36487 }
36488 /**
36489 * Get the name a Tensor (or Variable) would have if not uniqueified.
36490 * @param tensorName
36491 * @return Scoped name string.
36492 */
36493 function getScopedTensorName(tensorName) {
36494 if (!isValidTensorName(tensorName)) {
36495 throw new Error('Not a valid tensor name: \'' + tensorName + '\'');
36496 }
36497 return currentNameScopePrefix() + tensorName;
36498 }
36499 /**
36500 * Get unique names for Tensors and Variables.
36501 * @param scopedName The fully-qualified name of the Tensor, i.e. as produced by
36502 * `getScopedTensorName()`.
36503 * @return A unique version of the given fully scoped name.
36504 * If this is the first time that the scoped name is seen in this session,
36505 * then the given `scopedName` is returned unaltered. If the same name is
36506 * seen again (producing a collision), an incrementing suffix is added to the
36507 * end of the name, so it takes the form 'scope/name_1', 'scope/name_2', etc.
36508 */
36509 function getUniqueTensorName(scopedName) {
36510 if (!isValidTensorName(scopedName)) {
36511 throw new Error('Not a valid tensor name: \'' + scopedName + '\'');
36512 }
36513 if (!nameMap.has(scopedName)) {
36514 nameMap.set(scopedName, 0);
36515 }
36516 const index = nameMap.get(scopedName);
36517 nameMap.set(scopedName, nameMap.get(scopedName) + 1);
36518 if (index > 0) {
36519 const result = `${scopedName}_${index}`;
36520 // Mark the composed name as used in case someone wants
36521 // to call getUniqueTensorName("name_1").
36522 nameMap.set(result, 1);
36523 return result;
36524 }
36525 else {
36526 return scopedName;
36527 }
36528 }
36529 const tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/);
36530 /**
36531 * Determine whether a string is a valid tensor name.
36532 * @param name
36533 * @returns A Boolean indicating whether `name` is a valid tensor name.
36534 */
36535 function isValidTensorName(name) {
36536 return !!name.match(tensorNameRegex);
36537 }
36538
36539 /**
36540 * @license
36541 * Copyright 2018 Google LLC
36542 *
36543 * Use of this source code is governed by an MIT-style
36544 * license that can be found in the LICENSE file or at
36545 * https://opensource.org/licenses/MIT.
36546 * =============================================================================
36547 */
36548 /**
36549 * Determine if a number is an integer.
36550 */
36551 function isInteger(x) {
36552 return x === parseInt(x.toString(), 10);
36553 }
36554 /**
36555 * Calculate the product of an array of numbers.
36556 * @param array The array to calculate the product over.
36557 * @param begin Beginning index, inclusive.
36558 * @param end Ending index, exclusive.
36559 * @return The product.
36560 */
36561 function arrayProd(array, begin, end) {
36562 if (begin == null) {
36563 begin = 0;
36564 }
36565 if (end == null) {
36566 end = array.length;
36567 }
36568 let prod = 1;
36569 for (let i = begin; i < end; ++i) {
36570 prod *= array[i];
36571 }
36572 return prod;
36573 }
36574 /**
36575 * Compute minimum value.
36576 * @param array
36577 * @return minimum value.
36578 */
36579 function min$2(array) {
36580 // same behavior as tf.min()
36581 if (array.length === 0) {
36582 return Number.NaN;
36583 }
36584 let min = Number.POSITIVE_INFINITY;
36585 for (let i = 0; i < array.length; i++) {
36586 const value = array[i];
36587 if (value < min) {
36588 min = value;
36589 }
36590 }
36591 return min;
36592 }
36593 /**
36594 * Compute maximum value.
36595 * @param array
36596 * @return maximum value
36597 */
36598 function max$2(array) {
36599 // same behavior as tf.max()
36600 if (array.length === 0) {
36601 return Number.NaN;
36602 }
36603 let max = Number.NEGATIVE_INFINITY;
36604 for (let i = 0; i < array.length; i++) {
36605 const value = array[i];
36606 if (value > max) {
36607 max = value;
36608 }
36609 }
36610 return max;
36611 }
36612 /**
36613 * Compute sum of array.
36614 * @param array
36615 * @return The sum.
36616 */
36617 function sum$2(array) {
36618 let sum = 0;
36619 for (let i = 0; i < array.length; i++) {
36620 const value = array[i];
36621 sum += value;
36622 }
36623 return sum;
36624 }
36625 /**
36626 * Compute mean of array.
36627 * @param array
36628 * @return The mean.
36629 */
36630 function mean$1(array) {
36631 return sum$2(array) / array.length;
36632 }
36633 /**
36634 * Compute variance of array.
36635 * @param array
36636 * @return The variance.
36637 */
36638 function variance(array) {
36639 const meanValue = mean$1(array);
36640 const demeaned = array.map((value) => value - meanValue);
36641 let sumSquare = 0;
36642 for (let i = 0; i < demeaned.length; i++) {
36643 const value = demeaned[i];
36644 sumSquare += value * value;
36645 }
36646 return sumSquare / array.length;
36647 }
36648 /**
36649 * Compute median of array.
36650 * @param array
36651 * @return The median value.
36652 */
36653 function median(array) {
36654 const arraySorted = array.slice().sort((a, b) => a - b);
36655 const lowIdx = Math.floor((arraySorted.length - 1) / 2);
36656 const highIdx = Math.ceil((arraySorted.length - 1) / 2);
36657 if (lowIdx === highIdx) {
36658 return arraySorted[lowIdx];
36659 }
36660 return (arraySorted[lowIdx] + arraySorted[highIdx]) / 2;
36661 }
36662 /**
36663 * Generate an array of integers in [begin, end).
36664 * @param begin Beginning integer, inclusive.
36665 * @param end Ending integer, exclusive.
36666 * @returns Range array.
36667 * @throws ValueError, iff `end` < `begin`.
36668 */
36669 function range$2(begin, end) {
36670 if (end < begin) {
36671 throw new ValueError(`end (${end}) < begin (${begin}) is forbidden.`);
36672 }
36673 const out = [];
36674 for (let i = begin; i < end; ++i) {
36675 out.push(i);
36676 }
36677 return out;
36678 }
36679
36680 /**
36681 * @license
36682 * Copyright 2018 Google LLC
36683 *
36684 * Use of this source code is governed by an MIT-style
36685 * license that can be found in the LICENSE file or at
36686 * https://opensource.org/licenses/MIT.
36687 * =============================================================================
36688 */
36689 let _epsilon;
36690 /**
36691 * Returns the value of the fuzz factor used in numeric expressions.
36692 */
36693 function epsilon$1() {
36694 if (_epsilon == null) {
36695 _epsilon = backend$1().epsilon();
36696 }
36697 return _epsilon;
36698 }
36699 /**
36700 * Sets the value of the fuzz factor used in numeric expressions.
36701 * @param e New value of epsilon.
36702 */
36703 function setEpsilon(e) {
36704 _epsilon = e;
36705 }
36706 /**
36707 * Returns the default image data format convention.
36708 */
36709 function imageDataFormat() {
36710 return 'channelsLast';
36711 }
36712
36713 /**
36714 * @license
36715 * Copyright 2018 Google LLC
36716 *
36717 * Use of this source code is governed by an MIT-style
36718 * license that can be found in the LICENSE file or at
36719 * https://opensource.org/licenses/MIT.
36720 * =============================================================================
36721 */
36722 // tslint:enable
36723 /* Setting and getting backend from deeplearn.js. */
36724 // Default deeplearn.js backend is WebGL (GPU).
36725 let backend = 'webgl';
36726 function setBackend(requestedBackend) {
36727 setBackend$1(requestedBackend);
36728 backend = requestedBackend;
36729 }
36730 function getBackend() {
36731 return backend;
36732 }
36733 /**
36734 * Indicates whether the backend is operating symbolically.
36735 *
36736 * This function will be used to determine how to interpret user code. If
36737 * it returns true, calls to the backend construct a symbolic graph; if
36738 * it returns false, calls to the backend execute immediately.
36739 */
36740 function isBackendSymbolic() {
36741 return false;
36742 }
36743 /**
36744 * Get the number of elements in a Tensor.
36745 * @param x The Tensor.
36746 * @return Number of elements in `x`.
36747 */
36748 function countParams(x) {
36749 const shape = x.shape;
36750 if (shape.length > 0) {
36751 return shape.reduce((a, b) => a * b);
36752 }
36753 else {
36754 // Scalar.
36755 return 1;
36756 }
36757 }
36758 /**
36759 * Casts a tensor to a different dtype and returns it.
36760 * @param x Input tensor.
36761 * @param dtype String: 'float32'|'int32'|'bool'.
36762 * @returns Tensor of the specified `dtype`.
36763 */
36764 function cast$2(x, dtype) {
36765 return cast$3(x, dtype);
36766 }
36767 /**
36768 * Adds a 1-sized dimension at index "axis".
36769 * @param x Input tensor.
36770 * @param axis Position where to add the new axis.
36771 * @returns Result of the dimension expansion.
36772 */
36773 function expandDims$2(x, axis = -1) {
36774 const outShape = x.shape.slice();
36775 if (axis < 0) {
36776 axis = outShape.length + axis + 1;
36777 }
36778 outShape.splice(axis, 0, 1);
36779 return reshape$3(x, outShape);
36780 }
36781 /**
36782 * Repeats a 2D tensor.
36783 *
36784 * If `x` has shape `[samples, dim]` and `n` is 2, for example, the output
36785 * will have shape `[samples, 2, dim]`.
36786 *
36787 * @param x Input tensor.
36788 * @param n Integer, number of times to repeat.
36789 * @returns The result of the repeat operation.
36790 * @throws ValueError: If input tensor is not 2D.
36791 */
36792 function repeat(x, n) {
36793 return tidy(() => {
36794 if (x.shape.length !== 2) {
36795 throw new ValueError(`repeat() expects a rank-2 tensor, but received a ` +
36796 `rank-${x.shape.length} tensor.`);
36797 }
36798 const y = expandDims$2(x, 1);
36799 return tile$2(y, [1, n, 1]);
36800 });
36801 }
36802 /**
36803 * Flatten a Tensor into 1D.
36804 * @param x Input tensor.
36805 * @return The result of the flattening `x`.
36806 */
36807 function flatten$1(x) {
36808 const newShape = [arrayProd(x.shape)];
36809 return reshape$3(x, newShape);
36810 }
36811 /**
36812 * Turn a nD tensor into a 2D tensor with same 0th dimension.
36813 * In other words, it flattens each data samples of a batch.
36814 *
36815 * @param x The tensor to flatten. The rank of this tensor is required to be 2
36816 * or higher.
36817 * @return The result of the flattening.
36818 */
36819 function batchFlatten(x) {
36820 if (x.rank <= 1) {
36821 throw new ValueError(`batchFlatten requires a minimum rank of 2. Got rank: ${x.rank}.`);
36822 }
36823 const newShape = [x.shape[0], arrayProd(x.shape, 1)];
36824 return reshape$3(x, newShape);
36825 }
36826 /**
36827 * Do slicing along the first axis.
36828 * @param array input `tf.Tensor`.
36829 * @param start starting index, inclusive.
36830 * @param size size of the slice along the first axis.
36831 * @returns result of the slicing.
36832 * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
36833 */
36834 function sliceAlongFirstAxis(array, start, size) {
36835 return tidy(() => {
36836 switch (array.rank) {
36837 case 1:
36838 return slice1d(array, start, size);
36839 case 2:
36840 return slice2d(array, [start, 0], [size, array.shape[1]]);
36841 case 3:
36842 return slice3d(array, [start, 0, 0], [size, array.shape[1], array.shape[2]]);
36843 case 4:
36844 return slice4d(array, [start, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3]]);
36845 case 5:
36846 return slice$2(array, [start, 0, 0, 0, 0], [
36847 size, array.shape[1], array.shape[2], array.shape[3], array.shape[4]
36848 ]);
36849 case 6:
36850 return slice$2(array, [start, 0, 0, 0, 0, 0], [
36851 size, array.shape[1], array.shape[2], array.shape[3], array.shape[4],
36852 array.shape[5]
36853 ]);
36854 default:
36855 throw new ValueError(`sliceAlongFirstAxis() received an unsupported tensor rank: ` +
36856 `${array.rank}`);
36857 }
36858 });
36859 }
36860 /**
36861 * Do slicing along the last axis.
36862 * @param array input `tf.Tensor`.
36863 * @param start starting index, inclusive.
36864 * @param size size of the slice along the last axis.
36865 * @returns result of the slicing.
36866 * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
36867 */
36868 function sliceAlongLastAxis(array, start, size) {
36869 return tidy(() => {
36870 switch (array.rank) {
36871 case 1:
36872 return slice1d(array, start, size);
36873 case 2:
36874 return slice2d(array, [0, start], [array.shape[0], size]);
36875 case 3:
36876 return slice3d(array, [0, 0, start], [array.shape[0], array.shape[1], size]);
36877 case 4:
36878 return slice4d(array, [0, 0, 0, start], [array.shape[0], array.shape[1], array.shape[2], size]);
36879 default:
36880 throw new ValueError(`sliceAlongLastAxis() received an unsupported tensor rank: ` +
36881 `${array.rank}`);
36882 }
36883 });
36884 }
36885 /**
36886 * Do slicing along the sepcified axis.
36887 * @param array input `tf.Tensor`.
36888 * @param start starting index, inclusive.
36889 * @param size of the slice along the chosen axis.
36890 * @param choose an axis.
36891 * @returns result of the slicing.
36892 * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
36893 */
36894 function sliceAlongAxis(array, start, size, axis) {
36895 return tidy(() => {
36896 switch (array.rank) {
36897 case 1:
36898 return slice1d(array, start, size);
36899 case 2:
36900 switch (axis) {
36901 case 1:
36902 return sliceAlongFirstAxis(array, start, size);
36903 case 2:
36904 return sliceAlongLastAxis(array, start, size);
36905 default:
36906 throw new ValueError(`The axis is not within the rank of the tensor ` +
36907 `${axis}`);
36908 }
36909 case 3:
36910 switch (axis) {
36911 case 1:
36912 return sliceAlongFirstAxis(array, start, size);
36913 case 2:
36914 return slice3d(array, [0, start, 0], [array.shape[0], size, array.shape[2]]);
36915 case 3:
36916 return sliceAlongLastAxis(array, start, size);
36917 default:
36918 throw new ValueError(`The axis is not within the rank of the tensor ` +
36919 `${axis}`);
36920 }
36921 case 4:
36922 switch (axis) {
36923 case 1:
36924 return sliceAlongFirstAxis(array, start, size);
36925 case 2:
36926 return slice4d(array, [0, start, 0, 0], [array.shape[0], size, array.shape[2], array.shape[3]]);
36927 case 3:
36928 return slice4d(array, [0, 0, start, 0], [array.shape[0], array.shape[1], size, array.shape[3]]);
36929 case 4:
36930 return sliceAlongLastAxis(array, start, size);
36931 default:
36932 throw new ValueError(`The axis is not within the rank of the tensor ` +
36933 `${axis}`);
36934 }
36935 default:
36936 throw new ValueError(`sliceAlongLastAxis() received an unsupported tensor rank: ` +
36937 `${array.rank}`);
36938 }
36939 });
36940 }
36941 /**
36942 * Concatenates a list of tensors alongside the specified axis.
36943 * @param tensors `Array` of tensors to concatenate.
36944 * @param axis Concatenation axis.
36945 * @returns The result of the concatenation.
36946 */
36947 function concatenate$2(tensors, axis = -1) {
36948 let rank;
36949 if (axis < 0) {
36950 rank = tensors[0].rank;
36951 if (rank !== 0) {
36952 axis = rank;
36953 }
36954 else {
36955 axis = 0;
36956 }
36957 }
36958 if (axis === tensors[0].rank) {
36959 // Porting Note: This is necessary because tfc.concat() requires axis to be
36960 // in the interval [-rank, rank).
36961 axis = -1;
36962 }
36963 // Porting Note: Sparse concat is not supported yet.
36964 return concat$2(tensors, axis);
36965 }
36966 /**
36967 * Concatenate two arrays along the first dimension.
36968 * @param a The 1st `tf.Tensor` to concatenate.
36969 * @param b The 2nd `tf.Tensor` to concatenate.
36970 * @returns Result of the concatenation.
36971 * @throws ValueError: If `a` is of an unsupported subtype of `tf.Tensor`.
36972 */
36973 function concatAlongFirstAxis(a, b) {
36974 switch (a.rank) {
36975 case 1:
36976 return concat1d([a, b]);
36977 case 2:
36978 return concat2d([a, b], 0);
36979 case 3:
36980 return concat3d([a, b], 0);
36981 case 4:
36982 return concat4d([a, b], 0);
36983 default:
36984 throw new ValueError(`concatAlongFirstAxis() received an unsupported ` +
36985 `tensor rank: ${a.rank}`);
36986 }
36987 }
36988 /**
36989 * Creates a tensor by tiling `x` by `n`.
36990 * @param x A tensor.
36991 * @param n An Array of integers or a single integer. If an Array, the length
36992 * must be the same as the number of dimensions in `x`. If a single integer,
36993 * it will be treated as an Array of length 1.
36994 */
36995 function tile$2(x, n) {
36996 if (!Array.isArray(n)) {
36997 n = [n];
36998 }
36999 if (x.rank !== n.length) {
37000 throw new ValueError(`The length of input n (${n.length}) does not match ` +
37001 `the number of dimensions in input x (${x.rank})`);
37002 }
37003 return tile$3(x, n);
37004 }
37005 /* Creation of random tensors. */
37006 /**
37007 * Get a tensor with normal distribution of values.
37008 *
37009 * @param shape Shape of the tensor.
37010 * @param mean mean value of the normal distribution.
37011 * @param stddev standard deviation of the normal distribution.
37012 * @param dtype
37013 * @param seed
37014 * @return The normal tensor.
37015 */
37016 function randomNormal$1(shape, mean = 0.0, stddev = 1.0, dtype, seed) {
37017 return randomNormal$2(shape, mean, stddev, dtype, seed);
37018 }
37019 /* Linear Algebra */
37020 /**
37021 * Multiply two tensors and returns the result as a tensor.
37022 *
37023 * For 2D tensors, this is equivalent to matrix multiplication (matMul).
37024 * For tensors of higher ranks, it follows the Theano behavior,
37025 * (e.g. `(2, 3) * (4, 3, 5) -> (2, 4, 5)`). From the Theano documentation:
37026 *
37027 * For N dimensions it is a sum product over the last axis of x and the
37028 * second-to-last of y:
37029 *
37030 * @param a A tensor of at least rank 2.
37031 * @param b A tensor of at least rank 2.
37032 * @param activation (optional) A string identifying the activation
37033 * function.
37034 * @return Result of the dot operation.
37035 */
37036 function dot$1(a, b, activation, bias) {
37037 if ((a.rank < 2) || (b.rank < 2)) {
37038 throw new NotImplementedError(`dot requires both inputs to be rank >= 2` +
37039 ` but got x shape = ${a.shape} and y shape = ${b.shape}`);
37040 }
37041 if (b.rank >= 3) {
37042 const xLastDim = a.shape.slice(-1)[0];
37043 const ySecondLastDim = b.shape.slice(-2)[0];
37044 if (xLastDim !== ySecondLastDim) {
37045 throw new NotImplementedError(`If rank y >= 3, then the second last dim` +
37046 ` of y must equal the last dim of x but got x shape = ${a.shape} and ` +
37047 ` y shape = ${b.shape}`);
37048 }
37049 }
37050 // Handle basic 2D x 2D case.
37051 if ((a.rank === 2) && (b.rank === 2)) {
37052 const transposeA = false;
37053 const transposeB = false;
37054 // tfc.fused.matMul only fuses certain activation functions. Unsupported
37055 // activation functions are treated as 'linear' activations, which is
37056 // equivalent to a no-op.
37057 return matMul({
37058 a,
37059 b: b,
37060 transposeA,
37061 transposeB,
37062 bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
37063 activation
37064 });
37065 }
37066 else {
37067 // Reshape x into the analogous 2D Tensor.
37068 const aFirstDims = a.shape.slice(); // Holds all but the last dim of x.
37069 const aLastDim = aFirstDims.pop();
37070 a = reshape$3(a, [-1, aLastDim]);
37071 // Reshape y into the analogous 2D Tensor, and keep track of the
37072 // required dimensions to reproduce the output shape.
37073 const bShape = b.shape.slice();
37074 const bLastDim = bShape.pop();
37075 const ySecondLastDim = bShape.pop();
37076 const yOtherDims = [...bShape, bLastDim];
37077 // permutation should be like [r-2, 0, 1, 2, ... r-4, r-3, r-1]
37078 // where r is the rank of y.
37079 const perm = Array.from({ length: b.rank }, (_, i) => {
37080 if (i === 0) {
37081 return b.rank - 2;
37082 }
37083 else if (i <= b.rank - 2) {
37084 return i - 1;
37085 }
37086 return i;
37087 });
37088 b = reshape$3(transpose$2(b, perm), [ySecondLastDim, -1]);
37089 // Multiply x and y as 2D Tensors, and then reshape back to original.
37090 const outputShape = [...aFirstDims, ...yOtherDims];
37091 const transposeA = false;
37092 const transposeB = false;
37093 return reshape$3(matMul({
37094 a,
37095 b,
37096 transposeA,
37097 transposeB,
37098 bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
37099 activation
37100 }), outputShape);
37101 }
37102 }
37103 /**
37104 * Compute the sign Tensor of an input Tensor.
37105 *
37106 * Elements of the input `tf.Tensor` that are === 0 are mapped to 0.
37107 * Elements of the input `tf.Tensor` that are > 0 are mapped to 1.
37108 * Elements of the input `tf.Tensor` that are < 0 are mapped to -1.
37109 *
37110 * @param x Input `tf.Tensor`.
37111 * @return The sign `tf.Tensor`.
37112 */
37113 function sign$2(x) {
37114 // TODO(cais): Move to the core.
37115 return tidy(() => {
37116 const zerosLikeX = zerosLike$3(x);
37117 const onesLikeX = onesLike$3(x);
37118 return where(equal$2(x, zerosLikeX), zerosLikeX, where(greater$3(x, zerosLike$3(x)), onesLikeX, mul(-1, onesLikeX)));
37119 });
37120 }
37121 /**
37122 * Computes the one-hot representation of an integer tensor.
37123 * @param indices nD integer tensor of shape
37124 * `(batch_size, dim1, dim2, ... dim(n-1))`
37125 * @param numClasses Integer, number of classes to consider.
37126 * @returns (n + 1)D one hot representation of the input
37127 * with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`
37128 */
37129 function oneHot$2(indices, numClasses) {
37130 return tidy(() => {
37131 if (indices.rank !== 1) {
37132 throw new Error('Only 1D one-hot tensors are supported in the ' +
37133 'deeplearn backend, at present.');
37134 }
37135 indices = cast$3(indices, 'int32');
37136 return cast$3(oneHot$3(indices, numClasses), 'float32');
37137 });
37138 }
37139 /* Elementary math functions. */
37140 /**
37141 * Retrieves the elements of indices `indices` in the tensor `reference`.
37142 * @param reference A tensor.
37143 * @param indices An integer tensor of indices or an `Array` of integers.
37144 * @param axis Axis along which to perform the gather operation.
37145 * @returns The result of the gathering as a tensor.
37146 */
37147 function gather(reference, indices, axis) {
37148 return tidy(() => {
37149 if (Array.isArray(indices)) {
37150 indices = tensor1d(indices, 'int32');
37151 }
37152 else {
37153 indices = cast$3(indices, 'int32');
37154 }
37155 return gather$1(reference, indices, axis);
37156 });
37157 }
37158 /**
37159 * Element-wise square.
37160 * @param x Input tensor.
37161 * @return element-wise x^2
37162 */
37163 function square$1(x) {
37164 return mul(x, x);
37165 }
37166 /**
37167 * Element-wise exponentiation.
37168 *
37169 * Porting Note: In PyKeras, `a` (the exponent) is a Python integer, which
37170 * takes advatnage of the backend's (e.g., TensorFlow's) automatic
37171 * conversion to tensor. Here we allow `a` to be either a number or a tensor.
37172 *
37173 * @param x The base tensor.
37174 * @param a The exponent, tensor or number. If a number, it is rounded to the
37175 * nearest integer and converted to a tensor.
37176 * @returns A tensor of the same shape as `x`.
37177 */
37178 function pow$2(x, a) {
37179 return tidy(() => {
37180 if (typeof (a) === 'number') {
37181 a = scalar(Math.round(a), 'int32');
37182 }
37183 if (a.dtype !== 'int32') {
37184 throw new NotImplementedError(`Non-int32 dtype (${a.dtype}) is not supported by pow() yet`);
37185 }
37186 return pow$3(x, a);
37187 });
37188 }
37189 /**
37190 * Reshapes bias tensor according to rank of x.
37191 */
37192 function reshapeBias(xRank, bias, dataFormat) {
37193 const biasShape = bias.shape;
37194 if (bias.rank !== 1 && bias.rank !== xRank) {
37195 throw new ValueError(`Unexpected bias dimensions: ${bias.rank}` +
37196 `; expected it to be 1 or ${xRank}`);
37197 }
37198 if (xRank === 5) {
37199 if (dataFormat === 'channelsFirst') {
37200 if (biasShape.length === 1) {
37201 return reshape$3(bias, [1, biasShape[0], 1, 1, 1]);
37202 }
37203 else {
37204 return reshape$3(bias, [1, biasShape[3], biasShape[0], biasShape[1], biasShape[2]]);
37205 }
37206 }
37207 else if (dataFormat === 'channelsLast') {
37208 if (biasShape.length === 1) {
37209 return reshape$3(bias, [1, 1, 1, 1, biasShape[0]]);
37210 }
37211 else {
37212 return reshape$3(bias, [1].concat(biasShape));
37213 }
37214 }
37215 }
37216 else if (xRank === 4) {
37217 if (dataFormat === 'channelsFirst') {
37218 if (biasShape.length === 1) {
37219 return reshape$3(bias, [1, biasShape[0], 1, 1]);
37220 }
37221 else {
37222 return reshape$3(bias, [1, biasShape[2], biasShape[0], biasShape[1]]);
37223 }
37224 }
37225 else if (dataFormat === 'channelsLast') {
37226 if (biasShape.length === 1) {
37227 return reshape$3(bias, [1, 1, 1, biasShape[0]]);
37228 }
37229 else {
37230 return reshape$3(bias, [1].concat(biasShape));
37231 }
37232 }
37233 }
37234 else if (xRank === 3) {
37235 if (dataFormat === 'channelsFirst') {
37236 if (biasShape.length === 1) {
37237 return reshape$3(bias, [1, biasShape[0], 1]);
37238 }
37239 else {
37240 return reshape$3(bias, [1, biasShape[1], biasShape[0]]);
37241 }
37242 }
37243 else if (dataFormat === 'channelsLast') {
37244 if (biasShape.length === 1) {
37245 return reshape$3(bias, [1, 1, biasShape[0]]);
37246 }
37247 else {
37248 return reshape$3(bias, [1].concat(biasShape));
37249 }
37250 }
37251 }
37252 else if (xRank < 3) {
37253 return bias;
37254 }
37255 throw new ValueError(`Unsupported input rank by biasAdd: ${bias.rank}`);
37256 }
37257 /* Neural-network operations. */
37258 /**
37259 * Add a bias to a tensor.
37260 *
37261 * @param x The tensor to add the bias to.
37262 * @param bias The bias to add to `x`. Must be 1D or the same rank as `x`.
37263 * @return Result of the bias adding.
37264 * @throws ValueError: If the rank of `bias` is incorrect.
37265 */
37266 function biasAdd(x, bias, dataFormat) {
37267 return tidy(() => {
37268 if (dataFormat == null) {
37269 dataFormat = imageDataFormat();
37270 }
37271 checkDataFormat(dataFormat);
37272 return add$3(x, reshapeBias(x.rank, bias, dataFormat));
37273 });
37274 }
37275 /**
37276 * Exponential linear unit (ELU).
37277 * @param x A tensor or variable to compute the activation function for.
37278 * @param alpha: A scalar, a scaling factor for the negative section.
37279 * @return Output of the ELU operation.
37280 */
37281 function elu$3(x, alpha = 1) {
37282 // TODO(cais): Add support for alpha values other than 1.
37283 if (alpha !== 1) {
37284 throw new NotImplementedError(`Support for alpha values other than 1 (${alpha}) is not implemented ` +
37285 `yet.`);
37286 }
37287 return elu$4(x);
37288 }
37289 /**
37290 * Softsign of a tensor.
37291 *
37292 * Defined as x / (abs(x) + 1), element-wise.
37293 *
37294 * @param x: Input.
37295 * @returns Output.
37296 */
37297 function softsign(x) {
37298 return tidy(() => div$1(x, add$3(abs$2(x), 1)));
37299 }
37300 /**
37301 * Sets entries in `x` to zero at random, while scaling the entire tensor.
37302 *
37303 * @param x input tensor.
37304 * @param level fraction of the entries in the tensor that will be set to 0.
37305 * @param noiseShape shape of randomly generated keep/drop flags, must be
37306 * broadcastable to the shape of `x`. Optional.
37307 * @param seed random seed to ensure determinism. Optional.
37308 * @returns Result of the dropout operation.
37309 */
37310 function dropout$1(x, level, noiseShape, seed) {
37311 return tidy(() => dropout$2(x, level, noiseShape, seed));
37312 }
37313 /**
37314 * Element-wise, segment-wise linear approximation of sigmoid.
37315 *
37316 * Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
37317 * In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
37318 *
37319 * @param x Input tensor.
37320 * @returns Output tensor.
37321 */
37322 function hardSigmoid(x) {
37323 return tidy(() => {
37324 const y = add$3(.5, mul(.2, x));
37325 return clipByValue$2(y, 0, 1);
37326 });
37327 }
37328 /**
37329 * Invoke `x` in the training phase, and `alt` otherwise.
37330 *
37331 * Porting Note: We do not create placeholder tensors for the `training`
37332 * boolean flag here, because there is no such thing in the TF.js imperative
37333 * backend.
37334 *
37335 * @param x The function to invoke iff `training` is `true`.
37336 * @param alt The function to invoke iff `training` is `false`.
37337 * @param training Boolean flag for whether training phase is active.
37338 * @returns The return value of `x()` if `training` is `true`, or the return
37339 * value of `alt()` if `training` is `false`.
37340 */
37341 function inTrainPhase(x, alt, training = false) {
37342 return training ? x() : alt();
37343 }
37344
37345 /**
37346 * @license
37347 * Copyright 2018 Google LLC
37348 *
37349 * Use of this source code is governed by an MIT-style
37350 * license that can be found in the LICENSE file or at
37351 * https://opensource.org/licenses/MIT.
37352 * =============================================================================
37353 */
37354 const VALID_FAN_MODE_VALUES = ['fanIn', 'fanOut', 'fanAvg'];
37355 const VALID_DISTRIBUTION_VALUES = ['normal', 'uniform', 'truncatedNormal'];
37356 // We can't easily extract a string[] from the string union type, but we can
37357 // recapitulate the list, enforcing at compile time that the values are valid
37358 // and that we have the right number of them.
37359 /**
37360 * A string array of valid Initializer class names.
37361 *
37362 * This is guaranteed to match the `InitializerClassName` union type.
37363 */
37364 const initializerClassNames = [
37365 'Zeros', 'Ones', 'Constant', 'RandomNormal', 'RandomUniform',
37366 'TruncatedNormal', 'VarianceScaling', 'Orthogonal', 'Identity'
37367 ];
37368
37369 /**
37370 * @license
37371 * Copyright 2018 Google LLC
37372 *
37373 * Use of this source code is governed by an MIT-style
37374 * license that can be found in the LICENSE file or at
37375 * https://opensource.org/licenses/MIT.
37376 * =============================================================================
37377 */
37378 function checkFanMode(value) {
37379 checkStringTypeUnionValue(VALID_FAN_MODE_VALUES, 'FanMode', value);
37380 }
37381 function checkDistribution(value) {
37382 checkStringTypeUnionValue(VALID_DISTRIBUTION_VALUES, 'Distribution', value);
37383 }
37384 /**
37385 * Initializer base class.
37386 *
37387 * @doc {
37388 * heading: 'Initializers', subheading: 'Classes', namespace: 'initializers'}
37389 */
37390 class Initializer extends Serializable {
37391 fromConfigUsesCustomObjects() {
37392 return false;
37393 }
37394 getConfig() {
37395 return {};
37396 }
37397 }
37398 class Zeros extends Initializer {
37399 apply(shape, dtype) {
37400 return zeros$2(shape, dtype);
37401 }
37402 }
37403 /** @nocollapse */
37404 Zeros.className = 'Zeros';
37405 registerClass(Zeros);
37406 class Ones extends Initializer {
37407 apply(shape, dtype) {
37408 return ones$1(shape, dtype);
37409 }
37410 }
37411 /** @nocollapse */
37412 Ones.className = 'Ones';
37413 registerClass(Ones);
37414 class Constant extends Initializer {
37415 constructor(args) {
37416 super();
37417 if (typeof args !== 'object') {
37418 throw new ValueError(`Expected argument of type ConstantConfig but got ${args}`);
37419 }
37420 if (args.value === undefined) {
37421 throw new ValueError(`config must have value set but got ${args}`);
37422 }
37423 this.value = args.value;
37424 }
37425 apply(shape, dtype) {
37426 return tidy(() => mul(scalar(this.value), ones$1(shape, dtype)));
37427 }
37428 getConfig() {
37429 return {
37430 value: this.value,
37431 };
37432 }
37433 }
37434 /** @nocollapse */
37435 Constant.className = 'Constant';
37436 registerClass(Constant);
37437 class RandomUniform extends Initializer {
37438 constructor(args) {
37439 super();
37440 this.DEFAULT_MINVAL = -0.05;
37441 this.DEFAULT_MAXVAL = 0.05;
37442 this.minval = args.minval || this.DEFAULT_MINVAL;
37443 this.maxval = args.maxval || this.DEFAULT_MAXVAL;
37444 this.seed = args.seed;
37445 }
37446 apply(shape, dtype) {
37447 return randomUniform$1(shape, this.minval, this.maxval, dtype, this.seed);
37448 }
37449 getConfig() {
37450 return { minval: this.minval, maxval: this.maxval, seed: this.seed };
37451 }
37452 }
37453 /** @nocollapse */
37454 RandomUniform.className = 'RandomUniform';
37455 registerClass(RandomUniform);
37456 class RandomNormal extends Initializer {
37457 constructor(args) {
37458 super();
37459 this.DEFAULT_MEAN = 0.;
37460 this.DEFAULT_STDDEV = 0.05;
37461 this.mean = args.mean || this.DEFAULT_MEAN;
37462 this.stddev = args.stddev || this.DEFAULT_STDDEV;
37463 this.seed = args.seed;
37464 }
37465 apply(shape, dtype) {
37466 dtype = dtype || 'float32';
37467 if (dtype !== 'float32' && dtype !== 'int32') {
37468 throw new NotImplementedError(`randomNormal does not support dType ${dtype}.`);
37469 }
37470 return randomNormal$1(shape, this.mean, this.stddev, dtype, this.seed);
37471 }
37472 getConfig() {
37473 return { mean: this.mean, stddev: this.stddev, seed: this.seed };
37474 }
37475 }
37476 /** @nocollapse */
37477 RandomNormal.className = 'RandomNormal';
37478 registerClass(RandomNormal);
37479 class TruncatedNormal extends Initializer {
37480 constructor(args) {
37481 super();
37482 this.DEFAULT_MEAN = 0.;
37483 this.DEFAULT_STDDEV = 0.05;
37484 this.mean = args.mean || this.DEFAULT_MEAN;
37485 this.stddev = args.stddev || this.DEFAULT_STDDEV;
37486 this.seed = args.seed;
37487 }
37488 apply(shape, dtype) {
37489 dtype = dtype || 'float32';
37490 if (dtype !== 'float32' && dtype !== 'int32') {
37491 throw new NotImplementedError(`truncatedNormal does not support dType ${dtype}.`);
37492 }
37493 return truncatedNormal$1(shape, this.mean, this.stddev, dtype, this.seed);
37494 }
37495 getConfig() {
37496 return { mean: this.mean, stddev: this.stddev, seed: this.seed };
37497 }
37498 }
37499 /** @nocollapse */
37500 TruncatedNormal.className = 'TruncatedNormal';
37501 registerClass(TruncatedNormal);
37502 class Identity extends Initializer {
37503 constructor(args) {
37504 super();
37505 this.gain = args.gain != null ? args.gain : 1.0;
37506 }
37507 apply(shape, dtype) {
37508 return tidy(() => {
37509 if (shape.length !== 2 || shape[0] !== shape[1]) {
37510 throw new ValueError('Identity matrix initializer can only be used for' +
37511 ' 2D square matrices.');
37512 }
37513 else {
37514 return mul(this.gain, eye(shape[0]));
37515 }
37516 });
37517 }
37518 getConfig() {
37519 return { gain: this.gain };
37520 }
37521 }
37522 /** @nocollapse */
37523 Identity.className = 'Identity';
37524 registerClass(Identity);
37525 /**
37526 * Computes the number of input and output units for a weight shape.
37527 * @param shape Shape of weight.
37528 * @param dataFormat data format to use for convolution kernels.
37529 * Note that all kernels in Keras are standardized on the
37530 * CHANNEL_LAST ordering (even when inputs are set to CHANNEL_FIRST).
37531 * @return An length-2 array: fanIn, fanOut.
37532 */
37533 function computeFans(shape, dataFormat = 'channelsLast') {
37534 let fanIn;
37535 let fanOut;
37536 checkDataFormat(dataFormat);
37537 if (shape.length === 2) {
37538 fanIn = shape[0];
37539 fanOut = shape[1];
37540 }
37541 else if ([3, 4, 5].indexOf(shape.length) !== -1) {
37542 if (dataFormat === 'channelsFirst') {
37543 const receptiveFieldSize = arrayProd(shape, 2);
37544 fanIn = shape[1] * receptiveFieldSize;
37545 fanOut = shape[0] * receptiveFieldSize;
37546 }
37547 else if (dataFormat === 'channelsLast') {
37548 const receptiveFieldSize = arrayProd(shape, 0, shape.length - 2);
37549 fanIn = shape[shape.length - 2] * receptiveFieldSize;
37550 fanOut = shape[shape.length - 1] * receptiveFieldSize;
37551 }
37552 }
37553 else {
37554 const shapeProd = arrayProd(shape);
37555 fanIn = Math.sqrt(shapeProd);
37556 fanOut = Math.sqrt(shapeProd);
37557 }
37558 return [fanIn, fanOut];
37559 }
37560 class VarianceScaling extends Initializer {
37561 /**
37562 * Constructor of VarianceScaling.
37563 * @throws ValueError for invalid value in scale.
37564 */
37565 constructor(args) {
37566 super();
37567 if (args.scale < 0.0) {
37568 throw new ValueError(`scale must be a positive float. Got: ${args.scale}`);
37569 }
37570 this.scale = args.scale == null ? 1.0 : args.scale;
37571 this.mode = args.mode == null ? 'fanIn' : args.mode;
37572 checkFanMode(this.mode);
37573 this.distribution =
37574 args.distribution == null ? 'normal' : args.distribution;
37575 checkDistribution(this.distribution);
37576 this.seed = args.seed;
37577 }
37578 apply(shape, dtype) {
37579 const fans = computeFans(shape);
37580 const fanIn = fans[0];
37581 const fanOut = fans[1];
37582 let scale = this.scale;
37583 if (this.mode === 'fanIn') {
37584 scale /= Math.max(1, fanIn);
37585 }
37586 else if (this.mode === 'fanOut') {
37587 scale /= Math.max(1, fanOut);
37588 }
37589 else {
37590 scale /= Math.max(1, (fanIn + fanOut) / 2);
37591 }
37592 if (this.distribution === 'normal') {
37593 const stddev = Math.sqrt(scale);
37594 dtype = dtype || 'float32';
37595 if (dtype !== 'float32' && dtype !== 'int32') {
37596 throw new NotImplementedError(`${this.getClassName()} does not support dType ${dtype}.`);
37597 }
37598 return truncatedNormal$1(shape, 0, stddev, dtype, this.seed);
37599 }
37600 else {
37601 const limit = Math.sqrt(3 * scale);
37602 return randomUniform$1(shape, -limit, limit, dtype, this.seed);
37603 }
37604 }
37605 getConfig() {
37606 return {
37607 scale: this.scale,
37608 mode: this.mode,
37609 distribution: this.distribution,
37610 seed: this.seed
37611 };
37612 }
37613 }
37614 /** @nocollapse */
37615 VarianceScaling.className = 'VarianceScaling';
37616 registerClass(VarianceScaling);
37617 class GlorotUniform extends VarianceScaling {
37618 /**
37619 * Constructor of GlorotUniform
37620 * @param scale
37621 * @param mode
37622 * @param distribution
37623 * @param seed
37624 */
37625 constructor(args) {
37626 super({
37627 scale: 1.0,
37628 mode: 'fanAvg',
37629 distribution: 'uniform',
37630 seed: args == null ? null : args.seed
37631 });
37632 }
37633 getClassName() {
37634 // In Python Keras, GlorotUniform is not a class, but a helper method
37635 // that creates a VarianceScaling object. Use 'VarianceScaling' as
37636 // class name to be compatible with that.
37637 return VarianceScaling.className;
37638 }
37639 }
37640 /** @nocollapse */
37641 GlorotUniform.className = 'GlorotUniform';
37642 registerClass(GlorotUniform);
37643 class GlorotNormal extends VarianceScaling {
37644 /**
37645 * Constructor of GlorotNormal.
37646 * @param scale
37647 * @param mode
37648 * @param distribution
37649 * @param seed
37650 */
37651 constructor(args) {
37652 super({
37653 scale: 1.0,
37654 mode: 'fanAvg',
37655 distribution: 'normal',
37656 seed: args == null ? null : args.seed
37657 });
37658 }
37659 getClassName() {
37660 // In Python Keras, GlorotNormal is not a class, but a helper method
37661 // that creates a VarianceScaling object. Use 'VarianceScaling' as
37662 // class name to be compatible with that.
37663 return VarianceScaling.className;
37664 }
37665 }
37666 /** @nocollapse */
37667 GlorotNormal.className = 'GlorotNormal';
37668 registerClass(GlorotNormal);
37669 class HeNormal extends VarianceScaling {
37670 constructor(args) {
37671 super({
37672 scale: 2.0,
37673 mode: 'fanIn',
37674 distribution: 'normal',
37675 seed: args == null ? null : args.seed
37676 });
37677 }
37678 getClassName() {
37679 // In Python Keras, HeNormal is not a class, but a helper method
37680 // that creates a VarianceScaling object. Use 'VarianceScaling' as
37681 // class name to be compatible with that.
37682 return VarianceScaling.className;
37683 }
37684 }
37685 /** @nocollapse */
37686 HeNormal.className = 'HeNormal';
37687 registerClass(HeNormal);
37688 class HeUniform extends VarianceScaling {
37689 constructor(args) {
37690 super({
37691 scale: 2.0,
37692 mode: 'fanIn',
37693 distribution: 'uniform',
37694 seed: args == null ? null : args.seed
37695 });
37696 }
37697 getClassName() {
37698 // In Python Keras, HeUniform is not a class, but a helper method
37699 // that creates a VarianceScaling object. Use 'VarianceScaling' as
37700 // class name to be compatible with that.
37701 return VarianceScaling.className;
37702 }
37703 }
37704 /** @nocollapse */
37705 HeUniform.className = 'HeUniform';
37706 registerClass(HeUniform);
37707 class LeCunNormal extends VarianceScaling {
37708 constructor(args) {
37709 super({
37710 scale: 1.0,
37711 mode: 'fanIn',
37712 distribution: 'normal',
37713 seed: args == null ? null : args.seed
37714 });
37715 }
37716 getClassName() {
37717 // In Python Keras, LeCunNormal is not a class, but a helper method
37718 // that creates a VarianceScaling object. Use 'VarianceScaling' as
37719 // class name to be compatible with that.
37720 return VarianceScaling.className;
37721 }
37722 }
37723 /** @nocollapse */
37724 LeCunNormal.className = 'LeCunNormal';
37725 registerClass(LeCunNormal);
37726 class LeCunUniform extends VarianceScaling {
37727 constructor(args) {
37728 super({
37729 scale: 1.0,
37730 mode: 'fanIn',
37731 distribution: 'uniform',
37732 seed: args == null ? null : args.seed
37733 });
37734 }
37735 getClassName() {
37736 // In Python Keras, LeCunUniform is not a class, but a helper method
37737 // that creates a VarianceScaling object. Use 'VarianceScaling' as
37738 // class name to be compatible with that.
37739 return VarianceScaling.className;
37740 }
37741 }
37742 /** @nocollapse */
37743 LeCunUniform.className = 'LeCunUniform';
37744 registerClass(LeCunUniform);
37745 class Orthogonal extends Initializer {
37746 constructor(args) {
37747 super();
37748 this.DEFAULT_GAIN = 1;
37749 this.ELEMENTS_WARN_SLOW = 2000;
37750 this.gain = args.gain == null ? this.DEFAULT_GAIN : args.gain;
37751 this.seed = args.seed;
37752 }
37753 apply(shape, dtype) {
37754 return tidy(() => {
37755 if (shape.length < 2) {
37756 throw new NotImplementedError('Shape must be at least 2D.');
37757 }
37758 if (dtype !== 'int32' && dtype !== 'float32' && dtype !== undefined) {
37759 throw new TypeError(`Unsupported data type ${dtype}.`);
37760 }
37761 dtype = dtype;
37762 // flatten the input shape with the last dimension remaining its
37763 // original shape so it works for conv2d
37764 const numRows = sizeFromShape(shape.slice(0, -1));
37765 const numCols = shape[shape.length - 1];
37766 const numElements = numRows * numCols;
37767 if (numElements > this.ELEMENTS_WARN_SLOW) {
37768 console.warn(`Orthogonal initializer is being called on a matrix with more ` +
37769 `than ${this.ELEMENTS_WARN_SLOW} (${numElements}) elements: ` +
37770 `Slowness may result.`);
37771 }
37772 const flatShape = [Math.max(numCols, numRows), Math.min(numCols, numRows)];
37773 // Generate a random matrix
37774 const randNormalMat = randomNormal$1(flatShape, 0, 1, dtype, this.seed);
37775 // Compute QR factorization
37776 const qr = linalg.qr(randNormalMat, false);
37777 let qMat = qr[0];
37778 const rMat = qr[1];
37779 // Make Q uniform
37780 const diag = rMat.flatten().stridedSlice([0], [Math.min(numCols, numRows) * Math.min(numCols, numRows)], [Math.min(numCols, numRows) + 1]);
37781 qMat = mul(qMat, diag.sign());
37782 if (numRows < numCols) {
37783 qMat = qMat.transpose();
37784 }
37785 return mul(scalar(this.gain), qMat.reshape(shape));
37786 });
37787 }
37788 getConfig() {
37789 return {
37790 gain: this.gain,
37791 seed: this.seed,
37792 };
37793 }
37794 }
37795 /** @nocollapse */
37796 Orthogonal.className = 'Orthogonal';
37797 registerClass(Orthogonal);
37798 // Maps the JavaScript-like identifier keys to the corresponding registry
37799 // symbols.
37800 const INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
37801 'constant': 'Constant',
37802 'glorotNormal': 'GlorotNormal',
37803 'glorotUniform': 'GlorotUniform',
37804 'heNormal': 'HeNormal',
37805 'heUniform': 'HeUniform',
37806 'identity': 'Identity',
37807 'leCunNormal': 'LeCunNormal',
37808 'leCunUniform': 'LeCunUniform',
37809 'ones': 'Ones',
37810 'orthogonal': 'Orthogonal',
37811 'randomNormal': 'RandomNormal',
37812 'randomUniform': 'RandomUniform',
37813 'truncatedNormal': 'TruncatedNormal',
37814 'varianceScaling': 'VarianceScaling',
37815 'zeros': 'Zeros'
37816 };
37817 function deserializeInitializer(config, customObjects = {}) {
37818 return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'initializer');
37819 }
37820 function serializeInitializer(initializer) {
37821 return serializeKerasObject(initializer);
37822 }
37823 function getInitializer(identifier) {
37824 if (typeof identifier === 'string') {
37825 const className = identifier in INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
37826 INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
37827 identifier;
37828 /* We have four 'helper' classes for common initializers that
37829 all get serialized as 'VarianceScaling' and shouldn't go through
37830 the deserializeInitializer pathway. */
37831 if (className === 'GlorotNormal') {
37832 return new GlorotNormal();
37833 }
37834 else if (className === 'GlorotUniform') {
37835 return new GlorotUniform();
37836 }
37837 else if (className === 'HeNormal') {
37838 return new HeNormal();
37839 }
37840 else if (className === 'HeUniform') {
37841 return new HeUniform();
37842 }
37843 else if (className === 'LeCunNormal') {
37844 return new LeCunNormal();
37845 }
37846 else if (className === 'LeCunUniform') {
37847 return new LeCunUniform();
37848 }
37849 else {
37850 const config = {};
37851 config['className'] = className;
37852 config['config'] = {};
37853 return deserializeInitializer(config);
37854 }
37855 }
37856 else if (identifier instanceof Initializer) {
37857 return identifier;
37858 }
37859 else {
37860 return deserializeInitializer(identifier);
37861 }
37862 }
37863
37864 /**
37865 * @license
37866 * Copyright 2018 Google LLC
37867 *
37868 * Use of this source code is governed by an MIT-style
37869 * license that can be found in the LICENSE file or at
37870 * https://opensource.org/licenses/MIT.
37871 * =============================================================================
37872 */
37873 // tslint:enable
37874 /**
37875 * Determine whether the input is an Array of Shapes.
37876 */
37877 function isArrayOfShapes(x) {
37878 return Array.isArray(x) && Array.isArray(x[0]);
37879 }
37880 /**
37881 * Special case of normalizing shapes to lists.
37882 *
37883 * @param x A shape or list of shapes to normalize into a list of Shapes.
37884 * @return A list of Shapes.
37885 */
37886 function normalizeShapeList(x) {
37887 if (x.length === 0) {
37888 return [];
37889 }
37890 if (!Array.isArray(x[0])) {
37891 return [x];
37892 }
37893 return x;
37894 }
37895 /**
37896 * Helper function to obtain exactly one Tensor.
37897 * @param xs: A single `tf.Tensor` or an `Array` of `tf.Tensor`s.
37898 * @return A single `tf.Tensor`. If `xs` is an `Array`, return the first one.
37899 * @throws ValueError: If `xs` is an `Array` and its length is not 1.
37900 */
37901 function getExactlyOneTensor(xs) {
37902 let x;
37903 if (Array.isArray(xs)) {
37904 if (xs.length !== 1) {
37905 throw new ValueError(`Expected Tensor length to be 1; got ${xs.length}`);
37906 }
37907 x = xs[0];
37908 }
37909 else {
37910 x = xs;
37911 }
37912 return x;
37913 }
37914 /**
37915 * Helper function to obtain exactly on instance of Shape.
37916 *
37917 * @param shapes Input single `Shape` or Array of `Shape`s.
37918 * @returns If input is a single `Shape`, return it unchanged. If the input is
37919 * an `Array` containing exactly one instance of `Shape`, return the instance.
37920 * Otherwise, throw a `ValueError`.
37921 * @throws ValueError: If input is an `Array` of `Shape`s, and its length is not
37922 * 1.
37923 */
37924 function getExactlyOneShape(shapes) {
37925 if (Array.isArray(shapes) && Array.isArray(shapes[0])) {
37926 if (shapes.length === 1) {
37927 shapes = shapes;
37928 return shapes[0];
37929 }
37930 else {
37931 throw new ValueError(`Expected exactly 1 Shape; got ${shapes.length}`);
37932 }
37933 }
37934 else {
37935 return shapes;
37936 }
37937 }
37938
37939 /**
37940 * @license
37941 * Copyright 2018 Google LLC
37942 *
37943 * Use of this source code is governed by an MIT-style
37944 * license that can be found in the LICENSE file or at
37945 * https://opensource.org/licenses/MIT.
37946 * =============================================================================
37947 */
37948 /**
37949 * Count the elements in an Array of LayerVariables.
37950 *
37951 * @param weights: The LayerVariables of which the constituent numbers are to
37952 * be counted.
37953 * @returns A count of the elements in all the LayerVariables
37954 */
37955 function countParamsInWeights(weights) {
37956 let count = 0;
37957 for (const weight of weights) {
37958 if (weight.shape.length === 0) {
37959 count += 1;
37960 }
37961 else {
37962 count += weight.shape.reduce((a, b) => a * b);
37963 }
37964 }
37965 return count;
37966 }
37967
37968 /**
37969 * @license
37970 * Copyright 2018 Google LLC
37971 *
37972 * Use of this source code is governed by an MIT-style
37973 * license that can be found in the LICENSE file or at
37974 * https://opensource.org/licenses/MIT.
37975 * =============================================================================
37976 */
37977 const DEFAULT_VARIABLE_NAME_PREFIX = 'Variable';
37978 /**
37979 * A `tf.layers.LayerVariable` is similar to a `tf.Tensor` in that it has a
37980 * dtype and shape, but its value is mutable. The value is itself represented
37981 * as a`tf.Tensor`, and can be read with the `read()` method and updated with
37982 * the `write()` method.
37983 */
37984 class LayerVariable {
37985 /**
37986 * Construct Variable from a `tf.Tensor`.
37987 *
37988 * If not explicitly named, the Variable will be given a name with the
37989 * prefix 'Variable'. Variable names are unique. In the case of name
37990 * collision, suffixies '_<num>' will be added to the name.
37991 *
37992 * @param val Initial value of the Variable.
37993 * @param name Name of the variable. If `null` or `undefined` is provided, it
37994 * will default a name with the prefix 'Variable'.
37995 * @param constraint Optional, projection function to be applied to the
37996 * variable after optimize updates
37997 * @throws ValueError if `name` is `null` or `undefined`.
37998 */
37999 constructor(val, dtype = 'float32', name = DEFAULT_VARIABLE_NAME_PREFIX, trainable = true, constraint = null) {
38000 this.dtype = dtype == null ? 'float32' : dtype;
38001 this.shape = val.shape;
38002 this.id = getNextUniqueTensorId();
38003 name = name == null ? DEFAULT_VARIABLE_NAME_PREFIX : name;
38004 this.originalName = getScopedTensorName(name);
38005 this.name = getUniqueTensorName(this.originalName);
38006 this.trainable_ = trainable;
38007 this.constraint = constraint;
38008 this.val = variable$1(val, this.trainable_, this.name, this.dtype);
38009 }
38010 /**
38011 * Get a snapshot of the Variable's value.
38012 *
38013 * The returned value is a snapshot of the Variable's value at the time of
38014 * the invocation. Future mutations in the value of the tensor will only
38015 * be reflected by future calls to this method.
38016 */
38017 read() {
38018 this.assertNotDisposed();
38019 return this.val;
38020 }
38021 /**
38022 * Update the value of the Variable.
38023 *
38024 * @param newVal: The new value to update to. Must be consistent with the
38025 * dtype and shape of the Variable.
38026 * @return This Variable.
38027 */
38028 write(newVal) {
38029 // TODO(cais): Once TF.js Core supports Tensor.dtype, check dtype match.
38030 this.assertNotDisposed();
38031 checkShapesMatch(this.val, newVal);
38032 // Skip updating if this is the exact same tensor.
38033 if (this.val.id !== newVal.id) {
38034 this.val.assign(newVal);
38035 if (this.constraint != null) {
38036 this.val.assign(this.constraint.apply(this.val));
38037 }
38038 }
38039 return this;
38040 }
38041 /**
38042 * Dispose this LayersVariable instance from memory.
38043 */
38044 dispose() {
38045 this.assertNotDisposed();
38046 this.val.dispose();
38047 }
38048 assertNotDisposed() {
38049 if (this.val.isDisposed) {
38050 throw new Error(`LayersVariable ${this.name} is already disposed.`);
38051 }
38052 }
38053 get trainable() {
38054 return this.trainable_;
38055 }
38056 set trainable(trainable) {
38057 this.trainable_ = trainable;
38058 this.val.trainable = trainable;
38059 }
38060 }
38061 function checkShapesMatch(x, y) {
38062 if (x.shape.toString() !== y.shape.toString()) {
38063 throw new Error('Shape mismatch: ' + JSON.stringify(x.shape) + ' vs. ' +
38064 JSON.stringify(y.shape));
38065 }
38066 }
38067 /**
38068 * Create a Variable.
38069 * @param x The initial value of the `Variable`.
38070 * @param dtype optional, the type of the variable.
38071 * @param name optional, the name of the variable, default provided by
38072 * Variable.
38073 * @param constraint optional, a constraint to be applied after every update.
38074 * @return The newly instantiated `Variable`.
38075 */
38076 function variable(x, dtype, name, constraint) {
38077 return new LayerVariable(x, dtype, name, true, constraint);
38078 }
38079 /**
38080 * Instantiates an all-zeros Variable and returns it.
38081 *
38082 * @param shape Shape of the tensor.
38083 * @param dtype DType of the tensor.
38084 * @param name Name of the tensor.
38085 * @return An all-zero Variable.
38086 */
38087 function zerosVariable(shape, dtype, name) {
38088 // TODO(cais): Implement logic for dtype.
38089 return new LayerVariable(zeros$2(shape), dtype, name);
38090 }
38091 /**
38092 * Instantiates an all-zeros tensor of the same shape as another tensor.
38093 *
38094 * @param x The other tensor.
38095 * @param dtype DType of the tensor.
38096 * @param name Name of the tensor.
38097 * @return A newly instantiated Variable.
38098 */
38099 function zerosLike$2(x, dtype, name) {
38100 return new LayerVariable(zerosLike$3(x), dtype, name);
38101 }
38102 /**
38103 * Instantiates an all-ones tensor and returns it.
38104 *
38105 * @param shape Shape of the tensor.
38106 * @param dtype DType of the tensor.
38107 * @param name Name of the tensor.
38108 * @return An all-ones Variable.
38109 */
38110 function onesVariable(shape, dtype, name) {
38111 // TODO(cais): Implement logic for dtype.
38112 const allocated = ones$1(shape);
38113 return new LayerVariable(allocated, dtype, name);
38114 }
38115 /**
38116 * Instantiates an all-ones tensor of the same shape as another tensor.
38117 *
38118 * @param x The other tensor.
38119 * @param dtype DType of the tensor.
38120 * @param name Name of the tensor.
38121 * @return A newly instantiated Variable.
38122 */
38123 function onesLike$2(x, dtype, name) {
38124 const allocated = onesLike$3(x);
38125 return new LayerVariable(allocated, dtype, name);
38126 }
38127 /**
38128 * Instantiate an identity matrix and returns it, as a Variable
38129 *
38130 * @param size Number of rows/columns.
38131 * @param dtype Data type of returned Variable.
38132 * @param name Name of returned Variable.
38133 * @return A Variable, an identity matrix.
38134 */
38135 function eyeVariable(size, dtype, name) {
38136 return new LayerVariable(eye(size), dtype, name);
38137 }
38138 /**
38139 * Get a Variable with uniform distribution of values.
38140 * @param shape Shape of the tensor.
38141 * @param minval Lower bound of the uniform distribution.
38142 * @param maxval Upper bound of the uniform distribution.
38143 * @param dtype
38144 * @param seed
38145 * @param name Optional name.
38146 * @return The uniform-random Variable.
38147 */
38148 function randomUniformVariable(shape, minval, maxval, dtype, seed, name = 'randomUniform') {
38149 return new LayerVariable(randomUniform$1(shape, minval, maxval, dtype), dtype, name);
38150 }
38151 /**
38152 * Get a Variable with truncated-normal distribution of values.
38153 * @param shape Shape of the tensor.
38154 * @param mean mean value of the normal distribution.
38155 * @param stddev standard deviation of the normal distribution.
38156 * @param dtype
38157 * @param seed
38158 * @param name Optional name.
38159 * @return The truncated-normal-random Variable.
38160 */
38161 function truncatedNormalVariable(shape, mean = 0.0, stddev = 1.0, dtype, seed, name = 'truncatedNormal') {
38162 // TODO(cais): Implement logic for dtype and seed once they are supported
38163 // by deeplearn.js.
38164 dtype = dtype || 'float32';
38165 if (dtype !== 'float32' && dtype !== 'int32') {
38166 throw new NotImplementedError(`randomNormal does not support dType ${dtype}.`);
38167 }
38168 return new LayerVariable(truncatedNormal$1(shape, mean, stddev, dtype, seed), dtype, name);
38169 }
38170 /**
38171 * Get a Variable with normal distribution of values.
38172 * @param shape Shape of the tensor.
38173 * @param mean mean value of the normal distribution.
38174 * @param stddev standard deviation of the normal distribution.
38175 * @param dtype
38176 * @param seed
38177 * @param name Optional name.
38178 * @return The truncated-normal-random Variable.
38179 */
38180 function randomNormalVariable(shape, mean = 0.0, stddev = 1.0, dtype, seed, name = 'randomNormal') {
38181 dtype = dtype || 'float32';
38182 if (dtype !== 'float32' && dtype !== 'int32') {
38183 throw new NotImplementedError(`randomNormalVariable does not support dType ${dtype}.`);
38184 }
38185 return new LayerVariable(randomNormal$2(shape, mean, stddev, dtype, seed), dtype, name);
38186 }
38187 /**
38188 * Update the value of a Variable.
38189 * @param x The Variable to be updated.
38190 * @param xNew The new value to update to.
38191 * @return The Variable updated.
38192 */
38193 function update(x, xNew) {
38194 return x.write(xNew);
38195 }
38196 /**
38197 * Update the value of a Variable by adding an increment.
38198 * @param x The Variable to be updated.
38199 * @param increment The incrment to add to `x`.
38200 * @return The Variable updated.
38201 */
38202 function updateAdd(x, increment) {
38203 return x.write(add$3(x.read(), increment));
38204 }
38205 /**
38206 * Update the value of a Variable by subtracting a decrement.
38207 * @param x The Variable to be updated.
38208 * @param decrement The decrement to subtract from `x`.
38209 * @return The Variable updated.
38210 */
38211 function updateSub(x, decrement) {
38212 return x.write(sub$2(x.read(), decrement));
38213 }
38214 /**
38215 * Get the values of an array of Variables.
38216 *
38217 * @param tensors An `Array` of `Variable`s to get the values of.
38218 * @return The values of the inputs, as an `Array` of`tf.Tensor`s.
38219 */
38220 function batchGetValue(xs) {
38221 return xs.map(x => x.read());
38222 }
38223 /**
38224 * Update the value of multiple Variables at once.
38225 *
38226 * @param variablesAndValues An `Array`, each element is of type
38227 * [Variable, Tensor]. The first item is the
38228 * `Variable` of which the value is to be updated. The second item
38229 * carries the new value.
38230 */
38231 function batchSetValue(variablesAndValues) {
38232 variablesAndValues.forEach(variableAndValue => {
38233 const variable = variableAndValue[0];
38234 variable.write(variableAndValue[1]);
38235 });
38236 }
38237 /**
38238 * Returns the gradients of `variables` w.r.t. the return value of `lossFn`.
38239 * @param lossFn A function which returns a Scalar to be used as the function
38240 * value (i.e., numerator) for differentiation.
38241 * @param variables List of variables to be used as the independent variables
38242 * (i.e., denominator) for differentiation.
38243 * @returns An Array of gradients tensors.
38244 */
38245 function gradients(lossFn, variables) {
38246 // TODO(cais): The return type signature can be simplified if deeplearn makes
38247 // the corresponding type public.
38248 const variableList = variables.map(variable => variable.read());
38249 const valudAndGrads = variableGrads(lossFn, variableList);
38250 return variables.map(variable => valudAndGrads.grads[variable.name]);
38251 }
38252
38253 /**
38254 * @license
38255 * Copyright 2018 Google LLC
38256 *
38257 * Use of this source code is governed by an MIT-style
38258 * license that can be found in the LICENSE file or at
38259 * https://opensource.org/licenses/MIT.
38260 * =============================================================================
38261 */
38262 /**
38263 * Specifies the ndim, dtype and shape of every input to a layer.
38264 *
38265 * Every layer should expose (if appropriate) an `inputSpec` attribute:
38266 * a list of instances of InputSpec (one per input tensor).
38267 *
38268 * A null entry in a shape is compatible with any dimension,
38269 * a null shape is compatible with any shape.
38270 */
38271 class InputSpec {
38272 constructor(args) {
38273 this.dtype = args.dtype;
38274 this.shape = args.shape;
38275 /*
38276 TODO(michaelterry): Could throw error if ndim and shape are both defined
38277 (then backport).
38278 */
38279 if (args.shape != null) {
38280 this.ndim = args.shape.length;
38281 }
38282 else {
38283 this.ndim = args.ndim;
38284 }
38285 this.maxNDim = args.maxNDim;
38286 this.minNDim = args.minNDim;
38287 this.axes = args.axes || {};
38288 }
38289 }
38290 /**
38291 * `tf.SymbolicTensor` is a placeholder for a Tensor without any concrete value.
38292 *
38293 * They are most often encountered when building a graph of `Layer`s for a
38294 * `tf.LayersModel` and the input data's shape, but not values are known.
38295 *
38296 * @doc {heading: 'Models', 'subheading': 'Classes'}
38297 */
38298 class SymbolicTensor {
38299 /**
38300 *
38301 * @param dtype
38302 * @param shape
38303 * @param sourceLayer The Layer that produced this symbolic tensor.
38304 * @param inputs The inputs passed to sourceLayer's __call__() method.
38305 * @param nodeIndex
38306 * @param tensorIndex
38307 * @param callArgs The keyword arguments passed to the __call__() method.
38308 * @param name
38309 * @param outputTensorIndex The index of this tensor in the list of outputs
38310 * returned by apply().
38311 */
38312 constructor(dtype, shape, sourceLayer, inputs, callArgs, name, outputTensorIndex) {
38313 this.dtype = dtype;
38314 this.shape = shape;
38315 this.sourceLayer = sourceLayer;
38316 this.inputs = inputs;
38317 this.callArgs = callArgs;
38318 this.outputTensorIndex = outputTensorIndex;
38319 this.id = getNextUniqueTensorId();
38320 if (name != null) {
38321 this.originalName = getScopedTensorName(name);
38322 this.name = getUniqueTensorName(this.originalName);
38323 }
38324 this.rank = shape.length;
38325 }
38326 }
38327 let _nextNodeID = 0;
38328 /**
38329 * A `Node` describes the connectivity between two layers.
38330 *
38331 * Each time a layer is connected to some new input,
38332 * a node is added to `layer.inboundNodes`.
38333 *
38334 * Each time the output of a layer is used by another layer,
38335 * a node is added to `layer.outboundNodes`.
38336 *
38337 * `nodeIndices` and `tensorIndices` are basically fine-grained coordinates
38338 * describing the origin of the `inputTensors`, verifying the following:
38339 *
38340 * `inputTensors[i] ==
38341 * inboundLayers[i].inboundNodes[nodeIndices[i]].outputTensors[
38342 * tensorIndices[i]]`
38343 *
38344 * A node from layer A to layer B is added to:
38345 * A.outboundNodes
38346 * B.inboundNodes
38347 */
38348 class Node {
38349 constructor(args,
38350 // TODO(michaelterry): Define actual type for this.
38351 callArgs) {
38352 this.callArgs = callArgs;
38353 this.id = _nextNodeID++;
38354 /*
38355 Layer instance (NOT a list).
38356 this is the layer that takes a list of input tensors
38357 and turns them into a list of output tensors.
38358 the current node will be added to
38359 the inboundNodes of outboundLayer.
38360 */
38361 this.outboundLayer = args.outboundLayer;
38362 /*
38363 The following 3 properties describe where
38364 the input tensors come from: which layers,
38365 and for each layer, which node and which
38366 tensor output of each node.
38367 */
38368 // List of layer instances.
38369 this.inboundLayers = args.inboundLayers;
38370 // List of integers, 1:1 mapping with inboundLayers.
38371 this.nodeIndices = args.nodeIndices;
38372 // List of integers, 1:1 mapping with inboundLayers.
38373 this.tensorIndices = args.tensorIndices;
38374 /*
38375 Following 2 properties:
38376 tensor inputs and outputs of outboundLayer.
38377 */
38378 // List of tensors. 1:1 mapping with inboundLayers.
38379 this.inputTensors = args.inputTensors;
38380 // List of tensors, created by outboundLayer.call().
38381 this.outputTensors = args.outputTensors;
38382 /*
38383 Following 2 properties: input and output masks.
38384 List of tensors, 1:1 mapping with inputTensor.
38385 */
38386 this.inputMasks = args.inputMasks;
38387 // List of tensors, created by outboundLayer.computeMask().
38388 this.outputMasks = args.outputMasks;
38389 // Following 2 properties: input and output shapes.
38390 // List of shape tuples, shapes of inputTensors.
38391 this.inputShapes = args.inputShapes;
38392 // List of shape tuples, shapes of outputTensors.
38393 this.outputShapes = args.outputShapes;
38394 // Add nodes to all layers involved.
38395 for (const layer of args.inboundLayers) {
38396 if (layer != null) {
38397 layer.outboundNodes.push(this);
38398 }
38399 }
38400 args.outboundLayer.inboundNodes.push(this);
38401 }
38402 getConfig() {
38403 const inboundNames = [];
38404 for (const layer of this.inboundLayers) {
38405 if (layer != null) {
38406 inboundNames.push(layer.name);
38407 }
38408 else {
38409 inboundNames.push(null);
38410 }
38411 }
38412 return {
38413 outboundLayer: this.outboundLayer ? this.outboundLayer.name : null,
38414 inboundLayers: inboundNames,
38415 nodeIndices: this.nodeIndices,
38416 tensorIndices: this.tensorIndices
38417 };
38418 }
38419 }
38420 let _nextLayerID = 0;
38421 /**
38422 * A layer is a grouping of operations and weights that can be composed to
38423 * create a `tf.LayersModel`.
38424 *
38425 * Layers are constructed by using the functions under the
38426 * [tf.layers](#Layers-Basic) namespace.
38427 *
38428 * @doc {heading: 'Layers', subheading: 'Classes', namespace: 'layers'}
38429 */
38430 class Layer extends Serializable {
38431 constructor(args = {}) {
38432 super();
38433 this._callHook = null;
38434 this._addedWeightNames = [];
38435 // Porting Notes: PyKeras does not have this property in this base Layer
38436 // class. Instead lets Layer subclass set it dynamically and checks the
38437 // value with `hasattr`. In tfjs-layers, we let this be a member of this
38438 // base class.
38439 this._stateful = false;
38440 this.id = _nextLayerID++;
38441 this.activityRegularizer = null;
38442 this.inputSpec = null;
38443 this.supportsMasking = false;
38444 // These properties will be set upon call of this.build()
38445 this._trainableWeights = [];
38446 this._nonTrainableWeights = [];
38447 this._losses = [];
38448 this._updates = [];
38449 this._built = false;
38450 /*
38451 These lists will be filled via successive calls
38452 to this.addInboundNode().
38453 */
38454 this.inboundNodes = [];
38455 this.outboundNodes = [];
38456 let name = args.name;
38457 if (!name) {
38458 const prefix = this.getClassName();
38459 name = toSnakeCase(prefix) + '_' + getUid(prefix);
38460 }
38461 this.name = name;
38462 this.trainable_ = args.trainable == null ? true : args.trainable;
38463 if (args.inputShape != null || args.batchInputShape != null) {
38464 /*
38465 In this case we will later create an input layer
38466 to insert before the current layer
38467 */
38468 let batchInputShape;
38469 if (args.batchInputShape != null) {
38470 batchInputShape = args.batchInputShape;
38471 }
38472 else if (args.inputShape != null) {
38473 let batchSize = null;
38474 if (args.batchSize != null) {
38475 batchSize = args.batchSize;
38476 }
38477 batchInputShape = [batchSize].concat(args.inputShape);
38478 }
38479 this.batchInputShape = batchInputShape;
38480 // Set dtype.
38481 let dtype = args.dtype;
38482 if (dtype == null) {
38483 dtype = args.inputDType;
38484 }
38485 if (dtype == null) {
38486 dtype = 'float32';
38487 }
38488 this.dtype = dtype;
38489 }
38490 if (args.weights != null) {
38491 this.initialWeights = args.weights;
38492 }
38493 else {
38494 this.initialWeights = null;
38495 }
38496 // The value of `_refCount` is initialized to null. When the layer is used
38497 // in a symbolic way for the first time, it will be set to 1.
38498 this._refCount = null;
38499 this.fastWeightInitDuringBuild = false;
38500 }
38501 /**
38502 * Converts a layer and its index to a unique (immutable type) name.
38503 * This function is used internally with `this.containerNodes`.
38504 * @param layer The layer.
38505 * @param nodeIndex The layer's position (e.g. via enumerate) in a list of
38506 * nodes.
38507 *
38508 * @returns The unique name.
38509 */
38510 static nodeKey(layer, nodeIndex) {
38511 return layer.name + '_ib-' + nodeIndex.toString();
38512 }
38513 /**
38514 * Returns this.inboundNode at index nodeIndex.
38515 *
38516 * Porting note: This is a replacement for _get_node_attribute_at_index()
38517 * @param nodeIndex
38518 * @param attrName The name of the attribute related to request for this node.
38519 */
38520 getNodeAtIndex(nodeIndex, attrName) {
38521 if (this.inboundNodes.length === 0) {
38522 throw new RuntimeError('The layer has never been called ' +
38523 `and thus has no defined ${attrName}.`);
38524 }
38525 if (this.inboundNodes.length <= nodeIndex) {
38526 throw new ValueError(`Asked to get ${attrName} at node ${nodeIndex}, ` +
38527 `but the layer has only ${this.inboundNodes.length} inbound nodes.`);
38528 }
38529 return this.inboundNodes[nodeIndex];
38530 }
38531 /**
38532 * Retrieves the input tensor(s) of a layer at a given node.
38533 *
38534 * @param nodeIndex Integer, index of the node from which to retrieve the
38535 * attribute. E.g. `nodeIndex=0` will correspond to the first time the layer
38536 * was called.
38537 *
38538 * @return A tensor (or list of tensors if the layer has multiple inputs).
38539 */
38540 getInputAt(nodeIndex) {
38541 return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'input').inputTensors);
38542 }
38543 /**
38544 * Retrieves the output tensor(s) of a layer at a given node.
38545 *
38546 * @param nodeIndex Integer, index of the node from which to retrieve the
38547 * attribute. E.g. `nodeIndex=0` will correspond to the first time the layer
38548 * was called.
38549 *
38550 * @return A tensor (or list of tensors if the layer has multiple outputs).
38551 */
38552 getOutputAt(nodeIndex) {
38553 return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'output').outputTensors);
38554 }
38555 // Properties
38556 /**
38557 * Retrieves the input tensor(s) of a layer.
38558 *
38559 * Only applicable if the layer has exactly one inbound node,
38560 * i.e. if it is connected to one incoming layer.
38561 *
38562 * @return Input tensor or list of input tensors.
38563 *
38564 * @exception AttributeError if the layer is connected to more than one
38565 * incoming layers.
38566 */
38567 get input() {
38568 if (this.inboundNodes.length > 1) {
38569 throw new AttributeError(`Layer ${this.name}` +
38570 ' has multiple inbound nodes, ' +
38571 'hence the notion of "layer input" ' +
38572 'is ill-defined. ' +
38573 'Use `getInputAt(nodeIndex)` instead.');
38574 }
38575 else if (this.inboundNodes.length === 0) {
38576 throw new AttributeError(`Layer ${this.name}` +
38577 ' is not connected, no input to return.');
38578 }
38579 return singletonOrArray(this.getNodeAtIndex(0, 'input').inputTensors);
38580 }
38581 /**
38582 * Retrieves the output tensor(s) of a layer.
38583 *
38584 * Only applicable if the layer has exactly one inbound node,
38585 * i.e. if it is connected to one incoming layer.
38586 *
38587 * @return Output tensor or list of output tensors.
38588 *
38589 * @exception AttributeError if the layer is connected to more than one
38590 * incoming layers.
38591 */
38592 get output() {
38593 if (this.inboundNodes.length === 0) {
38594 throw new AttributeError(`Layer ${this.name}` +
38595 ' has no inbound nodes.');
38596 }
38597 if (this.inboundNodes.length > 1) {
38598 throw new AttributeError(`Layer ${this.name}` +
38599 ' has multiple inbound nodes, ' +
38600 'hence the notion of "layer output" ' +
38601 'is ill-defined. ' +
38602 'Use `getOutputAt(nodeIndex)` instead.');
38603 }
38604 return singletonOrArray(this.getNodeAtIndex(0, 'output').outputTensors);
38605 }
38606 get losses() {
38607 return this._losses;
38608 }
38609 /**
38610 * Retrieves the Layer's current loss values.
38611 *
38612 * Used for regularizers during training.
38613 */
38614 calculateLosses() {
38615 // Porting Node: This is an augmentation to Layer.loss in PyKeras.
38616 // In PyKeras, Layer.loss returns symbolic tensors. Here a concrete
38617 // Tensor (specifically Scalar) values are returned. This is due to the
38618 // imperative backend.
38619 return this.losses.map(lossFn => lossFn());
38620 }
38621 get updates() {
38622 return this._updates;
38623 }
38624 get built() {
38625 return this._built;
38626 }
38627 set built(built) {
38628 this._built = built;
38629 }
38630 get trainable() {
38631 return this.trainable_;
38632 }
38633 set trainable(trainable) {
38634 this._trainableWeights.forEach(w => w.trainable = trainable);
38635 this.trainable_ = trainable;
38636 }
38637 get trainableWeights() {
38638 if (this.trainable_) {
38639 return this._trainableWeights.filter(w => w.trainable);
38640 }
38641 else {
38642 return [];
38643 }
38644 }
38645 set trainableWeights(weights) {
38646 this._trainableWeights = weights;
38647 }
38648 get nonTrainableWeights() {
38649 if (this.trainable) {
38650 return this._trainableWeights.filter(w => !w.trainable)
38651 .concat(this._nonTrainableWeights);
38652 }
38653 else {
38654 return this._trainableWeights.concat(this._nonTrainableWeights);
38655 }
38656 }
38657 set nonTrainableWeights(weights) {
38658 this._nonTrainableWeights = weights;
38659 }
38660 /**
38661 * The concatenation of the lists trainableWeights and nonTrainableWeights
38662 * (in this order).
38663 */
38664 get weights() {
38665 return this.trainableWeights.concat(this.nonTrainableWeights);
38666 }
38667 get stateful() {
38668 return this._stateful;
38669 }
38670 /**
38671 * Reset the states of the layer.
38672 *
38673 * This method of the base Layer class is essentially a no-op.
38674 * Subclasses that are stateful (e.g., stateful RNNs) should override this
38675 * method.
38676 */
38677 resetStates() {
38678 if (!this.stateful) {
38679 throw new Error('Cannot call the resetStates() method of a non-stateful Layer ' +
38680 'object.');
38681 }
38682 }
38683 /**
38684 * Checks compatibility between the layer and provided inputs.
38685 *
38686 * This checks that the tensor(s) `input`
38687 * verify the input assumptions of the layer
38688 * (if any). If not, exceptions are raised.
38689 *
38690 * @param inputs Input tensor or list of input tensors.
38691 *
38692 * @exception ValueError in case of mismatch between
38693 * the provided inputs and the expectations of the layer.
38694 */
38695 assertInputCompatibility(inputs) {
38696 const inputsList = toList(inputs);
38697 if (this.inputSpec == null || this.inputSpec.length === 0) {
38698 return;
38699 }
38700 const inputSpec = toList(this.inputSpec);
38701 if (inputsList.length !== inputSpec.length) {
38702 throw new ValueError(`Layer ${this.name} expects ${inputSpec.length} inputs, ` +
38703 `but it received ${inputsList.length} input tensors. ` +
38704 `Input received: ${inputs}`);
38705 }
38706 for (let inputIndex = 0; inputIndex < inputsList.length; inputIndex++) {
38707 const x = inputsList[inputIndex];
38708 const spec = inputSpec[inputIndex];
38709 if (spec == null) {
38710 continue;
38711 }
38712 // Check ndim.
38713 const ndim = x.rank;
38714 if (spec.ndim != null) {
38715 if (ndim !== spec.ndim) {
38716 throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}: ` +
38717 `expected ndim=${spec.ndim}, found ndim=${ndim}`);
38718 }
38719 }
38720 if (spec.maxNDim != null) {
38721 if (ndim > spec.maxNDim) {
38722 throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}` +
38723 `: expected max_ndim=${spec.maxNDim}, found ndim=${ndim}`);
38724 }
38725 }
38726 if (spec.minNDim != null) {
38727 if (ndim < spec.minNDim) {
38728 throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}` +
38729 `: expected min_ndim=${spec.minNDim}, found ndim=${ndim}.`);
38730 }
38731 }
38732 // Check dtype.
38733 if (spec.dtype != null) {
38734 if (x.dtype !== spec.dtype) {
38735 throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name} ` +
38736 `: expected dtype=${spec.dtype}, found dtype=${x.dtype}.`);
38737 }
38738 }
38739 // Check specific shape axes.
38740 if (spec.axes) {
38741 const xShape = x.shape;
38742 for (const key in spec.axes) {
38743 const axis = Number(key);
38744 const value = spec.axes[key];
38745 // Perform Python-style slicing in case axis < 0;
38746 // TODO(cais): Use https://github.com/alvivi/typescript-underscore to
38747 // ensure type safety through Underscore calls.
38748 const xShapeAtAxis = axis >= 0 ? xShape[axis] : xShape[xShape.length + axis];
38749 if (value != null && [value, null].indexOf(xShapeAtAxis) === -1) {
38750 throw new ValueError(`Input ${inputIndex} is incompatible with layer ` +
38751 `${this.name}: expected axis ${axis} of input shape to ` +
38752 `have value ${value} but got shape ${xShape}.`);
38753 }
38754 }
38755 }
38756 // Check shape.
38757 if (spec.shape != null) {
38758 for (let i = 0; i < spec.shape.length; ++i) {
38759 const specDim = spec.shape[i];
38760 const dim = x.shape[i];
38761 if (specDim != null && dim != null) {
38762 if (specDim !== dim) {
38763 throw new ValueError(`Input ${inputIndex} is incompatible with layer ` +
38764 `${this.name}: expected shape=${spec.shape}, ` +
38765 `found shape=${x.shape}.`);
38766 }
38767 }
38768 }
38769 }
38770 }
38771 }
38772 /**
38773 * This is where the layer's logic lives.
38774 *
38775 * @param inputs Input tensor, or list/tuple of input tensors.
38776 * @param kwargs Additional keyword arguments.
38777 *
38778 * @return A tensor or list/tuple of tensors.
38779 */
38780 call(inputs, kwargs) {
38781 return inputs;
38782 }
38783 invokeCallHook(inputs, kwargs) {
38784 if (this._callHook != null) {
38785 this._callHook(inputs, kwargs);
38786 }
38787 }
38788 /**
38789 * Set call hook.
38790 * This is currently used for testing only.
38791 * @param callHook
38792 */
38793 setCallHook(callHook) {
38794 this._callHook = callHook;
38795 }
38796 /**
38797 * Clear call hook.
38798 * This is currently used for testing only.
38799 */
38800 clearCallHook() {
38801 this._callHook = null;
38802 }
38803 /**
38804 * Builds or executes a `Layer`'s logic.
38805 *
38806 * When called with `tf.Tensor`(s), execute the `Layer`'s computation and
38807 * return Tensor(s). For example:
38808 *
38809 * ```js
38810 * const denseLayer = tf.layers.dense({
38811 * units: 1,
38812 * kernelInitializer: 'zeros',
38813 * useBias: false
38814 * });
38815 *
38816 * // Invoke the layer's apply() method with a `tf.Tensor` (with concrete
38817 * // numeric values).
38818 * const input = tf.ones([2, 2]);
38819 * const output = denseLayer.apply(input);
38820 *
38821 * // The output's value is expected to be [[0], [0]], due to the fact that
38822 * // the dense layer has a kernel initialized to all-zeros and does not have
38823 * // a bias.
38824 * output.print();
38825 * ```
38826 *
38827 * When called with `tf.SymbolicTensor`(s), this will prepare the layer for
38828 * future execution. This entails internal book-keeping on shapes of
38829 * expected Tensors, wiring layers together, and initializing weights.
38830 *
38831 * Calling `apply` with `tf.SymbolicTensor`s are typically used during the
38832 * building of non-`tf.Sequential` models. For example:
38833 *
38834 * ```js
38835 * const flattenLayer = tf.layers.flatten();
38836 * const denseLayer = tf.layers.dense({units: 1});
38837 *
38838 * // Use tf.layers.input() to obtain a SymbolicTensor as input to apply().
38839 * const input = tf.input({shape: [2, 2]});
38840 * const output1 = flattenLayer.apply(input);
38841 *
38842 * // output1.shape is [null, 4]. The first dimension is the undetermined
38843 * // batch size. The second dimension comes from flattening the [2, 2]
38844 * // shape.
38845 * console.log(JSON.stringify(output1.shape));
38846 *
38847 * // The output SymbolicTensor of the flatten layer can be used to call
38848 * // the apply() of the dense layer:
38849 * const output2 = denseLayer.apply(output1);
38850 *
38851 * // output2.shape is [null, 1]. The first dimension is the undetermined
38852 * // batch size. The second dimension matches the number of units of the
38853 * // dense layer.
38854 * console.log(JSON.stringify(output2.shape));
38855 *
38856 * // The input and output can be used to construct a model that consists
38857 * // of the flatten and dense layers.
38858 * const model = tf.model({inputs: input, outputs: output2});
38859 * ```
38860 *
38861 * @param inputs a `tf.Tensor` or `tf.SymbolicTensor` or an Array of them.
38862 * @param kwargs Additional keyword arguments to be passed to `call()`.
38863 *
38864 * @return Output of the layer's `call` method.
38865 *
38866 * @exception ValueError error in case the layer is missing shape information
38867 * for its `build` call.
38868 *
38869 * @doc {heading: 'Models', 'subheading': 'Classes'}
38870 */
38871 // Porting Note: This is a replacement for __call__() in Python.
38872 apply(inputs, kwargs) {
38873 kwargs = kwargs || {};
38874 this.assertNotDisposed();
38875 // Ensure inputs are all the same type.
38876 const inputsList = toList(inputs);
38877 const allAreSymbolic = checkAllSymbolic(inputs);
38878 const noneAreSymbolic = checkNoneSymbolic(inputs);
38879 if (allAreSymbolic === noneAreSymbolic) {
38880 throw new ValueError('Arguments to apply() must be all ' +
38881 'SymbolicTensors or all Tensors');
38882 }
38883 // TODO(michaelterry): nameScope() may not be necessary.
38884 return nameScope(this.name, () => {
38885 // Handle laying building (weight creating, input spec locking).
38886 if (!this.built) {
38887 /*
38888 Throw exceptions in case the input is not compatible
38889 with the inputSpec specified in the layer constructor.
38890 */
38891 this.assertInputCompatibility(inputs);
38892 // Collect input shapes to build layer.
38893 const inputShapes = [];
38894 for (const xElem of toList(inputs)) {
38895 inputShapes.push(xElem.shape);
38896 }
38897 this.build(singletonOrArray(inputShapes));
38898 this.built = true;
38899 // Load weights that were specified at layer instantiation.
38900 if (this.initialWeights) {
38901 this.setWeights(this.initialWeights);
38902 }
38903 if (this._refCount === null && noneAreSymbolic) {
38904 // The first use of this layer is a non-symbolic call, set ref count
38905 // to 1 so the Layer can be properly disposed if its dispose() method
38906 // is called.
38907 this._refCount = 1;
38908 }
38909 }
38910 /*
38911 Throw exceptions in case the input is not compatible
38912 with the inputSpec set at build time.
38913 */
38914 this.assertInputCompatibility(inputs);
38915 // Handle mask propagation.
38916 // TODO(michaelterry): Mask propagation not currently implemented.
38917 // Actually call the layer, collecting output(s), mask(s), and shape(s).
38918 if (noneAreSymbolic) {
38919 let output = this.call(inputs, kwargs);
38920 // Apply masks to the output tensors if the layer supports it.
38921 if (this.supportsMasking) {
38922 // TODO(mattsoulanille): pass the input tensors' masks to computeMask
38923 this.setMaskMetadata(inputs, output);
38924 }
38925 // If the layer returns tensors from its inputs, unmodified,
38926 // we copy them to avoid loss of tensor metadata.
38927 const outputList = toList(output);
38928 const outputListCopy = [];
38929 // TODO(michaelterry): This copying may not be necessary given our eager
38930 // backend.
38931 for (let x of outputList) {
38932 if (inputsList.indexOf(x) !== -1) {
38933 x = x.clone();
38934 }
38935 outputListCopy.push(x);
38936 }
38937 output = singletonOrArray(outputListCopy);
38938 if (this.activityRegularizer != null) {
38939 throw new NotImplementedError('Layer invocation in the presence of activity ' +
38940 'regularizer(s) is not supported yet.');
38941 }
38942 // TODO(michaelterry): Call addInboundNode()?
38943 return output;
38944 }
38945 else {
38946 const inputShape = collectInputShape(inputs);
38947 const outputShape = this.computeOutputShape(inputShape);
38948 let output;
38949 const outputDType = guessOutputDType(inputs);
38950 this.warnOnIncompatibleInputShape(Array.isArray(inputs) ? inputShape[0] :
38951 inputShape);
38952 if (outputShape != null && outputShape.length > 0 &&
38953 Array.isArray(outputShape[0])) {
38954 // We have multiple output shapes. Create multiple output tensors.
38955 output = outputShape
38956 .map((shape, index) => new SymbolicTensor(outputDType, shape, this, toList(inputs), kwargs, this.name, index));
38957 }
38958 else {
38959 output = new SymbolicTensor(outputDType, outputShape, this, toList(inputs), kwargs, this.name);
38960 }
38961 /*
38962 Add an inbound node to the layer, so that it keeps track
38963 of the call and of all new variables created during the call.
38964 This also updates the layer history of the output tensor(s).
38965 If the input tensor(s) had no previous history,
38966 this does nothing.
38967 */
38968 this.addInboundNode(inputs, output, null, null, inputShape, outputShape, kwargs);
38969 this._refCount++;
38970 if (this.activityRegularizer != null) {
38971 throw new NotImplementedError('Layer invocation in the presence of activity ' +
38972 'regularizer(s) is not supported yet.');
38973 }
38974 return output;
38975 }
38976 });
38977 }
38978 /**
38979 * Check compatibility between input shape and this layer's batchInputShape.
38980 *
38981 * Print warning if any incompatibility is found.
38982 *
38983 * @param inputShape Input shape to be checked.
38984 */
38985 warnOnIncompatibleInputShape(inputShape) {
38986 if (this.batchInputShape == null) {
38987 return;
38988 }
38989 else if (inputShape.length !== this.batchInputShape.length) {
38990 console.warn(`The rank of the input tensor provided (shape: ` +
38991 `${JSON.stringify(inputShape)}) does not match that of the ` +
38992 `batchInputShape (${JSON.stringify(this.batchInputShape)}) ` +
38993 `of the layer ${this.name}`);
38994 }
38995 else {
38996 let dimMismatch = false;
38997 this.batchInputShape.forEach((dimension, i) => {
38998 if (dimension != null && inputShape[i] != null &&
38999 inputShape[i] !== dimension) {
39000 dimMismatch = true;
39001 }
39002 });
39003 if (dimMismatch) {
39004 console.warn(`The shape of the input tensor ` +
39005 `(${JSON.stringify(inputShape)}) does not ` +
39006 `match the expectation of layer ${this.name}: ` +
39007 `${JSON.stringify(this.batchInputShape)}`);
39008 }
39009 }
39010 }
39011 /**
39012 * Retrieves the output shape(s) of a layer.
39013 *
39014 * Only applicable if the layer has only one inbound node, or if all inbound
39015 * nodes have the same output shape.
39016 *
39017 * @returns Output shape or shapes.
39018 * @throws AttributeError: if the layer is connected to more than one incoming
39019 * nodes.
39020 *
39021 * @doc {heading: 'Models', 'subheading': 'Classes'}
39022 */
39023 get outputShape() {
39024 if (this.inboundNodes == null || this.inboundNodes.length === 0) {
39025 throw new AttributeError(`The layer ${this.name} has never been called and thus has no ` +
39026 `defined output shape.`);
39027 }
39028 const allOutputShapes = [];
39029 for (const node of this.inboundNodes) {
39030 const shapeString = JSON.stringify(node.outputShapes);
39031 if (allOutputShapes.indexOf(shapeString) === -1) {
39032 allOutputShapes.push(shapeString);
39033 }
39034 }
39035 if (allOutputShapes.length === 1) {
39036 const outputShapes = this.inboundNodes[0].outputShapes;
39037 if (Array.isArray(outputShapes) && Array.isArray(outputShapes[0]) &&
39038 outputShapes.length === 1) {
39039 return outputShapes[0];
39040 }
39041 else {
39042 return outputShapes;
39043 }
39044 }
39045 else {
39046 throw new AttributeError(`The layer ${this.name} has multiple inbound nodes with different ` +
39047 `output shapes. Hence the notion of "output shape" is ill-defined ` +
39048 `for the layer.`);
39049 // TODO(cais): Implement getOutputShapeAt().
39050 }
39051 }
39052 /**
39053 * Counts the total number of numbers (e.g., float32, int32) in the
39054 * weights.
39055 *
39056 * @returns An integer count.
39057 * @throws RuntimeError: If the layer is not built yet (in which case its
39058 * weights are not defined yet.)
39059 *
39060 * @doc {heading: 'Models', 'subheading': 'Classes'}
39061 */
39062 countParams() {
39063 if (!this.built) {
39064 throw new RuntimeError(`You tried to call countParams() on ${this.name}, ` +
39065 `but the layer is not built yet. Build it first by calling ` +
39066 `build(batchInputShape).`);
39067 }
39068 return countParamsInWeights(this.weights);
39069 }
39070 /**
39071 * Creates the layer weights.
39072 *
39073 * Must be implemented on all layers that have weights.
39074 *
39075 * Called when apply() is called to construct the weights.
39076 *
39077 * @param inputShape A `Shape` or array of `Shape` (unused).
39078 *
39079 * @doc {heading: 'Models', 'subheading': 'Classes'}
39080 */
39081 build(inputShape) {
39082 this.built = true;
39083 }
39084 /**
39085 * Returns the current values of the weights of the layer.
39086 *
39087 * @param trainableOnly Whether to get the values of only trainable weights.
39088 * @returns Weight values as an `Array` of `tf.Tensor`s.
39089 *
39090 * @doc {heading: 'Models', 'subheading': 'Classes'}
39091 */
39092 getWeights(trainableOnly = false) {
39093 return batchGetValue(trainableOnly ? this.trainableWeights : this.weights);
39094 }
39095 /**
39096 * Sets the weights of the layer, from Tensors.
39097 *
39098 * @param weights a list of Tensors. The number of arrays and their shape
39099 * must match number of the dimensions of the weights of the layer (i.e.
39100 * it should match the output of `getWeights`).
39101 *
39102 * @exception ValueError If the provided weights list does not match the
39103 * layer's specifications.
39104 *
39105 * @doc {heading: 'Models', 'subheading': 'Classes'}
39106 */
39107 setWeights(weights) {
39108 tidy(() => {
39109 const params = this.weights;
39110 if (params.length !== weights.length) {
39111 // TODO(cais): Restore the following and use `providedWeights`, instead
39112 // of `weights` in the error message, once the deeplearn.js bug is
39113 // fixed: https://github.com/PAIR-code/deeplearnjs/issues/498 const
39114 // providedWeights = JSON.stringify(weights).slice(0, 50);
39115 throw new ValueError(`You called setWeights(weights) on layer "${this.name}" ` +
39116 `with a weight list of length ${weights.length}, ` +
39117 `but the layer was expecting ${params.length} weights. ` +
39118 `Provided weights: ${weights}...`);
39119 }
39120 if (params.length === 0) {
39121 return;
39122 }
39123 const weightValueTuples = [];
39124 const paramValues = batchGetValue(params);
39125 for (let i = 0; i < paramValues.length; ++i) {
39126 const pv = paramValues[i];
39127 const p = params[i];
39128 const w = weights[i];
39129 if (!arraysEqual(pv.shape, w.shape)) {
39130 throw new ValueError(`Layer weight shape ${pv.shape} ` +
39131 `not compatible with provided weight shape ${w.shape}`);
39132 }
39133 weightValueTuples.push([p, w]);
39134 }
39135 batchSetValue(weightValueTuples);
39136 });
39137 }
39138 /**
39139 * Adds a weight variable to the layer.
39140 *
39141 * @param name Name of the new weight variable.
39142 * @param shape The shape of the weight.
39143 * @param dtype The dtype of the weight.
39144 * @param initializer An initializer instance.
39145 * @param regularizer A regularizer instance.
39146 * @param trainable Whether the weight should be trained via backprop or not
39147 * (assuming that the layer itself is also trainable).
39148 * @param constraint An optional trainable.
39149 * @return The created weight variable.
39150 *
39151 * @doc {heading: 'Models', 'subheading': 'Classes'}
39152 */
39153 addWeight(name, shape, dtype, initializer, regularizer, trainable, constraint, getInitializerFunc) {
39154 // Reject duplicate weight names.
39155 if (this._addedWeightNames.indexOf(name) !== -1) {
39156 throw new ValueError(`Duplicate weight name ${name} for layer ${this.name}`);
39157 }
39158 this._addedWeightNames.push(name);
39159 if (dtype == null) {
39160 dtype = 'float32';
39161 }
39162 if (this.fastWeightInitDuringBuild) {
39163 initializer = getInitializerFunc != null ? getInitializerFunc() :
39164 getInitializer('zeros');
39165 }
39166 const initValue = initializer.apply(shape, dtype);
39167 const weight = new LayerVariable(initValue, dtype, name, trainable, constraint);
39168 initValue.dispose();
39169 // Request backend not to dispose the weights of the model on scope() exit.
39170 if (regularizer != null) {
39171 this.addLoss(() => regularizer.apply(weight.read()));
39172 }
39173 if (trainable == null) {
39174 trainable = true;
39175 }
39176 if (trainable) {
39177 this._trainableWeights.push(weight);
39178 }
39179 else {
39180 this._nonTrainableWeights.push(weight);
39181 }
39182 return weight;
39183 }
39184 /**
39185 * Set the fast-weight-initialization flag.
39186 *
39187 * In cases where the initialized weight values will be immediately
39188 * overwritten by loaded weight values during model loading, setting
39189 * the flag to `true` saves unnecessary calls to potentially expensive
39190 * initializers and speeds up the loading process.
39191 *
39192 * @param value Target value of the flag.
39193 */
39194 setFastWeightInitDuringBuild(value) {
39195 this.fastWeightInitDuringBuild = value;
39196 }
39197 /**
39198 * Add losses to the layer.
39199 *
39200 * The loss may potentially be conditional on some inputs tensors,
39201 * for instance activity losses are conditional on the layer's inputs.
39202 *
39203 * @doc {heading: 'Models', 'subheading': 'Classes'}
39204 */
39205 addLoss(losses) {
39206 if (losses == null || Array.isArray(losses) && losses.length === 0) {
39207 return;
39208 }
39209 // Update this.losses
39210 losses = toList(losses);
39211 if (this._losses !== undefined && this._losses !== null) {
39212 this.losses.push(...losses);
39213 }
39214 }
39215 /**
39216 * Computes the output shape of the layer.
39217 *
39218 * Assumes that the layer will be built to match that input shape provided.
39219 *
39220 * @param inputShape A shape (tuple of integers) or a list of shape tuples
39221 * (one per output tensor of the layer). Shape tuples can include null for
39222 * free dimensions, instead of an integer.
39223 *
39224 * @doc {heading: 'Models', 'subheading': 'Classes'}
39225 */
39226 computeOutputShape(inputShape) {
39227 return inputShape;
39228 }
39229 /**
39230 * Computes an output mask tensor.
39231 *
39232 * @param inputs Tensor or list of tensors.
39233 * @param mask Tensor or list of tensors.
39234 *
39235 * @return null or a tensor (or list of tensors, one per output tensor of the
39236 * layer).
39237 */
39238 computeMask(inputs, mask) {
39239 if (!this.supportsMasking) {
39240 if (mask != null) {
39241 if (Array.isArray(mask)) {
39242 mask.forEach(maskElement => {
39243 if (maskElement != null) {
39244 throw new TypeError(`Layer ${this.name} does not support masking, ` +
39245 'but was passed an inputMask.');
39246 }
39247 });
39248 }
39249 else {
39250 throw new TypeError(`Layer ${this.name} does not support masking, ` +
39251 'but was passed an inputMask.');
39252 }
39253 }
39254 // masking not explicitly supported: return null as mask
39255 return null;
39256 }
39257 // if masking is explictly supported, by default
39258 // carry over the input mask
39259 return mask;
39260 }
39261 setMaskMetadata(inputs, outputs, previousMask) {
39262 if (!this.supportsMasking) {
39263 return;
39264 }
39265 const outputMasks = this.computeMask(inputs, previousMask);
39266 const outputsList = toList(outputs);
39267 const outputMasksList = toList(outputMasks);
39268 if (outputsList.length !== outputMasksList.length) {
39269 throw new Error(`${this.name} outputs ${outputsList.length} tensors ` +
39270 `but ${outputsList.length} masks for those tensors`);
39271 }
39272 for (let i = 0; i < outputsList.length; i++) {
39273 outputsList[i].kerasMask = outputMasksList[i];
39274 }
39275 }
39276 /**
39277 * Internal method to create an inbound node for the layer.
39278 *
39279 * @param inputTensors List of input tensors.
39280 * @param outputTensors List of output tensors.
39281 * @param inputMasks List of input masks (a mask can be a tensor, or null).
39282 * @param outputMasks List of output masks (a mask can be a tensor, or null).
39283 * @param inputShapes List of input shape tuples.
39284 * @param outputShapes List of output shape tuples.
39285 * @param kwargs Dictionary of keyword arguments that were passed to the
39286 * `call` method of the layer at the call that created the node.
39287 */
39288 addInboundNode(inputTensors, outputTensors, inputMasks, outputMasks, inputShapes, outputShapes, kwargs = null) {
39289 const inputTensorList = toList(inputTensors);
39290 outputTensors = toList(outputTensors);
39291 inputMasks = toList(inputMasks);
39292 outputMasks = toList(outputMasks);
39293 inputShapes = normalizeShapeList(inputShapes);
39294 outputShapes = normalizeShapeList(outputShapes);
39295 // Collect input tensor(s) coordinates.
39296 const inboundLayers = [];
39297 const nodeIndices = [];
39298 const tensorIndices = [];
39299 for (const x of inputTensorList) {
39300 /*
39301 * TODO(michaelterry): Keras adds this value to tensors; it's not
39302 * clear whether we'll use this or not.
39303 */
39304 inboundLayers.push(x.sourceLayer);
39305 nodeIndices.push(x.nodeIndex);
39306 tensorIndices.push(x.tensorIndex);
39307 }
39308 // Create node, add it to inbound nodes.
39309 // (This call has side effects.)
39310 // tslint:disable-next-line:no-unused-expression
39311 new Node({
39312 outboundLayer: this,
39313 inboundLayers,
39314 nodeIndices,
39315 tensorIndices,
39316 inputTensors: inputTensorList,
39317 outputTensors,
39318 inputMasks,
39319 outputMasks,
39320 inputShapes,
39321 outputShapes
39322 }, kwargs);
39323 // Update tensor history
39324 for (let i = 0; i < outputTensors.length; i++) {
39325 // TODO(michaelterry: _uses_learning_phase not tracked.
39326 outputTensors[i].sourceLayer = this;
39327 outputTensors[i].nodeIndex = this.inboundNodes.length - 1;
39328 outputTensors[i].tensorIndex = i;
39329 }
39330 }
39331 /**
39332 * Returns the config of the layer.
39333 *
39334 * A layer config is a TS dictionary (serializable)
39335 * containing the configuration of a layer.
39336 * The same layer can be reinstantiated later
39337 * (without its trained weights) from this configuration.
39338 *
39339 * The config of a layer does not include connectivity
39340 * information, nor the layer class name. These are handled
39341 * by 'Container' (one layer of abstraction above).
39342 *
39343 * Porting Note: The TS dictionary follows TS naming standards for
39344 * keys, and uses tfjs-layers type-safe Enums. Serialization methods
39345 * should use a helper function to convert to the pythonic storage
39346 * standard. (see serialization_utils.convertTsToPythonic)
39347 *
39348 * @returns TS dictionary of configuration.
39349 *
39350 * @doc {heading: 'Models', 'subheading': 'Classes'}
39351 */
39352 getConfig() {
39353 const config = { name: this.name, trainable: this.trainable };
39354 if (this.batchInputShape != null) {
39355 config['batchInputShape'] = this.batchInputShape;
39356 }
39357 if (this.dtype != null) {
39358 config['dtype'] = this.dtype;
39359 }
39360 return config;
39361 }
39362 /**
39363 * Dispose the weight variables that this Layer instance holds.
39364 *
39365 * @returns {number} Number of disposed variables.
39366 */
39367 disposeWeights() {
39368 this.weights.forEach(weight => weight.dispose());
39369 return this.weights.length;
39370 }
39371 assertNotDisposed() {
39372 if (this._refCount === 0) {
39373 throw new Error(`Layer '${this.name}' is already disposed.`);
39374 }
39375 }
39376 /**
39377 * Attempt to dispose layer's weights.
39378 *
39379 * This method decreases the reference count of the Layer object by 1.
39380 *
39381 * A Layer is reference-counted. Its reference count is incremented by 1
39382 * the first item its `apply()` method is called and when it becomes a part
39383 * of a new `Node` (through calling the `apply()` method on a
39384 * `tf.SymbolicTensor`).
39385 *
39386 * If the reference count of a Layer becomes 0, all the weights will be
39387 * disposed and the underlying memory (e.g., the textures allocated in WebGL)
39388 * will be freed.
39389 *
39390 * Note: If the reference count is greater than 0 after the decrement, the
39391 * weights of the Layer will *not* be disposed.
39392 *
39393 * After a Layer is disposed, it cannot be used in calls such as `apply()`,
39394 * `getWeights()` or `setWeights()` anymore.
39395 *
39396 * @returns A DisposeResult Object with the following fields:
39397 * - refCountAfterDispose: The reference count of the Container after this
39398 * `dispose()` call.
39399 * - numDisposedVariables: Number of `tf.Variable`s (i.e., weights) disposed
39400 * during this `dispose()` call.
39401 * @throws {Error} If the layer is not built yet, or if the layer has already
39402 * been disposed.
39403 *
39404 * @doc {heading: 'Models', 'subheading': 'Classes'}
39405 */
39406 dispose() {
39407 if (!this.built) {
39408 throw new Error(`Cannot dispose Layer ${this.name} because it has not been ` +
39409 `built yet.`);
39410 }
39411 if (this._refCount === null) {
39412 throw new Error(`Cannot dispose Layer ${this.name} because it has not been used ` +
39413 `yet.`);
39414 }
39415 this.assertNotDisposed();
39416 let numDisposedVariables = 0;
39417 if (--this._refCount === 0) {
39418 numDisposedVariables = this.disposeWeights();
39419 }
39420 return { refCountAfterDispose: this._refCount, numDisposedVariables };
39421 }
39422 }
39423 /**
39424 * Collects the input shape(s) of a list of `tf.Tensor`s or
39425 * `tf.SymbolicTensor`s.
39426 *
39427 * TODO(michaelterry): Update PyKeras docs (backport).
39428 *
39429 * @param inputTensors List of input tensors (or single input tensor).
39430 *
39431 * @return List of shape tuples (or single tuple), one tuple per input.
39432 */
39433 function collectInputShape(inputTensors) {
39434 inputTensors =
39435 toList(inputTensors);
39436 const shapes = [];
39437 for (const x of inputTensors) {
39438 shapes.push(x.shape);
39439 }
39440 return singletonOrArray(shapes);
39441 }
39442 /**
39443 * Guesses output dtype based on inputs.
39444 *
39445 * At present, just returns 'float32' for any input.
39446 *
39447 * @param inputTensors List of input tensors (or single input tensor).
39448 *
39449 * @return The guessed DType. At present, always returns 'float32'.
39450 */
39451 function guessOutputDType(inputTensors) {
39452 return 'float32';
39453 }
39454 /**
39455 * Returns the list of input tensors necessary to compute `tensor`.
39456 *
39457 * Output will always be a list of tensors (potentially with 1 element).
39458 *
39459 * @param tensor The tensor to start from.
39460 * @param layer Origin layer of the tensor.
39461 * @param nodeIndex Origin node index of the tensor.
39462 *
39463 * @return Array of input tensors.
39464 */
39465 function getSourceInputs(tensor, layer, nodeIndex) {
39466 if (layer == null || (nodeIndex != null && nodeIndex > 0)) {
39467 layer = tensor.sourceLayer;
39468 nodeIndex = tensor.nodeIndex;
39469 }
39470 if (layer.inboundNodes.length === 0) {
39471 return [tensor];
39472 }
39473 else {
39474 const node = layer.inboundNodes[nodeIndex];
39475 if (node.inboundLayers.length === 0) {
39476 return node.inputTensors;
39477 }
39478 else {
39479 const sourceTensors = [];
39480 for (let i = 0; i < node.inboundLayers.length; i++) {
39481 const x = node.inputTensors[i];
39482 const layer = node.inboundLayers[i];
39483 const nodeIndex = node.nodeIndices[i];
39484 const previousSources = getSourceInputs(x, layer, nodeIndex);
39485 // Avoid input redundancy.
39486 for (const x of previousSources) {
39487 if (sourceTensors.indexOf(x) === -1) {
39488 sourceTensors.push(x);
39489 }
39490 }
39491 }
39492 return sourceTensors;
39493 }
39494 }
39495 }
39496 function checkAllSymbolic(tensors) {
39497 let allAreSymbolic = true;
39498 for (const tensor of toList(tensors)) {
39499 if (!(tensor instanceof SymbolicTensor)) {
39500 allAreSymbolic = false;
39501 break;
39502 }
39503 }
39504 return allAreSymbolic;
39505 }
39506 function checkNoneSymbolic(tensors) {
39507 let noneAreSymbolic = true;
39508 for (const tensor of toList(tensors)) {
39509 if (tensor instanceof SymbolicTensor) {
39510 noneAreSymbolic = false;
39511 break;
39512 }
39513 }
39514 return noneAreSymbolic;
39515 }
39516
39517 /**
39518 * @license
39519 * Copyright 2018 Google LLC
39520 *
39521 * Use of this source code is governed by an MIT-style
39522 * license that can be found in the LICENSE file or at
39523 * https://opensource.org/licenses/MIT.
39524 * =============================================================================
39525 */
39526 class InputLayer extends Layer {
39527 constructor(args) {
39528 super({
39529 dtype: args.dtype,
39530 name: args.name != null ? args.name : getUid('input').toString()
39531 });
39532 // Normalize config.batchSize and config.sparse
39533 if (args.batchSize == null) {
39534 args.batchSize = null;
39535 }
39536 if (args.sparse == null) {
39537 args.sparse = false;
39538 }
39539 this.trainable = false;
39540 this.built = true;
39541 this.sparse = args.sparse;
39542 if (args.inputShape != null && args.batchInputShape != null) {
39543 throw new ValueError('Only provide the inputShape OR ' +
39544 'batchInputShape argument to inputLayer, not both at the same time.');
39545 }
39546 let batchInputShape = args.batchInputShape;
39547 if (batchInputShape == null) {
39548 if (args.inputShape == null) {
39549 throw new ValueError('An InputLayer should be passed either a ' +
39550 '`batchInputShape` or an `inputShape`.');
39551 }
39552 else {
39553 batchInputShape = [args.batchSize].concat(args.inputShape);
39554 }
39555 }
39556 else {
39557 // TODO(michaelterry): Backport to PyKeras
39558 if (args.batchSize != null) {
39559 throw new ValueError('Cannot specify batchSize if batchInputShape is ' +
39560 'specified when creating an InputLayer.');
39561 }
39562 }
39563 const dtype = args.dtype || 'float32';
39564 this.batchInputShape = batchInputShape;
39565 this.dtype = dtype;
39566 // TODO(michaelterry): Backport this to PyKeras?
39567 this.inputSpec = [{ shape: batchInputShape }];
39568 const inputTensor = new SymbolicTensor(this.dtype, this.batchInputShape, this, [], {}, this.name);
39569 inputTensor.nodeIndex = 0;
39570 inputTensor.tensorIndex = 0;
39571 // Create an input node to add to this.outboundNode.
39572 // (This call has side effects.)
39573 // tslint:disable-next-line:no-unused-expression
39574 new Node({
39575 outboundLayer: this,
39576 inboundLayers: [],
39577 nodeIndices: [],
39578 tensorIndices: [],
39579 inputTensors: [inputTensor],
39580 outputTensors: [inputTensor],
39581 inputMasks: [null],
39582 outputMasks: [null],
39583 inputShapes: [batchInputShape],
39584 outputShapes: [batchInputShape]
39585 });
39586 }
39587 apply(inputs, kwargs) {
39588 throw new ValueError('Cannot pass any input to an ' +
39589 `InputLayer's apply() method. InputLayer name: ${this.name}`);
39590 }
39591 dispose() {
39592 // dispose() for InputLayer is overridden as no-op.
39593 return { refCountAfterDispose: this._refCount, numDisposedVariables: 0 };
39594 }
39595 getConfig() {
39596 return {
39597 batchInputShape: this.batchInputShape,
39598 dtype: this.dtype,
39599 sparse: this.sparse,
39600 name: this.name
39601 };
39602 }
39603 }
39604 /** @nocollapse */
39605 InputLayer.className = 'InputLayer';
39606 registerClass(InputLayer);
39607 function Input(config) {
39608 if (config.batchShape == null && config.shape == null) {
39609 throw new Error('Please provide to Input either a `shape`' +
39610 ' or a `batchShape` argument. Note that ' +
39611 '`shape` does not include the batch ' +
39612 'dimension.');
39613 }
39614 if (config.batchShape != null && config.shape != null) {
39615 // TODO(michaelterry): Backport to PyKeras.
39616 throw new ValueError('Please provide either a `shape` or `batchShape` ' +
39617 'argument to Input, but not both.');
39618 }
39619 let batchShape = config.batchShape;
39620 if (config.shape != null && batchShape == null) {
39621 batchShape = [null].concat(config.shape);
39622 }
39623 let dtype = config.dtype;
39624 if (dtype == null) {
39625 dtype = 'float32';
39626 }
39627 const inputLayer = new InputLayer({
39628 batchInputShape: batchShape,
39629 name: config.name,
39630 dtype,
39631 sparse: config.sparse
39632 });
39633 const outputs = inputLayer.inboundNodes[0].outputTensors;
39634 return outputs[0];
39635 }
39636
39637 /**
39638 * @license
39639 * Copyright 2018 Google LLC
39640 *
39641 * Use of this source code is governed by an MIT-style
39642 * license that can be found in the LICENSE file or at
39643 * https://opensource.org/licenses/MIT.
39644 * =============================================================================
39645 */
39646 /**
39647 * Helper function to check the dtype and shape compatibility of a feed value.
39648 */
39649 function assertFeedCompatibility(key, val) {
39650 // Check dtype compatibility.
39651 if (key.dtype == null || key.dtype === val.dtype) {
39652 // a. If types match, return val tensor as is.
39653 return val;
39654 }
39655 try {
39656 // b. Attempt to convert to expected type.
39657 return cast$3(val, key.dtype);
39658 }
39659 catch (err) {
39660 // c. If conversion fails, return helpful error.
39661 throw new ValueError(`The dtype of the feed (${val.dtype}) can not be cast to the dtype ` +
39662 `of the key '${key.name}' (${key.dtype}).`);
39663 }
39664 }
39665 /**
39666 * FeedDict: A mapping from unique SymbolicTensors to feed values for them.
39667 * A feed value is a concrete value represented as an `Tensor`.
39668 */
39669 class FeedDict {
39670 /**
39671 * Constructor, optionally does copy-construction.
39672 * @param feeds An Array of `Feed`s, or another `FeedDict`, in which case
39673 * copy-construction will be performed.
39674 */
39675 constructor(feeds) {
39676 this.id2Value = {};
39677 this.id2Mask = {};
39678 this.name2Id = {};
39679 if (feeds instanceof FeedDict) {
39680 for (const id in feeds.id2Value) {
39681 this.id2Value[id] = feeds.id2Value[id];
39682 if (id in feeds.id2Mask) {
39683 this.id2Mask[id] = feeds.id2Mask[id];
39684 }
39685 }
39686 }
39687 else {
39688 if (feeds == null) {
39689 return;
39690 }
39691 for (const feed of feeds) {
39692 this.add(feed.key, feed.value);
39693 }
39694 }
39695 }
39696 /**
39697 * Add a key-value pair to the FeedDict.
39698 *
39699 * @param key The key of the feed.
39700 * @param value The value of the tensor feed.
39701 * @param mask The value of the mask feed (optional).
39702 * @returns This `FeedDict`.
39703 * @throws ValueError: If the key `SymbolicTensor` already exists in the
39704 * `FeedDict`.
39705 */
39706 add(key, value, mask) {
39707 if (this.id2Value[key.id] == null) {
39708 this.id2Value[key.id] = assertFeedCompatibility(key, value);
39709 this.name2Id[key.name] = key.id;
39710 if (mask != null) {
39711 this.id2Mask[key.id] = mask;
39712 }
39713 }
39714 else {
39715 throw new ValueError(`Duplicate key: name=${key.name}, id=${key.id}`);
39716 }
39717 return this;
39718 }
39719 /**
39720 * Add a Feed to the FeedDict.
39721 * @param feed The new `Feed` to add.
39722 * @returns This `FeedDict`.
39723 */
39724 addFeed(feed) {
39725 this.add(feed.key, feed.value);
39726 }
39727 /**
39728 * Probe whether a key already exists in the FeedDict.
39729 * @param key
39730 */
39731 hasKey(key) {
39732 return this.id2Value[key.id] != null;
39733 }
39734 /**
39735 * Get all the SymbolicTensor available in this FeedDict.
39736 */
39737 names() {
39738 return Object.keys(this.name2Id);
39739 }
39740 /**
39741 * Get the feed value for given key.
39742 * @param key The SymbolicTensor, or its name (as a string), of which the
39743 * value is sought.
39744 * @returns If `key` exists, the corresponding feed value.
39745 * @throws ValueError: If `key` does not exist in this `FeedDict`.
39746 */
39747 getValue(key) {
39748 if (key instanceof SymbolicTensor) {
39749 if (this.id2Value[key.id] == null) {
39750 throw new ValueError(`Nonexistent key: ${key.name}`);
39751 }
39752 else {
39753 return this.id2Value[key.id];
39754 }
39755 }
39756 else {
39757 const id = this.name2Id[key];
39758 if (id == null) {
39759 throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`);
39760 }
39761 return this.id2Value[id];
39762 }
39763 }
39764 /**
39765 * Get the feed mask for given key.
39766 * @param key The SymbolicTensor, or its name (as a string), of which the
39767 * value is sought.
39768 * @returns If `key` exists, the corresponding feed mask.
39769 * @throws ValueError: If `key` does not exist in this `FeedDict`.
39770 */
39771 getMask(key) {
39772 if (key instanceof SymbolicTensor) {
39773 if (this.id2Value[key.id] == null) {
39774 throw new ValueError(`Nonexistent key: ${key.name}`);
39775 }
39776 else {
39777 return this.id2Mask[key.id];
39778 }
39779 }
39780 else {
39781 const id = this.name2Id[key];
39782 if (id == null) {
39783 throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`);
39784 }
39785 return this.id2Mask[id];
39786 }
39787 }
39788 /** Dispose all mask Tensors held by this object. */
39789 disposeMasks() {
39790 if (this.id2Mask != null) {
39791 dispose(this.id2Mask);
39792 }
39793 }
39794 }
39795 // Cache for topologically sorted SymbolicTensors for given execution
39796 // targets (i.e., fetches).
39797 const cachedSorted = new LruCache();
39798 // Cache for recipient count maps for given execution targets (i.e., fetches).
39799 const cachedRecipientCounts = new LruCache();
39800 function updateCacheMaxEntries(maxEntries) {
39801 if (cachedSorted != null) {
39802 cachedSorted.setMaxEntries(maxEntries);
39803 }
39804 if (cachedRecipientCounts != null) {
39805 cachedRecipientCounts.setMaxEntries(maxEntries);
39806 }
39807 }
39808 /**
39809 * Execute a SymbolicTensor by using concrete feed values.
39810 *
39811 * A `SymbolicTensor` object is a node in a computation graph of TF.js
39812 * Layers. The object is backed by a source layer and input
39813 * `SymbolicTensor`s to the source layer. This method evaluates
39814 * the `call()` method of the source layer, using concrete values of the
39815 * inputs obtained from either
39816 * * `feedDict`, if the input key exists in `feedDict`, or else,
39817 * * a recursive call to `execute()` itself.
39818 *
39819 * @param x: The `SymbolicTensor` to execute.
39820 * @param feedDict: The feed values, as base condition of the recursion.
39821 * execution.
39822 * @param kwargs: Optional keyword arguments.
39823 * @param probe: A probe object (of interface `ExecutionProbe`) used for
39824 * testing memory footprint of `execute` calls.
39825 * @returns Result of the execution.
39826 * @throws ValueError: If any `SymbolicTensor`s from `InputLayer`s
39827 * encountered during the execution lacks a feed value in `feedDict`.
39828 */
39829 function execute(fetches, feedDict, kwargs, probe) {
39830 const training = kwargs == null ? false : kwargs['training'];
39831 const arrayFetches = Array.isArray(fetches);
39832 const fetchArray = arrayFetches ? fetches : [fetches];
39833 const outputNames = fetchArray.map(t => t.name);
39834 const finalOutputs = [];
39835 const feedNames = feedDict.names();
39836 for (const outputName of outputNames) {
39837 if (feedNames.indexOf(outputName) !== -1) {
39838 finalOutputs.push(feedDict.getValue(outputName));
39839 }
39840 else {
39841 finalOutputs.push(null);
39842 }
39843 }
39844 if (probe != null) {
39845 // For optional probing of memory footprint during execution.
39846 probe.maxNumTensors = -Infinity;
39847 probe.minNumTensors = Infinity;
39848 }
39849 // Check cache.
39850 const fetchAndFeedKey = outputNames.join(',') + '|' + feedDict.names().sort().join(',');
39851 let sorted = cachedSorted.get(fetchAndFeedKey);
39852 let recipientCounts;
39853 if (sorted == null) {
39854 // Cache doesn't contain the desired combination of fetches. Compute
39855 // topological sort for the combination for the first time.
39856 const out = getTopologicalSortAndRecipientCounts(fetchArray, feedDict);
39857 sorted = out.sorted;
39858 recipientCounts = out.recipientCounts;
39859 // Store results in cache for future use.
39860 cachedSorted.put(fetchAndFeedKey, sorted);
39861 cachedRecipientCounts.put(fetchAndFeedKey, recipientCounts);
39862 }
39863 recipientCounts = {};
39864 if (!training) {
39865 Object.assign(recipientCounts, cachedRecipientCounts.get(fetchAndFeedKey));
39866 }
39867 const internalFeedDict = new FeedDict(feedDict);
39868 // Start iterative execution on the topologically-sorted SymbolicTensors.
39869 for (let i = 0; i < sorted.length; ++i) {
39870 if (probe != null) {
39871 // For optional probing of memory usage during execution.
39872 const numTensors = memory().numTensors;
39873 if (numTensors > probe.maxNumTensors) {
39874 probe.maxNumTensors = numTensors;
39875 }
39876 if (numTensors < probe.minNumTensors) {
39877 probe.minNumTensors = numTensors;
39878 }
39879 }
39880 const symbolic = sorted[i];
39881 const srcLayer = symbolic.sourceLayer;
39882 if (srcLayer instanceof InputLayer) {
39883 continue;
39884 }
39885 const inputValues = [];
39886 const inputMasks = [];
39887 const tensorsToDispose = [];
39888 let maskExists = false;
39889 for (const input of symbolic.inputs) {
39890 const value = internalFeedDict.getValue(input);
39891 const mask = internalFeedDict.getMask(input);
39892 inputValues.push(value);
39893 inputMasks.push(mask);
39894 if (mask != null) {
39895 maskExists = true;
39896 }
39897 if (!training) {
39898 recipientCounts[input.name]--;
39899 if (recipientCounts[input.name] === 0 && !feedDict.hasKey(input) &&
39900 outputNames.indexOf(input.name) === -1 && !value.isDisposed &&
39901 input.sourceLayer.stateful !== true) {
39902 tensorsToDispose.push(value);
39903 }
39904 }
39905 }
39906 if (maskExists) {
39907 kwargs = kwargs || {};
39908 kwargs['mask'] = inputMasks[0];
39909 }
39910 const outputTensors = toList(srcLayer.apply(inputValues, kwargs));
39911 let outputMask = null;
39912 if (srcLayer.supportsMasking) {
39913 outputMask = srcLayer.computeMask(inputValues, inputMasks);
39914 }
39915 const layerOutputs = getNodeOutputs(symbolic);
39916 const outputSymbolicTensors = Array.isArray(layerOutputs) ? layerOutputs : [layerOutputs];
39917 for (let i = 0; i < outputSymbolicTensors.length; ++i) {
39918 if (!internalFeedDict.hasKey(outputSymbolicTensors[i])) {
39919 internalFeedDict.add(outputSymbolicTensors[i], outputTensors[i], Array.isArray(outputMask) ? outputMask[0] : outputMask);
39920 }
39921 const index = outputNames.indexOf(outputSymbolicTensors[i].name);
39922 if (index !== -1) {
39923 finalOutputs[index] = outputTensors[i];
39924 }
39925 }
39926 if (!training) {
39927 // Clean up Tensors that are no longer needed.
39928 dispose(tensorsToDispose);
39929 }
39930 }
39931 // NOTE(cais): Unlike intermediate tensors, we don't discard mask
39932 // tensors as we go, because these tensors are sometimes passed over a
39933 // series of mutliple layers, i.e., not obeying the immediate input
39934 // relations in the graph. If this becomes a memory-usage concern,
39935 // we can improve this in the future.
39936 internalFeedDict.disposeMasks();
39937 return arrayFetches ? finalOutputs : finalOutputs[0];
39938 }
39939 /**
39940 * Sort the `SymbolicTensor`s topologically, for an array of fetches.
39941 *
39942 * This function calls getTopologicalSortAndRecipientCountsForOneFetch and
39943 * merges their results.
39944 *
39945 * @param fetch The array of fetches requested. Must be a non-empty array.
39946 * @param feedDict The dictionary of fed values.
39947 * @returns sorted: Topologically-sorted array of SymbolicTensors.
39948 * recipientCounts: Recipient counts for all SymbolicTensors in `sorted`.
39949 */
39950 function getTopologicalSortAndRecipientCounts(fetches, feedDict) {
39951 assert$1(fetches != null && fetches.length > 0, () => `Expected at least one fetch, got none`);
39952 let finalSorted = [];
39953 let finalRecipientMap = {};
39954 if (fetches.length === 1) {
39955 // Special-casing 1 fetch for efficiency.
39956 const out = getTopologicalSortAndRecipientCountsForOneFetch(fetches[0], feedDict);
39957 finalSorted = out.sorted;
39958 finalRecipientMap = out.recipientMap;
39959 }
39960 else {
39961 const visited = new Set();
39962 for (const fetch of fetches) {
39963 const { sorted, recipientMap } = getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict);
39964 // Merge sorted SymbolicTensor Arrays.
39965 for (const symbolicTensor of sorted) {
39966 if (!visited.has(symbolicTensor.name)) {
39967 finalSorted.push(symbolicTensor);
39968 visited.add(symbolicTensor.name);
39969 }
39970 }
39971 // Merge recipient maps.
39972 for (const name in recipientMap) {
39973 if (finalRecipientMap[name] == null) {
39974 finalRecipientMap[name] = new Set();
39975 }
39976 recipientMap[name].forEach(recipient => finalRecipientMap[name].add(recipient));
39977 }
39978 }
39979 }
39980 return {
39981 sorted: finalSorted,
39982 recipientCounts: recipientMap2Counts(finalRecipientMap)
39983 };
39984 }
39985 function recipientMap2Counts(recipientMap) {
39986 const recipientCounts = {};
39987 for (const name in recipientMap) {
39988 recipientCounts[name] = recipientMap[name].size;
39989 }
39990 return recipientCounts;
39991 }
39992 /**
39993 * Sort the `SymbolicTensor`s topologically, for a single fetch.
39994 *
39995 * This helper function processes the upstream SymbolicTensors of a single
39996 * fetch.
39997 *
39998 * @param fetch The single fetch requested.
39999 * @param feedDict The dictionary of fed values.
40000 * @returns sorted: Topologically-sorted array of SymbolicTensors.
40001 * recipientMap: Recipient names for all SymbolicTensors in `sorted`.
40002 */
40003 function getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict) {
40004 const visited = new Set();
40005 const sorted = [];
40006 const recipientMap = {};
40007 // Put keys of the feedDict into visited first, so they don't have to be
40008 // walked. This is needed in case where there are feeds for intermediate
40009 // SymbolicTensors of the graph.
40010 for (const key of feedDict.names()) {
40011 visited.add(key);
40012 }
40013 const stack = [];
40014 const marks = [];
40015 // Initial population of stack and marks.
40016 stack.push(fetch);
40017 while (stack.length > 0) {
40018 const top = stack[stack.length - 1];
40019 if (visited.has(top.name)) {
40020 stack.pop();
40021 continue;
40022 }
40023 const topIsMarked = marks[marks.length - 1] === stack.length - 1;
40024 if (top.inputs.length === 0 || topIsMarked) {
40025 // Input SymbolicTensor or all children have been visited.
40026 stack.pop();
40027 sorted.push(top);
40028 visited.add(top.name);
40029 if (topIsMarked) {
40030 marks.pop();
40031 }
40032 }
40033 else {
40034 // A non-input SymbolicTensor whose upstream SymbolicTensors haven't
40035 // been visited yet. Push them onto the stack.
40036 marks.push(stack.length - 1);
40037 for (const input of top.inputs) {
40038 // Increment the recipient count. Note that this needs to happen
40039 // regardless of whether the SymbolicTensor has been visited before.
40040 if (recipientMap[input.name] == null) {
40041 recipientMap[input.name] = new Set();
40042 }
40043 recipientMap[input.name].add(top.name);
40044 if (visited.has(input.name)) {
40045 continue; // Avoid repeated visits to the same SymbolicTensor.
40046 }
40047 stack.push(input);
40048 }
40049 }
40050 }
40051 return { sorted, recipientMap };
40052 }
40053 /**
40054 * Get the symbolic output tensors of the node to which a given fetch belongs.
40055 * @param fetch The fetched symbolic tensor.
40056 * @returns The Array of symbolic tensors output by the node to which `fetch`
40057 * belongs.
40058 */
40059 function getNodeOutputs(fetch) {
40060 let layerOutputs;
40061 if (fetch.sourceLayer.inboundNodes.length === 1) {
40062 layerOutputs = fetch.sourceLayer.output;
40063 }
40064 else {
40065 let nodeIndex = null;
40066 for (let i = 0; i < fetch.sourceLayer.inboundNodes.length; ++i) {
40067 for (const outputTensor of fetch.sourceLayer.inboundNodes[i]
40068 .outputTensors) {
40069 if (outputTensor.id === fetch.id) {
40070 nodeIndex = i;
40071 break;
40072 }
40073 }
40074 }
40075 layerOutputs = fetch.sourceLayer.getOutputAt(nodeIndex);
40076 }
40077 return layerOutputs;
40078 }
40079
40080 /**
40081 * @license
40082 * Copyright 2022 Google LLC. All Rights Reserved.
40083 * Licensed under the Apache License, Version 2.0 (the "License");
40084 * you may not use this file except in compliance with the License.
40085 * You may obtain a copy of the License at
40086 *
40087 * http://www.apache.org/licenses/LICENSE-2.0
40088 *
40089 * Unless required by applicable law or agreed to in writing, software
40090 * distributed under the License is distributed on an "AS IS" BASIS,
40091 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40092 * See the License for the specific language governing permissions and
40093 * limitations under the License.
40094 * =============================================================================
40095 */
40096 const ENV$2 = env();
40097 /** The max number of entries for the caches of layers' topological sort. */
40098 ENV$2.registerFlag('TOPOLOGICAL_SORT_CACHE_MAX_ENTRIES', () => 100, updateCacheMaxEntries);
40099
40100 /**
40101 * @license
40102 * Copyright 2018 Google LLC
40103 *
40104 * Use of this source code is governed by an MIT-style
40105 * license that can be found in the LICENSE file or at
40106 * https://opensource.org/licenses/MIT.
40107 * =============================================================================
40108 */
40109 /**
40110 * Helper function used by many of the Constraints to find the L2Norms.
40111 */
40112 function calcL2Norms(w, axis) {
40113 return tidy(() => sqrt$2(sum$3(mul(w, w), axis, true)));
40114 }
40115 /**
40116 * Base class for functions that impose constraints on weight values
40117 *
40118 * @doc {
40119 * heading: 'Constraints',
40120 * subheading: 'Classes',
40121 * namespace: 'constraints'
40122 * }
40123 */
40124 class Constraint extends Serializable {
40125 getConfig() {
40126 return {};
40127 }
40128 }
40129 class MaxNorm extends Constraint {
40130 constructor(args) {
40131 super();
40132 this.defaultMaxValue = 2;
40133 this.defaultAxis = 0;
40134 this.maxValue =
40135 args.maxValue != null ? args.maxValue : this.defaultMaxValue;
40136 this.axis = args.axis != null ? args.axis : this.defaultAxis;
40137 }
40138 apply(w) {
40139 return tidy(() => {
40140 const norms = calcL2Norms(w, this.axis);
40141 const desired = clipByValue$2(norms, 0, this.maxValue);
40142 return mul(w, div$1(desired, add$3(epsilon$1(), norms)));
40143 });
40144 }
40145 getConfig() {
40146 return { maxValue: this.maxValue, axis: this.axis };
40147 }
40148 }
40149 /** @nocollapse */
40150 MaxNorm.className = 'MaxNorm';
40151 registerClass(MaxNorm);
40152 class UnitNorm extends Constraint {
40153 constructor(args) {
40154 super();
40155 this.defaultAxis = 0;
40156 this.axis = args.axis != null ? args.axis : this.defaultAxis;
40157 }
40158 apply(w) {
40159 return tidy(() => div$1(w, add$3(epsilon$1(), calcL2Norms(w, this.axis))));
40160 }
40161 getConfig() {
40162 return { axis: this.axis };
40163 }
40164 }
40165 /** @nocollapse */
40166 UnitNorm.className = 'UnitNorm';
40167 registerClass(UnitNorm);
40168 class NonNeg extends Constraint {
40169 apply(w) {
40170 return relu$2(w);
40171 }
40172 }
40173 /** @nocollapse */
40174 NonNeg.className = 'NonNeg';
40175 registerClass(NonNeg);
40176 class MinMaxNorm extends Constraint {
40177 constructor(args) {
40178 super();
40179 this.defaultMinValue = 0.0;
40180 this.defaultMaxValue = 1.0;
40181 this.defaultRate = 1.0;
40182 this.defaultAxis = 0;
40183 this.minValue =
40184 args.minValue != null ? args.minValue : this.defaultMinValue;
40185 this.maxValue =
40186 args.maxValue != null ? args.maxValue : this.defaultMaxValue;
40187 this.rate = args.rate != null ? args.rate : this.defaultRate;
40188 this.axis = args.axis != null ? args.axis : this.defaultAxis;
40189 }
40190 apply(w) {
40191 return tidy(() => {
40192 const norms = calcL2Norms(w, this.axis);
40193 const desired = add$3(mul(this.rate, clipByValue$2(norms, this.minValue, this.maxValue)), mul(1.0 - this.rate, norms));
40194 return mul(w, div$1(desired, add$3(epsilon$1(), norms)));
40195 });
40196 }
40197 getConfig() {
40198 return {
40199 minValue: this.minValue,
40200 maxValue: this.maxValue,
40201 rate: this.rate,
40202 axis: this.axis
40203 };
40204 }
40205 }
40206 /** @nocollapse */
40207 MinMaxNorm.className = 'MinMaxNorm';
40208 registerClass(MinMaxNorm);
40209 // Maps the JavaScript-like identifier keys to the corresponding registry
40210 // symbols.
40211 const CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
40212 'maxNorm': 'MaxNorm',
40213 'minMaxNorm': 'MinMaxNorm',
40214 'nonNeg': 'NonNeg',
40215 'unitNorm': 'UnitNorm'
40216 };
40217 function serializeConstraint(constraint) {
40218 return serializeKerasObject(constraint);
40219 }
40220 function deserializeConstraint(config, customObjects = {}) {
40221 return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'constraint');
40222 }
40223 function getConstraint(identifier) {
40224 if (identifier == null) {
40225 return null;
40226 }
40227 if (typeof identifier === 'string') {
40228 const className = identifier in CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
40229 CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
40230 identifier;
40231 const config = { className, config: {} };
40232 return deserializeConstraint(config);
40233 }
40234 else if (identifier instanceof Constraint) {
40235 return identifier;
40236 }
40237 else {
40238 return deserializeConstraint(identifier);
40239 }
40240 }
40241
40242 /**
40243 * @license
40244 * Copyright 2018 Google LLC
40245 *
40246 * Use of this source code is governed by an MIT-style
40247 * license that can be found in the LICENSE file or at
40248 * https://opensource.org/licenses/MIT.
40249 * =============================================================================
40250 */
40251 /**
40252 * MaxNorm weight constraint.
40253 *
40254 * Constrains the weights incident to each hidden unit
40255 * to have a norm less than or equal to a desired value.
40256 *
40257 * References
40258 * - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting
40259 * Srivastava, Hinton, et al.
40260 * 2014](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
40261 *
40262 * @doc {heading: 'Constraints',namespace: 'constraints'}
40263 */
40264 function maxNorm(args) {
40265 return new MaxNorm(args);
40266 }
40267 /**
40268 * Constrains the weights incident to each hidden unit to have unit norm.
40269 *
40270 * @doc {heading: 'Constraints', namespace: 'constraints'}
40271 */
40272 function unitNorm(args) {
40273 return new UnitNorm(args);
40274 }
40275 /**
40276 * Constrains the weight to be non-negative.
40277 *
40278 * @doc {heading: 'Constraints', namespace: 'constraints'}
40279 */
40280 function nonNeg() {
40281 return new NonNeg();
40282 }
40283 /** @doc {heading: 'Constraints', namespace: 'constraints'} */
40284 function minMaxNorm(config) {
40285 return new MinMaxNorm(config);
40286 }
40287
40288 var exports_constraints = /*#__PURE__*/Object.freeze({
40289 __proto__: null,
40290 maxNorm: maxNorm,
40291 minMaxNorm: minMaxNorm,
40292 nonNeg: nonNeg,
40293 unitNorm: unitNorm
40294 });
40295
40296 /**
40297 * @license
40298 * Copyright 2018 Google LLC
40299 *
40300 * Use of this source code is governed by an MIT-style
40301 * license that can be found in the LICENSE file or at
40302 * https://opensource.org/licenses/MIT.
40303 * =============================================================================
40304 */
40305 /**
40306 * Initializer that generates tensors initialized to 0.
40307 *
40308 * @doc {heading: 'Initializers', namespace: 'initializers'}
40309 */
40310 function zeros$1() {
40311 return new Zeros();
40312 }
40313 /**
40314 * Initializer that generates tensors initialized to 1.
40315 *
40316 * @doc {heading: 'Initializers', namespace: 'initializers'}
40317 */
40318 function ones() {
40319 return new Ones();
40320 }
40321 /**
40322 * Initializer that generates values initialized to some constant.
40323 *
40324 * @doc {heading: 'Initializers', namespace: 'initializers'}
40325 */
40326 function constant(args) {
40327 return new Constant(args);
40328 }
40329 /**
40330 * Initializer that generates random values initialized to a uniform
40331 * distribution.
40332 *
40333 * Values will be distributed uniformly between the configured minval and
40334 * maxval.
40335 *
40336 * @doc {heading: 'Initializers', namespace: 'initializers'}
40337 */
40338 function randomUniform(args) {
40339 return new RandomUniform(args);
40340 }
40341 /**
40342 * Initializer that generates random values initialized to a normal
40343 * distribution.
40344 *
40345 * @doc {heading: 'Initializers', namespace: 'initializers'}
40346 */
40347 function randomNormal(args) {
40348 return new RandomNormal(args);
40349 }
40350 /**
40351 * Initializer that generates random values initialized to a truncated normal
40352 * distribution.
40353 *
40354 * These values are similar to values from a `RandomNormal` except that values
40355 * more than two standard deviations from the mean are discarded and re-drawn.
40356 * This is the recommended initializer for neural network weights and filters.
40357 *
40358 * @doc {heading: 'Initializers', namespace: 'initializers'}
40359 */
40360 function truncatedNormal(args) {
40361 return new TruncatedNormal(args);
40362 }
40363 /**
40364 * Initializer that generates the identity matrix.
40365 * Only use for square 2D matrices.
40366 *
40367 * @doc {heading: 'Initializers', namespace: 'initializers'}
40368 */
40369 function identity$2(args) {
40370 return new Identity(args);
40371 }
40372 /**
40373 * Initializer capable of adapting its scale to the shape of weights.
40374 * With distribution=NORMAL, samples are drawn from a truncated normal
40375 * distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
40376 * - number of input units in the weight tensor, if mode = FAN_IN.
40377 * - number of output units, if mode = FAN_OUT.
40378 * - average of the numbers of input and output units, if mode = FAN_AVG.
40379 * With distribution=UNIFORM,
40380 * samples are drawn from a uniform distribution
40381 * within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
40382 *
40383 * @doc {heading: 'Initializers',namespace: 'initializers'}
40384 */
40385 function varianceScaling(config) {
40386 return new VarianceScaling(config);
40387 }
40388 /**
40389 * Glorot uniform initializer, also called Xavier uniform initializer.
40390 * It draws samples from a uniform distribution within [-limit, limit]
40391 * where `limit` is `sqrt(6 / (fan_in + fan_out))`
40392 * where `fan_in` is the number of input units in the weight tensor
40393 * and `fan_out` is the number of output units in the weight tensor
40394 *
40395 * Reference:
40396 * Glorot & Bengio, AISTATS 2010
40397 * http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf.
40398 *
40399 * @doc {heading: 'Initializers', namespace: 'initializers'}
40400 */
40401 function glorotUniform(args) {
40402 return new GlorotUniform(args);
40403 }
40404 /**
40405 * Glorot normal initializer, also called Xavier normal initializer.
40406 * It draws samples from a truncated normal distribution centered on 0
40407 * with `stddev = sqrt(2 / (fan_in + fan_out))`
40408 * where `fan_in` is the number of input units in the weight tensor
40409 * and `fan_out` is the number of output units in the weight tensor.
40410 *
40411 * Reference:
40412 * Glorot & Bengio, AISTATS 2010
40413 * http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
40414 *
40415 * @doc {heading: 'Initializers', namespace: 'initializers'}
40416 */
40417 function glorotNormal(args) {
40418 return new GlorotNormal(args);
40419 }
40420 /**
40421 * He normal initializer.
40422 *
40423 * It draws samples from a truncated normal distribution centered on 0
40424 * with `stddev = sqrt(2 / fanIn)`
40425 * where `fanIn` is the number of input units in the weight tensor.
40426 *
40427 * Reference:
40428 * He et al., http://arxiv.org/abs/1502.01852
40429 *
40430 * @doc {heading: 'Initializers', namespace: 'initializers'}
40431 */
40432 function heNormal(args) {
40433 return new HeNormal(args);
40434 }
40435 /**
40436 * He uniform initializer.
40437 *
40438 * It draws samples from a uniform distribution within [-limit, limit]
40439 * where `limit` is `sqrt(6 / fan_in)`
40440 * where `fanIn` is the number of input units in the weight tensor.
40441 *
40442 * Reference:
40443 * He et al., http://arxiv.org/abs/1502.01852
40444 *
40445 * @doc {heading: 'Initializers',namespace: 'initializers'}
40446 */
40447 function heUniform(args) {
40448 return new HeUniform(args);
40449 }
40450 /**
40451 * LeCun normal initializer.
40452 *
40453 * It draws samples from a truncated normal distribution centered on 0
40454 * with `stddev = sqrt(1 / fanIn)`
40455 * where `fanIn` is the number of input units in the weight tensor.
40456 *
40457 * References:
40458 * [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
40459 * [Efficient Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
40460 *
40461 * @doc {heading: 'Initializers', namespace: 'initializers'}
40462 */
40463 function leCunNormal(args) {
40464 return new LeCunNormal(args);
40465 }
40466 /**
40467 * LeCun uniform initializer.
40468 *
40469 * It draws samples from a uniform distribution in the interval
40470 * `[-limit, limit]` with `limit = sqrt(3 / fanIn)`,
40471 * where `fanIn` is the number of input units in the weight tensor.
40472 *
40473 * @doc {heading: 'Initializers', namespace: 'initializers'}
40474 */
40475 function leCunUniform(args) {
40476 return new LeCunUniform(args);
40477 }
40478 /**
40479 * Initializer that generates a random orthogonal matrix.
40480 *
40481 * Reference:
40482 * [Saxe et al., http://arxiv.org/abs/1312.6120](http://arxiv.org/abs/1312.6120)
40483 *
40484 * @doc {heading: 'Initializers', namespace: 'initializers'}
40485 */
40486 function orthogonal(args) {
40487 return new Orthogonal(args);
40488 }
40489
40490 var exports_initializers = /*#__PURE__*/Object.freeze({
40491 __proto__: null,
40492 constant: constant,
40493 glorotNormal: glorotNormal,
40494 glorotUniform: glorotUniform,
40495 heNormal: heNormal,
40496 heUniform: heUniform,
40497 identity: identity$2,
40498 leCunNormal: leCunNormal,
40499 leCunUniform: leCunUniform,
40500 ones: ones,
40501 orthogonal: orthogonal,
40502 randomNormal: randomNormal,
40503 randomUniform: randomUniform,
40504 truncatedNormal: truncatedNormal,
40505 varianceScaling: varianceScaling,
40506 zeros: zeros$1
40507 });
40508
40509 /**
40510 * @license
40511 * Copyright 2018 Google LLC
40512 *
40513 * Use of this source code is governed by an MIT-style
40514 * license that can be found in the LICENSE file or at
40515 * https://opensource.org/licenses/MIT.
40516 * =============================================================================
40517 */
40518 /**
40519 * Turn any Scalar values in a Logs object into actual number values.
40520 *
40521 * @param logs The `Logs` object to be resolved in place.
40522 */
40523 async function resolveScalarsInLogs(logs) {
40524 if (logs == null) {
40525 return;
40526 }
40527 const promises = [];
40528 const keys = [];
40529 const scalarsToDispose = [];
40530 for (const key in logs) {
40531 const value = logs[key];
40532 if (typeof value !== 'number') {
40533 const valueScalar = value;
40534 promises.push(valueScalar.data());
40535 keys.push(key);
40536 scalarsToDispose.push(valueScalar);
40537 }
40538 }
40539 if (promises.length > 0) {
40540 const values = await Promise.all(promises);
40541 for (let i = 0; i < values.length; ++i) {
40542 logs[keys[i]] = values[i][0];
40543 }
40544 // Dispose the original scalar tensors.
40545 dispose(scalarsToDispose);
40546 }
40547 }
40548 /**
40549 * Dispose all Tensors in an UnresolvedLogs object.
40550 *
40551 * @param logs An `UnresolvedLogs` object potentially containing `tf.Tensor`s in
40552 * places where the values can be `tf.Tensor` or `number`.
40553 */
40554 function disposeTensorsInLogs(logs) {
40555 if (logs == null) {
40556 return;
40557 }
40558 for (const key in logs) {
40559 const value = logs[key];
40560 if (typeof value !== 'number') {
40561 value.dispose();
40562 }
40563 }
40564 }
40565
40566 /**
40567 * @license
40568 * Copyright 2018 Google LLC
40569 *
40570 * Use of this source code is governed by an MIT-style
40571 * license that can be found in the LICENSE file or at
40572 * https://opensource.org/licenses/MIT.
40573 * =============================================================================
40574 */
40575 /** Verbosity logging level when fitting a model. */
40576 var ModelLoggingVerbosity;
40577 (function (ModelLoggingVerbosity) {
40578 ModelLoggingVerbosity[ModelLoggingVerbosity["SILENT"] = 0] = "SILENT";
40579 ModelLoggingVerbosity[ModelLoggingVerbosity["VERBOSE"] = 1] = "VERBOSE";
40580 })(ModelLoggingVerbosity || (ModelLoggingVerbosity = {}));
40581 /** How often to yield to the main thread when training (in ms). */
40582 const DEFAULT_YIELD_EVERY_MS = 125;
40583 /**
40584 * Abstract base class used to build new callbacks.
40585 *
40586 * The `logs` dictionary that callback methods take as argument will contain
40587 * keys for quantities relevant to the current batch or epoch.
40588 *
40589 * Currently, the `.fit()` method of the `Sequential` model class
40590 * will include the following quantities in the `logs` that
40591 * it passes to its callbacks:
40592 *
40593 * onEpochEnd: Logs include `acc` and `loss`, and optionally include `valLoss`
40594 * (if validation is enabled in `fit`), and `valAcc` (if validation and
40595 * accuracy monitoring are enabled).
40596 * onBatchBegin: Logs include `size`, the number of samples in the current
40597 * batch.
40598 * onBatchEnd: Logs include `loss`, and optionally `acc` (if accuracy monitoring
40599 * is enabled).
40600 */
40601 class BaseCallback {
40602 constructor() {
40603 // TODO(michaelterry): This type is a best guess.
40604 this.validationData = null;
40605 }
40606 setParams(params) {
40607 this.params = params;
40608 }
40609 async onEpochBegin(epoch, logs) { }
40610 async onEpochEnd(epoch, logs) { }
40611 async onBatchBegin(batch, logs) { }
40612 async onBatchEnd(batch, logs) { }
40613 async onTrainBegin(logs) { }
40614 async onTrainEnd(logs) { }
40615 // LayersModel needs to call Callback.setModel(), but cannot actually depend
40616 // on Callback because that creates a cyclic dependency. Providing this no-op
40617 // method on BaseCallback breaks the cycle: this way LayersModel can depend on
40618 // BaseCallback but not on Callback. The argument is typed as `Container`
40619 // (the superclass of LayersModel) to avoid recapitulating the cycle. Callback
40620 // overrides this method and enforces that the argument is really a
40621 // LayersModel.
40622 setModel(model) {
40623 // Do nothing. Use Callback instead of BaseCallback to track the model.
40624 }
40625 }
40626 /**
40627 * Container abstracting a list of callbacks.
40628 */
40629 class CallbackList {
40630 // TODO(cais): When the need arises, uncomment the following lines and
40631 // implement the queue for time values.
40632 // private deltaTBatch: number;
40633 // private deltaTsBatchBegin: Array<number>;
40634 // private deltaTsBatchEnd: Array<number>;
40635 /**
40636 * Constructor of CallbackList.
40637 * @param callbacks Array of `Callback` instances.
40638 * @param queueLength Queue length for keeping running statistics over
40639 * callback execution time.
40640 */
40641 constructor(callbacks, queueLength = 10) {
40642 // TODO(cais): Make use of queueLength when implementing the queue for time
40643 // values.
40644 if (callbacks == null) {
40645 callbacks = [];
40646 }
40647 this.callbacks = callbacks;
40648 this.queueLength = queueLength;
40649 }
40650 append(callback) {
40651 this.callbacks.push(callback);
40652 }
40653 setParams(params) {
40654 for (const callback of this.callbacks) {
40655 callback.setParams(params);
40656 }
40657 }
40658 setModel(model) {
40659 for (const callback of this.callbacks) {
40660 callback.setModel(model);
40661 }
40662 }
40663 /**
40664 * Called at the start of an epoch.
40665 * @param epoch Index of epoch.
40666 * @param logs Dictionary of logs.
40667 */
40668 async onEpochBegin(epoch, logs) {
40669 if (logs == null) {
40670 logs = {};
40671 }
40672 for (const callback of this.callbacks) {
40673 await callback.onEpochBegin(epoch, logs);
40674 }
40675 }
40676 /**
40677 * Called at the end of an epoch.
40678 * @param epoch Index of epoch.
40679 * @param logs Dictionary of logs.
40680 */
40681 async onEpochEnd(epoch, logs) {
40682 if (logs == null) {
40683 logs = {};
40684 }
40685 for (const callback of this.callbacks) {
40686 await callback.onEpochEnd(epoch, logs);
40687 }
40688 }
40689 /**
40690 * Called right before processing a batch.
40691 * @param batch Index of batch within the current epoch.
40692 * @param logs Dictionary of logs.
40693 */
40694 async onBatchBegin(batch, logs) {
40695 if (logs == null) {
40696 logs = {};
40697 }
40698 for (const callback of this.callbacks) {
40699 await callback.onBatchBegin(batch, logs);
40700 }
40701 }
40702 /**
40703 * Called at the end of a batch.
40704 * @param batch Index of batch within the current epoch.
40705 * @param logs Dictionary of logs.
40706 */
40707 async onBatchEnd(batch, logs) {
40708 if (logs == null) {
40709 logs = {};
40710 }
40711 for (const callback of this.callbacks) {
40712 await callback.onBatchEnd(batch, logs);
40713 }
40714 }
40715 /**
40716 * Called at the beginning of training.
40717 * @param logs Dictionary of logs.
40718 */
40719 async onTrainBegin(logs) {
40720 if (logs == null) {
40721 logs = {};
40722 }
40723 for (const callback of this.callbacks) {
40724 await callback.onTrainBegin(logs);
40725 }
40726 }
40727 /**
40728 * Called at the end of training.
40729 * @param logs Dictionary of logs.
40730 */
40731 async onTrainEnd(logs) {
40732 if (logs == null) {
40733 logs = {};
40734 }
40735 for (const callback of this.callbacks) {
40736 await callback.onTrainEnd(logs);
40737 }
40738 }
40739 }
40740 /**
40741 * Callback that accumulates epoch averages of metrics.
40742 *
40743 * This callback is automatically applied to every LayersModel.
40744 */
40745 class BaseLogger extends BaseCallback {
40746 constructor() {
40747 super();
40748 }
40749 async onEpochBegin(epoch) {
40750 this.seen = 0;
40751 this.totals = {};
40752 }
40753 async onBatchEnd(batch, logs) {
40754 if (logs == null) {
40755 logs = {};
40756 }
40757 const batchSize = logs['size'] == null ? 0 : logs['size'];
40758 this.seen += batchSize;
40759 for (const key in logs) {
40760 const value = logs[key];
40761 if (typeof value === 'number') {
40762 if (!this.totals.hasOwnProperty(key)) {
40763 this.totals[key] = 0;
40764 }
40765 this.totals[key] = this.totals[key] + value * batchSize;
40766 }
40767 else {
40768 let oldTotalsToDispose;
40769 if (key in this.totals) {
40770 oldTotalsToDispose = this.totals[key];
40771 }
40772 else {
40773 this.totals[key] = 0;
40774 }
40775 const total = tidy(() => add$3((this.totals[key]), mul(value, batchSize)));
40776 this.totals[key] = total;
40777 if (oldTotalsToDispose != null) {
40778 oldTotalsToDispose.dispose();
40779 }
40780 }
40781 }
40782 }
40783 async onEpochEnd(epoch, logs) {
40784 if (logs != null) {
40785 for (const key of this.params['metrics']) {
40786 if (this.totals[key] == null) {
40787 continue;
40788 }
40789 if (typeof this.totals[key] === 'number') {
40790 logs[key] = this.totals[key] / this.seen;
40791 }
40792 else {
40793 tidy(() => {
40794 const log = mul(div$1(1, this.seen), this.totals[key]);
40795 logs[key] = log;
40796 this.totals[key].dispose();
40797 keep(logs[key]);
40798 });
40799 }
40800 }
40801 }
40802 }
40803 }
40804 /**
40805 * Callback that records events into a `History` object. This callback is
40806 * automatically applied to every TF.js Layers model. The `History` object
40807 * gets returned by the `fit` method of models.
40808 */
40809 class History extends BaseCallback {
40810 async onTrainBegin(logs) {
40811 this.epoch = [];
40812 this.history = {};
40813 }
40814 async onEpochEnd(epoch, logs) {
40815 if (logs == null) {
40816 logs = {};
40817 }
40818 this.epoch.push(epoch);
40819 for (const key in logs) {
40820 if (this.history[key] == null) {
40821 this.history[key] = [];
40822 }
40823 this.history[key].push(logs[key]);
40824 }
40825 }
40826 /**
40827 * Await the values of all losses and metrics.
40828 */
40829 async syncData() {
40830 const promises = [];
40831 const keys = [];
40832 const indices = [];
40833 for (const key in this.history) {
40834 const valueArray = this.history[key];
40835 for (let i = 0; i < valueArray.length; ++i) {
40836 if (typeof valueArray[i] !== 'number') {
40837 const valueScalar = valueArray[i];
40838 promises.push(valueScalar.data());
40839 keys.push(key);
40840 indices.push(i);
40841 }
40842 }
40843 }
40844 const values = await Promise.all(promises);
40845 for (let n = 0; n < values.length; ++n) {
40846 const tensorToDispose = this.history[keys[n]][indices[n]];
40847 tensorToDispose.dispose();
40848 this.history[keys[n]][indices[n]] = values[n][0];
40849 }
40850 }
40851 }
40852 /**
40853 * Custom callback for training.
40854 */
40855 class CustomCallback extends BaseCallback {
40856 constructor(args, yieldEvery) {
40857 super();
40858 this.currentEpoch = 0;
40859 this.nowFunc = args.nowFunc;
40860 this.nextFrameFunc = args.nextFrameFunc || nextFrame;
40861 this.yieldEvery = yieldEvery || 'auto';
40862 if (this.yieldEvery === 'auto') {
40863 this.yieldEvery = DEFAULT_YIELD_EVERY_MS;
40864 }
40865 if (this.yieldEvery === 'never' && args.onYield != null) {
40866 throw new Error('yieldEvery is `never` but you provided an `onYield` callback. ' +
40867 'Either change `yieldEvery` or remove the callback');
40868 }
40869 if (isNumber(this.yieldEvery)) {
40870 // Decorate `maybeWait` so it will be called at most once every
40871 // `yieldEvery` ms.
40872 this.maybeWait = debounce(this.maybeWait.bind(this), this.yieldEvery, this.nowFunc);
40873 }
40874 this.trainBegin = args.onTrainBegin;
40875 this.trainEnd = args.onTrainEnd;
40876 this.epochBegin = args.onEpochBegin;
40877 this.epochEnd = args.onEpochEnd;
40878 this.batchBegin = args.onBatchBegin;
40879 this.batchEnd = args.onBatchEnd;
40880 this.yield = args.onYield;
40881 }
40882 async maybeWait(epoch, batch, logs) {
40883 const ps = [];
40884 if (this.yield != null) {
40885 await resolveScalarsInLogs(logs);
40886 ps.push(this.yield(epoch, batch, logs));
40887 }
40888 ps.push(this.nextFrameFunc());
40889 await Promise.all(ps);
40890 }
40891 async onEpochBegin(epoch, logs) {
40892 this.currentEpoch = epoch;
40893 if (this.epochBegin != null) {
40894 await resolveScalarsInLogs(logs);
40895 await this.epochBegin(epoch, logs);
40896 }
40897 }
40898 async onEpochEnd(epoch, logs) {
40899 const ps = [];
40900 if (this.epochEnd != null) {
40901 await resolveScalarsInLogs(logs);
40902 ps.push(this.epochEnd(epoch, logs));
40903 }
40904 if (this.yieldEvery === 'epoch') {
40905 ps.push(this.nextFrameFunc());
40906 }
40907 await Promise.all(ps);
40908 }
40909 async onBatchBegin(batch, logs) {
40910 if (this.batchBegin != null) {
40911 await resolveScalarsInLogs(logs);
40912 await this.batchBegin(batch, logs);
40913 }
40914 }
40915 async onBatchEnd(batch, logs) {
40916 const ps = [];
40917 if (this.batchEnd != null) {
40918 await resolveScalarsInLogs(logs);
40919 ps.push(this.batchEnd(batch, logs));
40920 }
40921 if (this.yieldEvery === 'batch') {
40922 ps.push(this.nextFrameFunc());
40923 }
40924 else if (isNumber(this.yieldEvery)) {
40925 ps.push(this.maybeWait(this.currentEpoch, batch, logs));
40926 }
40927 await Promise.all(ps);
40928 }
40929 async onTrainBegin(logs) {
40930 if (this.trainBegin != null) {
40931 await resolveScalarsInLogs(logs);
40932 await this.trainBegin(logs);
40933 }
40934 }
40935 async onTrainEnd(logs) {
40936 if (this.trainEnd != null) {
40937 await resolveScalarsInLogs(logs);
40938 await this.trainEnd(logs);
40939 }
40940 }
40941 }
40942 /**
40943 * Standardize callbacks or configurations of them to an Array of callbacks.
40944 */
40945 function standardizeCallbacks(callbacks, yieldEvery) {
40946 if (callbacks == null) {
40947 callbacks = {};
40948 }
40949 if (callbacks instanceof BaseCallback) {
40950 return [callbacks];
40951 }
40952 if (Array.isArray(callbacks) && callbacks[0] instanceof BaseCallback) {
40953 return callbacks;
40954 }
40955 // Convert custom callback configs to custom callback objects.
40956 const callbackConfigs = toList(callbacks);
40957 return callbackConfigs.map(callbackConfig => new CustomCallback(callbackConfig, yieldEvery));
40958 }
40959 /**
40960 * A global registry for callback constructors to be used during
40961 * LayersModel.fit().
40962 */
40963 class CallbackConstructorRegistry {
40964 /**
40965 * Blocks public access to constructor.
40966 */
40967 constructor() { }
40968 /**
40969 * Register a tf.LayersModel.fit() callback constructor.
40970 *
40971 * The registered callback constructor will be used to instantiate
40972 * callbacks for every tf.LayersModel.fit() call afterwards.
40973 *
40974 * @param verbosityLevel Level of verbosity at which the `callbackConstructor`
40975 * is to be reigstered.
40976 * @param callbackConstructor A no-arg constructor for `tf.Callback`.
40977 * @throws Error, if the same callbackConstructor has been registered before,
40978 * either at the same or a different `verbosityLevel`.
40979 */
40980 static registerCallbackConstructor(verbosityLevel, callbackConstructor) {
40981 assert$1(verbosityLevel >= 0 && Number.isInteger(verbosityLevel), () => `Verbosity level is expected to be an integer >= 0, ` +
40982 `but got ${verbosityLevel}`);
40983 CallbackConstructorRegistry.checkForDuplicate(callbackConstructor);
40984 if (CallbackConstructorRegistry.constructors[verbosityLevel] == null) {
40985 CallbackConstructorRegistry.constructors[verbosityLevel] = [];
40986 }
40987 CallbackConstructorRegistry.constructors[verbosityLevel].push(callbackConstructor);
40988 }
40989 static checkForDuplicate(callbackConstructor) {
40990 for (const levelName in CallbackConstructorRegistry.constructors) {
40991 const constructors = CallbackConstructorRegistry.constructors[+levelName];
40992 constructors.forEach(ctor => {
40993 if (ctor === callbackConstructor) {
40994 throw new ValueError('Duplicate callback constructor.');
40995 }
40996 });
40997 }
40998 }
40999 /**
41000 * Clear all registered callback constructors.
41001 */
41002 static clear() {
41003 CallbackConstructorRegistry.constructors = {};
41004 }
41005 /**
41006 * Create callbacks using the registered callback constructors.
41007 *
41008 * Given `verbosityLevel`, all constructors registered at that level or above
41009 * will be called and the instantiated callbacks will be used.
41010 *
41011 * @param verbosityLevel: Level of verbosity.
41012 */
41013 static createCallbacks(verbosityLevel) {
41014 const constructors = [];
41015 for (const levelName in CallbackConstructorRegistry.constructors) {
41016 const level = +levelName;
41017 if (verbosityLevel >= level) {
41018 constructors.push(...CallbackConstructorRegistry.constructors[level]);
41019 }
41020 }
41021 return constructors.map(ctor => new ctor());
41022 }
41023 }
41024 CallbackConstructorRegistry.constructors = {};
41025 function configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics) {
41026 const history = new History();
41027 const actualCallbacks = [
41028 new BaseLogger(), ...CallbackConstructorRegistry.createCallbacks(verbose)
41029 ];
41030 if (callbacks != null) {
41031 actualCallbacks.push(...callbacks);
41032 }
41033 actualCallbacks.push(history);
41034 const callbackList = new CallbackList(actualCallbacks);
41035 // TODO(cais): Figure out when this LayersModel instance can have a
41036 // dynamically
41037 // set property called 'callback_model' as in PyKeras.
41038 callbackList.setParams({
41039 epochs,
41040 initialEpoch,
41041 samples: numTrainSamples,
41042 steps: stepsPerEpoch,
41043 batchSize,
41044 verbose,
41045 doValidation,
41046 metrics: callbackMetrics,
41047 });
41048 return { callbackList, history };
41049 }
41050
41051 /**
41052 * @license
41053 * Copyright 2018 Google LLC
41054 *
41055 * Use of this source code is governed by an MIT-style
41056 * license that can be found in the LICENSE file or at
41057 * https://opensource.org/licenses/MIT.
41058 * =============================================================================
41059 */
41060 /**
41061 * Instantiate a layer from a config dictionary.
41062 * @param config dict of the form {class_name: str, config: dict}
41063 * @param customObjects dict mapping class names (or function names)
41064 * of custom (non-Keras) objects to class/functions
41065 * @param fastWeightInit Optional flag to use fast weight initialization
41066 * during deserialization. This is applicable to cases in which
41067 * the initialization will be immediately overwritten by loaded weight
41068 * values. Default: `false`.
41069 * @returns Layer instance (may be LayersModel, Sequential, Layer...)
41070 */
41071 function deserialize(config, customObjects = {}, fastWeightInit = false) {
41072 return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'layer', fastWeightInit);
41073 }
41074
41075 /**
41076 * @license
41077 * Copyright 2018 Google LLC
41078 *
41079 * Use of this source code is governed by an MIT-style
41080 * license that can be found in the LICENSE file or at
41081 * https://opensource.org/licenses/MIT.
41082 * =============================================================================
41083 */
41084 /**
41085 * Normalizes a tensor wrt the L2 norm alongside the specified axis.
41086 * @param x
41087 * @param axis Axis along which to perform normalization.
41088 */
41089 function l2Normalize(x, axis) {
41090 return tidy(() => {
41091 if (x.dtype !== 'float32') {
41092 x = cast$3(x, 'float32');
41093 }
41094 const squareSum = sum$3(square$1(x), axis, true);
41095 const epsilonTensor = fill$2(squareSum.shape, epsilon$1());
41096 const norm = sqrt$2(maximum$4(squareSum, epsilonTensor));
41097 return div$1(x, norm);
41098 });
41099 }
41100 function meanSquaredError$1(yTrue, yPred) {
41101 return tidy(() => mean$3(square$1(sub$2(yPred, yTrue)), -1));
41102 }
41103 function meanAbsoluteError$1(yTrue, yPred) {
41104 return tidy(() => mean$3(abs$2(sub$2(yPred, yTrue)), -1));
41105 }
41106 function meanAbsolutePercentageError$1(yTrue, yPred) {
41107 return tidy(() => {
41108 const diff = sub$2(yTrue, yPred);
41109 const clippedTrue = clipByValue$2(abs$2(yTrue), epsilon$1(), Number.MAX_VALUE);
41110 const absResult = abs$2(div$1(diff, clippedTrue));
41111 return mul(100, mean$3(absResult, -1));
41112 });
41113 }
41114 function meanSquaredLogarithmicError(yTrue, yPred) {
41115 return tidy(() => {
41116 const clippedPred = clipByValue$2(yPred, epsilon$1(), Number.MAX_VALUE);
41117 const firstLog = log$2(add$3(1, clippedPred));
41118 const clippedTrue = clipByValue$2(yTrue, epsilon$1(), Number.MAX_VALUE);
41119 const secondLog = log$2(add$3(1, clippedTrue));
41120 return mean$3(square$1(sub$2(firstLog, secondLog)), -1);
41121 });
41122 }
41123 function squaredHinge(yTrue, yPred) {
41124 return tidy(() => {
41125 const maxResult = maximum$4(0, sub$2(1, mul(yTrue, yPred)));
41126 return mean$3(square$1(maxResult), -1);
41127 });
41128 }
41129 function hinge(yTrue, yPred) {
41130 return tidy(() => {
41131 const maxResult = maximum$4(0, sub$2(1, mul(yTrue, yPred)));
41132 return mean$3(maxResult, -1);
41133 });
41134 }
41135 function categoricalHinge(yTrue, yPred) {
41136 return tidy(() => {
41137 const pos = sum$3(mul(yTrue, yPred), -1);
41138 const neg = max$3(mul(sub$2(1, yTrue), yPred), -1);
41139 return maximum$4(0, add$3(1, sub$2(neg, pos)));
41140 });
41141 }
41142 /**
41143 * Logarithm of the hyperbolic cosine of the prediction error.
41144 *
41145 * `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and
41146 * to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works mostly
41147 * like the mean squared error, but will not be so strongly affected by the
41148 * occasional wildly incorrect prediction.
41149 */
41150 function logcosh(yTrue, yPred) {
41151 return tidy(() => {
41152 const log2 = Math.log(2);
41153 const predictionDiff = sub$2(yPred, yTrue);
41154 const logcoshResult = sub$2(add$3(predictionDiff, softplus$2(mul(-2, predictionDiff))), log2);
41155 return mean$3(logcoshResult, -1);
41156 });
41157 }
41158 function categoricalCrossentropy$2(target, output, fromLogits = false) {
41159 return tidy(() => {
41160 if (fromLogits) {
41161 output = softmax$3(output);
41162 }
41163 else {
41164 // scale preds so that the class probabilities of each sample sum to 1.
41165 const outputSum = sum$3(output, output.shape.length - 1, true);
41166 output = div$1(output, outputSum);
41167 }
41168 output = clipByValue$2(output, epsilon$1(), 1 - epsilon$1());
41169 return neg$2(sum$3(mul(cast$3(target, 'float32'), log$2(output)), output.shape.length - 1));
41170 });
41171 }
41172 /**
41173 * Categorical crossentropy with integer targets.
41174 *
41175 * @param target An integer tensor.
41176 * @param output A tensor resulting from a softmax (unless `fromLogits` is
41177 * `true`, in which case `output` is expected to be the logits).
41178 * @param fromLogits Boolean, whether `output` is the result of a softmax, or is
41179 * a tensor of logits.
41180 */
41181 function sparseCategoricalCrossentropy$1(target, output, fromLogits = false) {
41182 return tidy(() => {
41183 const flatTarget = cast$3(floor$2(flatten$1(target)), 'int32');
41184 output = clipByValue$2(output, epsilon$1(), 1 - epsilon$1());
41185 const outputShape = output.shape;
41186 const oneHotTarget = reshape$3(oneHot$3(flatTarget, outputShape[outputShape.length - 1]), outputShape);
41187 return categoricalCrossentropy$2(oneHotTarget, output, fromLogits);
41188 });
41189 }
41190 /**
41191 * From TensorFlow's implementation in nn_impl.py:
41192 *
41193 * For brevity, let `x = logits`, `z = labels`. The logistic loss is
41194 * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
41195 * = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
41196 * = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
41197 * = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
41198 * = (1 - z) * x + log(1 + exp(-x))
41199 * = x - x * z + log(1 + exp(-x))
41200 * For x < 0, to avoid overflow in exp(-x), we reformulate the above
41201 * x - x * z + log(1 + exp(-x))
41202 * = log(exp(x)) - x * z + log(1 + exp(-x))
41203 * = - x * z + log(1 + exp(x))
41204 * Hence, to ensure stability and avoid overflow, the implementation uses this
41205 * equivalent formulation
41206 * max(x, 0) - x * z + log(1 + exp(-abs(x)))
41207 *
41208 * @param labels The labels.
41209 * @param logits The logits.
41210 */
41211 function sigmoidCrossEntropyWithLogits(labels, logits) {
41212 if (!arraysEqual(labels.shape, logits.shape)) {
41213 throw new ValueError(`logits and labels must have the same shape, but got shapes ` +
41214 `${JSON.stringify(labels.shape)} and ${JSON.stringify(logits.shape)}`);
41215 }
41216 return tidy(() => {
41217 // The logistic loss formula from above is
41218 // x - x * z + log(1 + exp(-x))
41219 // For x < 0, a more numerically stable formula is
41220 // -x * z + log(1 + exp(x))
41221 // Note that these two expressions can be combined into the following:
41222 // max(x, 0) - x * z + log(1 + exp(-abs(x)))
41223 const reluLogits = relu$2(logits);
41224 const negAbsLogits = neg$2(abs$2(logits));
41225 return add$3(sub$2(reluLogits, mul(logits, labels)), log1p$2(exp$2(negAbsLogits)));
41226 });
41227 }
41228 function binaryCrossentropy$2(yTrue, yPred) {
41229 return tidy(() => {
41230 let y;
41231 y = clipByValue$2(yPred, epsilon$1(), 1 - epsilon$1());
41232 y = log$2(div$1(y, sub$2(1, y)));
41233 return mean$3(sigmoidCrossEntropyWithLogits(yTrue, y), -1);
41234 });
41235 }
41236 function kullbackLeiblerDivergence(yTrue, yPred) {
41237 return tidy(() => {
41238 const clippedTrue = clipByValue$2(yTrue, epsilon$1(), 1);
41239 const clippedPred = clipByValue$2(yPred, epsilon$1(), 1);
41240 return sum$3(mul(yTrue, log$2(div$1(clippedTrue, clippedPred))), -1);
41241 });
41242 }
41243 function poisson(yTrue, yPred) {
41244 return tidy(() => {
41245 const logPred = log$2(add$3(epsilon$1(), yPred));
41246 return mean$3(sub$2(yPred, mul(yTrue, logPred)), -1);
41247 });
41248 }
41249 function cosineProximity$1(yTrue, yPred) {
41250 return tidy(() => {
41251 const trueNormalized = l2Normalize(yTrue, -1);
41252 const predNormalized = l2Normalize(yPred, -1);
41253 const trueXPred = mul(trueNormalized, predNormalized);
41254 return neg$2(sum$3(trueXPred, -1));
41255 });
41256 }
41257 const mse$2 = meanSquaredError$1;
41258 const MSE$2 = meanSquaredError$1;
41259 const mae$1 = meanAbsoluteError$1;
41260 const MAE$1 = meanAbsoluteError$1;
41261 const mape$2 = meanAbsolutePercentageError$1;
41262 const MAPE$2 = meanAbsolutePercentageError$1;
41263 const msle = meanSquaredLogarithmicError;
41264 const MSLE = meanSquaredLogarithmicError;
41265 const kld = kullbackLeiblerDivergence;
41266 const KLD = kullbackLeiblerDivergence;
41267 const cosine$1 = cosineProximity$1;
41268 // TODO(michaelterry): Add deserialize() function.
41269 const lossesMap = {
41270 meanSquaredError: meanSquaredError$1,
41271 meanAbsoluteError: meanAbsoluteError$1,
41272 meanAbsolutePercentageError: meanAbsolutePercentageError$1,
41273 meanSquaredLogarithmicError,
41274 squaredHinge,
41275 hinge,
41276 categoricalHinge,
41277 logcosh,
41278 categoricalCrossentropy: categoricalCrossentropy$2,
41279 sparseCategoricalCrossentropy: sparseCategoricalCrossentropy$1,
41280 binaryCrossentropy: binaryCrossentropy$2,
41281 kullbackLeiblerDivergence,
41282 poisson,
41283 cosineProximity: cosineProximity$1
41284 };
41285 // Porting note: This diverges from the PyKeras implementation and may need to
41286 // change based on (de)serialization requirements.
41287 function get$1(identifierOrFn) {
41288 if (typeof identifierOrFn === 'string') {
41289 if (identifierOrFn in lossesMap) {
41290 return lossesMap[identifierOrFn];
41291 }
41292 let errMsg = `Unknown loss ${identifierOrFn}`;
41293 if (identifierOrFn.toLowerCase().includes('softmaxcrossentropy')) {
41294 errMsg = `Unknown loss ${identifierOrFn}. ` +
41295 'Use "categoricalCrossentropy" as the string name for ' +
41296 'tf.losses.softmaxCrossEntropy';
41297 }
41298 throw new ValueError(errMsg);
41299 }
41300 else {
41301 return identifierOrFn;
41302 }
41303 }
41304
41305 /**
41306 * @license
41307 * Copyright 2018 Google LLC
41308 *
41309 * Use of this source code is governed by an MIT-style
41310 * license that can be found in the LICENSE file or at
41311 * https://opensource.org/licenses/MIT.
41312 * =============================================================================
41313 */
41314 function binaryAccuracy$1(yTrue, yPred) {
41315 return tidy(() => {
41316 const threshold = mul(.5, onesLike$3(yPred));
41317 const yPredThresholded = cast$2(greater$3(yPred, threshold), yTrue.dtype);
41318 return mean$3(equal$2(yTrue, yPredThresholded), -1);
41319 });
41320 }
41321 function categoricalAccuracy$1(yTrue, yPred) {
41322 return tidy(() => cast$2(equal$2(argMax$2(yTrue, -1), argMax$2(yPred, -1)), 'float32'));
41323 }
41324 function truePositives(yTrue, yPred) {
41325 return tidy(() => {
41326 return cast$3(sum$3(logicalAnd$2(equal$2(yTrue, 1), equal$2(yPred, 1))), 'float32');
41327 });
41328 }
41329 function falseNegatives(yTrue, yPred) {
41330 return tidy(() => {
41331 return cast$3(sum$3(logicalAnd$2(equal$2(yTrue, 1), equal$2(yPred, 0))), 'float32');
41332 });
41333 }
41334 function falsePositives(yTrue, yPred) {
41335 return tidy(() => {
41336 return cast$3(sum$3(logicalAnd$2(equal$2(yTrue, 0), equal$2(yPred, 1))), 'float32');
41337 });
41338 }
41339 function precision$1(yTrue, yPred) {
41340 return tidy(() => {
41341 const tp = truePositives(yTrue, yPred);
41342 const fp = falsePositives(yTrue, yPred);
41343 const denominator = add$3(tp, fp);
41344 return cast$3(where(greater$3(denominator, 0), div$1(tp, denominator), 0), 'float32');
41345 });
41346 }
41347 function recall$1(yTrue, yPred) {
41348 return tidy(() => {
41349 const tp = truePositives(yTrue, yPred);
41350 const fn = falseNegatives(yTrue, yPred);
41351 const denominator = add$3(tp, fn);
41352 return cast$3(where(greater$3(denominator, 0), div$1(tp, denominator), 0), 'float32');
41353 });
41354 }
41355 function binaryCrossentropy$1(yTrue, yPred) {
41356 return binaryCrossentropy$2(yTrue, yPred);
41357 }
41358 function sparseCategoricalAccuracy$1(yTrue, yPred) {
41359 if (yTrue.rank === yPred.rank) {
41360 yTrue = squeeze(yTrue, [yTrue.rank - 1]);
41361 }
41362 yPred = argMax$2(yPred, -1);
41363 if (yPred.dtype !== yTrue.dtype) {
41364 yPred = cast$3(yPred, yTrue.dtype);
41365 }
41366 return cast$3(equal$2(yTrue, yPred), 'float32');
41367 }
41368 function topKCategoricalAccuracy(yTrue, yPred) {
41369 throw new NotImplementedError();
41370 }
41371 function sparseTopKCategoricalAccuracy(yTrue, yPred) {
41372 throw new NotImplementedError();
41373 }
41374 function r2Score$1(yTrue, yPred) {
41375 return tidy(() => {
41376 const sumSquaresResiduals = yTrue.sub(yPred).square().sum();
41377 const sumSquares = yTrue.sub(yTrue.mean()).square().sum();
41378 return scalar(1).sub(sumSquaresResiduals.div(sumSquares));
41379 });
41380 }
41381 // Aliases.
41382 const mse$1 = meanSquaredError$1;
41383 const MSE$1 = meanSquaredError$1;
41384 const mae = meanAbsoluteError$1;
41385 const MAE = meanAbsoluteError$1;
41386 const mape$1 = meanAbsolutePercentageError$1;
41387 const MAPE$1 = meanAbsolutePercentageError$1;
41388 const categoricalCrossentropy$1 = categoricalCrossentropy$2;
41389 const cosine = cosineProximity$1;
41390 const sparseCategoricalCrossentropy = sparseCategoricalCrossentropy$1;
41391 // TODO(cais, nielsene): Add serialize().
41392 const metricsMap = {
41393 binaryAccuracy: binaryAccuracy$1,
41394 categoricalAccuracy: categoricalAccuracy$1,
41395 precision: precision$1,
41396 categoricalCrossentropy: categoricalCrossentropy$1,
41397 sparseCategoricalCrossentropy,
41398 mse: mse$1,
41399 MSE: MSE$1,
41400 mae,
41401 MAE,
41402 mape: mape$1,
41403 MAPE: MAPE$1,
41404 cosine
41405 };
41406 function get(identifier) {
41407 if (typeof identifier === 'string' && identifier in metricsMap) {
41408 return metricsMap[identifier];
41409 }
41410 else if (typeof identifier !== 'string' && identifier != null) {
41411 return identifier;
41412 }
41413 else {
41414 throw new ValueError(`Unknown metric ${identifier}`);
41415 }
41416 }
41417 /**
41418 * Get the shortcut function name.
41419 *
41420 * If the fn name is a string,
41421 * directly return the string name.
41422 * If the function is included in metricsMap or lossesMap,
41423 * return key of the map.
41424 * - If the function relative to multiple keys,
41425 * return the first found key as the function name.
41426 * - If the function exists in both lossesMap and metricsMap,
41427 * search lossesMap first.
41428 * If the function is not included in metricsMap or lossesMap,
41429 * return the function name.
41430 *
41431 * @param fn loss function, metric function, or short cut name.
41432 * @returns Loss or Metric name in string.
41433 */
41434 function getLossOrMetricName(fn) {
41435 assert(fn !== null, `Unknown LossOrMetricFn ${fn}`);
41436 if (typeof fn === 'string') {
41437 return fn;
41438 }
41439 else {
41440 let fnName;
41441 for (const key of Object.keys(lossesMap)) {
41442 if (lossesMap[key] === fn) {
41443 fnName = key;
41444 break;
41445 }
41446 }
41447 if (fnName !== undefined) {
41448 return fnName;
41449 }
41450 for (const key of Object.keys(metricsMap)) {
41451 if (metricsMap[key] === fn) {
41452 fnName = key;
41453 break;
41454 }
41455 }
41456 if (fnName !== undefined) {
41457 return fnName;
41458 }
41459 return fn.name;
41460 }
41461 }
41462
41463 /**
41464 * @license
41465 * Copyright 2018 Google LLC
41466 *
41467 * Use of this source code is governed by an MIT-style
41468 * license that can be found in the LICENSE file or at
41469 * https://opensource.org/licenses/MIT.
41470 * =============================================================================
41471 */
41472 // Add (de)serialize()
41473 // Porting note: This diverges from the PyKeras implementation and may need to
41474 // change based on (de)serialization requirements.
41475 function getOptimizer(identifier) {
41476 const optimizerMap = {
41477 'Adagrad': () => train.adagrad(0.01),
41478 'Adadelta': () => train.adadelta(1, 0.95, epsilon$1()),
41479 'Adam': () => train.adam(0.001, 0.9, 0.999, epsilon$1()),
41480 'Adamax': () => train.adamax(0.002, 0.9, 0.999, epsilon$1(), 0),
41481 'RMSProp': () => train.rmsprop(0.001, 0.9, 0, epsilon$1()),
41482 'SGD': () => train.sgd(0.01)
41483 };
41484 optimizerMap['adagrad'] = optimizerMap['Adagrad'];
41485 optimizerMap['adadelta'] = optimizerMap['Adadelta'];
41486 optimizerMap['adam'] = optimizerMap['Adam'];
41487 optimizerMap['adamax'] = optimizerMap['Adamax'];
41488 optimizerMap['rmsprop'] = optimizerMap['RMSProp'];
41489 optimizerMap['sgd'] = optimizerMap['SGD'];
41490 if (identifier in optimizerMap) {
41491 return optimizerMap[identifier]();
41492 }
41493 throw new ValueError(`Unknown Optimizer ${identifier}`);
41494 }
41495
41496 /**
41497 * @license
41498 * Copyright 2019 Google LLC
41499 *
41500 * Use of this source code is governed by an MIT-style
41501 * license that can be found in the LICENSE file or at
41502 * https://opensource.org/licenses/MIT.
41503 * =============================================================================
41504 */
41505 /** Utility functions related to user-defined metadata. */
41506 // Maximum recommended serialized size for user-defined metadata.
41507 // Beyond this limit, a warning message will be printed during model loading and
41508 // saving.
41509 const MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH = 1 * 1024 * 1024;
41510 /**
41511 * Check validity of user-defined metadata.
41512 *
41513 * @param userDefinedMetadata
41514 * @param modelName Name of the model that the user-defined metadata belongs to.
41515 * Used during construction of error messages.
41516 * @param checkSize Whether to check the size of the metadata is under
41517 * recommended limit. Default: `false`. If `true`, will try stringify the
41518 * JSON object and print a console warning if the serialzied size is above the
41519 * limit.
41520 * @throws Error if `userDefinedMetadata` is not a plain JSON object.
41521 */
41522 function checkUserDefinedMetadata(userDefinedMetadata, modelName, checkSize = false) {
41523 if (userDefinedMetadata == null ||
41524 typeof userDefinedMetadata !== 'object' ||
41525 Object.getPrototypeOf(userDefinedMetadata) !== Object.prototype ||
41526 !plainObjectCheck(userDefinedMetadata)) {
41527 throw new Error('User-defined metadata is expected to be a JSON object, but is not.');
41528 }
41529 if (checkSize) {
41530 const out = JSON.stringify(userDefinedMetadata);
41531 if (out.length > MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH) {
41532 console.warn(`User-defined metadata of model "${modelName}" is too large in ` +
41533 `size (length=${out.length} when serialized). It is not ` +
41534 `recommended to store such large objects in user-defined metadata. ` +
41535 `Please make sure its serialized length is <= ` +
41536 `${MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH}.`);
41537 }
41538 }
41539 }
41540 /**
41541 * Check if an input is plain JSON object or any valid subfield of it.
41542 *
41543 * @param x The input to be checked.
41544 * @param assertObject Whether to assert `x` is a JSON object, i.e., reject
41545 * cases of arrays and primitives.
41546 * @return Returns `true` if and only if `x` is a plain JSON object,
41547 * a JSON-valid primitive including string, number, boolean and null,
41548 * or an array of the said types.
41549 */
41550 // tslint:disable-next-line:no-any
41551 function plainObjectCheck(x) {
41552 if (x === null) {
41553 // Note: typeof `null` is 'object', and `null` is valid in JSON.
41554 return true;
41555 }
41556 else if (typeof x === 'object') {
41557 if (Object.getPrototypeOf(x) === Object.prototype) {
41558 // `x` is a JavaScript object and its prototype is Object.
41559 const keys = Object.keys(x);
41560 for (const key of keys) {
41561 if (typeof key !== 'string') {
41562 // JSON keys must be strings.
41563 return false;
41564 }
41565 if (!plainObjectCheck(x[key])) { // Recursive call.
41566 return false;
41567 }
41568 }
41569 return true;
41570 }
41571 else {
41572 // `x` is a JavaScript object but its prototype is not Object.
41573 if (Array.isArray(x)) {
41574 // `x` is a JavaScript array.
41575 for (const item of x) {
41576 if (!plainObjectCheck(item)) { // Recursive call.
41577 return false;
41578 }
41579 }
41580 return true;
41581 }
41582 else {
41583 // `x` is a JavaScript object and its prototype is not Object,
41584 // and it's not an Array. I.e., it's a complex object such as
41585 // `Error` and `Date`.
41586 return false;
41587 }
41588 }
41589 }
41590 else {
41591 // `x` is not a JavaScript object or `null`.
41592 const xType = typeof x;
41593 return xType === 'string' || xType === 'number' || xType === 'boolean';
41594 }
41595 }
41596
41597 /**
41598 * @license
41599 * Copyright 2018 Google LLC
41600 *
41601 * Use of this source code is governed by an MIT-style
41602 * license that can be found in the LICENSE file or at
41603 * https://opensource.org/licenses/MIT.
41604 * =============================================================================
41605 */
41606 /**
41607 * Print the summary of a LayersModel object.
41608 *
41609 * @param model tf.LayersModel instance.
41610 * @param lineLength Total length of printed lines. Set this to adapt to the
41611 * display to different terminal or console sizes.
41612 * @param positions Relative or absolute positions of log elements in each
41613 * line. Each number corresponds to right-most (i.e., ending) position of a
41614 * column.
41615 * If not provided, defaults to `[0.45, 0.85, 1]` for sequential-like
41616 * models and `[0.33, 0.55, 0.67, 1]` for non-sequential like models.
41617 * @param printFn Print function to use.
41618 * It will be called on each line of the summary. You can provide a custom
41619 * function in order to capture the string summary. Defaults to `console.log`.
41620 */
41621 function printSummary(model, lineLength, positions,
41622 // tslint:disable-next-line:no-any
41623 printFn = console.log) {
41624 const sequentialLike = isModelSequentialLike(model);
41625 // Header names for different log elements.
41626 const toDisplay = ['Layer (type)', 'Input Shape', 'Output shape', 'Param #'];
41627 if (sequentialLike) {
41628 lineLength = lineLength || 90;
41629 positions = positions || [0.32, 0.61, 0.89, 1];
41630 }
41631 else {
41632 lineLength = lineLength || 115;
41633 positions = positions || [0.24, 0.48, 0.70, 0.80, 1];
41634 // Header names for different log elements.
41635 }
41636 if (positions[positions.length - 1] <= 1) {
41637 // `positions` is relative. Convert it to absolute positioning.
41638 positions = positions.map(p => Math.floor(lineLength * p));
41639 }
41640 let relevantNodes;
41641 if (!sequentialLike) {
41642 toDisplay.push('Receives inputs');
41643 relevantNodes = [];
41644 for (const depth in model.nodesByDepth) {
41645 relevantNodes.push(...model.nodesByDepth[depth]);
41646 }
41647 }
41648 printFn('_'.repeat(lineLength));
41649 printRow(toDisplay, positions, printFn);
41650 printFn('='.repeat(lineLength));
41651 const layers = model.layers;
41652 for (let i = 0; i < layers.length; ++i) {
41653 if (sequentialLike) {
41654 printLayerSummary(layers[i], positions, printFn);
41655 }
41656 else {
41657 printLayerSummaryWithConnections(layers[i], positions, relevantNodes, printFn);
41658 }
41659 printFn((i === layers.length - 1 ? '=' : '_').repeat(lineLength));
41660 }
41661 // tslint:disable-next-line:no-any
41662 model.checkTrainableWeightsConsistency();
41663 const trainableCount = countTrainableParams(model);
41664 const nonTrainableCount = countParamsInWeights(model.nonTrainableWeights);
41665 printFn(`Total params: ${trainableCount + nonTrainableCount}`);
41666 printFn(`Trainable params: ${trainableCount}`);
41667 printFn(`Non-trainable params: ${nonTrainableCount}`);
41668 printFn('_'.repeat(lineLength));
41669 }
41670 function countTrainableParams(model) {
41671 let trainableCount;
41672 // tslint:disable:no-any
41673 if (model.collectedTrainableWeights != null) {
41674 trainableCount =
41675 countParamsInWeights(model.collectedTrainableWeights);
41676 }
41677 else {
41678 trainableCount = countParamsInWeights(model.trainableWeights);
41679 }
41680 // tslint:enable:no-any
41681 return trainableCount;
41682 }
41683 function isModelSequentialLike(model) {
41684 let sequentialLike = true;
41685 const nodesByDepth = [];
41686 const nodes = [];
41687 for (const depth in model.nodesByDepth) {
41688 nodesByDepth.push(model.nodesByDepth[depth]);
41689 }
41690 for (const depthNodes of nodesByDepth) {
41691 if (depthNodes.length > 1 ||
41692 depthNodes.length === 1 && depthNodes[0].inboundLayers.length > 1) {
41693 sequentialLike = false;
41694 break;
41695 }
41696 nodes.push(...depthNodes);
41697 }
41698 if (sequentialLike) {
41699 // Search for shared layers.
41700 for (const layer of model.layers) {
41701 let flag = false;
41702 for (const node of layer.inboundNodes) {
41703 if (nodes.indexOf(node) !== -1) {
41704 if (flag) {
41705 sequentialLike = false;
41706 break;
41707 }
41708 else {
41709 flag = true;
41710 }
41711 }
41712 }
41713 if (!sequentialLike) {
41714 break;
41715 }
41716 }
41717 }
41718 return sequentialLike;
41719 }
41720 function printRow(fields, positions,
41721 // tslint:disable-next-line:no-any
41722 printFn = console.log) {
41723 let line = '';
41724 for (let i = 0; i < fields.length; ++i) {
41725 if (i > 0) {
41726 line = line.slice(0, line.length - 1) + ' ';
41727 }
41728 line += fields[i];
41729 line = line.slice(0, positions[i]);
41730 line += ' '.repeat(positions[i] - line.length);
41731 }
41732 printFn(line);
41733 }
41734 /**
41735 * Prints a summary for a single Layer, without connectivity information.
41736 *
41737 * @param layer: Layer instance to print.
41738 */
41739 function printLayerSummary(layer, positions,
41740 // tslint:disable-next-line:no-any
41741 printFn) {
41742 let outputShape;
41743 let inputShape;
41744 try {
41745 inputShape = (layer.inboundNodes.map(x => JSON.stringify(x.inputShapes))).join(',');
41746 }
41747 catch (err) {
41748 inputShape = 'multiple';
41749 }
41750 try {
41751 outputShape = JSON.stringify(layer.outputShape);
41752 }
41753 catch (err) {
41754 outputShape = 'multiple';
41755 }
41756 const name = layer.name;
41757 const className = layer.getClassName();
41758 const fields = [`${name} (${className})`, inputShape,
41759 outputShape, layer.countParams().toString()];
41760 printRow(fields, positions, printFn);
41761 }
41762 /**
41763 * Prints a summary for a single Layer, with connectivity information.
41764 */
41765 function printLayerSummaryWithConnections(layer, positions, relevantNodes,
41766 // tslint:disable-next-line:no-any
41767 printFn) {
41768 let outputShape;
41769 let inputShape;
41770 try {
41771 inputShape = (layer.inboundNodes.map(x => JSON.stringify(x.inputShapes))).join(',');
41772 }
41773 catch (err) {
41774 inputShape = 'multiple';
41775 }
41776 try {
41777 outputShape = JSON.stringify(layer.outputShape);
41778 }
41779 catch (err) {
41780 outputShape = 'multiple';
41781 }
41782 const connections = [];
41783 for (const node of layer.inboundNodes) {
41784 if (relevantNodes != null && relevantNodes.length > 0 &&
41785 relevantNodes.indexOf(node) === -1) {
41786 continue;
41787 }
41788 for (let i = 0; i < node.inboundLayers.length; ++i) {
41789 const inboundLayer = node.inboundLayers[i].name;
41790 const inboundLayerIndex = node.nodeIndices[i];
41791 const inboundTensorIndex = node.tensorIndices[i];
41792 connections.push(`${inboundLayer}[${inboundLayerIndex}][${inboundTensorIndex}]`);
41793 }
41794 }
41795 const name = layer.name;
41796 const className = layer.getClassName();
41797 const firstConnection = connections.length === 0 ? '' : connections[0];
41798 const fields = [
41799 `${name} (${className})`, inputShape,
41800 outputShape, layer.countParams().toString(),
41801 firstConnection
41802 ];
41803 printRow(fields, positions, printFn);
41804 for (let i = 1; i < connections.length; ++i) {
41805 printRow(['', '', '', '', connections[i]], positions, printFn);
41806 }
41807 }
41808
41809 /**
41810 * @license
41811 * Copyright 2018 Google LLC
41812 *
41813 * Use of this source code is governed by an MIT-style
41814 * license that can be found in the LICENSE file or at
41815 * https://opensource.org/licenses/MIT.
41816 * =============================================================================
41817 */
41818 // tslint:enable
41819 /**
41820 * Test whether a value in an array is the name of a LayersModel or Layer.
41821 * @param key The key name that the value is found under. Note that the key
41822 * may not be at the level immediately above the value, if the value is in a
41823 * nested array.
41824 * @param index Index of the value in the Array that it is found in.
41825 * @param value The value object.
41826 * @returns A boolean indicating whether value is a name.
41827 */
41828 function isArrayItemInputOrOutputName(key, index, value) {
41829 return (key === 'inboundNodes' || key === 'outputLayers' ||
41830 key === 'inputLayers') &&
41831 index === 0 && typeof value === 'string';
41832 }
41833 /**
41834 * Convert a Pythonic config object to TypeScript config object.
41835 * @param pythonicConfig The config object to convert.
41836 * @param key Optional key name of the object being converted.
41837 * @returns Result of the conversion.
41838 */
41839 function convertPythonicToTs(pythonicConfig, key) {
41840 if (pythonicConfig === null) {
41841 return null;
41842 }
41843 else if (typeof pythonicConfig === 'string') {
41844 return toCamelCase(pythonicConfig);
41845 }
41846 else if ((typeof pythonicConfig === 'number') ||
41847 (typeof pythonicConfig === 'boolean')) {
41848 return pythonicConfig;
41849 }
41850 else if (pythonicConfig instanceof Array) {
41851 const tsArray = [];
41852 const arrayLength = pythonicConfig.length;
41853 for (let i = 0; i < arrayLength; ++i) {
41854 const item = pythonicConfig[i];
41855 if (isArrayItemInputOrOutputName(key, i, item)) {
41856 tsArray.push(item);
41857 }
41858 else {
41859 tsArray.push(convertPythonicToTs(item, key));
41860 }
41861 }
41862 return tsArray;
41863 }
41864 else {
41865 const tsDict = {};
41866 for (const pythonicKey of Object.keys(pythonicConfig)) {
41867 const pythonicValue = pythonicConfig[pythonicKey];
41868 if (pythonicKey === 'name' && typeof pythonicValue === 'string') {
41869 // Special case the 'name' key with a string value. Name values, such as
41870 // the names of LayersModel and Layer instances, should not undergo the
41871 // camel-case conversion.
41872 tsDict[pythonicKey] = pythonicValue;
41873 }
41874 else {
41875 const tsKey = toCamelCase(pythonicKey);
41876 tsDict[tsKey] = convertPythonicToTs(pythonicValue, tsKey);
41877 }
41878 }
41879 return tsDict;
41880 }
41881 }
41882 /**
41883 * Convert a TypeScript config object to Python config object.
41884 * @param tsConfig The config object to convert.
41885 * @param key Optional key name of the object being converted.
41886 * @returns Result of the conversion.
41887 */
41888 function convertTsToPythonic(tsConfig, key) {
41889 if (tsConfig === null || tsConfig === undefined) {
41890 return null;
41891 }
41892 else if (typeof tsConfig === 'string') {
41893 return toSnakeCase(tsConfig);
41894 }
41895 else if ((typeof tsConfig === 'number') || (typeof tsConfig === 'boolean')) {
41896 return tsConfig;
41897 }
41898 else if (tsConfig instanceof Array) {
41899 const pyArray = [];
41900 const arrayLength = tsConfig.length;
41901 for (let i = 0; i < arrayLength; ++i) {
41902 const item = tsConfig[i];
41903 if (isArrayItemInputOrOutputName(key, i, item)) {
41904 pyArray.push(item);
41905 }
41906 else {
41907 pyArray.push(convertTsToPythonic(item, key));
41908 }
41909 }
41910 return pyArray;
41911 }
41912 else {
41913 const pyDict = {};
41914 for (const tsKey of Object.keys(tsConfig)) {
41915 const tsValue = tsConfig[tsKey];
41916 const pyKey = toSnakeCase(tsKey);
41917 if ((tsKey === 'name' || tsKey === 'className') &&
41918 typeof tsValue === 'string') {
41919 // Special case the 'name' key with a string value. Name values, such as
41920 // the names of LayersModel and Layer instances, should not undergo the
41921 // snake-case conversion.
41922 pyDict[pyKey] = tsValue;
41923 }
41924 else {
41925 pyDict[pyKey] = convertTsToPythonic(tsValue, tsKey);
41926 }
41927 }
41928 return pyDict;
41929 }
41930 }
41931
41932 /** @license See the LICENSE file. */
41933 // This code is auto-generated, do not modify this file!
41934 const version$6 = '4.22.0';
41935
41936 /**
41937 * @license
41938 * Copyright 2018 Google LLC
41939 *
41940 * Use of this source code is governed by an MIT-style
41941 * license that can be found in the LICENSE file or at
41942 * https://opensource.org/licenses/MIT.
41943 * =============================================================================
41944 */
41945 // get weights key from tensor map in order to check if it is from keras v3.
41946 // e.g. dense/0
41947 const isKerasSavedModelFormat = (weights) => {
41948 const keys = Object.keys(weights);
41949 if (keys.length === 0) {
41950 return false;
41951 }
41952 const key = keys[0].split('/');
41953 return !isNaN(parseInt(key[key.length - 1], 10));
41954 };
41955 /**
41956 * A Container is a directed acyclic graph of layers.
41957 *
41958 * It is the topological form of a "model". A LayersModel
41959 * is simply a Container with added training routines.
41960 *
41961 */
41962 class Container extends Layer {
41963 constructor(args) {
41964 // No args passed to super's constructor.
41965 super({});
41966 this.containerNodes = new Set();
41967 this.name = args.name;
41968 if (this.name == null) {
41969 const prefix = this.getClassName().toLowerCase();
41970 this.name = getUid(prefix);
41971 }
41972 this.supportsMasking = false;
41973 this.trainable_ = true;
41974 // TODO(michaelterry): Initialize perInputLosses/Updates here.
41975 // Container-specific properties.
41976 if (Array.isArray(args.inputs)) {
41977 this.inputs = args.inputs.slice();
41978 }
41979 else {
41980 this.inputs = [args.inputs];
41981 }
41982 if (Array.isArray(args.outputs)) {
41983 this.outputs = args.outputs.slice();
41984 }
41985 else {
41986 this.outputs = [args.outputs];
41987 }
41988 // Check for redundancy in inputs.
41989 if (unique$2(this.inputs).length !== this.inputs.length) {
41990 throw new ValueError('The list of inputs passed to the model is ' +
41991 'redundant. All inputs should only appear once. Found: ' +
41992 `${this.inputs.map(x => x.name)}`);
41993 }
41994 // Check for redundancy in outputs.
41995 if (unique$2(this.outputs).length !== this.outputs.length) {
41996 console.warn('The list of outputs passed to the model is redundant. ' +
41997 'All outputs should only appear once. Found: ' +
41998 `${this.outputs.map(x => x.name)}`);
41999 }
42000 /*
42001 List of initial layers (1 to 1 mapping with this.inputs, hence the same
42002 layer might appear twice)
42003 */
42004 this.inputLayers = [];
42005 this.inputLayersNodeIndices = [];
42006 this.inputLayersTensorIndices = [];
42007 /*
42008 List of layers (1 to 1 mapping with this.outputs, hence the same layer
42009 might appear twice)
42010 */
42011 this.outputLayers = [];
42012 this.outputLayersNodeIndices = [];
42013 this.outputLayersTensorIndices = [];
42014 /*
42015 All layers in order of horizontal graph traversal. Entries are unique.
42016 Includes input and output layers.
42017 */
42018 this.layers = [];
42019 /*
42020 References to container layers that were constructed internally. We need
42021 these to properly dispose of tensors from nested containers.
42022 */
42023 this.internalContainerRefs = [];
42024 // TODO(michaelterry): Determine if caching still needed with eager
42025 // backend.
42026 /*
42027 This is for performance optimization when calling the Container on new
42028 inputs. Every time the Container is called on a set on input tensors,
42029 we compute the output tensors, output masks and output shapes in one pass,
42030 then cache them here. When one of these outputs is queried later,
42031 we retrieve it from there instead of recomputing it.
42032 */
42033 // this.outputTensorCache = {};
42034 // this.outputShapeCache = {};
42035 // Build this.outputLayers:
42036 for (const x of this.outputs) {
42037 const layer = x.sourceLayer;
42038 const nodeIndex = x.nodeIndex;
42039 const tensorIndex = x.tensorIndex;
42040 this.outputLayers.push(layer);
42041 this.outputLayersNodeIndices.push(nodeIndex);
42042 this.outputLayersTensorIndices.push(tensorIndex);
42043 }
42044 // TODO(michaelterry): Add output mask cache code.
42045 // Build this.inputLayers:
42046 for (const x of this.inputs) {
42047 const layer = x.sourceLayer;
42048 const nodeIndex = x.nodeIndex;
42049 const tensorIndex = x.tensorIndex;
42050 /*
42051 It's supposed to be an input layer, so only one node
42052 and one tensor output.
42053 */
42054 assert(nodeIndex === 0, 'input layer has >1 nodes');
42055 assert(tensorIndex === 0, 'input layer has >1 tensors');
42056 this.inputLayers.push(layer);
42057 this.inputLayersNodeIndices.push(nodeIndex);
42058 this.inputLayersTensorIndices.push(tensorIndex);
42059 }
42060 // Build this.inputNames and this.outputNames.
42061 this.inputNames = [];
42062 this.outputNames = [];
42063 this.feedInputShapes = [];
42064 this.feedInputNames = [];
42065 this.feedOutputNames = [];
42066 for (let i = 0; i < this.inputLayers.length; i++) {
42067 const layer = this.inputLayers[i];
42068 // Check that layer is an InputLayer.
42069 if (!(layer instanceof InputLayer)) {
42070 throw new TypeError('Input layers to a LayersModel must be InputLayer objects. ' +
42071 `Received inputs: ${args.inputs}. ` +
42072 `Input ${i} (0-based) originates ` +
42073 `from layer type ${layer.getClassName()}.`);
42074 }
42075 this.inputNames.push(layer.name);
42076 this.feedInputShapes.push(layer.batchInputShape);
42077 this.feedInputNames.push(layer.name);
42078 }
42079 for (const layer of this.outputLayers) {
42080 this.outputNames.push(layer.name);
42081 }
42082 this.internalInputShapes = this.inputs.map(x => x.shape);
42083 this.internalOutputShapes = this.outputs.map(x => x.shape);
42084 /*
42085 Container_nodes: set of nodes included in the graph (not all nodes
42086 included in the layers are relevant to the current graph).
42087 */
42088 // ids of all nodes relevant to the Container:
42089 const nodesDepths = {};
42090 // To recover nodes from their ID.
42091 const nodeIDToNode = {};
42092 const layersDepths = {};
42093 // To layers from their ID.
42094 const layerIDToLayer = {};
42095 const layerIndices = {};
42096 const nodesInDecreasingDepth = [];
42097 /**
42098 * Builds a map of the graph of layers.
42099 *
42100 * This recursively updates the map `layerIndices`,
42101 * the list `nodesInDecreasingDepth` and the set `containerNodes`.
42102 *
42103 * @param tensor Some tensor in a graph.
42104 * @param finishedNodes Set of nodes whose subgraphs have been traversed
42105 * completely. Useful to prevent duplicated work.
42106 * @param nodesInProgress Set of nodes that are currently active on the
42107 * recursion stack. Useful to detect cycles.
42108 * @param layer Layer from which `tensor` comes from. If not provided,
42109 * will be obtained from tensor.sourceLayer.
42110 * @param nodeIndex Node index from which `tensor` comes from.
42111 * @param tensorIndex TensorIndex from which `tensor` comes from.
42112 *
42113 * @exception RuntimeError if a cycle is detected.
42114 */
42115 const buildMapOfGraph = (tensor, finishedNodes, nodesInProgress, layer, nodeIndex, tensorIndex) => {
42116 if (layer == null || nodeIndex == null || tensorIndex == null) {
42117 layer = tensor.sourceLayer;
42118 nodeIndex = tensor.nodeIndex;
42119 tensorIndex = tensor.tensorIndex;
42120 }
42121 const node = layer.inboundNodes[nodeIndex];
42122 // Prevent cycles.
42123 if (nodesInProgress.indexOf(node) !== -1) {
42124 throw new RuntimeError(`The tensor ${tensor.name} at layer "${layer.name}" ` +
42125 'is part of a cycle.');
42126 }
42127 // Don't repeat work for shared subgraphs
42128 if (finishedNodes.indexOf(node) !== -1) {
42129 return;
42130 }
42131 // Update containerNodes.
42132 this.containerNodes.add(Container.nodeKey(layer, nodeIndex));
42133 // Store the traversal order for layer sorting.
42134 if (!(layer.id in layerIndices)) {
42135 layerIndices[layer.id] = Object.keys(layerIndices).length;
42136 }
42137 if (nodesInProgress.indexOf(node) === -1) {
42138 nodesInProgress.push(node);
42139 }
42140 // Propagate to all previous tensors connected to this node.
42141 const numInboundLayers = node.inboundLayers.length;
42142 for (let i = 0; i < numInboundLayers; i++) {
42143 const x = node.inputTensors[i];
42144 const layer = node.inboundLayers[i];
42145 const nodeIndex = node.nodeIndices[i];
42146 const tensorIndex = node.tensorIndices[i];
42147 buildMapOfGraph(x, finishedNodes, nodesInProgress, layer, nodeIndex, tensorIndex);
42148 }
42149 finishedNodes.push(node);
42150 while (nodesInProgress.indexOf(node) >= 0) {
42151 nodesInProgress.splice(nodesInProgress.indexOf(node), 1);
42152 }
42153 nodesInDecreasingDepth.push(node);
42154 };
42155 const finishedNodes = [];
42156 const nodesInProgress = [];
42157 for (const x of this.outputs) {
42158 buildMapOfGraph(x, finishedNodes, nodesInProgress);
42159 }
42160 const reversedNodesInDecreasingDepth = nodesInDecreasingDepth.slice().reverse();
42161 for (const node of reversedNodesInDecreasingDepth) {
42162 nodeIDToNode[node.id] = node;
42163 // If the depth is not set, the node has no outbound nodes (depth 0).
42164 if (!(node.id in nodesDepths)) {
42165 nodesDepths[node.id] = 0;
42166 }
42167 let depth = nodesDepths[node.id];
42168 // Update the depth of the corresponding layer
42169 const previousDepth = (layersDepths[node.outboundLayer.id] == null ?
42170 0 :
42171 layersDepths[node.outboundLayer.id]);
42172 /*
42173 If we've seen this layer before at a higher depth, we should use that
42174 depth instead of the node depth. This is necessary for shared layers
42175 that have inputs at different depth levels in the graph.
42176 */
42177 depth = Math.max(depth, previousDepth);
42178 layersDepths[node.outboundLayer.id] = depth;
42179 layerIDToLayer[node.outboundLayer.id] = node.outboundLayer;
42180 nodesDepths[node.id] = depth;
42181 // Update the depth of inbound nodes.
42182 for (let i = 0; i < node.inboundLayers.length; i++) {
42183 const inboundLayer = node.inboundLayers[i];
42184 const nodeIndex = node.nodeIndices[i];
42185 const inboundNode = inboundLayer.inboundNodes[nodeIndex];
42186 const previousDepth = (nodesDepths[inboundNode.id] == null ? 0 :
42187 nodesDepths[inboundNode.id]);
42188 nodesDepths[inboundNode.id] = Math.max(depth + 1, previousDepth);
42189 nodeIDToNode[inboundNode.id] = inboundNode;
42190 }
42191 }
42192 // Build a dict {depth: list of nodes with this depth}
42193 const nodesByDepth = {};
42194 for (const nodeID in nodesDepths) {
42195 const depth = nodesDepths[nodeID];
42196 if (!(depth in nodesByDepth)) {
42197 nodesByDepth[depth] = [];
42198 }
42199 nodesByDepth[depth].push(nodeIDToNode[nodeID]);
42200 }
42201 // Build a dict {depth: list of layers with this depth}
42202 const layersByDepth = {};
42203 for (const layerID in layersDepths) {
42204 const depth = layersDepths[layerID];
42205 if (!(depth in layersByDepth)) {
42206 layersByDepth[depth] = [];
42207 }
42208 layersByDepth[depth].push(layerIDToLayer[layerID]);
42209 }
42210 // Get sorted list of layer depths.
42211 let depthKeys = Object.keys(layersByDepth)
42212 .map(x => parseInt(x, 10))
42213 .sort(reverseNumberCompare);
42214 // Set this.layers and this.layersByDepth.
42215 this.layers = [];
42216 for (const depth of depthKeys) {
42217 const layersForDepth = layersByDepth[depth];
42218 // Container.layers needs to have a deterministic order:
42219 // here we order them by traversal order.
42220 layersForDepth.sort((a, b) => {
42221 const aIndex = layerIndices[a.id];
42222 const bIndex = layerIndices[b.id];
42223 if (aIndex < bIndex) {
42224 return -1;
42225 }
42226 if (aIndex > bIndex) {
42227 return 1;
42228 }
42229 return 0;
42230 });
42231 for (const layer of layersForDepth) {
42232 if (layer instanceof Container) {
42233 this.internalContainerRefs.push(layer);
42234 }
42235 this.layers.push(layer);
42236 }
42237 }
42238 this.layersByDepth = layersByDepth;
42239 // Get sorted list of node depths;
42240 depthKeys = Object.keys(nodesByDepth)
42241 .map(x => parseInt(x, 10))
42242 .sort(reverseNumberCompare);
42243 // Check that all tensors required are computable.
42244 // computable_tensors: all tensors in the graph
42245 // that can be computed from the inputs provided.
42246 const computableTensors = this.inputs.slice();
42247 // To provide a better error msg.
42248 const layersWithCompleteInput = [];
42249 for (const depth of depthKeys) {
42250 for (const node of nodesByDepth[depth]) {
42251 const layer = node.outboundLayer;
42252 if (layer != null) {
42253 for (const x of node.inputTensors) {
42254 if (computableTensors.indexOf(x) === -1) {
42255 throw new RuntimeError(`Graph disconnected: cannot obtain value for tensor ${x}` +
42256 ` at layer "${layer.name}". ` +
42257 'The following previous layers were accessed without ' +
42258 `issue: ${layersWithCompleteInput}`);
42259 }
42260 }
42261 for (const x of node.outputTensors) {
42262 computableTensors.push(x);
42263 }
42264 layersWithCompleteInput.push(layer.name);
42265 }
42266 }
42267 }
42268 // Set this.containerNodes and this.nodesByDepth.
42269 this.nodesByDepth = nodesByDepth;
42270 // Ensure name unicity, which will be crucial for serialization
42271 // (since serialized nodes refer to layers by their name).
42272 const allNames = this.layers.map(x => x.name);
42273 for (const name of allNames) {
42274 const numOccurrences = allNames.filter(x => x === name).length;
42275 if (numOccurrences !== 1) {
42276 throw new RuntimeError(`The name "${name}" is used ${numOccurrences} times ` +
42277 'in the model. All layer names should be unique. Layer names: ' +
42278 JSON.stringify(allNames));
42279 }
42280 }
42281 // Layer parameters.
42282 // The new container starts with a single inbound node
42283 // for its inputs, and no outbound nodes.
42284 // Will be appended to by future calls to apply().
42285 this.outboundNodes = [];
42286 // Will be appended to below, and by future calls to apply().
42287 this.inboundNodes = [];
42288 // Create the node linking internal inputs to internal outputs.
42289 // (This call has side effects.)
42290 // tslint:disable-next-line:no-unused-expression
42291 new Node({
42292 outboundLayer: this,
42293 inboundLayers: [],
42294 nodeIndices: [],
42295 tensorIndices: [],
42296 inputTensors: this.inputs,
42297 outputTensors: this.outputs,
42298 inputMasks: this.inputs.map(x => null),
42299 outputMasks: this.outputs.map(x => null),
42300 inputShapes: this.inputs.map(x => x.shape),
42301 outputShapes: this.outputs.map(x => x.shape)
42302 });
42303 this.built = true;
42304 this._refCount = 1; // The ref count of a container always start at 1.
42305 }
42306 assertNotDisposed() {
42307 if (this._refCount === 0) {
42308 throw new Error(`Container '${this.name}' is already disposed.`);
42309 }
42310 }
42311 /**
42312 * Attempt to dispose a LayersModel's weights.
42313 *
42314 * This method decrease the reference count of the LayersModel object by 1.
42315 *
42316 * A LayersModel is reference-counted. Its reference count is incremented by 1
42317 * when it is first constructed and when it is used as a Layer of another
42318 * LayersModel.
42319 *
42320 * If the reference count of a LayersModel becomes 0, the `dispose` method of
42321 * all its constituent `Layer`s will be called.
42322 *
42323 * Note: If the reference count is greater than 0 after the decrement, the
42324 * `dispose` method of its constituent `Layer`s will *not* be called.
42325 *
42326 * After a LayersModel is disposed, it cannot be used in calls such as
42327 * 'predict`, `evaluate` or `fit` anymore.
42328 *
42329 * @returns A DisposeResult Object with the following fields:
42330 * - refCountAfterDispose: The reference count of the LayersModel after this
42331 * `dispose()` call.
42332 * - numDisposedVariables: Number of `tf.Variable`s (i.e., weights) disposed
42333 * during this `dispose()` call.
42334 * @throws {Error} If the layer is not built yet, or if the LayersModel has
42335 * already been disposed.
42336 */
42337 dispose() {
42338 this.assertNotDisposed();
42339 const result = { refCountAfterDispose: null, numDisposedVariables: 0 };
42340 if (--this._refCount === 0) {
42341 for (const layer of this.layers) {
42342 result.numDisposedVariables += layer.dispose().numDisposedVariables;
42343 }
42344 // Call dispose on each internally created container layer again to ensure
42345 // their refCounts hit zero and their tensors are subsequently deleted.
42346 for (const container of this.internalContainerRefs) {
42347 result.numDisposedVariables += container.dispose().numDisposedVariables;
42348 }
42349 }
42350 result.refCountAfterDispose = this._refCount;
42351 return result;
42352 }
42353 get trainable() {
42354 return this.trainable_;
42355 }
42356 set trainable(trainable) {
42357 this.layers.forEach(layer => {
42358 // tslint:disable-next-line:no-any
42359 layer._trainableWeights
42360 .forEach(w => w.trainable = trainable);
42361 });
42362 this.trainable_ = trainable;
42363 }
42364 get trainableWeights() {
42365 // Porting Note: This check below is to prevent errors where the
42366 // _trainableWeights inherited from the parent class (Layer) gets
42367 // inadvertently used.
42368 if (this._trainableWeights.length > 0) {
42369 throw new ValueError('Container instance unexpectedly contains _trainableWeights.' +
42370 'The trainable weights of a Container are a union of the ' +
42371 'trainable weights of its consituent Layers. Its own ' +
42372 '_trainableWeights must remain an empty Array.');
42373 }
42374 if (!this.trainable) {
42375 return [];
42376 }
42377 let weights = [];
42378 for (const layer of this.layers) {
42379 weights = weights.concat(layer.trainableWeights);
42380 }
42381 return weights;
42382 }
42383 get nonTrainableWeights() {
42384 const weights = [];
42385 for (const layer of this.layers) {
42386 weights.push(...layer.nonTrainableWeights);
42387 }
42388 if (!this.trainable) {
42389 const trainableWeights = [];
42390 for (const layer of this.layers) {
42391 trainableWeights.push(...layer.trainableWeights);
42392 }
42393 return trainableWeights.concat(weights);
42394 }
42395 return weights;
42396 }
42397 get weights() {
42398 return this.trainableWeights.concat(this.nonTrainableWeights);
42399 }
42400 /**
42401 * Loads all layer weights from a JSON object.
42402 *
42403 * Porting Note: HDF5 weight files cannot be directly loaded in JavaScript /
42404 * TypeScript. The utility script at `scripts/pykeras.py` offers means
42405 * to convert them into JSON strings compatible with this method.
42406 * Porting Note: TensorFlow.js Layers supports only loading by name currently.
42407 *
42408 * @param weights A JSON mapping weight names to weight values as nested
42409 * arrays of numbers, or a `NamedTensorMap`, i.e., a JSON mapping weight
42410 * names to `tf.Tensor` objects.
42411 * @param strict Require that the provided weights exactly match those
42412 * required by the container. Default: `true`. Passing `false` means that
42413 * extra weights and missing weights will be silently ignored.
42414 */
42415 loadWeights(weights, strict = true) {
42416 const nameToWeight = {};
42417 let totalWeightsCount = 0;
42418 const modelIsKerasSavedModelFormat = isKerasSavedModelFormat(weights);
42419 if (modelIsKerasSavedModelFormat) {
42420 this.parseWeights(weights);
42421 }
42422 // Check if weights from keras v3.
42423 for (const layer of this.layers) {
42424 for (const [index, weight] of layer.weights.entries()) {
42425 // Parse the name to layerName/index.
42426 // e.g. dense/0, dense/1, dense_1/0, dense_1/1
42427 const parsedName = modelIsKerasSavedModelFormat ?
42428 `${weight.name.split('/').slice(0, -1).join('/') + '/'}${index}` :
42429 weight.originalName;
42430 if (nameToWeight[parsedName] != null) {
42431 throw new ValueError(`Duplicate weight name: ${parsedName}`);
42432 }
42433 nameToWeight[parsedName] = weight;
42434 totalWeightsCount++;
42435 }
42436 }
42437 const weightValueTuples = [];
42438 for (const name in weights) {
42439 // TF 2.2.0 added cell name to the weight name in the format of
42440 // layer_name/cell_name/weight_name, we need to remove
42441 // the inner cell name.
42442 let validatedName = name;
42443 if (nameToWeight[name] == null) {
42444 const tokens = name.split('/');
42445 const shortenNameArray = tokens.slice(0, -2).concat([tokens[tokens.length - 1]]);
42446 validatedName = shortenNameArray.join('/');
42447 }
42448 if (nameToWeight[validatedName] != null) {
42449 weightValueTuples.push([nameToWeight[validatedName], weights[name]]);
42450 }
42451 else if (strict) {
42452 throw new ValueError(`Provided weight data has no target variable: ${name}`);
42453 }
42454 delete nameToWeight[validatedName];
42455 }
42456 if (strict) {
42457 // Check that all weights are set.
42458 const unsetNames = [];
42459 for (const name in nameToWeight) {
42460 unsetNames.push(name);
42461 }
42462 if (unsetNames.length > 0) {
42463 throw new ValueError(`${unsetNames.length} of ${totalWeightsCount} weights are not set: ` +
42464 `${unsetNames}`);
42465 }
42466 }
42467 batchSetValue(weightValueTuples);
42468 }
42469 parseWeights(weights) {
42470 for (const key in Object.keys(weights)) {
42471 const listParts = key.split('/');
42472 const list = ['vars', 'layer_checkpoint_dependencies'];
42473 // For keras v3, the weights name are saved based on the folder structure.
42474 // e.g. _backbone/_layer_checkpoint_dependencies/transformer/_self../
42475 // _output_dense/vars/0
42476 // Therefore we discard the `vars` and `layer_checkpoint_depencies` within
42477 // the saved name and only keeps the layer name and weights.
42478 // This can help to mapping the actual name of the layers and load each
42479 // weight accordingly.
42480 const newKey = listParts
42481 .map(str => {
42482 if (str.startsWith('_')) {
42483 return str.slice(1);
42484 }
42485 return str;
42486 })
42487 .filter(str => !list.includes(str))
42488 .join('/');
42489 if (newKey !== key) {
42490 weights[newKey] = weights[key];
42491 delete weights[key];
42492 }
42493 }
42494 }
42495 /**
42496 * Util shared between different serialization methods.
42497 * @returns LayersModel config with Keras version information added.
42498 */
42499 updatedConfig() {
42500 const theConfig = this.getConfig();
42501 const modelConfig = {};
42502 modelConfig['className'] = this.getClassName();
42503 modelConfig['config'] = theConfig;
42504 modelConfig['kerasVersion'] = `tfjs-layers ${version$6}`;
42505 // TODO(nielsene): Replace something like K.backend() once
42506 // possible.
42507 modelConfig['backend'] = 'TensorFlow.js';
42508 return modelConfig;
42509 }
42510 /**
42511 * Returns a JSON string containing the network configuration.
42512 *
42513 * To load a network from a JSON save file, use
42514 * models.modelFromJSON(jsonString);
42515 * @param extraJsonArgs Unused in tfjs-layers, maintained for PyKeras
42516 * @param returnString Whether the return value should be stringified
42517 * (default: `true`).
42518 * @returns a JSON string if `returnString` (default), or a JSON object if
42519 * `!returnString`.
42520 */
42521 // tslint:disable-next-line:no-any
42522 toJSON(unused, returnString = true) {
42523 const modelConfig = convertTsToPythonic(this.updatedConfig());
42524 return returnString ? JSON.stringify(modelConfig) : modelConfig;
42525 }
42526 /**
42527 * Call the model on new inputs.
42528 *
42529 * In this case `call` just reapplies all ops in the graph to the new inputs
42530 * (e.g. build a new computational graph from the provided inputs).
42531 *
42532 * @param inputs A tensor or list of tensors.
42533 * @param mask A mask or list of masks. A mask can be either a tensor or null
42534 * (no mask).
42535 *
42536 * @return A tensor if there is a single output, or a list of tensors if there
42537 * are more than one outputs.
42538 */
42539 call(inputs, kwargs) {
42540 return tidy(() => {
42541 inputs = toList(inputs);
42542 const feedDict = new FeedDict();
42543 for (let i = 0; i < this.inputs.length; ++i) {
42544 feedDict.add(this.inputs[i], inputs[i]);
42545 }
42546 return execute(this.outputs, feedDict, kwargs);
42547 });
42548 }
42549 /**
42550 * Computes an output mask tensor.
42551 *
42552 * @param inputs Tensor or list of tensors.
42553 * @param mask Tensor or list of tensors.
42554 *
42555 * @return null or a tensor (or list of tensors, one per output tensor of the
42556 * layer).
42557 */
42558 computeMask(inputs, mask) {
42559 return tidy(() => {
42560 inputs = toList(inputs);
42561 let masks;
42562 if (mask == null) {
42563 masks = pyListRepeat(null, inputs.length);
42564 }
42565 else {
42566 masks = toList(mask);
42567 }
42568 // TODO(michaelterry): Add support for mask caching.
42569 return this.runInternalGraph(inputs, masks)[1];
42570 });
42571 }
42572 /**
42573 * Computes the output shape of the layer.
42574 *
42575 * Assumes that the layer will be built to match that input shape provided.
42576 *
42577 * @param inputShape A shape (tuple of integers) or a list of shape tuples
42578 * (one per output tensor of the layer). Shape tuples can include null for
42579 * free dimensions, instead of an integer.
42580 */
42581 computeOutputShape(inputShape) {
42582 const inputShapes = normalizeShapeList(inputShape);
42583 if (inputShapes.length !== this.inputLayers.length) {
42584 throw new ValueError(`Invalid inputShape argument ${inputShape}: ` +
42585 `model has ${this.inputLayers.length} tensor inputs.`);
42586 }
42587 // TODO(michaelterry): Add caching
42588 const layersToOutputShapes = {};
42589 for (let i = 0; i < inputShapes.length; i++) {
42590 const layer = this.inputLayers[i];
42591 const inputShape = inputShapes[i];
42592 // It's an input layer: computeOutputShape is identity,
42593 // and there is only one node and one tensor output.
42594 const shapeKey = layer.name + '_0_0';
42595 layersToOutputShapes[shapeKey] = inputShape;
42596 }
42597 const depthKeys = Object.keys(this.nodesByDepth)
42598 .map(x => parseInt(x, 10))
42599 .sort(reverseNumberCompare);
42600 // Iterate over nodes, by depth level.
42601 if (depthKeys.length > 1) {
42602 for (const depth of depthKeys) {
42603 const nodes = this.nodesByDepth[depth];
42604 for (const node of nodes) {
42605 // This is always a single layer, never a list.
42606 const layer = node.outboundLayer;
42607 if (this.inputLayers.map(x => x.id).indexOf(layer.id) !== -1) {
42608 // We've already covered the input layers a few lines above.
42609 continue;
42610 }
42611 // Potentially redundant list, same size of node.inputTensors.
42612 const inputShapes = [];
42613 for (let j = 0; j < node.inboundLayers.length; j++) {
42614 const inboundLayer = node.inboundLayers[j];
42615 const nodeIndex = node.nodeIndices[j];
42616 const tensorIndex = node.tensorIndices[j];
42617 const shapeKey = `${inboundLayer.name}_${nodeIndex}_${tensorIndex}`;
42618 const inputShape = layersToOutputShapes[shapeKey];
42619 inputShapes.push(inputShape);
42620 }
42621 const outputShape = layer.computeOutputShape(singletonOrArray(inputShapes));
42622 const outputShapes = normalizeShapeList(outputShape);
42623 const nodeIndex = layer.inboundNodes.indexOf(node);
42624 for (let j = 0; j < outputShapes.length; j++) {
42625 const shapeKey = `${layer.name}_${nodeIndex}_${j}`;
42626 layersToOutputShapes[shapeKey] = outputShapes[j];
42627 }
42628 }
42629 }
42630 }
42631 // Read final output shapes from layersToOutputShapes.
42632 const outputShapes = [];
42633 const outputShapeKeys = [];
42634 for (let i = 0; i < this.outputLayers.length; i++) {
42635 const layer = this.outputLayers[i];
42636 const nodeIndex = this.outputLayersNodeIndices[i];
42637 const tensorIndex = this.outputLayersTensorIndices[i];
42638 const shapeKey = `${layer.name}_${nodeIndex}_${tensorIndex}`;
42639 outputShapeKeys.push(shapeKey);
42640 }
42641 for (let i = 0; i < outputShapeKeys.length; i++) {
42642 const key = outputShapeKeys[i];
42643 assert(key in layersToOutputShapes);
42644 outputShapes.push(layersToOutputShapes[key]);
42645 }
42646 // TODO(michaelterry): Update cache
42647 return singletonOrArray(outputShapes);
42648 }
42649 /**
42650 * Computes output tensors for new inputs.
42651 *
42652 * Note:
42653 * - Expects `inputs` to be a list (potentially with 1 element).
42654 *
42655 * @param inputs List of tensors
42656 * @param masks List of masks (tensors or null).
42657 * @return Three lists: outputTensors, outputMasks, outputShapes
42658 */
42659 runInternalGraph(inputs, masks) {
42660 if (masks == null) {
42661 masks = pyListRepeat(null, inputs.length);
42662 }
42663 // Dictionary mapping reference tensors to tuples
42664 // (computed tensor, compute mask)
42665 // we assume a 1:1 mapping from tensor to mask
42666 // TODO: raise exception when a `.computeMask()` call
42667 // does not return a list the same size as `call`
42668 const tensorMap = {};
42669 for (let i = 0; i < this.inputs.length; ++i) {
42670 const x = this.inputs[i];
42671 const y = inputs[i];
42672 const mask = masks[i];
42673 tensorMap[x.id] = [y, mask];
42674 }
42675 const depthKeys = Object.keys(this.nodesByDepth)
42676 .map(x => parseInt(x, 10))
42677 .sort(reverseNumberCompare);
42678 for (const depth of depthKeys) {
42679 const nodes = this.nodesByDepth[depth];
42680 for (const node of nodes) {
42681 // This is always a single layer, never a list.
42682 const layer = node.outboundLayer;
42683 const referenceInputTensors = node.inputTensors;
42684 const referenceOutputTensors = node.outputTensors;
42685 // If all previous input tensors are available in tensorMap,
42686 // then call node.inboundLayer on them.
42687 // List of tuples [input, mask]:
42688 const computedData = new Array();
42689 for (const x of referenceInputTensors) {
42690 if (x.id in tensorMap) {
42691 computedData.push(tensorMap[x.id]);
42692 }
42693 }
42694 if (computedData.length === referenceInputTensors.length) {
42695 // TODO(michaelterry): Add K.name_scope here, if we need it.
42696 let kwargs = {};
42697 let computedTensors;
42698 let computedMasks;
42699 let outputTensors;
42700 let outputMasks;
42701 // call layer
42702 if (node.callArgs != null) {
42703 kwargs = node.callArgs;
42704 }
42705 if (computedData.length === 1) {
42706 const [computedTensor, computedMask] = computedData[0];
42707 if (kwargs['mask'] == null) {
42708 kwargs['mask'] = computedMask;
42709 }
42710 outputTensors =
42711 toList(layer.call(computedTensor, kwargs));
42712 outputMasks = toList(layer.computeMask(computedTensor, computedMask));
42713 computedTensors = [computedTensor];
42714 computedMasks = [computedMask];
42715 }
42716 else {
42717 computedTensors = computedData.map(x => x[0]);
42718 computedMasks = computedData.map(x => x[1]);
42719 if (kwargs['mask'] == null) {
42720 kwargs['mask'] = computedMasks;
42721 }
42722 outputTensors =
42723 toList(layer.call(computedTensors, kwargs));
42724 outputMasks = toList(layer.computeMask(computedTensors, computedMasks));
42725 }
42726 if (layer.activityRegularizer) {
42727 throw new NotImplementedError('LayersModel invocation with concrete Tensor value(s) in the ' +
42728 'presence of activity regularizer(s) is not supported yet.');
42729 }
42730 // TODO(michaelterry): Add model updates and losses
42731 // Update tensor map.
42732 for (let i = 0; i < referenceOutputTensors.length; ++i) {
42733 const x = referenceOutputTensors[i];
42734 const y = outputTensors[i];
42735 const mask = outputMasks[i];
42736 tensorMap[x.id] = [y, mask];
42737 }
42738 }
42739 }
42740 }
42741 const outputTensors = [];
42742 const outputMasks = [];
42743 const outputShapes = [];
42744 for (const x of this.outputs) {
42745 assert(x.id in tensorMap, `Could not compute output ${x.name} : ${x.id}`);
42746 const [tensor, mask] = tensorMap[x.id];
42747 outputShapes.push(tensor.shape);
42748 outputTensors.push(tensor);
42749 outputMasks.push(mask);
42750 }
42751 // TODO(michaelterry): Add support for caches.
42752 return [outputTensors, outputMasks, outputShapes];
42753 }
42754 /**
42755 * Builds a map of internal node keys to node ordering.
42756 * Used in serializaion a node orderings may change as unused nodes are
42757 * dropped. Porting Note: This helper method was pulled out of getConfig to
42758 * improve readability.
42759 * @param layers An array of Layers in the model.
42760 * @returns Map of Node Keys to index order within the layer.
42761 */
42762 buildNodeConversionMap(layers) {
42763 const nodeConversionMap = {};
42764 let keptNodes;
42765 for (const layer of this.layers) {
42766 keptNodes = layer instanceof Container ? 1 : 0;
42767 for (let originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) {
42768 const nodeKey = Container.nodeKey(layer, originalNodeIndex);
42769 if (this.containerNodes.has(nodeKey)) {
42770 // i.e. we mark it to be saved
42771 nodeConversionMap[nodeKey] = keptNodes;
42772 keptNodes += 1;
42773 }
42774 }
42775 }
42776 return nodeConversionMap;
42777 }
42778 getLayer(nameOrIndex, index) {
42779 if (index != null) {
42780 return this.findLayer(index);
42781 }
42782 else {
42783 if (nameOrIndex == null) {
42784 throw new ValueError('Provide either a layer name or layer index');
42785 }
42786 if (typeof nameOrIndex === 'number') {
42787 return this.findLayer(nameOrIndex);
42788 }
42789 }
42790 for (const layer of this.layers) {
42791 if (layer.name === nameOrIndex) {
42792 return layer;
42793 }
42794 }
42795 throw new ValueError(`No such layer: ${nameOrIndex}`);
42796 }
42797 findLayer(index) {
42798 if (this.layers.length <= index) {
42799 throw new ValueError(`Was asked to retrieve layer at index ${index}, but model only ` +
42800 `has ${this.layers.length} layer(s).`);
42801 }
42802 else {
42803 return this.layers[index];
42804 }
42805 }
42806 /**
42807 * Retrieves the Container's current loss values.
42808 *
42809 * Used for regularizers during training.
42810 */
42811 calculateLosses() {
42812 // Porting Node: This is an augmentation to Container.loss in PyKeras.
42813 // In PyKeras, Container.loss returns symbolic tensors. Here a concrete
42814 // Tensor (specifically Scalar) values are returned. This is due to the
42815 // imperative backend.
42816 return tidy(() => {
42817 const losses = [];
42818 for (const layer of this.layers) {
42819 for (let nodeIndex = 0; nodeIndex < layer.inboundNodes.length; ++nodeIndex) {
42820 const nodeKey = Container.nodeKey(layer, nodeIndex);
42821 if (this.containerNodes.has(nodeKey)) {
42822 losses.push(...layer.calculateLosses());
42823 }
42824 }
42825 }
42826 // TODO(cais): Add any unconditional model-level losses?
42827 return losses;
42828 });
42829 }
42830 getConfig() {
42831 const config = { name: this.name };
42832 // Build a map from layer unique name (self._node_key)
42833 // to the index of the nodes that are saved in the config.
42834 // Only nodes in container_nodes are saved.
42835 const nodeConversionMap = this.buildNodeConversionMap(this.layers);
42836 // Serialize and save the layers in layerConfigs
42837 const layerConfigs = [];
42838 for (const layer of this.layers) {
42839 const layerClassName = layer.getClassName();
42840 const layerConfig = layer.getConfig();
42841 const filteredInboundNodes = [];
42842 for (let originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) {
42843 const node = layer.inboundNodes[originalNodeIndex];
42844 const nodeKey = Container.nodeKey(layer, originalNodeIndex);
42845 let kwargs = {};
42846 if (this.containerNodes.has(nodeKey)) {
42847 // The node is relevant to the model:
42848 // add to filteredInboundNodes.
42849 if (node.callArgs) {
42850 try {
42851 JSON.stringify(node.callArgs);
42852 kwargs = node.callArgs;
42853 }
42854 catch (err) {
42855 console.warn(`Layer ${layer.name} was passed ` +
42856 `non-serializable keyword arguments: ` +
42857 `${node.callArgs}. They will not be included ` +
42858 `in the serialized model (and thus will be ` +
42859 `missing at deserialization time).`);
42860 kwargs = {};
42861 }
42862 }
42863 if (node.inboundLayers.length > 0) {
42864 const nodeData = [];
42865 for (let i = 0; i < node.inboundLayers.length; i++) {
42866 const inboundLayer = node.inboundLayers[i];
42867 const nodeIndex = node.nodeIndices[i];
42868 const tensorIndex = node.tensorIndices[i];
42869 const nodeKey = Container.nodeKey(inboundLayer, nodeIndex);
42870 let newNodeIndex = nodeConversionMap[nodeKey];
42871 if (newNodeIndex == null) {
42872 newNodeIndex = 0;
42873 }
42874 nodeData.push([inboundLayer.name, newNodeIndex, tensorIndex, kwargs]);
42875 }
42876 filteredInboundNodes.push(nodeData);
42877 }
42878 }
42879 }
42880 const dict = {};
42881 dict['name'] = layer.name;
42882 dict['className'] = layerClassName;
42883 dict['config'] = layerConfig;
42884 dict['inboundNodes'] = filteredInboundNodes;
42885 layerConfigs.push(dict);
42886 }
42887 config['layers'] = layerConfigs;
42888 // Gather info about inputs and outputs
42889 const modelInputs = [];
42890 for (let i = 0; i < this.inputLayers.length; i++) {
42891 const layer = this.inputLayers[i];
42892 const nodeIndex = this.inputLayersNodeIndices[i];
42893 const nodeKey = Container.nodeKey(layer, nodeIndex);
42894 if (!this.containerNodes.has(nodeKey)) {
42895 continue;
42896 }
42897 let newNodeIndex = nodeConversionMap[nodeKey];
42898 if (newNodeIndex === null || newNodeIndex === undefined) {
42899 newNodeIndex = 0;
42900 }
42901 const tensorIndex = this.inputLayersTensorIndices[i];
42902 modelInputs.push([layer.name, newNodeIndex, tensorIndex]);
42903 }
42904 config['inputLayers'] = modelInputs;
42905 const modelOutputs = [];
42906 for (let i = 0; i < this.outputLayers.length; i++) {
42907 const layer = this.outputLayers[i];
42908 const nodeIndex = this.outputLayersNodeIndices[i];
42909 const nodeKey = Container.nodeKey(layer, nodeIndex);
42910 if (!this.containerNodes.has(nodeKey)) {
42911 continue;
42912 }
42913 let newNodeIndex = nodeConversionMap[nodeKey];
42914 if (newNodeIndex === null || newNodeIndex === undefined) {
42915 newNodeIndex = 0;
42916 }
42917 const tensorIndex = this.outputLayersTensorIndices[i];
42918 modelOutputs.push([layer.name, newNodeIndex, tensorIndex]);
42919 }
42920 config['outputLayers'] = modelOutputs;
42921 return config;
42922 }
42923 /**
42924 * Instantiates a LayersModel from its config (output of `get_config()`).
42925 * @param cls the class to create
42926 * @param config LayersModel config dictionary.
42927 * @param customObjects An optional dictionary of custom objects.
42928 * @param fastWeightInit Optional flag to use fast weight initialization
42929 * during deserialization. This is applicable to cases in which
42930 * the initialization will be immediately overwritten by loaded weight
42931 * values. Default: `false`.
42932 * @returns A LayersModel instance.
42933 * @throws ValueError: In case of improperly formatted config dict.
42934 */
42935 /** @nocollapse */
42936 static fromConfig(cls, config, customObjects = {}, fastWeightInit = false) {
42937 // Layer instances created during
42938 // the graph reconstruction process
42939 const createdLayers = {};
42940 // Dictionary mapping layer instances to
42941 // node data that specifies a layer call.
42942 // It acts as a queue that maintains any unprocessed
42943 // layer call until it becomes possible to process it
42944 // (i.e. until the input tensors to the call all exist).
42945 const unprocessedNodes = {};
42946 function addUnprocessedNode(layer, nodeData) {
42947 if (!(layer.name in unprocessedNodes)) {
42948 unprocessedNodes[layer.name] = [nodeData];
42949 }
42950 else {
42951 unprocessedNodes[layer.name].push(nodeData);
42952 }
42953 }
42954 function processNode(layer, nodeData) {
42955 const inputTensors = [];
42956 let kwargs;
42957 for (const inputData of nodeData) {
42958 const inboundLayerName = inputData[0];
42959 const inboundNodeIndex = inputData[1];
42960 const inboundTensorIndex = inputData[2];
42961 kwargs = inputData[3] == null ?
42962 {} :
42963 inputData[3];
42964 if (!(inboundLayerName in createdLayers)) {
42965 addUnprocessedNode(layer, nodeData);
42966 return;
42967 }
42968 const inboundLayer = createdLayers[inboundLayerName];
42969 if (inboundLayer.inboundNodes.length <= inboundNodeIndex) {
42970 addUnprocessedNode(layer, nodeData);
42971 return;
42972 }
42973 const inboundNode = inboundLayer.inboundNodes[inboundNodeIndex];
42974 inputTensors.push(inboundNode.outputTensors[inboundTensorIndex]);
42975 }
42976 // Call layer on its inputs, thus creating the node
42977 // and building the layer if needed.
42978 // Note: This has Eager vs Graph Implications.
42979 if (inputTensors.length > 0) {
42980 layer.apply(singletonOrArray(inputTensors), kwargs); // was ** kwargs
42981 }
42982 }
42983 /**
42984 * Deserialize a layer, then call it on appropriate inputs.
42985 * @param layerData: layer config dict.
42986 * @throws ValueError: In case of improperly formatted `layer_data`
42987 * dict.
42988 */
42989 function processLayer(layerData) {
42990 const layerName = layerData['name'];
42991 // Instantiate layer.
42992 const layer = deserialize(layerData, config['customObjects'] != null ?
42993 config['customObjects'] :
42994 {});
42995 layer.setFastWeightInitDuringBuild(fastWeightInit);
42996 createdLayers[layerName] = layer;
42997 // Gather layer inputs.
42998 const inboundNodesData = layerData['inboundNodes'];
42999 inboundNodesData.forEach(nodeData => {
43000 if (!(nodeData instanceof Array)) {
43001 throw new ValueError(`Corrupted configuration, expected array for nodeData: ${nodeData}`);
43002 }
43003 // We don't process nodes (i.e. make layer calls)
43004 // on the fly because the inbound node may not yet exist,
43005 // in case of layer shared at different topological depths
43006 // (e.g.a model such as A(B(A(B(x)))))
43007 addUnprocessedNode(layer, nodeData);
43008 });
43009 }
43010 // First, we create all layers and enqueue nodes to be processed.
43011 const name = config['name'];
43012 const layersFromConfig = config['layers'];
43013 for (const layerData of layersFromConfig) {
43014 processLayer(layerData);
43015 }
43016 // Then we process nodes in order of layer depth.
43017 // Nodes that cannot yet be processed(if the inbound node
43018 // does not yet exist) are re - enqueued, and the process
43019 // is repeated until all nodes are processed.
43020 while (!isObjectEmpty(unprocessedNodes)) {
43021 for (const layerData of layersFromConfig) {
43022 const layer = createdLayers[layerData['name']];
43023 if (layer.name in unprocessedNodes) {
43024 const currentUnprocessedNodesForLayer = unprocessedNodes[layer.name];
43025 delete unprocessedNodes[layer.name];
43026 for (const nodeData of currentUnprocessedNodesForLayer) {
43027 processNode(layer, nodeData);
43028 }
43029 }
43030 }
43031 }
43032 const inputTensors = [];
43033 const outputTensors = [];
43034 const inputLayersFromConfig = config['inputLayers'];
43035 for (const layerData of inputLayersFromConfig) {
43036 const layerName = layerData[0];
43037 const nodeIndex = layerData[1];
43038 const tensorIndex = layerData[2];
43039 assert(layerName in createdLayers);
43040 const layer = createdLayers[layerName];
43041 const layerOutputTensors = layer.inboundNodes[nodeIndex].outputTensors;
43042 inputTensors.push(layerOutputTensors[tensorIndex]);
43043 }
43044 const outputLayersFromConfig = config['outputLayers'];
43045 for (const layerData of outputLayersFromConfig) {
43046 const layerName = layerData[0];
43047 const nodeIndex = layerData[1];
43048 const tensorIndex = layerData[2];
43049 assert(layerName in createdLayers);
43050 const layer = createdLayers[layerName];
43051 const layerOutputTensors = layer.inboundNodes[nodeIndex].outputTensors;
43052 outputTensors.push(layerOutputTensors[tensorIndex]);
43053 }
43054 return new cls({ inputs: inputTensors, outputs: outputTensors, name });
43055 }
43056 /**
43057 * Determine whether the container is stateful.
43058 *
43059 * Porting Note: this is the equivalent of the stateful @property of
43060 * the Container class in PyKeras.
43061 */
43062 get stateful() {
43063 // Porting Note: This check is to prevent inadvertent setting of the
43064 // _stateful property of the Container instance.
43065 if (this._stateful) {
43066 throw new ValueError('Container instance unexpectedly has _stateful = true. The ' +
43067 'statefulness of a Container is determined by the Layers it ' +
43068 'contains. Its _stateful property must remain the default false.');
43069 }
43070 for (const layer of this.layers) {
43071 if (layer.stateful) {
43072 return true;
43073 }
43074 }
43075 return false;
43076 }
43077 /**
43078 * Reset the state of all stateful constituent layers (if any).
43079 *
43080 * Examples of stateful layers include RNN layers whose `stateful` property
43081 * is set as `true`.
43082 */
43083 resetStates() {
43084 tidy(() => {
43085 this.layers.forEach(layer => {
43086 // tslint:disable:no-any
43087 if (layer.stateful) {
43088 layer.resetStates();
43089 }
43090 // tslint:enable:no-any
43091 });
43092 });
43093 }
43094 }
43095
43096 /**
43097 * @license
43098 * Copyright 2018 Google LLC
43099 *
43100 * Use of this source code is governed by an MIT-style
43101 * license that can be found in the LICENSE file or at
43102 * https://opensource.org/licenses/MIT.
43103 * =============================================================================
43104 */
43105 function standardizeSampleOrClassWeights(xWeight, outputNames, weightType) {
43106 const numOutputs = outputNames.length;
43107 if (xWeight == null || (Array.isArray(xWeight) && xWeight.length === 0)) {
43108 return outputNames.map(name => null);
43109 }
43110 if (numOutputs === 1) {
43111 if (Array.isArray(xWeight) && xWeight.length === 1) {
43112 return xWeight;
43113 }
43114 else if (typeof xWeight === 'object' && outputNames[0] in xWeight) {
43115 return [xWeight[outputNames[0]]];
43116 }
43117 else {
43118 return [xWeight];
43119 }
43120 }
43121 if (Array.isArray(xWeight)) {
43122 if (xWeight.length !== numOutputs) {
43123 throw new Error(`Provided ${weightType} is an array of ${xWeight.length} ` +
43124 `element(s), but the model has ${numOutputs} outputs. ` +
43125 `Make sure a set of weights is provided for each model output.`);
43126 }
43127 return xWeight;
43128 }
43129 else if (typeof xWeight === 'object' && Object.keys(xWeight).length > 0 &&
43130 typeof xWeight[Object.keys(xWeight)[0]] ===
43131 'object') {
43132 const output = [];
43133 outputNames.forEach(outputName => {
43134 if (outputName in xWeight) {
43135 output.push(xWeight[outputName]);
43136 }
43137 else {
43138 output.push(null);
43139 }
43140 });
43141 return output;
43142 }
43143 else {
43144 throw new Error(`The model has multiple (${numOutputs}) outputs, ` +
43145 `so ${weightType} must be either an array with ` +
43146 `${numOutputs} elements or an object with ${outputNames} keys. ` +
43147 `Provided ${weightType} not understood: ${JSON.stringify(xWeight)}`);
43148 }
43149 }
43150 /**
43151 * Standardize class weighting objects.
43152 *
43153 * This function takes a single class-weighting object, an array of them,
43154 * or a map from output name to class-weighting object. It compares it to the
43155 * output name(s) of the model, base on which it outputs an array of
43156 * class-weighting objects of which the length matches the number of outputs.
43157 *
43158 * @param classWeight Input class-weighting object(s).
43159 * @param outputNames All output name(s) of the model.
43160 * @return An array of class-weighting objects. The length of the array matches
43161 * the model's number of outputs.
43162 */
43163 function standardizeClassWeights(classWeight, outputNames) {
43164 return standardizeSampleOrClassWeights(classWeight, outputNames, 'classWeight');
43165 }
43166 function standardizeSampleWeights(classWeight, outputNames) {
43167 return standardizeSampleOrClassWeights(classWeight, outputNames, 'sampleWeight');
43168 }
43169 /**
43170 * Standardize by-sample and/or by-class weights for training.
43171 *
43172 * Note that this function operates on one model output at a time. For a model
43173 * with multiple outputs, you must call this function multiple times.
43174 *
43175 * @param y The target tensor that the by-sample and/or by-class weight is for.
43176 * The values of y are assumed to encode the classes, either directly
43177 * as an integer index, or as one-hot encoding.
43178 * @param sampleWeight By-sample weights.
43179 * @param classWeight By-class weights: an object mapping class indices
43180 * (integers) to a weight (float) to apply to the model's loss for the
43181 * samples from this class during training. This can be useful to tell the
43182 * model to "pay more attention" to samples from an under-represented class.
43183 * @param sampleWeightMode The mode for the sample weights.
43184 * @return A Promise of weight tensor, of which the size of the first dimension
43185 * matches that of `y`.
43186 */
43187 async function standardizeWeights(y, sampleWeight, classWeight, sampleWeightMode) {
43188 if (sampleWeight != null || sampleWeightMode != null) {
43189 // TODO(cais): Once 'temporal' mode is implemented, document it in the doc
43190 // string.
43191 throw new Error('Support sampleWeight is not implemented yet');
43192 }
43193 if (classWeight != null) {
43194 // Apply class weights per sample.
43195 const yClasses = tidy(() => {
43196 if (y.shape.length === 1) {
43197 // Assume class indices.
43198 return clone(y);
43199 }
43200 else if (y.shape.length === 2) {
43201 if (y.shape[1] > 1) {
43202 // Assume one-hot encoding of classes.
43203 const axis = 1;
43204 return argMax$2(y, axis);
43205 }
43206 else if (y.shape[1] === 1) {
43207 // Class index.
43208 return reshape$3(y, [y.shape[0]]);
43209 }
43210 else {
43211 throw new Error(`Encountered unexpected last-dimension size (${y.shape[1]}) ` +
43212 `during handling of class weights. The size is expected to be ` +
43213 `>= 1.`);
43214 }
43215 }
43216 else {
43217 throw new Error(`Unexpected rank of target (y) tensor (${y.rank}) during ` +
43218 `handling of class weights. The rank is expected to be 1 or 2.`);
43219 }
43220 });
43221 const yClassIndices = Array.from(await yClasses.data());
43222 dispose(yClasses);
43223 const classSampleWeight = [];
43224 yClassIndices.forEach(classIndex => {
43225 if (classWeight[classIndex] == null) {
43226 throw new Error(`classWeight must contain all classes in the training data. ` +
43227 `The class ${classIndex} exists in the data but not in ` +
43228 `classWeight`);
43229 }
43230 else {
43231 classSampleWeight.push(classWeight[classIndex]);
43232 }
43233 });
43234 return tensor1d(classSampleWeight, 'float32');
43235 }
43236 else {
43237 return null;
43238 }
43239 }
43240 /**
43241 * Apply per-sample weights on the loss values from a number of samples.
43242 *
43243 * @param losses Loss tensor of shape `[batchSize]`.
43244 * @param sampleWeights Per-sample weight tensor of shape `[batchSize]`.
43245 * @returns Tensor of the same shape as`losses`.
43246 */
43247 function computeWeightedLoss(losses, sampleWeights) {
43248 return mul(losses, sampleWeights);
43249 }
43250
43251 /**
43252 * @license
43253 * Copyright 2018 Google LLC
43254 *
43255 * Use of this source code is governed by an MIT-style
43256 * license that can be found in the LICENSE file or at
43257 * https://opensource.org/licenses/MIT.
43258 * =============================================================================
43259 */
43260 // Default batch size used during tensor-based validation.
43261 const DEFAULT_VALIDATION_BATCH_SIZE = 32;
43262 /**
43263 * Standardize the output of a dataset iterator for use by
43264 * LayersModel.fitDataset().
43265 *
43266 * @param model: A `tf.LayersModel` object.
43267 * @param iteratorOut The output of a dataset iterator. It is required to be
43268 * an object of the form `{xs: TensorOrArrayOrMap, ys:
43269 * TensorOrArrayOrMap}`, where `TensorOrArrayOrMap` is a single `tf.Tensor`,
43270 * a `tf.Tensor[]`, or a flat map from string names to `tf.Tensor`s.
43271 * @returns A flat array of `tf.Tensor` objects: the input `tf.Tensor`s
43272 * followed by the target `tf.Tensor`s. When `tf.Tensor`s are provided
43273 * as a map, the order in the resulting array is taken from the `inputNames`
43274 * and `outputNames` of the model.
43275 */
43276 function standardizeDataIteratorOutput(
43277 // Type `model` as `any` here to avoid circular dependency w/
43278 // training.ts.
43279 // tslint:disable-next-line:no-any
43280 model, iteratorOut) {
43281 let xs;
43282 let ys;
43283 const iteratorOutObj = iteratorOut;
43284 xs = iteratorOutObj['xs'];
43285 ys = iteratorOutObj['ys'];
43286 assert$1(xs != null && ys != null, () => 'A Dataset iterator for fitDataset() is expected to generate ' +
43287 'objects of the form `{xs: xVal, ys: yVal}`, where the two ' +
43288 'values may be `tf.Tensor`, an array of Tensors, or a map of ' +
43289 'string to Tensor. The provided Dataset instead generates ' +
43290 `${iteratorOut}`);
43291 const flattenedXs = flattenTensorOrArrayOrMap('input', model.inputNames, xs);
43292 const flattenedYs = flattenTensorOrArrayOrMap('output', model.outputNames, ys);
43293 const batchSize = flattenedXs[0].shape[0];
43294 assert$1(flattenedXs.length === model.inputs.length, () => `LayersModel has ${model.inputs.length} inputs, but the dataset ` +
43295 `provides ${flattenedXs.length} inputs. (Expected input keys: ` +
43296 `${JSON.stringify(model.inputNames)})`);
43297 assert$1(flattenedYs.length === model.outputs.length, () => `LayersModel has ${model.outputs.length} outputs, but the dataset ` +
43298 `provides ${flattenedYs.length} outputs. (Expected output keys: ` +
43299 `${JSON.stringify(model.outputNames)})`);
43300 for (let xIndex = 0; xIndex < flattenedXs.length; xIndex++) {
43301 assert$1(flattenedXs[xIndex].shape[0] === batchSize, () => `Batch size mismatch: input ` +
43302 `${model.inputNames[xIndex]} has ${flattenedXs[xIndex].shape[0]}; ` +
43303 `expected ${batchSize} based on input ${model.inputNames[0]}.`);
43304 }
43305 for (let yIndex = 0; yIndex < flattenedYs.length; yIndex++) {
43306 assert$1(flattenedYs[yIndex].shape[0] === batchSize, () => `Batch size mismatch: output ` +
43307 `${model.outputNames[yIndex]} has ${flattenedYs[yIndex].shape[0]}; ` +
43308 `expected ${batchSize} based on input ${model.inputNames[0]}.`);
43309 }
43310 return { xs: flattenedXs, ys: flattenedYs };
43311 }
43312 function flattenTensorOrArrayOrMap(inputOrOutput, names, values) {
43313 if (values instanceof Tensor) {
43314 return [values];
43315 }
43316 else if (Array.isArray(values)) {
43317 assert$1(values.length === names.length, () => `Received an array of ${values.length} Tensors, but expected ${names.length} to match the ${inputOrOutput} keys ${names}.`);
43318 return values;
43319 }
43320 else {
43321 const result = [];
43322 // Check that all the required keys are available.
43323 for (const name of names) {
43324 if (values[name] == null) {
43325 throw new ValueError(`The feature data generated by the dataset lacks the required ` +
43326 `${inputOrOutput} key '${name}'.`);
43327 }
43328 result.push(values[name]);
43329 }
43330 return result;
43331 }
43332 }
43333 function standardizeTensorValidationData(data) {
43334 if (data.length === 3) {
43335 throw new NotImplementedError('Validation with sample weights is not implemented yet.');
43336 }
43337 return { xs: data[0], ys: data[1] };
43338 }
43339 async function fitDataset(
43340 // Type `model` as `any` here to avoid circular dependency w/
43341 // training.ts.
43342 // tslint:disable-next-line:no-any
43343 model, dataset, args) {
43344 const hasBatchesPerEpoch = args.batchesPerEpoch != null;
43345 assert$1(model.optimizer != null, () => 'You must compile a model before training/testing. Use ' +
43346 'LayersModel.compile(modelCompileConfig).');
43347 assert$1(args != null, () => `For fitDataset(), the 2nd argument (config) is required, ` +
43348 `but it is not provided in this call.`);
43349 assert$1(args.epochs != null && args.epochs > 0 && Number.isInteger(args.epochs), () => `For fitDataset(), config.epochs is expected to be a positive ` +
43350 `integer, but got ${args.epochs}`);
43351 assert$1(!hasBatchesPerEpoch ||
43352 (args.batchesPerEpoch > 0 && Number.isInteger(args.batchesPerEpoch)), () => `For fitDataset(), config.batchesPerEpoch is expected to be a ` +
43353 `positive integer if specified, but got ${args.batchesPerEpoch}`);
43354 assert$1(
43355 // tslint:disable-next-line:no-any
43356 args['validationSplit'] == null, () => '`validationSplit` is not supported by `fitDataset()`. ' +
43357 'Use validationData instead.');
43358 if (model.isTraining) {
43359 throw new Error('Cannot start training because another fit() call is ongoing.');
43360 }
43361 model.isTraining = true;
43362 try {
43363 const doValidation = args.validationData != null;
43364 let valXs;
43365 let valYs;
43366 if (doValidation) {
43367 if (isDatasetObject(args.validationData)) {
43368 assert$1(args.validationBatches == null ||
43369 (args.validationBatches > 0 &&
43370 Number.isInteger(args.validationBatches)), () => `For fitDataset() with dataset-based validation, ` +
43371 `config.validationBatches is expected not to be provided, ` +
43372 `or to be a positive integer, ` +
43373 `but got ${args.validationBatches}`);
43374 }
43375 else {
43376 const validationData = standardizeTensorValidationData(args.validationData);
43377 valXs = validationData.xs;
43378 valYs = validationData.ys;
43379 }
43380 }
43381 const trainFunction = model.makeTrainFunction();
43382 const outLabels = model.getDedupedMetricsNames();
43383 let callbackMetrics;
43384 if (doValidation) {
43385 callbackMetrics =
43386 outLabels.slice().concat(outLabels.map(n => 'val_' + n));
43387 }
43388 else {
43389 callbackMetrics = outLabels.slice();
43390 }
43391 const callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery);
43392 const verbose = args.verbose == null ? 1 : args.verbose;
43393 const { callbackList, history } = configureCallbacks(callbacks, verbose, args.epochs, null, null, getStepsPerEpoch(dataset, args), null, // Batch size determined by the dataset itself.
43394 doValidation, callbackMetrics);
43395 callbackList.setModel(model);
43396 model.history = history;
43397 await callbackList.onTrainBegin();
43398 model.stopTraining_ = false;
43399 let epoch = args.initialEpoch == null ? 0 : args.initialEpoch;
43400 let dataIterator = await dataset.iterator();
43401 while (epoch < args.epochs) {
43402 const epochLogs = {};
43403 await callbackList.onEpochBegin(epoch);
43404 let stepsDone = 0;
43405 let batchIndex = 0;
43406 if (!hasBatchesPerEpoch) {
43407 dataIterator = await dataset.iterator();
43408 }
43409 while (hasBatchesPerEpoch ? stepsDone < args.batchesPerEpoch : true) {
43410 const iteratorOut = await dataIterator.next();
43411 // If `batchesPerEpoch` is specified, the dataset should not be
43412 // exhausted until all epoches are done.
43413 if (hasBatchesPerEpoch && iteratorOut.done) {
43414 console.warn('You provided `batchesPerEpoch` as ' +
43415 `${args.batchesPerEpoch}, ` +
43416 'but your dataset iterator ran out of data after ' +
43417 `${stepsDone} batches; ` +
43418 'interrupting training. Make sure that your ' +
43419 'dataset can generate at least `batchesPerEpoch * epochs` ' +
43420 'batches (in this case, ' +
43421 `${args.batchesPerEpoch * args.epochs} batches). ` +
43422 'You may need to use the repeat() function when building ' +
43423 'your dataset.');
43424 break;
43425 }
43426 if (iteratorOut.value != null) {
43427 const { xs, ys } = standardizeDataIteratorOutput(model, iteratorOut.value);
43428 const batchLogs = {};
43429 batchLogs['batch'] = batchIndex;
43430 batchLogs['size'] = xs[0].shape[0];
43431 await callbackList.onBatchBegin(batchIndex, batchLogs);
43432 const sampleWeights = [];
43433 if (args.classWeight != null) {
43434 const standardClassWeights = standardizeClassWeights(args.classWeight, model.outputNames);
43435 for (let i = 0; i < standardClassWeights.length; ++i) {
43436 sampleWeights.push(await standardizeWeights(ys[i], null, standardClassWeights[i]));
43437 }
43438 }
43439 // Train on batch.
43440 const ins = xs.concat(ys).concat(sampleWeights);
43441 const outs = trainFunction(ins);
43442 dispose(ins);
43443 for (let i = 0; i < outLabels.length; ++i) {
43444 const label = outLabels[i];
43445 const out = outs[i];
43446 batchLogs[label] = out;
43447 keep(out);
43448 }
43449 await callbackList.onBatchEnd(batchIndex, batchLogs);
43450 disposeTensorsInLogs(batchLogs);
43451 batchIndex++;
43452 stepsDone++;
43453 }
43454 if (hasBatchesPerEpoch ? stepsDone >= args.batchesPerEpoch :
43455 iteratorOut.done) {
43456 // Epoch finished. Perform validation.
43457 if (doValidation) {
43458 let valOuts;
43459 if (isDatasetObject(args.validationData)) {
43460 valOuts = toList(await model.evaluateDataset(args.validationData, { batches: args.validationBatches }));
43461 }
43462 else {
43463 valOuts = toList(model.evaluate(valXs, valYs, {
43464 batchSize: args.validationBatchSize == null ?
43465 DEFAULT_VALIDATION_BATCH_SIZE :
43466 args.validationBatchSize,
43467 verbose: 0
43468 }));
43469 }
43470 for (let i = 0; i < model.metricsNames.length; ++i) {
43471 epochLogs[`val_${model.metricsNames[i]}`] = valOuts[i];
43472 }
43473 }
43474 // Call `break` to exit one epoch lopp after validation is done. If
43475 // config.batchesPerEpoch is specified, an epoch while loop will
43476 // stop when `stepsDone >= config.batchesPerEpoch`. When
43477 // config.batchesPerEpoch is not provided, the following `break` is
43478 // required to exit the while lopp after dataset is exhausted.
43479 break;
43480 }
43481 if (model.stopTraining_) {
43482 break;
43483 }
43484 }
43485 await callbackList.onEpochEnd(epoch, epochLogs);
43486 epoch++;
43487 if (model.stopTraining_) {
43488 break;
43489 }
43490 }
43491 await callbackList.onTrainEnd();
43492 await model.history.syncData();
43493 return model.history;
43494 }
43495 finally {
43496 model.isTraining = false;
43497 }
43498 }
43499 /** Helper function that determines number of steps (batches) per epoch. */
43500 function getStepsPerEpoch(dataset, args) {
43501 // Attempt to determine # of batches in an epoch.
43502 let stepsPerEpoch = null;
43503 if (args.batchesPerEpoch != null) {
43504 stepsPerEpoch = args.batchesPerEpoch;
43505 }
43506 else if (Number.isFinite(dataset.size)) {
43507 stepsPerEpoch = dataset.size;
43508 }
43509 return stepsPerEpoch;
43510 }
43511 // Check if provided object is a Dataset object by checking its .iterator
43512 // element.
43513 function isDatasetObject(dataset) {
43514 return (typeof dataset.iterator === 'function');
43515 }
43516 // Check if provided object is a LazyIterator object by checking it's .next
43517 // element.
43518 function isLazyIteratorObject(iterator) {
43519 return (typeof iterator.next === 'function');
43520 }
43521 async function evaluateDataset(
43522 // Type `model` as `any` here to avoid circular dependency w/
43523 // training.ts.
43524 // tslint:disable-next-line:no-any
43525 model, dataset, args) {
43526 args = args || {};
43527 const hasBatches = args.batches != null;
43528 const f = model.testFunction;
43529 let outs = [];
43530 if (args.verbose > 0) {
43531 throw new NotImplementedError('Verbose mode is not implemented yet.');
43532 }
43533 assert$1(!hasBatches || (args.batches > 0 && Number.isInteger(args.batches)), () => 'Test loop expects `batches` to be a positive integer, but ' +
43534 `received ${JSON.stringify(args.batches)}`);
43535 const dataIterator = isLazyIteratorObject(dataset) ?
43536 dataset :
43537 await dataset.iterator();
43538 // Keeps track of number of examples used in this evaluation.
43539 let numExamples = 0;
43540 let batch = 0;
43541 while (hasBatches ? batch < args.batches : true) {
43542 const iteratorOut = await dataIterator.next();
43543 outs = tidy(() => {
43544 if (iteratorOut.value) {
43545 // TODO(cais): Once real dataset is available, use
43546 // `map(x => standardizeDataIteratorOutput(model, x).map(f)`.
43547 const { xs, ys } = standardizeDataIteratorOutput(model, iteratorOut.value);
43548 const xsAndYs = xs.concat(ys);
43549 const batchOuts = tidy(() => f(xsAndYs));
43550 dispose(xsAndYs);
43551 if (batch === 0) {
43552 for (let i = 0; i < batchOuts.length; ++i) {
43553 outs.push(scalar(0));
43554 }
43555 }
43556 const batchSize = xsAndYs[0].shape[0];
43557 for (let i = 0; i < batchOuts.length; ++i) {
43558 const batchOut = batchOuts[i];
43559 const oldScalar = outs[i];
43560 outs[i] =
43561 tidy(() => add$3(outs[i], mul(batchSize, batchOut)));
43562 if (batch > 0) {
43563 dispose(oldScalar);
43564 }
43565 }
43566 dispose(batchOuts);
43567 numExamples += batchSize;
43568 ++batch;
43569 }
43570 return outs;
43571 });
43572 if (iteratorOut.done) {
43573 if (hasBatches) {
43574 console.warn('Your dataset iterator ran out of data during evaluateDataset(). ' +
43575 'Interrupting evalution. Make sure that your ' +
43576 'dataset can generate at least `batches` ' +
43577 `batches (in this case, ${args.batches} batches). ` +
43578 'You may need to use the repeat() function when building ' +
43579 'your dataset.');
43580 }
43581 break;
43582 }
43583 }
43584 for (let i = 0; i < outs.length; ++i) {
43585 const oldScalar = outs[i];
43586 outs[i] = div$1(outs[i], numExamples);
43587 dispose(oldScalar);
43588 }
43589 return singletonOrArray(outs);
43590 }
43591
43592 /**
43593 * @license
43594 * Copyright 2018 Google LLC
43595 *
43596 * Use of this source code is governed by an MIT-style
43597 * license that can be found in the LICENSE file or at
43598 * https://opensource.org/licenses/MIT.
43599 * =============================================================================
43600 */
43601 function checkBatchSize(batchSize) {
43602 assert$1(batchSize > 0 && Number.isInteger(batchSize), () => `batchSize is required to be a positive integer, but got ${batchSize}`);
43603 }
43604 /**
43605 * Slice a Tensor or an Array of Tensors, by start and stop indices.
43606 *
43607 * Porting Note: The `_slice_arrays` function in PyKeras is covered by this
43608 * function and `sliceArraysByIndices()` together.
43609 *
43610 * @param arrays: the input.
43611 * @param start: the starting index (inclusive).
43612 * @param stop: the stopping index (exclusive).
43613 * @returns The result of the slicing. If `arrays` is an `Array` of
43614 * `tf.Tensor`s, the slicing will be applied to all elements of the `Array`
43615 * in the same way.
43616 */
43617 function sliceArrays(arrays, start, stop) {
43618 if (arrays == null) {
43619 return [null];
43620 }
43621 else if (Array.isArray(arrays)) {
43622 return arrays.map(array => sliceAlongFirstAxis(array, start, stop - start));
43623 }
43624 else { // Tensor.
43625 return sliceAlongFirstAxis(arrays, start, stop - start);
43626 }
43627 }
43628 /**
43629 * Slice a Tensor or an Array of Tensors, by random-order indices.
43630 *
43631 * Porting Note: The `_slice_arrays` function in PyKeras is covered by this
43632 * function and `sliceArrays()` together.
43633 *
43634 * @param arrays The input `tf.Tensor` or `Array` of `tf.Tensor`s to slice.
43635 * If an `Array` of `tf.Tensor`s, all `tf.Tensor`s will be sliced in the
43636 * same fashion.
43637 * @param indices The indices to use for slicing along the first (batch)
43638 * dimension.
43639 * @returns Result(s) of the slicing.
43640 */
43641 function sliceArraysByIndices(arrays, indices) {
43642 return tidy(() => {
43643 if (arrays == null) {
43644 return null;
43645 }
43646 else if (Array.isArray(arrays)) {
43647 return arrays.map(array => sliceArraysByIndices(array, indices));
43648 }
43649 else {
43650 // TODO(cais): indices should be a pre-constructed Tensor1D to avoid
43651 // tensor1d() calls.
43652 return gather(arrays, indices.dtype === 'int32' ? indices : cast$3(indices, 'int32'));
43653 }
43654 });
43655 }
43656 /**
43657 * Returns a list of batch indices (tuples of indices).
43658 * @param size: Integer, total size of the data to slice into batches.
43659 * @param batchSize: Integer, batch size.
43660 * @returns An Array of [batchStart, batchEnd] tuples. batchStart is
43661 * inclusive; batchEnd is exclusive. I.e., each batch consists of indices x
43662 * that satisfy batchStart <= x < batchEnd.
43663 */
43664 function makeBatches(size, batchSize) {
43665 const output = [];
43666 let batchStart = 0;
43667 let batchEnd = null;
43668 while (batchStart < size) {
43669 batchEnd = batchStart + batchSize;
43670 if (batchEnd >= size) {
43671 batchEnd = size;
43672 }
43673 output.push([batchStart, batchEnd]);
43674 batchStart = batchEnd;
43675 }
43676 return output;
43677 }
43678 /**
43679 * Ensure tensors all have a rank of at least 2.
43680 *
43681 * If a tensor has a rank of 1, it is dimension-expanded to rank 2.
43682 * If any tensor has a rank of 0 (i.e., is a scalar), an error will be thrown.
43683 */
43684 function ensureTensorsRank2OrHigher(tensors) {
43685 const outs = [];
43686 if (tensors instanceof Tensor) {
43687 tensors = [tensors];
43688 }
43689 // Make Tensors at least 2D.
43690 for (let i = 0; i < tensors.length; ++i) {
43691 const tensor = tensors[i];
43692 if (tensor.rank === 1) {
43693 outs.push(expandDims$2(tensor, 1));
43694 }
43695 else if (tensor.rank === 0) {
43696 throw new Error('Expected tensor to be at least 1D, but received a 0D tensor ' +
43697 '(scalar).');
43698 }
43699 else {
43700 outs.push(tensor);
43701 }
43702 }
43703 return outs;
43704 }
43705 /**
43706 * Compare a set of tensors with a reference (old) set, discard the ones
43707 * in the new set that are not present in the reference set.
43708 *
43709 * This method is used for memory clenaup during calls such as
43710 * LayersModel.fit().
43711 *
43712 * @param tensors New set which may contain Tensors not present in
43713 * `refTensors`.
43714 * @param refTensors Reference Tensor set.
43715 */
43716 // TODO(cais, kangyizhang): Deduplicate with tfjs-data.
43717 function disposeNewTensors(tensors, refTensors) {
43718 if (tensors == null) {
43719 return;
43720 }
43721 const oldTensorIds = [];
43722 if (refTensors instanceof Tensor) {
43723 oldTensorIds.push(refTensors.id);
43724 }
43725 else if (Array.isArray(refTensors)) {
43726 refTensors.forEach(t => oldTensorIds.push(t.id));
43727 }
43728 else if (refTensors != null) {
43729 // `oldTensors` is a map from string name to Tensor.
43730 for (const name in refTensors) {
43731 const oldTensor = refTensors[name];
43732 oldTensorIds.push(oldTensor.id);
43733 }
43734 }
43735 const tensorsToDispose = [];
43736 if (tensors instanceof Tensor) {
43737 if (oldTensorIds.indexOf(tensors.id) === -1) {
43738 tensorsToDispose.push(tensors);
43739 }
43740 }
43741 else if (Array.isArray(tensors)) {
43742 tensors.forEach(t => {
43743 if (oldTensorIds.indexOf(t.id) === -1) {
43744 tensorsToDispose.push(t);
43745 }
43746 });
43747 }
43748 else if (tensors != null) {
43749 // `oldTensors` is a map from string name to Tensor.
43750 for (const name in tensors) {
43751 const tensor = tensors[name];
43752 if (oldTensorIds.indexOf(tensor.id) === -1) {
43753 tensorsToDispose.push(tensor);
43754 }
43755 }
43756 }
43757 tensorsToDispose.forEach(t => {
43758 if (!t.isDisposed) {
43759 t.dispose();
43760 }
43761 });
43762 }
43763
43764 /**
43765 * @license
43766 * Copyright 2018 Google LLC
43767 *
43768 * Use of this source code is governed by an MIT-style
43769 * license that can be found in the LICENSE file or at
43770 * https://opensource.org/licenses/MIT.
43771 * =============================================================================
43772 */
43773 /**
43774 * Helper function for polymorphic input data: 1. singleton Tensor.
43775 */
43776 function isDataTensor(x) {
43777 return x instanceof Tensor;
43778 }
43779 /**
43780 * Helper function for polymorphic input data: 2. Array of Tensor.
43781 */
43782 function isDataArray(x) {
43783 return Array.isArray(x);
43784 }
43785 /**
43786 * Helper function for polymorphic input data: 3. "dict" of Tensor.
43787 */
43788 function isDataDict(x) {
43789 return !isDataTensor(x) && !isDataArray(x);
43790 }
43791 /**
43792 * Normalizes inputs and targets provided by users.
43793 * @param data User-provided input data (polymorphic).
43794 * @param names An Array of expected Tensor names.
43795 * @param shapes Optional Array of expected Tensor shapes.
43796 * @param checkBatchAxis Whether to check that the batch axis of the arrays
43797 * match the expected value found in `shapes`.
43798 * @param exceptionPrefix String prefix used for exception formatting.
43799 * @returns List of standardized input Tensors (one Tensor per model input).
43800 * @throws ValueError: in case of improperly formatted user data.
43801 */
43802 function standardizeInputData(data, names, shapes, checkBatchAxis = true, exceptionPrefix = '') {
43803 if (names == null || names.length === 0) {
43804 // Check for the case where the model expected no data, but some data got
43805 // sent.
43806 if (data != null) {
43807 let gotUnexpectedData = false;
43808 if (isDataArray(data) && data.length > 0) {
43809 gotUnexpectedData = true;
43810 }
43811 else if (isDataDict(data)) {
43812 for (const key in data) {
43813 if (data.hasOwnProperty(key)) {
43814 gotUnexpectedData = true;
43815 break;
43816 }
43817 }
43818 }
43819 else {
43820 // `data` is a singleton Tensor in this case.
43821 gotUnexpectedData = true;
43822 }
43823 if (gotUnexpectedData) {
43824 throw new ValueError(`Error when checking model ${exceptionPrefix} expected no data, ` +
43825 `but got ${data}`);
43826 }
43827 }
43828 return [];
43829 }
43830 if (data == null) {
43831 return names.map(name => null);
43832 }
43833 let arrays;
43834 if (isDataDict(data)) {
43835 data = data;
43836 arrays = [];
43837 for (const name of names) {
43838 if (data[name] == null) {
43839 throw new ValueError(`No data provided for "${name}". Need data for each key in: ` +
43840 `${names}`);
43841 }
43842 arrays.push(data[name]);
43843 }
43844 }
43845 else if (isDataArray(data)) {
43846 data = data;
43847 if (data.length !== names.length) {
43848 throw new ValueError(`Error when checking model ${exceptionPrefix}: the Array of ` +
43849 `Tensors that you are passing to your model is not the size the ` +
43850 `model expected. Expected to see ${names.length} Tensor(s), but ` +
43851 `instead got the following list of Tensor(s): ${data}`);
43852 }
43853 arrays = data;
43854 }
43855 else {
43856 data = data;
43857 if (names.length > 1) {
43858 throw new ValueError(`The model ${exceptionPrefix} expects ${names.length} Tensor(s), ` +
43859 `but only received one Tensor. Found: Tensor with shape ${data.shape}`);
43860 }
43861 arrays = [data];
43862 }
43863 arrays = ensureTensorsRank2OrHigher(arrays);
43864 // Check shape compatibility.
43865 if (shapes != null) {
43866 for (let i = 0; i < names.length; ++i) {
43867 if (shapes[i] == null) {
43868 continue;
43869 }
43870 const array = arrays[i];
43871 if (array.shape.length !== shapes[i].length) {
43872 throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
43873 `to have ${shapes[i].length} dimension(s). but got array with ` +
43874 `shape ${array.shape}`);
43875 }
43876 for (let j = 0; j < shapes[i].length; ++j) {
43877 if (j === 0 && !checkBatchAxis) {
43878 // Skip the first (batch) axis.
43879 continue;
43880 }
43881 const dim = array.shape[j];
43882 const refDim = shapes[i][j];
43883 if (refDim != null && refDim >= 0 && dim !== refDim) {
43884 throw new ValueError(`${exceptionPrefix} expected a batch of elements where each ` +
43885 `example has shape [${shapes[i].slice(1, shapes[i].length)}] ` +
43886 `(i.e.,tensor shape [*,${shapes[i].slice(1, shapes[i].length)}])` +
43887 ` but the ${exceptionPrefix} received an input with ${array.shape[0]}` +
43888 ` examples, each with shape [${array.shape.slice(1, array.shape.length)}]` +
43889 ` (tensor shape [${array.shape}])`);
43890 }
43891 }
43892 }
43893 }
43894 return arrays;
43895 }
43896 /**
43897 * User input validation for Tensors.
43898 * @param inputs `Array` of `tf.Tensor`s for inputs.
43899 * @param targets `Array` of `tf.Tensor`s for targets.
43900 * @param weights Optional `Array` of `tf.Tensor`s for sample weights.
43901 * @throws ValueError: in case of incorrectly formatted data.
43902 */
43903 function checkArrayLengths(inputs, targets, weights) {
43904 const setX = unique$2(inputs.map(input => input.shape[0]));
43905 setX.sort();
43906 const setY = unique$2(targets.map(target => target.shape[0]));
43907 setY.sort();
43908 // TODO(cais): Check `weights` as well.
43909 if (setX.length > 1) {
43910 throw new ValueError(`All input Tensors (x) should have the same number of samples. ` +
43911 `Got array shapes: ` +
43912 `${JSON.stringify(inputs.map(input => input.shape))}`);
43913 }
43914 if (setY.length > 1) {
43915 throw new ValueError(`All target Tensors (y) should have the same number of samples. ` +
43916 `Got array shapes: ` +
43917 `${JSON.stringify(targets.map(target => target.shape))}`);
43918 }
43919 if (setX.length > 0 && setY.length > 0 && !arraysEqual(setX, setY)) {
43920 throw new ValueError(`Input Tensors should have the same number of samples as target ` +
43921 `Tensors. Found ${setX[0]} input sample(s) and ${setY[0]} target ` +
43922 `sample(s).`);
43923 }
43924 }
43925 /**
43926 * Validation on the compatibility of targes and loss functions.
43927 *
43928 * This helps prevent users from using loss functions incorrectly.
43929 *
43930 * @param targets `Array` of `tf.Tensor`s of targets.
43931 * @param lossFns `Array` of loss functions.
43932 * @param outputShapes `Array` of shapes of model outputs.
43933 */
43934 function checkLossAndTargetCompatibility(targets, lossFns, outputShapes) {
43935 // TODO(cais): Dedicated test coverage?
43936 const keyLosses = [
43937 meanSquaredError$1, binaryCrossentropy$2,
43938 categoricalCrossentropy$2
43939 ];
43940 for (let i = 0; i < targets.length; ++i) {
43941 const y = targets[i];
43942 const loss = lossFns[i];
43943 const shape = outputShapes[i];
43944 if (loss == null) {
43945 continue;
43946 }
43947 if (loss === categoricalCrossentropy$2) {
43948 if (y.shape[y.shape.length - 1] === 1) {
43949 throw new ValueError(`You are passing a target array of shape ${y.shape} while using ` +
43950 `a loss 'categorical_crossentropy'. 'categorical_crossentropy'` +
43951 `expects targets to be binary matrices (1s and 0s) of shape ` +
43952 `[samples, classes].`);
43953 // TODO(cais): Example code in error message.
43954 }
43955 }
43956 if (keyLosses.indexOf(loss) !== -1) {
43957 const slicedYShape = y.shape.slice(1);
43958 const slicedShape = shape.slice(1);
43959 for (let j = 0; j < slicedYShape.length; ++j) {
43960 const targetDim = slicedYShape[j];
43961 const outDim = slicedShape[j];
43962 if (outDim != null && targetDim !== outDim) {
43963 throw new ValueError(`A target Tensor with shape ${y.shape} was passed for an ` +
43964 `output of shape ${shape}, while using a loss function that ` +
43965 `expects targets to have the same shape as the output.`);
43966 }
43967 }
43968 }
43969 }
43970 }
43971 /**
43972 * Check inputs provided by the user.
43973 *
43974 * Porting Note: This corresponds to _standardize_input_data() in Python
43975 * Keras. Because of the strong typing in TF.js, we do not need to convert
43976 * the data. Specifically:
43977 * 1) in PyKeras, `data` can be `DataFrame` instances from pandas, for
43978 * example. We don't need to worry about that here because there is no
43979 * widely popular javascript/typesdcript equivalent of pandas (so far).
43980 * If one becomes available in the future, we can add support.
43981 * 2) in PyKeras, inputs can be Python dict. But here we are stipulating
43982 * that the data is either a single `tf.Tensor` or an Array of `tf.Tensor`s. We
43983 * may add support for `Object` data inputs in the future when the need
43984 * arises.
43985 *
43986 * Instead, we perform basic checks for number of parameters and shapes.
43987 *
43988 * @param data: The input data.
43989 * @param names: Name for the inputs, from the model.
43990 * @param shapes: Expected shapes for the input data, from the model.
43991 * @param checkBatchAxis: Whether the size along the batch axis (i.e., the
43992 * first dimension) will be checked for matching.
43993 * @param exceptionPrefix: Execption prefix message, used in generating error
43994 * messages.
43995 * @throws ValueError: on incorrect number of inputs or mismatches in shapes.
43996 */
43997 function checkInputData(data, names, shapes, checkBatchAxis = true, exceptionPrefix = '') {
43998 let arrays;
43999 if (Array.isArray(data)) {
44000 if (data.length !== names.length) {
44001 throw new ValueError(`Error when checking model ${exceptionPrefix}: the Array of ` +
44002 `Tensors that you are passing to your model is not the size the ` +
44003 `the model expected. Expected to see ${names.length} Tensor(s),` +
44004 ` but instead got ${data.length} Tensors(s).`);
44005 }
44006 arrays = data;
44007 }
44008 else {
44009 if (names.length > 1) {
44010 throw new ValueError(`The model expects ${names.length} ${exceptionPrefix} Tensors, ` +
44011 `but only received one Tensor. Found: array with shape ` +
44012 `${JSON.stringify(data.shape)}.`);
44013 }
44014 arrays = [data];
44015 }
44016 if (shapes != null) {
44017 for (let i = 0; i < names.length; ++i) {
44018 if (shapes[i] == null) {
44019 continue;
44020 }
44021 const array = arrays[i];
44022 if (array.shape.length !== shapes[i].length) {
44023 throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
44024 `to have ${shapes[i].length} dimension(s), but got array with ` +
44025 `shape ${JSON.stringify(array.shape)}`);
44026 }
44027 for (let j = 0; j < shapes[i].length; ++j) {
44028 if (j === 0 && !checkBatchAxis) {
44029 continue;
44030 }
44031 const dim = array.shape[j];
44032 const refDim = shapes[i][j];
44033 if (refDim != null) {
44034 if (refDim !== dim) {
44035 throw new ValueError(`Error when checking ${exceptionPrefix}: expected ` +
44036 `${names[i]} to have shape ${JSON.stringify(shapes[i])} but ` +
44037 `got array with shape ${JSON.stringify(array.shape)}.`);
44038 }
44039 }
44040 }
44041 }
44042 }
44043 }
44044 /**
44045 * Maps metric functions to model outputs.
44046 * @param metrics An shortcut strings name, metric function, `Array` or dict
44047 * (`Object`) of metric functions.
44048 * @param outputNames An `Array` of the names of model outputs.
44049 * @returns An `Array` (one entry per model output) of `Array` of metric
44050 * functions. For instance, if the model has 2 outputs, and for the first
44051 * output we want to compute `binaryAccuracy` and `binaryCrossentropy`,
44052 * and just `binaryAccuracy` for the second output, the `Array` would look
44053 * like:
44054 * `[[binaryAccuracy, binaryCrossentropy], [binaryAccuracy]]`
44055 * @throws TypeError: incompatible metrics format.
44056 */
44057 function collectMetrics(metrics, outputNames) {
44058 if (metrics == null || Array.isArray(metrics) && metrics.length === 0) {
44059 return outputNames.map(name => []);
44060 }
44061 let wrappedMetrics;
44062 if (typeof metrics === 'string' || typeof metrics === 'function') {
44063 wrappedMetrics = [metrics];
44064 }
44065 else if (Array.isArray(metrics) || typeof metrics === 'object') {
44066 wrappedMetrics = metrics;
44067 }
44068 else {
44069 throw new TypeError('Type of metrics argument not understood. Expected an string,' +
44070 `function, Array, or Object, found: ${metrics}`);
44071 }
44072 if (Array.isArray(wrappedMetrics)) {
44073 // We then apply all metrics to all outputs.
44074 return outputNames.map(name => wrappedMetrics);
44075 }
44076 else {
44077 // In this case, metrics is a dict.
44078 const nestedMetrics = [];
44079 for (const name of outputNames) {
44080 let outputMetrics = wrappedMetrics.hasOwnProperty(name) ? wrappedMetrics[name] : [];
44081 if (!Array.isArray(outputMetrics)) {
44082 outputMetrics = [outputMetrics];
44083 }
44084 nestedMetrics.push(outputMetrics);
44085 }
44086 return nestedMetrics;
44087 }
44088 }
44089 const LAYERS_MODEL_FORMAT_NAME = 'layers-model';
44090 /**
44091 * A `tf.LayersModel` is a directed, acyclic graph of `tf.Layer`s plus methods
44092 * for training, evaluation, prediction and saving.
44093 *
44094 * `tf.LayersModel` is the basic unit of training, inference and evaluation in
44095 * TensorFlow.js. To create a `tf.LayersModel`, use `tf.LayersModel`.
44096 *
44097 * See also:
44098 * `tf.Sequential`, `tf.loadLayersModel`.
44099 *
44100 * @doc {heading: 'Models', subheading: 'Classes'}
44101 */
44102 class LayersModel extends Container {
44103 constructor(args) {
44104 super(args);
44105 this.isTraining = false;
44106 }
44107 /**
44108 * Print a text summary of the model's layers.
44109 *
44110 * The summary includes
44111 * - Name and type of all layers that comprise the model.
44112 * - Output shape(s) of the layers
44113 * - Number of weight parameters of each layer
44114 * - If the model has non-sequential-like topology, the inputs each layer
44115 * receives
44116 * - The total number of trainable and non-trainable parameters of the model.
44117 *
44118 * ```js
44119 * const input1 = tf.input({shape: [10]});
44120 * const input2 = tf.input({shape: [20]});
44121 * const dense1 = tf.layers.dense({units: 4}).apply(input1);
44122 * const dense2 = tf.layers.dense({units: 8}).apply(input2);
44123 * const concat = tf.layers.concatenate().apply([dense1, dense2]);
44124 * const output =
44125 * tf.layers.dense({units: 3, activation: 'softmax'}).apply(concat);
44126 *
44127 * const model = tf.model({inputs: [input1, input2], outputs: output});
44128 * model.summary();
44129 * ```
44130 *
44131 * @param lineLength Custom line length, in number of characters.
44132 * @param positions Custom widths of each of the columns, as either
44133 * fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number
44134 * of characters (e.g., `[30, 50, 65]`). Each number corresponds to
44135 * right-most (i.e., ending) position of a column.
44136 * @param printFn Custom print function. Can be used to replace the default
44137 * `console.log`. For example, you can use `x => {}` to mute the printed
44138 * messages in the console.
44139 *
44140 * @doc {heading: 'Models', subheading: 'Classes'}
44141 */
44142 summary(lineLength, positions, printFn = console.log) {
44143 if (!this.built) {
44144 throw new ValueError(`This model has never been called, thus its weights have not been ` +
44145 `created yet. So no summary can be displayed. Build the model ` +
44146 `first (e.g., by calling it on some test data).`);
44147 }
44148 printSummary(this, lineLength, positions, printFn);
44149 }
44150 /**
44151 * Configures and prepares the model for training and evaluation. Compiling
44152 * outfits the model with an optimizer, loss, and/or metrics. Calling `fit`
44153 * or `evaluate` on an un-compiled model will throw an error.
44154 *
44155 * @param args a `ModelCompileArgs` specifying the loss, optimizer, and
44156 * metrics to be used for fitting and evaluating this model.
44157 *
44158 * @doc {heading: 'Models', subheading: 'Classes'}
44159 */
44160 compile(args) {
44161 if (args.loss == null) {
44162 args.loss = [];
44163 }
44164 this.loss = args.loss;
44165 if (typeof args.optimizer === 'string') {
44166 this.optimizer_ = getOptimizer(args.optimizer);
44167 this.isOptimizerOwned = true;
44168 }
44169 else {
44170 if (!(args.optimizer instanceof Optimizer)) {
44171 throw new ValueError(`User-defined optimizer must be an instance of tf.Optimizer.`);
44172 }
44173 this.optimizer_ = args.optimizer;
44174 this.isOptimizerOwned = false;
44175 }
44176 // TODO(cais): Add lossWeights.
44177 // TODO(cais): Add sampleWeightMode.
44178 // Prepare loss functions.
44179 let lossFunctions = [];
44180 if (!Array.isArray(args.loss) && typeof args.loss !== 'string' &&
44181 typeof args.loss !== 'function') {
44182 args.loss = args.loss;
44183 for (const name in args.loss) {
44184 if (this.outputNames.indexOf(name) === -1) {
44185 throw new ValueError(`Unknown entry in loss dictionary: "${name}". ` +
44186 `Only expected the following keys: ${this.outputNames}`);
44187 }
44188 }
44189 for (const name of this.outputNames) {
44190 if (args.loss[name] == null) {
44191 console.warn(`Output "${name}" is missing from loss dictionary. We assume ` +
44192 `this was done on purpose, and we will not be expecting data ` +
44193 `to be passed to ${name} during training`);
44194 }
44195 lossFunctions.push(get$1(args.loss[name]));
44196 }
44197 }
44198 else if (Array.isArray(args.loss)) {
44199 if (args.loss.length !== this.outputs.length) {
44200 throw new ValueError(`When passing an Array as loss, it should have one entry per ` +
44201 `model output. The model has ${this.outputs.length} output(s), ` +
44202 `but you passed loss=${args.loss}.`);
44203 }
44204 const theLosses = args.loss;
44205 lossFunctions = theLosses.map(l => get$1(l));
44206 }
44207 else {
44208 const lossFunction = get$1(args.loss);
44209 this.outputs.forEach(_ => {
44210 lossFunctions.push(lossFunction);
44211 });
44212 }
44213 this.lossFunctions = lossFunctions;
44214 this.feedOutputNames = [];
44215 this.feedOutputShapes = [];
44216 this.feedLossFns = [];
44217 for (let i = 0; i < this.outputs.length; ++i) {
44218 // TODO(cais): Logic for skipping target(s).
44219 const shape = this.internalOutputShapes[i];
44220 const name = this.outputNames[i];
44221 this.feedOutputNames.push(name);
44222 this.feedOutputShapes.push(shape);
44223 this.feedLossFns.push(this.lossFunctions[i]);
44224 }
44225 // TODO(cais): Add logic for output masks.
44226 // TODO(cais): Add logic for sample weights.
44227 const skipTargetIndices = [];
44228 // Prepare metrics.
44229 this.metrics = args.metrics;
44230 // TODO(cais): Add weightedMetrics.
44231 this.metricsNames = ['loss'];
44232 this.metricsTensors = [];
44233 // Compute total loss.
44234 // Porting Note: In PyKeras, metrics_tensors are symbolic tensor objects.
44235 // Here, metricsTensors are TypeScript functions. This difference is due
44236 // to the difference in symbolic/imperative property of the backends.
44237 nameScope('loss', () => {
44238 for (let i = 0; i < this.outputs.length; ++i) {
44239 if (skipTargetIndices.indexOf(i) !== -1) {
44240 continue;
44241 }
44242 // TODO(cais): Add weightedLoss, sampleWeight and mask.
44243 // The following line should be weightedLoss
44244 const weightedLoss = this.lossFunctions[i];
44245 if (this.outputs.length > 1) {
44246 this.metricsTensors.push([weightedLoss, i]);
44247 this.metricsNames.push(this.outputNames[i] + '_loss');
44248 }
44249 }
44250 // Porting Note: Due to the imperative nature of the backend, we calculate
44251 // the regularizer penalties in the totalLossFunction, instead of here.
44252 });
44253 const nestedMetrics = collectMetrics(args.metrics, this.outputNames);
44254 // TODO(cais): Add nestedWeightedMetrics.
44255 /**
44256 * Helper function used in loop below.
44257 */
44258 const appendMetric = (outputIndex, metricName, metricTensor) => {
44259 if (this.outputNames.length > 1) {
44260 metricName = this.outputNames[outputIndex] + '_' + metricName;
44261 }
44262 this.metricsNames.push(metricName);
44263 this.metricsTensors.push([metricTensor, outputIndex]);
44264 };
44265 nameScope('metric', () => {
44266 for (let i = 0; i < this.outputs.length; ++i) {
44267 if (skipTargetIndices.indexOf(i) !== -1) {
44268 continue;
44269 }
44270 const outputMetrics = nestedMetrics[i];
44271 // TODO(cais): Add weights and outputWeightedMetrics.
44272 // TODO(cais): Add optional arg `weights` to the following function.
44273 const handleMetrics = (metrics) => {
44274 const metricNamePrefix = '';
44275 let metricName;
44276 let accFn;
44277 let weightedMetricFn;
44278 // TODO(cais): Use 'weights_' for weighted metrics.
44279 for (const metric of metrics) {
44280 if (typeof metric === 'string' &&
44281 ['accuracy', 'acc', 'crossentropy', 'ce'].indexOf(metric) !==
44282 -1) {
44283 const outputShape = this.internalOutputShapes[i];
44284 if (outputShape[outputShape.length - 1] === 1 ||
44285 this.lossFunctions[i] === binaryCrossentropy$2) {
44286 // case: binary accuracy/crossentropy.
44287 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
44288 accFn = binaryAccuracy$1;
44289 }
44290 else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
44291 accFn = binaryCrossentropy$1;
44292 }
44293 }
44294 else if (this.lossFunctions[i] ===
44295 sparseCategoricalCrossentropy$1) {
44296 // case: categorical accuracy / crossentropy with sparse
44297 // targets.
44298 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
44299 accFn = sparseCategoricalAccuracy$1;
44300 }
44301 else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
44302 accFn = sparseCategoricalCrossentropy;
44303 }
44304 }
44305 else {
44306 // case: categorical accuracy / crossentropy.
44307 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
44308 accFn = categoricalAccuracy$1;
44309 }
44310 else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
44311 accFn = categoricalCrossentropy$1;
44312 }
44313 }
44314 let suffix;
44315 if (['accuracy', 'acc'].indexOf(metric) !== -1) {
44316 suffix = 'acc';
44317 }
44318 else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
44319 suffix = 'ce';
44320 }
44321 // TODO(cais): Add weighting actually.
44322 weightedMetricFn = accFn;
44323 metricName = metricNamePrefix + suffix;
44324 }
44325 else {
44326 const metricFn = get(metric);
44327 // TODO(cais): Add weighting actually.
44328 weightedMetricFn = metricFn;
44329 metricName =
44330 metricNamePrefix + getLossOrMetricName(metric);
44331 }
44332 // TODO(cais): Add weighting and masking to metricResult.
44333 let metricResult;
44334 nameScope(metricName, () => {
44335 metricResult = weightedMetricFn;
44336 });
44337 appendMetric(i, metricName, metricResult);
44338 }
44339 };
44340 handleMetrics(outputMetrics);
44341 // TODO(cais): Call handleMetrics with weights.
44342 }
44343 });
44344 // Porting Notes: Given the imperative backend of tfjs-core,
44345 // there is no need for constructing the symbolic graph and placeholders.
44346 this.collectedTrainableWeights = this.trainableWeights;
44347 }
44348 /**
44349 * Check trainable weights count consistency.
44350 *
44351 * This will raise a warning if `this.trainableWeights` and
44352 * `this.collectedTrainableWeights` are inconsistent (i.e., have different
44353 * numbers of parameters).
44354 * Inconsistency will typically arise when one modifies `model.trainable`
44355 * without calling `model.compile()` again.
44356 */
44357 checkTrainableWeightsConsistency() {
44358 if (this.collectedTrainableWeights == null) {
44359 return;
44360 }
44361 if (this.trainableWeights.length !==
44362 this.collectedTrainableWeights.length) {
44363 console.warn('Discrepancy between trainableweights and collected trainable ' +
44364 'weights. Did you set `model.trainable` without calling ' +
44365 '`model.compile()` afterwards?');
44366 }
44367 }
44368 /**
44369 * Returns the loss value & metrics values for the model in test mode.
44370 *
44371 * Loss and metrics are specified during `compile()`, which needs to happen
44372 * before calls to `evaluate()`.
44373 *
44374 * Computation is done in batches.
44375 *
44376 * ```js
44377 * const model = tf.sequential({
44378 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
44379 * });
44380 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
44381 * const result = model.evaluate(
44382 * tf.ones([8, 10]), tf.ones([8, 1]), {batchSize: 4});
44383 * result.print();
44384 * ```
44385 *
44386 * @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the
44387 * model has multiple inputs.
44388 * @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the
44389 * model has multiple outputs.
44390 * @param args A `ModelEvaluateArgs`, containing optional fields.
44391 *
44392 * @return `Scalar` test loss (if the model has a single output and no
44393 * metrics) or `Array` of `Scalar`s (if the model has multiple outputs
44394 * and/or metrics). The attribute `model.metricsNames`
44395 * will give you the display labels for the scalar outputs.
44396 *
44397 * @doc {heading: 'Models', subheading: 'Classes'}
44398 */
44399 evaluate(x, y, args = {}) {
44400 const batchSize = args.batchSize == null ? 32 : args.batchSize;
44401 checkBatchSize(batchSize);
44402 // TODO(cais): Standardize `config.sampleWeights` as well.
44403 // Validate user data.
44404 const checkBatchAxis = true;
44405 const standardizedOuts = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
44406 try {
44407 // TODO(cais): If uses `useLearningPhase`, set the corresponding element
44408 // of the input to 0.
44409 const ins = standardizedOuts[0].concat(standardizedOuts[1]);
44410 this.makeTestFunction();
44411 const f = this.testFunction;
44412 const testOuts = this.testLoop(f, ins, batchSize, args.verbose, args.steps);
44413 return singletonOrArray(testOuts);
44414 }
44415 finally {
44416 disposeNewTensors(standardizedOuts[0], x);
44417 disposeNewTensors(standardizedOuts[1], y);
44418 }
44419 }
44420 // TODO(cais): Add code snippet below once real dataset objects are
44421 // available.
44422 /**
44423 * Evaluate model using a dataset object.
44424 *
44425 * Note: Unlike `evaluate()`, this method is asynchronous (`async`).
44426 *
44427 * @param dataset A dataset object. Its `iterator()` method is expected
44428 * to generate a dataset iterator object, the `next()` method of which
44429 * is expected to produce data batches for evaluation. The return value
44430 * of the `next()` call ought to contain a boolean `done` field and a
44431 * `value` field. The `value` field is expected to be an array of two
44432 * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
44433 * case is for models with exactly one input and one output (e.g.
44434 * a sequential model). The latter case is for models with multiple
44435 * inputs and/or multiple outputs. Of the two items in the array, the
44436 * first is the input feature(s) and the second is the output target(s).
44437 * @param args A configuration object for the dataset-based evaluation.
44438 * @returns Loss and metric values as an Array of `Scalar` objects.
44439 *
44440 * @doc {heading: 'Models', subheading: 'Classes'}
44441 */
44442 async evaluateDataset(dataset, args) {
44443 this.makeTestFunction();
44444 return evaluateDataset(this, dataset, args);
44445 }
44446 /**
44447 * Get number of samples provided for training, evaluation or prediction.
44448 *
44449 * @param ins Input `tf.Tensor`.
44450 * @param batchSize Integer batch size, optional.
44451 * @param steps Total number of steps (batches of samples) before
44452 * declaring loop finished. Optional.
44453 * @param stepsName The public API's parameter name for `steps`.
44454 * @returns Number of samples provided.
44455 */
44456 checkNumSamples(ins, batchSize, steps, stepsName = 'steps') {
44457 let numSamples;
44458 if (steps != null) {
44459 numSamples = null;
44460 if (batchSize != null) {
44461 throw new ValueError(`If ${stepsName} is set, batchSize must be null or undefined.` +
44462 `Got batchSize = ${batchSize}`);
44463 }
44464 }
44465 else if (ins != null) {
44466 if (Array.isArray(ins)) {
44467 numSamples = ins[0].shape[0];
44468 }
44469 else {
44470 numSamples = ins.shape[0];
44471 }
44472 }
44473 else {
44474 throw new ValueError(`Either the input data should have a defined shape, or ` +
44475 `${stepsName} shoud be specified.`);
44476 }
44477 return numSamples;
44478 }
44479 /**
44480 * Execute internal tensors of the model with input data feed.
44481 * @param inputs Input data feed. Must match the inputs of the model.
44482 * @param outputs Names of the output tensors to be fetched. Must match
44483 * names of the SymbolicTensors that belong to the graph.
44484 * @returns Fetched values for `outputs`.
44485 */
44486 execute(inputs, outputs) {
44487 if (Array.isArray(outputs) && outputs.length === 0) {
44488 throw new ValueError('`outputs` is an empty Array, which is not allowed.');
44489 }
44490 const outputsIsArray = Array.isArray(outputs);
44491 const outputNames = (outputsIsArray ? outputs : [outputs]);
44492 const outputSymbolicTensors = this.retrieveSymbolicTensors(outputNames);
44493 // Format the input into a FeedDict.
44494 const feedDict = new FeedDict();
44495 if (inputs instanceof Tensor) {
44496 inputs = [inputs];
44497 }
44498 if (Array.isArray(inputs)) {
44499 if (inputs.length !== this.inputs.length) {
44500 throw new ValueError(`The number of inputs provided (${inputs.length}) ` +
44501 `does not match the number of inputs of this model ` +
44502 `(${this.inputs.length}).`);
44503 }
44504 for (let i = 0; i < this.inputs.length; ++i) {
44505 feedDict.add(this.inputs[i], inputs[i]);
44506 }
44507 }
44508 else {
44509 for (const input of this.inputs) {
44510 const tensorValue = inputs[input.name];
44511 if (tensorValue == null) {
44512 throw new ValueError(`No value is provided for the model's input ${input.name}`);
44513 }
44514 feedDict.add(input, tensorValue);
44515 }
44516 }
44517 // Run execution.
44518 const executeOutputs = execute(outputSymbolicTensors, feedDict);
44519 return outputsIsArray ? executeOutputs : executeOutputs[0];
44520 }
44521 /**
44522 * Retrieve the model's internal symbolic tensors from symbolic-tensor names.
44523 */
44524 retrieveSymbolicTensors(symbolicTensorNames) {
44525 const outputSymbolicTensors = pyListRepeat(null, symbolicTensorNames.length);
44526 let outputsRemaining = symbolicTensorNames.length;
44527 for (const layer of this.layers) {
44528 const layerOutputs = Array.isArray(layer.output) ? layer.output : [layer.output];
44529 const layerOutputNames = layerOutputs.map(output => output.name);
44530 for (let i = 0; i < symbolicTensorNames.length; ++i) {
44531 const index = layerOutputNames.indexOf(symbolicTensorNames[i]);
44532 if (index !== -1) {
44533 outputSymbolicTensors[i] = layerOutputs[index];
44534 outputsRemaining--;
44535 }
44536 if (outputsRemaining === 0) {
44537 break;
44538 }
44539 }
44540 if (outputsRemaining === 0) {
44541 break;
44542 }
44543 }
44544 if (outputsRemaining > 0) {
44545 const remainingNames = [];
44546 outputSymbolicTensors.forEach((tensor, i) => {
44547 if (tensor == null) {
44548 remainingNames.push(symbolicTensorNames[i]);
44549 }
44550 });
44551 throw new ValueError(`Cannot find SymbolicTensors for output name(s): ` +
44552 `${JSON.stringify(remainingNames)}`);
44553 }
44554 return outputSymbolicTensors;
44555 }
44556 /**
44557 * Helper method to loop over some data in batches.
44558 *
44559 * Porting Note: Not using the functional approach in the Python equivalent
44560 * due to the imperative backend.
44561 * Porting Note: Does not support step mode currently.
44562 *
44563 * @param ins: input data
44564 * @param batchSize: integer batch size.
44565 * @param verbose: verbosity model
44566 * @returns: Predictions as `tf.Tensor` (if a single output) or an `Array` of
44567 * `tf.Tensor` (if multipe outputs).
44568 */
44569 predictLoop(ins, batchSize = 32, verbose = false) {
44570 return tidy(() => {
44571 const numSamples = this.checkNumSamples(ins);
44572 if (verbose) {
44573 throw new NotImplementedError('Verbose predictLoop() is not implemented yet.');
44574 }
44575 // Sample-based predictions.
44576 // Porting Note: Tensor currently does not support sliced assignments as
44577 // in numpy, e.g., x[1:3] = y. Therefore we use concatenation while
44578 // iterating over the batches.
44579 const batches = makeBatches(numSamples, batchSize);
44580 const outsBatches = this.outputs.map(output => []);
44581 // TODO(cais): Can the scope() be pushed down inside the for loop?
44582 for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
44583 const batchOuts = tidy(() => {
44584 const batchStart = batches[batchIndex][0];
44585 const batchEnd = batches[batchIndex][1];
44586 // TODO(cais): Take care of the case of the last element is a flag for
44587 // training/test.
44588 const insBatch = sliceArrays(ins, batchStart, batchEnd);
44589 // Construct the feeds for execute();
44590 const feeds = [];
44591 if (Array.isArray(insBatch)) {
44592 for (let i = 0; i < insBatch.length; ++i) {
44593 feeds.push({ key: this.inputs[i], value: insBatch[i] });
44594 }
44595 }
44596 else {
44597 feeds.push({ key: this.inputs[0], value: insBatch });
44598 }
44599 const feedDict = new FeedDict(feeds);
44600 return execute(this.outputs, feedDict);
44601 });
44602 batchOuts.forEach((batchOut, i) => outsBatches[i].push(batchOut));
44603 }
44604 return singletonOrArray(outsBatches.map(batches => concat$2(batches, 0)));
44605 });
44606 }
44607 /**
44608 * Generates output predictions for the input samples.
44609 *
44610 * Computation is done in batches.
44611 *
44612 * Note: the "step" mode of predict() is currently not supported.
44613 * This is because the TensorFlow.js core backend is imperative only.
44614 *
44615 * ```js
44616 * const model = tf.sequential({
44617 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
44618 * });
44619 * model.predict(tf.ones([8, 10]), {batchSize: 4}).print();
44620 * ```
44621 *
44622 * @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if
44623 * the model has multiple inputs.
44624 * @param args A `ModelPredictArgs` object containing optional fields.
44625 *
44626 * @return Prediction results as a `tf.Tensor`(s).
44627 *
44628 * @exception ValueError In case of mismatch between the provided input data
44629 * and the model's expectations, or in case a stateful model receives a
44630 * number of samples that is not a multiple of the batch size.
44631 *
44632 * @doc {heading: 'Models', subheading: 'Classes'}
44633 */
44634 predict(x, args = {}) {
44635 const xsRank2OrHigher = ensureTensorsRank2OrHigher(x);
44636 checkInputData(xsRank2OrHigher, this.inputNames, this.feedInputShapes, false);
44637 try {
44638 // TODO(cais): Take care of stateful models.
44639 // if (this.stateful) ...
44640 // TODO(cais): Take care of the learning_phase boolean flag.
44641 // if (this.useLearningPhase) ...
44642 const batchSize = args.batchSize == null ? 32 : args.batchSize;
44643 checkBatchSize(batchSize);
44644 return this.predictLoop(xsRank2OrHigher, batchSize);
44645 }
44646 finally {
44647 disposeNewTensors(xsRank2OrHigher, x);
44648 }
44649 }
44650 /**
44651 * Returns predictions for a single batch of samples.
44652 *
44653 * ```js
44654 * const model = tf.sequential({
44655 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
44656 * });
44657 * model.predictOnBatch(tf.ones([8, 10])).print();
44658 * ```
44659 * @param x: Input samples, as a Tensor (for models with exactly one
44660 * input) or an array of Tensors (for models with more than one input).
44661 * @return Tensor(s) of predictions
44662 *
44663 * @doc {heading: 'Models', subheading: 'Classes'}
44664 */
44665 predictOnBatch(x) {
44666 checkInputData(x, this.inputNames, this.feedInputShapes, true);
44667 // TODO(cais): Take care of the learning_phase boolean flag.
44668 // if (this.useLearningPhase) ...
44669 const batchSize = (Array.isArray(x) ? x[0] : x).shape[0];
44670 return this.predictLoop(x, batchSize);
44671 }
44672 standardizeUserDataXY(x, y, checkBatchAxis = true, batchSize) {
44673 // TODO(cais): Add sampleWeight, classWeight
44674 if (this.optimizer_ == null) {
44675 throw new RuntimeError('You must compile a model before training/testing. Use ' +
44676 'LayersModel.compile(modelCompileArgs).');
44677 }
44678 const outputShapes = [];
44679 for (let i = 0; i < this.feedOutputShapes.length; ++i) {
44680 const outputShape = this.feedOutputShapes[i];
44681 const lossFn = this.feedLossFns[i];
44682 if (lossFn === sparseCategoricalCrossentropy$1) {
44683 outputShapes.push(outputShape.slice(0, outputShape.length - 1).concat([1]));
44684 }
44685 else {
44686 // Porting Note: Because of strong typing `lossFn` must be a function.
44687 outputShapes.push(outputShape);
44688 }
44689 }
44690 x = standardizeInputData(x, this.feedInputNames, this.feedInputShapes, false, 'input');
44691 y = standardizeInputData(y, this.feedOutputNames, outputShapes, false, 'target');
44692 // TODO(cais): Standardize sampleWeights & classWeights.
44693 checkArrayLengths(x, y, null);
44694 // TODO(cais): Check sampleWeights as well.
44695 checkLossAndTargetCompatibility(y, this.feedLossFns, this.feedOutputShapes);
44696 if (this.stateful && batchSize != null && batchSize > 0) {
44697 if (x[0].shape[0] % batchSize !== 0) {
44698 throw new ValueError(`In a stateful network, you should only pass inputs with a ` +
44699 `number of samples that is divisible by the batch size ` +
44700 `${batchSize}. Found: ${x[0].shape[0]} sample(s).`);
44701 }
44702 }
44703 return [x, y];
44704 }
44705 async standardizeUserData(x, y, sampleWeight, classWeight, checkBatchAxis = true, batchSize) {
44706 const [standardXs, standardYs] = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
44707 // TODO(cais): Handle sampleWeights.
44708 if (sampleWeight != null) {
44709 throw new Error('sample weight is not supported yet.');
44710 }
44711 let standardSampleWeights = null;
44712 if (classWeight != null) {
44713 const classWeights = standardizeClassWeights(classWeight, this.outputNames);
44714 standardSampleWeights = [];
44715 for (let i = 0; i < classWeights.length; ++i) {
44716 standardSampleWeights.push(await standardizeWeights(standardYs[i], null, classWeights[i]));
44717 }
44718 }
44719 // TODO(cais): Deal with the case of model.stateful == true.
44720 return [standardXs, standardYs, standardSampleWeights];
44721 }
44722 /**
44723 * Loop over some test data in batches.
44724 * @param f A Function returning a list of tensors.
44725 * @param ins Array of tensors to be fed to `f`.
44726 * @param batchSize Integer batch size or `null` / `undefined`.
44727 * @param verbose verbosity mode.
44728 * @param steps Total number of steps (batches of samples) before
44729 * declaring test finished. Ignored with the default value of `null` /
44730 * `undefined`.
44731 * @returns Array of Scalars.
44732 */
44733 testLoop(f, ins, batchSize, verbose = 0, steps) {
44734 return tidy(() => {
44735 const numSamples = this.checkNumSamples(ins, batchSize, steps, 'steps');
44736 const outs = [];
44737 if (verbose > 0) {
44738 throw new NotImplementedError('Verbose mode is not implemented yet.');
44739 }
44740 // TODO(cais): Use `indicesForConversionToDense' to prevent slow down.
44741 if (steps != null) {
44742 throw new NotImplementedError('steps mode in testLoop() is not implemented yet');
44743 }
44744 else {
44745 const batches = makeBatches(numSamples, batchSize);
44746 const indexArray = tensor1d(range$2(0, numSamples));
44747 for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
44748 const batchStart = batches[batchIndex][0];
44749 const batchEnd = batches[batchIndex][1];
44750 const batchIds = sliceAlongFirstAxis(indexArray, batchStart, batchEnd - batchStart);
44751 // TODO(cais): In ins, train flag can be a number, instead of an
44752 // Tensor? Do we need to handle this in tfjs-layers?
44753 const insBatch = sliceArraysByIndices(ins, batchIds);
44754 const batchOuts = f(insBatch);
44755 if (batchIndex === 0) {
44756 for (let i = 0; i < batchOuts.length; ++i) {
44757 outs.push(scalar(0));
44758 }
44759 }
44760 for (let i = 0; i < batchOuts.length; ++i) {
44761 const batchOut = batchOuts[i];
44762 outs[i] =
44763 add$3(outs[i], mul(batchEnd - batchStart, batchOut));
44764 }
44765 }
44766 for (let i = 0; i < outs.length; ++i) {
44767 outs[i] = div$1(outs[i], numSamples);
44768 }
44769 }
44770 return outs;
44771 });
44772 }
44773 getDedupedMetricsNames() {
44774 const outLabels = this.metricsNames;
44775 // Rename duplicated metrics names (can happen with an output layer
44776 // shared among multiple dataflows).
44777 const dedupedOutLabels = [];
44778 for (let i = 0; i < outLabels.length; ++i) {
44779 const label = outLabels[i];
44780 let newLabel = label;
44781 if (count(outLabels, label) > 1) {
44782 const dupIndex = count(outLabels.slice(0, i), label);
44783 newLabel += `_${dupIndex}`;
44784 }
44785 dedupedOutLabels.push(newLabel);
44786 }
44787 return dedupedOutLabels;
44788 }
44789 /**
44790 * Creates a function that performs the following actions:
44791 *
44792 * 1. computes the losses
44793 * 2. sums them to get the total loss
44794 * 3. call the optimizer computes the gradients of the LayersModel's
44795 * trainable weights w.r.t. the total loss and update the variables
44796 * 4. calculates the metrics
44797 * 5. returns the values of the losses and metrics.
44798 */
44799 makeTrainFunction() {
44800 return (data) => {
44801 const lossValues = [];
44802 const inputs = data.slice(0, this.inputs.length);
44803 const targets = data.slice(this.inputs.length, this.inputs.length + this.outputs.length);
44804 const sampleWeights = data.slice(this.inputs.length + this.outputs.length, this.inputs.length + this.outputs.length * 2);
44805 const metricsValues = [];
44806 // Create a function that computes the total loss based on the
44807 // inputs. This function is used for obtaining gradients through
44808 // backprop.
44809 const totalLossFunction = () => {
44810 const feeds = [];
44811 for (let i = 0; i < this.inputs.length; ++i) {
44812 feeds.push({ key: this.inputs[i], value: inputs[i] });
44813 }
44814 const feedDict = new FeedDict(feeds);
44815 const outputs = execute(this.outputs, feedDict, { 'training': true });
44816 // TODO(cais): Take care of the case of multiple outputs from a
44817 // single layer?
44818 let totalLoss;
44819 for (let i = 0; i < this.lossFunctions.length; ++i) {
44820 const lossFunction = this.lossFunctions[i];
44821 let loss = lossFunction(targets[i], outputs[i]);
44822 if (sampleWeights[i] != null) {
44823 loss = computeWeightedLoss(loss, sampleWeights[i]);
44824 }
44825 // TODO(cais): push Scalar instead.
44826 const meanLoss = mean$3(loss);
44827 // TODO(cais): Use a scope() instead, to avoid ownership.
44828 lossValues.push(meanLoss);
44829 if (i === 0) {
44830 totalLoss = loss;
44831 }
44832 else {
44833 totalLoss = add$3(totalLoss, loss);
44834 }
44835 }
44836 // Compute the metrics.
44837 // TODO(cais): These should probably be calculated outside
44838 // totalLossFunction to benefit speed?
44839 for (let i = 0; i < this.metricsTensors.length; ++i) {
44840 let weightedMetric;
44841 if (this.outputs.length > 1 && i < this.outputs.length) {
44842 weightedMetric = lossValues[i];
44843 }
44844 else {
44845 const metric = this.metricsTensors[i][0];
44846 const outputIndex = this.metricsTensors[i][1];
44847 weightedMetric =
44848 mean$3(metric(targets[outputIndex], outputs[outputIndex]));
44849 }
44850 keep(weightedMetric);
44851 // TODO(cais): Use a scope() instead, to avoid ownership.
44852 metricsValues.push(weightedMetric);
44853 }
44854 totalLoss = mean$3(totalLoss);
44855 // Add regularizer penalties.
44856 this.calculateLosses().forEach(regularizerLoss => {
44857 totalLoss = add$3(totalLoss, regularizerLoss);
44858 });
44859 return totalLoss;
44860 };
44861 const variables = this.collectedTrainableWeights.map(param => param.read());
44862 const returnCost = true;
44863 const totalLossValue = this.optimizer_.minimize(totalLossFunction, returnCost, variables);
44864 return [totalLossValue].concat(metricsValues);
44865 };
44866 }
44867 /**
44868 * Create a function which, when invoked with an array of `tf.Tensor`s as a
44869 * batch of inputs, returns the prespecified loss and metrics of the model
44870 * under the batch of input data.
44871 */
44872 makeTestFunction() {
44873 this.testFunction = (data) => {
44874 return tidy(() => {
44875 const valOutputs = [];
44876 let totalLoss;
44877 const inputs = data.slice(0, this.inputs.length);
44878 const targets = data.slice(this.inputs.length, this.inputs.length + this.outputs.length);
44879 const feeds = [];
44880 for (let i = 0; i < this.inputs.length; ++i) {
44881 feeds.push({ key: this.inputs[i], value: inputs[i] });
44882 }
44883 const feedDict = new FeedDict(feeds);
44884 const outputs = execute(this.outputs, feedDict);
44885 // Compute total loss.
44886 for (let i = 0; i < this.lossFunctions.length; ++i) {
44887 const lossFunction = this.lossFunctions[i];
44888 // TODO(cais): Add sample weighting and replace the simple
44889 // averaging.
44890 const loss = mean$3(lossFunction(targets[i], outputs[i]));
44891 if (i === 0) {
44892 totalLoss = loss;
44893 }
44894 else {
44895 totalLoss = add$3(totalLoss, loss);
44896 }
44897 valOutputs.push(totalLoss);
44898 }
44899 // Compute the metrics.
44900 for (let i = 0; i < this.metricsTensors.length; ++i) {
44901 const metric = this.metricsTensors[i][0];
44902 const outputIndex = this.metricsTensors[i][1];
44903 // TODO(cais): Replace K.mean() with a proper weighting function.
44904 const meanMetric = mean$3(metric(targets[outputIndex], outputs[outputIndex]));
44905 valOutputs.push(meanMetric);
44906 }
44907 return valOutputs;
44908 });
44909 };
44910 }
44911 /**
44912 * Trains the model for a fixed number of epochs (iterations on a
44913 * dataset).
44914 *
44915 * ```js
44916 * const model = tf.sequential({
44917 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
44918 * });
44919 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
44920 * for (let i = 1; i < 5 ; ++i) {
44921 * const h = await model.fit(tf.ones([8, 10]), tf.ones([8, 1]), {
44922 * batchSize: 4,
44923 * epochs: 3
44924 * });
44925 * console.log("Loss after Epoch " + i + " : " + h.history.loss[0]);
44926 * }
44927 * ```
44928 *
44929 * @param x `tf.Tensor` of training data, or an array of `tf.Tensor`s if the
44930 * model has multiple inputs. If all inputs in the model are named, you
44931 * can also pass a dictionary mapping input names to `tf.Tensor`s.
44932 * @param y `tf.Tensor` of target (label) data, or an array of `tf.Tensor`s if
44933 * the model has multiple outputs. If all outputs in the model are named,
44934 * you can also pass a dictionary mapping output names to `tf.Tensor`s.
44935 * @param args A `ModelFitArgs`, containing optional fields.
44936 *
44937 * @return A `History` instance. Its `history` attribute contains all
44938 * information collected during training.
44939 *
44940 * @exception ValueError In case of mismatch between the provided input
44941 * data and what the model expects.
44942 *
44943 * @doc {heading: 'Models', subheading: 'Classes'}
44944 */
44945 async fit(x, y, args = {}) {
44946 if (this.isTraining) {
44947 throw new Error('Cannot start training because another fit() call is ongoing.');
44948 }
44949 this.isTraining = true;
44950 let inputs;
44951 let targets;
44952 let originalInputs;
44953 let originalTargets;
44954 let inputValX;
44955 let inputValY;
44956 let valX;
44957 let valY;
44958 let sampleWeights;
44959 try {
44960 const batchSize = args.batchSize == null ? 32 : args.batchSize;
44961 checkBatchSize(batchSize);
44962 // Validate user data.
44963 // TODO(cais): Support sampleWeight.
44964 const checkBatchAxis = false;
44965 const standardizedOuts = await this.standardizeUserData(x, y, args.sampleWeight, args.classWeight, checkBatchAxis, batchSize);
44966 inputs = standardizedOuts[0];
44967 targets = standardizedOuts[1];
44968 sampleWeights = standardizedOuts[2];
44969 // Prepare validation data.
44970 let doValidation = false;
44971 let valIns;
44972 if (args.validationData != null && args.validationData.length > 0) {
44973 doValidation = true;
44974 if (args.validationData.length === 2) {
44975 // config.validationData consists of valX and valY.
44976 inputValX = args.validationData[0];
44977 inputValY = args.validationData[1];
44978 }
44979 else if (args.validationData.length === 3) {
44980 throw new NotImplementedError('validationData including sample weights is not supported yet.');
44981 }
44982 else {
44983 throw new ValueError(`When passing validation data, it must contain 2 (valX, valY) ` +
44984 `or 3 (valX, valY, valSampleWeight) items; ` +
44985 `${args.validationData} is invalid.`);
44986 }
44987 const checkBatchAxis = true;
44988 const valStandardized = await this.standardizeUserData(inputValX, inputValY, null, /** Unused sample weights. */ null, /** Unused class weights. */ checkBatchAxis, batchSize);
44989 valX = valStandardized[0];
44990 valY = valStandardized[1];
44991 valIns = valX.concat(valY);
44992 // TODO(cais): Add useLearningPhase data properly.
44993 }
44994 else if (args.validationSplit != null && args.validationSplit > 0 &&
44995 args.validationSplit < 1) {
44996 doValidation = true;
44997 // Porting Note: In tfjs-layers, inputs[0] is always a Tensor.
44998 const splitAt = Math.floor(inputs[0].shape[0] * (1 - args.validationSplit));
44999 const originalBatchSize = inputs[0].shape[0];
45000 valX = sliceArrays(inputs, splitAt, originalBatchSize);
45001 originalInputs = inputs;
45002 inputs = sliceArrays(inputs, 0, splitAt);
45003 valY = sliceArrays(targets, splitAt, originalBatchSize);
45004 originalTargets = targets;
45005 targets = sliceArrays(targets, 0, splitAt);
45006 // TODO(cais): Once sampleWeights becomes available, slice it to get
45007 // valSampleWeights.
45008 valIns = valX.concat(valY);
45009 // TODO(cais): Add useLearningPhase data properly.
45010 }
45011 else if (args.validationSteps != null) {
45012 doValidation = true;
45013 // TODO(cais): Add useLearningPhase.
45014 }
45015 const ins = inputs.concat(targets).concat(sampleWeights);
45016 this.checkTrainableWeightsConsistency();
45017 // TODO(cais): Handle use_learning_phase and learning_phase?
45018 // Porting Note: Here we see a key deviation of tfjs-layers from
45019 // Keras.
45020 // Due to the imperative nature of tfjs-layers' backend (tfjs-core),
45021 // we do not construct symbolic computation graphs to embody the
45022 // training process. Instead, we define a function that performs the
45023 // training action. In PyKeras, the data (inputs and targets) are fed
45024 // through graph placeholders. In tfjs-layers, the data are fed as
45025 // function arguments. Since the function are defined below in the
45026 // scope, we don't have equivalents of PyKeras's
45027 // `_make_train_funciton`.
45028 const trainFunction = this.makeTrainFunction();
45029 const outLabels = this.getDedupedMetricsNames();
45030 let valFunction;
45031 let callbackMetrics;
45032 if (doValidation) {
45033 this.makeTestFunction();
45034 valFunction = this.testFunction;
45035 callbackMetrics =
45036 outLabels.slice().concat(outLabels.map(n => 'val_' + n));
45037 }
45038 else {
45039 valFunction = null;
45040 valIns = [];
45041 callbackMetrics = outLabels.slice();
45042 }
45043 const callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery);
45044 const out = await this.fitLoop(trainFunction, ins, outLabels, batchSize, args.epochs, args.verbose, callbacks, valFunction, valIns, args.shuffle, callbackMetrics, args.initialEpoch, null, null);
45045 return out;
45046 }
45047 finally {
45048 this.isTraining = false;
45049 // Memory clean up.
45050 disposeNewTensors(inputs, x);
45051 disposeNewTensors(targets, y);
45052 disposeNewTensors(originalInputs, x);
45053 disposeNewTensors(originalTargets, y);
45054 disposeNewTensors(valX, inputValX);
45055 disposeNewTensors(valY, inputValY);
45056 if (sampleWeights != null) {
45057 dispose(sampleWeights);
45058 }
45059 }
45060 // TODO(cais): Add value to outLabels.
45061 }
45062 /**
45063 * Abstract fit function for `f(ins)`.
45064 * @param f A Function returning a list of tensors. For training, this
45065 * function is expected to perform the updates to the variables.
45066 * @param ins List of tensors to be fed to `f`.
45067 * @param outLabels List of strings, display names of the outputs of `f`.
45068 * @param batchSize Integer batch size or `== null` if unknown. Default : 32.
45069 * @param epochs Number of times to iterate over the data. Default : 1.
45070 * @param verbose Verbosity mode: 0, 1, or 2. Default: 1.
45071 * @param callbacks List of callbacks to be called during training.
45072 * @param valF Function to call for validation.
45073 * @param valIns List of tensors to be fed to `valF`.
45074 * @param shuffle Whether to shuffle the data at the beginning of every
45075 * epoch. Default : true.
45076 * @param callbackMetrics List of strings, the display names of the metrics
45077 * passed to the callbacks. They should be the concatenation of the
45078 * display names of the outputs of `f` and the list of display names
45079 * of the outputs of `valF`.
45080 * @param initialEpoch Epoch at which to start training (useful for
45081 * resuming a previous training run). Default : 0.
45082 * @param stepsPerEpoch Total number of steps (batches on samples) before
45083 * declaring one epoch finished and starting the next epoch. Ignored with
45084 * the default value of `undefined` or `null`.
45085 * @param validationSteps Number of steps to run validation for (only if
45086 * doing validation from data tensors). Not applicable for tfjs-layers.
45087 * @returns A `History` object.
45088 */
45089 async fitLoop(f, ins, outLabels, batchSize, epochs, verbose, callbacks, valF, valIns, shuffle$1, callbackMetrics, initialEpoch, stepsPerEpoch, validationSteps) {
45090 if (batchSize == null) {
45091 batchSize = 32;
45092 }
45093 if (epochs == null) {
45094 epochs = 1;
45095 }
45096 if (shuffle$1 == null) {
45097 shuffle$1 = true;
45098 }
45099 if (initialEpoch == null) {
45100 initialEpoch = 0;
45101 }
45102 // TODO(cais): Change const to let below when implementing validation.
45103 let doValidation = false;
45104 if (valF != null && valIns != null) {
45105 doValidation = true;
45106 // TODO(cais): verbose message.
45107 }
45108 if (validationSteps != null) {
45109 doValidation = true;
45110 if (stepsPerEpoch == null) {
45111 throw new ValueError('Can only use `validationSteps` when doing step-wise training, ' +
45112 'i.e., `stepsPerEpoch` must be set.');
45113 }
45114 }
45115 const numTrainSamples = this.checkNumSamples(ins, batchSize, stepsPerEpoch, 'steps_per_epoch');
45116 let indexArray;
45117 if (numTrainSamples != null) {
45118 indexArray = range$2(0, numTrainSamples);
45119 }
45120 if (verbose == null) {
45121 verbose = 1;
45122 }
45123 const { callbackList, history } = configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics);
45124 callbackList.setModel(this);
45125 this.history = history;
45126 await callbackList.onTrainBegin();
45127 this.stopTraining_ = false;
45128 // TODO(cais): Take care of callbacks.validation_data as in PyKeras.
45129 // TODO(cais): Pre-convert feeds for performance as in PyKeras.
45130 for (let epoch = initialEpoch; epoch < epochs; ++epoch) {
45131 await callbackList.onEpochBegin(epoch);
45132 const epochLogs = {};
45133 if (stepsPerEpoch != null) {
45134 throw new NotImplementedError('stepsPerEpoch mode is not implemented yet.');
45135 }
45136 else {
45137 if (shuffle$1 === 'batch') {
45138 throw new NotImplementedError('batch shuffling is not implemneted'
45139 + ' yet');
45140 }
45141 else if (shuffle$1) {
45142 shuffle(indexArray);
45143 }
45144 // Convert the potentially shuffled indices to Tensor1D, to avoid the
45145 // cost of repeated creation of Array1Ds later on.
45146 const epochIndexArray1D = tensor1d(indexArray);
45147 const batches = makeBatches(numTrainSamples, batchSize);
45148 for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
45149 const batchLogs = {};
45150 await callbackList.onBatchBegin(batchIndex, batchLogs);
45151 tidy(() => {
45152 const batchStart = batches[batchIndex][0];
45153 const batchEnd = batches[batchIndex][1];
45154 const batchIds = sliceAlongFirstAxis(epochIndexArray1D, batchStart, batchEnd - batchStart);
45155 batchLogs['batch'] = batchIndex;
45156 batchLogs['size'] = batchEnd - batchStart;
45157 // TODO(cais): In ins, train flag can be a number, instead of an
45158 // Tensor? Do we need to handle this in tfjs-layers?
45159 const insBatch = sliceArraysByIndices(ins, batchIds);
45160 const outs = f(insBatch);
45161 for (let i = 0; i < outLabels.length; ++i) {
45162 const label = outLabels[i];
45163 const out = outs[i];
45164 batchLogs[label] = out;
45165 keep(out);
45166 // TODO(cais): Use scope() to avoid ownership.
45167 }
45168 if (batchIndex === batches.length - 1) { // Last batch.
45169 if (doValidation) {
45170 const valOuts = this.testLoop(valF, valIns, batchSize);
45171 // Porting Notes: In tfjs-layers, valOuts is always an Array.
45172 for (let i = 0; i < outLabels.length; ++i) {
45173 const label = outLabels[i];
45174 const out = valOuts[i];
45175 keep(out);
45176 // TODO(cais): Use scope() to avoid ownership.
45177 epochLogs['val_' + label] = out;
45178 }
45179 }
45180 }
45181 });
45182 await callbackList.onBatchEnd(batchIndex, batchLogs);
45183 disposeTensorsInLogs(batchLogs);
45184 if (this.stopTraining_) {
45185 break;
45186 }
45187 // TODO(cais): return outs as list of Tensor.
45188 }
45189 epochIndexArray1D.dispose();
45190 }
45191 // TODO(cais): Run validation at the end of the epoch.
45192 await callbackList.onEpochEnd(epoch, epochLogs);
45193 if (this.stopTraining_) {
45194 break;
45195 }
45196 }
45197 await callbackList.onTrainEnd();
45198 await this.history.syncData();
45199 return this.history;
45200 }
45201 // TODO(cais): Add code snippet below when it's possible to instantiate
45202 // actual dataset objects.
45203 /**
45204 * Trains the model using a dataset object.
45205 *
45206 * @param dataset A dataset object. Its `iterator()` method is expected
45207 * to generate a dataset iterator object, the `next()` method of which
45208 * is expected to produce data batches for training. The return value
45209 * of the `next()` call ought to contain a boolean `done` field and a
45210 * `value` field. The `value` field is expected to be an array of two
45211 * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
45212 * case is for models with exactly one input and one output (e.g.
45213 * a sequential model). The latter case is for models with multiple
45214 * inputs and/or multiple outputs.
45215 * Of the two items in the array, the first is the input feature(s) and
45216 * the second is the output target(s).
45217 * @param args A `ModelFitDatasetArgs`, containing optional fields.
45218 *
45219 * @return A `History` instance. Its `history` attribute contains all
45220 * information collected during training.
45221 *
45222 * @doc {heading: 'Models', subheading: 'Classes'}
45223 */
45224 async fitDataset(dataset, args) {
45225 return fitDataset(this, dataset, args);
45226 }
45227 /**
45228 * Runs a single gradient update on a single batch of data.
45229 *
45230 * This method differs from `fit()` and `fitDataset()` in the following
45231 * regards:
45232 * - It operates on exactly one batch of data.
45233 * - It returns only the loss and metric values, instead of
45234 * returning the batch-by-batch loss and metric values.
45235 * - It doesn't support fine-grained options such as verbosity and
45236 * callbacks.
45237 *
45238 * @param x Input data. It could be one of the following:
45239 * - A `tf.Tensor`, or an Array of `tf.Tensor`s (in case the model has
45240 * multiple inputs).
45241 * - An Object mapping input names to corresponding `tf.Tensor` (if the
45242 * model has named inputs).
45243 * @param y Target data. It could be either a `tf.Tensor` or multiple
45244 * `tf.Tensor`s. It should be consistent with `x`.
45245 * @returns Training loss or losses (in case the model has
45246 * multiple outputs), along with metrics (if any), as numbers.
45247 *
45248 * @doc {heading: 'Models', subheading: 'Classes'}
45249 */
45250 async trainOnBatch(x, y) {
45251 // TODO(cais): Support sampleWeight and classWeight.
45252 // TODO(cais): Support Dataset objects.
45253 const standardizeOut = await this.standardizeUserData(x, y);
45254 const inputs = standardizeOut[0];
45255 const targets = standardizeOut[1];
45256 const trainFunction = this.makeTrainFunction();
45257 const losses = trainFunction(inputs.concat(targets));
45258 const lossValues = [];
45259 for (const loss of losses) {
45260 const v = await loss.data();
45261 lossValues.push(v[0]);
45262 }
45263 dispose(losses);
45264 disposeNewTensors(standardizeOut[0], x);
45265 disposeNewTensors(standardizeOut[1], y);
45266 return singletonOrArray(lossValues);
45267 }
45268 /**
45269 * Extract weight values of the model.
45270 *
45271 * @param config: An instance of `io.SaveConfig`, which specifies
45272 * model-saving options such as whether only trainable weights are to be
45273 * saved.
45274 * @returns A `NamedTensorMap` mapping original weight names (i.e.,
45275 * non-uniqueified weight names) to their values.
45276 */
45277 getNamedWeights(config) {
45278 const namedWeights = [];
45279 const trainableOnly = config != null && config.trainableOnly;
45280 const weights = trainableOnly ? this.trainableWeights : this.weights;
45281 const weightValues = this.getWeights(trainableOnly);
45282 for (let i = 0; i < weights.length; ++i) {
45283 if (trainableOnly && !weights[i].trainable) {
45284 // Optionally skip non-trainable weights.
45285 continue;
45286 }
45287 namedWeights.push({ name: weights[i].originalName, tensor: weightValues[i] });
45288 }
45289 return namedWeights;
45290 }
45291 /**
45292 * Setter used for force stopping of LayersModel.fit() (i.e., training).
45293 *
45294 * Example:
45295 *
45296 * ```js
45297 * const input = tf.input({shape: [10]});
45298 * const output = tf.layers.dense({units: 1}).apply(input);
45299 * const model = tf.model({inputs: [input], outputs: [output]});
45300 * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
45301 * const xs = tf.ones([8, 10]);
45302 * const ys = tf.zeros([8, 1]);
45303 *
45304 * const history = await model.fit(xs, ys, {
45305 * epochs: 10,
45306 * callbacks: {
45307 * onEpochEnd: async (epoch, logs) => {
45308 * if (epoch === 2) {
45309 * model.stopTraining = true;
45310 * }
45311 * }
45312 * }
45313 * });
45314 *
45315 * // There should be only 3 values in the loss array, instead of 10
45316 * values,
45317 * // due to the stopping after 3 epochs.
45318 * console.log(history.history.loss);
45319 * ```
45320 */
45321 set stopTraining(stop) {
45322 this.stopTraining_ = stop;
45323 }
45324 get stopTraining() {
45325 return this.stopTraining_;
45326 }
45327 get optimizer() {
45328 return this.optimizer_;
45329 }
45330 set optimizer(optimizer) {
45331 if (this.optimizer_ !== optimizer) {
45332 this.optimizer_ = optimizer;
45333 this.isOptimizerOwned = false;
45334 }
45335 }
45336 dispose() {
45337 const result = super.dispose();
45338 if (result.refCountAfterDispose === 0 && this.optimizer != null &&
45339 this.isOptimizerOwned) {
45340 const numTensorsBeforeOptmizerDisposal = memory().numTensors;
45341 this.optimizer_.dispose();
45342 result.numDisposedVariables +=
45343 numTensorsBeforeOptmizerDisposal - memory().numTensors;
45344 }
45345 return result;
45346 }
45347 getLossIdentifiers() {
45348 let lossNames;
45349 if (typeof this.loss === 'string') {
45350 lossNames = toSnakeCase(this.loss);
45351 }
45352 else if (Array.isArray(this.loss)) {
45353 for (const loss of this.loss) {
45354 if (typeof loss !== 'string') {
45355 throw new Error('Serialization of non-string loss is not supported.');
45356 }
45357 }
45358 lossNames = this.loss.map(name => toSnakeCase(name));
45359 }
45360 else {
45361 const outputNames = Object.keys(this.loss);
45362 lossNames = {};
45363 const losses = this.loss;
45364 for (const outputName of outputNames) {
45365 if (typeof losses[outputName] === 'string') {
45366 lossNames[outputName] =
45367 toSnakeCase(losses[outputName]);
45368 }
45369 else {
45370 throw new Error('Serialization of non-string loss is not supported.');
45371 }
45372 }
45373 }
45374 return lossNames;
45375 }
45376 getMetricIdentifiers() {
45377 if (typeof this.metrics === 'string' ||
45378 typeof this.metrics === 'function') {
45379 return [toSnakeCase(getLossOrMetricName(this.metrics))];
45380 }
45381 else if (Array.isArray(this.metrics)) {
45382 return this.metrics.map(metric => toSnakeCase(getLossOrMetricName(metric)));
45383 }
45384 else {
45385 const metricsIdentifiers = {};
45386 for (const key in this.metrics) {
45387 metricsIdentifiers[key] =
45388 toSnakeCase(getLossOrMetricName(this.metrics[key]));
45389 }
45390 return metricsIdentifiers;
45391 }
45392 }
45393 getTrainingConfig() {
45394 return {
45395 loss: this.getLossIdentifiers(),
45396 metrics: this.getMetricIdentifiers(),
45397 optimizer_config: {
45398 class_name: this.optimizer.getClassName(),
45399 config: this.optimizer.getConfig()
45400 }
45401 };
45402 // TODO(cais): Add weight_metrics when they are supported.
45403 // TODO(cais): Add sample_weight_mode when it's supported.
45404 // TODO(cais): Add loss_weights when it's supported.
45405 }
45406 loadTrainingConfig(trainingConfig) {
45407 if (trainingConfig.weighted_metrics != null) {
45408 throw new Error('Loading weight_metrics is not supported yet.');
45409 }
45410 if (trainingConfig.loss_weights != null) {
45411 throw new Error('Loading loss_weights is not supported yet.');
45412 }
45413 if (trainingConfig.sample_weight_mode != null) {
45414 throw new Error('Loading sample_weight_mode is not supported yet.');
45415 }
45416 const tsConfig = convertPythonicToTs(trainingConfig.optimizer_config);
45417 const optimizer = deserialize(tsConfig);
45418 let loss;
45419 if (typeof trainingConfig.loss === 'string') {
45420 loss = toCamelCase(trainingConfig.loss);
45421 }
45422 else if (Array.isArray(trainingConfig.loss)) {
45423 loss = trainingConfig.loss.map(lossEntry => toCamelCase(lossEntry));
45424 }
45425 else if (trainingConfig.loss != null) {
45426 loss = {};
45427 for (const key in trainingConfig.loss) {
45428 loss[key] = toCamelCase(trainingConfig.loss[key]);
45429 }
45430 }
45431 let metrics;
45432 if (Array.isArray(trainingConfig.metrics)) {
45433 metrics = trainingConfig.metrics.map(metric => toCamelCase(metric));
45434 }
45435 else if (trainingConfig.metrics != null) {
45436 metrics = {};
45437 for (const key in trainingConfig.metrics) {
45438 metrics[key] = toCamelCase(trainingConfig.metrics[key]);
45439 }
45440 }
45441 this.compile({ loss, metrics, optimizer });
45442 }
45443 /**
45444 * Save the configuration and/or weights of the LayersModel.
45445 *
45446 * An `IOHandler` is an object that has a `save` method of the proper
45447 * signature defined. The `save` method manages the storing or
45448 * transmission of serialized data ("artifacts") that represent the
45449 * model's topology and weights onto or via a specific medium, such as
45450 * file downloads, local storage, IndexedDB in the web browser and HTTP
45451 * requests to a server. TensorFlow.js provides `IOHandler`
45452 * implementations for a number of frequently used saving mediums, such as
45453 * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
45454 * for more details.
45455 *
45456 * This method also allows you to refer to certain types of `IOHandler`s
45457 * as URL-like string shortcuts, such as 'localstorage://' and
45458 * 'indexeddb://'.
45459 *
45460 * Example 1: Save `model`'s topology and weights to browser [local
45461 * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
45462 * then load it back.
45463 *
45464 * ```js
45465 * const model = tf.sequential(
45466 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
45467 * console.log('Prediction from original model:');
45468 * model.predict(tf.ones([1, 3])).print();
45469 *
45470 * const saveResults = await model.save('localstorage://my-model-1');
45471 *
45472 * const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
45473 * console.log('Prediction from loaded model:');
45474 * loadedModel.predict(tf.ones([1, 3])).print();
45475 * ```
45476 *
45477 * Example 2. Saving `model`'s topology and weights to browser
45478 * [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
45479 * then load it back.
45480 *
45481 * ```js
45482 * const model = tf.sequential(
45483 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
45484 * console.log('Prediction from original model:');
45485 * model.predict(tf.ones([1, 3])).print();
45486 *
45487 * const saveResults = await model.save('indexeddb://my-model-1');
45488 *
45489 * const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
45490 * console.log('Prediction from loaded model:');
45491 * loadedModel.predict(tf.ones([1, 3])).print();
45492 * ```
45493 *
45494 * Example 3. Saving `model`'s topology and weights as two files
45495 * (`my-model-1.json` and `my-model-1.weights.bin`) downloaded from
45496 * browser.
45497 *
45498 * ```js
45499 * const model = tf.sequential(
45500 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
45501 * const saveResults = await model.save('downloads://my-model-1');
45502 * ```
45503 *
45504 * Example 4. Send `model`'s topology and weights to an HTTP server.
45505 * See the documentation of `tf.io.http` for more details
45506 * including specifying request parameters and implementation of the
45507 * server.
45508 *
45509 * ```js
45510 * const model = tf.sequential(
45511 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
45512 * const saveResults = await model.save('http://my-server/model/upload');
45513 * ```
45514 *
45515 * @param handlerOrURL An instance of `IOHandler` or a URL-like,
45516 * scheme-based string shortcut for `IOHandler`.
45517 * @param config Options for saving the model.
45518 * @returns A `Promise` of `SaveResult`, which summarizes the result of
45519 * the saving, such as byte sizes of the saved artifacts for the model's
45520 * topology and weight values.
45521 *
45522 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
45523 */
45524 async save(handlerOrURL, config) {
45525 if (typeof handlerOrURL === 'string') {
45526 const handlers = getSaveHandlers(handlerOrURL);
45527 if (handlers.length === 0) {
45528 throw new ValueError(`Cannot find any save handlers for URL '${handlerOrURL}'`);
45529 }
45530 else if (handlers.length > 1) {
45531 throw new ValueError(`Found more than one (${handlers.length}) save handlers for ` +
45532 `URL '${handlerOrURL}'`);
45533 }
45534 handlerOrURL = handlers[0];
45535 }
45536 if (handlerOrURL.save == null) {
45537 throw new ValueError('LayersModel.save() cannot proceed because the IOHandler ' +
45538 'provided does not have the `save` attribute defined.');
45539 }
45540 const weightDataAndSpecs = await encodeWeights(this.getNamedWeights(config));
45541 const returnString = false;
45542 const unusedArg = null;
45543 const modelConfig = this.toJSON(unusedArg, returnString);
45544 const modelArtifacts = {
45545 modelTopology: modelConfig,
45546 format: LAYERS_MODEL_FORMAT_NAME,
45547 generatedBy: `TensorFlow.js tfjs-layers v${version$6}`,
45548 convertedBy: null,
45549 };
45550 const includeOptimizer = config == null ? false : config.includeOptimizer;
45551 if (includeOptimizer && this.optimizer != null) {
45552 modelArtifacts.trainingConfig = this.getTrainingConfig();
45553 const weightType = 'optimizer';
45554 const { data: optimizerWeightData, specs: optimizerWeightSpecs } = await encodeWeights(await this.optimizer.getWeights(), weightType);
45555 weightDataAndSpecs.specs.push(...optimizerWeightSpecs);
45556 weightDataAndSpecs.data = concatenateArrayBuffers([weightDataAndSpecs.data, optimizerWeightData]);
45557 }
45558 if (this.userDefinedMetadata != null) {
45559 // Check serialized size of user-defined metadata.
45560 const checkSize = true;
45561 checkUserDefinedMetadata(this.userDefinedMetadata, this.name, checkSize);
45562 modelArtifacts.userDefinedMetadata = this.userDefinedMetadata;
45563 }
45564 modelArtifacts.weightData = weightDataAndSpecs.data;
45565 modelArtifacts.weightSpecs = weightDataAndSpecs.specs;
45566 return handlerOrURL.save(modelArtifacts);
45567 }
45568 /**
45569 * Set user-defined metadata.
45570 *
45571 * The set metadata will be serialized together with the topology
45572 * and weights of the model during `save()` calls.
45573 *
45574 * @param setUserDefinedMetadata
45575 */
45576 setUserDefinedMetadata(userDefinedMetadata) {
45577 checkUserDefinedMetadata(userDefinedMetadata, this.name);
45578 this.userDefinedMetadata = userDefinedMetadata;
45579 }
45580 /**
45581 * Get user-defined metadata.
45582 *
45583 * The metadata is supplied via one of the two routes:
45584 * 1. By calling `setUserDefinedMetadata()`.
45585 * 2. Loaded during model loading (if the model is constructed
45586 * via `tf.loadLayersModel()`.)
45587 *
45588 * If no user-defined metadata is available from either of the
45589 * two routes, this function will return `undefined`.
45590 */
45591 getUserDefinedMetadata() {
45592 return this.userDefinedMetadata;
45593 }
45594 }
45595 // The class name is 'Model' rather than 'LayersModel' for backwards
45596 // compatibility since this class name shows up in the serialization format.
45597 /** @nocollapse */
45598 LayersModel.className = 'Model';
45599 registerClass(LayersModel);
45600 /**
45601 * A `tf.Functional` is an alias to `tf.LayersModel`.
45602 *
45603 * See also:
45604 * `tf.LayersModel`, `tf.Sequential`, `tf.loadLayersModel`.
45605 */
45606 /** @doc {heading: 'Models', subheading: 'Classes'} */
45607 class Functional extends LayersModel {
45608 }
45609 Functional.className = 'Functional';
45610 registerClass(Functional);
45611
45612 /**
45613 * @license
45614 * Copyright 2018 Google LLC
45615 *
45616 * Use of this source code is governed by an MIT-style
45617 * license that can be found in the LICENSE file or at
45618 * https://opensource.org/licenses/MIT.
45619 * =============================================================================
45620 */
45621 /**
45622 * Parses a JSON model configuration file and returns a model instance.
45623 *
45624 * ```js
45625 * // This example shows how to serialize a model using `toJSON()` and
45626 * // deserialize it as another model using `tf.models.modelFromJSON()`.
45627 * // Note: this example serializes and deserializes only the topology
45628 * // of the model; the weights of the loaded model will be different
45629 * // from those of the the original model, due to random weight
45630 * // initialization.
45631 * // To load the topology and weights of a model, use `tf.loadLayersModel()`.
45632 * const model1 = tf.sequential();
45633 * model1.add(tf.layers.repeatVector({inputShape: [2], n: 4}));
45634 * // Serialize `model1` as a JSON object.
45635 * const model1JSON = model1.toJSON(null, false);
45636 * model1.summary();
45637 *
45638 * const model2 = await tf.models.modelFromJSON(model1JSON);
45639 * model2.summary();
45640 * ```
45641 *
45642 * @param modelAndWeightsConfig JSON object or string encoding a model and
45643 * weights configuration. It can also be only the topology JSON of the
45644 * model, in which case the weights will not be loaded.
45645 * @param custom_objects Optional dictionary mapping names
45646 * (strings) to custom classes or functions to be
45647 * considered during deserialization.
45648 * @returns A TensorFlow.js Layers `tf.LayersModel` instance (uncompiled).
45649 */
45650 async function modelFromJSON(modelAndWeightsConfig, customObjects) {
45651 if (!('modelTopology' in modelAndWeightsConfig)) {
45652 modelAndWeightsConfig = { modelTopology: modelAndWeightsConfig };
45653 }
45654 modelAndWeightsConfig = modelAndWeightsConfig;
45655 let modelTopology = modelAndWeightsConfig.modelTopology;
45656 if (modelTopology['model_config'] != null) {
45657 // If the model-topology JSON contains a 'model_config' field, then it is
45658 // a full model JSON (e.g., from `keras.Model.save()`), which contains
45659 // not only the model's architecture in its 'model_config' field, but
45660 // additional information such as the model's optimizer. We use only the
45661 // 'model_config' field currently.
45662 modelTopology = modelTopology['model_config'];
45663 }
45664 const tsConfig = convertPythonicToTs(modelTopology);
45665 const model = deserialize(tsConfig, customObjects);
45666 if (modelAndWeightsConfig.weightsManifest != null) {
45667 // Load the weight values keyed by the original tensor names in the model
45668 // file that was loaded. These should match the keys of the weight
45669 // manifest.
45670 const weightValues = await loadWeights(modelAndWeightsConfig.weightsManifest, modelAndWeightsConfig.pathPrefix, model.weights.map(weight => weight.originalName));
45671 // Map the weights to the unique tensor names generated during model loading
45672 const uniqueWeightValues = {};
45673 for (const weight of model.weights) {
45674 uniqueWeightValues[weight.originalName] =
45675 weightValues[weight.originalName];
45676 }
45677 model.loadWeights(uniqueWeightValues);
45678 // Dispose temporary weight values.
45679 dispose(weightValues);
45680 }
45681 return model;
45682 }
45683 /**
45684 * Load a model composed of Layer objects, including its topology and optionally
45685 * weights. See the Tutorial named "How to import a Keras Model" for usage
45686 * examples.
45687 *
45688 * This method is applicable to:
45689 *
45690 * 1. Models created with the `tf.layers.*`, `tf.sequential`, and
45691 * `tf.model` APIs of TensorFlow.js and later saved with the
45692 * `tf.LayersModel.save` method.
45693 * 2. Models converted from Keras or TensorFlow tf.keras using the
45694 * [tensorflowjs_converter](https://github.com/tensorflow/tfjs/tree/master/tfjs-converter).
45695 *
45696 * This mode is *not* applicable to TensorFlow `SavedModel`s or their converted
45697 * forms. For those models, use `tf.loadGraphModel`.
45698 *
45699 * Example 1. Load a model from an HTTP server.
45700 *
45701 * ```js
45702 * const model = await tf.loadLayersModel(
45703 * 'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json');
45704 * model.summary();
45705 * ```
45706 *
45707 * Example 2: Save `model`'s topology and weights to browser [local
45708 * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
45709 * then load it back.
45710 *
45711 * ```js
45712 * const model = tf.sequential(
45713 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
45714 * console.log('Prediction from original model:');
45715 * model.predict(tf.ones([1, 3])).print();
45716 *
45717 * const saveResults = await model.save('localstorage://my-model-1');
45718 *
45719 * const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
45720 * console.log('Prediction from loaded model:');
45721 * loadedModel.predict(tf.ones([1, 3])).print();
45722 * ```
45723 *
45724 * Example 3. Saving `model`'s topology and weights to browser
45725 * [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
45726 * then load it back.
45727 *
45728 * ```js
45729 * const model = tf.sequential(
45730 * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
45731 * console.log('Prediction from original model:');
45732 * model.predict(tf.ones([1, 3])).print();
45733 *
45734 * const saveResults = await model.save('indexeddb://my-model-1');
45735 *
45736 * const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
45737 * console.log('Prediction from loaded model:');
45738 * loadedModel.predict(tf.ones([1, 3])).print();
45739 * ```
45740 *
45741 * Example 4. Load a model from user-selected files from HTML
45742 * [file input
45743 * elements](https://developer.mozilla.org/en-US/docs/Web/HTML/Element/input/file).
45744 *
45745 * ```js
45746 * // Note: this code snippet will not work without the HTML elements in the
45747 * // page
45748 * const jsonUpload = document.getElementById('json-upload');
45749 * const weightsUpload = document.getElementById('weights-upload');
45750 *
45751 * const model = await tf.loadLayersModel(
45752 * tf.io.browserFiles([jsonUpload.files[0], weightsUpload.files[0]]));
45753 * ```
45754 *
45755 * @param pathOrIOHandler Can be either of the two formats
45756 * 1. A string path to the `ModelAndWeightsConfig` JSON describing
45757 * the model in the canonical TensorFlow.js format. For file://
45758 * (tfjs-node-only), http:// and https:// schemas, the path can be
45759 * either absolute or relative. The content of the JSON file is assumed to
45760 * be a JSON object with the following fields and values:
45761 * - 'modelTopology': A JSON object that can be either of:
45762 * 1. a model architecture JSON consistent with the format of the return
45763 * value of `keras.Model.to_json()`
45764 * 2. a full model JSON in the format of `keras.models.save_model()`.
45765 * - 'weightsManifest': A TensorFlow.js weights manifest.
45766 * See the Python converter function `save_model()` for more details.
45767 * It is also assumed that model weights can be accessed from relative
45768 * paths described by the `paths` fields in weights manifest.
45769 * 2. A `tf.io.IOHandler` object that loads model artifacts with its `load`
45770 * method.
45771 * @param options Optional configuration arguments for the model loading,
45772 * including:
45773 * - `strict`: Require that the provided weights exactly match those required
45774 * by the layers. Default true. Passing false means that both extra
45775 * weights and missing weights will be silently ignored.
45776 * - `onProgress`: A progress callback of the form:
45777 * `(fraction: number) => void`. This callback can be used to monitor the
45778 * model-loading process.
45779 * @returns A `Promise` of `tf.LayersModel`, with the topology and weights
45780 * loaded.
45781 *
45782 * @doc {heading: 'Models', subheading: 'Loading'}
45783 */
45784 async function loadLayersModel(pathOrIOHandler, options) {
45785 if (options == null) {
45786 options = {};
45787 }
45788 if (typeof pathOrIOHandler === 'string') {
45789 const handlers = getLoadHandlers(pathOrIOHandler, options);
45790 if (handlers.length === 0) {
45791 // For backward compatibility: if no load handler can be found,
45792 // assume it is a relative http path.
45793 // TODO(cais): Reformat the args into a single `LoadOptions` once the core
45794 // is refactored.
45795 handlers.push(browserHTTPRequest(pathOrIOHandler, options));
45796 }
45797 else if (handlers.length > 1) {
45798 throw new ValueError(`Found more than one (${handlers.length}) load handlers for ` +
45799 `URL '${pathOrIOHandler}'`);
45800 }
45801 pathOrIOHandler = handlers[0];
45802 }
45803 return loadLayersModelFromIOHandler(pathOrIOHandler, undefined, options);
45804 }
45805 /**
45806 * Load a model and optionally its weights, using an IOHandler object.
45807 *
45808 * @param handler The instance of `IOHandler` to be used during the model
45809 * loading.
45810 * @param customObjects Any optional custom objects to be used during model
45811 * loading.
45812 * @param strict Whether the weight loading will be done in strict mode.
45813 * Default: `true`.
45814 */
45815 async function loadLayersModelFromIOHandler(handler, customObjects, options) {
45816 if (options == null) {
45817 options = {};
45818 }
45819 if (handler.load == null) {
45820 throw new ValueError('Cannot proceed with model loading because the IOHandler provided ' +
45821 'does not have the `load` method implemented.');
45822 }
45823 const artifacts = await handler.load();
45824 let modelTopology = artifacts.modelTopology;
45825 if (modelTopology['model_config'] != null) {
45826 modelTopology = modelTopology['model_config'];
45827 }
45828 const strict = options.strict == null ? true : options.strict;
45829 // If weights are provided and the weight-loading mode is strict, use
45830 // fast weight initialization. This skips costly initializers such as
45831 // 'orthogonal' and saves unnecessary computation in cases where
45832 // the initialized weight values will immediately be overwritten by
45833 // loaded weight values.
45834 const fastWeightInit = artifacts.weightData != null && artifacts.weightSpecs != null && strict;
45835 const model = deserialize(convertPythonicToTs(modelTopology), customObjects, fastWeightInit);
45836 const trainingConfig = artifacts.trainingConfig;
45837 if (trainingConfig != null) {
45838 model.loadTrainingConfig(trainingConfig);
45839 }
45840 if (artifacts.userDefinedMetadata != null) {
45841 model.setUserDefinedMetadata(artifacts.userDefinedMetadata);
45842 }
45843 // If weightData is present, load the weights into the model.
45844 if (artifacts.weightData != null) {
45845 // Loading weights requires weightSpecs.
45846 if (artifacts.weightSpecs == null) {
45847 throw new ValueError('LayersModel artifacts contains weight data, but not weight specs. ' +
45848 'Therefore loading of weights cannot proceed.');
45849 }
45850 const { modelWeights, optimizerWeights } = decodeModelAndOptimizerWeights(artifacts.weightData, artifacts.weightSpecs);
45851 model.loadWeights(modelWeights, strict);
45852 if (model.optimizer != null && optimizerWeights.length > 0) {
45853 await model.optimizer.setWeights(optimizerWeights);
45854 }
45855 // Dispose temporary weight values.
45856 dispose(modelWeights);
45857 dispose(optimizerWeights.map(w => w.tensor));
45858 }
45859 return model;
45860 }
45861 function decodeModelAndOptimizerWeights(weightData, specs) {
45862 const name2Tensor = decodeWeights(weightData, specs);
45863 const modelWeights = {};
45864 const optimizerWeights = [];
45865 specs.forEach(spec => {
45866 if (spec.group === 'optimizer') {
45867 optimizerWeights.push({ name: spec.name, tensor: name2Tensor[spec.name] });
45868 }
45869 else {
45870 modelWeights[spec.name] = name2Tensor[spec.name];
45871 }
45872 });
45873 return { modelWeights, optimizerWeights };
45874 }
45875 /**
45876 * A model with a stack of layers, feeding linearly from one to the next.
45877 *
45878 * `tf.sequential` is a factory function that creates an instance of
45879 * `tf.Sequential`.
45880 *
45881 * ```js
45882 * // Define a model for linear regression.
45883 * const model = tf.sequential();
45884 * model.add(tf.layers.dense({units: 1, inputShape: [1]}));
45885 *
45886 * // Prepare the model for training: Specify the loss and the optimizer.
45887 * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
45888 *
45889 * // Generate some synthetic data for training.
45890 * const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
45891 * const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
45892 *
45893 * // Train the model using the data then do inference on a data point the
45894 * // model hasn't seen:
45895 * await model.fit(xs, ys);
45896 * model.predict(tf.tensor2d([5], [1, 1])).print();
45897 * ```
45898 *
45899 * @doc {heading: 'Models', subheading: 'Classes'}
45900 */
45901 class Sequential extends LayersModel {
45902 constructor(args) {
45903 super({ inputs: [], outputs: [] });
45904 args = args || {};
45905 this.trainable = true;
45906 this.built = false;
45907 // Set model name.
45908 this.name = (args.name != null) ? args.name : getUid('sequential_');
45909 // Add to the model any layers passed to the constructor.
45910 if (args.layers != null) {
45911 for (const layer of args.layers) {
45912 this.add(layer);
45913 }
45914 }
45915 }
45916 // Helper function to Sequential.add Throws if the new output shape will be
45917 // invalid.
45918 checkShape(layer) {
45919 const shape = layer.inboundNodes[0].outputTensors[0].shape;
45920 if (shape.some(x => x < 0)) {
45921 throw new ValueError('Negative dimension size caused by adding layer ' +
45922 `${layer.name} with input shape [` +
45923 `${layer.inboundNodes[0].inputTensors[0].shape}]`);
45924 }
45925 }
45926 /**
45927 * Adds a layer instance on top of the layer stack.
45928 *
45929 * ```js
45930 * const model = tf.sequential();
45931 * model.add(tf.layers.dense({units: 8, inputShape: [1]}));
45932 * model.add(tf.layers.dense({units: 4, activation: 'relu6'}));
45933 * model.add(tf.layers.dense({units: 1, activation: 'relu6'}));
45934 * // Note that the untrained model is random at this point.
45935 * model.predict(tf.randomNormal([10, 1])).print();
45936 * ```
45937 * @param layer Layer instance.
45938 *
45939 * @exception ValueError In case the `layer` argument does not know its
45940 * input shape.
45941 * @exception ValueError In case the `layer` argument has multiple output
45942 * tensors, or is already connected somewhere else (forbidden in
45943 * `Sequential` models).
45944 *
45945 * @doc {heading: 'Models', subheading: 'Classes'}
45946 */
45947 add(layer) {
45948 const isLayerModelInstance = layer instanceof Sequential || layer instanceof LayersModel;
45949 let modelLayer;
45950 if (isLayerModelInstance) {
45951 modelLayer = layer;
45952 if (modelLayer.outputs.length !== 1) {
45953 throw new ValueError('All layers in a Sequential model ' +
45954 'should have a single output tensor. ' +
45955 'For multi-output layers, ' +
45956 'use the functional API.');
45957 }
45958 if (modelLayer.inputs.length !== 1) {
45959 throw new ValueError('All layers in a Sequential model ' +
45960 'should have a single input tensor. ' +
45961 'For multi-input layers, ' +
45962 'use the functional API.');
45963 }
45964 }
45965 if (this.outputs.length === 0) {
45966 // first layer in model: check that it is an input layer
45967 if (layer.inboundNodes.length === 0) {
45968 // create an input layer
45969 if (layer.batchInputShape == null) {
45970 throw new ValueError('The first layer in a Sequential model must ' +
45971 'get an `inputShape` or `batchInputShape` argument.');
45972 }
45973 // Instantiate the input layer.
45974 const x = Input({
45975 batchShape: layer.batchInputShape,
45976 dtype: layer.dtype,
45977 name: layer.name + '_input'
45978 });
45979 // This will build the current layer and create the node connecting
45980 // the current layer to the input layer we just created.
45981 layer.apply(x);
45982 }
45983 if (isLayerModelInstance) {
45984 this.outputs = modelLayer.outputs;
45985 this.inputs = modelLayer.inputs;
45986 }
45987 else {
45988 if (layer.inboundNodes.length !== 1) {
45989 throw new ValueError('A layer added to a Sequential model must not already be ' +
45990 `connected somewhere else. LayersModel received layer ${layer.name} ` +
45991 `which has ${layer.inboundNodes.length} pre-existing inbound ` +
45992 'connections.');
45993 }
45994 if (layer.inboundNodes[0].outputTensors.length !== 1) {
45995 throw new ValueError('All layers in a Sequential model ' +
45996 'should have a single output tensor. ' +
45997 'For multi-output layers, ' +
45998 'use the functional API.');
45999 }
46000 this.checkShape(layer);
46001 this.outputs = [layer.inboundNodes[0].outputTensors[0]];
46002 this.inputs = getSourceInputs(this.outputs[0]);
46003 }
46004 this.inboundNodes = [];
46005 // We create an input node, which we will keep updated
46006 // as we add more layers.
46007 // (This call has side effects.)
46008 // tslint:disable-next-line:no-unused-expression
46009 new Node({
46010 outboundLayer: this,
46011 inboundLayers: [],
46012 nodeIndices: [],
46013 tensorIndices: [],
46014 inputTensors: this.inputs,
46015 outputTensors: this.outputs,
46016 // no model-level masking for now
46017 inputMasks: pyListRepeat(null, this.inputs.length),
46018 outputMasks: [null],
46019 inputShapes: this.inputs.map(x => x.shape),
46020 outputShapes: this.outputs[0].shape
46021 });
46022 }
46023 else {
46024 const outputTensor = layer.apply(this.outputs[0]);
46025 if (Array.isArray(outputTensor)) {
46026 throw new TypeError('All layers in a Sequential model ' +
46027 'should have a single output tensor. ' +
46028 'For multi-output layers, ' +
46029 'use the functional API.');
46030 }
46031 this.checkShape(layer);
46032 this.outputs = [outputTensor];
46033 // update self.inbound_nodes
46034 this.inboundNodes[0].outputTensors = this.outputs;
46035 this.inboundNodes[0].outputShapes = [this.outputs[0].shape];
46036 }
46037 this.layers.push(layer);
46038 this.built = false;
46039 }
46040 /**
46041 * Removes the last layer in the model.
46042 *
46043 * @exception TypeError if there are no layers in the model.
46044 */
46045 pop() {
46046 if (this.layers.length === 0) {
46047 throw new TypeError('There are no layers in the model.');
46048 }
46049 this.layers.pop();
46050 if (this.layers.length === 0) {
46051 this.outputs = [];
46052 this.inboundNodes = [];
46053 this.outboundNodes = [];
46054 }
46055 else {
46056 const lastLayerIndex = this.layers.length - 1;
46057 this.layers[lastLayerIndex].outboundNodes = [];
46058 this.outputs = [this.layers[lastLayerIndex].output];
46059 // update self.inbound_nodes
46060 this.inboundNodes[0].outputTensors = this.outputs;
46061 this.inboundNodes[0].outputShapes = [this.outputs[0].shape];
46062 }
46063 }
46064 call(inputs, kwargs) {
46065 if (this.model == null) {
46066 this.build();
46067 }
46068 return this.model.call(inputs, kwargs);
46069 }
46070 build(inputShape) {
46071 // Call `getExactlyOneShape` without using its return value,
46072 // to verify that exactly one input shape is provided.
46073 getExactlyOneShape(inputShape);
46074 if (this.inputs.length === 0 || this.outputs.length === 0) {
46075 throw new TypeError('Sequential model cannot be built: model is empty.' +
46076 ' Add some layers first.');
46077 }
46078 // actually create the model
46079 this.model = new LayersModel({
46080 inputs: this.inputs,
46081 outputs: this.outputs[0],
46082 name: this.name + '_model'
46083 });
46084 this.model.trainable = this.trainable;
46085 // mirror model attributes
46086 this.supportsMasking = this.model.supportsMasking;
46087 // TODO(michaelterry): Add caches
46088 this.inputLayers = this.model.inputLayers;
46089 this.inputLayersNodeIndices = this.model.inputLayersNodeIndices;
46090 this.inputLayersTensorIndices = this.model.inputLayersTensorIndices;
46091 this.outputLayers = this.model.outputLayers;
46092 this.outputLayersNodeIndices = this.model.outputLayersNodeIndices;
46093 this.outputLayersTensorIndices = this.model.outputLayersTensorIndices;
46094 this.nodesByDepth = this.model.nodesByDepth;
46095 this.containerNodes = this.model.containerNodes;
46096 this.outputNames = this.model.outputNames;
46097 this.inputNames = this.model.inputNames;
46098 // TODO(michaelterry): Add feedInputNames, feedInputs, if needed.
46099 // TODO(michaelterry): Add callbackModel if needed.
46100 this.built = true;
46101 }
46102 countParams() {
46103 if (!this.built) {
46104 this.build();
46105 }
46106 return super.countParams();
46107 }
46108 /**
46109 * Print a text summary of the Sequential model's layers.
46110 *
46111 * The summary includes
46112 * - Name and type of all layers that comprise the model.
46113 * - Output shape(s) of the layers
46114 * - Number of weight parameters of each layer
46115 * - The total number of trainable and non-trainable parameters of the
46116 * model.
46117 *
46118 * ```js
46119 * const model = tf.sequential();
46120 * model.add(
46121 * tf.layers.dense({units: 100, inputShape: [10], activation: 'relu'}));
46122 * model.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));
46123 *
46124 * model.summary();
46125 * ```
46126 *
46127 * @param lineLength Custom line length, in number of characters.
46128 * @param positions Custom widths of each of the columns, as either
46129 * fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number
46130 * of characters (e.g., `[30, 50, 65]`). Each number corresponds to
46131 * right-most (i.e., ending) position of a column.
46132 * @param printFn Custom print function. Can be used to replace the default
46133 * `console.log`. For example, you can use `x => {}` to mute the printed
46134 * messages in the console.
46135 *
46136 * @doc {heading: 'Models', subheading: 'Classes'}
46137 */
46138 summary(lineLength, positions, printFn = console.log) {
46139 if (!this.built) {
46140 this.build();
46141 }
46142 super.summary(lineLength, positions, printFn);
46143 }
46144 /**
46145 * Sets the weights of the model.
46146 *
46147 * @param weights Should be a list of Tensors with shapes and types matching
46148 * the output of `model.getWeights()`.
46149 */
46150 setWeights(weights) {
46151 if (this.model == null) {
46152 this.build();
46153 }
46154 this.model.setWeights(weights);
46155 }
46156 /**
46157 * Returns the loss value & metrics values for the model in test mode.
46158 *
46159 * Loss and metrics are specified during `compile()`, which needs to happen
46160 * before calls to `evaluate()`.
46161 *
46162 * Computation is done in batches.
46163 *
46164 * ```js
46165 * const model = tf.sequential({
46166 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
46167 * });
46168 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
46169 * const result = model.evaluate(tf.ones([8, 10]), tf.ones([8, 1]), {
46170 * batchSize: 4,
46171 * });
46172 * result.print();
46173 * ```
46174 *
46175 * @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the
46176 * model has multiple inputs.
46177 * @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the
46178 * model has multiple outputs.
46179 * @param args A `ModelEvaluateConfig`, containing optional fields.
46180 *
46181 * @return `Scalar` test loss (if the model has a single output and no
46182 * metrics) or `Array` of `Scalar`s (if the model has multiple outputs
46183 * and/or metrics). The attribute `model.metricsNames`
46184 * will give you the display labels for the scalar outputs.
46185 *
46186 * @doc {heading: 'Models', subheading: 'Classes'}
46187 */
46188 evaluate(x, y, args = {}) {
46189 if (!this.built) {
46190 throw new RuntimeError('The model needs to be compiled before being used.');
46191 }
46192 return this.model.evaluate(x, y, args);
46193 }
46194 // TODO(cais): Add code snippet below once real dataset objects are
46195 // available.
46196 /**
46197 * Evaluate model using a dataset object.
46198 *
46199 * Note: Unlike `evaluate()`, this method is asynchronous (`async`).
46200 *
46201 * @param dataset A dataset object. Its `iterator()` method is expected
46202 * to generate a dataset iterator object, the `next()` method of which
46203 * is expected to produce data batches for evaluation. The return value
46204 * of the `next()` call ought to contain a boolean `done` field and a
46205 * `value` field. The `value` field is expected to be an array of two
46206 * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
46207 * case is for models with exactly one input and one output (e.g.
46208 * a sequential model). The latter case is for models with multiple
46209 * inputs and/or multiple outputs. Of the two items in the array, the
46210 * first is the input feature(s) and the second is the output target(s).
46211 * @param args A configuration object for the dataset-based evaluation.
46212 * @returns Loss and metric values as an Array of `Scalar` objects.
46213 *
46214 * @doc {heading: 'Models', subheading: 'Classes'}
46215 */
46216 async evaluateDataset(dataset, args) {
46217 if (!this.built) {
46218 throw new RuntimeError('The model needs to be compiled before being used.');
46219 }
46220 return this.model.evaluateDataset(dataset, args);
46221 }
46222 /**
46223 * Generates output predictions for the input samples.
46224 *
46225 * Computation is done in batches.
46226 *
46227 * Note: the "step" mode of predict() is currently not supported.
46228 * This is because the TensorFlow.js core backend is imperative only.
46229 *
46230 * ```js
46231 * const model = tf.sequential({
46232 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
46233 * });
46234 * model.predict(tf.ones([2, 10])).print();
46235 * ```
46236 *
46237 * @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if
46238 * the model has multiple inputs.
46239 * @param conifg A `ModelPredictConfig` object containing optional fields.
46240 *
46241 * @return `tf.Tensor`(s) of predictions.
46242 *
46243 * @exception ValueError In case of mismatch between the provided input data
46244 * and the model's expectations, or in case a stateful model receives a
46245 * number of samples that is not a multiple of the batch size.
46246 *
46247 * @doc {heading: 'Models', subheading: 'Classes'}
46248 */
46249 predict(x, args = {}) {
46250 if (this.model == null) {
46251 this.build();
46252 }
46253 return this.model.predict(x, args);
46254 }
46255 /**
46256 * Returns predictions for a single batch of samples.
46257 *
46258 * @param x: Input samples, as a Tensor, or list of Tensors (if the model
46259 * has multiple inputs).
46260 * @return Tensor(s) of predictions
46261 */
46262 predictOnBatch(x) {
46263 if (this.model == null) {
46264 this.build();
46265 }
46266 return this.model.predictOnBatch(x);
46267 }
46268 /**
46269 * See `LayersModel.compile`.
46270 *
46271 * @param args
46272 */
46273 compile(args) {
46274 this.build();
46275 this.model.compile(args);
46276 this.optimizer_ = this.model.optimizer;
46277 // tslint:disable-next-line:no-any
46278 this.isOptimizerOwned = this.model.isOptimizerOwned;
46279 this.loss = this.model.loss;
46280 this.metrics = this.model.metrics;
46281 // TODO(cais): Add this.lossWeights, this.sampleWeightMode,
46282 // this.weightedMetrics, this.targets.
46283 this.metricsTensors = this.model.metricsTensors;
46284 this.metricsNames = this.model.metricsNames;
46285 // TODO(cais): Add sampleWeights.
46286 }
46287 get optimizer() {
46288 return this.model == null ? undefined : this.model.optimizer;
46289 }
46290 set optimizer(optimizer) {
46291 this.model.optimizer = optimizer;
46292 }
46293 /**
46294 * Trains the model for a fixed number of epochs (iterations on a dataset).
46295 *
46296 * ```js
46297 * const model = tf.sequential({
46298 * layers: [tf.layers.dense({units: 1, inputShape: [10]})]
46299 * });
46300 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
46301 * const history = await model.fit(tf.ones([8, 10]), tf.ones([8, 1]), {
46302 * batchSize: 4,
46303 * epochs: 3
46304 * });
46305 * console.log(history.history.loss[0]);
46306 * ```
46307 *
46308 * @param x `tf.Tensor` of training data, or an array of `tf.Tensor`s if the
46309 * model has multiple inputs. If all inputs in the model are named, you can
46310 * also pass a dictionary mapping input names to `tf.Tensor`s.
46311 * @param y `tf.Tensor` of target (label) data, or an array of `tf.Tensor`s if
46312 * the model has multiple outputs. If all outputs in the model are named, you
46313 * can also pass a dictionary mapping output names to `tf.Tensor`s.
46314 * @param args A `ModelFitConfig`, containing optional fields.
46315 *
46316 * @return A `History` instance. Its `history` attribute contains all
46317 * information collected during training.
46318 *
46319 * @exception ValueError In case of mismatch between the provided input data
46320 * and what the model expects.
46321 *
46322 * @doc {heading: 'Models', subheading: 'Classes'}
46323 */
46324 async fit(x, y, args = {}) {
46325 if (!this.built) {
46326 throw new RuntimeError('The model needs to be compiled before ' +
46327 'being used.');
46328 }
46329 return this.model.fit(x, y, args);
46330 }
46331 /**
46332 * Trains the model using a dataset object.
46333 *
46334 * ```js
46335 * const xArray = [
46336 * [1, 1, 1, 1, 1, 1, 1, 1, 1],
46337 * [1, 1, 1, 1, 1, 1, 1, 1, 1],
46338 * [1, 1, 1, 1, 1, 1, 1, 1, 1],
46339 * [1, 1, 1, 1, 1, 1, 1, 1, 1],
46340 * ];
46341 * const yArray = [1, 1, 1, 1];
46342 * // Create a dataset from the JavaScript array.
46343 * const xDataset = tf.data.array(xArray);
46344 * const yDataset = tf.data.array(yArray);
46345 * // Zip combines the `x` and `y` Datasets into a single Dataset, the
46346 * // iterator of which will return an object containing of two tensors,
46347 * // corresponding to `x` and `y`. The call to `batch(4)` will bundle
46348 * // four such samples into a single object, with the same keys now pointing
46349 * // to tensors that hold 4 examples, organized along the batch dimension.
46350 * // The call to `shuffle(4)` causes each iteration through the dataset to
46351 * // happen in a different order. The size of the shuffle window is 4.
46352 * const xyDataset = tf.data.zip({xs: xDataset, ys: yDataset})
46353 * .batch(4)
46354 * .shuffle(4);
46355 * const model = tf.sequential({
46356 * layers: [tf.layers.dense({units: 1, inputShape: [9]})]
46357 * });
46358 * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
46359 * const history = await model.fitDataset(xyDataset, {
46360 * epochs: 4,
46361 * callbacks: {onEpochEnd: (epoch, logs) => console.log(logs.loss)}
46362 * });
46363 * ```
46364 *
46365 * @param dataset A dataset object. Its `iterator()` method is expected to
46366 * generate a dataset iterator object, the `next()` method of which is
46367 * expected to produce data batches for evaluation. The return value of the
46368 * `next()` call ought to contain a boolean `done` field and a `value`
46369 * field.
46370 *
46371 * The `value` field is expected to be an object of with fields
46372 * `xs` and `ys`, which point to the feature tensor and the target tensor,
46373 * respectively. This case is for models with exactly one input and one
46374 * output (e.g. a sequential model). For example:
46375 * ```js
46376 * {value: {xs: xsTensor, ys: ysTensor}, done: false}
46377 * ```
46378 *
46379 * If the model has multiple inputs, the `xs` field of `value` should
46380 * be an object mapping input names to their respective feature tensors.
46381 * For example:
46382 * ```js
46383 * {
46384 * value: {
46385 * xs: {
46386 * input_1: xsTensor1,
46387 * input_2: xsTensor2
46388 * },
46389 * ys: ysTensor
46390 * },
46391 * done: false
46392 * }
46393 * ```
46394 * If the model has multiple outputs, the `ys` field of `value` should
46395 * be an object mapping output names to their respective target tensors.
46396 * For example:
46397 * ```js
46398 * {
46399 * value: {
46400 * xs: xsTensor,
46401 * ys: {
46402 * output_1: ysTensor1,
46403 * output_2: ysTensor2
46404 * },
46405 * },
46406 * done: false
46407 * }
46408 * ```
46409 * @param args A `ModelFitDatasetArgs`, containing optional fields.
46410 *
46411 * @return A `History` instance. Its `history` attribute contains all
46412 * information collected during training.
46413 *
46414 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
46415 */
46416 async fitDataset(dataset, args) {
46417 if (!this.built) {
46418 throw new RuntimeError('The model needs to be compiled before ' +
46419 'being used.');
46420 }
46421 return this.model.fitDataset(dataset, args);
46422 }
46423 /**
46424 * Runs a single gradient update on a single batch of data.
46425 *
46426 * This method differs from `fit()` and `fitDataset()` in the following
46427 * regards:
46428 * - It operates on exactly one batch of data.
46429 * - It returns only the loss and metric values, instead of
46430 * returning the batch-by-batch loss and metric values.
46431 * - It doesn't support fine-grained options such as verbosity and
46432 * callbacks.
46433 *
46434 * @param x Input data. It could be one of the following:
46435 * - A `tf.Tensor`, or an Array of `tf.Tensor`s (in case the model has
46436 * multiple inputs).
46437 * - An Object mapping input names to corresponding `tf.Tensor` (if the
46438 * model has named inputs).
46439 * @param y Target data. It could be either a `tf.Tensor` or multiple
46440 * `tf.Tensor`s. It should be consistent with `x`.
46441 * @returns Training loss or losses (in case the model has
46442 * multiple outputs), along with metrics (if any), as numbers.
46443 *
46444 * @doc {heading: 'Models', subheading: 'Classes'}
46445 */
46446 async trainOnBatch(x, y) {
46447 return this.model.trainOnBatch(x, y);
46448 }
46449 /* See parent class for JsDoc */
46450 /** @nocollapse */
46451 static fromConfig(cls, config, customObjects = {}, fastWeightInit = false) {
46452 let configArray;
46453 let extraModelConfig = {};
46454 if (config instanceof Array) {
46455 if (!(config[0].className != null) ||
46456 config[0]['className'] === 'Merge') {
46457 throw new ValueError('Legacy serialization format not supported yet.');
46458 }
46459 configArray = config;
46460 }
46461 else {
46462 assert$1(config['layers'] != null, () => `When the config data for a Sequential model is not an Array, ` +
46463 `it must be an Object that contains the 'layers' field.`);
46464 configArray = config['layers'];
46465 delete config['layers'];
46466 extraModelConfig = config;
46467 }
46468 const model = new cls(extraModelConfig);
46469 if (!(model instanceof Sequential)) {
46470 throw new NotImplementedError(`Sequential.fromConfig called on non-Sequential input: ${model}`);
46471 }
46472 for (const conf of configArray) {
46473 const customObjects = undefined;
46474 const layer = deserialize(conf, customObjects, fastWeightInit);
46475 if (fastWeightInit) {
46476 layer.setFastWeightInitDuringBuild(true);
46477 }
46478 model.add(layer);
46479 }
46480 return model;
46481 }
46482 /**
46483 * Setter used for force stopping of LayersModel.fit() (i.e., training).
46484 *
46485 * Example:
46486 *
46487 * ```js
46488 * const model = tf.sequential();
46489 * model.add(tf.layers.dense({units: 1, inputShape: [10]}));
46490 * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
46491 * const xs = tf.ones([8, 10]);
46492 * const ys = tf.zeros([8, 1]);
46493 *
46494 * const history = await model.fit(xs, ys, {
46495 * epochs: 10,
46496 * callbacks: {
46497 * onEpochEnd: async (epoch, logs) => {
46498 * if (epoch === 2) {
46499 * model.stopTraining = true;
46500 * }
46501 * }
46502 * }
46503 * });
46504 *
46505 * // There should be only 3 values in the loss array, instead of 10 values,
46506 * // due to the stopping after 3 epochs.
46507 * console.log(history.history.loss);
46508 * ```
46509 */
46510 set stopTraining(stop) {
46511 // TODO(cais): When refactoring to remove the composition pattern happens,
46512 // remove this method overriding.
46513 if (this.model == null) {
46514 throw new ValueError('Cannot set the stopTraining property of a sequential model before ' +
46515 'it is compiled.');
46516 }
46517 this.model.stopTraining = stop;
46518 }
46519 get stopTraining() {
46520 if (this.model == null) {
46521 throw new ValueError('Cannot get the stopTraining property of a sequential model before ' +
46522 'it is compiled.');
46523 }
46524 return this.model.stopTraining;
46525 }
46526 // TODO(cais): Override get trainableWeights() here
46527 // tslint:disable-next-line:no-any
46528 getConfig() {
46529 // NOTE(cais): We override the return type of getConfig() to `any` here,
46530 // because the `Sequential` class is a special case among `Container`
46531 // subtypes in that its getConfig() method returns an Array (not a
46532 // dict).
46533 const layers = [];
46534 for (const layer of this.layers) {
46535 const dict = {};
46536 dict['className'] = layer.getClassName();
46537 dict['config'] = layer.getConfig();
46538 layers.push(dict);
46539 }
46540 return { name: this.name, layers };
46541 }
46542 }
46543 /** @nocollapse */
46544 Sequential.className = 'Sequential';
46545 registerClass(Sequential);
46546
46547 /**
46548 * @license
46549 * Copyright 2018 Google LLC
46550 *
46551 * Use of this source code is governed by an MIT-style
46552 * license that can be found in the LICENSE file or at
46553 * https://opensource.org/licenses/MIT.
46554 * =============================================================================
46555 */
46556 // TODO(cais): Add doc string to all the public static functions in this
46557 // class; include exectuable JavaScript code snippets where applicable
46558 // (b/74074458).
46559 // LayersModel and related factory methods.
46560 /**
46561 * A model is a data structure that consists of `Layers` and defines inputs
46562 * and outputs.
46563 *
46564 * The key difference between `tf.model` and `tf.sequential` is that
46565 * `tf.model` is more generic, supporting an arbitrary graph (without
46566 * cycles) of layers. `tf.sequential` is less generic and supports only a linear
46567 * stack of layers.
46568 *
46569 * When creating a `tf.LayersModel`, specify its input(s) and output(s). Layers
46570 * are used to wire input(s) to output(s).
46571 *
46572 * For example, the following code snippet defines a model consisting of
46573 * two `dense` layers, with 10 and 4 units, respectively.
46574 *
46575 * ```js
46576 * // Define input, which has a size of 5 (not including batch dimension).
46577 * const input = tf.input({shape: [5]});
46578 *
46579 * // First dense layer uses relu activation.
46580 * const denseLayer1 = tf.layers.dense({units: 10, activation: 'relu'});
46581 * // Second dense layer uses softmax activation.
46582 * const denseLayer2 = tf.layers.dense({units: 4, activation: 'softmax'});
46583 *
46584 * // Obtain the output symbolic tensor by applying the layers on the input.
46585 * const output = denseLayer2.apply(denseLayer1.apply(input));
46586 *
46587 * // Create the model based on the inputs.
46588 * const model = tf.model({inputs: input, outputs: output});
46589 *
46590 * // The model can be used for training, evaluation and prediction.
46591 * // For example, the following line runs prediction with the model on
46592 * // some fake data.
46593 * model.predict(tf.ones([2, 5])).print();
46594 * ```
46595 * See also:
46596 * `tf.sequential`, `tf.loadLayersModel`.
46597 *
46598 * @doc {heading: 'Models', subheading: 'Creation'}
46599 */
46600 function model(args) {
46601 return new LayersModel(args);
46602 }
46603 /**
46604 * Creates a `tf.Sequential` model. A sequential model is any model where the
46605 * outputs of one layer are the inputs to the next layer, i.e. the model
46606 * topology is a simple 'stack' of layers, with no branching or skipping.
46607 *
46608 * This means that the first layer passed to a `tf.Sequential` model should have
46609 * a defined input shape. What that means is that it should have received an
46610 * `inputShape` or `batchInputShape` argument, or for some type of layers
46611 * (recurrent, Dense...) an `inputDim` argument.
46612 *
46613 * The key difference between `tf.model` and `tf.sequential` is that
46614 * `tf.sequential` is less generic, supporting only a linear stack of layers.
46615 * `tf.model` is more generic and supports an arbitrary graph (without
46616 * cycles) of layers.
46617 *
46618 * Examples:
46619 *
46620 * ```js
46621 * const model = tf.sequential();
46622 *
46623 * // First layer must have an input shape defined.
46624 * model.add(tf.layers.dense({units: 32, inputShape: [50]}));
46625 * // Afterwards, TF.js does automatic shape inference.
46626 * model.add(tf.layers.dense({units: 4}));
46627 *
46628 * // Inspect the inferred shape of the model's output, which equals
46629 * // `[null, 4]`. The 1st dimension is the undetermined batch dimension; the
46630 * // 2nd is the output size of the model's last layer.
46631 * console.log(JSON.stringify(model.outputs[0].shape));
46632 * ```
46633 *
46634 * It is also possible to specify a batch size (with potentially undetermined
46635 * batch dimension, denoted by "null") for the first layer using the
46636 * `batchInputShape` key. The following example is equivalent to the above:
46637 *
46638 * ```js
46639 * const model = tf.sequential();
46640 *
46641 * // First layer must have a defined input shape
46642 * model.add(tf.layers.dense({units: 32, batchInputShape: [null, 50]}));
46643 * // Afterwards, TF.js does automatic shape inference.
46644 * model.add(tf.layers.dense({units: 4}));
46645 *
46646 * // Inspect the inferred shape of the model's output.
46647 * console.log(JSON.stringify(model.outputs[0].shape));
46648 * ```
46649 *
46650 * You can also use an `Array` of already-constructed `Layer`s to create
46651 * a `tf.Sequential` model:
46652 *
46653 * ```js
46654 * const model = tf.sequential({
46655 * layers: [tf.layers.dense({units: 32, inputShape: [50]}),
46656 * tf.layers.dense({units: 4})]
46657 * });
46658 * console.log(JSON.stringify(model.outputs[0].shape));
46659 * ```
46660 *
46661 * @doc {heading: 'Models', subheading: 'Creation'}
46662 */
46663 function sequential(config) {
46664 return new Sequential(config);
46665 }
46666 /**
46667 * Used to instantiate an input to a model as a `tf.SymbolicTensor`.
46668 *
46669 * Users should call the `input` factory function for
46670 * consistency with other generator functions.
46671 *
46672 * Example:
46673 *
46674 * ```js
46675 * // Defines a simple logistic regression model with 32 dimensional input
46676 * // and 3 dimensional output.
46677 * const x = tf.input({shape: [32]});
46678 * const y = tf.layers.dense({units: 3, activation: 'softmax'}).apply(x);
46679 * const model = tf.model({inputs: x, outputs: y});
46680 * model.predict(tf.ones([2, 32])).print();
46681 * ```
46682 *
46683 * Note: `input` is only necessary when using `model`. When using
46684 * `sequential`, specify `inputShape` for the first layer or use `inputLayer`
46685 * as the first layer.
46686 *
46687 * @doc {heading: 'Models', subheading: 'Inputs'}
46688 */
46689 function input(config) {
46690 return Input(config);
46691 }
46692 function registerCallbackConstructor(verbosityLevel, callbackConstructor) {
46693 CallbackConstructorRegistry.registerCallbackConstructor(verbosityLevel, callbackConstructor);
46694 }
46695
46696 /**
46697 * @license
46698 * Copyright 2018 Google LLC
46699 *
46700 * Use of this source code is governed by an MIT-style
46701 * license that can be found in the LICENSE file or at
46702 * https://opensource.org/licenses/MIT.
46703 * =============================================================================
46704 */
46705 /**
46706 * Base class for Activations.
46707 *
46708 * Special note: due to cross-language compatibility reasons, the
46709 * static readonly className field in this family of classes must be set to
46710 * the initialLowerCamelCase name of the activation.
46711 */
46712 let Activation$1 = class Activation extends Serializable {
46713 getConfig() {
46714 return {};
46715 }
46716 };
46717 /**
46718 * Exponential linear unit (ELU).
46719 * Reference: https://arxiv.org/abs/1511.07289
46720 */
46721 class Elu extends Activation$1 {
46722 /**
46723 * Calculate the activation function.
46724 *
46725 * @param x: Input.
46726 * @param alpha: Scaling factor the negative section.
46727 * @return Output of the ELU activation.
46728 */
46729 apply(x, alpha = 1) {
46730 return elu$3(x, alpha);
46731 }
46732 }
46733 /** @nocollapse */
46734 Elu.className = 'elu';
46735 registerClass(Elu);
46736 /**
46737 * Scaled Exponential Linear Unit. (Klambauer et al., 2017).
46738 * Reference: Self-Normalizing Neural Networks, https://arxiv.org/abs/1706.02515
46739 * Notes:
46740 * - To be used together with the initialization "lecunNormal".
46741 * - To be used together with the dropout variant "AlphaDropout".
46742 */
46743 class Selu extends Activation$1 {
46744 apply(x) {
46745 return selu$2(x);
46746 }
46747 }
46748 /** @nocollapse */
46749 Selu.className = 'selu';
46750 registerClass(Selu);
46751 /**
46752 * Rectified linear unit
46753 */
46754 class Relu extends Activation$1 {
46755 apply(x) {
46756 return relu$2(x);
46757 }
46758 }
46759 /** @nocollapse */
46760 Relu.className = 'relu';
46761 registerClass(Relu);
46762 /**
46763 * Rectified linear unit activation maxing out at 6.0.
46764 */
46765 class Relu6 extends Activation$1 {
46766 apply(x) {
46767 return tidy(() => minimum$4(6.0, relu$2(x)));
46768 }
46769 }
46770 /** @nocollapse */
46771 Relu6.className = 'relu6';
46772 registerClass(Relu6);
46773 //* Linear activation (no-op) */
46774 class Linear extends Activation$1 {
46775 apply(x) {
46776 return x;
46777 }
46778 }
46779 /** @nocollapse */
46780 Linear.className = 'linear';
46781 registerClass(Linear);
46782 /**
46783 * Sigmoid activation function.
46784 */
46785 class Sigmoid extends Activation$1 {
46786 apply(x) {
46787 return sigmoid$2(x);
46788 }
46789 }
46790 /** @nocollapse */
46791 Sigmoid.className = 'sigmoid';
46792 registerClass(Sigmoid);
46793 /**
46794 * Segment-wise linear approximation of sigmoid.
46795 */
46796 class HardSigmoid extends Activation$1 {
46797 apply(x) {
46798 return hardSigmoid(x);
46799 }
46800 }
46801 /** @nocollapse */
46802 HardSigmoid.className = 'hardSigmoid';
46803 registerClass(HardSigmoid);
46804 /**
46805 * Softplus activation function.
46806 */
46807 class Softplus extends Activation$1 {
46808 apply(x) {
46809 return softplus$2(x);
46810 }
46811 }
46812 /** @nocollapse */
46813 Softplus.className = 'softplus';
46814 registerClass(Softplus);
46815 /**
46816 * Softsign activation function.
46817 */
46818 class Softsign extends Activation$1 {
46819 apply(x) {
46820 return softsign(x);
46821 }
46822 }
46823 /** @nocollapse */
46824 Softsign.className = 'softsign';
46825 registerClass(Softsign);
46826 /**
46827 * Hyperbolic tangent function.
46828 */
46829 class Tanh extends Activation$1 {
46830 apply(x) {
46831 return tanh$2(x);
46832 }
46833 }
46834 /** @nocollapse */
46835 Tanh.className = 'tanh';
46836 registerClass(Tanh);
46837 /**
46838 * Softmax activation function
46839 */
46840 let Softmax$1 = class Softmax extends Activation$1 {
46841 /**
46842 * Calculate the activation function.
46843 *
46844 * @param x Tensor.
46845 * @param axis Integer, axis along which the softmax normalization is applied.
46846 * Invalid if < 2, as softmax across 1 (the batch dimension) is assumed to be
46847 * an error.
46848 *
46849 * @returns a Tensor of the same shape as x
46850 *
46851 * @throws ValueError: In case `dim(x) < 2`.
46852 */
46853 apply(x, axis = (-1)) {
46854 return softmax$3(x, axis);
46855 }
46856 };
46857 /** @nocollapse */
46858 Softmax$1.className = 'softmax';
46859 registerClass(Softmax$1);
46860 /**
46861 * Log softmax activation function
46862 */
46863 class LogSoftmax extends Activation$1 {
46864 /**
46865 * Calculate the activation function of log softmax:
46866 * log( exp(x_i) / sum(exp(x)) )
46867 *
46868 * @param x Tensor.
46869 * @param axis Integer, axis along which the softmax normalization is applied.
46870 * Invalid if < 2, as softmax across 1 (the batch dimension) is assumed to be
46871 * an error.
46872 *
46873 * @returns a Tensor of the same shape as x
46874 *
46875 * @throws ValueError: In case `dim(x) < 2`.
46876 */
46877 apply(x, axis = (-1)) {
46878 return logSoftmax(x, axis);
46879 }
46880 }
46881 /** @nocollapse */
46882 LogSoftmax.className = 'logSoftmax';
46883 registerClass(LogSoftmax);
46884 /**
46885 * Gelu activation function
46886 */
46887 class Gelu extends Activation$1 {
46888 /**
46889 * Calculate the activation function.
46890 *
46891 * @param x Tensor.
46892 * @returns a Tensor of the same shape as x
46893 */
46894 apply(x) {
46895 return tidy(() => {
46896 return tidy(() => {
46897 const sqrtTwo = Math.sqrt(2);
46898 // Compute Φ(x) using the erf function
46899 const cdf = mul(0.5, add$3(1, erf$2(div$1(x, sqrtTwo))));
46900 // Compute GELU(x) = x * Φ(x)
46901 return mul(x, cdf);
46902 });
46903 });
46904 }
46905 }
46906 /** @nocollapse */
46907 Gelu.className = 'gelu';
46908 registerClass(Gelu);
46909 /**
46910 * GeluNew activation function
46911 */
46912 class GeluNew extends Activation$1 {
46913 /**
46914 * Calculate the activation function.
46915 *
46916 * @param x Tensor.
46917 * @returns a Tensor of the same shape as x
46918 */
46919 apply(x) {
46920 return tidy(() => {
46921 return mul(0.5, mul(x, add$3(1, tanh$2(mul(sqrt$2(div$1(2, Math.PI)), add$3(x, mul(0.044715, pow$3(x, 3))))))));
46922 });
46923 }
46924 }
46925 /** @nocollapse */
46926 GeluNew.className = 'gelu_new';
46927 registerClass(GeluNew);
46928 /**
46929 * Mish activation function
46930 */
46931 class Mish extends Activation$1 {
46932 /**
46933 * Calculate the activation function.
46934 *
46935 * @param x Tensor.
46936 * @returns a Tensor of the same shape as x
46937 */
46938 apply(x) {
46939 return tidy(() => mul(x, tanh$2(softplus$2(x))));
46940 }
46941 }
46942 /** @nocollapse */
46943 Mish.className = 'mish';
46944 registerClass(Mish);
46945 /**
46946 * Swish activation function
46947 */
46948 class Swish extends Activation$1 {
46949 /**
46950 * Calculate the activation function.
46951 *
46952 * @param x Tensor.
46953 * @param alpha Scaling factor for the sigmoid function.
46954 * @returns a Tensor of the same shape as x
46955 */
46956 apply(x, alpha = 1) {
46957 return tidy(() => mul(sigmoid$2(mul(x, alpha)), x));
46958 }
46959 }
46960 /** @nocollapse */
46961 Swish.className = 'swish';
46962 registerClass(Swish);
46963 function serializeActivation(activation) {
46964 return activation.getClassName();
46965 }
46966 function deserializeActivation(config, customObjects = {}) {
46967 return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'activation');
46968 }
46969 function getActivation(identifier) {
46970 if (identifier == null) {
46971 const config = {};
46972 config['className'] = 'linear';
46973 config['config'] = {};
46974 return deserializeActivation(config);
46975 }
46976 if (typeof identifier === 'string') {
46977 const config = {};
46978 config['className'] = identifier;
46979 config['config'] = {};
46980 return deserializeActivation(config);
46981 }
46982 else if (identifier instanceof Activation$1) {
46983 return identifier;
46984 }
46985 else {
46986 return deserializeActivation(identifier);
46987 }
46988 }
46989
46990 /**
46991 * @license
46992 * Copyright 2018 Google LLC
46993 *
46994 * Use of this source code is governed by an MIT-style
46995 * license that can be found in the LICENSE file or at
46996 * https://opensource.org/licenses/MIT.
46997 * =============================================================================
46998 */
46999 function assertObjectArgs(args) {
47000 if (args != null && typeof args !== 'object') {
47001 throw new Error(`Argument to L1L2 regularizer's constructor is expected to be an ` +
47002 `object, but received: ${args}`);
47003 }
47004 }
47005 /**
47006 * Regularizer base class.
47007 */
47008 class Regularizer extends Serializable {
47009 }
47010 class L1L2 extends Regularizer {
47011 constructor(args) {
47012 super();
47013 assertObjectArgs(args);
47014 this.l1 = args == null || args.l1 == null ? 0.01 : args.l1;
47015 this.l2 = args == null || args.l2 == null ? 0.01 : args.l2;
47016 this.hasL1 = this.l1 !== 0;
47017 this.hasL2 = this.l2 !== 0;
47018 }
47019 /**
47020 * Porting note: Renamed from __call__.
47021 * @param x Variable of which to calculate the regularization score.
47022 */
47023 apply(x) {
47024 return tidy(() => {
47025 let regularization = zeros$2([1]);
47026 if (this.hasL1) {
47027 regularization = add$3(regularization, sum$3(mul(this.l1, abs$2(x))));
47028 }
47029 if (this.hasL2) {
47030 regularization =
47031 add$3(regularization, sum$3(mul(this.l2, square$1(x))));
47032 }
47033 return reshape$3(regularization, []);
47034 });
47035 }
47036 getConfig() {
47037 return { 'l1': this.l1, 'l2': this.l2 };
47038 }
47039 /** @nocollapse */
47040 static fromConfig(cls, config) {
47041 return new cls({ l1: config['l1'], l2: config['l2'] });
47042 }
47043 }
47044 /** @nocollapse */
47045 L1L2.className = 'L1L2';
47046 registerClass(L1L2);
47047 function l1$1(args) {
47048 assertObjectArgs(args);
47049 return new L1L2({ l1: args != null ? args.l1 : null, l2: 0 });
47050 }
47051 function l2$1(args) {
47052 assertObjectArgs(args);
47053 return new L1L2({ l2: args != null ? args.l2 : null, l1: 0 });
47054 }
47055 // Maps the JavaScript-like identifier keys to the corresponding keras symbols.
47056 const REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
47057 'l1l2': 'L1L2'
47058 };
47059 function serializeRegularizer(constraint) {
47060 return serializeKerasObject(constraint);
47061 }
47062 function deserializeRegularizer(config, customObjects = {}) {
47063 return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'regularizer');
47064 }
47065 function getRegularizer(identifier) {
47066 if (identifier == null) {
47067 return null;
47068 }
47069 if (typeof identifier === 'string') {
47070 const className = identifier in REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
47071 REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
47072 identifier;
47073 const config = { className, config: {} };
47074 return deserializeRegularizer(config);
47075 }
47076 else if (identifier instanceof Regularizer) {
47077 return identifier;
47078 }
47079 else {
47080 return deserializeRegularizer(identifier);
47081 }
47082 }
47083
47084 /**
47085 * @license
47086 * Copyright 2018 Google LLC
47087 *
47088 * Use of this source code is governed by an MIT-style
47089 * license that can be found in the LICENSE file or at
47090 * https://opensource.org/licenses/MIT.
47091 * =============================================================================
47092 */
47093 class ReLU extends Layer {
47094 constructor(args) {
47095 super(args == null ? {} : args);
47096 this.supportsMasking = true;
47097 if (args != null) {
47098 this.maxValue = args.maxValue;
47099 }
47100 }
47101 call(inputs, kwargs) {
47102 inputs = getExactlyOneTensor(inputs);
47103 let output = relu$2(inputs);
47104 if (this.maxValue != null) {
47105 output = clipByValue$2(output, 0, this.maxValue);
47106 }
47107 return output;
47108 }
47109 computeOutputShape(inputShape) {
47110 return inputShape;
47111 }
47112 getConfig() {
47113 const config = { maxValue: this.maxValue };
47114 const baseConfig = super.getConfig();
47115 Object.assign(config, baseConfig);
47116 return config;
47117 }
47118 }
47119 /** @nocollapse */
47120 ReLU.className = 'ReLU';
47121 registerClass(ReLU);
47122 class LeakyReLU extends Layer {
47123 constructor(args) {
47124 super(args == null ? {} : args);
47125 this.DEFAULT_ALPHA = 0.3;
47126 if (args == null) {
47127 args = {};
47128 }
47129 this.alpha = args.alpha == null ? this.DEFAULT_ALPHA : args.alpha;
47130 }
47131 call(inputs, kwargs) {
47132 const x = getExactlyOneTensor(inputs);
47133 return leakyRelu$2(x, this.alpha);
47134 }
47135 computeOutputShape(inputShape) {
47136 return inputShape;
47137 }
47138 getConfig() {
47139 const config = { alpha: this.alpha };
47140 const baseConfig = super.getConfig();
47141 Object.assign(config, baseConfig);
47142 return config;
47143 }
47144 }
47145 /** @nocollapse */
47146 LeakyReLU.className = 'LeakyReLU';
47147 registerClass(LeakyReLU);
47148 class PReLU extends Layer {
47149 constructor(args) {
47150 super(args == null ? {} : args);
47151 this.DEFAULT_ALPHA_INITIALIZER = 'zeros';
47152 if (args == null) {
47153 args = {};
47154 }
47155 this.supportsMasking = true;
47156 this.alphaInitializer =
47157 getInitializer(args.alphaInitializer || this.DEFAULT_ALPHA_INITIALIZER);
47158 this.alphaRegularizer = getRegularizer(args.alphaRegularizer);
47159 this.alphaConstraint = getConstraint(args.alphaConstraint);
47160 if (args.sharedAxes == null) {
47161 this.sharedAxes = null;
47162 }
47163 else if (Array.isArray(args.sharedAxes)) {
47164 this.sharedAxes = args.sharedAxes;
47165 }
47166 else if (typeof args.sharedAxes === 'number') {
47167 this.sharedAxes = [args.sharedAxes];
47168 }
47169 else {
47170 throw new ValueError(`Expected sharedAxes to be a number or an array of numbers, ` +
47171 `but got ${args.sharedAxes}`);
47172 }
47173 }
47174 build(inputShape) {
47175 inputShape = getExactlyOneShape(inputShape);
47176 const paramShape = inputShape.slice(1);
47177 if (this.sharedAxes != null) {
47178 for (const i of this.sharedAxes) {
47179 paramShape[i - 1] = 1;
47180 }
47181 }
47182 this.alpha = this.addWeight('alpha', paramShape, 'float32', this.alphaInitializer, this.alphaRegularizer, true, this.alphaConstraint);
47183 // Set input spec.
47184 const axes = {};
47185 if (this.sharedAxes != null) {
47186 for (let i = 1; i < inputShape.length; ++i) {
47187 axes[i] = inputShape[i];
47188 }
47189 }
47190 this.inputSpec = [new InputSpec({
47191 ndim: inputShape.length,
47192 axes,
47193 })];
47194 this.built = true;
47195 }
47196 call(inputs, kwargs) {
47197 inputs = getExactlyOneTensor(inputs);
47198 return prelu$3(inputs, this.alpha.read());
47199 }
47200 getConfig() {
47201 const config = {
47202 alphaInitializer: serializeInitializer(this.alphaInitializer),
47203 alphaRegularizer: serializeRegularizer(this.alphaRegularizer),
47204 alphaConstraint: serializeConstraint(this.alphaConstraint),
47205 sharedAxes: this.sharedAxes
47206 };
47207 const baseConfig = super.getConfig();
47208 Object.assign(config, baseConfig);
47209 return config;
47210 }
47211 }
47212 /** @nocollapse */
47213 PReLU.className = 'PReLU';
47214 registerClass(PReLU);
47215 let ELU$3 = class ELU extends Layer {
47216 constructor(args) {
47217 super(args == null ? {} : args);
47218 this.DEFAULT_ALPHA = 1.0;
47219 if (args == null) {
47220 args = {};
47221 }
47222 if (args.alpha != null && args.alpha !== this.DEFAULT_ALPHA) {
47223 throw new NotImplementedError(`Non-default alpha value (${args.alpha}) is not supported by the ` +
47224 `ELU layer yet.`);
47225 }
47226 this.alpha = args.alpha == null ? this.DEFAULT_ALPHA : args.alpha;
47227 }
47228 call(inputs, kwargs) {
47229 const x = getExactlyOneTensor(inputs);
47230 return elu$4(x);
47231 }
47232 computeOutputShape(inputShape) {
47233 return inputShape;
47234 }
47235 getConfig() {
47236 const config = { alpha: this.alpha };
47237 const baseConfig = super.getConfig();
47238 Object.assign(config, baseConfig);
47239 return config;
47240 }
47241 };
47242 /** @nocollapse */
47243 ELU$3.className = 'ELU';
47244 registerClass(ELU$3);
47245 class ThresholdedReLU extends Layer {
47246 constructor(args) {
47247 super(args == null ? {} : args);
47248 this.DEFAULT_THETA = 1.0;
47249 if (args == null) {
47250 args = {};
47251 }
47252 this.theta = args.theta == null ? this.DEFAULT_THETA : args.theta;
47253 }
47254 call(inputs, kwargs) {
47255 const x = getExactlyOneTensor(inputs);
47256 return mul(x, cast$3(greater$3(x, this.theta), 'float32'));
47257 }
47258 computeOutputShape(inputShape) {
47259 return inputShape;
47260 }
47261 getConfig() {
47262 const config = { theta: this.theta };
47263 const baseConfig = super.getConfig();
47264 Object.assign(config, baseConfig);
47265 return config;
47266 }
47267 }
47268 /** @nocollapse */
47269 ThresholdedReLU.className = 'ThresholdedReLU';
47270 registerClass(ThresholdedReLU);
47271 class Softmax extends Layer {
47272 constructor(args) {
47273 super(args == null ? {} : args);
47274 this.DEFAULT_AXIS = 1.0;
47275 if (args == null) {
47276 args = {};
47277 }
47278 this.softmax = new Softmax$1().apply;
47279 this.axis = args.axis == null ? this.DEFAULT_AXIS : args.axis;
47280 }
47281 call(inputs, kwargs) {
47282 // TODO(pforderique): Add tests for when `this.axis` is a number[].
47283 return tidy(() => {
47284 let x = getExactlyOneTensor(inputs);
47285 const mask = kwargs['mask'];
47286 if (mask != null) {
47287 // Since mask is 1.0 for positions we want to keep and 0.0 for masked
47288 // positions, this operation will create a tensor which is 0.0 for
47289 // positions we want to attend and -1e.9 for masked positions.
47290 const adder = mul(sub$2(ones$1(x.shape), cast$3(mask, x.dtype)), scalar(-1e9));
47291 // Since we are adding it to the raw scores before the softmax, this
47292 // is effectively the same as removing these entirely.
47293 x = add$3(x, adder);
47294 }
47295 if (this.axis instanceof Array) {
47296 if (this.axis.length > 1) {
47297 return exp$2(sub$2(x, logSumExp(x, this.axis, true)));
47298 }
47299 else {
47300 return this.softmax(x, this.axis[0]);
47301 }
47302 }
47303 return this.softmax(x, this.axis);
47304 });
47305 }
47306 computeOutputShape(inputShape) {
47307 return inputShape;
47308 }
47309 getConfig() {
47310 const config = { axis: this.axis };
47311 const baseConfig = super.getConfig();
47312 Object.assign(config, baseConfig);
47313 return config;
47314 }
47315 }
47316 /** @nocollapse */
47317 Softmax.className = 'Softmax';
47318 registerClass(Softmax);
47319
47320 /**
47321 * @license
47322 * Copyright 2018 Google LLC
47323 *
47324 * Use of this source code is governed by an MIT-style
47325 * license that can be found in the LICENSE file or at
47326 * https://opensource.org/licenses/MIT.
47327 * =============================================================================
47328 */
47329 /**
47330 * Transforms a single number of array of numbers into an array of numbers.
47331 * @param value
47332 * @param n: The size of the tuple to be returned.
47333 * @param name: Name of the parameter, used for generating error messages.
47334 * @returns An array of numbers.
47335 */
47336 function normalizeArray(value, n, name) {
47337 if (typeof value === 'number') {
47338 return pyListRepeat(value, n);
47339 }
47340 else {
47341 if (value.length !== n) {
47342 throw new ValueError(`The ${name} argument must be an integer or tuple of ${n} integers.` +
47343 ` Received: ${value.length} elements.`);
47344 }
47345 for (let i = 0; i < n; ++i) {
47346 const singleValue = value[i];
47347 if (!isInteger(singleValue)) {
47348 throw new ValueError(`The ${name} argument must be an integer or tuple of ${n}` +
47349 ` integers. Received: ${JSON.stringify(value)} including a` +
47350 ` non-integer number ${singleValue}`);
47351 }
47352 }
47353 return value;
47354 }
47355 }
47356 /**
47357 * Determines output length of a convolution given input length.
47358 * @param inputLength
47359 * @param filterSize
47360 * @param padding
47361 * @param stride
47362 * @param dilation: dilation rate.
47363 */
47364 function convOutputLength(inputLength, filterSize, padding, stride, dilation = 1) {
47365 if (inputLength == null) {
47366 return inputLength;
47367 }
47368 const dilatedFilterSize = filterSize + (filterSize - 1) * (dilation - 1);
47369 let outputLength;
47370 if (padding === 'same') {
47371 outputLength = inputLength;
47372 }
47373 else { // VALID
47374 outputLength = inputLength - dilatedFilterSize + 1;
47375 }
47376 return Math.floor((outputLength + stride - 1) / stride);
47377 }
47378 function deconvLength(dimSize, strideSize, kernelSize, padding) {
47379 if (dimSize == null) {
47380 return null;
47381 }
47382 if (padding === 'valid') {
47383 dimSize = dimSize * strideSize + max$2([kernelSize - strideSize, 0]);
47384 }
47385 else if (padding === 'same') {
47386 dimSize = dimSize * strideSize;
47387 }
47388 else {
47389 throw new ValueError(`Unsupport padding mode: ${padding}.`);
47390 }
47391 return dimSize;
47392 }
47393
47394 /**
47395 * @license
47396 * Copyright 2018 Google LLC
47397 *
47398 * Use of this source code is governed by an MIT-style
47399 * license that can be found in the LICENSE file or at
47400 * https://opensource.org/licenses/MIT.
47401 * =============================================================================
47402 */
47403 /**
47404 * Transpose and cast the input before the conv2d.
47405 * @param x Input image tensor.
47406 * @param dataFormat
47407 */
47408 function preprocessConv2DInput(x, dataFormat) {
47409 // TODO(cais): Cast type to float32 if not.
47410 return tidy(() => {
47411 checkDataFormat(dataFormat);
47412 if (dataFormat === 'channelsFirst') {
47413 return transpose$2(x, [0, 2, 3, 1]); // NCHW -> NHWC.
47414 }
47415 else {
47416 return x;
47417 }
47418 });
47419 }
47420 /**
47421 * Transpose and cast the input before the conv3d.
47422 * @param x Input image tensor.
47423 * @param dataFormat
47424 */
47425 function preprocessConv3DInput(x, dataFormat) {
47426 return tidy(() => {
47427 checkDataFormat(dataFormat);
47428 if (dataFormat === 'channelsFirst') {
47429 return transpose$2(x, [0, 2, 3, 4, 1]); // NCDHW -> NDHWC.
47430 }
47431 else {
47432 return x;
47433 }
47434 });
47435 }
47436 /**
47437 * 1D-convolution with bias added.
47438 *
47439 * Porting Note: This function does not exist in the Python Keras backend.
47440 * It is exactly the same as `conv2d`, except the added `bias`.
47441 *
47442 * @param x Input tensor, rank-3, of shape `[batchSize, width, inChannels]`.
47443 * @param kernel Kernel, rank-3, of shape `[filterWidth, inDepth, outDepth]`.
47444 * @param bias Bias, rank-3, of shape `[outDepth]`.
47445 * @param strides
47446 * @param padding Padding mode.
47447 * @param dataFormat Data format.
47448 * @param dilationRate
47449 * @returns The result of the 1D convolution.
47450 * @throws ValueError, if `x`, `kernel` or `bias` is not of the correct rank.
47451 */
47452 function conv1dWithBias(x, kernel, bias, strides = 1, padding = 'valid', dataFormat, dilationRate = 1) {
47453 return tidy(() => {
47454 if (dataFormat == null) {
47455 dataFormat = imageDataFormat();
47456 }
47457 checkDataFormat(dataFormat);
47458 // Check the ranks of x, kernel and bias.
47459 if (x.shape.length !== 3) {
47460 throw new ValueError(`The input of a conv1dWithBias operation should be 3, but is ` +
47461 `${x.shape.length} instead.`);
47462 }
47463 if (kernel.shape.length !== 3) {
47464 throw new ValueError(`The kernel for a conv1dWithBias operation should be 3, but is ` +
47465 `${kernel.shape.length} instead`);
47466 }
47467 if (bias != null && bias.shape.length !== 1) {
47468 throw new ValueError(`The bias for a conv1dWithBias operation should be 1, but is ` +
47469 `${bias.shape.length} instead`);
47470 }
47471 // TODO(cais): Support CAUSAL padding mode.
47472 if (dataFormat === 'channelsFirst') {
47473 x = transpose$2(x, [0, 2, 1]); // NCW -> NWC.
47474 }
47475 if (padding === 'causal') {
47476 throw new NotImplementedError('The support for CAUSAL padding mode in conv1dWithBias is not ' +
47477 'implemented yet.');
47478 }
47479 let y = conv1d$2(x, kernel, strides, padding === 'same' ? 'same' : 'valid', 'NWC', dilationRate);
47480 if (bias != null) {
47481 y = biasAdd(y, bias);
47482 }
47483 return y;
47484 });
47485 }
47486 /**
47487 * 1D-convolution.
47488 *
47489 * @param x Input tensor, rank-3, of shape `[batchSize, width, inChannels]`.
47490 * @param kernel Kernel, rank-3, of shape `[filterWidth, inDepth, outDepth]`.s
47491 * @param strides
47492 * @param padding Padding mode.
47493 * @param dataFormat Data format.
47494 * @param dilationRate
47495 * @returns The result of the 1D convolution.
47496 * @throws ValueError, if `x`, `kernel` or `bias` is not of the correct rank.
47497 */
47498 function conv1d$1(x, kernel, strides = 1, padding = 'valid', dataFormat, dilationRate = 1) {
47499 return tidy(() => {
47500 checkDataFormat(dataFormat);
47501 return conv1dWithBias(x, kernel, null, strides, padding, dataFormat, dilationRate);
47502 });
47503 }
47504 /**
47505 * 2D Convolution
47506 * @param x
47507 * @param kernel kernel of the convolution.
47508 * @param strides strides array.
47509 * @param padding padding mode. Default to 'valid'.
47510 * @param dataFormat data format. Defaults to 'channelsLast'.
47511 * @param dilationRate dilation rate array.
47512 * @returns Result of the 2D pooling.
47513 */
47514 function conv2d$2(x, kernel, strides = [1, 1], padding = 'valid', dataFormat, dilationRate) {
47515 return tidy(() => {
47516 checkDataFormat(dataFormat);
47517 return conv2dWithBiasActivation(x, kernel, null, strides, padding, dataFormat, dilationRate);
47518 });
47519 }
47520 /**
47521 * 2D Convolution with an added bias and optional activation.
47522 * Note: This function does not exist in the Python Keras Backend. This function
47523 * is exactly the same as `conv2d`, except the added `bias`.
47524 */
47525 function conv2dWithBiasActivation(x, kernel, bias, strides = [1, 1], padding = 'valid', dataFormat, dilationRate, activation = null) {
47526 return tidy(() => {
47527 if (dataFormat == null) {
47528 dataFormat = imageDataFormat();
47529 }
47530 checkDataFormat(dataFormat);
47531 if (x.rank !== 3 && x.rank !== 4) {
47532 throw new ValueError(`conv2dWithBiasActivation expects input to be of rank 3 or 4, ` +
47533 `but received ${x.rank}.`);
47534 }
47535 if (kernel.rank !== 3 && kernel.rank !== 4) {
47536 throw new ValueError(`conv2dWithBiasActivation expects kernel to be of rank 3 or 4, ` +
47537 `but received ${x.rank}.`);
47538 }
47539 let y = preprocessConv2DInput(x, dataFormat);
47540 if (padding === 'causal') {
47541 throw new NotImplementedError('The support for CAUSAL padding mode in conv1dWithBias is not ' +
47542 'implemented yet.');
47543 }
47544 y = conv2d$3({
47545 x: y,
47546 filter: kernel,
47547 strides: strides,
47548 pad: padding === 'same' ? 'same' : 'valid',
47549 dilations: dilationRate,
47550 dataFormat: 'NHWC',
47551 bias,
47552 activation
47553 });
47554 if (dataFormat === 'channelsFirst') {
47555 y = transpose$2(y, [0, 3, 1, 2]);
47556 }
47557 return y;
47558 });
47559 }
47560 /**
47561 * 3D Convolution.
47562 * @param x
47563 * @param kernel kernel of the convolution.
47564 * @param strides strides array.
47565 * @param padding padding mode. Default to 'valid'.
47566 * @param dataFormat data format. Defaults to 'channelsLast'.
47567 * @param dilationRate dilation rate array.
47568 * @returns Result of the 3D convolution.
47569 */
47570 function conv3d$1(x, kernel, strides = [1, 1, 1], padding = 'valid', dataFormat, dilationRate) {
47571 return tidy(() => {
47572 checkDataFormat(dataFormat);
47573 return conv3dWithBias(x, kernel, null, strides, padding, dataFormat, dilationRate);
47574 });
47575 }
47576 /**
47577 * 3D Convolution with an added bias.
47578 * Note: This function does not exist in the Python Keras Backend. This function
47579 * is exactly the same as `conv3d`, except the added `bias`.
47580 */
47581 function conv3dWithBias(x, kernel, bias, strides = [1, 1, 1], padding = 'valid', dataFormat, dilationRate) {
47582 return tidy(() => {
47583 if (dataFormat == null) {
47584 dataFormat = imageDataFormat();
47585 }
47586 checkDataFormat(dataFormat);
47587 if (x.rank !== 4 && x.rank !== 5) {
47588 throw new ValueError(`conv3dWithBias expects input to be of rank 4 or 5, but received ` +
47589 `${x.rank}.`);
47590 }
47591 if (kernel.rank !== 4 && kernel.rank !== 5) {
47592 throw new ValueError(`conv3dWithBias expects kernel to be of rank 4 or 5, but received ` +
47593 `${x.rank}.`);
47594 }
47595 let y = preprocessConv3DInput(x, dataFormat);
47596 if (padding === 'causal') {
47597 throw new NotImplementedError('The support for CAUSAL padding mode in conv3dWithBias is not ' +
47598 'implemented yet.');
47599 }
47600 y = conv3d$2(y, kernel, strides, padding === 'same' ? 'same' : 'valid', 'NDHWC', dilationRate);
47601 if (bias != null) {
47602 y = biasAdd(y, bias);
47603 }
47604 if (dataFormat === 'channelsFirst') {
47605 y = transpose$2(y, [0, 4, 1, 2, 3]);
47606 }
47607 return y;
47608 });
47609 }
47610 /**
47611 * Abstract convolution layer.
47612 */
47613 class BaseConv extends Layer {
47614 constructor(rank, args) {
47615 super(args);
47616 this.bias = null;
47617 this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
47618 this.DEFAULT_BIAS_INITIALIZER = 'zeros';
47619 BaseConv.verifyArgs(args);
47620 this.rank = rank;
47621 assertPositiveInteger(this.rank, 'rank');
47622 if (this.rank !== 1 && this.rank !== 2 && this.rank !== 3) {
47623 throw new NotImplementedError(`Convolution layer for rank other than 1, 2, or 3 (${this.rank}) is ` +
47624 `not implemented yet.`);
47625 }
47626 this.kernelSize = normalizeArray(args.kernelSize, rank, 'kernelSize');
47627 this.strides = normalizeArray(args.strides == null ? 1 : args.strides, rank, 'strides');
47628 this.padding = args.padding == null ? 'valid' : args.padding;
47629 checkPaddingMode(this.padding);
47630 this.dataFormat =
47631 args.dataFormat == null ? 'channelsLast' : args.dataFormat;
47632 checkDataFormat(this.dataFormat);
47633 this.activation = getActivation(args.activation);
47634 this.useBias = args.useBias == null ? true : args.useBias;
47635 this.biasInitializer =
47636 getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
47637 this.biasConstraint = getConstraint(args.biasConstraint);
47638 this.biasRegularizer = getRegularizer(args.biasRegularizer);
47639 this.activityRegularizer = getRegularizer(args.activityRegularizer);
47640 this.dilationRate = normalizeArray(args.dilationRate == null ? 1 : args.dilationRate, rank, 'dilationRate');
47641 if (this.rank === 1 &&
47642 (Array.isArray(this.dilationRate) && this.dilationRate.length !== 1)) {
47643 throw new ValueError(`dilationRate must be a number or an array of a single number ` +
47644 `for 1D convolution, but received ` +
47645 `${JSON.stringify(this.dilationRate)}`);
47646 }
47647 else if (this.rank === 2) {
47648 if (typeof this.dilationRate === 'number') {
47649 this.dilationRate = [this.dilationRate, this.dilationRate];
47650 }
47651 else if (this.dilationRate.length !== 2) {
47652 throw new ValueError(`dilationRate must be a number or array of two numbers for 2D ` +
47653 `convolution, but received ${JSON.stringify(this.dilationRate)}`);
47654 }
47655 }
47656 else if (this.rank === 3) {
47657 if (typeof this.dilationRate === 'number') {
47658 this.dilationRate =
47659 [this.dilationRate, this.dilationRate, this.dilationRate];
47660 }
47661 else if (this.dilationRate.length !== 3) {
47662 throw new ValueError(`dilationRate must be a number or array of three numbers for 3D ` +
47663 `convolution, but received ${JSON.stringify(this.dilationRate)}`);
47664 }
47665 }
47666 }
47667 static verifyArgs(args) {
47668 // Check config.kernelSize type and shape.
47669 assert('kernelSize' in args, `required key 'kernelSize' not in config`);
47670 if (typeof args.kernelSize !== 'number' &&
47671 !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 3)) {
47672 throw new ValueError(`BaseConv expects config.kernelSize to be number or number[] with ` +
47673 `length 1, 2, or 3, but received ${JSON.stringify(args.kernelSize)}.`);
47674 }
47675 }
47676 getConfig() {
47677 const config = {
47678 kernelSize: this.kernelSize,
47679 strides: this.strides,
47680 padding: this.padding,
47681 dataFormat: this.dataFormat,
47682 dilationRate: this.dilationRate,
47683 activation: serializeActivation(this.activation),
47684 useBias: this.useBias,
47685 biasInitializer: serializeInitializer(this.biasInitializer),
47686 biasRegularizer: serializeRegularizer(this.biasRegularizer),
47687 activityRegularizer: serializeRegularizer(this.activityRegularizer),
47688 biasConstraint: serializeConstraint(this.biasConstraint)
47689 };
47690 const baseConfig = super.getConfig();
47691 Object.assign(config, baseConfig);
47692 return config;
47693 }
47694 }
47695 /**
47696 * Abstract nD convolution layer. Ancestor of convolution layers which reduce
47697 * across channels, i.e., Conv1D and Conv2D, but not DepthwiseConv2D.
47698 */
47699 class Conv extends BaseConv {
47700 constructor(rank, args) {
47701 super(rank, args);
47702 this.kernel = null;
47703 Conv.verifyArgs(args);
47704 this.filters = args.filters;
47705 assertPositiveInteger(this.filters, 'filters');
47706 this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
47707 this.kernelConstraint = getConstraint(args.kernelConstraint);
47708 this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
47709 }
47710 build(inputShape) {
47711 inputShape = getExactlyOneShape(inputShape);
47712 const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
47713 if (inputShape[channelAxis] == null) {
47714 throw new ValueError(`The channel dimension of the input should be defined. ` +
47715 `Found ${inputShape[channelAxis]}`);
47716 }
47717 const inputDim = inputShape[channelAxis];
47718 const kernelShape = this.kernelSize.concat([inputDim, this.filters]);
47719 this.kernel = this.addWeight('kernel', kernelShape, null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
47720 if (this.useBias) {
47721 this.bias = this.addWeight('bias', [this.filters], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
47722 }
47723 this.inputSpec = [{ ndim: this.rank + 2, axes: { [channelAxis]: inputDim } }];
47724 this.built = true;
47725 }
47726 call(inputs, kwargs) {
47727 return tidy(() => {
47728 inputs = getExactlyOneTensor(inputs);
47729 let outputs;
47730 const biasValue = this.bias == null ? null : this.bias.read();
47731 const fusedActivationName = mapActivationToFusedKernel(this.activation.getClassName());
47732 if (fusedActivationName != null && this.rank === 2) {
47733 outputs = conv2dWithBiasActivation(inputs, this.kernel.read(), biasValue, this.strides, this.padding, this.dataFormat, this.dilationRate, fusedActivationName);
47734 }
47735 else {
47736 if (this.rank === 1) {
47737 outputs = conv1dWithBias(inputs, this.kernel.read(), biasValue, this.strides[0], this.padding, this.dataFormat, this.dilationRate[0]);
47738 }
47739 else if (this.rank === 2) {
47740 // TODO(cais): Move up to constructor.
47741 outputs = conv2dWithBiasActivation(inputs, this.kernel.read(), biasValue, this.strides, this.padding, this.dataFormat, this.dilationRate);
47742 }
47743 else if (this.rank === 3) {
47744 outputs = conv3dWithBias(inputs, this.kernel.read(), biasValue, this.strides, this.padding, this.dataFormat, this.dilationRate);
47745 }
47746 else {
47747 throw new NotImplementedError('convolutions greater than 3D are not implemented yet.');
47748 }
47749 if (this.activation != null) {
47750 outputs = this.activation.apply(outputs);
47751 }
47752 }
47753 return outputs;
47754 });
47755 }
47756 computeOutputShape(inputShape) {
47757 inputShape = getExactlyOneShape(inputShape);
47758 const newSpace = [];
47759 const space = (this.dataFormat === 'channelsLast') ?
47760 inputShape.slice(1, inputShape.length - 1) :
47761 inputShape.slice(2);
47762 for (let i = 0; i < space.length; ++i) {
47763 const newDim = convOutputLength(space[i], this.kernelSize[i], this.padding, this.strides[i], typeof this.dilationRate === 'number' ? this.dilationRate :
47764 this.dilationRate[i]);
47765 newSpace.push(newDim);
47766 }
47767 let outputShape = [inputShape[0]];
47768 if (this.dataFormat === 'channelsLast') {
47769 outputShape = outputShape.concat(newSpace);
47770 outputShape.push(this.filters);
47771 }
47772 else {
47773 outputShape.push(this.filters);
47774 outputShape = outputShape.concat(newSpace);
47775 }
47776 return outputShape;
47777 }
47778 getConfig() {
47779 const config = {
47780 filters: this.filters,
47781 kernelInitializer: serializeInitializer(this.kernelInitializer),
47782 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
47783 kernelConstraint: serializeConstraint(this.kernelConstraint)
47784 };
47785 const baseConfig = super.getConfig();
47786 Object.assign(config, baseConfig);
47787 return config;
47788 }
47789 static verifyArgs(args) {
47790 // Check config.filters type, shape, and value.
47791 if (!('filters' in args) || typeof args.filters !== 'number' ||
47792 args.filters < 1) {
47793 throw new ValueError(`Convolution layer expected config.filters to be a 'number' > 0 ` +
47794 `but got ${JSON.stringify(args.filters)}`);
47795 }
47796 }
47797 }
47798 class Conv2D extends Conv {
47799 constructor(args) {
47800 super(2, args);
47801 Conv2D.verifyArgs(args);
47802 }
47803 getConfig() {
47804 const config = super.getConfig();
47805 delete config['rank'];
47806 return config;
47807 }
47808 static verifyArgs(args) {
47809 // config.kernelSize must be a number or array of numbers.
47810 if ((typeof args.kernelSize !== 'number') &&
47811 !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 2)) {
47812 throw new ValueError(`Conv2D expects config.kernelSize to be number or number[] with ` +
47813 `length 1 or 2, but received ${JSON.stringify(args.kernelSize)}.`);
47814 }
47815 }
47816 }
47817 /** @nocollapse */
47818 Conv2D.className = 'Conv2D';
47819 registerClass(Conv2D);
47820 class Conv3D extends Conv {
47821 constructor(args) {
47822 super(3, args);
47823 Conv3D.verifyArgs(args);
47824 }
47825 getConfig() {
47826 const config = super.getConfig();
47827 delete config['rank'];
47828 return config;
47829 }
47830 static verifyArgs(args) {
47831 // config.kernelSize must be a number or array of numbers.
47832 if (typeof args.kernelSize !== 'number') {
47833 if (!(Array.isArray(args.kernelSize) &&
47834 (args.kernelSize.length === 1 || args.kernelSize.length === 3))) {
47835 throw new ValueError(`Conv3D expects config.kernelSize to be number or` +
47836 ` [number, number, number], but received ${JSON.stringify(args.kernelSize)}.`);
47837 }
47838 }
47839 }
47840 }
47841 /** @nocollapse */
47842 Conv3D.className = 'Conv3D';
47843 registerClass(Conv3D);
47844 class Conv2DTranspose extends Conv2D {
47845 constructor(args) {
47846 super(args);
47847 this.inputSpec = [new InputSpec({ ndim: 4 })];
47848 if (this.padding !== 'same' && this.padding !== 'valid') {
47849 throw new ValueError(`Conv2DTranspose currently supports only padding modes 'same' ` +
47850 `and 'valid', but received padding mode ${this.padding}`);
47851 }
47852 }
47853 build(inputShape) {
47854 inputShape = getExactlyOneShape(inputShape);
47855 if (inputShape.length !== 4) {
47856 throw new ValueError('Input should have rank 4; Received input shape: ' +
47857 JSON.stringify(inputShape));
47858 }
47859 const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
47860 if (inputShape[channelAxis] == null) {
47861 throw new ValueError('The channel dimension of the inputs should be defined. ' +
47862 'Found `None`.');
47863 }
47864 const inputDim = inputShape[channelAxis];
47865 const kernelShape = this.kernelSize.concat([this.filters, inputDim]);
47866 this.kernel = this.addWeight('kernel', kernelShape, 'float32', this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
47867 if (this.useBias) {
47868 this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
47869 }
47870 // Set input spec.
47871 this.inputSpec =
47872 [new InputSpec({ ndim: 4, axes: { [channelAxis]: inputDim } })];
47873 this.built = true;
47874 }
47875 call(inputs, kwargs) {
47876 return tidy(() => {
47877 let input = getExactlyOneTensor(inputs);
47878 if (input.shape.length !== 4) {
47879 throw new ValueError(`Conv2DTranspose.call() expects input tensor to be rank-4, but ` +
47880 `received a tensor of rank-${input.shape.length}`);
47881 }
47882 const inputShape = input.shape;
47883 const batchSize = inputShape[0];
47884 let hAxis;
47885 let wAxis;
47886 if (this.dataFormat === 'channelsFirst') {
47887 hAxis = 2;
47888 wAxis = 3;
47889 }
47890 else {
47891 hAxis = 1;
47892 wAxis = 2;
47893 }
47894 const height = inputShape[hAxis];
47895 const width = inputShape[wAxis];
47896 const kernelH = this.kernelSize[0];
47897 const kernelW = this.kernelSize[1];
47898 const strideH = this.strides[0];
47899 const strideW = this.strides[1];
47900 // Infer the dynamic output shape.
47901 const outHeight = deconvLength(height, strideH, kernelH, this.padding);
47902 const outWidth = deconvLength(width, strideW, kernelW, this.padding);
47903 // Porting Note: We don't branch based on `this.dataFormat` here,
47904 // because
47905 // the tjfs-core function `conv2dTranspose` called below always
47906 // assumes channelsLast.
47907 const outputShape = [batchSize, outHeight, outWidth, this.filters];
47908 if (this.dataFormat !== 'channelsLast') {
47909 input = transpose$2(input, [0, 2, 3, 1]);
47910 }
47911 let outputs = conv2dTranspose$1(input, this.kernel.read(), outputShape, this.strides, this.padding);
47912 if (this.dataFormat !== 'channelsLast') {
47913 outputs = transpose$2(outputs, [0, 3, 1, 2]);
47914 }
47915 if (this.bias != null) {
47916 outputs =
47917 biasAdd(outputs, this.bias.read(), this.dataFormat);
47918 }
47919 if (this.activation != null) {
47920 outputs = this.activation.apply(outputs);
47921 }
47922 return outputs;
47923 });
47924 }
47925 computeOutputShape(inputShape) {
47926 inputShape = getExactlyOneShape(inputShape);
47927 const outputShape = inputShape.slice();
47928 let channelAxis;
47929 let heightAxis;
47930 let widthAxis;
47931 if (this.dataFormat === 'channelsFirst') {
47932 channelAxis = 1;
47933 heightAxis = 2;
47934 widthAxis = 3;
47935 }
47936 else {
47937 channelAxis = 3;
47938 heightAxis = 1;
47939 widthAxis = 2;
47940 }
47941 const kernelH = this.kernelSize[0];
47942 const kernelW = this.kernelSize[1];
47943 const strideH = this.strides[0];
47944 const strideW = this.strides[1];
47945 outputShape[channelAxis] = this.filters;
47946 outputShape[heightAxis] =
47947 deconvLength(outputShape[heightAxis], strideH, kernelH, this.padding);
47948 outputShape[widthAxis] =
47949 deconvLength(outputShape[widthAxis], strideW, kernelW, this.padding);
47950 return outputShape;
47951 }
47952 getConfig() {
47953 const config = super.getConfig();
47954 delete config['dilationRate'];
47955 return config;
47956 }
47957 }
47958 /** @nocollapse */
47959 Conv2DTranspose.className = 'Conv2DTranspose';
47960 registerClass(Conv2DTranspose);
47961 class Conv3DTranspose extends Conv3D {
47962 constructor(args) {
47963 super(args);
47964 this.inputSpec = [new InputSpec({ ndim: 5 })];
47965 if (this.padding !== 'same' && this.padding !== 'valid') {
47966 throw new ValueError(`Conv3DTranspose currently supports only padding modes 'same' ` +
47967 `and 'valid', but received padding mode ${this.padding}`);
47968 }
47969 }
47970 build(inputShape) {
47971 inputShape = getExactlyOneShape(inputShape);
47972 if (inputShape.length !== 5) {
47973 throw new ValueError('Input should have rank 5; Received input shape: ' +
47974 JSON.stringify(inputShape));
47975 }
47976 const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
47977 if (inputShape[channelAxis] == null) {
47978 throw new ValueError('The channel dimension of the inputs should be defined. ' +
47979 'Found `None`.');
47980 }
47981 const inputDim = inputShape[channelAxis];
47982 const kernelShape = this.kernelSize.concat([this.filters, inputDim]);
47983 this.kernel = this.addWeight('kernel', kernelShape, 'float32', this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
47984 if (this.useBias) {
47985 this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
47986 }
47987 // Set input spec.
47988 this.inputSpec =
47989 [new InputSpec({ ndim: 5, axes: { [channelAxis]: inputDim } })];
47990 this.built = true;
47991 }
47992 call(inputs, kwargs) {
47993 return tidy(() => {
47994 let input = getExactlyOneTensor(inputs);
47995 if (input.shape.length !== 5) {
47996 throw new ValueError(`Conv3DTranspose.call() expects input tensor to be rank-4, but ` +
47997 `received a tensor of rank-${input.shape.length}`);
47998 }
47999 const inputShape = input.shape;
48000 const batchSize = inputShape[0];
48001 let hAxis;
48002 let wAxis;
48003 let dAxis;
48004 if (this.dataFormat === 'channelsFirst') {
48005 dAxis = 2;
48006 hAxis = 3;
48007 wAxis = 4;
48008 }
48009 else {
48010 dAxis = 1;
48011 hAxis = 2;
48012 wAxis = 3;
48013 }
48014 const depth = inputShape[dAxis];
48015 const height = inputShape[hAxis];
48016 const width = inputShape[wAxis];
48017 const kernelD = this.kernelSize[0];
48018 const kernelH = this.kernelSize[1];
48019 const kernelW = this.kernelSize[2];
48020 const strideD = this.strides[0];
48021 const strideH = this.strides[1];
48022 const strideW = this.strides[2];
48023 // Infer the dynamic output shape.
48024 const outDepth = deconvLength(depth, strideD, kernelD, this.padding);
48025 const outHeight = deconvLength(height, strideH, kernelH, this.padding);
48026 const outWidth = deconvLength(width, strideW, kernelW, this.padding);
48027 // Same as `conv2dTranspose`. We always assumes channelsLast.
48028 const outputShape = [batchSize, outDepth, outHeight, outWidth, this.filters];
48029 if (this.dataFormat !== 'channelsLast') {
48030 input = transpose$2(input, [0, 2, 3, 4, 1]);
48031 }
48032 let outputs = conv3dTranspose$1(input, this.kernel.read(), outputShape, this.strides, this.padding);
48033 if (this.dataFormat !== 'channelsLast') {
48034 outputs = transpose$2(outputs, [0, 4, 1, 2, 3]);
48035 }
48036 if (this.bias !== null) {
48037 outputs =
48038 biasAdd(outputs, this.bias.read(), this.dataFormat);
48039 }
48040 if (this.activation !== null) {
48041 outputs = this.activation.apply(outputs);
48042 }
48043 return outputs;
48044 });
48045 }
48046 computeOutputShape(inputShape) {
48047 inputShape = getExactlyOneShape(inputShape);
48048 const outputShape = inputShape.slice();
48049 let channelAxis;
48050 let depthAxis;
48051 let heightAxis;
48052 let widthAxis;
48053 if (this.dataFormat === 'channelsFirst') {
48054 channelAxis = 1;
48055 depthAxis = 2;
48056 heightAxis = 3;
48057 widthAxis = 4;
48058 }
48059 else {
48060 channelAxis = 4;
48061 depthAxis = 1;
48062 heightAxis = 2;
48063 widthAxis = 3;
48064 }
48065 const kernelD = this.kernelSize[0];
48066 const kernelH = this.kernelSize[1];
48067 const kernelW = this.kernelSize[2];
48068 const strideD = this.strides[0];
48069 const strideH = this.strides[1];
48070 const strideW = this.strides[2];
48071 outputShape[channelAxis] = this.filters;
48072 outputShape[depthAxis] =
48073 deconvLength(outputShape[depthAxis], strideD, kernelD, this.padding);
48074 outputShape[heightAxis] =
48075 deconvLength(outputShape[heightAxis], strideH, kernelH, this.padding);
48076 outputShape[widthAxis] =
48077 deconvLength(outputShape[widthAxis], strideW, kernelW, this.padding);
48078 return outputShape;
48079 }
48080 getConfig() {
48081 const config = super.getConfig();
48082 delete config['dilationRate'];
48083 return config;
48084 }
48085 }
48086 /** @nocollapse */
48087 Conv3DTranspose.className = 'Conv3DTranspose';
48088 registerClass(Conv3DTranspose);
48089 class SeparableConv extends Conv {
48090 constructor(rank, config) {
48091 super(rank, config);
48092 this.DEFAULT_DEPTHWISE_INITIALIZER = 'glorotUniform';
48093 this.DEFAULT_POINTWISE_INITIALIZER = 'glorotUniform';
48094 this.depthwiseKernel = null;
48095 this.pointwiseKernel = null;
48096 if (config.filters == null) {
48097 throw new ValueError('The `filters` configuration field is required by SeparableConv, ' +
48098 'but is unspecified.');
48099 }
48100 if (config.kernelInitializer != null || config.kernelRegularizer != null ||
48101 config.kernelConstraint != null) {
48102 throw new ValueError('Fields kernelInitializer, kernelRegularizer and kernelConstraint ' +
48103 'are invalid for SeparableConv2D. Use depthwiseInitializer, ' +
48104 'depthwiseRegularizer, depthwiseConstraint, pointwiseInitializer, ' +
48105 'pointwiseRegularizer and pointwiseConstraint instead.');
48106 }
48107 if (config.padding != null && config.padding !== 'same' &&
48108 config.padding !== 'valid') {
48109 throw new ValueError(`SeparableConv${this.rank}D supports only padding modes: ` +
48110 `'same' and 'valid', but received ${JSON.stringify(config.padding)}`);
48111 }
48112 this.depthMultiplier =
48113 config.depthMultiplier == null ? 1 : config.depthMultiplier;
48114 this.depthwiseInitializer = getInitializer(config.depthwiseInitializer || this.DEFAULT_DEPTHWISE_INITIALIZER);
48115 this.depthwiseRegularizer = getRegularizer(config.depthwiseRegularizer);
48116 this.depthwiseConstraint = getConstraint(config.depthwiseConstraint);
48117 this.pointwiseInitializer = getInitializer(config.depthwiseInitializer || this.DEFAULT_POINTWISE_INITIALIZER);
48118 this.pointwiseRegularizer = getRegularizer(config.pointwiseRegularizer);
48119 this.pointwiseConstraint = getConstraint(config.pointwiseConstraint);
48120 }
48121 build(inputShape) {
48122 inputShape = getExactlyOneShape(inputShape);
48123 if (inputShape.length < this.rank + 2) {
48124 throw new ValueError(`Inputs to SeparableConv${this.rank}D should have rank ` +
48125 `${this.rank + 2}, but received input shape: ` +
48126 `${JSON.stringify(inputShape)}`);
48127 }
48128 const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
48129 if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) {
48130 throw new ValueError(`The channel dimension of the inputs should be defined, ` +
48131 `but found ${JSON.stringify(inputShape[channelAxis])}`);
48132 }
48133 const inputDim = inputShape[channelAxis];
48134 const depthwiseKernelShape = this.kernelSize.concat([inputDim, this.depthMultiplier]);
48135 const pointwiseKernelShape = [];
48136 for (let i = 0; i < this.rank; ++i) {
48137 pointwiseKernelShape.push(1);
48138 }
48139 pointwiseKernelShape.push(inputDim * this.depthMultiplier, this.filters);
48140 const trainable = true;
48141 this.depthwiseKernel = this.addWeight('depthwise_kernel', depthwiseKernelShape, 'float32', this.depthwiseInitializer, this.depthwiseRegularizer, trainable, this.depthwiseConstraint);
48142 this.pointwiseKernel = this.addWeight('pointwise_kernel', pointwiseKernelShape, 'float32', this.pointwiseInitializer, this.pointwiseRegularizer, trainable, this.pointwiseConstraint);
48143 if (this.useBias) {
48144 this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, trainable, this.biasConstraint);
48145 }
48146 else {
48147 this.bias = null;
48148 }
48149 this.inputSpec =
48150 [new InputSpec({ ndim: this.rank + 2, axes: { [channelAxis]: inputDim } })];
48151 this.built = true;
48152 }
48153 call(inputs, kwargs) {
48154 return tidy(() => {
48155 inputs = getExactlyOneTensor(inputs);
48156 let output;
48157 if (this.rank === 1) {
48158 throw new NotImplementedError('1D separable convolution is not implemented yet.');
48159 }
48160 else if (this.rank === 2) {
48161 if (this.dataFormat === 'channelsFirst') {
48162 inputs = transpose$2(inputs, [0, 2, 3, 1]); // NCHW -> NHWC.
48163 }
48164 output = separableConv2d$1(inputs, this.depthwiseKernel.read(), this.pointwiseKernel.read(), this.strides, this.padding, this.dilationRate, 'NHWC');
48165 }
48166 if (this.useBias) {
48167 output = biasAdd(output, this.bias.read(), this.dataFormat);
48168 }
48169 if (this.activation != null) {
48170 output = this.activation.apply(output);
48171 }
48172 if (this.dataFormat === 'channelsFirst') {
48173 output = transpose$2(output, [0, 3, 1, 2]); // NHWC -> NCHW.
48174 }
48175 return output;
48176 });
48177 }
48178 getConfig() {
48179 const config = super.getConfig();
48180 delete config['rank'];
48181 delete config['kernelInitializer'];
48182 delete config['kernelRegularizer'];
48183 delete config['kernelConstraint'];
48184 config['depthwiseInitializer'] =
48185 serializeInitializer(this.depthwiseInitializer);
48186 config['pointwiseInitializer'] =
48187 serializeInitializer(this.pointwiseInitializer);
48188 config['depthwiseRegularizer'] =
48189 serializeRegularizer(this.depthwiseRegularizer);
48190 config['pointwiseRegularizer'] =
48191 serializeRegularizer(this.pointwiseRegularizer);
48192 config['depthwiseConstraint'] =
48193 serializeConstraint(this.depthwiseConstraint);
48194 config['pointwiseConstraint'] =
48195 serializeConstraint(this.pointwiseConstraint);
48196 return config;
48197 }
48198 }
48199 /** @nocollapse */
48200 SeparableConv.className = 'SeparableConv';
48201 class SeparableConv2D extends SeparableConv {
48202 constructor(args) {
48203 super(2, args);
48204 }
48205 }
48206 /** @nocollapse */
48207 SeparableConv2D.className = 'SeparableConv2D';
48208 registerClass(SeparableConv2D);
48209 class Conv1D extends Conv {
48210 constructor(args) {
48211 super(1, args);
48212 Conv1D.verifyArgs(args);
48213 this.inputSpec = [{ ndim: 3 }];
48214 }
48215 getConfig() {
48216 const config = super.getConfig();
48217 delete config['rank'];
48218 delete config['dataFormat'];
48219 return config;
48220 }
48221 static verifyArgs(args) {
48222 // config.kernelSize must be a number or array of numbers.
48223 if (typeof args.kernelSize !== 'number' &&
48224 !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 1)) {
48225 throw new ValueError(`Conv1D expects config.kernelSize to be number or number[] with ` +
48226 `length 1, but received ${JSON.stringify(args.kernelSize)}.`);
48227 }
48228 }
48229 }
48230 /** @nocollapse */
48231 Conv1D.className = 'Conv1D';
48232 registerClass(Conv1D);
48233 class Cropping2D extends Layer {
48234 constructor(args) {
48235 super(args);
48236 if (typeof args.cropping === 'number') {
48237 this.cropping =
48238 [[args.cropping, args.cropping], [args.cropping, args.cropping]];
48239 }
48240 else if (typeof args.cropping[0] === 'number') {
48241 this.cropping = [
48242 [args.cropping[0], args.cropping[0]],
48243 [args.cropping[1], args.cropping[1]]
48244 ];
48245 }
48246 else {
48247 this.cropping = args.cropping;
48248 }
48249 this.dataFormat =
48250 args.dataFormat === undefined ? 'channelsLast' : args.dataFormat;
48251 this.inputSpec = [{ ndim: 4 }];
48252 }
48253 computeOutputShape(inputShape) {
48254 if (this.dataFormat === 'channelsFirst') {
48255 return [
48256 inputShape[0], inputShape[1],
48257 inputShape[2] - this.cropping[0][0] - this.cropping[0][1],
48258 inputShape[3] - this.cropping[1][0] - this.cropping[1][1]
48259 ];
48260 }
48261 else {
48262 return [
48263 inputShape[0],
48264 inputShape[1] - this.cropping[0][0] - this.cropping[0][1],
48265 inputShape[2] - this.cropping[1][0] - this.cropping[1][1], inputShape[3]
48266 ];
48267 }
48268 }
48269 call(inputs, kwargs) {
48270 return tidy(() => {
48271 inputs = getExactlyOneTensor(inputs);
48272 if (this.dataFormat === 'channelsLast') {
48273 const hSliced = sliceAlongAxis(inputs, this.cropping[0][0], inputs.shape[1] - this.cropping[0][0] - this.cropping[0][1], 2);
48274 return sliceAlongAxis(hSliced, this.cropping[1][0], inputs.shape[2] - this.cropping[1][1] - this.cropping[1][0], 3);
48275 }
48276 else {
48277 const hSliced = sliceAlongAxis(inputs, this.cropping[0][0], inputs.shape[2] - this.cropping[0][0] - this.cropping[0][1], 3);
48278 return sliceAlongAxis(hSliced, this.cropping[1][0], inputs.shape[3] - this.cropping[1][1] - this.cropping[1][0], 4);
48279 }
48280 });
48281 }
48282 getConfig() {
48283 const config = { cropping: this.cropping, dataFormat: this.dataFormat };
48284 const baseConfig = super.getConfig();
48285 Object.assign(config, baseConfig);
48286 return config;
48287 }
48288 }
48289 /** @nocollapse */
48290 Cropping2D.className = 'Cropping2D';
48291 registerClass(Cropping2D);
48292 class UpSampling2D extends Layer {
48293 constructor(args) {
48294 super(args);
48295 this.DEFAULT_SIZE = [2, 2];
48296 this.inputSpec = [{ ndim: 4 }];
48297 this.size = args.size == null ? this.DEFAULT_SIZE : args.size;
48298 this.dataFormat =
48299 args.dataFormat == null ? 'channelsLast' : args.dataFormat;
48300 checkDataFormat(this.dataFormat);
48301 this.interpolation =
48302 args.interpolation == null ? 'nearest' : args.interpolation;
48303 checkInterpolationFormat(this.interpolation);
48304 }
48305 computeOutputShape(inputShape) {
48306 if (this.dataFormat === 'channelsFirst') {
48307 const height = inputShape[2] == null ? null : this.size[0] * inputShape[2];
48308 const width = inputShape[3] == null ? null : this.size[1] * inputShape[3];
48309 return [inputShape[0], inputShape[1], height, width];
48310 }
48311 else {
48312 const height = inputShape[1] == null ? null : this.size[0] * inputShape[1];
48313 const width = inputShape[2] == null ? null : this.size[1] * inputShape[2];
48314 return [inputShape[0], height, width, inputShape[3]];
48315 }
48316 }
48317 call(inputs, kwargs) {
48318 return tidy(() => {
48319 let input = getExactlyOneTensor(inputs);
48320 const inputShape = input.shape;
48321 if (this.dataFormat === 'channelsFirst') {
48322 input = transpose$2(input, [0, 2, 3, 1]);
48323 const height = this.size[0] * inputShape[2];
48324 const width = this.size[1] * inputShape[3];
48325 const resized = this.interpolation === 'nearest' ?
48326 image$1.resizeNearestNeighbor(input, [height, width]) :
48327 image$1.resizeBilinear(input, [height, width]);
48328 return transpose$2(resized, [0, 3, 1, 2]);
48329 }
48330 else {
48331 const height = this.size[0] * inputShape[1];
48332 const width = this.size[1] * inputShape[2];
48333 return this.interpolation === 'nearest' ?
48334 image$1.resizeNearestNeighbor(input, [height, width]) :
48335 image$1.resizeBilinear(input, [height, width]);
48336 }
48337 });
48338 }
48339 getConfig() {
48340 const config = {
48341 size: this.size,
48342 dataFormat: this.dataFormat,
48343 interpolation: this.interpolation
48344 };
48345 const baseConfig = super.getConfig();
48346 Object.assign(config, baseConfig);
48347 return config;
48348 }
48349 }
48350 /** @nocollapse */
48351 UpSampling2D.className = 'UpSampling2D';
48352 registerClass(UpSampling2D);
48353
48354 /**
48355 * @license
48356 * Copyright 2018 Google LLC
48357 *
48358 * Use of this source code is governed by an MIT-style
48359 * license that can be found in the LICENSE file or at
48360 * https://opensource.org/licenses/MIT.
48361 * =============================================================================
48362 */
48363 /**
48364 * 2D convolution with separable filters.
48365 * @param x Input tensor.
48366 * @param depthwiseKernel Convolution kernel for depthwise convolution.
48367 * @param strides Strides (Array of two integers).
48368 * @param padding Padding model.
48369 * @param dataFormat Data format.
48370 * @param dilationRate Array of two integers, dilation rates for the separable
48371 * convolution.
48372 * @returns Output tensor.
48373 * @throws ValueError If depthwiseKernel is not a 4D array.
48374 */
48375 function depthwiseConv2d$1(x, depthwiseKernel, strides = [1, 1], padding = 'valid', dataFormat, dilationRate) {
48376 return tidy(() => {
48377 if (dataFormat == null) {
48378 dataFormat = imageDataFormat();
48379 }
48380 checkDataFormat(dataFormat);
48381 let y = preprocessConv2DInput(x, dataFormat);
48382 if (x.rank !== 4) {
48383 throw new ValueError(`Input for depthwiseConv2d is required to be 4-D, but is instead ` +
48384 `${x.rank}-D`);
48385 }
48386 if (depthwiseKernel.rank !== 4) {
48387 throw new ValueError(`depthwiseKernel is required to be 4-D, but is instead ` +
48388 `${depthwiseKernel.rank}-D`);
48389 }
48390 y = depthwiseConv2d$3(y, depthwiseKernel, strides, padding === 'same' ? 'same' : 'valid', 'NHWC', dilationRate);
48391 if (dataFormat === 'channelsFirst') {
48392 y = transpose$2(y, [0, 3, 1, 2]);
48393 }
48394 return y;
48395 });
48396 }
48397 class DepthwiseConv2D extends BaseConv {
48398 constructor(args) {
48399 super(2, args);
48400 this.depthwiseKernel = null;
48401 this.depthMultiplier =
48402 args.depthMultiplier == null ? 1 : args.depthMultiplier;
48403 this.depthwiseInitializer = getInitializer(args.depthwiseInitializer || this.DEFAULT_KERNEL_INITIALIZER);
48404 this.depthwiseConstraint = getConstraint(args.depthwiseConstraint);
48405 this.depthwiseRegularizer = getRegularizer(args.depthwiseRegularizer);
48406 }
48407 build(inputShape) {
48408 inputShape = getExactlyOneShape(inputShape);
48409 if (inputShape.length < 4) {
48410 throw new ValueError(`Inputs to DepthwiseConv2D should have rank 4. ` +
48411 `Received input shape: ${JSON.stringify(inputShape)}.`);
48412 }
48413 const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : 3;
48414 if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) {
48415 throw new ValueError('The channel dimension of the inputs to DepthwiseConv2D should ' +
48416 `be defined, but is not (${inputShape[channelAxis]}).`);
48417 }
48418 const inputDim = inputShape[channelAxis];
48419 const depthwiseKernelShape = [
48420 this.kernelSize[0], this.kernelSize[1], inputDim, this.depthMultiplier
48421 ];
48422 this.depthwiseKernel = this.addWeight('depthwise_kernel', depthwiseKernelShape, null, this.depthwiseInitializer, this.depthwiseRegularizer, true, this.depthwiseConstraint);
48423 if (this.useBias) {
48424 this.bias = this.addWeight('bias', [inputDim * this.depthMultiplier], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
48425 }
48426 else {
48427 this.bias = null;
48428 }
48429 this.built = true;
48430 }
48431 call(inputs, kwargs) {
48432 return tidy(() => {
48433 inputs = getExactlyOneTensor(inputs);
48434 let outputs = depthwiseConv2d$1(inputs, this.depthwiseKernel.read(), this.strides, this.padding, this.dataFormat, null);
48435 // TODO(cais): Add support for dilation.
48436 if (this.useBias) {
48437 outputs = biasAdd(outputs, this.bias.read(), this.dataFormat);
48438 }
48439 if (this.activation != null) {
48440 outputs = this.activation.apply(outputs);
48441 }
48442 return outputs;
48443 });
48444 }
48445 computeOutputShape(inputShape) {
48446 inputShape = getExactlyOneShape(inputShape);
48447 const rows = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
48448 const cols = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
48449 const outFilters = this.dataFormat === 'channelsFirst' ?
48450 inputShape[1] * this.depthMultiplier :
48451 inputShape[3] * this.depthMultiplier;
48452 const outRows = convOutputLength(rows, this.kernelSize[0], this.padding, this.strides[0]);
48453 const outCols = convOutputLength(cols, this.kernelSize[1], this.padding, this.strides[1]);
48454 if (this.dataFormat === 'channelsFirst') {
48455 return [inputShape[0], outFilters, outRows, outCols];
48456 }
48457 else {
48458 // In this case, assume 'channelsLast'.
48459 return [inputShape[0], outRows, outCols, outFilters];
48460 }
48461 }
48462 getConfig() {
48463 const config = super.getConfig();
48464 config['depthMultiplier'] = this.depthMultiplier;
48465 config['depthwiseInitializer'] =
48466 serializeInitializer(this.depthwiseInitializer);
48467 config['depthwiseRegularizer'] =
48468 serializeRegularizer(this.depthwiseRegularizer);
48469 config['depthwiseConstraint'] =
48470 serializeConstraint(this.depthwiseRegularizer);
48471 return config;
48472 }
48473 }
48474 /** @nocollapse */
48475 DepthwiseConv2D.className = 'DepthwiseConv2D';
48476 registerClass(DepthwiseConv2D);
48477
48478 /**
48479 * @license
48480 * Copyright 2018 Google LLC
48481 *
48482 * Use of this source code is governed by an MIT-style
48483 * license that can be found in the LICENSE file or at
48484 * https://opensource.org/licenses/MIT.
48485 * =============================================================================
48486 */
48487 /**
48488 * Standardize `apply()` args to a single list of tensor inputs.
48489 *
48490 * When running a model loaded from file, the input tensors `initialState` and
48491 * `constants` are passed to `RNN.apply()` as part of `inputs` instead of the
48492 * dedicated kwargs fields. `inputs` consists of
48493 * `[inputs, initialState0, initialState1, ..., constant0, constant1]` in this
48494 * case.
48495 * This method makes sure that arguments are
48496 * separated and that `initialState` and `constants` are `Array`s of tensors
48497 * (or None).
48498 *
48499 * @param inputs Tensor or `Array` of tensors.
48500 * @param initialState Tensor or `Array` of tensors or `null`/`undefined`.
48501 * @param constants Tensor or `Array` of tensors or `null`/`undefined`.
48502 * @returns An object consisting of
48503 * inputs: A tensor.
48504 * initialState: `Array` of tensors or `null`.
48505 * constants: `Array` of tensors or `null`.
48506 * @throws ValueError, if `inputs` is an `Array` but either `initialState` or
48507 * `constants` is provided.
48508 */
48509 function standardizeArgs(inputs, initialState, constants, numConstants) {
48510 if (Array.isArray(inputs)) {
48511 if (initialState != null || constants != null) {
48512 throw new ValueError('When inputs is an array, neither initialState or constants ' +
48513 'should be provided');
48514 }
48515 if (numConstants != null) {
48516 constants = inputs.slice(inputs.length - numConstants, inputs.length);
48517 inputs = inputs.slice(0, inputs.length - numConstants);
48518 }
48519 if (inputs.length > 1) {
48520 initialState = inputs.slice(1, inputs.length);
48521 }
48522 inputs = inputs[0];
48523 }
48524 function toListOrNull(x) {
48525 if (x == null || Array.isArray(x)) {
48526 return x;
48527 }
48528 else {
48529 return [x];
48530 }
48531 }
48532 initialState = toListOrNull(initialState);
48533 constants = toListOrNull(constants);
48534 return { inputs, initialState, constants };
48535 }
48536 /**
48537 * Iterates over the time dimension of a tensor.
48538 *
48539 * @param stepFunction RNN step function.
48540 * Parameters:
48541 * inputs: tensor with shape `[samples, ...]` (no time dimension),
48542 * representing input for the batch of samples at a certain time step.
48543 * states: an Array of tensors.
48544 * Returns:
48545 * outputs: tensor with shape `[samples, outputDim]` (no time dimension).
48546 * newStates: list of tensors, same length and shapes as `states`. The first
48547 * state in the list must be the output tensor at the previous timestep.
48548 * @param inputs Tensor of temporal data of shape `[samples, time, ...]` (at
48549 * least 3D).
48550 * @param initialStates Tensor with shape `[samples, outputDim]` (no time
48551 * dimension), containing the initial values of the states used in the step
48552 * function.
48553 * @param goBackwards If `true`, do the iteration over the time dimension in
48554 * reverse order and return the reversed sequence.
48555 * @param mask Binary tensor with shape `[sample, time, 1]`, with a zero for
48556 * every element that is masked.
48557 * @param constants An Array of constant values passed at each step.
48558 * @param unroll Whether to unroll the RNN or to use a symbolic loop. *Not*
48559 * applicable to this imperative deeplearn.js backend. Its value is ignored.
48560 * @param needPerStepOutputs Whether the per-step outputs are to be
48561 * concatenated into a single tensor and returned (as the second return
48562 * value). Default: `false`. This arg is included so that the relatively
48563 * expensive concatenation of the stepwise outputs can be omitted unless
48564 * the stepwise outputs need to be kept (e.g., for an LSTM layer of which
48565 * `returnSequence` is `true`.)
48566 * @returns An Array: `[lastOutput, outputs, newStates]`.
48567 * lastOutput: the lastest output of the RNN, of shape `[samples, ...]`.
48568 * outputs: tensor with shape `[samples, time, ...]` where each entry
48569 * `output[s, t]` is the output of the step function at time `t` for sample
48570 * `s`. This return value is provided if and only if the
48571 * `needPerStepOutputs` is set as `true`. If it is set as `false`, this
48572 * return value will be `undefined`.
48573 * newStates: Array of tensors, latest states returned by the step function,
48574 * of shape `(samples, ...)`.
48575 * @throws ValueError If input dimension is less than 3.
48576 *
48577 * TODO(nielsene): This needs to be tidy-ed.
48578 */
48579 function rnn$1(stepFunction, inputs, initialStates, goBackwards = false, mask, constants, unroll = false, needPerStepOutputs = false) {
48580 return tidy(() => {
48581 const ndim = inputs.shape.length;
48582 if (ndim < 3) {
48583 throw new ValueError(`Input should be at least 3D, but is ${ndim}D.`);
48584 }
48585 // Transpose to time-major, i.e., from [batch, time, ...] to [time, batch,
48586 // ...].
48587 const axes = [1, 0].concat(range$2(2, ndim));
48588 inputs = transpose$2(inputs, axes);
48589 if (constants != null) {
48590 throw new NotImplementedError('The rnn() functoin of the deeplearn.js backend does not support ' +
48591 'constants yet.');
48592 }
48593 // Porting Note: the unroll option is ignored by the imperative backend.
48594 if (unroll) {
48595 console.warn('Backend rnn(): the unroll = true option is not applicable to the ' +
48596 'imperative deeplearn.js backend.');
48597 }
48598 if (mask != null) {
48599 mask = cast$3(cast$3(mask, 'bool'), 'float32');
48600 if (mask.rank === ndim - 1) {
48601 mask = expandDims$3(mask, -1);
48602 }
48603 mask = transpose$2(mask, axes);
48604 }
48605 if (goBackwards) {
48606 inputs = reverse$2(inputs, 0);
48607 if (mask != null) {
48608 mask = reverse$2(mask, 0);
48609 }
48610 }
48611 // Porting Note: PyKeras with TensorFlow backend uses a symbolic loop
48612 // (tf.while_loop). But for the imperative deeplearn.js backend, we just
48613 // use the usual TypeScript control flow to iterate over the time steps in
48614 // the inputs.
48615 // Porting Note: PyKeras patches a "_use_learning_phase" attribute to
48616 // outputs.
48617 // This is not idiomatic in TypeScript. The info regarding whether we are
48618 // in a learning (i.e., training) phase for RNN is passed in a different
48619 // way.
48620 const perStepOutputs = [];
48621 let lastOutput;
48622 let states = initialStates;
48623 const timeSteps = inputs.shape[0];
48624 const perStepInputs = unstack(inputs);
48625 let perStepMasks;
48626 if (mask != null) {
48627 perStepMasks = unstack(mask);
48628 }
48629 for (let t = 0; t < timeSteps; ++t) {
48630 const currentInput = perStepInputs[t];
48631 const stepOutputs = tidy(() => stepFunction(currentInput, states));
48632 if (mask == null) {
48633 lastOutput = stepOutputs[0];
48634 states = stepOutputs[1];
48635 }
48636 else {
48637 const maskedOutputs = tidy(() => {
48638 const stepMask = perStepMasks[t];
48639 const negStepMask = sub$2(onesLike$3(stepMask), stepMask);
48640 // TODO(cais): Would tfc.where() be better for performance?
48641 const output = add$3(mul(stepOutputs[0], stepMask), mul(states[0], negStepMask));
48642 const newStates = states.map((state, i) => {
48643 return add$3(mul(stepOutputs[1][i], stepMask), mul(state, negStepMask));
48644 });
48645 return { output, newStates };
48646 });
48647 lastOutput = maskedOutputs.output;
48648 states = maskedOutputs.newStates;
48649 }
48650 if (needPerStepOutputs) {
48651 perStepOutputs.push(lastOutput);
48652 }
48653 }
48654 let outputs;
48655 if (needPerStepOutputs) {
48656 const axis = 1;
48657 outputs = stack(perStepOutputs, axis);
48658 }
48659 return [lastOutput, outputs, states];
48660 });
48661 }
48662 class RNN extends Layer {
48663 constructor(args) {
48664 super(args);
48665 let cell;
48666 if (args.cell == null) {
48667 throw new ValueError('cell property is missing for the constructor of RNN.');
48668 }
48669 else if (Array.isArray(args.cell)) {
48670 cell = new StackedRNNCells({ cells: args.cell });
48671 }
48672 else {
48673 cell = args.cell;
48674 }
48675 if (cell.stateSize == null) {
48676 throw new ValueError('The RNN cell should have an attribute `stateSize` (tuple of ' +
48677 'integers, one integer per RNN state).');
48678 }
48679 this.cell = cell;
48680 this.returnSequences =
48681 args.returnSequences == null ? false : args.returnSequences;
48682 this.returnState = args.returnState == null ? false : args.returnState;
48683 this.goBackwards = args.goBackwards == null ? false : args.goBackwards;
48684 this._stateful = args.stateful == null ? false : args.stateful;
48685 this.unroll = args.unroll == null ? false : args.unroll;
48686 this.supportsMasking = true;
48687 this.inputSpec = [new InputSpec({ ndim: 3 })];
48688 this.stateSpec = null;
48689 this.states_ = null;
48690 // TODO(cais): Add constantsSpec and numConstants.
48691 this.numConstants = null;
48692 // TODO(cais): Look into the use of initial_state in the kwargs of the
48693 // constructor.
48694 this.keptStates = [];
48695 }
48696 // Porting Note: This is the equivalent of `RNN.states` property getter in
48697 // PyKeras.
48698 getStates() {
48699 if (this.states_ == null) {
48700 const numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
48701 return range$2(0, numStates).map(x => null);
48702 }
48703 else {
48704 return this.states_;
48705 }
48706 }
48707 // Porting Note: This is the equivalent of the `RNN.states` property setter in
48708 // PyKeras.
48709 setStates(states) {
48710 this.states_ = states;
48711 }
48712 computeOutputShape(inputShape) {
48713 if (isArrayOfShapes(inputShape)) {
48714 inputShape = inputShape[0];
48715 }
48716 inputShape = inputShape;
48717 // TODO(cais): Remove the casting once stacked RNN cells become supported.
48718 let stateSize = this.cell.stateSize;
48719 if (!Array.isArray(stateSize)) {
48720 stateSize = [stateSize];
48721 }
48722 const outputDim = stateSize[0];
48723 let outputShape;
48724 if (this.returnSequences) {
48725 outputShape = [inputShape[0], inputShape[1], outputDim];
48726 }
48727 else {
48728 outputShape = [inputShape[0], outputDim];
48729 }
48730 if (this.returnState) {
48731 const stateShape = [];
48732 for (const dim of stateSize) {
48733 stateShape.push([inputShape[0], dim]);
48734 }
48735 return [outputShape].concat(stateShape);
48736 }
48737 else {
48738 return outputShape;
48739 }
48740 }
48741 computeMask(inputs, mask) {
48742 return tidy(() => {
48743 if (Array.isArray(mask)) {
48744 mask = mask[0];
48745 }
48746 const outputMask = this.returnSequences ? mask : null;
48747 if (this.returnState) {
48748 const stateMask = this.states.map(s => null);
48749 return [outputMask].concat(stateMask);
48750 }
48751 else {
48752 return outputMask;
48753 }
48754 });
48755 }
48756 /**
48757 * Get the current state tensors of the RNN.
48758 *
48759 * If the state hasn't been set, return an array of `null`s of the correct
48760 * length.
48761 */
48762 get states() {
48763 if (this.states_ == null) {
48764 const numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
48765 const output = [];
48766 for (let i = 0; i < numStates; ++i) {
48767 output.push(null);
48768 }
48769 return output;
48770 }
48771 else {
48772 return this.states_;
48773 }
48774 }
48775 set states(s) {
48776 this.states_ = s;
48777 }
48778 build(inputShape) {
48779 // Note inputShape will be an Array of Shapes of initial states and
48780 // constants if these are passed in apply().
48781 const constantShape = null;
48782 if (this.numConstants != null) {
48783 throw new NotImplementedError('Constants support is not implemented in RNN yet.');
48784 }
48785 if (isArrayOfShapes(inputShape)) {
48786 inputShape = inputShape[0];
48787 }
48788 inputShape = inputShape;
48789 const batchSize = this.stateful ? inputShape[0] : null;
48790 const inputDim = inputShape.slice(2);
48791 this.inputSpec[0] = new InputSpec({ shape: [batchSize, null, ...inputDim] });
48792 // Allow cell (if RNNCell Layer) to build before we set or validate
48793 // stateSpec.
48794 const stepInputShape = [inputShape[0]].concat(inputShape.slice(2));
48795 if (constantShape != null) {
48796 throw new NotImplementedError('Constants support is not implemented in RNN yet.');
48797 }
48798 else {
48799 this.cell.build(stepInputShape);
48800 }
48801 // Set or validate stateSpec.
48802 let stateSize;
48803 if (Array.isArray(this.cell.stateSize)) {
48804 stateSize = this.cell.stateSize;
48805 }
48806 else {
48807 stateSize = [this.cell.stateSize];
48808 }
48809 if (this.stateSpec != null) {
48810 if (!arraysEqual(this.stateSpec.map(spec => spec.shape[spec.shape.length - 1]), stateSize)) {
48811 throw new ValueError(`An initialState was passed that is not compatible with ` +
48812 `cell.stateSize. Received stateSpec=${this.stateSpec}; ` +
48813 `However cell.stateSize is ${this.cell.stateSize}`);
48814 }
48815 }
48816 else {
48817 this.stateSpec =
48818 stateSize.map(dim => new InputSpec({ shape: [null, dim] }));
48819 }
48820 if (this.stateful) {
48821 this.resetStates();
48822 }
48823 }
48824 /**
48825 * Reset the state tensors of the RNN.
48826 *
48827 * If the `states` argument is `undefined` or `null`, will set the
48828 * state tensor(s) of the RNN to all-zero tensors of the appropriate
48829 * shape(s).
48830 *
48831 * If `states` is provided, will set the state tensors of the RNN to its
48832 * value.
48833 *
48834 * @param states Optional externally-provided initial states.
48835 * @param training Whether this call is done during training. For stateful
48836 * RNNs, this affects whether the old states are kept or discarded. In
48837 * particular, if `training` is `true`, the old states will be kept so
48838 * that subsequent backpropgataion through time (BPTT) may work properly.
48839 * Else, the old states will be discarded.
48840 */
48841 resetStates(states, training = false) {
48842 tidy(() => {
48843 if (!this.stateful) {
48844 throw new AttributeError('Cannot call resetStates() on an RNN Layer that is not stateful.');
48845 }
48846 const batchSize = this.inputSpec[0].shape[0];
48847 if (batchSize == null) {
48848 throw new ValueError('If an RNN is stateful, it needs to know its batch size. Specify ' +
48849 'the batch size of your input tensors: \n' +
48850 '- If using a Sequential model, specify the batch size by ' +
48851 'passing a `batchInputShape` option to your first layer.\n' +
48852 '- If using the functional API, specify the batch size by ' +
48853 'passing a `batchShape` option to your Input layer.');
48854 }
48855 // Initialize state if null.
48856 if (this.states_ == null) {
48857 if (Array.isArray(this.cell.stateSize)) {
48858 this.states_ =
48859 this.cell.stateSize.map(dim => zeros$2([batchSize, dim]));
48860 }
48861 else {
48862 this.states_ = [zeros$2([batchSize, this.cell.stateSize])];
48863 }
48864 }
48865 else if (states == null) {
48866 // Dispose old state tensors.
48867 dispose(this.states_);
48868 // For stateful RNNs, fully dispose kept old states.
48869 if (this.keptStates != null) {
48870 dispose(this.keptStates);
48871 this.keptStates = [];
48872 }
48873 if (Array.isArray(this.cell.stateSize)) {
48874 this.states_ =
48875 this.cell.stateSize.map(dim => zeros$2([batchSize, dim]));
48876 }
48877 else {
48878 this.states_[0] = zeros$2([batchSize, this.cell.stateSize]);
48879 }
48880 }
48881 else {
48882 if (!Array.isArray(states)) {
48883 states = [states];
48884 }
48885 if (states.length !== this.states_.length) {
48886 throw new ValueError(`Layer ${this.name} expects ${this.states_.length} state(s), ` +
48887 `but it received ${states.length} state value(s). Input ` +
48888 `received: ${states}`);
48889 }
48890 if (training === true) {
48891 // Store old state tensors for complete disposal later, i.e., during
48892 // the next no-arg call to this method. We do not dispose the old
48893 // states immediately because that BPTT (among other things) require
48894 // them.
48895 this.keptStates.push(this.states_.slice());
48896 }
48897 else {
48898 dispose(this.states_);
48899 }
48900 for (let index = 0; index < this.states_.length; ++index) {
48901 const value = states[index];
48902 const dim = Array.isArray(this.cell.stateSize) ?
48903 this.cell.stateSize[index] :
48904 this.cell.stateSize;
48905 const expectedShape = [batchSize, dim];
48906 if (!arraysEqual(value.shape, expectedShape)) {
48907 throw new ValueError(`State ${index} is incompatible with layer ${this.name}: ` +
48908 `expected shape=${expectedShape}, received shape=${value.shape}`);
48909 }
48910 this.states_[index] = value;
48911 }
48912 }
48913 this.states_ = this.states_.map(state => keep(state.clone()));
48914 });
48915 }
48916 apply(inputs, kwargs) {
48917 // TODO(cais): Figure out whether initialState is in kwargs or inputs.
48918 let initialState = kwargs == null ? null : kwargs['initialState'];
48919 let constants = kwargs == null ? null : kwargs['constants'];
48920 if (kwargs == null) {
48921 kwargs = {};
48922 }
48923 const standardized = standardizeArgs(inputs, initialState, constants, this.numConstants);
48924 inputs = standardized.inputs;
48925 initialState = standardized.initialState;
48926 constants = standardized.constants;
48927 // If any of `initial_state` or `constants` are specified and are
48928 // `tf.SymbolicTensor`s, then add them to the inputs and temporarily modify
48929 // the input_spec to include them.
48930 let additionalInputs = [];
48931 let additionalSpecs = [];
48932 if (initialState != null) {
48933 kwargs['initialState'] = initialState;
48934 additionalInputs = additionalInputs.concat(initialState);
48935 this.stateSpec = [];
48936 for (const state of initialState) {
48937 this.stateSpec.push(new InputSpec({ shape: state.shape }));
48938 }
48939 // TODO(cais): Use the following instead.
48940 // this.stateSpec = initialState.map(state => new InputSpec({shape:
48941 // state.shape}));
48942 additionalSpecs = additionalSpecs.concat(this.stateSpec);
48943 }
48944 if (constants != null) {
48945 kwargs['constants'] = constants;
48946 additionalInputs = additionalInputs.concat(constants);
48947 // TODO(cais): Add this.constantsSpec.
48948 this.numConstants = constants.length;
48949 }
48950 const isTensor = additionalInputs[0] instanceof SymbolicTensor;
48951 if (isTensor) {
48952 // Compute full input spec, including state and constants.
48953 const fullInput = [inputs].concat(additionalInputs);
48954 const fullInputSpec = this.inputSpec.concat(additionalSpecs);
48955 // Perform the call with temporarily replaced inputSpec.
48956 const originalInputSpec = this.inputSpec;
48957 this.inputSpec = fullInputSpec;
48958 const output = super.apply(fullInput, kwargs);
48959 this.inputSpec = originalInputSpec;
48960 return output;
48961 }
48962 else {
48963 return super.apply(inputs, kwargs);
48964 }
48965 }
48966 // tslint:disable-next-line:no-any
48967 call(inputs, kwargs) {
48968 // Input shape: `[samples, time (padded with zeros), input_dim]`.
48969 // Note that the .build() method of subclasses **must** define
48970 // this.inputSpec and this.stateSpec owith complete input shapes.
48971 return tidy(() => {
48972 const mask = kwargs == null ? null : kwargs['mask'];
48973 const training = kwargs == null ? null : kwargs['training'];
48974 let initialState = kwargs == null ? null : kwargs['initialState'];
48975 inputs = getExactlyOneTensor(inputs);
48976 if (initialState == null) {
48977 if (this.stateful) {
48978 initialState = this.states_;
48979 }
48980 else {
48981 initialState = this.getInitialState(inputs);
48982 }
48983 }
48984 const numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
48985 if (initialState.length !== numStates) {
48986 throw new ValueError(`RNN Layer has ${numStates} state(s) but was passed ` +
48987 `${initialState.length} initial state(s).`);
48988 }
48989 if (this.unroll) {
48990 console.warn('Ignoring unroll = true for RNN layer, due to imperative backend.');
48991 }
48992 const cellCallKwargs = { training };
48993 // TODO(cais): Add support for constants.
48994 const step = (inputs, states) => {
48995 // `inputs` and `states` are concatenated to form a single `Array` of
48996 // `tf.Tensor`s as the input to `cell.call()`.
48997 const outputs = this.cell.call([inputs].concat(states), cellCallKwargs);
48998 // Marshall the return value into output and new states.
48999 return [outputs[0], outputs.slice(1)];
49000 };
49001 // TODO(cais): Add support for constants.
49002 const rnnOutputs = rnn$1(step, inputs, initialState, this.goBackwards, mask, null, this.unroll, this.returnSequences);
49003 const lastOutput = rnnOutputs[0];
49004 const outputs = rnnOutputs[1];
49005 const states = rnnOutputs[2];
49006 if (this.stateful) {
49007 this.resetStates(states, training);
49008 }
49009 const output = this.returnSequences ? outputs : lastOutput;
49010 // TODO(cais): Property set learning phase flag.
49011 if (this.returnState) {
49012 return [output].concat(states);
49013 }
49014 else {
49015 return output;
49016 }
49017 });
49018 }
49019 getInitialState(inputs) {
49020 return tidy(() => {
49021 // Build an all-zero tensor of shape [samples, outputDim].
49022 // [Samples, timeSteps, inputDim].
49023 let initialState = zeros$2(inputs.shape);
49024 // [Samples].
49025 initialState = sum$3(initialState, [1, 2]);
49026 initialState = expandDims$2(initialState); // [Samples, 1].
49027 if (Array.isArray(this.cell.stateSize)) {
49028 return this.cell.stateSize.map(dim => dim > 1 ? tile$2(initialState, [1, dim]) : initialState);
49029 }
49030 else {
49031 return this.cell.stateSize > 1 ?
49032 [tile$2(initialState, [1, this.cell.stateSize])] :
49033 [initialState];
49034 }
49035 });
49036 }
49037 get trainableWeights() {
49038 if (!this.trainable) {
49039 return [];
49040 }
49041 // Porting Note: In TypeScript, `this` is always an instance of `Layer`.
49042 return this.cell.trainableWeights;
49043 }
49044 get nonTrainableWeights() {
49045 // Porting Note: In TypeScript, `this` is always an instance of `Layer`.
49046 if (!this.trainable) {
49047 return this.cell.weights;
49048 }
49049 return this.cell.nonTrainableWeights;
49050 }
49051 setFastWeightInitDuringBuild(value) {
49052 super.setFastWeightInitDuringBuild(value);
49053 if (this.cell != null) {
49054 this.cell.setFastWeightInitDuringBuild(value);
49055 }
49056 }
49057 getConfig() {
49058 const baseConfig = super.getConfig();
49059 const config = {
49060 returnSequences: this.returnSequences,
49061 returnState: this.returnState,
49062 goBackwards: this.goBackwards,
49063 stateful: this.stateful,
49064 unroll: this.unroll,
49065 };
49066 if (this.numConstants != null) {
49067 config['numConstants'] = this.numConstants;
49068 }
49069 const cellConfig = this.cell.getConfig();
49070 if (this.getClassName() === RNN.className) {
49071 config['cell'] = {
49072 'className': this.cell.getClassName(),
49073 'config': cellConfig,
49074 };
49075 }
49076 // this order is necessary, to prevent cell name from replacing layer name
49077 return Object.assign(Object.assign(Object.assign({}, cellConfig), baseConfig), config);
49078 }
49079 /** @nocollapse */
49080 static fromConfig(cls, config, customObjects = {}) {
49081 const cellConfig = config['cell'];
49082 const cell = deserialize(cellConfig, customObjects);
49083 return new cls(Object.assign(config, { cell }));
49084 }
49085 }
49086 /** @nocollapse */
49087 RNN.className = 'RNN';
49088 registerClass(RNN);
49089 // Porting Note: This is a common parent class for RNN cells. There is no
49090 // equivalent of this in PyKeras. Having a common parent class forgoes the
49091 // need for `has_attr(cell, ...)` checks or its TypeScript equivalent.
49092 /**
49093 * An RNNCell layer.
49094 *
49095 * @doc {heading: 'Layers', subheading: 'Classes'}
49096 */
49097 class RNNCell extends Layer {
49098 }
49099 class SimpleRNNCell extends RNNCell {
49100 constructor(args) {
49101 super(args);
49102 this.DEFAULT_ACTIVATION = 'tanh';
49103 this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
49104 this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
49105 this.DEFAULT_BIAS_INITIALIZER = 'zeros';
49106 this.units = args.units;
49107 assertPositiveInteger(this.units, `units`);
49108 this.activation = getActivation(args.activation == null ? this.DEFAULT_ACTIVATION : args.activation);
49109 this.useBias = args.useBias == null ? true : args.useBias;
49110 this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
49111 this.recurrentInitializer = getInitializer(args.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER);
49112 this.biasInitializer =
49113 getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
49114 this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
49115 this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
49116 this.biasRegularizer = getRegularizer(args.biasRegularizer);
49117 this.kernelConstraint = getConstraint(args.kernelConstraint);
49118 this.recurrentConstraint = getConstraint(args.recurrentConstraint);
49119 this.biasConstraint = getConstraint(args.biasConstraint);
49120 this.dropout = min$2([1, max$2([0, args.dropout == null ? 0 : args.dropout])]);
49121 this.recurrentDropout = min$2([
49122 1,
49123 max$2([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
49124 ]);
49125 this.dropoutFunc = args.dropoutFunc;
49126 this.stateSize = this.units;
49127 this.dropoutMask = null;
49128 this.recurrentDropoutMask = null;
49129 }
49130 build(inputShape) {
49131 inputShape = getExactlyOneShape(inputShape);
49132 // TODO(cais): Use regularizer.
49133 this.kernel = this.addWeight('kernel', [inputShape[inputShape.length - 1], this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
49134 this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
49135 if (this.useBias) {
49136 this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
49137 }
49138 else {
49139 this.bias = null;
49140 }
49141 this.built = true;
49142 }
49143 // Porting Note: PyKeras' equivalent of this method takes two tensor inputs:
49144 // `inputs` and `states`. Here, the two tensors are combined into an
49145 // `Tensor[]` Array as the first input argument.
49146 // Similarly, PyKeras' equivalent of this method returns two values:
49147 // `output` and `[output]`. Here the two are combined into one length-2
49148 // `Tensor[]`, consisting of `output` repeated.
49149 call(inputs, kwargs) {
49150 return tidy(() => {
49151 inputs = inputs;
49152 if (inputs.length !== 2) {
49153 throw new ValueError(`SimpleRNNCell expects 2 input Tensors, got ${inputs.length}.`);
49154 }
49155 let prevOutput = inputs[1];
49156 inputs = inputs[0];
49157 const training = kwargs['training'] == null ? false : kwargs['training'];
49158 if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
49159 this.dropoutMask = generateDropoutMask({
49160 ones: () => onesLike$3(inputs),
49161 rate: this.dropout,
49162 training,
49163 dropoutFunc: this.dropoutFunc,
49164 });
49165 }
49166 if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
49167 this.recurrentDropoutMask == null) {
49168 this.recurrentDropoutMask = generateDropoutMask({
49169 ones: () => onesLike$3(prevOutput),
49170 rate: this.recurrentDropout,
49171 training,
49172 dropoutFunc: this.dropoutFunc,
49173 });
49174 }
49175 let h;
49176 const dpMask = this.dropoutMask;
49177 const recDpMask = this.recurrentDropoutMask;
49178 if (dpMask != null) {
49179 h = dot$1(mul(inputs, dpMask), this.kernel.read());
49180 }
49181 else {
49182 h = dot$1(inputs, this.kernel.read());
49183 }
49184 if (this.bias != null) {
49185 h = biasAdd(h, this.bias.read());
49186 }
49187 if (recDpMask != null) {
49188 prevOutput = mul(prevOutput, recDpMask);
49189 }
49190 let output = add$3(h, dot$1(prevOutput, this.recurrentKernel.read()));
49191 if (this.activation != null) {
49192 output = this.activation.apply(output);
49193 }
49194 // TODO(cais): Properly set learning phase on output tensor?
49195 return [output, output];
49196 });
49197 }
49198 getConfig() {
49199 const baseConfig = super.getConfig();
49200 const config = {
49201 units: this.units,
49202 activation: serializeActivation(this.activation),
49203 useBias: this.useBias,
49204 kernelInitializer: serializeInitializer(this.kernelInitializer),
49205 recurrentInitializer: serializeInitializer(this.recurrentInitializer),
49206 biasInitializer: serializeInitializer(this.biasInitializer),
49207 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
49208 recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
49209 biasRegularizer: serializeRegularizer(this.biasRegularizer),
49210 activityRegularizer: serializeRegularizer(this.activityRegularizer),
49211 kernelConstraint: serializeConstraint(this.kernelConstraint),
49212 recurrentConstraint: serializeConstraint(this.recurrentConstraint),
49213 biasConstraint: serializeConstraint(this.biasConstraint),
49214 dropout: this.dropout,
49215 recurrentDropout: this.recurrentDropout,
49216 };
49217 return Object.assign(Object.assign({}, baseConfig), config);
49218 }
49219 }
49220 /** @nocollapse */
49221 SimpleRNNCell.className = 'SimpleRNNCell';
49222 registerClass(SimpleRNNCell);
49223 class SimpleRNN extends RNN {
49224 constructor(args) {
49225 args.cell = new SimpleRNNCell(args);
49226 super(args);
49227 // TODO(cais): Add activityRegularizer.
49228 }
49229 call(inputs, kwargs) {
49230 return tidy(() => {
49231 if (this.cell.dropoutMask != null) {
49232 dispose(this.cell.dropoutMask);
49233 this.cell.dropoutMask = null;
49234 }
49235 if (this.cell.recurrentDropoutMask != null) {
49236 dispose(this.cell.recurrentDropoutMask);
49237 this.cell.recurrentDropoutMask = null;
49238 }
49239 const mask = kwargs == null ? null : kwargs['mask'];
49240 const training = kwargs == null ? null : kwargs['training'];
49241 const initialState = kwargs == null ? null : kwargs['initialState'];
49242 return super.call(inputs, { mask, training, initialState });
49243 });
49244 }
49245 /** @nocollapse */
49246 static fromConfig(cls, config) {
49247 return new cls(config);
49248 }
49249 }
49250 /** @nocollapse */
49251 SimpleRNN.className = 'SimpleRNN';
49252 registerClass(SimpleRNN);
49253 class GRUCell extends RNNCell {
49254 constructor(args) {
49255 super(args);
49256 this.DEFAULT_ACTIVATION = 'tanh';
49257 this.DEFAULT_RECURRENT_ACTIVATION = 'hardSigmoid';
49258 this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
49259 this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
49260 this.DEFAULT_BIAS_INITIALIZER = 'zeros';
49261 if (args.resetAfter) {
49262 throw new ValueError(`GRUCell does not support reset_after parameter set to true.`);
49263 }
49264 this.units = args.units;
49265 assertPositiveInteger(this.units, 'units');
49266 this.activation = getActivation(args.activation === undefined ? this.DEFAULT_ACTIVATION :
49267 args.activation);
49268 this.recurrentActivation = getActivation(args.recurrentActivation === undefined ?
49269 this.DEFAULT_RECURRENT_ACTIVATION :
49270 args.recurrentActivation);
49271 this.useBias = args.useBias == null ? true : args.useBias;
49272 this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
49273 this.recurrentInitializer = getInitializer(args.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER);
49274 this.biasInitializer =
49275 getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
49276 this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
49277 this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
49278 this.biasRegularizer = getRegularizer(args.biasRegularizer);
49279 this.kernelConstraint = getConstraint(args.kernelConstraint);
49280 this.recurrentConstraint = getConstraint(args.recurrentConstraint);
49281 this.biasConstraint = getConstraint(args.biasConstraint);
49282 this.dropout = min$2([1, max$2([0, args.dropout == null ? 0 : args.dropout])]);
49283 this.recurrentDropout = min$2([
49284 1,
49285 max$2([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
49286 ]);
49287 this.dropoutFunc = args.dropoutFunc;
49288 this.implementation = args.implementation;
49289 this.stateSize = this.units;
49290 this.dropoutMask = null;
49291 this.recurrentDropoutMask = null;
49292 }
49293 build(inputShape) {
49294 inputShape = getExactlyOneShape(inputShape);
49295 const inputDim = inputShape[inputShape.length - 1];
49296 this.kernel = this.addWeight('kernel', [inputDim, this.units * 3], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
49297 this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units * 3], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
49298 if (this.useBias) {
49299 this.bias = this.addWeight('bias', [this.units * 3], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
49300 }
49301 else {
49302 this.bias = null;
49303 }
49304 // Porting Notes: Unlike the PyKeras implementation, we perform slicing
49305 // of the weights and bias in the call() method, at execution time.
49306 this.built = true;
49307 }
49308 call(inputs, kwargs) {
49309 return tidy(() => {
49310 inputs = inputs;
49311 if (inputs.length !== 2) {
49312 throw new ValueError(`GRUCell expects 2 input Tensors (inputs, h, c), got ` +
49313 `${inputs.length}.`);
49314 }
49315 const training = kwargs['training'] == null ? false : kwargs['training'];
49316 let hTMinus1 = inputs[1]; // Previous memory state.
49317 inputs = inputs[0];
49318 // Note: For superior performance, TensorFlow.js always uses
49319 // implementation 2, regardless of the actual value of
49320 // config.implementation.
49321 if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
49322 this.dropoutMask = generateDropoutMask({
49323 ones: () => onesLike$3(inputs),
49324 rate: this.dropout,
49325 training,
49326 count: 3,
49327 dropoutFunc: this.dropoutFunc,
49328 });
49329 }
49330 if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
49331 this.recurrentDropoutMask == null) {
49332 this.recurrentDropoutMask = generateDropoutMask({
49333 ones: () => onesLike$3(hTMinus1),
49334 rate: this.recurrentDropout,
49335 training,
49336 count: 3,
49337 dropoutFunc: this.dropoutFunc,
49338 });
49339 }
49340 const dpMask = this.dropoutMask;
49341 const recDpMask = this.recurrentDropoutMask;
49342 let z;
49343 let r;
49344 let hh;
49345 if (0 < this.dropout && this.dropout < 1) {
49346 inputs = mul(inputs, dpMask[0]);
49347 }
49348 let matrixX = dot$1(inputs, this.kernel.read());
49349 if (this.useBias) {
49350 matrixX = biasAdd(matrixX, this.bias.read());
49351 }
49352 if (0 < this.recurrentDropout && this.recurrentDropout < 1) {
49353 hTMinus1 = mul(hTMinus1, recDpMask[0]);
49354 }
49355 const recurrentKernelValue = this.recurrentKernel.read();
49356 const [rk1, rk2] = split$3(recurrentKernelValue, [2 * this.units, this.units], recurrentKernelValue.rank - 1);
49357 const matrixInner = dot$1(hTMinus1, rk1);
49358 const [xZ, xR, xH] = split$3(matrixX, 3, matrixX.rank - 1);
49359 const [recurrentZ, recurrentR] = split$3(matrixInner, 2, matrixInner.rank - 1);
49360 z = this.recurrentActivation.apply(add$3(xZ, recurrentZ));
49361 r = this.recurrentActivation.apply(add$3(xR, recurrentR));
49362 const recurrentH = dot$1(mul(r, hTMinus1), rk2);
49363 hh = this.activation.apply(add$3(xH, recurrentH));
49364 const h = add$3(mul(z, hTMinus1), mul(add$3(1, neg$2(z)), hh));
49365 // TODO(cais): Add use_learning_phase flag properly.
49366 return [h, h];
49367 });
49368 }
49369 getConfig() {
49370 const baseConfig = super.getConfig();
49371 const config = {
49372 units: this.units,
49373 activation: serializeActivation(this.activation),
49374 recurrentActivation: serializeActivation(this.recurrentActivation),
49375 useBias: this.useBias,
49376 kernelInitializer: serializeInitializer(this.kernelInitializer),
49377 recurrentInitializer: serializeInitializer(this.recurrentInitializer),
49378 biasInitializer: serializeInitializer(this.biasInitializer),
49379 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
49380 recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
49381 biasRegularizer: serializeRegularizer(this.biasRegularizer),
49382 activityRegularizer: serializeRegularizer(this.activityRegularizer),
49383 kernelConstraint: serializeConstraint(this.kernelConstraint),
49384 recurrentConstraint: serializeConstraint(this.recurrentConstraint),
49385 biasConstraint: serializeConstraint(this.biasConstraint),
49386 dropout: this.dropout,
49387 recurrentDropout: this.recurrentDropout,
49388 implementation: this.implementation,
49389 resetAfter: false
49390 };
49391 return Object.assign(Object.assign({}, baseConfig), config);
49392 }
49393 }
49394 /** @nocollapse */
49395 GRUCell.className = 'GRUCell';
49396 registerClass(GRUCell);
49397 class GRU extends RNN {
49398 constructor(args) {
49399 if (args.implementation === 0) {
49400 console.warn('`implementation=0` has been deprecated, and now defaults to ' +
49401 '`implementation=1`. Please update your layer call.');
49402 }
49403 args.cell = new GRUCell(args);
49404 super(args);
49405 // TODO(cais): Add activityRegularizer.
49406 }
49407 call(inputs, kwargs) {
49408 return tidy(() => {
49409 if (this.cell.dropoutMask != null) {
49410 dispose(this.cell.dropoutMask);
49411 this.cell.dropoutMask = null;
49412 }
49413 if (this.cell.recurrentDropoutMask != null) {
49414 dispose(this.cell.recurrentDropoutMask);
49415 this.cell.recurrentDropoutMask = null;
49416 }
49417 const mask = kwargs == null ? null : kwargs['mask'];
49418 const training = kwargs == null ? null : kwargs['training'];
49419 const initialState = kwargs == null ? null : kwargs['initialState'];
49420 return super.call(inputs, { mask, training, initialState });
49421 });
49422 }
49423 /** @nocollapse */
49424 static fromConfig(cls, config) {
49425 if (config['implmentation'] === 0) {
49426 config['implementation'] = 1;
49427 }
49428 return new cls(config);
49429 }
49430 }
49431 /** @nocollapse */
49432 GRU.className = 'GRU';
49433 registerClass(GRU);
49434 class LSTMCell extends RNNCell {
49435 constructor(args) {
49436 super(args);
49437 this.DEFAULT_ACTIVATION = 'tanh';
49438 this.DEFAULT_RECURRENT_ACTIVATION = 'hardSigmoid';
49439 this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
49440 this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
49441 this.DEFAULT_BIAS_INITIALIZER = 'zeros';
49442 this.units = args.units;
49443 assertPositiveInteger(this.units, 'units');
49444 this.activation = getActivation(args.activation === undefined ? this.DEFAULT_ACTIVATION :
49445 args.activation);
49446 this.recurrentActivation = getActivation(args.recurrentActivation === undefined ?
49447 this.DEFAULT_RECURRENT_ACTIVATION :
49448 args.recurrentActivation);
49449 this.useBias = args.useBias == null ? true : args.useBias;
49450 this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
49451 this.recurrentInitializer = getInitializer(args.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER);
49452 this.biasInitializer =
49453 getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
49454 this.unitForgetBias = args.unitForgetBias;
49455 this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
49456 this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
49457 this.biasRegularizer = getRegularizer(args.biasRegularizer);
49458 this.kernelConstraint = getConstraint(args.kernelConstraint);
49459 this.recurrentConstraint = getConstraint(args.recurrentConstraint);
49460 this.biasConstraint = getConstraint(args.biasConstraint);
49461 this.dropout = min$2([1, max$2([0, args.dropout == null ? 0 : args.dropout])]);
49462 this.recurrentDropout = min$2([
49463 1,
49464 max$2([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
49465 ]);
49466 this.dropoutFunc = args.dropoutFunc;
49467 this.implementation = args.implementation;
49468 this.stateSize = [this.units, this.units];
49469 this.dropoutMask = null;
49470 this.recurrentDropoutMask = null;
49471 }
49472 build(inputShape) {
49473 var _a;
49474 inputShape = getExactlyOneShape(inputShape);
49475 const inputDim = inputShape[inputShape.length - 1];
49476 this.kernel = this.addWeight('kernel', [inputDim, this.units * 4], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
49477 this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units * 4], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
49478 let biasInitializer;
49479 if (this.useBias) {
49480 if (this.unitForgetBias) {
49481 const capturedBiasInit = this.biasInitializer;
49482 const capturedUnits = this.units;
49483 biasInitializer = new (_a = class CustomInit extends Initializer {
49484 apply(shape, dtype) {
49485 // TODO(cais): More informative variable names?
49486 const bI = capturedBiasInit.apply([capturedUnits]);
49487 const bF = (new Ones()).apply([capturedUnits]);
49488 const bCAndH = capturedBiasInit.apply([capturedUnits * 2]);
49489 return concatAlongFirstAxis(concatAlongFirstAxis(bI, bF), bCAndH);
49490 }
49491 },
49492 /** @nocollapse */
49493 _a.className = 'CustomInit',
49494 _a)();
49495 }
49496 else {
49497 biasInitializer = this.biasInitializer;
49498 }
49499 this.bias = this.addWeight('bias', [this.units * 4], null, biasInitializer, this.biasRegularizer, true, this.biasConstraint);
49500 }
49501 else {
49502 this.bias = null;
49503 }
49504 // Porting Notes: Unlike the PyKeras implementation, we perform slicing
49505 // of the weights and bias in the call() method, at execution time.
49506 this.built = true;
49507 }
49508 call(inputs, kwargs) {
49509 return tidy(() => {
49510 const training = kwargs['training'] == null ? false : kwargs['training'];
49511 inputs = inputs;
49512 if (inputs.length !== 3) {
49513 throw new ValueError(`LSTMCell expects 3 input Tensors (inputs, h, c), got ` +
49514 `${inputs.length}.`);
49515 }
49516 let hTMinus1 = inputs[1]; // Previous memory state.
49517 const cTMinus1 = inputs[2]; // Previous carry state.
49518 inputs = inputs[0];
49519 if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
49520 this.dropoutMask = generateDropoutMask({
49521 ones: () => onesLike$3(inputs),
49522 rate: this.dropout,
49523 training,
49524 count: 4,
49525 dropoutFunc: this.dropoutFunc
49526 });
49527 }
49528 if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
49529 this.recurrentDropoutMask == null) {
49530 this.recurrentDropoutMask = generateDropoutMask({
49531 ones: () => onesLike$3(hTMinus1),
49532 rate: this.recurrentDropout,
49533 training,
49534 count: 4,
49535 dropoutFunc: this.dropoutFunc
49536 });
49537 }
49538 const dpMask = this.dropoutMask;
49539 const recDpMask = this.recurrentDropoutMask;
49540 // Note: For superior performance, TensorFlow.js always uses
49541 // implementation 2 regardless of the actual value of
49542 // config.implementation.
49543 let i;
49544 let f;
49545 let c;
49546 let o;
49547 if (0 < this.dropout && this.dropout < 1) {
49548 inputs = mul(inputs, dpMask[0]);
49549 }
49550 let z = dot$1(inputs, this.kernel.read());
49551 if (0 < this.recurrentDropout && this.recurrentDropout < 1) {
49552 hTMinus1 = mul(hTMinus1, recDpMask[0]);
49553 }
49554 z = add$3(z, dot$1(hTMinus1, this.recurrentKernel.read()));
49555 if (this.useBias) {
49556 z = biasAdd(z, this.bias.read());
49557 }
49558 const [z0, z1, z2, z3] = split$3(z, 4, z.rank - 1);
49559 i = this.recurrentActivation.apply(z0);
49560 f = this.recurrentActivation.apply(z1);
49561 c = add$3(mul(f, cTMinus1), mul(i, this.activation.apply(z2)));
49562 o = this.recurrentActivation.apply(z3);
49563 const h = mul(o, this.activation.apply(c));
49564 // TODO(cais): Add use_learning_phase flag properly.
49565 return [h, h, c];
49566 });
49567 }
49568 getConfig() {
49569 const baseConfig = super.getConfig();
49570 const config = {
49571 units: this.units,
49572 activation: serializeActivation(this.activation),
49573 recurrentActivation: serializeActivation(this.recurrentActivation),
49574 useBias: this.useBias,
49575 kernelInitializer: serializeInitializer(this.kernelInitializer),
49576 recurrentInitializer: serializeInitializer(this.recurrentInitializer),
49577 biasInitializer: serializeInitializer(this.biasInitializer),
49578 unitForgetBias: this.unitForgetBias,
49579 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
49580 recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
49581 biasRegularizer: serializeRegularizer(this.biasRegularizer),
49582 activityRegularizer: serializeRegularizer(this.activityRegularizer),
49583 kernelConstraint: serializeConstraint(this.kernelConstraint),
49584 recurrentConstraint: serializeConstraint(this.recurrentConstraint),
49585 biasConstraint: serializeConstraint(this.biasConstraint),
49586 dropout: this.dropout,
49587 recurrentDropout: this.recurrentDropout,
49588 implementation: this.implementation,
49589 };
49590 return Object.assign(Object.assign({}, baseConfig), config);
49591 }
49592 }
49593 /** @nocollapse */
49594 LSTMCell.className = 'LSTMCell';
49595 registerClass(LSTMCell);
49596 class LSTM extends RNN {
49597 constructor(args) {
49598 if (args.implementation === 0) {
49599 console.warn('`implementation=0` has been deprecated, and now defaults to ' +
49600 '`implementation=1`. Please update your layer call.');
49601 }
49602 args.cell = new LSTMCell(args);
49603 super(args);
49604 // TODO(cais): Add activityRegularizer.
49605 }
49606 call(inputs, kwargs) {
49607 return tidy(() => {
49608 if (this.cell.dropoutMask != null) {
49609 dispose(this.cell.dropoutMask);
49610 this.cell.dropoutMask = null;
49611 }
49612 if (this.cell.recurrentDropoutMask != null) {
49613 dispose(this.cell.recurrentDropoutMask);
49614 this.cell.recurrentDropoutMask = null;
49615 }
49616 const mask = kwargs == null ? null : kwargs['mask'];
49617 const training = kwargs == null ? null : kwargs['training'];
49618 const initialState = kwargs == null ? null : kwargs['initialState'];
49619 return super.call(inputs, { mask, training, initialState });
49620 });
49621 }
49622 /** @nocollapse */
49623 static fromConfig(cls, config) {
49624 if (config['implmentation'] === 0) {
49625 config['implementation'] = 1;
49626 }
49627 return new cls(config);
49628 }
49629 }
49630 /** @nocollapse */
49631 LSTM.className = 'LSTM';
49632 registerClass(LSTM);
49633 class StackedRNNCells extends RNNCell {
49634 constructor(args) {
49635 super(args);
49636 this.cells = args.cells;
49637 }
49638 get stateSize() {
49639 // States are a flat list in reverse order of the cell stack.
49640 // This allows preserving the requirement `stack.statesize[0] ===
49641 // outputDim`. E.g., states of a 2-layer LSTM would be `[h2, c2, h1, c1]`,
49642 // assuming one LSTM has states `[h, c]`.
49643 const stateSize = [];
49644 for (const cell of this.cells.slice().reverse()) {
49645 if (Array.isArray(cell.stateSize)) {
49646 stateSize.push(...cell.stateSize);
49647 }
49648 else {
49649 stateSize.push(cell.stateSize);
49650 }
49651 }
49652 return stateSize;
49653 }
49654 call(inputs, kwargs) {
49655 return tidy(() => {
49656 inputs = inputs;
49657 let states = inputs.slice(1);
49658 // Recover per-cell states.
49659 const nestedStates = [];
49660 for (const cell of this.cells.slice().reverse()) {
49661 if (Array.isArray(cell.stateSize)) {
49662 nestedStates.push(states.splice(0, cell.stateSize.length));
49663 }
49664 else {
49665 nestedStates.push(states.splice(0, 1));
49666 }
49667 }
49668 nestedStates.reverse();
49669 // Call the cells in order and store the returned states.
49670 const newNestedStates = [];
49671 let callInputs;
49672 for (let i = 0; i < this.cells.length; ++i) {
49673 const cell = this.cells[i];
49674 states = nestedStates[i];
49675 // TODO(cais): Take care of constants.
49676 if (i === 0) {
49677 callInputs = [inputs[0]].concat(states);
49678 }
49679 else {
49680 callInputs = [callInputs[0]].concat(states);
49681 }
49682 callInputs = cell.call(callInputs, kwargs);
49683 newNestedStates.push(callInputs.slice(1));
49684 }
49685 // Format the new states as a flat list in reverse cell order.
49686 states = [];
49687 for (const cellStates of newNestedStates.slice().reverse()) {
49688 states.push(...cellStates);
49689 }
49690 return [callInputs[0]].concat(states);
49691 });
49692 }
49693 build(inputShape) {
49694 if (isArrayOfShapes(inputShape)) {
49695 // TODO(cais): Take care of input constants.
49696 // const constantShape = inputShape.slice(1);
49697 inputShape = inputShape[0];
49698 }
49699 inputShape = inputShape;
49700 let outputDim;
49701 this.cells.forEach((cell, i) => {
49702 nameScope(`RNNCell_${i}`, () => {
49703 // TODO(cais): Take care of input constants.
49704 cell.build(inputShape);
49705 if (Array.isArray(cell.stateSize)) {
49706 outputDim = cell.stateSize[0];
49707 }
49708 else {
49709 outputDim = cell.stateSize;
49710 }
49711 inputShape = [inputShape[0], outputDim];
49712 });
49713 });
49714 this.built = true;
49715 }
49716 getConfig() {
49717 const baseConfig = super.getConfig();
49718 const getCellConfig = (cell) => {
49719 return {
49720 'className': cell.getClassName(),
49721 'config': cell.getConfig(),
49722 };
49723 };
49724 const cellConfigs = this.cells.map(getCellConfig);
49725 const config = { 'cells': cellConfigs };
49726 return Object.assign(Object.assign({}, baseConfig), config);
49727 }
49728 /** @nocollapse */
49729 static fromConfig(cls, config, customObjects = {}) {
49730 const cells = [];
49731 for (const cellConfig of config['cells']) {
49732 cells.push(deserialize(cellConfig, customObjects));
49733 }
49734 return new cls({ cells });
49735 }
49736 get trainableWeights() {
49737 if (!this.trainable) {
49738 return [];
49739 }
49740 const weights = [];
49741 for (const cell of this.cells) {
49742 weights.push(...cell.trainableWeights);
49743 }
49744 return weights;
49745 }
49746 get nonTrainableWeights() {
49747 const weights = [];
49748 for (const cell of this.cells) {
49749 weights.push(...cell.nonTrainableWeights);
49750 }
49751 if (!this.trainable) {
49752 const trainableWeights = [];
49753 for (const cell of this.cells) {
49754 trainableWeights.push(...cell.trainableWeights);
49755 }
49756 return trainableWeights.concat(weights);
49757 }
49758 return weights;
49759 }
49760 /**
49761 * Retrieve the weights of a the model.
49762 *
49763 * @returns A flat `Array` of `tf.Tensor`s.
49764 */
49765 getWeights() {
49766 const weights = [];
49767 for (const cell of this.cells) {
49768 weights.push(...cell.weights);
49769 }
49770 return batchGetValue(weights);
49771 }
49772 /**
49773 * Set the weights of the model.
49774 *
49775 * @param weights An `Array` of `tf.Tensor`s with shapes and types matching
49776 * the output of `getWeights()`.
49777 */
49778 setWeights(weights) {
49779 const tuples = [];
49780 for (const cell of this.cells) {
49781 const numParams = cell.weights.length;
49782 const inputWeights = weights.splice(numParams);
49783 for (let i = 0; i < cell.weights.length; ++i) {
49784 tuples.push([cell.weights[i], inputWeights[i]]);
49785 }
49786 }
49787 batchSetValue(tuples);
49788 }
49789 }
49790 /** @nocollapse */
49791 StackedRNNCells.className = 'StackedRNNCells';
49792 registerClass(StackedRNNCells);
49793 function generateDropoutMask(args) {
49794 const { ones, rate, training = false, count = 1, dropoutFunc } = args;
49795 const droppedInputs = () => dropoutFunc != null ? dropoutFunc(ones(), rate) : dropout$1(ones(), rate);
49796 const createMask = () => inTrainPhase(droppedInputs, ones, training);
49797 // just in case count is provided with null or undefined
49798 if (!count || count <= 1) {
49799 return keep(createMask().clone());
49800 }
49801 const masks = Array(count).fill(undefined).map(createMask);
49802 return masks.map(m => keep(m.clone()));
49803 }
49804
49805 /**
49806 * @license
49807 * Copyright 2020 Google LLC
49808 *
49809 * Use of this source code is governed by an MIT-style
49810 * license that can be found in the LICENSE file or at
49811 * https://opensource.org/licenses/MIT.
49812 * =============================================================================
49813 */
49814 var __rest = (undefined && undefined.__rest) || function (s, e) {
49815 var t = {};
49816 for (var p in s) if (Object.prototype.hasOwnProperty.call(s, p) && e.indexOf(p) < 0)
49817 t[p] = s[p];
49818 if (s != null && typeof Object.getOwnPropertySymbols === "function")
49819 for (var i = 0, p = Object.getOwnPropertySymbols(s); i < p.length; i++) {
49820 if (e.indexOf(p[i]) < 0 && Object.prototype.propertyIsEnumerable.call(s, p[i]))
49821 t[p[i]] = s[p[i]];
49822 }
49823 return t;
49824 };
49825 class ConvRNN2DCell extends RNNCell {
49826 }
49827 /**
49828 * Base class for convolutional-recurrent layers.
49829 */
49830 class ConvRNN2D extends RNN {
49831 constructor(args) {
49832 if (args.unroll) {
49833 throw new NotImplementedError('Unrolling is not possible with convolutional RNNs.');
49834 }
49835 if (Array.isArray(args.cell)) {
49836 throw new NotImplementedError('It is not possible at the moment to stack convolutional cells.');
49837 }
49838 super(args);
49839 this.inputSpec = [new InputSpec({ ndim: 5 })];
49840 }
49841 call(inputs, kwargs) {
49842 return tidy(() => {
49843 if (this.cell.dropoutMask != null) {
49844 dispose(this.cell.dropoutMask);
49845 this.cell.dropoutMask = null;
49846 }
49847 if (this.cell.recurrentDropoutMask != null) {
49848 dispose(this.cell.recurrentDropoutMask);
49849 this.cell.recurrentDropoutMask = null;
49850 }
49851 if (kwargs && kwargs['constants']) {
49852 throw new ValueError('ConvRNN2D cell does not support constants');
49853 }
49854 const mask = kwargs == null ? null : kwargs['mask'];
49855 const training = kwargs == null ? null : kwargs['training'];
49856 const initialState = kwargs == null ? null : kwargs['initialState'];
49857 return super.call(inputs, { mask, training, initialState });
49858 });
49859 }
49860 computeOutputShape(inputShape) {
49861 let outShape = this.computeSingleOutputShape(inputShape);
49862 if (!this.returnSequences) {
49863 outShape = [outShape[0], ...outShape.slice(2)];
49864 }
49865 if (this.returnState) {
49866 outShape =
49867 [outShape, ...Array(2).fill([inputShape[0], ...outShape.slice(-3)])];
49868 }
49869 return outShape;
49870 }
49871 getInitialState(inputs) {
49872 return tidy(() => {
49873 const { stateSize } = this.cell;
49874 const inputShape = inputs.shape;
49875 const outputShape = this.computeSingleOutputShape(inputShape);
49876 const stateShape = [outputShape[0], ...outputShape.slice(2)];
49877 const initialState = zeros$2(stateShape);
49878 if (Array.isArray(stateSize)) {
49879 return Array(stateSize.length).fill(initialState);
49880 }
49881 return [initialState];
49882 });
49883 }
49884 resetStates(states, training = false) {
49885 tidy(() => {
49886 if (!this.stateful) {
49887 throw new AttributeError('Cannot call resetStates() on an RNN Layer that is not stateful.');
49888 }
49889 const inputShape = this.inputSpec[0].shape;
49890 const outputShape = this.computeSingleOutputShape(inputShape);
49891 const stateShape = [outputShape[0], ...outputShape.slice(2)];
49892 const batchSize = inputShape[0];
49893 if (batchSize == null) {
49894 throw new ValueError('If an RNN is stateful, it needs to know its batch size. Specify ' +
49895 'the batch size of your input tensors: \n' +
49896 '- If using a Sequential model, specify the batch size by ' +
49897 'passing a `batchInputShape` option to your first layer.\n' +
49898 '- If using the functional API, specify the batch size by ' +
49899 'passing a `batchShape` option to your Input layer.');
49900 }
49901 // Initialize state if null.
49902 if (this.getStates() == null) {
49903 if (Array.isArray(this.cell.stateSize)) {
49904 this.states_ = this.cell.stateSize.map(() => zeros$2(stateShape));
49905 }
49906 else {
49907 this.states_ = [zeros$2(stateShape)];
49908 }
49909 }
49910 else if (states == null) {
49911 // Dispose old state tensors.
49912 dispose(this.states_);
49913 // For stateful RNNs, fully dispose kept old states.
49914 if (this.keptStates != null) {
49915 dispose(this.keptStates);
49916 this.keptStates = [];
49917 }
49918 if (Array.isArray(this.cell.stateSize)) {
49919 this.states_ = this.cell.stateSize.map(() => zeros$2(stateShape));
49920 }
49921 else {
49922 this.states_[0] = zeros$2(stateShape);
49923 }
49924 }
49925 else {
49926 if (!Array.isArray(states)) {
49927 states = [states];
49928 }
49929 if (states.length !== this.states_.length) {
49930 throw new ValueError(`Layer ${this.name} expects ${this.states_.length} state(s), ` +
49931 `but it received ${states.length} state value(s). Input ` +
49932 `received: ${states}`);
49933 }
49934 if (training) {
49935 // Store old state tensors for complete disposal later, i.e., during
49936 // the next no-arg call to this method. We do not dispose the old
49937 // states immediately because that BPTT (among other things) require
49938 // them.
49939 this.keptStates.push(this.states_.slice());
49940 }
49941 else {
49942 dispose(this.states_);
49943 }
49944 for (let index = 0; index < this.states_.length; ++index) {
49945 const value = states[index];
49946 const expectedShape = stateShape;
49947 if (!arraysEqual(value.shape, expectedShape)) {
49948 throw new ValueError(`State ${index} is incompatible with layer ${this.name}: ` +
49949 `expected shape=${expectedShape}, received shape=${value.shape}`);
49950 }
49951 this.states_[index] = value;
49952 }
49953 }
49954 this.states_ = this.states_.map(state => keep(state.clone()));
49955 });
49956 }
49957 computeSingleOutputShape(inputShape) {
49958 const { dataFormat, filters, kernelSize, padding, strides, dilationRate } = this.cell;
49959 const isChannelsFirst = dataFormat === 'channelsFirst';
49960 const h = inputShape[isChannelsFirst ? 3 : 2];
49961 const w = inputShape[isChannelsFirst ? 4 : 3];
49962 const hOut = convOutputLength(h, kernelSize[0], padding, strides[0], dilationRate[0]);
49963 const wOut = convOutputLength(w, kernelSize[1], padding, strides[1], dilationRate[1]);
49964 const outShape = [
49965 ...inputShape.slice(0, 2),
49966 ...(isChannelsFirst ? [filters, hOut, wOut] : [hOut, wOut, filters])
49967 ];
49968 return outShape;
49969 }
49970 }
49971 /** @nocollapse */
49972 ConvRNN2D.className = 'ConvRNN2D';
49973 class ConvLSTM2DCell extends LSTMCell {
49974 constructor(args) {
49975 const { filters, kernelSize, strides, padding, dataFormat, dilationRate, } = args;
49976 super(Object.assign(Object.assign({}, args), { units: filters }));
49977 this.filters = filters;
49978 assertPositiveInteger(this.filters, 'filters');
49979 this.kernelSize = normalizeArray(kernelSize, 2, 'kernelSize');
49980 this.kernelSize.forEach(size => assertPositiveInteger(size, 'kernelSize'));
49981 this.strides = normalizeArray(strides || 1, 2, 'strides');
49982 this.strides.forEach(stride => assertPositiveInteger(stride, 'strides'));
49983 this.padding = padding || 'valid';
49984 checkPaddingMode(this.padding);
49985 this.dataFormat = dataFormat || 'channelsLast';
49986 checkDataFormat(this.dataFormat);
49987 this.dilationRate = normalizeArray(dilationRate || 1, 2, 'dilationRate');
49988 this.dilationRate.forEach(rate => assertPositiveInteger(rate, 'dilationRate'));
49989 }
49990 build(inputShape) {
49991 var _a;
49992 inputShape = getExactlyOneShape(inputShape);
49993 const channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
49994 if (inputShape[channelAxis] == null) {
49995 throw new ValueError(`The channel dimension of the input should be defined. ` +
49996 `Found ${inputShape[channelAxis]}`);
49997 }
49998 const inputDim = inputShape[channelAxis];
49999 const numOfKernels = 4;
50000 const kernelShape = this.kernelSize.concat([inputDim, this.filters * numOfKernels]);
50001 this.kernel = this.addWeight('kernel', kernelShape, null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
50002 const recurrentKernelShape = this.kernelSize.concat([this.filters, this.filters * numOfKernels]);
50003 this.recurrentKernel = this.addWeight('recurrent_kernel', recurrentKernelShape, null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
50004 if (this.useBias) {
50005 let biasInitializer;
50006 if (this.unitForgetBias) {
50007 const init = this.biasInitializer;
50008 const filters = this.filters;
50009 biasInitializer = new (_a = class CustomInit extends Initializer {
50010 apply(shape, dtype) {
50011 const biasI = init.apply([filters]);
50012 const biasF = ones$1([filters]);
50013 const biasCAndO = init.apply([filters * 2]);
50014 return concatenate$2([biasI, biasF, biasCAndO]);
50015 }
50016 },
50017 /** @nocollapse */
50018 _a.className = 'CustomInit',
50019 _a)();
50020 }
50021 else {
50022 biasInitializer = this.biasInitializer;
50023 }
50024 this.bias = this.addWeight('bias', [this.filters * numOfKernels], null, biasInitializer, this.biasRegularizer, true, this.biasConstraint);
50025 }
50026 this.built = true;
50027 }
50028 call(inputs, kwargs) {
50029 return tidy(() => {
50030 if (inputs.length !== 3) {
50031 throw new ValueError(`ConvLSTM2DCell expects 3 input Tensors (inputs, h, c), got ` +
50032 `${inputs.length}.`);
50033 }
50034 const training = kwargs['training'] || false;
50035 const x = inputs[0]; // Current input
50036 const hTMinus1 = inputs[1]; // Previous memory state.
50037 const cTMinus1 = inputs[2]; // Previous carry state.
50038 const numOfKernels = 4;
50039 if (0 < this.dropout && this.dropout < 1 && this.dropoutMask == null) {
50040 this.dropoutMask = generateDropoutMask({
50041 ones: () => onesLike$3(x),
50042 rate: this.dropout,
50043 training,
50044 count: numOfKernels,
50045 dropoutFunc: this.dropoutFunc
50046 });
50047 }
50048 const dropoutMask = this.dropoutMask;
50049 const applyDropout = (x, mask, index) => {
50050 if (!mask || !mask[index]) {
50051 return x;
50052 }
50053 return mul(mask[index], x);
50054 };
50055 let xI = applyDropout(x, dropoutMask, 0);
50056 let xF = applyDropout(x, dropoutMask, 1);
50057 let xC = applyDropout(x, dropoutMask, 2);
50058 let xO = applyDropout(x, dropoutMask, 3);
50059 if (0 < this.recurrentDropout && this.recurrentDropout < 1 &&
50060 this.recurrentDropoutMask == null) {
50061 this.recurrentDropoutMask = generateDropoutMask({
50062 ones: () => onesLike$3(hTMinus1),
50063 rate: this.recurrentDropout,
50064 training,
50065 count: numOfKernels,
50066 dropoutFunc: this.dropoutFunc
50067 });
50068 }
50069 const recDropoutMask = this.recurrentDropoutMask;
50070 let hI = applyDropout(hTMinus1, recDropoutMask, 0);
50071 let hF = applyDropout(hTMinus1, recDropoutMask, 1);
50072 let hC = applyDropout(hTMinus1, recDropoutMask, 2);
50073 let hO = applyDropout(hTMinus1, recDropoutMask, 3);
50074 const kernelChannelAxis = 3;
50075 const [kernelI, kernelF, kernelC, kernelO] = split$3(this.kernel.read(), numOfKernels, kernelChannelAxis);
50076 const [biasI, biasF, biasC, biasO] = this.useBias ?
50077 split$3(this.bias.read(), numOfKernels) :
50078 [null, null, null, null];
50079 xI = this.inputConv(xI, kernelI, biasI, this.padding);
50080 xF = this.inputConv(xF, kernelF, biasF, this.padding);
50081 xC = this.inputConv(xC, kernelC, biasC, this.padding);
50082 xO = this.inputConv(xO, kernelO, biasO, this.padding);
50083 const [recKernelI, recKernelF, recKernelC, recKernelO] = split$3(this.recurrentKernel.read(), numOfKernels, kernelChannelAxis);
50084 hI = this.recurrentConv(hI, recKernelI);
50085 hF = this.recurrentConv(hF, recKernelF);
50086 hC = this.recurrentConv(hC, recKernelC);
50087 hO = this.recurrentConv(hO, recKernelO);
50088 const i = this.recurrentActivation.apply(add$3(xI, hI));
50089 const f = this.recurrentActivation.apply(add$3(xF, hF));
50090 const c = add$3(mul(f, cTMinus1), mul(i, this.activation.apply(add$3(xC, hC))));
50091 const h = mul(this.recurrentActivation.apply(add$3(xO, hO)), this.activation.apply(c));
50092 return [h, h, c];
50093 });
50094 }
50095 getConfig() {
50096 const _a = super.getConfig(), { 'units': _ } = _a, baseConfig = __rest(_a, ['units']);
50097 const config = {
50098 filters: this.filters,
50099 kernelSize: this.kernelSize,
50100 padding: this.padding,
50101 dataFormat: this.dataFormat,
50102 dilationRate: this.dilationRate,
50103 strides: this.strides,
50104 };
50105 return Object.assign(Object.assign({}, baseConfig), config);
50106 }
50107 inputConv(x, w, b, padding) {
50108 const out = conv2d$4(x, w, this.strides, (padding || 'valid'), this.dataFormat === 'channelsFirst' ? 'NCHW' : 'NHWC', this.dilationRate);
50109 if (b) {
50110 return biasAdd(out, b, this.dataFormat);
50111 }
50112 return out;
50113 }
50114 recurrentConv(x, w) {
50115 const strides = 1;
50116 return conv2d$4(x, w, strides, 'same', this.dataFormat === 'channelsFirst' ? 'NCHW' : 'NHWC');
50117 }
50118 }
50119 /** @nocollapse */
50120 ConvLSTM2DCell.className = 'ConvLSTM2DCell';
50121 registerClass(ConvLSTM2DCell);
50122 class ConvLSTM2D extends ConvRNN2D {
50123 constructor(args) {
50124 const cell = new ConvLSTM2DCell(args);
50125 super(Object.assign(Object.assign({}, args), { cell }));
50126 }
50127 /** @nocollapse */
50128 static fromConfig(cls, config) {
50129 return new cls(config);
50130 }
50131 }
50132 /** @nocollapse */
50133 ConvLSTM2D.className = 'ConvLSTM2D';
50134 registerClass(ConvLSTM2D);
50135
50136 /**
50137 * @license
50138 * Copyright 2018 Google LLC
50139 *
50140 * Use of this source code is governed by an MIT-style
50141 * license that can be found in the LICENSE file or at
50142 * https://opensource.org/licenses/MIT.
50143 * =============================================================================
50144 */
50145 class Dropout extends Layer {
50146 constructor(args) {
50147 super(args);
50148 this.rate = Math.max(Math.min(args.rate, 1), 0);
50149 // So that the scalar doesn't get tidied up between executions.
50150 this.noiseShape = args.noiseShape;
50151 this.seed = args.seed;
50152 this.supportsMasking = true;
50153 }
50154 getNoiseShape(input) {
50155 if (this.noiseShape == null) {
50156 return this.noiseShape;
50157 }
50158 const inputShape = input.shape;
50159 const noiseShape = [];
50160 for (let i = 0; i < this.noiseShape.length; ++i) {
50161 noiseShape.push(this.noiseShape[i] == null ? inputShape[i] : this.noiseShape[i]);
50162 }
50163 return noiseShape;
50164 }
50165 call(inputs, kwargs) {
50166 return tidy(() => {
50167 this.invokeCallHook(inputs, kwargs);
50168 const input = getExactlyOneTensor(inputs);
50169 if (0 < this.rate && this.rate < 1) {
50170 const training = kwargs['training'] == null ? false : kwargs['training'];
50171 const noiseShape = this.getNoiseShape(input);
50172 const output = inTrainPhase(() => dropout$1(input, this.rate, noiseShape, this.seed), () => input, training);
50173 return output;
50174 }
50175 return inputs;
50176 });
50177 }
50178 getConfig() {
50179 const config = {
50180 rate: this.rate,
50181 noiseShape: this.noiseShape,
50182 seed: this.seed,
50183 };
50184 const baseConfig = super.getConfig();
50185 Object.assign(config, baseConfig);
50186 return config;
50187 }
50188 dispose() {
50189 return super.dispose();
50190 }
50191 }
50192 /** @nocollapse */
50193 Dropout.className = 'Dropout';
50194 registerClass(Dropout);
50195 class SpatialDropout1D extends Dropout {
50196 constructor(args) {
50197 super(args);
50198 this.inputSpec = [{ ndim: 3 }];
50199 }
50200 getNoiseShape(input) {
50201 const inputShape = input.shape;
50202 return [inputShape[0], 1, inputShape[2]];
50203 }
50204 }
50205 /** @nocollapse */
50206 SpatialDropout1D.className = 'SpatialDropout1D';
50207 registerClass(SpatialDropout1D);
50208 class Dense extends Layer {
50209 constructor(args) {
50210 super(args);
50211 // Default activation: Linear (none).
50212 this.activation = null;
50213 this.useBias = true;
50214 this.kernel = null;
50215 this.bias = null;
50216 this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
50217 this.DEFAULT_BIAS_INITIALIZER = 'zeros';
50218 if (args.batchInputShape == null && args.inputShape == null &&
50219 args.inputDim != null) {
50220 // This logic is copied from Layer's constructor, since we can't
50221 // do exactly what the Python constructor does for Dense().
50222 let batchSize = null;
50223 if (args.batchSize != null) {
50224 batchSize = args.batchSize;
50225 }
50226 this.batchInputShape = [batchSize, args.inputDim];
50227 }
50228 this.units = args.units;
50229 assertPositiveInteger(this.units, 'units');
50230 this.activation = getActivation(args.activation);
50231 if (args.useBias != null) {
50232 this.useBias = args.useBias;
50233 }
50234 this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
50235 this.biasInitializer =
50236 getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
50237 this.kernelConstraint = getConstraint(args.kernelConstraint);
50238 this.biasConstraint = getConstraint(args.biasConstraint);
50239 this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
50240 this.biasRegularizer = getRegularizer(args.biasRegularizer);
50241 this.activityRegularizer = getRegularizer(args.activityRegularizer);
50242 this.supportsMasking = true;
50243 this.inputSpec = [{ minNDim: 2 }];
50244 }
50245 build(inputShape) {
50246 inputShape = getExactlyOneShape(inputShape);
50247 const inputLastDim = inputShape[inputShape.length - 1];
50248 if (this.kernel == null) {
50249 this.kernel = this.addWeight('kernel', [inputLastDim, this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
50250 if (this.useBias) {
50251 this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
50252 }
50253 }
50254 this.inputSpec = [{ minNDim: 2, axes: { [-1]: inputLastDim } }];
50255 this.built = true;
50256 }
50257 computeOutputShape(inputShape) {
50258 inputShape = getExactlyOneShape(inputShape);
50259 const outputShape = inputShape.slice();
50260 outputShape[outputShape.length - 1] = this.units;
50261 return outputShape;
50262 }
50263 call(inputs, kwargs) {
50264 return tidy(() => {
50265 this.invokeCallHook(inputs, kwargs);
50266 // Dense layer accepts only a single input.
50267 const input = getExactlyOneTensor(inputs);
50268 const fusedActivationName = mapActivationToFusedKernel(this.activation.getClassName());
50269 let output;
50270 if (fusedActivationName != null) {
50271 output = dot$1(input, this.kernel.read(), fusedActivationName, this.bias ? this.bias.read() : null);
50272 }
50273 else {
50274 output = dot$1(input, this.kernel.read());
50275 if (this.bias != null) {
50276 output = biasAdd(output, this.bias.read());
50277 }
50278 if (this.activation != null) {
50279 output = this.activation.apply(output);
50280 }
50281 }
50282 return output;
50283 });
50284 }
50285 getConfig() {
50286 const config = {
50287 units: this.units,
50288 activation: serializeActivation(this.activation),
50289 useBias: this.useBias,
50290 kernelInitializer: serializeInitializer(this.kernelInitializer),
50291 biasInitializer: serializeInitializer(this.biasInitializer),
50292 kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
50293 biasRegularizer: serializeRegularizer(this.biasRegularizer),
50294 activityRegularizer: serializeRegularizer(this.activityRegularizer),
50295 kernelConstraint: serializeConstraint(this.kernelConstraint),
50296 biasConstraint: serializeConstraint(this.biasConstraint)
50297 };
50298 const baseConfig = super.getConfig();
50299 Object.assign(config, baseConfig);
50300 return config;
50301 }
50302 }
50303 /** @nocollapse */
50304 Dense.className = 'Dense';
50305 registerClass(Dense);
50306 class Flatten extends Layer {
50307 constructor(args) {
50308 args = args || {};
50309 super(args);
50310 this.inputSpec = [{ minNDim: 3 }];
50311 this.dataFormat = args.dataFormat;
50312 }
50313 computeOutputShape(inputShape) {
50314 inputShape = getExactlyOneShape(inputShape);
50315 for (const dim of inputShape.slice(1)) {
50316 if (dim == null) {
50317 throw new ValueError(`The shape of the input to "Flatten" is not fully defined ` +
50318 `(got ${inputShape.slice(1)}). Make sure to pass a complete ` +
50319 `"input_shape" or "batch_input_shape" argument to the first ` +
50320 `layer in your model.`);
50321 }
50322 }
50323 return [inputShape[0], arrayProd(inputShape, 1)];
50324 }
50325 call(inputs, kwargs) {
50326 return tidy(() => {
50327 this.invokeCallHook(inputs, kwargs);
50328 let input = getExactlyOneTensor(inputs);
50329 if (this.dataFormat === 'channelsFirst' && input.rank > 1) {
50330 const permutation = [0];
50331 for (let i = 2; i < input.rank; ++i) {
50332 permutation.push(i);
50333 }
50334 permutation.push(1);
50335 input = transpose$2(input, permutation);
50336 }
50337 return batchFlatten(input);
50338 });
50339 }
50340 getConfig() {
50341 const config = {};
50342 if (this.dataFormat != null) {
50343 config['dataFormat'] = this.dataFormat;
50344 }
50345 const baseConfig = super.getConfig();
50346 Object.assign(config, baseConfig);
50347 return config;
50348 }
50349 }
50350 /** @nocollapse */
50351 Flatten.className = 'Flatten';
50352 registerClass(Flatten);
50353 class Activation extends Layer {
50354 constructor(args) {
50355 super(args);
50356 this.supportsMasking = true;
50357 this.activation = getActivation(args.activation);
50358 }
50359 call(inputs, kwargs) {
50360 return tidy(() => {
50361 this.invokeCallHook(inputs, kwargs);
50362 const input = getExactlyOneTensor(inputs);
50363 return this.activation.apply(input);
50364 });
50365 }
50366 getConfig() {
50367 const config = { activation: serializeActivation(this.activation) };
50368 const baseConfig = super.getConfig();
50369 Object.assign(config, baseConfig);
50370 return config;
50371 }
50372 }
50373 /** @nocollapse */
50374 Activation.className = 'Activation';
50375 registerClass(Activation);
50376 class RepeatVector extends Layer {
50377 constructor(args) {
50378 super(args);
50379 this.n = args.n;
50380 this.inputSpec = [{ ndim: 2 }];
50381 }
50382 computeOutputShape(inputShape) {
50383 return [inputShape[0], this.n, inputShape[1]];
50384 }
50385 call(inputs, kwargs) {
50386 return tidy(() => {
50387 inputs = getExactlyOneTensor(inputs);
50388 return repeat(inputs, this.n);
50389 });
50390 }
50391 getConfig() {
50392 const config = {
50393 n: this.n,
50394 };
50395 const baseConfig = super.getConfig();
50396 Object.assign(config, baseConfig);
50397 return config;
50398 }
50399 }
50400 /** @nocollapse */
50401 RepeatVector.className = 'RepeatVector';
50402 registerClass(RepeatVector);
50403 class Reshape extends Layer {
50404 constructor(args) {
50405 super(args);
50406 this.targetShape = args.targetShape;
50407 // Make sure that all unknown dimensions are represented as `null`.
50408 for (let i = 0; i < this.targetShape.length; ++i) {
50409 if (this.isUnknown(this.targetShape[i])) {
50410 this.targetShape[i] = null;
50411 }
50412 }
50413 }
50414 isUnknown(dim) {
50415 return dim < 0 || dim == null;
50416 }
50417 /**
50418 * Finds and replaces a missing dimension in output shape.
50419 *
50420 * This is a near direct port of the internal Numpy function
50421 * `_fix_unknown_dimension` in `numpy/core/src/multiarray/shape.c`.
50422 *
50423 * @param inputShape: Original shape of array begin reshape.
50424 * @param outputShape: Target shape of the array, with at most a single
50425 * `null` or negative number, which indicates an underdetermined dimension
50426 * that should be derived from `inputShape` and the known dimensions of
50427 * `outputShape`.
50428 * @returns: The output shape with `null` replaced with its computed value.
50429 * @throws: ValueError: If `inputShape` and `outputShape` do not match.
50430 */
50431 fixUnknownDimension(inputShape, outputShape) {
50432 const errorMsg = 'Total size of new array must be unchanged.';
50433 const finalShape = outputShape.slice();
50434 let known = 1;
50435 let unknown = null;
50436 for (let i = 0; i < finalShape.length; ++i) {
50437 const dim = finalShape[i];
50438 if (this.isUnknown(dim)) {
50439 if (unknown === null) {
50440 unknown = i;
50441 }
50442 else {
50443 throw new ValueError('Can only specifiy one unknown dimension.');
50444 }
50445 }
50446 else {
50447 known *= dim;
50448 }
50449 }
50450 const originalSize = arrayProd(inputShape);
50451 if (unknown !== null) {
50452 if (known === 0 || originalSize % known !== 0) {
50453 throw new ValueError(errorMsg);
50454 }
50455 finalShape[unknown] = originalSize / known;
50456 }
50457 else if (originalSize !== known) {
50458 throw new ValueError(errorMsg);
50459 }
50460 return finalShape;
50461 }
50462 computeOutputShape(inputShape) {
50463 let anyUnknownDims = false;
50464 for (let i = 0; i < inputShape.length; ++i) {
50465 if (this.isUnknown(inputShape[i])) {
50466 anyUnknownDims = true;
50467 break;
50468 }
50469 }
50470 if (anyUnknownDims) {
50471 return inputShape.slice(0, 1).concat(this.targetShape);
50472 }
50473 else {
50474 return inputShape.slice(0, 1).concat(this.fixUnknownDimension(inputShape.slice(1), this.targetShape));
50475 }
50476 }
50477 call(inputs, kwargs) {
50478 return tidy(() => {
50479 this.invokeCallHook(inputs, kwargs);
50480 const input = getExactlyOneTensor(inputs);
50481 const inputShape = input.shape;
50482 const outputShape = inputShape.slice(0, 1).concat(this.fixUnknownDimension(inputShape.slice(1), this.targetShape));
50483 return reshape$3(input, outputShape);
50484 });
50485 }
50486 getConfig() {
50487 const config = {
50488 targetShape: this.targetShape,
50489 };
50490 const baseConfig = super.getConfig();
50491 Object.assign(config, baseConfig);
50492 return config;
50493 }
50494 }
50495 /** @nocollapse */
50496 Reshape.className = 'Reshape';
50497 registerClass(Reshape);
50498 class Permute extends Layer {
50499 constructor(args) {
50500 super(args);
50501 if (args.dims == null) {
50502 throw new Error('Required configuration field `dims` is missing during Permute ' +
50503 'constructor call.');
50504 }
50505 if (!Array.isArray(args.dims)) {
50506 throw new Error('Permute constructor requires `dims` to be an Array, but received ' +
50507 `${args.dims} instead.`);
50508 }
50509 // Check the validity of the permutation indices.
50510 const expectedSortedIndices = range$2(1, args.dims.length + 1);
50511 if (!arraysEqual(args.dims.slice().sort(), expectedSortedIndices)) {
50512 throw new Error('Invalid permutation `dims`: ' + JSON.stringify(args.dims) +
50513 ' `dims` must contain consecutive integers starting from 1.');
50514 }
50515 this.dims = args.dims;
50516 this.dimsIncludingBatch = [0].concat(this.dims);
50517 this.inputSpec = [new InputSpec({ ndim: this.dims.length + 1 })];
50518 }
50519 computeOutputShape(inputShape) {
50520 inputShape = getExactlyOneShape(inputShape);
50521 const outputShape = inputShape.slice();
50522 this.dims.forEach((dim, i) => {
50523 outputShape[i + 1] = inputShape[dim];
50524 });
50525 return outputShape;
50526 }
50527 call(inputs, kwargs) {
50528 return transpose$2(getExactlyOneTensor(inputs), this.dimsIncludingBatch);
50529 }
50530 getConfig() {
50531 const config = {
50532 dims: this.dims,
50533 };
50534 const baseConfig = super.getConfig();
50535 Object.assign(config, baseConfig);
50536 return config;
50537 }
50538 }
50539 /** @nocollapse */
50540 Permute.className = 'Permute';
50541 registerClass(Permute);
50542 class Masking extends Layer {
50543 constructor(args) {
50544 super(args == null ? {} : args);
50545 this.supportsMasking = true;
50546 if (args != null) {
50547 this.maskValue = args.maskValue == null ? 0 : args.maskValue;
50548 }
50549 else {
50550 this.maskValue = 0;
50551 }
50552 }
50553 computeOutputShape(inputShape) {
50554 return inputShape;
50555 }
50556 getConfig() {
50557 const baseConfig = super.getConfig();
50558 const config = { maskValue: this.maskValue };
50559 Object.assign(config, baseConfig);
50560 return config;
50561 }
50562 computeMask(inputs, mask) {
50563 const input = getExactlyOneTensor(inputs);
50564 const axis = -1;
50565 return any$2(notEqual$2(input, this.maskValue), axis);
50566 }
50567 call(inputs, kwargs) {
50568 return tidy(() => {
50569 this.invokeCallHook(inputs, kwargs);
50570 const input = getExactlyOneTensor(inputs);
50571 const axis = -1;
50572 const keepDims = true;
50573 const booleanMask = any$2(notEqual$2(input, this.maskValue), axis, keepDims);
50574 const output = mul(input, cast$3(booleanMask, input.dtype));
50575 return output;
50576 });
50577 }
50578 }
50579 /** @nocollapse */
50580 Masking.className = 'Masking';
50581 registerClass(Masking);
50582
50583 /**
50584 * @license
50585 * Copyright 2018 Google LLC
50586 *
50587 * Use of this source code is governed by an MIT-style
50588 * license that can be found in the LICENSE file or at
50589 * https://opensource.org/licenses/MIT.
50590 * =============================================================================
50591 */
50592 class Embedding extends Layer {
50593 constructor(args) {
50594 super(args);
50595 this.embeddings = null;
50596 this.DEFAULT_EMBEDDINGS_INITIALIZER = 'randomUniform';
50597 if (args.batchInputShape == null && args.inputShape == null) {
50598 // Porting Note: This logic is copied from Layer's constructor, since we
50599 // can't do exactly what the Python constructor does for Embedding().
50600 // Specifically, the super constructor can not be called after the
50601 // mutation of the `config` argument.
50602 let batchSize = null;
50603 if (args.batchSize != null) {
50604 batchSize = args.batchSize;
50605 }
50606 if (args.inputLength == null) {
50607 // Fix super-constructor to what it would have done if
50608 // 'config.inputShape' were (None, )
50609 this.batchInputShape = [batchSize, null];
50610 }
50611 else {
50612 // Fix super-constructor to what it would have done if
50613 // 'config.inputShape' were (config.inputLength, )
50614 this.batchInputShape =
50615 [batchSize].concat(toList(args.inputLength));
50616 }
50617 }
50618 this.inputDim = args.inputDim;
50619 assertPositiveInteger(this.inputDim, 'inputDim');
50620 this.outputDim = args.outputDim;
50621 assertPositiveInteger(this.outputDim, 'outputDim');
50622 this.embeddingsInitializer = getInitializer(args.embeddingsInitializer || this.DEFAULT_EMBEDDINGS_INITIALIZER);
50623 this.embeddingsRegularizer = getRegularizer(args.embeddingsRegularizer);
50624 this.activityRegularizer = getRegularizer(args.activityRegularizer);
50625 this.embeddingsConstraint = getConstraint(args.embeddingsConstraint);
50626 this.maskZero = args.maskZero;
50627 this.supportsMasking = args.maskZero;
50628 this.inputLength = args.inputLength;
50629 }
50630 build(inputShape) {
50631 this.embeddings = this.addWeight('embeddings', [this.inputDim, this.outputDim], this.dtype, this.embeddingsInitializer, this.embeddingsRegularizer, true, this.embeddingsConstraint);
50632 this.built = true;
50633 }
50634 // Override warnOnIncompatibleInputShape because an embedding layer allows
50635 // the input to have varying ranks.
50636 warnOnIncompatibleInputShape(inputShape) { }
50637 computeMask(inputs, mask) {
50638 return tidy(() => {
50639 if (!this.maskZero) {
50640 return null;
50641 }
50642 else {
50643 inputs = getExactlyOneTensor(inputs);
50644 return notEqual$2(inputs, zerosLike$3(inputs));
50645 }
50646 });
50647 }
50648 computeOutputShape(inputShape) {
50649 inputShape = getExactlyOneShape(inputShape);
50650 if (this.inputLength == null) {
50651 return [...inputShape, this.outputDim];
50652 }
50653 // inputLength can be an array if input is 3D or higher.
50654 const inLens = toList(this.inputLength);
50655 if (inLens.length !== inputShape.length - 1) {
50656 throw new ValueError(`"inputLength" is ${this.inputLength}, but received ` +
50657 `input shape has shape ${inputShape}`);
50658 }
50659 else {
50660 let i = 0;
50661 for (let k = 0; k < inLens.length; ++k) {
50662 const s1 = inLens[k];
50663 const s2 = inputShape[k + 1];
50664 if ((s1 != null) && (s2 != null) && (s1 !== s2)) {
50665 throw new ValueError(`"inputLength" is ${this.inputLength}, but received ` +
50666 `input shape has shape ${inputShape}`);
50667 }
50668 else if (s1 == null) {
50669 inLens[i] = s2;
50670 }
50671 i++;
50672 }
50673 }
50674 return [inputShape[0], ...inLens, this.outputDim];
50675 }
50676 call(inputs, kwargs) {
50677 return tidy(() => {
50678 this.invokeCallHook(inputs, kwargs);
50679 // Embedding layer accepts only a single input.
50680 let input = getExactlyOneTensor(inputs);
50681 if (input.dtype !== 'int32') {
50682 input = cast$2(input, 'int32');
50683 }
50684 const output = gather(this.embeddings.read(), reshape$3(input, [input.size]));
50685 return reshape$3(output, getExactlyOneShape(this.computeOutputShape(input.shape)));
50686 });
50687 }
50688 getConfig() {
50689 const config = {
50690 inputDim: this.inputDim,
50691 outputDim: this.outputDim,
50692 embeddingsInitializer: serializeInitializer(this.embeddingsInitializer),
50693 embeddingsRegularizer: serializeRegularizer(this.embeddingsRegularizer),
50694 activityRegularizer: serializeRegularizer(this.activityRegularizer),
50695 embeddingsConstraint: serializeConstraint(this.embeddingsConstraint),
50696 maskZero: this.maskZero,
50697 inputLength: this.inputLength
50698 };
50699 const baseConfig = super.getConfig();
50700 Object.assign(config, baseConfig);
50701 return config;
50702 }
50703 }
50704 /** @nocollapse */
50705 Embedding.className = 'Embedding';
50706 registerClass(Embedding);
50707
50708 /**
50709 * @license
50710 * Copyright 2018 Google LLC
50711 *
50712 * Use of this source code is governed by an MIT-style
50713 * license that can be found in the LICENSE file or at
50714 * https://opensource.org/licenses/MIT.
50715 * =============================================================================
50716 */
50717 /**
50718 * Generic Merge layer for element-wise merge functions.
50719 *
50720 * Used to implement `Sum`, `Average`, `Concatenate`, etc.
50721 */
50722 class Merge extends Layer {
50723 constructor(args) {
50724 super(args || {});
50725 this.supportsMasking = true;
50726 }
50727 /**
50728 * Logic for merging multiple tensors, to be overridden by subclasses.
50729 * @param inputs
50730 */
50731 mergeFunction(inputs) {
50732 throw new NotImplementedError();
50733 }
50734 /**
50735 * Computes the shape of the result of an elementwise operation.
50736 *
50737 * @param shape1: Shape of the first tensor.
50738 * @param shape2: Shape of the second tensor.
50739 * @returns Expected output shape when an elementwise operation is carried
50740 * out on 2 tensors with shapes `shape1` and `shape2`.
50741 * @throws ValueError: If `shape1` and `shape2` are not compatible for
50742 * element-wise operations.
50743 */
50744 computeElementwiseOpOutputShape(shape1, shape2) {
50745 if (shape1 == null || shape2 == null) {
50746 return null;
50747 }
50748 else if (shape1.length < shape2.length) {
50749 return this.computeElementwiseOpOutputShape(shape2, shape1);
50750 }
50751 else if (shape2.length === 0) {
50752 return shape1;
50753 }
50754 const outputShape = shape1.slice(0, shape1.length - shape2.length);
50755 for (let k = 0; k < shape2.length; ++k) {
50756 const i = shape1[shape1.length - shape2.length + k];
50757 const j = shape2[k];
50758 if (i == null || j == null || i < 0 || j < 0) {
50759 outputShape.push(null);
50760 }
50761 else if (i === 1) {
50762 outputShape.push(j);
50763 }
50764 else if (j === 1) {
50765 outputShape.push(i);
50766 }
50767 else {
50768 if (i !== j) {
50769 throw new ValueError('Operands could not be broadcast together with shapes ' +
50770 JSON.stringify(shape1) + ' ' + JSON.stringify(shape2));
50771 }
50772 outputShape.push(i);
50773 }
50774 }
50775 return outputShape;
50776 }
50777 build(inputShape) {
50778 // Used purely for shape validation.
50779 if (Array.isArray(inputShape) && !Array.isArray(inputShape[0])) {
50780 // Make sure that inputShape is an Array of shape.
50781 inputShape = [getExactlyOneShape(inputShape)];
50782 }
50783 inputShape = inputShape;
50784 if (inputShape.length < 2) {
50785 throw new ValueError('A merge layer should be called on an Array of at least 2 inputs.' +
50786 ` Got ${inputShape.length} input(s).`);
50787 }
50788 // Make sure that there is at most one unique batch size among the input
50789 // shapes.
50790 let batchSizes = [];
50791 for (const shape of inputShape) {
50792 if (shape != null && shape[0] !== null) {
50793 batchSizes.push(shape[0]);
50794 }
50795 }
50796 batchSizes = unique$2(batchSizes);
50797 if (batchSizes.length > 1) {
50798 throw new ValueError(`Can not merge tensors with different batch sizes. ` +
50799 `Got tensors with shapes: ${JSON.stringify(inputShape)}.`);
50800 }
50801 let outputShape = inputShape[0] == null ? null : inputShape[0].slice(1);
50802 for (let i = 1; i < inputShape.length; ++i) {
50803 const shape = inputShape[i] == null ? null : inputShape[i].slice(1);
50804 outputShape = this.computeElementwiseOpOutputShape(outputShape, shape);
50805 }
50806 // If the inputs have different ranks, we have to reshape them to make them
50807 // broadcastable.
50808 const allRanks = inputShape.map(shape => shape.length);
50809 if (inputShape.indexOf(null) === -1 &&
50810 unique$2(allRanks).length === 1) {
50811 this.reshapeRequired = false;
50812 }
50813 else {
50814 this.reshapeRequired = true;
50815 }
50816 }
50817 call(inputs, kwargs) {
50818 return tidy(() => {
50819 inputs = inputs;
50820 if (this.reshapeRequired) {
50821 const reshapedInputs = [];
50822 const inputDims = inputs.map(input => input.rank);
50823 if (inputDims.indexOf(null) === -1) {
50824 // If ranks of all inputs are available, we simply expand each of them
50825 // at axis=1 until all of them have the same rank.
50826 const maxNDim = max$2(inputDims);
50827 for (let x of inputs) {
50828 const xNDim = x.rank;
50829 for (let k = 0; k < maxNDim - xNDim; ++k) {
50830 x = expandDims$2(x, 1);
50831 }
50832 reshapedInputs.push(x);
50833 }
50834 return this.mergeFunction(reshapedInputs);
50835 }
50836 else {
50837 // Transpose all inputs so that batch size is the last dimension.
50838 // [batchSize, dim1, dim2, ...] -> [dim1, dim2, ..., batchSize]
50839 let transposed = false;
50840 for (const x of inputs) {
50841 const xNDim = x.rank;
50842 if (xNDim == null) {
50843 const xShape = x.shape;
50844 const batchSize = xShape[0];
50845 const newShape = xShape.slice(1).concat([batchSize]);
50846 let xTransposed = reshape$3(x, [batchSize].concat(arrayProd(xShape.slice(1))));
50847 xTransposed = transpose$2(xTransposed, [1, 0]);
50848 xTransposed = reshape$3(xTransposed, newShape);
50849 reshapedInputs.push(xTransposed);
50850 transposed = true;
50851 }
50852 else if (xNDim > 1) {
50853 const dims = range$2(1, xNDim).concat([0]);
50854 reshapedInputs.push(transpose$2(x, dims));
50855 transposed = true;
50856 }
50857 else {
50858 // We don't transpose inputs if they are 1D vectors or scalars.
50859 reshapedInputs.push(x);
50860 }
50861 }
50862 let y = this.mergeFunction(reshapedInputs);
50863 const yNDim = y.rank;
50864 if (transposed) {
50865 // If inputs have been transposed, we have to transpose the output
50866 // too.
50867 if (yNDim == null) {
50868 const yShape = y.shape;
50869 const yNDim = yShape.length;
50870 const batchSize = yShape[yNDim - 1];
50871 const newShape = [batchSize].concat(yShape.slice(0, yShape.length - 1));
50872 y = reshape$3(transpose$2(reshape$3(y, [-1, batchSize]), [1, 0]), newShape);
50873 }
50874 else if (yNDim > 1) {
50875 const dims = [yNDim - 1].concat(range$2(0, yNDim - 1));
50876 y = transpose$2(y, dims);
50877 }
50878 }
50879 return y;
50880 }
50881 }
50882 else {
50883 return this.mergeFunction(inputs);
50884 }
50885 });
50886 }
50887 computeOutputShape(inputShape) {
50888 inputShape = inputShape;
50889 let outputShape;
50890 if (inputShape[0] == null) {
50891 outputShape = null;
50892 }
50893 else {
50894 outputShape = inputShape[0].slice(1);
50895 }
50896 for (let i = 1; i < inputShape.length; ++i) {
50897 const shape = inputShape[i] == null ? null : inputShape[i].slice(1);
50898 outputShape = this.computeElementwiseOpOutputShape(outputShape, shape);
50899 }
50900 let batchSizes = [];
50901 for (const shape of inputShape) {
50902 if (shape != null && shape[0] !== null) {
50903 batchSizes.push(shape[0]);
50904 }
50905 }
50906 batchSizes = unique$2(batchSizes);
50907 if (batchSizes.length === 1) {
50908 outputShape = batchSizes.concat(outputShape);
50909 }
50910 else {
50911 outputShape = [null].concat(outputShape);
50912 }
50913 return outputShape;
50914 }
50915 computeMask(inputs, mask) {
50916 return tidy(() => {
50917 if (mask == null) {
50918 return null;
50919 }
50920 if (!Array.isArray(mask)) {
50921 throw new ValueError('`mask` should be an Array');
50922 }
50923 if (!Array.isArray(inputs)) {
50924 throw new ValueError('`inputs` should be an Array');
50925 }
50926 if (mask.length !== inputs.length) {
50927 throw new ValueError(`The Array 'inputs' and 'mask' are expected to have the same ` +
50928 `length, but have different lengths ` +
50929 `(${inputs.length} vs ${mask.length})`);
50930 }
50931 if (mask.every(m => m == null)) {
50932 return null;
50933 }
50934 mask = mask.map(m => m == null ? m : expandDims$3(m, 0));
50935 let output = mask[0];
50936 for (let i = 1; i < mask.length - 1; ++i) {
50937 output = logicalAnd$2(output, mask[i]);
50938 }
50939 return output;
50940 });
50941 }
50942 }
50943 class Add extends Merge {
50944 constructor(args) {
50945 super(args);
50946 }
50947 mergeFunction(inputs) {
50948 return tidy(() => {
50949 let output = inputs[0].clone();
50950 for (let i = 1; i < inputs.length; ++i) {
50951 output = add$3(output, inputs[i]);
50952 }
50953 return output;
50954 });
50955 }
50956 }
50957 /** @nocollapse */
50958 Add.className = 'Add';
50959 registerClass(Add);
50960 /**
50961 * Calculate the element-wise sum of inputs, which all have the same shape.
50962 *
50963 * This function can be invoked in three ways.
50964 *
50965 * 1. Construct an instance of `Add` layer, by using no input argument
50966 * or a single configuration argument. The resultant `Add` layer can then
50967 * be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
50968 *
50969 * ```js
50970 * const addLayer = tf.layers.add();
50971 *
50972 * // The layer can be applied to inputs.
50973 * const input1 = tf.input({shape: [2, 2]});
50974 * const input2 = tf.input({shape: [2, 2]});
50975 * const output = addLayer.apply([input1, input2]);
50976 * console.log(output.shape);
50977 * // You get [null, 2, 2], with the first dimension as the undetermined batch
50978 * // dimension.
50979 * ```
50980 *
50981 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
50982 * an `Layer` object internally and calls its `apply` method on the inputs,
50983 * generating a new `tf.SymbolicTensor`. For example:
50984 *
50985 * ```js
50986 * const input1 = tf.input({shape: [2, 2]});
50987 * const input2 = tf.input({shape: [2, 2]});
50988 * const output = tf.layers.add([input1, input2]);
50989 * console.log(output.shape);
50990 * // You get [null, 2, 2], with the first dimension as the undetermined batch
50991 * // dimension.
50992 * ```
50993 *
50994 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
50995 * an `Layer` object internally and calls its `apply` method on the inputs,
50996 * generating a new `tf.Tensor` as the result of the computation. For
50997 * example:
50998 *
50999 * ```js
51000 * const input1 = tf.tensor2d([1, 2, 3, 4], [2, 2]);
51001 * const input2 = tf.tensor2d([10, 20, 30, 40], [2, 2]);
51002 * tf.layers.add([input1, input2]).print();
51003 * // Gives [[11, 22], [33, 44]].
51004 *
51005 */
51006 function add$2(config) {
51007 if (Array.isArray(config)) {
51008 const layer = new Add({});
51009 return layer.apply(config);
51010 }
51011 else {
51012 return new Add(config);
51013 }
51014 }
51015 class Multiply extends Merge {
51016 constructor(args) {
51017 super(args);
51018 }
51019 mergeFunction(inputs) {
51020 return tidy(() => {
51021 let output = inputs[0].clone();
51022 for (let i = 1; i < inputs.length; ++i) {
51023 output = mul(output, inputs[i]);
51024 }
51025 return output;
51026 });
51027 }
51028 }
51029 /** @nocollapse */
51030 Multiply.className = 'Multiply';
51031 registerClass(Multiply);
51032 /**
51033 * Calculate the element-wise product of inputs, which all have the same shape.
51034 *
51035 * This function can be invoked in three ways.
51036 *
51037 * 1. Construct an instance of `Multiply` layer, by using no input argument
51038 * or a single configuration argument. The resultant `Multiply` layer can
51039 * then be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
51040 *
51041 * ```js
51042 * const multiplyLayer = tf.layers.multiply();
51043 *
51044 * // The layer can be applied to inputs.
51045 * const input1 = tf.input({shape: [2, 2]});
51046 * const input2 = tf.input({shape: [2, 2]});
51047 * const output = multiplyLayer.apply([input1, input2]);
51048 * console.log(output.shape);
51049 * // You get [null, 2, 2], with the first dimension as the undetermined batch
51050 * // dimension.
51051 * ```
51052 *
51053 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
51054 * an `Layer` object internally and calls its `apply` method on the inputs,
51055 * generating a new `tf.SymbolicTensor`. For example:
51056 *
51057 * ```js
51058 * const input1 = tf.input({shape: [2, 2]});
51059 * const input2 = tf.input({shape: [2, 2]});
51060 * const output = tf.layers.multiply([input1, input2]);
51061 * console.log(output.shape);
51062 * // You get [null, 2, 2], with the first dimension as the undetermined batch
51063 * // dimension.
51064 * ```
51065 *
51066 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
51067 * an `Layer` object internally and calls its `apply` method on the inputs,
51068 * generating a new `tf.Tensor` as the result of the computation. For
51069 * example:
51070 *
51071 * ```js
51072 * const input1 = tf.tensor2d([1, 2, 3, 4], [2, 2]);
51073 * const input2 = tf.tensor2d([10, 20, 30, 40], [2, 2]);
51074 * tf.layers.multiply([input1, input2]).print();
51075 * // Gives [[10, 40], [90, 160]].
51076 *
51077 */
51078 function multiply$3(config) {
51079 if (Array.isArray(config)) {
51080 const layer = new Multiply({});
51081 return layer.apply(config);
51082 }
51083 else {
51084 return new Multiply(config);
51085 }
51086 }
51087 class Average extends Merge {
51088 constructor(args) {
51089 super(args);
51090 }
51091 mergeFunction(inputs) {
51092 return tidy(() => {
51093 let output = inputs[0].clone();
51094 for (let i = 1; i < inputs.length; ++i) {
51095 output = add$3(output, inputs[i]);
51096 }
51097 return mul(1 / inputs.length, output);
51098 });
51099 }
51100 }
51101 /** @nocollapse */
51102 Average.className = 'Average';
51103 registerClass(Average);
51104 /**
51105 * Calculate the element-wise arithmetic mean of inputs, which all have the same
51106 * shape.
51107 *
51108 * This function can be invoked in three ways.
51109 *
51110 * 1. Construct an instance of `Average` layer, by using no input argument
51111 * or a single configuration argument. The resultant `Average` layer can then
51112 * be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
51113 *
51114 * ```js
51115 * const averageLayer = tf.layers.average();
51116 *
51117 * // The layer can be applied to inputs.
51118 * const input1 = tf.input({shape: [2, 2]});
51119 * const input2 = tf.input({shape: [2, 2]});
51120 * const output = averageLayer.apply([input1, input2]);
51121 * console.log(output.shape);
51122 * // You get [null, 2, 2], with the first dimension as the undetermined batch
51123 * // dimension.
51124 * ```
51125 *
51126 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
51127 * an `Layer` object internally and calls its `apply` method on the inputs,
51128 * generating a new `tf.SymbolicTensor`. For example:
51129 *
51130 * ```js
51131 * const input1 = tf.input({shape: [2, 2]});
51132 * const input2 = tf.input({shape: [2, 2]});
51133 * const output = tf.layers.average([input1, input2]);
51134 * console.log(output.shape);
51135 * // You get [null, 2, 2], with the first dimension as the undetermined batch
51136 * // dimension.
51137 * ```
51138 *
51139 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
51140 * an `Layer` object internally and calls its `apply` method on the inputs,
51141 * generating a new `tf.Tensor` as the result of the computation. For
51142 * example:
51143 *
51144 * ```js
51145 * const input1 = tf.tensor2d([1, 2, 3, 4], [2, 2]);
51146 * const input2 = tf.tensor2d([10, 20, 30, 40], [2, 2]);
51147 * tf.layers.average([input1, input2]).print();
51148 * // Gives [[5.5, 11], [16.5, 22]].
51149 *
51150 */
51151 function average$1(config) {
51152 if (Array.isArray(config)) {
51153 const layer = new Average({});
51154 return layer.apply(config);
51155 }
51156 else {
51157 return new Average(config);
51158 }
51159 }
51160 class Maximum extends Merge {
51161 constructor(args) {
51162 super(args);
51163 }
51164 mergeFunction(inputs) {
51165 return tidy(() => {
51166 let output = inputs[0];
51167 for (let i = 1; i < inputs.length; ++i) {
51168 output = maximum$4(output, inputs[i]);
51169 }
51170 return output;
51171 });
51172 }
51173 }
51174 /** @nocollapse */
51175 Maximum.className = 'Maximum';
51176 registerClass(Maximum);
51177 /**
51178 * Calculate the element-wise maximum of inputs, which all have the same shape.
51179 *
51180 * This function can be invoked in three ways.
51181 *
51182 * 1. Construct an instance of `Maximum` layer, by using no input argument
51183 * or a single configuration argument. The resultant `Maximum` layer can then
51184 * be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
51185 *
51186 * ```js
51187 * const maximumLayer = tf.layers.maximum();
51188 *
51189 * // The layer can be applied to inputs.
51190 * const input1 = tf.input({shape: [2, 2]});
51191 * const input2 = tf.input({shape: [2, 2]});
51192 * const output = maximumLayer.apply([input1, input2]);
51193 * console.log(output.shape);
51194 * // You get [null, 2, 2], with the first dimension as the undetermined batch
51195 * // dimension.
51196 * ```
51197 *
51198 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
51199 * an `Layer` object internally and calls its `apply` method on the inputs,
51200 * generating a new `tf.SymbolicTensor`. For example:
51201 *
51202 * ```js
51203 * const input1 = tf.input({shape: [2, 2]});
51204 * const input2 = tf.input({shape: [2, 2]});
51205 * const output = tf.layers.maximum([input1, input2]);
51206 * console.log(output.shape);
51207 * // You get [null, 2, 2], with the first dimension as the undetermined batch
51208 * // dimension.
51209 * ```
51210 *
51211 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
51212 * an `Layer` object internally and calls its `apply` method on the inputs,
51213 * generating a new `tf.Tensor` as the result of the computation. For
51214 * example:
51215 *
51216 * ```js
51217 * const input1 = tf.tensor2d([1, 20, 3, 40], [2, 2]);
51218 * const input2 = tf.tensor2d([10, 2, 30, 4], [2, 2]);
51219 * tf.layers.maximum([input1, input2]).print();
51220 * // Gives [[10, 20], [30, 40]].
51221 *
51222 */
51223 function maximum$3(config) {
51224 if (Array.isArray(config)) {
51225 const layer = new Maximum({});
51226 return layer.apply(config);
51227 }
51228 else {
51229 return new Maximum(config);
51230 }
51231 }
51232 class Minimum extends Merge {
51233 constructor(args) {
51234 super(args);
51235 }
51236 mergeFunction(inputs) {
51237 return tidy(() => {
51238 let output = inputs[0];
51239 for (let i = 1; i < inputs.length; ++i) {
51240 output = minimum$4(output, inputs[i]);
51241 }
51242 return output;
51243 });
51244 }
51245 }
51246 /** @nocollapse */
51247 Minimum.className = 'Minimum';
51248 registerClass(Minimum);
51249 /**
51250 * Calculate the element-wise minimum of inputs, which all have the same shape.
51251 *
51252 * This function can be invoked in three ways.
51253 *
51254 * 1. Construct an instance of `Minimum` layer, by using no input argument
51255 * or a single configuration argument. The resultant `Minimum` layer can then
51256 * be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
51257 *
51258 * ```js
51259 * const minimumLayer = tf.layers.minimum();
51260 *
51261 * // The layer can be applied to inputs.
51262 * const input1 = tf.input({shape: [2, 2]});
51263 * const input2 = tf.input({shape: [2, 2]});
51264 * const output = minimumLayer.apply([input1, input2]);
51265 * console.log(output.shape);
51266 * // You get [null, 2, 2], with the first dimension as the undetermined batch
51267 * // dimension.
51268 * ```
51269 *
51270 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
51271 * an `Layer` object internally and calls its `apply` method on the inputs,
51272 * generating a new `tf.SymbolicTensor`. For example:
51273 *
51274 * ```js
51275 * const input1 = tf.input({shape: [2, 2]});
51276 * const input2 = tf.input({shape: [2, 2]});
51277 * const output = tf.layers.minimum([input1, input2]);
51278 * console.log(output.shape);
51279 * // You get [null, 2, 2], with the first dimension as the undetermined batch
51280 * // dimension.
51281 * ```
51282 *
51283 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
51284 * an `Layer` object internally and calls its `apply` method on the inputs,
51285 * generating a new `tf.Tensor` as the result of the computation. For
51286 * example:
51287 *
51288 * ```js
51289 * const input1 = tf.tensor2d([1, 20, 3, 40], [2, 2]);
51290 * const input2 = tf.tensor2d([10, 2, 30, 4], [2, 2]);
51291 * tf.layers.minimum([input1, input2]).print();
51292 * // Gives [[1, 2], [3, 4]].
51293 *
51294 */
51295 function minimum$3(config) {
51296 if (Array.isArray(config)) {
51297 const layer = new Minimum({});
51298 return layer.apply(config);
51299 }
51300 else {
51301 return new Minimum(config);
51302 }
51303 }
51304 class Concatenate extends Merge {
51305 constructor(args) {
51306 super(args);
51307 this.DEFAULT_AXIS = -1;
51308 if (args == null) {
51309 args = {};
51310 }
51311 this.axis = args.axis == null ? this.DEFAULT_AXIS : args.axis;
51312 this.supportsMasking = true;
51313 this.reshapeRequired = false;
51314 }
51315 build(inputShape) {
51316 // Used purely for shape validation.]
51317 if (!(Array.isArray(inputShape) && Array.isArray(inputShape[0])) ||
51318 inputShape.length === 1) {
51319 throw new ValueError('A `Concatenate` layer should be called on a list of at least 2 ' +
51320 'inputs');
51321 }
51322 inputShape = inputShape;
51323 let allNoneShape = true;
51324 for (const shape of inputShape) {
51325 if (shape != null) {
51326 allNoneShape = false;
51327 break;
51328 }
51329 }
51330 if (allNoneShape) {
51331 return;
51332 }
51333 const shapeSet = [];
51334 for (let i = 0; i < inputShape.length; ++i) {
51335 const shapeWithoutConcatAxis = inputShape[i].slice();
51336 shapeWithoutConcatAxis.splice(this.axis, 1);
51337 let exists = false;
51338 for (const shape of shapeSet) {
51339 if (arraysEqual(shape, shapeWithoutConcatAxis)) {
51340 exists = true;
51341 break;
51342 }
51343 }
51344 if (!exists) {
51345 shapeSet.push(shapeWithoutConcatAxis);
51346 }
51347 }
51348 if (shapeSet.length > 1) {
51349 throw new ValueError('A `Concatenate` layer requires inputs with matching shapes ' +
51350 'except for the concat axis. Got input shapes: ' +
51351 JSON.stringify(inputShape));
51352 }
51353 }
51354 mergeFunction(inputs) {
51355 return tidy(() => {
51356 return concatenate$2(inputs, this.axis);
51357 });
51358 }
51359 computeOutputShape(inputShape) {
51360 if (!(Array.isArray(inputShape) && Array.isArray(inputShape[0]))) {
51361 throw new ValueError('A `Concatenate` layer should be called on a list of inputs.');
51362 }
51363 const inputShapes = inputShape;
51364 const outputShape = inputShapes[0].slice();
51365 const axis = this.axis < 0 ? outputShape.length + this.axis : this.axis;
51366 // Porting Note: the line above is because TypeScript doesn't support
51367 // negative indices.
51368 for (const shape of inputShapes.slice(1)) {
51369 if (outputShape[axis] == null || shape[axis] == null) {
51370 outputShape[axis] = null;
51371 break;
51372 }
51373 outputShape[axis] += shape[axis];
51374 }
51375 return outputShape;
51376 }
51377 computeMask(inputs, mask) {
51378 if (mask == null) {
51379 return null;
51380 }
51381 if (!Array.isArray(mask)) {
51382 throw new ValueError('`mask` should be an array for Concatenate');
51383 }
51384 if (!Array.isArray(inputs)) {
51385 throw new ValueError('`inputs` should be an array for Concatenate');
51386 }
51387 if (mask.length !== inputs.length) {
51388 throw new ValueError(`Mismatch in the length of mask (${mask.length}) ` +
51389 `and the legnth of inputs (${inputs.length})`);
51390 }
51391 return tidy(() => {
51392 let allNullMasks = true;
51393 mask.forEach(m => {
51394 if (m != null) {
51395 allNullMasks = false;
51396 return;
51397 }
51398 });
51399 if (allNullMasks) {
51400 return null;
51401 }
51402 const outputMasks = [];
51403 for (let i = 0; i < inputs.length; ++i) {
51404 if (mask[i] == null) {
51405 // Input is unmasked. Append all 1's to masks.
51406 outputMasks.push(cast$3(onesLike$3(inputs[i]), 'bool'));
51407 }
51408 else if (mask[i].rank < inputs[i].rank) {
51409 // Mask is smaller than the input, expand it.
51410 outputMasks.push(expandDims$3(mask[i], -1));
51411 }
51412 else {
51413 outputMasks.push(mask[i]);
51414 }
51415 }
51416 const concatenatedMasks = concat$2(outputMasks, this.axis);
51417 return all$2(concatenatedMasks, -1, false);
51418 });
51419 }
51420 getConfig() {
51421 const config = {
51422 'axis': this.axis,
51423 };
51424 const baseConfig = super.getConfig();
51425 Object.assign(config, baseConfig);
51426 return config;
51427 }
51428 }
51429 /** @nocollapse */
51430 Concatenate.className = 'Concatenate';
51431 registerClass(Concatenate);
51432 /**
51433 * Concatenate an `Array` of inputs.
51434 *
51435 * This function can be invoked in three ways.
51436 *
51437 * 1. Construct an instance of `Concatenate` layer, by using no input argument
51438 * or a single configuration argument. The resultant `Concatenate` layer can
51439 * then be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
51440 *
51441 * ```js
51442 * const concatLayer = tf.layers.concatenate();
51443 *
51444 * // The layer can be applied to inputs.
51445 * const input1 = tf.input({shape: [2, 3]});
51446 * const input2 = tf.input({shape: [2, 4]});
51447 * const output = concatLayer.apply([input1, input2]);
51448 * console.log(output.shape);
51449 * // You get [null, 2, 7], with the first dimension as the undetermined batch
51450 * // dimension and the last dimension as the result of concatenating the
51451 * // last dimensions of the two inputs.
51452 * ```
51453 *
51454 * 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
51455 * an `Layer` object internally and calls its `apply` method on the inputs,
51456 * generating a new `tf.SymbolicTensor`. For example:
51457 *
51458 * ```js
51459 * const input1 = tf.input({shape: [2, 3]});
51460 * const input2 = tf.input({shape: [2, 4]});
51461 * const output = tf.layers.concatenate([input1, input2]);
51462 * console.log(output.shape);
51463 * // You get [null, 2, 2], with the first dimension as the undetermined batch
51464 * // dimension and the last dimension as the result of concatenating the
51465 * // last dimensions of the two inputs.
51466 * ```
51467 *
51468 * 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
51469 * an `Layer` object internally and calls its `apply` method on the inputs,
51470 * generating a new `tf.Tensor` as the result of the computation. For
51471 * example:
51472 *
51473 * ```js
51474 * const input1 = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
51475 * const input2 = tf.tensor2d([[10, 20], [30, 40]], [2, 2]);
51476 * tf.layers.concatenate([input1, input2]).print();
51477 * // Gives [[1, 2, 10, 20], [3, 4, 30, 40]].
51478 *
51479 */
51480 function concatenate$1(config) {
51481 if (Array.isArray(config)) {
51482 const layer = new Concatenate({});
51483 return layer.apply(config);
51484 }
51485 else {
51486 return new Concatenate(config);
51487 }
51488 }
51489 /**
51490 * Interpretable potentially negative axis index.
51491 *
51492 * For example, given axis = -1, and dim = 3, this function will return 2.
51493 *
51494 * @param axis The axis index, may be a positive, zero or negative integer.
51495 * @param dim Total number of dimensions, a positive integer.
51496 * @returns A non-negative axis index equivalent to the input `axis`.
51497 */
51498 function interpretAxis(axis, dim) {
51499 while (axis < 0) {
51500 axis += dim;
51501 }
51502 return axis;
51503 }
51504 function batchDot(x, y, axes) {
51505 if (x.shape.length > 3 || y.shape.length > 3) {
51506 throw new NotImplementedError('batchDot is not implemented for tensors of 4D or higher rank yet');
51507 }
51508 assert$1(x.shape.length >= 2, () => `batchDot requires the rank of x to be >= 2, ` +
51509 `but got ${x.shape.length}`);
51510 assert$1(x.shape.length >= 2, () => `batchDot requires the rank of y to be >= 2, ` +
51511 `but got ${y.shape.length}`);
51512 if (typeof axes === 'number') {
51513 axes = [axes, axes];
51514 }
51515 if (x.dtype === 'complex64' || y.dtype === 'complex64') {
51516 throw new NotImplementedError('batchDot is not implemented for complex64-type Tensors yet.');
51517 }
51518 const xNDim = x.shape.length;
51519 const yNDim = y.shape.length;
51520 if (axes == null) {
51521 // Behave like batchMatmul by default.
51522 axes = [xNDim - 1, yNDim - 2];
51523 }
51524 const axesArray = axes;
51525 return tidy(() => {
51526 let diff;
51527 if (xNDim > yNDim) {
51528 diff = xNDim - yNDim;
51529 const diffShape = [];
51530 for (let i = 0; i < diff; ++i) {
51531 diffShape.push(1);
51532 }
51533 y = reshape$3(y, y.shape.concat(diffShape));
51534 }
51535 else if (yNDim > xNDim) {
51536 diff = yNDim - xNDim;
51537 const diffShape = [];
51538 for (let i = 0; i < diff; ++i) {
51539 diffShape.push(1);
51540 }
51541 x = reshape$3(x, x.shape.concat(diffShape));
51542 }
51543 else {
51544 diff = 0;
51545 }
51546 let out;
51547 if (x.shape.length === 2 && y.shape.length === 2) {
51548 if (axesArray[0] === axesArray[1]) {
51549 out = sum$3(mul(x, y), axesArray[0]);
51550 }
51551 else {
51552 out = sum$3(mul(transpose$2(x, [1, 0]), y), axesArray[1]);
51553 }
51554 }
51555 else {
51556 const adjX = axesArray[0] !== x.shape.length - 1;
51557 const adjY = axesArray[1] === y.shape.length - 1;
51558 out = matMul$1(x, y, adjX, adjY);
51559 }
51560 if (diff > 0) {
51561 let idx;
51562 if (xNDim > yNDim) {
51563 idx = xNDim + yNDim - 3;
51564 }
51565 else {
51566 idx = xNDim - 1;
51567 }
51568 const squeezeAxes = [];
51569 for (let i = idx; i < idx + diff; ++i) {
51570 squeezeAxes.push(i);
51571 }
51572 out = squeeze(out, squeezeAxes);
51573 }
51574 if (out.shape.length === 1) {
51575 out = expandDims$3(out, 1);
51576 }
51577 return out;
51578 });
51579 }
51580 class Dot extends Merge {
51581 constructor(args) {
51582 super(args);
51583 this.axes = args.axes;
51584 this.normalize = args.normalize == null ? false : args.normalize;
51585 this.supportsMasking = true;
51586 this.reshapeRequired = false;
51587 }
51588 build(inputShape) {
51589 assert$1(Array.isArray(inputShape) && inputShape.length === 2 &&
51590 Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), () => 'A `Dot` layer should be called on a list of exactly 2 inputs.');
51591 const shape1 = inputShape[0];
51592 const shape2 = inputShape[1];
51593 if (shape1.length > 3 || shape2.length > 3) {
51594 throw new NotImplementedError('Dot layer does not support tensors of 4D or higher rank yet.');
51595 }
51596 const axes = this.interpretAxes(shape1, shape2);
51597 if (shape1[axes[0]] !== shape2[axes[1]]) {
51598 throw new ValueError(`Dimension incompatibility: ` +
51599 `${shape1[axes[0]]} !== ${shape2[axes[1]]}`);
51600 }
51601 }
51602 mergeFunction(inputs) {
51603 if (inputs.length !== 2) {
51604 throw new ValueError('A `Dot` layer must be called on exactly 2 inputs, ' +
51605 `but received ${inputs.length} input(s).`);
51606 }
51607 let x1 = inputs[0];
51608 let x2 = inputs[1];
51609 let axes;
51610 if (!Array.isArray(this.axes)) {
51611 axes = [
51612 interpretAxis(this.axes, x1.shape.length),
51613 interpretAxis(this.axes, x2.shape.length)
51614 ];
51615 }
51616 else {
51617 axes = this.axes.map((axis, i) => interpretAxis(axis, inputs[i].shape.length));
51618 }
51619 if (this.normalize) {
51620 x1 = l2Normalize(x1, axes[0]);
51621 x2 = l2Normalize(x2, axes[1]);
51622 }
51623 return batchDot(x1, x2, axes);
51624 }
51625 interpretAxes(shape1, shape2) {
51626 let axes;
51627 if (!Array.isArray(this.axes)) {
51628 // `this.axes` is a single integer.
51629 axes = [
51630 interpretAxis(this.axes, shape1.length),
51631 interpretAxis(this.axes, shape2.length)
51632 ];
51633 }
51634 else {
51635 // `this.axes` is an Array of integers.
51636 axes = this.axes;
51637 }
51638 return axes;
51639 }
51640 computeOutputShape(inputShape) {
51641 assert$1(Array.isArray(inputShape) && inputShape.length === 2 &&
51642 Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), () => 'A `Dot` layer should be called on a list of exactly 2 inputs.');
51643 const shape1 = inputShape[0].slice();
51644 const shape2 = inputShape[1].slice();
51645 if (shape1.length > 3 || shape2.length > 3) {
51646 throw new NotImplementedError('Dot layer does not support tensors of 4D or higher rank yet.');
51647 }
51648 const axes = this.interpretAxes(shape1, shape2);
51649 shape1.splice(axes[0], 1);
51650 shape2.splice(axes[1], 1);
51651 shape2.splice(0, 1);
51652 const outputShape = shape1.concat(shape2);
51653 if (outputShape.length === 1) {
51654 outputShape.push(1);
51655 }
51656 return outputShape;
51657 }
51658 computeMask(inputs, mask) {
51659 return null;
51660 }
51661 getConfig() {
51662 const config = {
51663 'axes': this.axes,
51664 'normalize': this.normalize
51665 };
51666 const baseConfig = super.getConfig();
51667 Object.assign(config, baseConfig);
51668 return config;
51669 }
51670 }
51671 /** @nocollapse */
51672 Dot.className = 'Dot';
51673 registerClass(Dot);
51674 // TODO(cais): Add functional interfaces for the merge layers.
51675
51676 /**
51677 * @license
51678 * Copyright 2018 Google LLC
51679 *
51680 * Use of this source code is governed by an MIT-style
51681 * license that can be found in the LICENSE file or at
51682 * https://opensource.org/licenses/MIT.
51683 * =============================================================================
51684 */
51685 class GaussianNoise extends Layer {
51686 constructor(args) {
51687 super(args);
51688 this.supportsMasking = true;
51689 this.stddev = args.stddev;
51690 }
51691 computeOutputShape(inputShape) {
51692 return inputShape;
51693 }
51694 getConfig() {
51695 const baseConfig = super.getConfig();
51696 const config = { stddev: this.stddev };
51697 Object.assign(config, baseConfig);
51698 return config;
51699 }
51700 call(inputs, kwargs) {
51701 return tidy(() => {
51702 this.invokeCallHook(inputs, kwargs);
51703 const input = getExactlyOneTensor(inputs);
51704 const noised = () => add$3(randomNormal$1(input.shape, 0, this.stddev), input);
51705 const output = inTrainPhase(noised, () => input, kwargs['training'] || false);
51706 return output;
51707 });
51708 }
51709 }
51710 /** @nocollapse */
51711 GaussianNoise.className = 'GaussianNoise';
51712 registerClass(GaussianNoise);
51713 class GaussianDropout extends Layer {
51714 constructor(args) {
51715 super(args);
51716 this.supportsMasking = true;
51717 this.rate = args.rate;
51718 }
51719 computeOutputShape(inputShape) {
51720 return inputShape;
51721 }
51722 getConfig() {
51723 const baseConfig = super.getConfig();
51724 const config = { rate: this.rate };
51725 Object.assign(config, baseConfig);
51726 return config;
51727 }
51728 call(inputs, kwargs) {
51729 return tidy(() => {
51730 this.invokeCallHook(inputs, kwargs);
51731 const input = getExactlyOneTensor(inputs);
51732 if (this.rate > 0 && this.rate < 1) {
51733 const noised = () => {
51734 const stddev = Math.sqrt(this.rate / (1 - this.rate));
51735 return mul(input, randomNormal$1(input.shape, 1, stddev));
51736 };
51737 return inTrainPhase(noised, () => input, kwargs['training'] || false);
51738 }
51739 return input;
51740 });
51741 }
51742 }
51743 /** @nocollapse */
51744 GaussianDropout.className = 'GaussianDropout';
51745 registerClass(GaussianDropout);
51746 /**
51747 * Applies Alpha Dropout to the input.
51748 *
51749 * As it is a regularization layer, it is only active at training time.
51750 *
51751 * Alpha Dropout is a `Dropout` that keeps mean and variance of inputs
51752 * to their original values, in order to ensure the self-normalizing property
51753 * even after this dropout.
51754 * Alpha Dropout fits well to Scaled Exponential Linear Units
51755 * by randomly setting activations to the negative saturation value.
51756 *
51757 * Arguments:
51758 * - `rate`: float, drop probability (as with `Dropout`).
51759 * The multiplicative noise will have
51760 * standard deviation `sqrt(rate / (1 - rate))`.
51761 * - `noise_shape`: A 1-D `Tensor` of type `int32`, representing the
51762 * shape for randomly generated keep/drop flags.
51763 *
51764 * Input shape:
51765 * Arbitrary. Use the keyword argument `inputShape`
51766 * (tuple of integers, does not include the samples axis)
51767 * when using this layer as the first layer in a model.
51768 *
51769 * Output shape:
51770 * Same shape as input.
51771 *
51772 * References:
51773 * - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
51774 */
51775 class AlphaDropout extends Layer {
51776 constructor(args) {
51777 super(args);
51778 this.supportsMasking = true;
51779 this.rate = args.rate;
51780 this.noiseShape = args.noiseShape;
51781 }
51782 _getNoiseShape(inputs) {
51783 return this.noiseShape || getExactlyOneTensor(inputs).shape;
51784 }
51785 computeOutputShape(inputShape) {
51786 return inputShape;
51787 }
51788 getConfig() {
51789 const baseConfig = super.getConfig();
51790 const config = { rate: this.rate };
51791 Object.assign(config, baseConfig);
51792 return config;
51793 }
51794 call(inputs, kwargs) {
51795 return tidy(() => {
51796 if (this.rate < 1 && this.rate > 0) {
51797 const noiseShape = this._getNoiseShape(inputs);
51798 const droppedInputs = () => {
51799 const input = getExactlyOneTensor(inputs);
51800 const alpha = 1.6732632423543772848170429916717;
51801 const scale = 1.0507009873554804934193349852946;
51802 const alphaP = -alpha * scale;
51803 let keptIdx = greaterEqual$2(randomUniform$1(noiseShape), this.rate);
51804 keptIdx = cast$2(keptIdx, 'float32'); // get default dtype.
51805 // Get affine transformation params.
51806 const a = ((1 - this.rate) * (1 + this.rate * alphaP ** 2)) ** -0.5;
51807 const b = -a * alphaP * this.rate;
51808 // Apply mask.
51809 const x = add$3(mul(input, keptIdx), mul(add$3(keptIdx, -1), alphaP));
51810 return add$3(mul(x, a), b);
51811 };
51812 return inTrainPhase(droppedInputs, () => getExactlyOneTensor(inputs), kwargs['training'] || false);
51813 }
51814 return inputs;
51815 });
51816 }
51817 }
51818 /** @nocollapse */
51819 AlphaDropout.className = 'AlphaDropout';
51820 registerClass(AlphaDropout);
51821
51822 /**
51823 * @license
51824 * Copyright 2018 Google LLC
51825 *
51826 * Use of this source code is governed by an MIT-style
51827 * license that can be found in the LICENSE file or at
51828 * https://opensource.org/licenses/MIT.
51829 * =============================================================================
51830 */
51831 /**
51832 * Applies batch normalization on x given mean, var, beta and gamma.
51833 *
51834 * I.e. returns:
51835 * `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
51836 *
51837 * @param x Input tensor.
51838 * @param mean Mean of batch.
51839 * @param variance Variance of batch.
51840 * @param beta Tensor with which to center the input.
51841 * @param gamma Tensor by which to scale the input.
51842 * @param epsilon Fuzz factor.
51843 * @returns The result of the batch normalization.
51844 */
51845 function batchNormalization$1(x, mean, variance, beta, gamma, epsilon = 1e-3) {
51846 let out;
51847 if (x.rank === 2) {
51848 out = batchNorm2d(x, mean, variance, beta, gamma, epsilon);
51849 }
51850 else if (x.rank === 3) {
51851 // TODO(cais): Check rank; give proper error message.
51852 out = batchNorm3d(x, mean, variance, beta, gamma, epsilon);
51853 }
51854 else if (x.rank === 4) {
51855 out = batchNorm4d(x, mean, variance, beta, gamma, epsilon);
51856 }
51857 else {
51858 throw new NotImplementedError(`batchNormalization is not implemented for array of rank ${x.rank} ` +
51859 `yet`);
51860 }
51861 return out;
51862 }
51863 /**
51864 * Non-broadcasting batch normalization for use in training (not inference).
51865 *
51866 * The input is normalized to zero mean and unit variance along the
51867 * `reductionAxes`, followed by scaling with `gamma` and shifted by `beta`.
51868 * The result of that is returned as the first element
51869 * of the returned `Array`. The other two elements are the mean and variance,
51870 * respectively.
51871 *
51872 * @param x Input tensor to be normalized.
51873 * @param gamma Tensor by which to scale the input.
51874 * @param beta Tensor by which to center the input.
51875 * @param reductionAxes Axes over which to normalize.
51876 * @param epsilon Fuzz factor.
51877 * @returns An `Array` of three `Tensors`:
51878 * [normalized tensor, mean of input, variance of input].
51879 */
51880 function regularNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon = 1e-3) {
51881 return tidy(() => {
51882 const meanAndVariance = moments(x, reductionAxes);
51883 const mean = meanAndVariance.mean;
51884 const variance = meanAndVariance.variance;
51885 const normed = batchNormalization$1(x, mean, variance, beta, gamma, epsilon);
51886 return [normed, mean, variance];
51887 });
51888 }
51889 /**
51890 * Broadcasting batch normalization for use in training (not inference).
51891 *
51892 * The input is normalized to zero mean and unit variance along the
51893 * `reductionAxes`, followed by scaling with `gamma` and shifted by `beta`.
51894 * The result of that is returned as the first element
51895 * of the returned `Array`. The other two elements are the mean and variance,
51896 * respectively.
51897 *
51898 * @param x Input tensor to be normalized.
51899 * @param gamma Tensor by which to scale the input.
51900 * @param beta Tensor by which to center the input.
51901 * @param reductionAxes Axes over which to normalize.
51902 * @param epsilon Fuzz factor.
51903 * @returns An `Array` of three `Tensors`:
51904 * [normalized tensor, mean of input, variance of input].
51905 */
51906 function broadcastNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon = 1e-3) {
51907 return tidy(() => {
51908 const meanAndVariance = moments(x, reductionAxes);
51909 const mean = meanAndVariance.mean;
51910 const variance = meanAndVariance.variance;
51911 const targetShape = [];
51912 for (const axis of range$2(0, x.rank)) {
51913 if (reductionAxes.indexOf(axis) !== -1) {
51914 targetShape.push(1);
51915 }
51916 else {
51917 targetShape.push(x.shape[axis]);
51918 }
51919 }
51920 const broadcastMean = reshape$3(mean, targetShape);
51921 const broadcastVariance = reshape$3(variance, targetShape);
51922 const broadcastGamma = gamma == null ? null : reshape$3(gamma, targetShape);
51923 const broadcastBeta = beta == null ? null : reshape$3(beta, targetShape);
51924 const normed = batchNormalization$1(x, broadcastMean, broadcastVariance, broadcastBeta, broadcastGamma, epsilon);
51925 return [normed, mean, variance];
51926 });
51927 }
51928 /**
51929 * Batch normalization for use in training (not inference).
51930 *
51931 * @param x Input tensor to be normalized.
51932 * @param gamma Tensor by which to scale the input.
51933 * @param beta Tensor by which to center the input.
51934 * @param reductionAxes Axes over which to normalize.
51935 * @param epsilon Fuzz factor.
51936 * @returns An `Array` of three `Tensors`:
51937 * [normalized tensor, mean of input, variance of input].
51938 */
51939 function normalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon = 1e-3) {
51940 if (arraysEqual(reductionAxes.slice().sort(), range$2(0, x.rank - 1))) {
51941 return regularNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon);
51942 }
51943 else {
51944 return broadcastNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon);
51945 }
51946 }
51947 class BatchNormalization extends Layer {
51948 constructor(args) {
51949 if (args == null) {
51950 args = {};
51951 }
51952 super(args);
51953 this.supportsMasking = true;
51954 this.axis = args.axis == null ? -1 : args.axis;
51955 this.momentum = args.momentum == null ? 0.99 : args.momentum;
51956 this.epsilon = args.epsilon == null ? 1e-3 : args.epsilon;
51957 this.center = args.center == null ? true : args.center;
51958 this.scale = args.scale == null ? true : args.scale;
51959 this.betaInitializer = getInitializer(args.betaInitializer || 'zeros');
51960 this.gammaInitializer = getInitializer(args.gammaInitializer || 'ones');
51961 this.movingMeanInitializer =
51962 getInitializer(args.movingMeanInitializer || 'zeros');
51963 this.movingVarianceInitializer =
51964 getInitializer(args.movingVarianceInitializer || 'ones');
51965 this.betaConstraint = getConstraint(args.betaConstraint);
51966 this.gammaConstraint = getConstraint(args.gammaConstraint);
51967 this.betaRegularizer = getRegularizer(args.betaRegularizer);
51968 this.gammaRegularizer = getRegularizer(args.gammaRegularizer);
51969 }
51970 build(inputShape) {
51971 inputShape = getExactlyOneShape(inputShape);
51972 const axis = this.axis >= 0 ? this.axis : (this.axis + inputShape.length);
51973 const dim = inputShape[axis];
51974 if (dim == null) {
51975 throw new ValueError(`Axis ${axis} of input tensor should have a defined dimension but ` +
51976 `the layer received an input with shape ` +
51977 `${JSON.stringify(inputShape)}.`);
51978 }
51979 this.inputSpec =
51980 [new InputSpec({ ndim: inputShape.length, axes: { [axis]: dim } })];
51981 const shape = [dim];
51982 if (this.scale) {
51983 this.gamma = this.addWeight('gamma', shape, null, this.gammaInitializer, this.gammaRegularizer, true, this.gammaConstraint);
51984 }
51985 if (this.center) {
51986 this.beta = this.addWeight('beta', shape, null, this.betaInitializer, this.betaRegularizer, true, this.betaConstraint);
51987 }
51988 this.movingMean = this.addWeight('moving_mean', shape, null, this.movingMeanInitializer, null, false);
51989 this.movingVariance = this.addWeight('moving_variance', shape, null, this.movingVarianceInitializer, null, false);
51990 this.built = true;
51991 }
51992 call(inputs, kwargs) {
51993 return tidy(() => {
51994 const training = kwargs['training'] == null ? false : kwargs['training'];
51995 const input = getExactlyOneTensor(inputs);
51996 const inputShape = input.shape;
51997 const ndim = inputShape.length;
51998 const reductionAxes = range$2(0, ndim);
51999 const axis = this.axis >= 0 ? this.axis : (this.axis + ndim);
52000 reductionAxes.splice(axis, 1);
52001 const broadcastShape = pyListRepeat(1, ndim);
52002 broadcastShape[axis] = inputShape[axis];
52003 const sortedReductionAxes = reductionAxes.slice();
52004 sortedReductionAxes.sort();
52005 const needsBroadcasting = !arraysEqual(sortedReductionAxes, range$2(0, ndim).slice(0, ndim - 1));
52006 const normalizeInference = () => {
52007 if (needsBroadcasting) {
52008 const broadcastMovingMean = reshape$3(this.movingMean.read(), broadcastShape);
52009 const broadcastMovingVariance = reshape$3(this.movingVariance.read(), broadcastShape);
52010 const broadcastBeta = this.center ? reshape$3(this.beta.read(), broadcastShape) : null;
52011 const broadcastGamma = this.scale ? reshape$3(this.gamma.read(), broadcastShape) : null;
52012 return batchNormalization$1(input, broadcastMovingMean, broadcastMovingVariance, broadcastBeta, broadcastGamma, this.epsilon);
52013 }
52014 else {
52015 return batchNormalization$1(input, this.movingMean.read(), this.movingVariance.read(), this.beta == null ? null : this.beta.read(), this.gamma == null ? null : this.gamma.read(), this.epsilon);
52016 }
52017 };
52018 if (!training) {
52019 return normalizeInference();
52020 }
52021 const [normedTraining, mean, variance] = normalizeBatchInTraining(input, this.gamma.read(), this.beta.read(), reductionAxes, this.epsilon);
52022 const doMovingAverage = (variable, value, momentum) => {
52023 tidy(() => {
52024 const decay = 1 - momentum;
52025 const origValue = variable.read();
52026 const updateDelta = mul(sub$2(origValue, value), decay);
52027 variable.write(sub$2(origValue, updateDelta));
52028 });
52029 };
52030 // Perform updates to moving mean and moving variance for training.
52031 // Porting Note: In PyKeras, these updates to `movingMean` and
52032 // `movingAverage` are done as a deferred Graph, added to the `Layer`'s
52033 // `update`s using the `add_update()` method. Here we do it imperatively
52034 // and encapsulate the updates in a function that is invoked
52035 // immediately.
52036 const updateMovingMeanAndVariance = () => {
52037 doMovingAverage(this.movingMean, mean, this.momentum);
52038 doMovingAverage(this.movingVariance, variance, this.momentum);
52039 };
52040 updateMovingMeanAndVariance();
52041 return normedTraining;
52042 });
52043 }
52044 getConfig() {
52045 const config = {
52046 axis: this.axis,
52047 momentum: this.momentum,
52048 epsilon: this.epsilon,
52049 center: this.center,
52050 scale: this.scale,
52051 betaInitializer: serializeInitializer(this.betaInitializer),
52052 gammaInitializer: serializeInitializer(this.gammaInitializer),
52053 movingMeanInitializer: serializeInitializer(this.movingMeanInitializer),
52054 movingVarianceInitializer: serializeInitializer(this.movingVarianceInitializer),
52055 betaRegularizer: serializeRegularizer(this.betaRegularizer),
52056 gammaRegularizer: serializeRegularizer(this.gammaRegularizer),
52057 betaConstraint: serializeConstraint(this.betaConstraint),
52058 gammaConstraint: serializeConstraint(this.gammaConstraint)
52059 };
52060 const baseConfig = super.getConfig();
52061 Object.assign(config, baseConfig);
52062 return config;
52063 }
52064 }
52065 /** @nocollapse */
52066 BatchNormalization.className = 'BatchNormalization';
52067 registerClass(BatchNormalization);
52068 class LayerNormalization extends Layer {
52069 constructor(args) {
52070 if (args == null) {
52071 args = {};
52072 }
52073 super(args);
52074 this.axis = args.axis == null ? -1 : args.axis;
52075 if (typeof this.axis === 'number') {
52076 if (!Number.isInteger(this.axis)) {
52077 throw new Error(`Expected axis to be an integer, but received ${this.axis}`);
52078 }
52079 }
52080 else if (Array.isArray(this.axis)) {
52081 for (const axis of this.axis) {
52082 if (!Number.isInteger(axis)) {
52083 throw new Error(`Expected axis to be an array of integers, ` +
52084 `but received ${JSON.stringify(this.axis)}`);
52085 }
52086 }
52087 }
52088 else {
52089 throw new Error(`Expected axis to be an integer or an array of integers, ` +
52090 `but received ${JSON.stringify(this.axis)}`);
52091 }
52092 this.epsilon = args.epsilon == null ? 1e-3 : args.epsilon;
52093 this.center = args.center == null ? true : args.center;
52094 this.scale = args.scale == null ? true : args.scale;
52095 this.betaInitializer = getInitializer(args.betaInitializer || 'zeros');
52096 this.gammaInitializer = getInitializer(args.gammaInitializer || 'ones');
52097 this.betaRegularizer = getRegularizer(args.betaRegularizer);
52098 this.gammaRegularizer = getRegularizer(args.gammaRegularizer);
52099 this.supportsMasking = true;
52100 }
52101 build(inputShape) {
52102 inputShape = getExactlyOneShape(inputShape);
52103 const nDims = inputShape.length;
52104 // Convert axis to array and resolve negatives.
52105 if (typeof this.axis === 'number') {
52106 this.axis = [this.axis];
52107 }
52108 for (let i = 0; i < this.axis.length; ++i) {
52109 if (this.axis[i] < 0) {
52110 this.axis[i] += nDims;
52111 }
52112 }
52113 // Further validate axes.
52114 for (const axis of this.axis) {
52115 if (axis < 0 || axis >= nDims) {
52116 throw new Error(`Invalid axis: ${axis}`);
52117 }
52118 }
52119 if (this.axis.length !== unique$2(this.axis).length) {
52120 throw new Error(`Found duplicate axes in: ${this.axis}`);
52121 }
52122 const paramShape = this.axis.map(axis => inputShape[axis]);
52123 const trainable = true;
52124 if (this.scale) {
52125 this.gamma = this.addWeight('gamma', paramShape, 'float32', this.gammaInitializer, this.gammaRegularizer, trainable);
52126 }
52127 else {
52128 this.gamma = null;
52129 }
52130 if (this.center) {
52131 this.beta = this.addWeight('beta', paramShape, 'float32', this.betaInitializer, this.betaRegularizer, trainable);
52132 }
52133 else {
52134 this.beta = null;
52135 }
52136 this.built = true;
52137 }
52138 call(inputs, kwargs) {
52139 const input = getExactlyOneTensor(inputs);
52140 const inputShape = input.shape;
52141 const nDims = inputShape.length;
52142 return tidy(() => {
52143 const keepDims = true;
52144 let { mean, variance } = moments(input, this.axis, keepDims);
52145 const broadcastShape = pyListRepeat(1, nDims);
52146 for (const dim of this.axis) {
52147 broadcastShape[dim] = inputShape[dim];
52148 }
52149 const broadcast = (v) => {
52150 if (v != null && v.shape.length !== nDims) {
52151 return reshape$3(v, broadcastShape);
52152 }
52153 else {
52154 return v;
52155 }
52156 };
52157 let scale = this.scale ? broadcast(this.gamma.read()) : null;
52158 let offset = this.center ? broadcast(this.beta.read()) : null;
52159 // TODO(https://github.com/tensorflow/tfjs/issues/2120): The tiling below
52160 // is a workaround for the limitation of core's batchNormalization?d don't
52161 // support broadcasting in their gradients. In addition, the tiling is
52162 // necessary to ensure correctness on the browser CPU backend regardless
52163 // of forward or backward computation. Remove this workaround once the
52164 // limitation is addressed. See .
52165 const momentsTiling = [];
52166 const scaleOffsetTiling = [];
52167 for (let i = 0; i < nDims; ++i) {
52168 if (this.axis.indexOf(i) !== -1) {
52169 momentsTiling.push(inputShape[i]);
52170 scaleOffsetTiling.push(1);
52171 }
52172 else {
52173 momentsTiling.push(1);
52174 scaleOffsetTiling.push(inputShape[i]);
52175 }
52176 }
52177 mean = tile$3(mean, momentsTiling);
52178 variance = tile$3(variance, momentsTiling);
52179 if (scale != null) {
52180 scale = tile$3(scale, scaleOffsetTiling);
52181 }
52182 if (offset != null) {
52183 offset = tile$3(offset, scaleOffsetTiling);
52184 }
52185 return batchNormalization$1(input, mean, variance, offset, scale, this.epsilon);
52186 });
52187 }
52188 getConfig() {
52189 const config = {
52190 axis: this.axis,
52191 epsilon: this.epsilon,
52192 center: this.center,
52193 scale: this.scale,
52194 betaInitializer: serializeInitializer(this.betaInitializer),
52195 gammaInitializer: serializeInitializer(this.gammaInitializer),
52196 betaRegularizer: serializeRegularizer(this.betaRegularizer),
52197 gammaRegularizer: serializeRegularizer(this.gammaRegularizer)
52198 };
52199 const baseConfig = super.getConfig();
52200 Object.assign(config, baseConfig);
52201 return config;
52202 }
52203 }
52204 /** @nocollapse */
52205 LayerNormalization.className = 'LayerNormalization';
52206 registerClass(LayerNormalization);
52207
52208 /**
52209 * @license
52210 * Copyright 2018 Google LLC
52211 *
52212 * Use of this source code is governed by an MIT-style
52213 * license that can be found in the LICENSE file or at
52214 * https://opensource.org/licenses/MIT.
52215 * =============================================================================
52216 */
52217 /**
52218 * Pads the middle dimension of a 3D tensor.
52219 *
52220 * @param x Input `tf.Tensor` to be padded.
52221 * @param padding `Array` of 2 integers, how many zeros to add at the start and
52222 * end of the middle dimension (i.e., dimension 1).
52223 * @return A padded 3D `tf.Tensor`.
52224 */
52225 function temporalPadding(x, padding) {
52226 return tidy(() => {
52227 if (x.rank !== 3) {
52228 throw new ValueError(`temporalPadding expects input tensor to be 3-D, but received a ` +
52229 `${x.rank}-D tensor.`);
52230 }
52231 if (padding == null) {
52232 padding = [1, 1];
52233 }
52234 if (padding.length !== 2) {
52235 throw new ValueError(`temporalPadding expects input padding pattern to be a length-2 ` +
52236 `array, but received a length-${padding.length} array.`);
52237 }
52238 const pattern = [[0, 0], padding, [0, 0]];
52239 return pad(x, pattern);
52240 });
52241 }
52242 /**
52243 * Pads the 2nd and 3rd dimensions of a 4D tensor.
52244 *
52245 * @param x Input `tf.Tensor` to be padded.
52246 * @param padding `Array` of two `Array`s, each of which is an `Array` of two
52247 * integers. The amount of padding at the beginning and end of the 2nd and 3rd
52248 * dimensions, respectively.
52249 * @param dataFormat 'channelsLast' (default) or 'channelsFirst'.
52250 * @return Padded 4D `tf.Tensor`.
52251 */
52252 function spatial2dPadding(x, padding, dataFormat) {
52253 return tidy(() => {
52254 if (x.rank !== 4) {
52255 throw new ValueError(`temporalPadding expects input tensor to be 4-D, but received a ` +
52256 `${x.rank}-D tensor.`);
52257 }
52258 if (padding == null) {
52259 padding = [[1, 1], [1, 1]];
52260 }
52261 if (padding.length !== 2 || padding[0].length !== 2 ||
52262 padding[1].length !== 2) {
52263 throw new ValueError('spatial2dPadding expects `padding` to be an Array of two Arrays, ' +
52264 'each of which is an Array of two integers.');
52265 }
52266 if (dataFormat == null) {
52267 dataFormat = imageDataFormat();
52268 }
52269 if (dataFormat !== 'channelsLast' && dataFormat !== 'channelsFirst') {
52270 throw new ValueError(`Unknown data format: ${dataFormat}. ` +
52271 `Supported data formats are 'channelsLast' and 'channelsFirst.`);
52272 }
52273 let pattern;
52274 if (dataFormat === 'channelsFirst') {
52275 pattern = [[0, 0], [0, 0], padding[0], padding[1]];
52276 }
52277 else {
52278 pattern = [[0, 0], padding[0], padding[1], [0, 0]];
52279 }
52280 return pad(x, pattern);
52281 });
52282 }
52283 class ZeroPadding2D extends Layer {
52284 constructor(args) {
52285 if (args == null) {
52286 args = {};
52287 }
52288 super(args);
52289 this.dataFormat =
52290 args.dataFormat == null ? imageDataFormat() : args.dataFormat;
52291 // TODO(cais): Maybe refactor the following logic surrounding `padding`
52292 // into a helper method.
52293 if (args.padding == null) {
52294 this.padding = [[1, 1], [1, 1]];
52295 }
52296 else if (typeof args.padding === 'number') {
52297 this.padding =
52298 [[args.padding, args.padding], [args.padding, args.padding]];
52299 }
52300 else {
52301 args.padding = args.padding;
52302 if (args.padding.length !== 2) {
52303 throw new ValueError(`ZeroPadding2D expects padding to be a length-2 array, but ` +
52304 `received a length-${args.padding.length} array.`);
52305 }
52306 let heightPadding;
52307 let widthPadding;
52308 if (typeof args.padding[0] === 'number') {
52309 heightPadding = [args.padding[0], args.padding[0]];
52310 widthPadding = [args.padding[1], args.padding[1]];
52311 }
52312 else {
52313 args.padding = args.padding;
52314 if (args.padding[0].length !== 2) {
52315 throw new ValueError(`ZeroPadding2D expects height padding to be a length-2 array, ` +
52316 `but received a length-${args.padding[0].length} array.`);
52317 }
52318 heightPadding = args.padding[0];
52319 if (args.padding[1].length !== 2) {
52320 throw new ValueError(`ZeroPadding2D expects width padding to be a length-2 array, ` +
52321 `but received a length-${args.padding[1].length} array.`);
52322 }
52323 widthPadding = args.padding[1];
52324 }
52325 this.padding = [heightPadding, widthPadding];
52326 }
52327 this.inputSpec = [new InputSpec({ ndim: 4 })];
52328 }
52329 computeOutputShape(inputShape) {
52330 inputShape = getExactlyOneShape(inputShape);
52331 let rows;
52332 let cols;
52333 if (this.dataFormat === 'channelsFirst') {
52334 if (inputShape[2] != null && inputShape[2] >= 0) {
52335 rows = inputShape[2] + this.padding[0][0] + this.padding[0][1];
52336 }
52337 else {
52338 rows = null;
52339 }
52340 if (inputShape[3] != null && inputShape[3] >= 0) {
52341 cols = inputShape[3] + this.padding[1][0] + this.padding[1][1];
52342 }
52343 else {
52344 cols = null;
52345 }
52346 return [inputShape[0], inputShape[1], rows, cols];
52347 }
52348 else {
52349 if (inputShape[1] != null && inputShape[1] >= 0) {
52350 rows = inputShape[1] + this.padding[0][0] + this.padding[0][1];
52351 }
52352 else {
52353 rows = null;
52354 }
52355 if (inputShape[2] != null && inputShape[2] >= 0) {
52356 cols = inputShape[2] + this.padding[1][0] + this.padding[1][1];
52357 }
52358 else {
52359 cols = null;
52360 }
52361 return [inputShape[0], rows, cols, inputShape[3]];
52362 }
52363 }
52364 call(inputs, kwargs) {
52365 return tidy(() => spatial2dPadding(getExactlyOneTensor(inputs), this.padding, this.dataFormat));
52366 }
52367 getConfig() {
52368 const config = {
52369 padding: this.padding,
52370 dataFormat: this.dataFormat,
52371 };
52372 const baseConfig = super.getConfig();
52373 Object.assign(config, baseConfig);
52374 return config;
52375 }
52376 }
52377 /** @nocollapse */
52378 ZeroPadding2D.className = 'ZeroPadding2D';
52379 registerClass(ZeroPadding2D);
52380
52381 /**
52382 * @license
52383 * Copyright 2018 Google LLC
52384 *
52385 * Use of this source code is governed by an MIT-style
52386 * license that can be found in the LICENSE file or at
52387 * https://opensource.org/licenses/MIT.
52388 * =============================================================================
52389 */
52390 /**
52391 * 2D pooling.
52392 * @param x
52393 * @param poolSize
52394 * @param strides strides. Defaults to [1, 1].
52395 * @param padding padding. Defaults to 'valid'.
52396 * @param dataFormat data format. Defaults to 'channelsLast'.
52397 * @param poolMode Mode of pooling. Defaults to 'max'.
52398 * @returns Result of the 2D pooling.
52399 */
52400 function pool2d(x, poolSize, strides, padding, dataFormat, poolMode) {
52401 return tidy(() => {
52402 checkDataFormat(dataFormat);
52403 checkPoolMode(poolMode);
52404 checkPaddingMode(padding);
52405 if (strides == null) {
52406 strides = [1, 1];
52407 }
52408 if (padding == null) {
52409 padding = 'valid';
52410 }
52411 if (dataFormat == null) {
52412 dataFormat = imageDataFormat();
52413 }
52414 if (poolMode == null) {
52415 poolMode = 'max';
52416 }
52417 // TODO(cais): Remove the preprocessing step once deeplearn.js supports
52418 // dataFormat as an input argument.
52419 x = preprocessConv2DInput(x, dataFormat); // x is NHWC after preprocessing.
52420 let y;
52421 const paddingString = (padding === 'same') ? 'same' : 'valid';
52422 if (poolMode === 'max') {
52423 // TODO(cais): Rank check?
52424 y = maxPool$2(x, poolSize, strides, paddingString);
52425 }
52426 else { // 'avg'
52427 // TODO(cais): Check the dtype and rank of x and give clear error message
52428 // if those are incorrect.
52429 y = avgPool$2(
52430 // TODO(cais): Rank check?
52431 x, poolSize, strides, paddingString);
52432 }
52433 if (dataFormat === 'channelsFirst') {
52434 y = transpose$2(y, [0, 3, 1, 2]); // NHWC -> NCHW.
52435 }
52436 return y;
52437 });
52438 }
52439 /**
52440 * 3D pooling.
52441 * @param x
52442 * @param poolSize. Default to [1, 1, 1].
52443 * @param strides strides. Defaults to [1, 1, 1].
52444 * @param padding padding. Defaults to 'valid'.
52445 * @param dataFormat data format. Defaults to 'channelsLast'.
52446 * @param poolMode Mode of pooling. Defaults to 'max'.
52447 * @returns Result of the 3D pooling.
52448 */
52449 function pool3d$1(x, poolSize, strides, padding, dataFormat, poolMode) {
52450 return tidy(() => {
52451 checkDataFormat(dataFormat);
52452 checkPoolMode(poolMode);
52453 checkPaddingMode(padding);
52454 if (strides == null) {
52455 strides = [1, 1, 1];
52456 }
52457 if (padding == null) {
52458 padding = 'valid';
52459 }
52460 if (dataFormat == null) {
52461 dataFormat = imageDataFormat();
52462 }
52463 if (poolMode == null) {
52464 poolMode = 'max';
52465 }
52466 // x is NDHWC after preprocessing.
52467 x = preprocessConv3DInput(x, dataFormat);
52468 let y;
52469 const paddingString = (padding === 'same') ? 'same' : 'valid';
52470 if (poolMode === 'max') {
52471 y = maxPool3d$1(x, poolSize, strides, paddingString);
52472 }
52473 else { // 'avg'
52474 y = avgPool3d$1(x, poolSize, strides, paddingString);
52475 }
52476 if (dataFormat === 'channelsFirst') {
52477 y = transpose$2(y, [0, 4, 1, 2, 3]); // NDHWC -> NCDHW.
52478 }
52479 return y;
52480 });
52481 }
52482 /**
52483 * Abstract class for different pooling 1D layers.
52484 */
52485 class Pooling1D extends Layer {
52486 /**
52487 *
52488 * @param args Parameters for the Pooling layer.
52489 *
52490 * config.poolSize defaults to 2.
52491 */
52492 constructor(args) {
52493 if (args.poolSize == null) {
52494 args.poolSize = 2;
52495 }
52496 super(args);
52497 if (typeof args.poolSize === 'number') {
52498 this.poolSize = [args.poolSize];
52499 }
52500 else if (Array.isArray(args.poolSize) &&
52501 args.poolSize.length === 1 &&
52502 typeof args.poolSize[0] === 'number') {
52503 this.poolSize = args.poolSize;
52504 }
52505 else {
52506 throw new ValueError(`poolSize for 1D convolutional layer must be a number or an ` +
52507 `Array of a single number, but received ` +
52508 `${JSON.stringify(args.poolSize)}`);
52509 }
52510 assertPositiveInteger(this.poolSize, 'poolSize');
52511 if (args.strides == null) {
52512 this.strides = this.poolSize;
52513 }
52514 else {
52515 if (typeof args.strides === 'number') {
52516 this.strides = [args.strides];
52517 }
52518 else if (Array.isArray(args.strides) &&
52519 args.strides.length === 1 &&
52520 typeof args.strides[0] === 'number') {
52521 this.strides = args.strides;
52522 }
52523 else {
52524 throw new ValueError(`strides for 1D convolutional layer must be a number or an ` +
52525 `Array of a single number, but received ` +
52526 `${JSON.stringify(args.strides)}`);
52527 }
52528 }
52529 assertPositiveInteger(this.strides, 'strides');
52530 this.padding = args.padding == null ? 'valid' : args.padding;
52531 checkPaddingMode(this.padding);
52532 this.inputSpec = [new InputSpec({ ndim: 3 })];
52533 }
52534 computeOutputShape(inputShape) {
52535 inputShape = getExactlyOneShape(inputShape);
52536 const length = convOutputLength(inputShape[1], this.poolSize[0], this.padding, this.strides[0]);
52537 return [inputShape[0], length, inputShape[2]];
52538 }
52539 call(inputs, kwargs) {
52540 return tidy(() => {
52541 this.invokeCallHook(inputs, kwargs);
52542 // Add dummy last dimension.
52543 inputs = expandDims$2(getExactlyOneTensor(inputs), 2);
52544 const output = this.poolingFunction(getExactlyOneTensor(inputs), [this.poolSize[0], 1], [this.strides[0], 1], this.padding, 'channelsLast');
52545 // Remove dummy last dimension.
52546 return squeeze(output, [2]);
52547 });
52548 }
52549 getConfig() {
52550 const config = {
52551 poolSize: this.poolSize,
52552 padding: this.padding,
52553 strides: this.strides,
52554 };
52555 const baseConfig = super.getConfig();
52556 Object.assign(config, baseConfig);
52557 return config;
52558 }
52559 }
52560 class MaxPooling1D extends Pooling1D {
52561 constructor(args) {
52562 super(args);
52563 }
52564 poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
52565 checkDataFormat(dataFormat);
52566 checkPaddingMode(padding);
52567 return pool2d(inputs, poolSize, strides, padding, dataFormat, 'max');
52568 }
52569 }
52570 /** @nocollapse */
52571 MaxPooling1D.className = 'MaxPooling1D';
52572 registerClass(MaxPooling1D);
52573 class AveragePooling1D extends Pooling1D {
52574 constructor(args) {
52575 super(args);
52576 }
52577 poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
52578 checkDataFormat(dataFormat);
52579 checkPaddingMode(padding);
52580 return pool2d(inputs, poolSize, strides, padding, dataFormat, 'avg');
52581 }
52582 }
52583 /** @nocollapse */
52584 AveragePooling1D.className = 'AveragePooling1D';
52585 registerClass(AveragePooling1D);
52586 /**
52587 * Abstract class for different pooling 2D layers.
52588 */
52589 class Pooling2D extends Layer {
52590 constructor(args) {
52591 if (args.poolSize == null) {
52592 args.poolSize = [2, 2];
52593 }
52594 super(args);
52595 this.poolSize = Array.isArray(args.poolSize) ?
52596 args.poolSize :
52597 [args.poolSize, args.poolSize];
52598 if (args.strides == null) {
52599 this.strides = this.poolSize;
52600 }
52601 else if (Array.isArray(args.strides)) {
52602 if (args.strides.length !== 2) {
52603 throw new ValueError(`If the strides property of a 2D pooling layer is an Array, ` +
52604 `it is expected to have a length of 2, but received length ` +
52605 `${args.strides.length}.`);
52606 }
52607 this.strides = args.strides;
52608 }
52609 else {
52610 // `config.strides` is a number.
52611 this.strides = [args.strides, args.strides];
52612 }
52613 assertPositiveInteger(this.poolSize, 'poolSize');
52614 assertPositiveInteger(this.strides, 'strides');
52615 this.padding = args.padding == null ? 'valid' : args.padding;
52616 this.dataFormat =
52617 args.dataFormat == null ? 'channelsLast' : args.dataFormat;
52618 checkDataFormat(this.dataFormat);
52619 checkPaddingMode(this.padding);
52620 this.inputSpec = [new InputSpec({ ndim: 4 })];
52621 }
52622 computeOutputShape(inputShape) {
52623 inputShape = getExactlyOneShape(inputShape);
52624 let rows = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
52625 let cols = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
52626 rows =
52627 convOutputLength(rows, this.poolSize[0], this.padding, this.strides[0]);
52628 cols =
52629 convOutputLength(cols, this.poolSize[1], this.padding, this.strides[1]);
52630 if (this.dataFormat === 'channelsFirst') {
52631 return [inputShape[0], inputShape[1], rows, cols];
52632 }
52633 else {
52634 return [inputShape[0], rows, cols, inputShape[3]];
52635 }
52636 }
52637 call(inputs, kwargs) {
52638 return tidy(() => {
52639 this.invokeCallHook(inputs, kwargs);
52640 return this.poolingFunction(getExactlyOneTensor(inputs), this.poolSize, this.strides, this.padding, this.dataFormat);
52641 });
52642 }
52643 getConfig() {
52644 const config = {
52645 poolSize: this.poolSize,
52646 padding: this.padding,
52647 strides: this.strides,
52648 dataFormat: this.dataFormat
52649 };
52650 const baseConfig = super.getConfig();
52651 Object.assign(config, baseConfig);
52652 return config;
52653 }
52654 }
52655 class MaxPooling2D extends Pooling2D {
52656 constructor(args) {
52657 super(args);
52658 }
52659 poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
52660 checkDataFormat(dataFormat);
52661 checkPaddingMode(padding);
52662 return pool2d(inputs, poolSize, strides, padding, dataFormat, 'max');
52663 }
52664 }
52665 /** @nocollapse */
52666 MaxPooling2D.className = 'MaxPooling2D';
52667 registerClass(MaxPooling2D);
52668 class AveragePooling2D extends Pooling2D {
52669 constructor(args) {
52670 super(args);
52671 }
52672 poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
52673 checkDataFormat(dataFormat);
52674 checkPaddingMode(padding);
52675 return pool2d(inputs, poolSize, strides, padding, dataFormat, 'avg');
52676 }
52677 }
52678 /** @nocollapse */
52679 AveragePooling2D.className = 'AveragePooling2D';
52680 registerClass(AveragePooling2D);
52681 /**
52682 * Abstract class for different pooling 3D layers.
52683 */
52684 class Pooling3D extends Layer {
52685 constructor(args) {
52686 if (args.poolSize == null) {
52687 args.poolSize = [2, 2, 2];
52688 }
52689 super(args);
52690 this.poolSize = Array.isArray(args.poolSize) ?
52691 args.poolSize :
52692 [args.poolSize, args.poolSize, args.poolSize];
52693 if (args.strides == null) {
52694 this.strides = this.poolSize;
52695 }
52696 else if (Array.isArray(args.strides)) {
52697 if (args.strides.length !== 3) {
52698 throw new ValueError(`If the strides property of a 3D pooling layer is an Array, ` +
52699 `it is expected to have a length of 3, but received length ` +
52700 `${args.strides.length}.`);
52701 }
52702 this.strides = args.strides;
52703 }
52704 else {
52705 // `config.strides` is a number.
52706 this.strides = [args.strides, args.strides, args.strides];
52707 }
52708 assertPositiveInteger(this.poolSize, 'poolSize');
52709 assertPositiveInteger(this.strides, 'strides');
52710 this.padding = args.padding == null ? 'valid' : args.padding;
52711 this.dataFormat =
52712 args.dataFormat == null ? 'channelsLast' : args.dataFormat;
52713 checkDataFormat(this.dataFormat);
52714 checkPaddingMode(this.padding);
52715 this.inputSpec = [new InputSpec({ ndim: 5 })];
52716 }
52717 computeOutputShape(inputShape) {
52718 inputShape = getExactlyOneShape(inputShape);
52719 let depths = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
52720 let rows = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
52721 let cols = this.dataFormat === 'channelsFirst' ? inputShape[4] : inputShape[3];
52722 depths = convOutputLength(depths, this.poolSize[0], this.padding, this.strides[0]);
52723 rows =
52724 convOutputLength(rows, this.poolSize[1], this.padding, this.strides[1]);
52725 cols =
52726 convOutputLength(cols, this.poolSize[2], this.padding, this.strides[2]);
52727 if (this.dataFormat === 'channelsFirst') {
52728 return [inputShape[0], inputShape[1], depths, rows, cols];
52729 }
52730 else {
52731 return [inputShape[0], depths, rows, cols, inputShape[4]];
52732 }
52733 }
52734 call(inputs, kwargs) {
52735 return tidy(() => {
52736 this.invokeCallHook(inputs, kwargs);
52737 return this.poolingFunction(getExactlyOneTensor(inputs), this.poolSize, this.strides, this.padding, this.dataFormat);
52738 });
52739 }
52740 getConfig() {
52741 const config = {
52742 poolSize: this.poolSize,
52743 padding: this.padding,
52744 strides: this.strides,
52745 dataFormat: this.dataFormat
52746 };
52747 const baseConfig = super.getConfig();
52748 Object.assign(config, baseConfig);
52749 return config;
52750 }
52751 }
52752 class MaxPooling3D extends Pooling3D {
52753 constructor(args) {
52754 super(args);
52755 }
52756 poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
52757 checkDataFormat(dataFormat);
52758 checkPaddingMode(padding);
52759 return pool3d$1(inputs, poolSize, strides, padding, dataFormat, 'max');
52760 }
52761 }
52762 /** @nocollapse */
52763 MaxPooling3D.className = 'MaxPooling3D';
52764 registerClass(MaxPooling3D);
52765 class AveragePooling3D extends Pooling3D {
52766 constructor(args) {
52767 super(args);
52768 }
52769 poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
52770 checkDataFormat(dataFormat);
52771 checkPaddingMode(padding);
52772 return pool3d$1(inputs, poolSize, strides, padding, dataFormat, 'avg');
52773 }
52774 }
52775 /** @nocollapse */
52776 AveragePooling3D.className = 'AveragePooling3D';
52777 registerClass(AveragePooling3D);
52778 /**
52779 * Abstract class for different global pooling 1D layers.
52780 */
52781 class GlobalPooling1D extends Layer {
52782 constructor(args) {
52783 super(args);
52784 this.inputSpec = [new InputSpec({ ndim: 3 })];
52785 }
52786 computeOutputShape(inputShape) {
52787 return [inputShape[0], inputShape[2]];
52788 }
52789 call(inputs, kwargs) {
52790 throw new NotImplementedError();
52791 }
52792 }
52793 class GlobalAveragePooling1D extends GlobalPooling1D {
52794 constructor(args) {
52795 super(args || {});
52796 }
52797 call(inputs, kwargs) {
52798 return tidy(() => {
52799 const input = getExactlyOneTensor(inputs);
52800 return mean$3(input, 1);
52801 });
52802 }
52803 }
52804 /** @nocollapse */
52805 GlobalAveragePooling1D.className = 'GlobalAveragePooling1D';
52806 registerClass(GlobalAveragePooling1D);
52807 class GlobalMaxPooling1D extends GlobalPooling1D {
52808 constructor(args) {
52809 super(args || {});
52810 }
52811 call(inputs, kwargs) {
52812 return tidy(() => {
52813 const input = getExactlyOneTensor(inputs);
52814 return max$3(input, 1);
52815 });
52816 }
52817 }
52818 /** @nocollapse */
52819 GlobalMaxPooling1D.className = 'GlobalMaxPooling1D';
52820 registerClass(GlobalMaxPooling1D);
52821 /**
52822 * Abstract class for different global pooling 2D layers.
52823 */
52824 class GlobalPooling2D extends Layer {
52825 constructor(args) {
52826 super(args);
52827 this.dataFormat =
52828 args.dataFormat == null ? 'channelsLast' : args.dataFormat;
52829 checkDataFormat(this.dataFormat);
52830 this.inputSpec = [new InputSpec({ ndim: 4 })];
52831 }
52832 computeOutputShape(inputShape) {
52833 inputShape = inputShape;
52834 if (this.dataFormat === 'channelsLast') {
52835 return [inputShape[0], inputShape[3]];
52836 }
52837 else {
52838 return [inputShape[0], inputShape[1]];
52839 }
52840 }
52841 call(inputs, kwargs) {
52842 throw new NotImplementedError();
52843 }
52844 getConfig() {
52845 const config = { dataFormat: this.dataFormat };
52846 const baseConfig = super.getConfig();
52847 Object.assign(config, baseConfig);
52848 return config;
52849 }
52850 }
52851 class GlobalAveragePooling2D extends GlobalPooling2D {
52852 call(inputs, kwargs) {
52853 return tidy(() => {
52854 const input = getExactlyOneTensor(inputs);
52855 if (this.dataFormat === 'channelsLast') {
52856 return mean$3(input, [1, 2]);
52857 }
52858 else {
52859 return mean$3(input, [2, 3]);
52860 }
52861 });
52862 }
52863 }
52864 /** @nocollapse */
52865 GlobalAveragePooling2D.className = 'GlobalAveragePooling2D';
52866 registerClass(GlobalAveragePooling2D);
52867 class GlobalMaxPooling2D extends GlobalPooling2D {
52868 call(inputs, kwargs) {
52869 return tidy(() => {
52870 const input = getExactlyOneTensor(inputs);
52871 if (this.dataFormat === 'channelsLast') {
52872 return max$3(input, [1, 2]);
52873 }
52874 else {
52875 return max$3(input, [2, 3]);
52876 }
52877 });
52878 }
52879 }
52880 /** @nocollapse */
52881 GlobalMaxPooling2D.className = 'GlobalMaxPooling2D';
52882 registerClass(GlobalMaxPooling2D);
52883
52884 /**
52885 * @license
52886 * Copyright 2018 Google LLC
52887 *
52888 * Use of this source code is governed by an MIT-style
52889 * license that can be found in the LICENSE file or at
52890 * https://opensource.org/licenses/MIT.
52891 * =============================================================================
52892 */
52893 /**
52894 * Abstract wrapper base class.
52895 *
52896 * Wrappers take another layer and augment it in various ways.
52897 * Do not use this class as a layer, it is only an abstract base class.
52898 * Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
52899 */
52900 class Wrapper extends Layer {
52901 constructor(args) {
52902 // Porting Note: In PyKeras, `self.layer` is set prior to the calling
52903 // `super()`. But we can't do that here due to TypeScript's restriction.
52904 // See: https://github.com/Microsoft/TypeScript/issues/8277
52905 // As a result, we have to add checks in `get trainable()` and
52906 // `set trainable()` below in order to prevent using `this.layer` when
52907 // its value is `undefined`. The super constructor does use the getter
52908 // and the setter of `this.layer`.
52909 super(args);
52910 this.layer = args.layer;
52911 }
52912 build(inputShape) {
52913 this.built = true;
52914 }
52915 // TODO(cais): Implement activityRegularizer getter.
52916 get trainable() {
52917 // Porting Note: the check of `this.layer` here is necessary due to the
52918 // way the `constructor` of this class is written (see Porting Note
52919 // above).
52920 if (this.layer != null) {
52921 return this.layer.trainable;
52922 }
52923 else {
52924 return false;
52925 }
52926 }
52927 set trainable(value) {
52928 // Porting Note: the check of `this.layer` here is necessary due to the
52929 // way the `constructor` of this class is written (see Porting Note
52930 // above).
52931 if (this.layer != null) {
52932 this.layer.trainable = value;
52933 }
52934 }
52935 get trainableWeights() {
52936 return this.layer.trainableWeights;
52937 }
52938 // TODO(cais): Implement setter for trainableWeights.
52939 get nonTrainableWeights() {
52940 return this.layer.nonTrainableWeights;
52941 }
52942 // TODO(cais): Implement setter for nonTrainableWeights.
52943 get updates() {
52944 // tslint:disable-next-line:no-any
52945 return this.layer._updates;
52946 }
52947 // TODO(cais): Implement getUpdatesFor().
52948 get losses() {
52949 return this.layer.losses;
52950 }
52951 // TODO(cais): Implement getLossesFor().
52952 getWeights() {
52953 return this.layer.getWeights();
52954 }
52955 setWeights(weights) {
52956 this.layer.setWeights(weights);
52957 }
52958 getConfig() {
52959 const config = {
52960 'layer': {
52961 'className': this.layer.getClassName(),
52962 'config': this.layer.getConfig(),
52963 }
52964 };
52965 const baseConfig = super.getConfig();
52966 Object.assign(config, baseConfig);
52967 return config;
52968 }
52969 setFastWeightInitDuringBuild(value) {
52970 super.setFastWeightInitDuringBuild(value);
52971 if (this.layer != null) {
52972 this.layer.setFastWeightInitDuringBuild(value);
52973 }
52974 }
52975 /** @nocollapse */
52976 static fromConfig(cls, config, customObjects = {}) {
52977 const layerConfig = config['layer'];
52978 const layer = deserialize(layerConfig, customObjects);
52979 delete config['layer'];
52980 const newConfig = { layer };
52981 Object.assign(newConfig, config);
52982 return new cls(newConfig);
52983 }
52984 }
52985 class TimeDistributed extends Wrapper {
52986 constructor(args) {
52987 super(args);
52988 this.supportsMasking = true;
52989 }
52990 build(inputShape) {
52991 inputShape = getExactlyOneShape(inputShape);
52992 if (inputShape.length < 3) {
52993 throw new ValueError(`TimeDistributed layer expects an input shape >= 3D, but received ` +
52994 `input shape ${JSON.stringify(inputShape)}`);
52995 }
52996 this.inputSpec = [{ shape: inputShape }];
52997 const childInputShape = [inputShape[0]].concat(inputShape.slice(2));
52998 if (!this.layer.built) {
52999 this.layer.build(childInputShape);
53000 this.layer.built = true;
53001 }
53002 super.build(inputShape);
53003 }
53004 computeOutputShape(inputShape) {
53005 inputShape = getExactlyOneShape(inputShape);
53006 const childInputShape = [inputShape[0]].concat(inputShape.slice(2));
53007 const childOutputShape = this.layer.computeOutputShape(childInputShape);
53008 const timesteps = inputShape[1];
53009 return [childOutputShape[0], timesteps].concat(childOutputShape.slice(1));
53010 }
53011 call(inputs, kwargs) {
53012 return tidy(() => {
53013 // TODO(cais): Add 'training' and 'useLearningPhase' to kwargs.
53014 inputs = getExactlyOneTensor(inputs);
53015 // Porting Note: In tfjs-layers, `inputs` are always concrete tensor
53016 // values. Hence the inputs can't have an undetermined first (batch)
53017 // dimension, which is why we always use the K.rnn approach here.
53018 const step = (inputs, states) => {
53019 // TODO(cais): Add useLearningPhase.
53020 // NOTE(cais): `layer.call` may return a length-1 array of Tensor in
53021 // some cases (e.g., `layer` is a `Sequential` instance), which is
53022 // why `getExactlyOneTensor` is used below.
53023 const output = getExactlyOneTensor(this.layer.call(inputs, kwargs));
53024 return [output, []];
53025 };
53026 const rnnOutputs = rnn$1(step, inputs, [], false /* goBackwards */, null /* mask */, null /* constants */, false /* unroll */, true /* needPerStepOutputs */);
53027 const y = rnnOutputs[1];
53028 // TODO(cais): Add activity regularization.
53029 // TODO(cais): Add useLearningPhase.
53030 return y;
53031 });
53032 }
53033 }
53034 /** @nocollapse */
53035 TimeDistributed.className = 'TimeDistributed';
53036 registerClass(TimeDistributed);
53037 function checkBidirectionalMergeMode(value) {
53038 checkStringTypeUnionValue(VALID_BIDIRECTIONAL_MERGE_MODES, 'BidirectionalMergeMode', value);
53039 }
53040 const DEFAULT_BIDIRECTIONAL_MERGE_MODE = 'concat';
53041 class Bidirectional extends Wrapper {
53042 constructor(args) {
53043 super(args);
53044 // Note: When creating `this.forwardLayer`, the original Layer object
53045 // (`config.layer`) ought to be cloned. This is why we call
53046 // `getConfig()` followed by `deserialize()`. Without this cloning,
53047 // the layer names saved during serialization will incorrectly contain
53048 // the 'forward_' prefix. In Python Keras, this is done using
53049 // `copy.copy` (shallow copy), which does not have a simple equivalent
53050 // in JavaScript. JavaScript's `Object.assign()` does not copy
53051 // methods.
53052 const layerConfig = args.layer.getConfig();
53053 const forwDict = {};
53054 forwDict['className'] = args.layer.getClassName();
53055 forwDict['config'] = layerConfig;
53056 this.forwardLayer = deserialize(forwDict);
53057 layerConfig['goBackwards'] =
53058 layerConfig['goBackwards'] === true ? false : true;
53059 const backDict = {};
53060 backDict['className'] = args.layer.getClassName();
53061 backDict['config'] = layerConfig;
53062 this.backwardLayer = deserialize(backDict);
53063 this.forwardLayer.name = 'forward_' + this.forwardLayer.name;
53064 this.backwardLayer.name = 'backward_' + this.backwardLayer.name;
53065 this.mergeMode = args.mergeMode === undefined ?
53066 DEFAULT_BIDIRECTIONAL_MERGE_MODE :
53067 args.mergeMode;
53068 checkBidirectionalMergeMode(this.mergeMode);
53069 if (args.weights) {
53070 throw new NotImplementedError('weights support is not implemented for Bidirectional layer yet.');
53071 }
53072 this._stateful = args.layer.stateful;
53073 this.returnSequences = args.layer.returnSequences;
53074 this.returnState = args.layer.returnState;
53075 this.supportsMasking = true;
53076 this._trainable = true;
53077 this.inputSpec = args.layer.inputSpec;
53078 this.numConstants = null;
53079 }
53080 get trainable() {
53081 return this._trainable;
53082 }
53083 set trainable(value) {
53084 // Porting Note: the check of `this.layer` here is necessary due to the
53085 // way the `constructor` of this class is written (see Porting Note
53086 // above).
53087 this._trainable = value;
53088 if (this.forwardLayer != null) {
53089 this.forwardLayer.trainable = value;
53090 }
53091 if (this.backwardLayer != null) {
53092 this.backwardLayer.trainable = value;
53093 }
53094 }
53095 getWeights() {
53096 return this.forwardLayer.getWeights().concat(this.backwardLayer.getWeights());
53097 }
53098 setWeights(weights) {
53099 const numWeights = weights.length;
53100 const numeightsOver2 = Math.floor(numWeights / 2);
53101 this.forwardLayer.setWeights(weights.slice(0, numeightsOver2));
53102 this.backwardLayer.setWeights(weights.slice(numeightsOver2));
53103 }
53104 computeOutputShape(inputShape) {
53105 let layerShapes = this.forwardLayer.computeOutputShape(inputShape);
53106 if (!(Array.isArray(layerShapes) && Array.isArray(layerShapes[0]))) {
53107 layerShapes = [layerShapes];
53108 }
53109 layerShapes = layerShapes;
53110 let outputShape;
53111 let outputShapes;
53112 let stateShape;
53113 if (this.returnState) {
53114 stateShape = layerShapes.slice(1);
53115 outputShape = layerShapes[0];
53116 }
53117 else {
53118 outputShape = layerShapes[0];
53119 }
53120 outputShape = outputShape;
53121 if (this.mergeMode === 'concat') {
53122 outputShape[outputShape.length - 1] *= 2;
53123 outputShapes = [outputShape];
53124 }
53125 else if (this.mergeMode == null) {
53126 outputShapes = [outputShape, outputShape.slice()];
53127 }
53128 else {
53129 outputShapes = [outputShape];
53130 }
53131 if (this.returnState) {
53132 if (this.mergeMode == null) {
53133 return outputShapes.concat(stateShape).concat(stateShape.slice());
53134 }
53135 return [outputShape].concat(stateShape).concat(stateShape.slice());
53136 }
53137 return singletonOrArray(outputShapes);
53138 }
53139 apply(inputs, kwargs) {
53140 let initialState = kwargs == null ? null : kwargs['initialState'];
53141 let constants = kwargs == null ? null : kwargs['constants'];
53142 if (kwargs == null) {
53143 kwargs = {};
53144 }
53145 const standardized = standardizeArgs(inputs, initialState, constants, this.numConstants);
53146 inputs = standardized.inputs;
53147 initialState = standardized.initialState;
53148 constants = standardized.constants;
53149 if (Array.isArray(inputs)) {
53150 initialState = inputs.slice(1);
53151 inputs = inputs[0];
53152 }
53153 if ((initialState == null || initialState.length === 0) &&
53154 constants == null) {
53155 return super.apply(inputs, kwargs);
53156 }
53157 const additionalInputs = [];
53158 const additionalSpecs = [];
53159 if (initialState != null) {
53160 const numStates = initialState.length;
53161 if (numStates % 2 > 0) {
53162 throw new ValueError('When passing `initialState` to a Bidrectional RNN, ' +
53163 'the state should be an Array containing the states of ' +
53164 'the underlying RNNs.');
53165 }
53166 kwargs['initialState'] = initialState;
53167 additionalInputs.push(...initialState);
53168 const stateSpecs = initialState
53169 .map(state => new InputSpec({ shape: state.shape }));
53170 this.forwardLayer.stateSpec = stateSpecs.slice(0, numStates / 2);
53171 this.backwardLayer.stateSpec = stateSpecs.slice(numStates / 2);
53172 additionalSpecs.push(...stateSpecs);
53173 }
53174 if (constants != null) {
53175 throw new NotImplementedError('Support for constants in Bidirectional layers is not ' +
53176 'implemented yet.');
53177 }
53178 const isSymbolicTensor = additionalInputs[0] instanceof SymbolicTensor;
53179 for (const tensor of additionalInputs) {
53180 if (tensor instanceof SymbolicTensor !== isSymbolicTensor) {
53181 throw new ValueError('The initial state of a Bidirectional layer cannot be ' +
53182 'specified as a mix of symbolic and non-symbolic tensors');
53183 }
53184 }
53185 if (isSymbolicTensor) {
53186 // Compute the full input and specs, including the states.
53187 const fullInput = [inputs].concat(additionalInputs);
53188 const fullInputSpec = this.inputSpec.concat(additionalSpecs);
53189 // Perform the call temporarily and replace inputSpec.
53190 // Note: with initial states symbolic calls and non-symbolic calls to
53191 // this method differ in how the initial states are passed. For
53192 // symbolic calls, the initial states are passed in the first arg, as
53193 // an Array of SymbolicTensors; for non-symbolic calls, they are
53194 // passed in the second arg as a part of the kwargs. Hence the need to
53195 // temporarily modify inputSpec here.
53196 // TODO(cais): Make refactoring so that this hacky code below is no
53197 // longer needed.
53198 const originalInputSpec = this.inputSpec;
53199 this.inputSpec = fullInputSpec;
53200 const output = super.apply(fullInput, kwargs);
53201 this.inputSpec = originalInputSpec;
53202 return output;
53203 }
53204 else {
53205 return super.apply(inputs, kwargs);
53206 }
53207 }
53208 call(inputs, kwargs) {
53209 return tidy(() => {
53210 const initialState = kwargs['initialState'];
53211 let y;
53212 let yRev;
53213 if (initialState == null) {
53214 y = this.forwardLayer.call(inputs, kwargs);
53215 yRev = this.backwardLayer.call(inputs, kwargs);
53216 }
53217 else {
53218 const forwardState = initialState.slice(0, initialState.length / 2);
53219 const backwardState = initialState.slice(initialState.length / 2);
53220 y = this.forwardLayer.call(inputs, Object.assign(kwargs, { initialState: forwardState }));
53221 yRev = this.backwardLayer.call(inputs, Object.assign(kwargs, { initialState: backwardState }));
53222 }
53223 let states;
53224 if (this.returnState) {
53225 if (Array.isArray(y)) {
53226 states = y.slice(1).concat(yRev.slice(1));
53227 }
53228 else {
53229 }
53230 y = y[0];
53231 yRev = yRev[0];
53232 }
53233 if (this.returnSequences) {
53234 yRev = reverse$2(yRev, 1);
53235 }
53236 let output;
53237 if (this.mergeMode === 'concat') {
53238 output = concatenate$2([y, yRev]);
53239 }
53240 else if (this.mergeMode === 'sum') {
53241 output = add$3(y, yRev);
53242 }
53243 else if (this.mergeMode === 'ave') {
53244 output = mul(.5, add$3(y, yRev));
53245 }
53246 else if (this.mergeMode === 'mul') {
53247 output = mul(y, yRev);
53248 }
53249 else if (this.mergeMode == null) {
53250 output = [y, yRev];
53251 }
53252 // TODO(cais): Properly set learning phase.
53253 if (this.returnState) {
53254 if (this.mergeMode == null) {
53255 return output.concat(states);
53256 }
53257 return [output].concat(states);
53258 }
53259 return output;
53260 });
53261 }
53262 resetStates(states) {
53263 this.forwardLayer.resetStates();
53264 this.backwardLayer.resetStates();
53265 }
53266 build(inputShape) {
53267 nameScope(this.forwardLayer.name, () => {
53268 this.forwardLayer.build(inputShape);
53269 });
53270 nameScope(this.backwardLayer.name, () => {
53271 this.backwardLayer.build(inputShape);
53272 });
53273 this.built = true;
53274 }
53275 computeMask(inputs, mask) {
53276 if (Array.isArray(mask)) {
53277 mask = mask[0];
53278 }
53279 let outputMask;
53280 if (this.returnSequences) {
53281 if (this.mergeMode == null) {
53282 outputMask = [mask, mask];
53283 }
53284 else {
53285 outputMask = mask;
53286 }
53287 }
53288 else {
53289 if (this.mergeMode == null) {
53290 outputMask = [null, null];
53291 }
53292 else {
53293 outputMask = null;
53294 }
53295 }
53296 if (this.returnState) {
53297 const states = this.forwardLayer.states;
53298 const stateMask = states.map(state => null);
53299 if (Array.isArray(outputMask)) {
53300 return outputMask.concat(stateMask).concat(stateMask);
53301 }
53302 else {
53303 return [outputMask].concat(stateMask).concat(stateMask);
53304 }
53305 }
53306 else {
53307 return outputMask;
53308 }
53309 }
53310 get trainableWeights() {
53311 return this.forwardLayer.trainableWeights.concat(this.backwardLayer.trainableWeights);
53312 }
53313 get nonTrainableWeights() {
53314 return this.forwardLayer.nonTrainableWeights.concat(this.backwardLayer.nonTrainableWeights);
53315 }
53316 // TODO(cais): Implement constraints().
53317 setFastWeightInitDuringBuild(value) {
53318 super.setFastWeightInitDuringBuild(value);
53319 if (this.forwardLayer != null) {
53320 this.forwardLayer.setFastWeightInitDuringBuild(value);
53321 }
53322 if (this.backwardLayer != null) {
53323 this.backwardLayer.setFastWeightInitDuringBuild(value);
53324 }
53325 }
53326 getConfig() {
53327 const config = {
53328 'mergeMode': this.mergeMode,
53329 };
53330 // TODO(cais): Add logic for `numConstants` once the property is added.
53331 const baseConfig = super.getConfig();
53332 Object.assign(config, baseConfig);
53333 return config;
53334 }
53335 /** @nocollapse */
53336 static fromConfig(cls, config) {
53337 const rnnLayer = deserialize(config['layer']);
53338 delete config['layer'];
53339 // TODO(cais): Add logic for `numConstants` once the property is added.
53340 if (config['numConstants'] != null) {
53341 throw new NotImplementedError(`Deserialization of a Bidirectional layer with numConstants ` +
53342 `present is not supported yet.`);
53343 }
53344 // tslint:disable-next-line:no-any
53345 const newConfig = config;
53346 newConfig['layer'] = rnnLayer;
53347 return new cls(newConfig);
53348 }
53349 }
53350 /** @nocollapse */
53351 Bidirectional.className = 'Bidirectional';
53352 registerClass(Bidirectional);
53353
53354 /**
53355 * @license
53356 * Copyright 2022 CodeSmith LLC
53357 *
53358 * Use of this source code is governed by an MIT-style
53359 * license that can be found in the LICENSE file or at
53360 * https://opensource.org/licenses/MIT.
53361 * =============================================================================
53362 */
53363 /**
53364 * Preprocessing Rescaling Layer
53365 *
53366 * This rescales images by a scaling and offset factor
53367 */
53368 class Rescaling extends Layer {
53369 constructor(args) {
53370 super(args);
53371 this.scale = args.scale;
53372 if (args.offset) {
53373 this.offset = args.offset;
53374 }
53375 else {
53376 this.offset = 0;
53377 }
53378 }
53379 getConfig() {
53380 const config = {
53381 'scale': this.scale,
53382 'offset': this.offset
53383 };
53384 const baseConfig = super.getConfig();
53385 Object.assign(config, baseConfig);
53386 return config;
53387 }
53388 call(inputs, kwargs) {
53389 return tidy(() => {
53390 inputs = getExactlyOneTensor(inputs);
53391 if (inputs.dtype !== 'float32') {
53392 inputs = cast$2(inputs, 'float32');
53393 }
53394 return add$3(mul(inputs, this.scale), this.offset);
53395 });
53396 }
53397 }
53398 /** @nocollapse */
53399 Rescaling.className = 'Rescaling';
53400 registerClass(Rescaling);
53401
53402 /**
53403 * @license
53404 * Copyright 2022 CodeSmith LLC
53405 *
53406 * Use of this source code is governed by an MIT-style
53407 * license that can be found in the LICENSE file or at
53408 * https://opensource.org/licenses/MIT.
53409 * =============================================================================
53410 */
53411 const { resizeBilinear: resizeBilinear$2, cropAndResize: cropAndResize$2 } = image$1;
53412 class CenterCrop extends Layer {
53413 constructor(args) {
53414 super(args);
53415 this.height = args.height;
53416 this.width = args.width;
53417 }
53418 centerCrop(inputs, hBuffer, wBuffer, height, width, inputHeight, inputWidth, dtype) {
53419 return tidy(() => {
53420 let input;
53421 let isRank3 = false;
53422 const top = hBuffer / inputHeight;
53423 const left = wBuffer / inputWidth;
53424 const bottom = ((height) + hBuffer) / inputHeight;
53425 const right = ((width) + wBuffer) / inputWidth;
53426 const bound = [top, left, bottom, right];
53427 const boxesArr = [];
53428 if (inputs.rank === 3) {
53429 isRank3 = true;
53430 input = stack([inputs]);
53431 }
53432 else {
53433 input = inputs;
53434 }
53435 for (let i = 0; i < input.shape[0]; i++) {
53436 boxesArr.push(bound);
53437 }
53438 const boxes = tensor(boxesArr, [boxesArr.length, 4]);
53439 const boxInd = range$3(0, boxesArr.length, 1, 'int32');
53440 const cropSize = [height, width];
53441 const cropped = cropAndResize$2(input, boxes, boxInd, cropSize, 'nearest');
53442 if (isRank3) {
53443 return cast$2(getExactlyOneTensor(unstack(cropped)), dtype);
53444 }
53445 return cast$2(cropped, dtype);
53446 });
53447 }
53448 upsize(inputs, height, width, dtype) {
53449 return tidy(() => {
53450 const outputs = resizeBilinear$2(inputs, [height, width]);
53451 return cast$2(outputs, dtype);
53452 });
53453 }
53454 call(inputs, kwargs) {
53455 return tidy(() => {
53456 const rankedInputs = getExactlyOneTensor(inputs);
53457 const dtype = rankedInputs.dtype;
53458 const inputShape = rankedInputs.shape;
53459 const inputHeight = inputShape[inputShape.length - 3];
53460 const inputWidth = inputShape[inputShape.length - 2];
53461 let hBuffer = 0;
53462 if (inputHeight !== this.height) {
53463 hBuffer = Math.floor((inputHeight - this.height) / 2);
53464 }
53465 let wBuffer = 0;
53466 if (inputWidth !== this.width) {
53467 wBuffer = Math.floor((inputWidth - this.width) / 2);
53468 if (wBuffer === 0) {
53469 wBuffer = 1;
53470 }
53471 }
53472 if (hBuffer >= 0 && wBuffer >= 0) {
53473 return this.centerCrop(rankedInputs, hBuffer, wBuffer, this.height, this.width, inputHeight, inputWidth, dtype);
53474 }
53475 else {
53476 return this.upsize(inputs, this.height, this.width, dtype);
53477 }
53478 });
53479 }
53480 getConfig() {
53481 const config = {
53482 'height': this.height,
53483 'width': this.width
53484 };
53485 const baseConfig = super.getConfig();
53486 Object.assign(config, baseConfig);
53487 return config;
53488 }
53489 computeOutputShape(inputShape) {
53490 inputShape = getExactlyOneShape(inputShape);
53491 const hAxis = inputShape.length - 3;
53492 const wAxis = inputShape.length - 2;
53493 inputShape[hAxis] = this.height;
53494 inputShape[wAxis] = this.width;
53495 return inputShape;
53496 }
53497 }
53498 /** @nocollapse */
53499 CenterCrop.className = 'CenterCrop';
53500 registerClass(CenterCrop);
53501
53502 /**
53503 * @license
53504 * Copyright 2022 CodeSmith LLC
53505 *
53506 * Use of this source code is governed by an MIT-style
53507 * license that can be found in the LICENSE file or at
53508 * https://opensource.org/licenses/MIT.
53509 * =============================================================================
53510 */
53511 function encodeCategoricalInputs(inputs, outputMode, depth, weights) {
53512 let input = getExactlyOneTensor(inputs);
53513 if (input.dtype !== 'int32') {
53514 input = cast$2(input, 'int32');
53515 }
53516 if (outputMode === 'int') {
53517 return input;
53518 }
53519 const originalShape = input.shape;
53520 if (input.rank === 0) {
53521 input = expandDims$3(input, -1);
53522 }
53523 if (outputMode === 'oneHot') {
53524 if (input.shape[input.shape.length - 1] !== 1) {
53525 input = expandDims$3(input, -1);
53526 }
53527 }
53528 if (input.rank > 2) {
53529 throw new ValueError(`When outputMode is not int, maximum output rank is 2`
53530 + ` Received outputMode ${outputMode} and input shape ${originalShape}`
53531 + ` which would result in output rank ${input.rank}.`);
53532 }
53533 const binaryOutput = ['multiHot', 'oneHot'].includes(outputMode);
53534 const denseBincountInput = input;
53535 let binCounts;
53536 if ((typeof weights) !== 'undefined' && outputMode === 'count') {
53537 binCounts = denseBincount$2(denseBincountInput, weights, depth, binaryOutput);
53538 }
53539 else {
53540 binCounts = denseBincount$2(denseBincountInput, [], depth, binaryOutput);
53541 }
53542 if (outputMode !== 'tfIdf') {
53543 return binCounts;
53544 }
53545 if (weights) {
53546 return mul(binCounts, weights);
53547 }
53548 else {
53549 throw new ValueError(`When outputMode is 'tfIdf', weights must be provided.`);
53550 }
53551 }
53552
53553 /**
53554 * @license
53555 * Copyright 2022 CodeSmith LLC
53556 *
53557 * Use of this source code is governed by an MIT-style
53558 * license that can be found in the LICENSE file or at
53559 * https://opensource.org/licenses/MIT.
53560 * =============================================================================
53561 */
53562 class CategoryEncoding extends Layer {
53563 constructor(args) {
53564 super(args);
53565 this.numTokens = args.numTokens;
53566 if (args.outputMode) {
53567 this.outputMode = args.outputMode;
53568 }
53569 else {
53570 this.outputMode = 'multiHot';
53571 }
53572 }
53573 getConfig() {
53574 const config = {
53575 'numTokens': this.numTokens,
53576 'outputMode': this.outputMode,
53577 };
53578 const baseConfig = super.getConfig();
53579 Object.assign(config, baseConfig);
53580 return config;
53581 }
53582 computeOutputShape(inputShape) {
53583 inputShape = getExactlyOneShape(inputShape);
53584 if (inputShape == null) {
53585 return [this.numTokens];
53586 }
53587 if (this.outputMode === 'oneHot' && inputShape[inputShape.length - 1] !== 1) {
53588 inputShape.push(this.numTokens);
53589 return inputShape;
53590 }
53591 inputShape[inputShape.length - 1] = this.numTokens;
53592 return inputShape;
53593 }
53594 call(inputs, kwargs) {
53595 return tidy(() => {
53596 inputs = getExactlyOneTensor(inputs);
53597 if (inputs.dtype !== 'int32') {
53598 inputs = cast$2(inputs, 'int32');
53599 }
53600 let countWeights;
53601 if ((typeof kwargs['countWeights']) !== 'undefined') {
53602 if (this.outputMode !== 'count') {
53603 throw new ValueError(`countWeights is not used when outputMode !== count.
53604 Received countWeights=${kwargs['countWeights']}`);
53605 }
53606 countWeights
53607 = getExactlyOneTensor(kwargs['countWeights']);
53608 }
53609 const maxValue = max$3(inputs);
53610 const minValue = min$3(inputs);
53611 const greaterEqualMax = greater$3(this.numTokens, maxValue)
53612 .bufferSync().get(0);
53613 const greaterMin = greaterEqual$2(minValue, 0).bufferSync().get(0);
53614 if (!(greaterEqualMax && greaterMin)) {
53615 throw new ValueError('Input values must be between 0 < values <='
53616 + ` numTokens with numTokens=${this.numTokens}`);
53617 }
53618 return encodeCategoricalInputs(inputs, this.outputMode, this.numTokens, countWeights);
53619 });
53620 }
53621 }
53622 /** @nocollapse */
53623 CategoryEncoding.className = 'CategoryEncoding';
53624 registerClass(CategoryEncoding);
53625
53626 /**
53627 * @license
53628 * Copyright 2022 CodeSmith LLC
53629 *
53630 * Use of this source code is governed by an MIT-style
53631 * license that can be found in the LICENSE file or at
53632 * https://opensource.org/licenses/MIT.
53633 * =============================================================================
53634 */
53635 // tf methods unimplemented in tfjs: 'bicubic', 'area', 'lanczos3', 'lanczos5',
53636 // 'gaussian', 'mitchellcubic'
53637 const INTERPOLATION_KEYS$1 = ['bilinear', 'nearest'];
53638 const INTERPOLATION_METHODS$1 = new Set(INTERPOLATION_KEYS$1);
53639 /**
53640 * Preprocessing Resizing Layer
53641 *
53642 * This resizes images by a scaling and offset factor
53643 */
53644 class Resizing extends Layer {
53645 constructor(args) {
53646 super(args);
53647 this.height = args.height;
53648 this.width = args.width;
53649 if (args.interpolation) {
53650 if (INTERPOLATION_METHODS$1.has(args.interpolation)) {
53651 this.interpolation = args.interpolation;
53652 }
53653 else {
53654 throw new ValueError(`Invalid interpolation parameter: ${args.interpolation} is not implemented`);
53655 }
53656 }
53657 else {
53658 this.interpolation = 'bilinear';
53659 }
53660 this.cropToAspectRatio = Boolean(args.cropToAspectRatio);
53661 }
53662 computeOutputShape(inputShape) {
53663 inputShape = getExactlyOneShape(inputShape);
53664 const numChannels = inputShape[2];
53665 return [this.height, this.width, numChannels];
53666 }
53667 getConfig() {
53668 const config = {
53669 'height': this.height,
53670 'width': this.width,
53671 'interpolation': this.interpolation,
53672 'cropToAspectRatio': this.cropToAspectRatio
53673 };
53674 const baseConfig = super.getConfig();
53675 Object.assign(config, baseConfig);
53676 return config;
53677 }
53678 call(inputs, kwargs) {
53679 return tidy(() => {
53680 const size = [this.height, this.width];
53681 if (this.interpolation === 'bilinear') {
53682 return image$1.resizeBilinear(inputs, size, !this.cropToAspectRatio);
53683 }
53684 else if (this.interpolation === 'nearest') {
53685 return image$1.resizeNearestNeighbor(inputs, size, !this.cropToAspectRatio);
53686 }
53687 else {
53688 throw new Error(`Interpolation is ${this.interpolation} but only ${[...INTERPOLATION_METHODS$1]} are supported`);
53689 }
53690 });
53691 }
53692 }
53693 /** @nocollapse */
53694 Resizing.className = 'Resizing';
53695 registerClass(Resizing);
53696
53697 /**
53698 * @license
53699 * Copyright 2023 CodeSmith LLC
53700 *
53701 * Use of this source code is governed by an MIT-style
53702 * license that can be found in the LICENSE file or at
53703 * https://opensource.org/licenses/MIT.
53704 * =============================================================================
53705 */
53706 /**
53707 * Keeps track of seed and handles pseudorandomness
53708 * Instance created in BaseRandomLayer class
53709 * Utilized for random preprocessing layers
53710 */
53711 class RandomSeed {
53712 constructor(seed) {
53713 this.seed = seed;
53714 }
53715 next() {
53716 if (this.seed === undefined) {
53717 return undefined;
53718 }
53719 return this.seed++;
53720 }
53721 }
53722 RandomSeed.className = 'RandomSeed';
53723
53724 /**
53725 * @license
53726 * Copyright 2023 CodeSmith LLC
53727 *
53728 * Use of this source code is governed by an MIT-style
53729 * license that can be found in the LICENSE file or at
53730 * https://opensource.org/licenses/MIT.
53731 * =============================================================================
53732 */
53733 class BaseRandomLayer extends Layer {
53734 constructor(args) {
53735 super(args);
53736 this.randomGenerator = new RandomSeed(args.seed);
53737 }
53738 getConfig() {
53739 const config = {
53740 'seed': this.randomGenerator.seed
53741 };
53742 const baseConfig = super.getConfig();
53743 Object.assign(config, baseConfig);
53744 return config;
53745 }
53746 }
53747 // A layer handle the random number creation and savemodel behavior.
53748 /** @nocollapse */
53749 BaseRandomLayer.className = 'BaseRandomLayer';
53750
53751 /**
53752 * @license
53753 * Copyright 2023 CodeSmith LLC
53754 *
53755 * Use of this source code is governed by an MIT-style
53756 * license that can be found in the LICENSE file or at
53757 * https://opensource.org/licenses/MIT.
53758 * =============================================================================
53759 */
53760 const INTERPOLATION_KEYS = ['bilinear', 'nearest'];
53761 const INTERPOLATION_METHODS = new Set(INTERPOLATION_KEYS);
53762 /**
53763 * Preprocessing Layer with randomly varies image during training
53764 *
53765 * This layer randomly adjusts the width of a batch of images of a
53766 * batch of images by a random factor.
53767 *
53768 * The input should be a 3D (unbatched) or
53769 * 4D (batched) tensor in the `"channels_last"` image data format. Input pixel
53770 * values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and of integer
53771 * or floating point dtype. By default, the layer will output floats.
53772 *
53773 * tf methods implemented in tfjs: 'bilinear', 'nearest',
53774 * tf methods unimplemented in tfjs: 'bicubic', 'area', 'lanczos3', 'lanczos5',
53775 * 'gaussian', 'mitchellcubic'
53776 *
53777 */
53778 class RandomWidth extends BaseRandomLayer {
53779 constructor(args) {
53780 super(args);
53781 const { factor, interpolation = 'bilinear' } = args;
53782 this.factor = factor;
53783 if (Array.isArray(this.factor) && this.factor.length === 2) {
53784 this.widthLower = this.factor[0];
53785 this.widthUpper = this.factor[1];
53786 }
53787 else if (!Array.isArray(this.factor) && this.factor > 0) {
53788 this.widthLower = -this.factor;
53789 this.widthUpper = this.factor;
53790 }
53791 else {
53792 throw new ValueError(`Invalid factor: ${this.factor}. Must be positive number or tuple of 2 numbers`);
53793 }
53794 if (this.widthLower < -1.0 || this.widthUpper < -1.0) {
53795 throw new ValueError(`factor must have values larger than -1. Got: ${this.factor}`);
53796 }
53797 if (this.widthUpper < this.widthLower) {
53798 throw new ValueError(`factor cannot have upper bound less than lower bound.
53799 Got upper bound: ${this.widthUpper}.
53800 Got lower bound: ${this.widthLower}
53801 `);
53802 }
53803 if (interpolation) {
53804 if (INTERPOLATION_METHODS.has(interpolation)) {
53805 this.interpolation = interpolation;
53806 }
53807 else {
53808 throw new ValueError(`Invalid interpolation parameter: ${interpolation} is not implemented`);
53809 }
53810 }
53811 }
53812 getConfig() {
53813 const config = {
53814 'factor': this.factor,
53815 'interpolation': this.interpolation,
53816 };
53817 const baseConfig = super.getConfig();
53818 Object.assign(config, baseConfig);
53819 return config;
53820 }
53821 computeOutputShape(inputShape) {
53822 inputShape = getExactlyOneShape(inputShape);
53823 const numChannels = inputShape[2];
53824 return [this.imgHeight, -1, numChannels];
53825 }
53826 call(inputs, kwargs) {
53827 return tidy(() => {
53828 const input = getExactlyOneTensor(inputs);
53829 this.imgHeight = input.shape[input.shape.length - 3];
53830 const imgWidth = input.shape[input.shape.length - 2];
53831 this.widthFactor = randomUniform$1([1], (1.0 + this.widthLower), (1.0 + this.widthUpper), 'float32', this.randomGenerator.next());
53832 let adjustedWidth = this.widthFactor.dataSync()[0] * imgWidth;
53833 adjustedWidth = Math.round(adjustedWidth);
53834 const size = [this.imgHeight, adjustedWidth];
53835 switch (this.interpolation) {
53836 case 'bilinear':
53837 return image$1.resizeBilinear(inputs, size);
53838 case 'nearest':
53839 return image$1.resizeNearestNeighbor(inputs, size);
53840 default:
53841 throw new Error(`Interpolation is ${this.interpolation}
53842 but only ${[...INTERPOLATION_METHODS]} are supported`);
53843 }
53844 });
53845 }
53846 }
53847 /** @nocollapse */
53848 RandomWidth.className = 'RandomWidth';
53849 registerClass(RandomWidth);
53850
53851 /**
53852 * @license
53853 * Copyright 2018 Google LLC
53854 *
53855 * Use of this source code is governed by an MIT-style
53856 * license that can be found in the LICENSE file or at
53857 * https://opensource.org/licenses/MIT.
53858 * =============================================================================
53859 */
53860 // TODO(cais): Add doc string to all the public static functions in this
53861 // class; include exectuable JavaScript code snippets where applicable
53862 // (b/74074458).
53863 // Input Layer.
53864 /**
53865 * An input layer is an entry point into a `tf.LayersModel`.
53866 *
53867 * `InputLayer` is generated automatically for `tf.Sequential` models by
53868 * specifying the `inputshape` or `batchInputShape` for the first layer. It
53869 * should not be specified explicitly. However, it can be useful sometimes,
53870 * e.g., when constructing a sequential model from a subset of another
53871 * sequential model's layers. Like the code snippet below shows.
53872 *
53873 * ```js
53874 * // Define a model which simply adds two inputs.
53875 * const model1 = tf.sequential();
53876 * model1.add(tf.layers.dense({inputShape: [4], units: 3, activation: 'relu'}));
53877 * model1.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));
53878 * model1.summary();
53879 * model1.predict(tf.zeros([1, 4])).print();
53880 *
53881 * // Construct another model, reusing the second layer of `model1` while
53882 * // not using the first layer of `model1`. Note that you cannot add the second
53883 * // layer of `model` directly as the first layer of the new sequential model,
53884 * // because doing so will lead to an error related to the fact that the layer
53885 * // is not an input layer. Instead, you need to create an `inputLayer` and add
53886 * // it to the new sequential model before adding the reused layer.
53887 * const model2 = tf.sequential();
53888 * // Use an inputShape that matches the input shape of `model1`'s second
53889 * // layer.
53890 * model2.add(tf.layers.inputLayer({inputShape: [3]}));
53891 * model2.add(model1.layers[1]);
53892 * model2.summary();
53893 * model2.predict(tf.zeros([1, 3])).print();
53894 * ```
53895 *
53896 * @doc {heading: 'Layers', subheading: 'Inputs', namespace: 'layers'}
53897 */
53898 function inputLayer(args) {
53899 return new InputLayer(args);
53900 }
53901 // Advanced Activation Layers.
53902 /**
53903 * Exponential Linear Unit (ELU).
53904 *
53905 * It follows:
53906 * `f(x) = alpha * (exp(x) - 1.) for x < 0`,
53907 * `f(x) = x for x >= 0`.
53908 *
53909 * Input shape:
53910 * Arbitrary. Use the configuration `inputShape` when using this layer as the
53911 * first layer in a model.
53912 *
53913 * Output shape:
53914 * Same shape as the input.
53915 *
53916 * References:
53917 * - [Fast and Accurate Deep Network Learning by Exponential Linear Units
53918 * (ELUs)](https://arxiv.org/abs/1511.07289v1)
53919 *
53920 * @doc {
53921 * heading: 'Layers',
53922 * subheading: 'Advanced Activation',
53923 * namespace: 'layers'
53924 * }
53925 */
53926 function elu$2(args) {
53927 return new ELU$3(args);
53928 }
53929 /**
53930 * Rectified Linear Unit activation function.
53931 *
53932 * Input shape:
53933 * Arbitrary. Use the config field `inputShape` (Array of integers, does
53934 * not include the sample axis) when using this layer as the first layer
53935 * in a model.
53936 *
53937 * Output shape:
53938 * Same shape as the input.
53939 *
53940 * @doc {
53941 * heading: 'Layers',
53942 * subheading: 'Advanced Activation',
53943 * namespace: 'layers'
53944 * }
53945 */
53946 function reLU(args) {
53947 return new ReLU(args);
53948 }
53949 /**
53950 * Leaky version of a rectified linear unit.
53951 *
53952 * It allows a small gradient when the unit is not active:
53953 * `f(x) = alpha * x for x < 0.`
53954 * `f(x) = x for x >= 0.`
53955 *
53956 * Input shape:
53957 * Arbitrary. Use the configuration `inputShape` when using this layer as the
53958 * first layer in a model.
53959 *
53960 * Output shape:
53961 * Same shape as the input.
53962 *
53963 * @doc {
53964 * heading: 'Layers',
53965 * subheading: 'Advanced Activation',
53966 * namespace: 'layers'
53967 * }
53968 */
53969 function leakyReLU(args) {
53970 return new LeakyReLU(args);
53971 }
53972 /**
53973 * Parameterized version of a leaky rectified linear unit.
53974 *
53975 * It follows
53976 * `f(x) = alpha * x for x < 0.`
53977 * `f(x) = x for x >= 0.`
53978 * wherein `alpha` is a trainable weight.
53979 *
53980 * Input shape:
53981 * Arbitrary. Use the configuration `inputShape` when using this layer as the
53982 * first layer in a model.
53983 *
53984 * Output shape:
53985 * Same shape as the input.
53986 *
53987 * @doc {
53988 * heading: 'Layers',
53989 * subheading: 'Advanced Activation',
53990 * namespace: 'layers'
53991 * }
53992 */
53993 function prelu$2(args) {
53994 return new PReLU(args);
53995 }
53996 /**
53997 * Softmax activation layer.
53998 *
53999 * Input shape:
54000 * Arbitrary. Use the configuration `inputShape` when using this layer as the
54001 * first layer in a model.
54002 *
54003 * Output shape:
54004 * Same shape as the input.
54005 *
54006 * @doc {
54007 * heading: 'Layers',
54008 * subheading: 'Advanced Activation',
54009 * namespace: 'layers'
54010 * }
54011 */
54012 function softmax$2(args) {
54013 return new Softmax(args);
54014 }
54015 /**
54016 * Thresholded Rectified Linear Unit.
54017 *
54018 * It follows:
54019 * `f(x) = x for x > theta`,
54020 * `f(x) = 0 otherwise`.
54021 *
54022 * Input shape:
54023 * Arbitrary. Use the configuration `inputShape` when using this layer as the
54024 * first layer in a model.
54025 *
54026 * Output shape:
54027 * Same shape as the input.
54028 *
54029 * References:
54030 * - [Zero-Bias Autoencoders and the Benefits of Co-Adapting
54031 * Features](http://arxiv.org/abs/1402.3337)
54032 *
54033 * @doc {
54034 * heading: 'Layers',
54035 * subheading: 'Advanced Activation',
54036 * namespace: 'layers'
54037 * }
54038 */
54039 function thresholdedReLU(args) {
54040 return new ThresholdedReLU(args);
54041 }
54042 // Convolutional Layers.
54043 /**
54044 * 1D convolution layer (e.g., temporal convolution).
54045 *
54046 * This layer creates a convolution kernel that is convolved
54047 * with the layer input over a single spatial (or temporal) dimension
54048 * to produce a tensor of outputs.
54049 *
54050 * If `use_bias` is True, a bias vector is created and added to the outputs.
54051 *
54052 * If `activation` is not `null`, it is applied to the outputs as well.
54053 *
54054 * When using this layer as the first layer in a model, provide an
54055 * `inputShape` argument `Array` or `null`.
54056 *
54057 * For example, `inputShape` would be:
54058 * - `[10, 128]` for sequences of 10 vectors of 128-dimensional vectors
54059 * - `[null, 128]` for variable-length sequences of 128-dimensional vectors.
54060 *
54061 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
54062 */
54063 function conv1d(args) {
54064 return new Conv1D(args);
54065 }
54066 /**
54067 * 2D convolution layer (e.g. spatial convolution over images).
54068 *
54069 * This layer creates a convolution kernel that is convolved
54070 * with the layer input to produce a tensor of outputs.
54071 *
54072 * If `useBias` is True, a bias vector is created and added to the outputs.
54073 *
54074 * If `activation` is not `null`, it is applied to the outputs as well.
54075 *
54076 * When using this layer as the first layer in a model,
54077 * provide the keyword argument `inputShape`
54078 * (Array of integers, does not include the sample axis),
54079 * e.g. `inputShape=[128, 128, 3]` for 128x128 RGB pictures
54080 * in `dataFormat='channelsLast'`.
54081 *
54082 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
54083 */
54084 function conv2d$1(args) {
54085 return new Conv2D(args);
54086 }
54087 /**
54088 * Transposed convolutional layer (sometimes called Deconvolution).
54089 *
54090 * The need for transposed convolutions generally arises
54091 * from the desire to use a transformation going in the opposite direction of
54092 * a normal convolution, i.e., from something that has the shape of the output
54093 * of some convolution to something that has the shape of its input while
54094 * maintaining a connectivity pattern that is compatible with said
54095 * convolution.
54096 *
54097 * When using this layer as the first layer in a model, provide the
54098 * configuration `inputShape` (`Array` of integers, does not include the
54099 * sample axis), e.g., `inputShape: [128, 128, 3]` for 128x128 RGB pictures in
54100 * `dataFormat: 'channelsLast'`.
54101 *
54102 * Input shape:
54103 * 4D tensor with shape:
54104 * `[batch, channels, rows, cols]` if `dataFormat` is `'channelsFirst'`.
54105 * or 4D tensor with shape
54106 * `[batch, rows, cols, channels]` if `dataFormat` is `'channelsLast'`.
54107 *
54108 * Output shape:
54109 * 4D tensor with shape:
54110 * `[batch, filters, newRows, newCols]` if `dataFormat` is
54111 * `'channelsFirst'`. or 4D tensor with shape:
54112 * `[batch, newRows, newCols, filters]` if `dataFormat` is `'channelsLast'`.
54113 *
54114 * References:
54115 * - [A guide to convolution arithmetic for deep
54116 * learning](https://arxiv.org/abs/1603.07285v1)
54117 * - [Deconvolutional
54118 * Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf)
54119 *
54120 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
54121 */
54122 function conv2dTranspose(args) {
54123 return new Conv2DTranspose(args);
54124 }
54125 /**
54126 * 3D convolution layer (e.g. spatial convolution over volumes).
54127 *
54128 * This layer creates a convolution kernel that is convolved
54129 * with the layer input to produce a tensor of outputs.
54130 *
54131 * If `useBias` is True, a bias vector is created and added to the outputs.
54132 *
54133 * If `activation` is not `null`, it is applied to the outputs as well.
54134 *
54135 * When using this layer as the first layer in a model,
54136 * provide the keyword argument `inputShape`
54137 * (Array of integers, does not include the sample axis),
54138 * e.g. `inputShape=[128, 128, 128, 1]` for 128x128x128 grayscale volumes
54139 * in `dataFormat='channelsLast'`.
54140 *
54141 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
54142 */
54143 function conv3d(args) {
54144 return new Conv3D(args);
54145 }
54146 function conv3dTranspose(args) {
54147 return new Conv3DTranspose(args);
54148 }
54149 /**
54150 * Depthwise separable 2D convolution.
54151 *
54152 * Separable convolution consists of first performing
54153 * a depthwise spatial convolution
54154 * (which acts on each input channel separately)
54155 * followed by a pointwise convolution which mixes together the resulting
54156 * output channels. The `depthMultiplier` argument controls how many
54157 * output channels are generated per input channel in the depthwise step.
54158 *
54159 * Intuitively, separable convolutions can be understood as
54160 * a way to factorize a convolution kernel into two smaller kernels,
54161 * or as an extreme version of an Inception block.
54162 *
54163 * Input shape:
54164 * 4D tensor with shape:
54165 * `[batch, channels, rows, cols]` if data_format='channelsFirst'
54166 * or 4D tensor with shape:
54167 * `[batch, rows, cols, channels]` if data_format='channelsLast'.
54168 *
54169 * Output shape:
54170 * 4D tensor with shape:
54171 * `[batch, filters, newRows, newCols]` if data_format='channelsFirst'
54172 * or 4D tensor with shape:
54173 * `[batch, newRows, newCols, filters]` if data_format='channelsLast'.
54174 * `rows` and `cols` values might have changed due to padding.
54175 *
54176 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
54177 */
54178 function separableConv2d(args) {
54179 return new SeparableConv2D(args);
54180 }
54181 /**
54182 * Cropping layer for 2D input (e.g., image).
54183 *
54184 * This layer can crop an input
54185 * at the top, bottom, left and right side of an image tensor.
54186 *
54187 * Input shape:
54188 * 4D tensor with shape:
54189 * - If `dataFormat` is `"channelsLast"`:
54190 * `[batch, rows, cols, channels]`
54191 * - If `data_format` is `"channels_first"`:
54192 * `[batch, channels, rows, cols]`.
54193 *
54194 * Output shape:
54195 * 4D with shape:
54196 * - If `dataFormat` is `"channelsLast"`:
54197 * `[batch, croppedRows, croppedCols, channels]`
54198 * - If `dataFormat` is `"channelsFirst"`:
54199 * `[batch, channels, croppedRows, croppedCols]`.
54200 *
54201 * Examples
54202 * ```js
54203 *
54204 * const model = tf.sequential();
54205 * model.add(tf.layers.cropping2D({cropping:[[2, 2], [2, 2]],
54206 * inputShape: [128, 128, 3]}));
54207 * //now output shape is [batch, 124, 124, 3]
54208 * ```
54209 *
54210 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
54211 */
54212 function cropping2D(args) {
54213 return new Cropping2D(args);
54214 }
54215 /**
54216 * Upsampling layer for 2D inputs.
54217 *
54218 * Repeats the rows and columns of the data
54219 * by size[0] and size[1] respectively.
54220 *
54221 *
54222 * Input shape:
54223 * 4D tensor with shape:
54224 * - If `dataFormat` is `"channelsLast"`:
54225 * `[batch, rows, cols, channels]`
54226 * - If `dataFormat` is `"channelsFirst"`:
54227 * `[batch, channels, rows, cols]`
54228 *
54229 * Output shape:
54230 * 4D tensor with shape:
54231 * - If `dataFormat` is `"channelsLast"`:
54232 * `[batch, upsampledRows, upsampledCols, channels]`
54233 * - If `dataFormat` is `"channelsFirst"`:
54234 * `[batch, channels, upsampledRows, upsampledCols]`
54235 *
54236 *
54237 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
54238 */
54239 function upSampling2d(args) {
54240 return new UpSampling2D(args);
54241 }
54242 // Convolutional(depthwise) Layers.
54243 /**
54244 * Depthwise separable 2D convolution.
54245 *
54246 * Depthwise Separable convolutions consists in performing just the first step
54247 * in a depthwise spatial convolution (which acts on each input channel
54248 * separately). The `depthMultiplier` argument controls how many output channels
54249 * are generated per input channel in the depthwise step.
54250 *
54251 * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
54252 */
54253 function depthwiseConv2d(args) {
54254 return new DepthwiseConv2D(args);
54255 }
54256 // Basic Layers.
54257 /**
54258 * Applies an activation function to an output.
54259 *
54260 * This layer applies element-wise activation function. Other layers, notably
54261 * `dense` can also apply activation functions. Use this isolated activation
54262 * function to extract the values before and after the
54263 * activation. For instance:
54264 *
54265 * ```js
54266 * const input = tf.input({shape: [5]});
54267 * const denseLayer = tf.layers.dense({units: 1});
54268 * const activationLayer = tf.layers.activation({activation: 'relu6'});
54269 *
54270 * // Obtain the output symbolic tensors by applying the layers in order.
54271 * const denseOutput = denseLayer.apply(input);
54272 * const activationOutput = activationLayer.apply(denseOutput);
54273 *
54274 * // Create the model based on the inputs.
54275 * const model = tf.model({
54276 * inputs: input,
54277 * outputs: [denseOutput, activationOutput]
54278 * });
54279 *
54280 * // Collect both outputs and print separately.
54281 * const [denseOut, activationOut] = model.predict(tf.randomNormal([6, 5]));
54282 * denseOut.print();
54283 * activationOut.print();
54284 * ```
54285 *
54286 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
54287 */
54288 function activation(args) {
54289 return new Activation(args);
54290 }
54291 /**
54292 * Creates a dense (fully connected) layer.
54293 *
54294 * This layer implements the operation:
54295 * `output = activation(dot(input, kernel) + bias)`
54296 *
54297 * `activation` is the element-wise activation function
54298 * passed as the `activation` argument.
54299 *
54300 * `kernel` is a weights matrix created by the layer.
54301 *
54302 * `bias` is a bias vector created by the layer (only applicable if `useBias`
54303 * is `true`).
54304 *
54305 * **Input shape:**
54306 *
54307 * nD `tf.Tensor` with shape: `(batchSize, ..., inputDim)`.
54308 *
54309 * The most common situation would be
54310 * a 2D input with shape `(batchSize, inputDim)`.
54311 *
54312 * **Output shape:**
54313 *
54314 * nD tensor with shape: `(batchSize, ..., units)`.
54315 *
54316 * For instance, for a 2D input with shape `(batchSize, inputDim)`,
54317 * the output would have shape `(batchSize, units)`.
54318 *
54319 * Note: if the input to the layer has a rank greater than 2, then it is
54320 * flattened prior to the initial dot product with the kernel.
54321 *
54322 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
54323 */
54324 function dense(args) {
54325 return new Dense(args);
54326 }
54327 /**
54328 * Applies
54329 * [dropout](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf) to
54330 * the input.
54331 *
54332 * Dropout consists in randomly setting a fraction `rate` of input units to 0 at
54333 * each update during training time, which helps prevent overfitting.
54334 *
54335 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
54336 */
54337 function dropout(args) {
54338 return new Dropout(args);
54339 }
54340 /**
54341 * Spatial 1D version of Dropout.
54342 *
54343 * This Layer type performs the same function as the Dropout layer, but it drops
54344 * entire 1D feature maps instead of individual elements. For example, if an
54345 * input example consists of 3 timesteps and the feature map for each timestep
54346 * has a size of 4, a `spatialDropout1d` layer may zero out the feature maps
54347 * of the 1st timesteps and 2nd timesteps completely while sparing all feature
54348 * elements of the 3rd timestep.
54349 *
54350 * If adjacent frames (timesteps) are strongly correlated (as is normally the
54351 * case in early convolution layers), regular dropout will not regularize the
54352 * activation and will otherwise just result in merely an effective learning
54353 * rate decrease. In this case, `spatialDropout1d` will help promote
54354 * independence among feature maps and should be used instead.
54355 *
54356 * **Arguments:**
54357 * rate: A floating-point number >=0 and <=1. Fraction of the input elements
54358 * to drop.
54359 *
54360 * **Input shape:**
54361 * 3D tensor with shape `(samples, timesteps, channels)`.
54362 *
54363 * **Output shape:**
54364 * Same as the input shape.
54365 *
54366 * References:
54367 * - [Efficient Object Localization Using Convolutional
54368 * Networks](https://arxiv.org/abs/1411.4280)
54369 *
54370 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
54371 */
54372 function spatialDropout1d(args) {
54373 return new SpatialDropout1D(args);
54374 }
54375 /**
54376 * Flattens the input. Does not affect the batch size.
54377 *
54378 * A `Flatten` layer flattens each batch in its inputs to 1D (making the output
54379 * 2D).
54380 *
54381 * For example:
54382 *
54383 * ```js
54384 * const input = tf.input({shape: [4, 3]});
54385 * const flattenLayer = tf.layers.flatten();
54386 * // Inspect the inferred output shape of the flatten layer, which
54387 * // equals `[null, 12]`. The 2nd dimension is 4 * 3, i.e., the result of the
54388 * // flattening. (The 1st dimension is the undermined batch size.)
54389 * console.log(JSON.stringify(flattenLayer.apply(input).shape));
54390 * ```
54391 *
54392 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
54393 */
54394 function flatten(args) {
54395 return new Flatten(args);
54396 }
54397 /**
54398 * Repeats the input n times in a new dimension.
54399 *
54400 * ```js
54401 * const model = tf.sequential();
54402 * model.add(tf.layers.repeatVector({n: 4, inputShape: [2]}));
54403 * const x = tf.tensor2d([[10, 20]]);
54404 * // Use the model to do inference on a data point the model hasn't seen
54405 * model.predict(x).print();
54406 * // output shape is now [batch, 2, 4]
54407 * ```
54408 *
54409 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
54410 */
54411 function repeatVector(args) {
54412 return new RepeatVector(args);
54413 }
54414 /**
54415 * Reshapes an input to a certain shape.
54416 *
54417 * ```js
54418 * const input = tf.input({shape: [4, 3]});
54419 * const reshapeLayer = tf.layers.reshape({targetShape: [2, 6]});
54420 * // Inspect the inferred output shape of the Reshape layer, which
54421 * // equals `[null, 2, 6]`. (The 1st dimension is the undermined batch size.)
54422 * console.log(JSON.stringify(reshapeLayer.apply(input).shape));
54423 * ```
54424 *
54425 * Input shape:
54426 * Arbitrary, although all dimensions in the input shape must be fixed.
54427 * Use the configuration `inputShape` when using this layer as the
54428 * first layer in a model.
54429 *
54430 *
54431 * Output shape:
54432 * [batchSize, targetShape[0], targetShape[1], ...,
54433 * targetShape[targetShape.length - 1]].
54434 *
54435 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
54436 */
54437 function reshape$2(args) {
54438 return new Reshape(args);
54439 }
54440 /**
54441 * Permutes the dimensions of the input according to a given pattern.
54442 *
54443 * Useful for, e.g., connecting RNNs and convnets together.
54444 *
54445 * Example:
54446 *
54447 * ```js
54448 * const model = tf.sequential();
54449 * model.add(tf.layers.permute({
54450 * dims: [2, 1],
54451 * inputShape: [10, 64]
54452 * }));
54453 * console.log(model.outputShape);
54454 * // Now model's output shape is [null, 64, 10], where null is the
54455 * // unpermuted sample (batch) dimension.
54456 * ```
54457 *
54458 * Input shape:
54459 * Arbitrary. Use the configuration field `inputShape` when using this
54460 * layer as the first layer in a model.
54461 *
54462 * Output shape:
54463 * Same rank as the input shape, but with the dimensions re-ordered (i.e.,
54464 * permuted) according to the `dims` configuration of this layer.
54465 *
54466 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
54467 */
54468 function permute(args) {
54469 return new Permute(args);
54470 }
54471 /**
54472 * Maps positive integers (indices) into dense vectors of fixed size.
54473 * E.g. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]
54474 *
54475 * **Input shape:** 2D tensor with shape: `[batchSize, sequenceLength]`.
54476 *
54477 * **Output shape:** 3D tensor with shape: `[batchSize, sequenceLength,
54478 * outputDim]`.
54479 *
54480 * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
54481 */
54482 function embedding(args) {
54483 return new Embedding(args);
54484 }
54485 // Merge Layers.
54486 /**
54487 * Layer that performs element-wise addition on an `Array` of inputs.
54488 *
54489 * It takes as input a list of tensors, all of the same shape, and returns a
54490 * single tensor (also of the same shape). The inputs are specified as an
54491 * `Array` when the `apply` method of the `Add` layer instance is called. For
54492 * example:
54493 *
54494 * ```js
54495 * const input1 = tf.input({shape: [2, 2]});
54496 * const input2 = tf.input({shape: [2, 2]});
54497 * const addLayer = tf.layers.add();
54498 * const sum = addLayer.apply([input1, input2]);
54499 * console.log(JSON.stringify(sum.shape));
54500 * // You get [null, 2, 2], with the first dimension as the undetermined batch
54501 * // dimension.
54502 * ```
54503 *
54504 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
54505 */
54506 function add$1(args) {
54507 return new Add(args);
54508 }
54509 /**
54510 * Layer that performs element-wise averaging on an `Array` of inputs.
54511 *
54512 * It takes as input a list of tensors, all of the same shape, and returns a
54513 * single tensor (also of the same shape). For example:
54514 *
54515 * ```js
54516 * const input1 = tf.input({shape: [2, 2]});
54517 * const input2 = tf.input({shape: [2, 2]});
54518 * const averageLayer = tf.layers.average();
54519 * const average = averageLayer.apply([input1, input2]);
54520 * console.log(JSON.stringify(average.shape));
54521 * // You get [null, 2, 2], with the first dimension as the undetermined batch
54522 * // dimension.
54523 * ```
54524 *
54525 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
54526 */
54527 function average(args) {
54528 return new Average(args);
54529 }
54530 /**
54531 * Layer that concatenates an `Array` of inputs.
54532 *
54533 * It takes a list of tensors, all of the same shape except for the
54534 * concatenation axis, and returns a single tensor, the concatenation
54535 * of all inputs. For example:
54536 *
54537 * ```js
54538 * const input1 = tf.input({shape: [2, 2]});
54539 * const input2 = tf.input({shape: [2, 3]});
54540 * const concatLayer = tf.layers.concatenate();
54541 * const output = concatLayer.apply([input1, input2]);
54542 * console.log(JSON.stringify(output.shape));
54543 * // You get [null, 2, 5], with the first dimension as the undetermined batch
54544 * // dimension. The last dimension (5) is the result of concatenating the
54545 * // last dimensions of the inputs (2 and 3).
54546 * ```
54547 *
54548 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
54549 */
54550 function concatenate(args) {
54551 return new Concatenate(args);
54552 }
54553 /**
54554 * Layer that computes the element-wise maximum of an `Array` of inputs.
54555 *
54556 * It takes as input a list of tensors, all of the same shape, and returns a
54557 * single tensor (also of the same shape). For example:
54558 *
54559 * ```js
54560 * const input1 = tf.input({shape: [2, 2]});
54561 * const input2 = tf.input({shape: [2, 2]});
54562 * const maxLayer = tf.layers.maximum();
54563 * const max = maxLayer.apply([input1, input2]);
54564 * console.log(JSON.stringify(max.shape));
54565 * // You get [null, 2, 2], with the first dimension as the undetermined batch
54566 * // dimension.
54567 * ```
54568 *
54569 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
54570 */
54571 function maximum$2(args) {
54572 return new Maximum(args);
54573 }
54574 /**
54575 * Layer that computes the element-wise minimum of an `Array` of inputs.
54576 *
54577 * It takes as input a list of tensors, all of the same shape, and returns a
54578 * single tensor (also of the same shape). For example:
54579 *
54580 * ```js
54581 * const input1 = tf.input({shape: [2, 2]});
54582 * const input2 = tf.input({shape: [2, 2]});
54583 * const minLayer = tf.layers.minimum();
54584 * const min = minLayer.apply([input1, input2]);
54585 * console.log(JSON.stringify(min.shape));
54586 * // You get [null, 2, 2], with the first dimension as the undetermined batch
54587 * // dimension.
54588 * ```
54589 *
54590 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
54591 */
54592 function minimum$2(args) {
54593 return new Minimum(args);
54594 }
54595 /**
54596 * Layer that multiplies (element-wise) an `Array` of inputs.
54597 *
54598 * It takes as input an Array of tensors, all of the same
54599 * shape, and returns a single tensor (also of the same shape).
54600 * For example:
54601 *
54602 * ```js
54603 * const input1 = tf.input({shape: [2, 2]});
54604 * const input2 = tf.input({shape: [2, 2]});
54605 * const input3 = tf.input({shape: [2, 2]});
54606 * const multiplyLayer = tf.layers.multiply();
54607 * const product = multiplyLayer.apply([input1, input2, input3]);
54608 * console.log(product.shape);
54609 * // You get [null, 2, 2], with the first dimension as the undetermined batch
54610 * // dimension.
54611 *
54612 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
54613 */
54614 function multiply$2(args) {
54615 return new Multiply(args);
54616 }
54617 /**
54618 * Layer that computes a dot product between samples in two tensors.
54619 *
54620 * E.g., if applied to a list of two tensors `a` and `b` both of shape
54621 * `[batchSize, n]`, the output will be a tensor of shape `[batchSize, 1]`,
54622 * where each entry at index `[i, 0]` will be the dot product between
54623 * `a[i, :]` and `b[i, :]`.
54624 *
54625 * Example:
54626 *
54627 * ```js
54628 * const dotLayer = tf.layers.dot({axes: -1});
54629 * const x1 = tf.tensor2d([[10, 20], [30, 40]]);
54630 * const x2 = tf.tensor2d([[-1, -2], [-3, -4]]);
54631 *
54632 * // Invoke the layer's apply() method in eager (imperative) mode.
54633 * const y = dotLayer.apply([x1, x2]);
54634 * y.print();
54635 * ```
54636 *
54637 * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
54638 */
54639 function dot(args) {
54640 return new Dot(args);
54641 }
54642 // Normalization Layers.
54643 /**
54644 * Batch normalization layer (Ioffe and Szegedy, 2014).
54645 *
54646 * Normalize the activations of the previous layer at each batch,
54647 * i.e. applies a transformation that maintains the mean activation
54648 * close to 0 and the activation standard deviation close to 1.
54649 *
54650 * Input shape:
54651 * Arbitrary. Use the keyword argument `inputShape` (Array of integers, does
54652 * not include the sample axis) when calling the constructor of this class,
54653 * if this layer is used as a first layer in a model.
54654 *
54655 * Output shape:
54656 * Same shape as input.
54657 *
54658 * References:
54659 * - [Batch Normalization: Accelerating Deep Network Training by Reducing
54660 * Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
54661 *
54662 * @doc {heading: 'Layers', subheading: 'Normalization', namespace: 'layers'}
54663 */
54664 function batchNormalization(args) {
54665 return new BatchNormalization(args);
54666 }
54667 /**
54668 * Layer-normalization layer (Ba et al., 2016).
54669 *
54670 * Normalizes the activations of the previous layer for each given example in a
54671 * batch independently, instead of across a batch like in `batchNormalization`.
54672 * In other words, this layer applies a transformation that maintains the mean
54673 * activation within each example close to 0 and activation variance close to 1.
54674 *
54675 * Input shape:
54676 * Arbitrary. Use the argument `inputShape` when using this layer as the first
54677 * layer in a model.
54678 *
54679 * Output shape:
54680 * Same as input.
54681 *
54682 * References:
54683 * - [Layer Normalization](https://arxiv.org/abs/1607.06450)
54684 *
54685 * @doc {heading: 'Layers', subheading: 'Normalization', namespace: 'layers'}
54686 */
54687 function layerNormalization(args) {
54688 return new LayerNormalization(args);
54689 }
54690 // Padding Layers.
54691 /**
54692 * Zero-padding layer for 2D input (e.g., image).
54693 *
54694 * This layer can add rows and columns of zeros
54695 * at the top, bottom, left and right side of an image tensor.
54696 *
54697 * Input shape:
54698 * 4D tensor with shape:
54699 * - If `dataFormat` is `"channelsLast"`:
54700 * `[batch, rows, cols, channels]`
54701 * - If `data_format` is `"channels_first"`:
54702 * `[batch, channels, rows, cols]`.
54703 *
54704 * Output shape:
54705 * 4D with shape:
54706 * - If `dataFormat` is `"channelsLast"`:
54707 * `[batch, paddedRows, paddedCols, channels]`
54708 * - If `dataFormat` is `"channelsFirst"`:
54709 * `[batch, channels, paddedRows, paddedCols]`.
54710 *
54711 * @doc {heading: 'Layers', subheading: 'Padding', namespace: 'layers'}
54712 */
54713 function zeroPadding2d(args) {
54714 return new ZeroPadding2D(args);
54715 }
54716 // Pooling Layers.
54717 /**
54718 * Average pooling operation for spatial data.
54719 *
54720 * Input shape: `[batchSize, inLength, channels]`
54721 *
54722 * Output shape: `[batchSize, pooledLength, channels]`
54723 *
54724 * `tf.avgPool1d` is an alias.
54725 *
54726 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
54727 */
54728 function averagePooling1d(args) {
54729 return new AveragePooling1D(args);
54730 }
54731 function avgPool1d(args) {
54732 return averagePooling1d(args);
54733 }
54734 // For backwards compatibility.
54735 // See https://github.com/tensorflow/tfjs/issues/152
54736 function avgPooling1d(args) {
54737 return averagePooling1d(args);
54738 }
54739 /**
54740 * Average pooling operation for spatial data.
54741 *
54742 * Input shape:
54743 * - If `dataFormat === CHANNEL_LAST`:
54744 * 4D tensor with shape:
54745 * `[batchSize, rows, cols, channels]`
54746 * - If `dataFormat === CHANNEL_FIRST`:
54747 * 4D tensor with shape:
54748 * `[batchSize, channels, rows, cols]`
54749 *
54750 * Output shape
54751 * - If `dataFormat === CHANNEL_LAST`:
54752 * 4D tensor with shape:
54753 * `[batchSize, pooledRows, pooledCols, channels]`
54754 * - If `dataFormat === CHANNEL_FIRST`:
54755 * 4D tensor with shape:
54756 * `[batchSize, channels, pooledRows, pooledCols]`
54757 *
54758 * `tf.avgPool2d` is an alias.
54759 *
54760 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
54761 */
54762 function averagePooling2d(args) {
54763 return new AveragePooling2D(args);
54764 }
54765 function avgPool2d(args) {
54766 return averagePooling2d(args);
54767 }
54768 // For backwards compatibility.
54769 // See https://github.com/tensorflow/tfjs/issues/152
54770 function avgPooling2d(args) {
54771 return averagePooling2d(args);
54772 }
54773 /**
54774 * Average pooling operation for 3D data.
54775 *
54776 * Input shape
54777 * - If `dataFormat === channelsLast`:
54778 * 5D tensor with shape:
54779 * `[batchSize, depths, rows, cols, channels]`
54780 * - If `dataFormat === channelsFirst`:
54781 * 4D tensor with shape:
54782 * `[batchSize, channels, depths, rows, cols]`
54783 *
54784 * Output shape
54785 * - If `dataFormat=channelsLast`:
54786 * 5D tensor with shape:
54787 * `[batchSize, pooledDepths, pooledRows, pooledCols, channels]`
54788 * - If `dataFormat=channelsFirst`:
54789 * 5D tensor with shape:
54790 * `[batchSize, channels, pooledDepths, pooledRows, pooledCols]`
54791 *
54792 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
54793 */
54794 function averagePooling3d(args) {
54795 return new AveragePooling3D(args);
54796 }
54797 function avgPool3d(args) {
54798 return averagePooling3d(args);
54799 }
54800 // For backwards compatibility.
54801 // See https://github.com/tensorflow/tfjs/issues/152
54802 function avgPooling3d(args) {
54803 return averagePooling3d(args);
54804 }
54805 /**
54806 * Global average pooling operation for temporal data.
54807 *
54808 * Input Shape: 3D tensor with shape: `[batchSize, steps, features]`.
54809 *
54810 * Output Shape: 2D tensor with shape: `[batchSize, features]`.
54811 *
54812 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
54813 */
54814 function globalAveragePooling1d(args) {
54815 return new GlobalAveragePooling1D(args);
54816 }
54817 /**
54818 * Global average pooling operation for spatial data.
54819 *
54820 * Input shape:
54821 * - If `dataFormat` is `CHANNEL_LAST`:
54822 * 4D tensor with shape: `[batchSize, rows, cols, channels]`.
54823 * - If `dataFormat` is `CHANNEL_FIRST`:
54824 * 4D tensor with shape: `[batchSize, channels, rows, cols]`.
54825 *
54826 * Output shape:
54827 * 2D tensor with shape: `[batchSize, channels]`.
54828 *
54829 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
54830 */
54831 function globalAveragePooling2d(args) {
54832 return new GlobalAveragePooling2D(args);
54833 }
54834 /**
54835 * Global max pooling operation for temporal data.
54836 *
54837 * Input Shape: 3D tensor with shape: `[batchSize, steps, features]`.
54838 *
54839 * Output Shape: 2D tensor with shape: `[batchSize, features]`.
54840 *
54841 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
54842 */
54843 function globalMaxPooling1d(args) {
54844 return new GlobalMaxPooling1D(args);
54845 }
54846 /**
54847 * Global max pooling operation for spatial data.
54848 *
54849 * Input shape:
54850 * - If `dataFormat` is `CHANNEL_LAST`:
54851 * 4D tensor with shape: `[batchSize, rows, cols, channels]`.
54852 * - If `dataFormat` is `CHANNEL_FIRST`:
54853 * 4D tensor with shape: `[batchSize, channels, rows, cols]`.
54854 *
54855 * Output shape:
54856 * 2D tensor with shape: `[batchSize, channels]`.
54857 *
54858 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
54859 */
54860 function globalMaxPooling2d(args) {
54861 return new GlobalMaxPooling2D(args);
54862 }
54863 /**
54864 * Max pooling operation for temporal data.
54865 *
54866 * Input shape: `[batchSize, inLength, channels]`
54867 *
54868 * Output shape: `[batchSize, pooledLength, channels]`
54869 *
54870 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
54871 */
54872 function maxPooling1d(args) {
54873 return new MaxPooling1D(args);
54874 }
54875 /**
54876 * Max pooling operation for spatial data.
54877 *
54878 * Input shape
54879 * - If `dataFormat === CHANNEL_LAST`:
54880 * 4D tensor with shape:
54881 * `[batchSize, rows, cols, channels]`
54882 * - If `dataFormat === CHANNEL_FIRST`:
54883 * 4D tensor with shape:
54884 * `[batchSize, channels, rows, cols]`
54885 *
54886 * Output shape
54887 * - If `dataFormat=CHANNEL_LAST`:
54888 * 4D tensor with shape:
54889 * `[batchSize, pooledRows, pooledCols, channels]`
54890 * - If `dataFormat=CHANNEL_FIRST`:
54891 * 4D tensor with shape:
54892 * `[batchSize, channels, pooledRows, pooledCols]`
54893 *
54894 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
54895 */
54896 function maxPooling2d(args) {
54897 return new MaxPooling2D(args);
54898 }
54899 /**
54900 * Max pooling operation for 3D data.
54901 *
54902 * Input shape
54903 * - If `dataFormat === channelsLast`:
54904 * 5D tensor with shape:
54905 * `[batchSize, depths, rows, cols, channels]`
54906 * - If `dataFormat === channelsFirst`:
54907 * 5D tensor with shape:
54908 * `[batchSize, channels, depths, rows, cols]`
54909 *
54910 * Output shape
54911 * - If `dataFormat=channelsLast`:
54912 * 5D tensor with shape:
54913 * `[batchSize, pooledDepths, pooledRows, pooledCols, channels]`
54914 * - If `dataFormat=channelsFirst`:
54915 * 5D tensor with shape:
54916 * `[batchSize, channels, pooledDepths, pooledRows, pooledCols]`
54917 *
54918 * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
54919 */
54920 function maxPooling3d(args) {
54921 return new MaxPooling3D(args);
54922 }
54923 // Recurrent Layers.
54924 /**
54925 * Gated Recurrent Unit - Cho et al. 2014.
54926 *
54927 * This is an `RNN` layer consisting of one `GRUCell`. However, unlike
54928 * the underlying `GRUCell`, the `apply` method of `SimpleRNN` operates
54929 * on a sequence of inputs. The shape of the input (not including the first,
54930 * batch dimension) needs to be at least 2-D, with the first dimension being
54931 * time steps. For example:
54932 *
54933 * ```js
54934 * const rnn = tf.layers.gru({units: 8, returnSequences: true});
54935 *
54936 * // Create an input with 10 time steps.
54937 * const input = tf.input({shape: [10, 20]});
54938 * const output = rnn.apply(input);
54939 *
54940 * console.log(JSON.stringify(output.shape));
54941 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
54942 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
54943 * // 3rd dimension is the `GRUCell`'s number of units.
54944 *
54945 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
54946 */
54947 function gru(args) {
54948 return new GRU(args);
54949 }
54950 /**
54951 * Cell class for `GRU`.
54952 *
54953 * `GRUCell` is distinct from the `RNN` subclass `GRU` in that its
54954 * `apply` method takes the input data of only a single time step and returns
54955 * the cell's output at the time step, while `GRU` takes the input data
54956 * over a number of time steps. For example:
54957 *
54958 * ```js
54959 * const cell = tf.layers.gruCell({units: 2});
54960 * const input = tf.input({shape: [10]});
54961 * const output = cell.apply(input);
54962 *
54963 * console.log(JSON.stringify(output.shape));
54964 * // [null, 10]: This is the cell's output at a single time step. The 1st
54965 * // dimension is the unknown batch size.
54966 * ```
54967 *
54968 * Instance(s) of `GRUCell` can be used to construct `RNN` layers. The
54969 * most typical use of this workflow is to combine a number of cells into a
54970 * stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an
54971 * RNN. For example:
54972 *
54973 * ```js
54974 * const cells = [
54975 * tf.layers.gruCell({units: 4}),
54976 * tf.layers.gruCell({units: 8}),
54977 * ];
54978 * const rnn = tf.layers.rnn({cell: cells, returnSequences: true});
54979 *
54980 * // Create an input with 10 time steps and a length-20 vector at each step.
54981 * const input = tf.input({shape: [10, 20]});
54982 * const output = rnn.apply(input);
54983 *
54984 * console.log(JSON.stringify(output.shape));
54985 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
54986 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
54987 * // 3rd dimension is the last `gruCell`'s number of units.
54988 * ```
54989 *
54990 * To create an `RNN` consisting of only *one* `GRUCell`, use the
54991 * `tf.layers.gru`.
54992 *
54993 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
54994 */
54995 function gruCell(args) {
54996 return new GRUCell(args);
54997 }
54998 /**
54999 * Long-Short Term Memory layer - Hochreiter 1997.
55000 *
55001 * This is an `RNN` layer consisting of one `LSTMCell`. However, unlike
55002 * the underlying `LSTMCell`, the `apply` method of `LSTM` operates
55003 * on a sequence of inputs. The shape of the input (not including the first,
55004 * batch dimension) needs to be at least 2-D, with the first dimension being
55005 * time steps. For example:
55006 *
55007 * ```js
55008 * const lstm = tf.layers.lstm({units: 8, returnSequences: true});
55009 *
55010 * // Create an input with 10 time steps.
55011 * const input = tf.input({shape: [10, 20]});
55012 * const output = lstm.apply(input);
55013 *
55014 * console.log(JSON.stringify(output.shape));
55015 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
55016 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
55017 * // 3rd dimension is the `LSTMCell`'s number of units.
55018 *
55019 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
55020 */
55021 function lstm(args) {
55022 return new LSTM(args);
55023 }
55024 /**
55025 * Cell class for `LSTM`.
55026 *
55027 * `LSTMCell` is distinct from the `RNN` subclass `LSTM` in that its
55028 * `apply` method takes the input data of only a single time step and returns
55029 * the cell's output at the time step, while `LSTM` takes the input data
55030 * over a number of time steps. For example:
55031 *
55032 * ```js
55033 * const cell = tf.layers.lstmCell({units: 2});
55034 * const input = tf.input({shape: [10]});
55035 * const output = cell.apply(input);
55036 *
55037 * console.log(JSON.stringify(output.shape));
55038 * // [null, 10]: This is the cell's output at a single time step. The 1st
55039 * // dimension is the unknown batch size.
55040 * ```
55041 *
55042 * Instance(s) of `LSTMCell` can be used to construct `RNN` layers. The
55043 * most typical use of this workflow is to combine a number of cells into a
55044 * stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an
55045 * RNN. For example:
55046 *
55047 * ```js
55048 * const cells = [
55049 * tf.layers.lstmCell({units: 4}),
55050 * tf.layers.lstmCell({units: 8}),
55051 * ];
55052 * const rnn = tf.layers.rnn({cell: cells, returnSequences: true});
55053 *
55054 * // Create an input with 10 time steps and a length-20 vector at each step.
55055 * const input = tf.input({shape: [10, 20]});
55056 * const output = rnn.apply(input);
55057 *
55058 * console.log(JSON.stringify(output.shape));
55059 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
55060 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
55061 * // 3rd dimension is the last `lstmCell`'s number of units.
55062 * ```
55063 *
55064 * To create an `RNN` consisting of only *one* `LSTMCell`, use the
55065 * `tf.layers.lstm`.
55066 *
55067 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
55068 */
55069 function lstmCell(args) {
55070 return new LSTMCell(args);
55071 }
55072 /**
55073 * Fully-connected RNN where the output is to be fed back to input.
55074 *
55075 * This is an `RNN` layer consisting of one `SimpleRNNCell`. However, unlike
55076 * the underlying `SimpleRNNCell`, the `apply` method of `SimpleRNN` operates
55077 * on a sequence of inputs. The shape of the input (not including the first,
55078 * batch dimension) needs to be at least 2-D, with the first dimension being
55079 * time steps. For example:
55080 *
55081 * ```js
55082 * const rnn = tf.layers.simpleRNN({units: 8, returnSequences: true});
55083 *
55084 * // Create an input with 10 time steps.
55085 * const input = tf.input({shape: [10, 20]});
55086 * const output = rnn.apply(input);
55087 *
55088 * console.log(JSON.stringify(output.shape));
55089 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
55090 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
55091 * // 3rd dimension is the `SimpleRNNCell`'s number of units.
55092 * ```
55093 *
55094 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
55095 */
55096 function simpleRNN(args) {
55097 return new SimpleRNN(args);
55098 }
55099 /**
55100 * Cell class for `SimpleRNN`.
55101 *
55102 * `SimpleRNNCell` is distinct from the `RNN` subclass `SimpleRNN` in that its
55103 * `apply` method takes the input data of only a single time step and returns
55104 * the cell's output at the time step, while `SimpleRNN` takes the input data
55105 * over a number of time steps. For example:
55106 *
55107 * ```js
55108 * const cell = tf.layers.simpleRNNCell({units: 2});
55109 * const input = tf.input({shape: [10]});
55110 * const output = cell.apply(input);
55111 *
55112 * console.log(JSON.stringify(output.shape));
55113 * // [null, 10]: This is the cell's output at a single time step. The 1st
55114 * // dimension is the unknown batch size.
55115 * ```
55116 *
55117 * Instance(s) of `SimpleRNNCell` can be used to construct `RNN` layers. The
55118 * most typical use of this workflow is to combine a number of cells into a
55119 * stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an
55120 * RNN. For example:
55121 *
55122 * ```js
55123 * const cells = [
55124 * tf.layers.simpleRNNCell({units: 4}),
55125 * tf.layers.simpleRNNCell({units: 8}),
55126 * ];
55127 * const rnn = tf.layers.rnn({cell: cells, returnSequences: true});
55128 *
55129 * // Create an input with 10 time steps and a length-20 vector at each step.
55130 * const input = tf.input({shape: [10, 20]});
55131 * const output = rnn.apply(input);
55132 *
55133 * console.log(JSON.stringify(output.shape));
55134 * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
55135 * // same as the sequence length of `input`, due to `returnSequences`: `true`;
55136 * // 3rd dimension is the last `SimpleRNNCell`'s number of units.
55137 * ```
55138 *
55139 * To create an `RNN` consisting of only *one* `SimpleRNNCell`, use the
55140 * `tf.layers.simpleRNN`.
55141 *
55142 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
55143 */
55144 function simpleRNNCell(args) {
55145 return new SimpleRNNCell(args);
55146 }
55147 /**
55148 * Convolutional LSTM layer - Xingjian Shi 2015.
55149 *
55150 * This is a `ConvRNN2D` layer consisting of one `ConvLSTM2DCell`. However,
55151 * unlike the underlying `ConvLSTM2DCell`, the `apply` method of `ConvLSTM2D`
55152 * operates on a sequence of inputs. The shape of the input (not including the
55153 * first, batch dimension) needs to be 4-D, with the first dimension being time
55154 * steps. For example:
55155 *
55156 * ```js
55157 * const filters = 3;
55158 * const kernelSize = 3;
55159 *
55160 * const batchSize = 4;
55161 * const sequenceLength = 2;
55162 * const size = 5;
55163 * const channels = 3;
55164 *
55165 * const inputShape = [batchSize, sequenceLength, size, size, channels];
55166 * const input = tf.ones(inputShape);
55167 *
55168 * const layer = tf.layers.convLstm2d({filters, kernelSize});
55169 *
55170 * const output = layer.apply(input);
55171 * ```
55172 */
55173 /** @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */
55174 function convLstm2d(args) {
55175 return new ConvLSTM2D(args);
55176 }
55177 /**
55178 * Cell class for `ConvLSTM2D`.
55179 *
55180 * `ConvLSTM2DCell` is distinct from the `ConvRNN2D` subclass `ConvLSTM2D` in
55181 * that its `call` method takes the input data of only a single time step and
55182 * returns the cell's output at the time step, while `ConvLSTM2D` takes the
55183 * input data over a number of time steps. For example:
55184 *
55185 * ```js
55186 * const filters = 3;
55187 * const kernelSize = 3;
55188 *
55189 * const sequenceLength = 1;
55190 * const size = 5;
55191 * const channels = 3;
55192 *
55193 * const inputShape = [sequenceLength, size, size, channels];
55194 * const input = tf.ones(inputShape);
55195 *
55196 * const cell = tf.layers.convLstm2dCell({filters, kernelSize});
55197 *
55198 * cell.build(input.shape);
55199 *
55200 * const outputSize = size - kernelSize + 1;
55201 * const outShape = [sequenceLength, outputSize, outputSize, filters];
55202 *
55203 * const initialH = tf.zeros(outShape);
55204 * const initialC = tf.zeros(outShape);
55205 *
55206 * const [o, h, c] = cell.call([input, initialH, initialC], {});
55207 * ```
55208 */
55209 /** @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */
55210 function convLstm2dCell(args) {
55211 return new ConvLSTM2DCell(args);
55212 }
55213 /**
55214 * Base class for recurrent layers.
55215 *
55216 * Input shape:
55217 * 3D tensor with shape `[batchSize, timeSteps, inputDim]`.
55218 *
55219 * Output shape:
55220 * - if `returnState`, an Array of tensors (i.e., `tf.Tensor`s). The first
55221 * tensor is the output. The remaining tensors are the states at the
55222 * last time step, each with shape `[batchSize, units]`.
55223 * - if `returnSequences`, the output will have shape
55224 * `[batchSize, timeSteps, units]`.
55225 * - else, the output will have shape `[batchSize, units]`.
55226 *
55227 * Masking:
55228 * This layer supports masking for input data with a variable number
55229 * of timesteps. To introduce masks to your data,
55230 * use an embedding layer with the `mask_zero` parameter
55231 * set to `True`.
55232 *
55233 * Notes on using statefulness in RNNs:
55234 * You can set RNN layers to be 'stateful', which means that the states
55235 * computed for the samples in one batch will be reused as initial states
55236 * for the samples in the next batch. This assumes a one-to-one mapping
55237 * between samples in different successive batches.
55238 *
55239 * To enable statefulness:
55240 * - specify `stateful: true` in the layer constructor.
55241 * - specify a fixed batch size for your model, by passing
55242 * if sequential model:
55243 * `batchInputShape=[...]` to the first layer in your model.
55244 * else for functional model with 1 or more Input layers:
55245 * `batchShape=[...]` to all the first layers in your model.
55246 * This is the expected shape of your inputs *including the batch size*.
55247 * It should be a tuple of integers, e.g. `(32, 10, 100)`.
55248 * - specify `shuffle=False` when calling fit().
55249 *
55250 * To reset the states of your model, call `.resetStates()` on either
55251 * a specific layer, or on your entire model.
55252 *
55253 * Note on specifying the initial state of RNNs
55254 * You can specify the initial state of RNN layers symbolically by
55255 * calling them with the option `initialState`. The value of
55256 * `initialState` should be a tensor or list of tensors representing
55257 * the initial state of the RNN layer.
55258 *
55259 * You can specify the initial state of RNN layers numerically by
55260 * calling `resetStates` with the keyword argument `states`. The value of
55261 * `states` should be a numpy array or list of numpy arrays representing
55262 * the initial state of the RNN layer.
55263 *
55264 * Note on passing external constants to RNNs
55265 * You can pass "external" constants to the cell using the `constants`
55266 * keyword argument of `RNN.call` method. This requires that the `cell.call`
55267 * method accepts the same keyword argument `constants`. Such constants
55268 * can be used to condition the cell transformation on additional static
55269 * inputs (not changing over time), a.k.a. an attention mechanism.
55270 *
55271 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
55272 */
55273 function rnn(args) {
55274 return new RNN(args);
55275 }
55276 /**
55277 * Wrapper allowing a stack of RNN cells to behave as a single cell.
55278 *
55279 * Used to implement efficient stacked RNNs.
55280 *
55281 * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
55282 */
55283 function stackedRNNCells(args) {
55284 return new StackedRNNCells(args);
55285 }
55286 // Wrapper Layers.
55287 /** @doc {heading: 'Layers', subheading: 'Wrapper', namespace: 'layers'} */
55288 function bidirectional(args) {
55289 return new Bidirectional(args);
55290 }
55291 /**
55292 * This wrapper applies a layer to every temporal slice of an input.
55293 *
55294 * The input should be at least 3D, and the dimension of the index `1` will be
55295 * considered to be the temporal dimension.
55296 *
55297 * Consider a batch of 32 samples, where each sample is a sequence of 10 vectors
55298 * of 16 dimensions. The batch input shape of the layer is then `[32, 10,
55299 * 16]`, and the `inputShape`, not including the sample dimension, is
55300 * `[10, 16]`.
55301 *
55302 * You can then use `TimeDistributed` to apply a `Dense` layer to each of the 10
55303 * timesteps, independently:
55304 *
55305 * ```js
55306 * const model = tf.sequential();
55307 * model.add(tf.layers.timeDistributed({
55308 * layer: tf.layers.dense({units: 8}),
55309 * inputShape: [10, 16],
55310 * }));
55311 *
55312 * // Now model.outputShape = [null, 10, 8].
55313 * // The output will then have shape `[32, 10, 8]`.
55314 *
55315 * // In subsequent layers, there is no need for `inputShape`:
55316 * model.add(tf.layers.timeDistributed({layer: tf.layers.dense({units: 32})}));
55317 * console.log(JSON.stringify(model.outputs[0].shape));
55318 * // Now model.outputShape = [null, 10, 32].
55319 * ```
55320 *
55321 * The output will then have shape `[32, 10, 32]`.
55322 *
55323 * `TimeDistributed` can be used with arbitrary layers, not just `Dense`, for
55324 * instance a `Conv2D` layer.
55325 *
55326 * ```js
55327 * const model = tf.sequential();
55328 * model.add(tf.layers.timeDistributed({
55329 * layer: tf.layers.conv2d({filters: 64, kernelSize: [3, 3]}),
55330 * inputShape: [10, 299, 299, 3],
55331 * }));
55332 * console.log(JSON.stringify(model.outputs[0].shape));
55333 * ```
55334 *
55335 * @doc {heading: 'Layers', subheading: 'Wrapper', namespace: 'layers'}
55336 */
55337 function timeDistributed(args) {
55338 return new TimeDistributed(args);
55339 }
55340 // Aliases for pooling.
55341 const globalMaxPool1d = globalMaxPooling1d;
55342 const globalMaxPool2d = globalMaxPooling2d;
55343 const maxPool1d = maxPooling1d;
55344 const maxPool2d = maxPooling2d;
55345 /**
55346 * Apply additive zero-centered Gaussian noise.
55347 *
55348 * As it is a regularization layer, it is only active at training time.
55349 *
55350 * This is useful to mitigate overfitting
55351 * (you could see it as a form of random data augmentation).
55352 * Gaussian Noise (GS) is a natural choice as corruption process
55353 * for real valued inputs.
55354 *
55355 * # Arguments
55356 * stddev: float, standard deviation of the noise distribution.
55357 *
55358 * # Input shape
55359 * Arbitrary. Use the keyword argument `input_shape`
55360 * (tuple of integers, does not include the samples axis)
55361 * when using this layer as the first layer in a model.
55362 *
55363 * # Output shape
55364 * Same shape as input.
55365 *
55366 * @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'}
55367 */
55368 function gaussianNoise(args) {
55369 return new GaussianNoise(args);
55370 }
55371 /**
55372 * Apply multiplicative 1-centered Gaussian noise.
55373 *
55374 * As it is a regularization layer, it is only active at training time.
55375 *
55376 * Arguments:
55377 * - `rate`: float, drop probability (as with `Dropout`).
55378 * The multiplicative noise will have
55379 * standard deviation `sqrt(rate / (1 - rate))`.
55380 *
55381 * Input shape:
55382 * Arbitrary. Use the keyword argument `inputShape`
55383 * (tuple of integers, does not include the samples axis)
55384 * when using this layer as the first layer in a model.
55385 *
55386 * Output shape:
55387 * Same shape as input.
55388 *
55389 * References:
55390 * - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting](
55391 * http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
55392 *
55393 * @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'}
55394 */
55395 function gaussianDropout(args) {
55396 return new GaussianDropout(args);
55397 }
55398 /**
55399 * Applies Alpha Dropout to the input.
55400 *
55401 * As it is a regularization layer, it is only active at training time.
55402 *
55403 * Alpha Dropout is a `Dropout` that keeps mean and variance of inputs
55404 * to their original values, in order to ensure the self-normalizing property
55405 * even after this dropout.
55406 * Alpha Dropout fits well to Scaled Exponential Linear Units
55407 * by randomly setting activations to the negative saturation value.
55408 *
55409 * Arguments:
55410 * - `rate`: float, drop probability (as with `Dropout`).
55411 * The multiplicative noise will have
55412 * standard deviation `sqrt(rate / (1 - rate))`.
55413 * - `noise_shape`: A 1-D `Tensor` of type `int32`, representing the
55414 * shape for randomly generated keep/drop flags.
55415 *
55416 * Input shape:
55417 * Arbitrary. Use the keyword argument `inputShape`
55418 * (tuple of integers, does not include the samples axis)
55419 * when using this layer as the first layer in a model.
55420 *
55421 * Output shape:
55422 * Same shape as input.
55423 *
55424 * References:
55425 * - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
55426 *
55427 * @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'}
55428 */
55429 function alphaDropout(args) {
55430 return new AlphaDropout(args);
55431 }
55432 /**
55433 * Masks a sequence by using a mask value to skip timesteps.
55434 *
55435 * If all features for a given sample timestep are equal to `mask_value`,
55436 * then the sample timestep will be masked (skipped) in all downstream layers
55437 * (as long as they support masking).
55438 *
55439 * If any downstream layer does not support masking yet receives such
55440 * an input mask, an exception will be raised.
55441 *
55442 * Arguments:
55443 * - `maskValue`: Either None or mask value to skip.
55444 *
55445 * Input shape:
55446 * Arbitrary. Use the keyword argument `inputShape`
55447 * (tuple of integers, does not include the samples axis)
55448 * when using this layer as the first layer in a model.
55449 *
55450 * Output shape:
55451 * Same shape as input.
55452 *
55453 * @doc {heading: 'Layers', subheading: 'Mask', namespace: 'layers'}
55454 */
55455 function masking(args) {
55456 return new Masking(args);
55457 }
55458 /**
55459 * A preprocessing layer which rescales input values to a new range.
55460 *
55461 * This layer rescales every value of an input (often an image) by multiplying
55462 * by `scale` and adding `offset`.
55463 *
55464 * For instance:
55465 * 1. To rescale an input in the ``[0, 255]`` range
55466 * to be in the `[0, 1]` range, you would pass `scale=1/255`.
55467 * 2. To rescale an input in the ``[0, 255]`` range to be in the `[-1, 1]`
55468 * range, you would pass `scale=1./127.5, offset=-1`.
55469 * The rescaling is applied both during training and inference. Inputs can be
55470 * of integer or floating point dtype, and by default the layer will output
55471 * floats.
55472 *
55473 * Arguments:
55474 * - `scale`: Float, the scale to apply to the inputs.
55475 * - `offset`: Float, the offset to apply to the inputs.
55476 *
55477 * Input shape:
55478 * Arbitrary.
55479 *
55480 * Output shape:
55481 * Same as input.
55482 *
55483 * @doc {heading: 'Layers', subheading: 'Rescaling', namespace: 'layers'}
55484 */
55485 function rescaling(args) {
55486 return new Rescaling(args);
55487 }
55488 /**
55489 * A preprocessing layer which center crops images.
55490 *
55491 * This layers crops the central portion of the images to a target size. If an
55492 * image is smaller than the target size, it will be resized and cropped so as
55493 * to return the largest possible window in the image that matches the target
55494 * aspect ratio.
55495 *
55496 * Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and
55497 * of integer or floating point dtype.
55498 *
55499 * If the input height/width is even and the target height/width is odd (or
55500 * inversely), the input image is left-padded by 1 pixel.
55501 *
55502 * Arguments:
55503 * `height`: Integer, the height of the output shape.
55504 * `width`: Integer, the width of the output shape.
55505 *
55506 * Input shape:
55507 * 3D (unbatched) or 4D (batched) tensor with shape:
55508 * `(..., height, width, channels)`, in `channelsLast` format.
55509 *
55510 * Output shape:
55511 * 3D (unbatched) or 4D (batched) tensor with shape:
55512 * `(..., targetHeight, targetWidth, channels)`.
55513 *
55514 *
55515 * @doc {heading: 'Layers', subheading: 'CenterCrop', namespace: 'layers'}
55516 */
55517 function centerCrop(args) {
55518 return new CenterCrop(args);
55519 }
55520 /**
55521 * A preprocessing layer which resizes images.
55522 * This layer resizes an image input to a target height and width. The input
55523 * should be a 4D (batched) or 3D (unbatched) tensor in `"channels_last"`
55524 * format. Input pixel values can be of any range (e.g. `[0., 1.)` or `[0,
55525 * 255]`) and of interger or floating point dtype. By default, the layer will
55526 * output floats.
55527 *
55528 * Arguments:
55529 * - `height`: number, the height for the output tensor.
55530 * - `width`: number, the width for the output tensor.
55531 * - `interpolation`: string, the method for image resizing interpolation.
55532 * - `cropToAspectRatio`: boolean, whether to keep image aspect ratio.
55533 *
55534 * Input shape:
55535 * Arbitrary.
55536 *
55537 * Output shape:
55538 * height, width, num channels.
55539 *
55540 * @doc {heading: 'Layers', subheading: 'Resizing', namespace: 'layers'}
55541 */
55542 function resizing(args) {
55543 return new Resizing(args);
55544 }
55545 /**
55546 * A preprocessing layer which encodes integer features.
55547 *
55548 * This layer provides options for condensing data into a categorical encoding
55549 * when the total number of tokens are known in advance. It accepts integer
55550 * values as inputs, and it outputs a dense representation of those
55551 * inputs.
55552 *
55553 * Arguments:
55554 *
55555 * numTokens: The total number of tokens the layer should support. All
55556 * inputs to the layer must integers in the range `0 <= value <
55557 * numTokens`, or an error will be thrown.
55558 *
55559 * outputMode: Specification for the output of the layer.
55560 * Defaults to `multiHot`. Values can be `oneHot`, `multiHot` or
55561 * `count`, configuring the layer as follows:
55562 *
55563 * oneHot: Encodes each individual element in the input into an
55564 * array of `numTokens` size, containing a 1 at the element index. If
55565 * the last dimension is size 1, will encode on that dimension. If the
55566 * last dimension is not size 1, will append a new dimension for the
55567 * encoded output.
55568 *
55569 * multiHot: Encodes each sample in the input into a single array
55570 * of `numTokens` size, containing a 1 for each vocabulary term
55571 * present in the sample. Treats the last dimension as the sample
55572 * dimension, if input shape is `(..., sampleLength)`, output shape
55573 * will be `(..., numTokens)`.
55574 *
55575 * count: Like `multiHot`, but the int array contains a count of
55576 * the number of times the token at that index appeared in the sample.
55577 *
55578 * For all output modes, currently only output up to rank 2 is supported.
55579 * Call arguments:
55580 * inputs: A 1D or 2D tensor of integer inputs.
55581 * countWeights: A tensor in the same shape as `inputs` indicating the
55582 * weight for each sample value when summing up in `count` mode. Not used
55583 * in `multiHot` or `oneHot` modes.
55584 *
55585 *
55586 * @doc {heading: 'Layers', subheading: 'CategoryEncoding', namespace: 'layers'}
55587 */
55588 function categoryEncoding(args) {
55589 return new CategoryEncoding(args);
55590 }
55591 /**
55592 * A preprocessing layer which randomly varies image width during training.
55593 *
55594 * This layer will randomly adjusts the width of a batch of images of a batch
55595 * of images by a random factor.
55596 *
55597 * The input should be a 3D (unbatched) or 4D (batched) tensor in
55598 * the `"channels_last"` image data format. Input pixel values can be of any
55599 * range (e.g. `[0., 1.)` or `[0, 255]`) and of integer or floating point
55600 * dtype. By default, the layer will output floats. By default, this layer is
55601 * inactive during inference. For an overview and full list of preprocessing
55602 * layers, see the preprocessing [guide]
55603 * (https://www.tensorflow.org/guide/keras/preprocessing_layers).
55604 *
55605 * Arguments:
55606 *
55607 * factor:
55608 * A positive float (fraction of original width), or a tuple of size 2
55609 * representing lower and upper bound for resizing vertically.
55610 * When represented as a single float, this value is used for both the upper
55611 * and lower bound. For instance, `factor=(0.2, 0.3)` results in an output
55612 * with width changed by a random amount in the range `[20%, 30%]`.
55613 * `factor=(-0.2, 0.3)` results in an output with width changed by a random
55614 * amount in the range `[-20%, +30%]`. `factor=0.2` results in an output
55615 * with width changed by a random amount in the range `[-20%, +20%]`.
55616 * interpolation:
55617 * String, the interpolation method.
55618 * Defaults to `bilinear`.
55619 * Supports `"bilinear"`, `"nearest"`.
55620 * The tf methods `"bicubic"`, `"area"`, `"lanczos3"`, `"lanczos5"`,
55621 * `"gaussian"`, `"mitchellcubic"` are unimplemented in tfjs.
55622 * seed:
55623 * Integer. Used to create a random seed.
55624 *
55625 * Input shape:
55626 * 3D (unbatched) or 4D (batched) tensor with shape:
55627 * `(..., height, width, channels)`, in `"channels_last"` format.
55628 * Output shape:
55629 * 3D (unbatched) or 4D (batched) tensor with shape:
55630 * `(..., height, random_width, channels)`.
55631 *
55632 *
55633 * @doc {heading: 'Layers', subheading: 'RandomWidth', namespace: 'layers'}
55634 */
55635 function randomWidth(args) {
55636 return new RandomWidth(args);
55637 }
55638
55639 var exports_layers = /*#__PURE__*/Object.freeze({
55640 __proto__: null,
55641 Layer: Layer,
55642 RNN: RNN,
55643 RNNCell: RNNCell,
55644 activation: activation,
55645 add: add$1,
55646 alphaDropout: alphaDropout,
55647 average: average,
55648 averagePooling1d: averagePooling1d,
55649 averagePooling2d: averagePooling2d,
55650 averagePooling3d: averagePooling3d,
55651 avgPool1d: avgPool1d,
55652 avgPool2d: avgPool2d,
55653 avgPool3d: avgPool3d,
55654 avgPooling1d: avgPooling1d,
55655 avgPooling2d: avgPooling2d,
55656 avgPooling3d: avgPooling3d,
55657 batchNormalization: batchNormalization,
55658 bidirectional: bidirectional,
55659 categoryEncoding: categoryEncoding,
55660 centerCrop: centerCrop,
55661 concatenate: concatenate,
55662 conv1d: conv1d,
55663 conv2d: conv2d$1,
55664 conv2dTranspose: conv2dTranspose,
55665 conv3d: conv3d,
55666 conv3dTranspose: conv3dTranspose,
55667 convLstm2d: convLstm2d,
55668 convLstm2dCell: convLstm2dCell,
55669 cropping2D: cropping2D,
55670 dense: dense,
55671 depthwiseConv2d: depthwiseConv2d,
55672 dot: dot,
55673 dropout: dropout,
55674 elu: elu$2,
55675 embedding: embedding,
55676 flatten: flatten,
55677 gaussianDropout: gaussianDropout,
55678 gaussianNoise: gaussianNoise,
55679 globalAveragePooling1d: globalAveragePooling1d,
55680 globalAveragePooling2d: globalAveragePooling2d,
55681 globalMaxPool1d: globalMaxPool1d,
55682 globalMaxPool2d: globalMaxPool2d,
55683 globalMaxPooling1d: globalMaxPooling1d,
55684 globalMaxPooling2d: globalMaxPooling2d,
55685 gru: gru,
55686 gruCell: gruCell,
55687 input: input,
55688 inputLayer: inputLayer,
55689 layerNormalization: layerNormalization,
55690 leakyReLU: leakyReLU,
55691 lstm: lstm,
55692 lstmCell: lstmCell,
55693 masking: masking,
55694 maxPool1d: maxPool1d,
55695 maxPool2d: maxPool2d,
55696 maxPooling1d: maxPooling1d,
55697 maxPooling2d: maxPooling2d,
55698 maxPooling3d: maxPooling3d,
55699 maximum: maximum$2,
55700 minimum: minimum$2,
55701 multiply: multiply$2,
55702 permute: permute,
55703 prelu: prelu$2,
55704 randomWidth: randomWidth,
55705 reLU: reLU,
55706 repeatVector: repeatVector,
55707 rescaling: rescaling,
55708 reshape: reshape$2,
55709 resizing: resizing,
55710 rnn: rnn,
55711 separableConv2d: separableConv2d,
55712 simpleRNN: simpleRNN,
55713 simpleRNNCell: simpleRNNCell,
55714 softmax: softmax$2,
55715 spatialDropout1d: spatialDropout1d,
55716 stackedRNNCells: stackedRNNCells,
55717 thresholdedReLU: thresholdedReLU,
55718 timeDistributed: timeDistributed,
55719 upSampling2d: upSampling2d,
55720 zeroPadding2d: zeroPadding2d
55721 });
55722
55723 /**
55724 * Binary accuracy metric function.
55725 *
55726 * `yTrue` and `yPred` can have 0-1 values. Example:
55727 * ```js
55728 * const x = tf.tensor2d([[1, 1, 1, 1], [0, 0, 0, 0]], [2, 4]);
55729 * const y = tf.tensor2d([[1, 0, 1, 0], [0, 0, 0, 1]], [2, 4]);
55730 * const accuracy = tf.metrics.binaryAccuracy(x, y);
55731 * accuracy.print();
55732 * ```
55733 *
55734 * `yTrue` and `yPred` can also have floating-number values between 0 and 1, in
55735 * which case the values will be thresholded at 0.5 to yield 0-1 values (i.e.,
55736 * a value >= 0.5 and <= 1.0 is interpreted as 1).
55737 *
55738 * Example:
55739 * ```js
55740 * const x = tf.tensor1d([1, 1, 1, 1, 0, 0, 0, 0]);
55741 * const y = tf.tensor1d([0.2, 0.4, 0.6, 0.8, 0.2, 0.3, 0.4, 0.7]);
55742 * const accuracy = tf.metrics.binaryAccuracy(x, y);
55743 * accuracy.print();
55744 * ```
55745 *
55746 * @param yTrue Binary Tensor of truth.
55747 * @param yPred Binary Tensor of prediction.
55748 * @return Accuracy Tensor.
55749 *
55750 * @doc {heading: 'Metrics', namespace: 'metrics'}
55751 */
55752 function binaryAccuracy(yTrue, yPred) {
55753 return binaryAccuracy$1(yTrue, yPred);
55754 }
55755 /**
55756 * Binary crossentropy metric function.
55757 *
55758 * Example:
55759 * ```js
55760 * const x = tf.tensor2d([[0], [1], [1], [1]]);
55761 * const y = tf.tensor2d([[0], [0], [0.5], [1]]);
55762 * const crossentropy = tf.metrics.binaryCrossentropy(x, y);
55763 * crossentropy.print();
55764 * ```
55765 *
55766 * @param yTrue Binary Tensor of truth.
55767 * @param yPred Binary Tensor of prediction, probabilities for the `1` case.
55768 * @return Accuracy Tensor.
55769 *
55770 * @doc {heading: 'Metrics', namespace: 'metrics'}
55771 */
55772 function binaryCrossentropy(yTrue, yPred) {
55773 return binaryCrossentropy$1(yTrue, yPred);
55774 }
55775 /**
55776 * Sparse categorical accuracy metric function.
55777 *
55778 * Example:
55779 * ```js
55780 *
55781 * const yTrue = tf.tensor1d([1, 1, 2, 2, 0]);
55782 * const yPred = tf.tensor2d(
55783 * [[0, 1, 0], [1, 0, 0], [0, 0.4, 0.6], [0, 0.6, 0.4], [0.7, 0.3, 0]]);
55784 * const crossentropy = tf.metrics.sparseCategoricalAccuracy(yTrue, yPred);
55785 * crossentropy.print();
55786 * ```
55787 *
55788 * @param yTrue True labels: indices.
55789 * @param yPred Predicted probabilities or logits.
55790 * @returns Accuracy tensor.
55791 *
55792 * @doc {heading: 'Metrics', namespace: 'metrics'}
55793 */
55794 function sparseCategoricalAccuracy(yTrue, yPred) {
55795 return sparseCategoricalAccuracy$1(yTrue, yPred);
55796 }
55797 /**
55798 * Categorical accuracy metric function.
55799 *
55800 * Example:
55801 * ```js
55802 * const x = tf.tensor2d([[0, 0, 0, 1], [0, 0, 0, 1]]);
55803 * const y = tf.tensor2d([[0.1, 0.8, 0.05, 0.05], [0.1, 0.05, 0.05, 0.8]]);
55804 * const accuracy = tf.metrics.categoricalAccuracy(x, y);
55805 * accuracy.print();
55806 * ```
55807 *
55808 * @param yTrue Binary Tensor of truth: one-hot encoding of categories.
55809 * @param yPred Binary Tensor of prediction: probabilities or logits for the
55810 * same categories as in `yTrue`.
55811 * @return Accuracy Tensor.
55812 *
55813 * @doc {heading: 'Metrics', namespace: 'metrics'}
55814 */
55815 function categoricalAccuracy(yTrue, yPred) {
55816 return categoricalAccuracy$1(yTrue, yPred);
55817 }
55818 /**
55819 * Categorical crossentropy between an output tensor and a target tensor.
55820 *
55821 * @param target A tensor of the same shape as `output`.
55822 * @param output A tensor resulting from a softmax (unless `fromLogits` is
55823 * `true`, in which case `output` is expected to be the logits).
55824 * @param fromLogits Boolean, whether `output` is the result of a softmax, or is
55825 * a tensor of logits.
55826 *
55827 * @doc {heading: 'Metrics', namespace: 'metrics'}
55828 */
55829 function categoricalCrossentropy(yTrue, yPred) {
55830 return categoricalCrossentropy$1(yTrue, yPred);
55831 }
55832 /**
55833 * Computes the precision of the predictions with respect to the labels.
55834 *
55835 * Example:
55836 * ```js
55837 * const x = tf.tensor2d(
55838 * [
55839 * [0, 0, 0, 1],
55840 * [0, 1, 0, 0],
55841 * [0, 0, 0, 1],
55842 * [1, 0, 0, 0],
55843 * [0, 0, 1, 0]
55844 * ]
55845 * );
55846 *
55847 * const y = tf.tensor2d(
55848 * [
55849 * [0, 0, 1, 0],
55850 * [0, 1, 0, 0],
55851 * [0, 0, 0, 1],
55852 * [0, 1, 0, 0],
55853 * [0, 1, 0, 0]
55854 * ]
55855 * );
55856 *
55857 * const precision = tf.metrics.precision(x, y);
55858 * precision.print();
55859 * ```
55860 *
55861 * @param yTrue The ground truth values. Expected to contain only 0-1 values.
55862 * @param yPred The predicted values. Expected to contain only 0-1 values.
55863 * @return Precision Tensor.
55864 *
55865 * @doc {heading: 'Metrics', namespace: 'metrics'}
55866 */
55867 function precision(yTrue, yPred) {
55868 return precision$1(yTrue, yPred);
55869 }
55870 /**
55871 * Computes the recall of the predictions with respect to the labels.
55872 *
55873 * Example:
55874 * ```js
55875 * const x = tf.tensor2d(
55876 * [
55877 * [0, 0, 0, 1],
55878 * [0, 1, 0, 0],
55879 * [0, 0, 0, 1],
55880 * [1, 0, 0, 0],
55881 * [0, 0, 1, 0]
55882 * ]
55883 * );
55884 *
55885 * const y = tf.tensor2d(
55886 * [
55887 * [0, 0, 1, 0],
55888 * [0, 1, 0, 0],
55889 * [0, 0, 0, 1],
55890 * [0, 1, 0, 0],
55891 * [0, 1, 0, 0]
55892 * ]
55893 * );
55894 *
55895 * const recall = tf.metrics.recall(x, y);
55896 * recall.print();
55897 * ```
55898 *
55899 * @param yTrue The ground truth values. Expected to contain only 0-1 values.
55900 * @param yPred The predicted values. Expected to contain only 0-1 values.
55901 * @return Recall Tensor.
55902 *
55903 * @doc {heading: 'Metrics', namespace: 'metrics'}
55904 */
55905 function recall(yTrue, yPred) {
55906 return recall$1(yTrue, yPred);
55907 }
55908 /**
55909 * Loss or metric function: Cosine proximity.
55910 *
55911 * Mathematically, cosine proximity is defined as:
55912 * `-sum(l2Normalize(yTrue) * l2Normalize(yPred))`,
55913 * wherein `l2Normalize()` normalizes the L2 norm of the input to 1 and `*`
55914 * represents element-wise multiplication.
55915 *
55916 * ```js
55917 * const yTrue = tf.tensor2d([[1, 0], [1, 0]]);
55918 * const yPred = tf.tensor2d([[1 / Math.sqrt(2), 1 / Math.sqrt(2)], [0, 1]]);
55919 * const proximity = tf.metrics.cosineProximity(yTrue, yPred);
55920 * proximity.print();
55921 * ```
55922 *
55923 * @param yTrue Truth Tensor.
55924 * @param yPred Prediction Tensor.
55925 * @return Cosine proximity Tensor.
55926 *
55927 * @doc {heading: 'Metrics', namespace: 'metrics'}
55928 */
55929 function cosineProximity(yTrue, yPred) {
55930 return cosineProximity$1(yTrue, yPred);
55931 }
55932 /**
55933 * Loss or metric function: Mean absolute error.
55934 *
55935 * Mathematically, mean absolute error is defined as:
55936 * `mean(abs(yPred - yTrue))`,
55937 * wherein the `mean` is applied over feature dimensions.
55938 *
55939 * ```js
55940 * const yTrue = tf.tensor2d([[0, 1], [0, 0], [2, 3]]);
55941 * const yPred = tf.tensor2d([[0, 1], [0, 1], [-2, -3]]);
55942 * const mse = tf.metrics.meanAbsoluteError(yTrue, yPred);
55943 * mse.print();
55944 * ```
55945 *
55946 * @param yTrue Truth Tensor.
55947 * @param yPred Prediction Tensor.
55948 * @return Mean absolute error Tensor.
55949 *
55950 * @doc {heading: 'Metrics', namespace: 'metrics'}
55951 */
55952 function meanAbsoluteError(yTrue, yPred) {
55953 return meanAbsoluteError$1(yTrue, yPred);
55954 }
55955 /**
55956 * Loss or metric function: Mean absolute percentage error.
55957 *
55958 * ```js
55959 * const yTrue = tf.tensor2d([[0, 1], [10, 20]]);
55960 * const yPred = tf.tensor2d([[0, 1], [11, 24]]);
55961 * const mse = tf.metrics.meanAbsolutePercentageError(yTrue, yPred);
55962 * mse.print();
55963 * ```
55964 *
55965 * Aliases: `tf.metrics.MAPE`, `tf.metrics.mape`.
55966 *
55967 * @param yTrue Truth Tensor.
55968 * @param yPred Prediction Tensor.
55969 * @return Mean absolute percentage error Tensor.
55970 *
55971 * @doc {heading: 'Metrics', namespace: 'metrics'}
55972 */
55973 function meanAbsolutePercentageError(yTrue, yPred) {
55974 return meanAbsolutePercentageError$1(yTrue, yPred);
55975 }
55976 function MAPE(yTrue, yPred) {
55977 return meanAbsolutePercentageError$1(yTrue, yPred);
55978 }
55979 function mape(yTrue, yPred) {
55980 return meanAbsolutePercentageError$1(yTrue, yPred);
55981 }
55982 /**
55983 * Loss or metric function: Mean squared error.
55984 *
55985 * ```js
55986 * const yTrue = tf.tensor2d([[0, 1], [3, 4]]);
55987 * const yPred = tf.tensor2d([[0, 1], [-3, -4]]);
55988 * const mse = tf.metrics.meanSquaredError(yTrue, yPred);
55989 * mse.print();
55990 * ```
55991 *
55992 * Aliases: `tf.metrics.MSE`, `tf.metrics.mse`.
55993 *
55994 * @param yTrue Truth Tensor.
55995 * @param yPred Prediction Tensor.
55996 * @return Mean squared error Tensor.
55997 *
55998 * @doc {heading: 'Metrics', namespace: 'metrics'}
55999 */
56000 function meanSquaredError(yTrue, yPred) {
56001 return meanSquaredError$1(yTrue, yPred);
56002 }
56003 function MSE(yTrue, yPred) {
56004 return meanSquaredError$1(yTrue, yPred);
56005 }
56006 function mse(yTrue, yPred) {
56007 return meanSquaredError$1(yTrue, yPred);
56008 }
56009 /**
56010 * Computes R2 score.
56011 *
56012 * ```js
56013 * const yTrue = tf.tensor2d([[0, 1], [3, 4]]);
56014 * const yPred = tf.tensor2d([[0, 1], [-3, -4]]);
56015 * const r2Score = tf.metrics.r2Score(yTrue, yPred);
56016 * r2Score.print();
56017 * ```
56018 * @param yTrue Truth Tensor.
56019 * @param yPred Prediction Tensor.
56020 * @return R2 score Tensor.
56021 *
56022 * @doc {heading: 'Metrics', namespace: 'metrics'}
56023 */
56024 function r2Score(yTrue, yPred) {
56025 return r2Score$1(yTrue, yPred);
56026 }
56027
56028 var exports_metrics = /*#__PURE__*/Object.freeze({
56029 __proto__: null,
56030 MAPE: MAPE,
56031 MSE: MSE,
56032 binaryAccuracy: binaryAccuracy,
56033 binaryCrossentropy: binaryCrossentropy,
56034 categoricalAccuracy: categoricalAccuracy,
56035 categoricalCrossentropy: categoricalCrossentropy,
56036 cosineProximity: cosineProximity,
56037 mape: mape,
56038 meanAbsoluteError: meanAbsoluteError,
56039 meanAbsolutePercentageError: meanAbsolutePercentageError,
56040 meanSquaredError: meanSquaredError,
56041 mse: mse,
56042 precision: precision,
56043 r2Score: r2Score,
56044 recall: recall,
56045 sparseCategoricalAccuracy: sparseCategoricalAccuracy
56046 });
56047
56048 /**
56049 * @license
56050 * Copyright 2018 Google LLC
56051 *
56052 * Use of this source code is governed by an MIT-style
56053 * license that can be found in the LICENSE file or at
56054 * https://opensource.org/licenses/MIT.
56055 * =============================================================================
56056 */
56057
56058 var exports_models = /*#__PURE__*/Object.freeze({
56059 __proto__: null,
56060 modelFromJSON: modelFromJSON
56061 });
56062
56063 /**
56064 * @license
56065 * Copyright 2018 Google LLC
56066 *
56067 * Use of this source code is governed by an MIT-style
56068 * license that can be found in the LICENSE file or at
56069 * https://opensource.org/licenses/MIT.
56070 * =============================================================================
56071 */
56072 /**
56073 * Regularizer for L1 and L2 regularization.
56074 *
56075 * Adds a term to the loss to penalize large weights:
56076 * loss += sum(l1 * abs(x)) + sum(l2 * x^2)
56077 *
56078 * @doc {heading: 'Regularizers', namespace: 'regularizers'}
56079 */
56080 function l1l2(config) {
56081 return new L1L2(config);
56082 }
56083 /**
56084 * Regularizer for L1 regularization.
56085 *
56086 * Adds a term to the loss to penalize large weights:
56087 * loss += sum(l1 * abs(x))
56088 * @param args l1 config.
56089 *
56090 * @doc {heading: 'Regularizers', namespace: 'regularizers'}
56091 */
56092 function l1(config) {
56093 return l1$1(config);
56094 }
56095 /**
56096 * Regularizer for L2 regularization.
56097 *
56098 * Adds a term to the loss to penalize large weights:
56099 * loss += sum(l2 * x^2)
56100 * @param args l2 config.
56101 *
56102 * @doc {heading: 'Regularizers', namespace: 'regularizers'}
56103 */
56104 function l2(config) {
56105 return l2$1(config);
56106 }
56107
56108 var exports_regularizers = /*#__PURE__*/Object.freeze({
56109 __proto__: null,
56110 l1: l1,
56111 l1l2: l1l2,
56112 l2: l2
56113 });
56114
56115 /**
56116 * @license
56117 * Copyright 2018 Google LLC
56118 *
56119 * Use of this source code is governed by an MIT-style
56120 * license that can be found in the LICENSE file or at
56121 * https://opensource.org/licenses/MIT.
56122 * =============================================================================
56123 */
56124 class Callback extends BaseCallback {
56125 constructor() {
56126 super(...arguments);
56127 /** Instance of `keras.models.Model`. Reference of the model being trained. */
56128 this.model = null;
56129 }
56130 setModel(model) {
56131 if (!(model instanceof LayersModel)) {
56132 throw new Error('model must be a LayersModel, not some other Container');
56133 }
56134 this.model = model;
56135 }
56136 }
56137 function less$2(currVal, prevVal) {
56138 return currVal < prevVal;
56139 }
56140 function greater$2(currVal, prevVal) {
56141 return currVal > prevVal;
56142 }
56143 /**
56144 * A Callback that stops training when a monitored quantity has stopped
56145 * improving.
56146 */
56147 class EarlyStopping extends Callback {
56148 constructor(args) {
56149 super();
56150 if (args == null) {
56151 args = {};
56152 }
56153 if (args.restoreBestWeights) {
56154 throw new NotImplementedError('restoreBestWeights = True is not implemented in EarlyStopping yet.');
56155 }
56156 this.monitor = args.monitor || 'val_loss';
56157 this.minDelta = Math.abs(args.minDelta || 0);
56158 this.patience = args.patience || 0;
56159 this.verbose = args.verbose || 0;
56160 this.mode = args.mode || 'auto';
56161 this.baseline = args.baseline;
56162 if (['auto', 'min', 'max'].indexOf(this.mode) === -1) {
56163 console.warn(`EarlyStopping mode '${this.mode}' is invalid. ` +
56164 `Falling back to mode 'auto'.`);
56165 this.mode = 'auto';
56166 }
56167 if (this.mode === 'min') {
56168 this.monitorFunc = less$2;
56169 }
56170 else if (this.mode === 'max') {
56171 this.monitorFunc = greater$2;
56172 }
56173 else {
56174 // For mode === 'auto'.
56175 if (this.monitor.indexOf('acc') !== -1) {
56176 this.monitorFunc = greater$2;
56177 }
56178 else {
56179 this.monitorFunc = less$2;
56180 }
56181 }
56182 if (this.monitorFunc === less$2) {
56183 this.minDelta *= -1;
56184 }
56185 }
56186 async onTrainBegin(logs) {
56187 this.wait = 0;
56188 this.stoppedEpoch = 0;
56189 if (this.baseline != null) {
56190 this.best = this.baseline;
56191 }
56192 else {
56193 this.best = this.monitorFunc === less$2 ? Infinity : -Infinity;
56194 }
56195 }
56196 async onEpochEnd(epoch, logs) {
56197 await resolveScalarsInLogs(logs);
56198 const current = this.getMonitorValue(logs);
56199 if (current == null) {
56200 return;
56201 }
56202 if (this.monitorFunc(current - this.minDelta, this.best)) {
56203 this.best = current;
56204 this.wait = 0;
56205 // TODO(cais): Logic for restoreBestWeights.
56206 }
56207 else {
56208 this.wait++;
56209 if (this.wait >= this.patience) {
56210 this.stoppedEpoch = epoch;
56211 this.model.stopTraining = true;
56212 }
56213 // TODO(cais): Logic for restoreBestWeights.
56214 }
56215 }
56216 async onTrainEnd(logs) {
56217 if (this.stoppedEpoch > 0 && this.verbose) {
56218 console.log(`Epoch ${this.stoppedEpoch}: early stopping.`);
56219 }
56220 }
56221 getMonitorValue(logs) {
56222 if (logs == null) {
56223 logs = {};
56224 }
56225 const monitorValue = logs[this.monitor];
56226 if (monitorValue == null) {
56227 console.warn(`Metric for EarlyStopping ${this.monitor} is not available. ` +
56228 `Available metrics are: ${Object.keys(logs)}`);
56229 }
56230 return monitorValue;
56231 }
56232 }
56233 /**
56234 * Factory function for a Callback that stops training when a monitored
56235 * quantity has stopped improving.
56236 *
56237 * Early stopping is a type of regularization, and protects model against
56238 * overfitting.
56239 *
56240 * The following example based on fake data illustrates how this callback
56241 * can be used during `tf.LayersModel.fit()`:
56242 *
56243 * ```js
56244 * const model = tf.sequential();
56245 * model.add(tf.layers.dense({
56246 * units: 3,
56247 * activation: 'softmax',
56248 * kernelInitializer: 'ones',
56249 * inputShape: [2]
56250 * }));
56251 * const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
56252 * const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
56253 * const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
56254 * const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
56255 * model.compile(
56256 * {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});
56257 *
56258 * // Without the EarlyStopping callback, the val_acc value would be:
56259 * // 0.5, 0.5, 0.5, 0.5, ...
56260 * // With val_acc being monitored, training should stop after the 2nd epoch.
56261 * const history = await model.fit(xs, ys, {
56262 * epochs: 10,
56263 * validationData: [xsVal, ysVal],
56264 * callbacks: tf.callbacks.earlyStopping({monitor: 'val_acc'})
56265 * });
56266 *
56267 * // Expect to see a length-2 array.
56268 * console.log(history.history.val_acc);
56269 * ```
56270 *
56271 * @doc {
56272 * heading: 'Callbacks',
56273 * namespace: 'callbacks'
56274 * }
56275 */
56276 function earlyStopping(args) {
56277 return new EarlyStopping(args);
56278 }
56279 const callbacks = { earlyStopping };
56280
56281 /**
56282 * @license
56283 * Copyright 2018 Google LLC
56284 *
56285 * Use of this source code is governed by an MIT-style
56286 * license that can be found in the LICENSE file or at
56287 * https://opensource.org/licenses/MIT.
56288 * =============================================================================
56289 */
56290
56291 /**
56292 * @license
56293 * Copyright 2021 Google LLC. All Rights Reserved.
56294 * Licensed under the Apache License, Version 2.0 (the "License");
56295 * you may not use this file except in compliance with the License.
56296 * You may obtain a copy of the License at
56297 *
56298 * http://www.apache.org/licenses/LICENSE-2.0
56299 *
56300 * Unless required by applicable law or agreed to in writing, software
56301 * distributed under the License is distributed on an "AS IS" BASIS,
56302 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56303 * See the License for the specific language governing permissions and
56304 * limitations under the License.
56305 * =============================================================================
56306 */
56307 const ENV$1 = env();
56308 /** Whether to keep intermediate tensors. */
56309 ENV$1.registerFlag('KEEP_INTERMEDIATE_TENSORS', () => false, debugValue => {
56310 if (debugValue) {
56311 console.warn('Keep intermediate tensors is ON. This will print the values of all ' +
56312 'intermediate tensors during model inference. Not all models ' +
56313 'support this mode. For details, check e2e/benchmarks/ ' +
56314 'model_config.js. This significantly impacts performance.');
56315 }
56316 });
56317
56318 /**
56319 * @license
56320 * Copyright 2019 Google LLC. All Rights Reserved.
56321 * Licensed under the Apache License, Version 2.0 (the "License");
56322 * you may not use this file except in compliance with the License.
56323 * You may obtain a copy of the License at
56324 *
56325 * http://www.apache.org/licenses/LICENSE-2.0
56326 *
56327 * Unless required by applicable law or agreed to in writing, software
56328 * distributed under the License is distributed on an "AS IS" BASIS,
56329 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56330 * See the License for the specific language governing permissions and
56331 * limitations under the License.
56332 *
56333 * =============================================================================
56334 */
56335 /** DataType enum. */
56336 var DataType;
56337 (function (DataType) {
56338 // These properties must be quoted since they are used by parseDtypeParam
56339 // in tfjs-converter/src/operations/operation_mapper.ts to look up dtypes
56340 // by string name. If they are not quoted, Closure will mangle their names.
56341 // Not a legal value for DataType. Used to indicate a DataType field
56342 // has not been set.
56343 DataType[DataType["DT_INVALID"] = 0] = "DT_INVALID";
56344 // Data types that all computation devices are expected to be
56345 // capable to support.
56346 DataType[DataType["DT_FLOAT"] = 1] = "DT_FLOAT";
56347 DataType[DataType["DT_DOUBLE"] = 2] = "DT_DOUBLE";
56348 DataType[DataType["DT_INT32"] = 3] = "DT_INT32";
56349 DataType[DataType["DT_UINT8"] = 4] = "DT_UINT8";
56350 DataType[DataType["DT_INT16"] = 5] = "DT_INT16";
56351 DataType[DataType["DT_INT8"] = 6] = "DT_INT8";
56352 DataType[DataType["DT_STRING"] = 7] = "DT_STRING";
56353 DataType[DataType["DT_COMPLEX64"] = 8] = "DT_COMPLEX64";
56354 DataType[DataType["DT_INT64"] = 9] = "DT_INT64";
56355 DataType[DataType["DT_BOOL"] = 10] = "DT_BOOL";
56356 DataType[DataType["DT_QINT8"] = 11] = "DT_QINT8";
56357 DataType[DataType["DT_QUINT8"] = 12] = "DT_QUINT8";
56358 DataType[DataType["DT_QINT32"] = 13] = "DT_QINT32";
56359 DataType[DataType["DT_BFLOAT16"] = 14] = "DT_BFLOAT16";
56360 DataType[DataType["DT_QINT16"] = 15] = "DT_QINT16";
56361 DataType[DataType["DT_QUINT16"] = 16] = "DT_QUINT16";
56362 DataType[DataType["DT_UINT16"] = 17] = "DT_UINT16";
56363 DataType[DataType["DT_COMPLEX128"] = 18] = "DT_COMPLEX128";
56364 DataType[DataType["DT_HALF"] = 19] = "DT_HALF";
56365 DataType[DataType["DT_RESOURCE"] = 20] = "DT_RESOURCE";
56366 DataType[DataType["DT_VARIANT"] = 21] = "DT_VARIANT";
56367 DataType[DataType["DT_UINT32"] = 22] = "DT_UINT32";
56368 DataType[DataType["DT_UINT64"] = 23] = "DT_UINT64";
56369 // Do not use! These are only for parameters. Every enum above
56370 // should have a corresponding value below (verified by types_test).
56371 DataType[DataType["DT_FLOAT_REF"] = 101] = "DT_FLOAT_REF";
56372 DataType[DataType["DT_DOUBLE_REF"] = 102] = "DT_DOUBLE_REF";
56373 DataType[DataType["DT_INT32_REF"] = 103] = "DT_INT32_REF";
56374 DataType[DataType["DT_UINT8_REF"] = 104] = "DT_UINT8_REF";
56375 DataType[DataType["DT_INT16_REF"] = 105] = "DT_INT16_REF";
56376 DataType[DataType["DT_INT8_REF"] = 106] = "DT_INT8_REF";
56377 DataType[DataType["DT_STRING_REF"] = 107] = "DT_STRING_REF";
56378 DataType[DataType["DT_COMPLEX64_REF"] = 108] = "DT_COMPLEX64_REF";
56379 DataType[DataType["DT_INT64_REF"] = 109] = "DT_INT64_REF";
56380 DataType[DataType["DT_BOOL_REF"] = 110] = "DT_BOOL_REF";
56381 DataType[DataType["DT_QINT8_REF"] = 111] = "DT_QINT8_REF";
56382 DataType[DataType["DT_QUINT8_REF"] = 112] = "DT_QUINT8_REF";
56383 DataType[DataType["DT_QINT32_REF"] = 113] = "DT_QINT32_REF";
56384 DataType[DataType["DT_BFLOAT16_REF"] = 114] = "DT_BFLOAT16_REF";
56385 DataType[DataType["DT_QINT16_REF"] = 115] = "DT_QINT16_REF";
56386 DataType[DataType["DT_QUINT16_REF"] = 116] = "DT_QUINT16_REF";
56387 DataType[DataType["DT_UINT16_REF"] = 117] = "DT_UINT16_REF";
56388 DataType[DataType["DT_COMPLEX128_REF"] = 118] = "DT_COMPLEX128_REF";
56389 DataType[DataType["DT_HALF_REF"] = 119] = "DT_HALF_REF";
56390 DataType[DataType["DT_RESOURCE_REF"] = 120] = "DT_RESOURCE_REF";
56391 DataType[DataType["DT_VARIANT_REF"] = 121] = "DT_VARIANT_REF";
56392 DataType[DataType["DT_UINT32_REF"] = 122] = "DT_UINT32_REF";
56393 DataType[DataType["DT_UINT64_REF"] = 123] = "DT_UINT64_REF";
56394 })(DataType || (DataType = {}));
56395 var SaverDef;
56396 (function (SaverDef) {
56397 /** CheckpointFormatVersion enum. */
56398 let CheckpointFormatVersion;
56399 (function (CheckpointFormatVersion) {
56400 CheckpointFormatVersion[CheckpointFormatVersion["LEGACY"] = 0] = "LEGACY";
56401 CheckpointFormatVersion[CheckpointFormatVersion["V1"] = 1] = "V1";
56402 CheckpointFormatVersion[CheckpointFormatVersion["V2"] = 2] = "V2";
56403 })(CheckpointFormatVersion = SaverDef.CheckpointFormatVersion || (SaverDef.CheckpointFormatVersion = {}));
56404 })(SaverDef || (SaverDef = {}));
56405
56406 /**
56407 * @license
56408 * Copyright 2019 Google LLC. All Rights Reserved.
56409 * Licensed under the Apache License, Version 2.0 (the "License");
56410 * you may not use this file except in compliance with the License.
56411 * You may obtain a copy of the License at
56412 *
56413 * http://www.apache.org/licenses/LICENSE-2.0
56414 *
56415 * Unless required by applicable law or agreed to in writing, software
56416 * distributed under the License is distributed on an "AS IS" BASIS,
56417 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56418 * See the License for the specific language governing permissions and
56419 * limitations under the License.
56420 * =============================================================================
56421 */
56422 const CUSTOM_OPS = {};
56423 /**
56424 * Register an Op for graph model executor. This allows you to register
56425 * TensorFlow custom op or override existing op.
56426 *
56427 * Here is an example of registering a new MatMul Op.
56428 * ```js
56429 * const customMatmul = (node) =>
56430 * tf.matMul(
56431 * node.inputs[0], node.inputs[1],
56432 * node.attrs['transpose_a'], node.attrs['transpose_b']);
56433 *
56434 * tf.registerOp('MatMul', customMatmul);
56435 * ```
56436 * The inputs and attrs of the node object are based on the TensorFlow op
56437 * registry.
56438 *
56439 * @param name The Tensorflow Op name.
56440 * @param opFunc An op function which is called with the current graph node
56441 * during execution and needs to return a tensor or a list of tensors. The node
56442 * has the following attributes:
56443 * - attr: A map from attribute name to its value
56444 * - inputs: A list of input tensors
56445 *
56446 * @doc {heading: 'Models', subheading: 'Op Registry'}
56447 */
56448 function registerOp(name, opFunc) {
56449 const opMapper = {
56450 tfOpName: name,
56451 category: 'custom',
56452 inputs: [],
56453 attrs: [],
56454 customExecutor: opFunc
56455 };
56456 CUSTOM_OPS[name] = opMapper;
56457 }
56458 /**
56459 * Retrieve the OpMapper object for the registered op.
56460 *
56461 * @param name The Tensorflow Op name.
56462 *
56463 * @doc {heading: 'Models', subheading: 'Op Registry'}
56464 */
56465 function getRegisteredOp(name) {
56466 return CUSTOM_OPS[name];
56467 }
56468 /**
56469 * Deregister the Op for graph model executor.
56470 *
56471 * @param name The Tensorflow Op name.
56472 *
56473 * @doc {heading: 'Models', subheading: 'Op Registry'}
56474 */
56475 function deregisterOp(name) {
56476 delete CUSTOM_OPS[name];
56477 }
56478
56479 /**
56480 * @license
56481 * Copyright 2018 Google LLC. All Rights Reserved.
56482 * Licensed under the Apache License, Version 2.0 (the "License");
56483 * you may not use this file except in compliance with the License.
56484 * You may obtain a copy of the License at
56485 *
56486 * http://www.apache.org/licenses/LICENSE-2.0
56487 *
56488 * Unless required by applicable law or agreed to in writing, software
56489 * distributed under the License is distributed on an "AS IS" BASIS,
56490 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56491 * See the License for the specific language governing permissions and
56492 * limitations under the License.
56493 * =============================================================================
56494 */
56495 function getParamValue(paramName, node, tensorMap, context, resourceManager) {
56496 const inputParam = node.inputParams[paramName];
56497 if (inputParam && inputParam.inputIndexStart !== undefined) {
56498 const start = inputParam.inputIndexStart;
56499 const end = inputParam.inputIndexEnd === 0 ?
56500 undefined :
56501 (inputParam.inputIndexEnd === undefined ? start + 1 :
56502 inputParam.inputIndexEnd);
56503 const shiftedStart = start < 0 ? node.inputNames.length + start : start;
56504 if (inputParam.type === 'tensor') {
56505 return getTensor(node.inputNames[shiftedStart], tensorMap, context, resourceManager);
56506 }
56507 if (inputParam.type === 'tensors') {
56508 // TODO(mattSoulanille): This filters out NoOp nodes during execution, but
56509 // these should really never be in the execution graph in the first place.
56510 // They're necessary for ordering the graph, but should not be visible
56511 // during execution. Perhaps have different sets of children, one for
56512 // control dependencies and another for real dependencies.
56513 const inputs = node.inputs.slice(start, end);
56514 const inputNames = node.inputNames.slice(start, end)
56515 .filter((_name, index) => { var _a; return ((_a = inputs[index]) === null || _a === void 0 ? void 0 : _a.op) !== 'NoOp'; });
56516 return inputNames.map(name => getTensor(name, tensorMap, context, resourceManager));
56517 }
56518 const tensor = getTensor(node.inputNames[shiftedStart], tensorMap, context, resourceManager);
56519 const data = tensor.dataSync();
56520 return inputParam.type === 'number' ?
56521 data[0] :
56522 toNestedArray(tensor.shape, data);
56523 }
56524 const attrParam = node.attrParams[paramName];
56525 return attrParam && attrParam.value;
56526 }
56527 /**
56528 * Retrieve the tensor from tensorsMap based on input name.
56529 * @param name Node input name
56530 * @param tensorsMap Tensors map keyed by the node
56531 * @param context contains tensors and information for running the current node.
56532 * @param resourceManager Optional. Contains global resources of the model.
56533 */
56534 function getTensor(name, tensorsMap, context, resourceManager) {
56535 const [nodeName, index] = parseNodeName(name, context);
56536 if (resourceManager != null) {
56537 const tensor = resourceManager.getHashTableHandleByName(nodeName);
56538 if (tensor != null) {
56539 return tensor;
56540 }
56541 }
56542 const contextId = context.currentContextIds.find(contextId => {
56543 return !!tensorsMap[getNodeNameWithContextId(nodeName, contextId)];
56544 });
56545 return contextId !== undefined ?
56546 tensorsMap[getNodeNameWithContextId(nodeName, contextId)][index] :
56547 undefined;
56548 }
56549 /**
56550 * Retrieve the tensors based on input name for current context.
56551 * @param name Node input name
56552 * @param tensorsMap Tensors map keyed by the node
56553 */
56554 function getTensorsForCurrentContext(name, tensorsMap, context) {
56555 return tensorsMap[getNodeNameWithContextId(name, context.currentContextId)];
56556 }
56557 /**
56558 * Returns the node name, outputName and index from the Node input name.
56559 * @param inputName The input name of the node, in format of
56560 * node_name:output_index, i.e. MatMul:0, if the output_index is not set, it is
56561 * default to 0.
56562 * If the input name contains output name i.e. StringSplit:indices:0, it will
56563 * return ['StringSplit', 0, 'indices'].
56564 */
56565 function getNodeNameAndIndex(inputName, context) {
56566 const [nodeName, index, outputName] = parseNodeName(inputName, context);
56567 return [
56568 getNodeNameWithContextId(nodeName, context && context.currentContextId),
56569 index, outputName
56570 ];
56571 }
56572 function getNodeNameWithContextId(name, contextId) {
56573 return !!contextId ? `${name}-${contextId}` : name;
56574 }
56575 function parseNodeName(name, context) {
56576 if (name === '') {
56577 return ['', 0, undefined];
56578 }
56579 const isCacheEnabled = context != null && context.parseNodeNameCache != null;
56580 if (isCacheEnabled) {
56581 const cachedResult = context.parseNodeNameCache.get(name);
56582 if (cachedResult != null) {
56583 return cachedResult;
56584 }
56585 }
56586 const parts = name.split(':');
56587 let result;
56588 if (parts.length === 1) {
56589 result = [name, 0, undefined];
56590 }
56591 else {
56592 const nodeName = parts[0];
56593 const outputName = parts.length === 3 ? parts[1] : undefined;
56594 const index = Number(parts[parts.length - 1]);
56595 result = [nodeName, index, outputName];
56596 }
56597 if (isCacheEnabled) {
56598 context.parseNodeNameCache.set(name, result);
56599 }
56600 return result;
56601 }
56602 function split$2(arr, size) {
56603 const res = [];
56604 for (let i = 0; i < arr.length; i += size) {
56605 res.push(arr.slice(i, i + size));
56606 }
56607 return res;
56608 }
56609 function getPadding(node, tensorMap, context) {
56610 let pad = getParamValue('pad', node, tensorMap, context);
56611 if (pad === 'explicit') {
56612 // This is 1d array, we need to convert it to 2d array
56613 pad = getParamValue('explicitPaddings', node, tensorMap, context);
56614 const explicitPadding = [[0, 0], [0, 0], [0, 0], [0, 0]];
56615 for (let i = 0; i < 4; i++) {
56616 explicitPadding[i][0] = pad[i * 2];
56617 explicitPadding[i][1] = pad[i * 2 + 1];
56618 }
56619 return explicitPadding;
56620 }
56621 return pad;
56622 }
56623 /**
56624 * Reuse the tensor if it is marked as keep, otherwise clone the tensor to
56625 * avoid disposal. This is important for TensorArray and TensorList ops, since
56626 * internally they use a tensor as the id for TensorArray and TensorList, and
56627 * to simplify lookup, they also use Tensor.id as the key to the internal map.
56628 * These id tensors have been marked as kept in the backend, we need avoid clone
56629 * them in order to create new Tensor.id.
56630 * @param tensor
56631 */
56632 function cloneTensor(tensor) {
56633 return tensor.kept ? tensor : clone(tensor);
56634 }
56635
56636 /**
56637 * @license
56638 * Copyright 2023 Google LLC. All Rights Reserved.
56639 * Licensed under the Apache License, Version 2.0 (the "License");
56640 * you may not use this file except in compliance with the License.
56641 * You may obtain a copy of the License at
56642 *
56643 * http://www.apache.org/licenses/LICENSE-2.0
56644 *
56645 * Unless required by applicable law or agreed to in writing, software
56646 * distributed under the License is distributed on an "AS IS" BASIS,
56647 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
56648 * See the License for the specific language governing permissions and
56649 * limitations under the License.
56650 * =============================================================================
56651 */
56652 const json$i = [
56653 {
56654 'tfOpName': 'Add',
56655 'category': 'arithmetic',
56656 'inputs': [
56657 {
56658 'start': 0,
56659 'name': 'a',
56660 'type': 'tensor'
56661 },
56662 {
56663 'start': 1,
56664 'name': 'b',
56665 'type': 'tensor'
56666 }
56667 ],
56668 'attrs': [
56669 {
56670 'tfName': 'T',
56671 'name': 'dtype',
56672 'type': 'dtype',
56673 'notSupported': true
56674 }
56675 ]
56676 },
56677 {
56678 'tfOpName': 'AddV2',
56679 'category': 'arithmetic',
56680 'inputs': [
56681 {
56682 'start': 0,
56683 'name': 'a',
56684 'type': 'tensor'
56685 },
56686 {
56687 'start': 1,
56688 'name': 'b',
56689 'type': 'tensor'
56690 }
56691 ],
56692 'attrs': [
56693 {
56694 'tfName': 'T',
56695 'name': 'dtype',
56696 'type': 'dtype',
56697 'notSupported': true
56698 }
56699 ]
56700 },
56701 {
56702 'tfOpName': 'AddN',
56703 'category': 'arithmetic',
56704 'inputs': [
56705 {
56706 'start': 0,
56707 'end': 0,
56708 'name': 'tensors',
56709 'type': 'tensors'
56710 }
56711 ]
56712 },
56713 {
56714 'tfOpName': 'BiasAdd',
56715 'category': 'arithmetic',
56716 'inputs': [
56717 {
56718 'start': 0,
56719 'name': 'a',
56720 'type': 'tensor'
56721 },
56722 {
56723 'start': 1,
56724 'name': 'b',
56725 'type': 'tensor'
56726 }
56727 ],
56728 'attrs': [
56729 {
56730 'tfName': 'T',
56731 'name': 'dtype',
56732 'type': 'dtype',
56733 'notSupported': true
56734 },
56735 {
56736 'tfName': 'data_format',
56737 'name': 'dataFormat',
56738 'type': 'string',
56739 'notSupported': true
56740 }
56741 ]
56742 },
56743 {
56744 'tfOpName': 'Sub',
56745 'category': 'arithmetic',
56746 'inputs': [
56747 {
56748 'start': 0,
56749 'name': 'a',
56750 'type': 'tensor'
56751 },
56752 {
56753 'start': 1,
56754 'name': 'b',
56755 'type': 'tensor'
56756 }
56757 ],
56758 'attrs': [
56759 {
56760 'tfName': 'T',
56761 'name': 'dtype',
56762 'type': 'dtype',
56763 'notSupported': true
56764 }
56765 ]
56766 },
56767 {
56768 'tfOpName': 'RealDiv',
56769 'category': 'arithmetic',
56770 'inputs': [
56771 {
56772 'start': 0,
56773 'name': 'a',
56774 'type': 'tensor'
56775 },
56776 {
56777 'start': 1,
56778 'name': 'b',
56779 'type': 'tensor'
56780 }
56781 ],
56782 'attrs': [
56783 {
56784 'tfName': 'T',
56785 'name': 'dtype',
56786 'type': 'dtype',
56787 'notSupported': true
56788 }
56789 ]
56790 },
56791 {
56792 'tfOpName': 'Div',
56793 'category': 'arithmetic',
56794 'inputs': [
56795 {
56796 'start': 0,
56797 'name': 'a',
56798 'type': 'tensor'
56799 },
56800 {
56801 'start': 1,
56802 'name': 'b',
56803 'type': 'tensor'
56804 }
56805 ],
56806 'attrs': [
56807 {
56808 'tfName': 'T',
56809 'name': 'dtype',
56810 'type': 'dtype',
56811 'notSupported': true
56812 }
56813 ]
56814 },
56815 {
56816 'tfOpName': 'DivNoNan',
56817 'category': 'arithmetic',
56818 'inputs': [
56819 {
56820 'start': 0,
56821 'name': 'a',
56822 'type': 'tensor'
56823 },
56824 {
56825 'start': 1,
56826 'name': 'b',
56827 'type': 'tensor'
56828 }
56829 ],
56830 'attrs': [
56831 {
56832 'tfName': 'T',
56833 'name': 'dtype',
56834 'type': 'dtype',
56835 'notSupported': true
56836 }
56837 ]
56838 },
56839 {
56840 'tfOpName': 'FloorDiv',
56841 'category': 'arithmetic',
56842 'inputs': [
56843 {
56844 'start': 0,
56845 'name': 'a',
56846 'type': 'tensor'
56847 },
56848 {
56849 'start': 1,
56850 'name': 'b',
56851 'type': 'tensor'
56852 }
56853 ],
56854 'attrs': [
56855 {
56856 'tfName': 'T',
56857 'name': 'dtype',
56858 'type': 'dtype',
56859 'notSupported': true
56860 }
56861 ]
56862 },
56863 {
56864 'tfOpName': 'Mul',
56865 'category': 'arithmetic',
56866 'inputs': [
56867 {
56868 'start': 0,
56869 'name': 'a',
56870 'type': 'tensor'
56871 },
56872 {
56873 'start': 1,
56874 'name': 'b',
56875 'type': 'tensor'
56876 }
56877 ],
56878 'attrs': [
56879 {
56880 'tfName': 'T',
56881 'name': 'dtype',
56882 'type': 'dtype',
56883 'notSupported': true
56884 }
56885 ]
56886 },
56887 {
56888 'tfOpName': 'Maximum',
56889 'category': 'arithmetic',
56890 'inputs': [
56891 {
56892 'start': 0,
56893 'name': 'a',
56894 'type': 'tensor'
56895 },
56896 {
56897 'start': 1,
56898 'name': 'b',
56899 'type': 'tensor'
56900 }
56901 ],
56902 'attrs': [
56903 {
56904 'tfName': 'T',
56905 'name': 'dtype',
56906 'type': 'dtype',
56907 'notSupported': true
56908 }
56909 ]
56910 },
56911 {
56912 'tfOpName': 'Minimum',
56913 'category': 'arithmetic',
56914 'inputs': [
56915 {
56916 'start': 0,
56917 'name': 'a',
56918 'type': 'tensor'
56919 },
56920 {
56921 'start': 1,
56922 'name': 'b',
56923 'type': 'tensor'
56924 }
56925 ],
56926 'attrs': [
56927 {
56928 'tfName': 'T',
56929 'name': 'dtype',
56930 'type': 'dtype',
56931 'notSupported': true
56932 }
56933 ]
56934 },
56935 {
56936 'tfOpName': 'Pow',
56937 'category': 'arithmetic',
56938 'inputs': [
56939 {
56940 'start': 0,
56941 'name': 'a',
56942 'type': 'tensor'
56943 },
56944 {
56945 'start': 1,
56946 'name': 'b',
56947 'type': 'tensor'
56948 }
56949 ],
56950 'attrs': [
56951 {
56952 'tfName': 'T',
56953 'name': 'dtype',
56954 'type': 'dtype',
56955 'notSupported': true
56956 }
56957 ]
56958 },
56959 {
56960 'tfOpName': 'SquaredDifference',
56961 'category': 'arithmetic',
56962 'inputs': [
56963 {
56964 'start': 0,
56965 'name': 'a',
56966 'type': 'tensor'
56967 },
56968 {
56969 'start': 1,
56970 'name': 'b',
56971 'type': 'tensor'
56972 }
56973 ],
56974 'attrs': [
56975 {
56976 'tfName': 'T',
56977 'name': 'dtype',
56978 'type': 'dtype',
56979 'notSupported': true
56980 }
56981 ]
56982 },
56983 {
56984 'tfOpName': 'Mod',
56985 'category': 'arithmetic',
56986 'inputs': [
56987 {
56988 'start': 0,
56989 'name': 'a',
56990 'type': 'tensor'
56991 },
56992 {
56993 'start': 1,
56994 'name': 'b',
56995 'type': 'tensor'
56996 }
56997 ],
56998 'attrs': [
56999 {
57000 'tfName': 'T',
57001 'name': 'dtype',
57002 'type': 'dtype',
57003 'notSupported': true
57004 }
57005 ]
57006 },
57007 {
57008 'tfOpName': 'FloorMod',
57009 'category': 'arithmetic',
57010 'inputs': [
57011 {
57012 'start': 0,
57013 'name': 'a',
57014 'type': 'tensor'
57015 },
57016 {
57017 'start': 1,
57018 'name': 'b',
57019 'type': 'tensor'
57020 }
57021 ],
57022 'attrs': [
57023 {
57024 'tfName': 'T',
57025 'name': 'dtype',
57026 'type': 'dtype',
57027 'notSupported': true
57028 }
57029 ]
57030 }
57031 ];
57032
57033 var arithmetic = /*#__PURE__*/Object.freeze({
57034 __proto__: null,
57035 json: json$i
57036 });
57037
57038 /**
57039 * @license
57040 * Copyright 2023 Google LLC. All Rights Reserved.
57041 * Licensed under the Apache License, Version 2.0 (the "License");
57042 * you may not use this file except in compliance with the License.
57043 * You may obtain a copy of the License at
57044 *
57045 * http://www.apache.org/licenses/LICENSE-2.0
57046 *
57047 * Unless required by applicable law or agreed to in writing, software
57048 * distributed under the License is distributed on an "AS IS" BASIS,
57049 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
57050 * See the License for the specific language governing permissions and
57051 * limitations under the License.
57052 * =============================================================================
57053 */
57054 const json$h = [
57055 {
57056 'tfOpName': 'Abs',
57057 'category': 'basic_math',
57058 'inputs': [
57059 {
57060 'start': 0,
57061 'name': 'x',
57062 'type': 'tensor'
57063 }
57064 ],
57065 'attrs': [
57066 {
57067 'tfName': 'T',
57068 'name': 'dtype',
57069 'type': 'dtype',
57070 'notSupported': true
57071 }
57072 ]
57073 },
57074 {
57075 'tfOpName': 'Acos',
57076 'category': 'basic_math',
57077 'inputs': [
57078 {
57079 'start': 0,
57080 'name': 'x',
57081 'type': 'tensor'
57082 }
57083 ],
57084 'attrs': [
57085 {
57086 'tfName': 'T',
57087 'name': 'dtype',
57088 'type': 'dtype',
57089 'notSupported': true
57090 }
57091 ]
57092 },
57093 {
57094 'tfOpName': 'Asin',
57095 'category': 'basic_math',
57096 'inputs': [
57097 {
57098 'start': 0,
57099 'name': 'x',
57100 'type': 'tensor'
57101 }
57102 ],
57103 'attrs': [
57104 {
57105 'tfName': 'T',
57106 'name': 'dtype',
57107 'type': 'dtype',
57108 'notSupported': true
57109 }
57110 ]
57111 },
57112 {
57113 'tfOpName': 'Atan',
57114 'category': 'basic_math',
57115 'inputs': [
57116 {
57117 'start': 0,
57118 'name': 'x',
57119 'type': 'tensor'
57120 }
57121 ],
57122 'attrs': [
57123 {
57124 'tfName': 'T',
57125 'name': 'dtype',
57126 'type': 'dtype',
57127 'notSupported': true
57128 }
57129 ]
57130 },
57131 {
57132 'tfOpName': 'Atan2',
57133 'category': 'basic_math',
57134 'inputs': [
57135 {
57136 'start': 0,
57137 'name': 'x',
57138 'type': 'tensor'
57139 },
57140 {
57141 'start': 1,
57142 'name': 'y',
57143 'type': 'tensor'
57144 }
57145 ],
57146 'attrs': [
57147 {
57148 'tfName': 'T',
57149 'name': 'dtype',
57150 'type': 'dtype',
57151 'notSupported': true
57152 }
57153 ]
57154 },
57155 {
57156 'tfOpName': 'Ceil',
57157 'category': 'basic_math',
57158 'inputs': [
57159 {
57160 'start': 0,
57161 'name': 'x',
57162 'type': 'tensor'
57163 }
57164 ],
57165 'attrs': [
57166 {
57167 'tfName': 'T',
57168 'name': 'dtype',
57169 'type': 'dtype',
57170 'notSupported': true
57171 }
57172 ]
57173 },
57174 {
57175 'tfOpName': 'ClipByValue',
57176 'category': 'basic_math',
57177 'inputs': [
57178 {
57179 'start': 0,
57180 'name': 'x',
57181 'type': 'tensor'
57182 },
57183 {
57184 'start': 1,
57185 'name': 'clipValueMin',
57186 'type': 'number'
57187 },
57188 {
57189 'start': 2,
57190 'name': 'clipValueMax',
57191 'type': 'number'
57192 }
57193 ],
57194 'attrs': [
57195 {
57196 'tfName': 'T',
57197 'name': 'dtype',
57198 'type': 'dtype',
57199 'notSupported': true
57200 }
57201 ]
57202 },
57203 {
57204 'tfOpName': 'Complex',
57205 'category': 'basic_math',
57206 'inputs': [
57207 {
57208 'start': 0,
57209 'name': 'real',
57210 'type': 'tensor'
57211 },
57212 {
57213 'start': 1,
57214 'name': 'imag',
57215 'type': 'tensor'
57216 }
57217 ],
57218 'attrs': [
57219 {
57220 'tfName': 'T',
57221 'name': 'dtype',
57222 'type': 'dtype',
57223 'notSupported': true
57224 }
57225 ]
57226 },
57227 {
57228 'tfOpName': 'ComplexAbs',
57229 'category': 'basic_math',
57230 'inputs': [
57231 {
57232 'start': 0,
57233 'name': 'x',
57234 'type': 'tensor'
57235 }
57236 ],
57237 'attrs': [
57238 {
57239 'tfName': 'T',
57240 'name': 'dtype',
57241 'type': 'dtype',
57242 'notSupported': true
57243 }
57244 ]
57245 },
57246 {
57247 'tfOpName': 'Cos',
57248 'category': 'basic_math',
57249 'inputs': [
57250 {
57251 'start': 0,
57252 'name': 'x',
57253 'type': 'tensor'
57254 }
57255 ],
57256 'attrs': [
57257 {
57258 'tfName': 'T',
57259 'name': 'dtype',
57260 'type': 'dtype',
57261 'notSupported': true
57262 }
57263 ]
57264 },
57265 {
57266 'tfOpName': 'Cosh',
57267 'category': 'basic_math',
57268 'inputs': [
57269 {
57270 'start': 0,
57271 'name': 'x',
57272 'type': 'tensor'
57273 }
57274 ],
57275 'attrs': [
57276 {
57277 'tfName': 'T',
57278 'name': 'dtype',
57279 'type': 'dtype',
57280 'notSupported': true
57281 }
57282 ]
57283 },
57284 {
57285 'tfOpName': 'Elu',
57286 'category': 'basic_math',
57287 'inputs': [
57288 {
57289 'start': 0,
57290 'name': 'x',
57291 'type': 'tensor'
57292 }
57293 ],
57294 'attrs': [
57295 {
57296 'tfName': 'T',
57297 'name': 'dtype',
57298 'type': 'dtype',
57299 'notSupported': true
57300 }
57301 ]
57302 },
57303 {
57304 'tfOpName': 'Exp',
57305 'category': 'basic_math',
57306 'inputs': [
57307 {
57308 'start': 0,
57309 'name': 'x',
57310 'type': 'tensor'
57311 }
57312 ],
57313 'attrs': [
57314 {
57315 'tfName': 'T',
57316 'name': 'dtype',
57317 'type': 'dtype',
57318 'notSupported': true
57319 }
57320 ]
57321 },
57322 {
57323 'tfOpName': 'Floor',
57324 'category': 'basic_math',
57325 'inputs': [
57326 {
57327 'start': 0,
57328 'name': 'x',
57329 'type': 'tensor'
57330 }
57331 ],
57332 'attrs': [
57333 {
57334 'tfName': 'T',
57335 'name': 'dtype',
57336 'type': 'dtype',
57337 'notSupported': true
57338 }
57339 ]
57340 },
57341 {
57342 'tfOpName': 'Log',
57343 'category': 'basic_math',
57344 'inputs': [
57345 {
57346 'start': 0,
57347 'name': 'x',
57348 'type': 'tensor'
57349 }
57350 ],
57351 'attrs': [
57352 {
57353 'tfName': 'T',
57354 'name': 'dtype',
57355 'type': 'dtype',
57356 'notSupported': true
57357 }
57358 ]
57359 },
57360 {
57361 'tfOpName': 'Imag',
57362 'category': 'basic_math',
57363 'inputs': [
57364 {
57365 'start': 0,
57366 'name': 'x',
57367 'type': 'tensor'
57368 }
57369 ],
57370 'attrs': [
57371 {
57372 'tfName': 'T',
57373 'name': 'dtype',
57374 'type': 'dtype',
57375 'notSupported': true
57376 },
57377 {
57378 'tfName': 'Tout',
57379 'name': 'outputType',
57380 'type': 'dtype',
57381 'notSupported': true
57382 }
57383 ]
57384 },
57385 {
57386 'tfOpName': 'Neg',
57387 'category': 'basic_math',
57388 'inputs': [
57389 {
57390 'start': 0,
57391 'name': 'x',
57392 'type': 'tensor'
57393 }
57394 ],
57395 'attrs': [
57396 {
57397 'tfName': 'T',
57398 'name': 'dtype',
57399 'type': 'dtype',
57400 'notSupported': true
57401 }
57402 ]
57403 },
57404 {
57405 'tfOpName': 'Real',
57406 'category': 'basic_math',
57407 'inputs': [
57408 {
57409 'start': 0,
57410 'name': 'x',
57411 'type': 'tensor'
57412 }
57413 ],
57414 'attrs': [
57415 {
57416 'tfName': 'T',
57417 'name': 'dtype',
57418 'type': 'dtype',
57419 'notSupported': true
57420 },
57421 {
57422 'tfName': 'Tout',
57423 'name': 'outputType',
57424 'type': 'dtype',
57425 'notSupported': true
57426 }
57427 ]
57428 },
57429 {
57430 'tfOpName': 'Prelu',
57431 'category': 'basic_math',
57432 'inputs': [
57433 {
57434 'start': 0,
57435 'name': 'x',
57436 'type': 'tensor'
57437 },
57438 {
57439 'start': 1,
57440 'name': 'alpha',
57441 'type': 'tensor'
57442 }
57443 ],
57444 'attrs': [
57445 {
57446 'tfName': 'T',
57447 'name': 'dtype',
57448 'type': 'dtype',
57449 'notSupported': true
57450 }
57451 ]
57452 },
57453 {
57454 'tfOpName': 'Relu',
57455 'category': 'basic_math',
57456 'inputs': [
57457 {
57458 'start': 0,
57459 'name': 'x',
57460 'type': 'tensor'
57461 }
57462 ],
57463 'attrs': [
57464 {
57465 'tfName': 'T',
57466 'name': 'dtype',
57467 'type': 'dtype',
57468 'notSupported': true
57469 }
57470 ]
57471 },
57472 {
57473 'tfOpName': 'Relu6',
57474 'category': 'basic_math',
57475 'inputs': [
57476 {
57477 'start': 0,
57478 'name': 'x',
57479 'type': 'tensor'
57480 }
57481 ],
57482 'attrs': [
57483 {
57484 'tfName': 'T',
57485 'name': 'dtype',
57486 'type': 'dtype',
57487 'notSupported': true
57488 }
57489 ]
57490 },
57491 {
57492 'tfOpName': 'Selu',
57493 'category': 'basic_math',
57494 'inputs': [
57495 {
57496 'start': 0,
57497 'name': 'x',
57498 'type': 'tensor'
57499 }
57500 ],
57501 'attrs': [
57502 {
57503 'tfName': 'T',
57504 'name': 'dtype',
57505 'type': 'dtype',
57506 'notSupported': true
57507 }
57508 ]
57509 },
57510 {
57511 'tfOpName': 'Sigmoid',
57512 'category': 'basic_math',
57513 'inputs': [
57514 {
57515 'start': 0,
57516 'name': 'x',
57517 'type': 'tensor'
57518 }
57519 ],
57520 'attrs': [
57521 {
57522 'tfName': 'T',
57523 'name': 'dtype',
57524 'type': 'dtype',
57525 'notSupported': true
57526 }
57527 ]
57528 },
57529 {
57530 'tfOpName': 'Sin',
57531 'category': 'basic_math',
57532 'inputs': [
57533 {
57534 'start': 0,
57535 'name': 'x',
57536 'type': 'tensor'
57537 }
57538 ],
57539 'attrs': [
57540 {
57541 'tfName': 'T',
57542 'name': 'dtype',
57543 'type': 'dtype',
57544 'notSupported': true
57545 }
57546 ]
57547 },
57548 {
57549 'tfOpName': 'Sinh',
57550 'category': 'basic_math',
57551 'inputs': [
57552 {
57553 'start': 0,
57554 'name': 'x',
57555 'type': 'tensor'
57556 }
57557 ],
57558 'attrs': [
57559 {
57560 'tfName': 'T',
57561 'name': 'dtype',
57562 'type': 'dtype',
57563 'notSupported': true
57564 }
57565 ]
57566 },
57567 {
57568 'tfOpName': 'Sqrt',
57569 'category': 'basic_math',
57570 'inputs': [
57571 {
57572 'start': 0,
57573 'name': 'x',
57574 'type': 'tensor'
57575 }
57576 ],
57577 'attrs': [
57578 {
57579 'tfName': 'T',
57580 'name': 'dtype',
57581 'type': 'dtype',
57582 'notSupported': true
57583 }
57584 ]
57585 },
57586 {
57587 'tfOpName': 'Rsqrt',
57588 'category': 'basic_math',
57589 'inputs': [
57590 {
57591 'start': 0,
57592 'name': 'x',
57593 'type': 'tensor'
57594 }
57595 ],
57596 'attrs': [
57597 {
57598 'tfName': 'T',
57599 'name': 'dtype',
57600 'type': 'dtype',
57601 'notSupported': true
57602 }
57603 ]
57604 },
57605 {
57606 'tfOpName': 'Square',
57607 'category': 'basic_math',
57608 'inputs': [
57609 {
57610 'start': 0,
57611 'name': 'x',
57612 'type': 'tensor'
57613 }
57614 ],
57615 'attrs': [
57616 {
57617 'tfName': 'T',
57618 'name': 'dtype',
57619 'type': 'dtype',
57620 'notSupported': true
57621 }
57622 ]
57623 },
57624 {
57625 'tfOpName': 'Tan',
57626 'category': 'basic_math',
57627 'inputs': [
57628 {
57629 'start': 0,
57630 'name': 'x',
57631 'type': 'tensor'
57632 }
57633 ],
57634 'attrs': [
57635 {
57636 'tfName': 'T',
57637 'name': 'dtype',
57638 'type': 'dtype',
57639 'notSupported': true
57640 }
57641 ]
57642 },
57643 {
57644 'tfOpName': 'Tanh',
57645 'category': 'basic_math',
57646 'inputs': [
57647 {
57648 'start': 0,
57649 'name': 'x',
57650 'type': 'tensor'
57651 }
57652 ],
57653 'attrs': [
57654 {
57655 'tfName': 'T',
57656 'name': 'dtype',
57657 'type': 'dtype',
57658 'notSupported': true
57659 }
57660 ]
57661 },
57662 {
57663 'tfOpName': 'Sign',
57664 'category': 'basic_math',
57665 'inputs': [
57666 {
57667 'start': 0,
57668 'name': 'x',
57669 'type': 'tensor'
57670 }
57671 ],
57672 'attrs': [
57673 {
57674 'tfName': 'T',
57675 'name': 'dtype',
57676 'type': 'dtype',
57677 'notSupported': true
57678 }
57679 ]
57680 },
57681 {
57682 'tfOpName': 'Round',
57683 'category': 'basic_math',
57684 'inputs': [
57685 {
57686 'start': 0,
57687 'name': 'x',
57688 'type': 'tensor'
57689 }
57690 ],
57691 'attrs': [
57692 {
57693 'tfName': 'T',
57694 'name': 'dtype',
57695 'type': 'dtype',
57696 'notSupported': true
57697 }
57698 ]
57699 },
57700 {
57701 'tfOpName': 'Expm1',
57702 'category': 'basic_math',
57703 'inputs': [
57704 {
57705 'start': 0,
57706 'name': 'x',
57707 'type': 'tensor'
57708 }
57709 ],
57710 'attrs': [
57711 {
57712 'tfName': 'T',
57713 'name': 'dtype',
57714 'type': 'dtype',
57715 'notSupported': true
57716 }
57717 ]
57718 },
57719 {
57720 'tfOpName': 'Log1p',
57721 'category': 'basic_math',
57722 'inputs': [
57723 {
57724 'start': 0,
57725 'name': 'x',
57726 'type': 'tensor'
57727 }
57728 ],
57729 'attrs': [
57730 {
57731 'tfName': 'T',
57732 'name': 'dtype',
57733 'type': 'dtype',
57734 'notSupported': true
57735 }
57736 ]
57737 },
57738 {
57739 'tfOpName': 'Reciprocal',
57740 'category': 'basic_math',
57741 'inputs': [
57742 {
57743 'start': 0,
57744 'name': 'x',
57745 'type': 'tensor'
57746 }
57747 ],
57748 'attrs': [
57749 {
57750 'tfName': 'T',
57751 'name': 'dtype',
57752 'type': 'dtype',
57753 'notSupported': true
57754 }
57755 ]
57756 },
57757 {
57758 'tfOpName': 'Softplus',
57759 'category': 'basic_math',
57760 'inputs': [
57761 {
57762 'start': 0,
57763 'name': 'x',
57764 'type': 'tensor'
57765 }
57766 ],
57767 'attrs': [
57768 {
57769 'tfName': 'T',
57770 'name': 'dtype',
57771 'type': 'dtype',
57772 'notSupported': true
57773 }
57774 ]
57775 },
57776 {
57777 'tfOpName': 'Asinh',
57778 'category': 'basic_math',
57779 'inputs': [
57780 {
57781 'start': 0,
57782 'name': 'x',
57783 'type': 'tensor'
57784 }
57785 ],
57786 'attrs': [
57787 {
57788 'tfName': 'T',
57789 'name': 'dtype',
57790 'type': 'dtype',
57791 'notSupported': true
57792 }
57793 ]
57794 },
57795 {
57796 'tfOpName': 'Acosh',
57797 'category': 'basic_math',
57798 'inputs': [
57799 {
57800 'start': 0,
57801 'name': 'x',
57802 'type': 'tensor'
57803 }
57804 ],
57805 'attrs': [
57806 {
57807 'tfName': 'T',
57808 'name': 'dtype',
57809 'type': 'dtype',
57810 'notSupported': true
57811 }
57812 ]
57813 },
57814 {
57815 'tfOpName': 'Atanh',
57816 'category': 'basic_math',
57817 'inputs': [
57818 {
57819 'start': 0,
57820 'name': 'x',
57821 'type': 'tensor'
57822 }
57823 ],
57824 'attrs': [
57825 {
57826 'tfName': 'T',
57827 'name': 'dtype',
57828 'type': 'dtype',
57829 'notSupported': true
57830 }
57831 ]
57832 },
57833 {
57834 'tfOpName': 'Erf',
57835 'category': 'basic_math',
57836 'inputs': [
57837 {
57838 'start': 0,
57839 'name': 'x',
57840 'type': 'tensor'
57841 }
57842 ],
57843 'attrs': [
57844 {
57845 'tfName': 'T',
57846 'name': 'dtype',
57847 'type': 'dtype',
57848 'notSupported': true
57849 }
57850 ]
57851 },
57852 {
57853 'tfOpName': 'LeakyRelu',
57854 'category': 'basic_math',
57855 'inputs': [
57856 {
57857 'start': 0,
57858 'name': 'x',
57859 'type': 'tensor'
57860 }
57861 ],
57862 'attrs': [
57863 {
57864 'tfName': 'alpha',
57865 'name': 'alpha',
57866 'type': 'number',
57867 'defaultValue': 0.2
57868 },
57869 {
57870 'tfName': 'T',
57871 'name': 'dtype',
57872 'type': 'dtype',
57873 'notSupported': true
57874 }
57875 ]
57876 },
57877 {
57878 'tfOpName': 'IsNan',
57879 'category': 'basic_math',
57880 'inputs': [
57881 {
57882 'start': 0,
57883 'name': 'x',
57884 'type': 'tensor'
57885 }
57886 ],
57887 'attrs': [
57888 {
57889 'tfName': 'T',
57890 'name': 'dtype',
57891 'type': 'dtype',
57892 'notSupported': true
57893 }
57894 ]
57895 },
57896 {
57897 'tfOpName': 'IsFinite',
57898 'category': 'basic_math',
57899 'inputs': [
57900 {
57901 'start': 0,
57902 'name': 'x',
57903 'type': 'tensor'
57904 }
57905 ],
57906 'attrs': [
57907 {
57908 'tfName': 'T',
57909 'name': 'dtype',
57910 'type': 'dtype',
57911 'notSupported': true
57912 }
57913 ]
57914 },
57915 {
57916 'tfOpName': 'IsInf',
57917 'category': 'basic_math',
57918 'inputs': [
57919 {
57920 'start': 0,
57921 'name': 'x',
57922 'type': 'tensor'
57923 }
57924 ],
57925 'attrs': [
57926 {
57927 'tfName': 'T',
57928 'name': 'dtype',
57929 'type': 'dtype',
57930 'notSupported': true
57931 }
57932 ]
57933 }
57934 ];
57935
57936 var basicMath = /*#__PURE__*/Object.freeze({
57937 __proto__: null,
57938 json: json$h
57939 });
57940
57941 /**
57942 * @license
57943 * Copyright 2023 Google LLC. All Rights Reserved.
57944 * Licensed under the Apache License, Version 2.0 (the "License");
57945 * you may not use this file except in compliance with the License.
57946 * You may obtain a copy of the License at
57947 *
57948 * http://www.apache.org/licenses/LICENSE-2.0
57949 *
57950 * Unless required by applicable law or agreed to in writing, software
57951 * distributed under the License is distributed on an "AS IS" BASIS,
57952 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
57953 * See the License for the specific language governing permissions and
57954 * limitations under the License.
57955 * =============================================================================
57956 */
57957 const json$g = [
57958 {
57959 'tfOpName': 'EmptyTensorList',
57960 'category': 'control',
57961 'inputs': [
57962 {
57963 'start': 0,
57964 'name': 'elementShape',
57965 'type': 'shape'
57966 },
57967 {
57968 'start': 1,
57969 'name': 'maxNumElements',
57970 'type': 'number'
57971 }
57972 ],
57973 'attrs': [
57974 {
57975 'tfName': 'element_dtype',
57976 'name': 'elementDType',
57977 'type': 'dtype'
57978 }
57979 ]
57980 },
57981 {
57982 'tfOpName': 'LoopCond',
57983 'category': 'control',
57984 'inputs': [
57985 {
57986 'start': 0,
57987 'name': 'pred',
57988 'type': 'tensor'
57989 }
57990 ]
57991 },
57992 {
57993 'tfOpName': 'Switch',
57994 'category': 'control',
57995 'inputs': [
57996 {
57997 'start': 0,
57998 'name': 'data',
57999 'type': 'tensor'
58000 },
58001 {
58002 'start': 1,
58003 'name': 'pred',
58004 'type': 'tensor'
58005 }
58006 ]
58007 },
58008 {
58009 'tfOpName': 'Merge',
58010 'category': 'control',
58011 'inputs': [
58012 {
58013 'start': 0,
58014 'end': 0,
58015 'name': 'tensors',
58016 'type': 'tensors'
58017 }
58018 ]
58019 },
58020 {
58021 'tfOpName': 'Enter',
58022 'category': 'control',
58023 'inputs': [
58024 {
58025 'start': 0,
58026 'name': 'tensor',
58027 'type': 'tensor'
58028 }
58029 ],
58030 'attrs': [
58031 {
58032 'tfName': 'T',
58033 'name': 'dtype',
58034 'type': 'dtype',
58035 'notSupported': true
58036 },
58037 {
58038 'tfName': 'frame_name',
58039 'name': 'frameName',
58040 'type': 'string'
58041 },
58042 {
58043 'tfName': 'is_constant',
58044 'name': 'isConstant',
58045 'type': 'bool'
58046 }
58047 ]
58048 },
58049 {
58050 'tfOpName': 'Exit',
58051 'category': 'control',
58052 'inputs': [
58053 {
58054 'start': 0,
58055 'name': 'tensor',
58056 'type': 'tensor'
58057 }
58058 ],
58059 'attrs': [
58060 {
58061 'tfName': 'T',
58062 'name': 'dtype',
58063 'type': 'dtype',
58064 'notSupported': true
58065 }
58066 ]
58067 },
58068 {
58069 'tfOpName': 'NextIteration',
58070 'category': 'control',
58071 'inputs': [
58072 {
58073 'start': 0,
58074 'name': 'tensor',
58075 'type': 'tensor'
58076 }
58077 ],
58078 'attrs': [
58079 {
58080 'tfName': 'T',
58081 'name': 'dtype',
58082 'type': 'dtype',
58083 'notSupported': true
58084 }
58085 ]
58086 },
58087 {
58088 'tfOpName': 'TensorArrayV3',
58089 'category': 'control',
58090 'inputs': [
58091 {
58092 'start': 0,
58093 'name': 'size',
58094 'type': 'number'
58095 }
58096 ],
58097 'attrs': [
58098 {
58099 'tfName': 'dtype',
58100 'name': 'dtype',
58101 'type': 'dtype'
58102 },
58103 {
58104 'tfName': 'element_shape',
58105 'name': 'elementShape',
58106 'type': 'shape'
58107 },
58108 {
58109 'tfName': 'dynamic_size',
58110 'name': 'dynamicSize',
58111 'type': 'bool'
58112 },
58113 {
58114 'tfName': 'clear_after_read',
58115 'name': 'clearAfterRead',
58116 'type': 'bool'
58117 },
58118 {
58119 'tfName': 'identical_element_shapes',
58120 'name': 'identicalElementShapes',
58121 'type': 'bool'
58122 },
58123 {
58124 'tfName': 'tensor_array_name',
58125 'name': 'name',
58126 'type': 'string'
58127 }
58128 ]
58129 },
58130 {
58131 'tfOpName': 'TensorArrayWriteV3',
58132 'category': 'control',
58133 'inputs': [
58134 {
58135 'start': 0,
58136 'name': 'tensorArrayId',
58137 'type': 'tensor'
58138 },
58139 {
58140 'start': 1,
58141 'name': 'index',
58142 'type': 'number'
58143 },
58144 {
58145 'start': 2,
58146 'name': 'tensor',
58147 'type': 'tensor'
58148 },
58149 {
58150 'start': 3,
58151 'name': 'flowIn',
58152 'type': 'number'
58153 }
58154 ],
58155 'attrs': [
58156 {
58157 'tfName': 'T',
58158 'name': 'dtype',
58159 'type': 'dtype',
58160 'notSupported': true
58161 }
58162 ]
58163 },
58164 {
58165 'tfOpName': 'TensorArrayReadV3',
58166 'category': 'control',
58167 'inputs': [
58168 {
58169 'start': 0,
58170 'name': 'tensorArrayId',
58171 'type': 'tensor'
58172 },
58173 {
58174 'start': 1,
58175 'name': 'index',
58176 'type': 'number'
58177 },
58178 {
58179 'start': 2,
58180 'name': 'flowIn',
58181 'type': 'number'
58182 }
58183 ],
58184 'attrs': [
58185 {
58186 'tfName': 'dtype',
58187 'name': 'dtype',
58188 'type': 'dtype',
58189 'notSupported': true
58190 }
58191 ]
58192 },
58193 {
58194 'tfOpName': 'TensorArrayGatherV3',
58195 'category': 'control',
58196 'inputs': [
58197 {
58198 'start': 0,
58199 'name': 'tensorArrayId',
58200 'type': 'tensor'
58201 },
58202 {
58203 'start': 1,
58204 'name': 'indices',
58205 'type': 'number[]'
58206 },
58207 {
58208 'start': 2,
58209 'name': 'flowIn',
58210 'type': 'number'
58211 }
58212 ],
58213 'attrs': [
58214 {
58215 'tfName': 'dtype',
58216 'name': 'dtype',
58217 'type': 'dtype'
58218 },
58219 {
58220 'tfName': 'element_shape',
58221 'name': 'elementShape',
58222 'type': 'shape'
58223 }
58224 ]
58225 },
58226 {
58227 'tfOpName': 'TensorArrayScatterV3',
58228 'category': 'control',
58229 'inputs': [
58230 {
58231 'start': 0,
58232 'name': 'tensorArrayId',
58233 'type': 'tensor'
58234 },
58235 {
58236 'start': 1,
58237 'name': 'indices',
58238 'type': 'number[]'
58239 },
58240 {
58241 'start': 2,
58242 'name': 'tensor',
58243 'type': 'tensor'
58244 },
58245 {
58246 'start': 3,
58247 'name': 'flowIn',
58248 'type': 'number'
58249 }
58250 ],
58251 'attrs': [
58252 {
58253 'tfName': 'T',
58254 'name': 'dtype',
58255 'type': 'dtype'
58256 }
58257 ]
58258 },
58259 {
58260 'tfOpName': 'TensorArrayConcatV3',
58261 'category': 'control',
58262 'inputs': [
58263 {
58264 'start': 0,
58265 'name': 'tensorArrayId',
58266 'type': 'tensor'
58267 },
58268 {
58269 'start': 1,
58270 'name': 'flowIn',
58271 'type': 'number'
58272 }
58273 ],
58274 'attrs': [
58275 {
58276 'tfName': 'dtype',
58277 'name': 'dtype',
58278 'type': 'dtype'
58279 },
58280 {
58281 'tfName': 'element_shape_except0',
58282 'name': 'elementShapeExcept0',
58283 'type': 'shape',
58284 'notSupported': true
58285 }
58286 ]
58287 },
58288 {
58289 'tfOpName': 'TensorArraySplitV3',
58290 'category': 'control',
58291 'inputs': [
58292 {
58293 'start': 0,
58294 'name': 'tensorArrayId',
58295 'type': 'tensor'
58296 },
58297 {
58298 'start': 1,
58299 'name': 'tensor',
58300 'type': 'tensor'
58301 },
58302 {
58303 'start': 2,
58304 'name': 'lengths',
58305 'type': 'number[]'
58306 },
58307 {
58308 'start': 3,
58309 'name': 'flowIn',
58310 'type': 'number'
58311 }
58312 ],
58313 'attrs': [
58314 {
58315 'tfName': 'T',
58316 'name': 'dtype',
58317 'type': 'dtype'
58318 }
58319 ]
58320 },
58321 {
58322 'tfOpName': 'TensorArraySizeV3',
58323 'category': 'control',
58324 'inputs': [
58325 {
58326 'start': 0,
58327 'name': 'tensorArrayId',
58328 'type': 'tensor'
58329 },
58330 {
58331 'start': 1,
58332 'name': 'flowIn',
58333 'type': 'number'
58334 }
58335 ]
58336 },
58337 {
58338 'tfOpName': 'TensorArrayCloseV3',
58339 'category': 'control',
58340 'inputs': [
58341 {
58342 'start': 0,
58343 'name': 'tensorArrayId',
58344 'type': 'tensor'
58345 }
58346 ]
58347 },
58348 {
58349 'tfOpName': 'StatelessIf',
58350 'category': 'control',
58351 'inputs': [
58352 {
58353 'start': 0,
58354 'name': 'cond',
58355 'type': 'tensor'
58356 },
58357 {
58358 'start': 1,
58359 'end': 0,
58360 'name': 'args',
58361 'type': 'tensors'
58362 }
58363 ],
58364 'attrs': [
58365 {
58366 'tfName': 'then_branch',
58367 'name': 'thenBranch',
58368 'type': 'func'
58369 },
58370 {
58371 'tfName': 'else_branch',
58372 'name': 'elseBranch',
58373 'type': 'func'
58374 }
58375 ]
58376 },
58377 {
58378 'tfOpName': 'If',
58379 'category': 'control',
58380 'inputs': [
58381 {
58382 'start': 0,
58383 'name': 'cond',
58384 'type': 'tensor'
58385 },
58386 {
58387 'start': 1,
58388 'end': 0,
58389 'name': 'args',
58390 'type': 'tensors'
58391 }
58392 ],
58393 'attrs': [
58394 {
58395 'tfName': 'then_branch',
58396 'name': 'thenBranch',
58397 'type': 'func'
58398 },
58399 {
58400 'tfName': 'else_branch',
58401 'name': 'elseBranch',
58402 'type': 'func'
58403 }
58404 ]
58405 },
58406 {
58407 'tfOpName': 'StatelessWhile',
58408 'category': 'control',
58409 'inputs': [
58410 {
58411 'start': 0,
58412 'end': 0,
58413 'name': 'args',
58414 'type': 'tensors'
58415 }
58416 ],
58417 'attrs': [
58418 {
58419 'tfName': 'cond',
58420 'name': 'cond',
58421 'type': 'func'
58422 },
58423 {
58424 'tfName': 'body',
58425 'name': 'body',
58426 'type': 'func'
58427 }
58428 ]
58429 },
58430 {
58431 'tfOpName': 'While',
58432 'category': 'control',
58433 'inputs': [
58434 {
58435 'start': 0,
58436 'end': 0,
58437 'name': 'args',
58438 'type': 'tensors'
58439 }
58440 ],
58441 'attrs': [
58442 {
58443 'tfName': 'cond',
58444 'name': 'cond',
58445 'type': 'func'
58446 },
58447 {
58448 'tfName': 'body',
58449 'name': 'body',
58450 'type': 'func'
58451 }
58452 ]
58453 },
58454 {
58455 'tfOpName': 'TensorListScatter',
58456 'category': 'control',
58457 'inputs': [
58458 {
58459 'start': 0,
58460 'name': 'tensor',
58461 'type': 'tensor'
58462 },
58463 {
58464 'start': 1,
58465 'name': 'indices',
58466 'type': 'number[]'
58467 },
58468 {
58469 'start': 2,
58470 'name': 'elementShape',
58471 'type': 'shape'
58472 }
58473 ],
58474 'attrs': [
58475 {
58476 'tfName': 'element_dtype',
58477 'name': 'elementDType',
58478 'type': 'dtype'
58479 }
58480 ]
58481 },
58482 {
58483 'tfOpName': 'TensorListScatterV2',
58484 'category': 'control',
58485 'inputs': [
58486 {
58487 'start': 0,
58488 'name': 'tensor',
58489 'type': 'tensor'
58490 },
58491 {
58492 'start': 1,
58493 'name': 'indices',
58494 'type': 'number[]'
58495 },
58496 {
58497 'start': 2,
58498 'name': 'elementShape',
58499 'type': 'shape'
58500 },
58501 {
58502 'start': 3,
58503 'name': 'numElements',
58504 'type': 'number'
58505 }
58506 ],
58507 'attrs': [
58508 {
58509 'tfName': 'element_dtype',
58510 'name': 'elementDType',
58511 'type': 'dtype'
58512 }
58513 ]
58514 },
58515 {
58516 'tfOpName': 'TensorListGather',
58517 'category': 'control',
58518 'inputs': [
58519 {
58520 'start': 0,
58521 'name': 'tensorListId',
58522 'type': 'tensor'
58523 },
58524 {
58525 'start': 1,
58526 'name': 'indices',
58527 'type': 'number[]'
58528 },
58529 {
58530 'start': 2,
58531 'name': 'elementShape',
58532 'type': 'shape'
58533 }
58534 ],
58535 'attrs': [
58536 {
58537 'tfName': 'element_dtype',
58538 'name': 'elementDType',
58539 'type': 'dtype'
58540 }
58541 ]
58542 },
58543 {
58544 'tfOpName': 'TensorListGetItem',
58545 'category': 'control',
58546 'inputs': [
58547 {
58548 'start': 0,
58549 'name': 'tensorListId',
58550 'type': 'tensor'
58551 },
58552 {
58553 'start': 1,
58554 'name': 'index',
58555 'type': 'number'
58556 },
58557 {
58558 'start': 2,
58559 'name': 'elementShape',
58560 'type': 'shape'
58561 }
58562 ],
58563 'attrs': [
58564 {
58565 'tfName': 'element_dtype',
58566 'name': 'elementDType',
58567 'type': 'dtype'
58568 }
58569 ]
58570 },
58571 {
58572 'tfOpName': 'TensorListSetItem',
58573 'category': 'control',
58574 'inputs': [
58575 {
58576 'start': 0,
58577 'name': 'tensorListId',
58578 'type': 'tensor'
58579 },
58580 {
58581 'start': 1,
58582 'name': 'index',
58583 'type': 'number'
58584 },
58585 {
58586 'start': 2,
58587 'name': 'tensor',
58588 'type': 'tensor'
58589 }
58590 ],
58591 'attrs': [
58592 {
58593 'tfName': 'element_dtype',
58594 'name': 'elementDType',
58595 'type': 'dtype'
58596 }
58597 ]
58598 },
58599 {
58600 'tfOpName': 'TensorListReserve',
58601 'category': 'control',
58602 'inputs': [
58603 {
58604 'start': 0,
58605 'name': 'elementShape',
58606 'type': 'shape'
58607 },
58608 {
58609 'start': 1,
58610 'name': 'numElements',
58611 'type': 'number'
58612 }
58613 ],
58614 'attrs': [
58615 {
58616 'tfName': 'element_dtype',
58617 'name': 'elementDType',
58618 'type': 'dtype'
58619 }
58620 ]
58621 },
58622 {
58623 'tfOpName': 'TensorListFromTensor',
58624 'category': 'control',
58625 'inputs': [
58626 {
58627 'start': 0,
58628 'name': 'tensor',
58629 'type': 'tensor'
58630 },
58631 {
58632 'start': 1,
58633 'name': 'elementShape',
58634 'type': 'shape'
58635 }
58636 ],
58637 'attrs': [
58638 {
58639 'tfName': 'element_dtype',
58640 'name': 'elementDType',
58641 'type': 'dtype'
58642 }
58643 ]
58644 },
58645 {
58646 'tfOpName': 'TensorListStack',
58647 'category': 'control',
58648 'inputs': [
58649 {
58650 'start': 0,
58651 'name': 'tensorListId',
58652 'type': 'tensor'
58653 },
58654 {
58655 'start': 1,
58656 'name': 'elementShape',
58657 'type': 'shape'
58658 }
58659 ],
58660 'attrs': [
58661 {
58662 'tfName': 'element_dtype',
58663 'name': 'elementDType',
58664 'type': 'dtype'
58665 },
58666 {
58667 'tfName': 'num_elements',
58668 'name': 'numElements',
58669 'type': 'dtype'
58670 }
58671 ]
58672 },
58673 {
58674 'tfOpName': 'TensorListSplit',
58675 'category': 'control',
58676 'inputs': [
58677 {
58678 'start': 0,
58679 'name': 'tensor',
58680 'type': 'tensor'
58681 },
58682 {
58683 'start': 1,
58684 'name': 'elementShape',
58685 'type': 'shape'
58686 },
58687 {
58688 'start': 2,
58689 'name': 'lengths',
58690 'type': 'number[]'
58691 }
58692 ],
58693 'attrs': [
58694 {
58695 'tfName': 'element_dtype',
58696 'name': 'elementDType',
58697 'type': 'dtype'
58698 }
58699 ]
58700 },
58701 {
58702 'tfOpName': 'TensorListConcat',
58703 'category': 'control',
58704 'inputs': [
58705 {
58706 'start': 0,
58707 'name': 'tensorListId',
58708 'type': 'tensor'
58709 }
58710 ],
58711 'attrs': [
58712 {
58713 'tfName': 'element_shape',
58714 'name': 'elementShape',
58715 'type': 'shape'
58716 },
58717 {
58718 'tfName': 'element_dtype',
58719 'name': 'elementDType',
58720 'type': 'dtype'
58721 }
58722 ]
58723 },
58724 {
58725 'tfOpName': 'TensorListConcatV2',
58726 'category': 'control',
58727 'inputs': [
58728 {
58729 'start': 0,
58730 'name': 'tensorListId',
58731 'type': 'tensor'
58732 }
58733 ],
58734 'attrs': [
58735 {
58736 'tfName': 'element_shape',
58737 'name': 'elementShape',
58738 'type': 'shape'
58739 },
58740 {
58741 'tfName': 'element_dtype',
58742 'name': 'elementDType',
58743 'type': 'dtype'
58744 }
58745 ]
58746 },
58747 {
58748 'tfOpName': 'TensorListPopBack',
58749 'category': 'control',
58750 'inputs': [
58751 {
58752 'start': 0,
58753 'name': 'tensorListId',
58754 'type': 'tensor'
58755 },
58756 {
58757 'start': 1,
58758 'name': 'elementShape',
58759 'type': 'shape'
58760 }
58761 ],
58762 'attrs': [
58763 {
58764 'tfName': 'element_dtype',
58765 'name': 'elementDType',
58766 'type': 'dtype'
58767 }
58768 ]
58769 },
58770 {
58771 'tfOpName': 'TensorListPushBack',
58772 'category': 'control',
58773 'inputs': [
58774 {
58775 'start': 0,
58776 'name': 'tensorListId',
58777 'type': 'tensor'
58778 },
58779 {
58780 'start': 1,
58781 'name': 'tensor',
58782 'type': 'tensor'
58783 }
58784 ],
58785 'attrs': [
58786 {
58787 'tfName': 'element_dtype',
58788 'name': 'elementDType',
58789 'type': 'dtype'
58790 }
58791 ]
58792 },
58793 {
58794 'tfOpName': 'TensorListLength',
58795 'category': 'control',
58796 'inputs': [
58797 {
58798 'start': 0,
58799 'name': 'tensorListId',
58800 'type': 'tensor'
58801 }
58802 ]
58803 },
58804 {
58805 'tfOpName': 'TensorListResize',
58806 'category': 'control',
58807 'inputs': [
58808 {
58809 'start': 0,
58810 'name': 'tensorListId',
58811 'type': 'tensor'
58812 },
58813 {
58814 'start': 1,
58815 'name': 'size',
58816 'type': 'number'
58817 }
58818 ]
58819 }
58820 ];
58821
58822 var control = /*#__PURE__*/Object.freeze({
58823 __proto__: null,
58824 json: json$g
58825 });
58826
58827 /**
58828 * @license
58829 * Copyright 2023 Google LLC. All Rights Reserved.
58830 * Licensed under the Apache License, Version 2.0 (the "License");
58831 * you may not use this file except in compliance with the License.
58832 * You may obtain a copy of the License at
58833 *
58834 * http://www.apache.org/licenses/LICENSE-2.0
58835 *
58836 * Unless required by applicable law or agreed to in writing, software
58837 * distributed under the License is distributed on an "AS IS" BASIS,
58838 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
58839 * See the License for the specific language governing permissions and
58840 * limitations under the License.
58841 * =============================================================================
58842 */
58843 const json$f = [
58844 {
58845 'tfOpName': 'AvgPool',
58846 'category': 'convolution',
58847 'inputs': [
58848 {
58849 'start': 0,
58850 'name': 'x',
58851 'type': 'tensor'
58852 }
58853 ],
58854 'attrs': [
58855 {
58856 'tfName': 'strides',
58857 'name': 'strides',
58858 'type': 'number[]'
58859 },
58860 {
58861 'tfName': 'padding',
58862 'name': 'pad',
58863 'type': 'string'
58864 },
58865 {
58866 'tfName': 'data_format',
58867 'name': 'dataFormat',
58868 'type': 'string',
58869 'notSupported': true
58870 },
58871 {
58872 'tfName': 'ksize',
58873 'name': 'kernelSize',
58874 'type': 'number[]'
58875 },
58876 {
58877 'tfName': 'T',
58878 'name': 'dtype',
58879 'type': 'dtype',
58880 'notSupported': true
58881 }
58882 ]
58883 },
58884 {
58885 'tfOpName': 'MaxPool',
58886 'category': 'convolution',
58887 'inputs': [
58888 {
58889 'start': 0,
58890 'name': 'x',
58891 'type': 'tensor'
58892 }
58893 ],
58894 'attrs': [
58895 {
58896 'tfName': 'strides',
58897 'name': 'strides',
58898 'type': 'number[]'
58899 },
58900 {
58901 'tfName': 'padding',
58902 'name': 'pad',
58903 'type': 'string'
58904 },
58905 {
58906 'tfName': 'data_format',
58907 'name': 'dataFormat',
58908 'type': 'string',
58909 'notSupported': true
58910 },
58911 {
58912 'tfName': 'ksize',
58913 'name': 'kernelSize',
58914 'type': 'number[]'
58915 },
58916 {
58917 'tfName': 'explicit_paddings',
58918 'name': 'explicitPaddings',
58919 'type': 'number[]',
58920 'defaultValue': [],
58921 'notSupported': true
58922 },
58923 {
58924 'tfName': 'T',
58925 'name': 'dtype',
58926 'type': 'dtype',
58927 'notSupported': true
58928 }
58929 ]
58930 },
58931 {
58932 'tfOpName': 'MaxPoolWithArgmax',
58933 'category': 'convolution',
58934 'inputs': [
58935 {
58936 'start': 0,
58937 'name': 'x',
58938 'type': 'tensor'
58939 }
58940 ],
58941 'attrs': [
58942 {
58943 'tfName': 'strides',
58944 'name': 'strides',
58945 'type': 'number[]'
58946 },
58947 {
58948 'tfName': 'padding',
58949 'name': 'pad',
58950 'type': 'string'
58951 },
58952 {
58953 'tfName': 'ksize',
58954 'name': 'kernelSize',
58955 'type': 'number[]'
58956 },
58957 {
58958 'tfName': 'include_batch_in_index',
58959 'name': 'includeBatchInIndex',
58960 'type': 'bool'
58961 },
58962 {
58963 'tfName': 'T',
58964 'name': 'dtype',
58965 'type': 'dtype',
58966 'notSupported': true
58967 }
58968 ]
58969 },
58970 {
58971 'tfOpName': 'AvgPool3D',
58972 'category': 'convolution',
58973 'inputs': [
58974 {
58975 'start': 0,
58976 'name': 'x',
58977 'type': 'tensor'
58978 }
58979 ],
58980 'attrs': [
58981 {
58982 'tfName': 'strides',
58983 'name': 'strides',
58984 'type': 'number[]'
58985 },
58986 {
58987 'tfName': 'padding',
58988 'name': 'pad',
58989 'type': 'string'
58990 },
58991 {
58992 'tfName': 'data_format',
58993 'name': 'dataFormat',
58994 'type': 'string',
58995 'notSupported': true
58996 },
58997 {
58998 'tfName': 'ksize',
58999 'name': 'kernelSize',
59000 'type': 'number[]'
59001 },
59002 {
59003 'tfName': 'T',
59004 'name': 'dtype',
59005 'type': 'dtype',
59006 'notSupported': true
59007 }
59008 ]
59009 },
59010 {
59011 'tfOpName': 'MaxPool3D',
59012 'category': 'convolution',
59013 'inputs': [
59014 {
59015 'start': 0,
59016 'name': 'x',
59017 'type': 'tensor'
59018 }
59019 ],
59020 'attrs': [
59021 {
59022 'tfName': 'strides',
59023 'name': 'strides',
59024 'type': 'number[]'
59025 },
59026 {
59027 'tfName': 'padding',
59028 'name': 'pad',
59029 'type': 'string'
59030 },
59031 {
59032 'tfName': 'data_format',
59033 'name': 'dataFormat',
59034 'type': 'string',
59035 'notSupported': true
59036 },
59037 {
59038 'tfName': 'ksize',
59039 'name': 'kernelSize',
59040 'type': 'number[]'
59041 },
59042 {
59043 'tfName': 'T',
59044 'name': 'dtype',
59045 'type': 'dtype',
59046 'notSupported': true
59047 }
59048 ]
59049 },
59050 {
59051 'tfOpName': 'Conv1D',
59052 'category': 'convolution',
59053 'inputs': [
59054 {
59055 'start': 0,
59056 'name': 'x',
59057 'type': 'tensor'
59058 },
59059 {
59060 'start': 1,
59061 'name': 'filter',
59062 'type': 'tensor'
59063 }
59064 ],
59065 'attrs': [
59066 {
59067 'tfName': 'stride',
59068 'name': 'stride',
59069 'type': 'number'
59070 },
59071 {
59072 'tfName': 'padding',
59073 'name': 'pad',
59074 'type': 'string'
59075 },
59076 {
59077 'tfName': 'data_format',
59078 'name': 'dataFormat',
59079 'type': 'string',
59080 'defaultValue': 'NWC'
59081 },
59082 {
59083 'tfName': 'T',
59084 'name': 'dtype',
59085 'type': 'dtype',
59086 'notSupported': true
59087 },
59088 {
59089 'tfName': 'dilation',
59090 'name': 'dilation',
59091 'type': 'number',
59092 'defaultValue': 1
59093 }
59094 ]
59095 },
59096 {
59097 'tfOpName': 'Conv2D',
59098 'category': 'convolution',
59099 'inputs': [
59100 {
59101 'start': 0,
59102 'name': 'x',
59103 'type': 'tensor'
59104 },
59105 {
59106 'start': 1,
59107 'name': 'filter',
59108 'type': 'tensor'
59109 }
59110 ],
59111 'attrs': [
59112 {
59113 'tfName': 'T',
59114 'name': 'dtype',
59115 'type': 'dtype',
59116 'notSupported': true
59117 },
59118 {
59119 'tfName': 'strides',
59120 'name': 'strides',
59121 'type': 'number[]'
59122 },
59123 {
59124 'tfName': 'padding',
59125 'name': 'pad',
59126 'type': 'string'
59127 },
59128 {
59129 'tfName': 'useCudnnOnGpu',
59130 'name': 'useCudnnOnGpu',
59131 'type': 'bool'
59132 },
59133 {
59134 'tfName': 'data_format',
59135 'name': 'dataFormat',
59136 'type': 'string',
59137 'defaultValue': 'NHWC'
59138 },
59139 {
59140 'tfName': 'explicit_paddings',
59141 'name': 'explicitPaddings',
59142 'type': 'number[]',
59143 'defaultValue': []
59144 },
59145 {
59146 'tfName': 'dilations',
59147 'name': 'dilations',
59148 'type': 'number[]'
59149 }
59150 ]
59151 },
59152 {
59153 'tfOpName': '_FusedConv2D',
59154 'category': 'convolution',
59155 'inputs': [
59156 {
59157 'start': 0,
59158 'name': 'x',
59159 'type': 'tensor'
59160 },
59161 {
59162 'start': 1,
59163 'name': 'filter',
59164 'type': 'tensor'
59165 },
59166 {
59167 'start': 2,
59168 'end': 0,
59169 'name': 'args',
59170 'type': 'tensors'
59171 }
59172 ],
59173 'attrs': [
59174 {
59175 'tfName': 'num_args',
59176 'name': 'numArgs',
59177 'type': 'number'
59178 },
59179 {
59180 'tfName': 'T',
59181 'name': 'dtype',
59182 'type': 'dtype',
59183 'notSupported': true
59184 },
59185 {
59186 'tfName': 'strides',
59187 'name': 'strides',
59188 'type': 'number[]'
59189 },
59190 {
59191 'tfName': 'padding',
59192 'name': 'pad',
59193 'type': 'string'
59194 },
59195 {
59196 'tfName': 'explicit_paddings',
59197 'name': 'explicitPaddings',
59198 'type': 'number[]',
59199 'defaultValue': []
59200 },
59201 {
59202 'tfName': 'use_cudnn_on_gpu',
59203 'name': 'useCudnnOnGpu',
59204 'type': 'bool',
59205 'defaultValue': true
59206 },
59207 {
59208 'tfName': 'data_format',
59209 'name': 'dataFormat',
59210 'type': 'string',
59211 'defaultValue': 'NHWC'
59212 },
59213 {
59214 'tfName': 'dilations',
59215 'name': 'dilations',
59216 'type': 'number[]',
59217 'defaultValue': [
59218 1,
59219 1,
59220 1,
59221 1
59222 ]
59223 },
59224 {
59225 'tfName': 'fused_ops',
59226 'name': 'fusedOps',
59227 'type': 'string[]',
59228 'defaultValue': []
59229 },
59230 {
59231 'tfName': 'epsilon',
59232 'name': 'epsilon',
59233 'type': 'number',
59234 'defaultValue': 0.0001
59235 },
59236 {
59237 'tfName': 'leakyrelu_alpha',
59238 'name': 'leakyreluAlpha',
59239 'type': 'number',
59240 'defaultValue': 0.2
59241 }
59242 ]
59243 },
59244 {
59245 'tfOpName': 'Conv2DBackpropInput',
59246 'category': 'convolution',
59247 'inputs': [
59248 {
59249 'start': 2,
59250 'name': 'x',
59251 'type': 'tensor'
59252 },
59253 {
59254 'start': 1,
59255 'name': 'filter',
59256 'type': 'tensor'
59257 },
59258 {
59259 'start': 0,
59260 'name': 'outputShape',
59261 'type': 'number[]'
59262 }
59263 ],
59264 'attrs': [
59265 {
59266 'tfName': 'strides',
59267 'name': 'strides',
59268 'type': 'number[]'
59269 },
59270 {
59271 'tfName': 'padding',
59272 'name': 'pad',
59273 'type': 'string'
59274 },
59275 {
59276 'tfName': 'data_format',
59277 'name': 'dataFormat',
59278 'type': 'string',
59279 'notSupported': true
59280 },
59281 {
59282 'tfName': 'explicit_paddings',
59283 'name': 'explicitPaddings',
59284 'type': 'number[]',
59285 'defaultValue': []
59286 },
59287 {
59288 'tfName': 'dilations',
59289 'name': 'dilations',
59290 'type': 'number[]',
59291 'notSupported': true
59292 }
59293 ]
59294 },
59295 {
59296 'tfOpName': 'DepthwiseConv2d',
59297 'category': 'convolution',
59298 'inputs': [
59299 {
59300 'start': 0,
59301 'name': 'input',
59302 'type': 'tensor'
59303 },
59304 {
59305 'start': 1,
59306 'name': 'filter',
59307 'type': 'tensor'
59308 }
59309 ],
59310 'attrs': [
59311 {
59312 'tfName': 'strides',
59313 'name': 'strides',
59314 'type': 'number[]'
59315 },
59316 {
59317 'tfName': 'padding',
59318 'name': 'pad',
59319 'type': 'string'
59320 },
59321 {
59322 'tfName': 'data_format',
59323 'name': 'dataFormat',
59324 'type': 'string',
59325 'defaultValue': 'NHWC'
59326 },
59327 {
59328 'tfName': 'explicit_paddings',
59329 'name': 'explicitPaddings',
59330 'type': 'number[]',
59331 'defaultValue': []
59332 },
59333 {
59334 'tfName': 'dilations',
59335 'name': 'dilations',
59336 'type': 'number[]'
59337 }
59338 ]
59339 },
59340 {
59341 'tfOpName': 'DepthwiseConv2dNative',
59342 'category': 'convolution',
59343 'inputs': [
59344 {
59345 'start': 0,
59346 'name': 'input',
59347 'type': 'tensor'
59348 },
59349 {
59350 'start': 1,
59351 'name': 'filter',
59352 'type': 'tensor'
59353 }
59354 ],
59355 'attrs': [
59356 {
59357 'tfName': 'strides',
59358 'name': 'strides',
59359 'type': 'number[]'
59360 },
59361 {
59362 'tfName': 'padding',
59363 'name': 'pad',
59364 'type': 'string'
59365 },
59366 {
59367 'tfName': 'data_format',
59368 'name': 'dataFormat',
59369 'type': 'string',
59370 'defaultValue': 'NHWC'
59371 },
59372 {
59373 'tfName': 'explicit_paddings',
59374 'name': 'explicitPaddings',
59375 'type': 'number[]',
59376 'defaultValue': []
59377 },
59378 {
59379 'tfName': 'dilations',
59380 'name': 'dilations',
59381 'type': 'number[]'
59382 }
59383 ]
59384 },
59385 {
59386 'tfOpName': 'FusedDepthwiseConv2dNative',
59387 'category': 'convolution',
59388 'inputs': [
59389 {
59390 'start': 0,
59391 'name': 'x',
59392 'type': 'tensor'
59393 },
59394 {
59395 'start': 1,
59396 'name': 'filter',
59397 'type': 'tensor'
59398 },
59399 {
59400 'start': 2,
59401 'end': 0,
59402 'name': 'args',
59403 'type': 'tensors'
59404 }
59405 ],
59406 'attrs': [
59407 {
59408 'tfName': 'num_args',
59409 'name': 'numArgs',
59410 'type': 'number'
59411 },
59412 {
59413 'tfName': 'T',
59414 'name': 'dtype',
59415 'type': 'dtype',
59416 'notSupported': true
59417 },
59418 {
59419 'tfName': 'strides',
59420 'name': 'strides',
59421 'type': 'number[]'
59422 },
59423 {
59424 'tfName': 'padding',
59425 'name': 'pad',
59426 'type': 'string'
59427 },
59428 {
59429 'tfName': 'data_format',
59430 'name': 'dataFormat',
59431 'type': 'string',
59432 'defaultValue': 'NHWC'
59433 },
59434 {
59435 'tfName': 'dilations',
59436 'name': 'dilations',
59437 'type': 'number[]',
59438 'defaultValue': [
59439 1,
59440 1,
59441 1,
59442 1
59443 ]
59444 },
59445 {
59446 'tfName': 'fused_ops',
59447 'name': 'fusedOps',
59448 'type': 'string[]',
59449 'defaultValue': []
59450 },
59451 {
59452 'tfName': 'explicit_paddings',
59453 'name': 'explicitPaddings',
59454 'type': 'number[]',
59455 'defaultValue': []
59456 }
59457 ]
59458 },
59459 {
59460 'tfOpName': 'Conv3D',
59461 'category': 'convolution',
59462 'inputs': [
59463 {
59464 'start': 0,
59465 'name': 'x',
59466 'type': 'tensor'
59467 },
59468 {
59469 'start': 1,
59470 'name': 'filter',
59471 'type': 'tensor'
59472 }
59473 ],
59474 'attrs': [
59475 {
59476 'tfName': 'strides',
59477 'name': 'strides',
59478 'type': 'number[]'
59479 },
59480 {
59481 'tfName': 'padding',
59482 'name': 'pad',
59483 'type': 'string'
59484 },
59485 {
59486 'tfName': 'data_format',
59487 'name': 'dataFormat',
59488 'type': 'string',
59489 'defaultValue': 'NHWC'
59490 },
59491 {
59492 'tfName': 'dilations',
59493 'name': 'dilations',
59494 'type': 'number[]'
59495 }
59496 ]
59497 },
59498 {
59499 'tfOpName': 'Dilation2D',
59500 'category': 'convolution',
59501 'inputs': [
59502 {
59503 'start': 0,
59504 'name': 'x',
59505 'type': 'tensor'
59506 },
59507 {
59508 'start': 1,
59509 'name': 'filter',
59510 'type': 'tensor'
59511 }
59512 ],
59513 'attrs': [
59514 {
59515 'tfName': 'strides',
59516 'name': 'strides',
59517 'type': 'number[]'
59518 },
59519 {
59520 'tfName': 'rates',
59521 'name': 'dilations',
59522 'type': 'number[]'
59523 },
59524 {
59525 'tfName': 'padding',
59526 'name': 'pad',
59527 'type': 'string'
59528 }
59529 ]
59530 }
59531 ];
59532
59533 var convolution = /*#__PURE__*/Object.freeze({
59534 __proto__: null,
59535 json: json$f
59536 });
59537
59538 /**
59539 * @license
59540 * Copyright 2023 Google LLC. All Rights Reserved.
59541 * Licensed under the Apache License, Version 2.0 (the "License");
59542 * you may not use this file except in compliance with the License.
59543 * You may obtain a copy of the License at
59544 *
59545 * http://www.apache.org/licenses/LICENSE-2.0
59546 *
59547 * Unless required by applicable law or agreed to in writing, software
59548 * distributed under the License is distributed on an "AS IS" BASIS,
59549 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
59550 * See the License for the specific language governing permissions and
59551 * limitations under the License.
59552 * =============================================================================
59553 */
59554 const json$e = [
59555 {
59556 'tfOpName': 'Fill',
59557 'category': 'creation',
59558 'inputs': [
59559 {
59560 'start': 0,
59561 'name': 'shape',
59562 'type': 'number[]'
59563 },
59564 {
59565 'start': 1,
59566 'name': 'value',
59567 'type': 'number'
59568 }
59569 ],
59570 'attrs': [
59571 {
59572 'tfName': 'T',
59573 'name': 'dtype',
59574 'type': 'dtype'
59575 }
59576 ]
59577 },
59578 {
59579 'tfOpName': 'LinSpace',
59580 'category': 'creation',
59581 'inputs': [
59582 {
59583 'start': 0,
59584 'name': 'start',
59585 'type': 'number'
59586 },
59587 {
59588 'start': 1,
59589 'name': 'stop',
59590 'type': 'number'
59591 },
59592 {
59593 'start': 2,
59594 'name': 'num',
59595 'type': 'number'
59596 }
59597 ],
59598 'attrs': [
59599 {
59600 'tfName': 'T',
59601 'name': 'dtype',
59602 'type': 'dtype',
59603 'notSupported': true
59604 }
59605 ]
59606 },
59607 {
59608 'tfOpName': 'OneHot',
59609 'category': 'creation',
59610 'inputs': [
59611 {
59612 'start': 0,
59613 'name': 'indices',
59614 'type': 'tensor'
59615 },
59616 {
59617 'start': 1,
59618 'name': 'depth',
59619 'type': 'number'
59620 },
59621 {
59622 'start': 2,
59623 'name': 'onValue',
59624 'type': 'number',
59625 'defaultValue': 1
59626 },
59627 {
59628 'start': 3,
59629 'name': 'offValue',
59630 'type': 'number',
59631 'defaultValue': 0
59632 }
59633 ],
59634 'attrs': [
59635 {
59636 'tfName': 'axis',
59637 'name': 'axis',
59638 'type': 'number',
59639 'notSupported': true
59640 },
59641 {
59642 'tfName': 'T',
59643 'name': 'dtype',
59644 'type': 'dtype'
59645 }
59646 ]
59647 },
59648 {
59649 'tfOpName': 'Ones',
59650 'category': 'creation',
59651 'inputs': [
59652 {
59653 'start': 0,
59654 'name': 'shape',
59655 'type': 'number[]'
59656 }
59657 ],
59658 'attrs': [
59659 {
59660 'tfName': 'T',
59661 'name': 'dtype',
59662 'type': 'dtype'
59663 }
59664 ]
59665 },
59666 {
59667 'tfOpName': 'OnesLike',
59668 'category': 'creation',
59669 'inputs': [
59670 {
59671 'start': 0,
59672 'name': 'x',
59673 'type': 'tensor'
59674 }
59675 ],
59676 'attrs': [
59677 {
59678 'tfName': 'dtype',
59679 'name': 'dtype',
59680 'type': 'dtype'
59681 }
59682 ]
59683 },
59684 {
59685 'tfOpName': 'RandomStandardNormal',
59686 'category': 'creation',
59687 'inputs': [
59688 {
59689 'start': 0,
59690 'name': 'shape',
59691 'type': 'number[]'
59692 }
59693 ],
59694 'attrs': [
59695 {
59696 'tfName': 'seed',
59697 'name': 'seed',
59698 'type': 'number',
59699 'defaultValue': 0
59700 },
59701 {
59702 'tfName': 'seed2',
59703 'name': 'seed2',
59704 'type': 'number',
59705 'defaultValue': 0,
59706 'notSupported': true
59707 },
59708 {
59709 'tfName': 'dtype',
59710 'name': 'dtype',
59711 'type': 'dtype'
59712 },
59713 {
59714 'tfName': 'T',
59715 'name': 'T',
59716 'type': 'number',
59717 'notSupported': true
59718 }
59719 ]
59720 },
59721 {
59722 'tfOpName': 'RandomUniform',
59723 'category': 'creation',
59724 'inputs': [
59725 {
59726 'start': 0,
59727 'name': 'shape',
59728 'type': 'number[]'
59729 }
59730 ],
59731 'attrs': [
59732 {
59733 'tfName': 'minval',
59734 'name': 'minval',
59735 'type': 'number',
59736 'defaultValue': 0
59737 },
59738 {
59739 'tfName': 'maxval',
59740 'name': 'maxval',
59741 'type': 'number',
59742 'defaultValue': 1
59743 },
59744 {
59745 'tfName': 'dtype',
59746 'name': 'dtype',
59747 'type': 'dtype'
59748 },
59749 {
59750 'tfName': 'seed',
59751 'name': 'seed',
59752 'type': 'number',
59753 'defaultValue': 0
59754 },
59755 {
59756 'tfName': 'seed2',
59757 'name': 'seed2',
59758 'type': 'number',
59759 'defaultValue': 0,
59760 'notSupported': true
59761 },
59762 {
59763 'tfName': 'T',
59764 'name': 'T',
59765 'type': 'number',
59766 'notSupported': true
59767 }
59768 ]
59769 },
59770 {
59771 'tfOpName': 'RandomUniformInt',
59772 'category': 'creation',
59773 'inputs': [
59774 {
59775 'start': 0,
59776 'name': 'shape',
59777 'type': 'number[]'
59778 }
59779 ],
59780 'attrs': [
59781 {
59782 'tfName': 'minval',
59783 'name': 'minval',
59784 'type': 'number'
59785 },
59786 {
59787 'tfName': 'maxval',
59788 'name': 'maxval',
59789 'type': 'number'
59790 },
59791 {
59792 'tfName': 'seed',
59793 'name': 'seed',
59794 'type': 'number',
59795 'defaultValue': 0
59796 },
59797 {
59798 'tfName': 'seed2',
59799 'name': 'seed2',
59800 'type': 'number',
59801 'defaultValue': 0,
59802 'notSupported': true
59803 }
59804 ]
59805 },
59806 {
59807 'tfOpName': 'Range',
59808 'category': 'creation',
59809 'inputs': [
59810 {
59811 'start': 0,
59812 'name': 'start',
59813 'type': 'number'
59814 },
59815 {
59816 'start': 1,
59817 'name': 'stop',
59818 'type': 'number'
59819 },
59820 {
59821 'start': 2,
59822 'name': 'step',
59823 'type': 'number',
59824 'defaultValue': 0
59825 }
59826 ],
59827 'attrs': [
59828 {
59829 'tfName': 'Tidx',
59830 'name': 'dtype',
59831 'type': 'dtype'
59832 }
59833 ]
59834 },
59835 {
59836 'tfOpName': 'TruncatedNormal',
59837 'category': 'creation',
59838 'inputs': [
59839 {
59840 'start': 0,
59841 'name': 'shape',
59842 'type': 'number[]'
59843 }
59844 ],
59845 'attrs': [
59846 {
59847 'tfName': 'means',
59848 'name': 'mean',
59849 'type': 'number',
59850 'defaultValue': 0
59851 },
59852 {
59853 'tfName': 'stddev',
59854 'name': 'stdDev',
59855 'type': 'number',
59856 'defaultValue': 1
59857 },
59858 {
59859 'tfName': 'seed',
59860 'name': 'seed',
59861 'type': 'number'
59862 },
59863 {
59864 'tfName': 'seed2',
59865 'name': 'seed2',
59866 'type': 'number',
59867 'defaultValue': 0,
59868 'notSupported': true
59869 },
59870 {
59871 'tfName': 'dtype',
59872 'name': 'dtype',
59873 'type': 'dtype'
59874 },
59875 {
59876 'tfName': 'T',
59877 'name': 'T',
59878 'type': 'number',
59879 'notSupported': true
59880 }
59881 ]
59882 },
59883 {
59884 'tfOpName': 'Zeros',
59885 'category': 'creation',
59886 'inputs': [
59887 {
59888 'start': 0,
59889 'name': 'shape',
59890 'type': 'number[]'
59891 }
59892 ],
59893 'attrs': [
59894 {
59895 'tfName': 'T',
59896 'name': 'dtype',
59897 'type': 'dtype'
59898 }
59899 ]
59900 },
59901 {
59902 'tfOpName': 'ZerosLike',
59903 'category': 'creation',
59904 'inputs': [
59905 {
59906 'start': 0,
59907 'name': 'x',
59908 'type': 'tensor'
59909 }
59910 ],
59911 'attrs': [
59912 {
59913 'tfName': 'T',
59914 'name': 'dtype',
59915 'type': 'dtype'
59916 }
59917 ]
59918 },
59919 {
59920 'tfOpName': 'Multinomial',
59921 'category': 'creation',
59922 'inputs': [
59923 {
59924 'start': 0,
59925 'name': 'logits',
59926 'type': 'tensor'
59927 },
59928 {
59929 'start': 1,
59930 'name': 'numSamples',
59931 'type': 'number'
59932 }
59933 ],
59934 'attrs': [
59935 {
59936 'tfName': 'seed',
59937 'name': 'seed',
59938 'type': 'number'
59939 },
59940 {
59941 'tfName': 'seed2',
59942 'name': 'seed2',
59943 'type': 'number'
59944 },
59945 {
59946 'tfName': 'T',
59947 'name': 'dtype',
59948 'type': 'dtype'
59949 },
59950 {
59951 'tfName': 'output_dtype',
59952 'name': 'output_dtype',
59953 'type': 'dtype'
59954 }
59955 ]
59956 }
59957 ];
59958
59959 var creation = /*#__PURE__*/Object.freeze({
59960 __proto__: null,
59961 json: json$e
59962 });
59963
59964 /**
59965 * @license
59966 * Copyright 2023 Google LLC. All Rights Reserved.
59967 * Licensed under the Apache License, Version 2.0 (the "License");
59968 * you may not use this file except in compliance with the License.
59969 * You may obtain a copy of the License at
59970 *
59971 * http://www.apache.org/licenses/LICENSE-2.0
59972 *
59973 * Unless required by applicable law or agreed to in writing, software
59974 * distributed under the License is distributed on an "AS IS" BASIS,
59975 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
59976 * See the License for the specific language governing permissions and
59977 * limitations under the License.
59978 * =============================================================================
59979 */
59980 const json$d = [
59981 {
59982 'tfOpName': 'NonMaxSuppressionV2',
59983 'category': 'dynamic',
59984 'inputs': [
59985 {
59986 'start': 0,
59987 'name': 'boxes',
59988 'type': 'tensor'
59989 },
59990 {
59991 'start': 1,
59992 'name': 'scores',
59993 'type': 'tensor'
59994 },
59995 {
59996 'start': 2,
59997 'name': 'maxOutputSize',
59998 'type': 'number'
59999 },
60000 {
60001 'start': 3,
60002 'name': 'iouThreshold',
60003 'type': 'number'
60004 }
60005 ]
60006 },
60007 {
60008 'tfOpName': 'NonMaxSuppressionV3',
60009 'category': 'dynamic',
60010 'inputs': [
60011 {
60012 'start': 0,
60013 'name': 'boxes',
60014 'type': 'tensor'
60015 },
60016 {
60017 'start': 1,
60018 'name': 'scores',
60019 'type': 'tensor'
60020 },
60021 {
60022 'start': 2,
60023 'name': 'maxOutputSize',
60024 'type': 'number'
60025 },
60026 {
60027 'start': 3,
60028 'name': 'iouThreshold',
60029 'type': 'number'
60030 },
60031 {
60032 'start': 4,
60033 'name': 'scoreThreshold',
60034 'type': 'number'
60035 }
60036 ]
60037 },
60038 {
60039 'tfOpName': 'NonMaxSuppressionV4',
60040 'category': 'dynamic',
60041 'inputs': [
60042 {
60043 'start': 0,
60044 'name': 'boxes',
60045 'type': 'tensor'
60046 },
60047 {
60048 'start': 1,
60049 'name': 'scores',
60050 'type': 'tensor'
60051 },
60052 {
60053 'start': 2,
60054 'name': 'maxOutputSize',
60055 'type': 'number'
60056 },
60057 {
60058 'start': 3,
60059 'name': 'iouThreshold',
60060 'type': 'number'
60061 },
60062 {
60063 'start': 4,
60064 'name': 'scoreThreshold',
60065 'type': 'number'
60066 }
60067 ],
60068 'attrs': [
60069 {
60070 'tfName': 'T',
60071 'name': 'dtype',
60072 'type': 'dtype',
60073 'notSupported': true
60074 },
60075 {
60076 'tfName': 'T_threshold',
60077 'name': 'threshold',
60078 'type': 'dtype',
60079 'notSupported': true
60080 },
60081 {
60082 'tfName': 'pad_to_max_output_size',
60083 'name': 'padToMaxOutputSize',
60084 'type': 'bool'
60085 }
60086 ]
60087 },
60088 {
60089 'tfOpName': 'NonMaxSuppressionV5',
60090 'category': 'dynamic',
60091 'inputs': [
60092 {
60093 'start': 0,
60094 'name': 'boxes',
60095 'type': 'tensor'
60096 },
60097 {
60098 'start': 1,
60099 'name': 'scores',
60100 'type': 'tensor'
60101 },
60102 {
60103 'start': 2,
60104 'name': 'maxOutputSize',
60105 'type': 'number'
60106 },
60107 {
60108 'start': 3,
60109 'name': 'iouThreshold',
60110 'type': 'number'
60111 },
60112 {
60113 'start': 4,
60114 'name': 'scoreThreshold',
60115 'type': 'number'
60116 },
60117 {
60118 'start': 5,
60119 'name': 'softNmsSigma',
60120 'type': 'number'
60121 }
60122 ]
60123 },
60124 {
60125 'tfOpName': 'Where',
60126 'category': 'dynamic',
60127 'inputs': [
60128 {
60129 'start': 0,
60130 'name': 'condition',
60131 'type': 'tensor'
60132 }
60133 ],
60134 'attrs': [
60135 {
60136 'tfName': 'T',
60137 'name': 'dtype',
60138 'type': 'dtype',
60139 'notSupported': true
60140 }
60141 ]
60142 },
60143 {
60144 'tfOpName': 'ListDiff',
60145 'category': 'dynamic',
60146 'inputs': [
60147 {
60148 'start': 0,
60149 'name': 'x',
60150 'type': 'tensor'
60151 },
60152 {
60153 'start': 1,
60154 'name': 'y',
60155 'type': 'tensor'
60156 }
60157 ],
60158 'attrs': [
60159 {
60160 'tfName': 'T',
60161 'name': 'dtype',
60162 'type': 'dtype',
60163 'notSupported': true
60164 }
60165 ]
60166 }
60167 ];
60168
60169 var dynamic = /*#__PURE__*/Object.freeze({
60170 __proto__: null,
60171 json: json$d
60172 });
60173
60174 /**
60175 * @license
60176 * Copyright 2023 Google LLC. All Rights Reserved.
60177 * Licensed under the Apache License, Version 2.0 (the "License");
60178 * you may not use this file except in compliance with the License.
60179 * You may obtain a copy of the License at
60180 *
60181 * http://www.apache.org/licenses/LICENSE-2.0
60182 *
60183 * Unless required by applicable law or agreed to in writing, software
60184 * distributed under the License is distributed on an "AS IS" BASIS,
60185 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
60186 * See the License for the specific language governing permissions and
60187 * limitations under the License.
60188 * =============================================================================
60189 */
60190 const json$c = [
60191 {
60192 'tfOpName': 'LowerBound',
60193 'category': 'evaluation',
60194 'inputs': [
60195 {
60196 'start': 0,
60197 'name': 'sortedSequence',
60198 'type': 'tensor'
60199 },
60200 {
60201 'start': 1,
60202 'name': 'values',
60203 'type': 'tensor'
60204 }
60205 ]
60206 },
60207 {
60208 'tfOpName': 'TopKV2',
60209 'category': 'evaluation',
60210 'inputs': [
60211 {
60212 'start': 0,
60213 'name': 'x',
60214 'type': 'tensor'
60215 },
60216 {
60217 'start': 1,
60218 'name': 'k',
60219 'type': 'number'
60220 }
60221 ],
60222 'attrs': [
60223 {
60224 'tfName': 'sorted',
60225 'name': 'sorted',
60226 'type': 'bool'
60227 }
60228 ]
60229 },
60230 {
60231 'tfOpName': 'UpperBound',
60232 'category': 'evaluation',
60233 'inputs': [
60234 {
60235 'start': 0,
60236 'name': 'sortedSequence',
60237 'type': 'tensor'
60238 },
60239 {
60240 'start': 1,
60241 'name': 'values',
60242 'type': 'tensor'
60243 }
60244 ]
60245 },
60246 {
60247 'tfOpName': 'Unique',
60248 'category': 'evaluation',
60249 'inputs': [
60250 {
60251 'start': 0,
60252 'name': 'x',
60253 'type': 'tensor'
60254 }
60255 ]
60256 },
60257 {
60258 'tfOpName': 'UniqueV2',
60259 'category': 'evaluation',
60260 'inputs': [
60261 {
60262 'start': 0,
60263 'name': 'x',
60264 'type': 'tensor'
60265 },
60266 {
60267 'start': 1,
60268 'name': 'axis',
60269 'type': 'number'
60270 }
60271 ]
60272 }
60273 ];
60274
60275 var evaluation = /*#__PURE__*/Object.freeze({
60276 __proto__: null,
60277 json: json$c
60278 });
60279
60280 /**
60281 * @license
60282 * Copyright 2023 Google LLC. All Rights Reserved.
60283 * Licensed under the Apache License, Version 2.0 (the "License");
60284 * you may not use this file except in compliance with the License.
60285 * You may obtain a copy of the License at
60286 *
60287 * http://www.apache.org/licenses/LICENSE-2.0
60288 *
60289 * Unless required by applicable law or agreed to in writing, software
60290 * distributed under the License is distributed on an "AS IS" BASIS,
60291 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
60292 * See the License for the specific language governing permissions and
60293 * limitations under the License.
60294 * =============================================================================
60295 */
60296 const json$b = [
60297 {
60298 'tfOpName': 'PlaceholderWithDefault',
60299 'category': 'graph',
60300 'inputs': [
60301 {
60302 'start': 0,
60303 'name': 'default',
60304 'type': 'tensor'
60305 }
60306 ],
60307 'attrs': [
60308 {
60309 'tfName': 'shape',
60310 'name': 'shape',
60311 'type': 'shape'
60312 },
60313 {
60314 'tfName': 'dtype',
60315 'name': 'dtype',
60316 'type': 'dtype'
60317 }
60318 ]
60319 },
60320 {
60321 'tfOpName': 'Placeholder',
60322 'category': 'graph',
60323 'attrs': [
60324 {
60325 'tfName': 'shape',
60326 'name': 'shape',
60327 'type': 'shape'
60328 },
60329 {
60330 'tfName': 'dtype',
60331 'name': 'dtype',
60332 'type': 'dtype'
60333 }
60334 ]
60335 },
60336 {
60337 'tfOpName': 'Const',
60338 'category': 'graph'
60339 },
60340 {
60341 'tfOpName': 'Identity',
60342 'category': 'graph',
60343 'inputs': [
60344 {
60345 'start': 0,
60346 'name': 'x',
60347 'type': 'tensor'
60348 }
60349 ]
60350 },
60351 {
60352 'tfOpName': 'IdentityN',
60353 'category': 'graph',
60354 'inputs': [
60355 {
60356 'start': 0,
60357 'end': 0,
60358 'name': 'x',
60359 'type': 'tensors'
60360 }
60361 ]
60362 },
60363 {
60364 'tfOpName': 'Snapshot',
60365 'category': 'graph',
60366 'inputs': [
60367 {
60368 'start': 0,
60369 'name': 'x',
60370 'type': 'tensor'
60371 }
60372 ]
60373 },
60374 {
60375 'tfOpName': 'Rank',
60376 'category': 'graph',
60377 'inputs': [
60378 {
60379 'start': 0,
60380 'name': 'x',
60381 'type': 'tensor'
60382 }
60383 ]
60384 },
60385 {
60386 'tfOpName': 'Size',
60387 'category': 'graph',
60388 'inputs': [
60389 {
60390 'start': 0,
60391 'name': 'x',
60392 'type': 'tensor'
60393 }
60394 ]
60395 },
60396 {
60397 'tfOpName': 'Shape',
60398 'category': 'graph',
60399 'inputs': [
60400 {
60401 'start': 0,
60402 'name': 'x',
60403 'type': 'tensor'
60404 }
60405 ]
60406 },
60407 {
60408 'tfOpName': 'ShapeN',
60409 'category': 'graph',
60410 'inputs': [
60411 {
60412 'start': 0,
60413 'end': 0,
60414 'name': 'x',
60415 'type': 'tensors'
60416 }
60417 ]
60418 },
60419 {
60420 'tfOpName': 'Print',
60421 'category': 'graph',
60422 'inputs': [
60423 {
60424 'start': 0,
60425 'name': 'x',
60426 'type': 'tensor'
60427 },
60428 {
60429 'start': 1,
60430 'name': 'data',
60431 'type': 'tensors'
60432 }
60433 ],
60434 'attrs': [
60435 {
60436 'tfName': 'message',
60437 'name': 'message',
60438 'type': 'string'
60439 },
60440 {
60441 'tfName': 'first_n',
60442 'name': 'firstN',
60443 'type': 'number',
60444 'notSupported': true
60445 },
60446 {
60447 'tfName': 'summarize',
60448 'name': 'summarize',
60449 'type': 'number',
60450 'defaultValue': 3
60451 }
60452 ]
60453 },
60454 {
60455 'tfOpName': 'NoOp',
60456 'category': 'graph',
60457 'inputs': []
60458 },
60459 {
60460 'tfOpName': 'StopGradient',
60461 'category': 'graph',
60462 'inputs': [
60463 {
60464 'start': 0,
60465 'name': 'x',
60466 'type': 'tensor'
60467 }
60468 ]
60469 },
60470 {
60471 'tfOpName': 'FakeQuantWithMinMaxVars',
60472 'category': 'graph',
60473 'inputs': [
60474 {
60475 'start': 0,
60476 'name': 'x',
60477 'type': 'tensor'
60478 }
60479 ],
60480 'attrs': [
60481 {
60482 'tfName': 'min',
60483 'name': 'min',
60484 'type': 'number'
60485 },
60486 {
60487 'tfName': 'max',
60488 'name': 'max',
60489 'type': 'number'
60490 }
60491 ]
60492 }
60493 ];
60494
60495 var graph = /*#__PURE__*/Object.freeze({
60496 __proto__: null,
60497 json: json$b
60498 });
60499
60500 /**
60501 * @license
60502 * Copyright 2023 Google LLC. All Rights Reserved.
60503 * Licensed under the Apache License, Version 2.0 (the "License");
60504 * you may not use this file except in compliance with the License.
60505 * You may obtain a copy of the License at
60506 *
60507 * http://www.apache.org/licenses/LICENSE-2.0
60508 *
60509 * Unless required by applicable law or agreed to in writing, software
60510 * distributed under the License is distributed on an "AS IS" BASIS,
60511 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
60512 * See the License for the specific language governing permissions and
60513 * limitations under the License.
60514 * =============================================================================
60515 */
60516 const json$a = [
60517 {
60518 'tfOpName': 'HashTable',
60519 'category': 'hash_table',
60520 'inputs': [],
60521 'attrs': [
60522 {
60523 'tfName': 'shared_name',
60524 'name': 'sharedName',
60525 'type': 'string'
60526 },
60527 {
60528 'tfName': 'use_node_name_sharing',
60529 'name': 'useNodeNameSharing',
60530 'type': 'bool'
60531 },
60532 {
60533 'tfName': 'key_dtype',
60534 'name': 'keyDType',
60535 'type': 'dtype'
60536 },
60537 {
60538 'tfName': 'value_dtype',
60539 'name': 'valueDType',
60540 'type': 'dtype'
60541 }
60542 ]
60543 },
60544 {
60545 'tfOpName': 'HashTableV2',
60546 'category': 'hash_table',
60547 'inputs': [],
60548 'attrs': [
60549 {
60550 'tfName': 'shared_name',
60551 'name': 'sharedName',
60552 'type': 'string'
60553 },
60554 {
60555 'tfName': 'use_node_name_sharing',
60556 'name': 'useNodeNameSharing',
60557 'type': 'bool'
60558 },
60559 {
60560 'tfName': 'key_dtype',
60561 'name': 'keyDType',
60562 'type': 'dtype'
60563 },
60564 {
60565 'tfName': 'value_dtype',
60566 'name': 'valueDType',
60567 'type': 'dtype'
60568 }
60569 ]
60570 },
60571 {
60572 'tfOpName': 'LookupTableImport',
60573 'category': 'hash_table',
60574 'inputs': [
60575 {
60576 'start': 0,
60577 'name': 'tableHandle',
60578 'type': 'tensor'
60579 },
60580 {
60581 'start': 1,
60582 'name': 'keys',
60583 'type': 'tensor'
60584 },
60585 {
60586 'start': 2,
60587 'name': 'values',
60588 'type': 'tensor'
60589 }
60590 ],
60591 'attrs': [
60592 {
60593 'tfName': 'Tin',
60594 'name': 'tIn',
60595 'type': 'dtype',
60596 'notSupported': true
60597 },
60598 {
60599 'tfName': 'Tout',
60600 'name': 'tOut',
60601 'type': 'dtype',
60602 'notSupported': true
60603 }
60604 ]
60605 },
60606 {
60607 'tfOpName': 'LookupTableImportV2',
60608 'category': 'hash_table',
60609 'inputs': [
60610 {
60611 'start': 0,
60612 'name': 'tableHandle',
60613 'type': 'tensor'
60614 },
60615 {
60616 'start': 1,
60617 'name': 'keys',
60618 'type': 'tensor'
60619 },
60620 {
60621 'start': 2,
60622 'name': 'values',
60623 'type': 'tensor'
60624 }
60625 ],
60626 'attrs': [
60627 {
60628 'tfName': 'Tin',
60629 'name': 'tIn',
60630 'type': 'dtype',
60631 'notSupported': true
60632 },
60633 {
60634 'tfName': 'Tout',
60635 'name': 'tOut',
60636 'type': 'dtype',
60637 'notSupported': true
60638 }
60639 ]
60640 },
60641 {
60642 'tfOpName': 'LookupTableFind',
60643 'category': 'hash_table',
60644 'inputs': [
60645 {
60646 'start': 0,
60647 'name': 'tableHandle',
60648 'type': 'tensor'
60649 },
60650 {
60651 'start': 1,
60652 'name': 'keys',
60653 'type': 'tensor'
60654 },
60655 {
60656 'start': 2,
60657 'name': 'defaultValue',
60658 'type': 'tensor'
60659 }
60660 ],
60661 'attrs': [
60662 {
60663 'tfName': 'Tin',
60664 'name': 'tIn',
60665 'type': 'dtype',
60666 'notSupported': true
60667 },
60668 {
60669 'tfName': 'Tout',
60670 'name': 'tOut',
60671 'type': 'dtype',
60672 'notSupported': true
60673 }
60674 ]
60675 },
60676 {
60677 'tfOpName': 'LookupTableFindV2',
60678 'category': 'hash_table',
60679 'inputs': [
60680 {
60681 'start': 0,
60682 'name': 'tableHandle',
60683 'type': 'tensor'
60684 },
60685 {
60686 'start': 1,
60687 'name': 'keys',
60688 'type': 'tensor'
60689 },
60690 {
60691 'start': 2,
60692 'name': 'defaultValue',
60693 'type': 'tensor'
60694 }
60695 ],
60696 'attrs': [
60697 {
60698 'tfName': 'Tin',
60699 'name': 'tIn',
60700 'type': 'dtype',
60701 'notSupported': true
60702 },
60703 {
60704 'tfName': 'Tout',
60705 'name': 'tOut',
60706 'type': 'dtype',
60707 'notSupported': true
60708 }
60709 ]
60710 },
60711 {
60712 'tfOpName': 'LookupTableSize',
60713 'category': 'hash_table',
60714 'inputs': [
60715 {
60716 'start': 0,
60717 'name': 'tableHandle',
60718 'type': 'tensor'
60719 }
60720 ]
60721 },
60722 {
60723 'tfOpName': 'LookupTableSizeV2',
60724 'category': 'hash_table',
60725 'inputs': [
60726 {
60727 'start': 0,
60728 'name': 'tableHandle',
60729 'type': 'tensor'
60730 }
60731 ]
60732 },
60733 {
60734 'tfOpName': 'InitializeTable',
60735 'category': 'hash_table',
60736 'inputs': [
60737 {
60738 'start': 0,
60739 'name': 'tableHandle',
60740 'type': 'tensor'
60741 },
60742 {
60743 'start': 1,
60744 'name': 'keys',
60745 'type': 'tensor'
60746 },
60747 {
60748 'start': 2,
60749 'name': 'values',
60750 'type': 'tensor'
60751 }
60752 ]
60753 },
60754 {
60755 'tfOpName': 'InitializeTableV2',
60756 'category': 'hash_table',
60757 'inputs': [
60758 {
60759 'start': 0,
60760 'name': 'tableHandle',
60761 'type': 'tensor'
60762 },
60763 {
60764 'start': 1,
60765 'name': 'keys',
60766 'type': 'tensor'
60767 },
60768 {
60769 'start': 2,
60770 'name': 'values',
60771 'type': 'tensor'
60772 }
60773 ]
60774 }
60775 ];
60776
60777 var hashTable = /*#__PURE__*/Object.freeze({
60778 __proto__: null,
60779 json: json$a
60780 });
60781
60782 /**
60783 * @license
60784 * Copyright 2023 Google LLC. All Rights Reserved.
60785 * Licensed under the Apache License, Version 2.0 (the "License");
60786 * you may not use this file except in compliance with the License.
60787 * You may obtain a copy of the License at
60788 *
60789 * http://www.apache.org/licenses/LICENSE-2.0
60790 *
60791 * Unless required by applicable law or agreed to in writing, software
60792 * distributed under the License is distributed on an "AS IS" BASIS,
60793 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
60794 * See the License for the specific language governing permissions and
60795 * limitations under the License.
60796 * =============================================================================
60797 */
60798 const json$9 = [
60799 {
60800 'tfOpName': 'ResizeBilinear',
60801 'category': 'image',
60802 'inputs': [
60803 {
60804 'start': 0,
60805 'name': 'images',
60806 'type': 'tensor'
60807 },
60808 {
60809 'start': 1,
60810 'name': 'size',
60811 'type': 'number[]'
60812 }
60813 ],
60814 'attrs': [
60815 {
60816 'tfName': 'align_corners',
60817 'name': 'alignCorners',
60818 'type': 'bool'
60819 },
60820 {
60821 'tfName': 'half_pixel_centers',
60822 'name': 'halfPixelCenters',
60823 'type': 'bool'
60824 },
60825 {
60826 'tfName': 'T',
60827 'name': 'dtype',
60828 'type': 'dtype',
60829 'notSupported': true
60830 }
60831 ]
60832 },
60833 {
60834 'tfOpName': 'ResizeNearestNeighbor',
60835 'category': 'image',
60836 'inputs': [
60837 {
60838 'start': 0,
60839 'name': 'images',
60840 'type': 'tensor'
60841 },
60842 {
60843 'start': 1,
60844 'name': 'size',
60845 'type': 'number[]'
60846 }
60847 ],
60848 'attrs': [
60849 {
60850 'tfName': 'align_corners',
60851 'name': 'alignCorners',
60852 'type': 'bool'
60853 },
60854 {
60855 'tfName': 'half_pixel_centers',
60856 'name': 'halfPixelCenters',
60857 'type': 'bool'
60858 },
60859 {
60860 'tfName': 'T',
60861 'name': 'dtype',
60862 'type': 'dtype',
60863 'notSupported': true
60864 }
60865 ]
60866 },
60867 {
60868 'tfOpName': 'CropAndResize',
60869 'category': 'image',
60870 'inputs': [
60871 {
60872 'start': 0,
60873 'name': 'image',
60874 'type': 'tensor'
60875 },
60876 {
60877 'start': 1,
60878 'name': 'boxes',
60879 'type': 'tensor'
60880 },
60881 {
60882 'start': 2,
60883 'name': 'boxInd',
60884 'type': 'tensor'
60885 },
60886 {
60887 'start': 3,
60888 'name': 'cropSize',
60889 'type': 'number[]'
60890 }
60891 ],
60892 'attrs': [
60893 {
60894 'tfName': 'method',
60895 'name': 'method',
60896 'type': 'string'
60897 },
60898 {
60899 'tfName': 'extrapolation_value',
60900 'name': 'extrapolationValue',
60901 'type': 'number'
60902 }
60903 ]
60904 },
60905 {
60906 'tfOpName': 'ImageProjectiveTransformV3',
60907 'category': 'image',
60908 'inputs': [
60909 {
60910 'start': 0,
60911 'name': 'images',
60912 'type': 'tensor'
60913 },
60914 {
60915 'start': 1,
60916 'name': 'transforms',
60917 'type': 'tensor'
60918 },
60919 {
60920 'start': 2,
60921 'name': 'outputShape',
60922 'type': 'number[]'
60923 },
60924 {
60925 'start': 3,
60926 'name': 'fillValue',
60927 'type': 'number'
60928 }
60929 ],
60930 'attrs': [
60931 {
60932 'tfName': 'interpolation',
60933 'name': 'interpolation',
60934 'type': 'string'
60935 },
60936 {
60937 'tfName': 'fill_mode',
60938 'name': 'fillMode',
60939 'type': 'string'
60940 }
60941 ]
60942 }
60943 ];
60944
60945 var image = /*#__PURE__*/Object.freeze({
60946 __proto__: null,
60947 json: json$9
60948 });
60949
60950 /**
60951 * @license
60952 * Copyright 2023 Google LLC. All Rights Reserved.
60953 * Licensed under the Apache License, Version 2.0 (the "License");
60954 * you may not use this file except in compliance with the License.
60955 * You may obtain a copy of the License at
60956 *
60957 * http://www.apache.org/licenses/LICENSE-2.0
60958 *
60959 * Unless required by applicable law or agreed to in writing, software
60960 * distributed under the License is distributed on an "AS IS" BASIS,
60961 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
60962 * See the License for the specific language governing permissions and
60963 * limitations under the License.
60964 * =============================================================================
60965 */
60966 const json$8 = [
60967 {
60968 'tfOpName': 'Equal',
60969 'category': 'logical',
60970 'inputs': [
60971 {
60972 'start': 0,
60973 'name': 'a',
60974 'type': 'tensor'
60975 },
60976 {
60977 'start': 1,
60978 'name': 'b',
60979 'type': 'tensor'
60980 }
60981 ],
60982 'attrs': [
60983 {
60984 'tfName': 'T',
60985 'name': 'dtype',
60986 'type': 'dtype',
60987 'notSupported': true
60988 }
60989 ]
60990 },
60991 {
60992 'tfOpName': 'NotEqual',
60993 'category': 'logical',
60994 'inputs': [
60995 {
60996 'start': 0,
60997 'name': 'a',
60998 'type': 'tensor'
60999 },
61000 {
61001 'start': 1,
61002 'name': 'b',
61003 'type': 'tensor'
61004 }
61005 ],
61006 'attrs': [
61007 {
61008 'tfName': 'T',
61009 'name': 'dtype',
61010 'type': 'dtype',
61011 'notSupported': true
61012 }
61013 ]
61014 },
61015 {
61016 'tfOpName': 'Greater',
61017 'category': 'logical',
61018 'inputs': [
61019 {
61020 'start': 0,
61021 'name': 'a',
61022 'type': 'tensor'
61023 },
61024 {
61025 'start': 1,
61026 'name': 'b',
61027 'type': 'tensor'
61028 }
61029 ],
61030 'attrs': [
61031 {
61032 'tfName': 'T',
61033 'name': 'dtype',
61034 'type': 'dtype',
61035 'notSupported': true
61036 }
61037 ]
61038 },
61039 {
61040 'tfOpName': 'GreaterEqual',
61041 'category': 'logical',
61042 'inputs': [
61043 {
61044 'start': 0,
61045 'name': 'a',
61046 'type': 'tensor'
61047 },
61048 {
61049 'start': 1,
61050 'name': 'b',
61051 'type': 'tensor'
61052 }
61053 ],
61054 'attrs': [
61055 {
61056 'tfName': 'T',
61057 'name': 'dtype',
61058 'type': 'dtype',
61059 'notSupported': true
61060 }
61061 ]
61062 },
61063 {
61064 'tfOpName': 'Less',
61065 'category': 'logical',
61066 'inputs': [
61067 {
61068 'start': 0,
61069 'name': 'a',
61070 'type': 'tensor'
61071 },
61072 {
61073 'start': 1,
61074 'name': 'b',
61075 'type': 'tensor'
61076 }
61077 ],
61078 'attrs': [
61079 {
61080 'tfName': 'T',
61081 'name': 'dtype',
61082 'type': 'dtype',
61083 'notSupported': true
61084 }
61085 ]
61086 },
61087 {
61088 'tfOpName': 'LessEqual',
61089 'category': 'logical',
61090 'inputs': [
61091 {
61092 'start': 0,
61093 'name': 'a',
61094 'type': 'tensor'
61095 },
61096 {
61097 'start': 1,
61098 'name': 'b',
61099 'type': 'tensor'
61100 }
61101 ],
61102 'attrs': [
61103 {
61104 'tfName': 'T',
61105 'name': 'dtype',
61106 'type': 'dtype',
61107 'notSupported': true
61108 }
61109 ]
61110 },
61111 {
61112 'tfOpName': 'LogicalAnd',
61113 'category': 'logical',
61114 'inputs': [
61115 {
61116 'start': 0,
61117 'name': 'a',
61118 'type': 'tensor'
61119 },
61120 {
61121 'start': 1,
61122 'name': 'b',
61123 'type': 'tensor'
61124 }
61125 ],
61126 'attrs': [
61127 {
61128 'tfName': 'T',
61129 'name': 'dtype',
61130 'type': 'dtype',
61131 'notSupported': true
61132 }
61133 ]
61134 },
61135 {
61136 'tfOpName': 'LogicalNot',
61137 'category': 'logical',
61138 'inputs': [
61139 {
61140 'start': 0,
61141 'name': 'a',
61142 'type': 'tensor'
61143 }
61144 ],
61145 'attrs': [
61146 {
61147 'tfName': 'T',
61148 'name': 'dtype',
61149 'type': 'dtype',
61150 'notSupported': true
61151 }
61152 ]
61153 },
61154 {
61155 'tfOpName': 'LogicalOr',
61156 'category': 'logical',
61157 'inputs': [
61158 {
61159 'start': 0,
61160 'name': 'a',
61161 'type': 'tensor'
61162 },
61163 {
61164 'start': 1,
61165 'name': 'b',
61166 'type': 'tensor'
61167 }
61168 ],
61169 'attrs': [
61170 {
61171 'tfName': 'T',
61172 'name': 'dtype',
61173 'type': 'dtype',
61174 'notSupported': true
61175 }
61176 ]
61177 },
61178 {
61179 'tfOpName': 'Select',
61180 'category': 'logical',
61181 'inputs': [
61182 {
61183 'start': 0,
61184 'name': 'condition',
61185 'type': 'tensor'
61186 },
61187 {
61188 'start': 1,
61189 'name': 'a',
61190 'type': 'tensor'
61191 },
61192 {
61193 'start': 2,
61194 'name': 'b',
61195 'type': 'tensor'
61196 }
61197 ],
61198 'attrs': [
61199 {
61200 'tfName': 'T',
61201 'name': 'dtype',
61202 'type': 'dtype',
61203 'notSupported': true
61204 }
61205 ]
61206 },
61207 {
61208 'tfOpName': 'SelectV2',
61209 'category': 'logical',
61210 'inputs': [
61211 {
61212 'start': 0,
61213 'name': 'condition',
61214 'type': 'tensor'
61215 },
61216 {
61217 'start': 1,
61218 'name': 'a',
61219 'type': 'tensor'
61220 },
61221 {
61222 'start': 2,
61223 'name': 'b',
61224 'type': 'tensor'
61225 }
61226 ],
61227 'attrs': [
61228 {
61229 'tfName': 'T',
61230 'name': 'dtype',
61231 'type': 'dtype',
61232 'notSupported': true
61233 }
61234 ]
61235 },
61236 {
61237 'tfOpName': 'BitwiseAnd',
61238 'category': 'logical',
61239 'inputs': [
61240 {
61241 'start': 0,
61242 'name': 'x',
61243 'type': 'tensor'
61244 },
61245 {
61246 'start': 1,
61247 'name': 'y',
61248 'type': 'tensor'
61249 }
61250 ]
61251 }
61252 ];
61253
61254 var logical = /*#__PURE__*/Object.freeze({
61255 __proto__: null,
61256 json: json$8
61257 });
61258
61259 /**
61260 * @license
61261 * Copyright 2023 Google LLC. All Rights Reserved.
61262 * Licensed under the Apache License, Version 2.0 (the "License");
61263 * you may not use this file except in compliance with the License.
61264 * You may obtain a copy of the License at
61265 *
61266 * http://www.apache.org/licenses/LICENSE-2.0
61267 *
61268 * Unless required by applicable law or agreed to in writing, software
61269 * distributed under the License is distributed on an "AS IS" BASIS,
61270 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
61271 * See the License for the specific language governing permissions and
61272 * limitations under the License.
61273 * =============================================================================
61274 */
61275 const json$7 = [
61276 {
61277 'tfOpName': '_FusedMatMul',
61278 'category': 'matrices',
61279 'inputs': [
61280 {
61281 'start': 0,
61282 'name': 'a',
61283 'type': 'tensor'
61284 },
61285 {
61286 'start': 1,
61287 'name': 'b',
61288 'type': 'tensor'
61289 },
61290 {
61291 'start': 2,
61292 'end': 0,
61293 'name': 'args',
61294 'type': 'tensors'
61295 }
61296 ],
61297 'attrs': [
61298 {
61299 'tfName': 'num_args',
61300 'name': 'numArgs',
61301 'type': 'number'
61302 },
61303 {
61304 'tfName': 'fused_ops',
61305 'name': 'fusedOps',
61306 'type': 'string[]',
61307 'defaultValue': []
61308 },
61309 {
61310 'tfName': 'epsilon',
61311 'name': 'epsilon',
61312 'type': 'number',
61313 'defaultValue': 0.0001
61314 },
61315 {
61316 'tfName': 'transpose_a',
61317 'name': 'transposeA',
61318 'type': 'bool',
61319 'defaultValue': false
61320 },
61321 {
61322 'tfName': 'transpose_b',
61323 'name': 'transposeB',
61324 'type': 'bool',
61325 'defaultValue': false
61326 },
61327 {
61328 'tfName': 'leakyrelu_alpha',
61329 'name': 'leakyreluAlpha',
61330 'type': 'number',
61331 'defaultValue': 0.2
61332 },
61333 {
61334 'tfName': 'T',
61335 'name': 'dtype',
61336 'type': 'dtype',
61337 'notSupported': true
61338 }
61339 ]
61340 },
61341 {
61342 'tfOpName': 'MatMul',
61343 'category': 'matrices',
61344 'inputs': [
61345 {
61346 'start': 0,
61347 'name': 'a',
61348 'type': 'tensor'
61349 },
61350 {
61351 'start': 1,
61352 'name': 'b',
61353 'type': 'tensor'
61354 }
61355 ],
61356 'attrs': [
61357 {
61358 'tfName': 'transpose_a',
61359 'name': 'transposeA',
61360 'type': 'bool',
61361 'defaultValue': false
61362 },
61363 {
61364 'tfName': 'transpose_b',
61365 'name': 'transposeB',
61366 'type': 'bool',
61367 'defaultValue': false
61368 },
61369 {
61370 'tfName': 'T',
61371 'name': 'dtype',
61372 'type': 'dtype',
61373 'notSupported': true
61374 }
61375 ]
61376 },
61377 {
61378 'tfOpName': 'BatchMatMul',
61379 'category': 'matrices',
61380 'inputs': [
61381 {
61382 'start': 0,
61383 'name': 'a',
61384 'type': 'tensor'
61385 },
61386 {
61387 'start': 1,
61388 'name': 'b',
61389 'type': 'tensor'
61390 }
61391 ],
61392 'attrs': [
61393 {
61394 'tfName': 'adj_x',
61395 'name': 'transposeA',
61396 'type': 'bool',
61397 'defaultValue': false
61398 },
61399 {
61400 'tfName': 'adj_y',
61401 'name': 'transposeB',
61402 'type': 'bool',
61403 'defaultValue': false
61404 },
61405 {
61406 'tfName': 'T',
61407 'name': 'dtype',
61408 'type': 'dtype',
61409 'notSupported': true
61410 }
61411 ]
61412 },
61413 {
61414 'tfOpName': 'BatchMatMulV2',
61415 'category': 'matrices',
61416 'inputs': [
61417 {
61418 'start': 0,
61419 'name': 'a',
61420 'type': 'tensor'
61421 },
61422 {
61423 'start': 1,
61424 'name': 'b',
61425 'type': 'tensor'
61426 }
61427 ],
61428 'attrs': [
61429 {
61430 'tfName': 'adj_x',
61431 'name': 'transposeA',
61432 'type': 'bool',
61433 'defaultValue': false
61434 },
61435 {
61436 'tfName': 'adj_y',
61437 'name': 'transposeB',
61438 'type': 'bool',
61439 'defaultValue': false
61440 },
61441 {
61442 'tfName': 'T',
61443 'name': 'dtype',
61444 'type': 'dtype',
61445 'notSupported': true
61446 }
61447 ]
61448 },
61449 {
61450 'tfOpName': 'Transpose',
61451 'category': 'matrices',
61452 'inputs': [
61453 {
61454 'start': 0,
61455 'name': 'x',
61456 'type': 'tensor'
61457 },
61458 {
61459 'start': 1,
61460 'name': 'perm',
61461 'type': 'number[]'
61462 }
61463 ],
61464 'attrs': [
61465 {
61466 'tfName': 'T',
61467 'name': 'dtype',
61468 'type': 'dtype',
61469 'notSupported': true
61470 }
61471 ]
61472 },
61473 {
61474 'tfOpName': 'Einsum',
61475 'category': 'matrices',
61476 'inputs': [
61477 {
61478 'start': 0,
61479 'end': 0,
61480 'name': 'tensors',
61481 'type': 'tensors'
61482 }
61483 ],
61484 'attrs': [
61485 {
61486 'tfName': 'equation',
61487 'name': 'equation',
61488 'type': 'string'
61489 },
61490 {
61491 'tfName': 'N',
61492 'name': 'n',
61493 'type': 'number',
61494 'defaultValue': 2
61495 },
61496 {
61497 'tfName': 'T',
61498 'name': 'dtype',
61499 'type': 'dtype'
61500 }
61501 ]
61502 },
61503 {
61504 'tfOpName': 'MatrixBandPart',
61505 'category': 'matrices',
61506 'inputs': [
61507 {
61508 'start': 0,
61509 'name': 'a',
61510 'type': 'tensor'
61511 },
61512 {
61513 'start': 1,
61514 'name': 'numLower',
61515 'type': 'tensor'
61516 },
61517 {
61518 'start': 1,
61519 'name': 'numUpper',
61520 'type': 'tensor'
61521 }
61522 ]
61523 }
61524 ];
61525
61526 var matrices = /*#__PURE__*/Object.freeze({
61527 __proto__: null,
61528 json: json$7
61529 });
61530
61531 /**
61532 * @license
61533 * Copyright 2023 Google LLC. All Rights Reserved.
61534 * Licensed under the Apache License, Version 2.0 (the "License");
61535 * you may not use this file except in compliance with the License.
61536 * You may obtain a copy of the License at
61537 *
61538 * http://www.apache.org/licenses/LICENSE-2.0
61539 *
61540 * Unless required by applicable law or agreed to in writing, software
61541 * distributed under the License is distributed on an "AS IS" BASIS,
61542 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
61543 * See the License for the specific language governing permissions and
61544 * limitations under the License.
61545 * =============================================================================
61546 */
61547 const json$6 = [
61548 {
61549 'tfOpName': 'EuclideanNorm',
61550 'category': 'normalization',
61551 'inputs': [
61552 {
61553 'start': 0,
61554 'name': 'x',
61555 'type': 'tensor'
61556 },
61557 {
61558 'start': 1,
61559 'name': 'axis',
61560 'type': 'number[]'
61561 }
61562 ],
61563 'attrs': [
61564 {
61565 'tfName': 'keep_dims',
61566 'name': 'keepDims',
61567 'type': 'bool',
61568 'defaultValue': false
61569 }
61570 ]
61571 },
61572 {
61573 'tfOpName': 'FusedBatchNorm',
61574 'category': 'normalization',
61575 'inputs': [
61576 {
61577 'start': 0,
61578 'name': 'x',
61579 'type': 'tensor'
61580 },
61581 {
61582 'start': 1,
61583 'name': 'scale',
61584 'type': 'tensor'
61585 },
61586 {
61587 'start': 2,
61588 'name': 'offset',
61589 'type': 'tensor'
61590 },
61591 {
61592 'start': 3,
61593 'name': 'mean',
61594 'type': 'tensor'
61595 },
61596 {
61597 'start': 4,
61598 'name': 'variance',
61599 'type': 'tensor'
61600 }
61601 ],
61602 'attrs': [
61603 {
61604 'tfName': 'epsilon',
61605 'name': 'epsilon',
61606 'type': 'number',
61607 'defaultValue': 0.001
61608 },
61609 {
61610 'tfName': 'data_format',
61611 'name': 'dataFormat',
61612 'type': 'string',
61613 'notSupported': true
61614 }
61615 ]
61616 },
61617 {
61618 'tfOpName': 'FusedBatchNormV2',
61619 'category': 'normalization',
61620 'inputs': [
61621 {
61622 'start': 0,
61623 'name': 'x',
61624 'type': 'tensor'
61625 },
61626 {
61627 'start': 1,
61628 'name': 'scale',
61629 'type': 'tensor'
61630 },
61631 {
61632 'start': 2,
61633 'name': 'offset',
61634 'type': 'tensor'
61635 },
61636 {
61637 'start': 3,
61638 'name': 'mean',
61639 'type': 'tensor'
61640 },
61641 {
61642 'start': 4,
61643 'name': 'variance',
61644 'type': 'tensor'
61645 }
61646 ],
61647 'attrs': [
61648 {
61649 'tfName': 'epsilon',
61650 'name': 'epsilon',
61651 'type': 'number',
61652 'defaultValue': 0.001
61653 },
61654 {
61655 'tfName': 'data_format',
61656 'name': 'dataFormat',
61657 'type': 'string',
61658 'notSupported': true
61659 }
61660 ]
61661 },
61662 {
61663 'tfOpName': 'FusedBatchNormV3',
61664 'category': 'normalization',
61665 'inputs': [
61666 {
61667 'start': 0,
61668 'name': 'x',
61669 'type': 'tensor'
61670 },
61671 {
61672 'start': 1,
61673 'name': 'scale',
61674 'type': 'tensor'
61675 },
61676 {
61677 'start': 2,
61678 'name': 'offset',
61679 'type': 'tensor'
61680 },
61681 {
61682 'start': 3,
61683 'name': 'mean',
61684 'type': 'tensor'
61685 },
61686 {
61687 'start': 4,
61688 'name': 'variance',
61689 'type': 'tensor'
61690 }
61691 ],
61692 'attrs': [
61693 {
61694 'tfName': 'epsilon',
61695 'name': 'epsilon',
61696 'type': 'number',
61697 'defaultValue': 0.001
61698 },
61699 {
61700 'tfName': 'data_format',
61701 'name': 'dataFormat',
61702 'type': 'string',
61703 'notSupported': true
61704 }
61705 ]
61706 },
61707 {
61708 'tfOpName': 'LRN',
61709 'category': 'normalization',
61710 'inputs': [
61711 {
61712 'start': 0,
61713 'name': 'x',
61714 'type': 'tensor'
61715 }
61716 ],
61717 'attrs': [
61718 {
61719 'tfName': 'depth_radius',
61720 'name': 'radius',
61721 'type': 'number',
61722 'defaultValue': 5
61723 },
61724 {
61725 'tfName': 'bias',
61726 'name': 'bias',
61727 'type': 'number',
61728 'defaultValue': 1
61729 },
61730 {
61731 'tfName': 'alpha',
61732 'name': 'alpha',
61733 'type': 'number',
61734 'defaultValue': 1
61735 },
61736 {
61737 'tfName': 'beta',
61738 'name': 'beta',
61739 'type': 'number',
61740 'defaultValue': 0.5
61741 }
61742 ]
61743 },
61744 {
61745 'tfOpName': 'Softmax',
61746 'category': 'normalization',
61747 'inputs': [
61748 {
61749 'start': 0,
61750 'name': 'x',
61751 'type': 'tensor'
61752 }
61753 ]
61754 },
61755 {
61756 'tfOpName': 'LogSoftmax',
61757 'category': 'normalization',
61758 'inputs': [
61759 {
61760 'start': 0,
61761 'name': 'x',
61762 'type': 'tensor'
61763 }
61764 ]
61765 }
61766 ];
61767
61768 var normalization = /*#__PURE__*/Object.freeze({
61769 __proto__: null,
61770 json: json$6
61771 });
61772
61773 /**
61774 * @license
61775 * Copyright 2023 Google LLC. All Rights Reserved.
61776 * Licensed under the Apache License, Version 2.0 (the "License");
61777 * you may not use this file except in compliance with the License.
61778 * You may obtain a copy of the License at
61779 *
61780 * http://www.apache.org/licenses/LICENSE-2.0
61781 *
61782 * Unless required by applicable law or agreed to in writing, software
61783 * distributed under the License is distributed on an "AS IS" BASIS,
61784 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
61785 * See the License for the specific language governing permissions and
61786 * limitations under the License.
61787 * =============================================================================
61788 */
61789 const json$5 = [
61790 {
61791 'tfOpName': 'Bincount',
61792 'category': 'reduction',
61793 'inputs': [
61794 {
61795 'start': 0,
61796 'name': 'x',
61797 'type': 'tensor'
61798 },
61799 {
61800 'start': 1,
61801 'name': 'size',
61802 'type': 'number'
61803 },
61804 {
61805 'start': 2,
61806 'name': 'weights',
61807 'type': 'tensor'
61808 }
61809 ]
61810 },
61811 {
61812 'tfOpName': 'DenseBincount',
61813 'category': 'reduction',
61814 'inputs': [
61815 {
61816 'start': 0,
61817 'name': 'x',
61818 'type': 'tensor'
61819 },
61820 {
61821 'start': 1,
61822 'name': 'size',
61823 'type': 'number'
61824 },
61825 {
61826 'start': 2,
61827 'name': 'weights',
61828 'type': 'tensor'
61829 }
61830 ],
61831 'attrs': [
61832 {
61833 'tfName': 'binary_output',
61834 'name': 'binaryOutput',
61835 'type': 'bool'
61836 }
61837 ]
61838 },
61839 {
61840 'tfOpName': 'Max',
61841 'category': 'reduction',
61842 'inputs': [
61843 {
61844 'start': 0,
61845 'name': 'x',
61846 'type': 'tensor'
61847 },
61848 {
61849 'start': 1,
61850 'name': 'axis',
61851 'type': 'number[]'
61852 }
61853 ],
61854 'attrs': [
61855 {
61856 'tfName': 'keep_dims',
61857 'name': 'keepDims',
61858 'type': 'bool'
61859 }
61860 ]
61861 },
61862 {
61863 'tfOpName': 'Mean',
61864 'category': 'reduction',
61865 'inputs': [
61866 {
61867 'start': 0,
61868 'name': 'x',
61869 'type': 'tensor'
61870 },
61871 {
61872 'start': 1,
61873 'name': 'axis',
61874 'type': 'number[]'
61875 }
61876 ],
61877 'attrs': [
61878 {
61879 'tfName': 'keep_dims',
61880 'name': 'keepDims',
61881 'type': 'bool'
61882 }
61883 ]
61884 },
61885 {
61886 'tfOpName': 'Min',
61887 'category': 'reduction',
61888 'inputs': [
61889 {
61890 'start': 0,
61891 'name': 'x',
61892 'type': 'tensor'
61893 },
61894 {
61895 'start': 1,
61896 'name': 'axis',
61897 'type': 'number[]'
61898 }
61899 ],
61900 'attrs': [
61901 {
61902 'tfName': 'keep_dims',
61903 'name': 'keepDims',
61904 'type': 'bool'
61905 }
61906 ]
61907 },
61908 {
61909 'tfOpName': 'Sum',
61910 'category': 'reduction',
61911 'inputs': [
61912 {
61913 'start': 0,
61914 'name': 'x',
61915 'type': 'tensor'
61916 },
61917 {
61918 'start': 1,
61919 'name': 'axis',
61920 'type': 'number[]'
61921 }
61922 ],
61923 'attrs': [
61924 {
61925 'tfName': 'keep_dims',
61926 'name': 'keepDims',
61927 'type': 'bool'
61928 }
61929 ]
61930 },
61931 {
61932 'tfOpName': 'All',
61933 'category': 'reduction',
61934 'inputs': [
61935 {
61936 'start': 0,
61937 'name': 'x',
61938 'type': 'tensor'
61939 },
61940 {
61941 'start': 1,
61942 'name': 'axis',
61943 'type': 'number[]'
61944 }
61945 ],
61946 'attrs': [
61947 {
61948 'tfName': 'keep_dims',
61949 'name': 'keepDims',
61950 'type': 'bool'
61951 }
61952 ]
61953 },
61954 {
61955 'tfOpName': 'Any',
61956 'category': 'reduction',
61957 'inputs': [
61958 {
61959 'start': 0,
61960 'name': 'x',
61961 'type': 'tensor'
61962 },
61963 {
61964 'start': 1,
61965 'name': 'axis',
61966 'type': 'number[]'
61967 }
61968 ],
61969 'attrs': [
61970 {
61971 'tfName': 'keep_dims',
61972 'name': 'keepDims',
61973 'type': 'bool'
61974 }
61975 ]
61976 },
61977 {
61978 'tfOpName': 'ArgMax',
61979 'category': 'reduction',
61980 'inputs': [
61981 {
61982 'start': 0,
61983 'name': 'x',
61984 'type': 'tensor'
61985 },
61986 {
61987 'start': 1,
61988 'name': 'axis',
61989 'type': 'number'
61990 }
61991 ]
61992 },
61993 {
61994 'tfOpName': 'ArgMin',
61995 'category': 'reduction',
61996 'inputs': [
61997 {
61998 'start': 0,
61999 'name': 'x',
62000 'type': 'tensor'
62001 },
62002 {
62003 'start': 1,
62004 'name': 'axis',
62005 'type': 'number'
62006 }
62007 ]
62008 },
62009 {
62010 'tfOpName': 'Prod',
62011 'category': 'reduction',
62012 'inputs': [
62013 {
62014 'start': 0,
62015 'name': 'x',
62016 'type': 'tensor'
62017 },
62018 {
62019 'start': 1,
62020 'name': 'axis',
62021 'type': 'number[]'
62022 }
62023 ],
62024 'attrs': [
62025 {
62026 'tfName': 'keep_dims',
62027 'name': 'keepDims',
62028 'type': 'bool'
62029 },
62030 {
62031 'tfName': 'T',
62032 'name': 'dtype',
62033 'type': 'dtype',
62034 'notSupported': true
62035 }
62036 ]
62037 },
62038 {
62039 'tfOpName': 'Cumprod',
62040 'category': 'reduction',
62041 'inputs': [
62042 {
62043 'start': 0,
62044 'name': 'x',
62045 'type': 'tensor'
62046 },
62047 {
62048 'start': 1,
62049 'name': 'axis',
62050 'type': 'number'
62051 }
62052 ],
62053 'attrs': [
62054 {
62055 'tfName': 'exclusive',
62056 'name': 'exclusive',
62057 'type': 'bool'
62058 },
62059 {
62060 'tfName': 'reverse',
62061 'name': 'reverse',
62062 'type': 'bool'
62063 }
62064 ]
62065 },
62066 {
62067 'tfOpName': 'Cumsum',
62068 'category': 'reduction',
62069 'inputs': [
62070 {
62071 'start': 0,
62072 'name': 'x',
62073 'type': 'tensor'
62074 },
62075 {
62076 'start': 1,
62077 'name': 'axis',
62078 'type': 'number'
62079 }
62080 ],
62081 'attrs': [
62082 {
62083 'tfName': 'exclusive',
62084 'name': 'exclusive',
62085 'type': 'bool'
62086 },
62087 {
62088 'tfName': 'reverse',
62089 'name': 'reverse',
62090 'type': 'bool'
62091 }
62092 ]
62093 }
62094 ];
62095
62096 var reduction = /*#__PURE__*/Object.freeze({
62097 __proto__: null,
62098 json: json$5
62099 });
62100
62101 /**
62102 * @license
62103 * Copyright 2023 Google LLC. All Rights Reserved.
62104 * Licensed under the Apache License, Version 2.0 (the "License");
62105 * you may not use this file except in compliance with the License.
62106 * You may obtain a copy of the License at
62107 *
62108 * http://www.apache.org/licenses/LICENSE-2.0
62109 *
62110 * Unless required by applicable law or agreed to in writing, software
62111 * distributed under the License is distributed on an "AS IS" BASIS,
62112 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62113 * See the License for the specific language governing permissions and
62114 * limitations under the License.
62115 * =============================================================================
62116 */
62117 const json$4 = [
62118 {
62119 'tfOpName': 'ConcatV2',
62120 'category': 'slice_join',
62121 'inputs': [
62122 {
62123 'start': 0,
62124 'end': -1,
62125 'name': 'tensors',
62126 'type': 'tensors'
62127 },
62128 {
62129 'start': -1,
62130 'name': 'axis',
62131 'type': 'number'
62132 }
62133 ],
62134 'attrs': [
62135 {
62136 'tfName': 'N',
62137 'name': 'n',
62138 'type': 'number',
62139 'defaultValue': 2
62140 }
62141 ]
62142 },
62143 {
62144 'tfOpName': 'Concat',
62145 'category': 'slice_join',
62146 'inputs': [
62147 {
62148 'start': 1,
62149 'end': 0,
62150 'name': 'tensors',
62151 'type': 'tensors'
62152 },
62153 {
62154 'start': 0,
62155 'name': 'axis',
62156 'type': 'number'
62157 }
62158 ],
62159 'attrs': [
62160 {
62161 'tfName': 'N',
62162 'name': 'n',
62163 'type': 'number',
62164 'defaultValue': 2
62165 }
62166 ]
62167 },
62168 {
62169 'tfOpName': 'GatherV2',
62170 'category': 'slice_join',
62171 'inputs': [
62172 {
62173 'start': 0,
62174 'name': 'x',
62175 'type': 'tensor'
62176 },
62177 {
62178 'start': 1,
62179 'name': 'indices',
62180 'type': 'tensor'
62181 },
62182 {
62183 'start': 2,
62184 'name': 'axis',
62185 'type': 'number',
62186 'defaultValue': 0
62187 }
62188 ],
62189 'attrs': [
62190 {
62191 'tfName': 'batch_dims',
62192 'name': 'batchDims',
62193 'type': 'number',
62194 'defaultValue': 0
62195 }
62196 ]
62197 },
62198 {
62199 'tfOpName': 'Gather',
62200 'category': 'slice_join',
62201 'inputs': [
62202 {
62203 'start': 0,
62204 'name': 'x',
62205 'type': 'tensor'
62206 },
62207 {
62208 'start': 1,
62209 'name': 'indices',
62210 'type': 'tensor'
62211 }
62212 ],
62213 'attrs': [
62214 {
62215 'tfName': 'validate_indices',
62216 'name': 'validateIndices',
62217 'type': 'bool',
62218 'notSupported': true
62219 }
62220 ]
62221 },
62222 {
62223 'tfOpName': 'Reverse',
62224 'category': 'slice_join',
62225 'inputs': [
62226 {
62227 'start': 0,
62228 'name': 'x',
62229 'type': 'tensor'
62230 },
62231 {
62232 'start': 1,
62233 'name': 'dims',
62234 'type': 'bool[]'
62235 }
62236 ]
62237 },
62238 {
62239 'tfOpName': 'ReverseV2',
62240 'category': 'slice_join',
62241 'inputs': [
62242 {
62243 'start': 0,
62244 'name': 'x',
62245 'type': 'tensor'
62246 },
62247 {
62248 'start': 1,
62249 'name': 'axis',
62250 'type': 'number[]'
62251 }
62252 ]
62253 },
62254 {
62255 'tfOpName': 'Slice',
62256 'category': 'slice_join',
62257 'inputs': [
62258 {
62259 'start': 0,
62260 'name': 'x',
62261 'type': 'tensor'
62262 },
62263 {
62264 'start': 1,
62265 'name': 'begin',
62266 'type': 'number[]'
62267 },
62268 {
62269 'start': 2,
62270 'name': 'size',
62271 'type': 'number[]'
62272 }
62273 ]
62274 },
62275 {
62276 'tfOpName': 'StridedSlice',
62277 'category': 'slice_join',
62278 'inputs': [
62279 {
62280 'start': 0,
62281 'name': 'x',
62282 'type': 'tensor'
62283 },
62284 {
62285 'start': 1,
62286 'name': 'begin',
62287 'type': 'number[]'
62288 },
62289 {
62290 'start': 2,
62291 'name': 'end',
62292 'type': 'number[]'
62293 },
62294 {
62295 'start': 3,
62296 'name': 'strides',
62297 'type': 'number[]'
62298 }
62299 ],
62300 'attrs': [
62301 {
62302 'tfName': 'begin_mask',
62303 'name': 'beginMask',
62304 'type': 'number',
62305 'defaultValue': 0
62306 },
62307 {
62308 'tfName': 'end_mask',
62309 'name': 'endMask',
62310 'type': 'number',
62311 'defaultValue': 0
62312 },
62313 {
62314 'tfName': 'new_axis_mask',
62315 'name': 'newAxisMask',
62316 'type': 'number',
62317 'defaultValue': 0
62318 },
62319 {
62320 'tfName': 'ellipsis_mask',
62321 'name': 'ellipsisMask',
62322 'type': 'number',
62323 'defaultValue': 0
62324 },
62325 {
62326 'tfName': 'shrink_axis_mask',
62327 'name': 'shrinkAxisMask',
62328 'type': 'number',
62329 'defaultValue': 0
62330 }
62331 ]
62332 },
62333 {
62334 'tfOpName': 'Pack',
62335 'category': 'slice_join',
62336 'inputs': [
62337 {
62338 'start': 0,
62339 'end': 0,
62340 'name': 'tensors',
62341 'type': 'tensors'
62342 }
62343 ],
62344 'attrs': [
62345 {
62346 'tfName': 'axis',
62347 'name': 'axis',
62348 'type': 'number',
62349 'defaultValue': 0
62350 }
62351 ]
62352 },
62353 {
62354 'tfOpName': 'Unpack',
62355 'category': 'slice_join',
62356 'inputs': [
62357 {
62358 'start': 0,
62359 'name': 'tensor',
62360 'type': 'tensor'
62361 }
62362 ],
62363 'attrs': [
62364 {
62365 'tfName': 'axis',
62366 'name': 'axis',
62367 'type': 'number',
62368 'defaultValue': 0
62369 },
62370 {
62371 'tfName': 'num',
62372 'name': 'num',
62373 'type': 'number',
62374 'defaultValue': 0,
62375 'notSupported': true
62376 }
62377 ]
62378 },
62379 {
62380 'tfOpName': 'Tile',
62381 'category': 'slice_join',
62382 'inputs': [
62383 {
62384 'start': 0,
62385 'name': 'x',
62386 'type': 'tensor'
62387 },
62388 {
62389 'start': 1,
62390 'name': 'reps',
62391 'type': 'number[]'
62392 }
62393 ]
62394 },
62395 {
62396 'tfOpName': 'Split',
62397 'category': 'slice_join',
62398 'inputs': [
62399 {
62400 'start': 0,
62401 'name': 'axis',
62402 'type': 'number',
62403 'defaultValue': 0
62404 },
62405 {
62406 'start': 1,
62407 'name': 'x',
62408 'type': 'tensor'
62409 }
62410 ],
62411 'attrs': [
62412 {
62413 'tfName': 'num_split',
62414 'name': 'numOrSizeSplits',
62415 'type': 'number',
62416 'defaultValue': 1
62417 }
62418 ]
62419 },
62420 {
62421 'tfOpName': 'SplitV',
62422 'category': 'slice_join',
62423 'inputs': [
62424 {
62425 'start': 0,
62426 'name': 'x',
62427 'type': 'tensor'
62428 },
62429 {
62430 'start': 1,
62431 'name': 'numOrSizeSplits',
62432 'type': 'number[]'
62433 },
62434 {
62435 'start': 2,
62436 'name': 'axis',
62437 'type': 'number',
62438 'defaultValue': 0
62439 }
62440 ]
62441 },
62442 {
62443 'tfOpName': 'ScatterNd',
62444 'category': 'slice_join',
62445 'inputs': [
62446 {
62447 'start': 0,
62448 'name': 'indices',
62449 'type': 'tensor'
62450 },
62451 {
62452 'start': 1,
62453 'name': 'values',
62454 'type': 'tensor'
62455 },
62456 {
62457 'start': 2,
62458 'name': 'shape',
62459 'type': 'number[]'
62460 }
62461 ]
62462 },
62463 {
62464 'tfOpName': 'GatherNd',
62465 'category': 'slice_join',
62466 'inputs': [
62467 {
62468 'start': 0,
62469 'name': 'x',
62470 'type': 'tensor'
62471 },
62472 {
62473 'start': 1,
62474 'name': 'indices',
62475 'type': 'tensor'
62476 }
62477 ]
62478 },
62479 {
62480 'tfOpName': 'SparseToDense',
62481 'category': 'slice_join',
62482 'inputs': [
62483 {
62484 'start': 0,
62485 'name': 'sparseIndices',
62486 'type': 'tensor'
62487 },
62488 {
62489 'start': 1,
62490 'name': 'outputShape',
62491 'type': 'number[]'
62492 },
62493 {
62494 'start': 2,
62495 'name': 'sparseValues',
62496 'type': 'tensor'
62497 },
62498 {
62499 'start': 3,
62500 'name': 'defaultValue',
62501 'type': 'tensor'
62502 }
62503 ],
62504 'attrs': [
62505 {
62506 'tfName': 'validate_indices',
62507 'name': 'validateIndices',
62508 'type': 'bool',
62509 'defaultValue': false,
62510 'notSupported': true
62511 }
62512 ]
62513 },
62514 {
62515 'tfOpName': 'TensorScatterUpdate',
62516 'category': 'slice_join',
62517 'inputs': [
62518 {
62519 'start': 0,
62520 'name': 'tensor',
62521 'type': 'tensor'
62522 },
62523 {
62524 'start': 1,
62525 'name': 'indices',
62526 'type': 'tensor'
62527 },
62528 {
62529 'start': 2,
62530 'name': 'values',
62531 'type': 'tensor'
62532 }
62533 ]
62534 }
62535 ];
62536
62537 var sliceJoin = /*#__PURE__*/Object.freeze({
62538 __proto__: null,
62539 json: json$4
62540 });
62541
62542 /**
62543 * @license
62544 * Copyright 2023 Google LLC. All Rights Reserved.
62545 * Licensed under the Apache License, Version 2.0 (the "License");
62546 * you may not use this file except in compliance with the License.
62547 * You may obtain a copy of the License at
62548 *
62549 * http://www.apache.org/licenses/LICENSE-2.0
62550 *
62551 * Unless required by applicable law or agreed to in writing, software
62552 * distributed under the License is distributed on an "AS IS" BASIS,
62553 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62554 * See the License for the specific language governing permissions and
62555 * limitations under the License.
62556 * =============================================================================
62557 */
62558 const json$3 = [
62559 {
62560 'tfOpName': 'SparseFillEmptyRows',
62561 'category': 'sparse',
62562 'inputs': [
62563 {
62564 'start': 0,
62565 'name': 'indices',
62566 'type': 'tensor'
62567 },
62568 {
62569 'start': 1,
62570 'name': 'values',
62571 'type': 'tensor'
62572 },
62573 {
62574 'start': 2,
62575 'name': 'denseShape',
62576 'type': 'tensor'
62577 },
62578 {
62579 'start': 3,
62580 'name': 'defaultValue',
62581 'type': 'tensor'
62582 }
62583 ]
62584 },
62585 {
62586 'tfOpName': 'SparseReshape',
62587 'category': 'sparse',
62588 'inputs': [
62589 {
62590 'start': 0,
62591 'name': 'inputIndices',
62592 'type': 'tensor'
62593 },
62594 {
62595 'start': 1,
62596 'name': 'inputShape',
62597 'type': 'tensor'
62598 },
62599 {
62600 'start': 2,
62601 'name': 'newShape',
62602 'type': 'tensor'
62603 }
62604 ],
62605 'attrs': [
62606 {
62607 'tfName': 'T',
62608 'name': 'dtype',
62609 'type': 'dtype',
62610 'notSupported': true
62611 }
62612 ]
62613 },
62614 {
62615 'tfOpName': 'SparseSegmentMean',
62616 'category': 'sparse',
62617 'inputs': [
62618 {
62619 'start': 0,
62620 'name': 'data',
62621 'type': 'tensor'
62622 },
62623 {
62624 'start': 1,
62625 'name': 'indices',
62626 'type': 'tensor'
62627 },
62628 {
62629 'start': 2,
62630 'name': 'segmentIds',
62631 'type': 'tensor'
62632 }
62633 ]
62634 },
62635 {
62636 'tfOpName': 'SparseSegmentSum',
62637 'category': 'sparse',
62638 'inputs': [
62639 {
62640 'start': 0,
62641 'name': 'data',
62642 'type': 'tensor'
62643 },
62644 {
62645 'start': 1,
62646 'name': 'indices',
62647 'type': 'tensor'
62648 },
62649 {
62650 'start': 2,
62651 'name': 'segmentIds',
62652 'type': 'tensor'
62653 }
62654 ]
62655 }
62656 ];
62657
62658 var sparse = /*#__PURE__*/Object.freeze({
62659 __proto__: null,
62660 json: json$3
62661 });
62662
62663 /**
62664 * @license
62665 * Copyright 2023 Google LLC. All Rights Reserved.
62666 * Licensed under the Apache License, Version 2.0 (the "License");
62667 * you may not use this file except in compliance with the License.
62668 * You may obtain a copy of the License at
62669 *
62670 * http://www.apache.org/licenses/LICENSE-2.0
62671 *
62672 * Unless required by applicable law or agreed to in writing, software
62673 * distributed under the License is distributed on an "AS IS" BASIS,
62674 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62675 * See the License for the specific language governing permissions and
62676 * limitations under the License.
62677 * =============================================================================
62678 */
62679 const json$2 = [
62680 {
62681 'tfOpName': 'FFT',
62682 'category': 'spectral',
62683 'inputs': [
62684 {
62685 'start': 0,
62686 'name': 'x',
62687 'type': 'tensor'
62688 }
62689 ]
62690 },
62691 {
62692 'tfOpName': 'IFFT',
62693 'category': 'spectral',
62694 'inputs': [
62695 {
62696 'start': 0,
62697 'name': 'x',
62698 'type': 'tensor'
62699 }
62700 ]
62701 },
62702 {
62703 'tfOpName': 'RFFT',
62704 'category': 'spectral',
62705 'inputs': [
62706 {
62707 'start': 0,
62708 'name': 'x',
62709 'type': 'tensor'
62710 },
62711 {
62712 'start': 1,
62713 'name': 'fft_length',
62714 'type': 'number',
62715 'notSupported': true
62716 }
62717 ]
62718 },
62719 {
62720 'tfOpName': 'IRFFT',
62721 'category': 'spectral',
62722 'inputs': [
62723 {
62724 'start': 0,
62725 'name': 'x',
62726 'type': 'tensor'
62727 },
62728 {
62729 'start': 1,
62730 'name': 'fft_length',
62731 'type': 'number',
62732 'notSupported': true
62733 }
62734 ]
62735 }
62736 ];
62737
62738 var spectral = /*#__PURE__*/Object.freeze({
62739 __proto__: null,
62740 json: json$2
62741 });
62742
62743 /**
62744 * @license
62745 * Copyright 2023 Google LLC. All Rights Reserved.
62746 * Licensed under the Apache License, Version 2.0 (the "License");
62747 * you may not use this file except in compliance with the License.
62748 * You may obtain a copy of the License at
62749 *
62750 * http://www.apache.org/licenses/LICENSE-2.0
62751 *
62752 * Unless required by applicable law or agreed to in writing, software
62753 * distributed under the License is distributed on an "AS IS" BASIS,
62754 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62755 * See the License for the specific language governing permissions and
62756 * limitations under the License.
62757 * =============================================================================
62758 */
62759 const json$1 = [
62760 {
62761 'tfOpName': 'StaticRegexReplace',
62762 'category': 'string',
62763 'inputs': [
62764 {
62765 'start': 0,
62766 'name': 'input',
62767 'type': 'tensor'
62768 }
62769 ],
62770 'attrs': [
62771 {
62772 'tfName': 'pattern',
62773 'name': 'pattern',
62774 'type': 'string'
62775 },
62776 {
62777 'tfName': 'rewrite',
62778 'name': 'rewrite',
62779 'type': 'string'
62780 },
62781 {
62782 'tfName': 'replace_global',
62783 'name': 'replaceGlobal',
62784 'type': 'bool'
62785 }
62786 ]
62787 },
62788 {
62789 'tfOpName': 'StringNGrams',
62790 'category': 'string',
62791 'inputs': [
62792 {
62793 'start': 0,
62794 'name': 'data',
62795 'type': 'tensor'
62796 },
62797 {
62798 'start': 1,
62799 'name': 'dataSplits',
62800 'type': 'tensor'
62801 }
62802 ],
62803 'attrs': [
62804 {
62805 'tfName': 'separator',
62806 'name': 'separator',
62807 'type': 'string'
62808 },
62809 {
62810 'tfName': 'ngram_widths',
62811 'name': 'nGramWidths',
62812 'type': 'number[]'
62813 },
62814 {
62815 'tfName': 'left_pad',
62816 'name': 'leftPad',
62817 'type': 'string'
62818 },
62819 {
62820 'tfName': 'right_pad',
62821 'name': 'rightPad',
62822 'type': 'string'
62823 },
62824 {
62825 'tfName': 'pad_width',
62826 'name': 'padWidth',
62827 'type': 'number'
62828 },
62829 {
62830 'tfName': 'preserve_short_sequences',
62831 'name': 'preserveShortSequences',
62832 'type': 'bool'
62833 }
62834 ],
62835 'outputs': [
62836 'ngrams',
62837 'ngrams_splits'
62838 ]
62839 },
62840 {
62841 'tfOpName': 'StringSplit',
62842 'category': 'string',
62843 'inputs': [
62844 {
62845 'start': 0,
62846 'name': 'input',
62847 'type': 'tensor'
62848 },
62849 {
62850 'start': 1,
62851 'name': 'delimiter',
62852 'type': 'tensor'
62853 }
62854 ],
62855 'attrs': [
62856 {
62857 'tfName': 'skip_empty',
62858 'name': 'skipEmpty',
62859 'type': 'bool'
62860 }
62861 ],
62862 'outputs': [
62863 'indices',
62864 'values',
62865 'shape'
62866 ]
62867 },
62868 {
62869 'tfOpName': 'StringToHashBucketFast',
62870 'category': 'string',
62871 'inputs': [
62872 {
62873 'start': 0,
62874 'name': 'input',
62875 'type': 'tensor'
62876 }
62877 ],
62878 'attrs': [
62879 {
62880 'tfName': 'num_buckets',
62881 'name': 'numBuckets',
62882 'type': 'number'
62883 }
62884 ]
62885 }
62886 ];
62887
62888 var string = /*#__PURE__*/Object.freeze({
62889 __proto__: null,
62890 json: json$1
62891 });
62892
62893 /**
62894 * @license
62895 * Copyright 2023 Google LLC. All Rights Reserved.
62896 * Licensed under the Apache License, Version 2.0 (the "License");
62897 * you may not use this file except in compliance with the License.
62898 * You may obtain a copy of the License at
62899 *
62900 * http://www.apache.org/licenses/LICENSE-2.0
62901 *
62902 * Unless required by applicable law or agreed to in writing, software
62903 * distributed under the License is distributed on an "AS IS" BASIS,
62904 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62905 * See the License for the specific language governing permissions and
62906 * limitations under the License.
62907 * =============================================================================
62908 */
62909 const json = [
62910 {
62911 'tfOpName': 'Cast',
62912 'category': 'transformation',
62913 'inputs': [
62914 {
62915 'start': 0,
62916 'name': 'x',
62917 'type': 'tensor'
62918 }
62919 ],
62920 'attrs': [
62921 {
62922 'tfName': 'SrcT',
62923 'name': 'sdtype',
62924 'type': 'dtype',
62925 'notSupported': true
62926 },
62927 {
62928 'tfName': 'DstT',
62929 'name': 'dtype',
62930 'type': 'dtype'
62931 }
62932 ]
62933 },
62934 {
62935 'tfOpName': 'ExpandDims',
62936 'category': 'transformation',
62937 'inputs': [
62938 {
62939 'start': 0,
62940 'name': 'x',
62941 'type': 'tensor'
62942 },
62943 {
62944 'start': 1,
62945 'name': 'axis',
62946 'type': 'number'
62947 }
62948 ]
62949 },
62950 {
62951 'tfOpName': 'MirrorPad',
62952 'category': 'transformation',
62953 'inputs': [
62954 {
62955 'start': 0,
62956 'name': 'x',
62957 'type': 'tensor'
62958 },
62959 {
62960 'start': 1,
62961 'name': 'padding',
62962 'type': 'number[]'
62963 }
62964 ],
62965 'attrs': [
62966 {
62967 'tfName': 'mode',
62968 'name': 'mode',
62969 'type': 'string'
62970 }
62971 ]
62972 },
62973 {
62974 'tfOpName': 'Pad',
62975 'category': 'transformation',
62976 'inputs': [
62977 {
62978 'start': 0,
62979 'name': 'x',
62980 'type': 'tensor'
62981 },
62982 {
62983 'start': 1,
62984 'name': 'padding',
62985 'type': 'number[]'
62986 }
62987 ],
62988 'attrs': [
62989 {
62990 'tfName': 'constant_value',
62991 'name': 'constantValue',
62992 'type': 'number',
62993 'defaultValue': 0
62994 }
62995 ]
62996 },
62997 {
62998 'tfOpName': 'PadV2',
62999 'category': 'transformation',
63000 'inputs': [
63001 {
63002 'start': 0,
63003 'name': 'x',
63004 'type': 'tensor'
63005 },
63006 {
63007 'start': 1,
63008 'name': 'padding',
63009 'type': 'number[]'
63010 },
63011 {
63012 'start': 2,
63013 'name': 'constantValue',
63014 'type': 'number',
63015 'defaultValue': 0
63016 }
63017 ]
63018 },
63019 {
63020 'tfOpName': 'Reshape',
63021 'category': 'transformation',
63022 'inputs': [
63023 {
63024 'start': 0,
63025 'name': 'x',
63026 'type': 'tensor'
63027 },
63028 {
63029 'start': 1,
63030 'name': 'shape',
63031 'type': 'number[]'
63032 }
63033 ]
63034 },
63035 {
63036 'tfOpName': 'EnsureShape',
63037 'category': 'transformation',
63038 'inputs': [
63039 {
63040 'start': 0,
63041 'name': 'x',
63042 'type': 'tensor'
63043 },
63044 {
63045 'start': 1,
63046 'name': 'shape',
63047 'type': 'number[]'
63048 }
63049 ]
63050 },
63051 {
63052 'tfOpName': 'Squeeze',
63053 'category': 'transformation',
63054 'inputs': [
63055 {
63056 'start': 0,
63057 'name': 'x',
63058 'type': 'tensor'
63059 }
63060 ],
63061 'attrs': [
63062 {
63063 'tfName': 'axis',
63064 'tfDeprecatedName': 'squeeze_dims',
63065 'name': 'axis',
63066 'type': 'number[]'
63067 }
63068 ]
63069 },
63070 {
63071 'tfOpName': 'SpaceToBatchND',
63072 'category': 'transformation',
63073 'inputs': [
63074 {
63075 'start': 0,
63076 'name': 'x',
63077 'type': 'tensor'
63078 },
63079 {
63080 'start': 1,
63081 'name': 'blockShape',
63082 'type': 'number[]'
63083 },
63084 {
63085 'start': 2,
63086 'name': 'paddings',
63087 'type': 'number[]'
63088 }
63089 ]
63090 },
63091 {
63092 'tfOpName': 'BatchToSpaceND',
63093 'category': 'transformation',
63094 'inputs': [
63095 {
63096 'start': 0,
63097 'name': 'x',
63098 'type': 'tensor'
63099 },
63100 {
63101 'start': 1,
63102 'name': 'blockShape',
63103 'type': 'number[]'
63104 },
63105 {
63106 'start': 2,
63107 'name': 'crops',
63108 'type': 'number[]'
63109 }
63110 ]
63111 },
63112 {
63113 'tfOpName': 'DepthToSpace',
63114 'category': 'transformation',
63115 'inputs': [
63116 {
63117 'start': 0,
63118 'name': 'x',
63119 'type': 'tensor'
63120 }
63121 ],
63122 'attrs': [
63123 {
63124 'tfName': 'block_size',
63125 'name': 'blockSize',
63126 'type': 'number'
63127 },
63128 {
63129 'tfName': 'data_format',
63130 'name': 'dataFormat',
63131 'type': 'string'
63132 }
63133 ]
63134 },
63135 {
63136 'tfOpName': 'BroadcastTo',
63137 'category': 'transformation',
63138 'inputs': [
63139 {
63140 'start': 0,
63141 'name': 'x',
63142 'type': 'tensor'
63143 },
63144 {
63145 'start': 1,
63146 'name': 'shape',
63147 'type': 'number[]'
63148 }
63149 ],
63150 'attrs': []
63151 },
63152 {
63153 'tfOpName': 'BroadcastArgs',
63154 'category': 'transformation',
63155 'inputs': [
63156 {
63157 'start': 0,
63158 'name': 's0',
63159 'type': 'tensor'
63160 },
63161 {
63162 'start': 1,
63163 'name': 's1',
63164 'type': 'tensor'
63165 }
63166 ],
63167 'attrs': []
63168 }
63169 ];
63170
63171 var transformation = /*#__PURE__*/Object.freeze({
63172 __proto__: null,
63173 json: json
63174 });
63175
63176 /**
63177 * @license
63178 * Copyright 2018 Google LLC. All Rights Reserved.
63179 * Licensed under the Apache License, Version 2.0 (the "License");
63180 * you may not use this file except in compliance with the License.
63181 * You may obtain a copy of the License at
63182 *
63183 * http://www.apache.org/licenses/LICENSE-2.0
63184 *
63185 * Unless required by applicable law or agreed to in writing, software
63186 * distributed under the License is distributed on an "AS IS" BASIS,
63187 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
63188 * See the License for the specific language governing permissions and
63189 * limitations under the License.
63190 * =============================================================================
63191 */
63192 class OperationMapper {
63193 // Singleton instance for the mapper
63194 static get Instance() {
63195 return this._instance || (this._instance = new this());
63196 }
63197 // Loads the op mapping from the JSON file.
63198 constructor() {
63199 const ops = [
63200 arithmetic, basicMath, control, convolution, creation, dynamic,
63201 evaluation, graph, hashTable, image, logical, matrices, normalization,
63202 reduction, sliceJoin, sparse, spectral, string, transformation
63203 ];
63204 const mappersJson = [].concat(...ops.map(op => op.json));
63205 this.opMappers = mappersJson.reduce((map, mapper) => {
63206 map[mapper.tfOpName] = mapper;
63207 return map;
63208 }, {});
63209 }
63210 // Converts the model inference graph from Tensorflow GraphDef to local
63211 // representation for TensorFlow.js API
63212 transformGraph(graph, signature = {}) {
63213 const tfNodes = graph.node;
63214 const placeholders = [];
63215 const weights = [];
63216 const initNodes = [];
63217 const nodes = tfNodes.reduce((map, node) => {
63218 map[node.name] = this.mapNode(node);
63219 if (node.op.startsWith('Placeholder')) {
63220 placeholders.push(map[node.name]);
63221 }
63222 else if (node.op === 'Const') {
63223 weights.push(map[node.name]);
63224 }
63225 else if (node.input == null || node.input.length === 0) {
63226 initNodes.push(map[node.name]);
63227 }
63228 return map;
63229 }, {});
63230 let inputs = [];
63231 const outputs = [];
63232 let inputNodeNameToKey = {};
63233 let outputNodeNameToKey = {};
63234 if (signature != null) {
63235 inputNodeNameToKey = this.mapSignatureEntries(signature.inputs);
63236 outputNodeNameToKey = this.mapSignatureEntries(signature.outputs);
63237 }
63238 const allNodes = Object.keys(nodes);
63239 allNodes.forEach(key => {
63240 const node = nodes[key];
63241 node.inputNames.forEach((name, index) => {
63242 const [nodeName, , outputName] = getNodeNameAndIndex(name);
63243 const inputNode = nodes[nodeName];
63244 if (inputNode.outputs != null) {
63245 const outputIndex = inputNode.outputs.indexOf(outputName);
63246 if (outputIndex !== -1) {
63247 const inputName = `${nodeName}:${outputIndex}`;
63248 // update the input name to use the mapped output index directly.
63249 node.inputNames[index] = inputName;
63250 }
63251 }
63252 node.inputs.push(inputNode);
63253 inputNode.children.push(node);
63254 });
63255 });
63256 // if signature has not outputs set, add any node that does not have
63257 // outputs.
63258 if (Object.keys(outputNodeNameToKey).length === 0) {
63259 allNodes.forEach(key => {
63260 const node = nodes[key];
63261 if (node.children.length === 0) {
63262 outputs.push(node);
63263 }
63264 });
63265 }
63266 else {
63267 Object.keys(outputNodeNameToKey).forEach(name => {
63268 const [nodeName,] = getNodeNameAndIndex(name);
63269 const node = nodes[nodeName];
63270 if (node != null) {
63271 node.signatureKey = outputNodeNameToKey[name];
63272 outputs.push(node);
63273 }
63274 });
63275 }
63276 if (Object.keys(inputNodeNameToKey).length > 0) {
63277 Object.keys(inputNodeNameToKey).forEach(name => {
63278 const [nodeName,] = getNodeNameAndIndex(name);
63279 const node = nodes[nodeName];
63280 if (node) {
63281 node.signatureKey = inputNodeNameToKey[name];
63282 inputs.push(node);
63283 }
63284 });
63285 }
63286 else {
63287 inputs = placeholders;
63288 }
63289 let functions = {};
63290 if (graph.library != null && graph.library.function != null) {
63291 functions = graph.library.function.reduce((functions, func) => {
63292 functions[func.signature.name] = this.mapFunction(func);
63293 return functions;
63294 }, {});
63295 }
63296 const result = { nodes, inputs, outputs, weights, placeholders, signature, functions };
63297 if (initNodes.length > 0) {
63298 result.initNodes = initNodes;
63299 }
63300 return result;
63301 }
63302 mapSignatureEntries(entries) {
63303 return Object.keys(entries || {})
63304 .reduce((prev, curr) => {
63305 prev[entries[curr].name] = curr;
63306 return prev;
63307 }, {});
63308 }
63309 mapNode(node) {
63310 // Unsupported ops will cause an error at run-time (not parse time), since
63311 // they may not be used by the actual execution subgraph.
63312 const mapper = getRegisteredOp(node.op) || this.opMappers[node.op] || {};
63313 if (node.attr == null) {
63314 node.attr = {};
63315 }
63316 const newNode = {
63317 name: node.name,
63318 op: node.op,
63319 category: mapper.category,
63320 inputNames: (node.input ||
63321 []).map(input => input.startsWith('^') ? input.slice(1) : input),
63322 inputs: [],
63323 children: [],
63324 inputParams: {},
63325 attrParams: {},
63326 rawAttrs: node.attr,
63327 outputs: mapper.outputs
63328 };
63329 if (mapper.inputs != null) {
63330 newNode.inputParams =
63331 mapper.inputs.reduce((map, param) => {
63332 map[param.name] = {
63333 type: param.type,
63334 inputIndexStart: param.start,
63335 inputIndexEnd: param.end
63336 };
63337 return map;
63338 }, {});
63339 }
63340 if (mapper.attrs != null) {
63341 newNode.attrParams =
63342 mapper.attrs.reduce((map, param) => {
63343 const type = param.type;
63344 let value = undefined;
63345 switch (param.type) {
63346 case 'string':
63347 value = getStringParam(node.attr, param.tfName, param.defaultValue);
63348 if (value === undefined && !!param.tfDeprecatedName) {
63349 value = getStringParam(node.attr, param.tfDeprecatedName, param.defaultValue);
63350 }
63351 break;
63352 case 'string[]':
63353 value = getStringArrayParam(node.attr, param.tfName, param.defaultValue);
63354 if (value === undefined && !!param.tfDeprecatedName) {
63355 value = getStringArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
63356 }
63357 break;
63358 case 'number':
63359 value = getNumberParam(node.attr, param.tfName, (param.defaultValue || 0));
63360 if (value === undefined && !!param.tfDeprecatedName) {
63361 value = getNumberParam(node.attr, param.tfDeprecatedName, param.defaultValue);
63362 }
63363 break;
63364 case 'number[]':
63365 value = getNumericArrayParam(node.attr, param.tfName, param.defaultValue);
63366 if (value === undefined && !!param.tfDeprecatedName) {
63367 value = getNumericArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
63368 }
63369 break;
63370 case 'bool':
63371 value = getBoolParam(node.attr, param.tfName, param.defaultValue);
63372 if (value === undefined && !!param.tfDeprecatedName) {
63373 value = getBoolParam(node.attr, param.tfDeprecatedName, param.defaultValue);
63374 }
63375 break;
63376 case 'bool[]':
63377 value = getBoolArrayParam(node.attr, param.tfName, param.defaultValue);
63378 if (value === undefined && !!param.tfDeprecatedName) {
63379 value = getBoolArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
63380 }
63381 break;
63382 case 'shape':
63383 value = getTensorShapeParam(node.attr, param.tfName, param.defaultValue);
63384 if (value === undefined && !!param.tfDeprecatedName) {
63385 value = getTensorShapeParam(node.attr, param.tfDeprecatedName, param.defaultValue);
63386 }
63387 break;
63388 case 'shape[]':
63389 value = getTensorShapeArrayParam(node.attr, param.tfName, param.defaultValue);
63390 if (value === undefined && !!param.tfDeprecatedName) {
63391 value = getTensorShapeArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
63392 }
63393 break;
63394 case 'dtype':
63395 value = getDtypeParam(node.attr, param.tfName, param.defaultValue);
63396 if (value === undefined && !!param.tfDeprecatedName) {
63397 value = getDtypeParam(node.attr, param.tfDeprecatedName, param.defaultValue);
63398 }
63399 break;
63400 case 'dtype[]':
63401 value = getDtypeArrayParam(node.attr, param.tfName, param.defaultValue);
63402 if (value === undefined && !!param.tfDeprecatedName) {
63403 value = getDtypeArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
63404 }
63405 break;
63406 case 'func':
63407 value = getFuncParam(node.attr, param.tfName, param.defaultValue);
63408 if (value === undefined && !!param.tfDeprecatedName) {
63409 value = getFuncParam(node.attr, param.tfDeprecatedName, param.defaultValue);
63410 }
63411 break;
63412 case 'tensor':
63413 case 'tensors':
63414 break;
63415 default:
63416 throw new Error(`Unsupported param type: ${param.type} for op: ${node.op}`);
63417 }
63418 map[param.name] = { value, type };
63419 return map;
63420 }, {});
63421 }
63422 return newNode;
63423 }
63424 // map the TFunctionDef to TFJS graph object
63425 mapFunction(functionDef) {
63426 const tfNodes = functionDef.nodeDef;
63427 const placeholders = [];
63428 const weights = [];
63429 let nodes = {};
63430 if (tfNodes != null) {
63431 nodes = tfNodes.reduce((map, node) => {
63432 map[node.name] = this.mapNode(node);
63433 if (node.op === 'Const') {
63434 weights.push(map[node.name]);
63435 }
63436 return map;
63437 }, {});
63438 }
63439 const inputs = [];
63440 const outputs = [];
63441 functionDef.signature.inputArg.forEach(arg => {
63442 const [nodeName,] = getNodeNameAndIndex(arg.name);
63443 const node = {
63444 name: nodeName,
63445 op: 'Placeholder',
63446 inputs: [],
63447 inputNames: [],
63448 category: 'graph',
63449 inputParams: {},
63450 attrParams: { dtype: { value: parseDtypeParam(arg.type), type: 'dtype' } },
63451 children: []
63452 };
63453 node.signatureKey = arg.name;
63454 inputs.push(node);
63455 nodes[nodeName] = node;
63456 });
63457 const allNodes = Object.keys(nodes);
63458 allNodes.forEach(key => {
63459 const node = nodes[key];
63460 node.inputNames.forEach((name, index) => {
63461 const [nodeName, , outputName] = getNodeNameAndIndex(name);
63462 const inputNode = nodes[nodeName];
63463 if (inputNode.outputs != null) {
63464 const outputIndex = inputNode.outputs.indexOf(outputName);
63465 if (outputIndex !== -1) {
63466 const inputName = `${nodeName}:${outputIndex}`;
63467 // update the input name to use the mapped output index directly.
63468 node.inputNames[index] = inputName;
63469 }
63470 }
63471 node.inputs.push(inputNode);
63472 inputNode.children.push(node);
63473 });
63474 });
63475 const returnNodeMap = functionDef.ret;
63476 functionDef.signature.outputArg.forEach(output => {
63477 const [nodeName, index] = getNodeNameAndIndex(returnNodeMap[output.name]);
63478 const node = nodes[nodeName];
63479 if (node != null) {
63480 node.defaultOutput = index;
63481 outputs.push(node);
63482 }
63483 });
63484 const signature = this.mapArgsToSignature(functionDef);
63485 return { nodes, inputs, outputs, weights, placeholders, signature };
63486 }
63487 mapArgsToSignature(functionDef) {
63488 return {
63489 methodName: functionDef.signature.name,
63490 inputs: functionDef.signature.inputArg.reduce((map, arg) => {
63491 map[arg.name] = this.mapArgToTensorInfo(arg);
63492 return map;
63493 }, {}),
63494 outputs: functionDef.signature.outputArg.reduce((map, arg) => {
63495 map[arg.name] = this.mapArgToTensorInfo(arg, functionDef.ret);
63496 return map;
63497 }, {}),
63498 };
63499 }
63500 mapArgToTensorInfo(arg, nameMap) {
63501 let name = arg.name;
63502 if (nameMap != null) {
63503 name = nameMap[name];
63504 }
63505 return { name, dtype: arg.type };
63506 }
63507 }
63508 function decodeBase64(text) {
63509 const global = env().global;
63510 if (typeof global.atob !== 'undefined') {
63511 return global.atob(text);
63512 }
63513 else if (typeof Buffer !== 'undefined') {
63514 return new Buffer(text, 'base64').toString();
63515 }
63516 else {
63517 throw new Error('Unable to decode base64 in this environment. ' +
63518 'Missing built-in atob() or Buffer()');
63519 }
63520 }
63521 function parseStringParam(s, keepCase) {
63522 const value = Array.isArray(s) ? String.fromCharCode.apply(null, s) : decodeBase64(s);
63523 return keepCase ? value : value.toLowerCase();
63524 }
63525 function getStringParam(attrs, name, def, keepCase = false) {
63526 const param = attrs[name];
63527 if (param != null) {
63528 return parseStringParam(param.s, keepCase);
63529 }
63530 return def;
63531 }
63532 function getBoolParam(attrs, name, def) {
63533 const param = attrs[name];
63534 return param ? param.b : def;
63535 }
63536 function getNumberParam(attrs, name, def) {
63537 const param = attrs[name] || {};
63538 const value = param['i'] != null ? param['i'] : (param['f'] != null ? param['f'] : def);
63539 return (typeof value === 'number') ? value : parseInt(value, 10);
63540 }
63541 function parseDtypeParam(value) {
63542 if (typeof (value) === 'string') {
63543 // tslint:disable-next-line:no-any
63544 value = DataType[value];
63545 }
63546 switch (value) {
63547 case DataType.DT_FLOAT:
63548 case DataType.DT_HALF:
63549 return 'float32';
63550 case DataType.DT_INT32:
63551 case DataType.DT_INT64:
63552 case DataType.DT_INT8:
63553 case DataType.DT_UINT8:
63554 return 'int32';
63555 case DataType.DT_BOOL:
63556 return 'bool';
63557 case DataType.DT_DOUBLE:
63558 return 'float32';
63559 case DataType.DT_STRING:
63560 return 'string';
63561 case DataType.DT_COMPLEX64:
63562 case DataType.DT_COMPLEX128:
63563 return 'complex64';
63564 default:
63565 // Unknown dtype error will happen at runtime (instead of parse time),
63566 // since these nodes might not be used by the actual subgraph execution.
63567 return null;
63568 }
63569 }
63570 function getFuncParam(attrs, name, def) {
63571 const param = attrs[name];
63572 if (param && param.func) {
63573 return param.func.name;
63574 }
63575 return def;
63576 }
63577 function getDtypeParam(attrs, name, def) {
63578 const param = attrs[name];
63579 if (param && param.type) {
63580 return parseDtypeParam(param.type);
63581 }
63582 return def;
63583 }
63584 function getDtypeArrayParam(attrs, name, def) {
63585 const param = attrs[name];
63586 if (param && param.list && param.list.type) {
63587 return param.list.type.map(v => parseDtypeParam(v));
63588 }
63589 return def;
63590 }
63591 function parseTensorShapeParam(shape) {
63592 if (shape.unknownRank) {
63593 return undefined;
63594 }
63595 if (shape.dim != null) {
63596 return shape.dim.map(dim => (typeof dim.size === 'number') ? dim.size : parseInt(dim.size, 10));
63597 }
63598 return [];
63599 }
63600 function getTensorShapeParam(attrs, name, def) {
63601 const param = attrs[name];
63602 if (param && param.shape) {
63603 return parseTensorShapeParam(param.shape);
63604 }
63605 return def;
63606 }
63607 function getNumericArrayParam(attrs, name, def) {
63608 const param = attrs[name];
63609 if (param) {
63610 return ((param.list.f && param.list.f.length ? param.list.f :
63611 param.list.i) ||
63612 [])
63613 .map(v => (typeof v === 'number') ? v : parseInt(v, 10));
63614 }
63615 return def;
63616 }
63617 function getStringArrayParam(attrs, name, def, keepCase = false) {
63618 const param = attrs[name];
63619 if (param && param.list && param.list.s) {
63620 return param.list.s.map((v) => {
63621 return parseStringParam(v, keepCase);
63622 });
63623 }
63624 return def;
63625 }
63626 function getTensorShapeArrayParam(attrs, name, def) {
63627 const param = attrs[name];
63628 if (param && param.list && param.list.shape) {
63629 return param.list.shape.map((v) => {
63630 return parseTensorShapeParam(v);
63631 });
63632 }
63633 return def;
63634 }
63635 function getBoolArrayParam(attrs, name, def) {
63636 const param = attrs[name];
63637 if (param && param.list && param.list.b) {
63638 return param.list.b;
63639 }
63640 return def;
63641 }
63642
63643 /**
63644 * @license
63645 * Copyright 2019 Google LLC. All Rights Reserved.
63646 * Licensed under the Apache License, Version 2.0 (the "License");
63647 * you may not use this file except in compliance with the License.
63648 * You may obtain a copy of the License at
63649 *
63650 * http://www.apache.org/licenses/LICENSE-2.0
63651 *
63652 * Unless required by applicable law or agreed to in writing, software
63653 * distributed under the License is distributed on an "AS IS" BASIS,
63654 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
63655 * See the License for the specific language governing permissions and
63656 * limitations under the License.
63657 * =============================================================================
63658 */
63659 /**
63660 * Helper class for lookup inputs and params for nodes in the model graph.
63661 */
63662 class NodeValueImpl {
63663 constructor(node, tensorMap, context) {
63664 this.node = node;
63665 this.tensorMap = tensorMap;
63666 this.context = context;
63667 this.inputs = [];
63668 this.attrs = {};
63669 this.inputs = node.inputNames.map(name => this.getInput(name));
63670 if (node.rawAttrs != null) {
63671 this.attrs = Object.keys(node.rawAttrs)
63672 .reduce((attrs, key) => {
63673 attrs[key] = this.getAttr(key);
63674 return attrs;
63675 }, {});
63676 }
63677 }
63678 /**
63679 * Return the value of the attribute or input param.
63680 * @param name String: name of attribute or input param.
63681 */
63682 getInput(name) {
63683 return getTensor(name, this.tensorMap, this.context);
63684 }
63685 /**
63686 * Return the value of the attribute or input param.
63687 * @param name String: name of attribute or input param.
63688 */
63689 getAttr(name, defaultValue) {
63690 const value = this.node.rawAttrs[name];
63691 if (value.tensor != null) {
63692 return getTensor(name, this.tensorMap, this.context);
63693 }
63694 if (value.i != null || value.f != null) {
63695 return getNumberParam(this.node.rawAttrs, name, defaultValue);
63696 }
63697 if (value.s != null) {
63698 return getStringParam(this.node.rawAttrs, name, defaultValue);
63699 }
63700 if (value.b != null) {
63701 return getBoolParam(this.node.rawAttrs, name, defaultValue);
63702 }
63703 if (value.shape != null) {
63704 return getTensorShapeParam(this.node.rawAttrs, name, defaultValue);
63705 }
63706 if (value.type != null) {
63707 return getDtypeParam(this.node.rawAttrs, name, defaultValue);
63708 }
63709 if (value.list != null) {
63710 if (value.list.i != null || value.list.f != null) {
63711 return getNumericArrayParam(this.node.rawAttrs, name, defaultValue);
63712 }
63713 if (value.list.s != null) {
63714 return getStringArrayParam(this.node.rawAttrs, name, defaultValue);
63715 }
63716 if (value.list.shape != null) {
63717 return getTensorShapeArrayParam(this.node.rawAttrs, name, defaultValue);
63718 }
63719 if (value.list.b != null) {
63720 return getBoolArrayParam(this.node.rawAttrs, name, defaultValue);
63721 }
63722 if (value.list.type != null) {
63723 return getDtypeArrayParam(this.node.rawAttrs, name, defaultValue);
63724 }
63725 }
63726 return defaultValue;
63727 }
63728 }
63729
63730 /**
63731 * @license
63732 * Copyright 2020 Google LLC. All Rights Reserved.
63733 * Licensed under the Apache License, Version 2.0 (the "License");
63734 * you may not use this file except in compliance with the License.
63735 * You may obtain a copy of the License at
63736 *
63737 * http://www.apache.org/licenses/LICENSE-2.0
63738 *
63739 * Unless required by applicable law or agreed to in writing, software
63740 * distributed under the License is distributed on an "AS IS" BASIS,
63741 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
63742 * See the License for the specific language governing permissions and
63743 * limitations under the License.
63744 * =============================================================================
63745 */
63746
63747 var tfOps = /*#__PURE__*/Object.freeze({
63748 __proto__: null,
63749 OP_SCOPE_SUFFIX: OP_SCOPE_SUFFIX,
63750 abs: abs$2,
63751 acos: acos$2,
63752 acosh: acosh$2,
63753 add: add$3,
63754 addN: addN$2,
63755 all: all$2,
63756 any: any$2,
63757 argMax: argMax$2,
63758 argMin: argMin$2,
63759 asin: asin$2,
63760 asinh: asinh$2,
63761 atan: atan$2,
63762 atan2: atan2$2,
63763 atanh: atanh$2,
63764 avgPool: avgPool$2,
63765 avgPool3d: avgPool3d$1,
63766 basicLSTMCell: basicLSTMCell,
63767 batchNorm: batchNorm$2,
63768 batchNorm2d: batchNorm2d,
63769 batchNorm3d: batchNorm3d,
63770 batchNorm4d: batchNorm4d,
63771 batchToSpaceND: batchToSpaceND$2,
63772 bincount: bincount$2,
63773 bitwiseAnd: bitwiseAnd$2,
63774 booleanMaskAsync: booleanMaskAsync,
63775 broadcastArgs: broadcastArgs$2,
63776 broadcastTo: broadcastTo,
63777 buffer: buffer,
63778 cast: cast$3,
63779 ceil: ceil$2,
63780 clipByValue: clipByValue$2,
63781 clone: clone,
63782 complex: complex$2,
63783 concat: concat$2,
63784 concat1d: concat1d,
63785 concat2d: concat2d,
63786 concat3d: concat3d,
63787 concat4d: concat4d,
63788 conv1d: conv1d$2,
63789 conv2d: conv2d$4,
63790 conv2dTranspose: conv2dTranspose$1,
63791 conv3d: conv3d$2,
63792 conv3dTranspose: conv3dTranspose$1,
63793 cos: cos$2,
63794 cosh: cosh$2,
63795 cosineWindow: cosineWindow,
63796 cumprod: cumprod$2,
63797 cumsum: cumsum$2,
63798 denseBincount: denseBincount$2,
63799 depthToSpace: depthToSpace$2,
63800 depthwiseConv2d: depthwiseConv2d$3,
63801 diag: diag$2,
63802 dilation2d: dilation2d,
63803 div: div$1,
63804 divNoNan: divNoNan,
63805 dot: dot$2,
63806 dropout: dropout$2,
63807 einsum: einsum$2,
63808 elu: elu$4,
63809 enclosingPowerOfTwo: enclosingPowerOfTwo,
63810 ensureShape: ensureShape,
63811 equal: equal$2,
63812 erf: erf$2,
63813 euclideanNorm: euclideanNorm,
63814 exp: exp$2,
63815 expandDims: expandDims$3,
63816 expm1: expm1$2,
63817 eye: eye,
63818 fft: fft$2,
63819 fill: fill$2,
63820 floor: floor$2,
63821 floorDiv: floorDiv$2,
63822 fused: fused_ops,
63823 gather: gather$1,
63824 gatherND: gatherND,
63825 greater: greater$3,
63826 greaterEqual: greaterEqual$2,
63827 ifft: ifft$2,
63828 imag: imag$2,
63829 image: image$1,
63830 inTopKAsync: inTopKAsync,
63831 irfft: irfft,
63832 isFinite: isFinite$3,
63833 isInf: isInf$2,
63834 isNaN: isNaN$3,
63835 leakyRelu: leakyRelu$2,
63836 less: less$3,
63837 lessEqual: lessEqual$2,
63838 linalg: linalg,
63839 linspace: linspace,
63840 localResponseNormalization: localResponseNormalization,
63841 log: log$2,
63842 log1p: log1p$2,
63843 logSigmoid: logSigmoid,
63844 logSoftmax: logSoftmax,
63845 logSumExp: logSumExp,
63846 logicalAnd: logicalAnd$2,
63847 logicalNot: logicalNot$2,
63848 logicalOr: logicalOr$2,
63849 logicalXor: logicalXor,
63850 losses: losses,
63851 lowerBound: lowerBound$1,
63852 matMul: matMul$1,
63853 max: max$3,
63854 maxPool: maxPool$2,
63855 maxPool3d: maxPool3d$1,
63856 maxPoolWithArgmax: maxPoolWithArgmax,
63857 maximum: maximum$4,
63858 mean: mean$3,
63859 meshgrid: meshgrid,
63860 min: min$3,
63861 minimum: minimum$4,
63862 mirrorPad: mirrorPad$1,
63863 mod: mod$2,
63864 moments: moments,
63865 movingAverage: movingAverage,
63866 mul: mul,
63867 multiRNNCell: multiRNNCell,
63868 multinomial: multinomial$2,
63869 neg: neg$2,
63870 norm: norm,
63871 notEqual: notEqual$2,
63872 oneHot: oneHot$3,
63873 ones: ones$1,
63874 onesLike: onesLike$3,
63875 op: op,
63876 outerProduct: outerProduct,
63877 pad: pad,
63878 pad1d: pad1d,
63879 pad2d: pad2d,
63880 pad3d: pad3d,
63881 pad4d: pad4d,
63882 pool: pool$1,
63883 pow: pow$3,
63884 prelu: prelu$3,
63885 print: print,
63886 prod: prod$2,
63887 raggedGather: raggedGather$2,
63888 raggedRange: raggedRange$2,
63889 raggedTensorToTensor: raggedTensorToTensor$2,
63890 rand: rand,
63891 randomGamma: randomGamma,
63892 randomNormal: randomNormal$2,
63893 randomStandardNormal: randomStandardNormal,
63894 randomUniform: randomUniform$1,
63895 randomUniformInt: randomUniformInt,
63896 range: range$3,
63897 real: real$2,
63898 reciprocal: reciprocal$2,
63899 relu: relu$2,
63900 relu6: relu6$2,
63901 reshape: reshape$3,
63902 reverse: reverse$2,
63903 reverse1d: reverse1d,
63904 reverse2d: reverse2d,
63905 reverse3d: reverse3d,
63906 reverse4d: reverse4d,
63907 rfft: rfft,
63908 round: round$2,
63909 rsqrt: rsqrt$2,
63910 scalar: scalar,
63911 scatterND: scatterND,
63912 searchSorted: searchSorted$2,
63913 selu: selu$2,
63914 separableConv2d: separableConv2d$1,
63915 setdiff1dAsync: setdiff1dAsync,
63916 sigmoid: sigmoid$2,
63917 sign: sign$3,
63918 signal: signal,
63919 sin: sin$2,
63920 sinh: sinh$2,
63921 slice: slice$2,
63922 slice1d: slice1d,
63923 slice2d: slice2d,
63924 slice3d: slice3d,
63925 slice4d: slice4d,
63926 softmax: softmax$3,
63927 softplus: softplus$2,
63928 spaceToBatchND: spaceToBatchND$2,
63929 sparse: sparse$1,
63930 sparseToDense: sparseToDense$2,
63931 spectral: spectral$1,
63932 split: split$3,
63933 sqrt: sqrt$2,
63934 square: square$2,
63935 squaredDifference: squaredDifference$2,
63936 squeeze: squeeze,
63937 stack: stack,
63938 step: step$2,
63939 stridedSlice: stridedSlice$2,
63940 string: string$1,
63941 sub: sub$2,
63942 sum: sum$3,
63943 tan: tan$2,
63944 tanh: tanh$2,
63945 tensor: tensor,
63946 tensor1d: tensor1d,
63947 tensor2d: tensor2d,
63948 tensor3d: tensor3d,
63949 tensor4d: tensor4d,
63950 tensor5d: tensor5d,
63951 tensor6d: tensor6d,
63952 tensorScatterUpdate: tensorScatterUpdate$2,
63953 tile: tile$3,
63954 topk: topk,
63955 transpose: transpose$2,
63956 truncatedNormal: truncatedNormal$1,
63957 unique: unique$3,
63958 unsortedSegmentSum: unsortedSegmentSum$2,
63959 unstack: unstack,
63960 upperBound: upperBound$1,
63961 variable: variable$1,
63962 where: where,
63963 whereAsync: whereAsync,
63964 zeros: zeros$2,
63965 zerosLike: zerosLike$3
63966 });
63967
63968 /**
63969 * @license
63970 * Copyright 2018 Google LLC. All Rights Reserved.
63971 * Licensed under the Apache License, Version 2.0 (the "License");
63972 * you may not use this file except in compliance with the License.
63973 * You may obtain a copy of the License at
63974 *
63975 * http://www.apache.org/licenses/LICENSE-2.0
63976 *
63977 * Unless required by applicable law or agreed to in writing, software
63978 * distributed under the License is distributed on an "AS IS" BASIS,
63979 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
63980 * See the License for the specific language governing permissions and
63981 * limitations under the License.
63982 * =============================================================================
63983 */
63984 const executeOp$k = (node, tensorMap, context, ops = tfOps) => {
63985 switch (node.op) {
63986 case 'BiasAdd':
63987 case 'AddV2':
63988 case 'Add': {
63989 return [ops.add(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
63990 }
63991 case 'AddN': {
63992 return [ops.addN(getParamValue('tensors', node, tensorMap, context))];
63993 }
63994 case 'FloorMod':
63995 case 'Mod':
63996 return [ops.mod(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
63997 case 'Mul':
63998 return [ops.mul(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
63999 case 'RealDiv':
64000 case 'Div': {
64001 return [ops.div(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
64002 }
64003 case 'DivNoNan': {
64004 return [ops.divNoNan(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
64005 }
64006 case 'FloorDiv': {
64007 return [ops.floorDiv(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
64008 }
64009 case 'Sub': {
64010 return [ops.sub(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
64011 }
64012 case 'Minimum': {
64013 return [ops.minimum(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
64014 }
64015 case 'Maximum': {
64016 return [ops.maximum(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
64017 }
64018 case 'Pow': {
64019 return [ops.pow(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
64020 }
64021 case 'SquaredDifference': {
64022 return [ops.squaredDifference(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
64023 }
64024 default:
64025 throw TypeError(`Node type ${node.op} is not implemented`);
64026 }
64027 };
64028 const CATEGORY$j = 'arithmetic';
64029
64030 /**
64031 * @license
64032 * Copyright 2018 Google LLC. All Rights Reserved.
64033 * Licensed under the Apache License, Version 2.0 (the "License");
64034 * you may not use this file except in compliance with the License.
64035 * You may obtain a copy of the License at
64036 *
64037 * http://www.apache.org/licenses/LICENSE-2.0
64038 *
64039 * Unless required by applicable law or agreed to in writing, software
64040 * distributed under the License is distributed on an "AS IS" BASIS,
64041 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
64042 * See the License for the specific language governing permissions and
64043 * limitations under the License.
64044 * =============================================================================
64045 */
64046 const executeOp$j = (node, tensorMap, context, ops = tfOps) => {
64047 switch (node.op) {
64048 case 'Abs':
64049 case 'ComplexAbs':
64050 return [ops.abs(getParamValue('x', node, tensorMap, context))];
64051 case 'Acos':
64052 return [ops.acos(getParamValue('x', node, tensorMap, context))];
64053 case 'Acosh':
64054 return [ops.acosh(getParamValue('x', node, tensorMap, context))];
64055 case 'Asin':
64056 return [ops.asin(getParamValue('x', node, tensorMap, context))];
64057 case 'Asinh':
64058 return [ops.asinh(getParamValue('x', node, tensorMap, context))];
64059 case 'Atan':
64060 return [ops.atan(getParamValue('x', node, tensorMap, context))];
64061 case 'Atan2':
64062 return [ops.atan2(getParamValue('x', node, tensorMap, context), getParamValue('y', node, tensorMap, context))];
64063 case 'Atanh':
64064 return [ops.atanh(getParamValue('x', node, tensorMap, context))];
64065 case 'Ceil':
64066 return [ops.ceil(getParamValue('x', node, tensorMap, context))];
64067 case 'Complex':
64068 return [ops.complex(getParamValue('real', node, tensorMap, context), getParamValue('imag', node, tensorMap, context))];
64069 case 'Cos':
64070 return [ops.cos(getParamValue('x', node, tensorMap, context))];
64071 case 'Cosh':
64072 return [ops.cosh(getParamValue('x', node, tensorMap, context))];
64073 case 'Elu':
64074 return [ops.elu(getParamValue('x', node, tensorMap, context))];
64075 case 'Erf':
64076 return [ops.erf(getParamValue('x', node, tensorMap, context))];
64077 case 'Exp':
64078 return [ops.exp(getParamValue('x', node, tensorMap, context))];
64079 case 'Expm1': {
64080 return [ops.expm1(getParamValue('x', node, tensorMap, context))];
64081 }
64082 case 'Floor':
64083 return [ops.floor(getParamValue('x', node, tensorMap, context))];
64084 case 'Log':
64085 return [ops.log(getParamValue('x', node, tensorMap, context))];
64086 case 'Log1p': {
64087 return [ops.log1p(getParamValue('x', node, tensorMap, context))];
64088 }
64089 case 'Imag':
64090 return [ops.imag(getParamValue('x', node, tensorMap, context))];
64091 case 'Neg':
64092 return [ops.neg(getParamValue('x', node, tensorMap, context))];
64093 case 'Reciprocal': {
64094 return [ops.reciprocal(getParamValue('x', node, tensorMap, context))];
64095 }
64096 case 'Real':
64097 return [ops.real(getParamValue('x', node, tensorMap, context))];
64098 case 'Relu':
64099 return [ops.relu(getParamValue('x', node, tensorMap, context))];
64100 case 'Round': {
64101 return [ops.round(getParamValue('x', node, tensorMap, context))];
64102 }
64103 case 'Selu':
64104 return [ops.selu(getParamValue('x', node, tensorMap, context))];
64105 case 'Sigmoid':
64106 return [ops.sigmoid(getParamValue('x', node, tensorMap, context))];
64107 case 'Sin':
64108 return [ops.sin(getParamValue('x', node, tensorMap, context))];
64109 case 'Sign': {
64110 return [ops.sign(getParamValue('x', node, tensorMap, context))];
64111 }
64112 case 'Sinh': {
64113 return [ops.sinh(getParamValue('x', node, tensorMap, context))];
64114 }
64115 case 'Softplus': {
64116 return [ops.softplus(getParamValue('x', node, tensorMap, context))];
64117 }
64118 case 'Sqrt': {
64119 return [ops.sqrt(getParamValue('x', node, tensorMap, context))];
64120 }
64121 case 'Square': {
64122 return [ops.square(getParamValue('x', node, tensorMap, context))];
64123 }
64124 case 'Tanh': {
64125 return [ops.tanh(getParamValue('x', node, tensorMap, context))];
64126 }
64127 case 'Tan':
64128 return [ops.tan(getParamValue('x', node, tensorMap, context))];
64129 case 'ClipByValue':
64130 return [ops.clipByValue(getParamValue('x', node, tensorMap, context), getParamValue('clipValueMin', node, tensorMap, context), getParamValue('clipValueMax', node, tensorMap, context))];
64131 case 'Relu6':
64132 return [ops.relu6(getParamValue('x', node, tensorMap, context))];
64133 case 'Rsqrt':
64134 return [ops.rsqrt(getTensor(node.inputNames[0], tensorMap, context))];
64135 case 'LeakyRelu':
64136 return [ops.leakyRelu(getParamValue('x', node, tensorMap, context), getParamValue('alpha', node, tensorMap, context))];
64137 case 'Prelu':
64138 return [ops.prelu(getParamValue('x', node, tensorMap, context), getParamValue('alpha', node, tensorMap, context))];
64139 case 'IsNan':
64140 return [ops.isNaN(getTensor(node.inputNames[0], tensorMap, context))];
64141 case 'IsInf':
64142 return [ops.isInf(getTensor(node.inputNames[0], tensorMap, context))];
64143 case 'IsFinite':
64144 return [ops.isFinite(getTensor(node.inputNames[0], tensorMap, context))];
64145 default:
64146 throw TypeError(`Node type ${node.op} is not implemented`);
64147 }
64148 };
64149 const CATEGORY$i = 'basic_math';
64150
64151 /**
64152 * @license
64153 * Copyright 2020 Google LLC. All Rights Reserved.
64154 * Licensed under the Apache License, Version 2.0 (the "License");
64155 * you may not use this file except in compliance with the License.
64156 * You may obtain a copy of the License at
64157 *
64158 * http://www.apache.org/licenses/LICENSE-2.0
64159 *
64160 * Unless required by applicable law or agreed to in writing, software
64161 * distributed under the License is distributed on an "AS IS" BASIS,
64162 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
64163 * See the License for the specific language governing permissions and
64164 * limitations under the License.
64165 * =============================================================================
64166 */
64167 /**
64168 * Used by TensorList and TensorArray to verify if elementShape matches, support
64169 * negative value as the dim shape.
64170 * @param shapeA
64171 * @param shapeB
64172 * @param errorMessagePrefix
64173 */
64174 function assertShapesMatchAllowUndefinedSize(shapeA, shapeB, errorMessagePrefix = '') {
64175 // constant shape means unknown rank
64176 if (typeof shapeA === 'number' || typeof shapeB === 'number') {
64177 return;
64178 }
64179 assert$1(shapeA.length === shapeB.length, () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);
64180 for (let i = 0; i < shapeA.length; i++) {
64181 const dim0 = shapeA[i];
64182 const dim1 = shapeB[i];
64183 assert$1(dim0 < 0 || dim1 < 0 || dim0 === dim1, () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);
64184 }
64185 }
64186 function fullDefinedShape(elementShape) {
64187 if (typeof elementShape === 'number' || elementShape.some(dim => dim < 0)) {
64188 return false;
64189 }
64190 return true;
64191 }
64192 /**
64193 * Generate the output element shape from the list elementShape, list tensors
64194 * and input param.
64195 * @param listElementShape
64196 * @param tensors
64197 * @param elementShape
64198 */
64199 function inferElementShape(listElementShape, tensors, elementShape) {
64200 let partialShape = mergeElementShape(listElementShape, elementShape);
64201 const notfullDefinedShape = !fullDefinedShape(partialShape);
64202 if (notfullDefinedShape && tensors.length === 0) {
64203 throw new Error(`Tried to calculate elements of an empty list` +
64204 ` with non-fully-defined elementShape: ${partialShape}`);
64205 }
64206 if (notfullDefinedShape) {
64207 tensors.forEach(tensor => {
64208 partialShape = mergeElementShape(tensor.shape, partialShape);
64209 });
64210 }
64211 if (!fullDefinedShape(partialShape)) {
64212 throw new Error(`Non-fully-defined elementShape: ${partialShape}`);
64213 }
64214 return partialShape;
64215 }
64216 function mergeElementShape(elementShapeA, elementShapeB) {
64217 if (typeof elementShapeA === 'number') {
64218 return elementShapeB;
64219 }
64220 if (typeof elementShapeB === 'number') {
64221 return elementShapeA;
64222 }
64223 if (elementShapeA.length !== elementShapeB.length) {
64224 throw new Error(`Incompatible ranks during merge: ${elementShapeA} vs. ${elementShapeB}`);
64225 }
64226 const result = [];
64227 for (let i = 0; i < elementShapeA.length; ++i) {
64228 const dim0 = elementShapeA[i];
64229 const dim1 = elementShapeB[i];
64230 if (dim0 >= 0 && dim1 >= 0 && dim0 !== dim1) {
64231 throw new Error(`Incompatible shape during merge: ${elementShapeA} vs. ${elementShapeB}`);
64232 }
64233 result[i] = dim0 >= 0 ? dim0 : dim1;
64234 }
64235 return result;
64236 }
64237
64238 /**
64239 * @license
64240 * Copyright 2018 Google LLC. All Rights Reserved.
64241 * Licensed under the Apache License, Version 2.0 (the "License");
64242 * you may not use this file except in compliance with the License.
64243 * You may obtain a copy of the License at
64244 *
64245 * http://www.apache.org/licenses/LICENSE-2.0
64246 *
64247 * Unless required by applicable law or agreed to in writing, software
64248 * distributed under the License is distributed on an "AS IS" BASIS,
64249 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
64250 * See the License for the specific language governing permissions and
64251 * limitations under the License.
64252 * =============================================================================
64253 */
64254 /**
64255 * The TensorArray object keeps an array of Tensors. It
64256 * allows reading from the array and writing to the array.
64257 */
64258 class TensorArray {
64259 constructor(name, dtype, maxSize, elementShape, identicalElementShapes, dynamicSize, clearAfterRead) {
64260 this.name = name;
64261 this.dtype = dtype;
64262 this.maxSize = maxSize;
64263 this.elementShape = elementShape;
64264 this.identicalElementShapes = identicalElementShapes;
64265 this.dynamicSize = dynamicSize;
64266 this.clearAfterRead = clearAfterRead;
64267 this.tensors = [];
64268 this.closed_ = false;
64269 this.idTensor = scalar(0);
64270 keep(this.idTensor);
64271 }
64272 get id() {
64273 return this.idTensor.id;
64274 }
64275 get closed() {
64276 return this.closed_;
64277 }
64278 /**
64279 * Dispose the tensors and idTensor and mark the TensoryArray as closed.
64280 */
64281 clearAndClose(keepIds) {
64282 this.tensors.forEach(tensor => {
64283 if (keepIds == null || !keepIds.has(tensor.tensor.id)) {
64284 tensor.tensor.dispose();
64285 }
64286 });
64287 this.tensors = [];
64288 this.closed_ = true;
64289 this.idTensor.dispose();
64290 }
64291 size() {
64292 return this.tensors.length;
64293 }
64294 /**
64295 * Read the value at location index in the TensorArray.
64296 * @param index Number the index to read from.
64297 */
64298 read(index) {
64299 if (this.closed_) {
64300 throw new Error(`TensorArray ${this.name} has already been closed.`);
64301 }
64302 if (index < 0 || index >= this.size()) {
64303 throw new Error(`Tried to read from index ${index}, but array size is: ${this.size()}`);
64304 }
64305 const tensorWithState = this.tensors[index];
64306 if (tensorWithState.cleared) {
64307 throw new Error(`TensorArray ${this.name}: Could not read index ${index} twice because it was cleared after a previous read ` +
64308 `(perhaps try setting clear_after_read = false?).`);
64309 }
64310 if (this.clearAfterRead) {
64311 tensorWithState.cleared = true;
64312 }
64313 tensorWithState.read = true;
64314 return tensorWithState.tensor;
64315 }
64316 /**
64317 * Helper method to read multiple tensors from the specified indices.
64318 */
64319 readMany(indices) {
64320 return indices.map(index => this.read(index));
64321 }
64322 /**
64323 * Write value into the index of the TensorArray.
64324 * @param index number the index to write to.
64325 * @param tensor
64326 */
64327 write(index, tensor) {
64328 if (this.closed_) {
64329 throw new Error(`TensorArray ${this.name} has already been closed.`);
64330 }
64331 if (index < 0 || !this.dynamicSize && index >= this.maxSize) {
64332 throw new Error(`Tried to write to index ${index}, but array is not resizeable and size is: ${this.maxSize}`);
64333 }
64334 const t = this.tensors[index] || {};
64335 if (tensor.dtype !== this.dtype) {
64336 throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index},
64337 because the value dtype is ${tensor.dtype}, but TensorArray dtype is ${this.dtype}.`);
64338 }
64339 // Set the shape for the first time write to unknow shape tensor array
64340 if (this.size() === 0 &&
64341 (this.elementShape == null || this.elementShape.length === 0)) {
64342 this.elementShape = tensor.shape;
64343 }
64344 assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, `TensorArray ${this.name}: Could not write to TensorArray index ${index}.`);
64345 if (t.read) {
64346 throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index}, because it has already been read.`);
64347 }
64348 if (t.written) {
64349 throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index}, because it has already been written.`);
64350 }
64351 t.tensor = tensor;
64352 keep(tensor);
64353 t.written = true;
64354 this.tensors[index] = t;
64355 }
64356 /**
64357 * Helper method to write multiple tensors to the specified indices.
64358 */
64359 writeMany(indices, tensors) {
64360 if (indices.length !== tensors.length) {
64361 throw new Error(`TensorArray ${this.name}: could not write multiple tensors,` +
64362 `because the index size: ${indices.length} is not the same as tensors size: ${tensors.length}.`);
64363 }
64364 indices.forEach((i, index) => this.write(i, tensors[index]));
64365 }
64366 /**
64367 * Return selected values in the TensorArray as a packed Tensor. All of
64368 * selected values must have been written and their shapes must all match.
64369 * @param [indices] number[] Optional. Taking values in [0, max_value). If the
64370 * TensorArray is not dynamic, max_value=size(). If not specified returns
64371 * all tensors in the original order.
64372 * @param [dtype]
64373 */
64374 gather(indices, dtype) {
64375 if (!!dtype && dtype !== this.dtype) {
64376 throw new Error(`TensorArray dtype is ${this.dtype} but gather requested dtype ${dtype}`);
64377 }
64378 if (!indices) {
64379 indices = [];
64380 for (let i = 0; i < this.size(); i++) {
64381 indices.push(i);
64382 }
64383 }
64384 else {
64385 indices = indices.slice(0, this.size());
64386 }
64387 if (indices.length === 0) {
64388 return tensor([], [0].concat(this.elementShape));
64389 }
64390 // Read all the PersistentTensors into a vector to keep track of
64391 // their memory.
64392 const tensors = this.readMany(indices);
64393 assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, 'TensorArray shape mismatch: ');
64394 return stack(tensors, 0);
64395 }
64396 /**
64397 * Return the values in the TensorArray as a concatenated Tensor.
64398 */
64399 concat(dtype) {
64400 if (!!dtype && dtype !== this.dtype) {
64401 throw new Error(`TensorArray dtype is ${this.dtype} but concat requested dtype ${dtype}`);
64402 }
64403 if (this.size() === 0) {
64404 return tensor([], [0].concat(this.elementShape));
64405 }
64406 const indices = [];
64407 for (let i = 0; i < this.size(); i++) {
64408 indices.push(i);
64409 }
64410 // Collect all the tensors from the tensors array.
64411 const tensors = this.readMany(indices);
64412 assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, `TensorArray shape mismatch: tensor array shape (${this.elementShape}) vs first tensor shape (${tensors[0].shape})`);
64413 return concat$2(tensors, 0);
64414 }
64415 /**
64416 * Scatter the values of a Tensor in specific indices of a TensorArray.
64417 * @param indices number[] values in [0, max_value). If the
64418 * TensorArray is not dynamic, max_value=size().
64419 * @param tensor Tensor input tensor.
64420 */
64421 scatter(indices, tensor) {
64422 if (tensor.dtype !== this.dtype) {
64423 throw new Error(`TensorArray dtype is ${this.dtype} but tensor has dtype ${tensor.dtype}`);
64424 }
64425 if (indices.length !== tensor.shape[0]) {
64426 throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${indices.length} vs. ${tensor.shape[0]}`);
64427 }
64428 const maxIndex = Math.max(...indices);
64429 if (!this.dynamicSize && maxIndex >= this.maxSize) {
64430 throw new Error(`Max index must be < array size (${maxIndex} vs. ${this.maxSize})`);
64431 }
64432 this.writeMany(indices, unstack(tensor, 0));
64433 }
64434 /**
64435 * Split the values of a Tensor into the TensorArray.
64436 * @param length number[] with the lengths to use when splitting value along
64437 * its first dimension.
64438 * @param tensor Tensor, the tensor to split.
64439 */
64440 split(length, tensor) {
64441 if (tensor.dtype !== this.dtype) {
64442 throw new Error(`TensorArray dtype is ${this.dtype} but tensor has dtype ${tensor.dtype}`);
64443 }
64444 let totalLength = 0;
64445 const cumulativeLengths = length.map(len => {
64446 totalLength += len;
64447 return totalLength;
64448 });
64449 if (totalLength !== tensor.shape[0]) {
64450 throw new Error(`Expected sum of lengths to be equal to
64451 tensor.shape[0], but sum of lengths is
64452 ${totalLength}, and tensor's shape is: ${tensor.shape}`);
64453 }
64454 if (!this.dynamicSize && length.length !== this.maxSize) {
64455 throw new Error(`TensorArray's size is not equal to the size of lengths (${this.maxSize} vs. ${length.length}), ` +
64456 'and the TensorArray is not marked as dynamically resizeable');
64457 }
64458 const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
64459 const tensors = [];
64460 tidy(() => {
64461 tensor = reshape$3(tensor, [1, totalLength, elementPerRow]);
64462 for (let i = 0; i < length.length; ++i) {
64463 const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1];
64464 const indices = [0, previousLength, 0];
64465 const sizes = [1, length[i], elementPerRow];
64466 tensors[i] = reshape$3(slice$2(tensor, indices, sizes), this.elementShape);
64467 }
64468 return tensors;
64469 });
64470 const indices = [];
64471 for (let i = 0; i < length.length; i++) {
64472 indices[i] = i;
64473 }
64474 this.writeMany(indices, tensors);
64475 }
64476 }
64477
64478 /**
64479 * @license
64480 * Copyright 2020 Google LLC. All Rights Reserved.
64481 * Licensed under the Apache License, Version 2.0 (the "License");
64482 * you may not use this file except in compliance with the License.
64483 * You may obtain a copy of the License at
64484 *
64485 * http://www.apache.org/licenses/LICENSE-2.0
64486 *
64487 * Unless required by applicable law or agreed to in writing, software
64488 * distributed under the License is distributed on an "AS IS" BASIS,
64489 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
64490 * See the License for the specific language governing permissions and
64491 * limitations under the License.
64492 * =============================================================================
64493 */
64494 /**
64495 * TensorList stores a container of `tf.Tensor` objects, which are accessible
64496 * via tensors field.
64497 *
64498 * In order to get a copy of the underlying list, use the copy method:
64499 * ```
64500 * TensorList b = a.copy();
64501 * b.tensors().pushBack(t); // This does not modify a.tensors().
64502 * ```
64503 *
64504 * Note that this is not a deep copy: the memory locations of the underlying
64505 * tensors will still point to the same locations of the corresponding tensors
64506 * in the original.
64507 */
64508 class TensorList {
64509 get id() {
64510 return this.idTensor.id;
64511 }
64512 /**
64513 *
64514 * @param tensors list of tensors
64515 * @param elementShape shape of each tensor, this can be a single number (any
64516 * shape is allowed) or partial shape (dim = -1).
64517 * @param elementDtype data type of each tensor
64518 * @param maxNumElements The maximum allowed size of `tensors`. Defaults to -1
64519 * meaning that the size of `tensors` is unbounded.
64520 */
64521 constructor(tensors, elementShape, elementDtype, maxNumElements = -1) {
64522 this.tensors = tensors;
64523 this.elementShape = elementShape;
64524 this.elementDtype = elementDtype;
64525 if (tensors != null) {
64526 tensors.forEach(tensor => {
64527 if (elementDtype !== tensor.dtype) {
64528 throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${tensor.dtype}`);
64529 }
64530 assertShapesMatchAllowUndefinedSize(elementShape, tensor.shape, 'TensorList shape mismatch: ');
64531 keep(tensor);
64532 });
64533 }
64534 this.idTensor = scalar(0);
64535 this.maxNumElements = maxNumElements;
64536 keep(this.idTensor);
64537 }
64538 /**
64539 * Get a new TensorList containing a copy of the underlying tensor container.
64540 */
64541 copy() {
64542 return new TensorList([...this.tensors], this.elementShape, this.elementDtype);
64543 }
64544 /**
64545 * Dispose the tensors and idTensor and clear the tensor list.
64546 */
64547 clearAndClose(keepIds) {
64548 this.tensors.forEach(tensor => {
64549 if (keepIds == null || !keepIds.has(tensor.id)) {
64550 tensor.dispose();
64551 }
64552 });
64553 this.tensors.length = 0;
64554 this.idTensor.dispose();
64555 }
64556 /**
64557 * The size of the tensors in the tensor list.
64558 */
64559 size() {
64560 return this.tensors.length;
64561 }
64562 /**
64563 * Return a tensor that stacks a list of rank-R tf.Tensors into one rank-(R+1)
64564 * tf.Tensor.
64565 * @param elementShape shape of each tensor
64566 * @param elementDtype data type of each tensor
64567 * @param numElements the number of elements to stack
64568 */
64569 stack(elementShape, elementDtype, numElements = -1) {
64570 if (elementDtype !== this.elementDtype) {
64571 throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
64572 }
64573 if (numElements !== -1 && this.tensors.length !== numElements) {
64574 throw new Error(`Operation expected a list with ${numElements} elements but got a list with ${this.tensors.length} elements.`);
64575 }
64576 assertShapesMatchAllowUndefinedSize(elementShape, this.elementShape, 'TensorList shape mismatch: ');
64577 const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
64578 return tidy(() => {
64579 const reshapedTensors = this.tensors.map(tensor => reshape$3(tensor, outputElementShape));
64580 return stack(reshapedTensors, 0);
64581 });
64582 }
64583 /**
64584 * Pop a tensor from the end of the list.
64585 * @param elementShape shape of the tensor
64586 * @param elementDtype data type of the tensor
64587 */
64588 popBack(elementShape, elementDtype) {
64589 if (elementDtype !== this.elementDtype) {
64590 throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
64591 }
64592 if (this.size() === 0) {
64593 throw new Error('Trying to pop from an empty list.');
64594 }
64595 const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
64596 const tensor = this.tensors.pop();
64597 tensor.kept = false;
64598 assertShapesMatchAllowUndefinedSize(tensor.shape, elementShape, 'TensorList shape mismatch: ');
64599 return reshape$3(tensor, outputElementShape);
64600 }
64601 /**
64602 * Push a tensor to the end of the list.
64603 * @param tensor Tensor to be pushed.
64604 */
64605 pushBack(tensor) {
64606 if (tensor.dtype !== this.elementDtype) {
64607 throw new Error(`Invalid data types; op elements ${tensor.dtype}, but list elements ${this.elementDtype}`);
64608 }
64609 assertShapesMatchAllowUndefinedSize(tensor.shape, this.elementShape, 'TensorList shape mismatch: ');
64610 if (this.maxNumElements === this.size()) {
64611 throw new Error(`Trying to push element into a full list.`);
64612 }
64613 keep(tensor);
64614 this.tensors.push(tensor);
64615 }
64616 /**
64617 * Update the size of the list.
64618 * @param size the new size of the list.
64619 */
64620 resize(size) {
64621 if (size < 0) {
64622 throw new Error(`TensorListResize expects size to be non-negative. Got: ${size}`);
64623 }
64624 if (this.maxNumElements !== -1 && size > this.maxNumElements) {
64625 throw new Error(`TensorListResize input size ${size} is greater maxNumElement ${this.maxNumElements}.`);
64626 }
64627 const destTensorList = new TensorList([], this.elementShape, this.elementDtype, this.maxNumElements);
64628 destTensorList.tensors.length = size;
64629 for (let i = 0; i < Math.min(this.tensors.length, size); ++i) {
64630 destTensorList.tensors[i] = this.tensors[i];
64631 }
64632 return destTensorList;
64633 }
64634 /**
64635 * Retrieve the element at the provided index
64636 * @param elementShape shape of the tensor
64637 * @param elementDtype dtype of the tensor
64638 * @param elementIndex index of the tensor
64639 */
64640 getItem(elementIndex, elementShape, elementDtype) {
64641 if (elementDtype !== this.elementDtype) {
64642 throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
64643 }
64644 if (elementIndex < 0 || elementIndex > this.tensors.length) {
64645 throw new Error(`Trying to access element ${elementIndex} in a list with ${this.tensors.length} elements.`);
64646 }
64647 if (this.tensors[elementIndex] == null) {
64648 throw new Error(`element at index ${elementIndex} is null.`);
64649 }
64650 assertShapesMatchAllowUndefinedSize(this.tensors[elementIndex].shape, elementShape, 'TensorList shape mismatch: ');
64651 const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
64652 return reshape$3(this.tensors[elementIndex], outputElementShape);
64653 }
64654 /**
64655 * Set the tensor at the index
64656 * @param elementIndex index of the tensor
64657 * @param tensor the tensor to be inserted into the list
64658 */
64659 setItem(elementIndex, tensor) {
64660 if (tensor.dtype !== this.elementDtype) {
64661 throw new Error(`Invalid data types; op elements ${tensor.dtype}, but list elements ${this.elementDtype}`);
64662 }
64663 if (elementIndex < 0 ||
64664 this.maxNumElements !== -1 && elementIndex >= this.maxNumElements) {
64665 throw new Error(`Trying to set element ${elementIndex} in a list with max ${this.maxNumElements} elements.`);
64666 }
64667 assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, 'TensorList shape mismatch: ');
64668 keep(tensor);
64669 // dispose the previous value if it is replacing.
64670 if (this.tensors[elementIndex] != null) {
64671 this.tensors[elementIndex].kept = false;
64672 }
64673 this.tensors[elementIndex] = tensor;
64674 }
64675 /**
64676 * Return selected values in the TensorList as a stacked Tensor. All of
64677 * selected values must have been written and their shapes must all match.
64678 * @param indices indices of tensors to gather
64679 * @param elementDtype output tensor dtype
64680 * @param elementShape output tensor element shape
64681 */
64682 gather(indices, elementDtype, elementShape) {
64683 if (elementDtype !== this.elementDtype) {
64684 throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
64685 }
64686 assertShapesMatchAllowUndefinedSize(this.elementShape, elementShape, 'TensorList shape mismatch: ');
64687 // When indices is greater than the size of the list, indices beyond the
64688 // size of the list are ignored.
64689 indices = indices.slice(0, this.size());
64690 const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
64691 if (indices.length === 0) {
64692 return tensor([], [0].concat(outputElementShape));
64693 }
64694 return tidy(() => {
64695 const tensors = indices.map(i => reshape$3(this.tensors[i], outputElementShape));
64696 return stack(tensors, 0);
64697 });
64698 }
64699 /**
64700 * Return the values in the TensorList as a concatenated Tensor.
64701 * @param elementDtype output tensor dtype
64702 * @param elementShape output tensor element shape
64703 */
64704 concat(elementDtype, elementShape) {
64705 if (!!elementDtype && elementDtype !== this.elementDtype) {
64706 throw new Error(`TensorList dtype is ${this.elementDtype} but concat requested dtype ${elementDtype}`);
64707 }
64708 assertShapesMatchAllowUndefinedSize(this.elementShape, elementShape, 'TensorList shape mismatch: ');
64709 const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
64710 if (this.size() === 0) {
64711 return tensor([], [0].concat(outputElementShape));
64712 }
64713 return tidy(() => {
64714 const tensors = this.tensors.map(t => reshape$3(t, outputElementShape));
64715 return concat$2(tensors, 0);
64716 });
64717 }
64718 }
64719 /**
64720 * Creates a TensorList which, when stacked, has the value of tensor.
64721 * @param tensor from tensor
64722 * @param elementShape output tensor element shape
64723 */
64724 function fromTensor(tensor, elementShape, elementDtype) {
64725 const dtype = tensor.dtype;
64726 if (tensor.shape.length < 1) {
64727 throw new Error(`Tensor must be at least a vector, but saw shape: ${tensor.shape}`);
64728 }
64729 if (tensor.dtype !== elementDtype) {
64730 throw new Error(`Invalid data types; op elements ${tensor.dtype}, but list elements ${elementDtype}`);
64731 }
64732 const tensorElementShape = tensor.shape.slice(1);
64733 assertShapesMatchAllowUndefinedSize(tensorElementShape, elementShape, 'TensorList shape mismatch: ');
64734 const tensorList = unstack(tensor);
64735 return new TensorList(tensorList, elementShape, dtype);
64736 }
64737 /**
64738 * Return a TensorList of the given size with empty elements.
64739 * @param elementShape the shape of the future elements of the list
64740 * @param elementDtype the desired type of elements in the list
64741 * @param numElements the number of elements to reserve
64742 * @param maxNumElements the maximum number of elements in th list
64743 */
64744 function reserve(elementShape, elementDtype, numElements, maxNumElements) {
64745 return new TensorList([], elementShape, elementDtype, maxNumElements);
64746 }
64747 /**
64748 * Put tensors at specific indices of a stacked tensor into a TensorList.
64749 * @param indices list of indices on how to scatter the tensor.
64750 * @param tensor input tensor.
64751 * @param elementShape the shape of the future elements of the list
64752 * @param numElements the number of elements to scatter
64753 */
64754 function scatter(tensor, indices, elementShape, numElements) {
64755 if (indices.length !== tensor.shape[0]) {
64756 throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${indices.length} vs. ${tensor.shape[0]}`);
64757 }
64758 const maxIndex = Math.max(...indices);
64759 if (numElements != null && numElements !== -1 && maxIndex >= numElements) {
64760 throw new Error(`Max index must be < array size (${maxIndex} vs. ${numElements})`);
64761 }
64762 const list = new TensorList([], elementShape, tensor.dtype, numElements);
64763 const tensors = unstack(tensor, 0);
64764 indices.forEach((value, index) => {
64765 list.setItem(value, tensors[index]);
64766 });
64767 return list;
64768 }
64769 /**
64770 * Split the values of a Tensor into a TensorList.
64771 * @param length the lengths to use when splitting value along
64772 * its first dimension.
64773 * @param tensor the tensor to split.
64774 * @param elementShape the shape of the future elements of the list
64775 */
64776 function split$1(tensor, length, elementShape) {
64777 let totalLength = 0;
64778 const cumulativeLengths = length.map(len => {
64779 totalLength += len;
64780 return totalLength;
64781 });
64782 if (totalLength !== tensor.shape[0]) {
64783 throw new Error(`Expected sum of lengths to be equal to
64784 tensor.shape[0], but sum of lengths is
64785 ${totalLength}, and tensor's shape is: ${tensor.shape}`);
64786 }
64787 const shapeWithoutFirstDim = tensor.shape.slice(1);
64788 const outputElementShape = mergeElementShape(shapeWithoutFirstDim, elementShape);
64789 const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
64790 const tensors = tidy(() => {
64791 const tensors = [];
64792 tensor = reshape$3(tensor, [1, totalLength, elementPerRow]);
64793 for (let i = 0; i < length.length; ++i) {
64794 const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1];
64795 const indices = [0, previousLength, 0];
64796 const sizes = [1, length[i], elementPerRow];
64797 tensors[i] = reshape$3(slice$2(tensor, indices, sizes), outputElementShape);
64798 }
64799 tensor.dispose();
64800 return tensors;
64801 });
64802 const list = new TensorList([], elementShape, tensor.dtype, length.length);
64803 for (let i = 0; i < tensors.length; i++) {
64804 list.setItem(i, tensors[i]);
64805 }
64806 return list;
64807 }
64808
64809 /**
64810 * @license
64811 * Copyright 2018 Google LLC. All Rights Reserved.
64812 * Licensed under the Apache License, Version 2.0 (the "License");
64813 * you may not use this file except in compliance with the License.
64814 * You may obtain a copy of the License at
64815 *
64816 * http://www.apache.org/licenses/LICENSE-2.0
64817 *
64818 * Unless required by applicable law or agreed to in writing, software
64819 * distributed under the License is distributed on an "AS IS" BASIS,
64820 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
64821 * See the License for the specific language governing permissions and
64822 * limitations under the License.
64823 * =============================================================================
64824 */
64825 const executeOp$i = async (node, tensorMap, context) => {
64826 switch (node.op) {
64827 case 'If':
64828 case 'StatelessIf': {
64829 const thenFunc = getParamValue('thenBranch', node, tensorMap, context);
64830 const elseFunc = getParamValue('elseBranch', node, tensorMap, context);
64831 const cond = getParamValue('cond', node, tensorMap, context);
64832 const args = getParamValue('args', node, tensorMap, context);
64833 const condValue = await cond.data();
64834 if (condValue[0]) {
64835 return context.functionMap[thenFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap);
64836 }
64837 else {
64838 return context.functionMap[elseFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap);
64839 }
64840 }
64841 case 'While':
64842 case 'StatelessWhile': {
64843 const bodyFunc = getParamValue('body', node, tensorMap, context);
64844 const condFunc = getParamValue('cond', node, tensorMap, context);
64845 const args = getParamValue('args', node, tensorMap, context);
64846 // Calculate the condition of the loop
64847 const condResult = (await context.functionMap[condFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap));
64848 const argIds = args.map(tensor => tensor.id);
64849 let condValue = await condResult[0].data();
64850 // Dispose the intermediate tensors for condition function
64851 condResult.forEach(tensor => {
64852 if (!tensor.kept && argIds.indexOf(tensor.id) === -1) {
64853 tensor.dispose();
64854 }
64855 });
64856 let result = args;
64857 while (condValue[0]) {
64858 // Record the previous result for intermediate tensor tracking
64859 const origResult = result;
64860 // Execution the body of the loop
64861 result = await context.functionMap[bodyFunc].executeFunctionAsync(result, context.tensorArrayMap, context.tensorListMap);
64862 const resultIds = result.map(tensor => tensor.id);
64863 // Dispose the intermediate tensor for body function that is not global
64864 // kept, not input/output of the body function
64865 origResult.forEach(tensor => {
64866 if (!tensor.kept && argIds.indexOf(tensor.id) === -1 &&
64867 resultIds.indexOf(tensor.id) === -1) {
64868 tensor.dispose();
64869 }
64870 });
64871 // Recalcuate the condition of the loop using the latest results.
64872 const condResult = (await context.functionMap[condFunc].executeFunctionAsync(result, context.tensorArrayMap, context.tensorListMap));
64873 condValue = await condResult[0].data();
64874 // Dispose the intermediate tensors for condition function
64875 condResult.forEach(tensor => {
64876 if (!tensor.kept && argIds.indexOf(tensor.id) === -1 &&
64877 resultIds.indexOf(tensor.id) === -1) {
64878 tensor.dispose();
64879 }
64880 });
64881 }
64882 return result;
64883 }
64884 case 'LoopCond': {
64885 const pred = getParamValue('pred', node, tensorMap, context);
64886 return [cloneTensor(pred)];
64887 }
64888 case 'Switch': {
64889 const pred = getParamValue('pred', node, tensorMap, context);
64890 let data = getParamValue('data', node, tensorMap, context);
64891 if (!data.kept) {
64892 data = cloneTensor(data);
64893 }
64894 // Outputs nodes :0 => false, :1 => true
64895 return (await pred.data())[0] ? [undefined, data] : [data, undefined];
64896 }
64897 case 'Merge': {
64898 const inputName = node.inputNames.find(name => getTensor(name, tensorMap, context) !== undefined);
64899 if (inputName) {
64900 const data = getTensor(inputName, tensorMap, context);
64901 return [cloneTensor(data)];
64902 }
64903 return undefined;
64904 }
64905 case 'Enter': {
64906 const frameId = getParamValue('frameName', node, tensorMap, context);
64907 const data = getParamValue('tensor', node, tensorMap, context);
64908 context.enterFrame(frameId);
64909 return [cloneTensor(data)];
64910 }
64911 case 'Exit': {
64912 const data = getParamValue('tensor', node, tensorMap, context);
64913 context.exitFrame();
64914 return [cloneTensor(data)];
64915 }
64916 case 'NextIteration': {
64917 const data = getParamValue('tensor', node, tensorMap, context);
64918 context.nextIteration();
64919 return [cloneTensor(data)];
64920 }
64921 case 'TensorArrayV3': {
64922 const size = getParamValue('size', node, tensorMap, context);
64923 const dtype = getParamValue('dtype', node, tensorMap, context);
64924 const elementShape = getParamValue('elementShape', node, tensorMap, context);
64925 const dynamicSize = getParamValue('dynamicSize', node, tensorMap, context);
64926 const clearAfterRead = getParamValue('clearAfterRead', node, tensorMap, context);
64927 const identicalElementShapes = getParamValue('identicalElementShapes', node, tensorMap, context);
64928 const name = getParamValue('name', node, tensorMap, context);
64929 const tensorArray = new TensorArray(name, dtype, size, elementShape, identicalElementShapes, dynamicSize, clearAfterRead);
64930 context.addTensorArray(tensorArray);
64931 return [tensorArray.idTensor, scalar(1.0)];
64932 }
64933 case 'TensorArrayWriteV3': {
64934 const id = getParamValue('tensorArrayId', node, tensorMap, context);
64935 const index = getParamValue('index', node, tensorMap, context);
64936 const writeTensor = getParamValue('tensor', node, tensorMap, context);
64937 const writeTensorArray = context.getTensorArray(id.id);
64938 writeTensorArray.write(index, writeTensor);
64939 return [writeTensorArray.idTensor];
64940 }
64941 case 'TensorArrayReadV3': {
64942 const readId = getParamValue('tensorArrayId', node, tensorMap, context);
64943 const readIndex = getParamValue('index', node, tensorMap, context);
64944 const readTensorArray = context.getTensorArray(readId.id);
64945 return [readTensorArray.read(readIndex)];
64946 }
64947 case 'TensorArrayGatherV3': {
64948 const gatherId = getParamValue('tensorArrayId', node, tensorMap, context);
64949 const gatherIndices = getParamValue('indices', node, tensorMap, context);
64950 const gatherDtype = getParamValue('dtype', node, tensorMap, context);
64951 const gatherTensorArray = context.getTensorArray(gatherId.id);
64952 return [gatherTensorArray.gather(gatherIndices, gatherDtype)];
64953 }
64954 case 'TensorArrayScatterV3': {
64955 const scatterId = getParamValue('tensorArrayId', node, tensorMap, context);
64956 const scatterIndices = getParamValue('indices', node, tensorMap, context);
64957 const scatterTensor = getParamValue('tensor', node, tensorMap, context);
64958 const scatterTensorArray = context.getTensorArray(scatterId.id);
64959 scatterTensorArray.scatter(scatterIndices, scatterTensor);
64960 return [scatterTensorArray.idTensor];
64961 }
64962 case 'TensorArrayConcatV3': {
64963 const concatId = getParamValue('tensorArrayId', node, tensorMap, context);
64964 const concatTensorArray = context.getTensorArray(concatId.id);
64965 const concatDtype = getParamValue('dtype', node, tensorMap, context);
64966 return [concatTensorArray.concat(concatDtype)];
64967 }
64968 case 'TensorArraySplitV3': {
64969 const splitId = getParamValue('tensorArrayId', node, tensorMap, context);
64970 const splitTensor = getParamValue('tensor', node, tensorMap, context);
64971 const lengths = getParamValue('lengths', node, tensorMap, context);
64972 const splitTensorArray = context.getTensorArray(splitId.id);
64973 splitTensorArray.split(lengths, splitTensor);
64974 return [splitTensorArray.idTensor];
64975 }
64976 case 'TensorArraySizeV3': {
64977 const sizeId = getParamValue('tensorArrayId', node, tensorMap, context);
64978 const sizeTensorArray = context.getTensorArray(sizeId.id);
64979 return [scalar(sizeTensorArray.size(), 'int32')];
64980 }
64981 case 'TensorArrayCloseV3': {
64982 const closeId = getParamValue('tensorArrayId', node, tensorMap, context);
64983 const closeTensorArray = context.getTensorArray(closeId.id);
64984 closeTensorArray.clearAndClose();
64985 return [closeTensorArray.idTensor];
64986 }
64987 case 'TensorListSetItem': {
64988 const idTensor = getParamValue('tensorListId', node, tensorMap, context);
64989 const index = getParamValue('index', node, tensorMap, context);
64990 const writeTensor = getParamValue('tensor', node, tensorMap, context);
64991 const tensorList = context.getTensorList(idTensor.id);
64992 tensorList.setItem(index, writeTensor);
64993 return [tensorList.idTensor];
64994 }
64995 case 'TensorListGetItem': {
64996 const idTensor = getParamValue('tensorListId', node, tensorMap, context);
64997 const readIndex = getParamValue('index', node, tensorMap, context);
64998 const elementShape = getParamValue('elementShape', node, tensorMap, context);
64999 const elementDType = getParamValue('elementDType', node, tensorMap, context);
65000 const tensorList = context.getTensorList(idTensor.id);
65001 return [tensorList.getItem(readIndex, elementShape, elementDType)];
65002 }
65003 case 'TensorListScatterV2':
65004 case 'TensorListScatter': {
65005 const scatterIndices = getParamValue('indices', node, tensorMap, context);
65006 const scatterTensor = getParamValue('tensor', node, tensorMap, context);
65007 const elementShape = getParamValue('elementShape', node, tensorMap, context);
65008 const numElements = getParamValue('numElements', node, tensorMap, context);
65009 const tensorList = scatter(scatterTensor, scatterIndices, elementShape, numElements);
65010 context.addTensorList(tensorList);
65011 return [tensorList.idTensor];
65012 }
65013 case 'TensorListReserve':
65014 case 'EmptyTensorList': {
65015 const elementShape = getParamValue('elementShape', node, tensorMap, context);
65016 const elementDtype = getParamValue('elementDType', node, tensorMap, context);
65017 let numElementsParam;
65018 if (node.op === 'TensorListReserve') {
65019 numElementsParam = 'numElements';
65020 }
65021 else {
65022 numElementsParam = 'maxNumElements';
65023 }
65024 const numElements = getParamValue(numElementsParam, node, tensorMap, context);
65025 const maxNumElements = node.op === 'TensorListReserve' ? -1 : numElements;
65026 const tensorList = reserve(elementShape, elementDtype, numElements, maxNumElements);
65027 context.addTensorList(tensorList);
65028 return [tensorList.idTensor];
65029 }
65030 case 'TensorListGather': {
65031 const gatherId = getParamValue('tensorListId', node, tensorMap, context);
65032 const gatherIndices = getParamValue('indices', node, tensorMap, context);
65033 const elementShape = getParamValue('elementShape', node, tensorMap, context);
65034 const elementDtype = getParamValue('elementDType', node, tensorMap, context);
65035 const tensorList = context.getTensorList(gatherId.id);
65036 return [tensorList.gather(gatherIndices, elementDtype, elementShape)];
65037 }
65038 case 'TensorListStack': {
65039 const idTensor = getParamValue('tensorListId', node, tensorMap, context);
65040 const elementShape = getParamValue('elementShape', node, tensorMap, context);
65041 const elementDtype = getParamValue('elementDType', node, tensorMap, context);
65042 const numElements = getParamValue('numElements', node, tensorMap, context);
65043 const tensorList = context.getTensorList(idTensor.id);
65044 return [tensorList.stack(elementShape, elementDtype, numElements)];
65045 }
65046 case 'TensorListFromTensor': {
65047 const tensor = getParamValue('tensor', node, tensorMap, context);
65048 const elementShape = getParamValue('elementShape', node, tensorMap, context);
65049 const elementDtype = getParamValue('elementDType', node, tensorMap, context);
65050 const tensorList = fromTensor(tensor, elementShape, elementDtype);
65051 context.addTensorList(tensorList);
65052 return [tensorList.idTensor];
65053 }
65054 case 'TensorListConcat':
65055 case 'TensorListConcatV2': {
65056 const concatId = getParamValue('tensorListId', node, tensorMap, context);
65057 const tensorList = context.getTensorList(concatId.id);
65058 const concatDtype = getParamValue('dtype', node, tensorMap, context);
65059 const elementShape = getParamValue('elementShape', node, tensorMap, context);
65060 return [tensorList.concat(concatDtype, elementShape)];
65061 }
65062 case 'TensorListPushBack': {
65063 const idTensor = getParamValue('tensorListId', node, tensorMap, context);
65064 const writeTensor = getParamValue('tensor', node, tensorMap, context);
65065 const tensorList = context.getTensorList(idTensor.id);
65066 tensorList.pushBack(writeTensor);
65067 return [tensorList.idTensor];
65068 }
65069 case 'TensorListPopBack': {
65070 const idTensor = getParamValue('tensorListId', node, tensorMap, context);
65071 const elementShape = getParamValue('elementShape', node, tensorMap, context);
65072 const elementDType = getParamValue('elementDType', node, tensorMap, context);
65073 const tensorList = context.getTensorList(idTensor.id);
65074 return [tensorList.popBack(elementShape, elementDType)];
65075 }
65076 case 'TensorListSplit': {
65077 const splitTensor = getParamValue('tensor', node, tensorMap, context);
65078 const elementShape = getParamValue('elementShape', node, tensorMap, context);
65079 const lengths = getParamValue('lengths', node, tensorMap, context);
65080 const tensorList = split$1(splitTensor, lengths, elementShape);
65081 context.addTensorList(tensorList);
65082 return [tensorList.idTensor];
65083 }
65084 case 'TensorListLength': {
65085 const idTensor = getParamValue('tensorListId', node, tensorMap, context);
65086 const tensorList = context.getTensorList(idTensor.id);
65087 return [scalar(tensorList.size(), 'int32')];
65088 }
65089 case 'TensorListResize': {
65090 const idTensor = getParamValue('tensorListId', node, tensorMap, context);
65091 const size = getParamValue('size', node, tensorMap, context);
65092 const srcTensorList = context.getTensorList(idTensor.id);
65093 const destTensorList = srcTensorList.resize(size);
65094 context.addTensorList(destTensorList);
65095 return [destTensorList.idTensor];
65096 }
65097 default:
65098 throw TypeError(`Node type ${node.op} is not implemented`);
65099 }
65100 };
65101 const CATEGORY$h = 'control';
65102
65103 /**
65104 * @license
65105 * Copyright 2018 Google LLC. All Rights Reserved.
65106 * Licensed under the Apache License, Version 2.0 (the "License");
65107 * you may not use this file except in compliance with the License.
65108 * You may obtain a copy of the License at
65109 *
65110 * http://www.apache.org/licenses/LICENSE-2.0
65111 *
65112 * Unless required by applicable law or agreed to in writing, software
65113 * distributed under the License is distributed on an "AS IS" BASIS,
65114 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65115 * See the License for the specific language governing permissions and
65116 * limitations under the License.
65117 * =============================================================================
65118 */
65119 function fusedConvAndDepthWiseParams(node, tensorMap, context) {
65120 const [extraOp, activationFunc] = getParamValue('fusedOps', node, tensorMap, context);
65121 const isBiasAdd = extraOp === 'biasadd';
65122 const noBiasAdd = !isBiasAdd;
65123 const isPrelu = activationFunc === 'prelu';
65124 const isBatchNorm = extraOp === 'fusedbatchnorm';
65125 const numArgs = getParamValue('numArgs', node, tensorMap, context);
65126 if (isBiasAdd) {
65127 if (isPrelu && numArgs !== 2) {
65128 throw new Error('FusedConv2d and DepthwiseConv2d with BiasAdd and Prelu ' +
65129 'must have two extra arguments: bias and alpha.');
65130 }
65131 if (!isPrelu && isBiasAdd && numArgs !== 1) {
65132 throw new Error('FusedConv2d and DepthwiseConv2d with BiasAdd must have ' +
65133 'one extra argument: bias.');
65134 }
65135 }
65136 if (isBatchNorm) {
65137 throw new Error('FusedConv2d and DepthwiseConv2d with FusedBatchNorm is not supported');
65138 }
65139 const stride = getParamValue('strides', node, tensorMap, context);
65140 const pad = getPadding(node, tensorMap, context);
65141 const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
65142 .toUpperCase();
65143 const dilations = getParamValue('dilations', node, tensorMap, context);
65144 let [biasArg, preluArg] = getParamValue('args', node, tensorMap, context);
65145 if (noBiasAdd) {
65146 preluArg = biasArg;
65147 biasArg = undefined;
65148 }
65149 const leakyreluAlpha = getParamValue('leakyreluAlpha', node, tensorMap, context);
65150 return {
65151 stride,
65152 pad,
65153 dataFormat,
65154 dilations,
65155 biasArg,
65156 preluArg,
65157 activationFunc,
65158 leakyreluAlpha
65159 };
65160 }
65161 const executeOp$h = (node, tensorMap, context, ops = tfOps) => {
65162 switch (node.op) {
65163 case 'Conv1D': {
65164 const stride = getParamValue('stride', node, tensorMap, context);
65165 const pad = getParamValue('pad', node, tensorMap, context);
65166 const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
65167 .toUpperCase();
65168 const dilation = getParamValue('dilation', node, tensorMap, context);
65169 return [ops.conv1d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), stride, pad, dataFormat, dilation)];
65170 }
65171 case 'Conv2D': {
65172 const stride = getParamValue('strides', node, tensorMap, context);
65173 const pad = getPadding(node, tensorMap, context);
65174 const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
65175 .toUpperCase();
65176 const dilations = getParamValue('dilations', node, tensorMap, context);
65177 return [ops.conv2d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [stride[1], stride[2]], pad, dataFormat, [dilations[1], dilations[2]])];
65178 }
65179 case '_FusedConv2D': {
65180 const { stride, pad, dataFormat, dilations, biasArg, preluArg, activationFunc, leakyreluAlpha } = fusedConvAndDepthWiseParams(node, tensorMap, context);
65181 return [ops.fused.conv2d({
65182 x: getParamValue('x', node, tensorMap, context),
65183 filter: getParamValue('filter', node, tensorMap, context),
65184 strides: [stride[1], stride[2]],
65185 pad: pad,
65186 dataFormat: dataFormat,
65187 dilations: [dilations[1], dilations[2]],
65188 bias: biasArg,
65189 activation: activationFunc,
65190 preluActivationWeights: preluArg,
65191 leakyreluAlpha
65192 })];
65193 }
65194 case 'FusedDepthwiseConv2dNative': {
65195 const { stride, pad, dataFormat, dilations, biasArg, preluArg, activationFunc, leakyreluAlpha, } = fusedConvAndDepthWiseParams(node, tensorMap, context);
65196 return [ops.fused.depthwiseConv2d({
65197 x: getParamValue('x', node, tensorMap, context),
65198 filter: getParamValue('filter', node, tensorMap, context),
65199 strides: [stride[1], stride[2]],
65200 pad: pad,
65201 dataFormat: dataFormat,
65202 dilations: [dilations[1], dilations[2]],
65203 bias: biasArg,
65204 activation: activationFunc,
65205 preluActivationWeights: preluArg,
65206 leakyreluAlpha
65207 })];
65208 }
65209 case 'Conv2DBackpropInput':
65210 case 'Conv2dTranspose': {
65211 const shape = getParamValue('outputShape', node, tensorMap, context);
65212 const stride = getParamValue('strides', node, tensorMap, context);
65213 const pad = getPadding(node, tensorMap, context);
65214 return [ops.conv2dTranspose(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), shape, [stride[1], stride[2]], pad)];
65215 }
65216 case 'DepthwiseConv2dNative':
65217 case 'DepthwiseConv2d': {
65218 const stride = getParamValue('strides', node, tensorMap, context);
65219 const pad = getPadding(node, tensorMap, context);
65220 const dilations = getParamValue('dilations', node, tensorMap, context);
65221 const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
65222 .toUpperCase();
65223 return [ops.depthwiseConv2d(getParamValue('input', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [stride[1], stride[2]], pad, dataFormat, [dilations[1], dilations[2]])];
65224 }
65225 case 'Conv3D': {
65226 const stride = getParamValue('strides', node, tensorMap, context);
65227 const pad = getParamValue('pad', node, tensorMap, context);
65228 const dataFormat = getParamValue('dataFormat', node, tensorMap, context)
65229 .toUpperCase();
65230 const dilations = getParamValue('dilations', node, tensorMap, context);
65231 return [ops.conv3d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [stride[1], stride[2], stride[3]], pad, dataFormat, [dilations[1], dilations[2], dilations[3]])];
65232 }
65233 case 'AvgPool': {
65234 const stride = getParamValue('strides', node, tensorMap, context);
65235 const pad = getParamValue('pad', node, tensorMap, context);
65236 const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
65237 return [ops.avgPool(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], pad)];
65238 }
65239 case 'MaxPool': {
65240 const stride = getParamValue('strides', node, tensorMap, context);
65241 const pad = getParamValue('pad', node, tensorMap, context);
65242 const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
65243 return [ops.maxPool(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], pad)];
65244 }
65245 case 'MaxPoolWithArgmax': {
65246 const stride = getParamValue('strides', node, tensorMap, context);
65247 const pad = getParamValue('pad', node, tensorMap, context);
65248 const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
65249 const includeBatchInIndex = getParamValue('includeBatchInIndex', node, tensorMap, context);
65250 const { result, indexes } = ops.maxPoolWithArgmax(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], pad, includeBatchInIndex);
65251 return [result, indexes];
65252 }
65253 case 'AvgPool3D': {
65254 const stride = getParamValue('strides', node, tensorMap, context);
65255 const pad = getParamValue('pad', node, tensorMap, context);
65256 const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
65257 return [ops.avgPool3d(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2], kernelSize[3]], [stride[1], stride[2], stride[3]], pad)];
65258 }
65259 case 'MaxPool3D': {
65260 const stride = getParamValue('strides', node, tensorMap, context);
65261 const pad = getParamValue('pad', node, tensorMap, context);
65262 const kernelSize = getParamValue('kernelSize', node, tensorMap, context);
65263 return [ops.maxPool3d(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2], kernelSize[3]], [stride[1], stride[2], stride[3]], pad)];
65264 }
65265 case 'Dilation2D': {
65266 const strides = getParamValue('strides', node, tensorMap, context);
65267 const pad = getParamValue('pad', node, tensorMap, context);
65268 const dilations = getParamValue('dilations', node, tensorMap, context);
65269 // strides: [1, stride_height, stride_width, 1].
65270 const strideHeight = strides[1];
65271 const strideWidth = strides[2];
65272 // dilations: [1, dilation_height, dilation_width, 1].
65273 const dilationHeight = dilations[1];
65274 const dilationWidth = dilations[2];
65275 return [ops.dilation2d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [strideHeight, strideWidth], pad, [dilationHeight, dilationWidth], 'NHWC' /* dataFormat */)];
65276 }
65277 default:
65278 throw TypeError(`Node type ${node.op} is not implemented`);
65279 }
65280 };
65281 const CATEGORY$g = 'convolution';
65282
65283 /**
65284 * @license
65285 * Copyright 2018 Google LLC. All Rights Reserved.
65286 * Licensed under the Apache License, Version 2.0 (the "License");
65287 * you may not use this file except in compliance with the License.
65288 * You may obtain a copy of the License at
65289 *
65290 * http://www.apache.org/licenses/LICENSE-2.0
65291 *
65292 * Unless required by applicable law or agreed to in writing, software
65293 * distributed under the License is distributed on an "AS IS" BASIS,
65294 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65295 * See the License for the specific language governing permissions and
65296 * limitations under the License.
65297 * =============================================================================
65298 */
65299 const executeOp$g = (node, tensorMap, context, ops = tfOps) => {
65300 switch (node.op) {
65301 case 'Fill': {
65302 const shape = getParamValue('shape', node, tensorMap, context);
65303 const dtype = getParamValue('dtype', node, tensorMap, context);
65304 const value = getParamValue('value', node, tensorMap, context);
65305 return [ops.fill(shape, value, dtype)];
65306 }
65307 case 'LinSpace': {
65308 const start = getParamValue('start', node, tensorMap, context);
65309 const stop = getParamValue('stop', node, tensorMap, context);
65310 const num = getParamValue('num', node, tensorMap, context);
65311 return [ops.linspace(start, stop, num)];
65312 }
65313 case 'Multinomial': {
65314 const logits = getParamValue('logits', node, tensorMap, context);
65315 const numSamples = getParamValue('numSamples', node, tensorMap, context);
65316 const seed = getParamValue('seed', node, tensorMap, context);
65317 return [ops.multinomial(logits, numSamples, seed)];
65318 }
65319 case 'OneHot': {
65320 const indices = getParamValue('indices', node, tensorMap, context);
65321 const depth = getParamValue('depth', node, tensorMap, context);
65322 const onValue = getParamValue('onValue', node, tensorMap, context);
65323 const offValue = getParamValue('offValue', node, tensorMap, context);
65324 const dtype = getParamValue('dtype', node, tensorMap, context);
65325 return [ops.oneHot(indices, depth, onValue, offValue, dtype)];
65326 }
65327 case 'Ones': {
65328 return [ops.ones(getParamValue('shape', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
65329 }
65330 case 'OnesLike': {
65331 return [ops.onesLike(getParamValue('x', node, tensorMap, context))];
65332 }
65333 case 'RandomStandardNormal': {
65334 return [ops.randomStandardNormal(getParamValue('shape', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context), getParamValue('seed', node, tensorMap, context))];
65335 }
65336 case 'RandomUniform': {
65337 return [ops.randomUniform(
65338 // tslint:disable-next-line:no-any
65339 getParamValue('shape', node, tensorMap, context), getParamValue('minval', node, tensorMap, context), getParamValue('maxval', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
65340 }
65341 case 'RandomUniformInt': {
65342 return [ops.randomUniformInt(getParamValue('shape', node, tensorMap, context), getParamValue('minval', node, tensorMap, context), getParamValue('maxval', node, tensorMap, context), getParamValue('seed', node, tensorMap, context))];
65343 }
65344 case 'Range': {
65345 const start = getParamValue('start', node, tensorMap, context);
65346 const stop = getParamValue('stop', node, tensorMap, context);
65347 const step = getParamValue('step', node, tensorMap, context);
65348 return [ops.range(start, stop, step, getParamValue('dtype', node, tensorMap, context))];
65349 }
65350 case 'TruncatedNormal': {
65351 const shape = getParamValue('shape', node, tensorMap, context);
65352 const mean = getParamValue('mean', node, tensorMap, context);
65353 const stdDev = getParamValue('stdDev', node, tensorMap, context);
65354 const seed = getParamValue('seed', node, tensorMap, context);
65355 return [ops.truncatedNormal(shape, mean, stdDev, getParamValue('dtype', node, tensorMap, context), seed)];
65356 }
65357 case 'Zeros': {
65358 return [ops.zeros(getParamValue('shape', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
65359 }
65360 case 'ZerosLike': {
65361 return [ops.zerosLike(getParamValue('x', node, tensorMap, context))];
65362 }
65363 default:
65364 throw TypeError(`Node type ${node.op} is not implemented`);
65365 }
65366 };
65367 const CATEGORY$f = 'creation';
65368
65369 /**
65370 * @license
65371 * Copyright 2018 Google LLC. All Rights Reserved.
65372 * Licensed under the Apache License, Version 2.0 (the "License");
65373 * you may not use this file except in compliance with the License.
65374 * You may obtain a copy of the License at
65375 *
65376 * http://www.apache.org/licenses/LICENSE-2.0
65377 *
65378 * Unless required by applicable law or agreed to in writing, software
65379 * distributed under the License is distributed on an "AS IS" BASIS,
65380 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65381 * See the License for the specific language governing permissions and
65382 * limitations under the License.
65383 * =============================================================================
65384 */
65385 function nmsParams(node, tensorMap, context) {
65386 const boxes = getParamValue('boxes', node, tensorMap, context);
65387 const scores = getParamValue('scores', node, tensorMap, context);
65388 const maxOutputSize = getParamValue('maxOutputSize', node, tensorMap, context);
65389 const iouThreshold = getParamValue('iouThreshold', node, tensorMap, context);
65390 const scoreThreshold = getParamValue('scoreThreshold', node, tensorMap, context);
65391 const softNmsSigma = getParamValue('softNmsSigma', node, tensorMap, context);
65392 return {
65393 boxes,
65394 scores,
65395 maxOutputSize,
65396 iouThreshold,
65397 scoreThreshold,
65398 softNmsSigma
65399 };
65400 }
65401 const executeOp$f = async (node, tensorMap, context, resourceManager, ops = tfOps) => {
65402 switch (node.op) {
65403 case 'NonMaxSuppressionV5': {
65404 const { boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma } = nmsParams(node, tensorMap, context);
65405 const result = await ops.image.nonMaxSuppressionWithScoreAsync(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
65406 return [result.selectedIndices, result.selectedScores];
65407 }
65408 case 'NonMaxSuppressionV4': {
65409 const { boxes, scores, maxOutputSize, iouThreshold, scoreThreshold } = nmsParams(node, tensorMap, context);
65410 const padToMaxOutputSize = getParamValue('padToMaxOutputSize', node, tensorMap, context);
65411 const result = await ops.image.nonMaxSuppressionPaddedAsync(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize);
65412 return [result.selectedIndices, result.validOutputs];
65413 }
65414 case 'NonMaxSuppressionV3':
65415 case 'NonMaxSuppressionV2': {
65416 const { boxes, scores, maxOutputSize, iouThreshold, scoreThreshold } = nmsParams(node, tensorMap, context);
65417 return [await ops.image.nonMaxSuppressionAsync(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold)];
65418 }
65419 case 'Where': {
65420 const condition = ops.cast(getParamValue('condition', node, tensorMap, context), 'bool');
65421 const result = [await ops.whereAsync(condition)];
65422 condition.dispose();
65423 return result;
65424 }
65425 case 'ListDiff': {
65426 return ops.setdiff1dAsync(getParamValue('x', node, tensorMap, context), getParamValue('y', node, tensorMap, context));
65427 }
65428 default:
65429 throw TypeError(`Node type ${node.op} is not implemented`);
65430 }
65431 };
65432 const CATEGORY$e = 'dynamic';
65433
65434 /**
65435 * @license
65436 * Copyright 2018 Google LLC. All Rights Reserved.
65437 * Licensed under the Apache License, Version 2.0 (the "License");
65438 * you may not use this file except in compliance with the License.
65439 * You may obtain a copy of the License at
65440 *
65441 * http://www.apache.org/licenses/LICENSE-2.0
65442 *
65443 * Unless required by applicable law or agreed to in writing, software
65444 * distributed under the License is distributed on an "AS IS" BASIS,
65445 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65446 * See the License for the specific language governing permissions and
65447 * limitations under the License.
65448 * =============================================================================
65449 */
65450 const executeOp$e = (node, tensorMap, context, ops = tfOps) => {
65451 switch (node.op) {
65452 case 'LowerBound': {
65453 const sortedSequence = getParamValue('sortedSequence', node, tensorMap, context);
65454 const values = getParamValue('values', node, tensorMap, context);
65455 return [ops.lowerBound(sortedSequence, values)];
65456 }
65457 case 'TopKV2': {
65458 const x = getParamValue('x', node, tensorMap, context);
65459 const k = getParamValue('k', node, tensorMap, context);
65460 const sorted = getParamValue('sorted', node, tensorMap, context);
65461 const result = ops.topk(x, k, sorted);
65462 return [result.values, result.indices];
65463 }
65464 case 'UpperBound': {
65465 const sortedSequence = getParamValue('sortedSequence', node, tensorMap, context);
65466 const values = getParamValue('values', node, tensorMap, context);
65467 return [ops.upperBound(sortedSequence, values)];
65468 }
65469 case 'Unique': {
65470 const x = getParamValue('x', node, tensorMap, context);
65471 const result = ops.unique(x);
65472 return [result.values, result.indices];
65473 }
65474 case 'UniqueV2': {
65475 const x = getParamValue('x', node, tensorMap, context);
65476 const axis = getParamValue('axis', node, tensorMap, context);
65477 const result = ops.unique(x, axis);
65478 return [result.values, result.indices];
65479 }
65480 default:
65481 throw TypeError(`Node type ${node.op} is not implemented`);
65482 }
65483 };
65484 const CATEGORY$d = 'evaluation';
65485
65486 /**
65487 * @license
65488 * Copyright 2018 Google LLC. All Rights Reserved.
65489 * Licensed under the Apache License, Version 2.0 (the "License");
65490 * you may not use this file except in compliance with the License.
65491 * You may obtain a copy of the License at
65492 *
65493 * http://www.apache.org/licenses/LICENSE-2.0
65494 *
65495 * Unless required by applicable law or agreed to in writing, software
65496 * distributed under the License is distributed on an "AS IS" BASIS,
65497 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65498 * See the License for the specific language governing permissions and
65499 * limitations under the License.
65500 * =============================================================================
65501 */
65502 const executeOp$d = (node, tensorMap, context, ops = tfOps) => {
65503 switch (node.op) {
65504 case 'Const': {
65505 return tensorMap[node.name];
65506 }
65507 case 'PlaceholderWithDefault':
65508 const def = getParamValue('default', node, tensorMap, context);
65509 return [getTensor(node.name, tensorMap, context) || def];
65510 case 'Placeholder':
65511 return [getTensor(node.name, tensorMap, context)];
65512 case 'Identity':
65513 case 'StopGradient':
65514 case 'FakeQuantWithMinMaxVars': { // This op is currently ignored.
65515 const data = getParamValue('x', node, tensorMap, context);
65516 return [cloneTensor(data)];
65517 }
65518 case 'IdentityN':
65519 return getParamValue('x', node, tensorMap, context)
65520 .map((t) => cloneTensor(t));
65521 case 'Snapshot':
65522 const snapshot = getParamValue('x', node, tensorMap, context);
65523 return [cloneTensor(snapshot)];
65524 case 'Shape':
65525 return [ops.tensor1d(getParamValue('x', node, tensorMap, context).shape, 'int32')];
65526 case 'ShapeN':
65527 return getParamValue('x', node, tensorMap, context)
65528 .map((t) => ops.tensor1d(t.shape));
65529 case 'Size':
65530 return [ops.scalar(getParamValue('x', node, tensorMap, context).size, 'int32')];
65531 case 'Rank':
65532 return [ops.scalar(getParamValue('x', node, tensorMap, context).rank, 'int32')];
65533 case 'NoOp':
65534 return [ops.scalar(1)];
65535 case 'Print':
65536 const input = getParamValue('x', node, tensorMap, context);
65537 const data = getParamValue('data', node, tensorMap, context);
65538 const message = getParamValue('message', node, tensorMap, context);
65539 const summarize = getParamValue('summarize', node, tensorMap, context);
65540 console.warn('The graph has a tf.print() operation,' +
65541 'usually used for debugging, which slows down performance.');
65542 console.log(message);
65543 for (let i = 0; i < data.length; i++) {
65544 console.log(Array.prototype.slice.call(data[i].dataSync())
65545 .slice(0, summarize));
65546 }
65547 return [input];
65548 default:
65549 throw TypeError(`Node type ${node.op} is not implemented`);
65550 }
65551 };
65552 const CATEGORY$c = 'graph';
65553
65554 /**
65555 * @license
65556 * Copyright 2020 Google LLC. All Rights Reserved.
65557 * Licensed under the Apache License, Version 2.0 (the "License");
65558 * you may not use this file except in compliance with the License.
65559 * You may obtain a copy of the License at
65560 *
65561 * http://www.apache.org/licenses/LICENSE-2.0
65562 *
65563 * Unless required by applicable law or agreed to in writing, software
65564 * distributed under the License is distributed on an "AS IS" BASIS,
65565 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65566 * See the License for the specific language governing permissions and
65567 * limitations under the License.
65568 * =============================================================================
65569 */
65570 /**
65571 * Hashtable contains a set of tensors, which can be accessed by key.
65572 */
65573 class HashTable {
65574 get id() {
65575 return this.handle.id;
65576 }
65577 /**
65578 * Constructor of HashTable. Creates a hash table.
65579 *
65580 * @param keyDType `dtype` of the table keys.
65581 * @param valueDType `dtype` of the table values.
65582 */
65583 constructor(keyDType, valueDType) {
65584 this.keyDType = keyDType;
65585 this.valueDType = valueDType;
65586 this.handle = scalar(0);
65587 // tslint:disable-next-line: no-any
65588 this.tensorMap = new Map();
65589 keep(this.handle);
65590 }
65591 /**
65592 * Dispose the tensors and handle and clear the hashtable.
65593 */
65594 clearAndClose() {
65595 this.tensorMap.forEach(value => value.dispose());
65596 this.tensorMap.clear();
65597 this.handle.dispose();
65598 }
65599 /**
65600 * The number of items in the hash table.
65601 */
65602 size() {
65603 return this.tensorMap.size;
65604 }
65605 /**
65606 * The number of items in the hash table as a rank-0 tensor.
65607 */
65608 tensorSize() {
65609 return scalar(this.size(), 'int32');
65610 }
65611 /**
65612 * Replaces the contents of the table with the specified keys and values.
65613 * @param keys Keys to store in the hashtable.
65614 * @param values Values to store in the hashtable.
65615 */
65616 async import(keys, values) {
65617 this.checkKeyAndValueTensor(keys, values);
65618 // We only store the primitive values of the keys, this allows lookup
65619 // to be O(1).
65620 const $keys = await keys.data();
65621 // Clear the hashTable before inserting new values.
65622 this.tensorMap.forEach(value => value.dispose());
65623 this.tensorMap.clear();
65624 return tidy(() => {
65625 const $values = unstack(values);
65626 const keysLength = $keys.length;
65627 const valuesLength = $values.length;
65628 assert$1(keysLength === valuesLength, () => `The number of elements doesn't match, keys has ` +
65629 `${keysLength} elements, the values has ${valuesLength} ` +
65630 `elements.`);
65631 for (let i = 0; i < keysLength; i++) {
65632 const key = $keys[i];
65633 const value = $values[i];
65634 keep(value);
65635 this.tensorMap.set(key, value);
65636 }
65637 return this.handle;
65638 });
65639 }
65640 /**
65641 * Looks up keys in a hash table, outputs the corresponding values.
65642 *
65643 * Performs batch lookups, for every element in the key tensor, `find`
65644 * stacks the corresponding value into the return tensor.
65645 *
65646 * If an element is not present in the table, the given `defaultValue` is
65647 * used.
65648 *
65649 * @param keys Keys to look up. Must have the same type as the keys of the
65650 * table.
65651 * @param defaultValue The scalar `defaultValue` is the value output for keys
65652 * not present in the table. It must also be of the same type as the
65653 * table values.
65654 */
65655 async find(keys, defaultValue) {
65656 this.checkKeyAndValueTensor(keys, defaultValue);
65657 const $keys = await keys.data();
65658 return tidy(() => {
65659 const result = [];
65660 for (let i = 0; i < $keys.length; i++) {
65661 const key = $keys[i];
65662 const value = this.findWithDefault(key, defaultValue);
65663 result.push(value);
65664 }
65665 return stack(result);
65666 });
65667 }
65668 // tslint:disable-next-line: no-any
65669 findWithDefault(key, defaultValue) {
65670 const result = this.tensorMap.get(key);
65671 return result != null ? result : defaultValue;
65672 }
65673 checkKeyAndValueTensor(key, value) {
65674 if (key.dtype !== this.keyDType) {
65675 throw new Error(`Expect key dtype ${this.keyDType}, but got ` +
65676 `${key.dtype}`);
65677 }
65678 if (value.dtype !== this.valueDType) {
65679 throw new Error(`Expect value dtype ${this.valueDType}, but got ` +
65680 `${value.dtype}`);
65681 }
65682 }
65683 }
65684
65685 /**
65686 * @license
65687 * Copyright 2020 Google LLC. All Rights Reserved.
65688 * Licensed under the Apache License, Version 2.0 (the "License");
65689 * you may not use this file except in compliance with the License.
65690 * You may obtain a copy of the License at
65691 *
65692 * http://www.apache.org/licenses/LICENSE-2.0
65693 *
65694 * Unless required by applicable law or agreed to in writing, software
65695 * distributed under the License is distributed on an "AS IS" BASIS,
65696 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65697 * See the License for the specific language governing permissions and
65698 * limitations under the License.
65699 * =============================================================================
65700 */
65701 const executeOp$c = async (node, tensorMap, context, resourceManager) => {
65702 switch (node.op) {
65703 case 'HashTable':
65704 case 'HashTableV2': {
65705 const existingTableHandle = resourceManager.getHashTableHandleByName(node.name);
65706 // Table is shared with initializer.
65707 if (existingTableHandle != null) {
65708 return [existingTableHandle];
65709 }
65710 else {
65711 const keyDType = getParamValue('keyDType', node, tensorMap, context);
65712 const valueDType = getParamValue('valueDType', node, tensorMap, context);
65713 const hashTable = new HashTable(keyDType, valueDType);
65714 resourceManager.addHashTable(node.name, hashTable);
65715 return [hashTable.handle];
65716 }
65717 }
65718 case 'InitializeTable':
65719 case 'InitializeTableV2':
65720 case 'LookupTableImport':
65721 case 'LookupTableImportV2': {
65722 const handle = getParamValue('tableHandle', node, tensorMap, context, resourceManager);
65723 const keys = getParamValue('keys', node, tensorMap, context);
65724 const values = getParamValue('values', node, tensorMap, context);
65725 const hashTable = resourceManager.getHashTableById(handle.id);
65726 return [await hashTable.import(keys, values)];
65727 }
65728 case 'LookupTableFind':
65729 case 'LookupTableFindV2': {
65730 const handle = getParamValue('tableHandle', node, tensorMap, context, resourceManager);
65731 const keys = getParamValue('keys', node, tensorMap, context);
65732 const defaultValue = getParamValue('defaultValue', node, tensorMap, context);
65733 const hashTable = resourceManager.getHashTableById(handle.id);
65734 return [await hashTable.find(keys, defaultValue)];
65735 }
65736 case 'LookupTableSize':
65737 case 'LookupTableSizeV2': {
65738 const handle = getParamValue('tableHandle', node, tensorMap, context, resourceManager);
65739 const hashTable = resourceManager.getHashTableById(handle.id);
65740 return [hashTable.tensorSize()];
65741 }
65742 default:
65743 throw TypeError(`Node type ${node.op} is not implemented`);
65744 }
65745 };
65746 const CATEGORY$b = 'hash_table';
65747
65748 /**
65749 * @license
65750 * Copyright 2018 Google LLC. All Rights Reserved.
65751 * Licensed under the Apache License, Version 2.0 (the "License");
65752 * you may not use this file except in compliance with the License.
65753 * You may obtain a copy of the License at
65754 *
65755 * http://www.apache.org/licenses/LICENSE-2.0
65756 *
65757 * Unless required by applicable law or agreed to in writing, software
65758 * distributed under the License is distributed on an "AS IS" BASIS,
65759 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65760 * See the License for the specific language governing permissions and
65761 * limitations under the License.
65762 * =============================================================================
65763 */
65764 const executeOp$b = (node, tensorMap, context, ops = tfOps) => {
65765 switch (node.op) {
65766 case 'ResizeBilinear': {
65767 const images = getParamValue('images', node, tensorMap, context);
65768 const size = getParamValue('size', node, tensorMap, context);
65769 const alignCorners = getParamValue('alignCorners', node, tensorMap, context);
65770 const halfPixelCenters = getParamValue('halfPixelCenters', node, tensorMap, context);
65771 return [ops.image.resizeBilinear(images, [size[0], size[1]], alignCorners, halfPixelCenters)];
65772 }
65773 case 'ResizeNearestNeighbor': {
65774 const images = getParamValue('images', node, tensorMap, context);
65775 const size = getParamValue('size', node, tensorMap, context);
65776 const alignCorners = getParamValue('alignCorners', node, tensorMap, context);
65777 const halfPixelCenters = getParamValue('halfPixelCenters', node, tensorMap, context);
65778 return [ops.image.resizeNearestNeighbor(images, [size[0], size[1]], alignCorners, halfPixelCenters)];
65779 }
65780 case 'CropAndResize': {
65781 const image = getParamValue('image', node, tensorMap, context);
65782 const boxes = getParamValue('boxes', node, tensorMap, context);
65783 const boxInd = getParamValue('boxInd', node, tensorMap, context);
65784 const cropSize = getParamValue('cropSize', node, tensorMap, context);
65785 const method = getParamValue('method', node, tensorMap, context);
65786 const extrapolationValue = getParamValue('extrapolationValue', node, tensorMap, context);
65787 return [ops.image.cropAndResize(image, boxes, boxInd, cropSize, method, extrapolationValue)];
65788 }
65789 case 'ImageProjectiveTransformV3': {
65790 const images = getParamValue('images', node, tensorMap, context);
65791 const transforms = getParamValue('transforms', node, tensorMap, context);
65792 const outputShape = getParamValue('outputShape', node, tensorMap, context);
65793 const fillValue = getParamValue('fillValue', node, tensorMap, context);
65794 const interpolation = getParamValue('interpolation', node, tensorMap, context);
65795 const fillMode = getParamValue('fillMode', node, tensorMap, context);
65796 return [ops.image.transform(images, transforms, interpolation.toLowerCase(), fillMode.toLowerCase(), fillValue, outputShape)];
65797 }
65798 default:
65799 throw TypeError(`Node type ${node.op} is not implemented`);
65800 }
65801 };
65802 const CATEGORY$a = 'image';
65803
65804 /**
65805 * @license
65806 * Copyright 2018 Google LLC. All Rights Reserved.
65807 * Licensed under the Apache License, Version 2.0 (the "License");
65808 * you may not use this file except in compliance with the License.
65809 * You may obtain a copy of the License at
65810 *
65811 * http://www.apache.org/licenses/LICENSE-2.0
65812 *
65813 * Unless required by applicable law or agreed to in writing, software
65814 * distributed under the License is distributed on an "AS IS" BASIS,
65815 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65816 * See the License for the specific language governing permissions and
65817 * limitations under the License.
65818 * =============================================================================
65819 */
65820 const executeOp$a = (node, tensorMap, context, ops = tfOps) => {
65821 switch (node.op) {
65822 case 'Equal': {
65823 return [ops.equal(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
65824 }
65825 case 'NotEqual': {
65826 return [ops.notEqual(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
65827 }
65828 case 'Greater': {
65829 return [ops.greater(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
65830 }
65831 case 'GreaterEqual': {
65832 return [ops.greaterEqual(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
65833 }
65834 case 'Less': {
65835 return [ops.less(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
65836 }
65837 case 'LessEqual': {
65838 return [ops.lessEqual(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
65839 }
65840 case 'LogicalAnd': {
65841 return [ops.logicalAnd(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
65842 }
65843 case 'LogicalNot': {
65844 return [ops.logicalNot(getParamValue('a', node, tensorMap, context))];
65845 }
65846 case 'LogicalOr': {
65847 return [ops.logicalOr(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
65848 }
65849 case 'Select':
65850 case 'SelectV2': {
65851 return [ops.where(getParamValue('condition', node, tensorMap, context), getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
65852 }
65853 case 'BitwiseAnd': {
65854 return [ops.bitwiseAnd(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
65855 }
65856 default:
65857 throw TypeError(`Node type ${node.op} is not implemented`);
65858 }
65859 };
65860 const CATEGORY$9 = 'logical';
65861
65862 /**
65863 * @license
65864 * Copyright 2018 Google LLC. All Rights Reserved.
65865 * Licensed under the Apache License, Version 2.0 (the "License");
65866 * you may not use this file except in compliance with the License.
65867 * You may obtain a copy of the License at
65868 *
65869 * http://www.apache.org/licenses/LICENSE-2.0
65870 *
65871 * Unless required by applicable law or agreed to in writing, software
65872 * distributed under the License is distributed on an "AS IS" BASIS,
65873 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65874 * See the License for the specific language governing permissions and
65875 * limitations under the License.
65876 * =============================================================================
65877 */
65878 const executeOp$9 = (node, tensorMap, context, ops = tfOps) => {
65879 switch (node.op) {
65880 case 'BatchMatMul':
65881 case 'BatchMatMulV2':
65882 case 'MatMul':
65883 return [ops.matMul(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context), getParamValue('transposeA', node, tensorMap, context), getParamValue('transposeB', node, tensorMap, context))];
65884 case 'Einsum':
65885 return [ops.einsum(getParamValue('equation', node, tensorMap, context), ...getParamValue('tensors', node, tensorMap, context))];
65886 case 'Transpose':
65887 return [ops.transpose(getParamValue('x', node, tensorMap, context), getParamValue('perm', node, tensorMap, context))];
65888 case '_FusedMatMul':
65889 const [extraOp, activationFunc] = getParamValue('fusedOps', node, tensorMap, context);
65890 const isBiasAdd = extraOp === 'biasadd';
65891 const isPrelu = activationFunc === 'prelu';
65892 const numArgs = getParamValue('numArgs', node, tensorMap, context);
65893 const leakyreluAlpha = getParamValue('leakyreluAlpha', node, tensorMap, context);
65894 if (isBiasAdd) {
65895 if (isPrelu && numArgs !== 2) {
65896 throw new Error('Fused MatMul with BiasAdd and Prelu must have two ' +
65897 'extra arguments: bias and alpha.');
65898 }
65899 if (!isPrelu && numArgs !== 1) {
65900 throw new Error('Fused MatMul with BiasAdd must have one extra argument: bias.');
65901 }
65902 }
65903 const [biasArg, preluArg] = getParamValue('args', node, tensorMap, context);
65904 return [ops.fused.matMul({
65905 a: getParamValue('a', node, tensorMap, context),
65906 b: getParamValue('b', node, tensorMap, context),
65907 transposeA: getParamValue('transposeA', node, tensorMap, context),
65908 transposeB: getParamValue('transposeB', node, tensorMap, context),
65909 bias: biasArg,
65910 activation: activationFunc,
65911 preluActivationWeights: preluArg,
65912 leakyreluAlpha
65913 })];
65914 case 'MatrixBandPart':
65915 return [ops.linalg.bandPart(getParamValue('a', node, tensorMap, context), getParamValue('numLower', node, tensorMap, context), getParamValue('numUpper', node, tensorMap, context))];
65916 default:
65917 throw TypeError(`Node type ${node.op} is not implemented`);
65918 }
65919 };
65920 const CATEGORY$8 = 'matrices';
65921
65922 /**
65923 * @license
65924 * Copyright 2018 Google LLC. All Rights Reserved.
65925 * Licensed under the Apache License, Version 2.0 (the "License");
65926 * you may not use this file except in compliance with the License.
65927 * You may obtain a copy of the License at
65928 *
65929 * http://www.apache.org/licenses/LICENSE-2.0
65930 *
65931 * Unless required by applicable law or agreed to in writing, software
65932 * distributed under the License is distributed on an "AS IS" BASIS,
65933 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65934 * See the License for the specific language governing permissions and
65935 * limitations under the License.
65936 * =============================================================================
65937 */
65938 const executeOp$8 = (node, tensorMap, context, ops = tfOps) => {
65939 switch (node.op) {
65940 case 'EuclideanNorm':
65941 return [ops.euclideanNorm(getParamValue('x', node, tensorMap, context), getParamValue('axis', node, tensorMap, context), getParamValue('keepDims', node, tensorMap, context))];
65942 case 'FusedBatchNorm':
65943 case 'FusedBatchNormV2': {
65944 return [ops.batchNorm(getParamValue('x', node, tensorMap, context), getParamValue('mean', node, tensorMap, context), getParamValue('variance', node, tensorMap, context), getParamValue('offset', node, tensorMap, context), getParamValue('scale', node, tensorMap, context), getParamValue('epsilon', node, tensorMap, context))];
65945 }
65946 case 'FusedBatchNormV3': {
65947 return [ops.batchNorm(getParamValue('x', node, tensorMap, context), getParamValue('mean', node, tensorMap, context), getParamValue('variance', node, tensorMap, context), getParamValue('offset', node, tensorMap, context), getParamValue('scale', node, tensorMap, context), getParamValue('epsilon', node, tensorMap, context))];
65948 }
65949 case 'LRN': {
65950 return [ops.localResponseNormalization(getParamValue('x', node, tensorMap, context), getParamValue('radius', node, tensorMap, context), getParamValue('bias', node, tensorMap, context), getParamValue('alpha', node, tensorMap, context), getParamValue('beta', node, tensorMap, context))];
65951 }
65952 case 'Softmax': {
65953 return [ops.softmax(getParamValue('x', node, tensorMap, context))];
65954 }
65955 case 'LogSoftmax': {
65956 return [ops.logSoftmax(getParamValue('x', node, tensorMap, context))];
65957 }
65958 default:
65959 throw TypeError(`Node type ${node.op} is not implemented`);
65960 }
65961 };
65962 const CATEGORY$7 = 'normalization';
65963
65964 /**
65965 * @license
65966 * Copyright 2022 Google LLC. All Rights Reserved.
65967 * Licensed under the Apache License, Version 2.0 (the "License");
65968 * you may not use this file except in compliance with the License.
65969 * You may obtain a copy of the License at
65970 *
65971 * http://www.apache.org/licenses/LICENSE-2.0
65972 *
65973 * Unless required by applicable law or agreed to in writing, software
65974 * distributed under the License is distributed on an "AS IS" BASIS,
65975 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65976 * See the License for the specific language governing permissions and
65977 * limitations under the License.
65978 * =============================================================================
65979 */
65980 const executeOp$7 = (node, tensorMap, context, ops = tfOps) => {
65981 switch (node.op) {
65982 case 'RaggedGather': {
65983 const { outputNestedSplits, outputDenseValues, } = ops.raggedGather(getParamValue('paramsNestedSplits', node, tensorMap, context), getParamValue('paramsDenseValues', node, tensorMap, context), getParamValue('indices', node, tensorMap, context), getParamValue('outputRaggedRank', node, tensorMap, context));
65984 return outputNestedSplits.concat(outputDenseValues);
65985 }
65986 case 'RaggedRange': {
65987 const { rtNestedSplits, rtDenseValues } = ops.raggedRange(getParamValue('starts', node, tensorMap, context), getParamValue('limits', node, tensorMap, context), getParamValue('splits', node, tensorMap, context));
65988 return [rtNestedSplits, rtDenseValues];
65989 }
65990 case 'RaggedTensorToTensor': {
65991 return [ops.raggedTensorToTensor(getParamValue('shape', node, tensorMap, context), getParamValue('values', node, tensorMap, context), getParamValue('defaultValue', node, tensorMap, context), getParamValue('rowPartitionTensors', node, tensorMap, context), getParamValue('rowPartitionTypes', node, tensorMap, context))];
65992 }
65993 default:
65994 throw TypeError(`Node type ${node.op} is not implemented`);
65995 }
65996 };
65997 const CATEGORY$6 = 'ragged';
65998
65999 /**
66000 * @license
66001 * Copyright 2018 Google LLC. All Rights Reserved.
66002 * Licensed under the Apache License, Version 2.0 (the "License");
66003 * you may not use this file except in compliance with the License.
66004 * You may obtain a copy of the License at
66005 *
66006 * http://www.apache.org/licenses/LICENSE-2.0
66007 *
66008 * Unless required by applicable law or agreed to in writing, software
66009 * distributed under the License is distributed on an "AS IS" BASIS,
66010 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
66011 * See the License for the specific language governing permissions and
66012 * limitations under the License.
66013 * =============================================================================
66014 */
66015 const executeOp$6 = (node, tensorMap, context, ops = tfOps) => {
66016 switch (node.op) {
66017 case 'Max': {
66018 const axis = getParamValue('axis', node, tensorMap, context);
66019 const keepDims = getParamValue('keepDims', node, tensorMap, context);
66020 return [ops.max(getParamValue('x', node, tensorMap, context), axis, keepDims)];
66021 }
66022 case 'Mean': {
66023 const axis = getParamValue('axis', node, tensorMap, context);
66024 const keepDims = getParamValue('keepDims', node, tensorMap, context);
66025 return [ops.mean(getParamValue('x', node, tensorMap, context), axis, keepDims)];
66026 }
66027 case 'Min': {
66028 const axis = getParamValue('axis', node, tensorMap, context);
66029 const keepDims = getParamValue('keepDims', node, tensorMap, context);
66030 return [ops.min(getParamValue('x', node, tensorMap, context), axis, keepDims)];
66031 }
66032 case 'Sum': {
66033 const axis = getParamValue('axis', node, tensorMap, context);
66034 const keepDims = getParamValue('keepDims', node, tensorMap, context);
66035 return [ops.sum(getParamValue('x', node, tensorMap, context), axis, keepDims)];
66036 }
66037 case 'All': {
66038 const axis = getParamValue('axis', node, tensorMap, context);
66039 const keepDims = getParamValue('keepDims', node, tensorMap, context);
66040 return [ops.all(getParamValue('x', node, tensorMap, context), axis, keepDims)];
66041 }
66042 case 'Any': {
66043 const axis = getParamValue('axis', node, tensorMap, context);
66044 const keepDims = getParamValue('keepDims', node, tensorMap, context);
66045 return [ops.any(getParamValue('x', node, tensorMap, context), axis, keepDims)];
66046 }
66047 case 'ArgMax': {
66048 const axis = getParamValue('axis', node, tensorMap, context);
66049 return [ops.argMax(getParamValue('x', node, tensorMap, context), axis)];
66050 }
66051 case 'ArgMin': {
66052 const axis = getParamValue('axis', node, tensorMap, context);
66053 return [ops.argMin(getParamValue('x', node, tensorMap, context), axis)];
66054 }
66055 case 'Prod': {
66056 const axis = getParamValue('axis', node, tensorMap, context);
66057 const keepDims = getParamValue('keepDims', node, tensorMap, context);
66058 return [ops.prod(getParamValue('x', node, tensorMap, context), axis, keepDims)];
66059 }
66060 case 'Cumprod': {
66061 const axis = getParamValue('axis', node, tensorMap, context);
66062 const exclusive = getParamValue('exclusive', node, tensorMap, context);
66063 const reverse = getParamValue('reverse', node, tensorMap, context);
66064 return [ops.cumprod(getParamValue('x', node, tensorMap, context), axis, exclusive, reverse)];
66065 }
66066 case 'Cumsum': {
66067 const axis = getParamValue('axis', node, tensorMap, context);
66068 const exclusive = getParamValue('exclusive', node, tensorMap, context);
66069 const reverse = getParamValue('reverse', node, tensorMap, context);
66070 return [ops.cumsum(getParamValue('x', node, tensorMap, context), axis, exclusive, reverse)];
66071 }
66072 case 'Bincount':
66073 const x = getParamValue('x', node, tensorMap, context);
66074 const weights = getParamValue('weights', node, tensorMap, context);
66075 const size = getParamValue('size', node, tensorMap, context);
66076 return [ops.bincount(x, weights, size)];
66077 case 'DenseBincount': {
66078 const x = getParamValue('x', node, tensorMap, context);
66079 const weights = getParamValue('weights', node, tensorMap, context);
66080 const size = getParamValue('size', node, tensorMap, context);
66081 const binaryOutput = getParamValue('binaryOutput', node, tensorMap, context);
66082 return [ops.denseBincount(x, weights, size, binaryOutput)];
66083 }
66084 default:
66085 throw TypeError(`Node type ${node.op} is not implemented`);
66086 }
66087 };
66088 const CATEGORY$5 = 'reduction';
66089
66090 /**
66091 * @license
66092 * Copyright 2018 Google LLC. All Rights Reserved.
66093 * Licensed under the Apache License, Version 2.0 (the "License");
66094 * you may not use this file except in compliance with the License.
66095 * You may obtain a copy of the License at
66096 *
66097 * http://www.apache.org/licenses/LICENSE-2.0
66098 *
66099 * Unless required by applicable law or agreed to in writing, software
66100 * distributed under the License is distributed on an "AS IS" BASIS,
66101 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
66102 * See the License for the specific language governing permissions and
66103 * limitations under the License.
66104 * =============================================================================
66105 */
66106 const executeOp$5 = (node, tensorMap, context, ops = tfOps) => {
66107 switch (node.op) {
66108 case 'ConcatV2':
66109 case 'Concat': {
66110 const n = getParamValue('n', node, tensorMap, context);
66111 const axis = getParamValue('axis', node, tensorMap, context);
66112 let inputs = getParamValue('tensors', node, tensorMap, context);
66113 inputs = inputs.slice(0, n);
66114 return [ops.concat(inputs, axis)];
66115 }
66116 case 'Gather': {
66117 const input = getParamValue('x', node, tensorMap, context);
66118 const indices = getParamValue('indices', node, tensorMap, context);
66119 return [ops.gather(input, ops.cast(indices, 'int32'), 0)];
66120 }
66121 case 'GatherV2': {
66122 const axis = getParamValue('axis', node, tensorMap, context);
66123 const batchDims = getParamValue('batchDims', node, tensorMap, context);
66124 const input = getParamValue('x', node, tensorMap, context);
66125 const indices = getParamValue('indices', node, tensorMap, context);
66126 return [ops.gather(input, ops.cast(indices, 'int32'), axis, batchDims)];
66127 }
66128 case 'Reverse': {
66129 const dims = getParamValue('dims', node, tensorMap, context);
66130 const axis = [];
66131 for (let i = 0; i < dims.length; i++) {
66132 if (dims[i]) {
66133 axis.push(i);
66134 }
66135 }
66136 const input = getParamValue('x', node, tensorMap, context);
66137 return [ops.reverse(input, axis)];
66138 }
66139 case 'ReverseV2': {
66140 const axis = getParamValue('axis', node, tensorMap, context);
66141 const input = getParamValue('x', node, tensorMap, context);
66142 return [ops.reverse(input, axis)];
66143 }
66144 case 'Slice': {
66145 // tslint:disable-next-line:no-any
66146 const begin = getParamValue('begin', node, tensorMap, context);
66147 // tslint:disable-next-line:no-any
66148 const size = getParamValue('size', node, tensorMap, context);
66149 return [ops.slice(getParamValue('x', node, tensorMap, context), begin, size)];
66150 }
66151 case 'StridedSlice': {
66152 const begin = getParamValue('begin', node, tensorMap, context);
66153 const end = getParamValue('end', node, tensorMap, context);
66154 const strides = getParamValue('strides', node, tensorMap, context);
66155 const beginMask = getParamValue('beginMask', node, tensorMap, context);
66156 const endMask = getParamValue('endMask', node, tensorMap, context);
66157 const ellipsisMask = getParamValue('ellipsisMask', node, tensorMap, context);
66158 const newAxisMask = getParamValue('newAxisMask', node, tensorMap, context);
66159 const shrinkAxisMask = getParamValue('shrinkAxisMask', node, tensorMap, context);
66160 const tensor = getParamValue('x', node, tensorMap, context);
66161 return [ops.stridedSlice(tensor, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask)];
66162 }
66163 case 'Pack': {
66164 return tidy(() => {
66165 const axis = getParamValue('axis', node, tensorMap, context);
66166 const tensors = getParamValue('tensors', node, tensorMap, context);
66167 // Reshape the tensors to the first tensor's shape if they don't
66168 // match.
66169 const shape = tensors[0].shape;
66170 const squeezedShape = ops.squeeze(tensors[0]).shape;
66171 const mapped = tensors.map(tensor => {
66172 const sameShape = arraysEqual(tensor.shape, shape);
66173 if (!sameShape &&
66174 !arraysEqual(ops.squeeze(tensor).shape, squeezedShape)) {
66175 throw new Error('the input tensors shape does not match');
66176 }
66177 return sameShape ? tensor : ops.reshape(tensor, shape);
66178 });
66179 return [ops.stack(mapped, axis)];
66180 });
66181 }
66182 case 'Unpack': {
66183 const axis = getParamValue('axis', node, tensorMap, context);
66184 const tensor = getParamValue('tensor', node, tensorMap, context);
66185 return ops.unstack(tensor, axis);
66186 }
66187 case 'Tile': {
66188 const reps = getParamValue('reps', node, tensorMap, context);
66189 return [ops.tile(getParamValue('x', node, tensorMap, context), reps)];
66190 }
66191 case 'Split':
66192 case 'SplitV': {
66193 const axis = getParamValue('axis', node, tensorMap, context);
66194 const numOrSizeSplits = getParamValue('numOrSizeSplits', node, tensorMap, context);
66195 const tensor = getParamValue('x', node, tensorMap, context);
66196 return ops.split(tensor, numOrSizeSplits, axis);
66197 }
66198 case 'ScatterNd': {
66199 const indices = getParamValue('indices', node, tensorMap, context);
66200 const values = getParamValue('values', node, tensorMap, context);
66201 const shape = getParamValue('shape', node, tensorMap, context);
66202 return [ops.scatterND(indices, values, shape)];
66203 }
66204 case 'GatherNd': {
66205 const x = getParamValue('x', node, tensorMap, context);
66206 const indices = getParamValue('indices', node, tensorMap, context);
66207 return [ops.gatherND(x, indices)];
66208 }
66209 case 'SparseToDense': {
66210 const indices = getParamValue('sparseIndices', node, tensorMap, context);
66211 const shape = getParamValue('outputShape', node, tensorMap, context);
66212 const sparseValues = getParamValue('sparseValues', node, tensorMap, context);
66213 const defaultValue = getParamValue('defaultValue', node, tensorMap, context);
66214 return [ops.sparseToDense(indices, sparseValues, shape, sparseValues.dtype === defaultValue.dtype ?
66215 defaultValue :
66216 ops.cast(defaultValue, sparseValues.dtype))];
66217 }
66218 case 'TensorScatterUpdate': {
66219 const indices = getParamValue('indices', node, tensorMap, context);
66220 const values = getParamValue('values', node, tensorMap, context);
66221 const tensor = getParamValue('tensor', node, tensorMap, context);
66222 return [ops.tensorScatterUpdate(tensor, indices, values)];
66223 }
66224 default:
66225 throw TypeError(`Node type ${node.op} is not implemented`);
66226 }
66227 };
66228 const CATEGORY$4 = 'slice_join';
66229
66230 /**
66231 * @license
66232 * Copyright 2021 Google LLC. All Rights Reserved.
66233 * Licensed under the Apache License, Version 2.0 (the "License");
66234 * you may not use this file except in compliance with the License.
66235 * You may obtain a copy of the License at
66236 *
66237 * http://www.apache.org/licenses/LICENSE-2.0
66238 *
66239 * Unless required by applicable law or agreed to in writing, software
66240 * distributed under the License is distributed on an "AS IS" BASIS,
66241 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
66242 * See the License for the specific language governing permissions and
66243 * limitations under the License.
66244 * =============================================================================
66245 */
66246 const executeOp$4 = (node, tensorMap, context, ops = tfOps) => {
66247 switch (node.op) {
66248 case 'SparseFillEmptyRows': {
66249 const { outputIndices, outputValues, emptyRowIndicator, reverseIndexMap } = ops.sparse.sparseFillEmptyRows(getParamValue('indices', node, tensorMap, context), getParamValue('values', node, tensorMap, context), getParamValue('denseShape', node, tensorMap, context), getParamValue('defaultValue', node, tensorMap, context));
66250 return [
66251 outputIndices, outputValues, emptyRowIndicator, reverseIndexMap
66252 ];
66253 }
66254 case 'SparseReshape': {
66255 const { outputIndices, outputShape } = ops.sparse.sparseReshape(getParamValue('inputIndices', node, tensorMap, context), getParamValue('inputShape', node, tensorMap, context), getParamValue('newShape', node, tensorMap, context));
66256 return [outputIndices, outputShape];
66257 }
66258 case 'SparseSegmentMean': {
66259 const outputData = ops.sparse.sparseSegmentMean(getParamValue('data', node, tensorMap, context), getParamValue('indices', node, tensorMap, context), getParamValue('segmentIds', node, tensorMap, context));
66260 return [outputData];
66261 }
66262 case 'SparseSegmentSum': {
66263 const outputData = ops.sparse.sparseSegmentSum(getParamValue('data', node, tensorMap, context), getParamValue('indices', node, tensorMap, context), getParamValue('segmentIds', node, tensorMap, context));
66264 return [outputData];
66265 }
66266 default:
66267 throw TypeError(`Node type ${node.op} is not implemented`);
66268 }
66269 };
66270 const CATEGORY$3 = 'sparse';
66271
66272 /**
66273 * @license
66274 * Copyright 2018 Google LLC. All Rights Reserved.
66275 * Licensed under the Apache License, Version 2.0 (the "License");
66276 * you may not use this file except in compliance with the License.
66277 * You may obtain a copy of the License at
66278 *
66279 * http://www.apache.org/licenses/LICENSE-2.0
66280 *
66281 * Unless required by applicable law or agreed to in writing, software
66282 * distributed under the License is distributed on an "AS IS" BASIS,
66283 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
66284 * See the License for the specific language governing permissions and
66285 * limitations under the License.
66286 * =============================================================================
66287 */
66288 const executeOp$3 = (node, tensorMap, context, ops = tfOps) => {
66289 switch (node.op) {
66290 case 'FFT': {
66291 return [ops.fft(getParamValue('x', node, tensorMap, context))];
66292 }
66293 case 'IFFT': {
66294 return [ops.ifft(getParamValue('x', node, tensorMap, context))];
66295 }
66296 case 'RFFT': {
66297 return [ops.rfft(getParamValue('x', node, tensorMap, context))];
66298 }
66299 case 'IRFFT': {
66300 return [ops.irfft(getParamValue('x', node, tensorMap, context))];
66301 }
66302 default:
66303 throw TypeError(`Node type ${node.op} is not implemented`);
66304 }
66305 };
66306 const CATEGORY$2 = 'spectral';
66307
66308 /**
66309 * @license
66310 * Copyright 2021 Google LLC. All Rights Reserved.
66311 * Licensed under the Apache License, Version 2.0 (the "License");
66312 * you may not use this file except in compliance with the License.
66313 * You may obtain a copy of the License at
66314 *
66315 * http://www.apache.org/licenses/LICENSE-2.0
66316 *
66317 * Unless required by applicable law or agreed to in writing, software
66318 * distributed under the License is distributed on an "AS IS" BASIS,
66319 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
66320 * See the License for the specific language governing permissions and
66321 * limitations under the License.
66322 * =============================================================================
66323 */
66324 const executeOp$2 = (node, tensorMap, context, ops = tfOps) => {
66325 switch (node.op) {
66326 case 'StaticRegexReplace': {
66327 return [ops.string.staticRegexReplace(getParamValue('input', node, tensorMap, context), getParamValue('pattern', node, tensorMap, context), getParamValue('rewrite', node, tensorMap, context), getParamValue('replaceGlobal', node, tensorMap, context))];
66328 }
66329 case 'StringNGrams': {
66330 const { nGrams, nGramsSplits } = ops.string.stringNGrams(getParamValue('data', node, tensorMap, context), getParamValue('dataSplits', node, tensorMap, context), getParamValue('separator', node, tensorMap, context), getParamValue('nGramWidths', node, tensorMap, context), getParamValue('leftPad', node, tensorMap, context), getParamValue('rightPad', node, tensorMap, context), getParamValue('padWidth', node, tensorMap, context), getParamValue('preserveShortSequences', node, tensorMap, context));
66331 return [nGrams, nGramsSplits];
66332 }
66333 case 'StringSplit': {
66334 const { indices, values, shape } = ops.string.stringSplit(getParamValue('input', node, tensorMap, context), getParamValue('delimiter', node, tensorMap, context), getParamValue('skipEmpty', node, tensorMap, context));
66335 return [indices, values, shape];
66336 }
66337 case 'StringToHashBucketFast': {
66338 const output = ops.string.stringToHashBucketFast(getParamValue('input', node, tensorMap, context), getParamValue('numBuckets', node, tensorMap, context));
66339 return [output];
66340 }
66341 default:
66342 throw TypeError(`Node type ${node.op} is not implemented`);
66343 }
66344 };
66345 const CATEGORY$1 = 'string';
66346
66347 /**
66348 * @license
66349 * Copyright 2018 Google LLC. All Rights Reserved.
66350 * Licensed under the Apache License, Version 2.0 (the "License");
66351 * you may not use this file except in compliance with the License.
66352 * You may obtain a copy of the License at
66353 *
66354 * http://www.apache.org/licenses/LICENSE-2.0
66355 *
66356 * Unless required by applicable law or agreed to in writing, software
66357 * distributed under the License is distributed on an "AS IS" BASIS,
66358 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
66359 * See the License for the specific language governing permissions and
66360 * limitations under the License.
66361 * =============================================================================
66362 */
66363 const executeOp$1 = (node, tensorMap, context, ops = tfOps) => {
66364 switch (node.op) {
66365 case 'Cast': {
66366 return [ops.cast(getParamValue('x', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
66367 }
66368 case 'ExpandDims': {
66369 const axis = getParamValue('axis', node, tensorMap, context);
66370 return [ops.expandDims(getParamValue('x', node, tensorMap, context), axis)];
66371 }
66372 case 'Squeeze': {
66373 const axis = getParamValue('axis', node, tensorMap, context);
66374 return [ops.squeeze(getParamValue('x', node, tensorMap, context), axis)];
66375 }
66376 case 'Reshape': {
66377 return [ops.reshape(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
66378 }
66379 case 'EnsureShape': {
66380 return [ops.ensureShape(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
66381 }
66382 case 'MirrorPad': {
66383 return [ops.mirrorPad(getParamValue('x', node, tensorMap, context), getParamValue('padding', node, tensorMap, context), getParamValue('mode', node, tensorMap, context))];
66384 }
66385 case 'PadV2':
66386 case 'Pad': {
66387 return [ops.pad(getParamValue('x', node, tensorMap, context), getParamValue('padding', node, tensorMap, context), getParamValue('constantValue', node, tensorMap, context))];
66388 }
66389 case 'SpaceToBatchND': {
66390 const blockShape = getParamValue('blockShape', node, tensorMap, context);
66391 const paddings = getParamValue('paddings', node, tensorMap, context);
66392 return [ops.spaceToBatchND(getParamValue('x', node, tensorMap, context), blockShape, paddings)];
66393 }
66394 case 'BatchToSpaceND': {
66395 const blockShape = getParamValue('blockShape', node, tensorMap, context);
66396 const crops = getParamValue('crops', node, tensorMap, context);
66397 return [ops.batchToSpaceND(getParamValue('x', node, tensorMap, context), blockShape, crops)];
66398 }
66399 case 'DepthToSpace': {
66400 const blockSize = getParamValue('blockSize', node, tensorMap, context);
66401 const dataFormat = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
66402 return [ops.depthToSpace(getParamValue('x', node, tensorMap, context), blockSize, dataFormat)];
66403 }
66404 case 'BroadcastTo': {
66405 return [ops.broadcastTo(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
66406 }
66407 case 'BroadcastArgs': {
66408 return [ops.broadcastArgs(getParamValue('s0', node, tensorMap, context), getParamValue('s1', node, tensorMap, context))];
66409 }
66410 default:
66411 throw TypeError(`Node type ${node.op} is not implemented`);
66412 }
66413 };
66414 const CATEGORY = 'transformation';
66415
66416 /**
66417 * @license
66418 * Copyright 2018 Google LLC. All Rights Reserved.
66419 * Licensed under the Apache License, Version 2.0 (the "License");
66420 * you may not use this file except in compliance with the License.
66421 * You may obtain a copy of the License at
66422 *
66423 * http://www.apache.org/licenses/LICENSE-2.0
66424 *
66425 * Unless required by applicable law or agreed to in writing, software
66426 * distributed under the License is distributed on an "AS IS" BASIS,
66427 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
66428 * See the License for the specific language governing permissions and
66429 * limitations under the License.
66430 * =============================================================================
66431 */
66432 /**
66433 * Executes the op defined by the node object.
66434 * @param node
66435 * @param tensorMap contains tensors for executed nodes and weights
66436 * @param context contains tensors and information for running the current node.
66437 * @param resourceManager Optional. Contains global resources of the model.
66438 */
66439 function executeOp(node, tensorMap, context, resourceManager, tidy$1 = tidy) {
66440 const value = ((node, tensorMap, context) => {
66441 switch (node.category) {
66442 case 'arithmetic':
66443 return tidy$1(() => executeOp$k(node, tensorMap, context));
66444 case 'basic_math':
66445 return tidy$1(() => executeOp$j(node, tensorMap, context));
66446 case 'control':
66447 return executeOp$i(node, tensorMap, context);
66448 case 'convolution':
66449 return tidy$1(() => executeOp$h(node, tensorMap, context));
66450 case 'creation':
66451 return tidy$1(() => executeOp$g(node, tensorMap, context));
66452 case 'dynamic':
66453 return executeOp$f(node, tensorMap, context);
66454 case 'evaluation':
66455 return tidy$1(() => executeOp$e(node, tensorMap, context));
66456 case 'image':
66457 return tidy$1(() => executeOp$b(node, tensorMap, context));
66458 case 'graph':
66459 return tidy$1(() => executeOp$d(node, tensorMap, context));
66460 case 'logical':
66461 return tidy$1(() => executeOp$a(node, tensorMap, context));
66462 case 'matrices':
66463 return tidy$1(() => executeOp$9(node, tensorMap, context));
66464 case 'normalization':
66465 return tidy$1(() => executeOp$8(node, tensorMap, context));
66466 case 'ragged':
66467 return tidy$1(() => executeOp$7(node, tensorMap, context));
66468 case 'reduction':
66469 return tidy$1(() => executeOp$6(node, tensorMap, context));
66470 case 'slice_join':
66471 return tidy$1(() => executeOp$5(node, tensorMap, context));
66472 case 'sparse':
66473 return tidy$1(() => executeOp$4(node, tensorMap, context));
66474 case 'spectral':
66475 return tidy$1(() => executeOp$3(node, tensorMap, context));
66476 case 'string':
66477 return tidy$1(() => executeOp$2(node, tensorMap, context));
66478 case 'transformation':
66479 return tidy$1(() => executeOp$1(node, tensorMap, context));
66480 case 'hash_table':
66481 return executeOp$c(node, tensorMap, context, resourceManager);
66482 case 'custom':
66483 const opMapper = getRegisteredOp(node.op);
66484 if (opMapper && opMapper.customExecutor) {
66485 return opMapper.customExecutor(new NodeValueImpl(node, tensorMap, context));
66486 }
66487 else {
66488 throw TypeError(`Custom op ${node.op} is not registered.`);
66489 }
66490 default:
66491 throw TypeError(`Unknown op '${node.op}'. File an issue at ` +
66492 `https://github.com/tensorflow/tfjs/issues so we can add it` +
66493 `, or register a custom execution with tf.registerOp()`);
66494 }
66495 })(node, tensorMap, context);
66496 if (isPromise(value)) {
66497 return value.then((data) => [].concat(data));
66498 }
66499 return [].concat(value);
66500 }
66501
66502 /**
66503 * ExecutionContext captures the runtime environment of the node. It keeps
66504 * track of the current frame and iteration for the control flow ops.
66505 *
66506 * For example, typical Dynamic RNN model may contain loops, for which
66507 * TensorFlow will generate graphs with Enter/Exit nodes to control the
66508 * current execution frame, and NextIteration Nodes for iteration id increment.
66509 * For model with branch logic, TensorFLow will generate Switch/Merge ops.
66510 */
66511 class ExecutionContext {
66512 constructor(weightMap = {}, tensorArrayMap = {}, tensorListMap = {}, functionMap = {}, parseNodeNameCache) {
66513 this.weightMap = weightMap;
66514 this.tensorArrayMap = tensorArrayMap;
66515 this.tensorListMap = tensorListMap;
66516 this.functionMap = functionMap;
66517 this.parseNodeNameCache = parseNodeNameCache;
66518 this.rootContext = { id: 0, frameName: '', iterationId: 0 };
66519 this.contexts = [this.rootContext];
66520 this.lastId = 0;
66521 this.generateCurrentContextIds();
66522 }
66523 newFrame(id, frameName) {
66524 return { id, frameName, iterationId: 0 };
66525 }
66526 /**
66527 * Set the current context
66528 * @param contexts: ExecutionContextInfo[] the current path of execution
66529 * frames
66530 */
66531 set currentContext(contexts) {
66532 if (this.contexts !== contexts) {
66533 this.contexts = contexts;
66534 this.generateCurrentContextIds();
66535 }
66536 }
66537 get currentContext() {
66538 return this.contexts;
66539 }
66540 /**
66541 * Returns the current context in string format.
66542 */
66543 get currentContextId() {
66544 return this._currentContextIds[0];
66545 }
66546 /**
66547 * Returns the current context and all parent contexts in string format.
66548 * This allow access to the nodes in the current and parent frames.
66549 */
66550 get currentContextIds() {
66551 return this._currentContextIds;
66552 }
66553 generateCurrentContextIds() {
66554 const names = [];
66555 for (let i = 0; i < this.contexts.length - 1; i++) {
66556 const contexts = this.contexts.slice(0, this.contexts.length - i);
66557 names.push(this.contextIdforContexts(contexts));
66558 }
66559 names.push('');
66560 this._currentContextIds = names;
66561 }
66562 contextIdforContexts(contexts) {
66563 return contexts ?
66564 contexts
66565 .map(context => (context.id === 0 && context.iterationId === 0) ?
66566 '' :
66567 `${context.frameName}-${context.iterationId}`)
66568 .join('/') :
66569 '';
66570 }
66571 /**
66572 * Enter a new frame, a new context is pushed on the current context list.
66573 * @param frameId new frame id
66574 */
66575 enterFrame(frameId) {
66576 if (this.contexts) {
66577 this.lastId++;
66578 this.contexts = this.contexts.slice();
66579 this.contexts.push(this.newFrame(this.lastId, frameId));
66580 this._currentContextIds.unshift(this.contextIdforContexts(this.contexts));
66581 }
66582 }
66583 /**
66584 * Exit the current frame, the last context is removed from the current
66585 * context list.
66586 */
66587 exitFrame() {
66588 if (this.contexts && this.contexts.length > 1) {
66589 this.contexts = this.contexts.slice();
66590 this.contexts.splice(-1);
66591 this.currentContextIds.shift();
66592 }
66593 else {
66594 throw new Error('Cannot exit frame, the context is empty');
66595 }
66596 }
66597 /**
66598 * Enter the next iteration of a loop, the iteration id of last context is
66599 * increased.
66600 */
66601 nextIteration() {
66602 if (this.contexts && this.contexts.length > 0) {
66603 this.contexts = this.contexts.slice();
66604 this.lastId++;
66605 const context = Object.assign({}, this.contexts[this.contexts.length - 1]);
66606 context.iterationId += 1;
66607 context.id = this.lastId;
66608 this.contexts.splice(-1, 1, context);
66609 this._currentContextIds.splice(0, 1, this.contextIdforContexts(this.contexts));
66610 }
66611 else {
66612 throw new Error('Cannot increase frame iteration, the context is empty');
66613 }
66614 }
66615 getWeight(name) {
66616 return this.weightMap[name];
66617 }
66618 addTensorArray(tensorArray) {
66619 this.tensorArrayMap[tensorArray.id] = tensorArray;
66620 }
66621 getTensorArray(id) {
66622 return this.tensorArrayMap[id];
66623 }
66624 addTensorList(tensorList) {
66625 this.tensorListMap[tensorList.id] = tensorList;
66626 }
66627 getTensorList(id) {
66628 return this.tensorListMap[id];
66629 }
66630 dispose(keepIds) {
66631 for (const key in this.tensorArrayMap) {
66632 this.tensorArrayMap[key].clearAndClose(keepIds);
66633 }
66634 for (const key in this.tensorListMap) {
66635 this.tensorListMap[key].clearAndClose(keepIds);
66636 }
66637 }
66638 }
66639
66640 /**
66641 * @license
66642 * Copyright 2019 Google LLC. All Rights Reserved.
66643 * Licensed under the Apache License, Version 2.0 (the "License");
66644 * you may not use this file except in compliance with the License.
66645 * You may obtain a copy of the License at
66646 *
66647 * http://www.apache.org/licenses/LICENSE-2.0
66648 *
66649 * Unless required by applicable law or agreed to in writing, software
66650 * distributed under the License is distributed on an "AS IS" BASIS,
66651 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
66652 * See the License for the specific language governing permissions and
66653 * limitations under the License.
66654 * =============================================================================
66655 */
66656 /**
66657 * Given graph inputs and desired outputs, find the minimal set of nodes
66658 * to execute in order to compute the outputs. In addition return other useful
66659 * info such:
66660 * - Missing inputs needed to compute the output.
66661 * - Whether the subgraph contains dynamic ops (control flow, dynamic shape).
66662 * - Alternative inputs in order to avoid async (dynamic op) execution.
66663 */
66664 function getExecutionSubgraph(inputs, outputs, weightMap, initNodes) {
66665 const usedNodes = new Set();
66666 const missingInputs = [];
66667 let dynamicNode = null;
66668 let syncInputs = null;
66669 // Start with the outputs, going backwards and find all the nodes that are
66670 // needed to compute those outputs.
66671 const seen = new Set();
66672 const inputNodeNames = new Set(Object.keys(inputs).map((name) => parseNodeName(name)[0]));
66673 initNodes = initNodes || [];
66674 const initNodeNames = new Set(initNodes.map((node) => parseNodeName(node.name)[0]));
66675 const frontier = [...outputs];
66676 while (frontier.length > 0) {
66677 const node = frontier.pop();
66678 if (isControlFlow(node) || isDynamicShape(node) || isHashTable(node)) {
66679 if (dynamicNode == null) {
66680 dynamicNode = node;
66681 syncInputs = dynamicNode.children.map(child => child.name)
66682 .filter(name => usedNodes.has(name));
66683 }
66684 }
66685 usedNodes.add(node.name);
66686 // Weights are dead end since we already have their values.
66687 if (weightMap[node.name] != null) {
66688 continue;
66689 }
66690 // This node is a dead end since it's one of the user-provided inputs.
66691 if (inputNodeNames.has(node.name)) {
66692 continue;
66693 }
66694 // This node is a dead end since it doesn't have any inputs.
66695 if (initNodeNames.has(node.name)) {
66696 continue;
66697 }
66698 if (node.inputs.length === 0) {
66699 missingInputs.push(node.name);
66700 continue;
66701 }
66702 node.inputs.forEach(input => {
66703 // Don't add to the frontier if it is already there.
66704 if (seen.has(input.name)) {
66705 return;
66706 }
66707 seen.add(input.name);
66708 frontier.push(input);
66709 });
66710 }
66711 return { inputs, outputs, usedNodes, missingInputs, dynamicNode, syncInputs };
66712 }
66713 /**
66714 * Given the execution info, return a list of nodes in topological order that
66715 * need to be executed to compute the output.
66716 */
66717 function getNodesInTopologicalOrder(graph, executionInfo) {
66718 const { usedNodes, inputs } = executionInfo;
66719 const inputNodes = Object.keys(inputs)
66720 .map(name => parseNodeName(name)[0])
66721 .map(name => graph.nodes[name]);
66722 const initNodes = graph.initNodes || [];
66723 const isUsed = (node) => usedNodes.has(typeof node === 'string' ? node : node.name);
66724 function unique(nodes) {
66725 return [...new Map(nodes.map((node) => [node.name, node])).values()];
66726 }
66727 const predefinedNodes = unique([
66728 ...inputNodes,
66729 ...graph.weights,
66730 ...initNodes,
66731 ]).filter(isUsed);
66732 const allNodes = unique([
66733 ...predefinedNodes,
66734 ...Object.values(graph.nodes),
66735 ]).filter(isUsed);
66736 const nameToNode = new Map(allNodes.map((node) => [node.name, node]));
66737 const inCounts = {};
66738 for (const node of allNodes) {
66739 inCounts[node.name] = inCounts[node.name] || 0;
66740 for (const child of node.children) {
66741 // When the child is unused, set in counts to infinity so that it will
66742 // never be decreased to 0 and added to the execution list.
66743 if (!isUsed(child)) {
66744 inCounts[child.name] = Number.POSITIVE_INFINITY;
66745 }
66746 inCounts[child.name] = (inCounts[child.name] || 0) + 1;
66747 }
66748 }
66749 // Build execution order for all used nodes regardless whether they are
66750 // predefined or not.
66751 const frontier = Object.entries(inCounts)
66752 .filter(([, inCount]) => inCount === 0)
66753 .map(([name]) => name);
66754 const orderedNodeNames = [...frontier];
66755 while (frontier.length > 0) {
66756 const nodeName = frontier.pop();
66757 const node = nameToNode.get(nodeName);
66758 for (const child of node.children.filter(isUsed)) {
66759 if (--inCounts[child.name] === 0) {
66760 orderedNodeNames.push(child.name);
66761 frontier.push(child.name);
66762 }
66763 }
66764 }
66765 const orderedNodes = orderedNodeNames.map((name) => nameToNode.get(name));
66766 const filteredOrderedNodes = filterPredefinedReachableNodes(orderedNodes, predefinedNodes);
66767 // TODO: Turn validation on/off with tf env flag.
66768 validateNodesExecutionOrder(filteredOrderedNodes, predefinedNodes);
66769 return filteredOrderedNodes;
66770 }
66771 /**
66772 * This is a helper function of `getNodesInTopologicalOrder`.
66773 * Returns ordered nodes reachable by at least one predefined node.
66774 * This can help us filter out redundant nodes from the returned node list.
66775 * For example:
66776 * If we have four nodes with dependencies like this:
66777 * a --> b --> c --> d
66778 * when node `c` is predefined (e.g. given as an input tensor), we can
66779 * skip node `a` and `b` since their outputs will never be used.
66780 *
66781 * @param orderedNodes Graph nodes in execution order.
66782 * @param predefinedNodes Graph inputs, weights, and init nodes. Nodes in this
66783 * list must have distinct names.
66784 */
66785 function filterPredefinedReachableNodes(orderedNodes, predefinedNodes) {
66786 const nameToNode = new Map(orderedNodes.map((node) => [node.name, node]));
66787 // TODO: Filter out more nodes when >=2 nodes are predefined in a path.
66788 const stack = predefinedNodes.map((node) => node.name);
66789 const predefinedReachableNodeNames = new Set(stack);
66790 // Perform a DFS starting from the set of all predefined nodes
66791 // to find the set of all nodes reachable from the predefined nodes.
66792 while (stack.length > 0) {
66793 const nodeName = stack.pop();
66794 const node = nameToNode.get(nodeName);
66795 for (const child of node.children) {
66796 if (!nameToNode.has(child.name) ||
66797 predefinedReachableNodeNames.has(child.name)) {
66798 continue;
66799 }
66800 predefinedReachableNodeNames.add(child.name);
66801 stack.push(child.name);
66802 }
66803 }
66804 // Filter out unreachable nodes and build the ordered node list.
66805 const filteredOrderedNodes = orderedNodes.filter((node) => predefinedReachableNodeNames.has(node.name));
66806 return filteredOrderedNodes;
66807 }
66808 class NodesExecutionOrderError extends Error {
66809 constructor(message) {
66810 super(`NodesExecutionOrderError: ${message}`);
66811 }
66812 }
66813 /**
66814 * This is a helper function of `getNodesInTopologicalOrder`.
66815 * Validates property: given nodes `a` and `b`, Order(a) > Order(b) if `a`
66816 * is a child of `b`. This function throws an error if validation fails.
66817 *
66818 * @param orderedNodes Graph nodes in execution order.
66819 * @param predefinedNodes Graph inputs, weights, and init nodes. Nodes in this
66820 * list must have distinct names.
66821 */
66822 function validateNodesExecutionOrder(orderedNodes, predefinedNodes) {
66823 const nodeNameToOrder = new Map(orderedNodes.map((node, order) => [node.name, order]));
66824 const predefinedNodeNames = new Set(predefinedNodes.map((node) => node.name));
66825 const isPredefined = (node) => predefinedNodeNames.has(typeof node === 'string' ? node : node.name);
66826 const willBeExecutedNodeNames = new Set(orderedNodes.map((node) => node.name));
66827 const willBeExecuted = (node) => willBeExecutedNodeNames.has(typeof node === 'string' ? node : node.name);
66828 for (const node of orderedNodes) {
66829 for (const child of node.children.filter(willBeExecuted)) {
66830 if (!nodeNameToOrder.has(child.name)) {
66831 throw new NodesExecutionOrderError(`Child ${child.name} of node ${node.name} is unreachable.`);
66832 }
66833 if (nodeNameToOrder.get(node.name) > nodeNameToOrder.get(child.name)) {
66834 throw new NodesExecutionOrderError(`Node ${node.name} is scheduled to run after its child ${child.name}.`);
66835 }
66836 }
66837 if (!isPredefined(node)) {
66838 for (const input of node.inputs) {
66839 if (!nodeNameToOrder.has(input.name)) {
66840 throw new NodesExecutionOrderError(`Input ${input.name} of node ${node.name} is unreachable.`);
66841 }
66842 if (nodeNameToOrder.get(input.name) > nodeNameToOrder.get(node.name)) {
66843 throw new NodesExecutionOrderError(`Node ${node.name} is scheduled to run before its input ${input.name}.`);
66844 }
66845 }
66846 }
66847 }
66848 }
66849 /**
66850 * Given the execution info, return a map from node name to the disposable
66851 * node name list after its execution.
66852 *
66853 * @returns A map from node name to disposable nodes after its
66854 * execution. That is, for a node `x`, `nodeLiveUntilMap[x]` indicates
66855 * all nodes which their intermediate tensors should be disposed after `x`
66856 * being executed.
66857 */
66858 function getNodeLiveUntilMap(orderedNodes) {
66859 const nodeNameToOrder = new Map(orderedNodes.map((node, order) => [node.name, order]));
66860 const INF_LIFE = Number.MAX_SAFE_INTEGER;
66861 // Make control flow nodes (and consequently their direct parents)
66862 // live forever since they're tricky to track correctly.
66863 const selfLifespans = orderedNodes.map((node, nodeOrder) => isControlFlow(node) ? INF_LIFE : nodeOrder);
66864 const getSelfLifeSpan = (node) => {
66865 const selfLife = selfLifespans[nodeNameToOrder.get(node.name)];
66866 if (selfLife == null) {
66867 // If nodeToOrder does not contain the node, it is unused or
66868 // unreachable in graph.
66869 return -1;
66870 }
66871 return selfLife;
66872 };
66873 // `liveUntil[i]` points to the last node in the `orderedNodes` array that
66874 // may depend on tensors from node `i`. It indicates that all the
66875 // intermediate tensors from `orderedNodes[i]` should be disposed after
66876 // `orderedNodes[liveUntil[i]]` is executed.
66877 // A node lives long enough to pass on its tensors to its children.
66878 // It lives until at least `max(node's position, children's positions)`.
66879 const liveUntilOrders = orderedNodes.map((node, nodeOrder) => {
66880 return node.children.map(getSelfLifeSpan)
66881 .reduce((a, b) => Math.max(a, b), selfLifespans[nodeOrder]);
66882 });
66883 // liveUntilMap:
66884 // - Key: Name of a node `x`
66885 // - Values: All nodes whose intermediate tensors should be disposed
66886 // after `x` is executed.
66887 const liveUntilMap = new Map();
66888 for (let nodeOrder = 0; nodeOrder < orderedNodes.length; ++nodeOrder) {
66889 const liveUntilOrder = liveUntilOrders[nodeOrder];
66890 if (liveUntilOrder === INF_LIFE) {
66891 continue;
66892 }
66893 const node = orderedNodes[nodeOrder];
66894 const liveUntilNode = orderedNodes[liveUntilOrder];
66895 if (!liveUntilMap.has(liveUntilNode.name)) {
66896 liveUntilMap.set(liveUntilNode.name, []);
66897 }
66898 liveUntilMap.get(liveUntilNode.name).push(node);
66899 }
66900 return liveUntilMap;
66901 }
66902 const CONTROL_FLOW_OPS = new Set([
66903 'Switch', 'Merge', 'Enter', 'Exit', 'NextIteration', 'StatelessIf',
66904 'StatelessWhile', 'if', 'While'
66905 ]);
66906 const DYNAMIC_SHAPE_OPS = new Set([
66907 'NonMaxSuppressionV2', 'NonMaxSuppressionV3', 'NonMaxSuppressionV5', 'Where'
66908 ]);
66909 const HASH_TABLE_OPS = new Set([
66910 'HashTable', 'HashTableV2', 'LookupTableImport', 'LookupTableImportV2',
66911 'LookupTableFind', 'LookupTableFindV2', 'LookupTableSize', 'LookupTableSizeV2'
66912 ]);
66913 function isControlFlow(node) {
66914 return CONTROL_FLOW_OPS.has(node.op);
66915 }
66916 function isDynamicShape(node) {
66917 return DYNAMIC_SHAPE_OPS.has(node.op);
66918 }
66919 function isHashTable(node) {
66920 return HASH_TABLE_OPS.has(node.op);
66921 }
66922
66923 /**
66924 * @license
66925 * Copyright 2018 Google LLC. All Rights Reserved.
66926 * Licensed under the Apache License, Version 2.0 (the "License");
66927 * you may not use this file except in compliance with the License.
66928 * You may obtain a copy of the License at
66929 *
66930 * http://www.apache.org/licenses/LICENSE-2.0
66931 *
66932 * Unless required by applicable law or agreed to in writing, software
66933 * distributed under the License is distributed on an "AS IS" BASIS,
66934 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
66935 * See the License for the specific language governing permissions and
66936 * limitations under the License.
66937 * =============================================================================
66938 */
66939 class GraphExecutor {
66940 get weightIds() {
66941 return this.parent ? this.parent.weightIds : this._weightIds;
66942 }
66943 get functionExecutorMap() {
66944 return this.parent ? this.parent.functionExecutorMap :
66945 this._functionExecutorMap;
66946 }
66947 get weightMap() {
66948 return this.parent ? this.parent.weightMap : this._weightMap;
66949 }
66950 set weightMap(weightMap) {
66951 const weightIds = Object.keys(weightMap).map(key => weightMap[key].map(tensor => tensor.id));
66952 this._weightIds = [].concat(...weightIds);
66953 this._weightMap = weightMap;
66954 }
66955 /**
66956 * Set `ResourceManager` shared by executors of a model.
66957 * @param resourceManager: `ResourceManager` of the `GraphModel`.
66958 */
66959 set resourceManager(resourceManager) {
66960 this._resourceManager = resourceManager;
66961 }
66962 get inputs() {
66963 return this._inputs.map(node => {
66964 return {
66965 name: node.name,
66966 shape: node.attrParams['shape'] ?
66967 node.attrParams['shape'].value :
66968 undefined,
66969 dtype: node.attrParams['dtype'] ?
66970 node.attrParams['dtype'].value :
66971 undefined
66972 };
66973 });
66974 }
66975 get outputs() {
66976 return this._outputs.map(node => {
66977 return {
66978 name: node.name,
66979 shape: node.attrParams['shape'] ?
66980 node.attrParams['shape'].value :
66981 undefined,
66982 dtype: node.attrParams['dtype'] ?
66983 node.attrParams['dtype'].value :
66984 undefined
66985 };
66986 });
66987 }
66988 get inputNodes() {
66989 return this._inputs.map(node => node.signatureKey || node.name);
66990 }
66991 get outputNodes() {
66992 return this._outputs.map((node) => {
66993 const name = node.signatureKey || node.name;
66994 return node.defaultOutput ? (`${name}:${node.defaultOutput}`) : name;
66995 });
66996 }
66997 get functions() {
66998 return Object.keys(this._functions).reduce((map, key) => {
66999 map[key] = this._functions[key].signature;
67000 return map;
67001 }, {});
67002 }
67003 /**
67004 *
67005 * @param graph Graph the model or function graph to be executed.
67006 * @param parent When building function exector you need to set the parent
67007 * executor. Since the weights and function executor maps are set at parant
67008 * level, that function executor can access the function maps and weight maps
67009 * through the parent.
67010 */
67011 constructor(graph, parent) {
67012 this.graph = graph;
67013 this.parent = parent;
67014 this.compiledMap = new Map();
67015 this.parseNodeNameCache = new Map();
67016 this._weightMap = {};
67017 this.SEPARATOR = ',';
67018 this._functions = {};
67019 this._functionExecutorMap = {};
67020 this.keepIntermediateTensors = false;
67021 this._outputs = graph.outputs;
67022 this._inputs = graph.inputs;
67023 this._initNodes = graph.initNodes;
67024 this._signature = graph.signature;
67025 this._functions = graph.functions;
67026 // create sub-graph executors
67027 if (graph.functions != null) {
67028 Object.keys(graph.functions).forEach(name => {
67029 this._functionExecutorMap[name] =
67030 new GraphExecutor(graph.functions[name], this);
67031 });
67032 }
67033 }
67034 getCompilationKey(inputs, outputs) {
67035 const sortedInputs = inputs.map(node => node.name).sort();
67036 const sortedOutputs = outputs.map(node => node.name).sort();
67037 return sortedInputs.join(this.SEPARATOR) + '--' +
67038 sortedOutputs.join(this.SEPARATOR);
67039 }
67040 /**
67041 * Compiles the inference graph and returns the minimal set of nodes that are
67042 * required for execution, in the correct execution order.
67043 * @returns {Object} compilation The compile result.
67044 * @returns {Node[]} compilation.orderedNodes Nodes in the correct execution
67045 * order.
67046 * @returns {Map<string, Node[]>} compilation.nodeLiveUntilMap A map from node
67047 * to disposable nodes after its execution. That is, for a node `x`,
67048 * `nodeLiveUntilMap[x]` indicates all nodes whose intermediate
67049 * tensors should be disposed after `x` is executed.
67050 */
67051 compile(inputs, outputs) {
67052 const executionInfo = getExecutionSubgraph(inputs, outputs, this.weightMap, this._initNodes);
67053 const { missingInputs, dynamicNode, syncInputs } = executionInfo;
67054 if (dynamicNode != null) {
67055 throw new Error(`This execution contains the node '${dynamicNode.name}', which has ` +
67056 `the dynamic op '${dynamicNode.op}'. Please use ` +
67057 `model.executeAsync() instead. Alternatively, to avoid the ` +
67058 `dynamic ops, specify the inputs [${syncInputs}]`);
67059 }
67060 if (missingInputs.length > 0) {
67061 const outNames = outputs.map(n => n.name);
67062 const inNames = Object.keys(inputs);
67063 throw new Error(`Cannot compute the outputs [${outNames}] from the provided inputs ` +
67064 `[${inNames}]. Missing the following inputs: [${missingInputs}]`);
67065 }
67066 const orderedNodes = getNodesInTopologicalOrder(this.graph, executionInfo);
67067 const nodeLiveUntilMap = getNodeLiveUntilMap(orderedNodes);
67068 return { orderedNodes, nodeLiveUntilMap };
67069 }
67070 cloneAndKeepTensor(tensor) {
67071 if (tensor == null) {
67072 return null;
67073 }
67074 const clone = tensor.clone();
67075 // Keep the clone because`model.execute()` may be called within
67076 // a `tidy()`, but the user may inspect these tensors after the
67077 // tidy.
67078 keep(clone);
67079 return clone;
67080 }
67081 cloneTensorList(tensors) {
67082 if (!tensors) {
67083 return null;
67084 }
67085 const clonedTensor = tensors.map(tensor => {
67086 return this.cloneAndKeepTensor(tensor);
67087 });
67088 return clonedTensor;
67089 }
67090 cloneTensorMap(tensorsMap) {
67091 return Object.fromEntries(Object.entries(tensorsMap).map(([name, tensorsList]) => {
67092 return [name, this.cloneTensorList(tensorsList)];
67093 }));
67094 }
67095 /**
67096 * Executes the inference for given input tensors.
67097 * @param inputs Tensor map for the model inputs, keyed by the input node
67098 * names.
67099 * @param outputs Optional. output node name from the Tensorflow model, if
67100 * no outputs are specified, the default outputs of the model would be used.
67101 * You can inspect intermediate nodes of the model by adding them to the
67102 * outputs array.
67103 */
67104 execute(inputs, outputs) {
67105 // Dispose any tensors from a prior run to avoid leaking them.
67106 this.disposeIntermediateTensors();
67107 inputs = this.mapInputs(inputs);
67108 const names = Object.keys(inputs).sort();
67109 this.checkInputs(inputs);
67110 this.checkInputShapeAndType(inputs);
67111 outputs = this.mapOutputs(outputs);
67112 this.checkOutputs(outputs);
67113 const inputNodes = names.map(name => this.graph.nodes[parseNodeName(name)[0]]);
67114 const outputNodeNames = outputs.map(name => parseNodeName(name)[0]);
67115 const outputNodeNameSet = new Set(outputNodeNames);
67116 let outputNodes = outputNodeNames.map(name => this.graph.nodes[name]);
67117 // If no outputs are specified, then use the default outputs of the model.
67118 if (outputNodes.length === 0) {
67119 outputNodes = this._outputs;
67120 }
67121 const compilationKey = this.getCompilationKey(inputNodes, outputNodes);
67122 // Do nothing if the compiled graph cache contains the input.
67123 let compilation = this.compiledMap.get(compilationKey);
67124 if (compilation == null) {
67125 compilation = this.compile(inputs, outputNodes);
67126 this.compiledMap.set(compilationKey, compilation);
67127 }
67128 // Keep tensors if KEEP_INTERMEDIATE_TENSORS is on.
67129 try {
67130 this.keepIntermediateTensors = env().getBool('KEEP_INTERMEDIATE_TENSORS');
67131 }
67132 catch (e) {
67133 this.keepIntermediateTensors = false;
67134 console.warn(e.message);
67135 }
67136 const tensorArrayMap = {};
67137 const tensorListMap = {};
67138 return tidy(() => {
67139 const context = new ExecutionContext(this.weightMap, tensorArrayMap, tensorListMap, this.functionExecutorMap, this.parseNodeNameCache);
67140 const tensorsMap = Object.assign({}, this.weightMap);
67141 if (this.keepIntermediateTensors) {
67142 this.clonedTensorsMap = this.cloneTensorMap(this.weightMap);
67143 }
67144 Object.keys(inputs).forEach(name => {
67145 const [nodeName, index] = parseNodeName(name, context);
67146 const tensors = [];
67147 tensors[index] = inputs[name];
67148 tensorsMap[nodeName] = tensors;
67149 if (this.keepIntermediateTensors) {
67150 this.clonedTensorsMap[nodeName] = this.cloneTensorList(tensors);
67151 }
67152 });
67153 const tensorsToKeep = this.getFrozenTensorIds(tensorsMap);
67154 const { orderedNodes, nodeLiveUntilMap } = compilation;
67155 for (const node of orderedNodes) {
67156 if (tensorsMap[node.name]) {
67157 continue;
67158 }
67159 const tensors = executeOp(node, tensorsMap, context, this._resourceManager);
67160 if (isPromise(tensors)) {
67161 throw new Error(`The execution of the op '${node.op}' returned a promise. ` +
67162 `Please use model.executeAsync() instead.`);
67163 }
67164 tensorsMap[node.name] = tensors;
67165 if (this.keepIntermediateTensors) {
67166 this.clonedTensorsMap[node.name] = this.cloneTensorList(tensors);
67167 }
67168 this.checkTensorForDisposalWithNodeLiveUntilInfo(node, tensorsMap, context, tensorsToKeep, outputNodeNameSet, nodeLiveUntilMap.get(node.name));
67169 }
67170 // dispose the context for the root executor
67171 if (this.parent == null) {
67172 context.dispose(tensorsToKeep);
67173 }
67174 return outputs.map(name => getTensor(name, tensorsMap, context));
67175 });
67176 }
67177 getFrozenTensorIds(tensorMap) {
67178 const ids = [].concat.apply([], Object.keys(tensorMap)
67179 .map(key => tensorMap[key])
67180 .map(tensors => tensors.map(tensor => tensor.id)));
67181 return new Set(ids);
67182 }
67183 checkTensorForDisposal(nodeName, node, tensorMap, context, tensorsToKeep, outputNodeNameSet, intermediateTensorConsumerCount) {
67184 // Skip output nodes and any control flow nodes, since its dependency is
67185 // tricky to track correctly.
67186 if (isControlFlow(node) || outputNodeNameSet.has(nodeName)) {
67187 return;
67188 }
67189 for (const tensor of tensorMap[nodeName]) {
67190 if (tensor == null) {
67191 continue;
67192 }
67193 intermediateTensorConsumerCount[tensor.id] =
67194 (intermediateTensorConsumerCount[tensor.id] || 0) +
67195 node.children.length;
67196 }
67197 for (const input of node.inputs) {
67198 // Skip any control flow nodes, since its dependency is tricky to track
67199 // correctly.
67200 if (isControlFlow(input)) {
67201 continue;
67202 }
67203 const tensors = getTensorsForCurrentContext(input.name, tensorMap, context);
67204 if (tensors == null) {
67205 continue;
67206 }
67207 for (const tensor of tensors) {
67208 if (!tensor || tensor.kept || tensorsToKeep.has(tensor.id)) {
67209 continue;
67210 }
67211 // Only intermediate nodes' tensors have counts set, not marked as
67212 // kept, and not in `tensorsToKeep`.
67213 // Input and weight nodes' tensors should exist in `tensorsToKeep`.
67214 // Output and control flow nodes' tensors should never have count set.
67215 const count = intermediateTensorConsumerCount[tensor.id];
67216 if (count === 1) {
67217 tensor.dispose();
67218 delete intermediateTensorConsumerCount[tensor.id];
67219 }
67220 else if (count != null) {
67221 intermediateTensorConsumerCount[tensor.id]--;
67222 }
67223 }
67224 }
67225 }
67226 checkTensorForDisposalWithNodeLiveUntilInfo(node, tensorMap, context, tensorsToKeep, outputNodeNameSet, liveUntilNodes) {
67227 function isNonDisposableNode(node) {
67228 // Skip output nodes and any control flow nodes, since its dependency is
67229 // tricky to track correctly.
67230 return isControlFlow(node) || outputNodeNameSet.has(node.name);
67231 }
67232 if (isControlFlow(node) || liveUntilNodes == null) {
67233 return;
67234 }
67235 for (const nodeToDispose of liveUntilNodes) {
67236 if (isNonDisposableNode(nodeToDispose)) {
67237 continue;
67238 }
67239 const tensors = getTensorsForCurrentContext(nodeToDispose.name, tensorMap, context);
67240 for (const tensor of tensors) {
67241 if (!tensor || tensor.kept || tensorsToKeep.has(tensor.id)) {
67242 continue;
67243 }
67244 tensor.dispose();
67245 }
67246 }
67247 }
67248 /**
67249 * Executes the inference for given input tensors in Async fashion.
67250 * @param inputs Tensor map for the model inputs, keyed by the input node
67251 * names.
67252 * @param outputs output node name from the Tensorflow model, if no outputs
67253 * are specified, the default outputs of the model would be used. You can
67254 * inspect intermediate nodes of the model by adding them to the outputs
67255 * array.
67256 */
67257 async executeAsync(inputs, outputs) {
67258 return this._executeAsync(inputs, outputs);
67259 }
67260 disposeIntermediateTensors() {
67261 if (!this.clonedTensorsMap) {
67262 return;
67263 }
67264 Object.values(this.clonedTensorsMap).forEach(tensorsList => {
67265 for (const tensor of tensorsList) {
67266 if (tensor && !tensor.isDisposed) {
67267 tensor.dispose();
67268 }
67269 }
67270 });
67271 this.clonedTensorsMap = null;
67272 }
67273 getIntermediateTensors() {
67274 return this.clonedTensorsMap;
67275 }
67276 /**
67277 * Executes the inference for given input tensors in Async fashion.
67278 * @param inputs Tensor map for the model inputs, keyed by the input node
67279 * names.
67280 * @param outputs Optional. output node name from the Tensorflow model,
67281 * if no outputs are specified, the default outputs of the model would be
67282 * used. You can inspect intermediate nodes of the model by adding them to
67283 * the outputs array.
67284 * @param isFunctionExecution Optional. Flag for executing a function.
67285 * @param tensorArrayMap Optional, global TensorArray map by id. Used for
67286 * function execution.
67287 * @param tensorArrayMap Optional global TensorList map by id. Used for
67288 * function execution.
67289 */
67290 async _executeAsync(inputs, outputs, isFunctionExecution = false, tensorArrayMap = {}, tensorListMap = {}) {
67291 // Dispose any tensors from a prior run to avoid leaking them.
67292 this.disposeIntermediateTensors();
67293 if (!isFunctionExecution) {
67294 inputs = this.mapInputs(inputs);
67295 this.checkInputs(inputs);
67296 this.checkInputShapeAndType(inputs);
67297 outputs = this.mapOutputs(outputs);
67298 this.checkOutputs(outputs);
67299 }
67300 // Keep tensors if KEEP_INTERMEDIATE_TENSORS is on.
67301 try {
67302 this.keepIntermediateTensors = env().getBool('KEEP_INTERMEDIATE_TENSORS');
67303 }
67304 catch (e) {
67305 this.keepIntermediateTensors = false;
67306 console.warn(e.message);
67307 }
67308 const context = new ExecutionContext(this.weightMap, tensorArrayMap, tensorListMap, this.functionExecutorMap, this.parseNodeNameCache);
67309 if (this.keepIntermediateTensors) {
67310 this.clonedTensorsMap = this.cloneTensorMap(this.weightMap);
67311 }
67312 // Graph with control flow op requires runtime evaluation of the execution
67313 // order, while without control flow the execution order is pre-determined
67314 // in the compile method.
67315 const tensorsMap = await this.executeWithControlFlow(inputs, context, outputs, isFunctionExecution);
67316 const results = outputs.map(name => getTensor(name, tensorsMap, context));
67317 // dispose all the intermediate tensors
67318 const outputIds = results.map(t => t.id);
67319 const inputIds = Object.keys(inputs).map(name => inputs[name].id);
67320 const keepIds = new Set([...outputIds, ...inputIds, ...this.weightIds]);
67321 Object.values(tensorsMap).forEach(tensorsList => {
67322 tensorsList.forEach(tensor => {
67323 if (tensor && !tensor.isDisposed && !keepIds.has(tensor.id)) {
67324 tensor.dispose();
67325 }
67326 });
67327 });
67328 // dispose the context for the root executor
67329 if (this.parent == null) {
67330 context.dispose(keepIds);
67331 }
67332 return results;
67333 }
67334 async executeFunctionAsync(inputs, tensorArrayMap, tensorListMap) {
67335 const mappedInputs = inputs.reduce((map, tensor, index) => {
67336 map[this.inputs[index].name] = tensor;
67337 return map;
67338 }, {});
67339 return this._executeAsync(mappedInputs, this.outputNodes, true, tensorArrayMap, tensorListMap);
67340 }
67341 /**
67342 * When there are control flow nodes in the graph, the graph execution use
67343 * ExecutionContext to keep track of the frames and loop iterators.
67344 * @param inputs placeholder tensors for the graph.
67345 * @param context the execution context object for current execution.
67346 * @param outputNames Optional. output node name from the Tensorflow model,
67347 * if no outputs are specified, the default outputs of the model would be
67348 * used. You can inspect intermediate nodes of the model by adding them to
67349 * the outputs array.
67350 * @param isFunctionExecution Flag for executing a function.
67351 */
67352 async executeWithControlFlow(inputs, context, outputNames, isFunctionExecution) {
67353 const names = Object.keys(inputs);
67354 const inputNodes = names.map(name => this.graph.nodes[parseNodeName(name)[0]]);
67355 const outputNodeNames = outputNames.map(name => parseNodeName(name)[0]);
67356 const outputNodeNameSet = new Set(outputNodeNames);
67357 let outputNodes = outputNodeNames.map(name => this.graph.nodes[name]);
67358 // If no outputs are specified, then use the default outputs of the model.
67359 if (outputNodes.length === 0) {
67360 outputNodes = this._outputs;
67361 }
67362 const { usedNodes, missingInputs, dynamicNode, syncInputs } = getExecutionSubgraph(inputs, outputNodes, this.weightMap, this._initNodes);
67363 // First nodes to execute include inputNodes, weights, and initNodes.
67364 const stack = [
67365 ...inputNodes, ...this.graph.weights, ...(this._initNodes || [])
67366 ].map(node => {
67367 return { node, contexts: context.currentContext };
67368 });
67369 const tensorsMap = Object.assign({}, this.weightMap);
67370 Object.keys(inputs).forEach(name => {
67371 const [nodeName, index] = parseNodeName(name);
67372 const tensors = [];
67373 tensors[index] = inputs[name];
67374 tensorsMap[nodeName] = tensors;
67375 });
67376 const intermediateTensorConsumerCount = {};
67377 const tensorsToKeep = this.getFrozenTensorIds(tensorsMap);
67378 const added = {};
67379 while (stack.length > 0) {
67380 const promises = this.processStack(inputNodes, stack, context, tensorsMap, added, tensorsToKeep, outputNodeNameSet, intermediateTensorConsumerCount, usedNodes);
67381 await Promise.all(promises);
67382 }
67383 if (dynamicNode == null && !isFunctionExecution) {
67384 console.warn(`This model execution did not contain any nodes with control flow ` +
67385 `or dynamic output shapes. You can use model.execute() instead.`);
67386 }
67387 const missingOutputs = outputNodes
67388 .filter(node => !isControlFlow(node) &&
67389 !getTensor(node.name, tensorsMap, context))
67390 .map(node => node.name);
67391 if (missingOutputs.length > 0) {
67392 let alternativeMsg = '';
67393 if (dynamicNode != null) {
67394 alternativeMsg =
67395 `Alternatively, to avoid the dynamic ops, use model.execute() ` +
67396 `and specify the inputs [${syncInputs}]`;
67397 }
67398 throw new Error(`Cannot compute the outputs [${missingOutputs}] from the provided ` +
67399 `inputs [${names}]. Consider providing the following inputs: ` +
67400 `[${missingInputs}]. ${alternativeMsg}`);
67401 }
67402 return tensorsMap;
67403 }
67404 processStack(inputNodes, stack, context, tensorMap, added, tensorsToKeep, outputNodeNameSet, intermediateTensorConsumerCount, usedNodes) {
67405 const promises = [];
67406 while (stack.length > 0) {
67407 const item = stack.pop();
67408 context.currentContext = item.contexts;
67409 let nodeName = '';
67410 // The tensor of the Enter op with isConstant set should be set
67411 // in the parent scope, so it will be available as constant for the
67412 // whole loop.
67413 if (item.node.op === 'Enter' &&
67414 getParamValue('isConstant', item.node, tensorMap, context)) {
67415 [nodeName] = getNodeNameAndIndex(item.node.name, context);
67416 }
67417 // only process nodes that are not in the tensorMap yet, this include
67418 // inputNodes and internal initNodes.
67419 if (tensorMap[item.node.name] == null) {
67420 const tensors = executeOp(item.node, tensorMap, context, this._resourceManager);
67421 if (!nodeName) {
67422 [nodeName] = getNodeNameAndIndex(item.node.name, context);
67423 }
67424 const currentContext = context.currentContext;
67425 if (isPromise(tensors)) {
67426 promises.push(tensors.then(t => {
67427 tensorMap[nodeName] = t;
67428 if (this.keepIntermediateTensors) {
67429 this.clonedTensorsMap[nodeName] = this.cloneTensorList(t);
67430 }
67431 context.currentContext = currentContext;
67432 this.checkTensorForDisposal(nodeName, item.node, tensorMap, context, tensorsToKeep, outputNodeNameSet, intermediateTensorConsumerCount);
67433 this.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
67434 return t;
67435 }));
67436 }
67437 else {
67438 tensorMap[nodeName] = tensors;
67439 if (this.keepIntermediateTensors) {
67440 this.clonedTensorsMap[nodeName] = this.cloneTensorList(tensors);
67441 }
67442 this.checkTensorForDisposal(nodeName, item.node, tensorMap, context, tensorsToKeep, outputNodeNameSet, intermediateTensorConsumerCount);
67443 this.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
67444 }
67445 }
67446 else {
67447 this.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
67448 }
67449 }
67450 return promises;
67451 }
67452 processChildNodes(node, stack, context, tensorMap, added, usedNodes) {
67453 node.children.forEach((childNode) => {
67454 const [nodeName,] = getNodeNameAndIndex(childNode.name, context);
67455 if (added[nodeName] || !usedNodes.has(childNode.name)) {
67456 return;
67457 }
67458 // Merge op can be pushed if any of its inputs has value.
67459 if (childNode.op === 'Merge') {
67460 if (childNode.inputNames.some(name => {
67461 return !!getTensor(name, tensorMap, context);
67462 })) {
67463 added[nodeName] = true;
67464 stack.push({ contexts: context.currentContext, node: childNode });
67465 }
67466 }
67467 else // Otherwise all inputs must to have value.
67468 if (childNode.inputNames.every(name => {
67469 return !!getTensor(name, tensorMap, context);
67470 })) {
67471 added[nodeName] = true;
67472 stack.push({ contexts: context.currentContext, node: childNode });
67473 }
67474 });
67475 }
67476 /**
67477 * Releases the memory used by the weight tensors.
67478 */
67479 dispose() {
67480 Object.keys(this.weightMap)
67481 .forEach(key => this.weightMap[key].forEach(tensor => tensor.dispose()));
67482 }
67483 checkInputShapeAndType(inputs) {
67484 Object.keys(inputs).forEach(name => {
67485 const input = inputs[name];
67486 const [nodeName,] = parseNodeName(name);
67487 const node = this.graph.nodes[nodeName];
67488 if (node.attrParams['shape'] && node.attrParams['shape'].value) {
67489 const shape = node.attrParams['shape'].value;
67490 const match = shape.length === input.shape.length &&
67491 input.shape.every((dim, index) => shape[index] === -1 || shape[index] === dim);
67492 assert$1(match, () => `The shape of dict['${node.name}'] provided in ` +
67493 `model.execute(dict) must be [${shape}], but was ` +
67494 `[${input.shape}]`);
67495 }
67496 if (node.attrParams['dtype'] && node.attrParams['dtype'].value) {
67497 assert$1(input.dtype === node.attrParams['dtype'].value, () => `The dtype of dict['${node.name}'] provided in ` +
67498 `model.execute(dict) must be ` +
67499 `${node.attrParams['dtype'].value}, but was ${input.dtype}`);
67500 }
67501 });
67502 }
67503 mapInputs(inputs) {
67504 var _a, _b;
67505 const result = {};
67506 for (const inputName in inputs) {
67507 const tensor = (_b = (_a = this._signature) === null || _a === void 0 ? void 0 : _a.inputs) === null || _b === void 0 ? void 0 : _b[inputName];
67508 if (tensor != null) {
67509 result[tensor.name] = inputs[inputName];
67510 }
67511 else {
67512 result[inputName] = inputs[inputName];
67513 }
67514 }
67515 return result;
67516 }
67517 checkInputs(inputs) {
67518 const notInGraph = Object.keys(inputs).filter(name => {
67519 const [nodeName] = parseNodeName(name);
67520 return this.graph.nodes[nodeName] == null;
67521 });
67522 if (notInGraph.length > 0) {
67523 throw new Error(`The dict provided in model.execute(dict) has ` +
67524 `keys: [${notInGraph}] that are not part of graph`);
67525 }
67526 }
67527 mapOutputs(outputs) {
67528 return outputs.map(name => {
67529 var _a, _b;
67530 const tensor = (_b = (_a = this._signature) === null || _a === void 0 ? void 0 : _a.outputs) === null || _b === void 0 ? void 0 : _b[name];
67531 if (tensor != null) {
67532 return tensor.name;
67533 }
67534 return name;
67535 }, {});
67536 }
67537 checkOutputs(outputs) {
67538 outputs.forEach(name => {
67539 const [normalizedName] = parseNodeName(name);
67540 if (!this.graph.nodes[normalizedName]) {
67541 throw new Error(`The output '${name}' is not found in the graph`);
67542 }
67543 });
67544 }
67545 }
67546
67547 /**
67548 * Contains global resources of a model.
67549 */
67550 class ResourceManager {
67551 constructor(hashTableNameToHandle = {}, hashTableMap = {}) {
67552 this.hashTableNameToHandle = hashTableNameToHandle;
67553 this.hashTableMap = hashTableMap;
67554 }
67555 /**
67556 * Register a `HashTable` in the resource manager.
67557 *
67558 * The `HashTable` can be retrieved by `resourceManager.getHashTableById`,
67559 * where id is the table handle tensor's id.
67560 *
67561 * @param name Op node name that creates the `HashTable`.
67562 * @param hashTable The `HashTable` to be added to resource manager.
67563 */
67564 addHashTable(name, hashTable) {
67565 this.hashTableNameToHandle[name] = hashTable.handle;
67566 this.hashTableMap[hashTable.id] = hashTable;
67567 }
67568 /**
67569 * Get the table handle by node name.
67570 * @param name Op node name that creates the `HashTable`. This name is also
67571 * used in the inputs list of lookup and import `HashTable` ops.
67572 */
67573 getHashTableHandleByName(name) {
67574 return this.hashTableNameToHandle[name];
67575 }
67576 /**
67577 * Get the actual `HashTable` by its handle tensor's id.
67578 * @param id The id of the handle tensor.
67579 */
67580 getHashTableById(id) {
67581 return this.hashTableMap[id];
67582 }
67583 /**
67584 * Dispose `ResourceManager`, including its hashTables and tensors in them.
67585 */
67586 dispose() {
67587 for (const key in this.hashTableMap) {
67588 this.hashTableMap[key].clearAndClose();
67589 delete this.hashTableMap[key];
67590 }
67591 for (const name in this.hashTableNameToHandle) {
67592 this.hashTableNameToHandle[name].dispose();
67593 delete this.hashTableNameToHandle[name];
67594 }
67595 }
67596 }
67597
67598 /**
67599 * @license
67600 * Copyright 2018 Google LLC. All Rights Reserved.
67601 * Licensed under the Apache License, Version 2.0 (the "License");
67602 * you may not use this file except in compliance with the License.
67603 * You may obtain a copy of the License at
67604 *
67605 * http://www.apache.org/licenses/LICENSE-2.0
67606 *
67607 * Unless required by applicable law or agreed to in writing, software
67608 * distributed under the License is distributed on an "AS IS" BASIS,
67609 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
67610 * See the License for the specific language governing permissions and
67611 * limitations under the License.
67612 * =============================================================================
67613 */
67614 const TFHUB_SEARCH_PARAM = '?tfjs-format=file';
67615 const DEFAULT_MODEL_NAME = 'model.json';
67616 /**
67617 * A `tf.GraphModel` is a directed, acyclic graph built from a
67618 * SavedModel GraphDef and allows inference execution.
67619 *
67620 * A `tf.GraphModel` can only be created by loading from a model converted from
67621 * a [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) using
67622 * the command line converter tool and loaded via `tf.loadGraphModel`.
67623 *
67624 * @doc {heading: 'Models', subheading: 'Classes'}
67625 */
67626 class GraphModel {
67627 // Returns the version information for the tensorflow model GraphDef.
67628 get modelVersion() {
67629 return this.version;
67630 }
67631 get inputNodes() {
67632 return this.executor.inputNodes;
67633 }
67634 get outputNodes() {
67635 return this.executor.outputNodes;
67636 }
67637 get inputs() {
67638 return this.executor.inputs;
67639 }
67640 get outputs() {
67641 return this.executor.outputs;
67642 }
67643 get weights() {
67644 return this.executor.weightMap;
67645 }
67646 get metadata() {
67647 return this.artifacts.userDefinedMetadata;
67648 }
67649 get modelSignature() {
67650 return this.signature;
67651 }
67652 get modelStructuredOutputKeys() {
67653 return this.structuredOutputKeys;
67654 }
67655 /**
67656 * @param modelUrl url for the model, or an `io.IOHandler`.
67657 * @param weightManifestUrl url for the weight file generated by
67658 * scripts/convert.py script.
67659 * @param requestOption options for Request, which allows to send credentials
67660 * and custom headers.
67661 * @param onProgress Optional, progress callback function, fired periodically
67662 * before the load is completed.
67663 */
67664 constructor(modelUrl, loadOptions = {}, tfio = io) {
67665 this.modelUrl = modelUrl;
67666 this.loadOptions = loadOptions;
67667 this.version = 'n/a';
67668 this.io = tfio;
67669 if (loadOptions == null) {
67670 this.loadOptions = {};
67671 }
67672 this.resourceManager = new ResourceManager();
67673 }
67674 findIOHandler() {
67675 const path = this.modelUrl;
67676 if (path.load != null) {
67677 // Path is an IO Handler.
67678 this.handler = path;
67679 }
67680 else if (this.loadOptions.requestInit != null) {
67681 this.handler = this.io.browserHTTPRequest(path, this.loadOptions);
67682 }
67683 else {
67684 const handlers = this.io.getLoadHandlers(path, this.loadOptions);
67685 if (handlers.length === 0) {
67686 // For backward compatibility: if no load handler can be found,
67687 // assume it is a relative http path.
67688 handlers.push(this.io.browserHTTPRequest(path, this.loadOptions));
67689 }
67690 else if (handlers.length > 1) {
67691 throw new Error(`Found more than one (${handlers.length}) load handlers for ` +
67692 `URL '${[path]}'`);
67693 }
67694 this.handler = handlers[0];
67695 }
67696 }
67697 /**
67698 * Loads the model and weight files, construct the in memory weight map and
67699 * compile the inference graph.
67700 */
67701 load() {
67702 this.findIOHandler();
67703 if (this.handler.load == null) {
67704 throw new Error('Cannot proceed with model loading because the IOHandler provided ' +
67705 'does not have the `load` method implemented.');
67706 }
67707 const loadResult = this.handler.load();
67708 if (isPromise(loadResult)) {
67709 return loadResult.then(artifacts => {
67710 if (artifacts.getWeightStream == null) {
67711 return this.loadSync(artifacts);
67712 }
67713 return this.loadStreaming(artifacts);
67714 });
67715 }
67716 return this.loadSync(loadResult);
67717 }
67718 /**
67719 * Synchronously construct the in memory weight map and
67720 * compile the inference graph.
67721 *
67722 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
67723 */
67724 loadSync(artifacts) {
67725 const weightMap = this.io.decodeWeights(artifacts.weightData, artifacts.weightSpecs);
67726 return this.loadWithWeightMap(artifacts, weightMap);
67727 }
67728 async loadStreaming(artifacts) {
67729 if (artifacts.getWeightStream == null) {
67730 throw new Error('Model artifacts missing streamWeights function');
67731 }
67732 const weightMap = await decodeWeightsStream(artifacts.getWeightStream(), artifacts.weightSpecs);
67733 return this.loadWithWeightMap(artifacts, weightMap);
67734 }
67735 loadWithWeightMap(artifacts, weightMap) {
67736 this.artifacts = artifacts;
67737 const graph = this.artifacts.modelTopology;
67738 let signature = this.artifacts.signature;
67739 if (this.artifacts.userDefinedMetadata != null) {
67740 const metadata = this.artifacts.userDefinedMetadata;
67741 if (metadata.signature != null) {
67742 signature = metadata.signature;
67743 }
67744 if (metadata.structuredOutputKeys != null) {
67745 this.structuredOutputKeys = metadata.structuredOutputKeys;
67746 }
67747 }
67748 this.signature = signature;
67749 this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`;
67750 this.executor = new GraphExecutor(OperationMapper.Instance.transformGraph(graph, this.signature));
67751 this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap);
67752 // Attach a model-level resourceManager to each executor to share resources,
67753 // such as `HashTable`.
67754 this.executor.resourceManager = this.resourceManager;
67755 if (artifacts.modelInitializer != null &&
67756 artifacts.modelInitializer.node != null) {
67757 const initializer = OperationMapper.Instance.transformGraph(artifacts.modelInitializer);
67758 this.initializer = new GraphExecutor(initializer);
67759 this.initializer.weightMap = this.executor.weightMap;
67760 // Attach a model-level resourceManager to the initializer, the
67761 // hashTables created from when executing the initializer will be stored
67762 // in the resourceManager.
67763 this.initializer.resourceManager = this.resourceManager;
67764 this.initializerSignature = artifacts.initializerSignature;
67765 }
67766 return true;
67767 }
67768 /**
67769 * Save the configuration and/or weights of the GraphModel.
67770 *
67771 * An `IOHandler` is an object that has a `save` method of the proper
67772 * signature defined. The `save` method manages the storing or
67773 * transmission of serialized data ("artifacts") that represent the
67774 * model's topology and weights onto or via a specific medium, such as
67775 * file downloads, local storage, IndexedDB in the web browser and HTTP
67776 * requests to a server. TensorFlow.js provides `IOHandler`
67777 * implementations for a number of frequently used saving mediums, such as
67778 * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
67779 * for more details.
67780 *
67781 * This method also allows you to refer to certain types of `IOHandler`s
67782 * as URL-like string shortcuts, such as 'localstorage://' and
67783 * 'indexeddb://'.
67784 *
67785 * Example 1: Save `model`'s topology and weights to browser [local
67786 * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
67787 * then load it back.
67788 *
67789 * ```js
67790 * const modelUrl =
67791 * 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
67792 * const model = await tf.loadGraphModel(modelUrl);
67793 * const zeros = tf.zeros([1, 224, 224, 3]);
67794 * model.predict(zeros).print();
67795 *
67796 * const saveResults = await model.save('localstorage://my-model-1');
67797 *
67798 * const loadedModel = await tf.loadGraphModel('localstorage://my-model-1');
67799 * console.log('Prediction from loaded model:');
67800 * model.predict(zeros).print();
67801 * ```
67802 *
67803 * @param handlerOrURL An instance of `IOHandler` or a URL-like,
67804 * scheme-based string shortcut for `IOHandler`.
67805 * @param config Options for saving the model.
67806 * @returns A `Promise` of `SaveResult`, which summarizes the result of
67807 * the saving, such as byte sizes of the saved artifacts for the model's
67808 * topology and weight values.
67809 *
67810 * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
67811 */
67812 async save(handlerOrURL, config) {
67813 if (typeof handlerOrURL === 'string') {
67814 const handlers = this.io.getSaveHandlers(handlerOrURL);
67815 if (handlers.length === 0) {
67816 throw new Error(`Cannot find any save handlers for URL '${handlerOrURL}'`);
67817 }
67818 else if (handlers.length > 1) {
67819 throw new Error(`Found more than one (${handlers.length}) save handlers for ` +
67820 `URL '${handlerOrURL}'`);
67821 }
67822 handlerOrURL = handlers[0];
67823 }
67824 if (handlerOrURL.save == null) {
67825 throw new Error('GraphModel.save() cannot proceed because the IOHandler ' +
67826 'provided does not have the `save` attribute defined.');
67827 }
67828 return handlerOrURL.save(this.artifacts);
67829 }
67830 addStructuredOutputNames(outputTensors) {
67831 if (this.structuredOutputKeys) {
67832 const outputTensorsArray = outputTensors instanceof Tensor ? [outputTensors] : outputTensors;
67833 const outputTensorMap = {};
67834 outputTensorsArray.forEach((outputTensor, i) => outputTensorMap[this.structuredOutputKeys[i]] =
67835 outputTensor);
67836 return outputTensorMap;
67837 }
67838 return outputTensors;
67839 }
67840 /**
67841 * Execute the inference for the input tensors.
67842 *
67843 * @param input The input tensors, when there is single input for the model,
67844 * inputs param should be a `tf.Tensor`. For models with multiple inputs,
67845 * inputs params should be in either `tf.Tensor`[] if the input order is
67846 * fixed, or otherwise NamedTensorMap format.
67847 *
67848 * For model with multiple inputs, we recommend you use NamedTensorMap as the
67849 * input type, if you use `tf.Tensor`[], the order of the array needs to
67850 * follow the
67851 * order of inputNodes array. @see {@link GraphModel.inputNodes}
67852 *
67853 * You can also feed any intermediate nodes using the NamedTensorMap as the
67854 * input type. For example, given the graph
67855 * InputNode => Intermediate => OutputNode,
67856 * you can execute the subgraph Intermediate => OutputNode by calling
67857 * model.execute('IntermediateNode' : tf.tensor(...));
67858 *
67859 * This is useful for models that uses tf.dynamic_rnn, where the intermediate
67860 * state needs to be fed manually.
67861 *
67862 * For batch inference execution, the tensors for each input need to be
67863 * concatenated together. For example with mobilenet, the required input shape
67864 * is [1, 244, 244, 3], which represents the [batch, height, width, channel].
67865 * If we are provide a batched data of 100 images, the input tensor should be
67866 * in the shape of [100, 244, 244, 3].
67867 *
67868 * @param config Prediction configuration for specifying the batch size.
67869 * Currently the batch size option is ignored for graph model.
67870 *
67871 * @returns Inference result tensors. If the model is converted and it
67872 * originally had structured_outputs in tensorflow, then a NamedTensorMap
67873 * will be returned matching the structured_outputs. If no structured_outputs
67874 * are present, the output will be single `tf.Tensor` if the model has single
67875 * output node, otherwise Tensor[].
67876 *
67877 * @doc {heading: 'Models', subheading: 'Classes'}
67878 */
67879 predict(inputs, config) {
67880 const outputTensors = this.execute(inputs, this.outputNodes);
67881 return this.addStructuredOutputNames(outputTensors);
67882 }
67883 /**
67884 * Execute the inference for the input tensors in async fashion, use this
67885 * method when your model contains control flow ops.
67886 *
67887 * @param input The input tensors, when there is single input for the model,
67888 * inputs param should be a `tf.Tensor`. For models with mutliple inputs,
67889 * inputs params should be in either `tf.Tensor`[] if the input order is
67890 * fixed, or otherwise NamedTensorMap format.
67891 *
67892 * For model with multiple inputs, we recommend you use NamedTensorMap as the
67893 * input type, if you use `tf.Tensor`[], the order of the array needs to
67894 * follow the
67895 * order of inputNodes array. @see {@link GraphModel.inputNodes}
67896 *
67897 * You can also feed any intermediate nodes using the NamedTensorMap as the
67898 * input type. For example, given the graph
67899 * InputNode => Intermediate => OutputNode,
67900 * you can execute the subgraph Intermediate => OutputNode by calling
67901 * model.execute('IntermediateNode' : tf.tensor(...));
67902 *
67903 * This is useful for models that uses tf.dynamic_rnn, where the intermediate
67904 * state needs to be fed manually.
67905 *
67906 * For batch inference execution, the tensors for each input need to be
67907 * concatenated together. For example with mobilenet, the required input shape
67908 * is [1, 244, 244, 3], which represents the [batch, height, width, channel].
67909 * If we are provide a batched data of 100 images, the input tensor should be
67910 * in the shape of [100, 244, 244, 3].
67911 *
67912 * @param config Prediction configuration for specifying the batch size.
67913 * Currently the batch size option is ignored for graph model.
67914 *
67915 * @returns A Promise of inference result tensors. If the model is converted
67916 * and it originally had structured_outputs in tensorflow, then a
67917 * NamedTensorMap will be returned matching the structured_outputs. If no
67918 * structured_outputs are present, the output will be single `tf.Tensor` if
67919 * the model has single output node, otherwise Tensor[].
67920 *
67921 * @doc {heading: 'Models', subheading: 'Classes'}
67922 */
67923 async predictAsync(inputs, config) {
67924 const outputTensors = await this.executeAsync(inputs, this.outputNodes);
67925 return this.addStructuredOutputNames(outputTensors);
67926 }
67927 normalizeInputs(inputs) {
67928 var _a;
67929 if (!(inputs instanceof Tensor) && !Array.isArray(inputs)) {
67930 // The input is already a NamedTensorMap.
67931 const signatureInputs = (_a = this.signature) === null || _a === void 0 ? void 0 : _a.inputs;
67932 if (signatureInputs != null) {
67933 for (const input in signatureInputs) {
67934 const tensor = signatureInputs[input];
67935 if (tensor.resourceId != null) {
67936 inputs[input] = this.resourceIdToCapturedInput[tensor.resourceId];
67937 }
67938 }
67939 }
67940 return inputs;
67941 }
67942 inputs = Array.isArray(inputs) ? inputs : [inputs];
67943 const numCapturedInputs = Object.keys(this.resourceIdToCapturedInput).length;
67944 if (inputs.length + numCapturedInputs !== this.inputNodes.length) {
67945 throw new Error(`Input tensor count mismatch, the graph model has ${this.inputNodes.length -
67946 numCapturedInputs} non-resource placeholders, while there are ${inputs.length} input tensors provided.`);
67947 }
67948 let inputIndex = 0;
67949 return this.inputNodes.reduce((map, inputName) => {
67950 var _a, _b, _c;
67951 const resourceId = (_c = (_b = (_a = this.signature) === null || _a === void 0 ? void 0 : _a.inputs) === null || _b === void 0 ? void 0 : _b[inputName]) === null || _c === void 0 ? void 0 : _c.resourceId;
67952 if (resourceId != null) {
67953 map[inputName] = this.resourceIdToCapturedInput[resourceId];
67954 }
67955 else {
67956 map[inputName] = inputs[inputIndex++];
67957 }
67958 return map;
67959 }, {});
67960 }
67961 normalizeOutputs(outputs) {
67962 outputs = outputs || this.outputNodes;
67963 return !Array.isArray(outputs) ? [outputs] : outputs;
67964 }
67965 executeInitializerGraph() {
67966 if (this.initializer == null) {
67967 return [];
67968 }
67969 if (this.initializerSignature == null) {
67970 return this.initializer.execute({}, []);
67971 }
67972 else {
67973 return this.initializer.execute({}, Object.keys(this.initializerSignature.outputs));
67974 }
67975 }
67976 async executeInitializerGraphAsync() {
67977 if (this.initializer == null) {
67978 return [];
67979 }
67980 if (this.initializerSignature == null) {
67981 return this.initializer.executeAsync({}, []);
67982 }
67983 else {
67984 return this.initializer.executeAsync({}, Object.keys(this.initializerSignature.outputs));
67985 }
67986 }
67987 setResourceIdToCapturedInput(outputs) {
67988 this.resourceIdToCapturedInput = {};
67989 if (this.initializerSignature) {
67990 const signatureOutputs = this.initializerSignature.outputs;
67991 const outputNames = Object.keys(signatureOutputs);
67992 for (let i = 0; i < outputNames.length; i++) {
67993 const outputName = outputNames[i];
67994 const tensorInfo = signatureOutputs[outputName];
67995 this.resourceIdToCapturedInput[tensorInfo.resourceId] = outputs[i];
67996 }
67997 }
67998 }
67999 /**
68000 * Executes inference for the model for given input tensors.
68001 * @param inputs tensor, tensor array or tensor map of the inputs for the
68002 * model, keyed by the input node names.
68003 * @param outputs output node name from the TensorFlow model, if no
68004 * outputs are specified, the default outputs of the model would be used.
68005 * You can inspect intermediate nodes of the model by adding them to the
68006 * outputs array.
68007 *
68008 * @returns A single tensor if provided with a single output or no outputs
68009 * are provided and there is only one default output, otherwise return a
68010 * tensor array. The order of the tensor array is the same as the outputs
68011 * if provided, otherwise the order of outputNodes attribute of the model.
68012 *
68013 * @doc {heading: 'Models', subheading: 'Classes'}
68014 */
68015 execute(inputs, outputs) {
68016 if (this.resourceIdToCapturedInput == null) {
68017 this.setResourceIdToCapturedInput(this.executeInitializerGraph());
68018 }
68019 inputs = this.normalizeInputs(inputs);
68020 outputs = this.normalizeOutputs(outputs);
68021 const result = this.executor.execute(inputs, outputs);
68022 return result.length > 1 ? result : result[0];
68023 }
68024 /**
68025 * Executes inference for the model for given input tensors in async
68026 * fashion, use this method when your model contains control flow ops.
68027 * @param inputs tensor, tensor array or tensor map of the inputs for the
68028 * model, keyed by the input node names.
68029 * @param outputs output node name from the TensorFlow model, if no outputs
68030 * are specified, the default outputs of the model would be used. You can
68031 * inspect intermediate nodes of the model by adding them to the outputs
68032 * array.
68033 *
68034 * @returns A Promise of single tensor if provided with a single output or
68035 * no outputs are provided and there is only one default output, otherwise
68036 * return a tensor map.
68037 *
68038 * @doc {heading: 'Models', subheading: 'Classes'}
68039 */
68040 async executeAsync(inputs, outputs) {
68041 if (this.resourceIdToCapturedInput == null) {
68042 this.setResourceIdToCapturedInput(await this.executeInitializerGraphAsync());
68043 }
68044 inputs = this.normalizeInputs(inputs);
68045 outputs = this.normalizeOutputs(outputs);
68046 const result = await this.executor.executeAsync(inputs, outputs);
68047 return result.length > 1 ? result : result[0];
68048 }
68049 /**
68050 * Get intermediate tensors for model debugging mode (flag
68051 * KEEP_INTERMEDIATE_TENSORS is true).
68052 *
68053 * @doc {heading: 'Models', subheading: 'Classes'}
68054 */
68055 getIntermediateTensors() {
68056 return this.executor.getIntermediateTensors();
68057 }
68058 /**
68059 * Dispose intermediate tensors for model debugging mode (flag
68060 * KEEP_INTERMEDIATE_TENSORS is true).
68061 *
68062 * @doc {heading: 'Models', subheading: 'Classes'}
68063 */
68064 disposeIntermediateTensors() {
68065 this.executor.disposeIntermediateTensors();
68066 }
68067 convertTensorMapToTensorsMap(map) {
68068 return Object.keys(map).reduce((newMap, key) => {
68069 newMap[key] = [map[key]];
68070 return newMap;
68071 }, {});
68072 }
68073 /**
68074 * Releases the memory used by the weight tensors and resourceManager.
68075 *
68076 * @doc {heading: 'Models', subheading: 'Classes'}
68077 */
68078 dispose() {
68079 this.executor.dispose();
68080 if (this.initializer) {
68081 this.initializer.dispose();
68082 if (this.resourceIdToCapturedInput) {
68083 dispose(this.resourceIdToCapturedInput);
68084 }
68085 }
68086 this.resourceManager.dispose();
68087 }
68088 }
68089 /**
68090 * Load a graph model given a URL to the model definition.
68091 *
68092 * Example of loading MobileNetV2 from a URL and making a prediction with a
68093 * zeros input:
68094 *
68095 * ```js
68096 * const modelUrl =
68097 * 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
68098 * const model = await tf.loadGraphModel(modelUrl);
68099 * const zeros = tf.zeros([1, 224, 224, 3]);
68100 * model.predict(zeros).print();
68101 * ```
68102 *
68103 * Example of loading MobileNetV2 from a TF Hub URL and making a prediction
68104 * with a zeros input:
68105 *
68106 * ```js
68107 * const modelUrl =
68108 * 'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
68109 * const model = await tf.loadGraphModel(modelUrl, {fromTFHub: true});
68110 * const zeros = tf.zeros([1, 224, 224, 3]);
68111 * model.predict(zeros).print();
68112 * ```
68113 * @param modelUrl The url or an `io.IOHandler` that loads the model.
68114 * @param options Options for the HTTP request, which allows to send
68115 * credentials
68116 * and custom headers.
68117 *
68118 * @doc {heading: 'Models', subheading: 'Loading'}
68119 */
68120 async function loadGraphModel(modelUrl, options = {}, tfio = io) {
68121 if (modelUrl == null) {
68122 throw new Error('modelUrl in loadGraphModel() cannot be null. Please provide a url ' +
68123 'or an IOHandler that loads the model');
68124 }
68125 if (options == null) {
68126 options = {};
68127 }
68128 if (options.fromTFHub && typeof modelUrl === 'string') {
68129 modelUrl = getTFHubUrl(modelUrl);
68130 }
68131 const model = new GraphModel(modelUrl, options, tfio);
68132 await model.load();
68133 return model;
68134 }
68135 /**
68136 * Load a graph model given a synchronous IO handler with a 'load' method.
68137 *
68138 * @param modelSource The `io.IOHandlerSync` that loads the model, or the
68139 * `io.ModelArtifacts` that encode the model, or a tuple of
68140 * `[io.ModelJSON, ArrayBuffer]` of which the first element encodes the
68141 * model and the second contains the weights.
68142 *
68143 * @doc {heading: 'Models', subheading: 'Loading'}
68144 */
68145 function loadGraphModelSync(modelSource) {
68146 if (modelSource == null) {
68147 throw new Error('modelUrl in loadGraphModelSync() cannot be null. Please provide ' +
68148 'model artifacts or an IOHandler that loads the model');
68149 }
68150 let ioHandler;
68151 if (modelSource instanceof Array) {
68152 const [modelJSON, weights] = modelSource;
68153 if (!modelJSON) {
68154 throw new Error('modelJSON must be the first element of the array');
68155 }
68156 if (!weights || !(weights instanceof ArrayBuffer)) {
68157 throw new Error('An ArrayBuffer of weights must be the second element of' +
68158 ' the array');
68159 }
68160 if (!('modelTopology' in modelJSON)) {
68161 throw new Error('Model JSON is missing \'modelTopology\'');
68162 }
68163 if (!('weightsManifest' in modelJSON)) {
68164 throw new Error('Model JSON is missing \'weightsManifest\'');
68165 }
68166 const weightSpecs = getWeightSpecs(modelJSON.weightsManifest);
68167 const modelArtifacts = getModelArtifactsForJSONSync(modelJSON, weightSpecs, weights);
68168 ioHandler = fromMemorySync(modelArtifacts);
68169 }
68170 else if ('load' in modelSource) {
68171 // Then modelSource is already an IOHandlerSync.
68172 ioHandler = modelSource;
68173 }
68174 else if ('modelTopology' in modelSource && 'weightSpecs' in modelSource &&
68175 'weightData' in modelSource) {
68176 // modelSource is of type ModelArtifacts.
68177 ioHandler = fromMemorySync(modelSource);
68178 }
68179 else {
68180 throw new Error('Unknown model format');
68181 }
68182 const model = new GraphModel(ioHandler);
68183 model.load();
68184 return model;
68185 }
68186 function getTFHubUrl(modelUrl) {
68187 if (!modelUrl.endsWith('/')) {
68188 modelUrl = (modelUrl) + '/';
68189 }
68190 return `${modelUrl}${DEFAULT_MODEL_NAME}${TFHUB_SEARCH_PARAM}`;
68191 }
68192
68193 /** @license See the LICENSE file. */
68194 // This code is auto-generated, do not modify this file!
68195 const version$5 = '4.22.0';
68196
68197 /**
68198 * @license
68199 * Copyright 2018 Google LLC. All Rights Reserved.
68200 * Licensed under the Apache License, Version 2.0 (the "License");
68201 * you may not use this file except in compliance with the License.
68202 * You may obtain a copy of the License at
68203 *
68204 * http://www.apache.org/licenses/LICENSE-2.0
68205 *
68206 * Unless required by applicable law or agreed to in writing, software
68207 * distributed under the License is distributed on an "AS IS" BASIS,
68208 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68209 * See the License for the specific language governing permissions and
68210 * limitations under the License.
68211 * =============================================================================
68212 */
68213
68214 /**
68215 * @license
68216 * Copyright 2018 Google LLC. All Rights Reserved.
68217 * Licensed under the Apache License, Version 2.0 (the "License");
68218 * you may not use this file except in compliance with the License.
68219 * You may obtain a copy of the License at
68220 *
68221 * http://www.apache.org/licenses/LICENSE-2.0
68222 *
68223 * Unless required by applicable law or agreed to in writing, software
68224 * distributed under the License is distributed on an "AS IS" BASIS,
68225 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68226 * See the License for the specific language governing permissions and
68227 * limitations under the License.
68228 *
68229 * =============================================================================
68230 */
68231 /**
68232 * Apply a mapping function to a nested structure in a recursive manner.
68233 *
68234 * The result of the mapping is an object with the same nested structure (i.e.,
68235 * of arrays and dicts) as the input, except that some subtrees are replaced,
68236 * according to the results of the mapping function.
68237 *
68238 * Mappings are memoized. Thus, if the nested structure contains the same
68239 * object in multiple positions, the output will contain the same mapped object
68240 * in those positions. Cycles are not supported, however.
68241 *
68242 * @param input: The object to which to apply the mapping function.
68243 * @param mapFn: A function that expects a single node of the object tree, and
68244 * returns a `DeepMapResult`. The `DeepMapResult` either provides a
68245 * replacement value for that node (i.e., replacing the subtree), or indicates
68246 * that the node should be processed recursively.
68247 */
68248 function deepMap(input, mapFn) {
68249 return deepMapInternal(input, mapFn);
68250 }
68251 /**
68252 * @param seen: A Map of known object mappings (i.e., memoized results of
68253 * `mapFn()`)
68254 * @param containedIn: An set containing objects on the reference path currently
68255 * being processed (used to detect cycles).
68256 */
68257 function deepMapInternal(input, mapFn, seen = new Map(), containedIn = new Set()) {
68258 if (input == null) {
68259 return null;
68260 }
68261 if (typeof Blob === 'function' && input instanceof Blob) {
68262 return input.slice();
68263 }
68264 if (containedIn.has(input)) {
68265 throw new Error('Circular references are not supported.');
68266 }
68267 if (seen.has(input)) {
68268 return seen.get(input);
68269 }
68270 const result = mapFn(input);
68271 if (result.recurse && result.value !== null) {
68272 throw new Error('A deep map function may not return both a value and recurse=true.');
68273 }
68274 if (!result.recurse) {
68275 seen.set(input, result.value);
68276 return result.value;
68277 }
68278 else if (isIterable(input)) {
68279 // tslint:disable-next-line:no-any
68280 const mappedIterable = Array.isArray(input) ? [] : {};
68281 containedIn.add(input);
68282 for (const k in input) {
68283 const child = input[k];
68284 const childResult = deepMapInternal(child, mapFn, seen, containedIn);
68285 mappedIterable[k] = childResult;
68286 }
68287 containedIn.delete(input);
68288 if (input.__proto__) {
68289 mappedIterable.__proto__ = input.__proto__;
68290 }
68291 return mappedIterable;
68292 }
68293 else {
68294 throw new Error(`Can't recurse into non-iterable type: ${input}`);
68295 }
68296 }
68297 // TODO(soergel, kangyizhang) Reconsider naming of deepZip() to avoid confusion
68298 // with zip()
68299 /**
68300 * Zip nested structures together in a recursive manner.
68301 *
68302 * This has the effect of transposing or pivoting data, e.g. converting it from
68303 * a row-major representation to a column-major representation.
68304 *
68305 * For example, `deepZip([{a: 1, b: 2}, {a: 3, b: 4}])` returns
68306 * `{a: [1, 3], b: [2, 4]}`.
68307 *
68308 * The inputs should all have the same nested structure (i.e., of arrays and
68309 * dicts). The result is a single object with the same nested structure, where
68310 * the leaves are arrays collecting the values of the inputs at that location
68311 * (or, optionally, the result of a custom function applied to those arrays).
68312 *
68313 * @param inputs: An array of the objects to zip together.
68314 * @param zipFn: (optional) A function that expects an array of elements at a
68315 * single node of the object tree, and returns a `DeepMapResult`. The
68316 * `DeepMapResult` either provides a result value for that node (i.e.,
68317 * representing the subtree), or indicates that the node should be processed
68318 * recursively. The default zipFn recurses as far as possible and places
68319 * arrays at the leaves.
68320 */
68321 function deepZip(inputs, zipFn = zipToList) {
68322 return deepZipInternal(inputs, zipFn);
68323 }
68324 /**
68325 * @param containedIn: An set containing objects on the reference path currently
68326 * being processed (used to detect cycles).
68327 */
68328 function deepZipInternal(inputs, zipFn, containedIn = new Set()) {
68329 // The recursion follows the structure of input 0; it's assumed that all the
68330 // other inputs have the same structure.
68331 const input = inputs[0];
68332 if (containedIn.has(input)) {
68333 throw new Error('Circular references are not supported.');
68334 }
68335 const result = zipFn(inputs);
68336 if (result.recurse && result.value !== null) {
68337 throw new Error('A deep zip function may not return both a value and recurse=true.');
68338 }
68339 if (!result.recurse) {
68340 return result.value;
68341 }
68342 else if (isIterable(input)) {
68343 // tslint:disable-next-line:no-any
68344 const mappedIterable = Array.isArray(input) ? [] : {};
68345 containedIn.add(input);
68346 for (const k in input) {
68347 const children = inputs.map(x => x[k]);
68348 const childResult = deepZipInternal(children, zipFn, containedIn);
68349 mappedIterable[k] = childResult;
68350 }
68351 containedIn.delete(input);
68352 return mappedIterable;
68353 }
68354 else {
68355 throw new Error(`Can't recurse into non-iterable type: ${input}`);
68356 }
68357 }
68358 // tslint:disable-next-line:no-any
68359 function zipToList(x) {
68360 if (x === null) {
68361 return null;
68362 }
68363 // TODO(soergel): validate array type?
68364 if (isIterable(x[0])) {
68365 return { value: null, recurse: true };
68366 }
68367 else {
68368 return { value: x, recurse: false };
68369 }
68370 }
68371 /**
68372 * Apply an async mapping function to a nested structure in a recursive manner.
68373 *
68374 * This first creates a nested structure of Promises, and then awaits all of
68375 * those, resulting in a single Promise for a resolved nested structure.
68376 *
68377 * The result of the mapping is an object with the same nested structure (i.e.,
68378 * of arrays and dicts) as the input, except that some subtrees are replaced,
68379 * according to the results of the mapping function.
68380 *
68381 * Mappings are memoized. Thus, if the nested structure contains the same
68382 * object in multiple positions, the output will contain the same mapped object
68383 * in those positions. Cycles are not supported, however.
68384 *
68385 * @param input: The object to which to apply the mapping function.
68386 * @param mapFn: A function that expects a single node of the object tree, and
68387 * returns a `DeepMapAsyncResult`. The `DeepMapAsyncResult` either provides
68388 * a `Promise` for a replacement value for that node (i.e., replacing the
68389 * subtree), or indicates that the node should be processed recursively. Note
68390 * that the decision whether or not to recurse must be made immediately; only
68391 * the mapped value may be promised.
68392 */
68393 async function deepMapAndAwaitAll(input, mapFn) {
68394 const seen = new Map();
68395 // First do a normal deepMap, collecting Promises in 'seen' as a side effect.
68396 deepMapInternal(input, mapFn, seen);
68397 // Replace the Promises in 'seen' in place.
68398 // Note TypeScript provides no async map iteration, and regular map iteration
68399 // is broken too, so sadly we have to do Array.from() to make it work.
68400 // (There's no advantage to Promise.all(), and that would be tricky anyway.)
68401 for (const key of Array.from(seen.keys())) {
68402 const value = seen.get(key);
68403 if (isPromise(value)) {
68404 const mappedValue = await value;
68405 seen.set(key, mappedValue);
68406 }
68407 }
68408 // Normal deepMap again, this time filling in the resolved values.
68409 // It's unfortunate that we have to do two passes.
68410 // TODO(soergel): test performance and think harder about a fast solution.
68411 const result = deepMapInternal(input, mapFn, seen);
68412 return result;
68413 }
68414 /**
68415 * Determine whether the argument is iterable.
68416 *
68417 * @returns true if the argument is an array or any non-Tensor object.
68418 */
68419 // tslint:disable-next-line:no-any
68420 function isIterable(obj) {
68421 let isTextDecoder = false;
68422 if (env().get('IS_BROWSER')) {
68423 isTextDecoder = obj instanceof TextDecoder;
68424 }
68425 else {
68426 // tslint:disable-next-line:no-require-imports
68427 const { StringDecoder } = require('string_decoder');
68428 isTextDecoder = obj instanceof StringDecoder;
68429 }
68430 return obj != null && (!ArrayBuffer.isView(obj)) &&
68431 (Array.isArray(obj) ||
68432 (typeof obj === 'object' && !(obj instanceof Tensor) &&
68433 !(obj instanceof Promise) && !isTextDecoder));
68434 }
68435 /**
68436 * Determine whether the argument can be converted to Tensor.
68437 *
68438 * Tensors, primitives, arrays, and TypedArrays all qualify; anything else does
68439 * not.
68440 *
68441 * @returns true if the argument can be converted to Tensor.
68442 */
68443 // tslint:disable-next-line:no-any
68444 function canTensorify(obj) {
68445 return obj == null || isPrimitive(obj) || Array.isArray(obj) ||
68446 (typeof obj === 'object' && (obj instanceof Tensor)) ||
68447 isTypedArray(obj);
68448 }
68449 /**
68450 * Returns true if the given `value` is a primitive type. Otherwise returns
68451 * false. This is equivalant to node util.isPrimitive
68452 */
68453 function isPrimitive(value) {
68454 return (value === null ||
68455 (typeof value !== 'object' && typeof value !== 'function'));
68456 }
68457
68458 /**
68459 * @license
68460 * Copyright 2018 Google LLC. All Rights Reserved.
68461 * Licensed under the Apache License, Version 2.0 (the "License");
68462 * you may not use this file except in compliance with the License.
68463 * You may obtain a copy of the License at
68464 *
68465 * http://www.apache.org/licenses/LICENSE-2.0
68466 *
68467 * Unless required by applicable law or agreed to in writing, software
68468 * distributed under the License is distributed on an "AS IS" BASIS,
68469 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68470 * See the License for the specific language governing permissions and
68471 * limitations under the License.
68472 *
68473 * =============================================================================
68474 */
68475 function deepClone(container) {
68476 return deepMap(container, cloneIfTensor);
68477 }
68478 // tslint:disable-next-line: no-any
68479 function cloneIfTensor(item) {
68480 if (item instanceof Tensor) {
68481 return ({ value: item.clone(), recurse: false });
68482 }
68483 else if (isIterable(item)) {
68484 return { value: null, recurse: true };
68485 }
68486 else {
68487 return { value: item, recurse: false };
68488 }
68489 }
68490
68491 /**
68492 * @license
68493 * Copyright 2018 Google LLC. All Rights Reserved.
68494 * Licensed under the Apache License, Version 2.0 (the "License");
68495 * you may not use this file except in compliance with the License.
68496 * You may obtain a copy of the License at
68497 *
68498 * http://www.apache.org/licenses/LICENSE-2.0
68499 *
68500 * Unless required by applicable law or agreed to in writing, software
68501 * distributed under the License is distributed on an "AS IS" BASIS,
68502 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68503 * See the License for the specific language governing permissions and
68504 * limitations under the License.
68505 *
68506 * =============================================================================
68507 */
68508 /**
68509 * A ring buffer, providing O(1) FIFO, LIFO, and related operations.
68510 */
68511 class RingBuffer {
68512 /**
68513 * Constructs a `RingBuffer`.
68514 * @param capacity The number of items that the buffer can accomodate.
68515 */
68516 constructor(capacity) {
68517 this.capacity = capacity;
68518 // Note we store the indices in the range 0 <= index < 2*capacity.
68519 // This allows us to distinguish the full from the empty case.
68520 // See https://www.snellman.net/blog/archive/2016-12-13-ring-buffers/
68521 this.begin = 0; // inclusive
68522 this.end = 0; // exclusive
68523 if (capacity == null) {
68524 throw new RangeError('Can\'t create a ring buffer of unknown capacity.');
68525 }
68526 if (capacity < 1) {
68527 throw new RangeError('Can\'t create ring buffer of capacity < 1.');
68528 }
68529 this.data = new Array(capacity);
68530 this.doubledCapacity = 2 * capacity;
68531 }
68532 /**
68533 * Map any index into the range 0 <= index < 2*capacity.
68534 */
68535 wrap(index) {
68536 // don't trust % on negative numbers
68537 while (index < 0) {
68538 index += this.doubledCapacity;
68539 }
68540 return index % this.doubledCapacity;
68541 }
68542 get(index) {
68543 if (index < 0) {
68544 throw new RangeError('Can\'t get item at a negative index.');
68545 }
68546 return this.data[index % this.capacity];
68547 }
68548 set(index, value) {
68549 if (index < 0) {
68550 throw new RangeError('Can\'t set item at a negative index.');
68551 }
68552 this.data[index % this.capacity] = value;
68553 }
68554 /**
68555 * Returns the current number of items in the buffer.
68556 */
68557 length() {
68558 let length = this.end - this.begin;
68559 if (length < 0) {
68560 length = this.doubledCapacity + length;
68561 }
68562 return length;
68563 }
68564 /**
68565 * Reports whether the buffer is full.
68566 * @returns true if the number of items in the buffer equals its capacity, and
68567 * false otherwise.
68568 */
68569 isFull() {
68570 return this.length() === this.capacity;
68571 }
68572 /**
68573 * Reports whether the buffer is empty.
68574 * @returns true if the number of items in the buffer equals zero, and
68575 * false otherwise.
68576 */
68577 isEmpty() {
68578 return this.length() === 0;
68579 }
68580 /**
68581 * Adds an item to the end of the buffer.
68582 */
68583 push(value) {
68584 if (this.isFull()) {
68585 throw new RangeError('Ring buffer is full.');
68586 }
68587 this.set(this.end, value);
68588 this.end = this.wrap(this.end + 1);
68589 }
68590 /**
68591 * Adds many items to the end of the buffer, in order.
68592 */
68593 pushAll(values) {
68594 for (const value of values) {
68595 this.push(value);
68596 }
68597 }
68598 /**
68599 * Removes and returns the last item in the buffer.
68600 */
68601 pop() {
68602 if (this.isEmpty()) {
68603 throw new RangeError('Ring buffer is empty.');
68604 }
68605 this.end = this.wrap(this.end - 1);
68606 const result = this.get(this.end);
68607 this.set(this.end, undefined);
68608 return result;
68609 }
68610 /**
68611 * Adds an item to the beginning of the buffer.
68612 */
68613 unshift(value) {
68614 if (this.isFull()) {
68615 throw new RangeError('Ring buffer is full.');
68616 }
68617 this.begin = this.wrap(this.begin - 1);
68618 this.set(this.begin, value);
68619 }
68620 /**
68621 * Removes and returns the first item in the buffer.
68622 */
68623 shift() {
68624 if (this.isEmpty()) {
68625 throw new RangeError('Ring buffer is empty.');
68626 }
68627 const result = this.get(this.begin);
68628 this.set(this.begin, undefined);
68629 this.begin = this.wrap(this.begin + 1);
68630 return result;
68631 }
68632 /**
68633 * Removes and returns a specific item in the buffer, and moves the last item
68634 * to the vacated slot. This is useful for implementing a shuffling stream.
68635 * Note that this operation necessarily scrambles the original order.
68636 *
68637 * @param relativeIndex: the index of the item to remove, relative to the
68638 * first item in the buffer (e.g., hiding the ring nature of the underlying
68639 * storage).
68640 */
68641 shuffleExcise(relativeIndex) {
68642 if (this.isEmpty()) {
68643 throw new RangeError('Ring buffer is empty.');
68644 }
68645 const index = this.wrap(this.begin + relativeIndex);
68646 const result = this.get(index);
68647 this.set(index, this.pop());
68648 return result;
68649 }
68650 }
68651
68652 /**
68653 * @license
68654 * Copyright 2018 Google LLC. All Rights Reserved.
68655 * Licensed under the Apache License, Version 2.0 (the "License");
68656 * you may not use this file except in compliance with the License.
68657 * You may obtain a copy of the License at
68658 *
68659 * http://www.apache.org/licenses/LICENSE-2.0
68660 *
68661 * Unless required by applicable law or agreed to in writing, software
68662 * distributed under the License is distributed on an "AS IS" BASIS,
68663 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68664 * See the License for the specific language governing permissions and
68665 * limitations under the License.
68666 *
68667 * =============================================================================
68668 */
68669 class GrowingRingBuffer extends RingBuffer {
68670 /**
68671 * Constructs a `GrowingRingBuffer`.
68672 */
68673 constructor() {
68674 super(GrowingRingBuffer.INITIAL_CAPACITY);
68675 }
68676 isFull() {
68677 return false;
68678 }
68679 push(value) {
68680 if (super.isFull()) {
68681 this.expand();
68682 }
68683 super.push(value);
68684 }
68685 unshift(value) {
68686 if (super.isFull()) {
68687 this.expand();
68688 }
68689 super.unshift(value);
68690 }
68691 /**
68692 * Doubles the capacity of the buffer.
68693 */
68694 expand() {
68695 const newCapacity = this.capacity * 2;
68696 const newData = new Array(newCapacity);
68697 const len = this.length();
68698 // Rotate the buffer to start at index 0 again, since we can't just
68699 // allocate more space at the end.
68700 for (let i = 0; i < len; i++) {
68701 newData[i] = this.get(this.wrap(this.begin + i));
68702 }
68703 this.data = newData;
68704 this.capacity = newCapacity;
68705 this.doubledCapacity = 2 * this.capacity;
68706 this.begin = 0;
68707 this.end = len;
68708 }
68709 }
68710 GrowingRingBuffer.INITIAL_CAPACITY = 32;
68711
68712 /**
68713 * @license
68714 * Copyright 2018 Google LLC. All Rights Reserved.
68715 * Licensed under the Apache License, Version 2.0 (the "License");
68716 * you may not use this file except in compliance with the License.
68717 * You may obtain a copy of the License at
68718 *
68719 * http://www.apache.org/licenses/LICENSE-2.0
68720 *
68721 * Unless required by applicable law or agreed to in writing, software
68722 * distributed under the License is distributed on an "AS IS" BASIS,
68723 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68724 * See the License for the specific language governing permissions and
68725 * limitations under the License.
68726 *
68727 * =============================================================================
68728 */
68729 // Here we implement a simple asynchronous iterator.
68730 // This lets us avoid using either third-party stream libraries or
68731 // recent TypeScript language support requiring polyfills.
68732 /**
68733 * Create a `LazyIterator` from an array of items.
68734 */
68735 function iteratorFromItems(items) {
68736 return new ArrayIterator(items);
68737 }
68738 /**
68739 * Create a `LazyIterator` of incrementing integers.
68740 */
68741 function iteratorFromIncrementing(start) {
68742 let i = start;
68743 return iteratorFromFunction(() => ({ value: i++, done: false }));
68744 }
68745 /**
68746 * Create a `LazyIterator` from a function.
68747 *
68748 * ```js
68749 * let i = -1;
68750 * const func = () =>
68751 * ++i < 5 ? {value: i, done: false} : {value: null, done: true};
68752 * const iter = tf.data.iteratorFromFunction(func);
68753 * await iter.forEachAsync(e => console.log(e));
68754 * ```
68755 *
68756 * @param func A function that produces data on each call.
68757 */
68758 function iteratorFromFunction(func) {
68759 return new FunctionCallIterator(func);
68760 }
68761 /**
68762 * Create a `LazyIterator` by concatenating underlying streams, which are
68763 * themselves provided as a stream.
68764 *
68765 * This can also be thought of as a "stream flatten" operation.
68766 *
68767 * @param baseIterators A stream of streams to be concatenated.
68768 * @param baseErrorHandler An optional function that can intercept `Error`s
68769 * raised during a `next()` call on the base stream. This function can decide
68770 * whether the error should be propagated, whether the error should be
68771 * ignored, or whether the base stream should be terminated.
68772 */
68773 function iteratorFromConcatenated(baseIterators, baseErrorHandler) {
68774 return new ChainedIterator(baseIterators, baseErrorHandler);
68775 }
68776 /**
68777 * Create a `LazyIterator` by concatenating streams produced by calling a
68778 * stream-generating function a given number of times.
68779 *
68780 * Since a `LazyIterator` is read-once, it cannot be repeated, but this
68781 * function can be used to achieve a similar effect:
68782 *
68783 * LazyIterator.ofConcatenatedFunction(() => new MyIterator(), 6);
68784 *
68785 * @param iteratorFunc: A function that produces a new stream on each call.
68786 * @param count: The number of times to call the function.
68787 * @param baseErrorHandler An optional function that can intercept `Error`s
68788 * raised during a `next()` call on the base stream. This function can decide
68789 * whether the error should be propagated, whether the error should be
68790 * ignored, or whether the base stream should be terminated.
68791 */
68792 function iteratorFromConcatenatedFunction(iteratorFunc, count, baseErrorHandler) {
68793 return iteratorFromConcatenated(iteratorFromFunction(iteratorFunc).take(count), baseErrorHandler);
68794 }
68795 /**
68796 * Create a `LazyIterator` by zipping together an array, dict, or nested
68797 * structure of `LazyIterator`s (and perhaps additional constants).
68798 *
68799 * The underlying streams must provide elements in a consistent order such
68800 * that they correspond.
68801 *
68802 * Typically, the underlying streams should have the same number of
68803 * elements. If they do not, the behavior is determined by the
68804 * `mismatchMode` argument.
68805 *
68806 * The nested structure of the `iterators` argument determines the
68807 * structure of elements in the resulting iterator.
68808 *
68809 * @param iterators: An array or object containing LazyIterators at the
68810 * leaves.
68811 * @param mismatchMode: Determines what to do when one underlying iterator
68812 * is exhausted before the others. `ZipMismatchMode.FAIL` (the default)
68813 * causes an error to be thrown in this case. `ZipMismatchMode.SHORTEST`
68814 * causes the zipped iterator to terminate with the furst underlying
68815 * streams, so elements remaining on the longer streams are ignored.
68816 * `ZipMismatchMode.LONGEST` causes the zipped stream to continue, filling
68817 * in nulls for the exhausted streams, until all streams are exhausted.
68818 */
68819 function iteratorFromZipped(iterators, mismatchMode = ZipMismatchMode.FAIL) {
68820 return new ZipIterator(iterators, mismatchMode);
68821 }
68822 /**
68823 * An asynchronous iterator, providing lazy access to a potentially
68824 * unbounded stream of elements.
68825 *
68826 * Iterator can be obtained from a dataset:
68827 * `const iter = await dataset.iterator();`
68828 */
68829 class LazyIterator {
68830 /**
68831 * Collect all remaining elements of a bounded stream into an array.
68832 * Obviously this will succeed only for small streams that fit in memory.
68833 * Useful for testing.
68834 *
68835 * @returns A Promise for an array of stream elements, which will resolve
68836 * when the stream is exhausted.
68837 */
68838 async toArray() {
68839 const result = [];
68840 let x = await this.next();
68841 while (!x.done) {
68842 result.push(x.value);
68843 x = await this.next();
68844 }
68845 return result;
68846 }
68847 /**
68848 * Collect all elements of this dataset into an array with prefetching 100
68849 * elements. This is useful for testing, because the prefetch changes the
68850 * order in which the Promises are resolved along the processing pipeline.
68851 * This may help expose bugs where results are dependent on the order of
68852 * Promise resolution rather than on the logical order of the stream (i.e.,
68853 * due to hidden mutable state).
68854 *
68855 * @returns A Promise for an array of stream elements, which will resolve
68856 * when the stream is exhausted.
68857 */
68858 async toArrayForTest() {
68859 const stream = this.prefetch(100);
68860 const result = [];
68861 let x = await stream.next();
68862 while (!x.done) {
68863 result.push(x.value);
68864 x = await stream.next();
68865 }
68866 return result;
68867 }
68868 /**
68869 * Draw items from the stream until it is exhausted.
68870 *
68871 * This can be useful when the stream has side effects but no output. In
68872 * that case, calling this function guarantees that the stream will be
68873 * fully processed.
68874 */
68875 async resolveFully() {
68876 let x = await this.next();
68877 while (!x.done) {
68878 x = await this.next();
68879 }
68880 }
68881 /**
68882 * Draw items from the stream until it is exhausted, or a predicate fails.
68883 *
68884 * This can be useful when the stream has side effects but no output. In
68885 * that case, calling this function guarantees that the stream will be
68886 * fully processed.
68887 */
68888 async resolveWhile(predicate) {
68889 let x = await this.next();
68890 let shouldContinue = predicate(x.value);
68891 while ((!x.done) && shouldContinue) {
68892 x = await this.next();
68893 shouldContinue = predicate(x.value);
68894 }
68895 }
68896 /**
68897 * Handles errors thrown on this stream using a provided handler function.
68898 *
68899 * @param handler A function that handles any `Error` thrown during a `next()`
68900 * call and returns true if the stream should continue (dropping the failed
68901 * call) or false if the stream should quietly terminate. If the handler
68902 * itself throws (or rethrows) an `Error`, that will be propagated.
68903 *
68904 * @returns A `LazyIterator` of elements passed through from upstream,
68905 * possibly filtering or terminating on upstream `next()` calls that
68906 * throw an `Error`.
68907 */
68908 handleErrors(handler) {
68909 return new ErrorHandlingLazyIterator(this, handler);
68910 }
68911 // TODO(soergel): Implement reduce() etc.
68912 /**
68913 * Filters this stream according to `predicate`.
68914 *
68915 * @param predicate A function mapping a stream element to a boolean or a
68916 * `Promise` for one.
68917 *
68918 * @returns A `LazyIterator` of elements for which the predicate was true.
68919 */
68920 filter(predicate) {
68921 return new FilterIterator(this, predicate);
68922 }
68923 /**
68924 * Maps this stream through a 1-to-1 transform.
68925 *
68926 * @param transform A function mapping a stream element to a transformed
68927 * element.
68928 *
68929 * @returns A `LazyIterator` of transformed elements.
68930 */
68931 map(transform) {
68932 return new MapIterator(this, transform);
68933 }
68934 /**
68935 * Maps this stream through an async 1-to-1 transform.
68936 *
68937 * @param transform A function mapping a stream element to a `Promise` for a
68938 * transformed stream element.
68939 *
68940 * @returns A `LazyIterator` of transformed elements.
68941 */
68942 mapAsync(transform) {
68943 return new AsyncMapIterator(this, transform);
68944 }
68945 /**
68946 * Maps this stream through a 1-to-1 transform, forcing serial execution.
68947 *
68948 * @param transform A function mapping a stream element to a transformed
68949 * element.
68950 *
68951 * @returns A `LazyIterator` of transformed elements.
68952 */
68953 serialMapAsync(transform) {
68954 return new AsyncMapIterator(this, transform).serial();
68955 }
68956 /**
68957 * Maps this stream through a 1-to-many transform.
68958 *
68959 * @param transform A function mapping a stream element to an array of
68960 * transformed elements.
68961 *
68962 * @returns A `DataStream` of transformed elements.
68963 */
68964 flatmap(transform) {
68965 return new FlatmapIterator(this, transform);
68966 }
68967 /**
68968 * Apply a function to every element of the stream.
68969 *
68970 * @param f A function to apply to each stream element.
68971 */
68972 async forEachAsync(f) {
68973 return this.map(f).resolveFully();
68974 }
68975 /**
68976 * Apply a function to every element of the stream, forcing serial execution.
68977 *
68978 * @param f A function to apply to each stream element. Should return 'true'
68979 * to indicate that the stream should continue, or 'false' to cause it to
68980 * terminate.
68981 */
68982 async serialForEach(f) {
68983 return this.serialMapAsync(f).resolveWhile(x => (x === true));
68984 }
68985 /**
68986 * Groups elements into batches, represented as arrays of elements.
68987 *
68988 * We can think of the elements of this iterator as 'rows' (even if they are
68989 * nested structures). By the same token, consecutive values for a given
68990 * key within the elements form a 'column'. This matches the usual sense of
68991 * 'row' and 'column' when processing tabular data (e.g., parsing a CSV).
68992 *
68993 * Thus, "Row-major" means that the resulting batch is simply a collection of
68994 * rows: `[row1, row2, row3, ...]`. This is contrast to the column-major
68995 * form, which is needed for vectorized computation.
68996 *
68997 * @param batchSize The number of elements desired per batch.
68998 * @param smallLastBatch Whether to emit the final batch when it has fewer
68999 * than batchSize elements. Default true.
69000 * @returns A `LazyIterator` of batches of elements, represented as arrays
69001 * of the original element type.
69002 */
69003 rowMajorBatch(batchSize, smallLastBatch = true) {
69004 return new RowMajorBatchIterator(this, batchSize, smallLastBatch);
69005 }
69006 /**
69007 * Groups elements into batches, represented in column-major form.
69008 *
69009 * We can think of the elements of this iterator as 'rows' (even if they are
69010 * nested structures). By the same token, consecutive values for a given
69011 * key within the elements form a 'column'. This matches the usual sense of
69012 * 'row' and 'column' when processing tabular data (e.g., parsing a CSV).
69013 *
69014 * Thus, "column-major" means that the resulting batch is a (potentially
69015 * nested) structure representing the columns. Each column entry, then,
69016 * contains a collection of the values found in that column for a range of
69017 * input elements. This representation allows for vectorized computation, in
69018 * contrast to the row-major form.
69019 *
69020 * The inputs should all have the same nested structure (i.e., of arrays and
69021 * dicts). The result is a single object with the same nested structure,
69022 * where the leaves are arrays collecting the values of the inputs at that
69023 * location (or, optionally, the result of a custom function applied to those
69024 * arrays).
69025 *
69026 * @param batchSize The number of elements desired per batch.
69027 * @param smallLastBatch Whether to emit the final batch when it has fewer
69028 * than batchSize elements. Default true.
69029 * @param zipFn: (optional) A function that expects an array of elements at a
69030 * single node of the object tree, and returns a `DeepMapResult`. The
69031 * `DeepMapResult` either provides a result value for that node (i.e.,
69032 * representing the subtree), or indicates that the node should be processed
69033 * recursively. The default zipFn recurses as far as possible and places
69034 * arrays at the leaves.
69035 * @returns A `LazyIterator` of batches of elements, represented as an object
69036 * with collections at the leaves.
69037 */
69038 columnMajorBatch(batchSize, smallLastBatch = true,
69039 // tslint:disable-next-line:no-any
69040 zipFn = zipToList) {
69041 // First collect the desired number of input elements as a row-major batch.
69042 const rowBatches = this.rowMajorBatch(batchSize, smallLastBatch);
69043 // Now 'rotate' or 'pivot' the data, collecting all values from each column
69044 // in the batch (i.e., for each key within the elements) into an array.
69045 return rowBatches.map(x => deepZip(x, zipFn));
69046 }
69047 /**
69048 * Concatenate this `LazyIterator` with another.
69049 *
69050 * @param iterator A `LazyIterator` to be concatenated onto this one.
69051 * @param baseErrorHandler An optional function that can intercept `Error`s
69052 * raised during a `next()` call on the base stream. This function can
69053 * decide whether the error should be propagated, whether the error should
69054 * be ignored, or whether the base stream should be terminated.
69055 * @returns A `LazyIterator`.
69056 */
69057 concatenate(iterator, baseErrorHandler) {
69058 return new ChainedIterator(iteratorFromItems([this, iterator]), baseErrorHandler);
69059 }
69060 /**
69061 * Limits this stream to return at most `count` items.
69062 *
69063 * @param count The maximum number of items to provide from the stream. If
69064 * a negative or undefined value is given, the entire stream is returned
69065 * unaltered.
69066 */
69067 take(count) {
69068 if (count < 0 || count == null) {
69069 return this;
69070 }
69071 return new TakeIterator(this, count);
69072 }
69073 /**
69074 * Skips the first `count` items in this stream.
69075 *
69076 * @param count The number of items to skip. If a negative or undefined
69077 * value is given, the entire stream is returned unaltered.
69078 */
69079 skip(count) {
69080 if (count < 0 || count == null) {
69081 return this;
69082 }
69083 return new SkipIterator(this, count);
69084 }
69085 /**
69086 * Prefetch the first `bufferSize` items in this stream.
69087 *
69088 * Note this prefetches Promises, but makes no guarantees about when those
69089 * Promises resolve.
69090 *
69091 * @param bufferSize: An integer specifying the number of elements to be
69092 * prefetched.
69093 */
69094 prefetch(bufferSize) {
69095 return new PrefetchIterator(this, bufferSize);
69096 }
69097 // TODO(soergel): deep sharded shuffle, where supported
69098 /**
69099 * Randomly shuffles the elements of this stream.
69100 *
69101 * @param bufferSize: An integer specifying the number of elements from
69102 * this stream from which the new stream will sample.
69103 * @param seed: (Optional.) An integer specifying the random seed that
69104 * will be used to create the distribution.
69105 */
69106 shuffle(windowSize, seed) {
69107 return new ShuffleIterator(this, windowSize, seed);
69108 }
69109 /**
69110 * Force an iterator to execute serially: each next() call will await the
69111 * prior one, so that they cannot execute concurrently.
69112 */
69113 serial() {
69114 return new SerialIterator(this);
69115 }
69116 }
69117 // ============================================================================
69118 // The following private classes serve to implement the chainable methods
69119 // on LazyIterator. Unfortunately they can't be placed in separate files,
69120 // due to resulting trouble with circular imports.
69121 // ============================================================================
69122 // Iterators that just extend LazyIterator directly
69123 // ============================================================================
69124 class ArrayIterator extends LazyIterator {
69125 constructor(items) {
69126 super();
69127 this.items = items;
69128 this.trav = 0;
69129 }
69130 summary() {
69131 return `Array of ${this.items.length} items`;
69132 }
69133 async next() {
69134 if (this.trav >= this.items.length) {
69135 return { value: null, done: true };
69136 }
69137 const item = this.items[this.trav];
69138 this.trav++;
69139 return { value: deepClone(item), done: false };
69140 }
69141 }
69142 class FunctionCallIterator extends LazyIterator {
69143 constructor(nextFn) {
69144 super();
69145 this.nextFn = nextFn;
69146 }
69147 summary() {
69148 return `Function call`;
69149 }
69150 async next() {
69151 try {
69152 return this.nextFn();
69153 }
69154 catch (e) {
69155 // Modify the error message but leave the stack trace intact
69156 e.message =
69157 `Error thrown while iterating through a dataset: ${e.message}`;
69158 throw e;
69159 }
69160 }
69161 }
69162 class SerialIterator extends LazyIterator {
69163 constructor(upstream) {
69164 super();
69165 this.upstream = upstream;
69166 this.lastRead = Promise.resolve({ value: null, done: false });
69167 }
69168 summary() {
69169 return `${this.upstream.summary()} -> Serial`;
69170 }
69171 async next() {
69172 // This sets this.lastRead to a new Promise right away, as opposed to
69173 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
69174 // would not work because this.nextRead would be updated only after the
69175 // promise resolves.
69176 this.lastRead = this.lastRead.then(() => this.serialNext());
69177 return this.lastRead;
69178 }
69179 async serialNext() {
69180 return this.upstream.next();
69181 }
69182 }
69183 class SkipIterator extends LazyIterator {
69184 constructor(upstream, maxCount) {
69185 super();
69186 this.upstream = upstream;
69187 this.maxCount = maxCount;
69188 // Local state that should not be clobbered by out-of-order execution.
69189 this.count = 0;
69190 this.lastRead = Promise.resolve({ value: null, done: false });
69191 }
69192 summary() {
69193 return `${this.upstream.summary()} -> Skip`;
69194 }
69195 async next() {
69196 // This sets this.lastRead to a new Promise right away, as opposed to
69197 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
69198 // would not work because this.nextRead would be updated only after the
69199 // promise resolves.
69200 this.lastRead = this.lastRead.then(() => this.serialNext());
69201 return this.lastRead;
69202 }
69203 async serialNext() {
69204 // TODO(soergel): consider tradeoffs of reading in parallel, eg.
69205 // collecting next() promises in an Array and then waiting for
69206 // Promise.all() of those. Benefit: pseudo-parallel execution. Drawback:
69207 // maybe delayed GC.
69208 while (this.count++ < this.maxCount) {
69209 const skipped = await this.upstream.next();
69210 // short-circuit if upstream is already empty
69211 if (skipped.done) {
69212 return skipped;
69213 }
69214 dispose(skipped.value);
69215 }
69216 return this.upstream.next();
69217 }
69218 }
69219 class TakeIterator extends LazyIterator {
69220 constructor(upstream, maxCount) {
69221 super();
69222 this.upstream = upstream;
69223 this.maxCount = maxCount;
69224 this.count = 0;
69225 }
69226 summary() {
69227 return `${this.upstream.summary()} -> Take`;
69228 }
69229 async next() {
69230 if (this.count++ >= this.maxCount) {
69231 return { value: null, done: true };
69232 }
69233 return this.upstream.next();
69234 }
69235 }
69236 // Note this batch just groups items into row-wise element arrays.
69237 // Rotating these to a column-wise representation happens only at the dataset
69238 // level.
69239 class RowMajorBatchIterator extends LazyIterator {
69240 constructor(upstream, batchSize, enableSmallLastBatch = true) {
69241 super();
69242 this.upstream = upstream;
69243 this.batchSize = batchSize;
69244 this.enableSmallLastBatch = enableSmallLastBatch;
69245 this.lastRead = Promise.resolve({ value: null, done: false });
69246 }
69247 summary() {
69248 return `${this.upstream.summary()} -> RowMajorBatch`;
69249 }
69250 async next() {
69251 // This sets this.lastRead to a new Promise right away, as opposed to
69252 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
69253 // would not work because this.nextRead would be updated only after the
69254 // promise resolves.
69255 this.lastRead = this.lastRead.then(() => this.serialNext());
69256 return this.lastRead;
69257 }
69258 async serialNext() {
69259 const batch = [];
69260 while (batch.length < this.batchSize) {
69261 const item = await this.upstream.next();
69262 if (item.done) {
69263 if (this.enableSmallLastBatch && batch.length > 0) {
69264 return { value: batch, done: false };
69265 }
69266 return { value: null, done: true };
69267 }
69268 batch.push(item.value);
69269 }
69270 return { value: batch, done: false };
69271 }
69272 }
69273 class FilterIterator extends LazyIterator {
69274 constructor(upstream, predicate) {
69275 super();
69276 this.upstream = upstream;
69277 this.predicate = predicate;
69278 this.lastRead = Promise.resolve({ value: null, done: false });
69279 }
69280 summary() {
69281 return `${this.upstream.summary()} -> Filter`;
69282 }
69283 async next() {
69284 // This sets this.lastRead to a new Promise right away, as opposed to
69285 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
69286 // would not work because this.nextRead would be updated only after the
69287 // promise resolves.
69288 this.lastRead = this.lastRead.then(() => this.serialNext());
69289 return this.lastRead;
69290 }
69291 async serialNext() {
69292 while (true) {
69293 const item = await this.upstream.next();
69294 if (item.done || this.predicate(item.value)) {
69295 return item;
69296 }
69297 dispose(item.value);
69298 }
69299 }
69300 }
69301 class MapIterator extends LazyIterator {
69302 constructor(upstream, transform) {
69303 super();
69304 this.upstream = upstream;
69305 this.transform = transform;
69306 }
69307 summary() {
69308 return `${this.upstream.summary()} -> Map`;
69309 }
69310 async next() {
69311 const item = await this.upstream.next();
69312 if (item.done) {
69313 return { value: null, done: true };
69314 }
69315 const inputTensors = getTensorsInContainer(item.value);
69316 // Careful: the transform may mutate the item in place.
69317 // That's why we have to remember the input Tensors above, and then
69318 // below dispose only those that were not passed through to the output.
69319 // Note too that the transform function is responsible for tidying
69320 // any intermediate Tensors. Here we are concerned only about the
69321 // inputs.
69322 const mapped = this.transform(item.value);
69323 const outputTensors = getTensorsInContainer(mapped);
69324 // TODO(soergel) faster intersection
69325 // TODO(soergel) move to tf.disposeExcept(in, out)?
69326 for (const t of inputTensors) {
69327 if (!isTensorInList(t, outputTensors)) {
69328 t.dispose();
69329 }
69330 }
69331 return { value: mapped, done: false };
69332 }
69333 }
69334 class ErrorHandlingLazyIterator extends LazyIterator {
69335 constructor(upstream, handler) {
69336 super();
69337 this.upstream = upstream;
69338 this.handler = handler;
69339 this.count = 0;
69340 this.lastRead = Promise.resolve({ value: null, done: false });
69341 }
69342 summary() {
69343 return `${this.upstream.summary()} -> handleErrors`;
69344 }
69345 async next() {
69346 // This sets this.lastRead to a new Promise right away, as opposed to
69347 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
69348 // would not work because this.nextRead would be updated only after the
69349 // promise resolves.
69350 this.lastRead = this.lastRead.then(() => this.serialNext());
69351 return this.lastRead;
69352 }
69353 async serialNext() {
69354 while (true) {
69355 try {
69356 return await this.upstream.next();
69357 }
69358 catch (e) {
69359 if (!this.handler(e)) {
69360 return { value: null, done: true };
69361 }
69362 // If the handler returns true, loop and fetch the next upstream item.
69363 // If the upstream iterator throws an endless stream of errors, and if
69364 // the handler says to ignore them, then we loop forever here. That is
69365 // the correct behavior-- it's up to the handler to decide when to stop.
69366 }
69367 }
69368 }
69369 }
69370 class AsyncMapIterator extends LazyIterator {
69371 constructor(upstream, transform) {
69372 super();
69373 this.upstream = upstream;
69374 this.transform = transform;
69375 }
69376 summary() {
69377 return `${this.upstream.summary()} -> AsyncMap`;
69378 }
69379 async next() {
69380 const item = await this.upstream.next();
69381 if (item.done) {
69382 return { value: null, done: true };
69383 }
69384 const inputTensors = getTensorsInContainer(item.value);
69385 // Careful: the transform may mutate the item in place.
69386 // That's why we have to remember the input Tensors above, and then
69387 // below dispose only those that were not passed through to the output.
69388 // Note too that the transform function is responsible for tidying
69389 // any intermediate Tensors. Here we are concerned only about the
69390 // inputs.
69391 const mapped = await this.transform(item.value);
69392 const outputTensors = getTensorsInContainer(mapped);
69393 // TODO(soergel) faster intersection
69394 // TODO(soergel) move to tf.disposeExcept(in, out)?
69395 for (const t of inputTensors) {
69396 if (!isTensorInList(t, outputTensors)) {
69397 t.dispose();
69398 }
69399 }
69400 return { value: mapped, done: false };
69401 }
69402 }
69403 // Iterators that maintain a queue of pending items
69404 // ============================================================================
69405 /**
69406 * A base class for transforming streams that operate by maintaining an
69407 * output queue of elements that are ready to return via next(). This is
69408 * commonly required when the transformation is 1-to-many: A call to next()
69409 * may trigger a call to the underlying stream, which will produce many
69410 * mapped elements of this stream-- of which we need to return only one, so
69411 * we have to queue the rest.
69412 */
69413 class OneToManyIterator extends LazyIterator {
69414 constructor() {
69415 super();
69416 this.outputQueue = new GrowingRingBuffer();
69417 this.lastRead = Promise.resolve({ value: null, done: false });
69418 }
69419 async next() {
69420 // This sets this.lastRead to a new Promise right away, as opposed to
69421 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
69422 // would not work because this.nextRead would be updated only after the
69423 // promise resolves.
69424 this.lastRead = this.lastRead.then(() => this.serialNext());
69425 return this.lastRead;
69426 }
69427 async serialNext() {
69428 // Fetch so that the queue contains at least one item if possible.
69429 // If the upstream source is exhausted, AND there are no items left in
69430 // the output queue, then this stream is also exhausted.
69431 while (this.outputQueue.length() === 0) {
69432 // TODO(soergel): consider parallel reads.
69433 if (!await this.pump()) {
69434 return { value: null, done: true };
69435 }
69436 }
69437 return { value: this.outputQueue.shift(), done: false };
69438 }
69439 }
69440 class FlatmapIterator extends OneToManyIterator {
69441 constructor(upstream, transform) {
69442 super();
69443 this.upstream = upstream;
69444 this.transform = transform;
69445 }
69446 summary() {
69447 return `${this.upstream.summary()} -> Flatmap`;
69448 }
69449 async pump() {
69450 const item = await this.upstream.next();
69451 if (item.done) {
69452 return false;
69453 }
69454 const inputTensors = getTensorsInContainer(item.value);
69455 // Careful: the transform may mutate the item in place.
69456 // that's why we have to remember the input Tensors above, and then
69457 // below dispose only those that were not passed through to the output.
69458 // Note too that the transform function is responsible for tidying any
69459 // intermediate Tensors. Here we are concerned only about the inputs.
69460 const mappedArray = this.transform(item.value);
69461 const outputTensors = getTensorsInContainer(mappedArray);
69462 this.outputQueue.pushAll(mappedArray);
69463 // TODO(soergel) faster intersection, and deduplicate outputTensors
69464 // TODO(soergel) move to tf.disposeExcept(in, out)?
69465 for (const t of inputTensors) {
69466 if (!isTensorInList(t, outputTensors)) {
69467 t.dispose();
69468 }
69469 }
69470 return true;
69471 }
69472 }
69473 /**
69474 * Provides a `LazyIterator` that concatenates a stream of underlying
69475 * streams.
69476 *
69477 * Doing this in a concurrency-safe way requires some trickery. In
69478 * particular, we want this stream to return the elements from the
69479 * underlying streams in the correct order according to when next() was
69480 * called, even if the resulting Promises resolve in a different order.
69481 */
69482 class ChainedIterator extends LazyIterator {
69483 constructor(iterators, baseErrorHandler) {
69484 super();
69485 this.baseErrorHandler = baseErrorHandler;
69486 // Strict Promise execution order:
69487 // a next() call may not even begin until the previous one completes.
69488 this.lastRead = null;
69489 // Local state that should not be clobbered by out-of-order execution.
69490 this.iterator = null;
69491 this.moreIterators = iterators;
69492 }
69493 summary() {
69494 const upstreamSummaries = 'TODO: fill in upstream of chained summaries';
69495 return `${upstreamSummaries} -> Chained`;
69496 }
69497 async next() {
69498 this.lastRead = this.readFromChain(this.lastRead);
69499 return this.lastRead;
69500 }
69501 async readFromChain(lastRead) {
69502 // Must await on the previous read since the previous read may have advanced
69503 // the stream of streams, from which we need to read.
69504 // This is unfortunate since we can't parallelize reads. Which means
69505 // prefetching of chained streams is a no-op.
69506 // One solution is to prefetch immediately upstream of this.
69507 await lastRead;
69508 if (this.iterator == null) {
69509 const iteratorResult = await this.moreIterators.next();
69510 if (iteratorResult.done) {
69511 // No more streams to stream from.
69512 return { value: null, done: true };
69513 }
69514 this.iterator = iteratorResult.value;
69515 if (this.baseErrorHandler != null) {
69516 this.iterator = this.iterator.handleErrors(this.baseErrorHandler);
69517 }
69518 }
69519 const itemResult = await this.iterator.next();
69520 if (itemResult.done) {
69521 this.iterator = null;
69522 return this.readFromChain(lastRead);
69523 }
69524 return itemResult;
69525 }
69526 }
69527 var ZipMismatchMode;
69528 (function (ZipMismatchMode) {
69529 ZipMismatchMode[ZipMismatchMode["FAIL"] = 0] = "FAIL";
69530 ZipMismatchMode[ZipMismatchMode["SHORTEST"] = 1] = "SHORTEST";
69531 ZipMismatchMode[ZipMismatchMode["LONGEST"] = 2] = "LONGEST"; // use nulls for exhausted streams; use up the longest stream.
69532 })(ZipMismatchMode || (ZipMismatchMode = {}));
69533 /**
69534 * Provides a `LazyIterator` that zips together an array, dict, or nested
69535 * structure of `LazyIterator`s (and perhaps additional constants).
69536 *
69537 * The underlying streams must provide elements in a consistent order such
69538 * that they correspond.
69539 *
69540 * Typically, the underlying streams should have the same number of
69541 * elements. If they do not, the behavior is determined by the
69542 * `mismatchMode` argument.
69543 *
69544 * The nested structure of the `iterators` argument determines the
69545 * structure of elements in the resulting iterator.
69546 *
69547 * Doing this in a concurrency-safe way requires some trickery. In
69548 * particular, we want this stream to return the elements from the
69549 * underlying streams in the correct order according to when next() was
69550 * called, even if the resulting Promises resolve in a different order.
69551 *
69552 * @param iterators: An array or object containing LazyIterators at the
69553 * leaves.
69554 * @param mismatchMode: Determines what to do when one underlying iterator
69555 * is exhausted before the others. `ZipMismatchMode.FAIL` (the default)
69556 * causes an error to be thrown in this case. `ZipMismatchMode.SHORTEST`
69557 * causes the zipped iterator to terminate with the furst underlying
69558 * streams, so elements remaining on the longer streams are ignored.
69559 * `ZipMismatchMode.LONGEST` causes the zipped stream to continue, filling
69560 * in nulls for the exhausted streams, until all streams are exhausted.
69561 */
69562 class ZipIterator extends LazyIterator {
69563 constructor(iterators, mismatchMode = ZipMismatchMode.FAIL) {
69564 super();
69565 this.iterators = iterators;
69566 this.mismatchMode = mismatchMode;
69567 this.count = 0;
69568 this.currentPromise = null;
69569 }
69570 summary() {
69571 const upstreamSummaries = 'TODO: fill in upstream of zip summaries';
69572 return `{${upstreamSummaries}} -> Zip`;
69573 }
69574 async nextState(afterState) {
69575 // This chaining ensures that the underlying next() are not even called
69576 // before the previous ones have resolved.
69577 await afterState;
69578 // Collect underlying iterator "done" signals as a side effect in
69579 // getNext()
69580 let numIterators = 0;
69581 let iteratorsDone = 0;
69582 function getNext(container) {
69583 if (container instanceof LazyIterator) {
69584 const result = container.next();
69585 return {
69586 value: result.then(x => {
69587 numIterators++;
69588 if (x.done) {
69589 iteratorsDone++;
69590 }
69591 return x.value;
69592 }),
69593 recurse: false
69594 };
69595 }
69596 else {
69597 return { value: null, recurse: true };
69598 }
69599 }
69600 const mapped = await deepMapAndAwaitAll(this.iterators, getNext);
69601 if (numIterators === iteratorsDone) {
69602 // The streams have all ended.
69603 return { value: null, done: true };
69604 }
69605 if (iteratorsDone > 0) {
69606 switch (this.mismatchMode) {
69607 case ZipMismatchMode.FAIL:
69608 throw new Error('Zipped streams should have the same length. ' +
69609 `Mismatched at element ${this.count}.`);
69610 case ZipMismatchMode.SHORTEST:
69611 return { value: null, done: true };
69612 case ZipMismatchMode.LONGEST:
69613 default:
69614 // Continue. The exhausted streams already produced value: null.
69615 }
69616 }
69617 this.count++;
69618 return { value: mapped, done: false };
69619 }
69620 async next() {
69621 this.currentPromise = this.nextState(this.currentPromise);
69622 return this.currentPromise;
69623 }
69624 }
69625 // Iterators that maintain a ring buffer of pending promises
69626 // ============================================================================
69627 /**
69628 * A stream that prefetches a given number of items from an upstream source,
69629 * returning them in FIFO order.
69630 *
69631 * Note this prefetches Promises, but makes no guarantees about when those
69632 * Promises resolve.
69633 */
69634 class PrefetchIterator extends LazyIterator {
69635 constructor(upstream, bufferSize) {
69636 super();
69637 this.upstream = upstream;
69638 this.bufferSize = bufferSize;
69639 this.buffer = new RingBuffer(bufferSize);
69640 }
69641 summary() {
69642 return `${this.upstream.summary()} -> Prefetch`;
69643 }
69644 /**
69645 * Refill the prefetch buffer. Returns only after the buffer is full, or
69646 * the upstream source is exhausted.
69647 */
69648 refill() {
69649 while (!this.buffer.isFull()) {
69650 const v = this.upstream.next();
69651 this.buffer.push(v);
69652 }
69653 }
69654 next() {
69655 this.refill();
69656 // This shift will never throw an error because the buffer is always
69657 // full after a refill. If the stream is exhausted, the buffer will be
69658 // full of Promises that will resolve to the end-of-stream signal.
69659 return this.buffer.shift();
69660 }
69661 }
69662 /**
69663 * A stream that performs a sliding-window random shuffle on an upstream
69664 * source. This is like a `PrefetchIterator` except that the items are
69665 * returned in randomized order. Mixing naturally improves as the buffer
69666 * size increases.
69667 */
69668 class ShuffleIterator extends PrefetchIterator {
69669 constructor(upstream, windowSize, seed) {
69670 super(upstream, windowSize);
69671 this.upstream = upstream;
69672 this.windowSize = windowSize;
69673 // Local state that should not be clobbered by out-of-order execution.
69674 this.upstreamExhausted = false;
69675 this.random = seedrandom.alea(seed || now().toString());
69676 this.lastRead = Promise.resolve({ value: null, done: false });
69677 }
69678 async next() {
69679 // This sets this.lastRead to a new Promise right away, as opposed to
69680 // saying `await this.lastRead; this.lastRead = this.serialNext();` which
69681 // would not work because this.nextRead would be updated only after the
69682 // promise resolves.
69683 this.lastRead = this.lastRead.then(() => this.serialNext());
69684 return this.lastRead;
69685 }
69686 randomInt(max) {
69687 return Math.floor(this.random() * max);
69688 }
69689 chooseIndex() {
69690 return this.randomInt(this.buffer.length());
69691 }
69692 async serialNext() {
69693 // TODO(soergel): consider performance
69694 if (!this.upstreamExhausted) {
69695 this.refill();
69696 }
69697 while (!this.buffer.isEmpty()) {
69698 const chosenIndex = this.chooseIndex();
69699 const result = await this.buffer.shuffleExcise(chosenIndex);
69700 if (result.done) {
69701 this.upstreamExhausted = true;
69702 }
69703 else {
69704 this.refill();
69705 return result;
69706 }
69707 }
69708 return { value: null, done: true };
69709 }
69710 }
69711
69712 /**
69713 * @license
69714 * Copyright 2018 Google LLC. All Rights Reserved.
69715 * Licensed under the Apache License, Version 2.0 (the "License");
69716 * you may not use this file except in compliance with the License.
69717 * You may obtain a copy of the License at
69718 *
69719 * http://www.apache.org/licenses/LICENSE-2.0
69720 *
69721 * Unless required by applicable law or agreed to in writing, software
69722 * distributed under the License is distributed on an "AS IS" BASIS,
69723 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69724 * See the License for the specific language governing permissions and
69725 * limitations under the License.
69726 *
69727 * =============================================================================
69728 */
69729 // TODO(soergel): consider vectorized operations within the pipeline.
69730 /**
69731 * Represents a potentially large list of independent data elements (typically
69732 * 'samples' or 'examples').
69733 *
69734 * A 'data example' may be a primitive, an array, a map from string keys to
69735 * values, or any nested structure of these.
69736 *
69737 * A `Dataset` represents an ordered collection of elements, together with a
69738 * chain of transformations to be performed on those elements. Each
69739 * transformation is a method of `Dataset` that returns another `Dataset`, so
69740 * these may be chained, e.g.
69741 * `const processedDataset = rawDataset.filter(...).map(...).batch(...)`.
69742 *
69743 * Data loading and transformation is done in a lazy, streaming fashion. The
69744 * dataset may be iterated over multiple times; each iteration starts the data
69745 * loading anew and recapitulates the transformations.
69746 *
69747 * A `Dataset` is typically processed as a stream of unbatched examples -- i.e.,
69748 * its transformations are applied one example at a time. Batching produces a
69749 * new `Dataset` where each element is a batch. Batching should usually come
69750 * last in a pipeline, because data transformations are easier to express on a
69751 * per-example basis than on a per-batch basis.
69752 *
69753 * The following code examples are calling `await dataset.forEachAsync(...)` to
69754 * iterate once over the entire dataset in order to print out the data.
69755 *
69756 * @doc {heading: 'Data', subheading: 'Classes', namespace: 'data'}
69757 */
69758 class Dataset {
69759 constructor() {
69760 this.size = null;
69761 }
69762 // TODO(soergel): Make Datasets report whether repeated iterator() calls
69763 // produce the same result (e.g., reading from a file) or different results
69764 // (e.g., from the webcam). Currently we don't make this distinction but it
69765 // could be important for the user to know.
69766 // abstract isDeterministic(): boolean;
69767 /**
69768 * Groups elements into batches.
69769 *
69770 * It is assumed that each of the incoming dataset elements has the same
69771 * structure -- i.e. the same set of keys at each location in an object
69772 * hierarchy. For each key, the resulting `Dataset` provides a batched
69773 * element collecting all of the incoming values for that key.
69774 *
69775 * * Incoming primitives are grouped into a 1-D Tensor.
69776 * * Incoming Tensors are grouped into a new Tensor where the 0th axis is
69777 * the batch dimension.
69778 * * Incoming arrays are converted to Tensor and then batched.
69779 * * A nested array is interpreted as an n-D Tensor, so the batched result
69780 * has n+1 dimensions.
69781 * * An array that cannot be converted to Tensor produces an error.
69782 *
69783 * If an array should not be batched as a unit, it should first be converted
69784 * to an object with integer keys.
69785 *
69786 * Here are a few examples:
69787 *
69788 * Batch a dataset of numbers:
69789 * ```js
69790 * const a = tf.data.array([1, 2, 3, 4, 5, 6, 7, 8]).batch(4);
69791 * await a.forEachAsync(e => e.print());
69792 * ```
69793 *
69794 * Batch a dataset of arrays:
69795 * ```js
69796 * const b = tf.data.array([[1], [2], [3], [4], [5], [6], [7], [8]]).batch(4);
69797 * await b.forEachAsync(e => e.print());
69798 * ```
69799 *
69800 * Batch a dataset of objects:
69801 * ```js
69802 * const c = tf.data.array([{a: 1, b: 11}, {a: 2, b: 12}, {a: 3, b: 13},
69803 * {a: 4, b: 14}, {a: 5, b: 15}, {a: 6, b: 16}, {a: 7, b: 17},
69804 * {a: 8, b: 18}]).batch(4);
69805 * await c.forEachAsync(e => {
69806 * console.log('{');
69807 * for(var key in e) {
69808 * console.log(key+':');
69809 * e[key].print();
69810 * }
69811 * console.log('}');
69812 * })
69813 * ```
69814 *
69815 * @param batchSize The number of elements desired per batch.
69816 * @param smallLastBatch Whether to emit the final batch when it has fewer
69817 * than batchSize elements. Default true.
69818 * @returns A `Dataset`, from which a stream of batches can be obtained.
69819 *
69820 * @doc {heading: 'Data', subheading: 'Classes'}
69821 */
69822 batch(batchSize, smallLastBatch = true) {
69823 const base = this;
69824 assert$1(batchSize > 0, () => `batchSize needs to be positive, but it is
69825 ${batchSize}`);
69826 let size;
69827 if (this.size === Infinity || this.size == null) {
69828 // If the size of this dataset is infinity or null, the new size keeps the
69829 // same.
69830 size = this.size;
69831 }
69832 else if (smallLastBatch) {
69833 // If the size of this dataset is known and include small last batch, the
69834 // new size is full batch count plus last batch.
69835 size = Math.ceil(this.size / batchSize);
69836 }
69837 else {
69838 // If the size of this dataset is known and not include small last batch,
69839 // the new size is full batch count.
69840 size = Math.floor(this.size / batchSize);
69841 }
69842 return datasetFromIteratorFn(async () => {
69843 return (await base.iterator())
69844 .columnMajorBatch(batchSize, smallLastBatch, deepBatchConcat);
69845 }, size);
69846 }
69847 /**
69848 * Concatenates this `Dataset` with another.
69849 *
69850 * ```js
69851 * const a = tf.data.array([1, 2, 3]);
69852 * const b = tf.data.array([4, 5, 6]);
69853 * const c = a.concatenate(b);
69854 * await c.forEachAsync(e => console.log(e));
69855 * ```
69856 *
69857 * @param dataset A `Dataset` to be concatenated onto this one.
69858 * @returns A `Dataset`.
69859 *
69860 * @doc {heading: 'Data', subheading: 'Classes'}
69861 */
69862 concatenate(dataset) {
69863 const base = this;
69864 let size;
69865 if (this.size === Infinity || dataset.size === Infinity) {
69866 // If the size of any of these two dataset is infinity, new size is
69867 // infinity.
69868 size = Infinity;
69869 }
69870 else if (this.size != null && dataset.size != null) {
69871 // If the size of both datasets are known and not infinity, new size is
69872 // sum the size of these two datasets.
69873 size = this.size + dataset.size;
69874 }
69875 else {
69876 // If neither of these two datasets has infinite size and any of these two
69877 // datasets' size is null, the new size is null.
69878 size = null;
69879 }
69880 return datasetFromIteratorFn(async () => (await base.iterator()).concatenate(await dataset.iterator()), size);
69881 }
69882 /**
69883 * Filters this dataset according to `predicate`.
69884 *
69885 * ```js
69886 * const a = tf.data.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
69887 * .filter(x => x%2 === 0);
69888 * await a.forEachAsync(e => console.log(e));
69889 * ```
69890 *
69891 * @param predicate A function mapping a dataset element to a boolean or a
69892 * `Promise` for one.
69893 *
69894 * @returns A `Dataset` of elements for which the predicate was true.
69895 *
69896 * @doc {heading: 'Data', subheading: 'Classes'}
69897 */
69898 filter(predicate) {
69899 const base = this;
69900 let size;
69901 if (this.size === Infinity) {
69902 // If the size of this dataset is infinity, new size is infinity
69903 size = Infinity;
69904 }
69905 else {
69906 // If this dataset has limited elements, new size is null because it might
69907 // exhausted randomly.
69908 size = null;
69909 }
69910 return datasetFromIteratorFn(async () => {
69911 return (await base.iterator()).filter(x => tidy(() => predicate(x)));
69912 }, size);
69913 }
69914 /**
69915 * Apply a function to every element of the dataset.
69916 *
69917 * After the function is applied to a dataset element, any Tensors contained
69918 * within that element are disposed.
69919 *
69920 * ```js
69921 * const a = tf.data.array([1, 2, 3]);
69922 * await a.forEachAsync(e => console.log(e));
69923 * ```
69924 *
69925 * @param f A function to apply to each dataset element.
69926 * @returns A `Promise` that resolves after all elements have been processed.
69927 *
69928 * @doc {heading: 'Data', subheading: 'Classes'}
69929 */
69930 async forEachAsync(f) {
69931 return (await this.iterator()).forEachAsync(f);
69932 }
69933 /**
69934 * Maps this dataset through a 1-to-1 transform.
69935 *
69936 * ```js
69937 * const a = tf.data.array([1, 2, 3]).map(x => x*x);
69938 * await a.forEachAsync(e => console.log(e));
69939 * ```
69940 *
69941 * @param transform A function mapping a dataset element to a transformed
69942 * dataset element.
69943 *
69944 * @returns A `Dataset` of transformed elements.
69945 *
69946 * @doc {heading: 'Data', subheading: 'Classes'}
69947 */
69948 map(transform) {
69949 const base = this;
69950 return datasetFromIteratorFn(async () => {
69951 return (await base.iterator()).map(x => tidy(() => transform(x)));
69952 }, this.size);
69953 }
69954 /**
69955 * Maps this dataset through an async 1-to-1 transform.
69956 *
69957 * ```js
69958 * const a =
69959 * tf.data.array([1, 2, 3]).mapAsync(x => new Promise(function(resolve){
69960 * setTimeout(() => {
69961 * resolve(x * x);
69962 * }, Math.random()*1000 + 500);
69963 * }));
69964 * console.log(await a.toArray());
69965 * ```
69966 *
69967 * @param transform A function mapping a dataset element to a `Promise` for a
69968 * transformed dataset element. This transform is responsible for disposing
69969 * any intermediate `Tensor`s, i.e. by wrapping its computation in
69970 * `tf.tidy()`; that cannot be automated here (as it is in the synchronous
69971 * `map()` case).
69972 *
69973 * @returns A `Dataset` of transformed elements.
69974 *
69975 * @doc {heading: 'Data', subheading: 'Classes'}
69976 */
69977 mapAsync(transform) {
69978 const base = this;
69979 return datasetFromIteratorFn(async () => {
69980 return (await base.iterator()).mapAsync(transform);
69981 }, this.size);
69982 }
69983 /**
69984 * Creates a `Dataset` that prefetches elements from this dataset.
69985 *
69986 * @param bufferSize: An integer specifying the number of elements to be
69987 * prefetched.
69988 * @returns A `Dataset`.
69989 *
69990 * @doc {heading: 'Data', subheading: 'Classes'}
69991 */
69992 prefetch(bufferSize) {
69993 if (bufferSize == null) {
69994 throw new RangeError('`Dataset.prefetch()` requires bufferSize to be specified.');
69995 }
69996 const base = this;
69997 return datasetFromIteratorFn(async () => (await base.iterator()).prefetch(bufferSize), this.size);
69998 }
69999 /**
70000 * Repeats this dataset `count` times.
70001 *
70002 * NOTE: If this dataset is a function of global state (e.g. a random number
70003 * generator), then different repetitions may produce different elements.
70004 *
70005 * ```js
70006 * const a = tf.data.array([1, 2, 3]).repeat(3);
70007 * await a.forEachAsync(e => console.log(e));
70008 * ```
70009 *
70010 * @param count: (Optional) An integer, representing the number of times
70011 * the dataset should be repeated. The default behavior (if `count` is
70012 * `undefined` or negative) is for the dataset be repeated indefinitely.
70013 * @returns A `Dataset`.
70014 *
70015 * @doc {heading: 'Data', subheading: 'Classes'}
70016 */
70017 repeat(count) {
70018 const base = this;
70019 let size;
70020 if (this.size != null && count > 0) {
70021 // If this dataset has size and count is positive, new size is current
70022 // size multiply count. This also covers the case that current size is
70023 // infinity.
70024 size = this.size * count;
70025 }
70026 else if (count === 0) {
70027 // If count is 0, new size is 0.
70028 size = 0;
70029 }
70030 else if (this.size != null && (count === undefined || count < 0)) {
70031 // If this dataset has size and count is undefined or negative, the
70032 // dataset will be repeated indefinitely and new size is infinity.
70033 size = Infinity;
70034 }
70035 else {
70036 // If the size of this dataset is null, the new dataset's size is null.
70037 size = null;
70038 }
70039 return datasetFromIteratorFn(async () => {
70040 const iteratorIterator = iteratorFromFunction(async () => ({ value: await base.iterator(), done: false }));
70041 return iteratorFromConcatenated(iteratorIterator.take(count));
70042 }, size);
70043 }
70044 /**
70045 * Creates a `Dataset` that skips `count` initial elements from this dataset.
70046 *
70047 * ```js
70048 * const a = tf.data.array([1, 2, 3, 4, 5, 6]).skip(3);
70049 * await a.forEachAsync(e => console.log(e));
70050 * ```
70051 *
70052 * @param count: The number of elements of this dataset that should be skipped
70053 * to form the new dataset. If `count` is greater than the size of this
70054 * dataset, the new dataset will contain no elements. If `count`
70055 * is `undefined` or negative, skips the entire dataset.
70056 *
70057 * @returns A `Dataset`.
70058 *
70059 * @doc {heading: 'Data', subheading: 'Classes'}
70060 */
70061 skip(count) {
70062 const base = this;
70063 let size;
70064 if (this.size != null && count >= 0 && this.size >= count) {
70065 // If the size of this dataset is greater than count, the new dataset's
70066 // size is current size minus skipped size.This also covers the case that
70067 // current size is infinity.
70068 size = this.size - count;
70069 }
70070 else if (this.size != null &&
70071 (this.size < count || count === undefined || count < 0)) {
70072 // If the size of this dataset is smaller than count, or count is
70073 // undefined or negative, skips the entire dataset and the new size is 0.
70074 size = 0;
70075 }
70076 else {
70077 // If the size of this dataset is null, the new dataset's size is null.
70078 size = null;
70079 }
70080 return datasetFromIteratorFn(async () => (await base.iterator()).skip(count), size);
70081 }
70082 /**
70083 * Pseudorandomly shuffles the elements of this dataset. This is done in a
70084 * streaming manner, by sampling from a given number of prefetched elements.
70085 *
70086 * ```js
70087 * const a = tf.data.array([1, 2, 3, 4, 5, 6]).shuffle(3);
70088 * await a.forEachAsync(e => console.log(e));
70089 * ```
70090 *
70091 * @param bufferSize: An integer specifying the number of elements from this
70092 * dataset from which the new dataset will sample.
70093 * @param seed: (Optional) An integer specifying the random seed that will
70094 * be used to create the distribution.
70095 * @param reshuffleEachIteration: (Optional) A boolean, which if true
70096 * indicates that the dataset should be pseudorandomly reshuffled each time
70097 * it is iterated over. If false, elements will be returned in the same
70098 * shuffled order on each iteration. (Defaults to `true`.)
70099 * @returns A `Dataset`.
70100 *
70101 * @doc {heading: 'Data', subheading: 'Classes'}
70102 */
70103 shuffle(bufferSize, seed, reshuffleEachIteration = true) {
70104 if (bufferSize == null || bufferSize < 0) {
70105 if (this.size == null) {
70106 throw new RangeError('`Dataset.shuffle()` requires bufferSize to be specified.');
70107 }
70108 else {
70109 throw new RangeError('`Dataset.shuffle()` requires bufferSize to be specified. ' +
70110 'If your data fits in main memory (for regular JS objects), ' +
70111 'and/or GPU memory (for `tf.Tensor`s), consider setting ' +
70112 `bufferSize to the dataset size (${this.size} elements)`);
70113 }
70114 }
70115 const base = this;
70116 const random = seedrandom.alea(seed || now().toString());
70117 return datasetFromIteratorFn(async () => {
70118 let seed2 = random.int32();
70119 if (reshuffleEachIteration) {
70120 seed2 += random.int32();
70121 }
70122 return (await base.iterator()).shuffle(bufferSize, seed2.toString());
70123 }, this.size);
70124 }
70125 /**
70126 * Creates a `Dataset` with at most `count` initial elements from this
70127 * dataset.
70128 *
70129 * ```js
70130 * const a = tf.data.array([1, 2, 3, 4, 5, 6]).take(3);
70131 * await a.forEachAsync(e => console.log(e));
70132 * ```
70133 *
70134 * @param count: The number of elements of this dataset that should be taken
70135 * to form the new dataset. If `count` is `undefined` or negative, or if
70136 * `count` is greater than the size of this dataset, the new dataset will
70137 * contain all elements of this dataset.
70138 * @returns A `Dataset`.
70139 *
70140 * @doc {heading: 'Data', subheading: 'Classes'}
70141 */
70142 take(count) {
70143 const base = this;
70144 let size;
70145 if (this.size != null && this.size > count) {
70146 // If the size of this dataset is greater than count, the new dataset's
70147 // size is count.
70148 size = count;
70149 }
70150 else if (this.size != null && this.size <= count) {
70151 // If the size of this dataset is equal or smaller than count, the new
70152 // dataset's size is the size of this dataset.
70153 size = this.size;
70154 }
70155 else {
70156 // If the size of this dataset is null, the new dataset's size is null.
70157 size = null;
70158 }
70159 return datasetFromIteratorFn(async () => (await base.iterator()).take(count), size);
70160 }
70161 /**
70162 * Collect all elements of this dataset into an array.
70163 *
70164 * Obviously this will succeed only for small datasets that fit in memory.
70165 * Useful for testing and generally should be avoided if possible.
70166 *
70167 * ```js
70168 * const a = tf.data.array([1, 2, 3, 4, 5, 6]);
70169 * console.log(await a.toArray());
70170 * ```
70171 *
70172 * @returns A Promise for an array of elements, which will resolve
70173 * when a new stream has been obtained and fully consumed.
70174 *
70175 * @doc {heading: 'Data', subheading: 'Classes'}
70176 */
70177 async toArray() {
70178 if (this.size === Infinity) {
70179 throw new Error('Can not convert infinite data stream to array.');
70180 }
70181 return (await this.iterator()).toArray();
70182 }
70183 /**
70184 * Collect all elements of this dataset into an array with prefetching 100
70185 * elements. This is useful for testing, because the prefetch changes the
70186 * order in which the Promises are resolved along the processing pipeline.
70187 * This may help expose bugs where results are dependent on the order of
70188 * Promise resolution rather than on the logical order of the stream (i.e.,
70189 * due to hidden mutable state).
70190 *
70191 * @returns A Promise for an array of elements, which will resolve
70192 * when a new stream has been obtained and fully consumed.
70193 */
70194 async toArrayForTest() {
70195 if (this.size === Infinity) {
70196 throw new Error('Can not convert infinite data stream to array.');
70197 }
70198 return (await this.iterator()).toArrayForTest();
70199 }
70200 }
70201 // TODO(soergel): deep sharded shuffle, where supported
70202 Dataset.MAX_BUFFER_SIZE = 10000;
70203 /**
70204 * Create a `Dataset` defined by a provided iterator() function.
70205 *
70206 * ```js
70207 * let i = -1;
70208 * const func = () =>
70209 * ++i < 5 ? {value: i, done: false} : {value: null, done: true};
70210 * const iter = tf.data.iteratorFromFunction(func);
70211 * const ds = tf.data.datasetFromIteratorFn(iter);
70212 * await ds.forEachAsync(e => console.log(e));
70213 * ```
70214 */
70215 function datasetFromIteratorFn(iteratorFn, size = null) {
70216 return new class extends Dataset {
70217 constructor() {
70218 super(...arguments);
70219 this.size = size;
70220 }
70221 /*
70222 * Provide a new stream of elements. Note this will also start new streams
70223 * from any underlying `Dataset`s.
70224 */
70225 async iterator() {
70226 return iteratorFn();
70227 }
70228 }();
70229 }
70230 /**
70231 * Create a `Dataset` from an array of elements.
70232 *
70233 * Create a Dataset from an array of objects:
70234 * ```js
70235 * const a = tf.data.array([{'item': 1}, {'item': 2}, {'item': 3}]);
70236 * await a.forEachAsync(e => console.log(e));
70237 * ```
70238 *
70239 * Create a Dataset from an array of numbers:
70240 * ```js
70241 * const a = tf.data.array([4, 5, 6]);
70242 * await a.forEachAsync(e => console.log(e));
70243 * ```
70244 * @param items An array of elements that will be parsed as items in a dataset.
70245 *
70246 * @doc {heading: 'Data', subheading: 'Creation', namespace: 'data'}
70247 */
70248 function array(items) {
70249 return datasetFromIteratorFn(async () => iteratorFromItems(items), items.length);
70250 }
70251 /**
70252 * Create a `Dataset` by zipping together an array, dict, or nested
70253 * structure of `Dataset`s (and perhaps additional constants).
70254 * The underlying datasets must provide elements in a consistent order such that
70255 * they correspond.
70256 *
70257 * The number of elements in the resulting dataset is the same as the size of
70258 * the smallest dataset in datasets.
70259 *
70260 * The nested structure of the `datasets` argument determines the
70261 * structure of elements in the resulting iterator.
70262 *
70263 * Note this means that, given an array of two datasets that produce dict
70264 * elements, the result is a dataset that produces elements that are arrays
70265 * of two dicts:
70266 *
70267 * Zip an array of datasets:
70268 * ```js
70269 * console.log('Zip two datasets of objects:');
70270 * const ds1 = tf.data.array([{a: 1}, {a: 2}, {a: 3}]);
70271 * const ds2 = tf.data.array([{b: 4}, {b: 5}, {b: 6}]);
70272 * const ds3 = tf.data.zip([ds1, ds2]);
70273 * await ds3.forEachAsync(e => console.log(JSON.stringify(e)));
70274 *
70275 * // If the goal is to merge the dicts in order to produce elements like
70276 * // {a: ..., b: ...}, this requires a second step such as:
70277 * console.log('Merge the objects:');
70278 * const ds4 = ds3.map(x => {return {a: x[0].a, b: x[1].b}});
70279 * await ds4.forEachAsync(e => console.log(e));
70280 * ```
70281 *
70282 * Zip a dict of datasets:
70283 * ```js
70284 * const a = tf.data.array([{a: 1}, {a: 2}, {a: 3}]);
70285 * const b = tf.data.array([{b: 4}, {b: 5}, {b: 6}]);
70286 * const c = tf.data.zip({c: a, d: b});
70287 * await c.forEachAsync(e => console.log(JSON.stringify(e)));
70288 * ```
70289 *
70290 * @doc {heading: 'Data', subheading: 'Operations', namespace: 'data'}
70291 */
70292 function zip(datasets) {
70293 // manually type-check the argument for JS users
70294 if (!isIterable(datasets)) {
70295 throw new Error('The argument to zip() must be an object or array.');
70296 }
70297 let size;
70298 if (Array.isArray(datasets)) {
70299 for (let i = 0; i < datasets.length; i++) {
70300 size = size == null ? datasets[i].size :
70301 Math.min(size, datasets[i].size);
70302 }
70303 }
70304 else if (datasets instanceof Object) {
70305 for (const ds in datasets) {
70306 size = size == null ? datasets[ds].size :
70307 Math.min(size, datasets[ds].size);
70308 }
70309 }
70310 return datasetFromIteratorFn(async () => {
70311 const streams = await deepMapAndAwaitAll(datasets, d => {
70312 if (d instanceof Dataset) {
70313 return { value: d.iterator(), recurse: false };
70314 }
70315 else if (isIterable(d)) {
70316 return { value: null, recurse: true };
70317 }
70318 else {
70319 throw new Error('Leaves of the structure passed to zip() must be Datasets, ' +
70320 'not primitives.');
70321 }
70322 });
70323 return iteratorFromZipped(streams, ZipMismatchMode.SHORTEST);
70324 }, size);
70325 }
70326 /**
70327 * A zip function for use with deepZip, passed via the columnMajorBatch call.
70328 *
70329 * Accepts an array of identically-structured nested elements and either batches
70330 * them (if they are primitives, numeric arrays, or Tensors) or requests
70331 * recursion (if not).
70332 */
70333 // tslint:disable-next-line:no-any
70334 function deepBatchConcat(rows) {
70335 if (rows === null) {
70336 return null;
70337 }
70338 // use the first item to decide whether to recurse or batch here.
70339 const exampleRow = rows[0];
70340 if (canTensorify(exampleRow)) {
70341 // rows is an array of primitives, Tensors, or arrays. Batch them.
70342 const value = batchConcat(rows);
70343 return { value, recurse: false };
70344 }
70345 // the example row is an object, so recurse into it.
70346 return { value: null, recurse: true };
70347 }
70348 /**
70349 * Assembles a list of same-shaped numbers, number arrays, or Tensors
70350 * into a single new Tensor where axis 0 is the batch dimension.
70351 */
70352 function batchConcat(arrays) {
70353 if (arrays.length === 0) {
70354 // We can't return an empty Tensor because we don't know the element shape.
70355 throw new Error('Can\'t make a batch of zero elements.');
70356 }
70357 if (arrays[0] instanceof Tensor) {
70358 // Input is an array of Tensors
70359 return stack(arrays);
70360 }
70361 else {
70362 // Input is a possibly-nested array of numbers.
70363 return tensor(arrays);
70364 }
70365 }
70366
70367 /**
70368 * @license
70369 * Copyright 2018 Google LLC. All Rights Reserved.
70370 * Licensed under the Apache License, Version 2.0 (the "License");
70371 * you may not use this file except in compliance with the License.
70372 * You may obtain a copy of the License at
70373 *
70374 * http://www.apache.org/licenses/LICENSE-2.0
70375 *
70376 * Unless required by applicable law or agreed to in writing, software
70377 * distributed under the License is distributed on an "AS IS" BASIS,
70378 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70379 * See the License for the specific language governing permissions and
70380 * limitations under the License.
70381 *
70382 * =============================================================================
70383 */
70384 /**
70385 * Represents a potentially large collection of text lines.
70386 *
70387 * The results are not batched.
70388 */
70389 class TextLineDataset extends Dataset {
70390 /**
70391 * Create a `TextLineDataset`.
70392 *
70393 * @param input A `DataSource` providing a chunked, UTF8-encoded byte stream.
70394 */
70395 constructor(input) {
70396 super();
70397 this.input = input;
70398 }
70399 async iterator() {
70400 const inputIterator = await this.input.iterator();
70401 const utf8Iterator = inputIterator.decodeUTF8();
70402 const lineIterator = utf8Iterator.split('\n').map(line => {
70403 // Windows/DOS format text file has extra line breaker at the end of line.
70404 if (line.endsWith('\r')) {
70405 line = line.slice(0, -1);
70406 }
70407 return line;
70408 });
70409 return lineIterator;
70410 }
70411 }
70412
70413 /**
70414 * @license
70415 * Copyright 2018 Google LLC. All Rights Reserved.
70416 * Licensed under the Apache License, Version 2.0 (the "License");
70417 * you may not use this file except in compliance with the License.
70418 * You may obtain a copy of the License at
70419 *
70420 * http://www.apache.org/licenses/LICENSE-2.0
70421 *
70422 * Unless required by applicable law or agreed to in writing, software
70423 * distributed under the License is distributed on an "AS IS" BASIS,
70424 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70425 * See the License for the specific language governing permissions and
70426 * limitations under the License.
70427 *
70428 * =============================================================================
70429 */
70430 const CODE_QUOTE = '"';
70431 const STATE_OUT = Symbol('out');
70432 const STATE_FIELD = Symbol('field');
70433 const STATE_QUOTE = Symbol('quote');
70434 const STATE_QUOTE_AFTER_QUOTE = Symbol('quoteafterquote');
70435 const STATE_WITHIN_QUOTE_IN_QUOTE = Symbol('quoteinquote');
70436 /**
70437 * Represents a potentially large collection of delimited text records.
70438 *
70439 * The produced `TensorContainer`s each contain one key-value pair for
70440 * every column of the table. When a field is empty in the incoming data, the
70441 * resulting value is `undefined`, or throw error if it is required. Values
70442 * that can be parsed as numbers are emitted as type `number`, other values
70443 * are parsed as `string`.
70444 *
70445 * The results are not batched.
70446 *
70447 * @doc {heading: 'Data', subheading: 'Classes', namespace: 'data'}
70448 */
70449 class CSVDataset extends Dataset {
70450 /**
70451 * Returns column names of the csv dataset. If `configuredColumnsOnly` is
70452 * true, return column names in `columnConfigs`. If `configuredColumnsOnly` is
70453 * false and `columnNames` is provided, `columnNames`. If
70454 * `configuredColumnsOnly` is false and `columnNames` is not provided, return
70455 * all column names parsed from the csv file. For example usage please go to
70456 * `tf.data.csv`.
70457 *
70458 * @doc {heading: 'Data', subheading: 'Classes'}
70459 */
70460 async columnNames() {
70461 if (!this.columnNamesValidated) {
70462 await this.setColumnNames();
70463 }
70464 return this.configuredColumnsOnly ? Object.keys(this.columnConfigs) :
70465 this.fullColumnNames;
70466 }
70467 /* 1) If `columnNames` is provided as string[], use this string[] as output
70468 * keys in corresponding order. The length must match the number of inferred
70469 * columns if `hasHeader` is true .
70470 * 2) If `columnNames` is not provided, parse header line as `columnNames` if
70471 * hasHeader is true. If `hasHeader` is false, throw an error.
70472 * 3) If `columnConfigs` is provided, all the keys in `columnConfigs` must
70473 * exist in parsed `columnNames`.
70474 */
70475 async setColumnNames() {
70476 const columnNamesFromFile = await this.maybeReadHeaderLine();
70477 if (!this.fullColumnNames && !columnNamesFromFile) {
70478 // Throw an error if columnNames is not provided and no header line.
70479 throw new Error('Column names must be provided if there is no header line.');
70480 }
70481 else if (this.fullColumnNames && columnNamesFromFile) {
70482 // Check provided columnNames match header line.
70483 assert$1(columnNamesFromFile.length === this.fullColumnNames.length, () => 'The length of provided columnNames (' +
70484 this.fullColumnNames.length.toString() +
70485 ') does not match the length of the header line read from ' +
70486 'file (' + columnNamesFromFile.length.toString() + ').');
70487 }
70488 if (!this.fullColumnNames) {
70489 this.fullColumnNames = columnNamesFromFile;
70490 }
70491 // Check if there are duplicate column names.
70492 const counts = this.fullColumnNames.reduce((countAcc, name) => {
70493 countAcc[name] = (countAcc[name] + 1) || 1;
70494 return countAcc;
70495 }, {});
70496 const duplicateNames = Object.keys(counts).filter((name) => (counts[name] > 1));
70497 assert$1(duplicateNames.length === 0, () => 'Duplicate column names found: ' + duplicateNames.toString());
70498 // Check if keys in columnConfigs match columnNames.
70499 if (this.columnConfigs) {
70500 for (const key of Object.keys(this.columnConfigs)) {
70501 const index = this.fullColumnNames.indexOf(key);
70502 if (index === -1) {
70503 throw new Error('The key "' + key +
70504 '" provided in columnConfigs does not match any of the column ' +
70505 'names (' + this.fullColumnNames.toString() + ').');
70506 }
70507 }
70508 }
70509 this.columnNamesValidated = true;
70510 }
70511 async maybeReadHeaderLine() {
70512 if (this.hasHeader) {
70513 const iter = await this.base.iterator();
70514 const firstElement = await iter.next();
70515 if (firstElement.done) {
70516 throw new Error('No data was found for CSV parsing.');
70517 }
70518 const firstLine = firstElement.value;
70519 const headers = this.parseRow(firstLine, false);
70520 return headers;
70521 }
70522 else {
70523 return null;
70524 }
70525 }
70526 /**
70527 * Create a `CSVDataset`.
70528 *
70529 * @param input A `DataSource` providing a chunked, UTF8-encoded byte stream.
70530 * @param csvConfig (Optional) A CSVConfig object that contains configurations
70531 * of reading and decoding from CSV file(s).
70532 *
70533 * hasHeader: (Optional) A boolean value that indicates whether the first
70534 * row of provided CSV file is a header line with column names, and should
70535 * not be included in the data. Defaults to `true`.
70536 *
70537 * columnNames: (Optional) A list of strings that corresponds to
70538 * the CSV column names, in order. If provided, it ignores the column
70539 * names inferred from the header row. If not provided, infers the column
70540 * names from the first row of the records. If hasHeader is false and
70541 * columnNames is not provided, this method throws an error.
70542 *
70543 * columnConfigs: (Optional) A dictionary whose key is column names, value
70544 * is an object stating if this column is required, column's data type,
70545 * default value, and if this column is label. If provided, keys must
70546 * correspond to names provided in columnNames or inferred from the file
70547 * header lines. If isLabel is true any column, returns an array of two
70548 * items: the first item is a dict of features key/value pairs, the second
70549 * item is a dict of labels key/value pairs. If no feature is marked as
70550 * label, returns a dict of features only.
70551 *
70552 * configuredColumnsOnly (Optional) If true, only columns provided in
70553 * columnConfigs will be parsed and provided during iteration.
70554 *
70555 * delimiter (Optional) The string used to parse each line of the input
70556 * file. Defaults to `,`.
70557 */
70558 constructor(input, csvConfig) {
70559 super();
70560 this.input = input;
70561 this.hasHeader = true;
70562 this.fullColumnNames = null;
70563 this.columnNamesValidated = false;
70564 this.columnConfigs = null;
70565 this.configuredColumnsOnly = false;
70566 this.delimiter = ',';
70567 this.delimWhitespace = false;
70568 this.base = new TextLineDataset(input);
70569 if (!csvConfig) {
70570 csvConfig = {};
70571 }
70572 this.hasHeader = csvConfig.hasHeader === false ? false : true;
70573 this.fullColumnNames = csvConfig.columnNames;
70574 this.columnConfigs = csvConfig.columnConfigs;
70575 this.configuredColumnsOnly = csvConfig.configuredColumnsOnly;
70576 if (csvConfig.delimWhitespace) {
70577 assert$1(csvConfig.delimiter == null, () => 'Delimiter should not be provided when delimWhitespace is true.');
70578 this.delimWhitespace = true;
70579 this.delimiter = ' ';
70580 }
70581 else {
70582 this.delimiter = csvConfig.delimiter ? csvConfig.delimiter : ',';
70583 }
70584 }
70585 async iterator() {
70586 if (!this.columnNamesValidated) {
70587 await this.setColumnNames();
70588 }
70589 let lines = await this.base.iterator();
70590 if (this.hasHeader) {
70591 // We previously read the first line to get the columnNames.
70592 // Now that we're providing data, skip it.
70593 lines = lines.skip(1);
70594 }
70595 return lines.map(x => this.makeDataElement(x));
70596 }
70597 makeDataElement(line) {
70598 const values = this.parseRow(line);
70599 const features = {};
70600 const labels = {};
70601 for (let i = 0; i < this.fullColumnNames.length; i++) {
70602 const key = this.fullColumnNames[i];
70603 const config = this.columnConfigs ? this.columnConfigs[key] : null;
70604 if (this.configuredColumnsOnly && !config) {
70605 // This column is not selected.
70606 continue;
70607 }
70608 else {
70609 const value = values[i];
70610 let parsedValue = null;
70611 if (value === '') {
70612 // If default value is provided, use it. If default value is not
70613 // provided, set as undefined.
70614 if (config && config.default !== undefined) {
70615 parsedValue = config.default;
70616 }
70617 else if (config && (config.required || config.isLabel)) {
70618 throw new Error(`Required column ${key} is empty in this line: ${line}`);
70619 }
70620 else {
70621 parsedValue = undefined;
70622 }
70623 }
70624 else {
70625 // A value is present, so parse it based on type
70626 const valueAsNum = Number(value);
70627 if (isNaN(valueAsNum)) {
70628 // The value is a string and this column is declared as boolean
70629 // in config, parse it as boolean.
70630 if (config && config.dtype === 'bool') {
70631 parsedValue = this.getBoolean(value);
70632 }
70633 else {
70634 // Set value as string
70635 parsedValue = value;
70636 }
70637 }
70638 else if (!config || !config.dtype) {
70639 // If this value is a number and no type config is provided, return
70640 // it as number.
70641 parsedValue = valueAsNum;
70642 }
70643 else {
70644 // If this value is a number and data type is provided, parse it
70645 // according to provided data type.
70646 switch (config.dtype) {
70647 case 'float32':
70648 parsedValue = valueAsNum;
70649 break;
70650 case 'int32':
70651 parsedValue = Math.floor(valueAsNum);
70652 break;
70653 case 'bool':
70654 parsedValue = this.getBoolean(value);
70655 break;
70656 default:
70657 parsedValue = valueAsNum;
70658 }
70659 }
70660 }
70661 // Check if this column is label.
70662 (config && config.isLabel) ? labels[key] = parsedValue :
70663 features[key] = parsedValue;
70664 }
70665 }
70666 // If label exists, return an object of features and labels as {xs:features,
70667 // ys:labels}, otherwise return features only.
70668 if (Object.keys(labels).length === 0) {
70669 return features;
70670 }
70671 else {
70672 return { xs: features, ys: labels };
70673 }
70674 }
70675 getBoolean(value) {
70676 if (value === '1' || value.toLowerCase() === 'true') {
70677 return 1;
70678 }
70679 else {
70680 return 0;
70681 }
70682 }
70683 // adapted from https://beta.observablehq.com/@mbostock/streaming-csv
70684 parseRow(line, validateElementCount = true) {
70685 const result = [];
70686 let readOffset = 0;
70687 const readLength = line.length;
70688 let currentState = STATE_OUT;
70689 // Goes through the line to parse quote.
70690 for (let i = 0; i < readLength; i++) {
70691 switch (currentState) {
70692 // Before enter a new field
70693 case STATE_OUT:
70694 switch (line.charAt(i)) {
70695 // Enter a quoted field
70696 case CODE_QUOTE:
70697 readOffset = i + 1;
70698 currentState = STATE_QUOTE;
70699 break;
70700 // Read an empty field
70701 case this.delimiter:
70702 readOffset = i + 1;
70703 // If delimiter is white space and configured to collapse
70704 // multiple white spaces, ignore this white space.
70705 if (this.delimiter === ' ' && this.delimWhitespace) {
70706 break;
70707 }
70708 result.push('');
70709 currentState = STATE_OUT;
70710 break;
70711 // Enter an unquoted field
70712 default:
70713 currentState = STATE_FIELD;
70714 readOffset = i;
70715 break;
70716 }
70717 break;
70718 // In an unquoted field
70719 case STATE_FIELD:
70720 switch (line.charAt(i)) {
70721 // Exit an unquoted field, add it to result
70722 case this.delimiter:
70723 result.push(line.substring(readOffset, i));
70724 currentState = STATE_OUT;
70725 readOffset = i + 1;
70726 break;
70727 default:
70728 }
70729 break;
70730 // In a quoted field
70731 case STATE_QUOTE:
70732 switch (line.charAt(i)) {
70733 // Read a quote after a quote
70734 case CODE_QUOTE:
70735 currentState = STATE_QUOTE_AFTER_QUOTE;
70736 break;
70737 default:
70738 }
70739 break;
70740 // This state means it's right after a second quote in a field
70741 case STATE_QUOTE_AFTER_QUOTE:
70742 switch (line.charAt(i)) {
70743 // Finished a quoted field
70744 case this.delimiter:
70745 result.push(line.substring(readOffset, i - 1));
70746 currentState = STATE_OUT;
70747 readOffset = i + 1;
70748 break;
70749 // Finished a quoted part in a quoted field
70750 case CODE_QUOTE:
70751 currentState = STATE_QUOTE;
70752 break;
70753 // In a quoted part in a quoted field
70754 default:
70755 currentState = STATE_WITHIN_QUOTE_IN_QUOTE;
70756 break;
70757 }
70758 break;
70759 case STATE_WITHIN_QUOTE_IN_QUOTE:
70760 switch (line.charAt(i)) {
70761 // Exit a quoted part in a quoted field
70762 case CODE_QUOTE:
70763 currentState = STATE_QUOTE;
70764 break;
70765 default:
70766 }
70767 break;
70768 default:
70769 }
70770 }
70771 // Adds last item based on if it is quoted.
70772 if (currentState === STATE_QUOTE_AFTER_QUOTE) {
70773 result.push(line.substring(readOffset, readLength - 1));
70774 }
70775 else {
70776 result.push(line.substring(readOffset));
70777 }
70778 // Check if each row has the same number of elements as column names.
70779 if (validateElementCount && result.length !== this.fullColumnNames.length) {
70780 throw new Error(`Invalid row in csv file. Should have ${this.fullColumnNames.length} elements in a row, but got ${result}`);
70781 }
70782 return result;
70783 }
70784 }
70785 // TODO(soergel): add more basic datasets for parity with tf.data
70786 // tf.data.FixedLengthRecordDataset()
70787 // tf.data.TFRecordDataset()
70788
70789 /**
70790 * @license
70791 * Copyright 2019 Google LLC. All Rights Reserved.
70792 * Licensed under the Apache License, Version 2.0 (the "License");
70793 * you may not use this file except in compliance with the License.
70794 * You may obtain a copy of the License at
70795 *
70796 * http://www.apache.org/licenses/LICENSE-2.0
70797 *
70798 * Unless required by applicable law or agreed to in writing, software
70799 * distributed under the License is distributed on an "AS IS" BASIS,
70800 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70801 * See the License for the specific language governing permissions and
70802 * limitations under the License.
70803 *
70804 * =============================================================================
70805 */
70806 /**
70807 * Provide a stream of tensors from microphone audio stream. The tensors are
70808 * representing audio data as frequency-domain spectrogram generated with
70809 * browser's native FFT. Tensors representing time-domain waveform is available
70810 * based on configuration. Only works in browser environment.
70811 */
70812 class MicrophoneIterator extends LazyIterator {
70813 constructor(microphoneConfig) {
70814 super();
70815 this.microphoneConfig = microphoneConfig;
70816 this.isClosed = false;
70817 this.fftSize = microphoneConfig.fftSize || 1024;
70818 const fftSizeLog2 = Math.log2(this.fftSize);
70819 if (this.fftSize < 0 || fftSizeLog2 < 4 || fftSizeLog2 > 14 ||
70820 !Number.isInteger(fftSizeLog2)) {
70821 throw new Error(`Invalid fftSize: it must be a power of 2 between ` +
70822 `2 to 4 and 2 to 14, but got ${this.fftSize}`);
70823 }
70824 this.numFrames = microphoneConfig.numFramesPerSpectrogram || 43;
70825 this.sampleRateHz = microphoneConfig.sampleRateHz;
70826 this.columnTruncateLength =
70827 microphoneConfig.columnTruncateLength || this.fftSize;
70828 this.audioTrackConstraints = microphoneConfig.audioTrackConstraints;
70829 this.smoothingTimeConstant = microphoneConfig.smoothingTimeConstant || 0;
70830 this.includeSpectrogram =
70831 microphoneConfig.includeSpectrogram === false ? false : true;
70832 this.includeWaveform =
70833 microphoneConfig.includeWaveform === true ? true : false;
70834 if (!this.includeSpectrogram && !this.includeWaveform) {
70835 throw new Error('Both includeSpectrogram and includeWaveform are false. ' +
70836 'At least one type of data should be returned.');
70837 }
70838 }
70839 summary() {
70840 return `microphone`;
70841 }
70842 // Construct a MicrophoneIterator and start the audio stream.
70843 static async create(microphoneConfig = {}) {
70844 if (!env().get('IS_BROWSER')) {
70845 throw new Error('microphone API is only supported in browser environment.');
70846 }
70847 const microphoneIterator = new MicrophoneIterator(microphoneConfig);
70848 // Call async function start() to initialize the audio stream.
70849 await microphoneIterator.start();
70850 return microphoneIterator;
70851 }
70852 // Start the audio stream and FFT.
70853 async start() {
70854 try {
70855 this.stream = await navigator.mediaDevices.getUserMedia({
70856 audio: this.audioTrackConstraints == null ? true :
70857 this.audioTrackConstraints,
70858 video: false
70859 });
70860 }
70861 catch (e) {
70862 throw new Error(`Error thrown while initializing video stream: ${e.message}`);
70863 }
70864 if (!this.stream) {
70865 throw new Error('Could not obtain audio from microphone.');
70866 }
70867 const ctxConstructor =
70868 // tslint:disable-next-line:no-any
70869 window.AudioContext || window.webkitAudioContext;
70870 this.audioContext = new ctxConstructor();
70871 if (!this.sampleRateHz) {
70872 // If sample rate is not provided, use the available sample rate on
70873 // device.
70874 this.sampleRateHz = this.audioContext.sampleRate;
70875 }
70876 else if (this.audioContext.sampleRate !== this.sampleRateHz) {
70877 throw new Error(`Mismatch in sampling rate: ` +
70878 `Expected: ${this.sampleRateHz}; ` +
70879 `Actual: ${this.audioContext.sampleRate}`);
70880 }
70881 const streamSource = this.audioContext.createMediaStreamSource(this.stream);
70882 this.analyser = this.audioContext.createAnalyser();
70883 this.analyser.fftSize = this.fftSize * 2;
70884 this.analyser.smoothingTimeConstant = this.smoothingTimeConstant;
70885 streamSource.connect(this.analyser);
70886 this.freqData = new Float32Array(this.fftSize);
70887 this.timeData = new Float32Array(this.fftSize);
70888 return;
70889 }
70890 async next() {
70891 if (this.isClosed) {
70892 return { value: null, done: true };
70893 }
70894 let spectrogramTensor;
70895 let waveformTensor;
70896 const audioDataQueue = await this.getAudioData();
70897 if (this.includeSpectrogram) {
70898 const freqData = this.flattenQueue(audioDataQueue.freqDataQueue);
70899 spectrogramTensor = this.getTensorFromAudioDataArray(freqData, [this.numFrames, this.columnTruncateLength, 1]);
70900 }
70901 if (this.includeWaveform) {
70902 const timeData = this.flattenQueue(audioDataQueue.timeDataQueue);
70903 waveformTensor = this.getTensorFromAudioDataArray(timeData, [this.numFrames * this.fftSize, 1]);
70904 }
70905 return {
70906 value: { 'spectrogram': spectrogramTensor, 'waveform': waveformTensor },
70907 done: false
70908 };
70909 }
70910 // Capture one result from the audio stream, and extract the value from
70911 // iterator.next() result.
70912 async capture() {
70913 return (await this.next()).value;
70914 }
70915 async getAudioData() {
70916 const freqDataQueue = [];
70917 const timeDataQueue = [];
70918 let currentFrames = 0;
70919 return new Promise(resolve => {
70920 const intervalID = setInterval(() => {
70921 if (this.includeSpectrogram) {
70922 this.analyser.getFloatFrequencyData(this.freqData);
70923 // If the audio stream is initializing, return empty queue.
70924 if (this.freqData[0] === -Infinity) {
70925 resolve({ freqDataQueue, timeDataQueue });
70926 }
70927 freqDataQueue.push(this.freqData.slice(0, this.columnTruncateLength));
70928 }
70929 if (this.includeWaveform) {
70930 this.analyser.getFloatTimeDomainData(this.timeData);
70931 timeDataQueue.push(this.timeData.slice());
70932 }
70933 // Clean interval and return when all frames have been collected
70934 if (++currentFrames === this.numFrames) {
70935 clearInterval(intervalID);
70936 resolve({ freqDataQueue, timeDataQueue });
70937 }
70938 }, this.fftSize / this.sampleRateHz * 1e3);
70939 });
70940 }
70941 // Stop the audio stream and pause the iterator.
70942 stop() {
70943 if (!this.isClosed) {
70944 this.isClosed = true;
70945 this.analyser.disconnect();
70946 this.audioContext.close();
70947 if (this.stream != null && this.stream.getTracks().length > 0) {
70948 this.stream.getTracks()[0].stop();
70949 }
70950 }
70951 }
70952 // Override toArray() function to prevent collecting.
70953 toArray() {
70954 throw new Error('Can not convert infinite audio stream to array.');
70955 }
70956 // Return audio sampling rate in Hz
70957 getSampleRate() {
70958 return this.sampleRateHz;
70959 }
70960 flattenQueue(queue) {
70961 const frameSize = queue[0].length;
70962 const freqData = new Float32Array(queue.length * frameSize);
70963 queue.forEach((data, i) => freqData.set(data, i * frameSize));
70964 return freqData;
70965 }
70966 getTensorFromAudioDataArray(freqData, shape) {
70967 const vals = new Float32Array(sizeFromShape(shape));
70968 // If the data is less than the output shape, the rest is padded with zeros.
70969 vals.set(freqData, vals.length - freqData.length);
70970 return tensor(vals, shape);
70971 }
70972 }
70973
70974 /**
70975 * @license
70976 * Copyright 2018 Google LLC. All Rights Reserved.
70977 * Licensed under the Apache License, Version 2.0 (the "License");
70978 * you may not use this file except in compliance with the License.
70979 * You may obtain a copy of the License at
70980 *
70981 * http://www.apache.org/licenses/LICENSE-2.0
70982 *
70983 * Unless required by applicable law or agreed to in writing, software
70984 * distributed under the License is distributed on an "AS IS" BASIS,
70985 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70986 * See the License for the specific language governing permissions and
70987 * limitations under the License.
70988 *
70989 * =============================================================================
70990 */
70991 /**
70992 * Provide a stream of image tensors from webcam video stream. Only works in
70993 * browser environment.
70994 */
70995 class WebcamIterator extends LazyIterator {
70996 constructor(webcamVideoElement, webcamConfig) {
70997 super();
70998 this.webcamVideoElement = webcamVideoElement;
70999 this.webcamConfig = webcamConfig;
71000 this.isClosed = true;
71001 this.resize = false;
71002 if (this.needToResize()) {
71003 this.resize = true;
71004 this.cropSize =
71005 [this.webcamConfig.resizeHeight, this.webcamConfig.resizeWidth];
71006 this.cropBoxInd = tensor1d([0], 'int32');
71007 if (this.webcamConfig.centerCrop) {
71008 // Calculate the box based on resizing shape.
71009 const widthCroppingRatio = this.webcamConfig.resizeWidth * 1.0 / this.webcamVideoElement.width;
71010 const heightCroppingRatio = this.webcamConfig.resizeHeight * 1.0 /
71011 this.webcamVideoElement.height;
71012 const widthCropStart = (1 - widthCroppingRatio) / 2;
71013 const heightCropStart = (1 - heightCroppingRatio) / 2;
71014 const widthCropEnd = widthCropStart + widthCroppingRatio;
71015 const heightCropEnd = heightCroppingRatio + heightCropStart;
71016 this.cropBox = tensor2d([heightCropStart, widthCropStart, heightCropEnd, widthCropEnd], [1, 4]);
71017 }
71018 else {
71019 this.cropBox = tensor2d([0, 0, 1, 1], [1, 4]);
71020 }
71021 }
71022 }
71023 summary() {
71024 return `webcam`;
71025 }
71026 // Construct a WebcamIterator and start it's video stream.
71027 static async create(webcamVideoElement, webcamConfig = {}) {
71028 if (!env().get('IS_BROWSER')) {
71029 throw new Error('tf.data.webcam is only supported in browser environment.');
71030 }
71031 if (!webcamVideoElement) {
71032 // If webcam video element is not provided, create a hidden video element
71033 // with provided width and height.
71034 webcamVideoElement = document.createElement('video');
71035 if (!webcamConfig.resizeWidth || !webcamConfig.resizeHeight) {
71036 throw new Error('Please provide webcam video element, or resizeWidth and ' +
71037 'resizeHeight to create a hidden video element.');
71038 }
71039 webcamVideoElement.width = webcamConfig.resizeWidth;
71040 webcamVideoElement.height = webcamConfig.resizeHeight;
71041 }
71042 const webcamIterator = new WebcamIterator(webcamVideoElement, webcamConfig);
71043 // Call async function to initialize the video stream.
71044 await webcamIterator.start();
71045 return webcamIterator;
71046 }
71047 // Async function to start video stream.
71048 async start() {
71049 if (this.webcamConfig.facingMode) {
71050 assert$1((this.webcamConfig.facingMode === 'user') ||
71051 (this.webcamConfig.facingMode === 'environment'), () => `Invalid webcam facing mode: ${this.webcamConfig.facingMode}. ` +
71052 `Please provide 'user' or 'environment'`);
71053 }
71054 try {
71055 this.stream = await navigator.mediaDevices.getUserMedia({
71056 video: {
71057 deviceId: this.webcamConfig.deviceId,
71058 facingMode: this.webcamConfig.facingMode ?
71059 this.webcamConfig.facingMode :
71060 'user',
71061 width: this.webcamVideoElement.width,
71062 height: this.webcamVideoElement.height
71063 }
71064 });
71065 }
71066 catch (e) {
71067 // Modify the error message but leave the stack trace intact
71068 e.message = `Error thrown while initializing video stream: ${e.message}`;
71069 throw e;
71070 }
71071 if (!this.stream) {
71072 throw new Error('Could not obtain video from webcam.');
71073 }
71074 // Older browsers may not have srcObject
71075 try {
71076 this.webcamVideoElement.srcObject = this.stream;
71077 }
71078 catch (error) {
71079 console.log(error);
71080 this.webcamVideoElement.src = window.URL.createObjectURL(this.stream);
71081 }
71082 // Start the webcam video stream
71083 this.webcamVideoElement.play();
71084 this.isClosed = false;
71085 return new Promise(resolve => {
71086 // Add event listener to make sure the webcam has been fully initialized.
71087 this.webcamVideoElement.onloadedmetadata = () => {
71088 resolve();
71089 };
71090 });
71091 }
71092 async next() {
71093 if (this.isClosed) {
71094 return { value: null, done: true };
71095 }
71096 let img;
71097 try {
71098 img = fromPixels$1(this.webcamVideoElement);
71099 }
71100 catch (e) {
71101 throw new Error(`Error thrown converting video to pixels: ${JSON.stringify(e)}`);
71102 }
71103 if (this.resize) {
71104 try {
71105 return { value: this.cropAndResizeFrame(img), done: false };
71106 }
71107 catch (e) {
71108 throw new Error(`Error thrown cropping the video: ${e.message}`);
71109 }
71110 finally {
71111 img.dispose();
71112 }
71113 }
71114 else {
71115 return { value: img, done: false };
71116 }
71117 }
71118 needToResize() {
71119 // If resizeWidth and resizeHeight are provided, and different from the
71120 // width and height of original HTMLVideoElement, then resizing and cropping
71121 // is required.
71122 if (this.webcamConfig.resizeWidth && this.webcamConfig.resizeHeight &&
71123 (this.webcamVideoElement.width !== this.webcamConfig.resizeWidth ||
71124 this.webcamVideoElement.height !== this.webcamConfig.resizeHeight)) {
71125 return true;
71126 }
71127 return false;
71128 }
71129 // Cropping and resizing each frame based on config
71130 cropAndResizeFrame(img) {
71131 return tidy(() => {
71132 const expandedImage = expandDims$3(cast$3(img, 'float32'), (0));
71133 let resizedImage;
71134 resizedImage = image$1.cropAndResize(expandedImage, this.cropBox, this.cropBoxInd, this.cropSize, 'bilinear');
71135 // Extract image from batch cropping.
71136 const shape = resizedImage.shape;
71137 return reshape$3(resizedImage, shape.slice(1));
71138 });
71139 }
71140 // Capture one frame from the video stream, and extract the value from
71141 // iterator.next() result.
71142 async capture() {
71143 return (await this.next()).value;
71144 }
71145 // Stop the video stream and pause webcam iterator.
71146 stop() {
71147 const tracks = this.stream.getTracks();
71148 tracks.forEach(track => track.stop());
71149 try {
71150 this.webcamVideoElement.srcObject = null;
71151 }
71152 catch (error) {
71153 console.log(error);
71154 this.webcamVideoElement.src = null;
71155 }
71156 this.isClosed = true;
71157 }
71158 // Override toArray() function to prevent collecting.
71159 toArray() {
71160 throw new Error('Can not convert infinite video stream to array.');
71161 }
71162 }
71163
71164 /**
71165 * @license
71166 * Copyright 2018 Google LLC. All Rights Reserved.
71167 * Licensed under the Apache License, Version 2.0 (the "License");
71168 * you may not use this file except in compliance with the License.
71169 * You may obtain a copy of the License at
71170 *
71171 * http://www.apache.org/licenses/LICENSE-2.0
71172 *
71173 * Unless required by applicable law or agreed to in writing, software
71174 * distributed under the License is distributed on an "AS IS" BASIS,
71175 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71176 * See the License for the specific language governing permissions and
71177 * limitations under the License.
71178 *
71179 * =============================================================================
71180 */
71181 /**
71182 * Represents a data source readable as a stream of binary data chunks.
71183 *
71184 * Because `Dataset`s can be read repeatedly (via `Dataset.iterator()`), this
71185 * provides a means to repeatedly create streams from the underlying data
71186 * sources.
71187 */
71188 class DataSource {
71189 }
71190 // TODO(soergel): consider convenience factory functions here
71191 // in combination with chainable source->dataset above, e.g.:
71192 // tf.data.url(...).asCsvDataset().shuffle().batch()
71193
71194 /**
71195 * @license
71196 * Copyright 2018 Google LLC. All Rights Reserved.
71197 * Licensed under the Apache License, Version 2.0 (the "License");
71198 * you may not use this file except in compliance with the License.
71199 * You may obtain a copy of the License at
71200 *
71201 * http://www.apache.org/licenses/LICENSE-2.0
71202 *
71203 * Unless required by applicable law or agreed to in writing, software
71204 * distributed under the License is distributed on an "AS IS" BASIS,
71205 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71206 * See the License for the specific language governing permissions and
71207 * limitations under the License.
71208 *
71209 * =============================================================================
71210 */
71211 class StringIterator extends LazyIterator {
71212 /**
71213 * Splits a string stream on a given separator.
71214 *
71215 * It is assumed that the incoming chunk boundaries have no semantic meaning,
71216 * so conceptually the incoming stream is treated simply as the concatenation
71217 * of its elements.
71218 *
71219 * The outgoing stream provides chunks corresponding to the results of the
71220 * standard string split() operation (even if such a chunk spanned incoming
71221 * chunks). The separators are not included.
71222 *
71223 * A typical usage is to split a text file (represented as a stream with
71224 * arbitrary chunk boundaries) into lines.
71225 *
71226 * @param upstream A readable stream of strings that can be treated as
71227 * concatenated.
71228 * @param separator A character to split on.
71229 */
71230 split(separator) {
71231 return new SplitIterator(this, separator);
71232 }
71233 }
71234 // ============================================================================
71235 // The following private classes serve to implement the chainable methods
71236 // on StringIterator. Unfortunately they can't be placed in separate files, due
71237 // to resulting trouble with circular imports.
71238 // ============================================================================
71239 // We wanted multiple inheritance, e.g.
71240 // class SplitIterator extends QueueIterator<string>, StringIterator
71241 // but the TypeScript mixin approach is a bit hacky, so we take this adapter
71242 // approach instead.
71243 class SplitIterator extends StringIterator {
71244 constructor(upstream, separator) {
71245 super();
71246 this.upstream = upstream;
71247 this.impl = new SplitIteratorImpl(upstream, separator);
71248 }
71249 summary() {
71250 return this.impl.summary();
71251 }
71252 async next() {
71253 return this.impl.next();
71254 }
71255 }
71256 class SplitIteratorImpl extends OneToManyIterator {
71257 constructor(upstream, separator) {
71258 super();
71259 this.upstream = upstream;
71260 this.separator = separator;
71261 // A partial string at the end of an upstream chunk
71262 this.carryover = '';
71263 }
71264 summary() {
71265 return `${this.upstream.summary()} -> Split('${this.separator}')`;
71266 }
71267 async pump() {
71268 const chunkResult = await this.upstream.next();
71269 if (chunkResult.done) {
71270 if (this.carryover === '') {
71271 return false;
71272 }
71273 // Pretend that the pump succeeded in order to emit the small last batch.
71274 // The next pump() call will actually fail.
71275 this.outputQueue.push(this.carryover);
71276 this.carryover = '';
71277 return true;
71278 }
71279 const lines = chunkResult.value.split(this.separator);
71280 // Note the behavior: " ab ".split(' ') === ['', 'ab', '']
71281 // Thus the carryover may be '' if the separator falls on a chunk
71282 // boundary; this produces the correct result.
71283 lines[0] = this.carryover + lines[0];
71284 for (const line of lines.slice(0, -1)) {
71285 this.outputQueue.push(line);
71286 }
71287 this.carryover = lines[lines.length - 1];
71288 return true;
71289 }
71290 }
71291
71292 /**
71293 * @license
71294 * Copyright 2018 Google LLC. All Rights Reserved.
71295 * Licensed under the Apache License, Version 2.0 (the "License");
71296 * you may not use this file except in compliance with the License.
71297 * You may obtain a copy of the License at
71298 *
71299 * http://www.apache.org/licenses/LICENSE-2.0
71300 *
71301 * Unless required by applicable law or agreed to in writing, software
71302 * distributed under the License is distributed on an "AS IS" BASIS,
71303 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71304 * See the License for the specific language governing permissions and
71305 * limitations under the License.
71306 *
71307 * =============================================================================
71308 */
71309 class ByteChunkIterator extends LazyIterator {
71310 /**
71311 * Decode a stream of UTF8-encoded byte arrays to a stream of strings.
71312 *
71313 * The byte arrays producetd from the ByteChunkIterator on which this is
71314 * called will be interpreted as concatenated. No assumptions are made about
71315 * the boundaries of the incoming chunks, so a multi-byte UTF8 encoding of a
71316 * character may span the boundary between chunks. This naturally happens,
71317 * for instance, when reading fixed-size byte arrays from a file.
71318 */
71319 decodeUTF8() {
71320 return new Utf8Iterator(this);
71321 }
71322 }
71323 // ============================================================================
71324 // The following private classes serve to implement the chainable methods
71325 // on ByteChunkIterator. Unfortunately they can't be placed in separate files,
71326 // due to resulting trouble with circular imports.
71327 // ============================================================================
71328 // We wanted multiple inheritance, e.g.
71329 // class Utf8Iterator extends QueueIterator<string>, StringIterator
71330 // but the TypeScript mixin approach is a bit hacky, so we take this adapter
71331 // approach instead.
71332 class Utf8Iterator extends StringIterator {
71333 constructor(upstream) {
71334 super();
71335 this.upstream = upstream;
71336 this.impl = new Utf8IteratorImpl(upstream);
71337 }
71338 summary() {
71339 return this.impl.summary();
71340 }
71341 async next() {
71342 return this.impl.next();
71343 }
71344 }
71345 /**
71346 * Decode a stream of UTF8-encoded byte arrays to a stream of strings.
71347 *
71348 * This is tricky because the incoming byte array boundaries may disrupt a
71349 * multi-byte UTF8 character. Thus any incomplete character data at the end of
71350 * a chunk must be carried over and prepended to the next chunk before
71351 * decoding. Luckily with native decoder, TextDecoder in browser and
71352 * string_decoder in node, byte array boundaries are handled automatically.
71353 *
71354 * In the context of an input pipeline for machine learning, UTF8 decoding is
71355 * needed to parse text files containing training examples or prediction
71356 * requests (e.g., formatted as CSV or JSON). We cannot use the built-in
71357 * decoding provided by FileReader.readAsText() because here we are in a
71358 * streaming context, which FileReader does not support.
71359 *
71360 * @param upstream A `LazyIterator` of `Uint8Arrays` containing UTF8-encoded
71361 * text, which should be interpreted as concatenated. No assumptions are
71362 * made about the boundaries of the incoming chunks, so a multi-byte UTF8
71363 * encoding of a character may span the boundary between chunks. This
71364 * naturally happens, for instance, when reading fixed-size byte arrays from a
71365 * file.
71366 */
71367 class Utf8IteratorImpl extends OneToManyIterator {
71368 constructor(upstream) {
71369 super();
71370 this.upstream = upstream;
71371 if (env().get('IS_BROWSER')) {
71372 this.decoder = new TextDecoder('utf-8');
71373 }
71374 else {
71375 // tslint:disable-next-line:no-require-imports
71376 const { StringDecoder } = require('string_decoder');
71377 this.decoder = new StringDecoder('utf8');
71378 }
71379 }
71380 summary() {
71381 return `${this.upstream.summary()} -> Utf8`;
71382 }
71383 async pump() {
71384 const chunkResult = await this.upstream.next();
71385 let chunk;
71386 if (chunkResult.done) {
71387 return false;
71388 }
71389 else {
71390 chunk = chunkResult.value;
71391 }
71392 let text;
71393 if (env().get('IS_BROWSER')) {
71394 text = this.decoder.decode(chunk, { stream: true });
71395 }
71396 else {
71397 text = this.decoder.write(Buffer.from(chunk.buffer));
71398 }
71399 this.outputQueue.push(text);
71400 return true;
71401 }
71402 }
71403
71404 /**
71405 * @license
71406 * Copyright 2018 Google LLC. All Rights Reserved.
71407 * Licensed under the Apache License, Version 2.0 (the "License");
71408 * you may not use this file except in compliance with the License.
71409 * You may obtain a copy of the License at
71410 *
71411 * http://www.apache.org/licenses/LICENSE-2.0
71412 *
71413 * Unless required by applicable law or agreed to in writing, software
71414 * distributed under the License is distributed on an "AS IS" BASIS,
71415 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71416 * See the License for the specific language governing permissions and
71417 * limitations under the License.
71418 *
71419 * =============================================================================
71420 */
71421 /**
71422 * Provide a stream of chunks from a File, Blob, or Uint8Array.
71423 * @param file The source File, Blob or Uint8Array.
71424 * @param options Optional settings controlling file reading.
71425 * @returns a lazy Iterator of Uint8Arrays containing sequential chunks of the
71426 * input File, Blob or Uint8Array.
71427 */
71428 class FileChunkIterator extends ByteChunkIterator {
71429 constructor(file, options = {}) {
71430 super();
71431 this.file = file;
71432 this.options = options;
71433 assert$1((file instanceof Uint8Array) ||
71434 (env().get('IS_BROWSER') ?
71435 (file instanceof File || file instanceof Blob) :
71436 false), () => 'FileChunkIterator only supports File, Blob and Uint8Array ' +
71437 'right now.');
71438 this.offset = options.offset || 0;
71439 // default 1MB chunk has tolerable perf on large files
71440 this.chunkSize = options.chunkSize || 1024 * 1024;
71441 }
71442 summary() {
71443 return `FileChunks ${this.file}`;
71444 }
71445 async next() {
71446 if (this.offset >= ((this.file instanceof Uint8Array) ?
71447 this.file.byteLength :
71448 this.file.size)) {
71449 return { value: null, done: true };
71450 }
71451 const chunk = new Promise((resolve, reject) => {
71452 const end = this.offset + this.chunkSize;
71453 if (this.file instanceof Uint8Array) {
71454 // Note if end > this.uint8Array.byteLength, we just get a small last
71455 // chunk.
71456 resolve(new Uint8Array(this.file.slice(this.offset, end)));
71457 }
71458 else {
71459 // This branch assumes that this.file type is File or Blob, which
71460 // means it is in the browser environment.
71461 // TODO(soergel): is this a performance issue?
71462 const fileReader = new FileReader();
71463 fileReader.onload = (event) => {
71464 let data = fileReader.result;
71465 // Not sure we can trust the return type of
71466 // FileReader.readAsArrayBuffer See e.g.
71467 // https://github.com/node-file-api/FileReader/issues/2
71468 if (data instanceof ArrayBuffer) {
71469 data = new Uint8Array(data);
71470 }
71471 if (!(data instanceof Uint8Array)) {
71472 return reject(new TypeError('FileReader returned unknown type.'));
71473 }
71474 resolve(data);
71475 };
71476 fileReader.onabort = (event) => {
71477 return reject(new Error('Aborted'));
71478 };
71479 fileReader.onerror = (event) => {
71480 return reject(new Error(event.type));
71481 };
71482 // TODO(soergel): better handle onabort, onerror
71483 // Note if end > this.file.size, we just get a small last chunk.
71484 const slice = this.file.slice(this.offset, end);
71485 // We can't use readAsText here (even if we know the file is text)
71486 // because the slice boundary may fall within a multi-byte character.
71487 fileReader.readAsArrayBuffer(slice);
71488 }
71489 this.offset = end;
71490 });
71491 return { value: (await chunk), done: false };
71492 }
71493 }
71494
71495 /**
71496 * @license
71497 * Copyright 2018 Google LLC. All Rights Reserved.
71498 * Licensed under the Apache License, Version 2.0 (the "License");
71499 * you may not use this file except in compliance with the License.
71500 * You may obtain a copy of the License at
71501 *
71502 * http://www.apache.org/licenses/LICENSE-2.0
71503 *
71504 * Unless required by applicable law or agreed to in writing, software
71505 * distributed under the License is distributed on an "AS IS" BASIS,
71506 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71507 * See the License for the specific language governing permissions and
71508 * limitations under the License.
71509 *
71510 * =============================================================================
71511 */
71512 /**
71513 * Provide a stream of chunks from a URL.
71514 *
71515 * Note this class first downloads the entire file into memory before providing
71516 * the first element from the stream. This is because the Fetch API does not
71517 * yet reliably provide a reader stream for the response body.
71518 */
71519 async function urlChunkIterator(url, options = {}, fetchFunc) {
71520 let urlString;
71521 let requestInit;
71522 if ((typeof url) === 'string') {
71523 urlString = url;
71524 }
71525 else {
71526 urlString = url.url;
71527 requestInit = getRequestInitFromRequest(url);
71528 }
71529 const response = await (fetchFunc || fetch$1)(urlString, requestInit);
71530 if (response.ok) {
71531 const uint8Array = new Uint8Array(await response.arrayBuffer());
71532 return new FileChunkIterator(uint8Array, options);
71533 }
71534 else {
71535 throw new Error(response.statusText);
71536 }
71537 }
71538 // Generate RequestInit from Request to match tf.util.fetch signature.
71539 const getRequestInitFromRequest = (request) => {
71540 const init = {
71541 method: request.method,
71542 headers: request.headers,
71543 body: request.body,
71544 mode: request.mode,
71545 credentials: request.credentials,
71546 cache: request.cache,
71547 redirect: request.redirect,
71548 referrer: request.referrer,
71549 integrity: request.integrity,
71550 };
71551 return init;
71552 };
71553
71554 /**
71555 * @license
71556 * Copyright 2018 Google LLC. All Rights Reserved.
71557 * Licensed under the Apache License, Version 2.0 (the "License");
71558 * you may not use this file except in compliance with the License.
71559 * You may obtain a copy of the License at
71560 *
71561 * http://www.apache.org/licenses/LICENSE-2.0
71562 *
71563 * Unless required by applicable law or agreed to in writing, software
71564 * distributed under the License is distributed on an "AS IS" BASIS,
71565 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71566 * See the License for the specific language governing permissions and
71567 * limitations under the License.
71568 *
71569 * =============================================================================
71570 */
71571 // Skip tslint any type check cause this method is aiming to check type of
71572 // input.
71573 // tslint:disable-next-line:no-any
71574 function isLocalPath(source) {
71575 return (typeof source === 'string') && source.slice(0, 7) === 'file://';
71576 }
71577
71578 /**
71579 * @license
71580 * Copyright 2018 Google LLC. All Rights Reserved.
71581 * Licensed under the Apache License, Version 2.0 (the "License");
71582 * you may not use this file except in compliance with the License.
71583 * You may obtain a copy of the License at
71584 *
71585 * http://www.apache.org/licenses/LICENSE-2.0
71586 *
71587 * Unless required by applicable law or agreed to in writing, software
71588 * distributed under the License is distributed on an "AS IS" BASIS,
71589 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71590 * See the License for the specific language governing permissions and
71591 * limitations under the License.
71592 *
71593 * =============================================================================
71594 */
71595 /**
71596 * Represents a file, blob, or Uint8Array readable as a stream of binary data
71597 * chunks.
71598 */
71599 class FileDataSource extends DataSource {
71600 /**
71601 * Create a `FileDataSource`.
71602 *
71603 * @param input Local file path, or `File`/`Blob`/`Uint8Array` object to
71604 * read. Local file only works in node environment.
71605 * @param options Options passed to the underlying `FileChunkIterator`s,
71606 * such as {chunksize: 1024}.
71607 */
71608 constructor(input, options = {}) {
71609 super();
71610 this.input = input;
71611 this.options = options;
71612 }
71613 async iterator() {
71614 if (isLocalPath(this.input) && env().get('IS_NODE')) {
71615 // tslint:disable-next-line:no-require-imports
71616 const fs = require('fs');
71617 this.input = fs.readFileSync(this.input.slice(7));
71618 }
71619 // TODO(kangyizhang): Add LocalFileChunkIterator to split local streaming
71620 // with file in browser.
71621 return new FileChunkIterator(this.input, this.options);
71622 }
71623 }
71624
71625 /**
71626 * @license
71627 * Copyright 2018 Google LLC. All Rights Reserved.
71628 * Licensed under the Apache License, Version 2.0 (the "License");
71629 * you may not use this file except in compliance with the License.
71630 * You may obtain a copy of the License at
71631 *
71632 * http://www.apache.org/licenses/LICENSE-2.0
71633 *
71634 * Unless required by applicable law or agreed to in writing, software
71635 * distributed under the License is distributed on an "AS IS" BASIS,
71636 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71637 * See the License for the specific language governing permissions and
71638 * limitations under the License.
71639 *
71640 * =============================================================================
71641 */
71642 /*
71643 * Represents a URL readable as a stream of binary data chunks.
71644 */
71645 class URLDataSource extends DataSource {
71646 /**
71647 * Create a `URLDataSource`.
71648 *
71649 * @param url A source URL string, or a `Request` object.
71650 * @param options Options passed to the underlying `FileChunkIterator`s,
71651 * such as {chunksize: 1024}.
71652 */
71653 constructor(url, fileOptions = {}) {
71654 super();
71655 this.url = url;
71656 this.fileOptions = fileOptions;
71657 }
71658 // TODO(soergel): provide appropriate caching options. Currently this
71659 // will download the URL anew for each call to iterator(). Since we have
71660 // to treat the downloaded file as a blob/buffer anyway, we may as well retain
71661 // it-- but that raises GC issues. Also we may want a persistent disk cache.
71662 async iterator() {
71663 if (isLocalPath(this.url)) {
71664 return (new FileDataSource(this.url, this.fileOptions))
71665 .iterator();
71666 }
71667 else {
71668 return urlChunkIterator(this.url, this.fileOptions);
71669 }
71670 }
71671 }
71672
71673 /**
71674 * @license
71675 * Copyright 2018 Google LLC. All Rights Reserved.
71676 * Licensed under the Apache License, Version 2.0 (the "License");
71677 * you may not use this file except in compliance with the License.
71678 * You may obtain a copy of the License at
71679 *
71680 * http://www.apache.org/licenses/LICENSE-2.0
71681 *
71682 * Unless required by applicable law or agreed to in writing, software
71683 * distributed under the License is distributed on an "AS IS" BASIS,
71684 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71685 * See the License for the specific language governing permissions and
71686 * limitations under the License.
71687 *
71688 * =============================================================================
71689 */
71690 /**
71691 * Create a `CSVDataset` by reading and decoding CSV file(s) from provided URL
71692 * or local path if it's in Node environment.
71693 *
71694 * Note: If isLabel in columnConfigs is `true` for at least one column, the
71695 * element in returned `CSVDataset` will be an object of
71696 * `{xs:features, ys:labels}`: xs is a dict of features key/value pairs, ys
71697 * is a dict of labels key/value pairs. If no column is marked as label,
71698 * returns a dict of features only.
71699 *
71700 * ```js
71701 * const csvUrl =
71702 * 'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv';
71703 *
71704 * async function run() {
71705 * // We want to predict the column "medv", which represents a median value of
71706 * // a home (in $1000s), so we mark it as a label.
71707 * const csvDataset = tf.data.csv(
71708 * csvUrl, {
71709 * columnConfigs: {
71710 * medv: {
71711 * isLabel: true
71712 * }
71713 * }
71714 * });
71715 *
71716 * // Number of features is the number of column names minus one for the label
71717 * // column.
71718 * const numOfFeatures = (await csvDataset.columnNames()).length - 1;
71719 *
71720 * // Prepare the Dataset for training.
71721 * const flattenedDataset =
71722 * csvDataset
71723 * .map(({xs, ys}) =>
71724 * {
71725 * // Convert xs(features) and ys(labels) from object form (keyed by
71726 * // column name) to array form.
71727 * return {xs:Object.values(xs), ys:Object.values(ys)};
71728 * })
71729 * .batch(10);
71730 *
71731 * // Define the model.
71732 * const model = tf.sequential();
71733 * model.add(tf.layers.dense({
71734 * inputShape: [numOfFeatures],
71735 * units: 1
71736 * }));
71737 * model.compile({
71738 * optimizer: tf.train.sgd(0.000001),
71739 * loss: 'meanSquaredError'
71740 * });
71741 *
71742 * // Fit the model using the prepared Dataset
71743 * return model.fitDataset(flattenedDataset, {
71744 * epochs: 10,
71745 * callbacks: {
71746 * onEpochEnd: async (epoch, logs) => {
71747 * console.log(epoch + ':' + logs.loss);
71748 * }
71749 * }
71750 * });
71751 * }
71752 *
71753 * await run();
71754 * ```
71755 *
71756 * @param source URL or local path to get CSV file. If it's a local path, it
71757 * must have prefix `file://` and it only works in node environment.
71758 * @param csvConfig (Optional) A CSVConfig object that contains configurations
71759 * of reading and decoding from CSV file(s).
71760 *
71761 * @doc {
71762 * heading: 'Data',
71763 * subheading: 'Creation',
71764 * namespace: 'data',
71765 * configParamIndices: [1]
71766 * }
71767 */
71768 function csv(source, csvConfig = {}) {
71769 return new CSVDataset(new URLDataSource(source), csvConfig);
71770 }
71771 /**
71772 * Create a `Dataset` that produces each element by calling a provided function.
71773 *
71774 * Note that repeated iterations over this `Dataset` may produce different
71775 * results, because the function will be called anew for each element of each
71776 * iteration.
71777 *
71778 * Also, beware that the sequence of calls to this function may be out of order
71779 * in time with respect to the logical order of the Dataset. This is due to the
71780 * asynchronous lazy nature of stream processing, and depends on downstream
71781 * transformations (e.g. .shuffle()). If the provided function is pure, this is
71782 * no problem, but if it is a closure over a mutable state (e.g., a traversal
71783 * pointer), then the order of the produced elements may be scrambled.
71784 *
71785 * ```js
71786 * let i = -1;
71787 * const func = () =>
71788 * ++i < 5 ? {value: i, done: false} : {value: null, done: true};
71789 * const ds = tf.data.func(func);
71790 * await ds.forEachAsync(e => console.log(e));
71791 * ```
71792 *
71793 * @param f A function that produces one data element on each call.
71794 */
71795 function func(f) {
71796 const iter = iteratorFromFunction(f);
71797 return datasetFromIteratorFn(async () => iter);
71798 }
71799 /**
71800 * Create a `Dataset` that produces each element from provided JavaScript
71801 * generator, which is a function that returns a (potentially async) iterator.
71802 *
71803 * For more information on iterators and generators, see
71804 * https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators .
71805 * For the iterator protocol, see
71806 * https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Iteration_protocols .
71807 *
71808 * Example of creating a dataset from an iterator factory:
71809 * ```js
71810 * function makeIterator() {
71811 * const numElements = 10;
71812 * let index = 0;
71813 *
71814 * const iterator = {
71815 * next: () => {
71816 * let result;
71817 * if (index < numElements) {
71818 * result = {value: index, done: false};
71819 * index++;
71820 * return result;
71821 * }
71822 * return {value: index, done: true};
71823 * }
71824 * };
71825 * return iterator;
71826 * }
71827 * const ds = tf.data.generator(makeIterator);
71828 * await ds.forEachAsync(e => console.log(e));
71829 * ```
71830 *
71831 * Example of creating a dataset from a generator:
71832 * ```js
71833 * function* dataGenerator() {
71834 * const numElements = 10;
71835 * let index = 0;
71836 * while (index < numElements) {
71837 * const x = index;
71838 * index++;
71839 * yield x;
71840 * }
71841 * }
71842 *
71843 * const ds = tf.data.generator(dataGenerator);
71844 * await ds.forEachAsync(e => console.log(e));
71845 * ```
71846 *
71847 * @param generator A JavaScript function that returns
71848 * a (potentially async) JavaScript iterator.
71849 *
71850 * @doc {
71851 * heading: 'Data',
71852 * subheading: 'Creation',
71853 * namespace: 'data',
71854 * configParamIndices: [1]
71855 * }
71856 */
71857 function generator(generator) {
71858 return datasetFromIteratorFn(async () => {
71859 const gen = await generator();
71860 return iteratorFromFunction(() => gen.next());
71861 });
71862 }
71863 /**
71864 * Create an iterator that generates `Tensor`s from webcam video stream. This
71865 * API only works in Browser environment when the device has webcam.
71866 *
71867 * Note: this code snippet only works when the device has a webcam. It will
71868 * request permission to open the webcam when running.
71869 * ```js
71870 * const videoElement = document.createElement('video');
71871 * videoElement.width = 100;
71872 * videoElement.height = 100;
71873 * const cam = await tf.data.webcam(videoElement);
71874 * const img = await cam.capture();
71875 * img.print();
71876 * cam.stop();
71877 * ```
71878 *
71879 * @param webcamVideoElement A `HTMLVideoElement` used to play video from
71880 * webcam. If this element is not provided, a hidden `HTMLVideoElement` will
71881 * be created. In that case, `resizeWidth` and `resizeHeight` must be
71882 * provided to set the generated tensor shape.
71883 * @param webcamConfig A `WebcamConfig` object that contains configurations of
71884 * reading and manipulating data from webcam video stream.
71885 *
71886 * @doc {
71887 * heading: 'Data',
71888 * subheading: 'Creation',
71889 * namespace: 'data',
71890 * ignoreCI: true
71891 * }
71892 */
71893 async function webcam(webcamVideoElement, webcamConfig) {
71894 return WebcamIterator.create(webcamVideoElement, webcamConfig);
71895 }
71896 /**
71897 * Create an iterator that generates frequency-domain spectrogram `Tensor`s from
71898 * microphone audio stream with browser's native FFT. This API only works in
71899 * browser environment when the device has microphone.
71900 *
71901 * Note: this code snippet only works when the device has a microphone. It will
71902 * request permission to open the microphone when running.
71903 * ```js
71904 * const mic = await tf.data.microphone({
71905 * fftSize: 1024,
71906 * columnTruncateLength: 232,
71907 * numFramesPerSpectrogram: 43,
71908 * sampleRateHz:44100,
71909 * includeSpectrogram: true,
71910 * includeWaveform: true
71911 * });
71912 * const audioData = await mic.capture();
71913 * const spectrogramTensor = audioData.spectrogram;
71914 * spectrogramTensor.print();
71915 * const waveformTensor = audioData.waveform;
71916 * waveformTensor.print();
71917 * mic.stop();
71918 * ```
71919 *
71920 * @param microphoneConfig A `MicrophoneConfig` object that contains
71921 * configurations of reading audio data from microphone.
71922 *
71923 * @doc {
71924 * heading: 'Data',
71925 * subheading: 'Creation',
71926 * namespace: 'data',
71927 * ignoreCI: true
71928 * }
71929 */
71930 async function microphone(microphoneConfig) {
71931 return MicrophoneIterator.create(microphoneConfig);
71932 }
71933
71934 /** @license See the LICENSE file. */
71935 // This code is auto-generated, do not modify this file!
71936 const version$4 = '4.22.0';
71937
71938 /**
71939 * @license
71940 * Copyright 2018 Google LLC. All Rights Reserved.
71941 * Licensed under the Apache License, Version 2.0 (the "License");
71942 * you may not use this file except in compliance with the License.
71943 * You may obtain a copy of the License at
71944 *
71945 * http://www.apache.org/licenses/LICENSE-2.0
71946 *
71947 * Unless required by applicable law or agreed to in writing, software
71948 * distributed under the License is distributed on an "AS IS" BASIS,
71949 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71950 * See the License for the specific language governing permissions and
71951 * limitations under the License.
71952 * =============================================================================
71953 */
71954
71955 var index = /*#__PURE__*/Object.freeze({
71956 __proto__: null,
71957 CSVDataset: CSVDataset,
71958 Dataset: Dataset,
71959 FileDataSource: FileDataSource,
71960 TextLineDataset: TextLineDataset,
71961 URLDataSource: URLDataSource,
71962 array: array,
71963 csv: csv,
71964 func: func,
71965 generator: generator,
71966 microphone: microphone,
71967 version_data: version$4,
71968 webcam: webcam,
71969 zip: zip
71970 });
71971
71972 /**
71973 * @license
71974 * Copyright 2019 Google LLC. All Rights Reserved.
71975 * Licensed under the Apache License, Version 2.0 (the "License");
71976 * you may not use this file except in compliance with the License.
71977 * You may obtain a copy of the License at
71978 *
71979 * http://www.apache.org/licenses/LICENSE-2.0
71980 *
71981 * Unless required by applicable law or agreed to in writing, software
71982 * distributed under the License is distributed on an "AS IS" BASIS,
71983 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71984 * See the License for the specific language governing permissions and
71985 * limitations under the License.
71986 * =============================================================================
71987 */
71988 function assertNotComplex$1(tensor, opName) {
71989 if (!Array.isArray(tensor)) {
71990 tensor = [tensor];
71991 }
71992 tensor.forEach(t => {
71993 if (t != null) {
71994 assert$1(t.dtype !== 'complex64', () => `${opName} does not support complex64 tensors in the CPU backend.`);
71995 }
71996 });
71997 }
71998
71999 /**
72000 * @license
72001 * Copyright 2021 Google LLC. All Rights Reserved.
72002 * Licensed under the Apache License, Version 2.0 (the "License");
72003 * you may not use this file except in compliance with the License.
72004 * You may obtain a copy of the License at
72005 *
72006 * http://www.apache.org/licenses/LICENSE-2.0
72007 *
72008 * Unless required by applicable law or agreed to in writing, software
72009 * distributed under the License is distributed on an "AS IS" BASIS,
72010 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72011 * See the License for the specific language governing permissions and
72012 * limitations under the License.
72013 * =============================================================================
72014 */
72015 const whereImpl$1 = whereImpl$2;
72016 class MathBackendCPU extends KernelBackend {
72017 nextDataId() {
72018 return MathBackendCPU.nextDataId++;
72019 }
72020 constructor() {
72021 super();
72022 this.blockSize = 48;
72023 this.firstUse = true;
72024 this.data = new DataStorage(this, engine());
72025 }
72026 write(values, shape, dtype) {
72027 if (this.firstUse) {
72028 this.firstUse = false;
72029 if (env().get('IS_NODE')) {
72030 warn('\n============================\n' +
72031 'Hi, looks like you are running TensorFlow.js in ' +
72032 'Node.js. To speed things up dramatically, install our node ' +
72033 'backend, visit https://github.com/tensorflow/tfjs-node for more details. ' +
72034 '\n============================');
72035 }
72036 }
72037 const dataId = { id: this.nextDataId() };
72038 this.data.set(dataId, { values, dtype, refCount: 1 });
72039 return dataId;
72040 }
72041 /**
72042 * Create a data bucket in cpu backend.
72043 * @param shape Shape of the `TensorInfo`.
72044 * @param dtype DType of the `TensorInfo`.
72045 * @param values The value of the `TensorInfo` stored as a flattened array.
72046 */
72047 makeTensorInfo(shape, dtype, values) {
72048 let outId;
72049 if (dtype === 'string' && values != null && values.length > 0 &&
72050 isString(values[0])) {
72051 const encodedValues = values.map(d => encodeString(d));
72052 outId = this.write(encodedValues, shape, dtype);
72053 }
72054 else {
72055 outId = this.write(values, shape, dtype);
72056 }
72057 return { dataId: outId, shape, dtype };
72058 }
72059 /** Return refCount of a `TensorData`. */
72060 refCount(dataId) {
72061 if (this.data.has(dataId)) {
72062 const tensorData = this.data.get(dataId);
72063 return tensorData.refCount;
72064 }
72065 return 0;
72066 }
72067 /** Increase refCount of a `TensorData`. */
72068 incRef(dataId) {
72069 const tensorData = this.data.get(dataId);
72070 tensorData.refCount++;
72071 }
72072 /** Decrease refCount of a `TensorData`. */
72073 decRef(dataId) {
72074 if (this.data.has(dataId)) {
72075 const tensorData = this.data.get(dataId);
72076 tensorData.refCount--;
72077 }
72078 }
72079 move(dataId, values, shape, dtype, refCount) {
72080 this.data.set(dataId, { values, dtype, refCount });
72081 }
72082 numDataIds() {
72083 return this.data.numDataIds();
72084 }
72085 async read(dataId) {
72086 return this.readSync(dataId);
72087 }
72088 readSync(dataId) {
72089 const { dtype, complexTensorInfos } = this.data.get(dataId);
72090 if (dtype === 'complex64') {
72091 const realValues = this.readSync(complexTensorInfos.real.dataId);
72092 const imagValues = this.readSync(complexTensorInfos.imag.dataId);
72093 return mergeRealAndImagArrays(realValues, imagValues);
72094 }
72095 return convertBackendValuesAndArrayBuffer(this.data.get(dataId).values, dtype);
72096 }
72097 bufferSync(t) {
72098 const data = this.readSync(t.dataId);
72099 if (t.dtype === 'string') {
72100 try {
72101 // Decode the bytes into string.
72102 const strings = data.map(d => decodeString(d));
72103 return buffer(t.shape, t.dtype, strings);
72104 }
72105 catch (_a) {
72106 throw new Error('Failed to decode encoded string bytes into utf-8');
72107 }
72108 }
72109 return buffer(t.shape, t.dtype, data);
72110 }
72111 makeOutput(values, shape, dtype) {
72112 return engine().makeTensorFromTensorInfo(this.makeTensorInfo(shape, dtype, values), this);
72113 }
72114 /**
72115 * Dispose the memory if the dataId has 0 refCount. Return true if the memory
72116 * is released or memory is not managed in this backend, false if memory is
72117 * not cleared.
72118 * @param dataId
72119 * @oaram force Optional, remove the data regardless of refCount
72120 */
72121 disposeData(dataId, force = false) {
72122 if (this.data.has(dataId)) {
72123 this.data.get(dataId).refCount--;
72124 if (!force && this.data.get(dataId).refCount > 0) {
72125 return false;
72126 }
72127 const { complexTensorInfos } = this.data.get(dataId);
72128 if (complexTensorInfos != null) {
72129 this.disposeData(complexTensorInfos.real.dataId, true);
72130 this.disposeData(complexTensorInfos.imag.dataId, true);
72131 }
72132 this.data.delete(dataId);
72133 }
72134 return true;
72135 }
72136 disposeIntermediateTensorInfo(tensorInfo) {
72137 this.disposeData(tensorInfo.dataId);
72138 }
72139 async time(f) {
72140 const start = now();
72141 f();
72142 const kernelMs = now() - start;
72143 return { kernelMs };
72144 }
72145 memory() {
72146 return {
72147 // Unreliable due to automatic gc. The numbers above are cumulative.
72148 unreliable: true,
72149 reasons: ['The reported memory is an upper bound. Due to automatic garbage ' +
72150 'collection, the true allocated memory may be less.']
72151 };
72152 }
72153 where(condition) {
72154 assertNotComplex$1([condition], 'where');
72155 const condVals = this.readSync(condition.dataId);
72156 return whereImpl$1(condition.shape, condVals);
72157 }
72158 dispose() { }
72159 floatPrecision() {
72160 return 32;
72161 }
72162 /** Returns the smallest representable number. */
72163 epsilon() {
72164 return super.epsilon();
72165 }
72166 }
72167 MathBackendCPU.nextDataId = 0;
72168
72169 /**
72170 * @license
72171 * Copyright 2020 Google LLC. All Rights Reserved.
72172 * Licensed under the Apache License, Version 2.0 (the License);
72173 * you may not use this file except in compliance with the License.
72174 * You may obtain a copy of the License at
72175 *
72176 * http://www.apache.org/licenses/LICENSE-2.0
72177 *
72178 * Unless required by applicable law or agreed to in writing, software
72179 * distributed under the License is distributed on an AS IS BASIS,
72180 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72181 * See the License for the specific language governing permissions and
72182 * limitations under the License.
72183 * =============================================================================
72184 */
72185 function simpleAbsImpl(vals) {
72186 const resultValues = new Float32Array(vals.length);
72187 for (let i = 0; i < vals.length; ++i) {
72188 resultValues[i] = Math.abs(vals[i]);
72189 }
72190 return resultValues;
72191 }
72192 const abs$1 = (args) => {
72193 const { x } = args.inputs;
72194 const cpuBackend = args.backend;
72195 assertNotComplex$1(x, 'abs');
72196 let resultValues = new Float32Array(sizeFromShape(x.shape));
72197 const values = cpuBackend.data.get(x.dataId).values;
72198 resultValues = simpleAbsImpl(values);
72199 return cpuBackend.makeOutput(resultValues, x.shape, x.dtype);
72200 };
72201 const absConfig$1 = {
72202 kernelName: Abs,
72203 backendName: 'cpu',
72204 kernelFunc: abs$1,
72205 };
72206
72207 /**
72208 * @license
72209 * Copyright 2020 Google LLC. All Rights Reserved.
72210 * Licensed under the Apache License, Version 2.0 (the "License");
72211 * you may not use this file except in compliance with the License.
72212 * You may obtain a copy of the License at
72213 *
72214 * http://www.apache.org/licenses/LICENSE-2.0
72215 *
72216 * Unless required by applicable law or agreed to in writing, software
72217 * distributed under the License is distributed on an "AS IS" BASIS,
72218 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72219 * See the License for the specific language governing permissions and
72220 * limitations under the License.
72221 * =============================================================================
72222 */
72223 /**
72224 * Template that creates implementation for binary ops. Supports broadcast.
72225 */
72226 function createSimpleBinaryKernelImpl(op) {
72227 return (aShape, bShape, aVals, bVals, dtype) => {
72228 const newShape = assertAndGetBroadcastShape(aShape, bShape);
72229 const resultRank = newShape.length;
72230 const resultStrides = computeStrides(newShape);
72231 const resultSize = sizeFromShape(newShape);
72232 const result = getTypedArrayFromDType(dtype, resultSize);
72233 const aRank = aShape.length;
72234 const bRank = bShape.length;
72235 const aStrides = computeStrides(aShape);
72236 const bStrides = computeStrides(bShape);
72237 const aBroadcastDims = getBroadcastDims$1(aShape, newShape);
72238 const bBroadcastDims = getBroadcastDims$1(bShape, newShape);
72239 if (aBroadcastDims.length + bBroadcastDims.length === 0) {
72240 for (let i = 0; i < result.length; ++i) {
72241 result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
72242 }
72243 }
72244 else {
72245 for (let i = 0; i < result.length; ++i) {
72246 const loc = indexToLoc(i, resultRank, resultStrides);
72247 const aLoc = loc.slice(-aRank);
72248 aBroadcastDims.forEach(d => aLoc[d] = 0);
72249 const aIndex = locToIndex(aLoc, aRank, aStrides);
72250 const bLoc = loc.slice(-bRank);
72251 bBroadcastDims.forEach(d => bLoc[d] = 0);
72252 const bIndex = locToIndex(bLoc, bRank, bStrides);
72253 result[i] = op(aVals[aIndex], bVals[bIndex]);
72254 }
72255 }
72256 return [result, newShape];
72257 };
72258 }
72259
72260 /**
72261 * @license
72262 * Copyright 2020 Google LLC. All Rights Reserved.
72263 * Licensed under the Apache License, Version 2.0 (the "License");
72264 * you may not use this file except in compliance with the License.
72265 * You may obtain a copy of the License at
72266 *
72267 * http://www.apache.org/licenses/LICENSE-2.0
72268 *
72269 * Unless required by applicable law or agreed to in writing, software
72270 * distributed under the License is distributed on an "AS IS" BASIS,
72271 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72272 * See the License for the specific language governing permissions and
72273 * limitations under the License.
72274 * =============================================================================
72275 */
72276 function complex$1(args) {
72277 const { inputs, backend } = args;
72278 const { real, imag } = inputs;
72279 const realVals = backend.data.get(real.dataId).values;
72280 const imagVals = backend.data.get(imag.dataId).values;
72281 const complexInfo = backend.makeTensorInfo(real.shape, 'complex64');
72282 const complex = backend.data.get(complexInfo.dataId);
72283 // The complex tensor owns the underlying real and imag tensorInfos, only the
72284 // complex tensor tracks refCount, when complexData is disposed the
72285 // underlying tensorData will be disposed.
72286 complex.complexTensorInfos = {
72287 real: backend.makeTensorInfo(real.shape, 'float32', realVals),
72288 imag: backend.makeTensorInfo(imag.shape, 'float32', imagVals)
72289 };
72290 return complexInfo;
72291 }
72292 const complexConfig$1 = {
72293 kernelName: Complex,
72294 backendName: 'cpu',
72295 kernelFunc: complex$1
72296 };
72297
72298 /**
72299 * @license
72300 * Copyright 2020 Google LLC. All Rights Reserved.
72301 * Licensed under the Apache License, Version 2.0 (the "License");
72302 * you may not use this file except in compliance with the License.
72303 * You may obtain a copy of the License at
72304 *
72305 * http://www.apache.org/licenses/LICENSE-2.0
72306 *
72307 * Unless required by applicable law or agreed to in writing, software
72308 * distributed under the License is distributed on an "AS IS" BASIS,
72309 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72310 * See the License for the specific language governing permissions and
72311 * limitations under the License.
72312 * =============================================================================
72313 */
72314 /**
72315 * Generates a tensorInfo with all zeros value.
72316 * @param backend cpu backend.
72317 * @param shape Shape for the zeros tensor.
72318 * @param dtype Optional. If set, the result has this dtype.
72319 */
72320 function zeros(backend, shape, dtype = 'float32') {
72321 if (dtype === 'complex64') {
72322 const real = zeros(backend, shape, 'float32');
72323 const imag = zeros(backend, shape, 'float32');
72324 return complex$1({ inputs: { real, imag }, backend });
72325 }
72326 const values = makeZerosTypedArray(sizeFromShape(shape), dtype);
72327 return backend.makeTensorInfo(shape, dtype, values);
72328 }
72329
72330 /**
72331 * @license
72332 * Copyright 2020 Google LLC. All Rights Reserved.
72333 * Licensed under the Apache License, Version 2.0 (the "License");
72334 * you may not use this file except in compliance with the License.
72335 * You may obtain a copy of the License at
72336 *
72337 * http://www.apache.org/licenses/LICENSE-2.0
72338 *
72339 * Unless required by applicable law or agreed to in writing, software
72340 * distributed under the License is distributed on an "AS IS" BASIS,
72341 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72342 * See the License for the specific language governing permissions and
72343 * limitations under the License.
72344 * =============================================================================
72345 */
72346 function identity$1(args) {
72347 const { inputs, backend } = args;
72348 const { x } = inputs;
72349 backend.incRef(x.dataId);
72350 return { dataId: x.dataId, shape: x.shape, dtype: x.dtype };
72351 }
72352 const identityConfig$1 = {
72353 kernelName: Identity$1,
72354 backendName: 'cpu',
72355 kernelFunc: identity$1
72356 };
72357
72358 /**
72359 * @license
72360 * Copyright 2020 Google LLC. All Rights Reserved.
72361 * Licensed under the Apache License, Version 2.0 (the "License");
72362 * you may not use this file except in compliance with the License.
72363 * You may obtain a copy of the License at
72364 *
72365 * http://www.apache.org/licenses/LICENSE-2.0
72366 *
72367 * Unless required by applicable law or agreed to in writing, software
72368 * distributed under the License is distributed on an "AS IS" BASIS,
72369 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72370 * See the License for the specific language governing permissions and
72371 * limitations under the License.
72372 * =============================================================================
72373 */
72374 function real$1(args) {
72375 const { inputs, backend } = args;
72376 const { input } = inputs;
72377 const real = backend.data.get(input.dataId).complexTensorInfos.real;
72378 const realVal = backend.data.get(real.dataId).values;
72379 // When complex tensor is disposed, its underlying parts will be disposed too.
72380 // Make new tensor out of the real value of the complex. This makes sure the
72381 // value is still accessible even if complex tensor is disposed.
72382 return backend.makeTensorInfo(real.shape, real.dtype, realVal);
72383 }
72384 const realConfig$1 = {
72385 kernelName: Real,
72386 backendName: 'cpu',
72387 kernelFunc: real$1
72388 };
72389
72390 /**
72391 * @license
72392 * Copyright 2020 Google LLC. All Rights Reserved.
72393 * Licensed under the Apache License, Version 2.0 (the "License");
72394 * you may not use this file except in compliance with the License.
72395 * You may obtain a copy of the License at
72396 *
72397 * http://www.apache.org/licenses/LICENSE-2.0
72398 *
72399 * Unless required by applicable law or agreed to in writing, software
72400 * distributed under the License is distributed on an "AS IS" BASIS,
72401 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72402 * See the License for the specific language governing permissions and
72403 * limitations under the License.
72404 * =============================================================================
72405 */
72406 function castImpl(values, shape, inputType, dtype) {
72407 if (dtype === 'int32') {
72408 const resultValues = Int32Array.from(values);
72409 return [shape, 'int32', resultValues];
72410 }
72411 if (dtype === 'bool') {
72412 // This is essentially the result of notEqual(x, 0). We avoid using
72413 // kernel notEqual to avoid circular dependency, i.e. binary_utils ->
72414 // cast -> notEqual -> binary_utils.
72415 const zero = toTypedArray([0], inputType);
72416 const [resultData, resultShape] = createSimpleBinaryKernelImpl((a, b) => (a !== b) ? 1 : 0)(shape, [], values, zero, 'bool');
72417 return [resultShape, 'bool', resultData];
72418 }
72419 throw new Error(`Error in Cast: failed to cast ${inputType} to ${dtype}`);
72420 }
72421 function cast$1(args) {
72422 const { inputs, backend, attrs } = args;
72423 const { x } = inputs;
72424 const { dtype } = attrs;
72425 // Casting to complex64.
72426 if (dtype === 'complex64') {
72427 if (x.dtype === 'complex64') {
72428 return identity$1({ inputs: { x }, backend });
72429 }
72430 const zerosTensorInfo = zeros(backend, x.shape, x.dtype);
72431 const floatX = cast$1({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
72432 const result = complex$1({ inputs: { real: floatX, imag: zerosTensorInfo }, backend });
72433 backend.disposeIntermediateTensorInfo(zerosTensorInfo);
72434 backend.disposeIntermediateTensorInfo(floatX);
72435 return result;
72436 }
72437 // Casting from complex64
72438 if (x.dtype === 'complex64') {
72439 const realPart = real$1({ inputs: { input: x }, backend });
72440 const result = cast$1({ inputs: { x: realPart }, backend, attrs: { dtype } });
72441 backend.disposeIntermediateTensorInfo(realPart);
72442 return result;
72443 }
72444 if (!hasEncodingLoss(x.dtype, dtype)) {
72445 // We don't change the underlying data, since we cast to higher
72446 // precision.
72447 const result = identity$1({ inputs: { x }, backend });
72448 return { dataId: result.dataId, shape: result.shape, dtype };
72449 }
72450 const values = backend.data.get(x.dataId).values;
72451 const [resultShape, resultType, resultData] = castImpl(values, x.shape, x.dtype, dtype);
72452 return backend.makeTensorInfo(resultShape, resultType, resultData);
72453 }
72454 const castConfig$1 = {
72455 kernelName: Cast,
72456 backendName: 'cpu',
72457 kernelFunc: cast$1
72458 };
72459
72460 /**
72461 * @license
72462 * Copyright 2020 Google LLC. All Rights Reserved.
72463 * Licensed under the Apache License, Version 2.0 (the "License");
72464 * you may not use this file except in compliance with the License.
72465 * You may obtain a copy of the License at
72466 *
72467 * http://www.apache.org/licenses/LICENSE-2.0
72468 *
72469 * Unless required by applicable law or agreed to in writing, software
72470 * distributed under the License is distributed on an "AS IS" BASIS,
72471 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72472 * See the License for the specific language governing permissions and
72473 * limitations under the License.
72474 * =============================================================================
72475 */
72476 /**
72477 * Template that creates a `KernelFunc` for binary ops.
72478 * @param name Kernel name.
72479 * @param binaryKernelImpl A `SimpleBinaryKernelImpl` for the kernel.
72480 * @param binaryKernelComplexImpl Optional. If exists, represents a
72481 * `ComplexBinaryKernelImpl` for the kernel, will be used when input dtype
72482 * is `complex64`.
72483 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
72484 * result has the same dtype as the first input. This is mainly used in
72485 * comparison kernels, such as Equal, Less, Greater, etc.
72486 */
72487 function binaryKernelFunc$1(name, simpleImpl, complexImpl, dtype) {
72488 if (complexImpl == null) {
72489 return ({ inputs, backend }) => {
72490 const { a, b } = inputs;
72491 const cpuBackend = backend;
72492 assertNotComplex$1([a, b], name);
72493 const aVals = cpuBackend.data.get(a.dataId).values;
72494 const bVals = cpuBackend.data.get(b.dataId).values;
72495 const decodedAVals = a.dtype === 'string' ?
72496 // tslint:disable-next-line: no-any
72497 fromUint8ToStringArray(aVals) :
72498 aVals;
72499 const decodedBVals = a.dtype === 'string' ?
72500 // tslint:disable-next-line: no-any
72501 fromUint8ToStringArray(bVals) :
72502 bVals;
72503 const $dtype = dtype || a.dtype;
72504 const [resultData, resultShape] = simpleImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype);
72505 return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
72506 };
72507 }
72508 return ({ inputs, backend }) => {
72509 const { a, b } = inputs;
72510 const cpuBackend = backend;
72511 if (a.dtype === 'complex64' || b.dtype === 'complex64') {
72512 const $aComplex = cast$1({ inputs: { x: a }, backend: cpuBackend, attrs: { dtype: 'complex64' } });
72513 const $aComplexVals = cpuBackend.data.get($aComplex.dataId);
72514 const aReal = $aComplexVals.complexTensorInfos.real;
72515 const aImag = $aComplexVals.complexTensorInfos.imag;
72516 const aRealVals = cpuBackend.data.get(aReal.dataId).values;
72517 const aImagVals = cpuBackend.data.get(aImag.dataId).values;
72518 const $bComplex = cast$1({ inputs: { x: b }, backend: cpuBackend, attrs: { dtype: 'complex64' } });
72519 const $bComplexVals = cpuBackend.data.get($bComplex.dataId);
72520 const bReal = $bComplexVals.complexTensorInfos.real;
72521 const bImag = $bComplexVals.complexTensorInfos.imag;
72522 const bRealVals = cpuBackend.data.get(bReal.dataId).values;
72523 const bImagVals = cpuBackend.data.get(bImag.dataId).values;
72524 const [resultRealData, resultImagData, resultShape] = complexImpl(a.shape, b.shape, aRealVals, aImagVals, bRealVals, bImagVals);
72525 const resultReal = cpuBackend.makeTensorInfo(resultShape, 'float32', resultRealData);
72526 const resultImag = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImagData);
72527 const result = complex$1({ inputs: { real: resultReal, imag: resultImag }, backend: cpuBackend });
72528 cpuBackend.disposeIntermediateTensorInfo($aComplex);
72529 cpuBackend.disposeIntermediateTensorInfo($bComplex);
72530 cpuBackend.disposeIntermediateTensorInfo(resultReal);
72531 cpuBackend.disposeIntermediateTensorInfo(resultImag);
72532 return result;
72533 }
72534 else {
72535 const aVals = cpuBackend.data.get(a.dataId).values;
72536 const bVals = cpuBackend.data.get(b.dataId).values;
72537 const $dtype = dtype || a.dtype;
72538 const [resultData, resultShape] = simpleImpl(a.shape, b.shape, aVals, bVals, $dtype);
72539 return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
72540 }
72541 };
72542 }
72543 /**
72544 * Template that creates the complex type implementation for binary ops.
72545 * Supports broadcast.
72546 */
72547 function createComplexBinaryKernelImpl(op) {
72548 return (aShape, bShape, aRealVals, aImagVals, bRealVals, bImagVals) => {
72549 const resultShape = assertAndGetBroadcastShape(aShape, bShape);
72550 const resultSize = sizeFromShape(resultShape);
72551 const resultRank = resultShape.length;
72552 const resultStrides = computeStrides(resultShape);
72553 const resultRealVals = getTypedArrayFromDType('float32', resultSize);
72554 const resultImagVals = getTypedArrayFromDType('float32', resultSize);
72555 const aBroadcastDims = getBroadcastDims$1(aShape, resultShape);
72556 const bBroadcastDims = getBroadcastDims$1(bShape, resultShape);
72557 const aVals = mergeRealAndImagArrays(aRealVals, aImagVals);
72558 const bVals = mergeRealAndImagArrays(bRealVals, bImagVals);
72559 const aRank = aShape.length;
72560 const aStrides = computeStrides(aShape);
72561 const bRank = bShape.length;
72562 const bStrides = computeStrides(bShape);
72563 if (aBroadcastDims.length + bBroadcastDims.length === 0) {
72564 for (let i = 0; i < resultRealVals.length; i++) {
72565 const aIdx = i % aVals.length;
72566 const bIdx = i % bVals.length;
72567 const result = op(aVals[aIdx * 2], aVals[aIdx * 2 + 1], bVals[bIdx * 2], bVals[bIdx * 2 + 1]);
72568 resultRealVals[i] = result.real;
72569 resultImagVals[i] = result.imag;
72570 }
72571 }
72572 else {
72573 for (let i = 0; i < resultRealVals.length; i++) {
72574 const loc = indexToLoc(i, resultRank, resultStrides);
72575 const aLoc = loc.slice(-aRank);
72576 aBroadcastDims.forEach(d => aLoc[d] = 0);
72577 const aIndex = locToIndex(aLoc, aRank, aStrides);
72578 const bLoc = loc.slice(-bRank);
72579 bBroadcastDims.forEach(d => bLoc[d] = 0);
72580 const bIndex = locToIndex(bLoc, bRank, bStrides);
72581 const opResult = op(aVals[aIndex * 2], aVals[aIndex * 2 + 1], bVals[bIndex * 2], bVals[bIndex * 2 + 1]);
72582 resultRealVals[i] = opResult.real;
72583 resultImagVals[i] = opResult.imag;
72584 }
72585 }
72586 return [resultRealVals, resultImagVals, resultShape];
72587 };
72588 }
72589
72590 /**
72591 * @license
72592 * Copyright 2020 Google LLC. All Rights Reserved.
72593 * Licensed under the Apache License, Version 2.0 (the "License");
72594 * you may not use this file except in compliance with the License.
72595 * You may obtain a copy of the License at
72596 *
72597 * http://www.apache.org/licenses/LICENSE-2.0
72598 *
72599 * Unless required by applicable law or agreed to in writing, software
72600 * distributed under the License is distributed on an "AS IS" BASIS,
72601 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72602 * See the License for the specific language governing permissions and
72603 * limitations under the License.
72604 * =============================================================================
72605 */
72606 const addImpl = createSimpleBinaryKernelImpl(((a, b) => a + b));
72607 const addComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
72608 return { real: aReal + bReal, imag: aImag + bImag };
72609 }));
72610 const add = binaryKernelFunc$1(Add$1, addImpl, addComplexImpl);
72611 const addConfig$1 = {
72612 kernelName: Add$1,
72613 backendName: 'cpu',
72614 kernelFunc: add
72615 };
72616
72617 /**
72618 * @license
72619 * Copyright 2020 Google LLC. All Rights Reserved.
72620 * Licensed under the Apache License, Version 2.0 (the "License");
72621 * you may not use this file except in compliance with the License.
72622 * You may obtain a copy of the License at
72623 *
72624 * http://www.apache.org/licenses/LICENSE-2.0
72625 *
72626 * Unless required by applicable law or agreed to in writing, software
72627 * distributed under the License is distributed on an "AS IS" BASIS,
72628 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72629 * See the License for the specific language governing permissions and
72630 * limitations under the License.
72631 * =============================================================================
72632 */
72633 function bincountImpl(xVals, weightsVals, weightsDtype, weightsShape, size) {
72634 const weightsSize = sizeFromShape(weightsShape);
72635 const outVals = makeZerosTypedArray(size, weightsDtype);
72636 for (let i = 0; i < xVals.length; i++) {
72637 const value = xVals[i];
72638 if (value < 0) {
72639 throw new Error('Input x must be non-negative!');
72640 }
72641 if (value >= size) {
72642 continue;
72643 }
72644 if (weightsSize > 0) {
72645 outVals[value] += weightsVals[i];
72646 }
72647 else {
72648 outVals[value] += 1;
72649 }
72650 }
72651 return outVals;
72652 }
72653 function bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput = false) {
72654 const numRows = xBuf.shape[0];
72655 const numCols = xBuf.shape[1];
72656 const outBuf = buffer([numRows, size], weightsBuf.dtype);
72657 for (let i = 0; i < numRows; i++) {
72658 for (let j = 0; j < numCols; j++) {
72659 const value = xBuf.get(i, j);
72660 if (value < 0) {
72661 throw new Error('Input x must be non-negative!');
72662 }
72663 if (value >= size) {
72664 continue;
72665 }
72666 if (binaryOutput) {
72667 outBuf.set(1, i, value);
72668 }
72669 else {
72670 if (weightsBuf.size > 0) {
72671 outBuf.set(outBuf.get(i, value) + weightsBuf.get(i, j), i, value);
72672 }
72673 else {
72674 outBuf.set(outBuf.get(i, value) + 1, i, value);
72675 }
72676 }
72677 }
72678 }
72679 return outBuf;
72680 }
72681
72682 /**
72683 * @license
72684 * Copyright 2023 Google LLC.
72685 * Licensed under the Apache License, Version 2.0 (the "License");
72686 * you may not use this file except in compliance with the License.
72687 * You may obtain a copy of the License at
72688 *
72689 * http://www.apache.org/licenses/LICENSE-2.0
72690 *
72691 * Unless required by applicable law or agreed to in writing, software
72692 * distributed under the License is distributed on an "AS IS" BASIS,
72693 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72694 * See the License for the specific language governing permissions and
72695 * limitations under the License.
72696 * =============================================================================
72697 */
72698 const bitwiseAndImpl = createSimpleBinaryKernelImpl(((a, b) => a & b));
72699 const bitwiseAnd$1 = binaryKernelFunc$1(BitwiseAnd, bitwiseAndImpl);
72700 const bitwiseAndConfig$1 = {
72701 kernelName: BitwiseAnd,
72702 backendName: 'cpu',
72703 kernelFunc: bitwiseAnd$1
72704 };
72705
72706 /**
72707 * @license
72708 * Copyright 2020 Google LLC. All Rights Reserved.
72709 * Licensed under the Apache License, Version 2.0 (the "License");
72710 * you may not use this file except in compliance with the License.
72711 * You may obtain a copy of the License at
72712 *
72713 * http://www.apache.org/licenses/LICENSE-2.0
72714 *
72715 * Unless required by applicable law or agreed to in writing, software
72716 * distributed under the License is distributed on an "AS IS" BASIS,
72717 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72718 * See the License for the specific language governing permissions and
72719 * limitations under the License.
72720 * =============================================================================
72721 */
72722 /**
72723 * Template that creates implementation for unary op.
72724 */
72725 function createSimpleUnaryImpl(op) {
72726 return (values, dtype, attrs) => {
72727 const newValues = getArrayFromDType(dtype, values.length);
72728 for (let i = 0; i < values.length; ++i) {
72729 newValues[i] = op(values[i], attrs);
72730 }
72731 return newValues;
72732 };
72733 }
72734
72735 /**
72736 * @license
72737 * Copyright 2020 Google LLC. All Rights Reserved.
72738 * Licensed under the Apache License, Version 2.0 (the "License");
72739 * you may not use this file except in compliance with the License.
72740 * You may obtain a copy of the License at
72741 *
72742 * http://www.apache.org/licenses/LICENSE-2.0
72743 *
72744 * Unless required by applicable law or agreed to in writing, software
72745 * distributed under the License is distributed on an "AS IS" BASIS,
72746 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72747 * See the License for the specific language governing permissions and
72748 * limitations under the License.
72749 * =============================================================================
72750 */
72751 /**
72752 * Template that creates a `KernelFunc` for unary ops.
72753 * @param name Kernel name.
72754 * @param op A `SimpleUnaryOperation` for the kernel.
72755 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
72756 * result has the same dtype as the input. This is mainly used in certain
72757 * kernels that return bool type, such as isFinite, isInf, etc.
72758 */
72759 function unaryKernelFunc$1(name, op, dtype) {
72760 const impl = createSimpleUnaryImpl(op);
72761 return unaryKernelFuncFromImpl(name, impl, dtype);
72762 }
72763 /**
72764 * Template that creates a `KernelFunc` for unary ops from the given
72765 * `SimpleUnaryImpl`..
72766 * @param name Kernel name.
72767 * @param unaryImpl A `SimpleUnaryImpl` that implements the op.
72768 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
72769 * result has the same dtype as the input. This is mainly used in certain
72770 * kernels that return bool type, such as isFinite, isInf, etc.
72771 */
72772 function unaryKernelFuncFromImpl(name, unaryImpl, dtype) {
72773 return ({ inputs, attrs, backend }) => {
72774 const { x } = inputs;
72775 assertNotComplex$1(x, name);
72776 const cpuBackend = backend;
72777 const values = cpuBackend.data.get(x.dataId).values;
72778 let decoded;
72779 if (x.dtype === 'string') {
72780 if (!Array.isArray(values)) {
72781 throw new Error('String tensor\'s value was not an instance of Array');
72782 }
72783 decoded = fromUint8ToStringArray(values);
72784 }
72785 else {
72786 decoded = values;
72787 }
72788 const $dtype = dtype || x.dtype;
72789 const newValues = unaryImpl(decoded, $dtype, attrs);
72790 return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
72791 };
72792 }
72793
72794 /**
72795 * @license
72796 * Copyright 2020 Google LLC. All Rights Reserved.
72797 * Licensed under the Apache License, Version 2.0 (the License);
72798 * you may not use this file except in compliance with the License.
72799 * You may obtain a copy of the License at
72800 *
72801 * http://www.apache.org/licenses/LICENSE-2.0
72802 *
72803 * Unless required by applicable law or agreed to in writing, software
72804 * distributed under the License is distributed on an AS IS BASIS,
72805 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72806 * See the License for the specific language governing permissions and
72807 * limitations under the License.
72808 * =============================================================================
72809 */
72810 const ceilImpl = createSimpleUnaryImpl((xi) => Math.ceil(xi));
72811 const ceil$1 = unaryKernelFuncFromImpl(Ceil, ceilImpl);
72812 const ceilConfig$1 = {
72813 kernelName: Ceil,
72814 backendName: 'cpu',
72815 kernelFunc: ceil$1,
72816 };
72817
72818 /**
72819 * @license
72820 * Copyright 2020 Google LLC. All Rights Reserved.
72821 * Licensed under the Apache License, Version 2.0 (the "License");
72822 * you may not use this file except in compliance with the License.
72823 * You may obtain a copy of the License at
72824 *
72825 * http://www.apache.org/licenses/LICENSE-2.0
72826 *
72827 * Unless required by applicable law or agreed to in writing, software
72828 * distributed under the License is distributed on an "AS IS" BASIS,
72829 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72830 * See the License for the specific language governing permissions and
72831 * limitations under the License.
72832 * =============================================================================
72833 */
72834 function concatImpl$1(inputs, outShape, dtype, simplyConcat) {
72835 const outVals = getArrayFromDType(dtype, sizeFromShape(outShape));
72836 if (simplyConcat && dtype !== 'string') {
72837 // Use built-in TypedArray.set() method for speed.
72838 let offset = 0;
72839 inputs.forEach(input => {
72840 const size = sizeFromShape(input.shape);
72841 outVals.set(input.vals, offset);
72842 offset += size;
72843 });
72844 }
72845 else {
72846 let colOffset = 0;
72847 inputs.forEach(input => {
72848 const decodedData = dtype === 'string' ?
72849 fromUint8ToStringArray(input.vals) :
72850 input.vals;
72851 let tIdx = 0;
72852 for (let row = 0; row < input.shape[0]; ++row) {
72853 const resIdx = row * outShape[1] + colOffset;
72854 for (let col = 0; col < input.shape[1]; ++col) {
72855 outVals[resIdx + col] = decodedData[tIdx++];
72856 }
72857 }
72858 colOffset += input.shape[1];
72859 });
72860 }
72861 return outVals;
72862 }
72863
72864 /**
72865 * @license
72866 * Copyright 2020 Google LLC. All Rights Reserved.
72867 * Licensed under the Apache License, Version 2.0 (the "License");
72868 * you may not use this file except in compliance with the License.
72869 * You may obtain a copy of the License at
72870 *
72871 * http://www.apache.org/licenses/LICENSE-2.0
72872 *
72873 * Unless required by applicable law or agreed to in writing, software
72874 * distributed under the License is distributed on an "AS IS" BASIS,
72875 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72876 * See the License for the specific language governing permissions and
72877 * limitations under the License.
72878 * =============================================================================
72879 */
72880 const equalImpl = createSimpleBinaryKernelImpl((a, b) => (a === b) ? 1 : 0);
72881 const equal$1 = binaryKernelFunc$1(Equal, equalImpl, null /* complexImpl */, 'bool');
72882 const equalConfig$1 = {
72883 kernelName: Equal,
72884 backendName: 'cpu',
72885 kernelFunc: equal$1
72886 };
72887
72888 /**
72889 * @license
72890 * Copyright 2020 Google LLC. All Rights Reserved.
72891 * Licensed under the Apache License, Version 2.0 (the License);
72892 * you may not use this file except in compliance with the License.
72893 * You may obtain a copy of the License at
72894 *
72895 * http://www.apache.org/licenses/LICENSE-2.0
72896 *
72897 * Unless required by applicable law or agreed to in writing, software
72898 * distributed under the License is distributed on an AS IS BASIS,
72899 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72900 * See the License for the specific language governing permissions and
72901 * limitations under the License.
72902 * =============================================================================
72903 */
72904 const expImpl = createSimpleUnaryImpl((xi) => Math.exp(xi));
72905 const exp$1 = unaryKernelFuncFromImpl(Exp, expImpl, 'float32');
72906 const expConfig$1 = {
72907 kernelName: Exp,
72908 backendName: 'cpu',
72909 kernelFunc: exp$1,
72910 };
72911
72912 /**
72913 * @license
72914 * Copyright 2020 Google LLC. All Rights Reserved.
72915 * Licensed under the Apache License, Version 2.0 (the License);
72916 * you may not use this file except in compliance with the License.
72917 * You may obtain a copy of the License at
72918 *
72919 * http://www.apache.org/licenses/LICENSE-2.0
72920 *
72921 * Unless required by applicable law or agreed to in writing, software
72922 * distributed under the License is distributed on an AS IS BASIS,
72923 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72924 * See the License for the specific language governing permissions and
72925 * limitations under the License.
72926 * =============================================================================
72927 */
72928 const expm1Impl = createSimpleUnaryImpl((xi) => Math.expm1(xi));
72929 const expm1$1 = unaryKernelFuncFromImpl(Expm1, expm1Impl);
72930 const expm1Config$1 = {
72931 kernelName: Expm1,
72932 backendName: 'cpu',
72933 kernelFunc: expm1$1,
72934 };
72935
72936 /**
72937 * @license
72938 * Copyright 2020 Google LLC. All Rights Reserved.
72939 * Licensed under the Apache License, Version 2.0 (the License);
72940 * you may not use this file except in compliance with the License.
72941 * You may obtain a copy of the License at
72942 *
72943 * http://www.apache.org/licenses/LICENSE-2.0
72944 *
72945 * Unless required by applicable law or agreed to in writing, software
72946 * distributed under the License is distributed on an AS IS BASIS,
72947 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72948 * See the License for the specific language governing permissions and
72949 * limitations under the License.
72950 * =============================================================================
72951 */
72952 const floorImpl = createSimpleUnaryImpl((xi) => Math.floor(xi));
72953 const floor$1 = unaryKernelFuncFromImpl(Floor, floorImpl);
72954 const floorConfig$1 = {
72955 kernelName: Floor,
72956 backendName: 'cpu',
72957 kernelFunc: floor$1,
72958 };
72959
72960 /**
72961 * @license
72962 * Copyright 2020 Google LLC. All Rights Reserved.
72963 * Licensed under the Apache License, Version 2.0 (the "License");
72964 * you may not use this file except in compliance with the License.
72965 * You may obtain a copy of the License at
72966 *
72967 * http://www.apache.org/licenses/LICENSE-2.0
72968 *
72969 * Unless required by applicable law or agreed to in writing, software
72970 * distributed under the License is distributed on an "AS IS" BASIS,
72971 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72972 * See the License for the specific language governing permissions and
72973 * limitations under the License.
72974 * =============================================================================
72975 */
72976 const floorDivImpl = createSimpleBinaryKernelImpl((a, b) => Math.floor(a / b));
72977 const floorDiv$1 = binaryKernelFunc$1(FloorDiv, floorDivImpl, null /* complexImpl */, 'int32');
72978 const floorDivConfig$1 = {
72979 kernelName: FloorDiv,
72980 backendName: 'cpu',
72981 kernelFunc: floorDiv$1
72982 };
72983
72984 /**
72985 * @license
72986 * Copyright 2021 Google LLC. All Rights Reserved.
72987 * Licensed under the Apache License, Version 2.0 (the "License");
72988 * you may not use this file except in compliance with the License.
72989 * You may obtain a copy of the License at
72990 *
72991 * http://www.apache.org/licenses/LICENSE-2.0
72992 *
72993 * Unless required by applicable law or agreed to in writing, software
72994 * distributed under the License is distributed on an "AS IS" BASIS,
72995 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72996 * See the License for the specific language governing permissions and
72997 * limitations under the License.
72998 * =============================================================================
72999 */
73000 function gatherNdImpl(indicesData, paramsBuf, dtype, numSlices, sliceRank, sliceSize, strides, paramsShape, paramsSize) {
73001 const outBuf = buffer([numSlices, sliceSize], dtype);
73002 for (let i = 0; i < numSlices; i++) {
73003 const index = [];
73004 let flattenIndex = 0;
73005 for (let j = 0; j < sliceRank; j++) {
73006 const dim = indicesData[i * sliceRank + j];
73007 flattenIndex += dim * strides[j];
73008 index.push(dim);
73009 }
73010 if (flattenIndex < 0 || flattenIndex >= paramsSize / sliceSize) {
73011 throw new Error(`Invalid indices: ${index} does not index into ${paramsShape}`);
73012 }
73013 for (let k = 0; k < sliceSize; k++) {
73014 outBuf.values[i * sliceSize + k] =
73015 paramsBuf.get(...paramsBuf.indexToLoc(flattenIndex * sliceSize + k));
73016 }
73017 }
73018 return outBuf;
73019 }
73020
73021 /**
73022 * @license
73023 * Copyright 2020 Google LLC. All Rights Reserved.
73024 * Licensed under the Apache License, Version 2.0 (the "License");
73025 * you may not use this file except in compliance with the License.
73026 * You may obtain a copy of the License at
73027 *
73028 * http://www.apache.org/licenses/LICENSE-2.0
73029 *
73030 * Unless required by applicable law or agreed to in writing, software
73031 * distributed under the License is distributed on an "AS IS" BASIS,
73032 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73033 * See the License for the specific language governing permissions and
73034 * limitations under the License.
73035 * =============================================================================
73036 */
73037 function gatherV2Impl(xBuf, indicesBuf, flattenOutputShape) {
73038 const outBuf = buffer(flattenOutputShape, xBuf.dtype);
73039 for (let i = 0; i < outBuf.size; ++i) {
73040 const newLoc = outBuf.indexToLoc(i);
73041 const originalLoc = newLoc.slice();
73042 const batchIdx = originalLoc[0];
73043 const indicesIdx = originalLoc[2];
73044 const indicesIndex = indicesBuf.locToIndex([batchIdx, indicesIdx]);
73045 originalLoc[2] = indicesBuf.values[indicesIndex];
73046 const originalIndex = xBuf.locToIndex(originalLoc);
73047 if (0 <= originalIndex && originalIndex < xBuf.values.length) {
73048 outBuf.values[i] = xBuf.values[originalIndex];
73049 } // Else, index is out of bounds, so leave the default zero val in outBuf.
73050 }
73051 return outBuf;
73052 }
73053
73054 /**
73055 * @license
73056 * Copyright 2020 Google LLC. All Rights Reserved.
73057 * Licensed under the Apache License, Version 2.0 (the "License");
73058 * you may not use this file except in compliance with the License.
73059 * You may obtain a copy of the License at
73060 *
73061 * http://www.apache.org/licenses/LICENSE-2.0
73062 *
73063 * Unless required by applicable law or agreed to in writing, software
73064 * distributed under the License is distributed on an "AS IS" BASIS,
73065 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73066 * See the License for the specific language governing permissions and
73067 * limitations under the License.
73068 * =============================================================================
73069 */
73070 const greaterImpl = createSimpleBinaryKernelImpl((a, b) => (a > b) ? 1 : 0);
73071 const greater$1 = binaryKernelFunc$1(Greater, greaterImpl, null /* complexImpl */, 'bool');
73072 const greaterConfig$1 = {
73073 kernelName: Greater,
73074 backendName: 'cpu',
73075 kernelFunc: greater$1
73076 };
73077
73078 /**
73079 * @license
73080 * Copyright 2020 Google LLC. All Rights Reserved.
73081 * Licensed under the Apache License, Version 2.0 (the "License");
73082 * you may not use this file except in compliance with the License.
73083 * You may obtain a copy of the License at
73084 *
73085 * http://www.apache.org/licenses/LICENSE-2.0
73086 *
73087 * Unless required by applicable law or agreed to in writing, software
73088 * distributed under the License is distributed on an "AS IS" BASIS,
73089 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73090 * See the License for the specific language governing permissions and
73091 * limitations under the License.
73092 * =============================================================================
73093 */
73094 const greaterEqualImpl = createSimpleBinaryKernelImpl((a, b) => (a >= b) ? 1 : 0);
73095 const greaterEqual$1 = binaryKernelFunc$1(GreaterEqual, greaterEqualImpl, null /* complexImpl */, 'bool');
73096 const greaterEqualConfig$1 = {
73097 kernelName: GreaterEqual,
73098 backendName: 'cpu',
73099 kernelFunc: greaterEqual$1
73100 };
73101
73102 /**
73103 * @license
73104 * Copyright 2020 Google LLC. All Rights Reserved.
73105 * Licensed under the Apache License, Version 2.0 (the "License");
73106 * you may not use this file except in compliance with the License.
73107 * You may obtain a copy of the License at
73108 *
73109 * http://www.apache.org/licenses/LICENSE-2.0
73110 *
73111 * Unless required by applicable law or agreed to in writing, software
73112 * distributed under the License is distributed on an "AS IS" BASIS,
73113 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73114 * See the License for the specific language governing permissions and
73115 * limitations under the License.
73116 * =============================================================================
73117 */
73118 const lessImpl = createSimpleBinaryKernelImpl((a, b) => (a < b) ? 1 : 0);
73119 const less$1 = binaryKernelFunc$1(Less, lessImpl, null /* complexImpl */, 'bool');
73120 const lessConfig$1 = {
73121 kernelName: Less,
73122 backendName: 'cpu',
73123 kernelFunc: less$1
73124 };
73125
73126 /**
73127 * @license
73128 * Copyright 2020 Google LLC. All Rights Reserved.
73129 * Licensed under the Apache License, Version 2.0 (the "License");
73130 * you may not use this file except in compliance with the License.
73131 * You may obtain a copy of the License at
73132 *
73133 * http://www.apache.org/licenses/LICENSE-2.0
73134 *
73135 * Unless required by applicable law or agreed to in writing, software
73136 * distributed under the License is distributed on an "AS IS" BASIS,
73137 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73138 * See the License for the specific language governing permissions and
73139 * limitations under the License.
73140 * =============================================================================
73141 */
73142 const lessEqualImpl = createSimpleBinaryKernelImpl((a, b) => (a <= b) ? 1 : 0);
73143 const lessEqual$1 = binaryKernelFunc$1(LessEqual, lessEqualImpl, null /* complexImpl */, 'bool');
73144 const lessEqualConfig$1 = {
73145 kernelName: LessEqual,
73146 backendName: 'cpu',
73147 kernelFunc: lessEqual$1
73148 };
73149
73150 /**
73151 * @license
73152 * Copyright 2020 Google LLC. All Rights Reserved.
73153 * Licensed under the Apache License, Version 2.0 (the "License");
73154 * you may not use this file except in compliance with the License.
73155 * You may obtain a copy of the License at
73156 *
73157 * http://www.apache.org/licenses/LICENSE-2.0
73158 *
73159 * Unless required by applicable law or agreed to in writing, software
73160 * distributed under the License is distributed on an "AS IS" BASIS,
73161 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73162 * See the License for the specific language governing permissions and
73163 * limitations under the License.
73164 * =============================================================================
73165 */
73166 function linSpaceImpl(start, stop, num) {
73167 const step = (stop - start) / (num - 1);
73168 const values = makeZerosTypedArray(num, 'float32');
73169 values[0] = start;
73170 for (let i = 1; i < values.length; i++) {
73171 values[i] = values[i - 1] + step;
73172 }
73173 return values;
73174 }
73175
73176 /**
73177 * @license
73178 * Copyright 2020 Google LLC. All Rights Reserved.
73179 * Licensed under the Apache License, Version 2.0 (the License);
73180 * you may not use this file except in compliance with the License.
73181 * You may obtain a copy of the License at
73182 *
73183 * http://www.apache.org/licenses/LICENSE-2.0
73184 *
73185 * Unless required by applicable law or agreed to in writing, software
73186 * distributed under the License is distributed on an AS IS BASIS,
73187 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73188 * See the License for the specific language governing permissions and
73189 * limitations under the License.
73190 * =============================================================================
73191 */
73192 const logImpl = createSimpleUnaryImpl((xi) => Math.log(xi));
73193 const log$1 = unaryKernelFuncFromImpl(Log, logImpl);
73194 const logConfig$1 = {
73195 kernelName: Log,
73196 backendName: 'cpu',
73197 kernelFunc: log$1,
73198 };
73199
73200 /**
73201 * @license
73202 * Copyright 2020 Google LLC. All Rights Reserved.
73203 * Licensed under the Apache License, Version 2.0 (the "License");
73204 * you may not use this file except in compliance with the License.
73205 * You may obtain a copy of the License at
73206 *
73207 * http://www.apache.org/licenses/LICENSE-2.0
73208 *
73209 * Unless required by applicable law or agreed to in writing, software
73210 * distributed under the License is distributed on an "AS IS" BASIS,
73211 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73212 * See the License for the specific language governing permissions and
73213 * limitations under the License.
73214 * =============================================================================
73215 */
73216 function maxImpl$1(aVals, reduceSize, outShape, dtype) {
73217 const vals = getTypedArrayFromDType(dtype, sizeFromShape(outShape));
73218 for (let i = 0; i < vals.length; ++i) {
73219 const offset = i * reduceSize;
73220 let max = aVals[offset];
73221 for (let j = 0; j < reduceSize; ++j) {
73222 const value = aVals[offset + j];
73223 if (Number.isNaN(value) ||
73224 value > max) { // comparison with NaN always return false
73225 max = value;
73226 }
73227 }
73228 vals[i] = max;
73229 }
73230 return vals;
73231 }
73232
73233 /**
73234 * @license
73235 * Copyright 2020 Google LLC. All Rights Reserved.
73236 * Licensed under the Apache License, Version 2.0 (the "License");
73237 * you may not use this file except in compliance with the License.
73238 * You may obtain a copy of the License at
73239 *
73240 * http://www.apache.org/licenses/LICENSE-2.0
73241 *
73242 * Unless required by applicable law or agreed to in writing, software
73243 * distributed under the License is distributed on an "AS IS" BASIS,
73244 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73245 * See the License for the specific language governing permissions and
73246 * limitations under the License.
73247 * =============================================================================
73248 */
73249 const maximumImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => Math.max(aValue, bValue)));
73250 const maximum$1 = binaryKernelFunc$1(Maximum$1, maximumImpl);
73251 const maximumConfig$1 = {
73252 kernelName: Maximum$1,
73253 backendName: 'cpu',
73254 kernelFunc: maximum$1
73255 };
73256
73257 /**
73258 * @license
73259 * Copyright 2020 Google LLC. All Rights Reserved.
73260 * Licensed under the Apache License, Version 2.0 (the "License");
73261 * you may not use this file except in compliance with the License.
73262 * You may obtain a copy of the License at
73263 *
73264 * http://www.apache.org/licenses/LICENSE-2.0
73265 *
73266 * Unless required by applicable law or agreed to in writing, software
73267 * distributed under the License is distributed on an "AS IS" BASIS,
73268 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73269 * See the License for the specific language governing permissions and
73270 * limitations under the License.
73271 * =============================================================================
73272 */
73273 const minimumImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => Math.min(aValue, bValue)));
73274 const minimum$1 = binaryKernelFunc$1(Minimum$1, minimumImpl);
73275 const minimumConfig$1 = {
73276 kernelName: Minimum$1,
73277 backendName: 'cpu',
73278 kernelFunc: minimum$1
73279 };
73280
73281 /**
73282 * @license
73283 * Copyright 2020 Google LLC. All Rights Reserved.
73284 * Licensed under the Apache License, Version 2.0 (the "License");
73285 * you may not use this file except in compliance with the License.
73286 * You may obtain a copy of the License at
73287 *
73288 * http://www.apache.org/licenses/LICENSE-2.0
73289 *
73290 * Unless required by applicable law or agreed to in writing, software
73291 * distributed under the License is distributed on an "AS IS" BASIS,
73292 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73293 * See the License for the specific language governing permissions and
73294 * limitations under the License.
73295 * =============================================================================
73296 */
73297 const multiplyImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue * bValue));
73298 const multiplyComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
73299 return {
73300 real: aReal * bReal - aImag * bImag,
73301 imag: aReal * bImag + aImag * bReal
73302 };
73303 }));
73304 const multiply$1 = binaryKernelFunc$1(Multiply$1, multiplyImpl, multiplyComplexImpl);
73305 const multiplyConfig$1 = {
73306 kernelName: Multiply$1,
73307 backendName: 'cpu',
73308 kernelFunc: multiply$1
73309 };
73310
73311 /**
73312 * @license
73313 * Copyright 2020 Google LLC. All Rights Reserved.
73314 * Licensed under the Apache License, Version 2.0 (the "License");
73315 * you may not use this file except in compliance with the License.
73316 * You may obtain a copy of the License at
73317 *
73318 * http://www.apache.org/licenses/LICENSE-2.0
73319 *
73320 * Unless required by applicable law or agreed to in writing, software
73321 * distributed under the License is distributed on an "AS IS" BASIS,
73322 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73323 * See the License for the specific language governing permissions and
73324 * limitations under the License.
73325 * =============================================================================
73326 */
73327 function negImpl(xVals, xShape, xDtype) {
73328 const minusOne = createScalarValue(-1, xDtype);
73329 return multiplyImpl([], xShape, minusOne, xVals, xDtype);
73330 }
73331 function neg$1(args) {
73332 const { inputs, backend } = args;
73333 const { x } = inputs;
73334 assertNotComplex$1(x, 'neg');
73335 const xVals = backend.data.get(x.dataId).values;
73336 const [res, newShape] = negImpl(xVals, x.shape, x.dtype);
73337 return backend.makeTensorInfo(newShape, x.dtype, res);
73338 }
73339 const negConfig$1 = {
73340 kernelName: Neg,
73341 backendName: 'cpu',
73342 kernelFunc: neg$1
73343 };
73344
73345 /**
73346 * @license
73347 * Copyright 2020 Google LLC. All Rights Reserved.
73348 * Licensed under the Apache License, Version 2.0 (the "License");
73349 * you may not use this file except in compliance with the License.
73350 * You may obtain a copy of the License at
73351 *
73352 * http://www.apache.org/licenses/LICENSE-2.0
73353 *
73354 * Unless required by applicable law or agreed to in writing, software
73355 * distributed under the License is distributed on an "AS IS" BASIS,
73356 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73357 * See the License for the specific language governing permissions and
73358 * limitations under the License.
73359 * =============================================================================
73360 */
73361 const notEqualImpl = createSimpleBinaryKernelImpl(((a, b) => (a !== b) ? 1 : 0));
73362 const notEqual$1 = binaryKernelFunc$1(NotEqual, notEqualImpl, null /* complexOp */, 'bool');
73363 const notEqualConfig$1 = {
73364 kernelName: NotEqual,
73365 backendName: 'cpu',
73366 kernelFunc: notEqual$1
73367 };
73368
73369 /**
73370 * @license
73371 * Copyright 2020 Google LLC. All Rights Reserved.
73372 * Licensed under the Apache License, Version 2.0 (the "License");
73373 * you may not use this file except in compliance with the License.
73374 * You may obtain a copy of the License at
73375 *
73376 * http://www.apache.org/licenses/LICENSE-2.0
73377 *
73378 * Unless required by applicable law or agreed to in writing, software
73379 * distributed under the License is distributed on an "AS IS" BASIS,
73380 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73381 * See the License for the specific language governing permissions and
73382 * limitations under the License.
73383 * =============================================================================
73384 */
73385 function transposeImpl$1(xVals, xShape, dtype, perm, newShape) {
73386 const xRank = xShape.length;
73387 const xSize = sizeFromShape(xShape);
73388 const xStrides = computeStrides(xShape);
73389 const newStrides = computeStrides(newShape);
73390 const result = getTypedArrayFromDType(dtype, sizeFromShape(newShape));
73391 for (let i = 0; i < xSize; ++i) {
73392 const loc = indexToLoc(i, xRank, xStrides);
73393 // Permute location.
73394 const newLoc = new Array(loc.length);
73395 for (let i = 0; i < newLoc.length; i++) {
73396 newLoc[i] = loc[perm[i]];
73397 }
73398 const newIndex = locToIndex(newLoc, xRank, newStrides);
73399 result[newIndex] = xVals[i];
73400 }
73401 return result;
73402 }
73403
73404 /**
73405 * @license
73406 * Copyright 2020 Google LLC. All Rights Reserved.
73407 * Licensed under the Apache License, Version 2.0 (the "License");
73408 * you may not use this file except in compliance with the License.
73409 * You may obtain a copy of the License at
73410 *
73411 * http://www.apache.org/licenses/LICENSE-2.0
73412 *
73413 * Unless required by applicable law or agreed to in writing, software
73414 * distributed under the License is distributed on an "AS IS" BASIS,
73415 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73416 * See the License for the specific language governing permissions and
73417 * limitations under the License.
73418 * =============================================================================
73419 */
73420 function transpose$1(args) {
73421 const { inputs, attrs, backend } = args;
73422 const { x } = inputs;
73423 const { perm } = attrs;
73424 assertNotComplex$1(x, 'transpose');
73425 const xRank = x.shape.length;
73426 const newShape = new Array(xRank);
73427 for (let i = 0; i < newShape.length; i++) {
73428 newShape[i] = x.shape[perm[i]];
73429 }
73430 const values = backend.data.get(x.dataId).values;
73431 const result = transposeImpl$1(values, x.shape, x.dtype, perm, newShape);
73432 const dataId = backend.write(result, newShape, x.dtype);
73433 return { dataId, shape: newShape, dtype: x.dtype };
73434 }
73435 const transposeConfig$1 = {
73436 kernelName: Transpose,
73437 backendName: 'cpu',
73438 kernelFunc: transpose$1
73439 };
73440
73441 /**
73442 * @license
73443 * Copyright 2020 Google LLC. All Rights Reserved.
73444 * Licensed under the Apache License, Version 2.0 (the "License");
73445 * you may not use this file except in compliance with the License.
73446 * You may obtain a copy of the License at
73447 *
73448 * http://www.apache.org/licenses/LICENSE-2.0
73449 *
73450 * Unless required by applicable law or agreed to in writing, software
73451 * distributed under the License is distributed on an "AS IS" BASIS,
73452 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73453 * See the License for the specific language governing permissions and
73454 * limitations under the License.
73455 * =============================================================================
73456 */
73457 function prodImpl(xShape, xDtype, xVals, reductionAxes) {
73458 const [outShape, reduceShape] = computeOutAndReduceShapes(xShape, reductionAxes);
73459 const outDtype = upcastType(xDtype, 'int32');
73460 const outVals = makeZerosTypedArray(sizeFromShape(outShape), outDtype);
73461 const reduceSize = sizeFromShape(reduceShape);
73462 for (let i = 0; i < outVals.length; ++i) {
73463 const offset = i * reduceSize;
73464 let prod = 1;
73465 for (let j = 0; j < reduceSize; ++j) {
73466 prod *= xVals[offset + j];
73467 }
73468 outVals[i] = prod;
73469 }
73470 return { outVals, outShape, outDtype };
73471 }
73472 function prod$1(args) {
73473 const { inputs, backend, attrs } = args;
73474 const { x } = inputs;
73475 const { axis, keepDims } = attrs;
73476 assertNotComplex$1(x, 'prod');
73477 const xRank = x.shape.length;
73478 const axes = parseAxisParam(axis, x.shape);
73479 const permutation = getAxesPermutation(axes, xRank);
73480 let reductionAxes = axes;
73481 let permutedX = x;
73482 const intermediateTensorInfos = [];
73483 if (permutation != null) {
73484 permutedX = transpose$1({ inputs: { x }, backend, attrs: { perm: permutation } });
73485 intermediateTensorInfos.push(permutedX);
73486 reductionAxes = getInnerMostAxes(reductionAxes.length, xRank);
73487 }
73488 const xVals = backend.data.get(permutedX.dataId).values;
73489 const { outVals, outShape, outDtype } = prodImpl(permutedX.shape, permutedX.dtype, xVals, reductionAxes);
73490 let resultShape = outShape;
73491 if (keepDims) {
73492 resultShape = expandShapeToKeepDim(outShape, axes);
73493 }
73494 intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
73495 return backend.makeTensorInfo(resultShape, outDtype, outVals);
73496 }
73497 const prodConfig$1 = {
73498 kernelName: Prod,
73499 backendName: 'cpu',
73500 kernelFunc: prod$1
73501 };
73502
73503 /**
73504 * @license
73505 * Copyright 2022 Google LLC. All Rights Reserved.
73506 * Licensed under the Apache License, Version 2.0 (the "License");
73507 * you may not use this file except in compliance with the License.
73508 * You may obtain a copy of the License at
73509 *
73510 * http://www.apache.org/licenses/LICENSE-2.0
73511 *
73512 * Unless required by applicable law or agreed to in writing, software
73513 * distributed under the License is distributed on an "AS IS" BASIS,
73514 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73515 * See the License for the specific language governing permissions and
73516 * limitations under the License.
73517 * =============================================================================
73518 */
73519 function validateIndices(indices, indicesShape, numParams) {
73520 indices.forEach((index, i) => {
73521 if (index < 0 || index >= numParams) {
73522 const locString = indexToLoc(i, indicesShape.length, computeStrides(indicesShape))
73523 .join(',');
73524 throw new Error(`indices[${locString}] = ${index} is not in [0, ${numParams})`);
73525 }
73526 });
73527 }
73528 function validateSplits(paramsNestedSplits, numParamsDenseValues) {
73529 // Validate
73530 for (let dim = 0; dim < paramsNestedSplits.length; ++dim) {
73531 const splits = paramsNestedSplits[dim];
73532 const lastSplit = (dim === paramsNestedSplits.length - 1) ?
73533 numParamsDenseValues :
73534 paramsNestedSplits[dim + 1].length;
73535 if (splits.length === 0) {
73536 throw new Error('Ragged splits may not be empty');
73537 }
73538 if (splits[0] < 0) {
73539 throw new Error('Ragged splits must be non-negative');
73540 }
73541 if (splits[splits.length - 1] > lastSplit) {
73542 throw new Error('Ragged splits must not point past values');
73543 }
73544 for (let i = 1; i < splits.length; ++i) {
73545 if (splits[i - 1] > splits[i]) {
73546 throw new Error('Ragged splits must be sorted in ascending order');
73547 }
73548 }
73549 }
73550 }
73551 // Construct the `splits` output tensors, encoded using a nested vector.
73552 // Also find the slices of values that need to be copied, and store them
73553 // in `valueSlices`. The total number of values that will be copied (which
73554 // we need for allocating the output values tensor) is stored in `numValues`.
73555 function makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues) {
73556 const valueSlices = [];
73557 let numValues = 0;
73558 const numSplits = indicesShape.length - 1 + paramsNestedSplits.length;
73559 const outSplits = new Array(numSplits).fill(null).map(() => [0]);
73560 validateSplits(paramsNestedSplits, numParamsDenseValues);
73561 // Add `splits` that come from all but the last dimension of the dense
73562 // Tensor `indices`. In particular, for each dimension D, we add a
73563 // splits tensor whose values are:
73564 // range(reduceProd(splits.shape[:D]) + 1) * splits.shape[D+1]
73565 // E.g., if indices.shape=[2, 3, 4] then we will add splits tensors:
73566 // [0, 3, 6] # length=2+1, stride=3
73567 // [0, 4, 8, 12, 16, 20, 24] # length=2*3+1, stride=4
73568 let nrows = 1;
73569 for (let dim = 0; dim < indicesShape.length - 1; ++dim) {
73570 nrows *= indicesShape[dim];
73571 const rowLength = indicesShape[dim + 1];
73572 for (let i = 1; i < nrows + 1; ++i) {
73573 outSplits[dim].push(i * rowLength);
73574 }
73575 }
73576 // Add `splits` that come from `paramsNestedSplits`. Starting with the
73577 // outermost ragged dimension (i.e., the first `splits` tensor), we work
73578 // our way in, finding the range of values that should be copied. As we
73579 // go, we update the output `splits` for each dimension with the appropriate
73580 // values. In particular, the *lengths* of the slices from `param_splits`
73581 // should be copied to generate corresponding slice lengths in the output
73582 // splits. E.g., if we are copying a ragged row with length 4, then we
73583 // should add a new split point to outSplits that is 4 greater than the
73584 // previous split point in outSplits.
73585 for (let i = 0; i < indices.length; ++i) {
73586 let start = indices[i];
73587 let limit = indices[i] + 1;
73588 // Copy splits.
73589 for (let dim = 0; dim < paramsNestedSplits.length; ++dim) {
73590 const splits = paramsNestedSplits[dim];
73591 const outDim = dim + indicesShape.length - 1;
73592 if (outDim >= 0) {
73593 const outSplitsOutDim = outSplits[outDim];
73594 const delta = outSplitsOutDim[outSplitsOutDim.length - 1] - splits[start];
73595 for (let j = start; j < limit; ++j) {
73596 outSplits[outDim].push(splits[j + 1] + delta);
73597 }
73598 }
73599 start = splits[start];
73600 limit = splits[limit];
73601 }
73602 if (limit !== start) {
73603 valueSlices.push([start, limit]);
73604 numValues += limit - start;
73605 }
73606 }
73607 return { outSplits, valueSlices, numValues };
73608 }
73609 function getSplits(outSplits) {
73610 const splitsOut = [];
73611 for (let i = 0; i < outSplits.length; ++i) {
73612 const numSplits = outSplits[i].length;
73613 const splits = getArrayFromDType('int32', numSplits);
73614 splitsOut.push(splits);
73615 outSplits[i].forEach((value, j) => splits[j] = value);
73616 }
73617 return splitsOut;
73618 }
73619 function computeFlatOuterDims(orig, numOutDims) {
73620 const outDims = orig.slice(0, numOutDims);
73621 while (outDims.length < numOutDims) {
73622 outDims.push(1);
73623 }
73624 for (let inDim = numOutDims; inDim < orig.length; inDim++) {
73625 outDims[numOutDims - 1] *= orig[inDim];
73626 }
73627 return outDims;
73628 }
73629 // For each slice in `(start, limit)` in `valueSlices`, append
73630 // `paramsDenseValues[start,...,limit] to `values`. `valueSize` indicates
73631 // the number of scalars contained in each value paramsDenseValues[i].
73632 function writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, values, valuesShape) {
73633 const denseM = computeFlatOuterDims(paramsDenseValuesShape, 2)[1];
73634 const valuesM = computeFlatOuterDims(valuesShape, 2)[1];
73635 let outPos = 0;
73636 for (const slice of valueSlices) {
73637 for (let i = slice[0]; i < slice[1]; ++i) {
73638 for (let j = 0; j < valueSize; ++j) {
73639 values[outPos * valuesM + j] = paramsDenseValues[i * denseM + j];
73640 }
73641 ++outPos;
73642 }
73643 }
73644 }
73645 function getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues) {
73646 const valuesShape = paramsDenseValuesShape.slice();
73647 valuesShape[0] = numValues;
73648 const valuesOut = getArrayFromDType(paramsDenseValuesDType, sizeFromShape(valuesShape));
73649 const numElements = paramsDenseValues.length;
73650 const valueSize = numElements === 0 ? 0 : (numElements / paramsDenseValuesShape[0]);
73651 writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, valuesOut, valuesShape);
73652 return [valuesOut, valuesShape];
73653 }
73654 function raggedGatherImpl(paramsNestedSplits, paramsNestedSplitsShapes, paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, indices, indicesShape, outputRaggedRank) {
73655 if (paramsNestedSplits.length === 0) {
73656 throw new Error('paramsNestedSplits must be non empty');
73657 }
73658 if (paramsNestedSplitsShapes[0].length === 0) {
73659 throw new Error('Split tensors must not be scalars');
73660 }
73661 const numParams = paramsNestedSplitsShapes[0][0] - 1;
73662 validateIndices(indices, indicesShape, numParams);
73663 if (paramsDenseValuesShape.length === 0) {
73664 throw new Error('params.rank must be nonzero');
73665 }
73666 const numParamsDenseValues = paramsDenseValuesShape[0];
73667 // Calculate the `splits`, and store the value slices that we need to
73668 // copy in `valueSlices`.
73669 const { outSplits, valueSlices, numValues } = makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues);
73670 // Write the output tensors.
73671 const outputNestedSplits = getSplits(outSplits);
73672 const outputDenseValues = getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues);
73673 return [outputNestedSplits, outputDenseValues[0], outputDenseValues[1]];
73674 }
73675
73676 /**
73677 * @license
73678 * Copyright 2022 Google LLC.
73679 * Licensed under the Apache License, Version 2.0 (the "License");
73680 * you may not use this file except in compliance with the License.
73681 * You may obtain a copy of the License at
73682 *
73683 * http://www.apache.org/licenses/LICENSE-2.0
73684 *
73685 * Unless required by applicable law or agreed to in writing, software
73686 * distributed under the License is distributed on an "AS IS" BASIS,
73687 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73688 * See the License for the specific language governing permissions and
73689 * limitations under the License.
73690 * =============================================================================
73691 */
73692 const INT32_MAX = 2147483647;
73693 function raggedRangeImpl(starts, startsShape, startsDType, limits, limitsShape, deltas, deltasShape) {
73694 // Check input tensor shapes.
73695 if (startsShape.length > 1) {
73696 throw new Error('starts must be a scalar or vector');
73697 }
73698 if (limitsShape.length > 1) {
73699 throw new Error('limits must be a scalar or vector');
73700 }
73701 if (deltasShape.length > 1) {
73702 throw new Error('deltas must be a scalar or vector');
73703 }
73704 // Determine which tensors we need to broadcast.
73705 const broadcastStarts = startsShape.length === 0;
73706 const broadcastLimits = limitsShape.length === 0;
73707 const broadcastDeltas = deltasShape.length === 0;
73708 // nRows (number of output rows) is the size of the non-broadcast inputs,
73709 // or 1 if all inputs are scalars.
73710 const inSizes = [];
73711 if (!broadcastStarts) {
73712 inSizes.push(startsShape[0]);
73713 }
73714 if (!broadcastLimits) {
73715 inSizes.push(limitsShape[0]);
73716 }
73717 if (!broadcastDeltas) {
73718 inSizes.push(deltasShape[0]);
73719 }
73720 for (let i = 1; i < inSizes.length; ++i) {
73721 if (inSizes[i] !== inSizes[i - 1]) {
73722 throw new Error('starts, limits, and deltas must have the same shape');
73723 }
73724 }
73725 const nRows = inSizes.length === 0 ? 1 : inSizes[0];
73726 // Construct the rtNestedSplits tensor.
73727 const rtNestedSplits = getArrayFromDType('int32', nRows + 1);
73728 rtNestedSplits[0] = 0;
73729 for (let row = 0; row < nRows; ++row) {
73730 const start = broadcastStarts ? starts[0] : starts[row];
73731 const limit = broadcastLimits ? limits[0] : limits[row];
73732 const delta = broadcastDeltas ? deltas[0] : deltas[row];
73733 if (delta === 0) {
73734 throw new Error('Requires delta != 0');
73735 }
73736 let size; // The number of elements in the specified range.
73737 if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) {
73738 size = 0;
73739 }
73740 else {
73741 size = Math.ceil(Math.abs((limit - start) / delta));
73742 if (size > INT32_MAX) {
73743 throw new Error(`Requires ((limit - start) / delta) <= ${INT32_MAX}`);
73744 }
73745 }
73746 rtNestedSplits[row + 1] = rtNestedSplits[row] + size;
73747 }
73748 const nVals = rtNestedSplits[nRows];
73749 // Construct the rtDenseValues tensor.
73750 const rtDenseValues = getArrayFromDType(startsDType, nVals);
73751 let valueIndex = 0;
73752 for (let row = 0; row < nRows; ++row) {
73753 const rowSize = rtNestedSplits[row + 1] - rtNestedSplits[row];
73754 let value = broadcastStarts ? starts[0] : starts[row];
73755 const delta = broadcastDeltas ? deltas[0] : deltas[row];
73756 for (let i = 0; i < rowSize; ++i) {
73757 rtDenseValues[valueIndex++] = value;
73758 value += delta;
73759 }
73760 }
73761 return [rtNestedSplits, rtDenseValues];
73762 }
73763
73764 /**
73765 * @license
73766 * Copyright 2022 Google LLC. All Rights Reserved.
73767 * Licensed under the Apache License, Version 2.0 (the "License");
73768 * you may not use this file except in compliance with the License.
73769 * You may obtain a copy of the License at
73770 *
73771 * http://www.apache.org/licenses/LICENSE-2.0
73772 *
73773 * Unless required by applicable law or agreed to in writing, software
73774 * distributed under the License is distributed on an "AS IS" BASIS,
73775 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73776 * See the License for the specific language governing permissions and
73777 * limitations under the License.
73778 * =============================================================================
73779 */
73780 var RowPartitionType = RowPartitionType$1;
73781 // Based on
73782 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
73783 class RaggedTensorToTensorOp {
73784 constructor(shape, shapeShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypeStrings) {
73785 this.shape = shape;
73786 this.shapeShape = shapeShape;
73787 this.values = values;
73788 this.valuesShape = valuesShape;
73789 this.valuesDType = valuesDType;
73790 this.defaultValue = defaultValue;
73791 this.defaultValueShape = defaultValueShape;
73792 this.rowPartitionValues = rowPartitionValues;
73793 this.rowPartitionValuesShapes = rowPartitionValuesShapes;
73794 this.rowPartitionTypes =
73795 getRowPartitionTypesHelper(rowPartitionTypeStrings);
73796 this.raggedRank = getRaggedRank(this.rowPartitionTypes);
73797 }
73798 getRowPartitionTypeByDimension(dimension) {
73799 if (this.rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) {
73800 return this.rowPartitionTypes[dimension + 1];
73801 }
73802 else {
73803 return this.rowPartitionTypes[dimension];
73804 }
73805 }
73806 // Returns the relationship between dimension and dimension + 1.
73807 getRowPartitionTensor(dimension) {
73808 if (this.rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) {
73809 return this.rowPartitionValues[dimension + 1];
73810 }
73811 else {
73812 return this.rowPartitionValues[dimension];
73813 }
73814 }
73815 getMaxWidth(dimension) {
73816 const rowPartitionTensor = this.getRowPartitionTensor(dimension - 1);
73817 switch (this.getRowPartitionTypeByDimension(dimension - 1)) {
73818 case RowPartitionType.VALUE_ROWIDS:
73819 return RaggedTensorToTensorOp.getMaxWidthValueRowID(rowPartitionTensor);
73820 case RowPartitionType.ROW_SPLITS:
73821 return RaggedTensorToTensorOp.getMaxWidthRowSplit(rowPartitionTensor);
73822 default:
73823 throw new Error(`Cannot handle partition type ${RowPartitionType[this.getRowPartitionTypeByDimension(dimension - 1)]}`);
73824 }
73825 }
73826 static getMaxWidthRowSplit(rowSplit) {
73827 const tensorLength = rowSplit.length;
73828 if (tensorLength === 0 || tensorLength === 1) {
73829 return 0;
73830 }
73831 let maxWidth = 0;
73832 for (let i = 0; i < tensorLength - 1; ++i) {
73833 const currentWidth = rowSplit[i + 1] - rowSplit[i];
73834 if (currentWidth > maxWidth) {
73835 maxWidth = currentWidth;
73836 }
73837 }
73838 return maxWidth;
73839 }
73840 static getMaxWidthValueRowID(valueRowIds) {
73841 const indexLength = valueRowIds.length;
73842 if (indexLength === 0) {
73843 return 0;
73844 }
73845 let firstEqualIndex = 0;
73846 let firstEqualIndexValue = valueRowIds[0];
73847 let maxWidth = 0;
73848 for (let i = 1; i < indexLength; ++i) {
73849 const value = valueRowIds[i];
73850 if (value !== firstEqualIndexValue) {
73851 firstEqualIndexValue = value;
73852 maxWidth = Math.max(i - firstEqualIndex, maxWidth);
73853 firstEqualIndex = i;
73854 }
73855 }
73856 return Math.max(indexLength - firstEqualIndex, maxWidth);
73857 }
73858 tensorShapeFromTensor(t, tShape, isPartial = true) {
73859 if (tShape.length === 0) {
73860 if (t[0] === -1) {
73861 return [];
73862 }
73863 throw new Error(`The only valid scalar shape tensor is the fully unknown shape specified as -1.`);
73864 }
73865 // MakePartialShape/MakeShapeHelper.
73866 return makeShape(t, isPartial);
73867 }
73868 calculateOutputSize(firstDim) {
73869 const valueShape = this.valuesShape;
73870 const defaultValueShape = this.defaultValueShape;
73871 validateDefaultValueShape(defaultValueShape, valueShape);
73872 const shape = this.tensorShapeFromTensor(this.shape, this.shapeShape);
73873 const outputShape = combineRaggedTensorToTensorShapes(this.raggedRank, shape, valueShape);
73874 const result = outputShape;
73875 if (result[0] < 0) {
73876 result[0] = firstDim;
73877 }
73878 for (let i = 1; i <= this.raggedRank; ++i) {
73879 if (result[i] < 0) {
73880 result[i] = this.getMaxWidth(i);
73881 }
73882 }
73883 return result;
73884 }
73885 /**
73886 * The outputIndex represents the index in the output tensor
73887 * where the first element of a particular dimension would be written.
73888 * If it is -1, it indicates that the index is out of scope.
73889 * Example, given firstDimension = 10, firstDimensionOutput = 6,
73890 * and outputIndexMultiplier = 100:
73891 * result = [0 100 200 300 400 500 -1 -1 -1 -1]
73892 * If firstDimensionOutput = 11 instead, then:
73893 * result = [0 100 200 300 400 500 600 700 800 900]
73894 */
73895 calculateFirstParentOutputIndex(firstDimension, outputIndexMultiplier, firstDimensionOutput) {
73896 const minDimension = Math.min(firstDimension, firstDimensionOutput);
73897 const result = [];
73898 let currentOutputIndex = 0;
73899 for (let i = 0; i < minDimension; ++i, currentOutputIndex += outputIndexMultiplier) {
73900 result.push(currentOutputIndex);
73901 }
73902 for (let i = minDimension; i < firstDimension; ++i) {
73903 result.push(-1);
73904 }
73905 assert$1(result.length === firstDimension, () => 'Final length of result must be equal to firstDimension.');
73906 return result;
73907 }
73908 calculateOutputIndexRowSplit(rowSplit, parentOutputIndex, outputIndexMultiplier, outputSize) {
73909 const rowSplitSize = rowSplit.length;
73910 const result = [];
73911 for (let i = 0; i < rowSplitSize - 1; ++i) {
73912 const rowLength = rowSplit[i + 1] - rowSplit[i];
73913 let realLength = Math.min(outputSize, rowLength);
73914 let parentOutputIndexCurrent = parentOutputIndex[i];
73915 if (parentOutputIndexCurrent === -1) {
73916 realLength = 0;
73917 }
73918 for (let j = 0; j < realLength; ++j) {
73919 result.push(parentOutputIndexCurrent);
73920 parentOutputIndexCurrent += outputIndexMultiplier;
73921 }
73922 for (let j = 0; j < rowLength - realLength; ++j) {
73923 result.push(-1);
73924 }
73925 }
73926 if (rowSplitSize > 0 && result.length !== rowSplit[rowSplitSize - 1]) {
73927 throw new Error('Invalid row split size.');
73928 }
73929 return result;
73930 }
73931 // Calculate the output index of the first element of a list.
73932 // The parentOutputIndex is the same computation for the previous list.
73933 // -1 indicates an element or list that is out of range.
73934 // The outputIndexMultiplier is the number of output indices one moves
73935 // forward for each column.
73936 // E.g., given:
73937 // valueRowIds:[0 1 2 2 2 3 5 5 6]
73938 // parentOutputIndex:[1000 1100 2000 2100 -1 3000 4000]
73939 // outputIndexMultiplier: 10
73940 // outputSize: 2
73941 // You get:
73942 // result = [1000 1100 2000 2010 -1 2100 -1 -1 3000]
73943 // result[0] = parentOutputIndex[valueRowIds[0]]
73944 // result[1] = parentOutputIndex[valueRowIds[1]]
73945 // result[2] = parentOutputIndex[valueRowIds[2]]
73946 // result[3] = parentOutputIndex[valueRowIds[2] + 10]
73947 // result[4] = -1 because it is the third element the size is 2.
73948 // result[5] = parentOutputIndex[valueRowIds[3]]
73949 // result[6] = -1 because parentOutputIndex[valueRowIds[6]] == -1
73950 // result[7] = -1 because parentOutputIndex[valueRowIds[6]] == -1
73951 // result[8] = parentOutputIndex[valueRowIds[7]]
73952 calculateOutputIndexValueRowID(valueRowIds, parentOutputIndex, outputIndexMultiplier, outputSize) {
73953 const indexSize = valueRowIds.length;
73954 const result = [];
73955 if (indexSize === 0) {
73956 return [];
73957 }
73958 let currentOutputColumn = 0;
73959 let currentValueRowId = valueRowIds[0];
73960 if (currentValueRowId >= parentOutputIndex.length) {
73961 throw new Error(`Got currentValueRowId=${currentValueRowId}, which is not less than ${parentOutputIndex.length}`);
73962 }
73963 let currentOutputIndex = parentOutputIndex[currentValueRowId];
73964 result.push(currentOutputIndex);
73965 for (let i = 1; i < indexSize; ++i) {
73966 const nextValueRowId = valueRowIds[i];
73967 if (nextValueRowId === currentValueRowId) {
73968 if (currentOutputIndex >= 0) {
73969 ++currentOutputColumn;
73970 if (currentOutputColumn < outputSize) {
73971 currentOutputIndex += outputIndexMultiplier;
73972 }
73973 else {
73974 currentOutputIndex = -1;
73975 }
73976 }
73977 }
73978 else {
73979 currentOutputColumn = 0;
73980 currentValueRowId = nextValueRowId;
73981 if (nextValueRowId >= parentOutputIndex.length) {
73982 throw new Error(`Got nextValueRowId=${nextValueRowId} which is not less than ${parentOutputIndex.length}`);
73983 }
73984 currentOutputIndex = parentOutputIndex[nextValueRowId];
73985 }
73986 result.push(currentOutputIndex);
73987 }
73988 if (result.length !== valueRowIds.length) {
73989 throw new Error('Invalid row ids.');
73990 }
73991 return result;
73992 }
73993 calculateOutputIndex(dimension, parentOutputIndex, outputIndexMultiplier, outputSize) {
73994 const rowPartitionTensor = this.getRowPartitionTensor(dimension);
73995 const partitionType = this.getRowPartitionTypeByDimension(dimension);
73996 switch (partitionType) {
73997 case RowPartitionType.VALUE_ROWIDS:
73998 return this.calculateOutputIndexValueRowID(rowPartitionTensor, parentOutputIndex, outputIndexMultiplier, outputSize);
73999 case RowPartitionType.ROW_SPLITS:
74000 if (rowPartitionTensor.length - 1 > parentOutputIndex.length) {
74001 throw new Error(`Row partition size is greater than output size: ${rowPartitionTensor.length - 1} > ${parentOutputIndex.length}`);
74002 }
74003 return this.calculateOutputIndexRowSplit(rowPartitionTensor, parentOutputIndex, outputIndexMultiplier, outputSize);
74004 default:
74005 throw new Error(`Unsupported partition type: ${RowPartitionType[partitionType]}`);
74006 }
74007 }
74008 getFirstDimensionSize() {
74009 const firstPartitionTensor = this.rowPartitionValues[0];
74010 if (this.rowPartitionTypes.length === 0) {
74011 throw new Error('No row_partition_types given.');
74012 }
74013 const firstPartitionType = this.rowPartitionTypes[0];
74014 switch (firstPartitionType) {
74015 case RowPartitionType.FIRST_DIM_SIZE:
74016 return firstPartitionTensor[0];
74017 case RowPartitionType.VALUE_ROWIDS:
74018 throw new Error('Cannot handle VALUE_ROWIDS in first dimension.');
74019 case RowPartitionType.ROW_SPLITS:
74020 return this.rowPartitionValuesShapes[0][0] - 1;
74021 default:
74022 throw new Error(`Cannot handle type ${RowPartitionType[firstPartitionType]}`);
74023 }
74024 }
74025 compute() {
74026 const firstPartitionTensor = this.rowPartitionValues[0];
74027 if (firstPartitionTensor.length <= 0) {
74028 throw new Error('Invalid first partition input. ' +
74029 'Tensor requires at least one element.');
74030 }
74031 const firstDimension = this.getFirstDimensionSize();
74032 const outputSize = this.calculateOutputSize(firstDimension);
74033 const multiplier = new Array(this.raggedRank + 1);
74034 multiplier[multiplier.length - 1] = 1;
74035 for (let i = multiplier.length - 2; i >= 0; --i) {
74036 multiplier[i] = multiplier[i + 1] * outputSize[i + 1];
74037 }
74038 // Full size of the tensor.
74039 const outputShape = makeShape(outputSize, false);
74040 const outputTensor = getArrayFromDType(this.valuesDType, sizeFromShape(outputShape));
74041 const fullSize = multiplier[0] * outputSize[0];
74042 if (fullSize > 0) {
74043 let outputIndex = this.calculateFirstParentOutputIndex(firstDimension, multiplier[0], outputSize[0]);
74044 for (let i = 1; i <= this.raggedRank; ++i) {
74045 const newOutputIndex = this.calculateOutputIndex(i - 1, outputIndex, multiplier[i], outputSize[i]);
74046 outputIndex = newOutputIndex;
74047 }
74048 this.setOutput(this.raggedRank, outputIndex, outputTensor, outputShape);
74049 }
74050 return [outputShape, outputTensor];
74051 }
74052 setOutput(raggedRank, outputIndex, outputTensor, outputShape) {
74053 if (outputTensor.length === 0) {
74054 return;
74055 }
74056 const valuesBase = this.values;
74057 const outputBase = outputTensor;
74058 let elementShape = outputShape.slice();
74059 elementShape = elementShape.slice(raggedRank + 1);
74060 const valueElementSize = sizeFromShape(elementShape);
74061 const outputIndexSize = outputIndex.length;
74062 // Broadcast the default value to value_element_size. (We can skip this
74063 // if defaultValueTensor.size == 1, since we use fill when that's true.)
74064 let defaultValue = this.defaultValue;
74065 if (defaultValue.length !== valueElementSize && defaultValue.length !== 1) {
74066 const srcShape = this.defaultValueShape;
74067 tidy(() => {
74068 const defaultValueTensor = reshape$3(defaultValue, srcShape);
74069 const bCastDefault = broadcastTo(defaultValueTensor, elementShape);
74070 defaultValue = bCastDefault.dataSync();
74071 });
74072 }
74073 // Loop through the outputIndex array, finding contiguous regions that
74074 // should be copied. Once we find the end of a contiguous region, copy it
74075 // and add any necessary padding (with defaultValue).
74076 let srcStart = 0; // Start of contiguous region (in values)
74077 let dstStart = 0; // Destination for contiguous region (in output)
74078 let dstEnd = 0; // Destination for contiguous region (in output)
74079 for (let srcI = 0; srcI <= outputIndexSize; ++srcI) {
74080 // dstI is the destination where the value at srcI should be copied.
74081 let dstI = srcI < outputIndexSize ? outputIndex[srcI] : -1;
74082 // If we're still in a contiguous region, then update dstEnd go to the
74083 // next srcI.
74084 if (dstI === dstEnd) {
74085 ++dstEnd;
74086 continue;
74087 }
74088 // We found the end of contiguous region. This can be because we found
74089 // a gap (dstI > dstEnd), or a source value that shouldn't be copied
74090 // because it's out-of-bounds (dstI == -1), or the end of the tensor
74091 // (dstI === -1).
74092 if (dstStart < dstEnd) {
74093 // Copy the contiguous region.
74094 const src = valuesBase.subarray(srcStart * valueElementSize);
74095 const dst = outputBase.subarray(dstStart * valueElementSize);
74096 const nVals = (dstEnd - dstStart) * valueElementSize;
74097 copyArray(dst, src, nVals);
74098 }
74099 // Add any necessary padding (w/ defaultValue).
74100 if (srcI >= outputIndexSize) {
74101 // We reached the end of values: pad to the end of output.
74102 const outputSize = outputTensor.length;
74103 dstI = Math.floor(outputSize / valueElementSize);
74104 }
74105 if (dstI > dstEnd) {
74106 if (this.defaultValue.length === 1) {
74107 outputBase
74108 .subarray(dstEnd * valueElementSize, dstI * valueElementSize)
74109 .fill(this.defaultValue[0]);
74110 dstEnd = dstI;
74111 }
74112 else {
74113 while (dstI > dstEnd) {
74114 const dst = outputBase.slice(dstEnd * valueElementSize);
74115 copyArray(dst, defaultValue, valueElementSize);
74116 ++dstEnd;
74117 }
74118 }
74119 }
74120 // Update indices.
74121 if (dstI < 0) {
74122 // srcI should be skipped -- leave it out of the contiguous region.
74123 srcStart = srcI + 1;
74124 dstStart = dstEnd;
74125 }
74126 else {
74127 // srcI should be copied -- include it in the contiguous region.
74128 srcStart = srcI;
74129 dstStart = dstEnd;
74130 dstEnd = dstStart + 1;
74131 }
74132 }
74133 }
74134 }
74135 function copyArray(dst, src, size) {
74136 for (let i = 0; i < size; i++) {
74137 dst[i] = src[i];
74138 }
74139 }
74140 function makeShape(shape, isPartial) {
74141 const out = [];
74142 for (let dim of shape) {
74143 if (dim < 0) {
74144 if (!isPartial) {
74145 throw new Error(`Dimension ${dim} must be >= 0`);
74146 }
74147 if (dim < -1) {
74148 throw new Error(`Dimension ${dim} must be >= -1`);
74149 }
74150 dim = -1;
74151 }
74152 out.push(dim);
74153 }
74154 return out;
74155 }
74156 function raggedTensorToTensorImpl(shape, shapesShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes) {
74157 return new RaggedTensorToTensorOp(shape, shapesShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes)
74158 .compute();
74159 }
74160
74161 /**
74162 * @license
74163 * Copyright 2020 Google LLC. All Rights Reserved.
74164 * Licensed under the Apache License, Version 2.0 (the "License");
74165 * you may not use this file except in compliance with the License.
74166 * You may obtain a copy of the License at
74167 *
74168 * http://www.apache.org/licenses/LICENSE-2.0
74169 *
74170 * Unless required by applicable law or agreed to in writing, software
74171 * distributed under the License is distributed on an "AS IS" BASIS,
74172 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74173 * See the License for the specific language governing permissions and
74174 * limitations under the License.
74175 * =============================================================================
74176 */
74177 function rangeImpl(start, stop, step, dtype) {
74178 const sameStartStop = start === stop;
74179 const increasingRangeNegativeStep = start < stop && step < 0;
74180 const decreasingRangePositiveStep = stop < start && step > 1;
74181 if (sameStartStop || increasingRangeNegativeStep ||
74182 decreasingRangePositiveStep) {
74183 return makeZerosTypedArray(0, dtype);
74184 }
74185 const numElements = Math.abs(Math.ceil((stop - start) / step));
74186 const values = makeZerosTypedArray(numElements, dtype);
74187 if (stop < start && step === 1) {
74188 // Auto adjust the step's sign if it hasn't been set
74189 // (or was set to 1)
74190 step = -1;
74191 }
74192 values[0] = start;
74193 for (let i = 1; i < values.length; i++) {
74194 values[i] = values[i - 1] + step;
74195 }
74196 return values;
74197 }
74198
74199 /**
74200 * @license
74201 * Copyright 2020 Google LLC. All Rights Reserved.
74202 * Licensed under the Apache License, Version 2.0 (the License);
74203 * you may not use this file except in compliance with the License.
74204 * You may obtain a copy of the License at
74205 *
74206 * http://www.apache.org/licenses/LICENSE-2.0
74207 *
74208 * Unless required by applicable law or agreed to in writing, software
74209 * distributed under the License is distributed on an AS IS BASIS,
74210 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74211 * See the License for the specific language governing permissions and
74212 * limitations under the License.
74213 * =============================================================================
74214 */
74215 const rsqrtImpl = createSimpleUnaryImpl((xi) => 1 / Math.sqrt(xi));
74216 const rsqrt$1 = unaryKernelFuncFromImpl(Rsqrt, rsqrtImpl);
74217 const rsqrtConfig$1 = {
74218 kernelName: Rsqrt,
74219 backendName: 'cpu',
74220 kernelFunc: rsqrt$1,
74221 };
74222
74223 /**
74224 * @license
74225 * Copyright 2020 Google LLC. All Rights Reserved.
74226 * Licensed under the Apache License, Version 2.0 (the "License");
74227 * you may not use this file except in compliance with the License.
74228 * You may obtain a copy of the License at
74229 *
74230 * http://www.apache.org/licenses/LICENSE-2.0
74231 *
74232 * Unless required by applicable law or agreed to in writing, software
74233 * distributed under the License is distributed on an "AS IS" BASIS,
74234 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74235 * See the License for the specific language governing permissions and
74236 * limitations under the License.
74237 * =============================================================================
74238 */
74239 function scatterImpl(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices) {
74240 const flattenShape = [outputSize / sliceSize, sliceSize];
74241 const indicesData = indices.values;
74242 const updatesData = updates.values;
74243 if (outputSize === 0) {
74244 return buffer(shape, updates.dtype);
74245 }
74246 const outBuf = (defaultValue instanceof TensorBuffer) ?
74247 defaultValue :
74248 buffer(flattenShape, updates.dtype);
74249 if (typeof defaultValue === 'string') {
74250 outBuf.values.fill(defaultValue);
74251 }
74252 else if (typeof defaultValue === 'number') {
74253 outBuf.values.fill(defaultValue);
74254 }
74255 else if (typeof defaultValue === 'boolean') {
74256 outBuf.values.fill(+defaultValue);
74257 }
74258 for (let i = 0; i < numUpdates; i++) {
74259 const index = [];
74260 let flattenIndex = 0;
74261 for (let j = 0; j < sliceRank; j++) {
74262 const dim = indicesData[i * sliceRank + j];
74263 index.push(dim);
74264 flattenIndex += dim * strides[j];
74265 }
74266 if (flattenIndex < 0 || flattenIndex >= outputSize / sliceSize) {
74267 throw new Error(`Invalid indices: ${index} does not index into ${shape}`);
74268 }
74269 for (let k = 0; k < sliceSize; k++) {
74270 if (sumDupeIndices) {
74271 outBuf.values[flattenIndex * sliceSize + k] +=
74272 updatesData[i * sliceSize + k];
74273 }
74274 else {
74275 outBuf.values[flattenIndex * sliceSize + k] = updates.rank === 0 ?
74276 updatesData[0] :
74277 updatesData[i * sliceSize + k];
74278 }
74279 }
74280 }
74281 return outBuf;
74282 }
74283
74284 /**
74285 * @license
74286 * Copyright 2020 Google LLC. All Rights Reserved.
74287 * Licensed under the Apache License, Version 2.0 (the License);
74288 * you may not use this file except in compliance with the License.
74289 * You may obtain a copy of the License at
74290 *
74291 * http://www.apache.org/licenses/LICENSE-2.0
74292 *
74293 * Unless required by applicable law or agreed to in writing, software
74294 * distributed under the License is distributed on an AS IS BASIS,
74295 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74296 * See the License for the specific language governing permissions and
74297 * limitations under the License.
74298 * =============================================================================
74299 */
74300 const sigmoidImpl = createSimpleUnaryImpl((xi) => 1 / (1 + Math.exp(-xi)));
74301 const sigmoid$1 = unaryKernelFunc$1(Sigmoid$1, (xi) => 1 / (1 + Math.exp(-xi)));
74302 const sigmoidConfig$1 = {
74303 kernelName: Sigmoid$1,
74304 backendName: 'cpu',
74305 kernelFunc: sigmoid$1,
74306 };
74307
74308 /**
74309 * @license
74310 * Copyright 2020 Google LLC. All Rights Reserved.
74311 * Licensed under the Apache License, Version 2.0 (the "License");
74312 * you may not use this file except in compliance with the License.
74313 * You may obtain a copy of the License at
74314 *
74315 * http://www.apache.org/licenses/LICENSE-2.0
74316 *
74317 * Unless required by applicable law or agreed to in writing, software
74318 * distributed under the License is distributed on an "AS IS" BASIS,
74319 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74320 * See the License for the specific language governing permissions and
74321 * limitations under the License.
74322 * =============================================================================
74323 */
74324 function sliceImpl(vals, begin, size, shape, dtype) {
74325 const isContinous = isSliceContinous(shape, begin, size);
74326 const length = sizeFromShape(size);
74327 const xStrides = computeStrides(shape);
74328 if (isContinous) {
74329 const flatOffset = computeFlatOffset(begin, xStrides);
74330 if (dtype === 'string') {
74331 return vals.slice(flatOffset, flatOffset + length);
74332 }
74333 return vals.subarray(flatOffset, flatOffset + length);
74334 }
74335 const decodedData = dtype === 'string' ?
74336 fromUint8ToStringArray(vals) :
74337 vals;
74338 const inBuf = buffer(shape, dtype, decodedData);
74339 const outBuf = buffer(size, dtype);
74340 for (let i = 0; i < outBuf.size; ++i) {
74341 const outLoc = outBuf.indexToLoc(i);
74342 const inLoc = outLoc.map((idx, j) => idx + begin[j]);
74343 outBuf.set(inBuf.get(...inLoc), ...outLoc);
74344 }
74345 if (dtype === 'string') {
74346 return fromStringArrayToUint8(outBuf.values);
74347 }
74348 return outBuf.values;
74349 }
74350 function slice$1(args) {
74351 const { inputs, backend, attrs } = args;
74352 const { x } = inputs;
74353 const { begin, size } = attrs;
74354 assertNotComplex$1(x, 'slice');
74355 const [$begin, $size] = parseSliceParams(x, begin, size);
74356 assertParamsValid(x, $begin, $size);
74357 const vals = backend.data.get(x.dataId).values;
74358 const outVals = sliceImpl(vals, $begin, $size, x.shape, x.dtype);
74359 return backend.makeTensorInfo($size, x.dtype, outVals);
74360 }
74361 const sliceConfig$1 = {
74362 kernelName: Slice,
74363 backendName: 'cpu',
74364 kernelFunc: slice$1
74365 };
74366
74367 /**
74368 * @license
74369 * Copyright 2021 Google LLC. All Rights Reserved.
74370 * Licensed under the Apache License, Version 2.0 (the "License");
74371 * you may not use this file except in compliance with the License.
74372 * You may obtain a copy of the License at
74373 *
74374 * http://www.apache.org/licenses/LICENSE-2.0
74375 *
74376 * Unless required by applicable law or agreed to in writing, software
74377 * distributed under the License is distributed on an "AS IS" BASIS,
74378 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74379 * See the License for the specific language governing permissions and
74380 * limitations under the License.
74381 * =============================================================================
74382 */
74383 function sparseFillEmptyRowsImpl(indices, indicesShape, indicesDType, values, valuesDType, denseShape, defaultValue) {
74384 const indicesCount = indicesShape[0];
74385 const denseRows = denseShape[0];
74386 const emptyRowIndicator = new Array(denseRows);
74387 const reverseIndexMap = new Array(indicesCount);
74388 const rank = indicesShape[1];
74389 if (denseRows === 0) {
74390 if (indicesCount !== 0) {
74391 throw new Error(getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesCount));
74392 }
74393 const outputIndices = getArrayFromDType(indicesDType, 0);
74394 const outputValues = getArrayFromDType(valuesDType, 0);
74395 return [
74396 outputIndices, [0, rank], outputValues, emptyRowIndicator, reverseIndexMap
74397 ];
74398 }
74399 let rowsAreOrdered = true;
74400 let lastIndicesRow = 0;
74401 const csrOffset = new Array(denseRows).fill(0);
74402 for (let i = 0; i < indicesCount; ++i) {
74403 // indices is a 2d tensor with shape of [N, rank]
74404 const row = indices[i * rank];
74405 if (row < 0) {
74406 throw new Error(getSparseFillEmptyRowsNegativeIndexErrorMessage(i, row));
74407 }
74408 if (row >= denseRows) {
74409 throw new Error(getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(i, row, denseRows));
74410 }
74411 ++csrOffset[row];
74412 rowsAreOrdered = rowsAreOrdered && (row >= lastIndicesRow);
74413 lastIndicesRow = row;
74414 }
74415 let allRowsFull = true;
74416 for (let row = 0; row < denseRows; ++row) {
74417 // csrOffset here describes the number of elements in this dense row
74418 const rowEmpty = (csrOffset[row] === 0);
74419 emptyRowIndicator[row] = rowEmpty;
74420 allRowsFull = allRowsFull && !rowEmpty;
74421 // In filled version, each row has at least one element.
74422 csrOffset[row] = Math.max(csrOffset[row], 1);
74423 // Update csrOffset to represent the number of elements up to and
74424 // including denseRows + 1:
74425 // csrOffset[0] == #{elements of row 0}
74426 // csrOffset[1] == #{elements of row 1} + #{elements of row 0}
74427 // ..
74428 // csrOffset[i] == starting index for elements in row i + 1.
74429 if (row > 0) {
74430 csrOffset[row] += csrOffset[row - 1];
74431 }
74432 }
74433 if (allRowsFull && rowsAreOrdered) {
74434 const outputIndices = indices;
74435 const outputValues = values;
74436 for (let i = 0; i < indicesCount; ++i) {
74437 reverseIndexMap[i] = i;
74438 }
74439 return [
74440 outputIndices, [indicesCount, rank], outputValues, emptyRowIndicator,
74441 reverseIndexMap
74442 ];
74443 }
74444 else {
74445 const fullIndicesCount = csrOffset[denseRows - 1];
74446 const outputIndices = getArrayFromDType(indicesDType, fullIndicesCount * rank);
74447 const outputValues = getArrayFromDType(valuesDType, fullIndicesCount);
74448 const filledCount = new Array(denseRows).fill(0);
74449 // Fill in values for rows that are not missing
74450 for (let i = 0; i < indicesCount; ++i) {
74451 // indices is a 2d tensor with shape of [N, rank]
74452 const row = indices[i * rank];
74453 const offset = filledCount[row];
74454 const outputI = ((row === 0) ? 0 : csrOffset[row - 1]) + offset;
74455 filledCount[row]++; // Increment the filled count for this row.
74456 for (let j = 0; j < rank; ++j) {
74457 // indices and outputIndices are 2d tensors with shape of [N, rank]
74458 outputIndices[outputI * rank + j] = indices[i * rank + j];
74459 }
74460 outputValues[outputI] = values[i];
74461 // We'll need this reverse index map to backprop correctly.
74462 reverseIndexMap[i] = outputI;
74463 }
74464 // Fill in values for rows that are missing
74465 for (let row = 0; row < denseRows; ++row) {
74466 const rowCount = filledCount[row];
74467 if (rowCount === 0) { // We haven't filled this row
74468 const startingIndex = (row === 0) ? 0 : csrOffset[row - 1];
74469 // Remaining index values were set to zero already.
74470 // Just need to set the row index in the right location.
74471 // outputIndices is a 2d tensor with shape of [N, rank]
74472 outputIndices[startingIndex * rank + 0] = row;
74473 for (let col = 1; col < rank; ++col) {
74474 outputIndices[startingIndex * rank + col] = 0;
74475 }
74476 outputValues[startingIndex] = defaultValue;
74477 }
74478 }
74479 return [
74480 outputIndices, [fullIndicesCount, rank], outputValues, emptyRowIndicator,
74481 reverseIndexMap
74482 ];
74483 }
74484 }
74485
74486 /**
74487 * @license
74488 * Copyright 2021 Google LLC. All Rights Reserved.
74489 * Licensed under the Apache License, Version 2.0 (the "License");
74490 * you may not use this file except in compliance with the License.
74491 * You may obtain a copy of the License at
74492 *
74493 * http://www.apache.org/licenses/LICENSE-2.0
74494 *
74495 * Unless required by applicable law or agreed to in writing, software
74496 * distributed under the License is distributed on an "AS IS" BASIS,
74497 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74498 * See the License for the specific language governing permissions and
74499 * limitations under the License.
74500 * =============================================================================
74501 */
74502 function sparseReshapeImpl(inputIndices, inputIndicesShape, inputDType, inputShape, targetShape) {
74503 const denseSize = sizeFromShape(inputShape);
74504 const nnz = inputIndicesShape[0];
74505 const outputRank = targetShape.length;
74506 // Compute the output shape. Determine product of specified dimensions, and
74507 // find the index of the unspecified one.
74508 const outputShape = [];
74509 let product = 1;
74510 let unknownIndex = -1;
74511 for (let d = 0; d < outputRank; ++d) {
74512 const size = targetShape[d];
74513 if (size === -1) {
74514 if (unknownIndex !== -1) {
74515 throw new Error(getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(unknownIndex, d));
74516 }
74517 unknownIndex = d;
74518 outputShape.push(1);
74519 }
74520 else {
74521 if (size < 0) {
74522 throw new Error(getSparseReshapeNegativeOutputDimErrorMessage(d, size));
74523 }
74524 product *= size;
74525 outputShape.push(size);
74526 }
74527 }
74528 if (unknownIndex !== -1) {
74529 if (product <= 0) {
74530 throw new Error(getSparseReshapeEmptyTensorZeroOutputDimErrorMessage());
74531 }
74532 const missing = Math.trunc(denseSize / product);
74533 if (product * missing !== denseSize) {
74534 throw new Error(getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape));
74535 }
74536 outputShape[unknownIndex] = missing;
74537 }
74538 const outputSize = sizeFromShape(outputShape);
74539 if (outputSize !== denseSize) {
74540 throw new Error(getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape));
74541 }
74542 const inputRank = inputShape.length;
74543 const inputStrides = [];
74544 if (inputRank > 0) {
74545 inputStrides[inputRank - 1] = 1;
74546 for (let d = inputRank - 2; d >= 0; --d) {
74547 inputStrides[d] = inputStrides[d + 1] * inputShape[d + 1];
74548 }
74549 }
74550 const outputStrides = [];
74551 if (outputRank > 0) {
74552 outputStrides[outputRank - 1] = 1;
74553 for (let d = outputRank - 2; d >= 0; --d) {
74554 outputStrides[d] = outputStrides[d + 1] * outputShape[d + 1];
74555 }
74556 }
74557 const newIndices = getArrayFromDType(inputDType, nnz * outputRank);
74558 for (let i = 0; i < nnz; ++i) {
74559 let id = 0;
74560 for (let j = 0; j < inputRank; ++j) {
74561 // inputIndices is a 2d tensor with shape of [nnz, inputRank]
74562 id += inputIndices[i * inputRank + j] * inputStrides[j];
74563 }
74564 for (let j = 0; j < outputRank; ++j) {
74565 // newIndices is a 2d tensor with shape of [nnz, outputRank]
74566 newIndices[i * outputRank + j] = Math.trunc(id / outputStrides[j]);
74567 id %= outputStrides[j];
74568 }
74569 }
74570 return [newIndices, [nnz, outputRank], outputShape];
74571 }
74572
74573 /**
74574 * @license
74575 * Copyright 2021 Google LLC. All Rights Reserved.
74576 * Licensed under the Apache License, Version 2.0 (the "License");
74577 * you may not use this file except in compliance with the License.
74578 * You may obtain a copy of the License at
74579 *
74580 * http://www.apache.org/licenses/LICENSE-2.0
74581 *
74582 * Unless required by applicable law or agreed to in writing, software
74583 * distributed under the License is distributed on an "AS IS" BASIS,
74584 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74585 * See the License for the specific language governing permissions and
74586 * limitations under the License.
74587 * =============================================================================
74588 */
74589 function sparseSegmentReductionImpl(input, inputShape, inputDType, indices, segmentIds, isMean = false, defaultValue = 0) {
74590 const numIndices = indices.length;
74591 // Flatten the array to two dimensions
74592 const inputFlat = [inputShape[0], input.length / inputShape[0]];
74593 const numCol = inputFlat[1];
74594 // Note that the current implementation assumes that segmentIds values are
74595 // sorted.
74596 const lastSegmentIdPlusOne = numIndices > 0 ? segmentIds[numIndices - 1] + 1 : 0;
74597 const outputRows = lastSegmentIdPlusOne;
74598 if (outputRows < 0) {
74599 throw new Error(getSparseSegmentReductionNegativeSegmentIdsErrorMessage());
74600 }
74601 const outputShape = inputShape.slice();
74602 outputShape[0] = outputRows;
74603 const outputLength = outputShape.reduce((product, value) => product * value, 1);
74604 // Output array is initialized with the value 0 by default.
74605 const output = getArrayFromDType(inputDType, outputLength);
74606 // Note that we do not initialize the output buffer with a default value, so
74607 // we need to explicitly set missing indices to the default value.
74608 if (numIndices === 0) {
74609 if (outputRows > 0) {
74610 output.fill(defaultValue);
74611 }
74612 return [output, outputShape];
74613 }
74614 if (outputRows <= 0) {
74615 throw new Error(getSparseSegmentReductionNegativeSegmentIdsErrorMessage());
74616 }
74617 let start = 0, end = 1;
74618 // Index from which the output is not initialized.
74619 let uninitializedIndex = 0;
74620 let outIndex = segmentIds[start];
74621 while (true) {
74622 // We initialize nextIndex to 0 to avoid may be uninitialized warning
74623 let nextIndex = 0;
74624 if (end < numIndices) {
74625 nextIndex = segmentIds[end];
74626 if (outIndex === nextIndex) {
74627 ++end;
74628 continue;
74629 }
74630 // We have a new segment here. Verify that the segment ids are growing.
74631 if (outIndex >= nextIndex) {
74632 throw new Error(getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage());
74633 }
74634 }
74635 if (outIndex < 0 || outIndex >= outputRows) {
74636 throw new Error(getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(outIndex, outputRows));
74637 }
74638 // If there is a gap between two indices, we need to set that gap to the
74639 // default value.
74640 if (outIndex > uninitializedIndex) {
74641 output.fill(defaultValue, uninitializedIndex * numCol, outIndex * numCol);
74642 }
74643 for (let i = start; i < end; ++i) {
74644 const index = indices[i];
74645 if (index < 0 || index >= inputFlat[0]) {
74646 throw new Error(getSparseSegmentReductionIndicesOutOfRangeErrorMessage(i, indices[i], inputFlat[0]));
74647 }
74648 for (let j = 0; j < numCol; j++) {
74649 output[outIndex * numCol + j] += input[index * numCol + j];
74650 }
74651 }
74652 if (isMean) {
74653 for (let j = 0; j < numCol; j++) {
74654 output[outIndex * numCol + j] /= end - start;
74655 }
74656 }
74657 start = end;
74658 ++end;
74659 uninitializedIndex = outIndex + 1;
74660 outIndex = nextIndex;
74661 if (end > numIndices) {
74662 break;
74663 }
74664 }
74665 // Fill the gap at the end with the default value.
74666 if (uninitializedIndex < outputRows) {
74667 output.fill(defaultValue, uninitializedIndex * numCol, outputRows * numCol);
74668 }
74669 return [output, outputShape];
74670 }
74671
74672 /**
74673 * @license
74674 * Copyright 2020 Google LLC. All Rights Reserved.
74675 * Licensed under the Apache License, Version 2.0 (the License);
74676 * you may not use this file except in compliance with the License.
74677 * You may obtain a copy of the License at
74678 *
74679 * http://www.apache.org/licenses/LICENSE-2.0
74680 *
74681 * Unless required by applicable law or agreed to in writing, software
74682 * distributed under the License is distributed on an AS IS BASIS,
74683 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74684 * See the License for the specific language governing permissions and
74685 * limitations under the License.
74686 * =============================================================================
74687 */
74688 const sqrtImpl = createSimpleUnaryImpl((xi) => Math.sqrt(xi));
74689 const sqrt$1 = unaryKernelFunc$1(Sqrt, (xi) => Math.sqrt(xi));
74690 const sqrtConfig$1 = {
74691 kernelName: Sqrt,
74692 backendName: 'cpu',
74693 kernelFunc: sqrt$1,
74694 };
74695
74696 /**
74697 * @license
74698 * Copyright 2020 Google LLC. All Rights Reserved.
74699 * Licensed under the Apache License, Version 2.0 (the "License");
74700 * you may not use this file except in compliance with the License.
74701 * You may obtain a copy of the License at
74702 *
74703 * http://www.apache.org/licenses/LICENSE-2.0
74704 *
74705 * Unless required by applicable law or agreed to in writing, software
74706 * distributed under the License is distributed on an "AS IS" BASIS,
74707 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74708 * See the License for the specific language governing permissions and
74709 * limitations under the License.
74710 * =============================================================================
74711 */
74712 const squaredDifferenceImpl = createSimpleBinaryKernelImpl(((a, b) => {
74713 const diff = a - b;
74714 return diff * diff;
74715 }));
74716 const squaredDifference$1 = binaryKernelFunc$1(SquaredDifference, squaredDifferenceImpl);
74717 const squaredDifferenceConfig$1 = {
74718 kernelName: SquaredDifference,
74719 backendName: 'cpu',
74720 kernelFunc: squaredDifference$1
74721 };
74722
74723 /**
74724 * @license
74725 * Copyright 2023 Google LLC.
74726 * Licensed under the Apache License, Version 2.0 (the "License");
74727 * you may not use this file except in compliance with the License.
74728 * You may obtain a copy of the License at
74729 *
74730 * http://www.apache.org/licenses/LICENSE-2.0
74731 *
74732 * Unless required by applicable law or agreed to in writing, software
74733 * distributed under the License is distributed on an "AS IS" BASIS,
74734 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74735 * See the License for the specific language governing permissions and
74736 * limitations under the License.
74737 * =============================================================================
74738 */
74739 const staticRegexReplaceImpl = createSimpleUnaryImpl((x, attrs) => {
74740 const { pattern, replaceGlobal, rewrite } = attrs;
74741 // TODO(mattSoulanille): Don't create a regex each time.
74742 return x.replace(new RegExp(pattern, replaceGlobal ? 'g' : ''), rewrite);
74743 });
74744 const staticRegexReplace$1 = unaryKernelFuncFromImpl(StaticRegexReplace, staticRegexReplaceImpl);
74745 const staticRegexReplaceConfig$1 = {
74746 kernelName: StaticRegexReplace,
74747 backendName: 'cpu',
74748 kernelFunc: staticRegexReplace$1,
74749 };
74750
74751 /**
74752 * @license
74753 * Copyright 2020 Google LLC. All Rights Reserved.
74754 * Licensed under the Apache License, Version 2.0 (the "License");
74755 * you may not use this file except in compliance with the License.
74756 * You may obtain a copy of the License at
74757 *
74758 * http://www.apache.org/licenses/LICENSE-2.0
74759 *
74760 * Unless required by applicable law or agreed to in writing, software
74761 * distributed under the License is distributed on an "AS IS" BASIS,
74762 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74763 * See the License for the specific language governing permissions and
74764 * limitations under the License.
74765 * =============================================================================
74766 */
74767 function stridedSliceImpl(outShape, xBuf, strides, begin) {
74768 const outBuf = buffer(outShape, xBuf.dtype);
74769 for (let i = 0; i < outBuf.size; i++) {
74770 const loc = outBuf.indexToLoc(i);
74771 const newLoc = new Array(loc.length);
74772 for (let j = 0; j < newLoc.length; j++) {
74773 newLoc[j] = loc[j] * strides[j] + begin[j];
74774 }
74775 outBuf.set(xBuf.get(...newLoc), ...loc);
74776 }
74777 return outBuf;
74778 }
74779
74780 /**
74781 * @license
74782 * Copyright 2021 Google LLC. All Rights Reserved.
74783 * Licensed under the Apache License, Version 2.0 (the "License");
74784 * you may not use this file except in compliance with the License.
74785 * You may obtain a copy of the License at
74786 *
74787 * http://www.apache.org/licenses/LICENSE-2.0
74788 *
74789 * Unless required by applicable law or agreed to in writing, software
74790 * distributed under the License is distributed on an "AS IS" BASIS,
74791 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74792 * See the License for the specific language governing permissions and
74793 * limitations under the License.
74794 * =============================================================================
74795 */
74796 /**
74797 * The StringNGramsOp class creates ngrams from ragged string data.
74798 * The constructor contains all attributes related to the operation such as
74799 * padding widths and strings, and the compute function can be used to
74800 * compute the ngrams for different ragged tensor inputs.
74801 */
74802 class StringNGramsOp {
74803 constructor(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
74804 this.separator = encodeString(separator);
74805 this.nGramWidths = nGramWidths;
74806 this.leftPad = encodeString(leftPad);
74807 this.rightPad = encodeString(rightPad);
74808 this.padWidth = padWidth;
74809 this.preserveShort = preserveShortSequences;
74810 }
74811 getPadWidth(nGramWidth) {
74812 // Ngrams can be padded with either a fixed pad width or a dynamic pad
74813 // width depending on the 'padWidth' arg, but in no case should the padding
74814 // ever be wider than 'nGramWidth' - 1.
74815 return Math.min(this.padWidth < 0 ? nGramWidth - 1 : this.padWidth, nGramWidth - 1);
74816 }
74817 getNumNGrams(length, nGramWidth) {
74818 const padWidth = this.getPadWidth(nGramWidth);
74819 return Math.max(0, ((length + 2 * padWidth) - nGramWidth) + 1);
74820 }
74821 createNGrams(data, splitIndex, output, outputStartIndex, numNGrams, nGramWidth) {
74822 for (let nGramIndex = 0; nGramIndex < numNGrams; ++nGramIndex) {
74823 const padWidth = this.getPadWidth(nGramWidth);
74824 const leftPadding = Math.max(0, padWidth - nGramIndex);
74825 const rightPadding = Math.max(0, padWidth - (numNGrams - (nGramIndex + 1)));
74826 const numTokens = nGramWidth - (leftPadding + rightPadding);
74827 const dataStartIndex = splitIndex + (leftPadding > 0 ? 0 : nGramIndex - padWidth);
74828 // Calculate the total expected size of the nGram so we can reserve the
74829 // correct amount of space in the string.
74830 let nGramSize = 0;
74831 // Size of the left padding.
74832 nGramSize += leftPadding * this.leftPad.length;
74833 // Size of the tokens.
74834 for (let n = 0; n < numTokens; ++n) {
74835 nGramSize += data[dataStartIndex + n].length;
74836 }
74837 // Size of the right padding.
74838 nGramSize += rightPadding * this.rightPad.length;
74839 // Size of the separators.
74840 const numSeparators = leftPadding + rightPadding + numTokens - 1;
74841 nGramSize += numSeparators * this.separator.length;
74842 // Build the nGram.
74843 output[outputStartIndex + nGramIndex] = new Uint8Array(nGramSize);
74844 const nGram = output[outputStartIndex + nGramIndex];
74845 let nextNGramIndex = 0;
74846 const appendToNGram = (str) => str.forEach((value) => nGram[nextNGramIndex++] = value);
74847 for (let n = 0; n < leftPadding; ++n) {
74848 appendToNGram(this.leftPad);
74849 appendToNGram(this.separator);
74850 }
74851 // Only output first numTokens - 1 pairs of data and separator
74852 for (let n = 0; n < numTokens - 1; ++n) {
74853 appendToNGram(data[dataStartIndex + n]);
74854 appendToNGram(this.separator);
74855 }
74856 // Handle case when there are no tokens or no right padding as these
74857 // can result in consecutive separators.
74858 if (numTokens > 0) {
74859 // If we have tokens, then output last and then pair each separator
74860 // with the right padding that follows, to ensure nGram ends either with
74861 // the token or with the right pad.
74862 appendToNGram(data[dataStartIndex + numTokens - 1]);
74863 for (let n = 0; n < rightPadding; ++n) {
74864 appendToNGram(this.separator);
74865 appendToNGram(this.rightPad);
74866 }
74867 }
74868 else {
74869 // If we don't have tokens, then the last item inserted into the nGram
74870 // has been the separator from the left padding loop above. Hence,
74871 // output right pad and separator and make sure to finish with a
74872 // padding, not a separator.
74873 for (let n = 0; n < rightPadding - 1; ++n) {
74874 appendToNGram(this.rightPad);
74875 appendToNGram(this.separator);
74876 }
74877 appendToNGram(this.rightPad);
74878 }
74879 }
74880 }
74881 // Data and splits together form the definition of the ragged tensor,
74882 // where data is 1 dimensional and contains the values of the tensor
74883 // and splits denotes the indices at which each row starts.
74884 compute(data, splits) {
74885 // Validate that the splits are valid indices into data, only if there are
74886 // splits specified.
74887 const inputDataSize = data.length;
74888 const splitsSize = splits.length;
74889 if (splitsSize > 0) {
74890 let prevSplit = splits[0];
74891 if (prevSplit !== 0) {
74892 throw new Error(`First split value must be 0, got ${prevSplit}`);
74893 }
74894 for (let i = 1; i < splitsSize; ++i) {
74895 let validSplits = splits[i] >= prevSplit;
74896 validSplits = validSplits && (splits[i] <= inputDataSize);
74897 if (!validSplits) {
74898 throw new Error(`Invalid split value ${splits[i]}, must be in [${prevSplit}, ${inputDataSize}]`);
74899 }
74900 prevSplit = splits[i];
74901 }
74902 if (prevSplit !== inputDataSize) {
74903 throw new Error(`Last split value must be data size. Expected ${inputDataSize}, got ${prevSplit}`);
74904 }
74905 }
74906 const numBatchItems = splitsSize - 1;
74907 const nGramsSplits = getArrayFromDType('int32', splitsSize);
74908 // If there is no data or size, return an empty ragged tensor.
74909 if (inputDataSize === 0 || splitsSize === 0) {
74910 const empty = new Array(inputDataSize);
74911 for (let i = 0; i <= numBatchItems; ++i) {
74912 nGramsSplits[i] = 0;
74913 }
74914 return [empty, nGramsSplits];
74915 }
74916 nGramsSplits[0] = 0;
74917 for (let i = 1; i <= numBatchItems; ++i) {
74918 const length = splits[i] - splits[i - 1];
74919 let numNGrams = 0;
74920 this.nGramWidths.forEach((nGramWidth) => {
74921 numNGrams += this.getNumNGrams(length, nGramWidth);
74922 });
74923 if (this.preserveShort && length > 0 && numNGrams === 0) {
74924 numNGrams = 1;
74925 }
74926 nGramsSplits[i] = nGramsSplits[i - 1] + numNGrams;
74927 }
74928 const nGrams = new Array(nGramsSplits[numBatchItems]);
74929 for (let i = 0; i < numBatchItems; ++i) {
74930 const splitIndex = splits[i];
74931 let outputStartIdx = nGramsSplits[i];
74932 this.nGramWidths.forEach((nGramWidth) => {
74933 const length = splits[i + 1] - splits[i];
74934 const numNGrams = this.getNumNGrams(length, nGramWidth);
74935 this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth);
74936 outputStartIdx += numNGrams;
74937 });
74938 // If we're preserving short sequences, check to see if no sequence was
74939 // generated by comparing the current output start idx to the original
74940 // one (nGramSplitsdata). If no ngrams were generated, then they will
74941 // be equal (since we increment outputStartIdx by numNGrams every
74942 // time we create a set of ngrams.)
74943 if (this.preserveShort && outputStartIdx === nGramsSplits[i]) {
74944 const dataLength = splits[i + 1] - splits[i];
74945 // One legitimate reason to not have any ngrams when this.preserveShort
74946 // is true is if the sequence itself is empty. In that case, move on.
74947 if (dataLength === 0) {
74948 continue;
74949 }
74950 // We don't have to worry about dynamic padding sizes here: if padding
74951 // was dynamic, every sequence would have had sufficient padding to
74952 // generate at least one nGram.
74953 const nGramWidth = dataLength + 2 * this.padWidth;
74954 const numNGrams = 1;
74955 this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth);
74956 }
74957 }
74958 return [nGrams, nGramsSplits];
74959 }
74960 }
74961 function stringNGramsImpl(data, dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
74962 return new StringNGramsOp(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences)
74963 .compute(data, dataSplits);
74964 }
74965
74966 /**
74967 * @license
74968 * Copyright 2021 Google LLC. All Rights Reserved.
74969 * Licensed under the Apache License, Version 2.0 (the "License");
74970 * you may not use this file except in compliance with the License.
74971 * You may obtain a copy of the License at
74972 *
74973 * http://www.apache.org/licenses/LICENSE-2.0
74974 *
74975 * Unless required by applicable law or agreed to in writing, software
74976 * distributed under the License is distributed on an "AS IS" BASIS,
74977 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74978 * See the License for the specific language governing permissions and
74979 * limitations under the License.
74980 * =============================================================================
74981 */
74982 function split(str, delimiters, skipEmpty, result) {
74983 if (!str.length) {
74984 return;
74985 }
74986 // When the delimiter is empty, the input is split into individual characters.
74987 if (delimiters.length === 0) {
74988 for (let i = 0; i < str.length; ++i) {
74989 result.push(str.subarray(i, i + 1));
74990 }
74991 return;
74992 }
74993 // When there is one delimiter, the input is split only at that delimiter.
74994 if (delimiters.length === 1) {
74995 const delimiter = delimiters[0];
74996 let f = str.indexOf(delimiter);
74997 while (f !== -1) {
74998 const token = str.subarray(0, f);
74999 if (!skipEmpty || token.length !== 0) {
75000 result.push(token);
75001 }
75002 str = str.subarray(f + 1);
75003 f = str.indexOf(delimiter);
75004 }
75005 if (!skipEmpty || str.length !== 0) {
75006 result.push(str);
75007 }
75008 return;
75009 }
75010 // When there are multiple delimiters, the input is split at every instance
75011 // one of the delimiters appears.
75012 let tokenStart = 0;
75013 for (let i = 0; i < str.length + 1; i++) {
75014 if ((i === str.length) || (delimiters.indexOf(str[i]) !== -1)) {
75015 const token = str.subarray(tokenStart, i);
75016 if (!skipEmpty || token.length !== 0) {
75017 result.push(token);
75018 }
75019 tokenStart = i + 1;
75020 }
75021 }
75022 }
75023 function stringSplitImpl(input, delimiter, skipEmpty) {
75024 const batchSize = input.length;
75025 // Empty delimiter means split the input character by character.
75026 const tokens = [];
75027 let outputSize = 0;
75028 let maxNumEntries = 0;
75029 const numIndices = new Array(batchSize);
75030 for (let i = 0; i < batchSize; ++i) {
75031 const prevTokensLength = tokens.length;
75032 split(input[i], delimiter, skipEmpty, tokens);
75033 const nEntries = tokens.length - prevTokensLength;
75034 numIndices[i] = nEntries;
75035 outputSize += nEntries;
75036 maxNumEntries = Math.max(maxNumEntries, nEntries);
75037 }
75038 const indices = getArrayFromDType('int32', outputSize * 2);
75039 const values = new Array(outputSize);
75040 const shape = [batchSize, maxNumEntries];
75041 let c = 0;
75042 for (let i = 0; i < batchSize; ++i) {
75043 for (let j = 0; j < numIndices[i]; ++j) {
75044 // indices is a 2d tensor with shape of [outputSize, 2]
75045 indices[c * 2] = i;
75046 indices[c * 2 + 1] = j;
75047 values[c] = tokens[c];
75048 ++c;
75049 }
75050 }
75051 return [indices, values, shape];
75052 }
75053
75054 /**
75055 * @license
75056 * Copyright 2021 Google LLC. All Rights Reserved.
75057 * Licensed under the Apache License, Version 2.0 (the "License");
75058 * you may not use this file except in compliance with the License.
75059 * You may obtain a copy of the License at
75060 *
75061 * http://www.apache.org/licenses/LICENSE-2.0
75062 *
75063 * Unless required by applicable law or agreed to in writing, software
75064 * distributed under the License is distributed on an "AS IS" BASIS,
75065 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75066 * See the License for the specific language governing permissions and
75067 * limitations under the License.
75068 * =============================================================================
75069 */
75070 function stringToHashBucketFastImpl(input, numBuckets) {
75071 const output = getArrayFromDType('int32', input.length);
75072 for (let i = 0; i < input.length; ++i) {
75073 output[i] =
75074 fingerPrint64(input[i]).modulo(numBuckets).getLowBitsUnsigned();
75075 }
75076 return output;
75077 }
75078
75079 /**
75080 * @license
75081 * Copyright 2020 Google LLC. All Rights Reserved.
75082 * Licensed under the Apache License, Version 2.0 (the "License");
75083 * you may not use this file except in compliance with the License.
75084 * You may obtain a copy of the License at
75085 *
75086 * http://www.apache.org/licenses/LICENSE-2.0
75087 *
75088 * Unless required by applicable law or agreed to in writing, software
75089 * distributed under the License is distributed on an "AS IS" BASIS,
75090 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75091 * See the License for the specific language governing permissions and
75092 * limitations under the License.
75093 * =============================================================================
75094 */
75095 const subImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue - bValue));
75096 const subComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
75097 return { real: aReal - bReal, imag: aImag - bImag };
75098 }));
75099 const sub$1 = binaryKernelFunc$1(Sub, subImpl, subComplexImpl);
75100 const subConfig$1 = {
75101 kernelName: Sub,
75102 backendName: 'cpu',
75103 kernelFunc: sub$1
75104 };
75105
75106 /**
75107 * @license
75108 * Copyright 2019 Google LLC. All Rights Reserved.
75109 * Licensed under the Apache License, Version 2.0 (the "License");
75110 * you may not use this file except in compliance with the License.
75111 * You may obtain a copy of the License at
75112 *
75113 * http://www.apache.org/licenses/LICENSE-2.0
75114 *
75115 * Unless required by applicable law or agreed to in writing, software
75116 * distributed under the License is distributed on an "AS IS" BASIS,
75117 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75118 * See the License for the specific language governing permissions and
75119 * limitations under the License.
75120 * =============================================================================
75121 */
75122 /**
75123 * An implementation of the tile kernel shared between webgl and cpu for string
75124 * tensors only.
75125 */
75126 function tileImpl(xBuf, reps) {
75127 const newShape = new Array(xBuf.rank);
75128 for (let i = 0; i < newShape.length; i++) {
75129 newShape[i] = xBuf.shape[i] * reps[i];
75130 }
75131 const result = buffer(newShape, xBuf.dtype);
75132 for (let i = 0; i < result.values.length; ++i) {
75133 const newLoc = result.indexToLoc(i);
75134 const originalLoc = new Array(xBuf.rank);
75135 for (let j = 0; j < originalLoc.length; j++) {
75136 originalLoc[j] = newLoc[j] % xBuf.shape[j];
75137 }
75138 const originalIndex = xBuf.locToIndex(originalLoc);
75139 result.values[i] = xBuf.values[originalIndex];
75140 }
75141 return result;
75142 }
75143
75144 /**
75145 * @license
75146 * Copyright 2020 Google LLC. All Rights Reserved.
75147 * Licensed under the Apache License, Version 2.0 (the "License");
75148 * you may not use this file except in compliance with the License.
75149 * You may obtain a copy of the License at
75150 *
75151 * http://www.apache.org/licenses/LICENSE-2.0
75152 *
75153 * Unless required by applicable law or agreed to in writing, software
75154 * distributed under the License is distributed on an "AS IS" BASIS,
75155 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75156 * See the License for the specific language governing permissions and
75157 * limitations under the License.
75158 * =============================================================================
75159 */
75160 const comparePair = (a, b) => {
75161 const valueDiff = b.value - a.value;
75162 return valueDiff === 0 ? a.index - b.index : valueDiff;
75163 };
75164 /**
75165 * Partitions array where all elements smaller than the (k+1) smallest element
75166 * are found to the left of it, and all larger to the right of it.
75167 * Based on the Floyd-Rivest Algorithm, ref:
75168 * https://en.wikipedia.org/wiki/Floyd%E2%80%93Rivest_algorithm
75169 * @param array: Array to partition
75170 * @param left: Left index for the interval
75171 * @param right: Right index for the interval
75172 * @param k: Desired index value, where array[k] is the (k+1)th smallest element
75173 * when left = 0
75174 */
75175 function select$2(array, k, left = 0, right = array.length - 1) {
75176 while (right > left) {
75177 // Use select recursively to sample a smaller set of size s
75178 // the arbitrary constants 600 and 0.5 are used in the original
75179 // version to minimize execution time.
75180 if (right - left > 600) {
75181 const n = right - left + 1;
75182 const i = k - left + 1;
75183 const z = Math.log(n);
75184 const s = 0.5 * Math.exp(2 * z / 3);
75185 const sd = 0.5 * Math.sqrt(z * s * (n - s) / n) * Math.sign(i - n / 2);
75186 const newLeft = Math.max(left, Math.floor(k - i * s / n + sd));
75187 const newRight = Math.min(right, Math.floor(k + (n - i) * s / n + sd));
75188 select$2(array, k, newLeft, newRight);
75189 }
75190 // partition the elements between left and right around t
75191 const t = array[k];
75192 let i = left;
75193 let j = right;
75194 swap(array, left, k);
75195 if (comparePair(array[right], t) > 0) {
75196 swap(array, left, right);
75197 }
75198 while (i < j) {
75199 swap(array, i, j);
75200 i++;
75201 j--;
75202 while (comparePair(array[i], t) < 0) {
75203 i = i + 1;
75204 }
75205 while (comparePair(array[j], t) > 0) {
75206 j = j - 1;
75207 }
75208 }
75209 if (comparePair(array[left], t) === 0) {
75210 swap(array, left, j);
75211 }
75212 else {
75213 j = j + 1;
75214 swap(array, j, right);
75215 }
75216 // Adjust left and right towards the boundaries of the subset
75217 // containing the (k - left + 1)th smallest element.
75218 if (j <= k) {
75219 left = j + 1;
75220 }
75221 if (k <= j) {
75222 right = j - 1;
75223 }
75224 }
75225 }
75226 function topKImpl(x, xShape, xDtype, k, sorted) {
75227 // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.
75228 const lastDim = xShape[xShape.length - 1];
75229 const [batch, size] = [x.length / lastDim, lastDim];
75230 const allTopKVals = getTypedArrayFromDType(xDtype, batch * k);
75231 const allTopKIndices = getTypedArrayFromDType('int32', batch * k);
75232 for (let b = 0; b < batch; b++) {
75233 const offset = b * size;
75234 const vals = x.subarray(offset, offset + size);
75235 let valAndInd = new Array(vals.length);
75236 vals.forEach((value, index) => valAndInd[index] = { value, index });
75237 if (k < valAndInd.length) {
75238 select$2(valAndInd, k);
75239 valAndInd = valAndInd.slice(0, k);
75240 }
75241 if (sorted) {
75242 valAndInd.sort(comparePair);
75243 }
75244 const outOffset = b * k;
75245 const topKVals = allTopKVals.subarray(outOffset, outOffset + k);
75246 const topKIndices = allTopKIndices.subarray(outOffset, outOffset + k);
75247 for (let i = 0; i < k; i++) {
75248 topKVals[i] = valAndInd[i].value;
75249 topKIndices[i] = valAndInd[i].index;
75250 }
75251 }
75252 // Reshape back to the original input shape, except that the last
75253 // dimension is k.
75254 const outputShape = xShape.slice();
75255 outputShape[outputShape.length - 1] = k;
75256 return [
75257 buffer(outputShape, xDtype, allTopKVals),
75258 buffer(outputShape, 'int32', allTopKIndices)
75259 ];
75260 }
75261
75262 /**
75263 * @license
75264 * Copyright 2020 Google LLC. All Rights Reserved.
75265 * Licensed under the Apache License, Version 2.0 (the "License");
75266 * you may not use this file except in compliance with the License.
75267 * You may obtain a copy of the License at
75268 *
75269 * http://www.apache.org/licenses/LICENSE-2.0
75270 *
75271 * Unless required by applicable law or agreed to in writing, software
75272 * distributed under the License is distributed on an "AS IS" BASIS,
75273 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75274 * See the License for the specific language governing permissions and
75275 * limitations under the License.
75276 * =============================================================================
75277 */
75278 function uniqueImpl(values, axis, shape, dtype) {
75279 // Normalize and validate axis.
75280 const $axis = parseAxisParam(axis, shape)[0];
75281 // Calculate the new shape that is suitable for extracting data along the
75282 // given axis.
75283 //
75284 // The rank is 3.
75285 // The size of the 1st dimension is the size of all the axes < the given axis.
75286 // The size of the 2nd dimension is the same as the size of the given axis.
75287 // The size of the 3rd dimension is the size of all the axes > the given axis.
75288 //
75289 // For example, for a 4D tensor with shape=[2, 3, 5, 4] and axis=2, the
75290 // newShape would be: [2*3, 5, 4].
75291 //
75292 // Note that this is not the final output shape. This will be the shape for an
75293 // intermediate TensorBuffer (see inputBuffer below) to allow us to extract
75294 // values along the given axis. To demonstrate how it works, consider the
75295 // following example:
75296 //
75297 // Input: a 3D tensor, with shape [1, 2, 3]
75298 // [
75299 // [
75300 // [1,2,3],
75301 // [4,5,6]
75302 // ]
75303 // ]
75304 // Axis: 2 (the last axis).
75305 // Along axis 2, we expect to extract 3 tensors: [1,4], [2,5], [3,6].
75306 //
75307 // For this example, newShape would be: [2, 3, 1], where 2 is calculated from
75308 // 1*2. The re-shaped data would look like:
75309 //
75310 // [
75311 // [
75312 // [1], [2], [3]
75313 // ],
75314 // [
75315 // [4], [5], [6]
75316 // ]
75317 // ]
75318 //
75319 // Then, we can construct a 3-level nested loop by the following dimension
75320 // order to extract the values along the axis (dimension1):
75321 // i: dimension1 // 0,1,2 (newShape[1])
75322 // m: dimension0 // 0,1 (newShape[0])
75323 // n: dimension2 // 0 (newShape[2])
75324 //
75325 // m, i, n
75326 // ---------
75327 // Iteration 0: data at [0, 0, 0] => "1"
75328 // Iteration 1: data at [1, 0, 0] => "4"
75329 // We got [1,4].
75330 // Iteration 2: data at [0, 1, 0] => "2"
75331 // Iteration 3: data at [1, 1, 0] => "5"
75332 // We got [2,5].
75333 // Iteration 4: data at [0, 2, 0] => "3"
75334 // Iteration 5: data at [1, 2, 0] => "6"
75335 // We got [3,6].
75336 const newShape = [1, shape[0], 1];
75337 for (let i = 0; i < $axis; i++) {
75338 newShape[0] *= shape[i];
75339 }
75340 newShape[1] = shape[$axis];
75341 for (let i = $axis + 1; i < shape.length; i++) {
75342 newShape[2] *= shape[i];
75343 }
75344 // A map from unique elements (their string representations) to their values
75345 // in "indices" (below).
75346 const uniqueElements = new Map();
75347 // The indices of each unique element in the original tensor along the given
75348 // axis. It is 1D and has the same size as the given axis.
75349 const indices = new Int32Array(shape[$axis]);
75350 // Create a buffer so we can easily extract value at a given location.
75351 const inputBuffer = new TensorBuffer(newShape, dtype, values);
75352 // The indices along the given axis that have unique elements. This is a
75353 // de-duped version of "indices" above.
75354 const uniqueIndices = [];
75355 const is1DTensor = newShape[0] === 1 && newShape[2] === 1;
75356 for (let i = 0; i < shape[$axis]; i++) {
75357 // Extract values along the axis.
75358 let element;
75359 if (is1DTensor) {
75360 // Fast path for 1D tensor input.
75361 element = values[i].toString();
75362 }
75363 else {
75364 const axisValues = [];
75365 for (let m = 0; m < newShape[0]; m++) {
75366 for (let n = 0; n < newShape[2]; n++) {
75367 axisValues.push(inputBuffer.get(m, i, n));
75368 }
75369 }
75370 element = axisValues.join(',');
75371 }
75372 // Dedup and update various indices.
75373 const existingIndex = uniqueElements.get(element);
75374 if (existingIndex != null) {
75375 indices[i] = existingIndex;
75376 }
75377 else {
75378 const uniqueIndex = uniqueElements.size;
75379 uniqueElements.set(element, uniqueIndex);
75380 indices[i] = uniqueIndex;
75381 uniqueIndices.push(i);
75382 }
75383 }
75384 // Now we know where each of the unique elements are located along the axis
75385 // (uniqueIndices). Extract them from input buffer and store them in the
75386 // output buffer.
75387 const outputTmpShape = newShape.slice();
75388 outputTmpShape[1] = uniqueElements.size;
75389 const outputBuffer = new TensorBuffer(outputTmpShape, dtype);
75390 uniqueIndices.forEach((uniqueElementIndex, i) => {
75391 for (let m = 0; m < newShape[0]; m++) {
75392 for (let n = 0; n < newShape[2]; n++) {
75393 outputBuffer.set(inputBuffer.get(m, uniqueElementIndex, n), m, i, n);
75394 }
75395 }
75396 });
75397 // The output shape can be calculated from the input shape with the size of
75398 // the given axis replaced by the number of unique elements along that axis.
75399 const outputShape = shape.slice();
75400 outputShape[$axis] = outputTmpShape[1];
75401 return {
75402 outputValues: outputBuffer.values,
75403 outputShape,
75404 indices,
75405 };
75406 }
75407
75408 /**
75409 * @license
75410 * Copyright 2020 Google LLC. All Rights Reserved.
75411 * Licensed under the Apache License, Version 2.0 (the "License");
75412 * you may not use this file except in compliance with the License.
75413 * You may obtain a copy of the License at
75414 *
75415 * http://www.apache.org/licenses/LICENSE-2.0
75416 *
75417 * Unless required by applicable law or agreed to in writing, software
75418 * distributed under the License is distributed on an "AS IS" BASIS,
75419 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75420 * See the License for the specific language governing permissions and
75421 * limitations under the License.
75422 * =============================================================================
75423 */
75424
75425 var shared = /*#__PURE__*/Object.freeze({
75426 __proto__: null,
75427 addImpl: addImpl,
75428 bincountImpl: bincountImpl,
75429 bincountReduceImpl: bincountReduceImpl,
75430 bitwiseAndImpl: bitwiseAndImpl,
75431 castImpl: castImpl,
75432 ceilImpl: ceilImpl,
75433 concatImpl: concatImpl$1,
75434 equalImpl: equalImpl,
75435 expImpl: expImpl,
75436 expm1Impl: expm1Impl,
75437 floorDivImpl: floorDivImpl,
75438 floorImpl: floorImpl,
75439 gatherNdImpl: gatherNdImpl,
75440 gatherV2Impl: gatherV2Impl,
75441 greaterEqualImpl: greaterEqualImpl,
75442 greaterImpl: greaterImpl,
75443 lessEqualImpl: lessEqualImpl,
75444 lessImpl: lessImpl,
75445 linSpaceImpl: linSpaceImpl,
75446 logImpl: logImpl,
75447 maxImpl: maxImpl$1,
75448 maximumImpl: maximumImpl,
75449 minimumImpl: minimumImpl,
75450 multiplyImpl: multiplyImpl,
75451 negImpl: negImpl,
75452 notEqualImpl: notEqualImpl,
75453 prodImpl: prodImpl,
75454 raggedGatherImpl: raggedGatherImpl,
75455 raggedRangeImpl: raggedRangeImpl,
75456 raggedTensorToTensorImpl: raggedTensorToTensorImpl,
75457 rangeImpl: rangeImpl,
75458 rsqrtImpl: rsqrtImpl,
75459 scatterImpl: scatterImpl,
75460 sigmoidImpl: sigmoidImpl,
75461 simpleAbsImpl: simpleAbsImpl,
75462 sliceImpl: sliceImpl,
75463 sparseFillEmptyRowsImpl: sparseFillEmptyRowsImpl,
75464 sparseReshapeImpl: sparseReshapeImpl,
75465 sparseSegmentReductionImpl: sparseSegmentReductionImpl,
75466 sqrtImpl: sqrtImpl,
75467 squaredDifferenceImpl: squaredDifferenceImpl,
75468 staticRegexReplaceImpl: staticRegexReplaceImpl,
75469 stridedSliceImpl: stridedSliceImpl,
75470 stringNGramsImpl: stringNGramsImpl,
75471 stringSplitImpl: stringSplitImpl,
75472 stringToHashBucketFastImpl: stringToHashBucketFastImpl,
75473 subImpl: subImpl,
75474 tileImpl: tileImpl,
75475 topKImpl: topKImpl,
75476 transposeImpl: transposeImpl$1,
75477 uniqueImpl: uniqueImpl
75478 });
75479
75480 /** @license See the LICENSE file. */
75481 // This code is auto-generated, do not modify this file!
75482 const version$3 = '4.22.0';
75483
75484 /**
75485 * @license
75486 * Copyright 2020 Google LLC. All Rights Reserved.
75487 * Licensed under the Apache License, Version 2.0 (the "License");
75488 * you may not use this file except in compliance with the License.
75489 * You may obtain a copy of the License at
75490 *
75491 * http://www.apache.org/licenses/LICENSE-2.0
75492 *
75493 * Unless required by applicable law or agreed to in writing, software
75494 * distributed under the License is distributed on an "AS IS" BASIS,
75495 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75496 * See the License for the specific language governing permissions and
75497 * limitations under the License.
75498 * =============================================================================
75499 */
75500 // Side effects for default initialization of MathBackendCPU
75501 registerBackend('cpu', () => new MathBackendCPU(), 1 /* priority */);
75502
75503 /**
75504 * @license
75505 * Copyright 2020 Google LLC. All Rights Reserved.
75506 * Licensed under the Apache License, Version 2.0 (the License);
75507 * you may not use this file except in compliance with the License.
75508 * You may obtain a copy of the License at
75509 *
75510 * http://www.apache.org/licenses/LICENSE-2.0
75511 *
75512 * Unless required by applicable law or agreed to in writing, software
75513 * distributed under the License is distributed on an AS IS BASIS,
75514 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75515 * See the License for the specific language governing permissions and
75516 * limitations under the License.
75517 * =============================================================================
75518 */
75519 const elu$1 = unaryKernelFunc$1(Elu$1, (xi) => xi >= 0 ? xi : (Math.exp(xi) - 1));
75520 const eluConfig$1 = {
75521 kernelName: Elu$1,
75522 backendName: 'cpu',
75523 kernelFunc: elu$1,
75524 };
75525
75526 /**
75527 * @license
75528 * Copyright 2020 Google LLC. All Rights Reserved.
75529 * Licensed under the Apache License, Version 2.0 (the "License");
75530 * you may not use this file except in compliance with the License.
75531 * You may obtain a copy of the License at
75532 *
75533 * http://www.apache.org/licenses/LICENSE-2.0
75534 *
75535 * Unless required by applicable law or agreed to in writing, software
75536 * distributed under the License is distributed on an "AS IS" BASIS,
75537 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75538 * See the License for the specific language governing permissions and
75539 * limitations under the License.
75540 * =============================================================================
75541 */
75542 function leakyRelu$1(args) {
75543 const { inputs, backend, attrs } = args;
75544 const { x } = inputs;
75545 const { alpha } = attrs;
75546 assertNotComplex$1([x], 'leakyRelu');
75547 const xSize = sizeFromShape(x.shape);
75548 const xVals = backend.data.get(x.dataId).values;
75549 const outVals = getTypedArrayFromDType('float32', xSize);
75550 for (let i = 0; i < xVals.length; i++) {
75551 outVals[i] = xVals[i] < 0 ? alpha * xVals[i] : xVals[i];
75552 }
75553 return backend.makeTensorInfo(x.shape, 'float32', outVals);
75554 }
75555 const leakyReluConfig$1 = {
75556 kernelName: LeakyRelu,
75557 backendName: 'cpu',
75558 kernelFunc: leakyRelu$1
75559 };
75560
75561 /**
75562 * @license
75563 * Copyright 2020 Google LLC. All Rights Reserved.
75564 * Licensed under the Apache License, Version 2.0 (the License);
75565 * you may not use this file except in compliance with the License.
75566 * You may obtain a copy of the License at
75567 *
75568 * http://www.apache.org/licenses/LICENSE-2.0
75569 *
75570 * Unless required by applicable law or agreed to in writing, software
75571 * distributed under the License is distributed on an AS IS BASIS,
75572 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75573 * See the License for the specific language governing permissions and
75574 * limitations under the License.
75575 * =============================================================================
75576 */
75577 const preluImpl = createSimpleBinaryKernelImpl((xValue, aValue) => xValue < 0 ? aValue * xValue : xValue);
75578 function prelu$1(args) {
75579 const { inputs, backend } = args;
75580 const { x, alpha } = inputs;
75581 assertNotComplex$1([x, alpha], 'prelu');
75582 const aVals = backend.data.get(x.dataId).values;
75583 const bVals = backend.data.get(alpha.dataId).values;
75584 const [resultData, resultShape] = preluImpl(x.shape, alpha.shape, aVals, bVals, 'float32');
75585 return backend.makeTensorInfo(resultShape, 'float32', resultData);
75586 }
75587 const preluConfig$1 = {
75588 kernelName: Prelu,
75589 backendName: 'cpu',
75590 kernelFunc: prelu$1,
75591 };
75592
75593 /**
75594 * @license
75595 * Copyright 2020 Google LLC. All Rights Reserved.
75596 * Licensed under the Apache License, Version 2.0 (the License);
75597 * you may not use this file except in compliance with the License.
75598 * You may obtain a copy of the License at
75599 *
75600 * http://www.apache.org/licenses/LICENSE-2.0
75601 *
75602 * Unless required by applicable law or agreed to in writing, software
75603 * distributed under the License is distributed on an AS IS BASIS,
75604 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75605 * See the License for the specific language governing permissions and
75606 * limitations under the License.
75607 * =============================================================================
75608 */
75609 const relu$1 = unaryKernelFunc$1(Relu$1, (xi) => Math.max(0, xi));
75610 const reluConfig$1 = {
75611 kernelName: Relu$1,
75612 backendName: 'cpu',
75613 kernelFunc: relu$1,
75614 };
75615
75616 /**
75617 * @license
75618 * Copyright 2020 Google LLC. All Rights Reserved.
75619 * Licensed under the Apache License, Version 2.0 (the License);
75620 * you may not use this file except in compliance with the License.
75621 * You may obtain a copy of the License at
75622 *
75623 * http://www.apache.org/licenses/LICENSE-2.0
75624 *
75625 * Unless required by applicable law or agreed to in writing, software
75626 * distributed under the License is distributed on an AS IS BASIS,
75627 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75628 * See the License for the specific language governing permissions and
75629 * limitations under the License.
75630 * =============================================================================
75631 */
75632 const relu6$1 = unaryKernelFunc$1(Relu6$1, (xi) => Math.min(Math.max(0, xi), 6));
75633 const relu6Config$1 = {
75634 kernelName: Relu6$1,
75635 backendName: 'cpu',
75636 kernelFunc: relu6$1,
75637 };
75638
75639 /**
75640 * @license
75641 * Copyright 2020 Google LLC. All Rights Reserved.
75642 * Licensed under the Apache License, Version 2.0 (the "License");
75643 * you may not use this file except in compliance with the License.
75644 * You may obtain a copy of the License at
75645 *
75646 * http://www.apache.org/licenses/LICENSE-2.0
75647 *
75648 * Unless required by applicable law or agreed to in writing, software
75649 * distributed under the License is distributed on an "AS IS" BASIS,
75650 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75651 * See the License for the specific language governing permissions and
75652 * limitations under the License.
75653 * =============================================================================
75654 */
75655 function applyActivation(backend, x, activation, preluActivationWeights, leakyreluAlpha) {
75656 if (activation === 'linear') {
75657 return identity$1({ inputs: { x }, backend });
75658 }
75659 else if (activation === 'relu') {
75660 return relu$1({ inputs: { x }, backend });
75661 }
75662 else if (activation === 'elu') {
75663 return elu$1({ inputs: { x }, backend });
75664 }
75665 else if (activation === 'relu6') {
75666 return relu6$1({ inputs: { x }, backend });
75667 }
75668 else if (activation === 'prelu') {
75669 return prelu$1({ inputs: { x, alpha: preluActivationWeights }, backend });
75670 }
75671 else if (activation === 'leakyrelu') {
75672 return leakyRelu$1({ inputs: { x }, backend, attrs: { alpha: leakyreluAlpha } });
75673 }
75674 else if (activation === 'sigmoid') {
75675 return sigmoid$1({ inputs: { x }, backend });
75676 }
75677 throw new Error(`Activation ${activation} has not been implemented for the CPU backend.`);
75678 }
75679
75680 /**
75681 * @license
75682 * Copyright 2020 Google LLC. All Rights Reserved.
75683 * Licensed under the Apache License, Version 2.0 (the "License");
75684 * you may not use this file except in compliance with the License.
75685 * You may obtain a copy of the License at
75686 *
75687 * http://www.apache.org/licenses/LICENSE-2.0
75688 *
75689 * Unless required by applicable law or agreed to in writing, software
75690 * distributed under the License is distributed on an "AS IS" BASIS,
75691 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75692 * See the License for the specific language governing permissions and
75693 * limitations under the License.
75694 * =============================================================================
75695 */
75696 function reshape$1(args) {
75697 const { inputs, backend, attrs } = args;
75698 const { x } = inputs;
75699 const { shape } = attrs;
75700 const xSize = sizeFromShape(x.shape);
75701 const $shape = inferFromImplicitShape(shape, xSize);
75702 const $xSize = sizeFromShape($shape);
75703 assert$1(xSize === $xSize, () => `The new shape (${$shape}) has ${$xSize} elements and the old ` +
75704 `shape (${x.shape}) has ${xSize} elements. The new shape and old ` +
75705 `shape must have the same number of elements.`);
75706 backend.incRef(x.dataId);
75707 const xData = backend.data.get(x.dataId);
75708 if (xData.complexTensorInfos != null) {
75709 const real = xData.complexTensorInfos.real;
75710 const imag = xData.complexTensorInfos.imag;
75711 real.shape = $shape;
75712 imag.shape = $shape;
75713 }
75714 return { dataId: x.dataId, shape: $shape, dtype: x.dtype };
75715 }
75716 const reshapeConfig$1 = {
75717 kernelName: Reshape$1,
75718 backendName: 'cpu',
75719 kernelFunc: reshape$1
75720 };
75721
75722 /**
75723 * @license
75724 * Copyright 2020 Google LLC. All Rights Reserved.
75725 * Licensed under the Apache License, Version 2.0 (the License);
75726 * you may not use this file except in compliance with the License.
75727 * You may obtain a copy of the License at
75728 *
75729 * http://www.apache.org/licenses/LICENSE-2.0
75730 *
75731 * Unless required by applicable law or agreed to in writing, software
75732 * distributed under the License is distributed on an AS IS BASIS,
75733 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75734 * See the License for the specific language governing permissions and
75735 * limitations under the License.
75736 * =============================================================================
75737 */
75738 function batchMatMul$1(args) {
75739 const { inputs, backend, attrs } = args;
75740 const { a, b } = inputs;
75741 const { transposeA, transposeB } = attrs;
75742 assertNotComplex$1([a, b], 'matMul');
75743 const aRank = a.shape.length;
75744 const bRank = b.shape.length;
75745 const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
75746 const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
75747 const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
75748 const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
75749 const outerDimsA = a.shape.slice(0, -2);
75750 const outerDimsB = b.shape.slice(0, -2);
75751 const batchDimA = sizeFromShape(outerDimsA);
75752 const batchDimB = sizeFromShape(outerDimsB);
75753 const outShapeOuterDims = assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2));
75754 const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
75755 assert$1(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` +
75756 `${innerShapeB}) of Tensors with shapes ${a.shape} and ` +
75757 `${b.shape} and transposeA=${transposeA}` +
75758 ` and transposeB=${transposeB} must match.`);
75759 const a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] :
75760 [batchDimA, outerShapeA, innerShapeA];
75761 const b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] :
75762 [batchDimB, innerShapeB, outerShapeB];
75763 // The rest of the implementation is designed to operate on rank-3 tensors
75764 const a3d = reshape$1({ inputs: { x: a }, backend, attrs: { shape: a3dShape } });
75765 const b3d = reshape$1({ inputs: { x: b }, backend, attrs: { shape: b3dShape } });
75766 const sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];
75767 const leftDim = transposeA ? a3d.shape[2] : a3d.shape[1];
75768 const rightDim = transposeB ? b3d.shape[1] : b3d.shape[2];
75769 const batchDim = Math.max(batchDimA, batchDimB);
75770 const a3dValues = backend.data.get(a3d.dataId).values;
75771 const b3dValues = backend.data.get(b3d.dataId).values;
75772 const a3dStrides = computeStrides(a3d.shape);
75773 const b3dStrides = computeStrides(b3d.shape);
75774 const [aBatch, aOuterStep, aInnerStep] = transposeA ?
75775 [a3dStrides[0], 1, a3dStrides[1]] :
75776 [a3dStrides[0], a3dStrides[1], 1];
75777 const [bInnerStep, bOuterStep, bBatch] = transposeB ?
75778 [1, b3dStrides[1], b3dStrides[0]] :
75779 [b3dStrides[1], 1, b3dStrides[0]];
75780 const size = leftDim * rightDim;
75781 const result = buffer([batchDim, leftDim, rightDim], a3d.dtype);
75782 const resVals = result.values;
75783 const blockSize = backend.blockSize;
75784 for (let bi = 0; bi < batchDim; bi++) {
75785 const batchIndexA = bi % batchDimA;
75786 const batchIndexB = bi % batchDimB;
75787 for (let i0 = 0; i0 < leftDim; i0 += blockSize) {
75788 // for when blockSize doesn't evenly divide the input
75789 const iBlock = Math.min(i0 + blockSize, leftDim);
75790 for (let j0 = 0; j0 < rightDim; j0 += blockSize) {
75791 const jBlock = Math.min(j0 + blockSize, rightDim);
75792 for (let k0 = 0; k0 < sharedDim; k0 += blockSize) {
75793 const kBlock = Math.min(k0 + blockSize, sharedDim);
75794 for (let i = i0; i < iBlock; i++) {
75795 for (let j = j0; j < jBlock; j++) {
75796 let sum = 0.0;
75797 for (let k = k0; k < kBlock; k++) {
75798 const aVal =
75799 // tslint:disable-next-line: max-line-length
75800 a3dValues[batchIndexA * aBatch + i * aOuterStep + k * aInnerStep];
75801 const bVal =
75802 // tslint:disable-next-line: max-line-length
75803 b3dValues[k * bInnerStep + j * bOuterStep + batchIndexB * bBatch];
75804 sum += aVal * bVal;
75805 }
75806 resVals[bi * size + (i * rightDim + j)] += sum;
75807 }
75808 }
75809 }
75810 }
75811 }
75812 }
75813 backend.disposeIntermediateTensorInfo(a3d);
75814 backend.disposeIntermediateTensorInfo(b3d);
75815 // set correct shape on output.
75816 return backend.makeTensorInfo(outShape, result.dtype, result.values);
75817 }
75818 const batchMatMulConfig$1 = {
75819 kernelName: BatchMatMul,
75820 backendName: 'cpu',
75821 kernelFunc: batchMatMul$1,
75822 };
75823
75824 /**
75825 * @license
75826 * Copyright 2020 Google LLC. All Rights Reserved.
75827 * Licensed under the Apache License, Version 2.0 (the License);
75828 * you may not use this file except in compliance with the License.
75829 * You may obtain a copy of the License at
75830 *
75831 * http://www.apache.org/licenses/LICENSE-2.0
75832 *
75833 * Unless required by applicable law or agreed to in writing, software
75834 * distributed under the License is distributed on an AS IS BASIS,
75835 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75836 * See the License for the specific language governing permissions and
75837 * limitations under the License.
75838 * =============================================================================
75839 */
75840 function _fusedMatMul$1(args) {
75841 const { inputs, backend, attrs } = args;
75842 const { a, b, bias, preluActivationWeights } = inputs;
75843 const { transposeA, transposeB, activation, leakyreluAlpha } = attrs;
75844 let current;
75845 let addRes;
75846 let activationRes;
75847 const intermediates = [];
75848 const matMulRes = batchMatMul$1({ inputs: { a, b }, attrs: { transposeA, transposeB }, backend });
75849 current = matMulRes;
75850 if (bias) {
75851 addRes = add({ inputs: { a: current, b: bias }, backend });
75852 intermediates.push(current);
75853 current = addRes;
75854 }
75855 if (activation) {
75856 activationRes = applyActivation(backend, current, activation, preluActivationWeights, leakyreluAlpha);
75857 intermediates.push(current);
75858 current = activationRes;
75859 }
75860 for (const i of intermediates) {
75861 backend.disposeIntermediateTensorInfo(i);
75862 }
75863 return current;
75864 }
75865 const _fusedMatMulConfig$1 = {
75866 kernelName: _FusedMatMul,
75867 backendName: 'cpu',
75868 kernelFunc: _fusedMatMul$1,
75869 };
75870
75871 /**
75872 * @license
75873 * Copyright 2020 Google LLC. All Rights Reserved.
75874 * Licensed under the Apache License, Version 2.0 (the License);
75875 * you may not use this file except in compliance with the License.
75876 * You may obtain a copy of the License at
75877 *
75878 * http://www.apache.org/licenses/LICENSE-2.0
75879 *
75880 * Unless required by applicable law or agreed to in writing, software
75881 * distributed under the License is distributed on an AS IS BASIS,
75882 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75883 * See the License for the specific language governing permissions and
75884 * limitations under the License.
75885 * =============================================================================
75886 */
75887 const acos$1 = unaryKernelFunc$1(Acos, (xi) => Math.acos(xi));
75888 const acosConfig$1 = {
75889 kernelName: Acos,
75890 backendName: 'cpu',
75891 kernelFunc: acos$1,
75892 };
75893
75894 /**
75895 * @license
75896 * Copyright 2020 Google LLC. All Rights Reserved.
75897 * Licensed under the Apache License, Version 2.0 (the License);
75898 * you may not use this file except in compliance with the License.
75899 * You may obtain a copy of the License at
75900 *
75901 * http://www.apache.org/licenses/LICENSE-2.0
75902 *
75903 * Unless required by applicable law or agreed to in writing, software
75904 * distributed under the License is distributed on an AS IS BASIS,
75905 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75906 * See the License for the specific language governing permissions and
75907 * limitations under the License.
75908 * =============================================================================
75909 */
75910 const acosh$1 = unaryKernelFunc$1(Acosh, (xi) => Math.acosh(xi));
75911 const acoshConfig$1 = {
75912 kernelName: Acosh,
75913 backendName: 'cpu',
75914 kernelFunc: acosh$1,
75915 };
75916
75917 /**
75918 * @license
75919 * Copyright 2020 Google LLC. All Rights Reserved.
75920 * Licensed under the Apache License, Version 2.0 (the "License");
75921 * you may not use this file except in compliance with the License.
75922 * You may obtain a copy of the License at
75923 *
75924 * http://www.apache.org/licenses/LICENSE-2.0
75925 *
75926 * Unless required by applicable law or agreed to in writing, software
75927 * distributed under the License is distributed on an "AS IS" BASIS,
75928 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75929 * See the License for the specific language governing permissions and
75930 * limitations under the License.
75931 * =============================================================================
75932 */
75933 function addN$1(args) {
75934 const { inputs, backend } = args;
75935 const tensors = inputs;
75936 assertNotComplex$1(inputs, 'addN');
75937 const vals = tensors.map(t => backend.data.get(t.dataId).values);
75938 const outBuf = buffer(tensors[0].shape, tensors[0].dtype);
75939 const outVals = outBuf.values;
75940 for (let i = 0; i < tensors.length; i++) {
75941 const currVals = vals[i];
75942 for (let j = 0; j < outVals.length; j++) {
75943 outVals[j] += currVals[j];
75944 }
75945 }
75946 return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
75947 }
75948 const addNConfig$1 = {
75949 kernelName: AddN,
75950 backendName: 'cpu',
75951 kernelFunc: addN$1
75952 };
75953
75954 /**
75955 * @license
75956 * Copyright 2020 Google LLC. All Rights Reserved.
75957 * Licensed under the Apache License, Version 2.0 (the "License");
75958 * you may not use this file except in compliance with the License.
75959 * You may obtain a copy of the License at
75960 *
75961 * http://www.apache.org/licenses/LICENSE-2.0
75962 *
75963 * Unless required by applicable law or agreed to in writing, software
75964 * distributed under the License is distributed on an "AS IS" BASIS,
75965 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75966 * See the License for the specific language governing permissions and
75967 * limitations under the License.
75968 * =============================================================================
75969 */
75970 function all$1(args) {
75971 const { inputs, backend, attrs } = args;
75972 const { x } = inputs;
75973 const { axis, keepDims } = attrs;
75974 assertNotComplex$1(x, 'all');
75975 const origAxes = parseAxisParam(axis, x.shape);
75976 let axes = origAxes;
75977 const permutedAxes = getAxesPermutation(axes, x.shape.length);
75978 let $x = x;
75979 if (permutedAxes != null) {
75980 $x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
75981 axes = getInnerMostAxes(axes.length, x.shape.length);
75982 }
75983 assertAxesAreInnerMostDims('all', axes, $x.shape.length);
75984 const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
75985 const reduceSize = sizeFromShape(reduceShape);
75986 const vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
75987 const aVals = backend.data.get($x.dataId).values;
75988 for (let i = 0; i < vals.length; ++i) {
75989 const offset = i * reduceSize;
75990 let all = aVals[offset];
75991 for (let j = 0; j < reduceSize; ++j) {
75992 const value = aVals[offset + j];
75993 all = all && value;
75994 }
75995 vals[i] = all;
75996 }
75997 if (permutedAxes != null) {
75998 backend.disposeIntermediateTensorInfo($x);
75999 }
76000 const result = backend.makeTensorInfo(outShape, $x.dtype, vals);
76001 if (keepDims) {
76002 const expandedShape = expandShapeToKeepDim(outShape, origAxes);
76003 const reshapedResult = reshape$1({ inputs: { x: result }, backend, attrs: { shape: expandedShape } });
76004 backend.disposeIntermediateTensorInfo(result);
76005 return reshapedResult;
76006 }
76007 return result;
76008 }
76009 const allConfig$1 = {
76010 kernelName: All,
76011 backendName: 'cpu',
76012 kernelFunc: all$1
76013 };
76014
76015 /**
76016 * @license
76017 * Copyright 2020 Google LLC. All Rights Reserved.
76018 * Licensed under the Apache License, Version 2.0 (the "License");
76019 * you may not use this file except in compliance with the License.
76020 * You may obtain a copy of the License at
76021 *
76022 * http://www.apache.org/licenses/LICENSE-2.0
76023 *
76024 * Unless required by applicable law or agreed to in writing, software
76025 * distributed under the License is distributed on an "AS IS" BASIS,
76026 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76027 * See the License for the specific language governing permissions and
76028 * limitations under the License.
76029 * =============================================================================
76030 */
76031 function any$1(args) {
76032 const { inputs, backend, attrs } = args;
76033 const { x } = inputs;
76034 const { axis, keepDims } = attrs;
76035 assertNotComplex$1(x, 'any');
76036 const origAxes = parseAxisParam(axis, x.shape);
76037 let axes = origAxes;
76038 const permutedAxes = getAxesPermutation(axes, x.shape.length);
76039 let $x = x;
76040 if (permutedAxes != null) {
76041 $x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
76042 axes = getInnerMostAxes(axes.length, x.shape.length);
76043 }
76044 assertAxesAreInnerMostDims('any', axes, $x.shape.length);
76045 const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
76046 const reduceSize = sizeFromShape(reduceShape);
76047 const vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
76048 const aVals = backend.data.get($x.dataId).values;
76049 for (let i = 0; i < vals.length; ++i) {
76050 const offset = i * reduceSize;
76051 let anyVal = aVals[offset];
76052 for (let j = 0; j < reduceSize; ++j) {
76053 const value = aVals[offset + j];
76054 anyVal = anyVal || value;
76055 }
76056 vals[i] = anyVal;
76057 }
76058 if (permutedAxes != null) {
76059 backend.disposeIntermediateTensorInfo($x);
76060 }
76061 const result = backend.makeTensorInfo(outShape, $x.dtype, vals);
76062 if (keepDims) {
76063 const expandedShape = expandShapeToKeepDim(outShape, origAxes);
76064 const reshapedResult = reshape$1({ inputs: { x: result }, backend, attrs: { shape: expandedShape } });
76065 backend.disposeIntermediateTensorInfo(result);
76066 return reshapedResult;
76067 }
76068 return result;
76069 }
76070 const anyConfig$1 = {
76071 kernelName: Any,
76072 backendName: 'cpu',
76073 kernelFunc: any$1
76074 };
76075
76076 /**
76077 * @license
76078 * Copyright 2020 Google LLC. All Rights Reserved.
76079 * Licensed under the Apache License, Version 2.0 (the "License");
76080 * you may not use this file except in compliance with the License.
76081 * You may obtain a copy of the License at
76082 *
76083 * http://www.apache.org/licenses/LICENSE-2.0
76084 *
76085 * Unless required by applicable law or agreed to in writing, software
76086 * distributed under the License is distributed on an "AS IS" BASIS,
76087 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76088 * See the License for the specific language governing permissions and
76089 * limitations under the License.
76090 * =============================================================================
76091 */
76092 function argMax$1(args) {
76093 const { inputs, backend, attrs } = args;
76094 const { x } = inputs;
76095 const { axis } = attrs;
76096 assertNotComplex$1(x, 'argMax');
76097 let axes = parseAxisParam(axis, x.shape);
76098 const permutedAxes = getAxesPermutation(axes, x.shape.length);
76099 let $x = x;
76100 const intermediateTensorInfos = [];
76101 if (permutedAxes != null) {
76102 $x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
76103 intermediateTensorInfos.push($x);
76104 axes = getInnerMostAxes(axes.length, $x.shape.length);
76105 }
76106 axes = [axes[0]];
76107 assertAxesAreInnerMostDims('argMax', axes, $x.shape.length);
76108 const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
76109 const outSize = sizeFromShape(outShape);
76110 const vals = makeZerosTypedArray(outSize, 'int32');
76111 const reduceSize = sizeFromShape(reduceShape);
76112 const aVals = backend.data.get($x.dataId).values;
76113 for (let i = 0; i < vals.length; ++i) {
76114 const offset = i * reduceSize;
76115 let max = aVals[offset];
76116 let maxIndex = 0;
76117 for (let j = 0; j < reduceSize; ++j) {
76118 const value = aVals[offset + j];
76119 if (value > max) {
76120 max = value;
76121 maxIndex = j;
76122 }
76123 }
76124 vals[i] = maxIndex;
76125 }
76126 intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
76127 return backend.makeTensorInfo(outShape, 'int32', vals);
76128 }
76129 const argMaxConfig$1 = {
76130 kernelName: ArgMax,
76131 backendName: 'cpu',
76132 kernelFunc: argMax$1
76133 };
76134
76135 /**
76136 * @license
76137 * Copyright 2020 Google LLC. All Rights Reserved.
76138 * Licensed under the Apache License, Version 2.0 (the "License");
76139 * you may not use this file except in compliance with the License.
76140 * You may obtain a copy of the License at
76141 *
76142 * http://www.apache.org/licenses/LICENSE-2.0
76143 *
76144 * Unless required by applicable law or agreed to in writing, software
76145 * distributed under the License is distributed on an "AS IS" BASIS,
76146 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76147 * See the License for the specific language governing permissions and
76148 * limitations under the License.
76149 * =============================================================================
76150 */
76151 function argMin$1(args) {
76152 const { inputs, backend, attrs } = args;
76153 const { x } = inputs;
76154 const { axis } = attrs;
76155 assertNotComplex$1(x, 'argMin');
76156 let axes = parseAxisParam(axis, x.shape);
76157 const permutedAxes = getAxesPermutation(axes, x.shape.length);
76158 let $x = x;
76159 const intermediateTensorInfos = [];
76160 if (permutedAxes != null) {
76161 $x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
76162 intermediateTensorInfos.push($x);
76163 axes = getInnerMostAxes(axes.length, $x.shape.length);
76164 }
76165 axes = [axes[0]];
76166 assertAxesAreInnerMostDims('argMin', axes, $x.shape.length);
76167 const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
76168 const outSize = sizeFromShape(outShape);
76169 const vals = makeZerosTypedArray(outSize, 'int32');
76170 const reduceSize = sizeFromShape(reduceShape);
76171 const aVals = backend.data.get($x.dataId).values;
76172 for (let i = 0; i < vals.length; ++i) {
76173 const offset = i * reduceSize;
76174 let min = aVals[offset];
76175 let minIndex = 0;
76176 for (let j = 0; j < reduceSize; ++j) {
76177 const value = aVals[offset + j];
76178 if (value < min) {
76179 min = value;
76180 minIndex = j;
76181 }
76182 }
76183 vals[i] = minIndex;
76184 }
76185 intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
76186 return backend.makeTensorInfo(outShape, 'int32', vals);
76187 }
76188 const argMinConfig$1 = {
76189 kernelName: ArgMin,
76190 backendName: 'cpu',
76191 kernelFunc: argMin$1
76192 };
76193
76194 /**
76195 * @license
76196 * Copyright 2020 Google LLC. All Rights Reserved.
76197 * Licensed under the Apache License, Version 2.0 (the License);
76198 * you may not use this file except in compliance with the License.
76199 * You may obtain a copy of the License at
76200 *
76201 * http://www.apache.org/licenses/LICENSE-2.0
76202 *
76203 * Unless required by applicable law or agreed to in writing, software
76204 * distributed under the License is distributed on an AS IS BASIS,
76205 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76206 * See the License for the specific language governing permissions and
76207 * limitations under the License.
76208 * =============================================================================
76209 */
76210 const asin$1 = unaryKernelFunc$1(Asin, (xi) => Math.asin(xi));
76211 const asinConfig$1 = {
76212 kernelName: Asin,
76213 backendName: 'cpu',
76214 kernelFunc: asin$1,
76215 };
76216
76217 /**
76218 * @license
76219 * Copyright 2020 Google LLC. All Rights Reserved.
76220 * Licensed under the Apache License, Version 2.0 (the License);
76221 * you may not use this file except in compliance with the License.
76222 * You may obtain a copy of the License at
76223 *
76224 * http://www.apache.org/licenses/LICENSE-2.0
76225 *
76226 * Unless required by applicable law or agreed to in writing, software
76227 * distributed under the License is distributed on an AS IS BASIS,
76228 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76229 * See the License for the specific language governing permissions and
76230 * limitations under the License.
76231 * =============================================================================
76232 */
76233 const asinh$1 = unaryKernelFunc$1(Asinh, (xi) => Math.asinh(xi));
76234 const asinhConfig$1 = {
76235 kernelName: Asinh,
76236 backendName: 'cpu',
76237 kernelFunc: asinh$1,
76238 };
76239
76240 /**
76241 * @license
76242 * Copyright 2020 Google LLC. All Rights Reserved.
76243 * Licensed under the Apache License, Version 2.0 (the License);
76244 * you may not use this file except in compliance with the License.
76245 * You may obtain a copy of the License at
76246 *
76247 * http://www.apache.org/licenses/LICENSE-2.0
76248 *
76249 * Unless required by applicable law or agreed to in writing, software
76250 * distributed under the License is distributed on an AS IS BASIS,
76251 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76252 * See the License for the specific language governing permissions and
76253 * limitations under the License.
76254 * =============================================================================
76255 */
76256 const atan$1 = unaryKernelFunc$1(Atan, (xi) => Math.atan(xi));
76257 const atanConfig$1 = {
76258 kernelName: Atan,
76259 backendName: 'cpu',
76260 kernelFunc: atan$1,
76261 };
76262
76263 /**
76264 * @license
76265 * Copyright 2020 Google LLC. All Rights Reserved.
76266 * Licensed under the Apache License, Version 2.0 (the License);
76267 * you may not use this file except in compliance with the License.
76268 * You may obtain a copy of the License at
76269 *
76270 * http://www.apache.org/licenses/LICENSE-2.0
76271 *
76272 * Unless required by applicable law or agreed to in writing, software
76273 * distributed under the License is distributed on an AS IS BASIS,
76274 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76275 * See the License for the specific language governing permissions and
76276 * limitations under the License.
76277 * =============================================================================
76278 */
76279 const atan2Impl = createSimpleBinaryKernelImpl((aValue, bValue) => Math.atan2(aValue, bValue));
76280 const atan2$1 = binaryKernelFunc$1(Atan2, atan2Impl);
76281 const atan2Config$1 = {
76282 kernelName: Atan2,
76283 backendName: 'cpu',
76284 kernelFunc: atan2$1,
76285 };
76286
76287 /**
76288 * @license
76289 * Copyright 2020 Google LLC. All Rights Reserved.
76290 * Licensed under the Apache License, Version 2.0 (the License);
76291 * you may not use this file except in compliance with the License.
76292 * You may obtain a copy of the License at
76293 *
76294 * http://www.apache.org/licenses/LICENSE-2.0
76295 *
76296 * Unless required by applicable law or agreed to in writing, software
76297 * distributed under the License is distributed on an AS IS BASIS,
76298 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76299 * See the License for the specific language governing permissions and
76300 * limitations under the License.
76301 * =============================================================================
76302 */
76303 const atanh$1 = unaryKernelFunc$1(Atanh, (xi) => Math.atanh(xi));
76304 const atanhConfig$1 = {
76305 kernelName: Atanh,
76306 backendName: 'cpu',
76307 kernelFunc: atanh$1,
76308 };
76309
76310 /**
76311 * @license
76312 * Copyright 2020 Google LLC. All Rights Reserved.
76313 * Licensed under the Apache License, Version 2.0 (the "License");
76314 * you may not use this file except in compliance with the License.
76315 * You may obtain a copy of the License at
76316 *
76317 * http://www.apache.org/licenses/LICENSE-2.0
76318 *
76319 * Unless required by applicable law or agreed to in writing, software
76320 * distributed under the License is distributed on an "AS IS" BASIS,
76321 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76322 * See the License for the specific language governing permissions and
76323 * limitations under the License.
76324 * =============================================================================
76325 */
76326 function pool(xValues, xShape, dtype, strides, convInfo, poolType) {
76327 const strideHeight = convInfo.strideHeight;
76328 const strideWidth = convInfo.strideWidth;
76329 const dilationHeight = convInfo.dilationHeight;
76330 const dilationWidth = convInfo.dilationWidth;
76331 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
76332 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
76333 const padTop = convInfo.padInfo.top;
76334 const padLeft = convInfo.padInfo.left;
76335 const initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY :
76336 Number.POSITIVE_INFINITY);
76337 const output = buffer(convInfo.outShape, dtype);
76338 const outputVals = output.values;
76339 const outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3];
76340 const outputRowStrides = convInfo.outShape[2] * convInfo.outShape[3];
76341 const outputColStrides = convInfo.outShape[3];
76342 for (let b = 0; b < convInfo.batchSize; ++b) {
76343 const outputBatchOffset = b * outputBatchStrides;
76344 const inputBatchOffset = b * strides[0];
76345 for (let d = 0; d < convInfo.inChannels; ++d) {
76346 for (let yR = 0; yR < convInfo.outHeight; ++yR) {
76347 const xRCorner = yR * strideHeight - padTop;
76348 const xRMin = Math.max(0, xRCorner);
76349 const xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
76350 const outputRowOffset = outputBatchOffset + yR * outputRowStrides;
76351 for (let yC = 0; yC < convInfo.outWidth; ++yC) {
76352 const xCCorner = yC * strideWidth - padLeft;
76353 const xCMin = Math.max(0, xCCorner);
76354 const xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
76355 let minMaxValue = initialValue;
76356 let avgValue = 0;
76357 let count = 0;
76358 for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {
76359 const xROffset = inputBatchOffset + xR * strides[1];
76360 for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {
76361 const xCOffset = xROffset + xC * strides[2];
76362 const pixel = xValues[xCOffset + d];
76363 if ((poolType === 'max' && pixel > minMaxValue)) {
76364 minMaxValue = pixel;
76365 }
76366 else if (poolType === 'avg') {
76367 avgValue += pixel;
76368 count++;
76369 }
76370 }
76371 if (isNaN(minMaxValue)) {
76372 break;
76373 }
76374 }
76375 const outputOffset = outputRowOffset + yC * outputColStrides + d;
76376 outputVals[outputOffset] =
76377 poolType === 'avg' ? avgValue / count : minMaxValue;
76378 }
76379 }
76380 }
76381 }
76382 return output;
76383 }
76384 function maxPoolPositions(xValues, xShape, dtype, convInfo, flattenPositions = false, includeBatchInIndex = false) {
76385 const maxPositions = buffer(convInfo.outShape, 'int32');
76386 const strideHeight = convInfo.strideHeight;
76387 const strideWidth = convInfo.strideWidth;
76388 const dilationHeight = convInfo.dilationHeight;
76389 const dilationWidth = convInfo.dilationWidth;
76390 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
76391 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
76392 const padTop = convInfo.padInfo.top;
76393 const padLeft = convInfo.padInfo.left;
76394 const xBuf = buffer(xShape, dtype, xValues);
76395 for (let b = 0; b < convInfo.batchSize; ++b) {
76396 for (let d = 0; d < convInfo.inChannels; ++d) {
76397 for (let yR = 0; yR < convInfo.outHeight; ++yR) {
76398 const xRCorner = yR * strideHeight - padTop;
76399 let xRMin = xRCorner;
76400 while (xRMin < 0) {
76401 xRMin += dilationHeight;
76402 }
76403 // const xRMin = Math.max(0, xRCorner);
76404 const xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
76405 for (let yC = 0; yC < convInfo.outWidth; ++yC) {
76406 const xCCorner = yC * strideWidth - padLeft;
76407 let xCMin = xCCorner;
76408 while (xCMin < 0) {
76409 xCMin += dilationWidth;
76410 }
76411 const xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
76412 let maxValue = Number.NEGATIVE_INFINITY;
76413 let maxPosition = -1;
76414 for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {
76415 const wR = xR - xRCorner;
76416 for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {
76417 const wC = xC - xCCorner;
76418 // For some reason, disable-next-line is not working
76419 // TODO(mattsoulanille): Remove this when switching to TS5.
76420 /* tslint:disable: no-unnecessary-type-assertion */
76421 const pixel = xBuf.get(b, xR, xC, d);
76422 if (pixel > maxValue) {
76423 maxValue = pixel;
76424 if (flattenPositions) {
76425 maxPosition = includeBatchInIndex ?
76426 ((b * convInfo.inHeight + xR) * convInfo.inWidth + xC) *
76427 convInfo.inChannels +
76428 d :
76429 (xR * convInfo.inWidth + xC) * convInfo.inChannels + d;
76430 }
76431 else {
76432 maxPosition = wR * effectiveFilterWidth + wC;
76433 }
76434 }
76435 }
76436 }
76437 maxPositions.set(maxPosition, b, yR, yC, d);
76438 }
76439 }
76440 }
76441 }
76442 return maxPositions;
76443 }
76444 function pool3d(xValues, xShape, dtype, strides, convInfo, poolType) {
76445 const strideDepth = convInfo.strideDepth;
76446 const strideHeight = convInfo.strideHeight;
76447 const strideWidth = convInfo.strideWidth;
76448 const dilationDepth = convInfo.dilationDepth;
76449 const dilationHeight = convInfo.dilationHeight;
76450 const dilationWidth = convInfo.dilationWidth;
76451 const effectiveFilterDepth = convInfo.effectiveFilterDepth;
76452 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
76453 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
76454 const padFront = convInfo.padInfo.front;
76455 const padTop = convInfo.padInfo.top;
76456 const padLeft = convInfo.padInfo.left;
76457 const initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY :
76458 Number.POSITIVE_INFINITY);
76459 const output = buffer(convInfo.outShape, dtype);
76460 const outputVals = output.values;
76461 const outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] *
76462 convInfo.outShape[3] * convInfo.outShape[4];
76463 const outputDepthStrides = convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4];
76464 const outputRowStrides = convInfo.outShape[3] * convInfo.outShape[4];
76465 const outputColStrides = convInfo.outShape[4];
76466 for (let batch = 0; batch < convInfo.batchSize; ++batch) {
76467 const outputBatchOffset = batch * outputBatchStrides;
76468 const inputBatchOffset = batch * strides[0];
76469 for (let channel = 0; channel < convInfo.inChannels; ++channel) {
76470 for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
76471 const xDepthCorner = yDepth * strideDepth - padFront;
76472 let xDepthMin = xDepthCorner;
76473 while (xDepthMin < 0) {
76474 xDepthMin += dilationDepth;
76475 }
76476 const xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
76477 const outputDepthOffset = outputBatchOffset + yDepth * outputDepthStrides;
76478 for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {
76479 const xRowCorner = yRow * strideHeight - padTop;
76480 let xRowMin = xRowCorner;
76481 while (xRowMin < 0) {
76482 xRowMin += dilationHeight;
76483 }
76484 const xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
76485 const outputRowOffset = outputDepthOffset + yRow * outputRowStrides;
76486 for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {
76487 const xColCorner = yCol * strideWidth - padLeft;
76488 let xColMin = xColCorner;
76489 while (xColMin < 0) {
76490 xColMin += dilationWidth;
76491 }
76492 const xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
76493 // Shader code begins
76494 const outputColOffset = outputRowOffset + yCol * outputColStrides;
76495 let minMaxValue = initialValue;
76496 let avgValue = 0;
76497 let count = 0;
76498 for (let xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
76499 const xDepthOffset = inputBatchOffset + xDepth * strides[1];
76500 for (let xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
76501 const xRowOffset = xDepthOffset + xRow * strides[2];
76502 for (let xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
76503 const xColOffset = xRowOffset + xCol * strides[3];
76504 const pixel = xValues[xColOffset + channel];
76505 if ((poolType === 'max' && pixel > minMaxValue)) {
76506 minMaxValue = pixel;
76507 }
76508 else if (poolType === 'avg') {
76509 avgValue += pixel;
76510 count++;
76511 }
76512 if (isNaN(minMaxValue)) {
76513 break;
76514 }
76515 }
76516 if (isNaN(minMaxValue)) {
76517 break;
76518 }
76519 }
76520 if (isNaN(minMaxValue)) {
76521 break;
76522 }
76523 }
76524 const outputOffset = outputColOffset + channel;
76525 outputVals[outputOffset] = poolType === 'avg' ?
76526 avgValue / Math.max(count, 1) :
76527 minMaxValue;
76528 }
76529 }
76530 }
76531 }
76532 }
76533 return output;
76534 }
76535 function maxPool3dPositions(xBuf, convInfo) {
76536 const maxPositions = buffer(convInfo.outShape, 'int32');
76537 const strideDepth = convInfo.strideDepth;
76538 const strideHeight = convInfo.strideHeight;
76539 const strideWidth = convInfo.strideWidth;
76540 const dilationDepth = convInfo.dilationDepth;
76541 const dilationHeight = convInfo.dilationHeight;
76542 const dilationWidth = convInfo.dilationWidth;
76543 const effectiveFilterDepth = convInfo.effectiveFilterDepth;
76544 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
76545 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
76546 const padFront = convInfo.padInfo.front;
76547 const padTop = convInfo.padInfo.top;
76548 const padLeft = convInfo.padInfo.left;
76549 for (let batch = 0; batch < convInfo.batchSize; ++batch) {
76550 for (let channel = 0; channel < convInfo.inChannels; ++channel) {
76551 for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
76552 const xDepthCorner = yDepth * strideDepth - padFront;
76553 let xDepthMin = xDepthCorner;
76554 while (xDepthMin < 0) {
76555 xDepthMin += dilationDepth;
76556 }
76557 const xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
76558 for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {
76559 const xRowCorner = yRow * strideHeight - padTop;
76560 let xRowMin = xRowCorner;
76561 while (xRowMin < 0) {
76562 xRowMin += dilationHeight;
76563 }
76564 const xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
76565 for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {
76566 const xColCorner = yCol * strideWidth - padLeft;
76567 let xColMin = xColCorner;
76568 while (xColMin < 0) {
76569 xColMin += dilationWidth;
76570 }
76571 const xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
76572 // Shader code begins
76573 let maxValue = Number.NEGATIVE_INFINITY;
76574 let maxPosition = -1;
76575 for (let xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
76576 const wDepth = xDepth - xDepthCorner;
76577 for (let xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
76578 const wRow = xRow - xRowCorner;
76579 for (let xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
76580 const wCol = xCol - xColCorner;
76581 const pixel = xBuf.get(batch, xDepth, xRow, xCol, channel);
76582 if (pixel >= maxValue) {
76583 maxValue = pixel;
76584 maxPosition =
76585 wDepth * effectiveFilterHeight * effectiveFilterWidth +
76586 wRow * effectiveFilterHeight + wCol;
76587 }
76588 }
76589 }
76590 }
76591 maxPositions.set(maxPosition, batch, yDepth, yRow, yCol, channel);
76592 }
76593 }
76594 }
76595 }
76596 }
76597 return maxPositions;
76598 }
76599
76600 /**
76601 * @license
76602 * Copyright 2020 Google LLC. All Rights Reserved.
76603 * Licensed under the Apache License, Version 2.0 (the "License");
76604 * you may not use this file except in compliance with the License.
76605 * You may obtain a copy of the License at
76606 *
76607 * http://www.apache.org/licenses/LICENSE-2.0
76608 *
76609 * Unless required by applicable law or agreed to in writing, software
76610 * distributed under the License is distributed on an "AS IS" BASIS,
76611 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76612 * See the License for the specific language governing permissions and
76613 * limitations under the License.
76614 * =============================================================================
76615 */
76616 function avgPool$1(args) {
76617 const { inputs, backend, attrs } = args;
76618 const { x } = inputs;
76619 assertNotComplex$1(x, 'avgPool');
76620 const { filterSize, strides, pad, dimRoundingMode } = attrs;
76621 const dilations = 1;
76622 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool: Either strides or dilations must be 1. ' +
76623 `Got strides ${strides} and dilations '${dilations}'`);
76624 const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
76625 let res;
76626 if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
76627 arraysEqual(convInfo.inShape, convInfo.outShape)) {
76628 res = identity$1({ inputs: { x }, backend });
76629 }
76630 else {
76631 const xValues = backend.data.get(x.dataId).values;
76632 const strides = computeStrides(x.shape);
76633 const buffer = pool(xValues, x.shape, x.dtype, strides, convInfo, 'avg');
76634 res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
76635 }
76636 return res;
76637 }
76638 const avgPoolConfig$1 = {
76639 kernelName: AvgPool,
76640 backendName: 'cpu',
76641 kernelFunc: avgPool$1
76642 };
76643
76644 /**
76645 * @license
76646 * Copyright 2020 Google LLC. All Rights Reserved.
76647 * Licensed under the Apache License, Version 2.0 (the "License");
76648 * you may not use this file except in compliance with the License.
76649 * You may obtain a copy of the License at
76650 *
76651 * http://www.apache.org/licenses/LICENSE-2.0
76652 *
76653 * Unless required by applicable law or agreed to in writing, software
76654 * distributed under the License is distributed on an "AS IS" BASIS,
76655 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76656 * See the License for the specific language governing permissions and
76657 * limitations under the License.
76658 * =============================================================================
76659 */
76660 function avgPool3D$1(args) {
76661 const { inputs, backend, attrs } = args;
76662 const { x } = inputs;
76663 const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs;
76664 assertNotComplex$1(x, 'avgPool3d');
76665 const convInfo = computePool3DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode, dataFormat);
76666 const xValues = backend.data.get(x.dataId).values;
76667 const outBuf = pool3d(xValues, x.shape, x.dtype, computeStrides(x.shape), convInfo, 'avg');
76668 return backend.makeTensorInfo(outBuf.shape, 'float32', outBuf.values);
76669 }
76670 const avgPool3DConfig$1 = {
76671 kernelName: AvgPool3D,
76672 backendName: 'cpu',
76673 kernelFunc: avgPool3D$1
76674 };
76675
76676 /**
76677 * @license
76678 * Copyright 2020 Google LLC. All Rights Reserved.
76679 * Licensed under the Apache License, Version 2.0 (the "License");
76680 * you may not use this file except in compliance with the License.
76681 * You may obtain a copy of the License at
76682 *
76683 * http://www.apache.org/licenses/LICENSE-2.0
76684 *
76685 * Unless required by applicable law or agreed to in writing, software
76686 * distributed under the License is distributed on an "AS IS" BASIS,
76687 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76688 * See the License for the specific language governing permissions and
76689 * limitations under the License.
76690 * =============================================================================
76691 */
76692 function avgPool3DGrad$1(args) {
76693 const { inputs, backend, attrs } = args;
76694 const { dy, input } = inputs;
76695 const { filterSize, strides, pad, dimRoundingMode } = attrs;
76696 assertNotComplex$1([dy, input], 'avgPool3DGrad');
76697 const convInfo = computePool3DInfo(input.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
76698 const strideDepth = convInfo.strideDepth;
76699 const strideHeight = convInfo.strideHeight;
76700 const strideWidth = convInfo.strideWidth;
76701 const filterDepth = convInfo.filterDepth;
76702 const filterHeight = convInfo.filterHeight;
76703 const filterWidth = convInfo.filterWidth;
76704 const dilationDepth = convInfo.dilationDepth;
76705 const dilationHeight = convInfo.dilationHeight;
76706 const dilationWidth = convInfo.dilationWidth;
76707 const effectiveFilterDepth = convInfo.effectiveFilterDepth;
76708 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
76709 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
76710 const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
76711 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
76712 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
76713 const dx = buffer(input.shape, 'float32');
76714 const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
76715 const dyBuf = backend.bufferSync(dy);
76716 for (let batch = 0; batch < convInfo.batchSize; ++batch) {
76717 for (let channel = 0; channel < convInfo.inChannels; ++channel) {
76718 for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
76719 for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
76720 for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
76721 // Shader code begins.
76722 const dyDepthCorner = dxDepth - padFront;
76723 const dyRowCorner = dxRow - padTop;
76724 const dyColCorner = dxCol - padLeft;
76725 let dotProd = 0;
76726 for (let wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
76727 const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
76728 if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
76729 Math.floor(dyDepth) !== dyDepth) {
76730 continue;
76731 }
76732 for (let wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
76733 const dyRow = (dyRowCorner + wRow) / strideHeight;
76734 if (dyRow < 0 || dyRow >= convInfo.outHeight ||
76735 Math.floor(dyRow) !== dyRow) {
76736 continue;
76737 }
76738 for (let wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
76739 const dyCol = (dyColCorner + wCol) / strideWidth;
76740 if (dyCol < 0 || dyCol >= convInfo.outWidth ||
76741 Math.floor(dyCol) !== dyCol) {
76742 continue;
76743 }
76744 const pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
76745 dotProd += pixel;
76746 }
76747 }
76748 }
76749 dx.set(dotProd * avgMultiplier, batch, dxDepth, dxRow, dxCol, channel);
76750 }
76751 }
76752 }
76753 }
76754 }
76755 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
76756 }
76757 const avgPool3DGradConfig$1 = {
76758 kernelName: AvgPool3DGrad,
76759 backendName: 'cpu',
76760 kernelFunc: avgPool3DGrad$1
76761 };
76762
76763 /**
76764 * @license
76765 * Copyright 2020 Google LLC. All Rights Reserved.
76766 * Licensed under the Apache License, Version 2.0 (the "License");
76767 * you may not use this file except in compliance with the License.
76768 * You may obtain a copy of the License at
76769 *
76770 * http://www.apache.org/licenses/LICENSE-2.0
76771 *
76772 * Unless required by applicable law or agreed to in writing, software
76773 * distributed under the License is distributed on an "AS IS" BASIS,
76774 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76775 * See the License for the specific language governing permissions and
76776 * limitations under the License.
76777 * =============================================================================
76778 */
76779 function avgPoolGrad$1(args) {
76780 const { inputs, backend, attrs } = args;
76781 const { dy, input } = inputs;
76782 const x = input;
76783 assertNotComplex$1([dy, input], 'avgPoolGrad');
76784 const { filterSize, strides, pad } = attrs;
76785 const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad);
76786 const strideHeight = convInfo.strideHeight;
76787 const strideWidth = convInfo.strideWidth;
76788 const filterHeight = convInfo.filterHeight;
76789 const filterWidth = convInfo.filterWidth;
76790 const dilationHeight = convInfo.dilationHeight;
76791 const dilationWidth = convInfo.dilationWidth;
76792 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
76793 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
76794 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
76795 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
76796 const dx = buffer(x.shape, 'float32');
76797 const avgMultiplier = 1 / (filterHeight * filterWidth);
76798 const dyData = backend.data.get(dy.dataId).values;
76799 const dyBuf = buffer(dy.shape, 'float32', dyData);
76800 for (let b = 0; b < convInfo.batchSize; ++b) {
76801 for (let d = 0; d < convInfo.inChannels; ++d) {
76802 for (let dxR = 0; dxR < convInfo.inHeight; ++dxR) {
76803 for (let dxC = 0; dxC < convInfo.inWidth; ++dxC) {
76804 // Shader code begins.
76805 const dyRCorner = dxR - padTop;
76806 const dyCCorner = dxC - padLeft;
76807 let dotProd = 0;
76808 for (let wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
76809 const dyR = (dyRCorner + wR) / strideHeight;
76810 if (dyR < 0 || dyR >= convInfo.outHeight ||
76811 Math.floor(dyR) !== dyR) {
76812 continue;
76813 }
76814 for (let wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
76815 const dyC = (dyCCorner + wC) / strideWidth;
76816 if (dyC < 0 || dyC >= convInfo.outWidth ||
76817 Math.floor(dyC) !== dyC) {
76818 continue;
76819 }
76820 const pixel = dyBuf.get(b, dyR, dyC, d);
76821 dotProd += pixel;
76822 }
76823 }
76824 dx.set(dotProd * avgMultiplier, b, dxR, dxC, d);
76825 }
76826 }
76827 }
76828 }
76829 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
76830 }
76831 const avgPoolGradConfig$1 = {
76832 kernelName: AvgPoolGrad,
76833 backendName: 'cpu',
76834 kernelFunc: avgPoolGrad$1
76835 };
76836
76837 /**
76838 * @license
76839 * Copyright 2020 Google LLC. All Rights Reserved.
76840 * Licensed under the Apache License, Version 2.0 (the "License");
76841 * you may not use this file except in compliance with the License.
76842 * You may obtain a copy of the License at
76843 *
76844 * http://www.apache.org/licenses/LICENSE-2.0
76845 *
76846 * Unless required by applicable law or agreed to in writing, software
76847 * distributed under the License is distributed on an "AS IS" BASIS,
76848 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76849 * See the License for the specific language governing permissions and
76850 * limitations under the License.
76851 * =============================================================================
76852 */
76853 function batchNorm$1(args) {
76854 const { inputs, backend, attrs } = args;
76855 const { x, scale, offset, mean, variance } = inputs;
76856 assert$1(mean.shape.length === variance.shape.length, () => 'Batch normalization gradient requires mean and variance to have ' +
76857 'equal ranks.');
76858 assert$1(offset == null || mean.shape.length === offset.shape.length, () => 'Batch normalization gradient requires mean and offset to have ' +
76859 'equal ranks.');
76860 assert$1(scale == null || mean.shape.length === scale.shape.length, () => 'Batch normalization gradient requires mean and scale to have ' +
76861 'equal ranks.');
76862 assertNotComplex$1([x, mean, variance, scale, offset], 'batchNorm');
76863 let { varianceEpsilon } = attrs;
76864 if (varianceEpsilon == null) {
76865 varianceEpsilon = 0.001;
76866 }
76867 const xVals = backend.data.get(x.dataId).values;
76868 const mVals = backend.data.get(mean.dataId).values;
76869 const varVals = backend.data.get(variance.dataId).values;
76870 const sVals = scale ? backend.data.get(scale.dataId).values :
76871 new Float32Array([1]);
76872 const offVals = offset ?
76873 backend.data.get(offset.dataId).values :
76874 new Float32Array([0]);
76875 const outVals = new Float32Array(xVals.length);
76876 const offValsLength = offVals.length;
76877 const sValsLength = sVals.length;
76878 const varValsLength = varVals.length;
76879 const mValsLength = mVals.length;
76880 let offi = 0;
76881 let mi = 0;
76882 let si = 0;
76883 let vi = 0;
76884 for (let i = 0; i < xVals.length; ++i) {
76885 outVals[i] = offVals[offi++] +
76886 (xVals[i] - mVals[mi++]) * sVals[si++] /
76887 Math.sqrt(varVals[vi++] + varianceEpsilon);
76888 if (offi >= offValsLength) {
76889 offi = 0;
76890 }
76891 if (mi >= mValsLength) {
76892 mi = 0;
76893 }
76894 if (si >= sValsLength) {
76895 si = 0;
76896 }
76897 if (vi >= varValsLength) {
76898 vi = 0;
76899 }
76900 }
76901 return backend.makeTensorInfo(x.shape, x.dtype, outVals);
76902 }
76903 const batchNormConfig$1 = {
76904 kernelName: FusedBatchNorm,
76905 backendName: 'cpu',
76906 kernelFunc: batchNorm$1,
76907 };
76908
76909 /**
76910 * @license
76911 * Copyright 2020 Google LLC. All Rights Reserved.
76912 * Licensed under the Apache License, Version 2.0 (the "License");
76913 * you may not use this file except in compliance with the License.
76914 * You may obtain a copy of the License at
76915 *
76916 * http://www.apache.org/licenses/LICENSE-2.0
76917 *
76918 * Unless required by applicable law or agreed to in writing, software
76919 * distributed under the License is distributed on an "AS IS" BASIS,
76920 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76921 * See the License for the specific language governing permissions and
76922 * limitations under the License.
76923 * =============================================================================
76924 */
76925 function batchToSpaceND$1(args) {
76926 const { inputs, backend, attrs } = args;
76927 const { x } = inputs;
76928 const { blockShape, crops } = attrs;
76929 assertNotComplex$1([x], 'batchToSpaceND');
76930 const prod = blockShape.reduce((a, b) => a * b);
76931 const reshaped = getReshaped(x.shape, blockShape, prod);
76932 const permuted = getPermuted(reshaped.length, blockShape.length);
76933 const reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod);
76934 const sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length);
76935 const sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length);
76936 const xReshaped = reshape$1({ inputs: { x }, backend, attrs: { shape: reshaped } });
76937 const xTransposed = transpose$1({ inputs: { x: xReshaped }, backend, attrs: { perm: permuted } });
76938 const xTransposedReshaped = reshape$1({ inputs: { x: xTransposed }, backend, attrs: { shape: reshapedPermuted } });
76939 const result = slice$1({
76940 inputs: { x: xTransposedReshaped },
76941 backend,
76942 attrs: { begin: sliceBeginCoords, size: sliceSize }
76943 });
76944 backend.disposeIntermediateTensorInfo(xReshaped);
76945 backend.disposeIntermediateTensorInfo(xTransposed);
76946 backend.disposeIntermediateTensorInfo(xTransposedReshaped);
76947 return result;
76948 }
76949 const batchToSpaceNDConfig$1 = {
76950 kernelName: BatchToSpaceND,
76951 backendName: 'cpu',
76952 kernelFunc: batchToSpaceND$1
76953 };
76954
76955 /**
76956 * @license
76957 * Copyright 2020 Google LLC. All Rights Reserved.
76958 * Licensed under the Apache License, Version 2.0 (the "License");
76959 * you may not use this file except in compliance with the License.
76960 * You may obtain a copy of the License at
76961 *
76962 * http://www.apache.org/licenses/LICENSE-2.0
76963 *
76964 * Unless required by applicable law or agreed to in writing, software
76965 * distributed under the License is distributed on an "AS IS" BASIS,
76966 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76967 * See the License for the specific language governing permissions and
76968 * limitations under the License.
76969 * =============================================================================
76970 */
76971 function bincount$1(args) {
76972 const { inputs, backend, attrs } = args;
76973 const { x, weights } = inputs;
76974 const { size } = attrs;
76975 const xVals = backend.data.get(x.dataId).values;
76976 const weightsVals = backend.data.get(weights.dataId).values;
76977 const outVals = bincountImpl(xVals, weightsVals, weights.dtype, weights.shape, size);
76978 return backend.makeTensorInfo([size], weights.dtype, outVals);
76979 }
76980 const bincountConfig$1 = {
76981 kernelName: Bincount,
76982 backendName: 'cpu',
76983 kernelFunc: bincount$1
76984 };
76985
76986 /**
76987 * @license
76988 * Copyright 2021 Google LLC. All Rights Reserved.
76989 * Licensed under the Apache License, Version 2.0 (the "License");
76990 * you may not use this file except in compliance with the License.
76991 * You may obtain a copy of the License at
76992 *
76993 * http://www.apache.org/licenses/LICENSE-2.0
76994 *
76995 * Unless required by applicable law or agreed to in writing, software
76996 * distributed under the License is distributed on an "AS IS" BASIS,
76997 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
76998 * See the License for the specific language governing permissions and
76999 * limitations under the License.
77000 * =============================================================================
77001 */
77002 function broadcastArgs$1(args) {
77003 const { inputs, backend } = args;
77004 const { s0, s1 } = inputs;
77005 const s0Vals = backend.data.get(s0.dataId).values;
77006 const s1Vals = backend.data.get(s1.dataId).values;
77007 const broadcastShape = assertAndGetBroadcastShape(Array.from(s0Vals), Array.from(s1Vals));
77008 return backend.makeTensorInfo([broadcastShape.length], 'int32', Int32Array.from(broadcastShape));
77009 }
77010 const broadcastArgsConfig$1 = {
77011 kernelName: BroadcastArgs,
77012 backendName: 'cpu',
77013 kernelFunc: broadcastArgs$1
77014 };
77015
77016 /**
77017 * @license
77018 * Copyright 2020 Google LLC. All Rights Reserved.
77019 * Licensed under the Apache License, Version 2.0 (the License);
77020 * you may not use this file except in compliance with the License.
77021 * You may obtain a copy of the License at
77022 *
77023 * http://www.apache.org/licenses/LICENSE-2.0
77024 *
77025 * Unless required by applicable law or agreed to in writing, software
77026 * distributed under the License is distributed on an AS IS BASIS,
77027 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77028 * See the License for the specific language governing permissions and
77029 * limitations under the License.
77030 * =============================================================================
77031 */
77032 const clipByValue$1 = unaryKernelFunc$1(ClipByValue, (xi, attrs) => {
77033 const clipAttrs = attrs;
77034 if (xi > clipAttrs.clipValueMax) {
77035 return clipAttrs.clipValueMax;
77036 }
77037 return xi < clipAttrs.clipValueMin ? clipAttrs.clipValueMin : xi;
77038 });
77039 const clipByValueConfig$1 = {
77040 kernelName: ClipByValue,
77041 backendName: 'cpu',
77042 kernelFunc: clipByValue$1,
77043 };
77044
77045 /**
77046 * @license
77047 * Copyright 2020 Google LLC. All Rights Reserved.
77048 * Licensed under the Apache License, Version 2.0 (the License);
77049 * you may not use this file except in compliance with the License.
77050 * You may obtain a copy of the License at
77051 *
77052 * http://www.apache.org/licenses/LICENSE-2.0
77053 *
77054 * Unless required by applicable law or agreed to in writing, software
77055 * distributed under the License is distributed on an AS IS BASIS,
77056 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77057 * See the License for the specific language governing permissions and
77058 * limitations under the License.
77059 * =============================================================================
77060 */
77061 const complexAbs$1 = (args) => {
77062 const { x } = args.inputs;
77063 const cpuBackend = args.backend;
77064 const resultValues = new Float32Array(sizeFromShape(x.shape));
77065 const complexVals = cpuBackend.data.get(x.dataId);
77066 const real = complexVals.complexTensorInfos.real;
77067 const imag = complexVals.complexTensorInfos.imag;
77068 const realVals = cpuBackend.data.get(real.dataId).values;
77069 const imagVals = cpuBackend.data.get(imag.dataId).values;
77070 for (let i = 0; i < realVals.length; i++) {
77071 const real = realVals[i];
77072 const imag = imagVals[i];
77073 resultValues[i] = Math.hypot(real, imag);
77074 }
77075 return cpuBackend.makeOutput(resultValues, x.shape, 'float32');
77076 };
77077 const complexAbsConfig$1 = {
77078 kernelName: ComplexAbs,
77079 backendName: 'cpu',
77080 kernelFunc: complexAbs$1,
77081 };
77082
77083 /**
77084 * @license
77085 * Copyright 2020 Google LLC. All Rights Reserved.
77086 * Licensed under the Apache License, Version 2.0 (the "License");
77087 * you may not use this file except in compliance with the License.
77088 * You may obtain a copy of the License at
77089 *
77090 * http://www.apache.org/licenses/LICENSE-2.0
77091 *
77092 * Unless required by applicable law or agreed to in writing, software
77093 * distributed under the License is distributed on an "AS IS" BASIS,
77094 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77095 * See the License for the specific language governing permissions and
77096 * limitations under the License.
77097 * =============================================================================
77098 */
77099 function imag$1(args) {
77100 const { inputs, backend } = args;
77101 const { input } = inputs;
77102 const imag = backend.data.get(input.dataId).complexTensorInfos.imag;
77103 const imagVal = backend.data.get(imag.dataId).values;
77104 // When complex tensor is disposed, its underlying parts will be disposed too.
77105 // Make new tensor out of the imag value of the complex. This makes sure the
77106 // value is still accessible even if complex tensor is disposed.
77107 return backend.makeTensorInfo(imag.shape, imag.dtype, imagVal);
77108 }
77109 const imagConfig$1 = {
77110 kernelName: Imag,
77111 backendName: 'cpu',
77112 kernelFunc: imag$1
77113 };
77114
77115 /**
77116 * @license
77117 * Copyright 2020 Google LLC. All Rights Reserved.
77118 * Licensed under the Apache License, Version 2.0 (the "License");
77119 * you may not use this file except in compliance with the License.
77120 * You may obtain a copy of the License at
77121 *
77122 * http://www.apache.org/licenses/LICENSE-2.0
77123 *
77124 * Unless required by applicable law or agreed to in writing, software
77125 * distributed under the License is distributed on an "AS IS" BASIS,
77126 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77127 * See the License for the specific language governing permissions and
77128 * limitations under the License.
77129 * =============================================================================
77130 */
77131 function concat$1(args) {
77132 const { inputs, backend, attrs } = args;
77133 const { axis } = attrs;
77134 const $axis = parseAxisParam(axis, inputs[0].shape)[0];
77135 const shapes = inputs.map(t => t.shape);
77136 assertParamsConsistent(shapes, $axis);
77137 let outShape = computeOutShape$1(inputs.map(t => t.shape), $axis);
77138 if (sizeFromShape(outShape) === 0) {
77139 return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
77140 }
77141 // Keep only non-empty tensors (ignore tensors with 0 in their shape).
77142 const $inputs = inputs.filter(t => sizeFromShape(t.shape) > 0);
77143 if ($inputs.length === 1) {
77144 return identity$1({ inputs: { x: $inputs[0] }, backend });
77145 }
77146 if ($inputs[0].dtype === 'complex64') {
77147 const reals = $inputs.map((t) => real$1({ inputs: { input: t }, backend }));
77148 const imags = $inputs.map((t) => imag$1({ inputs: { input: t }, backend }));
77149 const realConcated = concat$1({ inputs: reals, backend, attrs: { axis: $axis } });
77150 const imagConcated = concat$1({ inputs: imags, backend, attrs: { axis: $axis } });
77151 const result = complex$1({ inputs: { real: realConcated, imag: imagConcated }, backend });
77152 reals.forEach(r => backend.disposeIntermediateTensorInfo(r));
77153 imags.forEach(i => backend.disposeIntermediateTensorInfo(i));
77154 backend.disposeIntermediateTensorInfo(realConcated);
77155 backend.disposeIntermediateTensorInfo(imagConcated);
77156 return result;
77157 }
77158 // Any concat of n-dimensional tensors across any axis can be reduced to
77159 // a concatenation of two-dimensional tensors across the axis 1 by first
77160 // partitioning the axes of the original tensors into those less than the
77161 // axis to be concatenated and the rest. Then reshape the tensors
77162 // into a two-dimensional tensor by collapsing these two sets of axes and
77163 // concatenate the resulting matrices across the axis 1, finally reshaping
77164 // the result to have the proper shape.
77165 const inputs2D = $inputs.map(t => {
77166 const innerSize = sizeFromShape(t.shape.slice($axis));
77167 const shape = [-1, innerSize];
77168 return reshape$1({ inputs: { x: t }, backend, attrs: { shape } });
77169 });
77170 const inputsValShapes = inputs2D.map(t => {
77171 return { vals: backend.data.get(t.dataId).values, shape: t.shape };
77172 });
77173 // Concats 2d tensors along axis=1.
77174 outShape =
77175 computeOutShape$1(inputs2D.map(t => t.shape), 1 /* axis */);
77176 const simplyConcat = inputs2D[0].shape[0] === 1;
77177 const outVals = concatImpl$1(inputsValShapes, outShape, inputs[0].dtype, simplyConcat);
77178 const finalOutShape = computeOutShape$1($inputs.map(t => t.shape), $axis);
77179 const outInfo = backend.makeTensorInfo(finalOutShape, inputs[0].dtype, outVals);
77180 inputs2D.forEach(t => backend.disposeIntermediateTensorInfo(t));
77181 return outInfo;
77182 }
77183 const concatConfig$1 = {
77184 kernelName: Concat,
77185 backendName: 'cpu',
77186 kernelFunc: concat$1
77187 };
77188
77189 /**
77190 * @license
77191 * Copyright 2020 Google LLC. All Rights Reserved.
77192 * Licensed under the Apache License, Version 2.0 (the "License");
77193 * you may not use this file except in compliance with the License.
77194 * You may obtain a copy of the License at
77195 *
77196 * http://www.apache.org/licenses/LICENSE-2.0
77197 *
77198 * Unless required by applicable law or agreed to in writing, software
77199 * distributed under the License is distributed on an "AS IS" BASIS,
77200 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77201 * See the License for the specific language governing permissions and
77202 * limitations under the License.
77203 * =============================================================================
77204 */
77205 function conv2D(args) {
77206 const { inputs, backend, attrs } = args;
77207 const { x, filter } = inputs;
77208 const { strides, pad, dataFormat, dilations, dimRoundingMode } = attrs;
77209 assertNotComplex$1([x, filter], 'conv2d');
77210 const $dataFormat = convertConv2DDataFormat(dataFormat);
77211 const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
77212 const filterHeight = convInfo.filterHeight;
77213 const filterWidth = convInfo.filterWidth;
77214 const dilationHeight = convInfo.dilationHeight;
77215 const dilationWidth = convInfo.dilationWidth;
77216 const padLeft = convInfo.padInfo.left;
77217 const padTop = convInfo.padInfo.top;
77218 const isChannelsLast = convInfo.dataFormat === 'channelsLast';
77219 const y = new TensorBuffer(convInfo.outShape, x.dtype);
77220 const xStrides = computeStrides(x.shape);
77221 const filterStrides = computeStrides(filter.shape);
77222 const xBatchStride = xStrides[0];
77223 const xRowStride = isChannelsLast ? xStrides[1] : xStrides[2];
77224 const xColStride = isChannelsLast ? xStrides[2] : 1;
77225 const xChannelStride = isChannelsLast ? 1 : xStrides[1];
77226 const yBatchStride = y.strides[0];
77227 const yRowStride = isChannelsLast ? y.strides[1] : y.strides[2];
77228 const yColStride = isChannelsLast ? y.strides[2] : 1;
77229 const yChannelStride = isChannelsLast ? 1 : y.strides[1];
77230 const xVals = backend.data.get(x.dataId).values;
77231 const wVals = backend.data.get(filter.dataId).values;
77232 const yVals = y.values;
77233 for (let b = 0; b < convInfo.batchSize; ++b) {
77234 const xOffset1 = b * xBatchStride;
77235 const yOffset1 = b * yBatchStride;
77236 for (let yR = 0; yR < convInfo.outHeight; ++yR) {
77237 const yOffset2 = yOffset1 + yR * yRowStride;
77238 const xRCorner = yR * convInfo.strideHeight - padTop;
77239 for (let wR = 0; wR < filterHeight; ++wR) {
77240 const xR = xRCorner + wR * dilationHeight;
77241 if (xR < 0 || xR >= convInfo.inHeight) {
77242 continue;
77243 }
77244 const wOffset1 = wR * filterStrides[0];
77245 const xOffset2 = xOffset1 + xR * xRowStride;
77246 for (let yC = 0; yC < convInfo.outWidth; ++yC) {
77247 const yOffset3 = yOffset2 + yC * yColStride;
77248 const xCCorner = yC * convInfo.strideWidth - padLeft;
77249 for (let wC = 0; wC < filterWidth; ++wC) {
77250 const xC = xCCorner + wC * dilationWidth;
77251 if (xC < 0 || xC >= convInfo.inWidth) {
77252 continue;
77253 }
77254 const wOffset2 = wOffset1 + wC * filterStrides[1];
77255 const xOffset3 = xOffset2 + xC * xColStride;
77256 let wOffset3 = wOffset2;
77257 for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
77258 const xVal = xVals[xOffset3 + d1 * xChannelStride];
77259 for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
77260 yVals[yOffset3 + d2 * yChannelStride] +=
77261 xVal * wVals[wOffset3 + d2];
77262 }
77263 wOffset3 += convInfo.outChannels;
77264 }
77265 }
77266 }
77267 }
77268 }
77269 }
77270 return backend.makeTensorInfo(y.shape, y.dtype, yVals);
77271 }
77272 const conv2DConfig$1 = {
77273 kernelName: Conv2D$1,
77274 backendName: 'cpu',
77275 kernelFunc: conv2D
77276 };
77277
77278 /**
77279 * @license
77280 * Copyright 2020 Google LLC. All Rights Reserved.
77281 * Licensed under the Apache License, Version 2.0 (the "License");
77282 * you may not use this file except in compliance with the License.
77283 * You may obtain a copy of the License at
77284 *
77285 * http://www.apache.org/licenses/LICENSE-2.0
77286 *
77287 * Unless required by applicable law or agreed to in writing, software
77288 * distributed under the License is distributed on an "AS IS" BASIS,
77289 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77290 * See the License for the specific language governing permissions and
77291 * limitations under the License.
77292 * =============================================================================
77293 */
77294 function conv2DBackpropFilter$1(args) {
77295 const { inputs, backend, attrs } = args;
77296 const { x, dy } = inputs;
77297 const { strides, pad, dataFormat, dimRoundingMode, filterShape } = attrs;
77298 assertNotComplex$1([x, dy], 'conv2dBackpropFilter');
77299 const $dataFormat = convertConv2DDataFormat(dataFormat);
77300 const convInfo = computeConv2DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
77301 const { strideHeight, strideWidth, filterHeight, filterWidth } = convInfo;
77302 const isChannelsLast = convInfo.dataFormat === 'channelsLast';
77303 const dW = new TensorBuffer(convInfo.filterShape, 'float32');
77304 const leftPad = convInfo.padInfo.left;
77305 const topPad = convInfo.padInfo.top;
77306 const xVals = backend.data.get(x.dataId).values;
77307 const dyVals = backend.data.get(dy.dataId).values;
77308 const xBuf = new TensorBuffer(x.shape, x.dtype, xVals);
77309 const dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);
77310 for (let wR = 0; wR < filterHeight; ++wR) {
77311 const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
77312 const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
77313 for (let wC = 0; wC < filterWidth; ++wC) {
77314 const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
77315 const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
77316 for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
77317 for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
77318 let dotProd = 0;
77319 for (let b = 0; b < convInfo.batchSize; ++b) {
77320 for (let yR = yRMin; yR < yRMax; ++yR) {
77321 const xR = wR + yR * strideHeight - topPad;
77322 for (let yC = yCMin; yC < yCMax; ++yC) {
77323 const xC = wC + yC * strideWidth - leftPad;
77324 if (isChannelsLast) {
77325 dotProd += xBuf.get(b, xR, xC, d1) *
77326 dyBuf.get(b, yR, yC, d2);
77327 }
77328 else {
77329 dotProd += xBuf.get(b, d1, xR, xC) *
77330 dyBuf.get(b, d2, yR, yC);
77331 }
77332 }
77333 }
77334 }
77335 dW.set(dotProd, wR, wC, d1, d2);
77336 }
77337 }
77338 }
77339 }
77340 return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
77341 }
77342 const conv2DBackpropFilterConfig$1 = {
77343 kernelName: Conv2DBackpropFilter,
77344 backendName: 'cpu',
77345 kernelFunc: conv2DBackpropFilter$1
77346 };
77347
77348 /**
77349 * @license
77350 * Copyright 2020 Google LLC. All Rights Reserved.
77351 * Licensed under the Apache License, Version 2.0 (the "License");
77352 * you may not use this file except in compliance with the License.
77353 * You may obtain a copy of the License at
77354 *
77355 * http://www.apache.org/licenses/LICENSE-2.0
77356 *
77357 * Unless required by applicable law or agreed to in writing, software
77358 * distributed under the License is distributed on an "AS IS" BASIS,
77359 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77360 * See the License for the specific language governing permissions and
77361 * limitations under the License.
77362 * =============================================================================
77363 */
77364 function conv2DBackpropInput$1(args) {
77365 const { inputs, backend, attrs } = args;
77366 const { dy, filter } = inputs;
77367 const { inputShape, strides, pad, dataFormat, dimRoundingMode } = attrs;
77368 assertNotComplex$1([dy, filter], 'conv2dBackpropInput');
77369 const filterStrides = computeStrides(filter.shape);
77370 const dyStrides = computeStrides(dy.shape);
77371 let $dataFormat = convertConv2DDataFormat(dataFormat);
77372 const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad, dimRoundingMode, false, $dataFormat);
77373 const dx = new TensorBuffer(convInfo.inShape, 'float32');
77374 const dxValues = dx.values;
77375 const dyValues = backend.data.get(dy.dataId).values;
77376 const fltValues = backend.data.get(filter.dataId).values;
77377 const [fltS0, fltS1, fltS2] = filterStrides;
77378 const { batchSize, filterHeight, filterWidth, inChannels, inHeight, inWidth, outChannels, outHeight, outWidth, strideHeight, strideWidth } = convInfo;
77379 $dataFormat = convInfo.dataFormat;
77380 const topPad = filterHeight - 1 - convInfo.padInfo.top;
77381 const leftPad = filterWidth - 1 - convInfo.padInfo.left;
77382 const isChannelsLast = $dataFormat === 'channelsLast';
77383 const xBatchStride = dx.strides[0];
77384 const xRowStride = isChannelsLast ? dx.strides[1] : dx.strides[2];
77385 const xColStride = isChannelsLast ? dx.strides[2] : 1;
77386 const xChannelStride = isChannelsLast ? 1 : dx.strides[1];
77387 const yBatchStride = dyStrides[0];
77388 const yRowStride = isChannelsLast ? dyStrides[1] : dyStrides[2];
77389 const yColStride = isChannelsLast ? dyStrides[2] : 1;
77390 const yChannelStride = isChannelsLast ? 1 : dyStrides[1];
77391 for (let b = 0; b < batchSize; ++b) {
77392 for (let d1 = 0; d1 < inChannels; ++d1) {
77393 for (let xR = 0; xR < inHeight; ++xR) {
77394 const xRCorner = xR - topPad;
77395 const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
77396 const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
77397 for (let xC = 0; xC < inWidth; ++xC) {
77398 const xCCorner = xC - leftPad;
77399 const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
77400 const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
77401 let dotProd = 0;
77402 for (let yR = xRMin; yR < yRMax; ++yR) {
77403 const wR = yR * strideHeight - xRCorner;
77404 for (let yC = xCMin; yC < yCMax; ++yC) {
77405 const wC = yC * strideWidth - xCCorner;
77406 const dyOffset = yBatchStride * b + yRowStride * yR + yColStride * yC;
77407 const fltOffset = fltS0 * (filterHeight - 1 - wR) +
77408 fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
77409 for (let d2 = 0; d2 < outChannels; ++d2) {
77410 const pixel = dyValues[dyOffset + yChannelStride * d2];
77411 const weight = fltValues[fltOffset + d2];
77412 dotProd += pixel * weight;
77413 }
77414 }
77415 }
77416 const dxOffset = xBatchStride * b + xRowStride * xR +
77417 xColStride * xC + xChannelStride * d1;
77418 dxValues[dxOffset] = dotProd;
77419 }
77420 }
77421 }
77422 }
77423 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
77424 }
77425 const conv2DBackpropInputConfig$1 = {
77426 kernelName: Conv2DBackpropInput,
77427 backendName: 'cpu',
77428 kernelFunc: conv2DBackpropInput$1
77429 };
77430
77431 /**
77432 * @license
77433 * Copyright 2020 Google LLC. All Rights Reserved.
77434 * Licensed under the Apache License, Version 2.0 (the "License");
77435 * you may not use this file except in compliance with the License.
77436 * You may obtain a copy of the License at
77437 *
77438 * http://www.apache.org/licenses/LICENSE-2.0
77439 *
77440 * Unless required by applicable law or agreed to in writing, software
77441 * distributed under the License is distributed on an "AS IS" BASIS,
77442 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77443 * See the License for the specific language governing permissions and
77444 * limitations under the License.
77445 * =============================================================================
77446 */
77447 function conv3D$1(args) {
77448 const { inputs, backend, attrs } = args;
77449 const { x, filter } = inputs;
77450 const { strides, pad, dilations } = attrs;
77451 assertNotComplex$1([x, filter], 'conv3d');
77452 const convInfo = computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad);
77453 const { filterDepth, filterHeight, filterWidth, dilationDepth, dilationHeight, dilationWidth, padInfo } = convInfo;
77454 const padFront = padInfo.front;
77455 const padLeft = padInfo.left;
77456 const padTop = padInfo.top;
77457 const y = new TensorBuffer(convInfo.outShape, x.dtype);
77458 const xVals = backend.data.get(x.dataId).values;
77459 const wVals = backend.data.get(filter.dataId).values;
77460 const yVals = y.values;
77461 const xStrides = computeStrides(x.shape);
77462 const filterStrides = computeStrides(filter.shape);
77463 for (let b = 0; b < convInfo.batchSize; ++b) {
77464 const xOffset1 = b * xStrides[0];
77465 const yOffset1 = b * y.strides[0];
77466 for (let yF = 0; yF < convInfo.outDepth; ++yF) {
77467 const yOffset2 = yOffset1 + yF * y.strides[1];
77468 const xFCorner = yF * convInfo.strideDepth - padFront;
77469 for (let wF = 0; wF < filterDepth; ++wF) {
77470 const xF = xFCorner + wF * dilationDepth;
77471 if (xF < 0 || xF >= convInfo.inDepth) {
77472 continue;
77473 }
77474 const wOffset1 = wF * filterStrides[0];
77475 const xOffset2 = xOffset1 + xF * xStrides[1];
77476 for (let yR = 0; yR < convInfo.outHeight; ++yR) {
77477 const yOffset3 = yOffset2 + yR * y.strides[2];
77478 const xRCorner = yR * convInfo.strideHeight - padTop;
77479 for (let wR = 0; wR < filterHeight; ++wR) {
77480 const xR = xRCorner + wR * dilationHeight;
77481 if (xR < 0 || xR >= convInfo.inHeight) {
77482 continue;
77483 }
77484 const wOffset2 = wOffset1 + wR * filterStrides[1];
77485 const xOffset3 = xOffset2 + xR * xStrides[2];
77486 for (let yC = 0; yC < convInfo.outWidth; ++yC) {
77487 const yOffset4 = yOffset3 + yC * convInfo.outChannels;
77488 const xCCorner = yC * convInfo.strideWidth - padLeft;
77489 for (let wC = 0; wC < filterWidth; ++wC) {
77490 const xC = xCCorner + wC * dilationWidth;
77491 if (xC < 0 || xC >= convInfo.inWidth) {
77492 continue;
77493 }
77494 const wOffset3 = wOffset2 + wC * filterStrides[2];
77495 const xOffset4 = xOffset3 + xC * convInfo.inChannels;
77496 let wOffset4 = wOffset3;
77497 for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
77498 const xVal = xVals[xOffset4 + d1];
77499 for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
77500 yVals[yOffset4 + d2] += xVal * wVals[wOffset4 + d2];
77501 }
77502 wOffset4 += convInfo.outChannels;
77503 }
77504 }
77505 }
77506 }
77507 }
77508 }
77509 }
77510 }
77511 return backend.makeTensorInfo(y.shape, y.dtype, y.values);
77512 }
77513 const conv3DConfig$1 = {
77514 kernelName: Conv3D$1,
77515 backendName: 'cpu',
77516 kernelFunc: conv3D$1
77517 };
77518
77519 /**
77520 * @license
77521 * Copyright 2020 Google LLC. All Rights Reserved.
77522 * Licensed under the Apache License, Version 2.0 (the "License");
77523 * you may not use this file except in compliance with the License.
77524 * You may obtain a copy of the License at
77525 *
77526 * http://www.apache.org/licenses/LICENSE-2.0
77527 *
77528 * Unless required by applicable law or agreed to in writing, software
77529 * distributed under the License is distributed on an "AS IS" BASIS,
77530 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77531 * See the License for the specific language governing permissions and
77532 * limitations under the License.
77533 * =============================================================================
77534 */
77535 function conv3DBackpropFilterV2$1(args) {
77536 const { inputs, backend, attrs } = args;
77537 const { x, dy } = inputs;
77538 const { strides, pad, filterShape } = attrs;
77539 assertNotComplex$1([x, dy], 'conv3dBackpropFilterV2');
77540 const xStrides = computeStrides(x.shape);
77541 const dyStrides = computeStrides(dy.shape);
77542 const convInfo = computeConv3DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad);
77543 const strideDepth = convInfo.strideDepth;
77544 const strideHeight = convInfo.strideHeight;
77545 const strideWidth = convInfo.strideWidth;
77546 const filterDepth = convInfo.filterDepth;
77547 const filterHeight = convInfo.filterHeight;
77548 const filterWidth = convInfo.filterWidth;
77549 const dw = new TensorBuffer(convInfo.filterShape, 'float32');
77550 const dwValues = dw.values;
77551 const [dwS0, dwS1, dwS2, dwS3] = dw.strides;
77552 const dyValues = backend.data.get(dy.dataId).values;
77553 const [dyS0, dyS1, dyS2, dyS3] = dyStrides;
77554 const xValues = backend.data.get(x.dataId).values;
77555 const [xS0, xS1, xS2, xS3] = xStrides;
77556 const frontPad = convInfo.padInfo.front;
77557 const leftPad = convInfo.padInfo.left;
77558 const topPad = convInfo.padInfo.top;
77559 for (let wF = 0; wF < filterDepth; ++wF) {
77560 const yFMin = Math.max(0, Math.ceil((frontPad - wF) / strideDepth));
77561 const yFMax = Math.min(convInfo.outDepth, (convInfo.inDepth + frontPad - wF) / strideDepth);
77562 const wOffset1 = wF * dwS0;
77563 for (let wR = 0; wR < filterHeight; ++wR) {
77564 const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
77565 const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
77566 const wOffset2 = wR * dwS1 + wOffset1;
77567 for (let wC = 0; wC < filterWidth; ++wC) {
77568 const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
77569 const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
77570 const wOffset3 = wC * dwS2 + wOffset2;
77571 for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
77572 const wOffset4 = d1 * dwS3 + wOffset3;
77573 for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
77574 let dotProd = 0;
77575 for (let b = 0; b < convInfo.batchSize; ++b) {
77576 const xOffset1 = b * xS0;
77577 const yOffset1 = b * dyS0;
77578 for (let yF = yFMin; yF < yFMax; ++yF) {
77579 const xF = wF + yF * strideDepth - frontPad;
77580 const xOffset2 = xF * xS1 + xOffset1;
77581 const yOffset2 = yF * dyS1 + yOffset1;
77582 for (let yR = yRMin; yR < yRMax; ++yR) {
77583 const xR = wR + yR * strideHeight - topPad;
77584 const xOffset3 = xR * xS2 + xOffset2;
77585 const yOffset3 = yR * dyS2 + yOffset2;
77586 for (let yC = yCMin; yC < yCMax; ++yC) {
77587 const xC = wC + yC * strideWidth - leftPad;
77588 const xOffset4 = xC * xS3 + xOffset3;
77589 const yOffset4 = yC * dyS3 + yOffset3;
77590 dotProd += xValues[xOffset4 + d1] * dyValues[yOffset4 + d2];
77591 }
77592 }
77593 }
77594 }
77595 dwValues[wOffset4 + d2] = dotProd;
77596 }
77597 }
77598 }
77599 }
77600 }
77601 return backend.makeTensorInfo(dw.shape, dw.dtype, dw.values);
77602 }
77603 const conv3DBackpropFilterV2Config$1 = {
77604 kernelName: Conv3DBackpropFilterV2,
77605 backendName: 'cpu',
77606 kernelFunc: conv3DBackpropFilterV2$1
77607 };
77608
77609 /**
77610 * @license
77611 * Copyright 2020 Google LLC. All Rights Reserved.
77612 * Licensed under the Apache License, Version 2.0 (the "License");
77613 * you may not use this file except in compliance with the License.
77614 * You may obtain a copy of the License at
77615 *
77616 * http://www.apache.org/licenses/LICENSE-2.0
77617 *
77618 * Unless required by applicable law or agreed to in writing, software
77619 * distributed under the License is distributed on an "AS IS" BASIS,
77620 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77621 * See the License for the specific language governing permissions and
77622 * limitations under the License.
77623 * =============================================================================
77624 */
77625 function conv3DBackpropInputV2(args) {
77626 const { inputs, backend, attrs } = args;
77627 const { dy, filter } = inputs;
77628 const { pad, strides, inputShape } = attrs;
77629 assertNotComplex$1([dy], 'conv3dBackpropInputV2');
77630 const dyStrides = computeStrides(dy.shape);
77631 const filterStrides = computeStrides(filter.shape);
77632 const convInfo = computeConv3DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad);
77633 const dx = new TensorBuffer(convInfo.inShape, 'float32');
77634 const dxValues = dx.values;
77635 const [dxS0, dxS1, dxS2, dxS3] = dx.strides;
77636 const dyValues = backend.data.get(dy.dataId).values;
77637 const [dyS0, dyS1, dyS2, dyS3] = dyStrides;
77638 const fltValues = backend.data.get(filter.dataId).values;
77639 const [fltS0, fltS1, fltS2, fltS3] = filterStrides;
77640 const { batchSize, filterDepth, filterHeight, filterWidth, inChannels, inDepth, inHeight, inWidth, outChannels, outDepth, outHeight, outWidth, strideDepth, strideHeight, strideWidth } = convInfo;
77641 const frontPad = filterDepth - 1 - convInfo.padInfo.front;
77642 const topPad = filterHeight - 1 - convInfo.padInfo.top;
77643 const leftPad = filterWidth - 1 - convInfo.padInfo.left;
77644 for (let b = 0; b < batchSize; ++b) {
77645 for (let d1 = 0; d1 < inChannels; ++d1) {
77646 // Frames of depth
77647 for (let xF = 0; xF < inDepth; ++xF) {
77648 const xFCorner = xF - frontPad;
77649 const xFMin = Math.max(0, Math.ceil(xFCorner / strideDepth));
77650 const yFMax = Math.min(outDepth, (filterDepth + xFCorner) / strideDepth);
77651 // Rows as per standard 2d matrix notation
77652 for (let xR = 0; xR < inHeight; ++xR) {
77653 const xRCorner = xR - topPad;
77654 const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
77655 const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
77656 // Columns as per standard 2d matrix notation
77657 for (let xC = 0; xC < inWidth; ++xC) {
77658 const xCCorner = xC - leftPad;
77659 const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
77660 const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
77661 let dotProd = 0;
77662 for (let yF = xFMin; yF < yFMax; ++yF) {
77663 const wF = yF * strideDepth - xFCorner;
77664 for (let yR = xRMin; yR < yRMax; ++yR) {
77665 const wR = yR * strideHeight - xRCorner;
77666 for (let yC = xCMin; yC < yCMax; ++yC) {
77667 const wC = yC * strideWidth - xCCorner;
77668 const dyOffset = dyS0 * b + dyS1 * yF + dyS2 * yR + dyS3 * yC;
77669 const fltOffset = fltS0 * (filterDepth - 1 - wF) +
77670 fltS1 * (filterHeight - 1 - wR) +
77671 fltS2 * (filterWidth - 1 - wC) + fltS3 * d1;
77672 for (let d2 = 0; d2 < outChannels; ++d2) {
77673 const pixel = dyValues[dyOffset + d2];
77674 const weight = fltValues[fltOffset + d2];
77675 dotProd += pixel * weight;
77676 }
77677 }
77678 }
77679 }
77680 dxValues[dxS0 * b + dxS1 * xF + dxS2 * xR + dxS3 * xC + d1] =
77681 dotProd;
77682 }
77683 }
77684 }
77685 }
77686 }
77687 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
77688 }
77689 const conv3DBackpropInputV2Config = {
77690 kernelName: Conv3DBackpropInputV2,
77691 backendName: 'cpu',
77692 kernelFunc: conv3DBackpropInputV2
77693 };
77694
77695 /**
77696 * @license
77697 * Copyright 2020 Google LLC. All Rights Reserved.
77698 * Licensed under the Apache License, Version 2.0 (the "License");
77699 * you may not use this file except in compliance with the License.
77700 * You may obtain a copy of the License at
77701 *
77702 * http://www.apache.org/licenses/LICENSE-2.0
77703 *
77704 * Unless required by applicable law or agreed to in writing, software
77705 * distributed under the License is distributed on an "AS IS" BASIS,
77706 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77707 * See the License for the specific language governing permissions and
77708 * limitations under the License.
77709 * =============================================================================
77710 */
77711 const cos$1 = unaryKernelFunc$1(Cos, (xi) => Math.cos(xi));
77712 const cosConfig$1 = {
77713 kernelName: Cos,
77714 backendName: 'cpu',
77715 kernelFunc: cos$1,
77716 };
77717
77718 /**
77719 * @license
77720 * Copyright 2020 Google LLC. All Rights Reserved.
77721 * Licensed under the Apache License, Version 2.0 (the License);
77722 * you may not use this file except in compliance with the License.
77723 * You may obtain a copy of the License at
77724 *
77725 * http://www.apache.org/licenses/LICENSE-2.0
77726 *
77727 * Unless required by applicable law or agreed to in writing, software
77728 * distributed under the License is distributed on an AS IS BASIS,
77729 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77730 * See the License for the specific language governing permissions and
77731 * limitations under the License.
77732 * =============================================================================
77733 */
77734 const cosh$1 = unaryKernelFunc$1(Cosh, (xi) => Math.cosh(xi));
77735 const coshConfig$1 = {
77736 kernelName: Cosh,
77737 backendName: 'cpu',
77738 kernelFunc: cosh$1,
77739 };
77740
77741 /**
77742 * @license
77743 * Copyright 2020 Google LLC. All Rights Reserved.
77744 * Licensed under the Apache License, Version 2.0 (the "License");
77745 * you may not use this file except in compliance with the License.
77746 * You may obtain a copy of the License at
77747 *
77748 * http://www.apache.org/licenses/LICENSE-2.0
77749 *
77750 * Unless required by applicable law or agreed to in writing, software
77751 * distributed under the License is distributed on an "AS IS" BASIS,
77752 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77753 * See the License for the specific language governing permissions and
77754 * limitations under the License.
77755 * =============================================================================
77756 */
77757 function cropAndResize$1(args) {
77758 const { inputs, backend, attrs } = args;
77759 const { image, boxes, boxInd } = inputs;
77760 const { cropSize, method, extrapolationValue } = attrs;
77761 const [batch, imageHeight, imageWidth, numChannels] = image.shape;
77762 const numBoxes = boxes.shape[0];
77763 const [cropHeight, cropWidth] = cropSize;
77764 const output = buffer([numBoxes, cropHeight, cropWidth, numChannels], 'float32');
77765 const boxVals = backend.data.get(boxes.dataId).values;
77766 const boxIndVals = backend.data.get(boxInd.dataId).values;
77767 const imageVals = backend.data.get(image.dataId).values;
77768 const inStride = computeStrides(image.shape); // to calculate flat indexes into image
77769 const outStride = computeStrides(output.shape); // to calculate flat indexes into output
77770 // Reference implementation
77771 // tslint:disable-next-line:max-line-length
77772 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op.cc
77773 for (let b = 0; b < numBoxes; b++) {
77774 const startInd = b * 4;
77775 const y1 = boxVals[startInd];
77776 const x1 = boxVals[startInd + 1];
77777 const y2 = boxVals[startInd + 2];
77778 const x2 = boxVals[startInd + 3];
77779 const bInd = boxIndVals[b];
77780 if (bInd >= batch) {
77781 continue;
77782 }
77783 const heightScale = (cropHeight > 1) ? (y2 - y1) * (imageHeight - 1) / (cropHeight - 1) : 0;
77784 const widthScale = (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : 0;
77785 for (let y = 0; y < cropHeight; y++) {
77786 const yInd = (cropHeight > 1) ?
77787 y1 * (imageHeight - 1) + y * (heightScale) :
77788 0.5 * (y1 + y2) * (imageHeight - 1);
77789 if (yInd < 0 || yInd > imageHeight - 1) {
77790 for (let x = 0; x < cropWidth; x++) {
77791 for (let c = 0; c < numChannels; c++) {
77792 const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
77793 output.values[ind] = extrapolationValue;
77794 }
77795 }
77796 continue;
77797 }
77798 if (method === 'bilinear') {
77799 const topInd = Math.floor(yInd);
77800 const bottomInd = Math.ceil(yInd);
77801 const yLerp = yInd - topInd;
77802 for (let x = 0; x < cropWidth; x++) {
77803 const xInd = (cropWidth > 1) ?
77804 x1 * (imageWidth - 1) + x * widthScale :
77805 0.5 * (x1 + x2) * (imageWidth - 1);
77806 if (xInd < 0 || xInd > imageWidth - 1) {
77807 for (let c = 0; c < numChannels; c++) {
77808 const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
77809 output.values[ind] = extrapolationValue;
77810 }
77811 continue;
77812 }
77813 const leftInd = Math.floor(xInd);
77814 const rightInd = Math.ceil(xInd);
77815 const xLerp = xInd - leftInd;
77816 for (let c = 0; c < numChannels; c++) {
77817 let ind = c + leftInd * inStride[2] + topInd * inStride[1] +
77818 bInd * inStride[0];
77819 const topLeft = imageVals[ind];
77820 ind = c + rightInd * inStride[2] + topInd * inStride[1] +
77821 bInd * inStride[0];
77822 const topRight = imageVals[ind];
77823 ind = c + leftInd * inStride[2] + bottomInd * inStride[1] +
77824 bInd * inStride[0];
77825 const bottomLeft = imageVals[ind];
77826 ind = c + rightInd * inStride[2] + bottomInd * inStride[1] +
77827 bInd * inStride[0];
77828 const bottomRight = imageVals[ind];
77829 const top = topLeft + (topRight - topLeft) * xLerp;
77830 const bottom = bottomLeft + (bottomRight - bottomLeft) * xLerp;
77831 ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
77832 output.values[ind] = top + ((bottom - top) * yLerp);
77833 }
77834 }
77835 }
77836 else { // method == "nearest"
77837 for (let x = 0; x < cropWidth; ++x) {
77838 const xInd = (cropWidth > 1) ?
77839 x1 * (imageWidth - 1) + x * widthScale :
77840 0.5 * (x1 + x2) * (imageWidth - 1);
77841 if (xInd < 0 || xInd > imageWidth - 1) {
77842 for (let c = 0; c < numChannels; c++) {
77843 const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
77844 output.values[ind] = extrapolationValue;
77845 }
77846 continue;
77847 }
77848 const closestX = Math.round(xInd);
77849 const closestY = Math.round(yInd);
77850 for (let c = 0; c < numChannels; c++) {
77851 const inInd = c + closestX * inStride[2] + closestY * inStride[1] +
77852 bInd * inStride[0];
77853 const outInd = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
77854 output.values[outInd] = imageVals[inInd];
77855 }
77856 }
77857 }
77858 }
77859 }
77860 return backend.makeTensorInfo(output.shape, output.dtype, output.values);
77861 }
77862 const cropAndResizeConfig$1 = {
77863 kernelName: CropAndResize,
77864 backendName: 'cpu',
77865 kernelFunc: cropAndResize$1
77866 };
77867
77868 /**
77869 * @license
77870 * Copyright 2022 Google LLC. All Rights Reserved.
77871 * Licensed under the Apache License, Version 2.0 (the "License");
77872 * you may not use this file except in compliance with the License.
77873 * You may obtain a copy of the License at
77874 *
77875 * http://www.apache.org/licenses/LICENSE-2.0
77876 *
77877 * Unless required by applicable law or agreed to in writing, software
77878 * distributed under the License is distributed on an "AS IS" BASIS,
77879 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77880 * See the License for the specific language governing permissions and
77881 * limitations under the License.
77882 * =============================================================================
77883 */
77884 function cumprod$1(args) {
77885 const { inputs, backend, attrs } = args;
77886 const { x } = inputs;
77887 const { axis, exclusive, reverse } = attrs;
77888 assertNotComplex$1(x, 'cumprod');
77889 const permutation = getAxesPermutation([axis], x.shape.length);
77890 let $x = x;
77891 if (permutation != null) {
77892 $x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutation } });
77893 }
77894 const permutedAxis = getInnerMostAxes(1, x.shape.length)[0];
77895 if (permutedAxis !== $x.shape.length - 1) {
77896 throw new Error(`backend.cumprod in CPU expects an inner-most ` +
77897 `axis=${$x.shape.length - 1} but got axis=${permutedAxis}`);
77898 }
77899 const resultDtype = upcastType($x.dtype, 'int32');
77900 const vals = makeOnesTypedArray(sizeFromShape($x.shape), resultDtype);
77901 const aVals = backend.data.get($x.dataId).values;
77902 const finalDim = $x.shape[$x.shape.length - 1];
77903 const indexAdjuster = reverse ?
77904 (i, j) => i + finalDim - j - 1 :
77905 (i, j) => i + j;
77906 for (let i = 0; i < aVals.length; i += finalDim) {
77907 for (let j = 0; j < finalDim; j++) {
77908 const idx = indexAdjuster(i, j);
77909 if (j === 0) {
77910 vals[idx] = exclusive ? 1 : aVals[idx];
77911 }
77912 else {
77913 const prevIdx = indexAdjuster(i, j - 1);
77914 vals[idx] = exclusive ? aVals[prevIdx] * vals[prevIdx] :
77915 aVals[idx] * vals[prevIdx];
77916 }
77917 }
77918 }
77919 const result = backend.makeTensorInfo($x.shape, resultDtype, vals);
77920 if (permutation != null) {
77921 const reversePermutation = getUndoAxesPermutation(permutation);
77922 const reverseTransposedResult = transpose$1({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } });
77923 backend.disposeIntermediateTensorInfo(result);
77924 backend.disposeIntermediateTensorInfo($x);
77925 return reverseTransposedResult;
77926 }
77927 return result;
77928 }
77929 const cumprodConfig$1 = {
77930 kernelName: Cumprod,
77931 backendName: 'cpu',
77932 kernelFunc: cumprod$1
77933 };
77934
77935 /**
77936 * @license
77937 * Copyright 2020 Google LLC. All Rights Reserved.
77938 * Licensed under the Apache License, Version 2.0 (the "License");
77939 * you may not use this file except in compliance with the License.
77940 * You may obtain a copy of the License at
77941 *
77942 * http://www.apache.org/licenses/LICENSE-2.0
77943 *
77944 * Unless required by applicable law or agreed to in writing, software
77945 * distributed under the License is distributed on an "AS IS" BASIS,
77946 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
77947 * See the License for the specific language governing permissions and
77948 * limitations under the License.
77949 * =============================================================================
77950 */
77951 function cumsum$1(args) {
77952 const { inputs, backend, attrs } = args;
77953 const { x } = inputs;
77954 const { axis, exclusive, reverse } = attrs;
77955 assertNotComplex$1(x, 'cumsum');
77956 const permutation = getAxesPermutation([axis], x.shape.length);
77957 let $x = x;
77958 if (permutation != null) {
77959 $x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutation } });
77960 }
77961 const permutedAxis = getInnerMostAxes(1, x.shape.length)[0];
77962 if (permutedAxis !== $x.shape.length - 1) {
77963 throw new Error(`backend.cumsum in CPU expects an inner-most ` +
77964 `axis=${$x.shape.length - 1} but got axis=${permutedAxis}`);
77965 }
77966 const resultDtype = upcastType($x.dtype, 'int32');
77967 const vals = makeZerosTypedArray(sizeFromShape($x.shape), resultDtype);
77968 const aVals = backend.data.get($x.dataId).values;
77969 const finalDim = $x.shape[$x.shape.length - 1];
77970 const indexAdjuster = reverse ?
77971 (i, j) => i + finalDim - j - 1 :
77972 (i, j) => i + j;
77973 for (let i = 0; i < aVals.length; i += finalDim) {
77974 for (let j = 0; j < finalDim; j++) {
77975 const idx = indexAdjuster(i, j);
77976 if (j === 0) {
77977 vals[idx] = exclusive ? 0 : aVals[idx];
77978 }
77979 else {
77980 const prevIdx = indexAdjuster(i, j - 1);
77981 vals[idx] = exclusive ? aVals[prevIdx] + vals[prevIdx] :
77982 aVals[idx] + vals[prevIdx];
77983 }
77984 }
77985 }
77986 const result = backend.makeTensorInfo($x.shape, resultDtype, vals);
77987 if (permutation != null) {
77988 const reversePermutation = getUndoAxesPermutation(permutation);
77989 const reverseTransposedResult = transpose$1({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } });
77990 backend.disposeIntermediateTensorInfo(result);
77991 backend.disposeIntermediateTensorInfo($x);
77992 return reverseTransposedResult;
77993 }
77994 return result;
77995 }
77996 const cumsumConfig$1 = {
77997 kernelName: Cumsum,
77998 backendName: 'cpu',
77999 kernelFunc: cumsum$1
78000 };
78001
78002 /**
78003 * @license
78004 * Copyright 2020 Google LLC. All Rights Reserved.
78005 * Licensed under the Apache License, Version 2.0 (the "License");
78006 * you may not use this file except in compliance with the License.
78007 * You may obtain a copy of the License at
78008 *
78009 * http://www.apache.org/licenses/LICENSE-2.0
78010 *
78011 * Unless required by applicable law or agreed to in writing, software
78012 * distributed under the License is distributed on an "AS IS" BASIS,
78013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78014 * See the License for the specific language governing permissions and
78015 * limitations under the License.
78016 * =============================================================================
78017 */
78018 function denseBincount$1(args) {
78019 const { inputs, backend, attrs } = args;
78020 const { x, weights } = inputs;
78021 const { size, binaryOutput } = attrs;
78022 if (x.shape.length === 1) {
78023 const xVals = backend.data.get(x.dataId).values;
78024 const weightsVals = backend.data.get(weights.dataId).values;
78025 const outVals = bincountImpl(xVals, weightsVals, weights.dtype, weights.shape, size);
78026 return backend.makeTensorInfo([size], weights.dtype, outVals);
78027 }
78028 else if (x.shape.length === 2) {
78029 const xBuf = backend.bufferSync(x);
78030 const weightsBuf = backend.bufferSync(weights);
78031 const outBuf = bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput);
78032 return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values);
78033 }
78034 throw new Error(`Error in denseBincount: input must be at most rank 2, but got rank` +
78035 `${x.shape.length}.`);
78036 }
78037 const denseBincountConfig$1 = {
78038 kernelName: DenseBincount,
78039 backendName: 'cpu',
78040 kernelFunc: denseBincount$1
78041 };
78042
78043 /**
78044 * @license
78045 * Copyright 2020 Google LLC. All Rights Reserved.
78046 * Licensed under the Apache License, Version 2.0 (the "License");
78047 * you may not use this file except in compliance with the License.
78048 * You may obtain a copy of the License at
78049 *
78050 * http://www.apache.org/licenses/LICENSE-2.0
78051 *
78052 * Unless required by applicable law or agreed to in writing, software
78053 * distributed under the License is distributed on an "AS IS" BASIS,
78054 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78055 * See the License for the specific language governing permissions and
78056 * limitations under the License.
78057 * =============================================================================
78058 */
78059 function depthToSpace$1(args) {
78060 const { inputs, backend, attrs } = args;
78061 const { x } = inputs;
78062 const { blockSize, dataFormat } = attrs;
78063 assert$1(dataFormat === 'NHWC', () => `Only NHWC dataFormat supported on CPU for depthToSpace. Got ${dataFormat}`);
78064 const batchSize = x.shape[0];
78065 const inputHeight = x.shape[1];
78066 const inputWidth = x.shape[2];
78067 const inputDepth = x.shape[3];
78068 const outputHeight = inputHeight * blockSize;
78069 const outputWidth = inputWidth * blockSize;
78070 const outputDepth = inputDepth / (blockSize * blockSize);
78071 const xValues = backend.data.get(x.dataId).values;
78072 const result = new Float32Array(batchSize * outputHeight * outputWidth * outputDepth);
78073 let outputIdx = 0;
78074 for (let b = 0; b < batchSize; ++b) {
78075 for (let h = 0; h < outputHeight; ++h) {
78076 const inH = Math.floor(h / blockSize);
78077 const offsetH = (h % blockSize);
78078 for (let w = 0; w < outputWidth; ++w) {
78079 const inW = Math.floor(w / blockSize);
78080 const offsetW = (w % blockSize);
78081 const offsetD = (offsetH * blockSize + offsetW) * outputDepth;
78082 for (let d = 0; d < outputDepth; ++d) {
78083 const inD = d + offsetD;
78084 const inputIdx = inD + inputDepth * (inW + inputWidth * (inH + inputHeight * b));
78085 result[outputIdx++] = xValues[inputIdx];
78086 }
78087 }
78088 }
78089 }
78090 return backend.makeTensorInfo([batchSize, outputHeight, outputWidth, outputDepth], x.dtype, result);
78091 }
78092 const depthToSpaceConfig$1 = {
78093 kernelName: DepthToSpace,
78094 backendName: 'cpu',
78095 kernelFunc: depthToSpace$1
78096 };
78097
78098 /**
78099 * @license
78100 * Copyright 2020 Google LLC. All Rights Reserved.
78101 * Licensed under the Apache License, Version 2.0 (the "License");
78102 * you may not use this file except in compliance with the License.
78103 * You may obtain a copy of the License at
78104 *
78105 * http://www.apache.org/licenses/LICENSE-2.0
78106 *
78107 * Unless required by applicable law or agreed to in writing, software
78108 * distributed under the License is distributed on an "AS IS" BASIS,
78109 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78110 * See the License for the specific language governing permissions and
78111 * limitations under the License.
78112 * =============================================================================
78113 */
78114 function depthwiseConv2dNative$1(args) {
78115 const { inputs, backend, attrs } = args;
78116 const { x, filter } = inputs;
78117 const { strides, pad, dilations, dimRoundingMode } = attrs;
78118 assertNotComplex$1([x, filter], 'depthwiseConv2DNative');
78119 const xStrides = computeStrides(x.shape);
78120 const filterStrides = computeStrides(filter.shape);
78121 let $dilations = dilations;
78122 if ($dilations == null) {
78123 $dilations = [1, 1];
78124 }
78125 assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +
78126 `1. Got strides ${strides} and dilations '${$dilations}'`);
78127 const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */);
78128 const { filterHeight, filterWidth, dilationHeight, dilationWidth, padInfo } = convInfo;
78129 const padLeft = padInfo.left;
78130 const padTop = padInfo.top;
78131 const chMul = convInfo.outChannels / convInfo.inChannels;
78132 const y = new TensorBuffer(convInfo.outShape, x.dtype);
78133 const xVals = backend.data.get(x.dataId).values;
78134 const wVals = backend.data.get(filter.dataId).values;
78135 const yVals = y.values;
78136 for (let b = 0; b < convInfo.batchSize; ++b) {
78137 const xOffset1 = b * xStrides[0];
78138 const yOffset1 = b * y.strides[0];
78139 for (let yR = 0; yR < convInfo.outHeight; ++yR) {
78140 const yOffset2 = yOffset1 + yR * y.strides[1];
78141 const xRCorner = yR * convInfo.strideHeight - padTop;
78142 for (let wR = 0; wR < filterHeight; ++wR) {
78143 const xR = xRCorner + wR * dilationHeight;
78144 if (xR < 0 || xR >= convInfo.inHeight) {
78145 continue;
78146 }
78147 const wOffset1 = wR * filterStrides[0];
78148 const xOffset2 = xOffset1 + xR * xStrides[1];
78149 for (let yC = 0; yC < convInfo.outWidth; ++yC) {
78150 const yOffset3 = yOffset2 + yC * y.strides[2];
78151 const xCCorner = yC * convInfo.strideWidth - padLeft;
78152 for (let wC = 0; wC < filterWidth; ++wC) {
78153 const xC = xCCorner + wC * dilationWidth;
78154 if (xC < 0 || xC >= convInfo.inWidth) {
78155 continue;
78156 }
78157 const wOffset2 = wOffset1 + wC * filterStrides[1];
78158 const xOffset3 = xOffset2 + xC * convInfo.inChannels;
78159 let yOffset4 = yOffset3;
78160 let wOffset3 = wOffset2;
78161 for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
78162 const xVal = xVals[xOffset3 + d1];
78163 for (let q = 0; q < chMul; ++q) {
78164 yVals[yOffset4 + q] += xVal * wVals[wOffset3 + q];
78165 }
78166 yOffset4 += chMul;
78167 wOffset3 += chMul;
78168 }
78169 }
78170 }
78171 }
78172 }
78173 }
78174 return backend.makeTensorInfo(y.shape, y.dtype, y.values);
78175 }
78176 const depthwiseConv2dNativeConfig$1 = {
78177 kernelName: DepthwiseConv2dNative,
78178 backendName: 'cpu',
78179 kernelFunc: depthwiseConv2dNative$1
78180 };
78181
78182 /**
78183 * @license
78184 * Copyright 2020 Google LLC. All Rights Reserved.
78185 * Licensed under the Apache License, Version 2.0 (the "License");
78186 * you may not use this file except in compliance with the License.
78187 * You may obtain a copy of the License at
78188 *
78189 * http://www.apache.org/licenses/LICENSE-2.0
78190 *
78191 * Unless required by applicable law or agreed to in writing, software
78192 * distributed under the License is distributed on an "AS IS" BASIS,
78193 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78194 * See the License for the specific language governing permissions and
78195 * limitations under the License.
78196 * =============================================================================
78197 */
78198 function depthwiseConv2dNativeBackpropFilter$1(args) {
78199 const { inputs, backend, attrs } = args;
78200 const { x, dy } = inputs;
78201 const { strides, dilations, pad, dimRoundingMode, filterShape } = attrs;
78202 assertNotComplex$1([x, dy], 'depthwiseConv2dNativeBackpropFilter');
78203 const convInfo = computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
78204 const { strideHeight, strideWidth, filterHeight, filterWidth } = convInfo;
78205 const dW = new TensorBuffer(convInfo.filterShape, 'float32');
78206 const leftPad = convInfo.padInfo.left;
78207 const topPad = convInfo.padInfo.top;
78208 const chMul = convInfo.outChannels / convInfo.inChannels;
78209 const xVals = backend.data.get(x.dataId).values;
78210 const xBuf = new TensorBuffer(x.shape, x.dtype, xVals);
78211 const dyVals = backend.data.get(dy.dataId).values;
78212 const dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);
78213 for (let wR = 0; wR < filterHeight; ++wR) {
78214 const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
78215 const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
78216 for (let wC = 0; wC < filterWidth; ++wC) {
78217 const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
78218 const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
78219 for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
78220 const d1 = Math.trunc(d2 / chMul);
78221 const dm = d2 % chMul;
78222 let dotProd = 0;
78223 for (let b = 0; b < convInfo.batchSize; ++b) {
78224 for (let yR = yRMin; yR < yRMax; ++yR) {
78225 const xR = wR + yR * strideHeight - topPad;
78226 for (let yC = yCMin; yC < yCMax; ++yC) {
78227 const xC = wC + yC * strideWidth - leftPad;
78228 dotProd += xBuf.get(b, xR, xC, d1) *
78229 dyBuf.get(b, yR, yC, d2);
78230 }
78231 }
78232 }
78233 dW.set(dotProd, wR, wC, d1, dm);
78234 }
78235 }
78236 }
78237 return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
78238 }
78239 const depthwiseConv2dNativeBackpropFilterConfig$1 = {
78240 kernelName: DepthwiseConv2dNativeBackpropFilter,
78241 backendName: 'cpu',
78242 kernelFunc: depthwiseConv2dNativeBackpropFilter$1
78243 };
78244
78245 /**
78246 * @license
78247 * Copyright 2020 Google LLC. All Rights Reserved.
78248 * Licensed under the Apache License, Version 2.0 (the "License");
78249 * you may not use this file except in compliance with the License.
78250 * You may obtain a copy of the License at
78251 *
78252 * http://www.apache.org/licenses/LICENSE-2.0
78253 *
78254 * Unless required by applicable law or agreed to in writing, software
78255 * distributed under the License is distributed on an "AS IS" BASIS,
78256 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78257 * See the License for the specific language governing permissions and
78258 * limitations under the License.
78259 * =============================================================================
78260 */
78261 function depthwiseConv2dNativeBackpropInput$1(args) {
78262 const { inputs, backend, attrs } = args;
78263 const { dy, filter } = inputs;
78264 const { strides, dilations, pad, dimRoundingMode, inputShape } = attrs;
78265 assertNotComplex$1([dy, filter], 'depthwiseConv2DNativeBackpropInput');
78266 const dyStrides = computeStrides(dy.shape);
78267 const filterStrides = computeStrides(filter.shape);
78268 const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
78269 const dx = new TensorBuffer(convInfo.inShape, 'float32');
78270 const dxValues = dx.values;
78271 const [dxS0, dxS1, dxS2] = dx.strides;
78272 const dyValues = backend.data.get(dy.dataId).values;
78273 const [dyS0, dyS1, dyS2] = dyStrides;
78274 const fltValues = backend.data.get(filter.dataId).values;
78275 const [fltS0, fltS1, fltS2] = filterStrides;
78276 const { batchSize, filterHeight, filterWidth, inChannels, inHeight, inWidth, outChannels, outHeight, outWidth, strideHeight, strideWidth } = convInfo;
78277 const topPad = filterHeight - 1 - convInfo.padInfo.top;
78278 const leftPad = filterWidth - 1 - convInfo.padInfo.left;
78279 const chMul = outChannels / inChannels;
78280 for (let b = 0; b < batchSize; ++b) {
78281 for (let d1 = 0; d1 < inChannels; ++d1) {
78282 for (let xR = 0; xR < inHeight; ++xR) {
78283 const xRCorner = xR - topPad;
78284 const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
78285 const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
78286 for (let xC = 0; xC < inWidth; ++xC) {
78287 const xCCorner = xC - leftPad;
78288 const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
78289 const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
78290 let dotProd = 0;
78291 for (let yR = xRMin; yR < yRMax; ++yR) {
78292 const wR = yR * strideHeight - xRCorner;
78293 for (let yC = xCMin; yC < yCMax; ++yC) {
78294 const wC = yC * strideWidth - xCCorner;
78295 const dyOffset = dyS0 * b + dyS1 * yR + dyS2 * yC;
78296 const fltOffset = fltS0 * (filterHeight - 1 - wR) +
78297 fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
78298 for (let dm = 0; dm < chMul; ++dm) {
78299 const d2 = d1 * chMul + dm;
78300 const pixel = dyValues[dyOffset + d2];
78301 const weight = fltValues[fltOffset + dm];
78302 dotProd += pixel * weight;
78303 }
78304 }
78305 }
78306 dxValues[dxS0 * b + dxS1 * xR + dxS2 * xC + d1] = dotProd;
78307 }
78308 }
78309 }
78310 }
78311 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
78312 }
78313 const depthwiseConv2dNativeBackpropInputConfig$1 = {
78314 kernelName: DepthwiseConv2dNativeBackpropInput,
78315 backendName: 'cpu',
78316 kernelFunc: depthwiseConv2dNativeBackpropInput$1
78317 };
78318
78319 /**
78320 * @license
78321 * Copyright 2020 Google LLC. All Rights Reserved.
78322 * Licensed under the Apache License, Version 2.0 (the "License");
78323 * you may not use this file except in compliance with the License.
78324 * You may obtain a copy of the License at
78325 *
78326 * http://www.apache.org/licenses/LICENSE-2.0
78327 *
78328 * Unless required by applicable law or agreed to in writing, software
78329 * distributed under the License is distributed on an "AS IS" BASIS,
78330 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78331 * See the License for the specific language governing permissions and
78332 * limitations under the License.
78333 * =============================================================================
78334 */
78335 function diag$1(args) {
78336 const { inputs, backend } = args;
78337 const { x } = inputs;
78338 const xSize = sizeFromShape(x.shape);
78339 const xVals = backend.data.get(x.dataId).values;
78340 const outBuf = buffer([xSize, xSize], x.dtype);
78341 const vals = outBuf.values;
78342 for (let i = 0; i < xVals.length; i++) {
78343 vals[i * xSize + i] = xVals[i];
78344 }
78345 const outShape = [...x.shape, ...x.shape];
78346 return backend.makeTensorInfo(outShape, outBuf.dtype, outBuf.values);
78347 }
78348 const diagConfig$1 = {
78349 kernelName: Diag,
78350 backendName: 'cpu',
78351 kernelFunc: diag$1
78352 };
78353
78354 /**
78355 * @license
78356 * Copyright 2020 Google LLC. All Rights Reserved.
78357 * Licensed under the Apache License, Version 2.0 (the "License");
78358 * you may not use this file except in compliance with the License.
78359 * You may obtain a copy of the License at
78360 *
78361 * http://www.apache.org/licenses/LICENSE-2.0
78362 *
78363 * Unless required by applicable law or agreed to in writing, software
78364 * distributed under the License is distributed on an "AS IS" BASIS,
78365 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78366 * See the License for the specific language governing permissions and
78367 * limitations under the License.
78368 * =============================================================================
78369 */
78370 const dilation2DConfig$1 = {
78371 kernelName: Dilation2D,
78372 backendName: 'cpu',
78373 kernelFunc: ({ inputs, backend, attrs }) => {
78374 const { x, filter } = inputs;
78375 const { strides, pad, dilations } = attrs;
78376 const cpuBackend = backend;
78377 const xVals = cpuBackend.data.get(x.dataId).values;
78378 const xRank = x.shape.length;
78379 const filterVals = cpuBackend.data.get(filter.dataId).values;
78380 const filterRank = filter.shape.length;
78381 const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations);
78382 const outSize = sizeFromShape(outShape);
78383 const outRank = outShape.length;
78384 const outputVals = getArrayFromDType(x.dtype, outSize);
78385 // Upsampling the input by fill in `dilation size - 1` values between each
78386 // input value.
78387 // This implementation follows the TF c++ implementation:
78388 // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
78389 for (let b = 0; b < batchSize; ++b) {
78390 for (let hOut = 0; hOut < outHeight; ++hOut) {
78391 const hBeg = hOut * strideHeight - padInfo.top;
78392 for (let wOut = 0; wOut < outWidth; ++wOut) {
78393 const wBeg = wOut * strideWidth - padInfo.left;
78394 for (let d = 0; d < inChannels; ++d) {
78395 let curVal = Number.MIN_SAFE_INTEGER;
78396 for (let h = 0; h < filterHeight; ++h) {
78397 const hIn = hBeg + h * dilationHeight;
78398 if (hIn >= 0 && hIn < inHeight) {
78399 for (let w = 0; w < filterWidth; ++w) {
78400 const wIn = wBeg + w * dilationWidth;
78401 if (wIn >= 0 && wIn < inWidth) {
78402 const xIndex = locToIndex([b, hIn, wIn, d], xRank, computeStrides(x.shape));
78403 const filterIndex = locToIndex([h, w, d], filterRank, computeStrides(filter.shape));
78404 const val = xVals[xIndex] + filterVals[filterIndex];
78405 if (val > curVal) {
78406 curVal = val;
78407 }
78408 }
78409 }
78410 }
78411 }
78412 const outputIndex = locToIndex([b, hOut, wOut, d], outRank, computeStrides(outShape));
78413 outputVals[outputIndex] = curVal;
78414 }
78415 }
78416 }
78417 }
78418 const dataId = cpuBackend.write(toTypedArray(outputVals, x.dtype), outShape, x.dtype);
78419 return { dataId, shape: outShape, dtype: x.dtype };
78420 }
78421 };
78422
78423 /**
78424 * @license
78425 * Copyright 2020 Google LLC. All Rights Reserved.
78426 * Licensed under the Apache License, Version 2.0 (the "License");
78427 * you may not use this file except in compliance with the License.
78428 * You may obtain a copy of the License at
78429 *
78430 * http://www.apache.org/licenses/LICENSE-2.0
78431 *
78432 * Unless required by applicable law or agreed to in writing, software
78433 * distributed under the License is distributed on an "AS IS" BASIS,
78434 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78435 * See the License for the specific language governing permissions and
78436 * limitations under the License.
78437 * =============================================================================
78438 */
78439 const dilation2DBackpropFilterConfig = {
78440 kernelName: Dilation2DBackpropFilter,
78441 backendName: 'cpu',
78442 kernelFunc: ({ inputs, backend, attrs }) => {
78443 const { x, filter, dy } = inputs;
78444 const { strides, pad, dilations } = attrs;
78445 const cpuBackend = backend;
78446 const $x = toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
78447 const $filter = toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
78448 const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations);
78449 assert$1(dy.rank === outShape.length, () => `Error in ${Dilation2DBackpropFilter}, dy ` +
78450 `must have the same rank as output ${outShape.length}, but got ` +
78451 `${dy.rank}`);
78452 const $dy = toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values);
78453 // The computed filter gradients has the same dimensions as the filter:
78454 // [filterHeight, filterWidth, depth]
78455 const gradients = makeZerosNestedTypedArray(filter.shape, filter.dtype);
78456 // In the case of multiple argmax branches, we only back-propagate along the
78457 // last branch, i.e., the one with largest value of `h * filter_cols + w`,
78458 // similarly to the max-pooling backward routines.
78459 // This implementation follows the TF c++ implementation:
78460 // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
78461 for (let b = 0; b < batchSize; ++b) {
78462 for (let hOut = 0; hOut < outHeight; ++hOut) {
78463 const hBeg = hOut * strideHeight - padInfo.top;
78464 for (let wOut = 0; wOut < outWidth; ++wOut) {
78465 const wBeg = wOut * strideWidth - padInfo.left;
78466 for (let d = 0; d < inChannels; ++d) {
78467 let curVal = Number.MIN_SAFE_INTEGER;
78468 let hMax = 0;
78469 let wMax = 0;
78470 for (let h = 0; h < filterHeight; ++h) {
78471 const hIn = hBeg + h * dilationHeight;
78472 if (hIn >= 0 && hIn < inHeight) {
78473 for (let w = 0; w < filterWidth; ++w) {
78474 const wIn = wBeg + w * dilationWidth;
78475 if (wIn >= 0 && wIn < inWidth) {
78476 const val = $x[b][hIn][wIn][d] + $filter[h][w][d];
78477 if (val > curVal) {
78478 curVal = val;
78479 hMax = h;
78480 wMax = w;
78481 }
78482 }
78483 }
78484 }
78485 }
78486 gradients[hMax][wMax][d] += $dy[b][hOut][wOut][d];
78487 }
78488 }
78489 }
78490 }
78491 const dataId = cpuBackend.write(toTypedArray(gradients, x.dtype), filter.shape, filter.dtype);
78492 return { dataId, shape: filter.shape, dtype: filter.dtype };
78493 }
78494 };
78495
78496 /**
78497 * @license
78498 * Copyright 2020 Google LLC. All Rights Reserved.
78499 * Licensed under the Apache License, Version 2.0 (the "License");
78500 * you may not use this file except in compliance with the License.
78501 * You may obtain a copy of the License at
78502 *
78503 * http://www.apache.org/licenses/LICENSE-2.0
78504 *
78505 * Unless required by applicable law or agreed to in writing, software
78506 * distributed under the License is distributed on an "AS IS" BASIS,
78507 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78508 * See the License for the specific language governing permissions and
78509 * limitations under the License.
78510 * =============================================================================
78511 */
78512 const dilation2DBackpropInputConfig = {
78513 kernelName: Dilation2DBackpropInput,
78514 backendName: 'cpu',
78515 kernelFunc: ({ inputs, backend, attrs }) => {
78516 const { x, filter, dy } = inputs;
78517 const { strides, pad, dilations } = attrs;
78518 const cpuBackend = backend;
78519 const $x = toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
78520 const $filter = toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
78521 const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations);
78522 assert$1(dy.rank === outShape.length, () => `Error in ${Dilation2DBackpropInput}, dy ` +
78523 `must have the same rank as output ${outShape.length}, but got ` +
78524 `${dy.rank}`);
78525 const $dy = toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values);
78526 // The computed gradients has the same dimensions as the input:
78527 // [batch, inputHeight, inputCols, inChannel]
78528 const gradients = makeZerosNestedTypedArray(x.shape, x.dtype);
78529 // In the case of multiple argmax branches, we only back-propagate along the
78530 // last branch, i.e., the one with largest value of `h * filter_cols + w`,
78531 // similarly to the max-pooling backward routines.
78532 // This implementation follows the TF c++ implementation:
78533 // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
78534 for (let b = 0; b < batchSize; ++b) {
78535 for (let hOut = 0; hOut < outHeight; ++hOut) {
78536 const hBeg = hOut * strideHeight - padInfo.top;
78537 for (let wOut = 0; wOut < outWidth; ++wOut) {
78538 const wBeg = wOut * strideWidth - padInfo.left;
78539 for (let d = 0; d < inChannels; ++d) {
78540 let curVal = Number.MIN_SAFE_INTEGER;
78541 let hInMax = (hBeg < 0) ? 0 : hBeg;
78542 let wInMax = (wBeg < 0) ? 0 : wBeg;
78543 for (let h = 0; h < filterHeight; ++h) {
78544 const hIn = hBeg + h * dilationHeight;
78545 if (hIn >= 0 && hIn < inHeight) {
78546 for (let w = 0; w < filterWidth; ++w) {
78547 const wIn = wBeg + w * dilationWidth;
78548 if (wIn >= 0 && wIn < inWidth) {
78549 const val = $x[b][hIn][wIn][d] + $filter[h][w][d];
78550 if (val > curVal) {
78551 curVal = val;
78552 hInMax = hIn;
78553 wInMax = wIn;
78554 }
78555 }
78556 }
78557 }
78558 }
78559 gradients[b][hInMax][wInMax][d] += $dy[b][hOut][wOut][d];
78560 }
78561 }
78562 }
78563 }
78564 const dataId = cpuBackend.write(toTypedArray(gradients, x.dtype), x.shape, x.dtype);
78565 return { dataId, shape: x.shape, dtype: x.dtype };
78566 }
78567 };
78568
78569 /**
78570 * @license
78571 * Copyright 2023 Google LLC.
78572 * Licensed under the Apache License, Version 2.0 (the "License");
78573 * you may not use this file except in compliance with the License.
78574 * You may obtain a copy of the License at
78575 *
78576 * http://www.apache.org/licenses/LICENSE-2.0
78577 *
78578 * Unless required by applicable law or agreed to in writing, software
78579 * distributed under the License is distributed on an "AS IS" BASIS,
78580 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78581 * See the License for the specific language governing permissions and
78582 * limitations under the License.
78583 * =============================================================================
78584 */
78585 function draw(args) {
78586 const { inputs, backend, attrs } = args;
78587 const { image } = inputs;
78588 const { canvas, options } = attrs;
78589 const { contextOptions, imageOptions } = options || {};
78590 const alpha = (imageOptions === null || imageOptions === void 0 ? void 0 : imageOptions.alpha) || 1;
78591 const contextType = (contextOptions === null || contextOptions === void 0 ? void 0 : contextOptions.contextType) || '2d';
78592 if (contextType !== '2d') {
78593 throw new Error(`Context type ${contextOptions.contextType} is not supported by the CPU backend.`);
78594 }
78595 const ctx = canvas.getContext(contextType, (contextOptions === null || contextOptions === void 0 ? void 0 : contextOptions.contextAttributes) || {});
78596 if (ctx == null) {
78597 throw new Error(`Could not get the context with ${contextType} type.`);
78598 }
78599 const [height, width] = image.shape.slice(0, 2);
78600 const depth = image.shape.length === 2 ? 1 : image.shape[2];
78601 const data = backend.data.get(image.dataId).values;
78602 const multiplier = image.dtype === 'float32' ? 255 : 1;
78603 const bytes = new Uint8ClampedArray(width * height * 4);
78604 for (let i = 0; i < height * width; ++i) {
78605 const rgba = [0, 0, 0, 255 * alpha];
78606 for (let d = 0; d < depth; d++) {
78607 const value = data[i * depth + d];
78608 if (image.dtype === 'float32') {
78609 if (value < 0 || value > 1) {
78610 throw new Error(`Tensor values for a float32 Tensor must be in the ` +
78611 `range [0 - 1] but encountered ${value}.`);
78612 }
78613 }
78614 else if (image.dtype === 'int32') {
78615 if (value < 0 || value > 255) {
78616 throw new Error(`Tensor values for a int32 Tensor must be in the ` +
78617 `range [0 - 255] but encountered ${value}.`);
78618 }
78619 }
78620 if (depth === 1) {
78621 rgba[0] = value * multiplier;
78622 rgba[1] = value * multiplier;
78623 rgba[2] = value * multiplier;
78624 }
78625 else {
78626 rgba[d] = value * multiplier;
78627 }
78628 }
78629 const j = i * 4;
78630 bytes[j + 0] = Math.round(rgba[0]);
78631 bytes[j + 1] = Math.round(rgba[1]);
78632 bytes[j + 2] = Math.round(rgba[2]);
78633 bytes[j + 3] = Math.round(rgba[3]);
78634 }
78635 canvas.width = width;
78636 canvas.height = height;
78637 const imageData = new ImageData(bytes, width, height);
78638 ctx.putImageData(imageData, 0, 0);
78639 return image;
78640 }
78641 const drawConfig = {
78642 kernelName: Draw,
78643 backendName: 'cpu',
78644 kernelFunc: draw
78645 };
78646
78647 /**
78648 * @license
78649 * Copyright 2020 Google LLC. All Rights Reserved.
78650 * Licensed under the Apache License, Version 2.0 (the "License");
78651 * you may not use this file except in compliance with the License.
78652 * You may obtain a copy of the License at
78653 *
78654 * http://www.apache.org/licenses/LICENSE-2.0
78655 *
78656 * Unless required by applicable law or agreed to in writing, software
78657 * distributed under the License is distributed on an "AS IS" BASIS,
78658 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78659 * See the License for the specific language governing permissions and
78660 * limitations under the License.
78661 * =============================================================================
78662 */
78663 function sum$1(args) {
78664 const { inputs, backend, attrs } = args;
78665 const { x } = inputs;
78666 const { axis, keepDims } = attrs;
78667 assertNotComplex$1(x, 'sum');
78668 let $x;
78669 if (x.dtype === 'bool') {
78670 $x = cast$1({ inputs: { x }, backend, attrs: { dtype: 'int32' } });
78671 }
78672 else {
78673 $x = identity$1({ inputs: { x }, backend });
78674 }
78675 const xRank = $x.shape.length;
78676 const axes = parseAxisParam(axis, $x.shape);
78677 const permutation = getAxesPermutation(axes, xRank);
78678 let reductionAxes = axes;
78679 let permutedX = $x;
78680 if (permutation != null) {
78681 permutedX =
78682 transpose$1({ inputs: { x: $x }, backend, attrs: { perm: permutation } });
78683 reductionAxes = getInnerMostAxes(reductionAxes.length, xRank);
78684 }
78685 assertAxesAreInnerMostDims('sum', reductionAxes, permutedX.shape.length);
78686 const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, reductionAxes);
78687 const resultDtype = upcastType(permutedX.dtype, 'int32');
78688 let result = zeros(backend, outShape, resultDtype);
78689 const reduceSize = sizeFromShape(reduceShape);
78690 const vals = backend.data.get(result.dataId).values;
78691 const aVals = backend.data.get(permutedX.dataId).values;
78692 for (let i = 0; i < vals.length; ++i) {
78693 const offset = i * reduceSize;
78694 let sum = 0;
78695 for (let j = 0; j < reduceSize; ++j) {
78696 sum += aVals[offset + j];
78697 }
78698 vals[i] = sum;
78699 }
78700 if (keepDims) {
78701 const newShape = expandShapeToKeepDim(result.shape, axes);
78702 const oldResult = result;
78703 result = reshape$1({ inputs: { x: result }, backend, attrs: { shape: newShape } });
78704 backend.disposeIntermediateTensorInfo(oldResult);
78705 }
78706 backend.disposeIntermediateTensorInfo($x);
78707 if (permutation != null) {
78708 backend.disposeIntermediateTensorInfo(permutedX);
78709 }
78710 return result;
78711 }
78712 const sumConfig$1 = {
78713 kernelName: Sum,
78714 backendName: 'cpu',
78715 kernelFunc: sum$1
78716 };
78717
78718 /**
78719 * @license
78720 * Copyright 2021 Google LLC. All Rights Reserved.
78721 * Licensed under the Apache License, Version 2.0 (the "License");
78722 * you may not use this file except in compliance with the License.
78723 * You may obtain a copy of the License at
78724 *
78725 * http://www.apache.org/licenses/LICENSE-2.0
78726 *
78727 * Unless required by applicable law or agreed to in writing, software
78728 * distributed under the License is distributed on an "AS IS" BASIS,
78729 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78730 * See the License for the specific language governing permissions and
78731 * limitations under the License.
78732 * =============================================================================
78733 */
78734 function einsum$1(args) {
78735 const { inputs, backend, attrs } = args;
78736 const { equation } = attrs;
78737 const tensors = inputs;
78738 const { allDims, summedDims, idDims } = decodeEinsumEquation(equation, tensors.length);
78739 checkEinsumDimSizes(allDims.length, idDims, tensors);
78740 const { path, steps } = getEinsumComputePath(summedDims, idDims);
78741 const nSteps = steps.length;
78742 let out = null;
78743 let numDimsRemaining = allDims.length;
78744 const tensorsToDispose = [];
78745 for (let i = 0; i < nSteps; ++i) {
78746 for (const idTerm of steps[i]) {
78747 const { permutationIndices: perm, expandDims: dimsToExpand } = getEinsumPermutation(numDimsRemaining, idDims[idTerm]);
78748 let x;
78749 if (isIdentityPermutation(perm)) {
78750 x = tensors[idTerm];
78751 }
78752 else {
78753 x = transpose$1({ inputs: { x: tensors[idTerm] }, backend, attrs: { perm } });
78754 tensorsToDispose.push(x);
78755 }
78756 const targetShape = x.shape.slice();
78757 for (let k = 0; k < dimsToExpand.length; ++k) {
78758 targetShape.splice(dimsToExpand[k], 0, 1);
78759 }
78760 if (!arraysEqual(x.shape, targetShape)) {
78761 x = reshape$1({ inputs: { x }, backend, attrs: { shape: targetShape } });
78762 tensorsToDispose.push(x);
78763 }
78764 if (out === null) {
78765 out = x;
78766 }
78767 else {
78768 // tslint:disable-next-line: no-unnecessary-type-assertion
78769 out = multiply$1({ inputs: { a: x, b: out }, backend });
78770 tensorsToDispose.push(out);
78771 }
78772 }
78773 if (i < nSteps - 1) {
78774 if (path[i] >= 0) {
78775 out = sum$1({
78776 inputs: { x: out },
78777 backend,
78778 attrs: {
78779 axis: path[i] - (allDims.length - numDimsRemaining),
78780 keepDims: false
78781 }
78782 });
78783 tensorsToDispose.push(out);
78784 }
78785 numDimsRemaining--;
78786 }
78787 }
78788 // Clean up intermediate tensors.
78789 for (const tensorInfo of tensorsToDispose) {
78790 if (tensorInfo === out) {
78791 continue;
78792 }
78793 backend.disposeIntermediateTensorInfo(tensorInfo);
78794 }
78795 return out;
78796 }
78797 const einsumConfig$1 = {
78798 kernelName: Einsum,
78799 backendName: 'cpu',
78800 kernelFunc: einsum$1
78801 };
78802
78803 /**
78804 * @license
78805 * Copyright 2020 Google LLC. All Rights Reserved.
78806 * Licensed under the Apache License, Version 2.0 (the "License");
78807 * you may not use this file except in compliance with the License.
78808 * You may obtain a copy of the License at
78809 *
78810 * http://www.apache.org/licenses/LICENSE-2.0
78811 *
78812 * Unless required by applicable law or agreed to in writing, software
78813 * distributed under the License is distributed on an "AS IS" BASIS,
78814 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78815 * See the License for the specific language governing permissions and
78816 * limitations under the License.
78817 * =============================================================================
78818 */
78819 function eluGrad$1(args) {
78820 const { inputs, backend } = args;
78821 const { dy, y } = inputs;
78822 assertNotComplex$1([dy, y], 'eluGrad');
78823 const resultValues = new Float32Array(sizeFromShape(y.shape));
78824 const values = backend.data.get(y.dataId).values;
78825 const dyValues = backend.data.get(dy.dataId).values;
78826 for (let i = 0; i < values.length; ++i) {
78827 const v = values[i];
78828 if (v >= 0) {
78829 resultValues[i] = dyValues[i];
78830 }
78831 else {
78832 resultValues[i] = dyValues[i] * (v + 1);
78833 }
78834 }
78835 return backend.makeTensorInfo(y.shape, 'float32', resultValues);
78836 }
78837 const eluGradConfig$1 = {
78838 kernelName: EluGrad,
78839 backendName: 'cpu',
78840 kernelFunc: eluGrad$1
78841 };
78842
78843 /**
78844 * @license
78845 * Copyright 2020 Google LLC. All Rights Reserved.
78846 * Licensed under the Apache License, Version 2.0 (the License);
78847 * you may not use this file except in compliance with the License.
78848 * You may obtain a copy of the License at
78849 *
78850 * http://www.apache.org/licenses/LICENSE-2.0
78851 *
78852 * Unless required by applicable law or agreed to in writing, software
78853 * distributed under the License is distributed on an AS IS BASIS,
78854 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78855 * See the License for the specific language governing permissions and
78856 * limitations under the License.
78857 * =============================================================================
78858 */
78859 const p = ERF_P;
78860 const a1 = ERF_A1;
78861 const a2 = ERF_A2;
78862 const a3 = ERF_A3;
78863 const a4 = ERF_A4;
78864 const a5 = ERF_A5;
78865 const erf$1 = unaryKernelFunc$1(Erf, (xi) => {
78866 const sign = Math.sign(xi);
78867 const v = Math.abs(xi);
78868 const t = 1.0 / (1.0 + p * v);
78869 return sign *
78870 (1.0 -
78871 (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t *
78872 Math.exp(-v * v));
78873 });
78874 const erfConfig$1 = {
78875 kernelName: Erf,
78876 backendName: 'cpu',
78877 kernelFunc: erf$1,
78878 };
78879
78880 /**
78881 * @license
78882 * Copyright 2020 Google LLC. All Rights Reserved.
78883 * Licensed under the Apache License, Version 2.0 (the "License");
78884 * you may not use this file except in compliance with the License.
78885 * You may obtain a copy of the License at
78886 *
78887 * http://www.apache.org/licenses/LICENSE-2.0
78888 *
78889 * Unless required by applicable law or agreed to in writing, software
78890 * distributed under the License is distributed on an "AS IS" BASIS,
78891 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78892 * See the License for the specific language governing permissions and
78893 * limitations under the License.
78894 * =============================================================================
78895 */
78896 function expandDims$1(args) {
78897 const { inputs, backend, attrs } = args;
78898 const { input } = inputs;
78899 const { dim } = attrs;
78900 const inputRank = input.shape.length;
78901 const newShape = input.shape.slice();
78902 let $dim = dim;
78903 if (dim < 0) {
78904 // Negative value is counted from the tail of rank.
78905 assert$1(-(inputRank + 1) <= dim, () => `Axis must be in the interval [${-(inputRank + 1)}, ${inputRank}]`);
78906 $dim = inputRank + dim + 1;
78907 }
78908 newShape.splice($dim, 0, 1);
78909 return reshape$1({ inputs: { x: input }, backend, attrs: { shape: newShape } });
78910 }
78911 const expandDimsConfig$1 = {
78912 kernelName: ExpandDims,
78913 backendName: 'cpu',
78914 kernelFunc: expandDims$1
78915 };
78916
78917 /**
78918 * @license
78919 * Copyright 2020 Google LLC. All Rights Reserved.
78920 * Licensed under the Apache License, Version 2.0 (the "License");
78921 * you may not use this file except in compliance with the License.
78922 * You may obtain a copy of the License at
78923 *
78924 * http://www.apache.org/licenses/LICENSE-2.0
78925 *
78926 * Unless required by applicable law or agreed to in writing, software
78927 * distributed under the License is distributed on an "AS IS" BASIS,
78928 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78929 * See the License for the specific language governing permissions and
78930 * limitations under the License.
78931 * =============================================================================
78932 */
78933 const realDivImpl = createSimpleBinaryKernelImpl((a, b) => a / b);
78934 const div = binaryKernelFunc$1(RealDiv, realDivImpl);
78935 const realDivConfig$1 = {
78936 kernelName: RealDiv,
78937 backendName: 'cpu',
78938 kernelFunc: div
78939 };
78940
78941 /**
78942 * @license
78943 * Copyright 2020 Google LLC. All Rights Reserved.
78944 * Licensed under the Apache License, Version 2.0 (the "License");
78945 * you may not use this file except in compliance with the License.
78946 * You may obtain a copy of the License at
78947 *
78948 * http://www.apache.org/licenses/LICENSE-2.0
78949 *
78950 * Unless required by applicable law or agreed to in writing, software
78951 * distributed under the License is distributed on an "AS IS" BASIS,
78952 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78953 * See the License for the specific language governing permissions and
78954 * limitations under the License.
78955 * =============================================================================
78956 */
78957 /**
78958 * Calculate FFT of inner most elements of batch tensor.
78959 */
78960 function fftBatch(input, inverse, cpuBackend) {
78961 const inputShape = input.shape;
78962 const batch = inputShape[0];
78963 const innerDim = inputShape[1];
78964 const inputVals = cpuBackend.data.get(input.dataId);
78965 const real2D = inputVals.complexTensorInfos.real;
78966 const imag2D = inputVals.complexTensorInfos.imag;
78967 // Collects real and imaginary values separately.
78968 const resultShape = [batch, innerDim];
78969 const resultSize = sizeFromShape(resultShape);
78970 const resultReal = getTypedArrayFromDType('float32', resultSize);
78971 const resultImag = getTypedArrayFromDType('float32', resultSize);
78972 for (let b = 0; b < batch; b++) {
78973 // TODO: Support slice ops for complex type.
78974 const r = slice$1({
78975 inputs: { x: real2D },
78976 backend: cpuBackend,
78977 attrs: { begin: [b, 0], size: [1, innerDim] }
78978 });
78979 const i = slice$1({
78980 inputs: { x: imag2D },
78981 backend: cpuBackend,
78982 attrs: { begin: [b, 0], size: [1, innerDim] }
78983 });
78984 const input = complex$1({ inputs: { real: r, imag: i }, backend: cpuBackend });
78985 // Run FFT by batch element.
78986 const { real, imag } = fftImpl$1(input, inverse, cpuBackend);
78987 const res = mergeRealAndImagArrays(real, imag);
78988 for (let d = 0; d < innerDim; d++) {
78989 const c = getComplexWithIndex(res, d);
78990 resultReal[b * innerDim + d] = c.real;
78991 resultImag[b * innerDim + d] = c.imag;
78992 }
78993 cpuBackend.disposeIntermediateTensorInfo(r);
78994 cpuBackend.disposeIntermediateTensorInfo(i);
78995 cpuBackend.disposeIntermediateTensorInfo(input);
78996 }
78997 const $realInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', resultReal);
78998 const $imagInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImag);
78999 const result = complex$1({ inputs: { real: $realInfo, imag: $imagInfo }, backend: cpuBackend });
79000 cpuBackend.disposeIntermediateTensorInfo($realInfo);
79001 cpuBackend.disposeIntermediateTensorInfo($imagInfo);
79002 return result;
79003 }
79004 function fftImpl$1(input, inverse, cpuBackend) {
79005 const inputSize = sizeFromShape(input.shape);
79006 const inputVals = cpuBackend.data.get(input.dataId);
79007 const realVals = cpuBackend.data.get(inputVals.complexTensorInfos.real.dataId).values;
79008 const imagVals = cpuBackend.data.get(inputVals.complexTensorInfos.imag.dataId).values;
79009 if (isExponentOf2(inputSize)) {
79010 const result = fftRadix2(realVals, imagVals, inputSize, inverse, cpuBackend);
79011 const resultShape = [input.shape[0], input.shape[1]];
79012 if (inverse) {
79013 const realInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', result.real);
79014 const imagInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', result.imag);
79015 const sizeInfo = cpuBackend.makeTensorInfo([], 'float32', createScalarValue(inputSize, 'float32'));
79016 const sizeInfoCopy = identity$1({ inputs: { x: sizeInfo }, backend: cpuBackend });
79017 const divRealInfo = realDivConfig$1.kernelFunc({ inputs: { a: realInfo, b: sizeInfo }, backend: cpuBackend });
79018 const divImagInfo = realDivConfig$1.kernelFunc({ inputs: { a: imagInfo, b: sizeInfoCopy }, backend: cpuBackend });
79019 const divRealVals = cpuBackend.data.get(divRealInfo.dataId).values;
79020 const divImagVals = cpuBackend.data.get(divImagInfo.dataId).values;
79021 cpuBackend.disposeIntermediateTensorInfo(realInfo);
79022 cpuBackend.disposeIntermediateTensorInfo(imagInfo);
79023 cpuBackend.disposeIntermediateTensorInfo(sizeInfo);
79024 cpuBackend.disposeIntermediateTensorInfo(sizeInfoCopy);
79025 cpuBackend.disposeIntermediateTensorInfo(divRealInfo);
79026 cpuBackend.disposeIntermediateTensorInfo(divImagInfo);
79027 return { real: divRealVals, imag: divImagVals };
79028 }
79029 return result;
79030 }
79031 else {
79032 const data = mergeRealAndImagArrays(realVals, imagVals);
79033 const rawOutput = fourierTransformByMatmul(data, inputSize, inverse);
79034 return splitRealAndImagArrays(rawOutput);
79035 }
79036 }
79037 function isExponentOf2(size) {
79038 return (size & size - 1) === 0;
79039 }
79040 // FFT using Cooley-Tukey algorithm on radix 2 dimensional input.
79041 function fftRadix2(realVals, imagVals, size, inverse, cpuBackend) {
79042 if (size === 1) {
79043 return { real: realVals, imag: imagVals };
79044 }
79045 const data = mergeRealAndImagArrays(realVals, imagVals);
79046 const half = size / 2;
79047 const evenComplex = complexWithEvenIndex(data);
79048 const evenRealVals = evenComplex.real;
79049 const evenImagVals = evenComplex.imag;
79050 const evenShape = [evenRealVals.length];
79051 const evenRealInfo = cpuBackend.makeTensorInfo(evenShape, 'float32', evenRealVals);
79052 const evenImagInfo = cpuBackend.makeTensorInfo(evenShape, 'float32', evenImagVals);
79053 const evenTensorInfo = complex$1({ inputs: { real: evenRealInfo, imag: evenImagInfo }, backend: cpuBackend });
79054 const oddComplex = complexWithOddIndex(data);
79055 const oddRealVals = oddComplex.real;
79056 const oddImagVals = oddComplex.imag;
79057 const oddShape = [oddRealVals.length];
79058 const oddRealInfo = cpuBackend.makeTensorInfo(oddShape, 'float32', oddRealVals);
79059 const oddImagInfo = cpuBackend.makeTensorInfo(oddShape, 'float32', oddImagVals);
79060 const oddTensorInfo = complex$1({ inputs: { real: oddRealInfo, imag: oddImagInfo }, backend: cpuBackend });
79061 // Recursive call for half part of original input.
79062 const $evenComplex = fftRadix2(evenRealVals, evenImagVals, half, inverse, cpuBackend);
79063 const $evenRealVals = $evenComplex.real;
79064 const $evenImagVals = $evenComplex.imag;
79065 const $evenShape = [$evenRealVals.length];
79066 const $evenRealInfo = cpuBackend.makeTensorInfo($evenShape, 'float32', $evenRealVals);
79067 const $evenImagInfo = cpuBackend.makeTensorInfo($evenShape, 'float32', $evenImagVals);
79068 const $evenTensorInfo = complex$1({
79069 inputs: { real: $evenRealInfo, imag: $evenImagInfo },
79070 backend: cpuBackend
79071 });
79072 const $oddComplex = fftRadix2(oddRealVals, oddImagVals, half, inverse, cpuBackend);
79073 const $oddRealVals = $oddComplex.real;
79074 const $oddImagVals = $oddComplex.imag;
79075 const $oddShape = [$oddRealVals.length];
79076 const $oddRealInfo = cpuBackend.makeTensorInfo($oddShape, 'float32', $oddRealVals);
79077 const $oddImagInfo = cpuBackend.makeTensorInfo($oddShape, 'float32', $oddImagVals);
79078 const $oddTensorInfo = complex$1({ inputs: { real: $oddRealInfo, imag: $oddImagInfo }, backend: cpuBackend });
79079 const e = exponents(size, inverse);
79080 const eShape = [e.real.length];
79081 const eRealInfo = cpuBackend.makeTensorInfo(eShape, 'float32', e.real);
79082 const eImagInfo = cpuBackend.makeTensorInfo(eShape, 'float32', e.imag);
79083 const complexInfo = complex$1({ inputs: { real: eRealInfo, imag: eImagInfo }, backend: cpuBackend });
79084 const exponentInfo = multiply$1({ inputs: { a: complexInfo, b: $oddTensorInfo }, backend: cpuBackend });
79085 const addPart = add({
79086 inputs: { a: $evenTensorInfo, b: exponentInfo },
79087 backend: cpuBackend
79088 });
79089 const subPart = sub$1({
79090 inputs: { a: $evenTensorInfo, b: exponentInfo },
79091 backend: cpuBackend
79092 });
79093 const addPartReal = real$1({ inputs: { input: addPart }, backend: cpuBackend });
79094 const subPartReal = real$1({ inputs: { input: subPart }, backend: cpuBackend });
79095 const addPartImag = imag$1({ inputs: { input: addPart }, backend: cpuBackend });
79096 const subPartImag = imag$1({ inputs: { input: subPart }, backend: cpuBackend });
79097 const $real = concat$1({
79098 inputs: [addPartReal, subPartReal],
79099 backend: cpuBackend,
79100 attrs: { axis: 0 }
79101 });
79102 const $imag = concat$1({
79103 inputs: [addPartImag, subPartImag],
79104 backend: cpuBackend,
79105 attrs: { axis: 0 }
79106 });
79107 const $realVals = cpuBackend.data.get($real.dataId).values;
79108 const $imagVals = cpuBackend.data.get($imag.dataId).values;
79109 cpuBackend.disposeIntermediateTensorInfo(evenRealInfo);
79110 cpuBackend.disposeIntermediateTensorInfo(evenImagInfo);
79111 cpuBackend.disposeIntermediateTensorInfo(evenTensorInfo);
79112 cpuBackend.disposeIntermediateTensorInfo(oddRealInfo);
79113 cpuBackend.disposeIntermediateTensorInfo(oddImagInfo);
79114 cpuBackend.disposeIntermediateTensorInfo(oddTensorInfo);
79115 cpuBackend.disposeIntermediateTensorInfo($evenRealInfo);
79116 cpuBackend.disposeIntermediateTensorInfo($evenImagInfo);
79117 cpuBackend.disposeIntermediateTensorInfo($evenTensorInfo);
79118 cpuBackend.disposeIntermediateTensorInfo($oddRealInfo);
79119 cpuBackend.disposeIntermediateTensorInfo($oddImagInfo);
79120 cpuBackend.disposeIntermediateTensorInfo($oddTensorInfo);
79121 cpuBackend.disposeIntermediateTensorInfo(eRealInfo);
79122 cpuBackend.disposeIntermediateTensorInfo(eImagInfo);
79123 cpuBackend.disposeIntermediateTensorInfo(complexInfo);
79124 cpuBackend.disposeIntermediateTensorInfo(exponentInfo);
79125 cpuBackend.disposeIntermediateTensorInfo(addPart);
79126 cpuBackend.disposeIntermediateTensorInfo(subPart);
79127 cpuBackend.disposeIntermediateTensorInfo(addPartReal);
79128 cpuBackend.disposeIntermediateTensorInfo(addPartImag);
79129 cpuBackend.disposeIntermediateTensorInfo(subPartReal);
79130 cpuBackend.disposeIntermediateTensorInfo(subPartImag);
79131 cpuBackend.disposeIntermediateTensorInfo($real);
79132 cpuBackend.disposeIntermediateTensorInfo($imag);
79133 return { real: $realVals, imag: $imagVals };
79134 }
79135 // Calculate fourier transform by multplying sinusoid matrix.
79136 function fourierTransformByMatmul(data, size, inverse) {
79137 const ret = new Float32Array(size * 2);
79138 // TODO: Use matmul instead once it supports complex64 type.
79139 for (let r = 0; r < size; r++) {
79140 let real = 0.0;
79141 let imag = 0.0;
79142 for (let c = 0; c < size; c++) {
79143 const e = exponent(r * c, size, inverse);
79144 const term = getComplexWithIndex(data, c);
79145 real += term.real * e.real - term.imag * e.imag;
79146 imag += term.real * e.imag + term.imag * e.real;
79147 }
79148 if (inverse) {
79149 real /= size;
79150 imag /= size;
79151 }
79152 assignToTypedArray(ret, real, imag, r);
79153 }
79154 return ret;
79155 }
79156
79157 /**
79158 * @license
79159 * Copyright 2020 Google LLC. All Rights Reserved.
79160 * Licensed under the Apache License, Version 2.0 (the "License");
79161 * you may not use this file except in compliance with the License.
79162 * You may obtain a copy of the License at
79163 *
79164 * http://www.apache.org/licenses/LICENSE-2.0
79165 *
79166 * Unless required by applicable law or agreed to in writing, software
79167 * distributed under the License is distributed on an "AS IS" BASIS,
79168 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79169 * See the License for the specific language governing permissions and
79170 * limitations under the License.
79171 * =============================================================================
79172 */
79173 function fft$1(args) {
79174 const { inputs, backend } = args;
79175 const { input } = inputs;
79176 const inputSize = sizeFromShape(input.shape);
79177 // Collapse all outer dimensions to a single batch dimension.
79178 const innerDimensionSize = input.shape[input.shape.length - 1];
79179 const batch = inputSize / innerDimensionSize;
79180 const input2D = reshape$1({
79181 inputs: { x: input },
79182 backend,
79183 attrs: { shape: [batch, innerDimensionSize] }
79184 });
79185 const result = fftBatch(input2D, false, backend);
79186 const resultReshaped = reshape$1({ inputs: { x: result }, backend, attrs: { shape: input.shape } });
79187 backend.disposeIntermediateTensorInfo(input2D);
79188 backend.disposeIntermediateTensorInfo(result);
79189 return resultReshaped;
79190 }
79191 const fftConfig$1 = {
79192 kernelName: FFT,
79193 backendName: 'cpu',
79194 kernelFunc: fft$1
79195 };
79196
79197 /**
79198 * @license
79199 * Copyright 2020 Google LLC. All Rights Reserved.
79200 * Licensed under the Apache License, Version 2.0 (the "License");
79201 * you may not use this file except in compliance with the License.
79202 * You may obtain a copy of the License at
79203 *
79204 * http://www.apache.org/licenses/LICENSE-2.0
79205 *
79206 * Unless required by applicable law or agreed to in writing, software
79207 * distributed under the License is distributed on an "AS IS" BASIS,
79208 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79209 * See the License for the specific language governing permissions and
79210 * limitations under the License.
79211 * =============================================================================
79212 */
79213 function fill$1(args) {
79214 const { backend, attrs } = args;
79215 const { shape, value, dtype } = attrs;
79216 const $dtype = dtype || inferDtype(value);
79217 const values = getArrayFromDType($dtype, sizeFromShape(shape));
79218 fillValues(values, value, $dtype);
79219 return backend.makeTensorInfo(shape, $dtype, values);
79220 }
79221 const fillConfig$1 = {
79222 kernelName: Fill,
79223 backendName: 'cpu',
79224 kernelFunc: fill$1
79225 };
79226 function fillValues(values, value, dtype) {
79227 if (dtype === 'string') {
79228 values.fill(value);
79229 }
79230 else {
79231 values.fill(value);
79232 }
79233 }
79234
79235 /**
79236 * @license
79237 * Copyright 2020 Google LLC. All Rights Reserved.
79238 * Licensed under the Apache License, Version 2.0 (the "License");
79239 * you may not use this file except in compliance with the License.
79240 * You may obtain a copy of the License at
79241 *
79242 * http://www.apache.org/licenses/LICENSE-2.0
79243 *
79244 * Unless required by applicable law or agreed to in writing, software
79245 * distributed under the License is distributed on an "AS IS" BASIS,
79246 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79247 * See the License for the specific language governing permissions and
79248 * limitations under the License.
79249 * =============================================================================
79250 */
79251 const flipLeftRightConfig$1 = {
79252 kernelName: FlipLeftRight,
79253 backendName: 'cpu',
79254 kernelFunc: ({ inputs, attrs, backend }) => {
79255 const { image } = inputs;
79256 const cpuBackend = backend;
79257 const output = getTypedArrayFromDType(image.dtype, sizeFromShape(image.shape));
79258 const [batch, imageHeight, imageWidth, numChannels] = image.shape;
79259 const imageVals = cpuBackend.data.get(image.dataId).values;
79260 for (let batchIdx = 0; batchIdx < batch; batchIdx++) {
79261 const batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
79262 for (let row = 0; row < imageHeight; row++) {
79263 const rowOffset = row * (imageWidth * numChannels);
79264 for (let col = 0; col < imageWidth; col++) {
79265 const colOffset = col * numChannels;
79266 for (let channel = 0; channel < numChannels; channel++) {
79267 const coordX = Math.round(imageWidth - col - 1);
79268 const outIdx = batchOffset + rowOffset + colOffset + channel;
79269 let outputValue = imageVals[outIdx];
79270 // If the coordinate position falls within the image boundaries...
79271 if (coordX >= 0 && coordX < imageWidth) {
79272 // set the output to the image value at the coordinate position.
79273 const rotatedColOffset = coordX * numChannels;
79274 const imageIdx = batchOffset + rowOffset + rotatedColOffset + channel;
79275 outputValue = imageVals[imageIdx];
79276 }
79277 output[outIdx] = outputValue;
79278 }
79279 }
79280 }
79281 }
79282 const dataId = cpuBackend.write(output, image.shape, image.dtype);
79283 return { dataId, shape: image.shape, dtype: image.dtype };
79284 }
79285 };
79286
79287 /**
79288 * @license
79289 * Copyright 2020 Google LLC. All Rights Reserved.
79290 * Licensed under the Apache License, Version 2.0 (the "License");
79291 * you may not use this file except in compliance with the License.
79292 * You may obtain a copy of the License at
79293 *
79294 * http://www.apache.org/licenses/LICENSE-2.0
79295 *
79296 * Unless required by applicable law or agreed to in writing, software
79297 * distributed under the License is distributed on an "AS IS" BASIS,
79298 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79299 * See the License for the specific language governing permissions and
79300 * limitations under the License.
79301 * =============================================================================
79302 */
79303 function fusedConv2D(args) {
79304 const { inputs, backend, attrs } = args;
79305 const { x, filter, bias, preluActivationWeights } = inputs;
79306 const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
79307 let result = conv2D({
79308 inputs: { x, filter },
79309 backend,
79310 attrs: { strides, pad, dataFormat, dilations, dimRoundingMode }
79311 });
79312 if (bias) {
79313 const resultOld = result;
79314 // For NCHW format, if bias is a 1-D tensor, it is supposed to be aligned
79315 // to the channel of the conv2d's result; if the bias is a scalar, the
79316 // bias_add is computed as if the bias was broadcasted to the shape of the
79317 // conv2d's result.
79318 if (dataFormat === 'NCHW' && bias.shape.length === 1 &&
79319 bias.shape[0] !== 1) {
79320 const reshapedBias = reshape$1({ inputs: { x: bias }, backend, attrs: { shape: [bias.shape[0], 1, 1] } });
79321 result =
79322 add({ inputs: { a: result, b: reshapedBias }, backend });
79323 backend.disposeIntermediateTensorInfo(reshapedBias);
79324 }
79325 else {
79326 // This condition handles NHWC and NCHW (scalar case). The only other case
79327 // for NCHW (1D case) is handled above.
79328 result = add({ inputs: { a: result, b: bias }, backend });
79329 }
79330 backend.disposeIntermediateTensorInfo(resultOld);
79331 }
79332 if (activation) {
79333 const resultOld = result;
79334 // For NCHW format, if PReLu activation weights is a 1-D tensor, it is
79335 // supposed to be aligned with the channel of the conv2d's result. For other
79336 // cases, whether NCHW or NHWC data format, the conv2d result is
79337 // already aligned with the activation weights.
79338 if (dataFormat === 'NCHW' && activation === 'prelu' &&
79339 preluActivationWeights.shape.length === 1 &&
79340 preluActivationWeights.shape[0] !== 1) {
79341 const reshapedAlpha = reshape$1({
79342 inputs: { x: preluActivationWeights },
79343 backend,
79344 attrs: { shape: [preluActivationWeights.shape[0], 1, 1] }
79345 });
79346 result = applyActivation(backend, result, activation, reshapedAlpha, leakyreluAlpha);
79347 backend.disposeIntermediateTensorInfo(reshapedAlpha);
79348 }
79349 else {
79350 result = applyActivation(backend, result, activation, preluActivationWeights, leakyreluAlpha);
79351 }
79352 backend.disposeIntermediateTensorInfo(resultOld);
79353 }
79354 return result;
79355 }
79356 const fusedConv2DConfig$1 = {
79357 kernelName: FusedConv2D,
79358 backendName: 'cpu',
79359 kernelFunc: fusedConv2D
79360 };
79361
79362 /**
79363 * @license
79364 * Copyright 2020 Google LLC. All Rights Reserved.
79365 * Licensed under the Apache License, Version 2.0 (the "License");
79366 * you may not use this file except in compliance with the License.
79367 * You may obtain a copy of the License at
79368 *
79369 * http://www.apache.org/licenses/LICENSE-2.0
79370 *
79371 * Unless required by applicable law or agreed to in writing, software
79372 * distributed under the License is distributed on an "AS IS" BASIS,
79373 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79374 * See the License for the specific language governing permissions and
79375 * limitations under the License.
79376 * =============================================================================
79377 */
79378 function fusedDepthwiseConv2D$1(args) {
79379 const { inputs, backend, attrs } = args;
79380 const { x, filter, bias, preluActivationWeights } = inputs;
79381 const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
79382 let result = depthwiseConv2dNative$1({
79383 inputs: { x, filter },
79384 backend,
79385 attrs: { strides, pad, dataFormat, dilations, dimRoundingMode }
79386 });
79387 if (bias) {
79388 const oldResult = result;
79389 result = add({ inputs: { a: result, b: bias }, backend });
79390 backend.disposeIntermediateTensorInfo(oldResult);
79391 }
79392 if (activation) {
79393 const oldResult = result;
79394 result = applyActivation(backend, result, activation, preluActivationWeights, leakyreluAlpha);
79395 backend.disposeIntermediateTensorInfo(oldResult);
79396 }
79397 return result;
79398 }
79399 const fusedDepthwiseConv2DConfig$1 = {
79400 kernelName: FusedDepthwiseConv2D,
79401 backendName: 'cpu',
79402 kernelFunc: fusedDepthwiseConv2D$1
79403 };
79404
79405 /**
79406 * @license
79407 * Copyright 2020 Google LLC. All Rights Reserved.
79408 * Licensed under the Apache License, Version 2.0 (the "License");
79409 * you may not use this file except in compliance with the License.
79410 * You may obtain a copy of the License at
79411 *
79412 * http://www.apache.org/licenses/LICENSE-2.0
79413 *
79414 * Unless required by applicable law or agreed to in writing, software
79415 * distributed under the License is distributed on an "AS IS" BASIS,
79416 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79417 * See the License for the specific language governing permissions and
79418 * limitations under the License.
79419 * =============================================================================
79420 */
79421 function gatherNd$1(args) {
79422 const { inputs, backend } = args;
79423 const { params, indices } = inputs;
79424 const paramsSize = sizeFromShape(params.shape);
79425 const indicesShape = indices.shape;
79426 const sliceRank = indicesShape[indicesShape.length - 1];
79427 const [resultShape, numSlices, sliceSize, strides] = prepareAndValidate(params, indices);
79428 if (numSlices === 0) {
79429 return backend.makeTensorInfo(resultShape, params.dtype, []);
79430 }
79431 const indicesData = backend.data.get(indices.dataId).values;
79432 const paramsBuf = backend.bufferSync(params);
79433 const outBuf = gatherNdImpl(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize);
79434 return backend.makeTensorInfo(resultShape, params.dtype, outBuf.values);
79435 }
79436 const gatherNdConfig$1 = {
79437 kernelName: GatherNd,
79438 backendName: 'cpu',
79439 kernelFunc: gatherNd$1
79440 };
79441
79442 /**
79443 * @license
79444 * Copyright 2020 Google LLC. All Rights Reserved.
79445 * Licensed under the Apache License, Version 2.0 (the "License");
79446 * you may not use this file except in compliance with the License.
79447 * You may obtain a copy of the License at
79448 *
79449 * http://www.apache.org/licenses/LICENSE-2.0
79450 *
79451 * Unless required by applicable law or agreed to in writing, software
79452 * distributed under the License is distributed on an "AS IS" BASIS,
79453 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79454 * See the License for the specific language governing permissions and
79455 * limitations under the License.
79456 * =============================================================================
79457 */
79458 function gatherV2$1(args) {
79459 const { inputs, backend, attrs } = args;
79460 const { x, indices } = inputs;
79461 const { axis, batchDims } = attrs;
79462 assertNotComplex$1([x, indices], 'gatherV2');
79463 // Throw error when any index is out of bound.
79464 const parsedAxis = parseAxisParam(axis, x.shape)[0];
79465 const indicesVals = backend.data.get(indices.dataId).values;
79466 const axisDim = x.shape[parsedAxis];
79467 for (let i = 0; i < indicesVals.length; ++i) {
79468 const index = indicesVals[i];
79469 assert$1(index <= axisDim - 1 && index >= 0, () => `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);
79470 }
79471 let $batchDims = batchDims;
79472 if (batchDims == null) {
79473 $batchDims = 0;
79474 }
79475 const indicesSize = sizeFromShape(indices.shape);
79476 const shapeInfo = collectGatherOpShapeInfo(x, indices, parsedAxis, $batchDims);
79477 const flattenX = reshape$1({
79478 inputs: { x },
79479 backend,
79480 attrs: {
79481 shape: [
79482 shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize,
79483 shapeInfo.sliceSize
79484 ]
79485 }
79486 });
79487 const flattenIndex = reshape$1({
79488 inputs: { x: indices },
79489 backend,
79490 attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] }
79491 });
79492 const flattenOutputShape = [
79493 shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize,
79494 shapeInfo.sliceSize
79495 ];
79496 const indicesBuf = backend.bufferSync(flattenIndex);
79497 const xBuf = backend.bufferSync(flattenX);
79498 const outBuf = gatherV2Impl(xBuf, indicesBuf, flattenOutputShape);
79499 backend.disposeIntermediateTensorInfo(flattenX);
79500 backend.disposeIntermediateTensorInfo(flattenIndex);
79501 return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values);
79502 }
79503 const gatherV2Config$1 = {
79504 kernelName: GatherV2,
79505 backendName: 'cpu',
79506 kernelFunc: gatherV2$1
79507 };
79508
79509 /**
79510 * @license
79511 * Copyright 2020 Google LLC. All Rights Reserved.
79512 * Licensed under the Apache License, Version 2.0 (the "License");
79513 * you may not use this file except in compliance with the License.
79514 * You may obtain a copy of the License at
79515 *
79516 * http://www.apache.org/licenses/LICENSE-2.0
79517 *
79518 * Unless required by applicable law or agreed to in writing, software
79519 * distributed under the License is distributed on an "AS IS" BASIS,
79520 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79521 * See the License for the specific language governing permissions and
79522 * limitations under the License.
79523 * =============================================================================
79524 */
79525 function ifft$1(args) {
79526 const { inputs, backend } = args;
79527 const { input } = inputs;
79528 const inputSize = sizeFromShape(input.shape);
79529 // Collapse all outer dimensions to a single batch dimension.
79530 const innerDimensionSize = input.shape[input.shape.length - 1];
79531 const batch = inputSize / innerDimensionSize;
79532 const input2D = reshape$1({
79533 inputs: { x: input },
79534 backend,
79535 attrs: { shape: [batch, innerDimensionSize] }
79536 });
79537 const result = fftBatch(input2D, true, backend);
79538 const resultReshaped = reshape$1({ inputs: { x: result }, backend, attrs: { shape: input.shape } });
79539 backend.disposeIntermediateTensorInfo(input2D);
79540 backend.disposeIntermediateTensorInfo(result);
79541 return resultReshaped;
79542 }
79543 const ifftConfig$1 = {
79544 kernelName: IFFT,
79545 backendName: 'cpu',
79546 kernelFunc: ifft$1
79547 };
79548
79549 /**
79550 * @license
79551 * Copyright 2020 Google LLC. All Rights Reserved.
79552 * Licensed under the Apache License, Version 2.0 (the License);
79553 * you may not use this file except in compliance with the License.
79554 * You may obtain a copy of the License at
79555 *
79556 * http://www.apache.org/licenses/LICENSE-2.0
79557 *
79558 * Unless required by applicable law or agreed to in writing, software
79559 * distributed under the License is distributed on an AS IS BASIS,
79560 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79561 * See the License for the specific language governing permissions and
79562 * limitations under the License.
79563 * =============================================================================
79564 */
79565 const isFinite$2 = unaryKernelFunc$1(IsFinite, (xi) => Number.isFinite(xi) ? 1 : 0, 'bool');
79566 const isFiniteConfig$1 = {
79567 kernelName: IsFinite,
79568 backendName: 'cpu',
79569 kernelFunc: isFinite$2,
79570 };
79571
79572 /**
79573 * @license
79574 * Copyright 2020 Google LLC. All Rights Reserved.
79575 * Licensed under the Apache License, Version 2.0 (the License);
79576 * you may not use this file except in compliance with the License.
79577 * You may obtain a copy of the License at
79578 *
79579 * http://www.apache.org/licenses/LICENSE-2.0
79580 *
79581 * Unless required by applicable law or agreed to in writing, software
79582 * distributed under the License is distributed on an AS IS BASIS,
79583 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79584 * See the License for the specific language governing permissions and
79585 * limitations under the License.
79586 * =============================================================================
79587 */
79588 const isInf$1 = unaryKernelFunc$1(IsInf, (xi) => Math.abs(xi) === Infinity ? 1 : 0, 'bool');
79589 const isInfConfig$1 = {
79590 kernelName: IsInf,
79591 backendName: 'cpu',
79592 kernelFunc: isInf$1,
79593 };
79594
79595 /**
79596 * @license
79597 * Copyright 2020 Google LLC. All Rights Reserved.
79598 * Licensed under the Apache License, Version 2.0 (the License);
79599 * you may not use this file except in compliance with the License.
79600 * You may obtain a copy of the License at
79601 *
79602 * http://www.apache.org/licenses/LICENSE-2.0
79603 *
79604 * Unless required by applicable law or agreed to in writing, software
79605 * distributed under the License is distributed on an AS IS BASIS,
79606 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79607 * See the License for the specific language governing permissions and
79608 * limitations under the License.
79609 * =============================================================================
79610 */
79611 const isNaN$2 = unaryKernelFunc$1(IsNan, (xi) => Number.isNaN(xi) ? 1 : 0, 'bool');
79612 const isNaNConfig$1 = {
79613 kernelName: IsNan,
79614 backendName: 'cpu',
79615 kernelFunc: isNaN$2,
79616 };
79617
79618 /**
79619 * @license
79620 * Copyright 2020 Google LLC. All Rights Reserved.
79621 * Licensed under the Apache License, Version 2.0 (the "License");
79622 * you may not use this file except in compliance with the License.
79623 * You may obtain a copy of the License at
79624 *
79625 * http://www.apache.org/licenses/LICENSE-2.0
79626 *
79627 * Unless required by applicable law or agreed to in writing, software
79628 * distributed under the License is distributed on an "AS IS" BASIS,
79629 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79630 * See the License for the specific language governing permissions and
79631 * limitations under the License.
79632 * =============================================================================
79633 */
79634 function linSpace$1(args) {
79635 const { backend, attrs } = args;
79636 const { start, stop, num } = attrs;
79637 const outVals = linSpaceImpl(start, stop, num);
79638 return backend.makeTensorInfo([outVals.length], 'float32', outVals);
79639 }
79640 const linSpaceConfig$1 = {
79641 kernelName: LinSpace,
79642 backendName: 'cpu',
79643 kernelFunc: linSpace$1
79644 };
79645
79646 /**
79647 * @license
79648 * Copyright 2020 Google LLC. All Rights Reserved.
79649 * Licensed under the Apache License, Version 2.0 (the License);
79650 * you may not use this file except in compliance with the License.
79651 * You may obtain a copy of the License at
79652 *
79653 * http://www.apache.org/licenses/LICENSE-2.0
79654 *
79655 * Unless required by applicable law or agreed to in writing, software
79656 * distributed under the License is distributed on an AS IS BASIS,
79657 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79658 * See the License for the specific language governing permissions and
79659 * limitations under the License.
79660 * =============================================================================
79661 */
79662 const log1p$1 = unaryKernelFunc$1(Log1p, (xi) => Math.log1p(xi));
79663 const log1pConfig$1 = {
79664 kernelName: Log1p,
79665 backendName: 'cpu',
79666 kernelFunc: log1p$1,
79667 };
79668
79669 /**
79670 * @license
79671 * Copyright 2020 Google LLC. All Rights Reserved.
79672 * Licensed under the Apache License, Version 2.0 (the "License");
79673 * you may not use this file except in compliance with the License.
79674 * You may obtain a copy of the License at
79675 *
79676 * http://www.apache.org/licenses/LICENSE-2.0
79677 *
79678 * Unless required by applicable law or agreed to in writing, software
79679 * distributed under the License is distributed on an "AS IS" BASIS,
79680 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79681 * See the License for the specific language governing permissions and
79682 * limitations under the License.
79683 * =============================================================================
79684 */
79685 const logicalAndImpl = createSimpleBinaryKernelImpl((a, b) => a && b);
79686 const logicalAnd$1 = binaryKernelFunc$1(LogicalAnd, logicalAndImpl, null /* complexImpl */, 'bool');
79687 const logicalAndConfig$1 = {
79688 kernelName: LogicalAnd,
79689 backendName: 'cpu',
79690 kernelFunc: logicalAnd$1
79691 };
79692
79693 /**
79694 * @license
79695 * Copyright 2020 Google LLC. All Rights Reserved.
79696 * Licensed under the Apache License, Version 2.0 (the License);
79697 * you may not use this file except in compliance with the License.
79698 * You may obtain a copy of the License at
79699 *
79700 * http://www.apache.org/licenses/LICENSE-2.0
79701 *
79702 * Unless required by applicable law or agreed to in writing, software
79703 * distributed under the License is distributed on an AS IS BASIS,
79704 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79705 * See the License for the specific language governing permissions and
79706 * limitations under the License.
79707 * =============================================================================
79708 */
79709 const logicalNot$1 = unaryKernelFunc$1(LogicalNot, (xi) => xi ? 0 : 1, 'bool');
79710 const logicalNotConfig$1 = {
79711 kernelName: LogicalNot,
79712 backendName: 'cpu',
79713 kernelFunc: logicalNot$1,
79714 };
79715
79716 /**
79717 * @license
79718 * Copyright 2020 Google LLC. All Rights Reserved.
79719 * Licensed under the Apache License, Version 2.0 (the "License");
79720 * you may not use this file except in compliance with the License.
79721 * You may obtain a copy of the License at
79722 *
79723 * http://www.apache.org/licenses/LICENSE-2.0
79724 *
79725 * Unless required by applicable law or agreed to in writing, software
79726 * distributed under the License is distributed on an "AS IS" BASIS,
79727 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79728 * See the License for the specific language governing permissions and
79729 * limitations under the License.
79730 * =============================================================================
79731 */
79732 const logicalOrImpl = createSimpleBinaryKernelImpl((a, b) => a || b);
79733 const logicalOr$1 = binaryKernelFunc$1(LogicalOr, logicalOrImpl, null /* complexImpl */, 'bool');
79734 const logicalOrConfig$1 = {
79735 kernelName: LogicalOr,
79736 backendName: 'cpu',
79737 kernelFunc: logicalOr$1
79738 };
79739
79740 /**
79741 * @license
79742 * Copyright 2020 Google LLC. All Rights Reserved.
79743 * Licensed under the Apache License, Version 2.0 (the "License");
79744 * you may not use this file except in compliance with the License.
79745 * You may obtain a copy of the License at
79746 *
79747 * http://www.apache.org/licenses/LICENSE-2.0
79748 *
79749 * Unless required by applicable law or agreed to in writing, software
79750 * distributed under the License is distributed on an "AS IS" BASIS,
79751 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79752 * See the License for the specific language governing permissions and
79753 * limitations under the License.
79754 * =============================================================================
79755 */
79756 function lRN(args) {
79757 const { inputs, backend, attrs } = args;
79758 const { x } = inputs;
79759 const { depthRadius, bias, alpha, beta } = attrs;
79760 assertNotComplex$1(x, 'LRN');
79761 const channels = x.shape[3];
79762 const maxD = channels - 1;
79763 const xValues = backend.data.get(x.dataId).values;
79764 const size = sizeFromShape(x.shape);
79765 const result = new Float32Array(size);
79766 function sumAcrossChannels(offset) {
79767 const currentChannel = offset % channels;
79768 let beginSumOffset = offset - currentChannel + Math.max(0, currentChannel - depthRadius);
79769 const endSumOffset = offset - currentChannel + Math.min(currentChannel + depthRadius, maxD);
79770 let sum = 0.0;
79771 for (; beginSumOffset <= endSumOffset; beginSumOffset++) {
79772 const z = xValues[beginSumOffset];
79773 sum += z * z;
79774 }
79775 return sum;
79776 }
79777 for (let offset = 0; offset < size; offset++) {
79778 const sum = sumAcrossChannels(offset);
79779 const val = xValues[offset] * Math.pow(bias + alpha * sum, -beta);
79780 result[offset] = val;
79781 }
79782 return backend.makeTensorInfo(x.shape, x.dtype, result);
79783 }
79784 // tslint:disable-next-line: variable-name
79785 const LRNConfig$1 = {
79786 kernelName: LRN,
79787 backendName: 'cpu',
79788 kernelFunc: lRN
79789 };
79790
79791 /**
79792 * @license
79793 * Copyright 2020 Google LLC. All Rights Reserved.
79794 * Licensed under the Apache License, Version 2.0 (the "License");
79795 * you may not use this file except in compliance with the License.
79796 * You may obtain a copy of the License at
79797 *
79798 * http://www.apache.org/licenses/LICENSE-2.0
79799 *
79800 * Unless required by applicable law or agreed to in writing, software
79801 * distributed under the License is distributed on an "AS IS" BASIS,
79802 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79803 * See the License for the specific language governing permissions and
79804 * limitations under the License.
79805 * =============================================================================
79806 */
79807 function lRNGrad(args) {
79808 const { inputs, backend, attrs } = args;
79809 const { x, y, dy } = inputs;
79810 const { depthRadius, bias, alpha, beta } = attrs;
79811 assertNotComplex$1(dy, 'LRNGrad');
79812 const dySize = sizeFromShape(dy.shape);
79813 const channels = dy.shape[3];
79814 const dyValues = backend.data.get(dy.dataId).values;
79815 const xValues = backend.data.get(x.dataId).values;
79816 const yValues = backend.data.get(y.dataId).values;
79817 const result = new Float32Array(dySize);
79818 const size = dySize;
79819 for (let offset = 0; offset < size; offset++) {
79820 const currentChannel = offset % channels;
79821 const depthBegin = (offset - currentChannel) + Math.max(0, currentChannel - depthRadius);
79822 const depthEnd = (offset - currentChannel) +
79823 Math.min(channels, currentChannel + depthRadius + 1);
79824 let norm = 0;
79825 for (let k = depthBegin; k < depthEnd; k++) {
79826 norm += Math.pow(xValues[k], 2);
79827 }
79828 norm = alpha * norm + bias;
79829 for (let k = depthBegin; k < depthEnd; k++) {
79830 let dyi = -2 * alpha * beta * xValues[k] * yValues[offset] / norm;
79831 if (offset === k) {
79832 dyi += Math.pow(norm, -beta);
79833 }
79834 dyi *= dyValues[offset];
79835 result[k] += dyi;
79836 }
79837 }
79838 return backend.makeTensorInfo(dy.shape, x.dtype, result);
79839 }
79840 // tslint:disable-next-line: variable-name
79841 const LRNGradConfig$1 = {
79842 kernelName: LRNGrad,
79843 backendName: 'cpu',
79844 kernelFunc: lRNGrad
79845 };
79846
79847 /**
79848 * @license
79849 * Copyright 2020 Google LLC. All Rights Reserved.
79850 * Licensed under the Apache License, Version 2.0 (the "License");
79851 * you may not use this file except in compliance with the License.
79852 * You may obtain a copy of the License at
79853 *
79854 * http://www.apache.org/licenses/LICENSE-2.0
79855 *
79856 * Unless required by applicable law or agreed to in writing, software
79857 * distributed under the License is distributed on an "AS IS" BASIS,
79858 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79859 * See the License for the specific language governing permissions and
79860 * limitations under the License.
79861 * =============================================================================
79862 */
79863 function max$1(args) {
79864 const { inputs, backend, attrs } = args;
79865 const { x } = inputs;
79866 const { reductionIndices, keepDims } = attrs;
79867 const cpuBackend = backend;
79868 let xShape = x.shape;
79869 const xRank = xShape.length;
79870 const origAxes = parseAxisParam(reductionIndices, xShape);
79871 let axes = origAxes;
79872 const permutedAxes = getAxesPermutation(axes, xRank);
79873 let xVals = cpuBackend.data.get(x.dataId).values;
79874 if (permutedAxes != null) {
79875 const newShape = new Array(xRank);
79876 for (let i = 0; i < newShape.length; i++) {
79877 newShape[i] = xShape[permutedAxes[i]];
79878 }
79879 xVals = transposeImpl$1(xVals, xShape, x.dtype, permutedAxes, newShape);
79880 axes = getInnerMostAxes(axes.length, xRank);
79881 xShape = newShape;
79882 }
79883 assertNotComplex$1(x, 'max');
79884 assertAxesAreInnerMostDims('max', axes, xRank);
79885 const [maxOutShape, reduceShape] = computeOutAndReduceShapes(xShape, axes);
79886 const reduceSize = sizeFromShape(reduceShape);
79887 const result = maxImpl$1(xVals, reduceSize, maxOutShape, x.dtype);
79888 const dataId = cpuBackend.write(result, maxOutShape, x.dtype);
79889 let outShape = maxOutShape;
79890 if (keepDims) {
79891 // reshape
79892 const newShape = expandShapeToKeepDim(maxOutShape, origAxes);
79893 outShape = newShape;
79894 }
79895 return { dataId, shape: outShape, dtype: x.dtype };
79896 }
79897 const maxConfig$1 = {
79898 kernelName: Max,
79899 backendName: 'cpu',
79900 kernelFunc: max$1
79901 };
79902
79903 /**
79904 * @license
79905 * Copyright 2020 Google LLC. All Rights Reserved.
79906 * Licensed under the Apache License, Version 2.0 (the "License");
79907 * you may not use this file except in compliance with the License.
79908 * You may obtain a copy of the License at
79909 *
79910 * http://www.apache.org/licenses/LICENSE-2.0
79911 *
79912 * Unless required by applicable law or agreed to in writing, software
79913 * distributed under the License is distributed on an "AS IS" BASIS,
79914 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79915 * See the License for the specific language governing permissions and
79916 * limitations under the License.
79917 * =============================================================================
79918 */
79919 function maxPool$1(args) {
79920 const { inputs, backend, attrs } = args;
79921 const { x } = inputs;
79922 assertNotComplex$1(x, 'maxPool');
79923 const { filterSize, strides, pad, dimRoundingMode } = attrs;
79924 const dilations = 1;
79925 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
79926 `Got strides ${strides} and dilations '${dilations}'`);
79927 const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
79928 let res;
79929 if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
79930 arraysEqual(convInfo.inShape, convInfo.outShape)) {
79931 res = identity$1({ inputs: { x }, backend });
79932 }
79933 else {
79934 const xValues = backend.data.get(x.dataId).values;
79935 const strides = computeStrides(x.shape);
79936 const buffer = pool(xValues, x.shape, x.dtype, strides, convInfo, 'max');
79937 res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
79938 }
79939 return res;
79940 }
79941 const maxPoolConfig$1 = {
79942 kernelName: MaxPool,
79943 backendName: 'cpu',
79944 kernelFunc: maxPool$1
79945 };
79946
79947 /**
79948 * @license
79949 * Copyright 2020 Google LLC. All Rights Reserved.
79950 * Licensed under the Apache License, Version 2.0 (the "License");
79951 * you may not use this file except in compliance with the License.
79952 * You may obtain a copy of the License at
79953 *
79954 * http://www.apache.org/licenses/LICENSE-2.0
79955 *
79956 * Unless required by applicable law or agreed to in writing, software
79957 * distributed under the License is distributed on an "AS IS" BASIS,
79958 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79959 * See the License for the specific language governing permissions and
79960 * limitations under the License.
79961 * =============================================================================
79962 */
79963 function maxPool3D(args) {
79964 const { inputs, backend, attrs } = args;
79965 const { x } = inputs;
79966 const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs;
79967 assertNotComplex$1(x, 'maxPool3d');
79968 const convInfo = computePool3DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode, dataFormat);
79969 const xValues = backend.data.get(x.dataId).values;
79970 const outBuf = pool3d(xValues, x.shape, x.dtype, computeStrides(x.shape), convInfo, 'max');
79971 return backend.makeTensorInfo(outBuf.shape, 'float32', outBuf.values);
79972 }
79973 const maxPool3DConfig$1 = {
79974 kernelName: MaxPool3D,
79975 backendName: 'cpu',
79976 kernelFunc: maxPool3D
79977 };
79978
79979 /**
79980 * @license
79981 * Copyright 2020 Google LLC. All Rights Reserved.
79982 * Licensed under the Apache License, Version 2.0 (the "License");
79983 * you may not use this file except in compliance with the License.
79984 * You may obtain a copy of the License at
79985 *
79986 * http://www.apache.org/licenses/LICENSE-2.0
79987 *
79988 * Unless required by applicable law or agreed to in writing, software
79989 * distributed under the License is distributed on an "AS IS" BASIS,
79990 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
79991 * See the License for the specific language governing permissions and
79992 * limitations under the License.
79993 * =============================================================================
79994 */
79995 function maxPool3DGrad$1(args) {
79996 const { inputs, backend, attrs } = args;
79997 const { dy, input } = inputs;
79998 const { filterSize, strides, pad, dimRoundingMode } = attrs;
79999 assertNotComplex$1([dy, input], 'maxPool3DGrad');
80000 const convInfo = computePool3DInfo(input.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
80001 const inputBuf = backend.bufferSync(input);
80002 const maxPosBuf = maxPool3dPositions(inputBuf, convInfo);
80003 const strideDepth = convInfo.strideDepth;
80004 const strideHeight = convInfo.strideHeight;
80005 const strideWidth = convInfo.strideWidth;
80006 const dilationDepth = convInfo.dilationDepth;
80007 const dilationHeight = convInfo.dilationHeight;
80008 const dilationWidth = convInfo.dilationWidth;
80009 const effectiveFilterDepth = convInfo.effectiveFilterDepth;
80010 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
80011 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
80012 const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
80013 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
80014 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
80015 const dx = buffer(input.shape, 'float32');
80016 const dyBuf = backend.bufferSync(dy);
80017 for (let batch = 0; batch < convInfo.batchSize; ++batch) {
80018 for (let channel = 0; channel < convInfo.inChannels; ++channel) {
80019 for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
80020 for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
80021 for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
80022 // Shader code begins
80023 const dyDepthCorner = dxDepth - padFront;
80024 const dyRowCorner = dxRow - padTop;
80025 const dyColCorner = dxCol - padLeft;
80026 let dotProd = 0;
80027 for (let wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
80028 const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
80029 if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
80030 Math.floor(dyDepth) !== dyDepth) {
80031 continue;
80032 }
80033 for (let wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
80034 const dyRow = (dyRowCorner + wRow) / strideHeight;
80035 if (dyRow < 0 || dyRow >= convInfo.outHeight ||
80036 Math.floor(dyRow) !== dyRow) {
80037 continue;
80038 }
80039 for (let wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
80040 const dyCol = (dyColCorner + wCol) / strideWidth;
80041 if (dyCol < 0 || dyCol >= convInfo.outWidth ||
80042 Math.floor(dyCol) !== dyCol) {
80043 continue;
80044 }
80045 const maxPos = effectiveFilterDepth * effectiveFilterHeight *
80046 effectiveFilterWidth -
80047 1 -
80048 maxPosBuf.get(batch, dyDepth, dyRow, dyCol, channel);
80049 const curPos = wDepth * effectiveFilterHeight * effectiveFilterWidth +
80050 wRow * effectiveFilterWidth + wCol;
80051 const mask = maxPos === curPos ? 1 : 0;
80052 if (mask === 0) {
80053 continue;
80054 }
80055 const pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
80056 dotProd += pixel * mask;
80057 }
80058 }
80059 }
80060 dx.set(dotProd, batch, dxDepth, dxRow, dxCol, channel);
80061 }
80062 }
80063 }
80064 }
80065 }
80066 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
80067 }
80068 const maxPool3DGradConfig$1 = {
80069 kernelName: MaxPool3DGrad,
80070 backendName: 'cpu',
80071 kernelFunc: maxPool3DGrad$1
80072 };
80073
80074 /**
80075 * @license
80076 * Copyright 2020 Google LLC. All Rights Reserved.
80077 * Licensed under the Apache License, Version 2.0 (the "License");
80078 * you may not use this file except in compliance with the License.
80079 * You may obtain a copy of the License at
80080 *
80081 * http://www.apache.org/licenses/LICENSE-2.0
80082 *
80083 * Unless required by applicable law or agreed to in writing, software
80084 * distributed under the License is distributed on an "AS IS" BASIS,
80085 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80086 * See the License for the specific language governing permissions and
80087 * limitations under the License.
80088 * =============================================================================
80089 */
80090 function maxPoolGrad$1(args) {
80091 const { inputs, backend, attrs } = args;
80092 const { dy, input, output } = inputs;
80093 const x = input;
80094 assertNotComplex$1([input, output], 'maxPoolGrad');
80095 const { filterSize, strides, pad, dimRoundingMode } = attrs;
80096 const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
80097 const xValues = backend.data.get(x.dataId).values;
80098 const maxPosBuf = buffer(convInfo.outShape, x.dtype, maxPoolPositions(xValues, x.shape, x.dtype, convInfo).values);
80099 const strideHeight = convInfo.strideHeight;
80100 const strideWidth = convInfo.strideWidth;
80101 const dilationHeight = convInfo.dilationHeight;
80102 const dilationWidth = convInfo.dilationWidth;
80103 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
80104 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
80105 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
80106 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
80107 const dx = buffer(x.shape, 'float32');
80108 const dyData = backend.data.get(dy.dataId).values;
80109 const dyBuf = buffer(dy.shape, 'float32', dyData);
80110 for (let b = 0; b < convInfo.batchSize; ++b) {
80111 for (let d = 0; d < convInfo.inChannels; ++d) {
80112 for (let dxR = 0; dxR < convInfo.inHeight; ++dxR) {
80113 for (let dxC = 0; dxC < convInfo.inWidth; ++dxC) {
80114 // Shader code begins.
80115 const dyRCorner = dxR - padTop;
80116 const dyCCorner = dxC - padLeft;
80117 let dotProd = 0;
80118 for (let wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
80119 const dyR = (dyRCorner + wR) / strideHeight;
80120 if (dyR < 0 || dyR >= convInfo.outHeight ||
80121 Math.floor(dyR) !== dyR) {
80122 continue;
80123 }
80124 for (let wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
80125 const dyC = (dyCCorner + wC) / strideWidth;
80126 if (dyC < 0 || dyC >= convInfo.outWidth ||
80127 Math.floor(dyC) !== dyC) {
80128 continue;
80129 }
80130 const maxPos = effectiveFilterHeight * effectiveFilterWidth - 1 -
80131 maxPosBuf.get(b, dyR, dyC, d);
80132 const curPos = wR * effectiveFilterWidth + wC;
80133 const mask = maxPos === curPos ? 1 : 0;
80134 if (mask === 0) {
80135 continue;
80136 }
80137 const pixel = dyBuf.get(b, dyR, dyC, d);
80138 dotProd += pixel * mask;
80139 }
80140 }
80141 dx.set(dotProd, b, dxR, dxC, d);
80142 }
80143 }
80144 }
80145 }
80146 return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
80147 }
80148 const maxPoolGradConfig$1 = {
80149 kernelName: MaxPoolGrad,
80150 backendName: 'cpu',
80151 kernelFunc: maxPoolGrad$1
80152 };
80153
80154 /**
80155 * @license
80156 * Copyright 2020 Google LLC. All Rights Reserved.
80157 * Licensed under the Apache License, Version 2.0 (the "License");
80158 * you may not use this file except in compliance with the License.
80159 * You may obtain a copy of the License at
80160 *
80161 * http://www.apache.org/licenses/LICENSE-2.0
80162 *
80163 * Unless required by applicable law or agreed to in writing, software
80164 * distributed under the License is distributed on an "AS IS" BASIS,
80165 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80166 * See the License for the specific language governing permissions and
80167 * limitations under the License.
80168 * =============================================================================
80169 */
80170 function maxPoolWithArgmaxImpl$1(xValues, xShape, dtype, includeBatchInIndex, convInfo) {
80171 const strides = computeStrides(xShape);
80172 const maxPools = pool(xValues, xShape, dtype, strides, convInfo, 'max');
80173 const maxPositions = maxPoolPositions(xValues, xShape, dtype, convInfo, true, includeBatchInIndex);
80174 return [maxPools.values, maxPositions.values];
80175 }
80176
80177 /**
80178 * @license
80179 * Copyright 2020 Google LLC. All Rights Reserved.
80180 * Licensed under the Apache License, Version 2.0 (the "License");
80181 * you may not use this file except in compliance with the License.
80182 * You may obtain a copy of the License at
80183 *
80184 * http://www.apache.org/licenses/LICENSE-2.0
80185 *
80186 * Unless required by applicable law or agreed to in writing, software
80187 * distributed under the License is distributed on an "AS IS" BASIS,
80188 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80189 * See the License for the specific language governing permissions and
80190 * limitations under the License.
80191 * =============================================================================
80192 */
80193 const maxPoolWithArgmaxConfig$1 = {
80194 kernelName: MaxPoolWithArgmax,
80195 backendName: 'cpu',
80196 kernelFunc: ({ inputs, attrs, backend }) => {
80197 const { x } = inputs;
80198 const { filterSize, strides, pad, includeBatchInIndex } = attrs;
80199 const cpuBackend = backend;
80200 assertNotComplex$1(x, 'MaxPoolWithArgmax');
80201 const values = cpuBackend.data.get(x.dataId).values;
80202 const convInfo = computePool2DInfo(x.shape, filterSize, strides, [1, 1], pad);
80203 const [pooled, indexes] = maxPoolWithArgmaxImpl$1(values, x.shape, x.dtype, includeBatchInIndex, convInfo);
80204 const pooledDataId = cpuBackend.write(pooled, convInfo.outShape, x.dtype);
80205 const indexesDataId = cpuBackend.write(indexes, convInfo.outShape, x.dtype);
80206 return [
80207 { dataId: pooledDataId, shape: convInfo.outShape, dtype: x.dtype },
80208 { dataId: indexesDataId, shape: convInfo.outShape, dtype: 'int32' }
80209 ];
80210 }
80211 };
80212
80213 /**
80214 * @license
80215 * Copyright 2020 Google LLC. All Rights Reserved.
80216 * Licensed under the Apache License, Version 2.0 (the "License");
80217 * you may not use this file except in compliance with the License.
80218 * You may obtain a copy of the License at
80219 *
80220 * http://www.apache.org/licenses/LICENSE-2.0
80221 *
80222 * Unless required by applicable law or agreed to in writing, software
80223 * distributed under the License is distributed on an "AS IS" BASIS,
80224 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80225 * See the License for the specific language governing permissions and
80226 * limitations under the License.
80227 * =============================================================================
80228 */
80229 function mean(args) {
80230 const { inputs, backend, attrs } = args;
80231 const { x } = inputs;
80232 const { axis, keepDims } = attrs;
80233 const axes = parseAxisParam(axis, x.shape);
80234 const shapes = computeOutAndReduceShapes(x.shape, axes);
80235 const reduceShape = shapes[1];
80236 const reduceSize = sizeFromShape(reduceShape);
80237 const toDispose = [];
80238 const reduceSizeScalar = backend.makeTensorInfo([], 'float32', new Float32Array([reduceSize]));
80239 toDispose.push(reduceSizeScalar);
80240 const $x = cast$1({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
80241 toDispose.push($x);
80242 const res = div({ inputs: { a: $x, b: reduceSizeScalar }, backend });
80243 toDispose.push(res);
80244 const result = sum$1({ inputs: { x: res }, backend, attrs: { axis, keepDims } });
80245 toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
80246 return result;
80247 }
80248 const meanConfig$1 = {
80249 kernelName: Mean,
80250 backendName: 'cpu',
80251 kernelFunc: mean
80252 };
80253
80254 /**
80255 * @license
80256 * Copyright 2020 Google LLC. All Rights Reserved.
80257 * Licensed under the Apache License, Version 2.0 (the "License");
80258 * you may not use this file except in compliance with the License.
80259 * You may obtain a copy of the License at
80260 *
80261 * http://www.apache.org/licenses/LICENSE-2.0
80262 *
80263 * Unless required by applicable law or agreed to in writing, software
80264 * distributed under the License is distributed on an "AS IS" BASIS,
80265 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80266 * See the License for the specific language governing permissions and
80267 * limitations under the License.
80268 * =============================================================================
80269 */
80270 function min$1(args) {
80271 const { inputs, backend, attrs } = args;
80272 const { x } = inputs;
80273 const { axis, keepDims } = attrs;
80274 assertNotComplex$1(x, 'min');
80275 const origAxes = parseAxisParam(axis, x.shape);
80276 let axes = origAxes;
80277 const permutedAxes = getAxesPermutation(axes, x.shape.length);
80278 let $x = x;
80279 if (permutedAxes != null) {
80280 $x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
80281 axes = getInnerMostAxes(axes.length, x.shape.length);
80282 }
80283 assertAxesAreInnerMostDims('min', axes, $x.shape.length);
80284 const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
80285 const reduceSize = sizeFromShape(reduceShape);
80286 const vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
80287 const aVals = backend.data.get($x.dataId).values;
80288 for (let i = 0; i < vals.length; ++i) {
80289 const offset = i * reduceSize;
80290 let min = aVals[offset];
80291 for (let j = 0; j < reduceSize; ++j) {
80292 const value = aVals[offset + j];
80293 if (Number.isNaN(value) ||
80294 value < min) { // comparison with NaN always return false
80295 min = value;
80296 }
80297 }
80298 vals[i] = min;
80299 }
80300 if (permutedAxes != null) {
80301 backend.disposeIntermediateTensorInfo($x);
80302 }
80303 const result = backend.makeTensorInfo(outShape, $x.dtype, vals);
80304 if (keepDims) {
80305 const expandedShape = expandShapeToKeepDim(outShape, origAxes);
80306 const reshapedResult = reshape$1({ inputs: { x: result }, backend, attrs: { shape: expandedShape } });
80307 backend.disposeIntermediateTensorInfo(result);
80308 return reshapedResult;
80309 }
80310 return result;
80311 }
80312 const minConfig$1 = {
80313 kernelName: Min,
80314 backendName: 'cpu',
80315 kernelFunc: min$1
80316 };
80317
80318 /**
80319 * @license
80320 * Copyright 2020 Google LLC. All Rights Reserved.
80321 * Licensed under the Apache License, Version 2.0 (the "License");
80322 * you may not use this file except in compliance with the License.
80323 * You may obtain a copy of the License at
80324 *
80325 * http://www.apache.org/licenses/LICENSE-2.0
80326 *
80327 * Unless required by applicable law or agreed to in writing, software
80328 * distributed under the License is distributed on an "AS IS" BASIS,
80329 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80330 * See the License for the specific language governing permissions and
80331 * limitations under the License.
80332 * =============================================================================
80333 */
80334 function mirrorPad(args) {
80335 const { inputs, backend, attrs } = args;
80336 const { x } = inputs;
80337 const { paddings, mode } = attrs;
80338 assertNotComplex$1(x, 'mirrorPad');
80339 const outShape = paddings.map((p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */);
80340 const start = paddings.map(p => p[0]);
80341 const end = paddings.map((p, i) => p[0] + x.shape[i]);
80342 const offset = mode === 'reflect' ? 0 : 1;
80343 const xVals = backend.data.get(x.dataId).values;
80344 const xRank = x.shape.length;
80345 const xStrides = computeStrides(x.shape);
80346 const resultSize = sizeFromShape(outShape);
80347 const resultRank = outShape.length;
80348 const resultStrides = computeStrides(outShape);
80349 const resVals = getTypedArrayFromDType(x.dtype, resultSize);
80350 for (let i = 0; i < resultSize; i++) {
80351 let coords = indexToLoc(i, resultRank, resultStrides);
80352 for (let i = 0; i < resultRank; i++) {
80353 if (coords[i] < start[i]) {
80354 coords[i] = start[i] * 2 - coords[i] - offset;
80355 }
80356 else if (coords[i] >= end[i]) {
80357 coords[i] = (end[i] - 1) * 2 - coords[i] + offset;
80358 }
80359 }
80360 coords = coords.map((c, i) => c - start[i]);
80361 const inIndex = locToIndex(coords, xRank, xStrides);
80362 resVals[i] = xVals[inIndex];
80363 }
80364 const outId = backend.write(resVals, outShape, x.dtype);
80365 return { dataId: outId, shape: outShape, dtype: x.dtype };
80366 }
80367 const mirrorPadConfig$1 = {
80368 kernelName: MirrorPad,
80369 backendName: 'cpu',
80370 kernelFunc: mirrorPad
80371 };
80372
80373 /**
80374 * @license
80375 * Copyright 2020 Google LLC. All Rights Reserved.
80376 * Licensed under the Apache License, Version 2.0 (the "License");
80377 * you may not use this file except in compliance with the License.
80378 * You may obtain a copy of the License at
80379 *
80380 * http://www.apache.org/licenses/LICENSE-2.0
80381 *
80382 * Unless required by applicable law or agreed to in writing, software
80383 * distributed under the License is distributed on an "AS IS" BASIS,
80384 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80385 * See the License for the specific language governing permissions and
80386 * limitations under the License.
80387 * =============================================================================
80388 */
80389 const modImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => {
80390 const rem = aValue % bValue;
80391 if ((aValue < 0 && bValue < 0) || (aValue >= 0 && bValue >= 0)) {
80392 return rem;
80393 }
80394 else {
80395 return (rem + bValue) % bValue;
80396 }
80397 }));
80398 const mod$1 = binaryKernelFunc$1(Mod, modImpl);
80399 const modConfig$1 = {
80400 kernelName: Mod,
80401 backendName: 'cpu',
80402 kernelFunc: mod$1
80403 };
80404
80405 /**
80406 * @license
80407 * Copyright 2020 Google LLC. All Rights Reserved.
80408 * Licensed under the Apache License, Version 2.0 (the "License");
80409 * you may not use this file except in compliance with the License.
80410 * You may obtain a copy of the License at
80411 *
80412 * http://www.apache.org/licenses/LICENSE-2.0
80413 *
80414 * Unless required by applicable law or agreed to in writing, software
80415 * distributed under the License is distributed on an "AS IS" BASIS,
80416 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80417 * See the License for the specific language governing permissions and
80418 * limitations under the License.
80419 * =============================================================================
80420 */
80421 function softmax$1(args) {
80422 const { inputs, backend, attrs } = args;
80423 const { logits } = inputs;
80424 const { dim } = attrs;
80425 const logitsRank = logits.shape.length;
80426 let $dim = dim;
80427 if ($dim === -1) {
80428 $dim = logitsRank - 1;
80429 }
80430 if ($dim !== logitsRank - 1) {
80431 throw Error('Softmax along a non-last dimension is not yet supported. ' +
80432 `Logits was rank ${logitsRank} and dim was ${$dim}`);
80433 }
80434 const axes = parseAxisParam([$dim], logits.shape);
80435 const maxLogit = max$1({
80436 inputs: { x: logits },
80437 backend,
80438 attrs: { reductionIndices: axes, keepDims: false }
80439 });
80440 const expandedShape = expandShapeToKeepDim(maxLogit.shape, axes);
80441 const maxLogitReshaped = reshape$1({ inputs: { x: maxLogit }, backend, attrs: { shape: expandedShape } });
80442 const a = sub$1({ inputs: { a: logits, b: maxLogitReshaped }, backend });
80443 const b = exp$1({ inputs: { x: a }, backend });
80444 const sumExp = sum$1({ inputs: { x: b }, backend, attrs: { axis: axes, keepDims: false } });
80445 const sumReshaped = reshape$1({ inputs: { x: sumExp }, backend, attrs: { shape: expandedShape } });
80446 const result = div({ inputs: { a: b, b: sumReshaped }, backend });
80447 backend.disposeIntermediateTensorInfo(maxLogit);
80448 backend.disposeIntermediateTensorInfo(maxLogitReshaped);
80449 backend.disposeIntermediateTensorInfo(a);
80450 backend.disposeIntermediateTensorInfo(b);
80451 backend.disposeIntermediateTensorInfo(sumExp);
80452 backend.disposeIntermediateTensorInfo(sumReshaped);
80453 return result;
80454 }
80455 const softmaxConfig$1 = {
80456 kernelName: Softmax$2,
80457 backendName: 'cpu',
80458 kernelFunc: softmax$1
80459 };
80460
80461 /**
80462 * @license
80463 * Copyright 2020 Google LLC. All Rights Reserved.
80464 * Licensed under the Apache License, Version 2.0 (the "License");
80465 * you may not use this file except in compliance with the License.
80466 * You may obtain a copy of the License at
80467 *
80468 * http://www.apache.org/licenses/LICENSE-2.0
80469 *
80470 * Unless required by applicable law or agreed to in writing, software
80471 * distributed under the License is distributed on an "AS IS" BASIS,
80472 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80473 * See the License for the specific language governing permissions and
80474 * limitations under the License.
80475 * =============================================================================
80476 */
80477 function multinomial$1(args) {
80478 const { inputs, backend, attrs } = args;
80479 const { logits } = inputs;
80480 const { numSamples, seed, normalized } = attrs;
80481 assertNotComplex$1(logits, 'multinomial');
80482 const probabilities = normalized ?
80483 logits :
80484 softmax$1({ inputs: { logits }, backend, attrs: { dim: -1 } });
80485 const batchSize = probabilities.shape[0];
80486 const numEvents = probabilities.shape[1];
80487 const probVals = backend.data.get(probabilities.dataId).values;
80488 const resShape = [batchSize, numSamples];
80489 const resVals = makeZerosTypedArray(sizeFromShape(resShape), 'int32');
80490 for (let b = 0; b < batchSize; ++b) {
80491 const offset = b * numEvents;
80492 // The cdf won't include the last event. It will be implicit if no other
80493 // event happened.
80494 const cdf = new Float32Array(numEvents - 1);
80495 cdf[0] = probVals[offset];
80496 for (let event = 1; event < cdf.length; ++event) {
80497 cdf[event] = cdf[event - 1] + probVals[offset + event];
80498 }
80499 const random = seedrandom.alea(seed.toString());
80500 const outOffset = b * numSamples;
80501 for (let sampleId = 0; sampleId < numSamples; ++sampleId) {
80502 const r = random();
80503 // Assume last event happened by default.
80504 resVals[outOffset + sampleId] = cdf.length;
80505 for (let event = 0; event < cdf.length; event++) {
80506 if (r < cdf[event]) {
80507 resVals[outOffset + sampleId] = event;
80508 break;
80509 }
80510 }
80511 }
80512 }
80513 if (!normalized) {
80514 backend.disposeIntermediateTensorInfo(probabilities);
80515 }
80516 return backend.makeTensorInfo(resShape, 'int32', resVals);
80517 }
80518 const multinomialConfig$1 = {
80519 kernelName: Multinomial,
80520 backendName: 'cpu',
80521 kernelFunc: multinomial$1
80522 };
80523
80524 /**
80525 * @license
80526 * Copyright 2020 Google LLC. All Rights Reserved.
80527 * Licensed under the Apache License, Version 2.0 (the "License");
80528 * you may not use this file except in compliance with the License.
80529 * You may obtain a copy of the License at
80530 *
80531 * http://www.apache.org/licenses/LICENSE-2.0
80532 *
80533 * Unless required by applicable law or agreed to in writing, software
80534 * distributed under the License is distributed on an "AS IS" BASIS,
80535 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80536 * See the License for the specific language governing permissions and
80537 * limitations under the License.
80538 * =============================================================================
80539 */
80540 const nonMaxSuppressionV3Impl$1 = nonMaxSuppressionV3Impl$2;
80541 function nonMaxSuppressionV3$1(args) {
80542 const { inputs, backend, attrs } = args;
80543 const { boxes, scores } = inputs;
80544 const { maxOutputSize, iouThreshold, scoreThreshold } = attrs;
80545 assertNotComplex$1(boxes, 'NonMaxSuppression');
80546 const boxesVals = backend.data.get(boxes.dataId).values;
80547 const scoresVals = backend.data.get(scores.dataId).values;
80548 const { selectedIndices } = nonMaxSuppressionV3Impl$1(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
80549 return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices));
80550 }
80551 const nonMaxSuppressionV3Config$1 = {
80552 kernelName: NonMaxSuppressionV3,
80553 backendName: 'cpu',
80554 kernelFunc: nonMaxSuppressionV3$1
80555 };
80556
80557 /**
80558 * @license
80559 * Copyright 2020 Google LLC. All Rights Reserved.
80560 * Licensed under the Apache License, Version 2.0 (the "License");
80561 * you may not use this file except in compliance with the License.
80562 * You may obtain a copy of the License at
80563 *
80564 * http://www.apache.org/licenses/LICENSE-2.0
80565 *
80566 * Unless required by applicable law or agreed to in writing, software
80567 * distributed under the License is distributed on an "AS IS" BASIS,
80568 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80569 * See the License for the specific language governing permissions and
80570 * limitations under the License.
80571 * =============================================================================
80572 */
80573 const nonMaxSuppressionV4Impl$1 = nonMaxSuppressionV4Impl$2;
80574 function nonMaxSuppressionV4$1(args) {
80575 const { inputs, backend, attrs } = args;
80576 const { boxes, scores } = inputs;
80577 const { maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize } = attrs;
80578 assertNotComplex$1(boxes, 'NonMaxSuppressionPadded');
80579 const boxesVals = backend.data.get(boxes.dataId).values;
80580 const scoresVals = backend.data.get(scores.dataId).values;
80581 const { selectedIndices, validOutputs } = nonMaxSuppressionV4Impl$1(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize);
80582 return [
80583 backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
80584 backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs]))
80585 ];
80586 }
80587 const nonMaxSuppressionV4Config$1 = {
80588 kernelName: NonMaxSuppressionV4,
80589 backendName: 'cpu',
80590 kernelFunc: nonMaxSuppressionV4$1
80591 };
80592
80593 /**
80594 * @license
80595 * Copyright 2019 Google LLC. All Rights Reserved.
80596 * Licensed under the Apache License, Version 2.0 (the "License");
80597 * you may not use this file except in compliance with the License.
80598 * You may obtain a copy of the License at
80599 *
80600 * http://www.apache.org/licenses/LICENSE-2.0
80601 *
80602 * Unless required by applicable law or agreed to in writing, software
80603 * distributed under the License is distributed on an "AS IS" BASIS,
80604 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80605 * See the License for the specific language governing permissions and
80606 * limitations under the License.
80607 * =============================================================================
80608 */
80609 const nonMaxSuppressionV5Impl$1 = nonMaxSuppressionV5Impl$2;
80610 function nonMaxSuppressionV5$1(args) {
80611 const { inputs, backend, attrs } = args;
80612 const { boxes, scores } = inputs;
80613 const { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma } = attrs;
80614 assertNotComplex$1(boxes, 'NonMaxSuppressionWithScore');
80615 const boxesVals = backend.data.get(boxes.dataId).values;
80616 const scoresVals = backend.data.get(scores.dataId).values;
80617 const maxOutputSizeVal = maxOutputSize;
80618 const iouThresholdVal = iouThreshold;
80619 const scoreThresholdVal = scoreThreshold;
80620 const softNmsSigmaVal = softNmsSigma;
80621 const { selectedIndices, selectedScores } = nonMaxSuppressionV5Impl$1(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal);
80622 return [
80623 backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
80624 backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores))
80625 ];
80626 }
80627 const nonMaxSuppressionV5Config$1 = {
80628 kernelName: NonMaxSuppressionV5,
80629 backendName: 'cpu',
80630 kernelFunc: nonMaxSuppressionV5$1
80631 };
80632
80633 /**
80634 * @license
80635 * Copyright 2020 Google LLC. All Rights Reserved.
80636 * Licensed under the Apache License, Version 2.0 (the "License");
80637 * you may not use this file except in compliance with the License.
80638 * You may obtain a copy of the License at
80639 *
80640 * http://www.apache.org/licenses/LICENSE-2.0
80641 *
80642 * Unless required by applicable law or agreed to in writing, software
80643 * distributed under the License is distributed on an "AS IS" BASIS,
80644 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80645 * See the License for the specific language governing permissions and
80646 * limitations under the License.
80647 * =============================================================================
80648 */
80649 function oneHot$1(args) {
80650 const { inputs, backend, attrs } = args;
80651 const { indices } = inputs;
80652 const { dtype, depth, onValue, offValue } = attrs;
80653 assertNotComplex$1(indices, 'oneHot');
80654 const indicesSize = sizeFromShape(indices.shape);
80655 const res = new Float32Array(indicesSize * depth);
80656 res.fill(offValue);
80657 const indicesVal = backend.data.get(indices.dataId).values;
80658 for (let event = 0; event < indicesSize; ++event) {
80659 if (indicesVal[event] >= 0 && indicesVal[event] < depth) {
80660 res[event * depth + indicesVal[event]] = onValue;
80661 }
80662 }
80663 return backend.makeTensorInfo([...indices.shape, depth], dtype, res);
80664 }
80665 const oneHotConfig$1 = {
80666 kernelName: OneHot,
80667 backendName: 'cpu',
80668 kernelFunc: oneHot$1
80669 };
80670
80671 /**
80672 * @license
80673 * Copyright 2020 Google LLC. All Rights Reserved.
80674 * Licensed under the Apache License, Version 2.0 (the "License");
80675 * you may not use this file except in compliance with the License.
80676 * You may obtain a copy of the License at
80677 *
80678 * http://www.apache.org/licenses/LICENSE-2.0
80679 *
80680 * Unless required by applicable law or agreed to in writing, software
80681 * distributed under the License is distributed on an "AS IS" BASIS,
80682 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80683 * See the License for the specific language governing permissions and
80684 * limitations under the License.
80685 * =============================================================================
80686 */
80687 function zerosLike$1(args) {
80688 const { inputs, backend } = args;
80689 const { x } = inputs;
80690 if (x.dtype === 'string') {
80691 throw new Error('zerosLike is not supported for string tensors');
80692 }
80693 else if (x.dtype === 'complex64') {
80694 const realPart = real$1({ inputs: { input: x }, backend });
80695 const r = zerosLike$1({ inputs: { x: realPart }, backend });
80696 const imagPart = imag$1({ inputs: { input: x }, backend });
80697 const i = zerosLike$1({ inputs: { x: imagPart }, backend });
80698 const result = complex$1({ inputs: { real: r, imag: i }, backend });
80699 backend.disposeIntermediateTensorInfo(realPart);
80700 backend.disposeIntermediateTensorInfo(r);
80701 backend.disposeIntermediateTensorInfo(imagPart);
80702 backend.disposeIntermediateTensorInfo(i);
80703 return result;
80704 }
80705 else {
80706 return fill$1({ backend, attrs: { shape: x.shape, value: 0, dtype: x.dtype } });
80707 }
80708 }
80709 const zerosLikeConfig$1 = {
80710 kernelName: ZerosLike,
80711 backendName: 'cpu',
80712 kernelFunc: zerosLike$1
80713 };
80714
80715 /**
80716 * @license
80717 * Copyright 2020 Google LLC. All Rights Reserved.
80718 * Licensed under the Apache License, Version 2.0 (the "License");
80719 * you may not use this file except in compliance with the License.
80720 * You may obtain a copy of the License at
80721 *
80722 * http://www.apache.org/licenses/LICENSE-2.0
80723 *
80724 * Unless required by applicable law or agreed to in writing, software
80725 * distributed under the License is distributed on an "AS IS" BASIS,
80726 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80727 * See the License for the specific language governing permissions and
80728 * limitations under the License.
80729 * =============================================================================
80730 */
80731 function onesLike$1(args) {
80732 const { inputs, backend } = args;
80733 const { x } = inputs;
80734 if (x.dtype === 'string') {
80735 throw new Error('onesLike is not supported for string tensors');
80736 }
80737 else if (x.dtype === 'complex64') {
80738 const realPart = real$1({ inputs: { input: x }, backend });
80739 const r = onesLike$1({ inputs: { x: realPart }, backend });
80740 const imagPart = imag$1({ inputs: { input: x }, backend });
80741 const i = zerosLike$1({ inputs: { x: imagPart }, backend });
80742 const result = complex$1({ inputs: { real: r, imag: i }, backend });
80743 backend.disposeIntermediateTensorInfo(realPart);
80744 backend.disposeIntermediateTensorInfo(r);
80745 backend.disposeIntermediateTensorInfo(imagPart);
80746 backend.disposeIntermediateTensorInfo(i);
80747 return result;
80748 }
80749 else {
80750 return fill$1({ backend, attrs: { shape: x.shape, value: 1, dtype: x.dtype } });
80751 }
80752 }
80753 const onesLikeConfig$1 = {
80754 kernelName: OnesLike,
80755 backendName: 'cpu',
80756 kernelFunc: onesLike$1
80757 };
80758
80759 /**
80760 * @license
80761 * Copyright 2020 Google LLC. All Rights Reserved.
80762 * Licensed under the Apache License, Version 2.0 (the "License");
80763 * you may not use this file except in compliance with the License.
80764 * You may obtain a copy of the License at
80765 *
80766 * http://www.apache.org/licenses/LICENSE-2.0
80767 *
80768 * Unless required by applicable law or agreed to in writing, software
80769 * distributed under the License is distributed on an "AS IS" BASIS,
80770 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80771 * See the License for the specific language governing permissions and
80772 * limitations under the License.
80773 * =============================================================================
80774 */
80775 function pack$1(args) {
80776 const { inputs, backend, attrs } = args;
80777 const { axis } = attrs;
80778 if (inputs.length === 1) {
80779 return expandDims$1({ inputs: { input: inputs[0] }, backend, attrs: { dim: axis } });
80780 }
80781 const shape = inputs[0].shape;
80782 const dtype = inputs[0].dtype;
80783 inputs.forEach(t => {
80784 assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes');
80785 assert$1(dtype === t.dtype, () => 'All tensors passed to stack must have matching dtypes');
80786 });
80787 const intermediateTensorInfos = [];
80788 const expandedTensors = inputs.map(t => {
80789 const expandedT = expandDims$1({ inputs: { input: t }, backend, attrs: { dim: axis } });
80790 intermediateTensorInfos.push(expandedT);
80791 return expandedT;
80792 });
80793 const result = concat$1({ inputs: expandedTensors, backend, attrs: { axis } });
80794 intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
80795 return result;
80796 }
80797 const packConfig$1 = {
80798 kernelName: Pack,
80799 backendName: 'cpu',
80800 kernelFunc: pack$1
80801 };
80802
80803 /**
80804 * @license
80805 * Copyright 2020 Google LLC. All Rights Reserved.
80806 * Licensed under the Apache License, Version 2.0 (the "License");
80807 * you may not use this file except in compliance with the License.
80808 * You may obtain a copy of the License at
80809 *
80810 * http://www.apache.org/licenses/LICENSE-2.0
80811 *
80812 * Unless required by applicable law or agreed to in writing, software
80813 * distributed under the License is distributed on an "AS IS" BASIS,
80814 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80815 * See the License for the specific language governing permissions and
80816 * limitations under the License.
80817 * =============================================================================
80818 */
80819 function padV2$1(args) {
80820 const { inputs, backend, attrs } = args;
80821 const { x } = inputs;
80822 const { paddings, constantValue } = attrs;
80823 assertNotComplex$1(x, 'pad');
80824 const outShape = paddings.map((p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */);
80825 const start = paddings.map(p => p[0]);
80826 const xVals = backend.data.get(x.dataId).values;
80827 const xSize = sizeFromShape(x.shape);
80828 const xRank = x.shape.length;
80829 const xStrides = computeStrides(x.shape);
80830 const resultSize = sizeFromShape(outShape);
80831 const resultRank = outShape.length;
80832 const resultStrides = computeStrides(outShape);
80833 const resVals = getTypedArrayFromDType(x.dtype, resultSize);
80834 if (constantValue !== 0) {
80835 resVals.fill(constantValue);
80836 }
80837 for (let i = 0; i < xSize; i++) {
80838 const coords = indexToLoc(i, xRank, xStrides);
80839 const outCoords = coords.map((c, i) => c + start[i]);
80840 const outIndex = locToIndex(outCoords, resultRank, resultStrides);
80841 resVals[outIndex] = xVals[i];
80842 }
80843 const outId = backend.write(resVals, outShape, x.dtype);
80844 return { dataId: outId, shape: outShape, dtype: x.dtype };
80845 }
80846 const padV2Config$1 = {
80847 kernelName: PadV2,
80848 backendName: 'cpu',
80849 kernelFunc: padV2$1
80850 };
80851
80852 /**
80853 * @license
80854 * Copyright 2020 Google LLC. All Rights Reserved.
80855 * Licensed under the Apache License, Version 2.0 (the "License");
80856 * you may not use this file except in compliance with the License.
80857 * You may obtain a copy of the License at
80858 *
80859 * http://www.apache.org/licenses/LICENSE-2.0
80860 *
80861 * Unless required by applicable law or agreed to in writing, software
80862 * distributed under the License is distributed on an "AS IS" BASIS,
80863 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80864 * See the License for the specific language governing permissions and
80865 * limitations under the License.
80866 * =============================================================================
80867 */
80868 const powImpl = createSimpleBinaryKernelImpl((a, b) => Math.pow(a, b));
80869 const pow$1 = binaryKernelFunc$1(Pow, powImpl);
80870 const powConfig$1 = {
80871 kernelName: Pow,
80872 backendName: 'cpu',
80873 kernelFunc: pow$1
80874 };
80875
80876 /**
80877 * @license
80878 * Copyright 2022 Google LLC. All Rights Reserved.
80879 * Licensed under the Apache License, Version 2.0 (the "License");
80880 * you may not use this file except in compliance with the License.
80881 * You may obtain a copy of the License at
80882 *
80883 * http://www.apache.org/licenses/LICENSE-2.0
80884 *
80885 * Unless required by applicable law or agreed to in writing, software
80886 * distributed under the License is distributed on an "AS IS" BASIS,
80887 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80888 * See the License for the specific language governing permissions and
80889 * limitations under the License.
80890 * =============================================================================
80891 */
80892 function raggedGather$1(args) {
80893 const { inputs, backend, attrs } = args;
80894 const { paramsNestedSplits, paramsDenseValues, indices } = inputs;
80895 const { outputRaggedRank } = attrs;
80896 const $paramsNestedSplits = paramsNestedSplits.map(t => backend.data.get(t.dataId).values);
80897 const $paramsNestedSplitsShapes = paramsNestedSplits.map(t => t.shape);
80898 const $paramsDenseValues = backend.data.get(paramsDenseValues.dataId).values;
80899 const $indices = backend.data.get(indices.dataId).values;
80900 const [outputNestedSplits, outputDenseValues, outputDenseValuesShape] = raggedGatherImpl($paramsNestedSplits, $paramsNestedSplitsShapes, $paramsDenseValues, paramsDenseValues.shape, paramsDenseValues.dtype, $indices, indices.shape, outputRaggedRank);
80901 const outputNestedSplitsTensors = outputNestedSplits.map((splits) => backend.makeTensorInfo([splits.length], 'int32', splits));
80902 const outputDenseValuesTensor = backend.makeTensorInfo(outputDenseValuesShape, paramsDenseValues.dtype, outputDenseValues);
80903 return outputNestedSplitsTensors.concat([outputDenseValuesTensor]);
80904 }
80905 const raggedGatherConfig$1 = {
80906 kernelName: RaggedGather,
80907 backendName: 'cpu',
80908 kernelFunc: raggedGather$1,
80909 };
80910
80911 /**
80912 * @license
80913 * Copyright 2022 Google LLC.
80914 * Licensed under the Apache License, Version 2.0 (the "License");
80915 * you may not use this file except in compliance with the License.
80916 * You may obtain a copy of the License at
80917 *
80918 * http://www.apache.org/licenses/LICENSE-2.0
80919 *
80920 * Unless required by applicable law or agreed to in writing, software
80921 * distributed under the License is distributed on an "AS IS" BASIS,
80922 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80923 * See the License for the specific language governing permissions and
80924 * limitations under the License.
80925 * =============================================================================
80926 */
80927 function raggedRange$1(args) {
80928 const { inputs, backend } = args;
80929 const { starts, limits, deltas } = inputs;
80930 const $starts = backend.data.get(starts.dataId).values;
80931 const $limits = backend.data.get(limits.dataId).values;
80932 const $deltas = backend.data.get(deltas.dataId).values;
80933 const [rtNestedSplitsData, rtDenseValuesData] = raggedRangeImpl($starts, starts.shape, starts.dtype, $limits, limits.shape, $deltas, deltas.shape);
80934 const rtNestedSplits = backend.makeTensorInfo([rtNestedSplitsData.length], 'int32', rtNestedSplitsData);
80935 const rtDenseValues = backend.makeTensorInfo([rtDenseValuesData.length], starts.dtype, rtDenseValuesData);
80936 return [rtNestedSplits, rtDenseValues];
80937 }
80938 const raggedRangeConfig$1 = {
80939 kernelName: RaggedRange,
80940 backendName: 'cpu',
80941 kernelFunc: raggedRange$1,
80942 };
80943
80944 /**
80945 * @license
80946 * Copyright 2022 Google LLC. All Rights Reserved.
80947 * Licensed under the Apache License, Version 2.0 (the "License");
80948 * you may not use this file except in compliance with the License.
80949 * You may obtain a copy of the License at
80950 *
80951 * http://www.apache.org/licenses/LICENSE-2.0
80952 *
80953 * Unless required by applicable law or agreed to in writing, software
80954 * distributed under the License is distributed on an "AS IS" BASIS,
80955 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80956 * See the License for the specific language governing permissions and
80957 * limitations under the License.
80958 * =============================================================================
80959 */
80960 function raggedTensorToTensor$1(args) {
80961 const { inputs, backend, attrs } = args;
80962 const { shape, values, defaultValue, rowPartitionTensors } = inputs;
80963 const { rowPartitionTypes } = attrs;
80964 const $shape = backend.data.get(shape.dataId).values;
80965 const $values = backend.data.get(values.dataId).values;
80966 const $defaultValue = backend.data.get(defaultValue.dataId).values;
80967 const $rowPartitionValues = rowPartitionTensors.map(t => backend.data.get(t.dataId).values);
80968 const rowPartitionValuesShapes = rowPartitionTensors.map(t => t.shape);
80969 const [outputShape, output] = raggedTensorToTensorImpl($shape, shape.shape, $values, values.shape, values.dtype, $defaultValue, defaultValue.shape, $rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes);
80970 return backend.makeTensorInfo(outputShape, values.dtype, output);
80971 }
80972 const raggedTensorToTensorConfig$1 = {
80973 kernelName: RaggedTensorToTensor,
80974 backendName: 'cpu',
80975 kernelFunc: raggedTensorToTensor$1,
80976 };
80977
80978 /**
80979 * @license
80980 * Copyright 2020 Google LLC. All Rights Reserved.
80981 * Licensed under the Apache License, Version 2.0 (the "License");
80982 * you may not use this file except in compliance with the License.
80983 * You may obtain a copy of the License at
80984 *
80985 * http://www.apache.org/licenses/LICENSE-2.0
80986 *
80987 * Unless required by applicable law or agreed to in writing, software
80988 * distributed under the License is distributed on an "AS IS" BASIS,
80989 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
80990 * See the License for the specific language governing permissions and
80991 * limitations under the License.
80992 * =============================================================================
80993 */
80994 function range$1(args) {
80995 const { backend, attrs } = args;
80996 const { start, stop, dtype, step } = attrs;
80997 const values = rangeImpl(start, stop, step, dtype);
80998 return backend.makeTensorInfo([values.length], dtype, values);
80999 }
81000 const rangeConfig$1 = {
81001 kernelName: Range,
81002 backendName: 'cpu',
81003 kernelFunc: range$1
81004 };
81005
81006 /**
81007 * @license
81008 * Copyright 2020 Google LLC. All Rights Reserved.
81009 * Licensed under the Apache License, Version 2.0 (the License);
81010 * you may not use this file except in compliance with the License.
81011 * You may obtain a copy of the License at
81012 *
81013 * http://www.apache.org/licenses/LICENSE-2.0
81014 *
81015 * Unless required by applicable law or agreed to in writing, software
81016 * distributed under the License is distributed on an AS IS BASIS,
81017 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81018 * See the License for the specific language governing permissions and
81019 * limitations under the License.
81020 * =============================================================================
81021 */
81022 const reciprocal$1 = unaryKernelFunc$1(Reciprocal, (xi) => 1 / xi);
81023 const reciprocalConfig$1 = {
81024 kernelName: Reciprocal,
81025 backendName: 'cpu',
81026 kernelFunc: reciprocal$1,
81027 };
81028
81029 /**
81030 * @license
81031 * Copyright 2020 Google LLC. All Rights Reserved.
81032 * Licensed under the Apache License, Version 2.0 (the "License");
81033 * you may not use this file except in compliance with the License.
81034 * You may obtain a copy of the License at
81035 *
81036 * http://www.apache.org/licenses/LICENSE-2.0
81037 *
81038 * Unless required by applicable law or agreed to in writing, software
81039 * distributed under the License is distributed on an "AS IS" BASIS,
81040 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81041 * See the License for the specific language governing permissions and
81042 * limitations under the License.
81043 * =============================================================================
81044 */
81045 function resizeBilinear$1(args) {
81046 const { inputs, backend, attrs } = args;
81047 const { images } = inputs;
81048 const { alignCorners, halfPixelCenters, size } = attrs;
81049 assertNotComplex$1(images, 'resizeBilinear');
81050 const imagesStrides = computeStrides(images.shape);
81051 const [newHeight, newWidth] = size;
81052 const [batch, oldHeight, oldWidth, numChannels] = images.shape;
81053 const xValues = backend.data.get(images.dataId).values;
81054 const result = new Float32Array(sizeFromShape([batch, newHeight, newWidth, numChannels]));
81055 const effectiveInputSize = [
81056 (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
81057 (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
81058 ];
81059 const effectiveOutputSize = [
81060 (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
81061 (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
81062 ];
81063 let outputIdx = 0;
81064 const effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
81065 const effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
81066 for (let b = 0; b < batch; b++) {
81067 for (let r = 0; r < newHeight; r++) {
81068 let sourceFracRow;
81069 if (halfPixelCenters) {
81070 sourceFracRow = effectiveRowSizeRatio * (r + 0.5) - 0.5;
81071 }
81072 else {
81073 sourceFracRow = effectiveRowSizeRatio * r;
81074 }
81075 const sourceRowFloor = Math.max(0, Math.floor(sourceFracRow));
81076 const rowFrac = sourceFracRow - sourceRowFloor;
81077 const sourceRowCeil = Math.min(oldHeight - 1, Math.ceil(sourceFracRow));
81078 const topRowOffset = b * imagesStrides[0] + sourceRowFloor * imagesStrides[1];
81079 const botRowOffset = b * imagesStrides[0] + sourceRowCeil * imagesStrides[1];
81080 for (let c = 0; c < newWidth; c++) {
81081 let sourceFracCol;
81082 if (halfPixelCenters) {
81083 sourceFracCol = effectiveColSizeRatio * (c + 0.5) - 0.5;
81084 }
81085 else {
81086 sourceFracCol = effectiveColSizeRatio * c;
81087 }
81088 const sourceColFloor = Math.max(0, Math.floor(sourceFracCol));
81089 const colFrac = sourceFracCol - sourceColFloor;
81090 const sourceColCeil = Math.min(oldWidth - 1, Math.ceil(sourceFracCol));
81091 const topLeftOffest = topRowOffset + sourceColFloor * imagesStrides[2];
81092 const botLeftOffset = botRowOffset + sourceColFloor * imagesStrides[2];
81093 const topRightOffset = topRowOffset + sourceColCeil * imagesStrides[2];
81094 const botRightOffest = botRowOffset + sourceColCeil * imagesStrides[2];
81095 for (let d = 0; d < numChannels; d++) {
81096 // Begin shader.
81097 // Compute the fractional index of the source.
81098 const topLeft = xValues[topLeftOffest + d];
81099 const bottomLeft = xValues[botLeftOffset + d];
81100 const topRight = xValues[topRightOffset + d];
81101 const bottomRight = xValues[botRightOffest + d];
81102 const top = topLeft + (topRight - topLeft) * colFrac;
81103 const bottom = bottomLeft + (bottomRight - bottomLeft) * colFrac;
81104 const newValue = top + (bottom - top) * rowFrac;
81105 result[outputIdx++] = newValue;
81106 }
81107 }
81108 }
81109 }
81110 return backend.makeTensorInfo([batch, newHeight, newWidth, numChannels], 'float32', result);
81111 }
81112 const resizeBilinearConfig$1 = {
81113 kernelName: ResizeBilinear,
81114 backendName: 'cpu',
81115 kernelFunc: resizeBilinear$1
81116 };
81117
81118 /**
81119 * @license
81120 * Copyright 2020 Google LLC. All Rights Reserved.
81121 * Licensed under the Apache License, Version 2.0 (the "License");
81122 * you may not use this file except in compliance with the License.
81123 * You may obtain a copy of the License at
81124 *
81125 * http://www.apache.org/licenses/LICENSE-2.0
81126 *
81127 * Unless required by applicable law or agreed to in writing, software
81128 * distributed under the License is distributed on an "AS IS" BASIS,
81129 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81130 * See the License for the specific language governing permissions and
81131 * limitations under the License.
81132 * =============================================================================
81133 */
81134 function resizeBilinearGrad$1(args) {
81135 const { inputs, backend, attrs } = args;
81136 const { images, dy } = inputs;
81137 const { alignCorners } = attrs;
81138 assertNotComplex$1([dy, images], 'resizeBilinearGrad');
81139 const imagesStrides = computeStrides(images.shape);
81140 const [batch, xHeight, xWidth, depth] = images.shape;
81141 const [, yHeight, yWidth] = dy.shape;
81142 const output = new Float32Array(batch * xHeight * xWidth * depth);
81143 // In the backwards pass, we want to find the pixels that were generated
81144 // for each pixel in the input image the forward pass and add the
81145 // corresponding coefficient from dy to the gradient (with some
81146 // interpolation).
81147 const effectiveXSize = [
81148 (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
81149 (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
81150 ];
81151 const effectiveYSize = [
81152 (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
81153 (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
81154 ];
81155 const heightScale = effectiveXSize[0] / effectiveYSize[0];
81156 const widthScale = effectiveXSize[1] / effectiveYSize[1];
81157 // Reference implementation
81158 // tslint:disable-next-line:max-line-length
81159 // https://github.com/tensorflow/tensorflow/blob/3039375c86a5bbc9610c7725dcaa95d635f87ba2/tensorflow/core/kernels/resize_bilinear_op.cc#L275
81160 const dyValues = backend.data.get(dy.dataId).values;
81161 let offset = 0;
81162 for (let b = 0; b < batch; b++) {
81163 const bOffset = b * imagesStrides[0];
81164 for (let r = 0; r < yHeight; r++) {
81165 const dxR = r * heightScale;
81166 const topDxRIndex = Math.floor(dxR);
81167 const bottomDxRIndex = Math.min(Math.ceil(dxR), xHeight - 1);
81168 const topDxROffset = bOffset + topDxRIndex * imagesStrides[1];
81169 const bottomDxROffset = bOffset + bottomDxRIndex * imagesStrides[1];
81170 const dxRLerp = dxR - topDxRIndex;
81171 const inverseDxRLerp = 1.0 - dxRLerp;
81172 for (let c = 0; c < yWidth; c++) {
81173 const dxC = c * widthScale;
81174 const leftDxCIndex = Math.floor(dxC);
81175 const rightDxCIndex = Math.min(Math.ceil(dxC), xWidth - 1);
81176 const dxCLerp = dxC - leftDxCIndex;
81177 const inverseDxCLerp = 1.0 - dxCLerp;
81178 const topLeftRCOffset = topDxROffset + leftDxCIndex * imagesStrides[2];
81179 const topRightRCOffset = topDxROffset + rightDxCIndex * imagesStrides[2];
81180 const bottomLeftRCOffset = bottomDxROffset + leftDxCIndex * imagesStrides[2];
81181 const bottomRightRCOffset = bottomDxROffset + rightDxCIndex * imagesStrides[2];
81182 const inverseDxRLerpTimesInverseDxCLerp = inverseDxRLerp * inverseDxCLerp;
81183 const inverseDxRLerpTimesDxCLerp = inverseDxRLerp * dxCLerp;
81184 const dxRLerpTimesInverseDxCLerp = dxRLerp * inverseDxCLerp;
81185 const dxRLerpTimesDxCLerp = dxRLerp * dxCLerp;
81186 for (let d = 0; d < depth; d++) {
81187 const dyVal = dyValues[offset++];
81188 output[topLeftRCOffset + d] +=
81189 dyVal * inverseDxRLerpTimesInverseDxCLerp;
81190 output[topRightRCOffset + d] += dyVal * inverseDxRLerpTimesDxCLerp;
81191 output[bottomLeftRCOffset + d] += dyVal * dxRLerpTimesInverseDxCLerp;
81192 output[bottomRightRCOffset + d] += dyVal * dxRLerpTimesDxCLerp;
81193 }
81194 }
81195 }
81196 }
81197 return backend.makeTensorInfo([batch, xWidth, xHeight, depth], 'float32', output);
81198 }
81199 const resizeBilinearGradConfig$1 = {
81200 kernelName: ResizeBilinearGrad,
81201 backendName: 'cpu',
81202 kernelFunc: resizeBilinearGrad$1
81203 };
81204
81205 /**
81206 * @license
81207 * Copyright 2020 Google LLC. All Rights Reserved.
81208 * Licensed under the Apache License, Version 2.0 (the "License");
81209 * you may not use this file except in compliance with the License.
81210 * You may obtain a copy of the License at
81211 *
81212 * http://www.apache.org/licenses/LICENSE-2.0
81213 *
81214 * Unless required by applicable law or agreed to in writing, software
81215 * distributed under the License is distributed on an "AS IS" BASIS,
81216 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81217 * See the License for the specific language governing permissions and
81218 * limitations under the License.
81219 * =============================================================================
81220 */
81221 function resizeNearestNeighbor$1(args) {
81222 const { inputs, backend, attrs } = args;
81223 const { images } = inputs;
81224 const { alignCorners, halfPixelCenters, size } = attrs;
81225 assertNotComplex$1(images, 'resizeNearestNeighbor');
81226 const imagesStrides = computeStrides(images.shape);
81227 const [newHeight, newWidth] = size;
81228 const [batch, oldHeight, oldWidth, numChannels] = images.shape;
81229 const xValues = backend.data.get(images.dataId).values;
81230 const output = new Float32Array(batch * newHeight * newWidth * numChannels);
81231 const effectiveInputSize = [
81232 (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
81233 (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
81234 ];
81235 const effectiveOutputSize = [
81236 (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
81237 (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
81238 ];
81239 const effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
81240 const effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
81241 let outputOffset = 0;
81242 for (let b = 0; b < batch; b++) {
81243 const batchOffset = b * imagesStrides[0];
81244 for (let r = 0; r < newHeight; r++) {
81245 const sourceFracRow = halfPixelCenters ?
81246 effectiveRowSizeRatio * (r + 0.5) :
81247 effectiveRowSizeRatio * r;
81248 let sourceNearestRow = Math.min(oldHeight - 1, alignCorners ? Math.round(sourceFracRow) : Math.floor(sourceFracRow));
81249 if (halfPixelCenters) {
81250 sourceNearestRow = Math.max(0, sourceNearestRow);
81251 }
81252 const rowOffset = batchOffset + sourceNearestRow * imagesStrides[1];
81253 for (let c = 0; c < newWidth; c++) {
81254 const sourceFracCol = halfPixelCenters ?
81255 effectiveColSizeRatio * (c + 0.5) :
81256 effectiveColSizeRatio * c;
81257 let sourceNearestCol = Math.min(oldWidth - 1, alignCorners ? Math.round(sourceFracCol) :
81258 Math.floor(sourceFracCol));
81259 if (halfPixelCenters) {
81260 sourceNearestCol = Math.max(0, sourceNearestCol);
81261 }
81262 const colOffset = rowOffset + sourceNearestCol * imagesStrides[2];
81263 for (let d = 0; d < numChannels; d++) {
81264 // Begin shader.
81265 // Compute the fractional index of the source.
81266 const newVal = xValues[colOffset + d];
81267 output[outputOffset++] = newVal;
81268 }
81269 }
81270 }
81271 }
81272 return backend.makeTensorInfo([batch, newHeight, newWidth, numChannels], images.dtype, output);
81273 }
81274 const resizeNearestNeighborConfig$1 = {
81275 kernelName: ResizeNearestNeighbor,
81276 backendName: 'cpu',
81277 kernelFunc: resizeNearestNeighbor$1
81278 };
81279
81280 /**
81281 * @license
81282 * Copyright 2020 Google LLC. All Rights Reserved.
81283 * Licensed under the Apache License, Version 2.0 (the "License");
81284 * you may not use this file except in compliance with the License.
81285 * You may obtain a copy of the License at
81286 *
81287 * http://www.apache.org/licenses/LICENSE-2.0
81288 *
81289 * Unless required by applicable law or agreed to in writing, software
81290 * distributed under the License is distributed on an "AS IS" BASIS,
81291 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81292 * See the License for the specific language governing permissions and
81293 * limitations under the License.
81294 * =============================================================================
81295 */
81296 function resizeNearestNeighborGrad$1(args) {
81297 const { inputs, backend, attrs } = args;
81298 const { images, dy } = inputs;
81299 const { alignCorners } = attrs;
81300 assertNotComplex$1([dy, images], 'resizeNearestNeighborGrad');
81301 const imagesStrides = computeStrides(images.shape);
81302 const dyStrides = computeStrides(dy.shape);
81303 const [batch, xHeight, xWidth, depth] = images.shape;
81304 const [, yHeight, yWidth] = dy.shape;
81305 const output = new Float32Array(batch * xHeight * xWidth * depth);
81306 const dyValues = backend.data.get(dy.dataId).values;
81307 // In the backwards pass, we want to find the pixels that were generated
81308 // for each pixel in the input image the forward pass
81309 const effectiveXSize = [
81310 (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
81311 (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
81312 ];
81313 const effectiveYSize = [
81314 (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
81315 (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
81316 ];
81317 const heightScale = effectiveXSize[0] / effectiveYSize[0];
81318 const widthScale = effectiveXSize[1] / effectiveYSize[1];
81319 const invHeightScale = 1 / heightScale;
81320 const invWidthScale = 1 / widthScale;
81321 // This defines the size of the window of values around a particular
81322 // index in dy that we want to search for contributions to dx.
81323 const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
81324 const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
81325 // Loop over the output space.
81326 for (let b = 0; b < batch; b++) {
81327 const batchOffset = b * imagesStrides[0];
81328 for (let r = 0; r < xHeight; r++) {
81329 const rowOffset = batchOffset + r * imagesStrides[1];
81330 // Compute bounds for where in dy we will look
81331 const startRLerp = Math.floor(r * invHeightScale);
81332 const startDyR = Math.floor(startRLerp - (winHeight / 2));
81333 for (let c = 0; c < xWidth; c++) {
81334 const colOffset = rowOffset + c * imagesStrides[2];
81335 // Compute bounds for where in dy we will look
81336 const startCLerp = Math.floor(c * invWidthScale);
81337 const startDyC = Math.floor(startCLerp - (winWidth / 2));
81338 for (let d = 0; d < depth; d++) {
81339 let accum = 0;
81340 // loop over dy
81341 for (let dyRIndex = 0; dyRIndex < winHeight; dyRIndex++) {
81342 const dyR = dyRIndex + startDyR;
81343 // Guard against the window exceeding the bounds of dy
81344 if (dyR < 0 || dyR >= yHeight) {
81345 continue;
81346 }
81347 const dyROffset = batchOffset + dyR * dyStrides[1];
81348 const sourceFracRow = dyR * heightScale;
81349 const sourceNearestRow = Math.min(xHeight - 1, alignCorners ? Math.round(sourceFracRow) :
81350 Math.floor(sourceFracRow));
81351 if (r !== sourceNearestRow) {
81352 continue;
81353 }
81354 for (let dyCIndex = 0; dyCIndex < winWidth; dyCIndex++) {
81355 const dyC = dyCIndex + startDyC;
81356 // Guard against the window exceeding the bounds of dy
81357 if (dyC < 0 || dyC >= yWidth) {
81358 continue;
81359 }
81360 const dyCOffset = dyROffset + dyC * dyStrides[2];
81361 const sourceFracCol = dyC * widthScale;
81362 const sourceNearestCol = Math.min(xWidth - 1, alignCorners ? Math.round(sourceFracCol) :
81363 Math.floor(sourceFracCol));
81364 if (c === sourceNearestCol) {
81365 accum += dyValues[dyCOffset + d];
81366 }
81367 }
81368 }
81369 output[colOffset + d] = accum;
81370 }
81371 }
81372 }
81373 }
81374 return backend.makeTensorInfo(images.shape, images.dtype, output);
81375 }
81376 const resizeNearestNeighborGradConfig$1 = {
81377 kernelName: ResizeNearestNeighborGrad,
81378 backendName: 'cpu',
81379 kernelFunc: resizeNearestNeighborGrad$1
81380 };
81381
81382 /**
81383 * @license
81384 * Copyright 2020 Google LLC. All Rights Reserved.
81385 * Licensed under the Apache License, Version 2.0 (the "License");
81386 * you may not use this file except in compliance with the License.
81387 * You may obtain a copy of the License at
81388 *
81389 * http://www.apache.org/licenses/LICENSE-2.0
81390 *
81391 * Unless required by applicable law or agreed to in writing, software
81392 * distributed under the License is distributed on an "AS IS" BASIS,
81393 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81394 * See the License for the specific language governing permissions and
81395 * limitations under the License.
81396 * =============================================================================
81397 */
81398 function reverse$1(args) {
81399 const { inputs, backend, attrs } = args;
81400 const { x } = inputs;
81401 const { dims } = attrs;
81402 assertNotComplex$1(x, 'reverse');
81403 const xRank = x.shape.length;
81404 const $dims = parseAxisParam(dims, x.shape);
81405 if (xRank === 0) {
81406 return identity$1({ inputs: { x }, backend });
81407 }
81408 const outBuf = new TensorBuffer(x.shape, x.dtype);
81409 const xBuf = backend.bufferSync(x);
81410 for (let i = 0; i < outBuf.size; i++) {
81411 const outLoc = outBuf.indexToLoc(i);
81412 const inLoc = outLoc.slice();
81413 $dims.forEach(d => inLoc[d] = x.shape[d] - 1 - inLoc[d]);
81414 outBuf.set(xBuf.get(...inLoc), ...outLoc);
81415 }
81416 return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
81417 }
81418 const reverseConfig$1 = {
81419 kernelName: Reverse,
81420 backendName: 'cpu',
81421 kernelFunc: reverse$1
81422 };
81423
81424 /**
81425 * @license
81426 * Copyright 2020 Google LLC. All Rights Reserved.
81427 * Licensed under the Apache License, Version 2.0 (the "License");
81428 * you may not use this file except in compliance with the License.
81429 * You may obtain a copy of the License at
81430 *
81431 * http://www.apache.org/licenses/LICENSE-2.0
81432 *
81433 * Unless required by applicable law or agreed to in writing, software
81434 * distributed under the License is distributed on an "AS IS" BASIS,
81435 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81436 * See the License for the specific language governing permissions and
81437 * limitations under the License.
81438 * =============================================================================
81439 */
81440 const rotateWithOffsetConfig$1 = {
81441 kernelName: RotateWithOffset,
81442 backendName: 'cpu',
81443 kernelFunc: ({ inputs, attrs, backend }) => {
81444 const { image } = inputs;
81445 const { radians, fillValue, center } = attrs;
81446 const cpuBackend = backend;
81447 const output = getTypedArrayFromDType(image.dtype, sizeFromShape(image.shape));
81448 const [batch, imageHeight, imageWidth, numChannels] = image.shape;
81449 const [centerX, centerY] = getImageCenter(center, imageHeight, imageWidth);
81450 const fullOpacityValue = 255;
81451 const sinFactor = Math.sin(radians);
81452 const cosFactor = Math.cos(radians);
81453 const imageVals = cpuBackend.data.get(image.dataId).values;
81454 for (let batchIdx = 0; batchIdx < batch; batchIdx++) {
81455 const batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
81456 for (let row = 0; row < imageHeight; row++) {
81457 const rowOffset = row * (imageWidth * numChannels);
81458 for (let col = 0; col < imageWidth; col++) {
81459 const colOffset = col * numChannels;
81460 for (let channel = 0; channel < numChannels; channel++) {
81461 const coords = [batch, row, col, channel];
81462 const x = coords[2];
81463 const y = coords[1];
81464 // coordX/coordY are the result of rotating and translating x/y.
81465 let coordX = (x - centerX) * cosFactor - (y - centerY) * sinFactor;
81466 let coordY = (x - centerX) * sinFactor + (y - centerY) * cosFactor;
81467 coordX = Math.round(coordX + centerX);
81468 coordY = Math.round(coordY + centerY);
81469 let outputValue = fillValue;
81470 if (typeof fillValue !== 'number') {
81471 if (channel === 3) {
81472 outputValue = fullOpacityValue;
81473 }
81474 else {
81475 outputValue = fillValue[channel];
81476 }
81477 }
81478 // If the coordinate position falls within the image boundaries...
81479 if (coordX >= 0 && coordX < imageWidth && coordY >= 0 &&
81480 coordY < imageHeight) {
81481 // set the output to the image value at the coordinate position.
81482 const rotatedRowOffset = coordY * (imageWidth * numChannels);
81483 const rotatedColOffset = coordX * numChannels;
81484 const imageIdx = batchOffset + rotatedRowOffset + rotatedColOffset + channel;
81485 outputValue = imageVals[imageIdx];
81486 }
81487 const outIdx = batchOffset + rowOffset + colOffset + channel;
81488 output[outIdx] = outputValue;
81489 }
81490 }
81491 }
81492 }
81493 const dataId = cpuBackend.write(output, image.shape, image.dtype);
81494 return { dataId, shape: image.shape, dtype: image.dtype };
81495 }
81496 };
81497
81498 /**
81499 * @license
81500 * Copyright 2020 Google LLC. All Rights Reserved.
81501 * Licensed under the Apache License, Version 2.0 (the License);
81502 * you may not use this file except in compliance with the License.
81503 * You may obtain a copy of the License at
81504 *
81505 * http://www.apache.org/licenses/LICENSE-2.0
81506 *
81507 * Unless required by applicable law or agreed to in writing, software
81508 * distributed under the License is distributed on an AS IS BASIS,
81509 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81510 * See the License for the specific language governing permissions and
81511 * limitations under the License.
81512 * =============================================================================
81513 */
81514 const round$1 = unaryKernelFunc$1(Round, (xi) => {
81515 // The algorithm is based on banker's rounding.
81516 const base = Math.floor(xi);
81517 if (xi - base < 0.5) {
81518 return Math.floor(xi);
81519 }
81520 else if (xi - base > 0.5) {
81521 return Math.ceil(xi);
81522 }
81523 else {
81524 if (base % 2.0 === 0.0) {
81525 return base;
81526 }
81527 else {
81528 return base + 1.0;
81529 }
81530 }
81531 });
81532 const roundConfig$1 = {
81533 kernelName: Round,
81534 backendName: 'cpu',
81535 kernelFunc: round$1,
81536 };
81537
81538 /**
81539 * @license
81540 * Copyright 2020 Google LLC. All Rights Reserved.
81541 * Licensed under the Apache License, Version 2.0 (the "License");
81542 * you may not use this file except in compliance with the License.
81543 * You may obtain a copy of the License at
81544 *
81545 * http://www.apache.org/licenses/LICENSE-2.0
81546 *
81547 * Unless required by applicable law or agreed to in writing, software
81548 * distributed under the License is distributed on an "AS IS" BASIS,
81549 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81550 * See the License for the specific language governing permissions and
81551 * limitations under the License.
81552 * =============================================================================
81553 */
81554 function scatterNd$1(args) {
81555 const { inputs, backend, attrs } = args;
81556 const { indices, updates } = inputs;
81557 const { shape } = attrs;
81558 const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, shape);
81559 const sumDupeIndices = true;
81560 const indicesBuf = backend.bufferSync(indices);
81561 const updatesBuf = backend.bufferSync(updates);
81562 const outBuf = scatterImpl(indicesBuf, updatesBuf, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, 0 /* defaultValue */, sumDupeIndices);
81563 return backend.makeTensorInfo(shape, outBuf.dtype, outBuf.values);
81564 }
81565 const scatterNdConfig$1 = {
81566 kernelName: ScatterNd,
81567 backendName: 'cpu',
81568 kernelFunc: scatterNd$1
81569 };
81570
81571 /**
81572 * @license
81573 * Copyright 2022 Google LLC. All Rights Reserved.
81574 * Licensed under the Apache License, Version 2.0 (the "License");
81575 * you may not use this file except in compliance with the License.
81576 * You may obtain a copy of the License at
81577 *
81578 * http://www.apache.org/licenses/LICENSE-2.0
81579 *
81580 * Unless required by applicable law or agreed to in writing, software
81581 * distributed under the License is distributed on an "AS IS" BASIS,
81582 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81583 * See the License for the specific language governing permissions and
81584 * limitations under the License.
81585 * =============================================================================
81586 */
81587 function lowerBound(array, value) {
81588 let left = 0;
81589 let right = array.length;
81590 let mid = 0;
81591 while (left < right) {
81592 mid = Math.floor((left + right) / 2);
81593 if (array[mid] < value) {
81594 left = mid + 1;
81595 }
81596 else {
81597 right = mid;
81598 }
81599 }
81600 return right;
81601 }
81602 function upperBound(array, value) {
81603 let left = 0;
81604 let right = array.length;
81605 let mid = 0;
81606 while (left < right) {
81607 mid = Math.floor((left + right) / 2);
81608 if (array[mid] <= value) {
81609 left = mid + 1;
81610 }
81611 else {
81612 right = mid;
81613 }
81614 }
81615 return right;
81616 }
81617 function searchSortedImpl(sortedInputs, values, batchSize, numInputs, numValues, side) {
81618 const output = getArrayFromDType('int32', batchSize * numValues);
81619 for (let b = 0; b < batchSize; ++b) {
81620 const sortedInputsSlice = sortedInputs.slice(b * numInputs, (b + 1) * numInputs);
81621 const outputOffset = b * numValues;
81622 for (let i = 0; i < numValues; ++i) {
81623 output[outputOffset + i] = side === 'left' ?
81624 lowerBound(sortedInputsSlice, values[i + outputOffset]) :
81625 upperBound(sortedInputsSlice, values[i + outputOffset]);
81626 }
81627 }
81628 return output;
81629 }
81630
81631 /**
81632 * @license
81633 * Copyright 2022 Google LLC. All Rights Reserved.
81634 * Licensed under the Apache License, Version 2.0 (the "License");
81635 * you may not use this file except in compliance with the License.
81636 * You may obtain a copy of the License at
81637 *
81638 * http://www.apache.org/licenses/LICENSE-2.0
81639 *
81640 * Unless required by applicable law or agreed to in writing, software
81641 * distributed under the License is distributed on an "AS IS" BASIS,
81642 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81643 * See the License for the specific language governing permissions and
81644 * limitations under the License.
81645 * =============================================================================
81646 */
81647 function searchSorted$1(args) {
81648 const { inputs, backend, attrs } = args;
81649 const { sortedSequence, values } = inputs;
81650 const { side } = attrs;
81651 const $sortedSequence = backend.data.get(sortedSequence.dataId).values;
81652 const $values = backend.data.get(values.dataId).values;
81653 const output = searchSortedImpl($sortedSequence, $values, sortedSequence.shape[0], sortedSequence.shape[1], values.shape[1], side);
81654 return backend.makeTensorInfo(values.shape, 'int32', output);
81655 }
81656 const searchSortedConfig$1 = {
81657 kernelName: SearchSorted,
81658 backendName: 'cpu',
81659 kernelFunc: searchSorted$1,
81660 };
81661
81662 /**
81663 * @license
81664 * Copyright 2020 Google LLC. All Rights Reserved.
81665 * Licensed under the Apache License, Version 2.0 (the "License");
81666 * you may not use this file except in compliance with the License.
81667 * You may obtain a copy of the License at
81668 *
81669 * http://www.apache.org/licenses/LICENSE-2.0
81670 *
81671 * Unless required by applicable law or agreed to in writing, software
81672 * distributed under the License is distributed on an "AS IS" BASIS,
81673 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81674 * See the License for the specific language governing permissions and
81675 * limitations under the License.
81676 * =============================================================================
81677 */
81678 function select$1(args) {
81679 const { inputs, backend } = args;
81680 const { condition, t, e } = inputs;
81681 assertNotComplex$1([condition, t, e], 'select');
81682 const conditionRank = condition.shape.length;
81683 const values = backend.data.get(condition.dataId).values;
81684 const tValues = backend.data.get(t.dataId).values;
81685 const eValues = backend.data.get(e.dataId).values;
81686 const resultDtype = upcastType(t.dtype, e.dtype);
81687 const newValues = makeZerosTypedArray(sizeFromShape(t.shape), resultDtype);
81688 let index = 0;
81689 const offset = conditionRank === 0 || conditionRank > 1 || t.shape.length === 1 ?
81690 1 :
81691 sizeFromShape(t.shape.slice(1));
81692 for (let i = 0; i < values.length; i++) {
81693 for (let j = 0; j < offset; j++) {
81694 if (values[i] === 1) {
81695 newValues[index++] = tValues[i];
81696 }
81697 else {
81698 newValues[index++] = eValues[i];
81699 }
81700 }
81701 }
81702 return backend.makeTensorInfo(t.shape, resultDtype, newValues);
81703 }
81704 const selectConfig$1 = {
81705 kernelName: Select,
81706 backendName: 'cpu',
81707 kernelFunc: select$1
81708 };
81709
81710 /**
81711 * @license
81712 * Copyright 2020 Google LLC. All Rights Reserved.
81713 * Licensed under the Apache License, Version 2.0 (the License);
81714 * you may not use this file except in compliance with the License.
81715 * You may obtain a copy of the License at
81716 *
81717 * http://www.apache.org/licenses/LICENSE-2.0
81718 *
81719 * Unless required by applicable law or agreed to in writing, software
81720 * distributed under the License is distributed on an AS IS BASIS,
81721 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81722 * See the License for the specific language governing permissions and
81723 * limitations under the License.
81724 * =============================================================================
81725 */
81726 const scaleAlpha = SELU_SCALEALPHA;
81727 const scale = SELU_SCALE;
81728 const selu$1 = unaryKernelFunc$1(Selu$1, (xi) => {
81729 if (xi >= 0) {
81730 return scale * xi;
81731 }
81732 else {
81733 return scaleAlpha * (Math.exp(xi) - 1);
81734 }
81735 });
81736 const seluConfig$1 = {
81737 kernelName: Selu$1,
81738 backendName: 'cpu',
81739 kernelFunc: selu$1,
81740 };
81741
81742 /**
81743 * @license
81744 * Copyright 2020 Google LLC. All Rights Reserved.
81745 * Licensed under the Apache License, Version 2.0 (the License);
81746 * you may not use this file except in compliance with the License.
81747 * You may obtain a copy of the License at
81748 *
81749 * http://www.apache.org/licenses/LICENSE-2.0
81750 *
81751 * Unless required by applicable law or agreed to in writing, software
81752 * distributed under the License is distributed on an AS IS BASIS,
81753 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81754 * See the License for the specific language governing permissions and
81755 * limitations under the License.
81756 * =============================================================================
81757 */
81758 const sign$1 = unaryKernelFunc$1(Sign, (xi) => {
81759 if (xi < 0) {
81760 return -1;
81761 }
81762 else if (xi > 0) {
81763 return 1;
81764 }
81765 else {
81766 return 0;
81767 }
81768 });
81769 const signConfig$1 = {
81770 kernelName: Sign,
81771 backendName: 'cpu',
81772 kernelFunc: sign$1,
81773 };
81774
81775 /**
81776 * @license
81777 * Copyright 2020 Google LLC. All Rights Reserved.
81778 * Licensed under the Apache License, Version 2.0 (the License);
81779 * you may not use this file except in compliance with the License.
81780 * You may obtain a copy of the License at
81781 *
81782 * http://www.apache.org/licenses/LICENSE-2.0
81783 *
81784 * Unless required by applicable law or agreed to in writing, software
81785 * distributed under the License is distributed on an AS IS BASIS,
81786 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81787 * See the License for the specific language governing permissions and
81788 * limitations under the License.
81789 * =============================================================================
81790 */
81791 const sin$1 = unaryKernelFunc$1(Sin, (xi) => Math.sin(xi));
81792 const sinConfig$1 = {
81793 kernelName: Sin,
81794 backendName: 'cpu',
81795 kernelFunc: sin$1,
81796 };
81797
81798 /**
81799 * @license
81800 * Copyright 2020 Google LLC. All Rights Reserved.
81801 * Licensed under the Apache License, Version 2.0 (the License);
81802 * you may not use this file except in compliance with the License.
81803 * You may obtain a copy of the License at
81804 *
81805 * http://www.apache.org/licenses/LICENSE-2.0
81806 *
81807 * Unless required by applicable law or agreed to in writing, software
81808 * distributed under the License is distributed on an AS IS BASIS,
81809 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81810 * See the License for the specific language governing permissions and
81811 * limitations under the License.
81812 * =============================================================================
81813 */
81814 const sinh$1 = unaryKernelFunc$1(Sinh, (xi) => Math.sinh(xi));
81815 const sinhConfig$1 = {
81816 kernelName: Sinh,
81817 backendName: 'cpu',
81818 kernelFunc: sinh$1,
81819 };
81820
81821 /**
81822 * @license
81823 * Copyright 2020 Google LLC. All Rights Reserved.
81824 * Licensed under the Apache License, Version 2.0 (the License);
81825 * you may not use this file except in compliance with the License.
81826 * You may obtain a copy of the License at
81827 *
81828 * http://www.apache.org/licenses/LICENSE-2.0
81829 *
81830 * Unless required by applicable law or agreed to in writing, software
81831 * distributed under the License is distributed on an AS IS BASIS,
81832 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81833 * See the License for the specific language governing permissions and
81834 * limitations under the License.
81835 * =============================================================================
81836 */
81837 // mirrors the implementation of tf.nn.softplus: https://goo.gl/vkcvwX
81838 // epsilon is the difference between 1.0 and the next representable float.
81839 // For a single precision 32 bit float this should be 2^-23, see:
81840 // https://math.byu.edu/~schow/work/IEEEFloatingPoint.htm
81841 const epsilon = 1.1920928955078125e-7;
81842 const threshold = Math.log(epsilon) + 2.0;
81843 const softplus$1 = unaryKernelFunc$1(Softplus$1, (xi) => {
81844 // Value above which exp(x) may overflow, but softplus(x) == x
81845 // is within machine epsilon.
81846 const tooLarge = xi > -threshold;
81847 // Value below which exp(x) may underflow, but softplus(x) == exp(x)
81848 // is within machine epsilon.
81849 const tooSmall = xi < threshold;
81850 const expX = Math.exp(xi);
81851 let result;
81852 if (tooSmall) {
81853 result = expX;
81854 }
81855 else if (tooLarge) {
81856 result = xi;
81857 }
81858 else {
81859 result = Math.log(1.0 + expX);
81860 }
81861 return result;
81862 });
81863 const softplusConfig$1 = {
81864 kernelName: Softplus$1,
81865 backendName: 'cpu',
81866 kernelFunc: softplus$1,
81867 };
81868
81869 /**
81870 * @license
81871 * Copyright 2020 Google LLC. All Rights Reserved.
81872 * Licensed under the Apache License, Version 2.0 (the "License");
81873 * you may not use this file except in compliance with the License.
81874 * You may obtain a copy of the License at
81875 *
81876 * http://www.apache.org/licenses/LICENSE-2.0
81877 *
81878 * Unless required by applicable law or agreed to in writing, software
81879 * distributed under the License is distributed on an "AS IS" BASIS,
81880 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81881 * See the License for the specific language governing permissions and
81882 * limitations under the License.
81883 * =============================================================================
81884 */
81885 function spaceToBatchND$1(args) {
81886 const { inputs, backend, attrs } = args;
81887 const { x } = inputs;
81888 const { blockShape, paddings } = attrs;
81889 assertNotComplex$1([x], 'spaceToBatchND');
81890 const prod = sizeFromShape(blockShape);
81891 const completePaddings = [[0, 0]];
81892 completePaddings.push(...paddings);
81893 for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
81894 completePaddings.push([0, 0]);
81895 }
81896 const paddedX = padV2Config$1.kernelFunc({
81897 inputs: { x },
81898 backend,
81899 attrs: { paddings: completePaddings, constantValue: 0 }
81900 });
81901 const reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false);
81902 const permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false);
81903 const flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false);
81904 const reshapeInputs = { x: paddedX };
81905 const reshapeAttrs = { shape: reshapedPaddedShape };
81906 const paddedXReshaped = reshape$1({ inputs: reshapeInputs, backend, attrs: reshapeAttrs });
81907 const transposeInputs = { x: paddedXReshaped };
81908 const transposeAttrs = { perm: permutedReshapedPaddedPermutation };
81909 const paddedXT = transpose$1({ inputs: transposeInputs, backend, attrs: transposeAttrs });
81910 const resultReshapeInputs = { x: paddedXT };
81911 const resultReshapeAttrs = { shape: flattenShape };
81912 const result = reshape$1({ inputs: resultReshapeInputs, backend, attrs: resultReshapeAttrs });
81913 backend.disposeIntermediateTensorInfo(paddedX);
81914 backend.disposeIntermediateTensorInfo(paddedXReshaped);
81915 backend.disposeIntermediateTensorInfo(paddedXT);
81916 return result;
81917 }
81918 const spaceToBatchNDConfig$1 = {
81919 kernelName: SpaceToBatchND,
81920 backendName: 'cpu',
81921 kernelFunc: spaceToBatchND$1
81922 };
81923
81924 /**
81925 * @license
81926 * Copyright 2021 Google LLC. All Rights Reserved.
81927 * Licensed under the Apache License, Version 2.0 (the "License");
81928 * you may not use this file except in compliance with the License.
81929 * You may obtain a copy of the License at
81930 *
81931 * http://www.apache.org/licenses/LICENSE-2.0
81932 *
81933 * Unless required by applicable law or agreed to in writing, software
81934 * distributed under the License is distributed on an "AS IS" BASIS,
81935 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81936 * See the License for the specific language governing permissions and
81937 * limitations under the License.
81938 * =============================================================================
81939 */
81940 function sparseFillEmptyRows$1(args) {
81941 const { inputs, backend } = args;
81942 const { indices, values, denseShape, defaultValue } = inputs;
81943 if (denseShape.shape.length !== 1) {
81944 throw new Error(`Dense shape must be a vector, saw:
81945 ${denseShape.shape}`);
81946 }
81947 if (indices.shape.length !== 2) {
81948 throw new Error(`Indices must be a matrix, saw:
81949 ${indices.shape}`);
81950 }
81951 if (values.shape.length !== 1) {
81952 throw new Error(`Values must be a vector, saw:
81953 ${values.shape}`);
81954 }
81955 if (defaultValue.shape.length !== 0) {
81956 throw new Error(`Default value must be a scalar, saw:
81957 ${defaultValue.shape}`);
81958 }
81959 const $indices = backend.data.get(indices.dataId).values;
81960 const $values = backend.data.get(values.dataId).values;
81961 const $denseShape = backend.data.get(denseShape.dataId).values;
81962 const $defaultValue = backend.data.get(defaultValue.dataId).values[0];
81963 const [outputIndices, outputIndicesShape, outputValues, emptyRowIndicator, reverseIndexMap] = sparseFillEmptyRowsImpl($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue);
81964 return [
81965 backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices),
81966 backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues),
81967 backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map((value) => Number(value)))),
81968 backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap)),
81969 ];
81970 }
81971 const sparseFillEmptyRowsConfig$1 = {
81972 kernelName: SparseFillEmptyRows,
81973 backendName: 'cpu',
81974 kernelFunc: sparseFillEmptyRows$1,
81975 };
81976
81977 /**
81978 * @license
81979 * Copyright 2021 Google LLC. All Rights Reserved.
81980 * Licensed under the Apache License, Version 2.0 (the "License");
81981 * you may not use this file except in compliance with the License.
81982 * You may obtain a copy of the License at
81983 *
81984 * http://www.apache.org/licenses/LICENSE-2.0
81985 *
81986 * Unless required by applicable law or agreed to in writing, software
81987 * distributed under the License is distributed on an "AS IS" BASIS,
81988 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
81989 * See the License for the specific language governing permissions and
81990 * limitations under the License.
81991 * =============================================================================
81992 */
81993 function sparseReshape$1(args) {
81994 const { inputs, backend } = args;
81995 const { inputIndices, inputShape, newShape } = inputs;
81996 if (inputIndices.shape.length !== 2) {
81997 throw new Error(`Input indices should be a matrix but received shape
81998 ${inputIndices.shape}`);
81999 }
82000 if (inputShape.shape.length !== 1) {
82001 throw new Error(`Input shape should be a vector but received shape
82002 ${inputShape.shape}`);
82003 }
82004 if (newShape.shape.length !== 1) {
82005 throw new Error(`Target shape should be a vector but received shape ${newShape.shape}`);
82006 }
82007 const $inputShape = Array.from(backend.data.get(inputShape.dataId).values);
82008 const $inputIndices = backend.data.get(inputIndices.dataId).values;
82009 const targetShape = Array.from(backend.data.get(newShape.dataId).values);
82010 const [newIndices, indicesShape, outputShape] = sparseReshapeImpl($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape);
82011 return [
82012 backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices),
82013 backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape)),
82014 ];
82015 }
82016 const sparseReshapeConfig$1 = {
82017 kernelName: SparseReshape,
82018 backendName: 'cpu',
82019 kernelFunc: sparseReshape$1,
82020 };
82021
82022 /**
82023 * @license
82024 * Copyright 2021 Google LLC. All Rights Reserved.
82025 * Licensed under the Apache License, Version 2.0 (the "License");
82026 * you may not use this file except in compliance with the License.
82027 * You may obtain a copy of the License at
82028 *
82029 * http://www.apache.org/licenses/LICENSE-2.0
82030 *
82031 * Unless required by applicable law or agreed to in writing, software
82032 * distributed under the License is distributed on an "AS IS" BASIS,
82033 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82034 * See the License for the specific language governing permissions and
82035 * limitations under the License.
82036 * =============================================================================
82037 */
82038 function sparseSegmentMean$1(args) {
82039 const { inputs, backend } = args;
82040 const { data, indices, segmentIds } = inputs;
82041 if (data.shape.length < 1) {
82042 throw new Error(`Data should be at least 1 dimensional but received scalar`);
82043 }
82044 if (indices.shape.length !== 1) {
82045 throw new Error(`Indices should be a vector but received shape
82046 ${indices.shape}`);
82047 }
82048 if (segmentIds.shape.length !== 1) {
82049 throw new Error(`Segment ids should be a vector but received shape
82050 ${segmentIds.shape}`);
82051 }
82052 if (indices.shape[0] !== segmentIds.shape[0]) {
82053 throw new Error(`segmentIds and indices should have same size.`);
82054 }
82055 const $data = backend.data.get(data.dataId).values;
82056 const $indices = backend.data.get(indices.dataId).values;
82057 const $segmentIds = backend.data.get(segmentIds.dataId).values;
82058 const [outputData, outputDataShape] = sparseSegmentReductionImpl($data, data.shape, data.dtype, $indices, $segmentIds, true);
82059 return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
82060 }
82061 const sparseSegmentMeanConfig$1 = {
82062 kernelName: SparseSegmentMean,
82063 backendName: 'cpu',
82064 kernelFunc: sparseSegmentMean$1,
82065 };
82066
82067 /**
82068 * @license
82069 * Copyright 2021 Google LLC. All Rights Reserved.
82070 * Licensed under the Apache License, Version 2.0 (the "License");
82071 * you may not use this file except in compliance with the License.
82072 * You may obtain a copy of the License at
82073 *
82074 * http://www.apache.org/licenses/LICENSE-2.0
82075 *
82076 * Unless required by applicable law or agreed to in writing, software
82077 * distributed under the License is distributed on an "AS IS" BASIS,
82078 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82079 * See the License for the specific language governing permissions and
82080 * limitations under the License.
82081 * =============================================================================
82082 */
82083 function sparseSegmentSum$1(args) {
82084 const { inputs, backend } = args;
82085 const { data, indices, segmentIds } = inputs;
82086 if (data.shape.length < 1) {
82087 throw new Error(`Data should be at least 1 dimensional but received scalar`);
82088 }
82089 if (indices.shape.length !== 1) {
82090 throw new Error(`Indices should be a vector but received shape
82091 ${indices.shape}`);
82092 }
82093 if (segmentIds.shape.length !== 1) {
82094 throw new Error(`Segment ids should be a vector but received shape
82095 ${segmentIds.shape}`);
82096 }
82097 if (indices.shape[0] !== segmentIds.shape[0]) {
82098 throw new Error(`segmentIds and indices should have same size.`);
82099 }
82100 const $data = backend.data.get(data.dataId).values;
82101 const $indices = backend.data.get(indices.dataId).values;
82102 const $segmentIds = backend.data.get(segmentIds.dataId).values;
82103 const [outputData, outputDataShape] = sparseSegmentReductionImpl($data, data.shape, data.dtype, $indices, $segmentIds);
82104 return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
82105 }
82106 const sparseSegmentSumConfig$1 = {
82107 kernelName: SparseSegmentSum,
82108 backendName: 'cpu',
82109 kernelFunc: sparseSegmentSum$1,
82110 };
82111
82112 /**
82113 * @license
82114 * Copyright 2020 Google LLC. All Rights Reserved.
82115 * Licensed under the Apache License, Version 2.0 (the "License");
82116 * you may not use this file except in compliance with the License.
82117 * You may obtain a copy of the License at
82118 *
82119 * http://www.apache.org/licenses/LICENSE-2.0
82120 *
82121 * Unless required by applicable law or agreed to in writing, software
82122 * distributed under the License is distributed on an "AS IS" BASIS,
82123 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82124 * See the License for the specific language governing permissions and
82125 * limitations under the License.
82126 * =============================================================================
82127 */
82128 function sparseToDense$1(args) {
82129 const { inputs, backend, attrs } = args;
82130 const { sparseIndices, sparseValues, defaultValue } = inputs;
82131 const { outputShape } = attrs;
82132 const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(sparseValues, sparseIndices, outputShape);
82133 const sumDupeIndices = false;
82134 const indicesBuf = backend.bufferSync(sparseIndices);
82135 let outBuf;
82136 switch (sparseValues.dtype) {
82137 case 'bool': {
82138 const updatesBuf = backend.bufferSync(sparseValues);
82139 const $defaultValue = Boolean(backend.data.get(defaultValue.dataId).values[0]);
82140 outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
82141 break;
82142 }
82143 case 'float32': {
82144 const updatesBuf = backend.bufferSync(sparseValues);
82145 const $defaultValue = backend.data.get(defaultValue.dataId).values[0];
82146 outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
82147 break;
82148 }
82149 case 'int32': {
82150 const updatesBuf = backend.bufferSync(sparseValues);
82151 const $defaultValue = backend.data.get(defaultValue.dataId).values[0];
82152 outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
82153 break;
82154 }
82155 case 'string': {
82156 const updatesBuf = backend.bufferSync(sparseValues);
82157 const $defaultValue = decodeString(backend.data.get(defaultValue.dataId).values[0]);
82158 outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
82159 break;
82160 }
82161 default:
82162 throw new Error(`Unsupported type ${sparseValues.dtype}`);
82163 }
82164 return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values);
82165 }
82166 const sparseToDenseConfig$1 = {
82167 kernelName: SparseToDense,
82168 backendName: 'cpu',
82169 kernelFunc: sparseToDense$1
82170 };
82171
82172 /**
82173 * @license
82174 * Copyright 2020 Google LLC. All Rights Reserved.
82175 * Licensed under the Apache License, Version 2.0 (the "License");
82176 * you may not use this file except in compliance with the License.
82177 * You may obtain a copy of the License at
82178 *
82179 * http://www.apache.org/licenses/LICENSE-2.0
82180 *
82181 * Unless required by applicable law or agreed to in writing, software
82182 * distributed under the License is distributed on an "AS IS" BASIS,
82183 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82184 * See the License for the specific language governing permissions and
82185 * limitations under the License.
82186 * =============================================================================
82187 */
82188 function splitV$1(args) {
82189 const { inputs, backend, attrs } = args;
82190 const { x } = inputs;
82191 const { numOrSizeSplits, axis } = attrs;
82192 const $axis = parseAxisParam(axis, x.shape)[0];
82193 const splitSizes = prepareSplitSize(x, numOrSizeSplits, $axis);
82194 const begin = new Array(x.shape.length).fill(0);
82195 const size = x.shape.slice();
82196 return splitSizes.map(s => {
82197 const sliceSize = [...size];
82198 sliceSize[$axis] = s;
82199 const sliceT = slice$1({ inputs: { x }, backend, attrs: { begin, size: sliceSize } });
82200 begin[$axis] += s;
82201 return sliceT;
82202 });
82203 }
82204 const splitVConfig$1 = {
82205 kernelName: SplitV,
82206 backendName: 'cpu',
82207 kernelFunc: splitV$1
82208 };
82209
82210 /**
82211 * @license
82212 * Copyright 2019 Google LLC. All Rights Reserved.
82213 * Licensed under the Apache License, Version 2.0 (the "License");
82214 * you may not use this file except in compliance with the License.
82215 * You may obtain a copy of the License at
82216 *
82217 * http://www.apache.org/licenses/LICENSE-2.0
82218 *
82219 * Unless required by applicable law or agreed to in writing, software
82220 * distributed under the License is distributed on an "AS IS" BASIS,
82221 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82222 * See the License for the specific language governing permissions and
82223 * limitations under the License.
82224 * =============================================================================
82225 */
82226 const squareConfig$1 = {
82227 kernelName: Square,
82228 backendName: 'cpu',
82229 kernelFunc: ({ inputs, backend }) => {
82230 const { x } = inputs;
82231 const cpuBackend = backend;
82232 assertNotComplex$1(x, 'square');
82233 const values = cpuBackend.data.get(x.dataId).values;
82234 const newValues = new Float32Array(values.length);
82235 for (let i = 0; i < values.length; ++i) {
82236 const value = values[i];
82237 newValues[i] = value * value;
82238 }
82239 const dataId = cpuBackend.write(newValues, x.shape, x.dtype);
82240 return { dataId, shape: x.shape, dtype: x.dtype };
82241 }
82242 };
82243
82244 /**
82245 * @license
82246 * Copyright 2020 Google LLC. All Rights Reserved.
82247 * Licensed under the Apache License, Version 2.0 (the License);
82248 * you may not use this file except in compliance with the License.
82249 * You may obtain a copy of the License at
82250 *
82251 * http://www.apache.org/licenses/LICENSE-2.0
82252 *
82253 * Unless required by applicable law or agreed to in writing, software
82254 * distributed under the License is distributed on an AS IS BASIS,
82255 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82256 * See the License for the specific language governing permissions and
82257 * limitations under the License.
82258 * =============================================================================
82259 */
82260 const step$1 = unaryKernelFunc$1(Step, (xi, attrs) => {
82261 const stepAttrs = attrs;
82262 if (isNaN(xi)) {
82263 return NaN;
82264 }
82265 else {
82266 return xi > 0 ? 1 : stepAttrs.alpha;
82267 }
82268 });
82269 const stepConfig$1 = {
82270 kernelName: Step,
82271 backendName: 'cpu',
82272 kernelFunc: step$1,
82273 };
82274
82275 /**
82276 * @license
82277 * Copyright 2020 Google LLC. All Rights Reserved.
82278 * Licensed under the Apache License, Version 2.0 (the "License");
82279 * you may not use this file except in compliance with the License.
82280 * You may obtain a copy of the License at
82281 *
82282 * http://www.apache.org/licenses/LICENSE-2.0
82283 *
82284 * Unless required by applicable law or agreed to in writing, software
82285 * distributed under the License is distributed on an "AS IS" BASIS,
82286 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82287 * See the License for the specific language governing permissions and
82288 * limitations under the License.
82289 * =============================================================================
82290 */
82291 function stridedSlice$1(args) {
82292 const { inputs, backend, attrs } = args;
82293 const { x } = inputs;
82294 const { begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask } = attrs;
82295 assertNotComplex$1(x, 'stridedSlice');
82296 const { finalShapeSparse, finalShape, isIdentity, sliceDim0, isSimpleSlice, begin: $begin, end: $end, strides: $strides } = sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
82297 let result;
82298 // ref:
82299 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/strided_slice_op.cc
82300 if (isIdentity) {
82301 // Optimization #1, slice is a no-op plus reshape
82302 result = reshape$1({ inputs: { x }, backend, attrs: { shape: finalShape } });
82303 }
82304 else if (sliceDim0 || isSimpleSlice) {
82305 // Optimization #2, slice is memory contiguous (only occurs in dim 0)
82306 assert$1(x.shape.length >= 1, () => `Input must have rank at least 1, got: ${x.shape.length}`);
82307 const size = computeOutShape$2($begin, $end, $strides);
82308 // To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end).
82309 const sliced = slice$1({ inputs: { x }, backend, attrs: { begin: $begin, size } });
82310 result =
82311 reshape$1({ inputs: { x: sliced }, backend, attrs: { shape: finalShape } });
82312 backend.disposeIntermediateTensorInfo(sliced);
82313 }
82314 else {
82315 const xBuf = backend.bufferSync(x);
82316 const outBuf = stridedSliceImpl(finalShapeSparse, xBuf, $strides, $begin);
82317 result = backend.makeTensorInfo(finalShape, outBuf.dtype, outBuf.values);
82318 }
82319 return result;
82320 }
82321 const stridedSliceConfig$1 = {
82322 kernelName: StridedSlice,
82323 backendName: 'cpu',
82324 kernelFunc: stridedSlice$1
82325 };
82326
82327 /**
82328 * @license
82329 * Copyright 2021 Google LLC. All Rights Reserved.
82330 * Licensed under the Apache License, Version 2.0 (the "License");
82331 * you may not use this file except in compliance with the License.
82332 * You may obtain a copy of the License at
82333 *
82334 * http://www.apache.org/licenses/LICENSE-2.0
82335 *
82336 * Unless required by applicable law or agreed to in writing, software
82337 * distributed under the License is distributed on an "AS IS" BASIS,
82338 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82339 * See the License for the specific language governing permissions and
82340 * limitations under the License.
82341 * =============================================================================
82342 */
82343 function stringNGrams$1(args) {
82344 const { inputs, backend, attrs } = args;
82345 const { separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences } = attrs;
82346 const { data, dataSplits } = inputs;
82347 const $data = backend.data.get(data.dataId).values;
82348 const $dataSplits = backend.data.get(dataSplits.dataId).values;
82349 const [nGrams, nGramsSplits] = stringNGramsImpl($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences);
82350 return [
82351 backend.makeTensorInfo([nGrams.length], 'string', nGrams),
82352 backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits),
82353 ];
82354 }
82355 const stringNGramsConfig$1 = {
82356 kernelName: StringNGrams,
82357 backendName: 'cpu',
82358 kernelFunc: stringNGrams$1,
82359 };
82360
82361 /**
82362 * @license
82363 * Copyright 2021 Google LLC. All Rights Reserved.
82364 * Licensed under the Apache License, Version 2.0 (the "License");
82365 * you may not use this file except in compliance with the License.
82366 * You may obtain a copy of the License at
82367 *
82368 * http://www.apache.org/licenses/LICENSE-2.0
82369 *
82370 * Unless required by applicable law or agreed to in writing, software
82371 * distributed under the License is distributed on an "AS IS" BASIS,
82372 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82373 * See the License for the specific language governing permissions and
82374 * limitations under the License.
82375 * =============================================================================
82376 */
82377 function stringSplit$1(args) {
82378 const { inputs, backend, attrs } = args;
82379 const { skipEmpty } = attrs;
82380 const { input, delimiter } = inputs;
82381 if (input.dtype !== 'string') {
82382 throw new Error('Input must be of datatype string');
82383 }
82384 if (input.shape.length !== 1) {
82385 throw new Error(`Input must be a vector, got shape: ${input.shape}`);
82386 }
82387 if (delimiter.shape.length !== 0) {
82388 throw new Error(`Delimiter must be a scalar, got shape: ${delimiter.shape}`);
82389 }
82390 const $input = backend.data.get(input.dataId).values;
82391 const $delimiter = backend.data.get(delimiter.dataId).values[0];
82392 const [indices, values, shape] = stringSplitImpl($input, $delimiter, skipEmpty);
82393 const outputSize = values.length;
82394 return [
82395 backend.makeTensorInfo([outputSize, 2], 'int32', indices),
82396 backend.makeTensorInfo([outputSize], 'string', values),
82397 backend.makeTensorInfo([2], 'int32', new Int32Array(shape))
82398 ];
82399 }
82400 const stringSplitConfig$1 = {
82401 kernelName: StringSplit,
82402 backendName: 'cpu',
82403 kernelFunc: stringSplit$1,
82404 };
82405
82406 /**
82407 * @license
82408 * Copyright 2021 Google LLC. All Rights Reserved.
82409 * Licensed under the Apache License, Version 2.0 (the "License");
82410 * you may not use this file except in compliance with the License.
82411 * You may obtain a copy of the License at
82412 *
82413 * http://www.apache.org/licenses/LICENSE-2.0
82414 *
82415 * Unless required by applicable law or agreed to in writing, software
82416 * distributed under the License is distributed on an "AS IS" BASIS,
82417 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82418 * See the License for the specific language governing permissions and
82419 * limitations under the License.
82420 * =============================================================================
82421 */
82422 function stringToHashBucketFast$1(args) {
82423 const { inputs, backend, attrs } = args;
82424 const { numBuckets } = attrs;
82425 const { input } = inputs;
82426 if (input.dtype !== 'string') {
82427 throw new Error('Input must be of datatype string');
82428 }
82429 if (numBuckets <= 0) {
82430 throw new Error(`Number of buckets must be at least 1`);
82431 }
82432 const $input = backend.data.get(input.dataId).values;
82433 const output = stringToHashBucketFastImpl($input, numBuckets);
82434 return backend.makeTensorInfo(input.shape, 'int32', output);
82435 }
82436 const stringToHashBucketFastConfig$1 = {
82437 kernelName: StringToHashBucketFast,
82438 backendName: 'cpu',
82439 kernelFunc: stringToHashBucketFast$1,
82440 };
82441
82442 /**
82443 * @license
82444 * Copyright 2020 Google LLC. All Rights Reserved.
82445 * Licensed under the Apache License, Version 2.0 (the License);
82446 * you may not use this file except in compliance with the License.
82447 * You may obtain a copy of the License at
82448 *
82449 * http://www.apache.org/licenses/LICENSE-2.0
82450 *
82451 * Unless required by applicable law or agreed to in writing, software
82452 * distributed under the License is distributed on an AS IS BASIS,
82453 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82454 * See the License for the specific language governing permissions and
82455 * limitations under the License.
82456 * =============================================================================
82457 */
82458 const tan$1 = unaryKernelFunc$1(Tan, (xi) => Math.tan(xi));
82459 const tanConfig$1 = {
82460 kernelName: Tan,
82461 backendName: 'cpu',
82462 kernelFunc: tan$1,
82463 };
82464
82465 /**
82466 * @license
82467 * Copyright 2020 Google LLC. All Rights Reserved.
82468 * Licensed under the Apache License, Version 2.0 (the License);
82469 * you may not use this file except in compliance with the License.
82470 * You may obtain a copy of the License at
82471 *
82472 * http://www.apache.org/licenses/LICENSE-2.0
82473 *
82474 * Unless required by applicable law or agreed to in writing, software
82475 * distributed under the License is distributed on an AS IS BASIS,
82476 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82477 * See the License for the specific language governing permissions and
82478 * limitations under the License.
82479 * =============================================================================
82480 */
82481 const tanh$1 = unaryKernelFunc$1(Tanh$1, (xi) => Math.tanh(xi));
82482 const tanhConfig$1 = {
82483 kernelName: Tanh$1,
82484 backendName: 'cpu',
82485 kernelFunc: tanh$1,
82486 };
82487
82488 /**
82489 * @license
82490 * Copyright 2022 Google LLC. All Rights Reserved.
82491 * Licensed under the Apache License, Version 2.0 (the "License");
82492 * you may not use this file except in compliance with the License.
82493 * You may obtain a copy of the License at
82494 *
82495 * http://www.apache.org/licenses/LICENSE-2.0
82496 *
82497 * Unless required by applicable law or agreed to in writing, software
82498 * distributed under the License is distributed on an "AS IS" BASIS,
82499 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82500 * See the License for the specific language governing permissions and
82501 * limitations under the License.
82502 * =============================================================================
82503 */
82504 function tensorScatterUpdate$1(args) {
82505 const { inputs, backend } = args;
82506 const { tensor, indices, updates } = inputs;
82507 const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, tensor.shape);
82508 const sumDupeIndices = false;
82509 const indicesBuf = backend.bufferSync(indices);
82510 const updatesBuf = backend.bufferSync(updates);
82511 const tensorBuf = backend.bufferSync(tensor);
82512 const outBuf = scatterImpl(indicesBuf, updatesBuf, tensor.shape, outputSize, sliceSize, numUpdates, sliceRank, strides, tensorBuf, sumDupeIndices);
82513 return backend.makeTensorInfo(tensor.shape, outBuf.dtype, outBuf.values);
82514 }
82515 const tensorScatterUpdateConfig$1 = {
82516 kernelName: TensorScatterUpdate,
82517 backendName: 'cpu',
82518 kernelFunc: tensorScatterUpdate$1
82519 };
82520
82521 /**
82522 * @license
82523 * Copyright 2020 Google LLC. All Rights Reserved.
82524 * Licensed under the Apache License, Version 2.0 (the "License");
82525 * you may not use this file except in compliance with the License.
82526 * You may obtain a copy of the License at
82527 *
82528 * http://www.apache.org/licenses/LICENSE-2.0
82529 *
82530 * Unless required by applicable law or agreed to in writing, software
82531 * distributed under the License is distributed on an "AS IS" BASIS,
82532 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82533 * See the License for the specific language governing permissions and
82534 * limitations under the License.
82535 * =============================================================================
82536 */
82537 function tile$1(args) {
82538 const { inputs, backend, attrs } = args;
82539 const { x } = inputs;
82540 const { reps } = attrs;
82541 assertNotComplex$1(x, 'tile');
82542 const outBuf = tileImpl(backend.bufferSync(x), reps);
82543 return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
82544 }
82545 const tileConfig$1 = {
82546 kernelName: Tile,
82547 backendName: 'cpu',
82548 kernelFunc: tile$1
82549 };
82550
82551 /**
82552 * @license
82553 * Copyright 2020 Google LLC. All Rights Reserved.
82554 * Licensed under the Apache License, Version 2.0 (the "License");
82555 * you may not use this file except in compliance with the License.
82556 * You may obtain a copy of the License at
82557 *
82558 * http://www.apache.org/licenses/LICENSE-2.0
82559 *
82560 * Unless required by applicable law or agreed to in writing, software
82561 * distributed under the License is distributed on an "AS IS" BASIS,
82562 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82563 * See the License for the specific language governing permissions and
82564 * limitations under the License.
82565 * =============================================================================
82566 */
82567 function topK$1(args) {
82568 const { inputs, backend, attrs } = args;
82569 const { x } = inputs;
82570 const { k, sorted } = attrs;
82571 assertNotComplex$1(x, 'topk');
82572 const xVals = backend.data.get(x.dataId).values;
82573 const [allTopKVals, allTopKIndices] = topKImpl(xVals, x.shape, x.dtype, k, sorted);
82574 return [
82575 backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values),
82576 backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)
82577 ];
82578 }
82579 const topKConfig$1 = {
82580 kernelName: TopK,
82581 backendName: 'cpu',
82582 kernelFunc: topK$1
82583 };
82584
82585 /**
82586 * @license
82587 * Copyright 2021 Google LLC. All Rights Reserved.
82588 * Licensed under the Apache License, Version 2.0 (the "License");
82589 * you may not use this file except in compliance with the License.
82590 * You may obtain a copy of the License at
82591 *
82592 * http://www.apache.org/licenses/LICENSE-2.0
82593 *
82594 * Unless required by applicable law or agreed to in writing, software
82595 * distributed under the License is distributed on an "AS IS" BASIS,
82596 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82597 * See the License for the specific language governing permissions and
82598 * limitations under the License.
82599 * =============================================================================
82600 */
82601 function transform$1(args) {
82602 const { inputs, attrs, backend } = args;
82603 const { image, transforms } = inputs;
82604 const { interpolation, fillMode, fillValue, outputShape } = attrs;
82605 const [batch, imageHeight, imageWidth, numChannels] = image.shape;
82606 const [outHeight, outWidth] = outputShape != null ? outputShape : [imageHeight, imageWidth];
82607 const outShape = [batch, outHeight, outWidth, numChannels];
82608 const inStrides = computeStrides(image.shape);
82609 const batchInStride = inStrides[0];
82610 const rowInStride = inStrides[1];
82611 const colInStride = inStrides[2];
82612 const outStrides = computeStrides(outShape);
82613 const batchOutStride = outStrides[0];
82614 const rowOutStride = outStrides[1];
82615 const colOutStride = outStrides[2];
82616 const outVals = getTypedArrayFromDType(image.dtype, sizeFromShape(outShape));
82617 outVals.fill(fillValue);
82618 const imageVals = backend.data.get(image.dataId).values;
82619 const transformVals = backend.data.get(transforms.dataId).values;
82620 // Ref TF implementation:
82621 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/image/image_ops.h
82622 for (let b = 0; b < batch; ++b) {
82623 const transform = transforms.shape[0] === 1 ?
82624 transformVals :
82625 transformVals.subarray(b * 8, b * 8 + 8);
82626 for (let outY = 0; outY < outHeight; ++outY) {
82627 for (let outX = 0; outX < outWidth; ++outX) {
82628 for (let channel = 0; channel < numChannels; ++channel) {
82629 let val;
82630 const projection = transform[6] * outX + transform[7] * outY + 1;
82631 if (projection === 0) {
82632 // Return the fill value for infinite coordinates,
82633 // which are outside the input image
82634 continue;
82635 }
82636 const inX = (transform[0] * outX + transform[1] * outY + transform[2]) /
82637 projection;
82638 const inY = (transform[3] * outX + transform[4] * outY + transform[5]) /
82639 projection;
82640 const x = mapCoord(inX, imageWidth, fillMode);
82641 const y = mapCoord(inY, imageHeight, fillMode);
82642 switch (interpolation) {
82643 case 'nearest':
82644 val = nearestInterpolation(imageVals, imageHeight, imageWidth, batchInStride, rowInStride, colInStride, b, y, x, channel, fillValue);
82645 break;
82646 case 'bilinear':
82647 val = bilinearInterpolation(imageVals, imageHeight, imageWidth, batchInStride, rowInStride, colInStride, b, y, x, channel, fillValue);
82648 break;
82649 default:
82650 throw new Error(`Error in Transform: Expect 'nearest' or ` +
82651 `'bilinear', but got ${interpolation}`);
82652 }
82653 const ind = b * batchOutStride + outY * rowOutStride +
82654 outX * colOutStride + channel;
82655 outVals[ind] = val;
82656 }
82657 }
82658 }
82659 return backend.makeTensorInfo(outShape, image.dtype, outVals);
82660 }
82661 const dataId = backend.write(outVals, outShape, image.dtype);
82662 return { dataId, shape: image.shape, dtype: image.dtype };
82663 }
82664 const transformConfig$1 = {
82665 kernelName: Transform,
82666 backendName: 'cpu',
82667 kernelFunc: transform$1
82668 };
82669 function mapCoord(outCoord, len, mode) {
82670 switch (mode) {
82671 case 'reflect':
82672 return mapCoordReflect(outCoord, len);
82673 case 'wrap':
82674 return mapCoordWrap(outCoord, len);
82675 case 'nearest':
82676 return mapCoordNearest(outCoord, len);
82677 case 'constant':
82678 default:
82679 return mapCoordConstant(outCoord, len);
82680 }
82681 }
82682 function mapCoordReflect(outCoord, len) {
82683 // Reflect [abcd] to [dcba|abcd|dcba].
82684 let inCoord = outCoord;
82685 if (inCoord < 0) {
82686 if (len <= 1) {
82687 inCoord = 0;
82688 }
82689 else {
82690 const sz2 = 2 * len;
82691 if (inCoord < sz2) {
82692 inCoord = sz2 * Math.trunc(-inCoord / sz2) + inCoord;
82693 }
82694 inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1;
82695 }
82696 }
82697 else if (inCoord > len - 1) {
82698 if (len <= 1) {
82699 inCoord = 0;
82700 }
82701 else {
82702 const sz2 = 2 * len;
82703 inCoord -= sz2 * Math.trunc(inCoord / sz2);
82704 if (inCoord >= len) {
82705 inCoord = sz2 - inCoord - 1;
82706 }
82707 }
82708 }
82709 // clamp is necessary because when outCoord = 3.5 and len = 4,
82710 // inCoord = 3.5 and will be rounded to 4 in nearest interpolation.
82711 return clamp(0, inCoord, len - 1);
82712 }
82713 function mapCoordWrap(outCoord, len) {
82714 // Wrap [abcd] to [abcd|abcd|abcd].
82715 let inCoord = outCoord;
82716 if (inCoord < 0) {
82717 if (len <= 1) {
82718 inCoord = 0;
82719 }
82720 else {
82721 const sz = len - 1;
82722 inCoord += len * (Math.trunc(-inCoord / sz) + 1);
82723 }
82724 }
82725 else if (inCoord > len - 1) {
82726 if (len <= 1) {
82727 inCoord = 0;
82728 }
82729 else {
82730 const sz = len - 1;
82731 inCoord -= len * Math.trunc(inCoord / sz);
82732 }
82733 }
82734 // clamp is necessary because when outCoord = -0.5 and len = 4,
82735 // inCoord = 3.5 and will be rounded to 4 in nearest interpolation.
82736 return clamp(0, inCoord, len - 1);
82737 }
82738 function mapCoordConstant(outCoord, len) {
82739 return outCoord;
82740 }
82741 function mapCoordNearest(outCoord, len) {
82742 return clamp(0, outCoord, len - 1);
82743 }
82744 function readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
82745 const ind = batch * batchStride + y * rowStride + x * colStride + channel;
82746 if (0 <= y && y < imageHeight && 0 <= x && x < imageWidth) {
82747 return imageVals[ind];
82748 }
82749 else {
82750 return fillValue;
82751 }
82752 }
82753 function nearestInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
82754 const $y = Math.round(y);
82755 const $x = Math.round(x);
82756 return readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, $y, $x, channel, fillValue);
82757 }
82758 function bilinearInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
82759 const yFloor = Math.floor(y);
82760 const xFloor = Math.floor(x);
82761 const yCeil = yFloor + 1;
82762 const xCeil = xFloor + 1;
82763 // f(x, yFloor) = (xCeil - x) / (xCeil - xFloor) * f(xFloor, yFloor)
82764 // + (x - xFloor) / (xCeil - xFloor) * f(xCeil, yFloor)
82765 const valueYFloor = (xCeil - x) *
82766 readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yFloor, xFloor, channel, fillValue) +
82767 (x - xFloor) *
82768 readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yFloor, xCeil, channel, fillValue);
82769 // f(x, yCeil) = (xCeil - x) / (xCeil - xFloor) * f(xFloor, yCeil)
82770 // + (x - xFloor) / (xCeil - xFloor) * f(xCeil, yCeil)
82771 const valueYCeil = (xCeil - x) *
82772 readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yCeil, xFloor, channel, fillValue) +
82773 (x - xFloor) *
82774 readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yCeil, xCeil, channel, fillValue);
82775 // f(x, y) = (yCeil - y) / (yCeil - yFloor) * f(x, yFloor)
82776 // + (y - yFloor) / (yCeil - yFloor) * f(x, yCeil)
82777 return (yCeil - y) * valueYFloor + (y - yFloor) * valueYCeil;
82778 }
82779
82780 /**
82781 * @license
82782 * Copyright 2020 Google LLC. All Rights Reserved.
82783 * Licensed under the Apache License, Version 2.0 (the License);
82784 * you may not use this file except in compliance with the License.
82785 * You may obtain a copy of the License at
82786 *
82787 * http://www.apache.org/licenses/LICENSE-2.0
82788 *
82789 * Unless required by applicable law or agreed to in writing, software
82790 * distributed under the License is distributed on an AS IS BASIS,
82791 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82792 * See the License for the specific language governing permissions and
82793 * limitations under the License.
82794 * =============================================================================
82795 */
82796 function unique$1(args) {
82797 const { inputs, attrs, backend } = args;
82798 const { axis } = attrs;
82799 const { x } = inputs;
82800 assertNotComplex$1(x, 'unique');
82801 const values = backend.data.get(x.dataId).values;
82802 const { outputValues, outputShape, indices } = uniqueImpl(values, axis, x.shape, x.dtype);
82803 return [
82804 backend.makeTensorInfo(outputShape, x.dtype, outputValues),
82805 backend.makeTensorInfo([indices.length], 'int32', indices),
82806 ];
82807 }
82808 const uniqueConfig$1 = {
82809 kernelName: Unique,
82810 backendName: 'cpu',
82811 kernelFunc: unique$1,
82812 };
82813
82814 /**
82815 * @license
82816 * Copyright 2020 Google LLC. All Rights Reserved.
82817 * Licensed under the Apache License, Version 2.0 (the "License");
82818 * you may not use this file except in compliance with the License.
82819 * You may obtain a copy of the License at
82820 *
82821 * http://www.apache.org/licenses/LICENSE-2.0
82822 *
82823 * Unless required by applicable law or agreed to in writing, software
82824 * distributed under the License is distributed on an "AS IS" BASIS,
82825 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82826 * See the License for the specific language governing permissions and
82827 * limitations under the License.
82828 * =============================================================================
82829 */
82830 function unpack$1(args) {
82831 const { inputs, backend, attrs } = args;
82832 const { value } = inputs;
82833 let { axis } = attrs;
82834 if (axis < 0) {
82835 axis += value.shape.length;
82836 }
82837 const valueRank = value.shape.length;
82838 const num = value.shape[axis];
82839 const outShape = new Array(valueRank - 1);
82840 let outIndex = 0;
82841 for (let i = 0; i < valueRank; i++) {
82842 if (i !== axis) {
82843 outShape[outIndex++] = value.shape[i];
82844 }
82845 }
82846 const begin = new Array(valueRank).fill(0);
82847 const size = value.shape.slice();
82848 size[axis] = 1;
82849 const res = new Array(num);
82850 for (let i = 0; i < res.length; i++) {
82851 begin[axis] = i;
82852 const tempRes = slice$1({ inputs: { x: value }, backend, attrs: { begin, size } });
82853 res[i] = reshape$1({ inputs: { x: tempRes }, backend, attrs: { shape: outShape } });
82854 backend.disposeIntermediateTensorInfo(tempRes);
82855 }
82856 return res;
82857 }
82858 const unpackConfig$1 = {
82859 kernelName: Unpack,
82860 backendName: 'cpu',
82861 kernelFunc: unpack$1
82862 };
82863
82864 /**
82865 * @license
82866 * Copyright 2020 Google LLC. All Rights Reserved.
82867 * Licensed under the Apache License, Version 2.0 (the "License");
82868 * you may not use this file except in compliance with the License.
82869 * You may obtain a copy of the License at
82870 *
82871 * http://www.apache.org/licenses/LICENSE-2.0
82872 *
82873 * Unless required by applicable law or agreed to in writing, software
82874 * distributed under the License is distributed on an "AS IS" BASIS,
82875 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82876 * See the License for the specific language governing permissions and
82877 * limitations under the License.
82878 * =============================================================================
82879 */
82880 function unsortedSegmentSum$1(args) {
82881 const { inputs, backend, attrs } = args;
82882 const { x, segmentIds } = inputs;
82883 const { numSegments } = attrs;
82884 assertNotComplex$1(x, 'unsortedSegmentSum');
82885 const xRank = x.shape.length;
82886 const segmentIdsRank = segmentIds.shape.length;
82887 const res = [];
82888 const intermediates = [];
82889 // Reshape the segment id's so that they can be broadcast with
82890 // x. The new shape should be [segmentIds.shape, 1, ..., 1]
82891 const numIters = xRank - segmentIdsRank;
82892 let $segmentIds = segmentIds;
82893 for (let i = 0; i < numIters; ++i) {
82894 const expanded = expandDims$1({ inputs: { input: $segmentIds }, backend, attrs: { dim: i + 1 } });
82895 $segmentIds = expanded;
82896 intermediates.push(expanded);
82897 }
82898 for (let i = 0; i < numSegments; ++i) {
82899 const scalarValue = createScalarValue(i, 'int32');
82900 const segmentId = backend.makeTensorInfo([], 'int32', scalarValue);
82901 const mask = equal$1({ inputs: { a: segmentId, b: $segmentIds }, backend });
82902 const maskCasted = cast$1({ inputs: { x: mask }, backend, attrs: { dtype: 'float32' } });
82903 const mul = multiply$1({ inputs: { a: maskCasted, b: x }, backend });
82904 const sumTensorInfo = sum$1({ inputs: { x: mul }, backend, attrs: { axis: 0, keepDims: false } });
82905 res.push(sumTensorInfo);
82906 intermediates.push(segmentId);
82907 intermediates.push(mask);
82908 intermediates.push(maskCasted);
82909 intermediates.push(mul);
82910 intermediates.push(sumTensorInfo);
82911 }
82912 const result = pack$1({ inputs: res, backend, attrs: { axis: 0 } });
82913 intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t));
82914 return result;
82915 }
82916 const unsortedSegmentSumConfig$1 = {
82917 kernelName: UnsortedSegmentSum,
82918 backendName: 'cpu',
82919 kernelFunc: unsortedSegmentSum$1
82920 };
82921
82922 /**
82923 * @license
82924 * Copyright 2020 Google LLC. All Rights Reserved.
82925 * Licensed under the Apache License, Version 2.0 (the "License");
82926 * you may not use this file except in compliance with the License.
82927 * You may obtain a copy of the License at
82928 *
82929 * http://www.apache.org/licenses/LICENSE-2.0
82930 *
82931 * Unless required by applicable law or agreed to in writing, software
82932 * distributed under the License is distributed on an "AS IS" BASIS,
82933 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
82934 * See the License for the specific language governing permissions and
82935 * limitations under the License.
82936 * =============================================================================
82937 */
82938 // List all kernel configs here
82939 const kernelConfigs$1 = [
82940 _fusedMatMulConfig$1,
82941 absConfig$1,
82942 acosConfig$1,
82943 acoshConfig$1,
82944 addConfig$1,
82945 addNConfig$1,
82946 allConfig$1,
82947 anyConfig$1,
82948 argMaxConfig$1,
82949 argMinConfig$1,
82950 asinConfig$1,
82951 asinhConfig$1,
82952 atanConfig$1,
82953 atan2Config$1,
82954 atanhConfig$1,
82955 avgPoolConfig$1,
82956 avgPool3DConfig$1,
82957 avgPool3DGradConfig$1,
82958 avgPoolGradConfig$1,
82959 batchMatMulConfig$1,
82960 batchNormConfig$1,
82961 batchToSpaceNDConfig$1,
82962 bincountConfig$1,
82963 bitwiseAndConfig$1,
82964 broadcastArgsConfig$1,
82965 castConfig$1,
82966 ceilConfig$1,
82967 clipByValueConfig$1,
82968 complexConfig$1,
82969 complexAbsConfig$1,
82970 concatConfig$1,
82971 conv2DConfig$1,
82972 conv2DBackpropFilterConfig$1,
82973 conv2DBackpropInputConfig$1,
82974 conv3DConfig$1,
82975 conv3DBackpropFilterV2Config$1,
82976 conv3DBackpropInputV2Config,
82977 cosConfig$1,
82978 coshConfig$1,
82979 cropAndResizeConfig$1,
82980 cumprodConfig$1,
82981 cumsumConfig$1,
82982 denseBincountConfig$1,
82983 depthToSpaceConfig$1,
82984 depthwiseConv2dNativeConfig$1,
82985 depthwiseConv2dNativeBackpropFilterConfig$1,
82986 depthwiseConv2dNativeBackpropInputConfig$1,
82987 diagConfig$1,
82988 dilation2DConfig$1,
82989 dilation2DBackpropFilterConfig,
82990 dilation2DBackpropInputConfig,
82991 drawConfig,
82992 einsumConfig$1,
82993 eluConfig$1,
82994 eluGradConfig$1,
82995 equalConfig$1,
82996 erfConfig$1,
82997 expConfig$1,
82998 expandDimsConfig$1,
82999 expm1Config$1,
83000 fftConfig$1,
83001 fillConfig$1,
83002 flipLeftRightConfig$1,
83003 floorConfig$1,
83004 floorDivConfig$1,
83005 fusedConv2DConfig$1,
83006 fusedDepthwiseConv2DConfig$1,
83007 gatherNdConfig$1,
83008 gatherV2Config$1,
83009 greaterConfig$1,
83010 greaterEqualConfig$1,
83011 identityConfig$1,
83012 ifftConfig$1,
83013 imagConfig$1,
83014 isFiniteConfig$1,
83015 isInfConfig$1,
83016 isNaNConfig$1,
83017 leakyReluConfig$1,
83018 lessConfig$1,
83019 lessEqualConfig$1,
83020 linSpaceConfig$1,
83021 logConfig$1,
83022 log1pConfig$1,
83023 logicalAndConfig$1,
83024 logicalNotConfig$1,
83025 logicalOrConfig$1,
83026 LRNConfig$1,
83027 LRNGradConfig$1,
83028 maxConfig$1,
83029 maximumConfig$1,
83030 maxPoolConfig$1,
83031 maxPool3DConfig$1,
83032 maxPool3DGradConfig$1,
83033 maxPoolGradConfig$1,
83034 maxPoolWithArgmaxConfig$1,
83035 meanConfig$1,
83036 minConfig$1,
83037 minimumConfig$1,
83038 mirrorPadConfig$1,
83039 modConfig$1,
83040 multinomialConfig$1,
83041 multiplyConfig$1,
83042 negConfig$1,
83043 nonMaxSuppressionV3Config$1,
83044 nonMaxSuppressionV4Config$1,
83045 nonMaxSuppressionV5Config$1,
83046 notEqualConfig$1,
83047 oneHotConfig$1,
83048 onesLikeConfig$1,
83049 packConfig$1,
83050 padV2Config$1,
83051 powConfig$1,
83052 preluConfig$1,
83053 prodConfig$1,
83054 raggedGatherConfig$1,
83055 raggedRangeConfig$1,
83056 raggedTensorToTensorConfig$1,
83057 rangeConfig$1,
83058 realConfig$1,
83059 realDivConfig$1,
83060 reciprocalConfig$1,
83061 reluConfig$1,
83062 relu6Config$1,
83063 reshapeConfig$1,
83064 resizeBilinearConfig$1,
83065 resizeBilinearGradConfig$1,
83066 resizeNearestNeighborConfig$1,
83067 resizeNearestNeighborGradConfig$1,
83068 reverseConfig$1,
83069 rotateWithOffsetConfig$1,
83070 roundConfig$1,
83071 rsqrtConfig$1,
83072 scatterNdConfig$1,
83073 searchSortedConfig$1,
83074 selectConfig$1,
83075 seluConfig$1,
83076 sigmoidConfig$1,
83077 signConfig$1,
83078 sinConfig$1,
83079 sinhConfig$1,
83080 sliceConfig$1,
83081 softmaxConfig$1,
83082 softplusConfig$1,
83083 spaceToBatchNDConfig$1,
83084 sparseFillEmptyRowsConfig$1,
83085 sparseReshapeConfig$1,
83086 sparseSegmentMeanConfig$1,
83087 sparseSegmentSumConfig$1,
83088 sparseToDenseConfig$1,
83089 splitVConfig$1,
83090 sqrtConfig$1,
83091 squareConfig$1,
83092 squaredDifferenceConfig$1,
83093 staticRegexReplaceConfig$1,
83094 stepConfig$1,
83095 stridedSliceConfig$1,
83096 stringNGramsConfig$1,
83097 stringSplitConfig$1,
83098 stringToHashBucketFastConfig$1,
83099 subConfig$1,
83100 sumConfig$1,
83101 tanConfig$1,
83102 tanhConfig$1,
83103 tensorScatterUpdateConfig$1,
83104 tileConfig$1,
83105 topKConfig$1,
83106 transformConfig$1,
83107 transposeConfig$1,
83108 uniqueConfig$1,
83109 unpackConfig$1,
83110 unsortedSegmentSumConfig$1,
83111 zerosLikeConfig$1
83112 ];
83113 for (const kernelConfig of kernelConfigs$1) {
83114 registerKernel(kernelConfig);
83115 }
83116
83117 /**
83118 * @license
83119 * Copyright 2020 Google LLC. All Rights Reserved.
83120 * Licensed under the Apache License, Version 2.0 (the "License");
83121 * you may not use this file except in compliance with the License.
83122 * You may obtain a copy of the License at
83123 *
83124 * http://www.apache.org/licenses/LICENSE-2.0
83125 *
83126 * Unless required by applicable law or agreed to in writing, software
83127 * distributed under the License is distributed on an "AS IS" BASIS,
83128 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
83129 * See the License for the specific language governing permissions and
83130 * limitations under the License.
83131 * =============================================================================
83132 */
83133
83134 /**
83135 * @license
83136 * Copyright 2018 Google LLC. All Rights Reserved.
83137 * Licensed under the Apache License, Version 2.0 (the "License");
83138 * you may not use this file except in compliance with the License.
83139 * You may obtain a copy of the License at
83140 *
83141 * http://www.apache.org/licenses/LICENSE-2.0
83142 *
83143 * Unless required by applicable law or agreed to in writing, software
83144 * distributed under the License is distributed on an "AS IS" BASIS,
83145 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
83146 * See the License for the specific language governing permissions and
83147 * limitations under the License.
83148 * =============================================================================
83149 */
83150 const contexts = {};
83151 const WEBGL_ATTRIBUTES = {
83152 alpha: false,
83153 antialias: false,
83154 premultipliedAlpha: false,
83155 preserveDrawingBuffer: false,
83156 depth: false,
83157 stencil: false,
83158 failIfMajorPerformanceCaveat: true
83159 };
83160 function clearWebGLContext(webGLVersion) {
83161 delete contexts[webGLVersion];
83162 }
83163 function setWebGLContext(webGLVersion, gl) {
83164 contexts[webGLVersion] = gl;
83165 }
83166 function getWebGLContext(webGLVersion, customCanvas) {
83167 if (!(webGLVersion in contexts) || customCanvas != null) {
83168 const newCtx = getWebGLRenderingContext(webGLVersion, customCanvas);
83169 if (newCtx !== null) {
83170 contexts[webGLVersion] = newCtx;
83171 }
83172 else {
83173 console.log('Could not get context for WebGL version', webGLVersion);
83174 return null;
83175 }
83176 }
83177 const gl = contexts[webGLVersion];
83178 if (gl == null || gl.isContextLost()) {
83179 delete contexts[webGLVersion];
83180 return getWebGLContext(webGLVersion);
83181 }
83182 gl.disable(gl.DEPTH_TEST);
83183 gl.disable(gl.STENCIL_TEST);
83184 gl.disable(gl.BLEND);
83185 gl.disable(gl.DITHER);
83186 gl.disable(gl.POLYGON_OFFSET_FILL);
83187 gl.disable(gl.SAMPLE_COVERAGE);
83188 gl.enable(gl.SCISSOR_TEST);
83189 gl.enable(gl.CULL_FACE);
83190 gl.cullFace(gl.BACK);
83191 return contexts[webGLVersion];
83192 }
83193 function createCanvas(webGLVersion) {
83194 // Use canvas element for Safari, since its offscreen canvas does not support
83195 // fencing.
83196 if (!env().getBool('IS_SAFARI') && typeof OffscreenCanvas !== 'undefined' &&
83197 webGLVersion === 2) {
83198 return new OffscreenCanvas(300, 150);
83199 }
83200 else if (typeof document !== 'undefined') {
83201 return document.createElement('canvas');
83202 }
83203 else {
83204 throw new Error('Cannot create a canvas in this context');
83205 }
83206 }
83207 function getWebGLRenderingContext(webGLVersion, customCanvas) {
83208 if (webGLVersion !== 1 && webGLVersion !== 2) {
83209 throw new Error('Cannot get WebGL rendering context, WebGL is disabled.');
83210 }
83211 const canvas = customCanvas == null ? createCanvas(webGLVersion) : customCanvas;
83212 canvas.addEventListener('webglcontextlost', (ev) => {
83213 ev.preventDefault();
83214 delete contexts[webGLVersion];
83215 }, false);
83216 if (env().getBool('SOFTWARE_WEBGL_ENABLED')) {
83217 WEBGL_ATTRIBUTES.failIfMajorPerformanceCaveat = false;
83218 }
83219 if (webGLVersion === 1) {
83220 return (
83221 // tslint:disable-next-line
83222 canvas.getContext('webgl', WEBGL_ATTRIBUTES) ||
83223 canvas
83224 .getContext('experimental-webgl', WEBGL_ATTRIBUTES));
83225 }
83226 return canvas.getContext('webgl2', WEBGL_ATTRIBUTES);
83227 }
83228
83229 /**
83230 * @license
83231 * Copyright 2017 Google LLC. All Rights Reserved.
83232 * Licensed under the Apache License, Version 2.0 (the "License");
83233 * you may not use this file except in compliance with the License.
83234 * You may obtain a copy of the License at
83235 *
83236 * http://www.apache.org/licenses/LICENSE-2.0
83237 *
83238 * Unless required by applicable law or agreed to in writing, software
83239 * distributed under the License is distributed on an "AS IS" BASIS,
83240 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
83241 * See the License for the specific language governing permissions and
83242 * limitations under the License.
83243 * =============================================================================
83244 */
83245 var PackingScheme;
83246 (function (PackingScheme) {
83247 /**
83248 * All values in a single texel are densely packed without any constraints.
83249 *
83250 * This is how the shader encodes a tensor with shape = [2, 3, 4]
83251 * (indices are [batch, row, col]).
83252 *
83253 * 000|001 010|011 020|021
83254 * ------- ------- -------
83255 * 002|003 012|013 022|023
83256 *
83257 * 100|101 110|111 120|121
83258 * ------- ------- -------
83259 * 102|103 112|113 122|123
83260 *
83261 */
83262 PackingScheme[PackingScheme["DENSE"] = 0] = "DENSE";
83263 /**
83264 * Single texels contain only values from the same batch, and from adjacent
83265 * rows and columns.
83266 *
83267 * This is how the shader encodes a tensor with shape = [2, 3, 5]
83268 * (indices are [batch, row, col]).
83269 *
83270 * 000|001 002|003 004|xxx 020|021 022|023 024|xxx
83271 * ------- ------- ------- ------- ------- -------
83272 * 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx
83273 *
83274 * 100|101 102|103 104|xxx 120|121 122|123 124|xxx
83275 * ------- ------- ------- ------- ------- -------
83276 * 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx
83277 *
83278 */
83279 PackingScheme[PackingScheme["SHARED_BATCH"] = 1] = "SHARED_BATCH";
83280 })(PackingScheme || (PackingScheme = {}));
83281 var TextureUsage;
83282 (function (TextureUsage) {
83283 TextureUsage[TextureUsage["RENDER"] = 0] = "RENDER";
83284 TextureUsage[TextureUsage["UPLOAD"] = 1] = "UPLOAD";
83285 TextureUsage[TextureUsage["PIXELS"] = 2] = "PIXELS";
83286 TextureUsage[TextureUsage["DOWNLOAD"] = 3] = "DOWNLOAD";
83287 })(TextureUsage || (TextureUsage = {}));
83288 var PhysicalTextureType;
83289 (function (PhysicalTextureType) {
83290 PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT16"] = 0] = "UNPACKED_FLOAT16";
83291 PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT32"] = 1] = "UNPACKED_FLOAT32";
83292 PhysicalTextureType[PhysicalTextureType["PACKED_4X1_UNSIGNED_BYTE"] = 2] = "PACKED_4X1_UNSIGNED_BYTE";
83293 PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT32"] = 3] = "PACKED_2X2_FLOAT32";
83294 PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT16"] = 4] = "PACKED_2X2_FLOAT16";
83295 })(PhysicalTextureType || (PhysicalTextureType = {}));
83296 function getUnpackedMatrixTextureShapeWidthHeight(rows, columns) {
83297 return [columns, rows];
83298 }
83299 function getUnpackedArraySizeFromMatrixSize(matrixSize, channelsPerTexture) {
83300 return matrixSize * channelsPerTexture;
83301 }
83302 function getColorMatrixTextureShapeWidthHeight(rows, columns) {
83303 return [columns * 4, rows];
83304 }
83305 /**
83306 * Get shape for densely packed RGBA texture.
83307 */
83308 function getDenseTexShape(shape) {
83309 const size = sizeFromShape(shape);
83310 const texelsNeeded = Math.ceil(size / 4);
83311 return sizeToSquarishShape(texelsNeeded);
83312 }
83313 function getMatrixSizeFromUnpackedArraySize(unpackedSize, channelsPerTexture) {
83314 if (unpackedSize % channelsPerTexture !== 0) {
83315 throw new Error(`unpackedSize (${unpackedSize}) must be a multiple of ` +
83316 `${channelsPerTexture}`);
83317 }
83318 return unpackedSize / channelsPerTexture;
83319 }
83320 function decodeMatrixFromUnpackedColorRGBAArray(unpackedArray, matrix, channels) {
83321 const requiredSize = unpackedArray.length * channels / 4;
83322 if (matrix.length < requiredSize) {
83323 throw new Error(`matrix length (${matrix.length}) must be >= ${requiredSize}`);
83324 }
83325 let dst = 0;
83326 for (let src = 0; src < unpackedArray.length; src += 4) {
83327 for (let c = 0; c < channels; c++) {
83328 matrix[dst++] = unpackedArray[src + c];
83329 }
83330 }
83331 }
83332 function getPackedMatrixTextureShapeWidthHeight(rows, columns) {
83333 return [
83334 Math.max(1, Math.ceil(columns / 2)), Math.max(1, Math.ceil(rows / 2))
83335 ];
83336 }
83337 function getPackedRGBAArraySizeFromMatrixShape(rows, columns) {
83338 const [w, h] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
83339 return w * h * 4;
83340 }
83341 function getTextureConfig(
83342 // tslint:disable-next-line:no-any
83343 gl, textureHalfFloatExtension) {
83344 // tslint:disable-next-line:no-any
83345 const glany = gl;
83346 let internalFormatFloat;
83347 let internalFormatHalfFloat;
83348 let internalFormatPackedHalfFloat;
83349 let internalFormatPackedFloat;
83350 let textureFormatFloat;
83351 let downloadTextureFormat;
83352 let downloadUnpackNumChannels;
83353 let defaultNumChannels;
83354 let textureTypeHalfFloat;
83355 let textureTypeFloat;
83356 if (env().getNumber('WEBGL_VERSION') === 2) {
83357 internalFormatFloat = glany.R32F;
83358 internalFormatHalfFloat = glany.R16F;
83359 internalFormatPackedHalfFloat = glany.RGBA16F;
83360 internalFormatPackedFloat = glany.RGBA32F;
83361 textureFormatFloat = glany.RED;
83362 downloadUnpackNumChannels = 4;
83363 defaultNumChannels = 1;
83364 textureTypeHalfFloat = glany.HALF_FLOAT;
83365 textureTypeFloat = glany.FLOAT;
83366 downloadTextureFormat = glany.RGBA8;
83367 }
83368 else {
83369 internalFormatFloat = gl.RGBA;
83370 internalFormatHalfFloat = gl.RGBA;
83371 internalFormatPackedHalfFloat = gl.RGBA;
83372 internalFormatPackedFloat = glany.RGBA;
83373 textureFormatFloat = gl.RGBA;
83374 downloadUnpackNumChannels = 4;
83375 defaultNumChannels = 4;
83376 textureTypeHalfFloat = textureHalfFloatExtension != null ?
83377 textureHalfFloatExtension.HALF_FLOAT_OES :
83378 null;
83379 textureTypeFloat = gl.FLOAT;
83380 downloadTextureFormat = gl.RGBA;
83381 }
83382 return {
83383 internalFormatFloat,
83384 internalFormatHalfFloat,
83385 internalFormatPackedHalfFloat,
83386 internalFormatPackedFloat,
83387 textureFormatFloat,
83388 downloadTextureFormat,
83389 downloadUnpackNumChannels,
83390 defaultNumChannels,
83391 textureTypeHalfFloat,
83392 textureTypeFloat
83393 };
83394 }
83395
83396 /**
83397 * @license
83398 * Copyright 2017 Google LLC. All Rights Reserved.
83399 * Licensed under the Apache License, Version 2.0 (the "License");
83400 * you may not use this file except in compliance with the License.
83401 * You may obtain a copy of the License at
83402 *
83403 * http://www.apache.org/licenses/LICENSE-2.0
83404 *
83405 * Unless required by applicable law or agreed to in writing, software
83406 * distributed under the License is distributed on an "AS IS" BASIS,
83407 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
83408 * See the License for the specific language governing permissions and
83409 * limitations under the License.
83410 * =============================================================================
83411 */
83412 function callAndCheck(gl, func) {
83413 const returnValue = func();
83414 if (env().getBool('DEBUG')) {
83415 checkWebGLError(gl);
83416 }
83417 return returnValue;
83418 }
83419 function checkWebGLError(gl) {
83420 const error = gl.getError();
83421 if (error !== gl.NO_ERROR) {
83422 throw new Error('WebGL Error: ' + getWebGLErrorMessage(gl, error));
83423 }
83424 }
83425 // https://en.wikipedia.org/wiki/Half-precision_floating-point_format
83426 const MIN_FLOAT16 = 5.96e-8;
83427 const MAX_FLOAT16 = 65504;
83428 function canBeRepresented(num) {
83429 if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 ||
83430 (MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16)) {
83431 return true;
83432 }
83433 return false;
83434 }
83435 function getWebGLErrorMessage(gl, status) {
83436 switch (status) {
83437 case gl.NO_ERROR:
83438 return 'NO_ERROR';
83439 case gl.INVALID_ENUM:
83440 return 'INVALID_ENUM';
83441 case gl.INVALID_VALUE:
83442 return 'INVALID_VALUE';
83443 case gl.INVALID_OPERATION:
83444 return 'INVALID_OPERATION';
83445 case gl.INVALID_FRAMEBUFFER_OPERATION:
83446 return 'INVALID_FRAMEBUFFER_OPERATION';
83447 case gl.OUT_OF_MEMORY:
83448 return 'OUT_OF_MEMORY';
83449 case gl.CONTEXT_LOST_WEBGL:
83450 return 'CONTEXT_LOST_WEBGL';
83451 default:
83452 return `Unknown error code ${status}`;
83453 }
83454 }
83455 function getExtensionOrThrow(gl, extensionName) {
83456 return throwIfNull(gl, () => gl.getExtension(extensionName), 'Extension "' + extensionName + '" not supported on this browser.');
83457 }
83458 function createVertexShader$1(gl, vertexShaderSource) {
83459 const vertexShader = throwIfNull(gl, () => gl.createShader(gl.VERTEX_SHADER), 'Unable to create vertex WebGLShader.');
83460 callAndCheck(gl, () => gl.shaderSource(vertexShader, vertexShaderSource));
83461 callAndCheck(gl, () => gl.compileShader(vertexShader));
83462 if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) {
83463 console.log(gl.getShaderInfoLog(vertexShader));
83464 throw new Error('Failed to compile vertex shader.');
83465 }
83466 return vertexShader;
83467 }
83468 function createFragmentShader(gl, fragmentShaderSource) {
83469 const fragmentShader = throwIfNull(gl, () => gl.createShader(gl.FRAGMENT_SHADER), 'Unable to create fragment WebGLShader.');
83470 callAndCheck(gl, () => gl.shaderSource(fragmentShader, fragmentShaderSource));
83471 callAndCheck(gl, () => gl.compileShader(fragmentShader));
83472 if (env().get('ENGINE_COMPILE_ONLY')) {
83473 return fragmentShader;
83474 }
83475 if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) {
83476 logShaderSourceAndInfoLog(fragmentShaderSource, gl.getShaderInfoLog(fragmentShader));
83477 throw new Error('Failed to compile fragment shader.');
83478 }
83479 return fragmentShader;
83480 }
83481 const lineNumberRegex = /ERROR: [0-9]+:([0-9]+):/g;
83482 function logShaderSourceAndInfoLog(shaderSource, shaderInfoLog) {
83483 const lineNumberRegexResult = lineNumberRegex.exec(shaderInfoLog);
83484 if (lineNumberRegexResult == null) {
83485 console.log(`Couldn't parse line number in error: ${shaderInfoLog}`);
83486 console.log(shaderSource);
83487 return;
83488 }
83489 const lineNumber = +lineNumberRegexResult[1];
83490 const shaderLines = shaderSource.split('\n');
83491 const pad = shaderLines.length.toString().length + 2;
83492 const linesWithLineNumbers = shaderLines.map((line, lineNumber) => rightPad((lineNumber + 1).toString(), pad) + line);
83493 let maxLineLength = 0;
83494 for (let i = 0; i < linesWithLineNumbers.length; i++) {
83495 maxLineLength = Math.max(linesWithLineNumbers[i].length, maxLineLength);
83496 }
83497 const beforeErrorLines = linesWithLineNumbers.slice(0, lineNumber - 1);
83498 const errorLine = linesWithLineNumbers.slice(lineNumber - 1, lineNumber);
83499 const afterErrorLines = linesWithLineNumbers.slice(lineNumber);
83500 console.log(beforeErrorLines.join('\n'));
83501 console.log(shaderInfoLog.split('\n')[0]);
83502 console.log(`%c ${rightPad(errorLine[0], maxLineLength)}`, 'border:1px solid red; background-color:#e3d2d2; color:#a61717');
83503 console.log(afterErrorLines.join('\n'));
83504 }
83505 function createProgram(gl) {
83506 return throwIfNull(gl, () => gl.createProgram(), 'Unable to create WebGLProgram.');
83507 }
83508 function linkProgram(gl, program) {
83509 callAndCheck(gl, () => gl.linkProgram(program));
83510 if (env().get('ENGINE_COMPILE_ONLY')) {
83511 return;
83512 }
83513 if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) {
83514 console.log(gl.getProgramInfoLog(program));
83515 throw new Error('Failed to link vertex and fragment shaders.');
83516 }
83517 }
83518 /// validateProgram is effectively "If we `useProgram(program); drawArrays();`,
83519 /// give feedback in log about perf/correctness warnings or errors that would
83520 /// occur."
83521 /// So make sure we set up all vertex/texture/sampler/uniform data before
83522 /// calling validateProgram!
83523 function validateProgram(gl, program) {
83524 callAndCheck(gl, () => gl.validateProgram(program));
83525 if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) {
83526 console.log(gl.getProgramInfoLog(program));
83527 throw new Error('Shader program validation failed.');
83528 }
83529 }
83530 function createStaticVertexBuffer(gl, data) {
83531 const buffer = throwIfNull(gl, () => gl.createBuffer(), 'Unable to create WebGLBuffer');
83532 callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer));
83533 callAndCheck(gl, () => gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW));
83534 return buffer;
83535 }
83536 function createStaticIndexBuffer(gl, data) {
83537 const buffer = throwIfNull(gl, () => gl.createBuffer(), 'Unable to create WebGLBuffer');
83538 callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer));
83539 callAndCheck(gl, () => gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW));
83540 return buffer;
83541 }
83542 function getNumChannels() {
83543 if (env().getNumber('WEBGL_VERSION') === 2) {
83544 return 1;
83545 }
83546 return 4;
83547 }
83548 function createTexture(gl) {
83549 return throwIfNull(gl, () => gl.createTexture(), 'Unable to create WebGLTexture.');
83550 }
83551 function validateTextureSize(width, height) {
83552 const maxTextureSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
83553 if ((width <= 0) || (height <= 0)) {
83554 const requested = `[${width}x${height}]`;
83555 throw new Error('Requested texture size ' + requested + ' is invalid.');
83556 }
83557 if ((width > maxTextureSize) || (height > maxTextureSize)) {
83558 const requested = `[${width}x${height}]`;
83559 const max = `[${maxTextureSize}x${maxTextureSize}]`;
83560 throw new Error('Requested texture size ' + requested +
83561 ' greater than WebGL maximum on this browser / GPU ' + max + '.');
83562 }
83563 }
83564 function createFramebuffer(gl) {
83565 return throwIfNull(gl, () => gl.createFramebuffer(), 'Unable to create WebGLFramebuffer.');
83566 }
83567 function bindVertexBufferToProgramAttribute(gl, program, attribute, buffer, arrayEntriesPerItem, itemStrideInBytes, itemOffsetInBytes) {
83568 const loc = gl.getAttribLocation(program, attribute);
83569 if (loc === -1) {
83570 // The GPU compiler decided to strip out this attribute because it's unused,
83571 // thus no need to bind.
83572 return false;
83573 }
83574 callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer));
83575 callAndCheck(gl, () => gl.vertexAttribPointer(loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes, itemOffsetInBytes));
83576 callAndCheck(gl, () => gl.enableVertexAttribArray(loc));
83577 return true;
83578 }
83579 function bindTextureUnit(gl, texture, textureUnit) {
83580 validateTextureUnit(gl, textureUnit);
83581 callAndCheck(gl, () => gl.activeTexture(gl.TEXTURE0 + textureUnit));
83582 callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture));
83583 }
83584 function unbindTextureUnit(gl, textureUnit) {
83585 validateTextureUnit(gl, textureUnit);
83586 callAndCheck(gl, () => gl.activeTexture(gl.TEXTURE0 + textureUnit));
83587 callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
83588 }
83589 function getProgramUniformLocationOrThrow(gl, program, uniformName) {
83590 return throwIfNull(gl, () => gl.getUniformLocation(program, uniformName), 'uniform "' + uniformName + '" not present in program.');
83591 }
83592 function getProgramUniformLocation(gl, program, uniformName) {
83593 return gl.getUniformLocation(program, uniformName);
83594 }
83595 function bindTextureToProgramUniformSampler(gl, texture, uniformSamplerLocation, textureUnit) {
83596 callAndCheck(gl, () => bindTextureUnit(gl, texture, textureUnit));
83597 callAndCheck(gl, () => gl.uniform1i(uniformSamplerLocation, textureUnit));
83598 }
83599 function bindCanvasToFramebuffer(gl) {
83600 callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null));
83601 callAndCheck(gl, () => gl.viewport(0, 0, gl.canvas.width, gl.canvas.height));
83602 callAndCheck(gl, () => gl.scissor(0, 0, gl.canvas.width, gl.canvas.height));
83603 }
83604 function bindColorTextureToFramebuffer(gl, texture, framebuffer) {
83605 callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer));
83606 callAndCheck(gl, () => gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0));
83607 }
83608 function unbindColorTextureFromFramebuffer(gl, framebuffer) {
83609 callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer));
83610 callAndCheck(gl, () => gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0));
83611 }
83612 function validateFramebuffer(gl) {
83613 const status = gl.checkFramebufferStatus(gl.FRAMEBUFFER);
83614 if (status !== gl.FRAMEBUFFER_COMPLETE) {
83615 throw new Error('Error binding framebuffer: ' + getFramebufferErrorMessage(gl, status));
83616 }
83617 }
83618 function getFramebufferErrorMessage(gl, status) {
83619 switch (status) {
83620 case gl.FRAMEBUFFER_INCOMPLETE_ATTACHMENT:
83621 return 'FRAMEBUFFER_INCOMPLETE_ATTACHMENT';
83622 case gl.FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT:
83623 return 'FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT';
83624 case gl.FRAMEBUFFER_INCOMPLETE_DIMENSIONS:
83625 return 'FRAMEBUFFER_INCOMPLETE_DIMENSIONS';
83626 case gl.FRAMEBUFFER_UNSUPPORTED:
83627 return 'FRAMEBUFFER_UNSUPPORTED';
83628 default:
83629 return `unknown error ${status}`;
83630 }
83631 }
83632 function throwIfNull(gl, returnTOrNull, failureMessage) {
83633 const tOrNull = callAndCheck(gl, () => returnTOrNull());
83634 if (tOrNull == null) {
83635 throw new Error(failureMessage);
83636 }
83637 return tOrNull;
83638 }
83639 function validateTextureUnit(gl, textureUnit) {
83640 const maxTextureUnit = gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS - 1;
83641 const glTextureUnit = textureUnit + gl.TEXTURE0;
83642 if (glTextureUnit < gl.TEXTURE0 || glTextureUnit > maxTextureUnit) {
83643 const textureUnitRange = `[gl.TEXTURE0, gl.TEXTURE${maxTextureUnit}]`;
83644 throw new Error(`textureUnit must be in ${textureUnitRange}.`);
83645 }
83646 }
83647 function getBatchDim(shape, dimsToSkip = 2) {
83648 return sizeFromShape(shape.slice(0, shape.length - dimsToSkip));
83649 }
83650 function getRowsCols(shape) {
83651 if (shape.length === 0) {
83652 throw Error('Cannot get rows and columns of an empty shape array.');
83653 }
83654 return [
83655 shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1]
83656 ];
83657 }
83658 function getShapeAs3D(shape) {
83659 let shapeAs3D = [1, 1, 1];
83660 const isScalar = shape.length === 0 || (shape.length === 1 && shape[0] === 1);
83661 if (!isScalar) {
83662 shapeAs3D =
83663 [getBatchDim(shape), ...getRowsCols(shape)];
83664 }
83665 return shapeAs3D;
83666 }
83667 function getTextureShapeFromLogicalShape(logShape, isPacked = false) {
83668 let maxTexSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
83669 let maxSizeForNarrowTex = env().getNumber('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE');
83670 if (maxSizeForNarrowTex === Infinity &&
83671 env().getBool('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE')) {
83672 maxSizeForNarrowTex = maxTexSize / 2;
83673 }
83674 if (isPacked) {
83675 maxTexSize = maxTexSize * 2;
83676 maxSizeForNarrowTex = maxSizeForNarrowTex * 2;
83677 // This logic ensures we accurately count the number of packed texels needed
83678 // to accommodate the tensor. We can only pack values in the same texel if
83679 // they are from adjacent pairs of rows/cols within the same batch. So if a
83680 // tensor has 3 rows, we pretend it has 4 rows in order to account for the
83681 // fact that the texels containing the third row are half empty.
83682 logShape = logShape.map((d, i) => i >= logShape.length - 2 ?
83683 nearestLargerEven(logShape[i]) :
83684 logShape[i]);
83685 // Packed texture height is at least 2 (the channel height of a single
83686 // texel).
83687 if (logShape.length === 1) {
83688 logShape = [2, logShape[0]];
83689 }
83690 }
83691 // If logical shape is 2, we don't squeeze, since we want to match physical.
83692 if (logShape.length !== 2) {
83693 const squeezeResult = squeezeShape(logShape);
83694 logShape = squeezeResult.newShape;
83695 }
83696 let size = sizeFromShape(logShape);
83697 let textureShape = null;
83698 if (logShape.length <= 1 && size <= maxTexSize) {
83699 textureShape = [1, size];
83700 }
83701 else if (logShape.length === 2 && logShape[0] <= maxTexSize &&
83702 logShape[1] <= maxTexSize) {
83703 textureShape = logShape;
83704 }
83705 else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize &&
83706 logShape[2] <= maxTexSize) {
83707 textureShape = [logShape[0] * logShape[1], logShape[2]];
83708 }
83709 else if (logShape.length === 3 && logShape[0] <= maxTexSize &&
83710 logShape[1] * logShape[2] <= maxTexSize) {
83711 textureShape = [logShape[0], logShape[1] * logShape[2]];
83712 }
83713 else if (logShape.length === 4 &&
83714 logShape[0] * logShape[1] * logShape[2] <= maxTexSize &&
83715 logShape[3] <= maxTexSize) {
83716 textureShape = [logShape[0] * logShape[1] * logShape[2], logShape[3]];
83717 }
83718 else if (logShape.length === 4 && logShape[0] <= maxTexSize &&
83719 logShape[1] * logShape[2] * logShape[3] <= maxTexSize) {
83720 textureShape = [logShape[0], logShape[1] * logShape[2] * logShape[3]];
83721 }
83722 // true if one edge length is 1 (1 or 2, if packed), while another edge
83723 // length exceeds maxSizeForNarrowTex.
83724 const isLongNarrowTex = textureShape != null &&
83725 Math.max(...textureShape) > maxSizeForNarrowTex &&
83726 Math.min(...textureShape) <= (isPacked ? 2 : 1) &&
83727 Math.min(...textureShape) > 0;
83728 if (textureShape == null || isLongNarrowTex) {
83729 if (isPacked) {
83730 // For packed textures size equals the number of channels required to
83731 // accommodate the texture data. However in order to squarify such that
83732 // inner dimensions stay even, we rewrite size to equal the number of
83733 // texels. Then in the return statement we rehydrate the squarified
83734 // dimensions to channel units.
83735 const batchDim = getBatchDim(logShape);
83736 let rows = 2, cols = 2;
83737 if (logShape.length) {
83738 [rows, cols] = getRowsCols(logShape);
83739 }
83740 size = batchDim * (rows / 2) * (cols / 2);
83741 textureShape =
83742 sizeToSquarishShape(size).map(d => d * 2);
83743 }
83744 else {
83745 textureShape = sizeToSquarishShape(size);
83746 }
83747 }
83748 return textureShape;
83749 }
83750 function isEven(n) {
83751 return n % 2 === 0;
83752 }
83753 /**
83754 * This determines whether reshaping a packed texture requires rearranging
83755 * the data within the texture, assuming 2x2 packing.
83756 */
83757 function isReshapeFree(shape1, shape2) {
83758 shape1 = shape1.slice(-2);
83759 shape2 = shape2.slice(-2);
83760 if (arraysEqual(shape1, shape2)) {
83761 return true;
83762 }
83763 if (!shape1.length || !shape2.length) { // One of the shapes is a scalar.
83764 return true;
83765 }
83766 if (shape1[0] === 0 || shape1[1] === 0 || shape2[0] === 0 ||
83767 shape2[1] === 0) {
83768 return true;
83769 }
83770 if (shape1.length !== shape2.length) { // One of the shapes is a vector.
83771 const shape1Cols = shape1[shape1.length - 1];
83772 const shape2Cols = shape2[shape2.length - 1];
83773 if (shape1Cols === shape2Cols) {
83774 return true;
83775 }
83776 if (isEven(shape1Cols) && isEven(shape2Cols) &&
83777 (shape1[0] === 1 || shape2[0] === 1)) {
83778 return true;
83779 }
83780 }
83781 return shape1[1] === shape2[1] && isEven(shape1[0]) && isEven(shape2[0]);
83782 }
83783 // We cache webgl params because the environment gets reset between
83784 // unit tests and we don't want to constantly query the WebGLContext for
83785 // MAX_TEXTURE_SIZE.
83786 let MAX_TEXTURE_SIZE;
83787 let MAX_TEXTURES_IN_SHADER;
83788 function getWebGLMaxTextureSize(webGLVersion) {
83789 if (MAX_TEXTURE_SIZE == null) {
83790 const gl = getWebGLContext(webGLVersion);
83791 MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE);
83792 }
83793 return MAX_TEXTURE_SIZE;
83794 }
83795 function resetMaxTextureSize() {
83796 MAX_TEXTURE_SIZE = null;
83797 }
83798 function resetMaxTexturesInShader() {
83799 MAX_TEXTURES_IN_SHADER = null;
83800 }
83801 function getMaxTexturesInShader(webGLVersion) {
83802 if (MAX_TEXTURES_IN_SHADER == null) {
83803 const gl = getWebGLContext(webGLVersion);
83804 MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS);
83805 }
83806 // We cap at 16 to avoid spurious runtime "memory exhausted" error.
83807 return Math.min(16, MAX_TEXTURES_IN_SHADER);
83808 }
83809 function getWebGLDisjointQueryTimerVersion(webGLVersion) {
83810 if (webGLVersion === 0) {
83811 return 0;
83812 }
83813 let queryTimerVersion;
83814 const gl = getWebGLContext(webGLVersion);
83815 if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') &&
83816 webGLVersion === 2) {
83817 queryTimerVersion = 2;
83818 }
83819 else if (hasExtension(gl, 'EXT_disjoint_timer_query')) {
83820 queryTimerVersion = 1;
83821 }
83822 else {
83823 queryTimerVersion = 0;
83824 }
83825 return queryTimerVersion;
83826 }
83827 function hasExtension(gl, extensionName) {
83828 const ext = gl.getExtension(extensionName);
83829 return ext != null;
83830 }
83831 function isWebGLVersionEnabled(webGLVersion) {
83832 try {
83833 const gl = getWebGLContext(webGLVersion);
83834 if (gl != null) {
83835 return true;
83836 }
83837 }
83838 catch (e) {
83839 console.log('Error when getting WebGL context: ', e);
83840 return false;
83841 }
83842 return false;
83843 }
83844 function isCapableOfRenderingToFloatTexture(webGLVersion) {
83845 if (webGLVersion === 0) {
83846 return false;
83847 }
83848 const gl = getWebGLContext(webGLVersion);
83849 if (webGLVersion === 1) {
83850 if (!hasExtension(gl, 'OES_texture_float')) {
83851 return false;
83852 }
83853 }
83854 else {
83855 if (!hasExtension(gl, 'EXT_color_buffer_float')) {
83856 return false;
83857 }
83858 }
83859 const isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
83860 return isFrameBufferComplete;
83861 }
83862 /**
83863 * Check if we can download values from a float/half-float texture.
83864 *
83865 * Note that for performance reasons we use binding a texture to a framebuffer
83866 * as a proxy for ability to download float values later using readPixels. The
83867 * texture params of this texture will not match those in readPixels exactly
83868 * but if we are unable to bind some kind of float texture to the frameBuffer
83869 * then we definitely will not be able to read float values from it.
83870 */
83871 function isDownloadFloatTextureEnabled(webGLVersion) {
83872 if (webGLVersion === 0) {
83873 return false;
83874 }
83875 const gl = getWebGLContext(webGLVersion);
83876 if (webGLVersion === 1) {
83877 if (!hasExtension(gl, 'OES_texture_float')) {
83878 return false;
83879 }
83880 if (!hasExtension(gl, 'WEBGL_color_buffer_float')) {
83881 return false;
83882 }
83883 }
83884 else {
83885 if (hasExtension(gl, 'EXT_color_buffer_float')) {
83886 return createFloatTextureAndBindToFramebuffer(gl);
83887 }
83888 const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
83889 if (hasExtension(gl, COLOR_BUFFER_HALF_FLOAT)) {
83890 const textureHalfFloatExtension = gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
83891 return createHalfFloatTextureAndBindToFramebuffer(gl, textureHalfFloatExtension);
83892 }
83893 return false;
83894 }
83895 const isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
83896 return isFrameBufferComplete;
83897 }
83898 function createFloatTextureAndBindToFramebuffer(gl) {
83899 const texConfig = getTextureConfig(gl);
83900 const texture = gl.createTexture();
83901 gl.bindTexture(gl.TEXTURE_2D, texture);
83902 const width = 1;
83903 const height = 1;
83904 gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeFloat, null);
83905 const frameBuffer = gl.createFramebuffer();
83906 gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
83907 gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
83908 const isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
83909 gl.bindTexture(gl.TEXTURE_2D, null);
83910 gl.bindFramebuffer(gl.FRAMEBUFFER, null);
83911 gl.deleteTexture(texture);
83912 gl.deleteFramebuffer(frameBuffer);
83913 return isFrameBufferComplete;
83914 }
83915 function createHalfFloatTextureAndBindToFramebuffer(
83916 // tslint:disable-next-line:no-any
83917 gl, textureHalfFloatExtension) {
83918 const texConfig = getTextureConfig(gl, textureHalfFloatExtension);
83919 const texture = gl.createTexture();
83920 gl.bindTexture(gl.TEXTURE_2D, texture);
83921 const width = 1;
83922 const height = 1;
83923 gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatHalfFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeHalfFloat, null);
83924 const frameBuffer = gl.createFramebuffer();
83925 gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
83926 gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
83927 const isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
83928 gl.bindTexture(gl.TEXTURE_2D, null);
83929 gl.bindFramebuffer(gl.FRAMEBUFFER, null);
83930 gl.deleteTexture(texture);
83931 gl.deleteFramebuffer(frameBuffer);
83932 return isFrameBufferComplete;
83933 }
83934 function isWebGLFenceEnabled(webGLVersion) {
83935 if (webGLVersion !== 2) {
83936 return false;
83937 }
83938 const gl = getWebGLContext(webGLVersion);
83939 // tslint:disable-next-line:no-any
83940 const isEnabled = gl.fenceSync != null;
83941 return isEnabled;
83942 }
83943 function assertNotComplex(tensor, opName) {
83944 if (!Array.isArray(tensor)) {
83945 tensor = [tensor];
83946 }
83947 tensor.forEach(t => {
83948 if (t != null) {
83949 assert$1(t.dtype !== 'complex64', () => `${opName} does not support complex64 tensors ` +
83950 'in the WebGL backend.');
83951 }
83952 });
83953 }
83954
83955 var webgl_util = /*#__PURE__*/Object.freeze({
83956 __proto__: null,
83957 assertNotComplex: assertNotComplex,
83958 bindCanvasToFramebuffer: bindCanvasToFramebuffer,
83959 bindColorTextureToFramebuffer: bindColorTextureToFramebuffer,
83960 bindTextureToProgramUniformSampler: bindTextureToProgramUniformSampler,
83961 bindTextureUnit: bindTextureUnit,
83962 bindVertexBufferToProgramAttribute: bindVertexBufferToProgramAttribute,
83963 callAndCheck: callAndCheck,
83964 canBeRepresented: canBeRepresented,
83965 createFragmentShader: createFragmentShader,
83966 createFramebuffer: createFramebuffer,
83967 createProgram: createProgram,
83968 createStaticIndexBuffer: createStaticIndexBuffer,
83969 createStaticVertexBuffer: createStaticVertexBuffer,
83970 createTexture: createTexture,
83971 createVertexShader: createVertexShader$1,
83972 getBatchDim: getBatchDim,
83973 getExtensionOrThrow: getExtensionOrThrow,
83974 getFramebufferErrorMessage: getFramebufferErrorMessage,
83975 getMaxTexturesInShader: getMaxTexturesInShader,
83976 getNumChannels: getNumChannels,
83977 getProgramUniformLocation: getProgramUniformLocation,
83978 getProgramUniformLocationOrThrow: getProgramUniformLocationOrThrow,
83979 getRowsCols: getRowsCols,
83980 getShapeAs3D: getShapeAs3D,
83981 getTextureShapeFromLogicalShape: getTextureShapeFromLogicalShape,
83982 getWebGLDisjointQueryTimerVersion: getWebGLDisjointQueryTimerVersion,
83983 getWebGLErrorMessage: getWebGLErrorMessage,
83984 getWebGLMaxTextureSize: getWebGLMaxTextureSize,
83985 hasExtension: hasExtension,
83986 isCapableOfRenderingToFloatTexture: isCapableOfRenderingToFloatTexture,
83987 isDownloadFloatTextureEnabled: isDownloadFloatTextureEnabled,
83988 isReshapeFree: isReshapeFree,
83989 isWebGLFenceEnabled: isWebGLFenceEnabled,
83990 isWebGLVersionEnabled: isWebGLVersionEnabled,
83991 linkProgram: linkProgram,
83992 logShaderSourceAndInfoLog: logShaderSourceAndInfoLog,
83993 resetMaxTextureSize: resetMaxTextureSize,
83994 resetMaxTexturesInShader: resetMaxTexturesInShader,
83995 unbindColorTextureFromFramebuffer: unbindColorTextureFromFramebuffer,
83996 unbindTextureUnit: unbindTextureUnit,
83997 validateFramebuffer: validateFramebuffer,
83998 validateProgram: validateProgram,
83999 validateTextureSize: validateTextureSize
84000 });
84001
84002 /**
84003 * @license
84004 * Copyright 2019 Google LLC. All Rights Reserved.
84005 * Licensed under the Apache License, Version 2.0 (the "License");
84006 * you may not use this file except in compliance with the License.
84007 * You may obtain a copy of the License at
84008 *
84009 * http://www.apache.org/licenses/LICENSE-2.0
84010 *
84011 * Unless required by applicable law or agreed to in writing, software
84012 * distributed under the License is distributed on an "AS IS" BASIS,
84013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84014 * See the License for the specific language governing permissions and
84015 * limitations under the License.
84016 * =============================================================================
84017 */
84018 const ENV = env();
84019 /**
84020 * This file contains WebGL-specific flag registrations.
84021 */
84022 /**
84023 * True if WebGL is supported.
84024 */
84025 ENV.registerFlag('HAS_WEBGL', () => ENV.getNumber('WEBGL_VERSION') > 0);
84026 /** 0: No WebGL, 1: WebGL 1.0, 2: WebGL 2.0. */
84027 ENV.registerFlag('WEBGL_VERSION', () => {
84028 if (isWebGLVersionEnabled(2)) {
84029 return 2;
84030 }
84031 else if (isWebGLVersionEnabled(1)) {
84032 return 1;
84033 }
84034 return 0;
84035 });
84036 /** Whether to check for numerical representation problems. */
84037 ENV.registerFlag('WEBGL_CHECK_NUMERICAL_PROBLEMS', () => false);
84038 ENV.registerFlag('WEBGL_BUFFER_SUPPORTED', () => ENV.get('WEBGL_VERSION') === 2);
84039 /** Whether the WebGL backend will sometimes forward ops to the CPU. */
84040 ENV.registerFlag('WEBGL_CPU_FORWARD', () => true);
84041 /** Whether the WebGL backend will always use f16 textures for rendering. */
84042 ENV.registerFlag('WEBGL_FORCE_F16_TEXTURES', () => false);
84043 /** Whether to turn all packing related flags on. */
84044 ENV.registerFlag('WEBGL_PACK', () => ENV.getBool('HAS_WEBGL'));
84045 /** Whether we will pack the batchnormalization op. */
84046 ENV.registerFlag('WEBGL_PACK_NORMALIZATION', () => ENV.getBool('WEBGL_PACK'));
84047 /** Whether we will pack the clip op. */
84048 ENV.registerFlag('WEBGL_PACK_CLIP', () => ENV.getBool('WEBGL_PACK'));
84049 /** Whether we will pack the depthwise conv op. */
84050 ENV.registerFlag('WEBGL_PACK_DEPTHWISECONV', () => ENV.getBool('WEBGL_PACK'));
84051 /** Whether we will pack binary ops. */
84052 ENV.registerFlag('WEBGL_PACK_BINARY_OPERATIONS', () => ENV.getBool('WEBGL_PACK'));
84053 /** Whether we will pack unary ops. */
84054 ENV.registerFlag('WEBGL_PACK_UNARY_OPERATIONS', () => ENV.getBool('WEBGL_PACK'));
84055 /** Whether we will pack array ops. */
84056 ENV.registerFlag('WEBGL_PACK_ARRAY_OPERATIONS', () => ENV.getBool('WEBGL_PACK'));
84057 /** Whether we will pack image ops. */
84058 ENV.registerFlag('WEBGL_PACK_IMAGE_OPERATIONS', () => ENV.getBool('WEBGL_PACK'));
84059 /** Whether we will pack reduce ops. */
84060 ENV.registerFlag('WEBGL_PACK_REDUCE', () => ENV.getBool('WEBGL_PACK'));
84061 /** Whether packed WebGL kernels lazily unpack their outputs. */
84062 ENV.registerFlag('WEBGL_LAZILY_UNPACK', () => ENV.getBool('WEBGL_PACK'));
84063 /** Whether we will use the im2col algorithm to speed up convolutions. */
84064 ENV.registerFlag('WEBGL_CONV_IM2COL', () => ENV.getBool('WEBGL_PACK'));
84065 /** Whether we will pack conv2dTranspose op. */
84066 ENV.registerFlag('WEBGL_PACK_CONV2DTRANSPOSE', () => ENV.getBool('WEBGL_PACK'));
84067 /** The maximum texture dimension. */
84068 ENV.registerFlag('WEBGL_MAX_TEXTURE_SIZE', () => getWebGLMaxTextureSize(ENV.getNumber('WEBGL_VERSION')));
84069 /** The maximum texture dimension. */
84070 ENV.registerFlag('WEBGL_MAX_TEXTURES_IN_SHADER', () => getMaxTexturesInShader(ENV.getNumber('WEBGL_VERSION')));
84071 /**
84072 * The disjoint_query_timer extension version.
84073 * 0: disabled, 1: EXT_disjoint_timer_query, 2:
84074 * EXT_disjoint_timer_query_webgl2.
84075 * In Firefox with WebGL 2.0,
84076 * EXT_disjoint_timer_query_webgl2 is not available, so we must use the
84077 * WebGL 1.0 extension.
84078 */
84079 ENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', () => {
84080 const webGLVersion = ENV.getNumber('WEBGL_VERSION');
84081 if (webGLVersion === 0) {
84082 return 0;
84083 }
84084 return getWebGLDisjointQueryTimerVersion(webGLVersion);
84085 });
84086 /**
84087 * Whether the timer object from the disjoint_query_timer extension gives
84088 * timing information that is reliable.
84089 */
84090 ENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', () => ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0 &&
84091 !isMobile());
84092 /**
84093 * Whether the device is physically capable of rendering to float32 textures.
84094 */
84095 ENV.registerFlag('WEBGL_RENDER_FLOAT32_CAPABLE', () => isCapableOfRenderingToFloatTexture(ENV.getNumber('WEBGL_VERSION')));
84096 /**
84097 * Whether rendering to float32 textures is enabled. If disabled, renders to
84098 * float16 textures.
84099 */
84100 ENV.registerFlag('WEBGL_RENDER_FLOAT32_ENABLED', () => {
84101 return ENV.getBool('WEBGL_FORCE_F16_TEXTURES') ?
84102 false :
84103 ENV.getBool('WEBGL_RENDER_FLOAT32_CAPABLE');
84104 });
84105 /**
84106 * Whether downloading float textures is enabled (16 or 32 bit). If disabled,
84107 * uses IEEE 754 encoding of the float32 values to 4 uint8 when downloading.
84108 */
84109 ENV.registerFlag('WEBGL_DOWNLOAD_FLOAT_ENABLED', () => isDownloadFloatTextureEnabled(ENV.getNumber('WEBGL_VERSION')));
84110 /** Whether the fence API is available. */
84111 ENV.registerFlag('WEBGL_FENCE_API_ENABLED', () => isWebGLFenceEnabled(ENV.getNumber('WEBGL_VERSION')));
84112 /**
84113 * Tensors with size <= than this will be uploaded as uniforms, not textures.
84114 */
84115 ENV.registerFlag('WEBGL_SIZE_UPLOAD_UNIFORM', () => {
84116 // Use uniform uploads only when 32bit floats are supported. In
84117 // 16bit
84118 // environments there are problems with comparing a 16bit texture value
84119 // with a 32bit uniform value.
84120 const useUniforms = ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED');
84121 return useUniforms ? 4 : 0;
84122 });
84123 /**
84124 * If the total number of bytes allocated on the GPU is greater than this
84125 * number, we will aggressively delete textures upon disposal with
84126 * gl.deleteMatrixTexture, rather than making them available for reuse.
84127 *
84128 * Default value -1 indicates that we will never aggressively delete textures.
84129 */
84130 ENV.registerFlag('WEBGL_DELETE_TEXTURE_THRESHOLD', () => {
84131 return -1;
84132 }, threshold => {
84133 if (!(typeof threshold === 'number')) {
84134 throw new Error('WEBGL_DELETE_TEXTURE_THRESHOLD must be a number but ' +
84135 `got ${threshold}.`);
84136 }
84137 if (threshold < 0 && threshold !== -1) {
84138 throw new Error(`WEBGL_DELETE_TEXTURE_THRESHOLD must be -1 (indicating never ` +
84139 `delete) or at least 0, but got ${threshold}.`);
84140 }
84141 });
84142 /**
84143 * Trigger a manual GL command flush if the threshold of time has passed since
84144 * previous Kernel execution. This can be useful for Andorid device where GL
84145 * command flush are delayed un til the end of javascript task. This value is
84146 * measured in millisecond. Typically you want to set this value to close to 1.
84147 *
84148 * Default value 1 for mobile chrome, and -1 for rest cases. -1 indicates that
84149 * we will not enforce manual flush and depend on system default flush schedule.
84150 */
84151 ENV.registerFlag('WEBGL_FLUSH_THRESHOLD', () => {
84152 return isMobile() ? 1 : -1;
84153 }, threshold => {
84154 if (!(typeof threshold === 'number')) {
84155 throw new Error('WEBGL_FLUSH_THRESHOLD must be a number but got ' +
84156 `${threshold}.`);
84157 }
84158 if (threshold < 0 && threshold !== -1) {
84159 throw new Error(`WEBGL_FLUSH_THRESHOLD must be -1 (indicating never ` +
84160 `manual flush) or at least 0, but got ${threshold}.`);
84161 }
84162 });
84163 /**
84164 * Threshold for input tensor size that determines whether WebGL backend will
84165 * delegate computation to CPU.
84166 *
84167 * Default value is 128.
84168 */
84169 ENV.registerFlag('CPU_HANDOFF_SIZE_THRESHOLD', () => 128);
84170 /** Whether we will use shapes uniforms. */
84171 ENV.registerFlag('WEBGL_USE_SHAPES_UNIFORMS', () => false);
84172 /**
84173 * Threshold for last dimension of input tensor that determines whether
84174 * WebGL backend for the Top K op will delegate computation to CPU. If input
84175 * is smaller than threshold then CPU will be used
84176 *
84177 * Default value is 100000.
84178 */
84179 ENV.registerFlag('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD', () => 100000);
84180 /**
84181 * Threshold for K that determines whether
84182 * WebGL backend for the Top K op will delegate computation to CPU. If k
84183 * is larger than threshold then CPU will be used
84184 *
84185 * Default value is 128.
84186 */
84187 ENV.registerFlag('TOPK_K_CPU_HANDOFF_THRESHOLD', () => 128);
84188 /** Whether we will use the experimental conv op. */
84189 ENV.registerFlag('WEBGL_EXP_CONV', () => false);
84190 /**
84191 * If the device performance is low or if no hardware GPU is available, whether
84192 * software WebGL will be used.
84193 */
84194 ENV.registerFlag('SOFTWARE_WEBGL_ENABLED', () => ENV.getBool('IS_TEST'));
84195 /**
84196 * For narrow texture (physical height or physical width is 1), if the length of
84197 * any texture edges exceed the threshold, the texture will be reshaped to be
84198 * more squarish.
84199 *
84200 * This flag is used to help some GPUs that could not provide correct
84201 * interpolations for long skinny triangles. We found Mali GPU probably has this
84202 * problem: https://github.com/tensorflow/tfjs/issues/6775.
84203 */
84204 ENV.registerFlag('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', () => Infinity);
84205 /**
84206 * If the flag is set to true, the max size of the narrow texture will be auto
84207 * computed and it will be considerred as a threshold to reshape the narrow
84208 * texture to be more squarish.
84209 *
84210 * This flag is used to help some GPUs that could not provide correct
84211 * interpolations for long skinny triangles. We found Mali GPU probably has this
84212 * problem: https://github.com/tensorflow/tfjs/issues/6775.
84213 */
84214 ENV.registerFlag('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE', () => false);
84215 /**
84216 * Whether to use the customized isnan. It's only useful for webgl2 since webgl1
84217 * doesn't have the builtin isnan.
84218 */
84219 ENV.registerFlag('WEBGL2_ISNAN_CUSTOM', () => false);
84220 /** Experimental flag, whether enter compile only phase. */
84221 ENV.registerFlag('ENGINE_COMPILE_ONLY', () => false);
84222
84223 /**
84224 * @license
84225 * Copyright 2018 Google LLC. All Rights Reserved.
84226 * Licensed under the Apache License, Version 2.0 (the "License");
84227 * you may not use this file except in compliance with the License.
84228 * You may obtain a copy of the License at
84229 *
84230 * http://www.apache.org/licenses/LICENSE-2.0
84231 *
84232 * Unless required by applicable law or agreed to in writing, software
84233 * distributed under the License is distributed on an "AS IS" BASIS,
84234 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84235 * See the License for the specific language governing permissions and
84236 * limitations under the License.
84237 * =============================================================================
84238 */
84239 function getGlslDifferences() {
84240 let version;
84241 let attribute;
84242 let varyingVs;
84243 let varyingFs;
84244 let texture2D;
84245 let output;
84246 let defineOutput;
84247 let defineSpecialNaN;
84248 let defineSpecialInf;
84249 let defineRound;
84250 if (env().getNumber('WEBGL_VERSION') === 2) {
84251 version = '#version 300 es';
84252 attribute = 'in';
84253 varyingVs = 'out';
84254 varyingFs = 'in';
84255 texture2D = 'texture';
84256 output = 'outputColor';
84257 defineOutput = 'out vec4 outputColor;';
84258 // Use custom isnan definition to work across differences between
84259 // implementations on various platforms. While this should happen in ANGLE
84260 // we still see differences between android and windows (on chrome) when
84261 // using isnan directly. Since WebGL2 supports uint type and
84262 // floatBitsToUinT built-in function, we could implment isnan following
84263 // IEEE 754 rules.
84264 // NaN defination in IEEE 754-1985 is :
84265 // - sign = either 0 or 1.
84266 // - biased exponent = all 1 bits.
84267 // - fraction = anything except all 0 bits (since all 0 bits represents
84268 // infinity).
84269 // https://en.wikipedia.org/wiki/IEEE_754-1985#Representation_of_non-numbers
84270 defineSpecialNaN = env().getBool('WEBGL2_ISNAN_CUSTOM') ? `
84271 bool isnan_custom(float val) {
84272 uint floatToUint = floatBitsToUint(val);
84273 return (floatToUint & 0x7fffffffu) > 0x7f800000u;
84274 }
84275
84276 bvec4 isnan_custom(vec4 val) {
84277 return bvec4(isnan_custom(val.x),
84278 isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w));
84279 }
84280
84281 #define isnan(value) isnan_custom(value)
84282 ` :
84283 '';
84284 // In webgl 2 we do not need to specify a custom isinf so there is no
84285 // need for a special INFINITY constant.
84286 defineSpecialInf = ``;
84287 defineRound = `
84288 #define round(value) newRound(value)
84289 int newRound(float value) {
84290 return int(floor(value + 0.5));
84291 }
84292
84293 ivec4 newRound(vec4 value) {
84294 return ivec4(floor(value + vec4(0.5)));
84295 }
84296 `;
84297 }
84298 else {
84299 version = '';
84300 attribute = 'attribute';
84301 varyingVs = 'varying';
84302 varyingFs = 'varying';
84303 texture2D = 'texture2D';
84304 output = 'gl_FragColor';
84305 defineOutput = '';
84306 // WebGL1 has no built in isnan so we define one here.
84307 defineSpecialNaN = `
84308 #define isnan(value) isnan_custom(value)
84309 bool isnan_custom(float val) {
84310 return (val > 0. || val < 1. || val == 0.) ? false : true;
84311 }
84312 bvec4 isnan_custom(vec4 val) {
84313 return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w));
84314 }
84315 `;
84316 defineSpecialInf = `
84317 uniform float INFINITY;
84318
84319 bool isinf(float val) {
84320 return abs(val) == INFINITY;
84321 }
84322 bvec4 isinf(vec4 val) {
84323 return equal(abs(val), vec4(INFINITY));
84324 }
84325 `;
84326 defineRound = `
84327 int round(float value) {
84328 return int(floor(value + 0.5));
84329 }
84330
84331 ivec4 round(vec4 value) {
84332 return ivec4(floor(value + vec4(0.5)));
84333 }
84334 `;
84335 }
84336 return {
84337 version,
84338 attribute,
84339 varyingVs,
84340 varyingFs,
84341 texture2D,
84342 output,
84343 defineOutput,
84344 defineSpecialNaN,
84345 defineSpecialInf,
84346 defineRound
84347 };
84348 }
84349
84350 /**
84351 * @license
84352 * Copyright 2018 Google LLC. All Rights Reserved.
84353 * Licensed under the Apache License, Version 2.0 (the "License");
84354 * you may not use this file except in compliance with the License.
84355 * You may obtain a copy of the License at
84356 *
84357 * http://www.apache.org/licenses/LICENSE-2.0
84358 *
84359 * Unless required by applicable law or agreed to in writing, software
84360 * distributed under the License is distributed on an "AS IS" BASIS,
84361 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84362 * See the License for the specific language governing permissions and
84363 * limitations under the License.
84364 * =============================================================================
84365 */
84366 /**
84367 * Produces GLSL code that derives logical coordinates from a flat
84368 * index. The code performs integer division with each stride and decrements
84369 * the index until the index equals the final dimension coordinate.
84370 */
84371 function getLogicalCoordinatesFromFlatIndex(coords, shape, index = 'index') {
84372 const strides = computeStrides(shape);
84373 return strides
84374 .map((stride, i) => {
84375 const line1 = `int ${coords[i]} = ${index} / ${stride}`;
84376 const line2 = i === strides.length - 1 ?
84377 `int ${coords[i + 1]} = ${index} - ${coords[i]} * ${stride}` :
84378 `index -= ${coords[i]} * ${stride}`;
84379 return `${line1}; ${line2};`;
84380 })
84381 .join('');
84382 }
84383 function getOutputLogicalCoordinatesFromFlatIndexByUniform(coords, shape, index = 'index') {
84384 const strides = computeStrides(shape);
84385 return strides
84386 .map((_, i) => {
84387 const line1 = `int ${coords[i]} = ${index} / outShapeStrides[${i}]`;
84388 const line2 = i === strides.length - 1 ?
84389 `int ${coords[i + 1]} = ${index} - ${coords[i]} * outShapeStrides[${i}]` :
84390 `index -= ${coords[i]} * outShapeStrides[${i}]`;
84391 return `${line1}; ${line2};`;
84392 })
84393 .join('');
84394 }
84395 // Produces GLSL code that computes strides.
84396 function symbolicallyComputeStrides(indicesArr, variableName) {
84397 const numCoords = indicesArr.length;
84398 const shape = indicesArr.map(d => `${variableName}[${d}]`);
84399 const strides = new Array(numCoords - 1);
84400 strides[numCoords - 2] = shape[numCoords - 1];
84401 for (let i = numCoords - 3; i >= 0; --i) {
84402 strides[i] = `(${strides[i + 1]} * ${shape[i + 1]})`;
84403 }
84404 return strides;
84405 }
84406 function getLogicalCoordinatesFromFlatIndexByUniform(coords, variableName, index = 'index') {
84407 const indicesArray = coords.map((_, i) => i);
84408 const strides = symbolicallyComputeStrides(indicesArray, variableName);
84409 return strides
84410 .map((_, i) => {
84411 const line1 = `int ${coords[i]} = ${index} / ${strides[i]}`;
84412 const line2 = i === strides.length - 1 ?
84413 `int ${coords[i + 1]} = ${index} - ${coords[i]} * ${strides[i]}` :
84414 `index -= ${coords[i]} * ${strides[i]}`;
84415 return `${line1}; ${line2};`;
84416 })
84417 .join('');
84418 }
84419 function buildVec(x) {
84420 if (x.length === 1) {
84421 return `${x[0]}`;
84422 }
84423 return `vec${x.length}(${x.join(',')})`;
84424 }
84425 /**
84426 * Produces GLSL code that computes the dot product of the input x and y
84427 * vectors. Handles splitting inputs into increments of vec4s when necessary.
84428 */
84429 function dotify(x, y) {
84430 if (x.length !== y.length) {
84431 throw new Error(`Vectors to be dotted must be of the same length -` +
84432 `got ${x.length} and ${y.length}`);
84433 }
84434 const slices = [];
84435 const nearestVec4 = Math.floor(x.length / 4);
84436 const nearestVec4Remainder = x.length % 4;
84437 for (let i = 0; i < nearestVec4; i++) {
84438 const xSlice = x.slice(i * 4, i * 4 + 4);
84439 const ySlice = y.slice(i * 4, i * 4 + 4);
84440 slices.push(`${buildVec(xSlice)}, ${buildVec(ySlice)}`);
84441 }
84442 if (nearestVec4Remainder !== 0) {
84443 let xSlice = x.slice(nearestVec4 * 4);
84444 let ySlice = y.slice(nearestVec4 * 4);
84445 if (xSlice.length === 1) {
84446 xSlice = xSlice.map(d => `float(${d})`);
84447 ySlice = ySlice.map(d => `float(${d})`);
84448 }
84449 slices.push(`${buildVec(xSlice)}, ${buildVec(ySlice)}`);
84450 }
84451 return slices.map((d, i) => `dot(${d})`).join('+');
84452 }
84453 /**
84454 * Produces GLSL that computes the flat index from 3D coordinates.
84455 */
84456 function getFlatIndexFrom3D(shape) {
84457 const strides = computeStrides(shape).map(d => d.toString());
84458 return `
84459 int getFlatIndex(ivec3 coords) {
84460 return coords.x * ${strides[0]} + coords.y * ${strides[1]} + coords.z;
84461 }
84462`;
84463 }
84464 function getFlatIndexFrom3DOutput() {
84465 return `
84466 int getFlatIndex(ivec3 coords) {
84467 return coords.x * outShapeStrides[0] + coords.y * outShapeStrides[1] + coords.z;
84468 }
84469`;
84470 }
84471 const ENCODE_FLOAT_SNIPPET = `
84472 const float FLOAT_MAX = 1.70141184e38;
84473 const float FLOAT_MIN = 1.17549435e-38;
84474
84475 lowp vec4 encode_float(highp float v) {
84476 if (isnan(v)) {
84477 return vec4(255, 255, 255, 255);
84478 }
84479
84480 highp float av = abs(v);
84481
84482 if(av < FLOAT_MIN) {
84483 return vec4(0.0, 0.0, 0.0, 0.0);
84484 } else if(v > FLOAT_MAX) {
84485 return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;
84486 } else if(v < -FLOAT_MAX) {
84487 return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;
84488 }
84489
84490 highp vec4 c = vec4(0,0,0,0);
84491
84492 highp float e = floor(log2(av));
84493 highp float m = exp2(fract(log2(av))) - 1.0;
84494
84495 c[2] = floor(128.0 * m);
84496 m -= c[2] / 128.0;
84497 c[1] = floor(32768.0 * m);
84498 m -= c[1] / 32768.0;
84499 c[0] = floor(8388608.0 * m);
84500
84501 highp float ebias = e + 127.0;
84502 c[3] = floor(ebias / 2.0);
84503 ebias -= c[3] * 2.0;
84504 c[2] += floor(ebias) * 128.0;
84505
84506 c[3] += 128.0 * step(0.0, -v);
84507
84508 return c / 255.0;
84509 }
84510`;
84511
84512 /**
84513 * @license
84514 * Copyright 2017 Google LLC. All Rights Reserved.
84515 * Licensed under the Apache License, Version 2.0 (the "License");
84516 * you may not use this file except in compliance with the License.
84517 * You may obtain a copy of the License at
84518 *
84519 * http://www.apache.org/licenses/LICENSE-2.0
84520 *
84521 * Unless required by applicable law or agreed to in writing, software
84522 * distributed under the License is distributed on an "AS IS" BASIS,
84523 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84524 * See the License for the specific language governing permissions and
84525 * limitations under the License.
84526 * =============================================================================
84527 */
84528 const { getBroadcastDims } = backend_util;
84529 function makeShader(inputsInfo, outputShape, program) {
84530 const prefixSnippets = [];
84531 inputsInfo.forEach(x => {
84532 const size = sizeFromShape(x.shapeInfo.logicalShape);
84533 // Snippet when we decided to upload the values as uniform.
84534 if (x.shapeInfo.isUniform) {
84535 prefixSnippets.push(`uniform float ${x.name}${size > 1 ? `[${size}]` : ''};`);
84536 }
84537 else {
84538 prefixSnippets.push(`uniform sampler2D ${x.name};`);
84539 prefixSnippets.push(`uniform int offset${x.name};`);
84540 }
84541 if (program.enableShapeUniforms) {
84542 const { uniformShape } = getUniformInfoFromShape(program.packedInputs, x.shapeInfo.logicalShape, x.shapeInfo.texShape);
84543 switch (uniformShape.length) {
84544 case 1:
84545 prefixSnippets.push(`uniform int ${x.name}Shape;`);
84546 break;
84547 case 2:
84548 prefixSnippets.push(`uniform ivec2 ${x.name}Shape;`);
84549 break;
84550 case 3:
84551 prefixSnippets.push(`uniform ivec3 ${x.name}Shape;`);
84552 break;
84553 case 4:
84554 prefixSnippets.push(`uniform ivec4 ${x.name}Shape;`);
84555 break;
84556 default:
84557 break;
84558 }
84559 prefixSnippets.push(`uniform ivec2 ${x.name}TexShape;`);
84560 }
84561 });
84562 if (program.enableShapeUniforms) {
84563 switch (outputShape.logicalShape.length) {
84564 case 1:
84565 prefixSnippets.push(`uniform int outShape;`);
84566 break;
84567 case 2:
84568 prefixSnippets.push(`uniform ivec2 outShape;`);
84569 prefixSnippets.push(`uniform int outShapeStrides;`);
84570 break;
84571 case 3:
84572 prefixSnippets.push(`uniform ivec3 outShape;`);
84573 prefixSnippets.push(`uniform ivec2 outShapeStrides;`);
84574 break;
84575 case 4:
84576 prefixSnippets.push(`uniform ivec4 outShape;`);
84577 prefixSnippets.push(`uniform ivec3 outShapeStrides;`);
84578 break;
84579 default:
84580 break;
84581 }
84582 prefixSnippets.push(`uniform ivec2 outTexShape;`);
84583 }
84584 if (program.customUniforms) {
84585 program.customUniforms.forEach((d) => {
84586 prefixSnippets.push(`uniform ${d.type} ${d.name}${d.arrayIndex ? `[${d.arrayIndex}]` : ''};`);
84587 });
84588 }
84589 const inputPrefixSnippet = prefixSnippets.join('\n');
84590 const inputSamplingSnippet = inputsInfo
84591 .map(x => getInputSamplingSnippet(x, outputShape, program.packedInputs, program.enableShapeUniforms))
84592 .join('\n');
84593 const outTexShape = outputShape.texShape;
84594 const glsl = getGlslDifferences();
84595 const floatTextureSampleSnippet = getFloatTextureSampleSnippet(glsl);
84596 let outputSamplingSnippet;
84597 let floatTextureSetOutputSnippet;
84598 let shaderPrefix = getShaderPrefix(glsl);
84599 if (outputShape.isPacked) {
84600 outputSamplingSnippet = getPackedOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms);
84601 floatTextureSetOutputSnippet = getFloatTextureSetRGBASnippet(glsl);
84602 }
84603 else {
84604 outputSamplingSnippet = getOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms);
84605 floatTextureSetOutputSnippet = getFloatTextureSetRSnippet(glsl);
84606 }
84607 if (program.packedInputs) {
84608 shaderPrefix += SHADER_PACKED_PREFIX;
84609 }
84610 const source = [
84611 shaderPrefix, floatTextureSampleSnippet, floatTextureSetOutputSnippet,
84612 inputPrefixSnippet, outputSamplingSnippet, inputSamplingSnippet,
84613 program.userCode
84614 ].join('\n');
84615 return source;
84616 }
84617 function getSamplerFromInInfo(inInfo, enableShapeUniforms = false) {
84618 const shape = inInfo.shapeInfo.logicalShape;
84619 switch (shape.length) {
84620 case 0:
84621 return getSamplerScalar(inInfo, enableShapeUniforms);
84622 case 1:
84623 return getSampler1D(inInfo, enableShapeUniforms);
84624 case 2:
84625 return getSampler2D(inInfo, enableShapeUniforms);
84626 case 3:
84627 return getSampler3D(inInfo, enableShapeUniforms);
84628 case 4:
84629 return getSampler4D(inInfo, enableShapeUniforms);
84630 case 5:
84631 return getSampler5D(inInfo);
84632 case 6:
84633 return getSampler6D(inInfo);
84634 default:
84635 throw new Error(`${shape.length}-D input sampling` +
84636 ` is not yet supported`);
84637 }
84638 }
84639 function getPackedSamplerFromInInfo(inInfo, enableShapeUniforms) {
84640 const shape = inInfo.shapeInfo.logicalShape;
84641 switch (shape.length) {
84642 case 0:
84643 return getPackedSamplerScalar(inInfo);
84644 case 1:
84645 return getPackedSampler1D(inInfo, enableShapeUniforms);
84646 case 2:
84647 return getPackedSampler2D(inInfo, enableShapeUniforms);
84648 case 3:
84649 return getPackedSampler3D(inInfo, enableShapeUniforms);
84650 default:
84651 return getPackedSamplerND(inInfo, enableShapeUniforms);
84652 }
84653 }
84654 function getInputSamplingSnippet(inInfo, outShapeInfo, usesPackedTextures = false, enableShapeUniforms) {
84655 let res = '';
84656 if (usesPackedTextures) {
84657 res += getPackedSamplerFromInInfo(inInfo, enableShapeUniforms);
84658 }
84659 else {
84660 res += getSamplerFromInInfo(inInfo, enableShapeUniforms);
84661 }
84662 const inShape = inInfo.shapeInfo.logicalShape;
84663 const outShape = outShapeInfo.logicalShape;
84664 if (inShape.length <= outShape.length) {
84665 if (usesPackedTextures) {
84666 res += getPackedSamplerAtOutputCoords(inInfo, outShapeInfo);
84667 }
84668 else {
84669 res += getSamplerAtOutputCoords(inInfo, outShapeInfo);
84670 }
84671 }
84672 return res;
84673 }
84674 function getPackedOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) {
84675 switch (outShape.length) {
84676 case 0:
84677 return getOutputScalarCoords();
84678 case 1:
84679 return getOutputPacked1DCoords(outShape, outTexShape, enableShapeUniforms);
84680 case 2:
84681 return getOutputPacked2DCoords(outShape, outTexShape, enableShapeUniforms);
84682 case 3:
84683 return getOutputPacked3DCoords(outShape, outTexShape, enableShapeUniforms);
84684 default:
84685 return getOutputPackedNDCoords(outShape, outTexShape, enableShapeUniforms);
84686 }
84687 }
84688 function getOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) {
84689 switch (outShape.length) {
84690 case 0:
84691 return getOutputScalarCoords();
84692 case 1:
84693 return getOutput1DCoords(outShape, outTexShape, enableShapeUniforms);
84694 case 2:
84695 return getOutput2DCoords(outShape, outTexShape, enableShapeUniforms);
84696 case 3:
84697 return getOutput3DCoords(outShape, outTexShape, enableShapeUniforms);
84698 case 4:
84699 return getOutput4DCoords(outShape, outTexShape, enableShapeUniforms);
84700 case 5:
84701 return getOutput5DCoords(outShape, outTexShape);
84702 case 6:
84703 return getOutput6DCoords(outShape, outTexShape);
84704 default:
84705 throw new Error(`${outShape.length}-D output sampling is not yet supported`);
84706 }
84707 }
84708 function getFloatTextureSampleSnippet(glsl) {
84709 return `
84710 float sampleTexture(sampler2D textureSampler, vec2 uv) {
84711 return ${glsl.texture2D}(textureSampler, uv).r;
84712 }
84713 `;
84714 }
84715 function getFloatTextureSetRSnippet(glsl) {
84716 return `
84717 void setOutput(float val) {
84718 ${glsl.output} = vec4(val, 0, 0, 0);
84719 }
84720 `;
84721 }
84722 function getFloatTextureSetRGBASnippet(glsl) {
84723 return `
84724 void setOutput(vec4 val) {
84725 ${glsl.output} = val;
84726 }
84727 `;
84728 }
84729 function getShaderPrefix(glsl) {
84730 const SHADER_PREFIX = `${glsl.version}
84731 precision highp float;
84732 precision highp int;
84733 precision highp sampler2D;
84734 ${glsl.varyingFs} vec2 resultUV;
84735 ${glsl.defineOutput}
84736 const vec2 halfCR = vec2(0.5, 0.5);
84737
84738 struct ivec5
84739 {
84740 int x;
84741 int y;
84742 int z;
84743 int w;
84744 int u;
84745 };
84746
84747 struct ivec6
84748 {
84749 int x;
84750 int y;
84751 int z;
84752 int w;
84753 int u;
84754 int v;
84755 };
84756
84757 uniform float NAN;
84758 ${glsl.defineSpecialNaN}
84759 ${glsl.defineSpecialInf}
84760 ${glsl.defineRound}
84761
84762 int imod(int x, int y) {
84763 return x - y * (x / y);
84764 }
84765
84766 int idiv(int a, int b, float sign) {
84767 int res = a / b;
84768 int mod = imod(a, b);
84769 if (sign < 0. && mod != 0) {
84770 res -= 1;
84771 }
84772 return res;
84773 }
84774
84775 //Based on the work of Dave Hoskins
84776 //https://www.shadertoy.com/view/4djSRW
84777 #define HASHSCALE1 443.8975
84778 float random(float seed){
84779 vec2 p = resultUV * seed;
84780 vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);
84781 p3 += dot(p3, p3.yzx + 19.19);
84782 return fract((p3.x + p3.y) * p3.z);
84783 }
84784
84785 ${SAMPLE_1D_SNIPPET}
84786 ${SAMPLE_2D_SNIPPET}
84787 ${SAMPLE_3D_SNIPPET}
84788 `;
84789 return SHADER_PREFIX;
84790 }
84791 const SAMPLE_1D_SNIPPET = `
84792vec2 uvFromFlat(int texNumR, int texNumC, int index) {
84793 int texR = index / texNumC;
84794 int texC = index - texR * texNumC;
84795 return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
84796}
84797vec2 packedUVfrom1D(int texNumR, int texNumC, int index) {
84798 int texelIndex = index / 2;
84799 int texR = texelIndex / texNumC;
84800 int texC = texelIndex - texR * texNumC;
84801 return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
84802}
84803`;
84804 const SAMPLE_2D_SNIPPET = `
84805vec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR,
84806 int texNumC, int row, int col) {
84807 int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2);
84808 int texR = texelIndex / texNumC;
84809 int texC = texelIndex - texR * texNumC;
84810 return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
84811}
84812`;
84813 const SAMPLE_3D_SNIPPET = `
84814vec2 packedUVfrom3D(int texNumR, int texNumC,
84815 int texelsInBatch, int texelsInLogicalRow, int b,
84816 int row, int col) {
84817 int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2);
84818 int texR = index / texNumC;
84819 int texC = index - texR * texNumC;
84820 return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
84821}
84822`;
84823 const SHADER_PACKED_PREFIX = `
84824 float getChannel(vec4 frag, vec2 innerDims) {
84825 vec2 modCoord = mod(innerDims, 2.);
84826 return modCoord.x == 0. ?
84827 (modCoord.y == 0. ? frag.r : frag.g) :
84828 (modCoord.y == 0. ? frag.b : frag.a);
84829 }
84830 float getChannel(vec4 frag, int dim) {
84831 float modCoord = mod(float(dim), 2.);
84832 return modCoord == 0. ? frag.r : frag.g;
84833 }
84834`;
84835 function getOutputScalarCoords() {
84836 return `
84837 int getOutputCoords() {
84838 return 0;
84839 }
84840 `;
84841 }
84842 function getOutputPacked1DCoords(shape, texShape, enableShapeUniforms) {
84843 const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
84844 if (packedTexShape[0] === 1) {
84845 if (enableShapeUniforms) {
84846 return `
84847 int getOutputCoords() {
84848 return 2 * int(resultUV.x * ceil(float(outTexShape[1]) / 2.0));
84849 }
84850 `;
84851 }
84852 return `
84853 int getOutputCoords() {
84854 return 2 * int(resultUV.x * ${packedTexShape[1]}.0);
84855 }
84856 `;
84857 }
84858 if (packedTexShape[1] === 1) {
84859 if (enableShapeUniforms) {
84860 return `
84861 int getOutputCoords() {
84862 return 2 * int(resultUV.y * ceil(float(outTexShape[0]) / 2.0));
84863 }
84864 `;
84865 }
84866 return `
84867 int getOutputCoords() {
84868 return 2 * int(resultUV.y * ${packedTexShape[0]}.0);
84869 }
84870 `;
84871 }
84872 if (enableShapeUniforms) {
84873 return `
84874 int getOutputCoords() {
84875 ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
84876 ivec2 resTexRC = ivec2(resultUV.yx *
84877 vec2(packedTexShape[0], packedTexShape[1]));
84878 return 2 * (resTexRC.x * packedTexShape[1] + resTexRC.y);
84879 }
84880 `;
84881 }
84882 return `
84883 int getOutputCoords() {
84884 ivec2 resTexRC = ivec2(resultUV.yx *
84885 vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
84886 return 2 * (resTexRC.x * ${packedTexShape[1]} + resTexRC.y);
84887 }
84888 `;
84889 }
84890 function getOutput1DCoords(shape, texShape, enableShapeUniforms) {
84891 if (texShape[0] === 1) {
84892 if (enableShapeUniforms) {
84893 return `
84894 int getOutputCoords() {
84895 return int(resultUV.x * float(outTexShape[1]));
84896 }
84897 `;
84898 }
84899 return `
84900 int getOutputCoords() {
84901 return int(resultUV.x * ${texShape[1]}.0);
84902 }
84903 `;
84904 }
84905 if (texShape[1] === 1) {
84906 if (enableShapeUniforms) {
84907 return `
84908 int getOutputCoords() {
84909 return int(resultUV.y * float(outTexShape[0]));
84910 }
84911 `;
84912 }
84913 return `
84914 int getOutputCoords() {
84915 return int(resultUV.y * ${texShape[0]}.0);
84916 }
84917 `;
84918 }
84919 if (enableShapeUniforms) {
84920 return `
84921 int getOutputCoords() {
84922 ivec2 resTexRC = ivec2(resultUV.yx *
84923 vec2(outTexShape[0], outTexShape[1]));
84924 return resTexRC.x * outTexShape[1] + resTexRC.y;
84925 }
84926 `;
84927 }
84928 return `
84929 int getOutputCoords() {
84930 ivec2 resTexRC = ivec2(resultUV.yx *
84931 vec2(${texShape[0]}, ${texShape[1]}));
84932 return resTexRC.x * ${texShape[1]} + resTexRC.y;
84933 }
84934 `;
84935 }
84936 function getOutputPacked3DCoords(shape, texShape, enableShapeUniforms) {
84937 if (enableShapeUniforms) {
84938 return `
84939 ivec3 getOutputCoords() {
84940 ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
84941 int texelsInLogicalRow = int(ceil(float(outShape[2]) / 2.0));
84942 int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[1]) / 2.0));
84943 ivec2 resTexRC = ivec2(resultUV.yx *
84944 vec2(packedTexShape[0], packedTexShape[1]));
84945 int index = resTexRC.x * packedTexShape[1] + resTexRC.y;
84946
84947 int b = index / texelsInBatch;
84948 index -= b * texelsInBatch;
84949
84950 int r = 2 * (index / texelsInLogicalRow);
84951 int c = imod(index, texelsInLogicalRow) * 2;
84952
84953 return ivec3(b, r, c);
84954 }
84955 `;
84956 }
84957 const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
84958 const texelsInLogicalRow = Math.ceil(shape[2] / 2);
84959 const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2);
84960 return `
84961 ivec3 getOutputCoords() {
84962 ivec2 resTexRC = ivec2(resultUV.yx *
84963 vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
84964 int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
84965
84966 int b = index / ${texelsInBatch};
84967 index -= b * ${texelsInBatch};
84968
84969 int r = 2 * (index / ${texelsInLogicalRow});
84970 int c = imod(index, ${texelsInLogicalRow}) * 2;
84971
84972 return ivec3(b, r, c);
84973 }
84974 `;
84975 }
84976 function getOutput3DCoords(shape, texShape, enableShapeUniforms) {
84977 if (enableShapeUniforms) {
84978 const coordsFromIndexSnippet = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], shape);
84979 return `
84980 ivec3 getOutputCoords() {
84981 ivec2 resTexRC = ivec2(resultUV.yx *
84982 vec2(outTexShape[0], outTexShape[1]));
84983 int index = resTexRC.x * outTexShape[1] + resTexRC.y;
84984 ${coordsFromIndexSnippet}
84985 return ivec3(r, c, d);
84986 }
84987`;
84988 }
84989 const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
84990 return `
84991 ivec3 getOutputCoords() {
84992 ivec2 resTexRC = ivec2(resultUV.yx *
84993 vec2(${texShape[0]}, ${texShape[1]}));
84994 int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
84995 ${coordsFromIndexSnippet}
84996 return ivec3(r, c, d);
84997 }
84998 `;
84999 }
85000 function getOutputPackedNDCoords(shape, texShape, enableShapeUniforms) {
85001 if (enableShapeUniforms) {
85002 // TODO: support 5d and 6d
85003 return `
85004 ivec4 getOutputCoords() {
85005 ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
85006 ivec2 resTexRC = ivec2(resultUV.yx *
85007 vec2(packedTexShape[0], packedTexShape[1]));
85008 int index = resTexRC.x * packedTexShape[1] + resTexRC.y;
85009
85010 int texelsInLogicalRow = int(ceil(float(outShape[3]) / 2.0));
85011 int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[2]) / 2.0));
85012 int texelsInBatchN = texelsInBatch * outShape[1];
85013
85014 int b2 = index / texelsInBatchN;
85015 index -= b2 * texelsInBatchN;
85016
85017 int b = index / texelsInBatch;
85018 index -= b * texelsInBatch;
85019
85020 int r = 2 * (index / texelsInLogicalRow);
85021 int c = imod(index, texelsInLogicalRow) * 2;
85022
85023 return ivec4(b2, b, r, c);
85024 }
85025 `;
85026 }
85027 const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
85028 const texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2);
85029 const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2);
85030 let texelsInBatchN = texelsInBatch;
85031 let batches = ``;
85032 let coords = 'b, r, c';
85033 for (let b = 2; b < shape.length - 1; b++) {
85034 texelsInBatchN *= shape[shape.length - b - 1];
85035 batches = `
85036 int b${b} = index / ${texelsInBatchN};
85037 index -= b${b} * ${texelsInBatchN};
85038 ` + batches;
85039 coords = `b${b}, ` + coords;
85040 }
85041 return `
85042 ivec${shape.length} getOutputCoords() {
85043 ivec2 resTexRC = ivec2(resultUV.yx *
85044 vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
85045 int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
85046
85047 ${batches}
85048
85049 int b = index / ${texelsInBatch};
85050 index -= b * ${texelsInBatch};
85051
85052 int r = 2 * (index / ${texelsInLogicalRow});
85053 int c = imod(index, ${texelsInLogicalRow}) * 2;
85054
85055 return ivec${shape.length}(${coords});
85056 }
85057 `;
85058 }
85059 function getOutput4DCoords(shape, texShape, enableShapeUniforms) {
85060 if (enableShapeUniforms) {
85061 const coordsFromIndexSnippet = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd', 'd2'], shape);
85062 return `
85063 ivec4 getOutputCoords() {
85064 ivec2 resTexRC = ivec2(resultUV.yx *
85065 vec2(outTexShape[0], outTexShape[1]));
85066 int index = resTexRC.x * outTexShape[1] + resTexRC.y;
85067 ${coordsFromIndexSnippet}
85068 return ivec4(r, c, d, d2);
85069 }
85070 `;
85071 }
85072 const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2'], shape);
85073 return `
85074 ivec4 getOutputCoords() {
85075 ivec2 resTexRC = ivec2(resultUV.yx *
85076 vec2(${texShape[0]}, ${texShape[1]}));
85077 int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
85078 ${coordsFromIndexSnippet}
85079 return ivec4(r, c, d, d2);
85080 }
85081 `;
85082 }
85083 function getOutput5DCoords(shape, texShape) {
85084 const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3'], shape);
85085 return `
85086 ivec5 getOutputCoords() {
85087 ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]},
85088 ${texShape[1]}));
85089
85090 int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
85091
85092 ${coordsFromIndexSnippet}
85093
85094 ivec5 outShape = ivec5(r, c, d, d2, d3);
85095 return outShape;
85096 }
85097 `;
85098 }
85099 function getOutput6DCoords(shape, texShape) {
85100 const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3', 'd4'], shape);
85101 return `
85102 ivec6 getOutputCoords() {
85103 ivec2 resTexRC = ivec2(resultUV.yx *
85104 vec2(${texShape[0]}, ${texShape[1]}));
85105 int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
85106
85107 ${coordsFromIndexSnippet}
85108
85109 ivec6 result = ivec6(r, c, d, d2, d3, d4);
85110 return result;
85111 }
85112 `;
85113 }
85114 function getOutputPacked2DCoords(shape, texShape, enableShapeUniforms) {
85115 const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
85116 if (arraysEqual(shape, texShape)) {
85117 if (enableShapeUniforms) {
85118 return `
85119 ivec2 getOutputCoords() {
85120 ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
85121 return 2 * ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1]));
85122 }
85123 `;
85124 }
85125 return `
85126 ivec2 getOutputCoords() {
85127 return 2 * ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
85128 }
85129 `;
85130 }
85131 // texels needed to accommodate a logical row
85132 const texelsInLogicalRow = Math.ceil(shape[1] / 2);
85133 /**
85134 * getOutputCoords
85135 *
85136 * resTexRC: The rows and columns of the texels. If you move over one
85137 * texel to the right in the packed texture, you are moving over one column
85138 * (not two).
85139 *
85140 * index: The texel index
85141 */
85142 if (enableShapeUniforms) {
85143 return `
85144 ivec2 getOutputCoords() {
85145 ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
85146 int texelsInLogicalRow = int(ceil(float(outShape[1]) / 2.0));
85147 ivec2 resTexRC = ivec2(resultUV.yx *
85148 vec2(packedTexShape[0], packedTexShape[1]));
85149
85150 int index = resTexRC.x * packedTexShape[1] + resTexRC.y;
85151 int r = 2 * (index / texelsInLogicalRow);
85152 int c = imod(index, texelsInLogicalRow) * 2;
85153
85154 return ivec2(r, c);
85155 }
85156 `;
85157 }
85158 return `
85159 ivec2 getOutputCoords() {
85160 ivec2 resTexRC = ivec2(resultUV.yx *
85161 vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
85162
85163 int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
85164 int r = 2 * (index / ${texelsInLogicalRow});
85165 int c = imod(index, ${texelsInLogicalRow}) * 2;
85166
85167 return ivec2(r, c);
85168 }
85169 `;
85170 }
85171 function getOutput2DCoords(shape, texShape, enableShapeUniforms) {
85172 if (arraysEqual(shape, texShape)) {
85173 if (enableShapeUniforms) {
85174 return `
85175 ivec2 getOutputCoords() {
85176 return ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1]));
85177 }
85178 `;
85179 }
85180 return `
85181 ivec2 getOutputCoords() {
85182 return ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]}));
85183 }
85184 `;
85185 }
85186 if (shape[1] === 1) {
85187 if (enableShapeUniforms) {
85188 return `
85189 ivec2 getOutputCoords() {
85190 ivec2 resTexRC = ivec2(resultUV.yx *
85191 vec2(outTexShape[0], outTexShape[1]));
85192 int index = resTexRC.x * outTexShape[1] + resTexRC.y;
85193 return ivec2(index, 0);
85194 }
85195 `;
85196 }
85197 return `
85198 ivec2 getOutputCoords() {
85199 ivec2 resTexRC = ivec2(resultUV.yx *
85200 vec2(${texShape[0]}, ${texShape[1]}));
85201 int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
85202 return ivec2(index, 0);
85203 }
85204 `;
85205 }
85206 if (shape[0] === 1) {
85207 if (enableShapeUniforms) {
85208 return `
85209 ivec2 getOutputCoords() {
85210 ivec2 resTexRC = ivec2(resultUV.yx *
85211 vec2(outTexShape[0], outTexShape[1]));
85212 int index = resTexRC.x * outTexShape[1] + resTexRC.y;
85213 return ivec2(0, index);
85214 }
85215 `;
85216 }
85217 return `
85218 ivec2 getOutputCoords() {
85219 ivec2 resTexRC = ivec2(resultUV.yx *
85220 vec2(${texShape[0]}, ${texShape[1]}));
85221 int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
85222 return ivec2(0, index);
85223 }
85224 `;
85225 }
85226 if (enableShapeUniforms) {
85227 return `
85228 ivec2 getOutputCoords() {
85229 ivec2 resTexRC = ivec2(resultUV.yx *
85230 vec2(outTexShape[0], outTexShape[1]));
85231 int index = resTexRC.x * outTexShape[1] + resTexRC.y;
85232 int r = index / outShape[1];
85233 int c = index - r * outShape[1];
85234 return ivec2(r, c);
85235 }
85236 `;
85237 }
85238 return `
85239 ivec2 getOutputCoords() {
85240 ivec2 resTexRC = ivec2(resultUV.yx *
85241 vec2(${texShape[0]}, ${texShape[1]}));
85242 int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
85243 int r = index / ${shape[1]};
85244 int c = index - r * ${shape[1]};
85245 return ivec2(r, c);
85246 }
85247 `;
85248 }
85249 function getFlatOffsetUniformName(texName) {
85250 return `offset${texName}`;
85251 }
85252 function getPackedSamplerScalar(inputInfo) {
85253 const texName = inputInfo.name;
85254 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
85255 const glsl = getGlslDifferences();
85256 return `
85257 vec4 ${funcName}() {
85258 return ${glsl.texture2D}(${texName}, halfCR);
85259 }
85260 `;
85261 }
85262 function getSamplerScalar(inputInfo, enableShapeUniforms) {
85263 const texName = inputInfo.name;
85264 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
85265 if (inputInfo.shapeInfo.isUniform) {
85266 return `float ${funcName}() {return ${texName};}`;
85267 }
85268 const [texNumR, texNumC] = inputInfo.shapeInfo.texShape;
85269 if (texNumR === 1 && texNumC === 1) {
85270 return `
85271 float ${funcName}() {
85272 return sampleTexture(${texName}, halfCR);
85273 }
85274 `;
85275 }
85276 const offset = getFlatOffsetUniformName(texName);
85277 if (enableShapeUniforms) {
85278 return `
85279 float ${funcName}() {
85280 vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], ${offset});
85281 return sampleTexture(${texName}, uv);
85282 }
85283 `;
85284 }
85285 const [tNumR, tNumC] = inputInfo.shapeInfo.texShape;
85286 return `
85287 float ${funcName}() {
85288 vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, ${offset});
85289 return sampleTexture(${texName}, uv);
85290 }
85291 `;
85292 }
85293 function getPackedSampler1D(inputInfo, enableShapeUniforms) {
85294 const texName = inputInfo.name;
85295 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
85296 const texShape = inputInfo.shapeInfo.texShape;
85297 const glsl = getGlslDifferences();
85298 if (enableShapeUniforms) {
85299 return `
85300 vec4 ${funcName}(int index) {
85301 ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0));
85302 vec2 uv = packedUVfrom1D(
85303 packedTexShape[0], packedTexShape[1], index);
85304 return ${glsl.texture2D}(${texName}, uv);
85305 }
85306 `;
85307 }
85308 const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
85309 return `
85310 vec4 ${funcName}(int index) {
85311 vec2 uv = packedUVfrom1D(
85312 ${packedTexShape[0]}, ${packedTexShape[1]}, index);
85313 return ${glsl.texture2D}(${texName}, uv);
85314 }
85315 `;
85316 }
85317 function getSampler1D(inputInfo, enableShapeUniforms) {
85318 const texName = inputInfo.name;
85319 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
85320 if (inputInfo.shapeInfo.isUniform) {
85321 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
85322 return `
85323 float ${funcName}(int index) {
85324 ${getUniformSampler(inputInfo)}
85325 }
85326 `;
85327 }
85328 const texShape = inputInfo.shapeInfo.texShape;
85329 const tNumR = texShape[0];
85330 const tNumC = texShape[1];
85331 if (tNumC === 1 && tNumR === 1) {
85332 return `
85333 float ${funcName}(int index) {
85334 return sampleTexture(${texName}, halfCR);
85335 }
85336 `;
85337 }
85338 const offset = getFlatOffsetUniformName(texName);
85339 if (tNumC === 1) {
85340 if (enableShapeUniforms) {
85341 return `
85342 float ${funcName}(int index) {
85343 vec2 uv = vec2(0.5, (float(index + ${offset}) + 0.5) / float(${texName}TexShape[0]));
85344 return sampleTexture(${texName}, uv);
85345 }
85346 `;
85347 }
85348 return `
85349 float ${funcName}(int index) {
85350 vec2 uv = vec2(0.5, (float(index + ${offset}) + 0.5) / ${tNumR}.0);
85351 return sampleTexture(${texName}, uv);
85352 }
85353 `;
85354 }
85355 if (tNumR === 1) {
85356 if (enableShapeUniforms) {
85357 return `
85358 float ${funcName}(int index) {
85359 vec2 uv = vec2((float(index + ${offset}) + 0.5) / float(${texName}TexShape[1]), 0.5);
85360 return sampleTexture(${texName}, uv);
85361 }
85362 `;
85363 }
85364 return `
85365 float ${funcName}(int index) {
85366 vec2 uv = vec2((float(index + ${offset}) + 0.5) / ${tNumC}.0, 0.5);
85367 return sampleTexture(${texName}, uv);
85368 }
85369 `;
85370 }
85371 if (enableShapeUniforms) {
85372 return `
85373 float ${funcName}(int index) {
85374 vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index + ${offset});
85375 return sampleTexture(${texName}, uv);
85376 }
85377 `;
85378 }
85379 return `
85380 float ${funcName}(int index) {
85381 vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, index + ${offset});
85382 return sampleTexture(${texName}, uv);
85383 }
85384 `;
85385 }
85386 function getPackedSampler2D(inputInfo, enableShapeUniforms) {
85387 const shape = inputInfo.shapeInfo.logicalShape;
85388 const texName = inputInfo.name;
85389 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
85390 const texShape = inputInfo.shapeInfo.texShape;
85391 const texNumR = texShape[0];
85392 const texNumC = texShape[1];
85393 const glsl = getGlslDifferences();
85394 if (texShape != null && arraysEqual(shape, texShape)) {
85395 if (enableShapeUniforms) {
85396 return `
85397 vec4 ${funcName}(int row, int col) {
85398 vec2 uv = (vec2(col, row) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]);
85399
85400 return ${glsl.texture2D}(${texName}, uv);
85401 }
85402 `;
85403 }
85404 return `
85405 vec4 ${funcName}(int row, int col) {
85406 vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
85407
85408 return ${glsl.texture2D}(${texName}, uv);
85409 }
85410 `;
85411 }
85412 if (enableShapeUniforms) {
85413 return `
85414 vec4 ${funcName}(int row, int col) {
85415 ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0));
85416 int valuesPerRow = int(ceil(float(${texName}Shape[1]) / 2.0));
85417 vec2 uv = packedUVfrom2D(valuesPerRow, packedTexShape[0], packedTexShape[1], row, col);
85418 return ${glsl.texture2D}(${texName}, uv);
85419 }
85420 `;
85421 }
85422 const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
85423 const valuesPerRow = Math.ceil(shape[1] / 2);
85424 return `
85425 vec4 ${funcName}(int row, int col) {
85426 vec2 uv = packedUVfrom2D(${valuesPerRow}, ${packedTexShape[0]}, ${packedTexShape[1]}, row, col);
85427 return ${glsl.texture2D}(${texName}, uv);
85428 }
85429 `;
85430 }
85431 function getSampler2D(inputInfo, enableShapeUniforms) {
85432 const shape = inputInfo.shapeInfo.logicalShape;
85433 const texName = inputInfo.name;
85434 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
85435 const texShape = inputInfo.shapeInfo.texShape;
85436 if (texShape != null && arraysEqual(shape, texShape)) {
85437 if (enableShapeUniforms) {
85438 return `
85439 float ${funcName}(int row, int col) {
85440 vec2 uv = (vec2(col, row) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]);
85441 return sampleTexture(${texName}, uv);
85442 }
85443 `;
85444 }
85445 const texNumR = texShape[0];
85446 const texNumC = texShape[1];
85447 return `
85448 float ${funcName}(int row, int col) {
85449 vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
85450 return sampleTexture(${texName}, uv);
85451 }
85452 `;
85453 }
85454 const { newShape, keptDims } = squeezeShape(shape);
85455 const squeezedShape = newShape;
85456 if (squeezedShape.length < shape.length) {
85457 const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
85458 const params = ['row', 'col'];
85459 return `
85460 ${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)}
85461 float ${funcName}(int row, int col) {
85462 return ${funcName}(${getSqueezedParams(params, keptDims)});
85463 }
85464 `;
85465 }
85466 if (inputInfo.shapeInfo.isUniform) {
85467 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
85468 return `
85469 float ${funcName}(int row, int col) {
85470 int index = round(dot(vec2(row, col), vec2(${shape[1]}, 1)));
85471 ${getUniformSampler(inputInfo)}
85472 }
85473 `;
85474 }
85475 const texNumR = texShape[0];
85476 const texNumC = texShape[1];
85477 const offset = getFlatOffsetUniformName(texName);
85478 if (texNumC === 1) {
85479 // index is used directly as physical (no risk of float16 overflow).
85480 if (enableShapeUniforms) {
85481 return `
85482 float ${funcName}(int row, int col) {
85483 float index = dot(vec3(row, col, ${offset}), vec3(${texName}Shape[1], 1, 1));
85484 vec2 uv = vec2(0.5, (index + 0.5) / float(${texName}TexShape[0]));
85485 return sampleTexture(${texName}, uv);
85486 }
85487 `;
85488 }
85489 return `
85490 float ${funcName}(int row, int col) {
85491 float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1));
85492 vec2 uv = vec2(0.5, (index + 0.5) / ${texNumR}.0);
85493 return sampleTexture(${texName}, uv);
85494 }
85495 `;
85496 }
85497 if (texNumR === 1) {
85498 // index is used directly as physical (no risk of float16 overflow).
85499 if (enableShapeUniforms) {
85500 return `
85501 float ${funcName}(int row, int col) {
85502 float index = dot(vec3(row, col, ${offset}), vec3(${texName}Shape[1], 1, 1));
85503 vec2 uv = vec2((index + 0.5) / float(${texName}TexShape[1]), 0.5);
85504 return sampleTexture(${texName}, uv);
85505 }
85506 `;
85507 }
85508 return `
85509 float ${funcName}(int row, int col) {
85510 float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1));
85511 vec2 uv = vec2((index + 0.5) / ${texNumC}.0, 0.5);
85512 return sampleTexture(${texName}, uv);
85513 }
85514 `;
85515 }
85516 if (enableShapeUniforms) {
85517 return `
85518 float ${funcName}(int row, int col) {
85519 // Explicitly use integer operations as dot() only works on floats.
85520 int index = row * ${texName}Shape[1] + col + ${offset};
85521 vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index);
85522 return sampleTexture(${texName}, uv);
85523 }
85524 `;
85525 }
85526 return `
85527 float ${funcName}(int row, int col) {
85528 // Explicitly use integer operations as dot() only works on floats.
85529 int index = row * ${shape[1]} + col + ${offset};
85530 vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
85531 return sampleTexture(${texName}, uv);
85532 }
85533`;
85534 }
85535 function getPackedSampler3D(inputInfo, enableShapeUniforms) {
85536 const shape = inputInfo.shapeInfo.logicalShape;
85537 const texName = inputInfo.name;
85538 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
85539 const texShape = inputInfo.shapeInfo.texShape;
85540 const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
85541 if (shape[0] === 1) {
85542 const squeezedShape = shape.slice(1);
85543 const keptDims = [1, 2];
85544 const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
85545 const params = ['b', 'row', 'col'];
85546 return `
85547 ${getPackedSamplerFromInInfo(newInputInfo, enableShapeUniforms)}
85548 vec4 ${funcName}(int b, int row, int col) {
85549 return ${funcName}(${getSqueezedParams(params, keptDims)});
85550 }
85551 `;
85552 }
85553 const glsl = getGlslDifferences();
85554 if (enableShapeUniforms) {
85555 return `
85556 vec4 ${funcName}(int b, int row, int col) {
85557 ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0));
85558 int valuesPerRow = int(ceil(float(${texName}Shape[2]) / 2.0));
85559 int texelsInBatch = valuesPerRow * int(ceil(float(${texName}Shape[1]) / 2.0));
85560 vec2 uv = packedUVfrom3D(
85561 packedTexShape[0], packedTexShape[1], texelsInBatch, valuesPerRow, b, row, col);
85562 return ${glsl.texture2D}(${texName}, uv);
85563 }
85564 `;
85565 }
85566 const texNumR = packedTexShape[0];
85567 const texNumC = packedTexShape[1];
85568 const valuesPerRow = Math.ceil(shape[2] / 2);
85569 const texelsInBatch = valuesPerRow * Math.ceil(shape[1] / 2);
85570 return `
85571 vec4 ${funcName}(int b, int row, int col) {
85572 vec2 uv = packedUVfrom3D(
85573 ${texNumR}, ${texNumC}, ${texelsInBatch}, ${valuesPerRow}, b, row, col);
85574 return ${glsl.texture2D}(${texName}, uv);
85575 }
85576 `;
85577 }
85578 function getSampler3D(inputInfo, enableShapeUniforms) {
85579 const shape = inputInfo.shapeInfo.logicalShape;
85580 const texName = inputInfo.name;
85581 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
85582 const stride0 = shape[1] * shape[2];
85583 const stride1 = shape[2];
85584 const { newShape, keptDims } = squeezeShape(shape);
85585 const squeezedShape = newShape;
85586 if (squeezedShape.length < shape.length) {
85587 const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
85588 const params = ['row', 'col', 'depth'];
85589 return `
85590 ${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)}
85591 float ${funcName}(int row, int col, int depth) {
85592 return ${funcName}(${getSqueezedParams(params, keptDims)});
85593 }
85594 `;
85595 }
85596 if (inputInfo.shapeInfo.isUniform) {
85597 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
85598 return `
85599 float ${funcName}(int row, int col, int depth) {
85600 int index = round(dot(vec3(row, col, depth),
85601 vec3(${stride0}, ${stride1}, 1)));
85602 ${getUniformSampler(inputInfo)}
85603 }
85604 `;
85605 }
85606 const texShape = inputInfo.shapeInfo.texShape;
85607 const texNumR = texShape[0];
85608 const texNumC = texShape[1];
85609 const flatOffset = inputInfo.shapeInfo.flatOffset;
85610 if (texNumC === stride0 && flatOffset == null) {
85611 // texC is used directly as physical (no risk of float16 overflow).
85612 if (enableShapeUniforms) {
85613 return `
85614 float ${funcName}(int row, int col, int depth) {
85615 int stride1 = ${texName}Shape[2];
85616 float texR = float(row);
85617 float texC = dot(vec2(col, depth), vec2(stride1, 1));
85618 vec2 uv = (vec2(texC, texR) + halfCR) /
85619 vec2(${texName}TexShape[1], ${texName}TexShape[0]);
85620 return sampleTexture(${texName}, uv);
85621 }
85622 `;
85623 }
85624 return `
85625 float ${funcName}(int row, int col, int depth) {
85626 float texR = float(row);
85627 float texC = dot(vec2(col, depth), vec2(${stride1}, 1));
85628 vec2 uv = (vec2(texC, texR) + halfCR) /
85629 vec2(${texNumC}.0, ${texNumR}.0);
85630 return sampleTexture(${texName}, uv);
85631 }
85632 `;
85633 }
85634 if (texNumC === stride1 && flatOffset == null) {
85635 // texR is used directly as physical (no risk of float16 overflow).
85636 if (enableShapeUniforms) {
85637 return `
85638 float ${funcName}(int row, int col, int depth) {
85639 float texR = dot(vec2(row, col), vec2(${texName}Shape[1], 1));
85640 float texC = float(depth);
85641 vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]);
85642 return sampleTexture(${texName}, uv);
85643 }
85644 `;
85645 }
85646 return `
85647 float ${funcName}(int row, int col, int depth) {
85648 float texR = dot(vec2(row, col), vec2(${shape[1]}, 1));
85649 float texC = float(depth);
85650 vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
85651 return sampleTexture(${texName}, uv);
85652 }
85653 `;
85654 }
85655 const offset = getFlatOffsetUniformName(texName);
85656 if (enableShapeUniforms) {
85657 return `
85658 float ${funcName}(int row, int col, int depth) {
85659 // Explicitly use integer operations as dot() only works on floats.
85660 int stride0 = ${texName}Shape[1] * ${texName}Shape[2];
85661 int stride1 = ${texName}Shape[2];
85662 int index = row * stride0 + col * stride1 + depth + ${offset};
85663 vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index);
85664 return sampleTexture(${texName}, uv);
85665 }
85666 `;
85667 }
85668 return `
85669 float ${funcName}(int row, int col, int depth) {
85670 // Explicitly use integer operations as dot() only works on floats.
85671 int index = row * ${stride0} + col * ${stride1} + depth + ${offset};
85672 vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
85673 return sampleTexture(${texName}, uv);
85674 }
85675 `;
85676 }
85677 function getPackedSamplerND(inputInfo, enableShapeUniforms) {
85678 const texName = inputInfo.name;
85679 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
85680 const glsl = getGlslDifferences();
85681 if (enableShapeUniforms) {
85682 // TODO: support 5d and 6d
85683 return `
85684 vec4 ${funcName}(int b2, int b, int row, int col) {
85685 int valuesPerRow = int(ceil(float(${texName}Shape[3]) / 2.0));
85686 int texelsInBatch = valuesPerRow * int(ceil(float(${texName}Shape[2]) / 2.0));
85687 int index = b * texelsInBatch + (row / 2) * valuesPerRow + (col / 2);
85688 texelsInBatch *= ${texName}Shape[1];
85689 index = b2 * texelsInBatch + index;
85690 ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0));
85691 int texR = index / packedTexShape[1];
85692 int texC = index - texR * packedTexShape[1];
85693 vec2 uv = (vec2(texC, texR) + halfCR) / vec2(packedTexShape[1], packedTexShape[0]); return ${glsl.texture2D}(${texName}, uv);
85694 }
85695 `;
85696 }
85697 const shape = inputInfo.shapeInfo.logicalShape;
85698 const rank = shape.length;
85699 const texShape = inputInfo.shapeInfo.texShape;
85700 const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
85701 const texNumR = packedTexShape[0];
85702 const texNumC = packedTexShape[1];
85703 const valuesPerRow = Math.ceil(shape[rank - 1] / 2);
85704 let texelsInBatch = valuesPerRow * Math.ceil(shape[rank - 2] / 2);
85705 let params = `int b, int row, int col`;
85706 let index = `b * ${texelsInBatch} + (row / 2) * ${valuesPerRow} + (col / 2)`;
85707 for (let b = 2; b < rank - 1; b++) {
85708 params = `int b${b}, ` + params;
85709 texelsInBatch *= shape[rank - b - 1];
85710 index = `b${b} * ${texelsInBatch} + ` + index;
85711 }
85712 return `
85713 vec4 ${funcName}(${params}) {
85714 int index = ${index};
85715 int texR = index / ${texNumC};
85716 int texC = index - texR * ${texNumC};
85717 vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}, ${texNumR});
85718 return ${glsl.texture2D}(${texName}, uv);
85719 }
85720 `;
85721 }
85722 function getSampler4D(inputInfo, enableShapeUniforms) {
85723 const shape = inputInfo.shapeInfo.logicalShape;
85724 const texName = inputInfo.name;
85725 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
85726 const stride2 = shape[3];
85727 const stride1 = shape[2] * stride2;
85728 const stride0 = shape[1] * stride1;
85729 const { newShape, keptDims } = squeezeShape(shape);
85730 if (newShape.length < shape.length) {
85731 const newInputInfo = squeezeInputInfo(inputInfo, newShape);
85732 const params = ['row', 'col', 'depth', 'depth2'];
85733 return `
85734 ${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)}
85735 float ${funcName}(int row, int col, int depth, int depth2) {
85736 return ${funcName}(${getSqueezedParams(params, keptDims)});
85737 }
85738 `;
85739 }
85740 if (inputInfo.shapeInfo.isUniform) {
85741 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
85742 return `
85743 float ${funcName}(int row, int col, int depth, int depth2) {
85744 int index = round(dot(vec4(row, col, depth, depth2),
85745 vec4(${stride0}, ${stride1}, ${stride2}, 1)));
85746 ${getUniformSampler(inputInfo)}
85747 }
85748 `;
85749 }
85750 const flatOffset = inputInfo.shapeInfo.flatOffset;
85751 const texShape = inputInfo.shapeInfo.texShape;
85752 const texNumR = texShape[0];
85753 const texNumC = texShape[1];
85754 const stride2Str = `int stride2 = ${texName}Shape[3];`;
85755 const stride1Str = `int stride1 = ${texName}Shape[2] * stride2;`;
85756 const stride0Str = `int stride0 = ${texName}Shape[1] * stride1;`;
85757 if (texNumC === stride0 && flatOffset == null) {
85758 // texC is used directly as physical (no risk of float16 overflow).
85759 if (enableShapeUniforms) {
85760 return `
85761 float ${funcName}(int row, int col, int depth, int depth2) {
85762 ${stride2Str}
85763 ${stride1Str}
85764 float texR = float(row);
85765 float texC =
85766 dot(vec3(col, depth, depth2),
85767 vec3(stride1, stride2, 1));
85768 vec2 uv = (vec2(texC, texR) + halfCR) /
85769 vec2(${texName}TexShape[1], ${texName}TexShape[0]);
85770 return sampleTexture(${texName}, uv);
85771 }
85772 `;
85773 }
85774 return `
85775 float ${funcName}(int row, int col, int depth, int depth2) {
85776 float texR = float(row);
85777 float texC =
85778 dot(vec3(col, depth, depth2),
85779 vec3(${stride1}, ${stride2}, 1));
85780 vec2 uv = (vec2(texC, texR) + halfCR) /
85781 vec2(${texNumC}.0, ${texNumR}.0);
85782 return sampleTexture(${texName}, uv);
85783 }
85784 `;
85785 }
85786 if (texNumC === stride2 && flatOffset == null) {
85787 // texR is used directly as physical (no risk of float16 overflow).
85788 if (enableShapeUniforms) {
85789 return `
85790 float ${funcName}(int row, int col, int depth, int depth2) {
85791 float texR = dot(vec3(row, col, depth),
85792 vec3(${texName}Shape[1] * ${texName}Shape[2], ${texName}Shape[2], 1));
85793 float texC = float(depth2);
85794 vec2 uv = (vec2(texC, texR) + halfCR) /
85795 vec2(${texName}TexShape[1], ${texName}TexShape[0]);
85796 return sampleTexture(${texName}, uv);
85797 }
85798 `;
85799 }
85800 return `
85801 float ${funcName}(int row, int col, int depth, int depth2) {
85802 float texR = dot(vec3(row, col, depth),
85803 vec3(${shape[1] * shape[2]}, ${shape[2]}, 1));
85804 float texC = float(depth2);
85805 vec2 uv = (vec2(texC, texR) + halfCR) /
85806 vec2(${texNumC}.0, ${texNumR}.0);
85807 return sampleTexture(${texName}, uv);
85808 }
85809 `;
85810 }
85811 const offset = getFlatOffsetUniformName(texName);
85812 if (enableShapeUniforms) {
85813 return `
85814 float ${funcName}(int row, int col, int depth, int depth2) {
85815 // Explicitly use integer operations as dot() only works on floats.
85816 ${stride2Str}
85817 ${stride1Str}
85818 ${stride0Str}
85819 int index = row * stride0 + col * stride1 +
85820 depth * stride2 + depth2;
85821 vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index + ${offset});
85822 return sampleTexture(${texName}, uv);
85823 }
85824 `;
85825 }
85826 return `
85827 float ${funcName}(int row, int col, int depth, int depth2) {
85828 // Explicitly use integer operations as dot() only works on floats.
85829 int index = row * ${stride0} + col * ${stride1} +
85830 depth * ${stride2} + depth2;
85831 vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index + ${offset});
85832 return sampleTexture(${texName}, uv);
85833 }
85834 `;
85835 }
85836 function getSampler5D(inputInfo) {
85837 const shape = inputInfo.shapeInfo.logicalShape;
85838 const texName = inputInfo.name;
85839 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
85840 const stride3 = shape[4];
85841 const stride2 = shape[3] * stride3;
85842 const stride1 = shape[2] * stride2;
85843 const stride0 = shape[1] * stride1;
85844 const { newShape, keptDims } = squeezeShape(shape);
85845 if (newShape.length < shape.length) {
85846 const newInputInfo = squeezeInputInfo(inputInfo, newShape);
85847 const params = ['row', 'col', 'depth', 'depth2', 'depth3'];
85848 return `
85849 ${getSamplerFromInInfo(newInputInfo)}
85850 float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
85851 return ${funcName}(${getSqueezedParams(params, keptDims)});
85852 }
85853 `;
85854 }
85855 if (inputInfo.shapeInfo.isUniform) {
85856 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
85857 return `
85858 float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
85859 float index = dot(
85860 vec4(row, col, depth, depth2),
85861 vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) +
85862 depth3;
85863 ${getUniformSampler(inputInfo)}
85864 }
85865 `;
85866 }
85867 const flatOffset = inputInfo.shapeInfo.flatOffset;
85868 const texShape = inputInfo.shapeInfo.texShape;
85869 const texNumR = texShape[0];
85870 const texNumC = texShape[1];
85871 if (texNumC === stride0 && flatOffset == null) {
85872 // texC is used directly as physical (no risk of float16 overflow).
85873 return `
85874 float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
85875 int texR = row;
85876 float texC = dot(vec4(col, depth, depth2, depth3),
85877 vec4(${stride1}, ${stride2}, ${stride3}, 1));
85878 vec2 uv = (vec2(texC, texR) + halfCR) /
85879 vec2(${texNumC}.0, ${texNumR}.0);
85880 return sampleTexture(${texName}, uv);
85881 }
85882 `;
85883 }
85884 if (texNumC === stride3 && flatOffset == null) {
85885 // texR is used directly as physical (no risk of float16 overflow).
85886 return `
85887 float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
85888 float texR = dot(
85889 vec4(row, col, depth, depth2),
85890 vec4(${shape[1] * shape[2] * shape[3]},
85891 ${shape[2] * shape[3]}, ${shape[3]}, 1));
85892 int texC = depth3;
85893 vec2 uv = (vec2(texC, texR) + halfCR) /
85894 vec2(${texNumC}.0, ${texNumR}.0);
85895 return sampleTexture(${texName}, uv);
85896 }
85897 `;
85898 }
85899 const offset = getFlatOffsetUniformName(texName);
85900 return `
85901 float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
85902 // Explicitly use integer operations as dot() only works on floats.
85903 int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} +
85904 depth2 * ${stride3} + depth3 + ${offset};
85905 vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
85906 return sampleTexture(${texName}, uv);
85907 }
85908 `;
85909 }
85910 function getSampler6D(inputInfo) {
85911 const shape = inputInfo.shapeInfo.logicalShape;
85912 const texName = inputInfo.name;
85913 const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
85914 const { newShape, keptDims } = squeezeShape(shape);
85915 if (newShape.length < shape.length) {
85916 const newInputInfo = squeezeInputInfo(inputInfo, newShape);
85917 const params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4'];
85918 return `
85919 ${getSamplerFromInInfo(newInputInfo)}
85920 float ${funcName}(int row, int col, int depth,
85921 int depth2, int depth3, int depth4) {
85922 return ${funcName}(${getSqueezedParams(params, keptDims)});
85923 }
85924 `;
85925 }
85926 const stride4 = shape[5];
85927 const stride3 = shape[4] * stride4;
85928 const stride2 = shape[3] * stride3;
85929 const stride1 = shape[2] * stride2;
85930 const stride0 = shape[1] * stride1;
85931 if (inputInfo.shapeInfo.isUniform) {
85932 // Uniform arrays will be less than 65505 (no risk of float16 overflow).
85933 return `
85934 float ${funcName}(int row, int col, int depth,
85935 int depth2, int depth3, int depth4) {
85936 int index = round(dot(
85937 vec4(row, col, depth, depth2),
85938 vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) +
85939 dot(
85940 vec2(depth3, depth4),
85941 vec2(${stride4}, 1)));
85942 ${getUniformSampler(inputInfo)}
85943 }
85944 `;
85945 }
85946 const flatOffset = inputInfo.shapeInfo.flatOffset;
85947 const texShape = inputInfo.shapeInfo.texShape;
85948 const texNumR = texShape[0];
85949 const texNumC = texShape[1];
85950 if (texNumC === stride0 && flatOffset == null) {
85951 // texC is used directly as physical (no risk of float16 overflow).
85952 return `
85953 float ${funcName}(int row, int col, int depth,
85954 int depth2, int depth3, int depth4) {
85955 int texR = row;
85956 float texC = dot(vec4(col, depth, depth2, depth3),
85957 vec4(${stride1}, ${stride2}, ${stride3}, ${stride4})) +
85958 float(depth4);
85959 vec2 uv = (vec2(texC, texR) + halfCR) /
85960 vec2(${texNumC}.0, ${texNumR}.0);
85961 return sampleTexture(${texName}, uv);
85962 }
85963 `;
85964 }
85965 if (texNumC === stride4 && flatOffset == null) {
85966 // texR is used directly as physical (no risk of float16 overflow).
85967 return `
85968 float ${funcName}(int row, int col, int depth,
85969 int depth2, int depth3, int depth4) {
85970 float texR = dot(vec4(row, col, depth, depth2),
85971 vec4(${shape[1] * shape[2] * shape[3] * shape[4]},
85972 ${shape[2] * shape[3] * shape[4]},
85973 ${shape[3] * shape[4]},
85974 ${shape[4]})) + float(depth3);
85975 int texC = depth4;
85976 vec2 uv = (vec2(texC, texR) + halfCR) /
85977 vec2(${texNumC}.0, ${texNumR}.0);
85978 return sampleTexture(${texName}, uv);
85979 }
85980 `;
85981 }
85982 const offset = getFlatOffsetUniformName(texName);
85983 return `
85984 float ${funcName}(int row, int col, int depth,
85985 int depth2, int depth3, int depth4) {
85986 // Explicitly use integer operations as dot() only works on floats.
85987 int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} +
85988 depth2 * ${stride3} + depth3 * ${stride4} + depth4 + ${offset};
85989 vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
85990 return sampleTexture(${texName}, uv);
85991 }
85992 `;
85993 }
85994 function getUniformSampler(inputInfo) {
85995 const texName = inputInfo.name;
85996 const inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape);
85997 if (inSize < 2) {
85998 return `return ${texName};`;
85999 }
86000 return `
86001 for (int i = 0; i < ${inSize}; i++) {
86002 if (i == index) {
86003 return ${texName}[i];
86004 }
86005 }
86006 `;
86007 }
86008 function getPackedSamplerAtOutputCoords(inputInfo, outShapeInfo) {
86009 const texName = inputInfo.name;
86010 const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
86011 const funcName = 'get' + texFuncSnippet + 'AtOutCoords';
86012 const inRank = inputInfo.shapeInfo.logicalShape.length;
86013 const outRank = outShapeInfo.logicalShape.length;
86014 const broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
86015 const type = getCoordsDataType(outRank);
86016 const rankDiff = outRank - inRank;
86017 let coordsSnippet;
86018 const fields = ['x', 'y', 'z', 'w', 'u', 'v'];
86019 if (inRank === 0) {
86020 coordsSnippet = '';
86021 }
86022 else if (outRank < 2 && broadcastDims.length >= 1) {
86023 coordsSnippet = 'coords = 0;';
86024 }
86025 else {
86026 coordsSnippet =
86027 broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`)
86028 .join('\n');
86029 }
86030 let unpackedCoordsSnippet = '';
86031 if (outRank < 2 && inRank > 0) {
86032 unpackedCoordsSnippet = 'coords';
86033 }
86034 else {
86035 unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape
86036 .map((s, i) => `coords.${fields[i + rankDiff]}`)
86037 .join(', ');
86038 }
86039 let output = `return outputValue;`;
86040 const inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape);
86041 const isInputScalar = inSize === 1;
86042 const outSize = sizeFromShape(outShapeInfo.logicalShape);
86043 const isOutputScalar = outSize === 1;
86044 if (inRank === 1 && !isInputScalar && !isOutputScalar) {
86045 output = `
86046 return vec4(outputValue.xy, outputValue.xy);
86047 `;
86048 }
86049 else if (isInputScalar && !isOutputScalar) {
86050 if (outRank === 1) {
86051 output = `
86052 return vec4(outputValue.x, outputValue.x, 0., 0.);
86053 `;
86054 }
86055 else {
86056 output = `
86057 return vec4(outputValue.x);
86058 `;
86059 }
86060 }
86061 else if (broadcastDims.length) {
86062 const rows = inRank - 2;
86063 const cols = inRank - 1;
86064 if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) {
86065 output = `return vec4(outputValue.x);`;
86066 }
86067 else if (broadcastDims.indexOf(rows) > -1) {
86068 output = `return vec4(outputValue.x, outputValue.y, ` +
86069 `outputValue.x, outputValue.y);`;
86070 }
86071 else if (broadcastDims.indexOf(cols) > -1) {
86072 output = `return vec4(outputValue.xx, outputValue.zz);`;
86073 }
86074 }
86075 return `
86076 vec4 ${funcName}() {
86077 ${type} coords = getOutputCoords();
86078 ${coordsSnippet}
86079 vec4 outputValue = get${texFuncSnippet}(${unpackedCoordsSnippet});
86080 ${output}
86081 }
86082 `;
86083 }
86084 function getSamplerAtOutputCoords(inputInfo, outShapeInfo) {
86085 const texName = inputInfo.name;
86086 const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
86087 const funcName = 'get' + texFuncSnippet + 'AtOutCoords';
86088 const outTexShape = outShapeInfo.texShape;
86089 const inTexShape = inputInfo.shapeInfo.texShape;
86090 const inRank = inputInfo.shapeInfo.logicalShape.length;
86091 const outRank = outShapeInfo.logicalShape.length;
86092 if (!inputInfo.shapeInfo.isUniform && inRank === outRank &&
86093 inputInfo.shapeInfo.flatOffset == null &&
86094 arraysEqual(inTexShape, outTexShape)) {
86095 return `
86096 float ${funcName}() {
86097 return sampleTexture(${texName}, resultUV);
86098 }
86099 `;
86100 }
86101 const type = getCoordsDataType(outRank);
86102 const broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
86103 const rankDiff = outRank - inRank;
86104 let coordsSnippet;
86105 const fields = ['x', 'y', 'z', 'w', 'u', 'v'];
86106 if (inRank === 0) {
86107 coordsSnippet = '';
86108 }
86109 else if (outRank < 2 && broadcastDims.length >= 1) {
86110 coordsSnippet = 'coords = 0;';
86111 }
86112 else {
86113 coordsSnippet =
86114 broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`)
86115 .join('\n');
86116 }
86117 let unpackedCoordsSnippet = '';
86118 if (outRank < 2 && inRank > 0) {
86119 unpackedCoordsSnippet = 'coords';
86120 }
86121 else {
86122 unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape
86123 .map((s, i) => `coords.${fields[i + rankDiff]}`)
86124 .join(', ');
86125 }
86126 return `
86127 float ${funcName}() {
86128 ${type} coords = getOutputCoords();
86129 ${coordsSnippet}
86130 return get${texFuncSnippet}(${unpackedCoordsSnippet});
86131 }
86132 `;
86133 }
86134 function getCoordsDataType(rank) {
86135 if (rank <= 1) {
86136 return 'int';
86137 }
86138 else if (rank === 2) {
86139 return 'ivec2';
86140 }
86141 else if (rank === 3) {
86142 return 'ivec3';
86143 }
86144 else if (rank === 4) {
86145 return 'ivec4';
86146 }
86147 else if (rank === 5) {
86148 return 'ivec5';
86149 }
86150 else if (rank === 6) {
86151 return 'ivec6';
86152 }
86153 else {
86154 throw Error(`GPU for rank ${rank} is not yet supported`);
86155 }
86156 }
86157 function getUniformInfoFromShape(isPacked, shape, texShape) {
86158 const { newShape, keptDims } = squeezeShape(shape);
86159 const rank = shape.length;
86160 const useSqueezePackedShape = isPacked && rank === 3 && shape[0] === 1;
86161 const squeezeShape$1 = useSqueezePackedShape ? shape.slice(1) : newShape;
86162 const useSqueezeShape = (!isPacked && rank > 1 && !arraysEqual(shape, texShape) &&
86163 newShape.length < rank) ||
86164 useSqueezePackedShape;
86165 const uniformShape = useSqueezeShape ? squeezeShape$1 : shape;
86166 return { useSqueezeShape, uniformShape, keptDims };
86167 }
86168 /** Returns a new input info (a copy) that has a squeezed logical shape. */
86169 function squeezeInputInfo(inInfo, squeezedShape) {
86170 // Deep copy.
86171 const newInputInfo = JSON.parse(JSON.stringify(inInfo));
86172 newInputInfo.shapeInfo.logicalShape = squeezedShape;
86173 return newInputInfo;
86174 }
86175 function getSqueezedParams(params, keptDims) {
86176 return keptDims.map(d => params[d]).join(', ');
86177 }
86178
86179 /**
86180 * @license
86181 * Copyright 2017 Google LLC. All Rights Reserved.
86182 * Licensed under the Apache License, Version 2.0 (the "License");
86183 * you may not use this file except in compliance with the License.
86184 * You may obtain a copy of the License at
86185 *
86186 * http://www.apache.org/licenses/LICENSE-2.0
86187 *
86188 * Unless required by applicable law or agreed to in writing, software
86189 * distributed under the License is distributed on an "AS IS" BASIS,
86190 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86191 * See the License for the specific language governing permissions and
86192 * limitations under the License.
86193 * =============================================================================
86194 */
86195 function compileProgram(gpgpu, program, inputs, output) {
86196 const inputInfos = inputs.map((input, i) => {
86197 const shapeInfo = {
86198 logicalShape: input.shape,
86199 texShape: input.isUniform ? null : input.texData.texShape,
86200 isUniform: input.isUniform,
86201 isPacked: input.isUniform ? false : input.texData.isPacked,
86202 flatOffset: null
86203 };
86204 if (input.texData != null && input.texData.slice != null &&
86205 input.texData.slice.flatOffset > 0) {
86206 shapeInfo.flatOffset = input.texData.slice.flatOffset;
86207 }
86208 return { name: program.variableNames[i], shapeInfo };
86209 });
86210 const inShapeInfos = inputInfos.map(x => x.shapeInfo);
86211 const outShapeInfo = {
86212 logicalShape: output.shape,
86213 texShape: output.texData.texShape,
86214 isUniform: false,
86215 isPacked: output.texData.isPacked,
86216 flatOffset: null
86217 };
86218 const source = makeShader(inputInfos, outShapeInfo, program);
86219 const fragmentShader = createFragmentShader(gpgpu.gl, source);
86220 const webGLProgram = gpgpu.createProgram(fragmentShader);
86221 if (!env().get('ENGINE_COMPILE_ONLY')) {
86222 gpgpu.buildVao(webGLProgram);
86223 return Object.assign({ program,
86224 fragmentShader,
86225 source,
86226 webGLProgram,
86227 inShapeInfos,
86228 outShapeInfo }, getUniformLocations(gpgpu, program, webGLProgram));
86229 }
86230 else {
86231 return {
86232 program,
86233 fragmentShader,
86234 source,
86235 webGLProgram,
86236 inShapeInfos,
86237 outShapeInfo,
86238 variablesLocations: null,
86239 customUniformLocations: null,
86240 infLoc: null,
86241 nanLoc: null,
86242 outShapeLocation: null,
86243 outShapeStridesLocation: null,
86244 outTexShapeLocation: null
86245 };
86246 }
86247 }
86248 function getUniformLocations(gpgpu, program, webGLProgram) {
86249 const variablesLocations = [];
86250 const customUniformLocations = [];
86251 let outShapeLocation;
86252 let outTexShapeLocation;
86253 let outShapeStridesLocation;
86254 let infLoc = null;
86255 let nanLoc = null;
86256 // Add special uniforms (NAN, INFINITY)
86257 nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false);
86258 if (env().getNumber('WEBGL_VERSION') === 1) {
86259 infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false);
86260 }
86261 // Add user-defined uniforms
86262 const shouldThrow = false;
86263 for (const varName of program.variableNames) {
86264 const varLocs = {
86265 name: varName,
86266 uniform: gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow),
86267 offset: gpgpu.getUniformLocation(webGLProgram, `offset${varName}`, shouldThrow),
86268 };
86269 if (program.enableShapeUniforms) {
86270 varLocs.shape = gpgpu.getUniformLocation(webGLProgram, `${varName}Shape`, shouldThrow);
86271 varLocs.texShape = gpgpu.getUniformLocation(webGLProgram, `${varName}TexShape`, shouldThrow);
86272 }
86273 variablesLocations.push(varLocs);
86274 }
86275 if (program.enableShapeUniforms) {
86276 outShapeLocation =
86277 gpgpu.getUniformLocation(webGLProgram, 'outShape', shouldThrow);
86278 outShapeStridesLocation =
86279 gpgpu.getUniformLocation(webGLProgram, 'outShapeStrides', shouldThrow);
86280 outTexShapeLocation =
86281 gpgpu.getUniformLocation(webGLProgram, 'outTexShape', shouldThrow);
86282 }
86283 if (program.customUniforms) {
86284 for (const d of program.customUniforms) {
86285 customUniformLocations.push(gpgpu.getUniformLocation(webGLProgram, d.name, shouldThrow));
86286 }
86287 }
86288 return {
86289 variablesLocations,
86290 customUniformLocations,
86291 infLoc,
86292 nanLoc,
86293 outShapeLocation,
86294 outShapeStridesLocation,
86295 outTexShapeLocation
86296 };
86297 }
86298 function validateBinaryAndProgram(shapeInfos, inputs) {
86299 if (shapeInfos.length !== inputs.length) {
86300 throw Error(`Binary was compiled with ${shapeInfos.length} inputs, but ` +
86301 `was executed with ${inputs.length} inputs`);
86302 }
86303 shapeInfos.forEach((s, i) => {
86304 const shapeA = s.logicalShape;
86305 const input = inputs[i];
86306 const shapeB = input.shape;
86307 if (!arraysEqual(shapeA, shapeB)) {
86308 throw Error(`Binary was compiled with different shapes than ` +
86309 `the current args. Shapes ${shapeA} and ${shapeB} must match`);
86310 }
86311 // The input is uploaded as uniform.
86312 if (s.isUniform && input.isUniform) {
86313 return;
86314 }
86315 const texShapeA = s.texShape;
86316 const texShapeB = input.isUniform ? null : input.texData.texShape;
86317 if (!arraysEqual(texShapeA, texShapeB)) {
86318 throw Error(`Binary was compiled with different texture shapes than the` +
86319 ` current args. Shape ${texShapeA} and ${texShapeB} must match`);
86320 }
86321 });
86322 }
86323 function runProgram(gpgpu, binary, inputs, output, customUniformValues) {
86324 if (!binary.program.enableShapeUniforms) {
86325 validateBinaryAndProgram(binary.inShapeInfos, inputs);
86326 validateBinaryAndProgram([binary.outShapeInfo], [output]);
86327 }
86328 const outTex = output.texData.texture;
86329 const outTexShape = output.texData.texShape;
86330 if (output.texData.isPacked) {
86331 gpgpu.setOutputPackedMatrixTexture(outTex.texture, outTexShape[0], outTexShape[1]);
86332 }
86333 else {
86334 gpgpu.setOutputMatrixTexture(outTex.texture, outTexShape[0], outTexShape[1]);
86335 }
86336 gpgpu.setProgram(binary.webGLProgram);
86337 gpgpu.bindVertexArray(binary.webGLProgram.vao);
86338 // Set special uniforms (NAN, INFINITY)
86339 if (env().getNumber('WEBGL_VERSION') === 1) {
86340 if (binary.infLoc !== null) {
86341 gpgpu.gl.uniform1f(binary.infLoc, Infinity);
86342 }
86343 }
86344 if (binary.nanLoc !== null) {
86345 gpgpu.gl.uniform1f(binary.nanLoc, NaN);
86346 }
86347 // Set user-defined inputs
86348 for (let i = 0; i < inputs.length; ++i) {
86349 const input = inputs[i];
86350 const { uniform: varLoc, offset: varOffsetLoc, shape: varShapeLoc, texShape: varTexShapeLoc, } = binary.variablesLocations[i];
86351 if (varShapeLoc) {
86352 const { uniformShape } = getUniformInfoFromShape(binary.program.packedInputs, input.shape, input.texData.texShape);
86353 switch (uniformShape.length) {
86354 case 1:
86355 gpgpu.gl.uniform1iv(varShapeLoc, new Int32Array(uniformShape));
86356 break;
86357 case 2:
86358 gpgpu.gl.uniform2iv(varShapeLoc, new Int32Array(uniformShape));
86359 break;
86360 case 3:
86361 gpgpu.gl.uniform3iv(varShapeLoc, new Int32Array(uniformShape));
86362 break;
86363 case 4:
86364 gpgpu.gl.uniform4iv(varShapeLoc, new Int32Array(uniformShape));
86365 break;
86366 default:
86367 break;
86368 }
86369 }
86370 if (varTexShapeLoc) {
86371 gpgpu.gl.uniform2i(varTexShapeLoc, input.texData.texShape[0], input.texData.texShape[1]);
86372 }
86373 if (varLoc == null) {
86374 // The compiler inferred that this variable is not used in this shader.
86375 continue;
86376 }
86377 if (input.isUniform) {
86378 // Upload the values of the tensor as uniform.
86379 if (sizeFromShape(input.shape) < 2) {
86380 gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]);
86381 }
86382 else {
86383 let vals = input.uniformValues;
86384 if (!(vals instanceof Float32Array)) {
86385 vals = new Float32Array(vals);
86386 }
86387 gpgpu.gl.uniform1fv(varLoc, vals);
86388 }
86389 continue;
86390 }
86391 // If the input was sliced, upload the flat offset index.
86392 if (input.texData.slice != null && varOffsetLoc != null) {
86393 gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset);
86394 }
86395 gpgpu.setInputMatrixTexture(input.texData.texture.texture, varLoc, i);
86396 }
86397 const outShapeLoc = binary.outShapeLocation;
86398 if (outShapeLoc) {
86399 switch (output.shape.length) {
86400 case 1:
86401 gpgpu.gl.uniform1iv(outShapeLoc, new Int32Array(output.shape));
86402 break;
86403 case 2:
86404 gpgpu.gl.uniform2iv(outShapeLoc, new Int32Array(output.shape));
86405 break;
86406 case 3:
86407 gpgpu.gl.uniform3iv(outShapeLoc, new Int32Array(output.shape));
86408 break;
86409 case 4:
86410 gpgpu.gl.uniform4iv(outShapeLoc, new Int32Array(output.shape));
86411 break;
86412 default:
86413 break;
86414 }
86415 }
86416 if (binary.outShapeStridesLocation) {
86417 const strides = computeStrides(output.shape);
86418 switch (output.shape.length) {
86419 case 2:
86420 gpgpu.gl.uniform1iv(binary.outShapeStridesLocation, new Int32Array(strides));
86421 break;
86422 case 3:
86423 gpgpu.gl.uniform2iv(binary.outShapeStridesLocation, new Int32Array(strides));
86424 break;
86425 case 4:
86426 gpgpu.gl.uniform3iv(binary.outShapeStridesLocation, new Int32Array(strides));
86427 break;
86428 default:
86429 break;
86430 }
86431 }
86432 if (binary.outTexShapeLocation) {
86433 gpgpu.gl.uniform2i(binary.outTexShapeLocation, output.texData.texShape[0], output.texData.texShape[1]);
86434 }
86435 if (binary.program.customUniforms && customUniformValues) {
86436 for (let i = 0; i < binary.program.customUniforms.length; ++i) {
86437 const d = binary.program.customUniforms[i];
86438 const customLoc = binary.customUniformLocations[i];
86439 const customValue = customUniformValues[i];
86440 if (d.type === 'float') {
86441 gpgpu.gl.uniform1fv(customLoc, customValue);
86442 }
86443 else if (d.type === 'vec2') {
86444 gpgpu.gl.uniform2fv(customLoc, customValue);
86445 }
86446 else if (d.type === 'vec3') {
86447 gpgpu.gl.uniform3fv(customLoc, customValue);
86448 }
86449 else if (d.type === 'vec4') {
86450 gpgpu.gl.uniform4fv(customLoc, customValue);
86451 }
86452 else if (d.type === 'int') {
86453 gpgpu.gl.uniform1iv(customLoc, customValue);
86454 }
86455 else if (d.type === 'ivec2') {
86456 gpgpu.gl.uniform2iv(customLoc, customValue);
86457 }
86458 else if (d.type === 'ivec3') {
86459 gpgpu.gl.uniform3iv(customLoc, customValue);
86460 }
86461 else if (d.type === 'ivec4') {
86462 gpgpu.gl.uniform4iv(customLoc, customValue);
86463 }
86464 else {
86465 throw Error(`uniform type ${d.type} is not supported yet.`);
86466 }
86467 }
86468 }
86469 gpgpu.executeProgram();
86470 }
86471 function makeShaderKey(program, inputs, output) {
86472 let keyInputs = '';
86473 inputs.concat(output).forEach(x => {
86474 const hasOffset = x.texData != null && x.texData.slice != null &&
86475 x.texData.slice.flatOffset > 0;
86476 // TODO: Remove the condition of !x.isUniform.
86477 if (program.enableShapeUniforms && !x.isUniform) {
86478 const xTexShape = x.texData.texShape;
86479 const { useSqueezeShape, uniformShape, keptDims } = getUniformInfoFromShape(program.packedInputs, x.shape, xTexShape);
86480 let rank1 = '', rank2 = '', rank34 = '';
86481 if (uniformShape.length === 1 && program.packedInputs) {
86482 const packedTexShape = [Math.ceil(xTexShape[0] / 2), Math.ceil(xTexShape[1] / 2)];
86483 rank1 = `${packedTexShape[0] > 1}_${packedTexShape[1] > 1}`;
86484 }
86485 else if (uniformShape.length === 2 && !program.packedInputs) {
86486 rank2 = `${uniformShape[0] > 1}_${uniformShape[1] > 1}`;
86487 }
86488 else if (uniformShape.length > 2 && !program.packedInputs) {
86489 const strides = computeStrides(uniformShape);
86490 rank34 = `${strides[0] === xTexShape[1]}_${strides[strides.length - 1] === xTexShape[1]}`;
86491 }
86492 const xRank = x.shape.length;
86493 const isLogicalShapTexShapeEqual = uniformShape.length === 2 && arraysEqual(x.shape, xTexShape);
86494 const isScalar = sizeFromShape(x.shape) === 1;
86495 const broadcastDims = getBroadcastDims$1(x.shape, output.shape);
86496 const isInOutTexShapeEqual = !program.packedInputs &&
86497 xRank === output.shape.length &&
86498 arraysEqual(xTexShape, output.texData.texShape);
86499 const isTexShapeGreaterThanOne = program.packedInputs || uniformShape.length > 2 ?
86500 '' :
86501 `${xTexShape[0] > 1}_${xTexShape[1] > 1}`;
86502 // These key components are needed due to shader_compiler is embedding
86503 // them in the shader.
86504 // |xRank| is used to determine the coords length. See
86505 // get[Packed]SamplerAtOutputCoords.
86506 // |isInOutTexShapeEqual| is used to determine whether going to an
86507 // optimization path in getSamplerAtOutputCoords.
86508 // |useSqueezeShape| is extracted from squeezeInputInfo of
86509 // getSampler[2|3|4]D/getPackedSampler3D.
86510 // |isScalar| is extracted from isInputScalar/isOutputScalar in
86511 // getPackedSamplerAtOutputCoords.
86512 // |broadcastDims| is extracted from get[Packed]SamplerAtOutputCoords.
86513 // |isLogicalShapTexShapeEqual| is used in
86514 // getOutput[Packed]2DCoords/get[Packed]Sampler2D.
86515 // |rank1| is used in getOutputPacked1DCoords.
86516 // |rank2| is used in getOutput2DCoords.
86517 // |rank34| is used in getSampler3D/getSampler4D.
86518 // |isTexShapeGreaterThanOne| are used in
86519 // getSampler[Scalar|1D|2D]/getOutput1DCoords.
86520 keyInputs += `${xRank}_${isInOutTexShapeEqual}_${useSqueezeShape ? keptDims : ''}_${uniformShape.length}_${isScalar}_${broadcastDims}_${isLogicalShapTexShapeEqual}_${rank1}_${rank2}_${rank34}_${isTexShapeGreaterThanOne}_${hasOffset}`;
86521 }
86522 else {
86523 const texShape = x.isUniform ? 'uniform' : x.texData.texShape;
86524 keyInputs += `${x.shape}_${texShape}_${hasOffset}`;
86525 }
86526 });
86527 const keyUserCode = program.userCode;
86528 let key = program.constructor.name;
86529 // Fast string concat. See https://jsperf.com/string-concatenation/14.
86530 key += '_' + keyInputs + '_' + keyUserCode +
86531 `${env().getNumber('WEBGL_VERSION')}`;
86532 return key;
86533 }
86534 function useShapeUniforms(rank) {
86535 // TODO: Remove the limitaion of rank <= 4.
86536 return env().getBool('WEBGL_USE_SHAPES_UNIFORMS') && rank <= 4;
86537 }
86538
86539 /**
86540 * @license
86541 * Copyright 2019 Google LLC. All Rights Reserved.
86542 * Licensed under the Apache License, Version 2.0 (the "License");
86543 * you may not use this file except in compliance with the License.
86544 * You may obtain a copy of the License at
86545 *
86546 * http://www.apache.org/licenses/LICENSE-2.0
86547 *
86548 * Unless required by applicable law or agreed to in writing, software
86549 * distributed under the License is distributed on an "AS IS" BASIS,
86550 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86551 * See the License for the specific language governing permissions and
86552 * limitations under the License.
86553 * =============================================================================
86554 */
86555 class DecodeMatrixProgram {
86556 constructor(outputShape) {
86557 this.variableNames = ['A'];
86558 this.packedInputs = false;
86559 this.packedOutput = true;
86560 this.outPackingScheme = PackingScheme.DENSE;
86561 this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
86562 const glsl = getGlslDifferences();
86563 this.outputShape = outputShape;
86564 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
86565 this.userCode = `
86566 ivec3 outCoordsFromFlatIndex(int index) {
86567 ${this.enableShapeUniforms ?
86568 getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) :
86569 getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape)}
86570 return ivec3(r, c, d);
86571 }
86572
86573 void main() {
86574 ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1]));
86575 int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y);
86576
86577 vec4 result = vec4(0.);
86578
86579 for (int i=0; i<4; i++) {
86580 int flatIndex = index + i;
86581 ivec3 rc = outCoordsFromFlatIndex(flatIndex);
86582 result[i] = getA(rc.x, rc.y, rc.z);
86583 }
86584
86585 ${glsl.output} = result;
86586 }
86587 `;
86588 }
86589 }
86590
86591 /**
86592 * @license
86593 * Copyright 2019 Google LLC. All Rights Reserved.
86594 * Licensed under the Apache License, Version 2.0 (the "License");
86595 * you may not use this file except in compliance with the License.
86596 * You may obtain a copy of the License at
86597 *
86598 * http://www.apache.org/licenses/LICENSE-2.0
86599 *
86600 * Unless required by applicable law or agreed to in writing, software
86601 * distributed under the License is distributed on an "AS IS" BASIS,
86602 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86603 * See the License for the specific language governing permissions and
86604 * limitations under the License.
86605 * =============================================================================
86606 */
86607 class DecodeMatrixPackedProgram {
86608 constructor(outputShape) {
86609 this.variableNames = ['A'];
86610 this.packedInputs = true;
86611 this.packedOutput = true;
86612 this.outPackingScheme = PackingScheme.DENSE;
86613 this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
86614 const glsl = getGlslDifferences();
86615 this.outputShape = outputShape;
86616 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
86617 this.userCode = `
86618 ivec3 outCoordsFromFlatIndex(int index) {
86619 ${this.enableShapeUniforms ?
86620 getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) :
86621 getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape)}
86622 return ivec3(r, c, d);
86623 }
86624
86625 void main() {
86626 ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1]));
86627 int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y);
86628
86629 vec4 result = vec4(0.);
86630
86631 for (int i=0; i<4; i++) {
86632 int flatIndex = index + i;
86633 ivec3 rc = outCoordsFromFlatIndex(flatIndex);
86634 result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z));
86635 }
86636
86637 ${glsl.output} = result;
86638 }
86639 `;
86640 }
86641 }
86642
86643 /**
86644 * @license
86645 * Copyright 2018 Google LLC. All Rights Reserved.
86646 * Licensed under the Apache License, Version 2.0 (the "License");
86647 * you may not use this file except in compliance with the License.
86648 * You may obtain a copy of the License at
86649 *
86650 * http://www.apache.org/licenses/LICENSE-2.0
86651 *
86652 * Unless required by applicable law or agreed to in writing, software
86653 * distributed under the License is distributed on an "AS IS" BASIS,
86654 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86655 * See the License for the specific language governing permissions and
86656 * limitations under the License.
86657 * =============================================================================
86658 */
86659 class EncodeFloatProgram {
86660 constructor(outputShape) {
86661 this.variableNames = ['A'];
86662 this.outTexUsage = TextureUsage.DOWNLOAD;
86663 const glsl = getGlslDifferences();
86664 this.outputShape = outputShape;
86665 this.userCode = `
86666 ${ENCODE_FLOAT_SNIPPET}
86667
86668 void main() {
86669 float x = getAAtOutCoords();
86670 ${glsl.output} = encode_float(x);
86671 }
86672 `;
86673 }
86674 }
86675
86676 /**
86677 * @license
86678 * Copyright 2018 Google LLC. All Rights Reserved.
86679 * Licensed under the Apache License, Version 2.0 (the "License");
86680 * you may not use this file except in compliance with the License.
86681 * You may obtain a copy of the License at
86682 *
86683 * http://www.apache.org/licenses/LICENSE-2.0
86684 *
86685 * Unless required by applicable law or agreed to in writing, software
86686 * distributed under the License is distributed on an "AS IS" BASIS,
86687 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86688 * See the License for the specific language governing permissions and
86689 * limitations under the License.
86690 * =============================================================================
86691 */
86692 class EncodeFloatPackedProgram {
86693 constructor(outputShape) {
86694 this.variableNames = ['A'];
86695 this.packedInputs = true;
86696 this.packedOutput = false;
86697 this.outTexUsage = TextureUsage.DOWNLOAD;
86698 const glsl = getGlslDifferences();
86699 this.outputShape = outputShape;
86700 this.userCode = `
86701 ${ENCODE_FLOAT_SNIPPET}
86702
86703 void main() {
86704 ivec3 coords = getOutputCoords();
86705 float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z));
86706 ${glsl.output} = encode_float(x);
86707 }
86708 `;
86709 }
86710 }
86711
86712 /**
86713 * @license
86714 * Copyright 2018 Google LLC. All Rights Reserved.
86715 * Licensed under the Apache License, Version 2.0 (the "License");
86716 * you may not use this file except in compliance with the License.
86717 * You may obtain a copy of the License at
86718 *
86719 * http://www.apache.org/licenses/LICENSE-2.0
86720 *
86721 * Unless required by applicable law or agreed to in writing, software
86722 * distributed under the License is distributed on an "AS IS" BASIS,
86723 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86724 * See the License for the specific language governing permissions and
86725 * limitations under the License.
86726 * =============================================================================
86727 */
86728 const CHANNEL_CHAR_TO_INDEX_MAP = {
86729 'R': 0,
86730 'G': 1,
86731 'B': 2,
86732 'A': 3
86733 };
86734 class EncodeMatrixProgram {
86735 constructor(outputShape, inputIsUnsignedByte = false, usedChannels = 'RGBA') {
86736 this.variableNames = ['A'];
86737 this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
86738 const glsl = getGlslDifferences();
86739 this.outputShape = outputShape;
86740 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
86741 let output = `result`;
86742 if (inputIsUnsignedByte) {
86743 output = `floor(result * 255. + 0.5)`;
86744 }
86745 let mainLoop = '';
86746 for (let usedChannelIndex = 0; usedChannelIndex < usedChannels.length; usedChannelIndex++) {
86747 const curChannel = usedChannels[usedChannelIndex];
86748 mainLoop += `
86749 if(offset == ${usedChannelIndex}) {
86750 result = values[${CHANNEL_CHAR_TO_INDEX_MAP[curChannel]}];
86751 }`;
86752 }
86753 this.userCode = `
86754 ${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() :
86755 getFlatIndexFrom3D(outputShape)}
86756
86757 void main() {
86758 ivec3 coords = getOutputCoords();
86759 int flatIndex = getFlatIndex(coords);
86760 float result = 0.;
86761 int offset = imod(flatIndex, ${usedChannels.length});
86762
86763 flatIndex = idiv(flatIndex, ${usedChannels.length}, 1.);
86764
86765 int r = flatIndex / texShape[1];
86766 if (r < texShape[0]) {
86767 int c = imod(flatIndex, texShape[1]);
86768 vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]);
86769 vec4 values = ${glsl.texture2D}(A, uv);
86770 ${mainLoop}
86771 }
86772 ${glsl.output} = vec4(${output}, 0., 0., 0.);
86773 }
86774 `;
86775 }
86776 }
86777
86778 /**
86779 * @license
86780 * Copyright 2018 Google LLC. All Rights Reserved.
86781 * Licensed under the Apache License, Version 2.0 (the "License");
86782 * you may not use this file except in compliance with the License.
86783 * You may obtain a copy of the License at
86784 *
86785 * http://www.apache.org/licenses/LICENSE-2.0
86786 *
86787 * Unless required by applicable law or agreed to in writing, software
86788 * distributed under the License is distributed on an "AS IS" BASIS,
86789 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86790 * See the License for the specific language governing permissions and
86791 * limitations under the License.
86792 * =============================================================================
86793 */
86794 /*
86795 This is how the shader encodes a tensor with shape = [2, 3, 5]
86796 (indices are [batch, row, col]).
86797
86798 000|001 002|003 004|xxx 020|021 022|023 024|xxx
86799 ------- ------- ------- ------- ------- -------
86800 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx
86801
86802 100|101 102|103 104|xxx 120|121 122|123 124|xxx
86803 ------- ------- ------- ------- ------- -------
86804 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx
86805
86806 Single texels contain only values from the same batch, and from adjacent rows
86807 and columns.
86808 */
86809 class EncodeMatrixPackedProgram {
86810 constructor(outputShape, inputIsUnsignedByte = false) {
86811 this.variableNames = ['A'];
86812 this.packedInputs = false;
86813 this.packedOutput = true;
86814 this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
86815 const glsl = getGlslDifferences();
86816 this.outputShape = outputShape;
86817 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
86818 let mainLoop = '';
86819 let output = 'result';
86820 if (inputIsUnsignedByte) {
86821 output = 'floor(result * 255. + 0.5)';
86822 }
86823 for (let row = 0; row <= 1; row++) {
86824 for (let col = 0; col <= 1; col++) {
86825 const channel = row * 2 + col;
86826 mainLoop += `
86827 localCoords = coords;
86828 if(localCoords[2] + ${col} < ${this.enableShapeUniforms ? 'outShape[2]' : `${outputShape[2]}`}) {
86829 localCoords[2] += ${col};
86830 if (localCoords[1] + ${row} < ${this.enableShapeUniforms ? 'outShape[1]' : `${outputShape[1]}`}) {
86831 localCoords[1] += ${row};
86832
86833 flatIndex = getFlatIndex(localCoords);
86834 offset = imod(flatIndex, 4);
86835
86836 flatIndex = idiv(flatIndex, 4, 1.);
86837
86838 int r = flatIndex / texShape[1];
86839 int c = imod(flatIndex, texShape[1]);
86840 vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]);
86841 values = ${glsl.texture2D}(A, uv);
86842
86843 if (offset == 0) {
86844 result[${channel}] = values[0];
86845 } else if (offset == 1) {
86846 result[${channel}] = values[1];
86847 } else if (offset == 2) {
86848 result[${channel}] = values[2];
86849 } else {
86850 result[${channel}] = values[3];
86851 }
86852 }
86853 }
86854 `;
86855 }
86856 }
86857 this.userCode = `
86858 ${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() :
86859 getFlatIndexFrom3D(outputShape)}
86860
86861 void main() {
86862 ivec3 coords = getOutputCoords();
86863
86864 vec4 result = vec4(0.);
86865 int flatIndex, r, c, offset;
86866 ivec3 localCoords;
86867 vec2 uv;
86868 vec4 values;
86869
86870 ${mainLoop}
86871
86872 ${glsl.output} = ${output};
86873 }
86874 `;
86875 }
86876 }
86877
86878 /**
86879 * @license
86880 * Copyright 2017 Google LLC. All Rights Reserved.
86881 * Licensed under the Apache License, Version 2.0 (the "License");
86882 * you may not use this file except in compliance with the License.
86883 * You may obtain a copy of the License at
86884 *
86885 * http://www.apache.org/licenses/LICENSE-2.0
86886 *
86887 * Unless required by applicable law or agreed to in writing, software
86888 * distributed under the License is distributed on an "AS IS" BASIS,
86889 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86890 * See the License for the specific language governing permissions and
86891 * limitations under the License.
86892 * =============================================================================
86893 */
86894 function createVertexShader(gl) {
86895 const glsl = getGlslDifferences();
86896 const vertexShaderSource = `${glsl.version}
86897 precision highp float;
86898 ${glsl.attribute} vec3 clipSpacePos;
86899 ${glsl.attribute} vec2 uv;
86900 ${glsl.varyingVs} vec2 resultUV;
86901
86902 void main() {
86903 gl_Position = vec4(clipSpacePos, 1);
86904 resultUV = uv;
86905 }`;
86906 return createVertexShader$1(gl, vertexShaderSource);
86907 }
86908 function createVertexBuffer(gl) {
86909 // [x y z u v] * [upper-left, lower-left, upper-right, lower-right]
86910 const vertexArray = new Float32Array([-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]);
86911 return createStaticVertexBuffer(gl, vertexArray);
86912 }
86913 function createIndexBuffer(gl) {
86914 // OpenGL (and WebGL) have "CCW == front" winding
86915 const triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]);
86916 return createStaticIndexBuffer(gl, triangleVertexIndices);
86917 }
86918 function createAndConfigureTexture(gl, width, height, internalFormat, textureFormat, textureType) {
86919 validateTextureSize(width, height);
86920 const texture = createTexture(gl);
86921 const tex2d = gl.TEXTURE_2D;
86922 callAndCheck(gl, () => gl.bindTexture(tex2d, texture));
86923 callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE));
86924 callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE));
86925 callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST));
86926 callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST));
86927 if (env().getNumber('WEBGL_VERSION') === 1) {
86928 callAndCheck(gl, () => gl.texImage2D(tex2d, 0, internalFormat, width, height, 0, textureFormat, textureType, null));
86929 }
86930 else {
86931 callAndCheck(gl, () => gl
86932 .texStorage2D(tex2d, 1, internalFormat, width, height));
86933 }
86934 callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
86935 return { texture, texShape: [height, width] };
86936 }
86937 function getInternalFormatForFloat32MatrixTexture(textureConfig) {
86938 return textureConfig.internalFormatFloat;
86939 }
86940 function createFloat32MatrixTexture(gl, rows, columns, textureConfig) {
86941 const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
86942 return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat32MatrixTexture(textureConfig), textureConfig.textureFormatFloat, gl.FLOAT);
86943 }
86944 function getInternalFormatForFloat16MatrixTexture(textureConfig) {
86945 return textureConfig.internalFormatHalfFloat;
86946 }
86947 function createFloat16MatrixTexture(gl, rows, columns, textureConfig) {
86948 const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
86949 return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16MatrixTexture(textureConfig), textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat);
86950 }
86951 function getInternalFormatForUnsignedBytesMatrixTexture(textureConfig) {
86952 return textureConfig.downloadTextureFormat;
86953 }
86954 function createUnsignedBytesMatrixTexture(gl, rows, columns, textureConfig) {
86955 const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
86956 return createAndConfigureTexture(gl, width, height, getInternalFormatForUnsignedBytesMatrixTexture(textureConfig), gl.RGBA, gl.UNSIGNED_BYTE);
86957 }
86958 function getInternalFormatForPackedMatrixTexture(textureConfig) {
86959 return textureConfig.internalFormatPackedFloat;
86960 }
86961 function createPackedMatrixTexture(gl, rows, columns, textureConfig) {
86962 const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
86963 return createAndConfigureTexture(gl, width, height, getInternalFormatForPackedMatrixTexture(textureConfig), gl.RGBA, gl.FLOAT);
86964 }
86965 function getInternalFormatForFloat16PackedMatrixTexture(textureConfig) {
86966 return textureConfig.internalFormatPackedHalfFloat;
86967 }
86968 function createFloat16PackedMatrixTexture(gl, rows, columns, textureConfig) {
86969 const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
86970 return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16PackedMatrixTexture(textureConfig), gl.RGBA, textureConfig.textureTypeHalfFloat);
86971 }
86972 function bindVertexProgramAttributeStreams(gl, program, vertexBuffer) {
86973 const posOffset = 0; // x is the first buffer element
86974 const uvOffset = 3 * 4; // uv comes after [x y z]
86975 const stride = (3 * 4) + (2 * 4); // xyz + uv, each entry is 4-byte float.
86976 callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer));
86977 const success = bindVertexBufferToProgramAttribute(gl, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset);
86978 return success &&
86979 bindVertexBufferToProgramAttribute(gl, program, 'uv', vertexBuffer, 2, stride, uvOffset);
86980 }
86981 function uploadDenseMatrixToTexture(gl, texture, width, height, data, textureConfig) {
86982 callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture));
86983 let dataForUpload, texelDataType, internalFormat;
86984 if (data instanceof Uint8Array) {
86985 dataForUpload = new Uint8Array(width * height * 4);
86986 texelDataType = gl.UNSIGNED_BYTE;
86987 internalFormat = gl.RGBA;
86988 }
86989 else {
86990 dataForUpload = new Float32Array(width * height * 4);
86991 texelDataType = gl.FLOAT;
86992 internalFormat = textureConfig.internalFormatPackedFloat;
86993 }
86994 dataForUpload.set(data);
86995 if (env().getNumber('WEBGL_VERSION') === 2) {
86996 callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, width, height, gl.RGBA, texelDataType, dataForUpload));
86997 }
86998 else {
86999 callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA, texelDataType, dataForUpload));
87000 }
87001 callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
87002 }
87003 function uploadPixelDataToTexture(gl, texture, pixels) {
87004 callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture));
87005 if (pixels.data instanceof Uint8Array) {
87006 if (env().getNumber('WEBGL_VERSION') === 2) {
87007 callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, pixels.width, pixels.height, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data));
87008 }
87009 else {
87010 callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data));
87011 }
87012 }
87013 else {
87014 if (env().getNumber('WEBGL_VERSION') === 2) {
87015 callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels));
87016 }
87017 else {
87018 callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, pixels));
87019 }
87020 }
87021 callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
87022 }
87023 function createBufferFromOutputTexture(gl2, rows, columns, textureConfig) {
87024 // Create and bind the buffer.
87025 const buffer = gl2.createBuffer();
87026 callAndCheck(gl2, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer));
87027 // Initialize the buffer to the size of the texture in bytes.
87028 const bytesPerFloat = 4;
87029 const valuesPerTexel = 4;
87030 const bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns;
87031 callAndCheck(gl2, () => gl2.bufferData(gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ));
87032 // Enqueue a command on the GPU command queue to copy of texture into the
87033 // buffer.
87034 callAndCheck(gl2, () => gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0));
87035 callAndCheck(gl2, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null));
87036 return buffer;
87037 }
87038 function downloadFloat32MatrixFromBuffer(gl, buffer, size) {
87039 const gl2 = gl;
87040 const downloadTarget = new Float32Array(size);
87041 gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
87042 gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
87043 gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
87044 return downloadTarget;
87045 }
87046 function downloadByteEncodedFloatMatrixFromOutputTexture(gl, rows, columns, textureConfig) {
87047 const [w, h] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
87048 const numChannels = 4;
87049 const downloadTarget = new Uint8Array(getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels));
87050 callAndCheck(gl, () => gl.readPixels(0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE, downloadTarget));
87051 // By wrapping the buffer in a Float32Array, we use native browser IEEE 754
87052 // decoding of the 4 bytes that back each 32 bit float.
87053 return new Float32Array(downloadTarget.buffer);
87054 }
87055 function downloadPackedMatrixFromBuffer(gl, buffer, batch, rows, cols, physicalRows, physicalCols, textureConfig) {
87056 const gl2 = gl;
87057 const downloadTarget = new Float32Array(getPackedRGBAArraySizeFromMatrixShape(physicalRows, physicalCols));
87058 gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
87059 gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
87060 gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
87061 return downloadTarget;
87062 }
87063 function downloadMatrixFromPackedOutputTexture(gl, physicalRows, physicalCols) {
87064 const packedRGBA = new Float32Array(physicalRows * physicalCols * 4);
87065 callAndCheck(gl, () => gl.readPixels(0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA));
87066 return packedRGBA;
87067 }
87068
87069 var gpgpu_util = /*#__PURE__*/Object.freeze({
87070 __proto__: null,
87071 bindVertexProgramAttributeStreams: bindVertexProgramAttributeStreams,
87072 createBufferFromOutputTexture: createBufferFromOutputTexture,
87073 createFloat16MatrixTexture: createFloat16MatrixTexture,
87074 createFloat16PackedMatrixTexture: createFloat16PackedMatrixTexture,
87075 createFloat32MatrixTexture: createFloat32MatrixTexture,
87076 createIndexBuffer: createIndexBuffer,
87077 createPackedMatrixTexture: createPackedMatrixTexture,
87078 createUnsignedBytesMatrixTexture: createUnsignedBytesMatrixTexture,
87079 createVertexBuffer: createVertexBuffer,
87080 createVertexShader: createVertexShader,
87081 downloadByteEncodedFloatMatrixFromOutputTexture: downloadByteEncodedFloatMatrixFromOutputTexture,
87082 downloadFloat32MatrixFromBuffer: downloadFloat32MatrixFromBuffer,
87083 downloadMatrixFromPackedOutputTexture: downloadMatrixFromPackedOutputTexture,
87084 downloadPackedMatrixFromBuffer: downloadPackedMatrixFromBuffer,
87085 getInternalFormatForFloat16MatrixTexture: getInternalFormatForFloat16MatrixTexture,
87086 getInternalFormatForFloat16PackedMatrixTexture: getInternalFormatForFloat16PackedMatrixTexture,
87087 getInternalFormatForFloat32MatrixTexture: getInternalFormatForFloat32MatrixTexture,
87088 getInternalFormatForPackedMatrixTexture: getInternalFormatForPackedMatrixTexture,
87089 getInternalFormatForUnsignedBytesMatrixTexture: getInternalFormatForUnsignedBytesMatrixTexture,
87090 uploadDenseMatrixToTexture: uploadDenseMatrixToTexture,
87091 uploadPixelDataToTexture: uploadPixelDataToTexture
87092 });
87093
87094 /**
87095 * @license
87096 * Copyright 2017 Google LLC. All Rights Reserved.
87097 * Licensed under the Apache License, Version 2.0 (the "License");
87098 * you may not use this file except in compliance with the License.
87099 * You may obtain a copy of the License at
87100 *
87101 * http://www.apache.org/licenses/LICENSE-2.0
87102 *
87103 * Unless required by applicable law or agreed to in writing, software
87104 * distributed under the License is distributed on an "AS IS" BASIS,
87105 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87106 * See the License for the specific language governing permissions and
87107 * limitations under the License.
87108 * =============================================================================
87109 */
87110 class GPGPUContext {
87111 constructor(gl) {
87112 this.outputTexture = null;
87113 this.program = null;
87114 this.disposed = false;
87115 this.itemsToPoll = [];
87116 const glVersion = env().getNumber('WEBGL_VERSION');
87117 if (gl != null) {
87118 this.gl = gl;
87119 setWebGLContext(glVersion, gl);
87120 }
87121 else {
87122 this.gl = getWebGLContext(glVersion);
87123 }
87124 gl = this.gl;
87125 if (env().getNumber('WEBGL_VERSION') === 2) {
87126 const gl2 = gl;
87127 this.createVertexArray = () => {
87128 return callAndCheck(gl2, () => gl2.createVertexArray());
87129 };
87130 this.bindVertexArray = (vao) => {
87131 return callAndCheck(gl2, () => gl2.bindVertexArray(vao));
87132 };
87133 this.deleteVertexArray = (vao) => {
87134 return callAndCheck(gl2, () => gl2.deleteVertexArray(vao));
87135 };
87136 this.getVertexArray = () => {
87137 return callAndCheck(gl2, () => gl2.getParameter(gl2.VERTEX_ARRAY_BINDING));
87138 };
87139 }
87140 else if (gl != null) {
87141 const ext = gl.getExtension('OES_vertex_array_object');
87142 if (ext == null) {
87143 throw new Error('All WebGL1 implementations are expected to offer' +
87144 ' OES_vertex_array_object.');
87145 }
87146 this.createVertexArray = () => {
87147 return callAndCheck(gl, () => ext.createVertexArrayOES());
87148 };
87149 this.bindVertexArray = (vao) => {
87150 return callAndCheck(gl, () => ext.bindVertexArrayOES(vao));
87151 };
87152 this.deleteVertexArray = (vao) => {
87153 return callAndCheck(gl, () => ext.deleteVertexArrayOES(vao));
87154 };
87155 this.getVertexArray = () => {
87156 return callAndCheck(gl, () => gl.getParameter(ext.VERTEX_ARRAY_BINDING_OES));
87157 };
87158 }
87159 // WebGL 2.0 enables texture floats without an extension.
87160 let COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float';
87161 const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
87162 this.parallelCompilationExtension =
87163 this.gl.getExtension('KHR_parallel_shader_compile');
87164 if (env().getNumber('WEBGL_VERSION') === 1) {
87165 const TEXTURE_FLOAT = 'OES_texture_float';
87166 const TEXTURE_HALF_FLOAT = 'OES_texture_half_float';
87167 this.textureFloatExtension =
87168 getExtensionOrThrow(this.gl, TEXTURE_FLOAT);
87169 if (hasExtension(this.gl, TEXTURE_HALF_FLOAT)) {
87170 this.textureHalfFloatExtension =
87171 getExtensionOrThrow(this.gl, TEXTURE_HALF_FLOAT);
87172 }
87173 else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {
87174 throw new Error('GL context does not support half float textures, yet the ' +
87175 'environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
87176 }
87177 this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT);
87178 if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
87179 this.colorBufferHalfFloatExtension =
87180 getExtensionOrThrow(this.gl, COLOR_BUFFER_HALF_FLOAT);
87181 }
87182 else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {
87183 throw new Error('GL context does not support color renderable half floats, yet ' +
87184 'the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
87185 }
87186 }
87187 else {
87188 COLOR_BUFFER_FLOAT = 'EXT_color_buffer_float';
87189 if (hasExtension(this.gl, COLOR_BUFFER_FLOAT)) {
87190 this.colorBufferFloatExtension =
87191 this.gl.getExtension(COLOR_BUFFER_FLOAT);
87192 }
87193 else if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
87194 this.colorBufferHalfFloatExtension =
87195 this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
87196 }
87197 else {
87198 throw new Error('GL context does not support color renderable floats');
87199 }
87200 }
87201 this.vertexBuffer = createVertexBuffer(this.gl);
87202 this.indexBuffer = createIndexBuffer(this.gl);
87203 this.framebuffer = createFramebuffer(this.gl);
87204 this.textureConfig =
87205 getTextureConfig(this.gl, this.textureHalfFloatExtension);
87206 }
87207 get debug() {
87208 return env().getBool('DEBUG');
87209 }
87210 dispose() {
87211 if (this.disposed) {
87212 return;
87213 }
87214 if (this.program != null) {
87215 console.warn('Disposing a GPGPUContext that still has a bound WebGLProgram.' +
87216 ' This is probably a resource leak, delete the program with ' +
87217 'GPGPUContext.deleteProgram before disposing.');
87218 }
87219 if (this.outputTexture != null) {
87220 console.warn('Disposing a GPGPUContext that still has a bound output matrix ' +
87221 'texture. This is probably a resource leak, delete the output ' +
87222 'matrix texture with GPGPUContext.deleteMatrixTexture before ' +
87223 'disposing.');
87224 }
87225 const gl = this.gl;
87226 callAndCheck(gl, () => gl.finish());
87227 callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null));
87228 callAndCheck(gl, () => gl.deleteFramebuffer(this.framebuffer));
87229 callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, null));
87230 callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null));
87231 callAndCheck(gl, () => gl.deleteBuffer(this.indexBuffer));
87232 this.disposed = true;
87233 }
87234 createFloat32MatrixTexture(rows, columns) {
87235 this.throwIfDisposed();
87236 return createFloat32MatrixTexture(this.gl, rows, columns, this.textureConfig);
87237 }
87238 createFloat16MatrixTexture(rows, columns) {
87239 this.throwIfDisposed();
87240 return createFloat16MatrixTexture(this.gl, rows, columns, this.textureConfig);
87241 }
87242 createUnsignedBytesMatrixTexture(rows, columns) {
87243 this.throwIfDisposed();
87244 return createUnsignedBytesMatrixTexture(this.gl, rows, columns, this.textureConfig);
87245 }
87246 uploadPixelDataToTexture(texture, pixels) {
87247 this.throwIfDisposed();
87248 uploadPixelDataToTexture(this.gl, texture, pixels);
87249 }
87250 uploadDenseMatrixToTexture(texture, width, height, data) {
87251 this.throwIfDisposed();
87252 uploadDenseMatrixToTexture(this.gl, texture, width, height, data, this.textureConfig);
87253 }
87254 createFloat16PackedMatrixTexture(rows, columns) {
87255 this.throwIfDisposed();
87256 return createFloat16PackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
87257 }
87258 createPackedMatrixTexture(rows, columns) {
87259 this.throwIfDisposed();
87260 return createPackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
87261 }
87262 deleteMatrixTexture(texture) {
87263 this.throwIfDisposed();
87264 if (this.outputTexture === texture) {
87265 unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
87266 this.outputTexture = null;
87267 }
87268 callAndCheck(this.gl, () => this.gl.deleteTexture(texture));
87269 }
87270 downloadByteEncodedFloatMatrixFromOutputTexture(texture, rows, columns) {
87271 return this.downloadMatrixDriver(texture, () => downloadByteEncodedFloatMatrixFromOutputTexture(this.gl, rows, columns, this.textureConfig));
87272 }
87273 downloadPackedMatrixFromBuffer(buffer, batch, rows, columns, physicalRows, physicalCols) {
87274 return downloadPackedMatrixFromBuffer(this.gl, buffer, batch, rows, columns, physicalRows, physicalCols, this.textureConfig);
87275 }
87276 downloadFloat32MatrixFromBuffer(buffer, size) {
87277 return downloadFloat32MatrixFromBuffer(this.gl, buffer, size);
87278 }
87279 createBufferFromTexture(texture, rows, columns) {
87280 this.bindTextureToFrameBuffer(texture);
87281 const result = createBufferFromOutputTexture(this.gl, rows, columns, this.textureConfig);
87282 this.unbindTextureToFrameBuffer();
87283 return result;
87284 }
87285 createAndWaitForFence() {
87286 const fenceContext = this.createFence(this.gl);
87287 return this.pollFence(fenceContext);
87288 }
87289 createFence(gl) {
87290 let query;
87291 let isFencePassed;
87292 if (env().getBool('WEBGL_FENCE_API_ENABLED')) {
87293 const gl2 = gl;
87294 const sync = gl2.fenceSync(gl2.SYNC_GPU_COMMANDS_COMPLETE, 0);
87295 gl.flush();
87296 isFencePassed = () => {
87297 const status = gl2.clientWaitSync(sync, 0, 0);
87298 return status === gl2.ALREADY_SIGNALED ||
87299 status === gl2.CONDITION_SATISFIED;
87300 };
87301 query = sync;
87302 }
87303 else if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
87304 query = this.beginQuery();
87305 this.endQuery();
87306 isFencePassed = () => this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'));
87307 }
87308 else {
87309 // If we have no way to fence, return true immediately. This will fire in
87310 // WebGL 1.0 when there is no disjoint query timer. In this case, because
87311 // the fence passes immediately, we'll immediately ask for a download of
87312 // the texture, which will cause the UI thread to hang.
87313 isFencePassed = () => true;
87314 }
87315 return { query, isFencePassed };
87316 }
87317 downloadMatrixFromPackedTexture(texture, physicalRows, physicalCols) {
87318 return this.downloadMatrixDriver(texture, () => downloadMatrixFromPackedOutputTexture(this.gl, physicalRows, physicalCols));
87319 }
87320 createProgram(fragmentShader) {
87321 this.throwIfDisposed();
87322 const gl = this.gl;
87323 if (this.vertexShader == null) {
87324 this.vertexShader = createVertexShader(gl);
87325 }
87326 const program = createProgram(gl);
87327 callAndCheck(gl, () => gl.attachShader(program, this.vertexShader));
87328 callAndCheck(gl, () => gl.attachShader(program, fragmentShader));
87329 linkProgram(gl, program);
87330 const program2 = Object.assign(program, { vao: this.createVertexArray() });
87331 if (this.debug) {
87332 validateProgram(gl, program2);
87333 }
87334 return program2;
87335 }
87336 buildVao(program) {
87337 this.setProgram(program);
87338 this.bindVertexArray(program.vao);
87339 const gl = this.gl;
87340 // Bind index buffer, and vertex buffers based on program attrib
87341 // locations.
87342 callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, this.indexBuffer));
87343 bindVertexProgramAttributeStreams(gl, program, this.vertexBuffer);
87344 }
87345 deleteProgram(program) {
87346 this.throwIfDisposed();
87347 if (program === this.program) {
87348 this.program = null;
87349 }
87350 if (program != null) {
87351 callAndCheck(this.gl, () => this.gl.deleteProgram(program));
87352 this.deleteVertexArray(program.vao);
87353 }
87354 }
87355 setProgram(program) {
87356 this.throwIfDisposed();
87357 this.program = program;
87358 if (this.program != null) {
87359 if (this.debug) {
87360 validateProgram(this.gl, this.program);
87361 }
87362 }
87363 callAndCheck(this.gl, () => this.gl.useProgram(program));
87364 }
87365 getUniformLocation(program, uniformName, shouldThrow = true) {
87366 this.throwIfDisposed();
87367 if (shouldThrow) {
87368 return getProgramUniformLocationOrThrow(this.gl, program, uniformName);
87369 }
87370 else {
87371 return getProgramUniformLocation(this.gl, program, uniformName);
87372 }
87373 }
87374 getAttributeLocation(program, attribute) {
87375 this.throwIfDisposed();
87376 return callAndCheck(this.gl, () => this.gl.getAttribLocation(program, attribute));
87377 }
87378 getUniformLocationNoThrow(program, uniformName) {
87379 this.throwIfDisposed();
87380 return this.gl.getUniformLocation(program, uniformName);
87381 }
87382 setInputMatrixTexture(inputMatrixTexture, uniformLocation, textureUnit) {
87383 this.throwIfDisposed();
87384 this.throwIfNoProgram();
87385 bindTextureToProgramUniformSampler(this.gl, inputMatrixTexture, uniformLocation, textureUnit);
87386 }
87387 setOutputMatrixTexture(outputMatrixTexture, rows, columns) {
87388 this.setOutputMatrixTextureDriver(outputMatrixTexture, columns, rows);
87389 }
87390 setOutputPackedMatrixTexture(outputPackedMatrixTexture, rows, columns) {
87391 this.throwIfDisposed();
87392 const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
87393 this.setOutputMatrixTextureDriver(outputPackedMatrixTexture, width, height);
87394 }
87395 setOutputMatrixWriteRegion(startRow, numRows, startColumn, numColumns) {
87396 this.setOutputMatrixWriteRegionDriver(startColumn, startRow, numColumns, numRows);
87397 }
87398 setOutputPackedMatrixWriteRegion(startRow, numRows, startColumn, numColumns) {
87399 throw new Error('setOutputPackedMatrixWriteRegion not implemented.');
87400 }
87401 debugValidate() {
87402 if (this.program != null) {
87403 validateProgram(this.gl, this.program);
87404 }
87405 validateFramebuffer(this.gl);
87406 }
87407 executeProgram() {
87408 this.throwIfDisposed();
87409 this.throwIfNoProgram();
87410 const gl = this.gl;
87411 if (this.debug) {
87412 const boundVao = this.getVertexArray();
87413 console.assert(boundVao === this.program.vao, 'VAO changed between setProgram and executeProgram!');
87414 this.debugValidate();
87415 }
87416 callAndCheck(gl, () => gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0));
87417 }
87418 blockUntilAllProgramsCompleted() {
87419 this.throwIfDisposed();
87420 callAndCheck(this.gl, () => this.gl.finish());
87421 }
87422 getQueryTimerExtension() {
87423 if (this.disjointQueryTimerExtension == null) {
87424 this.disjointQueryTimerExtension =
87425 getExtensionOrThrow(this.gl, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ?
87426 'EXT_disjoint_timer_query_webgl2' :
87427 'EXT_disjoint_timer_query');
87428 }
87429 return this.disjointQueryTimerExtension;
87430 }
87431 getQueryTimerExtensionWebGL2() {
87432 return this.getQueryTimerExtension();
87433 }
87434 getQueryTimerExtensionWebGL1() {
87435 return this.getQueryTimerExtension();
87436 }
87437 beginQuery() {
87438 if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
87439 const gl2 = this.gl;
87440 const ext = this.getQueryTimerExtensionWebGL2();
87441 const query = gl2.createQuery();
87442 gl2.beginQuery(ext.TIME_ELAPSED_EXT, query);
87443 return query;
87444 }
87445 const ext = this.getQueryTimerExtensionWebGL1();
87446 const query = ext.createQueryEXT();
87447 ext.beginQueryEXT(ext.TIME_ELAPSED_EXT, query);
87448 return query;
87449 }
87450 endQuery() {
87451 if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
87452 const gl2 = this.gl;
87453 const ext = this.getQueryTimerExtensionWebGL2();
87454 gl2.endQuery(ext.TIME_ELAPSED_EXT);
87455 return;
87456 }
87457 const ext = this.getQueryTimerExtensionWebGL1();
87458 ext.endQueryEXT(ext.TIME_ELAPSED_EXT);
87459 }
87460 async waitForQueryAndGetTime(query) {
87461 await repeatedTry(() => this.disposed || // while testing contexts are created / disposed
87462 // in rapid succession, so without this check we
87463 // may poll for the query timer indefinitely
87464 this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')));
87465 return this.getQueryTime(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'));
87466 }
87467 getQueryTime(query, queryTimerVersion) {
87468 if (queryTimerVersion === 0) {
87469 return null;
87470 }
87471 if (queryTimerVersion === 2) {
87472 const gl2 = this.gl;
87473 const timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT);
87474 // Return milliseconds.
87475 return timeElapsedNanos / 1000000;
87476 }
87477 else {
87478 const ext = this.getQueryTimerExtensionWebGL1();
87479 const timeElapsedNanos = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT);
87480 // Return milliseconds.
87481 return timeElapsedNanos / 1000000;
87482 }
87483 }
87484 isQueryAvailable(query, queryTimerVersion) {
87485 if (queryTimerVersion === 0) {
87486 return true;
87487 }
87488 if (queryTimerVersion === 2) {
87489 const gl2 = this.gl;
87490 const ext = this.getQueryTimerExtensionWebGL2();
87491 const available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE);
87492 if (this.disjoint == null) {
87493 this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
87494 }
87495 return available && !this.disjoint;
87496 }
87497 else {
87498 const ext = this.getQueryTimerExtensionWebGL1();
87499 const available = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_AVAILABLE_EXT);
87500 if (this.disjoint == null) {
87501 this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
87502 }
87503 return available && !this.disjoint;
87504 }
87505 }
87506 pollFence(fenceContext) {
87507 return new Promise(resolve => {
87508 this.addItemToPoll(() => fenceContext.isFencePassed(), () => resolve());
87509 });
87510 }
87511 pollItems() {
87512 // Find the last query that has finished.
87513 const index = linearSearchLastTrue(this.itemsToPoll.map(x => x.isDoneFn));
87514 for (let i = 0; i <= index; ++i) {
87515 const { resolveFn } = this.itemsToPoll[i];
87516 resolveFn();
87517 }
87518 this.itemsToPoll = this.itemsToPoll.slice(index + 1);
87519 }
87520 addItemToPoll(isDoneFn, resolveFn) {
87521 this.itemsToPoll.push({ isDoneFn, resolveFn });
87522 if (this.itemsToPoll.length > 1) {
87523 // We already have a running loop that polls.
87524 return;
87525 }
87526 // Start a new loop that polls.
87527 let scheduleFn = undefined;
87528 if ('setTimeoutCustom' in env().platform) {
87529 scheduleFn = env().platform.setTimeoutCustom.bind(env().platform);
87530 }
87531 repeatedTry(() => {
87532 this.pollItems();
87533 // End the loop if no more items to poll.
87534 return this.itemsToPoll.length === 0;
87535 }, () => 0, null, scheduleFn);
87536 }
87537 bindTextureToFrameBuffer(texture) {
87538 this.throwIfDisposed();
87539 bindColorTextureToFramebuffer(this.gl, texture, this.framebuffer);
87540 if (this.debug) {
87541 validateFramebuffer(this.gl);
87542 }
87543 }
87544 unbindTextureToFrameBuffer() {
87545 if (this.outputTexture != null) {
87546 bindColorTextureToFramebuffer(this.gl, this.outputTexture, this.framebuffer);
87547 if (this.debug) {
87548 validateFramebuffer(this.gl);
87549 }
87550 }
87551 else {
87552 unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
87553 }
87554 }
87555 downloadMatrixDriver(texture, downloadAndDecode) {
87556 this.bindTextureToFrameBuffer(texture);
87557 const result = downloadAndDecode();
87558 this.unbindTextureToFrameBuffer();
87559 return result;
87560 }
87561 setOutputMatrixTextureDriver(outputMatrixTextureMaybePacked, width, height) {
87562 this.throwIfDisposed();
87563 const gl = this.gl;
87564 bindColorTextureToFramebuffer(gl, outputMatrixTextureMaybePacked, this.framebuffer);
87565 if (this.debug) {
87566 validateFramebuffer(gl);
87567 }
87568 this.outputTexture = outputMatrixTextureMaybePacked;
87569 callAndCheck(gl, () => gl.viewport(0, 0, width, height));
87570 callAndCheck(gl, () => gl.scissor(0, 0, width, height));
87571 }
87572 setOutputMatrixWriteRegionDriver(x, y, width, height) {
87573 this.throwIfDisposed();
87574 callAndCheck(this.gl, () => this.gl.scissor(x, y, width, height));
87575 }
87576 throwIfDisposed() {
87577 if (this.disposed) {
87578 throw new Error('Attempted to use disposed GPGPUContext.');
87579 }
87580 }
87581 throwIfNoProgram() {
87582 if (this.program == null) {
87583 throw new Error('No GPU program is currently set.');
87584 }
87585 }
87586 }
87587 /**
87588 * Finds the index of the last true element using linear search.
87589 * Note: We can't do binary search because Chrome expects us to explicitly
87590 * test all fences before download:
87591 * https://github.com/tensorflow/tfjs/issues/1145
87592 */
87593 function linearSearchLastTrue(arr) {
87594 let i = 0;
87595 for (; i < arr.length; ++i) {
87596 const isDone = arr[i]();
87597 if (!isDone) {
87598 break;
87599 }
87600 }
87601 return i - 1;
87602 }
87603
87604 /**
87605 * @license
87606 * Copyright 2020 Google LLC. All Rights Reserved.
87607 * Licensed under the Apache License, Version 2.0 (the "License");
87608 * you may not use this file except in compliance with the License.
87609 * You may obtain a copy of the License at
87610 *
87611 * http://www.apache.org/licenses/LICENSE-2.0
87612 *
87613 * Unless required by applicable law or agreed to in writing, software
87614 * distributed under the License is distributed on an "AS IS" BASIS,
87615 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87616 * See the License for the specific language governing permissions and
87617 * limitations under the License.
87618 * =============================================================================
87619 */
87620 const { addImpl: addImplCPU, bincountImpl: bincountImplCPU, bincountReduceImpl: bincountReduceImplCPU, bitwiseAndImpl: bitwiseAndImplCPU, castImpl: castImplCPU, ceilImpl: ceilImplCPU, concatImpl: concatImplCPU, equalImpl: equalImplCPU, expImpl: expImplCPU, expm1Impl: expm1ImplCPU, floorImpl: floorImplCPU, gatherNdImpl: gatherNdImplCPU, gatherV2Impl: gatherV2ImplCPU, greaterImpl: greaterImplCPU, greaterEqualImpl: greaterEqualImplCPU, lessImpl: lessImplCPU, lessEqualImpl: lessEqualImplCPU, linSpaceImpl: linSpaceImplCPU, logImpl: logImplCPU, maxImpl: maxImplCPU, maximumImpl: maximumImplCPU, minimumImpl: minimumImplCPU, multiplyImpl: multiplyImplCPU, negImpl: negImplCPU, notEqualImpl: notEqualImplCPU, prodImpl: prodImplCPU, raggedGatherImpl: raggedGatherImplCPU, raggedRangeImpl: raggedRangeImplCPU, raggedTensorToTensorImpl: raggedTensorToTensorImplCPU, rangeImpl: rangeImplCPU, rsqrtImpl: rsqrtImplCPU, scatterImpl: scatterImplCPU, sigmoidImpl: sigmoidImplCPU, simpleAbsImpl: simpleAbsImplCPU, sliceImpl: sliceImplCPU, sparseFillEmptyRowsImpl: sparseFillEmptyRowsImplCPU, sparseReshapeImpl: sparseReshapeImplCPU, sparseSegmentReductionImpl: sparseSegmentReductionImplCPU, sqrtImpl: sqrtImplCPU, staticRegexReplaceImpl: staticRegexReplaceImplCPU, stridedSliceImpl: stridedSliceImplCPU, stringNGramsImpl: stringNGramsImplCPU, stringSplitImpl: stringSplitImplCPU, stringToHashBucketFastImpl: stringToHashBucketFastImplCPU, subImpl: subImplCPU, tileImpl: tileImplCPU, topKImpl: topKImplCPU, transposeImpl: transposeImplCPU, uniqueImpl: uniqueImplCPU, } = shared;
87621
87622 /**
87623 * @license
87624 * Copyright 2018 Google LLC. All Rights Reserved.
87625 * Licensed under the Apache License, Version 2.0 (the "License");
87626 * you may not use this file except in compliance with the License.
87627 * You may obtain a copy of the License at
87628 *
87629 * http://www.apache.org/licenses/LICENSE-2.0
87630 *
87631 * Unless required by applicable law or agreed to in writing, software
87632 * distributed under the License is distributed on an "AS IS" BASIS,
87633 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87634 * See the License for the specific language governing permissions and
87635 * limitations under the License.
87636 * =============================================================================
87637 */
87638 function getVecChannels(name, rank) {
87639 return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(d => `${name}.${d}`);
87640 }
87641 function getChannels(name, rank) {
87642 if (rank === 1) {
87643 return [name];
87644 }
87645 return getVecChannels(name, rank);
87646 }
87647 function getSourceCoords$2(rank, dims) {
87648 if (rank === 1) {
87649 return 'rc';
87650 }
87651 let coords = '';
87652 for (let i = 0; i < rank; i++) {
87653 coords += dims[i];
87654 if (i < rank - 1) {
87655 coords += ',';
87656 }
87657 }
87658 return coords;
87659 }
87660
87661 /**
87662 * @license
87663 * Copyright 2018 Google LLC. All Rights Reserved.
87664 * Licensed under the Apache License, Version 2.0 (the "License");
87665 * you may not use this file except in compliance with the License.
87666 * You may obtain a copy of the License at
87667 *
87668 * http://www.apache.org/licenses/LICENSE-2.0
87669 *
87670 * Unless required by applicable law or agreed to in writing, software
87671 * distributed under the License is distributed on an "AS IS" BASIS,
87672 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87673 * See the License for the specific language governing permissions and
87674 * limitations under the License.
87675 * =============================================================================
87676 */
87677 class PackProgram {
87678 constructor(outputShape) {
87679 this.variableNames = ['A'];
87680 this.packedInputs = false;
87681 this.packedOutput = true;
87682 // Only input / output 3D tensors.
87683 this.outputShape = outputShape;
87684 this.rank = outputShape.length;
87685 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
87686 if (this.rank === 0) {
87687 this.userCode = `
87688 void main() {
87689 setOutput(vec4(getA(), 0., 0., 0.));
87690 }
87691 `;
87692 }
87693 else {
87694 const channels = getChannels('rc', this.rank);
87695 const dtype = getCoordsDataType(this.rank);
87696 const outOfBoundsCondition = this.getOutOfBoundsCondition(channels);
87697 const setup = this.getSetup(channels);
87698 const output = this.getOutput(channels);
87699 this.userCode = `
87700 void main() {
87701 ${dtype} rc = getOutputCoords();
87702
87703 if(${outOfBoundsCondition}) {
87704 setOutput(vec4(0));
87705 } else {
87706 ${setup}
87707
87708 setOutput(vec4(${output}));
87709 }
87710 }
87711 `;
87712 }
87713 }
87714 getSourceCoordsArr(dims) {
87715 const coords = [];
87716 for (let row = 0; row <= 1; row++) {
87717 for (let col = 0; col <= 1; col++) {
87718 let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`;
87719 for (let d = 2; d < this.rank; d++) {
87720 coord = `${dims[dims.length - 1 - d]},` + coord;
87721 }
87722 coords.push(coord);
87723 }
87724 }
87725 return coords;
87726 }
87727 getOutOfBoundsCondition(dims) {
87728 if (this.rank === 1) {
87729 return `rc > ${this.enableShapeUniforms ? 'outShape' : this.outputShape[0]}`;
87730 }
87731 let cond = '';
87732 for (let i = this.rank - 2; i < this.rank; i++) {
87733 cond += `${dims[i]} >= ${this.enableShapeUniforms ? `outShape[${i}]` : this.outputShape[i]}`;
87734 if (i < this.rank - 1) {
87735 cond += '||';
87736 }
87737 }
87738 return cond;
87739 }
87740 getSetup(dims) {
87741 if (this.rank === 1) {
87742 return '';
87743 }
87744 const innerDims = dims.slice(-2);
87745 const col = this.enableShapeUniforms ? `outShape[${this.rank} - 1]` :
87746 this.outputShape[this.rank - 1];
87747 const row = this.enableShapeUniforms ? `outShape[${this.rank} - 2]` :
87748 this.outputShape[this.rank - 2];
87749 return `
87750 int r = ${innerDims[0]};
87751 int c = ${innerDims[1]};
87752 int rp1 = r + 1;
87753 int cp1 = c + 1;
87754
87755 bool cEdge = cp1 >= ${col};
87756 bool rEdge = rp1 >= ${row};
87757 `;
87758 }
87759 getOutput(dims) {
87760 const sourceCoords = this.getSourceCoordsArr(dims);
87761 if (this.rank === 1) {
87762 const outShape = this.enableShapeUniforms ? 'outShape' : this.outputShape[0];
87763 return `getA(rc), (rc + 1 >= ${outShape} ? 0. : getA(rc + 1)), 0, 0`;
87764 }
87765 return `getA(${sourceCoords[0]}),
87766 cEdge ? 0. : getA(${sourceCoords[1]}),
87767 rEdge ? 0. : getA(${sourceCoords[2]}),
87768 rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`;
87769 }
87770 }
87771
87772 /**
87773 * @license
87774 * Copyright 2018 Google LLC. All Rights Reserved.
87775 * Licensed under the Apache License, Version 2.0 (the "License");
87776 * you may not use this file except in compliance with the License.
87777 * You may obtain a copy of the License at
87778 *
87779 * http://www.apache.org/licenses/LICENSE-2.0
87780 *
87781 * Unless required by applicable law or agreed to in writing, software
87782 * distributed under the License is distributed on an "AS IS" BASIS,
87783 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87784 * See the License for the specific language governing permissions and
87785 * limitations under the License.
87786 * =============================================================================
87787 */
87788 class ReshapePackedProgram {
87789 constructor(outputShape, inputShape) {
87790 this.variableNames = ['A'];
87791 this.packedInputs = true;
87792 this.packedOutput = true;
87793 this.customUniforms = [{ name: 'inputShape', type: 'ivec3' }];
87794 this.outputShape = outputShape;
87795 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
87796 let mainLoop = ``;
87797 for (let i = 0; i < 4; i++) {
87798 let thisRC = `thisRC = rc;`;
87799 if (i % 2 === 1) {
87800 thisRC += `thisRC.z += 1;`;
87801 }
87802 if (i > 1) {
87803 thisRC += `thisRC.y += 1;`;
87804 }
87805 mainLoop += `
87806 ${thisRC}
87807 ${i > 0 ? `if(thisRC.y < rows && thisRC.z < cols){` : ''}
87808 int flatIndex = getFlatIndex(thisRC);
87809
87810 ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);
87811 vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));
87812
87813 result[${i}] =
87814 getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);
87815 ${i > 0 ? '}' : ''}
87816 `;
87817 }
87818 this.userCode = `
87819 ${getReshapedInputCoords(inputShape, this.enableShapeUniforms)}
87820 ${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() :
87821 getFlatIndexFrom3D(outputShape)}
87822
87823 void main() {
87824 ivec3 rc = getOutputCoords();
87825
87826 vec4 result = vec4(0.);
87827
87828 ivec3 thisRC;
87829 int rows = ${this.enableShapeUniforms ? 'outShape[1]' : outputShape[1]};
87830 int cols = ${this.enableShapeUniforms ? 'outShape[2]' : outputShape[2]};
87831
87832 ${mainLoop}
87833
87834 setOutput(result);
87835 }
87836 `;
87837 }
87838 }
87839 function getReshapedInputCoords(shape, enableShapeUniforms) {
87840 const coordsFromIndexSnippet = enableShapeUniforms ?
87841 getLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], 'inputShape') :
87842 getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
87843 return `
87844 ivec3 inputCoordsFromReshapedOutCoords(int index) {
87845 ${coordsFromIndexSnippet}
87846 return ivec3(r, c, d);
87847 }
87848 `;
87849 }
87850
87851 /**
87852 * @license
87853 * Copyright 2017 Google LLC. All Rights Reserved.
87854 * Licensed under the Apache License, Version 2.0 (the "License");
87855 * you may not use this file except in compliance with the License.
87856 * You may obtain a copy of the License at
87857 *
87858 * http://www.apache.org/licenses/LICENSE-2.0
87859 *
87860 * Unless required by applicable law or agreed to in writing, software
87861 * distributed under the License is distributed on an "AS IS" BASIS,
87862 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
87863 * See the License for the specific language governing permissions and
87864 * limitations under the License.
87865 * =============================================================================
87866 */
87867 class TextureManager {
87868 constructor(gpgpu) {
87869 this.gpgpu = gpgpu;
87870 this.numUsedTextures = 0;
87871 this.numFreeTextures = 0;
87872 this._numBytesAllocated = 0;
87873 // Number of bytes that have been allocated and available for reuse.
87874 this._numBytesFree = 0;
87875 this.freeTextures = {};
87876 this.usedTextures = {};
87877 this.logEnabled = false;
87878 }
87879 acquireTexture(shapeRC, usage, isPacked) {
87880 const physicalTexType = getPhysicalFromLogicalTextureType(usage, isPacked);
87881 const shapeKey = getKeyFromTextureShape(shapeRC, physicalTexType, isPacked);
87882 if (!(shapeKey in this.freeTextures)) {
87883 this.freeTextures[shapeKey] = [];
87884 }
87885 if (!(shapeKey in this.usedTextures)) {
87886 this.usedTextures[shapeKey] = [];
87887 }
87888 const texBytes = computeBytes(shapeRC, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
87889 if (this.freeTextures[shapeKey].length > 0) {
87890 this.numFreeTextures--;
87891 this.numUsedTextures++;
87892 this._numBytesFree -= texBytes;
87893 this.log();
87894 const newTexture = this.freeTextures[shapeKey].pop();
87895 this.usedTextures[shapeKey].push(newTexture);
87896 return newTexture;
87897 }
87898 let newTexture;
87899 if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT32) {
87900 newTexture = this.gpgpu.createPackedMatrixTexture(shapeRC[0], shapeRC[1]);
87901 }
87902 else if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT16) {
87903 newTexture =
87904 this.gpgpu.createFloat16PackedMatrixTexture(shapeRC[0], shapeRC[1]);
87905 }
87906 else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT32) {
87907 newTexture =
87908 this.gpgpu.createFloat32MatrixTexture(shapeRC[0], shapeRC[1]);
87909 }
87910 else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT16) {
87911 newTexture =
87912 this.gpgpu.createFloat16MatrixTexture(shapeRC[0], shapeRC[1]);
87913 }
87914 else if (physicalTexType === PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE) {
87915 newTexture =
87916 this.gpgpu.createUnsignedBytesMatrixTexture(shapeRC[0], shapeRC[1]);
87917 }
87918 this.usedTextures[shapeKey].push(newTexture);
87919 this.numUsedTextures++;
87920 this._numBytesAllocated += texBytes;
87921 this.log();
87922 return newTexture;
87923 }
87924 releaseTexture(texture, shape, logicalTexType, isPacked) {
87925 if (this.freeTextures == null) {
87926 // Already disposed.
87927 return;
87928 }
87929 const physicalTexType = getPhysicalFromLogicalTextureType(logicalTexType, isPacked);
87930 const shapeKey = getKeyFromTextureShape(shape, physicalTexType, isPacked);
87931 if (!(shapeKey in this.freeTextures)) {
87932 this.freeTextures[shapeKey] = [];
87933 }
87934 const texBytes = computeBytes(shape, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
87935 const deleteTexThreshold = env()
87936 .getNumber('WEBGL_DELETE_TEXTURE_THRESHOLD');
87937 if (deleteTexThreshold !== -1 &&
87938 this._numBytesAllocated > deleteTexThreshold) {
87939 this.gpgpu.deleteMatrixTexture(texture.texture);
87940 this._numBytesAllocated -= texBytes;
87941 }
87942 else {
87943 this.freeTextures[shapeKey].push(texture);
87944 this.numFreeTextures++;
87945 this._numBytesFree += texBytes;
87946 }
87947 this.numUsedTextures--;
87948 const texList = this.usedTextures[shapeKey];
87949 const texIndex = texList && texList.indexOf(texture);
87950 if (texIndex == null || texIndex < 0) {
87951 throw new Error('Cannot release a texture that was never provided by this ' +
87952 'texture manager');
87953 }
87954 texList[texIndex] = texList[texList.length - 1];
87955 texList.pop();
87956 this.log();
87957 }
87958 log() {
87959 if (!this.logEnabled) {
87960 return;
87961 }
87962 const total = this.numFreeTextures + this.numUsedTextures;
87963 console.log('Free/Used', `${this.numFreeTextures} / ${this.numUsedTextures}`, `(${total})`);
87964 const freeRatio = this._numBytesFree / this._numBytesAllocated;
87965 console.log(`Bytes allocated: ${this._numBytesAllocated}`);
87966 console.log(`Bytes unused: ${this._numBytesFree} (${Math.round(100 * freeRatio)}%)`);
87967 }
87968 get numBytesAllocated() {
87969 return this._numBytesAllocated;
87970 }
87971 get numBytesFree() {
87972 return this._numBytesFree;
87973 }
87974 getNumUsedTextures() {
87975 return this.numUsedTextures;
87976 }
87977 getNumFreeTextures() {
87978 return this.numFreeTextures;
87979 }
87980 dispose() {
87981 if (this.freeTextures == null) {
87982 // Already disposed.
87983 return;
87984 }
87985 for (const texShape in this.freeTextures) {
87986 this.freeTextures[texShape].forEach(tex => {
87987 this.gpgpu.deleteMatrixTexture(tex.texture);
87988 });
87989 }
87990 for (const texShape in this.usedTextures) {
87991 this.usedTextures[texShape].forEach(tex => {
87992 this.gpgpu.deleteMatrixTexture(tex.texture);
87993 });
87994 }
87995 // TODO: Assign non-null value (empty object) to textures after disposed.
87996 this.freeTextures = null;
87997 this.usedTextures = null;
87998 this.numUsedTextures = 0;
87999 this.numFreeTextures = 0;
88000 this._numBytesAllocated = 0;
88001 this._numBytesFree = 0;
88002 }
88003 }
88004 function numBytesForInternalFormat(gl, internalFormat) {
88005 // tslint:disable-next-line:no-any
88006 const glany = gl;
88007 if (internalFormat === glany.R32F) {
88008 return 4;
88009 }
88010 else if (internalFormat === glany.R16F) {
88011 return 2;
88012 }
88013 else if (internalFormat === glany.RGBA32F) {
88014 return 16;
88015 }
88016 else if (internalFormat === gl.RGBA) {
88017 return 16;
88018 }
88019 else if (internalFormat === glany.RGBA16F) {
88020 return 8;
88021 }
88022 else if (internalFormat === glany.RGBA8) {
88023 return 4;
88024 }
88025 throw new Error(`Unknown internal format ${internalFormat}`);
88026 }
88027 function computeBytes(shape, physicalTexType, gl, textureConfig, isPacked) {
88028 // It is not possible to infer packed status from the texture type because
88029 // depending on the textureConfig, different texture types may resolve to the
88030 // same internal format (e.g. in WebGL1, the internal format for
88031 // UNPACKED_FLOAT16 textures is gl.RGBA). Therefore we pass in `isPacked`
88032 // explicitly.
88033 const internalFormat = internalFormatForPhysicalTexType(physicalTexType, textureConfig);
88034 let numElements;
88035 if (isPacked) {
88036 const [packedWidth, packedHeight] = getPackedMatrixTextureShapeWidthHeight(shape[0], shape[1]);
88037 numElements = packedWidth * packedHeight;
88038 }
88039 else {
88040 const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(shape[0], shape[1]);
88041 numElements = width * height;
88042 }
88043 const bytesPerElement = numBytesForInternalFormat(gl, internalFormat);
88044 return numElements * bytesPerElement;
88045 }
88046 function internalFormatForPhysicalTexType(physicalTexType, textureConfig) {
88047 switch (physicalTexType) {
88048 case PhysicalTextureType.PACKED_2X2_FLOAT32:
88049 return getInternalFormatForPackedMatrixTexture(textureConfig);
88050 case PhysicalTextureType.PACKED_2X2_FLOAT16:
88051 return getInternalFormatForFloat16PackedMatrixTexture(textureConfig);
88052 case PhysicalTextureType.UNPACKED_FLOAT32:
88053 return getInternalFormatForFloat32MatrixTexture(textureConfig);
88054 case PhysicalTextureType.UNPACKED_FLOAT16:
88055 return getInternalFormatForFloat16MatrixTexture(textureConfig);
88056 case PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE:
88057 return getInternalFormatForUnsignedBytesMatrixTexture(textureConfig);
88058 default:
88059 throw new Error(`Unknown physical texture type ${physicalTexType}`);
88060 }
88061 }
88062 function getPhysicalTextureForRendering(isPacked) {
88063 if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) {
88064 if (isPacked) {
88065 return PhysicalTextureType.PACKED_2X2_FLOAT32;
88066 }
88067 return PhysicalTextureType.UNPACKED_FLOAT32;
88068 }
88069 if (isPacked) {
88070 return PhysicalTextureType.PACKED_2X2_FLOAT16;
88071 }
88072 return PhysicalTextureType.UNPACKED_FLOAT16;
88073 }
88074 function getPhysicalFromLogicalTextureType(logicalTexType, isPacked) {
88075 if (logicalTexType === TextureUsage.UPLOAD) {
88076 return PhysicalTextureType.PACKED_2X2_FLOAT32;
88077 }
88078 else if (logicalTexType === TextureUsage.RENDER || logicalTexType == null) {
88079 return getPhysicalTextureForRendering(isPacked);
88080 }
88081 else if (logicalTexType === TextureUsage.DOWNLOAD ||
88082 logicalTexType === TextureUsage.PIXELS) {
88083 return PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE;
88084 }
88085 throw new Error(`Unknown logical texture type ${logicalTexType}`);
88086 }
88087 function getKeyFromTextureShape(shapeRowsCol, physicalTexType, isPacked) {
88088 return `${shapeRowsCol[0]}_${shapeRowsCol[1]}_${physicalTexType}_${isPacked}`;
88089 }
88090
88091 /**
88092 * @license
88093 * Copyright 2017 Google LLC. All Rights Reserved.
88094 * Licensed under the Apache License, Version 2.0 (the "License");
88095 * you may not use this file except in compliance with the License.
88096 * You may obtain a copy of the License at
88097 *
88098 * http://www.apache.org/licenses/LICENSE-2.0
88099 *
88100 * Unless required by applicable law or agreed to in writing, software
88101 * distributed under the License is distributed on an "AS IS" BASIS,
88102 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88103 * See the License for the specific language governing permissions and
88104 * limitations under the License.
88105 * =============================================================================
88106 */
88107 class UnaryOpProgram {
88108 constructor(aShape, opSnippet) {
88109 this.variableNames = ['A'];
88110 this.outputShape = aShape;
88111 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
88112 this.userCode = `
88113 float unaryOperation(float x) {
88114 ${opSnippet}
88115 }
88116
88117 void main() {
88118 float x = getAAtOutCoords();
88119 float y = unaryOperation(x);
88120
88121 setOutput(y);
88122 }
88123 `;
88124 }
88125 }
88126 const CHECK_NAN_SNIPPET$1 = `if (isnan(x)) return x;`;
88127 const LINEAR$1 = `return x;`;
88128 const ABS$1 = `return abs(x);`;
88129 function STEP(alpha = 0.0) {
88130 return CHECK_NAN_SNIPPET$1 + `
88131 return x > 0.0 ? 1.0 : float(${alpha});
88132 `;
88133 }
88134 const ELU$2 = `return (x >= 0.0) ? x : (exp(x) - 1.0);`;
88135 const RELU$2 = CHECK_NAN_SNIPPET$1 + `
88136 return (x < 0.0) ? 0.0 : x;
88137`;
88138 const RELU6$2 = CHECK_NAN_SNIPPET$1 + `
88139 return (x < 0.0) ? 0.0 : min(6.0, x);
88140`;
88141 const CLONE = 'return x;';
88142 const SIGMOID$2 = `return 1.0 / (1.0 + exp(-1.0 * x));`;
88143
88144 /**
88145 * @license
88146 * Copyright 2018 Google LLC. All Rights Reserved.
88147 * Licensed under the Apache License, Version 2.0 (the "License");
88148 * you may not use this file except in compliance with the License.
88149 * You may obtain a copy of the License at
88150 *
88151 * http://www.apache.org/licenses/LICENSE-2.0
88152 *
88153 * Unless required by applicable law or agreed to in writing, software
88154 * distributed under the License is distributed on an "AS IS" BASIS,
88155 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88156 * See the License for the specific language governing permissions and
88157 * limitations under the License.
88158 * =============================================================================
88159 */
88160 const LINEAR = `return x;`;
88161 const ELU$1 = `
88162 vec4 result;
88163
88164 result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);
88165 result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);
88166 result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);
88167 result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);
88168
88169 return result;
88170`;
88171 const RELU$1 = `
88172 vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));
88173 bvec4 isNaN = isnan(x);
88174
88175 result.r = isNaN.r ? x.r : result.r;
88176 result.g = isNaN.g ? x.g : result.g;
88177 result.b = isNaN.b ? x.b : result.b;
88178 result.a = isNaN.a ? x.a : result.a;
88179
88180 return result;
88181`;
88182 const RELU6$1 = `
88183 vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));
88184 bvec4 isNaN = isnan(x);
88185
88186 result.r = isNaN.r ? x.r : result.r;
88187 result.g = isNaN.g ? x.g : result.g;
88188 result.b = isNaN.b ? x.b : result.b;
88189 result.a = isNaN.a ? x.a : result.a;
88190
88191 return result;
88192`;
88193 const SIGMOID$1 = `return 1.0 / (1.0 + exp(-1.0 * x));`;
88194 class UnaryOpPackedProgram {
88195 constructor(aShape, opSnippet) {
88196 this.variableNames = ['A'];
88197 this.packedInputs = true;
88198 this.packedOutput = true;
88199 this.outputShape = aShape;
88200 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
88201 this.userCode = `
88202 vec4 unaryOperation(vec4 x) {
88203 ${opSnippet}
88204 }
88205
88206 void main() {
88207 vec4 x = getAAtOutCoords();
88208 vec4 y = unaryOperation(x);
88209
88210 setOutput(y);
88211 }
88212 `;
88213 }
88214 }
88215
88216 /**
88217 * @license
88218 * Copyright 2018 Google LLC. All Rights Reserved.
88219 * Licensed under the Apache License, Version 2.0 (the "License");
88220 * you may not use this file except in compliance with the License.
88221 * You may obtain a copy of the License at
88222 *
88223 * http://www.apache.org/licenses/LICENSE-2.0
88224 *
88225 * Unless required by applicable law or agreed to in writing, software
88226 * distributed under the License is distributed on an "AS IS" BASIS,
88227 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88228 * See the License for the specific language governing permissions and
88229 * limitations under the License.
88230 * =============================================================================
88231 */
88232 class UnpackProgram {
88233 constructor(outputShape) {
88234 this.variableNames = ['A'];
88235 this.packedInputs = true;
88236 this.packedOutput = false;
88237 this.outputShape = outputShape;
88238 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
88239 const rank = outputShape.length;
88240 const channels = getChannels('rc', rank);
88241 const dtype = getCoordsDataType(rank);
88242 const sourceCoords = getSourceCoords$2(rank, channels);
88243 const innerDims = channels.slice(-2);
88244 const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`;
88245 this.userCode = `
88246 void main() {
88247 ${dtype} rc = getOutputCoords();
88248 vec4 packedInput = getA(${sourceCoords});
88249
88250 setOutput(getChannel(packedInput, ${coords}));
88251 }
88252 `;
88253 }
88254 }
88255
88256 /**
88257 * @license
88258 * Copyright 2017 Google LLC. All Rights Reserved.
88259 * Licensed under the Apache License, Version 2.0 (the "License");
88260 * you may not use this file except in compliance with the License.
88261 * You may obtain a copy of the License at
88262 *
88263 * http://www.apache.org/licenses/LICENSE-2.0
88264 *
88265 * Unless required by applicable law or agreed to in writing, software
88266 * distributed under the License is distributed on an "AS IS" BASIS,
88267 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88268 * See the License for the specific language governing permissions and
88269 * limitations under the License.
88270 * =============================================================================
88271 */
88272 const whereImpl = whereImpl$2;
88273 const EPSILON_FLOAT32 = 1e-7;
88274 const EPSILON_FLOAT16 = 1e-4;
88275 const binaryCaches = {};
88276 function getBinaryCache(webGLVersion) {
88277 if (webGLVersion in binaryCaches) {
88278 return binaryCaches[webGLVersion];
88279 }
88280 binaryCaches[webGLVersion] = {};
88281 return binaryCaches[webGLVersion];
88282 }
88283 // Empirically determined constant used to determine size threshold for handing
88284 // off execution to the CPU.
88285 const CPU_HANDOFF_SIZE_THRESHOLD = env().getNumber('CPU_HANDOFF_SIZE_THRESHOLD');
88286 // Empirically determined constant used to decide the number of MB on GPU
88287 // before we warn about high memory use. The MB are this constant * screen area
88288 // * dpi / 1024 / 1024.
88289 const BEFORE_PAGING_CONSTANT = 600;
88290 function numMBBeforeWarning() {
88291 if (env().global.screen == null) {
88292 return 1024; // 1 GB.
88293 }
88294 return (env().global.screen.height * env().global.screen.width *
88295 window.devicePixelRatio) *
88296 BEFORE_PAGING_CONSTANT / 1024 / 1024;
88297 }
88298 class MathBackendWebGL extends KernelBackend {
88299 nextDataId() {
88300 return MathBackendWebGL.nextDataId++;
88301 }
88302 constructor(gpuResource) {
88303 super();
88304 // Maps data ids that have a pending read operation, to list of subscribers.
88305 this.pendingRead = new WeakMap();
88306 // List of data ids that are scheduled for disposal, but are waiting on a
88307 // pending read operation.
88308 this.pendingDisposal = new WeakSet();
88309 // Used to count the number of 'shallow' sliced tensors that point to the
88310 // same data id.
88311 this.dataRefCount = new WeakMap();
88312 this.numBytesInGPU = 0;
88313 // Accumulated time spent (including blocking) in uploading data to webgl.
88314 this.uploadWaitMs = 0;
88315 // Accumulated time spent (including blocking in downloading data from webgl.
88316 this.downloadWaitMs = 0;
88317 // record the last manual GL Flush time.
88318 this.lastGlFlushTime = 0;
88319 this.warnedAboutMemory = false;
88320 this.pendingDeletes = 0;
88321 this.disposed = false;
88322 if (!env().getBool('HAS_WEBGL')) {
88323 throw new Error('WebGL is not supported on this device');
88324 }
88325 let newGPGPU;
88326 if (gpuResource != null) {
88327 if (gpuResource instanceof GPGPUContext) {
88328 newGPGPU = gpuResource;
88329 }
88330 else {
88331 const gl = getWebGLContext(env().getNumber('WEBGL_VERSION'), gpuResource);
88332 newGPGPU = new GPGPUContext(gl);
88333 }
88334 this.binaryCache = {};
88335 this.gpgpuCreatedLocally = false;
88336 }
88337 else {
88338 const gl = getWebGLContext(env().getNumber('WEBGL_VERSION'));
88339 newGPGPU = new GPGPUContext(gl);
88340 this.binaryCache = getBinaryCache(env().getNumber('WEBGL_VERSION'));
88341 this.gpgpuCreatedLocally = true;
88342 }
88343 this.gpgpu = newGPGPU;
88344 this.canvas = this.gpgpu.gl.canvas;
88345 this.textureManager = new TextureManager(this.gpgpu);
88346 this.numMBBeforeWarning = numMBBeforeWarning();
88347 this.texData = new DataStorage(this, engine());
88348 }
88349 numDataIds() {
88350 return this.texData.numDataIds() - this.pendingDeletes;
88351 }
88352 // Writes a new entry to the data store with a WebGL texture, and registers it
88353 // to the texture manager.
88354 writeTexture(texture, shape, dtype, texHeight, texWidth, channels) {
88355 // Temporarily create an tensor info to make the texture compatible with
88356 // the runWebGLProgram's input.
88357 const input = this.makeTensorInfo(shape, dtype);
88358 const inData = this.texData.get(input.dataId);
88359 // Even though the input texture could be unpacked or dense packed, it is
88360 // always considered as unpacked for EncodeMatrixProgram.
88361 inData.isPacked = false;
88362 // Bind texture to the input tensor.
88363 inData.texture = { texture, texShape: [texHeight, texWidth] };
88364 inData.texShape = [texHeight, texWidth];
88365 const shapeAs3D = getShapeAs3D(shape);
88366 const program = new EncodeMatrixProgram(shapeAs3D, false /* isByteArray */, channels);
88367 const output = this.runWebGLProgram(program, [input], dtype, [[texHeight, texWidth]]);
88368 output.shape = shape;
88369 // Unbind the texture from the input tensor to avoid the texture being
88370 // released.
88371 inData.texture = null;
88372 this.disposeIntermediateTensorInfo(input);
88373 return output.dataId;
88374 }
88375 write(values, shape, dtype) {
88376 if (env().getBool('WEBGL_CHECK_NUMERICAL_PROBLEMS') ||
88377 env().getBool('DEBUG')) {
88378 this.checkNumericalProblems(values);
88379 }
88380 if (dtype === 'complex64' && values != null) {
88381 throw new Error(`Cannot write to a complex64 dtype. ` +
88382 `Please use tf.complex(real, imag).`);
88383 }
88384 const dataId = { id: this.nextDataId() };
88385 this.texData.set(dataId, { shape, dtype, values, usage: TextureUsage.UPLOAD, refCount: 1 });
88386 return dataId;
88387 }
88388 /** Return refCount of a `TensorData`. */
88389 refCount(dataId) {
88390 if (this.texData.has(dataId)) {
88391 const tensorData = this.texData.get(dataId);
88392 return tensorData.refCount;
88393 }
88394 return 0;
88395 }
88396 /** Increase refCount of a `TextureData`. */
88397 incRef(dataId) {
88398 const texData = this.texData.get(dataId);
88399 texData.refCount++;
88400 }
88401 /** Decrease refCount of a `TextureData`. */
88402 decRef(dataId) {
88403 if (this.texData.has(dataId)) {
88404 const texData = this.texData.get(dataId);
88405 texData.refCount--;
88406 }
88407 }
88408 move(dataId, values, shape, dtype, refCount) {
88409 if (env().getBool('DEBUG')) {
88410 this.checkNumericalProblems(values);
88411 }
88412 if (dtype === 'complex64') {
88413 throw new Error(`Cannot write to a complex64 dtype. ` +
88414 `Please use tf.complex(real, imag).`);
88415 }
88416 this.texData.set(dataId, { shape, dtype, values, usage: TextureUsage.UPLOAD, refCount });
88417 }
88418 disposeIntermediateTensorInfo(tensorInfo) {
88419 this.disposeData(tensorInfo.dataId);
88420 }
88421 readSync(dataId) {
88422 const texData = this.texData.get(dataId);
88423 const { values, dtype, complexTensorInfos, slice, shape, isPacked } = texData;
88424 // The presence of `slice` indicates this tensor is a shallow slice of a
88425 // different tensor, and is using that original tensor's texture. Run
88426 // `clone` in order to copy that texture and read from it.
88427 if (slice != null) {
88428 let program;
88429 if (isPacked) {
88430 program = new UnaryOpPackedProgram(shape, CLONE);
88431 }
88432 else {
88433 program = new UnaryOpProgram(shape, CLONE);
88434 }
88435 const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype);
88436 const data = this.readSync(res.dataId);
88437 this.disposeIntermediateTensorInfo(res);
88438 return data;
88439 }
88440 if (values != null) {
88441 return this.convertAndCacheOnCPU(dataId);
88442 }
88443 if (dtype === 'string') {
88444 return values;
88445 }
88446 const shouldTimeProgram = this.activeTimers != null;
88447 let start;
88448 if (shouldTimeProgram) {
88449 start = now();
88450 }
88451 let result;
88452 if (dtype === 'complex64') {
88453 const realValues = this.readSync(complexTensorInfos.real.dataId);
88454 const imagValues = this.readSync(complexTensorInfos.imag.dataId);
88455 result = mergeRealAndImagArrays(realValues, imagValues);
88456 }
88457 else {
88458 result = this.getValuesFromTexture(dataId);
88459 }
88460 if (shouldTimeProgram) {
88461 this.downloadWaitMs += now() - start;
88462 }
88463 return this.convertAndCacheOnCPU(dataId, result);
88464 }
88465 async read(dataId) {
88466 if (this.pendingRead.has(dataId)) {
88467 const subscribers = this.pendingRead.get(dataId);
88468 return new Promise(resolve => subscribers.push(resolve));
88469 }
88470 const texData = this.texData.get(dataId);
88471 const { values, shape, slice, dtype, complexTensorInfos, isPacked } = texData;
88472 // The presence of `slice` indicates this tensor is a shallow slice of a
88473 // different tensor, and is using that original tensor's texture. Run
88474 // `clone` in order to copy that texture and read from it.
88475 if (slice != null) {
88476 let program;
88477 if (isPacked) {
88478 program = new UnaryOpPackedProgram(shape, CLONE);
88479 }
88480 else {
88481 program = new UnaryOpProgram(shape, CLONE);
88482 }
88483 const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype);
88484 const data = this.read(res.dataId);
88485 this.disposeIntermediateTensorInfo(res);
88486 return data;
88487 }
88488 if (values != null) {
88489 return this.convertAndCacheOnCPU(dataId);
88490 }
88491 if (env().getBool('DEBUG')) {
88492 // getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') caused a blocking GPU call.
88493 // For performance reason, only check it for debugging. In production,
88494 // it doesn't handle this use case anyway, so behavior is not changed.
88495 if (!env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') &&
88496 env().getNumber('WEBGL_VERSION') === 2) {
88497 throw new Error(`tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and ` +
88498 `WEBGL_VERSION=2 not yet supported.`);
88499 }
88500 }
88501 let buffer = null;
88502 let tmpDownloadTarget;
88503 if (dtype !== 'complex64' && env().get('WEBGL_BUFFER_SUPPORTED')) {
88504 // Possibly copy the texture into a buffer before inserting a fence.
88505 tmpDownloadTarget = this.decode(dataId);
88506 const tmpData = this.texData.get(tmpDownloadTarget.dataId);
88507 buffer = this.gpgpu.createBufferFromTexture(tmpData.texture.texture, ...getDenseTexShape(shape));
88508 }
88509 this.pendingRead.set(dataId, []);
88510 if (dtype !== 'complex64') {
88511 // Create a fence and wait for it to resolve.
88512 await this.gpgpu.createAndWaitForFence();
88513 }
88514 // Download the values from the GPU.
88515 let vals;
88516 if (dtype === 'complex64') {
88517 const ps = await Promise.all([
88518 this.read(complexTensorInfos.real.dataId),
88519 this.read(complexTensorInfos.imag.dataId)
88520 ]);
88521 const realValues = ps[0];
88522 const imagValues = ps[1];
88523 vals = mergeRealAndImagArrays(realValues, imagValues);
88524 }
88525 else if (buffer == null) {
88526 vals = this.getValuesFromTexture(dataId);
88527 }
88528 else {
88529 const size = sizeFromShape(shape);
88530 vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size);
88531 }
88532 if (tmpDownloadTarget != null) {
88533 this.disposeIntermediateTensorInfo(tmpDownloadTarget);
88534 }
88535 if (buffer != null) {
88536 const gl = this.gpgpu.gl;
88537 callAndCheck(gl, () => gl.deleteBuffer(buffer));
88538 }
88539 const dTypeVals = this.convertAndCacheOnCPU(dataId, vals);
88540 const subscribers = this.pendingRead.get(dataId);
88541 this.pendingRead.delete(dataId);
88542 // Notify all pending reads.
88543 subscribers.forEach(resolve => resolve(dTypeVals));
88544 if (this.pendingDisposal.has(dataId)) {
88545 this.pendingDisposal.delete(dataId);
88546 if (this.disposeData(dataId)) {
88547 engine().removeDataId(dataId, this);
88548 }
88549 this.pendingDeletes--;
88550 }
88551 return dTypeVals;
88552 }
88553 /**
88554 * Read tensor to a new texture that is densely packed for ease of use.
88555 * @param dataId The source tensor.
88556 * @param options
88557 * customTexShape: Optional. If set, will use the user defined texture
88558 * shape to create the texture.
88559 */
88560 readToGPU(dataId, options = {}) {
88561 const texData = this.texData.get(dataId);
88562 const { values, shape, slice, dtype, isPacked, texture } = texData;
88563 if (dtype === 'complex64') {
88564 throw new Error('Does not support reading texture for complex64 dtype.');
88565 }
88566 // The presence of `slice` indicates this tensor is a shallow slice of a
88567 // different tensor, and is using that original tensor's texture. Run
88568 // `clone` in order to copy that texture and read from it.
88569 if (slice != null) {
88570 let program;
88571 if (isPacked) {
88572 program = new UnaryOpPackedProgram(shape, CLONE);
88573 }
88574 else {
88575 program = new UnaryOpProgram(shape, CLONE);
88576 }
88577 const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype);
88578 const gpuResouorce = this.readToGPU(res, options);
88579 this.disposeIntermediateTensorInfo(res);
88580 return gpuResouorce;
88581 }
88582 if (texture == null) {
88583 if (values != null) {
88584 throw new Error('Data is not on GPU but on CPU.');
88585 }
88586 else {
88587 throw new Error('There is no data on GPU or CPU.');
88588 }
88589 }
88590 // Decode the texture so that it is stored densely (using four channels).
88591 const tmpTarget = this.decode(dataId, options.customTexShape);
88592 // Make engine track this tensor, so that we can dispose it later.
88593 const tensorRef = engine().makeTensorFromTensorInfo(tmpTarget);
88594 const tmpData = this.texData.get(tmpTarget.dataId);
88595 return Object.assign({ tensorRef }, tmpData.texture);
88596 }
88597 bufferSync(t) {
88598 const data = this.readSync(t.dataId);
88599 if (t.dtype === 'string') {
88600 try {
88601 // Decode the bytes into string.
88602 const strings = data.map(d => decodeString(d));
88603 return buffer(t.shape, t.dtype, strings);
88604 }
88605 catch (_a) {
88606 throw new Error('Failed to decode encoded string bytes into utf-8');
88607 }
88608 }
88609 return buffer(t.shape, t.dtype, data);
88610 }
88611 checkNumericalProblems(values) {
88612 if (values == null) {
88613 return;
88614 }
88615 for (let i = 0; i < values.length; i++) {
88616 const num = values[i];
88617 if (!canBeRepresented(num)) {
88618 if (env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) {
88619 throw Error(`The value ${num} cannot be represented with your ` +
88620 `current settings. Consider enabling float32 rendering: ` +
88621 `'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'`);
88622 }
88623 throw Error(`The value ${num} cannot be represented on this device.`);
88624 }
88625 }
88626 }
88627 getValuesFromTexture(dataId) {
88628 const { shape, dtype, isPacked } = this.texData.get(dataId);
88629 const size = sizeFromShape(shape);
88630 if (env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) {
88631 const tmpTarget = this.decode(dataId);
88632 const tmpData = this.texData.get(tmpTarget.dataId);
88633 const vals = this.gpgpu
88634 .downloadMatrixFromPackedTexture(tmpData.texture.texture, ...getDenseTexShape(shape))
88635 .subarray(0, size);
88636 this.disposeIntermediateTensorInfo(tmpTarget);
88637 return vals;
88638 }
88639 const shouldUsePackedProgram = env().getBool('WEBGL_PACK') && isPacked === true;
88640 const outputShape = shouldUsePackedProgram ? getShapeAs3D(shape) : shape;
88641 const program = shouldUsePackedProgram ?
88642 new EncodeFloatPackedProgram(outputShape) :
88643 new EncodeFloatProgram(outputShape);
88644 const output = this.runWebGLProgram(program, [{ shape: outputShape, dtype, dataId }], 'float32');
88645 const tmpData = this.texData.get(output.dataId);
88646 const vals = this.gpgpu
88647 .downloadByteEncodedFloatMatrixFromOutputTexture(tmpData.texture.texture, tmpData.texShape[0], tmpData.texShape[1])
88648 .subarray(0, size);
88649 this.disposeIntermediateTensorInfo(output);
88650 return vals;
88651 }
88652 timerAvailable() {
88653 return env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0;
88654 }
88655 time(f) {
88656 const oldActiveTimers = this.activeTimers;
88657 const newActiveTimers = [];
88658 let outerMostTime = false;
88659 if (this.programTimersStack == null) {
88660 this.programTimersStack = newActiveTimers;
88661 outerMostTime = true;
88662 }
88663 else {
88664 this.activeTimers.push(newActiveTimers);
88665 }
88666 this.activeTimers = newActiveTimers;
88667 f();
88668 // needing to split these up because util.flatten only accepts certain types
88669 const flattenedActiveTimerQueries = flatten$2(this.activeTimers.map((d) => d.query))
88670 .filter(d => d != null);
88671 const flattenedActiveTimerNames = flatten$2(this.activeTimers.map((d) => d.name))
88672 .filter(d => d != null);
88673 this.activeTimers = oldActiveTimers;
88674 if (outerMostTime) {
88675 this.programTimersStack = null;
88676 }
88677 const res = {
88678 uploadWaitMs: this.uploadWaitMs,
88679 downloadWaitMs: this.downloadWaitMs,
88680 kernelMs: null,
88681 wallMs: null // will be filled by the engine
88682 };
88683 return (async () => {
88684 if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') >
88685 0) {
88686 const kernelMs = await Promise.all(flattenedActiveTimerQueries);
88687 res['kernelMs'] = sum$4(kernelMs);
88688 res['getExtraProfileInfo'] = () => kernelMs
88689 .map((d, i) => ({ name: flattenedActiveTimerNames[i], ms: d }))
88690 .map(d => `${d.name}: ${d.ms}`)
88691 .join(', ');
88692 }
88693 else {
88694 res['kernelMs'] = {
88695 error: 'WebGL query timers are not supported in this environment.'
88696 };
88697 }
88698 this.uploadWaitMs = 0;
88699 this.downloadWaitMs = 0;
88700 return res;
88701 })();
88702 }
88703 memory() {
88704 return {
88705 unreliable: false,
88706 numBytesInGPU: this.numBytesInGPU,
88707 numBytesInGPUAllocated: this.textureManager.numBytesAllocated,
88708 numBytesInGPUFree: this.textureManager.numBytesFree
88709 };
88710 }
88711 startTimer() {
88712 if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
88713 return this.gpgpu.beginQuery();
88714 }
88715 return { startMs: now(), endMs: null };
88716 }
88717 endTimer(query) {
88718 if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
88719 this.gpgpu.endQuery();
88720 return query;
88721 }
88722 query.endMs = now();
88723 return query;
88724 }
88725 async getQueryTime(query) {
88726 if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
88727 return this.gpgpu.waitForQueryAndGetTime(query);
88728 }
88729 const timerQuery = query;
88730 return timerQuery.endMs - timerQuery.startMs;
88731 }
88732 /**
88733 * Decrease the RefCount on the dataId and dispose the memory if the dataId
88734 * has 0 refCount. If there are pending read on the data, the disposal would
88735 * added to the pending delete queue. Return true if the dataId is removed
88736 * from backend or the backend does not contain the dataId, false if the
88737 * dataId is not removed. Memory may or may not be released even when dataId
88738 * is removed, which also depends on dataRefCount, see `releaseGPU`.
88739 * @param dataId
88740 * @oaram force Optional, remove the data regardless of refCount
88741 */
88742 disposeData(dataId, force = false) {
88743 if (this.pendingDisposal.has(dataId)) {
88744 return false;
88745 }
88746 // No-op if already disposed.
88747 if (!this.texData.has(dataId)) {
88748 return true;
88749 }
88750 // if force flag is set, change refCount to 0, this would ensure disposal
88751 // when added to the pendingDisposal queue. Memory may or may not be
88752 // released, which also depends on dataRefCount, see `releaseGPU`.
88753 if (force) {
88754 this.texData.get(dataId).refCount = 0;
88755 }
88756 else {
88757 this.texData.get(dataId).refCount--;
88758 }
88759 if (!force && this.texData.get(dataId).refCount > 0) {
88760 return false;
88761 }
88762 if (this.pendingRead.has(dataId)) {
88763 this.pendingDisposal.add(dataId);
88764 this.pendingDeletes++;
88765 return false;
88766 }
88767 this.releaseGPUData(dataId);
88768 const { complexTensorInfos } = this.texData.get(dataId);
88769 if (complexTensorInfos != null) {
88770 this.disposeData(complexTensorInfos.real.dataId, force);
88771 this.disposeData(complexTensorInfos.imag.dataId, force);
88772 }
88773 this.texData.delete(dataId);
88774 return true;
88775 }
88776 releaseGPUData(dataId) {
88777 const { texture, dtype, texShape, usage, isPacked, slice } = this.texData.get(dataId);
88778 const key = slice && slice.origDataId || dataId;
88779 const refCount = this.dataRefCount.get(key);
88780 if (refCount > 1) {
88781 this.dataRefCount.set(key, refCount - 1);
88782 }
88783 else {
88784 this.dataRefCount.delete(key);
88785 if (texture != null) {
88786 this.numBytesInGPU -= this.computeBytes(texShape, dtype);
88787 this.textureManager.releaseTexture(texture, texShape, usage, isPacked);
88788 }
88789 }
88790 const texData = this.texData.get(dataId);
88791 texData.texture = null;
88792 texData.texShape = null;
88793 texData.isPacked = false;
88794 texData.slice = null;
88795 }
88796 getTexture(dataId) {
88797 this.uploadToGPU(dataId);
88798 return this.texData.get(dataId).texture.texture;
88799 }
88800 /**
88801 * Returns internal information for the specific data bucket. Used in unit
88802 * tests.
88803 */
88804 getDataInfo(dataId) {
88805 return this.texData.get(dataId);
88806 }
88807 /*
88808 Tests whether all the inputs to an op are small and on the CPU. This heuristic
88809 determines when it would be faster to execute a kernel on the CPU. WebGL
88810 kernels opt into running this check and forwarding when appropriate.
88811 TODO(https://github.com/tensorflow/tfjs/issues/872): Develop a more
88812 sustainable strategy for optimizing backend execution of ops.
88813 */
88814 shouldExecuteOnCPU(inputs, sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD) {
88815 return env().getBool('WEBGL_CPU_FORWARD') &&
88816 inputs.every(input => this.texData.get(input.dataId).texture == null &&
88817 sizeFromShape(input.shape) < sizeThreshold);
88818 }
88819 getGPGPUContext() {
88820 return this.gpgpu;
88821 }
88822 where(condition) {
88823 warn('tf.where() in webgl locks the UI thread. ' +
88824 'Call tf.whereAsync() instead');
88825 const condVals = condition.dataSync();
88826 return whereImpl(condition.shape, condVals);
88827 }
88828 packedUnaryOp(x, op, dtype) {
88829 const program = new UnaryOpPackedProgram(x.shape, op);
88830 const outInfo = this.compileAndRun(program, [x], dtype);
88831 return engine().makeTensorFromTensorInfo(outInfo);
88832 }
88833 // TODO(msoulanille) remove this once the backend has been modularized
88834 // a copy is needed here to break a circular dependency.
88835 // Also remove the op from unary_op.
88836 abs(x) {
88837 // TODO: handle cases when x is complex.
88838 if (this.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') {
88839 const outValues = simpleAbsImplCPU(this.texData.get(x.dataId).values);
88840 return this.makeOutput(x.shape, x.dtype, outValues);
88841 }
88842 if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
88843 return this.packedUnaryOp(x, ABS$1, x.dtype);
88844 }
88845 const program = new UnaryOpProgram(x.shape, ABS$1);
88846 const outInfo = this.compileAndRun(program, [x]);
88847 return engine().makeTensorFromTensorInfo(outInfo);
88848 }
88849 makeTensorInfo(shape, dtype, values) {
88850 let dataId;
88851 if (dtype === 'string' && values != null && values.length > 0 &&
88852 isString(values[0])) {
88853 const encodedValues = values.map(d => encodeString(d));
88854 dataId = this.write(encodedValues, shape, dtype);
88855 }
88856 else {
88857 dataId = this.write(values, shape, dtype);
88858 }
88859 this.texData.get(dataId).usage = null;
88860 return { dataId, shape, dtype };
88861 }
88862 makeOutput(shape, dtype, values) {
88863 return engine().makeTensorFromTensorInfo(this.makeTensorInfo(shape, dtype, values), this);
88864 }
88865 unpackTensor(input) {
88866 const program = new UnpackProgram(input.shape);
88867 return this.runWebGLProgram(program, [input], input.dtype);
88868 }
88869 packTensor(input) {
88870 const program = new PackProgram(input.shape);
88871 const preventEagerUnpackingOutput = true;
88872 return this.runWebGLProgram(program, [input], input.dtype, null /* customUniformValues */, preventEagerUnpackingOutput);
88873 }
88874 packedReshape(input, afterShape) {
88875 const input3DShape = [
88876 getBatchDim(input.shape),
88877 ...getRowsCols(input.shape)
88878 ];
88879 const input3D = {
88880 dtype: input.dtype,
88881 shape: input3DShape,
88882 dataId: input.dataId
88883 };
88884 const afterShapeAs3D = [
88885 getBatchDim(afterShape), ...getRowsCols(afterShape)
88886 ];
88887 const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
88888 const preventEagerUnpackingOfOutput = true;
88889 const customValues = [input3DShape];
88890 const output = this.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput);
88891 return { dataId: output.dataId, shape: afterShape, dtype: output.dtype };
88892 }
88893 decode(dataId, customTexShape) {
88894 const texData = this.texData.get(dataId);
88895 const { isPacked, shape, dtype } = texData;
88896 if (customTexShape != null) {
88897 const size = sizeFromShape(shape);
88898 const texSize = customTexShape[0] * customTexShape[1] * 4;
88899 assert$1(size <= texSize, () => 'customTexShape is too small. ' +
88900 'Row * Column * 4 should be equal or larger than the ' +
88901 'size of the tensor data.');
88902 }
88903 const shapeAs3D = getShapeAs3D(shape);
88904 let program;
88905 if (isPacked) {
88906 program = new DecodeMatrixPackedProgram(shapeAs3D);
88907 }
88908 else {
88909 program = new DecodeMatrixProgram(shapeAs3D);
88910 }
88911 const preventEagerUnpackingOfOutput = true;
88912 const customValues = [customTexShape != null ? customTexShape :
88913 getDenseTexShape(shapeAs3D)];
88914 const out = this.runWebGLProgram(program, [{ shape: shapeAs3D, dtype, dataId }], dtype, customValues, preventEagerUnpackingOfOutput, customTexShape);
88915 return { dtype, shape, dataId: out.dataId };
88916 }
88917 runWebGLProgram(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput = false, customTexShape) {
88918 const output = this.makeTensorInfo(program.outputShape, outputDtype);
88919 const outData = this.texData.get(output.dataId);
88920 if (program.packedOutput) {
88921 outData.isPacked = true;
88922 }
88923 if (program.outPackingScheme === PackingScheme.DENSE) {
88924 const texelShape = customTexShape != null ?
88925 customTexShape :
88926 getDenseTexShape(program.outputShape);
88927 // For a densely packed output, we explicitly set texShape
88928 // so it doesn't get assigned later according to our typical packing
88929 // scheme wherein a single texel can only contain values from adjacent
88930 // rows/cols.
88931 outData.texShape = texelShape.map(d => d * 2);
88932 }
88933 if (program.outTexUsage != null) {
88934 outData.usage = program.outTexUsage;
88935 }
88936 if (sizeFromShape(output.shape) === 0) {
88937 // Short-circuit the computation since the result is empty (has 0 in its
88938 // shape).
88939 outData.values =
88940 getTypedArrayFromDType(output.dtype, 0);
88941 return output;
88942 }
88943 const dataToDispose = [];
88944 const inputsData = inputs.map(input => {
88945 if (input.dtype === 'complex64') {
88946 throw new Error(`GPGPUProgram does not support complex64 input. For complex64 ` +
88947 `dtypes, please separate the program into real and imaginary ` +
88948 `parts.`);
88949 }
88950 let texData = this.texData.get(input.dataId);
88951 if (texData.texture == null) {
88952 if (!program.packedInputs &&
88953 sizeFromShape(input.shape) <=
88954 env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) {
88955 // Upload small tensors that live on the CPU as uniforms, not as
88956 // textures. Do this only when the environment supports 32bit floats
88957 // due to problems when comparing 16bit floats with 32bit floats.
88958 // TODO(https://github.com/tensorflow/tfjs/issues/821): Make it
88959 // possible for packed shaders to sample from uniforms.
88960 return {
88961 shape: input.shape,
88962 texData: null,
88963 isUniform: true,
88964 uniformValues: texData.values
88965 };
88966 }
88967 // This ensures that if a packed program's inputs have not yet been
88968 // uploaded to the GPU, they get uploaded as packed right off the bat.
88969 if (program.packedInputs) {
88970 texData.isPacked = true;
88971 texData.shape = input.shape;
88972 }
88973 }
88974 this.uploadToGPU(input.dataId);
88975 if (!!texData.isPacked !== !!program.packedInputs) {
88976 input = texData.isPacked ? this.unpackTensor(input) :
88977 this.packTensor(input);
88978 dataToDispose.push(input);
88979 texData = this.texData.get(input.dataId);
88980 }
88981 else if (texData.isPacked &&
88982 !isReshapeFree(texData.shape, input.shape)) {
88983 // This is a special case where a texture exists for a tensor
88984 // but the shapes are incompatible (due to packing constraints) because
88985 // the tensor did not have a chance to go through the packed reshape
88986 // shader. This only happens when we reshape the *same* tensor to form
88987 // *distinct* inputs to an op, e.g. dotting a vector with itself. This
88988 // case will disappear once packed uploading is the default.
88989 const savedInput = input;
88990 const targetShape = input.shape;
88991 input.shape = texData.shape;
88992 input = this.packedReshape(input, targetShape);
88993 dataToDispose.push(input);
88994 texData = this.texData.get(input.dataId);
88995 savedInput.shape = targetShape;
88996 }
88997 return { shape: input.shape, texData, isUniform: false };
88998 });
88999 this.uploadToGPU(output.dataId);
89000 const outputData = { shape: output.shape, texData: outData, isUniform: false };
89001 const key = makeShaderKey(program, inputsData, outputData);
89002 const binary = this.getAndSaveBinary(key, () => {
89003 return compileProgram(this.gpgpu, program, inputsData, outputData);
89004 });
89005 const shouldTimeProgram = this.activeTimers != null;
89006 let query;
89007 if (shouldTimeProgram) {
89008 query = this.startTimer();
89009 }
89010 if (!env().get('ENGINE_COMPILE_ONLY')) {
89011 runProgram(this.gpgpu, binary, inputsData, outputData, customUniformValues);
89012 }
89013 dataToDispose.forEach(info => this.disposeIntermediateTensorInfo(info));
89014 if (shouldTimeProgram) {
89015 query = this.endTimer(query);
89016 this.activeTimers.push({ name: program.constructor.name, query: this.getQueryTime(query) });
89017 }
89018 const glFlushThreshold = env().getNumber('WEBGL_FLUSH_THRESHOLD');
89019 // Manually GL flush requested
89020 if (glFlushThreshold > 0) {
89021 const time = now();
89022 if ((time - this.lastGlFlushTime) > glFlushThreshold) {
89023 this.gpgpu.gl.flush();
89024 this.lastGlFlushTime = time;
89025 }
89026 }
89027 if (!env().getBool('WEBGL_LAZILY_UNPACK') && outData.isPacked &&
89028 preventEagerUnpackingOfOutput === false) {
89029 const unpacked = this.unpackTensor(output);
89030 this.disposeIntermediateTensorInfo(output);
89031 return unpacked;
89032 }
89033 return output;
89034 }
89035 compileAndRun(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput = false) {
89036 outputDtype = outputDtype || inputs[0].dtype;
89037 const outInfo = this.runWebGLProgram(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput);
89038 return outInfo;
89039 }
89040 getAndSaveBinary(key, getBinary) {
89041 if (!(key in this.binaryCache)) {
89042 this.binaryCache[key] = getBinary();
89043 }
89044 return this.binaryCache[key];
89045 }
89046 getTextureManager() {
89047 return this.textureManager;
89048 }
89049 dispose() {
89050 if (this.disposed) {
89051 return;
89052 }
89053 // Avoid disposing the compiled webgl programs during unit testing because
89054 // it slows down test execution.
89055 if (!env().getBool('IS_TEST')) {
89056 const allKeys = Object.keys(this.binaryCache);
89057 allKeys.forEach(key => {
89058 this.gpgpu.deleteProgram(this.binaryCache[key].webGLProgram);
89059 delete this.binaryCache[key];
89060 });
89061 }
89062 this.textureManager.dispose();
89063 if (this.canvas != null &&
89064 (typeof (HTMLCanvasElement) !== 'undefined' &&
89065 this.canvas instanceof HTMLCanvasElement)) {
89066 this.canvas.remove();
89067 }
89068 else {
89069 this.canvas = null;
89070 }
89071 if (this.gpgpuCreatedLocally) {
89072 this.gpgpu.program = null;
89073 this.gpgpu.dispose();
89074 }
89075 this.disposed = true;
89076 }
89077 floatPrecision() {
89078 if (this.floatPrecisionValue == null) {
89079 this.floatPrecisionValue = tidy(() => {
89080 if (!env().get('WEBGL_RENDER_FLOAT32_ENABLED')) {
89081 // Momentarily switching DEBUG flag to false so we don't throw an
89082 // error trying to upload a small value.
89083 const debugFlag = env().getBool('DEBUG');
89084 env().set('DEBUG', false);
89085 const underflowCheckValue = this.abs(scalar(1e-8)).dataSync()[0];
89086 env().set('DEBUG', debugFlag);
89087 if (underflowCheckValue > 0) {
89088 return 32;
89089 }
89090 }
89091 return 16;
89092 });
89093 }
89094 return this.floatPrecisionValue;
89095 }
89096 /** Returns the smallest representable number. */
89097 epsilon() {
89098 return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;
89099 }
89100 uploadToGPU(dataId) {
89101 const texData = this.texData.get(dataId);
89102 const { shape, dtype, values, texture, usage, isPacked } = texData;
89103 if (texture != null) {
89104 // Array is already on GPU. No-op.
89105 return;
89106 }
89107 const shouldTimeProgram = this.activeTimers != null;
89108 let start;
89109 if (shouldTimeProgram) {
89110 start = now();
89111 }
89112 let texShape = texData.texShape;
89113 if (texShape == null) {
89114 // This texShape may not be the final texture shape. For packed or dense
89115 // textures, the texShape will be changed when textures are created.
89116 texShape = getTextureShapeFromLogicalShape(shape, isPacked);
89117 texData.texShape = texShape;
89118 }
89119 if (values != null) {
89120 const shapeAs3D = getShapeAs3D(shape);
89121 let program;
89122 let width = texShape[1], height = texShape[0];
89123 const isByteArray = values instanceof Uint8Array || values instanceof Uint8ClampedArray;
89124 // texture for float array is PhysicalTextureType.PACKED_2X2_FLOAT32, we
89125 // need to make sure the upload uses the same packed size
89126 if (isPacked || !isByteArray) {
89127 [width, height] = getPackedMatrixTextureShapeWidthHeight(texShape[0], texShape[1]);
89128 }
89129 if (isPacked) {
89130 program = new EncodeMatrixPackedProgram(shapeAs3D, isByteArray);
89131 }
89132 else {
89133 program = new EncodeMatrixProgram(shapeAs3D, isByteArray);
89134 }
89135 // TexShape for float array needs to be the original shape, which byte
89136 // array needs to be packed size. This allow the data upload shape to be
89137 // matched with texture creation logic.
89138 const tempDenseInputTexShape = isByteArray ? [height, width] : texShape;
89139 const tempDenseInputHandle = this.makeTensorInfo(tempDenseInputTexShape, dtype);
89140 const tempDenseInputTexData = this.texData.get(tempDenseInputHandle.dataId);
89141 if (isByteArray) {
89142 tempDenseInputTexData.usage = TextureUsage.PIXELS;
89143 }
89144 else {
89145 tempDenseInputTexData.usage = TextureUsage.UPLOAD;
89146 }
89147 tempDenseInputTexData.texShape = tempDenseInputTexShape;
89148 this.gpgpu.uploadDenseMatrixToTexture(this.getTexture(tempDenseInputHandle.dataId), width, height, values);
89149 const customValues = [[height, width]];
89150 // We want the output to remain packed regardless of the value of
89151 // WEBGL_PACK.
89152 const preventEagerUnpacking = true;
89153 const encodedOutputTarget = this.runWebGLProgram(program, [tempDenseInputHandle], dtype, customValues, preventEagerUnpacking);
89154 // Have the original texture assume the identity of the encoded output.
89155 const outputTexData = this.texData.get(encodedOutputTarget.dataId);
89156 texData.texShape = outputTexData.texShape;
89157 texData.isPacked = outputTexData.isPacked;
89158 texData.usage = outputTexData.usage;
89159 if (!env().get('ENGINE_COMPILE_ONLY')) {
89160 texData.texture = outputTexData.texture;
89161 // Once uploaded, don't store the values on cpu.
89162 texData.values = null;
89163 this.texData.delete(encodedOutputTarget.dataId);
89164 }
89165 else {
89166 this.disposeData(encodedOutputTarget.dataId);
89167 }
89168 this.disposeIntermediateTensorInfo(tempDenseInputHandle);
89169 if (shouldTimeProgram) {
89170 this.uploadWaitMs += now() - start;
89171 }
89172 }
89173 else {
89174 const newTexture = this.acquireTexture(texShape, usage, dtype, isPacked);
89175 texData.texture = newTexture;
89176 }
89177 }
89178 convertAndCacheOnCPU(dataId, float32Values) {
89179 const texData = this.texData.get(dataId);
89180 const { dtype } = texData;
89181 if (float32Values != null) {
89182 texData.values = float32ToTypedArray(float32Values, dtype);
89183 }
89184 return texData.values;
89185 }
89186 acquireTexture(texShape, texType, dtype, isPacked) {
89187 this.numBytesInGPU += this.computeBytes(texShape, dtype);
89188 if (!this.warnedAboutMemory &&
89189 this.numBytesInGPU > this.numMBBeforeWarning * 1024 * 1024) {
89190 const mb = (this.numBytesInGPU / 1024 / 1024).toFixed(2);
89191 this.warnedAboutMemory = true;
89192 console.warn(`High memory usage in GPU: ${mb} MB, ` +
89193 `most likely due to a memory leak`);
89194 }
89195 return this.textureManager.acquireTexture(texShape, texType, isPacked);
89196 }
89197 computeBytes(shape, dtype) {
89198 return shape[0] * shape[1] * bytesPerElement(dtype);
89199 }
89200 checkCompileCompletion() {
89201 for (const [, binary] of Object.entries(this.binaryCache)) {
89202 this.checkCompletion_(binary);
89203 }
89204 }
89205 async checkCompileCompletionAsync() {
89206 const ps = [];
89207 if (this.gpgpu.parallelCompilationExtension) {
89208 for (const [, binary] of Object.entries(this.binaryCache)) {
89209 ps.push(this.checkCompletionAsync_(binary));
89210 }
89211 return Promise.all(ps);
89212 }
89213 else {
89214 for (const [, binary] of Object.entries(this.binaryCache)) {
89215 const p = new Promise((resolve) => {
89216 try {
89217 this.checkCompletion_(binary);
89218 resolve(true);
89219 }
89220 catch (error) {
89221 throw error;
89222 }
89223 });
89224 ps.push(p);
89225 }
89226 return Promise.all(ps);
89227 }
89228 }
89229 async checkCompletionAsync_(binary) {
89230 if (this.gpgpu.gl.getProgramParameter(binary.webGLProgram, this.gpgpu.parallelCompilationExtension.COMPLETION_STATUS_KHR)) {
89231 return this.checkCompletion_(binary);
89232 }
89233 else {
89234 await nextFrame();
89235 return this.checkCompletionAsync_(binary);
89236 }
89237 }
89238 checkCompletion_(binary) {
89239 if (this.gpgpu.gl.getProgramParameter(binary.webGLProgram, this.gpgpu.gl.LINK_STATUS) === false) {
89240 console.log(this.gpgpu.gl.getProgramInfoLog(binary.webGLProgram));
89241 if (this.gpgpu.gl.getShaderParameter(binary.fragmentShader, this.gpgpu.gl.COMPILE_STATUS) === false) {
89242 logShaderSourceAndInfoLog(binary.source, this.gpgpu.gl.getShaderInfoLog(binary.fragmentShader));
89243 throw new Error('Failed to compile fragment shader.');
89244 }
89245 throw new Error('Failed to link vertex and fragment shaders.');
89246 }
89247 return true;
89248 }
89249 getUniformLocations() {
89250 for (const binary of Object.values(this.binaryCache)) {
89251 // TODO: Iterating through all binaries to build VAOs is supposed to be in
89252 // a seperate function, like 'setVaos'. However, to avoid breaking changes
89253 // for the users using parallel compile feature now, buildVao is silently
89254 // added here.
89255 this.gpgpu.buildVao(binary.webGLProgram);
89256 const { variablesLocations, customUniformLocations, infLoc, nanLoc, outShapeLocation, outShapeStridesLocation, outTexShapeLocation } = getUniformLocations(this.gpgpu, binary.program, binary.webGLProgram);
89257 binary.variablesLocations = variablesLocations;
89258 binary.customUniformLocations = customUniformLocations;
89259 binary.infLoc = infLoc;
89260 binary.nanLoc = nanLoc;
89261 binary.outShapeLocation = outShapeLocation;
89262 binary.outShapeStridesLocation = outShapeStridesLocation;
89263 binary.outTexShapeLocation = outTexShapeLocation;
89264 }
89265 }
89266 /**
89267 * Create a TF.js tensor out of an existing WebGL texture. A new texture will
89268 * be created.
89269 */
89270 createTensorFromGPUData(values, shape, dtype) {
89271 values.channels = values.channels || 'RGBA';
89272 const { texture, height, width, channels } = values;
89273 const backend = engine().backend;
89274 // Have to throw an error, otherwise WebGL just warns and returns wrong
89275 // values.
89276 if (!backend.gpgpu.gl.isTexture(texture)) {
89277 throw new Error(`The texture is invalid. Also, please make sure the texture and ` +
89278 `the TFJS WebGL backend are using the same canvas. If you want to ` +
89279 `use your own custom canvas, you have to create and use the custom ` +
89280 `TFJS WebGL backend created from the canvas through ` +
89281 `'new tf.MathBackendWebGL(customCanvas)'.`);
89282 }
89283 const dataId = backend.writeTexture(texture, shape, dtype, height, width, channels);
89284 return engine().makeTensorFromDataId(dataId, shape, dtype, backend);
89285 }
89286 }
89287 MathBackendWebGL.nextDataId = 0;
89288 function float32ToTypedArray(a, dtype) {
89289 if (dtype === 'float32' || dtype === 'complex64') {
89290 return a;
89291 }
89292 else if (dtype === 'int32' || dtype === 'bool') {
89293 const result = (dtype === 'int32') ? new Int32Array(a.length) :
89294 new Uint8Array(a.length);
89295 for (let i = 0; i < result.length; ++i) {
89296 result[i] = Math.round(a[i]);
89297 }
89298 return result;
89299 }
89300 else {
89301 throw new Error(`Unknown dtype ${dtype}`);
89302 }
89303 }
89304
89305 /** @license See the LICENSE file. */
89306 // This code is auto-generated, do not modify this file!
89307 const version$2 = '4.22.0';
89308
89309 /**
89310 * @license
89311 * Copyright 2019 Google LLC. All Rights Reserved.
89312 * Licensed under the Apache License, Version 2.0 (the "License");
89313 * you may not use this file except in compliance with the License.
89314 * You may obtain a copy of the License at
89315 *
89316 * http://www.apache.org/licenses/LICENSE-2.0
89317 *
89318 * Unless required by applicable law or agreed to in writing, software
89319 * distributed under the License is distributed on an "AS IS" BASIS,
89320 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89321 * See the License for the specific language governing permissions and
89322 * limitations under the License.
89323 * =============================================================================
89324 */
89325 /**
89326 * Enforce use of half precision textures if available on the platform.
89327 *
89328 * @doc {heading: 'Environment', namespace: 'webgl'}
89329 */
89330 function forceHalfFloat() {
89331 env().set('WEBGL_FORCE_F16_TEXTURES', true);
89332 }
89333
89334 /**
89335 * @license
89336 * Copyright 2020 Google Inc. All Rights Reserved.
89337 * Licensed under the Apache License, Version 2.0 (the "License");
89338 * you may not use this file except in compliance with the License.
89339 * You may obtain a copy of the License at
89340 *
89341 * http://www.apache.org/licenses/LICENSE-2.0
89342 *
89343 * Unless required by applicable law or agreed to in writing, software
89344 * distributed under the License is distributed on an "AS IS" BASIS,
89345 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89346 * See the License for the specific language governing permissions and
89347 * limitations under the License.
89348 * =============================================================================
89349 */
89350 if (isBrowser()) {
89351 registerBackend('webgl', () => new MathBackendWebGL(), 2 /* priority */);
89352 }
89353 const webgl = { forceHalfFloat };
89354
89355 /**
89356 * @license
89357 * Copyright 2017 Google LLC. All Rights Reserved.
89358 * Licensed under the Apache License, Version 2.0 (the "License");
89359 * you may not use this file except in compliance with the License.
89360 * You may obtain a copy of the License at
89361 *
89362 * http://www.apache.org/licenses/LICENSE-2.0
89363 *
89364 * Unless required by applicable law or agreed to in writing, software
89365 * distributed under the License is distributed on an "AS IS" BASIS,
89366 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89367 * See the License for the specific language governing permissions and
89368 * limitations under the License.
89369 * =============================================================================
89370 */
89371 const CHECK_NAN_SNIPPET = `
89372 if (isnan(a)) return a;
89373 if (isnan(b)) return b;
89374`;
89375 const SQUARED_DIFFERENCE$1 = 'return (a - b) * (a - b);';
89376 class BinaryOpProgram {
89377 constructor(op, aShape, bShape) {
89378 this.variableNames = ['A', 'B'];
89379 this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
89380 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
89381 this.userCode = `
89382 float binaryOperation(float a, float b) {
89383 ${op}
89384 }
89385
89386 void main() {
89387 float a = getAAtOutCoords();
89388 float b = getBAtOutCoords();
89389 setOutput(binaryOperation(a, b));
89390 }
89391 `;
89392 }
89393 }
89394
89395 /**
89396 * @license
89397 * Copyright 2018 Google LLC. All Rights Reserved.
89398 * Licensed under the Apache License, Version 2.0 (the "License");
89399 * you may not use this file except in compliance with the License.
89400 * You may obtain a copy of the License at
89401 *
89402 * http://www.apache.org/licenses/LICENSE-2.0
89403 *
89404 * Unless required by applicable law or agreed to in writing, software
89405 * distributed under the License is distributed on an "AS IS" BASIS,
89406 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89407 * See the License for the specific language governing permissions and
89408 * limitations under the License.
89409 * =============================================================================
89410 */
89411 const CHECK_NAN_SNIPPET_PACKED = `
89412 result.r = isNaN.r ? NAN : result.r;
89413 result.g = isNaN.g ? NAN : result.g;
89414 result.b = isNaN.b ? NAN : result.b;
89415 result.a = isNaN.a ? NAN : result.a;
89416`;
89417 const ELU_DER$1 = `
89418 vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));
89419 return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));
89420`;
89421 const NOT_EQUAL$1 = `
89422 return vec4(notEqual(a, b));
89423`;
89424 class BinaryOpPackedProgram {
89425 constructor(op, aShape, bShape, checkOutOfBounds = false) {
89426 this.variableNames = ['A', 'B'];
89427 this.supportsBroadcasting = true;
89428 this.packedInputs = true;
89429 this.packedOutput = true;
89430 this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
89431 const rank = this.outputShape.length;
89432 this.enableShapeUniforms = useShapeUniforms(rank);
89433 let checkOutOfBoundsString = '';
89434 if (checkOutOfBounds) {
89435 if (rank === 0 || sizeFromShape(this.outputShape) === 1) {
89436 checkOutOfBoundsString = `
89437 result.y = 0.;
89438 result.z = 0.;
89439 result.w = 0.;
89440 `;
89441 }
89442 else {
89443 const dtype = getCoordsDataType(rank);
89444 checkOutOfBoundsString = `
89445 ${dtype} coords = getOutputCoords();
89446 `;
89447 if (rank === 1) {
89448 if (this.enableShapeUniforms) {
89449 checkOutOfBoundsString += `
89450 result.y = (coords + 1) >= outShape ? 0. : result.y;
89451 result.z = 0.;
89452 result.w = 0.;
89453 `;
89454 }
89455 else {
89456 checkOutOfBoundsString += `
89457 result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y;
89458 result.z = 0.;
89459 result.w = 0.;
89460 `;
89461 }
89462 }
89463 else {
89464 const channels = getChannels('coords', rank);
89465 if (this.enableShapeUniforms) {
89466 checkOutOfBoundsString += `
89467 bool nextRowOutOfBounds =
89468 (${channels[rank - 2]} + 1) >= outShape[${rank} - 2];
89469 bool nextColOutOfBounds =
89470 (${channels[rank - 1]} + 1) >= outShape[${rank} - 1];
89471 result.y = nextColOutOfBounds ? 0. : result.y;
89472 result.z = nextRowOutOfBounds ? 0. : result.z;
89473 result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
89474 `;
89475 }
89476 else {
89477 checkOutOfBoundsString += `
89478 bool nextRowOutOfBounds =
89479 (${channels[rank - 2]} + 1) >= ${this.outputShape[rank - 2]};
89480 bool nextColOutOfBounds =
89481 (${channels[rank - 1]} + 1) >= ${this.outputShape[rank - 1]};
89482 result.y = nextColOutOfBounds ? 0. : result.y;
89483 result.z = nextRowOutOfBounds ? 0. : result.z;
89484 result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
89485 `;
89486 }
89487 }
89488 }
89489 }
89490 this.userCode = `
89491 vec4 binaryOperation(vec4 a, vec4 b) {
89492 ${op}
89493 }
89494
89495 void main() {
89496 vec4 a = getAAtOutCoords();
89497 vec4 b = getBAtOutCoords();
89498
89499 vec4 result = binaryOperation(a, b);
89500 ${checkOutOfBoundsString}
89501
89502 setOutput(result);
89503 }
89504 `;
89505 }
89506 }
89507
89508 /**
89509 * @license
89510 * Copyright 2020 Google LLC. All Rights Reserved.
89511 * Licensed under the Apache License, Version 2.0 (the "License");
89512 * you may not use this file except in compliance with the License.
89513 * You may obtain a copy of the License at
89514 *
89515 * http://www.apache.org/licenses/LICENSE-2.0
89516 *
89517 * Unless required by applicable law or agreed to in writing, software
89518 * distributed under the License is distributed on an "AS IS" BASIS,
89519 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89520 * See the License for the specific language governing permissions and
89521 * limitations under the License.
89522 * =============================================================================
89523 */
89524 function identity(args) {
89525 const { inputs, backend } = args;
89526 const { x } = inputs;
89527 backend.incRef(x.dataId);
89528 return { dataId: x.dataId, shape: x.shape, dtype: x.dtype };
89529 }
89530 const identityConfig = {
89531 kernelName: Identity$1,
89532 backendName: 'webgl',
89533 kernelFunc: identity
89534 };
89535
89536 /**
89537 * @license
89538 * Copyright 2020 Google LLC. All Rights Reserved.
89539 * Licensed under the Apache License, Version 2.0 (the "License");
89540 * you may not use this file except in compliance with the License.
89541 * You may obtain a copy of the License at
89542 *
89543 * http://www.apache.org/licenses/LICENSE-2.0
89544 *
89545 * Unless required by applicable law or agreed to in writing, software
89546 * distributed under the License is distributed on an "AS IS" BASIS,
89547 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89548 * See the License for the specific language governing permissions and
89549 * limitations under the License.
89550 * =============================================================================
89551 */
89552 /**
89553 * In WebGL data is stored in GPU textures which can't be efficiently copied, so
89554 * complex tensors share data with their real and imaginary components. Complex
89555 * tensors' reference to the components is tracked by refCount on the individual
89556 * component. The refCounts are increased by the identity call.
89557 *
89558 * When a complex tensor is disposed, it will reduce the refCount on the
89559 * components by calling disposeData on each.
89560 */
89561 function complex(args) {
89562 const { inputs, backend } = args;
89563 const { real, imag } = inputs;
89564 const complexInfo = backend.makeTensorInfo(real.shape, 'complex64');
89565 const complex = backend.texData.get(complexInfo.dataId);
89566 const realTensorInfo = identity({ inputs: { x: real }, backend });
89567 const imagTensorInfo = identity({ inputs: { x: imag }, backend });
89568 complex.complexTensorInfos = { real: realTensorInfo, imag: imagTensorInfo };
89569 return complexInfo;
89570 }
89571 const complexConfig = {
89572 kernelName: Complex,
89573 backendName: 'webgl',
89574 kernelFunc: complex
89575 };
89576
89577 /**
89578 * @license
89579 * Copyright 2020 Google LLC. All Rights Reserved.
89580 * Licensed under the Apache License, Version 2.0 (the "License");
89581 * you may not use this file except in compliance with the License.
89582 * You may obtain a copy of the License at
89583 *
89584 * http://www.apache.org/licenses/LICENSE-2.0
89585 *
89586 * Unless required by applicable law or agreed to in writing, software
89587 * distributed under the License is distributed on an "AS IS" BASIS,
89588 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89589 * See the License for the specific language governing permissions and
89590 * limitations under the License.
89591 * =============================================================================
89592 */
89593 const LEAKYRELU = `return (a < 0.) ? b * a : a;`;
89594 const LEAKYRELU_PACKED = `
89595 vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));
89596 return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);
89597`;
89598 function leakyRelu(args) {
89599 const { inputs, backend, attrs } = args;
89600 const { x } = inputs;
89601 const { alpha } = attrs;
89602 const $alpha = backend.makeTensorInfo([], 'float32', createScalarValue(alpha, 'float32'));
89603 const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
89604 new BinaryOpPackedProgram(LEAKYRELU_PACKED, x.shape, $alpha.shape) :
89605 new BinaryOpProgram(LEAKYRELU, x.shape, $alpha.shape);
89606 const result = backend.runWebGLProgram(program, [x, $alpha], 'float32');
89607 backend.disposeIntermediateTensorInfo($alpha);
89608 return result;
89609 }
89610 const leakyReluConfig = {
89611 kernelName: LeakyRelu,
89612 backendName: 'webgl',
89613 kernelFunc: leakyRelu
89614 };
89615
89616 /**
89617 * @license
89618 * Copyright 2020 Google LLC. All Rights Reserved.
89619 * Licensed under the Apache License, Version 2.0 (the "License");
89620 * you may not use this file except in compliance with the License.
89621 * You may obtain a copy of the License at
89622 *
89623 * http://www.apache.org/licenses/LICENSE-2.0
89624 *
89625 * Unless required by applicable law or agreed to in writing, software
89626 * distributed under the License is distributed on an "AS IS" BASIS,
89627 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89628 * See the License for the specific language governing permissions and
89629 * limitations under the License.
89630 * =============================================================================
89631 */
89632 const PRELU = `return (a < 0.) ? b * a : a;`;
89633 const PRELU_PACKED = `
89634 vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));
89635 return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);
89636`;
89637 function prelu(args) {
89638 const { inputs, backend } = args;
89639 const { x, alpha } = inputs;
89640 const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
89641 new BinaryOpPackedProgram(PRELU_PACKED, x.shape, alpha.shape) :
89642 new BinaryOpProgram(PRELU, x.shape, alpha.shape);
89643 return backend.runWebGLProgram(program, [x, alpha], 'float32');
89644 }
89645 const preluConfig = {
89646 kernelName: Prelu,
89647 backendName: 'webgl',
89648 kernelFunc: prelu
89649 };
89650
89651 /**
89652 * @license
89653 * Copyright 2020 Google LLC. All Rights Reserved.
89654 * Licensed under the Apache License, Version 2.0 (the "License");
89655 * you may not use this file except in compliance with the License.
89656 * You may obtain a copy of the License at
89657 *
89658 * http://www.apache.org/licenses/LICENSE-2.0
89659 *
89660 * Unless required by applicable law or agreed to in writing, software
89661 * distributed under the License is distributed on an "AS IS" BASIS,
89662 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89663 * See the License for the specific language governing permissions and
89664 * limitations under the License.
89665 * =============================================================================
89666 */
89667 const CHECK_NAN_SNIPPET_UNARY = `if (isnan(x)) return x;`;
89668 /**
89669 * Template that creates a `KernelFunc` for unary ops.
89670 * @param opSnippet Op snippet to create `UnaryOpProgram`.
89671 * @param packedOpSnippet Op snippet to create `UnaryOpPackedProgram`.
89672 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
89673 * result has the same dtype as the first input. This is mainly used in
89674 * comparison kernels, such as Equal, Less, Greater, etc.
89675 */
89676 function unaryKernelFunc({ opSnippet, packedOpSnippet, cpuKernelImpl, dtype }) {
89677 return ({ inputs, backend }) => {
89678 const { x } = inputs;
89679 const webglBackend = backend;
89680 const $dtype = dtype || x.dtype;
89681 if (webglBackend.shouldExecuteOnCPU([x]) && cpuKernelImpl != null) {
89682 const xData = webglBackend.texData.get(x.dataId);
89683 const outValues = cpuKernelImpl(xData.values, $dtype);
89684 return webglBackend.makeTensorInfo(x.shape, $dtype, outValues);
89685 }
89686 const shouldUsePackedProgram = env().getBool('WEBGL_PACK_UNARY_OPERATIONS') && packedOpSnippet != null;
89687 let program;
89688 if (shouldUsePackedProgram) {
89689 program = new UnaryOpPackedProgram(x.shape, packedOpSnippet);
89690 }
89691 else {
89692 program = new UnaryOpProgram(x.shape, opSnippet);
89693 }
89694 return webglBackend.runWebGLProgram(program, [x], $dtype);
89695 };
89696 }
89697 /**
89698 * Template that creates a `KernelFunc` for binary ops.
89699 * @param opSnippet Op snippet to create `BinaryOpProgram`.
89700 * @param packedOpSnippet Op snippet to create `BinaryOpPackedProgram`.
89701 * @param checkOutOfBoundsForPackedProgram Whether to set checkOutOfBounds=true
89702 * when creating BinaryOpPackedProgram.
89703 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
89704 * result has the same dtype as the first input. This is mainly used in
89705 * comparison kernels, such as Equal, Less, Greater, etc.
89706 */
89707 function binaryKernelFunc({ opSnippet, packedOpSnippet, checkOutOfBounds = false, supportsComplex = false, cpuKernelImpl, dtype }) {
89708 return ({ inputs, backend }) => {
89709 const { a, b } = inputs;
89710 const webglBackend = backend;
89711 if (supportsComplex && a.dtype === 'complex64') {
89712 const aData = webglBackend.texData.get(a.dataId);
89713 const bData = webglBackend.texData.get(b.dataId);
89714 const [real, imag] = [
89715 [aData.complexTensorInfos.real, bData.complexTensorInfos.real],
89716 [aData.complexTensorInfos.imag, bData.complexTensorInfos.imag]
89717 ].map(complexParts => {
89718 const [aPart, bPart] = complexParts;
89719 const aHandle = {
89720 dataId: aPart.dataId,
89721 dtype: aPart.dtype,
89722 shape: a.shape
89723 };
89724 const bHandle = {
89725 dataId: bPart.dataId,
89726 dtype: bPart.dtype,
89727 shape: b.shape
89728 };
89729 const program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
89730 return webglBackend.runWebGLProgram(program, [aHandle, bHandle], upcastType(aPart.dtype, bPart.dtype));
89731 });
89732 const complexOutput = complex({ inputs: { real, imag }, backend: webglBackend });
89733 webglBackend.disposeIntermediateTensorInfo(real);
89734 webglBackend.disposeIntermediateTensorInfo(imag);
89735 // TODO(annxingyuan): Implement CPU forwarding for complex inputs.
89736 return complexOutput;
89737 }
89738 const $dtype = dtype || upcastType(a.dtype, b.dtype);
89739 if ((a.dtype === 'string' || b.dtype === 'string' ||
89740 webglBackend.shouldExecuteOnCPU([a, b])) &&
89741 cpuKernelImpl != null) {
89742 const aVals = webglBackend.texData.get(a.dataId).values;
89743 const bVals = webglBackend.texData.get(b.dataId).values;
89744 const decodedAVals = a.dtype === 'string' ?
89745 // tslint:disable-next-line: no-any
89746 fromUint8ToStringArray(aVals) :
89747 aVals;
89748 const decodedBVals = a.dtype === 'string' ?
89749 // tslint:disable-next-line: no-any
89750 fromUint8ToStringArray(bVals) :
89751 bVals;
89752 const [outValues, outShape] = cpuKernelImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype);
89753 const out = webglBackend.makeTensorInfo(outShape, $dtype);
89754 const outData = webglBackend.texData.get(out.dataId);
89755 outData.values = outValues;
89756 return out;
89757 }
89758 const shouldUsePackedProgram = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') &&
89759 packedOpSnippet != null;
89760 let program;
89761 if (shouldUsePackedProgram) {
89762 program = new BinaryOpPackedProgram(packedOpSnippet, a.shape, b.shape, checkOutOfBounds);
89763 }
89764 else {
89765 program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
89766 }
89767 return webglBackend.runWebGLProgram(program, [a, b], $dtype);
89768 };
89769 }
89770 function mapActivationToShaderProgram(activation, packed = false) {
89771 if (activation === 'linear') {
89772 if (packed) {
89773 return LINEAR;
89774 }
89775 return LINEAR$1;
89776 }
89777 else if (activation === 'relu') {
89778 if (packed) {
89779 return RELU$1;
89780 }
89781 return RELU$2;
89782 }
89783 else if (activation === 'elu') {
89784 if (packed) {
89785 return ELU$1;
89786 }
89787 return ELU$2;
89788 }
89789 else if (activation === 'relu6') {
89790 if (packed) {
89791 return RELU6$1;
89792 }
89793 return RELU6$2;
89794 }
89795 else if (activation === 'prelu') {
89796 if (packed) {
89797 return PRELU_PACKED;
89798 }
89799 return PRELU;
89800 }
89801 else if (activation === 'leakyrelu') {
89802 if (packed) {
89803 return LEAKYRELU_PACKED;
89804 }
89805 return LEAKYRELU;
89806 }
89807 else if (activation === 'sigmoid') {
89808 if (packed) {
89809 return SIGMOID$1;
89810 }
89811 return SIGMOID$2;
89812 }
89813 throw new Error(`Activation ${activation} has not been implemented for the WebGL backend.`);
89814 }
89815
89816 /**
89817 * @license
89818 * Copyright 2018 Google LLC. All Rights Reserved.
89819 * Licensed under the Apache License, Version 2.0 (the "License");
89820 * you may not use this file except in compliance with the License.
89821 * You may obtain a copy of the License at
89822 *
89823 * http://www.apache.org/licenses/LICENSE-2.0
89824 *
89825 * Unless required by applicable law or agreed to in writing, software
89826 * distributed under the License is distributed on an "AS IS" BASIS,
89827 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89828 * See the License for the specific language governing permissions and
89829 * limitations under the License.
89830 * =============================================================================
89831 */
89832 class MatMulPackedProgram {
89833 constructor(aShape, bShape, outputShape, transposeA = false, transposeB = false, addBias = false, activation = null, hasPreluActivation = false, hasLeakyreluActivation = false) {
89834 this.variableNames = ['matrixA', 'matrixB'];
89835 this.packedInputs = true;
89836 this.packedOutput = true;
89837 this.outputShape = outputShape;
89838 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
89839 const sharedDim = transposeA ? aShape[1] : aShape[2];
89840 const sharedDimensionPacked = Math.ceil(sharedDim / 2);
89841 const aSample = transposeA ? 'i * 2, rc.y' : 'rc.y, i * 2';
89842 const bSample = transposeB ? 'rc.z, i * 2' : 'i * 2, rc.z';
89843 const aSwizzle = transposeA ? ['a.xxyy', 'a.zzww'] : ['a.xxzz', 'a.yyww'];
89844 const bSwizzle = transposeB ? ['b.xzxz', 'b.ywyw'] : ['b.xyxy', 'b.zwzw'];
89845 let activationSnippet = '', applyActivationSnippet = '';
89846 if (activation) {
89847 if (hasPreluActivation) {
89848 activationSnippet = `vec4 activation(vec4 a) {
89849 vec4 b = getPreluActivationWeightsAtOutCoords();
89850 ${activation}
89851 }`;
89852 }
89853 else if (hasLeakyreluActivation) {
89854 activationSnippet = `vec4 activation(vec4 a) {
89855 vec4 b = getLeakyreluAlphaAtOutCoords();
89856 ${activation}
89857 }`;
89858 }
89859 else {
89860 activationSnippet = `vec4 activation(vec4 x) {
89861 ${activation}
89862 }`;
89863 }
89864 applyActivationSnippet = `result = activation(result);`;
89865 }
89866 const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
89867 if (addBias) {
89868 this.variableNames.push('bias');
89869 }
89870 if (hasPreluActivation) {
89871 this.variableNames.push('preluActivationWeights');
89872 }
89873 if (hasLeakyreluActivation) {
89874 this.variableNames.push('leakyreluAlpha');
89875 }
89876 let batchASnippet = 'rc.x';
89877 let batchBSnippet = 'rc.x';
89878 if (aShape[0] < bShape[0]) {
89879 batchASnippet = `imod(rc.x, ${aShape[0]})`;
89880 }
89881 else if (bShape[0] < aShape[0]) {
89882 batchBSnippet = `imod(rc.x, ${bShape[0]})`;
89883 }
89884 this.userCode = `
89885 ${activationSnippet}
89886 // Don't use uniform for sharedDimensionPacked for performance.
89887 const float sharedDimension = ${sharedDimensionPacked}.0;
89888
89889 vec4 dot2x2ARowBCol(ivec3 rc) {
89890 vec4 result = vec4(0);
89891 int batchA = ${batchASnippet};
89892 int batchB = ${batchBSnippet};
89893 for (int i = 0; i < ${sharedDimensionPacked}; i++) {
89894 vec4 a = getMatrixA(batchA, ${aSample});
89895 vec4 b = getMatrixB(batchB, ${bSample});
89896
89897 // These swizzled products need to be separately added.
89898 // See: https://github.com/tensorflow/tfjs/issues/1735
89899 result += (${aSwizzle[0]} * ${bSwizzle[0]});
89900 result += (${aSwizzle[1]} * ${bSwizzle[1]});
89901 }
89902 return result;
89903 }
89904
89905 void main() {
89906 ivec3 rc = getOutputCoords();
89907 vec4 result = dot2x2ARowBCol(rc);
89908
89909 ${addBiasSnippet}
89910
89911 ${applyActivationSnippet}
89912
89913 setOutput(result);
89914 }
89915 `;
89916 }
89917 }
89918
89919 /**
89920 * @license
89921 * Copyright 2018 Google LLC. All Rights Reserved.
89922 * Licensed under the Apache License, Version 2.0 (the "License");
89923 * you may not use this file except in compliance with the License.
89924 * You may obtain a copy of the License at
89925 *
89926 * http://www.apache.org/licenses/LICENSE-2.0
89927 *
89928 * Unless required by applicable law or agreed to in writing, software
89929 * distributed under the License is distributed on an "AS IS" BASIS,
89930 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89931 * See the License for the specific language governing permissions and
89932 * limitations under the License.
89933 * =============================================================================
89934 */
89935 // (Ar + Ai)(Br + Bi) =
89936 // ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr
89937 // Yr = ArBr - AB
89938 // Yi = ArBi + AiBr
89939 const COMPLEX_MULTIPLY = {
89940 REAL: 'return areal * breal - aimag * bimag;',
89941 IMAG: 'return areal * bimag + aimag * breal;'
89942 };
89943 class BinaryOpComplexProgram {
89944 constructor(op, aShape, bShape) {
89945 this.variableNames = ['AReal', 'AImag', 'BReal', 'BImag'];
89946 this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
89947 this.userCode = `
89948 float binaryOpComplex(
89949 float areal, float aimag, float breal, float bimag) {
89950 ${op}
89951 }
89952
89953 void main() {
89954 float areal = getARealAtOutCoords();
89955 float aimag = getAImagAtOutCoords();
89956 float breal = getBRealAtOutCoords();
89957 float bimag = getBImagAtOutCoords();
89958 setOutput(binaryOpComplex(areal, aimag, breal, bimag));
89959 }
89960 `;
89961 }
89962 }
89963
89964 /**
89965 * @license
89966 * Copyright 2020 Google LLC. All Rights Reserved.
89967 * Licensed under the Apache License, Version 2.0 (the "License");
89968 * you may not use this file except in compliance with the License.
89969 * You may obtain a copy of the License at
89970 *
89971 * http://www.apache.org/licenses/LICENSE-2.0
89972 *
89973 * Unless required by applicable law or agreed to in writing, software
89974 * distributed under the License is distributed on an "AS IS" BASIS,
89975 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
89976 * See the License for the specific language governing permissions and
89977 * limitations under the License.
89978 * =============================================================================
89979 */
89980 const MUL = 'return a * b;';
89981 function multiply(args) {
89982 const { inputs, backend } = args;
89983 const { a, b } = inputs;
89984 const dtype = upcastType(a.dtype, b.dtype);
89985 if (a.dtype === 'complex64') {
89986 const aData = backend.texData.get(a.dataId);
89987 const bData = backend.texData.get(b.dataId);
89988 const realProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.REAL, a.shape, b.shape);
89989 const imagProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.IMAG, a.shape, b.shape);
89990 const inputs = [
89991 {
89992 dataId: aData.complexTensorInfos.real.dataId,
89993 dtype: aData.complexTensorInfos.real.dtype,
89994 shape: a.shape
89995 },
89996 {
89997 dataId: aData.complexTensorInfos.imag.dataId,
89998 dtype: aData.complexTensorInfos.imag.dtype,
89999 shape: a.shape
90000 },
90001 {
90002 dataId: bData.complexTensorInfos.real.dataId,
90003 dtype: bData.complexTensorInfos.real.dtype,
90004 shape: b.shape
90005 },
90006 {
90007 dataId: bData.complexTensorInfos.imag.dataId,
90008 dtype: bData.complexTensorInfos.imag.dtype,
90009 shape: b.shape
90010 }
90011 ];
90012 const realPart = backend.runWebGLProgram(realProgram, inputs, 'float32');
90013 const imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32');
90014 const complexOutput = complex({ inputs: { real: realPart, imag: imagPart }, backend });
90015 backend.disposeIntermediateTensorInfo(realPart);
90016 backend.disposeIntermediateTensorInfo(imagPart);
90017 // TODO(annxingyuan): CPU forwarding for complex inputs.
90018 return complexOutput;
90019 }
90020 if (backend.shouldExecuteOnCPU([a, b])) {
90021 const aData = backend.texData.get(a.dataId);
90022 const bData = backend.texData.get(b.dataId);
90023 const [outValues, outShape] = multiplyImplCPU(a.shape, b.shape, aData.values, bData.values, dtype);
90024 const out = backend.makeTensorInfo(outShape, dtype);
90025 const outData = backend.texData.get(out.dataId);
90026 outData.values = outValues;
90027 return out;
90028 }
90029 let program;
90030 if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
90031 program = new BinaryOpPackedProgram(MUL, a.shape, b.shape);
90032 }
90033 else {
90034 program = new BinaryOpProgram(MUL, a.shape, b.shape);
90035 }
90036 return backend.runWebGLProgram(program, [a, b], dtype);
90037 }
90038 const multiplyConfig = {
90039 kernelName: Multiply$1,
90040 backendName: 'webgl',
90041 kernelFunc: multiply
90042 };
90043
90044 /**
90045 * @license
90046 * Copyright 2020 Google LLC. All Rights Reserved.
90047 * Licensed under the Apache License, Version 2.0 (the "License");
90048 * you may not use this file except in compliance with the License.
90049 * You may obtain a copy of the License at
90050 *
90051 * http://www.apache.org/licenses/LICENSE-2.0
90052 *
90053 * Unless required by applicable law or agreed to in writing, software
90054 * distributed under the License is distributed on an "AS IS" BASIS,
90055 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90056 * See the License for the specific language governing permissions and
90057 * limitations under the License.
90058 * =============================================================================
90059 */
90060 function packedReshape(input, afterShape, backend) {
90061 const input3DShape = [getBatchDim(input.shape),
90062 ...getRowsCols(input.shape)];
90063 const input3D = {
90064 dtype: input.dtype,
90065 shape: input3DShape,
90066 dataId: input.dataId
90067 };
90068 const afterShapeAs3D = [getBatchDim(afterShape),
90069 ...getRowsCols(afterShape)];
90070 const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
90071 const preventEagerUnpackingOfOutput = true;
90072 const customValues = [input3DShape];
90073 const output = backend.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput);
90074 return { dataId: output.dataId, shape: afterShape, dtype: output.dtype };
90075 }
90076
90077 /**
90078 * @license
90079 * Copyright 2020 Google LLC. All Rights Reserved.
90080 * Licensed under the Apache License, Version 2.0 (the "License");
90081 * you may not use this file except in compliance with the License.
90082 * You may obtain a copy of the License at
90083 *
90084 * http://www.apache.org/licenses/LICENSE-2.0
90085 *
90086 * Unless required by applicable law or agreed to in writing, software
90087 * distributed under the License is distributed on an "AS IS" BASIS,
90088 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90089 * See the License for the specific language governing permissions and
90090 * limitations under the License.
90091 * =============================================================================
90092 */
90093 function reshape(args) {
90094 const { inputs, backend, attrs } = args;
90095 const { x } = inputs;
90096 const { shape } = attrs;
90097 const webglBackend = backend;
90098 const xSize = sizeFromShape(x.shape);
90099 const $shape = inferFromImplicitShape(shape, xSize);
90100 const $xSize = sizeFromShape($shape);
90101 assert$1(xSize === $xSize, () => `The new shape (${$shape}) has ${$xSize} elements and the old ` +
90102 `shape (${x.shape}) has ${xSize} elements. The new shape and old ` +
90103 `shape must have the same number of elements.`);
90104 const xTexData = webglBackend.texData.get(x.dataId);
90105 if (xTexData.isPacked && !isReshapeFree(x.shape, $shape) &&
90106 !(xTexData.texture !== null && isReshapeFree(xTexData.shape, $shape))) {
90107 return packedReshape(x, $shape, webglBackend);
90108 }
90109 webglBackend.incRef(x.dataId);
90110 return { dataId: x.dataId, shape: $shape, dtype: x.dtype };
90111 }
90112 const reshapeConfig = {
90113 kernelName: Reshape$1,
90114 backendName: 'webgl',
90115 kernelFunc: reshape
90116 };
90117
90118 /**
90119 * @license
90120 * Copyright 2020 Google LLC. All Rights Reserved.
90121 * Licensed under the Apache License, Version 2.0 (the "License");
90122 * you may not use this file except in compliance with the License.
90123 * You may obtain a copy of the License at
90124 *
90125 * http://www.apache.org/licenses/LICENSE-2.0
90126 *
90127 * Unless required by applicable law or agreed to in writing, software
90128 * distributed under the License is distributed on an "AS IS" BASIS,
90129 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90130 * See the License for the specific language governing permissions and
90131 * limitations under the License.
90132 * =============================================================================
90133 */
90134 class MeanProgram {
90135 constructor(reduceInfo, divisor) {
90136 this.variableNames = ['x'];
90137 const { windowSize, batchSize, inSize, outSize } = reduceInfo;
90138 this.outputShape = [batchSize, outSize];
90139 const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
90140 const windowSizeVec4Remainder = windowSize % 4;
90141 let updateSnippet = `sumValue += dot(values, ones);`;
90142 if (divisor != null) {
90143 const denominator = 1 / divisor;
90144 updateSnippet = `sumValue += dot(values * ${isInt(denominator) ? denominator.toPrecision(2) :
90145 denominator}, ones);`;
90146 }
90147 let checkOutOfBounds = '';
90148 if (inSize % windowSize > 0) {
90149 checkOutOfBounds = `
90150 if (inIdx < 0 || inIdx >= ${inSize}) {
90151 return 0.0;
90152 }
90153 `;
90154 }
90155 this.userCode = `
90156 const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
90157
90158 float getValue(int batch, int inIdx) {
90159 ${checkOutOfBounds}
90160 return getX(batch, inIdx);
90161 }
90162
90163 void main() {
90164 ivec2 coords = getOutputCoords();
90165 int batch = coords[0];
90166 int outIdx = coords[1];
90167 int inOffset = outIdx * ${windowSize};
90168
90169 float sumValue = 0.0;
90170
90171 for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
90172 int inIdx = inOffset + i;
90173 vec4 values = vec4(
90174 getValue(batch, inIdx),
90175 getValue(batch, inIdx + 1),
90176 getValue(batch, inIdx + 2),
90177 getValue(batch, inIdx + 3)
90178 );
90179
90180 ${updateSnippet}
90181 }
90182
90183 int inIdx = inOffset + ${windowSizeNearestVec4};
90184 if (${windowSizeVec4Remainder === 1}) {
90185 vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0);
90186
90187 ${updateSnippet}
90188 } else if (${windowSizeVec4Remainder === 2}) {
90189 vec4 values = vec4(
90190 getValue(batch, inIdx),
90191 getValue(batch, inIdx + 1), 0.0, 0.0);
90192
90193 ${updateSnippet}
90194 } else if (${windowSizeVec4Remainder === 3}) {
90195 vec4 values = vec4(
90196 getValue(batch, inIdx),
90197 getValue(batch, inIdx + 1),
90198 getValue(batch, inIdx + 2), 0.0);
90199
90200 ${updateSnippet}
90201 }
90202 setOutput(sumValue);
90203 }
90204 `;
90205 }
90206 }
90207
90208 /**
90209 * @license
90210 * Copyright 2017 Google LLC. All Rights Reserved.
90211 * Licensed under the Apache License, Version 2.0 (the "License");
90212 * you may not use this file except in compliance with the License.
90213 * You may obtain a copy of the License at
90214 *
90215 * http://www.apache.org/licenses/LICENSE-2.0
90216 *
90217 * Unless required by applicable law or agreed to in writing, software
90218 * distributed under the License is distributed on an "AS IS" BASIS,
90219 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90220 * See the License for the specific language governing permissions and
90221 * limitations under the License.
90222 * =============================================================================
90223 */
90224 class ReduceProgram {
90225 constructor(reduceInfo, reduceType) {
90226 this.variableNames = ['x'];
90227 const { windowSize, batchSize, inSize, outSize } = reduceInfo;
90228 this.outputShape = [batchSize, outSize];
90229 let initializationValue = '0.0';
90230 let compareOp = ``;
90231 if (reduceType === 'prod') {
90232 initializationValue = '1.0';
90233 }
90234 else if (reduceType === 'min') {
90235 // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
90236 initializationValue = '1.0 / 1e-20';
90237 compareOp = `min`;
90238 }
90239 else if (reduceType === 'max') {
90240 // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
90241 initializationValue = '-1.0 / 1e-20';
90242 compareOp = `max`;
90243 }
90244 let returnValue = `${reduceType}(${reduceType}(${reduceType}(` +
90245 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
90246 if (reduceType === 'sum') {
90247 returnValue = `sumValue`;
90248 }
90249 else if (reduceType === 'prod') {
90250 returnValue = `prodValue`;
90251 }
90252 else if (reduceType === 'all') {
90253 returnValue = `allValue`;
90254 }
90255 else if (reduceType === 'any') {
90256 returnValue = `anyValue`;
90257 }
90258 const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
90259 const windowSizeVec4Remainder = windowSize % 4;
90260 let updateSnippet = `
90261 if (${reduceType === 'sum'}) {
90262 sumValue += dot(values, ones);
90263 } else if (${reduceType === 'prod'}) {
90264 vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);
90265 prodValue *= tmp[0] * tmp[1];
90266 } else {
90267 minMaxValue = ${compareOp}(values, minMaxValue);
90268 if (${reduceType === 'min'} || ${reduceType === 'max'}) {
90269 minMaxValue = ${compareOp}(values, minMaxValue);
90270 bvec4 isNaN = isnan(values);
90271 if (isNaN.r || isNaN.g || isNaN.b || isNaN.a) {
90272 minMaxValue = vec4(NAN);
90273 }
90274 }
90275 }
90276 `;
90277 let vecType = `vec4`;
90278 if (reduceType === 'all') {
90279 initializationValue = '1.0';
90280 updateSnippet = `
90281 bool reducedAllValue = all(values);
90282 float floatedReducedAllValue = float(reducedAllValue);
90283 allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);
90284 `;
90285 vecType = `bvec4`;
90286 }
90287 else if (reduceType === 'any') {
90288 initializationValue = '0.0';
90289 updateSnippet = `
90290 bool reducedAnyValue = any(values);
90291 float floatedReducedAnyValue = float(reducedAnyValue);
90292 anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);
90293 `;
90294 vecType = `bvec4`;
90295 }
90296 let checkOutOfBounds = '';
90297 if (inSize % windowSize > 0) {
90298 checkOutOfBounds = `
90299 if (inIdx < 0 || inIdx >= ${inSize}) {
90300 return initializationValue;
90301 }
90302 `;
90303 }
90304 this.userCode = `
90305 const float initializationValue = ${initializationValue};
90306 const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
90307
90308 float getValue(int batch, int inIdx) {
90309 ${checkOutOfBounds}
90310 return getX(batch, inIdx);
90311 }
90312
90313 void main() {
90314 ivec2 coords = getOutputCoords();
90315 int batch = coords[0];
90316 int outIdx = coords[1];
90317 int inOffset = outIdx * ${windowSize};
90318
90319 vec4 minMaxValue = vec4(${initializationValue});
90320 float prodValue = 1.0;
90321 float sumValue = 0.0;
90322 float allValue = 1.0;
90323 float anyValue = 0.0;
90324
90325 for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
90326 int inIdx = inOffset + i;
90327 ${vecType} values = ${vecType}(
90328 getValue(batch, inIdx),
90329 getValue(batch, inIdx + 1),
90330 getValue(batch, inIdx + 2),
90331 getValue(batch, inIdx + 3)
90332 );
90333
90334 ${updateSnippet}
90335 }
90336
90337 int inIdx = inOffset + ${windowSizeNearestVec4};
90338 if (${windowSizeVec4Remainder === 1}) {
90339 ${vecType} values = ${vecType}(
90340 getValue(batch, inIdx),
90341 initializationValue,
90342 initializationValue,
90343 initializationValue
90344 );
90345
90346 ${updateSnippet}
90347 } else if (${windowSizeVec4Remainder === 2}) {
90348 ${vecType} values = ${vecType}(
90349 getValue(batch, inIdx),
90350 getValue(batch, inIdx + 1),
90351 initializationValue,
90352 initializationValue
90353 );
90354
90355 ${updateSnippet}
90356 } else if (${windowSizeVec4Remainder === 3}) {
90357 ${vecType} values = ${vecType}(
90358 getValue(batch, inIdx),
90359 getValue(batch, inIdx + 1),
90360 getValue(batch, inIdx + 2),
90361 initializationValue
90362 );
90363
90364 ${updateSnippet}
90365 }
90366 setOutput(${returnValue});
90367 }
90368 `;
90369 }
90370 }
90371
90372 /**
90373 * @license
90374 * Copyright 2020 Google LLC. All Rights Reserved.
90375 * Licensed under the Apache License, Version 2.0 (the "License");
90376 * you may not use this file except in compliance with the License.
90377 * You may obtain a copy of the License at
90378 *
90379 * http://www.apache.org/licenses/LICENSE-2.0
90380 *
90381 * Unless required by applicable law or agreed to in writing, software
90382 * distributed under the License is distributed on an "AS IS" BASIS,
90383 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90384 * See the License for the specific language governing permissions and
90385 * limitations under the License.
90386 * =============================================================================
90387 */
90388 // Returns an array of configuration objects that describe each stage of the
90389 // reduction.
90390 function getReductionStages(inShape) {
90391 const stages = [];
90392 while (stages.length === 0 || stages[stages.length - 1].outSize !== 1) {
90393 const outSize = stages.length ? stages[stages.length - 1].outSize : inShape[1];
90394 const windowSize = computeOptimalWindowSize(outSize);
90395 stages.push({
90396 inSize: outSize,
90397 windowSize,
90398 outSize: Math.ceil(outSize / windowSize)
90399 });
90400 }
90401 return stages;
90402 }
90403 function reduce(x, dtype, reductionType, backend) {
90404 const reductionStages = getReductionStages(x.shape);
90405 let result = x;
90406 for (let i = 0; i < reductionStages.length; i++) {
90407 const { inSize, windowSize, outSize } = reductionStages[i];
90408 let program;
90409 let previousResult;
90410 if (reductionType === 'mean') {
90411 program = i === 0 ?
90412 new MeanProgram({ windowSize, inSize, batchSize: x.shape[0], outSize }, inSize) :
90413 new MeanProgram({ windowSize, inSize, batchSize: x.shape[0], outSize });
90414 }
90415 else {
90416 program = new ReduceProgram({ windowSize, inSize, batchSize: x.shape[0], outSize }, reductionType);
90417 }
90418 previousResult = result;
90419 result = backend.runWebGLProgram(program, [result], dtype);
90420 if (previousResult.dataId !== x.dataId) {
90421 backend.disposeIntermediateTensorInfo(previousResult);
90422 }
90423 }
90424 return result;
90425 }
90426
90427 /**
90428 * @license
90429 * Copyright 2017 Google LLC. All Rights Reserved.
90430 * Licensed under the Apache License, Version 2.0 (the "License");
90431 * you may not use this file except in compliance with the License.
90432 * You may obtain a copy of the License at
90433 *
90434 * http://www.apache.org/licenses/LICENSE-2.0
90435 *
90436 * Unless required by applicable law or agreed to in writing, software
90437 * distributed under the License is distributed on an "AS IS" BASIS,
90438 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90439 * See the License for the specific language governing permissions and
90440 * limitations under the License.
90441 * =============================================================================
90442 */
90443 class TransposeProgram {
90444 constructor(aShape, newDim) {
90445 this.variableNames = ['A'];
90446 const outputShape = new Array(aShape.length);
90447 for (let i = 0; i < outputShape.length; i++) {
90448 outputShape[i] = aShape[newDim[i]];
90449 }
90450 this.outputShape = outputShape;
90451 this.rank = outputShape.length;
90452 const dtype = getCoordsDataType(this.rank);
90453 const switched = getSwitchedCoords(newDim);
90454 this.userCode = `
90455 void main() {
90456 ${dtype} resRC = getOutputCoords();
90457 setOutput(getA(${switched}));
90458 }
90459 `;
90460 }
90461 }
90462 function getSwitchedCoords(newDim) {
90463 const rank = newDim.length;
90464 if (rank > 6) {
90465 throw Error(`Transpose for rank ${rank} is not yet supported`);
90466 }
90467 const originalOrder = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v'];
90468 const switchedCoords = new Array(rank);
90469 for (let i = 0; i < newDim.length; i++) {
90470 switchedCoords[newDim[i]] = originalOrder[i];
90471 }
90472 return switchedCoords.join();
90473 }
90474
90475 /**
90476 * @license
90477 * Copyright 2019 Google LLC. All Rights Reserved.
90478 * Licensed under the Apache License, Version 2.0 (the "License");
90479 * you may not use this file except in compliance with the License.
90480 * You may obtain a copy of the License at
90481 *
90482 * http://www.apache.org/licenses/LICENSE-2.0
90483 *
90484 * Unless required by applicable law or agreed to in writing, software
90485 * distributed under the License is distributed on an "AS IS" BASIS,
90486 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90487 * See the License for the specific language governing permissions and
90488 * limitations under the License.
90489 * =============================================================================
90490 */
90491 class TransposePackedProgram {
90492 constructor(aShape, newDim) {
90493 this.variableNames = ['A'];
90494 this.packedInputs = true;
90495 this.packedOutput = true;
90496 const outputShape = new Array(aShape.length);
90497 for (let i = 0; i < outputShape.length; i++) {
90498 outputShape[i] = aShape[newDim[i]];
90499 }
90500 this.outputShape = outputShape;
90501 this.rank = outputShape.length;
90502 if (this.rank > 6) {
90503 throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`);
90504 }
90505 const dtype = getCoordsDataType(this.rank);
90506 const outputOrder = getVecChannels('rc', this.rank);
90507 const switchedOrder = new Array(this.rank);
90508 for (let i = 0; i < newDim.length; i++) {
90509 switchedOrder[newDim[i]] = outputOrder[i];
90510 }
90511 const innerDims = `vec2(${switchedOrder.slice(-2).join()})`;
90512 const nextColumn = `++${outputOrder[this.rank - 1]} < ${outputShape[this.rank - 1]}`;
90513 const getc = `getChannel(getA(${switchedOrder.join()}), ${innerDims})`;
90514 this.userCode = `
90515 void main() {
90516 ${dtype} rc = getOutputCoords();
90517 vec4 result = vec4(0.);
90518 result[0] = ${getc};
90519 if(${nextColumn}) {
90520 result[1] = ${getc};
90521 }
90522 --${outputOrder[this.rank - 1]};
90523 if(++${outputOrder[this.rank - 2]} < ${outputShape[this.rank - 2]}) {
90524 result[2] = ${getc};
90525 if(${nextColumn}) {
90526 result[3] = ${getc};
90527 }
90528 }
90529 setOutput(result);
90530 }
90531 `;
90532 }
90533 }
90534
90535 /**
90536 * @license
90537 * Copyright 2020 Google LLC. All Rights Reserved.
90538 * Licensed under the Apache License, Version 2.0 (the "License");
90539 * you may not use this file except in compliance with the License.
90540 * You may obtain a copy of the License at
90541 *
90542 * http://www.apache.org/licenses/LICENSE-2.0
90543 *
90544 * Unless required by applicable law or agreed to in writing, software
90545 * distributed under the License is distributed on an "AS IS" BASIS,
90546 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90547 * See the License for the specific language governing permissions and
90548 * limitations under the License.
90549 * =============================================================================
90550 */
90551 function transposeImpl(x, perm, backend) {
90552 const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
90553 new TransposePackedProgram(x.shape, perm) :
90554 new TransposeProgram(x.shape, perm);
90555 return backend.runWebGLProgram(program, [x], x.dtype);
90556 }
90557
90558 /**
90559 * @license
90560 * Copyright 2020 Google LLC. All Rights Reserved.
90561 * Licensed under the Apache License, Version 2.0 (the "License");
90562 * you may not use this file except in compliance with the License.
90563 * You may obtain a copy of the License at
90564 *
90565 * http://www.apache.org/licenses/LICENSE-2.0
90566 *
90567 * Unless required by applicable law or agreed to in writing, software
90568 * distributed under the License is distributed on an "AS IS" BASIS,
90569 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90570 * See the License for the specific language governing permissions and
90571 * limitations under the License.
90572 * =============================================================================
90573 */
90574 function sumImpl(x, axis, keepDims, backend) {
90575 const reductionIndices = axis;
90576 const xRank = x.shape.length;
90577 const origAxes = parseAxisParam(reductionIndices, x.shape);
90578 let axes = origAxes;
90579 const permutedAxes = getAxesPermutation(axes, xRank);
90580 const sumInputIsTransposed = permutedAxes != null;
90581 let sumInput = x;
90582 if (sumInputIsTransposed) {
90583 sumInput = transposeImpl(x, permutedAxes, backend);
90584 axes = getInnerMostAxes(axes.length, xRank);
90585 }
90586 assertAxesAreInnerMostDims('sum', axes, xRank);
90587 const [sumOutShape, reduceShape] = computeOutAndReduceShapes(sumInput.shape, axes);
90588 let outShape = sumOutShape;
90589 if (keepDims) {
90590 // rather than reshape at the end, set the target shape here.
90591 outShape = expandShapeToKeepDim(sumOutShape, origAxes);
90592 }
90593 const inSize = sizeFromShape(reduceShape);
90594 const xSize = sizeFromShape(x.shape);
90595 const batchSize = xSize / inSize;
90596 const reshapedInput = reshape({ inputs: { x: sumInput }, attrs: { shape: [batchSize, inSize] }, backend });
90597 const outType = sumOutType(x.dtype);
90598 const reduced = reduce(reshapedInput, outType, 'sum', backend);
90599 const out = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend });
90600 backend.disposeIntermediateTensorInfo(reshapedInput);
90601 backend.disposeIntermediateTensorInfo(reduced);
90602 if (sumInputIsTransposed) {
90603 backend.disposeIntermediateTensorInfo(sumInput);
90604 }
90605 return out;
90606 }
90607
90608 /**
90609 * @license
90610 * Copyright 2020 Google LLC. All Rights Reserved.
90611 * Licensed under the Apache License, Version 2.0 (the "License");
90612 * you may not use this file except in compliance with the License.
90613 * You may obtain a copy of the License at
90614 *
90615 * http://www.apache.org/licenses/LICENSE-2.0
90616 *
90617 * Unless required by applicable law or agreed to in writing, software
90618 * distributed under the License is distributed on an "AS IS" BASIS,
90619 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90620 * See the License for the specific language governing permissions and
90621 * limitations under the License.
90622 * =============================================================================
90623 */
90624 function sum(args) {
90625 const { inputs, backend, attrs } = args;
90626 const { x } = inputs;
90627 const { axis, keepDims } = attrs;
90628 return sumImpl(x, axis, keepDims, backend);
90629 }
90630 const sumConfig = {
90631 kernelName: Sum,
90632 backendName: 'webgl',
90633 kernelFunc: sum
90634 };
90635
90636 /**
90637 * @license
90638 * Copyright 2020 Google LLC. All Rights Reserved.
90639 * Licensed under the Apache License, Version 2.0 (the "License");
90640 * you may not use this file except in compliance with the License.
90641 * You may obtain a copy of the License at
90642 *
90643 * http://www.apache.org/licenses/LICENSE-2.0
90644 *
90645 * Unless required by applicable law or agreed to in writing, software
90646 * distributed under the License is distributed on an "AS IS" BASIS,
90647 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90648 * See the License for the specific language governing permissions and
90649 * limitations under the License.
90650 * =============================================================================
90651 */
90652 function transpose(args) {
90653 const { inputs, backend, attrs } = args;
90654 const { x } = inputs;
90655 const { perm } = attrs;
90656 const webglBackend = backend;
90657 const xRank = x.shape.length;
90658 const newShape = new Array(xRank);
90659 for (let i = 0; i < newShape.length; i++) {
90660 newShape[i] = x.shape[perm[i]];
90661 }
90662 let out;
90663 if (webglBackend.shouldExecuteOnCPU([x])) {
90664 const xTexData = webglBackend.texData.get(x.dataId);
90665 const values = xTexData.values;
90666 const outValues = transposeImplCPU(values, x.shape, x.dtype, perm, newShape);
90667 out = webglBackend.makeTensorInfo(newShape, x.dtype);
90668 const outData = webglBackend.texData.get(out.dataId);
90669 outData.values = outValues;
90670 }
90671 else {
90672 out = transposeImpl(x, perm, webglBackend);
90673 }
90674 return out;
90675 }
90676 const transposeConfig = {
90677 kernelName: Transpose,
90678 backendName: 'webgl',
90679 kernelFunc: transpose
90680 };
90681
90682 /**
90683 * @license
90684 * Copyright 2020 Google LLC. All Rights Reserved.
90685 * Licensed under the Apache License, Version 2.0 (the "License");
90686 * you may not use this file except in compliance with the License.
90687 * You may obtain a copy of the License at
90688 *
90689 * http://www.apache.org/licenses/LICENSE-2.0
90690 *
90691 * Unless required by applicable law or agreed to in writing, software
90692 * distributed under the License is distributed on an "AS IS" BASIS,
90693 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90694 * See the License for the specific language governing permissions and
90695 * limitations under the License.
90696 * =============================================================================
90697 */
90698 // Empirically determined minimal shared dimension in matmul before we forward
90699 // to a.mul(b).sum() in order to take advantage of GPU parallelism. See
90700 // https://github.com/tensorflow/tfjs-core/pull/1379 for benchmarks.
90701 const MATMUL_SHARED_DIM_THRESHOLD = 1000;
90702 function batchMatMulImpl({ a, b, transposeA, transposeB, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) {
90703 const aRank = a.shape.length;
90704 const bRank = b.shape.length;
90705 const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
90706 const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
90707 const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
90708 const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
90709 const outerDimsA = a.shape.slice(0, -2);
90710 const outerDimsB = b.shape.slice(0, -2);
90711 const batchDimA = sizeFromShape(outerDimsA);
90712 const batchDimB = sizeFromShape(outerDimsB);
90713 const outShapeOuterDims = assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2));
90714 const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
90715 assert$1(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` +
90716 `${innerShapeB}) of Tensors with shapes ${a.shape} and ` +
90717 `${b.shape} and transposeA=${transposeA}` +
90718 ` and transposeB=${transposeB} must match.`);
90719 const a3dShape = transposeA ?
90720 [batchDimA, innerShapeA, outerShapeA] :
90721 [batchDimA, outerShapeA, innerShapeA];
90722 const b3dShape = transposeB ?
90723 [batchDimB, outerShapeB, innerShapeB] :
90724 [batchDimB, innerShapeB, outerShapeB];
90725 // The rest of the implementation is designed to operate on rank-3 tensors
90726 const a3d = reshape({ inputs: { x: a }, backend, attrs: { shape: a3dShape } });
90727 const b3d = reshape({ inputs: { x: b }, backend, attrs: { shape: b3dShape } });
90728 const intermediates = [a3d, b3d];
90729 const batchDim = Math.max(batchDimA, batchDimB);
90730 const sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];
90731 const hasBias = bias != null;
90732 const hasPreluActivationWeights = preluActivationWeights != null;
90733 const hasLeakyreluAlpha = activation === 'leakyrelu';
90734 const fusedActivation = activation != null ?
90735 mapActivationToShaderProgram(activation, true) :
90736 null;
90737 const containsFusedOps = hasBias || hasPreluActivationWeights ||
90738 hasLeakyreluAlpha || fusedActivation != null;
90739 let out;
90740 // Since the matrices are vectors, it is faster to call mul().sum()
90741 // because sum() is O(sqrt(N)) due to divide-and-conquer.
90742 if ((outerShapeA === 1 || outerShapeB === 1) &&
90743 sharedDim > MATMUL_SHARED_DIM_THRESHOLD && containsFusedOps === false) {
90744 let aVec = a3d;
90745 let bVec = b3d;
90746 if (transposeA) {
90747 aVec = transpose({ inputs: { x: a3d }, backend, attrs: { perm: [0, 2, 1] } });
90748 intermediates.push(aVec);
90749 }
90750 if (transposeB) {
90751 bVec = transpose({ inputs: { x: b3d }, backend, attrs: { perm: [0, 2, 1] } });
90752 intermediates.push(bVec);
90753 }
90754 const shouldReshapeA = outerShapeB !== 1;
90755 const shouldReshapeB = outerShapeB === 1;
90756 let aVec3d = aVec;
90757 if (shouldReshapeA) {
90758 aVec3d = reshape({
90759 inputs: { x: aVec },
90760 backend,
90761 attrs: { shape: [batchDim, sharedDim, 1] }
90762 });
90763 intermediates.push(aVec3d);
90764 }
90765 const axis = outerShapeB === 1 ? 2 : 1;
90766 let bVec3d = bVec;
90767 if (shouldReshapeB) {
90768 bVec3d = reshape({
90769 inputs: { x: bVec },
90770 backend,
90771 attrs: { shape: [batchDim, 1, sharedDim] }
90772 });
90773 intermediates.push(bVec3d);
90774 }
90775 const product = multiply({ inputs: { a: aVec3d, b: bVec3d }, backend });
90776 out = sum({ inputs: { x: product }, backend, attrs: { axis, keepDims: true } });
90777 intermediates.push(product);
90778 }
90779 else {
90780 const dtype = upcastType(a.dtype, b.dtype);
90781 const program = new MatMulPackedProgram(a3dShape, b3dShape, [batchDim, outerShapeA, outerShapeB], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
90782 const inputs = [a3d, b3d];
90783 if (bias != null) {
90784 inputs.push(bias);
90785 }
90786 if (hasPreluActivationWeights) {
90787 inputs.push(preluActivationWeights);
90788 }
90789 if (hasLeakyreluAlpha) {
90790 const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
90791 inputs.push($leakyreluAlpha);
90792 intermediates.push($leakyreluAlpha);
90793 }
90794 out = backend.runWebGLProgram(program, inputs, dtype);
90795 }
90796 const outReshaped = reshape({ inputs: { x: out }, backend, attrs: { shape: outShape } });
90797 intermediates.push(out);
90798 for (const i of intermediates) {
90799 backend.disposeIntermediateTensorInfo(i);
90800 }
90801 return outReshaped;
90802 }
90803
90804 /**
90805 * @license
90806 * Copyright 2020 Google LLC. All Rights Reserved.
90807 * Licensed under the Apache License, Version 2.0 (the License);
90808 * you may not use this file except in compliance with the License.
90809 * You may obtain a copy of the License at
90810 *
90811 * http://www.apache.org/licenses/LICENSE-2.0
90812 *
90813 * Unless required by applicable law or agreed to in writing, software
90814 * distributed under the License is distributed on an AS IS BASIS,
90815 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90816 * See the License for the specific language governing permissions and
90817 * limitations under the License.
90818 * =============================================================================
90819 */
90820 function _fusedMatMul(args) {
90821 const { inputs, backend, attrs } = args;
90822 const { a, b, bias, preluActivationWeights } = inputs;
90823 const { transposeA, transposeB, activation, leakyreluAlpha } = attrs;
90824 return batchMatMulImpl({
90825 a,
90826 b,
90827 transposeA,
90828 transposeB,
90829 backend,
90830 bias,
90831 preluActivationWeights,
90832 leakyreluAlpha,
90833 activation
90834 });
90835 }
90836 const _fusedMatMulConfig = {
90837 kernelName: _FusedMatMul,
90838 backendName: 'webgl',
90839 kernelFunc: _fusedMatMul,
90840 };
90841
90842 /**
90843 * @license
90844 * Copyright 2020 Google LLC. All Rights Reserved.
90845 * Licensed under the Apache License, Version 2.0 (the "License");
90846 * you may not use this file except in compliance with the License.
90847 * You may obtain a copy of the License at
90848 *
90849 * http://www.apache.org/licenses/LICENSE-2.0
90850 *
90851 * Unless required by applicable law or agreed to in writing, software
90852 * distributed under the License is distributed on an "AS IS" BASIS,
90853 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90854 * See the License for the specific language governing permissions and
90855 * limitations under the License.
90856 * =============================================================================
90857 */
90858 const ABS = `return abs(x);`;
90859 function abs(args) {
90860 const { inputs, backend } = args;
90861 const { x } = inputs;
90862 // TODO: handle cases when x is complex. Once the cpu implementation
90863 // can handle complex values, refactor to use unaryKernelFunc.
90864 if (backend.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') {
90865 const xData = backend.texData.get(x.dataId);
90866 const outValues = simpleAbsImplCPU(xData.values);
90867 return backend.makeTensorInfo(x.shape, x.dtype, outValues);
90868 }
90869 let program;
90870 if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
90871 program = new UnaryOpPackedProgram(x.shape, ABS);
90872 }
90873 else {
90874 program = new UnaryOpProgram(x.shape, ABS);
90875 }
90876 return backend.runWebGLProgram(program, [x], x.dtype);
90877 }
90878 const absConfig = {
90879 kernelName: Abs,
90880 backendName: 'webgl',
90881 kernelFunc: abs
90882 };
90883
90884 /**
90885 * @license
90886 * Copyright 2020 Google LLC. All Rights Reserved.
90887 * Licensed under the Apache License, Version 2.0 (the "License");
90888 * you may not use this file except in compliance with the License.
90889 * You may obtain a copy of the License at
90890 *
90891 * http://www.apache.org/licenses/LICENSE-2.0
90892 *
90893 * Unless required by applicable law or agreed to in writing, software
90894 * distributed under the License is distributed on an "AS IS" BASIS,
90895 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90896 * See the License for the specific language governing permissions and
90897 * limitations under the License.
90898 * =============================================================================
90899 */
90900 const ACOS = CHECK_NAN_SNIPPET$1 + `
90901 if (abs(x) > 1.) {
90902 return NAN;
90903 }
90904 return acos(x);
90905`;
90906 const acos = unaryKernelFunc({ opSnippet: ACOS });
90907 const acosConfig = {
90908 kernelName: Acos,
90909 backendName: 'webgl',
90910 kernelFunc: acos,
90911 };
90912
90913 /**
90914 * @license
90915 * Copyright 2020 Google LLC. All Rights Reserved.
90916 * Licensed under the Apache License, Version 2.0 (the "License");
90917 * you may not use this file except in compliance with the License.
90918 * You may obtain a copy of the License at
90919 *
90920 * http://www.apache.org/licenses/LICENSE-2.0
90921 *
90922 * Unless required by applicable law or agreed to in writing, software
90923 * distributed under the License is distributed on an "AS IS" BASIS,
90924 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90925 * See the License for the specific language governing permissions and
90926 * limitations under the License.
90927 * =============================================================================
90928 */
90929 const ACOSH = CHECK_NAN_SNIPPET$1 + `
90930 if (x < 1.0) return NAN;
90931return log(x + sqrt(x * x - 1.0));`;
90932 const acosh = unaryKernelFunc({ opSnippet: ACOSH });
90933 const acoshConfig = {
90934 kernelName: Acosh,
90935 backendName: 'webgl',
90936 kernelFunc: acosh,
90937 };
90938
90939 /**
90940 * @license
90941 * Copyright 2020 Google LLC. All Rights Reserved.
90942 * Licensed under the Apache License, Version 2.0 (the "License");
90943 * you may not use this file except in compliance with the License.
90944 * You may obtain a copy of the License at
90945 *
90946 * http://www.apache.org/licenses/LICENSE-2.0
90947 *
90948 * Unless required by applicable law or agreed to in writing, software
90949 * distributed under the License is distributed on an "AS IS" BASIS,
90950 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90951 * See the License for the specific language governing permissions and
90952 * limitations under the License.
90953 * =============================================================================
90954 */
90955 const ADD = 'return a + b;';
90956 const addKernelFunc = binaryKernelFunc({
90957 opSnippet: ADD,
90958 packedOpSnippet: ADD,
90959 supportsComplex: true,
90960 cpuKernelImpl: addImplCPU
90961 });
90962 const addConfig = {
90963 kernelName: Add$1,
90964 backendName: 'webgl',
90965 kernelFunc: addKernelFunc
90966 };
90967
90968 /**
90969 * @license
90970 * Copyright 2019 Google LLC. All Rights Reserved.
90971 * Licensed under the Apache License, Version 2.0 (the "License");
90972 * you may not use this file except in compliance with the License.
90973 * You may obtain a copy of the License at
90974 *
90975 * http://www.apache.org/licenses/LICENSE-2.0
90976 *
90977 * Unless required by applicable law or agreed to in writing, software
90978 * distributed under the License is distributed on an "AS IS" BASIS,
90979 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90980 * See the License for the specific language governing permissions and
90981 * limitations under the License.
90982 * =============================================================================
90983 */
90984 class AddNProgram {
90985 constructor(outputShape, shapes) {
90986 this.outputShape = [];
90987 this.outputShape = outputShape;
90988 this.variableNames = shapes.map((_, i) => `T${i}`);
90989 const snippets = [];
90990 // Get target elements from every input tensor.
90991 this.variableNames.forEach(variable => {
90992 snippets.push(`float v${variable} = get${variable}AtOutCoords();`);
90993 });
90994 // Calculate the sum of all elements.
90995 const operation = this.variableNames
90996 .map(variable => {
90997 return `v${variable}`;
90998 })
90999 .join(' + ');
91000 this.userCode = `
91001 void main() {
91002 ${snippets.join('\n ')}
91003
91004 float result = ${operation};
91005 setOutput(result);
91006 }
91007 `;
91008 }
91009 }
91010
91011 /**
91012 * @license
91013 * Copyright 2019 Google LLC. All Rights Reserved.
91014 * Licensed under the Apache License, Version 2.0 (the "License");
91015 * you may not use this file except in compliance with the License.
91016 * You may obtain a copy of the License at
91017 *
91018 * http://www.apache.org/licenses/LICENSE-2.0
91019 *
91020 * Unless required by applicable law or agreed to in writing, software
91021 * distributed under the License is distributed on an "AS IS" BASIS,
91022 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91023 * See the License for the specific language governing permissions and
91024 * limitations under the License.
91025 * =============================================================================
91026 */
91027 class AddNPackedProgram {
91028 constructor(outputShape, shapes) {
91029 this.outputShape = [];
91030 this.packedInputs = true;
91031 this.packedOutput = true;
91032 this.outputShape = outputShape;
91033 this.variableNames = shapes.map((_, i) => `T${i}`);
91034 const snippets = [];
91035 // Get target elements from every input tensor.
91036 this.variableNames.forEach(variable => {
91037 snippets.push(`vec4 v${variable} = get${variable}AtOutCoords();`);
91038 });
91039 // Calculate the sum of all elements.
91040 const operation = this.variableNames
91041 .map(variable => {
91042 return `v${variable}`;
91043 })
91044 .join(' + ');
91045 this.userCode = `
91046 void main() {
91047 ${snippets.join('\n ')}
91048
91049 vec4 result = ${operation};
91050 setOutput(result);
91051 }
91052 `;
91053 }
91054 }
91055
91056 /**
91057 * @license
91058 * Copyright 2020 Google LLC. All Rights Reserved.
91059 * Licensed under the Apache License, Version 2.0 (the "License");
91060 * you may not use this file except in compliance with the License.
91061 * You may obtain a copy of the License at
91062 *
91063 * http://www.apache.org/licenses/LICENSE-2.0
91064 *
91065 * Unless required by applicable law or agreed to in writing, software
91066 * distributed under the License is distributed on an "AS IS" BASIS,
91067 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91068 * See the License for the specific language governing permissions and
91069 * limitations under the License.
91070 * =============================================================================
91071 */
91072 function addN(args) {
91073 const { inputs, backend } = args;
91074 const tensors = inputs;
91075 if (tensors.length === 1) {
91076 return identity({ inputs: { x: tensors[0] }, backend });
91077 }
91078 // Limit the number of uploaded textures for optimization.
91079 if (tensors.length > env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) {
91080 const midIndex = Math.floor(tensors.length / 2);
91081 const leftSide = addN({ inputs: tensors.slice(0, midIndex), backend });
91082 const rightSide = addN({ inputs: tensors.slice(midIndex), backend });
91083 return addN({ inputs: [leftSide, rightSide], backend });
91084 }
91085 const dtype = tensors.map(t => t.dtype).reduce((d1, d2) => upcastType(d1, d2));
91086 const shapes = tensors.map(t => t.shape);
91087 // We can make sure shapes are identical in op level.
91088 const usePackedOp = env().getBool('WEBGL_PACK');
91089 const program = usePackedOp ?
91090 new AddNPackedProgram(tensors[0].shape, shapes) :
91091 new AddNProgram(tensors[0].shape, shapes);
91092 return backend.runWebGLProgram(program, tensors, dtype);
91093 }
91094 const addNConfig = {
91095 kernelName: AddN,
91096 backendName: 'webgl',
91097 kernelFunc: addN
91098 };
91099
91100 /**
91101 * @license
91102 * Copyright 2020 Google LLC. All Rights Reserved.
91103 * Licensed under the Apache License, Version 2.0 (the "License");
91104 * you may not use this file except in compliance with the License.
91105 * You may obtain a copy of the License at
91106 *
91107 * http://www.apache.org/licenses/LICENSE-2.0
91108 *
91109 * Unless required by applicable law or agreed to in writing, software
91110 * distributed under the License is distributed on an "AS IS" BASIS,
91111 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91112 * See the License for the specific language governing permissions and
91113 * limitations under the License.
91114 * =============================================================================
91115 */
91116 function all(args) {
91117 const { inputs, backend, attrs } = args;
91118 const { x } = inputs;
91119 const { axis, keepDims } = attrs;
91120 const xRank = x.shape.length;
91121 const origAxes = parseAxisParam(axis, x.shape);
91122 let axes = origAxes;
91123 const permutedAxes = getAxesPermutation(axes, xRank);
91124 let permutedX = x;
91125 if (permutedAxes != null) {
91126 permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
91127 axes = getInnerMostAxes(axes.length, xRank);
91128 }
91129 assertAxesAreInnerMostDims('all', axes, xRank);
91130 const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes);
91131 const inSize = sizeFromShape(reduceShape);
91132 const a2D = reshape({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
91133 const reduced = reduce(a2D, a2D.dtype, 'all', backend);
91134 let res;
91135 if (keepDims) {
91136 const newShape = expandShapeToKeepDim(outShape, origAxes);
91137 res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: newShape } });
91138 }
91139 else {
91140 res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
91141 }
91142 backend.disposeIntermediateTensorInfo(a2D);
91143 backend.disposeIntermediateTensorInfo(reduced);
91144 if (permutedAxes != null) {
91145 backend.disposeIntermediateTensorInfo(permutedX);
91146 }
91147 return res;
91148 }
91149 const allConfig = {
91150 kernelName: All,
91151 backendName: 'webgl',
91152 kernelFunc: all
91153 };
91154
91155 /**
91156 * @license
91157 * Copyright 2020 Google LLC. All Rights Reserved.
91158 * Licensed under the Apache License, Version 2.0 (the "License");
91159 * you may not use this file except in compliance with the License.
91160 * You may obtain a copy of the License at
91161 *
91162 * http://www.apache.org/licenses/LICENSE-2.0
91163 *
91164 * Unless required by applicable law or agreed to in writing, software
91165 * distributed under the License is distributed on an "AS IS" BASIS,
91166 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91167 * See the License for the specific language governing permissions and
91168 * limitations under the License.
91169 * =============================================================================
91170 */
91171 function any(args) {
91172 const { inputs, backend, attrs } = args;
91173 const { x } = inputs;
91174 const { axis, keepDims } = attrs;
91175 const xRank = x.shape.length;
91176 const origAxes = parseAxisParam(axis, x.shape);
91177 let axes = origAxes;
91178 const permutedAxes = getAxesPermutation(axes, xRank);
91179 let permutedX = x;
91180 if (permutedAxes != null) {
91181 permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
91182 axes = getInnerMostAxes(axes.length, xRank);
91183 }
91184 assertAxesAreInnerMostDims('any', axes, xRank);
91185 const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes);
91186 const inSize = sizeFromShape(reduceShape);
91187 const a2D = reshape({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
91188 const reduced = reduce(a2D, a2D.dtype, 'any', backend);
91189 let res;
91190 if (keepDims) {
91191 const newShape = expandShapeToKeepDim(outShape, origAxes);
91192 res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: newShape } });
91193 }
91194 else {
91195 res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
91196 }
91197 backend.disposeIntermediateTensorInfo(a2D);
91198 backend.disposeIntermediateTensorInfo(reduced);
91199 if (permutedAxes != null) {
91200 backend.disposeIntermediateTensorInfo(permutedX);
91201 }
91202 return res;
91203 }
91204 const anyConfig = {
91205 kernelName: Any,
91206 backendName: 'webgl',
91207 kernelFunc: any
91208 };
91209
91210 /**
91211 * @license
91212 * Copyright 2017 Google LLC. All Rights Reserved.
91213 * Licensed under the Apache License, Version 2.0 (the "License");
91214 * you may not use this file except in compliance with the License.
91215 * You may obtain a copy of the License at
91216 *
91217 * http://www.apache.org/licenses/LICENSE-2.0
91218 *
91219 * Unless required by applicable law or agreed to in writing, software
91220 * distributed under the License is distributed on an "AS IS" BASIS,
91221 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91222 * See the License for the specific language governing permissions and
91223 * limitations under the License.
91224 * =============================================================================
91225 */
91226 class ArgMinMaxProgram {
91227 constructor(reduceInfo, op, firstPass) {
91228 this.variableNames = ['A'];
91229 const { windowSize, batchSize, outSize } = reduceInfo;
91230 if (!firstPass) {
91231 this.variableNames.push('bestIndicesA');
91232 }
91233 this.outputShape = [batchSize, outSize];
91234 const compOp = (op === 'max') ? '>' : '<';
91235 const indexSnippet = firstPass ?
91236 'inOffset + i;' :
91237 'round(getBestIndicesA(batch, inOffset + i));';
91238 this.userCode = `
91239 void main() {
91240 ivec2 coords = getOutputCoords();
91241 int batch = coords[0];
91242 int outIdx = coords[1];
91243 int inOffset = outIdx * ${windowSize};
91244
91245 int bestIndex = inOffset;
91246 float bestValue = getA(batch, bestIndex);
91247
91248 for (int i = 0; i < ${windowSize}; i++) {
91249 int inIdx = ${indexSnippet};
91250 float candidate = getA(batch, inIdx);
91251 if (candidate ${compOp} bestValue) {
91252 bestValue = candidate;
91253 bestIndex = inIdx;
91254 }
91255 }
91256 setOutput(float(bestIndex));
91257 }
91258 `;
91259 }
91260 }
91261
91262 /**
91263 * @license
91264 * Copyright 2019 Google LLC. All Rights Reserved.
91265 * Licensed under the Apache License, Version 2.0 (the "License");
91266 * you may not use this file except in compliance with the License.
91267 * You may obtain a copy of the License at
91268 *
91269 * http://www.apache.org/licenses/LICENSE-2.0
91270 *
91271 * Unless required by applicable law or agreed to in writing, software
91272 * distributed under the License is distributed on an "AS IS" BASIS,
91273 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91274 * See the License for the specific language governing permissions and
91275 * limitations under the License.
91276 * =============================================================================
91277 */
91278 class ArgMinMaxPackedProgram {
91279 constructor(shape, windowSize, op, firstPass) {
91280 this.variableNames = ['A'];
91281 this.packedInputs = true;
91282 this.packedOutput = true;
91283 assert$1(shape.length > 2, () => `Packed arg${op.charAt(0).toUpperCase() +
91284 op.slice(1)} supports only inputs with rank above 2.`);
91285 const inSize = shape[shape.length - 1];
91286 const outSize = Math.ceil(inSize / windowSize);
91287 this.outputShape = shape.slice(0, -1);
91288 if (outSize > 1) {
91289 this.outputShape.push(outSize);
91290 }
91291 if (!firstPass) {
91292 this.variableNames.push('bestIndicesA');
91293 }
91294 const outShape = this.outputShape;
91295 const rank = outShape.length;
91296 const dtype = getCoordsDataType(rank);
91297 const coords = getChannels('coords', rank);
91298 let sourceLocSetup;
91299 let sourceRank;
91300 if (outSize === 1) {
91301 sourceRank = rank + 1;
91302 const sourceLocDType = getCoordsDataType(sourceRank);
91303 sourceLocSetup = `
91304 ${sourceLocDType} sourceLocR = ${sourceLocDType}(${coords.join()}, 0);
91305 ++${coords[rank - 1]};
91306 ${sourceLocDType} sourceLocG = ${sourceLocDType}(${coords.join()}, 0);
91307 ++${coords[rank - 2]};
91308 ${sourceLocDType} sourceLocA = ${sourceLocDType}(${coords.join()}, 0);
91309 --${coords[rank - 1]};
91310 ${sourceLocDType} sourceLocB = ${sourceLocDType}(${coords.join()}, 0);
91311 --${coords[rank - 2]};`;
91312 }
91313 else {
91314 sourceRank = rank;
91315 sourceLocSetup = `
91316 ${dtype} sourceLocR = coords;
91317 ++${coords[rank - 1]};
91318 ${dtype} sourceLocG = coords;
91319 ++${coords[rank - 2]};
91320 ${dtype} sourceLocA = coords;
91321 --${coords[rank - 1]};
91322 ${dtype} sourceLocB = coords;
91323 --${coords[rank - 2]};`;
91324 }
91325 const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank);
91326 const inChannel = '.' + channels[sourceRank - 1]; // e.g. ".b" for rank 3.
91327 const intChannels = channels.map(x => 'int ' + x);
91328 const srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r');
91329 const srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g');
91330 const srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b');
91331 const srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a');
91332 const compOp = (op === 'max') ? 'greaterThan' : 'lessThan';
91333 const fetchCandidateIdx = firstPass ? '' : `
91334 inIdx = round(vec4(getBestIndicesAChannel(${srcRCoords.join()}),
91335 getBestIndicesAChannel(${srcGCoords.join()}),
91336 getBestIndicesAChannel(${srcBCoords.join()}),
91337 getBestIndicesAChannel(${srcACoords.join()})));`;
91338 const fetchValue = `vec4(
91339 getAChannel(${srcRCoords.join()}),
91340 hasNextCol ? getAChannel(${srcGCoords.join()}) : 0.,
91341 hasNextRow ? getAChannel(${srcBCoords.join()}) : 0.,
91342 hasNextRow && hasNextCol ? getAChannel(${srcACoords.join()}) : 0.)`;
91343 const getBestIndicesAChannelSnippet = firstPass ? '' : `
91344 float getBestIndicesAChannel(${intChannels.join()}) {
91345 return getChannel(getBestIndicesA(${channels.join()}),
91346 vec2(${channels.slice(-2).join()}));
91347 }`;
91348 this.userCode = `
91349 float getAChannel(${intChannels.join()}) {
91350 return getChannel(getA(${channels.join()}),
91351 vec2(${channels.slice(-2).join()}));
91352 }
91353 ${getBestIndicesAChannelSnippet}
91354 void main() {
91355 ${dtype} coords = getOutputCoords();
91356 bool hasNextCol = ${coords[rank - 1]} < ${outShape[rank - 1] - 1};
91357 bool hasNextRow = ${coords[rank - 2]} < ${outShape[rank - 2] - 1};
91358 ${sourceLocSetup}
91359 ivec4 srcIdx = ivec4(sourceLocR${inChannel}, sourceLocG${inChannel},
91360 sourceLocB${inChannel}, sourceLocA${inChannel}) * ${windowSize};
91361 ivec4 inIdx = srcIdx;
91362 vec4 bestIndex = vec4(inIdx);
91363 vec4 bestValue = ${fetchValue};
91364
91365 for (int i = 0; i < ${windowSize}; i++) {
91366 inIdx = srcIdx;
91367 ${fetchCandidateIdx}
91368 vec4 candidate = ${fetchValue};
91369 bvec4 nan = isnan(candidate);
91370 bvec4 replace = bvec4(
91371 vec4(${compOp}(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));
91372
91373 bestValue = vec4(replace.x ? candidate.x : bestValue.x,
91374 replace.y ? candidate.y : bestValue.y,
91375 replace.z ? candidate.z : bestValue.z,
91376 replace.w ? candidate.w : bestValue.w);
91377 bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));
91378 srcIdx++;
91379 }
91380 setOutput(bestIndex);
91381 }
91382 `;
91383 }
91384 }
91385
91386 /**
91387 * @license
91388 * Copyright 2020 Google LLC. All Rights Reserved.
91389 * Licensed under the Apache License, Version 2.0 (the "License");
91390 * you may not use this file except in compliance with the License.
91391 * You may obtain a copy of the License at
91392 *
91393 * http://www.apache.org/licenses/LICENSE-2.0
91394 *
91395 * Unless required by applicable law or agreed to in writing, software
91396 * distributed under the License is distributed on an "AS IS" BASIS,
91397 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91398 * See the License for the specific language governing permissions and
91399 * limitations under the License.
91400 * =============================================================================
91401 */
91402 function argReduce(backend, x, reduceType, bestIndicesA = null) {
91403 let batchSize = x.shape[0];
91404 let inSize = x.shape[1];
91405 if (bestIndicesA != null) {
91406 batchSize = bestIndicesA.shape[0];
91407 inSize = bestIndicesA.shape[1];
91408 }
91409 const windowSize = computeOptimalWindowSize(inSize);
91410 const reduceInfo = { windowSize, inSize, batchSize, outSize: Math.ceil(inSize / windowSize) };
91411 const program = new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null);
91412 const inputs = [x];
91413 if (bestIndicesA != null) {
91414 inputs.push(bestIndicesA);
91415 }
91416 const output = backend.runWebGLProgram(program, inputs, 'int32');
91417 // No need to run another GPGPU program.
91418 if (output.shape[1] === 1) {
91419 return output;
91420 }
91421 const result = argReduce(backend, x, reduceType, output);
91422 backend.disposeIntermediateTensorInfo(output);
91423 return result;
91424 }
91425 function argReducePacked(backend, x, reduceType, bestIndicesA = null) {
91426 const inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape;
91427 const inSize = inShape[inShape.length - 1];
91428 const windowSize = computeOptimalWindowSize(inSize);
91429 const program = new ArgMinMaxPackedProgram(inShape, windowSize, reduceType, bestIndicesA == null);
91430 const inputs = bestIndicesA == null ? [x] : [x, bestIndicesA];
91431 const output = backend.runWebGLProgram(program, inputs, 'int32');
91432 if (output.shape.length === x.shape.length) {
91433 const result = argReducePacked(backend, x, reduceType, output);
91434 backend.disposeIntermediateTensorInfo(output);
91435 return result;
91436 }
91437 return output;
91438 }
91439 function argMinMaxReduce(backend, x, axis, reduceType) {
91440 const axes = [axis];
91441 assertAxesAreInnerMostDims('arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, x.shape.length);
91442 if (!env().getBool('WEBGL_PACK_REDUCE') || x.shape.length <= 2) {
91443 const intermediateTensorInfos = [];
91444 // Eagerly unpack x input since it is passed in to all the shaders which
91445 // require unpacked inputs.
91446 const xtexData = backend.texData.get(x.dataId);
91447 const xIsPacked = xtexData !== null && xtexData.isPacked;
91448 let xUnPacked = x;
91449 if (xIsPacked) {
91450 xUnPacked = backend.unpackTensor(x);
91451 intermediateTensorInfos.push(xUnPacked);
91452 }
91453 const [outShape, reduceShape] = computeOutAndReduceShapes(xUnPacked.shape, axes);
91454 const inSize = sizeFromShape(reduceShape);
91455 const a2D = reshape({ inputs: { x: xUnPacked }, backend, attrs: { shape: [-1, inSize] } });
91456 intermediateTensorInfos.push(a2D);
91457 const reduced = argReduce(backend, a2D, reduceType);
91458 intermediateTensorInfos.push(reduced);
91459 const reshaped = reshape({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
91460 intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
91461 return reshaped;
91462 }
91463 return argReducePacked(backend, x, reduceType);
91464 }
91465
91466 /**
91467 * @license
91468 * Copyright 2020 Google LLC. All Rights Reserved.
91469 * Licensed under the Apache License, Version 2.0 (the "License");
91470 * you may not use this file except in compliance with the License.
91471 * You may obtain a copy of the License at
91472 *
91473 * http://www.apache.org/licenses/LICENSE-2.0
91474 *
91475 * Unless required by applicable law or agreed to in writing, software
91476 * distributed under the License is distributed on an "AS IS" BASIS,
91477 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91478 * See the License for the specific language governing permissions and
91479 * limitations under the License.
91480 * =============================================================================
91481 */
91482 function argMax(args) {
91483 const { inputs, backend, attrs } = args;
91484 const { x } = inputs;
91485 const { axis } = attrs;
91486 let axes = parseAxisParam(axis, x.shape);
91487 const permutedAxes = getAxesPermutation(axes, x.shape.length);
91488 let $x = x;
91489 const intermediateTensorInfos = [];
91490 if (permutedAxes != null) {
91491 $x = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
91492 intermediateTensorInfos.push($x);
91493 axes = getInnerMostAxes(axes.length, $x.shape.length);
91494 }
91495 assertAxesAreInnerMostDims('argMax', [axes[0]], $x.shape.length);
91496 const out = argMinMaxReduce(backend, $x, axes[0], 'max');
91497 intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
91498 return out;
91499 }
91500 const argMaxConfig = {
91501 kernelName: ArgMax,
91502 backendName: 'webgl',
91503 kernelFunc: argMax
91504 };
91505
91506 /**
91507 * @license
91508 * Copyright 2020 Google LLC. All Rights Reserved.
91509 * Licensed under the Apache License, Version 2.0 (the "License");
91510 * you may not use this file except in compliance with the License.
91511 * You may obtain a copy of the License at
91512 *
91513 * http://www.apache.org/licenses/LICENSE-2.0
91514 *
91515 * Unless required by applicable law or agreed to in writing, software
91516 * distributed under the License is distributed on an "AS IS" BASIS,
91517 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91518 * See the License for the specific language governing permissions and
91519 * limitations under the License.
91520 * =============================================================================
91521 */
91522 function argMin(args) {
91523 const { inputs, backend, attrs } = args;
91524 const { x } = inputs;
91525 const { axis } = attrs;
91526 let axes = parseAxisParam(axis, x.shape);
91527 const permutedAxes = getAxesPermutation(axes, x.shape.length);
91528 let $x = x;
91529 const intermediateTensorInfos = [];
91530 if (permutedAxes != null) {
91531 $x = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
91532 intermediateTensorInfos.push($x);
91533 axes = getInnerMostAxes(axes.length, $x.shape.length);
91534 }
91535 assertAxesAreInnerMostDims('argMin', [axes[0]], $x.shape.length);
91536 const out = argMinMaxReduce(backend, $x, axes[0], 'min');
91537 intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
91538 return out;
91539 }
91540 const argMinConfig = {
91541 kernelName: ArgMin,
91542 backendName: 'webgl',
91543 kernelFunc: argMin
91544 };
91545
91546 /**
91547 * @license
91548 * Copyright 2020 Google LLC. All Rights Reserved.
91549 * Licensed under the Apache License, Version 2.0 (the "License");
91550 * you may not use this file except in compliance with the License.
91551 * You may obtain a copy of the License at
91552 *
91553 * http://www.apache.org/licenses/LICENSE-2.0
91554 *
91555 * Unless required by applicable law or agreed to in writing, software
91556 * distributed under the License is distributed on an "AS IS" BASIS,
91557 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91558 * See the License for the specific language governing permissions and
91559 * limitations under the License.
91560 * =============================================================================
91561 */
91562 const ASIN = CHECK_NAN_SNIPPET$1 + `
91563 if (abs(x) > 1.) {
91564 return NAN;
91565 }
91566 return asin(x);
91567`;
91568 const asin = unaryKernelFunc({ opSnippet: ASIN });
91569 const asinConfig = {
91570 kernelName: Asin,
91571 backendName: 'webgl',
91572 kernelFunc: asin,
91573 };
91574
91575 /**
91576 * @license
91577 * Copyright 2020 Google LLC. All Rights Reserved.
91578 * Licensed under the Apache License, Version 2.0 (the "License");
91579 * you may not use this file except in compliance with the License.
91580 * You may obtain a copy of the License at
91581 *
91582 * http://www.apache.org/licenses/LICENSE-2.0
91583 *
91584 * Unless required by applicable law or agreed to in writing, software
91585 * distributed under the License is distributed on an "AS IS" BASIS,
91586 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91587 * See the License for the specific language governing permissions and
91588 * limitations under the License.
91589 * =============================================================================
91590 */
91591 const ASINH = CHECK_NAN_SNIPPET$1 + `return log(x + sqrt(x * x + 1.0));`;
91592 const asinh = unaryKernelFunc({ opSnippet: ASINH });
91593 const asinhConfig = {
91594 kernelName: Asinh,
91595 backendName: 'webgl',
91596 kernelFunc: asinh,
91597 };
91598
91599 /**
91600 * @license
91601 * Copyright 2020 Google LLC. All Rights Reserved.
91602 * Licensed under the Apache License, Version 2.0 (the "License");
91603 * you may not use this file except in compliance with the License.
91604 * You may obtain a copy of the License at
91605 *
91606 * http://www.apache.org/licenses/LICENSE-2.0
91607 *
91608 * Unless required by applicable law or agreed to in writing, software
91609 * distributed under the License is distributed on an "AS IS" BASIS,
91610 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91611 * See the License for the specific language governing permissions and
91612 * limitations under the License.
91613 * =============================================================================
91614 */
91615 const ATAN = CHECK_NAN_SNIPPET$1 + `
91616 return atan(x);
91617`;
91618 const atan = unaryKernelFunc({ opSnippet: ATAN });
91619 const atanConfig = {
91620 kernelName: Atan,
91621 backendName: 'webgl',
91622 kernelFunc: atan,
91623 };
91624
91625 /**
91626 * @license
91627 * Copyright 2020 Google LLC. All Rights Reserved.
91628 * Licensed under the Apache License, Version 2.0 (the "License");
91629 * you may not use this file except in compliance with the License.
91630 * You may obtain a copy of the License at
91631 *
91632 * http://www.apache.org/licenses/LICENSE-2.0
91633 *
91634 * Unless required by applicable law or agreed to in writing, software
91635 * distributed under the License is distributed on an "AS IS" BASIS,
91636 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91637 * See the License for the specific language governing permissions and
91638 * limitations under the License.
91639 * =============================================================================
91640 */
91641 const ATAN2 = CHECK_NAN_SNIPPET + `
91642 return atan(a, b);
91643`;
91644 const ATAN2_PACKED = `
91645 vec4 result = atan(a, b);
91646 bvec4 isNaNA = isnan(a);
91647 bvec4 isNaNB = isnan(b);
91648 bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w);
91649 ` +
91650 CHECK_NAN_SNIPPET_PACKED + `
91651 return result;
91652`;
91653 const atan2 = binaryKernelFunc({ opSnippet: ATAN2, packedOpSnippet: ATAN2_PACKED });
91654 const atan2Config = {
91655 kernelName: Atan2,
91656 backendName: 'webgl',
91657 kernelFunc: atan2,
91658 };
91659
91660 /**
91661 * @license
91662 * Copyright 2020 Google LLC. All Rights Reserved.
91663 * Licensed under the Apache License, Version 2.0 (the "License");
91664 * you may not use this file except in compliance with the License.
91665 * You may obtain a copy of the License at
91666 *
91667 * http://www.apache.org/licenses/LICENSE-2.0
91668 *
91669 * Unless required by applicable law or agreed to in writing, software
91670 * distributed under the License is distributed on an "AS IS" BASIS,
91671 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91672 * See the License for the specific language governing permissions and
91673 * limitations under the License.
91674 * =============================================================================
91675 */
91676 const ATANH = CHECK_NAN_SNIPPET$1 + `
91677 if ((x < -1.0) || (x > 1.0)) return NAN;
91678return (log(1.0 + x) - log(1.0 - x)) / 2.0;`;
91679 const atanh = unaryKernelFunc({ opSnippet: ATANH });
91680 const atanhConfig = {
91681 kernelName: Atanh,
91682 backendName: 'webgl',
91683 kernelFunc: atanh,
91684 };
91685
91686 /**
91687 * @license
91688 * Copyright 2017 Google LLC. All Rights Reserved.
91689 * Licensed under the Apache License, Version 2.0 (the "License");
91690 * you may not use this file except in compliance with the License.
91691 * You may obtain a copy of the License at
91692 *
91693 * http://www.apache.org/licenses/LICENSE-2.0
91694 *
91695 * Unless required by applicable law or agreed to in writing, software
91696 * distributed under the License is distributed on an "AS IS" BASIS,
91697 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
91698 * See the License for the specific language governing permissions and
91699 * limitations under the License.
91700 * =============================================================================
91701 */
91702 class Pool2DProgram {
91703 constructor(convInfo, poolType, computePositions, flattenPositions = false, includeBatchInIndex = false) {
91704 this.variableNames = ['x'];
91705 if (poolType === 'avg' && computePositions) {
91706 throw new Error('Cannot compute positions for average pool.');
91707 }
91708 const filterWidth = convInfo.filterWidth;
91709 const strideHeight = convInfo.strideHeight;
91710 const strideWidth = convInfo.strideWidth;
91711 const dilationHeight = convInfo.dilationHeight;
91712 const dilationWidth = convInfo.dilationWidth;
91713 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
91714 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
91715 const padTop = convInfo.padInfo.top;
91716 const padLeft = convInfo.padInfo.left;
91717 this.outputShape = convInfo.outShape;
91718 const isAvgPool = poolType === 'avg';
91719 const batchFlattenPositionStr = `((batch * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;
91720 const flattenPositionStr = `(xR * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;
91721 let initializationValue = '0.0';
91722 if (!isAvgPool) {
91723 // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
91724 initializationValue = '-1.0 / 1e-20';
91725 }
91726 if (computePositions) {
91727 const compareOp = '>=';
91728 this.userCode = `
91729 const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
91730 const ivec2 pads = ivec2(${padTop}, ${padLeft});
91731
91732 void main() {
91733 ivec4 coords = getOutputCoords();
91734 int batch = coords[0];
91735 int d = coords[3];
91736
91737 ivec2 xRCCorner = coords.yz * strides - pads;
91738 int xRCorner = xRCCorner.x;
91739 int xCCorner = xRCCorner.y;
91740
91741 // max/min x(?, ?, d) to get y(yR, yC, d).
91742 // ? = to be determined
91743 float minMaxValue = 0.0;
91744 float minMaxValueFound = 0.0;
91745 int minMaxPosition = 0;
91746 float avgValue = 0.0;
91747
91748 for (int wR = 0; wR < ${effectiveFilterHeight};
91749 wR += ${dilationHeight}) {
91750 int xR = xRCorner + wR;
91751
91752 if (xR < 0 || xR >= ${convInfo.inHeight}) {
91753 continue;
91754 }
91755
91756 for (int wC = 0; wC < ${effectiveFilterWidth};
91757 wC += ${dilationWidth}) {
91758 int xC = xCCorner + wC;
91759
91760 if (xC < 0 || xC >= ${convInfo.inWidth}) {
91761 continue;
91762 }
91763
91764 float value = getX(batch, xR, xC, d);
91765
91766 // If a min / max value has already been found, use it. If not,
91767 // use the current value.
91768 float currMinMaxValue = mix(
91769 value, minMaxValue, minMaxValueFound);
91770 if (value ${compareOp} currMinMaxValue) {
91771 minMaxValue = value;
91772 minMaxValueFound = 1.0;
91773 minMaxPosition = ${flattenPositions ? (includeBatchInIndex ? batchFlattenPositionStr :
91774 flattenPositionStr) :
91775 `wR * ${effectiveFilterWidth} + wC`};
91776 }
91777 }
91778 }
91779 setOutput(float(minMaxPosition));
91780 }
91781 `;
91782 return;
91783 }
91784 const compareOp = 'max';
91785 let returnValue = `${poolType}(${poolType}(${poolType}(` +
91786 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
91787 if (poolType === 'avg') {
91788 returnValue = `avgValue / max(count, 1.0)`;
91789 }
91790 const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
91791 const filterWidthVec4Remainder = filterWidth % 4;
91792 const updateSnippet = `
91793 if (${isAvgPool}) {
91794 avgValue += dot(values, ones);
91795 } else {
91796 minMaxValue = ${compareOp}(values, minMaxValue);
91797 }
91798 `;
91799 this.userCode = `
91800 const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
91801 const ivec2 pads = ivec2(${padTop}, ${padLeft});
91802 const float initializationValue = ${initializationValue};
91803 const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
91804
91805 float count = 0.0;
91806
91807 float getValue(int batch, int xR, int xC, int d) {
91808 if (xC < 0 || xC >= ${convInfo.inWidth}) {
91809 return initializationValue;
91810 }
91811 count += 1.0;
91812 return getX(batch, xR, xC, d);
91813 }
91814
91815 void main() {
91816 ivec4 coords = getOutputCoords();
91817 int batch = coords[0];
91818 int d = coords[3];
91819
91820 ivec2 xRCCorner = coords.yz * strides - pads;
91821 int xRCorner = xRCCorner.x;
91822 int xCCorner = xRCCorner.y;
91823
91824 // max/min x(?, ?, d) to get y(yR, yC, d).
91825 // ? = to be determined
91826 vec4 minMaxValue = vec4(${initializationValue});
91827 float avgValue = 0.0;
91828 count = 0.0;
91829
91830 for (int wR = 0; wR < ${effectiveFilterHeight};
91831 wR += ${dilationHeight}) {
91832 int xR = xRCorner + wR;
91833
91834 if (xR < 0 || xR >= ${convInfo.inHeight}) {
91835 continue;
91836 }
91837
91838 for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {
91839 int xC = xCCorner + wC * ${dilationWidth};
91840
91841 vec4 values = vec4(
91842 getValue(batch, xR, xC, d),
91843 getValue(batch, xR, xC + ${dilationWidth}, d),
91844 getValue(batch, xR, xC + 2 * ${dilationWidth}, d),
91845 getValue(batch, xR, xC + 3 * ${dilationWidth}, d)
91846 );
91847
91848 ${updateSnippet}
91849 }
91850
91851 int xC = xCCorner + ${filterWidthNearestVec4};
91852 if (${filterWidthVec4Remainder === 1}) {
91853 vec4 values = vec4(
91854 getValue(batch, xR, xC, d),
91855 initializationValue,
91856 initializationValue,
91857 initializationValue
91858 );
91859
91860 ${updateSnippet}
91861 } else if (${filterWidthVec4Remainder === 2}) {
91862 vec4 values = vec4(
91863 getValue(batch, xR, xC, d),
91864 getValue(batch, xR, xC + ${dilationWidth}, d),
91865 initializationValue,
91866 initializationValue
91867 );
91868
91869 ${updateSnippet}
91870 } else if (${filterWidthVec4Remainder === 3}) {
91871 vec4 values = vec4(
91872 getValue(batch, xR, xC, d),
91873 getValue(batch, xR, xC + ${dilationWidth}, d),
91874 getValue(batch, xR, xC + 2 * ${dilationWidth}, d),
91875 initializationValue
91876 );
91877
91878 ${updateSnippet}
91879 }
91880 }
91881 setOutput(${returnValue});
91882 }
91883 `;
91884 }
91885 }
91886 class Pool3DProgram {
91887 constructor(convInfo, poolType, computePositions, flattenPositions = false, includeBatchInIndex = false) {
91888 this.variableNames = ['x'];
91889 if (poolType === 'avg' && computePositions) {
91890 throw new Error('Cannot compute positions for average pool.');
91891 }
91892 const filterWidth = convInfo.filterWidth;
91893 const strideDepth = convInfo.strideDepth;
91894 const strideHeight = convInfo.strideHeight;
91895 const strideWidth = convInfo.strideWidth;
91896 const dilationDepth = convInfo.dilationDepth;
91897 const dilationHeight = convInfo.dilationHeight;
91898 const dilationWidth = convInfo.dilationWidth;
91899 const effectiveFilterDepth = convInfo.effectiveFilterDepth;
91900 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
91901 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
91902 const padFront = convInfo.padInfo.front;
91903 const padTop = convInfo.padInfo.top;
91904 const padLeft = convInfo.padInfo.left;
91905 this.outputShape = convInfo.outShape;
91906 const isAvgPool = poolType === 'avg';
91907 let initializationValue = '0.0';
91908 if (!isAvgPool) {
91909 // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
91910 initializationValue = '-1.0 / 1e-20';
91911 }
91912 if (computePositions) {
91913 const compareOp = '>=';
91914 this.userCode = `
91915 const ivec3 strides =
91916 ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
91917 const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
91918
91919 void main() {
91920 ivec5 coords = getOutputCoords();
91921 int batch = coords.x;
91922 int ch = coords.u;
91923
91924 ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
91925 int xDCorner = xCorner.x;
91926 int xRCorner = xCorner.y;
91927 int xCCorner = xCorner.z;
91928
91929 // max/min x(?, ?, ?, ch) to get y(yD, yR, yC, ch).
91930 // ? = to be determined
91931 float minMaxValue = 0.0;
91932 float minMaxValueFound = 0.0;
91933 int minMaxPosition = 0;
91934
91935 for (int wD = 0; wD < ${effectiveFilterDepth};
91936 wD += ${dilationDepth}) {
91937 int xD = xDCorner + wD;
91938
91939 if (xD < 0 || xD >= ${convInfo.inDepth}) {
91940 continue;
91941 }
91942
91943 for (int wR = 0; wR < ${effectiveFilterHeight};
91944 wR += ${dilationHeight}) {
91945 int xR = xRCorner + wR;
91946
91947 if (xR < 0 || xR >= ${convInfo.inHeight}) {
91948 continue;
91949 }
91950
91951 for (int wC = 0; wC < ${effectiveFilterWidth};
91952 wC += ${dilationWidth}) {
91953 int xC = xCCorner + wC;
91954
91955 if (xC < 0 || xC >= ${convInfo.inWidth}) {
91956 continue;
91957 }
91958
91959 float value = getX(batch, xD, xR, xC, ch);
91960
91961 // If a min / max value has already been found, use it. If not,
91962 // use the current value.
91963 float currMinMaxValue = mix(
91964 value, minMaxValue, minMaxValueFound);
91965 if (value ${compareOp} currMinMaxValue) {
91966 minMaxValue = value;
91967 minMaxValueFound = 1.0;
91968 minMaxPosition = ${flattenPositions ?
91969 (includeBatchInIndex ?
91970 `(((batch * ${convInfo.inDepth} + xD) * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch` :
91971 `((xD * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch`) :
91972 `wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} +
91973 wR * ${effectiveFilterWidth} + wC`};
91974 }
91975 }
91976 }
91977 }
91978 setOutput(float(minMaxPosition));
91979 }
91980 `;
91981 return;
91982 }
91983 const compareOp = 'max';
91984 let returnValue = `${poolType}(${poolType}(${poolType}(` +
91985 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
91986 if (poolType === 'avg') {
91987 // Use `max(count, 1.0)` instead of `count` in case count === 0.0.
91988 // If count === 0.0, `avgValue` is always 0.0 and we change `count`'s
91989 // value to avoid dividing zero.
91990 returnValue = `avgValue / max(count, 1.0)`;
91991 }
91992 const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
91993 const filterWidthVec4Remainder = filterWidth % 4;
91994 const updateSnippet = `
91995 if (${isAvgPool}) {
91996 avgValue += dot(values, ones);
91997 } else {
91998 minMaxValue = ${compareOp}(values, minMaxValue);
91999 }
92000 `;
92001 this.userCode = `
92002 const ivec3 strides =
92003 ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
92004 const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
92005 const float initializationValue = ${initializationValue};
92006 const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
92007
92008 float count = 0.0;
92009
92010 float getValue(int batch, int xD, int xR, int xC, int ch) {
92011 if (xC < 0 || xC >= ${convInfo.inWidth}) {
92012 return initializationValue;
92013 }
92014 count += 1.0;
92015 return getX(batch, xD, xR, xC, ch);
92016 }
92017
92018 void main() {
92019 ivec5 coords = getOutputCoords();
92020 int batch = coords.x;
92021 int ch = coords.u;
92022
92023 ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
92024 int xDCorner = xCorner.x;
92025 int xRCorner = xCorner.y;
92026 int xCCorner = xCorner.z;
92027
92028 // max/min x(?, ?, ?, d) to get y(yD, yR, yC, ch).
92029 // ? = to be determined
92030 vec4 minMaxValue = vec4(${initializationValue});
92031 float avgValue = 0.0;
92032 count = 0.0;
92033
92034 for (int wD = 0; wD < ${effectiveFilterDepth};
92035 wD += ${dilationDepth}) {
92036 int xD = xDCorner + wD;
92037
92038 if (xD < 0 || xD >= ${convInfo.inDepth}) {
92039 continue;
92040 }
92041
92042 for (int wR = 0; wR < ${effectiveFilterHeight};
92043 wR += ${dilationHeight}) {
92044 int xR = xRCorner + wR;
92045
92046 if (xR < 0 || xR >= ${convInfo.inHeight}) {
92047 continue;
92048 }
92049
92050 for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {
92051 int xC = xCCorner + wC * ${dilationWidth};
92052
92053 vec4 values = vec4(
92054 getValue(batch, xD, xR, xC, ch),
92055 getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
92056 getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),
92057 getValue(batch, xD, xR, xC + 3 * ${dilationWidth}, ch)
92058 );
92059
92060 ${updateSnippet}
92061 }
92062
92063 int xC = xCCorner + ${filterWidthNearestVec4};
92064 if (${filterWidthVec4Remainder === 1}) {
92065 vec4 values = vec4(
92066 getValue(batch, xD, xR, xC, ch),
92067 initializationValue,
92068 initializationValue,
92069 initializationValue
92070 );
92071
92072 ${updateSnippet}
92073 } else if (${filterWidthVec4Remainder === 2}) {
92074 vec4 values = vec4(
92075 getValue(batch, xD, xR, xC, ch),
92076 getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
92077 initializationValue,
92078 initializationValue
92079 );
92080
92081 ${updateSnippet}
92082 } else if (${filterWidthVec4Remainder === 3}) {
92083 vec4 values = vec4(
92084 getValue(batch, xD, xR, xC, ch),
92085 getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
92086 getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),
92087 initializationValue
92088 );
92089
92090 ${updateSnippet}
92091 }
92092 }
92093 }
92094 setOutput(${returnValue});
92095 }
92096 `;
92097 }
92098 }
92099
92100 /**
92101 * @license
92102 * Copyright 2020 Google LLC. All Rights Reserved.
92103 * Licensed under the Apache License, Version 2.0 (the "License");
92104 * you may not use this file except in compliance with the License.
92105 * You may obtain a copy of the License at
92106 *
92107 * http://www.apache.org/licenses/LICENSE-2.0
92108 *
92109 * Unless required by applicable law or agreed to in writing, software
92110 * distributed under the License is distributed on an "AS IS" BASIS,
92111 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92112 * See the License for the specific language governing permissions and
92113 * limitations under the License.
92114 * =============================================================================
92115 */
92116 function avgPool(args) {
92117 const { inputs, backend, attrs } = args;
92118 const { x } = inputs;
92119 assertNotComplex(x, 'avgPool');
92120 const { filterSize, strides, pad, dimRoundingMode } = attrs;
92121 const dilations = 1;
92122 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool: Either strides or dilations must be 1. ' +
92123 `Got strides ${strides} and dilations '${dilations}'`);
92124 const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
92125 if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
92126 arraysEqual(convInfo.inShape, convInfo.outShape)) {
92127 return identity({ inputs: { x }, backend });
92128 }
92129 const avgPoolProgram = new Pool2DProgram(convInfo, 'avg', false);
92130 return backend.runWebGLProgram(avgPoolProgram, [x], 'float32');
92131 }
92132 const avgPoolConfig = {
92133 kernelName: AvgPool,
92134 backendName: 'webgl',
92135 kernelFunc: avgPool
92136 };
92137
92138 /**
92139 * @license
92140 * Copyright 2020 Google LLC. All Rights Reserved.
92141 * Licensed under the Apache License, Version 2.0 (the "License");
92142 * you may not use this file except in compliance with the License.
92143 * You may obtain a copy of the License at
92144 *
92145 * http://www.apache.org/licenses/LICENSE-2.0
92146 *
92147 * Unless required by applicable law or agreed to in writing, software
92148 * distributed under the License is distributed on an "AS IS" BASIS,
92149 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92150 * See the License for the specific language governing permissions and
92151 * limitations under the License.
92152 * =============================================================================
92153 */
92154 function avgPool3D(args) {
92155 const { inputs, backend, attrs } = args;
92156 const { x } = inputs;
92157 const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs;
92158 const dilations = [1, 1, 1];
92159 const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat);
92160 const avgPoolProgram = new Pool3DProgram(convInfo, 'avg', false);
92161 return backend.runWebGLProgram(avgPoolProgram, [x], 'float32');
92162 }
92163 const avgPool3DConfig = {
92164 kernelName: AvgPool3D,
92165 backendName: 'webgl',
92166 kernelFunc: avgPool3D
92167 };
92168
92169 /**
92170 * @license
92171 * Copyright 2017 Google LLC. All Rights Reserved.
92172 * Licensed under the Apache License, Version 2.0 (the "License");
92173 * you may not use this file except in compliance with the License.
92174 * You may obtain a copy of the License at
92175 *
92176 * http://www.apache.org/licenses/LICENSE-2.0
92177 *
92178 * Unless required by applicable law or agreed to in writing, software
92179 * distributed under the License is distributed on an "AS IS" BASIS,
92180 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92181 * See the License for the specific language governing permissions and
92182 * limitations under the License.
92183 * =============================================================================
92184 */
92185 class AvgPool2DBackpropProgram {
92186 constructor(convInfo) {
92187 this.variableNames = ['dy'];
92188 this.outputShape = convInfo.inShape;
92189 const filterHeight = convInfo.filterHeight;
92190 const filterWidth = convInfo.filterWidth;
92191 const strideHeight = convInfo.strideHeight;
92192 const strideWidth = convInfo.strideWidth;
92193 const dilationHeight = convInfo.dilationHeight;
92194 const dilationWidth = convInfo.dilationWidth;
92195 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
92196 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
92197 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
92198 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
92199 const avgMultiplier = 1 / (filterHeight * filterWidth);
92200 this.userCode = `
92201 const ivec2 pads = ivec2(${padTop}, ${padLeft});
92202 const float avgMultiplier = float(${avgMultiplier});
92203
92204 void main() {
92205 ivec4 coords = getOutputCoords();
92206 int b = coords[0];
92207 int d = coords[3];
92208
92209 ivec2 dyRCCorner = coords.yz - pads;
92210 int dyRCorner = dyRCCorner.x;
92211 int dyCCorner = dyRCCorner.y;
92212
92213 // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).
92214 // ? = to be determined. : = across all values in that axis.
92215 float dotProd = 0.0;
92216 for (int wR = 0; wR < ${effectiveFilterHeight};
92217 wR += ${dilationHeight}) {
92218 float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
92219
92220 if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
92221 continue;
92222 }
92223 int idyR = int(dyR);
92224
92225 for (int wC = 0; wC < ${effectiveFilterWidth};
92226 wC+= ${dilationWidth}) {
92227 float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
92228
92229 if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
92230 fract(dyC) > 0.0) {
92231 continue;
92232 }
92233 int idyC = int(dyC);
92234
92235 float dyValue = getDy(b, idyR, idyC, d);
92236
92237 dotProd += dyValue * avgMultiplier;
92238 }
92239 }
92240 setOutput(dotProd);
92241 }
92242 `;
92243 }
92244 }
92245 class AvgPool3DBackpropProgram {
92246 constructor(convInfo) {
92247 this.variableNames = ['dy'];
92248 this.outputShape = convInfo.inShape;
92249 const filterDepth = convInfo.filterDepth;
92250 const filterHeight = convInfo.filterHeight;
92251 const filterWidth = convInfo.filterWidth;
92252 const strideDepth = convInfo.strideDepth;
92253 const strideHeight = convInfo.strideHeight;
92254 const strideWidth = convInfo.strideWidth;
92255 const dilationDepth = convInfo.dilationDepth;
92256 const dilationHeight = convInfo.dilationHeight;
92257 const dilationWidth = convInfo.dilationWidth;
92258 const effectiveFilterDepth = convInfo.effectiveFilterDepth;
92259 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
92260 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
92261 const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
92262 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
92263 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
92264 const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
92265 this.userCode = `
92266 const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
92267 const float avgMultiplier = float(${avgMultiplier});
92268
92269 void main() {
92270 ivec5 coords = getOutputCoords();
92271 int batch = coords.x;
92272 int ch = coords.u;
92273
92274 ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
92275 int dyDCorner = dyCorner.x;
92276 int dyRCorner = dyCorner.y;
92277 int dyCCorner = dyCorner.z;
92278
92279 // Convolve dy(?, ?, ?, d) with pos mask(:, :, :, ch) to get
92280 // dx(xD, xR, xC, ch).
92281 // ? = to be determined. : = across all values in that axis.
92282 float dotProd = 0.0;
92283
92284 for (int wD = 0; wD < ${effectiveFilterDepth};
92285 wD += ${dilationDepth}) {
92286 float dyD = float(dyDCorner + wD) / ${strideDepth}.0;
92287
92288 if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) {
92289 continue;
92290 }
92291 int idyD = int(dyD);
92292
92293 for (int wR = 0; wR < ${effectiveFilterHeight};
92294 wR += ${dilationHeight}) {
92295 float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
92296
92297 if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
92298 fract(dyR) > 0.0) {
92299 continue;
92300 }
92301 int idyR = int(dyR);
92302
92303 for (int wC = 0; wC < ${effectiveFilterWidth};
92304 wC += ${dilationWidth}) {
92305 float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
92306
92307 if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
92308 fract(dyC) > 0.0) {
92309 continue;
92310 }
92311 int idyC = int(dyC);
92312
92313 float dyValue = getDy(batch, idyD, idyR, idyC, ch);
92314
92315 dotProd += dyValue * avgMultiplier;
92316 }
92317 }
92318 }
92319 setOutput(dotProd);
92320 }
92321 `;
92322 }
92323 }
92324
92325 /**
92326 * @license
92327 * Copyright 2020 Google LLC. All Rights Reserved.
92328 * Licensed under the Apache License, Version 2.0 (the "License");
92329 * you may not use this file except in compliance with the License.
92330 * You may obtain a copy of the License at
92331 *
92332 * http://www.apache.org/licenses/LICENSE-2.0
92333 *
92334 * Unless required by applicable law or agreed to in writing, software
92335 * distributed under the License is distributed on an "AS IS" BASIS,
92336 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92337 * See the License for the specific language governing permissions and
92338 * limitations under the License.
92339 * =============================================================================
92340 */
92341 function avgPool3DGrad(args) {
92342 const { inputs, backend, attrs } = args;
92343 const { dy, input } = inputs;
92344 const x = input;
92345 const { filterSize, strides, pad, dimRoundingMode } = attrs;
92346 const dilations = [1, 1, 1];
92347 const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
92348 const avgPoolBackpropProgram = new AvgPool3DBackpropProgram(convInfo);
92349 return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype);
92350 }
92351 const avgPool3DGradConfig = {
92352 kernelName: AvgPool3DGrad,
92353 backendName: 'webgl',
92354 kernelFunc: avgPool3DGrad
92355 };
92356
92357 /**
92358 * @license
92359 * Copyright 2020 Google LLC. All Rights Reserved.
92360 * Licensed under the Apache License, Version 2.0 (the "License");
92361 * you may not use this file except in compliance with the License.
92362 * You may obtain a copy of the License at
92363 *
92364 * http://www.apache.org/licenses/LICENSE-2.0
92365 *
92366 * Unless required by applicable law or agreed to in writing, software
92367 * distributed under the License is distributed on an "AS IS" BASIS,
92368 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92369 * See the License for the specific language governing permissions and
92370 * limitations under the License.
92371 * =============================================================================
92372 */
92373 function avgPoolGrad(args) {
92374 const { inputs, backend, attrs } = args;
92375 const { dy, input } = inputs;
92376 const x = input;
92377 assertNotComplex([dy, input], 'avgPoolGrad');
92378 const { filterSize, strides, pad } = attrs;
92379 const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad);
92380 const avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo);
92381 return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype);
92382 }
92383 const avgPoolGradConfig = {
92384 kernelName: AvgPoolGrad,
92385 backendName: 'webgl',
92386 kernelFunc: avgPoolGrad
92387 };
92388
92389 /**
92390 * @license
92391 * Copyright 2020 Google LLC. All Rights Reserved.
92392 * Licensed under the Apache License, Version 2.0 (the "License");
92393 * you may not use this file except in compliance with the License.
92394 * You may obtain a copy of the License at
92395 *
92396 * http://www.apache.org/licenses/LICENSE-2.0
92397 *
92398 * Unless required by applicable law or agreed to in writing, software
92399 * distributed under the License is distributed on an "AS IS" BASIS,
92400 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92401 * See the License for the specific language governing permissions and
92402 * limitations under the License.
92403 * =============================================================================
92404 */
92405 function batchMatMul(args) {
92406 const { inputs, backend, attrs } = args;
92407 const { a, b } = inputs;
92408 const { transposeA, transposeB } = attrs;
92409 return batchMatMulImpl({ a, b, transposeA, transposeB, backend });
92410 }
92411 const batchMatMulConfig = {
92412 kernelName: BatchMatMul,
92413 backendName: 'webgl',
92414 kernelFunc: batchMatMul,
92415 };
92416
92417 /**
92418 * @license
92419 * Copyright 2017 Google LLC. All Rights Reserved.
92420 * Licensed under the Apache License, Version 2.0 (the "License");
92421 * you may not use this file except in compliance with the License.
92422 * You may obtain a copy of the License at
92423 *
92424 * http://www.apache.org/licenses/LICENSE-2.0
92425 *
92426 * Unless required by applicable law or agreed to in writing, software
92427 * distributed under the License is distributed on an "AS IS" BASIS,
92428 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92429 * See the License for the specific language governing permissions and
92430 * limitations under the License.
92431 * =============================================================================
92432 */
92433 class BatchNormProgram {
92434 constructor(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
92435 this.outputShape = [];
92436 this.variableNames = ['x', 'mean', 'variance'];
92437 assertAndGetBroadcastShape(xShape, meanShape);
92438 assertAndGetBroadcastShape(xShape, varianceShape);
92439 let offsetSnippet = '0.0';
92440 if (offsetShape != null) {
92441 assertAndGetBroadcastShape(xShape, offsetShape);
92442 this.variableNames.push('offset');
92443 offsetSnippet = 'getOffsetAtOutCoords()';
92444 }
92445 let scaleSnippet = '1.0';
92446 if (scaleShape != null) {
92447 assertAndGetBroadcastShape(xShape, scaleShape);
92448 this.variableNames.push('scale');
92449 scaleSnippet = 'getScaleAtOutCoords()';
92450 }
92451 this.outputShape = xShape;
92452 this.userCode = `
92453 void main() {
92454 float x = getXAtOutCoords();
92455 float mean = getMeanAtOutCoords();
92456 float variance = getVarianceAtOutCoords();
92457 float offset = ${offsetSnippet};
92458 float scale = ${scaleSnippet};
92459 float inv = scale * inversesqrt(variance + float(${varianceEpsilon}));
92460 setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1)));
92461 }
92462 `;
92463 }
92464 }
92465
92466 /**
92467 * @license
92468 * Copyright 2018 Google LLC. All Rights Reserved.
92469 * Licensed under the Apache License, Version 2.0 (the "License");
92470 * you may not use this file except in compliance with the License.
92471 * You may obtain a copy of the License at
92472 *
92473 * http://www.apache.org/licenses/LICENSE-2.0
92474 *
92475 * Unless required by applicable law or agreed to in writing, software
92476 * distributed under the License is distributed on an "AS IS" BASIS,
92477 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92478 * See the License for the specific language governing permissions and
92479 * limitations under the License.
92480 * =============================================================================
92481 */
92482 class BatchNormPackedProgram {
92483 constructor(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
92484 this.packedInputs = true;
92485 this.packedOutput = true;
92486 this.variableNames = ['x', 'mean', 'variance'];
92487 assertAndGetBroadcastShape(xShape, meanShape);
92488 assertAndGetBroadcastShape(xShape, varianceShape);
92489 let offsetSnippet = 'vec4(0.0)';
92490 if (offsetShape != null) {
92491 assertAndGetBroadcastShape(xShape, offsetShape);
92492 this.variableNames.push('offset');
92493 offsetSnippet = 'getOffsetAtOutCoords()';
92494 }
92495 let scaleSnippet = 'vec4(1.0)';
92496 if (scaleShape != null) {
92497 assertAndGetBroadcastShape(xShape, scaleShape);
92498 this.variableNames.push('scale');
92499 scaleSnippet = 'getScaleAtOutCoords()';
92500 }
92501 this.outputShape = xShape;
92502 this.userCode = `
92503 void main() {
92504 vec4 offset = ${offsetSnippet};
92505 vec4 scale = ${scaleSnippet};
92506
92507 vec4 x = getXAtOutCoords();
92508 vec4 mean = getMeanAtOutCoords();
92509 vec4 variance = getVarianceAtOutCoords();
92510
92511 vec4 inv = scale * inversesqrt(variance + vec4(${varianceEpsilon}));
92512
92513 setOutput((x - mean) * inv + offset);
92514 }
92515 `;
92516 }
92517 }
92518
92519 /**
92520 * @license
92521 * Copyright 2020 Google LLC. All Rights Reserved.
92522 * Licensed under the Apache License, Version 2.0 (the "License");
92523 * you may not use this file except in compliance with the License.
92524 * You may obtain a copy of the License at
92525 *
92526 * http://www.apache.org/licenses/LICENSE-2.0
92527 *
92528 * Unless required by applicable law or agreed to in writing, software
92529 * distributed under the License is distributed on an "AS IS" BASIS,
92530 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92531 * See the License for the specific language governing permissions and
92532 * limitations under the License.
92533 * =============================================================================
92534 */
92535 const batchNorm = ({ inputs, backend, attrs }) => {
92536 const { x, mean, variance, offset, scale } = inputs;
92537 assert$1(mean.shape.length === variance.shape.length, () => 'Batch normalization gradient requires mean and variance to have ' +
92538 'equal ranks.');
92539 assert$1(offset == null || mean.shape.length === offset.shape.length, () => 'Batch normalization gradient requires mean and offset to have ' +
92540 'equal ranks.');
92541 assert$1(scale == null || mean.shape.length === scale.shape.length, () => 'Batch normalization gradient requires mean and scale to have ' +
92542 'equal ranks.');
92543 let { varianceEpsilon } = attrs;
92544 if (varianceEpsilon == null) {
92545 varianceEpsilon = 0.001;
92546 }
92547 const finalInputs = [x, mean, variance];
92548 let offsetShape = null;
92549 if (offset != null) {
92550 offsetShape = offset.shape;
92551 finalInputs.push(offset);
92552 }
92553 let scaleShape = null;
92554 if (scale != null) {
92555 scaleShape = scale.shape;
92556 finalInputs.push(scale);
92557 }
92558 const program = env().getBool('WEBGL_PACK_NORMALIZATION') ?
92559 new BatchNormPackedProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon) :
92560 new BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon);
92561 const output = backend.runWebGLProgram(program, finalInputs, finalInputs[0].dtype);
92562 return output;
92563 };
92564 const batchNormConfig = {
92565 kernelName: FusedBatchNorm,
92566 backendName: 'webgl',
92567 kernelFunc: batchNorm,
92568 };
92569
92570 /**
92571 * @license
92572 * Copyright 2017 Google LLC. All Rights Reserved.
92573 * Licensed under the Apache License, Version 2.0 (the "License");
92574 * you may not use this file except in compliance with the License.
92575 * You may obtain a copy of the License at
92576 *
92577 * http://www.apache.org/licenses/LICENSE-2.0
92578 *
92579 * Unless required by applicable law or agreed to in writing, software
92580 * distributed under the License is distributed on an "AS IS" BASIS,
92581 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92582 * See the License for the specific language governing permissions and
92583 * limitations under the License.
92584 * =============================================================================
92585 */
92586 class SliceProgram {
92587 constructor(destSize) {
92588 this.variableNames = ['source'];
92589 this.outputShape = destSize;
92590 this.rank = destSize.length;
92591 const dtype = getCoordsDataType(this.rank);
92592 this.customUniforms = [{ name: 'start', arrayIndex: this.rank, type: 'int' }];
92593 const sourceCoords = getCoords$1(this.rank);
92594 let body;
92595 const coordSum = destSize.map((_, i) => {
92596 return `sourceLoc.${coords[i]} = start[${i}] + coords.${coords[i]};`;
92597 });
92598 body = `
92599 ${dtype} sourceLoc;
92600 ${dtype} coords = getOutputCoords();
92601 ${coordSum.join('\n')}
92602 `;
92603 this.userCode = `
92604 void main() {
92605 ${body}
92606 setOutput(getSource(${sourceCoords}));
92607 }
92608 `;
92609 }
92610 }
92611 const coords = ['x', 'y', 'z', 'w', 'u', 'v'];
92612 function getCoords$1(rank) {
92613 if (rank === 1) {
92614 return 'sourceLoc';
92615 }
92616 else if (rank <= 6) {
92617 return coords.slice(0, rank).map(x => 'sourceLoc.' + x).join(',');
92618 }
92619 else {
92620 throw Error(`Slicing for rank ${rank} is not yet supported`);
92621 }
92622 }
92623
92624 /**
92625 * @license
92626 * Copyright 2019 Google LLC. All Rights Reserved.
92627 * Licensed under the Apache License, Version 2.0 (the "License");
92628 * you may not use this file except in compliance with the License.
92629 * You may obtain a copy of the License at
92630 *
92631 * http://www.apache.org/licenses/LICENSE-2.0
92632 *
92633 * Unless required by applicable law or agreed to in writing, software
92634 * distributed under the License is distributed on an "AS IS" BASIS,
92635 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92636 * See the License for the specific language governing permissions and
92637 * limitations under the License.
92638 * =============================================================================
92639 */
92640 class SlicePackedProgram {
92641 constructor(destSize) {
92642 this.variableNames = ['source'];
92643 this.packedInputs = true;
92644 this.packedOutput = true;
92645 this.outputShape = destSize;
92646 this.rank = destSize.length;
92647 this.customUniforms = [{ name: 'start', arrayIndex: this.rank, type: 'int' }];
92648 const dtype = getCoordsDataType(this.rank);
92649 const coords = getChannels('coords', this.rank);
92650 const sourceLoc = getChannels('sourceLoc', this.rank);
92651 const innerDims = this.rank === 1 ? 'sourceLoc' : `vec2(${sourceLoc.slice(-2).join()})`;
92652 const getChannel = `getChannel(getSource(${sourceLoc.join()}), ${innerDims})`;
92653 const upperRow = `
92654 result.x = ${getChannel};
92655 if (++${coords[this.rank - 1]} < ${destSize[this.rank - 1]}) {
92656 ++${sourceLoc[this.rank - 1]};
92657 result.y = ${getChannel};
92658 --${sourceLoc[this.rank - 1]};
92659 }
92660 `;
92661 const lowerRow = this.rank === 1 ? '' : `
92662 --${coords[this.rank - 1]};
92663 if (++${coords[this.rank - 2]} < ${destSize[this.rank - 2]}) {
92664 ++${sourceLoc[this.rank - 2]};
92665 result.z = ${getChannel};
92666 if (++${coords[this.rank - 1]} < ${destSize[this.rank - 1]}) {
92667 ++${sourceLoc[this.rank - 1]};
92668 result.w = ${getChannel};
92669 }
92670 }
92671 `;
92672 const sourceLocSetup = this.rank <= 4 ?
92673 `sourceLoc = coords +
92674 ${dtype}(${destSize.map((_, i) => `start[${i}]`).join()});` :
92675 destSize.map((_, i) => `${sourceLoc[i]} = ${coords[i]} + start[${i}];`)
92676 .join('\n');
92677 this.userCode = `
92678 void main() {
92679 ${dtype} coords = getOutputCoords();
92680 ${dtype} sourceLoc;
92681 ${sourceLocSetup}
92682 vec4 result = vec4(0.);
92683 ${upperRow}
92684 ${lowerRow}
92685 setOutput(result);
92686 }
92687 `;
92688 }
92689 }
92690
92691 /**
92692 * @license
92693 * Copyright 2020 Google LLC. All Rights Reserved.
92694 * Licensed under the Apache License, Version 2.0 (the "License");
92695 * you may not use this file except in compliance with the License.
92696 * You may obtain a copy of the License at
92697 *
92698 * http://www.apache.org/licenses/LICENSE-2.0
92699 *
92700 * Unless required by applicable law or agreed to in writing, software
92701 * distributed under the License is distributed on an "AS IS" BASIS,
92702 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92703 * See the License for the specific language governing permissions and
92704 * limitations under the License.
92705 * =============================================================================
92706 */
92707 function shallowSlice(x, begin, size, backend) {
92708 const xTexData = backend.texData.get(x.dataId);
92709 const t = backend.makeTensorInfo(size, x.dtype);
92710 const newTexData = backend.texData.get(t.dataId);
92711 // Copy texture data from the original tensor.
92712 Object.assign(newTexData, xTexData);
92713 newTexData.refCount = 1;
92714 newTexData.shape = size;
92715 newTexData.dtype = x.dtype;
92716 let flatOffset = computeFlatOffset(begin, computeStrides(x.shape));
92717 if (xTexData.slice) {
92718 // We are slicing an already sliced tensor, so we have to accumulate
92719 // the offset.
92720 flatOffset += xTexData.slice.flatOffset;
92721 }
92722 newTexData.slice = {
92723 flatOffset,
92724 // Point to the original dataId, which is used to do ref counting.
92725 origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId
92726 };
92727 // Increase the ref count for that data bucket.
92728 const refCount = backend.dataRefCount.get(newTexData.slice.origDataId) || 1;
92729 backend.dataRefCount.set(newTexData.slice.origDataId, refCount + 1);
92730 return t;
92731 }
92732 function slice(args) {
92733 const { inputs, backend, attrs } = args;
92734 const { x } = inputs;
92735 const { begin, size } = attrs;
92736 const [$begin, $size] = parseSliceParams(x, begin, size);
92737 assertParamsValid(x, $begin, $size);
92738 if (sizeFromShape($size) === 0) {
92739 return backend.makeTensorInfo($size, x.dtype, []);
92740 }
92741 // Run on cpu if dtype is string. For string, the backend represents it
92742 // as Uint8Array[], where each Uint8Array is a character. Given that the
92743 // computation is only on the outer array, uploading the whole data onto
92744 // gpu is wasteful. Also, currently webgl doesn't have a design to
92745 // upload and retrieve Uint8Array[] between cpu and gpu. Therefore, we
92746 // just run the kernel on cpu if dtype is string.
92747 if (backend.shouldExecuteOnCPU([x]) || x.dtype === 'string') {
92748 const xTexData = backend.texData.get(x.dataId);
92749 const outValues = sliceImplCPU(xTexData.values, $begin, $size, x.shape, x.dtype);
92750 return backend.makeTensorInfo($size, x.dtype, outValues);
92751 }
92752 const { isPacked } = backend.texData.get(x.dataId);
92753 const isContinous = isSliceContinous(x.shape, $begin, $size);
92754 if (isPacked || !isContinous) {
92755 const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
92756 new SlicePackedProgram($size) :
92757 new SliceProgram($size);
92758 const customValues = [$begin];
92759 return backend.runWebGLProgram(program, [x], x.dtype, customValues);
92760 }
92761 backend.uploadToGPU(x.dataId);
92762 return shallowSlice(x, $begin, $size, backend);
92763 }
92764 const sliceConfig = {
92765 kernelName: Slice,
92766 backendName: 'webgl',
92767 kernelFunc: slice
92768 };
92769
92770 /**
92771 * @license
92772 * Copyright 2020 Google LLC. All Rights Reserved.
92773 * Licensed under the Apache License, Version 2.0 (the "License");
92774 * you may not use this file except in compliance with the License.
92775 * You may obtain a copy of the License at
92776 *
92777 * http://www.apache.org/licenses/LICENSE-2.0
92778 *
92779 * Unless required by applicable law or agreed to in writing, software
92780 * distributed under the License is distributed on an "AS IS" BASIS,
92781 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92782 * See the License for the specific language governing permissions and
92783 * limitations under the License.
92784 * =============================================================================
92785 */
92786 const batchToSpaceND = (args) => {
92787 const { inputs, backend, attrs } = args;
92788 const { x } = inputs;
92789 const { blockShape, crops } = attrs;
92790 assert$1(x.shape.length <= 4, () => 'batchToSpaceND for rank > 4 with a WebGL backend not ' +
92791 'implemented yet');
92792 const prod = blockShape.reduce((a, b) => a * b);
92793 const reshaped = getReshaped(x.shape, blockShape, prod);
92794 const permuted = getPermuted(reshaped.length, blockShape.length);
92795 const reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod);
92796 const sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length);
92797 const sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length);
92798 const toDispose = [];
92799 const reshapedIntermediate = reshape({ inputs: { x }, backend, attrs: { shape: reshaped } });
92800 const transposedIntermediate = transpose({ inputs: { x: reshapedIntermediate }, backend, attrs: { perm: permuted } });
92801 const reshapedIntermediate2 = reshape({
92802 inputs: { x: transposedIntermediate },
92803 backend,
92804 attrs: { shape: reshapedPermuted }
92805 });
92806 const sliced = slice({
92807 inputs: { x: reshapedIntermediate2 },
92808 backend,
92809 attrs: { begin: sliceBeginCoords, size: sliceSize }
92810 });
92811 toDispose.push(reshapedIntermediate);
92812 toDispose.push(transposedIntermediate);
92813 toDispose.push(reshapedIntermediate2);
92814 toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
92815 return sliced;
92816 };
92817 const batchToSpaceNDConfig = {
92818 kernelName: BatchToSpaceND,
92819 backendName: 'webgl',
92820 kernelFunc: batchToSpaceND
92821 };
92822
92823 /**
92824 * @license
92825 * Copyright 2020 Google LLC. All Rights Reserved.
92826 * Licensed under the Apache License, Version 2.0 (the "License");
92827 * you may not use this file except in compliance with the License.
92828 * You may obtain a copy of the License at
92829 *
92830 * http://www.apache.org/licenses/LICENSE-2.0
92831 *
92832 * Unless required by applicable law or agreed to in writing, software
92833 * distributed under the License is distributed on an "AS IS" BASIS,
92834 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92835 * See the License for the specific language governing permissions and
92836 * limitations under the License.
92837 * =============================================================================
92838 */
92839 function bincount(args) {
92840 const { inputs, backend, attrs } = args;
92841 const { x, weights } = inputs;
92842 const { size } = attrs;
92843 const xVals = backend.readSync(x.dataId);
92844 const weightsVals = backend.readSync(weights.dataId);
92845 const outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size);
92846 return backend.makeTensorInfo([size], weights.dtype, outVals);
92847 }
92848 const bincountConfig = {
92849 kernelName: Bincount,
92850 backendName: 'webgl',
92851 kernelFunc: bincount
92852 };
92853
92854 /**
92855 * @license
92856 * Copyright 2023 Google LLC.
92857 * Licensed under the Apache License, Version 2.0 (the "License");
92858 * you may not use this file except in compliance with the License.
92859 * You may obtain a copy of the License at
92860 *
92861 * http://www.apache.org/licenses/LICENSE-2.0
92862 *
92863 * Unless required by applicable law or agreed to in writing, software
92864 * distributed under the License is distributed on an "AS IS" BASIS,
92865 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92866 * See the License for the specific language governing permissions and
92867 * limitations under the License.
92868 * =============================================================================
92869 */
92870 const BITWISEAND = `
92871 int r = int(a.r) & int(b.r);
92872 int g = int(a.g) & int(b.g);
92873 int rb = int(a.b) & int(b.b);
92874 int ra = int(a.a) & int(b.a);
92875 return vec4(r, g, rb, ra);
92876`;
92877 const BITWISEAND_UNPACKED = `
92878 return float(int(a.r) & int(b.r));
92879`;
92880 function bitwiseAnd(args) {
92881 const { inputs, backend } = args;
92882 const { a, b } = inputs;
92883 const shouldUsePackedProgram = env().getBool('WEBGL_PACK_BINARY_OPERATIONS');
92884 const versionNumber = env().getNumber('WEBGL_VERSION');
92885 // The type of a and b are ensured to be `int32` in core, therefore no need to
92886 // consider other type situations.
92887 if ((backend.shouldExecuteOnCPU([a, b])) || versionNumber === 1) {
92888 const aVals = backend.texData.get(a.dataId).values;
92889 const bVals = backend.texData.get(b.dataId).values;
92890 const [outValues, outShape] = bitwiseAndImplCPU(a.shape, b.shape, aVals, bVals, a.dtype);
92891 const out = backend.makeTensorInfo(outShape, a.dtype);
92892 const outData = backend.texData.get(out.dataId);
92893 outData.values = outValues;
92894 return out;
92895 }
92896 let program;
92897 if (shouldUsePackedProgram) {
92898 program = new BinaryOpPackedProgram(BITWISEAND, a.shape, b.shape, false);
92899 }
92900 else {
92901 program = new BinaryOpProgram(BITWISEAND_UNPACKED, a.shape, b.shape);
92902 }
92903 return backend.runWebGLProgram(program, [a, b], a.dtype);
92904 }
92905 const bitwiseAndConfig = {
92906 kernelName: BitwiseAnd,
92907 backendName: 'webgl',
92908 kernelFunc: bitwiseAnd
92909 };
92910
92911 /**
92912 * @license
92913 * Copyright 2021 Google LLC. All Rights Reserved.
92914 * Licensed under the Apache License, Version 2.0 (the "License");
92915 * you may not use this file except in compliance with the License.
92916 * You may obtain a copy of the License at
92917 *
92918 * http://www.apache.org/licenses/LICENSE-2.0
92919 *
92920 * Unless required by applicable law or agreed to in writing, software
92921 * distributed under the License is distributed on an "AS IS" BASIS,
92922 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92923 * See the License for the specific language governing permissions and
92924 * limitations under the License.
92925 * =============================================================================
92926 */
92927 function broadcastArgs(args) {
92928 const { inputs, backend } = args;
92929 const { s0, s1 } = inputs;
92930 const s0Vals = backend.readSync(s0.dataId);
92931 const s1Vals = backend.readSync(s1.dataId);
92932 const broadcastShape = assertAndGetBroadcastShape(Array.from(s0Vals), Array.from(s1Vals));
92933 return backend.makeTensorInfo([broadcastShape.length], 'int32', Int32Array.from(broadcastShape));
92934 }
92935 const broadcastArgsConfig = {
92936 kernelName: BroadcastArgs,
92937 backendName: 'webgl',
92938 kernelFunc: broadcastArgs
92939 };
92940
92941 /**
92942 * @license
92943 * Copyright 2020 Google LLC. All Rights Reserved.
92944 * Licensed under the Apache License, Version 2.0 (the "License");
92945 * you may not use this file except in compliance with the License.
92946 * You may obtain a copy of the License at
92947 *
92948 * http://www.apache.org/licenses/LICENSE-2.0
92949 *
92950 * Unless required by applicable law or agreed to in writing, software
92951 * distributed under the License is distributed on an "AS IS" BASIS,
92952 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92953 * See the License for the specific language governing permissions and
92954 * limitations under the License.
92955 * =============================================================================
92956 */
92957 const NOT_EQUAL = `return float(a != b);`;
92958 const notEqual = binaryKernelFunc({ opSnippet: NOT_EQUAL, cpuKernelImpl: notEqualImplCPU, dtype: 'bool' });
92959 const notEqualConfig = {
92960 kernelName: NotEqual,
92961 backendName: 'webgl',
92962 kernelFunc: notEqual,
92963 };
92964
92965 /**
92966 * @license
92967 * Copyright 2020 Google LLC. All Rights Reserved.
92968 * Licensed under the Apache License, Version 2.0 (the "License");
92969 * you may not use this file except in compliance with the License.
92970 * You may obtain a copy of the License at
92971 *
92972 * http://www.apache.org/licenses/LICENSE-2.0
92973 *
92974 * Unless required by applicable law or agreed to in writing, software
92975 * distributed under the License is distributed on an "AS IS" BASIS,
92976 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
92977 * See the License for the specific language governing permissions and
92978 * limitations under the License.
92979 * =============================================================================
92980 */
92981 function real(args) {
92982 const { inputs, backend } = args;
92983 const { input } = inputs;
92984 const inputData = backend.texData.get(input.dataId);
92985 return identity({ inputs: { x: inputData.complexTensorInfos.real }, backend });
92986 }
92987 const realConfig = {
92988 kernelName: Real,
92989 backendName: 'webgl',
92990 kernelFunc: real
92991 };
92992
92993 /**
92994 * @license
92995 * Copyright 2020 Google LLC. All Rights Reserved.
92996 * Licensed under the Apache License, Version 2.0 (the "License");
92997 * you may not use this file except in compliance with the License.
92998 * You may obtain a copy of the License at
92999 *
93000 * http://www.apache.org/licenses/LICENSE-2.0
93001 *
93002 * Unless required by applicable law or agreed to in writing, software
93003 * distributed under the License is distributed on an "AS IS" BASIS,
93004 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93005 * See the License for the specific language governing permissions and
93006 * limitations under the License.
93007 * =============================================================================
93008 */
93009 const TO_INT = `return float(int(x));`;
93010 function int(input, backend) {
93011 const program = new UnaryOpProgram(input.shape, TO_INT);
93012 const output = backend.runWebGLProgram(program, [input], 'int32');
93013 return { dataId: output.dataId, shape: output.shape, dtype: output.dtype };
93014 }
93015
93016 /**
93017 * @license
93018 * Copyright 2020 Google LLC. All Rights Reserved.
93019 * Licensed under the Apache License, Version 2.0 (the "License");
93020 * you may not use this file except in compliance with the License.
93021 * You may obtain a copy of the License at
93022 *
93023 * http://www.apache.org/licenses/LICENSE-2.0
93024 *
93025 * Unless required by applicable law or agreed to in writing, software
93026 * distributed under the License is distributed on an "AS IS" BASIS,
93027 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93028 * See the License for the specific language governing permissions and
93029 * limitations under the License.
93030 * =============================================================================
93031 */
93032 function cast(args) {
93033 const { inputs, backend, attrs } = args;
93034 const { x } = inputs;
93035 const { dtype } = attrs;
93036 // Casting to complex64.
93037 if (dtype === 'complex64') {
93038 if (x.dtype === 'complex64') {
93039 return identity({ inputs: { x }, backend });
93040 }
93041 // TODO(annxingyuan): Import kernel function once zeros is modularized.
93042 const zerosTensor = zeros$2(x.shape);
93043 const floatX = cast({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
93044 const result = complex({ inputs: { real: floatX, imag: zerosTensor }, backend });
93045 zerosTensor.dispose();
93046 backend.disposeIntermediateTensorInfo(floatX);
93047 return result;
93048 }
93049 // Casting from complex64
93050 if (x.dtype === 'complex64') {
93051 const realPart = real({ inputs: { input: x }, backend });
93052 const result = cast({ inputs: { x: realPart }, backend, attrs: { dtype } });
93053 backend.disposeIntermediateTensorInfo(realPart);
93054 return result;
93055 }
93056 if (!hasEncodingLoss(x.dtype, dtype)) {
93057 // We don't change the underlying data, since we cast to higher
93058 // precision.
93059 const result = identity({ inputs: { x }, backend });
93060 return { dataId: result.dataId, shape: result.shape, dtype };
93061 }
93062 if (backend.shouldExecuteOnCPU([x])) {
93063 const values = backend.texData.get(x.dataId).values;
93064 const [resultShape, resultType, resultData] = castImplCPU(values, x.shape, x.dtype, dtype);
93065 return backend.makeTensorInfo(resultShape, resultType, resultData);
93066 }
93067 if (dtype === 'int32') {
93068 return int(x, backend);
93069 }
93070 if (dtype === 'bool') {
93071 const zerosTensorInfo = backend.makeTensorInfo([], 'bool', getTypedArrayFromDType('bool', 1));
93072 const binaryInputs = { a: x, b: zerosTensorInfo };
93073 const result = notEqual({ inputs: binaryInputs, backend });
93074 backend.disposeIntermediateTensorInfo(zerosTensorInfo);
93075 return result;
93076 }
93077 throw new Error(`Error in Cast: failed to cast ${x.dtype} to ${dtype}`);
93078 }
93079 const castConfig = {
93080 kernelName: Cast,
93081 backendName: 'webgl',
93082 kernelFunc: cast
93083 };
93084
93085 /**
93086 * @license
93087 * Copyright 2020 Google LLC. All Rights Reserved.
93088 * Licensed under the Apache License, Version 2.0 (the "License");
93089 * you may not use this file except in compliance with the License.
93090 * You may obtain a copy of the License at
93091 *
93092 * http://www.apache.org/licenses/LICENSE-2.0
93093 *
93094 * Unless required by applicable law or agreed to in writing, software
93095 * distributed under the License is distributed on an "AS IS" BASIS,
93096 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93097 * See the License for the specific language governing permissions and
93098 * limitations under the License.
93099 * =============================================================================
93100 */
93101 const CEIL = `return ceil(x);`;
93102 const ceil = unaryKernelFunc({ opSnippet: CEIL, packedOpSnippet: CEIL, cpuKernelImpl: ceilImplCPU });
93103 const ceilConfig = {
93104 kernelName: Ceil,
93105 backendName: 'webgl',
93106 kernelFunc: ceil
93107 };
93108
93109 /**
93110 * @license
93111 * Copyright 2017 Google LLC. All Rights Reserved.
93112 * Licensed under the Apache License, Version 2.0 (the "License");
93113 * you may not use this file except in compliance with the License.
93114 * You may obtain a copy of the License at
93115 *
93116 * http://www.apache.org/licenses/LICENSE-2.0
93117 *
93118 * Unless required by applicable law or agreed to in writing, software
93119 * distributed under the License is distributed on an "AS IS" BASIS,
93120 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93121 * See the License for the specific language governing permissions and
93122 * limitations under the License.
93123 * =============================================================================
93124 */
93125 class ClipProgram {
93126 constructor(aShape) {
93127 this.variableNames = ['A'];
93128 this.customUniforms = [
93129 { name: 'minVal', type: 'float' },
93130 { name: 'maxVal', type: 'float' }
93131 ];
93132 this.outputShape = aShape;
93133 this.userCode = `
93134
93135 void main() {
93136 float value = getAAtOutCoords();
93137 if (isnan(value)) {
93138 setOutput(value);
93139 return;
93140 }
93141
93142 setOutput(clamp(value, minVal, maxVal));
93143 }
93144 `;
93145 }
93146 }
93147
93148 /**
93149 * @license
93150 * Copyright 2018 Google LLC. All Rights Reserved.
93151 * Licensed under the Apache License, Version 2.0 (the "License");
93152 * you may not use this file except in compliance with the License.
93153 * You may obtain a copy of the License at
93154 *
93155 * http://www.apache.org/licenses/LICENSE-2.0
93156 *
93157 * Unless required by applicable law or agreed to in writing, software
93158 * distributed under the License is distributed on an "AS IS" BASIS,
93159 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93160 * See the License for the specific language governing permissions and
93161 * limitations under the License.
93162 * =============================================================================
93163 */
93164 class ClipPackedProgram {
93165 constructor(aShape) {
93166 this.variableNames = ['A'];
93167 this.packedInputs = true;
93168 this.packedOutput = true;
93169 this.customUniforms = [
93170 { name: 'minVal', type: 'float' },
93171 { name: 'maxVal', type: 'float' }
93172 ];
93173 this.outputShape = aShape;
93174 this.userCode = `
93175 void main() {
93176 vec4 value = getAAtOutCoords();
93177
93178 if (any(isnan(value))) {
93179 setOutput(value);
93180 return;
93181 }
93182
93183 setOutput(clamp(value, vec4(minVal), vec4(maxVal)));
93184 }
93185 `;
93186 }
93187 }
93188
93189 /**
93190 * @license
93191 * Copyright 2020 Google LLC. All Rights Reserved.
93192 * Licensed under the Apache License, Version 2.0 (the "License");
93193 * you may not use this file except in compliance with the License.
93194 * You may obtain a copy of the License at
93195 *
93196 * http://www.apache.org/licenses/LICENSE-2.0
93197 *
93198 * Unless required by applicable law or agreed to in writing, software
93199 * distributed under the License is distributed on an "AS IS" BASIS,
93200 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93201 * See the License for the specific language governing permissions and
93202 * limitations under the License.
93203 * =============================================================================
93204 */
93205 function clipByValue(args) {
93206 const { inputs, backend, attrs } = args;
93207 const { x } = inputs;
93208 const { clipValueMin, clipValueMax } = attrs;
93209 let program;
93210 if (env().getBool('WEBGL_PACK_CLIP')) {
93211 program = new ClipPackedProgram(x.shape);
93212 }
93213 else {
93214 program = new ClipProgram(x.shape);
93215 }
93216 const customValues = [[clipValueMin], [clipValueMax]];
93217 return backend.runWebGLProgram(program, [x], x.dtype, customValues);
93218 }
93219 const clipByValueConfig = {
93220 kernelName: ClipByValue,
93221 backendName: 'webgl',
93222 kernelFunc: clipByValue
93223 };
93224
93225 /**
93226 * @license
93227 * Copyright 2018 Google LLC. All Rights Reserved.
93228 * Licensed under the Apache License, Version 2.0 (the "License");
93229 * you may not use this file except in compliance with the License.
93230 * You may obtain a copy of the License at
93231 *
93232 * http://www.apache.org/licenses/LICENSE-2.0
93233 *
93234 * Unless required by applicable law or agreed to in writing, software
93235 * distributed under the License is distributed on an "AS IS" BASIS,
93236 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93237 * See the License for the specific language governing permissions and
93238 * limitations under the License.
93239 * =============================================================================
93240 */
93241 class ComplexAbsProgram {
93242 constructor(shape) {
93243 this.variableNames = ['real', 'imag'];
93244 this.outputShape = shape;
93245 this.userCode = `
93246 void main() {
93247 float re = abs(getRealAtOutCoords());
93248 float im = abs(getImagAtOutCoords());
93249 float mx = max(re, im);
93250
93251 // sadly the length function in glsl is not underflow-safe
93252 // (at least not on Intel GPUs). So the safe solution is
93253 // to ensure underflow-safety in all cases.
93254 setOutput(
93255 mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx))
93256 );
93257 }
93258 `;
93259 }
93260 }
93261
93262 /**
93263 * @license
93264 * Copyright 2020 Google LLC. All Rights Reserved.
93265 * Licensed under the Apache License, Version 2.0 (the "License");
93266 * you may not use this file except in compliance with the License.
93267 * You may obtain a copy of the License at
93268 *
93269 * http://www.apache.org/licenses/LICENSE-2.0
93270 *
93271 * Unless required by applicable law or agreed to in writing, software
93272 * distributed under the License is distributed on an "AS IS" BASIS,
93273 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93274 * See the License for the specific language governing permissions and
93275 * limitations under the License.
93276 * =============================================================================
93277 */
93278 // Returns a TensorInfo with the complex shape and the dataId of the
93279 // underlying part. We need to do this because a reshaped complex tensor is
93280 // not reflected in its parts.
93281 function makeComplexComponentTensorInfo(complexTensor, complexPart) {
93282 return {
93283 dataId: complexPart.dataId,
93284 dtype: complexPart.dtype,
93285 shape: complexTensor.shape
93286 };
93287 }
93288 function complexAbs(args) {
93289 const { inputs, backend } = args;
93290 const { x } = inputs;
93291 const xData = backend.texData.get(x.dataId);
93292 const program = new ComplexAbsProgram(x.shape);
93293 const programInputs = [
93294 makeComplexComponentTensorInfo(x, xData.complexTensorInfos.real),
93295 makeComplexComponentTensorInfo(x, xData.complexTensorInfos.imag),
93296 ];
93297 return backend.runWebGLProgram(program, programInputs, programInputs[0].dtype);
93298 }
93299 const complexAbsConfig = {
93300 kernelName: ComplexAbs,
93301 backendName: 'webgl',
93302 kernelFunc: complexAbs
93303 };
93304
93305 /**
93306 * @license
93307 * Copyright 2017 Google LLC. All Rights Reserved.
93308 * Licensed under the Apache License, Version 2.0 (the "License");
93309 * you may not use this file except in compliance with the License.
93310 * You may obtain a copy of the License at
93311 *
93312 * http://www.apache.org/licenses/LICENSE-2.0
93313 *
93314 * Unless required by applicable law or agreed to in writing, software
93315 * distributed under the License is distributed on an "AS IS" BASIS,
93316 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93317 * See the License for the specific language governing permissions and
93318 * limitations under the License.
93319 * =============================================================================
93320 */
93321 class ConcatProgram {
93322 // Concats 2d tensors along axis=1. See comments in MathBackendWebGL.concat().
93323 constructor(shapes) {
93324 this.outputShape = [];
93325 this.outputShape = computeOutShape$1(shapes, 1 /* axis */);
93326 this.variableNames = shapes.map((_, i) => `T${i}`);
93327 const offsets = new Array(shapes.length - 1);
93328 offsets[0] = shapes[0][1];
93329 for (let i = 1; i < offsets.length; i++) {
93330 offsets[i] = offsets[i - 1] + shapes[i][1];
93331 }
93332 const snippets = [`if (yC < ${offsets[0]}) setOutput(getT0(yR, yC));`];
93333 for (let i = 1; i < offsets.length; i++) {
93334 const shift = offsets[i - 1];
93335 snippets.push(`else if (yC < ${offsets[i]}) ` +
93336 `setOutput(getT${i}(yR, yC-${shift}));`);
93337 }
93338 const lastIndex = offsets.length;
93339 const lastShift = offsets[offsets.length - 1];
93340 snippets.push(`else setOutput(getT${lastIndex}(yR, yC-${lastShift}));`);
93341 this.userCode = `
93342 void main() {
93343 ivec2 coords = getOutputCoords();
93344 int yR = coords.x;
93345 int yC = coords.y;
93346
93347 ${snippets.join('\n ')}
93348 }
93349 `;
93350 }
93351 }
93352
93353 /**
93354 * @license
93355 * Copyright 2019 Google LLC. All Rights Reserved.
93356 * Licensed under the Apache License, Version 2.0 (the "License");
93357 * you may not use this file except in compliance with the License.
93358 * You may obtain a copy of the License at
93359 *
93360 * http://www.apache.org/licenses/LICENSE-2.0
93361 *
93362 * Unless required by applicable law or agreed to in writing, software
93363 * distributed under the License is distributed on an "AS IS" BASIS,
93364 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93365 * See the License for the specific language governing permissions and
93366 * limitations under the License.
93367 * =============================================================================
93368 */
93369 class ConcatPackedProgram {
93370 constructor(shapes, axis) {
93371 this.packedInputs = true;
93372 this.packedOutput = true;
93373 this.outputShape = [];
93374 this.outputShape = computeOutShape$1(shapes, axis);
93375 const shape = this.outputShape;
93376 const rank = shape.length;
93377 const dtype = getCoordsDataType(rank);
93378 const coords = getChannels('coords', rank);
93379 const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank);
93380 this.variableNames = shapes.map((_, i) => `T${i}`);
93381 const offsets = new Array(shapes.length - 1);
93382 offsets[0] = shapes[0][axis];
93383 for (let i = 1; i < offsets.length; i++) {
93384 offsets[i] = offsets[i - 1] + shapes[i][axis];
93385 }
93386 const channel = channels[axis];
93387 const lastChannels = channels.slice(-2);
93388 const allChannels = channels.join();
93389 let getValueSnippet = `if (${channel} < ${offsets[0]}) {
93390 return getChannel(
93391 getT0(${allChannels}), vec2(${lastChannels.join()}));
93392 }`;
93393 for (let i = 1; i < offsets.length; i++) {
93394 const shift = offsets[i - 1];
93395 // Note: the >= comparison below may seem unnecessary given the check
93396 // above but is needed to workaround branch execution issues on some
93397 // devices. It makes all the conditions exclusive without relying on
93398 // execution order.
93399 getValueSnippet += `
93400 if (${channel} < ${offsets[i]} && ${channel} >= ${offsets[i - 1]}) {
93401 return getChannel(
93402 getT${i}(${shiftedChannels(channels, channel, shift)}),
93403 vec2(${shiftedChannels(lastChannels, channel, shift)}));
93404 }`;
93405 }
93406 const lastIndex = offsets.length;
93407 const shift = offsets[offsets.length - 1];
93408 getValueSnippet += `
93409 return getChannel(
93410 getT${lastIndex}(${shiftedChannels(channels, channel, shift)}),
93411 vec2(${shiftedChannels(lastChannels, channel, shift)}));`;
93412 this.userCode = `
93413 float getValue(${channels.map(x => 'int ' + x)}) {
93414 ${getValueSnippet}
93415 }
93416
93417 void main() {
93418 ${dtype} coords = getOutputCoords();
93419 vec4 result = vec4(getValue(${coords}), 0., 0., 0.);
93420
93421 ${coords[rank - 1]} = ${coords[rank - 1]} + 1;
93422 if (${coords[rank - 1]} < ${shape[rank - 1]}) {
93423 result.g = getValue(${coords});
93424 }
93425
93426 ${coords[rank - 2]} = ${coords[rank - 2]} + 1;
93427 if (${coords[rank - 2]} < ${shape[rank - 2]}) {
93428 result.a = getValue(${coords});
93429 }
93430
93431 ${coords[rank - 1]} = ${coords[rank - 1]} - 1;
93432 if (${coords[rank - 2]} < ${shape[rank - 2]} &&
93433 ${coords[rank - 1]} < ${shape[rank - 1]}) {
93434 result.b = getValue(${coords});
93435 }
93436 setOutput(result);
93437 }
93438 `;
93439 }
93440 }
93441 /**
93442 * Return an expression for coordinates into a vector where a given channel
93443 * will be offset by [shift].
93444 *
93445 * @param channels the channels to consider
93446 * @param channel the channel we want shifted
93447 * @param shift the amount to subtract from the channel.
93448 *
93449 * @returns a string of the form 'x, y-[shift], z' where any one channel can
93450 * have the shift applied.
93451 */
93452 function shiftedChannels(channels, channel, shift) {
93453 const channelIdx = channels.indexOf(channel);
93454 const res = channels.map((c, idx) => {
93455 if (idx === channelIdx) {
93456 return `${c} - ${shift}`;
93457 }
93458 else {
93459 return c;
93460 }
93461 });
93462 return res.join();
93463 }
93464
93465 /**
93466 * @license
93467 * Copyright 2020 Google LLC. All Rights Reserved.
93468 * Licensed under the Apache License, Version 2.0 (the "License");
93469 * you may not use this file except in compliance with the License.
93470 * You may obtain a copy of the License at
93471 *
93472 * http://www.apache.org/licenses/LICENSE-2.0
93473 *
93474 * Unless required by applicable law or agreed to in writing, software
93475 * distributed under the License is distributed on an "AS IS" BASIS,
93476 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93477 * See the License for the specific language governing permissions and
93478 * limitations under the License.
93479 * =============================================================================
93480 */
93481 function imag(args) {
93482 const { inputs, backend } = args;
93483 const { input } = inputs;
93484 const inputData = backend.texData.get(input.dataId);
93485 return identity({ inputs: { x: inputData.complexTensorInfos.imag }, backend });
93486 }
93487 const imagConfig = {
93488 kernelName: Imag,
93489 backendName: 'webgl',
93490 kernelFunc: imag
93491 };
93492
93493 /**
93494 * @license
93495 * Copyright 2020 Google LLC. All Rights Reserved.
93496 * Licensed under the Apache License, Version 2.0 (the "License");
93497 * you may not use this file except in compliance with the License.
93498 * You may obtain a copy of the License at
93499 *
93500 * http://www.apache.org/licenses/LICENSE-2.0
93501 *
93502 * Unless required by applicable law or agreed to in writing, software
93503 * distributed under the License is distributed on an "AS IS" BASIS,
93504 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93505 * See the License for the specific language governing permissions and
93506 * limitations under the License.
93507 * =============================================================================
93508 */
93509 function concatImpl(inputs, axis, backend) {
93510 const dtype = inputs[0].dtype;
93511 if (dtype === 'complex64') {
93512 const reals = inputs.map((t) => real({ inputs: { input: t }, backend }));
93513 const imags = inputs.map((t) => imag({ inputs: { input: t }, backend }));
93514 const realConcated = concatImpl(reals, axis, backend);
93515 const imagConcated = concatImpl(imags, axis, backend);
93516 const result = complex({ inputs: { real: realConcated, imag: imagConcated }, backend });
93517 reals.forEach(r => backend.disposeIntermediateTensorInfo(r));
93518 imags.forEach(i => backend.disposeIntermediateTensorInfo(i));
93519 backend.disposeIntermediateTensorInfo(realConcated);
93520 backend.disposeIntermediateTensorInfo(imagConcated);
93521 return result;
93522 }
93523 let runOnCpu = backend.shouldExecuteOnCPU(inputs);
93524 // Run on cpu if dtype is string. For string, the backend represents it
93525 // as Uint8Array[], where each Uint8Array is a character. Given that the
93526 // computation is only on the outer array, uploading the whole data onto
93527 // gpu is wasteful. Also, currently webgl doesn't have a design to
93528 // upload and retrieve Uint8Array[] between cpu and gpu. Therefore, we
93529 // just run the kernel on cpu if dtype is string.
93530 if (dtype === 'string') {
93531 runOnCpu = true;
93532 }
93533 if (runOnCpu) {
93534 // Any concat of n-dimensional tensors across any axis can be reduced to
93535 // a concatenation of two-dimensional tensors across the axis 1 by first
93536 // partitioning the axes of the original tensors into those less than the
93537 // axis to be concatenated and the rest. Then reshape the tensors
93538 // into a two-dimensional tensor by collapsing these two sets of axes and
93539 // concatenate the resulting matrices across the axis 1, finally reshaping
93540 // the result to have the proper shape.
93541 const tensors2D = inputs.map(t => {
93542 const innerSize = sizeFromShape(t.shape.slice(axis));
93543 const shape = [-1, innerSize];
93544 return reshape({ inputs: { x: t }, backend, attrs: { shape } });
93545 });
93546 const inputsValShapes = tensors2D.map(t => {
93547 return { vals: backend.readSync(t.dataId), shape: t.shape };
93548 });
93549 // Concats 2d tensors along axis=1.
93550 const outShape = computeOutShape$1(tensors2D.map(t => t.shape), 1 /* axis */);
93551 const simplyConcat = tensors2D[0].shape[0] === 1;
93552 const outVals = concatImplCPU(inputsValShapes, outShape, dtype, simplyConcat);
93553 const finalOutShape = computeOutShape$1(inputs.map(t => t.shape), axis);
93554 const outInfo = backend.makeTensorInfo(finalOutShape, dtype, outVals);
93555 tensors2D.forEach(t => backend.disposeIntermediateTensorInfo(t));
93556 return outInfo;
93557 }
93558 // Keep only non-empty tensors (ignore tensors with 0 in their shape).
93559 const $inputs = inputs.filter(t => sizeFromShape(t.shape) > 0);
93560 const shouldPack = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') &&
93561 $inputs[0].shape.length > 1;
93562 if ($inputs.length === 1) {
93563 // Clone tensor.
93564 const program = shouldPack ?
93565 new UnaryOpProgram(inputs[0].shape, CLONE) :
93566 new UnaryOpPackedProgram(inputs[0].shape, CLONE);
93567 return backend.runWebGLProgram(program, inputs, dtype);
93568 }
93569 const maxTexturesInShader = env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER');
93570 if ($inputs.length > maxTexturesInShader) {
93571 const reducedInputs = [];
93572 for (let i = 0; i < $inputs.length; i += maxTexturesInShader) {
93573 const subArray = $inputs.slice(i, i + maxTexturesInShader);
93574 reducedInputs.push(concatImpl(subArray, axis, backend));
93575 }
93576 const result = concatImpl(reducedInputs, axis, backend);
93577 for (const i of reducedInputs) {
93578 backend.disposeIntermediateTensorInfo(i);
93579 }
93580 return result;
93581 }
93582 if (shouldPack) {
93583 const program = new ConcatPackedProgram($inputs.map(t => t.shape), axis);
93584 return backend.runWebGLProgram(program, $inputs, dtype);
93585 }
93586 const { tensors2D, outShape } = computeTensors2D($inputs, axis, backend);
93587 const program = new ConcatProgram(tensors2D.map(t => t.shape));
93588 const result = backend.runWebGLProgram(program, tensors2D, dtype);
93589 tensors2D.forEach(r => backend.disposeIntermediateTensorInfo(r));
93590 const reshapedResult = reshape({ inputs: { x: result }, attrs: { shape: outShape }, backend });
93591 backend.disposeIntermediateTensorInfo(result);
93592 return reshapedResult;
93593 }
93594 function computeTensors2D(inputs, axis, backend) {
93595 // Any concat of n-dimensional tensors across any axis can be reduced to
93596 // a concatenation of two-dimensional tensors across the axis 1 by first
93597 // partitioning the axes of the original tensors into those less than the
93598 // axis to be concatenated and the rest. Then reshape the tensors
93599 // into a two-dimensional tensor by collapsing these two sets of axes and
93600 // concatenate the resulting matrices across the axis 1, finally reshaping
93601 // the result to have the proper shape.
93602 const outShape = computeOutShape$1(inputs.map(t => t.shape), axis);
93603 const tensors2D = inputs.map(x => reshape({
93604 inputs: { x },
93605 attrs: { shape: [-1, sizeFromShape(x.shape.slice(axis))] },
93606 backend
93607 }));
93608 return { tensors2D, outShape };
93609 }
93610
93611 /**
93612 * @license
93613 * Copyright 2020 Google LLC. All Rights Reserved.
93614 * Licensed under the Apache License, Version 2.0 (the "License");
93615 * you may not use this file except in compliance with the License.
93616 * You may obtain a copy of the License at
93617 *
93618 * http://www.apache.org/licenses/LICENSE-2.0
93619 *
93620 * Unless required by applicable law or agreed to in writing, software
93621 * distributed under the License is distributed on an "AS IS" BASIS,
93622 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93623 * See the License for the specific language governing permissions and
93624 * limitations under the License.
93625 * =============================================================================
93626 */
93627 function concat(args) {
93628 const { inputs, backend, attrs } = args;
93629 const { axis } = attrs;
93630 const $axis = parseAxisParam(axis, inputs[0].shape)[0];
93631 const shapes = inputs.map(t => t.shape);
93632 assertParamsConsistent(shapes, $axis);
93633 const outShape = computeOutShape$1(inputs.map(t => t.shape), $axis);
93634 if (sizeFromShape(outShape) === 0) {
93635 return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
93636 }
93637 // Keep only non-empty tensors (ignore tensors with 0 in their shape).
93638 const $inputs = inputs.filter(t => sizeFromShape(t.shape) > 0);
93639 if ($inputs.length === 1) {
93640 return identity({ inputs: { x: $inputs[0] }, backend });
93641 }
93642 return concatImpl($inputs, $axis, backend);
93643 }
93644 const concatConfig = {
93645 kernelName: Concat,
93646 backendName: 'webgl',
93647 kernelFunc: concat
93648 };
93649
93650 /**
93651 * @license
93652 * Copyright 2017 Google LLC. All Rights Reserved.
93653 * Licensed under the Apache License, Version 2.0 (the "License");
93654 * you may not use this file except in compliance with the License.
93655 * You may obtain a copy of the License at
93656 *
93657 * http://www.apache.org/licenses/LICENSE-2.0
93658 *
93659 * Unless required by applicable law or agreed to in writing, software
93660 * distributed under the License is distributed on an "AS IS" BASIS,
93661 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93662 * See the License for the specific language governing permissions and
93663 * limitations under the License.
93664 * =============================================================================
93665 */
93666 class Conv2DProgram {
93667 constructor(convInfo, addBias = false, activation = null, hasPreluActivationWeights = false, hasLeakyreluAlpha = false) {
93668 this.variableNames = ['x', 'W'];
93669 this.outputShape = convInfo.outShape;
93670 const padTop = convInfo.padInfo.top;
93671 const padLeft = convInfo.padInfo.left;
93672 const strideHeight = convInfo.strideHeight;
93673 const strideWidth = convInfo.strideWidth;
93674 const dilationHeight = convInfo.dilationHeight;
93675 const dilationWidth = convInfo.dilationWidth;
93676 const filterHeight = convInfo.filterHeight;
93677 const filterWidth = convInfo.filterWidth;
93678 const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
93679 const inputDepthVec4Remainder = convInfo.inChannels % 4;
93680 const isChannelsLast = convInfo.dataFormat === 'channelsLast';
93681 const rowDim = isChannelsLast ? 1 : 2;
93682 const colDim = isChannelsLast ? 2 : 3;
93683 const channelDim = isChannelsLast ? 3 : 1;
93684 let activationSnippet = '', applyActivationSnippet = '';
93685 if (activation) {
93686 if (hasPreluActivationWeights) {
93687 activationSnippet = `float activation(float a) {
93688 float b = getPreluActivationWeightsAtOutCoords();
93689 ${activation}
93690 }`;
93691 }
93692 else if (hasLeakyreluAlpha) {
93693 activationSnippet = `float activation(float a) {
93694 float b = getLeakyreluAlphaAtOutCoords();
93695 ${activation}
93696 }`;
93697 }
93698 else {
93699 activationSnippet = `
93700 float activation(float x) {
93701 ${activation}
93702 }
93703 `;
93704 }
93705 applyActivationSnippet = `result = activation(result);`;
93706 }
93707 const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
93708 if (addBias) {
93709 this.variableNames.push('bias');
93710 }
93711 if (hasPreluActivationWeights) {
93712 this.variableNames.push('preluActivationWeights');
93713 }
93714 if (hasLeakyreluAlpha) {
93715 this.variableNames.push('leakyreluAlpha');
93716 }
93717 this.userCode = `
93718 ${activationSnippet}
93719
93720 const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
93721 const ivec2 pads = ivec2(${padTop}, ${padLeft});
93722
93723 void main() {
93724 ivec4 coords = getOutputCoords();
93725 int batch = coords[0];
93726 int d2 = coords[${channelDim}];
93727
93728 ivec2 xRCCorner =
93729 ivec2(coords[${rowDim}], coords[${colDim}]) * strides - pads;
93730 int xRCorner = xRCCorner.x;
93731 int xCCorner = xRCCorner.y;
93732
93733 // Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).
93734 // ? = to be determined. : = across all values in that axis.
93735 float dotProd = 0.0;
93736 for (int wR = 0; wR < ${filterHeight}; wR++) {
93737 int xR = xRCorner + wR * ${dilationHeight};
93738
93739 if (xR < 0 || xR >= ${convInfo.inHeight}) {
93740 continue;
93741 }
93742
93743 for (int wC = 0; wC < ${filterWidth}; wC++) {
93744 int xC = xCCorner + wC * ${dilationWidth};
93745
93746 if (xC < 0 || xC >= ${convInfo.inWidth}) {
93747 continue;
93748 }
93749
93750 for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) {
93751 vec4 wValues = vec4(
93752 getW(wR, wC, d1, d2),
93753 getW(wR, wC, d1 + 1, d2),
93754 getW(wR, wC, d1 + 2, d2),
93755 getW(wR, wC, d1 + 3, d2)
93756 );
93757
93758 if (${isChannelsLast}) {
93759 vec4 xValues = vec4(
93760 getX(batch, xR, xC, d1),
93761 getX(batch, xR, xC, d1 + 1),
93762 getX(batch, xR, xC, d1 + 2),
93763 getX(batch, xR, xC, d1 + 3)
93764 );
93765 dotProd += dot(xValues, wValues);
93766 } else {
93767 vec4 xValues = vec4(
93768 getX(batch, d1, xR, xC),
93769 getX(batch, d1 + 1, xR, xC),
93770 getX(batch, d1 + 2, xR, xC),
93771 getX(batch, d1 + 3, xR, xC)
93772 );
93773 dotProd += dot(xValues, wValues);
93774 }
93775 }
93776
93777 if (${inputDepthVec4Remainder === 1}) {
93778
93779 if (${isChannelsLast}) {
93780 dotProd +=
93781 getX(batch, xR, xC, ${inputDepthNearestVec4}) *
93782 getW(wR, wC, ${inputDepthNearestVec4}, d2);
93783 } else {
93784 dotProd +=
93785 getX(batch, ${inputDepthNearestVec4}, xR, xC) *
93786 getW(wR, wC, ${inputDepthNearestVec4}, d2);
93787 }
93788
93789 } else if (${inputDepthVec4Remainder === 2}) {
93790 vec2 wValues = vec2(
93791 getW(wR, wC, ${inputDepthNearestVec4}, d2),
93792 getW(wR, wC, ${inputDepthNearestVec4} + 1, d2)
93793 );
93794
93795 if (${isChannelsLast}) {
93796 vec2 xValues = vec2(
93797 getX(batch, xR, xC, ${inputDepthNearestVec4}),
93798 getX(batch, xR, xC, ${inputDepthNearestVec4} + 1)
93799 );
93800 dotProd += dot(xValues, wValues);
93801 } else {
93802 vec2 xValues = vec2(
93803 getX(batch, ${inputDepthNearestVec4}, xR, xC),
93804 getX(batch, ${inputDepthNearestVec4} + 1, xR, xC)
93805 );
93806 dotProd += dot(xValues, wValues);
93807 }
93808
93809 } else if (${inputDepthVec4Remainder === 3}) {
93810 vec3 wValues = vec3(
93811 getW(wR, wC, ${inputDepthNearestVec4}, d2),
93812 getW(wR, wC, ${inputDepthNearestVec4} + 1, d2),
93813 getW(wR, wC, ${inputDepthNearestVec4} + 2, d2)
93814 );
93815
93816 if (${isChannelsLast}) {
93817 vec3 xValues = vec3(
93818 getX(batch, xR, xC, ${inputDepthNearestVec4}),
93819 getX(batch, xR, xC, ${inputDepthNearestVec4} + 1),
93820 getX(batch, xR, xC, ${inputDepthNearestVec4} + 2)
93821 );
93822 dotProd += dot(xValues, wValues);
93823 } else {
93824 vec3 xValues = vec3(
93825 getX(batch, ${inputDepthNearestVec4}, xR, xC),
93826 getX(batch, ${inputDepthNearestVec4} + 1, xR, xC),
93827 getX(batch, ${inputDepthNearestVec4} + 2, xR, xC)
93828 );
93829 dotProd += dot(xValues, wValues);
93830 }
93831
93832 }
93833 }
93834 }
93835
93836 float result = dotProd;
93837 ${addBiasSnippet}
93838 ${applyActivationSnippet}
93839 setOutput(result);
93840 }
93841 `;
93842 }
93843 }
93844 class Conv3DProgram {
93845 constructor(convInfo) {
93846 this.variableNames = ['x', 'W'];
93847 this.outputShape = convInfo.outShape;
93848 const padFront = convInfo.padInfo.front;
93849 const padTop = convInfo.padInfo.top;
93850 const padLeft = convInfo.padInfo.left;
93851 const strideDepth = convInfo.strideDepth;
93852 const strideHeight = convInfo.strideHeight;
93853 const strideWidth = convInfo.strideWidth;
93854 const dilationDepth = convInfo.dilationDepth;
93855 const dilationHeight = convInfo.dilationHeight;
93856 const dilationWidth = convInfo.dilationWidth;
93857 const filterDepth = convInfo.filterDepth;
93858 const filterHeight = convInfo.filterHeight;
93859 const filterWidth = convInfo.filterWidth;
93860 const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
93861 const inputDepthVec4Remainder = convInfo.inChannels % 4;
93862 this.userCode = `
93863 const ivec3 strides = ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
93864 const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
93865
93866 void main() {
93867 ivec5 coords = getOutputCoords();
93868 int batch = coords.x;
93869 int d2 = coords.u;
93870
93871 ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
93872 int xFCorner = xFRCCorner.x;
93873 int xRCorner = xFRCCorner.y;
93874 int xCCorner = xFRCCorner.z;
93875
93876 // Convolve x(?, ?, ?, d1) with w(:, :, :, d1, d2) to get
93877 // y(yF, yR, yC, d2). ? = to be determined. : = across all
93878 // values in that axis.
93879 float dotProd = 0.0;
93880 for (int wF = 0; wF < ${filterDepth}; wF++) {
93881 int xF = xFCorner + wF * ${dilationDepth};
93882
93883 if (xF < 0 || xF >= ${convInfo.inDepth}) {
93884 continue;
93885 }
93886
93887 for (int wR = 0; wR < ${filterHeight}; wR++) {
93888 int xR = xRCorner + wR * ${dilationHeight};
93889
93890 if (xR < 0 || xR >= ${convInfo.inHeight}) {
93891 continue;
93892 }
93893
93894 for (int wC = 0; wC < ${filterWidth}; wC++) {
93895 int xC = xCCorner + wC * ${dilationWidth};
93896
93897 if (xC < 0 || xC >= ${convInfo.inWidth}) {
93898 continue;
93899 }
93900
93901 for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) {
93902 vec4 xValues = vec4(
93903 getX(batch, xF, xR, xC, d1),
93904 getX(batch, xF, xR, xC, d1 + 1),
93905 getX(batch, xF, xR, xC, d1 + 2),
93906 getX(batch, xF, xR, xC, d1 + 3)
93907 );
93908 vec4 wValues = vec4(
93909 getW(wF, wR, wC, d1, d2),
93910 getW(wF, wR, wC, d1 + 1, d2),
93911 getW(wF, wR, wC, d1 + 2, d2),
93912 getW(wF, wR, wC, d1 + 3, d2)
93913 );
93914
93915 dotProd += dot(xValues, wValues);
93916 }
93917
93918 if (${inputDepthVec4Remainder === 1}) {
93919 dotProd +=
93920 getX(batch, xF, xR, xC, ${inputDepthNearestVec4}) *
93921 getW(wF, wR, wC, ${inputDepthNearestVec4}, d2);
93922 } else if (${inputDepthVec4Remainder === 2}) {
93923 vec2 xValues = vec2(
93924 getX(batch, xF, xR, xC, ${inputDepthNearestVec4}),
93925 getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1)
93926 );
93927 vec2 wValues = vec2(
93928 getW(wF, wR, wC, ${inputDepthNearestVec4}, d2),
93929 getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2)
93930 );
93931 dotProd += dot(xValues, wValues);
93932 } else if (${inputDepthVec4Remainder === 3}) {
93933 vec3 xValues = vec3(
93934 getX(batch, xF, xR, xC, ${inputDepthNearestVec4}),
93935 getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1),
93936 getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 2)
93937 );
93938 vec3 wValues = vec3(
93939 getW(wF, wR, wC, ${inputDepthNearestVec4}, d2),
93940 getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2),
93941 getW(wF, wR, wC, ${inputDepthNearestVec4} + 2, d2)
93942 );
93943 dotProd += dot(xValues, wValues);
93944 }
93945 }
93946 }
93947 }
93948 setOutput(dotProd);
93949 }
93950 `;
93951 }
93952 }
93953
93954 /**
93955 * @license
93956 * Copyright 2022 Google LLC. All Rights Reserved.
93957 * Licensed under the Apache License, Version 2.0 (the "License");
93958 * you may not use this file except in compliance with the License.
93959 * You may obtain a copy of the License at
93960 *
93961 * http://www.apache.org/licenses/LICENSE-2.0
93962 *
93963 * Unless required by applicable law or agreed to in writing, software
93964 * distributed under the License is distributed on an "AS IS" BASIS,
93965 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93966 * See the License for the specific language governing permissions and
93967 * limitations under the License.
93968 * =============================================================================
93969 */
93970 class Conv2DPackedProgram {
93971 constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) {
93972 this.variableNames = ['x', 'W'];
93973 this.packedInputs = true;
93974 this.packedOutput = true;
93975 this.customUniforms = [
93976 { name: 'pads', type: 'ivec2' },
93977 { name: 'strides', type: 'ivec2' },
93978 { name: 'dilations', type: 'ivec2' },
93979 { name: 'inDims', type: 'ivec2' },
93980 ];
93981 this.outputShape = convInfo.outShape;
93982 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
93983 const padLeft = convInfo.padInfo.left;
93984 const strideWidth = convInfo.strideWidth;
93985 const dilationWidth = convInfo.dilationWidth;
93986 const filterHeight = convInfo.filterHeight;
93987 const filterWidth = convInfo.filterWidth;
93988 const texelsAcross = filterWidth;
93989 let mainLoop = `
93990 int xR; int xC; int xCOffset;
93991 vec4 wTexel; vec4 previous; vec4 final;`;
93992 for (let c = 0; c < filterWidth; c++) {
93993 mainLoop += `
93994 vec4 xTexelC${c * 2};
93995 int xTexelC${c * 2}Ready;
93996 vec4 xTexelC${c * 2 + 1};
93997 int xTexelC${c * 2 + 1}Ready;
93998 vec4 xC${c};`;
93999 }
94000 /**
94001 * This vectorized implementation works by gathering the values needed for
94002 * each output channel's dot product into vec4's and then multiplying them
94003 * all together (this happens in the final double for-loop below). Most of
94004 * the main loop consists of constructing these vec4's with the minimum
94005 * number of texture2D calls, which means making use of all four returned
94006 * values from a texture2D call at once.
94007 */
94008 mainLoop += `
94009 for (int r = 0; r < ${filterHeight}; r++) {
94010 for (int d1 = 0; d1 < ${convInfo.inChannels}; d1 += 2) {
94011 `;
94012 for (let c = 0; c < filterWidth; c++) {
94013 mainLoop += `
94014 xTexelC${c * 2} = vec4(0.0);
94015 xTexelC${c * 2}Ready = 0;
94016 xTexelC${c * 2 + 1} = vec4(0.0);
94017 xTexelC${c * 2 + 1}Ready = 0;
94018 xC${c} = vec4(0.0);`;
94019 }
94020 mainLoop += `
94021 xR = xRCorner + r * dilations[0];
94022 if (xR >=0 && xR < inDims[0]) {
94023 `;
94024 for (let texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) {
94025 const colIndex = texelC * 2;
94026 mainLoop += `
94027 xC = xCCorner + ${colIndex * dilationWidth};
94028 `;
94029 if (strideWidth === 1) {
94030 if (colIndex < filterWidth) {
94031 // If padding is odd, the outer texels have to be composed.
94032 if (padLeft % 2 === 1) {
94033 // TODO: Ensure vec4 previous does not result in redundant sample,
94034 // and avoid setting xTexelRC's that exceed the boundary in the
94035 // first place rather than resetting them to vec4(0)).
94036 // To compute xCOffset:
94037 // - If padding is odd, we must add 1 to ensure we ask for an
94038 // even-numbered row.
94039 // - We subtract 2 to access the previous texel.
94040 mainLoop += `
94041 xCOffset = xC + 1;
94042 if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) {
94043 xTexelC${colIndex} = getX(batch, xR, xCOffset, d1);
94044
94045 // Need to manually clear unused channels in case
94046 // we're reading from recycled texture.
94047 if (xCOffset + 1 >= inDims[1]) {
94048 xTexelC${colIndex}.zw = vec2(0.0);
94049 }
94050 xTexelC${colIndex}Ready = 1;
94051 }
94052 `;
94053 // This texel has been read in previous iteration if the dilation
94054 // is 1.
94055 if (dilationWidth === 1 && colIndex > 0) {
94056 mainLoop += `
94057 xC${colIndex} = vec4(xTexelC${colIndex - 2}.zw, xTexelC${colIndex}.xy);
94058 `;
94059 }
94060 else {
94061 mainLoop += `
94062 xCOffset = xC + 1 - 2;
94063
94064 if (xCOffset >= 0 && xCOffset < inDims[1]) {
94065 previous = getX(batch, xR, xCOffset, d1);
94066
94067 // Need to manually clear unused channels in case
94068 // we're reading from recycled texture.
94069 if (xCOffset + 1 >= inDims[1]) {
94070 previous.zw = vec2(0.0);
94071 }
94072
94073 xC${colIndex} = vec4(previous.zw, xTexelC${colIndex}.xy);
94074 } else {
94075 xC${colIndex} = vec4(0.0, 0.0, xTexelC${colIndex}.xy);
94076 }
94077 `;
94078 }
94079 }
94080 else {
94081 // Padding is even, so xRC corresponds to a single texel.
94082 mainLoop += `
94083 if (xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) {
94084 xTexelC${colIndex} = getX(batch, xR, xC, d1);
94085 if (xC + 1 >= inDims[1]) {
94086 xTexelC${colIndex}.zw = vec2(0.0);
94087 }
94088 xTexelC${colIndex}Ready = 1;
94089 }
94090
94091 xC${colIndex} = xTexelC${colIndex};
94092 `;
94093 }
94094 if (colIndex + 1 < filterWidth) {
94095 // If dilation is even, the second entry should match the first
94096 // (either both are composed or both are single samples). But if
94097 // dilation is odd, then the second entry should be the opposite
94098 // of the first (if the first is composed, the second is a single
94099 // sample, and vice versa.)
94100 const nextTexelOffset = padLeft % 2 === 0 ?
94101 nearestLargerEven(dilationWidth) :
94102 dilationWidth;
94103 if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) ||
94104 (dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) {
94105 mainLoop += `
94106 xCOffset = xC + imod(pads[1], 2) + ${nextTexelOffset};
94107
94108 if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
94109 xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
94110
94111 // Need to manually clear unused channels in case
94112 // we're reading from recycled texture.
94113 if (xCOffset + 1 >= inDims[1]) {
94114 xTexelC${colIndex + 1}.zw = vec2(0.0);
94115 }
94116 xTexelC${colIndex + 1}Ready = 1;
94117 }
94118 `;
94119 // If dilation > 1 then the xRC's will not be able to share any
94120 // values, so each xRC will require two unique calls to getX.
94121 if (dilationWidth > 1) {
94122 mainLoop += `
94123 xCOffset -= 2;
94124 if (xCOffset >= 0 && xCOffset < inDims[1]) {
94125 previous = getX(batch, xR, xCOffset, d1);
94126 xC${colIndex + 1} = vec4(previous.zw, xTexelC${colIndex + 1}.xy);
94127 } else {
94128 xC${colIndex + 1} = vec4(0.0, 0.0, xTexelC${colIndex + 1}.xy);
94129 }
94130 `;
94131 }
94132 else {
94133 mainLoop += `
94134 xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.xy);
94135 `;
94136 }
94137 }
94138 else {
94139 // If dilation is 1 and padding is odd, we have already read the
94140 // texel when constructing the previous x value. Here we can
94141 // simply skip the texture read.
94142 if (nextTexelOffset === 1) {
94143 mainLoop += `
94144 xC${colIndex + 1} = xTexelC${colIndex};
94145 `;
94146 }
94147 else {
94148 mainLoop += `
94149 xCOffset = xC + ${nextTexelOffset};
94150
94151 if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
94152 xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
94153 if (xCOffset + 1 >= inDims[1]) {
94154 xTexelC${colIndex + 1}.zw = vec2(0.0);
94155 }
94156 xTexelC${colIndex + 1}Ready = 1;
94157 }
94158
94159 xC${colIndex + 1} = xTexelC${colIndex + 1};
94160 `;
94161 }
94162 }
94163 }
94164 }
94165 }
94166 else { // stride === 2
94167 if (colIndex < filterWidth) {
94168 // Depending on whether padLeft is even or odd, we want either the
94169 // xy or zw channels from X texels for xC${colIndex}. If padLeft is
94170 // even, xC${colIndex +1} is simply the zw channels of texels we've
94171 // already sampled. But if padLeft is odd, xC{$c + 1}.zw will
94172 // need to come from the xy channels of a new texel, hence the `
94173 // vec4
94174 // final` initialized below.
94175 if (padLeft % 2 === 1) {
94176 mainLoop += `
94177 xCOffset = xC + 1 - strides[1];
94178 if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) {
94179 xTexelC${colIndex} = getX(batch, xR, xCOffset, d1);
94180 // Need to manually clear unused channels in case
94181 // we're reading from recycled texture.
94182 if (xCOffset + 1 >= inDims[1]) {
94183 xTexelC${colIndex}.zw = vec2(0.0);
94184 }
94185 xTexelC${colIndex}Ready = 1;
94186 }
94187
94188 if(xC + 1 >= 0 && xC + 1 < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
94189 xTexelC${colIndex + 1} = getX(batch, xR, xC + 1, d1);
94190 // Need to manually clear unused channels in case
94191 // we're reading from recycled texture.
94192 if (xC + 2 >= inDims[1]) {
94193 xTexelC${colIndex + 1}.zw = vec2(0.0);
94194 }
94195 xTexelC${colIndex + 1}Ready = 1;
94196 }
94197
94198 xC${colIndex} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw);
94199 `;
94200 if (colIndex + 1 < filterWidth) {
94201 mainLoop += `
94202 final = vec4(0.0);
94203 xCOffset = xC + 1 + strides[1];
94204 if(xCOffset >= 0 && xCOffset < inDims[1]) {
94205 final = getX(batch, xR, xCOffset, d1);
94206 }
94207 xC${colIndex + 1} = vec4(xTexelC${colIndex + 1}.xy, final.xy);
94208 `;
94209 }
94210 }
94211 else {
94212 mainLoop += `
94213 if(xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) {
94214 xTexelC${colIndex} = getX(batch, xR, xC, d1);
94215 if (xC + 1 >= inDims[1]) {
94216 xTexelC${colIndex}.zw = vec2(0.0);
94217 }
94218 xTexelC${colIndex}Ready = 1;
94219 }
94220
94221 xCOffset = xC + strides[1];
94222 if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
94223 xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
94224 if (xCOffset + 1 >= inDims[1]) {
94225 xTexelC${colIndex + 1}.zw = vec2(0.);
94226 }
94227 xTexelC${colIndex + 1}Ready = 1;
94228 }
94229
94230 xC${colIndex} = vec4(
94231 xTexelC${colIndex}.xy, xTexelC${colIndex + 1}.xy);
94232 `;
94233 if (colIndex + 1 < filterWidth) {
94234 mainLoop += `
94235 xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw);
94236 `;
94237 }
94238 }
94239 }
94240 }
94241 // localize the dotProd accumulation within the loop, the theory is for
94242 // GPU with limited cache, accumulate sum across large amount of
94243 // veriables will cause lots of cache misses. (i.e. 5x5 filter will have
94244 // 50 variables)
94245 if (colIndex < filterWidth) {
94246 mainLoop += `
94247 wTexel = getW(r, ${colIndex}, d1, d2);
94248 dotProd += xC${colIndex}.xxzz * vec4(wTexel.xy, wTexel.xy);
94249 if(d1 + 1 < ${convInfo.inChannels}) {
94250 dotProd += xC${colIndex}.yyww * vec4(wTexel.zw, wTexel.zw);
94251 }
94252 `;
94253 if (colIndex + 1 < filterWidth) {
94254 mainLoop += `
94255 wTexel = getW(r, ${colIndex + 1}, d1, d2);
94256 dotProd += xC${colIndex + 1}.xxzz * vec4(wTexel.xy, wTexel.xy);
94257 if(d1 + 1 < ${convInfo.inChannels}) {
94258 dotProd += xC${colIndex + 1}.yyww * vec4(wTexel.zw, wTexel.zw);
94259 }
94260 `;
94261 }
94262 }
94263 }
94264 mainLoop += `
94265 }
94266 `;
94267 mainLoop += `
94268 }
94269 `;
94270 mainLoop += `
94271 }
94272 `;
94273 let activationSnippet = '', applyActivationSnippet = '';
94274 if (activation) {
94275 if (hasPreluActivation) {
94276 activationSnippet = `vec4 activation(vec4 a) {
94277 vec4 b = getPreluActivationWeightsAtOutCoords();
94278 ${activation}
94279 }`;
94280 }
94281 else if (hasLeakyReluAlpha) {
94282 activationSnippet = `vec4 activation(vec4 a) {
94283 vec4 b = getLeakyreluAlphaAtOutCoords();
94284 ${activation}
94285 }`;
94286 }
94287 else {
94288 activationSnippet = `vec4 activation(vec4 x) {
94289 ${activation}
94290 }`;
94291 }
94292 applyActivationSnippet = `result = activation(result);`;
94293 }
94294 const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
94295 if (addBias) {
94296 this.variableNames.push('bias');
94297 }
94298 if (hasPreluActivation) {
94299 this.variableNames.push('preluActivationWeights');
94300 }
94301 if (hasLeakyReluAlpha) {
94302 this.variableNames.push('leakyreluAlpha');
94303 }
94304 this.userCode = `
94305 ${activationSnippet}
94306
94307 void main() {
94308 ivec4 coords = getOutputCoords();
94309 int batch = coords.x;
94310 ivec2 xRCCorner = coords.yz * strides - pads;
94311 int d2 = coords.w;
94312 int xRCorner = xRCCorner.x;
94313 int xCCorner = xRCCorner.y;
94314
94315 //intialize dotProd with a small epsilon seems to reduce GPU accuracy loss.
94316 vec4 dotProd = vec4(0.000000000000001);
94317
94318 ${mainLoop}
94319
94320 vec4 result = dotProd - vec4(0.000000000000001);
94321 ${addBiasSnippet}
94322 ${applyActivationSnippet}
94323 setOutput(result);
94324 }
94325 `;
94326 }
94327 }
94328
94329 /**
94330 * @license
94331 * Copyright 2019 Google LLC. All Rights Reserved.
94332 * Licensed under the Apache License, Version 2.0 (the "License");
94333 * you may not use this file except in compliance with the License.
94334 * You may obtain a copy of the License at
94335 *
94336 * http://www.apache.org/licenses/LICENSE-2.0
94337 *
94338 * Unless required by applicable law or agreed to in writing, software
94339 * distributed under the License is distributed on an "AS IS" BASIS,
94340 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94341 * See the License for the specific language governing permissions and
94342 * limitations under the License.
94343 * =============================================================================
94344 */
94345 class Im2ColPackedProgram {
94346 constructor(outputShape, convInfo) {
94347 this.variableNames = ['A'];
94348 this.packedInputs = true;
94349 this.packedOutput = true;
94350 this.customUniforms = [
94351 { name: 'inputShape', type: 'ivec4' },
94352 { name: 'pad', type: 'ivec2' },
94353 { name: 'stride', type: 'ivec2' },
94354 { name: 'dilation', type: 'ivec2' },
94355 { name: 'inChannels', type: 'int' },
94356 { name: 'itemsPerBlockRow', type: 'int' },
94357 { name: 'outWidth', type: 'int' },
94358 ];
94359 this.outputShape = outputShape;
94360 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
94361 const { dataFormat } = convInfo;
94362 const glsl = getGlslDifferences();
94363 const isChannelsLast = dataFormat === 'channelsLast';
94364 const rowDim = isChannelsLast ? 1 : 2;
94365 const colDim = isChannelsLast ? 2 : 3;
94366 const boundsCheckingSnippet = this.enableShapeUniforms ?
94367 'if(blockIndex < outShape[2] && pos < outShape[1]) {' :
94368 `if(blockIndex < ${outputShape[2]} && pos < ${outputShape[1]}) {`;
94369 let unrolled = ``;
94370 for (let row = 0; row <= 1; row++) {
94371 for (let col = 0; col <= 1; col++) {
94372 unrolled += `
94373 blockIndex = rc.z + ${col};
94374 pos = rc.y + ${row};
94375
94376 ${boundsCheckingSnippet}
94377 offsetY = int(blockIndex / outWidth) * stride[0] - pad[0];
94378 d0 = offsetY + dilation[0] * (pos / itemsPerBlockRow);
94379
94380 if(d0 < inputShape[${rowDim}] && d0 >= 0) {
94381 // Use custom imod instead mod. On Intel GPU, mod may generate
94382 // unexpected value.
94383 // https://github.com/tensorflow/tfjs/issues/5447
94384 offsetX = imod(blockIndex, outWidth) * stride[1] - pad[1];
94385 d1 = offsetX + dilation[1] * (imod(pos, itemsPerBlockRow) /
94386 inChannels);
94387
94388 if(d1 < inputShape[${colDim}] && d1 >= 0) {
94389
94390 ch = imod(pos, inChannels);
94391
94392 if (${isChannelsLast}) {
94393 innerDims = vec2(d1, ch);
94394 result[${row * 2 + col}] = getChannel(
94395 getA(rc.x, d0, int(innerDims.x),
94396 int(innerDims.y)), innerDims);
94397 } else {
94398 innerDims = vec2(d0, d1);
94399 result[${row * 2 + col}] = getChannel(
94400 getA(rc.x, ch, int(innerDims.x),
94401 int(innerDims.y)), innerDims);
94402 }
94403 }
94404 }
94405 }
94406 `;
94407 }
94408 }
94409 this.userCode = `
94410 void main() {
94411 ivec3 rc = getOutputCoords();
94412
94413 vec4 result = vec4(0);
94414
94415 int blockIndex, pos, offsetY, d0, offsetX, d1, ch;
94416 vec2 innerDims;
94417
94418 ${unrolled}
94419
94420 ${glsl.output} = result;
94421 }
94422 `;
94423 }
94424 }
94425
94426 /**
94427 * @license
94428 * Copyright 2020 Google LLC. All Rights Reserved.
94429 * Licensed under the Apache License, Version 2.0 (the "License");
94430 * you may not use this file except in compliance with the License.
94431 * You may obtain a copy of the License at
94432 *
94433 * http://www.apache.org/licenses/LICENSE-2.0
94434 *
94435 * Unless required by applicable law or agreed to in writing, software
94436 * distributed under the License is distributed on an "AS IS" BASIS,
94437 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94438 * See the License for the specific language governing permissions and
94439 * limitations under the License.
94440 * =============================================================================
94441 */
94442 // Both conv2dByMatMul and conv2dWithIm2Row fuse height and width into one
94443 // dimension to compute batchMatMul, so bias and activation weights are also
94444 // supposed to fuse the two dimensions into one.
94445 //
94446 // This function computes the target shape for fusing height and width
94447 // dimensions. Returning null means the shape is already compatible.
94448 //
94449 // Even though the bias is not supposed to be a 3-D or a 4-D (including
94450 // batch) tensor and PReLU activiation weights is not supposed to be a 4-D
94451 // tensor, we still need to support them, because we haven't disabled
94452 // them for NHWC format.
94453 // https://github.com/tensorflow/tfjs/blob/b53bd47e880367ae57493f0ea628abaf08db2d5d/tfjs-core/src/ops/fused/conv2d.ts#L181-L196
94454 function getShapeForBatchMatMul(shape, isChannelsLast) {
94455 const length = shape.length;
94456 if (length >= 3) {
94457 return isChannelsLast ?
94458 [
94459 ...shape.slice(0, -3) /* batch */,
94460 shape[length - 3] * shape[length - 2] /* height * width */,
94461 shape[length - 1] /* channel */
94462 ] :
94463 [
94464 ...shape.slice(0, -3) /* batch */, shape[length - 3] /* channel */,
94465 shape[length - 2] * shape[length - 1] /* height * width */
94466 ];
94467 }
94468 else if (!isChannelsLast && length === 1 && shape[0] > 1) {
94469 return [shape[0], 1];
94470 }
94471 else {
94472 return null;
94473 }
94474 }
94475 // For 1x1 kernels that iterate through every point in the input, convolution
94476 // can be expressed as matrix multiplication (without need for memory
94477 // remapping).
94478 function conv2dByMatMul({ x, filter, convInfo, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) {
94479 // Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
94480 // result from 2D to 4D.
94481 const xShape = x.shape;
94482 const xTexData = backend.texData.get(x.dataId);
94483 const sharedMatMulDim = convInfo.inChannels;
94484 const outerShapeX = xShape[0] * xShape[1] * xShape[2];
94485 const outerShapeFilter = convInfo.outChannels;
94486 const isChannelsLast = convInfo.dataFormat === 'channelsLast';
94487 const transposeA = false;
94488 const transposeB = false;
94489 let out;
94490 const intermediates = [];
94491 if (preluActivationWeights != null) {
94492 const targetShape = getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast);
94493 if (targetShape != null) {
94494 preluActivationWeights = reshape({
94495 inputs: { x: preluActivationWeights },
94496 backend,
94497 attrs: { shape: targetShape }
94498 });
94499 intermediates.push(preluActivationWeights);
94500 }
94501 }
94502 if (bias != null) {
94503 const targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast);
94504 if (targetShape != null) {
94505 bias = reshape({ inputs: { x: bias }, backend, attrs: { shape: targetShape } });
94506 intermediates.push(bias);
94507 }
94508 }
94509 // TODO: Once reduction ops are packed, batchMatMul will always be packed
94510 // and we can remove this condition.
94511 const batchMatMulWillBeUnpacked = (outerShapeX === 1 || outerShapeFilter === 1) &&
94512 sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD;
94513 // The algorithm in the if condition assumes (1) the output will be packed,
94514 // (2) x is packed, (3) x isChannelsLast, (4) x's packed texture is already
94515 // on GPU, (5) col is odd, (6) the width, height and inChannels are the same
94516 // for xTexData.shape and xShape.
94517 const canOptimize = !batchMatMulWillBeUnpacked && xTexData.isPacked &&
94518 isChannelsLast && xTexData.texture != null && xShape[2] % 2 !== 0 &&
94519 arraysEqual(xTexData.shape.slice(-3), xShape.slice(-3));
94520 if (canOptimize) {
94521 // We avoid expensive packed 2x2 reshape by padding col count to next,
94522 // even number. When col is odd, the result of packed batchMatMul is
94523 // the same (has the same texture layout and and values in the texture) as
94524 // it is for next even col. We make the odd-cols tensor to look like
94525 // even-cols tensor before the operation and, after the batchMatMul,
94526 // fix the even-cols result to have odd number of cols.
94527 const targetShape = xShape[0] * xShape[1] * (xShape[2] + 1);
94528 const xReshaped = {
94529 dataId: x.dataId,
94530 shape: [1, targetShape, convInfo.inChannels],
94531 dtype: x.dtype
94532 };
94533 // xTexData.shape gets referenced from GPGPUBinary.inShapeInfos.
94534 // Decrementing col count, after batchMatMul->...->compileProgram leads to
94535 // invalid col count within the reference in GPGPUBinary.inShapeInfos.
94536 // Alternative fix would be to provide a copy to GPGPUBinary.inShapeInfos
94537 // in compileProgram method, but that would affect compilation of all
94538 // programs - instead, provide a copy here, with even col count, before
94539 // calling batchMatMul->...->compileProgram and after that, the original
94540 // xTexData.shape is restored.
94541 const originalXTexDataShape = xTexData.shape;
94542 xTexData.shape = xTexData.shape.slice();
94543 xTexData.shape[xTexData.shape.length - 2]++;
94544 assert$1(isReshapeFree(xTexData.shape, xReshaped.shape), () => `packed reshape ${xTexData.shape} to ${xReshaped.shape} isn't free`);
94545 const filterReshaped = reshape({
94546 inputs: { x: filter },
94547 backend,
94548 attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] }
94549 });
94550 intermediates.push(filterReshaped);
94551 const pointwiseConv = batchMatMulImpl({
94552 a: xReshaped,
94553 b: filterReshaped,
94554 backend,
94555 transposeA,
94556 transposeB,
94557 bias,
94558 activation,
94559 preluActivationWeights,
94560 leakyreluAlpha
94561 });
94562 const pointwiseConvTexData = backend.texData.get(pointwiseConv.dataId);
94563 assert$1(pointwiseConvTexData.isPacked, () => 'batchMatMul result is expected to be packed');
94564 // Restore the input shape to original.
94565 xTexData.shape = originalXTexDataShape;
94566 // Set the output shape - there is no need for expensive reshape as data
94567 // layout is already correct.
94568 pointwiseConvTexData.shape = convInfo.outShape;
94569 out = identity({ inputs: { x: pointwiseConv }, backend });
94570 out.shape = convInfo.outShape;
94571 intermediates.push(pointwiseConv);
94572 }
94573 else {
94574 const numCols = convInfo.outHeight * convInfo.outWidth;
94575 const xReshaped = reshape({
94576 inputs: { x },
94577 backend,
94578 attrs: {
94579 shape: isChannelsLast ?
94580 [convInfo.batchSize, numCols, convInfo.inChannels] :
94581 [convInfo.batchSize, convInfo.inChannels, numCols]
94582 }
94583 });
94584 const filterReshaped = reshape({
94585 inputs: { x: filter },
94586 backend,
94587 attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] }
94588 });
94589 const result = batchMatMulImpl({
94590 a: isChannelsLast ? xReshaped : filterReshaped,
94591 b: isChannelsLast ? filterReshaped : xReshaped,
94592 transposeA: !isChannelsLast,
94593 transposeB,
94594 backend,
94595 bias,
94596 activation,
94597 preluActivationWeights,
94598 leakyreluAlpha
94599 });
94600 out = reshape({ inputs: { x: result }, backend, attrs: { shape: convInfo.outShape } });
94601 intermediates.push(xReshaped);
94602 intermediates.push(filterReshaped);
94603 intermediates.push(result);
94604 }
94605 for (const i of intermediates) {
94606 backend.disposeIntermediateTensorInfo(i);
94607 }
94608 return out;
94609 }
94610 // Implements the im2row algorithm as outlined in "High Performance
94611 // Convolutional Neural Networks for Document Processing" (Suvisoft, 2006)
94612 function conv2dWithIm2Row({ x, filter, convInfo, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) {
94613 // Rearranges conv2d input so each block to be convolved over forms the
94614 // column of a new matrix with shape [filterWidth * filterHeight *
94615 // inChannels, outHeight * outWidth]. The filter is also rearranged so each
94616 // output channel forms a row of a new matrix with shape [outChannels,
94617 // filterWidth * filterHeight * inChannels]. The convolution is then
94618 // computed by multiplying these matrices and reshaping the result.
94619 const { filterWidth, filterHeight, inChannels, outWidth, outHeight, dataFormat } = convInfo;
94620 const isChannelsLast = dataFormat === 'channelsLast';
94621 const sharedDim = filterWidth * filterHeight * inChannels;
94622 const numCols = outHeight * outWidth;
94623 const x2ColShape = [convInfo.batchSize, sharedDim, numCols];
94624 const transposeA = true;
94625 const transposeB = false;
94626 const intermediates = [];
94627 if (preluActivationWeights != null) {
94628 const targetShape = getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast);
94629 if (targetShape != null) {
94630 preluActivationWeights = reshape({
94631 inputs: { x: preluActivationWeights },
94632 backend,
94633 attrs: { shape: targetShape }
94634 });
94635 intermediates.push(preluActivationWeights);
94636 }
94637 }
94638 if (bias != null) {
94639 const targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast);
94640 if (targetShape != null) {
94641 bias = reshape({ inputs: { x: bias }, backend, attrs: { shape: targetShape } });
94642 intermediates.push(bias);
94643 }
94644 }
94645 const w2Row = reshape({
94646 inputs: { x: filter },
94647 backend,
94648 attrs: { shape: [1, sharedDim, sizeFromShape(filter.shape) / sharedDim] }
94649 });
94650 intermediates.push(w2Row);
94651 const im2ColProgram = new Im2ColPackedProgram(x2ColShape, convInfo);
94652 const customValues = [
94653 x.shape, [convInfo.padInfo.top, convInfo.padInfo.left],
94654 [convInfo.strideHeight, convInfo.strideWidth],
94655 [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inChannels],
94656 [convInfo.filterWidth * convInfo.inChannels], [convInfo.outWidth]
94657 ];
94658 const im2Col = backend.runWebGLProgram(im2ColProgram, [x], 'float32', customValues);
94659 const im2ColReshaped = reshape({ inputs: { x: im2Col }, backend, attrs: { shape: x2ColShape } });
94660 intermediates.push(im2Col);
94661 intermediates.push(im2ColReshaped);
94662 const hasBias = bias != null;
94663 const hasPreluActivationWeights = preluActivationWeights != null;
94664 const hasLeakyreluAlpha = activation === 'leakyrelu';
94665 const fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null;
94666 const matmulProgram = new MatMulPackedProgram(isChannelsLast ? im2ColReshaped.shape :
94667 w2Row.shape, isChannelsLast ? w2Row.shape :
94668 im2ColReshaped.shape, isChannelsLast ? [convInfo.batchSize, numCols, convInfo.outChannels] :
94669 [convInfo.batchSize, convInfo.outChannels, numCols], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
94670 const inputs = isChannelsLast ? [im2ColReshaped, w2Row] : [w2Row, im2ColReshaped];
94671 if (bias) {
94672 inputs.push(bias);
94673 }
94674 if (hasPreluActivationWeights) {
94675 inputs.push(preluActivationWeights);
94676 }
94677 if (hasLeakyreluAlpha) {
94678 const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
94679 inputs.push($leakyreluAlpha);
94680 intermediates.push($leakyreluAlpha);
94681 }
94682 const product = backend.runWebGLProgram(matmulProgram, inputs, 'float32');
94683 const out = reshape({ inputs: { x: product }, backend, attrs: { shape: convInfo.outShape } });
94684 intermediates.push(product);
94685 for (const i of intermediates) {
94686 backend.disposeIntermediateTensorInfo(i);
94687 }
94688 return out;
94689 }
94690
94691 /**
94692 * @license
94693 * Copyright 2020 Google LLC. All Rights Reserved.
94694 * Licensed under the Apache License, Version 2.0 (the "License");
94695 * you may not use this file except in compliance with the License.
94696 * You may obtain a copy of the License at
94697 *
94698 * http://www.apache.org/licenses/LICENSE-2.0
94699 *
94700 * Unless required by applicable law or agreed to in writing, software
94701 * distributed under the License is distributed on an "AS IS" BASIS,
94702 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94703 * See the License for the specific language governing permissions and
94704 * limitations under the License.
94705 * =============================================================================
94706 */
94707 function conv2d(args) {
94708 const { inputs, backend, attrs } = args;
94709 const { x, filter } = inputs;
94710 const { strides, pad, dataFormat, dilations, dimRoundingMode } = attrs;
94711 const $dataFormat = convertConv2DDataFormat(dataFormat);
94712 const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
94713 let out;
94714 if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
94715 convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
94716 convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
94717 (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) {
94718 out = conv2dByMatMul({ x, filter, convInfo, backend });
94719 }
94720 else if (convInfo.strideWidth <= 2 && $dataFormat === 'channelsLast'
94721 && env().getBool('WEBGL_EXP_CONV')) {
94722 const program = new Conv2DPackedProgram(convInfo);
94723 const customValues = [
94724 [convInfo.padInfo.top, convInfo.padInfo.left],
94725 [convInfo.strideHeight, convInfo.strideWidth],
94726 [convInfo.dilationHeight, convInfo.dilationWidth],
94727 [convInfo.inHeight, convInfo.inWidth]
94728 ];
94729 out =
94730 backend.runWebGLProgram(program, [x, filter], 'float32', customValues);
94731 }
94732 else if (env().getBool('WEBGL_CONV_IM2COL')) {
94733 out = conv2dWithIm2Row({ x, filter, convInfo, backend });
94734 }
94735 else {
94736 const program = new Conv2DProgram(convInfo);
94737 out = backend.runWebGLProgram(program, [x, filter], 'float32');
94738 }
94739 const outReshaped = reshape({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } });
94740 backend.disposeIntermediateTensorInfo(out);
94741 return outReshaped;
94742 }
94743 const conv2DConfig = {
94744 kernelName: Conv2D$1,
94745 backendName: 'webgl',
94746 kernelFunc: conv2d,
94747 };
94748
94749 /**
94750 * @license
94751 * Copyright 2017 Google LLC. All Rights Reserved.
94752 * Licensed under the Apache License, Version 2.0 (the "License");
94753 * you may not use this file except in compliance with the License.
94754 * You may obtain a copy of the License at
94755 *
94756 * http://www.apache.org/licenses/LICENSE-2.0
94757 *
94758 * Unless required by applicable law or agreed to in writing, software
94759 * distributed under the License is distributed on an "AS IS" BASIS,
94760 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
94761 * See the License for the specific language governing permissions and
94762 * limitations under the License.
94763 * =============================================================================
94764 */
94765 class Conv2DDerFilterProgram {
94766 constructor(convInfo) {
94767 this.variableNames = ['x', 'dy'];
94768 this.outputShape = convInfo.filterShape;
94769 const strideHeight = convInfo.strideHeight;
94770 const strideWidth = convInfo.strideWidth;
94771 const padTop = convInfo.padInfo.top;
94772 const padLeft = convInfo.padInfo.left;
94773 const isChannelsLast = convInfo.dataFormat === 'channelsLast';
94774 this.userCode = `
94775 void main() {
94776 ivec4 coords = getOutputCoords();
94777 int wR = coords.x;
94778 int wC = coords.y;
94779 int d1 = coords.z;
94780 int d2 = coords.w;
94781
94782 // Convolve x(?, ?, d1) with dy(:, :, d2) to get dw(wR, wC, d1, d2).
94783 // ? = to be determined. : = across all values in that axis.
94784 float dotProd = 0.0;
94785
94786 for (int b = 0; b < ${convInfo.batchSize}; b++) {
94787 for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
94788 int xR = wR + yR * ${strideHeight} - ${padTop};
94789
94790 if (xR < 0 || xR >= ${convInfo.inHeight}) {
94791 continue;
94792 }
94793
94794 for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
94795 int xC = wC + yC * ${strideWidth} - ${padLeft};
94796
94797 if (xC < 0 || xC >= ${convInfo.inWidth}) {
94798 continue;
94799 }
94800
94801 ${isChannelsLast ?
94802 `float dyValue = getDy(b, yR, yC, d2);
94803 float xValue = getX(b, xR, xC, d1);
94804 dotProd += (xValue * dyValue);` :
94805 `float dyValue = getDy(b, d2, yR, yC);
94806 float xValue = getX(b, d1, xR, xC);
94807 dotProd += (xValue * dyValue);`}
94808 }
94809 }
94810 }
94811 setOutput(dotProd);
94812 }
94813 `;
94814 }
94815 }
94816 class Conv2DDerInputProgram {
94817 constructor(convInfo) {
94818 this.variableNames = ['dy', 'W'];
94819 this.outputShape = convInfo.inShape;
94820 const filterHeight = convInfo.filterHeight;
94821 const filterWidth = convInfo.filterWidth;
94822 const strideHeight = convInfo.strideHeight;
94823 const strideWidth = convInfo.strideWidth;
94824 const isChannelsLast = convInfo.dataFormat === 'channelsLast';
94825 const padTop = filterHeight - 1 - convInfo.padInfo.top;
94826 const padLeft = filterWidth - 1 - convInfo.padInfo.left;
94827 const rowDim = isChannelsLast ? 1 : 2;
94828 const colDim = isChannelsLast ? 2 : 3;
94829 const channelDim = isChannelsLast ? 3 : 1;
94830 this.userCode = `
94831 const ivec2 pads = ivec2(${padTop}, ${padLeft});
94832
94833 void main() {
94834 ivec4 coords = getOutputCoords();
94835 int batch = coords[0];
94836 int d1 = coords[${channelDim}];
94837
94838 ivec2 dyCorner = ivec2(coords[${rowDim}], coords[${colDim}]) - pads;
94839 int dyRCorner = dyCorner.x;
94840 int dyCCorner = dyCorner.y;
94841
94842 // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
94843 // ? = to be determined. : = across all values in that axis.
94844 float dotProd = 0.0;
94845 for (int wR = 0; wR < ${filterHeight}; wR++) {
94846 float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
94847
94848 if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
94849 continue;
94850 }
94851 int idyR = int(dyR);
94852
94853 int wRPerm = ${filterHeight} - 1 - wR;
94854
94855 for (int wC = 0; wC < ${filterWidth}; wC++) {
94856 float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
94857
94858 if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
94859 fract(dyC) > 0.0) {
94860 continue;
94861 }
94862 int idyC = int(dyC);
94863
94864 int wCPerm = ${filterWidth} - 1 - wC;
94865
94866 for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) {
94867
94868 if (${isChannelsLast}) {
94869 float xValue = getDy(batch, idyR, idyC, d2);
94870 float wValue = getW(wRPerm, wCPerm, d1, d2);
94871 dotProd += xValue * wValue;
94872 } else {
94873 float xValue = getDy(batch, d2, idyR, idyC);
94874 float wValue = getW(wRPerm, wCPerm, d1, d2);
94875 dotProd += xValue * wValue;
94876 }
94877
94878 }
94879 }
94880 }
94881 setOutput(dotProd);
94882 }
94883 `;
94884 }
94885 }
94886 class Conv3DDerFilterProgram {
94887 constructor(convInfo) {
94888 this.variableNames = ['x', 'dy'];
94889 this.outputShape = convInfo.filterShape;
94890 const strideDepth = convInfo.strideDepth;
94891 const strideHeight = convInfo.strideHeight;
94892 const strideWidth = convInfo.strideWidth;
94893 const padFront = convInfo.padInfo.front;
94894 const padTop = convInfo.padInfo.top;
94895 const padLeft = convInfo.padInfo.left;
94896 this.userCode = `
94897 void main() {
94898 ivec5 coords = getOutputCoords();
94899 int wF = coords.x;
94900 int wR = coords.y;
94901 int wC = coords.z;
94902 int d1 = coords.w;
94903 int d2 = coords.u;
94904
94905 float dotProd = 0.0;
94906
94907 for (int b = 0; b < ${convInfo.batchSize}; b++) {
94908 for (int yF = 0; yF < ${convInfo.outDepth}; yF++) {
94909 int xF = wF + yF * ${strideDepth} - ${padFront};
94910
94911 if (xF < 0 || xF >= ${convInfo.inDepth}) {
94912 continue;
94913 }
94914
94915 for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
94916 int xR = wR + yR * ${strideHeight} - ${padTop};
94917
94918 if (xR < 0 || xR >= ${convInfo.inHeight}) {
94919 continue;
94920 }
94921
94922 for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
94923 int xC = wC + yC * ${strideWidth} - ${padLeft};
94924
94925 if (xC < 0 || xC >= ${convInfo.inWidth}) {
94926 continue;
94927 }
94928
94929 float dyValue = getDy(b, yF, yR, yC, d2);
94930 float xValue = getX(b, xF, xR, xC, d1);
94931 dotProd += (xValue * dyValue);
94932 }
94933 }
94934 }
94935 }
94936 setOutput(dotProd);
94937 }
94938 `;
94939 }
94940 }
94941 class Conv3DDerInputProgram {
94942 constructor(convInfo) {
94943 this.variableNames = ['dy', 'W'];
94944 this.outputShape = convInfo.inShape;
94945 const filterDepth = convInfo.filterDepth;
94946 const filterHeight = convInfo.filterHeight;
94947 const filterWidth = convInfo.filterWidth;
94948 const strideDepth = convInfo.strideDepth;
94949 const strideHeight = convInfo.strideHeight;
94950 const strideWidth = convInfo.strideWidth;
94951 const padFront = filterDepth - 1 - convInfo.padInfo.front;
94952 const padTop = filterHeight - 1 - convInfo.padInfo.top;
94953 const padLeft = filterWidth - 1 - convInfo.padInfo.left;
94954 this.userCode = `
94955 const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
94956
94957 void main() {
94958 ivec5 coords = getOutputCoords();
94959 int batch = coords.x;
94960 int d1 = coords.u;
94961
94962
94963 ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
94964 int dyFCorner = dyCorner.x;
94965 int dyRCorner = dyCorner.y;
94966 int dyCCorner = dyCorner.z;
94967
94968 float dotProd = 0.0;
94969 for (int wF = 0; wF < ${filterDepth}; wF++) {
94970 float dyF = float(dyFCorner + wF) / ${strideDepth}.0;
94971
94972 if (dyF < 0.0 || dyF >= ${convInfo.outDepth}.0 || fract(dyF) > 0.0) {
94973 continue;
94974 }
94975 int idyF = int(dyF);
94976
94977 int wFPerm = ${filterDepth} - 1 - wF;
94978
94979 for (int wR = 0; wR < ${filterHeight}; wR++) {
94980 float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
94981
94982 if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
94983 fract(dyR) > 0.0) {
94984 continue;
94985 }
94986 int idyR = int(dyR);
94987
94988 int wRPerm = ${filterHeight} - 1 - wR;
94989
94990 for (int wC = 0; wC < ${filterWidth}; wC++) {
94991 float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
94992
94993 if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
94994 fract(dyC) > 0.0) {
94995 continue;
94996 }
94997 int idyC = int(dyC);
94998
94999 int wCPerm = ${filterWidth} - 1 - wC;
95000
95001 for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) {
95002 float xValue = getDy(batch, idyF, idyR, idyC, d2);
95003 float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2);
95004 dotProd += xValue * wValue;
95005 }
95006 }
95007 }
95008 }
95009 setOutput(dotProd);
95010 }
95011 `;
95012 }
95013 }
95014
95015 /**
95016 * @license
95017 * Copyright 2020 Google LLC. All Rights Reserved.
95018 * Licensed under the Apache License, Version 2.0 (the "License");
95019 * you may not use this file except in compliance with the License.
95020 * You may obtain a copy of the License at
95021 *
95022 * http://www.apache.org/licenses/LICENSE-2.0
95023 *
95024 * Unless required by applicable law or agreed to in writing, software
95025 * distributed under the License is distributed on an "AS IS" BASIS,
95026 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95027 * See the License for the specific language governing permissions and
95028 * limitations under the License.
95029 * =============================================================================
95030 */
95031 function conv2DBackpropFilter(args) {
95032 const { inputs, backend, attrs } = args;
95033 const { x, dy } = inputs;
95034 const { strides, pad, dataFormat, dimRoundingMode, filterShape } = attrs;
95035 const $dataFormat = convertConv2DDataFormat(dataFormat);
95036 const convInfo = computeConv2DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
95037 const program = new Conv2DDerFilterProgram(convInfo);
95038 return backend.runWebGLProgram(program, [x, dy], 'float32');
95039 }
95040 const conv2DBackpropFilterConfig = {
95041 kernelName: Conv2DBackpropFilter,
95042 backendName: 'webgl',
95043 kernelFunc: conv2DBackpropFilter,
95044 };
95045
95046 /**
95047 * @license
95048 * Copyright 2023 Google LLC.
95049 * Licensed under the Apache License, Version 2.0 (the "License");
95050 * you may not use this file except in compliance with the License.
95051 * You may obtain a copy of the License at
95052 *
95053 * http://www.apache.org/licenses/LICENSE-2.0
95054 *
95055 * Unless required by applicable law or agreed to in writing, software
95056 * distributed under the License is distributed on an "AS IS" BASIS,
95057 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95058 * See the License for the specific language governing permissions and
95059 * limitations under the License.
95060 * =============================================================================
95061 */
95062 class Conv2DDerInputPackedProgram {
95063 constructor(convInfo) {
95064 this.variableNames = ['dy', 'W'];
95065 this.packedInputs = true;
95066 this.packedOutput = true;
95067 this.customUniforms = [
95068 { name: 'strides', type: 'vec2' },
95069 ];
95070 this.outputShape = convInfo.inShape;
95071 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
95072 const filterHeight = convInfo.filterHeight;
95073 const filterWidth = convInfo.filterWidth;
95074 const padTop = filterHeight - 1 - convInfo.padInfo.top;
95075 const padLeft = filterWidth - 1 - convInfo.padInfo.left;
95076 this.userCode = `
95077 const ivec2 pads = ivec2(${padTop}, ${padLeft});
95078
95079 void main() {
95080 ivec4 coords = getOutputCoords();
95081 int batch = coords[0];
95082 int d1 = coords[3];
95083
95084 ivec2 dyCorner = ivec2(coords[1], coords[2]) - pads;
95085 int dyRCorner = dyCorner.x;
95086 int dyCCorner = dyCorner.y;
95087
95088 vec4 result = vec4(0.);
95089 for (int wR = 0; wR < ${filterHeight}; wR++) {
95090 float dyR = float(dyRCorner + wR) / strides[0];
95091 if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
95092 continue;
95093 }
95094 int idyR = int(dyR);
95095 int wRPerm = ${filterHeight} - 1 - wR;
95096
95097 for (int wC = 0; wC < ${filterWidth}; wC++) {
95098 int wCPerm = ${filterWidth} - 1 - wC;
95099
95100 float dyC = float(dyCCorner + wC) / strides[1];
95101 bool idyCVal = (dyC >= 0.0) && (dyC < ${convInfo.outWidth}.0)
95102 && (fract(dyC) == 0.0);
95103 int idyC = int(dyC);
95104
95105 float dyC2 = float(dyCCorner + wC + 1) / strides[1];
95106 bool idyCVal2 = (dyC2 >= 0.0) && (dyC2 < ${convInfo.outWidth}.0)
95107 && (fract(dyC2) == 0.0);
95108 int idyC2 = int(dyC2);
95109
95110 if (idyCVal && idyCVal2) {
95111 for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) {
95112 vec4 wValue = getW(wRPerm, wCPerm, d1, d2);
95113 vec4 dySample = getDy(batch, idyR, idyC, d2);
95114 vec4 dySample2 = (idyC / 2 == idyC2 / 2) ?
95115 dySample : getDy(batch, idyR, idyC2, d2);
95116
95117 vec2 dyValue = mod(float(idyC), 2.) == 0. ?
95118 dySample.xy : dySample.zw;
95119 result.xy += vec2(dot(dyValue, wValue.xy),
95120 dot(dyValue, wValue.zw));
95121
95122 dyValue = mod(float(idyC2), 2.) == 0. ?
95123 dySample2.xy : dySample2.zw;
95124 result.zw += vec2(dot(dyValue, wValue.xy),
95125 dot(dyValue, wValue.zw));
95126 }
95127 } else if (idyCVal) {
95128 for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) {
95129 vec4 wValue = getW(wRPerm, wCPerm, d1, d2);
95130 vec4 dySample = getDy(batch, idyR, idyC, d2);
95131 vec2 dyValue = mod(float(idyC), 2.) == 0. ?
95132 dySample.xy : dySample.zw;
95133 result.xy += vec2(dot(dyValue, wValue.xy),
95134 dot(dyValue, wValue.zw));
95135 }
95136 } else if (idyCVal2) {
95137 for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) {
95138 vec4 wValue = getW(wRPerm, wCPerm, d1, d2);
95139 vec4 dySample = getDy(batch, idyR, idyC2, d2);
95140 vec2 dyValue = mod(float(idyC2), 2.) == 0. ?
95141 dySample.xy : dySample.zw;
95142 result.zw += vec2(dot(dyValue, wValue.xy),
95143 dot(dyValue, wValue.zw));
95144 }
95145 }
95146 }
95147 }
95148 setOutput(result);
95149 }
95150 `;
95151 }
95152 }
95153
95154 /**
95155 * @license
95156 * Copyright 2020 Google LLC. All Rights Reserved.
95157 * Licensed under the Apache License, Version 2.0 (the "License");
95158 * you may not use this file except in compliance with the License.
95159 * You may obtain a copy of the License at
95160 *
95161 * http://www.apache.org/licenses/LICENSE-2.0
95162 *
95163 * Unless required by applicable law or agreed to in writing, software
95164 * distributed under the License is distributed on an "AS IS" BASIS,
95165 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95166 * See the License for the specific language governing permissions and
95167 * limitations under the License.
95168 * =============================================================================
95169 */
95170 function conv2DBackpropInput(args) {
95171 const { inputs, backend, attrs } = args;
95172 const { dy, filter } = inputs;
95173 const { inputShape, strides, pad, dataFormat, dimRoundingMode } = attrs;
95174 const $dataFormat = convertConv2DDataFormat(dataFormat);
95175 const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad, dimRoundingMode, false, $dataFormat);
95176 if (env().getBool('WEBGL_PACK_CONV2DTRANSPOSE') &&
95177 $dataFormat === 'channelsLast') {
95178 const customValues = [
95179 [convInfo.strideHeight, convInfo.strideWidth],
95180 ];
95181 const program = new Conv2DDerInputPackedProgram(convInfo);
95182 return backend.runWebGLProgram(program, [dy, filter], 'float32', customValues);
95183 }
95184 else {
95185 const program = new Conv2DDerInputProgram(convInfo);
95186 return backend.runWebGLProgram(program, [dy, filter], 'float32');
95187 }
95188 }
95189 const conv2DBackpropInputConfig = {
95190 kernelName: Conv2DBackpropInput,
95191 backendName: 'webgl',
95192 kernelFunc: conv2DBackpropInput,
95193 };
95194
95195 /**
95196 * @license
95197 * Copyright 2020 Google LLC. All Rights Reserved.
95198 * Licensed under the Apache License, Version 2.0 (the "License");
95199 * you may not use this file except in compliance with the License.
95200 * You may obtain a copy of the License at
95201 *
95202 * http://www.apache.org/licenses/LICENSE-2.0
95203 *
95204 * Unless required by applicable law or agreed to in writing, software
95205 * distributed under the License is distributed on an "AS IS" BASIS,
95206 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95207 * See the License for the specific language governing permissions and
95208 * limitations under the License.
95209 * =============================================================================
95210 */
95211 function conv3D(args) {
95212 const { inputs, backend, attrs } = args;
95213 const { x, filter } = inputs;
95214 const { strides, pad, dilations } = attrs;
95215 const convInfo = computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad);
95216 const program = new Conv3DProgram(convInfo);
95217 return backend.runWebGLProgram(program, [x, filter], 'float32');
95218 }
95219 const conv3DConfig = {
95220 kernelName: Conv3D$1,
95221 backendName: 'webgl',
95222 kernelFunc: conv3D,
95223 };
95224
95225 /**
95226 * @license
95227 * Copyright 2020 Google LLC. All Rights Reserved.
95228 * Licensed under the Apache License, Version 2.0 (the "License");
95229 * you may not use this file except in compliance with the License.
95230 * You may obtain a copy of the License at
95231 *
95232 * http://www.apache.org/licenses/LICENSE-2.0
95233 *
95234 * Unless required by applicable law or agreed to in writing, software
95235 * distributed under the License is distributed on an "AS IS" BASIS,
95236 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95237 * See the License for the specific language governing permissions and
95238 * limitations under the License.
95239 * =============================================================================
95240 */
95241 function conv3DBackpropFilterV2(args) {
95242 const { inputs, backend, attrs } = args;
95243 const { x, dy } = inputs;
95244 const { strides, pad, filterShape } = attrs;
95245 const convInfo = computeConv3DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad);
95246 const program = new Conv3DDerFilterProgram(convInfo);
95247 return backend.runWebGLProgram(program, [x, dy], 'float32');
95248 }
95249 const conv3DBackpropFilterV2Config = {
95250 kernelName: Conv3DBackpropFilterV2,
95251 backendName: 'webgl',
95252 kernelFunc: conv3DBackpropFilterV2
95253 };
95254
95255 /**
95256 * @license
95257 * Copyright 2020 Google LLC. All Rights Reserved.
95258 * Licensed under the Apache License, Version 2.0 (the "License");
95259 * you may not use this file except in compliance with the License.
95260 * You may obtain a copy of the License at
95261 *
95262 * http://www.apache.org/licenses/LICENSE-2.0
95263 *
95264 * Unless required by applicable law or agreed to in writing, software
95265 * distributed under the License is distributed on an "AS IS" BASIS,
95266 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95267 * See the License for the specific language governing permissions and
95268 * limitations under the License.
95269 * =============================================================================
95270 */
95271 function conv3DBackpropInput(args) {
95272 const { inputs, backend, attrs } = args;
95273 const { dy, filter } = inputs;
95274 const { pad, strides, inputShape } = attrs;
95275 const convInfo = computeConv3DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad);
95276 const program = new Conv3DDerInputProgram(convInfo);
95277 return backend.runWebGLProgram(program, [dy, filter], 'float32');
95278 }
95279 const conv3DBackpropInputConfig = {
95280 kernelName: Conv3DBackpropInputV2,
95281 backendName: 'webgl',
95282 kernelFunc: conv3DBackpropInput,
95283 };
95284
95285 /**
95286 * @license
95287 * Copyright 2020 Google LLC. All Rights Reserved.
95288 * Licensed under the Apache License, Version 2.0 (the "License");
95289 * you may not use this file except in compliance with the License.
95290 * You may obtain a copy of the License at
95291 *
95292 * http://www.apache.org/licenses/LICENSE-2.0
95293 *
95294 * Unless required by applicable law or agreed to in writing, software
95295 * distributed under the License is distributed on an "AS IS" BASIS,
95296 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95297 * See the License for the specific language governing permissions and
95298 * limitations under the License.
95299 * =============================================================================
95300 */
95301 const COS = CHECK_NAN_SNIPPET_UNARY + `
95302 return cos(x);
95303`;
95304 const COS_PACKED = `
95305 vec4 result = cos(x);
95306 bvec4 isNaN = isnan(x);
95307 ${CHECK_NAN_SNIPPET_PACKED}
95308 return result;
95309`;
95310 const cos = unaryKernelFunc({ opSnippet: COS, packedOpSnippet: COS_PACKED });
95311 const cosConfig = {
95312 kernelName: Cos,
95313 backendName: 'webgl',
95314 kernelFunc: cos,
95315 };
95316
95317 /**
95318 * @license
95319 * Copyright 2020 Google LLC. All Rights Reserved.
95320 * Licensed under the Apache License, Version 2.0 (the "License");
95321 * you may not use this file except in compliance with the License.
95322 * You may obtain a copy of the License at
95323 *
95324 * http://www.apache.org/licenses/LICENSE-2.0
95325 *
95326 * Unless required by applicable law or agreed to in writing, software
95327 * distributed under the License is distributed on an "AS IS" BASIS,
95328 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95329 * See the License for the specific language governing permissions and
95330 * limitations under the License.
95331 * =============================================================================
95332 */
95333 const COSH = `
95334 float e2x = exp(-x);
95335 return (e2x + 1.0 / e2x) / 2.0;
95336`;
95337 const cosh = unaryKernelFunc({ opSnippet: COSH });
95338 const coshConfig = {
95339 kernelName: Cosh,
95340 backendName: 'webgl',
95341 kernelFunc: cosh,
95342 };
95343
95344 /**
95345 * @license
95346 * Copyright 2017 Google LLC. All Rights Reserved.
95347 * Licensed under the Apache License, Version 2.0 (the "License");
95348 * you may not use this file except in compliance with the License.
95349 * You may obtain a copy of the License at
95350 *
95351 * http://www.apache.org/licenses/LICENSE-2.0
95352 *
95353 * Unless required by applicable law or agreed to in writing, software
95354 * distributed under the License is distributed on an "AS IS" BASIS,
95355 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95356 * See the License for the specific language governing permissions and
95357 * limitations under the License.
95358 * =============================================================================
95359 */
95360 class CropAndResizeProgram {
95361 constructor(imageShape, boxShape, cropSize, method, extrapolationValue) {
95362 this.variableNames = ['Image', 'Boxes', 'BoxInd'];
95363 this.outputShape = [];
95364 const [batch, imageHeight, imageWidth, depth] = imageShape;
95365 const [numBoxes,] = boxShape;
95366 const [cropHeight, cropWidth] = cropSize;
95367 this.outputShape = [numBoxes, cropHeight, cropWidth, depth];
95368 const methodId = method === 'bilinear' ? 1 : 0;
95369 const [inputHeightFloat, inputWidthFloat] = [`${imageHeight - 1}.0`, `${imageWidth - 1}.0`];
95370 const [heightRatio, heightScale, inY] = cropHeight > 1 ?
95371 [
95372 `${(imageHeight - 1) / (cropHeight - 1)}`,
95373 '(y2-y1) * height_ratio',
95374 `y1*${inputHeightFloat} + float(y)*(height_scale)`,
95375 ] :
95376 [
95377 '0.0',
95378 '0.0',
95379 `0.5 * (y1+y2) * ${inputHeightFloat}`,
95380 ];
95381 const [widthRatio, widthScale, inX] = cropWidth > 1 ?
95382 [
95383 `${(imageWidth - 1) / (cropWidth - 1)}`,
95384 '(x2-x1) * width_ratio',
95385 `x1*${inputWidthFloat} + float(x)*(width_scale)`,
95386 ] :
95387 [
95388 '0.0',
95389 '0.0',
95390 `0.5 * (x1+x2) * ${inputWidthFloat}`,
95391 ];
95392 // Reference implementation
95393 // tslint:disable-next-line:max-line-length
95394 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
95395 this.userCode = `
95396 const float height_ratio = float(${heightRatio});
95397 const float width_ratio = float(${widthRatio});
95398 void main() {
95399 ivec4 coords = getOutputCoords();
95400 int b = coords[0];
95401 int y = coords[1];
95402 int x = coords[2];
95403 int d = coords[3];
95404
95405 // get box vals
95406 float y1 = getBoxes(b,0);
95407 float x1 = getBoxes(b,1);
95408 float y2 = getBoxes(b,2);
95409 float x2 = getBoxes(b,3);
95410
95411 // get image in batch index
95412 int bInd = round(getBoxInd(b));
95413 if(bInd < 0 || bInd >= ${batch}) {
95414 return;
95415 }
95416
95417 float height_scale = ${heightScale};
95418 float width_scale = ${widthScale};
95419
95420 float in_y = ${inY};
95421 if( in_y < 0.0 || in_y > ${inputHeightFloat} ) {
95422 setOutput(float(${extrapolationValue}));
95423 return;
95424 }
95425 float in_x = ${inX};
95426 if( in_x < 0.0 || in_x > ${inputWidthFloat} ) {
95427 setOutput(float(${extrapolationValue}));
95428 return;
95429 }
95430
95431 vec2 sourceFracIndexCR = vec2(in_x,in_y);
95432 if(${methodId} == 1) {
95433 // Compute the four integer indices.
95434 ivec2 sourceFloorCR = ivec2(sourceFracIndexCR);
95435 ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR));
95436
95437 float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d);
95438 float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d);
95439 float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d);
95440 float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d);
95441
95442 vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR);
95443
95444 float top = topLeft + (topRight - topLeft) * fracCR.x;
95445 float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x;
95446 float newValue = top + (bottom - top) * fracCR.y;
95447 setOutput(newValue);
95448 } else {
95449 // Compute the coordinators of nearest neighbor point.
95450 ivec2 sourceNearestCR = ivec2(floor(
95451 sourceFracIndexCR + vec2(0.5,0.5)));
95452 float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d);
95453 setOutput(newValue);
95454 }
95455 }
95456 `;
95457 }
95458 }
95459
95460 /**
95461 * @license
95462 * Copyright 2020 Google LLC. All Rights Reserved.
95463 * Licensed under the Apache License, Version 2.0 (the "License");
95464 * you may not use this file except in compliance with the License.
95465 * You may obtain a copy of the License at
95466 *
95467 * http://www.apache.org/licenses/LICENSE-2.0
95468 *
95469 * Unless required by applicable law or agreed to in writing, software
95470 * distributed under the License is distributed on an "AS IS" BASIS,
95471 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95472 * See the License for the specific language governing permissions and
95473 * limitations under the License.
95474 * =============================================================================
95475 */
95476 const cropAndResize = (args) => {
95477 const { inputs, backend, attrs } = args;
95478 const { image, boxes, boxInd } = inputs;
95479 const { cropSize, method, extrapolationValue } = attrs;
95480 const program = new CropAndResizeProgram(image.shape, boxes.shape, cropSize, method, extrapolationValue);
95481 return backend.runWebGLProgram(program, [image, boxes, boxInd], 'float32');
95482 };
95483 const cropAndResizeConfig = {
95484 kernelName: CropAndResize,
95485 backendName: 'webgl',
95486 kernelFunc: cropAndResize
95487 };
95488
95489 var CumOpType;
95490 (function (CumOpType) {
95491 CumOpType["Prod"] = "*";
95492 CumOpType["Sum"] = "+";
95493 })(CumOpType || (CumOpType = {}));
95494 class CumProgram {
95495 constructor(op, outputShape, exclusive, reverse) {
95496 this.op = op;
95497 this.outputShape = outputShape;
95498 this.variableNames = ['x'];
95499 this.customUniforms = [{ name: 'index', type: 'float' }];
95500 const rank = this.outputShape.length;
95501 const initVal = this.op === CumOpType.Prod ? '1.0' : '0.0';
95502 const val = exclusive ? initVal : `getX(${getCoords(rank, 'coords', this.op)})`;
95503 const length = this.outputShape[this.outputShape.length - 1];
95504 let condition = '';
95505 let idxString = '';
95506 // When exclusive is set, the cum op becomes roll op that copies the
95507 // value from the previous index based on the direction specified by the
95508 // reverse flag.
95509 if (exclusive) {
95510 condition = reverse ? `end != ${length - 1}` : 'end != 0';
95511 idxString = reverse ? 'end + 1' : 'end - 1';
95512 }
95513 else {
95514 condition = reverse ? `end + pow2 < ${length}` : 'end >= pow2';
95515 idxString = (reverse ? 'end + pow2' : 'end - pow2');
95516 }
95517 this.userCode = `
95518 void main() {
95519 ${getCoordsDataType(rank)} coords = getOutputCoords();
95520 int end = ${getFinalCoord(rank, 'coords', this.op)};
95521 float val = ${val};
95522 int pow2 = int(pow(2.0, index));
95523 if (${condition}) {
95524 int idx = ${idxString};
95525 ${getFinalCoord(rank, 'coords', this.op)} = idx;
95526 val ${this.op}= getX(${getCoords(rank, 'coords', this.op)});
95527 }
95528 setOutput(val);
95529 }
95530 `;
95531 }
95532 }
95533 function getCoords(rank, name, op) {
95534 if (rank === 1) {
95535 return `${name}`;
95536 }
95537 else if (rank === 2) {
95538 return `${name}.x, ${name}.y`;
95539 }
95540 else if (rank === 3) {
95541 return `${name}.x, ${name}.y, ${name}.z`;
95542 }
95543 else if (rank === 4) {
95544 return `${name}.x, ${name}.y, ${name}.z, ${name}.w`;
95545 }
95546 else {
95547 throw new Error(`Cumulative ${op} for rank ${rank} is not yet supported`);
95548 }
95549 }
95550 function getFinalCoord(rank, name, op) {
95551 if (rank === 1) {
95552 return `${name}`;
95553 }
95554 else if (rank === 2) {
95555 return `${name}.y`;
95556 }
95557 else if (rank === 3) {
95558 return `${name}.z`;
95559 }
95560 else if (rank === 4) {
95561 return `${name}.w`;
95562 }
95563 else {
95564 throw new Error(`Cumulative ${op} for rank ${rank} is not yet supported`);
95565 }
95566 }
95567
95568 /**
95569 * @license
95570 * Copyright 2022 Google LLC. All Rights Reserved.
95571 * Licensed under the Apache License, Version 2.0 (the "License");
95572 * you may not use this file except in compliance with the License.
95573 * You may obtain a copy of the License at
95574 *
95575 * http://www.apache.org/licenses/LICENSE-2.0
95576 *
95577 * Unless required by applicable law or agreed to in writing, software
95578 * distributed under the License is distributed on an "AS IS" BASIS,
95579 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95580 * See the License for the specific language governing permissions and
95581 * limitations under the License.
95582 * =============================================================================
95583 */
95584 function cumImpl(op, x, backend, axis, exclusive, reverse) {
95585 const xRank = x.shape.length;
95586 const permutation = getAxesPermutation([axis], xRank);
95587 let permutedX = x;
95588 if (permutation != null) {
95589 permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutation } });
95590 }
95591 const permutedAxis = getInnerMostAxes(1, xRank)[0];
95592 if (permutedAxis !== xRank - 1) {
95593 throw new Error(`WebGL cumprod shader expects an inner-most axis=${x.shape.length - 1} ` +
95594 `but got axis=${axis}`);
95595 }
95596 const size = permutedX.shape[permutedAxis];
95597 let result = identity({ inputs: { x: permutedX }, backend });
95598 // Use cum parallel algorithm, inspired by:
95599 // https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
95600 // Note: although the algorithm is called sum, it works for any associtative
95601 // operator with an identity.
95602 for (let i = 0; i <= Math.ceil(Math.log2(size)) - 1; i++) {
95603 const program = new CumProgram(op, permutedX.shape, false, reverse);
95604 const customValues = [[i]];
95605 const prevResult = result;
95606 result =
95607 backend.runWebGLProgram(program, [result], result.dtype, customValues);
95608 backend.disposeIntermediateTensorInfo(prevResult);
95609 }
95610 // For exclusive cum, shift the end result in the direction of product or sum
95611 // and add 1 for product or 0 for sum to the front index.
95612 if (exclusive) {
95613 const program = new CumProgram(op, permutedX.shape, exclusive, reverse);
95614 const prevResult = result;
95615 result = backend.runWebGLProgram(program, [result], result.dtype);
95616 backend.disposeIntermediateTensorInfo(prevResult);
95617 }
95618 if (permutation != null) {
95619 const reversePermutation = getUndoAxesPermutation(permutation);
95620 const reverseTransposedResult = transpose({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } });
95621 backend.disposeIntermediateTensorInfo(result);
95622 backend.disposeIntermediateTensorInfo(permutedX);
95623 return reverseTransposedResult;
95624 }
95625 return result;
95626 }
95627
95628 /**
95629 * @license
95630 * Copyright 2022 Google LLC. All Rights Reserved.
95631 * Licensed under the Apache License, Version 2.0 (the "License");
95632 * you may not use this file except in compliance with the License.
95633 * You may obtain a copy of the License at
95634 *
95635 * http://www.apache.org/licenses/LICENSE-2.0
95636 *
95637 * Unless required by applicable law or agreed to in writing, software
95638 * distributed under the License is distributed on an "AS IS" BASIS,
95639 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95640 * See the License for the specific language governing permissions and
95641 * limitations under the License.
95642 * =============================================================================
95643 */
95644 function cumprod(args) {
95645 const { inputs, backend, attrs } = args;
95646 const { x } = inputs;
95647 const { axis, exclusive, reverse } = attrs;
95648 return cumImpl(CumOpType.Prod, x, backend, axis, exclusive, reverse);
95649 }
95650 const cumprodConfig = {
95651 kernelName: Cumprod,
95652 backendName: 'webgl',
95653 kernelFunc: cumprod
95654 };
95655
95656 /**
95657 * @license
95658 * Copyright 2022 Google LLC. All Rights Reserved.
95659 * Licensed under the Apache License, Version 2.0 (the "License");
95660 * you may not use this file except in compliance with the License.
95661 * You may obtain a copy of the License at
95662 *
95663 * http://www.apache.org/licenses/LICENSE-2.0
95664 *
95665 * Unless required by applicable law or agreed to in writing, software
95666 * distributed under the License is distributed on an "AS IS" BASIS,
95667 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95668 * See the License for the specific language governing permissions and
95669 * limitations under the License.
95670 * =============================================================================
95671 */
95672 function cumsum(args) {
95673 const { inputs, backend, attrs } = args;
95674 const { x } = inputs;
95675 const { axis, exclusive, reverse } = attrs;
95676 return cumImpl(CumOpType.Sum, x, backend, axis, exclusive, reverse);
95677 }
95678 const cumsumConfig = {
95679 kernelName: Cumsum,
95680 backendName: 'webgl',
95681 kernelFunc: cumsum
95682 };
95683
95684 /**
95685 * @license
95686 * Copyright 2020 Google LLC. All Rights Reserved.
95687 * Licensed under the Apache License, Version 2.0 (the "License");
95688 * you may not use this file except in compliance with the License.
95689 * You may obtain a copy of the License at
95690 *
95691 * http://www.apache.org/licenses/LICENSE-2.0
95692 *
95693 * Unless required by applicable law or agreed to in writing, software
95694 * distributed under the License is distributed on an "AS IS" BASIS,
95695 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95696 * See the License for the specific language governing permissions and
95697 * limitations under the License.
95698 * =============================================================================
95699 */
95700 function denseBincount(args) {
95701 const { inputs, backend, attrs } = args;
95702 const { x, weights } = inputs;
95703 const { size, binaryOutput } = attrs;
95704 if (x.shape.length === 1) {
95705 const xVals = backend.readSync(x.dataId);
95706 const weightsVals = backend.readSync(weights.dataId);
95707 const outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size);
95708 return backend.makeTensorInfo([size], weights.dtype, outVals);
95709 }
95710 else if (x.shape.length === 2) {
95711 const xBuf = backend.bufferSync(x);
95712 const weightsBuf = backend.bufferSync(weights);
95713 const outBuf = bincountReduceImplCPU(xBuf, weightsBuf, size, binaryOutput);
95714 return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values);
95715 }
95716 throw new Error(`Error in denseBincount: input must be at most rank 2, but got rank` +
95717 `${x.shape.length}.`);
95718 }
95719 const denseBincountConfig = {
95720 kernelName: DenseBincount,
95721 backendName: 'webgl',
95722 kernelFunc: denseBincount
95723 };
95724
95725 /**
95726 * @license
95727 * Copyright 2018 Google LLC. All Rights Reserved.
95728 * Licensed under the Apache License, Version 2.0 (the "License");
95729 * you may not use this file except in compliance with the License.
95730 * You may obtain a copy of the License at
95731 *
95732 * http://www.apache.org/licenses/LICENSE-2.0
95733 *
95734 * Unless required by applicable law or agreed to in writing, software
95735 * distributed under the License is distributed on an "AS IS" BASIS,
95736 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95737 * See the License for the specific language governing permissions and
95738 * limitations under the License.
95739 * =============================================================================
95740 */
95741 class DepthToSpaceProgram {
95742 constructor(outputShape, blockSize, dataFormat) {
95743 this.variableNames = ['x'];
95744 this.outputShape = [];
95745 this.outputShape = outputShape;
95746 this.blockSize = blockSize;
95747 this.dataFormat = dataFormat;
95748 this.userCode = `
95749 void main() {
95750 ivec4 coords = getOutputCoords();
95751 int b = coords[0];
95752 int h = ${this.getHeightCoordString()};
95753 int w = ${this.getWidthCoordString()};
95754 int d = ${this.getDepthCoordString()};
95755
95756 int in_h = h / ${blockSize};
95757 int offset_h = imod(h, ${blockSize});
95758 int in_w = w / ${blockSize};
95759 int offset_w = imod(w, ${blockSize});
95760 int offset_d = (offset_h * ${blockSize} + offset_w) *
95761 ${this.getOutputDepthSize()};
95762 int in_d = d + offset_d;
95763
95764 float result = ${this.getInputSamplingString()};
95765 setOutput(result);
95766 }
95767 `;
95768 }
95769 getHeightCoordString() {
95770 if (this.dataFormat === 'NHWC') {
95771 return `coords[1]`;
95772 }
95773 else {
95774 return `coords[2]`;
95775 }
95776 }
95777 getWidthCoordString() {
95778 if (this.dataFormat === 'NHWC') {
95779 return `coords[2]`;
95780 }
95781 else {
95782 return `coords[3]`;
95783 }
95784 }
95785 getDepthCoordString() {
95786 if (this.dataFormat === 'NHWC') {
95787 return `coords[3]`;
95788 }
95789 else {
95790 return `coords[1]`;
95791 }
95792 }
95793 getOutputDepthSize() {
95794 if (this.dataFormat === 'NHWC') {
95795 return this.outputShape[3];
95796 }
95797 else {
95798 return this.outputShape[1];
95799 }
95800 }
95801 getInputSamplingString() {
95802 if (this.dataFormat === 'NHWC') {
95803 return `getX(b, in_h, in_w, in_d)`;
95804 }
95805 else {
95806 return `getX(b, in_d, in_h, in_w)`;
95807 }
95808 }
95809 }
95810
95811 /**
95812 * @license
95813 * Copyright 2020 Google LLC. All Rights Reserved.
95814 * Licensed under the Apache License, Version 2.0 (the "License");
95815 * you may not use this file except in compliance with the License.
95816 * You may obtain a copy of the License at
95817 *
95818 * http://www.apache.org/licenses/LICENSE-2.0
95819 *
95820 * Unless required by applicable law or agreed to in writing, software
95821 * distributed under the License is distributed on an "AS IS" BASIS,
95822 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95823 * See the License for the specific language governing permissions and
95824 * limitations under the License.
95825 * =============================================================================
95826 */
95827 function depthToSpace(args) {
95828 const { inputs, backend, attrs } = args;
95829 const { x } = inputs;
95830 const { blockSize, dataFormat } = attrs;
95831 const batchSize = x.shape[0];
95832 const inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2];
95833 const inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3];
95834 const inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1];
95835 const outputHeight = inputHeight * blockSize;
95836 const outputWidth = inputWidth * blockSize;
95837 const outputDepth = inputDepth / (blockSize * blockSize);
95838 const outputShape = (dataFormat === 'NHWC') ?
95839 [batchSize, outputHeight, outputWidth, outputDepth] :
95840 [batchSize, outputDepth, outputHeight, outputWidth];
95841 const program = new DepthToSpaceProgram(outputShape, blockSize, dataFormat);
95842 return backend.runWebGLProgram(program, [x], x.dtype);
95843 }
95844 const depthToSpaceConfig = {
95845 kernelName: DepthToSpace,
95846 backendName: 'webgl',
95847 kernelFunc: depthToSpace
95848 };
95849
95850 /**
95851 * @license
95852 * Copyright 2017 Google LLC. All Rights Reserved.
95853 * Licensed under the Apache License, Version 2.0 (the "License");
95854 * you may not use this file except in compliance with the License.
95855 * You may obtain a copy of the License at
95856 *
95857 * http://www.apache.org/licenses/LICENSE-2.0
95858 *
95859 * Unless required by applicable law or agreed to in writing, software
95860 * distributed under the License is distributed on an "AS IS" BASIS,
95861 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95862 * See the License for the specific language governing permissions and
95863 * limitations under the License.
95864 * =============================================================================
95865 */
95866 class DepthwiseConv2DProgram {
95867 constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) {
95868 this.variableNames = ['x', 'W'];
95869 this.customUniforms = [
95870 { name: 'pads', type: 'ivec2' },
95871 { name: 'strides', type: 'ivec2' },
95872 { name: 'dilations', type: 'ivec2' },
95873 { name: 'inDims', type: 'ivec2' },
95874 ];
95875 this.outputShape = convInfo.outShape;
95876 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
95877 const filterHeight = convInfo.filterHeight;
95878 const filterWidth = convInfo.filterWidth;
95879 const channelMul = convInfo.outChannels / convInfo.inChannels;
95880 let activationSnippet = '', applyActivationSnippet = '';
95881 if (activation) {
95882 if (hasPreluActivation) {
95883 activationSnippet = `float activation(float a) {
95884 float b = getPreluActivationWeightsAtOutCoords();
95885 ${activation}
95886 }`;
95887 }
95888 else if (hasLeakyReluAlpha) {
95889 activationSnippet = `float activation(float a) {
95890 float b = getLeakyreluAlphaAtOutCoords();
95891 ${activation}
95892 }`;
95893 }
95894 else {
95895 activationSnippet = `
95896 float activation(float x) {
95897 ${activation}
95898 }
95899 `;
95900 }
95901 applyActivationSnippet = `result = activation(result);`;
95902 }
95903 const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
95904 if (addBias) {
95905 this.variableNames.push('bias');
95906 }
95907 if (hasPreluActivation) {
95908 this.variableNames.push('preluActivationWeights');
95909 }
95910 if (hasLeakyReluAlpha) {
95911 this.variableNames.push('leakyreluAlpha');
95912 }
95913 this.userCode = `
95914 ${activationSnippet}
95915
95916 void main() {
95917 ivec4 coords = getOutputCoords();
95918 int batch = coords.x;
95919 ivec2 xRCCorner = coords.yz * strides - pads;
95920 int d2 = coords.w;
95921 int d1 = d2 / ${channelMul};
95922 int q = d2 - d1 * ${channelMul};
95923
95924 int xRCorner = xRCCorner.x;
95925 int xCCorner = xRCCorner.y;
95926
95927 // Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).
95928 // ? = to be determined. : = across all values in that axis.
95929 float dotProd = 0.0;
95930 // TO DO(dsmilkov): Flatten the two for loops and vec4 the operations.
95931 for (int wR = 0; wR < ${filterHeight}; wR++) {
95932 int xR = xRCorner + wR * dilations[0];
95933
95934 if (xR < 0 || xR >= inDims[0]) {
95935 continue;
95936 }
95937
95938 for (int wC = 0; wC < ${filterWidth}; wC++) {
95939 int xC = xCCorner + wC * dilations[1];
95940
95941 if (xC < 0 || xC >= inDims[1]) {
95942 continue;
95943 }
95944
95945 float xVal = getX(batch, xR, xC, d1);
95946 float wVal = getW(wR, wC, d1, q);
95947 dotProd += xVal * wVal;
95948 }
95949 }
95950
95951 float result = dotProd;
95952 ${addBiasSnippet}
95953 ${applyActivationSnippet}
95954 setOutput(result);
95955 }
95956 `;
95957 }
95958 }
95959
95960 /**
95961 * @license
95962 * Copyright 2018 Google LLC. All Rights Reserved.
95963 * Licensed under the Apache License, Version 2.0 (the "License");
95964 * you may not use this file except in compliance with the License.
95965 * You may obtain a copy of the License at
95966 *
95967 * http://www.apache.org/licenses/LICENSE-2.0
95968 *
95969 * Unless required by applicable law or agreed to in writing, software
95970 * distributed under the License is distributed on an "AS IS" BASIS,
95971 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
95972 * See the License for the specific language governing permissions and
95973 * limitations under the License.
95974 * =============================================================================
95975 */
95976 class DepthwiseConvPacked2DProgram {
95977 constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) {
95978 this.variableNames = ['x', 'W'];
95979 this.packedInputs = true;
95980 this.packedOutput = true;
95981 this.customUniforms = [
95982 { name: 'pads', type: 'ivec2' },
95983 { name: 'strides', type: 'ivec2' },
95984 { name: 'dilations', type: 'ivec2' },
95985 { name: 'inDims', type: 'ivec2' },
95986 ];
95987 this.outputShape = convInfo.outShape;
95988 this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
95989 const channelMul = convInfo.outChannels / convInfo.inChannels;
95990 const padLeft = convInfo.padInfo.left;
95991 const strideWidth = convInfo.strideWidth;
95992 const dilationWidth = convInfo.dilationWidth;
95993 const filterHeight = convInfo.filterHeight;
95994 const filterWidth = convInfo.filterWidth;
95995 const texelsAcross = filterWidth;
95996 let mainLoop = `
95997 int xR; int xC; int xCOffset;
95998 vec4 wTexel; vec4 previous; vec4 final;`;
95999 for (let c = 0; c < filterWidth; c++) {
96000 mainLoop += `
96001 vec4 xTexelC${c * 2};
96002 int xTexelC${c * 2}Ready;
96003 vec4 xTexelC${c * 2 + 1};
96004 int xTexelC${c * 2 + 1}Ready;
96005 vec4 xC${c};`;
96006 }
96007 /**
96008 * This vectorized implementation works by gathering the values needed for
96009 * each output channel's dot product into vec4's and then multiplying them
96010 * all together (this happens in the final double for-loop below). Most of
96011 * the main loop consists of constructing these vec4's with the minimum
96012 * number of texture2D calls, which means making use of all four returned
96013 * values from a texture2D call at once.
96014 */
96015 mainLoop += `
96016 for (int r = 0; r < ${filterHeight}; r++) {
96017 `;
96018 for (let c = 0; c < filterWidth; c++) {
96019 mainLoop += `
96020 xTexelC${c * 2} = vec4(0.0);
96021 xTexelC${c * 2}Ready = 0;
96022 xTexelC${c * 2 + 1} = vec4(0.0);
96023 xTexelC${c * 2 + 1}Ready = 0;
96024 xC${c} = vec4(0.0);`;
96025 }
96026 mainLoop += `
96027 xR = xRCorner + r * dilations[0];
96028 if (xR >=0 && xR < inDims[0]) {
96029 `;
96030 for (let texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) {
96031 const colIndex = texelC * 2;
96032 mainLoop += `
96033 xC = xCCorner + ${colIndex * dilationWidth};
96034 `;
96035 if (strideWidth === 1) {
96036 if (colIndex < filterWidth) {
96037 // If padding is odd, the outer texels have to be composed.
96038 if (padLeft % 2 === 1) {
96039 // TODO: Ensure vec4 previous does not result in redundant sample,
96040 // and avoid setting xTexelRC's that exceed the boundary in the
96041 // first place rather than resetting them to vec4(0)).
96042 // To compute xCOffset:
96043 // - If padding is odd, we must add 1 to ensure we ask for an
96044 // even-numbered row.
96045 // - We subtract 2 to access the previous texel.
96046 mainLoop += `
96047 xCOffset = xC + 1;
96048 if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) {
96049 xTexelC${colIndex} = getX(batch, xR, xCOffset, d1);
96050
96051 // Need to manually clear unused channels in case
96052 // we're reading from recycled texture.
96053 if (xCOffset + 1 >= inDims[1]) {
96054 xTexelC${colIndex}.zw = vec2(0.0);
96055 }
96056 xTexelC${colIndex}Ready = 1;
96057 }
96058 `;
96059 // This texel has been read in previous iteration if the dilation
96060 // is 1.
96061 if (dilationWidth === 1 && colIndex > 0) {
96062 mainLoop += `
96063 xC${colIndex} = vec4(xTexelC${colIndex - 2}.zw, xTexelC${colIndex}.xy);
96064 `;
96065 }
96066 else {
96067 mainLoop += `
96068 xCOffset = xC + 1 - 2;
96069
96070 if (xCOffset >= 0 && xCOffset < inDims[1]) {
96071 previous = getX(batch, xR, xCOffset, d1);
96072
96073 // Need to manually clear unused channels in case
96074 // we're reading from recycled texture.
96075 if (xCOffset + 1 >= inDims[1]) {
96076 previous.zw = vec2(0.0);
96077 }
96078
96079 xC${colIndex} = vec4(previous.zw, xTexelC${colIndex}.xy);
96080 } else {
96081 xC${colIndex} = vec4(0.0, 0.0, xTexelC${colIndex}.xy);
96082 }
96083 `;
96084 }
96085 }
96086 else {
96087 // Padding is even, so xRC corresponds to a single texel.
96088 mainLoop += `
96089 if (xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) {
96090 xTexelC${colIndex} = getX(batch, xR, xC, d1);
96091 if (xC + 1 >= inDims[1]) {
96092 xTexelC${colIndex}.zw = vec2(0.0);
96093 }
96094 xTexelC${colIndex}Ready = 1;
96095 }
96096
96097 xC${colIndex} = xTexelC${colIndex};
96098 `;
96099 }
96100 if (colIndex + 1 < filterWidth) {
96101 // If dilation is even, the second entry should match the first
96102 // (either both are composed or both are single samples). But if
96103 // dilation is odd, then the second entry should be the opposite
96104 // of the first (if the first is composed, the second is a single
96105 // sample, and vice versa.)
96106 const nextTexelOffset = padLeft % 2 === 0 ?
96107 nearestLargerEven(dilationWidth) :
96108 dilationWidth;
96109 if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) ||
96110 (dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) {
96111 mainLoop += `
96112 xCOffset = xC + imod(pads[1], 2) + ${nextTexelOffset};
96113
96114 if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
96115 xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
96116
96117 // Need to manually clear unused channels in case
96118 // we're reading from recycled texture.
96119 if (xCOffset + 1 >= inDims[1]) {
96120 xTexelC${colIndex + 1}.zw = vec2(0.0);
96121 }
96122 xTexelC${colIndex + 1}Ready = 1;
96123 }
96124 `;
96125 // If dilation > 1 then the xRC's will not be able to share any
96126 // values, so each xRC will require two unique calls to getX.
96127 if (dilationWidth > 1) {
96128 mainLoop += `
96129 xCOffset -= 2;
96130 if (xCOffset >= 0 && xCOffset < inDims[1]) {
96131 previous = getX(batch, xR, xCOffset, d1);
96132 xC${colIndex + 1} = vec4(previous.zw, xTexelC${colIndex + 1}.xy);
96133 } else {
96134 xC${colIndex + 1} = vec4(0.0, 0.0, xTexelC${colIndex + 1}.xy);
96135 }
96136 `;
96137 }
96138 else {
96139 mainLoop += `
96140 xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.xy);
96141 `;
96142 }
96143 }
96144 else {
96145 // If dilation is 1 and padding is odd, we have already read the
96146 // texel when constructing the previous x value. Here we can
96147 // simply skip the texture read.
96148 if (nextTexelOffset === 1) {
96149 mainLoop += `
96150 xC${colIndex + 1} = xTexelC${colIndex};
96151 `;
96152 }
96153 else {
96154 mainLoop += `
96155 xCOffset = xC + ${nextTexelOffset};
96156
96157 if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
96158 xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
96159 if (xCOffset + 1 >= inDims[1]) {
96160 xTexelC${colIndex + 1}.zw = vec2(0.0);
96161 }
96162 xTexelC${colIndex + 1}Ready = 1;
96163 }
96164
96165 xC${colIndex + 1} = xTexelC${colIndex + 1};
96166 `;
96167 }
96168 }
96169 }
96170 }
96171 }
96172 else { // stride === 2
96173 if (colIndex < filterWidth) {
96174 // Depending on whether padLeft is even or odd, we want either the
96175 // xy or zw channels from X texels for xC${colIndex}. If padLeft is
96176 // even, xC${colIndex +1} is simply the zw channels of texels we've
96177 // already sampled. But if padLeft is odd, xC{$c + 1}.zw will
96178 // need to come from the xy channels of a new texel, hence the `
96179 // vec4
96180 // final` initialized below.
96181 if (padLeft % 2 === 1) {
96182 mainLoop += `
96183 xCOffset = xC + 1 - strides[1];
96184 if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) {
96185 xTexelC${colIndex} = getX(batch, xR, xCOffset, d1);
96186 // Need to manually clear unused channels in case
96187 // we're reading from recycled texture.
96188 if (xCOffset + 1 >= inDims[1]) {
96189 xTexelC${colIndex}.zw = vec2(0.0);
96190 }
96191 xTexelC${colIndex}Ready = 1;
96192 }
96193
96194 if(xC + 1 >= 0 && xC + 1 < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
96195 xTexelC${colIndex + 1} = getX(batch, xR, xC + 1, d1);
96196 // Need to manually clear unused channels in case
96197 // we're reading from recycled texture.
96198 if (xC + 2 >= inDims[1]) {
96199 xTexelC${colIndex + 1}.zw = vec2(0.0);
96200 }
96201 xTexelC${colIndex + 1}Ready = 1;
96202 }
96203
96204 xC${colIndex} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw);
96205 `;
96206 if (colIndex + 1 < filterWidth) {
96207 mainLoop += `
96208 final = vec4(0.0);
96209 xCOffset = xC + 1 + strides[1];
96210 if(xCOffset >= 0 && xCOffset < inDims[1]) {
96211 final = getX(batch, xR, xCOffset, d1);
96212 }
96213 xC${colIndex + 1} = vec4(xTexelC${colIndex + 1}.xy, final.xy);
96214 `;
96215 }
96216 }
96217 else {
96218 mainLoop += `
96219 if(xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) {
96220 xTexelC${colIndex} = getX(batch, xR, xC, d1);
96221 if (xC + 1 >= inDims[1]) {
96222 xTexelC${colIndex}.zw = vec2(0.0);
96223 }
96224 xTexelC${colIndex}Ready = 1;
96225 }
96226
96227 xCOffset = xC + strides[1];
96228 if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
96229 xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
96230 if (xCOffset + 1 >= inDims[1]) {
96231 xTexelC${colIndex + 1}.zw = vec2(0.);
96232 }
96233 xTexelC${colIndex + 1}Ready = 1;
96234 }
96235
96236 xC${colIndex} = vec4(
96237 xTexelC${colIndex}.xy, xTexelC${colIndex + 1}.xy);
96238 `;
96239 if (colIndex + 1 < filterWidth) {
96240 mainLoop += `
96241 xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw);
96242 `;
96243 }
96244 }
96245 }
96246 }
96247 // localize the dotProd accumulation within the loop, the theory is for
96248 // GPU with limited cache, accumulate sum across large amount of
96249 // veriables will cause lots of cache misses. (i.e. 5x5 filter will have
96250 // 50 variables)
96251 if (colIndex < filterWidth) {
96252 mainLoop += `
96253 wTexel = getW(r, ${colIndex}, d1, q);
96254 dotProd += xC${colIndex} * vec4(wTexel.xz, wTexel.xz);
96255 `;
96256 if (colIndex + 1 < filterWidth) {
96257 mainLoop += `
96258 wTexel = getW(r, ${colIndex + 1}, d1, q);
96259 dotProd += xC${colIndex + 1} * vec4(wTexel.xz, wTexel.xz);
96260 `;
96261 }
96262 }
96263 }
96264 mainLoop += `
96265 }
96266 `;
96267 mainLoop += `
96268 }
96269 `;
96270 let activationSnippet = '', applyActivationSnippet = '';
96271 if (activation) {
96272 if (hasPreluActivation) {
96273 activationSnippet = `vec4 activation(vec4 a) {
96274 vec4 b = getPreluActivationWeightsAtOutCoords();
96275 ${activation}
96276 }`;
96277 }
96278 else if (hasLeakyReluAlpha) {
96279 activationSnippet = `vec4 activation(vec4 a) {
96280 vec4 b = getLeakyreluAlphaAtOutCoords();
96281 ${activation}
96282 }`;
96283 }
96284 else {
96285 activationSnippet = `vec4 activation(vec4 x) {
96286 ${activation}
96287 }`;
96288 }
96289 applyActivationSnippet = `result = activation(result);`;
96290 }
96291 const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
96292 if (addBias) {
96293 this.variableNames.push('bias');
96294 }
96295 if (hasPreluActivation) {
96296 this.variableNames.push('preluActivationWeights');
96297 }
96298 if (hasLeakyReluAlpha) {
96299 this.variableNames.push('leakyreluAlpha');
96300 }
96301 this.userCode = `
96302 ${activationSnippet}
96303
96304 void main() {
96305 ivec4 coords = getOutputCoords();
96306 int batch = coords.x;
96307 ivec2 xRCCorner = coords.yz * strides - pads;
96308 int d2 = coords.w;
96309 int d1 = d2 / ${channelMul};
96310 int q = d2 - d1 * ${channelMul};
96311 int xRCorner = xRCCorner.x;
96312 int xCCorner = xRCCorner.y;
96313
96314 //intialize dotProd with a small epsilon seems to reduce GPU accuracy loss.
96315 vec4 dotProd = vec4(0.000000000000001);
96316
96317 ${mainLoop}
96318
96319 vec4 result = dotProd - vec4(0.000000000000001);
96320 ${addBiasSnippet}
96321 ${applyActivationSnippet}
96322 setOutput(result);
96323 }
96324 `;
96325 }
96326 }
96327
96328 /**
96329 * @license
96330 * Copyright 2020 Google LLC. All Rights Reserved.
96331 * Licensed under the Apache License, Version 2.0 (the "License");
96332 * you may not use this file except in compliance with the License.
96333 * You may obtain a copy of the License at
96334 *
96335 * http://www.apache.org/licenses/LICENSE-2.0
96336 *
96337 * Unless required by applicable law or agreed to in writing, software
96338 * distributed under the License is distributed on an "AS IS" BASIS,
96339 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96340 * See the License for the specific language governing permissions and
96341 * limitations under the License.
96342 * =============================================================================
96343 */
96344 function depthwiseConv2dNative(args) {
96345 const { inputs, backend, attrs } = args;
96346 const { x, filter } = inputs;
96347 const { strides, pad, dilations, dimRoundingMode } = attrs;
96348 let $dilations = dilations;
96349 if ($dilations == null) {
96350 $dilations = [1, 1];
96351 }
96352 assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +
96353 `1. Got strides ${strides} and dilations '${$dilations}'`);
96354 const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */);
96355 let program;
96356 if (env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 &&
96357 convInfo.outChannels / convInfo.inChannels === 1) {
96358 program = new DepthwiseConvPacked2DProgram(convInfo);
96359 }
96360 else {
96361 program = new DepthwiseConv2DProgram(convInfo);
96362 }
96363 const customValues = [
96364 [convInfo.padInfo.top, convInfo.padInfo.left],
96365 [convInfo.strideHeight, convInfo.strideWidth],
96366 [convInfo.dilationHeight, convInfo.dilationWidth],
96367 [convInfo.inHeight, convInfo.inWidth]
96368 ];
96369 return backend.runWebGLProgram(program, [x, filter], 'float32', customValues);
96370 }
96371 const depthwiseConv2dNativeConfig = {
96372 kernelName: DepthwiseConv2dNative,
96373 backendName: 'webgl',
96374 kernelFunc: depthwiseConv2dNative,
96375 };
96376
96377 /**
96378 * @license
96379 * Copyright 2018 Google LLC. All Rights Reserved.
96380 * Licensed under the Apache License, Version 2.0 (the "License");
96381 * you may not use this file except in compliance with the License.
96382 * You may obtain a copy of the License at
96383 *
96384 * http://www.apache.org/licenses/LICENSE-2.0
96385 *
96386 * Unless required by applicable law or agreed to in writing, software
96387 * distributed under the License is distributed on an "AS IS" BASIS,
96388 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96389 * See the License for the specific language governing permissions and
96390 * limitations under the License.
96391 * =============================================================================
96392 */
96393 class DepthwiseConv2DDerFilterProgram {
96394 constructor(convInfo) {
96395 this.variableNames = ['x', 'dy'];
96396 this.outputShape = convInfo.filterShape;
96397 const strideHeight = convInfo.strideHeight;
96398 const strideWidth = convInfo.strideWidth;
96399 const padTop = convInfo.padInfo.top;
96400 const padLeft = convInfo.padInfo.left;
96401 const channelMul = convInfo.outChannels / convInfo.inChannels;
96402 this.userCode = `
96403 void main() {
96404 ivec4 coords = getOutputCoords();
96405 int wR = coords.x;
96406 int wC = coords.y;
96407 int d1 = coords.z;
96408 int dm = coords.w;
96409 int d2 = d1 * ${channelMul} + dm;
96410
96411 float dotProd = 0.0;
96412
96413 // TO DO: Vec4 over the batch size
96414 for (int b = 0; b < ${convInfo.batchSize}; b++) {
96415 for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
96416 int xR = wR + yR * ${strideHeight} - ${padTop};
96417
96418 if (xR < 0 || xR >= ${convInfo.inHeight}) {
96419 continue;
96420 }
96421
96422 for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
96423 int xC = wC + yC * ${strideWidth} - ${padLeft};
96424
96425 if (xC < 0 || xC >= ${convInfo.inWidth}) {
96426 continue;
96427 }
96428
96429 float dyValue = getDy(b, yR, yC, d2);
96430 float xValue = getX(b, xR, xC, d1);
96431 dotProd += (xValue * dyValue);
96432 }
96433 }
96434 }
96435 setOutput(dotProd);
96436 }
96437 `;
96438 }
96439 }
96440 class DepthwiseConv2DDerInputProgram {
96441 constructor(convInfo) {
96442 this.variableNames = ['dy', 'W'];
96443 this.outputShape = convInfo.inShape;
96444 const filterHeight = convInfo.filterHeight;
96445 const filterWidth = convInfo.filterWidth;
96446 const strideHeight = convInfo.strideHeight;
96447 const strideWidth = convInfo.strideWidth;
96448 const padTop = filterHeight - 1 - convInfo.padInfo.top;
96449 const padLeft = filterWidth - 1 - convInfo.padInfo.left;
96450 const channelMul = convInfo.outChannels / convInfo.inChannels;
96451 this.userCode = `
96452 const ivec2 pads = ivec2(${padTop}, ${padLeft});
96453
96454 void main() {
96455 ivec4 coords = getOutputCoords();
96456 int batch = coords[0];
96457 int d1 = coords[3];
96458 ivec2 dyCorner = coords.yz - pads;
96459 int dyRCorner = dyCorner.x;
96460 int dyCCorner = dyCorner.y;
96461
96462 float dotProd = 0.0;
96463
96464 for (int wR = 0; wR < ${filterHeight}; wR++) {
96465 float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
96466
96467 if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
96468 continue;
96469 }
96470 int idyR = int(dyR);
96471
96472 int wRPerm = ${filterHeight} - 1 - wR;
96473
96474 for (int wC = 0; wC < ${filterWidth}; wC++) {
96475 float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
96476
96477 if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
96478 fract(dyC) > 0.0) {
96479 continue;
96480 }
96481 int idyC = int(dyC);
96482
96483 int wCPerm = ${filterWidth} - 1 - wC;
96484
96485 // TO DO: Vec4 over the channelMul
96486 for (int dm = 0; dm < ${channelMul}; dm++) {
96487 int d2 = d1 * ${channelMul} + dm;
96488 float xValue = getDy(batch, idyR, idyC, d2);
96489 float wValue = getW(wRPerm, wCPerm, d1, dm);
96490 dotProd += xValue * wValue;
96491 }
96492 }
96493 }
96494 setOutput(dotProd);
96495 }
96496 `;
96497 }
96498 }
96499
96500 /**
96501 * @license
96502 * Copyright 2020 Google LLC. All Rights Reserved.
96503 * Licensed under the Apache License, Version 2.0 (the "License");
96504 * you may not use this file except in compliance with the License.
96505 * You may obtain a copy of the License at
96506 *
96507 * http://www.apache.org/licenses/LICENSE-2.0
96508 *
96509 * Unless required by applicable law or agreed to in writing, software
96510 * distributed under the License is distributed on an "AS IS" BASIS,
96511 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96512 * See the License for the specific language governing permissions and
96513 * limitations under the License.
96514 * =============================================================================
96515 */
96516 function depthwiseConv2dNativeBackpropFilter(args) {
96517 const { inputs, backend, attrs } = args;
96518 const { x, dy } = inputs;
96519 const { strides, dilations, pad, dimRoundingMode, filterShape } = attrs;
96520 const convInfo = computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
96521 const program = new DepthwiseConv2DDerFilterProgram(convInfo);
96522 return backend.runWebGLProgram(program, [x, dy], 'float32');
96523 }
96524 const depthwiseConv2dNativeBackpropFilterConfig = {
96525 kernelName: DepthwiseConv2dNativeBackpropFilter,
96526 backendName: 'webgl',
96527 kernelFunc: depthwiseConv2dNativeBackpropFilter
96528 };
96529
96530 /**
96531 * @license
96532 * Copyright 2020 Google LLC. All Rights Reserved.
96533 * Licensed under the Apache License, Version 2.0 (the "License");
96534 * you may not use this file except in compliance with the License.
96535 * You may obtain a copy of the License at
96536 *
96537 * http://www.apache.org/licenses/LICENSE-2.0
96538 *
96539 * Unless required by applicable law or agreed to in writing, software
96540 * distributed under the License is distributed on an "AS IS" BASIS,
96541 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96542 * See the License for the specific language governing permissions and
96543 * limitations under the License.
96544 * =============================================================================
96545 */
96546 function depthwiseConv2dNativeBackpropInput(args) {
96547 const { inputs, backend, attrs } = args;
96548 const { dy, filter } = inputs;
96549 const { strides, dilations, pad, dimRoundingMode, inputShape } = attrs;
96550 const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
96551 const program = new DepthwiseConv2DDerInputProgram(convInfo);
96552 return backend.runWebGLProgram(program, [dy, filter], 'float32');
96553 }
96554 const depthwiseConv2dNativeBackpropInputConfig = {
96555 kernelName: DepthwiseConv2dNativeBackpropInput,
96556 backendName: 'webgl',
96557 kernelFunc: depthwiseConv2dNativeBackpropInput
96558 };
96559
96560 /**
96561 * @license
96562 * Copyright 2019 Google LLC. All Rights Reserved.
96563 * Licensed under the Apache License, Version 2.0 (the "License");
96564 * you may not use this file except in compliance with the License.
96565 * You may obtain a copy of the License at
96566 *
96567 * http://www.apache.org/licenses/LICENSE-2.0
96568 *
96569 * Unless required by applicable law or agreed to in writing, software
96570 * distributed under the License is distributed on an "AS IS" BASIS,
96571 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96572 * See the License for the specific language governing permissions and
96573 * limitations under the License.
96574 * =============================================================================
96575 */
96576 class DiagProgram {
96577 constructor(size) {
96578 this.variableNames = ['X'];
96579 this.outputShape = [size, size];
96580 this.userCode = `
96581 void main() {
96582 ivec2 coords = getOutputCoords();
96583 float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;
96584 setOutput(val);
96585 }
96586 `;
96587 }
96588 }
96589
96590 /**
96591 * @license
96592 * Copyright 2020 Google LLC. All Rights Reserved.
96593 * Licensed under the Apache License, Version 2.0 (the "License");
96594 * you may not use this file except in compliance with the License.
96595 * You may obtain a copy of the License at
96596 *
96597 * http://www.apache.org/licenses/LICENSE-2.0
96598 *
96599 * Unless required by applicable law or agreed to in writing, software
96600 * distributed under the License is distributed on an "AS IS" BASIS,
96601 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96602 * See the License for the specific language governing permissions and
96603 * limitations under the License.
96604 * =============================================================================
96605 */
96606 function diag(args) {
96607 const { inputs, backend } = args;
96608 const { x } = inputs;
96609 const outShape = [...x.shape, ...x.shape];
96610 const xSize = sizeFromShape(x.shape);
96611 const flat = reshape({ inputs: { x }, backend, attrs: { shape: [xSize] } });
96612 const program = new DiagProgram(xSize);
96613 const res = backend.runWebGLProgram(program, [flat], flat.dtype);
96614 const out = reshape({ inputs: { x: res }, backend, attrs: { shape: outShape } });
96615 backend.disposeIntermediateTensorInfo(flat);
96616 backend.disposeIntermediateTensorInfo(res);
96617 return out;
96618 }
96619 const diagConfig = {
96620 kernelName: Diag,
96621 backendName: 'webgl',
96622 kernelFunc: diag
96623 };
96624
96625 /**
96626 * @license
96627 * Copyright 2017 Google LLC. All Rights Reserved.
96628 * Licensed under the Apache License, Version 2.0 (the "License");
96629 * you may not use this file except in compliance with the License.
96630 * You may obtain a copy of the License at
96631 *
96632 * http://www.apache.org/licenses/LICENSE-2.0
96633 *
96634 * Unless required by applicable law or agreed to in writing, software
96635 * distributed under the License is distributed on an "AS IS" BASIS,
96636 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96637 * See the License for the specific language governing permissions and
96638 * limitations under the License.
96639 * =============================================================================
96640 */
96641 class Dilation2DProgram {
96642 constructor(convInfo) {
96643 this.variableNames = ['x', 'W'];
96644 this.outputShape = convInfo.outShape;
96645 const { inHeight, inWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth } = convInfo;
96646 const { top: padTop, left: padLeft } = padInfo;
96647 this.userCode = `
96648 const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
96649 const ivec2 pads = ivec2(${padTop}, ${padLeft});
96650 const float neg_infinity = -3.4e38;
96651
96652 void main() {
96653 ivec4 coords = getOutputCoords();
96654 int batch = coords.x;
96655 int d1 = coords.w;
96656 ivec2 outTopLeftCorner =
96657 coords.yz * strides - pads;
96658 int hBeg = outTopLeftCorner.x;
96659 int wBeg = outTopLeftCorner.y;
96660
96661 float curVal = neg_infinity;
96662 for (int h = 0; h < ${filterHeight}; h++) {
96663 int hIn = hBeg + h * ${dilationHeight};
96664
96665 if (hIn >= 0 && hIn < ${inHeight}) {
96666 for (int w = 0; w < ${filterWidth}; w++) {
96667 int wIn = wBeg + w * ${dilationWidth};
96668
96669 if (wIn >= 0 && wIn < ${inWidth}) {
96670 float xVal = getX(batch, hIn, wIn, d1);
96671 float wVal = getW(h, w, d1);
96672
96673 float val = xVal + wVal;
96674 if (val > curVal) {
96675 curVal = val;
96676 }
96677 }
96678 }
96679 }
96680 }
96681
96682 float result = curVal;
96683 setOutput(result);
96684 }
96685 `;
96686 }
96687 }
96688
96689 /**
96690 * @license
96691 * Copyright 2020 Google LLC. All Rights Reserved.
96692 * Licensed under the Apache License, Version 2.0 (the "License");
96693 * you may not use this file except in compliance with the License.
96694 * You may obtain a copy of the License at
96695 *
96696 * http://www.apache.org/licenses/LICENSE-2.0
96697 *
96698 * Unless required by applicable law or agreed to in writing, software
96699 * distributed under the License is distributed on an "AS IS" BASIS,
96700 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96701 * See the License for the specific language governing permissions and
96702 * limitations under the License.
96703 * =============================================================================
96704 */
96705 function dilation2D(args) {
96706 const { inputs, backend, attrs } = args;
96707 const { x, filter } = inputs;
96708 const { strides, pad, dilations } = attrs;
96709 const convInfo = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations);
96710 let out;
96711 const program = new Dilation2DProgram(convInfo);
96712 out = backend.runWebGLProgram(program, [x, filter], 'float32');
96713 const outReshaped = reshape({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } });
96714 backend.disposeIntermediateTensorInfo(out);
96715 return outReshaped;
96716 }
96717 const dilation2DConfig = {
96718 kernelName: Dilation2D,
96719 backendName: 'webgl',
96720 kernelFunc: dilation2D,
96721 };
96722
96723 /**
96724 * @license
96725 * Copyright 2021 Google LLC. All Rights Reserved.
96726 * Licensed under the Apache License, Version 2.0 (the "License");
96727 * you may not use this file except in compliance with the License.
96728 * You may obtain a copy of the License at
96729 *
96730 * http://www.apache.org/licenses/LICENSE-2.0
96731 *
96732 * Unless required by applicable law or agreed to in writing, software
96733 * distributed under the License is distributed on an "AS IS" BASIS,
96734 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96735 * See the License for the specific language governing permissions and
96736 * limitations under the License.
96737 * =============================================================================
96738 */
96739 function einsum(args) {
96740 const { inputs, backend, attrs } = args;
96741 const { equation } = attrs;
96742 const tensors = inputs;
96743 const { allDims, summedDims, idDims } = decodeEinsumEquation(equation, tensors.length);
96744 checkEinsumDimSizes(allDims.length, idDims, tensors);
96745 const { path, steps } = getEinsumComputePath(summedDims, idDims);
96746 const nSteps = steps.length;
96747 let out = null;
96748 let numDimsRemaining = allDims.length;
96749 const tensorsToDispose = [];
96750 for (let i = 0; i < nSteps; ++i) {
96751 for (const idTerm of steps[i]) {
96752 const { permutationIndices: perm, expandDims: dimsToExpand } = getEinsumPermutation(numDimsRemaining, idDims[idTerm]);
96753 let x;
96754 if (isIdentityPermutation(perm)) {
96755 x = tensors[idTerm];
96756 }
96757 else {
96758 x = transpose({ inputs: { x: tensors[idTerm] }, backend, attrs: { perm } });
96759 tensorsToDispose.push(x);
96760 }
96761 const targetShape = x.shape.slice();
96762 for (let k = 0; k < dimsToExpand.length; ++k) {
96763 targetShape.splice(dimsToExpand[k], 0, 1);
96764 }
96765 if (!arraysEqual(x.shape, targetShape)) {
96766 x = reshape({ inputs: { x }, backend, attrs: { shape: targetShape } });
96767 tensorsToDispose.push(x);
96768 }
96769 if (out === null) {
96770 out = x;
96771 }
96772 else {
96773 // tslint:disable-next-line: no-unnecessary-type-assertion
96774 out = multiply({ inputs: { a: x, b: out }, backend });
96775 tensorsToDispose.push(out);
96776 }
96777 }
96778 if (i < nSteps - 1) {
96779 if (path[i] >= 0) {
96780 out = sum({
96781 inputs: { x: out },
96782 backend,
96783 attrs: {
96784 axis: path[i] - (allDims.length - numDimsRemaining),
96785 keepDims: false
96786 }
96787 });
96788 tensorsToDispose.push(out);
96789 }
96790 numDimsRemaining--;
96791 }
96792 }
96793 // Clean up intermediate tensors.
96794 for (const tensorInfo of tensorsToDispose) {
96795 if (tensorInfo === out) {
96796 continue;
96797 }
96798 backend.disposeIntermediateTensorInfo(tensorInfo);
96799 }
96800 return out;
96801 }
96802 const einsumConfig = {
96803 kernelName: Einsum,
96804 backendName: 'webgl',
96805 kernelFunc: einsum
96806 };
96807
96808 /**
96809 * @license
96810 * Copyright 2020 Google LLC. All Rights Reserved.
96811 * Licensed under the Apache License, Version 2.0 (the "License");
96812 * you may not use this file except in compliance with the License.
96813 * You may obtain a copy of the License at
96814 *
96815 * http://www.apache.org/licenses/LICENSE-2.0
96816 *
96817 * Unless required by applicable law or agreed to in writing, software
96818 * distributed under the License is distributed on an "AS IS" BASIS,
96819 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96820 * See the License for the specific language governing permissions and
96821 * limitations under the License.
96822 * =============================================================================
96823 */
96824 const ELU = `return (x >= 0.0) ? x : (exp(x) - 1.0);`;
96825 const ELU_PACKED = `
96826 vec4 result;
96827
96828 result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);
96829 result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);
96830 result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);
96831 result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);
96832
96833 return result;
96834`;
96835 const elu = unaryKernelFunc({ opSnippet: ELU, packedOpSnippet: ELU_PACKED });
96836 const eluConfig = {
96837 kernelName: Elu$1,
96838 backendName: 'webgl',
96839 kernelFunc: elu
96840 };
96841
96842 /**
96843 * @license
96844 * Copyright 2020 Google LLC. All Rights Reserved.
96845 * Licensed under the Apache License, Version 2.0 (the "License");
96846 * you may not use this file except in compliance with the License.
96847 * You may obtain a copy of the License at
96848 *
96849 * http://www.apache.org/licenses/LICENSE-2.0
96850 *
96851 * Unless required by applicable law or agreed to in writing, software
96852 * distributed under the License is distributed on an "AS IS" BASIS,
96853 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96854 * See the License for the specific language governing permissions and
96855 * limitations under the License.
96856 * =============================================================================
96857 */
96858 const ELU_DER = `return (b >= 0.0) ? a : a * (b + 1.0);`;
96859 const ELU_DER_PACKED = `
96860 vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));
96861 return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));
96862`;
96863 const eluGrad = (args) => {
96864 const { inputs, backend } = args;
96865 const { dy, y } = inputs;
96866 const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
96867 new BinaryOpPackedProgram(ELU_DER_PACKED, dy.shape, y.shape) :
96868 new BinaryOpProgram(ELU_DER, dy.shape, y.shape);
96869 return backend.runWebGLProgram(program, [dy, y], dy.dtype);
96870 };
96871 const eluGradConfig = {
96872 kernelName: EluGrad,
96873 backendName: 'webgl',
96874 kernelFunc: eluGrad
96875 };
96876
96877 /**
96878 * @license
96879 * Copyright 2020 Google LLC. All Rights Reserved.
96880 * Licensed under the Apache License, Version 2.0 (the "License");
96881 * you may not use this file except in compliance with the License.
96882 * You may obtain a copy of the License at
96883 *
96884 * http://www.apache.org/licenses/LICENSE-2.0
96885 *
96886 * Unless required by applicable law or agreed to in writing, software
96887 * distributed under the License is distributed on an "AS IS" BASIS,
96888 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96889 * See the License for the specific language governing permissions and
96890 * limitations under the License.
96891 * =============================================================================
96892 */
96893 const PACKED_EQUAL = `
96894 return vec4(equal(a, b));
96895`;
96896 const EQUAL = `return float(a == b);`;
96897 const equal = binaryKernelFunc({
96898 opSnippet: EQUAL,
96899 packedOpSnippet: PACKED_EQUAL,
96900 dtype: 'bool',
96901 cpuKernelImpl: equalImplCPU,
96902 });
96903 const equalConfig = {
96904 kernelName: Equal,
96905 backendName: 'webgl',
96906 kernelFunc: equal
96907 };
96908
96909 /**
96910 * @license
96911 * Copyright 2020 Google LLC. All Rights Reserved.
96912 * Licensed under the Apache License, Version 2.0 (the "License");
96913 * you may not use this file except in compliance with the License.
96914 * You may obtain a copy of the License at
96915 *
96916 * http://www.apache.org/licenses/LICENSE-2.0
96917 *
96918 * Unless required by applicable law or agreed to in writing, software
96919 * distributed under the License is distributed on an "AS IS" BASIS,
96920 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96921 * See the License for the specific language governing permissions and
96922 * limitations under the License.
96923 * =============================================================================
96924 */
96925 const ERF = `
96926 // Error function is calculated approximately with elementary function.
96927 // See "Handbook of Mathematical Functions with Formulas,
96928 // Graphs, and Mathematical Tables", Abramowitz and Stegun.
96929 float p = ${ERF_P};
96930 float a1 = ${ERF_A1};
96931 float a2 = ${ERF_A2};
96932 float a3 = ${ERF_A3};
96933 float a4 = ${ERF_A4};
96934 float a5 = ${ERF_A5};
96935
96936 float sign = sign(x);
96937 x = abs(x);
96938 float t = 1.0 / (1.0 + p * x);
96939 return sign * (1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x));
96940`;
96941 const erf = unaryKernelFunc({ opSnippet: ERF });
96942 const erfConfig = {
96943 kernelName: Erf,
96944 backendName: 'webgl',
96945 kernelFunc: erf,
96946 };
96947
96948 /**
96949 * @license
96950 * Copyright 2020 Google LLC. All Rights Reserved.
96951 * Licensed under the Apache License, Version 2.0 (the "License");
96952 * you may not use this file except in compliance with the License.
96953 * You may obtain a copy of the License at
96954 *
96955 * http://www.apache.org/licenses/LICENSE-2.0
96956 *
96957 * Unless required by applicable law or agreed to in writing, software
96958 * distributed under the License is distributed on an "AS IS" BASIS,
96959 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
96960 * See the License for the specific language governing permissions and
96961 * limitations under the License.
96962 * =============================================================================
96963 */
96964 const EXP = CHECK_NAN_SNIPPET_UNARY + `
96965 return exp(x);
96966`;
96967 const EXP_PACKED = `
96968 vec4 result = exp(x);
96969 bvec4 isNaN = isnan(x);
96970 result.r = isNaN.r ? x.r : result.r;
96971 result.g = isNaN.g ? x.g : result.g;
96972 result.b = isNaN.b ? x.b : result.b;
96973 result.a = isNaN.a ? x.a : result.a;
96974
96975 return result;
96976`;
96977 const exp = unaryKernelFunc({
96978 opSnippet: EXP,
96979 packedOpSnippet: EXP_PACKED,
96980 cpuKernelImpl: expImplCPU,
96981 dtype: 'float32',
96982 });
96983 const expConfig = {
96984 kernelName: Exp,
96985 backendName: 'webgl',
96986 kernelFunc: exp
96987 };
96988
96989 /**
96990 * @license
96991 * Copyright 2020 Google LLC. All Rights Reserved.
96992 * Licensed under the Apache License, Version 2.0 (the License);
96993 * you may not use this file except in compliance with the License.
96994 * You may obtain a copy of the License at
96995 *
96996 * http://www.apache.org/licenses/LICENSE-2.0
96997 *
96998 * Unless required by applicable law or agreed to in writing, software
96999 * distributed under the License is distributed on an AS IS BASIS,
97000 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97001 * See the License for the specific language governing permissions and
97002 * limitations under the License.
97003 * =============================================================================
97004 */
97005 function expandDims(args) {
97006 const { inputs, attrs, backend } = args;
97007 const { dim } = attrs;
97008 const { input } = inputs;
97009 const inputRank = input.shape.length;
97010 const newShape = input.shape.slice();
97011 let $dim = dim;
97012 if (dim < 0) {
97013 // Negative value is counted from the tail of rank.
97014 assert$1(-(inputRank + 1) <= dim, () => `Axis must be in the interval [${-(inputRank + 1)}, ${inputRank}]`);
97015 $dim = inputRank + dim + 1;
97016 }
97017 newShape.splice($dim, 0, 1);
97018 return reshape({ inputs: { x: input }, backend, attrs: { shape: newShape } });
97019 }
97020 const expandDimsConfig = {
97021 kernelName: ExpandDims,
97022 backendName: 'webgl',
97023 kernelFunc: expandDims,
97024 };
97025
97026 /**
97027 * @license
97028 * Copyright 2020 Google LLC. All Rights Reserved.
97029 * Licensed under the Apache License, Version 2.0 (the "License");
97030 * you may not use this file except in compliance with the License.
97031 * You may obtain a copy of the License at
97032 *
97033 * http://www.apache.org/licenses/LICENSE-2.0
97034 *
97035 * Unless required by applicable law or agreed to in writing, software
97036 * distributed under the License is distributed on an "AS IS" BASIS,
97037 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97038 * See the License for the specific language governing permissions and
97039 * limitations under the License.
97040 * =============================================================================
97041 */
97042 const EXPM1 = `return exp(x) - 1.0;`;
97043 const expm1 = unaryKernelFunc({ opSnippet: EXPM1, packedOpSnippet: EXPM1, cpuKernelImpl: expm1ImplCPU });
97044 const expm1Config = {
97045 kernelName: Expm1,
97046 backendName: 'webgl',
97047 kernelFunc: expm1
97048 };
97049
97050 /**
97051 * @license
97052 * Copyright 2018 Google LLC. All Rights Reserved.
97053 * Licensed under the Apache License, Version 2.0 (the "License");
97054 * you may not use this file except in compliance with the License.
97055 * You may obtain a copy of the License at
97056 *
97057 * http://www.apache.org/licenses/LICENSE-2.0
97058 *
97059 * Unless required by applicable law or agreed to in writing, software
97060 * distributed under the License is distributed on an "AS IS" BASIS,
97061 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97062 * See the License for the specific language governing permissions and
97063 * limitations under the License.
97064 * =============================================================================
97065 */
97066 class FFTProgram {
97067 constructor(component, inputShape, inverse) {
97068 this.variableNames = ['real', 'imag'];
97069 const innerDim = inputShape[1];
97070 this.outputShape = inputShape;
97071 const exponentMultiplierSnippet = inverse ? `2.0 * ${Math.PI}` : `-2.0 * ${Math.PI}`;
97072 const resultDenominator = inverse ? `${innerDim}.0` : '1.0';
97073 let opString;
97074 if (component === 'real') {
97075 opString = 'return real * expR - imag * expI;';
97076 }
97077 else if (component === 'imag') {
97078 opString = 'return real * expI + imag * expR;';
97079 }
97080 else {
97081 throw new Error(`FFT component must be either "real" or "imag", got ${component}.`);
97082 }
97083 this.userCode = `
97084 const float exponentMultiplier = ${exponentMultiplierSnippet};
97085
97086 float unaryOpComplex(float real, float expR, float imag, float expI) {
97087 ${opString}
97088 }
97089
97090 float mulMatDFT(int batch, int index) {
97091 float indexRatio = float(index) / float(${innerDim});
97092 float exponentMultiplierTimesIndexRatio =
97093 exponentMultiplier * indexRatio;
97094
97095 float result = 0.0;
97096
97097 for (int i = 0; i < ${innerDim}; i++) {
97098 // x = (-2|2 * PI / N) * index * i;
97099 float x = exponentMultiplierTimesIndexRatio * float(i);
97100 float expR = cos(x);
97101 float expI = sin(x);
97102 float real = getReal(batch, i);
97103 float imag = getImag(batch, i);
97104
97105 result +=
97106 unaryOpComplex(real, expR, imag, expI) / ${resultDenominator};
97107 }
97108
97109 return result;
97110 }
97111
97112 void main() {
97113 ivec2 coords = getOutputCoords();
97114 setOutput(mulMatDFT(coords[0], coords[1]));
97115 }
97116 `;
97117 }
97118 }
97119
97120 /**
97121 * @license
97122 * Copyright 2020 Google LLC. All Rights Reserved.
97123 * Licensed under the Apache License, Version 2.0 (the "License");
97124 * you may not use this file except in compliance with the License.
97125 * You may obtain a copy of the License at
97126 *
97127 * http://www.apache.org/licenses/LICENSE-2.0
97128 *
97129 * Unless required by applicable law or agreed to in writing, software
97130 * distributed under the License is distributed on an "AS IS" BASIS,
97131 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97132 * See the License for the specific language governing permissions and
97133 * limitations under the License.
97134 * =============================================================================
97135 */
97136 function fftImpl(x, inverse, backend) {
97137 const xData = backend.texData.get(x.dataId);
97138 const inputSize = sizeFromShape(x.shape);
97139 // Collapse all outer dimensions to a single batch dimension.
97140 const innerDimensionSize = x.shape[x.shape.length - 1];
97141 const batch = inputSize / innerDimensionSize;
97142 const input2D = reshape({ inputs: { x }, backend, attrs: { shape: [batch, innerDimensionSize] } });
97143 const xShape = input2D.shape;
97144 const realProgram = new FFTProgram('real', xShape, inverse);
97145 const imagProgram = new FFTProgram('imag', xShape, inverse);
97146 const inputs = [
97147 {
97148 dataId: xData.complexTensorInfos.real.dataId,
97149 dtype: xData.complexTensorInfos.real.dtype,
97150 shape: xShape
97151 },
97152 {
97153 dataId: xData.complexTensorInfos.imag.dataId,
97154 dtype: xData.complexTensorInfos.imag.dtype,
97155 shape: xShape
97156 }
97157 ];
97158 const realPart = backend.runWebGLProgram(realProgram, inputs, 'float32');
97159 const imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32');
97160 const complexOutput = complex({ inputs: { real: realPart, imag: imagPart }, backend });
97161 backend.disposeIntermediateTensorInfo(realPart);
97162 backend.disposeIntermediateTensorInfo(imagPart);
97163 const complexOutputReshaped = reshape({ inputs: { x: complexOutput }, backend, attrs: { shape: x.shape } });
97164 backend.disposeIntermediateTensorInfo(input2D);
97165 backend.disposeIntermediateTensorInfo(complexOutput);
97166 return complexOutputReshaped;
97167 }
97168
97169 /**
97170 * @license
97171 * Copyright 2020 Google LLC. All Rights Reserved.
97172 * Licensed under the Apache License, Version 2.0 (the "License");
97173 * you may not use this file except in compliance with the License.
97174 * You may obtain a copy of the License at
97175 *
97176 * http://www.apache.org/licenses/LICENSE-2.0
97177 *
97178 * Unless required by applicable law or agreed to in writing, software
97179 * distributed under the License is distributed on an "AS IS" BASIS,
97180 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97181 * See the License for the specific language governing permissions and
97182 * limitations under the License.
97183 * =============================================================================
97184 */
97185 function fft(args) {
97186 const { inputs, backend } = args;
97187 const { input } = inputs;
97188 return fftImpl(input, false /* inverse */, backend);
97189 }
97190 const fftConfig = {
97191 kernelName: FFT,
97192 backendName: 'webgl',
97193 kernelFunc: fft
97194 };
97195
97196 /**
97197 * @license
97198 * Copyright 2019 Google LLC. All Rights Reserved.
97199 * Licensed under the Apache License, Version 2.0 (the "License");
97200 * you may not use this file except in compliance with the License.
97201 * You may obtain a copy of the License at
97202 *
97203 * http://www.apache.org/licenses/LICENSE-2.0
97204 *
97205 * Unless required by applicable law or agreed to in writing, software
97206 * distributed under the License is distributed on an "AS IS" BASIS,
97207 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97208 * See the License for the specific language governing permissions and
97209 * limitations under the License.
97210 * =============================================================================
97211 */
97212 class FillProgram {
97213 constructor(shape, value) {
97214 this.outputShape = [];
97215 this.customUniforms = [{ name: 'value', type: 'float' }];
97216 this.variableNames = ['x'];
97217 this.outputShape = shape;
97218 this.userCode = `
97219 void main() {
97220 // Input can be obtained from uniform value.
97221 setOutput(value);
97222 }
97223 `;
97224 }
97225 }
97226
97227 /**
97228 * @license
97229 * Copyright 2020 Google LLC. All Rights Reserved.
97230 * Licensed under the Apache License, Version 2.0 (the "License");
97231 * you may not use this file except in compliance with the License.
97232 * You may obtain a copy of the License at
97233 *
97234 * http://www.apache.org/licenses/LICENSE-2.0
97235 *
97236 * Unless required by applicable law or agreed to in writing, software
97237 * distributed under the License is distributed on an "AS IS" BASIS,
97238 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97239 * See the License for the specific language governing permissions and
97240 * limitations under the License.
97241 * =============================================================================
97242 */
97243 function fill(args) {
97244 const { backend, attrs } = args;
97245 const { shape, value } = attrs;
97246 let { dtype } = attrs;
97247 dtype = dtype || inferDtype(value);
97248 if (dtype === 'string') {
97249 // String type should be handled in CPU memory.
97250 const values = getArrayFromDType(dtype, sizeFromShape(shape));
97251 values.fill(value);
97252 return backend.makeTensorInfo(shape, dtype, values);
97253 }
97254 else {
97255 const program = new FillProgram(shape, value);
97256 const customValues = [[value]];
97257 return backend.runWebGLProgram(program, [], dtype, customValues);
97258 }
97259 }
97260 const fillConfig = {
97261 kernelName: Fill,
97262 backendName: 'webgl',
97263 kernelFunc: fill
97264 };
97265
97266 /**
97267 * @license
97268 * Copyright 2020 Google LLC. All Rights Reserved.
97269 * Licensed under the Apache License, Version 2.0 (the "License");
97270 * you may not use this file except in compliance with the License.
97271 * You may obtain a copy of the License at
97272 *
97273 * http://www.apache.org/licenses/LICENSE-2.0
97274 *
97275 * Unless required by applicable law or agreed to in writing, software
97276 * distributed under the License is distributed on an "AS IS" BASIS,
97277 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97278 * See the License for the specific language governing permissions and
97279 * limitations under the License.
97280 * =============================================================================
97281 */
97282 class FlipLeftRightProgram {
97283 constructor(imageShape) {
97284 this.variableNames = ['Image'];
97285 this.outputShape = [];
97286 const imageWidth = imageShape[2];
97287 this.outputShape = imageShape;
97288 this.userCode = `
97289 void main() {
97290 ivec4 coords = getOutputCoords();
97291 int x = coords[2];
97292
97293 int coordX = ${imageWidth} - x - 1;
97294 float outputValue;
97295 if(coordX >= 0 && coordX < ${imageWidth}) {
97296 outputValue = getImage(coords[0], coords[1], coordX, coords[3]);
97297 } else {
97298 outputValue = getImage(coords[0], coords[1], coords[2], coords[3]);
97299 }
97300 setOutput(outputValue);
97301 }
97302 `;
97303 }
97304 }
97305
97306 /**
97307 * @license
97308 * Copyright 2020 Google LLC. All Rights Reserved.
97309 * Licensed under the Apache License, Version 2.0 (the "License");
97310 * you may not use this file except in compliance with the License.
97311 * You may obtain a copy of the License at
97312 *
97313 * http://www.apache.org/licenses/LICENSE-2.0
97314 *
97315 * Unless required by applicable law or agreed to in writing, software
97316 * distributed under the License is distributed on an "AS IS" BASIS,
97317 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97318 * See the License for the specific language governing permissions and
97319 * limitations under the License.
97320 * =============================================================================
97321 */
97322 const flipLeftRightConfig = {
97323 kernelName: FlipLeftRight,
97324 backendName: 'webgl',
97325 kernelFunc: ({ inputs, backend }) => {
97326 const { image } = inputs;
97327 const webglBackend = backend;
97328 const program = new FlipLeftRightProgram(image.shape);
97329 const output = webglBackend.runWebGLProgram(program, [image], image.dtype);
97330 return output;
97331 }
97332 };
97333
97334 /**
97335 * @license
97336 * Copyright 2020 Google LLC. All Rights Reserved.
97337 * Licensed under the Apache License, Version 2.0 (the "License");
97338 * you may not use this file except in compliance with the License.
97339 * You may obtain a copy of the License at
97340 *
97341 * http://www.apache.org/licenses/LICENSE-2.0
97342 *
97343 * Unless required by applicable law or agreed to in writing, software
97344 * distributed under the License is distributed on an "AS IS" BASIS,
97345 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97346 * See the License for the specific language governing permissions and
97347 * limitations under the License.
97348 * =============================================================================
97349 */
97350 const FLOOR = `return floor(x);`;
97351 const floor = unaryKernelFunc({ opSnippet: FLOOR, packedOpSnippet: FLOOR, cpuKernelImpl: floorImplCPU });
97352 const floorConfig = {
97353 kernelName: Floor,
97354 backendName: 'webgl',
97355 kernelFunc: floor,
97356 };
97357
97358 /**
97359 * @license
97360 * Copyright 2020 Google LLC. All Rights Reserved.
97361 * Licensed under the Apache License, Version 2.0 (the "License");
97362 * you may not use this file except in compliance with the License.
97363 * You may obtain a copy of the License at
97364 *
97365 * http://www.apache.org/licenses/LICENSE-2.0
97366 *
97367 * Unless required by applicable law or agreed to in writing, software
97368 * distributed under the License is distributed on an "AS IS" BASIS,
97369 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97370 * See the License for the specific language governing permissions and
97371 * limitations under the License.
97372 * =============================================================================
97373 */
97374 // We use native integer division to deal with floating point imprecision. Since
97375 // we implement floor division and glsl implements truncated division, we
97376 // correct for this by subtracting 1 from result when the result is negative and
97377 // there is a remainder.
97378 const INT_DIV = `
97379 float s = sign(a) * sign(b);
97380 int ia = round(a);
97381 int ib = round(b);
97382 if (ib != 0) {
97383 // Windows (D3D) wants guaranteed non-zero int division at compile-time.
97384 return float(idiv(ia, ib, s));
97385 } else {
97386 return NAN;
97387 }
97388`;
97389 const INT_DIV_PACKED = `
97390 ivec4 ia = round(a);
97391 ivec4 ib = round(b);
97392 bvec4 cond = notEqual(ib, ivec4(0));
97393 ivec4 result = ivec4(0);
97394 vec4 s = sign(a) * sign(b);
97395
97396 // Windows (D3D) wants guaranteed non-zero int division at compile-time.
97397 if (cond[0]) {
97398 result[0] = idiv(ia[0], ib[0], s[0]);
97399 }
97400 if (cond[1]) {
97401 result[1] = idiv(ia[1], ib[1], s[1]);
97402 }
97403 if (cond[2]) {
97404 result[2] = idiv(ia[2], ib[2], s[2]);
97405 }
97406 if (cond[3]) {
97407 result[3] = idiv(ia[3], ib[3], s[3]);
97408 }
97409 return vec4(result);
97410`;
97411 const floorDiv = binaryKernelFunc({ opSnippet: INT_DIV, packedOpSnippet: INT_DIV_PACKED, dtype: 'int32' });
97412 const floorDivConfig = {
97413 kernelName: FloorDiv,
97414 backendName: 'webgl',
97415 kernelFunc: floorDiv
97416 };
97417
97418 /**
97419 * @license
97420 * Copyright 2018 Google LLC. All Rights Reserved.
97421 * Licensed under the Apache License, Version 2.0 (the "License");
97422 * you may not use this file except in compliance with the License.
97423 * You may obtain a copy of the License at
97424 *
97425 * http://www.apache.org/licenses/LICENSE-2.0
97426 *
97427 * Unless required by applicable law or agreed to in writing, software
97428 * distributed under the License is distributed on an "AS IS" BASIS,
97429 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97430 * See the License for the specific language governing permissions and
97431 * limitations under the License.
97432 * =============================================================================
97433 */
97434 class FromPixelsProgram {
97435 constructor(outputShape) {
97436 this.variableNames = ['A'];
97437 const glsl = getGlslDifferences();
97438 const [height, width,] = outputShape;
97439 this.outputShape = outputShape;
97440 this.userCode = `
97441 void main() {
97442 ivec3 coords = getOutputCoords();
97443 int texR = coords[0];
97444 int texC = coords[1];
97445 int depth = coords[2];
97446 vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0);
97447
97448 vec4 values = ${glsl.texture2D}(A, uv);
97449 float value;
97450 if (depth == 0) {
97451 value = values.r;
97452 } else if (depth == 1) {
97453 value = values.g;
97454 } else if (depth == 2) {
97455 value = values.b;
97456 } else if (depth == 3) {
97457 value = values.a;
97458 }
97459
97460 setOutput(floor(value * 255.0 + 0.5));
97461 }
97462 `;
97463 }
97464 }
97465
97466 /**
97467 * @license
97468 * Copyright 2018 Google LLC. All Rights Reserved.
97469 * Licensed under the Apache License, Version 2.0 (the "License");
97470 * you may not use this file except in compliance with the License.
97471 * You may obtain a copy of the License at
97472 *
97473 * http://www.apache.org/licenses/LICENSE-2.0
97474 *
97475 * Unless required by applicable law or agreed to in writing, software
97476 * distributed under the License is distributed on an "AS IS" BASIS,
97477 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97478 * See the License for the specific language governing permissions and
97479 * limitations under the License.
97480 * =============================================================================
97481 */
97482 class FromPixelsPackedProgram {
97483 constructor(outputShape) {
97484 this.variableNames = ['A'];
97485 this.packedInputs = false;
97486 this.packedOutput = true;
97487 const glsl = getGlslDifferences();
97488 const [height, width,] = outputShape;
97489 this.outputShape = outputShape;
97490 this.userCode = `
97491 void main() {
97492 ivec3 coords = getOutputCoords();
97493 int texR = coords[0];
97494 int texC = coords[1];
97495 int depth = coords[2];
97496
97497 vec4 result = vec4(0.);
97498
97499 for(int row=0; row<=1; row++) {
97500 for(int col=0; col<=1; col++) {
97501 texC = coords[1] + row;
97502 depth = coords[2] + col;
97503
97504 vec2 uv = (vec2(texC, texR) + halfCR) /
97505 vec2(${width}.0, ${height}.0);
97506 vec4 values = ${glsl.texture2D}(A, uv);
97507 float value;
97508 if (depth == 0) {
97509 value = values.r;
97510 } else if (depth == 1) {
97511 value = values.g;
97512 } else if (depth == 2) {
97513 value = values.b;
97514 } else if (depth == 3) {
97515 value = values.a;
97516 }
97517
97518 result[row * 2 + col] = floor(value * 255.0 + 0.5);
97519 }
97520 }
97521
97522 ${glsl.output} = result;
97523 }
97524 `;
97525 }
97526 }
97527
97528 /**
97529 * @license
97530 * Copyright 2019 Google LLC. All Rights Reserved.
97531 * Licensed under the Apache License, Version 2.0 (the "License");
97532 * you may not use this file except in compliance with the License.
97533 * You may obtain a copy of the License at
97534 *
97535 * http://www.apache.org/licenses/LICENSE-2.0
97536 *
97537 * Unless required by applicable law or agreed to in writing, software
97538 * distributed under the License is distributed on an "AS IS" BASIS,
97539 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97540 * See the License for the specific language governing permissions and
97541 * limitations under the License.
97542 * =============================================================================
97543 */
97544 const fromPixelsConfig = {
97545 kernelName: FromPixels,
97546 backendName: 'webgl',
97547 kernelFunc: fromPixels,
97548 };
97549 let fromPixels2DContext;
97550 let willReadFrequently = env().getBool('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU');
97551 function fromPixels(args) {
97552 const { inputs, backend, attrs } = args;
97553 let { pixels } = inputs;
97554 const { numChannels } = attrs;
97555 const isVideo = typeof (HTMLVideoElement) !== 'undefined' &&
97556 pixels instanceof HTMLVideoElement;
97557 const isImage = typeof (HTMLImageElement) !== 'undefined' &&
97558 pixels instanceof HTMLImageElement;
97559 const [width, height] = isVideo ?
97560 [
97561 pixels.videoWidth,
97562 pixels.videoHeight
97563 ] :
97564 [pixels.width, pixels.height];
97565 const texShape = [height, width];
97566 const outShape = [height, width, numChannels];
97567 if (isImage || isVideo) {
97568 const newWillReadFrequently = env().getBool('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU');
97569 if (fromPixels2DContext == null ||
97570 newWillReadFrequently !== willReadFrequently) {
97571 willReadFrequently = newWillReadFrequently;
97572 fromPixels2DContext =
97573 document.createElement('canvas').getContext('2d', { willReadFrequently });
97574 }
97575 fromPixels2DContext.canvas.width = width;
97576 fromPixels2DContext.canvas.height = height;
97577 fromPixels2DContext.drawImage(pixels, 0, 0, width, height);
97578 pixels = fromPixels2DContext.canvas;
97579 }
97580 const tempPixelHandle = backend.makeTensorInfo(texShape, 'int32');
97581 // This is a byte texture with pixels.
97582 backend.texData.get(tempPixelHandle.dataId).usage = TextureUsage.PIXELS;
97583 backend.gpgpu.uploadPixelDataToTexture(backend.getTexture(tempPixelHandle.dataId), pixels);
97584 const program = env().getBool('WEBGL_PACK') ?
97585 new FromPixelsPackedProgram(outShape) :
97586 new FromPixelsProgram(outShape);
97587 const res = backend.runWebGLProgram(program, [tempPixelHandle], 'int32');
97588 backend.disposeData(tempPixelHandle.dataId);
97589 return res;
97590 }
97591
97592 /**
97593 * @license
97594 * Copyright 2020 Google LLC. All Rights Reserved.
97595 * Licensed under the Apache License, Version 2.0 (the "License");
97596 * you may not use this file except in compliance with the License.
97597 * You may obtain a copy of the License at
97598 *
97599 * http://www.apache.org/licenses/LICENSE-2.0
97600 *
97601 * Unless required by applicable law or agreed to in writing, software
97602 * distributed under the License is distributed on an "AS IS" BASIS,
97603 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97604 * See the License for the specific language governing permissions and
97605 * limitations under the License.
97606 * =============================================================================
97607 */
97608 function fusedConv2d(args) {
97609 const { inputs, backend, attrs } = args;
97610 const { x, filter, bias, preluActivationWeights } = inputs;
97611 const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
97612 const $dataFormat = convertConv2DDataFormat(dataFormat);
97613 const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
97614 let out;
97615 const intermediates = [];
97616 const hasBias = bias != null;
97617 const hasPreluActivationWeights = preluActivationWeights != null;
97618 const hasLeakyreluAlpha = activation === 'leakyrelu';
97619 const prepareInputs = () => {
97620 const inputs = [x, filter];
97621 // If the input is a 1-D tensor, align it with the channels.
97622 //
97623 // For fusedConv2d, the inputs (x, W, bias, preluActivationWeights) are
97624 // supposed to be aligned with the dataFormat. The 4-D tensor inputs or
97625 // scalar inputs are originally aligned, but the 1-D tensor inputs are
97626 // supposed to be aligned with the channels (only bias and PReLU activation
97627 // weights could be a 1-D tensor).
97628 const alignInputWithDataFormat = (input, dataFormat) => {
97629 if (dataFormat === 'NCHW' && input.shape.length === 1 &&
97630 input.shape[0] !== 1) {
97631 const alignedInput = reshape({
97632 inputs: { x: input },
97633 backend,
97634 attrs: { shape: [input.shape[0], 1, 1] }
97635 });
97636 intermediates.push(alignedInput);
97637 return alignedInput;
97638 }
97639 return input;
97640 };
97641 if (hasBias) {
97642 inputs.push(alignInputWithDataFormat(bias, dataFormat));
97643 }
97644 if (hasPreluActivationWeights) {
97645 inputs.push(alignInputWithDataFormat(preluActivationWeights, dataFormat));
97646 }
97647 if (hasLeakyreluAlpha) {
97648 const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
97649 inputs.push($leakyreluAlpha);
97650 intermediates.push($leakyreluAlpha);
97651 }
97652 return inputs;
97653 };
97654 if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
97655 convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
97656 convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
97657 (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) {
97658 out = conv2dByMatMul({
97659 x,
97660 filter,
97661 convInfo,
97662 backend,
97663 bias,
97664 activation,
97665 preluActivationWeights,
97666 leakyreluAlpha
97667 });
97668 }
97669 else if (convInfo.strideWidth <= 2 && $dataFormat === 'channelsLast'
97670 && env().getBool('WEBGL_EXP_CONV')) {
97671 const fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null;
97672 const program = new Conv2DPackedProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
97673 const customValues = [
97674 [convInfo.padInfo.top, convInfo.padInfo.left],
97675 [convInfo.strideHeight, convInfo.strideWidth],
97676 [convInfo.dilationHeight, convInfo.dilationWidth],
97677 [convInfo.inHeight, convInfo.inWidth]
97678 ];
97679 const inputs = prepareInputs();
97680 out = backend.runWebGLProgram(program, inputs, 'float32', customValues);
97681 }
97682 else if (env().getBool('WEBGL_CONV_IM2COL')) {
97683 out = conv2dWithIm2Row({
97684 x,
97685 filter,
97686 convInfo,
97687 backend,
97688 bias,
97689 activation,
97690 preluActivationWeights,
97691 leakyreluAlpha
97692 });
97693 }
97694 else {
97695 const fusedActivation = activation ? mapActivationToShaderProgram(activation, false) : null;
97696 const program = new Conv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
97697 const inputs = prepareInputs();
97698 out = backend.runWebGLProgram(program, inputs, 'float32');
97699 }
97700 const outReshaped = reshape({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } });
97701 intermediates.push(out);
97702 intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t));
97703 return outReshaped;
97704 }
97705 const fusedConv2DConfig = {
97706 kernelName: FusedConv2D,
97707 backendName: 'webgl',
97708 kernelFunc: fusedConv2d,
97709 };
97710
97711 /**
97712 * @license
97713 * Copyright 2020 Google LLC. All Rights Reserved.
97714 * Licensed under the Apache License, Version 2.0 (the "License");
97715 * you may not use this file except in compliance with the License.
97716 * You may obtain a copy of the License at
97717 *
97718 * http://www.apache.org/licenses/LICENSE-2.0
97719 *
97720 * Unless required by applicable law or agreed to in writing, software
97721 * distributed under the License is distributed on an "AS IS" BASIS,
97722 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97723 * See the License for the specific language governing permissions and
97724 * limitations under the License.
97725 * =============================================================================
97726 */
97727 function fusedDepthwiseConv2D(args) {
97728 const { inputs, backend, attrs } = args;
97729 const { x, filter, bias, preluActivationWeights } = inputs;
97730 const { strides, pad, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
97731 const intermediates = [];
97732 let $dilations = dilations;
97733 if ($dilations == null) {
97734 $dilations = [1, 1];
97735 }
97736 assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +
97737 `1. Got strides ${strides} and dilations '${$dilations}'`);
97738 const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */);
97739 const shouldPackDepthwiseConv = env().getBool('WEBGL_PACK_DEPTHWISECONV') &&
97740 convInfo.strideWidth <= 2 &&
97741 convInfo.outChannels / convInfo.inChannels === 1;
97742 const fusedActivation = activation ?
97743 mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) :
97744 null;
97745 const programInputs = [x, filter];
97746 const hasBias = bias != null;
97747 const hasPreluActivationWeights = preluActivationWeights != null;
97748 const hasLeakyreluAlpha = activation === 'leakyrelu';
97749 if (hasBias) {
97750 programInputs.push(bias);
97751 }
97752 if (hasPreluActivationWeights) {
97753 programInputs.push(preluActivationWeights);
97754 }
97755 if (hasLeakyreluAlpha) {
97756 const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
97757 programInputs.push($leakyreluAlpha);
97758 intermediates.push($leakyreluAlpha);
97759 }
97760 let program;
97761 if (shouldPackDepthwiseConv) {
97762 program = new DepthwiseConvPacked2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
97763 }
97764 else {
97765 program = new DepthwiseConv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
97766 }
97767 const customValues = [
97768 [convInfo.padInfo.top, convInfo.padInfo.left],
97769 [convInfo.strideHeight, convInfo.strideWidth],
97770 [convInfo.dilationHeight, convInfo.dilationWidth],
97771 [convInfo.inHeight, convInfo.inWidth]
97772 ];
97773 const result = backend.runWebGLProgram(program, programInputs, 'float32', customValues);
97774 intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t));
97775 return result;
97776 }
97777 const fusedDepthwiseConv2DConfig = {
97778 kernelName: FusedDepthwiseConv2D,
97779 backendName: 'webgl',
97780 kernelFunc: fusedDepthwiseConv2D,
97781 };
97782
97783 class GatherNDProgram {
97784 constructor(sliceDim, strides, shape, paramsShape) {
97785 this.sliceDim = sliceDim;
97786 this.strides = strides;
97787 this.paramsShape = paramsShape;
97788 this.variableNames = ['x', 'indices'];
97789 this.outputShape = shape;
97790 const dtype = getCoordsDataType(shape.length);
97791 let mainLoop = `
97792 int index;`;
97793 for (let j = 0; j < this.sliceDim; j++) {
97794 mainLoop += `
97795 index = round(getIndices(coords[0], ${j}));
97796 out_of_bounds = out_of_bounds || index < 0;
97797 out_of_bounds = out_of_bounds || index >= ${this.paramsShape[j]};
97798 flattenIndex += index * ${this.strides[j]};`;
97799 }
97800 this.userCode = `
97801 void main() {
97802 ${dtype} coords = getOutputCoords();
97803 int flattenIndex = 0;
97804 bool out_of_bounds = false;
97805
97806 ${mainLoop}
97807
97808 setOutput(out_of_bounds ? 0.0 : getX(flattenIndex, coords[1]));
97809 }
97810 `;
97811 }
97812 }
97813
97814 /**
97815 * @license
97816 * Copyright 2020 Google LLC. All Rights Reserved.
97817 * Licensed under the Apache License, Version 2.0 (the "License");
97818 * you may not use this file except in compliance with the License.
97819 * You may obtain a copy of the License at
97820 *
97821 * http://www.apache.org/licenses/LICENSE-2.0
97822 *
97823 * Unless required by applicable law or agreed to in writing, software
97824 * distributed under the License is distributed on an "AS IS" BASIS,
97825 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97826 * See the License for the specific language governing permissions and
97827 * limitations under the License.
97828 * =============================================================================
97829 */
97830 function gatherNd(args) {
97831 const { inputs, backend } = args;
97832 const { params, indices } = inputs;
97833 const indicesShape = indices.shape;
97834 const sliceRank = indicesShape[indicesShape.length - 1];
97835 const paramsSize = sizeFromShape(params.shape);
97836 const [resultShape, numSlices, sliceSize, strides] = prepareAndValidate(params, indices);
97837 const flattenIndices = reshape({ inputs: { x: indices }, backend, attrs: { shape: [numSlices, sliceRank] } });
97838 const flattenX = reshape({
97839 inputs: { x: params },
97840 backend,
97841 attrs: { shape: [(sizeFromShape(params.shape) / sliceSize), sliceSize] }
97842 });
97843 if (backend.shouldExecuteOnCPU([params, indices]) ||
97844 params.dtype === 'string') {
97845 const indicesData = backend.readSync(indices.dataId);
97846 const paramsBuf = backend.bufferSync(params);
97847 const outValue = gatherNdImplCPU(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize);
97848 return backend.makeTensorInfo(resultShape, params.dtype, outValue.values);
97849 }
97850 const program = new GatherNDProgram(sliceRank, strides, [numSlices, sliceSize], params.shape);
97851 const res = backend.runWebGLProgram(program, [flattenX, flattenIndices], flattenX.dtype);
97852 const reshaped = reshape({ inputs: { x: res }, backend, attrs: { shape: resultShape } });
97853 backend.disposeIntermediateTensorInfo(flattenIndices);
97854 backend.disposeIntermediateTensorInfo(flattenX);
97855 backend.disposeIntermediateTensorInfo(res);
97856 return reshaped;
97857 }
97858 const gatherNdConfig = {
97859 kernelName: GatherNd,
97860 backendName: 'webgl',
97861 kernelFunc: gatherNd
97862 };
97863
97864 /**
97865 * @license
97866 * Copyright 2017 Google LLC. All Rights Reserved.
97867 * Licensed under the Apache License, Version 2.0 (the "License");
97868 * you may not use this file except in compliance with the License.
97869 * You may obtain a copy of the License at
97870 *
97871 * http://www.apache.org/licenses/LICENSE-2.0
97872 *
97873 * Unless required by applicable law or agreed to in writing, software
97874 * distributed under the License is distributed on an "AS IS" BASIS,
97875 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97876 * See the License for the specific language governing permissions and
97877 * limitations under the License.
97878 * =============================================================================
97879 */
97880 class GatherProgram {
97881 constructor(aShape, outputShape) {
97882 this.variableNames = ['A', 'indices'];
97883 this.outputShape = outputShape;
97884 this.rank = outputShape.length;
97885 const dtype = getCoordsDataType(this.rank);
97886 const sourceCoords = getSourceCoords$1(aShape, 2);
97887 this.userCode = `
97888 void main() {
97889 ${dtype} resRC = getOutputCoords();
97890 int index = int(getIndices(resRC.x, resRC.z));
97891 float inBounds = (index >= 0) && (index < ${aShape[2]}) ? 1.0 : 0.0;
97892 setOutput(inBounds * getA(${sourceCoords}));
97893 }
97894 `;
97895 }
97896 }
97897 // The input and output are always flattened into rank 4 tensors.
97898 function getSourceCoords$1(aShape, axis) {
97899 const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
97900 const sourceCoords = [];
97901 for (let i = 0; i < aShape.length; i++) {
97902 if (i === 2) {
97903 sourceCoords.push('index');
97904 }
97905 else {
97906 sourceCoords.push(`${currentCoords[i]}`);
97907 }
97908 }
97909 return sourceCoords.join();
97910 }
97911
97912 /**
97913 * @license
97914 * Copyright 2020 Google LLC. All Rights Reserved.
97915 * Licensed under the Apache License, Version 2.0 (the "License");
97916 * you may not use this file except in compliance with the License.
97917 * You may obtain a copy of the License at
97918 *
97919 * http://www.apache.org/licenses/LICENSE-2.0
97920 *
97921 * Unless required by applicable law or agreed to in writing, software
97922 * distributed under the License is distributed on an "AS IS" BASIS,
97923 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97924 * See the License for the specific language governing permissions and
97925 * limitations under the License.
97926 * =============================================================================
97927 */
97928 function gatherV2(args) {
97929 const { inputs, backend, attrs } = args;
97930 const { x, indices } = inputs;
97931 const { axis, batchDims } = attrs;
97932 const parsedAxis = parseAxisParam(axis, x.shape)[0];
97933 if (env().get('DEBUG')) {
97934 // In debug mode, throw error when any index is out of bound.
97935 // Otherwise, just fill out of bounds with zeroes.
97936 const indicesVals = backend.readSync(indices.dataId);
97937 const axisDim = x.shape[parsedAxis];
97938 for (let i = 0; i < indicesVals.length; ++i) {
97939 const index = indicesVals[i];
97940 assert$1(index <= axisDim - 1 && index >= 0, () => `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);
97941 }
97942 }
97943 const shapeInfo = collectGatherOpShapeInfo(x, indices, parsedAxis, batchDims);
97944 const indicesSize = sizeFromShape(indices.shape);
97945 const toDispose = [];
97946 const flattenX = reshape({
97947 inputs: { x },
97948 backend,
97949 attrs: {
97950 shape: [
97951 shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize,
97952 shapeInfo.sliceSize
97953 ]
97954 }
97955 });
97956 const flattenIndex = reshape({
97957 inputs: { x: indices },
97958 backend,
97959 attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] }
97960 });
97961 toDispose.push(flattenX);
97962 toDispose.push(flattenIndex);
97963 const flattenOutputShape = [
97964 shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize,
97965 shapeInfo.sliceSize
97966 ];
97967 if (backend.shouldExecuteOnCPU([x, indices]) || x.dtype === 'string') {
97968 const indicesBuf = backend.bufferSync(flattenIndex);
97969 const xBuf = backend.bufferSync(flattenX);
97970 const outBuf = gatherV2ImplCPU(xBuf, indicesBuf, flattenOutputShape);
97971 toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
97972 return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values);
97973 }
97974 const program = new GatherProgram(flattenX.shape, flattenOutputShape);
97975 const res = backend.runWebGLProgram(program, [flattenX, flattenIndex], flattenX.dtype);
97976 toDispose.push(res);
97977 const reshaped = reshape({ inputs: { x: res }, backend, attrs: { shape: shapeInfo.outputShape } });
97978 toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
97979 return reshaped;
97980 }
97981 const gatherV2Config = {
97982 kernelName: GatherV2,
97983 backendName: 'webgl',
97984 kernelFunc: gatherV2
97985 };
97986
97987 /**
97988 * @license
97989 * Copyright 2020 Google LLC. All Rights Reserved.
97990 * Licensed under the Apache License, Version 2.0 (the "License");
97991 * you may not use this file except in compliance with the License.
97992 * You may obtain a copy of the License at
97993 *
97994 * http://www.apache.org/licenses/LICENSE-2.0
97995 *
97996 * Unless required by applicable law or agreed to in writing, software
97997 * distributed under the License is distributed on an "AS IS" BASIS,
97998 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
97999 * See the License for the specific language governing permissions and
98000 * limitations under the License.
98001 * =============================================================================
98002 */
98003 const GREATER = `return float(a > b);`;
98004 const GREATER_PACKED = `
98005 return vec4(greaterThan(a, b));
98006`;
98007 const greater = binaryKernelFunc({
98008 opSnippet: GREATER,
98009 packedOpSnippet: GREATER_PACKED,
98010 cpuKernelImpl: greaterImplCPU,
98011 dtype: 'bool'
98012 });
98013 const greaterConfig = {
98014 kernelName: Greater,
98015 backendName: 'webgl',
98016 kernelFunc: greater
98017 };
98018
98019 /**
98020 * @license
98021 * Copyright 2020 Google LLC. All Rights Reserved.
98022 * Licensed under the Apache License, Version 2.0 (the "License");
98023 * you may not use this file except in compliance with the License.
98024 * You may obtain a copy of the License at
98025 *
98026 * http://www.apache.org/licenses/LICENSE-2.0
98027 *
98028 * Unless required by applicable law or agreed to in writing, software
98029 * distributed under the License is distributed on an "AS IS" BASIS,
98030 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98031 * See the License for the specific language governing permissions and
98032 * limitations under the License.
98033 * =============================================================================
98034 */
98035 const GREATER_EQUAL = `return float(a >= b);`;
98036 const GREATER_EQUAL_PACKED = `
98037 return vec4(greaterThanEqual(a, b));
98038`;
98039 const greaterEqual = binaryKernelFunc({
98040 opSnippet: GREATER_EQUAL,
98041 packedOpSnippet: GREATER_EQUAL_PACKED,
98042 dtype: 'bool',
98043 cpuKernelImpl: greaterEqualImplCPU
98044 });
98045 const greaterEqualConfig = {
98046 kernelName: GreaterEqual,
98047 backendName: 'webgl',
98048 kernelFunc: greaterEqual
98049 };
98050
98051 /**
98052 * @license
98053 * Copyright 2020 Google LLC. All Rights Reserved.
98054 * Licensed under the Apache License, Version 2.0 (the "License");
98055 * you may not use this file except in compliance with the License.
98056 * You may obtain a copy of the License at
98057 *
98058 * http://www.apache.org/licenses/LICENSE-2.0
98059 *
98060 * Unless required by applicable law or agreed to in writing, software
98061 * distributed under the License is distributed on an "AS IS" BASIS,
98062 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98063 * See the License for the specific language governing permissions and
98064 * limitations under the License.
98065 * =============================================================================
98066 */
98067 function ifft(args) {
98068 const { inputs, backend } = args;
98069 const { input } = inputs;
98070 return fftImpl(input, true /* inverse */, backend);
98071 }
98072 const ifftConfig = {
98073 kernelName: IFFT,
98074 backendName: 'webgl',
98075 kernelFunc: ifft
98076 };
98077
98078 /**
98079 * @license
98080 * Copyright 2020 Google LLC. All Rights Reserved.
98081 * Licensed under the Apache License, Version 2.0 (the "License");
98082 * you may not use this file except in compliance with the License.
98083 * You may obtain a copy of the License at
98084 *
98085 * http://www.apache.org/licenses/LICENSE-2.0
98086 *
98087 * Unless required by applicable law or agreed to in writing, software
98088 * distributed under the License is distributed on an "AS IS" BASIS,
98089 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98090 * See the License for the specific language governing permissions and
98091 * limitations under the License.
98092 * =============================================================================
98093 */
98094 const IS_FINITE = `return float(!isnan(x) && !isinf(x));`;
98095 const isFinite$1 = unaryKernelFunc({ opSnippet: IS_FINITE, dtype: 'bool' });
98096 const isFiniteConfig = {
98097 kernelName: IsFinite,
98098 backendName: 'webgl',
98099 kernelFunc: isFinite$1,
98100 };
98101
98102 /**
98103 * @license
98104 * Copyright 2020 Google LLC. All Rights Reserved.
98105 * Licensed under the Apache License, Version 2.0 (the "License");
98106 * you may not use this file except in compliance with the License.
98107 * You may obtain a copy of the License at
98108 *
98109 * http://www.apache.org/licenses/LICENSE-2.0
98110 *
98111 * Unless required by applicable law or agreed to in writing, software
98112 * distributed under the License is distributed on an "AS IS" BASIS,
98113 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98114 * See the License for the specific language governing permissions and
98115 * limitations under the License.
98116 * =============================================================================
98117 */
98118 const IS_INF = `return float(isinf(x));`;
98119 const isInf = unaryKernelFunc({ opSnippet: IS_INF, dtype: 'bool' });
98120 const isInfConfig = {
98121 kernelName: IsInf,
98122 backendName: 'webgl',
98123 kernelFunc: isInf,
98124 };
98125
98126 /**
98127 * @license
98128 * Copyright 2020 Google LLC. All Rights Reserved.
98129 * Licensed under the Apache License, Version 2.0 (the "License");
98130 * you may not use this file except in compliance with the License.
98131 * You may obtain a copy of the License at
98132 *
98133 * http://www.apache.org/licenses/LICENSE-2.0
98134 *
98135 * Unless required by applicable law or agreed to in writing, software
98136 * distributed under the License is distributed on an "AS IS" BASIS,
98137 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98138 * See the License for the specific language governing permissions and
98139 * limitations under the License.
98140 * =============================================================================
98141 */
98142 const IS_NAN = `return float(isnan(x));`;
98143 const isNaN$1 = unaryKernelFunc({ opSnippet: IS_NAN, dtype: 'bool' });
98144 const isNaNConfig = {
98145 kernelName: IsNan,
98146 backendName: 'webgl',
98147 kernelFunc: isNaN$1,
98148 };
98149
98150 /**
98151 * @license
98152 * Copyright 2020 Google LLC. All Rights Reserved.
98153 * Licensed under the Apache License, Version 2.0 (the "License");
98154 * you may not use this file except in compliance with the License.
98155 * You may obtain a copy of the License at
98156 *
98157 * http://www.apache.org/licenses/LICENSE-2.0
98158 *
98159 * Unless required by applicable law or agreed to in writing, software
98160 * distributed under the License is distributed on an "AS IS" BASIS,
98161 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98162 * See the License for the specific language governing permissions and
98163 * limitations under the License.
98164 * =============================================================================
98165 */
98166 const LESS = `return float(a < b);`;
98167 const LESS_PACKED = `
98168 return vec4(lessThan(a, b));
98169`;
98170 const less = binaryKernelFunc({
98171 opSnippet: LESS,
98172 packedOpSnippet: LESS_PACKED,
98173 cpuKernelImpl: lessImplCPU,
98174 dtype: 'bool'
98175 });
98176 const lessConfig = {
98177 kernelName: Less,
98178 backendName: 'webgl',
98179 kernelFunc: less
98180 };
98181
98182 /**
98183 * @license
98184 * Copyright 2020 Google LLC. All Rights Reserved.
98185 * Licensed under the Apache License, Version 2.0 (the "License");
98186 * you may not use this file except in compliance with the License.
98187 * You may obtain a copy of the License at
98188 *
98189 * http://www.apache.org/licenses/LICENSE-2.0
98190 *
98191 * Unless required by applicable law or agreed to in writing, software
98192 * distributed under the License is distributed on an "AS IS" BASIS,
98193 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98194 * See the License for the specific language governing permissions and
98195 * limitations under the License.
98196 * =============================================================================
98197 */
98198 const LESS_EQUAL = `return float(a <= b);`;
98199 const LESS_EQUAL_PACKED = `
98200 return vec4(lessThanEqual(a, b));
98201`;
98202 const lessEqual = binaryKernelFunc({
98203 opSnippet: LESS_EQUAL,
98204 packedOpSnippet: LESS_EQUAL_PACKED,
98205 cpuKernelImpl: lessEqualImplCPU,
98206 dtype: 'bool'
98207 });
98208 const lessEqualConfig = {
98209 kernelName: LessEqual,
98210 backendName: 'webgl',
98211 kernelFunc: lessEqual
98212 };
98213
98214 /**
98215 * @license
98216 * Copyright 2020 Google LLC. All Rights Reserved.
98217 * Licensed under the Apache License, Version 2.0 (the "License");
98218 * you may not use this file except in compliance with the License.
98219 * You may obtain a copy of the License at
98220 *
98221 * http://www.apache.org/licenses/LICENSE-2.0
98222 *
98223 * Unless required by applicable law or agreed to in writing, software
98224 * distributed under the License is distributed on an "AS IS" BASIS,
98225 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98226 * See the License for the specific language governing permissions and
98227 * limitations under the License.
98228 * =============================================================================
98229 */
98230 function linSpace(args) {
98231 const { backend, attrs } = args;
98232 const { start, stop, num } = attrs;
98233 // TODO: Use CPU implementation due to the precision problem in Safari.
98234 const outVals = linSpaceImplCPU(start, stop, num);
98235 return backend.makeTensorInfo([outVals.length], 'float32', outVals);
98236 }
98237 const linSpaceConfig = {
98238 kernelName: LinSpace,
98239 backendName: 'webgl',
98240 kernelFunc: linSpace
98241 };
98242
98243 /**
98244 * @license
98245 * Copyright 2020 Google LLC. All Rights Reserved.
98246 * Licensed under the Apache License, Version 2.0 (the "License");
98247 * you may not use this file except in compliance with the License.
98248 * You may obtain a copy of the License at
98249 *
98250 * http://www.apache.org/licenses/LICENSE-2.0
98251 *
98252 * Unless required by applicable law or agreed to in writing, software
98253 * distributed under the License is distributed on an "AS IS" BASIS,
98254 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98255 * See the License for the specific language governing permissions and
98256 * limitations under the License.
98257 * =============================================================================
98258 */
98259 // Windows chrome return 0 if the input is negative value. We will specifically
98260 // return NaN if the input is 0 to solve compatiblity issue.
98261 const LOG = CHECK_NAN_SNIPPET_UNARY + `
98262 return x < 0.0 ? 0./0. : log(x);
98263`;
98264 const LOG_PACKED = `
98265 vec4 result = log(x);
98266 bvec4 isNaN = isnan(x);
98267 result.r = isNaN.r ? x.r : (x.r < 0.0 ? 0./0. : result.r);
98268 result.g = isNaN.g ? x.g : (x.g < 0.0 ? 0./0. : result.g);
98269 result.b = isNaN.b ? x.b : (x.b < 0.0 ? 0./0. : result.b);
98270 result.a = isNaN.a ? x.a : (x.a < 0.0 ? 0./0. : result.a);
98271 return result;
98272`;
98273 const log = unaryKernelFunc({ opSnippet: LOG, packedOpSnippet: LOG_PACKED, cpuKernelImpl: logImplCPU });
98274 const logConfig = {
98275 kernelName: Log,
98276 backendName: 'webgl',
98277 kernelFunc: log
98278 };
98279
98280 /**
98281 * @license
98282 * Copyright 2020 Google LLC. All Rights Reserved.
98283 * Licensed under the Apache License, Version 2.0 (the "License");
98284 * you may not use this file except in compliance with the License.
98285 * You may obtain a copy of the License at
98286 *
98287 * http://www.apache.org/licenses/LICENSE-2.0
98288 *
98289 * Unless required by applicable law or agreed to in writing, software
98290 * distributed under the License is distributed on an "AS IS" BASIS,
98291 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98292 * See the License for the specific language governing permissions and
98293 * limitations under the License.
98294 * =============================================================================
98295 */
98296 const LOG1P = CHECK_NAN_SNIPPET_UNARY + `
98297 return log(1.0 + x);
98298`;
98299 const log1p = unaryKernelFunc({ opSnippet: LOG1P });
98300 const log1pConfig = {
98301 kernelName: Log1p,
98302 backendName: 'webgl',
98303 kernelFunc: log1p,
98304 };
98305
98306 /**
98307 * @license
98308 * Copyright 2020 Google LLC. All Rights Reserved.
98309 * Licensed under the Apache License, Version 2.0 (the "License");
98310 * you may not use this file except in compliance with the License.
98311 * You may obtain a copy of the License at
98312 *
98313 * http://www.apache.org/licenses/LICENSE-2.0
98314 *
98315 * Unless required by applicable law or agreed to in writing, software
98316 * distributed under the License is distributed on an "AS IS" BASIS,
98317 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98318 * See the License for the specific language governing permissions and
98319 * limitations under the License.
98320 * =============================================================================
98321 */
98322 const LOGICAL_AND = `return float(a >= 1.0 && b >= 1.0);`;
98323 const LOGICAL_AND_PACKED = `
98324 return vec4(
98325 vec4(greaterThanEqual(a, vec4(1.0))) *
98326 vec4(greaterThanEqual(b, vec4(1.0))));
98327`;
98328 const logicalAnd = binaryKernelFunc({
98329 opSnippet: LOGICAL_AND,
98330 packedOpSnippet: LOGICAL_AND_PACKED,
98331 dtype: 'bool'
98332 });
98333 const logicalAndConfig = {
98334 kernelName: LogicalAnd,
98335 backendName: 'webgl',
98336 kernelFunc: logicalAnd
98337 };
98338
98339 /**
98340 * @license
98341 * Copyright 2020 Google LLC. All Rights Reserved.
98342 * Licensed under the Apache License, Version 2.0 (the "License");
98343 * you may not use this file except in compliance with the License.
98344 * You may obtain a copy of the License at
98345 *
98346 * http://www.apache.org/licenses/LICENSE-2.0
98347 *
98348 * Unless required by applicable law or agreed to in writing, software
98349 * distributed under the License is distributed on an "AS IS" BASIS,
98350 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98351 * See the License for the specific language governing permissions and
98352 * limitations under the License.
98353 * =============================================================================
98354 */
98355 const LOGICAL_NOT = `return float(!(x >= 1.0));`;
98356 const logicalNot = unaryKernelFunc({ opSnippet: LOGICAL_NOT });
98357 const logicalNotConfig = {
98358 kernelName: LogicalNot,
98359 backendName: 'webgl',
98360 kernelFunc: logicalNot,
98361 };
98362
98363 /**
98364 * @license
98365 * Copyright 2020 Google LLC. All Rights Reserved.
98366 * Licensed under the Apache License, Version 2.0 (the "License");
98367 * you may not use this file except in compliance with the License.
98368 * You may obtain a copy of the License at
98369 *
98370 * http://www.apache.org/licenses/LICENSE-2.0
98371 *
98372 * Unless required by applicable law or agreed to in writing, software
98373 * distributed under the License is distributed on an "AS IS" BASIS,
98374 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98375 * See the License for the specific language governing permissions and
98376 * limitations under the License.
98377 * =============================================================================
98378 */
98379 const LOGICAL_OR = `return float(a >= 1.0 || b >= 1.0);`;
98380 const LOGICAL_OR_PACKED = `
98381 return min(
98382 vec4(greaterThanEqual(a, vec4(1.0))) +
98383 vec4(greaterThanEqual(b, vec4(1.0))),
98384 vec4(1.0));
98385`;
98386 const logicalOr = binaryKernelFunc({ opSnippet: LOGICAL_OR, packedOpSnippet: LOGICAL_OR_PACKED, dtype: 'bool' });
98387 const logicalOrConfig = {
98388 kernelName: LogicalOr,
98389 backendName: 'webgl',
98390 kernelFunc: logicalOr
98391 };
98392
98393 /**
98394 * @license
98395 * Copyright 2017 Google LLC. All Rights Reserved.
98396 * Licensed under the Apache License, Version 2.0 (the "License");
98397 * you may not use this file except in compliance with the License.
98398 * You may obtain a copy of the License at
98399 *
98400 * http://www.apache.org/licenses/LICENSE-2.0
98401 *
98402 * Unless required by applicable law or agreed to in writing, software
98403 * distributed under the License is distributed on an "AS IS" BASIS,
98404 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98405 * See the License for the specific language governing permissions and
98406 * limitations under the License.
98407 * =============================================================================
98408 */
98409 class LRNProgram {
98410 constructor(xShape, radius, bias, alpha, beta) {
98411 this.variableNames = ['x'];
98412 this.outputShape = [];
98413 const rad = radius;
98414 const maxD = xShape[3] - 1;
98415 this.outputShape = xShape;
98416 // optimize pow(bias + alpha * sum, -beta)
98417 // src: https://github.com/tensorflow/tensorflow/..
98418 // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
98419 // tensorflow/core/kernels/mkl_lrn_op.cc#L320
98420 let powOperator;
98421 const basis = `float(${bias}) + float(${alpha}) * sum`;
98422 if (beta === 0.5) {
98423 powOperator = `inversesqrt(${basis})`;
98424 }
98425 else if (beta === 1.0) {
98426 powOperator = `1.0/(${basis})`;
98427 }
98428 else {
98429 powOperator = `exp(log(${basis}) * float(-${beta}));`;
98430 }
98431 this.userCode = `
98432 void main() {
98433 ivec4 coords = getOutputCoords();
98434 int b = coords[0];
98435 int r = coords[1];
98436 int c = coords[2];
98437 int d = coords[3];
98438 float x = getX(b, r, c, d);
98439 float sum = 0.0;
98440 for (int j = -${rad}; j <= ${rad}; j++) {
98441 int idx = d + j;
98442 if (idx >= 0 && idx <= ${maxD}) {
98443 float z = getX(b, r, c, idx);
98444 sum += z * z;
98445 }
98446 }
98447 float val = x * ${powOperator};
98448 setOutput(val);
98449 }
98450 `;
98451 }
98452 }
98453
98454 /**
98455 * @license
98456 * Copyright 2019 Google LLC. All Rights Reserved.
98457 * Licensed under the Apache License, Version 2.0 (the "License");
98458 * you may not use this file except in compliance with the License.
98459 * You may obtain a copy of the License at
98460 *
98461 * http://www.apache.org/licenses/LICENSE-2.0
98462 *
98463 * Unless required by applicable law or agreed to in writing, software
98464 * distributed under the License is distributed on an "AS IS" BASIS,
98465 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98466 * See the License for the specific language governing permissions and
98467 * limitations under the License.
98468 * =============================================================================
98469 */
98470 class LRNPackedProgram {
98471 constructor(xShape, radius, bias, alpha, beta) {
98472 this.variableNames = ['x'];
98473 this.outputShape = [];
98474 this.packedInputs = true;
98475 this.packedOutput = true;
98476 const rad = radius;
98477 const maxD = xShape[3] - 1;
98478 this.outputShape = xShape;
98479 // optimize pow(bias + alpha * sum, -beta)
98480 // src: https://github.com/tensorflow/tensorflow/..
98481 // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
98482 // tensorflow/core/kernels/mkl_lrn_op.cc#L320
98483 let powOperator;
98484 const basis = `float(${bias}) + float(${alpha}) * sum`;
98485 if (beta === 0.5) {
98486 powOperator = `inversesqrt(${basis})`;
98487 }
98488 else if (beta === 1.0) {
98489 powOperator = `1.0/(${basis})`;
98490 }
98491 else {
98492 powOperator = `exp(log(${basis}) * float(-${beta}));`;
98493 }
98494 this.userCode = `
98495 void main() {
98496 ivec4 coords = getOutputCoords();
98497 int b = coords.x;
98498 int r = coords.y;
98499 int c = coords.z;
98500 int d = coords.w;
98501
98502 bool hasNextCol = d < ${this.outputShape[3]};
98503 bool hasNextRow = c < ${this.outputShape[2]};
98504
98505 vec4 sum = vec4(0.);
98506 vec4 xFragAtOutputCoords = getX(b, r, c, d);
98507
98508 vec4 xAtOutputCoords = vec4(
98509 getChannel(xFragAtOutputCoords, vec2(c, d)),
98510 hasNextCol ?
98511 getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,
98512 hasNextRow ?
98513 getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,
98514 (hasNextRow && hasNextCol) ?
98515 getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0
98516 );
98517
98518 int firstChannel = d - ${rad};
98519 vec2 cache = vec2(0.);
98520 if(firstChannel >= 0){
98521 vec4 firstChannelFrag = getX(b, r, c, firstChannel);
98522 cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));
98523 if(hasNextRow){
98524 cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));
98525 }
98526 }
98527
98528 ivec2 depth = ivec2(d, d + 1);
98529 for (int j = - ${rad}; j <= ${rad}; j++) {
98530 ivec2 idx = depth + j;
98531 bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));
98532 bvec2 belowUpperBound = lessThanEqual(idx, ivec2(${maxD}));
98533
98534 bool depthInRange = aboveLowerBound.x && belowUpperBound.x;
98535 bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;
98536
98537 if(depthInRange || depthPlusOneInRange){
98538 vec4 z = vec4(0.);
98539 vec4 xFragAtCurrentDepth;
98540 z.xz = cache.xy;
98541 if(depthPlusOneInRange && hasNextCol){
98542 xFragAtCurrentDepth = idx.y != d ?
98543 getX(b, r, c, idx.y) : xFragAtOutputCoords;
98544 z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));
98545 if(hasNextRow){
98546 z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));
98547 }
98548 }
98549 cache.xy = z.yw;
98550 sum += z * z;
98551 }
98552 }
98553 vec4 result = xAtOutputCoords * ${powOperator};
98554 setOutput(result);
98555 }
98556 `;
98557 }
98558 }
98559
98560 /**
98561 * @license
98562 * Copyright 2020 Google LLC. All Rights Reserved.
98563 * Licensed under the Apache License, Version 2.0 (the "License");
98564 * you may not use this file except in compliance with the License.
98565 * You may obtain a copy of the License at
98566 *
98567 * http://www.apache.org/licenses/LICENSE-2.0
98568 *
98569 * Unless required by applicable law or agreed to in writing, software
98570 * distributed under the License is distributed on an "AS IS" BASIS,
98571 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98572 * See the License for the specific language governing permissions and
98573 * limitations under the License.
98574 * =============================================================================
98575 */
98576 const lrn = (args) => {
98577 const { inputs, backend, attrs } = args;
98578 const { x } = inputs;
98579 const { depthRadius, bias, alpha, beta } = attrs;
98580 const program = env().getBool('WEBGL_PACK_NORMALIZATION') ?
98581 new LRNPackedProgram(x.shape, depthRadius, bias, alpha, beta) :
98582 new LRNProgram(x.shape, depthRadius, bias, alpha, beta);
98583 return backend.runWebGLProgram(program, [x], x.dtype);
98584 };
98585 // tslint:disable-next-line: variable-name
98586 const LRNConfig = {
98587 kernelName: LRN,
98588 backendName: 'webgl',
98589 kernelFunc: lrn
98590 };
98591
98592 /**
98593 * @license
98594 * Copyright 2018 Google LLC. All Rights Reserved.
98595 * Licensed under the Apache License, Version 2.0 (the "License");
98596 * you may not use this file except in compliance with the License.
98597 * You may obtain a copy of the License at
98598 *
98599 * http://www.apache.org/licenses/LICENSE-2.0
98600 *
98601 * Unless required by applicable law or agreed to in writing, software
98602 * distributed under the License is distributed on an "AS IS" BASIS,
98603 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98604 * See the License for the specific language governing permissions and
98605 * limitations under the License.
98606 * =============================================================================
98607 */
98608 class LRNGradProgram {
98609 constructor(inputShape, depthRadius, bias, alpha, beta) {
98610 this.variableNames = ['inputImage', 'outputImage', 'dy'];
98611 this.outputShape = [];
98612 this.outputShape = inputShape;
98613 this.depth = inputShape[3];
98614 this.depthRadius = depthRadius;
98615 this.bias = bias;
98616 this.alpha = alpha;
98617 this.beta = beta;
98618 this.userCode = `
98619 void main() {
98620 ivec4 coords = getOutputCoords();
98621 int b = coords[0];
98622 int r = coords[1];
98623 int c = coords[2];
98624
98625 float result = 0.0;
98626 for (int d = 0; d < ${this.depth}; ++d) {
98627 int depthBegin = int(max(0.0, float(d - ${depthRadius})));
98628 int depthEnd = int(min(float(${this.depth}),
98629 float(d + ${depthRadius} + 1)));
98630
98631 const int MIN_DEPTH_BEGIN = 0;
98632 const int MAX_DEPTH_END = ${this.depth};
98633
98634 float norm = 0.0;
98635 for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) {
98636 if (k < depthBegin){
98637 continue;
98638 }
98639 else if (k >= depthBegin && k < depthEnd) {
98640 norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k);
98641 }
98642 else {
98643 break;
98644 }
98645 }
98646
98647 norm = float(${alpha}) * norm + float(${bias});
98648
98649 for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){
98650 if (k < depthBegin){
98651 continue;
98652 }
98653 else if (k >= depthBegin && k < depthEnd){
98654 float dyi = -2.0 * float(${alpha})
98655 * float(${beta})
98656 * getInputImage(b, r, c, k) * getOutputImage(b, r, c, d)
98657 / norm;
98658 if (k == d) {
98659 dyi += pow(norm, -1.0 * ${beta});
98660 }
98661 if (k == coords[3]) {
98662 dyi *= getDy(b, r, c, d);
98663 result += dyi;
98664 }
98665 }
98666 else {
98667 break;
98668 }
98669 }
98670 }
98671 setOutput(result);
98672 }
98673 `;
98674 }
98675 }
98676
98677 /**
98678 * @license
98679 * Copyright 2020 Google LLC. All Rights Reserved.
98680 * Licensed under the Apache License, Version 2.0 (the "License");
98681 * you may not use this file except in compliance with the License.
98682 * You may obtain a copy of the License at
98683 *
98684 * http://www.apache.org/licenses/LICENSE-2.0
98685 *
98686 * Unless required by applicable law or agreed to in writing, software
98687 * distributed under the License is distributed on an "AS IS" BASIS,
98688 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98689 * See the License for the specific language governing permissions and
98690 * limitations under the License.
98691 * =============================================================================
98692 */
98693 const lrnGrad = (args) => {
98694 const { inputs, backend, attrs } = args;
98695 const { x, y, dy } = inputs;
98696 const { depthRadius, bias, alpha, beta } = attrs;
98697 const program = new LRNGradProgram(x.shape, depthRadius, bias, alpha, beta);
98698 return backend.runWebGLProgram(program, [x, y, dy], x.dtype);
98699 };
98700 // tslint:disable-next-line: variable-name
98701 const LRNGradConfig = {
98702 kernelName: LRNGrad,
98703 backendName: 'webgl',
98704 kernelFunc: lrnGrad
98705 };
98706
98707 /**
98708 * @license
98709 * Copyright 2020 Google LLC. All Rights Reserved.
98710 * Licensed under the Apache License, Version 2.0 (the "License");
98711 * you may not use this file except in compliance with the License.
98712 * You may obtain a copy of the License at
98713 *
98714 * http://www.apache.org/licenses/LICENSE-2.0
98715 *
98716 * Unless required by applicable law or agreed to in writing, software
98717 * distributed under the License is distributed on an "AS IS" BASIS,
98718 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98719 * See the License for the specific language governing permissions and
98720 * limitations under the License.
98721 * =============================================================================
98722 */
98723 function maxImpl(x, reduceShape, outShape, backend) {
98724 const inSize = sizeFromShape(reduceShape);
98725 const xSize = sizeFromShape(x.shape);
98726 const batchSize = xSize / inSize;
98727 const reshapedInput = reshape({ inputs: { x }, attrs: { shape: [batchSize, inSize] }, backend });
98728 const reduced = reduce(reshapedInput, x.dtype, 'max', backend);
98729 const reshapedOutput = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend });
98730 backend.disposeIntermediateTensorInfo(reshapedInput);
98731 backend.disposeIntermediateTensorInfo(reduced);
98732 return reshapedOutput;
98733 }
98734
98735 /**
98736 * @license
98737 * Copyright 2020 Google LLC. All Rights Reserved.
98738 * Licensed under the Apache License, Version 2.0 (the "License");
98739 * you may not use this file except in compliance with the License.
98740 * You may obtain a copy of the License at
98741 *
98742 * http://www.apache.org/licenses/LICENSE-2.0
98743 *
98744 * Unless required by applicable law or agreed to in writing, software
98745 * distributed under the License is distributed on an "AS IS" BASIS,
98746 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98747 * See the License for the specific language governing permissions and
98748 * limitations under the License.
98749 * =============================================================================
98750 */
98751 function max(args) {
98752 const { inputs, backend, attrs } = args;
98753 const { x } = inputs;
98754 const { reductionIndices, keepDims } = attrs;
98755 const xRank = x.shape.length;
98756 const origAxes = parseAxisParam(reductionIndices, x.shape);
98757 let axes = origAxes;
98758 const permutedAxes = getAxesPermutation(axes, xRank);
98759 const maxInputIsTransposed = permutedAxes != null;
98760 const shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]);
98761 let maxInput = x;
98762 if (maxInputIsTransposed) {
98763 if (shouldExecuteOnCPU) {
98764 const xTexData = backend.texData.get(maxInput.dataId);
98765 const values = xTexData.values;
98766 const newShape = new Array(xRank);
98767 for (let i = 0; i < newShape.length; i++) {
98768 newShape[i] = x.shape[permutedAxes[i]];
98769 }
98770 const maxInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
98771 maxInput = backend.makeTensorInfo(newShape, x.dtype);
98772 const maxInputData = backend.texData.get(maxInput.dataId);
98773 maxInputData.values = maxInputValues;
98774 }
98775 else {
98776 maxInput = transposeImpl(x, permutedAxes, backend);
98777 }
98778 axes = getInnerMostAxes(axes.length, xRank);
98779 }
98780 assertAxesAreInnerMostDims('max', axes, xRank);
98781 const [maxOutShape, reduceShape] = computeOutAndReduceShapes(maxInput.shape, axes);
98782 let outShape = maxOutShape;
98783 if (keepDims) {
98784 // rather than reshape at the end, set the target shape here.
98785 outShape = expandShapeToKeepDim(maxOutShape, origAxes);
98786 }
98787 let out;
98788 if (shouldExecuteOnCPU) {
98789 const xTexData = backend.texData.get(maxInput.dataId);
98790 const values = xTexData.values;
98791 const outValues = maxImplCPU(values, sizeFromShape(reduceShape), outShape, x.dtype);
98792 out = backend.makeTensorInfo(outShape, x.dtype);
98793 const outData = backend.texData.get(out.dataId);
98794 outData.values = outValues;
98795 }
98796 else {
98797 out = maxImpl(maxInput, reduceShape, outShape, backend);
98798 }
98799 if (maxInputIsTransposed) {
98800 backend.disposeIntermediateTensorInfo(maxInput);
98801 }
98802 return out;
98803 }
98804 const maxConfig = {
98805 kernelName: Max,
98806 backendName: 'webgl',
98807 kernelFunc: max
98808 };
98809
98810 /**
98811 * @license
98812 * Copyright 2020 Google LLC. All Rights Reserved.
98813 * Licensed under the Apache License, Version 2.0 (the "License");
98814 * you may not use this file except in compliance with the License.
98815 * You may obtain a copy of the License at
98816 *
98817 * http://www.apache.org/licenses/LICENSE-2.0
98818 *
98819 * Unless required by applicable law or agreed to in writing, software
98820 * distributed under the License is distributed on an "AS IS" BASIS,
98821 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98822 * See the License for the specific language governing permissions and
98823 * limitations under the License.
98824 * =============================================================================
98825 */
98826 const MAXIMUM = CHECK_NAN_SNIPPET + `
98827 return max(a, b);
98828`;
98829 const MAXIMUM_PACKED = `
98830 vec4 result = vec4(max(a, b));
98831 bvec4 isNaNA = isnan(a);
98832 bvec4 isNaNB = isnan(b);
98833 bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w);
98834 ` +
98835 CHECK_NAN_SNIPPET_PACKED + `
98836 return result;
98837`;
98838 const maximum = binaryKernelFunc({
98839 opSnippet: MAXIMUM,
98840 packedOpSnippet: MAXIMUM_PACKED,
98841 cpuKernelImpl: maximumImplCPU
98842 });
98843 const maximumConfig = {
98844 kernelName: Maximum$1,
98845 backendName: 'webgl',
98846 kernelFunc: maximum
98847 };
98848
98849 /**
98850 * @license
98851 * Copyright 2020 Google LLC. All Rights Reserved.
98852 * Licensed under the Apache License, Version 2.0 (the "License");
98853 * you may not use this file except in compliance with the License.
98854 * You may obtain a copy of the License at
98855 *
98856 * http://www.apache.org/licenses/LICENSE-2.0
98857 *
98858 * Unless required by applicable law or agreed to in writing, software
98859 * distributed under the License is distributed on an "AS IS" BASIS,
98860 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98861 * See the License for the specific language governing permissions and
98862 * limitations under the License.
98863 * =============================================================================
98864 */
98865 function maxPool(args) {
98866 const { inputs, backend, attrs } = args;
98867 const { x } = inputs;
98868 assertNotComplex(x, 'maxPool');
98869 const { filterSize, strides, pad, dimRoundingMode } = attrs;
98870 const dilations = 1;
98871 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
98872 `Got strides ${strides} and dilations '${dilations}'`);
98873 const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
98874 if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
98875 arraysEqual(convInfo.inShape, convInfo.outShape)) {
98876 return identity({ inputs: { x }, backend });
98877 }
98878 const maxPoolProgram = new Pool2DProgram(convInfo, 'max', false);
98879 return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype);
98880 }
98881 const maxPoolConfig = {
98882 kernelName: MaxPool,
98883 backendName: 'webgl',
98884 kernelFunc: maxPool
98885 };
98886
98887 /**
98888 * @license
98889 * Copyright 2020 Google LLC. All Rights Reserved.
98890 * Licensed under the Apache License, Version 2.0 (the "License");
98891 * you may not use this file except in compliance with the License.
98892 * You may obtain a copy of the License at
98893 *
98894 * http://www.apache.org/licenses/LICENSE-2.0
98895 *
98896 * Unless required by applicable law or agreed to in writing, software
98897 * distributed under the License is distributed on an "AS IS" BASIS,
98898 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98899 * See the License for the specific language governing permissions and
98900 * limitations under the License.
98901 * =============================================================================
98902 */
98903 function maxPool3d(args) {
98904 const { inputs, backend, attrs } = args;
98905 const { x } = inputs;
98906 const { filterSize, strides, pad, dataFormat, dimRoundingMode } = attrs;
98907 const dilations = [1, 1, 1];
98908 const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat);
98909 const maxPoolProgram = new Pool3DProgram(convInfo, 'max', false);
98910 return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype);
98911 }
98912 const maxPool3DConfig = {
98913 kernelName: MaxPool3D,
98914 backendName: 'webgl',
98915 kernelFunc: maxPool3d
98916 };
98917
98918 /**
98919 * @license
98920 * Copyright 2017 Google LLC. All Rights Reserved.
98921 * Licensed under the Apache License, Version 2.0 (the "License");
98922 * you may not use this file except in compliance with the License.
98923 * You may obtain a copy of the License at
98924 *
98925 * http://www.apache.org/licenses/LICENSE-2.0
98926 *
98927 * Unless required by applicable law or agreed to in writing, software
98928 * distributed under the License is distributed on an "AS IS" BASIS,
98929 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98930 * See the License for the specific language governing permissions and
98931 * limitations under the License.
98932 * =============================================================================
98933 */
98934 class MaxPool2DBackpropProgram {
98935 constructor(convInfo) {
98936 this.variableNames = ['dy', 'maxPos'];
98937 this.outputShape = convInfo.inShape;
98938 const strideHeight = convInfo.strideHeight;
98939 const strideWidth = convInfo.strideWidth;
98940 const dilationHeight = convInfo.dilationHeight;
98941 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
98942 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
98943 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
98944 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
98945 const lastIndex = effectiveFilterHeight * effectiveFilterWidth - 1;
98946 this.userCode = `
98947 const ivec2 pads = ivec2(${padTop}, ${padLeft});
98948
98949 void main() {
98950 ivec4 coords = getOutputCoords();
98951 int b = coords[0];
98952 int d = coords[3];
98953
98954 ivec2 dyRCCorner = coords.yz - pads;
98955 int dyRCorner = dyRCCorner.x;
98956 int dyCCorner = dyRCCorner.y;
98957
98958 // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).
98959 // ? = to be determined. : = across all values in that axis.
98960 float dotProd = 0.0;
98961 for (int wR = 0; wR < ${effectiveFilterHeight};
98962 wR += ${dilationHeight}) {
98963 float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
98964
98965 if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
98966 continue;
98967 }
98968 int idyR = int(dyR);
98969
98970 for (int wC = 0; wC < ${effectiveFilterWidth}; wC++) {
98971 float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
98972
98973 if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
98974 fract(dyC) > 0.0) {
98975 continue;
98976 }
98977 int idyC = int(dyC);
98978
98979 float dyValue = getDy(b, idyR, idyC, d);
98980 int maxPosValue = ${lastIndex} - int(getMaxPos(b, idyR, idyC, d));
98981
98982 // Get the current value, check it against the value from the
98983 // position matrix.
98984 int curPosValue = wR * ${effectiveFilterWidth} + wC;
98985 float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);
98986
98987 dotProd += dyValue * mask;
98988 }
98989 }
98990 setOutput(dotProd);
98991 }
98992 `;
98993 }
98994 }
98995 class MaxPool3DBackpropProgram {
98996 constructor(convInfo) {
98997 this.variableNames = ['dy', 'maxPos'];
98998 this.outputShape = convInfo.inShape;
98999 const strideDepth = convInfo.strideDepth;
99000 const strideHeight = convInfo.strideHeight;
99001 const strideWidth = convInfo.strideWidth;
99002 const dilationDepth = convInfo.dilationDepth;
99003 const dilationHeight = convInfo.dilationHeight;
99004 const dilationWidth = convInfo.dilationWidth;
99005 const effectiveFilterDepth = convInfo.effectiveFilterDepth;
99006 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
99007 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
99008 const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
99009 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
99010 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
99011 const lastIndex = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1;
99012 this.userCode = `
99013 const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
99014
99015 void main() {
99016 ivec5 coords = getOutputCoords();
99017 int batch = coords.x;
99018 int ch = coords.u;
99019
99020 ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
99021 int dyDCorner = dyCorner.x;
99022 int dyRCorner = dyCorner.y;
99023 int dyCCorner = dyCorner.z;
99024
99025 // Convolve dy(?, ?, ?, ch) with pos mask(:, :, :, d) to get
99026 // dx(xD, xR, xC, ch).
99027 // ? = to be determined. : = across all values in that axis.
99028 float dotProd = 0.0;
99029
99030 for (int wD = 0; wD < ${effectiveFilterDepth};
99031 wD += ${dilationDepth}) {
99032 float dyD = float(dyDCorner + wD) / ${strideDepth}.0;
99033
99034 if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) {
99035 continue;
99036 }
99037 int idyD = int(dyD);
99038
99039 for (int wR = 0; wR < ${effectiveFilterHeight};
99040 wR += ${dilationHeight}) {
99041 float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
99042
99043 if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
99044 fract(dyR) > 0.0) {
99045 continue;
99046 }
99047 int idyR = int(dyR);
99048
99049 for (int wC = 0; wC < ${effectiveFilterWidth};
99050 wC += ${dilationWidth}) {
99051 float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
99052
99053 if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
99054 fract(dyC) > 0.0) {
99055 continue;
99056 }
99057 int idyC = int(dyC);
99058
99059 float dyValue = getDy(batch, idyD, idyR, idyC, ch);
99060 int maxPosValue = ${lastIndex} -
99061 int(getMaxPos(batch, idyD, idyR, idyC, ch));
99062
99063 // Get the current value, check it against the value from the
99064 // position matrix.
99065 int curPosValue =
99066 wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} +
99067 wR * ${effectiveFilterWidth} + wC;
99068 float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);
99069
99070 dotProd += dyValue * mask;
99071 }
99072 }
99073 }
99074 setOutput(dotProd);
99075 }
99076 `;
99077 }
99078 }
99079
99080 /**
99081 * @license
99082 * Copyright 2020 Google LLC. All Rights Reserved.
99083 * Licensed under the Apache License, Version 2.0 (the "License");
99084 * you may not use this file except in compliance with the License.
99085 * You may obtain a copy of the License at
99086 *
99087 * http://www.apache.org/licenses/LICENSE-2.0
99088 *
99089 * Unless required by applicable law or agreed to in writing, software
99090 * distributed under the License is distributed on an "AS IS" BASIS,
99091 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99092 * See the License for the specific language governing permissions and
99093 * limitations under the License.
99094 * =============================================================================
99095 */
99096 function maxPool3DGrad(args) {
99097 const { inputs, backend, attrs } = args;
99098 const { dy, input } = inputs;
99099 const x = input;
99100 const { filterSize, strides, pad, dimRoundingMode } = attrs;
99101 const dilations = [1, 1, 1];
99102 const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
99103 const maxPool3dPositionsProgram = new Pool3DProgram(convInfo, 'max', true /* get positions */);
99104 const maxPool3dPositions = backend.runWebGLProgram(maxPool3dPositionsProgram, [x], x.dtype);
99105 const maxPoolBackpropProgram = new MaxPool3DBackpropProgram(convInfo);
99106 const result = backend.runWebGLProgram(maxPoolBackpropProgram, [dy, maxPool3dPositions], x.dtype);
99107 backend.disposeIntermediateTensorInfo(maxPool3dPositions);
99108 return result;
99109 }
99110 const maxPool3DGradConfig = {
99111 kernelName: MaxPool3DGrad,
99112 backendName: 'webgl',
99113 kernelFunc: maxPool3DGrad
99114 };
99115
99116 /**
99117 * @license
99118 * Copyright 2020 Google LLC. All Rights Reserved.
99119 * Licensed under the Apache License, Version 2.0 (the "License");
99120 * you may not use this file except in compliance with the License.
99121 * You may obtain a copy of the License at
99122 *
99123 * http://www.apache.org/licenses/LICENSE-2.0
99124 *
99125 * Unless required by applicable law or agreed to in writing, software
99126 * distributed under the License is distributed on an "AS IS" BASIS,
99127 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99128 * See the License for the specific language governing permissions and
99129 * limitations under the License.
99130 * =============================================================================
99131 */
99132 function maxPoolGrad(args) {
99133 const { inputs, backend, attrs } = args;
99134 const { dy, input, output } = inputs;
99135 const x = input;
99136 assertNotComplex([input, output], 'maxPoolGrad');
99137 const { filterSize, strides, pad, dimRoundingMode } = attrs;
99138 const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
99139 const getPositions = true;
99140 const maxPoolPositionsProgram = new Pool2DProgram(convInfo, 'max', getPositions);
99141 const maxPoolPositions = backend.runWebGLProgram(maxPoolPositionsProgram, [x], x.dtype);
99142 const maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo);
99143 const result = backend.runWebGLProgram(maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype);
99144 backend.disposeIntermediateTensorInfo(maxPoolPositions);
99145 return result;
99146 }
99147 const maxPoolGradConfig = {
99148 kernelName: MaxPoolGrad,
99149 backendName: 'webgl',
99150 kernelFunc: maxPoolGrad
99151 };
99152
99153 /**
99154 * @license
99155 * Copyright 2020 Google LLC. All Rights Reserved.
99156 * Licensed under the Apache License, Version 2.0 (the "License");
99157 * you may not use this file except in compliance with the License.
99158 * You may obtain a copy of the License at
99159 *
99160 * http://www.apache.org/licenses/LICENSE-2.0
99161 *
99162 * Unless required by applicable law or agreed to in writing, software
99163 * distributed under the License is distributed on an "AS IS" BASIS,
99164 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99165 * See the License for the specific language governing permissions and
99166 * limitations under the License.
99167 * =============================================================================
99168 */
99169 function maxPoolWithArgmaxImpl(x, includeBatchInIndex, convInfo, backend) {
99170 let program = new Pool2DProgram(convInfo, 'max', false);
99171 const poolOutput = backend.runWebGLProgram(program, [x], 'float32');
99172 program = new Pool2DProgram(convInfo, 'max', true, true, includeBatchInIndex);
99173 const indexOutput = backend.runWebGLProgram(program, [x], 'float32');
99174 return [poolOutput, indexOutput];
99175 }
99176
99177 /**
99178 * @license
99179 * Copyright 2020 Google LLC. All Rights Reserved.
99180 * Licensed under the Apache License, Version 2.0 (the "License");
99181 * you may not use this file except in compliance with the License.
99182 * You may obtain a copy of the License at
99183 *
99184 * http://www.apache.org/licenses/LICENSE-2.0
99185 *
99186 * Unless required by applicable law or agreed to in writing, software
99187 * distributed under the License is distributed on an "AS IS" BASIS,
99188 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99189 * See the License for the specific language governing permissions and
99190 * limitations under the License.
99191 * =============================================================================
99192 */
99193 const maxPoolWithArgmaxConfig = {
99194 kernelName: MaxPoolWithArgmax,
99195 backendName: 'webgl',
99196 kernelFunc: ({ inputs, attrs, backend }) => {
99197 const { x } = inputs;
99198 const { filterSize, strides, pad, includeBatchInIndex } = attrs;
99199 const webglBackend = backend;
99200 assert$1(x.shape.length === 4, () => `Error in maxPool: input must be rank 4 but got rank ${x.shape.length}.`);
99201 const dilations = [1, 1];
99202 assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
99203 `Got strides ${strides} and dilations '${dilations}'`);
99204 const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad);
99205 const [result, indexes] = maxPoolWithArgmaxImpl(x, includeBatchInIndex, convInfo, webglBackend);
99206 return [result, indexes];
99207 }
99208 };
99209
99210 /**
99211 * @license
99212 * Copyright 2020 Google LLC. All Rights Reserved.
99213 * Licensed under the Apache License, Version 2.0 (the "License");
99214 * you may not use this file except in compliance with the License.
99215 * You may obtain a copy of the License at
99216 *
99217 * http://www.apache.org/licenses/LICENSE-2.0
99218 *
99219 * Unless required by applicable law or agreed to in writing, software
99220 * distributed under the License is distributed on an "AS IS" BASIS,
99221 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99222 * See the License for the specific language governing permissions and
99223 * limitations under the License.
99224 * =============================================================================
99225 */
99226 function meanImpl(x, reduceShape, outShape, backend) {
99227 const inSize = sizeFromShape(reduceShape);
99228 const xSize = sizeFromShape(x.shape);
99229 const batchSize = xSize / inSize;
99230 const reshapedInput = reshape({ inputs: { x }, attrs: { shape: [batchSize, inSize] }, backend });
99231 const reduced = reduce(reshapedInput, 'float32', 'mean', backend);
99232 const reshapedOutput = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend });
99233 backend.disposeIntermediateTensorInfo(reshapedInput);
99234 backend.disposeIntermediateTensorInfo(reduced);
99235 return reshapedOutput;
99236 }
99237
99238 /**
99239 * @license
99240 * Copyright 2020 Google LLC. All Rights Reserved.
99241 * Licensed under the Apache License, Version 2.0 (the "License");
99242 * you may not use this file except in compliance with the License.
99243 * You may obtain a copy of the License at
99244 *
99245 * http://www.apache.org/licenses/LICENSE-2.0
99246 *
99247 * Unless required by applicable law or agreed to in writing, software
99248 * distributed under the License is distributed on an "AS IS" BASIS,
99249 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99250 * See the License for the specific language governing permissions and
99251 * limitations under the License.
99252 * =============================================================================
99253 */
99254 const meanConfig = {
99255 kernelName: Mean,
99256 backendName: 'webgl',
99257 kernelFunc: ({ inputs, attrs, backend }) => {
99258 const { x } = inputs;
99259 const { keepDims, axis } = attrs;
99260 const webglBackend = backend;
99261 const xRank = x.shape.length;
99262 const origAxes = parseAxisParam(axis, x.shape);
99263 let axes = origAxes;
99264 const permutedAxes = getAxesPermutation(axes, xRank);
99265 const meanInputIsTransposed = permutedAxes != null;
99266 const shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]);
99267 const intermediates = [];
99268 let meanInput = x;
99269 if (meanInputIsTransposed) {
99270 if (shouldExecuteOnCPU) {
99271 const xTexData = webglBackend.texData.get(meanInput.dataId);
99272 const values = xTexData.values;
99273 const newShape = new Array(xRank);
99274 for (let i = 0; i < newShape.length; i++) {
99275 newShape[i] = x.shape[permutedAxes[i]];
99276 }
99277 const meanInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
99278 meanInput = webglBackend.makeTensorInfo(newShape, x.dtype);
99279 const meanInputData = webglBackend.texData.get(meanInput.dataId);
99280 meanInputData.values = meanInputValues;
99281 }
99282 else {
99283 meanInput = transposeImpl(x, permutedAxes, webglBackend);
99284 }
99285 intermediates.push(meanInput);
99286 axes = getInnerMostAxes(axes.length, xRank);
99287 }
99288 assertAxesAreInnerMostDims('sum', axes, xRank);
99289 const [meanOutShape, reduceShape] = computeOutAndReduceShapes(meanInput.shape, axes);
99290 let outShape = meanOutShape;
99291 if (keepDims) {
99292 // rather than reshape at the end, set the target shape here.
99293 outShape = expandShapeToKeepDim(meanOutShape, origAxes);
99294 }
99295 const out = meanImpl(meanInput, reduceShape, outShape, webglBackend);
99296 for (const i of intermediates) {
99297 webglBackend.disposeIntermediateTensorInfo(i);
99298 }
99299 return out;
99300 }
99301 };
99302
99303 /**
99304 * @license
99305 * Copyright 2020 Google LLC. All Rights Reserved.
99306 * Licensed under the Apache License, Version 2.0 (the "License");
99307 * you may not use this file except in compliance with the License.
99308 * You may obtain a copy of the License at
99309 *
99310 * http://www.apache.org/licenses/LICENSE-2.0
99311 *
99312 * Unless required by applicable law or agreed to in writing, software
99313 * distributed under the License is distributed on an "AS IS" BASIS,
99314 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99315 * See the License for the specific language governing permissions and
99316 * limitations under the License.
99317 * =============================================================================
99318 */
99319 function min(args) {
99320 const { inputs, backend, attrs } = args;
99321 const { x } = inputs;
99322 const { axis, keepDims } = attrs;
99323 const xRank = x.shape.length;
99324 const origAxes = parseAxisParam(axis, x.shape);
99325 let axes = origAxes;
99326 const permutedAxes = getAxesPermutation(axes, xRank);
99327 let permutedX = x;
99328 if (permutedAxes != null) {
99329 permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
99330 axes = getInnerMostAxes(axes.length, x.shape.length);
99331 }
99332 assertAxesAreInnerMostDims('min', axes, xRank);
99333 const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes);
99334 const inSize = sizeFromShape(reduceShape);
99335 const a2D = reshape({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
99336 const reduced = reduce(a2D, a2D.dtype, 'min', backend);
99337 let res;
99338 if (keepDims) {
99339 const newShape = expandShapeToKeepDim(outShape, origAxes);
99340 res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: newShape } });
99341 }
99342 else {
99343 res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
99344 }
99345 backend.disposeIntermediateTensorInfo(a2D);
99346 backend.disposeIntermediateTensorInfo(reduced);
99347 if (permutedAxes != null) {
99348 backend.disposeIntermediateTensorInfo(permutedX);
99349 }
99350 return res;
99351 }
99352 const minConfig = {
99353 kernelName: Min,
99354 backendName: 'webgl',
99355 kernelFunc: min
99356 };
99357
99358 /**
99359 * @license
99360 * Copyright 2020 Google LLC. All Rights Reserved.
99361 * Licensed under the Apache License, Version 2.0 (the "License");
99362 * you may not use this file except in compliance with the License.
99363 * You may obtain a copy of the License at
99364 *
99365 * http://www.apache.org/licenses/LICENSE-2.0
99366 *
99367 * Unless required by applicable law or agreed to in writing, software
99368 * distributed under the License is distributed on an "AS IS" BASIS,
99369 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99370 * See the License for the specific language governing permissions and
99371 * limitations under the License.
99372 * =============================================================================
99373 */
99374 const MINIMUM = CHECK_NAN_SNIPPET + `
99375 return min(a, b);
99376`;
99377 const MINIMUM_PACKED = `
99378 vec4 result = vec4(min(a, b));
99379 bvec4 isNaNA = isnan(a);
99380 bvec4 isNaNB = isnan(b);
99381 bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w);
99382 ` +
99383 CHECK_NAN_SNIPPET_PACKED + `
99384 return result;
99385`;
99386 const minimum = binaryKernelFunc({
99387 opSnippet: MINIMUM,
99388 packedOpSnippet: MINIMUM_PACKED,
99389 cpuKernelImpl: minimumImplCPU
99390 });
99391 const minimumConfig = {
99392 kernelName: Minimum$1,
99393 backendName: 'webgl',
99394 kernelFunc: minimum
99395 };
99396
99397 /**
99398 * @license
99399 * Copyright 2020 Google LLC. All Rights Reserved.
99400 * Licensed under the Apache License, Version 2.0 (the "License");
99401 * you may not use this file except in compliance with the License.
99402 * You may obtain a copy of the License at
99403 *
99404 * http://www.apache.org/licenses/LICENSE-2.0
99405 *
99406 * Unless required by applicable law or agreed to in writing, software
99407 * distributed under the License is distributed on an "AS IS" BASIS,
99408 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99409 * See the License for the specific language governing permissions and
99410 * limitations under the License.
99411 * =============================================================================
99412 */
99413 class MirrorPadProgram {
99414 constructor(xShape, paddings, mode) {
99415 this.variableNames = ['x'];
99416 this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
99417 const rank = xShape.length;
99418 const dtype = getCoordsDataType(rank);
99419 const start = paddings.map(p => p[0]).join(',');
99420 const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
99421 const unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
99422 const offset = mode === 'reflect' ? 0 : 1;
99423 if (rank === 1) {
99424 this.userCode = `
99425 int start = ${start};
99426 int end = ${end};
99427
99428 void main() {
99429 int outC = getOutputCoords();
99430 if (outC < start) {
99431 outC = start * 2 - outC - ${offset};
99432 } else if(outC >= end) {
99433 outC = (end - 1) * 2 - outC + ${offset};
99434 }
99435 setOutput(getX(outC - start));
99436 }
99437 `;
99438 return;
99439 }
99440 this.userCode = `
99441 ${dtype} start = ${dtype}(${start});
99442 ${dtype} end = ${dtype}(${end});
99443
99444 void main() {
99445 ${dtype} outC = getOutputCoords();
99446 for (int i = 0; i < ${rank}; i++) {
99447 if (outC[i] < start[i]) {
99448 outC[i] = start[i] * 2 - outC[i] - ${offset};
99449 } else if(outC[i] >= end[i]) {
99450 outC[i] = (end[i] - 1) * 2 - outC[i] + ${offset};
99451 }
99452 }
99453 ${dtype} coords = outC - start;
99454 setOutput(getX(${unpackedCoords}));
99455 }
99456 `;
99457 }
99458 }
99459
99460 /**
99461 * @license
99462 * Copyright 2020 Google LLC. All Rights Reserved.
99463 * Licensed under the Apache License, Version 2.0 (the "License");
99464 * you may not use this file except in compliance with the License.
99465 * You may obtain a copy of the License at
99466 *
99467 * http://www.apache.org/licenses/LICENSE-2.0
99468 *
99469 * Unless required by applicable law or agreed to in writing, software
99470 * distributed under the License is distributed on an "AS IS" BASIS,
99471 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99472 * See the License for the specific language governing permissions and
99473 * limitations under the License.
99474 * =============================================================================
99475 */
99476 /**
99477 * Example shader code for
99478 * `mirrorPad(tf.tensor1d([1, 2, 3], 'int32'), [[2, 2]], 'reflect')`
99479 * ```
99480 * const int start = int(2);
99481 * const int end = int(5);
99482 *
99483 * void main() {
99484 * int outputLoc = getOutputCoords();
99485 * vec4 result = vec4(0.);
99486 *
99487 * int rc = outputLoc;
99488 *
99489 * int source = rc;
99490 * if (source < start) {
99491 * source = start * 2 - source - 0;
99492 * } else if (source >= end) {
99493 * source = (end - 1) * 2 - source + 0;
99494 * }
99495 * source -= start;
99496 *
99497 * result[0] = getChannel(getX(source), source);
99498 * rc += 1;
99499 * if(rc < 6) {
99500 * int source = rc;
99501 * if (source < start) {
99502 * source = start * 2 - source - 0;
99503 * } else if (source >= end) {
99504 * source = (end - 1) * 2 - source + 0;
99505 * }
99506 * source -= start;
99507 *
99508 * result[1] = getChannel(getX(source), source);
99509 * }
99510 *
99511 * setOutput(result);
99512 * }
99513 * ```
99514 */
99515 class MirrorPadPackedProgram {
99516 constructor(xShape, paddings, mode) {
99517 this.variableNames = ['x'];
99518 this.packedInputs = true;
99519 this.packedOutput = true;
99520 this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
99521 const rank = xShape.length;
99522 const dtype = getCoordsDataType(rank);
99523 const start = paddings.map(p => p[0]).join(',');
99524 const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
99525 const coords = getChannels('rc', rank);
99526 const source = getChannels('source', rank);
99527 const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`;
99528 const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`;
99529 const offset = mode === 'reflect' ? 0 : 1;
99530 let mainLoop = '';
99531 if (rank === 1) {
99532 const padSetup = `
99533 ${dtype} source = rc;
99534 if (source < start) {
99535 source = start * 2 - source - ${offset};
99536 } else if (source >= end) {
99537 source = (end - 1) * 2 - source + ${offset};
99538 }
99539 source -= start;
99540 `;
99541 mainLoop = `
99542 ${dtype} rc = outputLoc;
99543 ${padSetup}
99544 result[0] = getChannel(getX(${source.join()}), ${innerDims});
99545 ${coords[rank - 1]} += 1;
99546 if(${cLimit}) {
99547 ${padSetup}
99548 result[1] = getChannel(getX(${source.join()}), ${innerDims});
99549 }
99550 `;
99551 }
99552 else {
99553 const padSetup = `
99554 ${dtype} source = rc;
99555 ${dtype} lt = ${dtype}(lessThan(source, start));
99556 ${dtype} gte = ${dtype}(greaterThanEqual(source, end));
99557 ${dtype} orig = 1 - (lt + gte);
99558 source = orig * source +
99559 lt * (start * 2 - source - ${offset}) +
99560 gte * ((end - 1) * 2 - source + ${offset});
99561 source -= start;
99562 `;
99563 mainLoop = `
99564 ${dtype} rc = outputLoc;
99565 ${padSetup}
99566 result[0] = getChannel(getX(${source.join()}), ${innerDims});
99567 ${coords[rank - 1]} += 1;
99568 if(${cLimit}) {
99569 ${padSetup}
99570 result[1] = getChannel(getX(${source.join()}), ${innerDims});
99571 }
99572 rc = outputLoc;
99573 ${coords[rank - 2]} += 1;
99574 if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {
99575 ${padSetup}
99576 result[2] = getChannel(getX(${source.join()}), ${innerDims});
99577 ${coords[rank - 1]} += 1;
99578 if(${cLimit}) {
99579 ${padSetup}
99580 result[3] = getChannel(getX(${source.join()}), ${innerDims});
99581 }
99582 }
99583 `;
99584 }
99585 this.userCode = `
99586 const ${dtype} start = ${dtype}(${start});
99587 const ${dtype} end = ${dtype}(${end});
99588
99589 void main() {
99590 ${dtype} outputLoc = getOutputCoords();
99591 vec4 result = vec4(0.);
99592 ${mainLoop}
99593 setOutput(result);
99594 }
99595 `;
99596 }
99597 }
99598
99599 /**
99600 * @license
99601 * Copyright 2020 Google LLC. All Rights Reserved.
99602 * Licensed under the Apache License, Version 2.0 (the "License");
99603 * you may not use this file except in compliance with the License.
99604 * You may obtain a copy of the License at
99605 *
99606 * http://www.apache.org/licenses/LICENSE-2.0
99607 *
99608 * Unless required by applicable law or agreed to in writing, software
99609 * distributed under the License is distributed on an "AS IS" BASIS,
99610 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99611 * See the License for the specific language governing permissions and
99612 * limitations under the License.
99613 * =============================================================================
99614 */
99615 const mirrorPadKernelFunc = ({ inputs, backend, attrs }) => {
99616 const { x } = inputs;
99617 const { paddings, mode } = attrs;
99618 const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
99619 new MirrorPadPackedProgram(x.shape, paddings, mode) :
99620 new MirrorPadProgram(x.shape, paddings, mode);
99621 const output = backend.runWebGLProgram(program, [x], x.dtype);
99622 return output;
99623 };
99624 const mirrorPadConfig = {
99625 kernelName: MirrorPad,
99626 backendName: 'webgl',
99627 kernelFunc: mirrorPadKernelFunc,
99628 };
99629
99630 /**
99631 * @license
99632 * Copyright 2020 Google LLC. All Rights Reserved.
99633 * Licensed under the Apache License, Version 2.0 (the "License");
99634 * you may not use this file except in compliance with the License.
99635 * You may obtain a copy of the License at
99636 *
99637 * http://www.apache.org/licenses/LICENSE-2.0
99638 *
99639 * Unless required by applicable law or agreed to in writing, software
99640 * distributed under the License is distributed on an "AS IS" BASIS,
99641 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99642 * See the License for the specific language governing permissions and
99643 * limitations under the License.
99644 * =============================================================================
99645 */
99646 const MOD = `if (b == 0.0) return NAN;
99647 return mod(a, b);`;
99648 const MOD_PACKED = `
99649 vec4 result = mod(a, b);
99650 bvec4 isNaN = equal(b, vec4(0.0));
99651 ` +
99652 CHECK_NAN_SNIPPET_PACKED + `
99653 return result;
99654`;
99655 const mod = binaryKernelFunc({
99656 opSnippet: MOD,
99657 packedOpSnippet: MOD_PACKED,
99658 });
99659 const modConfig = {
99660 kernelName: Mod,
99661 backendName: 'webgl',
99662 kernelFunc: mod
99663 };
99664
99665 /**
99666 * @license
99667 * Copyright 2017 Google LLC. All Rights Reserved.
99668 * Licensed under the Apache License, Version 2.0 (the "License");
99669 * you may not use this file except in compliance with the License.
99670 * You may obtain a copy of the License at
99671 *
99672 * http://www.apache.org/licenses/LICENSE-2.0
99673 *
99674 * Unless required by applicable law or agreed to in writing, software
99675 * distributed under the License is distributed on an "AS IS" BASIS,
99676 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99677 * See the License for the specific language governing permissions and
99678 * limitations under the License.
99679 * =============================================================================
99680 */
99681 class MultinomialProgram {
99682 constructor(batchSize, numOutcomes, numSamples) {
99683 this.variableNames = ['probs'];
99684 this.customUniforms = [{ name: 'seed', type: 'float' }];
99685 this.outputShape = [batchSize, numSamples];
99686 this.userCode = `
99687 void main() {
99688 ivec2 coords = getOutputCoords();
99689 int batch = coords[0];
99690
99691 float r = random(seed);
99692 float cdf = 0.0;
99693
99694 for (int i = 0; i < ${numOutcomes - 1}; i++) {
99695 cdf += getProbs(batch, i);
99696
99697 if (r < cdf) {
99698 setOutput(float(i));
99699 return;
99700 }
99701 }
99702
99703 // If no other event happened, last event happened.
99704 setOutput(float(${numOutcomes - 1}));
99705 }
99706 `;
99707 }
99708 }
99709
99710 /**
99711 * @license
99712 * Copyright 2020 Google LLC. All Rights Reserved.
99713 * Licensed under the Apache License, Version 2.0 (the "License");
99714 * you may not use this file except in compliance with the License.
99715 * You may obtain a copy of the License at
99716 *
99717 * http://www.apache.org/licenses/LICENSE-2.0
99718 *
99719 * Unless required by applicable law or agreed to in writing, software
99720 * distributed under the License is distributed on an "AS IS" BASIS,
99721 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99722 * See the License for the specific language governing permissions and
99723 * limitations under the License.
99724 * =============================================================================
99725 */
99726 // Without the equality check div produces 0.9999 for a = b, which when
99727 // floored can cause errors.
99728 const DIV = `
99729if (a == b) {
99730 return 1.0;
99731};
99732return a / b;`;
99733 // We do the same as in ./binaryop_gpu, with vec4 and ivec4.
99734 // On Linux, the vectorized implementation produces NaNs when a and b are 0.
99735 const DIV_PACKED = `
99736 // vec4 one = vec4(equal(a, b));
99737 // return one + (vec4(1.0) - one) * a / b;
99738 vec4 result = a / b;
99739 if(a.x == b.x) {
99740 result.x = 1.;
99741 }
99742 if(a.y == b.y) {
99743 result.y = 1.;
99744 }
99745 if(a.z == b.z) {
99746 result.z = 1.;
99747 }
99748 if(a.w == b.w) {
99749 result.w = 1.;
99750 }
99751
99752 return result;
99753`;
99754 const realDiv = binaryKernelFunc({ opSnippet: DIV, packedOpSnippet: DIV_PACKED, checkOutOfBounds: true });
99755 const realDivConfig = {
99756 kernelName: RealDiv,
99757 backendName: 'webgl',
99758 kernelFunc: realDiv,
99759 };
99760
99761 /**
99762 * @license
99763 * Copyright 2020 Google LLC. All Rights Reserved.
99764 * Licensed under the Apache License, Version 2.0 (the "License");
99765 * you may not use this file except in compliance with the License.
99766 * You may obtain a copy of the License at
99767 *
99768 * http://www.apache.org/licenses/LICENSE-2.0
99769 *
99770 * Unless required by applicable law or agreed to in writing, software
99771 * distributed under the License is distributed on an "AS IS" BASIS,
99772 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99773 * See the License for the specific language governing permissions and
99774 * limitations under the License.
99775 * =============================================================================
99776 */
99777 const SUB = 'return a - b;';
99778 const sub = binaryKernelFunc({
99779 opSnippet: SUB,
99780 packedOpSnippet: SUB,
99781 supportsComplex: true,
99782 cpuKernelImpl: subImplCPU
99783 });
99784 const subConfig = {
99785 kernelName: Sub,
99786 backendName: 'webgl',
99787 kernelFunc: sub
99788 };
99789
99790 /**
99791 * @license
99792 * Copyright 2020 Google LLC. All Rights Reserved.
99793 * Licensed under the Apache License, Version 2.0 (the "License");
99794 * you may not use this file except in compliance with the License.
99795 * You may obtain a copy of the License at
99796 *
99797 * http://www.apache.org/licenses/LICENSE-2.0
99798 *
99799 * Unless required by applicable law or agreed to in writing, software
99800 * distributed under the License is distributed on an "AS IS" BASIS,
99801 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99802 * See the License for the specific language governing permissions and
99803 * limitations under the License.
99804 * =============================================================================
99805 */
99806 function softmax(args) {
99807 const { inputs, backend, attrs } = args;
99808 const { logits } = inputs;
99809 const { dim } = attrs;
99810 const axes = parseAxisParam([dim], logits.shape);
99811 const maxLogit = max({
99812 inputs: { x: logits },
99813 backend,
99814 attrs: { reductionIndices: axes, keepDims: false }
99815 });
99816 const expandedShape = expandShapeToKeepDim(maxLogit.shape, axes);
99817 const maxLogitsReshaped = reshape({ inputs: { x: maxLogit }, backend, attrs: { shape: expandedShape } });
99818 const a = sub({ inputs: { a: logits, b: maxLogitsReshaped }, backend });
99819 const b = exp({ inputs: { x: a }, backend });
99820 const sumExp = sum({ inputs: { x: b }, backend, attrs: { axis: axes, keepDims: false } });
99821 const sumExpReshaped = reshape({ inputs: { x: sumExp }, backend, attrs: { shape: expandedShape } });
99822 const res = realDiv({ inputs: { a: b, b: sumExpReshaped }, backend });
99823 backend.disposeIntermediateTensorInfo(maxLogit);
99824 backend.disposeIntermediateTensorInfo(maxLogitsReshaped);
99825 backend.disposeIntermediateTensorInfo(a);
99826 backend.disposeIntermediateTensorInfo(b);
99827 backend.disposeIntermediateTensorInfo(sumExp);
99828 backend.disposeIntermediateTensorInfo(sumExpReshaped);
99829 return res;
99830 }
99831 const softmaxConfig = {
99832 kernelName: Softmax$2,
99833 backendName: 'webgl',
99834 kernelFunc: softmax
99835 };
99836
99837 /**
99838 * @license
99839 * Copyright 2020 Google LLC. All Rights Reserved.
99840 * Licensed under the Apache License, Version 2.0 (the "License");
99841 * you may not use this file except in compliance with the License.
99842 * You may obtain a copy of the License at
99843 *
99844 * http://www.apache.org/licenses/LICENSE-2.0
99845 *
99846 * Unless required by applicable law or agreed to in writing, software
99847 * distributed under the License is distributed on an "AS IS" BASIS,
99848 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99849 * See the License for the specific language governing permissions and
99850 * limitations under the License.
99851 * =============================================================================
99852 */
99853 function multinomial(args) {
99854 const { inputs, backend, attrs } = args;
99855 const { logits } = inputs;
99856 const { numSamples, seed, normalized } = attrs;
99857 const probs = normalized ?
99858 logits :
99859 softmax({ inputs: { logits }, backend, attrs: { dim: logits.shape.length - 1 } });
99860 const batchSize = probs.shape[0];
99861 const numOutcomes = probs.shape[1];
99862 const program = new MultinomialProgram(batchSize, numOutcomes, numSamples);
99863 const customValues = [[seed]];
99864 const res = backend.runWebGLProgram(program, [probs], 'int32', customValues);
99865 if (!normalized) {
99866 backend.disposeIntermediateTensorInfo(probs);
99867 }
99868 return res;
99869 }
99870 const multinomialConfig = {
99871 kernelName: Multinomial,
99872 backendName: 'webgl',
99873 kernelFunc: multinomial
99874 };
99875
99876 /**
99877 * @license
99878 * Copyright 2020 Google LLC. All Rights Reserved.
99879 * Licensed under the Apache License, Version 2.0 (the "License");
99880 * you may not use this file except in compliance with the License.
99881 * You may obtain a copy of the License at
99882 *
99883 * http://www.apache.org/licenses/LICENSE-2.0
99884 *
99885 * Unless required by applicable law or agreed to in writing, software
99886 * distributed under the License is distributed on an "AS IS" BASIS,
99887 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99888 * See the License for the specific language governing permissions and
99889 * limitations under the License.
99890 * =============================================================================
99891 */
99892 const NEG = CHECK_NAN_SNIPPET$1 + `
99893 return -x;
99894`;
99895 const NEG_PACKED = `
99896 vec4 result = -x;
99897 bvec4 isNaN = isnan(x);
99898
99899 result.r = isNaN.r ? x.r : result.r;
99900 result.g = isNaN.g ? x.g : result.g;
99901 result.b = isNaN.b ? x.b : result.b;
99902 result.a = isNaN.a ? x.a : result.a;
99903
99904 return result;
99905`;
99906 // This doesn't use unaryKernelFunc because negImplCPU is not of type
99907 // SimpleUnaryKernelImplCPU.
99908 function neg(args) {
99909 const { inputs, backend } = args;
99910 const { x } = inputs;
99911 if (backend.shouldExecuteOnCPU([x])) {
99912 const xData = backend.texData.get(x.dataId);
99913 const [outValues, newShape] = negImplCPU(xData.values, x.shape, x.dtype);
99914 return backend.makeTensorInfo(newShape, x.dtype, outValues);
99915 }
99916 let program;
99917 if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
99918 program = new UnaryOpPackedProgram(x.shape, NEG_PACKED);
99919 }
99920 else {
99921 program = new UnaryOpProgram(x.shape, NEG);
99922 }
99923 return backend.runWebGLProgram(program, [x], x.dtype);
99924 }
99925 const negConfig = {
99926 kernelName: Neg,
99927 backendName: 'webgl',
99928 kernelFunc: neg
99929 };
99930
99931 /**
99932 * @license
99933 * Copyright 2020 Google LLC. All Rights Reserved.
99934 * Licensed under the Apache License, Version 2.0 (the "License");
99935 * you may not use this file except in compliance with the License.
99936 * You may obtain a copy of the License at
99937 *
99938 * http://www.apache.org/licenses/LICENSE-2.0
99939 *
99940 * Unless required by applicable law or agreed to in writing, software
99941 * distributed under the License is distributed on an "AS IS" BASIS,
99942 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99943 * See the License for the specific language governing permissions and
99944 * limitations under the License.
99945 * =============================================================================
99946 */
99947 const nonMaxSuppressionV3Impl = nonMaxSuppressionV3Impl$2;
99948 function nonMaxSuppressionV3(args) {
99949 warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
99950 'Call tf.nonMaxSuppressionAsync() instead');
99951 const { inputs, backend, attrs } = args;
99952 const { boxes, scores } = inputs;
99953 const { maxOutputSize, iouThreshold, scoreThreshold } = attrs;
99954 const boxesVals = backend.readSync(boxes.dataId);
99955 const scoresVals = backend.readSync(scores.dataId);
99956 const { selectedIndices } = nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
99957 return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices));
99958 }
99959 const nonMaxSuppressionV3Config = {
99960 kernelName: NonMaxSuppressionV3,
99961 backendName: 'webgl',
99962 kernelFunc: nonMaxSuppressionV3
99963 };
99964
99965 /**
99966 * @license
99967 * Copyright 2020 Google LLC. All Rights Reserved.
99968 * Licensed under the Apache License, Version 2.0 (the "License");
99969 * you may not use this file except in compliance with the License.
99970 * You may obtain a copy of the License at
99971 *
99972 * http://www.apache.org/licenses/LICENSE-2.0
99973 *
99974 * Unless required by applicable law or agreed to in writing, software
99975 * distributed under the License is distributed on an "AS IS" BASIS,
99976 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99977 * See the License for the specific language governing permissions and
99978 * limitations under the License.
99979 * =============================================================================
99980 */
99981 const nonMaxSuppressionV4Impl = nonMaxSuppressionV4Impl$2;
99982 function nonMaxSuppressionV4(args) {
99983 warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
99984 'Call tf.nonMaxSuppressionAsync() instead');
99985 const { inputs, backend, attrs } = args;
99986 const { boxes, scores } = inputs;
99987 const { maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize } = attrs;
99988 const boxesVals = backend.readSync(boxes.dataId);
99989 const scoresVals = backend.readSync(scores.dataId);
99990 const { selectedIndices, validOutputs } = nonMaxSuppressionV4Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize);
99991 return [
99992 backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
99993 backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs]))
99994 ];
99995 }
99996 const nonMaxSuppressionV4Config = {
99997 kernelName: NonMaxSuppressionV4,
99998 backendName: 'webgl',
99999 kernelFunc: nonMaxSuppressionV4
100000 };
100001
100002 /**
100003 * @license
100004 * Copyright 2020 Google LLC. All Rights Reserved.
100005 * Licensed under the Apache License, Version 2.0 (the "License");
100006 * you may not use this file except in compliance with the License.
100007 * You may obtain a copy of the License at
100008 *
100009 * http://www.apache.org/licenses/LICENSE-2.0
100010 *
100011 * Unless required by applicable law or agreed to in writing, software
100012 * distributed under the License is distributed on an "AS IS" BASIS,
100013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100014 * See the License for the specific language governing permissions and
100015 * limitations under the License.
100016 * =============================================================================
100017 */
100018 const nonMaxSuppressionV5Impl = nonMaxSuppressionV5Impl$2;
100019 function nonMaxSuppressionV5(args) {
100020 warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
100021 'Call tf.nonMaxSuppressionAsync() instead');
100022 const { inputs, backend, attrs } = args;
100023 const { boxes, scores } = inputs;
100024 const { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma } = attrs;
100025 const boxesVals = backend.readSync(boxes.dataId);
100026 const scoresVals = backend.readSync(scores.dataId);
100027 const maxOutputSizeVal = maxOutputSize;
100028 const iouThresholdVal = iouThreshold;
100029 const scoreThresholdVal = scoreThreshold;
100030 const softNmsSigmaVal = softNmsSigma;
100031 const { selectedIndices, selectedScores } = nonMaxSuppressionV5Impl(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal);
100032 return [
100033 backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
100034 backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores))
100035 ];
100036 }
100037 const nonMaxSuppressionV5Config = {
100038 kernelName: NonMaxSuppressionV5,
100039 backendName: 'webgl',
100040 kernelFunc: nonMaxSuppressionV5
100041 };
100042
100043 /**
100044 * @license
100045 * Copyright 2017 Google LLC. All Rights Reserved.
100046 * Licensed under the Apache License, Version 2.0 (the "License");
100047 * you may not use this file except in compliance with the License.
100048 * You may obtain a copy of the License at
100049 *
100050 * http://www.apache.org/licenses/LICENSE-2.0
100051 *
100052 * Unless required by applicable law or agreed to in writing, software
100053 * distributed under the License is distributed on an "AS IS" BASIS,
100054 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100055 * See the License for the specific language governing permissions and
100056 * limitations under the License.
100057 * =============================================================================
100058 */
100059 class OneHotProgram {
100060 constructor(numIndices, depth, onValue, offValue) {
100061 this.variableNames = ['indices'];
100062 this.outputShape = [numIndices, depth];
100063 this.userCode = `
100064 void main() {
100065 ivec2 coords = getOutputCoords();
100066 int index = round(getIndices(coords.x));
100067 setOutput(mix(float(${offValue}), float(${onValue}),
100068 float(index == coords.y)));
100069 }
100070 `;
100071 }
100072 }
100073
100074 /**
100075 * @license
100076 * Copyright 2020 Google LLC. All Rights Reserved.
100077 * Licensed under the Apache License, Version 2.0 (the "License");
100078 * you may not use this file except in compliance with the License.
100079 * You may obtain a copy of the License at
100080 *
100081 * http://www.apache.org/licenses/LICENSE-2.0
100082 *
100083 * Unless required by applicable law or agreed to in writing, software
100084 * distributed under the License is distributed on an "AS IS" BASIS,
100085 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100086 * See the License for the specific language governing permissions and
100087 * limitations under the License.
100088 * =============================================================================
100089 */
100090 const oneHot = (args) => {
100091 const { inputs, backend, attrs } = args;
100092 const { indices } = inputs;
100093 const { dtype, depth, onValue, offValue } = attrs;
100094 const indicesSize = sizeFromShape(indices.shape);
100095 const program = new OneHotProgram(indicesSize, depth, onValue, offValue);
100096 const reshaped = reshape({ inputs: { x: indices }, backend, attrs: { shape: [indicesSize] } });
100097 const result = backend.runWebGLProgram(program, [reshaped], dtype);
100098 backend.disposeIntermediateTensorInfo(reshaped);
100099 const outShape = [...indices.shape, depth];
100100 const out = reshape({ inputs: { x: result }, backend, attrs: { shape: outShape } });
100101 backend.disposeIntermediateTensorInfo(result);
100102 return out;
100103 };
100104 const oneHotConfig = {
100105 kernelName: OneHot,
100106 backendName: 'webgl',
100107 kernelFunc: oneHot
100108 };
100109
100110 /**
100111 * @license
100112 * Copyright 2020 Google LLC. All Rights Reserved.
100113 * Licensed under the Apache License, Version 2.0 (the "License");
100114 * you may not use this file except in compliance with the License.
100115 * You may obtain a copy of the License at
100116 *
100117 * http://www.apache.org/licenses/LICENSE-2.0
100118 *
100119 * Unless required by applicable law or agreed to in writing, software
100120 * distributed under the License is distributed on an "AS IS" BASIS,
100121 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100122 * See the License for the specific language governing permissions and
100123 * limitations under the License.
100124 * =============================================================================
100125 */
100126 function zerosLike(args) {
100127 const { inputs, backend } = args;
100128 const { x } = inputs;
100129 if (x.dtype === 'complex64') {
100130 const realPart = real({ inputs: { input: x }, backend });
100131 const r = zerosLike({ inputs: { x: realPart }, backend });
100132 const imagPart = imag({ inputs: { input: x }, backend });
100133 const i = zerosLike({ inputs: { x: imagPart }, backend });
100134 const result = complex({ inputs: { real: r, imag: i }, backend });
100135 backend.disposeIntermediateTensorInfo(realPart);
100136 backend.disposeIntermediateTensorInfo(r);
100137 backend.disposeIntermediateTensorInfo(imagPart);
100138 backend.disposeIntermediateTensorInfo(i);
100139 return result;
100140 }
100141 else {
100142 return fill({
100143 attrs: {
100144 shape: x.shape,
100145 dtype: x.dtype,
100146 value: x.dtype === 'string' ? '' : 0
100147 },
100148 backend
100149 });
100150 }
100151 }
100152 const zerosLikeConfig = {
100153 kernelName: ZerosLike,
100154 backendName: 'webgl',
100155 kernelFunc: zerosLike
100156 };
100157
100158 /**
100159 * @license
100160 * Copyright 2020 Google LLC. All Rights Reserved.
100161 * Licensed under the Apache License, Version 2.0 (the "License");
100162 * you may not use this file except in compliance with the License.
100163 * You may obtain a copy of the License at
100164 *
100165 * http://www.apache.org/licenses/LICENSE-2.0
100166 *
100167 * Unless required by applicable law or agreed to in writing, software
100168 * distributed under the License is distributed on an "AS IS" BASIS,
100169 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100170 * See the License for the specific language governing permissions and
100171 * limitations under the License.
100172 * =============================================================================
100173 */
100174 function onesLike(args) {
100175 const { inputs, backend } = args;
100176 const { x } = inputs;
100177 if (x.dtype === 'string') {
100178 throw new Error('onesLike is not supported under string dtype');
100179 }
100180 else if (x.dtype === 'complex64') {
100181 const realPart = real({ inputs: { input: x }, backend });
100182 const r = onesLike({ inputs: { x: realPart }, backend });
100183 const imagPart = imag({ inputs: { input: x }, backend });
100184 const i = zerosLike({ inputs: { x: imagPart }, backend });
100185 const result = complex({ inputs: { real: r, imag: i }, backend });
100186 backend.disposeIntermediateTensorInfo(realPart);
100187 backend.disposeIntermediateTensorInfo(r);
100188 backend.disposeIntermediateTensorInfo(imagPart);
100189 backend.disposeIntermediateTensorInfo(i);
100190 return result;
100191 }
100192 else {
100193 // TODO(cais, smilkov): Add WebGL shader for onesLike:
100194 // https://github.com/tensorflow/tfjs/issues/1293
100195 return fill({ attrs: { shape: x.shape, dtype: x.dtype, value: 1 }, backend });
100196 }
100197 }
100198 const onesLikeConfig = {
100199 kernelName: OnesLike,
100200 backendName: 'webgl',
100201 kernelFunc: onesLike
100202 };
100203
100204 /**
100205 * @license
100206 * Copyright 2020 Google LLC. All Rights Reserved.
100207 * Licensed under the Apache License, Version 2.0 (the "License");
100208 * you may not use this file except in compliance with the License.
100209 * You may obtain a copy of the License at
100210 *
100211 * http://www.apache.org/licenses/LICENSE-2.0
100212 *
100213 * Unless required by applicable law or agreed to in writing, software
100214 * distributed under the License is distributed on an "AS IS" BASIS,
100215 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100216 * See the License for the specific language governing permissions and
100217 * limitations under the License.
100218 * =============================================================================
100219 */
100220 function pack(args) {
100221 const { inputs, backend, attrs } = args;
100222 const { axis } = attrs;
100223 if (inputs.length === 1) {
100224 return expandDims({ inputs: { input: inputs[0] }, backend, attrs: { dim: axis } });
100225 }
100226 const shape = inputs[0].shape;
100227 const dtype = inputs[0].dtype;
100228 inputs.forEach(t => {
100229 assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes');
100230 assert$1(dtype === t.dtype, () => 'All tensors passed to stack must have matching dtypes');
100231 });
100232 const intermediateTensorInfos = [];
100233 const expandedTensors = inputs.map(t => {
100234 const expandedT = expandDims({ inputs: { input: t }, backend, attrs: { dim: axis } });
100235 intermediateTensorInfos.push(expandedT);
100236 return expandedT;
100237 });
100238 const result = concat({ inputs: expandedTensors, backend, attrs: { axis } });
100239 intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
100240 return result;
100241 }
100242 const packConfig = {
100243 kernelName: Pack,
100244 backendName: 'webgl',
100245 kernelFunc: pack
100246 };
100247
100248 /**
100249 * @license
100250 * Copyright 2017 Google LLC. All Rights Reserved.
100251 * Licensed under the Apache License, Version 2.0 (the "License");
100252 * you may not use this file except in compliance with the License.
100253 * You may obtain a copy of the License at
100254 *
100255 * http://www.apache.org/licenses/LICENSE-2.0
100256 *
100257 * Unless required by applicable law or agreed to in writing, software
100258 * distributed under the License is distributed on an "AS IS" BASIS,
100259 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100260 * See the License for the specific language governing permissions and
100261 * limitations under the License.
100262 * =============================================================================
100263 */
100264 class PadProgram {
100265 constructor(xShape, paddings, constantValue) {
100266 this.variableNames = ['x'];
100267 this.customUniforms = [{ name: 'value', type: 'float' }];
100268 this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
100269 const rank = xShape.length;
100270 const type = getCoordsDataType(rank);
100271 const start = paddings.map(p => p[0]).join(',');
100272 const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
100273 const unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
100274 if (rank === 1) {
100275 this.userCode = `
100276 int start = ${start};
100277 int end = ${end};
100278
100279 void main() {
100280 int outC = getOutputCoords();
100281 if (outC < start || outC >= end) {
100282 setOutput(value);
100283 } else {
100284 setOutput(getX(outC - start));
100285 }
100286 }
100287 `;
100288 return;
100289 }
100290 this.userCode = `
100291 ${type} start = ${type}(${start});
100292 ${type} end = ${type}(${end});
100293
100294 void main() {
100295 ${type} outC = getOutputCoords();
100296 if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) {
100297 setOutput(value);
100298 } else {
100299 ${type} coords = outC - start;
100300 setOutput(getX(${unpackedCoords}));
100301 }
100302 }
100303 `;
100304 }
100305 }
100306
100307 /**
100308 * @license
100309 * Copyright 2019 Google LLC. All Rights Reserved.
100310 * Licensed under the Apache License, Version 2.0 (the "License");
100311 * you may not use this file except in compliance with the License.
100312 * You may obtain a copy of the License at
100313 *
100314 * http://www.apache.org/licenses/LICENSE-2.0
100315 *
100316 * Unless required by applicable law or agreed to in writing, software
100317 * distributed under the License is distributed on an "AS IS" BASIS,
100318 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100319 * See the License for the specific language governing permissions and
100320 * limitations under the License.
100321 * =============================================================================
100322 */
100323 class PadPackedProgram {
100324 constructor(xShape, paddings, constantValue) {
100325 this.variableNames = ['x'];
100326 this.packedInputs = true;
100327 this.packedOutput = true;
100328 this.customUniforms = [{ name: 'value', type: 'float' }];
100329 this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
100330 const rank = xShape.length;
100331 const dtype = getCoordsDataType(rank);
100332 const start = paddings.map(p => p[0]).join(',');
100333 const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
100334 const coords = getChannels('rc', rank);
100335 const source = getChannels('source', rank);
100336 const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`;
100337 const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`;
100338 const componentSetup = [
100339 `${dtype} rc = outputLoc;`, `${coords[rank - 1]} += 1;
100340 if(${cLimit}) {
100341 `,
100342 rank === 1 ? '' : `}
100343 rc = outputLoc;
100344 ${coords[rank - 2]} += 1;
100345 if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {`,
100346 rank === 1 ? '' : ` ${coords[rank - 1]} += 1;
100347 if(${cLimit}) {`
100348 ];
100349 const paddingArea = rank === 1 ?
100350 'rc < start || rc >= end' :
100351 'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))';
100352 let mainLoop = '';
100353 for (let i = 0, j = rank === 1 ? 2 : 4; i < j; i++) {
100354 mainLoop += `
100355 ${componentSetup[i]}
100356 if (${paddingArea}) {
100357 result[${i}] = float(value);
100358 } else {
100359 ${dtype} source = rc - start;
100360 result[${i}] = getChannel(getX(${source.join()}), ${innerDims});
100361 }
100362 `;
100363 }
100364 mainLoop += (rank === 1 ? `} ` : `}}`);
100365 this.userCode = `
100366 const ${dtype} start = ${dtype}(${start});
100367 const ${dtype} end = ${dtype}(${end});
100368
100369 void main() {
100370 ${dtype} outputLoc = getOutputCoords();
100371 vec4 result = vec4(0.);
100372 ${mainLoop}
100373 setOutput(result);
100374 }
100375 `;
100376 }
100377 }
100378
100379 /**
100380 * @license
100381 * Copyright 2020 Google LLC. All Rights Reserved.
100382 * Licensed under the Apache License, Version 2.0 (the "License");
100383 * you may not use this file except in compliance with the License.
100384 * You may obtain a copy of the License at
100385 *
100386 * http://www.apache.org/licenses/LICENSE-2.0
100387 *
100388 * Unless required by applicable law or agreed to in writing, software
100389 * distributed under the License is distributed on an "AS IS" BASIS,
100390 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100391 * See the License for the specific language governing permissions and
100392 * limitations under the License.
100393 * =============================================================================
100394 */
100395 const padV2 = (args) => {
100396 const { inputs, backend, attrs } = args;
100397 const { x } = inputs;
100398 const { paddings, constantValue } = attrs;
100399 if (sizeFromShape(x.shape) === 0) {
100400 // Short-circuit the computation, since x doesn't have value, only
100401 // the shape is used to compute output shape to pad.
100402 const outputShape = paddings.map((p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */);
100403 return fill({
100404 backend,
100405 attrs: { shape: outputShape, value: constantValue, dtype: x.dtype }
100406 });
100407 }
100408 const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
100409 new PadPackedProgram(x.shape, paddings, constantValue) :
100410 new PadProgram(x.shape, paddings, constantValue);
100411 const customValues = [[constantValue]];
100412 return backend.runWebGLProgram(program, [x], x.dtype, customValues);
100413 };
100414 const padV2Config = {
100415 kernelName: PadV2,
100416 backendName: 'webgl',
100417 kernelFunc: padV2
100418 };
100419
100420 /**
100421 * @license
100422 * Copyright 2020 Google LLC. All Rights Reserved.
100423 * Licensed under the Apache License, Version 2.0 (the "License");
100424 * you may not use this file except in compliance with the License.
100425 * You may obtain a copy of the License at
100426 *
100427 * http://www.apache.org/licenses/LICENSE-2.0
100428 *
100429 * Unless required by applicable law or agreed to in writing, software
100430 * distributed under the License is distributed on an "AS IS" BASIS,
100431 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100432 * See the License for the specific language governing permissions and
100433 * limitations under the License.
100434 * =============================================================================
100435 */
100436 const POW = `
100437 if(a < 0.0 && floor(b) < b){
100438 return NAN;
100439 }
100440 if (b == 0.0) {
100441 return 1.0;
100442 }
100443 return (round(mod(b, 2.0)) != 1) ?
100444 pow(abs(a), b) : sign(a) * pow(abs(a), b);
100445`;
100446 const POW_PACKED = `
100447 // isModRound1 has 1 for components with round(mod(b, 2.0)) == 1, 0 otherwise.
100448 vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1)));
100449 vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1);
100450 vec4 result = multiplier * pow(abs(a), b);
100451
100452 // Ensure that a^0 = 1, including 0^0 = 1 as this correspond to TF and JS
100453 bvec4 isExpZero = equal(b, vec4(0.0));
100454 result.r = isExpZero.r ? 1.0 : result.r;
100455 result.g = isExpZero.g ? 1.0 : result.g;
100456 result.b = isExpZero.b ? 1.0 : result.b;
100457 result.a = isExpZero.a ? 1.0 : result.a;
100458
100459 bvec4 isNaN1 = lessThan(a, vec4(0.0));
100460 bvec4 isNaN2 = lessThan(floor(b), b);
100461 bvec4 isNaN = bvec4(isNaN1.x && isNaN2.x, isNaN1.y && isNaN2.y, isNaN1.z && isNaN2.z, isNaN1.w && isNaN2.w);
100462 ` +
100463 CHECK_NAN_SNIPPET_PACKED + `
100464 return result;
100465`;
100466 const pow = binaryKernelFunc({ opSnippet: POW, packedOpSnippet: POW_PACKED });
100467 const powConfig = {
100468 kernelName: Pow,
100469 backendName: 'webgl',
100470 kernelFunc: pow
100471 };
100472
100473 /**
100474 * @license
100475 * Copyright 2020 Google LLC. All Rights Reserved.
100476 * Licensed under the Apache License, Version 2.0 (the "License");
100477 * you may not use this file except in compliance with the License.
100478 * You may obtain a copy of the License at
100479 *
100480 * http://www.apache.org/licenses/LICENSE-2.0
100481 *
100482 * Unless required by applicable law or agreed to in writing, software
100483 * distributed under the License is distributed on an "AS IS" BASIS,
100484 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100485 * See the License for the specific language governing permissions and
100486 * limitations under the License.
100487 * =============================================================================
100488 */
100489 function prod(args) {
100490 const { inputs, backend, attrs } = args;
100491 const { x } = inputs;
100492 const { axis, keepDims } = attrs;
100493 const xRank = x.shape.length;
100494 const toDispose = [];
100495 const origAxes = parseAxisParam(axis, x.shape);
100496 let axes = origAxes;
100497 const permutedAxes = getAxesPermutation(axes, xRank);
100498 let permutedX = x;
100499 if (permutedAxes != null) {
100500 permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
100501 axes = getInnerMostAxes(axes.length, xRank);
100502 toDispose.push(permutedX);
100503 }
100504 assertAxesAreInnerMostDims('prod', axes, xRank);
100505 let res;
100506 if (backend.shouldExecuteOnCPU([permutedX])) {
100507 const xVals = backend.texData.get(permutedX.dataId).values;
100508 const { outVals, outShape, outDtype } = prodImplCPU(permutedX.shape, permutedX.dtype, xVals, axes);
100509 res = backend.makeTensorInfo(outShape, outDtype, outVals);
100510 }
100511 else {
100512 const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes);
100513 const inSize = sizeFromShape(reduceShape);
100514 const a2D = reshape({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
100515 const outputDType = sumOutType(x.dtype);
100516 const reduced = reduce(a2D, outputDType, 'prod', backend);
100517 res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
100518 toDispose.push(a2D);
100519 toDispose.push(reduced);
100520 }
100521 if (keepDims) {
100522 toDispose.push(res);
100523 const newShape = expandShapeToKeepDim(res.shape, origAxes);
100524 res = reshape({ inputs: { x: res }, backend, attrs: { shape: newShape } });
100525 }
100526 toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
100527 return res;
100528 }
100529 const prodConfig = {
100530 kernelName: Prod,
100531 backendName: 'webgl',
100532 kernelFunc: prod
100533 };
100534
100535 /**
100536 * @license
100537 * Copyright 2022 Google LLC. All Rights Reserved.
100538 * Licensed under the Apache License, Version 2.0 (the "License");
100539 * you may not use this file except in compliance with the License.
100540 * You may obtain a copy of the License at
100541 *
100542 * http://www.apache.org/licenses/LICENSE-2.0
100543 *
100544 * Unless required by applicable law or agreed to in writing, software
100545 * distributed under the License is distributed on an "AS IS" BASIS,
100546 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100547 * See the License for the specific language governing permissions and
100548 * limitations under the License.
100549 * =============================================================================
100550 */
100551 function raggedGather(args) {
100552 const { inputs, backend, attrs } = args;
100553 const { paramsNestedSplits, paramsDenseValues, indices } = inputs;
100554 const { outputRaggedRank } = attrs;
100555 const $paramsNestedSplits = paramsNestedSplits.map(t => backend.readSync(t.dataId));
100556 const $paramsNestedSplitsShapes = paramsNestedSplits.map(t => t.shape);
100557 const $paramsDenseValues = backend.readSync(paramsDenseValues.dataId);
100558 const $indices = backend.readSync(indices.dataId);
100559 const [outputNestedSplits, outputDenseValues, outputDenseValuesShape] = raggedGatherImplCPU($paramsNestedSplits, $paramsNestedSplitsShapes, $paramsDenseValues, paramsDenseValues.shape, paramsDenseValues.dtype, $indices, indices.shape, outputRaggedRank);
100560 const outputNestedSplitsTensors = outputNestedSplits.map((splits) => backend.makeTensorInfo([splits.length], 'int32', splits));
100561 const outputDenseValuesTensor = backend.makeTensorInfo(outputDenseValuesShape, paramsDenseValues.dtype, outputDenseValues);
100562 return outputNestedSplitsTensors.concat([outputDenseValuesTensor]);
100563 }
100564 const raggedGatherConfig = {
100565 kernelName: RaggedGather,
100566 backendName: 'webgl',
100567 kernelFunc: raggedGather,
100568 };
100569
100570 /**
100571 * @license
100572 * Copyright 2022 Google LLC.
100573 * Licensed under the Apache License, Version 2.0 (the "License");
100574 * you may not use this file except in compliance with the License.
100575 * You may obtain a copy of the License at
100576 *
100577 * http://www.apache.org/licenses/LICENSE-2.0
100578 *
100579 * Unless required by applicable law or agreed to in writing, software
100580 * distributed under the License is distributed on an "AS IS" BASIS,
100581 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100582 * See the License for the specific language governing permissions and
100583 * limitations under the License.
100584 * =============================================================================
100585 */
100586 function raggedRange(args) {
100587 const { inputs, backend } = args;
100588 const { starts, limits, deltas } = inputs;
100589 const $starts = backend.readSync(starts.dataId);
100590 const $limits = backend.readSync(limits.dataId);
100591 const $deltas = backend.readSync(deltas.dataId);
100592 const [rtNestedSplitsData, rtDenseValuesData] = raggedRangeImplCPU($starts, starts.shape, starts.dtype, $limits, limits.shape, $deltas, deltas.shape);
100593 const rtNestedSplits = backend.makeTensorInfo([rtNestedSplitsData.length], 'int32', rtNestedSplitsData);
100594 const rtDenseValues = backend.makeTensorInfo([rtDenseValuesData.length], starts.dtype, rtDenseValuesData);
100595 return [rtNestedSplits, rtDenseValues];
100596 }
100597 const raggedRangeConfig = {
100598 kernelName: RaggedRange,
100599 backendName: 'webgl',
100600 kernelFunc: raggedRange,
100601 };
100602
100603 /**
100604 * @license
100605 * Copyright 2022 Google LLC. All Rights Reserved.
100606 * Licensed under the Apache License, Version 2.0 (the "License");
100607 * you may not use this file except in compliance with the License.
100608 * You may obtain a copy of the License at
100609 *
100610 * http://www.apache.org/licenses/LICENSE-2.0
100611 *
100612 * Unless required by applicable law or agreed to in writing, software
100613 * distributed under the License is distributed on an "AS IS" BASIS,
100614 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100615 * See the License for the specific language governing permissions and
100616 * limitations under the License.
100617 * =============================================================================
100618 */
100619 function raggedTensorToTensor(args) {
100620 const { inputs, backend, attrs } = args;
100621 const { shape, values, defaultValue, rowPartitionTensors } = inputs;
100622 const { rowPartitionTypes } = attrs;
100623 const $shape = backend.readSync(shape.dataId);
100624 const $values = backend.readSync(values.dataId);
100625 const $defaultValue = backend.readSync(defaultValue.dataId);
100626 const $rowPartitionValues = rowPartitionTensors.map(t => backend.readSync(t.dataId));
100627 const rowPartitionValuesShapes = rowPartitionTensors.map(t => t.shape);
100628 const [outputShape, output] = raggedTensorToTensorImplCPU($shape, shape.shape, $values, values.shape, values.dtype, $defaultValue, defaultValue.shape, $rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes);
100629 return backend.makeTensorInfo(outputShape, values.dtype, output);
100630 }
100631 const raggedTensorToTensorConfig = {
100632 kernelName: RaggedTensorToTensor,
100633 backendName: 'webgl',
100634 kernelFunc: raggedTensorToTensor,
100635 };
100636
100637 /**
100638 * @license
100639 * Copyright 2020 Google LLC. All Rights Reserved.
100640 * Licensed under the Apache License, Version 2.0 (the "License");
100641 * you may not use this file except in compliance with the License.
100642 * You may obtain a copy of the License at
100643 *
100644 * http://www.apache.org/licenses/LICENSE-2.0
100645 *
100646 * Unless required by applicable law or agreed to in writing, software
100647 * distributed under the License is distributed on an "AS IS" BASIS,
100648 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100649 * See the License for the specific language governing permissions and
100650 * limitations under the License.
100651 * =============================================================================
100652 */
100653 const range = (args) => {
100654 const { backend, attrs } = args;
100655 const { start, stop, step, dtype } = attrs;
100656 const values = rangeImplCPU(start, stop, step, dtype);
100657 return backend.makeTensorInfo([values.length], dtype, values);
100658 };
100659 const rangeConfig = {
100660 kernelName: Range,
100661 backendName: 'webgl',
100662 kernelFunc: range
100663 };
100664
100665 /**
100666 * @license
100667 * Copyright 2020 Google LLC. All Rights Reserved.
100668 * Licensed under the Apache License, Version 2.0 (the "License");
100669 * you may not use this file except in compliance with the License.
100670 * You may obtain a copy of the License at
100671 *
100672 * http://www.apache.org/licenses/LICENSE-2.0
100673 *
100674 * Unless required by applicable law or agreed to in writing, software
100675 * distributed under the License is distributed on an "AS IS" BASIS,
100676 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100677 * See the License for the specific language governing permissions and
100678 * limitations under the License.
100679 * =============================================================================
100680 */
100681 const RECIPROCAL = `return 1.0 / x;`;
100682 const reciprocal = unaryKernelFunc({ opSnippet: RECIPROCAL });
100683 const reciprocalConfig = {
100684 kernelName: Reciprocal,
100685 backendName: 'webgl',
100686 kernelFunc: reciprocal,
100687 };
100688
100689 /**
100690 * @license
100691 * Copyright 2020 Google LLC. All Rights Reserved.
100692 * Licensed under the Apache License, Version 2.0 (the "License");
100693 * you may not use this file except in compliance with the License.
100694 * You may obtain a copy of the License at
100695 *
100696 * http://www.apache.org/licenses/LICENSE-2.0
100697 *
100698 * Unless required by applicable law or agreed to in writing, software
100699 * distributed under the License is distributed on an "AS IS" BASIS,
100700 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100701 * See the License for the specific language governing permissions and
100702 * limitations under the License.
100703 * =============================================================================
100704 */
100705 const RELU = CHECK_NAN_SNIPPET$1 + `
100706 return (x < 0.0) ? 0.0 : x;
100707`;
100708 const RELU_PACKED = `
100709 vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));
100710 bvec4 isNaN = isnan(x);
100711
100712 result.r = isNaN.r ? x.r : result.r;
100713 result.g = isNaN.g ? x.g : result.g;
100714 result.b = isNaN.b ? x.b : result.b;
100715 result.a = isNaN.a ? x.a : result.a;
100716
100717 return result;
100718`;
100719 const relu = unaryKernelFunc({ opSnippet: RELU, packedOpSnippet: RELU_PACKED });
100720 const reluConfig = {
100721 kernelName: Relu$1,
100722 backendName: 'webgl',
100723 kernelFunc: relu
100724 };
100725
100726 /**
100727 * @license
100728 * Copyright 2020 Google LLC. All Rights Reserved.
100729 * Licensed under the Apache License, Version 2.0 (the "License");
100730 * you may not use this file except in compliance with the License.
100731 * You may obtain a copy of the License at
100732 *
100733 * http://www.apache.org/licenses/LICENSE-2.0
100734 *
100735 * Unless required by applicable law or agreed to in writing, software
100736 * distributed under the License is distributed on an "AS IS" BASIS,
100737 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100738 * See the License for the specific language governing permissions and
100739 * limitations under the License.
100740 * =============================================================================
100741 */
100742 const RELU6 = CHECK_NAN_SNIPPET$1 + `
100743 return (x < 0.0) ? 0.0 : min(6.0, x);
100744`;
100745 const RELU6_PACKED = `
100746 vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));
100747 bvec4 isNaN = isnan(x);
100748
100749 result.r = isNaN.r ? x.r : result.r;
100750 result.g = isNaN.g ? x.g : result.g;
100751 result.b = isNaN.b ? x.b : result.b;
100752 result.a = isNaN.a ? x.a : result.a;
100753
100754 return result;
100755`;
100756 const relu6 = unaryKernelFunc({ opSnippet: RELU6, packedOpSnippet: RELU6_PACKED });
100757 const relu6Config = {
100758 kernelName: Relu6$1,
100759 backendName: 'webgl',
100760 kernelFunc: relu6
100761 };
100762
100763 /**
100764 * @license
100765 * Copyright 2017 Google LLC. All Rights Reserved.
100766 * Licensed under the Apache License, Version 2.0 (the "License");
100767 * you may not use this file except in compliance with the License.
100768 * You may obtain a copy of the License at
100769 *
100770 * http://www.apache.org/licenses/LICENSE-2.0
100771 *
100772 * Unless required by applicable law or agreed to in writing, software
100773 * distributed under the License is distributed on an "AS IS" BASIS,
100774 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100775 * See the License for the specific language governing permissions and
100776 * limitations under the License.
100777 * =============================================================================
100778 */
100779 class ResizeBilinearProgram {
100780 constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
100781 this.variableNames = ['A'];
100782 this.outputShape = [];
100783 const [batch, oldHeight, oldWidth, depth] = inputShape;
100784 this.outputShape = [batch, newHeight, newWidth, depth];
100785 const effectiveInSize = [
100786 (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
100787 (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
100788 ];
100789 const effectiveOutSize = [
100790 (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
100791 (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
100792 ];
100793 let sourceFracIndexRC;
100794 if (halfPixelCenters) {
100795 sourceFracIndexRC =
100796 `(vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC` +
100797 ` - vec2(0.5)`;
100798 }
100799 else {
100800 sourceFracIndexRC = `vec2(yRC) * effectiveInputOverOutputRatioRC`;
100801 }
100802 this.userCode = `
100803 const vec2 effectiveInputOverOutputRatioRC = vec2(
100804 ${effectiveInSize[0] / effectiveOutSize[0]},
100805 ${effectiveInSize[1] / effectiveOutSize[1]});
100806 const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0);
100807
100808 void main() {
100809 ivec4 coords = getOutputCoords();
100810 int b = coords[0];
100811 int d = coords[3];
100812 ivec2 yRC = coords.yz;
100813
100814 // Fractional source index.
100815 vec2 sourceFracIndexRC = ${sourceFracIndexRC};
100816
100817 // Compute the four integer indices.
100818 ivec2 sourceFloorRC = ivec2(max(sourceFracIndexRC, vec2(0.0)));
100819 ivec2 sourceCeilRC = ivec2(
100820 min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));
100821
100822 float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d);
100823 float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d);
100824 float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d);
100825 float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d);
100826
100827 vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);
100828
100829 float top = topLeft + (topRight - topLeft) * fracRC.y;
100830 float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;
100831 float newValue = top + (bottom - top) * fracRC.x;
100832
100833 setOutput(newValue);
100834 }
100835 `;
100836 }
100837 }
100838
100839 /**
100840 * @license
100841 * Copyright 2019 Google LLC. All Rights Reserved.
100842 * Licensed under the Apache License, Version 2.0 (the "License");
100843 * you may not use this file except in compliance with the License.
100844 * You may obtain a copy of the License at
100845 *
100846 * http://www.apache.org/licenses/LICENSE-2.0
100847 *
100848 * Unless required by applicable law or agreed to in writing, software
100849 * distributed under the License is distributed on an "AS IS" BASIS,
100850 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100851 * See the License for the specific language governing permissions and
100852 * limitations under the License.
100853 * =============================================================================
100854 */
100855 class ResizeBilinearPackedProgram {
100856 constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
100857 this.variableNames = ['A'];
100858 this.packedInputs = true;
100859 this.packedOutput = true;
100860 this.outputShape = [];
100861 const [batch, oldHeight, oldWidth, depth] = inputShape;
100862 this.outputShape = [batch, newHeight, newWidth, depth];
100863 const effectiveInSize = [
100864 (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
100865 (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
100866 ];
100867 const effectiveOutSize = [
100868 (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
100869 (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
100870 ];
100871 let sourceFracIndexRC;
100872 if (halfPixelCenters) {
100873 sourceFracIndexRC = `(vec3(yRC) + vec3(0.5)) * ` +
100874 `effectiveInputOverOutputRatioRC - vec3(0.5)`;
100875 }
100876 else {
100877 sourceFracIndexRC = `vec3(yRC) * effectiveInputOverOutputRatioRC`;
100878 }
100879 this.userCode = `
100880 const vec3 effectiveInputOverOutputRatioRC = vec3(
100881 ${effectiveInSize[0] / effectiveOutSize[0]},
100882 ${effectiveInSize[1] / effectiveOutSize[1]},
100883 ${effectiveInSize[1] / effectiveOutSize[1]});
100884 const vec3 inputShapeRC = vec3(${oldHeight}.0, ${oldWidth}.0,
100885 ${oldWidth}.0);
100886
100887 float getAValue(int b, int r, int c, int d) {
100888 return getChannel(getA(b, r, c, d), vec2(c, d));
100889 }
100890
100891 void main() {
100892 ivec4 coords = getOutputCoords();
100893 int b = coords[0];
100894 int d = coords[3];
100895 // Calculate values for next column in yRC.z.
100896 ivec3 yRC = coords.yzz + ivec3(0, 0, 1);
100897
100898 // Fractional source index.
100899 vec3 sourceFracIndexRC = ${sourceFracIndexRC};
100900
100901 // Compute the four integer indices.
100902 ivec3 sourceFloorRC = ivec3(max(sourceFracIndexRC, vec3(0.0)));
100903 ivec3 sourceCeilRC = ivec3(
100904 min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));
100905
100906 // Should we calculate next column and row elements in 2x2 packed cell.
100907 bool hasNextCol = d < ${depth - 1};
100908 bool hasNextRow = coords.z < ${newWidth - 1};
100909
100910 // In parallel, construct four corners for all four components in
100911 // packed 2x2 cell.
100912 vec4 topLeft = vec4(
100913 getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d),
100914 hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1)
100915 : 0.0,
100916 hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d)
100917 : 0.0,
100918 (hasNextRow && hasNextCol) ?
100919 getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0);
100920
100921 vec4 bottomLeft = vec4(
100922 getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d),
100923 hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1)
100924 : 0.0,
100925 hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d)
100926 : 0.0,
100927 (hasNextRow && hasNextCol) ?
100928 getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0);
100929
100930 vec4 topRight = vec4(
100931 getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d),
100932 hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1)
100933 : 0.0,
100934 hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d)
100935 : 0.0,
100936 (hasNextRow && hasNextCol) ?
100937 getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0);
100938
100939 vec4 bottomRight = vec4(
100940 getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d),
100941 hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1)
100942 : 0.0,
100943 hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d)
100944 : 0.0,
100945 (hasNextRow && hasNextCol) ?
100946 getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0);
100947
100948 vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC);
100949
100950 vec4 top = mix(topLeft, topRight, fracRC.yyzz);
100951 vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz);
100952 vec4 newValue = mix(top, bottom, fracRC.x);
100953
100954 setOutput(newValue);
100955 }
100956 `;
100957 }
100958 }
100959
100960 /**
100961 * @license
100962 * Copyright 2020 Google LLC. All Rights Reserved.
100963 * Licensed under the Apache License, Version 2.0 (the "License");
100964 * you may not use this file except in compliance with the License.
100965 * You may obtain a copy of the License at
100966 *
100967 * http://www.apache.org/licenses/LICENSE-2.0
100968 *
100969 * Unless required by applicable law or agreed to in writing, software
100970 * distributed under the License is distributed on an "AS IS" BASIS,
100971 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100972 * See the License for the specific language governing permissions and
100973 * limitations under the License.
100974 * =============================================================================
100975 */
100976 function resizeBilinear(args) {
100977 const { inputs, backend, attrs } = args;
100978 const { images } = inputs;
100979 const { alignCorners, halfPixelCenters, size } = attrs;
100980 const [newHeight, newWidth] = size;
100981 const program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ?
100982 new ResizeBilinearPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) :
100983 new ResizeBilinearProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters);
100984 return backend.runWebGLProgram(program, [images], 'float32');
100985 }
100986 const resizeBilinearConfig = {
100987 kernelName: ResizeBilinear,
100988 backendName: 'webgl',
100989 kernelFunc: resizeBilinear
100990 };
100991
100992 /**
100993 * @license
100994 * Copyright 2018 Google LLC. All Rights Reserved.
100995 * Licensed under the Apache License, Version 2.0 (the "License");
100996 * you may not use this file except in compliance with the License.
100997 * You may obtain a copy of the License at
100998 *
100999 * http://www.apache.org/licenses/LICENSE-2.0
101000 *
101001 * Unless required by applicable law or agreed to in writing, software
101002 * distributed under the License is distributed on an "AS IS" BASIS,
101003 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101004 * See the License for the specific language governing permissions and
101005 * limitations under the License.
101006 * =============================================================================
101007 */
101008 class ResizeBilinearBackpropProgram {
101009 constructor(dyShape, inputShape, alignCorners) {
101010 this.variableNames = ['dy'];
101011 this.outputShape = [];
101012 this.outputShape = inputShape;
101013 const [, xHeight, xWidth,] = inputShape;
101014 const [, yHeight, yWidth] = dyShape;
101015 // In the backwards pass, we want to find the pixels that were generated for
101016 // each pixel in the input image the forward pass and add the corresponding
101017 // coefficient from dy to the gradient (with some interpolation).
101018 const effectiveXSize = [
101019 (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
101020 (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
101021 ];
101022 const effectiveYSize = [
101023 (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
101024 (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
101025 ];
101026 const heightScale = effectiveXSize[0] / effectiveYSize[0];
101027 const widthScale = effectiveXSize[1] / effectiveYSize[1];
101028 const invHeightScale = 1 / heightScale;
101029 const invWidthScale = 1 / widthScale;
101030 // This defines the size of the window of values around a particular
101031 // index in dy that we want to search for contributions to dx.
101032 const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
101033 const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
101034 this.userCode = `
101035 void main() {
101036 ivec4 coords = getOutputCoords();
101037 int b = coords[0];
101038 int d = coords[3];
101039 int r = coords[1];
101040 int c = coords[2];
101041
101042 float accumulator = 0.0;
101043
101044 const float heightScale = float(${heightScale});
101045 const float widthScale = float(${widthScale});
101046
101047 const float invHeightScale = float(${invHeightScale});
101048 const float invWidthScale = float(${invWidthScale});
101049
101050 const int winHeight = int(${winHeight});
101051 const int winWidth = int(${winWidth});
101052
101053 // Compute bounds for where in dy we will look
101054 float startRLerp = floor(float(r) * invHeightScale);
101055 int startDyR = int(startRLerp - float(winHeight / 2));
101056
101057 float startCLerp = floor(float(c) * invWidthScale);
101058 int startDyC = int(startCLerp - float(winWidth / 2));
101059
101060 // Loop over dy
101061 for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {
101062 int dyR = dyROffset + startDyR;
101063
101064 // Guard against the window exceeding the bounds of dy
101065 if (dyR < 0 || dyR >= ${yHeight}) {
101066 continue;
101067 }
101068
101069 for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {
101070 int dyC = dyCOffset + startDyC;
101071
101072 // Guard against the window exceeding the bounds of dy
101073 if (dyC < 0 || dyC >= ${yWidth}) {
101074 continue;
101075 }
101076
101077 float dxR = float(dyR) * heightScale;
101078 int topDxRIndex = int(floor(dxR));
101079 int bottomDxRIndex = int(min(ceil(dxR), ${xHeight - 1}.0));
101080 float dxRLerp = dxR - float(topDxRIndex);
101081 float inverseDxRLerp = 1.0 - dxRLerp;
101082
101083 float dxC = float(dyC) * widthScale;
101084 int leftDxCIndex = int(floor(dxC));
101085 int rightDxCIndex = int(min(ceil(dxC), ${xWidth - 1}.0));
101086 float dxCLerp = dxC - float(leftDxCIndex);
101087 float inverseDxCLerp = 1.0 - dxCLerp;
101088
101089 if (r == topDxRIndex && c == leftDxCIndex) {
101090 // topLeft
101091 accumulator +=
101092 getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;
101093 }
101094
101095 if (r == topDxRIndex && c == rightDxCIndex) {
101096 // topRight
101097 accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;
101098 }
101099
101100 if (r == bottomDxRIndex && c == leftDxCIndex) {
101101 // bottomLeft
101102 accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;
101103 }
101104
101105 if (r == bottomDxRIndex && c == rightDxCIndex) {
101106 // bottomRight
101107 accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;
101108 }
101109 }
101110 }
101111 // End loop over dy
101112
101113 setOutput(accumulator);
101114 }
101115 `;
101116 }
101117 }
101118
101119 /**
101120 * @license
101121 * Copyright 2020 Google LLC. All Rights Reserved.
101122 * Licensed under the Apache License, Version 2.0 (the "License");
101123 * you may not use this file except in compliance with the License.
101124 * You may obtain a copy of the License at
101125 *
101126 * http://www.apache.org/licenses/LICENSE-2.0
101127 *
101128 * Unless required by applicable law or agreed to in writing, software
101129 * distributed under the License is distributed on an "AS IS" BASIS,
101130 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101131 * See the License for the specific language governing permissions and
101132 * limitations under the License.
101133 * =============================================================================
101134 */
101135 function resizeBilinearGrad(args) {
101136 const { inputs, backend, attrs } = args;
101137 const { images, dy } = inputs;
101138 const { alignCorners } = attrs;
101139 const program = new ResizeBilinearBackpropProgram(dy.shape, images.shape, alignCorners);
101140 return backend.runWebGLProgram(program, [dy], dy.dtype);
101141 }
101142 const resizeBilinearGradConfig = {
101143 kernelName: ResizeBilinearGrad,
101144 backendName: 'webgl',
101145 kernelFunc: resizeBilinearGrad
101146 };
101147
101148 /**
101149 * @license
101150 * Copyright 2018 Google LLC. All Rights Reserved.
101151 * Licensed under the Apache License, Version 2.0 (the "License");
101152 * you may not use this file except in compliance with the License.
101153 * You may obtain a copy of the License at
101154 *
101155 * http://www.apache.org/licenses/LICENSE-2.0
101156 *
101157 * Unless required by applicable law or agreed to in writing, software
101158 * distributed under the License is distributed on an "AS IS" BASIS,
101159 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101160 * See the License for the specific language governing permissions and
101161 * limitations under the License.
101162 * =============================================================================
101163 */
101164 class ResizeNearestNeighborProgram {
101165 constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
101166 this.variableNames = ['A'];
101167 this.outputShape = [];
101168 const [batch, oldHeight, oldWidth, depth] = inputShape;
101169 this.outputShape = [batch, newHeight, newWidth, depth];
101170 const effectiveInSize = [
101171 (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
101172 (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
101173 ];
101174 const effectiveOutSize = [
101175 (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
101176 (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
101177 ];
101178 // When align corners is false, we rounds the value with floor.
101179 const roundBase = alignCorners ? '0.5' : '0.0';
101180 let sourceFracIndexRC;
101181 if (halfPixelCenters) {
101182 sourceFracIndexRC =
101183 `max((vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC` +
101184 `, vec2(0.0))`;
101185 }
101186 else {
101187 sourceFracIndexRC = `vec2(yRC) * effectiveInputOverOutputRatioRC`;
101188 }
101189 this.userCode = `
101190 const vec2 effectiveInputOverOutputRatioRC = vec2(
101191 ${effectiveInSize[0] / effectiveOutSize[0]},
101192 ${effectiveInSize[1] / effectiveOutSize[1]});
101193 const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0);
101194
101195 void main() {
101196 ivec4 coords = getOutputCoords();
101197 int b = coords[0];
101198 int d = coords[3];
101199 ivec2 yRC = coords.yz;
101200
101201 // Fractional source index.
101202 vec2 sourceFracIndexRC = ${sourceFracIndexRC};
101203
101204 // Compute the coordinators of nearest neighbor point.
101205 ivec2 sourceNearestRC = ivec2(
101206 min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ${roundBase})));
101207 float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d);
101208
101209 setOutput(newValue);
101210 }
101211 `;
101212 }
101213 }
101214
101215 /**
101216 * @license
101217 * Copyright 2019 Google LLC. All Rights Reserved.
101218 * Licensed under the Apache License, Version 2.0 (the "License");
101219 * you may not use this file except in compliance with the License.
101220 * You may obtain a copy of the License at
101221 *
101222 * http://www.apache.org/licenses/LICENSE-2.0
101223 *
101224 * Unless required by applicable law or agreed to in writing, software
101225 * distributed under the License is distributed on an "AS IS" BASIS,
101226 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101227 * See the License for the specific language governing permissions and
101228 * limitations under the License.
101229 * =============================================================================
101230 */
101231 class ResizeNearestNeighborPackedProgram {
101232 constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
101233 this.variableNames = ['A'];
101234 this.packedInputs = true;
101235 this.packedOutput = true;
101236 this.outputShape = [];
101237 const [batch, oldHeight, oldWidth, depth] = inputShape;
101238 this.outputShape = [batch, newHeight, newWidth, depth];
101239 const effectiveInSize = [
101240 (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
101241 (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
101242 ];
101243 const effectiveOutSize = [
101244 (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
101245 (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
101246 ];
101247 // When align corners is false, we rounds the value with floor.
101248 const roundBase = alignCorners ? '0.5' : '0.0';
101249 let sourceFracIndexRC;
101250 if (halfPixelCenters) {
101251 sourceFracIndexRC = `max((vec3(yRC) + vec3(0.5)) * ` +
101252 `effectiveInputOverOutputRatioRC, vec3(0.0))`;
101253 }
101254 else {
101255 sourceFracIndexRC = `vec3(yRC) * effectiveInputOverOutputRatioRC`;
101256 }
101257 this.userCode = `
101258 const vec3 effectiveInputOverOutputRatioRC = vec3(
101259 ${effectiveInSize[0] / effectiveOutSize[0]},
101260 ${effectiveInSize[1] / effectiveOutSize[1]},
101261 ${effectiveInSize[1] / effectiveOutSize[1]});
101262 const vec3 inputShapeRC = vec3(${oldHeight}.0, ${oldWidth}.0,
101263 ${oldWidth}.0);
101264
101265 float getAValue(int b, int r, int c, int d) {
101266 return getChannel(getA(b, r, c, d), vec2(c, d));
101267 }
101268
101269 void main() {
101270 ivec4 coords = getOutputCoords();
101271 int b = coords[0];
101272 int d = coords[3];
101273 // Calculate values for next column in yRC.z.
101274 ivec3 yRC = coords.yzz + ivec3(0, 0, 1);
101275
101276 // Fractional source index.
101277 vec3 sourceFracIndexRC = ${sourceFracIndexRC};
101278
101279 // Compute the coordinators of nearest neighbor point.
101280 ivec3 sourceNearestRC = ivec3(
101281 min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ${roundBase})));
101282
101283 // Should we calculate next column and row elements in 2x2 packed cell.
101284 bool hasNextCol = d < ${depth - 1};
101285 bool hasNextRow = coords.z < ${newWidth - 1};
101286
101287 vec4 newValue = vec4(
101288 getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d),
101289 hasNextCol ? getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d + 1)
101290 : 0.0,
101291 hasNextRow ? getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d)
101292 : 0.0,
101293 (hasNextRow && hasNextCol) ?
101294 getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d + 1) : 0.0);
101295
101296 setOutput(newValue);
101297 }
101298 `;
101299 }
101300 }
101301
101302 /**
101303 * @license
101304 * Copyright 2020 Google LLC. All Rights Reserved.
101305 * Licensed under the Apache License, Version 2.0 (the "License");
101306 * you may not use this file except in compliance with the License.
101307 * You may obtain a copy of the License at
101308 *
101309 * http://www.apache.org/licenses/LICENSE-2.0
101310 *
101311 * Unless required by applicable law or agreed to in writing, software
101312 * distributed under the License is distributed on an "AS IS" BASIS,
101313 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101314 * See the License for the specific language governing permissions and
101315 * limitations under the License.
101316 * =============================================================================
101317 */
101318 function resizeNearestNeighbor(args) {
101319 const { inputs, backend, attrs } = args;
101320 const { images } = inputs;
101321 const { alignCorners, halfPixelCenters, size } = attrs;
101322 const [newHeight, newWidth] = size;
101323 const program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ?
101324 new ResizeNearestNeighborPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) :
101325 new ResizeNearestNeighborProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters);
101326 return backend.runWebGLProgram(program, [images], images.dtype);
101327 }
101328 const resizeNearestNeighborConfig = {
101329 kernelName: ResizeNearestNeighbor,
101330 backendName: 'webgl',
101331 kernelFunc: resizeNearestNeighbor
101332 };
101333
101334 /**
101335 * @license
101336 * Copyright 2018 Google LLC. All Rights Reserved.
101337 * Licensed under the Apache License, Version 2.0 (the "License");
101338 * you may not use this file except in compliance with the License.
101339 * You may obtain a copy of the License at
101340 *
101341 * http://www.apache.org/licenses/LICENSE-2.0
101342 *
101343 * Unless required by applicable law or agreed to in writing, software
101344 * distributed under the License is distributed on an "AS IS" BASIS,
101345 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101346 * See the License for the specific language governing permissions and
101347 * limitations under the License.
101348 * =============================================================================
101349 */
101350 class ResizeNearestNeigborBackpropProgram {
101351 constructor(dyShape, inputShape, alignCorners) {
101352 this.variableNames = ['dy'];
101353 this.outputShape = [];
101354 this.outputShape = inputShape;
101355 const [, xHeight, xWidth,] = inputShape;
101356 const [, yHeight, yWidth] = dyShape;
101357 // In the backwards pass, we want to find the pixels that were generated for
101358 // each pixel in the input image the forward pass and add the corresponding
101359 // coefficient from dy to the gradient (with some interpolation).
101360 const effectiveXSize = [
101361 (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
101362 (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
101363 ];
101364 const effectiveYSize = [
101365 (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
101366 (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
101367 ];
101368 const heightScale = effectiveXSize[0] / effectiveYSize[0];
101369 const widthScale = effectiveXSize[1] / effectiveYSize[1];
101370 const invHeightScale = 1 / heightScale;
101371 const invWidthScale = 1 / widthScale;
101372 // This defines the size of the window of values around a particular
101373 // index in dy that we want to search for contributions to dx.
101374 const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
101375 const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
101376 this.userCode = `
101377 void main() {
101378 ivec4 coords = getOutputCoords();
101379 int b = coords[0];
101380 int d = coords[3];
101381 int r = coords[1];
101382 int c = coords[2];
101383
101384 float accumulator = 0.0;
101385
101386 const float heightScale = float(${heightScale});
101387 const float widthScale = float(${widthScale});
101388
101389 const float invHeightScale = float(${invHeightScale});
101390 const float invWidthScale = float(${invWidthScale});
101391
101392 const int winHeight = int(${winHeight});
101393 const int winWidth = int(${winWidth});
101394
101395 // Compute bounds for where in dy we will look
101396 float startRLerp = floor(float(r) * invHeightScale);
101397 int startDyR = int(floor(startRLerp - float(winHeight / 2)));
101398
101399 float startCLerp = floor(float(c) * invWidthScale);
101400 int startDyC = int(floor(startCLerp - float(winWidth / 2)));
101401
101402 // Loop over dy
101403 for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {
101404 int dyR = dyROffset + startDyR;
101405
101406 // Guard against the window exceeding the bounds of dy
101407 if (dyR < 0 || dyR >= ${yHeight}) {
101408 continue;
101409 }
101410
101411 for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {
101412 int dyC = dyCOffset + startDyC;
101413
101414 // Guard against the window exceeding the bounds of dy
101415 if (dyC < 0 || dyC >= ${yWidth}) {
101416 continue;
101417 }
101418
101419 float sourceFracRow =
101420 float(${effectiveXSize[0]}) *
101421 (float(dyR) / float(${effectiveYSize[0]}));
101422
101423 float sourceFracCol =
101424 float(${effectiveXSize[1]}) *
101425 (float(dyC) / float(${effectiveYSize[1]}));
101426
101427 int sourceNearestRow = int(min(
101428 float(int(${xHeight}) - 1),
101429 ${alignCorners} ? float(round(sourceFracRow)) :
101430 float(floor(sourceFracRow))));
101431
101432 int sourceNearestCol = int(min(
101433 float(int(${xWidth}) - 1),
101434 ${alignCorners} ? float(round(sourceFracCol)) :
101435 float(floor(sourceFracCol))));
101436
101437 if (r == sourceNearestRow && c == sourceNearestCol) {
101438 accumulator += getDy(b, dyR, dyC, d);
101439 }
101440 }
101441 }
101442 // End loop over dy
101443
101444 setOutput(accumulator);
101445 }
101446 `;
101447 }
101448 }
101449
101450 /**
101451 * @license
101452 * Copyright 2020 Google LLC. All Rights Reserved.
101453 * Licensed under the Apache License, Version 2.0 (the "License");
101454 * you may not use this file except in compliance with the License.
101455 * You may obtain a copy of the License at
101456 *
101457 * http://www.apache.org/licenses/LICENSE-2.0
101458 *
101459 * Unless required by applicable law or agreed to in writing, software
101460 * distributed under the License is distributed on an "AS IS" BASIS,
101461 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101462 * See the License for the specific language governing permissions and
101463 * limitations under the License.
101464 * =============================================================================
101465 */
101466 function resizeNearestNeighborGrad(args) {
101467 const { inputs, backend, attrs } = args;
101468 const { images, dy } = inputs;
101469 const { alignCorners } = attrs;
101470 const program = new ResizeNearestNeigborBackpropProgram(dy.shape, images.shape, alignCorners);
101471 return backend.runWebGLProgram(program, [dy], dy.dtype);
101472 }
101473 const resizeNearestNeighborGradConfig = {
101474 kernelName: ResizeNearestNeighborGrad,
101475 backendName: 'webgl',
101476 kernelFunc: resizeNearestNeighborGrad
101477 };
101478
101479 /**
101480 * @license
101481 * Copyright 2017 Google LLC. All Rights Reserved.
101482 * Licensed under the Apache License, Version 2.0 (the "License");
101483 * you may not use this file except in compliance with the License.
101484 * You may obtain a copy of the License at
101485 *
101486 * http://www.apache.org/licenses/LICENSE-2.0
101487 *
101488 * Unless required by applicable law or agreed to in writing, software
101489 * distributed under the License is distributed on an "AS IS" BASIS,
101490 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101491 * See the License for the specific language governing permissions and
101492 * limitations under the License.
101493 * =============================================================================
101494 */
101495 class ReverseProgram {
101496 constructor(xShape, axis) {
101497 this.variableNames = ['x'];
101498 const rank = xShape.length;
101499 if (rank > 4) {
101500 throw new Error(`WebGL backend: Reverse of rank-${rank} tensor is not yet supported`);
101501 }
101502 this.outputShape = xShape;
101503 if (rank === 1) {
101504 this.userCode = `
101505 void main() {
101506 int coord = getOutputCoords();
101507 setOutput(getX(${xShape[0]} - coord - 1));
101508 }
101509 `;
101510 return;
101511 }
101512 const getInCoord = (i) => {
101513 if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
101514 return `${xShape[i]} - coords[${i}] - 1`;
101515 }
101516 return `coords[${i}]`;
101517 };
101518 const inCoords = xShape.map((_, i) => getInCoord(i)).join(',');
101519 const type = getCoordsDataType(rank);
101520 this.userCode = `
101521 void main() {
101522 ${type} coords = getOutputCoords();
101523 setOutput(getX(${inCoords}));
101524 }
101525 `;
101526 }
101527 }
101528
101529 /**
101530 * @license
101531 * Copyright 2019 Google LLC. All Rights Reserved.
101532 * Licensed under the Apache License, Version 2.0 (the "License");
101533 * you may not use this file except in compliance with the License.
101534 * You may obtain a copy of the License at
101535 *
101536 * http://www.apache.org/licenses/LICENSE-2.0
101537 *
101538 * Unless required by applicable law or agreed to in writing, software
101539 * distributed under the License is distributed on an "AS IS" BASIS,
101540 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101541 * See the License for the specific language governing permissions and
101542 * limitations under the License.
101543 * =============================================================================
101544 */
101545 class ReversePackedProgram {
101546 constructor(xShape, axis) {
101547 this.variableNames = ['x'];
101548 this.packedInputs = true;
101549 this.packedOutput = true;
101550 const rank = xShape.length;
101551 if (rank > 4) {
101552 throw new Error(`WebGL backend: Reverse of rank-${rank} tensor is not yet supported`);
101553 }
101554 this.outputShape = xShape;
101555 const channels = getChannels('rc', rank);
101556 const nextColumn = `${channels[rank - 1]} + 1 < ${this.outputShape[rank - 1]}`;
101557 const nextRow = `${channels[rank - 2]} + 1 < ${this.outputShape[rank - 2]}`;
101558 const type = getCoordsDataType(rank);
101559 if (rank === 1) {
101560 this.userCode = `
101561 void main(){
101562 int rc = getOutputCoords();
101563 vec4 result = vec4(0.);
101564 result.r = getChannel(getX(${xShape[0]} - rc - 1),
101565 ${xShape[0]} - rc - 1);
101566 if(${nextColumn}){
101567 result.g = getChannel(getX(${xShape[0]} - (rc + 1) - 1),
101568 ${xShape[0]} - (rc + 1) - 1);
101569 }
101570 setOutput(result);
101571 }
101572 `;
101573 }
101574 else {
101575 this.userCode = `
101576 void main() {
101577 ${type} rc = getOutputCoords();
101578 vec4 result = vec4(0.);
101579 result.r = ${getR(channels.slice())};
101580 if(${nextColumn}){
101581 result.g = ${getG(channels.slice())};
101582 }
101583 if(${nextRow}) {
101584 result.b = ${getB(channels.slice())};
101585 if(${nextColumn}) {
101586 result.a = ${getA(channels.slice())};
101587 }
101588 }
101589 setOutput(result);
101590 }
101591 `;
101592 }
101593 function getR(channels) {
101594 return getChannel(channels);
101595 }
101596 function getG(channels) {
101597 channels[rank - 1] = '(' + channels[rank - 1] + ` + 1)`;
101598 return getChannel(channels);
101599 }
101600 function getB(channels) {
101601 channels[rank - 2] = '(' + channels[rank - 2] + ` + 1)`;
101602 return getChannel(channels);
101603 }
101604 function getA(channels) {
101605 channels[rank - 1] = '(' + channels[rank - 1] + ` + 1)`;
101606 channels[rank - 2] = '(' + channels[rank - 2] + ` + 1)`;
101607 return getChannel(channels);
101608 }
101609 function getChannel(channels) {
101610 const inCoordsArray = xShape.map((_, i) => getInCoord(i, channels));
101611 const inCoords = inCoordsArray.join(',');
101612 const innerDims = inCoordsArray.slice(-2).join(',');
101613 return `getChannel(getX(${inCoords}), vec2(${innerDims}))`;
101614 }
101615 function getInCoord(i, channels1) {
101616 if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
101617 return `${xShape[i]} - ${channels1[i]} - 1`;
101618 }
101619 else {
101620 return `${channels1[i]}`;
101621 }
101622 }
101623 }
101624 }
101625
101626 /**
101627 * @license
101628 * Copyright 2020 Google LLC. All Rights Reserved.
101629 * Licensed under the Apache License, Version 2.0 (the "License");
101630 * you may not use this file except in compliance with the License.
101631 * You may obtain a copy of the License at
101632 *
101633 * http://www.apache.org/licenses/LICENSE-2.0
101634 *
101635 * Unless required by applicable law or agreed to in writing, software
101636 * distributed under the License is distributed on an "AS IS" BASIS,
101637 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101638 * See the License for the specific language governing permissions and
101639 * limitations under the License.
101640 * =============================================================================
101641 */
101642 function reverse(args) {
101643 const { inputs, backend, attrs } = args;
101644 const { x } = inputs;
101645 const { dims } = attrs;
101646 const xRank = x.shape.length;
101647 const $dims = parseAxisParam(dims, x.shape);
101648 if (xRank === 0) {
101649 return identity({ inputs: { x }, backend });
101650 }
101651 const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
101652 new ReversePackedProgram(x.shape, $dims) :
101653 new ReverseProgram(x.shape, $dims);
101654 return backend.runWebGLProgram(program, [x], x.dtype);
101655 }
101656 const reverseConfig = {
101657 kernelName: Reverse,
101658 backendName: 'webgl',
101659 kernelFunc: reverse
101660 };
101661
101662 /**
101663 * @license
101664 * Copyright 2020 Google LLC. All Rights Reserved.
101665 * Licensed under the Apache License, Version 2.0 (the "License");
101666 * you may not use this file except in compliance with the License.
101667 * You may obtain a copy of the License at
101668 *
101669 * http://www.apache.org/licenses/LICENSE-2.0
101670 *
101671 * Unless required by applicable law or agreed to in writing, software
101672 * distributed under the License is distributed on an "AS IS" BASIS,
101673 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101674 * See the License for the specific language governing permissions and
101675 * limitations under the License.
101676 * =============================================================================
101677 */
101678 class RotateProgram {
101679 constructor(imageShape, fillValue) {
101680 this.variableNames = ['Image'];
101681 this.outputShape = [];
101682 this.customUniforms = [{ name: 'params', type: 'vec4' }];
101683 const imageHeight = imageShape[1];
101684 const imageWidth = imageShape[2];
101685 this.outputShape = imageShape;
101686 let fillSnippet = '';
101687 if (typeof fillValue === 'number') {
101688 fillSnippet = `float outputValue = ${fillValue.toFixed(2)};`;
101689 }
101690 else {
101691 fillSnippet = `
101692 vec3 fill = vec3(${fillValue.join(',')});
101693 float outputValue = fill[coords[3]];`;
101694 }
101695 this.userCode = `
101696 void main() {
101697 ivec4 coords = getOutputCoords();
101698 int x = coords[2];
101699 int y = coords[1];
101700 float coordXFloat = (float(x) - params[0]) * params[3] -
101701 (float(y) - params[1]) * params[2];
101702 float coordYFloat = (float(x) - params[0]) * params[2] +
101703 (float(y) - params[1]) * params[3];
101704 int coordX = int(round(coordXFloat + params[0]));
101705 int coordY = int(round(coordYFloat + params[1]));
101706 ${fillSnippet}
101707 if(coordX >= 0 && coordX < ${imageWidth} && coordY >= 0 && coordY < ${imageHeight}) {
101708 outputValue = getImage(coords[0], coordY, coordX, coords[3]);
101709 }
101710 setOutput(outputValue);
101711 }
101712 `;
101713 }
101714 }
101715
101716 /**
101717 * @license
101718 * Copyright 2020 Google LLC. All Rights Reserved.
101719 * Licensed under the Apache License, Version 2.0 (the "License");
101720 * you may not use this file except in compliance with the License.
101721 * You may obtain a copy of the License at
101722 *
101723 * http://www.apache.org/licenses/LICENSE-2.0
101724 *
101725 * Unless required by applicable law or agreed to in writing, software
101726 * distributed under the License is distributed on an "AS IS" BASIS,
101727 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101728 * See the License for the specific language governing permissions and
101729 * limitations under the License.
101730 * =============================================================================
101731 */
101732 const rotateWithOffsetConfig = {
101733 kernelName: RotateWithOffset,
101734 backendName: 'webgl',
101735 kernelFunc: ({ inputs, attrs, backend }) => {
101736 const { image } = inputs;
101737 const { radians, fillValue, center } = attrs;
101738 const webglBackend = backend;
101739 const program = new RotateProgram(image.shape, fillValue);
101740 const [centerX, centerY] = getImageCenter(center, image.shape[1], image.shape[2]);
101741 const customValues = [[centerX, centerY, Math.sin(radians), Math.cos(radians)]];
101742 const output = webglBackend.runWebGLProgram(program, [image], image.dtype, customValues);
101743 return output;
101744 }
101745 };
101746
101747 /**
101748 * @license
101749 * Copyright 2020 Google LLC. All Rights Reserved.
101750 * Licensed under the Apache License, Version 2.0 (the "License");
101751 * you may not use this file except in compliance with the License.
101752 * You may obtain a copy of the License at
101753 *
101754 * http://www.apache.org/licenses/LICENSE-2.0
101755 *
101756 * Unless required by applicable law or agreed to in writing, software
101757 * distributed under the License is distributed on an "AS IS" BASIS,
101758 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101759 * See the License for the specific language governing permissions and
101760 * limitations under the License.
101761 * =============================================================================
101762 */
101763 const ROUND = `
101764 // OpenGL ES does not support round function.
101765 // The algorithm is based on banker's rounding.
101766 float base = floor(x);
101767 if ((x - base) < 0.5) {
101768 return floor(x);
101769 } else if ((x - base) > 0.5) {
101770 return ceil(x);
101771 } else {
101772 if (mod(base, 2.0) == 0.0) {
101773 return base;
101774 } else {
101775 return base + 1.0;
101776 }
101777 }
101778`;
101779 const round = unaryKernelFunc({ opSnippet: ROUND });
101780 const roundConfig = {
101781 kernelName: Round,
101782 backendName: 'webgl',
101783 kernelFunc: round,
101784 };
101785
101786 /**
101787 * @license
101788 * Copyright 2020 Google LLC. All Rights Reserved.
101789 * Licensed under the Apache License, Version 2.0 (the "License");
101790 * you may not use this file except in compliance with the License.
101791 * You may obtain a copy of the License at
101792 *
101793 * http://www.apache.org/licenses/LICENSE-2.0
101794 *
101795 * Unless required by applicable law or agreed to in writing, software
101796 * distributed under the License is distributed on an "AS IS" BASIS,
101797 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101798 * See the License for the specific language governing permissions and
101799 * limitations under the License.
101800 * =============================================================================
101801 */
101802 const RSQRT = `return inversesqrt(x);`;
101803 const rsqrt = unaryKernelFunc({ opSnippet: RSQRT, cpuKernelImpl: rsqrtImplCPU });
101804 const rsqrtConfig = {
101805 kernelName: Rsqrt,
101806 backendName: 'webgl',
101807 kernelFunc: rsqrt
101808 };
101809
101810 /**
101811 * @license
101812 * Copyright 2018 Google LLC. All Rights Reserved.
101813 * Licensed under the Apache License, Version 2.0 (the "License");
101814 * you may not use this file except in compliance with the License.
101815 * You may obtain a copy of the License at
101816 *
101817 * http://www.apache.org/licenses/LICENSE-2.0
101818 *
101819 * Unless required by applicable law or agreed to in writing, software
101820 * distributed under the License is distributed on an "AS IS" BASIS,
101821 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101822 * See the License for the specific language governing permissions and
101823 * limitations under the License.
101824 * =============================================================================
101825 */
101826 class ScatterProgram {
101827 constructor(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex = true, defaultIsTensor = false) {
101828 this.variableNames = ['updates', 'indices', 'defaultValue'];
101829 this.outputShape = shape;
101830 const stridesType = getCoordsDataType(strides.length);
101831 const dtype = getCoordsDataType(shape.length);
101832 let indicesString = '';
101833 if (indicesRank === 1) {
101834 indicesString = 'i';
101835 }
101836 else if (indicesRank === 2) {
101837 indicesString = 'i, j';
101838 }
101839 const indicesSnippet = `getIndices(${indicesString})`;
101840 let updatesString = '';
101841 if (updatesRank === 1) {
101842 updatesString = 'i';
101843 }
101844 else if (updatesRank === 2) {
101845 updatesString = 'i, coords[1]';
101846 }
101847 const updatesSnippet = `getUpdates(${updatesString})`;
101848 let defaultValuesString = '';
101849 if (defaultIsTensor) {
101850 defaultValuesString = 'coords[0], coords[1]';
101851 }
101852 const defaultValueSnippet = `getDefaultValue(${defaultValuesString})`;
101853 const strideString = sliceDim > 1 ? 'strides[j]' : 'strides';
101854 this.userCode = `
101855 ${stridesType} strides = ${stridesType}(${strides});
101856
101857 void main() {
101858 ${dtype} coords = getOutputCoords();
101859 float sum = 0.0;
101860 bool found = false;
101861 for (int i = 0; i < ${updateSize}; i++) {
101862 int flattenedIndex = 0;
101863 for (int j = 0; j < ${sliceDim}; j++) {
101864 int index = round(${indicesSnippet});
101865 flattenedIndex += index * ${strideString};
101866 }
101867 if (flattenedIndex == coords[0]) {
101868 sum += ${updatesSnippet};
101869 found = true;
101870 }
101871 }
101872 setOutput(mix(${defaultValueSnippet}, sum, float(found)));
101873 }
101874 `;
101875 }
101876 }
101877
101878 /**
101879 * @license
101880 * Copyright 2023 Google LLC.
101881 * Licensed under the Apache License, Version 2.0 (the "License");
101882 * you may not use this file except in compliance with the License.
101883 * You may obtain a copy of the License at
101884 *
101885 * http://www.apache.org/licenses/LICENSE-2.0
101886 *
101887 * Unless required by applicable law or agreed to in writing, software
101888 * distributed under the License is distributed on an "AS IS" BASIS,
101889 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101890 * See the License for the specific language governing permissions and
101891 * limitations under the License.
101892 * =============================================================================
101893 */
101894 class ScatterPackedProgram {
101895 constructor(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex = true, defaultIsTensor = false) {
101896 this.variableNames = ['updates', 'indices', 'defaultValue'];
101897 this.packedInputs = true;
101898 this.packedOutput = true;
101899 this.outputShape = shape;
101900 const stridesType = getCoordsDataType(strides.length);
101901 const dtype = getCoordsDataType(shape.length);
101902 let indicesString = '';
101903 if (indicesRank === 1) {
101904 indicesString = 'i';
101905 }
101906 else if (indicesRank === 2) {
101907 indicesString = 'i, j';
101908 }
101909 const indicesSnippet = `getIndices(${indicesString})`;
101910 let updatesString = '';
101911 if (updatesRank === 1) {
101912 updatesString = 'i';
101913 }
101914 else if (updatesRank === 2) {
101915 updatesString = 'i, coords[1]';
101916 }
101917 const updatesSnippet = `getUpdates(${updatesString})`;
101918 let defaultValuesString = '';
101919 if (defaultIsTensor) {
101920 defaultValuesString = 'coords[0], coords[1]';
101921 }
101922 const defaultValueSnippet = `getDefaultValue(${defaultValuesString})`;
101923 const strideString = sliceDim > 1 ? 'strides[j]' : 'strides';
101924 const strideString2 = sliceDim > 1 ? 'strides[j + 1]' : 'strides';
101925 this.userCode = `
101926 ${stridesType} strides = ${stridesType}(${strides});
101927
101928 void main() {
101929 ${dtype} coords = getOutputCoords();
101930 vec4 sum = vec4(0.);
101931 vec4 found = vec4(0.);
101932 for (int i = 0; i < ${updateSize}; i+=2) {
101933 ivec2 flattenedIndex = ivec2(0);
101934 for (int j = 0; j < ${sliceDim}; j+=2) {
101935 ivec4 index = round(${indicesSnippet});
101936 flattenedIndex += index.xz * ${strideString};
101937 if (j + 1 < ${sliceDim}) {
101938 flattenedIndex += index.yw * ${strideString2};
101939 }
101940 }
101941 if (flattenedIndex[0] == coords[0] || flattenedIndex[1] == coords[0] ||
101942 flattenedIndex[0] == coords[0] + 1 || flattenedIndex[1] == coords[0] + 1) {
101943 vec4 updVals = ${updatesSnippet};
101944 if (flattenedIndex[0] == coords[0]) {
101945 sum.xy += updVals.xy;
101946 found.xy = vec2(1.);
101947 } else if (flattenedIndex[0] == coords[0] + 1) {
101948 sum.zw += updVals.xy;
101949 found.zw = vec2(1.);
101950 }
101951 if (flattenedIndex[1] == coords[0]) {
101952 sum.xy += updVals.zw;
101953 found.xy = vec2(1.);
101954 } else if (flattenedIndex[1] == coords[0] + 1) {
101955 sum.zw += updVals.zw;
101956 found.zw = vec2(1.);
101957 }
101958 }
101959 }
101960 setOutput(mix(${defaultValueSnippet}, sum, found));
101961 }
101962 `;
101963 }
101964 }
101965
101966 /**
101967 * @license
101968 * Copyright 2020 Google LLC. All Rights Reserved.
101969 * Licensed under the Apache License, Version 2.0 (the "License");
101970 * you may not use this file except in compliance with the License.
101971 * You may obtain a copy of the License at
101972 *
101973 * http://www.apache.org/licenses/LICENSE-2.0
101974 *
101975 * Unless required by applicable law or agreed to in writing, software
101976 * distributed under the License is distributed on an "AS IS" BASIS,
101977 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101978 * See the License for the specific language governing permissions and
101979 * limitations under the License.
101980 * =============================================================================
101981 */
101982 function scatterNd(args) {
101983 const { inputs, backend, attrs } = args;
101984 const { indices, updates } = inputs;
101985 const { shape } = attrs;
101986 const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, shape);
101987 const flattenShape = [outputSize / sliceSize, sliceSize];
101988 if (outputSize === 0) {
101989 return backend.makeTensorInfo(shape, indices.dtype);
101990 }
101991 const flattenIndices = reshape({ inputs: { x: indices }, backend, attrs: { shape: [numUpdates, sliceRank] } });
101992 const flattenX = reshape({ inputs: { x: updates }, backend, attrs: { shape: [numUpdates, sliceSize] } });
101993 const defaultValue = backend.makeTensorInfo([], 'float32', new Float32Array([0])); // scalar(0)
101994 let program;
101995 if (env().getBool('WEBGL_PACK')) {
101996 program = new ScatterPackedProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape);
101997 }
101998 else {
101999 program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape);
102000 }
102001 const res = backend.runWebGLProgram(program, [flattenX, flattenIndices, defaultValue], flattenX.dtype);
102002 const reshaped = reshape({ inputs: { x: res }, backend, attrs: { shape } });
102003 backend.disposeIntermediateTensorInfo(flattenIndices);
102004 backend.disposeIntermediateTensorInfo(flattenX);
102005 backend.disposeIntermediateTensorInfo(res);
102006 backend.disposeIntermediateTensorInfo(defaultValue);
102007 return reshaped;
102008 }
102009 const scatterNdConfig = {
102010 kernelName: ScatterNd,
102011 backendName: 'webgl',
102012 kernelFunc: scatterNd
102013 };
102014
102015 /**
102016 * @license
102017 * Copyright 2022 Google LLC. All Rights Reserved.
102018 * Licensed under the Apache License, Version 2.0 (the "License");
102019 * you may not use this file except in compliance with the License.
102020 * You may obtain a copy of the License at
102021 *
102022 * http://www.apache.org/licenses/LICENSE-2.0
102023 *
102024 * Unless required by applicable law or agreed to in writing, software
102025 * distributed under the License is distributed on an "AS IS" BASIS,
102026 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102027 * See the License for the specific language governing permissions and
102028 * limitations under the License.
102029 * =============================================================================
102030 */
102031 class SearchSortedProgram {
102032 constructor(batchSize, numInputs, numValues, side) {
102033 this.variableNames = ['sortedSequence', 'values'];
102034 this.customUniforms = [{ name: 'numInputs', type: 'int' }];
102035 this.outputShape = [batchSize, numValues];
102036 const webGL2LoopHead = 'while (left < right) {';
102037 // WebGL1 doesn't accept non constant loop conditions, so upper bound loop
102038 // iterations.
102039 const webGL1LoopHead = `for (int i = 0; i < ${Math.ceil(Math.log2(numInputs + 1))}; ++i) { if (left >= right) break;`;
102040 const loopHead = env().getNumber('WEBGL_VERSION') === 2 ? webGL2LoopHead :
102041 webGL1LoopHead;
102042 // left corresponds to lower bound and right to upper bound.
102043 const boundComparator = side === 'left' ? '<' : '<=';
102044 this.userCode = `
102045 int findBound(int batch, float value) {
102046 int left = 0;
102047 int right = numInputs;
102048 int mid;
102049 ${loopHead}
102050 mid = (left + right) / 2;
102051 if (getSortedSequence(batch, mid) ${boundComparator} value) {
102052 left = mid + 1;
102053 } else {
102054 right = mid;
102055 }
102056 }
102057 return right;
102058 }
102059
102060 void main() {
102061 ivec2 coords = getOutputCoords();
102062 int batch = coords[0];
102063 int valueIndex = coords[1];
102064
102065 float value = getValues(batch, valueIndex);
102066
102067 setOutput(float(findBound(batch, value)));
102068 }
102069 `;
102070 }
102071 }
102072
102073 /**
102074 * @license
102075 * Copyright 2022 Google LLC. All Rights Reserved.
102076 * Licensed under the Apache License, Version 2.0 (the "License");
102077 * you may not use this file except in compliance with the License.
102078 * You may obtain a copy of the License at
102079 *
102080 * http://www.apache.org/licenses/LICENSE-2.0
102081 *
102082 * Unless required by applicable law or agreed to in writing, software
102083 * distributed under the License is distributed on an "AS IS" BASIS,
102084 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102085 * See the License for the specific language governing permissions and
102086 * limitations under the License.
102087 * =============================================================================
102088 */
102089 function searchSorted(args) {
102090 const { inputs, backend, attrs } = args;
102091 const { sortedSequence, values } = inputs;
102092 const { side } = attrs;
102093 const program = new SearchSortedProgram(sortedSequence.shape[0], sortedSequence.shape[1], values.shape[1], side);
102094 const customValues = [[sortedSequence.shape[1]]];
102095 return backend.runWebGLProgram(program, [sortedSequence, values], 'int32', customValues);
102096 }
102097 const searchSortedConfig = {
102098 kernelName: SearchSorted,
102099 backendName: 'webgl',
102100 kernelFunc: searchSorted,
102101 };
102102
102103 /**
102104 * @license
102105 * Copyright 2017 Google LLC. All Rights Reserved.
102106 * Licensed under the Apache License, Version 2.0 (the "License");
102107 * you may not use this file except in compliance with the License.
102108 * You may obtain a copy of the License at
102109 *
102110 * http://www.apache.org/licenses/LICENSE-2.0
102111 *
102112 * Unless required by applicable law or agreed to in writing, software
102113 * distributed under the License is distributed on an "AS IS" BASIS,
102114 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102115 * See the License for the specific language governing permissions and
102116 * limitations under the License.
102117 * =============================================================================
102118 */
102119 class SelectProgram {
102120 constructor(cRank, shape, rank) {
102121 this.variableNames = ['c', 'a', 'b'];
102122 this.outputShape = shape;
102123 let cCoords;
102124 let abCoords;
102125 if (rank > 4) {
102126 throw Error(`Where for rank ${rank} is not yet supported`);
102127 }
102128 if (rank === 1) {
102129 abCoords = `resRC`;
102130 cCoords = `resRC`;
102131 }
102132 else {
102133 const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
102134 const cCoordVars = [];
102135 const abCoordVars = [];
102136 for (let i = 0; i < shape.length; i++) {
102137 abCoordVars.push(`${currentCoords[i]}`);
102138 if (i < cRank) {
102139 cCoordVars.push(`${currentCoords[i]}`);
102140 }
102141 }
102142 cCoords = cCoordVars.join();
102143 abCoords = abCoordVars.join();
102144 }
102145 const dtype = getCoordsDataType(rank);
102146 this.userCode = `
102147 void main() {
102148 ${dtype} resRC = getOutputCoords();
102149 float cVal = getC(${cCoords});
102150 if (cVal >= 1.0) {
102151 setOutput(getA(${abCoords}));
102152 } else {
102153 setOutput(getB(${abCoords}));
102154 }
102155 }
102156 `;
102157 }
102158 }
102159
102160 /**
102161 * @license
102162 * Copyright 2020 Google LLC. All Rights Reserved.
102163 * Licensed under the Apache License, Version 2.0 (the "License");
102164 * you may not use this file except in compliance with the License.
102165 * You may obtain a copy of the License at
102166 *
102167 * http://www.apache.org/licenses/LICENSE-2.0
102168 *
102169 * Unless required by applicable law or agreed to in writing, software
102170 * distributed under the License is distributed on an "AS IS" BASIS,
102171 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102172 * See the License for the specific language governing permissions and
102173 * limitations under the License.
102174 * =============================================================================
102175 */
102176 function select(args) {
102177 const { inputs, backend } = args;
102178 const { condition, t, e } = inputs;
102179 const program = new SelectProgram(condition.shape.length, t.shape, t.shape.length);
102180 return backend.runWebGLProgram(program, [condition, t, e], upcastType(t.dtype, e.dtype));
102181 }
102182 const selectConfig = {
102183 kernelName: Select,
102184 backendName: 'webgl',
102185 kernelFunc: select
102186 };
102187
102188 /**
102189 * @license
102190 * Copyright 2020 Google LLC. All Rights Reserved.
102191 * Licensed under the Apache License, Version 2.0 (the "License");
102192 * you may not use this file except in compliance with the License.
102193 * You may obtain a copy of the License at
102194 *
102195 * http://www.apache.org/licenses/LICENSE-2.0
102196 *
102197 * Unless required by applicable law or agreed to in writing, software
102198 * distributed under the License is distributed on an "AS IS" BASIS,
102199 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102200 * See the License for the specific language governing permissions and
102201 * limitations under the License.
102202 * =============================================================================
102203 */
102204 const SELU = `
102205 // Stable and Attracting Fixed Point (0, 1) for Normalized Weights.
102206 // see: https://arxiv.org/abs/1706.02515
102207 float scaleAlpha = ${SELU_SCALEALPHA};
102208 float scale = ${SELU_SCALE};
102209 return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0);
102210`;
102211 const selu = unaryKernelFunc({ opSnippet: SELU });
102212 const seluConfig = {
102213 kernelName: Selu$1,
102214 backendName: 'webgl',
102215 kernelFunc: selu,
102216 };
102217
102218 /**
102219 * @license
102220 * Copyright 2020 Google LLC. All Rights Reserved.
102221 * Licensed under the Apache License, Version 2.0 (the "License");
102222 * you may not use this file except in compliance with the License.
102223 * You may obtain a copy of the License at
102224 *
102225 * http://www.apache.org/licenses/LICENSE-2.0
102226 *
102227 * Unless required by applicable law or agreed to in writing, software
102228 * distributed under the License is distributed on an "AS IS" BASIS,
102229 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102230 * See the License for the specific language governing permissions and
102231 * limitations under the License.
102232 * =============================================================================
102233 */
102234 const SIGMOID = CHECK_NAN_SNIPPET_UNARY + `
102235 return 1.0 / (1.0 + exp(-1.0 * x));
102236`;
102237 const SIGMOID_PACKED = `
102238 vec4 result = 1.0 / (1.0 + exp(-1.0 * x));
102239 bvec4 isNaN = isnan(x);
102240
102241 result.r = isNaN.r ? x.r : result.r;
102242 result.g = isNaN.g ? x.g : result.g;
102243 result.b = isNaN.b ? x.b : result.b;
102244 result.a = isNaN.a ? x.a : result.a;
102245
102246 return result;
102247`;
102248 const sigmoid = unaryKernelFunc({
102249 opSnippet: SIGMOID,
102250 packedOpSnippet: SIGMOID_PACKED,
102251 cpuKernelImpl: sigmoidImplCPU
102252 });
102253 const sigmoidConfig = {
102254 kernelName: Sigmoid$1,
102255 backendName: 'webgl',
102256 kernelFunc: sigmoid,
102257 };
102258
102259 /**
102260 * @license
102261 * Copyright 2020 Google LLC. All Rights Reserved.
102262 * Licensed under the Apache License, Version 2.0 (the "License");
102263 * you may not use this file except in compliance with the License.
102264 * You may obtain a copy of the License at
102265 *
102266 * http://www.apache.org/licenses/LICENSE-2.0
102267 *
102268 * Unless required by applicable law or agreed to in writing, software
102269 * distributed under the License is distributed on an "AS IS" BASIS,
102270 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102271 * See the License for the specific language governing permissions and
102272 * limitations under the License.
102273 * =============================================================================
102274 */
102275 // Sign does not propagate NANs.
102276 const SIGN = `
102277 if (isnan(x)) { return 0.0; }
102278 return sign(x);
102279`;
102280 const sign = unaryKernelFunc({ opSnippet: SIGN });
102281 const signConfig = {
102282 kernelName: Sign,
102283 backendName: 'webgl',
102284 kernelFunc: sign,
102285 };
102286
102287 /**
102288 * @license
102289 * Copyright 2020 Google LLC. All Rights Reserved.
102290 * Licensed under the Apache License, Version 2.0 (the "License");
102291 * you may not use this file except in compliance with the License.
102292 * You may obtain a copy of the License at
102293 *
102294 * http://www.apache.org/licenses/LICENSE-2.0
102295 *
102296 * Unless required by applicable law or agreed to in writing, software
102297 * distributed under the License is distributed on an "AS IS" BASIS,
102298 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102299 * See the License for the specific language governing permissions and
102300 * limitations under the License.
102301 * =============================================================================
102302 */
102303 const SIN = CHECK_NAN_SNIPPET_UNARY + `
102304 return sin(x);
102305`;
102306 const SIN_PACKED = `
102307 vec4 result = sin(x);
102308 bvec4 isNaN = isnan(x);
102309 ${CHECK_NAN_SNIPPET_PACKED}
102310 return result;
102311`;
102312 const sin = unaryKernelFunc({ opSnippet: SIN, packedOpSnippet: SIN_PACKED });
102313 const sinConfig = {
102314 kernelName: Sin,
102315 backendName: 'webgl',
102316 kernelFunc: sin,
102317 };
102318
102319 /**
102320 * @license
102321 * Copyright 2020 Google LLC. All Rights Reserved.
102322 * Licensed under the Apache License, Version 2.0 (the "License");
102323 * you may not use this file except in compliance with the License.
102324 * You may obtain a copy of the License at
102325 *
102326 * http://www.apache.org/licenses/LICENSE-2.0
102327 *
102328 * Unless required by applicable law or agreed to in writing, software
102329 * distributed under the License is distributed on an "AS IS" BASIS,
102330 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102331 * See the License for the specific language governing permissions and
102332 * limitations under the License.
102333 * =============================================================================
102334 */
102335 const SINH = `
102336 float e2x = exp(x);
102337 return (e2x - 1.0 / e2x) / 2.0;
102338`;
102339 const sinh = unaryKernelFunc({ opSnippet: SINH });
102340 const sinhConfig = {
102341 kernelName: Sinh,
102342 backendName: 'webgl',
102343 kernelFunc: sinh,
102344 };
102345
102346 /**
102347 * @license
102348 * Copyright 2020 Google LLC. All Rights Reserved.
102349 * Licensed under the Apache License, Version 2.0 (the "License");
102350 * you may not use this file except in compliance with the License.
102351 * You may obtain a copy of the License at
102352 *
102353 * http://www.apache.org/licenses/LICENSE-2.0
102354 *
102355 * Unless required by applicable law or agreed to in writing, software
102356 * distributed under the License is distributed on an "AS IS" BASIS,
102357 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102358 * See the License for the specific language governing permissions and
102359 * limitations under the License.
102360 * =============================================================================
102361 */
102362 const SOFTPLUS = `
102363 float epsilon = 1.1920928955078125e-7;
102364 float threshold = log(epsilon) + 2.0;
102365
102366 bool too_large = x > -threshold;
102367 bool too_small = x < threshold;
102368
102369 float result;
102370 float exp_x = exp(x);
102371
102372 if (too_large){
102373 result = x;
102374 }
102375 else if (too_small){
102376 result = exp_x;
102377 }
102378 else{
102379 result = log(exp_x + 1.0);
102380 }
102381 return result;
102382`;
102383 const softplus = unaryKernelFunc({ opSnippet: SOFTPLUS });
102384 const softplusConfig = {
102385 kernelName: Softplus$1,
102386 backendName: 'webgl',
102387 kernelFunc: softplus,
102388 };
102389
102390 /**
102391 * @license
102392 * Copyright 2020 Google LLC. All Rights Reserved.
102393 * Licensed under the Apache License, Version 2.0 (the "License");
102394 * you may not use this file except in compliance with the License.
102395 * You may obtain a copy of the License at
102396 *
102397 * http://www.apache.org/licenses/LICENSE-2.0
102398 *
102399 * Unless required by applicable law or agreed to in writing, software
102400 * distributed under the License is distributed on an "AS IS" BASIS,
102401 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102402 * See the License for the specific language governing permissions and
102403 * limitations under the License.
102404 * =============================================================================
102405 */
102406 const spaceToBatchND = (args) => {
102407 const { inputs, backend, attrs } = args;
102408 const { x } = inputs;
102409 const { blockShape, paddings } = attrs;
102410 assert$1(x.shape.length <= 4, () => 'spaceToBatchND for rank > 4 with a WebGL backend not ' +
102411 'implemented yet');
102412 const prod = blockShape.reduce((a, b) => a * b);
102413 const completePaddings = [[0, 0]];
102414 completePaddings.push(...paddings);
102415 for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
102416 completePaddings.push([0, 0]);
102417 }
102418 const toDispose = [];
102419 const paddedX = padV2({
102420 inputs: { x },
102421 backend,
102422 attrs: { paddings: completePaddings, constantValue: 0 }
102423 });
102424 const reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false);
102425 const permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false);
102426 const flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false);
102427 const reshapedPaddedX = reshape({ inputs: { x: paddedX }, backend, attrs: { shape: reshapedPaddedShape } });
102428 const paddedXT = transpose({
102429 inputs: { x: reshapedPaddedX },
102430 backend,
102431 attrs: { perm: permutedReshapedPaddedPermutation }
102432 });
102433 const result = reshape({ inputs: { x: paddedXT }, backend, attrs: { shape: flattenShape } });
102434 toDispose.push(paddedX);
102435 toDispose.push(reshapedPaddedX);
102436 toDispose.push(paddedXT);
102437 toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
102438 return result;
102439 };
102440 const spaceToBatchNDConfig = {
102441 kernelName: SpaceToBatchND,
102442 backendName: 'webgl',
102443 kernelFunc: spaceToBatchND
102444 };
102445
102446 /**
102447 * @license
102448 * Copyright 2021 Google LLC. All Rights Reserved.
102449 * Licensed under the Apache License, Version 2.0 (the "License");
102450 * you may not use this file except in compliance with the License.
102451 * You may obtain a copy of the License at
102452 *
102453 * http://www.apache.org/licenses/LICENSE-2.0
102454 *
102455 * Unless required by applicable law or agreed to in writing, software
102456 * distributed under the License is distributed on an "AS IS" BASIS,
102457 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102458 * See the License for the specific language governing permissions and
102459 * limitations under the License.
102460 * =============================================================================
102461 */
102462 function sparseFillEmptyRows(args) {
102463 const { inputs, backend } = args;
102464 const { indices, values, denseShape, defaultValue } = inputs;
102465 if (denseShape.shape.length !== 1) {
102466 throw new Error(`Dense shape must be a vector, saw:
102467 ${denseShape.shape}`);
102468 }
102469 if (indices.shape.length !== 2) {
102470 throw new Error(`Indices must be a matrix, saw:
102471 ${indices.shape}`);
102472 }
102473 if (values.shape.length !== 1) {
102474 throw new Error(`Values must be a vector, saw:
102475 ${values.shape}`);
102476 }
102477 if (defaultValue.shape.length !== 0) {
102478 throw new Error(`Default value must be a scalar, saw:
102479 ${defaultValue.shape}`);
102480 }
102481 const $indices = backend.readSync(indices.dataId);
102482 const $values = backend.readSync(values.dataId);
102483 const $denseShape = backend.readSync(denseShape.dataId);
102484 const $defaultValue = backend.readSync(defaultValue.dataId)[0];
102485 const [outputIndices, outputIndicesShape, outputValues, emptyRowIndicator, reverseIndexMap] = sparseFillEmptyRowsImplCPU($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue);
102486 return [
102487 backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices),
102488 backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues),
102489 backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map((value) => Number(value)))),
102490 backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap)),
102491 ];
102492 }
102493 const sparseFillEmptyRowsConfig = {
102494 kernelName: SparseFillEmptyRows,
102495 backendName: 'webgl',
102496 kernelFunc: sparseFillEmptyRows,
102497 };
102498
102499 /**
102500 * @license
102501 * Copyright 2021 Google LLC. All Rights Reserved.
102502 * Licensed under the Apache License, Version 2.0 (the "License");
102503 * you may not use this file except in compliance with the License.
102504 * You may obtain a copy of the License at
102505 *
102506 * http://www.apache.org/licenses/LICENSE-2.0
102507 *
102508 * Unless required by applicable law or agreed to in writing, software
102509 * distributed under the License is distributed on an "AS IS" BASIS,
102510 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102511 * See the License for the specific language governing permissions and
102512 * limitations under the License.
102513 * =============================================================================
102514 */
102515 function sparseReshape(args) {
102516 const { inputs, backend } = args;
102517 const { inputIndices, inputShape, newShape } = inputs;
102518 if (inputIndices.shape.length !== 2) {
102519 throw new Error(`Input indices should be a matrix but received shape ${inputIndices.shape}`);
102520 }
102521 if (inputShape.shape.length !== 1) {
102522 throw new Error(`Input shape should be a vector but received shape ${inputShape.shape}`);
102523 }
102524 if (newShape.shape.length !== 1) {
102525 throw new Error(`Target shape should be a vector but received shape ${newShape.shape}`);
102526 }
102527 const $inputShape = Array.from(backend.readSync(inputShape.dataId));
102528 const $inputIndices = backend.readSync(inputIndices.dataId);
102529 const targetShape = Array.from(backend.readSync(newShape.dataId));
102530 const [newIndices, indicesShape, outputShape] = sparseReshapeImplCPU($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape);
102531 return [
102532 backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices),
102533 backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape)),
102534 ];
102535 }
102536 const sparseReshapeConfig = {
102537 kernelName: SparseReshape,
102538 backendName: 'webgl',
102539 kernelFunc: sparseReshape,
102540 };
102541
102542 /**
102543 * @license
102544 * Copyright 2021 Google LLC. All Rights Reserved.
102545 * Licensed under the Apache License, Version 2.0 (the "License");
102546 * you may not use this file except in compliance with the License.
102547 * You may obtain a copy of the License at
102548 *
102549 * http://www.apache.org/licenses/LICENSE-2.0
102550 *
102551 * Unless required by applicable law or agreed to in writing, software
102552 * distributed under the License is distributed on an "AS IS" BASIS,
102553 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102554 * See the License for the specific language governing permissions and
102555 * limitations under the License.
102556 * =============================================================================
102557 */
102558 function sparseSegmentMean(args) {
102559 const { inputs, backend } = args;
102560 const { data, indices, segmentIds } = inputs;
102561 if (data.shape.length < 1) {
102562 throw new Error(`Data should be at least 1 dimensional but received scalar`);
102563 }
102564 if (indices.shape.length !== 1) {
102565 throw new Error(`Indices should be a vector but received shape
102566 ${indices.shape}`);
102567 }
102568 if (segmentIds.shape.length !== 1) {
102569 throw new Error(`Segment ids should be a vector but received shape
102570 ${segmentIds.shape}`);
102571 }
102572 const $data = backend.readSync(data.dataId);
102573 const $indices = backend.readSync(indices.dataId);
102574 const $segmentIds = backend.readSync(segmentIds.dataId);
102575 const [outputData, outputDataShape] = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds, true);
102576 return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
102577 }
102578 const sparseSegmentMeanConfig = {
102579 kernelName: SparseSegmentMean,
102580 backendName: 'webgl',
102581 kernelFunc: sparseSegmentMean,
102582 };
102583
102584 /**
102585 * @license
102586 * Copyright 2021 Google LLC. All Rights Reserved.
102587 * Licensed under the Apache License, Version 2.0 (the "License");
102588 * you may not use this file except in compliance with the License.
102589 * You may obtain a copy of the License at
102590 *
102591 * http://www.apache.org/licenses/LICENSE-2.0
102592 *
102593 * Unless required by applicable law or agreed to in writing, software
102594 * distributed under the License is distributed on an "AS IS" BASIS,
102595 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102596 * See the License for the specific language governing permissions and
102597 * limitations under the License.
102598 * =============================================================================
102599 */
102600 function sparseSegmentSum(args) {
102601 const { inputs, backend } = args;
102602 const { data, indices, segmentIds } = inputs;
102603 if (data.shape.length < 1) {
102604 throw new Error(`Data should be at least 1 dimensional but received scalar`);
102605 }
102606 if (indices.shape.length !== 1) {
102607 throw new Error(`Indices should be a vector but received shape
102608 ${indices.shape}`);
102609 }
102610 if (segmentIds.shape.length !== 1) {
102611 throw new Error(`Segment ids should be a vector but received shape
102612 ${segmentIds.shape}`);
102613 }
102614 const $data = backend.readSync(data.dataId);
102615 const $indices = backend.readSync(indices.dataId);
102616 const $segmentIds = backend.readSync(segmentIds.dataId);
102617 const [outputData, outputDataShape] = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds);
102618 return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
102619 }
102620 const sparseSegmentSumConfig = {
102621 kernelName: SparseSegmentSum,
102622 backendName: 'webgl',
102623 kernelFunc: sparseSegmentSum,
102624 };
102625
102626 /**
102627 * @license
102628 * Copyright 2020 Google LLC. All Rights Reserved.
102629 * Licensed under the Apache License, Version 2.0 (the "License");
102630 * you may not use this file except in compliance with the License.
102631 * You may obtain a copy of the License at
102632 *
102633 * http://www.apache.org/licenses/LICENSE-2.0
102634 *
102635 * Unless required by applicable law or agreed to in writing, software
102636 * distributed under the License is distributed on an "AS IS" BASIS,
102637 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102638 * See the License for the specific language governing permissions and
102639 * limitations under the License.
102640 * =============================================================================
102641 */
102642 function sparseToDense(args) {
102643 const { inputs, backend, attrs } = args;
102644 const { sparseIndices, sparseValues, defaultValue } = inputs;
102645 const { outputShape } = attrs;
102646 const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(sparseValues, sparseIndices, outputShape);
102647 const sumDupeIndices = false;
102648 if (sparseValues.dtype === 'string') {
102649 const indicesBuf = backend.bufferSync(sparseIndices);
102650 const updatesBuf = backend.bufferSync(sparseValues);
102651 const $defaultValue = decodeString(backend.readSync(defaultValue.dataId)[0]);
102652 const outBuf = scatterImplCPU(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
102653 return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values);
102654 }
102655 const program = new ScatterProgram(numUpdates, sliceRank, sparseIndices.shape.length, sparseValues.shape.length, strides, [outputSize, 1], sumDupeIndices);
102656 const res = backend.runWebGLProgram(program, [sparseValues, sparseIndices, defaultValue], sparseValues.dtype);
102657 const reshaped = reshape({ inputs: { x: res }, backend, attrs: { shape: outputShape } });
102658 backend.disposeIntermediateTensorInfo(res);
102659 return reshaped;
102660 }
102661 const sparseToDenseConfig = {
102662 kernelName: SparseToDense,
102663 backendName: 'webgl',
102664 kernelFunc: sparseToDense
102665 };
102666
102667 /**
102668 * @license
102669 * Copyright 2020 Google LLC. All Rights Reserved.
102670 * Licensed under the Apache License, Version 2.0 (the "License");
102671 * you may not use this file except in compliance with the License.
102672 * You may obtain a copy of the License at
102673 *
102674 * http://www.apache.org/licenses/LICENSE-2.0
102675 *
102676 * Unless required by applicable law or agreed to in writing, software
102677 * distributed under the License is distributed on an "AS IS" BASIS,
102678 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102679 * See the License for the specific language governing permissions and
102680 * limitations under the License.
102681 * =============================================================================
102682 */
102683 function splitV(args) {
102684 const { inputs, backend, attrs } = args;
102685 const { x } = inputs;
102686 const { numOrSizeSplits, axis } = attrs;
102687 const $axis = parseAxisParam(axis, x.shape)[0];
102688 const splitSizes = prepareSplitSize(x, numOrSizeSplits, $axis);
102689 const xRank = x.shape.length;
102690 const begin = new Array(xRank).fill(0);
102691 const size = x.shape.slice();
102692 return splitSizes.map(s => {
102693 const sliceSize = [...size];
102694 sliceSize[$axis] = s;
102695 const sliceT = slice({ inputs: { x }, backend, attrs: { begin, size: sliceSize } });
102696 begin[$axis] += s;
102697 return sliceT;
102698 });
102699 }
102700 const splitVConfig = {
102701 kernelName: SplitV,
102702 backendName: 'webgl',
102703 kernelFunc: splitV
102704 };
102705
102706 /**
102707 * @license
102708 * Copyright 2020 Google LLC. All Rights Reserved.
102709 * Licensed under the Apache License, Version 2.0 (the "License");
102710 * you may not use this file except in compliance with the License.
102711 * You may obtain a copy of the License at
102712 *
102713 * http://www.apache.org/licenses/LICENSE-2.0
102714 *
102715 * Unless required by applicable law or agreed to in writing, software
102716 * distributed under the License is distributed on an "AS IS" BASIS,
102717 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102718 * See the License for the specific language governing permissions and
102719 * limitations under the License.
102720 * =============================================================================
102721 */
102722 const SQRT = `return sqrt(x);`;
102723 const sqrt = unaryKernelFunc({ opSnippet: SQRT, packedOpSnippet: SQRT, cpuKernelImpl: sqrtImplCPU });
102724 const sqrtConfig = {
102725 kernelName: Sqrt,
102726 backendName: 'webgl',
102727 kernelFunc: sqrt
102728 };
102729
102730 /**
102731 * @license
102732 * Copyright 2019 Google LLC. All Rights Reserved.
102733 * Licensed under the Apache License, Version 2.0 (the "License");
102734 * you may not use this file except in compliance with the License.
102735 * You may obtain a copy of the License at
102736 *
102737 * http://www.apache.org/licenses/LICENSE-2.0
102738 *
102739 * Unless required by applicable law or agreed to in writing, software
102740 * distributed under the License is distributed on an "AS IS" BASIS,
102741 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102742 * See the License for the specific language governing permissions and
102743 * limitations under the License.
102744 * =============================================================================
102745 */
102746 const SQUARE = `return x * x;`;
102747 const square = unaryKernelFunc({ opSnippet: SQUARE });
102748 const squareConfig = {
102749 kernelName: Square,
102750 backendName: 'webgl',
102751 kernelFunc: square,
102752 };
102753
102754 /**
102755 * @license
102756 * Copyright 2020 Google LLC. All Rights Reserved.
102757 * Licensed under the Apache License, Version 2.0 (the "License");
102758 * you may not use this file except in compliance with the License.
102759 * You may obtain a copy of the License at
102760 *
102761 * http://www.apache.org/licenses/LICENSE-2.0
102762 *
102763 * Unless required by applicable law or agreed to in writing, software
102764 * distributed under the License is distributed on an "AS IS" BASIS,
102765 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102766 * See the License for the specific language governing permissions and
102767 * limitations under the License.
102768 * =============================================================================
102769 */
102770 const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);';
102771 const squaredDifference = binaryKernelFunc({ opSnippet: SQUARED_DIFFERENCE, packedOpSnippet: SQUARED_DIFFERENCE });
102772 const squaredDifferenceConfig = {
102773 kernelName: SquaredDifference,
102774 backendName: 'webgl',
102775 kernelFunc: squaredDifference,
102776 };
102777
102778 /**
102779 * @license
102780 * Copyright 2023 Google LLC.
102781 * Licensed under the Apache License, Version 2.0 (the "License");
102782 * you may not use this file except in compliance with the License.
102783 * You may obtain a copy of the License at
102784 *
102785 * http://www.apache.org/licenses/LICENSE-2.0
102786 *
102787 * Unless required by applicable law or agreed to in writing, software
102788 * distributed under the License is distributed on an "AS IS" BASIS,
102789 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102790 * See the License for the specific language governing permissions and
102791 * limitations under the License.
102792 * =============================================================================
102793 */
102794 function staticRegexReplace(args) {
102795 const { inputs, backend, attrs } = args;
102796 const { x } = inputs;
102797 if (x.dtype !== 'string') {
102798 throw new Error('Input must be of datatype string');
102799 }
102800 const $x = backend.readSync(x.dataId);
102801 const stringInput = fromUint8ToStringArray($x);
102802 const output = staticRegexReplaceImplCPU(stringInput, 'string', attrs);
102803 return backend.makeTensorInfo(x.shape, 'string', output);
102804 }
102805 const staticRegexReplaceConfig = {
102806 kernelName: StaticRegexReplace,
102807 backendName: 'webgl',
102808 kernelFunc: staticRegexReplace,
102809 };
102810
102811 /**
102812 * @license
102813 * Copyright 2020 Google LLC. All Rights Reserved.
102814 * Licensed under the Apache License, Version 2.0 (the "License");
102815 * you may not use this file except in compliance with the License.
102816 * You may obtain a copy of the License at
102817 *
102818 * http://www.apache.org/licenses/LICENSE-2.0
102819 *
102820 * Unless required by applicable law or agreed to in writing, software
102821 * distributed under the License is distributed on an "AS IS" BASIS,
102822 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102823 * See the License for the specific language governing permissions and
102824 * limitations under the License.
102825 * =============================================================================
102826 */
102827 function step({ inputs, attrs, backend }) {
102828 const { x } = inputs;
102829 const opSnippet = CHECK_NAN_SNIPPET$1 + `
102830 return x > 0.0 ? 1.0 : float(${attrs.alpha});
102831 `;
102832 const program = new UnaryOpProgram(x.shape, opSnippet);
102833 return backend.runWebGLProgram(program, [x], x.dtype);
102834 }
102835 const stepConfig = {
102836 kernelName: Step,
102837 backendName: 'webgl',
102838 kernelFunc: step,
102839 };
102840
102841 /**
102842 * @license
102843 * Copyright 2017 Google LLC. All Rights Reserved.
102844 * Licensed under the Apache License, Version 2.0 (the "License");
102845 * you may not use this file except in compliance with the License.
102846 * You may obtain a copy of the License at
102847 *
102848 * http://www.apache.org/licenses/LICENSE-2.0
102849 *
102850 * Unless required by applicable law or agreed to in writing, software
102851 * distributed under the License is distributed on an "AS IS" BASIS,
102852 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102853 * See the License for the specific language governing permissions and
102854 * limitations under the License.
102855 * =============================================================================
102856 */
102857 class StridedSliceProgram {
102858 constructor(begin, strides, size) {
102859 this.variableNames = ['x'];
102860 this.outputShape = size;
102861 const rank = size.length;
102862 const inputDtype = getCoordsDataType(size.length);
102863 const dtype = getCoordsDataType(size.length);
102864 let newCoords = '';
102865 if (rank === 1) {
102866 newCoords = 'coords * strides + begin';
102867 }
102868 else {
102869 let outputAxis = 0;
102870 newCoords =
102871 size.map((_, i) => {
102872 outputAxis++;
102873 return size.length === 1 ?
102874 `coords * strides[${i}] + begin[${i}]` :
102875 `coords[${outputAxis - 1}] * strides[${i}] + begin[${i}]`;
102876 })
102877 .join(',');
102878 }
102879 this.userCode = `
102880 ${inputDtype} begin = ${inputDtype}(${begin});
102881 ${inputDtype} strides = ${inputDtype}(${strides});
102882
102883 void main() {
102884 ${dtype} coords = getOutputCoords();
102885 setOutput(getX(${newCoords}));
102886 }
102887 `;
102888 }
102889 }
102890
102891 /**
102892 * @license
102893 * Copyright 2020 Google LLC. All Rights Reserved.
102894 * Licensed under the Apache License, Version 2.0 (the "License");
102895 * you may not use this file except in compliance with the License.
102896 * You may obtain a copy of the License at
102897 *
102898 * http://www.apache.org/licenses/LICENSE-2.0
102899 *
102900 * Unless required by applicable law or agreed to in writing, software
102901 * distributed under the License is distributed on an "AS IS" BASIS,
102902 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102903 * See the License for the specific language governing permissions and
102904 * limitations under the License.
102905 * =============================================================================
102906 */
102907 function stridedSlice(args) {
102908 const { inputs, backend, attrs } = args;
102909 const { x } = inputs;
102910 const { begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask } = attrs;
102911 const { finalShapeSparse, finalShape, isIdentity, sliceDim0, isSimpleSlice, begin: $begin, end: $end, strides: $strides } = sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
102912 let result;
102913 if (isIdentity) {
102914 // Optimization #1, slice is a no-op plus reshape
102915 result = reshape({ inputs: { x }, backend, attrs: { shape: finalShape } });
102916 }
102917 else if (sliceDim0 || isSimpleSlice) {
102918 // Optimization #2, slice is memory contiguous (only occurs in dim 0)
102919 assert$1(x.shape.length >= 1, () => `Input must have rank at least 1, got: ${x.shape.length}`);
102920 const size = computeOutShape$2($begin, $end, $strides);
102921 // To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end).
102922 const sliced = slice({ inputs: { x }, backend, attrs: { begin: $begin, size } });
102923 result =
102924 reshape({ inputs: { x: sliced }, backend, attrs: { shape: finalShape } });
102925 backend.disposeIntermediateTensorInfo(sliced);
102926 }
102927 else {
102928 const shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]);
102929 if (shouldExecuteOnCPU) {
102930 // tslint:disable-next-line: no-unnecessary-type-assertion
102931 const values = backend.readSync(x.dataId);
102932 // tslint:disable-next-line: no-unnecessary-type-assertion
102933 const xBuf = buffer(x.shape, x.dtype, values);
102934 const resultValues = stridedSliceImplCPU(finalShapeSparse, xBuf, $strides, $begin);
102935 result = backend.makeTensorInfo(finalShape, x.dtype, resultValues.values);
102936 }
102937 else {
102938 const program = new StridedSliceProgram($begin, $strides, finalShapeSparse);
102939 result = backend.runWebGLProgram(program, [x], x.dtype);
102940 }
102941 }
102942 const resultReshaped = reshape({ inputs: { x: result }, backend, attrs: { shape: finalShape } });
102943 backend.disposeIntermediateTensorInfo(result);
102944 return resultReshaped;
102945 }
102946 const stridedSliceConfig = {
102947 kernelName: StridedSlice,
102948 backendName: 'webgl',
102949 kernelFunc: stridedSlice
102950 };
102951
102952 /**
102953 * @license
102954 * Copyright 2021 Google LLC. All Rights Reserved.
102955 * Licensed under the Apache License, Version 2.0 (the "License");
102956 * you may not use this file except in compliance with the License.
102957 * You may obtain a copy of the License at
102958 *
102959 * http://www.apache.org/licenses/LICENSE-2.0
102960 *
102961 * Unless required by applicable law or agreed to in writing, software
102962 * distributed under the License is distributed on an "AS IS" BASIS,
102963 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102964 * See the License for the specific language governing permissions and
102965 * limitations under the License.
102966 * =============================================================================
102967 */
102968 function stringNGrams(args) {
102969 const { inputs, backend, attrs } = args;
102970 const { separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences } = attrs;
102971 const { data, dataSplits } = inputs;
102972 const $data = backend.readSync(data.dataId);
102973 const $dataSplits = backend.readSync(dataSplits.dataId);
102974 const [nGrams, nGramsSplits] = stringNGramsImplCPU($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences);
102975 return [
102976 backend.makeTensorInfo([nGrams.length], 'string', nGrams),
102977 backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits),
102978 ];
102979 }
102980 const stringNGramsConfig = {
102981 kernelName: StringNGrams,
102982 backendName: 'webgl',
102983 kernelFunc: stringNGrams,
102984 };
102985
102986 /**
102987 * @license
102988 * Copyright 2021 Google LLC. All Rights Reserved.
102989 * Licensed under the Apache License, Version 2.0 (the "License");
102990 * you may not use this file except in compliance with the License.
102991 * You may obtain a copy of the License at
102992 *
102993 * http://www.apache.org/licenses/LICENSE-2.0
102994 *
102995 * Unless required by applicable law or agreed to in writing, software
102996 * distributed under the License is distributed on an "AS IS" BASIS,
102997 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102998 * See the License for the specific language governing permissions and
102999 * limitations under the License.
103000 * =============================================================================
103001 */
103002 function stringSplit(args) {
103003 const { inputs, backend, attrs } = args;
103004 const { skipEmpty } = attrs;
103005 const { input, delimiter } = inputs;
103006 if (input.dtype !== 'string') {
103007 throw new Error('Input must be of datatype string');
103008 }
103009 if (input.shape.length !== 1) {
103010 throw new Error(`Input must be a vector, got shape: ${input.shape}`);
103011 }
103012 if (delimiter.shape.length !== 0) {
103013 throw new Error(`Delimiter must be a scalar, got shape: ${delimiter.shape}`);
103014 }
103015 const $input = backend.readSync(input.dataId);
103016 const $delimiter = backend.readSync(delimiter.dataId)[0];
103017 const [indices, values, shape] = stringSplitImplCPU($input, $delimiter, skipEmpty);
103018 const outputSize = values.length;
103019 return [
103020 backend.makeTensorInfo([outputSize, 2], 'int32', indices),
103021 backend.makeTensorInfo([outputSize], 'string', values),
103022 backend.makeTensorInfo([2], 'int32', new Int32Array(shape))
103023 ];
103024 }
103025 const stringSplitConfig = {
103026 kernelName: StringSplit,
103027 backendName: 'webgl',
103028 kernelFunc: stringSplit,
103029 };
103030
103031 /**
103032 * @license
103033 * Copyright 2021 Google LLC. All Rights Reserved.
103034 * Licensed under the Apache License, Version 2.0 (the "License");
103035 * you may not use this file except in compliance with the License.
103036 * You may obtain a copy of the License at
103037 *
103038 * http://www.apache.org/licenses/LICENSE-2.0
103039 *
103040 * Unless required by applicable law or agreed to in writing, software
103041 * distributed under the License is distributed on an "AS IS" BASIS,
103042 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103043 * See the License for the specific language governing permissions and
103044 * limitations under the License.
103045 * =============================================================================
103046 */
103047 function stringToHashBucketFast(args) {
103048 const { inputs, backend, attrs } = args;
103049 const { numBuckets } = attrs;
103050 const { input } = inputs;
103051 if (input.dtype !== 'string') {
103052 throw new Error('Input must be of datatype string');
103053 }
103054 if (numBuckets <= 0) {
103055 throw new Error(`Number of buckets must be at least 1`);
103056 }
103057 const $input = backend.readSync(input.dataId);
103058 const output = stringToHashBucketFastImplCPU($input, numBuckets);
103059 return backend.makeTensorInfo(input.shape, 'int32', output);
103060 }
103061 const stringToHashBucketFastConfig = {
103062 kernelName: StringToHashBucketFast,
103063 backendName: 'webgl',
103064 kernelFunc: stringToHashBucketFast,
103065 };
103066
103067 /**
103068 * @license
103069 * Copyright 2020 Google LLC. All Rights Reserved.
103070 * Licensed under the Apache License, Version 2.0 (the "License");
103071 * you may not use this file except in compliance with the License.
103072 * You may obtain a copy of the License at
103073 *
103074 * http://www.apache.org/licenses/LICENSE-2.0
103075 *
103076 * Unless required by applicable law or agreed to in writing, software
103077 * distributed under the License is distributed on an "AS IS" BASIS,
103078 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103079 * See the License for the specific language governing permissions and
103080 * limitations under the License.
103081 * =============================================================================
103082 */
103083 const TAN = `return tan(x);`;
103084 const tan = unaryKernelFunc({ opSnippet: TAN });
103085 const tanConfig = {
103086 kernelName: Tan,
103087 backendName: 'webgl',
103088 kernelFunc: tan,
103089 };
103090
103091 /**
103092 * @license
103093 * Copyright 2020 Google LLC. All Rights Reserved.
103094 * Licensed under the Apache License, Version 2.0 (the "License");
103095 * you may not use this file except in compliance with the License.
103096 * You may obtain a copy of the License at
103097 *
103098 * http://www.apache.org/licenses/LICENSE-2.0
103099 *
103100 * Unless required by applicable law or agreed to in writing, software
103101 * distributed under the License is distributed on an "AS IS" BASIS,
103102 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103103 * See the License for the specific language governing permissions and
103104 * limitations under the License.
103105 * =============================================================================
103106 */
103107 const TANH = `
103108 float e2x = exp(-2.0 * abs(x));
103109 return sign(x) * (1.0 - e2x) / (1.0 + e2x);
103110`;
103111 const tanh = unaryKernelFunc({ opSnippet: TANH });
103112 const tanhConfig = {
103113 kernelName: Tanh$1,
103114 backendName: 'webgl',
103115 kernelFunc: tanh,
103116 };
103117
103118 /**
103119 * @license
103120 * Copyright 2022 Google LLC. All Rights Reserved.
103121 * Licensed under the Apache License, Version 2.0 (the "License");
103122 * you may not use this file except in compliance with the License.
103123 * You may obtain a copy of the License at
103124 *
103125 * http://www.apache.org/licenses/LICENSE-2.0
103126 *
103127 * Unless required by applicable law or agreed to in writing, software
103128 * distributed under the License is distributed on an "AS IS" BASIS,
103129 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103130 * See the License for the specific language governing permissions and
103131 * limitations under the License.
103132 * =============================================================================
103133 */
103134 function tensorScatterUpdate(args) {
103135 const { inputs, backend, attrs } = args;
103136 const { tensor, indices, updates } = inputs;
103137 const {} = attrs;
103138 const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, tensor.shape);
103139 const flattenShape = [outputSize / sliceSize, sliceSize];
103140 if (outputSize === 0) {
103141 return backend.makeTensorInfo(tensor.shape, indices.dtype);
103142 }
103143 const flattenIndices = reshape({ inputs: { x: indices }, backend, attrs: { shape: [numUpdates, sliceRank] } });
103144 const flattenX = reshape({ inputs: { x: updates }, backend, attrs: { shape: [numUpdates, sliceSize] } });
103145 const flattenTensor = reshape({ inputs: { x: tensor }, backend, attrs: { shape: flattenShape } });
103146 const program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape, false, true);
103147 const res = backend.runWebGLProgram(program, [flattenX, flattenIndices, flattenTensor], flattenTensor.dtype);
103148 const reshaped = reshape({ inputs: { x: res }, backend, attrs: { shape: tensor.shape } });
103149 backend.disposeIntermediateTensorInfo(flattenIndices);
103150 backend.disposeIntermediateTensorInfo(flattenX);
103151 backend.disposeIntermediateTensorInfo(flattenTensor);
103152 backend.disposeIntermediateTensorInfo(res);
103153 return reshaped;
103154 }
103155 const tensorScatterUpdateConfig = {
103156 kernelName: TensorScatterUpdate,
103157 backendName: 'webgl',
103158 kernelFunc: tensorScatterUpdate
103159 };
103160
103161 /**
103162 * @license
103163 * Copyright 2017 Google LLC. All Rights Reserved.
103164 * Licensed under the Apache License, Version 2.0 (the "License");
103165 * you may not use this file except in compliance with the License.
103166 * You may obtain a copy of the License at
103167 *
103168 * http://www.apache.org/licenses/LICENSE-2.0
103169 *
103170 * Unless required by applicable law or agreed to in writing, software
103171 * distributed under the License is distributed on an "AS IS" BASIS,
103172 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103173 * See the License for the specific language governing permissions and
103174 * limitations under the License.
103175 * =============================================================================
103176 */
103177 class TileProgram {
103178 constructor(aShape, reps) {
103179 this.variableNames = ['A'];
103180 const outputShape = new Array(aShape.length);
103181 for (let i = 0; i < outputShape.length; i++) {
103182 outputShape[i] = aShape[i] * reps[i];
103183 }
103184 this.outputShape = outputShape;
103185 this.rank = outputShape.length;
103186 const dtype = getCoordsDataType(this.rank);
103187 const sourceCoords = getSourceCoords(aShape);
103188 this.userCode = `
103189 void main() {
103190 ${dtype} resRC = getOutputCoords();
103191 setOutput(getA(${sourceCoords}));
103192 }
103193 `;
103194 }
103195 }
103196 function getSourceCoords(aShape) {
103197 const rank = aShape.length;
103198 if (rank > 5) {
103199 throw Error(`Tile for rank ${rank} is not yet supported`);
103200 }
103201 if (rank === 1) {
103202 return `imod(resRC, ${aShape[0]})`;
103203 }
103204 const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u'];
103205 const sourceCoords = [];
103206 for (let i = 0; i < aShape.length; i++) {
103207 sourceCoords.push(`imod(${currentCoords[i]}, ${aShape[i]})`);
103208 }
103209 return sourceCoords.join();
103210 }
103211
103212 /**
103213 * @license
103214 * Copyright 2020 Google LLC. All Rights Reserved.
103215 * Licensed under the Apache License, Version 2.0 (the "License");
103216 * you may not use this file except in compliance with the License.
103217 * You may obtain a copy of the License at
103218 *
103219 * http://www.apache.org/licenses/LICENSE-2.0
103220 *
103221 * Unless required by applicable law or agreed to in writing, software
103222 * distributed under the License is distributed on an "AS IS" BASIS,
103223 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103224 * See the License for the specific language governing permissions and
103225 * limitations under the License.
103226 * =============================================================================
103227 */
103228 function tile(params) {
103229 const { inputs, backend, attrs } = params;
103230 const { x } = inputs;
103231 const { reps } = attrs;
103232 // tile gpu program cannot handle rank > 5 case.
103233 if (x.dtype === 'string' || x.shape.length > 5) {
103234 // Even thought string tensor is always on CPU, just to be consistent on how
103235 // to access tensor data.
103236 const data = backend.readSync(x.dataId);
103237 const value = x.dtype === 'string' ?
103238 data.map(d => decodeString(d)) :
103239 data;
103240 const buf = buffer(x.shape, x.dtype, value);
103241 const outBuf = tileImplCPU(buf, reps);
103242 return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
103243 }
103244 const program = new TileProgram(x.shape, reps);
103245 const output = backend.runWebGLProgram(program, [x], x.dtype);
103246 return output;
103247 }
103248 const tileConfig = {
103249 kernelName: Tile,
103250 backendName: 'webgl',
103251 kernelFunc: tile,
103252 };
103253
103254 // Based on Algorithm 2 of Bitonic Top K, ref:
103255 // https://anilshanbhag.in/static/papers/gputopk_sigmod18.pdf
103256 // The original algorithm is based on computing the top K only, however
103257 // since for TFJS we require the indices of the top K values as well then the
103258 // algorithm found here is a bit modified. Rather than producing the values
103259 // at each step, the indices containing the top K are generated instead.
103260 // The output values are not generated to reduce the number of outputs in the
103261 // GPU, the values can easily be retrieved from the indices using a gather
103262 // op.
103263 class SwapProgram {
103264 /**
103265 * @param shape desired output shape (can be larger than input shape, output
103266 * will be padded with -Infinity)
103267 */
103268 constructor(shape) {
103269 this.variableNames = ['x', 'indices'];
103270 // |n| Size of the original input of TopK.
103271 // |firstPass|indicates if this is the first time swap is being used which
103272 // means no indices input containing the top K is present yet.
103273 // |inc| Swaps pairs of indices (0, inc), (1, inc + 1), (2, inc + 2) ...
103274 this.customUniforms = [
103275 { name: 'n', type: 'int' },
103276 { name: 'firstPass', type: 'int' },
103277 { name: 'negativeInf', type: 'float' },
103278 { name: 'dir', type: 'int' },
103279 { name: 'inc', type: 'int' }
103280 ];
103281 this.outputShape = shape;
103282 this.userCode = `
103283 void main() {
103284 ivec2 coords = getOutputCoords();
103285 int batch = coords[0];
103286 int elemIdx = coords[1];
103287
103288 // We compare elements pair-wise within a group of size 2 * inc.
103289 // The comparing rule for each group alternates between ascending
103290 // and descending. Within each group, we compare each pair at
103291 // positions i and i+inc. To decide whether an element at position i
103292 // is x0 or x1, we mod it by 2 * inc, if the result is smaller than
103293 // inc, it is in the first half of the group, we denote it as x0,
103294 // otherwise we denote it as x1.
103295 // For example, as shown in the Bitonic top K paper referenced above,
103296 // Figure5(a) shows that element[1] is in the
103297 // second half of the group when group size is 2, but it is in the
103298 // first half of the group when group size is 4.
103299
103300 bool isFirstInPair = imod(elemIdx, 2 * inc) < inc;
103301 int i = isFirstInPair ? elemIdx : elemIdx - inc;
103302
103303 int i0 = firstPass == 1 ? i : int(getIndices(batch, i));
103304 int i1 = firstPass == 1 ? i + inc : int(getIndices(batch, i + inc));
103305 float x0 = i0 < n ? getX(batch, i0) : negativeInf;
103306 float x1 = i1 < n ? getX(batch, i1) : negativeInf;
103307
103308 // Denotes which direction indices are in (ascending or descending).
103309 bool reverse = imod(elemIdx, 2 * dir) >= dir;
103310 bool isGreater = x0 > x1 || (x0 == x1 && i1 > i0);
103311 if (reverse == isGreater) { // Elements in opposite order of direction
103312 int iTemp = i0;
103313 i0 = i1;
103314 i1 = iTemp;
103315 }
103316 if (isFirstInPair) {
103317 setOutput(float(i0));
103318 } else {
103319 setOutput(float(i1));
103320 }
103321 }
103322 `;
103323 }
103324 }
103325 class MergeProgram {
103326 /**
103327 * @param shape desired output shape (must be half of the input size)
103328 */
103329 constructor(shape) {
103330 this.variableNames = ['x', 'indices'];
103331 // |n| Size of the original input of TopK
103332 // |firstPass| indicates if this is the first time swap is being used which
103333 // means no indices input containing the top K is present yet.
103334 // |k| Top k elements desired
103335 this.customUniforms = [
103336 { name: 'n', type: 'int' },
103337 { name: 'firstPass', type: 'int' },
103338 { name: 'k', type: 'int' }
103339 ];
103340 this.outputShape = shape;
103341 this.userCode = `
103342 void main() {
103343 // Takes max of indices (0, k), (1, k + 1), (2, k + 2) ...
103344 ivec2 coords = getOutputCoords();
103345 int batch = coords[0];
103346 int elemIdx = coords[1];
103347
103348 // The output size is half of the previous size.
103349 // If the previous sequence is | | | | _ _ _ _ | | | | _ _ _ _ (k=4),
103350 // we only need to output the indices at positions |, the indices at
103351 // positions _ can be thrown away, see Figure5(b) After Phase 2
103352 // (Merge phase) in the Bitonic Top K paper referenced above.
103353 // For example, the paper shows we only need to output the orange bars.
103354 // The output sequence should look like this | | | | | | | |.
103355 // Because the sequence is halved, to map the output index back
103356 // to the previous sequence to find the corresponding value,
103357 // we need to double the index. When we double the index,
103358 // we basically interpolate a position, so 2i looks like
103359 // | _ | _ | _ | _ | _ | _ | _. We move the | to the first k position
103360 // of each 2k positions by - elemIdx % k. E.g. for output at
103361 // index 4,5,6,7, we want to get the corresponding element at
103362 // original index 8,9,10,11, for output at index 8,9,10,11,
103363 // we want to get the corresponding element at original index
103364 // 16,17,18,19, so on and so forth.
103365
103366 int i = elemIdx < k ? elemIdx : (elemIdx * 2 - imod(elemIdx, k));
103367 int i0 = firstPass == 1 ? i : int(getIndices(batch, i));
103368 int i1 = firstPass == 1 ? i + k : int(getIndices(batch, i + k));
103369
103370 float x0 = getX(batch, i0);
103371 float x1 = i1 < n ? getX(batch, i1) : x0;
103372
103373 setOutput(x0 >= x1 ? float(i0) : float(i1));
103374 }
103375 `;
103376 }
103377 }
103378
103379 /**
103380 * @license
103381 * Copyright 2020 Google LLC. All Rights Reserved.
103382 * Licensed under the Apache License, Version 2.0 (the "License");
103383 * you may not use this file except in compliance with the License.
103384 * You may obtain a copy of the License at
103385 *
103386 * http://www.apache.org/licenses/LICENSE-2.0
103387 *
103388 * Unless required by applicable law or agreed to in writing, software
103389 * distributed under the License is distributed on an "AS IS" BASIS,
103390 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103391 * See the License for the specific language governing permissions and
103392 * limitations under the License.
103393 * =============================================================================
103394 */
103395 function disposeIntermediateTensorInfoOrNull(backend, tensorInfo) {
103396 if (tensorInfo !== null) {
103397 backend.disposeIntermediateTensorInfo(tensorInfo);
103398 }
103399 }
103400 function roundUpToPow2(num) {
103401 let pow2 = 1;
103402 while (pow2 < num) {
103403 pow2 *= 2;
103404 }
103405 return pow2;
103406 }
103407 // Based on Algorithm 2 of Bitonic Top K, ref:
103408 // https://anilshanbhag.in/static/papers/gputopk_sigmod18.pdf
103409 function topK(args) {
103410 const { inputs, backend, attrs } = args;
103411 const { x } = inputs;
103412 const { k, sorted } = attrs;
103413 // Empirically determined constant used to determine last dim threshold for
103414 // handing off execution to the CPU.
103415 const TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD = env().getNumber('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD');
103416 // Empirically determined constant used to determine k threshold for handing
103417 // off execution to the CPU.
103418 const TOPK_K_CPU_HANDOFF_THRESHOLD = env().getNumber('TOPK_K_CPU_HANDOFF_THRESHOLD');
103419 const xShape = x.shape;
103420 const lastDim = xShape[xShape.length - 1];
103421 if (backend.shouldExecuteOnCPU([x]) ||
103422 lastDim < TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD ||
103423 k > TOPK_K_CPU_HANDOFF_THRESHOLD) {
103424 const xVals = backend.readSync(x.dataId);
103425 const [allTopKVals, allTopKIndices] = topKImplCPU(xVals, xShape, x.dtype, k, sorted);
103426 return [
103427 backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values),
103428 backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)
103429 ];
103430 }
103431 if (k === 0) {
103432 xShape[xShape.length - 1] = 0;
103433 return [
103434 backend.makeTensorInfo(xShape, x.dtype, []),
103435 backend.makeTensorInfo(xShape, 'int32', [])
103436 ];
103437 }
103438 if (lastDim === 1 /* firstPass */) {
103439 return [
103440 x, fill({ attrs: { shape: xShape, dtype: 'int32', value: 0 }, backend })
103441 ];
103442 }
103443 // Eagerly unpack x input since it is passed in to all the shaders which
103444 // require unpacked inputs.
103445 const xtexData = backend.texData.get(x.dataId);
103446 const xIsPacked = xtexData !== null && xtexData.isPacked;
103447 const xUnPacked = xIsPacked ? backend.unpackTensor(x) : x;
103448 // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.
103449 const xSize = sizeFromShape(xShape);
103450 const batch = xSize / lastDim;
103451 const x2D = reshape({ inputs: { x: xUnPacked }, attrs: { shape: [batch, lastDim] }, backend });
103452 if (xIsPacked) {
103453 disposeIntermediateTensorInfoOrNull(backend, xUnPacked);
103454 }
103455 const kPow2 = roundUpToPow2(k);
103456 const lastDimPow2 = roundUpToPow2(lastDim);
103457 // Only the indices containing the top K are kept at every step to reduce
103458 // number of outputs in the GPU algorithms, so once the final set of indices
103459 // is computed then gather is used to grab the corresponding values
103460 // from the original input.
103461 let indices = null;
103462 // GPU algorithm always takes in an indices input but this input is not used
103463 // on the first run of a GPU algorithm, therefore if indices is null we simply
103464 // pass in x2D instead of it but the value will not actually be used
103465 const getInputs = () => indices === null ? [x2D, x2D] : [x2D, indices];
103466 const runSwap = (dir, inc, shape) => {
103467 const inputs = getInputs();
103468 const program = new SwapProgram(shape);
103469 const fistPass = indices === null ? 1 : 0;
103470 const customValues = [[lastDim], [fistPass], [Number.NEGATIVE_INFINITY], [dir], [inc]];
103471 const prevIndices = indices;
103472 indices = backend.runWebGLProgram(program, inputs, 'int32', customValues);
103473 disposeIntermediateTensorInfoOrNull(backend, prevIndices);
103474 };
103475 // Step 1: local sort
103476 for (let len = 1; len < kPow2; len *= 2) {
103477 const dir = len * 2;
103478 for (let inc = len; inc >= 1; inc /= 2) {
103479 runSwap(dir, inc, [batch, lastDimPow2]);
103480 }
103481 }
103482 // Step 2: merge
103483 for (let indicesSize = lastDimPow2; indicesSize > kPow2; indicesSize /= 2) {
103484 const inputs = getInputs();
103485 const mergeProgram = new MergeProgram([batch, indicesSize / 2]);
103486 const firstPass = indices === null ? 1 : 0;
103487 const customValues = [[lastDim], [firstPass], [kPow2]];
103488 const prevIndices = indices;
103489 indices =
103490 backend.runWebGLProgram(mergeProgram, inputs, 'int32', customValues);
103491 disposeIntermediateTensorInfoOrNull(backend, prevIndices);
103492 // Step 3: rebuild
103493 const len = kPow2 / 2;
103494 const dir = len * 2;
103495 for (let inc = len; inc >= 1; inc /= 2) {
103496 runSwap(dir, inc, indices.shape);
103497 }
103498 }
103499 // Keep only the requested top K results instead of kPow2
103500 let prevIndices = indices;
103501 indices = slice({ inputs: { x: indices }, backend, attrs: { begin: 0, size: [batch, k] } });
103502 disposeIntermediateTensorInfoOrNull(backend, prevIndices);
103503 // Gather values on last dimension
103504 let values = gatherV2({ inputs: { x: x2D, indices }, backend, attrs: { axis: 1, batchDims: 1 } });
103505 disposeIntermediateTensorInfoOrNull(backend, x2D);
103506 // Reshape back to the original input shape, except that the last
103507 // dimension is k.
103508 const newShape = xShape.slice(0, -1);
103509 newShape.push(k);
103510 prevIndices = indices;
103511 indices = reshape({ inputs: { x: indices }, attrs: { shape: newShape }, backend });
103512 disposeIntermediateTensorInfoOrNull(backend, prevIndices);
103513 const prevValues = values;
103514 values = reshape({ inputs: { x: values }, attrs: { shape: newShape }, backend });
103515 disposeIntermediateTensorInfoOrNull(backend, prevValues);
103516 return [values, indices];
103517 }
103518 const topKConfig = {
103519 kernelName: TopK,
103520 backendName: 'webgl',
103521 kernelFunc: topK
103522 };
103523
103524 /**
103525 * @license
103526 * Copyright 2021 Google LLC. All Rights Reserved.
103527 * Licensed under the Apache License, Version 2.0 (the "License");
103528 * you may not use this file except in compliance with the License.
103529 * You may obtain a copy of the License at
103530 *
103531 * http://www.apache.org/licenses/LICENSE-2.0
103532 *
103533 * Unless required by applicable law or agreed to in writing, software
103534 * distributed under the License is distributed on an "AS IS" BASIS,
103535 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103536 * See the License for the specific language governing permissions and
103537 * limitations under the License.
103538 * =============================================================================
103539 */
103540 class TransformProgram {
103541 constructor(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape) {
103542 this.variableNames = ['Image', 'Transforms'];
103543 this.outputShape = outShape;
103544 const interpolationModeId = interpolation === 'nearest' ? 1 : 2;
103545 let fillModeId;
103546 switch (fillMode) {
103547 case 'constant':
103548 fillModeId = 1;
103549 break;
103550 case 'reflect':
103551 fillModeId = 2;
103552 break;
103553 case 'wrap':
103554 fillModeId = 3;
103555 break;
103556 case 'nearest':
103557 fillModeId = 4;
103558 break;
103559 default:
103560 fillModeId = 1;
103561 break;
103562 }
103563 this.userCode = `
103564 float mapCoord(float outCoord, float len) {
103565 float inCoord = outCoord;
103566 if(${fillModeId} == 2) {
103567 if (inCoord < 0.0) {
103568 if (len <= 1.0) {
103569 inCoord = 0.0;
103570 } else {
103571 float sz2 = 2.0 * len;
103572 if (inCoord < sz2) {
103573 inCoord = sz2 * float(int(float(-inCoord / sz2))) +
103574 inCoord;
103575 }
103576 inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1.0;
103577 }
103578 } else if (inCoord > len - 1.0) {
103579 if (len <= 1.0) {
103580 inCoord = 0.0;
103581 } else {
103582 float sz2 = 2.0 * len;
103583 inCoord -= sz2 * float(int(float(inCoord / sz2)));
103584 if (inCoord >= len) {
103585 inCoord = sz2 - inCoord - 1.0;
103586 }
103587 }
103588 }
103589 return clamp(inCoord, 0.0, len - 1.0);
103590 } else if (${fillModeId} == 3) {
103591 if (inCoord < 0.0) {
103592 if (len <= 1.0) {
103593 inCoord = 0.0;
103594 } else {
103595 float sz = len - 1.0;
103596 inCoord += len * (float(int(float(-inCoord / sz))) + 1.0);
103597 }
103598 } else if (inCoord > len - 1.0) {
103599 if (len <= 1.0) {
103600 inCoord = 0.0;
103601 } else {
103602 float sz = len - 1.0;
103603 inCoord -= len * float(int(float(inCoord / sz)));
103604 }
103605 }
103606 return clamp(inCoord, 0.0, len - 1.0);
103607 } else if (${fillModeId} == 4) {
103608 return clamp(outCoord, 0.0, len - 1.0);
103609 } else {
103610 return outCoord;
103611 }
103612 }
103613
103614 float readWithFillValue(int batch, int coordY, int coordX,
103615 int channel) {
103616 float outputValue;
103617 if (0 <= coordY && coordY < ${imageHeight} && 0 <= coordX && coordX < ${imageWidth}) {
103618 outputValue = getImage(batch, coordY, coordX, channel);
103619 } else {
103620 outputValue = float(${fillValue});
103621 }
103622 return outputValue;
103623 }
103624
103625 void main() {
103626 ivec4 coords = getOutputCoords();
103627 float outputValue;
103628 int batch = coords[0];
103629 int x = coords[2];
103630 int y = coords[1];
103631 int channel = coords[3];
103632 float xf = float(x);
103633 float yf = float(y);
103634 float a1 = getTransforms(batch, 0);
103635 float a2 = getTransforms(batch, 1);
103636 float a3 = getTransforms(batch, 2);
103637 float b1 = getTransforms(batch, 3);
103638 float b2 = getTransforms(batch, 4);
103639 float b3 = getTransforms(batch, 5);
103640 float c1 = getTransforms(batch, 6);
103641 float c2 = getTransforms(batch, 7);
103642 float projection = c1 * xf + c2 * yf + 1.0;
103643 if (projection == 0.0) {
103644 outputValue = float(${fillValue});
103645 } else {
103646 float inX = (a1 * xf + a2 * yf + a3) / projection;
103647 float inY = (b1 * xf + b2 * yf + b3) / projection;
103648 float mapX = mapCoord(inX, float(${imageWidth}));
103649 float mapY = mapCoord(inY, float(${imageHeight}));
103650
103651 if (${interpolationModeId} == 1) {
103652 int coordY = int(round(mapY));
103653 int coordX = int(round(mapX));
103654 outputValue = readWithFillValue(batch, coordY, coordX,
103655 channel);
103656 } else {
103657 float yFloor = floor(mapY);
103658 float xFloor = floor(mapX);
103659 float yCeil = yFloor + 1.0;
103660 float xCeil = xFloor + 1.0;
103661 float valueYFloor = (xCeil - mapX) *
103662 readWithFillValue(batch, int(yFloor), int(xFloor), channel) +
103663 (mapX - xFloor) *
103664 readWithFillValue(batch, int(yFloor), int(xCeil), channel);
103665 float valueYCeil = (xCeil - mapX) *
103666 readWithFillValue(batch, int(yCeil), int(xFloor), channel) +
103667 (mapX - xFloor) *
103668 readWithFillValue(batch, int(yCeil), int(xCeil), channel);
103669 outputValue = (yCeil - mapY) * valueYFloor +
103670 (mapY - yFloor) * valueYCeil;
103671 }
103672 }
103673 setOutput(outputValue);
103674 }
103675 `;
103676 }
103677 }
103678
103679 /**
103680 * @license
103681 * Copyright 2021 Google LLC. All Rights Reserved.
103682 * Licensed under the Apache License, Version 2.0 (the "License");
103683 * you may not use this file except in compliance with the License.
103684 * You may obtain a copy of the License at
103685 *
103686 * http://www.apache.org/licenses/LICENSE-2.0
103687 *
103688 * Unless required by applicable law or agreed to in writing, software
103689 * distributed under the License is distributed on an "AS IS" BASIS,
103690 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103691 * See the License for the specific language governing permissions and
103692 * limitations under the License.
103693 * =============================================================================
103694 */
103695 function transform(args) {
103696 const { inputs, backend, attrs } = args;
103697 const { image, transforms } = inputs;
103698 const { interpolation, fillMode, fillValue, outputShape } = attrs;
103699 const [batch, imageHeight, imageWidth, numChannels] = image.shape;
103700 const [outHeight, outWidth] = outputShape != null ? outputShape : [imageHeight, imageWidth];
103701 const outShape = [batch, outHeight, outWidth,
103702 numChannels];
103703 const program = new TransformProgram(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape);
103704 return backend.runWebGLProgram(program, [image, transforms], 'float32');
103705 }
103706 const transformConfig = {
103707 kernelName: Transform,
103708 backendName: 'webgl',
103709 kernelFunc: transform
103710 };
103711
103712 /**
103713 * @license
103714 * Copyright 2020 Google LLC. All Rights Reserved.
103715 * Licensed under the Apache License, Version 2.0 (the License);
103716 * you may not use this file except in compliance with the License.
103717 * You may obtain a copy of the License at
103718 *
103719 * http://www.apache.org/licenses/LICENSE-2.0
103720 *
103721 * Unless required by applicable law or agreed to in writing, software
103722 * distributed under the License is distributed on an AS IS BASIS,
103723 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103724 * See the License for the specific language governing permissions and
103725 * limitations under the License.
103726 * =============================================================================
103727 */
103728 function unique(args) {
103729 const { inputs, attrs, backend } = args;
103730 const { axis } = attrs;
103731 const { x } = inputs;
103732 assertNotComplex(x, 'unique');
103733 // For now, always forward calculation to the CPU backend.
103734 console.warn('WARNING: ', 'UI might be locked temporarily as data is being downloaded');
103735 const values = backend.readSync(x.dataId);
103736 const { outputValues, outputShape, indices } = uniqueImplCPU(values, axis, x.shape, x.dtype);
103737 return [
103738 backend.makeTensorInfo(outputShape, x.dtype, outputValues),
103739 backend.makeTensorInfo([indices.length], 'int32', indices),
103740 ];
103741 }
103742 const uniqueConfig = {
103743 kernelName: Unique,
103744 backendName: 'webgl',
103745 kernelFunc: unique,
103746 };
103747
103748 /**
103749 * @license
103750 * Copyright 2020 Google LLC. All Rights Reserved.
103751 * Licensed under the Apache License, Version 2.0 (the "License");
103752 * you may not use this file except in compliance with the License.
103753 * You may obtain a copy of the License at
103754 *
103755 * http://www.apache.org/licenses/LICENSE-2.0
103756 *
103757 * Unless required by applicable law or agreed to in writing, software
103758 * distributed under the License is distributed on an "AS IS" BASIS,
103759 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103760 * See the License for the specific language governing permissions and
103761 * limitations under the License.
103762 * =============================================================================
103763 */
103764 function unpack(args) {
103765 const { inputs, backend, attrs } = args;
103766 const { value } = inputs;
103767 let { axis } = attrs;
103768 if (axis < 0) {
103769 axis += value.shape.length;
103770 }
103771 const x = value;
103772 const xRank = x.shape.length;
103773 const num = value.shape[axis];
103774 const outShape = new Array(xRank - 1);
103775 let outIndex = 0;
103776 for (let i = 0; i < xRank; i++) {
103777 if (i !== axis) {
103778 outShape[outIndex++] = x.shape[i];
103779 }
103780 }
103781 const toDispose = [];
103782 const begin = new Array(xRank).fill(0);
103783 const size = x.shape.slice();
103784 size[axis] = 1;
103785 const res = new Array(num);
103786 for (let i = 0; i < res.length; i++) {
103787 begin[axis] = i;
103788 const sliced = slice({ inputs: { x }, backend, attrs: { begin, size } });
103789 const reshaped = reshape({ inputs: { x: sliced }, backend, attrs: { shape: outShape } });
103790 res[i] = reshaped;
103791 toDispose.push(sliced);
103792 }
103793 toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
103794 return res;
103795 }
103796 const unpackConfig = {
103797 kernelName: Unpack,
103798 backendName: 'webgl',
103799 kernelFunc: unpack
103800 };
103801
103802 /**
103803 * @license
103804 * Copyright 2018 Google LLC. All Rights Reserved.
103805 * Licensed under the Apache License, Version 2.0 (the "License");
103806 * you may not use this file except in compliance with the License.
103807 * You may obtain a copy of the License at
103808 *
103809 * http://www.apache.org/licenses/LICENSE-2.0
103810 *
103811 * Unless required by applicable law or agreed to in writing, software
103812 * distributed under the License is distributed on an "AS IS" BASIS,
103813 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103814 * See the License for the specific language governing permissions and
103815 * limitations under the License.
103816 * =============================================================================
103817 */
103818 class SegmentOpProgram {
103819 constructor(segOpInfo, segOpType) {
103820 this.variableNames = ['x', 'segmentIds'];
103821 const windowSize = segOpInfo.windowSize;
103822 const batchSize = segOpInfo.batchSize;
103823 const inSize = segOpInfo.inSize;
103824 const numSegments = segOpInfo.numSegments;
103825 const outSize = numSegments * Math.ceil(inSize / windowSize);
103826 this.outputShape = [batchSize, outSize];
103827 const initializationValue = '0.0';
103828 const returnValue = `sumValue`;
103829 const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
103830 const windowSizeVec4Remainder = windowSize % 4;
103831 const updateSnippet = `
103832 sumValue += dot(values, segFilter);
103833 `;
103834 let checkValueOutOfBounds = '';
103835 if (inSize % windowSize > 0) {
103836 checkValueOutOfBounds = `
103837 if (inIdx < 0 || inIdx >= ${inSize}) {
103838 return initializationValue;
103839 }
103840 `;
103841 }
103842 let checkSegmentIdOutOfBounds = '';
103843 if (inSize % windowSize > 0) {
103844 checkSegmentIdOutOfBounds = `
103845 if (inIdx < 0 || inIdx >= ${inSize}) {
103846 return -1.0;
103847 }
103848 `;
103849 }
103850 this.userCode = `
103851 const float initializationValue = ${initializationValue};
103852
103853 float getValue(int batch, int inIdx) {
103854 ${checkValueOutOfBounds}
103855 return getX(batch, inIdx);
103856 }
103857
103858 float getSegmentIdAtIndex(int inIdx) {
103859 ${checkSegmentIdOutOfBounds}
103860 return getSegmentIds(inIdx);
103861 }
103862
103863 void main() {
103864 ivec2 coords = getOutputCoords();
103865 int batch = coords[0];
103866 int outIdx = coords[1];
103867 int inOffset = int(floor(float(outIdx) / float(
103868 ${numSegments})) * float(${windowSize}));
103869 int currentSeg = int(mod(float(outIdx), float(${numSegments})));
103870
103871 float sumValue = 0.0;
103872
103873 for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
103874 int inIdx = inOffset + i;
103875 vec4 values = vec4(
103876 getValue(batch, inIdx),
103877 getValue(batch, inIdx + 1),
103878 getValue(batch, inIdx + 2),
103879 getValue(batch, inIdx + 3)
103880 );
103881
103882 vec4 segFilter = vec4(
103883 int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
103884 int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
103885 int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,
103886 int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0
103887 );
103888
103889 ${updateSnippet}
103890 }
103891
103892 int inIdx = inOffset + ${windowSizeNearestVec4};
103893 if (${windowSizeVec4Remainder === 1}) {
103894 vec4 values = vec4(
103895 getValue(batch, inIdx),
103896 initializationValue,
103897 initializationValue,
103898 initializationValue
103899 );
103900
103901 int inIdxSeg = int(getSegmentIdAtIndex(inIdx));
103902
103903 vec4 segFilter = vec4(
103904 int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
103905 0,
103906 0,
103907 0
103908 );
103909
103910 ${updateSnippet}
103911 } else if (${windowSizeVec4Remainder === 2}) {
103912 vec4 values = vec4(
103913 getValue(batch, inIdx),
103914 getValue(batch, inIdx + 1),
103915 initializationValue,
103916 initializationValue
103917 );
103918
103919 vec4 segFilter = vec4(
103920 int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
103921 int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
103922 0,
103923 0
103924 );
103925
103926 ${updateSnippet}
103927 } else if (${windowSizeVec4Remainder === 3}) {
103928 vec4 values = vec4(
103929 getValue(batch, inIdx),
103930 getValue(batch, inIdx + 1),
103931 getValue(batch, inIdx + 2),
103932 initializationValue
103933 );
103934
103935 vec4 segFilter = vec4(
103936 int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
103937 int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
103938 int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,
103939 0
103940 );
103941
103942 ${updateSnippet}
103943 }
103944 setOutput(${returnValue});
103945 }
103946 `;
103947 }
103948 }
103949
103950 /**
103951 * @license
103952 * Copyright 2020 Google LLC. All Rights Reserved.
103953 * Licensed under the Apache License, Version 2.0 (the "License");
103954 * you may not use this file except in compliance with the License.
103955 * You may obtain a copy of the License at
103956 *
103957 * http://www.apache.org/licenses/LICENSE-2.0
103958 *
103959 * Unless required by applicable law or agreed to in writing, software
103960 * distributed under the License is distributed on an "AS IS" BASIS,
103961 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
103962 * See the License for the specific language governing permissions and
103963 * limitations under the License.
103964 * =============================================================================
103965 */
103966 function unsortedSegmentSum(args) {
103967 const { inputs, backend, attrs } = args;
103968 const { x, segmentIds } = inputs;
103969 const { numSegments } = attrs;
103970 const xRank = x.shape.length;
103971 const toDispose = [];
103972 let axis = 0;
103973 const permutation = getAxesPermutation([axis], xRank);
103974 let permutedX = x;
103975 if (permutation != null) {
103976 permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutation } });
103977 toDispose.push(permutedX);
103978 axis = getInnerMostAxes(1, xRank)[0];
103979 }
103980 const outShape = computeOutShape(permutedX.shape, axis, numSegments);
103981 const inSize = sizeFromShape([permutedX.shape[axis]]);
103982 const a2D = reshape({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
103983 toDispose.push(a2D);
103984 const outputDType = sumOutType(x.dtype);
103985 const segOpCompute = (x, segOpType, segmentIds, dtype, numSegments) => {
103986 const batchSize = x.shape[0];
103987 const inSize = x.shape[1];
103988 const windowSize = segOpComputeOptimalWindowSize(inSize, numSegments);
103989 const segOpInfo = { windowSize, inSize, batchSize, numSegments };
103990 const program = new SegmentOpProgram(segOpInfo, segOpType);
103991 const output = backend.compileAndRun(program, [x, segmentIds], dtype);
103992 toDispose.push(output);
103993 // No need to run another GPGPU program.
103994 if (output.shape[1] === numSegments) {
103995 return output;
103996 }
103997 const rangeInfo = range({
103998 backend,
103999 attrs: { start: 0, stop: numSegments, step: 1, dtype: 'float32' }
104000 });
104001 const tileInfo = tile({
104002 inputs: { x: rangeInfo },
104003 backend,
104004 attrs: { reps: [inSize / windowSize] }
104005 });
104006 toDispose.push(rangeInfo);
104007 toDispose.push(tileInfo);
104008 const result = segOpCompute(output, segOpType, tileInfo, dtype, numSegments);
104009 return result;
104010 };
104011 const segOpResult = segOpCompute(a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments);
104012 const reshaped = reshape({ inputs: { x: segOpResult }, backend, attrs: { shape: outShape } });
104013 let result = reshaped;
104014 if (permutation != null) {
104015 toDispose.push(reshaped);
104016 const perm = getUndoAxesPermutation(permutation);
104017 result = transpose({ inputs: { x: result }, backend, attrs: { perm } });
104018 }
104019 toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
104020 return result;
104021 }
104022 const unsortedSegmentSumConfig = {
104023 kernelName: UnsortedSegmentSum,
104024 backendName: 'webgl',
104025 kernelFunc: unsortedSegmentSum
104026 };
104027
104028 /**
104029 * @license
104030 * Copyright 2020 Google LLC. All Rights Reserved.
104031 * Licensed under the Apache License, Version 2.0 (the "License");
104032 * you may not use this file except in compliance with the License.
104033 * You may obtain a copy of the License at
104034 *
104035 * http://www.apache.org/licenses/LICENSE-2.0
104036 *
104037 * Unless required by applicable law or agreed to in writing, software
104038 * distributed under the License is distributed on an "AS IS" BASIS,
104039 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
104040 * See the License for the specific language governing permissions and
104041 * limitations under the License.
104042 * =============================================================================
104043 */
104044 // List all kernel configs here
104045 const kernelConfigs = [
104046 _fusedMatMulConfig,
104047 absConfig,
104048 acosConfig,
104049 acoshConfig,
104050 addConfig,
104051 addNConfig,
104052 allConfig,
104053 anyConfig,
104054 argMaxConfig,
104055 argMinConfig,
104056 asinConfig,
104057 asinhConfig,
104058 atanConfig,
104059 atan2Config,
104060 atanhConfig,
104061 avgPoolConfig,
104062 avgPool3DConfig,
104063 avgPool3DGradConfig,
104064 avgPoolGradConfig,
104065 batchMatMulConfig,
104066 batchNormConfig,
104067 batchToSpaceNDConfig,
104068 bincountConfig,
104069 bitwiseAndConfig,
104070 broadcastArgsConfig,
104071 castConfig,
104072 ceilConfig,
104073 clipByValueConfig,
104074 complexConfig,
104075 complexAbsConfig,
104076 concatConfig,
104077 conv2DConfig,
104078 conv2DBackpropFilterConfig,
104079 conv2DBackpropInputConfig,
104080 conv3DConfig,
104081 conv3DBackpropFilterV2Config,
104082 conv3DBackpropInputConfig,
104083 cosConfig,
104084 coshConfig,
104085 cropAndResizeConfig,
104086 cumprodConfig,
104087 cumsumConfig,
104088 denseBincountConfig,
104089 depthToSpaceConfig,
104090 depthwiseConv2dNativeConfig,
104091 depthwiseConv2dNativeBackpropFilterConfig,
104092 depthwiseConv2dNativeBackpropInputConfig,
104093 diagConfig,
104094 dilation2DConfig,
104095 einsumConfig,
104096 eluConfig,
104097 eluGradConfig,
104098 equalConfig,
104099 erfConfig,
104100 expConfig,
104101 expandDimsConfig,
104102 expm1Config,
104103 fftConfig,
104104 fillConfig,
104105 flipLeftRightConfig,
104106 floorConfig,
104107 floorDivConfig,
104108 fromPixelsConfig,
104109 fusedConv2DConfig,
104110 fusedDepthwiseConv2DConfig,
104111 gatherNdConfig,
104112 gatherV2Config,
104113 greaterConfig,
104114 greaterEqualConfig,
104115 identityConfig,
104116 ifftConfig,
104117 imagConfig,
104118 isFiniteConfig,
104119 isInfConfig,
104120 isNaNConfig,
104121 leakyReluConfig,
104122 lessConfig,
104123 lessEqualConfig,
104124 linSpaceConfig,
104125 logConfig,
104126 log1pConfig,
104127 logicalAndConfig,
104128 logicalNotConfig,
104129 logicalOrConfig,
104130 LRNConfig,
104131 LRNGradConfig,
104132 maxConfig,
104133 maximumConfig,
104134 maxPoolConfig,
104135 maxPool3DConfig,
104136 maxPool3DGradConfig,
104137 maxPoolGradConfig,
104138 maxPoolWithArgmaxConfig,
104139 meanConfig,
104140 minConfig,
104141 minimumConfig,
104142 mirrorPadConfig,
104143 modConfig,
104144 multinomialConfig,
104145 multiplyConfig,
104146 negConfig,
104147 nonMaxSuppressionV3Config,
104148 nonMaxSuppressionV4Config,
104149 nonMaxSuppressionV5Config,
104150 notEqualConfig,
104151 oneHotConfig,
104152 onesLikeConfig,
104153 packConfig,
104154 padV2Config,
104155 powConfig,
104156 preluConfig,
104157 prodConfig,
104158 raggedGatherConfig,
104159 raggedRangeConfig,
104160 raggedTensorToTensorConfig,
104161 rangeConfig,
104162 realConfig,
104163 realDivConfig,
104164 reciprocalConfig,
104165 reluConfig,
104166 relu6Config,
104167 reshapeConfig,
104168 resizeBilinearConfig,
104169 resizeBilinearGradConfig,
104170 resizeNearestNeighborConfig,
104171 resizeNearestNeighborGradConfig,
104172 reverseConfig,
104173 rotateWithOffsetConfig,
104174 roundConfig,
104175 rsqrtConfig,
104176 scatterNdConfig,
104177 searchSortedConfig,
104178 selectConfig,
104179 seluConfig,
104180 sigmoidConfig,
104181 signConfig,
104182 sinConfig,
104183 sinhConfig,
104184 sliceConfig,
104185 softmaxConfig,
104186 softplusConfig,
104187 spaceToBatchNDConfig,
104188 sparseFillEmptyRowsConfig,
104189 sparseReshapeConfig,
104190 sparseSegmentMeanConfig,
104191 sparseSegmentSumConfig,
104192 sparseToDenseConfig,
104193 splitVConfig,
104194 sqrtConfig,
104195 squareConfig,
104196 squaredDifferenceConfig,
104197 staticRegexReplaceConfig,
104198 stepConfig,
104199 stridedSliceConfig,
104200 stringNGramsConfig,
104201 stringSplitConfig,
104202 stringToHashBucketFastConfig,
104203 subConfig,
104204 sumConfig,
104205 tanConfig,
104206 tanhConfig,
104207 tensorScatterUpdateConfig,
104208 tileConfig,
104209 topKConfig,
104210 transformConfig,
104211 transposeConfig,
104212 uniqueConfig,
104213 unpackConfig,
104214 unsortedSegmentSumConfig,
104215 zerosLikeConfig
104216 ];
104217 for (const kernelConfig of kernelConfigs) {
104218 registerKernel(kernelConfig);
104219 }
104220
104221 /**
104222 * @license
104223 * Copyright 2020 Google LLC. All Rights Reserved.
104224 * Licensed under the Apache License, Version 2.0 (the "License");
104225 * you may not use this file except in compliance with the License.
104226 * You may obtain a copy of the License at
104227 *
104228 * http://www.apache.org/licenses/LICENSE-2.0
104229 *
104230 * Unless required by applicable law or agreed to in writing, software
104231 * distributed under the License is distributed on an "AS IS" BASIS,
104232 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
104233 * See the License for the specific language governing permissions and
104234 * limitations under the License.
104235 * =============================================================================
104236 */
104237
104238 /** @license See the LICENSE file. */
104239 // This code is auto-generated, do not modify this file!
104240 const version$1 = '4.22.0';
104241
104242 /**
104243 * @license
104244 * Copyright 2018 Google LLC. All Rights Reserved.
104245 * Licensed under the Apache License, Version 2.0 (the "License");
104246 * you may not use this file except in compliance with the License.
104247 * You may obtain a copy of the License at
104248 *
104249 * http://www.apache.org/licenses/LICENSE-2.0
104250 *
104251 * Unless required by applicable law or agreed to in writing, software
104252 * distributed under the License is distributed on an "AS IS" BASIS,
104253 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
104254 * See the License for the specific language governing permissions and
104255 * limitations under the License.
104256 * =============================================================================
104257 */
104258 const version = {
104259 'tfjs-core': version$7,
104260 'tfjs-backend-cpu': version$3,
104261 'tfjs-backend-webgl': version$2,
104262 'tfjs-data': version$4,
104263 'tfjs-layers': version$6,
104264 'tfjs-converter': version$5,
104265 'tfjs': version$1
104266 };
104267
104268 exports.Abs = Abs;
104269 exports.Acos = Acos;
104270 exports.Acosh = Acosh;
104271 exports.AdadeltaOptimizer = AdadeltaOptimizer;
104272 exports.AdagradOptimizer = AdagradOptimizer;
104273 exports.AdamOptimizer = AdamOptimizer;
104274 exports.AdamaxOptimizer = AdamaxOptimizer;
104275 exports.Add = Add$1;
104276 exports.AddN = AddN;
104277 exports.All = All;
104278 exports.Any = Any;
104279 exports.ArgMax = ArgMax;
104280 exports.ArgMin = ArgMin;
104281 exports.Asin = Asin;
104282 exports.Asinh = Asinh;
104283 exports.Atan = Atan;
104284 exports.Atan2 = Atan2;
104285 exports.Atanh = Atanh;
104286 exports.AvgPool = AvgPool;
104287 exports.AvgPool3D = AvgPool3D;
104288 exports.AvgPool3DGrad = AvgPool3DGrad;
104289 exports.AvgPoolGrad = AvgPoolGrad;
104290 exports.BatchMatMul = BatchMatMul;
104291 exports.BatchToSpaceND = BatchToSpaceND;
104292 exports.Bincount = Bincount;
104293 exports.BitwiseAnd = BitwiseAnd;
104294 exports.BroadcastArgs = BroadcastArgs;
104295 exports.BroadcastTo = BroadcastTo;
104296 exports.Callback = Callback;
104297 exports.CallbackList = CallbackList;
104298 exports.Cast = Cast;
104299 exports.Ceil = Ceil;
104300 exports.ClipByValue = ClipByValue;
104301 exports.Complex = Complex;
104302 exports.ComplexAbs = ComplexAbs;
104303 exports.Concat = Concat;
104304 exports.Conv2D = Conv2D$1;
104305 exports.Conv2DBackpropFilter = Conv2DBackpropFilter;
104306 exports.Conv2DBackpropInput = Conv2DBackpropInput;
104307 exports.Conv3D = Conv3D$1;
104308 exports.Conv3DBackpropFilterV2 = Conv3DBackpropFilterV2;
104309 exports.Conv3DBackpropInputV2 = Conv3DBackpropInputV2;
104310 exports.Cos = Cos;
104311 exports.Cosh = Cosh;
104312 exports.CropAndResize = CropAndResize;
104313 exports.Cumprod = Cumprod;
104314 exports.Cumsum = Cumsum;
104315 exports.CustomCallback = CustomCallback;
104316 exports.DataStorage = DataStorage;
104317 exports.DenseBincount = DenseBincount;
104318 exports.DepthToSpace = DepthToSpace;
104319 exports.DepthwiseConv2dNative = DepthwiseConv2dNative;
104320 exports.DepthwiseConv2dNativeBackpropFilter = DepthwiseConv2dNativeBackpropFilter;
104321 exports.DepthwiseConv2dNativeBackpropInput = DepthwiseConv2dNativeBackpropInput;
104322 exports.Diag = Diag;
104323 exports.Dilation2D = Dilation2D;
104324 exports.Dilation2DBackpropFilter = Dilation2DBackpropFilter;
104325 exports.Dilation2DBackpropInput = Dilation2DBackpropInput;
104326 exports.Draw = Draw;
104327 exports.EarlyStopping = EarlyStopping;
104328 exports.Einsum = Einsum;
104329 exports.Elu = Elu$1;
104330 exports.EluGrad = EluGrad;
104331 exports.Environment = Environment;
104332 exports.Equal = Equal;
104333 exports.Erf = Erf;
104334 exports.Exp = Exp;
104335 exports.ExpandDims = ExpandDims;
104336 exports.Expm1 = Expm1;
104337 exports.FFT = FFT;
104338 exports.Fill = Fill;
104339 exports.FlipLeftRight = FlipLeftRight;
104340 exports.Floor = Floor;
104341 exports.FloorDiv = FloorDiv;
104342 exports.FromPixels = FromPixels;
104343 exports.FusedBatchNorm = FusedBatchNorm;
104344 exports.FusedConv2D = FusedConv2D;
104345 exports.FusedDepthwiseConv2D = FusedDepthwiseConv2D;
104346 exports.GPGPUContext = GPGPUContext;
104347 exports.GatherNd = GatherNd;
104348 exports.GatherV2 = GatherV2;
104349 exports.GraphModel = GraphModel;
104350 exports.Greater = Greater;
104351 exports.GreaterEqual = GreaterEqual;
104352 exports.History = History;
104353 exports.IFFT = IFFT;
104354 exports.Identity = Identity$1;
104355 exports.Imag = Imag;
104356 exports.InputSpec = InputSpec;
104357 exports.IsFinite = IsFinite;
104358 exports.IsInf = IsInf;
104359 exports.IsNan = IsNan;
104360 exports.KernelBackend = KernelBackend;
104361 exports.LRN = LRN;
104362 exports.LRNGrad = LRNGrad;
104363 exports.LayerVariable = LayerVariable;
104364 exports.LayersModel = LayersModel;
104365 exports.LeakyRelu = LeakyRelu;
104366 exports.Less = Less;
104367 exports.LessEqual = LessEqual;
104368 exports.LinSpace = LinSpace;
104369 exports.Log = Log;
104370 exports.Log1p = Log1p;
104371 exports.LogSoftmax = LogSoftmax$1;
104372 exports.LogicalAnd = LogicalAnd;
104373 exports.LogicalNot = LogicalNot;
104374 exports.LogicalOr = LogicalOr;
104375 exports.LogicalXor = LogicalXor;
104376 exports.LowerBound = LowerBound;
104377 exports.MathBackendCPU = MathBackendCPU;
104378 exports.MathBackendWebGL = MathBackendWebGL;
104379 exports.MatrixBandPart = MatrixBandPart;
104380 exports.Max = Max;
104381 exports.MaxPool = MaxPool;
104382 exports.MaxPool3D = MaxPool3D;
104383 exports.MaxPool3DGrad = MaxPool3DGrad;
104384 exports.MaxPoolGrad = MaxPoolGrad;
104385 exports.MaxPoolWithArgmax = MaxPoolWithArgmax;
104386 exports.Maximum = Maximum$1;
104387 exports.Mean = Mean;
104388 exports.Min = Min;
104389 exports.Minimum = Minimum$1;
104390 exports.MirrorPad = MirrorPad;
104391 exports.Mod = Mod;
104392 exports.MomentumOptimizer = MomentumOptimizer;
104393 exports.Multinomial = Multinomial;
104394 exports.Multiply = Multiply$1;
104395 exports.Neg = Neg;
104396 exports.NonMaxSuppressionV3 = NonMaxSuppressionV3;
104397 exports.NonMaxSuppressionV4 = NonMaxSuppressionV4;
104398 exports.NonMaxSuppressionV5 = NonMaxSuppressionV5;
104399 exports.NotEqual = NotEqual;
104400 exports.OP_SCOPE_SUFFIX = OP_SCOPE_SUFFIX;
104401 exports.OneHot = OneHot;
104402 exports.OnesLike = OnesLike;
104403 exports.Optimizer = Optimizer;
104404 exports.OptimizerConstructors = OptimizerConstructors;
104405 exports.Pack = Pack;
104406 exports.PadV2 = PadV2;
104407 exports.Pool = Pool;
104408 exports.Pow = Pow;
104409 exports.Prelu = Prelu;
104410 exports.Prod = Prod;
104411 exports.RMSPropOptimizer = RMSPropOptimizer;
104412 exports.RNN = RNN;
104413 exports.RaggedGather = RaggedGather;
104414 exports.RaggedRange = RaggedRange;
104415 exports.RaggedTensorToTensor = RaggedTensorToTensor;
104416 exports.Range = Range;
104417 exports.Real = Real;
104418 exports.RealDiv = RealDiv;
104419 exports.Reciprocal = Reciprocal;
104420 exports.Relu = Relu$1;
104421 exports.Relu6 = Relu6$1;
104422 exports.Reshape = Reshape$1;
104423 exports.ResizeBilinear = ResizeBilinear;
104424 exports.ResizeBilinearGrad = ResizeBilinearGrad;
104425 exports.ResizeNearestNeighbor = ResizeNearestNeighbor;
104426 exports.ResizeNearestNeighborGrad = ResizeNearestNeighborGrad;
104427 exports.Reverse = Reverse;
104428 exports.RotateWithOffset = RotateWithOffset;
104429 exports.Round = Round;
104430 exports.Rsqrt = Rsqrt;
104431 exports.SGDOptimizer = SGDOptimizer;
104432 exports.ScatterNd = ScatterNd;
104433 exports.SearchSorted = SearchSorted;
104434 exports.Select = Select;
104435 exports.Selu = Selu$1;
104436 exports.Sequential = Sequential;
104437 exports.Sigmoid = Sigmoid$1;
104438 exports.Sign = Sign;
104439 exports.Sin = Sin;
104440 exports.Sinh = Sinh;
104441 exports.Slice = Slice;
104442 exports.Softmax = Softmax$2;
104443 exports.Softplus = Softplus$1;
104444 exports.SpaceToBatchND = SpaceToBatchND;
104445 exports.SparseFillEmptyRows = SparseFillEmptyRows;
104446 exports.SparseReshape = SparseReshape;
104447 exports.SparseSegmentMean = SparseSegmentMean;
104448 exports.SparseSegmentSum = SparseSegmentSum;
104449 exports.SparseToDense = SparseToDense;
104450 exports.SplitV = SplitV;
104451 exports.Sqrt = Sqrt;
104452 exports.Square = Square;
104453 exports.SquaredDifference = SquaredDifference;
104454 exports.StaticRegexReplace = StaticRegexReplace;
104455 exports.Step = Step;
104456 exports.StridedSlice = StridedSlice;
104457 exports.StringNGrams = StringNGrams;
104458 exports.StringSplit = StringSplit;
104459 exports.StringToHashBucketFast = StringToHashBucketFast;
104460 exports.Sub = Sub;
104461 exports.Sum = Sum;
104462 exports.SymbolicTensor = SymbolicTensor;
104463 exports.Tan = Tan;
104464 exports.Tanh = Tanh$1;
104465 exports.Tensor = Tensor;
104466 exports.TensorBuffer = TensorBuffer;
104467 exports.TensorScatterUpdate = TensorScatterUpdate;
104468 exports.Tile = Tile;
104469 exports.TopK = TopK;
104470 exports.Transform = Transform;
104471 exports.Transpose = Transpose;
104472 exports.Unique = Unique;
104473 exports.Unpack = Unpack;
104474 exports.UnsortedSegmentSum = UnsortedSegmentSum;
104475 exports.UpperBound = UpperBound;
104476 exports.Variable = Variable;
104477 exports.ZerosLike = ZerosLike;
104478 exports._FusedMatMul = _FusedMatMul;
104479 exports.abs = abs$2;
104480 exports.acos = acos$2;
104481 exports.acosh = acosh$2;
104482 exports.add = add$3;
104483 exports.addN = addN$2;
104484 exports.all = all$2;
104485 exports.any = any$2;
104486 exports.argMax = argMax$2;
104487 exports.argMin = argMin$2;
104488 exports.asin = asin$2;
104489 exports.asinh = asinh$2;
104490 exports.atan = atan$2;
104491 exports.atan2 = atan2$2;
104492 exports.atanh = atanh$2;
104493 exports.avgPool = avgPool$2;
104494 exports.avgPool3d = avgPool3d$1;
104495 exports.backend = backend$1;
104496 exports.backend_util = backend_util;
104497 exports.basicLSTMCell = basicLSTMCell;
104498 exports.batchNorm = batchNorm$2;
104499 exports.batchNorm2d = batchNorm2d;
104500 exports.batchNorm3d = batchNorm3d;
104501 exports.batchNorm4d = batchNorm4d;
104502 exports.batchToSpaceND = batchToSpaceND$2;
104503 exports.bincount = bincount$2;
104504 exports.bitwiseAnd = bitwiseAnd$2;
104505 exports.booleanMaskAsync = booleanMaskAsync;
104506 exports.broadcastArgs = broadcastArgs$2;
104507 exports.broadcastTo = broadcastTo;
104508 exports.broadcast_util = broadcast_util;
104509 exports.browser = browser;
104510 exports.buffer = buffer;
104511 exports.callbacks = callbacks;
104512 exports.cast = cast$3;
104513 exports.ceil = ceil$2;
104514 exports.clipByValue = clipByValue$2;
104515 exports.clone = clone;
104516 exports.complex = complex$2;
104517 exports.concat = concat$2;
104518 exports.concat1d = concat1d;
104519 exports.concat2d = concat2d;
104520 exports.concat3d = concat3d;
104521 exports.concat4d = concat4d;
104522 exports.constraints = exports_constraints;
104523 exports.conv1d = conv1d$2;
104524 exports.conv2d = conv2d$4;
104525 exports.conv2dTranspose = conv2dTranspose$1;
104526 exports.conv3d = conv3d$2;
104527 exports.conv3dTranspose = conv3dTranspose$1;
104528 exports.copyRegisteredKernels = copyRegisteredKernels;
104529 exports.cos = cos$2;
104530 exports.cosh = cosh$2;
104531 exports.cosineWindow = cosineWindow;
104532 exports.cumprod = cumprod$2;
104533 exports.cumsum = cumsum$2;
104534 exports.customGrad = customGrad;
104535 exports.data = index;
104536 exports.denseBincount = denseBincount$2;
104537 exports.deprecationWarn = deprecationWarn;
104538 exports.depthToSpace = depthToSpace$2;
104539 exports.depthwiseConv2d = depthwiseConv2d$3;
104540 exports.deregisterOp = deregisterOp;
104541 exports.device_util = device_util;
104542 exports.diag = diag$2;
104543 exports.dilation2d = dilation2d;
104544 exports.disableDeprecationWarnings = disableDeprecationWarnings;
104545 exports.dispose = dispose;
104546 exports.disposeVariables = disposeVariables;
104547 exports.div = div$1;
104548 exports.divNoNan = divNoNan;
104549 exports.dot = dot$2;
104550 exports.dropout = dropout$2;
104551 exports.einsum = einsum$2;
104552 exports.elu = elu$4;
104553 exports.enableDebugMode = enableDebugMode;
104554 exports.enableProdMode = enableProdMode;
104555 exports.enclosingPowerOfTwo = enclosingPowerOfTwo;
104556 exports.engine = engine;
104557 exports.ensureShape = ensureShape;
104558 exports.env = env;
104559 exports.equal = equal$2;
104560 exports.erf = erf$2;
104561 exports.euclideanNorm = euclideanNorm;
104562 exports.exp = exp$2;
104563 exports.expandDims = expandDims$3;
104564 exports.expm1 = expm1$2;
104565 exports.eye = eye;
104566 exports.fft = fft$2;
104567 exports.fill = fill$2;
104568 exports.findBackend = findBackend;
104569 exports.findBackendFactory = findBackendFactory;
104570 exports.floor = floor$2;
104571 exports.floorDiv = floorDiv$2;
104572 exports.forceHalfFloat = forceHalfFloat;
104573 exports.fused = fused_ops;
104574 exports.gather = gather$1;
104575 exports.gatherND = gatherND;
104576 exports.gather_util = gather_nd_util;
104577 exports.getBackend = getBackend$1;
104578 exports.getGradient = getGradient;
104579 exports.getKernel = getKernel;
104580 exports.getKernelsForBackend = getKernelsForBackend;
104581 exports.gpgpu_util = gpgpu_util;
104582 exports.grad = grad;
104583 exports.grads = grads;
104584 exports.greater = greater$3;
104585 exports.greaterEqual = greaterEqual$2;
104586 exports.ifft = ifft$2;
104587 exports.imag = imag$2;
104588 exports.image = image$1;
104589 exports.inTopKAsync = inTopKAsync;
104590 exports.initializers = exports_initializers;
104591 exports.input = input;
104592 exports.io = io;
104593 exports.irfft = irfft;
104594 exports.isFinite = isFinite$3;
104595 exports.isInf = isInf$2;
104596 exports.isNaN = isNaN$3;
104597 exports.keep = keep;
104598 exports.kernel_impls = kernel_impls;
104599 exports.layers = exports_layers;
104600 exports.leakyRelu = leakyRelu$2;
104601 exports.less = less$3;
104602 exports.lessEqual = lessEqual$2;
104603 exports.linalg = linalg;
104604 exports.linspace = linspace;
104605 exports.loadGraphModel = loadGraphModel;
104606 exports.loadGraphModelSync = loadGraphModelSync;
104607 exports.loadLayersModel = loadLayersModel;
104608 exports.localResponseNormalization = localResponseNormalization;
104609 exports.log = log$2;
104610 exports.log1p = log1p$2;
104611 exports.logSigmoid = logSigmoid;
104612 exports.logSoftmax = logSoftmax;
104613 exports.logSumExp = logSumExp;
104614 exports.logicalAnd = logicalAnd$2;
104615 exports.logicalNot = logicalNot$2;
104616 exports.logicalOr = logicalOr$2;
104617 exports.logicalXor = logicalXor;
104618 exports.losses = losses;
104619 exports.lowerBound = lowerBound$1;
104620 exports.matMul = matMul$1;
104621 exports.math = math;
104622 exports.max = max$3;
104623 exports.maxPool = maxPool$2;
104624 exports.maxPool3d = maxPool3d$1;
104625 exports.maxPoolWithArgmax = maxPoolWithArgmax;
104626 exports.maximum = maximum$4;
104627 exports.mean = mean$3;
104628 exports.memory = memory;
104629 exports.meshgrid = meshgrid;
104630 exports.metrics = exports_metrics;
104631 exports.min = min$3;
104632 exports.minimum = minimum$4;
104633 exports.mirrorPad = mirrorPad$1;
104634 exports.mod = mod$2;
104635 exports.model = model;
104636 exports.models = exports_models;
104637 exports.moments = moments;
104638 exports.movingAverage = movingAverage;
104639 exports.mul = mul;
104640 exports.multiRNNCell = multiRNNCell;
104641 exports.multinomial = multinomial$2;
104642 exports.neg = neg$2;
104643 exports.nextFrame = nextFrame;
104644 exports.norm = norm;
104645 exports.notEqual = notEqual$2;
104646 exports.oneHot = oneHot$3;
104647 exports.ones = ones$1;
104648 exports.onesLike = onesLike$3;
104649 exports.op = op;
104650 exports.outerProduct = outerProduct;
104651 exports.pad = pad;
104652 exports.pad1d = pad1d;
104653 exports.pad2d = pad2d;
104654 exports.pad3d = pad3d;
104655 exports.pad4d = pad4d;
104656 exports.pool = pool$1;
104657 exports.pow = pow$3;
104658 exports.prelu = prelu$3;
104659 exports.print = print;
104660 exports.prod = prod$2;
104661 exports.profile = profile;
104662 exports.raggedGather = raggedGather$2;
104663 exports.raggedRange = raggedRange$2;
104664 exports.raggedTensorToTensor = raggedTensorToTensor$2;
104665 exports.rand = rand;
104666 exports.randomGamma = randomGamma;
104667 exports.randomNormal = randomNormal$2;
104668 exports.randomStandardNormal = randomStandardNormal;
104669 exports.randomUniform = randomUniform$1;
104670 exports.randomUniformInt = randomUniformInt;
104671 exports.range = range$3;
104672 exports.ready = ready;
104673 exports.real = real$2;
104674 exports.reciprocal = reciprocal$2;
104675 exports.registerBackend = registerBackend;
104676 exports.registerCallbackConstructor = registerCallbackConstructor;
104677 exports.registerGradient = registerGradient;
104678 exports.registerKernel = registerKernel;
104679 exports.registerOp = registerOp;
104680 exports.regularizers = exports_regularizers;
104681 exports.relu = relu$2;
104682 exports.relu6 = relu6$2;
104683 exports.removeBackend = removeBackend;
104684 exports.reshape = reshape$3;
104685 exports.reverse = reverse$2;
104686 exports.reverse1d = reverse1d;
104687 exports.reverse2d = reverse2d;
104688 exports.reverse3d = reverse3d;
104689 exports.reverse4d = reverse4d;
104690 exports.rfft = rfft;
104691 exports.round = round$2;
104692 exports.rsqrt = rsqrt$2;
104693 exports.scalar = scalar;
104694 exports.scatterND = scatterND;
104695 exports.scatter_util = scatter_nd_util;
104696 exports.searchSorted = searchSorted$2;
104697 exports.selu = selu$2;
104698 exports.separableConv2d = separableConv2d$1;
104699 exports.sequential = sequential;
104700 exports.serialization = serialization;
104701 exports.setBackend = setBackend$1;
104702 exports.setPlatform = setPlatform;
104703 exports.setWebGLContext = setWebGLContext;
104704 exports.setdiff1dAsync = setdiff1dAsync;
104705 exports.shared = shared;
104706 exports.sigmoid = sigmoid$2;
104707 exports.sign = sign$3;
104708 exports.signal = signal;
104709 exports.sin = sin$2;
104710 exports.sinh = sinh$2;
104711 exports.slice = slice$2;
104712 exports.slice1d = slice1d;
104713 exports.slice2d = slice2d;
104714 exports.slice3d = slice3d;
104715 exports.slice4d = slice4d;
104716 exports.slice_util = slice_util;
104717 exports.softmax = softmax$3;
104718 exports.softplus = softplus$2;
104719 exports.spaceToBatchND = spaceToBatchND$2;
104720 exports.sparse = sparse$1;
104721 exports.sparseToDense = sparseToDense$2;
104722 exports.spectral = spectral$1;
104723 exports.split = split$3;
104724 exports.sqrt = sqrt$2;
104725 exports.square = square$2;
104726 exports.squaredDifference = squaredDifference$2;
104727 exports.squeeze = squeeze;
104728 exports.stack = stack;
104729 exports.step = step$2;
104730 exports.stridedSlice = stridedSlice$2;
104731 exports.string = string$1;
104732 exports.sub = sub$2;
104733 exports.sum = sum$3;
104734 exports.sumOutType = sumOutType;
104735 exports.tan = tan$2;
104736 exports.tanh = tanh$2;
104737 exports.tensor = tensor;
104738 exports.tensor1d = tensor1d;
104739 exports.tensor2d = tensor2d;
104740 exports.tensor3d = tensor3d;
104741 exports.tensor4d = tensor4d;
104742 exports.tensor5d = tensor5d;
104743 exports.tensor6d = tensor6d;
104744 exports.tensorScatterUpdate = tensorScatterUpdate$2;
104745 exports.tensor_util = tensor_util;
104746 exports.test_util = test_util;
104747 exports.tidy = tidy;
104748 exports.tile = tile$3;
104749 exports.time = time;
104750 exports.topk = topk;
104751 exports.train = train;
104752 exports.transpose = transpose$2;
104753 exports.truncatedNormal = truncatedNormal$1;
104754 exports.unique = unique$3;
104755 exports.unregisterGradient = unregisterGradient;
104756 exports.unregisterKernel = unregisterKernel;
104757 exports.unsortedSegmentSum = unsortedSegmentSum$2;
104758 exports.unstack = unstack;
104759 exports.upcastType = upcastType;
104760 exports.upperBound = upperBound$1;
104761 exports.util = util;
104762 exports.valueAndGrad = valueAndGrad;
104763 exports.valueAndGrads = valueAndGrads;
104764 exports.variable = variable$1;
104765 exports.variableGrads = variableGrads;
104766 exports.version = version;
104767 exports.version_converter = version$5;
104768 exports.version_core = version$7;
104769 exports.version_cpu = version$3;
104770 exports.version_layers = version$6;
104771 exports.version_webgl = version$2;
104772 exports.webgl = webgl;
104773 exports.webgl_util = webgl_util;
104774 exports.where = where;
104775 exports.whereAsync = whereAsync;
104776 exports.zeros = zeros$2;
104777 exports.zerosLike = zerosLike$3;
104778
104779}));
104780//# sourceMappingURL=tf.es2017.js.map